Skip to content

feat: Add Support for v2 Model in Web UI #903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成音频按钮',
'3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
'跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 点击生成音频按钮',
'自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'}
'自然语言控制': '1. 选择预训练音色(v2模型需要选择或录入prompt音频)\n2. 输入instruct文本\n3. 点击生成音频按钮'}
stream_mode_list = [('否', False), ('是', True)]
max_val = 0.8
model_versions = None


def generate_seed():
Expand Down Expand Up @@ -61,6 +62,10 @@ def change_instruction(mode_checkbox_group):

def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
seed, stream, speed):
if model_versions == 'v2':
if stream:
stream = False
gr.Warning('您正在使用v2版本模型, 不支持流式推理, 将使用非流式模式.')
if prompt_wav_upload is not None:
prompt_wav = prompt_wav_upload
elif prompt_wav_record is not None:
Expand All @@ -69,13 +74,13 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
prompt_wav = None
# if instruct mode, please make sure that model is iic/CosyVoice-300M-Instruct and not cross_lingual mode
if mode_checkbox_group in ['自然语言控制']:
if cosyvoice.instruct is False:
if cosyvoice.instruct is False and model_versions == 'v1':
gr.Warning('您正在使用自然语言控制模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M-Instruct模型'.format(args.model_dir))
yield (cosyvoice.sample_rate, default_data)
if instruct_text == '':
gr.Warning('您正在使用自然语言控制模式, 请输入instruct文本')
yield (cosyvoice.sample_rate, default_data)
if prompt_wav is not None or prompt_text != '':
if (prompt_wav is not None or prompt_text != '') and model_versions == 'v1':
gr.Info('您正在使用自然语言控制模式, prompt音频/prompt文本会被忽略')
# if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language
if mode_checkbox_group in ['跨语种复刻']:
Expand Down Expand Up @@ -128,11 +133,20 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
set_all_random_seed(seed)
for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed):
yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
else:
elif mode_checkbox_group == '自然语言控制':
logging.info('get instruct inference request')
set_all_random_seed(seed)
for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed):
yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
if model_versions == 'v1':
for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed):
yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
elif model_versions == 'v2':
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
for i in cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k, stream=stream):
yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
else:
gr.Warning('非预期的模型版本!')
else:
gr.Warning('非预期的选项!')


def main():
Expand Down Expand Up @@ -186,9 +200,11 @@ def main():
args = parser.parse_args()
try:
cosyvoice = CosyVoice(args.model_dir)
model_versions = 'v1'
except Exception:
try:
cosyvoice = CosyVoice2(args.model_dir)
model_versions = 'v2'
except Exception:
raise TypeError('no valid model_type!')

Expand Down
Loading