-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrecord_all_gifs.py
More file actions
142 lines (115 loc) · 4.19 KB
/
record_all_gifs.py
File metadata and controls
142 lines (115 loc) · 4.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
Record GIFs for all 9 configurations (3 algorithms × 3 representations).
Usage:
python record_all_gifs.py
python record_all_gifs.py --algo mlp
python record_all_gifs.py --rep compact
"""
import sys
import os
import argparse
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
try:
import pygame
HAS_PYGAME = True
except ImportError:
print("ERROR: pygame required. pip install pygame")
sys.exit(1)
try:
from PIL import Image
HAS_PIL = True
except ImportError:
print("ERROR: Pillow required. pip install Pillow")
sys.exit(1)
from snake_rl.env.snake_env import SnakeEnv
from snake_rl.utils.save_load import load_agent_weights
from record_gameplay import render_frame, make_agent, weight_name
try:
from snake_rl.agents.double_dqn import DoubleDQNAgent
HAS_TORCH = True
except ImportError:
HAS_TORCH = False
print("WARNING: PyTorch not installed. DQN GIFs will be skipped.")
GRID_SIZE = 20
CELL_SIZE = 30
WEIGHTS_DIR = "weights"
RECORDINGS_DIR = "recordings"
def record_one_gif(algo, rep_name, n_episodes=3, max_steps_per_ep=300, fps=10, seed=0):
"""Record a GIF for one agent configuration."""
name = weight_name(algo, rep_name, seed=seed)
ext = ".pt" if algo == "mlp" else ".npz"
weight_file = os.path.join(WEIGHTS_DIR, f"{name}{ext}")
if not os.path.exists(weight_file):
print(f" SKIP: no weights at {weight_file}")
return False
try:
agent = make_agent(algo, rep_name, weights_dir=WEIGHTS_DIR, name=name)
load_agent_weights(agent, name, WEIGHTS_DIR)
agent.epsilon = 0.0
except Exception as e:
print(f" SKIP: failed to load agent — {e}")
return False
os.makedirs(RECORDINGS_DIR, exist_ok=True)
env = SnakeEnv(grid_size=GRID_SIZE, max_steps_factor=3, seed=999)
grid_w = GRID_SIZE * CELL_SIZE
panel_h = 60
pygame.init()
pygame.font.init()
surface = pygame.Surface((grid_w, grid_w + panel_h))
algo_label = "double_dqn" if algo == "mlp" else f"{algo}_sarsa"
label = f"{algo_label} + {rep_name}"
frames = []
for ep in range(n_episodes):
obs, _ = env.reset()
done = False
steps = 0
while not done and steps < max_steps_per_ep:
q_vals = agent.q_values(obs) if hasattr(agent, "q_values") else None
action = agent.act(obs)
obs, _, term, trunc, _ = env.step(action)
done = term or trunc
steps += 1
render_frame(surface, env, agent_name=label, q_vals=q_vals)
frame_data = pygame.image.tostring(surface, "RGB")
frames.append(Image.frombytes("RGB", (grid_w, grid_w + panel_h), frame_data))
pygame.quit()
if frames:
filepath = os.path.join(RECORDINGS_DIR, f"{algo_label}__{rep_name}.gif")
frames[0].save(
filepath,
save_all=True,
append_images=frames[1:],
duration=1000 // fps,
loop=0,
optimize=True,
)
print(f" Saved: {filepath} ({len(frames)} frames)")
return True
return False
def main():
parser = argparse.ArgumentParser(description="Record GIFs for all trained agents")
parser.add_argument("--algo", type=str, default=None,
choices=["linear", "tile", "mlp"])
parser.add_argument("--rep", type=str, default=None,
choices=["compact", "local", "extended"])
parser.add_argument("--seed", type=int, default=0,
help="Which seed's weights to load (default: 0)")
args = parser.parse_args()
algos = [args.algo] if args.algo else (
["linear", "tile", "mlp"] if HAS_TORCH else ["linear", "tile"]
)
reps = [args.rep] if args.rep else ["compact", "local", "extended"]
saved = 0
skipped = 0
for algo in algos:
for rep in reps:
print(f"\n[{algo} × {rep}] Recording...")
success = record_one_gif(algo, rep, seed=args.seed)
if success:
saved += 1
else:
skipped += 1
print(f"\nDone: {saved} GIFs saved, {skipped} skipped")
print(f"Output: {os.path.abspath(RECORDINGS_DIR)}/")
if __name__ == "__main__":
main()