@@ -357,9 +357,9 @@ def test_layer_norm_grad(self):
357357
358358 gx1 , gw1 , gb1 = mx .grad (f1 , argnums = (0 , 1 , 2 ))(x , w , b , y )
359359 gx2 , gw2 , gb2 = mx .grad (f2 , argnums = (0 , 1 , 2 ))(x , w , b , y )
360- self .assertLess (mx .abs (gx1 - gx2 ).max (), 1e -5 )
361- self .assertLess (mx .abs (gw1 - gw2 ).max () / mx .abs (gw1 ).mean (), 1e -5 )
362- self .assertLess (mx .abs (gb1 - gb2 ).max () / mx .abs (gb1 ).mean (), 1e -5 )
360+ self .assertLess (mx .abs (gx1 - gx2 ).max (), 5e -5 )
361+ self .assertLess (mx .abs (gw1 - gw2 ).max () / mx .abs (gw1 ).mean (), 5e -5 )
362+ self .assertLess (mx .abs (gb1 - gb2 ).max () / mx .abs (gb1 ).mean (), 5e -5 )
363363
364364 def gf (f ):
365365 def inner (x , w , b , y ):
@@ -370,8 +370,8 @@ def inner(x, w, b, y):
370370
371371 gx1 , gw1 , gb1 = mx .grad (gf (f1 ), argnums = (0 , 1 , 2 ))(x , w , b , y )
372372 gx2 , gw2 , gb2 = mx .grad (gf (f2 ), argnums = (0 , 1 , 2 ))(x , w , b , y )
373- self .assertLess (mx .abs (gx1 - gx2 ).max () / mx .abs (gx1 ).mean (), 1e -5 )
374- self .assertLess (mx .abs (gw1 - gw2 ).max () / mx .abs (gw1 ).mean (), 1e -5 )
373+ self .assertLess (mx .abs (gx1 - gx2 ).max () / mx .abs (gx1 ).mean (), 5e -5 )
374+ self .assertLess (mx .abs (gw1 - gw2 ).max () / mx .abs (gw1 ).mean (), 5e -5 )
375375 self .assertLess (mx .abs (gb1 ).max (), 1e-9 )
376376 self .assertLess (mx .abs (gb2 ).max (), 1e-9 )
377377
0 commit comments