From 67aaea06872eac43ede44f5bfb28559b64b59199 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Thu, 22 Jan 2026 17:48:04 +0000 Subject: [PATCH] Simpler indexing in JAX DQN --- cleanrl/dqn_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cleanrl/dqn_jax.py b/cleanrl/dqn_jax.py index 4bda81920..8e006634b 100644 --- a/cleanrl/dqn_jax.py +++ b/cleanrl/dqn_jax.py @@ -171,7 +171,7 @@ def update(q_state, observations, actions, next_observations, rewards, dones): def mse_loss(params): q_pred = q_network.apply(params, observations) # (batch_size, num_actions) - q_pred = q_pred[jnp.arange(q_pred.shape[0]), actions.squeeze()] # (batch_size,) + q_pred = q_pred[:, actions.squeeze()] # (batch_size,) return ((q_pred - next_q_value) ** 2).mean(), q_pred (loss_value, q_pred), grads = jax.value_and_grad(mse_loss, has_aux=True)(q_state.params)