Skip to content

Commit 0d9e1c9

Browse files
committed
Finer-grained quantization disabling
1 parent 4b87428 commit 0d9e1c9

File tree

1 file changed

+74
-28
lines changed

1 file changed

+74
-28
lines changed

src/brevitas/graph/calibrate.py

+74-28
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from functools import partial
55
import sys
6-
from typing import List, Optional
6+
from typing import List, Optional, Tuple, Type
77

88
import torch
99
from torch import nn
@@ -13,6 +13,7 @@
1313
from brevitas.nn import QuantLinear
1414
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
1515
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjectorBase
16+
from brevitas.proxy.parameter_quant import ParameterQuantProxyFromInjector
1617
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase
1718
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjectorBase
1819
from brevitas.proxy.runtime_quant import ClampQuantProxyFromInjector
@@ -30,6 +31,7 @@
3031

3132
_PARAM_PROXIES = (WeightQuantProxyFromInjectorBase, BiasQuantProxyFromInjectorBase)
3233

34+
_WEIGHT_PROXIES = (WeightQuantProxyFromInjectorBase)
3335
_BIAS_PROXIES = (BiasQuantProxyFromInjectorBase)
3436

3537
_ACC_PROXIES = (TruncQuantProxyFromInjector, ClampQuantProxyFromInjector)
@@ -205,17 +207,40 @@ def disable_act_quantization(self, model, is_training):
205207
module.train(is_training)
206208
module.disable_quant = True
207209

208-
def disable_param_quantization(self, model, is_training):
209-
for module in model.modules():
210-
if isinstance(module, _PARAM_PROXIES):
211-
module.train(is_training)
212-
module.disable_quant = True
213-
214-
def disable_bias_quantization(self, model, is_training):
210+
def _set_param_quantization(
211+
self,
212+
model: nn.Module,
213+
is_training: bool,
214+
disable_quant: bool,
215+
quant_proxies: Tuple[Type[ParameterQuantProxyFromInjector]] = _PARAM_PROXIES) -> None:
215216
for module in model.modules():
216-
if isinstance(module, _BIAS_PROXIES):
217+
if isinstance(module, quant_proxies):
217218
module.train(is_training)
218-
module.disable_quant = True
219+
module.disable_quant = disable_quant
220+
221+
def disable_param_quantization(self, model: nn.Module, is_training: bool) -> None:
222+
self._set_param_quantization(
223+
model=model,
224+
is_training=is_training,
225+
disable_quant=True,
226+
quant_proxies=_PARAM_PROXIES,
227+
)
228+
229+
def disable_bias_quantization(self, model: nn.Module, is_training: bool):
230+
self._set_param_quantization(
231+
model=model,
232+
is_training=is_training,
233+
disable_quant=True,
234+
quant_proxies=_BIAS_PROXIES,
235+
)
236+
237+
def disable_weight_quantization(self, model: nn.Module, is_training: bool):
238+
self._set_param_quantization(
239+
model=model,
240+
is_training=is_training,
241+
disable_quant=True,
242+
quant_proxies=_WEIGHT_PROXIES,
243+
)
219244

220245
def enable_act_quantization(self, model, is_training):
221246
for module in model.modules():
@@ -229,17 +254,29 @@ def enable_act_quantization(self, model, is_training):
229254
if hasattr(m, 'observer_only'):
230255
m.observer_only = False
231256

232-
def enable_param_quantization(self, model, is_training):
233-
for module in model.modules():
234-
if isinstance(module, _PARAM_PROXIES):
235-
module.disable_quant = False
236-
module.train(is_training)
237-
238-
def enable_bias_quantization(self, model, is_training):
239-
for module in model.modules():
240-
if isinstance(module, _BIAS_PROXIES):
241-
module.disable_quant = False
242-
module.train(is_training)
257+
def enable_param_quantization(self, model: nn.Module, is_training: bool):
258+
self._set_param_quantization(
259+
model=model,
260+
is_training=is_training,
261+
disable_quant=False,
262+
quant_proxies=_PARAM_PROXIES,
263+
)
264+
265+
def enable_bias_quantization(self, model: nn.Module, is_training: bool):
266+
self._set_param_quantization(
267+
model=model,
268+
is_training=is_training,
269+
disable_quant=False,
270+
quant_proxies=_BIAS_PROXIES,
271+
)
272+
273+
def enable_weight_quantization(self, model: nn.Module, is_training: bool):
274+
self._set_param_quantization(
275+
model=model,
276+
is_training=is_training,
277+
disable_quant=False,
278+
quant_proxies=_WEIGHT_PROXIES,
279+
)
243280

244281
def apply(self, model, is_training, quantization_enabled):
245282
if not quantization_enabled:
@@ -258,7 +295,10 @@ class disable_enable_quantization:
258295
Args:
259296
model (nn.Module): module for which quantization will be enabled/
260297
disabled
261-
disable_quant (bool): whether to disable quantization
298+
disable_quant_act (bool): whether to disable activation quantization
299+
disable_weight_quant (bool): whether to disable weight quantization
300+
disable_bias_quant (bool): whether to disable bias quantization
301+
disable_out_quant (bool): whether to disable output quantization
262302
excluded_modules (list): list of submodules of modules to be excluded
263303
from quantization disabling
264304
"""
@@ -267,12 +307,14 @@ def __init__(
267307
self,
268308
model: nn.Module,
269309
disable_act_quant: bool = True,
270-
disable_param_quant: bool = True,
310+
disable_weight_quant: bool = True,
311+
disable_bias_quant: bool = True,
271312
disable_out_quant: bool = True,
272313
excluded_modules: Optional[List[nn.Module]] = None):
273314
self.model = model
274315
self.disable_act_quant = disable_act_quant
275-
self.disable_param_quant = disable_param_quant
316+
self.disable_weight_quant = disable_weight_quant
317+
self.disable_bias_quant = disable_bias_quant
276318
self.disable_out_quant = disable_out_quant
277319
self.excluded_modules = excluded_modules if excluded_modules is not None else []
278320
self.disable_quant_class = DisableEnableQuantization()
@@ -281,8 +323,10 @@ def __init__(
281323
def __enter__(self):
282324
if self.disable_act_quant:
283325
self.disable_quant_class.disable_act_quantization(self.model, False)
284-
if self.disable_param_quant:
285-
self.disable_quant_class.disable_param_quantization(self.model, False)
326+
if self.disable_weight_quant:
327+
self.disable_quant_class.disable_weight_quantization(self.model, False)
328+
if self.disable_bias_quant:
329+
self.disable_quant_class.disable_bias_quantization(self.model, False)
286330
if self.disable_out_quant:
287331
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
288332
# Re-enable quantization for excluded modules
@@ -294,8 +338,10 @@ def __enter__(self):
294338
def __exit__(self, type, value, traceback):
295339
if self.disable_act_quant:
296340
self.disable_quant_class.enable_act_quantization(self.model, False)
297-
if self.disable_param_quant:
298-
self.disable_quant_class.enable_param_quantization(self.model, False)
341+
if self.disable_weight_quant:
342+
self.disable_quant_class.enable_weight_quantization(self.model, False)
343+
if self.disable_bias_quant:
344+
self.disable_quant_class.enable_bias_quantization(self.model, False)
299345
if self.disable_out_quant:
300346
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)
301347

0 commit comments

Comments
 (0)