Skip to content

Commit fea75cd

Browse files
Merge pull request #329 from jonbinney/v2_b
Improvements in training
2 parents b426e6e + 5933199 commit fea75cd

File tree

7 files changed

+94
-41
lines changed

7 files changed

+94
-41
lines changed

deep_quoridor/experiments/B5W3/base.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
run_id: cucu-26b
1+
run_id: cucu-28d
22
quoridor:
33
board_size: 5
44
max_walls: 3

deep_quoridor/src/agents/alphazero/nn_evaluator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,14 +284,16 @@ def train_iteration_v2(self, samples):
284284
self.cache = LRUCache(max_size=self.max_cache_size)
285285

286286
policy_loss, value_loss, total_loss = self.compute_losses(samples)
287+
assert policy_loss, "Expected policy_loss"
288+
assert value_loss, "Expected value_loss"
287289
assert total_loss, "Expected total_loss"
288290

289291
# Backward pass
290292
self.optimizer.zero_grad()
291293
total_loss.backward()
292294
self.optimizer.step()
293295

294-
return total_loss.item()
296+
return policy_loss.item(), value_loss.item(), total_loss.item()
295297

296298
def train_iteration(
297299
self,

deep_quoridor/src/arena.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
plugins: list[ArenaPlugin] = [],
7272
swap_players: bool = True,
7373
max_steps: int = 1000,
74+
verbose: bool = True,
7475
):
7576
self.board_size = board_size
7677
self.max_walls = max_walls
@@ -81,6 +82,7 @@ def __init__(
8182

8283
self.renderers = renderers
8384
self.plugins = CompositeArenaPlugin([p for p in plugins + renderers + [saver] if p is not None])
85+
self.verbose = verbose
8486

8587
def _play_game(self, agent1: Agent, agent2: Agent, game_id: str) -> GameResult:
8688
self.game.reset()
@@ -119,7 +121,7 @@ def _play_game(self, agent1: Agent, agent2: Agent, game_id: str) -> GameResult:
119121
done=True,
120122
)
121123

122-
if truncation:
124+
if truncation and self.verbose:
123125
# Print the game state to help debug.
124126
print(f"\nP1: {agent1.name()} P2: {agent2.name()}")
125127
print(self.game.render())

