Skip to content

Commit 2b33690

Browse files
committed
feat: Add Support for v2 Model in Web UI
- Added support for the v2 model in the web UI. - Implemented logic to handle v2-specific features, including the handling of prompt audio and disabling streaming inference for v2 models. - Updated UI instructions to ensure users are properly guided when selecting the v2 model.
1 parent 41c5e8c commit 2b33690

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

webui.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@
3030
instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成音频按钮',
3131
'3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
3232
'跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 点击生成音频按钮',
33-
'自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'}
33+
'自然语言控制': '1. 选择预训练音色(v2模型需要选择或录入prompt音频)\n2. 输入instruct文本\n3. 点击生成音频按钮'}
3434
stream_mode_list = [('否', False), ('是', True)]
3535
max_val = 0.8
36+
model_versions = None
3637

3738

3839
def generate_seed():
@@ -61,6 +62,10 @@ def change_instruction(mode_checkbox_group):
6162

6263
def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
6364
seed, stream, speed):
65+
if model_versions == 'v2':
66+
if stream:
67+
stream = False
68+
gr.Warning('您正在使用v2版本模型, 不支持流式推理, 将使用非流式模式.')
6469
if prompt_wav_upload is not None:
6570
prompt_wav = prompt_wav_upload
6671
elif prompt_wav_record is not None:
@@ -69,13 +74,13 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
6974
prompt_wav = None
7075
# if instruct mode, please make sure that model is iic/CosyVoice-300M-Instruct and not cross_lingual mode
7176
if mode_checkbox_group in ['自然语言控制']:
72-
if cosyvoice.instruct is False:
77+
if cosyvoice.instruct is False and model_versions == 'v1':
7378
gr.Warning('您正在使用自然语言控制模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M-Instruct模型'.format(args.model_dir))
7479
yield (cosyvoice.sample_rate, default_data)
7580
if instruct_text == '':
7681
gr.Warning('您正在使用自然语言控制模式, 请输入instruct文本')
7782
yield (cosyvoice.sample_rate, default_data)
78-
if prompt_wav is not None or prompt_text != '':
83+
if (prompt_wav is not None or prompt_text != '') and model_versions == 'v1':
7984
gr.Info('您正在使用自然语言控制模式, prompt音频/prompt文本会被忽略')
8085
# if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language
8186
if mode_checkbox_group in ['跨语种复刻']:
@@ -128,11 +133,20 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
128133
set_all_random_seed(seed)
129134
for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed):
130135
yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
131-
else:
136+
elif mode_checkbox_group == '自然语言控制':
132137
logging.info('get instruct inference request')
133138
set_all_random_seed(seed)
134-
for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed):
135-
yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
139+
if model_versions == 'v1':
140+
for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed):
141+
yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
142+
elif model_versions == 'v2':
143+
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
144+
for i in cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k, stream=stream):
145+
yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
146+
else:
147+
gr.Warning('非预期的模型版本!')
148+
else:
149+
gr.Warning('非预期的选项!')
136150

137151

138152
def main():
@@ -186,9 +200,11 @@ def main():
186200
args = parser.parse_args()
187201
try:
188202
cosyvoice = CosyVoice(args.model_dir)
203+
model_versions = 'v1'
189204
except Exception:
190205
try:
191206
cosyvoice = CosyVoice2(args.model_dir)
207+
model_versions = 'v2'
192208
except Exception:
193209
raise TypeError('no valid model_type!')
194210

0 commit comments

Comments
 (0)