Skip to content

Commit 7e867ad

Browse files
seanx92facebook-github-bot
authored andcommitted
add backward compatible reference for _get_unflattened_lengths (pytorch#2541)
Summary: Pull Request resolved: pytorch#2541 Reviewed By: PaulZhang12 Differential Revision: D65490058 fbshipit-source-id: 7fe44cc56bf5b72abe20e10911c0c288905c2dd1
1 parent 0512183 commit 7e867ad

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

torchrec/modules/utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,17 @@ def _slice_1d_tensor(tensor: torch.Tensor, start: int, end: int) -> torch.Tensor
4848
return tensor[start:end]
4949

5050

51+
# PLEASE DO NOT USE THIS FUNCTION, THIS FUNCTION IS FOR BACKWARD COMPATIBILITY ONLY
52+
# USE THE ONE IN torchrec/quant/embedding_modules.py
53+
# TODO(@shuaoxiong): remove this function after we make sure all models switch to the new reference
54+
@torch.fx.wrap
55+
def _get_unflattened_lengths(lengths: torch.Tensor, num_features: int) -> torch.Tensor:
56+
"""
57+
Unflatten lengths tensor from [F * B] to [F, B].
58+
"""
59+
return lengths.view(num_features, -1)
60+
61+
5162
def extract_module_or_tensor_callable(
5263
module_or_callable: Union[
5364
Callable[[], torch.nn.Module],

0 commit comments

Comments
 (0)