Skip to content

Commit e95b100

Browse files
committed
feat: Add Support for v2 Model in Web UI
- Added support for the v2 model in the web UI. - Implemented logic to handle v2-specific features, including the handling of prompt audio and disabling streaming inference for v2 models. - Updated UI instructions to ensure users are properly guided when selecting the v2 model.
1 parent 41c5e8c commit e95b100

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

webui.py

+22-8
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@
3030
instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成音频按钮',
3131
'3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
3232
'跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 点击生成音频按钮',
33-
'自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'}
33+
'自然语言控制': '1. 选择预训练音色(v2模型需要选择或录入prompt音频)\n2. 输入instruct文本\n3. 点击生成音频按钮'}
3434
stream_mode_list = [('否', False), ('是', True)]
3535
max_val = 0.8
36-
36+
model_versions = None
3737

3838
def generate_seed():
3939
seed = random.randint(1, 100000000)
@@ -61,6 +61,10 @@ def change_instruction(mode_checkbox_group):
6161

6262
def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
6363
seed, stream, speed):
64+
if model_versions == 'v2':
65+
if stream == True:
66+
stream = False
67+
gr.Warning('您正在使用v2版本模型, 不支持流式推理, 将使用非流式模式.')
6468
if prompt_wav_upload is not None:
6569
prompt_wav = prompt_wav_upload
6670
elif prompt_wav_record is not None:
@@ -69,13 +73,13 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
6973
prompt_wav = None
7074
# if instruct mode, please make sure that model is iic/CosyVoice-300M-Instruct and not cross_lingual mode
7175
if mode_checkbox_group in ['自然语言控制']:
72-
if cosyvoice.instruct is False:
76+
if cosyvoice.instruct is False and model_versions == 'v1':
7377
gr.Warning('您正在使用自然语言控制模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M-Instruct模型'.format(args.model_dir))
7478
yield (cosyvoice.sample_rate, default_data)
7579
if instruct_text == '':
7680
gr.Warning('您正在使用自然语言控制模式, 请输入instruct文本')
7781
yield (cosyvoice.sample_rate, default_data)
78-
if prompt_wav is not None or prompt_text != '':
82+
if prompt_wav is not None or prompt_text != '' and model_versions == 'v1':
7983
gr.Info('您正在使用自然语言控制模式, prompt音频/prompt文本会被忽略')
8084
# if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language
8185
if mode_checkbox_group in ['跨语种复刻']:
@@ -128,12 +132,20 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
128132
set_all_random_seed(seed)
129133
for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed):
130134
yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
131-
else:
135+
elif mode_checkbox_group == '自然语言控制':
132136
logging.info('get instruct inference request')
133137
set_all_random_seed(seed)
134-
for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed):
135-
yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
136-
138+
if model_versions == 'v1':
139+
for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed):
140+
yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
141+
elif model_versions == 'v2':
142+
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
143+
for i in cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k, stream=stream):
144+
yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
145+
else:
146+
gr.Warning('非预期的模型版本!')
147+
else:
148+
gr.Warning('非预期的选项!')
137149

138150
def main():
139151
with gr.Blocks() as demo:
@@ -186,9 +198,11 @@ def main():
186198
args = parser.parse_args()
187199
try:
188200
cosyvoice = CosyVoice(args.model_dir)
201+
model_versions = 'v1'
189202
except Exception:
190203
try:
191204
cosyvoice = CosyVoice2(args.model_dir)
205+
model_versions = 'v2'
192206
except Exception:
193207
raise TypeError('no valid model_type!')
194208

0 commit comments

Comments
 (0)