-
Notifications
You must be signed in to change notification settings - Fork 490
Expand file tree
/
Copy pathprocess.py
More file actions
295 lines (241 loc) · 11.1 KB
/
process.py
File metadata and controls
295 lines (241 loc) · 11.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import os
from collections import defaultdict
from typing import Iterable
import torch
from compressed_tensors.compressors import compress_module
from compressed_tensors.entrypoints.convert import Converter
from compressed_tensors.quantization import QuantizationScheme
from compressed_tensors.utils import match_quantizable_tensors
from compressed_tensors.utils.safetensors_load import (
InverseWeightMap,
load_tensors_from_inverse_weight_map,
)
from loguru import logger
from safetensors.torch import save_file
from torch.nn import Module
from llmcompressor.entrypoints.model_free.lifecycle import (
calibrate_weight,
initialize_quantized_linear,
validate_weight_for_quantization,
)
from llmcompressor.entrypoints.model_free.microscale import (
get_fused_names,
is_microscale_scheme,
)
from llmcompressor.modifiers.quantization.calibration import (
apply_calibration_status,
freeze_module_quantization,
initialize_observer,
observe,
update_qparams,
)
from llmcompressor.observers import Observer
__all__ = [
"validate_file",
"process_file",
"process_file_microscale_scheme",
]
def validate_file(
inverse_weight_map: InverseWeightMap,
save_path: str | os.PathLike,
scheme: QuantizationScheme,
ignore: Iterable[str],
device: str | torch.device,
converter: Converter | None = None,
):
"""
Validate that each quantizable tensor in a safetensors file can be quantized.
:param inverse_weight_map: mapping of source file path -> tensor names to validate
:param save_path: save path of file with quantized weights
:param scheme: quantization scheme to apply to tensors
:param ignore: modules to ignore. Modules ending with "norm" are automatically
ignored
:param device: device used to quantize and compress weights
:param converter: optional converter to apply to the checkpoint,
e.g. conversion of some layers from some format to compressed-tensors
"""
tensors = load_tensors_from_inverse_weight_map(inverse_weight_map, device)
if converter is not None:
converter.validate(tensors)
for _, name in match_quantizable_tensors(tensors, ignore, scheme.targets):
validate_weight_for_quantization(tensors[name], scheme, name)
def process_file(
inverse_weight_map: InverseWeightMap,
save_path: str | os.PathLike,
scheme: QuantizationScheme,
ignore: Iterable[str],
device: str | torch.device,
converter: Converter | None = None,
) -> tuple[int, dict[str, str]]:
"""
Quantize and compress tensors in a given safetensors file.
:param inverse_weight_map: mapping of source file path -> tensor names.
For standard mode: {{resolved_path: None}} means load all tensors to process
:param save_path: save path of file with quantized weights
:param scheme: quantization scheme to apply to tensors
:param ignore: modules to ignore. Modules ending with "norm" are automatically
ignored
:param device: device used to quantize and compress weights
:param converter: optional converter to apply to the checkpoint,
e.g. conversion of some layers from some format to compressed-tensors
"""
assert not is_microscale_scheme(scheme), "Use `process_file_microscale_scheme`"
tensors = load_tensors_from_inverse_weight_map(inverse_weight_map, device)
tensors = split_fused_moe_experts(tensors)
if converter is not None:
converter.process(tensors)
for module_name, name in match_quantizable_tensors(tensors, ignore, scheme.targets):
validate_weight_for_quantization(tensors[name], scheme, name)
# 1. initialize module with qparams (on device)
module = initialize_quantized_linear(tensors[name], scheme, device)
# 2. calibrate weight qparams
calibrate_weight(module)
# 3. compress module using qparams
compress_module(module)
# 4. save compressed data (on cpu)
del tensors[name]
prefix = module_name + "."
for key, value in module.state_dict(prefix=prefix).items():
tensors[key] = value.to("cpu")
save_file(tensors, save_path)
total_size = sum(tensor.nbytes for tensor in tensors.values())
weight_map = {key: os.path.basename(save_path) for key in tensors.keys()}
return total_size, weight_map
def process_file_microscale_scheme(
inverse_weight_map: InverseWeightMap,
save_path: str | os.PathLike,
scheme: QuantizationScheme,
ignore: Iterable[str],
device: str | torch.device,
converter: Converter | None = None,
) -> tuple[int, dict[str, str]]:
"""
Quantize and compress tensors for a single output shard using a microscale
scheme (NVFP4, MXFP4).
Accepts a precomputed inverse_weight_map that specifies exactly which tensors
to load from which source files — including any fused partner tensors from
other shards needed for global scale computation. This avoids runtime
discovery of fused partners and redundant tensor reads.
Partner tensors fetched from other shards are re-saved into this shard's
output. The caller updates the safetensors index to reflect new locations.
:param inverse_weight_map: mapping of resolved source file path ->
list of tensor names to load from that file.
Example: {"/path/shard0.safetensors": ["q_proj.weight"],
"/path/shard1.safetensors": ["k_proj.weight", "v_proj.weight"]}
:param save_path: output path for this shard's compressed weights
:param scheme: microscale quantization scheme (NVFP4, MXFP4)
:param ignore: modules to ignore. Modules ending with "norm" are automatically
ignored
:param device: device used to quantize and compress weights
:param converter: optional converter to apply to the checkpoint,
e.g. conversion of some layers from some format to compressed-tensors
"""
assert is_microscale_scheme(scheme), "Use `process_file` for non-microscale scheme"
tensors = load_tensors_from_inverse_weight_map(inverse_weight_map, device)
tensors = split_fused_moe_experts(tensors)
if converter is not None:
converter.process(tensors)
# Get fused sets. Non-primary shards may have incomplete sets (k/v without q)
# since only the primary-owning shard fetches partners — this is correct.
fused_sets, _ = get_fused_names(list(tensors.keys()))
fused_name_to_fused_index: dict[str, int] = {
name: index
for index, matched_set in enumerate(fused_sets)
for name in matched_set.values()
if name is not None
}
fused_modules: dict[int, dict[str, Module]] = defaultdict(dict)
for module_name, name in match_quantizable_tensors(tensors, ignore, scheme.targets):
validate_weight_for_quantization(tensors[name], scheme, name)
# 1. initialize module with qparams (on device)
module = initialize_quantized_linear(tensors[name], scheme, device)
# gather fused modules for later processing
if name in fused_name_to_fused_index:
fused_index = fused_name_to_fused_index[name]
fused_modules[fused_index][name] = module
initialize_observer(module, "weight")
apply_calibration_status(module)
continue
# 2. get module qparams
calibrate_weight(module)
# 3. compress module using qparams
compress_module(module)
# 4. save compressed data (on cpu)
del tensors[name]
prefix = module_name + "."
for key, value in module.state_dict(prefix=prefix).items():
tensors[key] = value.to("cpu")
# Compress fused modules with shared global scale
for named_modules in fused_modules.values():
# 2. fuse observers, observe weights, and get qparams
Observer.fuse([mod.weight_observer for mod in named_modules.values()])
observe(named_modules.values(), base_name="weight")
update_qparams(named_modules.values(), base_name="weight")
for name, module in named_modules.items():
freeze_module_quantization(module)
# 3. compress module using microscale qparams
compress_module(module)
# 4. save compressed data (on cpu)
del tensors[name]
module_name, _ = name.rsplit(".", 1)
prefix = module_name + "."
for key, value in module.state_dict(prefix=prefix).items():
tensors[key] = value.to("cpu")
# Save ALL tensors to this shard's output — including partner tensors fetched
# from other shards. Partners are re-saved here so future runs don't need to
# re-fetch them. The caller updates the safetensors index to reflect new locations.
os.makedirs(os.path.dirname(os.path.abspath(save_path)), exist_ok=True)
save_file(tensors, save_path)
total_size = sum(t.nbytes for t in tensors.values())
weight_map = {key: os.path.basename(save_path) for key in tensors.keys()}
return total_size, weight_map
def split_fused_moe_experts(
tensors: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
"""
Find fused MoE experts (with gate_up_proj/down_proj).
Split them from 3D tensors into individual 2D expert tensors.
Args:
tensors: Dictionary of loaded tensors from safetensors file
Returns:
split_tensors: New dictionary with split expert weights
"""
split_tensors = {}
params_to_split = {
# If a 3D gate_up_proj layer is found, split it into a
# 2D gate_proj and up_proj layer for each expert
"gate_up_proj": ["gate_proj", "up_proj"],
# If a 3D down_proj layer is found, split it into a
# 2D down_proj layer for each expert
"down_proj": ["down_proj"],
}
for name, tensor in tensors.items():
keys_to_split = [key for key in params_to_split if key in name]
if len(keys_to_split) >= 2:
raise ValueError(f"Found multiple keys matching {name}: {keys_to_split}")
elif len(keys_to_split) == 1 and tensor.ndim == 3:
unsplit_name = keys_to_split[0]
split_names = params_to_split[unsplit_name]
# Get number of experts
num_experts = tensor.shape[0]
if tensor.shape[1] % len(split_names) != 0:
raise ValueError(
f"{unsplit_name} expects a second dimension divisible by "
f"{len(split_names)} but got shape: {tensor.shape}"
)
# Split into experts
intermediate_size = tensor.shape[1] // len(split_names)
for expert_idx in range(num_experts):
expert_tensor = tensor[expert_idx]
# Split into layers
split_layers = expert_tensor.split(intermediate_size, dim=0)
for split_name, split_layer in zip(split_names, split_layers):
key = name.replace(
unsplit_name, f"{expert_idx}.{split_name}.weight"
)
split_tensors[key] = split_layer
logger.info(f"Split {name} into {num_experts} experts")
else:
# Non-MoE or non-3D tensors, keep as is
split_tensors[name] = tensor
return split_tensors