forked from k4zmu2a/SpaceCadetPinball
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathballhandler.py
More file actions
172 lines (151 loc) · 6.12 KB
/
ballhandler.py
File metadata and controls
172 lines (151 loc) · 6.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import subprocess
import os
import numpy as np
from multiprocessing import shared_memory
import torch
import random
from lib import device
class ActionSpace:
def __init__(self, num_actions):
self.space = list(range(num_actions))
def sample(self):
return random.choice(self.space)
class GameEnvironment:
def __init__(self, width, height, n_frames, plotter=None):
self.width = width
self.height = height
self.save_width = width//2
self.save_height = height//2
self.frame_id = 0
self.same_reward_counter = 0
self.prev_action = None
self.plotter=plotter
self.action_space = ActionSpace(4)
self.left_flipper_up = False
self.right_flipper_up = False
self.plunger_down = False
self.n_frames = n_frames
self.prev_score = np.array([0], dtype=np.int32)
### INIT SHM
self.shm_objs = []
# ball info is not used but needs to be opened atm since we modified the C++ code.
init_ball_info = np.array([-1, -1, -1, -1, -1, -1, -1], dtype=np.float32)
self.ball_info = self.init_shared_memory("ball_info", init_ball_info, np.float32)
self.init_sem = np.array([0], dtype=np.int32)
self.sem = self.init_shared_memory("sem", self.init_sem, np.int32)
init_action = np.array([33], dtype=np.uint8)
self.action = self.init_shared_memory("action", init_action, np.uint8)
init_score = np.array([0], dtype=np.int32)
self.score = self.init_shared_memory("score", init_score, np.int32)
init_pixels = np.zeros([height//2*width//2*4], dtype=np.uint8)
self.pixels = self.init_shared_memory("pixels", init_pixels, np.uint8)
# START GAME AND FAST FORWARD
self.process = self.start_game()
self.fast_forward_frames(550)
def __del__(self):
self.process.kill()
"""if self.plotter:
self.plotter.process_data(f"Score: {self.score[0]}")
else:
print(f"Score: {self.score[0]}")"""
for shm in self.shm_objs:
shm.close()
shm.unlink()
self.process.terminate()
def init_shared_memory(self, name, data, dtype):
shm = shared_memory.SharedMemory(name, create=True, size=data.nbytes)
self.shm_objs.append(shm)
arr = np.ndarray(data.shape, dtype=dtype, buffer=shm.buf)
arr[:] = data[:]
return arr
def start_game(self):
# Create a dictionary to hold the environment variables
env = dict(os.environ)
env["N_FRAMES"] = str(self.n_frames)
c_program_path = './bin/SpaceCadetPinball'
return subprocess.Popen(c_program_path, shell=True, env=env)
def fast_forward_frames(self, n):
for _ in range(n):
while self.sem[0] == 0:
pass
self.frame_id += 1
self.sem[:] = self.init_sem[:] # Tell C to proceed
def is_done(self):
if self.sem[0] < 0 or self.ball_info[1]>14.0:
return True
return False
def is_stuck(self):
if self.same_reward_counter > 500:
return True
return False
def int_to_c_action(self, int_action):
"""Given an integer action representing toggle actions, return a string represting the action we should send to the game."""
if int_action == 0:
action = "l" if self.left_flipper_up else "L"
elif int_action == 1:
action = "r" if self.right_flipper_up else "R"
elif int_action == 2:
action = "." if self.plunger_down else "!"
elif int_action == 3:
action = "p"
else:
raise Exception(f"Unknown int_action: {int_action}")
return np.array([ord(action)], dtype=np.uint8)
def get_reward(self):
score_diff = self.score[0] - self.prev_score[0]
if score_diff == 0:
self.same_reward_counter += 1
reward = 0
else:
self.same_reward_counter = 0
reward = 1
reward = torch.tensor(reward, dtype=torch.float32)
reward.to(device)
self.prev_score[:] = self.score[:]
return reward, score_diff
def get_ball_state(self):
"""Get the current ball state (position and velocity, 4 values)."""
state = self.ball_info.astype(np.float32)
state = state[[0,1,4,5]]
state[0] = state[0] / 10
state[1] = state[1] / 20
state = torch.from_numpy(state)
state = state.to(device)
return state
def get_state(self):
"""Get the complete state, including ball position/velocity and flipper/plunger toggle values (7 values)."""
action_state_tensor = torch.tensor([self.left_flipper_up, self.right_flipper_up, self.plunger_down], dtype=torch.float32).to(device)
return torch.cat((self.get_ball_state(), action_state_tensor), dim=0)
def update_toggles(self, action):
if action == 0:
self.left_flipper_up = not self.left_flipper_up
elif action == 1:
self.right_flipper_up = not self.right_flipper_up
elif action == 2:
self.plunger_down = not self.plunger_down
def step(self, action):
# Action is one of 4 possible values:
# 0: toggle left flipper
# 1: toggle right flipper
# 2: toggle plunger
# 3: do nothing
self.action[:] = self.int_to_c_action(action)[:]
# Update our internal representation of flipper and plunger
self.update_toggles(action)
while self.sem[0] != self.n_frames:
if self.sem[0] < 0:
break
# sem is either < 0 or 4 here
if self.sem[0] == self.n_frames:
self.sem[:] = self.init_sem[:]
state = self.get_state()
reward, score_diff = self.get_reward()
is_done, is_stuck = self.is_done(), self.is_stuck()
# Negative reward if we lose
if is_done or is_stuck:
reward -= 1
# Negative reward if we take actions
if action in range(3):
reward -= 0.1
self.frame_id += self.n_frames
return state, reward, score_diff, is_done, is_stuck