Skip to content

Commit 4a99c90

Browse files
hzx55906zzzzzmeng
authored andcommitted
[BugFix] fix reduce_sampling (vllm-project#9545)
### What this PR does / why we need it? Fix the issue in reduce_sampling where enabling speculative sampling causes an error with a single curl request. vllm-project#8308 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.20.2 - vLLM main: vllm-project/vllm@1ac10f1 --------- Signed-off-by: hzx55906 <513464215@qq.com>
1 parent b478d13 commit 4a99c90

5 files changed

Lines changed: 13 additions & 11 deletions

File tree

vllm_ascend/ops/triton/reject_sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def sample_recovered_tokens_kernel(
329329

330330
qv = tl.load(q_ptr + req_idx * C + offs, mask=mask, other=1.0).to(tl.float32)
331331

332-
bad_q = (qv <= 0) | tl.math.isinf(qv)
332+
bad_q = (qv <= 0) | (qv != qv) | (qv == float("inf")) | (qv == -float("inf"))
333333
score = tl.where(bad_q, float("-inf"), prob / qv)
334334
score = tl.where(mask, score, float("-inf"))
335335

vllm_ascend/patch/worker/patch_llama_eagle3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def compute_logits(
77
self,
88
hidden_states: torch.Tensor,
9-
enable_reduce_sample: bool = True,
9+
enable_reduce_sample: bool = False,
1010
) -> torch.Tensor | None:
1111
if enable_reduce_sample:
1212
logits = self.logits_processor(self.lm_head, hidden_states)

vllm_ascend/sample/rejection_sampler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def rejection_sample(
444444
global_vocab_size,
445445
batch_size,
446446
NO_DRAFT_PROBS=draft_probs is None,
447-
enable_reduce_sampling=True,
447+
ENABLE_REDUCE_SAMPLING=True,
448448
BLOCK_SIZE=block_size,
449449
)
450450
else:
@@ -482,7 +482,7 @@ def rejection_sample(
482482
global_vocab_size,
483483
batch_size,
484484
NO_DRAFT_PROBS=draft_probs is None,
485-
enable_reduce_sampling=True,
485+
ENABLE_REDUCE_SAMPLING=True,
486486
BLOCK_SIZE=block_size,
487487
)
488488
else:
@@ -553,7 +553,7 @@ def rejection_sample(
553553
vocab_size, # global_vocab_size
554554
batch_size,
555555
NO_DRAFT_PROBS=draft_probs is None,
556-
enable_reduce_sampling=False,
556+
ENABLE_REDUCE_SAMPLING=False,
557557
BLOCK_SIZE=block_size,
558558
)
559559
else:
@@ -591,7 +591,7 @@ def rejection_sample(
591591
vocab_size, # global_vocab_size
592592
batch_size,
593593
NO_DRAFT_PROBS=draft_probs is None,
594-
enable_reduce_sampling=False,
594+
ENABLE_REDUCE_SAMPLING=False,
595595
BLOCK_SIZE=block_size,
596596
)
597597
else:
@@ -704,7 +704,7 @@ def sample_recovered_tokens(
704704
global_vocab_size if global_vocab_size is not None else vocab_size,
705705
NO_DRAFT_PROBS=draft_probs is None,
706706
BLOCK_VERIFY=use_block_verify,
707-
enable_reduce_sampling=enable_reduce_sampling,
707+
ENABLE_REDUCE_SAMPLING=enable_reduce_sampling,
708708
SUB_BLOCK=512,
709709
# TODO: enable multibuffer when accuracy problem is solved.
710710
multibuffer=False,

vllm_ascend/spec_decode/llm_base_proposer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -958,7 +958,7 @@ def _run_merged_draft(
958958
draft_token_ids = draft_token_ids[:num_indices]
959959
token_indices_to_sample = token_indices_to_sample[:num_indices]
960960
else:
961-
logits = self.model.compute_logits(sample_hidden_states, get_ascend_config().enable_reduce_sample)
961+
logits = self.model.compute_logits(sample_hidden_states)
962962
if lmhead_tp_enable() and num_indices < logits.shape[0]:
963963
logits = logits[:num_indices]
964964
token_indices_to_sample = token_indices_to_sample[:num_indices]
@@ -1089,7 +1089,9 @@ def _run_merged_draft(
10891089

10901090
sample_hidden_states = last_hidden_states[token_indices_to_sample]
10911091
if get_ascend_config().enable_reduce_sample:
1092-
draft_token_ids = self.model.compute_logits(sample_hidden_states)
1092+
draft_token_ids = self.model.compute_logits(
1093+
sample_hidden_states, get_ascend_config().enable_reduce_sample
1094+
)
10931095
if lmhead_tp_enable() and num_indices < draft_token_ids.shape[0]:
10941096
draft_token_ids = draft_token_ids[:num_indices]
10951097
token_indices_to_sample = token_indices_to_sample[:num_indices]

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2287,7 +2287,7 @@ def _sample(self, logits, spec_decode_metadata):
22872287
if spec_decode_metadata is None:
22882288
if lmhead_tp_enable() and logits is not None:
22892289
logits = logits[: self.input_batch.num_reqs]
2290-
if self.input_batch.top_k_cpu is not None and get_ascend_config().enable_reduce_sample:
2290+
if self.input_batch.sampling_metadata.top_k is not None and get_ascend_config().enable_reduce_sample:
22912291
max_topk = self.input_batch.top_k_cpu[self.input_batch.top_k_cpu < logits.shape[1]].max()
22922292
self.sampler.prepare_sampling(max_topk)
22932293
return self.sampler(
@@ -2297,7 +2297,7 @@ def _sample(self, logits, spec_decode_metadata):
22972297

22982298
if lmhead_tp_enable() and logits is not None:
22992299
logits = logits[: len(spec_decode_metadata.logits_indices)]
2300-
if self.input_batch.top_k_cpu is not None and get_ascend_config().enable_reduce_sample:
2300+
if self.input_batch.sampling_metadata.top_k is not None and get_ascend_config().enable_reduce_sample:
23012301
max_topk = self.input_batch.top_k_cpu[self.input_batch.top_k_cpu < logits.shape[1]].max()
23022302
self.rejection_sampler.prepare_sampling(max_topk)
23032303
sampler_output = self.rejection_sampler(

0 commit comments

Comments
 (0)