88import tempfile
99import time
1010from dataclasses import asdict
11+ from pathlib import Path
1112from typing import BinaryIO , Optional , cast
1213
14+ import wandb
15+ import yaml
1316from agent_evolution_tournament import AgentEvolutionTournament , AgentEvolutionTournamentParams
1417from agents .alphazero .alphazero import AlphaZeroAgent , AlphaZeroBenchmarkOverrideParams , AlphaZeroParams
1518from agents .alphazero .self_play_manager import GameParams , SelfPlayManager
2225
2326def train_alphazero (
2427 args : argparse .Namespace ,
28+ alphazero_params : str ,
2529 wandb_train_plugin : Optional [WandbTrainPlugin ],
2630):
2731 game_params = GameParams (args .board_size , args .max_walls , args .max_steps )
2832
2933 # Create an agent that we'll use to do training.
30- training_params = parse_subargs (args . params , AlphaZeroParams )
34+ training_params = parse_subargs (alphazero_params , AlphaZeroParams )
3135 assert isinstance (training_params , AlphaZeroParams )
3236 training_params .training_mode = True # We only use this agent for training
3337 training_params .train_every = None # We manually run training at the end of each epoch
@@ -74,7 +78,8 @@ def train_alphazero(
7478 if wandb_train_plugin is not None :
7579 # HACK: the start_game method only cares that "game" has board_size and max_walls
7680 # members, so we pass in a GameParams object. We have to call start_game
77- # because it calls the plugin's internal _intialize method which sets up metrics.
81+ # because it calls the plugin's internal _intialize method which sets up metrics
82+ # and creates the WandB run.
7883 wandb_train_plugin .start_game (game = args , agent1 = training_agent , agent2 = training_agent )
7984 training_agent .set_wandb_run (wandb_train_plugin .run )
8085 # Compute the tournament metrics with the initial model, possibly random initialized, to
@@ -168,25 +173,38 @@ def main(args):
168173
169174 set_deterministic (args .seed )
170175
171- t0 = time . time ()
176+ alphazero_params = args . params
172177
173178 if args .wandb is None :
174179 wandb_train_plugin = None
175180 else :
181+ # Create the benchmarks and evolution tournament, and then create the WandB training plugin
176182 wandb_params = parse_subargs (args .wandb , WandbParams )
177183 assert isinstance (wandb_params , WandbParams )
178184
185+ wandb_run = None
186+ if args .sweep is not None :
187+ # Instead of letting the wandb plugin start its own run, we create one from the sweep config
188+ # and later pass it in to WandbTrainPlugin.start_game().
189+ with open (args .sweep ) as sweep_config_file :
190+ sweep_config = yaml .load (sweep_config_file , Loader = yaml .FullLoader )
191+
192+ wandb_run = wandb .init (config = sweep_config )
193+
194+ # Apply the sweep params on top of the command line params
195+ alphazero_params = override_subargs (alphazero_params , wandb_run .config ["alphazero" ])
196+
179197 metrics = Metrics (
180198 args .board_size , args .max_walls , args .benchmarks , args .benchmarks_t , args .max_steps , args .num_workers
181199 )
182- agent_encoded_name = "alphazero:" + args .params
183200
201+ benchmark_params = alphazero_params
184202 if args .benchmarks_params :
185- benchmarks_params = parse_subargs (args .benchmarks_params , AlphaZeroBenchmarkOverrideParams )
186- assert isinstance (benchmarks_params , AlphaZeroBenchmarkOverrideParams )
187- override_args = {k : v for k , v in asdict (benchmarks_params ).items () if v is not None }
203+ benchmark_param_overrides = parse_subargs (args .benchmarks_params , AlphaZeroBenchmarkOverrideParams )
204+ assert isinstance (benchmark_param_overrides , AlphaZeroBenchmarkOverrideParams )
205+ override_args = {k : v for k , v in asdict (benchmark_param_overrides ).items () if v is not None }
188206
189- agent_encoded_name = "alphazero:" + override_subargs (args . params , override_args )
207+ benchmark_params = override_subargs (benchmark_params , override_args )
190208
191209 if args .agent_evolution is not None :
192210 agent_evolution_params = parse_subargs (args .agent_evolution , AgentEvolutionTournamentParams )
@@ -204,15 +222,16 @@ def main(args):
204222 wandb_train_plugin = WandbTrainPlugin (
205223 wandb_params ,
206224 args .epochs * args .games_per_epoch ,
207- agent_encoded_name ,
225+ "alphazero:" + benchmark_params ,
208226 metrics ,
209227 agent_evolution_tournament ,
210228 include_raw_metrics = True ,
229+ wandb_run = wandb_run ,
211230 )
212231
213232 t0 = time .time ()
214233
215- train_alphazero (args , wandb_train_plugin = wandb_train_plugin )
234+ train_alphazero (args , alphazero_params , wandb_train_plugin )
216235
217236 t1 = time .time ()
218237
@@ -292,14 +311,16 @@ def main(args):
292311 type = str ,
293312 help = "Parameters for the Agent Evolution Tournament" ,
294313 )
314+ parser .add_argument (
315+ "--sweep" ,
316+ default = None ,
317+ type = Path ,
318+ help = "Path to WandB sweep config yaml file" ,
319+ )
320+
295321 args = parser .parse_args ()
296322
297- # Handle deprecated --max-game-length argument
298- if args .max_game_length is not None :
299- if args .max_steps != parser .get_default ("max-steps" ): # Check if --max-steps was also provided (not default)
300- print ("Warning: Both --max-game-length and --max-steps provided. Using --max-steps value." )
301- else :
302- print ("Warning: --max-game-length is deprecated. Please use --max-steps instead." )
303- args .max_steps = args .max_game_length
323+ if args .sweep is not None and args .wandb is None :
324+ print ("Enabling WandB since we're doing a sweep" )
304325
305326 main (args )
0 commit comments