Skip to content

Commit 8f53a20

Browse files
liwei109Li Wei
authored andcommitted
[FP8] support FP8 quantization
1 parent bf0482a commit 8f53a20

File tree

4 files changed

+89
-26
lines changed

4 files changed

+89
-26
lines changed

mppq/data.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ class DataType(IntEnum):
3434
INT64 = TensorProto.INT64
3535
UINT64 = TensorProto.UINT64
3636

37+
FP8_E4M3FN = TensorProto.FLOAT8E4M3FN
38+
FP8_E4M3FNUZ = TensorProto.FLOAT8E4M3FNUZ
3739
FP8_E5M2 = TensorProto.FLOAT8E5M2
40+
FP8_E5M2FNUZ = TensorProto.FLOAT8E5M2FNUZ
3841
BF16 = TensorProto.BFLOAT16
3942
FP16 = TensorProto.FLOAT16
4043
FP32 = TensorProto.FLOAT
@@ -89,7 +92,10 @@ def from_torch(cls, dtype: torch_type):
8992
torch.float16: DataType.FP16,
9093
torch.float32: DataType.FP32,
9194
torch.float64: DataType.FP64,
95+
torch.float8_e4m3fn: DataType.FP8_E4M3FN,
96+
torch.float8_e4m3fnuz: DataType.FP8_E4M3FNUZ,
9297
torch.float8_e5m2: DataType.FP8_E5M2,
98+
torch.float8_e5m2fnuz: DataType.FP8_E5M2FNUZ,
9399
}
94100
if dtype not in torch_converting_dict:
95101
raise TypeError(
@@ -129,7 +135,10 @@ def to_torch(cls, dtype) -> torch_type:
129135
DataType.FP16: torch.float16,
130136
DataType.FP32: torch.float32,
131137
DataType.FP64: torch.float64,
138+
DataType.FP8_E4M3FN: torch.float8_e4m3fn,
139+
DataType.FP8_E4M3FNUZ: torch.float8_e4m3fnuz,
132140
DataType.FP8_E5M2: torch.float8_e5m2,
141+
DataType.FP8_E5M2FNUZ: torch.float8_e5m2fnuz,
133142
}
134143
assert isinstance(dtype, int)
135144
return torch_converting_dict[DataType(dtype)]
@@ -211,6 +220,10 @@ def convert_any_to_numpy(x: Any, accept_none: bool = True) -> None | np.ndarray:
211220
raise ValueError("Trying to convert an empty value.")
212221
return x
213222
elif isinstance(x, torch.Tensor):
223+
if "float8" in str(x.dtype):
224+
return convert_any_to_numpy(
225+
x.cpu().to(torch.float32).numpy(), accept_none=accept_none
226+
)
214227
return convert_any_to_numpy(x.cpu().numpy(), accept_none=accept_none)
215228
elif isinstance(x, Number):
216229
return np.array([x])

mppq/frontend/onnx/openvino_exporter.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,18 @@ def insert_quantize_node(
5959
elif config.policy.has_property(QuantizationProperty.FLOATING):
6060
# Following code will export Linear Quantization Config
6161
# That is for FP32 -> FP8
62+
offset_dtype, value_type = self.infer_qtype(config)
6263
scale = convert_any_to_tensor(config.scale.clone(), dtype=torch.float32)
63-
offset = convert_any_to_tensor(config.offset.clone(), dtype=torch.float32)
64+
offset = convert_any_to_tensor(config.offset.clone(), dtype=offset_dtype)
6465

6566
created = graph.create_operation(
66-
op_type="QuantizeFloating",
67-
attributes={
68-
"min": config.quant_min,
69-
"max": config.quant_max,
70-
"exponent": config.exponent_bits,
71-
"mantissa": config.mantissa_bits,
72-
},
67+
op_type="QuantizeLinear",
68+
# attributes={
69+
# "min": config.quant_min,
70+
# "max": config.quant_max,
71+
# "exponent": config.exponent_bits,
72+
# "mantissa": config.mantissa_bits,
73+
# },
7374
)
7475

7576
if config.policy.has_property(QuantizationProperty.PER_CHANNEL):
@@ -89,10 +90,11 @@ def insert_quantize_node(
8990
graph.create_variable(
9091
name=None, value=scale, is_parameter=True, dest_ops=[created]
9192
)
92-
graph.create_variable(
93-
name=None, value=offset, is_parameter=True, dest_ops=[created]
94-
)
93+
# graph.create_variable(
94+
# name=None, value=offset, is_parameter=True, dest_ops=[created]
95+
# ) # zero_point is not used for FP8
9596

97+
created.outputs[0].dtype = value_type
9698
created.outputs[0].shape = var.shape
9799
created.inputs[0].shape = var.shape
98100
return created
@@ -149,17 +151,18 @@ def insert_dequantize_node(
149151
return created
150152

151153
elif config.policy.has_property(QuantizationProperty.FLOATING):
154+
offset_dtype, value_type = self.infer_qtype(config)
152155
scale = convert_any_to_tensor(config.scale.clone(), dtype=torch.float32)
153-
offset = convert_any_to_tensor(config.offset.clone(), dtype=torch.float32)
156+
offset = convert_any_to_tensor(config.offset.clone(), dtype=offset_dtype)
154157

155158
created = graph.create_operation(
156-
op_type="DequantizeFloating",
157-
attributes={
158-
"min": config.quant_min,
159-
"max": config.quant_max,
160-
"exponent": config.exponent_bits,
161-
"mantissa": config.mantissa_bits,
162-
},
159+
op_type="DequantizeLinear",
160+
# attributes={
161+
# "min": config.quant_min,
162+
# "max": config.quant_max,
163+
# "exponent": config.exponent_bits,
164+
# "mantissa": config.mantissa_bits,
165+
# },
163166
)
164167

165168
if config.policy.has_property(QuantizationProperty.PER_CHANNEL):
@@ -179,12 +182,14 @@ def insert_dequantize_node(
179182
graph.create_variable(
180183
name=None, value=scale, is_parameter=True, dest_ops=[created]
181184
)
182-
graph.create_variable(
183-
name=None, value=offset, is_parameter=True, dest_ops=[created]
184-
)
185+
# graph.create_variable(
186+
# name=None, value=offset, is_parameter=True, dest_ops=[created]
187+
# )
185188

186-
created.outputs[0].shape = var.shape
187189
created.inputs[0].shape = var.shape
190+
created.inputs[0].dtype = value_type
191+
created.outputs[0].shape = var.shape
192+
created.outputs[0].dtype = torch.float32
188193

189194
return created
190195

mppq/utils/qfunction/__init__.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
QuantizationStates,
77
TensorQuantizationConfig,
88
)
9-
from mppq.utils.qfunction.floating import floating_quant
9+
from mppq.utils.qfunction.floating import floating_fake_quant, floating_quant_tofp8
1010
from mppq.utils.qfunction.linear import (
1111
dynamic_linear_quant,
1212
linear_fake_quant,
@@ -52,7 +52,9 @@ def ppq_fake_quant(
5252

5353
if config.policy.has_property(QuantizationProperty.FLOATING):
5454
if not config.policy.has_property(QuantizationProperty.DYNAMIC):
55-
return floating_quant(tensor, config)
55+
return floating_fake_quant(tensor, config)
56+
else:
57+
raise NotImplementedError("Dynamic floating quant is not support now!")
5658

5759
raise ValueError(
5860
"Unexpected Quantization Property Found in ppq_fake_quant. "
@@ -80,8 +82,29 @@ def ppq_quant_toint(
8082
)
8183

8284

85+
def ppq_quant_tofp8(
86+
tensor: torch.Tensor, config: TensorQuantizationConfig
87+
) -> torch.Tensor:
88+
"""
89+
## PPQ 核心量化函数
90+
91+
根据 config 中描述的策略,这个函数将会执行线性量化,动态量化
92+
93+
但是结果直接出来是float8类型
94+
"""
95+
if config.policy.has_property(QuantizationProperty.FLOATING):
96+
if not config.policy.has_property(QuantizationProperty.DYNAMIC):
97+
return floating_quant_tofp8(tensor, config)
98+
99+
raise ValueError(
100+
"Unexpected Quantization Property Found in ppq_quant_tofp8. "
101+
"Do not know how to quantize your config yet."
102+
)
103+
104+
83105
__all__ = [
84106
"ppq_fake_quant",
85107
"ppq_quant_toint",
108+
"ppq_quant_tofloat",
86109
"BaseQuantFunction",
87110
]

mppq/utils/qfunction/floating.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def backward(ctx, *dy: torch.Tensor):
107107
return dy[0], None, None, None, None, None, None, None, None, None
108108

109109

110-
def floating_quant(
110+
def floating_fake_quant(
111111
tensor: torch.Tensor, config: TensorQuantizationConfig
112112
) -> torch.Tensor:
113113
"""PPQ 核心量化函数,没啥好说的了吧,这个玩意既做 quant 也做 dequant"""
@@ -137,3 +137,25 @@ def floating_quant(
137137
)
138138
assert isinstance(qtensor, torch.Tensor)
139139
return qtensor
140+
141+
142+
def floating_quant_tofp8(
143+
tensor: torch.Tensor, config: TensorQuantizationConfig
144+
) -> torch.Tensor:
145+
"""PPQ 核心量化函数,没啥好说的了吧,这个玩意只做 quant 不做 dequant"""
146+
if not config.policy.has_property(QuantizationProperty.FLOATING):
147+
raise ValueError("Critical Quantization Error! Non-floating config detected.")
148+
149+
if config.policy.has_property(QuantizationProperty.PER_CHANNEL):
150+
shape = [
151+
1 if axis != config.channel_axis else -1 for axis in range(tensor.ndim)
152+
]
153+
scale = config.scale.view(shape)
154+
offset = config.offset.view(shape).to(tensor.device)
155+
tensor = (tensor / scale) + offset
156+
else: # QuantizationProperty.PER_TENSOR
157+
tensor = (tensor / config.scale.to(tensor.device)) + config.offset.to(
158+
tensor.device
159+
)
160+
tensor = torch.clamp(tensor, config.quant_min, config.quant_max)
161+
return tensor.to(torch.float8_e4m3fn)

0 commit comments

Comments
 (0)