Skip to content

Commit e9ac6a9

Browse files
author
Xu Xiong
committed
modify benchmark
1 parent 83cd13e commit e9ac6a9

7 files changed

Lines changed: 92 additions & 67 deletions

File tree

benchmarks/benchmark_speculative_decoding.py

Lines changed: 66 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import multiprocessing as mp
44
import random
55
import sys
6+
import json
67
from pathlib import Path
78
from time import perf_counter
89

@@ -49,6 +50,8 @@ def main():
4950
default="Number {i}: ",
5051
help="Prompt template. Must contain '{i}', e.g. 'Number {i}: ' or 'Topic {i}: '",
5152
)
53+
parser.add_argument("--group_idx", type=int, default=0,
54+
help="Which group to run (0-9)")
5255
args = parser.parse_args()
5356

5457
if args.n_processes == "n_gpus":
@@ -81,72 +84,76 @@ def benchmark_inference(process_idx, args, result_pipe):
8184
).to(device)
8285
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
8386

84-
batch_size = getattr(args, 'batch_size', 8)
87+
batch_size = getattr(args, 'batch_size', 32)
88+
group_idx = getattr(args, 'group_idx', 0)
89+
90+
# 加载固定的prompt组
8591
dataset = load_dataset("tatsu-lab/alpaca")["train"]
86-
indices = random.sample(range(len(dataset)), batch_size)
92+
with open("eval_indices.json", "r") as f:
93+
groups = json.load(f)
94+
95+
indices = groups[group_idx]
8796
sampled = dataset.select(indices)
88-
test_prompts = []
89-
for item in sampled:
90-
test_prompts.append(item["instruction"])
91-
# test_prompts.append("Generate a list of the best places to eat in London.")
92-
93-
# base_prompt = (
94-
# "Quantum mechanics explains the behavior of particles at very small scales. "
95-
# "Neural networks learn patterns by adjusting weights through backpropagation. "
96-
# "Distributed systems require robust consensus mechanisms to maintain state. "
97-
# "Optimization algorithms like gradient descent are fundamental to machine learning. "
98-
# "Transformer architectures rely on attention mechanisms to capture dependencies. "
99-
# "Reinforcement learning optimizes actions by maximizing cumulative rewards. "
100-
# "Bayesian inference updates beliefs based on observed evidence and prior knowledge. "
101-
# "Convex optimization problems guarantee global minima under certain conditions. "
102-
# "Signal processing extracts meaningful information from noisy measurements. "
103-
# )
104-
# prompts = [
105-
# f"{base_prompt} Example {i + 1} discusses large-scale AI systems and scientific discovery."
106-
# for i in range(batch_size)
107-
# ]
108-
# prompt_indices = [args.prompt_start_index + i for i in range(batch_size)]
109-
# if "{i}" not in args.prompt_template:
110-
# raise ValueError("--prompt_template must include '{i}' placeholder")
111-
# prompts = [args.prompt_template.format(i=i) for i in prompt_indices]
112-
# test_prompts = prompts
113-
97+
test_prompts = [item["instruction"] for item in sampled]
98+
99+
logger.info(f"Running group {group_idx}/{len(groups)-1}")
100+
logger.info(f"Prompts: {test_prompts}")
101+
114102
tokenizer.pad_token = tokenizer.eos_token
115-
input_ids = tokenizer(test_prompts, return_tensors="pt", padding=True).to(device)["input_ids"]
116-
117-
result = ""
118-
start_time = perf_counter()
103+
input_ids = tokenizer(
104+
test_prompts,
105+
return_tensors="pt",
106+
padding=True
107+
).to(device)["input_ids"]
108+
119109
max_new_tokens = getattr(args, 'seq_len', 128)
120-
result = model.generate(input_ids=input_ids, drafter=drafter, max_new_tokens=max_new_tokens)
121-
time = perf_counter() - start_time
122-
generated_tokens_nums = []
123-
for i in range(batch_size):
124-
prompt_mask = input_ids[i].ne(tokenizer.pad_token_id)
125-
prompt_length = prompt_mask.sum().item()
126-
result_mask = result[i].ne(tokenizer.pad_token_id) & result[i].ne(0)
127-
result_length = result_mask.sum().item()
128-
generated_tokens_num = result_length - prompt_length
129-
generated_tokens_nums.append(generated_tokens_num)
130-
131-
logger.info(f"result: {result[i]}")
132110

133-
avg_generated_tokens = sum(generated_tokens_nums) / batch_size
134-
speed = avg_generated_tokens / time
135-
136-
decoded_results = tokenizer.batch_decode(result, skip_special_tokens=True)
137-
138-
logger.info(f"benchmark_inference batch size: {batch_size}")
139-
logger.info(f"Total time: {time:.4f}s, Average speed: {speed:.2f} tokens/s")
140-
logger.info(f"Generated tokens per sample: {generated_tokens_nums}")
141-
142-
for i, (prompt, decoded_result) in enumerate(zip(test_prompts, decoded_results)):
143-
logger.info(f"Sample {i}:")
144-
logger.info(f" Prompt: {prompt}")
145-
logger.info(f" Result: {decoded_result}")
146-
logger.info(f" Generated tokens: {generated_tokens_nums[i]}")
111+
# warmup
112+
logger.info("Warming up...")
113+
_ = model.generate(
114+
input_ids=input_ids,
115+
drafter=drafter,
116+
max_new_tokens=10
117+
)
118+
119+
# 正式计时
120+
logger.info("Starting benchmark...")
121+
start_time = perf_counter()
122+
result = model.generate(
123+
input_ids=input_ids,
124+
drafter=drafter,
125+
max_new_tokens=max_new_tokens
126+
)
127+
elapsed_time = perf_counter() - start_time
128+
129+
original_output_ids = result
130+
131+
total_generated = 128 * batch_size
132+
throughput = total_generated / elapsed_time
133+
134+
logger.info(f"Group {group_idx} | "
135+
f"Total time: {elapsed_time:.4f}s | "
136+
f"Throughput: {throughput:.2f} tokens/s | "
137+
f"Generated tokens per sample: {total_generated}")
147138

