Skip to content

Commit 5709900

Browse files
committed
Fixed AMP error messages
1 parent 63550c6 commit 5709900

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

ignite/engine/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def supervised_training_step_amp(
185185
"""
186186

187187
try:
188-
from torch.amp import autocast, GradScaler
188+
from torch.amp import autocast
189189
except ImportError:
190190
raise ImportError("Please install torch>=1.12.0 to use amp_mode='amp'.")
191191

@@ -412,7 +412,7 @@ def _check_arg(
412412
try:
413413
from torch.amp import GradScaler
414414
except ImportError:
415-
raise ImportError("Please install torch>=1.6.0 to use scaler argument.")
415+
raise ImportError("Please install torch>=2.3.1 to use scaler argument.")
416416
scaler = GradScaler(enabled=True)
417417

418418
if on_tpu:

tests/ignite/engine/test_create_supervised.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def _():
167167
trainer.run(data)
168168

169169

170-
@pytest.mark.skipif(Version(torch.__version__) < Version("1.12.0"), reason="Skip if < 1.12.0")
170+
@pytest.mark.skipif(Version(torch.__version__) < Version("2.3.1"), reason="Skip if < 2.3.1")
171171
def test_create_supervised_training_scalar_assignment():
172172
with mock.patch("ignite.engine._check_arg") as check_arg_mock:
173173
check_arg_mock.return_value = None, torch.amp.GradScaler(enabled=False)
@@ -456,11 +456,11 @@ def test_create_supervised_trainer_amp_error(mock_torch_cuda_amp_module):
456456
_test_create_supervised_trainer_wrong_accumulation(trainer_device="cpu", amp_mode="amp")
457457
with pytest.raises(ImportError, match="Please install torch>=1.12.0 to use amp_mode='amp'."):
458458
_test_create_supervised_trainer(amp_mode="amp")
459-
with pytest.raises(ImportError, match="Please install torch>=1.6.0 to use scaler argument."):
459+
with pytest.raises(ImportError, match="Please install torch>=2.3.1 to use scaler argument."):
460460
_test_create_supervised_trainer(amp_mode="amp", scaler=True)
461461

462462

463-
@pytest.mark.skipif(Version(torch.__version__) < Version("1.12.0"), reason="Skip if < 1.12.0")
463+
@pytest.mark.skipif(Version(torch.__version__) < Version("2.3.1"), reason="Skip if < 2.3.1")
464464
def test_create_supervised_trainer_scaler_not_amp():
465465
scaler = torch.amp.GradScaler(enabled=torch.cuda.is_available())
466466

0 commit comments

Comments
 (0)