6060
6161from fbgemm_gpu .utils .loader import load_torch_module , load_torch_module_bc
6262
63+ from torch .autograd .profiler import record_function
64+
6365try :
6466 load_torch_module (
6567 "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_training_gpu" ,
@@ -626,6 +628,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
626628 lxu_cache_locations_list : List [Tensor ]
627629 lxu_cache_locations_empty : Tensor
628630 timesteps_prefetched : List [int ]
631+ prefetched_info : List [Tuple [Tensor , Tensor ]]
629632 record_cache_metrics : RecordCacheMetrics
630633 # pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
631634 uvm_cache_stats : torch .Tensor
@@ -690,6 +693,8 @@ def __init__( # noqa C901
690693 embedding_table_index_type : torch .dtype = torch .int64 ,
691694 embedding_table_offset_type : torch .dtype = torch .int64 ,
692695 embedding_shard_info : Optional [List [Tuple [int , int , int , int ]]] = None ,
696+ enable_raw_embedding_streaming : bool = False ,
697+ res_params : Optional [RESParams ] = None ,
693698 ) -> None :
694699 super (SplitTableBatchedEmbeddingBagsCodegen , self ).__init__ ()
695700 self .uuid = str (uuid .uuid4 ())
@@ -700,6 +705,7 @@ def __init__( # noqa C901
700705 )
701706
702707 self .logging_table_name : str = self .get_table_name_for_logging (table_names )
708+ self .enable_raw_embedding_streaming : bool = enable_raw_embedding_streaming
703709 self .pooling_mode = pooling_mode
704710 self .is_nobag : bool = self .pooling_mode == PoolingMode .NONE
705711
@@ -1460,6 +1466,30 @@ def __init__( # noqa C901
14601466 )
14611467 self .embedding_table_offset_type : torch .dtype = embedding_table_offset_type
14621468
1469+ self .prefetched_info : List [Tuple [Tensor , Tensor ]] = torch .jit .annotate (
1470+ List [Tuple [Tensor , Tensor ]], []
1471+ )
1472+ if self .enable_raw_embedding_streaming :
1473+ self .res_params : RESParams = res_params or RESParams ()
1474+ self .res_params .table_sizes = [0 ] + list (accumulate (rows ))
1475+ res_port_from_env = os .getenv ("LOCAL_RES_PORT" )
1476+ self .res_params .res_server_port = (
1477+ int (res_port_from_env ) if res_port_from_env else 0
1478+ )
1479+ # pyre-fixme[4]: Attribute must be annotated.
1480+ self ._raw_embedding_streamer = torch .classes .fbgemm .RawEmbeddingStreamer (
1481+ self .uuid ,
1482+ self .enable_raw_embedding_streaming ,
1483+ self .res_params .res_store_shards ,
1484+ self .res_params .res_server_port ,
1485+ self .res_params .table_names ,
1486+ self .res_params .table_offsets ,
1487+ self .res_params .table_sizes ,
1488+ )
1489+ logging .info (
1490+ f"{ self .uuid } raw embedding streaming enabled with { self .res_params = } "
1491+ )
1492+
14631493 @torch .jit .ignore
14641494 def log (self , msg : str ) -> None :
14651495 """
@@ -2521,7 +2551,13 @@ def _prefetch(
25212551 self .local_uvm_cache_stats .zero_ ()
25222552 self ._report_io_size_count ("prefetch_input" , indices )
25232553
2554+ # streaming before updating the cache
2555+ self .raw_embedding_stream ()
2556+
25242557 final_lxu_cache_locations = torch .empty_like (indices , dtype = torch .int32 )
2558+ linear_cache_indices_merged = torch .zeros (
2559+ 0 , dtype = indices .dtype , device = indices .device
2560+ )
25252561 for (
25262562 partial_indices ,
25272563 partial_lxu_cache_locations ,
@@ -2537,6 +2573,9 @@ def _prefetch(
25372573 vbe_metadata .max_B if vbe_metadata is not None else - 1 ,
25382574 base_offset ,
25392575 )
2576+ linear_cache_indices_merged = torch .cat (
2577+ [linear_cache_indices_merged , linear_cache_indices ]
2578+ )
25402579
25412580 if (
25422581 self .record_cache_metrics .record_cache_miss_counter
@@ -2617,6 +2656,23 @@ def _prefetch(
26172656 if self .should_log ():
26182657 self .print_uvm_cache_stats (use_local_cache = False )
26192658
2659+ if self .enable_raw_embedding_streaming :
2660+ with record_function (
2661+ "## uvm_save_prefetched_rows {} {} ##" .format (self .timestep , self .uuid )
2662+ ):
2663+ (
2664+ linear_unique_indices ,
2665+ linear_unique_indices_length ,
2666+ _ ,
2667+ ) = torch .ops .fbgemm .get_unique_indices (
2668+ linear_cache_indices_merged ,
2669+ self .total_cache_hash_size ,
2670+ compute_count = False ,
2671+ )
2672+ self .prefetched_info .append (
2673+ (linear_unique_indices , linear_unique_indices_length )
2674+ )
2675+
26202676 def should_log (self ) -> bool :
26212677 """Determines if we should log for this step, using exponentially decreasing frequency.
26222678
@@ -3829,6 +3885,55 @@ def _debug_print_input_stats_factory_null(
38293885 return _debug_print_input_stats_factory_impl
38303886 return _debug_print_input_stats_factory_null
38313887
3888+ @torch .jit .ignore
3889+ def raw_embedding_stream (self ) -> None :
3890+ if not self .enable_raw_embedding_streaming :
3891+ return None
3892+ # when pipelining is enabled
3893+ # prefetch in iter i happens before the backward sparse in iter i - 1
3894+ # so embeddings for iter i - 1's changed ids are not updated.
3895+ # so we can only fetch the indices from the iter i - 2
3896+ # when pipelining is disabled
3897+ # prefetch in iter i happens before forward iter i
3898+ # so we can get the iter i - 1's changed ids safely.
3899+ target_prev_iter = 1
3900+ if self .prefetch_pipeline :
3901+ target_prev_iter = 2
3902+ if not len (self .prefetched_info ) > (target_prev_iter - 1 ):
3903+ return None
3904+ with record_function (
3905+ "## uvm_lookup_prefetched_rows {} {} ##" .format (self .timestep , self .uuid )
3906+ ):
3907+ (updated_indices , updated_count ) = self .prefetched_info .pop (0 )
3908+ updated_locations = torch .ops .fbgemm .lxu_cache_lookup (
3909+ updated_indices ,
3910+ self .lxu_cache_state ,
3911+ self .total_cache_hash_size ,
3912+ gather_cache_stats = False , # not collecting cache stats
3913+ num_uniq_cache_indices = updated_count ,
3914+ )
3915+ updated_weights = torch .empty (
3916+ [updated_indices .size ()[0 ], self .max_D_cache ],
3917+ # pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `dtype`, expected `Optional[dtype]` but got `Union[Module, dtype, Tensor]`
3918+ dtype = self .lxu_cache_weights .dtype ,
3919+ # pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `device`, expected `Union[None, int, str, device]` but got `Union[Module, device, Tensor]`
3920+ device = self .lxu_cache_weights .device ,
3921+ )
3922+ torch .ops .fbgemm .masked_index_select (
3923+ updated_weights ,
3924+ updated_locations ,
3925+ self .lxu_cache_weights ,
3926+ updated_count ,
3927+ )
3928+ # stream weights
3929+ self ._raw_embedding_streamer .stream (
3930+ updated_indices .to (device = torch .device ("cpu" )),
3931+ updated_weights .to (device = torch .device ("cpu" )),
3932+ updated_count .to (device = torch .device ("cpu" )),
3933+ False , # require_tensor_copy
3934+ False , # blocking_tensor_copy
3935+ )
3936+
38323937
38333938class DenseTableBatchedEmbeddingBagsCodegen (nn .Module ):
38343939 """
0 commit comments