@@ -146,7 +146,7 @@ def step(
146146 normalize_arg (tree_attention_mask ),
147147 normalize_arg (kv_cache_position_ids ),
148148 normalize_arg (draft_tokens ),
149- normalize_arg (torch . tensor ( prefill_length ) ),
149+ normalize_arg (prefill_length ),
150150 normalize_arg (torch .tensor (1 if is_spec_dec else 0 )),
151151 )
152152 logger .info (f"_ServerInferenceSession step id { step_id } " )
@@ -328,6 +328,7 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与
328328 kv_cache_position_ids : Optional [torch .Tensor ] = None ,
329329 draft_tokens : Optional [torch .Tensor ] = None ,
330330 is_spec_decoding : Optional [torch .Tensor ] = None ,
331+ prefill_length : Optional [torch .Tensor ] = None ,
331332 ) -> torch .Tensor :
332333 assert not self ._closed
333334 if torch .is_grad_enabled ():
@@ -360,6 +361,7 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与
360361 is_spec_decoding = is_spec_decoding .cpu () if is_spec_decoding is not None else None
361362
362363 step_id = str (uuid .uuid4 ()) # Generate a unique step ID.
364+ batch_size = inputs .shape [0 ]
363365
364366 n_input_tokens = inputs .shape [1 ] if kv_cache_position_ids is None else kv_cache_position_ids .numel ()
365367 if self ._position + n_input_tokens > self ._max_length :
@@ -370,13 +372,15 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与
370372 server_idx = 0
371373 block_idx = 0
372374 inference_step_start = time .perf_counter ()
373- self .prefill_length = inputs .shape [1 ] - tree_attention_mask .shape [1 ] if tree_attention_mask is not None and self .first_inference else 0
374- # self.first_inference = False
375+ if tree_attention_mask is not None :
376+ self .prefill_length = prefill_length .to (inputs .device )
377+ else :
378+ self .prefill_length = torch .zeros (batch_size )
375379 keep_indices = torch .arange (
376- inputs .shape [1 ],
377- dtype = torch .int64 ,
378- device = inputs .device
379- )
380+ inputs .shape [1 ],
381+ dtype = torch .int64 ,
382+ device = inputs .device
383+ ). unsqueeze ( 0 ). expand ( inputs . shape [ 0 ], - 1 )
380384 self .keep_indices = keep_indices
381385 if is_spec_decoding is not None and is_spec_decoding .item () == 1 :
382386 is_spec_dec = True
@@ -440,8 +444,11 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与
440444 time .sleep (delay )
441445
442446 self ._position += n_input_tokens
447+ # logger.info(f"keep_indices: {keep_indices}")
448+ # logger.info(f"before _recover_hidden_states: {inputs}")
443449 if draft_tokens is not None and is_spec_dec :
444- inputs = self ._recover_hidden_states (inputs , self .keep_indices , draft_tokens .shape [1 ])
450+ inputs = self ._restore_hidden_states (inputs , self .keep_indices , draft_tokens .shape [1 ])
451+ # logger.info(f"after _recover_hidden_states: {inputs}")
445452 outputs = inputs
446453
447454 # 🔍 CLIENT DEBUG: Log inference step end
@@ -454,28 +461,48 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与
454461 # print('client inference session outputs ', outputs.shape)
455462 return outputs
456463
457- def _recover_hidden_states (self , hidden_states , keep_indices , original_length ):
458- if not torch .is_tensor (keep_indices ):
459- keep_indices = torch .tensor (keep_indices , device = hidden_states .device , dtype = torch .long )
460-
461- if hidden_states .dim () == 2 :
462- # [S_kept, H]
463- recovered = hidden_states .new_zeros ((original_length , hidden_states .size (- 1 )))
464- recovered [keep_indices ] = hidden_states
465-
466- elif hidden_states .dim () == 3 :
467- # [B, S_kept, H]
468- B , _ , H = hidden_states .shape
469- recovered = hidden_states .new_zeros ((B , original_length , H ))
470- recovered [:, keep_indices , :] = hidden_states
471-
472- else :
473- raise ValueError (f"Unexpected hidden_states dim { hidden_states .dim ()} , expected 2 or 3" )
474- mask = torch .ones (original_length , dtype = torch .bool , device = hidden_states .device )
475- mask [keep_indices ] = False
476- self ._last_padded_mask = mask
477-
478- return recovered
464+ def _restore_hidden_states (
465+ self ,
466+ flattened_hidden_states : torch .Tensor , # [N_total_valid, hidden_size]
467+ keep_indices : torch .Tensor , # [B, max_keep_len],padding 为 -1
468+ original_seq_len : int , # 原始序列长度
469+ ) -> torch .Tensor :
470+ """
471+ 将铺平的 hidden states 还原为 [B, original_seq_len, hidden_size]
472+
473+ Args:
474+ flattened_hidden_states: [N_total_valid, hidden_size] 铺平后的有效 hidden states
475+ keep_indices: [B, max_keep_len] 每个 batch 的 keep indices,padding 为 -1
476+ original_seq_len: 原始序列长度
477+
478+ Returns:
479+ restored_hidden_states: [B, original_seq_len, hidden_size],无效位置用 0 填充
480+ """
481+ batch_size , max_keep_len = keep_indices .shape
482+ hidden_size = flattened_hidden_states .shape [- 1 ]
483+ device = flattened_hidden_states .device
484+ dtype = flattened_hidden_states .dtype
485+
486+ # 创建输出 tensor,用 0 填充
487+ restored_hidden_states = torch .zeros (
488+ batch_size , original_seq_len , hidden_size ,
489+ dtype = dtype , device = device
490+ )
491+
492+ # 创建有效 mask: [B, max_keep_len]
493+ valid_mask = keep_indices >= 0
494+
495+ # 创建 batch 索引: [B, max_keep_len]
496+ batch_idx = torch .arange (batch_size , device = device ).unsqueeze (1 ).expand_as (keep_indices )
497+
498+ # 取出有效部分的索引
499+ valid_batch_idx = batch_idx [valid_mask ] # [N_total_valid]
500+ valid_seq_idx = keep_indices [valid_mask ] # [N_total_valid]
501+
502+ # 写入还原位置
503+ restored_hidden_states [valid_batch_idx , valid_seq_idx , :] = flattened_hidden_states
504+
505+ return restored_hidden_states
479506
480507 def _update_sequence (self , server_idx : int , block_idx : int , attempt_no : int ) -> int :
481508 # If there is a failed server session, this code closes it
0 commit comments