From d6765e15502d3a2fb08b453b1117d22e68f5ff99 Mon Sep 17 00:00:00 2001 From: Hui Kang Date: Mon, 4 Nov 2024 09:34:01 +0000 Subject: [PATCH] torchrec change for dynamic embedding commit --- torchrec/distributed/embedding.py | 20 ++++- torchrec/distributed/embedding_lookup.py | 16 ++++ .../sharding/rw_sequence_sharding.py | 36 ++++++--- torchrec/distributed/sharding/rw_sharding.py | 76 +++++++++++++++---- 4 files changed, 119 insertions(+), 29 deletions(-) diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 4b16bd0c0..35a36d555 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -170,9 +170,12 @@ def create_sharding_infos_by_sharding( if parameter_sharding.compute_kernel not in [ kernel.value for kernel in EmbeddingComputeKernel ]: - raise ValueError( - f"Compute kernel not supported {parameter_sharding.compute_kernel}" - ) + compute_kernel_params = parameter_sharding.get_params_of_compute_kernel() + if "customized_compute_kernel" in compute_kernel_params and \ + parameter_sharding.compute_kernel != compute_kernel_params["customized_compute_kernel"]: + raise ValueError( + f"Compute kernel not supported {parameter_sharding.compute_kernel}" + ) param_name = "embeddings." + config.name + ".weight" assert param_name in parameter_by_name or param_name in state_dict @@ -200,6 +203,7 @@ def create_sharding_infos_by_sharding( per_table_fused_params, parameter_sharding ) per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params) + per_table_fused_params.update(parameter_sharding.get_params_of_compute_kernel()) sharding_type_to_sharding_infos[parameter_sharding.sharding_type].append( ( @@ -507,12 +511,18 @@ def _initialize_torch_state(self) -> None: # noqa self._model_parallel_name_to_local_shards = OrderedDict() self._model_parallel_name_to_sharded_tensor = OrderedDict() model_parallel_name_to_compute_kernel: Dict[str, str] = {} + customized_table_names = [] for ( table_name, parameter_sharding, ) in self.module_sharding_plan.items(): if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value: continue + if parameter_sharding.compute_kernel not in [compute_kernel.value for + compute_kernel in EmbeddingComputeKernel]: + customized_table_names.append(table_name) + continue + self._model_parallel_name_to_local_shards[table_name] = [] model_parallel_name_to_compute_kernel[table_name] = ( parameter_sharding.compute_kernel @@ -535,6 +545,8 @@ def _initialize_torch_state(self) -> None: # noqa # save local_shards for transforming MP params to shardedTensor for key, v in lookup.state_dict().items(): table_name = key[: -len(".weight")] + if table_name in customized_table_names: + continue self._model_parallel_name_to_local_shards[table_name].extend( v.local_shards() ) @@ -798,7 +810,7 @@ def input_dist( features_before_input_dist=features, unbucketize_permute_tensor=( input_dist.unbucketize_permute_tensor - if isinstance(input_dist, RwSparseFeaturesDist) + if hasattr(input_dist, "unbucketize_permute_tensor") else None ), ) diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 341416058..ef9136d39 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -128,6 +128,14 @@ def __init__( pg: Optional[dist.ProcessGroup] = None, device: Optional[torch.device] = None, ) -> None: + def _exist_customized_compute_kernel(config: GroupedEmbeddingConfig): + # only confirm that config.compute_kernel not in EmbeddingComputeKernel + exist_key = "customized_compute_kernel" + if exist_key in config.fused_params: + if config.compute_kernel == config.fused_params[exist_key]: + return True + return False + # TODO rename to _create_embedding_kernel def _create_lookup( config: GroupedEmbeddingConfig, @@ -147,6 +155,14 @@ def _create_lookup( pg=pg, device=device, ) + elif _exist_customized_compute_kernel(config): + assert "ComputeKernel" in config.fused_params + ComputeKernel = config.fused_params["ComputeKernel"] + return ComputeKernel( + config=config, + pg=pg, + device=device, + ) else: raise ValueError( f"Compute kernel not supported {config.compute_kernel}" diff --git a/torchrec/distributed/sharding/rw_sequence_sharding.py b/torchrec/distributed/sharding/rw_sequence_sharding.py index 5b15373bb..a31f42d32 100644 --- a/torchrec/distributed/sharding/rw_sequence_sharding.py +++ b/torchrec/distributed/sharding/rw_sequence_sharding.py @@ -115,17 +115,31 @@ def create_input_dist( ) -> BaseSparseFeaturesDist[KeyedJaggedTensor]: num_features = self._get_num_features() feature_hash_sizes = self._get_feature_hash_sizes() - return RwSparseFeaturesDist( - # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got - # `Optional[ProcessGroup]`. - pg=self._pg, - num_features=num_features, - feature_hash_sizes=feature_hash_sizes, - device=device if device is not None else self._device, - is_sequence=True, - has_feature_processor=self._has_feature_processor, - need_pos=False, - ) + if self._customized_dist: + return self._customized_dist( + # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + pg=self._pg, + num_features=num_features, + feature_hash_sizes=feature_hash_sizes, + device=device if device is not None else self._device, + is_sequence=True, + has_feature_processor=self._has_feature_processor, + need_pos=False, + dist_type_per_feature=self._dist_type_per_feature, + ) + else: + return RwSparseFeaturesDist( + # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + pg=self._pg, + num_features=num_features, + feature_hash_sizes=feature_hash_sizes, + device=device if device is not None else self._device, + is_sequence=True, + has_feature_processor=self._has_feature_processor, + need_pos=False, + ) def create_lookup( self, diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index afbfba94c..33b981ce6 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -125,6 +125,7 @@ def __init__( device = torch.device("cpu") self._device: torch.device = device sharded_tables_per_rank = self._shard(sharding_infos) + self._init_customized_distributor(sharding_infos) self._need_pos = need_pos self._grouped_embedding_configs_per_rank: List[List[GroupedEmbeddingConfig]] = ( [] @@ -161,6 +162,15 @@ def _shard( ), ) + if info.param_sharding.compute_kernel not in [ + kernel.value for kernel in EmbeddingComputeKernel + ]: + compute_kernel = info.param_sharding.get_params_of_compute_kernel( + )["customized_compute_kernel"] + else: + compute_kernel = EmbeddingComputeKernel( + info.param_sharding.compute_kernel + ) for rank in range(self._world_size): tables_per_rank[rank].append( ShardedEmbeddingTable( @@ -175,9 +185,7 @@ def _shard( has_feature_processor=info.embedding_config.has_feature_processor, local_rows=shards[rank].shard_sizes[0], local_cols=info.embedding_config.embedding_dim, - compute_kernel=EmbeddingComputeKernel( - info.param_sharding.compute_kernel - ), + compute_kernel=compute_kernel, local_metadata=shards[rank], global_metadata=global_metadata, weight_init_max=info.embedding_config.weight_init_max, @@ -187,6 +195,31 @@ def _shard( ) return tables_per_rank + def _init_customized_distributor(self, sharding_infos: List[EmbeddingShardingInfo]): + common_dist_type = None + common_customized_dist = None + + self._dist_type_per_feature: Dict[str, str] = {} + for sharding_info in sharding_infos: + if "dist_type" in sharding_info.fused_params: + dist_type = sharding_info.fused_params["dist_type"] + if common_dist_type == None: + common_dist_type = dist_type + else: + assert(dist_type != common_dist_type, "Customized distributor type must keep the same.") + dist_ = sharding_info.fused_params["Distributor"] + if common_customized_dist == None: + common_customized_dist = dist_ + else: + assert(common_customized_dist != dist_, "Customized distributor implementation must keep the same.") + else: + dist_type = "continuous" + feature_names = sharding_info.embedding_config.feature_names + for f in feature_names: + self._dist_type_per_feature[f] = dist_type + + self._customized_dist = common_customized_dist + def embedding_dims(self) -> List[int]: embedding_dims = [] for grouped_config in self._grouped_embedding_configs: @@ -465,17 +498,32 @@ def create_input_dist( ) -> BaseSparseFeaturesDist[KeyedJaggedTensor]: num_features = self._get_num_features() feature_hash_sizes = self._get_feature_hash_sizes() - return RwSparseFeaturesDist( - # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got - # `Optional[ProcessGroup]`. - pg=self._pg, - num_features=num_features, - feature_hash_sizes=feature_hash_sizes, - device=device if device is not None else self._device, - is_sequence=False, - has_feature_processor=self._has_feature_processor, - need_pos=self._need_pos, - ) + + if self._customized_dist: + return self._customized_dist( + # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + pg=self._pg, + num_features=num_features, + feature_hash_sizes=feature_hash_sizes, + device=device if device is not None else self._device, + is_sequence=False, + has_feature_processor=self._has_feature_processor, + need_pos=self._need_pos, + dist_type_per_feature=self._dist_type_per_feature, + ) + else: + return RwSparseFeaturesDist( + # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + pg=self._pg, + num_features=num_features, + feature_hash_sizes=feature_hash_sizes, + device=device if device is not None else self._device, + is_sequence=False, + has_feature_processor=self._has_feature_processor, + need_pos=self._need_pos, + ) def create_lookup( self,