Skip to content

Commit 7f07384

Browse files
xiongxu1998Xu Xiong
andauthored
add batch support for speculative decoding and it's pruning (ai-decentralized#38)
Co-authored-by: Xu Xiong <xiongxu1998@gmail.com>
1 parent 9b894e6 commit 7f07384

13 files changed

Lines changed: 1263 additions & 410 deletions

benchmarks/benchmark_speculative_decoding.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -60,35 +60,59 @@ def benchmark_inference(process_idx, args, result_pipe):
6060
).to(device)
6161
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
6262

63+
batch_size = 4
6364
dataset = load_dataset("tatsu-lab/alpaca")["train"]
64-
indices = random.sample(range(len(dataset)), 1)
65+
indices = random.sample(range(len(dataset)), batch_size)
6566
sampled = dataset.select(indices)
67+
test_prompts = []
68+
# for item in sampled:
69+
# test_prompts.append(item["instruction"])
70+
71+
test_prompts.append("Hi,")
72+
test_prompts.append("")
73+
test_prompts.append("")
74+
test_prompts.append("")
75+
76+
tokenizer.pad_token = tokenizer.eos_token
77+
# tokenizer.padding_side = "left"
78+
input_ids = tokenizer(test_prompts, return_tensors="pt", padding=True).to(device)["input_ids"]
6679

67-
for item in sampled:
80+
# test_prompt = ""
81+
# bos_token_id = tokenizer.bos_token_id
82+
# if bos_token_id is not None:
83+
# input_ids = torch.tensor([[bos_token_id]], dtype=torch.long, device=device)
84+
# else:
85+
# # 如果tokenizer没有bos_token_id,可能需要手动获取或处理
86+
# logger.warning("Tokenizer does not have a bos_token_id. Using an empty tensor.")
87+
# input_ids = torch.tensor([[]], dtype=torch.long, device=device)
6888

69-
test_prompt = item["instruction"]
70-
logger.info(f"test_prompt: {test_prompt}")
71-
input_ids = tokenizer.encode(test_prompt, return_tensors="pt", add_special_tokens=False).to(device)
72-
73-
test_prompt = ""
74-
bos_token_id = tokenizer.bos_token_id
75-
if bos_token_id is not None:
76-
input_ids = torch.tensor([[bos_token_id]], dtype=torch.long, device=device)
77-
else:
78-
# 如果tokenizer没有bos_token_id,可能需要手动获取或处理
79-
logger.warning("Tokenizer does not have a bos_token_id. Using an empty tensor.")
80-
input_ids = torch.tensor([[]], dtype=torch.long, device=device)
81-
8289

83-
result = ""
84-
start_time = perf_counter()
85-
result = model.generate(input_ids=input_ids, ssm=ssm)
86-
time = perf_counter() - start_time
87-
generated_tokens_num = result.shape[1] - input_ids.shape[1]
88-
speed = generated_tokens_num / time
89-
decoded_result = tokenizer.decode(result[0], skip_special_tokens=True)
90+
result = ""
91+
start_time = perf_counter()
92+
result = model.generate(input_ids=input_ids, ssm=ssm)
93+
time = perf_counter() - start_time
94+
generated_tokens_nums = []
95+
for i in range(batch_size):
96+
prompt_length = input_ids[i].ne(tokenizer.pad_token_id).sum().item()
97+
result_length = result[i].ne(tokenizer.pad_token_id).sum().item()
98+
generated_tokens_num = result_length - prompt_length
99+
generated_tokens_nums.append(generated_tokens_num)
100+
101+
avg_generated_tokens = sum(generated_tokens_nums) / batch_size
102+
speed = avg_generated_tokens / time
103+
104+
# 解码所有结果
105+
decoded_results = tokenizer.batch_decode(result, skip_special_tokens=True)
106+
107+
logger.info(f"benchmark_inference batch size: {batch_size}")
108+
logger.info(f"Total time: {time:.4f}s, Average speed: {speed:.2f} tokens/s")
109+
logger.info(f"Generated tokens per sample: {generated_tokens_nums}")
90110

91-
logger.info(f"benchmark_inference, result: {result}, generated_tokens_num: {generated_tokens_num}, time: {time} speed: {speed}, decoded_result: {decoded_result}")
111+
for i, (prompt, decoded_result) in enumerate(zip(test_prompts, decoded_results)):
112+
logger.info(f"Sample {i}:")
113+
logger.info(f" Prompt: {prompt}")
114+
logger.info(f" Result: {decoded_result}")
115+
logger.info(f" Generated tokens: {generated_tokens_nums[i]}")
92116

93117

94118
result_pipe.send(speed)

src/bloombee/client/inference_session.py

