Skip to content

Bug in the Reply Buffer: end of episodes is not correctly handled #228

@theovincent

Description

@theovincent

Hi,

Thank you for your great work. It is really cool to open source such an amazing code base!

TL;DR

@yogesh1q2w, and I noticed that the last transitions of a trajectory are not properly handled. Indeed, multiple ReplayElements with a terminal flag are stored when only one is given to the accumulator.

It is problematic because the additional terminal states do not correspond to states that can be observed from the environment. This is problematic because we use function approximation.

How to reproduce?

After forking the repo and running

python3.11.5 -m venv env_cpu
source env_cpu/bin/activate
pip install --upgrade pip setuptools wheel
pip install -e .

I ran

import numpy as np
from dopamine.jax.replay_memory import accumulator, samplers, replay_buffer, elements

transition_accumulator = accumulator.TransitionAccumulator(stack_size=4, update_horizon=1, gamma=0.99)
sampling_distribution = samplers.UniformSamplingDistribution(seed=1)
rb = replay_buffer.ReplayBuffer(
	transition_accumulator=transition_accumulator,
	sampling_distribution=sampling_distribution,
	batch_size=1,
	max_capacity=50,
	compress=False
)

for i in range(8):
	rb.add(elements.TransitionElement(i * np.ones(1), i, i, False if i < 7 else True, False))

print(rb._memory)
OrderedDict([(0,
              ReplayElement(state=array([[0., 0., 0., 0.]]), action=0, reward=0.0, next_state=array([[0., 0., 0., 1.]]), is_terminal=False, episode_end=False)),
             (1,
              ReplayElement(state=array([[0., 0., 0., 1.]]), action=1, reward=1.0, next_state=array([[0., 0., 1., 2.]]), is_terminal=False, episode_end=False)),
             (2,
              ReplayElement(state=array([[0., 0., 1., 2.]]), action=2, reward=2.0, next_state=array([[0., 1., 2., 3.]]), is_terminal=False, episode_end=False)),
             (3,
              ReplayElement(state=array([[0., 1., 2., 3.]]), action=3, reward=3.0, next_state=array([[1., 2., 3., 4.]]), is_terminal=False, episode_end=False)),
             (4,
              ReplayElement(state=array([[1., 2., 3., 4.]]), action=4, reward=4.0, next_state=array([[2., 3., 4., 5.]]), is_terminal=False, episode_end=False)),
             (5,
              ReplayElement(state=array([[2., 3., 4., 5.]]), action=5, reward=5.0, next_state=array([[3., 4., 5., 6.]]), is_terminal=False, episode_end=False)),
             (6,
              ReplayElement(state=array([[3., 4., 5., 6.]]), action=6, reward=6.0, next_state=array([[4., 5., 6., 7.]]), is_terminal=True, episode_end=True)),
             (7,
              ReplayElement(state=array([[0., 4., 5., 6.]]), action=6, reward=6.0, next_state=array([[4., 5., 6., 7.]]), is_terminal=True, episode_end=True)),
             (8,
              ReplayElement(state=array([[0., 0., 5., 6.]]), action=6, reward=6.0, next_state=array([[0., 5., 6., 7.]]), is_terminal=True, episode_end=True)),
             (9,
              ReplayElement(state=array([[0., 0., 0., 6.]]), action=6, reward=6.0, next_state=array([[0., 0., 6., 7.]]), is_terminal=True, episode_end=True))])

The last 3 ReplayElements are incorrect. They should not have been added.

How to fix the bug?

Replacing the following lines

# Check if we have a valid transition, i.e. we either
# 1) have accumulated more transitions than the update horizon
# 2) have a trajectory shorter than the update horizon, but the
# last element is terminal
if not (
trajectory_len > self._update_horizon
or (trajectory_len > 1 and last_transition.is_terminal)
):
return None

by

    # Check if we have a valid transition, i.e. we either
    #   1) have accumulated more transitions than the update horizon and the
    #      last element is not terminal
    #   2) have a trajectory shorter than the update horizon, but the
    #      last element is terminal and we have enough frames to stack
    if not (
        (trajectory_len > self._update_horizon and not last_transition.is_terminal)
        or (trajectory_len > self._stack_size and last_transition.is_terminal)
    ):
        return None

solves the issue. Indeed, by running the same code again, we obtain:

OrderedDict([(0,
              ReplayElement(state=array([[0., 0., 0., 0.]]), action=0, reward=0.0, next_state=array([[0., 0., 0., 1.]]), is_terminal=False, episode_end=False)),
             (1,
              ReplayElement(state=array([[0., 0., 0., 1.]]), action=1, reward=1.0, next_state=array([[0., 0., 1., 2.]]), is_terminal=False, episode_end=False)),
             (2,
              ReplayElement(state=array([[0., 0., 1., 2.]]), action=2, reward=2.0, next_state=array([[0., 1., 2., 3.]]), is_terminal=False, episode_end=False)),
             (3,
              ReplayElement(state=array([[0., 1., 2., 3.]]), action=3, reward=3.0, next_state=array([[1., 2., 3., 4.]]), is_terminal=False, episode_end=False)),
             (4,
              ReplayElement(state=array([[1., 2., 3., 4.]]), action=4, reward=4.0, next_state=array([[2., 3., 4., 5.]]), is_terminal=False, episode_end=False)),
             (5,
              ReplayElement(state=array([[2., 3., 4., 5.]]), action=5, reward=5.0, next_state=array([[3., 4., 5., 6.]]), is_terminal=False, episode_end=False)),
             (6,
              ReplayElement(state=array([[3., 4., 5., 6.]]), action=6, reward=6.0, next_state=array([[4., 5., 6., 7.]]), is_terminal=True, episode_end=True))])

The last ReplayElements have been filtered 🎉

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions