-
Notifications
You must be signed in to change notification settings - Fork 27
Expand file tree
/
Copy pathbenchmark_speculative_decoding.py
More file actions
122 lines (95 loc) · 4.63 KB
/
benchmark_speculative_decoding.py
File metadata and controls
122 lines (95 loc) · 4.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/env python3
import argparse
import multiprocessing as mp
from time import perf_counter
import numpy as np
import torch
from hivemind.utils.logging import get_logger
from transformers import AutoTokenizer
from bloombee import AutoDistributedSpeculativeModel
from bloombee.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import random
logger = get_logger()
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, required=True, help="Model")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
parser.add_argument("--seq_len", type=int, default=2048, help="Sequence length")
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
args = parser.parse_args()
if args.n_processes == "n_gpus":
args.n_processes = torch.cuda.device_count()
else:
args.n_processes = int(args.n_processes)
pipe_recv, pipe_send = mp.Pipe(duplex=False)
processes = [mp.Process(target=benchmark_inference, args=(i, args, pipe_send)) for i in range(args.n_processes)]
for proc in processes:
proc.start()
for proc in processes:
proc.join()
speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)])
logger.info(f"Final result: {speed=:.2f}")
@torch.inference_mode()
def benchmark_inference(process_idx, args, result_pipe):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ssm = AutoModelForCausalLM.from_pretrained("JackFram/llama-68m")
ssm = ssm.to(device).eval()
# warm up ssm to reduce inference later
with torch.no_grad():
dummy_input = torch.ones(1, 8, dtype=torch.long, device=device)
ssm(dummy_input, attention_mask=torch.ones_like(dummy_input))
model = AutoDistributedSpeculativeModel.from_pretrained(
args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype]
).to(device)
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
batch_size = 4
dataset = load_dataset("tatsu-lab/alpaca")["train"]
indices = random.sample(range(len(dataset)), batch_size)
sampled = dataset.select(indices)
test_prompts = []
# for item in sampled:
# test_prompts.append(item["instruction"])
test_prompts.append("Hi,")
test_prompts.append("")
test_prompts.append("")
test_prompts.append("")
tokenizer.pad_token = tokenizer.eos_token
# tokenizer.padding_side = "left"
input_ids = tokenizer(test_prompts, return_tensors="pt", padding=True).to(device)["input_ids"]
# test_prompt = ""
# bos_token_id = tokenizer.bos_token_id
# if bos_token_id is not None:
# input_ids = torch.tensor([[bos_token_id]], dtype=torch.long, device=device)
# else:
# # 如果tokenizer没有bos_token_id,可能需要手动获取或处理
# logger.warning("Tokenizer does not have a bos_token_id. Using an empty tensor.")
# input_ids = torch.tensor([[]], dtype=torch.long, device=device)
result = ""
start_time = perf_counter()
result = model.generate(input_ids=input_ids, ssm=ssm)
time = perf_counter() - start_time
generated_tokens_nums = []
for i in range(batch_size):
prompt_length = input_ids[i].ne(tokenizer.pad_token_id).sum().item()
result_length = result[i].ne(tokenizer.pad_token_id).sum().item()
generated_tokens_num = result_length - prompt_length
generated_tokens_nums.append(generated_tokens_num)
avg_generated_tokens = sum(generated_tokens_nums) / batch_size
speed = avg_generated_tokens / time
# 解码所有结果
decoded_results = tokenizer.batch_decode(result, skip_special_tokens=True)
logger.info(f"benchmark_inference batch size: {batch_size}")
logger.info(f"Total time: {time:.4f}s, Average speed: {speed:.2f} tokens/s")
logger.info(f"Generated tokens per sample: {generated_tokens_nums}")
for i, (prompt, decoded_result) in enumerate(zip(test_prompts, decoded_results)):
logger.info(f"Sample {i}:")
logger.info(f" Prompt: {prompt}")
logger.info(f" Result: {decoded_result}")
logger.info(f" Generated tokens: {generated_tokens_nums[i]}")
result_pipe.send(speed)
if __name__ == "__main__":
main()