From 13269e873ac823fa1fad813b7e2b84e223312588 Mon Sep 17 00:00:00 2001 From: uwezkhan06 Date: Mon, 20 Apr 2026 16:49:20 +0530 Subject: [PATCH] Hardening: Introduce overflow-checked custom allocation primitives --- lib/common/allocations.h | 12 ++++++ lib/common/pool.c | 6 +-- lib/compress/zstdmt_compress.c | 4 +- lib/decompress/zstd_decompress.c | 12 ++++-- tests/allocationTests.c | 68 ++++++++++++++++++++++++++++++++ 5 files changed, 94 insertions(+), 8 deletions(-) create mode 100644 tests/allocationTests.c diff --git a/lib/common/allocations.h b/lib/common/allocations.h index d4d392998f7..428768d8881 100644 --- a/lib/common/allocations.h +++ b/lib/common/allocations.h @@ -30,6 +30,12 @@ MEM_STATIC void* ZSTD_customMalloc(size_t size, ZSTD_customMem customMem) return ZSTD_malloc(size); } +MEM_STATIC void* ZSTD_customMalloc2(size_t nmemb, size_t size, ZSTD_customMem customMem) +{ + if (nmemb > 0 && size > SIZE_MAX / nmemb) return NULL; + return ZSTD_customMalloc(nmemb * size, customMem); +} + MEM_STATIC void* ZSTD_customCalloc(size_t size, ZSTD_customMem customMem) { if (customMem.customAlloc) { @@ -46,6 +52,12 @@ MEM_STATIC void* ZSTD_customCalloc(size_t size, ZSTD_customMem customMem) return ZSTD_calloc(1, size); } +MEM_STATIC void* ZSTD_customCalloc2(size_t nmemb, size_t size, ZSTD_customMem customMem) +{ + if (nmemb > 0 && size > SIZE_MAX / nmemb) return NULL; + return ZSTD_customCalloc(nmemb * size, customMem); +} + MEM_STATIC void ZSTD_customFree(void* ptr, ZSTD_customMem customMem) { if (ptr!=NULL) { diff --git a/lib/common/pool.c b/lib/common/pool.c index dd5fb0c4d4b..119b2a7cf65 100644 --- a/lib/common/pool.c +++ b/lib/common/pool.c @@ -126,7 +126,7 @@ POOL_ctx* POOL_create_advanced(size_t numThreads, size_t queueSize, * empty and full queues. */ ctx->queueSize = queueSize + 1; - ctx->queue = (POOL_job*)ZSTD_customCalloc(ctx->queueSize * sizeof(POOL_job), customMem); + ctx->queue = (POOL_job*)ZSTD_customCalloc2(ctx->queueSize, sizeof(POOL_job), customMem); ctx->queueHead = 0; ctx->queueTail = 0; ctx->numThreadsBusy = 0; @@ -140,7 +140,7 @@ POOL_ctx* POOL_create_advanced(size_t numThreads, size_t queueSize, } ctx->shutdown = 0; /* Allocate space for the thread handles */ - ctx->threads = (ZSTD_pthread_t*)ZSTD_customCalloc(numThreads * sizeof(ZSTD_pthread_t), customMem); + ctx->threads = (ZSTD_pthread_t*)ZSTD_customCalloc2(numThreads, sizeof(ZSTD_pthread_t), customMem); ctx->threadCapacity = 0; ctx->threadLimit = numThreads; ctx->customMem = customMem; @@ -219,7 +219,7 @@ static int POOL_resize_internal(POOL_ctx* ctx, size_t numThreads) } /* numThreads > threadCapacity */ ctx->threadLimit = numThreads; - { ZSTD_pthread_t* const threadPool = (ZSTD_pthread_t*)ZSTD_customCalloc(numThreads * sizeof(ZSTD_pthread_t), ctx->customMem); + { ZSTD_pthread_t* const threadPool = (ZSTD_pthread_t*)ZSTD_customCalloc2(numThreads, sizeof(ZSTD_pthread_t), ctx->customMem); if (!threadPool) return 1; /* extend existing thread pool */ ZSTD_memcpy(threadPool, ctx->threads, ctx->threadCapacity * sizeof(ZSTD_pthread_t)); diff --git a/lib/compress/zstdmt_compress.c b/lib/compress/zstdmt_compress.c index cb710475c2a..acb4c5e75c4 100644 --- a/lib/compress/zstdmt_compress.c +++ b/lib/compress/zstdmt_compress.c @@ -128,7 +128,7 @@ static ZSTDMT_bufferPool* ZSTDMT_createBufferPool(unsigned maxNbBuffers, ZSTD_cu ZSTD_customFree(bufPool, cMem); return NULL; } - bufPool->buffers = (Buffer*)ZSTD_customCalloc(maxNbBuffers * sizeof(Buffer), cMem); + bufPool->buffers = (Buffer*)ZSTD_customCalloc2(maxNbBuffers, sizeof(Buffer), cMem); if (bufPool->buffers==NULL) { ZSTDMT_freeBufferPool(bufPool); return NULL; @@ -389,7 +389,7 @@ static ZSTDMT_CCtxPool* ZSTDMT_createCCtxPool(int nbWorkers, return NULL; } cctxPool->totalCCtx = nbWorkers; - cctxPool->cctxs = (ZSTD_CCtx**)ZSTD_customCalloc(nbWorkers * sizeof(ZSTD_CCtx*), cMem); + cctxPool->cctxs = (ZSTD_CCtx**)ZSTD_customCalloc2(nbWorkers, sizeof(ZSTD_CCtx*), cMem); if (!cctxPool->cctxs) { ZSTDMT_freeCCtxPool(cctxPool); return NULL; diff --git a/lib/decompress/zstd_decompress.c b/lib/decompress/zstd_decompress.c index 9eb98327ef3..bae76cfa323 100644 --- a/lib/decompress/zstd_decompress.c +++ b/lib/decompress/zstd_decompress.c @@ -129,8 +129,14 @@ static size_t ZSTD_DDictHashSet_emplaceDDict(ZSTD_DDictHashSet* hashSet, const Z * Returns 0 on success, otherwise a zstd error code. */ static size_t ZSTD_DDictHashSet_expand(ZSTD_DDictHashSet* hashSet, ZSTD_customMem customMem) { - size_t newTableSize = hashSet->ddictPtrTableSize * DDICT_HASHSET_RESIZE_FACTOR; - const ZSTD_DDict** newTable = (const ZSTD_DDict**)ZSTD_customCalloc(sizeof(ZSTD_DDict*) * newTableSize, customMem); + size_t const oldTableSize = hashSet->ddictPtrTableSize; + size_t const newTableSize = oldTableSize * DDICT_HASHSET_RESIZE_FACTOR; + const ZSTD_DDict** newTable; + + DEBUGLOG(4, "Expanding DDict hash table! Old size: %zu new size: %zu", oldTableSize, newTableSize); + RETURN_ERROR_IF(newTableSize <= oldTableSize, memory_allocation, "Expanded hashset size overflow!"); + + newTable = (const ZSTD_DDict**)ZSTD_customCalloc2(newTableSize, sizeof(ZSTD_DDict*), customMem); const ZSTD_DDict** oldTable = hashSet->ddictPtrTable; size_t oldTableSize = hashSet->ddictPtrTableSize; size_t i; @@ -180,7 +186,7 @@ static ZSTD_DDictHashSet* ZSTD_createDDictHashSet(ZSTD_customMem customMem) { DEBUGLOG(4, "Allocating new hash set"); if (!ret) return NULL; - ret->ddictPtrTable = (const ZSTD_DDict**)ZSTD_customCalloc(DDICT_HASHSET_TABLE_BASE_SIZE * sizeof(ZSTD_DDict*), customMem); + ret->ddictPtrTable = (const ZSTD_DDict**)ZSTD_customCalloc2(DDICT_HASHSET_TABLE_BASE_SIZE, sizeof(ZSTD_DDict*), customMem); if (!ret->ddictPtrTable) { ZSTD_customFree(ret, customMem); return NULL; diff --git a/tests/allocationTests.c b/tests/allocationTests.c new file mode 100644 index 00000000000..affa8f090b6 --- /dev/null +++ b/tests/allocationTests.c @@ -0,0 +1,68 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the + * LICENSE file in the root directory of this source tree) and the GPLv2 (found + * in the COPYING file in the root directory of this source tree). + * You may select, at your option, one of the above-listed licenses. + */ + +#include +#include +#include +#include "../lib/common/allocations.h" +#include "../lib/common/pool.h" +#include "../lib/common/zstd_internal.h" + +#define ASSERT_TRUE(p) \ + do { \ + if (!(p)) { \ + printf("Assertion failed at %s:%d: %s\n", __FILE__, __LINE__, #p); \ + return 1; \ + } \ + } while (0) + +#define ASSERT_FALSE(p) ASSERT_TRUE(!(p)) + +static int test_customCalloc2_overflow(void) { + size_t const nmemb = SIZE_MAX / 2 + 1; + size_t const size = 4; + void* const ptr = ZSTD_customCalloc2(nmemb, size, ZSTD_defaultCMem); + ASSERT_TRUE(ptr == NULL); + return 0; +} + +static int test_pool_create_overflow(void) { + /* numThreads * sizeof(thread) should overflow */ + size_t const numThreads = SIZE_MAX / 2; + POOL_ctx* const ctx = POOL_create(numThreads, 1); + ASSERT_TRUE(ctx == NULL); + return 0; +} + +int main(void) { + int result = 0; + + printf("Testing ZSTD_customCalloc2 overflow protection...\n"); + if (test_customCalloc2_overflow()) { + printf("FAILED: ZSTD_customCalloc2 overflow test\n"); + result = 1; + } else { + printf("SUCCESS: ZSTD_customCalloc2 overflow test\n"); + } + + printf("Testing POOL_create overflow protection...\n"); + if (test_pool_create_overflow()) { + printf("FAILED: POOL_create overflow test\n"); + result = 1; + } else { + printf("SUCCESS: POOL_create overflow test\n"); + } + + if (result == 0) { + printf("PASS: All allocation security tests\n"); + } + + return result; +}