-
Notifications
You must be signed in to change notification settings - Fork 357
Expand file tree
/
Copy pathexport.py
More file actions
136 lines (106 loc) · 4.45 KB
/
export.py
File metadata and controls
136 lines (106 loc) · 4.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import warnings
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
import modelopt.torch.opt as mto
from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format
from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint
from modelopt.torch.opt.conversion import restore_from_modelopt_state
from modelopt.torch.quantization.utils import set_quantizer_state_dict
from modelopt.torch.utils import print_rank_0
RAND_SEED = 1234
# Enable automatic save/load of modelopt state huggingface checkpointing
mto.enable_huggingface_checkpointing()
def get_model(
ckpt_path: str,
device="cuda",
):
"""
Loads a QLoRA model that has been trained using modelopt trainer.
"""
# TODO: Add support for merging adapters in BF16 and merging adapters with quantization for deployment
device_map = "auto"
if device == "cpu":
device_map = "cpu"
# Load model
model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map)
# Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
if hasattr(model, "peft_config"):
modelopt_state = mto.load_modelopt_state(f"{ckpt_path}/modelopt_state_train.pth")
restore_from_modelopt_state(model, modelopt_state)
print_rank_0("Restored modelopt state")
# Restore modelopt quantizer state dict
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
if modelopt_weights is not None:
set_quantizer_state_dict(model, modelopt_weights)
print_rank_0("Restored modelopt quantizer state dict")
return model
def main(args):
# Load model
model = get_model(args.pyt_ckpt_path, args.device)
tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)
is_qlora = hasattr(model, "peft_config")
# Export HF checkpoint
export_dir = Path(args.export_path)
export_dir.mkdir(parents=True, exist_ok=True)
if is_qlora:
base_model_dir = export_dir / "base_model"
base_model_dir.mkdir(parents=True, exist_ok=True)
else:
base_model_dir = export_dir
try:
post_state_dict, hf_quant_config = _export_transformers_checkpoint(
model, is_modelopt_qlora=is_qlora
)
with open(f"{base_model_dir}/hf_quant_config.json", "w") as file:
json.dump(hf_quant_config, file, indent=4)
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
# Save model
if is_qlora:
model.base_model.save_pretrained(f"{base_model_dir}", state_dict=post_state_dict)
model.save_pretrained(export_dir)
else:
model.save_pretrained(export_dir, state_dict=post_state_dict)
config_path = f"{base_model_dir}/config.json"
config_data = model.config.to_dict()
config_data["quantization_config"] = hf_quant_config
with open(config_path, "w") as file:
json.dump(config_data, file, indent=4)
# Save tokenizer
tokenizer.save_pretrained(export_dir)
except Exception as e:
warnings.warn(
"Cannot export model to the model_config. The modelopt-optimized model state_dict"
" can be saved with torch.save for further inspection."
)
raise e
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--pyt_ckpt_path",
help="Specify where the PyTorch checkpoint path is",
required=True,
)
parser.add_argument("--device", default="cuda")
parser.add_argument(
"--export_path",
default="exported_model",
help="Path to save the exported model",
)
args = parser.parse_args()
main(args)