diff --git a/pfrl/replay_buffer.py b/pfrl/replay_buffer.py index 7da0fd3f9..583528099 100644 --- a/pfrl/replay_buffer.py +++ b/pfrl/replay_buffer.py @@ -113,6 +113,12 @@ def stop_current_episode(self, env_id=0): """ raise NotImplementedError + @abstractmethod + def clear(self): + """Clears the replay buffer. + """ + raise NotImplementedError + class AbstractEpisodicReplayBuffer(AbstractReplayBuffer): """Defines a common interface of episodic replay buffer. @@ -145,6 +151,12 @@ def n_episodes(self): """ raise NotImplementedError + @abstractmethod + def clear(self): + """Clears the replay buffer. + """ + raise NotImplementedError + def random_subseq(seq, subseq_len): if len(seq) <= subseq_len: diff --git a/pfrl/replay_buffers/episodic.py b/pfrl/replay_buffers/episodic.py index 31e88b0e4..c480430b0 100644 --- a/pfrl/replay_buffers/episodic.py +++ b/pfrl/replay_buffers/episodic.py @@ -12,9 +12,8 @@ class EpisodicReplayBuffer(AbstractEpisodicReplayBuffer): capacity: Optional[int] = None def __init__(self, capacity=None): - self.current_episode = collections.defaultdict(list) - self.episodic_memory = RandomAccessQueue() - self.memory = RandomAccessQueue() + # initialize data structures + self.initialize_memory() self.capacity = capacity def append( @@ -97,3 +96,12 @@ def stop_current_episode(self, env_id=0): for _ in range(len(discarded_episode)): self.memory.popleft() assert not self.current_episode[env_id] + + def initialize_memory(self): + self.current_episode = collections.defaultdict(list) + self.episodic_memory = RandomAccessQueue() + self.memory = RandomAccessQueue() + + def clear(self): + self.initialize_memory() + diff --git a/pfrl/replay_buffers/persistent.py b/pfrl/replay_buffers/persistent.py index 19342df2e..fe7dfea74 100644 --- a/pfrl/replay_buffers/persistent.py +++ b/pfrl/replay_buffers/persistent.py @@ -80,6 +80,9 @@ def load(self, _): "{}.load() has been ignored, as it is persistent replay buffer".format(self) ) + def clear(self): + raise NotImplementedError + class PersistentEpisodicReplayBuffer(EpisodicReplayBuffer): """Episodic version of :py:class:`PersistentReplayBuffer` @@ -163,3 +166,6 @@ def load(self, _): warnings.warn( "PersistentEpisodicReplayBuffer.load() is called but it has not effect." ) + + def clear(self): + raise NotImplementedError diff --git a/pfrl/replay_buffers/prioritized.py b/pfrl/replay_buffers/prioritized.py index 5814d2e71..3ca014469 100644 --- a/pfrl/replay_buffers/prioritized.py +++ b/pfrl/replay_buffers/prioritized.py @@ -99,6 +99,24 @@ def __init__( self.capacity = capacity assert num_steps > 0 self.num_steps = num_steps + self.beta0 = beta0 + self.betasteps = betasteps + self.initialize_memory(capacity, num_steps, alpha, beta0, betasteps, + eps, normalize_by_max, error_min, error_max) + + def sample(self, n): + assert len(self.memory) >= n + sampled, probabilities, min_prob = self.memory.sample(n) + weights = self.weights_from_probabilities(probabilities, min_prob) + for e, w in zip(sampled, weights): + e[0]["weight"] = w + return sampled + + def update_errors(self, errors): + self.memory.set_last_priority(self.priority_from_errors(errors)) + + def initialize_memory(self, capacity, num_steps, alpha, beta0, betasteps, + eps, normalize_by_max, error_min, error_max): self.memory = PrioritizedBuffer(capacity=capacity) self.last_n_transitions = collections.defaultdict( lambda: collections.deque([], maxlen=num_steps) @@ -114,13 +132,6 @@ def __init__( error_max=error_max, ) - def sample(self, n): - assert len(self.memory) >= n - sampled, probabilities, min_prob = self.memory.sample(n) - weights = self.weights_from_probabilities(probabilities, min_prob) - for e, w in zip(sampled, weights): - e[0]["weight"] = w - return sampled - - def update_errors(self, errors): - self.memory.set_last_priority(self.priority_from_errors(errors)) + def clear(self): + self.initialize_memory(self.capacity, self.num_steps, self.alpha, self.beta0, self.betasteps, + self.eps, self.normalize_by_max, self.error_min, self.error_max) diff --git a/pfrl/replay_buffers/prioritized_episodic.py b/pfrl/replay_buffers/prioritized_episodic.py index e31a74863..652ce18cb 100644 --- a/pfrl/replay_buffers/prioritized_episodic.py +++ b/pfrl/replay_buffers/prioritized_episodic.py @@ -22,25 +22,16 @@ def __init__( error_min=None, error_max=None, ): - self.current_episode = collections.defaultdict(list) - self.episodic_memory = PrioritizedBuffer( - capacity=None, wait_priority_after_sampling=wait_priority_after_sampling - ) - self.memory = RandomAccessQueue(maxlen=capacity) + self.initialize_memory(capacity) + self.capacity = capacity + self.wait_priority_after_sampling = wait_priority_after_sampling + self.beta0 = beta0 + self.betasteps = betasteps + self.capacity_left = capacity self.default_priority_func = default_priority_func self.uniform_ratio = uniform_ratio self.return_sample_weights = return_sample_weights - PriorityWeightError.__init__( - self, - alpha, - beta0, - betasteps, - eps, - normalize_by_max, - error_min=error_min, - error_max=error_max, - ) def sample_episodes(self, n_episodes, max_len=None): """Sample n unique samples from this replay buffer""" @@ -75,3 +66,27 @@ def stop_current_episode(self, env_id=0): discarded_episode = self.episodic_memory.popleft() self.capacity_left += len(discarded_episode) assert not self.current_episode[env_id] + + def initialize_memory(self, capacity, wait_priority_after_sampling, + alpha, beta0, betasteps, eps, normalize_by_max, + error_min, error_max): + self.current_episode = collections.defaultdict(list) + self.episodic_memory = PrioritizedBuffer( + capacity=None, wait_priority_after_sampling=wait_priority_after_sampling + ) + self.memory = RandomAccessQueue(maxlen=capacity) + PriorityWeightError.__init__( + self, + alpha, + beta0, + betasteps, + eps, + normalize_by_max, + error_min=error_min, + error_max=error_max, + ) + + def clear(self): + self.initialize_memory(self.capacity, self.wait_priority_after_sampling, + self.alpha, self.beta0, self.betasteps, self.eps, + self.normalize_by_max, self.error_min, self.error_max) diff --git a/pfrl/replay_buffers/replay_buffer.py b/pfrl/replay_buffers/replay_buffer.py index 0db496dd0..060d538cd 100644 --- a/pfrl/replay_buffers/replay_buffer.py +++ b/pfrl/replay_buffers/replay_buffer.py @@ -25,10 +25,7 @@ def __init__(self, capacity: Optional[int] = None, num_steps: int = 1): self.capacity = capacity assert num_steps > 0 self.num_steps = num_steps - self.memory = RandomAccessQueue(maxlen=capacity) - self.last_n_transitions: collections.defaultdict = collections.defaultdict( - lambda: collections.deque([], maxlen=num_steps) - ) + self.initialize_memory(capacity, num_steps) def append( self, @@ -92,3 +89,12 @@ def load(self, filename): if isinstance(self.memory, collections.deque): # Load v0.2 self.memory = RandomAccessQueue(self.memory, maxlen=self.memory.maxlen) + + def initialize_memory(self, capacity, num_steps): + self.memory = RandomAccessQueue(maxlen=capacity) + self.last_n_transitions: collections.defaultdict = collections.defaultdict( + lambda: collections.deque([], maxlen=num_steps) + ) + + def clear(self): + self.initialize_memory(self.capacity, self.num_steps) diff --git a/tests/replay_buffers_test/test_persistent_replay_buffer.py b/tests/replay_buffers_test/test_persistent_replay_buffer.py index ce94ccfa4..0a5015788 100644 --- a/tests/replay_buffers_test/test_persistent_replay_buffer.py +++ b/tests/replay_buffers_test/test_persistent_replay_buffer.py @@ -51,6 +51,11 @@ def test_append_and_sample(self, capacity): assert t0["next_state"] == t1["state"] assert t0["next_action"] == t1["action"] + def test_clear(self, capacity): + rbuf = PersistentEpisodicReplayBuffer(self.tempdir.name, capacity) + with pytest.raises(NotImplementedError): + rbuf.clear() + def test_save_and_load(self, capacity): tempdir = tempfile.mkdtemp() @@ -172,3 +177,9 @@ def test(self): # Finally it should have 4 + 2 + 9 = 15 transitions assert len(rbuf) == 15 + + def test_clear(self): + rbuf = PersistentEpisodicReplayBuffer(self.tempdir.name, capacity=None) + with pytest.raises(NotImplementedError): + rbuf.clear() + diff --git a/tests/replay_buffers_test/test_replay_buffer.py b/tests/replay_buffers_test/test_replay_buffer.py index bf2b2b037..00c90c4c1 100644 --- a/tests/replay_buffers_test/test_replay_buffer.py +++ b/tests/replay_buffers_test/test_replay_buffer.py @@ -66,6 +66,30 @@ def test_append_and_sample(self): assert s2[1] == list(correct_item) assert s2[0] == list(correct_item2) + def test_clear(self): + capacity = self.capacity + num_steps = self.num_steps + rbuf = replay_buffers.ReplayBuffer(capacity, num_steps) + + assert len(rbuf) == 0 + + # Add one and sample one + correct_item = collections.deque([], maxlen=num_steps) + for _ in range(num_steps): + trans1 = dict( + state=0, + action=1, + reward=2, + next_state=3, + next_action=4, + is_state_terminal=False, + ) + correct_item.append(trans1) + rbuf.append(**trans1) + assert len(rbuf) == 1 + rbuf.clear() + assert len(rbuf) == 0 + def test_append_and_terminate(self): capacity = self.capacity num_steps = self.num_steps @@ -248,6 +272,32 @@ def test_append_and_sample(self): assert t0["next_state"] == t1["state"] assert t0["next_action"] == t1["action"] + def test_clear(self): + capacity = self.capacity + rbuf = replay_buffers.EpisodicReplayBuffer(capacity) + assert len(rbuf) == 0 + assert rbuf.n_episodes == 0 + for n in [10, 15, 5] * 3: + transs = [ + dict( + state=i, + action=100 + i, + reward=200 + i, + next_state=i + 1, + next_action=101 + i, + is_state_terminal=(i == n - 1), + ) + for i in range(n) + ] + for trans in transs: + rbuf.append(**trans) + + assert len(rbuf) == 90 + assert rbuf.n_episodes == 9 + rbuf.clear() + assert len(rbuf) == 0 + assert rbuf.n_episodes == 0 + def test_save_and_load(self): capacity = self.capacity @@ -398,6 +448,36 @@ def test_append_and_sample(self): s4 = rbuf.sample(2) np.testing.assert_allclose(s4[0][0]["weight"], s4[1][0]["weight"]) + def test_clear(self): + capacity = self.capacity + num_steps = self.num_steps + rbuf = replay_buffers.PrioritizedReplayBuffer( + capacity, + normalize_by_max=self.normalize_by_max, + error_max=5, + num_steps=num_steps, + ) + + # assert len(rbuf) == 0 + + # Add one and sample one + correct_item = collections.deque([], maxlen=num_steps) + for _ in range(num_steps): + trans1 = dict( + state=0, + action=1, + reward=2, + next_state=3, + next_action=4, + is_state_terminal=False, + ) + correct_item.append(trans1) + rbuf.append(**trans1) + # assert len(rbuf) == 1 + rbuf.update_errors([3.14]) + rbuf.clear() + # assert len(rbuf) == 0 + def test_normalize_by_max(self): rbuf = replay_buffers.PrioritizedReplayBuffer( @@ -701,6 +781,9 @@ def test(self): # Finally it should have 9 + 2 + 4 = 15 transitions assert len(rbuf) == 15 + rbuf.clear() + assert len(rbuf) == 0 + assert len(rbuf.last_n_transitions) == 0 @pytest.mark.parametrize( @@ -767,6 +850,9 @@ def test(self): # Finally it should have 4 + 2 + 9 = 15 transitions assert len(rbuf) == 15 + rbuf.clear() + assert len(rbuf) == 0 + assert len(rbuf.current_episode) == 0 class TestReplayBufferFail(unittest.TestCase):