1515 QuantVisibility ,
1616 TensorQuantizationConfig ,
1717)
18- from mppq .utils .qfunction import ppq_quant_toint
18+ from mppq .utils .qfunction import ppq_quant_toint , ppq_quant_tofloat
1919from mppq .utils .round import ppq_tensor_round
2020
2121
@@ -91,6 +91,12 @@ def infer_qtype(self, config: TensorQuantizationConfig):
9191 if config .num_of_bits > 8 :
9292 offset_dtype = torch .int32
9393 value_dtype = torch .int32
94+ if config .exponent_bits == 4 :
95+ offset_dtype = torch .float8_e4m3fn
96+ value_dtype = torch .float8_e4m3fn
97+ if config .exponent_bits == 5 :
98+ offset_dtype = torch .float8_e5m2
99+ value_dtype = torch .float8_e5m2
94100 return offset_dtype , value_dtype
95101
96102 def insert_quantize_node (
@@ -141,17 +147,18 @@ def insert_quantize_node(
141147 elif config .policy .has_property (QuantizationProperty .FLOATING ):
142148 # Following code will export Linear Quantization Config
143149 # That is for FP32 -> FP8
150+ offset_dtype , value_type = self .infer_qtype (config )
144151 scale = convert_any_to_tensor (config .scale .clone (), dtype = torch .float32 )
145- offset = convert_any_to_tensor (config .offset .clone (), dtype = torch . float32 )
152+ offset = convert_any_to_tensor (config .offset .clone (), dtype = offset_dtype )
146153
147154 created = graph .create_operation (
148- op_type = "QuantizeFloating " ,
149- attributes = {
150- "min" : config .quant_min ,
151- "max" : config .quant_max ,
152- "exponent" : config .exponent_bits ,
153- "mantissa" : config .mantissa_bits ,
154- },
155+ op_type = "QuantizeLinear " ,
156+ # attributes={
157+ # "min": config.quant_min,
158+ # "max": config.quant_max,
159+ # "exponent": config.exponent_bits,
160+ # "mantissa": config.mantissa_bits,
161+ # },
155162 )
156163
157164 if config .policy .has_property (QuantizationProperty .PER_CHANNEL ):
@@ -171,10 +178,11 @@ def insert_quantize_node(
171178 graph .create_variable (
172179 name = None , value = scale , is_parameter = True , dest_ops = [created ]
173180 )
174- graph .create_variable (
175- name = None , value = offset , is_parameter = True , dest_ops = [created ]
176- )
181+ # graph.create_variable(
182+ # name=None, value=offset, is_parameter=True, dest_ops=[created]
183+ # ) # zero_point is not used for FP8
177184
185+ created .outputs [0 ].dtype = value_type
178186 created .outputs [0 ].shape = var .shape
179187 created .inputs [0 ].shape = var .shape
180188 return created
@@ -231,17 +239,18 @@ def insert_dequantize_node(
231239 return created
232240
233241 elif config .policy .has_property (QuantizationProperty .FLOATING ):
242+ offset_dtype , value_type = self .infer_qtype (config )
234243 scale = convert_any_to_tensor (config .scale .clone (), dtype = torch .float32 )
235- offset = convert_any_to_tensor (config .offset .clone (), dtype = torch . float32 )
244+ offset = convert_any_to_tensor (config .offset .clone (), dtype = offset_dtype )
236245
237246 created = graph .create_operation (
238- op_type = "DequantizeFloating " ,
239- attributes = {
240- "min" : config .quant_min ,
241- "max" : config .quant_max ,
242- "exponent" : config .exponent_bits ,
243- "mantissa" : config .mantissa_bits ,
244- },
247+ op_type = "DequantizeLinear " ,
248+ # attributes={
249+ # "min": config.quant_min,
250+ # "max": config.quant_max,
251+ # "exponent": config.exponent_bits,
252+ # "mantissa": config.mantissa_bits,
253+ # },
245254 )
246255
247256 if config .policy .has_property (QuantizationProperty .PER_CHANNEL ):
@@ -261,12 +270,14 @@ def insert_dequantize_node(
261270 graph .create_variable (
262271 name = None , value = scale , is_parameter = True , dest_ops = [created ]
263272 )
264- graph .create_variable (
265- name = None , value = offset , is_parameter = True , dest_ops = [created ]
266- )
273+ # graph.create_variable(
274+ # name=None, value=offset, is_parameter=True, dest_ops=[created]
275+ # )
267276
268- created .outputs [0 ].shape = var .shape
269277 created .inputs [0 ].shape = var .shape
278+ created .inputs [0 ].dtype = value_type
279+ created .outputs [0 ].shape = var .shape
280+ created .outputs [0 ].dtype = torch .float32
270281
271282 return created
272283 else :
@@ -468,6 +479,11 @@ def convert_operation(
468479 ):
469480 var .value = ppq_quant_toint (tensor = var .value , config = config )
470481
482+ if quantized_param and config .policy .has_property (
483+ QuantizationProperty .FLOATING
484+ ):
485+ var .value = ppq_quant_tofloat (tensor = var .value , config = config )
486+
471487 elif not var .is_parameter :
472488
473489 # Patch 20230103:
0 commit comments