Skip to content

Commit 7e280ac

Browse files
chouxifacebook-github-bot
authored andcommitted
Pass the updated embeddings to EmbeddingKVDB (#4210)
Summary: X-link: facebookresearch/FBGEMM#1285 Added the functions `stream_cuda/stream`. - When `blocking_tensor_copy=false` Create a new thread to execute the callback registered via the `stream_cuda` the reason as in comment. this is based on [the findings that callbacks are executed in serialized manner]( https://fb.workplace.com/permalink.php?story_fbid=pfbid02s7RWvRZ4g2nS5i42kkyApvLsCbiRpGrdAEPs2p5qr2MDnq5YbgfThQ6PXSB6y13Al&id=100026528794331) - When `blocking_tensor_copy=true` just copy and enqueue in the callback thread. Added profile to the callback functions. Added the function `stream_sync_cuda` - To explicitly join the async copy thread to make sure the copy is happened before the buffer got overwritten again. Reviewed By: q10 Differential Revision: D73819097
1 parent d592c92 commit 7e280ac

File tree

7 files changed

+277
-4
lines changed

7 files changed

+277
-4
lines changed

fbgemm_gpu/src/ps_split_embeddings_cache/ps_split_table_batched_embeddings.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,18 @@ class EmbeddingParameterServerWrapper : public torch::jit::CustomClassHolder {
5252
return impl_->set_cuda(indices, weights, count, timestep, is_bwd);
5353
}
5454

55+
void stream_cuda(
56+
const Tensor& indices,
57+
const Tensor& weights,
58+
const Tensor& count,
59+
bool blocking_tensor_copy = true) {
60+
return impl_->stream_cuda(indices, weights, count, blocking_tensor_copy);
61+
}
62+
63+
void stream_sync_cuda() {
64+
return impl_->stream_sync_cuda();
65+
}
66+
5567
void get_cuda(Tensor indices, Tensor weights, Tensor count) {
5668
return impl_->get_cuda(indices, weights, count);
5769
}
@@ -95,6 +107,10 @@ static auto embedding_parameter_server_wrapper =
95107
int64_t,
96108
int64_t>())
97109
.def("set_cuda", &EmbeddingParameterServerWrapper::set_cuda)
110+
.def("stream_cuda", &EmbeddingParameterServerWrapper::stream_cuda)
111+
.def(
112+
"stream_sync_cuda",
113+
&EmbeddingParameterServerWrapper::stream_sync_cuda)
98114
.def("get_cuda", &EmbeddingParameterServerWrapper::get_cuda)
99115
.def("compact", &EmbeddingParameterServerWrapper::compact)
100116
.def("flush", &EmbeddingParameterServerWrapper::flush)

fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,18 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
8383
return impl_->set_cuda(indices, weights, count, timestep, is_bwd);
8484
}
8585

