-
Notifications
You must be signed in to change notification settings - Fork 438
Expand file tree
/
Copy pathcalibration.py
More file actions
311 lines (247 loc) · 9.78 KB
/
calibration.py
File metadata and controls
311 lines (247 loc) · 9.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
from typing import Any, Optional
import torch
from compressed_tensors.quantization import (
DynamicType,
QuantizationArgs,
QuantizationStatus,
QuantizationStrategy,
)
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
from compressed_tensors.utils import (
align_module_device,
getattr_chain,
update_offload_parameter,
)
from loguru import logger
from torch.nn import Module
from llmcompressor.observers import Observer
__all__ = [
"initialize_observer",
"update_weight_zp_scale",
"calibrate_input_hook",
"calibrate_output_hook",
"freeze_module_quantization",
"apply_calibration_status",
"reset_quantization_status",
"update_weight_global_scale",
"calibrate_query_hook",
"calibrate_key_hook",
"calibrate_value_hook",
"recompute_qparams_from_observer",
]
def initialize_observer(
module: Module,
base_name: str,
):
"""
Initialize observer module and attach as submodule.
The name of the observer is fetched from the quantization_args.
The name is then used to load the observer from the registry and attached
to the module. The name of the observer uses the base_name provided.
This function always initializes memoryless observers for weights
:param module: torch.nn.Module that the observer is being attached to
:param base_name: str used to name the observer attribute
"""
if base_name == "weight":
arg_name = "weights"
elif base_name == "output":
arg_name = "output_activations"
else: # input, q, k, v
arg_name = "input_activations"
args: QuantizationArgs = getattr_chain(
module, f"quantization_scheme.{arg_name}", None
)
observer = args.observer
# training is no longer supported: always use memoryless for weights
if base_name == "weight" and args.observer in ("static_minmax", "minmax"):
observer = "memoryless_minmax"
logger.warning(
"Overriding weight observer for lower memory usage "
f"({args.observer} -> {observer})",
log_once=True,
)
if base_name == "weight" and args.observer in ("mse",):
observer = "memoryless_mse"
logger.warning(
"Overriding weight observer for lower memory usage "
f"({args.observer} -> {observer})",
log_once=True,
)
if args is not None and args.dynamic is not True:
observer = Observer.load_from_registry(
observer, base_name=base_name, args=args, module=module
)
module.register_module(f"{base_name}_observer", observer)
def call_observer(
module: Module,
base_name: str,
value: Optional[torch.Tensor] = None,
should_calculate_gparam: bool = False,
should_calculate_qparams: bool = True,
):
"""
Call a module's attached input/weight/output observer using a provided value.
Update the module's scale and zp using the observer's return values.
:param module: torch.nn.Module
:param base_name: substring used to fetch the observer, scales, and zp
:param value: torch.Tensor to be passed to the observer for activations. If
base_name is "weight", then the module's weight tensor will be used
"""
with align_module_device(module):
if value is None and base_name == "weight":
value = module.weight
observer: Observer = getattr(module, f"{base_name}_observer")
if should_calculate_gparam:
global_scale = observer.get_global_scale(value)
update_offload_parameter(module, f"{base_name}_global_scale", global_scale)
if should_calculate_qparams:
scale, zero_point = observer(value)
update_offload_parameter(module, f"{base_name}_scale", scale)
if hasattr(module, f"{base_name}_zero_point"):
update_offload_parameter(module, f"{base_name}_zero_point", zero_point)
def update_weight_global_scale(module: Module):
if getattr_chain(module, "quantization_scheme.weights", None) is None:
return
if (
getattr_chain(module, "quantization_scheme.weights.strategy", None)
!= QuantizationStrategy.TENSOR_GROUP
):
return
call_observer(
module,
base_name="weight",
should_calculate_gparam=True,
should_calculate_qparams=False,
)
def update_weight_zp_scale(module: Module):
"""
marks a layer as ready for calibration which activates observers
to update scales and zero points on each forward pass
apply to full model with `model.apply(update_weight_zp_scale)`
:param module: module to set for calibration
:param quantize_weights_upfront: whether to automatically
run weight quantization at the start of calibration
"""
if getattr_chain(module, "quantization_scheme.weights", None) is None:
return
if getattr(module, "quantization_status", None) != QuantizationStatus.CALIBRATION:
logger.warning(
"Attempting to calibrate weights of a module not in calibration mode"
)
call_observer(module=module, base_name="weight")
def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
"""
Calibrate input or output activations by calling the a module's attached
observer.
:param module: torch.nn.Module
:param base_name: substring used to fetch the observer, scales, and zp
:param value: torch.Tensor to be passed to the observer
"""
# If empty tensor, can't update zp/scale
# Case for MoEs
if value.numel() == 0:
return
field_name = "input" if base_name != "output" else "output" # input,q,k,v,output
args_attr = f"quantization_scheme.{field_name}_activations"
quantization_args = getattr_chain(module, args_attr, None)
calculate_qparams = True
calculate_gparam = False
if quantization_args is not None:
if quantization_args.dynamic in (True, DynamicType.LOCAL):
calculate_qparams = False
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
calculate_gparam = True
call_observer(
module=module,
base_name=base_name,
value=value,
should_calculate_gparam=calculate_gparam,
should_calculate_qparams=calculate_qparams,
)
def calibrate_input_hook(module: Module, args: Any):
"""
Hook to calibrate input activations.
Will call the observers to update the scales/zp before applying
input QDQ in the module's forward pass.
"""
args = args[0] if isinstance(args, tuple) else args
calibrate_activations(module, value=args, base_name="input")
def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor):
"""
Hook to calibrate output activations.
Will call the observers to update the scales/zp before applying
output QDQ.
"""
calibrate_activations(
module,
value=output,
base_name="output",
)
output = forward_quantize(
module=module,
value=output,
base_name="output",
args=module.quantization_scheme.output_activations,
)
return output
def calibrate_query_hook(module: Module, query_states: torch.Tensor):
calibrate_activations(module, query_states, base_name="q")
def calibrate_key_hook(module: Module, key_states: torch.Tensor):
calibrate_activations(module, key_states, base_name="k")
def calibrate_value_hook(module: Module, value_states: torch.Tensor):
calibrate_activations(module, value_states, base_name="v")
def recompute_qparams_from_observer(module: Module, base_name: str):
"""
Recompute scale and zero_point from an observer's accumulated
past_min_vals/past_max_vals. Used after DDP all-reduce to update
qparams from synchronized statistics.
:param module: module with quantization parameters
:param base_name: "input", "output", "q", "k", or "v"
"""
from compressed_tensors.quantization.utils import calculate_qparams
observer: Observer = getattr(module, f"{base_name}_observer", None)
if observer is None:
return
min_vals = getattr(observer, "past_min_vals", None)
max_vals = getattr(observer, "past_max_vals", None)
if min_vals is None or max_vals is None:
return
global_scale = getattr(module, f"{base_name}_global_scale", None)
scale, zero_point = calculate_qparams(
min_vals=min_vals,
max_vals=max_vals,
quantization_args=observer.args,
global_scale=global_scale,
)
update_offload_parameter(module, f"{base_name}_scale", scale)
if hasattr(module, f"{base_name}_zero_point"):
update_offload_parameter(module, f"{base_name}_zero_point", zero_point)
def apply_calibration_status(module: Module):
scheme = getattr(module, "quantization_scheme", None)
if not scheme:
# no quantization scheme nothing to do
return
module.quantization_status = QuantizationStatus.CALIBRATION
def freeze_module_quantization(module: Module):
"""
deletes observers when calibration is complete.
apply to full model with `model.apply(freeze_module_quantization)`
:param module: module to freeze quantization for
"""
scheme = getattr(module, "quantization_scheme", None)
if not scheme:
# no quantization scheme nothing to do
return
if module.quantization_status == QuantizationStatus.FROZEN:
# nothing to do, already frozen
return
# remove observers
for name in ("input", "weight", "output", "q", "k", "v"):
obs_name = f"{name}_observer"
if hasattr(module, obs_name):
delattr(module, obs_name)
module.quantization_status = QuantizationStatus.FROZEN
def reset_quantization_status(model: Module):
for module in model.modules():
if hasattr(module, "quantization_status"):
delattr(module, "quantization_status")