Skip to content

Commit 7858493

Browse files
author
Li Wei
committed
[FP8] support FP8 quantization
1 parent 33b2b59 commit 7858493

File tree

5 files changed

+98
-26
lines changed

5 files changed

+98
-26
lines changed

mppq/data.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ class DataType(IntEnum):
3434
INT64 = TensorProto.INT64
3535
UINT64 = TensorProto.UINT64
3636

37+
FP8_E4M3FN = TensorProto.FLOAT8E4M3FN
38+
FP8_E4M3FNUZ = TensorProto.FLOAT8E4M3FNUZ
3739
FP8_E5M2 = TensorProto.FLOAT8E5M2
40+
FP8_E5M2FNUZ = TensorProto.FLOAT8E5M2FNUZ
3841
BF16 = TensorProto.BFLOAT16
3942
FP16 = TensorProto.FLOAT16
4043
FP32 = TensorProto.FLOAT
@@ -89,7 +92,10 @@ def from_torch(cls, dtype: torch_type):
8992
torch.float16: DataType.FP16,
9093
torch.float32: DataType.FP32,
9194
torch.float64: DataType.FP64,
95+
torch.float8_e4m3fn: DataType.FP8_E4M3FN,
96+
torch.float8_e4m3fnuz: DataType.FP8_E4M3FNUZ,
9297
torch.float8_e5m2: DataType.FP8_E5M2,
98+
torch.float8_e5m2fnuz: DataType.FP8_E5M2FNUZ,
9399
}
94100
if dtype not in torch_converting_dict:
95101
raise TypeError(
@@ -129,7 +135,10 @@ def to_torch(cls, dtype) -> torch_type:
129135
DataType.FP16: torch.float16,
130136
DataType.FP32: torch.float32,
131137
DataType.FP64: torch.float64,
138+
DataType.FP8_E4M3FN: torch.float8_e4m3fn,
139+
DataType.FP8_E4M3FNUZ: torch.float8_e4m3fnuz,
132140
DataType.FP8_E5M2: torch.float8_e5m2,
141+
DataType.FP8_E5M2FNUZ: torch.float8_e5m2fnuz,
133142
}
134143
assert isinstance(dtype, int)
135144
return torch_converting_dict[DataType(dtype)]
@@ -211,6 +220,10 @@ def convert_any_to_numpy(x: Any, accept_none: bool = True) -> None | np.ndarray:
211220
raise ValueError("Trying to convert an empty value.")
212221
return x
213222
elif isinstance(x, torch.Tensor):
223+
if "float8" in str(x.dtype):
224+
return convert_any_to_numpy(
225+
x.cpu().to(torch.float32).numpy(), accept_none=accept_none
226+
)
214227
return convert_any_to_numpy(x.cpu().numpy(), accept_none=accept_none)
215228
elif isinstance(x, Number):
216229
return np.array([x])

mppq/ffi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def dummy_locator():
152152
class ENABLE_CUDA_KERNEL:
153153
"""Auto config compiler path before entering compiling CUDA context"""
154154

155-
USING_CUDA_KERNEL = False
155+
USING_CUDA_KERNEL = True
156156

157157
def __init__(self) -> None:
158158
self._state = True

mppq/frontend/onnx/onnxruntime_exporter.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
QuantVisibility,
1616
TensorQuantizationConfig,
1717
)
18-
from mppq.utils.qfunction import ppq_quant_toint
18+
from mppq.utils.qfunction import ppq_quant_toint, ppq_quant_tofloat
1919
from mppq.utils.round import ppq_tensor_round
2020

2121

@@ -91,6 +91,12 @@ def infer_qtype(self, config: TensorQuantizationConfig):
9191
if config.num_of_bits > 8:
9292
offset_dtype = torch.int32
9393
value_dtype = torch.int32
94+
if config.exponent_bits == 4:
95+
offset_dtype = torch.float8_e4m3fn
96+
value_dtype = torch.float8_e4m3fn
97+
if config.exponent_bits == 5:
98+
offset_dtype = torch.float8_e5m2
99+
value_dtype = torch.float8_e5m2
94100
return offset_dtype, value_dtype
95101

96102
def insert_quantize_node(
@@ -141,17 +147,18 @@ def insert_quantize_node(
141147
elif config.policy.has_property(QuantizationProperty.FLOATING):
142148
# Following code will export Linear Quantization Config
143149
# That is for FP32 -> FP8
150+
offset_dtype, value_type = self.infer_qtype(config)
144151
scale = convert_any_to_tensor(config.scale.clone(), dtype=torch.float32)
145-
offset = convert_any_to_tensor(config.offset.clone(), dtype=torch.float32)
152+
offset = convert_any_to_tensor(config.offset.clone(), dtype=offset_dtype)
146153

147154
created = graph.create_operation(
148-
op_type="QuantizeFloating",
149-
attributes={
150-
"min": config.quant_min,
151-
"max": config.quant_max,
152-
"exponent": config.exponent_bits,
153-
"mantissa": config.mantissa_bits,
154-
},
155+
op_type="QuantizeLinear",
156+
# attributes={
157+
# "min": config.quant_min,
158+
# "max": config.quant_max,
159+
# "exponent": config.exponent_bits,
160+
# "mantissa": config.mantissa_bits,
161+
# },
155162
)
156163

157164
if config.policy.has_property(QuantizationProperty.PER_CHANNEL):
@@ -171,10 +178,11 @@ def insert_quantize_node(
171178
graph.create_variable(
172179
name=None, value=scale, is_parameter=True, dest_ops=[created]
173180
)
174-
graph.create_variable(
175-
name=None, value=offset, is_parameter=True, dest_ops=[created]
176-
)
181+
# graph.create_variable(
182+
# name=None, value=offset, is_parameter=True, dest_ops=[created]
183+
# ) # zero_point is not used for FP8
177184

185+
created.outputs[0].dtype = value_type
178186
created.outputs[0].shape = var.shape
179187
created.inputs[0].shape = var.shape
180188
return created
@@ -231,17 +239,18 @@ def insert_dequantize_node(
231239
return created
232240

233241
elif config.policy.has_property(QuantizationProperty.FLOATING):
242+
offset_dtype, value_type = self.infer_qtype(config)
234243
scale = convert_any_to_tensor(config.scale.clone(), dtype=torch.float32)
235-
offset = convert_any_to_tensor(config.offset.clone(), dtype=torch.float32)
244+
offset = convert_any_to_tensor(config.offset.clone(), dtype=offset_dtype)
236245

237246
created = graph.create_operation(
238-
op_type="DequantizeFloating",
239-
attributes={
240-
"min": config.quant_min,
241-
"max": config.quant_max,
242-
"exponent": config.exponent_bits,
243-
"mantissa": config.mantissa_bits,
244-
},
247+
op_type="DequantizeLinear",
248+
# attributes={
249+
# "min": config.quant_min,
250+
# "max": config.quant_max,
251+
# "exponent": config.exponent_bits,
252+
# "mantissa": config.mantissa_bits,
253+
# },
245254
)
246255

247256
if config.policy.has_property(QuantizationProperty.PER_CHANNEL):
@@ -261,12 +270,14 @@ def insert_dequantize_node(
261270
graph.create_variable(
262271
name=None, value=scale, is_parameter=True, dest_ops=[created]
263272
)
264-
graph.create_variable(
265-
name=None, value=offset, is_parameter=True, dest_ops=[created]
266-
)
273+
# graph.create_variable(
274+
# name=None, value=offset, is_parameter=True, dest_ops=[created]
275+
# )
267276

268-
created.outputs[0].shape = var.shape
269277
created.inputs[0].shape = var.shape
278+
created.inputs[0].dtype = value_type
279+
created.outputs[0].shape = var.shape
280+
created.outputs[0].dtype = torch.float32
270281

271282
return created
272283
else:
@@ -468,6 +479,11 @@ def convert_operation(
468479
):
469480
var.value = ppq_quant_toint(tensor=var.value, config=config)
470481

482+
if quantized_param and config.policy.has_property(
483+
QuantizationProperty.FLOATING
484+
):
485+
var.value = ppq_quant_tofloat(tensor=var.value, config=config)
486+
471487
elif not var.is_parameter:
472488

473489
# Patch 20230103:

mppq/utils/qfunction/__init__.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
QuantizationStates,
77
TensorQuantizationConfig,
88
)
9-
from mppq.utils.qfunction.floating import floating_quant
9+
from mppq.utils.qfunction.floating import floating_quant, floating_quant_tofloat
1010
from mppq.utils.qfunction.linear import (
1111
dynamic_linear_quant,
1212
linear_fake_quant,
@@ -80,8 +80,29 @@ def ppq_quant_toint(
8080
)
8181

8282

83+
def ppq_quant_tofloat(
84+
tensor: torch.Tensor, config: TensorQuantizationConfig
85+
) -> torch.Tensor:
86+
"""
87+
## PPQ 核心量化函数
88+
89+
根据 config 中描述的策略,这个函数将会执行线性量化,动态量化
90+
91+
但是结果直接出来是float8类型
92+
"""
93+
if config.policy.has_property(QuantizationProperty.FLOATING):
94+
if not config.policy.has_property(QuantizationProperty.DYNAMIC):
95+
return floating_quant_tofloat(tensor, config)
96+
97+
raise ValueError(
98+
"Unexpected Quantization Property Found in ppq_quant_tofp8. "
99+
"Do not know how to quantize your config yet."
100+
)
101+
102+
83103
__all__ = [
84104
"ppq_fake_quant",
85105
"ppq_quant_toint",
106+
"ppq_quant_tofloat",
86107
"BaseQuantFunction",
87108
]

mppq/utils/qfunction/floating.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,25 @@ def floating_quant(
137137
)
138138
assert isinstance(qtensor, torch.Tensor)
139139
return qtensor
140+
141+
142+
def floating_quant_tofloat(
143+
tensor: torch.Tensor, config: TensorQuantizationConfig
144+
) -> torch.Tensor:
145+
"""PPQ 核心量化函数,没啥好说的了吧,这个玩意只做 quant 不做 dequant"""
146+
if not config.policy.has_property(QuantizationProperty.FLOATING):
147+
raise ValueError("Critical Quantization Error! Non-floating config detected.")
148+
149+
if config.policy.has_property(QuantizationProperty.PER_CHANNEL):
150+
shape = [
151+
1 if axis != config.channel_axis else -1 for axis in range(tensor.ndim)
152+
]
153+
scale = config.scale.view(shape)
154+
offset = config.offset.view(shape).to(tensor.device)
155+
tensor = (tensor / scale) + offset
156+
else: # QuantizationProperty.PER_TENSOR
157+
tensor = (tensor / config.scale.to(tensor.device)) + config.offset.to(
158+
tensor.device
159+
)
160+
tensor = torch.clamp(tensor, config.quant_min, config.quant_max)
161+
return tensor.to(torch.float8_e4m3fn)

0 commit comments

Comments
 (0)