From 65582d1f2cc714f3c89fdba536ac8c97cd41653a Mon Sep 17 00:00:00 2001 From: GDM Neurolab Date: Thu, 24 Apr 2025 20:36:39 -0700 Subject: [PATCH] NA PiperOrigin-RevId: 751249454 --- disentangled_rnns/library/rnn_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/disentangled_rnns/library/rnn_utils.py b/disentangled_rnns/library/rnn_utils.py index 2be6621..c5aebf4 100644 --- a/disentangled_rnns/library/rnn_utils.py +++ b/disentangled_rnns/library/rnn_utils.py @@ -641,7 +641,9 @@ def unroll_network(xs): apply = jax.jit(model.apply) y_hats, states = apply(params, key, xs) - states = np.squeeze(np.array(states)) + if isinstance(states, tuple) or isinstance(states, list) and len(states) == 1: + states = np.array(states[0]) + # States should now be (n_timesteps, n_episodes, n_hidden) assert states.shape[0] == xs.shape[0], ( 'States and inputs should have the same number of timesteps.')