Skip to content

Commit 3219b16

Browse files
kylesayrsbrian-dellabetta
authored andcommitted
Add Attention Quantization Examples (vllm-project#2484)
## Purpose ## * Add attention and kv quantization examples to reflect current vllm support * ## Changes ## * Add fp8 tensor attention example * Add fp8 kv head example ## Testing ## * Ran examples e2e and ran in vllm --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com> Signed-off-by: Ziming <frankziming26@outlook.com>
1 parent 0c142a3 commit 3219b16

File tree

5 files changed

+111
-24
lines changed

5 files changed

+111
-24
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ Some of the exciting new features include:
4343
* **Updated FP4 Microscale Support**: GPTQ now supports FP4 quantization schemes, including both [MXFP4](examples/quantization_w4a16_fp4/mxfp4/llama3_example.py) and [NVFP4](examples/quantization_w4a4_fp4/llama3_gptq_example.py). MXFP4 support has also been improved with updated weight scale generation. Models with weight-only quantization in the MXFP4 format can now run in vLLM as of vLLM v0.14.0. MXFP4 models with activation quantization are not yet supported in vLLM for compressed-tensors models
4444
* **New Model-Free PTQ Pathway**: A new model-free PTQ pathway has been added to LLM Compressor, called [`model_free_ptq`](src/llmcompressor/entrypoints/model_free/__init__.py#L36). This pathway allows you to quantize your model without the requirement of Hugging Face model definition and is especially useful in cases where `oneshot` may fail. This pathway is currently supported for data-free pathways only i.e FP8 quantization and was leveraged to quantize the [Mistral Large 3 model](https://huggingface.co/mistralai/Mistral-Large-3-675B-Instruct-2512). Additional [examples](examples/model_free_ptq) have been added illustrating how LLM Compressor can be used for Kimi K2
4545
* **MXFP8 Microscale Support (Experimental)**: LLM Compressor now supports MXFP8 quantization via PTQ. Both W8A8 ([MXFP8](experimental/mxfp8/qwen3_example_w8a8_mxfp8.py)) and W8A16 weight-only ([MXFP8A16](experimental/mxfp8/qwen3_example_w8a16_mxfp8.py)) modes are available.
46-
* **Extended KV Cache and Attention Quantization Support**: LLM Compressor now supports attention quantization. KV Cache quantization, which previously only supported per-tensor scales, has been extended to support any quantization scheme including a new `per-head` quantization scheme. Support for these checkpoints is on-going in vLLM and scripts to get started have been added to the [experimental folder](experimental/attention)
46+
* **Extended KV Cache and Attention Quantization Support**: LLM Compressor now supports attention quantization, as well as fine-grained KV Cache quantization. Previously only per-tensor KV cache quantization was supported. Now, you can quantize KV cache with `per-head` scales and run with vLLM. Examples of more generalized attention and kv cache quantization can be found in the [experimental folder](experimental/attention).
4747

4848

4949
### Supported Formats
@@ -86,7 +86,8 @@ Applying quantization with `llmcompressor`:
8686
* [Weight only quantization to `int4` using AWQ](examples/awq/README.md)
8787
* [Weight only quantization to `int4` using AutoRound](examples/autoround/quantization_w4a16/README.md)
8888
* [KV Cache quantization to `fp8`](examples/quantization_kv_cache/README.md)
89-
* [Attention quantization to `fp8` (experimental)](experimental/attention/README.md)
89+
* [KV Cache quantization to `fp8` using per-head](examples/quantization_kv_cache/llama3_fp8_head_kv_example.py)
90+
* [Attention quantization to `fp8`](examples/quantization_attention/README.md)
9091
* [Attention quantization to `nvfp4` with SpinQuant (experimental)](experimental/attention/README.md)
9192
* [Quantizing MoE LLMs](examples/quantizing_moe/README.md)
9293
* [Quantizing Vision-Language Models](examples/multimodal_vision/README.md)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Attention Quantization in LLM Compressor #
2+
LLM Compressor supports applying static attention quantization to models
3+
4+
## Per-Head FP8 Attention Example ##
5+
For an example applying attention quantization, see [llama3_attention.py](/examples/quantization_attention/llama3_attention.py).
6+
7+
```python
8+
recipe = QuantizationModifier(
9+
config_groups={
10+
"attention": QuantizationScheme(
11+
targets=["LlamaAttention"],
12+
input_activations=QuantizationArgs(
13+
num_bits=8, type="float", strategy="attn_head"
14+
),
15+
)
16+
}
17+
)
18+
```
19+
20+
Accuracy should be almost identical to the base model for FP8 attention.
21+
Note that attention quantization also implicitly applies kv cache quantization with the same quantization arguments.

experimental/attention/llama3_attention.py renamed to examples/quantization_attention/llama3_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
from compressed_tensors.offload import dispatch_model
12
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
23
from datasets import load_dataset
34
from transformers import AutoModelForCausalLM, AutoTokenizer
45

56
from llmcompressor import oneshot
67
from llmcompressor.modifiers.quantization import QuantizationModifier
7-
from compressed_tensors.offload import dispatch_model
88

99
# Select model and load it.
1010
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
@@ -56,7 +56,7 @@ def tokenize(sample):
5656
"attention": QuantizationScheme(
5757
targets=["LlamaAttention"],
5858
input_activations=QuantizationArgs(
59-
num_bits=8, type="float", strategy="attn_head"
59+
num_bits=8, type="float", strategy="tensor"
6060
),
6161
)
6262
}
@@ -82,6 +82,6 @@ def tokenize(sample):
8282
print("==========================================\n\n")
8383

8484
# Save to disk compressed.
85-
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-attention-fp8-head"
85+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-attention-fp8"
8686
model.save_pretrained(SAVE_DIR, save_compressed=True)
8787
tokenizer.save_pretrained(SAVE_DIR)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from compressed_tensors.offload import dispatch_model
2+
from compressed_tensors.quantization import QuantizationArgs
3+
from datasets import load_dataset
4+
from transformers import AutoModelForCausalLM, AutoTokenizer
5+
6+
from llmcompressor import oneshot
7+
from llmcompressor.modifiers.quantization import QuantizationModifier
8+
9+
# Select model and load it.
10+
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
11+
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
12+
tokenizer = AutoTokenizer.from_pretrained(model_id)
13+
14+
# Select calibration dataset.
15+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
16+
DATASET_SPLIT = "train_sft"
17+
18+
# Select number of samples. 512 samples is a good place to start.
19+
# Increasing the number of samples can improve accuracy.
20+
NUM_CALIBRATION_SAMPLES = 512
21+
MAX_SEQUENCE_LENGTH = 2048
22+
23+
# Load dataset and preprocess.
24+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
25+
ds = ds.shuffle(seed=42)
26+
27+
28+
def preprocess(example):
29+
return {
30+
"text": tokenizer.apply_chat_template(
31+
example["messages"],
32+
tokenize=False,
33+
)
34+
}
35+
36+
37+
ds = ds.map(preprocess)
38+
39+
40+
# Tokenize inputs.
41+
def tokenize(sample):
42+
return tokenizer(
43+
sample["text"],
44+
padding=False,
45+
max_length=MAX_SEQUENCE_LENGTH,
46+
truncation=True,
47+
add_special_tokens=False,
48+
)
49+
50+
51+
ds = ds.map(tokenize, remove_columns=ds.column_names)
52+
53+
# Configure the quantization algorithm to run.
54+
recipe = QuantizationModifier(
55+
targets="Linear",
56+
scheme="FP8_DYNAMIC",
57+
ignore=["lm_head"],
58+
kv_cache_scheme=QuantizationArgs(num_bits=8, type="float", strategy="attn_head"),
59+
)
60+
61+
# Apply algorithms.
62+
oneshot(
63+
model=model,
64+
dataset=ds,
65+
recipe=recipe,
66+
max_seq_length=MAX_SEQUENCE_LENGTH,
67+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
68+
)
69+
70+
# Confirm generations of the quantized model look sane.
71+
print("\n\n")
72+
print("========== SAMPLE GENERATION ==============")
73+
dispatch_model(model)
74+
sample = tokenizer("Hello my name is", return_tensors="pt")
75+
sample = {key: value.to(model.device) for key, value in sample.items()}
76+
output = model.generate(**sample, max_new_tokens=100)
77+
print(tokenizer.decode(output[0]))
78+
print("==========================================\n\n")
79+
80+
# Save to disk compressed.
81+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-fp8-kv-head"
82+
model.save_pretrained(SAVE_DIR, save_compressed=True)
83+
tokenizer.save_pretrained(SAVE_DIR)

experimental/attention/README.md

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,5 @@
11
# Attention Quantization in LLM Compressor #
2-
LLM Compressor supports applying static attention quantization to models. Please note that attention quantization support in vLLM is still ongoing and is not fully supported as of this writing.
3-
4-
## FP8 Attention Example ##
5-
For an example applying attention quantization, see [llama3_attention.py](/experimental/attention/llama3_attention.py).
6-
7-
```python
8-
recipe = QuantizationModifier(
9-
config_groups={
10-
"attention": QuantizationScheme(
11-
targets=["LlamaAttention"],
12-
input_activations=QuantizationArgs(
13-
num_bits=8, type="float", strategy="attn_head"
14-
),
15-
)
16-
}
17-
)
18-
```
19-
20-
Note that attention quantization also implicitly applies kv cache quantization with the same quantization arguments.
2+
LLM Compressor supports applying static attention quantization to models. Please note that NVFP4 attention quantization and R3 support in vLLM is still ongoing and is not fully supported as of this writing.
213

224
## NVFP4 Attention + R3 Example ##
235
Attention quantization can be improved using the R3 transform, as described by [SpinQuant](https://arxiv.org/abs/2405.16406). This transform reduces the presence of outliers in the attention activation distribution, thereby improving accurcy recovery.

0 commit comments

Comments
 (0)