Skip to content
Open
27 changes: 15 additions & 12 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def dequantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tensor:
return (x - zero_point) * scale

def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
return self.dequantize(self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point), self.scale, self.zero_point, self.bit_width
return self.dequantize(
self.quantize(x, self.scale, self.zero_point), self.scale,
self.zero_point) #, self.scale, self.zero_point, self.bit_width


class IntWeightInferencetHandler(IntInferencetHandler):
Expand All @@ -108,7 +110,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
else:
x = self.inner_forward(x, self.scale, self.zero_point)

return x, self.scale, self.zero_point, self.bit_width
return x #, self.scale, self.zero_point, self.bit_width


class DynamicIntInferenceHandler(IntInferencetHandler):
Expand Down Expand Up @@ -140,11 +142,10 @@ def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
inp_shape = x.shape
x, scale, zero_point, *other = self.module_forward(x)

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
x = groupwise_dequant_expand(x, scale, zero_point, self.group_dim, inp_shape)[0]
# When we skip quant tensor, we return the flattened version of the groupwise tensor
x = groupwise_dequant_expand(x, scale, zero_point, self.group_dim, inp_shape)[0]
output_args = tuple([x, scale, zero_point] + list(other))
return output_args
return x


class GroupwiseIntWeightInferenceHandler(IntWeightInferencetHandler):
Expand Down Expand Up @@ -182,7 +183,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:

# If we skip quant tensor, we return the flattened version of the groupwise tensor
out = groupwise_dequant_expand(out, scale, zero_point, self.group_dim, inp_shape)[0]
return out, scale, zero_point, self.bit_width
return out #, scale, zero_point, self.bit_width


class FloatInferencetHandler(InferenceHandler):
Expand Down Expand Up @@ -253,7 +254,9 @@ def dequantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tensor:
return (x - zero_point) * scale

def forward(self, x: Tensor) -> Tuple[Tensor]:
return self.dequantize(self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values
return self.dequantize(
self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point
) #, self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values


class FloatWeightInferencetHandler(FloatInferencetHandler):
Expand All @@ -279,7 +282,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
x = self.cached_weight
else:
x = self.inner_forward(x, self.scale, self.zero_point)
return x, self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values
return x # , self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values


class GroupwiseFloatInferenceHandler(FloatInferencetHandler):
Expand All @@ -301,8 +304,8 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
x, scale, zero_point, *other = self.module_forward(x)
# If we skip quant tensor, we return the flattened version of the groupwise tensor
x = groupwise_dequant_expand(x, scale, zero_point, self.group_dim, inp_shape)[0]
output_args = tuple([x, scale, zero_point] + list(other))
return output_args
# output_args = tuple([x, scale, zero_point] + list(other))
return x


class GroupwiseFloatWeightInferenceHandler(FloatWeightInferencetHandler):
Expand Down Expand Up @@ -342,7 +345,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
out = self.inner_forward(x, scale, zero_point)
out = groupwise_dequant_expand(out, scale, zero_point, self.group_dim, inp_shape)[0]

return out, scale, zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values
return out #, scale, zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values


class DynamicFloatInferenceHandler(FloatInferencetHandler):
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/export/inference/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _override_act_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool
_override_caching_mode(m, 'act', enabled, metadata_only)


def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = False):
def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True):
_override_caching_mode(m, 'weight', enabled, metadata_only)


Expand Down
15 changes: 15 additions & 0 deletions src/brevitas/export/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def _override_bias_caching_mode(m: Module, enabled: bool):
m.cache_inference_quant_bias = enabled


def _override_weight_caching_mode(m: Module, enabled: bool):
if hasattr(m, 'cache_inference_quant_weight'):
if not hasattr(m, "cache_inference_quant_weight_backup"):
m.cache_inference_quant_weight_backup = m.cache_inference_quant_weight
m.cache_inference_quant_weight = enabled


def _override_act_caching_mode(m: Module, enabled: bool):
if hasattr(m, 'cache_inference_quant_act'):
if not hasattr(m, "cache_inference_quant_act_backup"):
Expand All @@ -91,6 +98,12 @@ def _restore_act_caching_mode(m: Module):
del m.cache_inference_quant_act_backup


def _restore_weight_caching_mode(m: Module):
if hasattr(m, "cache_inference_quant_weight_backup"):
m.cache_inference_quant_weight = m.cache_inference_quant_weight_backup
del m.cache_inference_quant_weight_backup


def _set_recurrent_layer_export_mode(model: Module, enabled: bool):
for m in model.modules():
if isinstance(m, QuantRecurrentLayerMixin) and hasattr(m, 'export_mode'):
Expand Down Expand Up @@ -196,11 +209,13 @@ def _cache_inp_out(cls, module, *args, **kwargs):
module.apply(lambda m: _override_quant_metadata_caching_mode(m, enabled=True))
module.apply(lambda m: _override_bias_caching_mode(m, enabled=True))
module.apply(lambda m: _override_act_caching_mode(m, enabled=True))
module.apply(lambda m: _override_weight_caching_mode(m, enabled=True))
_ = module.forward(*args, **kwargs)
# Restore previous caching properties
module.apply(lambda m: _restore_quant_metadata_caching_mode(m))
module.apply(lambda m: _restore_bias_caching_mode(m))
module.apply(lambda m: _restore_act_caching_mode(m))
module.apply(lambda m: _restore_weight_caching_mode(m))

