-
Notifications
You must be signed in to change notification settings - Fork 438
Expand file tree
/
Copy pathprocess.py
More file actions
257 lines (205 loc) · 9.56 KB
/
process.py
File metadata and controls
257 lines (205 loc) · 9.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
import os
from collections import defaultdict
from typing import Iterable
import torch
from compressed_tensors.entrypoints.convert import Converter
from compressed_tensors.quantization import QuantizationScheme
from compressed_tensors.utils import match_quantizable_tensors
from safetensors.torch import load_file, save_file
from torch.nn import Module
from llmcompressor.entrypoints.model_free.lifecycle import (
calibrate_global_scale,
calibrate_scale_zp,
compress_module,
initialize_quantized_linear,
validate_weight_for_quantization,
)
from llmcompressor.entrypoints.model_free.microscale import (
get_fused_names,
is_microscale_scheme,
)
__all__ = [
"validate_file",
"process_file",
"process_file_microscale_scheme",
]
def validate_file(
file_path: str | os.PathLike,
save_path: str | os.PathLike,
scheme: QuantizationScheme,
ignore: Iterable[str],
device: str | torch.device,
converter: Converter | None = None,
):
"""
Validate that each quantizable tensor in a safetensors file can be quantized.
:param file_path: safetensors file to validate
:param scheme: quantization scheme to apply to tensors
:param ignore: modules to ignore. Modules ending with "norm" are automatically
ignored
:param converter: optional converter to apply to the checkpoint,
e.g. conversion of some layers from some format to compressed-tensors
"""
tensors = load_file(file_path)
if converter is not None:
converter.validate(tensors)
for _, name in match_quantizable_tensors(tensors, ignore, scheme.targets):
validate_weight_for_quantization(tensors[name], scheme, name)
def process_file(
file_path: str | os.PathLike,
save_path: str | os.PathLike,
scheme: QuantizationScheme,
ignore: Iterable[str],
device: str | torch.device,
converter: Converter | None = None,
) -> tuple[int, dict[str, str]]:
"""
Quantize and compress tensors in a given safetensors file
:param file_path: safetensors file to process
:param save_path: save path of file with quantized weights
:param scheme: quantization scheme to apply to tensors
:param ignore: modules to ignore. Modules ending with "norm" are automatically
ignored
:param device: device used to quantize and compress weights
:param converter: optional converter to apply to the checkpoint,
e.g. conversion of some layers from some format to compressed-tensors
"""
assert not is_microscale_scheme(scheme), "Use `_process_file_microscale_scheme`"
tensors = load_file(file_path)
tensors = split_fused_moe_experts(tensors)
if converter is not None:
converter.process(tensors)
for module_name, name in match_quantizable_tensors(tensors, ignore, scheme.targets):
validate_weight_for_quantization(tensors[name], scheme, name)
# 1. initialize module with qparams (on device)
module = initialize_quantized_linear(tensors[name], scheme, device)
# 2. calibrate weight qparams
calibrate_scale_zp(module)
# 3. compress module using qparams
compress_module(module)
# 4. save compressed data (on cpu)
del tensors[name]
prefix = module_name + "."
for key, value in module.state_dict(prefix=prefix).items():
tensors[key] = value.to("cpu")
save_file(tensors, save_path)
total_size = sum(tensor.nbytes for tensor in tensors.values())
weight_map = {key: os.path.basename(save_path) for key in tensors.keys()}
return total_size, weight_map
def process_file_microscale_scheme(
file_path: str | os.PathLike,
save_path: str | os.PathLike,
scheme: QuantizationScheme,
ignore: Iterable[str],
device: str | torch.device,
converter: Converter | None = None,
) -> tuple[int, dict[str, str]]:
"""
Quantize and compress tensors in a given safetensors file
:param file_path: safetensors file to process
:param save_path: save path of file with quantized weights
:param scheme: quantization scheme to apply to tensors
:param ignore: modules to ignore. Modules ending with "norm" are automatically
ignored
:param device: device used to quantize and compress weights
:param converter: optional converter to apply to the checkpoint,
e.g. conversion of some layers from some format to compressed-tensors
"""
assert is_microscale_scheme(scheme), "Use `_process_file` for non-microscale scheme"
tensors = load_file(file_path)
if converter is not None:
converter.process(tensors)
fused_sets, unmatched_sets = get_fused_names(tensors)
assert len(unmatched_sets) <= 0 # should be caught by `validate_safetensors_index`
fused_name_to_fused_index: dict[str, int] # fused_name -> fused_index
fused_modules: dict[int, dict[str, Module]] # fused_index -> named_modules
fused_name_to_fused_index = {
name: index
for index, matched_set in enumerate(fused_sets)
for name in matched_set.values()
}
fused_modules = defaultdict(dict)
for module_name, name in match_quantizable_tensors(tensors, ignore, scheme.targets):
validate_weight_for_quantization(tensors[name], scheme, name)
# 1. initialize module with qparams (on device)
module = initialize_quantized_linear(tensors[name], scheme, device)
# 2. calibrate weight qparams. Delay scale/zp calibration for fused modules
calibrate_global_scale(module)
if name in fused_name_to_fused_index:
fused_index = fused_name_to_fused_index[name]
fused_modules[fused_index][name] = module
continue
calibrate_scale_zp(module)
# 3. compress module using qparams
compress_module(module)
# 4. save compressed data (on cpu)
del tensors[name]
prefix = module_name + "."
for key, value in module.state_dict(prefix=prefix).items():
tensors[key] = value.to("cpu")
# compress and save miscroscale fused modules
for named_modules in fused_modules.values():
# 2.1. fuse global scales
global_scales = [m.weight_global_scale for m in named_modules.values()]
fused_global_scale = torch.min(torch.cat(global_scales, dim=0))
for name, module in named_modules.items():
module_name, _ = name.rsplit(".", 1)
module.weight_global_scale.data.copy_(fused_global_scale)
# 2.2. finish calibration with fused global scales
calibrate_scale_zp(module)
# 3. compress module using miscroscale qparams
compress_module(module)
# 4. save compressed data (on cpu)
del tensors[name]
prefix = module_name + "."
for key, value in module.state_dict(prefix=prefix).items():
tensors[key] = value.to("cpu")
save_file(tensors, save_path)
total_size = sum(tensor.nbytes for tensor in tensors.values())
weight_map = {key: os.path.basename(save_path) for key in tensors.keys()}
return total_size, weight_map
def split_fused_moe_experts(tensors: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
Find fused MoE experts (with gate_up_proj/down_proj).
Split them from 3D tensors into individual 2D expert tensors.
Args:
tensors: Dictionary of loaded tensors from safetensors file
Returns:
New dictionary with split expert weights
"""
_tensors = {}
for name, tensor in tensors.items():
# Check if this is a MoE expert weight (3D tensor for experts)
if tensor.ndim == 3 and ("experts.gate_up_proj" in name or "experts.down_proj" in name):
# Get number of experts
num_experts = tensor.shape[0]
if "gate_up_proj" in name:
# gate_up_proj is typically [num_experts, 2*intermediate, hidden]
if tensor.shape[1] % 2 != 0:
print(f"Warning: gate_up_proj {name} has odd second dimension: {tensor.shape}")
continue
hidden_size = tensor.shape[1] // 2
# Split into individual experts
for expert_idx in range(num_experts):
expert_tensor = tensor[expert_idx] # [2*hidden, intermediate]
# Split gate and up projections
gate_proj = expert_tensor[:hidden_size, :]
up_proj = expert_tensor[hidden_size:, :]
# Create new key names
base_key = name.replace("mlp.experts.gate_up_proj", f"mlp.experts.{expert_idx}")
_tensors[base_key + ".gate_proj.weight"] = gate_proj
_tensors[base_key + ".up_proj.weight"] = up_proj
print(f"Split {name} into {num_experts} experts")
elif "down_proj" in name:
# down_proj is typically [num_experts, hidden, intermediate]
# Split into individual experts
for expert_idx in range(num_experts):
down_proj = tensor[expert_idx] # [hidden, intermediate]
# Create new key name
new_key = name.replace("mlp.experts.down_proj", f"mlp.experts.{expert_idx}") + ".down_proj.weight"
_tensors[new_key] = down_proj
print(f"Split {name} into {num_experts} experts")
else:
# Non-MoE or non-3D tensors, keep as is
_tensors[name] = tensor
return _tensors