We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 29d5875 commit 8b2e80eCopy full SHA for 8b2e80e
model/orbax/experimental/model/jd2obm/utils.py
@@ -25,13 +25,17 @@
25
26
def _obm_to_voxel_dtype(t):
27
if isinstance(t, obm.ShloDType):
28
+ if t == obm.ShloDType.str:
29
+ return np.dtype(np.str_)
30
return obm.shlo_dtype_to_np_dtype(t)
31
return t
32
33
34
def _voxel_to_obm_dtype(t) -> obm.ShloDType:
35
if not isinstance(t, np.dtype):
36
raise ValueError(f'Expected a numpy.dtype, got {t!r} of type {type(t)}')
37
+ if t == np.dtype(np.str_) or t == np.dtype(np.bytes_):
38
+ return obm.ShloDType.str
39
return obm.np_dtype_to_shlo_dtype(t)
40
41
0 commit comments