139+
# 保存结果
140+
result_label = "pruned" if getattr(args, 'pruning', False) else "unpruned"
141+
output_file = f"results_{result_label}_group_{group_idx}.json"
142+
result_data = {
143+
"group_idx": group_idx,
144+
"pruning": getattr(args, 'pruning', False),
145+
"throughput": throughput,
146+
"elapsed_time": elapsed_time,
147+
"total_generated": total_generated,
148+
"generated_tokens_nums": total_generated,
149+
"batch_size": batch_size,
150+
"max_new_tokens": max_new_tokens,
151+
}
152+
with open(output_file, "w") as f:
153+
json.dump(result_data, f, indent=2)
154+
logger.info(f"Results saved to {output_file}")
148155

149-
result_pipe.send(speed)
156+
result_pipe.send(throughput)
150157

151158

152159
if __name__ == "__main__":

benchmarks/prompts_generate.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import random
2+
import json
3+
from datasets import load_dataset
4+
5+
dataset = load_dataset("tatsu-lab/alpaca")["train"]
6+
batch_size = 32
7+
num_groups = 10
8+
9+
random.seed(42)
10+
groups = []
11+
for i in range(num_groups):
12+
indices = random.sample(range(len(dataset)), batch_size)
13+
groups.append(indices)
14+
15+
with open("eval_indices.json", "w") as f:
16+
json.dump(groups, f)
17+
18+
print(f"Generated {num_groups} groups of {batch_size} prompts each")

src/bloombee/models/llama/speculative_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def generate(
3535
logits_processor: Optional[LogitsProcessorList] = None,
3636
stopping_criteria: Optional[StoppingCriteriaList] = None,
3737
streamer: Optional["BaseStreamer"] = None,
38-
beam_width: int = 1,
39-
max_tree_depth: int = 5,
38+
beam_width: int = 2,
39+
max_tree_depth: int = 3,
4040
use_kv_cache: bool = True,
4141
kv_cache_window: int = 2048,
4242
max_new_tokens: int = 128,

src/bloombee/server/backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,8 +476,8 @@ def _flag_to_bool(value) -> bool:
476476
self.pruner_manager.train_lm_head(middle_norm_hidden_states, norm_hidden_states)
477477

478478
if not training_mode and self._is_spec_decoding and self._need_pruning and self._is_last_block:
479-
norm_hidden_states = self.module.rms_norm(output_hidden_states)
480-
keep_indices = self.prune_draft_tree(norm_hidden_states, inference_info.draft_tokens, full_mask)
479+
# norm_hidden_states = self.module.rms_norm(output_hidden_states)
480+
# keep_indices = self.prune_draft_tree(norm_hidden_states, inference_info.draft_tokens, full_mask)
481481
keep_indices = keep_indices
482482
# t7 = time.perf_counter()
483483
# logger.info(f"prune_draft_tree took {t7 - t6:.4f} seconds")

src/bloombee/server/block_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -616,8 +616,8 @@ async def iterate_rpc_inference(
616616

617617
if is_spec_dec:
618618
rotary_position_ids = _create_tree_position_ids_with_invalid_cache(
619-
width=1,
620-
depth=5,
619+
width=2,
620+
depth=3,
621621
prefill_length=prefill_length - 1,
622622
kv_cache_position_ids=kv_cache_position_ids,
623623
batch_offset=0,

src/bloombee/server/handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ async def _cross_stage_push_wrapper(mb_hidden, mb_keep, push_metadata):
795795
push_tensor_bytes = sum(len(t.buffer) for t in next_tensors)
796796

797797
# 模拟网络传输延时
798-
NETWORK_SPEED_BYTES_PER_SEC = 10 * 1024 * 1024 # 10 MB/s
798+
NETWORK_SPEED_BYTES_PER_SEC = 50 * 1024 * 1024 # 10 MB/s
799799
transfer_delay = push_tensor_bytes / NETWORK_SPEED_BYTES_PER_SEC
800800
await asyncio.sleep(transfer_delay)
801801
task = asyncio.create_task(self._push_outputs(request, output_tensors, step_metadata))

src/bloombee/server/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def __init__(
330330
# Create configuration
331331
config = PruningConfig(
332332
method=PruningMethod.ADAPTIVE_NEURAL,
333-
neural_threshold=0.75,
333+
neural_threshold=0.5,
334334
simple_threshold=0.1
335335
)
336336

0 commit comments

Comments
 (0)