|
27 | 27 |
|
28 | 28 | import torch
|
29 | 29 | from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
|
| 30 | +from tensordict import TensorDict |
30 | 31 | from torch import distributed as dist, nn, Tensor
|
31 | 32 | from torch.autograd.profiler import record_function
|
32 | 33 | from torch.distributed._tensor import DTensor
|
|
90 | 91 | )
|
91 | 92 | from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
|
92 | 93 | from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
|
93 |
| -from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor |
| 94 | +from torchrec.sparse.jagged_tensor import ( |
| 95 | + _to_offsets, |
| 96 | + KeyedJaggedTensor, |
| 97 | + KeyedTensor, |
| 98 | + td_to_kjt, |
| 99 | +) |
94 | 100 |
|
95 | 101 | try:
|
96 | 102 | torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
|
@@ -655,9 +661,7 @@ def __init__(
|
655 | 661 | self._inverse_indices_permute_indices: Optional[torch.Tensor] = None
|
656 | 662 | # to support mean pooling callback hook
|
657 | 663 | self._has_mean_pooling_callback: bool = (
|
658 |
| - True |
659 |
| - if PoolingType.MEAN.value in self._pooling_type_to_rs_features |
660 |
| - else False |
| 664 | + PoolingType.MEAN.value in self._pooling_type_to_rs_features |
661 | 665 | )
|
662 | 666 | self._dim_per_key: Optional[torch.Tensor] = None
|
663 | 667 | self._kjt_key_indices: Dict[str, int] = {}
|
@@ -1161,62 +1165,84 @@ def _create_inverse_indices_permute_indices(
|
1161 | 1165 |
|
1162 | 1166 | # pyre-ignore [14]
|
1163 | 1167 | def input_dist(
|
1164 |
| - self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor |
| 1168 | + self, |
| 1169 | + ctx: EmbeddingBagCollectionContext, |
| 1170 | + features: Union[KeyedJaggedTensor, TensorDict], |
1165 | 1171 | ) -> Awaitable[Awaitable[KJTList]]:
|
1166 |
| - ctx.variable_batch_per_feature = features.variable_stride_per_key() |
1167 |
| - ctx.inverse_indices = features.inverse_indices_or_none() |
| 1172 | + if isinstance(features, KeyedJaggedTensor): |
| 1173 | + ctx.variable_batch_per_feature = features.variable_stride_per_key() |
| 1174 | + ctx.inverse_indices = features.inverse_indices_or_none() |
| 1175 | + feature_keys = features.keys() |
| 1176 | + else: # features is TensorDict |
| 1177 | + ctx.variable_batch_per_feature = False # TD does not support variable batch |
| 1178 | + ctx.inverse_indices = None |
| 1179 | + feature_keys = list(features.keys()) # pyre-ignore[6] |
1168 | 1180 | if self._has_uninitialized_input_dist:
|
1169 |
| - self._create_input_dist(features.keys()) |
| 1181 | + self._create_input_dist(feature_keys) |
1170 | 1182 | self._has_uninitialized_input_dist = False
|
1171 | 1183 | if ctx.variable_batch_per_feature:
|
1172 | 1184 | self._create_inverse_indices_permute_indices(ctx.inverse_indices)
|
1173 | 1185 | if self._has_mean_pooling_callback:
|
1174 |
| - self._init_mean_pooling_callback(features.keys(), ctx.inverse_indices) |
| 1186 | + self._init_mean_pooling_callback(feature_keys, ctx.inverse_indices) |
1175 | 1187 | with torch.no_grad():
|
1176 |
| - if self._has_features_permute: |
1177 |
| - features = features.permute( |
1178 |
| - self._features_order, |
1179 |
| - self._features_order_tensor, |
1180 |
| - ) |
1181 |
| - if self._has_mean_pooling_callback: |
1182 |
| - ctx.divisor = _create_mean_pooling_divisor( |
1183 |
| - lengths=features.lengths(), |
1184 |
| - stride=features.stride(), |
1185 |
| - keys=features.keys(), |
1186 |
| - offsets=features.offsets(), |
1187 |
| - pooling_type_to_rs_features=self._pooling_type_to_rs_features, |
1188 |
| - stride_per_key=features.stride_per_key(), |
1189 |
| - dim_per_key=self._dim_per_key, # pyre-ignore[6] |
1190 |
| - embedding_names=self._embedding_names, |
1191 |
| - embedding_dims=self._embedding_dims, |
1192 |
| - variable_batch_per_feature=ctx.variable_batch_per_feature, |
1193 |
| - kjt_inverse_order=self._kjt_inverse_order, # pyre-ignore[6] |
1194 |
| - kjt_key_indices=self._kjt_key_indices, |
1195 |
| - kt_key_ordering=self._kt_key_ordering, # pyre-ignore[6] |
1196 |
| - inverse_indices=ctx.inverse_indices, |
1197 |
| - weights=features.weights_or_none(), |
1198 |
| - ) |
| 1188 | + if isinstance(features, KeyedJaggedTensor): |
| 1189 | + if self._has_features_permute: |
| 1190 | + features = features.permute( |
| 1191 | + self._features_order, |
| 1192 | + self._features_order_tensor, |
| 1193 | + ) |
| 1194 | + if self._has_mean_pooling_callback: |
| 1195 | + ctx.divisor = _create_mean_pooling_divisor( |
| 1196 | + lengths=features.lengths(), |
| 1197 | + stride=features.stride(), |
| 1198 | + keys=features.keys(), |
| 1199 | + offsets=features.offsets(), |
| 1200 | + pooling_type_to_rs_features=self._pooling_type_to_rs_features, |
| 1201 | + stride_per_key=features.stride_per_key(), |
| 1202 | + dim_per_key=self._dim_per_key, # pyre-ignore[6] |
| 1203 | + embedding_names=self._embedding_names, |
| 1204 | + embedding_dims=self._embedding_dims, |
| 1205 | + variable_batch_per_feature=ctx.variable_batch_per_feature, |
| 1206 | + kjt_inverse_order=self._kjt_inverse_order, # pyre-ignore[6] |
| 1207 | + kjt_key_indices=self._kjt_key_indices, |
| 1208 | + kt_key_ordering=self._kt_key_ordering, # pyre-ignore[6] |
| 1209 | + inverse_indices=ctx.inverse_indices, |
| 1210 | + weights=features.weights_or_none(), |
| 1211 | + ) |
1199 | 1212 |
|
1200 |
| - features_by_shards = features.split( |
1201 |
| - self._feature_splits, |
1202 |
| - ) |
| 1213 | + features_by_sharding_types = features.split( |
| 1214 | + self._feature_splits, |
| 1215 | + ) |
| 1216 | + else: # features is TensorDict |
| 1217 | + feature_names = [feature_keys[i] for i in self._features_order] |
| 1218 | + feature_name_by_sharding_types: List[List[str]] = [] |
| 1219 | + start = 0 |
| 1220 | + for length in self._feature_splits: |
| 1221 | + feature_name_by_sharding_types.append( |
| 1222 | + feature_names[start : start + length] |
| 1223 | + ) |
| 1224 | + start += length |
| 1225 | + features_by_sharding_types = [ |
| 1226 | + td_to_kjt(features, names) |
| 1227 | + for names in feature_name_by_sharding_types |
| 1228 | + ] |
1203 | 1229 | awaitables = []
|
1204 |
| - for input_dist, features_by_shard, sharding_type in zip( |
| 1230 | + for input_dist, features_by_sharding_type, sharding_type in zip( |
1205 | 1231 | self._input_dists,
|
1206 |
| - features_by_shards, |
| 1232 | + features_by_sharding_types, |
1207 | 1233 | self._sharding_types,
|
1208 | 1234 | ):
|
1209 | 1235 | with maybe_annotate_embedding_event(
|
1210 | 1236 | EmbeddingEvent.KJT_SPLITS_DIST,
|
1211 | 1237 | self._module_fqn,
|
1212 | 1238 | sharding_type,
|
1213 | 1239 | ):
|
1214 |
| - awaitables.append(input_dist(features_by_shard)) |
| 1240 | + awaitables.append(input_dist(features_by_sharding_type)) |
1215 | 1241 |
|
1216 | 1242 | ctx.sharding_contexts.append(
|
1217 | 1243 | EmbeddingShardingContext(
|
1218 |
| - batch_size_per_feature_pre_a2a=features_by_shard.stride_per_key(), |
1219 |
| - variable_batch_per_feature=features_by_shard.variable_stride_per_key(), |
| 1244 | + batch_size_per_feature_pre_a2a=features_by_sharding_type.stride_per_key(), |
| 1245 | + variable_batch_per_feature=features_by_sharding_type.variable_stride_per_key(), |
1220 | 1246 | )
|
1221 | 1247 | )
|
1222 | 1248 | return KJTListSplitsAwaitable(
|
|
0 commit comments