Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 47 additions & 23 deletions benchmarks/benchmark_speculative_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,35 +60,59 @@ def benchmark_inference(process_idx, args, result_pipe):
).to(device)
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)

batch_size = 4
dataset = load_dataset("tatsu-lab/alpaca")["train"]
indices = random.sample(range(len(dataset)), 1)
indices = random.sample(range(len(dataset)), batch_size)
sampled = dataset.select(indices)
test_prompts = []
# for item in sampled:
# test_prompts.append(item["instruction"])

test_prompts.append("Hi,")
test_prompts.append("")
test_prompts.append("")
test_prompts.append("")

tokenizer.pad_token = tokenizer.eos_token
# tokenizer.padding_side = "left"
input_ids = tokenizer(test_prompts, return_tensors="pt", padding=True).to(device)["input_ids"]

for item in sampled:
# test_prompt = ""
# bos_token_id = tokenizer.bos_token_id
# if bos_token_id is not None:
# input_ids = torch.tensor([[bos_token_id]], dtype=torch.long, device=device)
# else:
# # 如果tokenizer没有bos_token_id,可能需要手动获取或处理
# logger.warning("Tokenizer does not have a bos_token_id. Using an empty tensor.")
# input_ids = torch.tensor([[]], dtype=torch.long, device=device)

test_prompt = item["instruction"]
logger.info(f"test_prompt: {test_prompt}")
input_ids = tokenizer.encode(test_prompt, return_tensors="pt", add_special_tokens=False).to(device)

test_prompt = ""
bos_token_id = tokenizer.bos_token_id
if bos_token_id is not None:
input_ids = torch.tensor([[bos_token_id]], dtype=torch.long, device=device)
else:
# 如果tokenizer没有bos_token_id,可能需要手动获取或处理
logger.warning("Tokenizer does not have a bos_token_id. Using an empty tensor.")
input_ids = torch.tensor([[]], dtype=torch.long, device=device)


result = ""
start_time = perf_counter()
result = model.generate(input_ids=input_ids, ssm=ssm)
time = perf_counter() - start_time
generated_tokens_num = result.shape[1] - input_ids.shape[1]
speed = generated_tokens_num / time
decoded_result = tokenizer.decode(result[0], skip_special_tokens=True)
result = ""
start_time = perf_counter()
result = model.generate(input_ids=input_ids, ssm=ssm)
time = perf_counter() - start_time
generated_tokens_nums = []
for i in range(batch_size):
prompt_length = input_ids[i].ne(tokenizer.pad_token_id).sum().item()
result_length = result[i].ne(tokenizer.pad_token_id).sum().item()
generated_tokens_num = result_length - prompt_length
generated_tokens_nums.append(generated_tokens_num)

avg_generated_tokens = sum(generated_tokens_nums) / batch_size
speed = avg_generated_tokens / time

# 解码所有结果
decoded_results = tokenizer.batch_decode(result, skip_special_tokens=True)

logger.info(f"benchmark_inference batch size: {batch_size}")
logger.info(f"Total time: {time:.4f}s, Average speed: {speed:.2f} tokens/s")
logger.info(f"Generated tokens per sample: {generated_tokens_nums}")

logger.info(f"benchmark_inference, result: {result}, generated_tokens_num: {generated_tokens_num}, time: {time} speed: {speed}, decoded_result: {decoded_result}")
for i, (prompt, decoded_result) in enumerate(zip(test_prompts, decoded_results)):
logger.info(f"Sample {i}:")
logger.info(f" Prompt: {prompt}")
logger.info(f" Result: {decoded_result}")
logger.info(f" Generated tokens: {generated_tokens_nums[i]}")


result_pipe.send(speed)
Expand Down
87 changes: 57 additions & 30 deletions src/bloombee/client/inference_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def step(
normalize_arg(tree_attention_mask),
normalize_arg(kv_cache_position_ids),
normalize_arg(draft_tokens),
normalize_arg(torch.tensor(prefill_length)),
normalize_arg(prefill_length),
normalize_arg(torch.tensor(1 if is_spec_dec else 0)),
)
logger.info(f"_ServerInferenceSession step id {step_id}")
Expand Down Expand Up @@ -328,6 +328,7 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与
kv_cache_position_ids: Optional[torch.Tensor] = None,
draft_tokens: Optional[torch.Tensor] = None,
is_spec_decoding: Optional[torch.Tensor] = None,
prefill_length: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert not self._closed
if torch.is_grad_enabled():
Expand Down Expand Up @@ -360,6 +361,7 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与
is_spec_decoding = is_spec_decoding.cpu() if is_spec_decoding is not None else None

step_id = str(uuid.uuid4()) # Generate a unique step ID.
batch_size = inputs.shape[0]

n_input_tokens = inputs.shape[1] if kv_cache_position_ids is None else kv_cache_position_ids.numel()
if self._position + n_input_tokens > self._max_length:
Expand All @@ -370,13 +372,15 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与
server_idx = 0
block_idx = 0
inference_step_start = time.perf_counter()
self.prefill_length = inputs.shape[1] - tree_attention_mask.shape[1] if tree_attention_mask is not None and self.first_inference else 0
# self.first_inference = False
if tree_attention_mask is not None:
self.prefill_length = prefill_length.to(inputs.device)
else:
self.prefill_length = torch.zeros(batch_size)
keep_indices = torch.arange(
inputs.shape[1],
dtype=torch.int64,
device=inputs.device
)
inputs.shape[1],
dtype=torch.int64,
device=inputs.device
).unsqueeze(0).expand(inputs.shape[0], -1)
self.keep_indices = keep_indices
if is_spec_decoding is not None and is_spec_decoding.item() == 1:
is_spec_dec = True
Expand Down Expand Up @@ -440,8 +444,11 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与
time.sleep(delay)

self._position += n_input_tokens
# logger.info(f"keep_indices: {keep_indices}")
# logger.info(f"before _recover_hidden_states: {inputs}")
if draft_tokens is not None and is_spec_dec:
inputs = self._recover_hidden_states(inputs, self.keep_indices, draft_tokens.shape[1])
inputs = self._restore_hidden_states(inputs, self.keep_indices, draft_tokens.shape[1])
# logger.info(f"after _recover_hidden_states: {inputs}")
outputs = inputs

# 🔍 CLIENT DEBUG: Log inference step end
Expand All @@ -454,28 +461,48 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与
# print('client inference session outputs ', outputs.shape)
return outputs

def _recover_hidden_states(self, hidden_states, keep_indices, original_length):
if not torch.is_tensor(keep_indices):
keep_indices = torch.tensor(keep_indices, device=hidden_states.device, dtype=torch.long)

if hidden_states.dim() == 2:
# [S_kept, H]
recovered = hidden_states.new_zeros((original_length, hidden_states.size(-1)))
recovered[keep_indices] = hidden_states

elif hidden_states.dim() == 3:
# [B, S_kept, H]
B, _, H = hidden_states.shape
recovered = hidden_states.new_zeros((B, original_length, H))
recovered[:, keep_indices, :] = hidden_states

else:
raise ValueError(f"Unexpected hidden_states dim {hidden_states.dim()}, expected 2 or 3")
mask = torch.ones(original_length, dtype=torch.bool, device=hidden_states.device)
mask[keep_indices] = False
self._last_padded_mask = mask

return recovered
def _restore_hidden_states(
self,
flattened_hidden_states: torch.Tensor, # [N_total_valid, hidden_size]
keep_indices: torch.Tensor, # [B, max_keep_len],padding 为 -1
original_seq_len: int, # 原始序列长度
) -> torch.Tensor:
"""
将铺平的 hidden states 还原为 [B, original_seq_len, hidden_size]

Args:
flattened_hidden_states: [N_total_valid, hidden_size] 铺平后的有效 hidden states
keep_indices: [B, max_keep_len] 每个 batch 的 keep indices,padding 为 -1
original_seq_len: 原始序列长度

Returns:
restored_hidden_states: [B, original_seq_len, hidden_size],无效位置用 0 填充
"""
batch_size, max_keep_len = keep_indices.shape
hidden_size = flattened_hidden_states.shape[-1]
device = flattened_hidden_states.device
dtype = flattened_hidden_states.dtype

