Skip to content

Handling of truncated trajectories in AlphaZero training example #1306

@shehper

Description

@shehper

Problem Description

In examples/alphazero/train.py, we compute value_mask as follows:

value_mask = jnp.cumsum(data.terminated[::-1, :], axis=0)[::-1, :] >= 1

The purpose is to avoid updating the critic network on incomplete trajectories, as is evident by masking of value loss:

value_loss = jnp.mean(value_loss * samples.mask) # mask if the episode is truncated

Now, critic and actor networks share a torso of residual blocks as defined in network.py, and while we mask value losses, we don't mask policy losses for samples from incomplete trajectories:

policy_loss = optax.softmax_cross_entropy(logits, samples.policy_tgt)

Therefore, we are in fact inadvertently influencing both the policy and the value network outputs by samples from incomplete trajectories. This seems to be against the intended effect of defining value_mask.


Possible Solutions

  1. To mask out the effect of truncated trajectories from computation of policy loss as well.
  2. To bootstrap value target for truncated trajectories.

I am not sure which of these or another solution is used by the original AlphaZero papers.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions