-
Notifications
You must be signed in to change notification settings - Fork 438
Expand file tree
/
Copy pathlifecycle.py
More file actions
63 lines (51 loc) · 1.86 KB
/
lifecycle.py
File metadata and controls
63 lines (51 loc) · 1.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
from compressed_tensors.quantization import (
QuantizationScheme,
initialize_module_for_quantization,
)
from llmcompressor.modifiers.quantization.calibration import (
apply_calibration_status,
freeze_module_quantization,
initialize_observer,
update_weight_global_scale,
update_weight_zp_scale,
)
from llmcompressor.observers.helpers import flatten_for_calibration
__all__ = [
"initialize_quantized_linear",
"validate_weight_for_quantization",
"calibrate_global_scale",
"calibrate_scale_zp",
]
def validate_weight_for_quantization(
weight: torch.Tensor, scheme: QuantizationScheme, tensor_name: str
):
if weight.ndim != 2:
raise ValueError(
f"Unable to quantize tensor `{tensor_name}`: expected 2D linear weight, "
f"but got shape {tuple(weight.shape)}"
)
try:
flatten_for_calibration(weight, "weight", scheme.weights)
except Exception as exc:
raise ValueError(f"Unable to quantize tensor `{tensor_name}`: {exc}") from exc
def initialize_quantized_linear(
weight: torch.Tensor, scheme: QuantizationScheme, device: str | torch.device
) -> torch.nn.Module:
out_features, in_features = weight.shape
module = torch.nn.Linear(
in_features, out_features, bias=False, device=device, dtype=weight.dtype
)
module.weight.data.copy_(weight)
initialize_module_for_quantization(module, scheme, force_zero_point=False)
return module
def calibrate_global_scale(module: torch.nn.Linear):
initialize_observer(module, "weight")
apply_calibration_status(module)
update_weight_global_scale(module)
freeze_module_quantization(module)
def calibrate_scale_zp(module: torch.nn.Linear):
initialize_observer(module, "weight")
apply_calibration_status(module)
update_weight_zp_scale(module)
freeze_module_quantization(module)