Skip to content

Commit 3ccf74c

Browse files
jerryxyjOrbax Authors
authored andcommitted
Restrict bfloat16 conversion to only float32 types.
PiperOrigin-RevId: 868479414
1 parent 16c42cb commit 3ccf74c

File tree

2 files changed

+3
-7
lines changed

2 files changed

+3
-7
lines changed

export/orbax/export/utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ def to_bfloat16(x: Any) -> Any:
544544

545545
def _to_bfloat16_leaf(x: Any) -> Any:
546546
if isinstance(x, jax.ShapeDtypeStruct):
547-
if jnp.issubdtype(x.dtype, jnp.floating):
547+
if jnp.issubdtype(x.dtype, jnp.float32):
548548
return jax.ShapeDtypeStruct(
549549
x.shape,
550550
jnp.bfloat16,
@@ -555,12 +555,8 @@ def _to_bfloat16_leaf(x: Any) -> Any:
555555
if hasattr(x, 'dtype'):
556556
if x.dtype == tf.string:
557557
return x
558-
if jnp.issubdtype(x.dtype, jnp.floating):
558+
if jnp.issubdtype(x.dtype, jnp.float32):
559559
return x.astype(jnp.bfloat16)
560-
561-
if isinstance(x, float):
562-
return jnp.bfloat16(x)
563-
564560
return x
565561

566562
flattened_x, treedef = jax.tree_util.tree_flatten(x)

export/orbax/export/utils_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def test_to_bfloat16(self):
306306
self.assertEqual(y['a'].dtype, jnp.bfloat16)
307307
self.assertEqual(y['b'].dtype, tf.string)
308308
self.assertAllEqual(y['b'], x['b'])
309-
self.assertEqual(y['c'].dtype, jnp.bfloat16)
309+
self.assertIsInstance(y['c'], float)
310310
self.assertEqual(y['d'].dtype, jnp.bfloat16)
311311

312312
if __name__ == '__main__':

0 commit comments

Comments
 (0)