Skip to content

Commit 1bf8291

Browse files
authored
Merge pull request #317 from jonbinney/jdb/wandb-sweeps
Initial implementation of W&B Sweeps
2 parents 086ca6b + 1f9cd97 commit 1bf8291

File tree

2 files changed

+68
-29
lines changed

2 files changed

+68
-29
lines changed

deep_quoridor/src/plugins/wandb_train.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Optional
88

99
import wandb
10+
import wandb.wandb_run
1011
from agent_evolution_tournament import AgentEvolutionTournament
1112
from agents.core.trainable_agent import TrainableAgent
1213
from arena_utils import ArenaPlugin
@@ -54,6 +55,7 @@ def __init__(
5455
metrics: Metrics,
5556
agent_evolution_tournament: Optional[AgentEvolutionTournament] = None,
5657
include_raw_metrics: bool = False,
58+
wandb_run: wandb.wandb_run.Run | None = None,
5759
):
5860
self.params = params
5961
self.total_episodes = total_episodes
@@ -67,8 +69,29 @@ def __init__(
6769
self.agent_evolution_tournament = agent_evolution_tournament
6870
self.include_raw_metrics = include_raw_metrics
6971

70-
def _initialize(self, game):
71-
assert self.agent
72+
if wandb_run is None:
73+
self.run = wandb.init(
74+
project=self.params.project,
75+
job_type="train",
76+
id=self.params.run_id(),
77+
group=f"{self.params.run_id()}",
78+
notes=self.params.notes,
79+
)
80+
else:
81+
self.run = wandb_run
82+
Timer.set_wandb_run(self.run)
83+
84+
self._initialized = False
85+
86+
def _initialize(
87+
self,
88+
game,
89+
):
90+
if self._initialized:
91+
return
92+
93+
assert self.agent is not None
94+
assert self.run is not None
7295

7396
config = {
7497
"board_size": game.board_size,
@@ -78,23 +101,18 @@ def _initialize(self, game):
78101
}
79102
config.update(self.agent.model_hyperparameters())
80103

81-
self.run = wandb.init(
82-
project=self.params.project,
83-
job_type="train",
84-
config=config,
85-
tags=[self.agent.model_id(), f"-{self.params.run_id()}"],
86-
id=self.params.run_id(),
87-
group=f"{self.params.run_id()}",
88-
notes=self.params.notes,
89-
)
90-
Timer.set_wandb_run(self.run)
104+
self.run.config.update(config)
105+
assert self.run.tags is not None
106+
self.run.tags = self.run.tags + (self.agent.model_id(), f"-{self.params.run_id()}")
91107

92108
wandb.define_metric("Loss step", hidden=True)
93109
wandb.define_metric("Epoch", hidden=True)
94110
wandb.define_metric("Episode", hidden=True)
95111
wandb.define_metric("loss_*", "Loss step")
96112
wandb.define_metric("*", "Episode")
97113

114+
self._initialized = True
115+
98116
def start_game(self, game, agent1, agent2):
99117
if (self.agent is not None) and (self.agent != agent1) and (self.agent != agent2):
100118
raise ValueError("WandbTrainPlugin being used for an agent, but another agent is being trained")

deep_quoridor/src/train_alphazero.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@
88
import tempfile
99
import time
1010
from dataclasses import asdict
11+
from pathlib import Path
1112
from typing import BinaryIO, Optional, cast
1213

14+
import wandb
15+
import yaml
1316
from agent_evolution_tournament import AgentEvolutionTournament, AgentEvolutionTournamentParams
1417
from agents.alphazero.alphazero import AlphaZeroAgent, AlphaZeroBenchmarkOverrideParams, AlphaZeroParams
1518
from agents.alphazero.self_play_manager import GameParams, SelfPlayManager
@@ -22,12 +25,13 @@
2225

2326
def train_alphazero(
2427
args: argparse.Namespace,
28+
alphazero_params: str,
2529
wandb_train_plugin: Optional[WandbTrainPlugin],
2630
):
2731
game_params = GameParams(args.board_size, args.max_walls, args.max_steps)
2832

2933
# Create an agent that we'll use to do training.
30-
training_params = parse_subargs(args.params, AlphaZeroParams)
34+
training_params = parse_subargs(alphazero_params, AlphaZeroParams)
3135
assert isinstance(training_params, AlphaZeroParams)
3236
training_params.training_mode = True # We only use this agent for training
3337
training_params.train_every = None # We manually run training at the end of each epoch
@@ -74,7 +78,8 @@ def train_alphazero(
7478
if wandb_train_plugin is not None:
7579
# HACK: the start_game method only cares that "game" has board_size and max_walls
7680
# members, so we pass in a GameParams object. We have to call start_game
77-
# because it calls the plugin's internal _intialize method which sets up metrics.
81+
# because it calls the plugin's internal _intialize method which sets up metrics
82+
# and creates the WandB run.
7883
wandb_train_plugin.start_game(game=args, agent1=training_agent, agent2=training_agent)
7984
training_agent.set_wandb_run(wandb_train_plugin.run)
8085
# Compute the tournament metrics with the initial model, possibly random initialized, to
@@ -168,25 +173,38 @@ def main(args):
168173

169174
set_deterministic(args.seed)
170175

171-
t0 = time.time()
176+
alphazero_params = args.params
172177

173178
if args.wandb is None:
174179
wandb_train_plugin = None
175180
else:
181+
# Create the benchmarks and evolution tournament, and then create the WandB training plugin
176182
wandb_params = parse_subargs(args.wandb, WandbParams)
177183
assert isinstance(wandb_params, WandbParams)
178184

185+
wandb_run = None
186+
if args.sweep is not None:
187+
# Instead of letting the wandb plugin start its own run, we create one from the sweep config
188+
# and later pass it in to WandbTrainPlugin.start_game().
189+
with open(args.sweep) as sweep_config_file:
190+
sweep_config = yaml.load(sweep_config_file, Loader=yaml.FullLoader)
191+
192+
wandb_run = wandb.init(config=sweep_config)
193+
194+
# Apply the sweep params on top of the command line params
195+
alphazero_params = override_subargs(alphazero_params, wandb_run.config["alphazero"])
196+
179197
metrics = Metrics(
180198
args.board_size, args.max_walls, args.benchmarks, args.benchmarks_t, args.max_steps, args.num_workers
181199
)
182-
agent_encoded_name = "alphazero:" + args.params
183200

201+
benchmark_params = alphazero_params
184202
if args.benchmarks_params:
185-
benchmarks_params = parse_subargs(args.benchmarks_params, AlphaZeroBenchmarkOverrideParams)
186-
assert isinstance(benchmarks_params, AlphaZeroBenchmarkOverrideParams)
187-
override_args = {k: v for k, v in asdict(benchmarks_params).items() if v is not None}
203+
benchmark_param_overrides = parse_subargs(args.benchmarks_params, AlphaZeroBenchmarkOverrideParams)
204+
assert isinstance(benchmark_param_overrides, AlphaZeroBenchmarkOverrideParams)
205+
override_args = {k: v for k, v in asdict(benchmark_param_overrides).items() if v is not None}
188206

189-
agent_encoded_name = "alphazero:" + override_subargs(args.params, override_args)
207+
benchmark_params = override_subargs(benchmark_params, override_args)
190208

191209
if args.agent_evolution is not None:
192210
agent_evolution_params = parse_subargs(args.agent_evolution, AgentEvolutionTournamentParams)
@@ -204,15 +222,16 @@ def main(args):
204222
wandb_train_plugin = WandbTrainPlugin(
205223
wandb_params,
206224
args.epochs * args.games_per_epoch,
207-
agent_encoded_name,
225+
"alphazero:" + benchmark_params,
208226
metrics,
209227
agent_evolution_tournament,
210228
include_raw_metrics=True,
229+
wandb_run=wandb_run,
211230
)
212231

213232
t0 = time.time()
214233

215-
train_alphazero(args, wandb_train_plugin=wandb_train_plugin)
234+
train_alphazero(args, alphazero_params, wandb_train_plugin)
216235

217236
t1 = time.time()
218237

@@ -292,14 +311,16 @@ def main(args):
292311
type=str,
293312
help="Parameters for the Agent Evolution Tournament",
294313
)
314+
parser.add_argument(
315+
"--sweep",
316+
default=None,
317+
type=Path,
318+
help="Path to WandB sweep config yaml file",
319+
)
320+
295321
args = parser.parse_args()
296322

297-
# Handle deprecated --max-game-length argument
298-
if args.max_game_length is not None:
299-
if args.max_steps != parser.get_default("max-steps"): # Check if --max-steps was also provided (not default)
300-
print("Warning: Both --max-game-length and --max-steps provided. Using --max-steps value.")
301-
else:
302-
print("Warning: --max-game-length is deprecated. Please use --max-steps instead.")
303-
args.max_steps = args.max_game_length
323+
if args.sweep is not None and args.wandb is None:
324+
print("Enabling WandB since we're doing a sweep")
304325

305326
main(args)

0 commit comments

Comments
 (0)