Skip to content

Commit 0512183

Browse files
PaulZhang12facebook-github-bot
authored andcommitted
Split FPEBC (pytorch#2535)
Summary: Pull Request resolved: pytorch#2535 Support API to split a FPEBC into the FP and the EBC counterparts. Furthermore, if FP is a nn.ModuleDict, wrap it over a nn.Module where the forward pass can be called. This helps support eager model split in IEN. Reviewed By: iamzainhuda Differential Revision: D65439160 fbshipit-source-id: 42cc148daefca22fce5d67b7adf10834732e6803
1 parent 63d604a commit 0512183

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

torchrec/modules/fp_embedding_modules.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
# pyre-strict
99

10-
from typing import Dict, List, Set, Union
10+
from typing import Dict, List, Set, Tuple, Union
1111

1212
import torch
1313
import torch.nn as nn
@@ -55,6 +55,15 @@ def apply_feature_processors_to_kjt(
5555
)
5656

5757

58+
class FeatureProcessorDictWrapper(FeatureProcessorsCollection):
59+
def __init__(self, feature_processors: nn.ModuleDict) -> None:
60+
super().__init__()
61+
self._feature_processors = feature_processors
62+
63+
def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor:
64+
return apply_feature_processors_to_kjt(features, self._feature_processors)
65+
66+
5867
class FeatureProcessedEmbeddingBagCollection(nn.Module):
5968
"""
6069
FeatureProcessedEmbeddingBagCollection represents a EmbeddingBagCollection module and a set of feature processor modules.
@@ -125,6 +134,18 @@ def __init__(
125134
feature_names_set.update(table_config.feature_names)
126135
self._feature_names: List[str] = list(feature_names_set)
127136

137+
def split(
138+
self,
139+
) -> Tuple[FeatureProcessorsCollection, EmbeddingBagCollection]:
140+
if isinstance(self._feature_processors, nn.ModuleDict):
141+
return (
142+
FeatureProcessorDictWrapper(self._feature_processors),
143+
self._embedding_bag_collection,
144+
)
145+
else:
146+
assert isinstance(self._feature_processors, FeatureProcessorsCollection)
147+
return self._feature_processors, self._embedding_bag_collection
148+
128149
def forward(
129150
self,
130151
features: KeyedJaggedTensor,

torchrec/modules/tests/test_fp_embedding_modules.py

+18
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,15 @@ def test_position_weighted_module_ebc_with_excessive_features(self) -> None:
9595
self.assertEqual(pooled_embeddings.values().size(), (3, 16))
9696
self.assertEqual(pooled_embeddings.offset_per_key(), [0, 8, 16])
9797

98+
# Test split method, FP then EBC
99+
fp, ebc = fp_ebc.split()
100+
fp_kjt = fp(features)
101+
pooled_embeddings_split = ebc(fp_kjt)
102+
103+
self.assertEqual(pooled_embeddings_split.keys(), ["f1", "f2"])
104+
self.assertEqual(pooled_embeddings_split.values().size(), (3, 16))
105+
self.assertEqual(pooled_embeddings_split.offset_per_key(), [0, 8, 16])
106+
98107

99108
class PositionWeightedModuleCollectionEmbeddingBagCollectionTest(unittest.TestCase):
100109
def generate_fp_ebc(self) -> FeatureProcessedEmbeddingBagCollection:
@@ -144,3 +153,12 @@ def test_position_weighted_collection_module_ebc(self) -> None:
144153
pooled_embeddings_gm_script.offset_per_key(),
145154
pooled_embeddings.offset_per_key(),
146155
)
156+
157+
# Test split method, FP then EBC
158+
fp, ebc = fp_ebc.split()
159+
fp_kjt = fp(features)
160+
pooled_embeddings_split = ebc(fp_kjt)
161+
162+
self.assertEqual(pooled_embeddings_split.keys(), ["f1", "f2"])
163+
self.assertEqual(pooled_embeddings_split.values().size(), (3, 16))
164+
self.assertEqual(pooled_embeddings_split.offset_per_key(), [0, 8, 16])

0 commit comments

Comments
 (0)