Skip to content

Commit 3259272

Browse files
committed
Fix proxies
1 parent 99e8b56 commit 3259272

File tree

1 file changed

+49
-29
lines changed

1 file changed

+49
-29
lines changed

src/brevitas/export/inference/handler.py

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch.nn as nn
1111

1212
import brevitas.config as config
13+
from brevitas.core.function_wrapper.shape import DynamicOverSubChannelBlockView
1314
from brevitas.function import compute_max_mantissa
1415
from brevitas.function.ops import max_float
1516
from brevitas.function.ops import max_int
@@ -155,28 +156,18 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):
155156
return output_dict
156157

157158

158-
class IntInferencetHandler(InferenceHandler, FloatToIntMixin):
159-
handled_layer = (ActQuantProxyFromInjector, BiasQuantProxyFromInjector)
159+
class IntInferencetHandlerBase(InferenceHandler, FloatToIntMixin):
160160

161-
def __init__(self, scale_shape=(1,), zero_point_shape=(1,)):
161+
def __init__(self):
162162
super().__init__()
163-
self.register_buffer('scale', torch.ones(scale_shape))
164-
self.register_buffer('zero_point', torch.ones(zero_point_shape))
165163
self.register_buffer('bit_width', torch.ones(()))
166164
self.register_buffer('min_clamp', torch.ones(()))
167165
self.register_buffer('max_clamp', torch.ones(()))
168166

169167
def prepare_for_export(self, module: nn.Module):
168+
InferenceHandler.prepare_for_export(self, module)
170169
FloatToIntMixin.prepare_for_export(self, module)
171170
if module.is_quant_enabled:
172-
scale = module.scale_() if hasattr(module, 'scale_') else module.scale()
173-
zero_point = module.zero_point_() if hasattr(module,
174-
'zero_point_') else module.zero_point()
175-
# Continguous is used to be extra-safe with torch.compile
176-
self.scale = scale.contiguous()
177-
self.zero_point = zero_point.contiguous()
178-
179-
self.zero_point = self.zero_point.to(self.scale.device)
180171
self.bit_width = module.bit_width()
181172
self.min_clamp = min_int(module.is_signed, module.is_narrow_range, self.bit_width)
182173
self.max_clamp = max_int(module.is_signed, module.is_narrow_range, self.bit_width)
@@ -192,6 +183,18 @@ def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
192183
return self.dequantize(self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point), self.scale, self.zero_point, self.bit_width
193184

194185

186+
class IntInferencetHandler(IntInferencetHandlerBase, StaticScaleZeroPointMixin):
187+
handled_layer = (ActQuantProxyFromInjector, BiasQuantProxyFromInjector)
188+
189+
def __init__(self, scale_shape=(1,), zero_point_shape=(1,)):
190+
IntInferencetHandlerBase.__init__(self)
191+
StaticScaleZeroPointMixin.__init__(self, scale_shape, zero_point_shape)
192+
193+
def prepare_for_export(self, module: nn.Module):
194+
IntInferencetHandlerBase.prepare_for_export(self, module)
195+
StaticScaleZeroPointMixin.prepare_for_export(self, module)
196+
197+
195198
class IntWeightInferencetHandler(IntInferencetHandler):
196199
handled_layer = WeightQuantProxyFromInjector
197200

@@ -219,26 +222,27 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
219222
return x, self.scale, self.zero_point, self.bit_width
220223

221224

222-
class DynamicIntInferenceHandler(IntInferencetHandler):
225+
class DynamicIntInferenceHandler(IntInferencetHandlerBase):
223226
handled_layer = DynamicActQuantProxyFromInjector
224227

225228
def prepare_for_export(self, module: nn.Module):
229+
super().prepare_for_export(module)
226230
if module.is_quant_enabled:
227231
self.module_forward = module.fused_activation_quant_proxy.tensor_quant
228232

229233
def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
230234
return self.module_forward(x)
231235

232236

233-
class GroupwiseIntInferenceHandler(IntInferencetHandler, GroupwiseMixin):
237+
class GroupwiseIntInferenceHandler(IntInferencetHandlerBase, GroupwiseMixin):
234238
handled_layer = GroupwiseActQuantProxyFromInjector
235239

236240
def __init__(self):
237241
super().__init__()
238242
self.skip_create_quant_tensor = True
239243

240244
def prepare_for_export(self, module):
241-
GroupwiseMixin.prepare_for_export(self, module)
245+
super().prepare_for_export(module)
242246
self.module_forward = None
243247
if module.is_quant_enabled:
244248
self.module_forward = module.fused_activation_quant_proxy.tensor_quant
@@ -262,12 +266,14 @@ def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
262266
class GroupwiseIntWeightInferenceHandler(IntWeightInferencetHandler, GroupwiseMixin):
263267
handled_layer = GroupwiseWeightQuantProxyFromInjector
264268

265-
def __init__(self):
266-
super().__init__()
269+
def __init__(self, scale_shape=(1,), zero_point_shape=(1,)):
270+
IntWeightInferencetHandler.__init__(self, scale_shape, zero_point_shape)
271+
GroupwiseMixin.__init__(self)
267272
self.skip_create_quant_tensor = True
268273

