26
26
27
27
# from functools import reduce
28
28
# from math import gcd
29
- from typing import Dict , Optional , Callable , Any , List
29
+ from typing import Any , Callable , Dict , List , Optional
30
30
31
31
import torch
32
32
import torch .nn as nn
37
37
from torchao .quantization .quant_api import (
38
38
int4_weight_only ,
39
39
Int4WeightOnlyQuantizer ,
40
+ int8_weight_only ,
40
41
Int8DynActInt4WeightQuantizer ,
41
42
quantize_ ,
42
43
)
45
46
find_multiple ,
46
47
get_device_str ,
47
48
get_precision ,
48
- set_precision ,
49
49
name_to_dtype ,
50
+ set_precision ,
50
51
state_dict_device ,
51
52
use_et_backend ,
52
53
)
60
61
61
62
import inspect
62
63
64
+
63
65
def get_named_parameters (func : Callable ) -> List [str ]:
64
66
# Get the signature of the function
65
67
signature = inspect .signature (func )
66
-
68
+
67
69
# Extract the parameters from the signature
68
70
parameters = signature .parameters
69
-
71
+
70
72
# Filter and return named parameters
71
73
named_params = [
72
- name for name , param in parameters .items ()
73
- if param .kind in (inspect .Parameter .POSITIONAL_OR_KEYWORD , inspect .Parameter .KEYWORD_ONLY )
74
+ name
75
+ for name , param in parameters .items ()
76
+ if param .kind
77
+ in (inspect .Parameter .POSITIONAL_OR_KEYWORD , inspect .Parameter .KEYWORD_ONLY )
74
78
]
75
79
return named_params
76
80
77
- def validate_args (named_params : List [str ], q_kwargs : Dict [str , Any ], quantizer : Optional [str ] = None ) -> Dict [str , Any ]:
81
+
82
+ def validate_args (
83
+ named_params : List [str ], q_kwargs : Dict [str , Any ], quantizer : Optional [str ] = None
84
+ ) -> Dict [str , Any ]:
78
85
for key in q_kwargs .keys ():
79
86
if key not in named_params :
80
- print (f"Specification for quantizer { quantizer } has extraneous key { key } . Ignoring." )
87
+ print (
88
+ f"Specification for quantizer { quantizer } has extraneous key { key } . Ignoring."
89
+ )
81
90
del q_kwargs [key ]
82
91
return q_kwargs
83
-
84
-
92
+
93
+
85
94
#########################################################################
86
95
### torchchat quantization API ###
87
96
@@ -110,21 +119,30 @@ def quantize_model(
110
119
if quantizer not in quantizer_class_dict :
111
120
raise RuntimeError (f"unknown quantizer { quantizer } specified" )
112
121
else :
122
+ ao_quant = True
113
123
# Use tensor subclass API for int4 weight only.
114
124
if device == "cuda" and quantizer == "linear:int4" :
115
125
quantize_ (model , int4_weight_only (q_kwargs ["groupsize" ]))
126
+ elif quantizer == "linear:int8" :
127
+ print ("quantizer is linear int8" )
128
+ quantize_ (model , int8_weight_only ())
129
+ else :
130
+ ao_quant = False
131
+ if ao_quant :
116
132
if not support_tensor_subclass :
117
133
unwrap_tensor_subclass (model )
118
134
continue
119
-
135
+
120
136
if quantizer in ["linear:a8wxdq" , "embedding:wx" ]:
121
137
# These quantizers require float32 input weights. Note that after quantization,
122
138
# the weights will no longer be float32, but lowbit integers
123
139
if get_precision () != torch .float32 :
124
- print (f"Quantizer { quantizer } requires float32 inputs, but received { get_precision ()} . Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32." )
140
+ print (
141
+ f"Quantizer { quantizer } requires float32 inputs, but received { get_precision ()} . Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32."
142
+ )
125
143
set_precision (torch .float32 )
126
-
127
- # We set global precision from quantize options if it is specified at cli.py:485
144
+
145
+ # We set global precision from quantize options if it is specified at cli.py:485
128
146
# so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
129
147
precision = get_precision ()
130
148
@@ -141,14 +159,19 @@ def quantize_model(
141
159
model = quant_handler .quantize (model )
142
160
143
161
144
-
145
162
#########################################################################
146
163
### QuantHandler API definition ###
147
164
### (unify with torchao in future) ###
148
165
149
166
150
167
class QuantHandler :
151
- def __init__ (self , model : Optional [nn .Module ] = None , device = "cpu" , precision = None , tokenizer = None ):
168
+ def __init__ (
169
+ self ,
170
+ model : Optional [nn .Module ] = None ,
171
+ device = "cpu" ,
172
+ precision = None ,
173
+ tokenizer = None ,
174
+ ):
152
175
self .model_ = model
153
176
self .device = device
154
177
self .tokenizer = tokenizer
@@ -176,7 +199,15 @@ def quantize(self, model: nn.Module) -> nn.Module:
176
199
177
200
178
201
class PrecisionHandler (QuantHandler ):
179
- def __init__ (self , model : Optional [nn .Module ]= None , device = "cpu" , precision = None , tokenizer = None , * , dtype ):
202
+ def __init__ (
203
+ self ,
204
+ model : Optional [nn .Module ] = None ,
205
+ device = "cpu" ,
206
+ precision = None ,
207
+ tokenizer = None ,
208
+ * ,
209
+ dtype ,
210
+ ):
180
211
self .model_ = model
181
212
self .device = device
182
213
self .tokenizer = tokenizer
@@ -205,7 +236,15 @@ def quantized_model(self) -> nn.Module:
205
236
206
237
207
238
class ExecutorHandler (QuantHandler ):
208
- def __init__ (self , model : Optional [nn .Module ]= None , device = "cpu" , precision = None , tokenizer = None , * , accelerator ):
239
+ def __init__ (
240
+ self ,
241
+ model : Optional [nn .Module ] = None ,
242
+ device = "cpu" ,
243
+ precision = None ,
244
+ tokenizer = None ,
245
+ * ,
246
+ accelerator ,
247
+ ):
209
248
self .model_ = model
210
249
211
250
if isinstance (accelerator , str ):
@@ -529,147 +568,6 @@ def linear_int8_et(input, weight, scales):
529
568
)
530
569
531
570
532
- class WeightOnlyInt8Linear (nn .Module ):
533
- __constants__ = ["in_features" , "out_features" ]
534
- in_features : int
535
- out_features : int
536
- weight : torch .Tensor
537
- scales : torch .Tensor
538
-
539
- def __init__ (
540
- self ,
541
- in_features ,
542
- out_features ,
543
- bias = None ,
544
- device = None ,
545
- dtype = None ,
546
- * ,
547
- weight : Optional [torch .Tensor ] = None ,
548
- scales : Optional [torch .Tensor ] = None ,
549
- groupsize : Optional [int ] = None ,
550
- ):
551
- super ().__init__ ()
552
- if dtype is None :
553
- dtype = torch .get_default_dtype ()
554
-
555
- if device is None :
556
- device = "cpu"
557
-
558
- assert not bias , "Bias is not supported by LinearInt8"
559
- self .in_features = in_features
560
- self .out_features = out_features
561
-
562
- assert (weight is None ) == bool (
563
- scales is None
564
- ), "must specify both weights and scales, or neither"
565
- if weight is None :
566
- weight = torch .empty (
567
- (out_features , in_features ),
568
- dtype = torch .int8 ,
569
- device = device ,
570
- )
571
- if groupsize is None or (groupsize == 0 ):
572
- scales = torch .empty (out_features , dtype = dtype , device = device )
573
- else :
574
- n_groups = (in_features + groupsize - 1 ) // groupsize
575
- scales = torch .empty (out_features , n_groups , dtype = dtype , device = device )
576
-
577
- self .register_buffer ("weight" , weight .to (device ))
578
- self .register_buffer ("scales" , scales .to (device ))
579
-
580
- if use_et_backend ():
581
- self .forward = self .et_forward
582
- else :
583
- self .forward = self .aoti_forward
584
-
585
- def aoti_forward (self , input : torch .Tensor ) -> torch .Tensor :
586
- return linear_int8_aoti (input , self .weight , self .scales )
587
-
588
- def et_forward (self , input : torch .Tensor ) -> torch .Tensor :
589
- return linear_int8_et (input , self .weight , self .scales )
590
-
591
-
592
- class WeightOnlyInt8QuantHandler (QuantHandler ):
593
- def __init__ (
594
- self ,
595
- model : Optional [nn .Module ] = None ,
596
- device = None ,
597
- precision = None ,
598
- tokenizer = None ,
599
- * ,
600
- node_type : str = "*" ,
601
- bitwidth : Optional [int ] = None ,
602
- groupsize : Optional [int ] = None ,
603
- ):
604
- self .model_ = model
605
- self .device = device
606
- self .groupsize = groupsize
607
- self .node_type = node_type
608
- if bitwidth is None :
609
- self .bitwidth = 8
610
- else :
611
- self .bitwidth = bitwidth
612
-
613
- @torch .no_grad ()
614
- def quantize (self , module ):
615
- # cur_state_dict = state_dict_device(self.model_.state_dict())
616
- # dict_device = "cpu" # self.device
617
-
618
- if self .bitwidth == 4 :
619
- range_min = - 8
620
- range_max = 7
621
- elif self .bitwidth == 8 :
622
- range_min = - 128
623
- range_max = 127
624
- else :
625
- raise ValueError (f"Unsupported bitwidth { self .bitwidth } " )
626
-
627
- for name , child in module .named_children ():
628
- # print(f"name: {name}")
629
- if isinstance (child , nn .Linear ):
630
- if (
631
- (self .node_type == "*" )
632
- or (self .node_type == "output" and name == "output" )
633
- or (self .node_type == "!output" and name != "output" )
634
- ):
635
- # print(f"{name, child}")
636
- input_weight = child .weight .float ()
637
- # print(f"{name, child}")
638
- # print(f"in_features: {child.in_features}")
639
- # print(f"out_features: {child.out_features}")
640
-
641
- # print(f"expanded weight shape {input_weight.shape}")
642
- weight , scales , _ = dynamically_quantize_per_channel (
643
- input_weight ,
644
- range_min ,
645
- range_max ,
646
- torch .int8 ,
647
- self .groupsize ,
648
- scales_dtype = child .weight .dtype ,
649
- )
650
-
651
- setattr (
652
- module ,
653
- name ,
654
- WeightOnlyInt8Linear (
655
- in_features = child .in_features ,
656
- out_features = child .out_features ,
657
- device = self .device ,
658
- # update variables from quantization
659
- weight = weight ,
660
- scales = scales ,
661
- groupsize = self .groupsize ,
662
- ),
663
- )
664
- else :
665
- self .quantize (child )
666
-
667
- return module
668
-
669
- def quantized_model (self ) -> nn .Module :
670
- return self .quantize (self .model_ )
671
-
672
-
673
571
#########################################################################
674
572
##### embedding table quantization ######
675
573
### (unify with torchao in future) ###
@@ -886,10 +784,10 @@ def quantized_model(self) -> nn.Module:
886
784
# class references
887
785
quantizer_class_dict = {
888
786
"embedding" : EmbeddingOnlyQuantHandler ,
889
- "linear:int8" : WeightOnlyInt8QuantHandler ,
890
787
"precision" : PrecisionHandler ,
891
788
"executor" : ExecutorHandler ,
892
789
"linear:int4" : Int4WeightOnlyQuantizer ,
790
+ "linear:int8" : int8_weight_only ,
893
791
"linear:a8w4dq" : Int8DynActInt4WeightQuantizer ,
894
792
}
895
793
@@ -932,11 +830,16 @@ def quantized_model(self) -> nn.Module:
932
830
print ("Slow fallback kernels will be used." )
933
831
934
832
except Exception as e :
833
+
935
834
class ErrorHandler (QuantHandler ):
936
- def __init__ (self , model : Optional [nn .Module ]= None , device = "cpu" , precision = None ):
835
+ def __init__ (
836
+ self , model : Optional [nn .Module ] = None , device = "cpu" , precision = None
837
+ ):
937
838
global torchao_experimental_load_error
938
- raise Exception (f"Note: Failed to load torchao experimental quantizer with error: { torchao_experimental_load_error } " )
939
-
839
+ raise Exception (
840
+ f"Note: Failed to load torchao experimental quantizer with error: { torchao_experimental_load_error } "
841
+ )
842
+
940
843
torchao_experimental_load_error = e
941
844
quantizer_class_dict ["linear:a8wxdq" ] = ErrorHandler
942
845
quantizer_class_dict ["embedding:wx" ] = ErrorHandler
0 commit comments