Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def init_new_module(self, old_module: Module, name: str = None, load_state_dict:
is_assign_supported = 'assign' in inspect.signature(
old_module.load_state_dict).parameters.keys()
if 'device' in new_module_signature_keys and is_assign_supported and load_state_dict:
new_kwargs['quant_device'] = new_kwargs['device']
new_kwargs['device'] = torch.device("meta")

# init the new module
Expand Down
18 changes: 16 additions & 2 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,26 @@ def __init__(
kwargs_prefix: str,
**kwargs):
proxy_name = proxy_prefix + 'quant'
filtered_kwargs = filter_kwargs(kwargs_prefix, kwargs)

# `device` and `dtype` are special keywords, propagated through no matter the prefix
# `None` is the default value for these parameters when they are not specified

# When applying quantization, if we use `meta` device to avoid memory duplication, we need
# to keep track of the original device for the eventual quantization parameters
# If not, they would be stuck in `meta` and error out
device = kwargs.get('device') if kwargs.get('device') not in [
'meta', torch.device('meta')] else kwargs.get('quant_device')
dtype = kwargs.get('dtype')
filtered_kwargs['device'] = device
filtered_kwargs['dtype'] = dtype

if quant is None:
quant_injector = none_quant_injector.let(**filter_kwargs(kwargs_prefix, kwargs))
quant_injector = none_quant_injector.let(**filtered_kwargs)
quant = quant_injector.proxy_class(self, quant_injector)
elif isclass(quant) and issubclass(quant, (Injector, ExtendedInjector)):
quant_injector = quant
quant_injector = quant_injector.let(**filter_kwargs(kwargs_prefix, kwargs))
quant_injector = quant_injector.let(**filtered_kwargs)
quant = quant_injector.proxy_class(self, quant_injector)
else:
if not isinstance(quant, proxy_protocol):
Expand Down
6 changes: 6 additions & 0 deletions src/brevitas/nn/quant_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def __init__(
input_quant=input_quant,
output_quant=output_quant,
return_quant_tensor=return_quant_tensor,
device=device,
dtype=dtype,
**kwargs)
self.is_same_padded_strided = is_same_padded_strided

Expand Down Expand Up @@ -162,6 +164,8 @@ def __init__(
input_quant=input_quant,
output_quant=output_quant,
return_quant_tensor=return_quant_tensor,
device=device,
dtype=dtype,
**kwargs)
self.is_same_padded_strided = is_same_padded_strided

Expand Down Expand Up @@ -251,6 +255,8 @@ def __init__(
input_quant=input_quant,
output_quant=output_quant,
return_quant_tensor=return_quant_tensor,
device=device,
dtype=dtype,
**kwargs)
self.is_same_padded_strided = is_same_padded_strided

Expand Down
6 changes: 6 additions & 0 deletions src/brevitas/nn/quant_convtranspose.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def __init__(
input_quant=input_quant,
output_quant=output_quant,
return_quant_tensor=return_quant_tensor,
device=device,
dtype=dtype,
**kwargs)
self._output_size = None

Expand Down Expand Up @@ -169,6 +171,8 @@ def __init__(
input_quant=input_quant,
output_quant=output_quant,
return_quant_tensor=return_quant_tensor,
device=device,
dtype=dtype,
**kwargs)
self._output_size = None

Expand Down Expand Up @@ -264,6 +268,8 @@ def __init__(
input_quant=input_quant,
output_quant=output_quant,
return_quant_tensor=return_quant_tensor,
device=device,
dtype=dtype,
**kwargs)
self._output_size = None

Expand Down
3 changes: 2 additions & 1 deletion src/brevitas/nn/quant_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def __init__(
_weight=_weight,
device=device,
dtype=dtype)
QuantWeightMixin.__init__(self, weight_quant=weight_quant, **kwargs)
QuantWeightMixin.__init__(
self, weight_quant=weight_quant, device=device, dtype=dtype, **kwargs)
self.accept_quant_tensor = False
self.return_quant_tensor = return_quant_tensor

Expand Down
2 changes: 2 additions & 0 deletions src/brevitas/nn/quant_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def __init__(
input_quant=input_quant,
output_quant=output_quant,
return_quant_tensor=return_quant_tensor,
device=device,
dtype=dtype,
**kwargs)

@property
Expand Down
71 changes: 10 additions & 61 deletions src/brevitas/nn/quant_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,9 @@ def forward(self):
class GateParams(QuantBiasMixin, nn.Module):

def __init__(
self,
input_size,
hidden_size,
bias,
weight_quant,
bias_quant,
input_weight,
dtype,
device,
**kwargs):
self, input_size, hidden_size, bias, weight_quant, bias_quant, input_weight, **kwargs):
device = kwargs.get('device', None)
dtype = kwargs.get('dtype', None)
nn.Module.__init__(self)
if bias:
self.bias = nn.Parameter(torch.randn(hidden_size, dtype=dtype, device=device))
Expand Down Expand Up @@ -340,8 +333,6 @@ def __init__(
shared_input_hidden_weights: bool,
return_quant_tensor: bool,
nonlinearity: str,
dtype: Optional[torch.dtype],
device: Optional[torch.device],
input_weight: GateWeight = None,
**kwargs):
nn.Module.__init__(self)
Expand Down Expand Up @@ -373,15 +364,7 @@ def __init__(
shared_input_hidden_weights=shared_input_hidden_weights,
return_quant_tensor=return_quant_tensor)
self.gate_params = GateParams(
input_size,
hidden_size,
bias,
weight_quant,
bias_quant,
input_weight,
dtype=dtype,
device=device,
**kwargs)
input_size, hidden_size, bias, weight_quant, bias_quant, input_weight, **kwargs)
self.reset_parameters()