86+
void stream_cuda(
87+
const at::Tensor& indices,
88+
const at::Tensor& weights,
89+
const at::Tensor& count,
90+
bool blocking_tensor_copy = true) {
91+
return impl_->stream_cuda(indices, weights, count, blocking_tensor_copy);
92+
}
93+
94+
void stream_sync_cuda() {
95+
return impl_->stream_sync_cuda();
96+
}
97+
8698
void get_cuda(at::Tensor indices, at::Tensor weights, at::Tensor count) {
8799
return impl_->get_cuda(indices, weights, count);
88100
}

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ EmbeddingKVDB::~EmbeddingKVDB() {
203203
}
204204
#ifdef FBGEMM_FBCODE
205205
if (enable_raw_embedding_streaming_) {
206-
weights_stream_thread_->join();
206+
join_stream_tensor_copy_thread();
207+
join_weights_stream_thread();
207208
}
208209
#endif
209210
}
@@ -300,6 +301,39 @@ folly::coro::Task<void> EmbeddingKVDB::tensor_stream(
300301
}
301302
co_return;
302303
}
304+
305+
void EmbeddingKVDB::copy_and_enqueue_stream_tensors(
306+
const at::Tensor& indices,
307+
const at::Tensor& weights,
308+
const at::Tensor& count) {
309+
auto rec = torch::autograd::profiler::record_function_enter_new(
310+
"## EmbeddingKVDB::copy_and_enqueue_stream_tensors ##");
311+
auto stream_item =
312+
tensor_copy(indices, weights, count, kv_db::RocksdbWriteMode::STREAM);
313+
weights_to_stream_queue_.enqueue(stream_item);
314+
rec->record.end();
315+
}
316+
317+
void EmbeddingKVDB::join_stream_tensor_copy_thread() {
318+
auto rec = torch::autograd::profiler::record_function_enter_new(
319+
"## EmbeddingKVDB::join_stream_tensor_copy_thread ##");
320+
if (stream_tensor_copy_thread_ != nullptr &&
321+
stream_tensor_copy_thread_->joinable()) {
322+
stream_tensor_copy_thread_->join();
323+
}
324+
rec->record.end();
325+
}
326+
327+
void EmbeddingKVDB::join_weights_stream_thread() {
328+
if (weights_stream_thread_ != nullptr && weights_stream_thread_->joinable()) {
329+
stop_ = true;
330+
weights_stream_thread_->join();
331+
}
332+
}
333+
334+
uint64_t EmbeddingKVDB::get_weights_to_stream_queue_size() {
335+
return weights_to_stream_queue_.size();
336+
}
303337
#endif
304338

