Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions lib/common/allocations.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ MEM_STATIC void* ZSTD_customMalloc(size_t size, ZSTD_customMem customMem)
return ZSTD_malloc(size);
}

MEM_STATIC void* ZSTD_customMalloc2(size_t nmemb, size_t size, ZSTD_customMem customMem)
{
if (nmemb > 0 && size > SIZE_MAX / nmemb) return NULL;
return ZSTD_customMalloc(nmemb * size, customMem);
}

MEM_STATIC void* ZSTD_customCalloc(size_t size, ZSTD_customMem customMem)
{
if (customMem.customAlloc) {
Expand All @@ -46,6 +52,12 @@ MEM_STATIC void* ZSTD_customCalloc(size_t size, ZSTD_customMem customMem)
return ZSTD_calloc(1, size);
}

MEM_STATIC void* ZSTD_customCalloc2(size_t nmemb, size_t size, ZSTD_customMem customMem)
{
if (nmemb > 0 && size > SIZE_MAX / nmemb) return NULL;
return ZSTD_customCalloc(nmemb * size, customMem);
}

MEM_STATIC void ZSTD_customFree(void* ptr, ZSTD_customMem customMem)
{
if (ptr!=NULL) {
Expand Down
6 changes: 3 additions & 3 deletions lib/common/pool.c
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ POOL_ctx* POOL_create_advanced(size_t numThreads, size_t queueSize,
* empty and full queues.
*/
ctx->queueSize = queueSize + 1;
ctx->queue = (POOL_job*)ZSTD_customCalloc(ctx->queueSize * sizeof(POOL_job), customMem);
ctx->queue = (POOL_job*)ZSTD_customCalloc2(ctx->queueSize, sizeof(POOL_job), customMem);
ctx->queueHead = 0;
ctx->queueTail = 0;
ctx->numThreadsBusy = 0;
Expand All @@ -140,7 +140,7 @@ POOL_ctx* POOL_create_advanced(size_t numThreads, size_t queueSize,
}
ctx->shutdown = 0;
/* Allocate space for the thread handles */
ctx->threads = (ZSTD_pthread_t*)ZSTD_customCalloc(numThreads * sizeof(ZSTD_pthread_t), customMem);
ctx->threads = (ZSTD_pthread_t*)ZSTD_customCalloc2(numThreads, sizeof(ZSTD_pthread_t), customMem);
ctx->threadCapacity = 0;
ctx->threadLimit = numThreads;
ctx->customMem = customMem;
Expand Down Expand Up @@ -219,7 +219,7 @@ static int POOL_resize_internal(POOL_ctx* ctx, size_t numThreads)
}
/* numThreads > threadCapacity */
ctx->threadLimit = numThreads;
{ ZSTD_pthread_t* const threadPool = (ZSTD_pthread_t*)ZSTD_customCalloc(numThreads * sizeof(ZSTD_pthread_t), ctx->customMem);
{ ZSTD_pthread_t* const threadPool = (ZSTD_pthread_t*)ZSTD_customCalloc2(numThreads, sizeof(ZSTD_pthread_t), ctx->customMem);
if (!threadPool) return 1;
/* extend existing thread pool */
ZSTD_memcpy(threadPool, ctx->threads, ctx->threadCapacity * sizeof(ZSTD_pthread_t));
Expand Down
4 changes: 2 additions & 2 deletions lib/compress/zstdmt_compress.c
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ static ZSTDMT_bufferPool* ZSTDMT_createBufferPool(unsigned maxNbBuffers, ZSTD_cu
ZSTD_customFree(bufPool, cMem);
return NULL;
}
bufPool->buffers = (Buffer*)ZSTD_customCalloc(maxNbBuffers * sizeof(Buffer), cMem);
bufPool->buffers = (Buffer*)ZSTD_customCalloc2(maxNbBuffers, sizeof(Buffer), cMem);
if (bufPool->buffers==NULL) {
ZSTDMT_freeBufferPool(bufPool);
return NULL;
Expand Down Expand Up @@ -389,7 +389,7 @@ static ZSTDMT_CCtxPool* ZSTDMT_createCCtxPool(int nbWorkers,
return NULL;
}
cctxPool->totalCCtx = nbWorkers;
cctxPool->cctxs = (ZSTD_CCtx**)ZSTD_customCalloc(nbWorkers * sizeof(ZSTD_CCtx*), cMem);
cctxPool->cctxs = (ZSTD_CCtx**)ZSTD_customCalloc2(nbWorkers, sizeof(ZSTD_CCtx*), cMem);
if (!cctxPool->cctxs) {
ZSTDMT_freeCCtxPool(cctxPool);
return NULL;
Expand Down
12 changes: 9 additions & 3 deletions lib/decompress/zstd_decompress.c
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,14 @@ static size_t ZSTD_DDictHashSet_emplaceDDict(ZSTD_DDictHashSet* hashSet, const Z
* Returns 0 on success, otherwise a zstd error code.
*/
static size_t ZSTD_DDictHashSet_expand(ZSTD_DDictHashSet* hashSet, ZSTD_customMem customMem) {
size_t newTableSize = hashSet->ddictPtrTableSize * DDICT_HASHSET_RESIZE_FACTOR;
const ZSTD_DDict** newTable = (const ZSTD_DDict**)ZSTD_customCalloc(sizeof(ZSTD_DDict*) * newTableSize, customMem);
size_t const oldTableSize = hashSet->ddictPtrTableSize;
size_t const newTableSize = oldTableSize * DDICT_HASHSET_RESIZE_FACTOR;
const ZSTD_DDict** newTable;

DEBUGLOG(4, "Expanding DDict hash table! Old size: %zu new size: %zu", oldTableSize, newTableSize);
RETURN_ERROR_IF(newTableSize <= oldTableSize, memory_allocation, "Expanded hashset size overflow!");

newTable = (const ZSTD_DDict**)ZSTD_customCalloc2(newTableSize, sizeof(ZSTD_DDict*), customMem);
const ZSTD_DDict** oldTable = hashSet->ddictPtrTable;
size_t oldTableSize = hashSet->ddictPtrTableSize;
size_t i;
Expand Down Expand Up @@ -180,7 +186,7 @@ static ZSTD_DDictHashSet* ZSTD_createDDictHashSet(ZSTD_customMem customMem) {
DEBUGLOG(4, "Allocating new hash set");
if (!ret)
return NULL;
ret->ddictPtrTable = (const ZSTD_DDict**)ZSTD_customCalloc(DDICT_HASHSET_TABLE_BASE_SIZE * sizeof(ZSTD_DDict*), customMem);
ret->ddictPtrTable = (const ZSTD_DDict**)ZSTD_customCalloc2(DDICT_HASHSET_TABLE_BASE_SIZE, sizeof(ZSTD_DDict*), customMem);
if (!ret->ddictPtrTable) {
ZSTD_customFree(ret, customMem);
return NULL;
Expand Down
68 changes: 68 additions & 0 deletions tests/allocationTests.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under both the BSD-style license (found in the
* LICENSE file in the root directory of this source tree) and the GPLv2 (found
* in the COPYING file in the root directory of this source tree).
* You may select, at your option, one of the above-listed licenses.
*/

#include <stddef.h>
#include <stdio.h>
#include <stdint.h>
#include "../lib/common/allocations.h"
#include "../lib/common/pool.h"
#include "../lib/common/zstd_internal.h"

#define ASSERT_TRUE(p) \
do { \
if (!(p)) { \
printf("Assertion failed at %s:%d: %s\n", __FILE__, __LINE__, #p); \
return 1; \
} \
} while (0)

#define ASSERT_FALSE(p) ASSERT_TRUE(!(p))

static int test_customCalloc2_overflow(void) {
size_t const nmemb = SIZE_MAX / 2 + 1;
size_t const size = 4;
void* const ptr = ZSTD_customCalloc2(nmemb, size, ZSTD_defaultCMem);
ASSERT_TRUE(ptr == NULL);
return 0;
}

static int test_pool_create_overflow(void) {
/* numThreads * sizeof(thread) should overflow */
size_t const numThreads = SIZE_MAX / 2;
POOL_ctx* const ctx = POOL_create(numThreads, 1);
ASSERT_TRUE(ctx == NULL);
return 0;
}

int main(void) {
int result = 0;

printf("Testing ZSTD_customCalloc2 overflow protection...\n");
if (test_customCalloc2_overflow()) {
printf("FAILED: ZSTD_customCalloc2 overflow test\n");
result = 1;
} else {
printf("SUCCESS: ZSTD_customCalloc2 overflow test\n");
}

printf("Testing POOL_create overflow protection...\n");
if (test_pool_create_overflow()) {
printf("FAILED: POOL_create overflow test\n");
result = 1;
} else {
printf("SUCCESS: POOL_create overflow test\n");
}

if (result == 0) {
printf("PASS: All allocation security tests\n");
}

return result;
}