Lines changed: 57 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def step(
146146
normalize_arg(tree_attention_mask),
147147
normalize_arg(kv_cache_position_ids),
148148
normalize_arg(draft_tokens),
149-
normalize_arg(torch.tensor(prefill_length)),
149+
normalize_arg(prefill_length),
150150
normalize_arg(torch.tensor(1 if is_spec_dec else 0)),
151151
)
152152
logger.info(f"_ServerInferenceSession step id {step_id}")
@@ -328,6 +328,7 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与
328328
kv_cache_position_ids: Optional[torch.Tensor] = None,
329329
draft_tokens: Optional[torch.Tensor] = None,
330330
is_spec_decoding: Optional[torch.Tensor] = None,
331+
prefill_length: Optional[torch.Tensor] = None,
331332
) -> torch.Tensor:
332333
assert not self._closed
333334
if torch.is_grad_enabled():
@@ -360,6 +361,7 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与
360361
is_spec_decoding = is_spec_decoding.cpu() if is_spec_decoding is not None else None
361362

362363
step_id = str(uuid.uuid4()) # Generate a unique step ID.
364+
batch_size = inputs.shape[0]
363365

364366
n_input_tokens = inputs.shape[1] if kv_cache_position_ids is None else kv_cache_position_ids.numel()
365367
if self._position + n_input_tokens > self._max_length:
@@ -370,13 +372,15 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与
370372
server_idx = 0
371373
block_idx = 0
372374
inference_step_start = time.perf_counter()
373-
self.prefill_length = inputs.shape[1] - tree_attention_mask.shape[1] if tree_attention_mask is not None and self.first_inference else 0
374-
# self.first_inference = False
375+
if tree_attention_mask is not None:
376+
self.prefill_length = prefill_length.to(inputs.device)
377+
else:
378+
self.prefill_length = torch.zeros(batch_size)
375379
keep_indices = torch.arange(
376-
inputs.shape[1],
377-
dtype=torch.int64,
378-
device=inputs.device
379-
)
380+
inputs.shape[1],
381+
dtype=torch.int64,
382+
device=inputs.device
383+
).unsqueeze(0).expand(inputs.shape[0], -1)
380384
self.keep_indices = keep_indices
381385
if is_spec_decoding is not None and is_spec_decoding.item() == 1:
382386
is_spec_dec = True
@@ -440,8 +444,11 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与
440444
time.sleep(delay)
441445

442446
self._position += n_input_tokens
447+
# logger.info(f"keep_indices: {keep_indices}")
448+
# logger.info(f"before _recover_hidden_states: {inputs}")
443449
if draft_tokens is not None and is_spec_dec:
444-
inputs = self._recover_hidden_states(inputs, self.keep_indices, draft_tokens.shape[1])
450+
inputs = self._restore_hidden_states(inputs, self.keep_indices, draft_tokens.shape[1])
451+
# logger.info(f"after _recover_hidden_states: {inputs}")
445452
outputs = inputs
446453

447454
# 🔍 CLIENT DEBUG: Log inference step end
@@ -454,28 +461,48 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与
454461
# print('client inference session outputs ', outputs.shape)
455462
return outputs
456463

457-
def _recover_hidden_states(self, hidden_states, keep_indices, original_length):
458-
if not torch.is_tensor(keep_indices):
459-
keep_indices = torch.tensor(keep_indices, device=hidden_states.device, dtype=torch.long)
460-
461-
if hidden_states.dim() == 2:
462-
# [S_kept, H]
463-
recovered = hidden_states.new_zeros((original_length, hidden_states.size(-1)))
464-
recovered[keep_indices] = hidden_states
465-
466-
elif hidden_states.dim() == 3:
467-
# [B, S_kept, H]
468-
B, _, H = hidden_states.shape
469-
recovered = hidden_states.new_zeros((B, original_length, H))
470-
recovered[:, keep_indices, :] = hidden_states
471-
472-
else:
473-
raise ValueError(f"Unexpected hidden_states dim {hidden_states.dim()}, expected 2 or 3")
474-
mask = torch.ones(original_length, dtype=torch.bool, device=hidden_states.device)
475-
mask[keep_indices] = False
476-
self._last_padded_mask = mask
477-
478-
return recovered
464+
def _restore_hidden_states(
465+
self,
466+
flattened_hidden_states: torch.Tensor, # [N_total_valid, hidden_size]
467+
keep_indices: torch.Tensor, # [B, max_keep_len],padding 为 -1
468+
original_seq_len: int, # 原始序列长度
469+
) -> torch.Tensor:
470+
"""
471+
将铺平的 hidden states 还原为 [B, original_seq_len, hidden_size]
472+
473+
Args:
474+
flattened_hidden_states: [N_total_valid, hidden_size] 铺平后的有效 hidden states
475+
keep_indices: [B, max_keep_len] 每个 batch 的 keep indices,padding 为 -1
476+
original_seq_len: 原始序列长度
477+
478+
Returns:
479+
restored_hidden_states: [B, original_seq_len, hidden_size],无效位置用 0 填充
480+
"""
481+
batch_size, max_keep_len = keep_indices.shape
482+
hidden_size = flattened_hidden_states.shape[-1]
483+
device = flattened_hidden_states.device
484+
dtype = flattened_hidden_states.dtype
485+
486+
# 创建输出 tensor,用 0 填充
487+
restored_hidden_states = torch.zeros(
488+
batch_size, original_seq_len, hidden_size,
489+
dtype=dtype, device=device
490+
)
491+
492+
# 创建有效 mask: [B, max_keep_len]
493+
valid_mask = keep_indices >= 0
494+
495+
# 创建 batch 索引: [B, max_keep_len]
496+
batch_idx = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(keep_indices)
497+
498+
# 取出有效部分的索引
499+
valid_batch_idx = batch_idx[valid_mask] # [N_total_valid]
500+
valid_seq_idx = keep_indices[valid_mask] # [N_total_valid]
501+
502+
# 写入还原位置
503+
restored_hidden_states[valid_batch_idx, valid_seq_idx, :] = flattened_hidden_states
504+
505+
return restored_hidden_states
479506

