-
Notifications
You must be signed in to change notification settings - Fork 490
Expand file tree
/
Copy pathtest_base.py
More file actions
282 lines (236 loc) · 8.78 KB
/
test_base.py
File metadata and controls
282 lines (236 loc) · 8.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import pytest
import torch
from compressed_tensors.utils import getattr_chain
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmcompressor import oneshot
from llmcompressor.modifiers.factory import ModifierFactory
from llmcompressor.modifiers.transform.smoothquant.base import (
SmoothQuantModifier,
)
@pytest.mark.unit
@pytest.mark.usefixtures("setup_modifier_factory")
def test_smooth_quant_is_registered():
smoothing_strength = 0.3
mappings = [(["layer1", "layer2"], "layer3")]
modifier = ModifierFactory.create(
type_="SmoothQuantModifier",
allow_experimental=False,
allow_registered=True,
smoothing_strength=smoothing_strength,
mappings=mappings,
)
assert isinstance(
modifier, SmoothQuantModifier
), "PyTorch SmoothQuant not registered"
assert modifier.smoothing_strength == smoothing_strength
assert modifier.mappings == mappings
@pytest.mark.unit
@pytest.mark.usefixtures("setup_modifier_factory")
def test_smooth_quant_defaults():
default_sq = SmoothQuantModifier()
assert default_sq.smoothing_strength == 0.5
@pytest.mark.unit
def test_override_defaults():
strength = 0.7
dummy_map = [(["layer1", "layer2"], "layer3")]
non_default_sq = SmoothQuantModifier(
smoothing_strength=strength, mappings=dummy_map
)
assert non_default_sq.smoothing_strength == strength
assert non_default_sq.mappings == dummy_map
@pytest.mark.unit
def test_moe_all_experts_smoothed():
"""
Test that SmoothQuant smooths ALL experts in MoE models, not just expert.0.
Verifies that all experts are included in balance_layers when resolving
mappings for MoE models with multiple experts.
"""
num_experts = 8
hidden_size = 256
experts = torch.nn.ModuleList(
[
torch.nn.ModuleDict(
{
"w1": torch.nn.Linear(hidden_size, hidden_size),
"w2": torch.nn.Linear(hidden_size, hidden_size),
}
)
for _ in range(num_experts)
]
)
model = torch.nn.ModuleDict(
{
"layers": torch.nn.ModuleList(
[
torch.nn.ModuleDict(
{
"input_layernorm": torch.nn.LayerNorm(hidden_size),
"mlp": torch.nn.ModuleDict(
{
"gate": torch.nn.Linear(hidden_size, num_experts),
"experts": experts,
}
),
}
)
]
)
}
)
sq = SmoothQuantModifier(
smoothing_strength=0.8,
mappings=[(["re:.*experts.*w1"], "re:.*input_layernorm")],
ignore=["re:.*gate"],
)
resolved_mappings = sq._resolve_mappings(model)
assert len(resolved_mappings) == 1
mapping = resolved_mappings[0]
assert "input_layernorm" in mapping.smooth_name
assert (
len(mapping.balance_layers) == num_experts
), f"Expected {num_experts} balance layers, got {len(mapping.balance_layers)}"
# Verify no duplicates
balance_layer_ids = [id(layer) for layer in mapping.balance_layers]
assert len(balance_layer_ids) == len(set(balance_layer_ids))
# Verify correct layers
expected_expert_w1s = {experts[i].w1 for i in range(num_experts)}
assert set(mapping.balance_layers) == expected_expert_w1s
@pytest.mark.unit
def test_moe_multiple_layers_all_experts_smoothed():
"""
Test SmoothQuant with multiple MoE layers to ensure all experts across
all layers are smoothed correctly.
"""
num_layers = 2
num_experts = 4
hidden_size = 128
def create_moe_layer():
experts = torch.nn.ModuleList(
[
torch.nn.ModuleDict(
{
"w1": torch.nn.Linear(hidden_size, hidden_size),
"w2": torch.nn.Linear(hidden_size, hidden_size),
}
)
for _ in range(num_experts)
]
)
return torch.nn.ModuleDict(
{
"input_layernorm": torch.nn.LayerNorm(hidden_size),
"mlp": torch.nn.ModuleDict(
{
"gate": torch.nn.Linear(hidden_size, num_experts),
"experts": experts,
}
),
}
)
model = torch.nn.ModuleDict(
{"layers": torch.nn.ModuleList([create_moe_layer() for _ in range(num_layers)])}
)
sq = SmoothQuantModifier(
smoothing_strength=0.8,
mappings=[(["re:.*experts.*w1"], "re:.*input_layernorm")],
ignore=["re:.*gate"],
)
resolved_mappings = sq._resolve_mappings(model)
assert len(resolved_mappings) == num_layers
for i, mapping in enumerate(resolved_mappings):
assert len(mapping.balance_layers) == num_experts, (
f"Layer {i}: Expected {num_experts} balance layers, "
f"got {len(mapping.balance_layers)}"
)
# Verify all balance layers are unique
balance_layer_ids = [id(layer) for layer in mapping.balance_layers]
assert len(balance_layer_ids) == len(set(balance_layer_ids))
@pytest.mark.unit
def test_ignore_behavior():
"""Test that mapping is skipped when ALL layers are in ignore list"""
hidden_size = 64
model = torch.nn.ModuleDict(
{
"decoder": torch.nn.ModuleDict(
{
"input_layernorm": torch.nn.LayerNorm(hidden_size),
"self_attn": torch.nn.ModuleDict(
{
"q_proj": torch.nn.Linear(hidden_size, hidden_size),
"k_proj": torch.nn.Linear(hidden_size, hidden_size),
"v_proj": torch.nn.Linear(hidden_size, hidden_size),
}
),
}
)
}
)
# Test case 1: Some balance layers ignored - mapping should proceed
sq = SmoothQuantModifier(
smoothing_strength=0.5,
mappings=[
(["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm")
],
ignore=["re:.*q_proj", "re:.*k_proj"], # Only 2 of 3 balance layers ignored
)
resolved_mappings = sq._resolve_mappings(model)
# Mapping should exist because v_proj is not ignored
assert len(resolved_mappings) == 1
# Test case 2: All layers ignored - mapping should be skipped
sq2 = SmoothQuantModifier(
smoothing_strength=0.5,
mappings=[
(["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm")
],
ignore=[
"re:.*input_layernorm",
"re:.*q_proj",
"re:.*k_proj",
"re:.*v_proj",
],
)
resolved_mappings2 = sq2._resolve_mappings(model)
# Mapping should be skipped because all layers are ignored
assert len(resolved_mappings2) == 0
@pytest.mark.smoke
@pytest.mark.integration
def test_smoothquant_e2e():
"""
Test that SmoothQuant applied via oneshot actually transforms weights correctly.
Runs oneshot with SmoothQuantModifier on a small model and verifies:
1. Weights actually changed (smoothing was applied)
2. The model output is approximately preserved (the SmoothQuant invariant)
"""
model_id = "nm-testing/tinysmokellama-3.2"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Record original weights for layers that should be smoothed
orig_weights = {
name: param.clone()
for name, param in model.named_parameters()
if "input_layernorm" in name or "q_proj.weight" in name
}
# Get original model output
sample_input = tokenizer("Hello world", return_tensors="pt")
with torch.no_grad():
before_smooth = model(**sample_input).logits
oneshot(
model=model,
dataset="open_platypus",
splits="train[:10%]",
recipe=SmoothQuantModifier(smoothing_strength=0.5),
num_calibration_samples=4,
max_seq_length=128,
)
# 1. Verify weights actually changed
for name, orig_param in orig_weights.items():
current_param = getattr_chain(model, name)
assert not torch.equal(
orig_param, current_param.to(orig_param.device)
), f"Weight {name} was not modified by SmoothQuant"
# 2. Verify model output is approximately preserved (SmoothQuant invariant)
with torch.no_grad():
after_smooth = model(**sample_input).logits
torch.testing.assert_close(
before_smooth, after_smooth.to(before_smooth.device), atol=0.1, rtol=0.1
)