|
27 | 27 |
|
28 | 28 | import torch
|
29 | 29 | from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
|
| 30 | +from tensordict import TensorDict |
30 | 31 | from torch import distributed as dist, nn, Tensor
|
31 | 32 | from torch.autograd.profiler import record_function
|
32 | 33 | from torch.distributed._tensor import DTensor
|
|
92 | 93 | from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
|
93 | 94 | from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
|
94 | 95 | from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor
|
| 96 | +from torchrec.sparse.tensor_dict import maybe_td_to_kjt |
95 | 97 |
|
96 | 98 | try:
|
97 | 99 | torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
|
|
100 | 102 | except OSError:
|
101 | 103 | pass
|
102 | 104 |
|
103 |
| -try: |
104 |
| - from tensordict import TensorDict |
105 |
| -except ImportError: |
106 |
| - |
107 |
| - class TensorDict: |
108 |
| - pass |
109 |
| - |
110 | 105 |
|
111 | 106 | def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
|
112 | 107 | return (
|
@@ -661,9 +656,7 @@ def __init__(
|
661 | 656 | self._inverse_indices_permute_indices: Optional[torch.Tensor] = None
|
662 | 657 | # to support mean pooling callback hook
|
663 | 658 | self._has_mean_pooling_callback: bool = (
|
664 |
| - True |
665 |
| - if PoolingType.MEAN.value in self._pooling_type_to_rs_features |
666 |
| - else False |
| 659 | + PoolingType.MEAN.value in self._pooling_type_to_rs_features |
667 | 660 | )
|
668 | 661 | self._dim_per_key: Optional[torch.Tensor] = None
|
669 | 662 | self._kjt_key_indices: Dict[str, int] = {}
|
@@ -1163,17 +1156,27 @@ def _create_inverse_indices_permute_indices(
|
1163 | 1156 |
|
1164 | 1157 | # pyre-ignore [14]
|
1165 | 1158 | def input_dist(
|
1166 |
| - self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor |
| 1159 | + self, |
| 1160 | + ctx: EmbeddingBagCollectionContext, |
| 1161 | + features: Union[KeyedJaggedTensor, TensorDict], |
1167 | 1162 | ) -> Awaitable[Awaitable[KJTList]]:
|
| 1163 | + if isinstance(features, TensorDict): |
| 1164 | + feature_keys = list(features.keys()) # pyre-ignore[6] |
| 1165 | + if len(self._features_order) > 0: |
| 1166 | + feature_keys = [feature_keys[i] for i in self._features_order] |
| 1167 | + self._has_features_permute = False # feature_keys are in order |
| 1168 | + features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6] |
| 1169 | + else: |
| 1170 | + feature_keys = features.keys() |
1168 | 1171 | ctx.variable_batch_per_feature = features.variable_stride_per_key()
|
1169 | 1172 | ctx.inverse_indices = features.inverse_indices_or_none()
|
1170 | 1173 | if self._has_uninitialized_input_dist:
|
1171 |
| - self._create_input_dist(features.keys()) |
| 1174 | + self._create_input_dist(feature_keys) |
1172 | 1175 | self._has_uninitialized_input_dist = False
|
1173 | 1176 | if ctx.variable_batch_per_feature:
|
1174 | 1177 | self._create_inverse_indices_permute_indices(ctx.inverse_indices)
|
1175 | 1178 | if self._has_mean_pooling_callback:
|
1176 |
| - self._init_mean_pooling_callback(features.keys(), ctx.inverse_indices) |
| 1179 | + self._init_mean_pooling_callback(feature_keys, ctx.inverse_indices) |
1177 | 1180 | with torch.no_grad():
|
1178 | 1181 | if self._has_features_permute:
|
1179 | 1182 | features = features.permute(
|
|
0 commit comments