@@ -170,7 +170,7 @@ def _():
170170@pytest .mark .skipif (Version (torch .__version__ ) < Version ("1.12.0" ), reason = "Skip if < 1.12.0" )
171171def test_create_supervised_training_scalar_assignment ():
172172 with mock .patch ("ignite.engine._check_arg" ) as check_arg_mock :
173- check_arg_mock .return_value = None , torch .amp .GradScaler ('cuda' , enabled = False )
173+ check_arg_mock .return_value = None , torch .amp .GradScaler (enabled = False )
174174 trainer , _ = _default_create_supervised_trainer (model_device = "cpu" , trainer_device = "cpu" , scaler = True )
175175 assert hasattr (trainer .state , "scaler" )
176176 assert isinstance (trainer .state .scaler , torch .amp .GradScaler )
@@ -462,7 +462,7 @@ def test_create_supervised_trainer_amp_error(mock_torch_cuda_amp_module):
462462
463463@pytest .mark .skipif (Version (torch .__version__ ) < Version ("1.12.0" ), reason = "Skip if < 1.12.0" )
464464def test_create_supervised_trainer_scaler_not_amp ():
465- scaler = torch .amp .GradScaler ('cuda' , enabled = torch .cuda .is_available ())
465+ scaler = torch .amp .GradScaler (enabled = torch .cuda .is_available ())
466466
467467 with pytest .raises (ValueError , match = f"scaler argument is { scaler } , but amp_mode is None." ):
468468 _test_create_supervised_trainer (amp_mode = None , scaler = scaler )
@@ -540,7 +540,7 @@ def test_create_supervised_trainer_on_cuda_amp_scaler():
540540 _test_create_mocked_supervised_trainer (
541541 model_device = model_device , trainer_device = trainer_device , amp_mode = "amp" , scaler = True
542542 )
543- scaler = torch .amp .GradScaler ('cuda' , enabled = torch .cuda .is_available ())
543+ scaler = torch .amp .GradScaler (enabled = torch .cuda .is_available ())
544544 _test_create_supervised_trainer (
545545 gradient_accumulation_steps = 1 ,
546546 model_device = model_device ,
0 commit comments