Skip to content

Commit 4a1adb6

Browse files
authored
Fix logprob does not handle temperature (THUDM#557)
1 parent 79fd101 commit 4a1adb6

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

slime/backends/fsdp_utils/actor.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def compute_log_prob(
230230
model_args["pixel_values"] = batch["pixel_values"]
231231
logits = self.model(**model_args).logits
232232
batch[f"{store_prefix}log_probs"] = gather_log_probs_packed(
233-
logits, batch["tokens"], self.args.rollout_temperature
233+
logits, batch["tokens"], temperature=self.args.rollout_temperature
234234
)
235235
return rollout_data
236236

@@ -386,7 +386,9 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
386386
).logits
387387

388388
# Handle packed sequences
389-
log_probs = gather_log_probs_packed(logits, packed_batch["tokens"], packed_batch["cu_seqlens"])
389+
log_probs = gather_log_probs_packed(
390+
logits, packed_batch["tokens"], packed_batch["cu_seqlens"], temperature=self.args.rollout_temperature
391+
)
390392
packed_batch["cur_log_probs"] = log_probs
391393
unpacked_batches = unpack_sequences(packed_batch)
392394

@@ -655,7 +657,10 @@ def gather_log_probs(logits: torch.Tensor, input_ids: torch.Tensor, rollout_temp
655657

656658

657659
def gather_log_probs_packed(
658-
logits: torch.Tensor, input_ids: torch.Tensor, cu_seqlens: torch.Tensor | float | None = None
660+
logits: torch.Tensor,
661+
input_ids: torch.Tensor,
662+
cu_seqlens: torch.Tensor | float | None = None,
663+
temperature: torch.Tensor | None = None,
659664
) -> torch.Tensor:
660665
"""Gather next-token log probabilities for packed sequences.
661666
@@ -674,6 +679,9 @@ def gather_log_probs_packed(
674679
logits = logits.squeeze(0)
675680
input_ids = input_ids.squeeze(0)
676681

682+
if temperature is not None:
683+
logits = logits.div(temperature)
684+
677685
# Shift for next-token prediction: logits[:-1] predicts input_ids[1:]
678686
log_probs = torch.log_softmax(logits[:-1], dim=-1)
679687
targets = input_ids[1:].to(device=log_probs.device)

0 commit comments

Comments
 (0)