@@ -260,6 +260,8 @@ def __init__(
260260 self .double_bias = double_bias
261261
262262 def __call__ (self , inputs , prev_state ):
263+ inputs = jnp .asarray (inputs )
264+ prev_state = jnp .asarray (prev_state )
263265 input_to_hidden = hk .Linear (self .hidden_size )
264266 # TODO(b/173771088): Consider changing default to double_bias=False.
265267 hidden_to_hidden = hk .Linear (self .hidden_size , with_bias = self .double_bias )
@@ -329,6 +331,8 @@ def __call__(
329331 inputs : jax .Array ,
330332 prev_state : LSTMState ,
331333 ) -> tuple [jax .Array , LSTMState ]:
334+ inputs = jnp .asarray (inputs )
335+ prev_state = jax .tree .map (jnp .asarray , prev_state )
332336 if len (inputs .shape ) > 2 or not inputs .shape :
333337 raise ValueError ("LSTM input must be rank-1 or rank-2." )
334338 x_and_h = jnp .concatenate ([inputs , prev_state .hidden ], axis = - 1 )
@@ -410,6 +414,8 @@ def __call__(
410414 inputs ,
411415 state : LSTMState ,
412416 ) -> tuple [jax .Array , LSTMState ]:
417+ inputs = jnp .asarray (inputs )
418+ state = jax .tree .map (jnp .asarray , state )
413419 input_to_hidden = hk .ConvND (
414420 num_spatial_dims = self .num_spatial_dims ,
415421 output_channels = 4 * self .output_channels ,
@@ -559,6 +565,8 @@ def __init__(
559565 self .b_init = b_init or jnp .zeros
560566
561567 def __call__ (self , inputs , state ):
568+ inputs = jnp .asarray (inputs )
569+ state = jnp .asarray (state )
562570 if inputs .ndim not in (1 , 2 ):
563571 raise ValueError ("GRU input must be rank-1 or rank-2." )
564572
@@ -650,6 +658,9 @@ def __call__(self, inputs, state):
650658 Tuple of the wrapped core's ``output, next_state``.
651659 """
652660 inputs , should_reset = inputs
661+ inputs = jax .tree .map (jnp .asarray , inputs )
662+ should_reset = jax .tree .map (jnp .asarray , should_reset )
663+ state = jax .tree .map (jnp .asarray , state )
653664 if jax .tree_util .treedef_is_leaf (jax .tree .structure (should_reset )):
654665 # Equivalent to not tree.is_nested, but with support for Jax extensible
655666 # pytrees.
0 commit comments