@@ -279,6 +279,8 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template):
279279
280280 truncated_mask = torch .tensor ([b ['is_truncated' ] for b in rollout_batch ], dtype = torch .bool , device = self .device )
281281
282+ rolled_labels = torch .roll (labels , shifts = - 1 , dims = - 1 )
283+
282284 if template .padding_free :
283285 # In padding_free mode, labels shape is [1, total_seq_len] (rmpad format)
284286 # Calculate seq_lengths from cu_seq_lens or position_ids
@@ -290,7 +292,7 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template):
290292 max_seq_len = seq_lengths .max ().item ()
291293
292294 # completion_mask in rmpad format [1, total_tokens]
293- completion_mask_rmpad = (labels != - 100 ).float ()
295+ completion_mask_rmpad = (rolled_labels != - 100 ).float ()
294296 completion_mask , _ = pad_logps_back_to_batch (
295297 logps_rmpad = completion_mask_rmpad ,
296298 logits_to_keep = max_seq_len ,
@@ -312,8 +314,8 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template):
312314 seq_lengths = torch .full ((batch_size , ), labels .shape [- 1 ], dtype = torch .int64 , device = self .device )
313315 max_seq_len = labels .shape [- 1 ]
314316
315- # completion_mask is already [batch_size, seq_len] in non-padding_free mode
316- completion_mask = (labels != - 100 )
317+ # completion_mask based on rolled labels for alignment with per_token_logps
318+ completion_mask = (rolled_labels != - 100 )
317319
318320 encoded_batch .update ({
319321 'completion_mask' : completion_mask , # [batch_size, max_seq_len]
@@ -934,34 +936,31 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]:
934936
935937 inputs = self ._prepare_model_inputs (batch )
936938 if self .beta != 0.0 :
937- with torch . no_grad (), self .null_ref_context () as ref_models :
939+ with self .null_ref_context () as ref_models :
938940 assert len (ref_models ) == 1 , 'GRPO currently does not support VPP.'
939941 ref_model = ref_models [0 ]
940- ref_per_token_logps_raw = self .model_forward (
941- ref_model , iter ([deepcopy (inputs )]), no_grad = True , per_token = True )[ 'logps' ]
942+ ref_per_token_logps_packed = self .compute_per_token_logps (
943+ ref_model , iter ([deepcopy (inputs )]), temperature = self . temperature )
942944 if self .template .padding_free :
943- # In padding_free mode, logps are in rmpad format [1, total_tokens]
944- # Pad to batch format [batch_size, max_seq_len]
945945 ref_per_token_logps , _ = pad_logps_back_to_batch (
946- logps_rmpad = ref_per_token_logps_raw ,
946+ logps_rmpad = ref_per_token_logps_packed ,
947947 logits_to_keep = max_seq_len ,
948948 batch_size = batch_size ,
949949 seq_lengths = seq_lengths )
950950 else :
951- # In non-padding_free mode, logps are already in batch format [batch_size, seq_len]
952- ref_per_token_logps = ref_per_token_logps_raw
951+ ref_per_token_logps = ref_per_token_logps_packed
953952 batch ['ref_per_token_logps' ] = ref_per_token_logps
954953
955- old_per_token_logps_raw = self .model_forward (
956- self .unwrapped_models [0 ], iter ([deepcopy (inputs )]), no_grad = True , per_token = True )[ 'logps' ]
954+ old_per_token_logps_packed = self .compute_per_token_logps (
955+ self .unwrapped_models [0 ], iter ([deepcopy (inputs )]), temperature = self . temperature )
957956 if self .template .padding_free :
958957 old_per_token_logps , _ = pad_logps_back_to_batch (
959- logps_rmpad = old_per_token_logps_raw ,
958+ logps_rmpad = old_per_token_logps_packed ,
960959 logits_to_keep = max_seq_len ,
961960 batch_size = batch_size ,
962961 seq_lengths = seq_lengths )
963962 else :
964- old_per_token_logps = old_per_token_logps_raw
963+ old_per_token_logps = old_per_token_logps_packed
965964 batch ['old_per_token_logps' ] = old_per_token_logps
966965
967966 return batch
@@ -1052,69 +1051,46 @@ def forward_step(self, data_iterator, model):
10521051
10531052 # Check if this is the PP last stage (only last stage has labels and computes loss)
10541053 is_pp_last_stage = mpu .is_pipeline_last_stage ()
1055-
1056- if self .compute_entropy :
1057- # Forward without labels to get logits, then compute logps and entropy
1058- inputs_for_logits = {k : v for k , v in inputs .items () if k != 'labels' }
1059- output_tensor = model (** inputs_for_logits )
1060-
1061- # Compute per_token_logps and per_token_entropy from logits on PP last stage
1062- if is_pp_last_stage and output_tensor is not None :
1063- # output_tensor is logits [batch/1, seq, partition_vocab_size]
1064- per_token_logps_raw , per_token_entropy_raw = compute_logps_and_entropy_from_logits (
1065- output_tensor , labels , compute_entropy = True )
1066-
1067- # In CP mode, all_gather and reconstruct full sequence
1068- if args .context_parallel_size > 1 :
1069- num_samples = packed_seq_params .num_samples if args .padding_free else micro_batch_size
1070- per_token_logps_raw = self ._postprocess_packed_tensor_cp (per_token_logps_raw , packed_seq_params ,
1071- num_samples )
1072- per_token_entropy_raw = self ._postprocess_packed_tensor_cp (per_token_entropy_raw , packed_seq_params ,
1073- num_samples )
1074-
1075- if args .padding_free :
1076- # Pad from rmpad [1, total_tokens] to batch format [batch_size, max_seq_len]
1077- per_token_logps , _ = pad_logps_back_to_batch (
1078- logps_rmpad = per_token_logps_raw ,
1079- logits_to_keep = max_seq_len ,
1080- batch_size = micro_batch_size ,
1081- seq_lengths = seq_lengths )
1054+ inputs_for_logits = {k : v for k , v in inputs .items () if k != 'labels' }
1055+ logits_packed = model (** inputs_for_logits )
1056+ output_tensor = None
1057+ if is_pp_last_stage and logits_packed is not None :
1058+ if self .temperature != 1.0 :
1059+ logits_packed .div_ (self .temperature )
1060+ per_token_logps_packed , per_token_entropy_packed = compute_logps_and_entropy_from_logits (
1061+ logits_packed , labels , compute_entropy = self .compute_entropy )
1062+
1063+ # In CP mode, all_gather and reconstruct full sequence
1064+ if args .context_parallel_size > 1 :
1065+ num_samples = packed_seq_params .num_samples if args .padding_free else micro_batch_size
1066+ per_token_logps_packed = self ._postprocess_packed_tensor_cp (per_token_logps_packed , packed_seq_params ,
1067+ num_samples )
1068+ if per_token_entropy_packed is not None :
1069+ per_token_entropy_packed = self ._postprocess_packed_tensor_cp (per_token_entropy_packed ,
1070+ packed_seq_params , num_samples )
1071+
1072+ if args .padding_free :
1073+ # Pad from rmpad [1, total_tokens] to batch format [batch_size, max_seq_len]
1074+ per_token_logps , _ = pad_logps_back_to_batch (
1075+ logps_rmpad = per_token_logps_packed ,
1076+ logits_to_keep = max_seq_len ,
1077+ batch_size = micro_batch_size ,
1078+ seq_lengths = seq_lengths )
1079+ if per_token_entropy_packed is not None :
10821080 per_token_entropy , _ = pad_logps_back_to_batch (
1083- logps_rmpad = per_token_entropy_raw ,
1081+ logps_rmpad = per_token_entropy_packed ,
10841082 logits_to_keep = max_seq_len ,
10851083 batch_size = micro_batch_size ,
10861084 seq_lengths = seq_lengths ,
10871085 pad_value = float ('nan' ))
10881086 else :
1089- per_token_logps = per_token_logps_raw
1090- per_token_entropy = per_token_entropy_raw
1091-
1092- data ['per_token_logps' ] = per_token_logps
1093- data ['per_token_entropy' ] = per_token_entropy
1094- else :
1095- # Standard forward with labels, returns per-token loss (more efficient)
1096- output_tensor = model (** inputs )
1097-
1098- # Convert output_tensor (per-token loss) to per_token_logps on PP last stage
1099- if is_pp_last_stage and output_tensor is not None :
1100- per_token_logps_raw = self .get_logps (
1101- output_tensor ,
1102- labels ,
1103- packed_seq_params ,
1104- packed_seq_params .num_samples if args .padding_free else micro_batch_size ,
1105- per_token = True )
1106-
1107- if args .padding_free :
1108- per_token_logps , _ = pad_logps_back_to_batch (
1109- logps_rmpad = per_token_logps_raw ,
1110- logits_to_keep = max_seq_len ,
1111- batch_size = micro_batch_size ,
1112- seq_lengths = seq_lengths )
1113- else :
1114- per_token_logps = per_token_logps_raw
1087+ per_token_entropy = None
1088+ else :
1089+ per_token_logps = per_token_logps_packed
1090+ per_token_entropy = per_token_entropy_packed
11151091
1116- data [ 'per_token_logps' ] = per_token_logps
1117- data ['per_token_entropy' ] = None
1092+ output_tensor = per_token_logps
1093+ data ['per_token_entropy' ] = per_token_entropy
11181094
11191095 return output_tensor , partial (self .loss_func , data = data )
11201096
@@ -1129,7 +1105,7 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]):
11291105
11301106 # Get pre-computed per_token_logps and per_token_entropy from forward_step
11311107 # These are already in batch format [batch_size, max_seq_len]
1132- per_token_logps = data . get ( 'per_token_logps' )
1108+ per_token_logps = output_tensor
11331109 per_token_entropy = data .get ('per_token_entropy' )
11341110
11351111 # Get pre-padded ref/old/rollout logps from data
@@ -1409,38 +1385,6 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]):
14091385
14101386 return loss , reporting_metric
14111387
1412- def model_forward (self , model , data_iterator , no_grad = True , per_token = False ):
1413- """Forward pass through model to compute logps.
1414-
1415- Args:
1416- model: The model to forward
1417- data_iterator: Iterator providing batch data
1418- no_grad: Whether to use torch.no_grad() context
1419- per_token: Whether to return per-token logps
1420-
1421- Returns:
1422- data dict containing 'logps'
1423- """
1424- # used to calculate model forward (logps) in GRPO
1425- data = self .get_batch (data_iterator )
1426- data .pop ('loss_scale' , None )
1427- input_ids = data .get ('input_ids' )
1428- labels = data .get ('labels' )
1429- context = torch .no_grad () if no_grad else nullcontext ()
1430-
1431- with context :
1432- output_tensor = forward_step_helper (self .args , model , data )
1433-
1434- # packed_seq_params only exists in padding_free mode
1435- packed_seq_params = data .get ('packed_seq_params' )
1436- if packed_seq_params is not None :
1437- num_samples = packed_seq_params .num_samples
1438- else :
1439- num_samples = input_ids .shape [0 ] if input_ids is not None else labels .shape [0 ]
1440- data ['logps' ] = None if labels is None else self .get_logps (
1441- output_tensor , labels , packed_seq_params , num_samples , per_token = per_token )
1442- return data
1443-
14441388 def inputs2requests (self , inputs : Union [DataType , List [RolloutInferRequest ]]) -> List [RolloutInferRequest ]:
14451389 """Convert raw input data into RolloutInferRequest objects"""
14461390
0 commit comments