@@ -34,9 +34,9 @@ def test_no_op_loss_scale(self):
3434 ("StaticLossScale(2)" , jmp .StaticLossScale , 2 ),
3535 ("StaticLossScale(3)" , jmp .StaticLossScale , 3 ),
3636 ("StaticLossScale(4)" , jmp .StaticLossScale , 4 ),
37- ("DynamicLossScale(2)" , jmp .DynamicLossScale , 2 ),
38- ("DynamicLossScale(3)" , jmp .DynamicLossScale , 3 ),
39- ("DynamicLossScale(4)" , jmp .DynamicLossScale , 4 ),
37+ ("DynamicLossScale(2)" , jmp .DynamicLossScale , 2. ),
38+ ("DynamicLossScale(3)" , jmp .DynamicLossScale , 3. ),
39+ ("DynamicLossScale(4)" , jmp .DynamicLossScale , 4. ),
4040 )
4141 def test_static_loss_scale (self , cls , scale ):
4242 loss_scale = cls (scale )
@@ -98,7 +98,7 @@ def test_dynamic_loss_scale_adjust_reduce_on_non_finite(self, period, factor):
9898 self .assertEqual (loss_scale .period , period )
9999 self .assertEqual (loss_scale .factor , factor )
100100
101- @parameterized .parameters ((20 , 2 , .3125 ), (30 , 3 , .37 ), (5 , 2 , 0 ))
101+ @parameterized .parameters ((20 , 2 , .3125 ), (30 , 3 , .37 ), (5. , 2. , 0. ))
102102 def test_dynamic_loss_scale_explicit_min_loss_scale (self , period , factor ,
103103 min_loss_scale ):
104104 grads_finite = jnp .bool_ (False )
@@ -120,6 +120,17 @@ def test_dynamic_loss_scale_explicit_min_loss_scale(self, period, factor,
120120 def test_dynamic_loss_scale_adjust_requires_scalar_input (self ):
121121 pass
122122
123+ def test_dynamic_loss_scale_raises_type_error_on_int_loss_scale (self ):
124+ expected_message = "Expected floating type for loss_scale"
125+ with self .assertWarnsRegex (Warning , expected_message ):
126+ jmp .DynamicLossScale (jnp .asarray (1 , dtype = jnp .int32 ))
127+
128+ def test_dynamic_loss_scale_raises_type_error_on_int_min_loss_scale (self ):
129+ expected_message = "Expected floating type for min_loss_scale"
130+ with self .assertWarnsRegex (Warning , expected_message ):
131+ jmp .DynamicLossScale (jnp .asarray (1 , dtype = jnp .float32 ),
132+ min_loss_scale = jnp .asarray (1 , dtype = jnp .int32 ))
133+
123134 @parameterized .parameters (jnp .inf , jnp .nan )
124135 def test_all_finite (self , non_finite ):
125136 self .assertTrue (jmp .all_finite (None ))
0 commit comments