Skip to content

Commit c30f9b3

Browse files
hawkinspcopybara-github
authored andcommitted
[JAX] Make haiku tests pass with jax_explicit_x64_dtypes=ERROR
PiperOrigin-RevId: 914349194
1 parent f69e304 commit c30f9b3

8 files changed

Lines changed: 22 additions & 10 deletions

File tree

haiku/_src/basic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def __call__(
166166
if not inputs.shape:
167167
raise ValueError("Input must not be scalar.")
168168

169+
inputs = jnp.asarray(inputs)
169170
input_size = self.input_size = inputs.shape[-1]
170171
output_size = self.output_size
171172
dtype = inputs.dtype

haiku/_src/conv.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def __call__(
177177
unbatched, or an array of shape ``[N, spatial_dims, output_channels]``
178178
and rank-N+2 if batched.
179179
"""
180+
inputs = jnp.asarray(inputs)
180181
unbatched_rank = self.num_spatial_dims + 1
181182
allowed_ranks = [unbatched_rank, unbatched_rank + 1]
182183
if inputs.ndim not in allowed_ranks:
@@ -574,6 +575,7 @@ def __call__(
574575
unbatched, or an array of shape ``[N, spatial_dims, output_channels]``
575576
and rank-N+2 if batched.
576577
"""
578+
inputs = jnp.asarray(inputs)
577579
unbatched_rank = self.num_spatial_dims + 1
578580
allowed_ranks = [unbatched_rank, unbatched_rank + 1]
579581
if inputs.ndim not in allowed_ranks:

haiku/_src/layer_norm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def __call__(
149149
Returns:
150150
The array, normalized.
151151
"""
152+
inputs = jnp.asarray(inputs)
152153
if self.create_scale and scale is not None:
153154
raise ValueError(
154155
"Cannot pass `scale` at call time if `create_scale=True`.")

haiku/_src/layer_norm_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def test_multiple_param_axis(self, param_axis, param_shape):
254254
ln(x)
255255
self.assertEqual(ln.params_dict()["layer_norm/scale"].shape, param_shape)
256256
self.assertEqual(ln.params_dict()["layer_norm/offset"].shape, param_shape)
257-
257+
258258
@staticmethod
259259
def _ln_f(axis, param_axis, create_scale=True, create_offset=True):
260260
def f(x):
@@ -269,7 +269,7 @@ def f(x):
269269
def test_layernorm_dtype_propagation_float64(self):
270270
"""Input dtype -> output dtype; params dtype match input."""
271271
fwd = LayerNormTest._ln_f(axis=-1, param_axis=-1)
272-
x = jnp.arange(12, dtype=jnp.float64).reshape(3, 4)
272+
x = jnp.arange(12, dtype=float).reshape(3, 4)
273273
params = fwd.init(jax.random.PRNGKey(0), x)
274274
y = fwd.apply(params, None, x)
275275
self.assertEqual(y.dtype, x.dtype)

haiku/_src/nets/vqvae_test.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,17 +113,11 @@ def testNoneBatch(self, constructor, kwargs):
113113
vqvae_module(inputs, is_training=False)
114114

115115
@parameterized.parameters({'use_jit': True, 'dtype': jnp.float32},
116-
{'use_jit': True, 'dtype': jnp.float64},
117-
{'use_jit': False, 'dtype': jnp.float32},
118-
{'use_jit': False, 'dtype': jnp.float64})
116+
{'use_jit': False, 'dtype': jnp.float32})
119117
@test_utils.transform_and_run
120118
def testEmaUpdating(self, use_jit, dtype):
121-
if jax.local_devices()[0].platform == 'tpu' and dtype == jnp.float64:
122-
self.skipTest('F64 not supported by TPU')
123-
124119
embedding_dim = 6
125-
np_dtype = np.float64 if dtype is jnp.float64 else np.float32
126-
decay = np.array(0.1, dtype=np_dtype)
120+
decay = np.array(0.1, dtype=np.float32)
127121
vqvae_module = vqvae.VectorQuantizerEMA(
128122
embedding_dim=embedding_dim,
129123
num_embeddings=7,

haiku/_src/pool.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def max_pool(
9494
if padding not in ("SAME", "VALID"):
9595
raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
9696

97+
value = jnp.asarray(value)
9798
_warn_if_unsafe(window_shape, strides)
9899
window_shape = _infer_shape(value, window_shape, channel_axis)
99100
strides = _infer_shape(value, strides, channel_axis)
@@ -127,6 +128,7 @@ def avg_pool(
127128
if padding not in ("SAME", "VALID"):
128129
raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
129130

131+
value = jnp.asarray(value)
130132
_warn_if_unsafe(window_shape, strides)
131133
window_shape = _infer_shape(value, window_shape, channel_axis)
132134
strides = _infer_shape(value, strides, channel_axis)

haiku/_src/recurrent.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,8 @@ def __init__(
260260
self.double_bias = double_bias
261261

262262
def __call__(self, inputs, prev_state):
263+
inputs = jnp.asarray(inputs)
264+
prev_state = jnp.asarray(prev_state)
263265
input_to_hidden = hk.Linear(self.hidden_size)
264266
# TODO(b/173771088): Consider changing default to double_bias=False.
265267
hidden_to_hidden = hk.Linear(self.hidden_size, with_bias=self.double_bias)
@@ -329,6 +331,8 @@ def __call__(
329331
inputs: jax.Array,
330332
prev_state: LSTMState,
331333
) -> tuple[jax.Array, LSTMState]:
334+
inputs = jnp.asarray(inputs)
335+
prev_state = jax.tree.map(jnp.asarray, prev_state)
332336
if len(inputs.shape) > 2 or not inputs.shape:
333337
raise ValueError("LSTM input must be rank-1 or rank-2.")
334338
x_and_h = jnp.concatenate([inputs, prev_state.hidden], axis=-1)
@@ -410,6 +414,8 @@ def __call__(
410414
inputs,
411415
state: LSTMState,
412416
) -> tuple[jax.Array, LSTMState]:
417+
inputs = jnp.asarray(inputs)
418+
state = jax.tree.map(jnp.asarray, state)
413419
input_to_hidden = hk.ConvND(
414420
num_spatial_dims=self.num_spatial_dims,
415421
output_channels=4 * self.output_channels,
@@ -559,6 +565,8 @@ def __init__(
559565
self.b_init = b_init or jnp.zeros
560566

561567
def __call__(self, inputs, state):
568+
inputs = jnp.asarray(inputs)
569+
state = jnp.asarray(state)
562570
if inputs.ndim not in (1, 2):
563571
raise ValueError("GRU input must be rank-1 or rank-2.")
564572

@@ -650,6 +658,9 @@ def __call__(self, inputs, state):
650658
Tuple of the wrapped core's ``output, next_state``.
651659
"""
652660
inputs, should_reset = inputs
661+
inputs = jax.tree.map(jnp.asarray, inputs)
662+
should_reset = jax.tree.map(jnp.asarray, should_reset)
663+
state = jax.tree.map(jnp.asarray, state)
653664
if jax.tree_util.treedef_is_leaf(jax.tree.structure(should_reset)):
654665
# Equivalent to not tree.is_nested, but with support for Jax extensible
655666
# pytrees.

haiku/_src/rms_norm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def __call__(self, inputs: jax.Array):
106106
Returns:
107107
The normalized array, of the same shape as the inputs.
108108
"""
109+
inputs = jnp.asarray(inputs)
109110
axis = self.axis
110111
if isinstance(axis, slice):
111112
axis = tuple(range(inputs.ndim)[axis])

0 commit comments

Comments
 (0)