-
Notifications
You must be signed in to change notification settings - Fork 723
Expand file tree
/
Copy path__init__.py
More file actions
217 lines (194 loc) · 8.71 KB
/
__init__.py
File metadata and controls
217 lines (194 loc) · 8.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
quantization module
"""
from typing import List, Type
from fastdeploy import envs
from fastdeploy.utils import parse_quantization
from .quant_base import QuantConfigBase
QUANTIZATION_METHODS: List[str] = [
"wint2",
"wint4",
"wint8",
"weight_only",
"block_wise_fp8",
"w4afp8",
"w8a8",
"w4a8",
"wfp8afp8",
"mix_quant",
"tensor_wise_fp8",
"kvcache",
"modelopt_fp4",
"mxfp4",
]
def _compute_hadamard_block_size(moe_intermediate_size: int, tp_size: int) -> int:
if moe_intermediate_size % tp_size != 0:
raise ValueError(
f"moe_intermediate_size ({moe_intermediate_size}) must be divisible by " f"tp_size ({tp_size})"
)
shard_size = moe_intermediate_size // tp_size
block_size = shard_size & (-shard_size)
block_size = min(block_size, 512)
return block_size
def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
if args.quantization is not None and isinstance(args.quantization, str):
args.quantization = parse_quantization(args.quantization)
# 1.model_config.is_quantized
# TODO(bukejiyu) model_config.is_quantized is v0 only need to be removed in future
if model_config.model_format == "torch":
quantization_config = model_config.quantization_config
if quantization_config is not None:
model_config.is_quantized = True
else:
quantization_config = model_config.quantization_config
if not model_config.is_quantized:
if quantization_config is not None:
if "is_quantized" in quantization_config:
model_config.is_quantized = quantization_config["is_quantized"]
elif "is_moe_quantized" in quantization_config:
model_config.is_moe_quantized = quantization_config["is_moe_quantized"]
elif "kv_cache_quant_type" not in quantization_config:
model_config.is_quantized = True
if "is_moe_quantized" not in quantization_config:
model_config.is_quantized = True
else:
model_config.is_moe_quantized = True
if quantization_config is not None and quantization_config.get("quantization", None) is None:
raise ValueError(
"quantization_config should have a key named 'quantization' for specify quant config."
)
quant_config_name = None
if quantization_config is not None:
quant_config_name = _get_offline_quant_config_name(
quantization_config, model_config.model_format == "torch", is_v1_loader
)
elif args.quantization is not None:
quantization_config = {}
try:
quantization_config.update(args.quantization)
quant_config_name = quantization_config["quantization"]
except:
quant_config_name = args.quantization["quantization"]
quantization_config["quantization"] = quant_config_name
# Special handling for Ernie models
if quant_config_name == "wint4" and is_ernie:
quantization_config["dense_quant_type"] = "wint8"
quantization_config["moe_quant_type"] = "wint4"
quantization_config["quantization"] = "mix_quant"
quant_config_name = "mix_quant"
# Special handling for moe w4afp8 dynamic quant
elif quant_config_name == "w4afp8":
quantization_config["dense_quant_type"] = "block_wise_fp8"
quantization_config["moe_quant_type"] = "w4afp8"
tp_size = getattr(args, "tensor_parallel_size", 1)
moe_intermediate_size = getattr(model_config, "moe_intermediate_size", None)
if moe_intermediate_size is not None:
hadamard_block_size = _compute_hadamard_block_size(moe_intermediate_size, tp_size)
quantization_config["hadamard_block_size"] = hadamard_block_size
else:
quantization_config["hadamard_block_size"] = 512
quantization_config["quantization"] = "mix_quant"
quant_config_name = "mix_quant"
else:
quant_config_name = None
if quant_config_name is None:
quant_config = None
else:
# Handle both dict and QuantizationConfig object
if hasattr(quantization_config, "to_dict"):
quantization_config_dict = quantization_config.to_dict()
else:
quantization_config_dict = quantization_config if isinstance(quantization_config, dict) else {}
if not quantization_config_dict.get("is_quantized"):
quantization_config_dict["is_quantized"] = model_config.is_quantized
if args.dynamic_load_weight and quantization_config is not None:
quantization_config_dict["is_quantized"] = True
quant_cls = get_quantization_config(quant_config_name)
quant_config = quant_cls.from_config(quantization_config_dict)
return quant_config
def _get_offline_quant_config_name(quantization_config, is_torch_weight, is_v1_loader):
if is_torch_weight:
# only support block_wise_fp8 now
# Handle both dict and QuantizationConfig object
if hasattr(quantization_config, "quant_method"):
quant_method = quantization_config.quant_method
else:
quant_method = quantization_config.get("quant_method")
has_block_size = (
"weight_block_size" in quantization_config
if isinstance(quantization_config, dict)
else hasattr(quantization_config, "weight_block_size")
and quantization_config.weight_block_size is not None
)
if quant_method == "fp8" and has_block_size:
quant_config_name = "block_wise_fp8"
elif quant_method == "modelopt":
# Try to get quant_algo from dict or from to_dict() method
quant_algo = None
if isinstance(quantization_config, dict):
quant_algo = quantization_config.get("quant_algo", "")
elif hasattr(quantization_config, "to_dict"):
quant_algo = quantization_config.to_dict().get("quant_algo", "")
if quant_algo == "NVFP4":
quant_config_name = "modelopt_fp4"
else:
raise ValueError(f"modelopt only supports NVFP4 quantization, got quant_algo={quant_algo}")
elif quant_method == "mxfp4":
quant_config_name = "mxfp4"
else:
raise ValueError(
f"Torch weight offline quantization only supports block-wise FP8, modelopt NVFP4, or mxfp4. Got quant_method={quant_method}"
)
else:
quant_config_name = quantization_config["quantization"]
return quant_config_name
def get_quantization_config(quantization: str) -> Type[QuantConfigBase]:
"""
Get the quantization config class by the quantization name.
"""
if quantization not in QUANTIZATION_METHODS:
raise ValueError(f"Invalid quantization method: {quantization}")
from .block_wise_fp8 import BlockWiseFP8Config
from .kv_cache import KvCacheQuantConfig
from .mix_quant import MixQuantConfig
from .nvfp4 import ModelOptNvFp4Config
from .tensor_wise_fp8 import TensorWiseFP8Config
from .w4a8 import W4A8Config
from .w4afp8 import W4AFP8Config
from .w8a8 import W8A8Config
from .weight_only import WeightOnlyConfig, WINT4Config, WINT8Config
from .wfp8afp8 import WFP8AFP8Config
from .wint2 import WINT2Config
if envs.FD_MOE_MXFP4_BACKEND is not None:
from .mxfp4 import MXFP4Config
method_to_config = {
"wint2": WINT2Config,
"wint4": WINT4Config,
"wint8": WINT8Config,
"weight_only": WeightOnlyConfig,
"block_wise_fp8": BlockWiseFP8Config,
"w4afp8": W4AFP8Config,
"w8a8": W8A8Config,
"w4a8": W4A8Config,
"wfp8afp8": WFP8AFP8Config,
"tensor_wise_fp8": TensorWiseFP8Config,
"kvcache": KvCacheQuantConfig,
"mix_quant": MixQuantConfig,
"modelopt_fp4": ModelOptNvFp4Config,
}
if envs.FD_MOE_MXFP4_BACKEND is not None:
method_to_config["mxfp4"] = MXFP4Config
return method_to_config[quantization]