Skip to content

Jemalloc Mempool and Adaptation for CPU HASHTABLE #4154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions fbgemm_gpu/src/dram_kv_embedding_cache/SynchronizedShardedMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

#pragma once

#include <folly/Synchronized.h>
#include <folly/container/F14Map.h>
#include "folly/Synchronized.h"

#include "fixed_block_pool.h"

namespace kv_mem {

Expand All @@ -29,18 +31,30 @@ class SynchronizedShardedMap {
public:
using iterator = typename folly::F14FastMap<K, V>::const_iterator;

explicit SynchronizedShardedMap(std::size_t numShards) : shards_(numShards) {}
explicit SynchronizedShardedMap(std::size_t numShards,
std::size_t block_size,
std::size_t block_alignment,
std::size_t blocks_per_chunk = 8192)
: shards_(numShards), mempools_(numShards) {
// Init mempools_
for (auto& pool : mempools_) {
pool = std::make_unique<kv_mem::FixedBlockPool>(
block_size, block_alignment, blocks_per_chunk);
}
}

// Get shard map by index
auto& by(int index) {
return shards_.at(index % shards_.size());
}
auto& by(int index) { return shards_.at(index % shards_.size()); }

auto getNumShards() {
return shards_.size();
// Get shard pool by index
auto* pool_by(int index) {
return mempools_.at(index % shards_.size()).get();
}

auto getNumShards() { return shards_.size(); }

private:
std::vector<folly::Synchronized<folly::F14FastMap<K, V>, M>> shards_;
std::vector<std::unique_ptr<FixedBlockPool>> mempools_;
};
} // namespace kv_mem
} // namespace kv_mem
60 changes: 39 additions & 21 deletions fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

#include "SynchronizedShardedMap.h"
#include "deeplearning/fbgemm/fbgemm_gpu/src/ssd_split_embeddings_cache/initializer.h"
#include "store_value.h"
#include "store_value_utils.h"

#include <ATen/core/ivalue.h>
#include <caffe2/torch/fb/distributed/wireSerializer/WireSerializer.h>
Expand Down Expand Up @@ -70,8 +70,13 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
max_D_(max_D),
num_shards_(num_shards),
weight_ttl_in_hours_(weight_ttl_in_hours),
kv_store_(SynchronizedShardedMap<int64_t, StoreValue<weight_type>>(
num_shards_)),
block_size_(StoreValueUtils::calculate_block_size<weight_type>(max_D)),
block_alignment_(StoreValueUtils::calculate_block_alignment<weight_type>()),
kv_store_(SynchronizedShardedMap<int64_t, weight_type*>(
num_shards_,
block_size_,
block_alignment_,
/*blocks_per_chunk=*/8192)),
elem_size_(row_storage_bitwidth / 8) {
executor_ = std::make_unique<folly::CPUThreadPoolExecutor>(std::max<size_t>(
num_threads, facebook::Proc::getCpuInfo().numCpuCores));
Expand Down Expand Up @@ -185,20 +190,33 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
CHECK_EQ(indices.size(0), weights.size(0));
{
auto wlmap = kv_store_.by(shard_id).wlock();
auto* pool = kv_store_.pool_by(shard_id);
auto indices_data_ptr = indices.data_ptr<index_t>();
for (auto index_iter = indexes.begin();
index_iter != indexes.end();
index_iter++) {
const auto& id_index = *index_iter;
auto id = int64_t(indices_data_ptr[id_index]);
wlmap->try_emplace(
id,
StoreValue<weight_type>(std::vector<weight_type>(
weights[id_index]
.template data_ptr<weight_type>(),
weights[id_index]
.template data_ptr<weight_type>() +
weights[id_index].numel())));
// use mempool
weight_type* block = nullptr;
// First check if the key already exists
auto it = wlmap->find(id);
if (it != wlmap->end()) {
block = it->second;
} else {
// Key doesn't exist, allocate new block and insert.
block = StoreValueUtils::allocate<weight_type>(
block_size_, block_alignment_, pool);
wlmap->insert({id, block});
}
StoreValueUtils::update_timestamp<weight_type>(block);
auto* data_ptr = StoreValueUtils::data_ptr<weight_type>(block);
std::copy(weights[id_index]
.template data_ptr<weight_type>(),
weights[id_index]
.template data_ptr<weight_type>() +
weights[id_index].numel(),
data_ptr);
}
}
});
Expand Down Expand Up @@ -276,16 +294,13 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
row_storage_data_ptr));
continue;
}
const auto& cache_results =
cached_iter->second.getValueAndPromote();
CHECK_EQ(cache_results.size(), max_D_);
// use mempool
const auto* data_ptr = StoreValueUtils::data_ptr<weight_type>(cached_iter->second);
StoreValueUtils::update_timestamp(cached_iter->second);
std::copy(
reinterpret_cast<const weight_type*>(
&(cache_results[0])),
reinterpret_cast<const weight_type*>(
&(cache_results[max_D_])),
&(weights_data_ptr
[id_index * max_D_])); // dst_start
data_ptr,
data_ptr + max_D_,
&(weights_data_ptr[index * max_D_])); // dst_start
}
}
});
Expand Down Expand Up @@ -368,7 +383,10 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
int64_t max_D_;
int64_t num_shards_;
int64_t weight_ttl_in_hours_;
SynchronizedShardedMap<int64_t, StoreValue<weight_type>> kv_store_;
// mempool params
size_t block_size_;
size_t block_alignment_;
SynchronizedShardedMap<int64_t, weight_type*> kv_store_;
std::atomic_bool is_eviction_ongoing_ = false;
std::vector<std::unique_ptr<ssd::Initializer>> initializers_;
int64_t elem_size_;
Expand Down
128 changes: 128 additions & 0 deletions fbgemm_gpu/src/dram_kv_embedding_cache/fixed_block_pool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#pragma once

