|
27 | 27 |
|
28 | 28 | import torch
|
29 | 29 | from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
|
| 30 | +from tensordict import TensorDict |
30 | 31 | from torch import distributed as dist, nn, Tensor
|
31 | 32 | from torch.autograd.profiler import record_function
|
32 | 33 | from torch.distributed._shard.sharded_tensor import TensorProperties
|
|
94 | 95 | from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
|
95 | 96 | from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
|
96 | 97 | from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor
|
| 98 | +from torchrec.sparse.tensor_dict import maybe_td_to_kjt |
97 | 99 |
|
98 | 100 | try:
|
99 | 101 | torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
|
|
102 | 104 | except OSError:
|
103 | 105 | pass
|
104 | 106 |
|
105 |
| -try: |
106 |
| - from tensordict import TensorDict |
107 |
| -except ImportError: |
108 |
| - |
109 |
| - class TensorDict: |
110 |
| - pass |
111 |
| - |
112 | 107 |
|
113 | 108 | def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
|
114 | 109 | return (
|
@@ -663,9 +658,7 @@ def __init__(
|
663 | 658 | self._inverse_indices_permute_indices: Optional[torch.Tensor] = None
|
664 | 659 | # to support mean pooling callback hook
|
665 | 660 | self._has_mean_pooling_callback: bool = (
|
666 |
| - True |
667 |
| - if PoolingType.MEAN.value in self._pooling_type_to_rs_features |
668 |
| - else False |
| 661 | + PoolingType.MEAN.value in self._pooling_type_to_rs_features |
669 | 662 | )
|
670 | 663 | self._dim_per_key: Optional[torch.Tensor] = None
|
671 | 664 | self._kjt_key_indices: Dict[str, int] = {}
|
@@ -1196,8 +1189,16 @@ def _create_inverse_indices_permute_indices(
|
1196 | 1189 |
|
1197 | 1190 | # pyre-ignore [14]
|
1198 | 1191 | def input_dist(
|
1199 |
| - self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor |
| 1192 | + self, |
| 1193 | + ctx: EmbeddingBagCollectionContext, |
| 1194 | + features: Union[KeyedJaggedTensor, TensorDict], |
1200 | 1195 | ) -> Awaitable[Awaitable[KJTList]]:
|
| 1196 | + if isinstance(features, TensorDict): |
| 1197 | + feature_keys = list(features.keys()) # pyre-ignore[6] |
| 1198 | + if len(self._features_order) > 0: |
| 1199 | + feature_keys = [feature_keys[i] for i in self._features_order] |
| 1200 | + self._has_features_permute = False # feature_keys are in order |
| 1201 | + features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6] |
1201 | 1202 | ctx.variable_batch_per_feature = features.variable_stride_per_key()
|
1202 | 1203 | ctx.inverse_indices = features.inverse_indices_or_none()
|
1203 | 1204 | if self._has_uninitialized_input_dist:
|
|
0 commit comments