1010import torch .nn as nn
1111
1212import brevitas .config as config
13+ from brevitas .core .function_wrapper .shape import DynamicOverSubChannelBlockView
1314from brevitas .function import compute_max_mantissa
1415from brevitas .function .ops import max_float
1516from brevitas .function .ops import max_int
@@ -155,28 +156,18 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):
155156 return output_dict
156157
157158
158- class IntInferencetHandler (InferenceHandler , FloatToIntMixin ):
159- handled_layer = (ActQuantProxyFromInjector , BiasQuantProxyFromInjector )
159+ class IntInferencetHandlerBase (InferenceHandler , FloatToIntMixin ):
160160
161- def __init__ (self , scale_shape = ( 1 ,), zero_point_shape = ( 1 ,) ):
161+ def __init__ (self ):
162162 super ().__init__ ()
163- self .register_buffer ('scale' , torch .ones (scale_shape ))
164- self .register_buffer ('zero_point' , torch .ones (zero_point_shape ))
165163 self .register_buffer ('bit_width' , torch .ones (()))
166164 self .register_buffer ('min_clamp' , torch .ones (()))
167165 self .register_buffer ('max_clamp' , torch .ones (()))
168166
169167 def prepare_for_export (self , module : nn .Module ):
168+ InferenceHandler .prepare_for_export (self , module )
170169 FloatToIntMixin .prepare_for_export (self , module )
171170 if module .is_quant_enabled :
172- scale = module .scale_ () if hasattr (module , 'scale_' ) else module .scale ()
173- zero_point = module .zero_point_ () if hasattr (module ,
174- 'zero_point_' ) else module .zero_point ()
175- # Continguous is used to be extra-safe with torch.compile
176- self .scale = scale .contiguous ()
177- self .zero_point = zero_point .contiguous ()
178-
179- self .zero_point = self .zero_point .to (self .scale .device )
180171 self .bit_width = module .bit_width ()
181172 self .min_clamp = min_int (module .is_signed , module .is_narrow_range , self .bit_width )
182173 self .max_clamp = max_int (module .is_signed , module .is_narrow_range , self .bit_width )
@@ -192,6 +183,18 @@ def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
192183 return self .dequantize (self .quantize (x , self .scale , self .zero_point ), self .scale , self .zero_point ), self .scale , self .zero_point , self .bit_width
193184
194185
186+ class IntInferencetHandler (IntInferencetHandlerBase , StaticScaleZeroPointMixin ):
187+ handled_layer = (ActQuantProxyFromInjector , BiasQuantProxyFromInjector )
188+
189+ def __init__ (self , scale_shape = (1 ,), zero_point_shape = (1 ,)):
190+ IntInferencetHandlerBase .__init__ (self )
191+ StaticScaleZeroPointMixin .__init__ (self , scale_shape , zero_point_shape )
192+
193+ def prepare_for_export (self , module : nn .Module ):
194+ IntInferencetHandlerBase .prepare_for_export (self , module )
195+ StaticScaleZeroPointMixin .prepare_for_export (self , module )
196+
197+
195198class IntWeightInferencetHandler (IntInferencetHandler ):
196199 handled_layer = WeightQuantProxyFromInjector
197200
@@ -219,26 +222,27 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
219222 return x , self .scale , self .zero_point , self .bit_width
220223
221224
222- class DynamicIntInferenceHandler (IntInferencetHandler ):
225+ class DynamicIntInferenceHandler (IntInferencetHandlerBase ):
223226 handled_layer = DynamicActQuantProxyFromInjector
224227
225228 def prepare_for_export (self , module : nn .Module ):
229+ super ().prepare_for_export (module )
226230 if module .is_quant_enabled :
227231 self .module_forward = module .fused_activation_quant_proxy .tensor_quant
228232
229233 def forward (self , x : Tensor , unused_scale : Tensor = None ) -> Tuple [Tensor ]:
230234 return self .module_forward (x )
231235
232236
233- class GroupwiseIntInferenceHandler (IntInferencetHandler , GroupwiseMixin ):
237+ class GroupwiseIntInferenceHandler (IntInferencetHandlerBase , GroupwiseMixin ):
234238 handled_layer = GroupwiseActQuantProxyFromInjector
235239
236240 def __init__ (self ):
237241 super ().__init__ ()
238242 self .skip_create_quant_tensor = True
239243
240244 def prepare_for_export (self , module ):
241- GroupwiseMixin .prepare_for_export (self , module )
245+ super () .prepare_for_export (module )
242246 self .module_forward = None
243247 if module .is_quant_enabled :
244248 self .module_forward = module .fused_activation_quant_proxy .tensor_quant
@@ -262,12 +266,14 @@ def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
262266class GroupwiseIntWeightInferenceHandler (IntWeightInferencetHandler , GroupwiseMixin ):
263267 handled_layer = GroupwiseWeightQuantProxyFromInjector
264268
265- def __init__ (self ):
266- super ().__init__ ()
269+ def __init__ (self , scale_shape = (1 ,), zero_point_shape = (1 ,)):
270+ IntWeightInferencetHandler .__init__ (self , scale_shape , zero_point_shape )
271+ GroupwiseMixin .__init__ (self )
267272 self .skip_create_quant_tensor = True
268273
269274 def prepare_for_export (self , module ):
270- super ().prepare_for_export (module )
275+ IntWeightInferencetHandler .prepare_for_export (self , module )
276+ GroupwiseMixin .prepare_for_export (self , module )
271277 if module .is_quant_enabled :
272278 self .input_view = module .input_view_impl
273279
@@ -301,23 +307,20 @@ def __init__(self):
301307 self .register_buffer ('exponent_bit_width' , torch .ones (()))
302308 self .register_buffer ('exponent_bias' , torch .ones (()))
303309 self .register_buffer ('fp_internal_scale_min' , torch .ones (()))
304- self .register_buffer ('saturating_t' , torch .ones (()).to (torch .bool ))
305310 self .inf_values = None
306311 self .nan_values = None
307312 self .eps = 1e-8 #torch.finfo(self.scale.dtype).tiny
308-
309- @property
310- def saturating (self ):
311- return bool (self .saturating_t .item ())
313+ self .saturating = True
312314
313315 def prepare_for_export (self , module ):
316+ InferenceHandler .prepare_for_export (self , module )
314317 FloatToIntMixin .prepare_for_export (self , module )
315318 if module .is_quant_enabled :
316319
317320 self .exponent_bit_width = module .exponent_bit_width ()
318321 self .mantissa_bit_width = module .mantissa_bit_width ()
319322 self .exponent_bias = module .exponent_bias ()
320- self .saturating_t = torch . tensor ( module .is_saturating () )
323+ self .saturating = module .is_saturating ()
321324 self .inf_values = module .inf_values ()
322325 self .nan_values = module .nan_values ()
323326 if module .tensor_quant is not None :
@@ -350,7 +353,8 @@ def quantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tuple[Tensor
350353 n_max_val_mask = - x > self .max_clamp
351354
352355 # Clamp
353- x = torch .clamp (x , self .min_clamp .to (x .device ), self .max_clamp .to (x .device ))
356+ # x = torch.clamp(x, self.min_clamp.to(x.device), self.max_clamp.to(x.device))
357+ x = self .float_clamp_impl .saturating_clamp (x , self .max_clamp , self .min_clamp )
354358 if not self .saturating :
355359 x = self .float_clamp_impl .inf_nan_clamp (x , inf_mask , p_max_val_mask , n_max_val_mask )
356360
@@ -367,10 +371,15 @@ class FloatInferencetHandler(FloatInferenceHandlerBase, StaticScaleZeroPointMixi
367371 handled_layer = (ActFloatQuantProxyFromInjector , BiasQuantProxyFromInjector )
368372
369373 def __init__ (self , scale_shape = (1 ,), zero_point_shape = (1 ,)):
370- super ().__init__ (scale_shape , zero_point_shape )
374+ FloatInferenceHandlerBase .__init__ (self )
375+ StaticScaleZeroPointMixin .__init__ (self , scale_shape , zero_point_shape )
371376
377+ def prepare_for_export (self , module ):
378+ FloatInferenceHandlerBase .prepare_for_export (self , module )
379+ StaticScaleZeroPointMixin .prepare_for_export (self , module )
372380
373- class FloatWeightInferencetHandler (FloatInferenceHandlerBase , StaticScaleZeroPointMixin ):
381+
382+ class FloatWeightInferencetHandler (FloatInferencetHandler ):
374383 handled_layer = WeightFloatQuantProxyFromInjector
375384
376385 def __init__ (self , scale_shape = (1 ,), zero_point_shape = (1 ,)):
@@ -435,6 +444,7 @@ def __init__(self, scale_shape=(1,), zero_point_shape=(1,)):
435444 StaticScaleZeroPointMixin .__init__ (self , scale_shape , zero_point_shape )
436445 GroupwiseMixin .__init__ (self )
437446 self .skip_create_quant_tensor = True
447+ self .reshape_op = DynamicOverSubChannelBlockView (self .group_size , self .group_dim )
438448
439449 def reshape (self , x , group_dim , group_size ):
440450 init_shape = list (x .shape )
@@ -446,6 +456,16 @@ def reshape(self, x, group_dim, group_size):
446456 x = x .reshape (shape )
447457 return x
448458
459+ def prepare_for_export (self , module ):
460+ FloatInferenceHandlerBase .prepare_for_export (self , module )
461+ StaticScaleZeroPointMixin .prepare_for_export (self , module )
462+ GroupwiseMixin .prepare_for_export (self , module )
463+ if module .is_quant_enabled :
464+ if module ._cached_weight is not None and not module .cache_inference_quant_weight_metadata_only :
465+ self .cached_weight = module ._cached_weight .value
466+ else :
467+ self .cached_weight = None
468+
449469 def inner_forward (self , x : Tensor , scale : Tensor , zero_point : Tensor ) -> Tuple [Tensor ]:
450470 out = self .dequantize (self .quantize (x , scale , zero_point ), scale , zero_point )
451471 return out
@@ -459,7 +479,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
459479 scale = self .scale
460480 zero_point = self .zero_point
461481 inp_shape = x .shape
462- x = self .reshape ( x , self . group_dim , self . group_size )
482+ x = self .reshape_op ( x )
463483
464484 out = self .inner_forward (x , scale , zero_point )
465485 out = groupwise_dequant_expand (out , scale , zero_point , self .group_dim , inp_shape )[0 ]
0 commit comments