@property
Expand Down Expand Up @@ -496,8 +479,6 @@ def __init__(
shared_cell_state_quant: bool,
shared_intra_layer_weight_quant: bool,
shared_intra_layer_gate_acc_quant: bool,
dtype: Optional[torch.dtype],
device: Optional[torch.device],
return_quant_tensor: bool,
input_input_weight: GateWeight = None,
input_forget_weight: GateWeight = None,
Expand Down Expand Up @@ -568,15 +549,7 @@ def __init__(
return_quant_tensor=return_quant_tensor)

self.input_gate_params = GateParams(
input_size,
hidden_size,
bias,
weight_quant,
bias_quant,
input_forget_weight,
dtype=dtype,
device=device,
**kwargs)
input_size, hidden_size, bias, weight_quant, bias_quant, input_forget_weight, **kwargs)
if shared_intra_layer_weight_quant:
# Share the input-to-hidden input weight quantizer, which is also shared with hidden-to-hidden
weight_quant = self.input_gate_params.input_weight.weight_quant
Expand All @@ -590,29 +563,11 @@ def __init__(
weight_quant,
bias_quant,
input_input_weight,
dtype=dtype,
device=device,
**kwargs)
self.cell_gate_params = GateParams(
input_size,
hidden_size,
bias,
weight_quant,
bias_quant,
input_cell_weight,
dtype=dtype,
device=device,
**kwargs)
input_size, hidden_size, bias, weight_quant, bias_quant, input_cell_weight, **kwargs)
self.output_gate_params = GateParams(
input_size,
hidden_size,
bias,
weight_quant,
bias_quant,
input_output_weight,
dtype=dtype,
device=device,
**kwargs)
input_size, hidden_size, bias, weight_quant, bias_quant, input_output_weight, **kwargs)
self.shared_cell_state_quant = shared_cell_state_quant
self.cifg = cifg
self.reset_parameters()
Expand Down Expand Up @@ -786,8 +741,6 @@ def __init__(
bidirectional: bool,
io_quant,
shared_input_hidden_weights: bool,
dtype: Optional[torch.dtype],
device: Optional[torch.device],
return_quant_tensor: bool,
**kwargs):
super(QuantRecurrentStackBase, self).__init__()
Expand All @@ -812,8 +765,6 @@ def __init__(
reverse_input=False,
quantize_output_only=quantize_output_only,
shared_input_hidden_weights=shared_input_hidden_weights,
dtype=dtype,
device=device,
return_quant_tensor=layer_return_quant_tensor,
**kwargs)
directions.append(left_to_right)
Expand All @@ -829,8 +780,6 @@ def __init__(
reverse_input=True,
quantize_output_only=quantize_output_only,
shared_input_hidden_weights=shared_input_hidden_weights,
dtype=dtype,
device=device,
return_quant_tensor=layer_return_quant_tensor,
**shared_weights,
**kwargs)
Expand Down Expand Up @@ -914,8 +863,8 @@ def __init__(
gate_acc_quant=gate_acc_quant,
shared_input_hidden_weights=shared_input_hidden_weights,
return_quant_tensor=return_quant_tensor,
dtype=dtype,
device=device,
dtype=dtype,
**kwargs)


Expand Down Expand Up @@ -943,8 +892,8 @@ def __init__(
shared_intra_layer_gate_acc_quant=False,
shared_cell_state_quant=True,
return_quant_tensor: bool = False,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
**kwargs):
super(QuantLSTM, self).__init__(
layer_impl=_QuantLSTMLayer,
Expand All @@ -967,8 +916,8 @@ def __init__(
shared_intra_layer_gate_acc_quant=shared_intra_layer_gate_acc_quant,
shared_cell_state_quant=shared_cell_state_quant,
return_quant_tensor=return_quant_tensor,
dtype=dtype,
device=device,
dtype=dtype,
**kwargs)
if cat_output_cell_states and cell_state_quant is not None and not shared_cell_state_quant:
raise RuntimeError("Concatenating cell states requires shared cell quantizers.")
Expand Down
6 changes: 5 additions & 1 deletion src/brevitas/nn/quant_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,12 @@ def __init__(
self.pre_process_k = pre_process_k
self.pre_process_v = pre_process_v

special_keys = ['device', 'dtype', 'quant_dtype']

def filter_kwargs(prefix):
return {k[len(prefix):]: v for k, v in kwargs.items() if k.startswith(prefix)}
return {
k[len(prefix):]: v for k,
v in kwargs.items() if k.startswith(prefix) or k in special_keys}

self.q_scaled_quant = QuantIdentity(act_quant=q_scaled_quant, **filter_kwargs('q_scaled_'))
self.k_transposed_quant = QuantIdentity(
Expand Down
1 change: 0 additions & 1 deletion src/brevitas/quant/solver/weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ class WeightQuantSolver(SolveStatsReduceDimFromEnum,
SolveParameterScalingShape,
SolveWeightScalingPerOutputChannelShapeFromModule,
SolveWeightTensorQuantFromEnum,
SolveDtypeDeviceFromTrackedParameterList,
SolveInputViewImpl):
"""
Translate enum and shape directives to weight-specific quantization core modules.
Expand Down
Loading