# 创建输出 tensor,用 0 填充
restored_hidden_states = torch.zeros(
batch_size, original_seq_len, hidden_size,
dtype=dtype, device=device
)

# 创建有效 mask: [B, max_keep_len]
valid_mask = keep_indices >= 0

# 创建 batch 索引: [B, max_keep_len]
batch_idx = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(keep_indices)

# 取出有效部分的索引
valid_batch_idx = batch_idx[valid_mask] # [N_total_valid]
valid_seq_idx = keep_indices[valid_mask] # [N_total_valid]

# 写入还原位置
restored_hidden_states[valid_batch_idx, valid_seq_idx, :] = flattened_hidden_states

return restored_hidden_states

def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> int:
# If there is a failed server session, this code closes it
Expand Down
27 changes: 23 additions & 4 deletions src/bloombee/client/remote_generation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
import dataclasses
from contextvars import ContextVar
from typing import Any, ContextManager, Dict, List, Optional, Tuple
from typing import Any, ContextManager, Dict, List, Optional, Tuple, Union

import torch
import transformers
Expand All @@ -22,22 +22,38 @@ class RemotePastKeyValues(Cache):

def __init__(self) -> None:
super().__init__()
self._seen_tokens = 0
self._seen_tokens: Optional[torch.Tensor] = None
self.hypo_ids: Optional[torch.LongTensor] = None
self.kv_cache_position_ids: Optional[torch.LongTensor] = None
self.is_spec_decoding: Optional[torch.LongTensor] = None
self.prefill_length: Optional[torch.LongTensor] = None

def __getitem__(self, _index: int) -> List[torch.Tensor]:
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
if self._seen_tokens is None:
return 0
if self._seen_tokens.dim() == 0:
return self._seen_tokens.item()
return self._seen_tokens[0].item()

def get_seq_length_batch(self) -> Optional[torch.Tensor]:
return self._seen_tokens

def get_max_length(self) -> Optional[int]:
return None

def update_seen(self, new_seen: int) -> None:
self._seen_tokens += new_seen
def update_seen(self, new_seen: Union[int, torch.Tensor]) -> None:
if isinstance(new_seen, int):
self._seen_tokens = torch.tensor([new_seen])
elif isinstance(new_seen, torch.Tensor):
if new_seen.dim() == 0:
new_seen = new_seen.unsqueeze(0)
self._seen_tokens = new_seen
else:
raise TypeError(f"new_seen must be int or torch.Tensor, got {type(new_seen)}")


def reorder_cache(self, beam_idx):
raise NotImplementedError("Beam search reordering is not implemented yet")
Expand All @@ -47,6 +63,9 @@ def set_kv_cache(self, position_ids: Optional[torch.LongTensor]):

def set_is_spec_decoding(self, is_spec_decoding: Optional[torch.LongTensor]):
self.is_spec_decoding = is_spec_decoding

def set_prefill_length(self, prefill_length: Optional[torch.LongTensor]):
self.prefill_length = prefill_length


