diff --git a/benchmarks/benchmark_speculative_decoding.py b/benchmarks/benchmark_speculative_decoding.py index 98c6c110..df57d373 100644 --- a/benchmarks/benchmark_speculative_decoding.py +++ b/benchmarks/benchmark_speculative_decoding.py @@ -60,35 +60,59 @@ def benchmark_inference(process_idx, args, result_pipe): ).to(device) tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False) + batch_size = 4 dataset = load_dataset("tatsu-lab/alpaca")["train"] - indices = random.sample(range(len(dataset)), 1) + indices = random.sample(range(len(dataset)), batch_size) sampled = dataset.select(indices) + test_prompts = [] + # for item in sampled: + # test_prompts.append(item["instruction"]) + + test_prompts.append("Hi,") + test_prompts.append("") + test_prompts.append("") + test_prompts.append("") + + tokenizer.pad_token = tokenizer.eos_token + # tokenizer.padding_side = "left" + input_ids = tokenizer(test_prompts, return_tensors="pt", padding=True).to(device)["input_ids"] - for item in sampled: + # test_prompt = "" + # bos_token_id = tokenizer.bos_token_id + # if bos_token_id is not None: + # input_ids = torch.tensor([[bos_token_id]], dtype=torch.long, device=device) + # else: + # # 如果tokenizer没有bos_token_id,可能需要手动获取或处理 + # logger.warning("Tokenizer does not have a bos_token_id. Using an empty tensor.") + # input_ids = torch.tensor([[]], dtype=torch.long, device=device) - test_prompt = item["instruction"] - logger.info(f"test_prompt: {test_prompt}") - input_ids = tokenizer.encode(test_prompt, return_tensors="pt", add_special_tokens=False).to(device) - - test_prompt = "" - bos_token_id = tokenizer.bos_token_id - if bos_token_id is not None: - input_ids = torch.tensor([[bos_token_id]], dtype=torch.long, device=device) - else: - # 如果tokenizer没有bos_token_id,可能需要手动获取或处理 - logger.warning("Tokenizer does not have a bos_token_id. Using an empty tensor.") - input_ids = torch.tensor([[]], dtype=torch.long, device=device) - - result = "" - start_time = perf_counter() - result = model.generate(input_ids=input_ids, ssm=ssm) - time = perf_counter() - start_time - generated_tokens_num = result.shape[1] - input_ids.shape[1] - speed = generated_tokens_num / time - decoded_result = tokenizer.decode(result[0], skip_special_tokens=True) + result = "" + start_time = perf_counter() + result = model.generate(input_ids=input_ids, ssm=ssm) + time = perf_counter() - start_time + generated_tokens_nums = [] + for i in range(batch_size): + prompt_length = input_ids[i].ne(tokenizer.pad_token_id).sum().item() + result_length = result[i].ne(tokenizer.pad_token_id).sum().item() + generated_tokens_num = result_length - prompt_length + generated_tokens_nums.append(generated_tokens_num) + + avg_generated_tokens = sum(generated_tokens_nums) / batch_size + speed = avg_generated_tokens / time + + # 解码所有结果 + decoded_results = tokenizer.batch_decode(result, skip_special_tokens=True) + + logger.info(f"benchmark_inference batch size: {batch_size}") + logger.info(f"Total time: {time:.4f}s, Average speed: {speed:.2f} tokens/s") + logger.info(f"Generated tokens per sample: {generated_tokens_nums}") - logger.info(f"benchmark_inference, result: {result}, generated_tokens_num: {generated_tokens_num}, time: {time} speed: {speed}, decoded_result: {decoded_result}") + for i, (prompt, decoded_result) in enumerate(zip(test_prompts, decoded_results)): + logger.info(f"Sample {i}:") + logger.info(f" Prompt: {prompt}") + logger.info(f" Result: {decoded_result}") + logger.info(f" Generated tokens: {generated_tokens_nums[i]}") result_pipe.send(speed) diff --git a/src/bloombee/client/inference_session.py b/src/bloombee/client/inference_session.py index 67882e43..d438499b 100644 --- a/src/bloombee/client/inference_session.py +++ b/src/bloombee/client/inference_session.py @@ -146,7 +146,7 @@ def step( normalize_arg(tree_attention_mask), normalize_arg(kv_cache_position_ids), normalize_arg(draft_tokens), - normalize_arg(torch.tensor(prefill_length)), + normalize_arg(prefill_length), normalize_arg(torch.tensor(1 if is_spec_dec else 0)), ) logger.info(f"_ServerInferenceSession step id {step_id}") @@ -328,6 +328,7 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与 kv_cache_position_ids: Optional[torch.Tensor] = None, draft_tokens: Optional[torch.Tensor] = None, is_spec_decoding: Optional[torch.Tensor] = None, + prefill_length: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert not self._closed if torch.is_grad_enabled(): @@ -360,6 +361,7 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与 is_spec_decoding = is_spec_decoding.cpu() if is_spec_decoding is not None else None step_id = str(uuid.uuid4()) # Generate a unique step ID. + batch_size = inputs.shape[0] n_input_tokens = inputs.shape[1] if kv_cache_position_ids is None else kv_cache_position_ids.numel() if self._position + n_input_tokens > self._max_length: @@ -370,13 +372,15 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与 server_idx = 0 block_idx = 0 inference_step_start = time.perf_counter() - self.prefill_length = inputs.shape[1] - tree_attention_mask.shape[1] if tree_attention_mask is not None and self.first_inference else 0 - # self.first_inference = False + if tree_attention_mask is not None: + self.prefill_length = prefill_length.to(inputs.device) + else: + self.prefill_length = torch.zeros(batch_size) keep_indices = torch.arange( - inputs.shape[1], - dtype=torch.int64, - device=inputs.device - ) + inputs.shape[1], + dtype=torch.int64, + device=inputs.device + ).unsqueeze(0).expand(inputs.shape[0], -1) self.keep_indices = keep_indices if is_spec_decoding is not None and is_spec_decoding.item() == 1: is_spec_dec = True @@ -440,8 +444,11 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与 time.sleep(delay) self._position += n_input_tokens + # logger.info(f"keep_indices: {keep_indices}") + # logger.info(f"before _recover_hidden_states: {inputs}") if draft_tokens is not None and is_spec_dec: - inputs = self._recover_hidden_states(inputs, self.keep_indices, draft_tokens.shape[1]) + inputs = self._restore_hidden_states(inputs, self.keep_indices, draft_tokens.shape[1]) + # logger.info(f"after _recover_hidden_states: {inputs}") outputs = inputs # 🔍 CLIENT DEBUG: Log inference step end @@ -454,28 +461,48 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与 # print('client inference session outputs ', outputs.shape) return outputs - def _recover_hidden_states(self, hidden_states, keep_indices, original_length): - if not torch.is_tensor(keep_indices): - keep_indices = torch.tensor(keep_indices, device=hidden_states.device, dtype=torch.long) - - if hidden_states.dim() == 2: - # [S_kept, H] - recovered = hidden_states.new_zeros((original_length, hidden_states.size(-1))) - recovered[keep_indices] = hidden_states - - elif hidden_states.dim() == 3: - # [B, S_kept, H] - B, _, H = hidden_states.shape - recovered = hidden_states.new_zeros((B, original_length, H)) - recovered[:, keep_indices, :] = hidden_states - - else: - raise ValueError(f"Unexpected hidden_states dim {hidden_states.dim()}, expected 2 or 3") - mask = torch.ones(original_length, dtype=torch.bool, device=hidden_states.device) - mask[keep_indices] = False - self._last_padded_mask = mask - - return recovered + def _restore_hidden_states( + self, + flattened_hidden_states: torch.Tensor, # [N_total_valid, hidden_size] + keep_indices: torch.Tensor, # [B, max_keep_len],padding 为 -1 + original_seq_len: int, # 原始序列长度 + ) -> torch.Tensor: + """ + 将铺平的 hidden states 还原为 [B, original_seq_len, hidden_size] + + Args: + flattened_hidden_states: [N_total_valid, hidden_size] 铺平后的有效 hidden states + keep_indices: [B, max_keep_len] 每个 batch 的 keep indices,padding 为 -1 + original_seq_len: 原始序列长度 + + Returns: + restored_hidden_states: [B, original_seq_len, hidden_size],无效位置用 0 填充 + """ + batch_size, max_keep_len = keep_indices.shape + hidden_size = flattened_hidden_states.shape[-1] + device = flattened_hidden_states.device + dtype = flattened_hidden_states.dtype + + # 创建输出 tensor,用 0 填充 + restored_hidden_states = torch.zeros( + batch_size, original_seq_len, hidden_size, + dtype=dtype, device=device + ) + + # 创建有效 mask: [B, max_keep_len] + valid_mask = keep_indices >= 0 + + # 创建 batch 索引: [B, max_keep_len] + batch_idx = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(keep_indices) + + # 取出有效部分的索引 + valid_batch_idx = batch_idx[valid_mask] # [N_total_valid] + valid_seq_idx = keep_indices[valid_mask] # [N_total_valid] + + # 写入还原位置 + restored_hidden_states[valid_batch_idx, valid_seq_idx, :] = flattened_hidden_states + + return restored_hidden_states def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> int: # If there is a failed server session, this code closes it diff --git a/src/bloombee/client/remote_generation.py b/src/bloombee/client/remote_generation.py index aabffae3..33ed6743 100644 --- a/src/bloombee/client/remote_generation.py +++ b/src/bloombee/client/remote_generation.py @@ -1,7 +1,7 @@ import contextlib import dataclasses from contextvars import ContextVar -from typing import Any, ContextManager, Dict, List, Optional, Tuple +from typing import Any, ContextManager, Dict, List, Optional, Tuple, Union import torch import transformers @@ -22,22 +22,38 @@ class RemotePastKeyValues(Cache): def __init__(self) -> None: super().__init__() - self._seen_tokens = 0 + self._seen_tokens: Optional[torch.Tensor] = None self.hypo_ids: Optional[torch.LongTensor] = None self.kv_cache_position_ids: Optional[torch.LongTensor] = None self.is_spec_decoding: Optional[torch.LongTensor] = None + self.prefill_length: Optional[torch.LongTensor] = None def __getitem__(self, _index: int) -> List[torch.Tensor]: return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation() def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + if self._seen_tokens is None: + return 0 + if self._seen_tokens.dim() == 0: + return self._seen_tokens.item() + return self._seen_tokens[0].item() + + def get_seq_length_batch(self) -> Optional[torch.Tensor]: return self._seen_tokens def get_max_length(self) -> Optional[int]: return None - def update_seen(self, new_seen: int) -> None: - self._seen_tokens += new_seen + def update_seen(self, new_seen: Union[int, torch.Tensor]) -> None: + if isinstance(new_seen, int): + self._seen_tokens = torch.tensor([new_seen]) + elif isinstance(new_seen, torch.Tensor): + if new_seen.dim() == 0: + new_seen = new_seen.unsqueeze(0) + self._seen_tokens = new_seen + else: + raise TypeError(f"new_seen must be int or torch.Tensor, got {type(new_seen)}") + def reorder_cache(self, beam_idx): raise NotImplementedError("Beam search reordering is not implemented yet") @@ -47,6 +63,9 @@ def set_kv_cache(self, position_ids: Optional[torch.LongTensor]): def set_is_spec_decoding(self, is_spec_decoding: Optional[torch.LongTensor]): self.is_spec_decoding = is_spec_decoding + + def set_prefill_length(self, prefill_length: Optional[torch.LongTensor]): + self.prefill_length = prefill_length _skipped_tokens = ContextVar("skipped_tokens", default=0) diff --git a/src/bloombee/flexgen_utils/pytorch_backend.py b/src/bloombee/flexgen_utils/pytorch_backend.py index f6d09aeb..b7c6476b 100644 --- a/src/bloombee/flexgen_utils/pytorch_backend.py +++ b/src/bloombee/flexgen_utils/pytorch_backend.py @@ -100,8 +100,8 @@ def precompute_freqs_cis( freqs = torch.outer(t, freqs).float() # type: ignore freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 else: - t = position_ids.float().to(inv_freq.device) - freqs = torch.outer(t, inv_freq).float() + t = position_ids.float().to(inv_freq.device) # [B, S] + freqs = t.unsqueeze(-1) * inv_freq.reshape(1, 1, -1) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) return freqs_cis @@ -723,6 +723,8 @@ def mha_gen_llama(self, inputs, attention_mask, w_q, w_k, w_v, hidden = rms_norm(inputs.data, input_layernorm.data) + # logger.info(f"after norm, hidden states: {hidden}") + # shape: (b, 1, h) q = F.linear(hidden, w_q.data) k = F.linear(hidden, w_k.data) @@ -733,6 +735,16 @@ def mha_gen_llama(self, inputs, attention_mask, w_q, w_k, w_v, k = k.view(b, tgt_s, n_head, head_dim) v = v.view(b, tgt_s, n_head, head_dim) + # logger.info(f"after projection, query_states: {q}") + # logger.info(f"after projection, key_states: {k}") + # logger.info(f"after projection, value_states: {v}") + + # logger.info(f"attention_mask: {attention_mask.shape}") + # logger.info(f"inputs: {inputs.shape}") + # logger.info(f"freq_cis: {freq_cis.shape}") + # logger.info(f"src_s: {src_s}") + # logger.info(f"tgt_s: {tgt_s}") + if rotary_position_ids is not None: freqs_slice = freq_cis[-tgt_s:] else: @@ -740,6 +752,9 @@ def mha_gen_llama(self, inputs, attention_mask, w_q, w_k, w_v, q, k = apply_rotary_emb(q, k, freqs_cis=freqs_slice) + # logger.info(f"after rotary, query_states: {q}") + # logger.info(f"after rotary, key_states: {k}") + # shape: (b * n_head, 1, head_dim) q = q.permute(0, 2, 1, 3).reshape(b * n_head, tgt_s, head_dim) # shape: (1, b * n_head, head_dim) diff --git a/src/bloombee/models/llama/model.py b/src/bloombee/models/llama/model.py index 3761d361..cd8cccdf 100644 --- a/src/bloombee/models/llama/model.py +++ b/src/bloombee/models/llama/model.py @@ -88,7 +88,7 @@ def forward( # print('model.py llama model inputs_embeds, ', inputs_embeds) # Temporarily commented for cleaner debug output output_shape = input_shape + (hidden_states.size(-1),) - # logger.info(f"hidden_states: {hidden_states}") + # logger.info(f"input_ids: {input_ids}") hidden_states = self.layers( hidden_states, @@ -98,6 +98,7 @@ def forward( kv_cache_position_ids=past_key_values.kv_cache_position_ids if past_key_values is not None else None, draft_tokens = input_ids, is_spec_decoding = past_key_values.is_spec_decoding if past_key_values is not None else None, + prefill_length = past_key_values.prefill_length if past_key_values is not None else None, ) if past_key_values is None: diff --git a/src/bloombee/models/llama/spe_dec_tree.py b/src/bloombee/models/llama/spe_dec_tree.py index d7918235..ca73afb8 100644 --- a/src/bloombee/models/llama/spe_dec_tree.py +++ b/src/bloombee/models/llama/spe_dec_tree.py @@ -185,8 +185,24 @@ def prepare_incremental_tree_batch( trees: List[SpeculativeTree], input_ids: torch.LongTensor, device: torch.device, - pad_token_id: int = 0 + pad_token_id: int = 0, + seq_lengths: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, List[List[List[TreeNode]]]]: + """ + 准备增量 tree batch,支持不同序列长度 + + Args: + trees: speculative trees 列表 + input_ids: [batch_size, max_seq_len] 输入 token ids(可能包含 padding) + device: 设备 + pad_token_id: padding token id + seq_lengths: [batch_size] 每个序列的真实长度,如果为 None 则假设所有序列长度相同 + + Returns: + tree_tokens: [batch_size, max_tree_size] + attention_mask: [batch_size, max_tree_size, past_len + max_tree_size] + batch_node_paths: 每个 batch 的节点路径列表 + """ batch_size = len(trees) if not trees or all(tree.total_nodes <= 1 for tree in trees): @@ -237,10 +253,11 @@ def prepare_tree_attention_batch( trees: List[SpeculativeTree], prefix_tokens: torch.Tensor, device: torch.device, - pad_token_id: int = 0 + pad_token_id: int = 0, + seq_lengths: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, List[List[List[TreeNode]]]]: tree_tokens, attention_mask, batch_node_paths = prepare_incremental_tree_batch( - trees, prefix_tokens, device, pad_token_id + trees, prefix_tokens, device, pad_token_id, seq_lengths ) if tree_tokens.shape[1] > 0: full_sequence = torch.cat([prefix_tokens, tree_tokens], dim=-1) diff --git a/src/bloombee/models/llama/speculative_model.py b/src/bloombee/models/llama/speculative_model.py index c6d5326b..fb105e26 100644 --- a/src/bloombee/models/llama/speculative_model.py +++ b/src/bloombee/models/llama/speculative_model.py @@ -37,7 +37,7 @@ def generate( max_tree_depth: int = 4, use_kv_cache: bool = True, kv_cache_window: int = 2048, - max_new_tokens: int = 128, + max_new_tokens: int = 64, **model_kwargs, ) -> torch.LongTensor: @@ -49,7 +49,7 @@ def generate( generation_config.return_dict_in_generate = False # Calculate session max length - this is critical for distributed inference - session_max_length = 2048 + session_max_length = 4096 # Use inference session for proper distributed caching with self.transformer.h.inference_session(max_length=session_max_length) as session: @@ -93,23 +93,45 @@ def _sample_with_session( # Initialize past_key_values for session tracking past_key_values = RemotePastKeyValues() - past_key_values.update_seen(session.position) + batch_positions = torch.full( + (batch_size,), + session.position, + dtype=torch.long, + device="cuda" + ) + past_key_values.update_seen(batch_positions) past_key_values.set_is_spec_decoding(torch.tensor([1], dtype=torch.long, device="cuda")) is_first_iteration = True step_idx = 0 current_input_ids = input_ids - result = "" llm_generated_token = None - while not finished and current_input_ids.shape[1] < input_ids.shape[1] + max_new_tokens: - # 1. Build speculative trees using SSM + # 新增:维护每个序列的真实长度 + seq_lengths = torch.full((batch_size,), input_ids.shape[1], dtype=torch.long, device=input_ids.device) + ignore_token_ids: list = [0, 2] + valid_mask = torch.ones_like(input_ids, dtype=torch.bool) + for token_id in ignore_token_ids: + valid_mask = valid_mask & (input_ids != token_id) + + # 计算每个序列的有效 token 数量 + seq_lengths = valid_mask.sum(dim=1) # [batch_size] + past_key_values.set_prefill_length(seq_lengths) + + pad_token_id = generation_config.pad_token_id if generation_config.pad_token_id is not None else 0 + logger.info(f"init input_ids: {input_ids}, seq_lengths: {seq_lengths}") + # 修改循环条件:基于最短序列的长度判断 + initial_len = input_ids.shape[1] + while not finished and (seq_lengths.min().item() - initial_len) < max_new_tokens: + # 1. Build speculative trees using SSM - 传入 seq_lengths spec_trees = self._build_speculative_trees_batched( - current_input_ids, ssm, beam_width, max_tree_depth + current_input_ids, ssm, beam_width, max_tree_depth, seq_lengths ) - # 2. Verify trees using distributed inference - but through forward() call - verified_tokens, verified_tokens_positions, past_key_values, llm_generated_token = self._verify_trees_with_forward( + # logger.info(f"spec_trees, {spec_trees}") + + # 2. Verify trees using distributed inference + verified_tokens, verified_tokens_positions, past_key_values, llm_generated_token, valid_lengths = self._verify_trees_with_forward( input_ids=current_input_ids, llm_generated_token=llm_generated_token, trees=spec_trees, @@ -117,28 +139,50 @@ def _sample_with_session( past_key_values=past_key_values, is_first_iteration=is_first_iteration, use_kv_cache=use_kv_cache, - kv_cache_window=kv_cache_window + kv_cache_window=kv_cache_window, + seq_lengths=seq_lengths, ) + # logger.info(f"verified_tokens_positions: {verified_tokens_positions}") + past_key_values.set_kv_cache(verified_tokens_positions) is_first_iteration = False # 3. Apply stopping conditions if has_eos_stopping_criteria: - verified_tokens = verified_tokens * unfinished_sequences + generation_config.pad_token_id * ( - 1 - unfinished_sequences + if verified_tokens is not None: + verified_tokens = verified_tokens * unfinished_sequences.unsqueeze(-1) + pad_token_id * ( + 1 - unfinished_sequences.unsqueeze(-1) + ) + llm_generated_token = llm_generated_token * unfinished_sequences.unsqueeze(-1) + pad_token_id * ( + 1 - unfinished_sequences.unsqueeze(-1) ) - # 4. Update input sequence - if verified_tokens is not None: - current_input_ids = torch.cat([current_input_ids, verified_tokens], dim=-1) - current_input_ids = torch.cat([current_input_ids, llm_generated_token.unsqueeze(0)], dim=-1) + # 4. Update input sequence with proper padding handling + # logger.info(f"current_input_ids: {current_input_ids}") + # logger.info(f"verified_tokens: {verified_tokens}") + # logger.info(f"llm_generated_token: {llm_generated_token}") + # logger.info(f"valid_lengths: {valid_lengths}") + # logger.info(f"seq_lengths: {seq_lengths}") + current_input_ids, seq_lengths = self._update_input_ids_with_padding( + current_input_ids=current_input_ids, + verified_tokens=verified_tokens, + llm_generated_token=llm_generated_token, + valid_lengths=valid_lengths, + seq_lengths=seq_lengths, + pad_token_id=pad_token_id, + ) + + # logger.info(f"current_input_ids: {current_input_ids}, seq_lengths: {seq_lengths}") - if streamer is not None: - streamer.put(verified_tokens.cpu()) - else : - current_input_ids = torch.cat([current_input_ids, llm_generated_token.unsqueeze(0)], dim=-1) + if streamer is not None: + # Stream 时根据 valid_lengths 只输出有效 token + for i in range(batch_size): + if unfinished_sequences[i]: + if verified_tokens is not None and valid_lengths[i] > 0: + streamer.put(verified_tokens[i, :valid_lengths[i]].cpu()) + streamer.put(llm_generated_token[i].cpu()) # 5. Check if finished unfinished_sequences = unfinished_sequences & ~stopping_criteria(current_input_ids, None) @@ -147,7 +191,61 @@ def _sample_with_session( if streamer is not None: streamer.end() + return current_input_ids + + def _update_input_ids_with_padding( + self, + current_input_ids: torch.LongTensor, + verified_tokens: Optional[torch.LongTensor], + llm_generated_token: torch.LongTensor, + valid_lengths: torch.LongTensor, + seq_lengths: torch.LongTensor, + pad_token_id: int, + ) -> Tuple[torch.LongTensor, torch.LongTensor]: + """ + 更新 input_ids,处理不同序列验证通过的 token 数量不同的情况 + + Returns: + updated_input_ids: 更新后的 input_ids,padding 对齐 + updated_seq_lengths: 更新后的每个序列真实长度 + """ + batch_size = current_input_ids.shape[0] + device = current_input_ids.device + + # 计算每个序列需要添加的 token 数(verified + 1 个 llm token) + tokens_to_add = valid_lengths + 1 # [batch_size] + + # 计算新的序列长度 + new_seq_lengths = seq_lengths + tokens_to_add + new_max_len = new_seq_lengths.max().item() + + # 创建新的 input_ids tensor + new_input_ids = torch.full( + (batch_size, new_max_len), + pad_token_id, + dtype=torch.long, + device=device + ) + + for i in range(batch_size): + old_len = seq_lengths[i].item() + new_len = new_seq_lengths[i].item() + + # 复制原有的有效 token + new_input_ids[i, :old_len] = current_input_ids[i, :old_len] + + # 添加验证通过的 token + v_len = valid_lengths[i].item() + if v_len > 0 and verified_tokens is not None: + new_input_ids[i, old_len:old_len + v_len] = verified_tokens[i, :v_len] + # 添加 llm_generated_token + new_input_ids[i, old_len + v_len] = llm_generated_token[i, 0] + else: + # 只添加 llm_generated_token + new_input_ids[i, old_len] = llm_generated_token[i, 0] + + return new_input_ids, new_seq_lengths def _verify_trees_with_forward( self, @@ -159,21 +257,38 @@ def _verify_trees_with_forward( is_first_iteration: bool, use_kv_cache: bool, kv_cache_window: int, - ) -> Tuple[torch.LongTensor, torch.Tensor, RemotePastKeyValues, torch.Tensor]: + seq_lengths: torch.LongTensor, + ) -> Tuple[torch.LongTensor, torch.Tensor, RemotePastKeyValues, torch.Tensor, torch.Tensor]: """ Verify speculative trees using standard forward() call within the active session context + + Returns: + verified_tokens: [batch_size, max_verified_len] 或 None + kv_cache_position_ids: [batch_size, max_pos_len] + past_key_values: 更新后的 past_key_values + llm_generated_tokens: [batch_size, 1] + valid_lengths: [batch_size] 每个序列验证通过的 token 数 """ tree_tokens, attention_mask, batch_node_paths = prepare_incremental_tree_batch( - trees, input_ids, input_ids.device + trees, input_ids, input_ids.device, seq_lengths=seq_lengths ) + # logger.info(f"tree_tokens: {tree_tokens}, attention_mask: {attention_mask.shape}") + + batch_size = input_ids.shape[0] + device = input_ids.device + if attention_mask is None or tree_tokens.shape[1] == 0: logger.warning("No tree tokens to verify, falling back to regular generation") - return self._fallback_generation_with_forward(input_ids, logits_processor, past_key_values), past_key_values + fallback_token = self._fallback_generation_with_forward(input_ids, logits_processor, past_key_values, seq_lengths) + valid_lengths = torch.zeros(batch_size, dtype=torch.long, device=device) + return None, torch.zeros(batch_size, 1, dtype=torch.long, device=device), past_key_values, fallback_token, valid_lengths tree_mask_packed = self.pack_bool_mask_to_int64(attention_mask) + # logger.info(f"tree_mask_packed: {tree_mask_packed}") + with torch.no_grad(): if not use_kv_cache: # No cache: process tree tokens directly @@ -189,27 +304,24 @@ def _verify_trees_with_forward( elif is_first_iteration or past_key_values is None: # First iteration: process full sequence to establish cache - full_sequence = torch.cat([input_ids, tree_tokens], dim=-1) + # 需要根据 seq_lengths 构建正确的 full_sequence + max_seq_len = seq_lengths.max().item() + full_sequence = torch.cat([input_ids[:, :max_seq_len], tree_tokens], dim=-1) + outputs = self( input_ids=full_sequence, - attention_mask=tree_mask_packed, # Let the session handle attention - past_key_values=past_key_values, # Start fresh + attention_mask=tree_mask_packed, + past_key_values=past_key_values, use_cache=True ) - # Extract only the tree portion of the logits logits = outputs.logits - # Update past_key_values tracking if past_key_values is None: new_past_key_values = RemotePastKeyValues() else: new_past_key_values = past_key_values - # The session will automatically handle the KV cache positioning - # if self.transformer.h.active_session: - # new_past_key_values.update_seen(self.transformer.h.active_session.position) - else: # Subsequent iterations: use existing cache active_session = self.transformer.h.active_session @@ -220,12 +332,12 @@ def _verify_trees_with_forward( if active_session.position > kv_cache_window: trim_amount = active_session.position - kv_cache_window active_session.position = kv_cache_window + if llm_generated_token is None: full_sequence = tree_tokens else: - full_sequence = torch.cat([llm_generated_token.unsqueeze(0), tree_tokens], dim=-1) + full_sequence = torch.cat([llm_generated_token, tree_tokens], dim=-1) - # Process tree tokens with existing cache outputs = self( input_ids=full_sequence, attention_mask=tree_mask_packed, @@ -237,11 +349,11 @@ def _verify_trees_with_forward( new_past_key_values = past_key_values new_past_key_values.update_seen(active_session.position) - # Extract verification results - verified_tokens, verified_tokens_positions, llm_generated_token = self._extract_best_verified_paths_fixed( - logits, batch_node_paths, input_ids, logits_processor, tree_tokens.shape[1] + # Extract verification results - 现在返回 valid_lengths + verified_tokens, kv_cache_position_ids, llm_generated_tokens, valid_lengths = self._extract_best_verified_paths_fixed( + logits, batch_node_paths, input_ids, logits_processor, tree_tokens.shape[1], seq_lengths, is_first_iteration ) - return verified_tokens, verified_tokens_positions, new_past_key_values, llm_generated_token + return verified_tokens, kv_cache_position_ids, new_past_key_values, llm_generated_tokens, valid_lengths def pack_bool_mask_to_int64(self, mask_bool: torch.Tensor) -> torch.Tensor: assert mask_bool.dtype == torch.bool, "Input must be a bool tensor" @@ -252,6 +364,7 @@ def _fallback_generation_with_forward( input_ids: torch.LongTensor, logits_processor: LogitsProcessorList, past_key_values: RemotePastKeyValues, + seq_lengths: torch.LongTensor, temperature: float = 1.0 ) -> torch.LongTensor: """ @@ -260,15 +373,23 @@ def _fallback_generation_with_forward( try: logger.info("[DEBUG] Using fallback generation") - # Generate single token using standard forward call + batch_size = input_ids.shape[0] + device = input_ids.device + + # 获取每个序列最后一个有效 token + last_tokens = torch.zeros(batch_size, 1, dtype=torch.long, device=device) + for i in range(batch_size): + last_pos = seq_lengths[i].item() - 1 + last_tokens[i, 0] = input_ids[i, last_pos] + outputs = self( - input_ids=input_ids[:, -1:], # Just the last token + input_ids=last_tokens, attention_mask=None, past_key_values=past_key_values, use_cache=True ) - logits = outputs.logits[:, -1, :] # Last position logits + logits = outputs.logits[:, -1, :] # [batch_size, vocab_size] # Apply logits processors processed_logits = logits @@ -286,30 +407,34 @@ def _fallback_generation_with_forward( except Exception as e: logger.error(f"Fallback generation failed: {e}") - # Ultimate fallback - return EOS token eos_token_id = getattr(self.config, 'eos_token_id', 2) - return torch.tensor([[eos_token_id]], device=input_ids.device) + return torch.full((input_ids.shape[0], 1), eos_token_id, device=input_ids.device) - # Keep your existing methods with minimal changes def _build_speculative_trees_batched( self, input_ids: torch.LongTensor, ssm: LlamaForCausalLM, beam_width: int, - max_depth: int + max_depth: int, + seq_lengths: torch.LongTensor, ) -> List[SpeculativeTree]: """Build speculative trees using the small model (SSM)""" - # start_total = time.time() batch_size = input_ids.shape[0] trees = [] + + pad_token_id = getattr(ssm.config, 'pad_token_id', 0) for batch_idx in range(batch_size): - # start_batch = time.time() - root_token = input_ids[batch_idx, -1].item() + # 获取该序列的真实长度 + actual_len = seq_lengths[batch_idx].item() + + # 只取有效部分的 input_ids + valid_input_ids = input_ids[batch_idx, :actual_len] + + root_token = valid_input_ids[-1].item() tree = SpeculativeTree(root_token, f"req_{batch_idx}") + for depth in range(max_depth): - # start_depth = time.time() - current_nodes = tree.get_nodes_at_depth(depth) if not current_nodes: break @@ -319,7 +444,7 @@ def _build_speculative_trees_batched( for node in current_nodes: path_to_node = node.get_path_from_root() context = torch.cat([ - input_ids[batch_idx, :-1], + valid_input_ids[:-1], # 使用有效的 input_ids torch.tensor([root_token] + path_to_node, device=input_ids.device) ]) contexts.append(context) @@ -333,16 +458,16 @@ def _build_speculative_trees_batched( for ctx in contexts: pad_len = max_len - len(ctx) - pad_token_id = getattr(ssm.config, 'pad_token_id', 0) + # 左侧 padding padded = torch.cat([ torch.full((pad_len,), pad_token_id, dtype=torch.long, device=input_ids.device), ctx ]) mask = torch.cat([ - torch.zeros(pad_len, dtype=torch.bool, device=input_ids.device), - torch.ones(len(ctx), dtype=torch.bool, device=input_ids.device) + torch.zeros(pad_len, dtype=torch.long, device=input_ids.device), + torch.ones(len(ctx), dtype=torch.long, device=input_ids.device) ]) padded_contexts.append(padded) @@ -352,14 +477,13 @@ def _build_speculative_trees_batched( batch_masks = torch.stack(attention_masks) # SSM forward - # start_ssm = time.time() with torch.no_grad(): - outputs = ssm(batch_contexts, attention_mask=batch_masks) - batch_logits = outputs.logits[:, -1, :] - # end_ssm = time.time() + # logger.info(f"batch_contexts: {batch_contexts}") + # logger.info(f"batch_masks: {batch_masks}") + outputs = ssm(batch_contexts, attention_mask=batch_masks, use_cache=False) + batch_logits = outputs.logits[:, -1, :] # 左侧 padding 所以 -1 是正确的 # Generate candidates - # start_add_layer = time.time() candidates_per_node = [] for i in range(len(current_nodes)): logits = batch_logits[i] @@ -391,71 +515,143 @@ def _extract_best_verified_paths_fixed( batch_node_paths: List[List[List[TreeNode]]], input_ids: torch.LongTensor, logits_processor: LogitsProcessorList, - tree_len: int # do not include root - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tree_len: int, + seq_lengths: torch.LongTensor, + is_first_iteration: bool, + ) -> Tuple[Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: """ - Extract best verified paths (simplified for batch=1) - Returns: (verified_tokens, kv_cache_position_ids) + Returns: + verified_tokens: [batch_size, max_verified_len] 或 None + kv_cache_position_ids: [batch_size, max_pos_len] + llm_generated_tokens: [batch_size, 1] + valid_lengths: [batch_size] 每个序列验证通过的 token 数(不包括 llm token) """ + batch_size = logits.shape[0] + seq_len = logits.shape[1] total_tree_tokens = tree_len - fallback_pos = max(0, logits.shape[1] - total_tree_tokens) - tree_root_position = input_ids.shape[1] - 1 - - node_paths = batch_node_paths[0] - best_verified = [] - best_positions = [] - best_score = -1 - for node_path in node_paths: - verified_tokens = [] - verified_positions = [] - for node in node_path: - pos = node.parent.position_in_sequence + 1 - if pos >= logits.shape[1]: - break - - predicted_token = torch.argmax(logits[0, pos]).item() + fallback_pos = max(0, seq_len - total_tree_tokens) + device = logits.device + + # 存储结果 + verified_tokens_list = [] + positions_list = [] + llm_tokens_list = [] + valid_lengths_list = [] + + for batch_idx in range(batch_size): + actual_len = seq_lengths[batch_idx].item() + real_fallback_pos = actual_len if is_first_iteration else fallback_pos + tree_root_position = actual_len - 1 + + node_paths = batch_node_paths[batch_idx] + best_verified = [] + best_positions = [] + best_score = -1 + + for node_path in node_paths: + verified_tokens = [] + verified_positions = [] - current_logits = logits[0, node.position_in_sequence + 1] + for node in node_path: + pos = node.parent.position_in_sequence + 1 + if pos >= seq_len: + break + + predicted_token = torch.argmax(logits[batch_idx, pos]).item() + + if predicted_token == node.token_id: + verified_tokens.append(node.token_id) + absolute_position = tree_root_position + node.position_in_sequence + 1 + verified_positions.append(absolute_position) + else: + break - if (torch.all(current_logits == 0)): - break + if len(verified_tokens) > best_score: + best_score = len(verified_tokens) + best_verified = verified_tokens + best_positions = verified_positions + + # 确定取 llm_token 的位置 + if len(best_verified) > 0: + pos = best_positions[-1] - tree_root_position + final_logits = logits[batch_idx, pos].unsqueeze(0) - if predicted_token == node.token_id: - verified_tokens.append(node.token_id) - absolute_position = tree_root_position + node.position_in_sequence + 1 - verified_positions.append(absolute_position) + # 检查是否全 0(被裁剪),需要回退 + if torch.all(final_logits == 0): + # 回退:最后一个 verified token 作为 llm_token + llm_token = torch.tensor([best_verified[-1]], device=device) + best_verified = best_verified[:-1] + best_positions = best_positions[:-1] else: - break + # 正常:从 logits 采样 + processed_logits = final_logits.clone() + for processor in logits_processor: + processed_logits = processor( + input_ids[batch_idx:batch_idx+1], + processed_logits + ) + next_token = torch.argmax(processed_logits[0]).item() + llm_token = torch.tensor([next_token], device=device) + else: + # fallback: 从 fallback_pos 采样 + final_logits = logits[batch_idx, real_fallback_pos - 1:real_fallback_pos] + processed_logits = final_logits.clone() + for processor in logits_processor: + processed_logits = processor( + input_ids[batch_idx:batch_idx+1], + processed_logits + ) + next_token = torch.argmax(processed_logits[0]).item() + llm_token = torch.tensor([next_token], device=device) - if len(verified_tokens) > best_score: - best_score = len(verified_tokens) - best_verified = verified_tokens - best_positions = verified_positions - - if len(best_verified) == 0: - final_logits = logits[0, fallback_pos-1:fallback_pos] # 取单个位置的logits - processed_logits = final_logits - for processor in logits_processor: - processed_logits = processor(input_ids, processed_logits) - - top5_values, top5_indices = torch.topk(final_logits[0], k=5) - next_token = torch.argmax(final_logits[0]).item() - kv_cache_position_ids = torch.tensor([tree_root_position], - device=logits.device) - llm_generated_token = torch.tensor([next_token], device=logits.device) - return None, kv_cache_position_ids, llm_generated_token - - verified_tensor = torch.tensor([best_verified], device=logits.device) # [1, num_verified] - pos = best_positions[-1] - tree_root_position - final_logits = logits[0, pos] # 取单个位置的logits - - processed_logits = final_logits - for processor in logits_processor: - processed_logits = processor(input_ids, processed_logits) - next_token = torch.argmax(final_logits).item() - llm_generated_token = torch.tensor([next_token], device=logits.device) - - all_positions = [tree_root_position] + best_positions - kv_cache_position_ids = torch.tensor(all_positions, device=logits.device) - - return verified_tensor, kv_cache_position_ids, llm_generated_token \ No newline at end of file + # 构建 positions + all_positions = [tree_root_position] + best_positions + positions = torch.tensor(all_positions, device=device) + + # 构建 verified_tensor + if len(best_verified) > 0: + verified_tensor = torch.tensor(best_verified, dtype=torch.long, device=device) + else: + verified_tensor = torch.empty(0, dtype=torch.long, device=device) + + verified_tokens_list.append(verified_tensor) + positions_list.append(positions) + llm_tokens_list.append(llm_token) + valid_lengths_list.append(len(best_verified)) + + # 统一 padding 成 batch tensor + + # 1. llm_generated_tokens: [batch_size, 1] + llm_generated_tokens = torch.stack(llm_tokens_list, dim=0) + + # 2. valid_lengths: [batch_size] + valid_lengths = torch.tensor(valid_lengths_list, dtype=torch.long, device=device) + + # 3. positions: [batch_size, max_pos_len] + max_pos_len = max(pos.shape[0] for pos in positions_list) + kv_cache_position_ids = torch.full( + (batch_size, max_pos_len), + -1, + dtype=torch.long, + device=device + ) + for i, pos in enumerate(positions_list): + kv_cache_position_ids[i, :pos.shape[0]] = pos + + # 4. verified_tokens: [batch_size, max_verified_len] 或 None + max_verified_len = max(v.shape[0] for v in verified_tokens_list) if verified_tokens_list else 0 + + if max_verified_len > 0: + verified_tokens = torch.full( + (batch_size, max_verified_len), + -1, + dtype=torch.long, + device=device + ) + for i, v in enumerate(verified_tokens_list): + if v.shape[0] > 0: + verified_tokens[i, :v.shape[0]] = v + else: + verified_tokens = None + + return verified_tokens, kv_cache_position_ids, llm_generated_tokens, valid_lengths \ No newline at end of file diff --git a/src/bloombee/server/backend.py b/src/bloombee/server/backend.py index cfbae15f..59a18218 100644 --- a/src/bloombee/server/backend.py +++ b/src/bloombee/server/backend.py @@ -140,7 +140,7 @@ def __init__( 1, 128, dtype=self.dtype ), # draft_tokens BatchTensorDescriptor( - 1, dtype=torch.int64 + 128, dtype=torch.int64 ), # prefill_length BatchTensorDescriptor( 1, dtype=torch.int64 @@ -194,19 +194,22 @@ def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> S cache_tensors.append(keys) return cache_tensors - def prune_draft_tree(self, original_hidden_states: torch.Tensor, norm_hidden_states: torch.Tensor, draft_tokens: torch.Tensor, tree_attention_mask): + def prune_draft_tree( + self, + norm_hidden_states: torch.Tensor, + draft_tokens: torch.Tensor, + tree_attention_mask: torch.Tensor + ): results = self.pruner_manager.prune_speculation_tree( norm_hidden_states, draft_tokens, tree_attention_mask ) - keep_indices = results['keep_indices'] - logger.info(f"keep_indices: {keep_indices}") + keep_indices = results['keep_indices'] # [B, max_keep_len],padding 为 -1 + # logger.info(f"keep_indices: {keep_indices}") self.pruner_manager.middle_keep_indices = keep_indices - new_hidden_states = original_hidden_states[:, keep_indices, :] - keep_indices = torch.tensor(keep_indices, dtype=torch.int64, device=original_hidden_states.device) - return new_hidden_states, keep_indices + return keep_indices def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: *inputs, active_adapter = inputs @@ -276,46 +279,71 @@ def inference_step( # Each block will execute once # print("transformer backend inference step : output_hidden_states", output_hidden_states) # output_hidden_states:None # Centralized select: aggregate + reorder + slice kv_cache_position_ids = inference_info.kv_cache_position_ids - logger.info(f"_last_keep_indices: {self._last_keep_indices}") - logger.info(f"keep_indices: {inference_info.keep_indices}") - logger.info(f"before format kv_cache_position_ids: {kv_cache_position_ids}") + # logger.info(f"_last_keep_indices: {self._last_keep_indices}") + # logger.info(f"keep_indices: {inference_info.keep_indices}") + # logger.info(f"before format kv_cache_position_ids: {kv_cache_position_ids}") # if self._is_spec_decoding and not self._need_pruning: # kv_cache_position_ids = self._update_kv_cache_position_ids(kv_cache_position_ids, self._last_keep_indices) - logger.info(f"after format kv_cache_position_ids: {kv_cache_position_ids}") - k_pkv, v_pkv, need_reorder = self.cache_manager.select_cache( - prefix_length=inference_info.prefix_length, - hypo_ids=hypo_ids, - kv_cache_position_ids=kv_cache_position_ids - ) - if k_pkv is not None: - logger.info(f"kv cache size: {k_pkv.shape}, prefix_length: {inference_info.prefix_length}") + # logger.info(f"after format kv_cache_position_ids: {kv_cache_position_ids}") + if kv_cache_position_ids is not None and kv_cache_position_ids.numel() > 0: + # 1. 取出需要 reorder 的 cache + k_pkv_old, v_pkv_old, need_reorder = self.cache_manager.select_cache_for_reorder( + kv_cache_position_ids=kv_cache_position_ids + ) + + if need_reorder and k_pkv_old is not None: + # 2. 重排并写回,获取每个 batch 的有效长度 + new_prefix_length, kv_valid_lengths = self.cache_manager.reorder_and_write_cache( + k_pkv=k_pkv_old, + v_pkv=v_pkv_old, + kv_cache_position_ids=kv_cache_position_ids, + ) + else: + new_prefix_length = inference_info.prefix_length + kv_valid_lengths = torch.full((batch_size,), new_prefix_length, device=hidden_states.device) + + # 3. 取连续的 cache + k_pkv, v_pkv, _ = self.cache_manager.select_cache( + prefix_length=new_prefix_length, + hypo_ids=hypo_ids, + kv_cache_position_ids=None, + ) else: - logger.info(f"kv cache is None, prefix_length: {inference_info.prefix_length}") - layer_past = (k_pkv, v_pkv) if k_pkv is not None else None - if need_reorder: - # In speculative decoding mode, inference is not squential so that kv-cache need to be reordered. - self.cache_manager.write_pkv_cache( - k_pkv=k_pkv, - v_pkv=v_pkv, - start_position=0 + k_pkv, v_pkv, _ = self.cache_manager.select_cache( + prefix_length=inference_info.prefix_length, + hypo_ids=hypo_ids, + kv_cache_position_ids=None ) - - past_key_values_length = k_pkv.shape[2] if k_pkv is not None else 0 - logger.info(f"past_key_values_length: {past_key_values_length}") - logger.info(f"inference_info.tree_attention_mask: {inference_info.tree_attention_mask}") + new_prefix_length = k_pkv.shape[2] if k_pkv is not None else 0 + kv_valid_lengths = torch.full((batch_size,), inference_info.prefix_length, device=hidden_states.device) + + layer_past = (k_pkv, v_pkv) if k_pkv is not None else None + # if k_pkv is not None: + # logger.info(f"kv cache size: {k_pkv.shape}, prefix_length: {inference_info.prefix_length}") + # else: + # logger.info(f"kv cache is None, prefix_length: {inference_info.prefix_length}") + + # logger.info(f"past_key_values_length: {new_prefix_length}, hidden_states.device: {hidden_states.device}") + # logger.info(f"inference_info.tree_attention_mask, shape: {inference_info.tree_attention_mask.shape}") full_mask = None + device = hidden_states.device + # logger.info(f"kv_valid_lengths: {kv_valid_lengths}") + if self._is_spec_decoding: full_mask = self._create_attention_mask( - tree_attention_mask=inference_info.tree_attention_mask, - src_len=seq_len + past_key_values_length, - past_key_values_length=past_key_values_length, + tree_attention_mask=inference_info.tree_attention_mask.to(device), + src_len=seq_len + new_prefix_length, + past_key_values_length=new_prefix_length, + kv_valid_lengths=kv_valid_lengths.to(device), # 新增参数 + prefill_lengths=inference_info.prefill_length.to(device), device=hidden_states.device, ) attention_mask = self.convert_mask_to_scores(full_mask) if full_mask is not None else None if full_mask == None: - full_mask = self._create_causal_attention_mask(batch_size, (seq_len + past_key_values_length), past_key_values_length, hidden_states.device) + full_mask = self._create_causal_attention_mask(batch_size, (seq_len + new_prefix_length), new_prefix_length, hidden_states.device) attention_mask = self.convert_mask_to_scores(full_mask) if full_mask is not None else None - logger.info(f"full_mask: {full_mask}") + # logger.info(f"full_mask, shape: {full_mask.shape}, {full_mask}") + # logger.info(f"hidden states in backend before compute: {hidden_states}") for offset in range(0, seq_len, max_chunk_length): # Iterate through sequence to process hidden states in chunks only run offset=0 hidden_states_chunk = hidden_states[:, offset : offset + max_chunk_length, :] # Get current hidden states chunk # print('transformer backend inference step() offset ', offset ) @@ -331,16 +359,21 @@ def inference_step( # Each block will execute once self._position_ids_cache[cache_key] = base_ids.unsqueeze(0).expand(batch_size, -1) # Add offset to cached base tensor (avoids creating new tensor) - position_ids = self._position_ids_cache[cache_key] + (past_key_values_length + offset) + position_ids = self._position_ids_cache[cache_key] + (new_prefix_length + offset) + # logger.info(f"position_ids: {position_ids}") + # logger.info(f"prefill_length: {inference_info.prefill_length}, kv_valid_lengths: {kv_valid_lengths}") if self._is_spec_decoding: rotary_position_ids = self._create_tree_position_ids( - 2, 4, inference_info.prefill_length - 1, past_key_values_length, device='cuda:0' + 2, 4, inference_info.prefill_length - 1, kv_valid_lengths, device='cuda' ) else: rotary_position_ids = None - rotary_position_ids = rotary_position_ids[inference_info.keep_indices] if rotary_position_ids is not None and inference_info.keep_indices is not None else rotary_position_ids - logger.info(f"rotary_position_ids: {rotary_position_ids}") - logger.info(f"position_ids: {position_ids}") + + # logger.info(f"before gather rotary_position_ids: {rotary_position_ids}") + # logger.info(f"keep_indices: {inference_info.keep_indices}") + # logger.info(f"hidden_states: {hidden_states.shape}") + # rotary_position_ids = rotary_position_ids = torch.gather(rotary_position_ids, 1, inference_info.keep_indices.to("cuda")) if rotary_position_ids is not None and inference_info.keep_indices is not None else rotary_position_ids + # logger.info(f"after gather rotary_position_ids: {rotary_position_ids}") try: # Fixed: Properly handle forward method return values with position_ids # print(f' About to call module.forward with position_ids...') @@ -371,8 +404,8 @@ def inference_step( # Each block will execute once except Exception as e: print(f' ERROR in module.forward: {type(e).__name__}: {e}') - # import traceback - # traceback.print_exc() + import traceback + traceback.print_exc() return (hidden_states, None) # Return original input as fallback if seq_len > max_chunk_length: @@ -383,15 +416,27 @@ def inference_step( # Each block will execute once # logger.info(f"inference_step, output_hidden_states: {output_hidden_states}") # Centralized KV update via KVCacheManager (logs OFFLOAD: KV write ...) - self.cache_manager.update_cache(new_kvs, past_key_values_length) + if self._is_spec_decoding: + self.cache_manager.update_cache_batched(new_kvs, kv_valid_lengths) + else: + self.cache_manager.update_cache(new_kvs, new_prefix_length) keep_indices = inference_info.keep_indices if self._is_spec_decoding and self._need_pruning and self._is_last_block: norm_hidden_states = self.module.rms_norm(output_hidden_states) - output_hidden_states, keep_indices = self.prune_draft_tree(output_hidden_states, norm_hidden_states, inference_info.draft_tokens, full_mask) - - self._last_keep_indices = keep_indices + past_key_values_length - logger.info(f"update _last_keep_indices: {self._last_keep_indices}") + keep_indices = self.prune_draft_tree(norm_hidden_states, inference_info.draft_tokens, full_mask) + + if self._is_spec_decoding and self._is_last_block: + original_hidden_states = output_hidden_states + batch_size, seq_len, hidden_size = original_hidden_states.shape + device = original_hidden_states.device + valid_mask = keep_indices >= 0 + batch_idx = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(keep_indices) + valid_hidden_states = original_hidden_states[batch_idx[valid_mask], keep_indices[valid_mask], :] + output_hidden_states = valid_hidden_states.unsqueeze(0) + + self._last_keep_indices = keep_indices + new_prefix_length + # logger.info(f"update _last_keep_indices: {self._last_keep_indices}") # In training mode, you need to deploy your whole model in one device and choose a specific middle layer. After saving the middle_states, you can train the MLP network by comparing the middle states and final states logits. training_mode = False @@ -432,24 +477,60 @@ def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Te cache_tensor[...] = cache_tensor[hypo_ids.to(cache_tensor.device)] # in-place reorder cache by hypo ids def _create_tree_position_ids( - self, width: int, depth: int, prefill_length: int, past_len: int, device: torch.device + self, + width: int, + depth: int, + prefill_length: torch.Tensor, # 每个 batch 样本当前的有效总长度(已包含在该样本在 KV cache 中的位置) + kv_valid_lengths: torch.Tensor, # KV cache 的统一起始偏移量(通常是所有样本对齐后的基准) + device: torch.device ) -> torch.Tensor: - position_ids = [] - depth = depth + 1 - + + batch_size = prefill_length.shape[0] + + # 1. 生成 Tree 模板(相对偏移,根节点为 0) + tree_position_ids_list = [] def dfs_generate(node_depth, current_depth): - position_ids.append(node_depth) - if current_depth < depth - 1: + tree_position_ids_list.append(node_depth) + if current_depth < depth: for _ in range(width): dfs_generate(node_depth + 1, current_depth + 1) - dfs_generate(0, 0) + tree_len = len(tree_position_ids_list) + tree_position_ids = torch.tensor(tree_position_ids_list, device=device) + + # 判断是否为 Prefill 阶段 (根据输入长度或 past_len 判断) + # 假设如果 past_len == 0 且输入包含 prefill 部分 + is_prefill = (kv_valid_lengths.max().item() == 0) + + if is_prefill: + # --- Prefill 阶段逻辑 --- + max_prefill_len = prefill_length.max().item() + total_len = max_prefill_len + tree_len + full_position_ids = torch.zeros(batch_size, total_len, dtype=torch.long, device=device) + + for i in range(batch_size): + pl = prefill_length[i].item() + # Prefill 部分: [0, 1, ..., pl-1] + full_position_ids[i, :pl] = torch.arange(pl, device=device) + # Tree 部分: 接在有效 Prefill 长度 pl 之后,并推到 total_len 的末尾 + full_position_ids[i, max_prefill_len:] = tree_position_ids + pl + + return full_position_ids - tree_position_ids = torch.tensor(position_ids, device=device) + past_len + prefill_length - prefill_ids = torch.arange(prefill_length, device=device) + past_len - full_position_ids = torch.cat([prefill_ids, tree_position_ids], dim=0) - - return full_position_ids + else: + # --- Generation / Verify 阶段逻辑 --- + # 此时 input_ids 只有 Tree 部分,长度为 tree_len + full_position_ids = torch.zeros(batch_size, tree_len, dtype=torch.long, device=device) + + for i in range(batch_size): + # 重点:这里的 pl 应该是该样本在上一轮结束时已经确认的 token 总数 + past_len = kv_valid_lengths[i] + # 这里的 ID 必须累加 past_len 和 pl + # 如果你的 prefill_length[i] 已经包含了 past_len,则直接加 tree_position_ids + # 如果 prefill_length[i] 只是当前请求的长度,则需要 past_len + pl + tree_position_ids + full_position_ids[i, :] = tree_position_ids + past_len + + return full_position_ids def _update_kv_cache_position_ids(self, kv_cache_position_ids, keep_indices): if kv_cache_position_ids is None or keep_indices is None: @@ -471,36 +552,69 @@ def _create_attention_mask( *, src_len: int, # prefix_len + tree_len past_key_values_length: int, + kv_valid_lengths: Optional[torch.Tensor] = None, # [B] 每个 batch 的有效 KV 长度(后续轮次用) + prefill_lengths: Optional[torch.Tensor] = None, # [B] 每个 batch 的实际 prefill 长度(首轮用) device: torch.device, ) -> Optional[torch.Tensor]: if tree_attention_mask is None or is_dummy(tree_attention_mask): return None + + # logger.info(f"tree_attention_mask: {tree_attention_mask.shape}") + # logger.info(f"src_len: {src_len}") + # logger.info(f"past_key_values_length: {past_key_values_length}") + # logger.info(f"kv_valid_lengths: {kv_valid_lengths}") + # logger.info(f"prefill_lengths: {prefill_lengths}") tree_mask = tree_attention_mask tree_len = tree_mask.size(1) B = tree_mask.size(0) - prefix_len = src_len - tree_len + prefix_len = src_len - tree_len # 最大 prefix 长度(包含 padding) current_token_count = src_len - past_key_values_length if current_token_count <= 0: return None if past_key_values_length == 0: + # ============ 首轮:处理 prefill_lengths ============ full_mask = torch.zeros(B, src_len, src_len, dtype=torch.bool, device=device) + + # 如果没有提供 prefill_lengths,假设所有 batch 的 prefill 长度相同 + if prefill_lengths is None: + prefill_lengths = torch.full((B,), prefix_len, dtype=torch.long, device=device) + if prefix_len > 0: - causal_indices = torch.tril_indices(prefix_len, prefix_len, device=device) - full_mask[:, causal_indices[0], causal_indices[1]] = True + # 位置索引 + row_idx = torch.arange(prefix_len, device=device).view(1, -1, 1) # [1, prefix_len, 1] + col_idx = torch.arange(prefix_len, device=device).view(1, 1, -1) # [1, 1, prefix_len] + prefill_lens = prefill_lengths.view(B, 1, 1) # [B, 1, 1] + + # causal mask: row >= col + # 有效 mask: row < prefill_lengths AND col < prefill_lengths + causal_mask = row_idx >= col_idx # [1, prefix_len, prefix_len] + row_valid = row_idx < prefill_lens # [B, prefix_len, 1] + col_valid = col_idx < prefill_lens # [B, 1, prefix_len] + + prefix_mask = causal_mask & row_valid & col_valid # [B, prefix_len, prefix_len] + full_mask[:, :prefix_len, :prefix_len] = prefix_mask if prefix_len > 0 and tree_len > 0: - full_mask[:, prefix_len:, :prefix_len] = True + # tree tokens attend to prefix(只到有效 prefill 位置) + col_idx = torch.arange(prefix_len, device=device).view(1, 1, -1) # [1, 1, prefix_len] + prefill_lens = prefill_lengths.view(B, 1, 1) # [B, 1, 1] + col_valid = col_idx < prefill_lens # [B, 1, prefix_len] + col_valid = col_valid.expand(B, tree_len, prefix_len) # [B, tree_len, prefix_len] + full_mask[:, prefix_len:, :prefix_len] = col_valid if tree_len > 0: full_mask[:, prefix_len:, prefix_len:] = tree_mask + return full_mask else: + # ============ 后续轮次:处理 kv_valid_lengths ============ current_mask = torch.zeros(B, current_token_count, src_len, dtype=torch.bool, device=device) start_pos = past_key_values_length + if start_pos < prefix_len: prefix_tokens = min(current_token_count, prefix_len - start_pos) for i in range(prefix_tokens): @@ -515,7 +629,40 @@ def _create_attention_mask( if prefix_len > 0: current_mask[:, :, :prefix_len] = True current_mask[:, :, prefix_len:] = tree_mask[:, tree_start:tree_start + current_token_count, :] + + # 应用 kv_valid_lengths mask + if kv_valid_lengths is not None: + current_mask = self._apply_kv_valid_mask(current_mask, kv_valid_lengths, past_key_values_length, device) + return current_mask + + + def _apply_kv_valid_mask( + self, + mask: torch.Tensor, + kv_valid_lengths: torch.Tensor, + kv_len: int, + device: torch.device, + ) -> torch.Tensor: + """ + 将 kv_valid_lengths 应用到 mask 上,屏蔽每个 batch 超出有效长度的 KV 位置 + """ + if kv_len <= 0: + return mask + + B = mask.shape[0] + key_len = mask.shape[2] + actual_kv_len = min(kv_len, key_len) + + # [1, actual_kv_len] + kv_positions = torch.arange(actual_kv_len, device=device).unsqueeze(0) + + # [B, actual_kv_len] -> [B, 1, actual_kv_len] + kv_valid_mask = (kv_positions < kv_valid_lengths.unsqueeze(1)).unsqueeze(1) + + mask[:, :, :actual_kv_len] = mask[:, :, :actual_kv_len] & kv_valid_mask + + return mask def _create_causal_attention_mask( self, diff --git a/src/bloombee/server/block_functions.py b/src/bloombee/server/block_functions.py index 86063455..36039531 100644 --- a/src/bloombee/server/block_functions.py +++ b/src/bloombee/server/block_functions.py @@ -228,6 +228,49 @@ async def run_rpc_backward( grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape + +def restore_hidden_states( + flattened_hidden_states: torch.Tensor, # [N_total_valid, hidden_size] + keep_indices: torch.Tensor, # [B, max_keep_len],padding 为 -1 + original_seq_len: int, # 原始序列长度 +) -> torch.Tensor: + """ + 将铺平的 hidden states 还原为 [B, original_seq_len, hidden_size] + + Args: + flattened_hidden_states: [N_total_valid, hidden_size] 铺平后的有效 hidden states + keep_indices: [B, max_keep_len] 每个 batch 的 keep indices,padding 为 -1 + original_seq_len: 原始序列长度 + + Returns: + restored_hidden_states: [B, original_seq_len, hidden_size],无效位置用 0 填充 + """ + batch_size, max_keep_len = keep_indices.shape + hidden_size = flattened_hidden_states.shape[-1] + device = flattened_hidden_states.device + dtype = flattened_hidden_states.dtype + + # 创建输出 tensor,用 0 填充 + restored_hidden_states = torch.zeros( + batch_size, original_seq_len, hidden_size, + dtype=dtype, device=device + ) + + # 创建有效 mask: [B, max_keep_len] + valid_mask = keep_indices >= 0 + + # 创建 batch 索引: [B, max_keep_len] + batch_idx = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(keep_indices) + + # 取出有效部分的索引 + valid_batch_idx = batch_idx[valid_mask] # [N_total_valid] + valid_seq_idx = keep_indices[valid_mask] # [N_total_valid] + + # 写入还原位置 + restored_hidden_states[valid_batch_idx, valid_seq_idx, :] = flattened_hidden_states + + return restored_hidden_states + async def iterate_rpc_inference( requested_uids: Sequence[ExpertUID], requested_backends: Sequence[TransformerBackend], @@ -255,9 +298,9 @@ async def iterate_rpc_inference( step_receive_time = perf_counter() if "start_from_position" in step_metadata: start_from_position = step_metadata["start_from_position"] - assert ( - prefix_length >= start_from_position, - ), f"prefix_length={prefix_length}, start_from_position={start_from_position}" + # assert ( + # prefix_length >= start_from_position, + # ), f"prefix_length={prefix_length}, start_from_position={start_from_position}" prefix_length = start_from_position flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors) @@ -266,7 +309,7 @@ async def iterate_rpc_inference( flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure) hidden_states, keep_indices, need_pruning1, prompts, hypo_ids, tree_attention_mask, kv_cache_position_ids, draft_tokens, prefill_length, is_spec_dec1, *_ = flat_tensors - draft_tokens = draft_tokens[0] if draft_tokens is not None and not is_dummy(draft_tokens) else None + draft_tokens = draft_tokens if draft_tokens is not None and not is_dummy(draft_tokens) else None batch_size, length_increment, _ = hidden_states.shape # Fix for bus error in cross-machine setups: ensure tensors are contiguous @@ -294,12 +337,23 @@ async def iterate_rpc_inference( need_pruning = need_pruning1 == 1 if need_pruning1 is not None and not is_dummy(need_pruning1) else False is_spec_dec = is_spec_dec1 == 1 if is_spec_dec1 is not None and not is_dummy(is_spec_dec1) else False - if is_spec_dec and not need_pruning: - attention_mask_indices = keep_indices[prefill_length:] - prefill_length - logger.info(f"prefill_length: {prefill_length}, attention_mask_indices: {attention_mask_indices}") - idx = attention_mask_indices - tree_attention_mask = tree_attention_mask[:, idx][:, :, idx] + if is_spec_dec and draft_tokens.shape[0] != hidden_states.shape[0]: + hidden_states = restore_hidden_states(hidden_states, keep_indices, draft_tokens.shape[-1]) + + # if is_spec_dec and not need_pruning: + # prefill_len = draft_tokens.shape[-1] - tree_attention_mask.shape[-1] + # attention_mask_indices = keep_indices[:, prefill_len:] - prefill_len + # logger.info(f"keep_indices: {keep_indices}") + # logger.info(f"draft_tokens: {draft_tokens}, tree_attention_mask: {tree_attention_mask.shape}") + # logger.info(f"prefill_length: {prefill_length}, attention_mask_indices: {attention_mask_indices}") + # idx = attention_mask_indices + # idx_2d = idx.unsqueeze(-1).expand(-1, -1, tree_attention_mask.size(-1)) + # tree_attention_mask = torch.gather(tree_attention_mask, dim=1, index=idx_2d) + # idx_3d = idx.unsqueeze(1).expand(-1, tree_attention_mask.size(1), -1) + # tree_attention_mask = torch.gather(tree_attention_mask, dim=2, index=idx_3d) + # logger.info(f"tree_attention_mask: {tree_attention_mask.shape}") + # Cast inputs to backend dtype hidden_states = hidden_states.to(requested_backends[0].dtype) assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}" @@ -459,7 +513,7 @@ async def iterate_rpc_inference( hidden_states.shape[1], dtype=torch.int64, device=hidden_states.device - ) + ).unsqueeze(0).expand(hidden_states.shape[0], -1) serialize_start = perf_counter() need_pruning_next = torch.tensor(0) diff --git a/src/bloombee/server/memory_cache_manager.py b/src/bloombee/server/memory_cache_manager.py index 01e70c5e..855fbd0f 100644 --- a/src/bloombee/server/memory_cache_manager.py +++ b/src/bloombee/server/memory_cache_manager.py @@ -341,4 +341,194 @@ def write_pkv_cache(self, k_pkv: torch.Tensor, v_pkv: torch.Tensor, start_positi self._write_kvs( kvs=(k_write, v_write), start_position=start_position - ) \ No newline at end of file + ) + + def update_cache_batched( + self, + new_kvs: AdaptedKVCache, + kv_valid_lengths: torch.Tensor, + ) -> None: + """ + Batch speculative decoding 专用:每个 batch 从不同位置写入 KV cache + """ + # 快速路径:所有 batch 的 start_position 相同 + if (kv_valid_lengths == kv_valid_lengths[0]).all(): + self._write_kvs(new_kvs, kv_valid_lengths[0].item()) + return + + # 慢速路径:逐 batch 写入 + assert self._active_cache_tensors_stack, "write called outside of use_cache context" + cache_tensors = self._active_cache_tensors_stack[-1] + (k_cache, v_cache), = cache_tensors + S_total, BH_dst, D_dst = k_cache.shape + + new_kvs_data = new_kvs.kvs if hasattr(new_kvs, "kvs") else new_kvs + key, value = new_kvs_data + + def _to_torch(x): + if hasattr(x, 'device') and ( + getattr(getattr(x, 'device', None), 'device_type', None) == DeviceType.COMPRESSED + or (hasattr(x, 'data') and isinstance(getattr(x, 'data'), tuple) and len(getattr(x, 'data')) == 3) + ): + return x.device.decompress(x) + return getattr(x, 'data', x) + + key_t = _to_torch(key) # (B*H, D, s_new) + value_t = _to_torch(value) # (B*H, s_new, D) + + BH_src, D_src, s_new = key_t.shape + H = getattr(self.block_config, "num_attention_heads", None) + B = BH_src // H + + if key_t.dtype != k_cache.dtype: + key_t = key_t.to(dtype=k_cache.dtype) + if value_t.dtype != v_cache.dtype: + value_t = value_t.to(dtype=v_cache.dtype) + + # (B*H, D, s_new) -> (s_new, B*H, D) + k_write = key_t.permute(2, 0, 1) + v_write = value_t.permute(1, 0, 2) + + for i in range(B): + start_pos = kv_valid_lengths[i].item() + end_pos = min(start_pos + s_new, S_total) + actual_len = end_pos - start_pos + + if actual_len <= 0: + continue + + head_start = i * H + head_end = (i + 1) * H + + # 提取第 i 个 batch 的数据并写入 + k_batch = k_write[:actual_len, head_start:head_end, :].contiguous() + v_batch = v_write[:actual_len, head_start:head_end, :].contiguous() + + dst_idx = (slice(start_pos, end_pos), slice(head_start, head_end), slice(0, D_src)) + + k_src_tt = TorchTensor.create_from_torch(k_batch, self.attention_compute) + v_src_tt = TorchTensor.create_from_torch(v_batch, self.attention_compute) + + general_copy(k_cache, dst_idx, k_src_tt, None) + general_copy(v_cache, dst_idx, v_src_tt, None) + + def reorder_and_write_cache( + self, + k_pkv: torch.Tensor, + v_pkv: torch.Tensor, + kv_cache_position_ids: torch.Tensor, + ) -> Tuple[int, torch.Tensor]: + """ + 将分散的 positions 重排到连续位置 [0, N) + + Args: + k_pkv, v_pkv: [B, H, S_old, D] 原始取出的 cache + kv_cache_position_ids: [B, pos_len] 每个 batch 需要保留的 positions(-1 为无效) + + Returns: + max_new_length: 重排后的最大有效长度(用于 select_cache) + valid_lengths: [B] 每个 batch 的实际有效长度(用于 attention mask) + """ + assert self._active_cache_tensors_stack, "write called outside of use_cache" + + B, H, S_old, D = k_pkv.shape + device = k_pkv.device + + if kv_cache_position_ids.dim() == 1: + kv_cache_position_ids = kv_cache_position_ids.unsqueeze(0) + + # 计算每个 batch 的完整索引(prefix + valid tree positions) + batch_all_positions = [] + valid_lengths_list = [] + + for i in range(B): + positions = kv_cache_position_ids[i] + valid_mask = positions >= 0 + valid_positions = positions[valid_mask].tolist() + + if len(valid_positions) > 0: + # root_position 是第一个有效位置 + root_position = valid_positions[0] + # prefix: [0, 1, ..., root_position - 1] + prefix_positions = list(range(root_position)) + # 完整序列: prefix + valid tree positions + all_positions = prefix_positions + valid_positions + else: + # 没有有效的 tree positions,保持原样 + all_positions = list(range(S_old)) + + batch_all_positions.append(all_positions) + valid_lengths_list.append(len(all_positions)) + + valid_lengths = torch.tensor(valid_lengths_list, dtype=torch.long, device=device) + max_new_length = valid_lengths.max().item() + + # 创建新的 reordered cache + k_new = torch.zeros(B, H, max_new_length, D, dtype=k_pkv.dtype, device=device) + v_new = torch.zeros(B, H, max_new_length, D, dtype=v_pkv.dtype, device=device) + + for i in range(B): + all_positions = batch_all_positions[i] + valid_len = len(all_positions) + + if valid_len > 0: + positions_tensor = torch.tensor(all_positions, dtype=torch.long, device=device) + # 从原 cache 的 all_positions 位置取出,写入新 cache 的 [0, valid_len) + k_new[i, :, :valid_len, :] = k_pkv[i, :, positions_tensor, :] + v_new[i, :, :valid_len, :] = v_pkv[i, :, positions_tensor, :] + + # 写回 cache(从 position 0 开始) + self.write_pkv_cache(k_new, v_new, start_position=0) + + return max_new_length, valid_lengths + + def select_cache_for_reorder( + self, + kv_cache_position_ids: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]: + """ + 为 reorder 准备:取出所有 batch 需要的 positions 的并集 + """ + assert self._active_cache_tensors_stack, "select_cache called outside of use_cache" + + cache_tensors = self._active_cache_tensors_stack[-1] + (k_cache, v_cache), = cache_tensors + S_full, BH, D = k_cache.shape + + compute_dst = self.attention_compute + + H = getattr(self.block_config, "num_attention_heads", None) + B = BH // H + + def _as_torch(x): + return x.data if hasattr(x, "data") else x + + if kv_cache_position_ids.dim() == 1: + kv_cache_position_ids = kv_cache_position_ids.unsqueeze(0) + + # 找出需要的最大 position + valid_mask = kv_cache_position_ids >= 0 + if not valid_mask.any(): + return None, None, False + + max_position = kv_cache_position_ids[valid_mask].max().item() + + # 取 [0, max_position] 范围 + prefix_length = int(max_position) + 1 + idx_all = (slice(0, prefix_length), slice(0, BH)) + + k_sel, _ = k_cache.smart_copy(compute_dst, idx_all) + v_sel, _ = v_cache.smart_copy(compute_dst, idx_all) + k_sbh = _as_torch(k_sel) + v_sbh = _as_torch(v_sel) + + def _to_pkv(x_sbh: torch.Tensor) -> torch.Tensor: + return x_sbh.view(prefix_length, B, H, D).permute(1, 2, 0, 3) + + k_pkv = _to_pkv(k_sbh) + v_pkv = _to_pkv(v_sbh) + + # 判断是否需要 reorder + need_reorder = True # 只要有 kv_cache_position_ids 就需要 + + return k_pkv, v_pkv, need_reorder \ No newline at end of file diff --git a/src/bloombee/server/server.py b/src/bloombee/server/server.py index bfcf3e89..5d5c7e55 100644 --- a/src/bloombee/server/server.py +++ b/src/bloombee/server/server.py @@ -237,7 +237,7 @@ def __init__( if max_batch_size is None: max_batch_size = 8192 if is_multiquery_attn else 2048 if inference_max_length is None: - inference_max_length = 8192 if is_multiquery_attn else 2048 + inference_max_length = 8192 if is_multiquery_attn else 4096 self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size self.inference_max_length = inference_max_length self.max_chunk_size_bytes = max_chunk_size_bytes @@ -300,7 +300,7 @@ def __init__( # Create configuration config = PruningConfig( - method=PruningMethod.ADAPTIVE_NEURAL, + method=PruningMethod.SIMPLE_PROBABILITY, neural_threshold=0.5, simple_threshold=0.1 ) diff --git a/src/bloombee/server/speculative_pruner/adaptive_neural_pruner.py b/src/bloombee/server/speculative_pruner/adaptive_neural_pruner.py index 30816858..36698097 100644 --- a/src/bloombee/server/speculative_pruner/adaptive_neural_pruner.py +++ b/src/bloombee/server/speculative_pruner/adaptive_neural_pruner.py @@ -17,7 +17,7 @@ def __init__(self, hidden_size=64): super().__init__() self.quality_path = nn.Sequential( - nn.Linear(3, hidden_size), # prob(3) + acceptance(1) + nn.Linear(3, hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size // 2), nn.ReLU(), @@ -25,18 +25,19 @@ def __init__(self, hidden_size=64): ) self.threshold_path = nn.Sequential( - nn.Linear(4, hidden_size // 2), # network features(4) + nn.Linear(4, hidden_size // 2), nn.ReLU(), nn.Linear(hidden_size // 2, 1) ) def forward(self, prob_features): - quality_score = self.quality_path(prob_features).squeeze(-1) # [batch] + quality_score = self.quality_path(prob_features).squeeze(-1) decision_score = quality_score decision_prob = torch.sigmoid(decision_score) return decision_prob, quality_score + class AdaptiveNeuralPruner: def __init__( self, @@ -80,7 +81,7 @@ def __init__( self.g_ite = 0 self.temp_ite_count = 0 - self.atc = 0 # accept_tokens_count + self.atc = 0 self.after_pruing_atc = 0 self.keep_count = 0 @@ -101,87 +102,179 @@ def __init__( if 'g_ite' in checkpoint: self.g_ite = checkpoint['g_ite'] - def _get_parent_postion(self, i, mask, prefix): - for j in range(i-1, -1, -1): - if mask[0, i, j + prefix] == True: + def _get_parent_position(self, i, mask, prefix, batch_idx=0): + """获取 parent position,支持 batch""" + for j in range(i - 1, -1, -1): + if mask[batch_idx, i, j + prefix] == True: return j return i def prune_branches( self, norm_hidden_states: torch.Tensor, - draft_tokens: Optional[List[int]] = None, + draft_tokens: Union[List[int], torch.Tensor] = None, tree_attention_mask: torch.Tensor = None, network_condition = None, ) -> Dict: - seq_len = len(draft_tokens) + """ + 支持 batch 的 prune_branches + + Args: + norm_hidden_states: [B, seq_len, hidden_size] + draft_tokens: [B, seq_len] 或 List[List[int]] + tree_attention_mask: [B, seq_len, total_len] + network_condition: 网络条件 + + Returns: + keep_indices: [B, max_keep_len],padding 为 -1 + 等其他信息 + """ + # 处理输入维度 + if norm_hidden_states.dim() == 2: + norm_hidden_states = norm_hidden_states.unsqueeze(0) + tree_attention_mask = tree_attention_mask.unsqueeze(0) if tree_attention_mask.dim() == 2 else tree_attention_mask + if isinstance(draft_tokens, list): + draft_tokens = [draft_tokens] + elif isinstance(draft_tokens, torch.Tensor) and draft_tokens.dim() == 1: + draft_tokens = draft_tokens.unsqueeze(0) + + batch_size = norm_hidden_states.shape[0] + seq_len = norm_hidden_states.shape[1] + device = norm_hidden_states.device + prefix_len = tree_attention_mask.shape[2] - seq_len + if network_condition is None: network_condition = self.get_network_condition() or NetworkCondition.mock() - + + # 获取 logits: [B, seq_len, vocab_size] logits = self.lm_head(norm_hidden_states) - # Initialize masks - keep_mask = torch.ones(seq_len, dtype=torch.bool) - discarded = torch.zeros(seq_len, dtype=torch.bool) - decision_probs = torch.zeros(seq_len) - quality_scores = torch.zeros(seq_len) - threshold_adjusts = torch.zeros(seq_len) - prob_features_list, _ = self.collect_training_data( - logits, - tree_attention_mask, - draft_tokens + # 转换 draft_tokens 为 tensor + if isinstance(draft_tokens, list): + if isinstance(draft_tokens[0], list): + draft_tokens = torch.tensor(draft_tokens, device=device) + else: + draft_tokens = torch.tensor(draft_tokens, device=device).unsqueeze(0) + + # 存储每个 batch 的结果 + batch_keep_indices = [] + batch_prune_indices = [] + batch_decision_probs = [] + batch_quality_scores = [] + + for b in range(batch_size): + # 收集该 batch 的训练数据 + prob_features_list, _ = self.collect_training_data_single( + logits[b:b+1], + tree_attention_mask[b:b+1], + draft_tokens[b] ) - for i in range(seq_len): - if i == 0: - keep_mask[0] = True - decision_probs[0] = 1.0 - continue - if discarded[i]: - keep_mask[i] = False - decision_probs[i] = 0.0 - continue + # 初始化 masks + keep_mask = torch.ones(seq_len, dtype=torch.bool, device=device) + discarded = torch.zeros(seq_len, dtype=torch.bool, device=device) + decision_probs = torch.zeros(seq_len, device=device) + quality_scores = torch.zeros(seq_len, device=device) - prob_features = prob_features_list[i-1] - with torch.no_grad(): - prob, quality = self.decision_net(prob_features.unsqueeze(0)) + for i in range(seq_len): + if i == 0: + keep_mask[0] = True + decision_probs[0] = 1.0 + continue - decision_probs[i] = prob.item() - quality_scores[i] = quality.item() - threshold_adjusts[i] = 0 - keep = prob.item() > self.config.neural_threshold + if discarded[i]: + keep_mask[i] = False + decision_probs[i] = 0.0 + continue - if not keep: - keep_mask[i] = False - discarded[i] = True + prob_features = prob_features_list[i - 1] + with torch.no_grad(): + prob, quality = self.decision_net(prob_features.unsqueeze(0)) + + decision_probs[i] = prob.item() + quality_scores[i] = quality.item() + keep = prob.item() > self.config.neural_threshold - # Discard descendants - for j in range(i + 1, seq_len): - if tree_attention_mask[0, j, i + prefix_len] == 1: - discarded[j] = True - keep_mask[j] = False - - keep_indices = torch.where(keep_mask)[0].tolist() - prune_indices = torch.where(~keep_mask)[0].tolist() + if not keep: + keep_mask[i] = False + discarded[i] = True + + # Discard descendants + for j in range(i + 1, seq_len): + if tree_attention_mask[b, j, i + prefix_len] == 1: + discarded[j] = True + keep_mask[j] = False + + keep_indices = torch.where(keep_mask)[0].tolist() + prune_indices = torch.where(~keep_mask)[0].tolist() + + batch_keep_indices.append(keep_indices) + batch_prune_indices.append(prune_indices) + batch_decision_probs.append(decision_probs) + batch_quality_scores.append(quality_scores) + + # Padding keep_indices to same length with -1 + max_keep_len = max(len(indices) for indices in batch_keep_indices) + + padded_keep_indices = torch.full( + (batch_size, max_keep_len), + -1, + dtype=torch.long, + device=device + ) + + for b, indices in enumerate(batch_keep_indices): + if len(indices) > 0: + padded_keep_indices[b, :len(indices)] = torch.tensor(indices, device=device) + + # Padding prune_indices + max_prune_len = max(len(indices) for indices in batch_prune_indices) if batch_prune_indices else 0 + if max_prune_len > 0: + padded_prune_indices = torch.full( + (batch_size, max_prune_len), + -1, + dtype=torch.long, + device=device + ) + for b, indices in enumerate(batch_prune_indices): + if len(indices) > 0: + padded_prune_indices[b, :len(indices)] = torch.tensor(indices, device=device) + else: + padded_prune_indices = torch.empty((batch_size, 0), dtype=torch.long, device=device) + + # Stack decision_probs and quality_scores + stacked_decision_probs = torch.stack(batch_decision_probs, dim=0) # [B, seq_len] + stacked_quality_scores = torch.stack(batch_quality_scores, dim=0) # [B, seq_len] + + # 计算有效长度 + valid_lengths = (padded_keep_indices >= 0).sum(dim=1) # [B] return { - 'keep_indices': keep_indices, - 'prune_indices': prune_indices, - 'decision_probs': decision_probs.cpu().tolist(), - 'quality_scores': quality_scores.cpu().tolist(), + 'keep_indices': padded_keep_indices, # [B, max_keep_len] + 'prune_indices': padded_prune_indices, # [B, max_prune_len] + 'decision_probs': stacked_decision_probs, # [B, seq_len] + 'quality_scores': stacked_quality_scores, # [B, seq_len] 'threshold_adjusts': 0, 'network_condition': network_condition, + 'valid_lengths': valid_lengths, # [B] } - def collect_training_data( + def collect_training_data_single( self, logits: torch.Tensor, tree_attention_mask: torch.Tensor, - draft_tokens: Optional[List[int]] = None - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + draft_tokens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + 为单个 batch 收集训练数据 - seq_len = len(draft_tokens) + Args: + logits: [1, seq_len, vocab_size] + tree_attention_mask: [1, seq_len, total_len] + draft_tokens: [seq_len] + """ + seq_len = draft_tokens.shape[0] prefix_len = tree_attention_mask.shape[2] - seq_len with torch.no_grad(): @@ -189,9 +282,9 @@ def collect_training_data( labels_list = [] for i in range(1, seq_len): - parent_postion = self._get_parent_postion(i, tree_attention_mask, prefix_len) + parent_position = self._get_parent_position(i, tree_attention_mask, prefix_len, batch_idx=0) - logits_at_pos = logits[0, parent_postion] + logits_at_pos = logits[0, parent_position] probs = F.softmax(logits_at_pos, dim=-1) max_prob = torch.max(probs).item() @@ -208,14 +301,11 @@ def collect_training_data( else: normalized_entropy = 0.0 - if draft_tokens is not None: - token_prob = probs[draft_tokens[i]].item() - else: - token_prob = torch.topk(probs, k=min(5, self.vocab_size)).values.sum().item() + token_prob = probs[draft_tokens[i]].item() eps = 1e-10 logp_draft = math.log(token_prob + eps) - log_ratio = logp_draft + log_ratio = logp_draft log_ratio = max(log_ratio, -10.0) / 10.0 log_ratio = -log_ratio @@ -250,13 +340,27 @@ def collect_training_data( ) return prob_features, labels + + def collect_training_data( + self, + logits: torch.Tensor, + tree_attention_mask: torch.Tensor, + draft_tokens: Union[List[int], torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + 原有接口,保持兼容 + """ + if isinstance(draft_tokens, list): + draft_tokens = torch.tensor(draft_tokens, device=self.device) + + return self.collect_training_data_single(logits, tree_attention_mask, draft_tokens) def _get_current_accepted_tokens_indices( - self, - final_logits: torch.Tensor, - attention_mask: torch.Tensor, - draft_tokens: torch.Tensor, - ): + self, + final_logits: torch.Tensor, + attention_mask: torch.Tensor, + draft_tokens: torch.Tensor, + ): seq_len = len(draft_tokens) prefix_len = attention_mask.shape[2] - seq_len @@ -297,8 +401,6 @@ def _get_current_accepted_tokens_indices( best_path = path[:validated] last_index = best_path[-1] - next_token = probs[0, last_index].argmax().item() - logger.info(f"next token: {next_token}") return best_path, best_validated @@ -393,4 +495,4 @@ def get_metrics(self) -> Dict[str, float]: return { 'prune_rate': prune_rate, 'accuracy': accuracy, - } + } \ No newline at end of file diff --git a/src/bloombee/server/speculative_pruner/simple_probability_pruner.py b/src/bloombee/server/speculative_pruner/simple_probability_pruner.py index 8531a29f..22f57e96 100644 --- a/src/bloombee/server/speculative_pruner/simple_probability_pruner.py +++ b/src/bloombee/server/speculative_pruner/simple_probability_pruner.py @@ -42,15 +42,23 @@ def __init__( self.correct_prunes = 0 def _get_parent_postion(self, i, mask, prefix): + """单个序列的 parent position 查找""" for j in range(i-1, -1, -1): - if mask[0, i, j + prefix] == True: + if mask[i, j + prefix] == True: + return j + return i + + def _get_parent_postion_batched(self, i, mask, prefix, batch_idx): + """batch 版本的 parent position 查找""" + for j in range(i-1, -1, -1): + if mask[batch_idx, i, j + prefix] == True: return j return i def prune_branches( self, middle_hidden_states: torch.Tensor, - draft_tokens: List[int], + draft_tokens: Union[List[int], torch.Tensor], tree_attention_mask: torch.Tensor, **kwargs ) -> Dict[str, Any]: @@ -58,110 +66,163 @@ def prune_branches( Prune branches based on probability threshold Process nodes sequentially in depth-first order + 支持 batch 处理 + Args: - middle_hidden_states: [seq_len, hidden_size] - hidden states in depth-first order - draft_tokens: Token IDs in depth-first order - tree_attention_mask: [seq_len, seq_len] - encodes tree structure + middle_hidden_states: [B, seq_len, hidden_size] - hidden states in depth-first order + draft_tokens: [B, seq_len] Token IDs in depth-first order + tree_attention_mask: [B, seq_len, total_len] - encodes tree structure + + Returns: + keep_indices: [B, max_keep_len] 保留的索引,padding 用 -1 + 其他元数据 """ + # 处理输入维度 + if middle_hidden_states.dim() == 2: + # 单序列情况,扩展为 batch + middle_hidden_states = middle_hidden_states.unsqueeze(0) + tree_attention_mask = tree_attention_mask.unsqueeze(0) if tree_attention_mask.dim() == 2 else tree_attention_mask + if isinstance(draft_tokens, list): + draft_tokens = [draft_tokens] + elif isinstance(draft_tokens, torch.Tensor) and draft_tokens.dim() == 1: + draft_tokens = draft_tokens.unsqueeze(0) + + batch_size = middle_hidden_states.shape[0] seq_len = middle_hidden_states.shape[1] - # logger.info(f"middle_hidden_states: {middle_hidden_states.shape}") + device = middle_hidden_states.device prefix_len = tree_attention_mask.shape[2] - seq_len - # Get middle layer logits and probabilities + # Get middle layer logits and probabilities: [B, seq_len, vocab_size] middle_logits = self.lm_head(middle_hidden_states) - # probs = F.softmax(middle_logits, dim=-1) - - # Initialize keep mask (all True initially) - keep_mask = torch.ones(seq_len, dtype=torch.bool) - - # Track which nodes are discarded (for skipping descendants) - discarded = torch.zeros(seq_len, dtype=torch.bool) - # Store scores for all nodes (for statistics and fallback) - scores = torch.zeros(seq_len) - - # Process each node in depth-first order - for i in range(seq_len): - if i == 0: - keep_mask[0] = True - scores[i] = 1.0 - continue + # 转换 draft_tokens 为 tensor + if isinstance(draft_tokens, list): + if isinstance(draft_tokens[0], list): + # List of lists + draft_tokens = torch.tensor(draft_tokens, device=device) + else: + # Single list + draft_tokens = torch.tensor(draft_tokens, device=device).unsqueeze(0) + + # 存储每个 batch 的结果 + batch_keep_indices = [] + batch_prune_indices = [] + batch_scores = [] + batch_keep_masks = [] + + for b in range(batch_size): + # Initialize keep mask (all True initially) + keep_mask = torch.ones(seq_len, dtype=torch.bool, device=device) - # Skip if already discarded by ancestor - if discarded[i]: - keep_mask[i] = False - scores[i] = 0.0 # Set score to 0 for discarded nodes - continue + # Track which nodes are discarded (for skipping descendants) + discarded = torch.zeros(seq_len, dtype=torch.bool, device=device) - # logger.info(f"draft_tokens[i]: {draft_tokens[i]}") + # Store scores for all nodes + scores = torch.zeros(seq_len, device=device) - # Get token probability - parent_postion = self._get_parent_postion(i, tree_attention_mask, prefix_len) - # logger.info(f"xiongxu i : {i}, parent_postion: {parent_postion}") - logits_at_pos = middle_logits[0, parent_postion] - # logger.info(f"xiongxu i : {i}, logits_at_pos: {logits_at_pos}") - probs = F.softmax(logits_at_pos, dim=-1) - # logger.info(f"xiongxu i : {i}, probs: {probs}") - topk = 50 + # Process each node in depth-first order + for i in range(seq_len): + if i == 0: + keep_mask[0] = True + scores[i] = 1.0 + continue + + # Skip if already discarded by ancestor + if discarded[i]: + keep_mask[i] = False + scores[i] = 0.0 + continue + + # Get token probability + parent_position = self._get_parent_postion_batched(i, tree_attention_mask, prefix_len, b) + logits_at_pos = middle_logits[b, parent_position] + probs = F.softmax(logits_at_pos, dim=-1) + topk = 50 - # 取 top-50 token ids(在 parent 的分布上) - topk_ids = torch.topk( - probs, - k=min(topk, self.vocab_size), - dim=-1 - ).indices # shape: [topk] + # 取 top-50 token ids + topk_ids = torch.topk( + probs, + k=min(topk, self.vocab_size), + dim=-1 + ).indices - # 判断 draft token 是否在 topk - label = 1.0 if draft_tokens[i] in topk_ids.tolist() else 0.0 + # 判断 draft token 是否在 topk + draft_id = draft_tokens[b, i].item() + label = 1.0 if draft_id in topk_ids.tolist() else 0.0 + + draft_prob = probs[draft_id].item() + scores[i] = draft_prob + + # Check if score meets threshold + if label == 0.0: + keep_mask[i] = False + discarded[i] = True + + # Mark all descendants as discarded + for j in range(i + 1, seq_len): + if tree_attention_mask[b, j, i + prefix_len] == 1: + discarded[j] = True + keep_mask[j] = False - draft_id = draft_tokens[i] # int - draft_prob = probs[draft_id].item() - # logger.info(f"xiongxu [node {i}] draft_token={draft_id}, prob={draft_prob:.6f}") + # Get final indices for this batch + keep_indices = torch.where(keep_mask)[0].tolist() + prune_indices = torch.where(~keep_mask)[0].tolist() - # Check if score meets threshold - if label == 0.0: - keep_mask[i] = False - discarded[i] = True - - # Mark all descendants as discarded - # Descendants are nodes j > i where j can attend to i - for j in range(i + 1, seq_len): - if tree_attention_mask[0, j, i + prefix_len] == 1: - discarded[j] = True - keep_mask[j] = False + batch_keep_indices.append(keep_indices) + batch_prune_indices.append(prune_indices) + batch_scores.append(scores) + batch_keep_masks.append(keep_mask) + + # Padding keep_indices to same length with -1 + max_keep_len = max(len(indices) for indices in batch_keep_indices) + + padded_keep_indices = torch.full( + (batch_size, max_keep_len), + -1, + dtype=torch.long, + device=device + ) - # Get final indices - keep_indices = torch.where(keep_mask)[0].tolist() - prune_indices = torch.where(~keep_mask)[0].tolist() + for b, indices in enumerate(batch_keep_indices): + if len(indices) > 0: + padded_keep_indices[b, :len(indices)] = torch.tensor(indices, device=device) + + # Padding prune_indices + max_prune_len = max(len(indices) for indices in batch_prune_indices) if batch_prune_indices else 0 + if max_prune_len > 0: + padded_prune_indices = torch.full( + (batch_size, max_prune_len), + -1, + dtype=torch.long, + device=device + ) + for b, indices in enumerate(batch_prune_indices): + if len(indices) > 0: + padded_prune_indices[b, :len(indices)] = torch.tensor(indices, device=device) + else: + padded_prune_indices = torch.empty((batch_size, 0), dtype=torch.long, device=device) + + # Stack scores and masks + stacked_scores = torch.stack(batch_scores, dim=0) # [B, seq_len] + stacked_keep_masks = torch.stack(batch_keep_masks, dim=0) # [B, seq_len] + + # 计算有效长度 + valid_lengths = (padded_keep_indices >= 0).sum(dim=1) # [B] return { - 'keep_indices': keep_indices, - 'prune_indices': prune_indices, - 'keep_probs': scores.tolist(), - 'keep_mask': keep_mask, + 'keep_indices': padded_keep_indices, # [B, max_keep_len],padding 为 -1 + 'prune_indices': padded_prune_indices, # [B, max_prune_len],padding 为 -1 + 'keep_probs': stacked_scores, # [B, seq_len] + 'keep_mask': stacked_keep_masks, # [B, seq_len] + 'valid_lengths': valid_lengths, # [B] 每个 batch 的有效 keep 数量 'metadata': { 'middle_logits': middle_logits, - 'avg_score': scores[keep_mask].mean().item() if keep_mask.any() else 0.0, + 'avg_score': stacked_scores[stacked_keep_masks].mean().item() if stacked_keep_masks.any() else 0.0, } } - def _create_pruned_attention_mask( - self, - original_mask: torch.Tensor, - keep_indices: List[int] - ) -> torch.Tensor: - """Create new attention mask for kept nodes only""" - new_len = len(keep_indices) - new_mask = torch.zeros(new_len, new_len, dtype=original_mask.dtype) - - for new_i, old_i in enumerate(keep_indices): - for new_j, old_j in enumerate(keep_indices): - new_mask[new_i, new_j] = original_mask[old_i, old_j] - - return new_mask - def get_metrics(self) -> Dict[str, float]: """Get pruning metrics""" prune_rate = self.pruned_branches / max(self.total_branches, 1)