269274
def prepare_for_export(self, module):
270-
super().prepare_for_export(module)
275+
IntWeightInferencetHandler.prepare_for_export(self, module)
276+
GroupwiseMixin.prepare_for_export(self, module)
271277
if module.is_quant_enabled:
272278
self.input_view = module.input_view_impl
273279

@@ -301,23 +307,20 @@ def __init__(self):
301307
self.register_buffer('exponent_bit_width', torch.ones(()))
302308
self.register_buffer('exponent_bias', torch.ones(()))
303309
self.register_buffer('fp_internal_scale_min', torch.ones(()))
304-
self.register_buffer('saturating_t', torch.ones(()).to(torch.bool))
305310
self.inf_values = None
306311
self.nan_values = None
307312
self.eps = 1e-8 #torch.finfo(self.scale.dtype).tiny
308-
309-
@property
310-
def saturating(self):
311-
return bool(self.saturating_t.item())
313+
self.saturating = True
312314

313315
def prepare_for_export(self, module):
316+
InferenceHandler.prepare_for_export(self, module)
314317
FloatToIntMixin.prepare_for_export(self, module)
315318
if module.is_quant_enabled:
316319

317320
self.exponent_bit_width = module.exponent_bit_width()
318321
self.mantissa_bit_width = module.mantissa_bit_width()
319322
self.exponent_bias = module.exponent_bias()
320-
self.saturating_t = torch.tensor(module.is_saturating())
323+
self.saturating = module.is_saturating()
321324
self.inf_values = module.inf_values()
322325
self.nan_values = module.nan_values()
323326
if module.tensor_quant is not None:
@@ -350,7 +353,8 @@ def quantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tuple[Tensor
350353
n_max_val_mask = -x > self.max_clamp
351354

352355
# Clamp
353-
x = torch.clamp(x, self.min_clamp.to(x.device), self.max_clamp.to(x.device))
356+
# x = torch.clamp(x, self.min_clamp.to(x.device), self.max_clamp.to(x.device))
357+
x = self.float_clamp_impl.saturating_clamp(x, self.max_clamp, self.min_clamp)
354358
if not self.saturating:
355359
x = self.float_clamp_impl.inf_nan_clamp(x, inf_mask, p_max_val_mask, n_max_val_mask)
356360

@@ -367,10 +371,15 @@ class FloatInferencetHandler(FloatInferenceHandlerBase, StaticScaleZeroPointMixi
367371
handled_layer = (ActFloatQuantProxyFromInjector, BiasQuantProxyFromInjector)
368372

369373
def __init__(self, scale_shape=(1,), zero_point_shape=(1,)):
370-
super().__init__(scale_shape, zero_point_shape)
374+
FloatInferenceHandlerBase.__init__(self)
375+
StaticScaleZeroPointMixin.__init__(self, scale_shape, zero_point_shape)
371376

377+
def prepare_for_export(self, module):
378+
FloatInferenceHandlerBase.prepare_for_export(self, module)
379+
StaticScaleZeroPointMixin.prepare_for_export(self, module)
372380

373-
class FloatWeightInferencetHandler(FloatInferenceHandlerBase, StaticScaleZeroPointMixin):
381+
382+
class FloatWeightInferencetHandler(FloatInferencetHandler):
374383
handled_layer = WeightFloatQuantProxyFromInjector
375384

376385
def __init__(self, scale_shape=(1,), zero_point_shape=(1,)):
@@ -435,6 +444,7 @@ def __init__(self, scale_shape=(1,), zero_point_shape=(1,)):
435444
StaticScaleZeroPointMixin.__init__(self, scale_shape, zero_point_shape)
436445
GroupwiseMixin.__init__(self)
437446
self.skip_create_quant_tensor = True
447+
self.reshape_op = DynamicOverSubChannelBlockView(self.group_size, self.group_dim)
438448

439449
def reshape(self, x, group_dim, group_size):
440450
init_shape = list(x.shape)
@@ -446,6 +456,16 @@ def reshape(self, x, group_dim, group_size):
446456
x = x.reshape(shape)
447457
return x
448458

459+
def prepare_for_export(self, module):
460+
FloatInferenceHandlerBase.prepare_for_export(self, module)
461+
StaticScaleZeroPointMixin.prepare_for_export(self, module)
462+
GroupwiseMixin.prepare_for_export(self, module)
463+
if module.is_quant_enabled:
464+
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
465+
self.cached_weight = module._cached_weight.value
466+
else:
467+
self.cached_weight = None
468+
449469
def inner_forward(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tuple[Tensor]:
450470
out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point)
451471
return out
@@ -459,7 +479,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
459479
scale = self.scale
460480
zero_point = self.zero_point
461481
inp_shape = x.shape
462-
x = self.reshape(x, self.group_dim, self.group_size)
482+
x = self.reshape_op(x)
463483

464484
out = self.inner_forward(x, scale, zero_point)
465485
out = groupwise_dequant_expand(out, scale, zero_point, self.group_dim, inp_shape)[0]

0 commit comments

Comments
 (0)