|
20 | 20 |
|
21 | 21 | from lightning.pytorch.plugins import MixedPrecision |
22 | 22 | from lightning.pytorch.utilities import GradClipAlgorithmType |
| 23 | +from tests_pytorch.helpers.runif import RunIf |
23 | 24 |
|
24 | 25 |
|
25 | 26 | def test_clip_gradients(): |
@@ -119,6 +120,43 @@ def test_torch_autocast_cache_behavior_with_no_grad(cache_enabled, expect_grad): |
119 | 120 | loss.backward() |
120 | 121 |
|
121 | 122 |
|
| 123 | +@RunIf(min_cuda_gpus=1) |
| 124 | +@pytest.mark.parametrize(("cache_enabled", "expect_grad"), [(True, False), (False, True)]) |
| 125 | +def test_torch_autocast_cache_behavior_with_no_grad_cuda(cache_enabled, expect_grad): |
| 126 | + """Document the same autocast cache behavior on CUDA, where the reported regression happens.""" |
| 127 | + layer = nn.Linear(2, 1, device="cuda") |
| 128 | + x = torch.randn(1, 2, device="cuda") |
| 129 | + |
| 130 | + with torch.autocast("cuda", dtype=torch.float16, cache_enabled=cache_enabled): |
| 131 | + with torch.no_grad(): |
| 132 | + _ = layer(x) |
| 133 | + |
| 134 | + loss = layer(x).mean() |
| 135 | + if expect_grad: |
| 136 | + loss.backward() |
| 137 | + assert loss.grad_fn is not None |
| 138 | + else: |
| 139 | + assert loss.grad_fn is None |
| 140 | + with pytest.raises(RuntimeError, match="does not require grad"): |
| 141 | + loss.backward() |
| 142 | + |
| 143 | + |
| 144 | +@RunIf(min_cuda_gpus=1) |
| 145 | +def test_amp_with_no_grad_cuda(): |
| 146 | + """Test the Lightning workaround on the CUDA path used by the reported regression.""" |
| 147 | + layer = nn.Linear(2, 1, device="cuda") |
| 148 | + x = torch.randn(1, 2, device="cuda") |
| 149 | + amp = MixedPrecision(precision="16-mixed", device="cuda") |
| 150 | + |
| 151 | + with amp.forward_context(): |
| 152 | + with torch.no_grad(): |
| 153 | + _ = layer(x) |
| 154 | + |
| 155 | + loss = layer(x).mean() |
| 156 | + loss.backward() |
| 157 | + assert loss.grad_fn is not None |
| 158 | + |
| 159 | + |
122 | 160 | def test_amp_autocast_context_manager_disables_cache(): |
123 | 161 | """Test that the public autocast context manager preserves the existing no-cache workaround.""" |
124 | 162 | amp = MixedPrecision(precision="bf16-mixed", device="cpu") |
|
0 commit comments