diff --git a/.gitignore b/.gitignore old mode 100644 new mode 100755 index 4a9f23f8..035fd5d1 --- a/.gitignore +++ b/.gitignore @@ -49,4 +49,4 @@ compile_commands.json pretrained_models/* *_pb2_grpc.py *_pb2.py -*.tar \ No newline at end of file +*.tar diff --git a/README.md b/README.md index 4a1dbd30..8c708b36 100644 --- a/README.md +++ b/README.md @@ -1,241 +1,21 @@ -[![SVG Banners](https://svg-banners.vercel.app/api?type=origin&text1=CosyVoice🤠&text2=Text-to-Speech%20💖%20Large%20Language%20Model&width=800&height=210)](https://github.com/Akshay090/svg-banners) -## 👉🏻 CosyVoice 👈🏻 -**CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://arxiv.org/abs/2412.10117); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice2-0.5B); [HuggingFace](https://huggingface.co/spaces/FunAudioLLM/CosyVoice2-0.5B) +## 使用方法 -**CosyVoice 1.0**: [Demos](https://fun-audio-llm.github.io); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_v1.pdf); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice-300M) +python cosyvoice_2_demo.py --fp16 --use_flow_cache -## Highlight🔥 +然后在命令行,里面输入下面的格式,回车即可。 -**CosyVoice 2.0** has been released! Compared to version 1.0, the new version offers more accurate, more stable, faster, and better speech generation capabilities. -### Multilingual -- **Supported Language**: Chinese, English, Japanese, Korean, Chinese dialects (Cantonese, Sichuanese, Shanghainese, Tianjinese, Wuhanese, etc.) -- **Crosslingual & Mixlingual**:Support zero-shot voice cloning for cross-lingual and code-switching scenarios. -### Ultra-Low Latency -- **Bidirectional Streaming Support**: CosyVoice 2.0 integrates offline and streaming modeling technologies. -- **Rapid First Packet Synthesis**: Achieves latency as low as 150ms while maintaining high-quality audio output. -### High Accuracy -- **Improved Pronunciation**: Reduces pronunciation errors by 30% to 50% compared to CosyVoice 1.0. -- **Benchmark Achievements**: Attains the lowest character error rate on the hard test set of the Seed-TTS evaluation set. -### Strong Stability -- **Consistency in Timbre**: Ensures reliable voice consistency for zero-shot and cross-language speech synthesis. -- **Cross-language Synthesis**: Marked improvements compared to version 1.0. -### Natural Experience -- **Enhanced Prosody and Sound Quality**: Improved alignment of synthesized audio, raising MOS evaluation scores from 5.4 to 5.53. -- **Emotional and Dialectal Flexibility**: Now supports more granular emotional controls and accent adjustments. +音色代码@要说的话 -## Roadmap +已经存有几个人的音色: +| 角色名称 | 音色代码 | +|---------|---------| +| 哈利波特 | hp | +| 老许 | laoxu | -- [x] 2024/12 - - [x] 25hz cosyvoice 2.0 released +### 例子 -- [x] 2024/09 +hp@Blimey! Professor Snape's given us a mountain of potions homework. Wish I had my invisibility cloak right now. Ron, Hermione, fancy a trip to Hogsmeade? - - [x] 25hz cosyvoice base model - - [x] 25hz cosyvoice voice conversion model - -- [x] 2024/08 - - - [x] Repetition Aware Sampling(RAS) inference for llm stability - - [x] Streaming inference mode support, including kv cache and sdpa for rtf optimization - -- [x] 2024/07 - - - [x] Flow matching training support - - [x] WeTextProcessing support when ttsfrd is not available - - [x] Fastapi server and client - - -## Install - -**Clone and install** - -- Clone the repo -``` sh -git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git -# If you failed to clone submodule due to network failures, please run following command until success -cd CosyVoice -git submodule update --init --recursive -``` - -- Install Conda: please see https://docs.conda.io/en/latest/miniconda.html -- Create Conda env: - -``` sh -conda create -n cosyvoice -y python=3.10 -conda activate cosyvoice -# pynini is required by WeTextProcessing, use conda to install it as it can be executed on all platform. -conda install -y -c conda-forge pynini==2.1.5 -pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com - -# If you encounter sox compatibility issues -# ubuntu -sudo apt-get install sox libsox-dev -# centos -sudo yum install sox sox-devel -``` - -**Model download** - -We strongly recommend that you download our pretrained `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource. - -``` python -# SDK模型下载 -from modelscope import snapshot_download -snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B') -snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M') -snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT') -snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct') -snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd') -``` - -``` sh -# git模型下载,请确保已安装git lfs -mkdir -p pretrained_models -git clone https://www.modelscope.cn/iic/CosyVoice2-0.5B.git pretrained_models/CosyVoice2-0.5B -git clone https://www.modelscope.cn/iic/CosyVoice-300M.git pretrained_models/CosyVoice-300M -git clone https://www.modelscope.cn/iic/CosyVoice-300M-SFT.git pretrained_models/CosyVoice-300M-SFT -git clone https://www.modelscope.cn/iic/CosyVoice-300M-Instruct.git pretrained_models/CosyVoice-300M-Instruct -git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd -``` - -Optionally, you can unzip `ttsfrd` resouce and install `ttsfrd` package for better text normalization performance. - -Notice that this step is not necessary. If you do not install `ttsfrd` package, we will use WeTextProcessing by default. - -``` sh -cd pretrained_models/CosyVoice-ttsfrd/ -unzip resource.zip -d . -pip install ttsfrd_dependency-0.1-py3-none-any.whl -pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl -``` - -**Basic Usage** - -We strongly recommend using `CosyVoice2-0.5B` for better performance. -Follow code below for detailed usage of each model. - -``` python -import sys -sys.path.append('third_party/Matcha-TTS') -from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 -from cosyvoice.utils.file_utils import load_wav -import torchaudio -``` - -**CosyVoice2 Usage** -```python -cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False, use_flow_cache=False) - -# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference -# zero_shot usage -prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000) -for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)): - torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) - -# save zero_shot spk for future usage -assert cosyvoice.add_zero_shot_spk('希望你以后能够做的比我还好呦。', prompt_speech_16k, 'my_zero_shot_spk') is True -for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '', '', zero_shot_spk_id='my_zero_shot_spk', stream=False)): - torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) -cosyvoice.save_spkinfo() - -# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248 -for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', prompt_speech_16k, stream=False)): - torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) - -# instruct usage -for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)): - torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) - -# bistream usage, you can use generator as input, this is useful when using text llm model as input -# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length -def text_generator(): - yield '收到好友从远方寄来的生日礼物,' - yield '那份意外的惊喜与深深的祝福' - yield '让我心中充满了甜蜜的快乐,' - yield '笑容如花儿般绽放。' -for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)): - torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) -``` - -**CosyVoice Usage** -```python -cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=False, load_trt=False, fp16=False) -# sft usage -print(cosyvoice.list_available_spks()) -# change stream=True for chunk stream inference -for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)): - torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) - -cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M') -# zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean -prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000) -for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)): - torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) -# cross_lingual usage -prompt_speech_16k = load_wav('./asset/cross_lingual_prompt.wav', 16000) -for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)): - torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) -# vc usage -prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000) -source_speech_16k = load_wav('./asset/cross_lingual_prompt.wav', 16000) -for i, j in enumerate(cosyvoice.inference_vc(source_speech_16k, prompt_speech_16k, stream=False)): - torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) - -cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct') -# instruct usage, support [laughter][breath] -for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的勇气智慧。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.', stream=False)): - torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) -``` - -**Start web demo** - -You can use our web demo page to get familiar with CosyVoice quickly. - -Please see the demo website for details. - -``` python -# change iic/CosyVoice-300M-SFT for sft inference, or iic/CosyVoice-300M-Instruct for instruct inference -python3 webui.py --port 50000 --model_dir pretrained_models/CosyVoice-300M -``` - -**Advanced Usage** - -For advanced user, we have provided train and inference scripts in `examples/libritts/cosyvoice/run.sh`. - -**Build for deployment** - -Optionally, if you want service deployment, -you can run following steps. - -``` sh -cd runtime/python -docker build -t cosyvoice:v1.0 . -# change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference -# for grpc usage -docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity" -cd grpc && python3 client.py --port 50000 --mode -# for fastapi usage -docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && python3 server.py --port 50000 --model_dir iic/CosyVoice-300M && sleep infinity" -cd fastapi && python3 client.py --port 50000 --mode -``` - -## Discussion & Communication - -You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues). - -You can also scan the QR code to join our official Dingding chat group. - - - -## Acknowledge - -1. We borrowed a lot of code from [FunASR](https://github.com/modelscope/FunASR). -2. We borrowed a lot of code from [FunCodec](https://github.com/modelscope/FunCodec). -3. We borrowed a lot of code from [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS). -4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec). -5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet). - -## Disclaimer -The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal. +laoxu@七牛毕竟,是国内最早做云存储的公司。所以我想,就是和云存储相关的交流,可以在这个会之后自由讨论的时候,知无不言,言无不尽. \ No newline at end of file diff --git a/README_quantization.md b/README_quantization.md new file mode 100644 index 00000000..4fa0ce3a --- /dev/null +++ b/README_quantization.md @@ -0,0 +1,68 @@ +# CosyVoice 模型量化指南 + +本指南提供了使用不同量化方法对CosyVoice模型进行量化的步骤。 + +## 准备工作 + +首先,您需要安装相应的量化库。我们推荐使用bitsandbytes进行量化,它的兼容性最好: + +```bash +pip install bitsandbytes +``` + +## 量化模型 + +### 1. 使用 BitsAndBytes 量化 (推荐) + +BitsAndBytes是一种简单易用的量化方法,适合快速尝试,兼容性最好。 + +```bash +python quant_cosyvoice_bnb.py --model_dir pretrained_models/CosyVoice2-0.5B --output_dir pretrained_models/CosyVoice2-0.5B-bnb --bits 8 +``` + +参数说明: +- `--model_dir`: 原始模型目录 +- `--output_dir`: 量化后模型保存目录 +- `--bits`: 量化位数 (4 或 8),建议先尝试8位 + +### 2. 使用简化的量化方法 + +我们提供了一个简化的量化脚本,它使用bitsandbytes库对模型进行量化,但采用了更直接的方法: + +```bash +python quant_cosyvoice_gptq.py --model_dir pretrained_models/CosyVoice2-0.5B --output_dir pretrained_models/CosyVoice2-0.5B-quantized --bits 8 +``` + +参数说明: +- `--model_dir`: 原始模型目录 +- `--output_dir`: 量化后模型保存目录 +- `--bits`: 量化位数 (4 或 8) +- `--block_size`: 量化块大小 (默认32) + +## 使用量化后的模型 + +量化完成后,您可以使用以下命令测试量化后的模型: + +```bash +python cosyvoice_2_demo.py --model_dir pretrained_models/CosyVoice2-0.5B-bnb +``` + +## 简单量化方法 + +如果上述方法都遇到问题,所有脚本都包含了一个简单的备选量化方法,它不依赖于特定的量化库,而是使用简单的权重量化技术。这种方法虽然不如专业量化库精确,但兼容性最好。 + +## 注意事项 + +1. 量化会导致模型质量略有下降,但通常不会显著影响语音合成质量 +2. 4位量化可以显著减小模型大小,但可能会导致更多的质量损失 +3. 如果遇到问题,建议先尝试8位量化,再尝试4位量化 +4. 量化过程可能需要较长时间,请耐心等待 + +## 故障排除 + +如果在量化过程中遇到问题: + +1. 首先尝试BitsAndBytes方法,它的兼容性最好 +2. 如果出现内存错误,尝试在更大内存的机器上运行 +3. 如果所有方法都失败,使用脚本中的简单量化方法 +4. 确保您的Python环境干净,没有冲突的库版本 \ No newline at end of file diff --git a/api.py b/api.py new file mode 100644 index 00000000..5e84d5d1 --- /dev/null +++ b/api.py @@ -0,0 +1,241 @@ +import time +import io, os, sys +os.environ["CUDA_VISIBLE_DEVICES"] = "7" +# 获取当前文件的绝对路径的目录 +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +# 添加依赖的第三方库到系统路径 +sys.path.append('{}/third_party/AcademiCodec'.format(ROOT_DIR)) +sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR)) + +import requests +from pydub import AudioSegment + +import numpy as np +# 导入Flask相关库,用于构建Web API服务 +from flask import Flask, request, Response, send_from_directory +import torch +import torchaudio + +# 导入CosyVoice模型相关的库 +from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 +from cosyvoice.utils.file_utils import load_wav +# import ffmpeg + +# 导入CORS处理跨域问题的库 +from flask_cors import CORS +from flask import make_response + +import shutil + +import json + +# 初始化CosyVoice2模型,加载预训练模型 +cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_trt=True, fp16=True) + +# 获取模型内置的预训练音色列表 +default_voices = cosyvoice.list_available_spks() + +# 获取自定义音色列表 +spk_custom = [] +for name in os.listdir(f"{ROOT_DIR}/voices/"): + print(name.replace(".pt", "")) + spk_custom.append(name.replace(".pt", "")) + +print("默认音色", default_voices) +print("自定义音色", spk_custom) + +# 创建Flask应用实例 +app = Flask(__name__) + +# 配置跨域资源共享(CORS) +# 允许所有来源的跨域请求,解决前端与后端API交互时的跨域问题 +CORS(app, cors_allowed_origins="*") +# CORS(app, supports_credentials=True) # 支持携带凭证的跨域请求(如cookie),当前已注释 +def process_audio(tts_speeches, sample_rate=22050, format="wav"): + """ + 处理音频数据并返回响应 + + 参数: + tts_speeches: 待处理的音频张量列表 + sample_rate: 采样率,默认22050Hz + format: 输出音频格式,默认wav + + 返回: + 包含音频数据的内存缓冲区 + """ + buffer = io.BytesIO() # 创建内存缓冲区,是一个在内存中模拟文件操作的对象 + + # 合并多个音频片段,dim=1表示在第二个维度上合并 + # 假设tts_speeches中的每个张量形状为[1, L],合并后audio_data形状仍为[1, L_total] + # 第0维保留,表示声道数,通常为1(单声道) + audio_data = torch.concat(tts_speeches, dim=1) + + # 将音频数据保存到内存缓冲区,此操作会将文件写入buffer并移动指针到末尾 + torchaudio.save(buffer, audio_data, sample_rate, format=format) + + # 将缓冲区指针重置到开始位置,这样后续读取时才能从头开始读取数据 + # 如果不重置,后续读取将从末尾开始,得到空数据 + buffer.seek(0) + + return buffer + +def create_audio_response(buffer, format="wav"): + """ + 创建音频HTTP响应 + + 参数: + buffer: 包含音频数据的缓冲区 + format: 音频格式,默认wav + + 返回: + Flask响应对象,包含适当的MIME类型和头信息 + """ + if format == "wav": + # wav格式直接返回Response对象 + return Response(buffer.read(), mimetype="audio/wav") + else: + # 其他格式使用make_response创建响应,并设置适当的头信息 + response = make_response(buffer.read()) + response.headers['Content-Type'] = f'audio/{format}' + response.headers['Content-Disposition'] = f'attachment; filename=sound.{format}' + return response + +def load_voice_data(speaker): + """ + 加载自定义语音数据 + + 参数: + speaker: 说话人ID/名称 + + 返回: + 加载的语音参考数据,如果加载失败则返回None + """ + voice_path = f"{ROOT_DIR}/voices/{speaker}.pt" + try: + # 检测是否有GPU可用 + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if not os.path.exists(voice_path): + return None + # 加载语音模型数据 + voice_data = torch.load(voice_path, map_location=device) + return voice_data.get('audio_ref') + except Exception as e: + raise ValueError(f"加载音色文件失败: {e}") + +# 定义路由,同时处理根路径和/tts路径的GET和POST请求 +@app.route("/", methods=['GET', 'POST']) +@app.route("/tts", methods=['GET', 'POST']) +def tts(): + """ + 文本转语音(TTS)API端点 + 处理文本到语音的转换请求,支持GET和POST方法 + + 请求参数: + text: 要转换的文本 + speaker: 说话人ID/名称 + instruct: 指令模式下的提示(可选) + streaming: 是否使用流式输出,默认0(非流式) + speed: 语速,默认1.0 + + 返回: + 音频数据或错误信息 + """ + # 根据请求方法获取参数,支持GET和POST两种方式 + params = request.get_json() if request.method == 'POST' else request.args + text = params.get('text') + speaker = params.get('speaker') + instruct = params.get('instruct') + streaming = int(params.get('streaming', 0)) + speed = float(params.get('speed', 1.0)) + + # 参数验证 + if not text or not speaker: + return {"error": "文本和角色名不能为空"}, 400 + + # 处理指令模式(可自定义音色) + if instruct: + prompt_speech_16k = load_voice_data(speaker) + if prompt_speech_16k is None: + return {"error": "预训练音色文件中缺少audio_ref数据!"}, 500 + + # 定义指令模式下的推理函数 + inference_func = lambda: cosyvoice.inference_instruct2( + text, instruct, prompt_speech_16k, stream=bool(streaming), speed=speed + ) + else: + # 定义标准模式下的推理函数 + inference_func = lambda: cosyvoice.inference_sft( + text, speaker, stream=bool(streaming), speed=speed + ) + + # 处理流式输出模式 + if streaming: + def generate(): + """生成器函数,用于流式传输音频片段""" + # 第一个标志,用于标记是否已经发送WAV头 + first_chunk = True + + for _, i in enumerate(inference_func()): + audio_data = i['tts_speech'].numpy()[0] # 获取原始音频数据 + + if first_chunk: + # 第一个片段,发送完整WAV头 + buffer = process_audio([i['tts_speech']], format="wav") + yield buffer.read() + first_chunk = False + else: + # 后续片段,只发送原始音频数据 + # 将音频数据转换为字节流 + audio_bytes = (audio_data * (2 ** 15)).astype(np.int16).tobytes() + yield audio_bytes + + # 创建流式响应 + response = make_response(generate()) + response.headers.update({ + 'Content-Type': 'audio/wav', + 'Transfer-Encoding': 'chunked', # 使用分块传输编码 + 'Content-Disposition': 'attachment; filename=sound.wav' + }) + return response + + # 处理非流式输出模式 + tts_speeches = [i['tts_speech'] for _, i in enumerate(inference_func())] + buffer = process_audio(tts_speeches, format="wav") + return create_audio_response(buffer) + + +@app.route("/speakers", methods=['GET', 'POST']) +def speakers(): + """ + 获取可用说话人列表的API端点 + 返回系统中所有可用的预训练和自定义音色列表 + + 返回: + 包含音色信息的JSON响应 + """ + voices = [] + + # 添加预训练的默认音色 + for x in default_voices: + voices.append({"name":x,"voice_id":x}) + + # 添加自定义音色 + for name in os.listdir("voices"): + name = name.replace(".pt","") + voices.append({"name":name,"voice_id":name}) + + # 创建JSON响应,确保使用UTF-8编码,并显式设置Content-Type + response = app.response_class( + response=json.dumps(voices, ensure_ascii=False), + status=200, + mimetype='application/json; charset=utf-8' + ) + response.headers.set('Content-Type', 'application/json; charset=utf-8') + return response + +# 程序入口点 +if __name__ == "__main__": + # 启动Flask Web服务器 + # host='0.0.0.0'表示接受来自任何IP的连接 + # port=9880指定服务端口 + app.run(host='0.0.0.0', port=9880) diff --git a/asset/HarryPorter.txt b/asset/HarryPorter.txt new file mode 100755 index 00000000..0dda59b3 --- /dev/null +++ b/asset/HarryPorter.txt @@ -0,0 +1 @@ +But, Hagrid, there must be a mistake. This says platform 93/4. There's no such thing, is there? diff --git a/asset/HarryPorter.wav b/asset/HarryPorter.wav new file mode 100755 index 00000000..5e0a291d Binary files /dev/null and b/asset/HarryPorter.wav differ diff --git a/asset/harry_potter_snape_injured.txt b/asset/harry_potter_snape_injured.txt new file mode 100755 index 00000000..25531c22 --- /dev/null +++ b/asset/harry_potter_snape_injured.txt @@ -0,0 +1 @@ +I’m not hungry. That explains the blood. Listen. Last night, I'm guessing Snape let the troll in as a diversion, so he could get past that dog. But he got bit, that's why he's limping. The day I was at Gringotts, Hagrid took something out of the vault. Said it was Hogwarts business, very secret. That's what the dog's guarding. That's what Snape wants. I never get mail. diff --git a/asset/harry_potter_snape_injured.wav b/asset/harry_potter_snape_injured.wav new file mode 100755 index 00000000..fe1a3f0a Binary files /dev/null and b/asset/harry_potter_snape_injured.wav differ diff --git a/asset/laoxu.txt b/asset/laoxu.txt new file mode 100755 index 00000000..afec8c0e --- /dev/null +++ b/asset/laoxu.txt @@ -0,0 +1 @@ +啊这个也能理解啊,因为七牛毕竟,是国内最早做云存储的公司。嗯,所以我想,就是和云存储相关的交流,可以在这个这个会之后自由讨论的时候,我们只管沟通啊。知无不言,言无不尽,哼哼. diff --git a/asset/laoxu.wav b/asset/laoxu.wav new file mode 100755 index 00000000..5723f311 Binary files /dev/null and b/asset/laoxu.wav differ diff --git a/asset/wzy_read_poet_27s.txt b/asset/wzy_read_poet_27s.txt new file mode 100755 index 00000000..d4c075d5 --- /dev/null +++ b/asset/wzy_read_poet_27s.txt @@ -0,0 +1 @@ +我最喜欢夏天,满地的鲜花,这里一朵,那里一朵, 真比天上的星星还多。 夜晚,我数着天上的星星,真比地上的花儿还要多。那里一颗,真比天上的花还,花儿还多。 diff --git a/asset/wzy_read_poet_27s.wav b/asset/wzy_read_poet_27s.wav new file mode 100755 index 00000000..13eed4fc Binary files /dev/null and b/asset/wzy_read_poet_27s.wav differ diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py old mode 100644 new mode 100755 index a7bfab4f..f003740c --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -15,6 +15,7 @@ import time from typing import Generator from tqdm import tqdm +from math import ceil from hyperpyyaml import load_hyperpyyaml from modelscope import snapshot_download import torch @@ -23,11 +24,16 @@ from cosyvoice.utils.file_utils import logging from cosyvoice.utils.class_utils import get_model_type +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +current_dir = os.path.dirname(os.path.abspath(__file__)) +parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir)) +grandparent_dir = os.path.dirname(parent_dir) class CosyVoice: def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False): self.instruct = True if '-Instruct' in model_dir else False + self.is_05b = True if 'CosyVoice2-0.5B' in model_dir else False self.model_dir = model_dir self.fp16 = fp16 if not os.path.exists(model_dir): @@ -77,56 +83,144 @@ def add_zero_shot_spk(self, prompt_text, prompt_speech_16k, zero_shot_spk_id): def save_spkinfo(self): torch.save(self.frontend.spk2info, '{}/spk2info.pt'.format(self.model_dir)) - def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True): - for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): - model_input = self.frontend.frontend_sft(i, spk_id) + def _process_with_progress(self, model_input, text_segment, stream, speed): + """处理带进度条的TTS生成 + Args: + model_input: 模型输入 + text_segment: 当前文本片段 + stream: 是否流式输出 + speed: 语速 + """ + tqdm.write(f'{text_segment}\n') + + # 初始化进度条参数 + # 估计迭代次数:非流式模式下为1,流式模式下根据文本长度估计 + estimated_iterations = 1 if not stream else max(1, len(text_segment) // 10) + + # tqdm参数说明: + # - total: 预计的总迭代次数 + # - leave=False: 进度条完成后会被清除,不会在控制台留下痕迹 + # - desc: 进度条前面显示的描述文本 + # - disable=not stream: 当stream为False时禁用进度条,只在流式模式下显示 + with tqdm(total=estimated_iterations, leave=False, desc='当前片段', disable=not stream) as pbar: + iter_count = 0 + start_time = time.time() - logging.info('synthesis text {}'.format(i)) for model_output in self.model.tts(**model_input, stream=stream, speed=speed): - speech_len = model_output['tts_speech'].shape[1] / self.sample_rate - logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + speech = model_output['tts_speech'] + speech_len = speech.shape[1] / self.sample_rate + iter_count += 1 + if stream: + # 更新进度条后缀显示实时速率比(rtf) + rtf = (time.time() - start_time) / speech_len + pbar.set_postfix_str(f'rtf={rtf:.2f}') + # 持久打印rtf值,不随进度条消失 + print(f" 在cosyvoice.py函数: _process_with_progress 中,耗时: {time.time() - start_time:.2f} 秒") + # 仅在迭代次数小于3时根据实际语音长度更新预估总迭代次数 + if iter_count <= 3: + # 估计总token数量 + total_speech_tokens =len(self.model.tts_speech_token_dict[self.model.this_uuid]) + + # 每次迭代处理token_hop_len个token,计算总迭代次数 + # 注意考虑到token_hop_len会随着迭代增加 + iterations_needed = 0 + remaining_tokens = total_speech_tokens + # current_hop_len = self.model.token_min_hop_len + current_hop_len = self.model.token_hop_len # 仅适用于CosyVoice2 + + while remaining_tokens > 0: + remaining_tokens -= current_hop_len + iterations_needed += 1 + # current_hop_len = min(self.model.token_max_hop_len, int(current_hop_len * self.model.stream_scale_factor)) # 仅适用于CosyVoice + + pbar.total = max(1, iterations_needed) + + # 确保总迭代次数至少等于当前迭代次数,动态调整进度条长度 + pbar.total = max(pbar.total, iter_count) + # 更新进度条,前进一步 + pbar.update(1) + else: + # 非流式模式下关闭进度条 + pbar.close() + yield model_output start_time = time.time() + def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True): + default_voices = self.list_available_spks() + + for text_segment in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend), desc='生成进度'): + + start_time = time.time() + # 根据音色ID获取模型输入 + spk = default_voices[0] if spk_id not in default_voices else spk_id + print(f"spk: {spk}, spk_id: {spk_id}") + model_input = self.frontend.frontend_sft(text_segment, spk) + + # # 如果是自定义音色,加载并更新音色相关特征 + # if spk_id not in default_voices: + # newspk = torch.load( + # f'{grandparent_dir}/voices/{spk_id}.pt', + # map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # ) + + + yield from self._process_with_progress(model_input, text_segment, stream, speed) + + def _save_voice_model(self, model_input, prompt_speech_16k, text_ref=None, save_path='output.pt'): + """保存音色模型到文件 + Args: + model_input: 包含音色信息的模型输入 + prompt_speech_16k: 参考音频 + text_ref: 参考文本(可选) + save_path: 保存路径,默认为output.pt + """ + model_input['audio_ref'] = prompt_speech_16k + if text_ref is not None: + model_input['text_ref'] = text_ref + + torch.save(model_input, save_path) + def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True): prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend) - for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): - if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text): - logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text)) - model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id) - start_time = time.time() - logging.info('synthesis text {}'.format(i)) - for model_output in self.model.tts(**model_input, stream=stream, speed=speed): - speech_len = model_output['tts_speech'].shape[1] / self.sample_rate - logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) - yield model_output - start_time = time.time() + + # 先获取所有分段,找出最长的一段 + text_parts = list(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)) + longest_segment = max(text_parts, key=len) + longest_idx = text_parts.index(longest_segment) + + for idx, text_segment in enumerate(tqdm(text_parts, desc='生成进度')): + if (not isinstance(text_segment, Generator)) and len(text_segment) < 0.5 * len(prompt_text): + logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(text_segment, prompt_text)) + model_input = self.frontend.frontend_zero_shot(text_segment, prompt_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id) - def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True): - for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): - model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate) - start_time = time.time() - logging.info('synthesis text {}'.format(i)) - for model_output in self.model.tts(**model_input, stream=stream, speed=speed): - speech_len = model_output['tts_speech'].shape[1] / self.sample_rate - logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) - yield model_output - start_time = time.time() + if idx == 0 or idx == longest_idx: # 保存第一段或最长段作为音色模型 + self._save_voice_model(model_input, prompt_speech_16k, prompt_text) + + yield from self._process_with_progress(model_input, text_segment, stream, speed) + + def inference_cross_lingual(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True): + # 先获取所有分段,找出最长的一段 + text_parts = list(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)) + longest_segment = max(text_parts, key=len) + longest_idx = text_parts.index(longest_segment) + + for idx, text_segment in enumerate(tqdm(text_parts, desc='生成进度')): + model_input = self.frontend.frontend_cross_lingual(text_segment, prompt_speech_16k, self.sample_rate) + + if idx == 0 or idx == longest_idx: # 保存第一段或最长段作为音色模型 + self._save_voice_model(model_input, prompt_speech_16k) + + yield from self._process_with_progress(model_input, text_segment, stream, speed) def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True): assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!' if self.instruct is False: raise ValueError('{} do not support instruct inference'.format(self.model_dir)) instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend) - for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): - model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text) - start_time = time.time() - logging.info('synthesis text {}'.format(i)) - for model_output in self.model.tts(**model_input, stream=stream, speed=speed): - speech_len = model_output['tts_speech'].shape[1] / self.sample_rate - logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) - yield model_output - start_time = time.time() + for text_segment in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): + model_input = self.frontend.frontend_instruct(text_segment, spk_id, instruct_text) + yield from self._process_with_progress(model_input, text_segment, stream, speed) def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0): model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate) @@ -142,6 +236,7 @@ class CosyVoice2(CosyVoice): def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False): self.instruct = True if '-Instruct' in model_dir else False + self.is_05b = True if 'CosyVoice2-0.5B' in model_dir else False self.model_dir = model_dir self.fp16 = fp16 if not os.path.exists(model_dir): @@ -172,6 +267,7 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_fl self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), self.fp16) + logging.info('CosyVoice2 初始化完成!') del configs def inference_instruct(self, *args, **kwargs): @@ -179,12 +275,6 @@ def inference_instruct(self, *args, **kwargs): def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True): assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!' - for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): - model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate) - start_time = time.time() - logging.info('synthesis text {}'.format(i)) - for model_output in self.model.tts(**model_input, stream=stream, speed=speed): - speech_len = model_output['tts_speech'].shape[1] / self.sample_rate - logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) - yield model_output - start_time = time.time() + for text_segment in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): + model_input = self.frontend.frontend_instruct2(text_segment, instruct_text, prompt_speech_16k, self.sample_rate) + yield from self._process_with_progress(model_input, text_segment, stream, speed) diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py old mode 100644 new mode 100755 index 8770e312..d803b2df --- a/cosyvoice/cli/frontend.py +++ b/cosyvoice/cli/frontend.py @@ -68,7 +68,8 @@ def __init__(self, 'failed to initialize ttsfrd resource' self.frd.set_lang_type('pinyinvg') else: - self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True) + # alan wanglinlin 20250321 overwrite_cache=False 原来为True + self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=False, remove_interjections=False) self.en_tn_model = EnNormalizer() self.inflect_parser = inflect.engine() @@ -149,9 +150,41 @@ def text_normalize(self, text, split=True, text_frontend=True): return texts if split is True else text def frontend_sft(self, tts_text, spk_id): + if spk_id not in self.spk2info: + logging.warning(f"in Frontend.py line 157, 说话人ID {spk_id} 不存在于 spk2info 中, embedding is {embedding}") + # 更新模型输入中的音色特征 + spk_fields = [ + "flow_embedding", "llm_embedding", + "llm_prompt_speech_token", "llm_prompt_speech_token_len", + "flow_prompt_speech_token", "flow_prompt_speech_token_len", + "prompt_speech_feat_len", "prompt_speech_feat", + "prompt_text", "prompt_text_len" + ] + model_input = {} + for field in spk_fields: + if field in self.spk2info[spk_id]: + model_input[field] = self.spk2info[spk_id][field] + tts_text_token, tts_text_token_len = self._extract_text_token(tts_text) - embedding = self.spk2info[spk_id]['embedding'] - model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding} + + if 'llm_embedding' in self.spk2info[spk_id]: + embedding = self.spk2info[spk_id]['llm_embedding'] + elif 'embedding' in self.spk2info[spk_id]: + embedding = self.spk2info[spk_id]['embedding'] + else: + embedding = self.spk2info[spk_id]['flow_embedding'] + + # 确保即使 model_input 中已有相同的键,也会被更新 + update_fields = { + 'text': tts_text_token, + 'text_len': tts_text_token_len, + 'llm_embedding': embedding, + 'flow_embedding': embedding + } + + # 更新 model_input 中的字段 + for key, value in update_fields.items(): + model_input[key] = value return model_input def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate, zero_shot_spk_id): @@ -179,7 +212,7 @@ def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_ return model_input def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate): - model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate) + model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate, '') # in cross lingual mode, we remove prompt in llm del model_input['prompt_text'] del model_input['prompt_text_len'] @@ -197,7 +230,7 @@ def frontend_instruct(self, tts_text, spk_id, instruct_text): return model_input def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate): - model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate) + model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate, '') del model_input['llm_prompt_speech_token'] del model_input['llm_prompt_speech_token_len'] return model_input diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 20ddad03..d69c51ab 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -22,6 +22,7 @@ import uuid from cosyvoice.utils.common import fade_in_out from cosyvoice.utils.file_utils import convert_onnx_to_trt +import logging class CosyVoiceModel: @@ -39,7 +40,7 @@ def __init__(self, if self.fp16 is True: self.llm.half() self.flow.half() - self.token_min_hop_len = 2 * self.flow.input_frame_rate + self.token_min_hop_len = int(1.5 * self.flow.input_frame_rate) self.token_max_hop_len = 4 * self.flow.input_frame_rate self.token_overlap_len = 20 # mel fade in out @@ -51,7 +52,7 @@ def __init__(self, # speech fade in out self.speech_window = np.hamming(2 * self.source_cache_len) # rtf and decoding related - self.stream_scale_factor = 1 + self.stream_scale_factor = 1.1 assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf' self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() self.lock = threading.Lock() @@ -61,6 +62,7 @@ def __init__(self, self.mel_overlap_dict = {} self.flow_cache_dict = {} self.hift_cache_dict = {} + self.this_uuid = '' def load(self, llm_model, flow_model, hift_model): self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True) @@ -90,6 +92,7 @@ def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16): import tensorrt as trt with open(flow_decoder_estimator_model, 'rb') as f: self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + logging.info('loading trt {}'.format(flow_decoder_estimator_model)) assert self.flow.decoder.estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context() @@ -98,6 +101,7 @@ def get_trt_kwargs(self): opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200)] max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)] input_names = ["x", "mask", "mu", "cond"] + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): @@ -173,6 +177,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread this_uuid = str(uuid.uuid1()) + self.this_uuid = this_uuid with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.hift_cache_dict[this_uuid] = None @@ -260,6 +265,8 @@ def __init__(self, # speech fade in out self.speech_window = np.hamming(2 * self.source_cache_len) # rtf and decoding related + self.stream_scale_factor = 1.1 + assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf' self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() self.lock = threading.Lock() # dict used to store session related variable @@ -267,6 +274,7 @@ def __init__(self, self.llm_end_dict = {} self.flow_cache_dict = {} self.hift_cache_dict = {} + self.this_uuid = '' def init_flow_cache(self): encoder_cache = {'offset': 0, @@ -345,6 +353,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread this_uuid = str(uuid.uuid1()) + self.this_uuid = this_uuid with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.hift_cache_dict[this_uuid] = None @@ -360,7 +369,8 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze flow_prompt_speech_token = flow_prompt_speech_token[:, -int(self.flow_decoder_required_cache_size / self.flow.token_mel_ratio):] prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size:] while True: - time.sleep(0.1) + time.sleep(0.05) + start_time = time.time() if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len: this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token, @@ -373,10 +383,12 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32).to(self.device) prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device) yield {'tts_speech': this_tts_speech.cpu()} + # print(f"in model.py, token2wav耗时: {time.time() - start_time:.2f} 秒") with self.lock: self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][self.token_hop_len:] if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < self.token_hop_len + self.flow.pre_lookahead_len: break + # print(f"in model.py, before join 耗时: {time.time() - start_time_p:.2f} 秒") p.join() # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index 670ae69c..f8733b88 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -434,40 +434,65 @@ def inference_bistream( ) -> Generator[torch.Tensor, None, None]: device = prompt_text.device - # 1. prepare input + # 1. 准备输入 + # 创建开始/结束嵌入向量,形状为 [1, 1, llm_input_size] sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + # 创建任务ID嵌入向量,形状为 [1, 1, llm_input_size] task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + + # 处理提示语音标记,如果存在 if prompt_speech_token_len != 0: + # 将提示语音标记转换为嵌入向量 prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) else: + # 如果没有提示语音标记,创建空张量 prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device) + + # 初始化模型输入,只包含开始标记嵌入 lm_input = torch.concat([sos_eos_emb], dim=1) - # 2. iterate text - out_tokens = [] - cache = None - # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5 + # 2. 迭代处理文本 + out_tokens = [] # 存储生成的输出标记 + cache = None # 初始化注意力缓存为空 + # 将提示文本转换为嵌入向量,作为文本缓存的初始值 text_cache = self.llm.model.model.embed_tokens(prompt_text) - next_fill_index = -1 + next_fill_index = -1 # 下一个填充标记的位置,初始为-1 + + # 逐批处理流式输入的文本 for this_text in text: + # 将当前批次文本转换为嵌入并添加到文本缓存 text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1) - # prompt_speech_token_emb not empty, try append to lm_input + + # 如果还有提示语音标记未处理,尝试将其添加到输入中 while prompt_speech_token_emb.size(1) != 0: - if text_cache.size(1) >= self.mix_ratio[0]: - lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]] + # 如果文本缓存长度足够,处理一批文本和语音标记 + if text_cache.size(1) >= self.mix_ratio[0]: # mix_ratio[0]是文本批次大小(如5) + # 提取文本和语音批次,mix_ratio[1]是语音批次大小(如15) + lm_input_text = text_cache[:, :self.mix_ratio[0]] + lm_input_speech = prompt_speech_token_emb[:, :self.mix_ratio[1]] logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1))) + + # 将文本和语音批次连接到输入中 lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1) - text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:] + + # 更新缓存,移除已处理的部分 + text_cache = text_cache[:, self.mix_ratio[0]:] + prompt_speech_token_emb = prompt_speech_token_emb[:, self.mix_ratio[1]:] else: + # 文本不足,等待更多输入 logging.info('not enough text token to decode, wait for more') break - # no prompt_speech_token_emb remain, can decode some speech token + + # 当提示语音标记处理完毕后,开始生成新的语音标记 if prompt_speech_token_emb.size(1) == 0: + # 处理填充标记的情况 if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1): logging.info('get fill token, need to append more text token') + # 如果文本缓存足够,添加新的文本批次 if text_cache.size(1) >= self.mix_ratio[0]: lm_input_text = text_cache[:, :self.mix_ratio[0]] logging.info('append {} text token'.format(lm_input_text.size(1))) + if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2: lm_input = lm_input_text else: diff --git a/cosyvoice/utils/file_utils.py b/cosyvoice/utils/file_utils.py index f0a450c5..cef5f70b 100644 --- a/cosyvoice/utils/file_utils.py +++ b/cosyvoice/utils/file_utils.py @@ -16,7 +16,7 @@ import json import torchaudio import logging -logging.getLogger('matplotlib').setLevel(logging.WARNING) +logging.getLogger('matplotlib').setLevel(logging.DEBUG) logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s') diff --git a/cosyvoice_2_demo.py b/cosyvoice_2_demo.py new file mode 100755 index 00000000..15851dcd --- /dev/null +++ b/cosyvoice_2_demo.py @@ -0,0 +1,254 @@ +import sys +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "7" +import logging +from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 +from cosyvoice.utils.file_utils import load_wav +import torchaudio +import torch +import argparse +import time +import numpy as np +from stream_player import StreamPlayer + +# 设置根目录并添加第三方库路径 +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append("{}/third_party/Matcha-TTS".format(ROOT_DIR)) + + +# 设置日志级别为 DEBUG +logging.basicConfig( + level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logging.getLogger().setLevel(logging.DEBUG) + +# 确保设置影响所有模块 +for name in logging.root.manager.loggerDict: + logging.getLogger(name).setLevel(logging.DEBUG) + + +# 添加命令行参数解析 +parser = argparse.ArgumentParser(description="CosyVoice2 Demo") +parser.add_argument( + "--model_dir", + type=str, + default="pretrained_models/CosyVoice2-0.5B", + help="模型目录路径", +) +parser.add_argument( + "--fp16", action="store_true", default=False, help="是否使用半精度(fp16)推理" +) +parser.add_argument( + "--use_flow_cache", action="store_true", default=False, help="是否使用流式缓存" +) + +args = parser.parse_args() + +print(f"使用模型目录: {args.model_dir}") +cosyvoice = CosyVoice2( + args.model_dir, + load_jit=False, + load_trt=True, + fp16=args.fp16, + use_flow_cache=args.use_flow_cache, +) + +print(cosyvoice.list_available_spks()) + + +# # prompt_speech_16k = load_wav("./asset/sqr3.wav", 16000) +# # prompt_speech_16k = load_wav("./asset/wll3.wav", 16000) +# # prompt_speech_16k = load_wav("./asset/wzy_read_poet_27s.wav", 16000) +# prompt_speech_16k = load_wav("./asset/harry_potter_snape_injured.wav", 16000) +# # prompt_speech_16k = load_wav("./asset/laoxu.wav", 16000) +# for i, j in enumerate( +# cosyvoice.inference_zero_shot( +# "收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。", +# # "声纹识别能力,多测一些", +# # '明天是星期六啦,我要去上果粒课啦,你们知道吗?', +# "I’m not hungry. That explains the blood. Listen. Last night, I'm guessing Snape let the troll in as a diversion, so he could get past that dog. But he got bit, that's why he's limping. The day I was at Gringotts, Hagrid took something out of the vault. Said it was Hogwarts business, very secret. That's what the dog's guarding. That's what Snape wants. I never get mail.", +# # "啊这个也能理解啊,因为七牛毕竟,是国内最早做云存储的公司。嗯,所以我想,就是和云存储相关的交流,可以在这个这个会之后自由讨论的时候,我们只管沟通啊。知无不言,言无不尽,哼哼。", +# # "我最喜欢夏天,满地的鲜花,这里一朵,那里一朵, 真比天上的星星还多。 夜晚,我数着天上的星星,真比地上的花儿还要多。那里一颗,真比天上的花还,花儿还多。", +# prompt_speech_16k, +# stream=args.use_flow_cache, +# ) +# ): +# torchaudio.save( +# "zero_shot_{}.wav".format(i), j["tts_speech"], cosyvoice.sample_rate +# ) + +# # save zero_shot spk for future usage +# assert ( +# cosyvoice.add_zero_shot_spk( +# # "声纹识别能力,多测一些", prompt_speech_16k, "wll" +# # '明天是星期六啦,我要去上果粒课啦,你们知道吗?', prompt_speech_16k, "wzy" +# # "啊这个也能理解啊,因为七牛毕竟,是国内最早做云存储的公司。嗯,所以我想,就是和云存储相关的交流,可以在这个这个会之后自由讨论的时候,我们只管沟通啊。知无不言,言无不尽,哼哼。", prompt_speech_16k, "laoxu" +# # "我最喜欢夏天,满地的鲜花,这里一朵,那里一朵, 真比天上的星星还多。 夜晚,我数着天上的星星,真比地上的花儿还要多。那里一颗,真比天上的花还,花儿还多。", +# "I’m not hungry. That explains the blood. Listen. Last night, I'm guessing Snape let the troll in as a diversion, so he could get past that dog. But he got bit, that's why he's limping. The day I was at Gringotts, Hagrid took something out of the vault. Said it was Hogwarts business, very secret. That's what the dog's guarding. That's what Snape wants. I never get mail.", +# prompt_speech_16k, +# "hp", +# ) +# is True +# ) +# cosyvoice.save_spkinfo() + + +player = StreamPlayer(sample_rate=cosyvoice.sample_rate, channels=1, block_size=18048) +player.start() + + +print( + "\n按回车使用默认文本,输入新文本后回车使用新文本,输入q后回车退出, 输入@后回车使用新指令\n" +) + + +while True: + # 交互式循环,可以反复输入文本生成语音 + # speaker = "xiaoluo_mandarin" + # speaker = "Donald J. Trump" + # default_tts_text = "CosyVoice迎来全面升级,提供更准、更稳、更快、 更好的语音生成能力。make america great again. " + default_speaker = "hp" + default_tts_text = "从此每当害怕时,他就想起那个和伙伴共同编织星光的夜晚 [noise] ,勇气便像萤火虫般在心底亮起来。" + default_instruct_text = "用很慢的语速读这个故事" + speaker = default_speaker + tts_text = default_tts_text + instruct_text = default_instruct_text + # 获取用户输入 + user_input = input( + f"请输入文本 (格式: ' speaker @ tts_text @ instruct_text') 退出: q " + ) + + # 检查是否退出 + if user_input.strip() == "q": + print("退出语音生成循环") + break + + if len(user_input) > 1: + speaker = user_input.split("@")[0] + if len(user_input.split("@")) > 1: + speaker = user_input.split("@")[0] + tts_text = user_input.split("@")[1] + if len(user_input.split("@")) > 2: + speaker = user_input.split("@")[0] + tts_text = user_input.split("@")[1] + instruct_text = user_input.split("@")[2] + + print(f"SPEAKER 是: {speaker}, tts_text 是: {tts_text}") + start_time = time.time() + for i, j in enumerate( + # cosyvoice.inference_instruct2( + # tts_text, + # instruct_text, + # prompt_speech_16k, + # stream=True, + # speed=0.8, + # text_frontend=True, + # ) + cosyvoice.inference_sft( + tts_text, + speaker, + stream=args.use_flow_cache, + ) + ): + current_time = time.time() + # logging.info(f"第 {i} 次生成耗时: {current_time - start_time:.2f} 秒") + start_time = current_time + + # torchaudio.save( + # "sft_{}.wav".format(i), j["tts_speech"], cosyvoice.sample_rate + # ) + player.play(j["tts_speech"].numpy().T) + +# 停止播放器 +player.stop() + +# # 最后一个示例,保存到文件而不是播放 +# start_time = time.time() +# for i, j in enumerate( +# cosyvoice.inference_zero_shot( +# # "这句话里面到底在使用了谁的语音呢?", +# "CosyVoice迎来全面升级,提供更准、更稳、更快、 更好的语音生成能力。make america great again. ", +# "我会把三段话切成3段,用来做", +# prompt_speech_16k, +# stream=True, +# ) +# ): +# current_time = time.time() +# logging.info(f"第 {i} 次生成耗时: {current_time - start_time:.2f} 秒") +# start_time = current_time +# torchaudio.save( +# "zero_shot_{}.wav".format(i), j["tts_speech"], cosyvoice.sample_rate +# ) + +# # instruct usage +# for i, j in enumerate( +# cosyvoice.inference_instruct2( +# "收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。", +# "用四川话说这句话", +# prompt_speech_16k, +# stream=True, +# ) +# ): +# torchaudio.save("instruct_{}.wav".format(i), j["tts_speech"], cosyvoice.sample_rate) + +# # NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference +# # zero_shot usage +# prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000) +# for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)): +# torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) + +# # fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248 +# for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', prompt_speech_16k, stream=False)): +# torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) + +# # instruct usage +# for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)): +# torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) + +# # bistream usage, you can use generator as input, this is useful when using text llm model as input +# # NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length +# def text_generator(): +# yield '收到好友从远方寄来的生日礼物,' +# yield '那份意外的惊喜与深深的祝福' +# yield '让我心中充满了甜蜜的快乐,' +# yield '笑容如花儿般绽放。' +# for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)): +# torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) + +# start_time = time.time() + +# for i, j in enumerate( +# cosyvoice.inference_cross_lingual( +# "在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。", +# "没有用到的文本", +# prompt_speech_16k, +# stream=True, +# ) +# ): +# current_time = time.time() +# print(f"第 {i} 次生成耗时: {current_time - start_time:.2f} 秒") +# start_time = current_time +# torchaudio.save( +# "fine_grained_control_{}.wav".format(i), j["tts_speech"], cosyvoice.sample_rate +# ) + +# # 使用改进的播放机制进行流式语音生成和播放 +# print("ATTENTION: 文本已经给到模型,开始生成语音啦!!!") +# start_time = time.time() + +# # 流式生成并添加到播放队列 +# for i, j in enumerate( +# cosyvoice.inference_zero_shot( +# # "这句话里面到底在使用了谁的语音呢?", +# "CosyVoice迎来全面升级,提供更准、更稳、更快、 更好的语音生成能力。make america great again. ", +# "我会把三段话切成3段,用来做", +# prompt_speech_16k, +# stream=True, +# ) +# ): +# current_time = time.time() +# logging.info(f"第 {i} 次生成耗时: {current_time - start_time:.2f} 秒") +# start_time = current_time + +# player.play(j["tts_speech"].numpy().T) diff --git a/cosyvoice_demo_quant_compare.py b/cosyvoice_demo_quant_compare.py new file mode 100644 index 00000000..385ae625 --- /dev/null +++ b/cosyvoice_demo_quant_compare.py @@ -0,0 +1,76 @@ +import sys + +sys.path.append("third_party/Matcha-TTS") +from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 +from cosyvoice.utils.file_utils import load_wav +import torchaudio +import argparse +import os + +# 添加命令行参数解析 +parser = argparse.ArgumentParser(description="CosyVoice2 Demo") +parser.add_argument( + "--model_dir", + type=str, + default="pretrained_models/CosyVoice2-0.5B", + help="模型目录路径", +) +parser.add_argument( + "--output_dir", + type=str, + default="exp", + help="输出目录路径", +) + +parser.add_argument( + "--fp16", action="store_true", default=False, help="是否使用半精度(fp16)推理" +) +args = parser.parse_args() + +print(f"使用模型目录: {args.model_dir}") +cosyvoice = CosyVoice2(args.model_dir, load_jit=False, load_trt=False, fp16=args.fp16) +# cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False, fp16=args.fp16) + +k = 0 +for name, transcript in [ + ("./asset/sqr3.wav", "我会把三段话切成3段,用来做"), + ("./asset/wll3.wav", "声纹识别能力,多测一些"), + ("./asset/wzy_stereo.wav", "明天是星期六啦,我要去上果力课啦,你们知道吗?"), +]: + + prompt_speech_16k = load_wav(name, 16000) + for i, j in enumerate( + cosyvoice.inference_zero_shot( + "我们是X robot小组,[laughter],在做角色扮演的机器人。", + transcript, + prompt_speech_16k, + stream=False, + ) + ): + audio_path = os.path.join(args.output_dir, "zero_shot_{}.wav".format(k)) + torchaudio.save(audio_path, j["tts_speech"], cosyvoice.sample_rate) + k += 1 + +# # NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference +# # zero_shot usage +# prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000) +# for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)): +# torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) + +# # fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248 +# for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', prompt_speech_16k, stream=False)): +# torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) + +# # instruct usage +# for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)): +# torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) + +# # bistream usage, you can use generator as input, this is useful when using text llm model as input +# # NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length +# def text_generator(): +# yield '收到好友从远方寄来的生日礼物,' +# yield '那份意外的惊喜与深深的祝福' +# yield '让我心中充满了甜蜜的快乐,' +# yield '笑容如花儿般绽放。' +# for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)): +# torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) diff --git a/docker/Dockerfile b/docker/Dockerfile index d7faf031..7f88ee74 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,4 +1,4 @@ -FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 +FROM nvidia/cuda:12.2.2-cudnn8-devel-ubuntu22.04 ARG VENV_NAME="cosyvoice" ENV VENV=$VENV_NAME @@ -8,17 +8,23 @@ ENV DEBIAN_FRONTEN=noninteractive ENV PYTHONUNBUFFERED=1 SHELL ["/bin/bash", "--login", "-c"] +# 设置阿里云镜像源 +RUN sed -i 's/archive.ubuntu.com/mirrors.aliyun.com/g' /etc/apt/sources.list +RUN sed -i 's/security.ubuntu.com/mirrors.aliyun.com/g' /etc/apt/sources.list + RUN apt-get update -y --fix-missing RUN apt-get install -y git build-essential curl wget ffmpeg unzip git git-lfs sox libsox-dev && \ apt-get clean && \ git lfs install +RUN git config --global http.proxy socks5://183.240.180.158:10080 +RUN git config --global https.proxy socks5://183.240.180.158:10080 # ================================================================== # conda install and conda forge channel as default # ------------------------------------------------------------------ # Install miniforge -RUN wget --quiet https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh -O ~/miniforge.sh && \ - /bin/bash ~/miniforge.sh -b -p /opt/conda && \ +RUN wget --quiet https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh -O ~/miniforge.sh +RUN /bin/bash ~/miniforge.sh -b -p /opt/conda && \ rm ~/miniforge.sh && \ ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ echo "source /opt/conda/etc/profile.d/conda.sh" >> /opt/nvidia/entrypoint.d/100.conda.sh && \ @@ -26,7 +32,7 @@ RUN wget --quiet https://github.com/conda-forge/miniforge/releases/latest/downlo echo "conda activate ${VENV}" >> /opt/nvidia/entrypoint.d/110.conda_default_env.sh && \ echo "conda activate ${VENV}" >> $HOME/.bashrc -ENV PATH /opt/conda/bin:$PATH +ENV PATH=/opt/conda/bin:$PATH RUN conda config --add channels conda-forge && \ conda config --set channel_priority strict @@ -36,7 +42,12 @@ RUN conda config --add channels conda-forge && \ RUN conda create -y -n ${VENV} python=3.10 ENV CONDA_DEFAULT_ENV=${VENV} -ENV PATH /opt/conda/bin:/opt/conda/envs/${VENV}/bin:$PATH +ENV PATH=/opt/conda/bin:/opt/conda/envs/${VENV}/bin:$PATH + +# 先升级pip +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install --upgrade pip + WORKDIR /workspace @@ -45,7 +56,22 @@ ENV PYTHONPATH="${PYTHONPATH}:/workspace/CosyVoice:/workspace/CosyVoice/third_pa RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git RUN conda activate ${VENV} && conda install -y -c conda-forge pynini==2.1.5 -RUN conda activate ${VENV} && cd CosyVoice && \ +RUN conda activate ${VENV} +RUN cd CosyVoice2-Ex + +ENV http_proxy=socks5://183.240.180.158:10080 +ENV https_proxy=socks5://183.240.180.158:10080 +ENV all_proxy=socks5://183.240.180.158:10080 + +WORKDIR /workspace/CosyVoice2-Ex +# 下一句要实现打印当前目录 +RUN echo "当前目录: $(pwd)" +RUN --mount=type=cache,target=/root/.cache/pip \ pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com -WORKDIR /workspace/CosyVoice + +ENV LOG_LEVEL=INFO + +EXPOSE 8080 + +CMD ["sh", "-c", "python webui.py --port 8080 --log_level ${LOG_LEVEL}"] diff --git a/docker/voice_install.sh b/docker/voice_install.sh new file mode 100644 index 00000000..d2a581bc --- /dev/null +++ b/docker/voice_install.sh @@ -0,0 +1,84 @@ +#!/bin/bash +# 添加错误处理函数 +# 注意这里没有包含安装 ros2 humble, 需要另行安装 + +# 设置命令回显 +set -x + +# 定义函数用于显示安装步骤 +show_step() { + echo "====================================================" + echo "🔶 步骤: $@" + echo "====================================================" +} + + +show_step "###################正式安装从这里开始#####################################" + +# 1. 安装通用包,基础的环境 +show_step "安装通用包,基础的环境" +apt-get update -y --fix-missing +apt-get install -y git curl wget ffmpeg unzip git-lfs sox libsox-dev && \ + apt-get clean + +# 安装通用包 +apt-get install -y apt ssh make gcc curl cmake g++ unzip lsof net-tools + +# 安装ROS2 依赖的包, 并不安装 ros2 humble +apt-get install -y python3-pip ament-cmake + +# 2. 安装语音需要的系统级的包或库 +show_step "安装语音系统级依赖包" + +# 分批安装包,以减少错误风险 +# 基础开发工具 +apt-get install -y pkg-config libfftw3-dev nlohmann-json3-dev libeigen3-dev + +# 音频相关库 +apt-get install -y libsndfile1-dev pulseaudio pulseaudio-utils + +# Mesa相关 +apt-get install -y mesa-utils libglu1-mesa-dev + +# Pybind相关 +apt-get install -y python3-pybind11 pybind11-dev + +# 音频编解码相关 +apt-get install -y libmpg123-dev libmad0-dev libsndio-dev libwebrtc-audio-processing-dev libwavpack-dev + +# 视频编解码相关 +apt-get install -y libavcodec-dev libavc1394-dev + +# 3. 安装语音需要的 python 包 +show_step "安装语音需要的 python 包, 放入一个轻量级的虚拟环境中,没有使用过重的conda" + +echo "当前目录: $(pwd)" +# 安装轻量级的 venv +pip install virtualenv + +virtualenv venv_voice +source venv_voice/bin/activate +# source /disk1/venv_torch/bin/activate + +# 创建requirements.txt文件 +cat > requirements.txt << EOF +torch==2.6.0+cpu +torchaudio==2.6.0 +torchvision==0.21.0 +yeaudio==0.0.7 +tqdm==4.67.1 +SoundCard==0.4.3 +scikit-learn==1.6.1 +scipy==1.15.1 +pybind11==2.13.6 +pip==25.0.1 +llvmlite==0.44.0 +kaldi-native-fbank==1.20.2 +catkin-pkg==1.0.0 +librosa==0.10.2 +EOF + +# 在上面venv_voice中pip安装一大堆包 +pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com + +echo "所有依赖安装完成" \ No newline at end of file diff --git a/example_instructions.html b/example_instructions.html new file mode 100644 index 00000000..204a8f51 --- /dev/null +++ b/example_instructions.html @@ -0,0 +1,274 @@ + + + + + CosyVoice 控制指令示例 + + + +

CosyVoice 控制指令示例

+

点击示例文本可以复制到剪贴板。所有示例都可以直接在 CosyVoice 中使用。

+ +
+

1. 角色扮演控制

+
神秘<|endofprompt|>那座古老的城堡笼罩在神秘的雾气中,吸引着冒险者前去探索。
+
凶猛<|endofprompt|>战士们以凶猛的勇气冲锋陷阵,让敌人闻风丧胆。
+
好奇<|endofprompt|>对艺术的无限可能,她总是充满好奇,愿意尝试各种不同的创作形式。
+
优雅<|endofprompt|>那位钢琴家的演奏优雅动人,每一个音符都触动人心。
+
孤独<|endofprompt|>夜深人静时,她坐在窗前,总是感到一种莫名的孤独。
+
模仿机器人风格<|endofprompt|>在人工智能技术的支持下,我能够识别语音指令并执行相关操作。
+
我想听听你模仿小猪佩奇的语气。<|endofprompt|>大家好,我是小猪佩奇,今天我和苏西羊一起去公园玩。
+
一个活泼、爱冒险的小精灵<|endofprompt|>嘿,看那片云,它看起来像一只大象在跳舞!
+
一位权威、威严的古代将军<|endofprompt|>战场上的胜利,不仅依赖于兵力,更取决于决策的果敢和士气的高昂。
+
一个忧郁的诗人<|endofprompt|>月光下的一切都是那么宁静,却也那么孤寂,正如我心中的那片荒芜。
+
+ +
+

2. 方言控制

+
用粤语说这句话<|endofprompt|>我最近迷上一部经典港剧,入面嗰啲对白真系有嚟头。
+
用四川话说<|endofprompt|>前儿个去宽窄巷子喝盖碗茶,坐在小板凳上头听人摆龙门阵,简直安逸得很。
+
上海话<|endofprompt|>侬晓得伐,上礼拜我去淮海路的小马路上头捡漏,居然淘到一只老克勒的手表。
+
郑州话<|endofprompt|>这阵子我在听豫剧,虽然有些地方唱词听不太明白,但音乐一响,耳朵就被吸住了。
+
长沙话<|endofprompt|>哎呀,前几天去坡子街吃夜宵,那口味虾辣得我直冒汗,嘴巴烧得像火。
+
天津话<|endofprompt|>今儿个去逛古文化街,那些个手工艺品五花八门,特别是杨柳青年画。
+
+ +
+

3. 细粒度控制

+

笑声控制

+
在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。
+
他搞的一个恶作剧,让大家<laughter>忍俊不禁</laughter>。
+
Oh wow [laughter], I thought I had seen it all until now [laughter]. Your ability to surprise never ceases to amaze me [laughter].
+ +

强调控制

+
追求卓越不是终点,它需要你每天都<strong>付出</strong>和<strong>精进</strong>,最终才能达到巅峰。
+
With <strong>determination</strong> and <strong>focus</strong>, we can overcome <strong>any challenge</strong>.
+ +

呼吸控制

+
当你用心去倾听一首音乐时[breath],你会开始注意到那些细微的音符变化[breath],并通过它们感受到音乐背后的情感。
+
深呼吸[breath]让我们保持冷静[breath]仔细思考这个问题。
+ +

混合控制

+
这个笑话太有趣了[laughter],让我喘口气[breath],<strong>实在是太好笑了</strong>!
+
The performance was <strong>breathtaking</strong> [breath], and the audience burst into [laughter] thunderous applause.
+
+ +
+

4. 说话风格控制

+

情感风格

+
用开心的语气说<|endofprompt|>参加朋友的婚礼,看着新人幸福的笑脸,我感到无比开心。
+
用伤心的语气说<|endofprompt|>收到拒信的那一刻,我感到无比伤心。
+
用惊讶的语气说<|endofprompt|>走进家门,看见墙上挂满了我的照片,我惊讶得愣住了。
+
用生气的语气说<|endofprompt|>在交通高峰期,遭遇到一位鲁莽的司机插队,我感到非常生气。
+
用恐惧的情感表达<|endofprompt|>看恐怖电影时,那突如其来的惊悚画面让我感到无比恐惧。
+
用恶心的情感表达<|endofprompt|>听到关于人体实验的细节描述,我感到非常恶心。
+ +

语速控制

+
快速<|endofprompt|>这款新应用程序加载速度极快,让用户体验得到了极大的提升。
+
非常快速<|endofprompt|>这款新应用程序加载速度极快,让用户体验得到了极大的提升。
+
慢速<|endofprompt|>听着轻柔的音乐,我在画布上慢慢地涂抹色彩。
+
非常慢速<|endofprompt|>听着轻柔的音乐,我在画布上慢慢地涂抹色彩。
+ +

语气控制

+
冷静<|endofprompt|>在争论中,我试图让自己冷静下来,理智地表达我的观点。
+
严肃<|endofprompt|>这个安全隐患问题必须严肃处理,我们不能掉以轻心。
+ +

英文风格

+
Bubbling with happiness<|endofprompt|>The laughter of children playing in the park fills the air.
+
Overcome with sorrow<|endofprompt|>I miss my dear friend who moved away last month.
+
Speaking very fast<|endofprompt|>I can't believe how much I have to do today!
+
Speaking with patience<|endofprompt|>As we work through this problem, I'll go slowly.
+
+ + + + \ No newline at end of file diff --git a/quant_cosyvoice.py b/quant_cosyvoice.py new file mode 100644 index 00000000..1416ccd1 --- /dev/null +++ b/quant_cosyvoice.py @@ -0,0 +1,87 @@ +import sys +sys.path.append('third_party/Matcha-TTS') +import torch +import os +import shutil +from cosyvoice.cli.cosyvoice import CosyVoice2 +from cosyvoice.llm.llm import Qwen2LM + + +# 原始模型路径 +model_dir = 'pretrained_models/CosyVoice2-0.5B' +original_llm_path = os.path.join(model_dir, 'llm.pt') + +# 备份原始模型 +backup_path = os.path.join(model_dir, 'llm.pt.backup') +if not os.path.exists(backup_path): + print(f"备份原始模型到: {backup_path}") + shutil.copy(original_llm_path, backup_path) + +# 创建量化后的模型目录 +quantized_model_dir = 'pretrained_models/CosyVoice2-0.5B-quantized' +os.makedirs(quantized_model_dir, exist_ok=True) + +# 复制原始模型目录中除了llm.pt之外的所有文件 +for file_name in os.listdir(model_dir): + if not file_name.endswith('.pt') and not file_name.endswith(r'.backup') and not file_name.startswith(r'flow.'): + src_path = os.path.join(model_dir, file_name) + dst_path = os.path.join(quantized_model_dir, file_name) + if os.path.isfile(src_path): + shutil.copy2(src_path, dst_path) + print(f"复制文件: {src_path} -> {dst_path}") + +# 使用CosyVoice2类加载模型 +print("加载原始模型...") +cosyvoice2 = CosyVoice2(model_dir, load_jit=False, load_trt=False, fp16=False) + +# 提取LLM部分 +original_model = cosyvoice2.model.llm +original_model.eval() # 设置为评估模式 + +print("开始量化模型...") + +# 创建一个新的量化模型,只量化线性层 +# 使用更保守的量化设置 +quantized_model = torch.quantization.quantize_dynamic( + original_model, + {torch.nn.Linear}, # 只量化线性层 + dtype=torch.qint8, + inplace=False # 不要修改原始模型 +) + +# 保存量化后的模型到新目录 +quantized_model_path = os.path.join(quantized_model_dir, 'llm.pt') +print(f"保存量化模型到: {quantized_model_path}") + +# 使用torch.save保存整个模型,而不仅仅是state_dict +torch.save(quantized_model, quantized_model_path) + +print(f"量化完成!请使用以下命令测试量化后的模型:") +print(f"python cosyvoice_2_demo.py --model_dir {quantized_model_dir}") +print("如果出现问题,可以继续使用原始模型。") + + +""" +# 方案2: 如果需要量化嵌入层,可以尝试以下代码(取消注释使用) +# 注意:这需要PyTorch 1.13或更高版本 + +# import torch.ao.quantization as quantization +# from torch.ao.quantization import float_qparams_weight_only_qconfig +# from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx + +# # 为嵌入层设置特殊配置 +# qconfig_dict = { +# 'object_type': [ +# (torch.nn.Embedding, float_qparams_weight_only_qconfig), +# (torch.nn.Linear, torch.quantization.get_default_qconfig('fbgemm')) +# ] +# } + +# # 使用FX图模式量化 +# prepared_model = prepare_fx(model, qconfig_dict) +# # 如果有校准数据,可以在这里运行校准 +# quantized_model = convert_fx(prepared_model) + +# # 保存量化后的模型 +# torch.save(quantized_model.state_dict(), out_model) +""" diff --git a/quant_cosyvoice_awq.py b/quant_cosyvoice_awq.py new file mode 100644 index 00000000..09fa7288 --- /dev/null +++ b/quant_cosyvoice_awq.py @@ -0,0 +1,141 @@ +import sys +sys.path.append('third_party/Matcha-TTS') +import torch +import os +import shutil +from cosyvoice.cli.cosyvoice import CosyVoice2 +import argparse + +# 首先检查是否安装了awq +try: + import awq +except ImportError: + print("请先安装awq库: pip install awq") + sys.exit(1) + +# 解析命令行参数 +parser = argparse.ArgumentParser(description='使用AWQ量化CosyVoice模型') +parser.add_argument('--model_dir', type=str, default='pretrained_models/CosyVoice2-0.5B', + help='原始模型目录路径') +parser.add_argument('--output_dir', type=str, default='pretrained_models/CosyVoice2-0.5B-awq', + help='量化后模型保存目录') +parser.add_argument('--bits', type=int, default=4, choices=[4, 8], + help='量化位数 (4 或 8)') +parser.add_argument('--group_size', type=int, default=128, + help='量化组大小') +args = parser.parse_args() + +# 创建输出目录 +os.makedirs(args.output_dir, exist_ok=True) + +# 复制原始模型目录中除了llm.pt之外的所有文件 +print(f"复制模型文件从 {args.model_dir} 到 {args.output_dir}") +for file_name in os.listdir(args.model_dir): + if not file_name.endswith('.pt') and not file_name.endswith('.backup'): + src_path = os.path.join(args.model_dir, file_name) + dst_path = os.path.join(args.output_dir, file_name) + if os.path.isfile(src_path): + shutil.copy2(src_path, dst_path) + print(f"复制文件: {src_path} -> {dst_path}") + +# 加载原始模型 +print("加载原始模型...") +cosyvoice2 = CosyVoice2(args.model_dir, load_jit=False, load_trt=False, fp16=False) + +# 提取LLM部分 +original_model = cosyvoice2.model.llm +original_model.eval() + +# 保存模型配置 +if hasattr(original_model, 'config'): + config_path = os.path.join(args.output_dir, 'config.json') + if hasattr(original_model.config, 'to_json_file'): + original_model.config.to_json_file(config_path) + print(f"保存模型配置到: {config_path}") + +# 准备校准数据 +# 这里使用一些简单的文本作为校准数据 +calibration_data = [ + "这是一个用于校准的示例文本,包含一些常见的中文词汇和句子结构。", + "语音合成技术可以将文本转换为自然流畅的语音,广泛应用于各种场景。", + "人工智能的发展日新月异,语音技术是其中重要的一环。", + "这是一个测试句子,用于模型量化校准。", + "欢迎使用CosyVoice语音合成系统,它可以生成自然、流畅的语音。" +] + +print(f"开始使用AWQ进行{args.bits}位量化...") + +# 使用AWQ量化模型 +try: + from awq import AutoAWQForCausalLM + + # 获取tokenizer + tokenizer = original_model.tokenizer if hasattr(original_model, 'tokenizer') else None + + if tokenizer is None: + print("警告: 无法获取tokenizer,AWQ可能无法正常工作") + # 尝试从transformers加载tokenizer + try: + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B") + print("使用Qwen2-7B的tokenizer作为替代") + except: + print("无法加载替代tokenizer,将尝试继续...") + + # 使用AWQ量化 + quantized_model = AutoAWQForCausalLM.from_pretrained( + original_model, + tokenizer=tokenizer, + ) + + # 执行量化 + quantized_model.quantize( + tokenizer=tokenizer, + quant_config={ + "bits": args.bits, + "group_size": args.group_size, + "zero_point": True, + "q_group_size": 128, + }, + calib_data=calibration_data, + ) + + # 保存量化后的模型 + quantized_model.save_quantized(args.output_dir) + print(f"量化模型已保存到: {args.output_dir}") + + print("量化完成!请使用以下命令测试量化后的模型:") + print(f"python cosyvoice_2_demo.py --model_dir {args.output_dir}") + +except Exception as e: + print(f"AWQ量化过程中出错: {e}") + print("尝试使用替代方法...") + + # 如果上面的方法失败,尝试使用optimum-awq + try: + from transformers import AutoModelForCausalLM + from optimum.awq import AWQConfig, load_quantized_model + + print("使用optimum-awq进行量化...") + + # 配置AWQ + awq_config = AWQConfig( + bits=args.bits, + group_size=args.group_size, + zero_point=True, + ) + + # 量化模型 + quantized_model = load_quantized_model( + original_model, + awq_config, + calibration_data, + ) + + # 保存量化后的模型 + quantized_model.save_pretrained(args.output_dir) + print(f"量化模型已保存到: {args.output_dir}") + + except Exception as e2: + print(f"替代方法也失败: {e2}") + print("建议尝试使用其他量化工具,如bitsandbytes或llama.cpp") \ No newline at end of file diff --git a/quant_cosyvoice_bnb.py b/quant_cosyvoice_bnb.py new file mode 100644 index 00000000..6268c1d4 --- /dev/null +++ b/quant_cosyvoice_bnb.py @@ -0,0 +1,143 @@ +import sys +sys.path.append('third_party/Matcha-TTS') +import torch +import os +import shutil +from cosyvoice.cli.cosyvoice import CosyVoice2 +import argparse + +# 首先检查是否安装了bitsandbytes +try: + import bitsandbytes as bnb + print("成功导入bitsandbytes库") +except ImportError: + print("请先安装bitsandbytes库: pip install bitsandbytes") + sys.exit(1) + +# 解析命令行参数 +parser = argparse.ArgumentParser(description='使用bitsandbytes量化CosyVoice模型') +parser.add_argument('--model_dir', type=str, default='pretrained_models/CosyVoice2-0.5B', + help='原始模型目录路径') +parser.add_argument('--output_dir', type=str, default='pretrained_models/CosyVoice2-0.5B-bnb', + help='量化后模型保存目录') +parser.add_argument('--bits', type=int, default=8, choices=[4, 8], + help='量化位数 (4 或 8)') +args = parser.parse_args() + +# 创建输出目录 +os.makedirs(args.output_dir, exist_ok=True) + +# 复制原始模型目录中除了llm.pt之外的所有文件 +print(f"复制模型文件从 {args.model_dir} 到 {args.output_dir}") +for file_name in os.listdir(args.model_dir): + if not file_name.endswith('.pt') and not file_name.endswith(r'.backup') and not file_name.startswith(r'flow.'): + src_path = os.path.join(args.model_dir, file_name) + dst_path = os.path.join(args.output_dir, file_name) + if os.path.isfile(src_path): + shutil.copy2(src_path, dst_path) + print(f"复制文件: {src_path} -> {dst_path}") + + +# 加载原始模型 +print("加载原始模型...") +cosyvoice2 = CosyVoice2(args.model_dir, load_jit=False, load_trt=False, fp16=False) + +# 提取LLM部分 +original_model = cosyvoice2.model.llm +original_model.eval() + +# 保存模型配置 +if hasattr(original_model, 'config'): + config_path = os.path.join(args.output_dir, 'config.json') + if hasattr(original_model.config, 'to_json_file'): + original_model.config.to_json_file(config_path) + print(f"保存模型配置到: {config_path}") + +print(f"开始使用bitsandbytes进行{args.bits}位量化...") + +# 使用bitsandbytes量化模型 +try: + # 创建量化模型的副本 + quantized_model = type(original_model)(original_model.config) + + # 将原始模型的权重复制到量化模型 + quantized_model.load_state_dict(original_model.state_dict()) + + # 将线性层转换为量化线性层 + for name, module in list(quantized_model.named_modules()): + if isinstance(module, torch.nn.Linear): + parent_name = name.rsplit('.', 1)[0] if '.' in name else '' + child_name = name.rsplit('.', 1)[1] if '.' in name else name + + if parent_name: + parent = quantized_model.get_submodule(parent_name) + + if args.bits == 8: + # 8位量化 + new_module = bnb.nn.Linear8bitLt( + module.in_features, + module.out_features, + bias=module.bias is not None, + ) + else: + # 4位量化 + new_module = bnb.nn.Linear4bit( + module.in_features, + module.out_features, + bias=module.bias is not None, + ) + + # 复制权重 + with torch.no_grad(): + if hasattr(new_module, 'weight'): + new_module.weight.copy_(module.weight) + if module.bias is not None and hasattr(new_module, 'bias'): + new_module.bias.copy_(module.bias) + + # 替换模块 + try: + setattr(parent, child_name, new_module) + print(f"成功量化模块: {name}") + except Exception as e: + print(f"无法量化模块 {name}: {e}") + + # 保存量化后的模型 + quantized_model_path = os.path.join(args.output_dir, 'llm.pt') + torch.save(quantized_model, quantized_model_path) + print(f"量化模型已保存到: {quantized_model_path}") + + print("量化完成!请使用以下命令测试量化后的模型:") + print(f"python cosyvoice_2_demo.py --model_dir {args.output_dir}") + +except Exception as e: + print(f"bitsandbytes量化过程中出错: {e}") + + # 如果上面的方法失败,尝试使用简单的量化方法 + try: + print("尝试使用简单量化方法...") + + # 创建一个简单的量化函数 + def simple_quantize(model, bits=8): + """简单的量化函数,将模型的权重量化为指定位数""" + for name, param in model.named_parameters(): + if 'weight' in name and param.dim() > 1: # 只量化权重矩阵 + # 计算量化范围 + max_val = torch.max(torch.abs(param.data)) + scale = (2**(bits-1) - 1) / max_val + + # 量化 + param.data = torch.round(param.data * scale) / scale + + return model + + # 量化模型 + quantized_model = simple_quantize(original_model, bits=args.bits) + + # 保存量化后的模型 + quantized_model_path = os.path.join(args.output_dir, 'llm.pt') + torch.save(quantized_model, quantized_model_path) + print(f"使用简单量化方法保存模型到: {quantized_model_path}") + + except Exception as e2: + print(f"简单量化方法也失败: {e2}") + print("建议尝试手动调整模型结构或使用其他量化工具") \ No newline at end of file diff --git a/quant_cosyvoice_bnb_advancd.py b/quant_cosyvoice_bnb_advancd.py new file mode 100644 index 00000000..7d846dd7 --- /dev/null +++ b/quant_cosyvoice_bnb_advancd.py @@ -0,0 +1,621 @@ +import sys +sys.path.append('third_party/Matcha-TTS') +import torch +import os +import shutil +from cosyvoice.cli.cosyvoice import CosyVoice2 +import argparse + +""" +CosyVoice模型量化脚本 + +本脚本使用bitsandbytes库对CosyVoice模型进行量化,支持8位和4位量化。 + +量化原理: +1. 8位量化(Linear8bitLt): + - 将模型的线性层权重从FP32/FP16量化为INT8 + - 使用LLM.int8()方法,将异常值(outliers)提取出来在FP16中计算,其余在INT8中计算 + - 实际量化发生在模型被移动到CUDA设备时(.to("cuda")) + - 输入数据需要是FP16类型 + +2. 4位量化(Linear4bit): + - 将模型的线性层权重量化为4位精度 + - 计算使用FP16进行 + - 同样需要将模型移动到CUDA设备触发量化 + - 输入数据需要是FP16类型 + +注意事项: +- 量化前最好确保模型权重是FP16类型 +- 使用量化模型时,输入必须是FP16类型 +- 量化会导致一定的精度损失,但可以显著减少内存占用 +""" + +# 首先检查是否安装了必要的库 +try: + import bitsandbytes as bnb + print("成功导入bitsandbytes库") +except ImportError: + print("请先安装bitsandbytes库: pip install bitsandbytes") + sys.exit(1) + +# 解析命令行参数 +parser = argparse.ArgumentParser(description='使用简化的量化方法量化CosyVoice模型') +parser.add_argument('--model_dir', type=str, default='pretrained_models/CosyVoice2-0.5B', + help='原始模型目录路径') +parser.add_argument('--output_dir', type=str, default='pretrained_models/CosyVoice2-0.5B-quantized', + help='量化后模型保存目录') +parser.add_argument('--bits', type=int, default=8, choices=[4, 8], + help='量化位数 (4 或 8)') +parser.add_argument('--block_size', type=int, default=32, + help='量化块大小') +parser.add_argument('--convert_to_fp16', action='store_true', + help='尝试将模型权重转换为fp16') +parser.add_argument('--save_quantized', action='store_true', + help='同时保存完整的量化模型(包含量化参数)') +args = parser.parse_args() + +# 创建输出目录 +os.makedirs(args.output_dir, exist_ok=True) + +# 复制原始模型目录中除了llm.pt之外的所有文件 +print(f"复制模型文件从 {args.model_dir} 到 {args.output_dir}") +for file_name in os.listdir(args.model_dir): + if not file_name.endswith('llm.pt') and not file_name.endswith(r'.backup'): + src_path = os.path.join(args.model_dir, file_name) + dst_path = os.path.join(args.output_dir, file_name) + if os.path.isfile(src_path) and not os.path.exists(dst_path): + shutil.copy2(src_path, dst_path) + print(f"复制文件: {src_path} -> {dst_path}") + +# 加载原始模型 +print("加载原始模型...") +cosyvoice2 = CosyVoice2(args.model_dir, load_jit=False, load_trt=False, fp16=False) + +# 提取LLM部分 +original_model = cosyvoice2.model.llm +original_model.eval() + +# 尝试将模型转换为fp16 +def convert_to_fp16(model): + """尝试将模型权重转换为fp16""" + print("尝试将模型权重转换为fp16...") + try: + # 检查当前权重类型 + weight_dtype = None + for name, param in model.named_parameters(): + if 'weight' in name: + weight_dtype = param.dtype + print(f"原始模型权重数据类型: {weight_dtype}") + break + + # 如果已经是fp16,则不需要转换 + if weight_dtype == torch.float16: + print("模型已经是fp16类型,无需转换") + return model + + # 转换为fp16 + fp16_model = model.half() + print("成功将模型转换为fp16") + + # 验证转换结果 + for name, param in fp16_model.named_parameters(): + if 'weight' in name: + print(f"转换后权重数据类型: {param.dtype}") + break + + return fp16_model + except Exception as e: + print(f"转换模型为fp16失败: {e}") + print("将继续使用原始模型进行量化") + return model + +# 尝试转换模型为fp16 +if args.convert_to_fp16: + original_model = convert_to_fp16(original_model) + +# 检查原始模型大小 +def check_model_size(model_path): + """检查模型文件大小""" + if os.path.exists(model_path): + size_bytes = os.path.getsize(model_path) + size_mb = size_bytes / (1024 * 1024) + return size_mb + return None + +original_model_path = os.path.join(args.model_dir, 'llm.pt') +original_size_mb = check_model_size(original_model_path) +if original_size_mb: + print(f"原始模型大小: {original_size_mb:.2f} MB") + +print(f"开始进行{args.bits}位量化...") + +# 定义一个更高级的量化函数 +def advanced_quantize(model, bits=8, block_size=32): + """ + 使用块量化方法对模型进行量化 + + 参数: + - model: 要量化的模型 + - bits: 量化位数 (4 或 8) + - block_size: 量化块大小 + + 返回: + - 量化后的模型 + """ + print(f"使用高级量化方法: {bits}位, 块大小={block_size}") + + if hasattr(model, 'config'): + print(f"model有config") + # 创建模型副本 + quantized_model = type(model)(model.config) if hasattr(model, 'config') else model + + # 复制模型状态 + if hasattr(model, 'state_dict'): + print(f"复制模型状态: model有state_dict") + quantized_model.load_state_dict(model.state_dict()) + + # 检查模型权重的数据类型 + weight_dtype = None + for name, param in quantized_model.named_parameters(): + if 'weight' in name: + weight_dtype = param.dtype + print(f"模型权重数据类型: {weight_dtype}") + break + + # 根据权重类型决定has_fp16_weights参数 + has_fp16_weights = weight_dtype == torch.float16 + if not has_fp16_weights: + print(f"警告: 模型权重不是float16类型,而是{weight_dtype}。将设置has_fp16_weights=False。") + print("这可能会影响量化效果,建议先将模型转换为fp16再进行量化。") + else: + print("模型权重是float16类型,将设置has_fp16_weights=True以获得最佳效果。") + + # 移除量化特定的参数 + quantized_count = 0 + # 对每个线性层进行量化 + for name, module in list(quantized_model.named_modules()): + if isinstance(module, torch.nn.Linear): + parent_name = name.rsplit('.', 1)[0] if '.' in name else '' + child_name = name.rsplit('.', 1)[1] if '.' in name else name + + if parent_name: + try: + parent = quantized_model.get_submodule(parent_name) + + # 根据位数选择量化方法 + if bits == 8: + # 8位量化 + new_module = bnb.nn.Linear8bitLt( + module.in_features, + module.out_features, + bias=module.bias is not None, + has_fp16_weights=False, + threshold= 0.001 # 推荐的阈值 + ) + else: + # 4位量化 + new_module = bnb.nn.Linear4bit( + module.in_features, + module.out_features, + bias=module.bias is not None, + compute_dtype=torch.float16, + ) + + # 复制权重 + with torch.no_grad(): + if hasattr(new_module, 'weight') and hasattr(module, 'weight'): + if hasattr(new_module.weight, 'copy_'): + # print("复制权重:", "name:", name, "parent_name:", parent_name, "child_name:", child_name) + new_module.weight.copy_(module.weight) + if module.bias is not None and hasattr(new_module, 'bias') and hasattr(module, 'bias'): + if hasattr(new_module.bias, 'copy_'): + new_module.bias.copy_(module.bias) + + # print("parent type:", type(parent)) + # 替换模块 + setattr(parent, child_name, new_module) + quantized_count += 1 + # print(f"成功量化模块: {name}") + except Exception as e: + print(f"无法量化模块 {name}: {e}") + + print(f"成功量化模块: {quantized_count} 个") + # 将模型移动到CUDA设备,触发实际的量化过程 + if torch.cuda.is_available(): + print("将模型移动到CUDA设备,触发实际量化...") + try: + # 记录CUDA内存使用情况 + if torch.cuda.is_available(): + torch.cuda.empty_cache() + before_mem = torch.cuda.memory_allocated() / (1024 * 1024) + print(f"量化前CUDA内存占用: {before_mem:.2f} MB") + + quantized_model = quantized_model.to("cuda") + print("成功将模型移动到CUDA设备") + + model = None + torch.cuda.empty_cache() + + # 再次记录CUDA内存使用情况 + if torch.cuda.is_available(): + after_mem = torch.cuda.memory_allocated() / (1024 * 1024) + print(f"量化后CUDA内存占用: {after_mem:.2f} MB") + print(f"量化节省的CUDA内存: {(1 - after_mem/before_mem)*100:.2f}%") + except Exception as e: + print(f"将模型移动到CUDA设备时出错: {e}") + print("量化可能未完全生效") + else: + print("警告: 未检测到CUDA设备,量化可能不会生效") + + return quantized_model + +# 使用高级量化方法 +try: + quantized_model = advanced_quantize(original_model, bits=args.bits, block_size=args.block_size) + # quantized_model = original_model + # 清理模型中可能存在的原始权重,减小保存的模型大小 + print("清理模型中的原始权重以减小保存大小...") + for name, module in quantized_model.named_modules(): + if hasattr(module, 'weight_ori') and module.weight_ori is not None: + print(f"清理模块 {name} 的原始权重") + module.weight_ori = None + + # 保存量化后的模型 + quantized_model_path = os.path.join(args.output_dir, 'llm.pt') + + # 如果需要,保存完整的量化模型(包含量化参数) + if args.save_quantized: + quantized_full_path = os.path.join(args.output_dir, 'llm_quantized_full.pt') + torch.save(quantized_model.state_dict(), quantized_full_path) + print(f"完整量化模型(包含量化参数)已保存到: {quantized_full_path}") + + # 创建兼容的状态字典,移除量化特定的参数 + def create_compatible_state_dict(model): + """创建兼容的状态字典,移除量化特定的参数""" + print("创建兼容的状态字典,移除量化特定的参数...") + state_dict = model.state_dict() + compatible_state_dict = {} + + # 移除量化特定的参数 + removed_count = 0 + for key in list(state_dict.keys()): + if any(suffix in key for suffix in ['.SCB', '.weight_format', '.CB']): + # print(f"移除量化特定参数: {key}") + removed_count += 1 + continue + compatible_state_dict[key] = state_dict[key] + + print(f"总共移除了 {removed_count} 个量化特定参数") + return compatible_state_dict + + # 保存兼容的状态字典 + compatible_state_dict = create_compatible_state_dict(quantized_model) + torch.save(compatible_state_dict, quantized_model_path) + print(f"兼容的量化模型已保存到: {quantized_model_path}") + + # 检查模型大小 + model_size_mb = os.path.getsize(quantized_model_path) / (1024 * 1024) + print(f"量化后模型大小: {model_size_mb:.2f} MB") + + # 显示大小比较 + if original_size_mb: + size_ratio = model_size_mb / original_size_mb + size_reduction = (1 - size_ratio) * 100 + print(f"模型大小变化: {size_ratio:.2f}x 原始大小 (减少了 {size_reduction:.2f}%)") + if size_ratio > 1: + print("警告: 量化后的模型比原始模型更大,这可能是因为保存了额外的量化参数或原始权重。") + print("建议检查量化设置,特别是has_fp16_weights参数。") + + print("量化完成!请使用以下命令测试量化后的模型:") + print(f"python cosyvoice_2_demo.py --model_dir {args.output_dir}") + + # 添加使用提示 + if args.bits == 8: + print("\n重要提示:") + print("1. 使用8位量化模型时,请确保输入数据为float16类型") + print("2. 示例: model_input = model_input.to(torch.float16)") + print("3. 如果遇到性能问题,可能需要检查模型是否正确量化") + + if args.save_quantized: + print("\n如果要直接加载完整的量化模型,可以使用以下代码:") + print("```python") + print("import torch") + print("import bitsandbytes as bnb") + print("from transformers import AutoConfig") + print("from cosyvoice.cli.cosyvoice import CosyVoice2") + print("") + print("# 创建一个自定义加载器函数") + print("def load_quantized_model(model_dir):") + print(" # 加载配置") + print(" cosyvoice = CosyVoice2(model_dir, load_jit=False, load_trt=False, fp16=True)") + print(" # 替换线性层为量化层") + print(" for name, module in cosyvoice.model.llm.named_modules():") + print(" if isinstance(module, torch.nn.Linear):") + print(" parent_name = name.rsplit('.', 1)[0] if '.' in name else ''") + print(" child_name = name.rsplit('.', 1)[1] if '.' in name else name") + print(" if parent_name:") + print(" parent = cosyvoice.model.llm.get_submodule(parent_name)") + print(" # 创建8位量化层") + print(" new_module = bnb.nn.Linear8bitLt(") + print(" module.in_features,") + print(" module.out_features,") + print(" bias=module.bias is not None,") + print(" has_fp16_weights=False,") + print(" threshold=6.0") + print(" )") + print(" # 替换模块") + print(" setattr(parent, child_name, new_module)") + print(" # 加载量化模型权重") + print(f" cosyvoice.model.llm.load_state_dict(torch.load('{os.path.join(args.output_dir, 'llm_quantized_full.pt')}'))") + print(" # 移动到CUDA") + print(" cosyvoice.model.llm = cosyvoice.model.llm.to('cuda')") + print(" return cosyvoice") + print("") + print("# 使用自定义加载器加载量化模型") + print(f"cosyvoice = load_quantized_model('{args.output_dir}')") + print("```") + elif args.bits == 4: + print("\n重要提示:") + print("1. 使用4位量化模型时,请确保输入数据为float16类型") + print("2. 示例: model_input = model_input.to(torch.float16)") + print("3. 如果遇到性能问题,可能需要检查模型是否正确量化") + + # 如果保存了完整量化模型,创建加载器脚本 + if args.save_quantized: + loader_script_path = os.path.join(args.output_dir, 'load_quantized_model.py') + with open(loader_script_path, 'w', encoding='utf-8') as f: + f.write(""" +import torch +import bitsandbytes as bnb +from cosyvoice.cli.cosyvoice import CosyVoice2 +import os +import argparse + +# 添加命令行参数解析 +parser = argparse.ArgumentParser(description='加载量化的CosyVoice2模型') +parser.add_argument('--model_dir', type=str, default='""" + args.output_dir + """', + help='模型目录路径') +args = parser.parse_args() + +def load_quantized_model(model_dir): + \"\"\"加载量化的CosyVoice2模型\"\"\" + print(f"加载量化模型: {model_dir}") + + # 加载配置 + cosyvoice = CosyVoice2(model_dir, load_jit=False, load_trt=False, fp16=True) + + # 替换线性层为量化层 + print("替换线性层为量化层...") + for name, module in cosyvoice.model.llm.named_modules(): + if isinstance(module, torch.nn.Linear): + parent_name = name.rsplit('.', 1)[0] if '.' in name else '' + child_name = name.rsplit('.', 1)[1] if '.' in name else name + if parent_name: + try: + parent = cosyvoice.model.llm.get_submodule(parent_name) + # 创建8位量化层 + new_module = bnb.nn.Linear8bitLt( + module.in_features, + module.out_features, + bias=module.bias is not None, + has_fp16_weights=False, + threshold=6.0 + ) + # 替换模块 + setattr(parent, child_name, new_module) + print(f"替换模块: {name}") + except Exception as e: + print(f"替换模块 {name} 失败: {e}") + + # 加载量化模型权重 + quantized_weights_path = os.path.join(model_dir, 'llm_quantized_full.pt') + print(f"加载量化权重: {quantized_weights_path}") + cosyvoice.model.llm.load_state_dict(torch.load(quantized_weights_path)) + + # 移动到CUDA + if torch.cuda.is_available(): + print("将模型移动到CUDA...") + cosyvoice.model.llm = cosyvoice.model.llm.to('cuda') + else: + print("警告: 未检测到CUDA设备") + + return cosyvoice + +# 加载量化模型 +cosyvoice = load_quantized_model(args.model_dir) + +# 测试模型 +print("\\n测试模型...") +from cosyvoice.utils.file_utils import load_wav +import torchaudio + +# 加载测试音频 +prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000) + +# 测试推理 +print("执行推理...") +for i, j in enumerate(cosyvoice.inference_zero_shot('这是一个测试句子,用于验证量化模型是否正常工作。', '希望一切顺利。', prompt_speech_16k, stream=False)): + output_path = f'quantized_test_{i}.wav' + torchaudio.save(output_path, j['tts_speech'], cosyvoice.sample_rate) + print(f"已保存测试音频到: {output_path}") + +print("\\n测试完成!如果生成了音频文件,说明量化模型加载成功。") +""") + print(f"\n已创建量化模型加载器脚本: {loader_script_path}") + print(f"可以使用以下命令测试量化模型:") + print(f"python {loader_script_path}") + print(f"或者指定模型目录:") + print(f"python {loader_script_path} --model_dir 模型目录路径") + +except Exception as e: + print(f"高级量化方法失败: {e}") + print("尝试使用简单量化方法...") + + # 简单量化方法 + def simple_quantize(model, bits=8): + """简单的量化函数,将模型的权重量化为指定位数""" + print(f"使用简单量化方法: {bits}位") + + # 创建模型副本 + quantized_model = type(model)(model.config) if hasattr(model, 'config') else model + + # 复制模型状态 + if hasattr(model, 'state_dict'): + quantized_model.load_state_dict(model.state_dict()) + + # 检查模型权重的数据类型 + weight_dtype = None + for name, param in quantized_model.named_parameters(): + if 'weight' in name: + weight_dtype = param.dtype + print(f"模型权重数据类型: {weight_dtype}") + break + + # 对每个参数进行量化 + for name, param in quantized_model.named_parameters(): + if 'weight' in name and param.dim() > 1: # 只量化权重矩阵 + # 计算量化范围 + max_val = torch.max(torch.abs(param.data)) + scale = (2**(bits-1) - 1) / max_val + + # 量化 + param.data = torch.round(param.data * scale) / scale + print(f"量化参数: {name}") + + # 将模型移动到CUDA设备(如果可用) + if torch.cuda.is_available(): + print("将模型移动到CUDA设备...") + try: + quantized_model = quantized_model.to("cuda") + print("成功将模型移动到CUDA设备") + except Exception as e: + print(f"将模型移动到CUDA设备时出错: {e}") + else: + print("警告: 未检测到CUDA设备") + + # 如果保存了完整量化模型,创建加载器脚本 + if args.save_quantized: + loader_script_path = os.path.join(args.output_dir, 'load_quantized_model.py') + with open(loader_script_path, 'w', encoding='utf-8') as f: + f.write(""" +import torch +import bitsandbytes as bnb +from cosyvoice.cli.cosyvoice import CosyVoice2 +import os +import argparse + +# 添加命令行参数解析 +parser = argparse.ArgumentParser(description='加载量化的CosyVoice2模型') +parser.add_argument('--model_dir', type=str, default='""" + args.output_dir + """', + help='模型目录路径') +args = parser.parse_args() + +def load_quantized_model(model_dir): + \"\"\"加载量化的CosyVoice2模型\"\"\" + print(f"加载量化模型: {model_dir}") + + # 加载配置 + cosyvoice = CosyVoice2(model_dir, load_jit=False, load_trt=False, fp16=True) + + # 替换线性层为量化层 + print("替换线性层为量化层...") + for name, module in cosyvoice.model.llm.named_modules(): + if isinstance(module, torch.nn.Linear): + parent_name = name.rsplit('.', 1)[0] if '.' in name else '' + child_name = name.rsplit('.', 1)[1] if '.' in name else name + if parent_name: + try: + parent = cosyvoice.model.llm.get_submodule(parent_name) + # 创建8位量化层 + new_module = bnb.nn.Linear8bitLt( + module.in_features, + module.out_features, + bias=module.bias is not None, + has_fp16_weights=False, + threshold=6.0 + ) + # 替换模块 + setattr(parent, child_name, new_module) + print(f"替换模块: {name}") + except Exception as e: + print(f"替换模块 {name} 失败: {e}") + + # 加载量化模型权重 + quantized_weights_path = os.path.join(model_dir, 'llm_quantized_full.pt') + print(f"加载量化权重: {quantized_weights_path}") + cosyvoice.model.llm.load_state_dict(torch.load(quantized_weights_path)) + + # 移动到CUDA + if torch.cuda.is_available(): + print("将模型移动到CUDA...") + cosyvoice.model.llm = cosyvoice.model.llm.to('cuda') + else: + print("警告: 未检测到CUDA设备") + + return cosyvoice + +# 加载量化模型 +cosyvoice = load_quantized_model(args.model_dir) + +# 测试模型 +print("\\n测试模型...") +from cosyvoice.utils.file_utils import load_wav +import torchaudio + +# 加载测试音频 +prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000) + +# 测试推理 +print("执行推理...") +for i, j in enumerate(cosyvoice.inference_zero_shot('这是一个测试句子,用于验证量化模型是否正常工作。', '希望一切顺利。', prompt_speech_16k, stream=False)): + output_path = f'quantized_test_{i}.wav' + torchaudio.save(output_path, j['tts_speech'], cosyvoice.sample_rate) + print(f"已保存测试音频到: {output_path}") + +print("\\n测试完成!如果生成了音频文件,说明量化模型加载成功。") +""") + print(f"\n已创建量化模型加载器脚本: {loader_script_path}") + print(f"可以使用以下命令测试量化模型:") + print(f"python {loader_script_path}") + print(f"或者指定模型目录:") + print(f"python {loader_script_path} --model_dir 模型目录路径") + + return quantized_model + + try: + # 使用简单量化方法 + quantized_model = simple_quantize(original_model, bits=args.bits) + + # 保存量化后的模型 + quantized_model_path = os.path.join(args.output_dir, 'llm.pt') + + # 如果需要,保存完整的量化模型(包含量化参数) + if args.save_quantized: + quantized_full_path = os.path.join(args.output_dir, 'llm_quantized_full.pt') + torch.save(quantized_model.state_dict(), quantized_full_path) + print(f"完整量化模型(包含量化参数)已保存到: {quantized_full_path}") + + # 保存兼容的状态字典 + compatible_state_dict = create_compatible_state_dict(quantized_model) + torch.save(compatible_state_dict, quantized_model_path) + print(f"兼容的量化模型已保存到: {quantized_model_path}") + + # 检查模型大小 + model_size_mb = os.path.getsize(quantized_model_path) / (1024 * 1024) + print(f"量化后模型大小: {model_size_mb:.2f} MB") + + # 显示大小比较 + if original_size_mb: + size_ratio = model_size_mb / original_size_mb + size_reduction = (1 - size_ratio) * 100 + print(f"模型大小变化: {size_ratio:.2f}x 原始大小 (减少了 {size_reduction:.2f}%)") + if size_ratio > 1: + print("警告: 量化后的模型比原始模型更大,这可能是因为简单量化方法不够高效。") + print("建议尝试其他量化方法或工具。") + + print("量化完成!请使用以下命令测试量化后的模型:") + print(f"python cosyvoice_2_demo.py --model_dir {args.output_dir}") + + except Exception as e2: + print(f"简单量化方法也失败: {e2}") + print("建议尝试使用其他量化工具或手动调整模型结构") diff --git a/quant_cosyvoice_gptq_real.py b/quant_cosyvoice_gptq_real.py new file mode 100644 index 00000000..f7620ac0 --- /dev/null +++ b/quant_cosyvoice_gptq_real.py @@ -0,0 +1,167 @@ +import sys +sys.path.append('third_party/Matcha-TTS') +import torch +import os +import shutil +from cosyvoice.cli.cosyvoice import CosyVoice2 +import argparse + +# 首先检查是否安装了必要的库 +try: + import auto_gptq + from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig + from auto_gptq.utils.peft_utils import get_gptq_peft_model + print("成功导入auto-gptq库") +except ImportError as e: + print(f"导入auto-gptq库失败: {e}") + print("请安装auto-gptq库: pip install auto-gptq") + sys.exit(1) +except Exception as e: + print(f"auto-gptq库版本不兼容: {e}") + print("请尝试安装兼容的版本: pip install auto-gptq==0.4.2 transformers==4.30.0") + print("或者使用bitsandbytes方法: python quant_cosyvoice_bnb.py") + sys.exit(1) + +# 解析命令行参数 +parser = argparse.ArgumentParser(description='使用GPTQ量化CosyVoice模型') +parser.add_argument('--model_dir', type=str, default='pretrained_models/CosyVoice2-0.5B', + help='原始模型目录路径') +parser.add_argument('--output_dir', type=str, default='pretrained_models/CosyVoice2-0.5B-gptq', + help='量化后模型保存目录') +parser.add_argument('--bits', type=int, default=8, choices=[2, 3, 4, 8], + help='量化位数 (2, 3, 4, 或 8)') +parser.add_argument('--group_size', type=int, default=128, + help='量化组大小') +parser.add_argument('--desc_act', action='store_true', + help='是否使用描述激活') +args = parser.parse_args() + +# 创建输出目录 +os.makedirs(args.output_dir, exist_ok=True) + +# 复制原始模型目录中除了llm.pt之外的所有文件 +print(f"复制模型文件从 {args.model_dir} 到 {args.output_dir}") +for file_name in os.listdir(args.model_dir): + if not file_name.endswith('.pt') and not file_name.endswith(r'.backup') and not file_name.startswith(r'flow.'): + src_path = os.path.join(args.model_dir, file_name) + dst_path = os.path.join(args.output_dir, file_name) + if os.path.isfile(src_path): + shutil.copy2(src_path, dst_path) + print(f"复制文件: {src_path} -> {dst_path}") + +# 加载原始模型 +print("加载原始模型...") +cosyvoice2 = CosyVoice2(args.model_dir, load_jit=False, load_trt=False, fp16=False) + +# 提取LLM部分 +original_model = cosyvoice2.model.llm +original_model.eval() + +# 保存模型配置 +if hasattr(original_model, 'config'): + config_path = os.path.join(args.output_dir, 'config.json') + if hasattr(original_model.config, 'to_json_file'): + original_model.config.to_json_file(config_path) + print(f"保存模型配置到: {config_path}") + +# 设置GPTQ量化配置 +quantize_config = BaseQuantizeConfig( + bits=args.bits, # 量化位数 + group_size=args.group_size, # 量化组大小 + desc_act=args.desc_act, # 是否使用描述激活 +) + +# 准备校准数据 +# 这里使用一些简单的文本作为校准数据 +# 在实际应用中,应该使用更多样化的数据 +calibration_data = [ + "这是一个用于校准的示例文本,包含一些常见的中文词汇和句子结构。", + "语音合成技术可以将文本转换为自然流畅的语音,广泛应用于各种场景。", + "人工智能的发展日新月异,语音技术是其中重要的一环。", + "这是一个测试句子,用于模型量化校准。", + "欢迎使用CosyVoice语音合成系统,它可以生成自然、流畅的语音。" +] + +# 创建一个简单的数据集 +class SimpleDataset(torch.utils.data.Dataset): + def __init__(self, texts, tokenizer): + self.encodings = tokenizer(texts, padding=True, truncation=True, return_tensors="pt") + + def __getitem__(self, idx): + return {key: val[idx] for key, val in self.encodings.items()} + + def __len__(self): + return len(self.encodings.input_ids) + +# 获取tokenizer +tokenizer = original_model.tokenizer if hasattr(original_model, 'tokenizer') else None + +if tokenizer is None: + print("警告: 无法获取tokenizer,将使用默认校准方法") + examples = [""] * 5 # 使用空字符串作为默认校准数据 +else: + # 创建校准数据集 + dataset = SimpleDataset(calibration_data, tokenizer) + examples = [{"input_ids": item["input_ids"]} for item in dataset] + +print(f"开始使用GPTQ进行{args.bits}位量化...") + + + + + + +# 使用GPTQ量化模型 +try: + # 对于不同的模型架构,可能需要调整这里的代码 + quantized_model = AutoGPTQForCausalLM.from_pretrained( + original_model, + quantize_config=quantize_config, + ) + + # 执行量化 + quantized_model.quantize(examples) + + # 保存量化后的模型 + quantized_model_path = os.path.join(args.output_dir, 'llm.pt') + quantized_model.save_pretrained(args.output_dir) + print(f"量化模型已保存到: {args.output_dir}") + + print("量化完成!请使用以下命令测试量化后的模型:") + print(f"python cosyvoice_2_demo.py --model_dir {args.output_dir}") + +except Exception as e: + print(f"量化过程中出错: {e}") + print("尝试使用替代方法...") + + + # 如果上面的方法失败,尝试使用更通用的方法 + try: + from transformers import AutoModelForCausalLM + from optimum.gptq import GPTQConfig, load_quantized_model + + print("使用optimum-gptq进行量化...") + + # 配置GPTQ + gptq_config = GPTQConfig( + bits=args.bits, + group_size=args.group_size, + desc_act=args.desc_act, + ) + + # 量化模型 + quantized_model = load_quantized_model( + original_model, + gptq_config, + calibration_data, + ) + + # 保存量化后的模型 + quantized_model.save_pretrained(args.output_dir) + print(f"量化模型已保存到: {args.output_dir}") + + except Exception as e2: + print(f"替代方法也失败: {e2}") + print("建议尝试使用其他量化工具,如bitsandbytes或llama.cpp") + + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 4166dace..31b952b9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ ---extra-index-url https://download.pytorch.org/whl/cu121 +--extra-index-url https://download.pytorch.org/whl/cu117 --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ # https://github.com/microsoft/onnxruntime/issues/21684 conformer==0.3.2 deepspeed==0.14.2; sys_platform == 'linux' @@ -27,11 +27,11 @@ pyworld==0.3.4 rich==13.7.1 soundfile==0.12.1 tensorboard==2.14.0 -tensorrt-cu12==10.0.1; sys_platform == 'linux' -tensorrt-cu12-bindings==10.0.1; sys_platform == 'linux' -tensorrt-cu12-libs==10.0.1; sys_platform == 'linux' -torch==2.3.1 -torchaudio==2.3.1 +tensorrt==10.0.1; sys_platform == 'linux' +tensorrt-bindings==10.0.1; sys_platform == 'linux' +tensorrt-libs==10.0.1; sys_platform == 'linux' +# torch==2.3.1 +# torchaudio==2.3.1 transformers==4.40.1 uvicorn==0.30.0 wget==3.2 diff --git a/runtime/python/Dockerfile b/runtime/python/Dockerfile index ae7e01fd..6245e6eb 100644 --- a/runtime/python/Dockerfile +++ b/runtime/python/Dockerfile @@ -1,4 +1,8 @@ -FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime +# syntax=docker/dockerfile:1 + +# FROM nvidia/cuda:12.2.2-cudnn8-devel-ubuntu22.04 +FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel +# FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime ENV DEBIAN_FRONTEND=noninteractive WORKDIR /opt/CosyVoice @@ -7,7 +11,34 @@ RUN sed -i s@/archive.ubuntu.com/@/mirrors.aliyun.com/@g /etc/apt/sources.list RUN apt-get update -y RUN apt-get -y install git unzip git-lfs g++ RUN git lfs install + +ENV http_proxy=socks5://183.240.180.158:10080 +ENV https_proxy=socks5://183.240.180.158:10080 +ENV all_proxy=socks5://183.240.180.158:10080 +RUN git config --global http.proxy socks5://183.240.180.158:10080 +RUN git config --global https.proxy socks5://183.240.180.158:10080 + RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git # here we use python==3.10 because we cannot find an image which have both python3.8 and torch2.0.1-cu118 installed -RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com + +# 修改requirements.txt文件,将TensorRT的CUDA 12版本改为普通版本 +RUN sed -i 's/tensorrt-cu12==10.0.1/# tensorrt==9.0.0.post11.dev1/g' /opt/CosyVoice/CosyVoice/requirements.txt && \ + sed -i 's/tensorrt-cu12-bindings==10.0.1/# tensorrt-bindings==9.0.0.post11.dev1/g' /opt/CosyVoice/CosyVoice/requirements.txt && \ + sed -i 's/tensorrt-cu12-libs==10.0.1/# tensorrt-libs==9.0.0.post11.dev1/g' /opt/CosyVoice/CosyVoice/requirements.txt && \ + sed -i 's/torch==2.3.1/# torch==2.3.1/g' /opt/CosyVoice/CosyVoice/requirements.txt && \ + sed -i 's/torchaudio==2.3.1/# torchaudio==2.3.1/g' /opt/CosyVoice/CosyVoice/requirements.txt + +# 先升级pip +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install --upgrade pip + +# 先安装onnxruntime-gpu,这是最耗时的部分 +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install --verbose onnxruntime-gpu==1.18.0 --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ + + +# 再安装其他依赖 +RUN --mount=type=cache,target=/root/.cache/pip \ + cd CosyVoice && pip3 install --verbose -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple/ --trusted-host=pypi.tuna.tsinghua.edu.cn + RUN cd CosyVoice/runtime/python/grpc && python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto \ No newline at end of file diff --git a/runtime/python/fastapi/client.py b/runtime/python/fastapi/client.py index 0fb29b76..adfcf128 100644 --- a/runtime/python/fastapi/client.py +++ b/runtime/python/fastapi/client.py @@ -1,59 +1,109 @@ -# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import argparse import logging import requests import torch import torchaudio import numpy as np +import time +import os +import sys +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append('{}/../../..'.format(ROOT_DIR)) +# from stream_player import StreamPlayer +# player = StreamPlayer(sample_rate=22050, channels=1, block_size=18048) +# player.start() + +logging.basicConfig(level=logging.DEBUG) def main(): url = "http://{}:{}/inference_{}".format(args.host, args.port, args.mode) - if args.mode == 'sft': - payload = { - 'tts_text': args.tts_text, - 'spk_id': args.spk_id - } - response = requests.request("GET", url, data=payload, stream=True) - elif args.mode == 'zero_shot': - payload = { - 'tts_text': args.tts_text, - 'prompt_text': args.prompt_text - } - files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))] - response = requests.request("GET", url, data=payload, files=files, stream=True) - elif args.mode == 'cross_lingual': - payload = { - 'tts_text': args.tts_text, - } - files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))] - response = requests.request("GET", url, data=payload, files=files, stream=True) - else: - payload = { - 'tts_text': args.tts_text, - 'spk_id': args.spk_id, - 'instruct_text': args.instruct_text - } - response = requests.request("GET", url, data=payload, stream=True) - tts_audio = b'' - for r in response.iter_content(chunk_size=16000): - tts_audio += r - tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0) - logging.info('save response to {}'.format(args.tts_wav)) - torchaudio.save(args.tts_wav, tts_speech, target_sr) - logging.info('get response') + logging.info('请求URL: {}'.format(url)) + + time_start = time.time() + + try: + if args.mode == 'sft': + payload = { + 'tts_text': args.tts_text, + 'spk_id': args.spk_id + } + response = requests.request("GET", url, data=payload, stream=True, timeout=args.timeout) + elif args.mode == 'zero_shot': + payload = { + 'tts_text': args.tts_text, + 'prompt_text': args.prompt_text + } + files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))] + response = requests.request("GET", url, data=payload, files=files, stream=True, timeout=args.timeout) + elif args.mode == 'cross_lingual': + payload = { + 'tts_text': args.tts_text, + } + files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))] + response = requests.request("GET", url, data=payload, files=files, stream=True, timeout=args.timeout) + elif args.mode == 'instruct2': + payload = { + 'tts_text': args.tts_text, + 'instruct_text': args.instruct_text, + 'spk_id': args.spk_id + } + response = requests.request("GET", url, data=payload, stream=True, timeout=args.timeout) + else: + payload = { + 'tts_text': args.tts_text, + 'spk_id': args.spk_id, + 'instruct_text': args.instruct_text + } + response = requests.request("GET", url, data=payload, stream=True, timeout=args.timeout) + + # 确保响应状态码正确 + response.raise_for_status() + + # 接收并处理音频数据 + tts_audio = b'' + chunk_count = 0 + last_log_time = time.time() + + # 调整每次接收的块大小,建议设置为较大值以减少网络往返次数 + # 但不要太大,否则会增加首次播放延迟 + for r in response.iter_content(chunk_size=64000): + if r: # 过滤掉保持连接活跃的空块 + now = time.time() + chunk_count += 1 + tts_audio += r + + # # 播放音频 + # player.play(r) + + # 避免日志过于频繁 + if now - last_log_time > 0.5: + logging.debug(f"接收到第{chunk_count}块音频数据,大小: {len(r)} 字节,已接收总量: {len(tts_audio)}") + last_log_time = now + + # 记录最终接收到的数据量 + logging.info(f"接收完成,共接收{chunk_count}块数据,总大小: {len(tts_audio)} 字节") + + if len(tts_audio) == 0: + logging.error("未接收到任何音频数据!") + return + + # 将接收到的字节数据转换为PyTorch张量 + tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0) + time_end = time.time() + logging.info('处理时间: {:.2f}秒'.format(time_end - time_start)) + logging.info('保存音频到: {}'.format(args.tts_wav)) + torchaudio.save(args.tts_wav, tts_speech, target_sr) + logging.info('音频合成完成') + + except requests.exceptions.Timeout: + logging.error(f"请求超时!请尝试增加超时时间(当前: {args.timeout}秒)") + except requests.exceptions.ConnectionError: + logging.error("连接服务器失败!请检查服务器是否正在运行") + except requests.exceptions.HTTPError as e: + logging.error(f"HTTP错误: {e}") + except Exception as e: + logging.error(f"发生未知错误: {e}") if __name__ == "__main__": @@ -66,8 +116,8 @@ def main(): default='50000') parser.add_argument('--mode', default='sft', - choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'], - help='request mode') + choices=['sft', 'zero_shot', 'cross_lingual', 'instruct', 'instruct2'], + help='请求模式') parser.add_argument('--tts_text', type=str, default='你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?') @@ -87,6 +137,10 @@ def main(): parser.add_argument('--tts_wav', type=str, default='demo.wav') + parser.add_argument('--timeout', + type=int, + default=300, + help='请求超时时间(秒)') args = parser.parse_args() prompt_sr, target_sr = 16000, 22050 main() diff --git a/runtime/python/fastapi/readme.txt b/runtime/python/fastapi/readme.txt new file mode 100644 index 00000000..9e15daac --- /dev/null +++ b/runtime/python/fastapi/readme.txt @@ -0,0 +1,8 @@ +pip install soundfile +pip install torchaudio +pip install torch + +python client.py --tts_text " Yes! Can the lenses show who shares cookies best? Brilliant! The lenses will glow golden around generous friends like when Hagrid shares rock cakes! I will polish imaginary lenses" --spk_id "hp" + + +spk_id的选项有 laoxu(老许) hp (哈利波特) diff --git a/runtime/python/fastapi/server.py b/runtime/python/fastapi/server.py index 74c62d81..5f525835 100644 --- a/runtime/python/fastapi/server.py +++ b/runtime/python/fastapi/server.py @@ -1,34 +1,33 @@ -# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import os +os.environ["CUDA_VISIBLE_DEVICES"] = "7" import sys import argparse import logging -logging.getLogger('matplotlib').setLevel(logging.WARNING) +import torch +# logging.getLogger('matplotlib').setLevel(logging.WARNING) +logging.getLogger().setLevel(logging.DEBUG) from fastapi import FastAPI, UploadFile, Form, File from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware import uvicorn import numpy as np +import io +from typing import Iterator, Any, List ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append('{}/../../..'.format(ROOT_DIR)) sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 from cosyvoice.utils.file_utils import load_wav +CosyVoice2_path = '{}/../../..'.format(ROOT_DIR) + app = FastAPI() -# set cross region allowance +# 设置跨域资源共享(CORS) +# 这段代码允许不同域名的前端应用访问这个API服务器 +# 如果没有这个设置,当前端应用(如网页)的域名与API服务器域名不同时, +# 浏览器会因安全限制阻止前端访问API,导致用户无法使用语音合成功能 +# 例如:如果网页在example.com上,而API在api.example.com上,没有CORS设置会导致请求失败 +# "*"表示允许任何网站访问,在生产环境中可能需要限制为特定域名以提高安全性 app.add_middleware( CORSMiddleware, allow_origins=["*"], @@ -36,18 +35,82 @@ allow_methods=["*"], allow_headers=["*"]) +def process_audio_chunk(audio_data) -> bytes: + """将模型输出的音频数据转换为字节格式""" + try: + # 当前使用标准16位PCM格式以确保兼容性 + audio_bytes = (audio_data.numpy() * (2 ** 15)).astype(np.int16).tobytes() + return audio_bytes + except Exception as e: + logging.error(f"处理音频数据时出错: {e}") + return b'' -def generate_data(model_output): - for i in model_output: - tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes() - yield tts_audio +def generate_data(model_output) -> Iterator[bytes]: + """ + 处理模型输出的语音数据,将其转换为适合流式传输的格式 + + 关于数据格式选择的说明: + - 确实可以使用float16或float32格式,这样处理流程会更简洁 + - float16相比float32可以节省一半带宽,同时保持足够的精度 + - 音频数据在[-1,1]范围内的浮点数确实可以被现代音频库直接处理 + - 但需要确保客户端能正确解析这种格式的数据 + + 当前我们使用int16是因为: + 1. 广泛兼容性:几乎所有音频系统都支持16位PCM + 2. 客户端代码:当前客户端已配置为接收并解析int16格式 + + 如果整个系统都支持浮点音频,可以考虑简化为float16格式 + """ + try: + # 检查model_output是否为可迭代对象 + if hasattr(model_output, '__iter__'): + for i, item in enumerate(model_output): + if 'tts_speech' in item: + audio_bytes = process_audio_chunk(item['tts_speech']) + logging.debug(f"发送音频数据片段[{i}],大小: {len(audio_bytes)} 字节") + yield audio_bytes + else: + # 处理单个输出的情况 + if hasattr(model_output, 'tts_speech'): + audio_bytes = process_audio_chunk(model_output.tts_speech) + logging.debug(f"发送单个音频数据片段,大小: {len(audio_bytes)} 字节") + yield audio_bytes + except Exception as e: + logging.error(f"生成数据时发生错误: {e}") + raise @app.get("/inference_sft") @app.post("/inference_sft") async def inference_sft(tts_text: str = Form(), spk_id: str = Form()): - model_output = cosyvoice.inference_sft(tts_text, spk_id) - return StreamingResponse(generate_data(model_output)) + # 使用流式模式生成语音,inference_sft本身会返回一个生成器 + try: + logging.info(f"开始语音合成: '{tts_text[:30]}...'") + + # 获取所有生成的音频数据 + all_outputs = [] + + # 对生成的每个片段进行处理 + async def stream_generator(): + for chunk in cosyvoice.inference_sft(tts_text, spk_id, stream=True): + # 处理单个音频数据块 + audio_bytes = process_audio_chunk(chunk['tts_speech']) + logging.debug(f"流式发送音频数据片段,大小: {len(audio_bytes)} 字节") + yield audio_bytes + + # 返回流式响应 + return StreamingResponse( + stream_generator(), + media_type="audio/wave", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Transfer-Encoding": "chunked" + } + ) + except Exception as e: + logging.error(f"inference_sft处理失败: {e}") + raise @app.get("/inference_zero_shot") @@ -62,23 +125,66 @@ async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), @app.post("/inference_cross_lingual") async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File()): prompt_speech_16k = load_wav(prompt_wav.file, 16000) - model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k) + model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, prompt_speech_16k) return StreamingResponse(generate_data(model_output)) @app.get("/inference_instruct") @app.post("/inference_instruct") async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form()): - model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text) - return StreamingResponse(generate_data(model_output)) + try: + model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text) + return StreamingResponse( + generate_data(model_output), + media_type="audio/wav", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} + ) + except Exception as e: + logging.error(f"inference_instruct处理失败: {e}") + raise + + +def load_voice_data(speaker): + """ + 加载自定义语音数据 + + 参数: + speaker: 说话人ID/名称 + + 返回: + 加载的语音参考数据,如果加载失败则返回None + """ + + voice_path = f"{CosyVoice2_path}/voices/{speaker}.pt" + try: + # 检测是否有GPU可用 + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if not os.path.exists(voice_path): + return None + # 加载语音模型数据 + voice_data = torch.load(voice_path, map_location=device) + return voice_data.get('audio_ref') + except Exception as e: + raise ValueError(f"加载音色文件失败: {e}") @app.get("/inference_instruct2") @app.post("/inference_instruct2") -async def inference_instruct2(tts_text: str = Form(), instruct_text: str = Form(), prompt_wav: UploadFile = File()): - prompt_speech_16k = load_wav(prompt_wav.file, 16000) - model_output = cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k) - return StreamingResponse(generate_data(model_output)) +# async def inference_instruct2(tts_text: str = Form(), instruct_text: str = Form(), prompt_wav: UploadFile = File(), spk_id: str = Form()): +async def inference_instruct2(tts_text: str = Form(), instruct_text: str = Form(), spk_id: str = Form()): + try: + prompt_speech_16k = load_voice_data(spk_id) + # else: + # prompt_speech_16k = load_wav(prompt_wav.file, 16000) + model_output = cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k) + return StreamingResponse( + generate_data(model_output), + media_type="audio/wav", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} + ) + except Exception as e: + logging.error(f"inference_instruct2处理失败: {e}") + raise if __name__ == '__main__': @@ -88,14 +194,32 @@ async def inference_instruct2(tts_text: str = Form(), instruct_text: str = Form( default=50000) parser.add_argument('--model_dir', type=str, - default='iic/CosyVoice-300M', + default= CosyVoice2_path+'/pretrained_models/CosyVoice2-0.5B', help='local path or modelscope repo id') + parser.add_argument('--timeout', + type=int, + default=120, + help='服务器请求超时时间(秒)') args = parser.parse_args() try: - cosyvoice = CosyVoice(args.model_dir) + # cosyvoice = CosyVoice(args.model_dir) + print(f"使用模型目录: {args.model_dir}") + cosyvoice = CosyVoice2( + args.model_dir, + load_jit=False, + load_trt=True, + fp16=True, + use_flow_cache=True, + ) except Exception: - try: - cosyvoice = CosyVoice2(args.model_dir) - except Exception: - raise TypeError('no valid model_type!') - uvicorn.run(app, host="0.0.0.0", port=args.port) + raise TypeError('no valid model_type!') + + # 设置更长的超时时间,确保长文本语音合成不会中断 + uvicorn.run( + app, + host="0.0.0.0", + port=args.port, + timeout_keep_alive=args.timeout, + timeout_graceful_shutdown=args.timeout, + limit_concurrency=10 + ) diff --git a/stream_player.py b/stream_player.py new file mode 100644 index 00000000..25f820dd --- /dev/null +++ b/stream_player.py @@ -0,0 +1,233 @@ +import sounddevice as sd +import numpy as np +import threading +import time +import logging +import torch + + +# class StreamPlayer: +# def __init__(self, sample_rate=22050, channels=1, block_size=8192): +# self.sample_rate = sample_rate +# self.channels = channels +# self.block_size = block_size +# self.audio_queue = queue.Queue() +# self.playing = False +# self.play_thread = None + +# def start(self): +# """启动播放线程""" +# self.playing = True +# self.play_thread = threading.Thread(target=self._play_loop) +# # 将线程设置为守护线程,这样当主程序退出时,该线程会自动终止 +# # 避免因为播放线程未结束而导致程序无法正常退出 +# # 如果不设置daemon=True,则主程序结束时会等待该线程完成 +# self.play_thread.daemon = True +# self.play_thread.start() + +# def _play_loop(self): +# """播放线程循环函数""" +# while self.playing: +# try: +# # 从队列获取音频数据 +# audio_data = self.audio_queue.get(timeout=0.2) + +# # 如果获取到有效数据则播放 +# if audio_data is not None and len(audio_data) > 0: +# try: +# # 确保数据格式正确 +# sd.play(audio_data, self.sample_rate, blocksize=self.block_size) +# # 等待播放完成 +# sd.wait() +# except sd.PortAudioError as e: +# print(f"音频播放错误: {e}") +# # 短暂暂停后继续 +# time.sleep(0.5) + +# self.audio_queue.task_done() +# except queue.Empty: +# # 队列为空时短暂休眠 +# time.sleep(0.01) + +# def play(self, audio_data): +# """将音频数据添加到播放队列""" +# self.audio_queue.put(audio_data) + +# def stop(self): +# """停止播放并清理资源""" +# self.playing = False +# if self.play_thread and self.play_thread.is_alive(): +# self.play_thread.join(timeout=1.0) + +# # 清空队列 +# try: +# while True: +# self.audio_queue.get_nowait() +# self.audio_queue.task_done() +# except queue.Empty: +# pass + +# 实现自定义的音频播放器,使用sounddevice的回调机制 +class StreamPlayer: + def __init__(self, sample_rate=22050, channels=1, block_size=4096, latency='high', max_buffer_size=3000000): + """ + 使用连续缓冲区的音频播放器 + + 参数: + sample_rate: 采样率 + channels: 通道数 + block_size: 音频处理块大小 + latency: 延迟设置 ('low', 'high', 'medium') + max_buffer_size: 缓冲区最大样本数,超过会截断 + """ + self.sample_rate = sample_rate + self.channels = channels + self.block_size = block_size + self.latency = latency + self.max_buffer_size = max_buffer_size + + # 音频缓冲区,所有新音频都拼接到这里 + self.buffer = np.zeros((0,), dtype=np.float32) + self.buffer_lock = threading.Lock() # 用于同步访问缓冲区 + + # 当前播放位置 + self.position = 0 + self.stream = None + self.is_playing = False + + def _audio_callback(self, outdata, frames, time, status): + """ + sounddevice回调函数,从连续缓冲区读取数据 + """ + if status: + logging.warning(f"流状态: {status}") + + with self.buffer_lock: + # 计算可用的音频数据量 (缓冲区长度 - 当前位置) + available = len(self.buffer) - self.position + + if available <= 0: + # 缓冲区中没有可用数据,播放静音 + outdata.fill(0) + return + + # 确定要播放的样本数 + play_length = min(len(outdata), available) + + # 复制数据到输出缓冲区 + outdata[:play_length] = self.buffer[self.position:self.position+play_length].reshape(-1, 1) + + # 如果没有足够数据填满输出缓冲区,剩余部分填充静音 + if play_length < len(outdata): + outdata[play_length:].fill(0) + logging.info(f"缓冲区数据不足,部分输出静音 ({play_length}/{len(outdata)})") + + # 更新位置 + self.position += play_length + + # 如果位置超过了设定的阈值,裁剪缓冲区 + if self.position > self.max_buffer_size // 2: + # 保留后半部分缓冲区 + self.buffer = self.buffer[self.position - self.block_size:] + # 重置位置,保留一个块的余量防止播放断裂 + self.position = self.block_size + logging.debug(f"缓冲区已裁剪,新长度: {len(self.buffer)}") + + def start(self): + """启动音频流""" + if self.stream is None or not self.stream.active: + try: + self.stream = sd.OutputStream( + samplerate=self.sample_rate, + channels=self.channels, + callback=self._audio_callback, + blocksize=self.block_size, + latency=self.latency + ) + self.stream.start() + self.is_playing = True + logging.debug("音频播放流已启动") + except sd.PortAudioError as e: + logging.error(f"启动音频流失败: {e}") + raise + + def play(self, audio_data): + """ + 添加音频数据到连续缓冲区 + """ + if isinstance(audio_data, bytes): + # 处理从网络接收的int16字节流数据 + audio_data = np.frombuffer(audio_data, dtype=np.int16) + # 将int16数据转换回[-1,1]范围的float32 + audio_data = audio_data.astype(np.float32) / (2 ** 15) + elif not isinstance(audio_data, np.ndarray): + audio_data = np.array(audio_data, dtype=np.float32) + + # 保证数据是一维的 + if len(audio_data.shape) > 1: + audio_data = audio_data.flatten() + + with self.buffer_lock: + # 将新数据附加到缓冲区 + self.buffer = np.concatenate((self.buffer, audio_data)) + + # 如果缓冲区超过最大大小,则裁剪 + if len(self.buffer) > self.max_buffer_size: + # 保留后半部分,确保当前播放位置之后的数据不会丢失 + keep_from = max(0, self.position - self.block_size) + self.buffer = self.buffer[keep_from:] + self.position -= keep_from + logging.debug(f"缓冲区已裁剪,新长度: {len(self.buffer)}") + + # logging.debug(f"缓冲区状态: {self.get_buffer_status()}") + # logging.debug(f"添加了 {len(audio_data)} 个样本,当前缓冲区大小: {len(self.buffer)}, 当前位置: {self.position}") + + def stop(self): + """停止音频流并清理资源""" + if self.stream is not None and self.stream.active: + self.stream.stop() + self.stream.close() + self.stream = None + self.is_playing = False + + with self.buffer_lock: + # 清空缓冲区 + self.buffer = np.zeros((0,), dtype=np.float32) + self.position = 0 + + logging.debug("音频播放流已停止,缓冲区已清空") + + def is_empty(self): + """检查缓冲区是否为空""" + with self.buffer_lock: + return len(self.buffer) <= self.position + + def get_buffer_status(self): + """获取缓冲区状态""" + with self.buffer_lock: + total = len(self.buffer) + available = max(0, total - self.position) + return { + "total_size": total, + "position": self.position, + "available": available, + "buffer_seconds": available / self.sample_rate if self.sample_rate > 0 else 0 + } + + def start_with_prebuffer(self, min_buffer_samples=16384): + """启动音频流,但先等待缓冲区达到最小大小""" + # 检查缓冲区大小 + prebuffer_wait_start = time.time() + max_wait_time = 3.0 # 最多等待3秒 + + while len(self.buffer) < min_buffer_samples: + # 检查是否超时 + if time.time() - prebuffer_wait_start > max_wait_time: + logging.warning(f"预缓冲超时,当前缓冲区大小: {len(self.buffer)}") + break + + time.sleep(0.05) # 短暂休眠,等待缓冲区填充 + logging.debug(f"等待预缓冲,当前大小: {len(self.buffer)}/{min_buffer_samples}") + + # 正常启动流 + self.start() diff --git a/webui.py b/webui.py index 3552cd92..323ff61d 100644 --- a/webui.py +++ b/webui.py @@ -11,8 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + # 必须在导入torch和创建任何模型之前设置CUDA_VISIBLE_DEVICES import os +os.environ["CUDA_VISIBLE_DEVICES"] = "5" import sys +import platform import argparse import gradio as gr import numpy as np @@ -20,181 +23,510 @@ import torchaudio import random import librosa +from funasr import AutoModel +from funasr.utils.postprocess_utils import rich_transcription_postprocess +from cosyvoice.utils.file_utils import load_wav, logging +import shutil +import time + +# # 设置日志级别为 DEBUG +# logging.basicConfig(level=logging.DEBUG, +# format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +# logging.getLogger().setLevel(logging.DEBUG) + +# # 确保设置影响所有模块 +# for name in logging.root.manager.loggerDict: +# logging.getLogger(name).setLevel(logging.DEBUG) + + + +# 设置环境变量禁用tokenizers并行处理 +os.environ["TOKENIZERS_PARALLELISM"] = "false" + ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR)) from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 -from cosyvoice.utils.file_utils import load_wav, logging from cosyvoice.utils.common import set_all_random_seed -inference_mode_list = ['预训练音色', '3s极速复刻', '跨语种复刻', '自然语言控制'] + + + +# from modelscope import snapshot_download +# snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B') +try: + shutil.copy2('spk2info.pt', 'pretrained_models/CosyVoice2-0.5B/spk2info.pt') +except Exception as e: + logging.warning(f'复制文件失败: {e}') + +inference_mode_list = ['预训练音色', '自然语言控制', '3s极速复刻', '跨语种复刻'] instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成音频按钮', - '3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮', - '跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 点击生成音频按钮', + '3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮\n4. (可选)保存音色模型', + '跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 点击生成音频按钮\n3. (可选)保存音色模型', '自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'} stream_mode_list = [('否', False), ('是', True)] max_val = 0.8 +def refresh_sft_spk(): + """刷新音色选择列表 """ + # 获取自定义音色 + files = [(entry.name, entry.stat().st_mtime) for entry in os.scandir(f"{ROOT_DIR}/voices")] + files.sort(key=lambda x: x[1], reverse=True) # 按时间排序 + + # 添加预训练音色 + choices = [f[0].replace(".pt", "") for f in files] + cosyvoice.list_available_spks() -def generate_seed(): - seed = random.randint(1, 100000000) + if not choices: + choices = [''] + + return {"choices": choices, "__type__": "update"} + + +def refresh_prompt_wav(): + """刷新音频选择列表""" + files = [(entry.name, entry.stat().st_mtime) for entry in os.scandir(f"{ROOT_DIR}/audios")] + files.sort(key=lambda x: x[1], reverse=True) # 按时间排序 + choices = ["请选择参考音频或者自己上传"] + [f[0] for f in files] + + if not choices: + choices = [''] + + return {"choices": choices, "__type__": "update"} + + +def change_prompt_wav(filename): + """切换音频文件""" + full_path = f"{ROOT_DIR}/audios/{filename}" + if not os.path.exists(full_path): + logging.warning(f"音频文件不存在: {full_path}") + return None + + return full_path + +def save_voice_model(voice_name): + """保存音色模型""" + if not voice_name: + gr.Info("音色名称不能为空") + return False + + try: + shutil.copyfile(f"{ROOT_DIR}/output.pt", f"{ROOT_DIR}/voices/{voice_name}.pt") + gr.Info("音色保存成功,存放位置为voices目录") + return True + except Exception as e: + logging.error(f"保存音色失败: {e}") + gr.Warning("保存音色失败") + return False + +def generate_random_seed(): + """生成随机种子""" return { "__type__": "update", - "value": seed + "value": random.randint(1, 100000000) } - -def postprocess(speech, top_db=60, hop_length=220, win_length=440): +def postprocess(speech, top_db = 60, hop_length = 220, win_length = 440): + """音频后处理方法""" + # 修剪静音部分 speech, _ = librosa.effects.trim( - speech, top_db=top_db, + speech, + top_db=top_db, frame_length=win_length, hop_length=hop_length ) + + # 音量归一化 if speech.abs().max() > max_val: speech = speech / speech.abs().max() * max_val + + # 添加尾部静音 speech = torch.concat([speech, torch.zeros(1, int(cosyvoice.sample_rate * 0.2))], dim=1) return speech - def change_instruction(mode_checkbox_group): - return instruct_dict[mode_checkbox_group] + """切换模式的处理""" + voice_dropdown_visible = mode_checkbox_group in ['预训练音色', '自然语言控制'] + save_btn_visible = mode_checkbox_group in ['3s极速复刻'] + return ( + instruct_dict[mode_checkbox_group], + gr.update(visible=voice_dropdown_visible), + gr.update(visible=save_btn_visible) + ) +def prompt_wav_recognition(prompt_wav): + """音频识别文本""" + if prompt_wav is None: + return '' + + try: + res = asr_model.generate( + input=prompt_wav, + language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech" + use_itn=True, + ) + return res[0]["text"].split('|>')[-1] + except Exception as e: + logging.error(f"音频识别文本失败: {e}") + gr.Warning("识别文本失败,请检查音频是否包含人声内容") + return '' + +def load_voice_data(voice_path): + """加载音色数据""" + try: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + voice_data = torch.load(voice_path, map_location=device) if os.path.exists(voice_path) else None + return voice_data.get('audio_ref') if voice_data else None + except Exception as e: + logging.error(f"加载音色文件失败: {e}") + return None + +def validate_input(mode, tts_text, sft_dropdown, prompt_text, prompt_wav, instruct_text): + """验证输入参数的合法性 + + Args: + mode: 推理模式 + tts_text: 合成文本 + sft_dropdown: 预训练音色 + prompt_text: prompt文本 + prompt_wav: prompt音频 + instruct_text: instruct文本 + + Returns: + bool: 验证是否通过 + str: 错误信息 + """ + if mode in ['自然语言控制']: + if not cosyvoice.is_05b and cosyvoice.instruct is False: + return False, f'您正在使用自然语言控制模式, {args.model_dir}模型不支持此模式' + if not instruct_text: + return False, '您正在使用自然语言控制模式, 请输入instruct文本' + + elif mode in ['跨语种复刻']: + if not cosyvoice.is_05b and cosyvoice.instruct is True: + return False, f'您正在使用跨语种复刻模式, {args.model_dir}模型不支持此模式' + if not prompt_wav: + return False, '您正在使用跨语种复刻模式, 请提供prompt音频' + + elif mode in ['3s极速复刻', '跨语种复刻']: + if not prompt_wav: + return False, 'prompt音频为空,您是否忘记输入prompt音频?' + if torchaudio.info(prompt_wav).sample_rate < prompt_sr: + return False, f'prompt音频采样率{torchaudio.info(prompt_wav).sample_rate}低于{prompt_sr}' + + elif mode in ['预训练音色']: + if not sft_dropdown: + return False, '没有可用的预训练音色!' + + if mode in ['3s极速复刻'] and not prompt_text: + return False, 'prompt文本为空,您是否忘记输入prompt文本?' + + return True, '' + +def process_audio(speech_generator, stream): + """处理音频生成 + + Args: + speech_generator: 音频生成器 + stream: 是否流式处理 + + Returns: + tuple: (音频数据列表, 总时长) + """ + tts_speeches = [] + total_duration = 0 + for i in speech_generator: + tts_speeches.append(i['tts_speech']) + total_duration += i['tts_speech'].shape[1] / cosyvoice.sample_rate + if stream: + yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten()), None + + if not stream: + audio_data = torch.concat(tts_speeches, dim=1) + yield None, (cosyvoice.sample_rate, audio_data.numpy().flatten()) + + yield total_duration def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, stream, speed): - if prompt_wav_upload is not None: - prompt_wav = prompt_wav_upload - elif prompt_wav_record is not None: - prompt_wav = prompt_wav_record - else: - prompt_wav = None - # if instruct mode, please make sure that model is iic/CosyVoice-300M-Instruct and not cross_lingual mode - if mode_checkbox_group in ['自然语言控制']: - if cosyvoice.instruct is False: - gr.Warning('您正在使用自然语言控制模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M-Instruct模型'.format(args.model_dir)) - yield (cosyvoice.sample_rate, default_data) - if instruct_text == '': - gr.Warning('您正在使用自然语言控制模式, 请输入instruct文本') - yield (cosyvoice.sample_rate, default_data) - if prompt_wav is not None or prompt_text != '': - gr.Info('您正在使用自然语言控制模式, prompt音频/prompt文本会被忽略') - # if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language - if mode_checkbox_group in ['跨语种复刻']: - if cosyvoice.instruct is True: - gr.Warning('您正在使用跨语种复刻模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M模型'.format(args.model_dir)) - yield (cosyvoice.sample_rate, default_data) - if instruct_text != '': - gr.Info('您正在使用跨语种复刻模式, instruct文本会被忽略') - if prompt_wav is None: - gr.Warning('您正在使用跨语种复刻模式, 请提供prompt音频') - yield (cosyvoice.sample_rate, default_data) - gr.Info('您正在使用跨语种复刻模式, 请确保合成文本和prompt文本为不同语言') - # if in zero_shot cross_lingual, please make sure that prompt_text and prompt_wav meets requirements - if mode_checkbox_group in ['3s极速复刻', '跨语种复刻']: - if prompt_wav is None: - gr.Warning('prompt音频为空,您是否忘记输入prompt音频?') - yield (cosyvoice.sample_rate, default_data) - if torchaudio.info(prompt_wav).sample_rate < prompt_sr: - gr.Warning('prompt音频采样率{}低于{}'.format(torchaudio.info(prompt_wav).sample_rate, prompt_sr)) - yield (cosyvoice.sample_rate, default_data) - # sft mode only use sft_dropdown - if mode_checkbox_group in ['预训练音色']: - if instruct_text != '' or prompt_wav is not None or prompt_text != '': - gr.Info('您正在使用预训练音色模式,prompt文本/prompt音频/instruct文本会被忽略!') - if sft_dropdown == '': - gr.Warning('没有可用的预训练音色!') - yield (cosyvoice.sample_rate, default_data) - # zero_shot mode only use prompt_wav prompt text - if mode_checkbox_group in ['3s极速复刻']: - if prompt_text == '': - gr.Warning('prompt文本为空,您是否忘记输入prompt文本?') - yield (cosyvoice.sample_rate, default_data) - if instruct_text != '': - gr.Info('您正在使用3s极速复刻模式,预训练音色/instruct文本会被忽略!') + """生成音频的主函数 + + Args: + tts_text: 合成文本 + mode_checkbox_group: 推理模式 + sft_dropdown: 预训练音色 + prompt_text: prompt文本 + prompt_wav_upload: 上传的prompt音频 + prompt_wav_record: 录制的prompt音频 + instruct_text: instruct文本 + seed: 随机种子 + stream: 是否流式推理 + speed: 语速 + + Yields: + tuple: 音频数据 + """ + start_time = time.time() + logging.info(f"开始生成音频 - 模式: {mode_checkbox_group}, 文本长度: {len(tts_text)}") + # 处理prompt音频输入 + prompt_wav = prompt_wav_upload if prompt_wav_upload is not None else prompt_wav_record + + # 验证输入 + is_valid, error_msg = validate_input(mode_checkbox_group, tts_text, sft_dropdown, + prompt_text, prompt_wav, instruct_text) + if not is_valid: + gr.Warning(error_msg) + yield (cosyvoice.sample_rate, default_data), None + return + + # 设置随机种子 + set_all_random_seed(seed) + # 根据不同模式处理 if mode_checkbox_group == '预训练音色': - logging.info('get sft inference request') - set_all_random_seed(seed) - for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream, speed=speed): - yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten()) - elif mode_checkbox_group == '3s极速复刻': - logging.info('get zero_shot inference request') + # logging.info('get sft inference request') + generator = cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream, speed=speed) + + elif mode_checkbox_group in ['3s极速复刻', '跨语种复刻']: + # logging.info(f'get {mode_checkbox_group} inference request') prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr)) - set_all_random_seed(seed) - for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed): - yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten()) - elif mode_checkbox_group == '跨语种复刻': - logging.info('get cross_lingual inference request') - prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr)) - set_all_random_seed(seed) - for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed): - yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten()) - else: - logging.info('get instruct inference request') - set_all_random_seed(seed) - for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed): - yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten()) + inference_func = (cosyvoice.inference_zero_shot if mode_checkbox_group == '3s极速复刻' + else cosyvoice.inference_cross_lingual) + generator = inference_func(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed) + + else: # 自然语言控制模式 + # logging.info('get instruct inference request') + voice_path = f"{ROOT_DIR}/voices/{sft_dropdown}.pt" + prompt_speech_16k = load_voice_data(voice_path) + + if prompt_speech_16k is None: + gr.Warning('预训练音色文件中缺少prompt_speech数据!') + yield (cosyvoice.sample_rate, default_data), None + return + + generator = cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k, + stream=stream, speed=speed) + + # 处理音频生成并获取总时长 + audio_generator = process_audio(generator, stream) + total_duration = 0 + + # 收集所有音频输出 + for output in audio_generator: + if isinstance(output, (float, int)): # 如果是总时长 + total_duration = output + else: # 如果是音频数据 + yield output + processing_time = time.time() - start_time + rtf = processing_time / total_duration if total_duration > 0 else 0 + logging.info(f"音频生成完成 耗时: {processing_time:.2f}秒, rtf: {rtf:.2f}") + +def update_audio_visibility(stream_enabled): + """更新音频组件的可见性""" + return [ + gr.update(visible=stream_enabled), # 流式音频组件 + gr.update(visible=not stream_enabled) # 非流式音频组件 + ] def main(): with gr.Blocks() as demo: - gr.Markdown("### 代码库 [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) \ - 预训练模型 [CosyVoice-300M](https://www.modelscope.cn/models/iic/CosyVoice-300M) \ + # 页面标题和说明 + gr.Markdown("### 代码库 [CosyVoice2-Ex](https://github.com/journey-ad/CosyVoice2-Ex) 原始项目 [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) \ + 预训练模型 [CosyVoice2-0.5B](https://www.modelscope.cn/models/iic/CosyVoice2-0.5B) \ + [CosyVoice-300M](https://www.modelscope.cn/models/iic/CosyVoice-300M) \ [CosyVoice-300M-Instruct](https://www.modelscope.cn/models/iic/CosyVoice-300M-Instruct) \ [CosyVoice-300M-SFT](https://www.modelscope.cn/models/iic/CosyVoice-300M-SFT)") gr.Markdown("#### 请输入需要合成的文本,选择推理模式,并按照提示步骤进行操作") - tts_text = gr.Textbox(label="输入合成文本", lines=1, value="我是通义实验室语音团队全新推出的生成式语音大模型,提供舒适自然的语音合成能力。") + # 主要输入区域 + tts_text = gr.Textbox( + label="输入合成文本", + lines=1, + value="CosyVoice迎来全面升级,提供更准、更稳、更快、 更好的语音生成能力。CosyVoice is undergoing a comprehensive upgrade, providing more accurate, stable, faster, and better voice generation capabilities." + ) + with gr.Row(): - mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0]) - instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5) - sft_dropdown = gr.Dropdown(choices=sft_spk, label='选择预训练音色', value=sft_spk[0], scale=0.25) - stream = gr.Radio(choices=stream_mode_list, label='是否流式推理', value=stream_mode_list[0][1]) - speed = gr.Number(value=1, label="速度调节(仅支持非流式推理)", minimum=0.5, maximum=2.0, step=0.1) + mode_checkbox_group = gr.Radio( + choices=inference_mode_list, + label='选择推理模式', + value=inference_mode_list[0] + ) + instruction_text = gr.Text( + label="操作步骤", + value=instruct_dict[inference_mode_list[0]], + scale=0.5 + ) + + # 音色选择部分 + sft_dropdown = gr.Dropdown( + choices=sft_spk, + label='选择预训练音色', + value=sft_spk[0], + scale=0.25 + ) + refresh_voice_button = gr.Button("刷新音色") + + # 流式控制和速度调节 + with gr.Column(scale=0.25): + stream = gr.Radio( + choices=stream_mode_list, + label='是否流式推理', + value=stream_mode_list[0][1] + ) + speed = gr.Number( + value=1, + label="速度调节(仅支持非流式推理)", + minimum=0.5, + maximum=2.0, + step=0.1 + ) + + # 随机种子控制 with gr.Column(scale=0.25): seed_button = gr.Button(value="\U0001F3B2") seed = gr.Number(value=0, label="随机推理种子") + # 音频输入区域 with gr.Row(): - prompt_wav_upload = gr.Audio(sources='upload', type='filepath', label='选择prompt音频文件,注意采样率不低于16khz') - prompt_wav_record = gr.Audio(sources='microphone', type='filepath', label='录制prompt音频文件') - prompt_text = gr.Textbox(label="输入prompt文本", lines=1, placeholder="请输入prompt文本,需与prompt音频内容一致,暂时不支持自动识别...", value='') - instruct_text = gr.Textbox(label="输入instruct文本", lines=1, placeholder="请输入instruct文本.", value='') + prompt_wav_upload = gr.Audio( + sources='upload', + type='filepath', + label='选择prompt音频文件,注意采样率不低于16khz' + ) + prompt_wav_record = gr.Audio( + sources='microphone', + type='filepath', + label='录制prompt音频文件' + ) + wavs_dropdown = gr.Dropdown( + label="参考音频列表", + choices=reference_wavs, + value="请选择参考音频或者自己上传", + interactive=True + ) + refresh_button = gr.Button("刷新参考音频") + # 文本输入区域 + prompt_text = gr.Textbox( + label="输入prompt文本", + lines=1, + placeholder="请输入prompt文本,支持自动识别,您可以自行修正识别结果...", + value='' + ) + instruct_text = gr.Textbox( + label="输入instruct文本", + lines=1, + placeholder="请输入instruct文本. 例如: 用四川话说这句话。", + value='' + ) + + # 保存音色按钮(默认隐藏) + with gr.Row(visible=False) as save_spk_btn: + new_name = gr.Textbox(label="输入新的音色名称", lines=1, placeholder="输入新的音色名称.", value='', scale=2) + save_button = gr.Button(value="保存音色模型", scale=1) + + # 生成按钮 generate_button = gr.Button("生成音频") - audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=True) + # 音频输出区域 + with gr.Group() as audio_group: + audio_output_stream = gr.Audio( + label="合成音频(流式)", + value=None, + streaming=True, + autoplay=True, + show_label=True, + show_download_button=True, + visible=False + ) + audio_output_normal = gr.Audio( + label="合成音频", + value=None, + streaming=False, + autoplay=True, + show_label=True, + show_download_button=True, + visible=True + ) + + # 绑定事件 + refresh_voice_button.click(fn=refresh_sft_spk, inputs=[], outputs=[sft_dropdown]) + refresh_button.click(fn=refresh_prompt_wav, inputs=[], outputs=[wavs_dropdown]) + wavs_dropdown.change(change_prompt_wav, inputs=[wavs_dropdown], outputs=[prompt_wav_upload]) + save_button.click(save_voice_model, inputs=[new_name]) + seed_button.click(generate_random_seed, inputs=[], outputs=[seed]) + + generate_button.click( + generate_audio, + inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, + prompt_wav_upload, prompt_wav_record, instruct_text, + seed, stream, speed], + outputs=[audio_output_stream, audio_output_normal] + ) + + mode_checkbox_group.change( + fn=change_instruction, + inputs=[mode_checkbox_group], + outputs=[instruction_text, sft_dropdown, save_spk_btn] + ) + + prompt_wav_upload.change(fn=prompt_wav_recognition, inputs=[prompt_wav_upload], outputs=[prompt_text]) + prompt_wav_record.change(fn=prompt_wav_recognition, inputs=[prompt_wav_record], outputs=[prompt_text]) + + stream.change( + fn=update_audio_visibility, + inputs=[stream], + outputs=[audio_output_stream, audio_output_normal] + ) - seed_button.click(generate_seed, inputs=[], outputs=seed) - generate_button.click(generate_audio, - inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, - seed, stream, speed], - outputs=[audio_output]) - mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text]) + # 配置队列和启动服务 demo.queue(max_size=4, default_concurrency_limit=2) - demo.launch(server_name='0.0.0.0', server_port=args.port) + demo.launch(server_name='0.0.0.0', server_port=args.port, inbrowser=args.open) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--port', type=int, - default=8000) + default=8080) parser.add_argument('--model_dir', type=str, default='pretrained_models/CosyVoice2-0.5B', help='local path or modelscope repo id') + parser.add_argument('--open', + action='store_true', + help='open in browser') + parser.add_argument('--log_level', + type=str, + default='INFO', + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + help='set log level') args = parser.parse_args() + + + logging.getLogger().setLevel(getattr(logging, args.log_level)) + + try: cosyvoice = CosyVoice(args.model_dir) - except Exception: + except Exception as e: + logging.warning(f"尝试加载CosyVoice模型失败: {e}") try: - cosyvoice = CosyVoice2(args.model_dir) - except Exception: + cosyvoice = CosyVoice2(args.model_dir, load_trt=True, fp16=True) + except Exception as e: + logging.error(f"尝试加载CosyVoice2模型也失败: {e}") raise TypeError('no valid model_type!') - sft_spk = cosyvoice.list_available_spks() - if len(sft_spk) == 0: - sft_spk = [''] + + sft_spk = refresh_sft_spk()['choices'] + reference_wavs = refresh_prompt_wav()['choices'] + prompt_sr = 16000 default_data = np.zeros(cosyvoice.sample_rate) + + model_dir = "iic/SenseVoiceSmall" + asr_model = AutoModel( + model=model_dir, + disable_update=True, + log_level=args.log_level, + device="cuda:0") main()