2323
2424from tico .quantization .algorithm .fpi_gptq .util import iterate_GPTQ
2525
26+
2627def quantize (x , scale , zero , maxq ):
2728 if maxq < 0 :
2829 return (x > scale / 2 ).float () * scale + (x < zero / 2 ).float () * zero
@@ -62,11 +63,11 @@ def configure(
6263
6364 def _prepare_tensor (self , x , weight = False ):
6465 """Prepare tensor for quantization by flattening according to per-channel setting.
65-
66+
6667 Args:
6768 x: Input tensor to prepare
6869 weight: Whether the tensor is a weight (affects flattening for activations)
69-
70+
7071 Returns:
7172 Tuple of (prepared tensor, original shape)
7273 """
@@ -88,10 +89,10 @@ def _prepare_tensor(self, x, weight=False):
8889
8990 def _compute_scale_zero_bounds (self , x ):
9091 """Compute scale and zero bounds from tensor values.
91-
92+
9293 Args:
9394 x: Prepared tensor (flattened according to per-channel setting)
94-
95+
9596 Returns:
9697 Tuple of (scale, zero, xmin, xmax) computed from tensor bounds
9798 """
@@ -123,17 +124,17 @@ def _compute_scale_zero_bounds(self, x):
123124
124125 def _reshape_scale_zero (self , shape , weight = False ):
125126 """Reshape scale and zero tensors according to the original tensor shape.
126-
127+
127128 Args:
128129 shape: Original tensor shape before preparation
129130 weight: Whether the tensor is a weight (affects reshape for activations)
130131 """
131132 if weight :
132133 shape = [- 1 ] + [1 ] * (len (shape ) - 1 )
133- self .scale = self .scale .reshape (shape )
134- self .zero = self .zero .reshape (shape )
134+ self .scale = self .scale .reshape (shape ) # type: ignore[has-type]
135+ self .zero = self .zero .reshape (shape ) # type: ignore[has-type]
135136 return
136-
137+
137138 if len (shape ) == 4 :
138139 self .scale = self .scale .reshape ((1 , - 1 , 1 , 1 ))
139140 self .zero = self .zero .reshape ((1 , - 1 , 1 , 1 ))
@@ -146,7 +147,7 @@ def _reshape_scale_zero(self, shape, weight=False):
146147
147148 def _expand_for_per_tensor (self , shape , weight = False ):
148149 """Expand scale and zero for per-tensor quantization.
149-
150+
150151 Args:
151152 shape: Original tensor shape before preparation
152153 weight: Whether the tensor is a weight
@@ -169,20 +170,24 @@ def find_params(self, x, weight=False):
169170
170171 self .scale , self .zero , xmin , xmax = self ._compute_scale_zero_bounds (x )
171172
172- if self .mse is not None and self .mse != "smse_for_gptq" and self .mse != "mse_for_gptq" :
173+ if (
174+ self .mse is not None
175+ and self .mse != "smse_for_gptq"
176+ and self .mse != "mse_for_gptq"
177+ ):
173178 self ._optimize_mse (x , xmin , xmax )
174179
175180 self ._expand_for_per_tensor (shape , weight )
176181 self ._reshape_scale_zero (shape , weight )
177182
178183 def _compute_shrink_params (self , p , xmin , xmax ):
179184 """Compute scale and zero for a shrink factor p.
180-
185+
181186 Args:
182187 p: Shrink factor (1 - i / grid)
183188 xmin: Minimum values per channel
184189 xmax: Maximum values per channel
185-
190+
186191 Returns:
187192 Tuple of (scale1, zero1) for the given shrink factor
188193 """
@@ -194,13 +199,13 @@ def _compute_shrink_params(self, p, xmin, xmax):
194199
195200 def _update_best_params (self , best , err , scale1 , zero1 ):
196201 """Update best parameters if current error is lower.
197-
202+
198203 Args:
199204 best: Current best error values
200205 err: Current iteration error values
201206 scale1: Current iteration scale values
202207 zero1: Current iteration zero values
203-
208+
204209 Returns:
205210 Updated best error values
206211 """
@@ -213,7 +218,7 @@ def _update_best_params(self, best, err, scale1, zero1):
213218
214219 def _grid_search (self , x , xmin , xmax , compute_error_fn ):
215220 """Common grid search loop for MSE optimization.
216-
221+
217222 Args:
218223 x: Prepared tensor
219224 xmin: Minimum values per channel
@@ -230,25 +235,28 @@ def _grid_search(self, x, xmin, xmax, compute_error_fn):
230235
231236 def _optimize_mse (self , x , xmin , xmax ):
232237 """Optimize scale and zero using MSE-based grid search.
233-
238+
234239 Args:
235240 x: Prepared tensor
236241 xmin: Minimum values per channel
237242 xmax: Maximum values per channel
238243 """
244+
239245 def compute_error (x , scale1 , zero1 ):
240246 q = quantize (x , scale1 .unsqueeze (1 ), zero1 .unsqueeze (1 ), self .maxq )
241247 q -= x
242248 q .abs_ ()
243249 if self .mse == "smse" : # sensitivity weighted mse
244250 # in case sensitivity is a second order derivatives of some global loss
245251 # (q**2) * self.sensitivity is just a global loss change due to quantization.
246- q = (q ** 2 ) * self .sensitivity .to (q .device ) # estimate global target change
252+ q = (q ** 2 ) * self .sensitivity .to (
253+ q .device
254+ ) # estimate global target change
247255 else :
248256 assert self .mse == "mse"
249257 q .pow_ (self .norm )
250258 return torch .sum (q , 1 )
251-
259+
252260 self ._grid_search (x , xmin , xmax , compute_error )
253261
254262 def update (self , x , Hinv , perm ):
@@ -269,13 +277,13 @@ def update(self, x, Hinv, perm):
269277 self ._optimize_mse_for_gptq (x , Hinv , sensitivity , xmin , xmax )
270278
271279 self ._reshape_scale_zero (shape , weight = True )
272-
280+
273281 del sensitivity
274282 sensitivity = None
275283
276284 def _optimize_mse_for_gptq (self , x , Hinv , sensitivity , xmin , xmax ):
277285 """Optimize scale and zero using GPTQ-aware MSE grid search.
278-
286+
279287 Args:
280288 x: Prepared tensor
281289 Hinv: Inverse Hessian matrix
@@ -284,7 +292,7 @@ def _optimize_mse_for_gptq(self, x, Hinv, sensitivity, xmin, xmax):
284292 xmax: Maximum values per channel
285293 """
286294 num_of_iters = 15
287-
295+
288296 def compute_error (x , scale1 , zero1 ):
289297 q , _ = iterate_GPTQ (
290298 scale1 .unsqueeze (1 ),
@@ -298,7 +306,7 @@ def compute_error(x, scale1, zero1):
298306 assert self .mse == "smse_for_gptq"
299307 err = ((q - x ) ** 2 ) * sensitivity .to (q .device )
300308 return torch .sum (err , 1 )
301-
309+
302310 self ._grid_search (x , xmin , xmax , compute_error )
303311
304312 def quantize (self , x ):
0 commit comments