Skip to content

Commit 1120456

Browse files
Fix torch.compile breaking toggle_optimizer / untoggle_optimizer (#21686)
* Fix toggle_optimizer breaking under torch.compile (#21513) `LightningModule.toggle_optimizer` and `untoggle_optimizer` mutate `requires_grad` on parameters to implement multi-optimizer gradient masking. Dynamo/AOTAutograd does not support `setattr()` on `Tensor.requires_grad` because it can change a tensor's leaf-ness mid-graph, so when the `LightningModule` is wrapped with `torch.compile` tracing either graph-breaks with "Unsupported: setattr() on Tensor.requires_grad" or raises a `KeyError` on the internal `param_requires_grad_state` mapping when the traced parameter references diverge from those held by `trainer.optimizers`. Decorate both helpers with `@torch.compiler.disable` (the same pattern already used for logging bookkeeping in `logger_connector/result.py`) so they run as opaque Python when called from a compiled `training_step`. Eager behavior is unchanged. Adds a CPU regression test that compiles a two-optimizer `LightningModule` calling `toggle_optimizer` / `untoggle_optimizer` in `training_step` and exercises one training iteration, plus a CHANGELOG entry. * Narrow test_toggle_untoggle to check compiler.disable attribute (#21513) The previous regression test compiled a `LightningModule` end-to-end and called `self.optimizers()` inside the compiled `training_step`, which unrelated to the toggle_optimizer fix trips a separate Dynamo limitation: tracing `self.trainer.strategy._lightning_optimizers` raises `InternalTorchDynamoError: GetAttrVariable(...) has no type` across all CI platforms and torch versions. The shipped fix — `@torch.compiler.disable` on `toggle_optimizer` / `untoggle_optimizer` — does not require a full compiled trainer run to verify; it only guarantees Dynamo skips those two methods. Replace the integration test with a direct attribute check that both methods carry the `_torchdynamo_disable` marker installed by `torch.compiler.disable`, following the same `has_dynamo(fn)` pattern already used by `tests/utilities/test_compile.py::test_compile_uncompile`. Toggle/untoggle functional correctness remains covered by the existing `test_toggle_untoggle_2_optimizers_no_shared_parameters` and `test_toggle_untoggle_3_optimizers_shared_parameters` tests in this file. --------- Co-authored-by: Deependu <deependujha21@gmail.com>
1 parent b4b5f6d commit 1120456

3 files changed

Lines changed: 39 additions & 0 deletions

File tree

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2828

2929
- Fixed `SIGTERMException` producing a zero exit code instead of 143 (128 + SIGTERM) ([#21623](https://github.com/Lightning-AI/pytorch-lightning/issues/21623))
3030

31+
- Fixed `LightningModule.toggle_optimizer` / `untoggle_optimizer` breaking under `torch.compile` by disabling Dynamo tracing on these bookkeeping helpers ([#21513](https://github.com/Lightning-AI/pytorch-lightning/issues/21513))
32+
3133
---
3234

3335
## [2.6.4] - 2026-05-20

src/lightning/pytorch/core/module.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,12 +1136,24 @@ def backward(self, loss):
11361136
else:
11371137
loss.backward(*args, **kwargs)
11381138

1139+
@torch.compiler.disable
11391140
def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> None:
11401141
"""Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to
11411142
prevent dangling gradients in multiple-optimizer setup.
11421143
11431144
It works with :meth:`untoggle_optimizer` to make sure ``param_requires_grad_state`` is properly reset.
11441145
1146+
.. note::
1147+
This method is decorated with :func:`torch.compiler.disable` so that it is executed as regular
1148+
Python when the ``LightningModule`` is wrapped with :func:`torch.compile`. Mutating
1149+
``requires_grad`` on parameters is not supported by Dynamo/AOTAutograd (it can change a
1150+
tensor's leaf-ness mid-graph), so tracing this bookkeeping helper would either fail with
1151+
``Unsupported: setattr() on Tensor.requires_grad`` or produce a ``KeyError`` on the
1152+
internal ``param_requires_grad_state`` mapping when the traced parameter references diverge
1153+
from those held by ``trainer.optimizers``. Disabling the compiler on this method keeps the
1154+
behavior identical for eager users while making it safe to call from a compiled
1155+
``training_step``.
1156+
11451157
Args:
11461158
optimizer: The optimizer to toggle.
11471159
@@ -1165,9 +1177,13 @@ def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> N
11651177
param.requires_grad = param_requires_grad_state[param]
11661178
self._param_requires_grad_state = param_requires_grad_state
11671179

1180+
@torch.compiler.disable
11681181
def untoggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> None:
11691182
"""Resets the state of required gradients that were toggled with :meth:`toggle_optimizer`.
11701183
1184+
See :meth:`toggle_optimizer` for details on why this method is decorated with
1185+
:func:`torch.compiler.disable`.
1186+
11711187
Args:
11721188
optimizer: The optimizer to untoggle.
11731189

tests/tests_pytorch/core/test_lightning_module.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,27 @@ def configure_optimizers(self):
298298
trainer.fit(model)
299299

300300

301+
@RunIf(dynamo=True)
302+
def test_toggle_untoggle_optimizer_are_compiler_disabled():
303+
"""Regression test for https://github.com/Lightning-AI/pytorch-lightning/issues/21513.
304+
305+
``toggle_optimizer`` / ``untoggle_optimizer`` mutate ``requires_grad`` on Parameters, which
306+
Dynamo/AOTAutograd does not support because it can change a tensor's leaf-ness mid-graph.
307+
Tracing these helpers either graph-breaks with ``Unsupported: setattr() on Tensor.requires_grad``
308+
or raises a ``KeyError`` on the internal ``param_requires_grad_state`` mapping when the traced
309+
parameter references diverge from those held by ``trainer.optimizers``. Both methods are
310+
decorated with ``@torch.compiler.disable`` so that Dynamo never enters them. This test verifies
311+
the decorator is attached via the ``_torchdynamo_disable`` attribute the decorator installs
312+
(the same assertion pattern used by ``tests/utilities/test_compile.py::test_compile_uncompile``).
313+
"""
314+
315+
def is_compiler_disabled(fn):
316+
return any(el.startswith("_torchdynamo_disable") for el in dir(fn))
317+
318+
assert is_compiler_disabled(LightningModule.toggle_optimizer)
319+
assert is_compiler_disabled(LightningModule.untoggle_optimizer)
320+
321+
301322
@pytest.mark.parametrize(
302323
("accelerator", "device"),
303324
[

0 commit comments

Comments
 (0)