44
55import numpy as np
66import wandb
7+ from pydantic_yaml import parse_yaml_file_as
8+ from utils import Timer
79from v2 .common import MockWandb , create_alphazero
810from v2 .config import Config
9- from v2 .yaml_models import LatestModel
11+ from v2 .yaml_models import GameInfo , LatestModel
1012
1113
1214def train (config : Config ):
13- global azparams
1415 batch_size = config .training .batch_size
15- training_iterations = 1
16- min_new_games = 25
1716
1817 if config .wandb :
1918 run_id = f"{ config .run_id } -training"
@@ -38,19 +37,18 @@ def train(config: Config):
3837 alphazero_agent .save_model (filename )
3938 LatestModel .write (config , str (filename ), 0 )
4039
40+ training_steps = 0
4141 last_game = 0
4242 model_version = 1
4343 moves_per_game = []
4444 game_filename = []
4545
4646 while True :
47- while True :
48- ready = [f for f in sorted (config .paths .replay_buffers_ready .glob ("*.pkl" )) if f .is_file ()]
49- if len (ready ) >= min_new_games :
50- break
51- time .sleep (1 )
47+ Timer .start ("waiting-to-train" )
48+
49+ # Process new games: find new files, move them and extract the info used for training
50+ ready = [f for f in sorted (config .paths .replay_buffers_ready .glob ("*.pkl" )) if f .is_file ()]
5251
53- # Process new games
5452 for f in ready :
5553 last_game += 1
5654
@@ -59,44 +57,67 @@ def train(config: Config):
5957 yaml_file = f .with_suffix (".yaml" )
6058 new_yaml_name = new_name .with_suffix (".yaml" )
6159 yaml_file .rename (new_yaml_name )
60+ game_info = parse_yaml_file_as (GameInfo , new_yaml_name )
6261
6362 f .rename (new_name )
6463 with open (new_name , "rb" ) as f :
6564 data = pickle .load (f )
66- game_length = len (list (data ))
67- moves_per_game .append (game_length )
65+ moves_per_game .append (game_info .game_length )
6866 game_filename .append (f .name )
69- wandb_run .log ({"game_length" : game_length , "Game num" : last_game , "Model version" : model_version })
67+ wandb_run .log (
68+ {
69+ "game_length" : game_info .game_length ,
70+ "model_lag" : model_version - 1 - game_info .model_version ,
71+ "Game num" : last_game ,
72+ "Model version" : model_version ,
73+ }
74+ )
7075
7176 total_moves = sum (moves_per_game )
72- if total_moves < batch_size :
73- continue
7477
75- t0 = time .time ()
76- for _ in range (training_iterations ):
77- # Sample
78- # TO DO, we need to roll out games when it's longer that the replay buffer size
79- # TO DO probably we want to sample for all the training iterations together to make it faster
80- samples = []
78+ games_needed_to_train = config .training .games_per_training_step * (training_steps + 1 )
79+
80+ if total_moves < batch_size or games_needed_to_train > last_game :
81+ time .sleep (1 )
82+ continue
8183
82- games = np .random .choice (last_game , batch_size , p = [moves / total_moves for moves in moves_per_game ])
83- samples_per_game = Counter (games )
84- for game_number in samples_per_game :
85- file = config .paths .replay_buffers / game_filename [game_number ]
86- with open (file , "rb" ) as f :
87- data = pickle .load (f )
84+ time_waiting_to_train = Timer .finish ("waiting-to-train" )
8885
89- samples .extend (np .random .choice (list (data ), samples_per_game [game_number ]))
86+ # Sample moves from the replay buffer files
87+ Timer .start ("sample" )
88+ samples = []
9089
91- # print(f"{game_number}: {samples_per_game[game_number]}, {len(entries)}")
90+ games = np .random .choice (last_game , batch_size , p = [moves / total_moves for moves in moves_per_game ])
91+ samples_per_game = Counter (games )
92+ for game_number in samples_per_game :
93+ file = config .paths .replay_buffers / game_filename [game_number ]
94+ with open (file , "rb" ) as f :
95+ data = pickle .load (f )
9296
93- # Train
94- loss = alphazero_agent .evaluator .train_iteration_v2 (samples )
95- wandb_run .log ({"loss" : loss , "games_played" : last_game , "Model version" : model_version }, commit = True )
97+ samples .extend (np .random .choice (list (data ), samples_per_game [game_number ]))
98+ time_sample = Timer .finish ("sample" )
99+
100+ # Train the network for one step using the samples
101+ Timer .start ("train" )
102+ policy_loss , value_loss , total_loss = alphazero_agent .evaluator .train_iteration_v2 (samples )
103+ training_steps += 1
104+ time_train = Timer .finish ("train" )
105+
106+ wandb_run .log (
107+ {
108+ "policy_loss" : policy_loss ,
109+ "value_loss" : value_loss ,
110+ "total_loss" : total_loss ,
111+ "games_played" : last_game ,
112+ "time-sample" : time_sample ,
113+ "time-train" : time_train ,
114+ "time-waiting-to-train" : time_waiting_to_train ,
115+ "Model version" : model_version ,
116+ },
117+ commit = True ,
118+ )
96119
97- print (f"Loss: { loss } " )
98- t1 = time .time ()
99- print (f"Sampling and training took { t1 - t0 } " )
120+ print (f"Sampling and training took { time_sample } , { time_train } " )
100121
101122 new_model_filename = config .paths .checkpoints / f"model_{ model_version } .pt"
102123 alphazero_agent .save_model (new_model_filename )
0 commit comments