Skip to content

Commit

Permalink
[API] Fix top_k_top_p_sampling_from_logits param typo (#875)
Browse files Browse the repository at this point in the history
Resolves #873
  • Loading branch information
kasohrab authored Feb 18, 2025
1 parent 78dde79 commit 68a0378
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions flashinfer/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ def min_p_sampling_from_probs(


def top_k_top_p_sampling_from_logits(
probs: torch.Tensor,
logits: torch.Tensor,
uniform_samples: torch.Tensor,
top_k: Union[torch.Tensor, int],
top_p: Union[torch.Tensor, float],
Expand Down Expand Up @@ -798,13 +798,13 @@ def top_k_top_p_sampling_from_logits(
top_p_sampling_from_probs
"""
if filter_apply_order == "top_k_first":
masked_logits = top_k_mask_logits(probs, top_k)
masked_logits = top_k_mask_logits(logits, top_k)
probs = torch.softmax(masked_logits, dim=-1)
return top_p_sampling_from_probs(
probs, uniform_samples, top_p, deterministic, check_nan=check_nan
)
elif filter_apply_order == "joint":
probs = torch.softmax(probs, dim=-1)
probs = torch.softmax(logits, dim=-1)
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
Expand Down

0 comments on commit 68a0378

Please sign in to comment.