15
15
IntNBitTableBatchedEmbeddingBagsCodegen ,
16
16
)
17
17
from torch import nn
18
+
19
+ from torch .distributed ._shard .sharding_spec import EnumerableShardingSpec
18
20
from torchrec .distributed .embedding_lookup import EmbeddingComputeKernel
19
21
from torchrec .distributed .embedding_sharding import (
20
22
EmbeddingSharding ,
68
70
from torchrec .sparse .jagged_tensor import KeyedJaggedTensor , KeyedTensor
69
71
70
72
71
- def get_device_from_parameter_sharding (ps : ParameterSharding ) -> str :
72
- # pyre-ignore
73
- return ps .sharding_spec .shards [0 ].placement .device ().type
73
+ def get_device_from_parameter_sharding (
74
+ ps : ParameterSharding ,
75
+ ) -> Union [str , Tuple [str , ...]]:
76
+ """
77
+ Returns list of device type per shard if table is sharded across
78
+ different device type, else reutrns single device type for the
79
+ table parameter.
80
+ """
81
+ if not isinstance (ps .sharding_spec , EnumerableShardingSpec ):
82
+ raise ValueError ("Expected EnumerableShardingSpec as input to the function" )
83
+
84
+ device_type_list : Tuple [str , ...] = tuple (
85
+ # pyre-fixme[16]: `Optional` has no attribute `device`
86
+ [shard .placement .device ().type for shard in ps .sharding_spec .shards ]
87
+ )
88
+ if len (set (device_type_list )) == 1 :
89
+ return device_type_list [0 ]
90
+ else :
91
+ assert (
92
+ ps .sharding_type == "row_wise"
93
+ ), "Only row_wise sharding supports sharding across multiple device types for a table"
94
+ return device_type_list
74
95
75
96
76
97
def get_device_from_sharding_infos (
77
98
emb_shard_infos : List [EmbeddingShardingInfo ],
78
- ) -> str :
99
+ ) -> Union [ str , Tuple [ str , ...]] :
79
100
res = list (
80
101
{
81
102
get_device_from_parameter_sharding (ps .param_sharding )
@@ -86,6 +107,13 @@ def get_device_from_sharding_infos(
86
107
return res [0 ]
87
108
88
109
110
+ def get_device_for_first_shard_from_sharding_infos (
111
+ emb_shard_infos : List [EmbeddingShardingInfo ],
112
+ ) -> str :
113
+ device_type = get_device_from_sharding_infos (emb_shard_infos )
114
+ return device_type [0 ] if isinstance (device_type , tuple ) else device_type
115
+
116
+
89
117
torch .fx .wrap ("len" )
90
118
91
119
@@ -103,13 +131,19 @@ def create_infer_embedding_bag_sharding(
103
131
NullShardingContext , InputDistOutputs , List [torch .Tensor ], torch .Tensor
104
132
]:
105
133
propogate_device : bool = get_propogate_device ()
134
+ device_type_from_sharding_infos : Union [str , Tuple [str , ...]] = (
135
+ get_device_from_sharding_infos (sharding_infos )
136
+ )
106
137
if sharding_type == ShardingType .TABLE_WISE .value :
107
138
return InferTwEmbeddingSharding (
108
139
sharding_infos , env , device = device if propogate_device else None
109
140
)
110
141
elif sharding_type == ShardingType .ROW_WISE .value :
111
142
return InferRwPooledEmbeddingSharding (
112
- sharding_infos , env , device = device if propogate_device else None
143
+ sharding_infos ,
144
+ env ,
145
+ device = device if propogate_device else None ,
146
+ device_type_from_sharding_infos = device_type_from_sharding_infos ,
113
147
)
114
148
elif sharding_type == ShardingType .COLUMN_WISE .value :
115
149
return InferCwPooledEmbeddingSharding (
@@ -148,12 +182,12 @@ def __init__(
148
182
module .embedding_bag_configs ()
149
183
)
150
184
self ._sharding_type_device_group_to_sharding_infos : Dict [
151
- Tuple [str , str ], List [EmbeddingShardingInfo ]
185
+ Tuple [str , Union [ str , Tuple [ str , ...]] ], List [EmbeddingShardingInfo ]
152
186
] = create_sharding_infos_by_sharding_device_group (
153
187
module , table_name_to_parameter_sharding , "embedding_bags." , fused_params
154
188
)
155
189
self ._sharding_type_device_group_to_sharding : Dict [
156
- Tuple [str , str ],
190
+ Tuple [str , Union [ str , Tuple [ str , ...]] ],
157
191
EmbeddingSharding [
158
192
NullShardingContext ,
159
193
InputDistOutputs ,
@@ -167,7 +201,11 @@ def __init__(
167
201
(
168
202
env
169
203
if not isinstance (env , Dict )
170
- else env [get_device_from_sharding_infos (embedding_configs )]
204
+ else env [
205
+ get_device_for_first_shard_from_sharding_infos (
206
+ embedding_configs
207
+ )
208
+ ]
171
209
),
172
210
device if get_propogate_device () else None ,
173
211
)
@@ -250,7 +288,7 @@ def tbes_configs(
250
288
251
289
def sharding_type_device_group_to_sharding_infos (
252
290
self ,
253
- ) -> Dict [Tuple [str , str ], List [EmbeddingShardingInfo ]]:
291
+ ) -> Dict [Tuple [str , Union [ str , Tuple [ str , ...]] ], List [EmbeddingShardingInfo ]]:
254
292
return self ._sharding_type_device_group_to_sharding_infos
255
293
256
294
def embedding_bag_configs (self ) -> List [EmbeddingBagConfig ]:
@@ -329,7 +367,9 @@ def copy(self, device: torch.device) -> nn.Module:
329
367
return super ().copy (device )
330
368
331
369
@property
332
- def shardings (self ) -> Dict [Tuple [str , str ], FeatureShardingMixIn ]:
370
+ def shardings (
371
+ self ,
372
+ ) -> Dict [Tuple [str , Union [str , Tuple [str , ...]]], FeatureShardingMixIn ]:
333
373
# pyre-ignore [7]
334
374
return self ._sharding_type_device_group_to_sharding
335
375
@@ -552,7 +592,7 @@ class ShardedQuantEbcInputDist(torch.nn.Module):
552
592
def __init__ (
553
593
self ,
554
594
sharding_type_device_group_to_sharding : Dict [
555
- Tuple [str , str ],
595
+ Tuple [str , Union [ str , Tuple [ str , ...]] ],
556
596
EmbeddingSharding [
557
597
NullShardingContext ,
558
598
InputDistOutputs ,
0 commit comments