Skip to content

Commit a77cf3b

Browse files
Xinyi Wangmeta-codesync[bot]
authored andcommitted
Implement a write method in DMP (#3801)
Summary: Pull Request resolved: #3801 TorchRec allows users to create embeddings with custom input. This was done in D78749760. In this diff I expose this method to DistributedDataParallel (DMP), so that for modules with config enable_embedding_update = True, DMP will be able to update the embeddings with custom input. **Approach** We recursively initialize writable modules in `_init_dmp` method and when callers call `write` update all found modules with provided kjt Reviewed By: kausv Differential Revision: D93914739 fbshipit-source-id: 38f2019c079df325dbe246b7dea79de80a6f113f
1 parent bbcc55c commit a77cf3b

File tree

4 files changed

+273
-47
lines changed

4 files changed

+273
-47
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@
7474
GroupedEmbeddingConfig,
7575
ShardedEmbeddingTable,
7676
)
77-
from torchrec.distributed.model_tracker.types import IndexedLookup
7877
from torchrec.distributed.shards_wrapper import LocalShardsWrapper
7978
from torchrec.distributed.types import (
8079
LazyAwaitable,

torchrec/distributed/embedding.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@ def __init__(
498498
self._write_splits: List[int] = []
499499
self._feature_splits: List[int] = []
500500
self._features_order: List[int] = []
501+
self._writable_embedding_names: set[str] = set()
501502

502503
self._has_uninitialized_input_dist: bool = True
503504
logger.info(f"EC index dedup enabled: {self._use_index_dedup}.")
@@ -1685,6 +1686,7 @@ def _create_write_dist(self) -> None:
16851686
if sharding.enable_embedding_update:
16861687
self._write_dists.append(sharding.create_write_dist())
16871688
self._write_splits.append(sharding._get_num_writable_features())
1689+
self._writable_embedding_names.update(sharding.embedding_names())
16881690

16891691
# pyrefly: ignore[bad-override]
16901692
def write_dist(
@@ -1694,6 +1696,10 @@ def write_dist(
16941696
raise ValueError("enable_embedding_update is False for this collection")
16951697
if not self._write_dists:
16961698
self._create_write_dist()
1699+
if set(embeddings.keys()) != self._writable_embedding_names:
1700+
raise ValueError(
1701+
f"write_dist feature names {embeddings.keys()} do not match expected {self._writable_embedding_names}"
1702+
)
16971703
with torch.no_grad():
16981704
embeddings_by_shards = embeddings.split(self._write_splits)
16991705
awaitables = []

torchrec/distributed/model_parallel.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import copy
1212
import logging as logger
1313
from collections import defaultdict, OrderedDict
14+
from functools import wraps
1415
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Set, Tuple, Type
1516

1617
import torch
@@ -30,6 +31,7 @@
3031
from torch.nn.parallel import DistributedDataParallel
3132
from torchrec.distributed.collective_utils import create_on_rank_and_share_result
3233
from torchrec.distributed.comm import get_local_size
34+
from torchrec.distributed.embedding import ShardedEmbeddingCollection
3335
from torchrec.distributed.model_tracker.model_delta_tracker import (
3436
ModelDeltaTracker,
3537
ModelDeltaTrackerTrec,
@@ -40,7 +42,6 @@
4042
ModelTrackerConfigs,
4143
RawIdTrackerConfig,
4244
Trackers,
43-
UniqueRows,
4445
)
4546
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
4647
from torchrec.distributed.sharding_plan import get_default_sharders
@@ -61,7 +62,6 @@
6162
append_prefix,
6263
copy_to_device,
6364
filter_state_dict,
64-
none_throws,
6565
sharded_model_copy,
6666
)
6767
from torchrec.optim.fused import FusedOptimizerModule
@@ -77,6 +77,40 @@
7777
_DDP_STATE_DICT_PREFIX = "module."
7878

7979

80+
def _populate_updatable_modules(
81+
func: Callable[..., nn.Module],
82+
) -> Callable[..., nn.Module]:
83+
"""
84+
Decorator that populates the list of modules that can be updated with kjt.
85+
Specifically, modules with enable_embedding_update flag set to True.
86+
87+
Applied to _shard_modules_impl to automatically process returned modules.
88+
"""
89+
90+
@wraps(func)
91+
def wrapper(
92+
self: "DistributedModelParallel",
93+
module: nn.Module,
94+
path: str = "",
95+
module_id_cache: Optional[Dict[str, "ShardedModule"]] = None,
96+
) -> nn.Module:
97+
result = func(self, module, path, module_id_cache)
98+
99+
module_id = id(result)
100+
if module_id_cache and module_id in module_id_cache:
101+
# skip adding duplicate one
102+
return result
103+
104+
if isinstance(result, ShardedEmbeddingCollection) and getattr(
105+
result, "enable_embedding_update", False
106+
):
107+
self._writable_sharded_modules.append(result)
108+
109+
return result
110+
111+
return wrapper
112+
113+
80114
class DataParallelWrapper(abc.ABC):
81115
"""
82116
Interface implemented by custom data parallel wrappers.
@@ -297,6 +331,7 @@ def __init__(
297331
# pyrefly: ignore[bad-argument-type, missing-argument]
298332
plan = planner.plan(module, self.sharders)
299333
self._plan: ShardingPlan = plan
334+
self._writable_sharded_modules: list[ShardedEmbeddingCollection] = []
300335
self._dmp_wrapped_module: nn.Module = self._init_dmp(module)
301336
self._optim: CombinedOptimizer = self._init_optim(self._dmp_wrapped_module)
302337

@@ -462,6 +497,7 @@ def _fused_optim_impl(
462497
)
463498
return fused_optims
464499

500+
@_populate_updatable_modules
465501
def _shard_modules_impl(
466502
self,
467503
module: nn.Module,
@@ -613,6 +649,18 @@ def load_state_dict(
613649
) -> _IncompatibleKeys:
614650
return self._load_state_dict(self, state_dict, prefix, strict)
615651

652+
def write(self, *input, **kwargs) -> None:
653+
"""
654+
Write features to the sharded module if it has enable_embedding_update flag.
655+
"""
656+
if len(self._writable_sharded_modules) == 0:
657+
raise RuntimeError(
658+
"No writable sharded modules found. Please check `enable_embedding_update` flag in your embedding config"
659+
)
660+
661+
for module in self._writable_sharded_modules:
662+
module.write(*input, **kwargs)
663+
616664
def _load_state_dict(
617665
self,
618666
module: nn.Module,

0 commit comments

Comments
 (0)