Skip to content

Commit c0a4484

Browse files
committed
Temp fix
1 parent e872263 commit c0a4484

File tree

9 files changed

+56
-22
lines changed

9 files changed

+56
-22
lines changed

src/brevitas/graph/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def init_new_module(self, old_module: Module, name: str = None, load_state_dict:
232232
is_assign_supported = 'assign' in inspect.signature(
233233
old_module.load_state_dict).parameters.keys()
234234
if 'device' in new_module_signature_keys and is_assign_supported and load_state_dict:
235+
new_kwargs['quant_device'] = new_kwargs['device']
235236
new_kwargs['device'] = torch.device("meta")
236237

237238
# init the new module

src/brevitas/nn/mixin/base.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,16 @@ def __init__(
4747

4848
# `device` and `dtype` are special keywords, propagated through no matter the prefix
4949
# `None` is the default value for these parameters when they are not specified
50-
special_keys = ['device', 'dtype']
51-
for key in special_keys:
52-
filtered_kwargs[key] = kwargs.get(key, None)
50+
51+
# When applying quantization, if we use `meta` device to avoid memory duplication, we need
52+
# to keep track of the original device for the eventual quantization parameters
53+
# If not, they would be stuck in `meta` and error out
54+
device = kwargs.get('device') if kwargs.get('device') not in [
55+
'meta', torch.device('meta')] else kwargs.get('quant_device')
56+
dtype = kwargs.get('dtype')
57+
filtered_kwargs['device'] = device
58+
filtered_kwargs['dtype'] = dtype
59+
5360
if quant is None:
5461
quant_injector = none_quant_injector.let(**filtered_kwargs)
5562
quant = quant_injector.proxy_class(self, quant_injector)

src/brevitas/nn/quant_conv.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ def __init__(
4545
input_quant: Optional[ActQuantType] = None,
4646
output_quant: Optional[ActQuantType] = None,
4747
return_quant_tensor: bool = False,
48+
device: Optional[torch.device] = None,
49+
dtype: Optional[torch.dtype] = None,
4850
**kwargs) -> None:
49-
device = kwargs.get('device', None)
50-
dtype = kwargs.get('dtype', None)
5151
# avoid an init error in the super class by setting padding to 0
5252
if padding_mode == 'zeros' and padding == 'same' and (stride > 1 if isinstance(
5353
stride, int) else any(map(lambda x: x > 1, stride))):
@@ -75,6 +75,8 @@ def __init__(
7575
input_quant=input_quant,
7676
output_quant=output_quant,
7777
return_quant_tensor=return_quant_tensor,
78+
device=device,
79+
dtype=dtype,
7880
**kwargs)
7981
self.is_same_padded_strided = is_same_padded_strided
8082

@@ -132,9 +134,9 @@ def __init__(
132134
input_quant: Optional[ActQuantType] = None,
133135
output_quant: Optional[ActQuantType] = None,
134136
return_quant_tensor: bool = False,
137+
device: Optional[torch.device] = None,
138+
dtype: Optional[torch.dtype] = None,
135139
**kwargs) -> None:
136-
device = kwargs.get('device', None)
137-
dtype = kwargs.get('dtype', None)
138140
# avoid an init error in the super class by setting padding to 0
139141
if padding_mode == 'zeros' and padding == 'same' and (stride > 1 if isinstance(
140142
stride, int) else any(map(lambda x: x > 1, stride))):
@@ -162,6 +164,8 @@ def __init__(
162164
input_quant=input_quant,
163165
output_quant=output_quant,
164166
return_quant_tensor=return_quant_tensor,
167+
device=device,
168+
dtype=dtype,
165169
**kwargs)
166170
self.is_same_padded_strided = is_same_padded_strided
167171

@@ -221,9 +225,9 @@ def __init__(
221225
input_quant: Optional[ActQuantType] = None,
222226
output_quant: Optional[ActQuantType] = None,
223227
return_quant_tensor: bool = False,
228+
device: Optional[torch.device] = None,
229+
dtype: Optional[torch.dtype] = None,
224230
**kwargs) -> None:
225-
device = kwargs.get('device', None)
226-
dtype = kwargs.get('dtype', None)
227231
# avoid an init error in the super class by setting padding to 0
228232
if padding_mode == 'zeros' and padding == 'same' and (stride > 1 if isinstance(
229233
stride, int) else any(map(lambda x: x > 1, stride))):
@@ -251,6 +255,8 @@ def __init__(
251255
input_quant=input_quant,
252256
output_quant=output_quant,
253257
return_quant_tensor=return_quant_tensor,
258+
device=device,
259+
dtype=dtype,
254260
**kwargs)
255261
self.is_same_padded_strided = is_same_padded_strided
256262

src/brevitas/nn/quant_convtranspose.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ def __init__(
5050
input_quant: Optional[ActQuantType] = None,
5151
output_quant: Optional[ActQuantType] = None,
5252
return_quant_tensor: bool = False,
53+
device: Optional[torch.device] = None,
54+
dtype: Optional[torch.dtype] = None,
5355
**kwargs) -> None:
54-
device = kwargs.get('device', None)
55-
dtype = kwargs.get('dtype', None)
5656
ConvTranspose1d.__init__(
5757
self,
5858
in_channels=in_channels,
@@ -74,6 +74,8 @@ def __init__(
7474
input_quant=input_quant,
7575
output_quant=output_quant,
7676
return_quant_tensor=return_quant_tensor,
77+
device=device,
78+
dtype=dtype,
7779
**kwargs)
7880
self._output_size = None
7981

@@ -145,9 +147,9 @@ def __init__(
145147
input_quant: Optional[ActQuantType] = None,
146148
output_quant: Optional[ActQuantType] = None,
147149
return_quant_tensor: bool = False,
150+
device: Optional[torch.device] = None,
151+
dtype: Optional[torch.dtype] = None,
148152
**kwargs) -> None:
149-
device = kwargs.get('device', None)
150-
dtype = kwargs.get('dtype', None)
151153
ConvTranspose2d.__init__(
152154
self,
153155
in_channels=in_channels,
@@ -169,6 +171,8 @@ def __init__(
169171
input_quant=input_quant,
170172
output_quant=output_quant,
171173
return_quant_tensor=return_quant_tensor,
174+
device=device,
175+
dtype=dtype,
172176
**kwargs)
173177
self._output_size = None
174178

@@ -240,9 +244,9 @@ def __init__(
240244
input_quant: Optional[ActQuantType] = None,
241245
output_quant: Optional[ActQuantType] = None,
242246
return_quant_tensor: bool = False,
247+
device: Optional[torch.device] = None,
248+
dtype: Optional[torch.dtype] = None,
243249
**kwargs) -> None:
244-
device = kwargs.get('device', None)
245-
dtype = kwargs.get('dtype', None)
246250
ConvTranspose3d.__init__(
247251
self,
248252
in_channels=in_channels,
@@ -264,6 +268,8 @@ def __init__(
264268
input_quant=input_quant,
265269
output_quant=output_quant,
266270
return_quant_tensor=return_quant_tensor,
271+
device=device,
272+
dtype=dtype,
267273
**kwargs)
268274
self._output_size = None
269275

src/brevitas/nn/quant_embedding.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ def __init__(
3232
_weight: Optional[Tensor] = None,
3333
weight_quant: WeightQuantType = Int8WeightPerTensorFloat,
3434
return_quant_tensor=False,
35+
device: Optional[torch.device] = None,
36+
dtype: Optional[torch.dtype] = None,
3537
**kwargs) -> None:
36-
device = kwargs.get('device', None)
37-
dtype = kwargs.get('dtype', None)
3838
Embedding.__init__(
3939
self,
4040
num_embeddings=num_embeddings,
@@ -47,7 +47,8 @@ def __init__(
4747
_weight=_weight,
4848
device=device,
4949
dtype=dtype)
50-
QuantWeightMixin.__init__(self, weight_quant=weight_quant, **kwargs)
50+
QuantWeightMixin.__init__(
51+
self, weight_quant=weight_quant, device=device, dtype=dtype, **kwargs)
5152
self.accept_quant_tensor = False
5253
self.return_quant_tensor = return_quant_tensor
5354

src/brevitas/nn/quant_linear.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ def __init__(
3535
input_quant: Optional[ActQuantType] = None,
3636
output_quant: Optional[ActQuantType] = None,
3737
return_quant_tensor: bool = False,
38+
device: Optional[torch.device] = None,
39+
dtype: Optional[torch.dtype] = None,
3840
**kwargs) -> None:
39-
device = kwargs.get('device', None)
40-
dtype = kwargs.get('dtype', None)
4141
Linear.__init__(self, in_features, out_features, bias, device=device, dtype=dtype)
4242
QuantWBIOL.__init__(
4343
self,
@@ -46,6 +46,8 @@ def __init__(
4646
input_quant=input_quant,
4747
output_quant=output_quant,
4848
return_quant_tensor=return_quant_tensor,
49+
device=device,
50+
dtype=dtype,
4951
**kwargs)
5052

5153
@property

src/brevitas/nn/quant_rnn.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,8 @@ def __init__(
845845
gate_acc_quant=Int8ActPerTensorFloat,
846846
shared_input_hidden_weights=False,
847847
return_quant_tensor: bool = False,
848+
dtype: Optional[torch.dtype] = None,
849+
device: Optional[torch.device] = None,
848850
**kwargs):
849851
super(QuantRNN, self).__init__(
850852
layer_impl=_QuantRNNLayer,
@@ -861,6 +863,8 @@ def __init__(
861863
gate_acc_quant=gate_acc_quant,
862864
shared_input_hidden_weights=shared_input_hidden_weights,
863865
return_quant_tensor=return_quant_tensor,
866+
device=device,
867+
dtype=dtype,
864868
**kwargs)
865869

866870

@@ -888,6 +892,8 @@ def __init__(
888892
shared_intra_layer_gate_acc_quant=False,
889893
shared_cell_state_quant=True,
890894
return_quant_tensor: bool = False,
895+
dtype: Optional[torch.dtype] = None,
896+
device: Optional[torch.device] = None,
891897
**kwargs):
892898
super(QuantLSTM, self).__init__(
893899
layer_impl=_QuantLSTMLayer,
@@ -910,6 +916,8 @@ def __init__(
910916
shared_intra_layer_gate_acc_quant=shared_intra_layer_gate_acc_quant,
911917
shared_cell_state_quant=shared_cell_state_quant,
912918
return_quant_tensor=return_quant_tensor,
919+
device=device,
920+
dtype=dtype,
913921
**kwargs)
914922
if cat_output_cell_states and cell_state_quant is not None and not shared_cell_state_quant:
915923
raise RuntimeError("Concatenating cell states requires shared cell quantizers.")

src/brevitas/nn/quant_sdpa.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,12 @@ def __init__(
142142
self.pre_process_k = pre_process_k
143143
self.pre_process_v = pre_process_v
144144

145+
special_keys = ['device', 'dtype', 'quant_dtype']
146+
145147
def filter_kwargs(prefix):
146-
return {k[len(prefix):]: v for k, v in kwargs.items() if k.startswith(prefix)}
148+
return {
149+
k[len(prefix):]: v for k,
150+
v in kwargs.items() if k.startswith(prefix) or k in special_keys}
147151

148152
self.q_scaled_quant = QuantIdentity(act_quant=q_scaled_quant, **filter_kwargs('q_scaled_'))
149153
self.k_transposed_quant = QuantIdentity(

src/brevitas/quant/solver/weight.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ class WeightQuantSolver(SolveStatsReduceDimFromEnum,
104104
SolveParameterScalingShape,
105105
SolveWeightScalingPerOutputChannelShapeFromModule,
106106
SolveWeightTensorQuantFromEnum,
107-
SolveDtypeDeviceFromTrackedParameterList,
108107
SolveInputViewImpl):
109108
"""
110109
Translate enum and shape directives to weight-specific quantization core modules.

0 commit comments

Comments
 (0)