Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable benchmark script #1554

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions benchmarking/generation_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import argparse

import torch
import torch.utils.benchmark as benchmark
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

parser = argparse.ArgumentParser()

parser.add_argument(
"--model_name", default="meta-llama/Llama-3.1-8B-Instruct", required=False, type=str, help="model_name"
)
parser.add_argument("--quant_type", default="int8", type=str, help="quant type", choices=["int8", "nf4", "fp4"])
parser.add_argument("--device_map", default="cpu", type=str, help="device_map", choices=["cpu", "xpu", "cuda"])
args = parser.parse_args()

model_name = args.model_name
device_map = args.device_map
if args.quant_type == "int8":
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
else:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type=args.quant_type,
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
quantized_model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype="auto", device_map=device_map, quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device)

output = quantized_model.generate(**input_ids, max_new_tokens=10)
print(tokenizer.decode(output[0], skip_special_tokens=True))


# benchmark the performance
def benchmark_fn(f, *args, **kwargs):
# Manual warmup
for _ in range(2):
f(*args, **kwargs)

t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=torch.get_num_threads(),
)
return t0.blocked_autorange().mean


MAX_NEW_TOKENS = 100

quantized_model_latency = benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS)

bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map, torch_dtype=torch.bfloat16)
bf16_model_latency = benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS)

print(f"bnb model latency: {quantized_model_latency:.3f}")
print(f"bf16 model latency: {bf16_model_latency:.3f}")
print(f"BNB vs. bf16 model speed-up: {(bf16_model_latency / quantized_model_latency):.3f}")

print(f"BNB model memory: {(quantized_model.get_memory_footprint() / 1024 / 1024 / 1024):.3f} GB")
print(f"bf16 model memory: {(bf16_model.get_memory_footprint() / 1024 / 1024 / 1024):.3f} GB")
print(
f"BNB vs. bf16 model memory ratio: {(bf16_model.get_memory_footprint() / quantized_model.get_memory_footprint()):.3f}"
)
17 changes: 12 additions & 5 deletions docs/source/non_cuda_backends.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,25 @@ Thank you for your support!

### Intel

The following performance data is collected from Intel 4th Gen Xeon (SPR) platform. The tables show speed-up and memory compared with different data types of [Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf).
The below performance data is collected from the Intel 4th Gen Xeon (SPR) platform. The tables show speed-up and memory compared with different data types of [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct).

You may run `benchmarking/generation_benchmark.py` to reproduce the below model memory and inference results. Please note that you need to bind cores if you are using the CPU to benchmark. For example, run `numactl -C 0-55 -m 0 python generation_benchmark.py --quant_type nf4` on Intel 4th Gen Xeon with single socket.

The finetune results are selected from [peft](https://github.com/huggingface/peft/blob/main/examples/olora_finetuning/olora_finetuning.py).

#### Model memory (CPU)
| Data Type | BF16 | INT8 | NF4 | FP4 |
|---|---|---|---|---|
| Memory (GB) | 15.0 | 8.5 | 5.2 | 5.2 |

#### Inference (CPU)

| Data Type | BF16 | INT8 | NF4 | FP4 |
|---|---|---|---|---|
| Speed-Up (vs BF16) | 1.0x | 0.44x | 1.8x | 0.1x |
| Memory (GB) | 13.1 | 7.6 | 5.0 | 4.6 |
| Speed-Up (vs BF16) | 1.0x | 0.57x | 2.6x | 0.1x |

#### Fine-Tuning (CPU)

| Data Type | BF16 | INT8 | NF4 | FP4 |
|---|---|---|---|---|
| Speed-Up (vs BF16) | 1.0x | 0.38x | 0.1x | 0.1x |
| Memory (GB) | 40 | 9 | 6.6 | 6.6 |
| Speed-Up (vs BF16) | 1.0x | 0.91x | 1.0x | 1.0x |