@@ -230,7 +230,7 @@ def compute_log_prob(
230230 model_args ["pixel_values" ] = batch ["pixel_values" ]
231231 logits = self .model (** model_args ).logits
232232 batch [f"{ store_prefix } log_probs" ] = gather_log_probs_packed (
233- logits , batch ["tokens" ], self .args .rollout_temperature
233+ logits , batch ["tokens" ], temperature = self .args .rollout_temperature
234234 )
235235 return rollout_data
236236
@@ -386,7 +386,9 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
386386 ).logits
387387
388388 # Handle packed sequences
389- log_probs = gather_log_probs_packed (logits , packed_batch ["tokens" ], packed_batch ["cu_seqlens" ])
389+ log_probs = gather_log_probs_packed (
390+ logits , packed_batch ["tokens" ], packed_batch ["cu_seqlens" ], temperature = self .args .rollout_temperature
391+ )
390392 packed_batch ["cur_log_probs" ] = log_probs
391393 unpacked_batches = unpack_sequences (packed_batch )
392394
@@ -655,7 +657,10 @@ def gather_log_probs(logits: torch.Tensor, input_ids: torch.Tensor, rollout_temp
655657
656658
657659def gather_log_probs_packed (
658- logits : torch .Tensor , input_ids : torch .Tensor , cu_seqlens : torch .Tensor | float | None = None
660+ logits : torch .Tensor ,
661+ input_ids : torch .Tensor ,
662+ cu_seqlens : torch .Tensor | float | None = None ,
663+ temperature : torch .Tensor | None = None ,
659664) -> torch .Tensor :
660665 """Gather next-token log probabilities for packed sequences.
661666
@@ -674,6 +679,9 @@ def gather_log_probs_packed(
674679 logits = logits .squeeze (0 )
675680 input_ids = input_ids .squeeze (0 )
676681
682+ if temperature is not None :
683+ logits = logits .div (temperature )
684+
677685 # Shift for next-token prediction: logits[:-1] predicts input_ids[1:]
678686 log_probs = torch .log_softmax (logits [:- 1 ], dim = - 1 )
679687 targets = input_ids [1 :].to (device = log_probs .device )
0 commit comments