deep_quoridor/src/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def tournament(self, agent: Agent):
183183
Runs a tournament for the given agent against predefined benchmarks. This was created for v2
184184
architecture, and has some duplicated code from compute().
185185
"""
186-
arena = Arena(self.board_size, self.max_walls, max_steps=self.max_steps, renderers=[MatchResultsRenderer()])
186+
arena = Arena(self.board_size, self.max_walls, max_steps=self.max_steps, verbose=False)
187187
# We store the elos of the opponents playing against each other so we don't have to play those matches
188188
# every time
189189
if not self.stored_elos:

deep_quoridor/src/utils/timer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def start(cls, name: str):
3939
cls.starts[name] = time.perf_counter()
4040

4141
@classmethod
42-
def finish(cls, name: str, episode: Optional[int] = None) -> str:
42+
def finish(cls, name: str, episode: Optional[int] = None) -> float:
4343
if name not in cls.starts:
4444
print(f"TIMER: WARNING - timer for {name} was not started but trying to finish")
4545
return ""
@@ -54,7 +54,7 @@ def finish(cls, name: str, episode: Optional[int] = None) -> str:
5454
if cls.wandb_run:
5555
cls.wandb_run.log({f"time-{name}": elapsed, "Episode": episode})
5656

57-
return format_time(elapsed)
57+
return elapsed
5858

5959
@classmethod
6060
def log_cumulative(cls, x_name: str, x_value: int | float):

deep_quoridor/src/v2/TODO.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# V1 Parity
2+
3+
- Replay buffer length: righ now we're not rolling out old games to respect the length
4+
- Agent Evolution benchmark
5+
- Allow to set a finish point (e.g. maximum number of models or games)
6+
- Include a CI test
7+
- Overrides from the command line
8+
- Continuation
9+
- Upload models to wandb
10+
11+
# Other improvements and new features
12+
13+
- Use a schedule for the learning rate
14+
- Use a logger class rather than just printing out.
15+
16+
# Performance
17+
18+
- In train, sample from all the training iterations together
19+
- For the replay buffer files, either:
20+
- Move them in directories based on their game number (e.g. games_1000)
21+
- Join multiple games in one file (a bit more tricky with sampling and race conditions)
22+
23+
# Ideas
24+
25+
- Self-healing processes
26+
- Allow to dynamically change the number of workers and parallel games, to experiment with performance
27+
- Mount the run directory and make other processes play from another computer
28+
- The processes could write status files and we could have a script to watch the status (e.g. elapsed time.)

deep_quoridor/src/v2/trainer.py

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@
44

55
import numpy as np
66
import wandb
7+
from pydantic_yaml import parse_yaml_file_as
8+
from utils import Timer
79
from v2.common import MockWandb, create_alphazero
810
from v2.config import Config
9-
from v2.yaml_models import LatestModel
11+
from v2.yaml_models import GameInfo, LatestModel
1012

1113

1214
def train(config: Config):
13-
global azparams
1415
batch_size = config.training.batch_size
15-
training_iterations = 1
16-
min_new_games = 25
1716

1817
if config.wandb:
1918
run_id = f"{config.run_id}-training"
@@ -38,19 +37,18 @@ def train(config: Config):
3837
alphazero_agent.save_model(filename)
3938
LatestModel.write(config, str(filename), 0)
4039

40+
training_steps = 0
4141
last_game = 0
4242
model_version = 1
4343
moves_per_game = []
4444
game_filename = []
4545

4646
while True:
47-
while True:
48-
ready = [f for f in sorted(config.paths.replay_buffers_ready.glob("*.pkl")) if f.is_file()]
49-
if len(ready) >= min_new_games:
50-
break
51-
time.sleep(1)
47+
Timer.start("waiting-to-train")
48+
49+
# Process new games: find new files, move them and extract the info used for training
50+
ready = [f for f in sorted(config.paths.replay_buffers_ready.glob("*.pkl")) if f.is_file()]
5251

53-
# Process new games
5452
for f in ready:
5553
last_game += 1
5654

@@ -59,44 +57,67 @@ def train(config: Config):
5957
yaml_file = f.with_suffix(".yaml")
6058
new_yaml_name = new_name.with_suffix(".yaml")
6159
yaml_file.rename(new_yaml_name)
60+
game_info = parse_yaml_file_as(GameInfo, new_yaml_name)
6261

6362
f.rename(new_name)
6463
with open(new_name, "rb") as f:
6564
data = pickle.load(f)
66-
game_length = len(list(data))
67-
moves_per_game.append(game_length)
65+
moves_per_game.append(game_info.game_length)
6866
game_filename.append(f.name)
69-
wandb_run.log({"game_length": game_length, "Game num": last_game, "Model version": model_version})
67+
wandb_run.log(
68+
{
69+
"game_length": game_info.game_length,
70+
"model_lag": model_version - 1 - game_info.model_version,
71+
"Game num": last_game,
72+
"Model version": model_version,
73+
}
74+
)
7075

7176
total_moves = sum(moves_per_game)
72-
if total_moves < batch_size:
73-
continue
7477

75-
t0 = time.time()
76-
for _ in range(training_iterations):
77-
# Sample
78-
# TO DO, we need to roll out games when it's longer that the replay buffer size
79-
# TO DO probably we want to sample for all the training iterations together to make it faster
80-
samples = []
78+
games_needed_to_train = config.training.games_per_training_step * (training_steps + 1)
79+
80+
if total_moves < batch_size or games_needed_to_train > last_game:
81+
time.sleep(1)
82+
continue
8183

82-
games = np.random.choice(last_game, batch_size, p=[moves / total_moves for moves in moves_per_game])
83-
samples_per_game = Counter(games)
84-
for game_number in samples_per_game:
85-
file = config.paths.replay_buffers / game_filename[game_number]
86-
with open(file, "rb") as f:
87-
data = pickle.load(f)
84+
time_waiting_to_train = Timer.finish("waiting-to-train")
8885

89-
samples.extend(np.random.choice(list(data), samples_per_game[game_number]))
86+
# Sample moves from the replay buffer files
87+
Timer.start("sample")
88+
samples = []
9089

91-
# print(f"{game_number}: {samples_per_game[game_number]}, {len(entries)}")
90+
games = np.random.choice(last_game, batch_size, p=[moves / total_moves for moves in moves_per_game])
91+
samples_per_game = Counter(games)
92+
for game_number in samples_per_game:
93+
file = config.paths.replay_buffers / game_filename[game_number]
94+
with open(file, "rb") as f:
95+
data = pickle.load(f)
9296

93-
# Train
94-
loss = alphazero_agent.evaluator.train_iteration_v2(samples)
95-
wandb_run.log({"loss": loss, "games_played": last_game, "Model version": model_version}, commit=True)
97+
samples.extend(np.random.choice(list(data), samples_per_game[game_number]))
98+
time_sample = Timer.finish("sample")
99+
100+
# Train the network for one step using the samples
101+
Timer.start("train")
102+
policy_loss, value_loss, total_loss = alphazero_agent.evaluator.train_iteration_v2(samples)
103+
training_steps += 1
104+
time_train = Timer.finish("train")
105+
106+
wandb_run.log(
107+
{
108+
"policy_loss": policy_loss,
109+
"value_loss": value_loss,
110+
"total_loss": total_loss,
111+
"games_played": last_game,
112+
"time-sample": time_sample,
113+
"time-train": time_train,
114+
"time-waiting-to-train": time_waiting_to_train,
115+
"Model version": model_version,
116+
},
117+
commit=True,
118+
)
96119

97-
print(f"Loss: {loss}")
98-
t1 = time.time()
99-
print(f"Sampling and training took {t1 - t0}")
120+
print(f"Sampling and training took {time_sample}, {time_train}")
100121

101122
new_model_filename = config.paths.checkpoints / f"model_{model_version}.pt"
102123
alphazero_agent.save_model(new_model_filename)

0 commit comments

Comments
 (0)