Flatten ZSTD_row_getMatchMask (#2681)

* Flatten ZSTD_row_getMatchMask

* Remove the SIMD abstraction layer.
* Add big endian support.
* Align `hashTags` within `tagRow` to a 16-byte boundary. 
* Switch SSE2 to use aligned reads.
* Optimize scalar path using SWAR.
* Optimize neon path for `n == 32`
* Work around minor clang issue for NEON (https://bugs.llvm.org/show_bug.cgi?id=49577)

* replace memcpy with MEM_readST

* silence alignment warnings

* fix neon casts

* Update zstd_lazy.c

* unify simd preprocessor detection (#3)

* remove duplicate asserts

* tweak rotates

* improve endian detection

* add cast

there is a fun little catch-22 with gcc: result from pmovmskb has to be cast to uint32_t to avoid a zero-extension
but must be uint16_t to get gcc to generate a rotate instruction..

* more casts

* fix casts

better work-around for the (bogus) warning: unary minus on unsigned
This commit is contained in:
aqrit 2021-06-09 01:50:25 -04:00 committed by GitHub
parent 8a3bdfaa7b
commit dd4f6aa9e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 137 additions and 227 deletions

View File

@ -197,6 +197,22 @@
#define STATIC_BMI2 0
#endif
/* compile time determination of SIMD support */
#if !defined(ZSTD_NO_INTRINSICS)
# if defined(__SSE2__) || defined(_M_AMD64) || (defined (_M_IX86) && defined(_M_IX86_FP) && (_M_IX86_FP >= 2))
# define ZSTD_ARCH_X86_SSE2
# endif
# if defined(__ARM_NEON) || defined(_M_ARM64)
# define ZSTD_ARCH_ARM_NEON
# endif
#
# if defined(ZSTD_ARCH_X86_SSE2)
# include <emmintrin.h>
# elif defined(ZSTD_ARCH_ARM_NEON)
# include <arm_neon.h>
# endif
#endif
/* compat. with non-clang compilers */
#ifndef __has_builtin
# define __has_builtin(x) 0

View File

@ -153,8 +153,22 @@ MEM_STATIC unsigned MEM_64bits(void) { return sizeof(size_t)==8; }
MEM_STATIC unsigned MEM_isLittleEndian(void)
{
#if defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
return 1;
#elif defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
return 0;
#elif defined(__clang__) && __LITTLE_ENDIAN__
return 1;
#elif defined(__clang__) && __BIG_ENDIAN__
return 0;
#elif defined(_MSC_VER) && (_M_AMD64 || _M_IX86)
return 1;
#elif defined(__DMC__) && defined(_M_IX86)
return 1;
#else
const union { U32 u; BYTE c[4]; } one = { 1 }; /* don't use static : performance detrimental */
return one.c[0];
#endif
}
#if defined(MEM_FORCE_MEMORY_ACCESS) && (MEM_FORCE_MEMORY_ACCESS==2)

View File

@ -19,9 +19,6 @@
/*-*************************************
* Dependencies
***************************************/
#if !defined(ZSTD_NO_INTRINSICS) && defined(__ARM_NEON)
#include <arm_neon.h>
#endif
#include "compiler.h"
#include "mem.h"
#include "debug.h" /* assert, DEBUGLOG, RAWLOG, g_debuglevel */
@ -247,7 +244,7 @@ static UNUSED_ATTR const U32 OF_defaultNormLog = OF_DEFAULTNORMLOG;
* Shared functions to include for inlining
*********************************************/
static void ZSTD_copy8(void* dst, const void* src) {
#if !defined(ZSTD_NO_INTRINSICS) && defined(__ARM_NEON)
#if defined(ZSTD_ARCH_ARM_NEON)
vst1_u8((uint8_t*)dst, vld1_u8((const uint8_t*)src));
#else
ZSTD_memcpy(dst, src, 8);
@ -256,7 +253,7 @@ static void ZSTD_copy8(void* dst, const void* src) {
#define COPY8(d,s) { ZSTD_copy8(d,s); d+=8; s+=8; }
static void ZSTD_copy16(void* dst, const void* src) {
#if !defined(ZSTD_NO_INTRINSICS) && defined(__ARM_NEON)
#if defined(ZSTD_ARCH_ARM_NEON)
vst1q_u8((uint8_t*)dst, vld1q_u8((const uint8_t*)src));
#else
ZSTD_memcpy(dst, src, 16);

View File

@ -222,7 +222,7 @@ static int ZSTD_rowMatchFinderUsed(const ZSTD_strategy strategy, const ZSTD_useR
/* Returns row matchfinder usage enum given an initial mode and cParams */
static ZSTD_useRowMatchFinderMode_e ZSTD_resolveRowMatchFinderMode(ZSTD_useRowMatchFinderMode_e mode,
const ZSTD_compressionParameters* const cParams) {
#if !defined(ZSTD_NO_INTRINSICS) && (defined(__SSE2__) || defined(_M_AMD64) || defined(__ARM_NEON))
#if defined(ZSTD_ARCH_X86_SSE2) || defined(ZSTD_ARCH_ARM_NEON)
int const kHasSIMD128 = 1;
#else
int const kHasSIMD128 = 0;

View File

@ -865,7 +865,7 @@ FORCE_INLINE_TEMPLATE size_t ZSTD_HcFindBestMatch_extDict_selectMLS (
* (SIMD) Row-based matchfinder
***********************************/
/* Constants for row-based hash */
#define ZSTD_ROW_HASH_TAG_OFFSET 1 /* byte offset of hashes in the match state's tagTable from the beginning of a row */
#define ZSTD_ROW_HASH_TAG_OFFSET 16 /* byte offset of hashes in the match state's tagTable from the beginning of a row */
#define ZSTD_ROW_HASH_TAG_BITS 8 /* nb bits to use for the tag */
#define ZSTD_ROW_HASH_TAG_MASK ((1u << ZSTD_ROW_HASH_TAG_BITS) - 1)
@ -873,197 +873,6 @@ FORCE_INLINE_TEMPLATE size_t ZSTD_HcFindBestMatch_extDict_selectMLS (
typedef U32 ZSTD_VecMask; /* Clarifies when we are interacting with a U32 representing a mask of matches */
#if !defined(ZSTD_NO_INTRINSICS) && (defined(__SSE2__) || defined(_M_AMD64)) /* SIMD SSE version*/
#include <emmintrin.h>
typedef __m128i ZSTD_Vec128;
/* Returns a 128-bit container with 128-bits from src */
static ZSTD_Vec128 ZSTD_Vec128_read(const void* const src) {
return _mm_loadu_si128((ZSTD_Vec128 const*)src);
}
/* Returns a ZSTD_Vec128 with the byte "val" packed 16 times */
static ZSTD_Vec128 ZSTD_Vec128_set8(BYTE val) {
return _mm_set1_epi8((char)val);
}
/* Do byte-by-byte comparison result of x and y. Then collapse 128-bit resultant mask
* into a 32-bit mask that is the MSB of each byte.
* */
static ZSTD_VecMask ZSTD_Vec128_cmpMask8(ZSTD_Vec128 x, ZSTD_Vec128 y) {
return (ZSTD_VecMask)_mm_movemask_epi8(_mm_cmpeq_epi8(x, y));
}
typedef struct {
__m128i fst;
__m128i snd;
} ZSTD_Vec256;
static ZSTD_Vec256 ZSTD_Vec256_read(const void* const ptr) {
ZSTD_Vec256 v;
v.fst = ZSTD_Vec128_read(ptr);
v.snd = ZSTD_Vec128_read((ZSTD_Vec128 const*)ptr + 1);
return v;
}
static ZSTD_Vec256 ZSTD_Vec256_set8(BYTE val) {
ZSTD_Vec256 v;
v.fst = ZSTD_Vec128_set8(val);
v.snd = ZSTD_Vec128_set8(val);
return v;
}
static ZSTD_VecMask ZSTD_Vec256_cmpMask8(ZSTD_Vec256 x, ZSTD_Vec256 y) {
ZSTD_VecMask fstMask;
ZSTD_VecMask sndMask;
fstMask = ZSTD_Vec128_cmpMask8(x.fst, y.fst);
sndMask = ZSTD_Vec128_cmpMask8(x.snd, y.snd);
return fstMask | (sndMask << 16);
}
#elif !defined(ZSTD_NO_INTRINSICS) && defined(__ARM_NEON) /* SIMD ARM NEON Version */
#include <arm_neon.h>
typedef uint8x16_t ZSTD_Vec128;
static ZSTD_Vec128 ZSTD_Vec128_read(const void* const src) {
return vld1q_u8((const BYTE* const)src);
}
static ZSTD_Vec128 ZSTD_Vec128_set8(BYTE val) {
return vdupq_n_u8(val);
}
/* Mimics '_mm_movemask_epi8()' from SSE */
static U32 ZSTD_vmovmaskq_u8(ZSTD_Vec128 val) {
/* Shift out everything but the MSB bits in each byte */
uint16x8_t highBits = vreinterpretq_u16_u8(vshrq_n_u8(val, 7));
/* Merge the even lanes together with vsra (right shift and add) */
uint32x4_t paired16 = vreinterpretq_u32_u16(vsraq_n_u16(highBits, highBits, 7));
uint64x2_t paired32 = vreinterpretq_u64_u32(vsraq_n_u32(paired16, paired16, 14));
uint8x16_t paired64 = vreinterpretq_u8_u64(vsraq_n_u64(paired32, paired32, 28));
/* Extract the low 8 bits from each lane, merge */
return vgetq_lane_u8(paired64, 0) | ((U32)vgetq_lane_u8(paired64, 8) << 8);
}
static ZSTD_VecMask ZSTD_Vec128_cmpMask8(ZSTD_Vec128 x, ZSTD_Vec128 y) {
return (ZSTD_VecMask)ZSTD_vmovmaskq_u8(vceqq_u8(x, y));
}
typedef struct {
uint8x16_t fst;
uint8x16_t snd;
} ZSTD_Vec256;
static ZSTD_Vec256 ZSTD_Vec256_read(const void* const ptr) {
ZSTD_Vec256 v;
v.fst = ZSTD_Vec128_read(ptr);
v.snd = ZSTD_Vec128_read((ZSTD_Vec128 const*)ptr + 1);
return v;
}
static ZSTD_Vec256 ZSTD_Vec256_set8(BYTE val) {
ZSTD_Vec256 v;
v.fst = ZSTD_Vec128_set8(val);
v.snd = ZSTD_Vec128_set8(val);
return v;
}
static ZSTD_VecMask ZSTD_Vec256_cmpMask8(ZSTD_Vec256 x, ZSTD_Vec256 y) {
ZSTD_VecMask fstMask;
ZSTD_VecMask sndMask;
fstMask = ZSTD_Vec128_cmpMask8(x.fst, y.fst);
sndMask = ZSTD_Vec128_cmpMask8(x.snd, y.snd);
return fstMask | (sndMask << 16);
}
#else /* Scalar fallback version */
#define VEC128_NB_SIZE_T (16 / sizeof(size_t))
typedef struct {
size_t vec[VEC128_NB_SIZE_T];
} ZSTD_Vec128;
static ZSTD_Vec128 ZSTD_Vec128_read(const void* const src) {
ZSTD_Vec128 ret;
ZSTD_memcpy(ret.vec, src, VEC128_NB_SIZE_T*sizeof(size_t));
return ret;
}
static ZSTD_Vec128 ZSTD_Vec128_set8(BYTE val) {
ZSTD_Vec128 ret = { {0} };
int startBit = sizeof(size_t) * 8 - 8;
for (;startBit >= 0; startBit -= 8) {
unsigned j = 0;
for (;j < VEC128_NB_SIZE_T; ++j) {
ret.vec[j] |= ((size_t)val << startBit);
}
}
return ret;
}
/* Compare x to y, byte by byte, generating a "matches" bitfield */
static ZSTD_VecMask ZSTD_Vec128_cmpMask8(ZSTD_Vec128 x, ZSTD_Vec128 y) {
ZSTD_VecMask res = 0;
unsigned i = 0;
unsigned l = 0;
for (; i < VEC128_NB_SIZE_T; ++i) {
const size_t cmp1 = x.vec[i];
const size_t cmp2 = y.vec[i];
unsigned j = 0;
for (; j < sizeof(size_t); ++j, ++l) {
if (((cmp1 >> j*8) & 0xFF) == ((cmp2 >> j*8) & 0xFF)) {
res |= ((U32)1 << (j+i*sizeof(size_t)));
}
}
}
return res;
}
#define VEC256_NB_SIZE_T 2*VEC128_NB_SIZE_T
typedef struct {
size_t vec[VEC256_NB_SIZE_T];
} ZSTD_Vec256;
static ZSTD_Vec256 ZSTD_Vec256_read(const void* const src) {
ZSTD_Vec256 ret;
ZSTD_memcpy(ret.vec, src, VEC256_NB_SIZE_T*sizeof(size_t));
return ret;
}
static ZSTD_Vec256 ZSTD_Vec256_set8(BYTE val) {
ZSTD_Vec256 ret = { {0} };
int startBit = sizeof(size_t) * 8 - 8;
for (;startBit >= 0; startBit -= 8) {
unsigned j = 0;
for (;j < VEC256_NB_SIZE_T; ++j) {
ret.vec[j] |= ((size_t)val << startBit);
}
}
return ret;
}
/* Compare x to y, byte by byte, generating a "matches" bitfield */
static ZSTD_VecMask ZSTD_Vec256_cmpMask8(ZSTD_Vec256 x, ZSTD_Vec256 y) {
ZSTD_VecMask res = 0;
unsigned i = 0;
unsigned l = 0;
for (; i < VEC256_NB_SIZE_T; ++i) {
const size_t cmp1 = x.vec[i];
const size_t cmp2 = y.vec[i];
unsigned j = 0;
for (; j < sizeof(size_t); ++j, ++l) {
if (((cmp1 >> j*8) & 0xFF) == ((cmp2 >> j*8) & 0xFF)) {
res |= ((U32)1 << (j+i*sizeof(size_t)));
}
}
}
return res;
}
#endif /* !defined(ZSTD_NO_INTRINSICS) && defined(__SSE2__) */
/* ZSTD_VecMask_next():
* Starting from the LSB, returns the idx of the next non-zero bit.
* Basically counting the nb of trailing zeroes.
@ -1085,22 +894,22 @@ static U32 ZSTD_VecMask_next(ZSTD_VecMask val) {
# endif
}
/* ZSTD_VecMask_rotateRight():
* Rotates a bitfield to the right by "rotation" bits.
* If the rotation is greater than totalBits, the returned mask is 0.
/* ZSTD_rotateRight_U32():
* Rotates a bitfield to the right by "count" bits.
* https://en.wikipedia.org/w/index.php?title=Circular_shift&oldid=991635599#Implementing_circular_shifts
*/
FORCE_INLINE_TEMPLATE ZSTD_VecMask
ZSTD_VecMask_rotateRight(ZSTD_VecMask mask, U32 const rotation, U32 const totalBits) {
if (rotation == 0)
return mask;
switch (totalBits) {
default:
assert(0);
case 16:
return (mask >> rotation) | (U16)(mask << (16 - rotation));
case 32:
return (mask >> rotation) | (U32)(mask << (32 - rotation));
}
FORCE_INLINE_TEMPLATE
U32 ZSTD_rotateRight_U32(U32 const value, U32 count) {
assert(count < 32);
count &= 0x1F; /* for fickle pattern recognition */
return (value >> count) | (U32)(value << ((0U - count) & 0x1F));
}
FORCE_INLINE_TEMPLATE
U16 ZSTD_rotateRight_U16(U16 const value, U32 count) {
assert(count < 16);
count &= 0x0F; /* for fickle pattern recognition */
return (value >> count) | (U16)(value << ((0U - count) & 0x0F));
}
/* ZSTD_row_nextIndex():
@ -1226,24 +1035,98 @@ void ZSTD_row_update(ZSTD_matchState_t* const ms, const BYTE* ip) {
/* Returns a ZSTD_VecMask (U32) that has the nth bit set to 1 if the newly-computed "tag" matches
* the hash at the nth position in a row of the tagTable.
*/
* Each row is a circular buffer beginning at the value of "head". So we must rotate the "matches" bitfield
* to match up with the actual layout of the entries within the hashTable */
FORCE_INLINE_TEMPLATE
ZSTD_VecMask ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head, const U32 rowEntries) {
ZSTD_VecMask matches = 0;
const BYTE* const src = tagRow + ZSTD_ROW_HASH_TAG_OFFSET;
assert((rowEntries == 16) || (rowEntries == 32));
#if defined(ZSTD_ARCH_X86_SSE2)
if (rowEntries == 16) {
ZSTD_Vec128 hashes = ZSTD_Vec128_read(tagRow + ZSTD_ROW_HASH_TAG_OFFSET);
ZSTD_Vec128 expandedTags = ZSTD_Vec128_set8(tag);
matches = ZSTD_Vec128_cmpMask8(hashes, expandedTags);
} else if (rowEntries == 32) {
ZSTD_Vec256 hashes = ZSTD_Vec256_read(tagRow + ZSTD_ROW_HASH_TAG_OFFSET);
ZSTD_Vec256 expandedTags = ZSTD_Vec256_set8(tag);
matches = ZSTD_Vec256_cmpMask8(hashes, expandedTags);
} else {
assert(0);
const __m128i chunk = _mm_loadu_si128((const __m128i*)(const void*)src);
const __m128i equalMask = _mm_cmpeq_epi8(chunk, _mm_set1_epi8(tag));
const U16 matches = (U16)_mm_movemask_epi8(equalMask);
return ZSTD_rotateRight_U16(matches, head);
} else { /* rowEntries == 32 */
const __m128i chunk0 = _mm_loadu_si128((const __m128i*)(const void*)&src[0]);
const __m128i chunk1 = _mm_loadu_si128((const __m128i*)(const void*)&src[16]);
const __m128i equalMask0 = _mm_cmpeq_epi8(chunk0, _mm_set1_epi8(tag));
const __m128i equalMask1 = _mm_cmpeq_epi8(chunk1, _mm_set1_epi8(tag));
const U32 lo = (U32)_mm_movemask_epi8(equalMask0);
const U32 hi = (U32)_mm_movemask_epi8(equalMask1);
return ZSTD_rotateRight_U32((hi << 16) | lo, head);
}
/* Each row is a circular buffer beginning at the value of "head". So we must rotate the "matches" bitfield
to match up with the actual layout of the entries within the hashTable */
return ZSTD_VecMask_rotateRight(matches, head, rowEntries);
#else
# if defined(ZSTD_ARCH_ARM_NEON)
if (MEM_isLittleEndian()) {
if (rowEntries == 16) {
const uint8x16_t chunk = vld1q_u8(src);
const uint16x8_t equalMask = vreinterpretq_u16_u8(vceqq_u8(chunk, vdupq_n_u8(tag)));
const uint16x8_t t0 = vshlq_n_u16(equalMask, 7);
const uint32x4_t t1 = vreinterpretq_u32_u16(vsriq_n_u16(t0, t0, 14));
const uint64x2_t t2 = vreinterpretq_u64_u32(vshrq_n_u32(t1, 14));
const uint8x16_t t3 = vreinterpretq_u8_u64(vsraq_n_u64(t2, t2, 28));
const U16 hi = (U16)vgetq_lane_u8(t3, 8);
const U16 lo = (U16)vgetq_lane_u8(t3, 0);
return ZSTD_rotateRight_U16((hi << 8) | lo, head);
} else { /* rowEntries == 32 */
const uint16x8x2_t chunk = vld2q_u16((const U16*)(const void*)src);
const uint8x16_t chunk0 = vreinterpretq_u8_u16(chunk.val[0]);
const uint8x16_t chunk1 = vreinterpretq_u8_u16(chunk.val[1]);
const uint8x16_t equalMask0 = vceqq_u8(chunk0, vdupq_n_u8(tag));
const uint8x16_t equalMask1 = vceqq_u8(chunk1, vdupq_n_u8(tag));
const int8x8_t pack0 = vqmovn_s16(vreinterpretq_s16_u8(equalMask0));
const int8x8_t pack1 = vqmovn_s16(vreinterpretq_s16_u8(equalMask1));
const uint8x8_t t0 = vreinterpret_u8_s8(pack0);
const uint8x8_t t1 = vreinterpret_u8_s8(pack1);
const uint8x8_t t2 = vsri_n_u8(t1, t0, 2);
const uint8x8x2_t t3 = vuzp_u8(t2, t0);
const uint8x8_t t4 = vsri_n_u8(t3.val[1], t3.val[0], 4);
const U32 matches = vget_lane_u32(vreinterpret_u32_u8(t4), 0);
return ZSTD_rotateRight_U32(matches, head);
}
}
# endif
{ /* SWAR */
const size_t chunkSize = sizeof(size_t);
const size_t shiftAmount = ((chunkSize * 8) - chunkSize);
const size_t xFF = ~((size_t)0);
const size_t x01 = xFF / 0xFF;
const size_t x80 = x01 << 7;
const size_t splatChar = tag * x01;
size_t matches = 0;
int i = rowEntries - chunkSize;
assert((sizeof(size_t) == 4) || (sizeof(size_t) == 8));
if (MEM_isLittleEndian()) { /* runtime check so have two loops */
const size_t extractMagic = (xFF / 0x7F) >> chunkSize;
do {
size_t chunk = MEM_readST(&src[i]);
chunk ^= splatChar;
chunk = (((chunk | x80) - x01) | chunk) & x80;
matches <<= chunkSize;
matches |= (chunk * extractMagic) >> shiftAmount;
i -= chunkSize;
} while (i >= 0);
} else { /* big endian: reverse bits during extraction */
const size_t msb = xFF ^ (xFF >> 1);
const size_t extractMagic = (msb / 0x1FF) | msb;
do {
size_t chunk = MEM_readST(&src[i]);
chunk ^= splatChar;
chunk = (((chunk | x80) - x01) | chunk) & x80;
matches <<= chunkSize;
matches |= ((chunk >> 7) * extractMagic) >> shiftAmount;
i -= chunkSize;
} while (i >= 0);
}
matches = ~matches;
if (rowEntries == 16) {
return ZSTD_rotateRight_U16((U16)matches, head);
} else { /* rowEntries == 32 */
return ZSTD_rotateRight_U32((U32)matches, head);
}
}
#endif
}
/* The high-level approach of the SIMD row based match finder is as follows: