Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Functions to Read and Write dcz Headers #4272

Draft
wants to merge 4 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions lib/common/sha256.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under both the BSD-style license (found in the
* LICENSE file in the root directory of this source tree) and the GPLv2 (found
* in the COPYING file in the root directory of this source tree).
* You may select, at your option, one of the above-listed licenses.
*/

#include "sha256.h"

#include <string.h>

#include "bits.h"
#include "mem.h"
#include "zstd_deps.h"

#define ZSTD_SHA256_BLOCK_SIZE 64

#define ZSTD_SHA256_CH(x, y, z) (((x) & (y)) ^ (~(x) & (z)))
#define ZSTD_SHA256_MAJ(x, y, z) (((x) & (y)) ^ ((x) & (z)) ^ ((y) & (z)))
#define ZSTD_SHA256_SIGMA0(x) \
(ZSTD_rotateRight_U32(x, 2) ^ \
ZSTD_rotateRight_U32(x, 13) ^ \
ZSTD_rotateRight_U32(x, 22))
#define ZSTD_SHA256_SIGMA1(x) \
(ZSTD_rotateRight_U32(x, 6) ^ \
ZSTD_rotateRight_U32(x, 11) ^ \
ZSTD_rotateRight_U32(x, 25))
#define ZSTD_SHA256_sigma0(x) \
(ZSTD_rotateRight_U32(x, 7) ^ \
ZSTD_rotateRight_U32(x, 18) ^ \
((x) >> 3))
#define ZSTD_SHA256_sigma1(x) \
(ZSTD_rotateRight_U32(x, 17) ^ \
ZSTD_rotateRight_U32(x, 19) ^ \
((x) >> 10))


static const uint32_t I[8] = {
0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a,
0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19
};

static const uint32_t K[64] = {
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5,
0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3,
0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc,
0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7,
0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13,
0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3,
0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5,
0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208,
0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2
};

static void ZSTD_SHA256_block(uint32_t hash[8], const uint8_t block[64]) {
uint32_t w[64];
uint32_t a, b, c, d, e, f, g, h, t;
for (t = 0; t < 16; t++) {
w[t] = MEM_readBE32(block);
block += 4;
}
for (; t < 64; t++) {
w[t] = ZSTD_SHA256_sigma1(w[t - 2]) + w[t - 7]
+ ZSTD_SHA256_sigma0(w[t - 15]) + w[t - 16];
}
a = hash[0];
b = hash[1];
c = hash[2];
d = hash[3];
e = hash[4];
f = hash[5];
g = hash[6];
h = hash[7];
for (t = 0; t < 64; t++) {
const uint32_t t1 = h + ZSTD_SHA256_SIGMA1(e) + ZSTD_SHA256_CH(e, f, g)
+ K[t] + w[t];
const uint32_t t2 = ZSTD_SHA256_SIGMA0(a) + ZSTD_SHA256_MAJ(a, b, c);
h = g;
g = f;
f = e;
e = d + t1;
d = c;
c = b;
b = a;
a = t1 + t2;
}
hash[0] += a;
hash[1] += b;
hash[2] += c;
hash[3] += d;
hash[4] += e;
hash[5] += f;
hash[6] += g;
hash[7] += h;
}

static void ZSTD_SHA256_finish(
uint32_t hash[8],
const void* msg_tail,
const size_t remaining,
const size_t total_size) {
uint8_t buf[ZSTD_SHA256_BLOCK_SIZE];
uint64_t total_bits = total_size * 8;
uint32_t i;
assert(remaining < 64);
ZSTD_memcpy(buf, msg_tail, remaining);
buf[remaining] = '\x80';
ZSTD_memset(buf + remaining + 1, 0, ZSTD_SHA256_BLOCK_SIZE - remaining - 1);
if (remaining > 55) {
ZSTD_SHA256_block(hash, buf);
ZSTD_memset(buf, 0, ZSTD_SHA256_BLOCK_SIZE);
}
for (i = 0; i < 8; i++) {
buf[56 + i] = (uint8_t)(total_bits >> ((7 - i) * 8));
}
ZSTD_SHA256_block(hash, buf);
}

static ZSTD_SHA256_Result ZSTD_SHA256_digest(uint32_t hash[8]) {
ZSTD_SHA256_Result result;
uint8_t* d = result.digest;
uint32_t i;
for (i = 0; i < 8; i++) {
MEM_writeBE32(d, hash[i]);
d += 4;
}
return result;
}

ZSTD_SHA256_Result ZSTD_SHA256_hash(const void* const msg, const size_t size) {
const uint8_t* cur = (const uint8_t*)msg;
size_t remaining = size;
uint32_t hash[8];

/* Init hash */
ZSTD_memcpy(hash, I, sizeof(hash));

/* Process full blocks */
while (remaining >= ZSTD_SHA256_BLOCK_SIZE) {
ZSTD_SHA256_block(hash, cur);
cur += ZSTD_SHA256_BLOCK_SIZE;
remaining -= ZSTD_SHA256_BLOCK_SIZE;
}

/* Pad */
ZSTD_SHA256_finish(hash, cur, remaining, size);

/* Extract digest */
return ZSTD_SHA256_digest(hash);
}
27 changes: 27 additions & 0 deletions lib/common/sha256.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under both the BSD-style license (found in the
* LICENSE file in the root directory of this source tree) and the GPLv2 (found
* in the COPYING file in the root directory of this source tree).
* You may select, at your option, one of the above-listed licenses.
*/

#ifndef ZSTD_SHA256_H
#define ZSTD_SHA256_H

#include "mem.h" /* size_t, uint8_t */

#define ZSTD_SHA256_DIGEST_SIZE 32

typedef struct {
uint8_t digest[ZSTD_SHA256_DIGEST_SIZE];
} ZSTD_SHA256_Result;

/**
* Returns the SHA-256 hash of the provided data.
*/
ZSTD_SHA256_Result ZSTD_SHA256_hash(const void* data, size_t len);

#endif /* ZSTD_SHA256_H */
3 changes: 3 additions & 0 deletions lib/common/zstd_deps.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* NULL
* INT_MAX
* UINT_MAX
* ZSTD_memcmp()
* ZSTD_memcpy()
* ZSTD_memset()
* ZSTD_memmove()
Expand All @@ -41,10 +42,12 @@
#include <string.h>

#if defined(__GNUC__) && __GNUC__ >= 4
# define ZSTD_memcmp(a,b,l) __builtin_memcmp((a),(b),(l))
# define ZSTD_memcpy(d,s,l) __builtin_memcpy((d),(s),(l))
# define ZSTD_memmove(d,s,l) __builtin_memmove((d),(s),(l))
# define ZSTD_memset(p,v,l) __builtin_memset((p),(v),(l))
#else
# define ZSTD_memcmp(a,b,l) memcmp((a),(b),(l))
# define ZSTD_memcpy(d,s,l) memcpy((d),(s),(l))
# define ZSTD_memmove(d,s,l) memmove((d),(s),(l))
# define ZSTD_memset(p,v,l) memset((p),(v),(l))
Expand Down
4 changes: 4 additions & 0 deletions lib/common/zstd_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#ifndef XXH_STATIC_LINKING_ONLY
# define XXH_STATIC_LINKING_ONLY /* XXH64_state_t */
#endif
#include "sha256.h" /* ZSTD_SHA256_DIGEST_SIZE */
#include "xxhash.h" /* XXH_reset, update, digest */
#ifndef ZSTD_NO_TRACE
# include "zstd_trace.h"
Expand Down Expand Up @@ -87,6 +88,9 @@ typedef enum { bt_raw, bt_rle, bt_compressed, bt_reserved } blockType_e;

#define ZSTD_FRAMECHECKSUMSIZE 4

#define ZSTD_HTTPDCZ_HEADER_SIZE (ZSTD_SKIPPABLEHEADERSIZE + ZSTD_SHA256_DIGEST_SIZE)
#define ZSTD_HTTPDCZ_HEADER_SKIPPABLE_VARIANT (0x0e)

#define MIN_SEQUENCES_SIZE 1 /* nbSeq==0 */
#define MIN_CBLOCK_SIZE (1 /*litCSize*/ + 1 /* RLE or RAW */) /* for a non-null block */
#define MIN_LITERALS_FOR_4_STREAMS 6
Expand Down
39 changes: 37 additions & 2 deletions lib/compress/zstd_compress.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
#include "zstd_opt.h"
#include "zstd_ldm.h"
#include "zstd_compress_superblock.h"
#include "../common/bits.h" /* ZSTD_highbit32, ZSTD_rotateRight_U64 */
#include "../common/bits.h" /* ZSTD_highbit32, ZSTD_rotateRight_U64 */
#include "../common/sha256.h" /* ZSTD_SHA256_Result, ZSTD_SHA256_hash, ZSTD_SHA256_DIGEST_SIZE */

/* ***************************************************************
* Tuning parameters
Expand Down Expand Up @@ -5200,7 +5201,7 @@ ZSTD_compress_insertDictionary(ZSTD_compressedBlockState_t* bs,
void* workspace)
{
DEBUGLOG(4, "ZSTD_compress_insertDictionary (dictSize=%u)", (U32)dictSize);
if ((dict==NULL) || (dictSize<8)) {
if ((dict == NULL) || (dictSize < ZSTD_DICTIONARYSIZE_MIN)) {
RETURN_ERROR_IF(dictContentType == ZSTD_dct_fullDict, dictionary_wrong, "");
return 0;
}
Expand Down Expand Up @@ -7836,3 +7837,37 @@ void ZSTD_CCtxParams_registerSequenceProducer(
params->extSeqProdState = NULL;
}
}

size_t ZSTD_writeHeaderForHTTPDCZ(
void* dst, size_t dstCapacity,
const void* dict, size_t dictSize)
{
ZSTD_SHA256_Result hash;
RETURN_ERROR_IF(dst == NULL, dstBuffer_null, "NULL dst buffer.");
RETURN_ERROR_IF(dstCapacity < ZSTD_HTTPDCZ_HEADER_SIZE,
dstSize_tooSmall, "Too small to write frame header.");
RETURN_ERROR_IF(dict == NULL,
dictionary_corrupted, "Dictionary invalid: NULL pointer.");
RETURN_ERROR_IF(dictSize < ZSTD_DICTIONARYSIZE_MIN,
dictionary_corrupted, "Dictionary invalid: too small.");

hash = ZSTD_SHA256_hash(dict, dictSize);
return ZSTD_writeSkippableFrame(
dst, dstCapacity,
hash.digest, ZSTD_SHA256_DIGEST_SIZE,
ZSTD_HTTPDCZ_HEADER_SKIPPABLE_VARIANT);
}

size_t ZSTD_writeHeaderForHTTPDCZ_fromCDict(
void* dst, size_t dstCapacity,
const ZSTD_CDict* cdict)
{
RETURN_ERROR_IF(cdict == NULL,
parameter_unsupported, "NULL cdict pointer.");
RETURN_ERROR_IF(cdict->dictContentType != ZSTD_dct_rawContent,
dictionary_corrupted,
"Dictionary must be loaded as raw content for DCZ.");
return ZSTD_writeHeaderForHTTPDCZ(
dst, dstCapacity,
cdict->dictContent, cdict->dictContentSize);
}
4 changes: 2 additions & 2 deletions lib/compress/zstd_opt.c
Original file line number Diff line number Diff line change
Expand Up @@ -730,11 +730,11 @@ ZSTD_insertBtAndGetAllMatches (
if ((dictMode == ZSTD_noDict) || (dictMode == ZSTD_dictMatchState) || (matchIndex+matchLength >= dictLimit)) {
assert(matchIndex+matchLength >= dictLimit); /* ensure the condition is correct when !extDict */
match = base + matchIndex;
if (matchIndex >= dictLimit) assert(memcmp(match, ip, matchLength) == 0); /* ensure early section of match is equal as expected */
if (matchIndex >= dictLimit) assert(ZSTD_memcmp(match, ip, matchLength) == 0); /* ensure early section of match is equal as expected */
matchLength += ZSTD_count(ip+matchLength, match+matchLength, iLimit);
} else {
match = dictBase + matchIndex;
assert(memcmp(match, ip, matchLength) == 0); /* ensure early section of match is equal as expected */
assert(ZSTD_memcmp(match, ip, matchLength) == 0); /* ensure early section of match is equal as expected */
matchLength += ZSTD_count_2segments(ip+matchLength, match+matchLength, iLimit, dictEnd, prefixStart);
if (matchIndex+matchLength >= dictLimit)
match = base + matchIndex; /* prepare for match[matchLength] read */
Expand Down
6 changes: 6 additions & 0 deletions lib/decompress/zstd_ddict.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ size_t ZSTD_DDict_dictSize(const ZSTD_DDict* ddict)
return ddict->dictSize;
}

ZSTD_dictContentType_e ZSTD_DDict_type(const ZSTD_DDict* ddict)
{
assert(ddict != NULL);
return (ddict->entropyPresent) ? ZSTD_dct_fullDict : ZSTD_dct_rawContent;
}