#include <cstddef>
#include <memory_resource>
#include <stdexcept>
#include <vector>

#include <cassert>

namespace kv_mem {
class FixedBlockPool : public std::pmr::memory_resource {
public:
explicit FixedBlockPool(
std::size_t block_size, // Size of each memory block
std::size_t block_alignment, // Memory block alignment requirement
std::size_t blocks_per_chunk = 8192, // Number of blocks per chunk
std::pmr::memory_resource* upstream = std::pmr::new_delete_resource())
// Minimum block size is 8 bytes
: block_size_(std::max(block_size, sizeof(void*))),
block_alignment_(block_alignment),
blocks_per_chunk_(blocks_per_chunk),
upstream_(upstream),
chunks_(upstream) {
// Validate minimum data size, whether it's less than 8 bytes
// half type, 2 bytes, minimum embedding length 4
// float type, 4 bytes, minimum embedding length 2
// Large objects use memory pool, small objects are placed directly in the
// hashtable
if (block_size < sizeof(void*)) {
// Block size must be at least able to store a pointer (for free list)
throw std::invalid_argument("Block size must be at least sizeof(void*)");
}

// Validate that alignment requirement is a power of 2
if ((block_alignment_ & (block_alignment_ - 1)) != 0) {
throw std::invalid_argument("Alignment must be power of two");
}

// Validate that block size is a multiple of alignment
if (block_size_ % block_alignment_ != 0) {
throw std::invalid_argument("Block size must align with alignment");
}

// Ensure block size is at least 1
if (block_size_ < 1) {
throw std::invalid_argument("Block size must be at least 1");
}
}

// Release all allocated memory during destruction
~FixedBlockPool() override {
for (auto&& chunk : chunks_) {
upstream_->deallocate(chunk.ptr, chunk.size, chunk.alignment);
}
}

protected:
// Core allocation function
void* do_allocate(std::size_t bytes, std::size_t alignment) override {
// Only handle matching block size and alignment requirements
if (bytes != block_size_ || alignment != block_alignment_) {
throw std::bad_alloc();
}

// Allocate a new chunk when no blocks are available
if (!free_list_) {
allocate_chunk();
}

// Take a block from the head of the free list
void* result = free_list_;
free_list_ = *static_cast<void**>(free_list_);
return result;
}

// Core deallocation function
void do_deallocate(void* p,
[[maybe_unused]] std::size_t bytes,
[[maybe_unused]] std::size_t alignment) override {
// Insert memory block back to the head of free list
*static_cast<void**>(p) = free_list_;
free_list_ = p;
}

// Resource equality comparison (only the same object is equal)
[[nodiscard]] bool do_is_equal(
const std::pmr::memory_resource& other) const noexcept override {
return this == &other;
}

private:
// Chunk metadata
struct chunk_info {
void* ptr; // Memory block pointer
std::size_t size; // Total size
std::size_t alignment;
};

// Allocate a new memory chunk
void allocate_chunk() {
const std::size_t chunk_size = block_size_ * blocks_per_chunk_;

// Allocate aligned memory through upstream resource
void* chunk_ptr = upstream_->allocate(chunk_size, block_alignment_);

// Record chunk information for later release
chunks_.push_back({chunk_ptr, chunk_size, block_alignment_});

// Initialize free list: link blocks in reverse order from chunk end to
// beginning (improves locality)
char* current = static_cast<char*>(chunk_ptr) + chunk_size;
for (std::size_t i = 0; i < blocks_per_chunk_; ++i) {
current -= block_size_;
*reinterpret_cast<void**>(current) = free_list_;
free_list_ = current;
}
}

// Member variables
const std::size_t block_size_; // Block size (not less than pointer size)
const std::size_t block_alignment_; // Block alignment requirement
const std::size_t blocks_per_chunk_; // Number of blocks per chunk
std::pmr::memory_resource* upstream_; // Upstream memory resource
std::pmr::vector<chunk_info> chunks_{
1024}; // Records of all allocated chunks
void* free_list_ = nullptr; // Free block list head pointer
};
} // namespace kv_mem
82 changes: 82 additions & 0 deletions fbgemm_gpu/src/dram_kv_embedding_cache/store_value_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#pragma once
#include <chrono>

