control long length within AVX2 implementation

This commit is contained in:
Yann Collet 2025-01-07 16:42:36 -08:00 committed by Yann Collet
parent d1f0e5fb97
commit 8d62164589

View File

@ -7121,8 +7121,12 @@ size_t ZSTD_compressSequences(ZSTD_CCtx* cctx,
* At the end, instead of extracting two __m128i,
* we use _mm256_permute4x64_epi64(..., 0xE8) to move lane2 into lane1,
* then store the lower 16 bytes in one go.
*
* @returns 0 on succes, with no long length detected
* @returns > 0 if there is one long length (> 65535),
* indicating the position, and type.
*/
void convertSequences_noRepcodes(
size_t convertSequences_noRepcodes(
SeqDef* dstSeqs,
const ZSTD_Sequence* inSeqs,
size_t nbSequences)
@ -7136,6 +7140,9 @@ void convertSequences_noRepcodes(
ZSTD_REP_NUM, 0, -MINMATCH, 0 /* for sequence i+1 */
);
/* limit: check if there is a long length */
const __m256i limit = _mm256_set1_epi32(65535);
/*
* shuffle mask for byte-level rearrangement in each 128-bit half:
*
@ -7170,16 +7177,20 @@ void convertSequences_noRepcodes(
*/
#define PERM_LANE_0X_E8 0xE8 /* [0,2,2,3] in lane indices */
size_t i = 0;
size_t longLen = 0, i = 0;
/* Process 2 sequences per loop iteration */
for (; i + 1 < nbSequences; i += 2) {
/* 1) Load 2 ZSTD_Sequence (32 bytes) */
/* Load 2 ZSTD_Sequence (32 bytes) */
__m256i vin = _mm256_loadu_si256((__m256i const*)&inSeqs[i]);
/* 2) Add {2, 0, -3, 0} in each 128-bit half */
/* Add {2, 0, -3, 0} in each 128-bit half */
__m256i vadd = _mm256_add_epi32(vin, addition);
/* 3) Shuffle bytes so each half gives us the 8 bytes we need */
/* Check for long length */
__m256i cmp = _mm256_cmpgt_epi32(vadd, limit); // 0xFFFFFFFF for element > 65535
int cmp_res = _mm256_movemask_epi8(cmp);
/* Shuffle bytes so each half gives us the 8 bytes we need */
__m256i vshf = _mm256_shuffle_epi8(vadd, mask);
/*
* Now:
@ -7189,20 +7200,44 @@ void convertSequences_noRepcodes(
* Lane3 = 0
*/
/* 4) Permute 64-bit lanes => move Lane2 down into Lane1. */
/* Permute 64-bit lanes => move Lane2 down into Lane1. */
__m256i vperm = _mm256_permute4x64_epi64(vshf, PERM_LANE_0X_E8);
/*
* Now the lower 16 bytes (Lane0+Lane1) = [seq0, seq1].
* The upper 16 bytes are [Lane2, Lane3] = [seq1, 0], but we won't use them.
*/
/* 5) Store only the lower 16 bytes => 2 SeqDef (8 bytes each) */
/* Store only the lower 16 bytes => 2 SeqDef (8 bytes each) */
_mm_storeu_si128((__m128i *)&dstSeqs[i], _mm256_castsi256_si128(vperm));
/*
* This writes out 16 bytes total:
* - offset 0..7 => seq0 (offBase, litLength, mlBase)
* - offset 8..15 => seq1 (offBase, litLength, mlBase)
*/
/* check (unlikely) long lengths > 65535
* indices for lengths correspond to bits [4..7], [8..11], [20..23], [24..27]
* => combined mask = 0x0FF00FF0
*/
if (UNLIKELY((cmp_res & 0x0FF00FF0) != 0)) {
/* long length detected: let's figure out which one*/
if (inSeqs[i].matchLength > 65535+MINMATCH) {
assert(longLen == 0);
longLen = i + 1;
}
if (inSeqs[i].litLength > 65535) {
assert(longLen == 0);
longLen = i + nbSequences + 1;
}
if (inSeqs[i+1].matchLength > 65535+MINMATCH) {
assert(longLen == 0);
longLen = i + 1 + 1;
}
if (inSeqs[i+1].litLength > 65535) {
assert(longLen == 0);
longLen = i + 1 + nbSequences + 1;
}
}
}
/* Handle leftover if @nbSequences is odd */
@ -7213,93 +7248,24 @@ void convertSequences_noRepcodes(
/* note: doesn't work if one length is > 65535 */
dstSeqs[i].litLength = (U16)inSeqs[i].litLength;
dstSeqs[i].mlBase = (U16)(inSeqs[i].matchLength - MINMATCH);
if (UNLIKELY(inSeqs[i].matchLength > 65535+MINMATCH)) {
assert(longLen == 0);
longLen = i + 1;
}
if (UNLIKELY(inSeqs[i].litLength > 65535)) {
assert(longLen == 0);
longLen = i + nbSequences + 1;
}
}
return longLen;
}
#elif defined(__SSSE3__)
/* the vector implementation could also be ported to SSSE3,
* but since this implementation is targeting modern systems >= Sapphire Rapid,
* it's not useful to develop and maintain code for older platforms (before AVX2) */
#include <tmmintrin.h> /* SSSE3 intrinsics: _mm_shuffle_epi8 */
#include <emmintrin.h> /* SSE2 intrinsics: _mm_add_epi32, etc. */
/*
* Convert sequences with SSE.
* - offset -> offBase = offset + 2
* - litLength (32-bit) -> (U16) litLength
* - matchLength (32-bit) -> (U16)(matchLength - 3)
* - rep is discarded.
*
* We shuffle so that only the first 8 bytes in the final 128-bit
* register are used. We still store 16 bytes (low 8 are good, high 8 are "don't care").
*/
static void convertSequences_noRepcodes(SeqDef* dstSeqs,
const ZSTD_Sequence* inSeqs,
size_t nbSequences)
{
/*
addition = { offset+2, litLength+0, matchLength-3, rep+0 }
setr means the first argument is placed in the lowest 32 bits,
second in next-lower 32 bits, etc.
*/
const __m128i addition = _mm_setr_epi32(2, 0, -3, 0);
/*
Shuffle mask: we reorder bytes after the addition.
Input layout in 128-bit register (after addition):
Bytes: [ 0..3 | 4..7 | 8..11 | 12..15 ]
Fields: offset+2 litLength matchLength rep
We want in output:
Bytes: [ 0..3 | 4..5 | 6..7 | 8..15 ignore ]
Fields: offset+2 (U16)litLength (U16)(matchLength)
_mm_shuffle_epi8 picks bytes from the source. A byte of 0x80 means zero out.
So we want:
out[0] = in[0], out[1] = in[1], out[2] = in[2], out[3] = in[3], // offset+2 (4 bytes)
out[4] = in[4], out[5] = in[5], // (U16) litLength
out[6] = in[8], out[7] = in[9], // (U16) matchLength
out[8..15] = 0x80 => won't matter if we only care about first 8 bytes
*/
const __m128i mask = _mm_setr_epi8(
0, 1, 2, 3, /* offset (4 bytes) */
4, 5, /* litLength (2 bytes) */
8, 9, /* matchLength (2 bytes) */
(char)0x80, (char)0x80, (char)0x80, (char)0x80,
(char)0x80, (char)0x80, (char)0x80, (char)0x80
);
size_t i;
for (i = 0; i + 1 < nbSequences; i += 2) {
/*-------------------------*/
/* Process inSeqs[i] */
/*-------------------------*/
__m128i vin0 = _mm_loadu_si128((const __m128i *)(const void*)&inSeqs[i]);
__m128i vadd0 = _mm_add_epi32(vin0, addition);
__m128i vshf0 = _mm_shuffle_epi8(vadd0, mask);
_mm_storel_epi64((__m128i *)(void*)&dstSeqs[i], vshf0);
/*-------------------------*/
/* Process inSeqs[i + 1] */
/*-------------------------*/
__m128i vin1 = _mm_loadu_si128((__m128i const *)(const void*)&inSeqs[i + 1]);
__m128i vadd1 = _mm_add_epi32(vin1, addition);
__m128i vshf1 = _mm_shuffle_epi8(vadd1, mask);
_mm_storel_epi64((__m128i *)(void*)&dstSeqs[i + 1], vshf1);
}
/* Handle leftover if nbSequences is odd */
if (i < nbSequences) {
/* Fallback: process last sequence */
assert(i == nbSequences - 1);
dstSeqs[i].offBase = OFFSET_TO_OFFBASE(inSeqs[i].offset);
/* note: doesn't work if one length is > 65535 */
dstSeqs[i].litLength = (U16)inSeqs[i].litLength;
dstSeqs[i].mlBase = (U16)(inSeqs[i].matchLength - MINMATCH);
}
}
#else /* no SSE */
#else /* no AVX2 */
static size_t
convertSequences_noRepcodes(SeqDef* dstSeqs,
@ -7312,6 +7278,10 @@ convertSequences_noRepcodes(SeqDef* dstSeqs,
/* note: doesn't work if one length is > 65535 */
dstSeqs[n].litLength = (U16)inSeqs[n].litLength;
dstSeqs[n].mlBase = (U16)(inSeqs[n].matchLength - MINMATCH);
if (UNLIKELY(inSeqs[n].matchLength > 65535+MINMATCH)) {
assert(longLen == 0);
longLen = n + 1;
}
if (UNLIKELY(inSeqs[n].litLength > 65535)) {
assert(longLen == 0);
longLen = n + nbSequences + 1;