305339
void EmbeddingKVDB::update_cache_and_storage(
@@ -389,6 +423,45 @@ void EmbeddingKVDB::set_cuda(
389423
rec->record.end();
390424
}
391425

426+
void EmbeddingKVDB::stream_cuda(
427+
const at::Tensor& indices,
428+
const at::Tensor& weights,
429+
const at::Tensor& count,
430+
bool blocking_tensor_copy) {
431+
#ifdef FBGEMM_FBCODE
432+
auto rec = torch::autograd::profiler::record_function_enter_new(
433+
"## EmbeddingKVDB::stream_cuda ##");
434+
check_tensor_type_consistency(indices, weights);
435+
// take reference to self to avoid lifetime issues.
436+
auto self = shared_from_this();
437+
std::function<void()>* functor = new std::function<void()>(
438+
[=]() { self->stream(indices, weights, count, blocking_tensor_copy); });
439+
AT_CUDA_CHECK(cudaStreamAddCallback(
440+
at::cuda::getCurrentCUDAStream(),
441+
kv_db_utils::cuda_callback_func,
442+
functor,
443+
0));
444+
rec->record.end();
445+
#endif
446+
}
447+
448+
void EmbeddingKVDB::stream_sync_cuda() {
449+
#ifdef FBGEMM_FBCODE
450+
auto rec = torch::autograd::profiler::record_function_enter_new(
451+
"## EmbeddingKVDB::stream_sync_cuda ##");
452+
// take reference to self to avoid lifetime issues.
453+
auto self = shared_from_this();
454+
std::function<void()>* functor = new std::function<void()>(
455+
[=]() { self->join_stream_tensor_copy_thread(); });
456+
AT_CUDA_CHECK(cudaStreamAddCallback(
457+
at::cuda::getCurrentCUDAStream(),
458+
kv_db_utils::cuda_callback_func,
459+
functor,
460+
0));
461+
rec->record.end();
462+
#endif
463+
}
464+
392465
std::vector<double> EmbeddingKVDB::get_l2cache_perf(
393466
const int64_t step,
394467
const int64_t interval) {
@@ -458,6 +531,9 @@ void EmbeddingKVDB::set(
458531
return;
459532
}
460533
CHECK_EQ(max_D_, weights.size(1));
534+
535+
auto rec = torch::autograd::profiler::record_function_enter_new(
536+
"## EmbeddingKVDB::set_callback ##");
461537
// defer the L2 cache/rocksdb update to the background thread as it could
462538
// be parallelized with other cuda kernels, as long as all updates are
463539
// finished before the next L2 cache lookup
@@ -473,6 +549,7 @@ void EmbeddingKVDB::set(
473549
} else {
474550
update_cache_and_storage(indices, weights, count, write_mode);
475551
}
552+
rec->record.end();
476553
}
477554

478555
void EmbeddingKVDB::get(
@@ -486,6 +563,8 @@ void EmbeddingKVDB::get(
486563
<< num_lookups;
487564
return;
488565
}
566+
auto rec = torch::autograd::profiler::record_function_enter_new(
567+
"## EmbeddingKVDB::get_callback ##");
489568
CHECK_GE(max_D_, weights.size(1));
490569
auto start_ts = facebook::WallClockUtil::NowInUsecFast();
491570
wait_util_filling_work_done();
@@ -546,6 +625,33 @@ void EmbeddingKVDB::get(
546625
get_kv_db_async(indices, weights, count).wait();
547626
}
548627
get_total_duration_ += facebook::WallClockUtil::NowInUsecFast() - start_ts;
628+
rec->record.end();
629+
}
630+
631+
void EmbeddingKVDB::stream(
632+
const at::Tensor& indices,
633+
const at::Tensor& weights,
634+
const at::Tensor& count,
635+
bool blocking_tensor_copy) {
636+
if (!enable_raw_embedding_streaming_) {
637+
return;
638+
}
639+
auto rec = torch::autograd::profiler::record_function_enter_new(
640+
"## EmbeddingKVDB::stream_callback ##");
641+
if (blocking_tensor_copy) {
642+
copy_and_enqueue_stream_tensors(indices, weights, count);
643+
return;
644+
}
645+
// Make sure the previous thread is done before starting a new one
646+
join_stream_tensor_copy_thread();
647+
// Cuda dispatches the host callbacks all in the same CPU thread. But the
648+
// callbacks don't need to be serialized.
649+
// So, We need to spin up a new thread to unblock the CUDA stream, so the CUDA
650+
// can continue executing other host callbacks, eg. get/evict.
651+
stream_tensor_copy_thread_ = std::make_unique<std::thread>([=, this]() {
652+
copy_and_enqueue_stream_tensors(indices, weights, count);
653+
});
654+
rec->record.end();
549655
}
550656

551657
std::shared_ptr<CacheContext> EmbeddingKVDB::get_cache(

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,19 @@ class CacheContext {
8080
/// BWD_L1_CNFLCT_MISS_WRITE_BACK: L1 conflict miss will insert into L2 for
8181
/// embedding update on bwd path
8282
///
83-
/// All the L2 cache filling above will potentially trigger rocksdb write once
84-
/// L2 cache is full
83+
/// All the L2 cache filling above will
84+
/// potentially trigger rocksdb write once L2 cache is full
85+
///
86+
/// STREAM: placeholder for raw embedding streaming requests, it doesn't
87+
/// directly interact with L2 and rocksDB
8588
///
8689
/// Additionally we will do ssd io on L2 flush
8790
enum RocksdbWriteMode {
8891
FWD_ROCKSDB_READ = 0,
8992
FWD_L1_EVICTION = 1,
9093
BWD_L1_CNFLCT_MISS_WRITE_BACK = 2,
9194
FLUSH = 3,
95+
STREAM = 4,
9296
};
9397

9498
/// @ingroup embedding-ssd
@@ -195,6 +199,33 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
195199
const at::Tensor& count,
196200
int64_t sleep_ms = 0);
197201

202+
/// Stream out non-negative elements in <indices> and its paired embeddings
203+
/// from <weights> for the first <count> elements in the tensor.
204+
/// It spins up a thread that will copy all 3 tensors to CPU and inject them
205+
/// into the background queue which will be picked up by another set of thread
206+
/// pools for streaming out to the thrift server (co-located on same host
207+
/// now).
208+
///
209+
/// This is used in cuda stream callback, which doesn't require to be
210+
/// serialized with other callbacks, thus a separate thread is used to
211+
/// maximize the overlapping with other callbacks.
212+
///
213+
/// @param indices The 1D embedding index tensor, should skip on negative
214+
/// value
215+
/// @param weights The 2D tensor that each row(embeddings) is paired up with
216+
/// relative element in <indices>
217+
/// @param count A single element tensor that contains the number of indices
218+
/// to be processed
219+
/// @param blocking_tensor_copy whether to copy the tensors to be streamed in
220+
/// a blocking manner
221+
///
222+
/// @return None
223+
void stream(
224+
const at::Tensor& indices,
225+
const at::Tensor& weights,
226+
const at::Tensor& count,
227+
bool blocking_tensor_copy = true);
228+
198229
/// storage tier counterpart of function get()
199230
virtual folly::SemiFuture<std::vector<folly::Unit>> get_kv_db_async(
200231
const at::Tensor& indices,
@@ -233,6 +264,14 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
233264
const int64_t timestep,
234265
const bool is_bwd = false);
235266

267+
void stream_cuda(
268+
const at::Tensor& indices,
269+
const at::Tensor& weights,
270+
const at::Tensor& count,
271+
bool blocking_tensor_copy = true);
272+
273+
void stream_sync_cuda();
274+
236275
/// export internally collected L2 performance metrics out
237276
///
238277
/// @param step the training step that caller side wants to report the stats
@@ -313,6 +352,28 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
313352
folly::coro::Task<void> tensor_stream(
314353
const at::Tensor& indices,
315354
const at::Tensor& weights);
355+
/*
356+
* Copy the indices, weights and count tensors and enqueue them for
357+
* asynchronous stream.
358+
*/
359+
void copy_and_enqueue_stream_tensors(
360+
const at::Tensor& indices,
361+
const at::Tensor& weights,
362+
const at::Tensor& count);
363+
364+
/*
365+
* Join the stream tensor copy thread, make sure the thread is properly
366+
* finished before creating new.
367+
*/
368+
void join_stream_tensor_copy_thread();
369+
370+
/*
371+
* FOR TESTING: Join the weight stream thread, make sure the thread is
372+
* properly finished for destruction and testing.
373+
*/
374+
void join_weights_stream_thread();
375+
// FOR TESTING: get queue size.
376+
uint64_t get_weights_to_stream_queue_size();
316377
#endif
317378

318379
private:
@@ -452,7 +513,8 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
452513
std::vector<int64_t> table_offsets_;
453514
at::Tensor table_sizes_;
454515
std::unique_ptr<std::thread> weights_stream_thread_;
455-
folly::USPSCQueue<QueueItem, true> weights_to_stream_queue_;
516+
folly::UMPSCQueue<QueueItem, true> weights_to_stream_queue_;
517+
std::unique_ptr<std::thread> stream_tensor_copy_thread_;
456518
}; // class EmbeddingKVDB
457519

458520
} // namespace kv_db

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,17 @@ static auto embedding_rocks_db_wrapper =
545545
torch::arg("timestep"),
546546
torch::arg("is_bwd") = false,
547547
})
548+
.def(
549+
"stream_cuda",
550+
&EmbeddingRocksDBWrapper::stream_cuda,
551+
"",
552+
{
553+
torch::arg("indices"),
554+
torch::arg("weights"),
555+
torch::arg("count"),
556+
torch::arg("blocking_tensor_copy"),
557+
})
558+
.def("stream_sync_cuda", &EmbeddingRocksDBWrapper::stream_sync_cuda)
548559
.def("get_cuda", &EmbeddingRocksDBWrapper::get_cuda)
549560
.def("compact", &EmbeddingRocksDBWrapper::compact)
550561
.def("flush", &EmbeddingRocksDBWrapper::flush)

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
483483
case kv_db::RocksdbWriteMode::FLUSH:
484484
flush_write_dur_ += duration;
485485
break;
486+
case kv_db::RocksdbWriteMode::STREAM:
487+
break;
486488
}
487489
#endif
488490
return folly::collect(futures);

0 commit comments

Comments
 (0)