void ZSTD_copyDDictParameters(ZSTD_DCtx* dctx, const ZSTD_DDict* ddict)
{
DEBUGLOG(4, "ZSTD_copyDDictParameters");
Expand Down
2 changes: 2 additions & 0 deletions lib/decompress/zstd_ddict.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
const void* ZSTD_DDict_dictContent(const ZSTD_DDict* ddict);
size_t ZSTD_DDict_dictSize(const ZSTD_DDict* ddict);

ZSTD_dictContentType_e ZSTD_DDict_type(const ZSTD_DDict* ddict);

void ZSTD_copyDDictParameters(ZSTD_DCtx* dctx, const ZSTD_DDict* ddict);


Expand Down
55 changes: 54 additions & 1 deletion lib/decompress/zstd_decompress.c
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@
#define FSE_STATIC_LINKING_ONLY
#include "../common/fse.h"
#include "../common/huf.h"
#include "../common/xxhash.h" /* XXH64_reset, XXH64_update, XXH64_digest, XXH64 */
#include "../common/sha256.h" /* ZSTD_SHA256_Result, ZSTD_SHA256_hash, ZSTD_SHA256_DIGEST_SIZE */
#include "../common/xxhash.h" /* XXH64_reset, XXH64_update, XXH64_digest, XXH64 */
#include "zstd_decompress_internal.h" /* ZSTD_DCtx */
#include "zstd_ddict.h" /* ZSTD_DDictDictContent */
#include "zstd_decompress_block.h" /* ZSTD_decompressBlock_internal */
Expand Down Expand Up @@ -623,6 +624,7 @@ size_t ZSTD_readSkippableFrame(void* dst, size_t dstCapacity,

/* check input validity */
RETURN_ERROR_IF(!ZSTD_isSkippableFrame(src, srcSize), frameParameter_unsupported, "");
FORWARD_IF_ERROR(skippableFrameSize, "");
RETURN_ERROR_IF(skippableFrameSize < ZSTD_SKIPPABLEHEADERSIZE || skippableFrameSize > srcSize, srcSize_wrong, "");
RETURN_ERROR_IF(skippableContentSize > dstCapacity, dstSize_tooSmall, "");

Expand All @@ -635,6 +637,57 @@ size_t ZSTD_readSkippableFrame(void* dst, size_t dstCapacity,
}
}

size_t ZSTD_readHeaderForHTTPDCZ(
void* dst, size_t dstCapacity,
const void* src, size_t srcSize)
{
unsigned variant;
size_t result;
RETURN_ERROR_IF(dst == NULL, dstBuffer_null, "NULL dst buffer.");
RETURN_ERROR_IF(dstCapacity < ZSTD_SHA256_DIGEST_SIZE,
dstSize_tooSmall, "Too small");

result = ZSTD_readSkippableFrame(dst, dstCapacity, &variant, src, srcSize);

FORWARD_IF_ERROR(result, "Couldn't read skippable frame.");
RETURN_ERROR_IF(variant != ZSTD_HTTPDCZ_HEADER_SKIPPABLE_VARIANT,
prefix_unknown, "Skippable frame magic is wrong for DCZ.");
RETURN_ERROR_IF(result != ZSTD_SHA256_DIGEST_SIZE,
corruption_detected, "Wrong skippable frame size for DCZ.");
return result;
}

size_t ZSTD_readHeaderForHTTPDCZ_validateDictMatches(
const void* src, size_t srcSize,
const void* dict, size_t dictSize)
{
ZSTD_SHA256_Result frameHash, dictHash;
size_t result = ZSTD_readHeaderForHTTPDCZ(frameHash.digest, ZSTD_SHA256_DIGEST_SIZE, src, srcSize);
FORWARD_IF_ERROR(result, "Couldn't read DCZ header.");
RETURN_ERROR_IF(dict == NULL,
dictionary_corrupted, "Dictionary invalid: NULL pointer.");
RETURN_ERROR_IF(dictSize < ZSTD_DICTIONARYSIZE_MIN,
dictionary_corrupted, "Dictionary invalid: too small.");
dictHash = ZSTD_SHA256_hash(dict, dictSize);
RETURN_ERROR_IF(
ZSTD_memcmp(frameHash.digest, dictHash.digest, ZSTD_SHA256_DIGEST_SIZE),
dictionary_wrong, "DCZ hashes don't match.");
return 1;
}

size_t ZSTD_readHeaderForHTTPDCZ_validateDDictMatches(
const void* src, size_t srcSize,
const ZSTD_DDict* ddict)
{
RETURN_ERROR_IF(ddict == NULL, parameter_unsupported, "NULL ddict pointer.");
RETURN_ERROR_IF(ZSTD_DDict_type(ddict) != ZSTD_dct_rawContent,
dictionary_corrupted,
"Dictionary must be loaded as raw content for DCZ.");
return ZSTD_readHeaderForHTTPDCZ_validateDictMatches(
src, srcSize,
ZSTD_DDict_dictContent(ddict), ZSTD_DDict_dictSize(ddict));
}

/** ZSTD_findDecompressedSize() :
* `srcSize` must be the exact length of some number of ZSTD compressed and/or
* skippable frames
Expand Down
Loading
Loading