|
3 | 3 | import multiprocessing as mp |
4 | 4 | import random |
5 | 5 | import sys |
| 6 | +import json |
6 | 7 | from pathlib import Path |
7 | 8 | from time import perf_counter |
8 | 9 |
|
@@ -49,6 +50,8 @@ def main(): |
49 | 50 | default="Number {i}: ", |
50 | 51 | help="Prompt template. Must contain '{i}', e.g. 'Number {i}: ' or 'Topic {i}: '", |
51 | 52 | ) |
| 53 | + parser.add_argument("--group_idx", type=int, default=0, |
| 54 | + help="Which group to run (0-9)") |
52 | 55 | args = parser.parse_args() |
53 | 56 |
|
54 | 57 | if args.n_processes == "n_gpus": |
@@ -81,72 +84,76 @@ def benchmark_inference(process_idx, args, result_pipe): |
81 | 84 | ).to(device) |
82 | 85 | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False) |
83 | 86 |
|
84 | | - batch_size = getattr(args, 'batch_size', 8) |
| 87 | + batch_size = getattr(args, 'batch_size', 32) |
| 88 | + group_idx = getattr(args, 'group_idx', 0) |
| 89 | + |
| 90 | + # 加载固定的prompt组 |
85 | 91 | dataset = load_dataset("tatsu-lab/alpaca")["train"] |
86 | | - indices = random.sample(range(len(dataset)), batch_size) |
| 92 | + with open("eval_indices.json", "r") as f: |
| 93 | + groups = json.load(f) |
| 94 | + |
| 95 | + indices = groups[group_idx] |
87 | 96 | sampled = dataset.select(indices) |
88 | | - test_prompts = [] |
89 | | - for item in sampled: |
90 | | - test_prompts.append(item["instruction"]) |
91 | | - # test_prompts.append("Generate a list of the best places to eat in London.") |
92 | | - |
93 | | - # base_prompt = ( |
94 | | - # "Quantum mechanics explains the behavior of particles at very small scales. " |
95 | | - # "Neural networks learn patterns by adjusting weights through backpropagation. " |
96 | | - # "Distributed systems require robust consensus mechanisms to maintain state. " |
97 | | - # "Optimization algorithms like gradient descent are fundamental to machine learning. " |
98 | | - # "Transformer architectures rely on attention mechanisms to capture dependencies. " |
99 | | - # "Reinforcement learning optimizes actions by maximizing cumulative rewards. " |
100 | | - # "Bayesian inference updates beliefs based on observed evidence and prior knowledge. " |
101 | | - # "Convex optimization problems guarantee global minima under certain conditions. " |
102 | | - # "Signal processing extracts meaningful information from noisy measurements. " |
103 | | - # ) |
104 | | - # prompts = [ |
105 | | - # f"{base_prompt} Example {i + 1} discusses large-scale AI systems and scientific discovery." |
106 | | - # for i in range(batch_size) |
107 | | - # ] |
108 | | - # prompt_indices = [args.prompt_start_index + i for i in range(batch_size)] |
109 | | - # if "{i}" not in args.prompt_template: |
110 | | - # raise ValueError("--prompt_template must include '{i}' placeholder") |
111 | | - # prompts = [args.prompt_template.format(i=i) for i in prompt_indices] |
112 | | - # test_prompts = prompts |
113 | | - |
| 97 | + test_prompts = [item["instruction"] for item in sampled] |
| 98 | + |
| 99 | + logger.info(f"Running group {group_idx}/{len(groups)-1}") |
| 100 | + logger.info(f"Prompts: {test_prompts}") |
| 101 | + |
114 | 102 | tokenizer.pad_token = tokenizer.eos_token |
115 | | - input_ids = tokenizer(test_prompts, return_tensors="pt", padding=True).to(device)["input_ids"] |
116 | | - |
117 | | - result = "" |
118 | | - start_time = perf_counter() |
| 103 | + input_ids = tokenizer( |
| 104 | + test_prompts, |
| 105 | + return_tensors="pt", |
| 106 | + padding=True |
| 107 | + ).to(device)["input_ids"] |
| 108 | + |
119 | 109 | max_new_tokens = getattr(args, 'seq_len', 128) |
120 | | - result = model.generate(input_ids=input_ids, drafter=drafter, max_new_tokens=max_new_tokens) |
121 | | - time = perf_counter() - start_time |
122 | | - generated_tokens_nums = [] |
123 | | - for i in range(batch_size): |
124 | | - prompt_mask = input_ids[i].ne(tokenizer.pad_token_id) |
125 | | - prompt_length = prompt_mask.sum().item() |
126 | | - result_mask = result[i].ne(tokenizer.pad_token_id) & result[i].ne(0) |
127 | | - result_length = result_mask.sum().item() |
128 | | - generated_tokens_num = result_length - prompt_length |
129 | | - generated_tokens_nums.append(generated_tokens_num) |
130 | | - |
131 | | - logger.info(f"result: {result[i]}") |
132 | 110 |
|
133 | | - avg_generated_tokens = sum(generated_tokens_nums) / batch_size |
134 | | - speed = avg_generated_tokens / time |
135 | | - |
136 | | - decoded_results = tokenizer.batch_decode(result, skip_special_tokens=True) |
137 | | - |
138 | | - logger.info(f"benchmark_inference batch size: {batch_size}") |
139 | | - logger.info(f"Total time: {time:.4f}s, Average speed: {speed:.2f} tokens/s") |
140 | | - logger.info(f"Generated tokens per sample: {generated_tokens_nums}") |
141 | | - |
142 | | - for i, (prompt, decoded_result) in enumerate(zip(test_prompts, decoded_results)): |
143 | | - logger.info(f"Sample {i}:") |
144 | | - logger.info(f" Prompt: {prompt}") |
145 | | - logger.info(f" Result: {decoded_result}") |
146 | | - logger.info(f" Generated tokens: {generated_tokens_nums[i]}") |
| 111 | + # warmup |
| 112 | + logger.info("Warming up...") |
| 113 | + _ = model.generate( |
| 114 | + input_ids=input_ids, |
| 115 | + drafter=drafter, |
| 116 | + max_new_tokens=10 |
| 117 | + ) |
| 118 | + |
| 119 | + # 正式计时 |
| 120 | + logger.info("Starting benchmark...") |
| 121 | + start_time = perf_counter() |
| 122 | + result = model.generate( |
| 123 | + input_ids=input_ids, |
| 124 | + drafter=drafter, |
| 125 | + max_new_tokens=max_new_tokens |
| 126 | + ) |
| 127 | + elapsed_time = perf_counter() - start_time |
| 128 | + |
| 129 | + original_output_ids = result |
| 130 | + |
| 131 | + total_generated = 128 * batch_size |
| 132 | + throughput = total_generated / elapsed_time |
| 133 | + |
| 134 | + logger.info(f"Group {group_idx} | " |
| 135 | + f"Total time: {elapsed_time:.4f}s | " |
| 136 | + f"Throughput: {throughput:.2f} tokens/s | " |
| 137 | + f"Generated tokens per sample: {total_generated}") |
147 | 138 |
|
| 139 | + # 保存结果 |
| 140 | + result_label = "pruned" if getattr(args, 'pruning', False) else "unpruned" |
| 141 | + output_file = f"results_{result_label}_group_{group_idx}.json" |
| 142 | + result_data = { |
| 143 | + "group_idx": group_idx, |
| 144 | + "pruning": getattr(args, 'pruning', False), |
| 145 | + "throughput": throughput, |
| 146 | + "elapsed_time": elapsed_time, |
| 147 | + "total_generated": total_generated, |
| 148 | + "generated_tokens_nums": total_generated, |
| 149 | + "batch_size": batch_size, |
| 150 | + "max_new_tokens": max_new_tokens, |
| 151 | + } |
| 152 | + with open(output_file, "w") as f: |
| 153 | + json.dump(result_data, f, indent=2) |
| 154 | + logger.info(f"Results saved to {output_file}") |
148 | 155 |
|
149 | | - result_pipe.send(speed) |
| 156 | + result_pipe.send(throughput) |
150 | 157 |
|
151 | 158 |
|
152 | 159 | if __name__ == "__main__": |
|
0 commit comments