Skip to content

Commit 584f385

Browse files
committed
Add ctx function for no syn FSDP when wrapping TiledMLP
1 parent 1d1150b commit 584f385

1 file changed

Lines changed: 6 additions & 9 deletions

File tree

src/liger_kernel/ops/utils.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,9 @@ def set_large_grf_mode(kernel_args: dict):
153153
kernel_args["grf_mode"] = "large"
154154

155155

156-
def _get_fsdp_ctx(mlp_module: torch.nn.Module):
157-
"""
158-
Return FSDP.summon_full_params context if module is FSDP-wrapped,
159-
otherwise return a no-op context.
160-
"""
161-
if isinstance(mlp_module, FSDP):
162-
return FSDP.summon_full_params(mlp_module, write_back=True)
163-
else:
164-
return contextlib.nullcontext()
156+
def _get_no_sync_context(module):
157+
"""Return no_sync context if module is DDP or FSDP, else a no-op."""
158+
if isinstance(module, (FSDP, parallel.DistributedDataParallel)):
159+
return module.no_sync()
160+
import contextlib
161+
return contextlib.nullcontext()

0 commit comments

Comments
 (0)