Skip to content

Commit c66e137

Browse files
enable benchmark script (#1554)
* enable benchmark script Signed-off-by: jiqing-feng <[email protected]> * Small fixes to non_cuda_backends.mdx --------- Signed-off-by: jiqing-feng <[email protected]> Co-authored-by: Titus <[email protected]>
1 parent 2640753 commit c66e137

File tree

2 files changed

+79
-5
lines changed

2 files changed

+79
-5
lines changed

benchmarking/generation_benchmark.py

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import argparse
2+
3+
import torch
4+
import torch.utils.benchmark as benchmark
5+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
6+
7+
parser = argparse.ArgumentParser()
8+
9+
parser.add_argument(
10+
"--model_name", default="meta-llama/Llama-3.1-8B-Instruct", required=False, type=str, help="model_name"
11+
)
12+
parser.add_argument("--quant_type", default="int8", type=str, help="quant type", choices=["int8", "nf4", "fp4"])
13+
parser.add_argument("--device_map", default="cpu", type=str, help="device_map", choices=["cpu", "xpu", "cuda"])
14+
args = parser.parse_args()
15+
16+
model_name = args.model_name
17+
device_map = args.device_map
18+
if args.quant_type == "int8":
19+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
20+
else:
21+
quantization_config = BitsAndBytesConfig(
22+
load_in_4bit=True,
23+
bnb_4bit_quant_type=args.quant_type,
24+
bnb_4bit_use_double_quant=True,
25+
bnb_4bit_compute_dtype=torch.bfloat16,
26+
)
27+
quantized_model = AutoModelForCausalLM.from_pretrained(
28+
model_name, torch_dtype="auto", device_map=device_map, quantization_config=quantization_config
29+
)
30+
tokenizer = AutoTokenizer.from_pretrained(model_name)
31+
input_text = "What are we having for dinner?"
32+
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device)
33+
34+
output = quantized_model.generate(**input_ids, max_new_tokens=10)
35+
print(tokenizer.decode(output[0], skip_special_tokens=True))
36+
37+
38+
# benchmark the performance
39+
def benchmark_fn(f, *args, **kwargs):
40+
# Manual warmup
41+
for _ in range(2):
42+
f(*args, **kwargs)
43+
44+
t0 = benchmark.Timer(
45+
stmt="f(*args, **kwargs)",
46+
globals={"args": args, "kwargs": kwargs, "f": f},
47+
num_threads=torch.get_num_threads(),
48+
)
49+
return t0.blocked_autorange().mean
50+
51+
52+
MAX_NEW_TOKENS = 100
53+
54+
quantized_model_latency = benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS)
55+
56+
bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map, torch_dtype=torch.bfloat16)
57+
bf16_model_latency = benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS)
58+
59+
print(f"bnb model latency: {quantized_model_latency:.3f}")
60+
print(f"bf16 model latency: {bf16_model_latency:.3f}")
61+
print(f"BNB vs. bf16 model speed-up: {(bf16_model_latency / quantized_model_latency):.3f}")
62+
63+
print(f"BNB model memory: {(quantized_model.get_memory_footprint() / 1024 / 1024 / 1024):.3f} GB")
64+
print(f"bf16 model memory: {(bf16_model.get_memory_footprint() / 1024 / 1024 / 1024):.3f} GB")
65+
print(
66+
f"BNB vs. bf16 model memory ratio: {(bf16_model.get_memory_footprint() / quantized_model.get_memory_footprint()):.3f}"
67+
)

docs/source/non_cuda_backends.mdx

+12-5
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,25 @@ Thank you for your support!
2727

2828
### Intel
2929

30-
The following performance data is collected from Intel 4th Gen Xeon (SPR) platform. The tables show speed-up and memory compared with different data types of [Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf).
30+
The below performance data is collected from the Intel 4th Gen Xeon (SPR) platform. The tables show speed-up and memory compared with different data types of [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct).
31+
32+
You may run `benchmarking/generation_benchmark.py` to reproduce the below model memory and inference results. Please note that you need to bind cores if you are using the CPU to benchmark. For example, run `numactl -C 0-55 -m 0 python generation_benchmark.py --quant_type nf4` on Intel 4th Gen Xeon with single socket.
33+
34+
The finetune results are selected from [peft](https://github.com/huggingface/peft/blob/main/examples/olora_finetuning/olora_finetuning.py).
35+
36+
#### Model memory (CPU)
37+
| Data Type | BF16 | INT8 | NF4 | FP4 |
38+
|---|---|---|---|---|
39+
| Memory (GB) | 15.0 | 8.5 | 5.2 | 5.2 |
3140

3241
#### Inference (CPU)
3342

3443
| Data Type | BF16 | INT8 | NF4 | FP4 |
3544
|---|---|---|---|---|
36-
| Speed-Up (vs BF16) | 1.0x | 0.44x | 1.8x | 0.1x |
37-
| Memory (GB) | 13.1 | 7.6 | 5.0 | 4.6 |
45+
| Speed-Up (vs BF16) | 1.0x | 0.57x | 2.6x | 0.1x |
3846

3947
#### Fine-Tuning (CPU)
4048

4149
| Data Type | BF16 | INT8 | NF4 | FP4 |
4250
|---|---|---|---|---|
43-
| Speed-Up (vs BF16) | 1.0x | 0.38x | 0.1x | 0.1x |
44-
| Memory (GB) | 40 | 9 | 6.6 | 6.6 |
51+
| Speed-Up (vs BF16) | 1.0x | 0.91x | 1.0x | 1.0x |

0 commit comments

Comments
 (0)