-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathframe_buffer.py
More file actions
40 lines (33 loc) · 960 Bytes
/
frame_buffer.py
File metadata and controls
40 lines (33 loc) · 960 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import numpy as np
class FrameBuffer:
def __init__(self, args):
self.history_length = args.history_length
self.screen_size = (args.screen_height, args.screen_width)
self.buffer = np.zeros([1, self.history_length, args.screen_height, args.screen_width], dtype=np.uint8)
def add(self, screen):
assert screen.shape == self.screen_size
assert screen.dtype == np.uint8
self.buffer = np.roll(self.buffer, 1, axis=1)
self.buffer[0, 0] = screen
assert np.count_nonzero(self.buffer[0, 0]) > 0
def get_state(self):
return self.buffer[0].transpose(1, 2, 0)
def get_state_as_batch(self):
return self.buffer.transpose(0, 2, 3, 1)
def reset(self):
self.buffer *= 0
'''
class args:
def __init__(self):
self.screen_height = 10
self.screen_width = 10
self.history_length = 3
buf = FrameBuffer(args())
ar = np.ndarray([10,10])
ar[0] = 1
buf.add(ar)
ar[0] = 2
buf.add(ar)
ar[0] = 3
buf.add(ar)
'''