Skip to content

Commit 428e5ee

Browse files
committed
format
TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
1 parent 7e9ecf7 commit 428e5ee

7 files changed

Lines changed: 46 additions & 36 deletions

File tree

tico/quantization/algorithm/fpi_gptq/fpi_gptq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
)
3333

3434
from tico.quantization.algorithm.gptq.quant import quantize, Quantizer
35-
from tico.quantization.algorithm.fpi_gptq.util import quantize, iterate_GPTQ
35+
from tico.quantization.algorithm.fpi_gptq.util import iterate_GPTQ, quantize
36+
3637

3738
class FPI_GPTQ:
3839
def __init__(self, layer):

tico/quantization/algorithm/fpi_gptq/util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import torch
2222

23+
2324
def quantize(x, scale, zero, maxq):
2425
if maxq < 0:
2526
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
@@ -49,7 +50,7 @@ def iterate_GPTQ(scale, zero, maxq, W, Hinv, max_num_of_iters=50):
4950

5051
if torch.cuda.is_available():
5152
torch.cuda.empty_cache()
52-
53+
5354
cur_Q = quantize(cur_weights, scale, zero, maxq)
5455

5556
return cur_Q, cur_weights

tico/quantization/algorithm/gptq/gptq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,9 +360,9 @@ def fasterquant(
360360
H = torch.cholesky_inverse(H)
361361
H = torch.linalg.cholesky(H, upper=True).float()
362362
Hinv = H
363-
363+
364364
self.quantizer.update(W, Hinv, perm)
365-
365+
366366
assert isinstance(Hinv, torch.Tensor)
367367
for i1 in range(0, self.columns, blocksize):
368368
i2 = min(i1 + blocksize, self.columns)

tico/quantization/algorithm/gptq/quant.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from tico.quantization.algorithm.fpi_gptq.util import iterate_GPTQ
2525

26+
2627
def quantize(x, scale, zero, maxq):
2728
if maxq < 0:
2829
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
@@ -62,11 +63,11 @@ def configure(
6263

6364
def _prepare_tensor(self, x, weight=False):
6465
"""Prepare tensor for quantization by flattening according to per-channel setting.
65-
66+
6667
Args:
6768
x: Input tensor to prepare
6869
weight: Whether the tensor is a weight (affects flattening for activations)
69-
70+
7071
Returns:
7172
Tuple of (prepared tensor, original shape)
7273
"""
@@ -88,10 +89,10 @@ def _prepare_tensor(self, x, weight=False):
8889

8990
def _compute_scale_zero_bounds(self, x):
9091
"""Compute scale and zero bounds from tensor values.
91-
92+
9293
Args:
9394
x: Prepared tensor (flattened according to per-channel setting)
94-
95+
9596
Returns:
9697
Tuple of (scale, zero, xmin, xmax) computed from tensor bounds
9798
"""
@@ -123,17 +124,17 @@ def _compute_scale_zero_bounds(self, x):
123124

124125
def _reshape_scale_zero(self, shape, weight=False):
125126
"""Reshape scale and zero tensors according to the original tensor shape.
126-
127+
127128
Args:
128129
shape: Original tensor shape before preparation
129130
weight: Whether the tensor is a weight (affects reshape for activations)
130131
"""
131132
if weight:
132133
shape = [-1] + [1] * (len(shape) - 1)
133-
self.scale = self.scale.reshape(shape)
134-
self.zero = self.zero.reshape(shape)
134+
self.scale = self.scale.reshape(shape) # type: ignore[has-type]
135+
self.zero = self.zero.reshape(shape) # type: ignore[has-type]
135136
return
136-
137+
137138
if len(shape) == 4:
138139
self.scale = self.scale.reshape((1, -1, 1, 1))
139140
self.zero = self.zero.reshape((1, -1, 1, 1))
@@ -146,7 +147,7 @@ def _reshape_scale_zero(self, shape, weight=False):
146147

147148
def _expand_for_per_tensor(self, shape, weight=False):
148149
"""Expand scale and zero for per-tensor quantization.
149-
150+
150151
Args:
151152
shape: Original tensor shape before preparation
152153
weight: Whether the tensor is a weight
@@ -169,20 +170,24 @@ def find_params(self, x, weight=False):
169170

170171
self.scale, self.zero, xmin, xmax = self._compute_scale_zero_bounds(x)
171172

172-
if self.mse is not None and self.mse != "smse_for_gptq" and self.mse != "mse_for_gptq":
173+
if (
174+
self.mse is not None
175+
and self.mse != "smse_for_gptq"
176+
and self.mse != "mse_for_gptq"
177+
):
173178
self._optimize_mse(x, xmin, xmax)
174179

175180
self._expand_for_per_tensor(shape, weight)
176181
self._reshape_scale_zero(shape, weight)
177182

178183
def _compute_shrink_params(self, p, xmin, xmax):
179184
"""Compute scale and zero for a shrink factor p.
180-
185+
181186
Args:
182187
p: Shrink factor (1 - i / grid)
183188
xmin: Minimum values per channel
184189
xmax: Maximum values per channel
185-
190+
186191
Returns:
187192
Tuple of (scale1, zero1) for the given shrink factor
188193
"""
@@ -194,13 +199,13 @@ def _compute_shrink_params(self, p, xmin, xmax):
194199

195200
def _update_best_params(self, best, err, scale1, zero1):
196201
"""Update best parameters if current error is lower.
197-
202+
198203
Args:
199204
best: Current best error values
200205
err: Current iteration error values
201206
scale1: Current iteration scale values
202207
zero1: Current iteration zero values
203-
208+
204209
Returns:
205210
Updated best error values
206211
"""
@@ -213,7 +218,7 @@ def _update_best_params(self, best, err, scale1, zero1):
213218

214219
def _grid_search(self, x, xmin, xmax, compute_error_fn):
215220
"""Common grid search loop for MSE optimization.
216-
221+
217222
Args:
218223
x: Prepared tensor
219224
xmin: Minimum values per channel
@@ -230,25 +235,28 @@ def _grid_search(self, x, xmin, xmax, compute_error_fn):
230235

231236
def _optimize_mse(self, x, xmin, xmax):
232237
"""Optimize scale and zero using MSE-based grid search.
233-
238+
234239
Args:
235240
x: Prepared tensor
236241
xmin: Minimum values per channel
237242
xmax: Maximum values per channel
238243
"""
244+
239245
def compute_error(x, scale1, zero1):
240246
q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
241247
q -= x
242248
q.abs_()
243249
if self.mse == "smse": # sensitivity weighted mse
244250
# in case sensitivity is a second order derivatives of some global loss
245251
# (q**2) * self.sensitivity is just a global loss change due to quantization.
246-
q = (q**2) * self.sensitivity.to(q.device) # estimate global target change
252+
q = (q**2) * self.sensitivity.to(
253+
q.device
254+
) # estimate global target change
247255
else:
248256
assert self.mse == "mse"
249257
q.pow_(self.norm)
250258
return torch.sum(q, 1)
251-
259+
252260
self._grid_search(x, xmin, xmax, compute_error)
253261

254262
def update(self, x, Hinv, perm):
@@ -269,13 +277,13 @@ def update(self, x, Hinv, perm):
269277
self._optimize_mse_for_gptq(x, Hinv, sensitivity, xmin, xmax)
270278

271279
self._reshape_scale_zero(shape, weight=True)
272-
280+
273281
del sensitivity
274282
sensitivity = None
275283

276284
def _optimize_mse_for_gptq(self, x, Hinv, sensitivity, xmin, xmax):
277285
"""Optimize scale and zero using GPTQ-aware MSE grid search.
278-
286+
279287
Args:
280288
x: Prepared tensor
281289
Hinv: Inverse Hessian matrix
@@ -284,7 +292,7 @@ def _optimize_mse_for_gptq(self, x, Hinv, sensitivity, xmin, xmax):
284292
xmax: Maximum values per channel
285293
"""
286294
num_of_iters = 15
287-
295+
288296
def compute_error(x, scale1, zero1):
289297
q, _ = iterate_GPTQ(
290298
scale1.unsqueeze(1),
@@ -298,7 +306,7 @@ def compute_error(x, scale1, zero1):
298306
assert self.mse == "smse_for_gptq"
299307
err = ((q - x) ** 2) * sensitivity.to(q.device)
300308
return torch.sum(err, 1)
301-
309+
302310
self._grid_search(x, xmin, xmax, compute_error)
303311

304312
def quantize(self, x):

tico/quantization/algorithm/gptq/quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def _hook(_, inp, out):
478478
model.lm_head = model.lm_head.to(old_device)
479479
if torch.cuda.is_available():
480480
torch.cuda.empty_cache()
481-
481+
482482
device = next(layer.parameters()).device # in case lm_head is located on cpu
483483
for batch_idx in tqdm(
484484
range(batch_num),

tico/quantization/algorithm/gptq/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def compute_sensitivity_info(self):
223223
if model.device.type != "cpu":
224224
torch.cuda.synchronize()
225225
torch.cuda.empty_cache()
226-
226+
227227
model = model.to(dtype)
228228

229229
return sensitivity

tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def _print_sample(title, items):
370370
_print_sample("unused GPTQ entries", unused)
371371

372372

373-
def evaluate_ppl_of_model_on_dataset(model, dataset, device: str = "cuda"):
373+
def evaluate_ppl_of_model_on_dataset(model, dataset, device):
374374
if hasattr(model, "device") and model.device.type != device.type:
375375
if hasattr(model, "to"):
376376
model.to(device)
@@ -415,6 +415,7 @@ def evaluate_ppl_of_model_on_dataset(model, dataset, device: str = "cuda"):
415415
ppl = np.exp(torch.cat(nlls, dim=-1).mean().item())
416416
return ppl
417417

418+
418419
# -------------------------------------------------------------------------
419420
# Helper — clear gptq quantizers after injection
420421
# -------------------------------------------------------------------------
@@ -1349,12 +1350,11 @@ def main():
13491350

13501351
calib_inputs = build_calibration_inputs(model, tokenizer, args, device)
13511352
train_ppl_ioqdtype = evaluate_ppl_of_model_on_dataset(
1352-
model, calib_inputs, device=device
1353-
)
1353+
model, calib_inputs, device=device
1354+
)
13541355
print("\n┌── Wikitext-2 train perplexity ─────────────")
13551356
print(f"│ FP32 : {train_ppl_ioqdtype:8.2f}")
13561357
print("└───────────────────────────────────────────")
1357-
13581358

13591359
model = apply_spinquant(model, args)
13601360
model = apply_cle(model, args)
@@ -1363,14 +1363,14 @@ def main():
13631363
q_m = quantize_using_PTQ(model, calib_inputs, args)
13641364

13651365
evaluate(q_m, tokenizer, dataset_test, args)
1366-
1366+
13671367
train_ppl_ioqdtype = evaluate_ppl_of_model_on_dataset(
1368-
q_m, calib_inputs, device=device
1369-
)
1368+
q_m, calib_inputs, device=device
1369+
)
13701370
print("\n┌── Wikitext-2 train perplexity ─────────────")
13711371
print(f"│ int16 : {train_ppl_ioqdtype:8.2f}")
13721372
print("└───────────────────────────────────────────")
1373-
1373+
13741374
save_requested_artifacts(q_m, tokenizer, calib_inputs, args)
13751375

13761376

0 commit comments

Comments
 (0)