Skip to content

FIX: penalty for episode length termination#453

Open
TheWill-Of-D wants to merge 1 commit intokscalelabs:masterfrom
TheWill-Of-D:master
Open

FIX: penalty for episode length termination#453
TheWill-Of-D wants to merge 1 commit intokscalelabs:masterfrom
TheWill-Of-D:master

Conversation

@TheWill-Of-D
Copy link

episode length termination penalized policy for reaching the episode's end. Policy may learn to avoid it/avoid learning gaits that are stable long term.

episode length termination penalized policy for reaching the episode termination
@CLAassistant
Copy link

CLAassistant commented Jul 5, 2025

CLA assistant check
All committers have signed the CLA.

@b-vm
Copy link
Contributor

b-vm commented Jul 5, 2025

Have you tested this change?

Whether an episode length termination penalizes or not can be controlled by setting pos_termination in EpisodeLengthTermination

ksim/ksim/terminations.py

Lines 187 to 203 in 88c8c2d

class EpisodeLengthTermination(Termination):
"""Terminates the episode if the robot has been alive for too long.
This defaults to a positive termination.
"""
max_length_sec: float = attrs.field(validator=attrs.validators.gt(0.0))
disable_at_curriculum_level: int = attrs.field(default=None)
pos_termination: bool = attrs.field(default=True)
def __call__(self, state: PhysicsData, curriculum_level: Array) -> Array:
termination_value = 1 if self.pos_termination else -1
long_episodes = jnp.where(state.time > self.max_length_sec, termination_value, 0)
if self.disable_at_curriculum_level is not None:
return jnp.where(curriculum_level < self.disable_at_curriculum_level, 0, long_episodes)
return long_episodes

@TheWill-Of-D
Copy link
Author

TheWill-Of-D commented Jul 5, 2025

I ran the pytest and am currently training with the modification. The chaotic movements are gone, so far its ok. I can update you after the training run.

You're right, EpisodeLengthTermination has a pos_termination flag to return 1 on a timeout, but this sets the successes_t flag not the done_t flag. done_t is used in ppo.py for the advantage calculation to create the value mask. In decay_gamma * values_shifted_t * mask_t, where mask_t is incorrectly made zero for episode length termination (this shouldnt be done for time-unlimited events like the tasks here).

This equation deltas_t = bootstrapped_rewards_t + decay_gamma * values_shifted_t * mask_t - values_t is calculated in a non-standard way. For time unlimited tasks, you do not make values_shifted_t = 0, this will be considered as a punishment by the agent. This alteration doesnt affect the value of the bootstrapped_rewards_t at all.

Check section 3 in this paper.

@b-vm
Copy link
Contributor

b-vm commented Jul 5, 2025

Here done is computed:

ksim/ksim/task/rl.py

Lines 937 to 957 in 88c8c2d

terminated = jax.tree.reduce(jnp.logical_or, [t != 0 for t in terminations.values()])
success = jax.tree.reduce(jnp.logical_and, [t != -1 for t in terminations.values()]) & terminated
# Combines all the relevant data into a single object. Lives up here to
# avoid accidentally incorporating information it shouldn't access to.
transition = Trajectory(
qpos=jnp.array(next_physics_state.data.qpos),
qvel=jnp.array(next_physics_state.data.qvel),
xpos=jnp.array(next_physics_state.data.xpos),
xquat=jnp.array(next_physics_state.data.xquat),
ctrl=jnp.array(next_physics_state.data.ctrl),
obs=observations,
command=env_states.commands,
event_state=next_physics_state.event_states,
action=action.action,
done=terminated,
success=success,
timestep=next_physics_state.data.time,
termination_components=terminations,
aux_outputs=action.aux_outputs,
)

Seems like done is True if any termination function produces a nonzero value, either 1 (success) or -1 (fail)

You're right that terminating episodes absent of failures will discourage learning. Thats why we bootstrap from the value function and add those values to the final reward, if and only if this happens. To the model it looks as if the termination never happened. See this PR #410

That said it would be interesting to see how your changes affect the critic loss and the total reward in an A/B test. Interesting paper btw.

@b-vm
Copy link
Contributor

b-vm commented Jul 5, 2025

EDIT: We already do this so can be disregarded

This line in the paper you mentioned got me thinking, we should try bootstrapping from the value function for all final rewards in every rollout buffer trajectory - in addition to the success bootstrapping.

"We argue that this insight should be incorporated by bootstrapping from the value of the state at the end of each partial episode."

@TheWill-Of-D
Copy link
Author

Found the same issue in SB3
DLR-RM/stable-baselines3#633

gymnasium documentation
https://gymnasium.farama.org/tutorials/gymnasium_basics/handling_time_limits/?hl=en-GB

https://arxiv.org/pdf/1712.00378

Merging only "values" across episodes doesn't negatively affect training as the properties environment itself haven't changed much. this is the standard method for non-finite long-horizon tasks (as shown in the above links).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants