Merge pull request #1864 from terrelln/dict-fix

Fix 2 bugs in dictionary loading
This commit is contained in:
Nick Terrell 2019-11-01 20:01:12 -07:00 committed by GitHub
commit 332aade370
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 171 additions and 43 deletions

View File

@ -2771,7 +2771,7 @@ static size_t ZSTD_checkDictNCount(short* normalizedCounter, unsigned dictMaxSym
/*! ZSTD_loadZstdDictionary() : /*! ZSTD_loadZstdDictionary() :
* @return : dictID, or an error code * @return : dictID, or an error code
* assumptions : magic number supposed already checked * assumptions : magic number supposed already checked
* dictSize supposed > 8 * dictSize supposed >= 8
*/ */
static size_t ZSTD_loadZstdDictionary(ZSTD_compressedBlockState_t* bs, static size_t ZSTD_loadZstdDictionary(ZSTD_compressedBlockState_t* bs,
ZSTD_matchState_t* ms, ZSTD_matchState_t* ms,
@ -2788,7 +2788,7 @@ static size_t ZSTD_loadZstdDictionary(ZSTD_compressedBlockState_t* bs,
size_t dictID; size_t dictID;
ZSTD_STATIC_ASSERT(HUF_WORKSPACE_SIZE >= (1<<MAX(MLFSELog,LLFSELog))); ZSTD_STATIC_ASSERT(HUF_WORKSPACE_SIZE >= (1<<MAX(MLFSELog,LLFSELog)));
assert(dictSize > 8); assert(dictSize >= 8);
assert(MEM_readLE32(dictPtr) == ZSTD_MAGIC_DICTIONARY); assert(MEM_readLE32(dictPtr) == ZSTD_MAGIC_DICTIONARY);
dictPtr += 4; /* skip magic number */ dictPtr += 4; /* skip magic number */
@ -2890,7 +2890,10 @@ ZSTD_compress_insertDictionary(ZSTD_compressedBlockState_t* bs,
void* workspace) void* workspace)
{ {
DEBUGLOG(4, "ZSTD_compress_insertDictionary (dictSize=%u)", (U32)dictSize); DEBUGLOG(4, "ZSTD_compress_insertDictionary (dictSize=%u)", (U32)dictSize);
if ((dict==NULL) || (dictSize<=8)) return 0; if ((dict==NULL) || (dictSize<8)) {
RETURN_ERROR_IF(dictContentType == ZSTD_dct_fullDict, dictionary_wrong);
return 0;
}
ZSTD_reset_compressedBlockState(bs); ZSTD_reset_compressedBlockState(bs);
@ -2942,7 +2945,7 @@ static size_t ZSTD_compressBegin_internal(ZSTD_CCtx* cctx,
FORWARD_IF_ERROR( ZSTD_resetCCtx_internal(cctx, *params, pledgedSrcSize, FORWARD_IF_ERROR( ZSTD_resetCCtx_internal(cctx, *params, pledgedSrcSize,
ZSTDcrp_makeClean, zbuff) ); ZSTDcrp_makeClean, zbuff) );
{ size_t const dictID = cdict ? { size_t const dictID = cdict ?
ZSTD_compress_insertDictionary( ZSTD_compress_insertDictionary(
cctx->blockState.prevCBlock, &cctx->blockState.matchState, cctx->blockState.prevCBlock, &cctx->blockState.matchState,
&cctx->workspace, params, cdict->dictContent, cdict->dictContentSize, &cctx->workspace, params, cdict->dictContent, cdict->dictContentSize,
@ -3219,7 +3222,7 @@ static size_t ZSTD_initCDict_internal(
ZSTDirp_reset, ZSTDirp_reset,
ZSTD_resetTarget_CDict)); ZSTD_resetTarget_CDict));
/* (Maybe) load the dictionary /* (Maybe) load the dictionary
* Skips loading the dictionary if it is <= 8 bytes. * Skips loading the dictionary if it is < 8 bytes.
*/ */
{ ZSTD_CCtx_params params; { ZSTD_CCtx_params params;
memset(&params, 0, sizeof(params)); memset(&params, 0, sizeof(params));

View File

@ -1096,7 +1096,7 @@ ZSTD_loadDEntropy(ZSTD_entropyDTables_t* entropy,
size_t const dictContentSize = (size_t)(dictEnd - (dictPtr+12)); size_t const dictContentSize = (size_t)(dictEnd - (dictPtr+12));
for (i=0; i<3; i++) { for (i=0; i<3; i++) {
U32 const rep = MEM_readLE32(dictPtr); dictPtr += 4; U32 const rep = MEM_readLE32(dictPtr); dictPtr += 4;
RETURN_ERROR_IF(rep==0 || rep >= dictContentSize, RETURN_ERROR_IF(rep==0 || rep > dictContentSize,
dictionary_corrupted); dictionary_corrupted);
entropy->rep[i] = rep; entropy->rep[i] = rep;
} } } }
@ -1265,7 +1265,7 @@ size_t ZSTD_DCtx_loadDictionary_advanced(ZSTD_DCtx* dctx,
{ {
RETURN_ERROR_IF(dctx->streamStage != zdss_init, stage_wrong); RETURN_ERROR_IF(dctx->streamStage != zdss_init, stage_wrong);
ZSTD_clearDict(dctx); ZSTD_clearDict(dctx);
if (dict && dictSize >= 8) { if (dict && dictSize != 0) {
dctx->ddictLocal = ZSTD_createDDict_advanced(dict, dictSize, dictLoadMethod, dictContentType, dctx->customMem); dctx->ddictLocal = ZSTD_createDDict_advanced(dict, dictSize, dictLoadMethod, dictContentType, dctx->customMem);
RETURN_ERROR_IF(dctx->ddictLocal == NULL, memory_allocation); RETURN_ERROR_IF(dctx->ddictLocal == NULL, memory_allocation);
dctx->ddict = dctx->ddictLocal; dctx->ddict = dctx->ddictLocal;

View File

@ -73,7 +73,8 @@ FUZZ_TARGETS := \
dictionary_round_trip \ dictionary_round_trip \
dictionary_decompress \ dictionary_decompress \
zstd_frame_info \ zstd_frame_info \
simple_compress simple_compress \
dictionary_loader
all: $(FUZZ_TARGETS) all: $(FUZZ_TARGETS)
@ -110,6 +111,9 @@ simple_compress: $(FUZZ_HEADERS) $(FUZZ_OBJ) simple_compress.o
zstd_frame_info: $(FUZZ_HEADERS) $(FUZZ_OBJ) zstd_frame_info.o zstd_frame_info: $(FUZZ_HEADERS) $(FUZZ_OBJ) zstd_frame_info.o
$(CXX) $(FUZZ_TARGET_FLAGS) $(FUZZ_OBJ) zstd_frame_info.o $(LIB_FUZZING_ENGINE) -o $@ $(CXX) $(FUZZ_TARGET_FLAGS) $(FUZZ_OBJ) zstd_frame_info.o $(LIB_FUZZING_ENGINE) -o $@
dictionary_loader: $(FUZZ_HEADERS) $(FUZZ_OBJ) dictionary_loader.o
$(CXX) $(FUZZ_TARGET_FLAGS) $(FUZZ_OBJ) dictionary_loader.o $(LIB_FUZZING_ENGINE) -o $@
libregression.a: $(FUZZ_HEADERS) $(PRGDIR)/util.h $(PRGDIR)/util.c regression_driver.o libregression.a: $(FUZZ_HEADERS) $(PRGDIR)/util.h $(PRGDIR)/util.c regression_driver.o
$(AR) $(FUZZ_ARFLAGS) $@ regression_driver.o $(AR) $(FUZZ_ARFLAGS) $@ regression_driver.o

View File

@ -0,0 +1,93 @@
/*
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under both the BSD-style license (found in the
* LICENSE file in the root directory of this source tree) and the GPLv2 (found
* in the COPYING file in the root directory of this source tree).
*/
/**
* This fuzz target makes sure that whenever a compression dictionary can be
* loaded, the data can be round tripped.
*/
#include <stddef.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include "fuzz_helpers.h"
#include "zstd_helpers.h"
#include "fuzz_data_producer.h"
/**
* Compresses the data and returns the compressed size or an error.
*/
static size_t compress(void* compressed, size_t compressedCapacity,
void const* source, size_t sourceSize,
void const* dict, size_t dictSize,
ZSTD_dictLoadMethod_e dictLoadMethod,
ZSTD_dictContentType_e dictContentType)
{
ZSTD_CCtx* cctx = ZSTD_createCCtx();
FUZZ_ZASSERT(ZSTD_CCtx_loadDictionary_advanced(
cctx, dict, dictSize, dictLoadMethod, dictContentType));
size_t const compressedSize = ZSTD_compress2(
cctx, compressed, compressedCapacity, source, sourceSize);
ZSTD_freeCCtx(cctx);
return compressedSize;
}
static size_t decompress(void* result, size_t resultCapacity,
void const* compressed, size_t compressedSize,
void const* dict, size_t dictSize,
ZSTD_dictLoadMethod_e dictLoadMethod,
ZSTD_dictContentType_e dictContentType)
{
ZSTD_DCtx* dctx = ZSTD_createDCtx();
FUZZ_ZASSERT(ZSTD_DCtx_loadDictionary_advanced(
dctx, dict, dictSize, dictLoadMethod, dictContentType));
size_t const resultSize = ZSTD_decompressDCtx(
dctx, result, resultCapacity, compressed, compressedSize);
FUZZ_ZASSERT(resultSize);
ZSTD_freeDCtx(dctx);
return resultSize;
}
int LLVMFuzzerTestOneInput(const uint8_t *src, size_t size)
{
FUZZ_dataProducer_t *producer = FUZZ_dataProducer_create(src, size);
ZSTD_dictLoadMethod_e const dlm =
size = FUZZ_dataProducer_uint32Range(producer, 0, 1);
ZSTD_dictContentType_e const dct =
FUZZ_dataProducer_uint32Range(producer, 0, 2);
size = FUZZ_dataProducer_remainingBytes(producer);
DEBUGLOG(2, "Dict load method %d", dlm);
DEBUGLOG(2, "Dict content type %d", dct);
DEBUGLOG(2, "Dict size %u", (unsigned)size);
void* const rBuf = malloc(size);
FUZZ_ASSERT(rBuf);
size_t const cBufSize = ZSTD_compressBound(size);
void* const cBuf = malloc(cBufSize);
FUZZ_ASSERT(cBuf);
size_t const cSize =
compress(cBuf, cBufSize, src, size, src, size, dlm, dct);
/* compression failing is okay */
if (ZSTD_isError(cSize)) {
FUZZ_ASSERT_MSG(dct != ZSTD_dct_rawContent, "Raw must always succeed!");
goto out;
}
size_t const rSize =
decompress(rBuf, size, cBuf, cSize, src, size, dlm, dct);
FUZZ_ASSERT_MSG(rSize == size, "Incorrect regenerated size");
FUZZ_ASSERT_MSG(!memcmp(src, rBuf, size), "Corruption!");
out:
free(cBuf);
free(rBuf);
FUZZ_dataProducer_free(producer);
return 0;
}

View File

@ -27,6 +27,7 @@ def abs_join(a, *p):
class InputType(object): class InputType(object):
RAW_DATA = 1 RAW_DATA = 1
COMPRESSED_DATA = 2 COMPRESSED_DATA = 2
DICTIONARY_DATA = 3
class FrameType(object): class FrameType(object):
@ -54,6 +55,7 @@ TARGET_INFO = {
'dictionary_decompress': TargetInfo(InputType.COMPRESSED_DATA), 'dictionary_decompress': TargetInfo(InputType.COMPRESSED_DATA),
'zstd_frame_info': TargetInfo(InputType.COMPRESSED_DATA), 'zstd_frame_info': TargetInfo(InputType.COMPRESSED_DATA),
'simple_compress': TargetInfo(InputType.RAW_DATA), 'simple_compress': TargetInfo(InputType.RAW_DATA),
'dictionary_loader': TargetInfo(InputType.DICTIONARY_DATA),
} }
TARGETS = list(TARGET_INFO.keys()) TARGETS = list(TARGET_INFO.keys())
ALL_TARGETS = TARGETS + ['all'] ALL_TARGETS = TARGETS + ['all']
@ -73,6 +75,7 @@ LIB_FUZZING_ENGINE = os.environ.get('LIB_FUZZING_ENGINE', 'libregression.a')
AFL_FUZZ = os.environ.get('AFL_FUZZ', 'afl-fuzz') AFL_FUZZ = os.environ.get('AFL_FUZZ', 'afl-fuzz')
DECODECORPUS = os.environ.get('DECODECORPUS', DECODECORPUS = os.environ.get('DECODECORPUS',
abs_join(FUZZ_DIR, '..', 'decodecorpus')) abs_join(FUZZ_DIR, '..', 'decodecorpus'))
ZSTD = os.environ.get('ZSTD', abs_join(FUZZ_DIR, '..', '..', 'zstd'))
# Sanitizer environment variables # Sanitizer environment variables
MSAN_EXTRA_CPPFLAGS = os.environ.get('MSAN_EXTRA_CPPFLAGS', '') MSAN_EXTRA_CPPFLAGS = os.environ.get('MSAN_EXTRA_CPPFLAGS', '')
@ -673,6 +676,11 @@ def gen_parser(args):
default=DECODECORPUS, default=DECODECORPUS,
help="decodecorpus binary (default: $DECODECORPUS='{}')".format( help="decodecorpus binary (default: $DECODECORPUS='{}')".format(
DECODECORPUS)) DECODECORPUS))
parser.add_argument(
'--zstd',
type=str,
default=ZSTD,
help="zstd binary (default: $ZSTD='{}')".format(ZSTD))
parser.add_argument( parser.add_argument(
'--fuzz-rng-seed-size', '--fuzz-rng-seed-size',
type=int, type=int,
@ -707,46 +715,66 @@ def gen(args):
return 1 return 1
seed = create(args.seed) seed = create(args.seed)
with tmpdir() as compressed: with tmpdir() as compressed, tmpdir() as decompressed, tmpdir() as dict:
with tmpdir() as decompressed: info = TARGET_INFO[args.TARGET]
cmd = [
args.decodecorpus, if info.input_type == InputType.DICTIONARY_DATA:
'-n{}'.format(args.number), number = max(args.number, 1000)
'-p{}/'.format(compressed), else:
'-o{}'.format(decompressed), number = args.number
cmd = [
args.decodecorpus,
'-n{}'.format(args.number),
'-p{}/'.format(compressed),
'-o{}'.format(decompressed),
]
if info.frame_type == FrameType.BLOCK:
cmd += [
'--gen-blocks',
'--max-block-size-log={}'.format(min(args.max_size_log, 17))
] ]
else:
cmd += ['--max-content-size-log={}'.format(args.max_size_log)]
info = TARGET_INFO[args.TARGET] print(' '.join(cmd))
if info.frame_type == FrameType.BLOCK: subprocess.check_call(cmd)
cmd += [
'--gen-blocks', if info.input_type == InputType.RAW_DATA:
'--max-block-size-log={}'.format(min(args.max_size_log, 17)) print('using decompressed data in {}'.format(decompressed))
samples = decompressed
elif info.input_type == InputType.COMPRESSED_DATA:
print('using compressed data in {}'.format(compressed))
samples = compressed
else:
assert info.input_type == InputType.DICTIONARY_DATA
print('making dictionary data from {}'.format(decompressed))
samples = dict
min_dict_size_log = 9
max_dict_size_log = max(min_dict_size_log + 1, args.max_size_log)
for dict_size_log in range(min_dict_size_log, max_dict_size_log):
dict_size = 1 << dict_size_log
cmd = [
args.zstd,
'--train',
'-r', decompressed,
'--maxdict={}'.format(dict_size),
'-o', abs_join(dict, '{}.zstd-dict'.format(dict_size))
] ]
else: print(' '.join(cmd))
cmd += ['--max-content-size-log={}'.format(args.max_size_log)] subprocess.check_call(cmd)
print(' '.join(cmd)) # Copy the samples over and prepend the RNG seeds
subprocess.check_call(cmd) for name in os.listdir(samples):
samplename = abs_join(samples, name)
if info.input_type == InputType.RAW_DATA: outname = abs_join(seed, name)
print('using decompressed data in {}'.format(decompressed)) with open(samplename, 'rb') as sample:
samples = decompressed with open(outname, 'wb') as out:
else: CHUNK_SIZE = 131072
assert info.input_type == InputType.COMPRESSED_DATA chunk = sample.read(CHUNK_SIZE)
print('using compressed data in {}'.format(compressed)) while len(chunk) > 0:
samples = compressed out.write(chunk)
# Copy the samples over and prepend the RNG seeds
for name in os.listdir(samples):
samplename = abs_join(samples, name)
outname = abs_join(seed, name)
with open(samplename, 'rb') as sample:
with open(outname, 'wb') as out:
CHUNK_SIZE = 131072
chunk = sample.read(CHUNK_SIZE) chunk = sample.read(CHUNK_SIZE)
while len(chunk) > 0:
out.write(chunk)
chunk = sample.read(CHUNK_SIZE)
return 0 return 0