File tree Expand file tree Collapse file tree 2 files changed +3
-7
lines changed
Expand file tree Collapse file tree 2 files changed +3
-7
lines changed Original file line number Diff line number Diff line change @@ -544,7 +544,7 @@ def to_bfloat16(x: Any) -> Any:
544544
545545 def _to_bfloat16_leaf (x : Any ) -> Any :
546546 if isinstance (x , jax .ShapeDtypeStruct ):
547- if jnp .issubdtype (x .dtype , jnp .floating ):
547+ if jnp .issubdtype (x .dtype , jnp .float32 ):
548548 return jax .ShapeDtypeStruct (
549549 x .shape ,
550550 jnp .bfloat16 ,
@@ -555,12 +555,8 @@ def _to_bfloat16_leaf(x: Any) -> Any:
555555 if hasattr (x , 'dtype' ):
556556 if x .dtype == tf .string :
557557 return x
558- if jnp .issubdtype (x .dtype , jnp .floating ):
558+ if jnp .issubdtype (x .dtype , jnp .float32 ):
559559 return x .astype (jnp .bfloat16 )
560-
561- if isinstance (x , float ):
562- return jnp .bfloat16 (x )
563-
564560 return x
565561
566562 flattened_x , treedef = jax .tree_util .tree_flatten (x )
Original file line number Diff line number Diff line change @@ -306,7 +306,7 @@ def test_to_bfloat16(self):
306306 self .assertEqual (y ['a' ].dtype , jnp .bfloat16 )
307307 self .assertEqual (y ['b' ].dtype , tf .string )
308308 self .assertAllEqual (y ['b' ], x ['b' ])
309- self .assertEqual (y ['c' ]. dtype , jnp . bfloat16 )
309+ self .assertIsInstance (y ['c' ], float )
310310 self .assertEqual (y ['d' ].dtype , jnp .bfloat16 )
311311
312312if __name__ == '__main__' :
You can’t perform that action at this time.
0 commit comments