diff --git a/disentangled_rnns/library/rnn_utils.py b/disentangled_rnns/library/rnn_utils.py index 2be6621..c5aebf4 100644 --- a/disentangled_rnns/library/rnn_utils.py +++ b/disentangled_rnns/library/rnn_utils.py @@ -641,7 +641,9 @@ def unroll_network(xs): apply = jax.jit(model.apply) y_hats, states = apply(params, key, xs) - states = np.squeeze(np.array(states)) + if isinstance(states, tuple) or isinstance(states, list) and len(states) == 1: + states = np.array(states[0]) + # States should now be (n_timesteps, n_episodes, n_hidden) assert states.shape[0] == xs.shape[0], ( 'States and inputs should have the same number of timesteps.')