Skip to content

Commit 7192edf

Browse files
authored
Enabling MOE Quantization using linear decomposition (#2043)
* Enabling MOE Quantization using linear decomposition Summary: This PR is a first step at optimizing moe inference using torchAO. The goal for this step is to enable existing quantization kernels and workflows to work for moe quantization by decomposing the group gemm into a sequence of unbalanced linear ops that can use the existing quantized kernels. To enable this we had to add support for quantizing these 3D tensors as well as slicing and indexing. 2 methods of achieving this were implemented. for int8wo, int8dq, int4wo, fp8wo, fp8dq, the underlying quantized tensor subclass was adapted to both support 3D tensors, indexing and slicing, as well as an updated transformation function that can handle the ConditionalFeedForwardAOQuantizable modules if the filter funciton in quantize_ is used to target the aforementioned module. For some complex kernels which use packed data that couldn't be made to easily work in 3D, we also added FakeExtraDimTensor which can transform any quantized tensor subclass into supporting the necessary slice and index operations for moe quantization. This option is enabled by using MoeQuantConfig. This can be applied to huggingface llama4 for instance as shown int he llama4_quant.py example. Since the hf moe module is implemented in a way that's not condusive to quantization, it first requires a module swap to the MOEFeedForwardAOQuantizable. TODO final benchmark numbers from run.sh, consolidate 3x implementation of MOEFeedForwardAOQuantizable and ConditionalFeedForwardAOQuantizable. verify hqq Test Plan: python test/quantization/test_moe_quant.py python test/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py -k "test_moe_quant_intx" sh torchao/_models/mixtral-moe/run.sh Reviewers: Subscribers: Tasks: Tags: * fixing CI Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fixing CI Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fixing CI Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * lint Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * remove test code Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fixing exp test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fixing experimental test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fixing experimental CI Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fixing generate.py device stuff Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fixing tests that aren't skipping Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * ruff format Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * removing test code Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fixing CI Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * update API and remove branching on quant_api.py transform functions Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * ruff format Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fix weird ci error Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * remove change to test_integration.py Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent e417afc commit 7192edf

24 files changed

+2585
-123
lines changed

test/quantization/test_moe_quant.py

Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
import unittest
2+
3+
import torch
4+
from parameterized import parameterized
5+
6+
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
7+
from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl
8+
from torchao.dtypes.uintx.tensor_core_tiled_layout import TensorCoreTiledAQTTensorImpl
9+
from torchao.quantization.prototype.moe_quant.quantizable_moe_modules import (
10+
MOEFeedForwardAOQuantizable,
11+
)
12+
from torchao.quantization.prototype.moe_quant.utils import (
13+
FakeExtraDimTensor,
14+
MoEQuantConfig,
15+
UseFakeExtraDimTensor,
16+
cond_ffn_filter,
17+
)
18+
from torchao.quantization.quant_api import (
19+
AffineQuantizedTensor,
20+
Float8DynamicActivationFloat8WeightConfig,
21+
Float8WeightOnlyConfig,
22+
Int4WeightOnlyConfig,
23+
Int8DynamicActivationInt8WeightConfig,
24+
Int8WeightOnlyConfig,
25+
LinearActivationQuantizedTensor,
26+
quantize_,
27+
)
28+
from torchao.quantization.utils import compute_error
29+
from torchao.utils import (
30+
TORCH_VERSION_AT_LEAST_2_5,
31+
TORCH_VERSION_AT_LEAST_2_6,
32+
is_sm_at_least_90,
33+
)
34+
35+
36+
class TestMoEQuantCompile(unittest.TestCase):
37+
DEFAULT_PARAMS = (512, 256, 8, 2) # hidden_dim, expert_dim, num_experts, top_k
38+
39+
@torch.no_grad()
40+
def _test_impl_moe_quant(
41+
self,
42+
config,
43+
num_tokens=1,
44+
model_params=None,
45+
base_class=AffineQuantizedTensor,
46+
tensor_impl_class=None,
47+
dtype=torch.bfloat16,
48+
device="cuda",
49+
fullgraph=False,
50+
):
51+
"""
52+
Tests moe quant for techniques using fake extra dim
53+
"""
54+
if model_params is None:
55+
model_params = self.DEFAULT_PARAMS
56+
57+
input_shape = (num_tokens, model_params[0])
58+
model = (
59+
MOEFeedForwardAOQuantizable(*model_params, empty_init=False)
60+
.to(dtype)
61+
.to(device)
62+
)
63+
input = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
64+
65+
out = model(input)
66+
67+
quantize_(model, config, cond_ffn_filter)
68+
69+
if (
70+
isinstance(config, MoEQuantConfig)
71+
and config.use_fake_extra_dim_tensor == UseFakeExtraDimTensor.TRUE
72+
):
73+
self.assertIsInstance(model.experts.w1, FakeExtraDimTensor)
74+
if base_class is not None:
75+
self.assertIsInstance(model.experts.w1.head_tensor, base_class)
76+
if tensor_impl_class is not None:
77+
self.assertIsInstance(
78+
model.experts.w1.head_tensor.tensor_impl, tensor_impl_class
79+
)
80+
else:
81+
if base_class is not None:
82+
self.assertIsInstance(model.experts.w1, base_class)
83+
if tensor_impl_class is not None:
84+
self.assertIsInstance(model.experts.w1.tensor_impl, tensor_impl_class)
85+
86+
out_q = model(input)
87+
88+
torch._dynamo.config.capture_scalar_outputs = True
89+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
90+
model_c = torch.compile(model, mode="reduce-overhead", fullgraph=fullgraph)
91+
92+
model_c(input)
93+
model_c(input)
94+
out_qc = model_c(input).clone()
95+
96+
for i in range(10):
97+
input = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
98+
model_c(input)
99+
100+
self.assertGreaterEqual(compute_error(out_q, out), 10)
101+
self.assertGreaterEqual(compute_error(out_qc, out), 10)
102+
103+
@parameterized.expand(
104+
[
105+
("single_token", 1, False),
106+
("multiple_tokens", 8, False),
107+
]
108+
)
109+
def test_int4wo_fake_dim(self, name, num_tokens, fullgraph):
110+
if not torch.cuda.is_available():
111+
self.skipTest("Need CUDA available")
112+
if not TORCH_VERSION_AT_LEAST_2_5:
113+
self.skipTest("Test only enabled for 2.5+")
114+
115+
config = MoEQuantConfig(
116+
Int4WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE
117+
)
118+
tensor_impl_class = TensorCoreTiledAQTTensorImpl
119+
120+
self._test_impl_moe_quant(
121+
config=config,
122+
num_tokens=num_tokens,
123+
tensor_impl_class=tensor_impl_class,
124+
fullgraph=fullgraph,
125+
)
126+
127+
@parameterized.expand(
128+
[
129+
("single_token", 1, True),
130+
("multiple_tokens", 8, False),
131+
]
132+
)
133+
def test_int4wo_base(self, name, num_tokens, fullgraph):
134+
if not torch.cuda.is_available():
135+
self.skipTest("Need CUDA available")
136+
if not is_sm_at_least_90():
137+
self.skipTest("Requires CUDA capability >= 9.0")
138+
if not TORCH_VERSION_AT_LEAST_2_5:
139+
self.skipTest("Test only enabled for 2.5+")
140+
141+
config = MoEQuantConfig(Int4WeightOnlyConfig())
142+
tensor_impl_class = TensorCoreTiledAQTTensorImpl
143+
144+
self._test_impl_moe_quant(
145+
config=config,
146+
num_tokens=num_tokens,
147+
tensor_impl_class=tensor_impl_class,
148+
fullgraph=fullgraph,
149+
)
150+
151+
@parameterized.expand(
152+
[
153+
("single_token", 1, False),
154+
("multiple_tokens", 8, False),
155+
]
156+
)
157+
def test_int8wo_fake_dim(self, name, num_tokens, fullgraph):
158+
if not torch.cuda.is_available():
159+
self.skipTest("Need CUDA available")
160+
if not TORCH_VERSION_AT_LEAST_2_5:
161+
self.skipTest("Test only enabled for 2.5+")
162+
163+
config = MoEQuantConfig(
164+
Int8WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE
165+
)
166+
tensor_impl_class = PlainAQTTensorImpl
167+
168+
self._test_impl_moe_quant(
169+
config=config,
170+
num_tokens=num_tokens,
171+
tensor_impl_class=tensor_impl_class,
172+
fullgraph=fullgraph,
173+
)
174+
175+
@parameterized.expand(
176+
[
177+
("single_token", 1, True),
178+
("multiple_tokens", 8, False),
179+
]
180+
)
181+
def test_int8wo_base(self, name, num_tokens, fullgraph):
182+
if not torch.cuda.is_available():
183+
self.skipTest("Need CUDA available")
184+
if not TORCH_VERSION_AT_LEAST_2_6:
185+
self.skipTest("Test only enabled for 2.6+")
186+
187+
config = MoEQuantConfig(Int8WeightOnlyConfig())
188+
tensor_impl_class = PlainAQTTensorImpl
189+
190+
self._test_impl_moe_quant(
191+
config=config,
192+
num_tokens=num_tokens,
193+
tensor_impl_class=tensor_impl_class,
194+
fullgraph=fullgraph,
195+
)
196+
197+
@parameterized.expand(
198+
[
199+
("single_token", 1, True),
200+
("multiple_tokens", 8, False),
201+
]
202+
)
203+
def test_int8wo_base_cpu(self, name, num_tokens, fullgraph):
204+
if not TORCH_VERSION_AT_LEAST_2_6:
205+
self.skipTest("Test only enabled for 2.6+")
206+
207+
config = MoEQuantConfig(Int8WeightOnlyConfig())
208+
tensor_impl_class = PlainAQTTensorImpl
209+
210+
self._test_impl_moe_quant(
211+
config=config,
212+
num_tokens=num_tokens,
213+
tensor_impl_class=tensor_impl_class,
214+
fullgraph=fullgraph,
215+
device="cpu",
216+
)
217+
218+
@parameterized.expand(
219+
[
220+
("multiple_tokens", 32, False),
221+
]
222+
)
223+
def test_int8dq_fake_dim(self, name, num_tokens, fullgraph):
224+
if not torch.cuda.is_available():
225+
self.skipTest("Need CUDA available")
226+
if not TORCH_VERSION_AT_LEAST_2_5:
227+
self.skipTest("Test only enabled for 2.5+")
228+
229+
config = MoEQuantConfig(
230+
Int8DynamicActivationInt8WeightConfig(),
231+
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
232+
)
233+
base_class = LinearActivationQuantizedTensor
234+
235+
self._test_impl_moe_quant(
236+
model_params=(512, 256, 2, 2),
237+
config=config,
238+
num_tokens=num_tokens,
239+
base_class=base_class,
240+
fullgraph=fullgraph,
241+
)
242+
243+
@parameterized.expand(
244+
[
245+
("multiple_tokens", 32, False),
246+
]
247+
)
248+
def test_int8dq_base(self, name, num_tokens, fullgraph):
249+
if not torch.cuda.is_available():
250+
self.skipTest("Need CUDA available")
251+
if not TORCH_VERSION_AT_LEAST_2_5:
252+
self.skipTest("Test only enabled for 2.5+")
253+
254+
config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig())
255+
base_class = LinearActivationQuantizedTensor
256+
257+
self._test_impl_moe_quant(
258+
model_params=(512, 256, 2, 2),
259+
config=config,
260+
num_tokens=num_tokens,
261+
base_class=base_class,
262+
fullgraph=fullgraph,
263+
)
264+
265+
@parameterized.expand(
266+
[
267+
("single_token", 1, False),
268+
("multiple_tokens", 8, False),
269+
]
270+
)
271+
def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph):
272+
if not torch.cuda.is_available():
273+
self.skipTest("Need CUDA available")
274+
if not is_sm_at_least_90():
275+
self.skipTest("Requires CUDA capability >= 9.0")
276+
277+
config = MoEQuantConfig(
278+
Float8WeightOnlyConfig(),
279+
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
280+
)
281+
tensor_impl_class = Float8AQTTensorImpl
282+
283+
self._test_impl_moe_quant(
284+
config=config,
285+
num_tokens=num_tokens,
286+
tensor_impl_class=tensor_impl_class,
287+
fullgraph=fullgraph,
288+
)
289+
290+
@parameterized.expand(
291+
[
292+
("single_token", 1, True),
293+
("multiple_tokens", 8, False),
294+
]
295+
)
296+
def test_fp8wo_base(self, name, num_tokens, fullgraph):
297+
if not torch.cuda.is_available():
298+
self.skipTest("Need CUDA available")
299+
if not is_sm_at_least_90():
300+
self.skipTest("Requires CUDA capability >= 9.0")
301+
302+
config = MoEQuantConfig(Float8WeightOnlyConfig())
303+
tensor_impl_class = Float8AQTTensorImpl
304+
305+
self._test_impl_moe_quant(
306+
config=config,
307+
num_tokens=num_tokens,
308+
tensor_impl_class=tensor_impl_class,
309+
fullgraph=fullgraph,
310+
)
311+
312+
@parameterized.expand(
313+
[
314+
("single_token", 1, False),
315+
("multiple_tokens", 8, False),
316+
]
317+
)
318+
def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph):
319+
if not torch.cuda.is_available():
320+
self.skipTest("Need CUDA available")
321+
if not is_sm_at_least_90():
322+
self.skipTest("Requires CUDA capability >= 9.0")
323+
324+
config = MoEQuantConfig(
325+
Float8DynamicActivationFloat8WeightConfig(),
326+
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
327+
)
328+
base_class = LinearActivationQuantizedTensor
329+
330+
self._test_impl_moe_quant(
331+
config=config,
332+
num_tokens=num_tokens,
333+
base_class=base_class,
334+
fullgraph=fullgraph,
335+
)
336+
337+
@parameterized.expand(
338+
[
339+
("single_token", 1, True),
340+
("multiple_tokens", 8, False),
341+
]
342+
)
343+
def test_fp8dq_base(self, name, num_tokens, fullgraph):
344+
if not torch.cuda.is_available():
345+
self.skipTest("Need CUDA available")
346+
if not is_sm_at_least_90():
347+
self.skipTest("Requires CUDA capability >= 9.0")
348+
349+
config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig())
350+
base_class = LinearActivationQuantizedTensor
351+
352+
self._test_impl_moe_quant(
353+
config=config,
354+
num_tokens=num_tokens,
355+
base_class=base_class,
356+
fullgraph=fullgraph,
357+
)
358+
359+
360+
if __name__ == "__main__":
361+
unittest.main()

torchao/_models/mixtral-moe/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
## Mixtral-MoE
2+
3+
This folder contains code and scripts for benchmarking the Mixtral-MoE model.
4+
Running
5+
6+
`sh scripts/prepare.sh`
7+
8+
should download the model and `sh run.sh` will run teh benchmarks.

0 commit comments

Comments
 (0)