Skip to content

Commit 3e28211

Browse files
Add SeqMSE support to aimet pass (microsoft#2158)
## Describe your changes Adds sequential MSE support to aimet quantization pass. ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - [ ] Is this PR including examples changes? If yes, please remember to update [example documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md) in a follow-up PR. ## (Optional) Issue link --------- Signed-off-by: Michael Tuttle <mtuttle@qti.qualcomm.com>
1 parent 910cbaf commit 3e28211

File tree

3 files changed

+69
-11
lines changed

3 files changed

+69
-11
lines changed

olive/olive_config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@
602602
}
603603
},
604604
"extra_dependencies": {
605-
"aimet-onnx": [ "aimet-onnx>=2.12.0" ],
605+
"aimet-onnx": [ "aimet-onnx>=2.15.0" ],
606606
"auto-opt": [ "optimum" ],
607607
"azureml": [ "azure-ai-ml>=1.11.1", "azure-identity" ],
608608
"bnb": [ "bitsandbytes", "triton" ],

olive/passes/onnx/aimet_quantization.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,29 @@ def apply( # pylint: disable=arguments-differ
171171
return sim
172172

173173

174+
class SeqMSE(_AimetTechnique):
175+
@staticmethod
176+
def apply( # pylint: disable=arguments-differ
177+
sim,
178+
*,
179+
data_config=None,
180+
num_candidates: int = 20,
181+
):
182+
"""Apply aimet_onnx sequential MSE technique to sim.
183+
184+
Args:
185+
sim: QuantizationSimModel to optimize.
186+
data_config: Dataset to use for optimization. If not specified for the technique, will default to the calibration data.
187+
num_candidates: Number of encoding candidates to sweep for each weight.
188+
189+
"""
190+
from aimet_onnx import apply_seq_mse
191+
192+
apply_seq_mse(sim, data_config, num_candidates)
193+
194+
return sim
195+
196+
174197
class AimetQuantization(Pass):
175198
"""Quantize ONNX model using aimet-onnx."""
176199

@@ -340,6 +363,6 @@ def _run_for_config(
340363
)
341364

342365
sim.compute_encodings(calib_dataloader)
343-
qdq_model = sim.to_onnx_qdq()
366+
qdq_model = sim.to_onnx_qdq(prequantize_constants=True)
344367

345368
return model_proto_to_olive_model(qdq_model, output_model_path, config)

test/passes/onnx/test_aimet_quantization.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,9 @@ def test_aimet_quantization_uses_provided_precisions(tmp_path, precisions):
155155

156156
initializer_dict = {tensor.name: tensor for tensor in model.graph.initializer}
157157
tensor_to_quantizer = {
158-
node.input[0]: node for node in model.graph.node if node.op_type in ("QuantizeLinear", "DequantizeLinear")
158+
node.input[0].removesuffix("_q"): node
159+
for node in model.graph.node
160+
if node.op_type in ("QuantizeLinear", "DequantizeLinear")
159161
}
160162

161163
# Weight should be symmetrically quantized with precision type
@@ -269,11 +271,11 @@ def test_aimet_quantization_applies_adaround(tmp_path):
269271
}
270272
p = create_pass_from_dict(AimetQuantization, config, disable_search=True)
271273

272-
with patch("aimet_onnx.apply_adaround") as mock_seq_mse:
274+
with patch("aimet_onnx.apply_adaround") as mock_adaround:
273275
out = p.run(input_model, tmp_path)
274-
assert mock_seq_mse.call_count == 1
276+
assert mock_adaround.call_count == 1
275277

276-
(_, data, num_iterations, nodes_to_include), _ = mock_seq_mse.call_args
278+
(_, data, num_iterations, nodes_to_include), _ = mock_adaround.call_args
277279
assert isinstance(data, Iterable)
278280
assert num_iterations == 5
279281
assert nodes_to_include is None
@@ -305,13 +307,44 @@ def test_aimet_quantization_excludes_adaround_nodes(tmp_path):
305307
}
306308
p = create_pass_from_dict(AimetQuantization, config, disable_search=True)
307309

308-
with patch("aimet_onnx.apply_adaround") as mock_seq_mse:
310+
with patch("aimet_onnx.apply_adaround") as mock_adaround:
309311
p.run(input_model, tmp_path)
310-
assert mock_seq_mse.call_count == 1
311-
(_, _, _, nodes_to_include), _ = mock_seq_mse.call_args
312+
assert mock_adaround.call_count == 1
313+
(_, _, _, nodes_to_include), _ = mock_adaround.call_args
312314
assert not nodes_to_include
313315

314316

317+
@pytest.mark.skipif(not IS_LINUX, reason="Only run on linux")
318+
@pytest.mark.skipif(CUDA_AVAILABLE, reason="Only run on cpu tests")
319+
def test_aimet_quantization_applies_seq_mse(tmp_path):
320+
input_model = dummy_onnx_matmul_model(tmp_path / "dummy_model_mm.onnx")
321+
config = {
322+
"data_config": DataConfig(
323+
name="test_quant_dc_config",
324+
load_dataset_config=DataComponentConfig(type="simple_dataset"),
325+
dataloader_config=DataComponentConfig(type="_test_quant_dataloader_len_16"),
326+
),
327+
"precision": "int4",
328+
"techniques": [
329+
{
330+
"name": "seqmse",
331+
"num_candidates": 5,
332+
}
333+
],
334+
}
335+
p = create_pass_from_dict(AimetQuantization, config, disable_search=True)
336+
337+
with patch("aimet_onnx.apply_seq_mse") as mock_seq_mse:
338+
out = p.run(input_model, tmp_path)
339+
assert mock_seq_mse.call_count == 1
340+
341+
(_, data, num_candidates), _ = mock_seq_mse.call_args
342+
assert isinstance(data, Iterable)
343+
assert num_candidates == 5
344+
345+
assert out is not None
346+
347+
315348
@pytest.mark.skipif(not IS_LINUX, reason="Only run on linux")
316349
@pytest.mark.skipif(CUDA_AVAILABLE, reason="Only run on cpu tests")
317350
@pytest.mark.parametrize(
@@ -344,7 +377,7 @@ def test_aimet_quantization_excludes_op_types(tmp_path, op_types, disabled_quant
344377
model = onnx.load(out.model_path)
345378

346379
tensor_to_quantizer = {
347-
tensor: node
380+
tensor.removesuffix("_q"): node
348381
for node in model.graph.node
349382
for tensor in (node.input[0], node.output[0])
350383
if node.op_type in ("QuantizeLinear", "DequantizeLinear")
@@ -374,7 +407,9 @@ def test_aimet_quantization_preserves_quantization_in_prequantized_model(tmp_pat
374407
model = onnx.load(out.model_path)
375408

376409
tensor_to_quantizer = {
377-
node.input[0]: node for node in model.graph.node if node.op_type in ("QuantizeLinear", "DequantizeLinear")
410+
node.input[0].removesuffix("_q"): node
411+
for node in model.graph.node
412+
if node.op_type in ("QuantizeLinear", "DequantizeLinear")
378413
}
379414

380415
weight_quantizer = tensor_to_quantizer["weight_dq"]

0 commit comments

Comments
 (0)