diff --git a/lib/compress/zstd_compress.c b/lib/compress/zstd_compress.c index a8856cd1b..35346b92c 100644 --- a/lib/compress/zstd_compress.c +++ b/lib/compress/zstd_compress.c @@ -2771,7 +2771,7 @@ static size_t ZSTD_checkDictNCount(short* normalizedCounter, unsigned dictMaxSym /*! ZSTD_loadZstdDictionary() : * @return : dictID, or an error code * assumptions : magic number supposed already checked - * dictSize supposed > 8 + * dictSize supposed >= 8 */ static size_t ZSTD_loadZstdDictionary(ZSTD_compressedBlockState_t* bs, ZSTD_matchState_t* ms, @@ -2788,7 +2788,7 @@ static size_t ZSTD_loadZstdDictionary(ZSTD_compressedBlockState_t* bs, size_t dictID; ZSTD_STATIC_ASSERT(HUF_WORKSPACE_SIZE >= (1< 8); + assert(dictSize >= 8); assert(MEM_readLE32(dictPtr) == ZSTD_MAGIC_DICTIONARY); dictPtr += 4; /* skip magic number */ @@ -2890,7 +2890,10 @@ ZSTD_compress_insertDictionary(ZSTD_compressedBlockState_t* bs, void* workspace) { DEBUGLOG(4, "ZSTD_compress_insertDictionary (dictSize=%u)", (U32)dictSize); - if ((dict==NULL) || (dictSize<=8)) return 0; + if ((dict==NULL) || (dictSize<8)) { + RETURN_ERROR_IF(dictContentType == ZSTD_dct_fullDict, dictionary_wrong); + return 0; + } ZSTD_reset_compressedBlockState(bs); @@ -2942,7 +2945,7 @@ static size_t ZSTD_compressBegin_internal(ZSTD_CCtx* cctx, FORWARD_IF_ERROR( ZSTD_resetCCtx_internal(cctx, *params, pledgedSrcSize, ZSTDcrp_makeClean, zbuff) ); - { size_t const dictID = cdict ? + { size_t const dictID = cdict ? ZSTD_compress_insertDictionary( cctx->blockState.prevCBlock, &cctx->blockState.matchState, &cctx->workspace, params, cdict->dictContent, cdict->dictContentSize, @@ -3219,7 +3222,7 @@ static size_t ZSTD_initCDict_internal( ZSTDirp_reset, ZSTD_resetTarget_CDict)); /* (Maybe) load the dictionary - * Skips loading the dictionary if it is <= 8 bytes. + * Skips loading the dictionary if it is < 8 bytes. */ { ZSTD_CCtx_params params; memset(¶ms, 0, sizeof(params)); diff --git a/lib/decompress/zstd_decompress.c b/lib/decompress/zstd_decompress.c index ca47a6678..dd4591b7b 100644 --- a/lib/decompress/zstd_decompress.c +++ b/lib/decompress/zstd_decompress.c @@ -1096,7 +1096,7 @@ ZSTD_loadDEntropy(ZSTD_entropyDTables_t* entropy, size_t const dictContentSize = (size_t)(dictEnd - (dictPtr+12)); for (i=0; i<3; i++) { U32 const rep = MEM_readLE32(dictPtr); dictPtr += 4; - RETURN_ERROR_IF(rep==0 || rep >= dictContentSize, + RETURN_ERROR_IF(rep==0 || rep > dictContentSize, dictionary_corrupted); entropy->rep[i] = rep; } } @@ -1265,7 +1265,7 @@ size_t ZSTD_DCtx_loadDictionary_advanced(ZSTD_DCtx* dctx, { RETURN_ERROR_IF(dctx->streamStage != zdss_init, stage_wrong); ZSTD_clearDict(dctx); - if (dict && dictSize >= 8) { + if (dict && dictSize != 0) { dctx->ddictLocal = ZSTD_createDDict_advanced(dict, dictSize, dictLoadMethod, dictContentType, dctx->customMem); RETURN_ERROR_IF(dctx->ddictLocal == NULL, memory_allocation); dctx->ddict = dctx->ddictLocal; diff --git a/tests/fuzz/Makefile b/tests/fuzz/Makefile index 83837e620..f66dadef0 100644 --- a/tests/fuzz/Makefile +++ b/tests/fuzz/Makefile @@ -73,7 +73,8 @@ FUZZ_TARGETS := \ dictionary_round_trip \ dictionary_decompress \ zstd_frame_info \ - simple_compress + simple_compress \ + dictionary_loader all: $(FUZZ_TARGETS) @@ -110,6 +111,9 @@ simple_compress: $(FUZZ_HEADERS) $(FUZZ_OBJ) simple_compress.o zstd_frame_info: $(FUZZ_HEADERS) $(FUZZ_OBJ) zstd_frame_info.o $(CXX) $(FUZZ_TARGET_FLAGS) $(FUZZ_OBJ) zstd_frame_info.o $(LIB_FUZZING_ENGINE) -o $@ +dictionary_loader: $(FUZZ_HEADERS) $(FUZZ_OBJ) dictionary_loader.o + $(CXX) $(FUZZ_TARGET_FLAGS) $(FUZZ_OBJ) dictionary_loader.o $(LIB_FUZZING_ENGINE) -o $@ + libregression.a: $(FUZZ_HEADERS) $(PRGDIR)/util.h $(PRGDIR)/util.c regression_driver.o $(AR) $(FUZZ_ARFLAGS) $@ regression_driver.o diff --git a/tests/fuzz/dictionary_loader.c b/tests/fuzz/dictionary_loader.c new file mode 100644 index 000000000..cb34f5d22 --- /dev/null +++ b/tests/fuzz/dictionary_loader.c @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2016-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the + * LICENSE file in the root directory of this source tree) and the GPLv2 (found + * in the COPYING file in the root directory of this source tree). + */ + +/** + * This fuzz target makes sure that whenever a compression dictionary can be + * loaded, the data can be round tripped. + */ + +#include +#include +#include +#include +#include "fuzz_helpers.h" +#include "zstd_helpers.h" +#include "fuzz_data_producer.h" + +/** + * Compresses the data and returns the compressed size or an error. + */ +static size_t compress(void* compressed, size_t compressedCapacity, + void const* source, size_t sourceSize, + void const* dict, size_t dictSize, + ZSTD_dictLoadMethod_e dictLoadMethod, + ZSTD_dictContentType_e dictContentType) +{ + ZSTD_CCtx* cctx = ZSTD_createCCtx(); + FUZZ_ZASSERT(ZSTD_CCtx_loadDictionary_advanced( + cctx, dict, dictSize, dictLoadMethod, dictContentType)); + size_t const compressedSize = ZSTD_compress2( + cctx, compressed, compressedCapacity, source, sourceSize); + ZSTD_freeCCtx(cctx); + return compressedSize; +} + +static size_t decompress(void* result, size_t resultCapacity, + void const* compressed, size_t compressedSize, + void const* dict, size_t dictSize, + ZSTD_dictLoadMethod_e dictLoadMethod, + ZSTD_dictContentType_e dictContentType) +{ + ZSTD_DCtx* dctx = ZSTD_createDCtx(); + FUZZ_ZASSERT(ZSTD_DCtx_loadDictionary_advanced( + dctx, dict, dictSize, dictLoadMethod, dictContentType)); + size_t const resultSize = ZSTD_decompressDCtx( + dctx, result, resultCapacity, compressed, compressedSize); + FUZZ_ZASSERT(resultSize); + ZSTD_freeDCtx(dctx); + return resultSize; +} + +int LLVMFuzzerTestOneInput(const uint8_t *src, size_t size) +{ + FUZZ_dataProducer_t *producer = FUZZ_dataProducer_create(src, size); + ZSTD_dictLoadMethod_e const dlm = + size = FUZZ_dataProducer_uint32Range(producer, 0, 1); + ZSTD_dictContentType_e const dct = + FUZZ_dataProducer_uint32Range(producer, 0, 2); + size = FUZZ_dataProducer_remainingBytes(producer); + + DEBUGLOG(2, "Dict load method %d", dlm); + DEBUGLOG(2, "Dict content type %d", dct); + DEBUGLOG(2, "Dict size %u", (unsigned)size); + + void* const rBuf = malloc(size); + FUZZ_ASSERT(rBuf); + size_t const cBufSize = ZSTD_compressBound(size); + void* const cBuf = malloc(cBufSize); + FUZZ_ASSERT(cBuf); + + size_t const cSize = + compress(cBuf, cBufSize, src, size, src, size, dlm, dct); + /* compression failing is okay */ + if (ZSTD_isError(cSize)) { + FUZZ_ASSERT_MSG(dct != ZSTD_dct_rawContent, "Raw must always succeed!"); + goto out; + } + size_t const rSize = + decompress(rBuf, size, cBuf, cSize, src, size, dlm, dct); + FUZZ_ASSERT_MSG(rSize == size, "Incorrect regenerated size"); + FUZZ_ASSERT_MSG(!memcmp(src, rBuf, size), "Corruption!"); + +out: + free(cBuf); + free(rBuf); + FUZZ_dataProducer_free(producer); + return 0; +} diff --git a/tests/fuzz/fuzz.py b/tests/fuzz/fuzz.py index 9df68df0c..87f115afd 100755 --- a/tests/fuzz/fuzz.py +++ b/tests/fuzz/fuzz.py @@ -27,6 +27,7 @@ def abs_join(a, *p): class InputType(object): RAW_DATA = 1 COMPRESSED_DATA = 2 + DICTIONARY_DATA = 3 class FrameType(object): @@ -54,6 +55,7 @@ TARGET_INFO = { 'dictionary_decompress': TargetInfo(InputType.COMPRESSED_DATA), 'zstd_frame_info': TargetInfo(InputType.COMPRESSED_DATA), 'simple_compress': TargetInfo(InputType.RAW_DATA), + 'dictionary_loader': TargetInfo(InputType.DICTIONARY_DATA), } TARGETS = list(TARGET_INFO.keys()) ALL_TARGETS = TARGETS + ['all'] @@ -73,6 +75,7 @@ LIB_FUZZING_ENGINE = os.environ.get('LIB_FUZZING_ENGINE', 'libregression.a') AFL_FUZZ = os.environ.get('AFL_FUZZ', 'afl-fuzz') DECODECORPUS = os.environ.get('DECODECORPUS', abs_join(FUZZ_DIR, '..', 'decodecorpus')) +ZSTD = os.environ.get('ZSTD', abs_join(FUZZ_DIR, '..', '..', 'zstd')) # Sanitizer environment variables MSAN_EXTRA_CPPFLAGS = os.environ.get('MSAN_EXTRA_CPPFLAGS', '') @@ -673,6 +676,11 @@ def gen_parser(args): default=DECODECORPUS, help="decodecorpus binary (default: $DECODECORPUS='{}')".format( DECODECORPUS)) + parser.add_argument( + '--zstd', + type=str, + default=ZSTD, + help="zstd binary (default: $ZSTD='{}')".format(ZSTD)) parser.add_argument( '--fuzz-rng-seed-size', type=int, @@ -707,46 +715,66 @@ def gen(args): return 1 seed = create(args.seed) - with tmpdir() as compressed: - with tmpdir() as decompressed: - cmd = [ - args.decodecorpus, - '-n{}'.format(args.number), - '-p{}/'.format(compressed), - '-o{}'.format(decompressed), + with tmpdir() as compressed, tmpdir() as decompressed, tmpdir() as dict: + info = TARGET_INFO[args.TARGET] + + if info.input_type == InputType.DICTIONARY_DATA: + number = max(args.number, 1000) + else: + number = args.number + cmd = [ + args.decodecorpus, + '-n{}'.format(args.number), + '-p{}/'.format(compressed), + '-o{}'.format(decompressed), + ] + + if info.frame_type == FrameType.BLOCK: + cmd += [ + '--gen-blocks', + '--max-block-size-log={}'.format(min(args.max_size_log, 17)) ] + else: + cmd += ['--max-content-size-log={}'.format(args.max_size_log)] - info = TARGET_INFO[args.TARGET] - if info.frame_type == FrameType.BLOCK: - cmd += [ - '--gen-blocks', - '--max-block-size-log={}'.format(min(args.max_size_log, 17)) + print(' '.join(cmd)) + subprocess.check_call(cmd) + + if info.input_type == InputType.RAW_DATA: + print('using decompressed data in {}'.format(decompressed)) + samples = decompressed + elif info.input_type == InputType.COMPRESSED_DATA: + print('using compressed data in {}'.format(compressed)) + samples = compressed + else: + assert info.input_type == InputType.DICTIONARY_DATA + print('making dictionary data from {}'.format(decompressed)) + samples = dict + min_dict_size_log = 9 + max_dict_size_log = max(min_dict_size_log + 1, args.max_size_log) + for dict_size_log in range(min_dict_size_log, max_dict_size_log): + dict_size = 1 << dict_size_log + cmd = [ + args.zstd, + '--train', + '-r', decompressed, + '--maxdict={}'.format(dict_size), + '-o', abs_join(dict, '{}.zstd-dict'.format(dict_size)) ] - else: - cmd += ['--max-content-size-log={}'.format(args.max_size_log)] + print(' '.join(cmd)) + subprocess.check_call(cmd) - print(' '.join(cmd)) - subprocess.check_call(cmd) - - if info.input_type == InputType.RAW_DATA: - print('using decompressed data in {}'.format(decompressed)) - samples = decompressed - else: - assert info.input_type == InputType.COMPRESSED_DATA - print('using compressed data in {}'.format(compressed)) - samples = compressed - - # Copy the samples over and prepend the RNG seeds - for name in os.listdir(samples): - samplename = abs_join(samples, name) - outname = abs_join(seed, name) - with open(samplename, 'rb') as sample: - with open(outname, 'wb') as out: - CHUNK_SIZE = 131072 + # Copy the samples over and prepend the RNG seeds + for name in os.listdir(samples): + samplename = abs_join(samples, name) + outname = abs_join(seed, name) + with open(samplename, 'rb') as sample: + with open(outname, 'wb') as out: + CHUNK_SIZE = 131072 + chunk = sample.read(CHUNK_SIZE) + while len(chunk) > 0: + out.write(chunk) chunk = sample.read(CHUNK_SIZE) - while len(chunk) > 0: - out.write(chunk) - chunk = sample.read(CHUNK_SIZE) return 0