-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathrun.py
More file actions
33 lines (24 loc) · 857 Bytes
/
run.py
File metadata and controls
33 lines (24 loc) · 857 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from os import truncate
import gym
import numpy as np
from policy import Policy
from preprocess import preprocess
env = gym.make("ALE/Pong-v5", full_action_space=False)
obs_dim = preprocess(np.zeros((env.observation_space.shape))).numel()
act_dim = 2
policy = Policy(obs_dim, act_dim)
policy.load("model.pt")
policy.model.eval()
env.close()
env = gym.make("ALE/Pong-v5", full_action_space=False, render_mode="human")
terminated, truncated = False, False
observation, info = env.reset()
prev_observation = preprocess(observation)
while not terminated and not truncated:
env.render()
observation = preprocess(observation)
action = policy.act((observation - prev_observation).flatten())
prev_observation = observation
env_action = [2, 3][action]
observation, reward, terminated, truncated, _ = env.step(env_action)
env.close()