#include "fixed_block_pool.h"

namespace kv_mem {

class StoreValueUtils {
public:
// Metadata structure (publicly accessible)
// alignas(8) MetaHeader >= sizeof(void*), avoid mempool block too small.
struct alignas(8) MetaHeader {
int64_t timestamp; // 8 bytes
// Can be extended with other fields: uint32_t counter, uint64_t key, etc.
};

// Create memory block with metadata
template <typename scalar_t>
static scalar_t* allocate(size_t& block_size,
size_t& alignment,
FixedBlockPool* pool) {
return reinterpret_cast<scalar_t*>(pool->allocate(block_size, alignment));
}

// Destroy memory block
template <typename scalar_t>
static void deallocate(scalar_t* block,
size_t& block_size,
size_t& alignment,
FixedBlockPool* pool) {
pool->deallocate(block, block_size, alignment);
}

// Calculate storage size
template <typename scalar_t>
static size_t calculate_block_size(size_t dimension) {
return sizeof(MetaHeader) + dimension * sizeof(scalar_t);
}

// Calculate alignment requirements
template <typename scalar_t>
static size_t calculate_block_alignment() {
return std::max(alignof(MetaHeader), alignof(scalar_t));
}

// Metadata operations
template <typename scalar_t>
static int64_t get_timestamp(const scalar_t* block) {
return reinterpret_cast<const MetaHeader*>(block)->timestamp;
}

template <typename scalar_t>
static void set_timestamp(scalar_t* block, int64_t ts) {
reinterpret_cast<MetaHeader*>(block)->timestamp = ts;
}

template <typename scalar_t>
static void update_timestamp(scalar_t* block) {
reinterpret_cast<MetaHeader*>(block)->timestamp = current_timestamp();
}

// Data pointer retrieval
template <typename scalar_t>
static scalar_t* data_ptr(scalar_t* block) {
return reinterpret_cast<scalar_t*>(reinterpret_cast<char*>(block) +
sizeof(MetaHeader));
}

template <typename scalar_t>
static const scalar_t* data_ptr(const scalar_t* block) {
return reinterpret_cast<const scalar_t*>(
reinterpret_cast<const char*>(block) + sizeof(MetaHeader));
}

static int64_t current_timestamp() {
return std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
// facebook::WallClockUtil::NowInUsecFast();
}
};
} // namespace kv_mem
14 changes: 14 additions & 0 deletions fbgemm_gpu/test/dram_kv_embedding_cache/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
add_executable(fixed_block_pool_test ${CMAKE_CURRENT_SOURCE_DIR}/fixed_block_pool_test.cpp)
target_compile_features(fixed_block_pool_test PUBLIC cxx_std_17)
target_include_directories(fixed_block_pool_test PUBLIC ${FBGEMM_SOURCE_DIR})
target_link_libraries(fixed_block_pool_test gtest gtest_main)

add_executable(sharded_map_test ${CMAKE_CURRENT_SOURCE_DIR}/sharded_map_test.cpp)
target_compile_features(sharded_map_test PUBLIC cxx_std_17)
target_include_directories(fixed_block_pool_test PUBLIC ${FBGEMM_SOURCE_DIR})
target_link_libraries(sharded_map_test gtest gtest_main Folly::folly)

add_executable(store_value_utils_test ${CMAKE_CURRENT_SOURCE_DIR}/store_value_utils_test.cpp)
target_compile_features(store_value_utils_test PUBLIC cxx_std_17)
target_include_directories(store_value_utils_test PUBLIC ${FBGEMM_SOURCE_DIR})
target_link_libraries(store_value_utils_test gtest gtest_main Folly::folly)
Loading