480507
def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> int:
481508
# If there is a failed server session, this code closes it

src/bloombee/client/remote_generation.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import contextlib
22
import dataclasses
33
from contextvars import ContextVar
4-
from typing import Any, ContextManager, Dict, List, Optional, Tuple
4+
from typing import Any, ContextManager, Dict, List, Optional, Tuple, Union
55

66
import torch
77
import transformers
@@ -22,22 +22,38 @@ class RemotePastKeyValues(Cache):
2222

2323
def __init__(self) -> None:
2424
super().__init__()
25-
self._seen_tokens = 0
25+
self._seen_tokens: Optional[torch.Tensor] = None
2626
self.hypo_ids: Optional[torch.LongTensor] = None
2727
self.kv_cache_position_ids: Optional[torch.LongTensor] = None
2828
self.is_spec_decoding: Optional[torch.LongTensor] = None
29+
self.prefill_length: Optional[torch.LongTensor] = None
2930

3031
def __getitem__(self, _index: int) -> List[torch.Tensor]:
3132
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()
3233

3334
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
35+
if self._seen_tokens is None:
36+
return 0
37+
if self._seen_tokens.dim() == 0:
38+
return self._seen_tokens.item()
39+
return self._seen_tokens[0].item()
40+
41+
def get_seq_length_batch(self) -> Optional[torch.Tensor]:
3442
return self._seen_tokens
3543

3644
def get_max_length(self) -> Optional[int]:
3745
return None
3846

39-
def update_seen(self, new_seen: int) -> None:
40-
self._seen_tokens += new_seen
47+
def update_seen(self, new_seen: Union[int, torch.Tensor]) -> None:
48+
if isinstance(new_seen, int):
49+
self._seen_tokens = torch.tensor([new_seen])
50+
elif isinstance(new_seen, torch.Tensor):
51+
if new_seen.dim() == 0:
52+
new_seen = new_seen.unsqueeze(0)
53+
self._seen_tokens = new_seen
54+
else:
55+
raise TypeError(f"new_seen must be int or torch.Tensor, got {type(new_seen)}")
56+
4157

4258
def reorder_cache(self, beam_idx):
4359
raise NotImplementedError("Beam search reordering is not implemented yet")
@@ -47,6 +63,9 @@ def set_kv_cache(self, position_ids: Optional[torch.LongTensor]):
4763

4864
def set_is_spec_decoding(self, is_spec_decoding: Optional[torch.LongTensor]):
4965
self.is_spec_decoding = is_spec_decoding
66+
67+
def set_prefill_length(self, prefill_length: Optional[torch.LongTensor]):
68+
self.prefill_length = prefill_length
5069

5170

5271
_skipped_tokens = ContextVar("skipped_tokens", default=0)