@classmethod
def jit_inference_trace(
Expand Down
11 changes: 8 additions & 3 deletions src/brevitas/nn/mixin/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def output_channel_dim(self) -> int:
def quant_weight(
self,
quant_input: Optional[QuantTensor] = None,
subtensor_slice_list: List[Optional[Tuple[int, int]]] = None):
subtensor_slice_list: List[Optional[Tuple[int, int]]] = None,
return_quant_tensor: bool = True):
weights_to_quantize = self.weight
if not self.weight_quant.is_quant_enabled and hasattr(self, 'weight_orig'):
weights_to_quantize = self.weight_orig.to(self.weight.device)
Expand All @@ -70,9 +71,13 @@ def quant_weight(
else:
weight_slice_tuple = slice(None)
if self.weight_quant.requires_quant_input:
out = self.weight_quant(weights_to_quantize[weight_slice_tuple], quant_input)
out = self.weight_quant(
weights_to_quantize[weight_slice_tuple],
quant_input,
return_quant_tensor=return_quant_tensor)
else:
out = self.weight_quant(weights_to_quantize[weight_slice_tuple])
out = self.weight_quant(
weights_to_quantize[weight_slice_tuple], return_quant_tensor=return_quant_tensor)
if subtensor_slice_list is not None:
# Restore the quantizer behaviour to full tensor quantization
# The modules to slice should have been cached already at this point
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/nn/quant_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def requires_export_handler(self):

def forward(self, input: QuantTensor):
x = self.unpack_input(input)
x = self.trunc_quant(x)
x = self.trunc_quant(x, return_quant_tensor=self.return_quant_tensor)
return self.pack_output(x)


Expand All @@ -61,5 +61,5 @@ def requires_export_handler(self):

def forward(self, input: QuantTensor):
x = self.unpack_input(input)
x = self.clamp_quant(x)
x = self.clamp_quant(x, return_quant_tensor=self.return_quant_tensor)
return self.pack_output(x)
4 changes: 2 additions & 2 deletions src/brevitas/nn/quant_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def forward(self, input: Union[Tensor, QuantTensor]):
if not isinstance(x, QuantTensor):
x = self.cache_class.quant_tensor.set(value=x)
y = AvgPool2d.forward(self, x)
y = self.trunc_quant(y)
y = self.trunc_quant(y, return_quant_tensor=self.return_quant_tensor)
else:
y = AvgPool2d.forward(self, _unpack_quant_tensor(x))

Expand Down Expand Up @@ -149,7 +149,7 @@ def forward(self, input: Union[Tensor, QuantTensor]):
if not isinstance(x, QuantTensor):
x = self.cache_class.quant_tensor.set(value=x)
y = AdaptiveAvgPool2d.forward(self, x)
y = self.trunc_quant(y)
y = self.trunc_quant(y, return_quant_tensor=self.return_quant_tensor)
else:
y = AdaptiveAvgPool2d.forward(self, _unpack_quant_tensor(x))

Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/nn/quant_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def out_channels(self) -> int:
return self.num_embeddings

def forward(self, inp):
quant_weight = self.quant_weight()
quant_weight = self.quant_weight(return_quant_tensor=self.return_quant_tensor)
out = embedding(
inp,
quant_weight,
Expand Down
20 changes: 13 additions & 7 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from torch import Tensor
from torch.nn import Module

from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.torch_utils import compute_channel_view_shape

from .mixin import *
from .utils import merge_bn
Expand Down Expand Up @@ -45,12 +45,12 @@ def requires_export_handler(self):

def forward(self, input: Union[Tensor, QuantTensor]):
input = self.unpack_input(input)
quant_input = self.input_quant(input)
quant_input = self.input_quant(input, return_quant_tensor=self.return_quant_tensor)
# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler(quant_input)
return out
out = self.act_quant(quant_input)
out = self.act_quant(quant_input, return_quant_tensor=self.return_quant_tensor)
out = self.pack_output(out)
return out

Expand Down Expand Up @@ -142,8 +142,11 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
out = self.export_handler(inp)
return out

quant_input = self.input_quant(inp)
quant_weight = self.quant_weight(quant_input)
is_quant_tensor_required = self.return_quant_tensor or getattr(
self.bias_quant, 'requires_input_scale', False) or getattr(
self.weight_quant, 'requires_quant_input', False)
quant_input = self.input_quant(inp, return_quant_tensor=is_quant_tensor_required)
quant_weight = self.quant_weight(quant_input, return_quant_tensor=is_quant_tensor_required)

compute_output_quant_tensor = isinstance(quant_input, QuantTensor) and isinstance(
quant_weight, QuantTensor)
Expand All @@ -152,12 +155,15 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
raise RuntimeError("QuantLayer is not correctly configured")

if self.bias is not None:
quant_bias = self.bias_quant(self.bias, quant_input, quant_weight)
quant_bias = self.bias_quant(
self.bias, quant_input, quant_weight, return_quant_tensor=self.return_quant_tensor)
else:
quant_bias = None

output_tensor = self.inner_forward_impl(quant_input, quant_weight, quant_bias)

quant_output = self.output_quant(output_tensor)
quant_output = self.output_quant(
output_tensor, return_quant_tensor=self.return_quant_tensor)
return self.pack_output(quant_output)

def _load_from_state_dict(
Expand Down
Loading
Loading