forked from GPT-SoVITS-Devel/GPT-SoVITS_minimal_inference
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathonnx_to_fp16.py
More file actions
256 lines (216 loc) · 10 KB
/
onnx_to_fp16.py
File metadata and controls
256 lines (216 loc) · 10 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import os
import onnx
import onnx.helper
from onnx import TensorProto
from onnxconverter_common.float16 import convert_float_to_float16
import argparse
from onnxsim import simplify
import numpy as np
# --- 配置区 ---
MODEL_CONFIGS = {
"vq_encoder": {"fp16": False, "sensitive": []},
"bert": {"fp16": True, "sensitive": ["LayerNormalization", "Mean"]},
"ssl": {"fp16": True, "sensitive": ["LayerNormalization", "Mean"]},
"gpt_encoder": {"fp16": True, "sensitive": ["Pow", "Exp", "Mean", "ReduceMean", "LayerNormalization"]},
"gpt_step": {"fp16": True, "sensitive": ["Pow", "Exp", "MatMulInteger", "LayerNormalization"]},
"sovits": {"fp16": True, "sensitive": ["InstanceNormalization", "Resize", "Mean", "Sum", "Exp"], "native_sensitive": ["Resize"]},
# spectrogram 和 sv_embedding 保持 FP32,因为 STFT 和后续计算需要 FP32 精度
"spectrogram": {"fp16": False, "sensitive": []},
"sv_embedding": {"fp16": False, "sensitive": []},
}
# 全局通用黑名单
GLOBAL_SENSITIVE_OPS = [
"Softmax",
"LayerNormalization",
"InstanceNormalization",
"ReduceMean",
"Pow",
"Exp",
"Resize",
"Mean",
"Sum",
]
def get_tensor_type(name, type_map, initializer_map, graph_input_map):
if name in initializer_map: return initializer_map[name]
if name in graph_input_map: return graph_input_map[name]
if name in type_map: return type_map[name]
return TensorProto.UNDEFINED
def fix_mixed_types_robust(model):
"""
用于修复 GPT/SoVITS 转换 FP16 后遗留的类型不匹配问题
"""
initializer_map = {init.name: init.data_type for init in model.graph.initializer}
graph_input_map = {}
for inp in model.graph.input:
if inp.type.HasField("tensor_type"):
graph_input_map[inp.name] = inp.type.tensor_type.elem_type
try:
model = onnx.shape_inference.infer_shapes(model)
except:
pass
type_map = {}
for vi in model.graph.value_info:
if vi.type.HasField("tensor_type"):
type_map[vi.name] = vi.type.tensor_type.elem_type
new_nodes = []
ops_to_check = ["MatMul", "Gemm", "Conv"]
for node in model.graph.node:
if node.op_type in ops_to_check and len(node.input) >= 2:
data_name = node.input[0]
weight_name = node.input[1]
t_data = get_tensor_type(data_name, type_map, initializer_map, graph_input_map)
t_weight = get_tensor_type(weight_name, type_map, initializer_map, graph_input_map)
need_cast = False
# 权重是 FP16 但数据是 FP32 或 未知 -> 强制 Cast 数据
if t_weight == TensorProto.FLOAT16 and (t_data == TensorProto.FLOAT or t_data == TensorProto.UNDEFINED):
need_cast = True
if need_cast:
cast_name = f"{data_name}_cast_fp16_fix_{node.name}"
if any(n.name == cast_name for n in new_nodes): cast_name += "_dup"
cast_node = onnx.helper.make_node(
"Cast", inputs=[data_name], outputs=[cast_name],
to=TensorProto.FLOAT16, name=cast_name
)
new_nodes.append(cast_node)
node.input[0] = cast_name
if new_nodes:
model.graph.node.extend(new_nodes)
print(f" [Robust Fix] Inserted {len(new_nodes)} Cast nodes.")
return model
def fix_broken_attributes(model):
try:
model = onnx.shape_inference.infer_shapes(model)
except:
print(" [Warn] Shape inference failed inside fix_broken_attributes, relying on partial info.")
# 构建类型映射 (Name -> DataType)
type_map = {}
for vi in list(model.graph.input) + list(model.graph.output) + list(model.graph.value_info):
if vi.type.HasField("tensor_type"):
type_map[vi.name] = vi.type.tensor_type.elem_type
cnt = 0
# 需要检查属性的算子列表
random_ops = ["RandomNormal", "RandomUniform", "RandomNormalLike", "RandomUniformLike"]
for node in model.graph.node:
# --- 修复 Random 系列算子 ---
if node.op_type in random_ops:
out_name = node.output[0]
# 只有当我们确切知道该输出应该是 FP16 时才动手
if out_name in type_map:
real_dtype = type_map[out_name]
# 检查是否已有 dtype 属性
found_dtype = False
for attr in node.attribute:
if attr.name == "dtype":
found_dtype = True
if attr.i != real_dtype:
attr.i = real_dtype # 强制修正属性
cnt += 1
# 如果没有 dtype 属性,且输出要是 FP16,必须显式添加 dtype=10 (FLOAT16)
# 因为 RandomNormal 默认通常是 Float(1)
if not found_dtype and real_dtype == TensorProto.FLOAT16:
new_attr = onnx.helper.make_attribute("dtype", TensorProto.FLOAT16)
node.attribute.extend([new_attr])
cnt += 1
# --- 修复 Cast 算子 ---
elif node.op_type == "Cast":
out_name = node.output[0]
if out_name in type_map:
real_dtype = type_map[out_name]
for attr in node.attribute:
if attr.name == "to" and attr.i != real_dtype:
attr.i = real_dtype
cnt += 1
if cnt > 0:
print(f" [Attribute Fix] Fixed {cnt} attributes (Random/Cast mismatch).")
return model
def optimize_single_model(input_path, output_path, native_fp16=False):
filename = os.path.basename(input_path)
model_name_key = None
# 匹配策略
for key in MODEL_CONFIGS:
if key in filename:
model_name_key = key
break
# 默认策略:如果不匹配(如 unknown.onnx),默认保持 FP32 以求稳
config = MODEL_CONFIGS.get(model_name_key, {"fp16": False, "sensitive": []})
if native_fp16 and config["fp16"]:
keep_io = config.get("keep_input_types", False)
strategy = f"Native FP16 (I/O: {'FP32' if keep_io else 'FP16'})"
else:
keep_io = config.get("keep_input_types", False)
io_status = "FP32 Input" if keep_io else "FP16"
strategy = f"{'FP16 (Mixed)' if config['fp16'] else 'FP32 (Keep)'} [{io_status}]"
print(f"Processing: {filename} | Strategy: {strategy}")
model = onnx.load(input_path)
# 如果启用 FP16,执行转换和修复
if config["fp16"]:
print(" Converting to FP16...")
if native_fp16:
native_block_list = config.get("native_sensitive", [])
keep_io = config.get("keep_input_types", False)
if native_block_list:
print(f" [Native FP16] Converting all layers to FP16, but preserving {native_block_list}...")
model = convert_float_to_float16(
model,
keep_io_types=keep_io, # 根据 keep_input_types 配置决定是否转换 I/O
op_block_list=native_block_list # 保留特定的敏感操作为 FP32
)
else:
print(f" [Native FP16] Converting all layers to FP16 (I/O: {'FP32' if keep_io else 'FP16'})...")
model = convert_float_to_float16(
model,
keep_io_types=keep_io, # 根据 keep_input_types 配置决定是否转换 I/O
op_block_list=[] # 不保留任何敏感操作
)
else:
# 混合精度模式:保留敏感操作为 FP32
block_list = GLOBAL_SENSITIVE_OPS + config["sensitive"]
keep_io = config.get("keep_input_types", False)
model = convert_float_to_float16(
model,
keep_io_types=keep_io, # 根据 keep_input_types 配置决定是否转换 I/O
op_block_list=block_list
)
# 仅在混合精度模式下需要修复类型不匹配
model = fix_mixed_types_robust(model)
# 修复属性(无论哪种 FP16 模式都需要)
model = fix_broken_attributes(model)
else:
print(" Skipping FP16 conversion (Sensitivity/Low-Cost).")
# 通用 Simplification (无论 FP16 还是 FP32 都需要简化)
print(" Simplifying...")
try:
model, check = simplify(model)
except Exception as e:
print(f" [Warn] Simplify failed/warned: {e}")
onnx.save(model, output_path)
print(f" Saved: {output_path}")
import shutil
def process_directory(input_dir, output_dir, native_fp16=False):
os.makedirs(output_dir, exist_ok=True)
for filename in os.listdir(input_dir):
if filename.endswith(".onnx"):
optimize_single_model(
os.path.join(input_dir, filename),
os.path.join(output_dir, filename),
native_fp16=native_fp16
)
# 复制 .data 文件 (如果有)
dfile = os.path.join(input_dir, filename + ".data")
if os.path.exists(dfile):
shutil.copy(dfile, os.path.join(output_dir, filename + ".data"))
shutil.copy(os.path.join(input_dir, "config.json"),os.path.join(output_dir, "config.json"))
print(f"\nOptimization complete: {output_dir}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert ONNX models to FP16")
parser.add_argument("--input_dir", required=True, help="Input directory containing FP32 ONNX models")
parser.add_argument("--output_dir", required=True, help="Output directory for FP16 models")
parser.add_argument("--native_fp16", action="store_true", help="Use native FP16 mode (convert all layers to FP16, no mixed precision)")
args = parser.parse_args()
if args.native_fp16:
print("=" * 60)
print("NATIVE FP16 MODE ENABLED")
print("All layers will be converted to FP16")
print("Note: This may cause numerical instability in some operations")
print("=" * 60)
process_directory(args.input_dir, args.output_dir, native_fp16=args.native_fp16)