Skip to content

Commit 00296a5

Browse files
authored
[Feature]Support and update INT8/INT4 quantization inference (#299)
Signed-off-by: Li Wei <liwei.109@outlook.com>
1 parent f421ea0 commit 00296a5

20 files changed

Lines changed: 488 additions & 351 deletions

File tree

ci/scripts/env/install_env.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ docker exec "${DOCKER_NAME}" bash -lc "
5454
# Patch torch dynamo eval_frame
5555
cp vllm_kunlun/patches/eval_frame.py \
5656
/root/miniconda/envs/${CONDA_ENV}/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py
57+
58+
# Patch quantization __init__.py
59+
cp vllm_kunlun/quantization/__init__.py \
60+
/root/miniconda/envs/${CONDA_ENV}/lib/python3.10/site-packages/vllm/model_executor/layers/quantization/__init__.py
5761
5862
########################################
5963
# Kunlun runtime dependencies

docs/source/installation.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ Copy the eval_frame.py patch:
7676

7777
```
7878
cp vllm_kunlun/patches/eval_frame.py "${CONDA_PREFIX:-$VIRTUAL_ENV}"/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py
79+
```
80+
81+
### Replace quantization __init__.py
82+
83+
```
84+
cp vllm_kunlun/quantization/__init__.py "${CONDA_PREFIX:-$VIRTUAL_ENV}"/lib/python3.10/site-packages/vllm/model_executor/layers/quantization/__init__.py
7985
```
8086

8187
## Choose to download customized xpytorch

vllm_kunlun/ops/__init__.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,3 @@
2929

3030
# TODO @xyDong0223 remove v0.16.0
3131
# import vllm_kunlun.ops.mla
32-
33-
# quantization
34-
# TODO @liwei109 enable quantization in v0.16.0
35-
# import vllm_kunlun.ops.quantization.awq
36-
# import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors
37-
# import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors_moe
38-
# import vllm_kunlun.ops.quantization.gptq
39-
# import vllm_kunlun.ops.quantization.kernels.kunlun_exllama_linear
40-
# import vllm_kunlun.ops.quantization.kernels.kunlun_scale_mm
41-
# import vllm_kunlun.ops.quantization.moe_wna16

vllm_kunlun/ops/linear.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,22 +336,47 @@ def _load_fused_module_from_checkpoint(
336336
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
337337

338338

339+
def validate_shard_id(self, loaded_shard_id: int | tuple[int, ...] | None):
340+
if loaded_shard_id is None:
341+
return
342+
if isinstance(loaded_shard_id, tuple):
343+
for idx in loaded_shard_id:
344+
if not (0 <= idx < len(self.output_sizes)):
345+
raise ValueError(
346+
f"Shard id index {idx} should be between 0 and "
347+
f"{len(self.output_sizes) - 1}. Got shard id {loaded_shard_id}."
348+
)
349+
if len(loaded_shard_id) > 1 and any(
350+
b - a != 1 for a, b in zip(loaded_shard_id[:-1], loaded_shard_id[1:])
351+
):
352+
raise ValueError(
353+
"Shard id with multiple indices should be consecutive. "
354+
f"Got shard id {loaded_shard_id}."
355+
)
356+
return
357+
elif isinstance(loaded_shard_id, int):
358+
if loaded_shard_id < 0 or loaded_shard_id >= len(self.output_sizes):
359+
raise ValueError(
360+
f"Shard id should be between 0 and {len(self.output_sizes) - 1}. "
361+
f"Got shard id {loaded_shard_id}."
362+
)
363+
return
364+
365+
339366
def weight_loader_v2(
340367
self,
341368
param: BasevLLMParameter,
342369
loaded_weight: torch.Tensor,
343370
loaded_shard_id: tuple[int, ...] | int | None = None,
344371
):
372+
self.validate_shard_id(loaded_shard_id)
345373
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
346374
if isinstance(param, PerTensorScaleParameter):
347375
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
348376
return
349377
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
350378
param.load_merged_column_weight(loaded_weight=loaded_weight)
351379
return
352-
# TODO: @dsikka - move to parameter.py
353-
self._load_fused_module_from_checkpoint(param, loaded_weight)
354-
return
355380
output_sizes = (
356381
[self.output_sizes[idx] for idx in loaded_shard_id]
357382
if loaded_shard_id
@@ -363,6 +388,7 @@ def weight_loader_v2(
363388
adjust_block_scale_shard(weight_block_size, size, 0)[0]
364389
for size in (output_sizes or self.output_sizes)
365390
]
391+
# TODO: @dsikka - move to parameter.py
366392
self._load_fused_module_from_checkpoint(
367393
param, loaded_weight, output_sizes=output_sizes
368394
)
@@ -394,6 +420,7 @@ def weight_loader_v2(
394420
MergedColumnParallelLinear._load_fused_module_from_checkpoint = (
395421
_load_fused_module_from_checkpoint
396422
)
423+
MergedColumnParallelLinear.validate_shard_id = validate_shard_id
397424
MergedColumnParallelLinear.weight_loader_v2 = weight_loader_v2
398425

399426

vllm_kunlun/ops/quantization/__init__.py

Whitespace-only changes.

vllm_kunlun/ops/quantization/compressed_tensors/__init__.py

Whitespace-only changes.

vllm_kunlun/ops/quantization/kernels/__init__.py

Whitespace-only changes.

vllm_kunlun/platforms/kunlun.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import vllm.envs as envs
88
from vllm.logger import init_logger
99
from vllm.platforms.interface import DeviceCapability, Platform, PlatformEnum
10+
from vllm.utils.argparse_utils import FlexibleArgumentParser
1011
from vllm.v1.attention.backends.registry import AttentionBackendEnum
1112

1213
if TYPE_CHECKING:
@@ -375,3 +376,15 @@ def support_hybrid_kv_cache(cls) -> bool:
375376
@classmethod
376377
def support_static_graph_mode(cls) -> bool:
377378
return True
379+
380+
@classmethod
381+
def pre_register_and_update(
382+
cls, parser: FlexibleArgumentParser | None = None
383+
) -> None:
384+
from vllm_kunlun.quantization.awq import KunlunAWQConfig # noqa
385+
from vllm_kunlun.quantization.compressed_tensors import ( # noqa
386+
KunlunCompressedTensorsConfig,
387+
)
388+
from vllm_kunlun.quantization.gptq import KunlunGPTQConfig # noqa
389+
from vllm_kunlun.quantization.kernels import _POSSIBLE_INT8_KERNELS # noqa
390+
from vllm_kunlun.quantization.kernels import _POSSIBLE_KERNELS # noqa
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# patched by vLLM-Kunlun
4+
5+
from typing import Literal, get_args
6+
7+
from vllm.logger import init_logger
8+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
9+
from vllm.platforms import current_platform
10+
11+
logger = init_logger(__name__)
12+
13+
QuantizationMethods = Literal[
14+
"awq",
15+
"fp8",
16+
"ptpc_fp8",
17+
"fbgemm_fp8",
18+
# "fp_quant",
19+
"modelopt",
20+
"modelopt_fp4",
21+
"bitblas",
22+
"gguf",
23+
"gptq_marlin_24",
24+
"gptq_marlin",
25+
"gptq_bitblas",
26+
"awq_marlin",
27+
"gptq",
28+
"compressed-tensors",
29+
"bitsandbytes",
30+
"experts_int8",
31+
"ipex",
32+
"quark",
33+
"moe_wna16",
34+
"torchao",
35+
"inc",
36+
"mxfp4",
37+
"petit_nvfp4",
38+
"cpu_awq",
39+
]
40+
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
41+
42+
DEPRECATED_QUANTIZATION_METHODS = [
43+
"tpu_int8",
44+
"ptpc_fp8",
45+
"fbgemm_fp8",
46+
# "fp_quant",
47+
"bitblas",
48+
"gptq_marlin_24",
49+
"gptq_bitblas",
50+
"experts_int8",
51+
"ipex",
52+
"petit_nvfp4",
53+
]
54+
55+
# The customized quantization methods which will be added to this dict.
56+
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {}
57+
58+
59+
def register_quantization_config(quantization: str):
60+
"""Register a customized vllm quantization config.
61+
62+
When a quantization method is not supported by vllm, you can register a customized
63+
quantization config to support it.
64+
65+
Args:
66+
quantization (str): The quantization method name.
67+
68+
Examples:
69+
>>> from vllm.model_executor.layers.quantization import (
70+
... register_quantization_config,
71+
... )
72+
>>> from vllm.model_executor.layers.quantization import get_quantization_config
73+
>>> from vllm.model_executor.layers.quantization.base_config import (
74+
... QuantizationConfig,
75+
... )
76+
>>>
77+
>>> @register_quantization_config("my_quant")
78+
... class MyQuantConfig(QuantizationConfig):
79+
... pass
80+
>>>
81+
>>> get_quantization_config("my_quant")
82+
<class 'MyQuantConfig'>
83+
""" # noqa: E501
84+
85+
def _wrapper(quant_config_cls):
86+
if quantization in QUANTIZATION_METHODS:
87+
logger.warning(
88+
"The quantization method '%s' already exists and will be "
89+
"overwritten by the quantization config %s.",
90+
quantization,
91+
quant_config_cls,
92+
)
93+
else:
94+
QUANTIZATION_METHODS.append(quantization)
95+
# Automatically assume the custom quantization config is supported
96+
if sq := current_platform.supported_quantization:
97+
sq.append(quantization)
98+
99+
if not issubclass(quant_config_cls, QuantizationConfig):
100+
raise ValueError(
101+
"The quantization config must be a subclass of `QuantizationConfig`."
102+
)
103+
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls
104+
return quant_config_cls
105+
106+
return _wrapper
107+
108+
109+
def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
110+
if quantization not in QUANTIZATION_METHODS:
111+
raise ValueError(f"Invalid quantization method: {quantization}")
112+
113+
# lazy import to avoid triggering `torch.compile` too early
114+
from vllm.model_executor.layers.quantization.awq import AWQConfig
115+
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
116+
from vllm.model_executor.layers.quantization.bitblas import BitBLASConfig
117+
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
118+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
119+
CompressedTensorsConfig,
120+
)
121+
from vllm.model_executor.layers.quantization.cpu_wna16 import CPUAWQConfig
122+
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
123+
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
124+
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
125+
126+
# from vllm.model_executor.layers.quantization.fp_quant import FPQuantConfig
127+
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
128+
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
129+
from vllm.model_executor.layers.quantization.gptq_bitblas import GPTQBitBLASConfig
130+
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
131+
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
132+
GPTQMarlin24Config,
133+
)
134+
from vllm.model_executor.layers.quantization.inc import INCConfig
135+
from vllm.model_executor.layers.quantization.ipex_quant import IPEXConfig
136+
from vllm.model_executor.layers.quantization.modelopt import (
137+
ModelOptFp8Config,
138+
ModelOptNvFp4Config,
139+
)
140+
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
141+
from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4Config
142+
from vllm.model_executor.layers.quantization.petit import PetitNvFp4Config
143+
from vllm.model_executor.layers.quantization.ptpc_fp8 import PTPCFp8Config
144+
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
145+
from vllm.model_executor.layers.quantization.torchao import TorchAOConfig
146+
147+
method_to_config: dict[str, type[QuantizationConfig]] = {
148+
"awq": AWQConfig,
149+
"fp8": Fp8Config,
150+
"fbgemm_fp8": FBGEMMFp8Config,
151+
# "fp_quant": FPQuantConfig,
152+
"modelopt": ModelOptFp8Config,
153+
"modelopt_fp4": ModelOptNvFp4Config,
154+
"bitblas": BitBLASConfig,
155+
"gguf": GGUFConfig,
156+
"gptq_marlin_24": GPTQMarlin24Config,
157+
"gptq_marlin": GPTQMarlinConfig,
158+
"gptq_bitblas": GPTQBitBLASConfig,
159+
"awq_marlin": AWQMarlinConfig,
160+
"gptq": GPTQConfig,
161+
"compressed-tensors": CompressedTensorsConfig,
162+
"bitsandbytes": BitsAndBytesConfig,
163+
"ptpc_fp8": PTPCFp8Config,
164+
"experts_int8": ExpertsInt8Config,
165+
"ipex": IPEXConfig,
166+
"quark": QuarkConfig,
167+
"moe_wna16": MoeWNA16Config,
168+
"torchao": TorchAOConfig,
169+
"auto-round": INCConfig,
170+
"inc": INCConfig,
171+
"mxfp4": Mxfp4Config,
172+
"petit_nvfp4": PetitNvFp4Config,
173+
"cpu_awq": CPUAWQConfig,
174+
}
175+
# Update the `method_to_config` with customized quantization methods.
176+
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
177+
178+
return method_to_config[quantization]
179+
180+
181+
__all__ = [
182+
"QuantizationConfig",
183+
"QuantizationMethods",
184+
"get_quantization_config",
185+
"register_quantization_config",
186+
"QUANTIZATION_METHODS",
187+
]

0 commit comments

Comments
 (0)