mirror of
https://github.com/facebook/zstd.git
synced 2025-10-04 00:02:33 -04:00
control long length within AVX2 implementation
This commit is contained in:
parent
d1f0e5fb97
commit
8d62164589
@ -7121,8 +7121,12 @@ size_t ZSTD_compressSequences(ZSTD_CCtx* cctx,
|
||||
* At the end, instead of extracting two __m128i,
|
||||
* we use _mm256_permute4x64_epi64(..., 0xE8) to move lane2 into lane1,
|
||||
* then store the lower 16 bytes in one go.
|
||||
*
|
||||
* @returns 0 on succes, with no long length detected
|
||||
* @returns > 0 if there is one long length (> 65535),
|
||||
* indicating the position, and type.
|
||||
*/
|
||||
void convertSequences_noRepcodes(
|
||||
size_t convertSequences_noRepcodes(
|
||||
SeqDef* dstSeqs,
|
||||
const ZSTD_Sequence* inSeqs,
|
||||
size_t nbSequences)
|
||||
@ -7136,6 +7140,9 @@ void convertSequences_noRepcodes(
|
||||
ZSTD_REP_NUM, 0, -MINMATCH, 0 /* for sequence i+1 */
|
||||
);
|
||||
|
||||
/* limit: check if there is a long length */
|
||||
const __m256i limit = _mm256_set1_epi32(65535);
|
||||
|
||||
/*
|
||||
* shuffle mask for byte-level rearrangement in each 128-bit half:
|
||||
*
|
||||
@ -7170,16 +7177,20 @@ void convertSequences_noRepcodes(
|
||||
*/
|
||||
#define PERM_LANE_0X_E8 0xE8 /* [0,2,2,3] in lane indices */
|
||||
|
||||
size_t i = 0;
|
||||
size_t longLen = 0, i = 0;
|
||||
/* Process 2 sequences per loop iteration */
|
||||
for (; i + 1 < nbSequences; i += 2) {
|
||||
/* 1) Load 2 ZSTD_Sequence (32 bytes) */
|
||||
/* Load 2 ZSTD_Sequence (32 bytes) */
|
||||
__m256i vin = _mm256_loadu_si256((__m256i const*)&inSeqs[i]);
|
||||
|
||||
/* 2) Add {2, 0, -3, 0} in each 128-bit half */
|
||||
/* Add {2, 0, -3, 0} in each 128-bit half */
|
||||
__m256i vadd = _mm256_add_epi32(vin, addition);
|
||||
|
||||
/* 3) Shuffle bytes so each half gives us the 8 bytes we need */
|
||||
/* Check for long length */
|
||||
__m256i cmp = _mm256_cmpgt_epi32(vadd, limit); // 0xFFFFFFFF for element > 65535
|
||||
int cmp_res = _mm256_movemask_epi8(cmp);
|
||||
|
||||
/* Shuffle bytes so each half gives us the 8 bytes we need */
|
||||
__m256i vshf = _mm256_shuffle_epi8(vadd, mask);
|
||||
/*
|
||||
* Now:
|
||||
@ -7189,20 +7200,44 @@ void convertSequences_noRepcodes(
|
||||
* Lane3 = 0
|
||||
*/
|
||||
|
||||
/* 4) Permute 64-bit lanes => move Lane2 down into Lane1. */
|
||||
/* Permute 64-bit lanes => move Lane2 down into Lane1. */
|
||||
__m256i vperm = _mm256_permute4x64_epi64(vshf, PERM_LANE_0X_E8);
|
||||
/*
|
||||
* Now the lower 16 bytes (Lane0+Lane1) = [seq0, seq1].
|
||||
* The upper 16 bytes are [Lane2, Lane3] = [seq1, 0], but we won't use them.
|
||||
*/
|
||||
|
||||
/* 5) Store only the lower 16 bytes => 2 SeqDef (8 bytes each) */
|
||||
/* Store only the lower 16 bytes => 2 SeqDef (8 bytes each) */
|
||||
_mm_storeu_si128((__m128i *)&dstSeqs[i], _mm256_castsi256_si128(vperm));
|
||||
/*
|
||||
* This writes out 16 bytes total:
|
||||
* - offset 0..7 => seq0 (offBase, litLength, mlBase)
|
||||
* - offset 8..15 => seq1 (offBase, litLength, mlBase)
|
||||
*/
|
||||
|
||||
/* check (unlikely) long lengths > 65535
|
||||
* indices for lengths correspond to bits [4..7], [8..11], [20..23], [24..27]
|
||||
* => combined mask = 0x0FF00FF0
|
||||
*/
|
||||
if (UNLIKELY((cmp_res & 0x0FF00FF0) != 0)) {
|
||||
/* long length detected: let's figure out which one*/
|
||||
if (inSeqs[i].matchLength > 65535+MINMATCH) {
|
||||
assert(longLen == 0);
|
||||
longLen = i + 1;
|
||||
}
|
||||
if (inSeqs[i].litLength > 65535) {
|
||||
assert(longLen == 0);
|
||||
longLen = i + nbSequences + 1;
|
||||
}
|
||||
if (inSeqs[i+1].matchLength > 65535+MINMATCH) {
|
||||
assert(longLen == 0);
|
||||
longLen = i + 1 + 1;
|
||||
}
|
||||
if (inSeqs[i+1].litLength > 65535) {
|
||||
assert(longLen == 0);
|
||||
longLen = i + 1 + nbSequences + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Handle leftover if @nbSequences is odd */
|
||||
@ -7213,93 +7248,24 @@ void convertSequences_noRepcodes(
|
||||
/* note: doesn't work if one length is > 65535 */
|
||||
dstSeqs[i].litLength = (U16)inSeqs[i].litLength;
|
||||
dstSeqs[i].mlBase = (U16)(inSeqs[i].matchLength - MINMATCH);
|
||||
if (UNLIKELY(inSeqs[i].matchLength > 65535+MINMATCH)) {
|
||||
assert(longLen == 0);
|
||||
longLen = i + 1;
|
||||
}
|
||||
if (UNLIKELY(inSeqs[i].litLength > 65535)) {
|
||||
assert(longLen == 0);
|
||||
longLen = i + nbSequences + 1;
|
||||
}
|
||||
}
|
||||
|
||||
return longLen;
|
||||
}
|
||||
|
||||
#elif defined(__SSSE3__)
|
||||
/* the vector implementation could also be ported to SSSE3,
|
||||
* but since this implementation is targeting modern systems >= Sapphire Rapid,
|
||||
* it's not useful to develop and maintain code for older platforms (before AVX2) */
|
||||
|
||||
#include <tmmintrin.h> /* SSSE3 intrinsics: _mm_shuffle_epi8 */
|
||||
#include <emmintrin.h> /* SSE2 intrinsics: _mm_add_epi32, etc. */
|
||||
|
||||
/*
|
||||
* Convert sequences with SSE.
|
||||
* - offset -> offBase = offset + 2
|
||||
* - litLength (32-bit) -> (U16) litLength
|
||||
* - matchLength (32-bit) -> (U16)(matchLength - 3)
|
||||
* - rep is discarded.
|
||||
*
|
||||
* We shuffle so that only the first 8 bytes in the final 128-bit
|
||||
* register are used. We still store 16 bytes (low 8 are good, high 8 are "don't care").
|
||||
*/
|
||||
static void convertSequences_noRepcodes(SeqDef* dstSeqs,
|
||||
const ZSTD_Sequence* inSeqs,
|
||||
size_t nbSequences)
|
||||
{
|
||||
/*
|
||||
addition = { offset+2, litLength+0, matchLength-3, rep+0 }
|
||||
setr means the first argument is placed in the lowest 32 bits,
|
||||
second in next-lower 32 bits, etc.
|
||||
*/
|
||||
const __m128i addition = _mm_setr_epi32(2, 0, -3, 0);
|
||||
|
||||
/*
|
||||
Shuffle mask: we reorder bytes after the addition.
|
||||
|
||||
Input layout in 128-bit register (after addition):
|
||||
Bytes: [ 0..3 | 4..7 | 8..11 | 12..15 ]
|
||||
Fields: offset+2 litLength matchLength rep
|
||||
|
||||
We want in output:
|
||||
Bytes: [ 0..3 | 4..5 | 6..7 | 8..15 ignore ]
|
||||
Fields: offset+2 (U16)litLength (U16)(matchLength)
|
||||
|
||||
_mm_shuffle_epi8 picks bytes from the source. A byte of 0x80 means “zero out”.
|
||||
So we want:
|
||||
out[0] = in[0], out[1] = in[1], out[2] = in[2], out[3] = in[3], // offset+2 (4 bytes)
|
||||
out[4] = in[4], out[5] = in[5], // (U16) litLength
|
||||
out[6] = in[8], out[7] = in[9], // (U16) matchLength
|
||||
out[8..15] = 0x80 => won't matter if we only care about first 8 bytes
|
||||
*/
|
||||
const __m128i mask = _mm_setr_epi8(
|
||||
0, 1, 2, 3, /* offset (4 bytes) */
|
||||
4, 5, /* litLength (2 bytes) */
|
||||
8, 9, /* matchLength (2 bytes) */
|
||||
(char)0x80, (char)0x80, (char)0x80, (char)0x80,
|
||||
(char)0x80, (char)0x80, (char)0x80, (char)0x80
|
||||
);
|
||||
size_t i;
|
||||
|
||||
for (i = 0; i + 1 < nbSequences; i += 2) {
|
||||
/*-------------------------*/
|
||||
/* Process inSeqs[i] */
|
||||
/*-------------------------*/
|
||||
__m128i vin0 = _mm_loadu_si128((const __m128i *)(const void*)&inSeqs[i]);
|
||||
__m128i vadd0 = _mm_add_epi32(vin0, addition);
|
||||
__m128i vshf0 = _mm_shuffle_epi8(vadd0, mask);
|
||||
_mm_storel_epi64((__m128i *)(void*)&dstSeqs[i], vshf0);
|
||||
|
||||
/*-------------------------*/
|
||||
/* Process inSeqs[i + 1] */
|
||||
/*-------------------------*/
|
||||
__m128i vin1 = _mm_loadu_si128((__m128i const *)(const void*)&inSeqs[i + 1]);
|
||||
__m128i vadd1 = _mm_add_epi32(vin1, addition);
|
||||
__m128i vshf1 = _mm_shuffle_epi8(vadd1, mask);
|
||||
_mm_storel_epi64((__m128i *)(void*)&dstSeqs[i + 1], vshf1);
|
||||
}
|
||||
|
||||
/* Handle leftover if nbSequences is odd */
|
||||
if (i < nbSequences) {
|
||||
/* Fallback: process last sequence */
|
||||
assert(i == nbSequences - 1);
|
||||
dstSeqs[i].offBase = OFFSET_TO_OFFBASE(inSeqs[i].offset);
|
||||
/* note: doesn't work if one length is > 65535 */
|
||||
dstSeqs[i].litLength = (U16)inSeqs[i].litLength;
|
||||
dstSeqs[i].mlBase = (U16)(inSeqs[i].matchLength - MINMATCH);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#else /* no SSE */
|
||||
#else /* no AVX2 */
|
||||
|
||||
static size_t
|
||||
convertSequences_noRepcodes(SeqDef* dstSeqs,
|
||||
@ -7312,6 +7278,10 @@ convertSequences_noRepcodes(SeqDef* dstSeqs,
|
||||
/* note: doesn't work if one length is > 65535 */
|
||||
dstSeqs[n].litLength = (U16)inSeqs[n].litLength;
|
||||
dstSeqs[n].mlBase = (U16)(inSeqs[n].matchLength - MINMATCH);
|
||||
if (UNLIKELY(inSeqs[n].matchLength > 65535+MINMATCH)) {
|
||||
assert(longLen == 0);
|
||||
longLen = n + 1;
|
||||
}
|
||||
if (UNLIKELY(inSeqs[n].litLength > 65535)) {
|
||||
assert(longLen == 0);
|
||||
longLen = n + nbSequences + 1;
|
||||
|
Loading…
x
Reference in New Issue
Block a user