-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathwebui.py
512 lines (435 loc) · 19.4 KB
/
webui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Liu Yue)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import platform
import argparse
import gradio as gr
import numpy as np
import torch
import torchaudio
import random
import librosa
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
import shutil
import time
# 设置环境变量禁用tokenizers并行处理
os.environ["TOKENIZERS_PARALLELISM"] = "false"
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav, logging
from cosyvoice.utils.common import set_all_random_seed
from modelscope import snapshot_download
snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
try:
shutil.copy2('spk2info.pt', 'pretrained_models/CosyVoice2-0.5B/spk2info.pt')
except Exception as e:
logging.warning(f'复制文件失败: {e}')
inference_mode_list = ['预训练音色', '自然语言控制', '3s极速复刻', '跨语种复刻']
instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成音频按钮',
'3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮\n4. (可选)保存音色模型',
'跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 点击生成音频按钮\n3. (可选)保存音色模型',
'自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'}
stream_mode_list = [('否', False), ('是', True)]
max_val = 0.8
def refresh_sft_spk():
"""刷新音色选择列表 """
# 获取自定义音色
files = [(entry.name, entry.stat().st_mtime) for entry in os.scandir(f"{ROOT_DIR}/voices")]
files.sort(key=lambda x: x[1], reverse=True) # 按时间排序
# 添加预训练音色
choices = [f[0].replace(".pt", "") for f in files] + cosyvoice.list_available_spks()
if not choices:
choices = ['']
return {"choices": choices, "__type__": "update"}
def refresh_prompt_wav():
"""刷新音频选择列表"""
files = [(entry.name, entry.stat().st_mtime) for entry in os.scandir(f"{ROOT_DIR}/audios")]
files.sort(key=lambda x: x[1], reverse=True) # 按时间排序
choices = ["请选择参考音频或者自己上传"] + [f[0] for f in files]
if not choices:
choices = ['']
return {"choices": choices, "__type__": "update"}
def change_prompt_wav(filename):
"""切换音频文件"""
full_path = f"{ROOT_DIR}/audios/{filename}"
if not os.path.exists(full_path):
logging.warning(f"音频文件不存在: {full_path}")
return None
return full_path
def save_voice_model(voice_name):
"""保存音色模型"""
if not voice_name:
gr.Info("音色名称不能为空")
return False
try:
shutil.copyfile(f"{ROOT_DIR}/output.pt", f"{ROOT_DIR}/voices/{voice_name}.pt")
gr.Info("音色保存成功,存放位置为voices目录")
return True
except Exception as e:
logging.error(f"保存音色失败: {e}")
gr.Warning("保存音色失败")
return False
def generate_random_seed():
"""生成随机种子"""
return {
"__type__": "update",
"value": random.randint(1, 100000000)
}
def postprocess(speech, top_db = 60, hop_length = 220, win_length = 440):
"""音频后处理方法"""
# 修剪静音部分
speech, _ = librosa.effects.trim(
speech,
top_db=top_db,
frame_length=win_length,
hop_length=hop_length
)
# 音量归一化
if speech.abs().max() > max_val:
speech = speech / speech.abs().max() * max_val
# 添加尾部静音
speech = torch.concat([speech, torch.zeros(1, int(cosyvoice.sample_rate * 0.2))], dim=1)
return speech
def change_instruction(mode_checkbox_group):
"""切换模式的处理"""
voice_dropdown_visible = mode_checkbox_group in ['预训练音色', '自然语言控制']
save_btn_visible = mode_checkbox_group in ['3s极速复刻']
return (
instruct_dict[mode_checkbox_group],
gr.update(visible=voice_dropdown_visible),
gr.update(visible=save_btn_visible)
)
def prompt_wav_recognition(prompt_wav):
"""音频识别文本"""
if prompt_wav is None:
return ''
try:
res = asr_model.generate(
input=prompt_wav,
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
use_itn=True,
)
return res[0]["text"].split('|>')[-1]
except Exception as e:
logging.error(f"音频识别文本失败: {e}")
gr.Warning("识别文本失败,请检查音频是否包含人声内容")
return ''
def load_voice_data(voice_path):
"""加载音色数据"""
try:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
voice_data = torch.load(voice_path, map_location=device) if os.path.exists(voice_path) else None
return voice_data.get('audio_ref') if voice_data else None
except Exception as e:
logging.error(f"加载音色文件失败: {e}")
return None
def validate_input(mode, tts_text, sft_dropdown, prompt_text, prompt_wav, instruct_text):
"""验证输入参数的合法性
Args:
mode: 推理模式
tts_text: 合成文本
sft_dropdown: 预训练音色
prompt_text: prompt文本
prompt_wav: prompt音频
instruct_text: instruct文本
Returns:
bool: 验证是否通过
str: 错误信息
"""
if mode in ['自然语言控制']:
if not cosyvoice.is_05b and cosyvoice.instruct is False:
return False, f'您正在使用自然语言控制模式, {args.model_dir}模型不支持此模式'
if not instruct_text:
return False, '您正在使用自然语言控制模式, 请输入instruct文本'
elif mode in ['跨语种复刻']:
if not cosyvoice.is_05b and cosyvoice.instruct is True:
return False, f'您正在使用跨语种复刻模式, {args.model_dir}模型不支持此模式'
if not prompt_wav:
return False, '您正在使用跨语种复刻模式, 请提供prompt音频'
elif mode in ['3s极速复刻', '跨语种复刻']:
if not prompt_wav:
return False, 'prompt音频为空,您是否忘记输入prompt音频?'
if torchaudio.info(prompt_wav).sample_rate < prompt_sr:
return False, f'prompt音频采样率{torchaudio.info(prompt_wav).sample_rate}低于{prompt_sr}'
elif mode in ['预训练音色']:
if not sft_dropdown:
return False, '没有可用的预训练音色!'
if mode in ['3s极速复刻'] and not prompt_text:
return False, 'prompt文本为空,您是否忘记输入prompt文本?'
return True, ''
def process_audio(speech_generator, stream):
"""处理音频生成
Args:
speech_generator: 音频生成器
stream: 是否流式处理
Returns:
tuple: (音频数据列表, 总时长)
"""
tts_speeches = []
total_duration = 0
for i in speech_generator:
tts_speeches.append(i['tts_speech'])
total_duration += i['tts_speech'].shape[1] / cosyvoice.sample_rate
if stream:
yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten()), None
if not stream:
audio_data = torch.concat(tts_speeches, dim=1)
yield None, (cosyvoice.sample_rate, audio_data.numpy().flatten())
yield total_duration
def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
seed, stream, speed):
"""生成音频的主函数
Args:
tts_text: 合成文本
mode_checkbox_group: 推理模式
sft_dropdown: 预训练音色
prompt_text: prompt文本
prompt_wav_upload: 上传的prompt音频
prompt_wav_record: 录制的prompt音频
instruct_text: instruct文本
seed: 随机种子
stream: 是否流式推理
speed: 语速
Yields:
tuple: 音频数据
"""
start_time = time.time()
logging.info(f"开始生成音频 - 模式: {mode_checkbox_group}, 文本长度: {len(tts_text)}")
# 处理prompt音频输入
prompt_wav = prompt_wav_upload if prompt_wav_upload is not None else prompt_wav_record
# 验证输入
is_valid, error_msg = validate_input(mode_checkbox_group, tts_text, sft_dropdown,
prompt_text, prompt_wav, instruct_text)
if not is_valid:
gr.Warning(error_msg)
yield (cosyvoice.sample_rate, default_data), None
return
# 设置随机种子
set_all_random_seed(seed)
# 根据不同模式处理
if mode_checkbox_group == '预训练音色':
# logging.info('get sft inference request')
generator = cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream, speed=speed)
elif mode_checkbox_group in ['3s极速复刻', '跨语种复刻']:
# logging.info(f'get {mode_checkbox_group} inference request')
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
inference_func = (cosyvoice.inference_zero_shot if mode_checkbox_group == '3s极速复刻'
else cosyvoice.inference_cross_lingual)
generator = inference_func(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed)
else: # 自然语言控制模式
# logging.info('get instruct inference request')
voice_path = f"{ROOT_DIR}/voices/{sft_dropdown}.pt"
prompt_speech_16k = load_voice_data(voice_path)
if prompt_speech_16k is None:
gr.Warning('预训练音色文件中缺少prompt_speech数据!')
yield (cosyvoice.sample_rate, default_data), None
return
generator = cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k,
stream=stream, speed=speed)
# 处理音频生成并获取总时长
audio_generator = process_audio(generator, stream)
total_duration = 0
# 收集所有音频输出
for output in audio_generator:
if isinstance(output, (float, int)): # 如果是总时长
total_duration = output
else: # 如果是音频数据
yield output
processing_time = time.time() - start_time
rtf = processing_time / total_duration if total_duration > 0 else 0
logging.info(f"音频生成完成 耗时: {processing_time:.2f}秒, rtf: {rtf:.2f}")
def update_audio_visibility(stream_enabled):
"""更新音频组件的可见性"""
return [
gr.update(visible=stream_enabled), # 流式音频组件
gr.update(visible=not stream_enabled) # 非流式音频组件
]
def main():
with gr.Blocks() as demo:
# 页面标题和说明
gr.Markdown("### 代码库 [CosyVoice2-Ex](https://github.com/journey-ad/CosyVoice2-Ex) 原始项目 [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) \
预训练模型 [CosyVoice2-0.5B](https://www.modelscope.cn/models/iic/CosyVoice2-0.5B) \
[CosyVoice-300M](https://www.modelscope.cn/models/iic/CosyVoice-300M) \
[CosyVoice-300M-Instruct](https://www.modelscope.cn/models/iic/CosyVoice-300M-Instruct) \
[CosyVoice-300M-SFT](https://www.modelscope.cn/models/iic/CosyVoice-300M-SFT)")
gr.Markdown("#### 请输入需要合成的文本,选择推理模式,并按照提示步骤进行操作")
# 主要输入区域
tts_text = gr.Textbox(
label="输入合成文本",
lines=1,
value="CosyVoice迎来全面升级,提供更准、更稳、更快、 更好的语音生成能力。CosyVoice is undergoing a comprehensive upgrade, providing more accurate, stable, faster, and better voice generation capabilities."
)
with gr.Row():
mode_checkbox_group = gr.Radio(
choices=inference_mode_list,
label='选择推理模式',
value=inference_mode_list[0]
)
instruction_text = gr.Text(
label="操作步骤",
value=instruct_dict[inference_mode_list[0]],
scale=0.5
)
# 音色选择部分
sft_dropdown = gr.Dropdown(
choices=sft_spk,
label='选择预训练音色',
value=sft_spk[0],
scale=0.25
)
refresh_voice_button = gr.Button("刷新音色")
# 流式控制和速度调节
with gr.Column(scale=0.25):
stream = gr.Radio(
choices=stream_mode_list,
label='是否流式推理',
value=stream_mode_list[0][1]
)
speed = gr.Number(
value=1,
label="速度调节(仅支持非流式推理)",
minimum=0.5,
maximum=2.0,
step=0.1
)
# 随机种子控制
with gr.Column(scale=0.25):
seed_button = gr.Button(value="\U0001F3B2")
seed = gr.Number(value=0, label="随机推理种子")
# 音频输入区域
with gr.Row():
prompt_wav_upload = gr.Audio(
sources='upload',
type='filepath',
label='选择prompt音频文件,注意采样率不低于16khz'
)
prompt_wav_record = gr.Audio(
sources='microphone',
type='filepath',
label='录制prompt音频文件'
)
wavs_dropdown = gr.Dropdown(
label="参考音频列表",
choices=reference_wavs,
value="请选择参考音频或者自己上传",
interactive=True
)
refresh_button = gr.Button("刷新参考音频")
# 文本输入区域
prompt_text = gr.Textbox(
label="输入prompt文本",
lines=1,
placeholder="请输入prompt文本,支持自动识别,您可以自行修正识别结果...",
value=''
)
instruct_text = gr.Textbox(
label="输入instruct文本",
lines=1,
placeholder="请输入instruct文本. 例如: 用四川话说这句话。",
value=''
)
# 保存音色按钮(默认隐藏)
with gr.Row(visible=False) as save_spk_btn:
new_name = gr.Textbox(label="输入新的音色名称", lines=1, placeholder="输入新的音色名称.", value='', scale=2)
save_button = gr.Button(value="保存音色模型", scale=1)
# 生成按钮
generate_button = gr.Button("生成音频")
# 音频输出区域
with gr.Group() as audio_group:
audio_output_stream = gr.Audio(
label="合成音频(流式)",
value=None,
streaming=True,
autoplay=True,
show_label=True,
show_download_button=True,
visible=False
)
audio_output_normal = gr.Audio(
label="合成音频",
value=None,
streaming=False,
autoplay=True,
show_label=True,
show_download_button=True,
visible=True
)
# 绑定事件
refresh_voice_button.click(fn=refresh_sft_spk, inputs=[], outputs=[sft_dropdown])
refresh_button.click(fn=refresh_prompt_wav, inputs=[], outputs=[wavs_dropdown])
wavs_dropdown.change(change_prompt_wav, inputs=[wavs_dropdown], outputs=[prompt_wav_upload])
save_button.click(save_voice_model, inputs=[new_name])
seed_button.click(generate_random_seed, inputs=[], outputs=[seed])
generate_button.click(
generate_audio,
inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text,
prompt_wav_upload, prompt_wav_record, instruct_text,
seed, stream, speed],
outputs=[audio_output_stream, audio_output_normal]
)
mode_checkbox_group.change(
fn=change_instruction,
inputs=[mode_checkbox_group],
outputs=[instruction_text, sft_dropdown, save_spk_btn]
)
prompt_wav_upload.change(fn=prompt_wav_recognition, inputs=[prompt_wav_upload], outputs=[prompt_text])
prompt_wav_record.change(fn=prompt_wav_recognition, inputs=[prompt_wav_record], outputs=[prompt_text])
stream.change(
fn=update_audio_visibility,
inputs=[stream],
outputs=[audio_output_stream, audio_output_normal]
)
# 配置队列和启动服务
demo.queue(max_size=4, default_concurrency_limit=2)
demo.launch(server_name='0.0.0.0', server_port=args.port, inbrowser=args.open)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--port',
type=int,
default=8080)
parser.add_argument('--model_dir',
type=str,
default='pretrained_models/CosyVoice2-0.5B',
help='local path or modelscope repo id')
parser.add_argument('--open',
action='store_true',
help='open in browser')
parser.add_argument('--log_level',
type=str,
default='INFO',
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
help='set log level')
args = parser.parse_args()
logging.getLogger().setLevel(getattr(logging, args.log_level))
try:
cosyvoice = CosyVoice(args.model_dir)
except Exception:
try:
cosyvoice = CosyVoice2(args.model_dir)
except Exception:
raise TypeError('no valid model_type!')
sft_spk = refresh_sft_spk()['choices']
reference_wavs = refresh_prompt_wav()['choices']
prompt_sr = 16000
default_data = np.zeros(cosyvoice.sample_rate)
model_dir = "iic/SenseVoiceSmall"
asr_model = AutoModel(
model=model_dir,
disable_update=True,
log_level=args.log_level,
device="cuda:0")
main()