Skip to content

Track updated rows in SSDTBE #4211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def get_unique_indices_v2(
torch.Tensor,
torch.Tensor,
Optional[torch.Tensor],
Tuple[torch.Tensor, torch.Tensor],
],
Tuple[torch.Tensor, torch.Tensor],
]:
"""
A wrapper for get_unique_indices for overloading the return type
Expand All @@ -43,7 +43,6 @@ def get_unique_indices_v2(
return ret[:-1]
if compute_inverse_indices:
# Return (unique_indices, length, inverse_indices)
# pyre-fixme[7]: The arity arity of this return is wrong (3 vs 4)
return ret[0], ret[1], ret[3]
# Return (unique_indices, length)
return ret[:-2]
266 changes: 266 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,62 @@ def __init__(
# SSD scratch pad index queue lookup completion event
self.ssd_event_sp_idxq_lookup: torch.cuda.streams.Event = torch.cuda.Event()

if self.enable_raw_embedding_streaming:
# RES reuse the eviction stream
self.ssd_event_cache_streamed: torch.cuda.streams.Event = torch.cuda.Event()
self.ssd_event_cache_streaming_synced: torch.cuda.streams.Event = (
torch.cuda.Event()
)
self.ssd_event_cache_streaming_computed: torch.cuda.streams.Event = (
torch.cuda.Event()
)
self.ssd_event_sp_streamed: torch.cuda.streams.Event = torch.cuda.Event()

# Updated buffers
self.register_buffer(
"lxu_cache_updated_weights",
torch.ops.fbgemm.new_unified_tensor(
torch.zeros(
1,
device=self.current_device,
dtype=cache_dtype,
),
self.lxu_cache_weights.shape,
is_host_mapped=self.uvm_host_mapped,
),
)

# For storing embedding indices to update to
self.register_buffer(
"lxu_cache_updated_indices",
torch.ops.fbgemm.new_unified_tensor(
torch.zeros(
1,
device=self.current_device,
dtype=torch.long,
),
(self.lxu_cache_weights.shape[0],),
is_host_mapped=self.uvm_host_mapped,
),
)

# For storing the number of updated rows
self.register_buffer(
"lxu_cache_updated_count",
torch.ops.fbgemm.new_unified_tensor(
torch.zeros(
1,
device=self.current_device,
dtype=torch.int,
),
(1,),
is_host_mapped=self.uvm_host_mapped,
),
)

# (Indices, Count)
self.prefetched_info: List[Tuple[Tensor, Tensor]] = []

self.timesteps_prefetched: List[int] = []
# TODO: add type annotation
# pyre-fixme[4]: Attribute must be annotated.
Expand Down Expand Up @@ -1135,6 +1191,7 @@ def evict(
is_rows_uvm (bool): A flag to indicate whether `rows` is a UVM
tensor (which is accessible on both host and
device)
is_bwd (bool): A flag to indicate if the eviction is during backward
Returns:
None
"""
Expand All @@ -1160,6 +1217,95 @@ def evict(
if post_event is not None:
stream.record_event(post_event)

def raw_embedding_stream_sync(
self,
stream: torch.cuda.Stream,
pre_event: Optional[torch.cuda.Event],
post_event: Optional[torch.cuda.Event],
name: Optional[str] = "",
) -> None:
"""
Blocking wait the copy operation of the tensors to be streamed,
to make sure they are not overwritten
Args:
stream (Stream): The CUDA stream that cudaStreamAddCallback will
synchronize the host function with. Moreover, the
asynchronous D->H memory copies will operate on
this stream
pre_event (Event): The CUDA event that the stream has to wait on
post_event (Event): The CUDA event that the current will record on
when the eviction is done
Returns:
None
"""
with record_function(f"## ssd_stream_{name} ##"):
with torch.cuda.stream(stream):
if pre_event is not None:
stream.wait_event(pre_event)

self.record_function_via_dummy_profile(
f"## ssd_stream_sync_{name} ##",
self.ssd_db.stream_sync_cuda,
)

if post_event is not None:
stream.record_event(post_event)

def raw_embedding_stream(
self,
rows: Tensor,
indices_cpu: Tensor,
actions_count_cpu: Tensor,
stream: torch.cuda.Stream,
pre_event: Optional[torch.cuda.Event],
post_event: Optional[torch.cuda.Event],
is_rows_uvm: bool,
blocking_tensor_copy: bool = True,
name: Optional[str] = "",
) -> None:
"""
Stream data from the given input tensors to a remote service
Args:
rows (Tensor): The 2D tensor that contains rows to evict
indices_cpu (Tensor): The 1D CPU tensor that contains the row
indices that the rows will be evicted to
actions_count_cpu (Tensor): A scalar tensor that contains the
number of rows that the evict function
has to process
stream (Stream): The CUDA stream that cudaStreamAddCallback will
synchronize the host function with. Moreover, the
asynchronous D->H memory copies will operate on
this stream
pre_event (Event): The CUDA event that the stream has to wait on
post_event (Event): The CUDA event that the current will record on
when the eviction is done
is_rows_uvm (bool): A flag to indicate whether `rows` is a UVM
tensor (which is accessible on both host and
device)
Returns:
None
"""
with record_function(f"## ssd_stream_{name} ##"):
with torch.cuda.stream(stream):
if pre_event is not None:
stream.wait_event(pre_event)

rows_cpu = rows if is_rows_uvm else self.to_pinned_cpu(rows)

rows.record_stream(stream)

self.record_function_via_dummy_profile(
f"## ssd_stream_{name} ##",
self.ssd_db.stream_cuda,
indices_cpu,
rows_cpu,
actions_count_cpu,
blocking_tensor_copy,
)

if post_event is not None:
stream.record_event(post_event)

def _evict_from_scratch_pad(self, grad: Tensor) -> None:
"""
Evict conflict missed rows from a scratch pad
Expand Down Expand Up @@ -1196,6 +1342,18 @@ def _evict_from_scratch_pad(self, grad: Tensor) -> None:
if not do_evict:
return

if self.enable_raw_embedding_streaming:
self.raw_embedding_stream(
rows=inserted_rows,
indices_cpu=post_bwd_evicted_indices_cpu,
actions_count_cpu=actions_count_cpu,
stream=self.ssd_eviction_stream,
pre_event=self.ssd_event_backward,
post_event=self.ssd_event_sp_streamed,
is_rows_uvm=True,
blocking_tensor_copy=True,
name="scratch_pad",
)
self.evict(
rows=inserted_rows,
indices_cpu=post_bwd_evicted_indices_cpu,
Expand Down Expand Up @@ -1437,6 +1595,79 @@ def _prefetch( # noqa C901
masks=torch.where(evicted_indices != -1, 1, 0),
count=actions_count_gpu,
)
has_raw_embedding_streaming = False
if self.enable_raw_embedding_streaming:
# when pipelining is enabled
# prefetch in iter i happens before the backward sparse in iter i - 1
# so embeddings for iter i - 1's changed ids are not updated.
# so we can only fetch the indices from the iter i - 2
# when pipelining is disabled
# prefetch in iter i happens before forward iter i
# so we can get the iter i - 1's changed ids safely.
target_prev_iter = 1
if self.prefetch_pipeline:
target_prev_iter = 2
if len(self.prefetched_info) > (target_prev_iter - 1):
with record_function(
"## ssd_lookup_prefetched_rows {} {} ##".format(
self.timestep, self.tbe_unique_id
)
):
# wait for the copy to finish before overwriting the buffer
self.raw_embedding_stream_sync(
stream=self.ssd_eviction_stream,
pre_event=self.ssd_event_cache_streamed,
post_event=self.ssd_event_cache_streaming_synced,
name="cache_update",
)
current_stream.wait_event(self.ssd_event_cache_streaming_synced)
(updated_indices, updated_counts_gpu) = (
self.prefetched_info.pop(0)
)
self.lxu_cache_updated_indices[: updated_indices.size(0)].copy_(
updated_indices,
non_blocking=True,
)
self.lxu_cache_updated_count[:1].copy_(
updated_counts_gpu, non_blocking=True
)
has_raw_embedding_streaming = True

with record_function(
"## ssd_save_prefetched_rows {} {} ##".format(
self.timestep, self.tbe_unique_id
)
):
masked_updated_indices = torch.where(
torch.where(lxu_cache_locations != -1, True, False),
linear_cache_indices,
-1,
)

(
uni_updated_indices,
uni_updated_indices_length,
) = get_unique_indices_v2(
masked_updated_indices,
self.total_hash_size,
compute_count=False,
compute_inverse_indices=False,
)
assert uni_updated_indices is not None
assert uni_updated_indices_length is not None
# The unique indices has 1 more -1 element than needed,
# which might make the tensor length go out of range
# compared to the pre-allocated buffer.
unique_len = min(
self.lxu_cache_weights.size(0),
uni_updated_indices.size(0),
)
self.prefetched_info.append(
(
uni_updated_indices.narrow(0, 0, unique_len),
uni_updated_indices_length.clamp(max=unique_len),
)
)

with record_function("## ssd_d2h_inserted_indices ##"):
# Transfer actions_count and insert_indices right away to
Expand Down Expand Up @@ -1580,6 +1811,41 @@ def _prefetch( # noqa C901
# Ensure that D2H is done
current_stream.wait_event(self.ssd_event_get_inputs_cpy)

if self.enable_raw_embedding_streaming and has_raw_embedding_streaming:
current_stream.wait_event(self.ssd_event_sp_streamed)
with record_function(
"## ssd_compute_updated_rows {} {} ##".format(
self.timestep, self.tbe_unique_id
)
):
# cache rows that are changed in the previous iteration
updated_cache_locations = torch.ops.fbgemm.lxu_cache_lookup(
self.lxu_cache_updated_indices,
self.lxu_cache_state,
self.total_hash_size,
self.gather_ssd_cache_stats,
self.local_ssd_cache_stats,
)
torch.ops.fbgemm.masked_index_select(
self.lxu_cache_updated_weights,
updated_cache_locations,
self.lxu_cache_weights,
self.lxu_cache_updated_count,
)
current_stream.record_event(self.ssd_event_cache_streaming_computed)

self.raw_embedding_stream(
rows=self.lxu_cache_updated_weights,
indices_cpu=self.lxu_cache_updated_indices,
actions_count_cpu=self.lxu_cache_updated_count,
stream=self.ssd_eviction_stream,
pre_event=self.ssd_event_cache_streaming_computed,
post_event=self.ssd_event_cache_streamed,
is_rows_uvm=True,
blocking_tensor_copy=False,
name="cache_update",
)

if self.gather_ssd_cache_stats:
# call to collect past SSD IO dur right before next rocksdb IO

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ class EmbeddingParameterServerWrapper : public torch::jit::CustomClassHolder {
return impl_->set_cuda(indices, weights, count, timestep, is_bwd);
}

void stream_cuda(
const Tensor& indices,
const Tensor& weights,
const Tensor& count,
bool blocking_tensor_copy = true) {
return impl_->stream_cuda(indices, weights, count, blocking_tensor_copy);
}

void stream_sync_cuda() {
return impl_->stream_sync_cuda();
}

void get_cuda(Tensor indices, Tensor weights, Tensor count) {
return impl_->get_cuda(indices, weights, count);
}
Expand Down Expand Up @@ -95,6 +107,10 @@ static auto embedding_parameter_server_wrapper =
int64_t,
int64_t>())
.def("set_cuda", &EmbeddingParameterServerWrapper::set_cuda)
.def("stream_cuda", &EmbeddingParameterServerWrapper::stream_cuda)
.def(
"stream_sync_cuda",
&EmbeddingParameterServerWrapper::stream_sync_cuda)
.def("get_cuda", &EmbeddingParameterServerWrapper::get_cuda)
.def("compact", &EmbeddingParameterServerWrapper::compact)
.def("flush", &EmbeddingParameterServerWrapper::flush)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,18 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
return impl_->set_cuda(indices, weights, count, timestep, is_bwd);
}

void stream_cuda(
const at::Tensor& indices,
const at::Tensor& weights,
const at::Tensor& count,
bool blocking_tensor_copy = true) {
return impl_->stream_cuda(indices, weights, count, blocking_tensor_copy);
}

void stream_sync_cuda() {
return impl_->stream_sync_cuda();
}

void get_cuda(at::Tensor indices, at::Tensor weights, at::Tensor count) {
return impl_->get_cuda(indices, weights, count);
}
Expand Down
Loading
Loading