Merge pull request #1659 from terrelln/big-dict

Fix data corruption in niche use case
This commit is contained in:
Nick Terrell 2019-06-24 12:40:58 -07:00 committed by GitHub
commit 9038579ab2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 203 additions and 60 deletions

View File

@ -1826,16 +1826,15 @@ static void ZSTD_reduceTable_btlazy2(U32* const table, U32 const size, U32 const
/*! ZSTD_reduceIndex() :
* rescale all indexes to avoid future overflow (indexes are U32) */
static void ZSTD_reduceIndex (ZSTD_CCtx* zc, const U32 reducerValue)
static void ZSTD_reduceIndex (ZSTD_matchState_t* ms, ZSTD_CCtx_params const* params, const U32 reducerValue)
{
ZSTD_matchState_t* const ms = &zc->blockState.matchState;
{ U32 const hSize = (U32)1 << zc->appliedParams.cParams.hashLog;
{ U32 const hSize = (U32)1 << params->cParams.hashLog;
ZSTD_reduceTable(ms->hashTable, hSize, reducerValue);
}
if (zc->appliedParams.cParams.strategy != ZSTD_fast) {
U32 const chainSize = (U32)1 << zc->appliedParams.cParams.chainLog;
if (zc->appliedParams.cParams.strategy == ZSTD_btlazy2)
if (params->cParams.strategy != ZSTD_fast) {
U32 const chainSize = (U32)1 << params->cParams.chainLog;
if (params->cParams.strategy == ZSTD_btlazy2)
ZSTD_reduceTable_btlazy2(ms->chainTable, chainSize, reducerValue);
else
ZSTD_reduceTable(ms->chainTable, chainSize, reducerValue);
@ -2821,6 +2820,25 @@ out:
}
static void ZSTD_overflowCorrectIfNeeded(ZSTD_matchState_t* ms, ZSTD_CCtx_params const* params, void const* ip, void const* iend)
{
if (ZSTD_window_needOverflowCorrection(ms->window, iend)) {
U32 const maxDist = (U32)1 << params->cParams.windowLog;
U32 const cycleLog = ZSTD_cycleLog(params->cParams.chainLog, params->cParams.strategy);
U32 const correction = ZSTD_window_correctOverflow(&ms->window, cycleLog, maxDist, ip);
ZSTD_STATIC_ASSERT(ZSTD_CHAINLOG_MAX <= 30);
ZSTD_STATIC_ASSERT(ZSTD_WINDOWLOG_MAX_32 <= 30);
ZSTD_STATIC_ASSERT(ZSTD_WINDOWLOG_MAX <= 31);
ZSTD_reduceIndex(ms, params, correction);
if (ms->nextToUpdate < correction) ms->nextToUpdate = 0;
else ms->nextToUpdate -= correction;
/* invalidate dictionaries on overflow correction */
ms->loadedDictEnd = 0;
ms->dictMatchState = NULL;
}
}
/*! ZSTD_compress_frameChunk() :
* Compress a chunk of data into one or multiple blocks.
* All blocks will be terminated, all input will be consumed.
@ -2854,20 +2872,7 @@ static size_t ZSTD_compress_frameChunk (ZSTD_CCtx* cctx,
"not enough space to store compressed block");
if (remaining < blockSize) blockSize = remaining;
if (ZSTD_window_needOverflowCorrection(ms->window, ip + blockSize)) {
U32 const cycleLog = ZSTD_cycleLog(cctx->appliedParams.cParams.chainLog, cctx->appliedParams.cParams.strategy);
U32 const correction = ZSTD_window_correctOverflow(&ms->window, cycleLog, maxDist, ip);
ZSTD_STATIC_ASSERT(ZSTD_CHAINLOG_MAX <= 30);
ZSTD_STATIC_ASSERT(ZSTD_WINDOWLOG_MAX_32 <= 30);
ZSTD_STATIC_ASSERT(ZSTD_WINDOWLOG_MAX <= 31);
ZSTD_reduceIndex(cctx, correction);
if (ms->nextToUpdate < correction) ms->nextToUpdate = 0;
else ms->nextToUpdate -= correction;
/* invalidate dictionaries on overflow correction */
ms->loadedDictEnd = 0;
ms->dictMatchState = NULL;
}
ZSTD_overflowCorrectIfNeeded(ms, &cctx->appliedParams, ip, ip + blockSize);
ZSTD_checkDictValidity(&ms->window, ip + blockSize, maxDist, &ms->loadedDictEnd, &ms->dictMatchState);
/* Ensure hash/chain table insertion resumes no sooner than lowlimit */
@ -3007,18 +3012,7 @@ static size_t ZSTD_compressContinue_internal (ZSTD_CCtx* cctx,
if (!frame) {
/* overflow check and correction for block mode */
if (ZSTD_window_needOverflowCorrection(ms->window, (const char*)src + srcSize)) {
U32 const cycleLog = ZSTD_cycleLog(cctx->appliedParams.cParams.chainLog, cctx->appliedParams.cParams.strategy);
U32 const correction = ZSTD_window_correctOverflow(&ms->window, cycleLog, 1 << cctx->appliedParams.cParams.windowLog, src);
ZSTD_STATIC_ASSERT(ZSTD_CHAINLOG_MAX <= 30);
ZSTD_STATIC_ASSERT(ZSTD_WINDOWLOG_MAX_32 <= 30);
ZSTD_STATIC_ASSERT(ZSTD_WINDOWLOG_MAX <= 31);
ZSTD_reduceIndex(cctx, correction);
if (ms->nextToUpdate < correction) ms->nextToUpdate = 0;
else ms->nextToUpdate -= correction;
ms->loadedDictEnd = 0;
ms->dictMatchState = NULL;
}
ZSTD_overflowCorrectIfNeeded(ms, &cctx->appliedParams, src, (BYTE const*)src + srcSize);
}
DEBUGLOG(5, "ZSTD_compressContinue_internal (blockSize=%u)", (unsigned)cctx->blockSize);
@ -3074,7 +3068,7 @@ static size_t ZSTD_loadDictionaryContent(ZSTD_matchState_t* ms,
const void* src, size_t srcSize,
ZSTD_dictTableLoadMethod_e dtlm)
{
const BYTE* const ip = (const BYTE*) src;
const BYTE* ip = (const BYTE*) src;
const BYTE* const iend = ip + srcSize;
ZSTD_window_update(&ms->window, src, srcSize);
@ -3085,32 +3079,42 @@ static size_t ZSTD_loadDictionaryContent(ZSTD_matchState_t* ms,
if (srcSize <= HASH_READ_SIZE) return 0;
switch(params->cParams.strategy)
{
case ZSTD_fast:
ZSTD_fillHashTable(ms, iend, dtlm);
break;
case ZSTD_dfast:
ZSTD_fillDoubleHashTable(ms, iend, dtlm);
break;
while (iend - ip > HASH_READ_SIZE) {
size_t const remaining = iend - ip;
size_t const chunk = MIN(remaining, ZSTD_CHUNKSIZE_MAX);
const BYTE* const ichunk = ip + chunk;
case ZSTD_greedy:
case ZSTD_lazy:
case ZSTD_lazy2:
if (srcSize >= HASH_READ_SIZE)
ZSTD_insertAndFindFirstIndex(ms, iend-HASH_READ_SIZE);
break;
ZSTD_overflowCorrectIfNeeded(ms, params, ip, ichunk);
case ZSTD_btlazy2: /* we want the dictionary table fully sorted */
case ZSTD_btopt:
case ZSTD_btultra:
case ZSTD_btultra2:
if (srcSize >= HASH_READ_SIZE)
ZSTD_updateTree(ms, iend-HASH_READ_SIZE, iend);
break;
switch(params->cParams.strategy)
{
case ZSTD_fast:
ZSTD_fillHashTable(ms, ichunk, dtlm);
break;
case ZSTD_dfast:
ZSTD_fillDoubleHashTable(ms, ichunk, dtlm);
break;
default:
assert(0); /* not possible : not a valid strategy id */
case ZSTD_greedy:
case ZSTD_lazy:
case ZSTD_lazy2:
if (chunk >= HASH_READ_SIZE)
ZSTD_insertAndFindFirstIndex(ms, ichunk-HASH_READ_SIZE);
break;
case ZSTD_btlazy2: /* we want the dictionary table fully sorted */
case ZSTD_btopt:
case ZSTD_btultra:
case ZSTD_btultra2:
if (chunk >= HASH_READ_SIZE)
ZSTD_updateTree(ms, ichunk-HASH_READ_SIZE, ichunk);
break;
default:
assert(0); /* not possible : not a valid strategy id */
}
ip = ichunk;
}
ms->nextToUpdate = (U32)(iend - ms->window.base);

View File

@ -525,8 +525,13 @@ void ZSTD_updateTree_internal(
DEBUGLOG(6, "ZSTD_updateTree_internal, from %u to %u (dictMode:%u)",
idx, target, dictMode);
while(idx < target)
idx += ZSTD_insertBt1(ms, base+idx, iend, mls, dictMode == ZSTD_extDict);
while(idx < target) {
U32 const forward = ZSTD_insertBt1(ms, base+idx, iend, mls, dictMode == ZSTD_extDict);
assert(idx < (U32)(idx + forward));
idx += forward;
}
assert((size_t)(ip - base) <= (size_t)(U32)(-1));
assert((size_t)(iend - base) <= (size_t)(U32)(-1));
ms->nextToUpdate = target;
}

View File

@ -1197,7 +1197,7 @@ static size_t ZSTDMT_computeOverlapSize(ZSTD_CCtx_params const params)
ovLog = MIN(params.cParams.windowLog, ZSTDMT_computeTargetJobLog(params) - 2)
- overlapRLog;
}
assert(0 <= ovLog && ovLog <= 30);
assert(0 <= ovLog && ovLog <= ZSTD_WINDOWLOG_MAX);
DEBUGLOG(4, "overlapLog : %i", params.overlapLog);
DEBUGLOG(4, "overlap size : %i", 1 << ovLog);
return (ovLog==0) ? 0 : (size_t)1 << ovLog;

View File

@ -215,6 +215,9 @@ roundTripCrash : $(ZSTD_OBJECTS) roundTripCrash.c
longmatch : $(ZSTD_OBJECTS) longmatch.c
$(CC) $(FLAGS) $^ -o $@$(EXT)
bigdict: $(ZSTDMT_OBJECTS) $(PRGDIR)/datagen.c bigdict.c
$(CC) $(FLAGS) $(MULTITHREAD) $^ -o $@$(EXT)
invalidDictionaries : $(ZSTD_OBJECTS) invalidDictionaries.c
$(CC) $(FLAGS) $^ -o $@$(EXT)
@ -256,7 +259,7 @@ clean:
zstreamtest$(EXT) zstreamtest32$(EXT) \
datagen$(EXT) paramgrill$(EXT) roundTripCrash$(EXT) longmatch$(EXT) \
symbols$(EXT) invalidDictionaries$(EXT) legacy$(EXT) poolTests$(EXT) \
decodecorpus$(EXT) checkTag$(EXT)
decodecorpus$(EXT) checkTag$(EXT) bigdict$(EXT)
@echo Cleaning completed
@ -397,6 +400,9 @@ test-zstream32: zstreamtest32
test-longmatch: longmatch
$(QEMU_SYS) ./longmatch
test-bigdict: bigdict
$(QEMU_SYS) ./bigdict
test-invalidDictionaries: invalidDictionaries
$(QEMU_SYS) ./invalidDictionaries

128
tests/bigdict.c Normal file
View File

@ -0,0 +1,128 @@
/*
* Copyright (c) 2017-present, Yann Collet, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under both the BSD-style license (found in the
* LICENSE file in the root directory of this source tree) and the GPLv2 (found
* in the COPYING file in the root directory of this source tree).
* You may select, at your option, one of the above-listed licenses.
*/
#include <assert.h>
#include <stdio.h>
#include <stddef.h>
#include <stdlib.h>
#include <stdint.h>
#include "datagen.h"
#include "mem.h"
#define ZSTD_STATIC_LINKING_ONLY
#include "zstd.h"
static int
compress(ZSTD_CCtx* cctx, ZSTD_DCtx* dctx,
void* dst, size_t dstCapacity,
void const* src, size_t srcSize,
void* roundtrip, ZSTD_EndDirective end)
{
ZSTD_inBuffer in = {src, srcSize, 0};
ZSTD_outBuffer out = {dst, dstCapacity, 0};
int ended = 0;
while (!ended && (in.pos < in.size || out.pos > 0)) {
size_t rc;
out.pos = 0;
rc = ZSTD_compressStream2(cctx, &out, &in, end);
if (ZSTD_isError(rc))
return 1;
if (end == ZSTD_e_end && rc == 0)
ended = 1;
{
ZSTD_inBuffer rtIn = {dst, out.pos, 0};
ZSTD_outBuffer rtOut = {roundtrip, srcSize, 0};
rc = 1;
while (rtIn.pos < rtIn.size || rtOut.pos > 0) {
rtOut.pos = 0;
rc = ZSTD_decompressStream(dctx, &rtOut, &rtIn);
if (ZSTD_isError(rc)) {
fprintf(stderr, "Decompression error: %s\n", ZSTD_getErrorName(rc));
return 1;
}
if (rc == 0)
break;
}
if (ended && rc != 0) {
fprintf(stderr, "Frame not finished!\n");
return 1;
}
}
}
return 0;
}
int main(int argc, const char** argv)
{
ZSTD_CCtx* cctx = ZSTD_createCCtx();
ZSTD_DCtx* dctx = ZSTD_createDCtx();
const size_t dataSize = (size_t)1 << 30;
const size_t outSize = ZSTD_compressBound(dataSize);
const size_t bufferSize = (size_t)1 << 31;
char* buffer = (char*)malloc(bufferSize);
void* out = malloc(outSize);
void* roundtrip = malloc(dataSize);
(void)argc;
(void)argv;
if (!buffer || !out || !roundtrip || !cctx || !dctx) {
fprintf(stderr, "Allocation failure\n");
return 1;
}
if (ZSTD_isError(ZSTD_CCtx_setParameter(cctx, ZSTD_c_windowLog, 31)))
return 1;
if (ZSTD_isError(ZSTD_CCtx_setParameter(cctx, ZSTD_c_nbWorkers, 1)))
return 1;
if (ZSTD_isError(ZSTD_CCtx_setParameter(cctx, ZSTD_c_overlapLog, 9)))
return 1;
if (ZSTD_isError(ZSTD_CCtx_setParameter(cctx, ZSTD_c_checksumFlag, 1)))
return 1;
if (ZSTD_isError(ZSTD_CCtx_setParameter(cctx, ZSTD_c_strategy, ZSTD_btopt)))
return 1;
if (ZSTD_isError(ZSTD_CCtx_setParameter(cctx, ZSTD_c_targetLength, 7)))
return 1;
if (ZSTD_isError(ZSTD_CCtx_setParameter(cctx, ZSTD_c_minMatch, 7)))
return 1;
if (ZSTD_isError(ZSTD_CCtx_setParameter(cctx, ZSTD_c_searchLog, 1)))
return 1;
if (ZSTD_isError(ZSTD_CCtx_setParameter(cctx, ZSTD_c_hashLog, 10)))
return 1;
if (ZSTD_isError(ZSTD_CCtx_setParameter(cctx, ZSTD_c_chainLog, 10)))
return 1;
if (ZSTD_isError(ZSTD_DCtx_setParameter(dctx, ZSTD_d_windowLogMax, 31)))
return 1;
RDG_genBuffer(buffer, bufferSize, 1.0, 0.0, 0xbeefcafe);
/* Compress 30 GB */
{
int i;
for (i = 0; i < 10; ++i) {
fprintf(stderr, "Compressing 1 GB\n");
if (compress(cctx, dctx, out, outSize, buffer, dataSize, roundtrip, ZSTD_e_continue))
return 1;
}
}
fprintf(stderr, "Compressing 1 GB\n");
if (compress(cctx, dctx, out, outSize, buffer, dataSize, roundtrip, ZSTD_e_end))
return 1;
fprintf(stderr, "Success!\n");
free(roundtrip);
free(out);
free(buffer);
ZSTD_freeDCtx(dctx);
ZSTD_freeCCtx(cctx);
return 0;
}