src/bloombee/flexgen_utils/pytorch_backend.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def precompute_freqs_cis(
100100
freqs = torch.outer(t, freqs).float() # type: ignore
101101
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
102102
else:
103-
t = position_ids.float().to(inv_freq.device)
104-
freqs = torch.outer(t, inv_freq).float()
103+
t = position_ids.float().to(inv_freq.device) # [B, S]
104+
freqs = t.unsqueeze(-1) * inv_freq.reshape(1, 1, -1)
105105
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
106106
return freqs_cis
107107

@@ -723,6 +723,8 @@ def mha_gen_llama(self, inputs, attention_mask, w_q, w_k, w_v,
723723

724724
hidden = rms_norm(inputs.data, input_layernorm.data)
725725

726+
# logger.info(f"after norm, hidden states: {hidden}")
727+
726728
# shape: (b, 1, h)
727729
q = F.linear(hidden, w_q.data)
728730
k = F.linear(hidden, w_k.data)
@@ -733,13 +735,26 @@ def mha_gen_llama(self, inputs, attention_mask, w_q, w_k, w_v,
733735
k = k.view(b, tgt_s, n_head, head_dim)
734736
v = v.view(b, tgt_s, n_head, head_dim)
735737

738+
# logger.info(f"after projection, query_states: {q}")
739+
# logger.info(f"after projection, key_states: {k}")
740+
# logger.info(f"after projection, value_states: {v}")
741+
742+
# logger.info(f"attention_mask: {attention_mask.shape}")
743+
# logger.info(f"inputs: {inputs.shape}")
744+
# logger.info(f"freq_cis: {freq_cis.shape}")
745+
# logger.info(f"src_s: {src_s}")
746+
# logger.info(f"tgt_s: {tgt_s}")
747+
736748
if rotary_position_ids is not None:
737749
freqs_slice = freq_cis[-tgt_s:]
738750
else:
739751
freqs_slice = freq_cis[src_s - tgt_s: src_s]
740752

741753
q, k = apply_rotary_emb(q, k, freqs_cis=freqs_slice)
742754

755+
# logger.info(f"after rotary, query_states: {q}")
756+
# logger.info(f"after rotary, key_states: {k}")
757+
743758
# shape: (b * n_head, 1, head_dim)
744759
q = q.permute(0, 2, 1, 3).reshape(b * n_head, tgt_s, head_dim)
745760
# shape: (1, b * n_head, head_dim)

src/bloombee/models/llama/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def forward(
8888
# print('model.py llama model inputs_embeds, ', inputs_embeds) # Temporarily commented for cleaner debug output
8989
output_shape = input_shape + (hidden_states.size(-1),)
9090

91-
# logger.info(f"hidden_states: {hidden_states}")
91+
# logger.info(f"input_ids: {input_ids}")
9292

9393
hidden_states = self.layers(
9494
hidden_states,
@@ -98,6 +98,7 @@ def forward(
9898
kv_cache_position_ids=past_key_values.kv_cache_position_ids if past_key_values is not None else None,
9999
draft_tokens = input_ids,
100100
is_spec_decoding = past_key_values.is_spec_decoding if past_key_values is not None else None,
101+
prefill_length = past_key_values.prefill_length if past_key_values is not None else None,
101102
)
102103

103104
if past_key_values is None:

src/bloombee/models/llama/spe_dec_tree.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,24 @@ def prepare_incremental_tree_batch(
185185
trees: List[SpeculativeTree],
186186
input_ids: torch.LongTensor,
187187
device: torch.device,
188-
pad_token_id: int = 0
188+
pad_token_id: int = 0,
189+
seq_lengths: Optional[torch.LongTensor] = None,
189190
) -> Tuple[torch.Tensor, torch.Tensor, List[List[List[TreeNode]]]]:
191+
"""
192+
准备增量 tree batch,支持不同序列长度
193+
194+
Args:
195+
trees: speculative trees 列表
196+
input_ids: [batch_size, max_seq_len] 输入 token ids(可能包含 padding)
197+
device: 设备
198+
pad_token_id: padding token id
199+
seq_lengths: [batch_size] 每个序列的真实长度,如果为 None 则假设所有序列长度相同
200+
201+
Returns:
202+
tree_tokens: [batch_size, max_tree_size]
203+
attention_mask: [batch_size, max_tree_size, past_len + max_tree_size]
204+
batch_node_paths: 每个 batch 的节点路径列表
205+
"""
190206
batch_size = len(trees)
191207

192208
if not trees or all(tree.total_nodes <= 1 for tree in trees):
@@ -237,10 +253,11 @@ def prepare_tree_attention_batch(
237253
trees: List[SpeculativeTree],
238254
prefix_tokens: torch.Tensor,
239255
device: torch.device,
240-
pad_token_id: int = 0
256+
pad_token_id: int = 0,
257+
seq_lengths: Optional[torch.LongTensor] = None,
241258
) -> Tuple[torch.Tensor, torch.Tensor, List[List[List[TreeNode]]]]:
242259
tree_tokens, attention_mask, batch_node_paths = prepare_incremental_tree_batch(
243-
trees, prefix_tokens, device, pad_token_id
260+
trees, prefix_tokens, device, pad_token_id, seq_lengths
244261
)
245262
if tree_tokens.shape[1] > 0:
246263
full_sequence = torch.cat([prefix_tokens, tree_tokens], dim=-1)

0 commit comments

Comments
 (0)