-
Notifications
You must be signed in to change notification settings - Fork 107
Expand file tree
/
Copy pathtest_llmc_integration.py
More file actions
239 lines (205 loc) · 7.77 KB
/
test_llmc_integration.py
File metadata and controls
239 lines (205 loc) · 7.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import pytest
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
from llmcompressor import oneshot
from llmcompressor.modifiers.autoround import AutoRoundModifier
from transformers import AutoModelForCausalLM, AutoTokenizer
from auto_round.calib_dataset import get_dataset
recipe_str = """
quant_stage:
quant_modifiers:
AutoRoundModifier:
ignore: ["lm_head"]
iters: 10
config_groups:
group_0:
targets:
- "Linear"
input_activations: null
output_activations: null
weights:
num_bits: 4
type: "int"
symmetric: true
strategy: group
group_size: 128
"""
recipe_modifier_full = AutoRoundModifier(
ignore=["lm_head"],
iters=10,
config_groups={
"group_0": QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(num_bits=4, strategy="group", group_size=128),
)
},
)
recipe_modifier_nvfp4 = AutoRoundModifier(
ignore=["lm_head"],
iters=2,
scheme="NVFP4",
)
recipe_modifier_mxfp4 = AutoRoundModifier(
ignore=["lm_head"],
iters=0,
scheme="MXFP4",
)
w8a8_dynamic_recipe_modifier = AutoRoundModifier(
ignore=["lm_head"],
iters=0,
config_groups={
"group_0": QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(num_bits=8, type="float", strategy="channel"),
input_activations=QuantizationArgs(num_bits=8, type="float", strategy="token", dynamic=True),
)
},
)
w8a8_static_recipe_modifier = AutoRoundModifier(
ignore=["lm_head"],
iters=0,
config_groups={
"group_0": QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(num_bits=8, type="float", strategy="tensor"),
input_activations=QuantizationArgs(num_bits=8, type="float", strategy="tensor"),
)
},
)
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires at least 1 Cuda GPU")
@pytest.mark.parametrize(
"recipe",
[
recipe_str,
recipe_modifier_full,
recipe_modifier_nvfp4,
recipe_modifier_mxfp4,
],
)
def test_oneshot_application(recipe, tmp_path):
output = tmp_path / "oneshot_output"
model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model)
dataset = get_dataset(
tokenizer=tokenizer,
seqlen=1024,
nsamples=32,
)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
oneshot(
model=model,
dataset=dataset,
output_dir=output,
recipe=recipe,
)
model_loaded = AutoModelForCausalLM.from_pretrained(output, device_map=device)
# Check that the model is quantized
# decompress() will attach a quantization_config to the model
# as we decompress right away
quantization_config = model_loaded.config.quantization_config.quantization_config
assert quantization_config is not None
# check config is set properly
assert "lm_head" in quantization_config.ignore
assert len(quantization_config.config_groups) == 1
quant_scheme = quantization_config.config_groups["group_0"]
assert isinstance(quant_scheme, QuantizationScheme)
weight_args = quantization_config.config_groups["group_0"].weights
assert isinstance(weight_args, QuantizationArgs)
assert weight_args.num_bits == 4
# Check a specific layer is quantized
targeted_linear_layer = model_loaded.model.layers[2].self_attn.q_proj
assert hasattr(targeted_linear_layer, "quantization_scheme")
# Check lm-head is not quantized
not_targeted = model_loaded.lm_head
assert not hasattr(not_targeted, "quantization_scheme")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires at least 2 Cuda GPUs")
def test_oneshot_with_device_ids(tmp_path):
output = tmp_path / "oneshot_output"
model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model)
dataset = get_dataset(
tokenizer=tokenizer,
seqlen=512,
nsamples=4,
)
device = "cuda:0"
recipe = AutoRoundModifier(
ignore=["lm_head"],
iters=10,
config_groups={
"group_0": QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(num_bits=4, strategy="group", group_size=128),
)
},
device_ids="0,1",
)
oneshot(
model=model,
dataset=dataset,
output_dir=output,
recipe=recipe,
)
model_loaded = AutoModelForCausalLM.from_pretrained(output, device_map=device)
# Check that the model is quantized
# decompress() will attach a quantization_config to the model
# as we decompress right away
quantization_config = model_loaded.config.quantization_config.quantization_config
assert quantization_config is not None
# check config is set properly
assert "lm_head" in quantization_config.ignore
assert len(quantization_config.config_groups) == 1
quant_scheme = quantization_config.config_groups["group_0"]
assert isinstance(quant_scheme, QuantizationScheme)
weight_args = quantization_config.config_groups["group_0"].weights
assert isinstance(weight_args, QuantizationArgs)
assert weight_args.num_bits == 4
# Check a specific layer is quantized
targeted_linear_layer = model_loaded.model.layers[2].self_attn.q_proj
assert hasattr(targeted_linear_layer, "quantization_scheme")
# Check lm-head is not quantized
not_targeted = model_loaded.lm_head
assert not hasattr(not_targeted, "quantization_scheme")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires at least 1 Cuda GPU")
@pytest.mark.parametrize(
"recipe",
[w8a8_dynamic_recipe_modifier, w8a8_static_recipe_modifier],
)
def test_rtn_oneshot(recipe, tmp_path):
output = tmp_path / "oneshot_output"
model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model)
dataset = get_dataset(
tokenizer=tokenizer,
seqlen=1024,
nsamples=32,
)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
oneshot(
model=model,
dataset=dataset,
output_dir=output,
recipe=recipe,
)
model_loaded = AutoModelForCausalLM.from_pretrained(output, device_map=device)
quantization_config = model_loaded.config.quantization_config.quantization_config
assert quantization_config is not None
# check config is set properly
assert "lm_head" in quantization_config.ignore
assert len(quantization_config.config_groups) == 1
quant_scheme = quantization_config.config_groups["group_0"]
assert isinstance(quant_scheme, QuantizationScheme)
weight_args = quantization_config.config_groups["group_0"].weights
act_args = quantization_config.config_groups["group_0"].input_activations
assert isinstance(weight_args, QuantizationArgs)
assert weight_args.num_bits == recipe.config_groups["group_0"].weights.num_bits
assert weight_args.strategy == recipe.config_groups["group_0"].weights.strategy
if act_args is not None:
assert act_args.num_bits == recipe.config_groups["group_0"].input_activations.num_bits
assert act_args.strategy == recipe.config_groups["group_0"].input_activations.strategy
# Check a specific layer is quantized
targeted_linear_layer = model_loaded.model.layers[2].self_attn.q_proj
assert hasattr(targeted_linear_layer, "quantization_scheme")
# Check lm-head is not quantized
not_targeted = model_loaded.lm_head
assert not hasattr(not_targeted, "quantization_scheme")