Skip to content

Commit c906500

Browse files
mgoinclaude
andauthored
Add the QuantizedActivation linear-kernel contract (vllm-project#44260)
Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
1 parent 9eaacb2 commit c906500

13 files changed

Lines changed: 327 additions & 27 deletions

File tree

.buildkite/test_areas/quantization.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@ steps:
2121
- uv pip install --system conch-triton-kernels
2222
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py
2323

24+
- label: Quantized Fusions
25+
key: quantized-fusions
26+
timeout_in_minutes: 30
27+
source_file_dependencies:
28+
- tests/fusion
29+
- vllm/model_executor/layers/fusion
30+
- vllm/model_executor/kernels/linear
31+
- vllm/model_executor/layers/quantization/compressed_tensors
32+
- vllm/model_executor/layers/quantization/modelopt.py
33+
commands:
34+
- pytest -v -s fusion/
35+
2436
- label: Quantized MoE Test (B200)
2537
key: quantized-moe-test-b200
2638
timeout_in_minutes: 60

tests/fusion/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Contract tests for the QuantizedActivation linear-kernel integration."""
4+
5+
import pytest
6+
import torch
7+
8+
from vllm.model_executor.kernels.linear import (
9+
_POSSIBLE_FP8_BLOCK_KERNELS,
10+
_POSSIBLE_FP8_KERNELS,
11+
_POSSIBLE_INT8_KERNELS,
12+
_POSSIBLE_NVFP4_KERNELS,
13+
)
14+
from vllm.model_executor.kernels.linear.nvfp4.base import (
15+
NvFp4LinearKernel,
16+
NvFp4LinearLayerConfig,
17+
)
18+
from vllm.model_executor.kernels.linear.nvfp4.flashinfer import (
19+
FlashInferCutlassNvFp4LinearKernel,
20+
FlashInferTrtllmNvFp4LinearKernel,
21+
)
22+
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
23+
CutlassFP8ScaledMMLinearKernel,
24+
)
25+
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
26+
FlashInferFP8ScaledMMLinearKernel,
27+
)
28+
from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import (
29+
FP8ScaledMMLinearLayerConfig,
30+
Int8ScaledMMLinearKernel,
31+
Int8ScaledMMLinearLayerConfig,
32+
)
33+
from vllm.model_executor.layers.fusion.quant_activation import (
34+
QuantizedActivation,
35+
as_quantized_activation,
36+
expose_input_quant_key,
37+
)
38+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
39+
kFp8StaticTensorSym,
40+
kNvfp4Dynamic,
41+
)
42+
from vllm.platforms import current_platform
43+
44+
# The only backends that consume a pre-quantized activation.
45+
SUPPORTING = {
46+
CutlassFP8ScaledMMLinearKernel,
47+
FlashInferFP8ScaledMMLinearKernel,
48+
FlashInferCutlassNvFp4LinearKernel,
49+
}
50+
51+
52+
def _all_kernel_classes() -> list[type]:
53+
seen: dict[type, None] = {}
54+
for registry in (
55+
_POSSIBLE_FP8_KERNELS,
56+
_POSSIBLE_FP8_BLOCK_KERNELS,
57+
_POSSIBLE_INT8_KERNELS,
58+
_POSSIBLE_NVFP4_KERNELS,
59+
):
60+
for kernels in registry.values():
61+
for cls in kernels:
62+
seen.setdefault(cls, None)
63+
return list(seen)
64+
65+
66+
def _probe(cls: type):
67+
"""A bare kernel instance with a plausible config, so input_quant_key()
68+
can be queried without the hardware-gated constructor."""
69+
obj = cls.__new__(cls) # type: ignore[call-overload]
70+
if issubclass(cls, NvFp4LinearKernel):
71+
obj.config = NvFp4LinearLayerConfig()
72+
elif issubclass(cls, Int8ScaledMMLinearKernel):
73+
obj.config = Int8ScaledMMLinearLayerConfig(
74+
is_static_input_scheme=True, is_channelwise=False, input_symmetric=True
75+
)
76+
else:
77+
obj.config = FP8ScaledMMLinearLayerConfig(
78+
weight_quant_key=kFp8StaticTensorSym,
79+
activation_quant_key=kFp8StaticTensorSym,
80+
weight_shape=(16, 16),
81+
input_dtype=torch.bfloat16,
82+
out_dtype=torch.bfloat16,
83+
)
84+
return obj
85+
86+
87+
def _resolved_apply_weights(cls: type):
88+
for base in cls.__mro__:
89+
if "apply_weights" in base.__dict__:
90+
return base.__dict__["apply_weights"]
91+
raise AssertionError(f"{cls.__name__} has no apply_weights in its MRO")
92+
93+
94+
def test_only_known_backends_support_prequantized_input():
95+
declarers = {c for c in _all_kernel_classes() if _probe(c).input_quant_key()}
96+
assert declarers == SUPPORTING
97+
98+
99+
def test_supporting_backend_declares_consume_via_helper():
100+
for cls in SUPPORTING:
101+
fn = _resolved_apply_weights(cls)
102+
assert "as_quantized_activation" in fn.__code__.co_names, cls.__name__
103+
104+
105+
def test_bridge_marks_supporting_and_skips_others():
106+
supported = _probe(FlashInferCutlassNvFp4LinearKernel)
107+
layer = torch.nn.Module()
108+
expose_input_quant_key(layer, supported)
109+
assert layer.input_quant_key == kNvfp4Dynamic
110+
111+
unsupported = _probe(FlashInferTrtllmNvFp4LinearKernel)
112+
assert unsupported.input_quant_key() is None
113+
layer = torch.nn.Module()
114+
expose_input_quant_key(layer, unsupported)
115+
assert not hasattr(layer, "input_quant_key")
116+
117+
118+
def test_as_quantized_activation_validates_key():
119+
qa = QuantizedActivation(
120+
data=torch.zeros(2, 4, dtype=current_platform.fp8_dtype()),
121+
scale=torch.tensor(1.0),
122+
orig_dtype=torch.bfloat16,
123+
orig_shape=torch.Size([2, 4]),
124+
quant_key=kFp8StaticTensorSym,
125+
)
126+
with pytest.raises(AssertionError):
127+
as_quantized_activation(qa, kNvfp4Dynamic)
128+
with pytest.raises(AssertionError):
129+
as_quantized_activation(qa, None)
130+
assert as_quantized_activation(torch.zeros(2, 4), kFp8StaticTensorSym) is None
131+
assert as_quantized_activation(qa, kFp8StaticTensorSym) is qa

vllm/model_executor/kernels/linear/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import torch
99
from typing_extensions import Self
1010

11+
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
12+
1113

1214
@dataclass
1315
class MMLinearLayerConfig: ...
@@ -237,6 +239,12 @@ def __init__(self, config: _ConfigT) -> None:
237239
"""
238240
self.config = config
239241

242+
def input_quant_key(self) -> QuantKey | None:
243+
"""Return the input quantization key supported by this kernel. If the kernel
244+
does not support input quantization outside of the kernel, return None.
245+
"""
246+
return None
247+
240248
@abstractmethod
241249
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
242250
"""Process and transform weights after loading from checkpoint.

vllm/model_executor/kernels/linear/nvfp4/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import torch
88

9+
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
10+
911

1012
@dataclass
1113
class NvFp4LinearLayerConfig:
@@ -33,6 +35,12 @@ def __init__(self, config: NvFp4LinearLayerConfig) -> None:
3335
assert self.is_supported()[0]
3436
self.config = config
3537

38+
def input_quant_key(self) -> QuantKey | None:
39+
"""Return the input quantization key supported by this kernel. If the kernel
40+
does not support input quantization outside of the kernel, return None.
41+
"""
42+
return None
43+
3644
@classmethod
3745
@abstractmethod
3846
def is_supported(

vllm/model_executor/kernels/linear/nvfp4/flashinfer.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,20 @@
44
import torch
55

66
from vllm._custom_ops import scaled_fp4_quant
7+
from vllm.model_executor.layers.fusion.quant_activation import (
8+
QuantizedActivation,
9+
as_quantized_activation,
10+
)
711
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
812
pad_nvfp4_activation_for_cutlass,
913
pad_nvfp4_weight_for_cutlass,
1014
slice_nvfp4_output,
1115
swizzle_blockscale,
1216
)
17+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
18+
QuantKey,
19+
kNvfp4Dynamic,
20+
)
1321
from vllm.platforms import current_platform
1422
from vllm.utils.flashinfer import (
1523
flashinfer_scaled_fp4_mm,
@@ -23,6 +31,11 @@
2331
class FlashInferCutlassNvFp4LinearKernel(NvFp4LinearKernel):
2432
"""NVFP4 GEMM via FlashInfer's CUTLASS wrapper."""
2533

34+
def input_quant_key(self) -> QuantKey | None:
35+
"""This kernel supports dynamic quantization of the input. By
36+
convention, pre-quantized blockscales must use the swizzled layout."""
37+
return kNvfp4Dynamic
38+
2639
@classmethod
2740
def is_supported(
2841
cls, compute_capability: int | None = None
@@ -56,21 +69,29 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
5669
def apply_weights(
5770
self,
5871
layer: torch.nn.Module,
59-
x: torch.Tensor,
72+
x: torch.Tensor | QuantizedActivation,
6073
bias: torch.Tensor | None = None,
6174
) -> torch.Tensor:
6275
output_size = layer.output_size_per_partition
63-
output_dtype = x.dtype
64-
output_shape = [*x.shape[:-1], output_size]
6576
weights_padding_bytes = getattr(layer, "weights_padding_cols", 0)
6677

67-
x_fp4, x_blockscale = scaled_fp4_quant(
68-
x,
69-
layer.input_global_scale_inv,
70-
is_sf_swizzled_layout=True,
71-
backend="flashinfer-cutlass",
72-
padded_n=x.shape[-1] + weights_padding_bytes * 2,
73-
)
78+
qa = as_quantized_activation(x, self.input_quant_key())
79+
if qa is not None:
80+
x_fp4, x_blockscale = qa.data, qa.scale
81+
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_bytes)
82+
output_dtype = qa.orig_dtype
83+
output_shape = [*qa.orig_shape[:-1], output_size]
84+
else:
85+
assert isinstance(x, torch.Tensor)
86+
output_dtype = x.dtype
87+
output_shape = [*x.shape[:-1], output_size]
88+
x_fp4, x_blockscale = scaled_fp4_quant(
89+
x,
90+
layer.input_global_scale_inv,
91+
is_sf_swizzled_layout=True,
92+
backend="flashinfer-cutlass",
93+
padded_n=x.shape[-1] + weights_padding_bytes * 2,
94+
)
7495

7596
out = flashinfer_scaled_fp4_mm(
7697
x_fp4,

vllm/model_executor/kernels/linear/scaled_mm/ScaledMMLinearKernel.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88

99
import torch
1010

11+
from vllm.model_executor.layers.fusion.quant_activation import (
12+
QuantizedActivation,
13+
as_quantized_activation,
14+
)
1115
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
1216
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1317
QuantKey,
@@ -71,6 +75,17 @@ def __init__(self, c: _ConfigT, layer_param_names: Sequence[str]) -> None:
7175
self.config = c
7276
self.layer_param_names = layer_param_names
7377

78+
def input_quant_key(self) -> QuantKey | None:
79+
"""The activation quant key this kernel can consume pre-quantized.
80+
81+
Manual fusion uses this to decide whether to hoist activation
82+
quantization out of apply_weights into an upstream fused kernel.
83+
Return None when the kernel needs in-kernel quantization (custom
84+
padding or swizzling, dynamic scales, etc.). Kernels that return a
85+
key must consume the activation via as_quantized_activation.
86+
"""
87+
return None
88+
7489
@abstractmethod
7590
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
7691
raise NotImplementedError
@@ -120,30 +135,30 @@ def _get_layer_params(self, layer) -> _FP8ParamsT:
120135
def apply_weights(
121136
self,
122137
layer: torch.nn.Module,
123-
x: torch.Tensor,
138+
x: torch.Tensor | QuantizedActivation,
124139
bias: torch.Tensor | None = None,
125140
) -> torch.Tensor:
126141
fp8_dtype = self.fp8_dtype
127142
maybe_out_dtype = self.config.out_dtype
128143
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
129144

130-
# ops.scaled_fp8_quant supports both dynamic and static quant.
131-
# If dynamic, layer.input_scale is None and x_s computed from x.
132-
# If static, layer.input_scale is scalar and x_s is input_scale.
133-
# View input as 2D matrix for fp8 methods
134-
x_2d = x.view(-1, x.shape[-1])
135-
output_shape = [*x.shape[:-1], w.shape[1]]
136-
out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype
145+
qa = as_quantized_activation(x, self.input_quant_key())
146+
if qa is not None:
147+
x_data, x_s = qa.data, qa.scale
148+
orig_shape, orig_dtype = qa.orig_shape, qa.orig_dtype
149+
assert x_data.dtype == fp8_dtype
150+
else:
151+
assert isinstance(x, torch.Tensor)
152+
x_data = x
153+
orig_shape, orig_dtype = x.shape, x.dtype
154+
155+
x_2d = x_data.view(-1, x_data.shape[-1])
156+
output_shape = [*orig_shape[:-1], w.shape[1]]
157+
out_dtype = orig_dtype if maybe_out_dtype is None else maybe_out_dtype
137158

138-
# If input not quantized
139-
# TODO(luka) remove this path if not used anymore
140159
x_2d_q = x_2d
141-
if x.dtype != fp8_dtype:
142-
x_2d_q, x_s = self.quant_fp8(
143-
x_2d,
144-
x_s,
145-
x_s_ub,
146-
)
160+
if qa is None:
161+
x_2d_q, x_s = self.quant_fp8(x_2d, x_s, x_s_ub)
147162
return self.apply_scaled_mm(
148163
A=x_2d_q,
149164
B=w,

vllm/model_executor/kernels/linear/scaled_mm/cutlass.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from vllm.model_executor.layers.quantization.utils import replace_parameter
1212
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1313
GroupShape,
14+
QuantKey,
15+
kFp8StaticTensorSym,
1416
)
1517
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
1618
CUTLASS_BLOCK_FP8_SUPPORTED,
@@ -171,6 +173,13 @@ def is_supported(
171173
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
172174
return True, None
173175

176+
def input_quant_key(self) -> QuantKey | None:
177+
"""Only static per-tensor activation quantization is supported for external
178+
quantization."""
179+
if self.config.activation_quant_key == kFp8StaticTensorSym:
180+
return kFp8StaticTensorSym
181+
return None
182+
174183
@staticmethod
175184
def _pad_to_alignment(
176185
x: torch.Tensor, dim: int, alignment: int, value: float = 0.0

vllm/model_executor/kernels/linear/scaled_mm/flashinfer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
)
1313
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1414
GroupShape,
15+
QuantKey,
16+
kFp8StaticTensorSym,
1517
)
1618
from vllm.platforms import current_platform
1719
from vllm.utils.flashinfer import (
@@ -62,6 +64,11 @@ def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | Non
6264

6365
return True, None
6466

67+
def input_quant_key(self) -> QuantKey | None:
68+
if self.config.activation_quant_key == kFp8StaticTensorSym:
69+
return kFp8StaticTensorSym
70+
return None
71+
6572
def apply_scaled_mm(
6673
self,
6774
*,

0 commit comments

Comments
 (0)