-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Description
Hi,
Thank you for your great work. It is really cool to open source such an amazing code base!
TL;DR
@yogesh1q2w, and I noticed that the last transitions of a trajectory are not properly handled. Indeed, multiple ReplayElements with a terminal flag are stored when only one is given to the accumulator.
It is problematic because the additional terminal states do not correspond to states that can be observed from the environment. This is problematic because we use function approximation.
How to reproduce?
After forking the repo and running
python3.11.5 -m venv env_cpu
source env_cpu/bin/activate
pip install --upgrade pip setuptools wheel
pip install -e .
I ran
import numpy as np
from dopamine.jax.replay_memory import accumulator, samplers, replay_buffer, elements
transition_accumulator = accumulator.TransitionAccumulator(stack_size=4, update_horizon=1, gamma=0.99)
sampling_distribution = samplers.UniformSamplingDistribution(seed=1)
rb = replay_buffer.ReplayBuffer(
transition_accumulator=transition_accumulator,
sampling_distribution=sampling_distribution,
batch_size=1,
max_capacity=50,
compress=False
)
for i in range(8):
rb.add(elements.TransitionElement(i * np.ones(1), i, i, False if i < 7 else True, False))
print(rb._memory)OrderedDict([(0,
ReplayElement(state=array([[0., 0., 0., 0.]]), action=0, reward=0.0, next_state=array([[0., 0., 0., 1.]]), is_terminal=False, episode_end=False)),
(1,
ReplayElement(state=array([[0., 0., 0., 1.]]), action=1, reward=1.0, next_state=array([[0., 0., 1., 2.]]), is_terminal=False, episode_end=False)),
(2,
ReplayElement(state=array([[0., 0., 1., 2.]]), action=2, reward=2.0, next_state=array([[0., 1., 2., 3.]]), is_terminal=False, episode_end=False)),
(3,
ReplayElement(state=array([[0., 1., 2., 3.]]), action=3, reward=3.0, next_state=array([[1., 2., 3., 4.]]), is_terminal=False, episode_end=False)),
(4,
ReplayElement(state=array([[1., 2., 3., 4.]]), action=4, reward=4.0, next_state=array([[2., 3., 4., 5.]]), is_terminal=False, episode_end=False)),
(5,
ReplayElement(state=array([[2., 3., 4., 5.]]), action=5, reward=5.0, next_state=array([[3., 4., 5., 6.]]), is_terminal=False, episode_end=False)),
(6,
ReplayElement(state=array([[3., 4., 5., 6.]]), action=6, reward=6.0, next_state=array([[4., 5., 6., 7.]]), is_terminal=True, episode_end=True)),
(7,
ReplayElement(state=array([[0., 4., 5., 6.]]), action=6, reward=6.0, next_state=array([[4., 5., 6., 7.]]), is_terminal=True, episode_end=True)),
(8,
ReplayElement(state=array([[0., 0., 5., 6.]]), action=6, reward=6.0, next_state=array([[0., 5., 6., 7.]]), is_terminal=True, episode_end=True)),
(9,
ReplayElement(state=array([[0., 0., 0., 6.]]), action=6, reward=6.0, next_state=array([[0., 0., 6., 7.]]), is_terminal=True, episode_end=True))])The last 3 ReplayElements are incorrect. They should not have been added.
How to fix the bug?
Replacing the following lines
dopamine/dopamine/jax/replay_memory/accumulator.py
Lines 74 to 82 in bec5f4e
| # Check if we have a valid transition, i.e. we either | |
| # 1) have accumulated more transitions than the update horizon | |
| # 2) have a trajectory shorter than the update horizon, but the | |
| # last element is terminal | |
| if not ( | |
| trajectory_len > self._update_horizon | |
| or (trajectory_len > 1 and last_transition.is_terminal) | |
| ): | |
| return None |
by
# Check if we have a valid transition, i.e. we either
# 1) have accumulated more transitions than the update horizon and the
# last element is not terminal
# 2) have a trajectory shorter than the update horizon, but the
# last element is terminal and we have enough frames to stack
if not (
(trajectory_len > self._update_horizon and not last_transition.is_terminal)
or (trajectory_len > self._stack_size and last_transition.is_terminal)
):
return Nonesolves the issue. Indeed, by running the same code again, we obtain:
OrderedDict([(0,
ReplayElement(state=array([[0., 0., 0., 0.]]), action=0, reward=0.0, next_state=array([[0., 0., 0., 1.]]), is_terminal=False, episode_end=False)),
(1,
ReplayElement(state=array([[0., 0., 0., 1.]]), action=1, reward=1.0, next_state=array([[0., 0., 1., 2.]]), is_terminal=False, episode_end=False)),
(2,
ReplayElement(state=array([[0., 0., 1., 2.]]), action=2, reward=2.0, next_state=array([[0., 1., 2., 3.]]), is_terminal=False, episode_end=False)),
(3,
ReplayElement(state=array([[0., 1., 2., 3.]]), action=3, reward=3.0, next_state=array([[1., 2., 3., 4.]]), is_terminal=False, episode_end=False)),
(4,
ReplayElement(state=array([[1., 2., 3., 4.]]), action=4, reward=4.0, next_state=array([[2., 3., 4., 5.]]), is_terminal=False, episode_end=False)),
(5,
ReplayElement(state=array([[2., 3., 4., 5.]]), action=5, reward=5.0, next_state=array([[3., 4., 5., 6.]]), is_terminal=False, episode_end=False)),
(6,
ReplayElement(state=array([[3., 4., 5., 6.]]), action=6, reward=6.0, next_state=array([[4., 5., 6., 7.]]), is_terminal=True, episode_end=True))])The last ReplayElements have been filtered 🎉