@@ -45,9 +45,9 @@ def __init__(
4545 input_quant : Optional [ActQuantType ] = None ,
4646 output_quant : Optional [ActQuantType ] = None ,
4747 return_quant_tensor : bool = False ,
48+ device : Optional [torch .device ] = None ,
49+ dtype : Optional [torch .dtype ] = None ,
4850 ** kwargs ) -> None :
49- device = kwargs .get ('device' , None )
50- dtype = kwargs .get ('dtype' , None )
5151 # avoid an init error in the super class by setting padding to 0
5252 if padding_mode == 'zeros' and padding == 'same' and (stride > 1 if isinstance (
5353 stride , int ) else any (map (lambda x : x > 1 , stride ))):
@@ -75,6 +75,8 @@ def __init__(
7575 input_quant = input_quant ,
7676 output_quant = output_quant ,
7777 return_quant_tensor = return_quant_tensor ,
78+ device = device ,
79+ dtype = dtype ,
7880 ** kwargs )
7981 self .is_same_padded_strided = is_same_padded_strided
8082
@@ -132,9 +134,9 @@ def __init__(
132134 input_quant : Optional [ActQuantType ] = None ,
133135 output_quant : Optional [ActQuantType ] = None ,
134136 return_quant_tensor : bool = False ,
137+ device : Optional [torch .device ] = None ,
138+ dtype : Optional [torch .dtype ] = None ,
135139 ** kwargs ) -> None :
136- device = kwargs .get ('device' , None )
137- dtype = kwargs .get ('dtype' , None )
138140 # avoid an init error in the super class by setting padding to 0
139141 if padding_mode == 'zeros' and padding == 'same' and (stride > 1 if isinstance (
140142 stride , int ) else any (map (lambda x : x > 1 , stride ))):
@@ -162,6 +164,8 @@ def __init__(
162164 input_quant = input_quant ,
163165 output_quant = output_quant ,
164166 return_quant_tensor = return_quant_tensor ,
167+ device = device ,
168+ dtype = dtype ,
165169 ** kwargs )
166170 self .is_same_padded_strided = is_same_padded_strided
167171
@@ -221,9 +225,9 @@ def __init__(
221225 input_quant : Optional [ActQuantType ] = None ,
222226 output_quant : Optional [ActQuantType ] = None ,
223227 return_quant_tensor : bool = False ,
228+ device : Optional [torch .device ] = None ,
229+ dtype : Optional [torch .dtype ] = None ,
224230 ** kwargs ) -> None :
225- device = kwargs .get ('device' , None )
226- dtype = kwargs .get ('dtype' , None )
227231 # avoid an init error in the super class by setting padding to 0
228232 if padding_mode == 'zeros' and padding == 'same' and (stride > 1 if isinstance (
229233 stride , int ) else any (map (lambda x : x > 1 , stride ))):
@@ -251,6 +255,8 @@ def __init__(
251255 input_quant = input_quant ,
252256 output_quant = output_quant ,
253257 return_quant_tensor = return_quant_tensor ,
258+ device = device ,
259+ dtype = dtype ,
254260 ** kwargs )
255261 self .is_same_padded_strided = is_same_padded_strided
256262
0 commit comments