Skip to content

Commit e860a2c

Browse files
faran928facebook-github-bot
authored andcommitted
Enable Ebc Heterogenous Sharding (#2837)
Summary: Pull Request resolved: #2837 Enable Ebc Heterogenous Sharding so that a single Ebc table can be sharded across hbm and cpu Reviewed By: jiayisuse Differential Revision: D70229136 fbshipit-source-id: baf190c311df95df2c17abe0d58b86d615dd4c56
1 parent 44d04b5 commit e860a2c

File tree

5 files changed

+92
-18
lines changed

5 files changed

+92
-18
lines changed

torchrec/distributed/embedding_lookup.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -822,25 +822,30 @@ def __init__(
822822
device: Optional[torch.device] = None,
823823
feature_processor: Optional[BaseGroupedFeatureProcessor] = None,
824824
fused_params: Optional[Dict[str, Any]] = None,
825+
shard_index: Optional[int] = None,
825826
) -> None:
826827
# TODO rename to _create_embedding_kernel
827828
def _create_lookup(
828829
config: GroupedEmbeddingConfig,
829830
device: Optional[torch.device] = None,
830831
fused_params: Optional[Dict[str, Any]] = None,
832+
shard_index: Optional[int] = None,
831833
) -> BaseBatchedEmbeddingBag[
832834
Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
833835
]:
834836
return QuantBatchedEmbeddingBag(
835837
config=config,
836838
device=device,
837839
fused_params=fused_params,
840+
shard_index=shard_index,
838841
)
839842

840843
super().__init__()
841844
self._emb_modules: nn.ModuleList = nn.ModuleList()
842845
for config in grouped_configs:
843-
self._emb_modules.append(_create_lookup(config, device, fused_params))
846+
self._emb_modules.append(
847+
_create_lookup(config, device, fused_params, shard_index)
848+
)
844849

845850
self._feature_splits: List[int] = [
846851
config.num_features() for config in grouped_configs
@@ -1030,6 +1035,7 @@ def __init__(
10301035
world_size: int,
10311036
fused_params: Optional[Dict[str, Any]] = None,
10321037
device: Optional[torch.device] = None,
1038+
device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None,
10331039
) -> None:
10341040
super().__init__()
10351041
self._embedding_lookups_per_rank: List[
@@ -1047,6 +1053,11 @@ def __init__(
10471053
self._is_empty_rank: List[bool] = []
10481054
for rank in range(world_size):
10491055
empty_rank = len(grouped_configs_per_rank[rank]) == 0
1056+
# Propagate shard index to get the correct runtime_device based on shard metadata
1057+
# in case of heterogenous sharding of a single table across different device types
1058+
shard_index = (
1059+
rank if isinstance(device_type_from_sharding_infos, tuple) else None
1060+
)
10501061
self._is_empty_rank.append(empty_rank)
10511062
if not empty_rank:
10521063
self._embedding_lookups_per_rank.append(
@@ -1055,6 +1066,7 @@ def __init__(
10551066
grouped_configs=grouped_configs_per_rank[rank],
10561067
device=rank_device(device_type, rank),
10571068
fused_params=fused_params,
1069+
shard_index=shard_index,
10581070
)
10591071
)
10601072

torchrec/distributed/embeddingbag.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,27 @@ def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
113113
)
114114

115115

116-
def get_device_from_parameter_sharding(ps: ParameterSharding) -> str:
117-
# pyre-ignore
118-
return ps.sharding_spec.shards[0].placement.device().type
116+
def get_device_from_parameter_sharding(
117+
ps: ParameterSharding,
118+
) -> Union[str, Tuple[str, ...]]:
119+
"""
120+
Returns list of device type per shard if table is sharded across different
121+
device type, else reutrns single device type for the table parameter
122+
"""
123+
if not isinstance(ps.sharding_spec, EnumerableShardingSpec):
124+
raise ValueError("Expected EnumerableShardingSpec as input to the function")
125+
126+
device_type_list: Tuple[str, ...] = tuple(
127+
# pyre-fixme[16]: `Optional` has no attribute `device`
128+
[shard.placement.device().type for shard in ps.sharding_spec.shards]
129+
)
130+
if len(set(device_type_list)) == 1:
131+
return device_type_list[0]
132+
else:
133+
assert (
134+
ps.sharding_type == "row_wise"
135+
), "Only row_wise sharding supports sharding across multiple device types for a table"
136+
return device_type_list
119137

120138

121139
def replace_placement_with_meta_device(
@@ -319,7 +337,7 @@ def create_sharding_infos_by_sharding_device_group(
319337
prefix: str,
320338
fused_params: Optional[Dict[str, Any]],
321339
suffix: Optional[str] = "weight",
322-
) -> Dict[Tuple[str, str], List[EmbeddingShardingInfo]]:
340+
) -> Dict[Tuple[str, Union[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]]:
323341

324342
if fused_params is None:
325343
fused_params = {}
@@ -335,7 +353,7 @@ def create_sharding_infos_by_sharding_device_group(
335353
shared_feature[feature_name] = True
336354

337355
sharding_type_device_group_to_sharding_infos: Dict[
338-
Tuple[str, str], List[EmbeddingShardingInfo]
356+
Tuple[str, Union[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]
339357
] = {}
340358

341359
# state_dict returns parameter.Tensor, which loses parameter level attributes

torchrec/distributed/quant_embedding_kernel.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def __init__(
232232
pg: Optional[dist.ProcessGroup] = None,
233233
device: Optional[torch.device] = None,
234234
fused_params: Optional[Dict[str, Any]] = None,
235+
shard_index: Optional[int] = None,
235236
) -> None:
236237
super().__init__(config, pg, device)
237238

@@ -253,7 +254,9 @@ def __init__(
253254
fused_params
254255
)
255256

256-
self._runtime_device: torch.device = _get_runtime_device(device, config)
257+
self._runtime_device: torch.device = _get_runtime_device(
258+
device, config, shard_index
259+
)
257260
# 16 for CUDA, 1 for others like CPU and MTIA.
258261
self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1
259262
embedding_specs = []

torchrec/distributed/quant_embeddingbag.py

+51-11
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
IntNBitTableBatchedEmbeddingBagsCodegen,
1616
)
1717
from torch import nn
18+
19+
from torch.distributed._shard.sharding_spec import EnumerableShardingSpec
1820
from torchrec.distributed.embedding_lookup import EmbeddingComputeKernel
1921
from torchrec.distributed.embedding_sharding import (
2022
EmbeddingSharding,
@@ -68,14 +70,33 @@
6870
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
6971

7072

71-
def get_device_from_parameter_sharding(ps: ParameterSharding) -> str:
72-
# pyre-ignore
73-
return ps.sharding_spec.shards[0].placement.device().type
73+
def get_device_from_parameter_sharding(
74+
ps: ParameterSharding,
75+
) -> Union[str, Tuple[str, ...]]:
76+
"""
77+
Returns list of device type per shard if table is sharded across
78+
different device type, else reutrns single device type for the
79+
table parameter.
80+
"""
81+
if not isinstance(ps.sharding_spec, EnumerableShardingSpec):
82+
raise ValueError("Expected EnumerableShardingSpec as input to the function")
83+
84+
device_type_list: Tuple[str, ...] = tuple(
85+
# pyre-fixme[16]: `Optional` has no attribute `device`
86+
[shard.placement.device().type for shard in ps.sharding_spec.shards]
87+
)
88+
if len(set(device_type_list)) == 1:
89+
return device_type_list[0]
90+
else:
91+
assert (
92+
ps.sharding_type == "row_wise"
93+
), "Only row_wise sharding supports sharding across multiple device types for a table"
94+
return device_type_list
7495

7596

7697
def get_device_from_sharding_infos(
7798
emb_shard_infos: List[EmbeddingShardingInfo],
78-
) -> str:
99+
) -> Union[str, Tuple[str, ...]]:
79100
res = list(
80101
{
81102
get_device_from_parameter_sharding(ps.param_sharding)
@@ -86,6 +107,13 @@ def get_device_from_sharding_infos(
86107
return res[0]
87108

88109

110+
def get_device_for_first_shard_from_sharding_infos(
111+
emb_shard_infos: List[EmbeddingShardingInfo],
112+
) -> str:
113+
device_type = get_device_from_sharding_infos(emb_shard_infos)
114+
return device_type[0] if isinstance(device_type, tuple) else device_type
115+
116+
89117
torch.fx.wrap("len")
90118

91119

@@ -103,13 +131,19 @@ def create_infer_embedding_bag_sharding(
103131
NullShardingContext, InputDistOutputs, List[torch.Tensor], torch.Tensor
104132
]:
105133
propogate_device: bool = get_propogate_device()
134+
device_type_from_sharding_infos: Union[str, Tuple[str, ...]] = (
135+
get_device_from_sharding_infos(sharding_infos)
136+
)
106137
if sharding_type == ShardingType.TABLE_WISE.value:
107138
return InferTwEmbeddingSharding(
108139
sharding_infos, env, device=device if propogate_device else None
109140
)
110141
elif sharding_type == ShardingType.ROW_WISE.value:
111142
return InferRwPooledEmbeddingSharding(
112-
sharding_infos, env, device=device if propogate_device else None
143+
sharding_infos,
144+
env,
145+
device=device if propogate_device else None,
146+
device_type_from_sharding_infos=device_type_from_sharding_infos,
113147
)
114148
elif sharding_type == ShardingType.COLUMN_WISE.value:
115149
return InferCwPooledEmbeddingSharding(
@@ -148,12 +182,12 @@ def __init__(
148182
module.embedding_bag_configs()
149183
)
150184
self._sharding_type_device_group_to_sharding_infos: Dict[
151-
Tuple[str, str], List[EmbeddingShardingInfo]
185+
Tuple[str, Union[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]
152186
] = create_sharding_infos_by_sharding_device_group(
153187
module, table_name_to_parameter_sharding, "embedding_bags.", fused_params
154188
)
155189
self._sharding_type_device_group_to_sharding: Dict[
156-
Tuple[str, str],
190+
Tuple[str, Union[str, Tuple[str, ...]]],
157191
EmbeddingSharding[
158192
NullShardingContext,
159193
InputDistOutputs,
@@ -167,7 +201,11 @@ def __init__(
167201
(
168202
env
169203
if not isinstance(env, Dict)
170-
else env[get_device_from_sharding_infos(embedding_configs)]
204+
else env[
205+
get_device_for_first_shard_from_sharding_infos(
206+
embedding_configs
207+
)
208+
]
171209
),
172210
device if get_propogate_device() else None,
173211
)
@@ -250,7 +288,7 @@ def tbes_configs(
250288

251289
def sharding_type_device_group_to_sharding_infos(
252290
self,
253-
) -> Dict[Tuple[str, str], List[EmbeddingShardingInfo]]:
291+
) -> Dict[Tuple[str, Union[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]]:
254292
return self._sharding_type_device_group_to_sharding_infos
255293

256294
def embedding_bag_configs(self) -> List[EmbeddingBagConfig]:
@@ -329,7 +367,9 @@ def copy(self, device: torch.device) -> nn.Module:
329367
return super().copy(device)
330368

331369
@property
332-
def shardings(self) -> Dict[Tuple[str, str], FeatureShardingMixIn]:
370+
def shardings(
371+
self,
372+
) -> Dict[Tuple[str, Union[str, Tuple[str, ...]]], FeatureShardingMixIn]:
333373
# pyre-ignore [7]
334374
return self._sharding_type_device_group_to_sharding
335375

@@ -552,7 +592,7 @@ class ShardedQuantEbcInputDist(torch.nn.Module):
552592
def __init__(
553593
self,
554594
sharding_type_device_group_to_sharding: Dict[
555-
Tuple[str, str],
595+
Tuple[str, Union[str, Tuple[str, ...]]],
556596
EmbeddingSharding[
557597
NullShardingContext,
558598
InputDistOutputs,

torchrec/distributed/sharding/rw_sharding.py

+1
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,7 @@ def create_lookup(
752752
world_size=self._world_size,
753753
fused_params=fused_params,
754754
device=device if device is not None else self._device,
755+
device_type_from_sharding_infos=self._device_type_from_sharding_infos,
755756
)
756757

757758
def create_output_dist(

0 commit comments

Comments
 (0)