File tree Expand file tree Collapse file tree 2 files changed +5
-5
lines changed Expand file tree Collapse file tree 2 files changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -185,7 +185,7 @@ def supervised_training_step_amp(
185185 """
186186
187187 try :
188- from torch .amp import autocast , GradScaler
188+ from torch .amp import autocast
189189 except ImportError :
190190 raise ImportError ("Please install torch>=1.12.0 to use amp_mode='amp'." )
191191
@@ -412,7 +412,7 @@ def _check_arg(
412412 try :
413413 from torch .amp import GradScaler
414414 except ImportError :
415- raise ImportError ("Please install torch>=1.6.0 to use scaler argument." )
415+ raise ImportError ("Please install torch>=2.3.1 to use scaler argument." )
416416 scaler = GradScaler (enabled = True )
417417
418418 if on_tpu :
Original file line number Diff line number Diff line change @@ -167,7 +167,7 @@ def _():
167167 trainer .run (data )
168168
169169
170- @pytest .mark .skipif (Version (torch .__version__ ) < Version ("1.12.0 " ), reason = "Skip if < 1.12.0 " )
170+ @pytest .mark .skipif (Version (torch .__version__ ) < Version ("2.3.1 " ), reason = "Skip if < 2.3.1 " )
171171def test_create_supervised_training_scalar_assignment ():
172172 with mock .patch ("ignite.engine._check_arg" ) as check_arg_mock :
173173 check_arg_mock .return_value = None , torch .amp .GradScaler (enabled = False )
@@ -456,11 +456,11 @@ def test_create_supervised_trainer_amp_error(mock_torch_cuda_amp_module):
456456 _test_create_supervised_trainer_wrong_accumulation (trainer_device = "cpu" , amp_mode = "amp" )
457457 with pytest .raises (ImportError , match = "Please install torch>=1.12.0 to use amp_mode='amp'." ):
458458 _test_create_supervised_trainer (amp_mode = "amp" )
459- with pytest .raises (ImportError , match = "Please install torch>=1.6.0 to use scaler argument." ):
459+ with pytest .raises (ImportError , match = "Please install torch>=2.3.1 to use scaler argument." ):
460460 _test_create_supervised_trainer (amp_mode = "amp" , scaler = True )
461461
462462
463- @pytest .mark .skipif (Version (torch .__version__ ) < Version ("1.12.0 " ), reason = "Skip if < 1.12.0 " )
463+ @pytest .mark .skipif (Version (torch .__version__ ) < Version ("2.3.1 " ), reason = "Skip if < 2.3.1 " )
464464def test_create_supervised_trainer_scaler_not_amp ():
465465 scaler = torch .amp .GradScaler (enabled = torch .cuda .is_available ())
466466
You can’t perform that action at this time.
0 commit comments