Skip to content

Commit e6af835

Browse files
Merge branch 'main' into main
2 parents 8efbc34 + 3b7eb51 commit e6af835

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -594,25 +594,30 @@ def _prepare_inputs(
594594
req_indices, positions_np)
595595
self.input_batch.block_table.commit_slot_mapping(
596596
total_num_scheduled_tokens)
597+
598+
total_num_pcp_pads = 0
597599
if self.pcp_size > 1:
598600
if not self.vllm_config.model_config.use_mla:
599601
self.generate_kv_idx(scheduler_output)
600602
tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(
601603
tokens)
602604
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
603605
total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs])
606+
total_num_pcp_pads = torch.sum(self.num_pcp_pads).item()
604607
else:
605608
position_pcp, pcp_unpad_mask = None, None
606609
self.num_pcp_pads = self.num_pcp_pads[:num_reqs]
607610

608-
total_num_pcp_pads = sum(self.num_pcp_pads)
609611
max_num_scheduled_tokens = max(tokens)
610-
num_valid_tokens = np.array([
611-
num_tokens -
612-
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
613-
for num_tokens, i in zip(tokens, req_ids)
614-
],
615-
dtype=np.int32)
612+
if not scheduler_output.scheduled_spec_decode_tokens:
613+
num_valid_tokens = np.array(tokens, dtype=np.int32)
614+
else:
615+
num_valid_tokens = np.array([
616+
num_tokens -
617+
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
618+
for num_tokens, i in zip(tokens, req_ids)
619+
],
620+
dtype=np.int32)
616621

617622
if (self.use_aclgraph and total_num_scheduled_tokens
618623
<= self.cudagraph_batch_sizes[-1]):

0 commit comments

Comments
 (0)