-
Notifications
You must be signed in to change notification settings - Fork 37
Description
Problem Description
In examples/alphazero/train.py, we compute value_mask as follows:
pgx/examples/alphazero/train.py
Line 179 in 87278d2
| value_mask = jnp.cumsum(data.terminated[::-1, :], axis=0)[::-1, :] >= 1 |
The purpose is to avoid updating the critic network on incomplete trajectories, as is evident by masking of value loss:
pgx/examples/alphazero/train.py
Line 211 in 87278d2
| value_loss = jnp.mean(value_loss * samples.mask) # mask if the episode is truncated |
Now, critic and actor networks share a torso of residual blocks as defined in network.py, and while we mask value losses, we don't mask policy losses for samples from incomplete trajectories:
pgx/examples/alphazero/train.py
Line 207 in 87278d2
| policy_loss = optax.softmax_cross_entropy(logits, samples.policy_tgt) |
Therefore, we are in fact inadvertently influencing both the policy and the value network outputs by samples from incomplete trajectories. This seems to be against the intended effect of defining value_mask.
Possible Solutions
- To mask out the effect of truncated trajectories from computation of policy loss as well.
- To bootstrap value target for truncated trajectories.
I am not sure which of these or another solution is used by the original AlphaZero papers.