Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,12 +1819,13 @@ def split_embedding_weights(
row_offset = 0
for emb_height, emb_dim in self.embedding_specs:
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
db=self.ssd_db,
shape=[emb_height, emb_dim],
dtype=dtype,
row_offset=row_offset,
snapshot_handle=snapshot_handle,
)
# TODO add if else support in the future for dram integration.
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
row_offset += emb_height
splits.append(PartiallyMaterializedTensor(tensor_wrapper))
return splits
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
/// get all ids in the kvstore
///
/// @return a Tensor contained ids
at::Tensor get_keys_in_range(int64_t start, int64_t end) {
at::Tensor get_keys_in_range(int64_t start, int64_t end) override {
std::vector<std::vector<int64_t>> ids;
for (int i = 0; i < num_shards_; i++) {
ids.push_back(std::vector<int64_t>());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#pragma once

#include <ATen/ATen.h>
#include "../ssd_split_embeddings_cache/kv_tensor_wrapper.h"
#include "dram_kv_embedding_cache.h"

Expand Down Expand Up @@ -58,98 +59,49 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
at::Tensor count,
int64_t timestep,
bool is_bwd) {
return std::visit(
[&indices, &weights, &count, &timestep](auto& ptr) {
if (ptr) {
ptr->set_cuda(indices, weights, count, timestep);
}
},
impl_);
return impl_->set_cuda(indices, weights, count, timestep);
}

void get_cuda(at::Tensor indices, at::Tensor weights, at::Tensor count) {
return std::visit(
[&indices, &weights, &count](auto& ptr) {
if (ptr) {
ptr->get_cuda(indices, weights, count);
}
},
impl_);
return impl_->get_cuda(indices, weights, count);
}

void set(at::Tensor indices, at::Tensor weights, at::Tensor count) {
return std::visit(
[&indices, &weights, &count](auto& ptr) {
if (ptr) {
ptr->set(indices, weights, count);
}
},
impl_);
return impl_->set(indices, weights, count);
}

void flush() {
return std::visit(
[](auto& ptr) {
if (ptr) {
ptr->flush();
}
},
impl_);
return impl_->flush();
}

void set_range_to_storage(
const at::Tensor& weights,
const int64_t start,
const int64_t length) {
return std::visit(
[&weights, &start, &length](auto& ptr) {
if (ptr) {
ptr->set_range_to_storage(weights, start, length);
}
},
impl_);
return impl_->set_range_to_storage(weights, start, length);
}

void get(
at::Tensor indices,
at::Tensor weights,
at::Tensor count,
int64_t sleep_ms) {
return std::visit(
[&indices, &weights, &count, sleep_ms](auto& ptr) {
if (ptr) {
ptr->get(indices, weights, count, sleep_ms);
}
},
impl_);
return impl_->get(indices, weights, count, sleep_ms);
}

void wait_util_filling_work_done() {
return std::visit(
[](auto& ptr) {
if (ptr) {
ptr->wait_util_filling_work_done();
}
},
impl_);
return impl_->wait_util_filling_work_done();
}

at::Tensor get_keys_in_range(int64_t start, int64_t end) {
return std::visit(
[&start, &end](auto& ptr) {
if (ptr) {
return ptr->get_keys_in_range(start, end);
}
return at::empty({0});
},
impl_);
return impl_->get_keys_in_range(start, end);
}

private:
// friend class EmbeddingRocksDBWrapper;
friend class ssd::KVTensorWrapper;

DramKVEmbeddingCacheVariant impl_;
std::shared_ptr<kv_db::EmbeddingKVDB> impl_;
};

} // namespace kv_mem
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include <folly/Random.h>
#include <folly/concurrency/UnboundedQueue.h>
#include <folly/coro/BlockingWait.h>
#include <folly/executors/CPUThreadPoolExecutor.h>
#include <folly/futures/Future.h>
#include <folly/hash/Hash.h>
Expand All @@ -41,6 +42,10 @@
#include "fbgemm_gpu/split_embeddings_cache/cachelib_cache.h"
#include "fbgemm_gpu/utils/dispatch_macros.h"

namespace ssd {
class SnapshotHandle;
}

namespace kv_db {

/// @ingroup embedding-ssd
Expand Down Expand Up @@ -240,25 +245,73 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
// finished, it could also be called in unitest to sync
void wait_util_filling_work_done();

virtual at::Tensor get_keys_in_range(int64_t start, int64_t end) {
(void)start;
(void)end;
FBEXCEPTION("Not implemented");
}

void set_range_to_storage(
const at::Tensor& weights,
const int64_t start,
const int64_t length) {
const auto seq_indices =
at::arange(start, start + length, at::TensorOptions().dtype(at::kLong));
const auto count = at::tensor({length}, at::ScalarType::Long);
folly::coro::blockingWait(set_kv_db_async(seq_indices, weights, count));
}

virtual void get_range_from_snapshot(
const at::Tensor& weights,
const int64_t start,
const int64_t length,
const ssd::SnapshotHandle* snapshot_handle) {
(void)weights;
(void)start;
(void)length;
(void)snapshot_handle;
FBEXCEPTION("Not implemented");
}

void set_kv_to_storage(const at::Tensor& ids, const at::Tensor& weights) {
const auto count = at::tensor({ids.size(0)}, at::ScalarType::Long);
folly::coro::blockingWait(set_kv_db_async(ids, weights, count));
}

virtual void get_kv_from_storage_by_snapshot(
const at::Tensor& ids,
const at::Tensor& weights,
const ssd::SnapshotHandle* snapshot_handle) {
(void)ids;
(void)weights;
(void)snapshot_handle;
FBEXCEPTION("Not implemented");
}

virtual int64_t get_max_D() {
return max_D_;
}

private:
/// Find non-negative embedding indices in <indices> and shard them into
/// #cachelib_pools pieces to be lookedup in parallel
///
/// @param indices The 1D embedding index tensor, should skip on negative
/// value
/// @param count A single element tensor that contains the number of indices
/// to be processed
/// @param count A single element tensor that contains the number of
/// indices to be processed
///
/// @return preallocated list of memory pointer with <count> size, cache miss
/// or invalid embedding indices will have sentinel pointer(nullptr)
/// @note element in <indices> will be updated to sentinel value on cache hit
/// @return preallocated list of memory pointer with <count> size, cache
/// miss or invalid embedding indices will have sentinel pointer(nullptr)
/// @note element in <indices> will be updated to sentinel value on cache
/// hit
std::shared_ptr<CacheContext> get_cache(
const at::Tensor& indices,
const at::Tensor& count);

/// Find non-negative embedding indices in <indices> and shard them into
/// #cachelib_pools pieces, insert into Cachelib in parallel with their paired
/// embeddings from <weights>
/// #cachelib_pools pieces, insert into Cachelib in parallel with their
/// paired embeddings from <weights>
///
/// @param indices The 1D embedding index tensor, should skip on negative
/// value
Expand All @@ -268,8 +321,8 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
/// to be processed
///
/// @return None if L2 is missing or no eviction, other wise return tuple of
/// tensors with length of <count> containing L2 evicted embedding indices and
/// embeddings, invalid pairs will have sentinel value(-1) on <indices>
/// tensors with length of <count> containing L2 evicted embedding indices
/// and embeddings, invalid pairs will have sentinel value(-1) on <indices>
folly::Optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>> set_cache(
const at::Tensor& indices,
const at::Tensor& weights,
Expand All @@ -284,8 +337,8 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
/// relative slot in <cached_addr_list>
///
/// @return None
/// @note weigths will be updated on the slot that paired up with valid cache
/// addr pointer
/// @note weigths will be updated on the slot that paired up with valid
/// cache addr pointer
folly::SemiFuture<std::vector<folly::Unit>> cache_memcpy(
const at::Tensor& weights,
const std::vector<void*>& cached_addr_list);
Expand Down Expand Up @@ -337,11 +390,11 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
// bg L2 write
// - L1 cache eviction: insert into bg queue for L2 write
// - ScratchPad update: insert into bg queue for L2 write
// in non-prefetch pipeline, cuda synchronization guarantee get_cuda() happen
// after SP update
// in prefetch pipeline, cuda sync only guarantee get_cuda() happen after L1
// cache eviction pipeline case, SP bwd update could happen in parallel with
// L2 read mutex is used for l2 cache to do read / write exclusively
// in non-prefetch pipeline, cuda synchronization guarantee get_cuda()
// happen after SP update in prefetch pipeline, cuda sync only guarantee
// get_cuda() happen after L1 cache eviction pipeline case, SP bwd update
// could happen in parallel with L2 read mutex is used for l2 cache to do
// read / write exclusively
std::mutex l2_cache_mtx_;

// perf stats
Expand Down
27 changes: 25 additions & 2 deletions fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
#include <ATen/Tensor.h> // @manual=//caffe2:ATen-core
#include <torch/custom_class.h>

namespace kv_mem {
class DramKVEmbeddingCacheWrapper;
}

namespace kv_db {
class EmbeddingKVDB;
}

namespace ssd {

class EmbeddingRocksDB;
Expand All @@ -32,7 +40,6 @@ struct EmbeddingSnapshotHandleWrapper : public torch::jit::CustomClassHolder {
class KVTensorWrapper : public torch::jit::CustomClassHolder {
public:
explicit KVTensorWrapper(
c10::intrusive_ptr<EmbeddingRocksDBWrapper> db,
std::vector<int64_t> shape,
int64_t dtype,
int64_t row_offset,
Expand All @@ -41,6 +48,22 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {

at::Tensor narrow(int64_t dim, int64_t start, int64_t length);

/// @brief if the backend storage is SSD, use this function
/// to set db_ inside KVTensorWrapper
/// this function should be called right after KVTensorWrapper
/// initialization
/// @param db: the DB wrapper
void set_embedding_rocks_dp_wrapper(
c10::intrusive_ptr<EmbeddingRocksDBWrapper> db);

/// @brief if the backend storage is DramKV, use this function
/// to set db_ inside KVTensorWrapper
/// this function should be called right after KVTensorWrapper
/// initialization
/// @param db: the DB wrapper
void set_dram_db_wrapper(
c10::intrusive_ptr<kv_mem::DramKVEmbeddingCacheWrapper> db);

void set_range(
int64_t dim,
const int64_t start,
Expand All @@ -66,7 +89,7 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {
std::string layout_str();

private:
std::shared_ptr<EmbeddingRocksDB> db_;
std::shared_ptr<kv_db::EmbeddingKVDB> db_;
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper> snapshot_handle_;
at::TensorOptions options_;
std::vector<int64_t> shape_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,23 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
class SnapshotHandle {};

KVTensorWrapper::KVTensorWrapper(
c10::intrusive_ptr<EmbeddingRocksDBWrapper> db,
std::vector<int64_t> shape,
[[maybe_unused]] int64_t dtype,
int64_t row_offset,
[[maybe_unused]] std::optional<
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>> snapshot_handle)
// @lint-ignore CLANGTIDY clang-diagnostic-missing-noreturn
: db_(db->impl_), shape_(std::move(shape)), row_offset_(row_offset) {
: shape_(std::move(shape)), row_offset_(row_offset) {
FBEXCEPTION("Not implemented");
}

void KVTensorWrapper::set_embedding_rocks_dp_wrapper(
c10::intrusive_ptr<EmbeddingRocksDBWrapper> db) {
FBEXCEPTION("Not implemented");
}

void KVTensorWrapper::set_dram_db_wrapper(
c10::intrusive_ptr<kv_mem::DramKVEmbeddingCacheWrapper> db) {
FBEXCEPTION("Not implemented");
}

Expand Down
Loading
Loading