@@ -133,13 +133,43 @@ def compute_log_prob(
133133 padded_batches ,
134134 store_prefix = "" ,
135135 ):
136- rollout_data = {f"{ store_prefix } log_probs" : []}
137- with timer (f"{ store_prefix } log_probs" ) and torch .no_grad ():
138- for batch in padded_batches :
139- with torch .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
140- logits = self .model (input_ids = batch ["tokens" ]).logits
141- batch [f"{ store_prefix } log_probs" ] = gather_log_probs (logits , batch ["tokens" ])
142- return rollout_data
136+ """
137+ Compute log probabilities using specified model.
138+
139+ Args:
140+ model_tag: "actor" for current model, "ref" for reference model
141+ padded_batches: Input batches
142+ store_prefix: Prefix for storing results (e.g., "ref_")
143+ """
144+ # Save current model parameters if switching to different model
145+ current_params = None
146+ if model_tag != "actor" and model_tag in self .weights :
147+ current_params = {}
148+ with FSDP .state_dict_type (self .model , StateDictType .FULL_STATE_DICT ):
149+ current_state_dict = self .model .state_dict ()
150+ for name , param in current_state_dict .items ():
151+ current_params [name ] = param .clone ()
152+
153+ # Load the specified model parameters
154+ self .update_gpu_params_dict (self .weights [model_tag ])
155+ self .model .eval () # Set to eval mode for ref model
156+
157+ try :
158+ rollout_data = {f"{ store_prefix } log_probs" : []}
159+ with timer (f"{ store_prefix } log_probs" ) and torch .no_grad ():
160+ for batch in padded_batches :
161+ with torch .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
162+ logits = self .model (input_ids = batch ["tokens" ]).logits
163+ batch [f"{ store_prefix } log_probs" ] = gather_log_probs (logits , batch ["tokens" ])
164+ return rollout_data
165+
166+ finally :
167+ # Restore original model parameters if we switched
168+ if current_params is not None :
169+ with FSDP .state_dict_type (self .model , StateDictType .FULL_STATE_DICT ):
170+ self .model .load_state_dict (current_params , strict = True )
171+ self .model .train () # Restore training mode
172+ torch .cuda .synchronize ()
143173
144174 def pad_and_move_to_device (self , rollout_data ):
145175 tokens = rollout_data ["tokens" ]
@@ -188,7 +218,7 @@ def train(self, rollout_id, rollout_data_ref): # type: ignore[override]
188218 ), f"Invalid grad_accum { grad_accum } for micro_batch_size { self .args .micro_batch_size } and global_batch_size { self .args .global_batch_size } "
189219
190220 if "ref" in self .weights :
191- self .compute_ref_log_probs ( padded_batches )
221+ self .compute_log_prob ( "ref" , padded_batches , store_prefix = "ref_" )
192222
193223 self .compute_log_prob ("actor" , padded_batches )
194224
@@ -347,8 +377,6 @@ def train(self, rollout_id, rollout_data_ref): # type: ignore[override]
347377
348378 self .update_cpu_params_dict (self .weights ["actor" ])
349379
350- self ._save_debug_train_data (rollout_id , rollout_data , padded_batches )
351-
352380 Timer ().start ("train_wait" )
353381 return
354382
@@ -430,71 +458,6 @@ def load_ref_model(self, ref_load_path):
430458 self .model .load_state_dict (current_weights , strict = True )
431459 torch .cuda .synchronize ()
432460
433- def compute_ref_log_probs (self , padded_batches ):
434- """
435- Compute log probabilities using reference model parameters.
436-
437- This method temporarily loads ref model parameters from CPU memory
438- (loaded once during initialization) to GPU, computes forward pass,
439- then restores original model parameters. No disk I/O involved.
440- """
441- if "ref" not in self .weights :
442- raise RuntimeError ("Reference model weights not loaded" )
443-
444- current_params = {}
445- with FSDP .state_dict_type (self .model , StateDictType .FULL_STATE_DICT ):
446- current_state_dict = self .model .state_dict ()
447- for name , param in current_state_dict .items ():
448- current_params [name ] = param .clone ()
449-
450- try :
451- self .update_gpu_params_dict (self .weights ["ref" ])
452- self .model .eval ()
453- for batch in padded_batches :
454- with torch .no_grad ():
455- with torch .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
456- logits = self .model (input_ids = batch ["tokens" ]).logits
457- batch ["ref_log_probs" ] = gather_log_probs (logits , batch ["tokens" ])
458-
459- finally :
460- with FSDP .state_dict_type (self .model , StateDictType .FULL_STATE_DICT ):
461- self .model .load_state_dict (current_params , strict = True )
462- self .model .train ()
463- torch .cuda .synchronize ()
464-
465- def _log_debug_rollout_data (self , rollout_id , rollout_data ):
466- """Log rollout data for debugging (similar to Megatron backend)"""
467- print (f"Debug rollout { rollout_id } : logging rollout data" )
468-
469- def _save_debug_train_data (self , rollout_id , rollout_data , padded_batches ):
470- """Save debug train data if requested"""
471- from pathlib import Path
472-
473- if (path_template := getattr (self .args , 'save_debug_train_data' , None )) is not None :
474- rank = dist .get_rank ()
475- path = Path (path_template .format (rollout_id = rollout_id , rank = rank ))
476- print (f"Save debug train data to { path } " )
477- path .parent .mkdir (parents = True , exist_ok = True )
478-
479- debug_data = {
480- 'rollout_id' : rollout_id ,
481- 'rank' : rank ,
482- 'rollout_data' : rollout_data ,
483- 'batch_info' : []
484- }
485-
486- for i , batch in enumerate (padded_batches ):
487- batch_info = {}
488- for key , value in batch .items ():
489- if isinstance (value , torch .Tensor ):
490- batch_info [key ] = value .cpu ().detach ()
491- else :
492- batch_info [key ] = value
493- debug_data ['batch_info' ].append (batch_info )
494-
495- torch .save (debug_data , path )
496-
497-
498461def gather_log_probs (logits : torch .Tensor , input_ids : torch .Tensor ) -> torch .Tensor :
499462 # log_probs: [B, T-1, V]; input_ids: [B, T]
500463 pred_logits = logits [:, :- 1 ]
0 commit comments