@@ -594,25 +594,30 @@ def _prepare_inputs(
594594 req_indices , positions_np )
595595 self .input_batch .block_table .commit_slot_mapping (
596596 total_num_scheduled_tokens )
597+
598+ total_num_pcp_pads = 0
597599 if self .pcp_size > 1 :
598600 if not self .vllm_config .model_config .use_mla :
599601 self .generate_kv_idx (scheduler_output )
600602 tokens , position_pcp , pcp_unpad_mask = self ._update_tokens_for_pcp (
601603 tokens )
602604 num_scheduled_tokens = np .array (tokens , dtype = np .int32 )
603605 total_num_scheduled_tokens = sum (num_scheduled_tokens [:num_reqs ])
606+ total_num_pcp_pads = torch .sum (self .num_pcp_pads ).item ()
604607 else :
605608 position_pcp , pcp_unpad_mask = None , None
606609 self .num_pcp_pads = self .num_pcp_pads [:num_reqs ]
607610
608- total_num_pcp_pads = sum (self .num_pcp_pads )
609611 max_num_scheduled_tokens = max (tokens )
610- num_valid_tokens = np .array ([
611- num_tokens -
612- len (scheduler_output .scheduled_spec_decode_tokens .get (i , []))
613- for num_tokens , i in zip (tokens , req_ids )
614- ],
615- dtype = np .int32 )
612+ if not scheduler_output .scheduled_spec_decode_tokens :
613+ num_valid_tokens = np .array (tokens , dtype = np .int32 )
614+ else :
615+ num_valid_tokens = np .array ([
616+ num_tokens -
617+ len (scheduler_output .scheduled_spec_decode_tokens .get (i , []))
618+ for num_tokens , i in zip (tokens , req_ids )
619+ ],
620+ dtype = np .int32 )
616621
617622 if (self .use_aclgraph and total_num_scheduled_tokens
618623 <= self .cudagraph_batch_sizes [- 1 ]):
0 commit comments