_skipped_tokens = ContextVar("skipped_tokens", default=0)
Expand Down
19 changes: 17 additions & 2 deletions src/bloombee/flexgen_utils/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def precompute_freqs_cis(
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
else:
t = position_ids.float().to(inv_freq.device)
freqs = torch.outer(t, inv_freq).float()
t = position_ids.float().to(inv_freq.device) # [B, S]
freqs = t.unsqueeze(-1) * inv_freq.reshape(1, 1, -1)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis

Expand Down Expand Up @@ -723,6 +723,8 @@ def mha_gen_llama(self, inputs, attention_mask, w_q, w_k, w_v,

hidden = rms_norm(inputs.data, input_layernorm.data)

# logger.info(f"after norm, hidden states: {hidden}")

# shape: (b, 1, h)
q = F.linear(hidden, w_q.data)
k = F.linear(hidden, w_k.data)
Expand All @@ -733,13 +735,26 @@ def mha_gen_llama(self, inputs, attention_mask, w_q, w_k, w_v,
k = k.view(b, tgt_s, n_head, head_dim)
v = v.view(b, tgt_s, n_head, head_dim)

# logger.info(f"after projection, query_states: {q}")
# logger.info(f"after projection, key_states: {k}")
# logger.info(f"after projection, value_states: {v}")

# logger.info(f"attention_mask: {attention_mask.shape}")
# logger.info(f"inputs: {inputs.shape}")
# logger.info(f"freq_cis: {freq_cis.shape}")
# logger.info(f"src_s: {src_s}")
# logger.info(f"tgt_s: {tgt_s}")

if rotary_position_ids is not None:
freqs_slice = freq_cis[-tgt_s:]
else:
freqs_slice = freq_cis[src_s - tgt_s: src_s]

q, k = apply_rotary_emb(q, k, freqs_cis=freqs_slice)

# logger.info(f"after rotary, query_states: {q}")
# logger.info(f"after rotary, key_states: {k}")

# shape: (b * n_head, 1, head_dim)
q = q.permute(0, 2, 1, 3).reshape(b * n_head, tgt_s, head_dim)
# shape: (1, b * n_head, head_dim)
Expand Down
3 changes: 2 additions & 1 deletion src/bloombee/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def forward(
# print('model.py llama model inputs_embeds, ', inputs_embeds) # Temporarily commented for cleaner debug output
output_shape = input_shape + (hidden_states.size(-1),)

# logger.info(f"hidden_states: {hidden_states}")
# logger.info(f"input_ids: {input_ids}")

hidden_states = self.layers(
hidden_states,
Expand All @@ -98,6 +98,7 @@ def forward(
kv_cache_position_ids=past_key_values.kv_cache_position_ids if past_key_values is not None else None,
draft_tokens = input_ids,
is_spec_decoding = past_key_values.is_spec_decoding if past_key_values is not None else None,
prefill_length = past_key_values.prefill_length if past_key_values is not None else None,
)

if past_key_values is None:
Expand Down
23 changes: 20 additions & 3 deletions src/bloombee/models/llama/spe_dec_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,24 @@ def prepare_incremental_tree_batch(
trees: List[SpeculativeTree],
input_ids: torch.LongTensor,
device: torch.device,
pad_token_id: int = 0
pad_token_id: int = 0,
seq_lengths: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[List[List[TreeNode]]]]:
"""
准备增量 tree batch,支持不同序列长度

Args:
trees: speculative trees 列表
input_ids: [batch_size, max_seq_len] 输入 token ids(可能包含 padding)
device: 设备
pad_token_id: padding token id
seq_lengths: [batch_size] 每个序列的真实长度,如果为 None 则假设所有序列长度相同

Returns:
tree_tokens: [batch_size, max_tree_size]
attention_mask: [batch_size, max_tree_size, past_len + max_tree_size]
batch_node_paths: 每个 batch 的节点路径列表
"""
batch_size = len(trees)

if not trees or all(tree.total_nodes <= 1 for tree in trees):
Expand Down Expand Up @@ -237,10 +253,11 @@ def prepare_tree_attention_batch(
trees: List[SpeculativeTree],
prefix_tokens: torch.Tensor,
device: torch.device,
pad_token_id: int = 0
pad_token_id: int = 0,
seq_lengths: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[List[List[TreeNode]]]]:
tree_tokens, attention_mask, batch_node_paths = prepare_incremental_tree_batch(
trees, prefix_tokens, device, pad_token_id
trees, prefix_tokens, device, pad_token_id, seq_lengths
)
if tree_tokens.shape[1] > 0:
full_sequence = torch.cat([prefix_tokens, tree_tokens], dim=-1)
Expand Down
Loading