Skip to content

Commit 167a244

Browse files
committed
test: add CUDA coverage for AMP no_grad cache handling
1 parent 3d5a8d6 commit 167a244

1 file changed

Lines changed: 38 additions & 0 deletions

File tree

tests/tests_pytorch/plugins/precision/test_amp.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from lightning.pytorch.plugins import MixedPrecision
2222
from lightning.pytorch.utilities import GradClipAlgorithmType
23+
from tests_pytorch.helpers.runif import RunIf
2324

2425

2526
def test_clip_gradients():
@@ -119,6 +120,43 @@ def test_torch_autocast_cache_behavior_with_no_grad(cache_enabled, expect_grad):
119120
loss.backward()
120121

121122

123+
@RunIf(min_cuda_gpus=1)
124+
@pytest.mark.parametrize(("cache_enabled", "expect_grad"), [(True, False), (False, True)])
125+
def test_torch_autocast_cache_behavior_with_no_grad_cuda(cache_enabled, expect_grad):
126+
"""Document the same autocast cache behavior on CUDA, where the reported regression happens."""
127+
layer = nn.Linear(2, 1, device="cuda")
128+
x = torch.randn(1, 2, device="cuda")
129+
130+
with torch.autocast("cuda", dtype=torch.float16, cache_enabled=cache_enabled):
131+
with torch.no_grad():
132+
_ = layer(x)
133+
134+
loss = layer(x).mean()
135+
if expect_grad:
136+
loss.backward()
137+
assert loss.grad_fn is not None
138+
else:
139+
assert loss.grad_fn is None
140+
with pytest.raises(RuntimeError, match="does not require grad"):
141+
loss.backward()
142+
143+
144+
@RunIf(min_cuda_gpus=1)
145+
def test_amp_with_no_grad_cuda():
146+
"""Test the Lightning workaround on the CUDA path used by the reported regression."""
147+
layer = nn.Linear(2, 1, device="cuda")
148+
x = torch.randn(1, 2, device="cuda")
149+
amp = MixedPrecision(precision="16-mixed", device="cuda")
150+
151+
with amp.forward_context():
152+
with torch.no_grad():
153+
_ = layer(x)
154+
155+
loss = layer(x).mean()
156+
loss.backward()
157+
assert loss.grad_fn is not None
158+
159+
122160
def test_amp_autocast_context_manager_disables_cache():
123161
"""Test that the public autocast context manager preserves the existing no-cache workaround."""
124162
amp = MixedPrecision(precision="bf16-mixed", device="cpu")

0 commit comments

Comments
 (0)