Skip to content

Commit d886c80

Browse files
Merge pull request #162 from yardenas:update-pick-cartesian
PiperOrigin-RevId: 794781579 Change-Id: I79161cd7e0cf92dffe7a93a219f86d1e1c3c471b
2 parents 9811486 + 1e0fe97 commit d886c80

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

mujoco_playground/_src/manipulation/franka_emika_panda/pick_cartesian.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ def reset(self, rng: jax.Array) -> mjx_env.State:
216216
f'reward/{k}': 0.0
217217
for k in self._config.reward_config.reward_scales.keys()
218218
},
219+
'reward/success': jp.array(0.0),
220+
'reward/lifted': jp.array(0.0),
219221
}
220222

221223
info = {
@@ -335,9 +337,8 @@ def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
335337

336338
# Sparse rewards
337339
box_pos = data.xpos[self._obj_body]
338-
total_reward += (
339-
box_pos[2] > 0.05
340-
) * self._config.reward_config.lifted_reward
340+
lifted = (box_pos[2] > 0.05) * self._config.reward_config.lifted_reward
341+
total_reward += lifted
341342
success = self._get_success(data, state.info)
342343
total_reward += success * self._config.reward_config.success_reward
343344

@@ -354,6 +355,10 @@ def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
354355
out_of_bounds |= box_pos[2] < 0.0
355356
state.metrics.update(out_of_bounds=out_of_bounds.astype(float))
356357
state.metrics.update({f'reward/{k}': v for k, v in raw_rewards.items()})
358+
state.metrics.update({
359+
'reward/lifted': lifted.astype(float),
360+
'reward/success': success.astype(float),
361+
})
357362

358363
done = (
359364
out_of_bounds

0 commit comments

Comments
 (0)