-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgptq_v5.py
More file actions
1271 lines (1076 loc) · 48.6 KB
/
gptq_v5.py
File metadata and controls
1271 lines (1076 loc) · 48.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium
# Based on original gptq algorithm and code from https://github.com/IST-DASLab/gptq
import contextlib
import math
import os
import sys
import threading
import time
from typing import Dict, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import transformers
from torch.nn.modules.conv import _ConvNd
from ..looper.named_module import NamedModule
from ..quantization import QuantizeConfig
from ..utils.device import get_device
from ..utils.logger import setup_logger
from .gar import compose_final_perm, compute_global_perm, compute_local_perms, invert_perm
from .quantizer import HF_OPTIMUM, Quantizer
log = setup_logger()
lock = threading.Lock()
# Shared workspaces are cached globally per device so that concurrent GPTQ
# instances reuse temporary buffers instead of repeatedly allocating large
# tensors during Hessian accumulation. Each device retains at most a single
# workspace; when size or dtype requirements change, the prior buffer is
# discarded to avoid unbounded cache growth.
_WORKSPACE_CACHE: Dict[Tuple[str, Optional[int]], torch.Tensor] = {}
_WORKSPACE_LOCKS: Dict[Tuple[str, Optional[int]], threading.Lock] = {}
_BF16_SUPPORT_CACHE: Dict[Tuple[str, Optional[int]], bool] = {}
def _device_cache_key(device: torch.device) -> Tuple[str, Optional[int]]:
dev = torch.device(device)
return dev.type, dev.index
def _workspace_cache_key(device: torch.device) -> Tuple[str, Optional[int]]:
return _device_cache_key(device)
def _needs_workspace_resize(
workspace: Optional[torch.Tensor],
dtype: torch.dtype,
required_rows: int,
cols: int,
) -> bool:
if workspace is None:
return True
if workspace.ndim != 2:
return True
if workspace.dtype != dtype:
return True
if workspace.shape[1] != cols:
return True
if workspace.shape[0] < required_rows:
return True
return False
@contextlib.contextmanager
def _lease_workspace(device: torch.device, dtype: torch.dtype, cols: int, required_rows: int):
key = _workspace_cache_key(device)
lock = _WORKSPACE_LOCKS.setdefault(key, threading.Lock())
with lock:
workspace = _WORKSPACE_CACHE.pop(key, None)
if _needs_workspace_resize(workspace, dtype, required_rows, cols):
rows = max(required_rows, 1)
workspace = torch.empty((rows, cols), dtype=dtype, device=device)
try:
yield workspace
finally:
with lock:
_WORKSPACE_CACHE[key] = workspace
def _device_supports_bfloat16(device: torch.device) -> bool:
cache_key = _device_cache_key(device)
cached = _BF16_SUPPORT_CACHE.get(cache_key)
if cached is not None:
return cached
dev = torch.device(device)
if dev.type == "meta":
_BF16_SUPPORT_CACHE[cache_key] = False
return False
try:
a = torch.zeros((1, 1), dtype=torch.bfloat16, device=dev)
b = torch.zeros((1, 1), dtype=torch.bfloat16, device=dev)
_ = torch.matmul(a, b)
support = True
except Exception:
support = False
_BF16_SUPPORT_CACHE[cache_key] = support
return support
def get_number_of_rows_and_cols(layer: nn.Module):
# return layer.weight.shape[0], np.prod(layer.weight.shape[1:])
if isinstance(layer, NamedModule):
layer = layer.module
if isinstance(layer, transformers.Conv1D):
# transformers.Conv1D: weight shape is (n_in, n_out)
return layer.weight.shape[1], layer.weight.shape[0]
else:
# weight shape is (n_out, n_in)
return layer.weight.shape[0], np.prod(layer.weight.shape[1:])
class GPTQ:
def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None):
self.lock = threading.Lock()
self.rows, self.columns = get_number_of_rows_and_cols(module)
if isinstance(module, NamedModule):
self.module = module.module
self.name = module.name
self._named_module = module
self.layer_index = getattr(module, "layer_index", -1)
self.full_name = getattr(module, "full_name", self.name)
else:
self.name = HF_OPTIMUM
self.module = module
self._named_module = None
self.layer_index = -1
self.full_name = self.name
self._original_rows = self.rows
self._original_columns = self.columns
if self._named_module is not None:
pad_info = self._named_module.state.get("tp_pad_info")
else:
pad_info = getattr(self.module, "_tp_pad_info", None)
if isinstance(pad_info, dict):
pad_cols = int(pad_info.get("pad_cols", 0) or 0)
pad_cols = max(pad_cols, 0)
else:
pad_info = None
pad_cols = 0
self._tp_pad_info = pad_info
self._tp_pad_cols = pad_cols
if self._tp_pad_cols:
self.columns += self._tp_pad_cols
module_device = get_device(self.module)
setattr(self.module, "target_device", module_device)
self.device = module_device
if module_device.type == "meta":
self._final_hessian_device_hint = torch.device("cpu")
else:
self._final_hessian_device_hint = torch.device(module_device)
self._validate_module(self.module)
self.qcfg = qcfg if qcfg else QuantizeConfig() # HF compat will not pass qcfg
self.module_copy = None
self.H: Optional[torch.Tensor] = None
self.nsamples = 0
self.quantizer = self.create_quantizer(name=self.name)
# fwd counter
self.fwd_counter = 0
self.fail_safe = False
# Store per-device Hessian contributions so multi-GPU calibration can
# keep local accumulators and merge only once when quantization begins.
self._device_hessian_partials: Dict[torch.device, torch.Tensor] = {}
self._device_sample_counts: Dict[torch.device, int] = {}
self._hessian_dirty: bool = False
# ===== Fair-GPTQ debiasing state =====
self.alpha: float = float(getattr(self.qcfg, "alpha", 0.0))
self.log_delta_w: bool = bool(getattr(self.qcfg, "log_delta_w", False))
self.log_hessian_sensitivity: bool = bool(
getattr(self.qcfg, "log_hessian_sensitivity", False)
)
self.sum_hessians: int = int(getattr(self.qcfg, "sum_hessians", 0))
self.nsamples_debias: int = 0
self.H_x01: Optional[torch.Tensor] = None
ignored_bias = False
ignore_bias_cfg = getattr(self.qcfg, "ignore_bias", None)
if ignore_bias_cfg:
for name_unbias in ignore_bias_cfg.split(","):
name_unbias = name_unbias.strip()
if name_unbias and name_unbias in self.full_name:
self.alpha = 0.0
ignored_bias = True
found_layer = False
select_layer_cfg = getattr(self.qcfg, "select_layer", None)
if select_layer_cfg:
for layer_index_i in select_layer_cfg.split(","):
layer_index_i = layer_index_i.strip()
if not layer_index_i:
continue
if int(self.layer_index) == int(layer_index_i):
if self.alpha != 0.0 and not ignored_bias:
found_layer = True
self.alpha = float(getattr(self.qcfg, "alpha", self.alpha))
if not found_layer:
# if layer is not in selected set
self.alpha = 0.0
if self.alpha:
log.info(f"Selected Layer for debiasing: {self.full_name}")
if self.sum_hessians:
log.info("Using H2inv Hessian for deltaW.")
# ===== end Fair-GPTQ debiasing state =====
@staticmethod
def _validate_module(module):
assert isinstance(
module, (nn.Linear, nn.Conv1d, nn.Conv2d, transformers.Conv1D)
), f"We support only linear and convolutional layers. actual = `{module}`"
def create_quantizer(self, name: str) -> Quantizer:
return Quantizer(qcfg=self.qcfg, name=name)
def shape(self):
if hasattr(self, "module"):
return self.module.weight.shape
else:
return (0, 0)
def _mock_hessian_inverse(self, H: torch.Tensor):
"""Mock hessian inverse for fast testing"""
damp = self.qcfg.damp_percent
# Return identity matrix instead of complex inversion
identity = torch.eye(H.shape[0], dtype=torch.float32, device=H.device)
return identity, damp
def _clone_module(self, copy=True, device: torch.device = None):
if not device:
device = self.module.weight.data.device
clone = self.module.weight.data.to(copy=copy, device=device)
if isinstance(self.module, _ConvNd):
clone = clone.flatten(1)
if isinstance(self.module, transformers.pytorch_utils.Conv1D):
clone = clone.t()
if self._tp_pad_cols:
pad = torch.zeros(
(clone.shape[0], self._tp_pad_cols),
dtype=clone.dtype,
device=clone.device,
)
clone = torch.cat((clone, pad), dim=1)
return clone.float()
@staticmethod
def _truncate_last_dim(tensor: torch.Tensor, length: int) -> torch.Tensor:
if tensor.dim() == 0:
return tensor
trim = min(length, tensor.shape[-1])
if trim == tensor.shape[-1]:
return tensor
return tensor.narrow(tensor.dim() - 1, 0, trim).contiguous()
def add_batch(self, inp: torch.Tensor, out: torch.Tensor, batch_index: Optional[int] = None):
print("INP SHAPE")
print(inp.shape)
print(self.full_name)
print("="*15)
batch_token_size, xtx, device = self.process_batch(inp)
if batch_token_size == 0 or xtx is None:
return
dev = torch.device(device)
with self.lock:
self.fwd_counter += 1
existing = self._device_hessian_partials.get(dev)
if existing is None:
self._device_hessian_partials[dev] = xtx
else:
existing.add_(xtx)
del xtx
self._device_sample_counts[dev] = (
self._device_sample_counts.get(dev, 0) + batch_token_size
)
self.nsamples += batch_token_size
self._hessian_dirty = True
def _preferred_staging_dtype(self, input_dtype: torch.dtype, device: torch.device) -> torch.dtype:
device = torch.device(device)
if not self.qcfg.hessian_use_bfloat16_staging:
return torch.float32
if input_dtype not in (torch.float16, torch.bfloat16):
return torch.float32
if not _device_supports_bfloat16(device):
return torch.float32
return torch.bfloat16
def _resolve_hessian_chunk_size(self, rows: int, stage_dtype: torch.dtype) -> Optional[int]:
if rows == 0:
return None
cfg_chunk = self.qcfg.hessian_chunk_size
if cfg_chunk is not None:
return max(1, min(cfg_chunk, rows))
bytes_budget = self.qcfg.hessian_chunk_bytes
if bytes_budget is not None:
bytes_per_row = self.columns * torch.tensor([], dtype=stage_dtype).element_size()
if bytes_per_row > 0:
chunk_rows = bytes_budget // bytes_per_row
if chunk_rows > 0:
return max(1, min(int(chunk_rows), rows))
return 1
return None
@contextlib.contextmanager
def _borrow_materialized_chunk_fp32(
self,
chunk: torch.Tensor,
rows: int,
) -> torch.Tensor:
if rows == 0:
yield chunk.new_zeros((0, self.columns), dtype=torch.float32)
return
device = chunk.device
stage_dtype = self._preferred_staging_dtype(chunk.dtype, device)
with _lease_workspace(device, stage_dtype, self.columns, rows) as staging_workspace:
staging_view = staging_workspace[:rows, :]
staging_view.copy_(chunk.to(dtype=stage_dtype))
if stage_dtype == torch.float32:
try:
yield staging_view
finally:
if device.type == "cuda":
torch.cuda.current_stream(device).synchronize()
else:
with _lease_workspace(device, torch.float32, self.columns, rows) as fp32_workspace:
try:
fp32_view = fp32_workspace[:rows, :]
fp32_view.copy_(staging_view.to(torch.float32))
yield fp32_view
finally:
if device.type == "cuda":
torch.cuda.current_stream(device).synchronize()
def _compute_hessian_xtx(self, matrix: torch.Tensor) -> torch.Tensor:
rows = matrix.shape[0]
if rows == 0:
return torch.zeros(
(self.columns, self.columns),
dtype=torch.float32,
device=matrix.device,
)
stage_dtype = self._preferred_staging_dtype(matrix.dtype, matrix.device)
chunk_size = self._resolve_hessian_chunk_size(rows, stage_dtype)
if chunk_size is None:
mat32 = matrix.to(dtype=torch.float32)
return torch.matmul(mat32.T, mat32)
xtx_accum = torch.zeros(
(self.columns, self.columns),
dtype=torch.float32,
device=matrix.device,
)
for start in range(0, rows, chunk_size):
rows_this = min(chunk_size, rows - start)
source = matrix[start : start + rows_this]
with self._borrow_materialized_chunk_fp32(source, rows_this) as materialized:
materialized32 = materialized
xtx_accum.add_(torch.matmul(materialized32.T, materialized32))
return xtx_accum
def _accumulate_debias_from_inp(self, inp: torch.Tensor, inp_device: torch.device) -> None:
"""
Port of Fair-GPTQ debiasing term H_x01 from the old GPTQ implementation.
Uses the original [batch, seq, hidden] input before GPTQ reshaping.
"""
if not self.alpha:
return
# Only handle linear / Conv1D style inputs where last dim is the hidden dimension.
if not isinstance(self.module, (nn.Linear, transformers.Conv1D)):
return
device = getattr(self.module, "target_device", self.device)
inp1 = inp.to(device=device)
# print("INP. SHAPE")
# print(inp1.shape)
# Expect [B, T, H]; if [T, H], make it [1, T, H]
if inp1.dim() == 2:
inp1 = inp1.unsqueeze(0)
if inp1.dim() != 3:
# We only know how to do this for [batch, seq, hidden]
return
batch_size, seq_len, hidden_dim = inp1.shape
# Same logic as old code: if batch_size != 2, split sequence in half into 2 groups
if batch_size != 2:
real_seq_len = seq_len // 2
if real_seq_len == 0:
return
inp1 = inp1.reshape(2, real_seq_len, hidden_dim)
tmp = inp1.shape[0] # should be 2
X0 = inp1[0] # [seq_len', hidden_dim]
X1 = inp1[1] # [seq_len', hidden_dim]
X0 = X0.t() # [hidden_dim, seq_len']
X1 = X1.t()
# Initialize or rescale H_x01
cols = self.columns
if self.H_x01 is None:
try:
self.H_x01 = torch.zeros(
(cols, cols),
device=device,
dtype=torch.float32,
)
except RuntimeError:
log.info("Memory: OOM H_x01 allocate bypass")
if device.type == "cuda" and torch.cuda.is_available():
torch.cuda.empty_cache()
self.H_x01 = torch.zeros(
(cols, cols),
device="cpu",
dtype=torch.float32,
).to(device)
else:
n_old = float(self.nsamples_debias)
if n_old > 0.0:
self.H_x01.mul_(n_old / (n_old + tmp))
samples_Hx01 = (float(self.nsamples_debias) + tmp) / 2.0
if samples_Hx01 <= 0.0:
samples_Hx01 = tmp
delta = math.sqrt(2.0 / samples_Hx01) * (X0 - X1).float() # [H, L]
# If tensor-parallel padding exists, pad rows up to self.columns
if self._tp_pad_cols:
pad_rows = self._tp_pad_cols
pad = delta.new_zeros((pad_rows, delta.shape[1]))
delta_full = torch.cat((delta, pad), dim=0) # [columns, L]
else:
delta_full = delta # [columns, L] when columns == hidden_dim
try:
self.H_x01.add_(self.alpha * (delta_full @ delta_full.t()))
except RuntimeError:
log.info("Memory: OOM cpu bypass for debias matmul")
dev0 = self.H_x01.device
H_x01_cpu = self.H_x01.to("cpu")
delta_cpu = delta_full.to("cpu")
H_x01_cpu.add_(self.alpha * (delta_cpu @ delta_cpu.t()))
self.H_x01 = H_x01_cpu.to(dev0)
self.nsamples_debias += int(tmp)
def process_batch(self, inp: torch.Tensor) -> Tuple[int, Optional[torch.Tensor], torch.device]:
# print(f"inp = {inp}")
# print(f"self.module = {self.module} device = {self.module.target_device}")
inp_device = get_device(inp)
# Debiasing accumulation on original input (Fair-GPTQ)
if self.alpha:
self._accumulate_debias_from_inp(inp, inp_device)
# input reshaping
if isinstance(self.module, (nn.Linear, transformers.Conv1D)):
reshaped_inp = inp.reshape(-1, inp.shape[-1])
else:
if isinstance(self.module, nn.Conv1d):
reshaped_inp = inp.reshape(
inp.size(0) * self.module.groups,
inp.size(1) // self.module.groups,
inp.shape[2],
1,
)
unfold = nn.Unfold(
self.module.kernel_size + (1,),
dilation=self.module.dilation + (1,),
padding=self.module.padding + (0,),
stride=self.module.stride + (1,),
)
# output size (batch_size, channels * \prod kernel_size, num_patches)
reshaped_inp = unfold(reshaped_inp)
else:
reshaped_inp = inp.reshape(
inp.size(0) * self.module.groups,
inp.size(1) // self.module.groups,
inp.shape[2],
inp.shape[3],
)
unfold = nn.Unfold(
self.module.kernel_size,
dilation=self.module.dilation,
padding=self.module.padding,
stride=self.module.stride,
)
# output size (batch_size, channels * \prod kernel_size, num_patches)
reshaped_inp = unfold(reshaped_inp)
reshaped_inp = reshaped_inp.transpose(1, 2).flatten(0, 1)
# Delay dtype conversion until we materialize Hessian chunks to avoid unnecessary temporaries
reshaped_inp = reshaped_inp.contiguous()
if self._tp_pad_cols:
pad = reshaped_inp.new_zeros((reshaped_inp.shape[0], self._tp_pad_cols))
reshaped_inp = torch.cat((reshaped_inp, pad), dim=1)
canonical_device = torch.device(inp_device)
batch_token_size = reshaped_inp.shape[0]
if batch_token_size == 0:
del reshaped_inp
return 0, None, canonical_device
try:
xtx = self._compute_hessian_xtx(reshaped_inp).to(dtype=torch.float32)
except RuntimeError as exc:
if (
torch.device(inp_device).type == "cuda"
and "out of memory" in str(exc).lower()
):
log.warn(
"GPTQ module '%s' fell back to CPU Hessian accumulation due to GPU OOM during batch processing.",
getattr(self, "name", "<unknown>"),
)
reshaped_inp_cpu = reshaped_inp.to(device=torch.device("cpu"))
del reshaped_inp
if torch.cuda.is_available():
torch.cuda.empty_cache()
canonical_device = torch.device("cpu")
xtx = self._compute_hessian_xtx(reshaped_inp_cpu).to(dtype=torch.float32)
xtx = xtx.detach()
del reshaped_inp_cpu
else:
del reshaped_inp
raise
else:
xtx = xtx.detach()
del reshaped_inp
return batch_token_size, xtx, canonical_device
def _select_hessian_target_device(self, requested: Optional[torch.device]) -> torch.device:
if requested is not None:
return torch.device(requested)
hint = getattr(self, "_final_hessian_device_hint", None)
if hint is not None:
return torch.device(hint)
if self._device_hessian_partials:
partial_device = next(iter(self._device_hessian_partials.keys()))
return torch.device(partial_device)
return torch.device("cpu")
def _materialize_global_hessian(self, target_device: Optional[torch.device] = None) -> None:
device = self._select_hessian_target_device(target_device)
with self.lock:
if not self._hessian_dirty and self.H is not None:
if self.H.device != device:
self.H = self.H.to(device=device)
return
total_samples = sum(self._device_sample_counts.values())
# Reuse the existing tensor when possible to avoid an extra allocation.
reuse_buffer = (
self.H is not None
and self.H.shape == (self.columns, self.columns)
and self.H.device == device
)
result_accum: torch.Tensor
if reuse_buffer and self.H.dtype == torch.float32:
result_accum = self.H
result_accum.zero_()
else:
result_accum = torch.zeros(
(self.columns, self.columns),
dtype=torch.float32,
device=device,
)
if total_samples == 0:
self.H = result_accum
self.nsamples = 0
self._hessian_dirty = False
self._final_hessian_device_hint = device
self._device_hessian_partials.clear()
self._device_sample_counts.clear()
return
for partial_device, partial in self._device_hessian_partials.items():
if partial.device != result_accum.device or partial.dtype != torch.float32:
tmp = partial.to(device=result_accum.device, dtype=torch.float32)
result_accum.add_(tmp)
del tmp
else:
result_accum.add_(partial)
result_accum.mul_(2.0 / float(total_samples))
self.H = result_accum
self.nsamples = total_samples
self._hessian_dirty = False
self._final_hessian_device_hint = result_accum.device
self._device_hessian_partials.clear()
self._device_sample_counts.clear()
del result_accum
def finalize_hessian(self, target_device: Optional[torch.device] = None) -> torch.Tensor:
self._materialize_global_hessian(target_device=target_device)
if self.H is None:
self.H = torch.zeros(
(self.columns, self.columns),
dtype=torch.float32,
device=self._select_hessian_target_device(target_device),
)
return self.H
# FIXME, optimum needs fasterquant, we need to remove it
def fasterquant(
self,
blocksize=128,
percdamp=0.01,
damp_auto_increment=0.0015,
group_size=-1,
actorder=False,
static_groups=False,
):
return self.hf_quantize(
blocksize, percdamp, damp_auto_increment, group_size, actorder, static_groups
)
# public api exposed to hf
def hf_quantize(
self,
blocksize=128,
percdamp=0.01,
damp_auto_increment=0.0015,
group_size=-1,
actorder=False,
static_groups=False,
act_group_aware: Optional[bool] = None,
):
self.qcfg.group_size = group_size
self.qcfg.damp_percent = percdamp
self.qcfg.damp_auto_increment = damp_auto_increment
self.qcfg.desc_act = actorder
if act_group_aware is not None:
self.qcfg.act_group_aware = act_group_aware
self.qcfg._resolve_activation_ordering(actorder, act_group_aware)
self.qcfg.static_groups = static_groups
(
Q,
scale,
zero,
g_idx,
duration,
avg_loss,
damp_percent,
nsamples,
) = self.quantize(blocksize=blocksize)
self.module.weight.data = Q
return scale, zero, g_idx, duration, avg_loss, damp_percent
@torch.inference_mode()
def hessian_inverse(self, H: torch.Tensor):
# Capture a writable view of the Hessian diagonal so we can restore it between attempts.
diag_view = H.diagonal()
orig_diag = diag_view.clone()
# When a block is numerically singular, pure damping can stall at 1.0.
# Prepare a tiny diagonal floor (relative to the largest entry) that we
# only inject if the normal damping loop fails. Keeping the scale near 1e-6
# of the dominant entry keeps the bias negligible for healthy layers while
# still rescuing pathological Hessian blocks.
base_abs_max = torch.max(orig_diag.abs()).item()
if not math.isfinite(base_abs_max) or base_abs_max == 0.0:
base_abs_max = 1.0
floor_base = base_abs_max * 1e-6
max_floor_attempts = 6
used_damp = self.qcfg.damp_percent
last_error = None
attempt = 0
while attempt <= max_floor_attempts:
if attempt == 0:
current_diag = orig_diag
else:
floor_increment = floor_base * math.pow(10.0, attempt - 1)
current_diag = torch.clamp(
orig_diag + floor_increment, min=floor_increment
)
if attempt == 1:
log.warn(
f"Quantization: Module `{self.name}` -> Applying Hessian diagonal floor (+{floor_increment:.2e}) to recover positive definiteness."
)
else:
log.warn(
f"Quantization: Module `{self.name}` -> Increasing Hessian diagonal floor to +{floor_increment:.2e}."
)
diag_view.copy_(current_diag)
mean = torch.mean(current_diag)
damp = self.qcfg.damp_percent
damp_recovery_started = False
recovery_initial_damp = None
recovery_last_damp = None
while 0 < damp < 1:
try:
diag_view.add_(damp * mean)
H2 = torch.linalg.cholesky(H)
Hinv_result = torch.linalg.cholesky(
torch.cholesky_inverse(H2), upper=True
)
diag_view.copy_(current_diag)
del H2
used_damp = damp
if damp_recovery_started:
log.warn(
f"Quantization: Module `{self.name}` -> Damp recovery succeeded at `damp_percent={damp:.5f}` "
f"(started at {recovery_initial_damp:.5f})."
)
return Hinv_result, used_damp
except torch._C._LinAlgError as e:
last_error = e
diag_view.copy_(current_diag)
if self.qcfg.damp_auto_increment != 0:
if not damp_recovery_started:
damp_recovery_started = True
recovery_initial_damp = damp
log.warn(
f"Quantization: Module `{self.name}` -> Starting damp recovery at "
f"`damp_percent={damp:.5f}`, increment step `{self.qcfg.damp_auto_increment:.5f}`."
)
damp += self.qcfg.damp_auto_increment
recovery_last_damp = damp
else:
log.warn(
f"Quantization: Module `{self.name}` -> Hessian Cholesky failed with `damp_percent={damp:.5f}` and no auto increment configured."
)
break
if damp_recovery_started:
final_damp = recovery_last_damp if recovery_last_damp is not None else damp
log.warn(
f"Quantization: Module `{self.name}` -> Damp recovery failed after reaching `damp_percent={final_damp:.5f}`."
)
attempt += 1
log.error(
f"Quantization: Module `{self.name}` -> Hessian remained non positive-definite after diagonal floor attempts. Last `damp_percent` tried = {damp:.5f}."
)
if last_error is not None:
log.debug(f"Hessian failure detail: {last_error}")
return None, 1.0
@torch.inference_mode()
def quantize(
self,
blocksize=128,
):
# log.info(f"Quantization `{self.name}` using samples: `{self.nsamples}`")
start = time.time()
target_device = getattr(self.module, "target_device", None)
self.finalize_hessian(target_device=target_device)
if self.qcfg.mock_quantization:
# Use simplified hessian inverse (identity matrix)
self.hessian_inverse = self._mock_hessian_inverse
# TODO: waiting for pytorch implementation of ops for MPS
if sys.platform == "darwin" and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "1":
raise RuntimeError(
"For MacOS you must set env `PYTORCH_ENABLE_MPS_FALLBACK=1` before running quantization."
)
if self.module_copy is None:
W = self._clone_module(device=self.H.device)
else:
W = self.module_copy.to(device=self.H.device)
del self.module_copy
self.quantizer.find_params(W, weight=True)
dead = torch.diag(self.H) == 0
self.H[dead, dead] = 1
W[:, dead] = 0
scale = []
zero = []
now_idx = 1
if self.qcfg.static_groups:
import copy
groups = []
for i in range(0, self.columns, self.qcfg.group_size):
quantizer = copy.deepcopy(self.quantizer)
quantizer.find_params(
W[:, i : (i + self.qcfg.group_size)], weight=True
)
scale.append(quantizer.scale)
zero.append(quantizer.zero)
groups.append(quantizer)
perm = None
invperm = None
final_perm = None
global_perm = None
if self.qcfg.desc_act:
perm = torch.argsort(torch.diag(self.H), descending=True)
W = W[:, perm]
self.H = self.H[perm][:, perm]
invperm = torch.argsort(perm)
elif self.qcfg.act_group_aware:
diag_h = torch.diag(self.H)
local_perms, local_values = compute_local_perms(
diag_h, self.qcfg.group_size, return_values=True
)
global_perm = compute_global_perm(
diag_h,
self.qcfg.group_size,
precomputed_values=local_values,
)
del local_values
final_perm = compose_final_perm(
local_perms, global_perm, self.qcfg.group_size
)
W = W[:, final_perm]
self.H = self.H[final_perm][:, final_perm]
# Align H_x01 (if present) with the same basis as H/W
H_x01 = self.H_x01
if H_x01 is not None:
H_x01 = H_x01.to(device=self.H.device, dtype=torch.float32)
if self.qcfg.desc_act and perm is not None:
H_x01 = H_x01[perm][:, perm]
elif self.qcfg.act_group_aware and final_perm is not None:
H_x01 = H_x01[final_perm][:, final_perm]
self.H_x01 = H_x01
# Optional combined Hessian H_2 = H + H_x01 for sum_hessians mode
H_2 = None
if self.alpha and self.sum_hessians and self.H_x01 is not None:
H_2 = (self.H + self.H_x01).to(dtype=torch.float32, device=self.H.device)
Losses = torch.zeros_like(W)
Q = torch.zeros_like(W)
Hinv, damp = self.hessian_inverse(self.H)
# Optional inverse for H_2 (H + H_x01) used only for debias deltaW
H2inv = None
if H_2 is not None:
H2inv, _ = self.hessian_inverse(H_2)
# Fair-GPTQ debias ΔW correction (before quantization loop)
if self.alpha and self.H_x01 is not None and Hinv is not None:
H_x01 = self.H_x01.to(device=W.device, dtype=torch.float32)
inv_for_delta = H2inv if (self.sum_hessians and H2inv is not None) else Hinv
if inv_for_delta is not None:
if self.log_delta_w:
W_old = W.clone()
# tmp = H_x01 * W^T (cols x cols) * (cols x rows) -> (cols x rows)
tmp = H_x01 @ W.t()
# tmp2 = inv_for_delta^T * (inv_for_delta * tmp)
tmp2 = inv_for_delta.t() @ (inv_for_delta @ tmp)
# W -= tmp2^T
W.sub_(tmp2.t())
if self.log_delta_w:
delta_W = W - W_old
amplitude = torch.norm(delta_W, p=2).item()
amp_rel = amplitude / (torch.norm(W_old, p=2) + 1e-12).item()
fname = "./logs/delta_w.txt"
full_name = getattr(self, "full_name", self.name)
try:
with open(fname, "a") as f:
f.write(
f"{full_name} abs={amplitude:.6g} rel={amp_rel:.6g}\n"
)
except Exception as e:
log.warn(f"Failed to log delta_w for `{full_name}`: {e}")
# Use simplified loop when mock_quantization is active
if self.qcfg.mock_quantization or (self.fail_safe and self.fwd_counter == 0):
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1
W1 = W[:, i1:i2]
Q1 = torch.zeros_like(W1)
# Handle group quantization parameters efficiently (similar to original)
if self.qcfg.group_size != -1:
if not self.qcfg.static_groups:
# Find parameters for entire groups at once (optimized)
group_start_cols = list(range(i1, i2, self.qcfg.group_size))
for group_start in group_start_cols:
group_end = min(
group_start + self.qcfg.group_size, self.columns
)
if group_start < group_end:
self.quantizer.find_params(
W[:, group_start:group_end], weight=True
)
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
now_idx += 1
else:
# Static groups - use pre-computed groups
for i in range(count):
idx = i1 + i
if self.qcfg.desc_act and perm is not None:
idx = perm[idx]
self.quantizer = groups[idx // self.qcfg.group_size]
# Vectorized quantization for the entire block (major optimization)
if len(scale) > 0 and len(zero) > 0:
# Use latest scale and zero for the entire block
latest_scale = scale[-1]
latest_zero = zero[-1]
# Reshape scales and zeros to match block dimensions
if latest_scale.dim() == 1:
latest_scale = latest_scale.view(-1, 1)
if latest_zero.dim() == 1:
latest_zero = latest_zero.view(-1, 1)
# Apply quantization formula using the cloned weights W1
maxq_val = 2 ** self.qcfg.bits - 1
if self.qcfg.sym:
# Symmetric quantization: Q = scale * clamp(round(x/scale), -maxq/2, maxq/2)
Q1 = latest_scale * torch.clamp(
torch.round(W1 / latest_scale),
-(maxq_val // 2),
maxq_val // 2,
)
else:
# Asymmetric quantization: Q = scale * (clamp(round(x/scale) + zero, 0, maxq) - zero)
quantized = torch.clamp(
torch.round(W1 / latest_scale) + latest_zero,
0,
maxq_val,
)
Q1 = latest_scale * (quantized - latest_zero)