We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 1d1150b commit 584f385Copy full SHA for 584f385
1 file changed
src/liger_kernel/ops/utils.py
@@ -153,12 +153,9 @@ def set_large_grf_mode(kernel_args: dict):
153
kernel_args["grf_mode"] = "large"
154
155
156
-def _get_fsdp_ctx(mlp_module: torch.nn.Module):
157
- """
158
- Return FSDP.summon_full_params context if module is FSDP-wrapped,
159
- otherwise return a no-op context.
160
161
- if isinstance(mlp_module, FSDP):
162
- return FSDP.summon_full_params(mlp_module, write_back=True)
163
- else:
164
- return contextlib.nullcontext()
+def _get_no_sync_context(module):
+ """Return no_sync context if module is DDP or FSDP, else a no-op."""
+ if isinstance(module, (FSDP, parallel.DistributedDataParallel)):
+ return module.no_sync()
+ import contextlib
+ return contextlib.nullcontext()
0 commit comments