3
3
4
4
from functools import partial
5
5
import sys
6
- from typing import List , Optional
6
+ from typing import List , Optional , Tuple , Type
7
7
8
8
import torch
9
9
from torch import nn
13
13
from brevitas .nn import QuantLinear
14
14
from brevitas .nn .quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
15
15
from brevitas .proxy .parameter_quant import BiasQuantProxyFromInjectorBase
16
+ from brevitas .proxy .parameter_quant import ParameterQuantProxyFromInjector
16
17
from brevitas .proxy .parameter_quant import WeightQuantProxyFromInjectorBase
17
18
from brevitas .proxy .runtime_quant import ActQuantProxyFromInjectorBase
18
19
from brevitas .proxy .runtime_quant import ClampQuantProxyFromInjector
30
31
31
32
_PARAM_PROXIES = (WeightQuantProxyFromInjectorBase , BiasQuantProxyFromInjectorBase )
32
33
34
+ _WEIGHT_PROXIES = (WeightQuantProxyFromInjectorBase )
33
35
_BIAS_PROXIES = (BiasQuantProxyFromInjectorBase )
34
36
35
37
_ACC_PROXIES = (TruncQuantProxyFromInjector , ClampQuantProxyFromInjector )
@@ -205,17 +207,40 @@ def disable_act_quantization(self, model, is_training):
205
207
module .train (is_training )
206
208
module .disable_quant = True
207
209
208
- def disable_param_quantization (self , model , is_training ):
209
- for module in model .modules ():
210
- if isinstance (module , _PARAM_PROXIES ):
211
- module .train (is_training )
212
- module .disable_quant = True
213
-
214
- def disable_bias_quantization (self , model , is_training ):
210
+ def _set_param_quantization (
211
+ self ,
212
+ model : nn .Module ,
213
+ is_training : bool ,
214
+ disable_quant : bool ,
215
+ quant_proxies : Tuple [Type [ParameterQuantProxyFromInjector ]] = _PARAM_PROXIES ) -> None :
215
216
for module in model .modules ():
216
- if isinstance (module , _BIAS_PROXIES ):
217
+ if isinstance (module , quant_proxies ):
217
218
module .train (is_training )
218
- module .disable_quant = True
219
+ module .disable_quant = disable_quant
220
+
221
+ def disable_param_quantization (self , model : nn .Module , is_training : bool ) -> None :
222
+ self ._set_param_quantization (
223
+ model = model ,
224
+ is_training = is_training ,
225
+ disable_quant = True ,
226
+ quant_proxies = _PARAM_PROXIES ,
227
+ )
228
+
229
+ def disable_bias_quantization (self , model : nn .Module , is_training : bool ):
230
+ self ._set_param_quantization (
231
+ model = model ,
232
+ is_training = is_training ,
233
+ disable_quant = True ,
234
+ quant_proxies = _BIAS_PROXIES ,
235
+ )
236
+
237
+ def disable_weight_quantization (self , model : nn .Module , is_training : bool ):
238
+ self ._set_param_quantization (
239
+ model = model ,
240
+ is_training = is_training ,
241
+ disable_quant = True ,
242
+ quant_proxies = _WEIGHT_PROXIES ,
243
+ )
219
244
220
245
def enable_act_quantization (self , model , is_training ):
221
246
for module in model .modules ():
@@ -229,17 +254,29 @@ def enable_act_quantization(self, model, is_training):
229
254
if hasattr (m , 'observer_only' ):
230
255
m .observer_only = False
231
256
232
- def enable_param_quantization (self , model , is_training ):
233
- for module in model .modules ():
234
- if isinstance (module , _PARAM_PROXIES ):
235
- module .disable_quant = False
236
- module .train (is_training )
237
-
238
- def enable_bias_quantization (self , model , is_training ):
239
- for module in model .modules ():
240
- if isinstance (module , _BIAS_PROXIES ):
241
- module .disable_quant = False
242
- module .train (is_training )
257
+ def enable_param_quantization (self , model : nn .Module , is_training : bool ):
258
+ self ._set_param_quantization (
259
+ model = model ,
260
+ is_training = is_training ,
261
+ disable_quant = False ,
262
+ quant_proxies = _PARAM_PROXIES ,
263
+ )
264
+
265
+ def enable_bias_quantization (self , model : nn .Module , is_training : bool ):
266
+ self ._set_param_quantization (
267
+ model = model ,
268
+ is_training = is_training ,
269
+ disable_quant = False ,
270
+ quant_proxies = _BIAS_PROXIES ,
271
+ )
272
+
273
+ def enable_weight_quantization (self , model : nn .Module , is_training : bool ):
274
+ self ._set_param_quantization (
275
+ model = model ,
276
+ is_training = is_training ,
277
+ disable_quant = False ,
278
+ quant_proxies = _WEIGHT_PROXIES ,
279
+ )
243
280
244
281
def apply (self , model , is_training , quantization_enabled ):
245
282
if not quantization_enabled :
@@ -258,7 +295,10 @@ class disable_enable_quantization:
258
295
Args:
259
296
model (nn.Module): module for which quantization will be enabled/
260
297
disabled
261
- disable_quant (bool): whether to disable quantization
298
+ disable_quant_act (bool): whether to disable activation quantization
299
+ disable_weight_quant (bool): whether to disable weight quantization
300
+ disable_bias_quant (bool): whether to disable bias quantization
301
+ disable_out_quant (bool): whether to disable output quantization
262
302
excluded_modules (list): list of submodules of modules to be excluded
263
303
from quantization disabling
264
304
"""
@@ -267,12 +307,14 @@ def __init__(
267
307
self ,
268
308
model : nn .Module ,
269
309
disable_act_quant : bool = True ,
270
- disable_param_quant : bool = True ,
310
+ disable_weight_quant : bool = True ,
311
+ disable_bias_quant : bool = True ,
271
312
disable_out_quant : bool = True ,
272
313
excluded_modules : Optional [List [nn .Module ]] = None ):
273
314
self .model = model
274
315
self .disable_act_quant = disable_act_quant
275
- self .disable_param_quant = disable_param_quant
316
+ self .disable_weight_quant = disable_weight_quant
317
+ self .disable_bias_quant = disable_bias_quant
276
318
self .disable_out_quant = disable_out_quant
277
319
self .excluded_modules = excluded_modules if excluded_modules is not None else []
278
320
self .disable_quant_class = DisableEnableQuantization ()
@@ -281,8 +323,10 @@ def __init__(
281
323
def __enter__ (self ):
282
324
if self .disable_act_quant :
283
325
self .disable_quant_class .disable_act_quantization (self .model , False )
284
- if self .disable_param_quant :
285
- self .disable_quant_class .disable_param_quantization (self .model , False )
326
+ if self .disable_weight_quant :
327
+ self .disable_quant_class .disable_weight_quantization (self .model , False )
328
+ if self .disable_bias_quant :
329
+ self .disable_quant_class .disable_bias_quantization (self .model , False )
286
330
if self .disable_out_quant :
287
331
self .return_quant_tensor_state = disable_return_quant_tensor (self .model )
288
332
# Re-enable quantization for excluded modules
@@ -294,8 +338,10 @@ def __enter__(self):
294
338
def __exit__ (self , type , value , traceback ):
295
339
if self .disable_act_quant :
296
340
self .disable_quant_class .enable_act_quantization (self .model , False )
297
- if self .disable_param_quant :
298
- self .disable_quant_class .enable_param_quantization (self .model , False )
341
+ if self .disable_weight_quant :
342
+ self .disable_quant_class .enable_weight_quantization (self .model , False )
343
+ if self .disable_bias_quant :
344
+ self .disable_quant_class .enable_bias_quantization (self .model , False )
299
345
if self .disable_out_quant :
300
346
restore_return_quant_tensor (self .model , self .return_quant_tensor_state )
301
347
0 commit comments