-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbuffer.py
More file actions
40 lines (33 loc) · 1.69 KB
/
Copy pathbuffer.py
File metadata and controls
40 lines (33 loc) · 1.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# replay buffer for off-line training
import numpy as np
import random
from collections import deque
from typing import Tuple
class Buffer:
def __init__(self, buffer_size: int = 10000):
self.buffer_size = buffer_size
self.buffer = deque(maxlen=buffer_size)
self.buffer_idx = 0
def add(self, transition: Tuple[np.ndarray, np.ndarray, float, np.ndarray, bool]):
# transition: (state, action, reward, next_state, done)
self.buffer.append(transition)
def sample(self, batch_size: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
assert batch_size <= len(self.buffer), "The batch size is larger than the buffer size."
batch = random.sample(self.buffer, batch_size)
state_batch = np.array([transition[0] for transition in batch])
action_batch = np.array([transition[1] for transition in batch])
reward_batch = np.array([transition[2] for transition in batch])
next_state_batch = np.array([transition[3] for transition in batch])
done_batch = np.array([transition[4] for transition in batch])
return state_batch, action_batch, reward_batch, next_state_batch, done_batch
def get_data(self):
assert len(self.buffer) > 0, "The buffer is empty."
state_batch = np.array([transition[0] for transition in self.buffer])
action_batch = np.array([transition[1] for transition in self.buffer])
next_state_batch = np.array([transition[3] for transition in self.buffer])
return state_batch, action_batch, next_state_batch
def __len__(self):
return len(self.buffer)
def clear(self):
self.buffer.clear()
self.buffer_idx = 0