forked from NVIDIA/Model-Optimizer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmultinode_ptq.py
More file actions
381 lines (315 loc) · 12.3 KB
/
multinode_ptq.py
File metadata and controls
381 lines (315 loc) · 12.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-node PTQ (Post-Training Quantization) with FSDP2 support."""
import argparse
import json
import os
import random
import time
import warnings
from pathlib import Path
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from accelerate import Accelerator
from example_utils import build_quant_cfg, get_tokenizer
from tqdm import tqdm
from transformers import AutoModelForCausalLM, PreTrainedTokenizer, PreTrainedTokenizerFast
import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
from modelopt.torch.export import get_model_type
from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format
from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint
from modelopt.torch.quantization.config import need_calibration
from modelopt.torch.quantization.utils import patch_fsdp_mp_dtypes
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader, get_supported_datasets
# Constants
RAND_SEED = 1234
QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = {
"int8": mtq.INT8_DEFAULT_CFG,
"int4_awq": mtq.INT4_AWQ_CFG,
"fp8": mtq.FP8_DEFAULT_CFG,
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
"nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG,
"w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG,
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
}
KV_QUANT_CFG_CHOICES = {
"none": "none",
"fp8": "FP8_KV_CFG",
"nvfp4": "NVFP4_KV_CFG",
"nvfp4_affine": "NVFP4_AFFINE_KV_CFG",
}
# Enable HuggingFace checkpointing
mto.enable_huggingface_checkpointing()
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Multi-node post-training quantization with FSDP2")
parser.add_argument(
"--pyt_ckpt_path",
required=True,
help="Path to PyTorch checkpoint",
)
parser.add_argument(
"--qformat",
default="fp8",
choices=QUANT_CFG_CHOICES.keys(),
help="Quantization format",
)
parser.add_argument(
"--kv_cache_qformat",
default="fp8",
choices=list(KV_QUANT_CFG_CHOICES.keys()),
help="KV cache quantization format",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size for calibration",
)
parser.add_argument(
"--calib_size",
type=str,
default="512",
help="Comma-separated list of calibration sizes per dataset",
)
parser.add_argument(
"--dataset",
help=(
f"name of a dataset, or a comma separated list of datasets. "
f"dataset choices are {get_supported_datasets()}"
),
type=str,
default=None,
)
parser.add_argument(
"--export_path",
default="exported_model",
help="Directory to export the quantized model",
)
parser.add_argument(
"--trust_remote_code",
action="store_true",
help="Trust remote code for HuggingFace models",
)
parser.add_argument("--awq_block_size", default=0, type=int)
args = parser.parse_args()
# Parse comma-separated lists
args.dataset = args.dataset.split(",") if args.dataset else None
args.calib_size = [int(x) for x in args.calib_size.split(",")]
return args
def load_and_prepare_model(
model_path: str,
calib_dataloader: torch.utils.data.DataLoader,
accelerator: Accelerator,
trust_remote_code: bool = False,
) -> tuple[nn.Module, str, list[str], torch.utils.data.DataLoader]:
"""Load model and prepare it for FSDP2 distributed execution.
Args:
model_path: Path to the HuggingFace model
calibration_dataloader: Calibration dataloader to be sharded for calibration
accelerator: Accelerate's Accelerator instance
trust_remote_code: Whether to trust remote code
Returns:
Tuple of (prepared_model, model_type, original_architectures, calibration_dataloader)
"""
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype="auto",
trust_remote_code=trust_remote_code,
)
model.eval()
model_type = get_model_type(model)
# Need the original architectures for export
# FSDP prefix is added to the architectures for FSDP2 wrapped models
original_architectures = model.config.architectures
# FSDP2 requires an optimizer to be prepared together with the model
dummy_optimizer = torch.optim.SGD(model.parameters(), lr=0.0)
model, _, calibration_dataloader = accelerator.prepare(model, dummy_optimizer, calib_dataloader)
return model, model_type, original_architectures, calibration_dataloader
def create_calibration_dataloader(
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
dataset_names: list[str],
calib_sizes: list[int],
batch_size: int,
) -> torch.utils.data.DataLoader:
"""Create calibration dataloader from dataset.
Args:
tokenizer: HuggingFace tokenizer
dataset_names: List of dataset names (defaults to cnn_dailymail)
calib_sizes: Number of samples for each dataset
batch_size: Batch size for calibration
Returns:
DataLoader for calibration
"""
return get_dataset_dataloader(
dataset_name=dataset_names,
tokenizer=tokenizer,
batch_size=batch_size,
num_samples=calib_sizes,
device=None, # Keep data on CPU, calibration loop handles device transfer
include_labels=False,
)
def create_fsdp2_calibration_loop(
model: nn.Module,
dataloader: torch.utils.data.DataLoader,
accelerator: Accelerator,
):
"""Create calibration loop compatible with FSDP2.
For FSDP2, we need to use the outer FSDP-wrapped model instead of
the parameter passed by mtq.quantize to properly handle DTensor.
Args:
model: FSDP2-wrapped model
dataloader: Calibration dataloader
accelerator: Accelerator instance for device management
Returns:
Calibration function compatible with mtq.quantize
"""
def calibrate(unwrapped_model):
"""Calibration loop that uses the FSDP-wrapped model."""
for batch in tqdm(dataloader, desc="Calibrating"):
if isinstance(batch, dict):
batch = {
k: v.to(accelerator.device) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()
}
# Use outer model (FSDP-wrapped), not the parameter
# Important: We should forward pass using the unwrapped model
# mtq.quantize will unwrap the model & pass to the forward_loop
model(**batch)
return calibrate
def export_model(
model: nn.Module,
accelerator: Accelerator,
export_path: str | Path,
architectures: list[str],
):
"""Export quantized model to HuggingFace format.
Args:
model: Quantized model
accelerator: Accelerator instance for state dict gathering
export_path: Directory to export model to
"""
export_dir = Path(export_path)
export_dir.mkdir(parents=True, exist_ok=True)
post_state_dict, hf_quant_config = _export_transformers_checkpoint(
model, torch.bfloat16, accelerator=accelerator
)
if accelerator.is_main_process:
# Save hf_quant_config.json for backward compatibility
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
json.dump(hf_quant_config, file, indent=4)
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
# Save model
model.save_pretrained(export_dir, state_dict=post_state_dict, save_modelopt_state=False)
original_config = f"{export_dir}/config.json"
config_data = {}
with open(original_config) as file:
config_data = json.load(file)
config_data["quantization_config"] = hf_quant_config
# Update config architectures to use original architectures that does not have FSDP prefix
config_data["architectures"] = architectures
with open(original_config, "w") as file:
json.dump(config_data, file, indent=4)
def main(args):
"""Main quantization workflow."""
# Validate GPU availability
if not torch.cuda.is_available():
raise OSError("GPU is required for quantization.")
# Validate quantization format
if args.qformat not in QUANT_CFG_CHOICES:
raise ValueError(
f"Quantization format {args.qformat} not supported. Choose from: {QUANT_CFG_CHOICES.keys()}"
)
# Set random seeds
random.seed(RAND_SEED)
np.random.seed(RAND_SEED)
torch.manual_seed(RAND_SEED)
# Initialize accelerator
accelerator = Accelerator()
print(f"Rank: {os.environ.get('RANK', 'Not set')}")
print(f"World Size: {os.environ.get('WORLD_SIZE', 'Not set')}")
print(f"Local Rank: {os.environ.get('LOCAL_RANK', 'Not set')}")
# Load tokenizer
tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code)
default_padding_side = tokenizer.padding_side
tokenizer.padding_side = "left" # Left padding for better calibration
# Set default dataset if not provided
if args.dataset is None:
args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"]
warnings.warn(
"No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2."
)
# Adjust calib_size to match dataset length by extending or truncating as needed
args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[
: len(args.dataset)
]
# Create calibration dataloader with max batch size
calib_dataloader = create_calibration_dataloader(
tokenizer=tokenizer,
dataset_names=args.dataset,
calib_sizes=args.calib_size,
batch_size=args.batch_size,
)
# Load and prepare model
model, model_type, original_architectures, calib_dataloader = load_and_prepare_model(
model_path=args.pyt_ckpt_path,
calib_dataloader=calib_dataloader,
accelerator=accelerator,
trust_remote_code=args.trust_remote_code,
)
# Build quantization config
quant_cfg = build_quant_cfg(
args.qformat,
args.kv_cache_qformat,
args.awq_block_size,
model_type,
QUANT_CFG_CHOICES,
KV_QUANT_CFG_CHOICES,
)
# Quantize the model
if accelerator.is_main_process:
print("Starting quantization...")
start_time = time.time()
if need_calibration(quant_cfg):
calibrate_fn = create_fsdp2_calibration_loop(model, calib_dataloader, accelerator)
else:
calibrate_fn = None
warnings.warn("Dynamic quantization. Calibration skipped.")
with torch.no_grad():
model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_fn)
elapsed = time.time() - start_time
if accelerator.is_main_process:
print(f"Quantization completed in {elapsed:.2f}s")
mtq.print_quant_summary(model)
start_time = time.time()
export_model(model, accelerator, args.export_path, original_architectures)
elapsed = time.time() - start_time
if accelerator.is_main_process:
# Restore default padding and export the tokenizer as well.
if tokenizer is not None:
tokenizer.padding_side = default_padding_side
tokenizer.save_pretrained(args.export_path)
# Export the model
print(f"Export completed in {elapsed:.2f}s")
print(f"Model exported to {args.export_path}")
print("Unpatching FSDP2 MP dtypes")
if __name__ == "__main__":
args = parse_args()
# This context manager can be removed once the update to FSDP2 function is reflected in torch
with patch_fsdp_mp_dtypes():
main(args)