Skip to content

Commit 7ac7aa3

Browse files
committed
split
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 75437a6 commit 7ac7aa3

File tree

3 files changed

+86
-2
lines changed

3 files changed

+86
-2
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ Applying quantization with `llmcompressor`:
8686
* [Weight only quantization to `int4` using AWQ](examples/awq/README.md)
8787
* [Weight only quantization to `int4` using AutoRound](examples/autoround/quantization_w4a16/README.md)
8888
* [KV Cache quantization to `fp8`](examples/quantization_kv_cache/README.md)
89-
* [Attention quantization to `fp8` (experimental)](examples/quantization_attention/README.md)
89+
* [KV Cache quantization to `fp8` using per-head](examples/quantization_kv_cache/llama3_fp8_head_kv_example.py)
90+
* [Attention quantization to `fp8`](examples/quantization_attention/README.md)
9091
* [Attention quantization to `nvfp4` with SpinQuant (experimental)](experimental/attention/README.md)
9192
* [Quantizing MoE LLMs](examples/quantizing_moe/README.md)
9293
* [Quantizing Vision-Language Models](examples/multimodal_vision/README.md)

examples/quantization_attention/llama3_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def tokenize(sample):
5656
"attention": QuantizationScheme(
5757
targets=["LlamaAttention"],
5858
input_activations=QuantizationArgs(
59-
num_bits=8, type="float", strategy="attn_head"
59+
num_bits=8, type="float", strategy="tensor"
6060
),
6161
)
6262
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from compressed_tensors.offload import dispatch_model
2+
from compressed_tensors.quantization import QuantizationArgs
3+
from datasets import load_dataset
4+
from transformers import AutoModelForCausalLM, AutoTokenizer
5+
6+
from llmcompressor import oneshot
7+
from llmcompressor.modifiers.quantization import QuantizationModifier
8+
9+
# Select model and load it.
10+
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
11+
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
12+
tokenizer = AutoTokenizer.from_pretrained(model_id)
13+
14+
# Select calibration dataset.
15+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
16+
DATASET_SPLIT = "train_sft"
17+
18+
# Select number of samples. 512 samples is a good place to start.
19+
# Increasing the number of samples can improve accuracy.
20+
NUM_CALIBRATION_SAMPLES = 512
21+
MAX_SEQUENCE_LENGTH = 2048
22+
23+
# Load dataset and preprocess.
24+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
25+
ds = ds.shuffle(seed=42)
26+
27+
28+
def preprocess(example):
29+
return {
30+
"text": tokenizer.apply_chat_template(
31+
example["messages"],
32+
tokenize=False,
33+
)
34+
}
35+
36+
37+
ds = ds.map(preprocess)
38+
39+
40+
# Tokenize inputs.
41+
def tokenize(sample):
42+
return tokenizer(
43+
sample["text"],
44+
padding=False,
45+
max_length=MAX_SEQUENCE_LENGTH,
46+
truncation=True,
47+
add_special_tokens=False,
48+
)
49+
50+
51+
ds = ds.map(tokenize, remove_columns=ds.column_names)
52+
53+
# Configure the quantization algorithm to run.
54+
recipe = QuantizationModifier(
55+
targets="Linear",
56+
scheme="FP8_DYNAMIC",
57+
ignore=["lm_head"],
58+
kv_cache_scheme=QuantizationArgs(num_bits=8, type="float", strategy="attn_head"),
59+
)
60+
61+
# Apply algorithms.
62+
oneshot(
63+
model=model,
64+
dataset=ds,
65+
recipe=recipe,
66+
max_seq_length=MAX_SEQUENCE_LENGTH,
67+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
68+
)
69+
70+
# Confirm generations of the quantized model look sane.
71+
print("\n\n")
72+
print("========== SAMPLE GENERATION ==============")
73+
dispatch_model(model)
74+
sample = tokenizer("Hello my name is", return_tensors="pt")
75+
sample = {key: value.to(model.device) for key, value in sample.items()}
76+
output = model.generate(**sample, max_new_tokens=100)
77+
print(tokenizer.decode(output[0]))
78+
print("==========================================\n\n")
79+
80+
# Save to disk compressed.
81+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-attention-fp8-head"
82+
model.save_pretrained(SAVE_DIR, save_compressed=True)
83+
tokenizer.save_pretrained(SAVE_DIR)

0 commit comments

Comments
 (0)