diff --git a/lib/decompress/huf_decompress.c b/lib/decompress/huf_decompress.c index 5bbdef49a..b8795efb5 100644 --- a/lib/decompress/huf_decompress.c +++ b/lib/decompress/huf_decompress.c @@ -164,17 +164,18 @@ static size_t HUF_initFastDStream(BYTE const* ip) { * op [in/out] - The output pointers, must be updated to reflect what is written. * bits [in/out] - The bitstream containers, must be updated to reflect the current state. * dt [in] - The decoding table. - * ilimit [in] - The input limit, stop when any input pointer is below ilimit. + * ilowest [in] - The beginning of the valid range of the input. Decoders may read + * down to this pointer. It may be below iend[0]. * oend [in] - The end of the output stream. op[3] must not cross oend. * iend [in] - The end of each input stream. ip[i] may cross iend[i], - * as long as it is above ilimit, but that indicates corruption. + * as long as it is above ilowest, but that indicates corruption. */ typedef struct { BYTE const* ip[4]; BYTE* op[4]; U64 bits[4]; void const* dt; - BYTE const* ilimit; + BYTE const* ilowest; BYTE* oend; BYTE const* iend[4]; } HUF_DecompressFastArgs; @@ -192,7 +193,7 @@ static size_t HUF_DecompressFastArgs_init(HUF_DecompressFastArgs* args, void* ds void const* dt = DTable + 1; U32 const dtLog = HUF_getDTableDesc(DTable).tableLog; - const BYTE* const ilimit = (const BYTE*)src + 6 + 8; + const BYTE* const istart = (const BYTE*)src; BYTE* const oend = ZSTD_maybeNullPtrAdd((BYTE*)dst, dstSize); @@ -215,7 +216,6 @@ static size_t HUF_DecompressFastArgs_init(HUF_DecompressFastArgs* args, void* ds /* Read the jump table. */ { - const BYTE* const istart = (const BYTE*)src; size_t const length1 = MEM_readLE16(istart); size_t const length2 = MEM_readLE16(istart+2); size_t const length3 = MEM_readLE16(istart+4); @@ -227,10 +227,8 @@ static size_t HUF_DecompressFastArgs_init(HUF_DecompressFastArgs* args, void* ds /* HUF_initFastDStream() requires this, and this small of an input * won't benefit from the ASM loop anyways. - * length1 must be >= 16 so that ip[0] >= ilimit before the loop - * starts. */ - if (length1 < 16 || length2 < 8 || length3 < 8 || length4 < 8) + if (length1 < 8 || length2 < 8 || length3 < 8 || length4 < 8) return 0; if (length4 > srcSize) return ERROR(corruption_detected); /* overflow */ } @@ -262,11 +260,12 @@ static size_t HUF_DecompressFastArgs_init(HUF_DecompressFastArgs* args, void* ds args->bits[2] = HUF_initFastDStream(args->ip[2]); args->bits[3] = HUF_initFastDStream(args->ip[3]); - /* If ip[] >= ilimit, it is guaranteed to be safe to - * reload bits[]. It may be beyond its section, but is - * guaranteed to be valid (>= istart). - */ - args->ilimit = ilimit; + /* The decoders must be sure to never read beyond ilowest. + * This is lower than iend[0], but allowing decoders to read + * down to ilowest can allow an extra iteration or two in the + * fast loop. + */ + args->ilowest = istart; args->oend = oend; args->dt = dt; @@ -291,7 +290,7 @@ static size_t HUF_initRemainingDStream(BIT_DStream_t* bit, HUF_DecompressFastArg assert(sizeof(size_t) == 8); bit->bitContainer = MEM_readLEST(args->ip[stream]); bit->bitsConsumed = ZSTD_countTrailingZeros64(args->bits[stream]); - bit->start = (const char*)args->iend[0]; + bit->start = (const char*)args->ilowest; bit->limitPtr = bit->start + sizeof(size_t); bit->ptr = (const char*)args->ip[stream]; @@ -717,7 +716,7 @@ void HUF_decompress4X1_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs* BYTE* op[4]; U16 const* const dtable = (U16 const*)args->dt; BYTE* const oend = args->oend; - BYTE const* const ilimit = args->ilimit; + BYTE const* const ilowest = args->ilowest; /* Copy the arguments to local variables */ ZSTD_memcpy(&bits, &args->bits, sizeof(bits)); @@ -735,7 +734,7 @@ void HUF_decompress4X1_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs* #ifndef NDEBUG for (stream = 0; stream < 4; ++stream) { assert(op[stream] <= (stream == 3 ? oend : op[stream + 1])); - assert(ip[stream] >= ilimit); + assert(ip[stream] >= ilowest); } #endif /* Compute olimit */ @@ -745,7 +744,7 @@ void HUF_decompress4X1_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs* /* Each iteration consumes up to 11 bits * 5 = 55 bits < 7 bytes * per stream. */ - size_t const iiters = (size_t)(ip[0] - ilimit) / 7; + size_t const iiters = (size_t)(ip[0] - ilowest) / 7; /* We can safely run iters iterations before running bounds checks */ size_t const iters = MIN(oiters, iiters); size_t const symbols = iters * 5; @@ -756,8 +755,8 @@ void HUF_decompress4X1_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs* */ olimit = op[3] + symbols; - /* Exit fast decoding loop once we get close to the end. */ - if (op[3] + 20 > olimit) + /* Exit fast decoding loop once we reach the end. */ + if (op[3] == olimit) break; /* Exit the decoding loop if any input pointer has crossed the @@ -836,7 +835,7 @@ HUF_decompress4X1_usingDTable_internal_fast( HUF_DecompressFastLoopFn loopFn) { void const* dt = DTable + 1; - const BYTE* const iend = (const BYTE*)cSrc + 6; + BYTE const* const ilowest = (BYTE const*)cSrc; BYTE* const oend = ZSTD_maybeNullPtrAdd((BYTE*)dst, dstSize); HUF_DecompressFastArgs args; { size_t const ret = HUF_DecompressFastArgs_init(&args, dst, dstSize, cSrc, cSrcSize, DTable); @@ -845,18 +844,22 @@ HUF_decompress4X1_usingDTable_internal_fast( return 0; } - assert(args.ip[0] >= args.ilimit); + assert(args.ip[0] >= args.ilowest); loopFn(&args); - /* Our loop guarantees that ip[] >= ilimit and that we haven't + /* Our loop guarantees that ip[] >= ilowest and that we haven't * overwritten any op[]. */ - assert(args.ip[0] >= iend); - assert(args.ip[1] >= iend); - assert(args.ip[2] >= iend); - assert(args.ip[3] >= iend); + assert(args.ip[0] >= ilowest); + assert(args.ip[0] >= ilowest); + assert(args.ip[1] >= ilowest); + assert(args.ip[2] >= ilowest); + assert(args.ip[3] >= ilowest); assert(args.op[3] <= oend); - (void)iend; + + assert(ilowest == args.ilowest); + assert(ilowest + 6 == args.iend[0]); + (void)ilowest; /* finish bit streams one by one. */ { size_t const segmentSize = (dstSize+3) / 4; @@ -1512,7 +1515,7 @@ void HUF_decompress4X2_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs* BYTE* op[4]; BYTE* oend[4]; HUF_DEltX2 const* const dtable = (HUF_DEltX2 const*)args->dt; - BYTE const* const ilimit = args->ilimit; + BYTE const* const ilowest = args->ilowest; /* Copy the arguments to local registers. */ ZSTD_memcpy(&bits, &args->bits, sizeof(bits)); @@ -1535,7 +1538,7 @@ void HUF_decompress4X2_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs* #ifndef NDEBUG for (stream = 0; stream < 4; ++stream) { assert(op[stream] <= oend[stream]); - assert(ip[stream] >= ilimit); + assert(ip[stream] >= ilowest); } #endif /* Compute olimit */ @@ -1548,7 +1551,7 @@ void HUF_decompress4X2_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs* * We also know that each input pointer is >= ip[0]. So we can run * iters loops before running out of input. */ - size_t iters = (size_t)(ip[0] - ilimit) / 7; + size_t iters = (size_t)(ip[0] - ilowest) / 7; /* Each iteration can produce up to 10 bytes of output per stream. * Each output stream my advance at different rates. So take the * minimum number of safe iterations among all the output streams. @@ -1566,8 +1569,8 @@ void HUF_decompress4X2_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs* */ olimit = op[3] + (iters * 5); - /* Exit the fast decoding loop if we are too close to the end. */ - if (op[3] + 10 > olimit) + /* Exit the fast decoding loop once we reach the end. */ + if (op[3] == olimit) break; /* Exit the decoding loop if any input pointer has crossed the @@ -1652,7 +1655,7 @@ HUF_decompress4X2_usingDTable_internal_fast( const HUF_DTable* DTable, HUF_DecompressFastLoopFn loopFn) { void const* dt = DTable + 1; - const BYTE* const iend = (const BYTE*)cSrc + 6; + const BYTE* const ilowest = (const BYTE*)cSrc; BYTE* const oend = ZSTD_maybeNullPtrAdd((BYTE*)dst, dstSize); HUF_DecompressFastArgs args; { @@ -1662,16 +1665,19 @@ HUF_decompress4X2_usingDTable_internal_fast( return 0; } - assert(args.ip[0] >= args.ilimit); + assert(args.ip[0] >= args.ilowest); loopFn(&args); /* note : op4 already verified within main loop */ - assert(args.ip[0] >= iend); - assert(args.ip[1] >= iend); - assert(args.ip[2] >= iend); - assert(args.ip[3] >= iend); + assert(args.ip[0] >= ilowest); + assert(args.ip[1] >= ilowest); + assert(args.ip[2] >= ilowest); + assert(args.ip[3] >= ilowest); assert(args.op[3] <= oend); - (void)iend; + + assert(ilowest == args.ilowest); + assert(ilowest + 6 == args.iend[0]); + (void)ilowest; /* finish bitStreams one by one */ { diff --git a/lib/decompress/huf_decompress_amd64.S b/lib/decompress/huf_decompress_amd64.S index 671624fe3..3b96b4461 100644 --- a/lib/decompress/huf_decompress_amd64.S +++ b/lib/decompress/huf_decompress_amd64.S @@ -131,7 +131,7 @@ HUF_decompress4X1_usingDTable_internal_fast_asm_loop: movq 88(%rax), %bits3 movq 96(%rax), %dtable push %rax /* argument */ - push 104(%rax) /* ilimit */ + push 104(%rax) /* ilowest */ push 112(%rax) /* oend */ push %olimit /* olimit space */ @@ -156,11 +156,11 @@ HUF_decompress4X1_usingDTable_internal_fast_asm_loop: shrq $2, %r15 movq %ip0, %rax /* rax = ip0 */ - movq 40(%rsp), %rdx /* rdx = ilimit */ - subq %rdx, %rax /* rax = ip0 - ilimit */ - movq %rax, %rbx /* rbx = ip0 - ilimit */ + movq 40(%rsp), %rdx /* rdx = ilowest */ + subq %rdx, %rax /* rax = ip0 - ilowest */ + movq %rax, %rbx /* rbx = ip0 - ilowest */ - /* rdx = (ip0 - ilimit) / 7 */ + /* rdx = (ip0 - ilowest) / 7 */ movabsq $2635249153387078803, %rdx mulq %rdx subq %rdx, %rbx @@ -183,9 +183,8 @@ HUF_decompress4X1_usingDTable_internal_fast_asm_loop: /* If (op3 + 20 > olimit) */ movq %op3, %rax /* rax = op3 */ - addq $20, %rax /* rax = op3 + 20 */ - cmpq %rax, %olimit /* op3 + 20 > olimit */ - jb .L_4X1_exit + cmpq %rax, %olimit /* op3 == olimit */ + je .L_4X1_exit /* If (ip1 < ip0) go to exit */ cmpq %ip0, %ip1 @@ -316,7 +315,7 @@ HUF_decompress4X1_usingDTable_internal_fast_asm_loop: /* Restore stack (oend & olimit) */ pop %rax /* olimit */ pop %rax /* oend */ - pop %rax /* ilimit */ + pop %rax /* ilowest */ pop %rax /* arg */ /* Save ip / op / bits */ @@ -387,7 +386,7 @@ HUF_decompress4X2_usingDTable_internal_fast_asm_loop: movq 96(%rax), %dtable push %rax /* argument */ push %rax /* olimit */ - push 104(%rax) /* ilimit */ + push 104(%rax) /* ilowest */ movq 112(%rax), %rax push %rax /* oend3 */ @@ -414,9 +413,9 @@ HUF_decompress4X2_usingDTable_internal_fast_asm_loop: /* We can consume up to 7 input bytes each iteration. */ movq %ip0, %rax /* rax = ip0 */ - movq 40(%rsp), %rdx /* rdx = ilimit */ - subq %rdx, %rax /* rax = ip0 - ilimit */ - movq %rax, %r15 /* r15 = ip0 - ilimit */ + movq 40(%rsp), %rdx /* rdx = ilowest */ + subq %rdx, %rax /* rax = ip0 - ilowest */ + movq %rax, %r15 /* r15 = ip0 - ilowest */ /* rdx = rax / 7 */ movabsq $2635249153387078803, %rdx @@ -426,7 +425,7 @@ HUF_decompress4X2_usingDTable_internal_fast_asm_loop: addq %r15, %rdx shrq $2, %rdx - /* r15 = (ip0 - ilimit) / 7 */ + /* r15 = (ip0 - ilowest) / 7 */ movq %rdx, %r15 /* r15 = min(r15, min(oend0 - op0, oend1 - op1, oend2 - op2, oend3 - op3) / 10) */ @@ -467,9 +466,8 @@ HUF_decompress4X2_usingDTable_internal_fast_asm_loop: /* If (op3 + 10 > olimit) */ movq %op3, %rax /* rax = op3 */ - addq $10, %rax /* rax = op3 + 10 */ - cmpq %rax, %olimit /* op3 + 10 > olimit */ - jb .L_4X2_exit + cmpq %rax, %olimit /* op3 == olimit */ + je .L_4X2_exit /* If (ip1 < ip0) go to exit */ cmpq %ip0, %ip1 @@ -537,7 +535,7 @@ HUF_decompress4X2_usingDTable_internal_fast_asm_loop: pop %rax /* oend1 */ pop %rax /* oend2 */ pop %rax /* oend3 */ - pop %rax /* ilimit */ + pop %rax /* ilowest */ pop %rax /* olimit */ pop %rax /* arg */