Skip to content

Commit 8b2e80e

Browse files
PAenugulaOrbax Authors
authored andcommitted
No public description
PiperOrigin-RevId: 865066915
1 parent 29d5875 commit 8b2e80e

File tree

1 file changed

+4
-0
lines changed
  • model/orbax/experimental/model/jd2obm

1 file changed

+4
-0
lines changed

model/orbax/experimental/model/jd2obm/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,17 @@
2525

2626
def _obm_to_voxel_dtype(t):
2727
if isinstance(t, obm.ShloDType):
28+
if t == obm.ShloDType.str:
29+
return np.dtype(np.str_)
2830
return obm.shlo_dtype_to_np_dtype(t)
2931
return t
3032

3133

3234
def _voxel_to_obm_dtype(t) -> obm.ShloDType:
3335
if not isinstance(t, np.dtype):
3436
raise ValueError(f'Expected a numpy.dtype, got {t!r} of type {type(t)}')
37+
if t == np.dtype(np.str_) or t == np.dtype(np.bytes_):
38+
return obm.ShloDType.str
3539
return obm.np_dtype_to_shlo_dtype(t)
3640

3741

0 commit comments

Comments
 (0)