Skip to content

Commit c967848

Browse files
Raahul Kalyaan Jakkafacebook-github-bot
Raahul Kalyaan Jakka
authored andcommitted
Adding a mutex lock to set_range function
Summary: X-link: facebookresearch/FBGEMM#1281 **Context:** While we expose KVTensor to external surfaces (i.e., checkpointing), they have the freedom to leverage the KVTensor functions in a concurrent fashion. For example, https://www.internalfb.com/code/fbsource/[5b7b1eef7d69]/fbcode/aiplatform/modelstore/checkpointing/pyper/TensorLoaderCallback.h?lines=85-86 This function here calls set_range to the same KVTensor multiple times because we divide a huge chunk of data into smaller chunks and try to write it in a concurrent fashion. This is a bad practice because in SSD I/O, We also use multi threading to write data in KVTensor. Currently, we use 32 threads (each thread per shard) to write data. Due to this, when we call set_range multiple times, this can lead to thread contention and increase in synchronization overhead **In this Diff:** We introduce a mutex lock on the set_range function, due to this every transaction is locked during execution and the multiple calls are processed serially leading to more efficient use of the threads Differential Revision: D75555658
1 parent bcde9c1 commit c967848

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {
9999
int64_t row_offset_;
100100
std::optional<at::Tensor> sorted_indices_ = std::nullopt;
101101
int64_t width_offset_;
102+
std::mutex mtx;
102103
};
103104

104105
} // namespace ssd

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include <torch/library.h>
1313

1414
#include <torch/custom_class.h>
15-
15+
#include <mutex>
1616
#include "../dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h"
1717
#include "./ssd_table_batched_embeddings.h"
1818
#include "embedding_rocksdb_wrapper.h"
@@ -377,6 +377,8 @@ void KVTensorWrapper::set_range(
377377
const int64_t start,
378378
const int64_t length,
379379
const at::Tensor& weights) {
380+
// Mutex lock for disabling concurrent writes to the same KVTensor
381+
std::lock_guard<std::mutex> lock(mtx);
380382
CHECK_EQ(dim, 0) << "Only set_range on dim 0 is supported";
381383
CHECK_TRUE(db_ != nullptr);
382384
CHECK_GE(db_->get_max_D(), shape_[1]);

0 commit comments

Comments
 (0)