Skip to content

Commit 7c0a68a

Browse files
Microvefacebook-github-bot
authored andcommitted
Add SimpleFSDP to leaf node
Summary: Torchrec should not trace into SimpleFSDP which will cause failures. We thus added it to the leaf node Differential Revision: D71450848
1 parent 76446e7 commit 7c0a68a

File tree

1 file changed

+2
-0
lines changed
  • torchrec/distributed/train_pipeline

1 file changed

+2
-0
lines changed

torchrec/distributed/train_pipeline/utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
if not torch._running_with_deploy():
4141
from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2
42+
from torch.distributed.fb.simple_fsdp import SimpleFSDPModule as SimpleFSDP
4243
else:
4344

4445
class FSDP2:
@@ -743,6 +744,7 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool
743744
or module_qualified_name in self._leaf_modules
744745
or isinstance(m, FSDP)
745746
or isinstance(m, FSDP2)
747+
or isinstance(m, SimpleFSDP)
746748
):
747749
return True
748750
return super().is_leaf_module(m, module_qualified_name)

0 commit comments

Comments
 (0)