@@ -40,14 +40,12 @@ def init(self, args, role, wandb_run_id, with_ref: bool = False): # type: ignor
4040 self .args = args
4141 torch .manual_seed (args .seed )
4242
43- # Serialize tokenizer/config loading across ranks to avoid HF cache race
4443 for i in range (dist .get_world_size ()):
4544 if i == dist .get_rank ():
4645 self .hf_config = AutoConfig .from_pretrained (self .args .hf_checkpoint , trust_remote_code = True )
4746 self .tokenizer = AutoTokenizer .from_pretrained (self .args .hf_checkpoint , trust_remote_code = True )
4847 dist .barrier (group = get_gloo_group ())
4948
50- # Load model
5149 with torch .device (f"cuda:{ torch .cuda .current_device ()} " ):
5250 model = AutoModelForCausalLM .from_pretrained (
5351 self .args .hf_checkpoint ,
@@ -141,7 +139,6 @@ def compute_log_prob(
141139 padded_batches: Input batches
142140 store_prefix: Prefix for storing results (e.g., "ref_")
143141 """
144- # Save current model parameters if switching to different model
145142 current_params = None
146143 if model_tag != "actor" and model_tag in self .weights :
147144 current_params = {}
@@ -150,9 +147,8 @@ def compute_log_prob(
150147 for name , param in current_state_dict .items ():
151148 current_params [name ] = param .clone ()
152149
153- # Load the specified model parameters
154150 self .update_gpu_params_dict (self .weights [model_tag ])
155- self .model .eval () # Set to eval mode for ref model
151+ self .model .eval ()
156152
157153 try :
158154 rollout_data = {f"{ store_prefix } log_probs" : []}
@@ -164,11 +160,10 @@ def compute_log_prob(
164160 return rollout_data
165161
166162 finally :
167- # Restore original model parameters if we switched
168163 if current_params is not None :
169164 with FSDP .state_dict_type (self .model , StateDictType .FULL_STATE_DICT ):
170165 self .model .load_state_dict (current_params , strict = True )
171- self .model .train () # Restore training mode
166+ self .model .train ()
172167 torch .cuda .synchronize ()
173168
174169 def pad_and_move_to_device (self , rollout_data ):
@@ -324,11 +319,9 @@ def train(self, rollout_id, rollout_data_ref): # type: ignore[override]
324319 reported ["kl_loss" ] = kl_loss .detach ()
325320 reported ["kl_loss_coef" ] = torch .tensor (self .args .kl_loss_coef , device = kl_loss .device )
326321
327- # Scale loss for gradient accumulation
328322 loss = loss / grad_accum
329323 loss .backward ()
330324
331- # Accumulate reported metrics (store tensors for later mean)
332325 for k , v in reported .items ():
333326 reported_accum .setdefault (k , []).append (v )
334327
@@ -337,12 +330,10 @@ def train(self, rollout_id, rollout_data_ref): # type: ignore[override]
337330 grad_norm = torch .nn .utils .clip_grad_norm_ (self .model .parameters (), self .args .clip_grad )
338331 self .optimizer .step ()
339332 self .optimizer .zero_grad (set_to_none = True )
340- # Aggregate logs
341333 aggregated = {k : torch .stack (v ).mean ().item () for k , v in reported_accum .items ()}
342334 # TODO: change this, this is slow.
343335 reduced_aggregated = [None ] * world_size
344336 dist .all_gather_object (reduced_aggregated , aggregated )
345- # Mean across dp ranks
346337 aggregated = {}
347338 for k in reported_accum .keys ():
348339 aggregated [k ] = sum ([r [k ] for r in reduced_aggregated ]) / world_size
0 commit comments