diff --git a/.gitmodules b/.gitmodules index f70299afd0..11735cacaf 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,7 @@ [submodule "local_server/cosyvoice_server/CosyVoice"] path = local_server/cosyvoice_server/CosyVoice url = https://github.com/wehos/CosyVoice-v3 +[submodule "local_server/qwen3_tts_server/Qwen3-TTS"] + path = local_server/qwen3_tts_server/Qwen3-TTS + url = https://github.com/rophec/Qwen3-TTS.git + branch = qwen3_tts diff --git a/local_server/qwen3_tts_server/Qwen3-TTS b/local_server/qwen3_tts_server/Qwen3-TTS new file mode 160000 index 0000000000..1ab0dd7535 --- /dev/null +++ b/local_server/qwen3_tts_server/Qwen3-TTS @@ -0,0 +1 @@ +Subproject commit 1ab0dd75353392f28a0d05d9ca960c9954b13c83 diff --git a/local_server/qwen3_tts_server/local_server.py b/local_server/qwen3_tts_server/local_server.py new file mode 100644 index 0000000000..867bc6afdf --- /dev/null +++ b/local_server/qwen3_tts_server/local_server.py @@ -0,0 +1,766 @@ +import asyncio +import websockets +import json +import logging +import torch +import sys +import os +import time +import threading +import uuid +import queue +import numpy as np +import re +from pathlib import Path + +ENABLE_TRUE_STREAMING = False +# ======================================================== +# 1. 初始化 Logging +# ======================================================== +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - [TTS-Server] %(levelname)s - %(message)s' +) +logger = logging.getLogger("Qwen3-TTS-Server") + +# ======================================================== +# 2. 路径配置 (适配 Linux 原生路径) +# ======================================================== +PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) +# 假设你还在 WSL 的这个位置 +MODEL_SOURCE_DIR = os.path.join(PROJECT_ROOT, "Qwen3-TTS-streaming") + +if MODEL_SOURCE_DIR not in sys.path: + sys.path.insert(0, MODEL_SOURCE_DIR) + +try: + from qwen_tts.core.models.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration + from qwen_tts.core.models.processing_qwen3_tts import Qwen3TTSProcessor + from qwen_tts.inference.qwen3_tts_model import Qwen3TTSModel, VoiceClonePromptItem + + torch.serialization.add_safe_globals([VoiceClonePromptItem]) + logger.info("✅ 成功导入 Qwen3-TTS 原生组件") +except ImportError as e: + logger.error(f"❌ 组件导入失败: {e}") + sys.exit(1) + + +# ======================================================== +# 3. Server 类定义 (融合版) +# ======================================================== +class QwenLocalServer: + def __init__( + self, + model_path, + voice_pt_path=None, + ref_wav=None, + ref_text=None, + language=None, + chunk_size=None, + buffer_fallback_chars=None, + ): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.pt_path = voice_pt_path or os.path.join(PROJECT_ROOT, "06B_arona.pt") # 注意 这里的nyaning_voice 是测试用音声的音色 hidden2048 适配1.7B + self.model = None + self.model_hidden_size = None + self.cached_prompt = None + self.voice_lock = threading.Lock() + self.prompt_cache = {} + self.active_voice_file = str((Path(PROJECT_ROOT) / "active_voice.json").resolve()) + self.active_voice_path = None + self.active_voice_mtime = None + self.voice_version = 0 + self.pad_token_id = None + self.ref_text = ref_text or "アラバマ シュー ノ サイダイ トシ ワ バーミングハム デ アル。" # 后面的是自己的sample样本 需要更改 + self.ref_wav = ref_wav or os.path.join(PROJECT_ROOT, "uttid_f1.wav") # 测试用sample 样本 + self.language = language or "Chinese" + self.chunk_size = int(chunk_size) if chunk_size else 4096 + self.buffer_fallback_chars = int(buffer_fallback_chars) if buffer_fallback_chars else 30 + self.inference_lock =threading.Lock() + # --- 任务队列 (解决并发卡顿) --- + self.task_queue = queue.Queue() + + self.is_06b = "0.6B" in model_path.upper() + + if self.is_06b: + logger.info("🛣️ 探测到 0.6B 模型,走专属极速通道...") + self._load_engine_06b(model_path) + else: + logger.info("🛣️ 走 1.7B 原版通道...") + self._load_engine(model_path) + # 启动后台工人 + threading.Thread(target=self._worker_loop, daemon=True).start() + + def _validate_prompt_dim(self, ref_spk, pt_path: str): + """校验 ref_spk_embedding 维度是否匹配当前模型 hidden_size""" + if self.model_hidden_size is None or ref_spk is None: + return True + spk_dim = ref_spk.shape[-1] + if spk_dim != self.model_hidden_size: + logger.error( + f"❌ 音色文件 {pt_path} 的 ref_spk_embedding 维度 ({spk_dim}) " + f"与当前模型 hidden_size ({self.model_hidden_size}) 不匹配!" + f"请用当前模型重新生成 .pt 文件。" + ) + return False + return True + + def _load_prompt_from_pt(self, pt_path: str): + payload = torch.load(pt_path, map_location=self.device, weights_only=False) + + if isinstance(payload, dict) and "items" in payload: + items_raw = payload.get("items") + if isinstance(items_raw, list): + items = [] + for it in items_raw: + if isinstance(it, VoiceClonePromptItem): + items.append(it) + continue + if not isinstance(it, dict): + raise TypeError("Invalid voice prompt item") + ref_code = it.get("ref_code", None) + if ref_code is not None and not torch.is_tensor(ref_code): + ref_code = torch.tensor(ref_code, device=self.device) + if ref_code is not None: + ref_code = ref_code.to(device=self.device).long() + ref_spk = it.get("ref_spk_embedding", None) + if ref_spk is None: + raise ValueError("Missing ref_spk_embedding") + if not torch.is_tensor(ref_spk): + ref_spk = torch.tensor(ref_spk, device=self.device) + ref_spk = ref_spk.to(device=self.device, dtype=torch.bfloat16) + if not self._validate_prompt_dim(ref_spk, pt_path): + raise ValueError( + f"Prompt 维度不匹配: ref_spk={ref_spk.shape[-1]} " + f"vs model={self.model_hidden_size}。" + f"请用当前模型重新生成 .pt 文件。" + ) + items.append( + VoiceClonePromptItem( + ref_code=ref_code, + ref_spk_embedding=ref_spk, + x_vector_only_mode=bool(it.get("x_vector_only_mode", False)), + icl_mode=bool(it.get("icl_mode", not bool(it.get("x_vector_only_mode", False)))), + ref_text=it.get("ref_text", None), + ) + ) + return items + + if isinstance(payload, list): + for it in payload: + if isinstance(it, VoiceClonePromptItem): + if it.ref_code is not None: + it.ref_code = it.ref_code.to(device=self.device).long() + if it.ref_spk_embedding is not None: + it.ref_spk_embedding = it.ref_spk_embedding.to(device=self.device, dtype=torch.bfloat16) + if not self._validate_prompt_dim(it.ref_spk_embedding, pt_path): + raise ValueError( + f"Prompt 维度不匹配: ref_spk={it.ref_spk_embedding.shape[-1]} " + f"vs model={self.model_hidden_size}。" + f"请用当前模型重新生成 .pt 文件。" + ) + return payload + + return payload + + def _read_active_voice_path(self): + try: + if not os.path.exists(self.active_voice_file): + return None, None + mtime = os.path.getmtime(self.active_voice_file) + with open(self.active_voice_file, "r", encoding="utf-8") as f: + data = json.load(f) + p = None + if isinstance(data, dict): + p = data.get("voice_pt_path") + if not p: + return None, mtime + p = str(p) + return p, mtime + except Exception as e: + logger.error(f"读取 active_voice.json 失败: {e}") + return None, None + + def _ensure_active_prompt_loaded(self): + with self.voice_lock: + desired_path, mtime = self._read_active_voice_path() + if not desired_path: + desired_path = self.pt_path + try: + mtime = os.path.getmtime(desired_path) if os.path.exists(desired_path) else None + except Exception: + mtime = None + + if ( + self.active_voice_path == desired_path + and self.active_voice_mtime == mtime + and self.cached_prompt is not None + ): + return self.cached_prompt, self.voice_version + + cached = self.prompt_cache.get(desired_path) + if cached is not None: + cached_mtime, cached_prompt = cached + if cached_mtime == mtime: + self.cached_prompt = cached_prompt + else: + self.cached_prompt = self._load_prompt_from_pt(desired_path) + self.prompt_cache[desired_path] = (mtime, self.cached_prompt) + else: + self.cached_prompt = self._load_prompt_from_pt(desired_path) + self.prompt_cache[desired_path] = (mtime, self.cached_prompt) + + self.active_voice_path = desired_path + self.active_voice_mtime = mtime + self.pt_path = desired_path + self.voice_version += 1 + logger.info(f"🔁 Active voice switched: {self.active_voice_path}") + return self.cached_prompt, self.voice_version + + def _load_engine(self, model_path): + try: + t0 = time.time() + logger.info(f"正在启动 N.E.K.O 语音引擎 (Device: {self.device})...") + + processor = Qwen3TTSProcessor.from_pretrained(model_path, fix_mistral_regex=True) + + # --- 关键优化 1: 显式指定 Bfloat16 和 Flash Attention 2 --- + raw_model = Qwen3TTSForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, # 给 transformers 看 + dtype=torch.bfloat16, # 给 qwen 看 + attn_implementation="flash_attention_2", + device_map=self.device, # 旧的是cuda 硬编码 + low_cpu_mem_usage=True + ) + raw_model.eval() + + # --- 兼容修复: 0.6B 模型的 code_predictor 可能被默认初始化为 float32 --- + # device_map 模式下 .to(dtype) 会被 Accelerate 拦截,必须逐参数强制转换 + for param in raw_model.parameters(): + if param.data.dtype == torch.float32: + param.data = param.data.to(dtype=torch.bfloat16) + for buf in raw_model.buffers(): + if buf.dtype == torch.float32: + buf.data = buf.data.to(dtype=torch.bfloat16) + + self.model = Qwen3TTSModel(model=raw_model, processor=processor) + self.model_hidden_size = raw_model.config.talker_config.hidden_size + logger.info(f"📐 模型 hidden_size = {self.model_hidden_size}") + + # --- 关键优化 2: 预设 Config 防止推理时反复初始化 --- + self.pad_token_id = raw_model.config.pad_token_id or raw_model.config.eos_token_id + if hasattr(self.model, 'generation_config'): + self.model.generation_config.pad_token_id = self.pad_token_id + + # 加载/提取音色 + if os.path.exists(self.pt_path): + logger.info(f"✨ 发现音色特征 {self.pt_path},加载中...") + self.cached_prompt = self._load_prompt_from_pt(self.pt_path) + else: + logger.warning("🎙️ 未发现音色特征,尝试提取...") + ref_wav = self.ref_wav + if os.path.exists(ref_wav): + with torch.no_grad(): + with torch.amp.autocast(self.device, dtype=torch.bfloat16): + self.cached_prompt = self.model.create_voice_clone_prompt( + ref_audio=ref_wav, ref_text=self.ref_text + ) + torch.save(self.cached_prompt, self.pt_path) + logger.info("✅ 音色提取完成") + else: + logger.error(f"❌ 找不到参考音频: {ref_wav}") + + logger.info(f"🚀 引擎就绪 | 精度: {raw_model.dtype} | 耗时: {time.time() - t0:.2f}s") + except Exception as e: + logger.error(f"❌ 加载引擎异常: {e}") + sys.exit(1) + + def _worker_loop(self): + logger.info("👷 智能拼句队列服务已启动") + while True: + task = self.task_queue.get() + if task is None: + break + + full_text, job_id, loop, audio_queue, cancel_event, prompt_snapshot, language = task + + if cancel_event.is_set(): + self.task_queue.task_done() + continue + + # 任务分发 判断是不是qwen_0.6B + if getattr(self, 'is_06b', False): + self._do_inference_06b(full_text, job_id, loop, audio_queue, cancel_event, prompt_snapshot, language) + else: + self._do_inference(full_text, job_id, loop, audio_queue, cancel_event, prompt_snapshot, language) + + self.task_queue.task_done() + + def _do_inference(self, full_text, job_id, loop, audio_queue, cancel_event, prompt_snapshot, language): + try: + if not self.model or prompt_snapshot is None: + return + + start_time = time.time() + + # 🟢 [核心开关]:在这里切换 + # True: 使用魔改版的流式 Generator (算出一点发一点) + # False: 使用魔改版提供的普通接口 (算完一整句再发) + USE_STREAMING_GENERATOR = True + + # 探测流式方法 + stream_func = None + search_target = self.model + depth = 0 + while search_target is not None and depth < 3: + if hasattr(search_target, "stream_generate_voice_clone"): + stream_func = getattr(search_target, "stream_generate_voice_clone") + break + search_target = getattr(search_target, "model", None) + depth += 1 + + # ======================================================= + # 🛠️ --- 精度与类型预处理 (彻底消灭 Casting fp32 警告) --- + # ======================================================= + # 1. 如果传入的是路径字符串,手动加载成张量对象 + if isinstance(prompt_snapshot, str): + try: + prompt_snapshot = torch.load(prompt_snapshot, map_location=self.device, weights_only=False) + except Exception as e: + logger.error(f"❌ 加载音色文件失败: {e}") + return + + # 兼容最新版带 "items" 字典包装的 .pt 文件 + prompt_items = prompt_snapshot.get("items", []) if isinstance(prompt_snapshot, dict) else prompt_snapshot + + if isinstance(prompt_items, list): + for item in prompt_items: + # 强转至设备与精度,匹配 Flash-Attention 2 + if hasattr(item, 'ref_code') and item.ref_code is not None: + item.ref_code = item.ref_code.to(device=self.device, dtype=torch.long) + if hasattr(item, 'ref_spk_embedding') and item.ref_spk_embedding is not None: + item.ref_spk_embedding = item.ref_spk_embedding.to(device=self.device, dtype=torch.bfloat16) + # ======================================================= + + # ======================================================= + # 🚀 模式 A:真流式 (Generator 模式) + # ======================================================= + if ENABLE_TRUE_STREAMING and USE_STREAMING_GENERATOR and stream_func: + logger.info(f"🌊 [{job_id}] 模式:真流式 (Generator)") + + total_samples = 0 # 统计总采样点,用于计算 RTF + + with torch.inference_mode(): + with self.inference_lock, torch.amp.autocast(self.device, dtype=torch.bfloat16): + for pcm_chunk in stream_func( + text=full_text, + voice_clone_prompt=prompt_snapshot, + language=language, + pad_token_id=self.pad_token_id, + emit_every_frames=24, # 1.7B 推荐 16 帧 + first_chunk_emit_every=6, + overlap_samples=512, + ): + if cancel_event.is_set(): + break + if pcm_chunk is not None: + # 兼容魔改库返回元组的格式 + audio_raw = pcm_chunk[0] if isinstance(pcm_chunk, (tuple, list)) else pcm_chunk + audio_data = np.asarray(audio_raw).flatten() + if len(audio_data) == 0: + continue + + # 累加采样的点数 + total_samples += len(audio_data) + + audio_int16 = (audio_data * 32767).clip(-32768, 32767).astype(np.int16) + loop.call_soon_threadsafe(audio_queue.put_nowait, audio_int16.tobytes()) + + inference_duration = time.time() - start_time + # Qwen3-TTS 默认采样率是 24000 + audio_real_duration = total_samples / 24000.0 + rtf = inference_duration / audio_real_duration if audio_real_duration > 0 else 0 + + logger.info(f"✅ [{job_id}] 真流式完成 | 耗时:{inference_duration:.3f}s | 音频:{audio_real_duration:.2f}s | RTF:{rtf:.4f}") + + # ======================================================= + # 🐢 模式 B:块生成 (原本的整句逻辑) + # ======================================================= + else: + mode_str = "原版块生成" if not stream_func else "流式库-整句模式" + logger.info(f"🐢 [{job_id}] 模式:{mode_str}") + with torch.inference_mode(): + with self.inference_lock, torch.amp.autocast(self.device, dtype=torch.bfloat16): + wavs, sr = self.model.generate_voice_clone( + text=full_text, + voice_clone_prompt=prompt_snapshot, + language=language, + pad_token_id=self.pad_token_id + ) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + inference_duration = time.time() - start_time + audio_data = wavs[0].flatten() + audio_int16 = (audio_data * 32767).clip(-32768, 32767).astype(np.int16) + + audio_real_duration = len(audio_int16) / sr + rtf = inference_duration / audio_real_duration if audio_real_duration > 0 else 0 + + logger.info( + f"✅ [{job_id}] 完成 | 耗时:{inference_duration:.3f}s | 音频:{audio_real_duration:.2f}s | RTF:{rtf:.4f}" + ) + + chunk_size = self.chunk_size + for i in range(0, len(audio_int16), chunk_size): + if cancel_event.is_set(): + break + chunk = audio_int16[i:i + chunk_size].tobytes() + loop.call_soon_threadsafe(audio_queue.put_nowait, chunk) + + except Exception as e: + logger.error(f"❌ 推理错误: {e}") + import traceback + logger.error(traceback.format_exc()) + finally: + loop.call_soon_threadsafe(audio_queue.put_nowait, b"__END__") + + + def _load_engine_06b(self, model_path): + try: + t0 = time.time() + logger.info(f"正在启动 0.6B 专属引擎 (Device: {self.device})...") + + # 锁死全局默认精度,防止内部隐式生成 fp32 张量(保持为 bfloat16,不再还原) + torch.set_default_dtype(torch.bfloat16) + + processor = Qwen3TTSProcessor.from_pretrained(model_path, fix_mistral_regex=True) + + # 剔除 device_map,防止 Accelerate 干扰精度 + raw_model = Qwen3TTSForConditionalGeneration.from_pretrained( + model_path, + dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + device_map=self.device, + low_cpu_mem_usage=True + ) + + + raw_model = raw_model.to(device=self.device, dtype=torch.bfloat16) + # 0.6B 的 code_predictor 子模块很小(5层),强制用 eager attention 避免 Flash Attention dtype 问题 + cp = raw_model.talker.code_predictor + cp.config._attn_implementation = "eager" + for layer in cp.model.layers: + if hasattr(layer, 'self_attn'): + layer.self_attn.config._attn_implementation = "eager" + raw_model.eval() + + self.model = Qwen3TTSModel(model=raw_model, processor=processor) + self.model_hidden_size = raw_model.config.talker_config.hidden_size + logger.info(f"📐 0.6B 模型 hidden_size = {self.model_hidden_size}") + + self.pad_token_id = raw_model.config.pad_token_id or raw_model.config.eos_token_id + if hasattr(self.model, 'generation_config'): + self.model.generation_config.pad_token_id = self.pad_token_id + + if os.path.exists(self.pt_path): + logger.info(f"✨ 发现音色特征 {self.pt_path},加载中...") + self.cached_prompt = self._load_prompt_from_pt(self.pt_path) + else: + logger.warning("🎙️ 未发现音色特征,跳过自动提取") + + logger.info(f"🚀 0.6B 引擎就绪 | 精度: {raw_model.dtype} | 耗时: {time.time() - t0:.2f}s") + except Exception as e: + logger.error(f"❌ 0.6B 加载引擎异常: {e}") + + def _do_inference_06b(self, full_text, job_id, loop, audio_queue, cancel_event, prompt_snapshot, language): + try: + if not self.model or prompt_snapshot is None: return + + start_time = time.time() + USE_STREAMING_GENERATOR = True + + stream_func = None + search_target = self.model + depth = 0 + while search_target is not None and depth < 3: + if hasattr(search_target, "stream_generate_voice_clone"): + stream_func = getattr(search_target, "stream_generate_voice_clone") + break + search_target = getattr(search_target, "model", None) + depth += 1 + + # --- 深度洗数据,杜绝音色精度报错 --- + if isinstance(prompt_snapshot, str): + try: + prompt_snapshot = torch.load(prompt_snapshot, map_location=self.device, weights_only=False) + except Exception as e: + logger.error(f"❌ 加载音色文件失败: {e}") + return + + prompt_items = prompt_snapshot.get("items", []) if isinstance(prompt_snapshot, + dict) else prompt_snapshot + if not isinstance(prompt_items, list): + prompt_items = [prompt_items] + + for item in prompt_items: + if hasattr(item, 'ref_code') and item.ref_code is not None: + item.ref_code = item.ref_code.to(device=self.device, dtype=torch.long) + if hasattr(item, 'ref_spk_embedding') and item.ref_spk_embedding is not None: + item.ref_spk_embedding = item.ref_spk_embedding.to(device=self.device, dtype=torch.bfloat16) + + try: + if hasattr(self.model, 'model'): + self.model.model.to(device=self.device, dtype=torch.bfloat16) + + # --- 推理开始 --- + if ENABLE_TRUE_STREAMING and USE_STREAMING_GENERATOR and stream_func: + logger.info(f"🌊 [{job_id}] 0.6B 专属真流式") + total_samples = 0 + with torch.inference_mode(): + with self.inference_lock, torch.amp.autocast(self.device, dtype=torch.bfloat16): + for pcm_chunk in stream_func( + text=full_text, + voice_clone_prompt=prompt_snapshot, + language=language, + pad_token_id=self.pad_token_id, + emit_every_frames=12, # 0.6B 专属的 12 帧对齐 + first_chunk_emit_every=3, + overlap_samples=512, + ): + if cancel_event.is_set(): break + if pcm_chunk is not None: + audio_raw = pcm_chunk[0] if isinstance(pcm_chunk, (tuple, list)) else pcm_chunk + audio_data = np.asarray(audio_raw).flatten() + if len(audio_data) == 0: continue + total_samples += len(audio_data) + audio_int16 = (audio_data * 32767).clip(-32768, 32767).astype(np.int16) + loop.call_soon_threadsafe(audio_queue.put_nowait, audio_int16.tobytes()) + + inference_duration = time.time() - start_time + audio_real_duration = total_samples / 24000.0 + rtf = inference_duration / audio_real_duration if audio_real_duration > 0 else 0 + logger.info( + f"✅ [{job_id}] 0.6B 真流式完成 | 耗时:{inference_duration:.3f}s | 音频:{audio_real_duration:.2f}s | RTF:{rtf:.4f}") + else: + logger.info(f"🐢 [{job_id}] 0.6B 原版生成") + with torch.inference_mode(): + with self.inference_lock, torch.amp.autocast(self.device, dtype=torch.bfloat16): + wavs, sr = self.model.generate_voice_clone( + text=full_text, voice_clone_prompt=prompt_snapshot, language=language, + pad_token_id=self.pad_token_id + ) + if torch.cuda.is_available(): torch.cuda.synchronize() + inference_duration = time.time() - start_time + audio_data = wavs[0].flatten() + audio_int16 = (audio_data * 32767).clip(-32768, 32767).astype(np.int16) + audio_real_duration = len(audio_int16) / sr + rtf = inference_duration / audio_real_duration if audio_real_duration > 0 else 0 + logger.info( + f"✅ [{job_id}] 0.6B 完成 | 耗时:{inference_duration:.3f}s | 音频:{audio_real_duration:.2f}s | RTF:{rtf:.4f}") + + chunk_size = self.chunk_size + for i in range(0, len(audio_int16), chunk_size): + if cancel_event.is_set(): break + chunk = audio_int16[i:i + chunk_size].tobytes() + loop.call_soon_threadsafe(audio_queue.put_nowait, chunk) + finally: + pass + except Exception as e: + logger.error(f"❌ 0.6B 推理错误: {e}") + finally: + loop.call_soon_threadsafe(audio_queue.put_nowait, b"__END__") + + async def handle_tts(self, websocket): + logger.info(f"客户端连接: {websocket.remote_address}") + loop = asyncio.get_running_loop() + + # 智能拼句缓冲区 + sentence_buffer = "" + + current_job_id = None + cancel_event = threading.Event() + audio_queue = asyncio.Queue() + last_voice_version = self.voice_version + + async def _stop_current_job(keep_buffer: bool = False): + nonlocal current_job_id, cancel_event, sentence_buffer + cancel_event.set() + current_job_id = None + cancel_event = threading.Event() + while not audio_queue.empty(): + try: + audio_queue.get_nowait() + except Exception: + break + if not keep_buffer: + sentence_buffer = "" + + # 发送循环 + async def _sender_loop(): + while True: + try: + chunk = await audio_queue.get() + if chunk == b"__END__": + if current_job_id: + # 适配 N.E.K.O 客户端协议 + response_done = {"type": "response.done", "job_id": current_job_id} + await websocket.send(json.dumps(response_done)) + continue + await websocket.send(chunk) + except Exception: + break + + sender_task = asyncio.create_task(_sender_loop()) + + try: + # 发送 ready 信号 + await websocket.send(json.dumps({"type": "ready"})) + + async for message in websocket: + if isinstance(message, bytes): + continue + try: + data = json.loads(message) + except Exception: + continue + + msg_type = data.get("type") + if "text" in data and not msg_type: msg_type = "legacy.text" + + if msg_type == "input_text_buffer.append": + text_fragment = data.get("text", "") + sentence_buffer += text_fragment + + elif msg_type in ("input_text_buffer.commit", "legacy.text"): + if msg_type == "legacy.text": + sentence_buffer += data.get("text", "") + + _, current_version = self._ensure_active_prompt_loaded() + if current_version != last_voice_version: + last_voice_version = current_version + # Hot voice switch: cancel current audio job but keep the buffered text. + # Otherwise the first committed sentence right after switching may be dropped. + await _stop_current_job(keep_buffer=True) + + # 正则智能断句 + parts = re.split(r'([。!?.!?\n]+)', sentence_buffer) + + if len(parts) > 1: + for i in range(0, len(parts) - 1, 2): + sentence = parts[i] + parts[i + 1] + sentence = sentence.strip() + if not sentence: continue + + if not current_job_id: + current_job_id = str(uuid.uuid4()) + await websocket.send(json.dumps({"type": "response.start", "job_id": current_job_id})) + + logger.info(f"📥 句子: {sentence[:20]}...") + prompt, _ = self._ensure_active_prompt_loaded() + self.task_queue.put((sentence, current_job_id, loop, audio_queue, cancel_event, prompt, self.language)) + + sentence_buffer = parts[-1] + + # 缓冲区兜底 (防止一直不说话) + if len(sentence_buffer) > self.buffer_fallback_chars: + sentence = sentence_buffer + sentence_buffer = "" + if not current_job_id: + current_job_id = str(uuid.uuid4()) + await websocket.send(json.dumps({"type": "response.start", "job_id": current_job_id})) + prompt, _ = self._ensure_active_prompt_loaded() + self.task_queue.put((sentence, current_job_id, loop, audio_queue, cancel_event, prompt, self.language)) + + elif msg_type == "cancel": + await _stop_current_job() + + elif msg_type == "session.update": + # 客户端会在连接后下发 voice_pt_path / language + # 这里更新运行参数并强制刷新音色缓存 + new_pt = data.get("voice_pt_path") + new_lang = data.get("language") + if new_pt: + with self.voice_lock: + self.pt_path = str(new_pt) + self.active_voice_path = None + self.active_voice_mtime = None + self.cached_prompt = None + self.voice_version += 1 + logger.info(f"🔁 会话更新: voice_pt_path -> {self.pt_path}") + if new_lang: + self.language = new_lang + logger.info(f"🔁 会话更新: language -> {self.language}") + + finally: + await _stop_current_job() + sender_task.cancel() + + +async def main(): + tts_custom = {} + repo_root = None + try: + p = Path(__file__).resolve() + config_path = None + for parent in [p.parent] + list(p.parents): + candidate = parent / "config" / "api_providers.json" + if candidate.exists(): + config_path = candidate + break + if config_path is not None: + with config_path.open("r", encoding="utf-8") as f: + cfg = json.load(f) + tts_custom = cfg.get("tts_custom", {}) or {} + repo_root = config_path.parent.parent + except Exception: + tts_custom = {} + repo_root = None + + + # model_path = tts_custom.get("model_path") or "/home/amadeus/models/qwen3_tts" # [旧的 1.7B 路径] + model_path = tts_custom.get("model_path") or "/home/amadeus/models/Qwen3-TTS-12Hz-0.6B-Base" # 0.6B 路径 按需更改 这里是硬编码 + host = tts_custom.get("host") or "127.0.0.1" + port = int(tts_custom.get("port") or 8765) + + voice_pt_path = tts_custom.get("voice_pt_path") + if voice_pt_path: + p = Path(voice_pt_path) + if not p.is_absolute() and repo_root is not None: + p = (repo_root / p).resolve() + voice_pt_path = str(p) + + ref_wav = tts_custom.get("ref_wav") + if ref_wav: + p = Path(ref_wav) + if not p.is_absolute() and repo_root is not None: + p = (repo_root / p).resolve() + ref_wav = str(p) + + ref_text = tts_custom.get("ref_text") + language = tts_custom.get("language") + chunk_size = tts_custom.get("chunk_size") + buffer_fallback_chars = tts_custom.get("buffer_fallback_chars") + + server = QwenLocalServer( + model_path, + voice_pt_path=voice_pt_path, + ref_wav=ref_wav, + ref_text=ref_text, + language=language, + chunk_size=chunk_size, + buffer_fallback_chars=buffer_fallback_chars, + ) + + async with websockets.serve(server.handle_tts, host, port): + logger.info(f"🚀 本地 TTS 服务已启动: ws://{host}:{port}") + await asyncio.Future() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass diff --git a/local_server/qwen3_tts_server/server_commit.py b/local_server/qwen3_tts_server/server_commit.py new file mode 100644 index 0000000000..e7c47b3804 --- /dev/null +++ b/local_server/qwen3_tts_server/server_commit.py @@ -0,0 +1,131 @@ +import os +import asyncio +import logging +import wave +from tts_realtime_client import TTSRealtimeClient, SessionMode +import pyaudio + +# QwenTTS 服务配置 +# 以下是北京地域url,如果使用新加坡地域的模型,需要将url替换为:wss://dashscope-intl.aliyuncs.com/api-ws/v1/realtime?model=qwen3-tts-flash-realtime +URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime?model=qwen3-tts-flash-realtime" +# 新加坡和北京地域的API Key不同。获取API Key:https://help.aliyun.com/zh/model-studio/get-api-key +# 若没有配置环境变量,请用百炼API Key将下行替换为:API_KEY="sk-xxx" +API_KEY = os.getenv("DASHSCOPE_API_KEY") + +if not API_KEY: + raise ValueError("Please set DASHSCOPE_API_KEY environment variable") + +# 收集音频数据 +_audio_chunks = [] +# 实时播放相关 +_AUDIO_SAMPLE_RATE = 24000 +_audio_pyaudio = pyaudio.PyAudio() +_audio_stream = None # 将在运行时打开 + +def _audio_callback(audio_bytes: bytes): + """TTSRealtimeClient 音频回调: 实时播放并缓存""" + global _audio_stream + if _audio_stream is not None: + try: + _audio_stream.write(audio_bytes) + except Exception as exc: + logging.error(f"PyAudio playback error: {exc}") + _audio_chunks.append(audio_bytes) + logging.info(f"Received audio chunk: {len(audio_bytes)} bytes") + +def _save_audio_to_file(filename: str = "output.wav", sample_rate: int = 24000) -> bool: + """将收集到的音频数据保存为 WAV 文件""" + if not _audio_chunks: + logging.warning("No audio data to save") + return False + + try: + audio_data = b"".join(_audio_chunks) + with wave.open(filename, 'wb') as wav_file: + wav_file.setnchannels(1) # 单声道 + wav_file.setsampwidth(2) # 16-bit + wav_file.setframerate(sample_rate) + wav_file.writeframes(audio_data) + logging.info(f"Audio saved to: {filename}") + return True + except Exception as exc: + logging.error(f"Failed to save audio: {exc}") + return False + +async def _produce_text(client: TTSRealtimeClient): + """向服务器发送文本片段""" + text_fragments = [ + "阿里云的大模型服务平台百炼是一站式的大模型开发及应用构建平台。", + "不论是开发者还是业务人员,都能深入参与大模型应用的设计和构建。", + "您可以通过简单的界面操作,在5分钟内开发出一款大模型应用,", + "或在几小时内训练出一个专属模型,从而将更多精力专注于应用创新。", + ] + + logging.info("Sending text fragments…") + for text in text_fragments: + logging.info(f"Sending fragment: {text}") + await client.append_text(text) + await asyncio.sleep(0.1) # 片段间稍作延时 + + # 等待服务器完成内部处理后结束会话 + await asyncio.sleep(1.0) + await client.finish_session() + +async def _run_demo(): + """运行完整 Demo""" + global _audio_stream + # 打开 PyAudio 输出流 + _audio_stream = _audio_pyaudio.open( + format=pyaudio.paInt16, + channels=1, + rate=_AUDIO_SAMPLE_RATE, + output=True, + frames_per_buffer=1024 + ) + + client = TTSRealtimeClient( + base_url=URL, + api_key=API_KEY, + voice="Cherry", + mode=SessionMode.SERVER_COMMIT, + audio_callback=_audio_callback + ) + + # 建立连接 + await client.connect() + + # 并行执行消息处理与文本发送 + consumer_task = asyncio.create_task(client.handle_messages()) + producer_task = asyncio.create_task(_produce_text(client)) + + await producer_task # 等待文本发送完成 + + # 等待 response.done + await client.wait_for_response_done() + + # 关闭连接并取消消费者任务 + await client.close() + consumer_task.cancel() + + # 关闭音频流 + if _audio_stream is not None: + _audio_stream.stop_stream() + _audio_stream.close() + _audio_pyaudio.terminate() + + # 保存音频数据 + os.makedirs("outputs", exist_ok=True) + _save_audio_to_file(os.path.join("outputs", "qwen_tts_output.wav")) + +def main(): + """同步入口""" + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + logging.info("Starting QwenTTS Realtime Client demo…") + asyncio.run(_run_demo()) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/local_server/qwen3_tts_server/tts_realtime_client.py b/local_server/qwen3_tts_server/tts_realtime_client.py new file mode 100644 index 0000000000..ff97c57ba5 --- /dev/null +++ b/local_server/qwen3_tts_server/tts_realtime_client.py @@ -0,0 +1,6849 @@ +# -- coding: utf-8 -- + +import asyncio +import websockets +import json +import base64 +import time +from typing import Optional, Callable, Dict, Any +from enum import Enum + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +class SessionMode(Enum): + SERVER_COMMIT = "server_commit" + COMMIT = "commit" + + +class TTSRealtimeClient: + """ + 与 TTS Realtime API 交互的客户端。 + + 该类提供了连接 TTS Realtime API、发送文本数据、获取音频输出以及管理 WebSocket 连接的相关方法。 + + 属性说明: + base_url (str): + Realtime API 的基础地址。 + api_key (str): + 用于身份验证的 API Key。 + voice (str): + 服务器合成语音所使用的声音。 + mode (SessionMode): + 会话模式,可选 server_commit 或 commit。 + audio_callback (Callable[[bytes], None]): + 接收音频数据的回调函数。 + language_type(str) + 合成的语音的语种,可选值Chinese、English、German、Italian、Portuguese、Spanish、Japanese、Korean、French、Russian、Auto + """ + + def __init__( + self, + base_url: str, + api_key: str, + voice: str = "Cherry", + mode: SessionMode = SessionMode.SERVER_COMMIT, + audio_callback: Optional[Callable[[bytes], None]] = None, + language_type: str = "Auto"): + self.base_url = base_url + self.api_key = api_key + self.voice = voice + self.mode = mode + self.ws = None + self.audio_callback = audio_callback + self.language_type = language_type + + # 当前回复状态 + self._current_response_id = None + self._current_item_id = None + self._is_responding = False + self._response_done_future = None + + + async def connect(self) -> None: + """与 TTS Realtime API 建立 WebSocket 连接。""" + headers = { + "Authorization": f"Bearer {self.api_key}" + } + + self.ws = await websockets.connect(self.base_url, additional_headers=headers) + + # 设置默认会话配置 + await self.update_session({ + "mode": self.mode.value, + "voice": self.voice, + "language_type": self.language_type, + "response_format": "pcm", + "sample_rate": 24000 + }) + + + async def send_event(self, event) -> None: + """发送事件到服务器。""" + event['event_id'] = "event_" + str(int(time.time() * 1000)) + print(f"发送事件: type={event['type']}, event_id={event['event_id']}") + await self.ws.send(json.dumps(event)) + + + async def update_session(self, config: Dict[str, Any]) -> None: + """更新会话配置。""" + event = { + "type": "session.update", + "session": config + } + print("更新会话配置: ", event) + await self.send_event(event) + + + async def append_text(self, text: str) -> None: + """向 API 发送文本数据。""" + event = { + "type": "input_text_buffer.append", + "text": text + } + await self.send_event(event) + + + async def commit_text_buffer(self) -> None: + """提交文本缓冲区以触发处理。""" + event = { + "type": "input_text_buffer.commit" + } + await self.send_event(event) + + + async def clear_text_buffer(self) -> None: + """清除文本缓冲区。""" + event = { + "type": "input_text_buffer.clear" + } + await self.send_event(event) + + + async def finish_session(self) -> None: + """结束会话。""" + event = { + "type": "session.finish" + } + await self.send_event(event) + + + async def wait_for_response_done(self): + """等待 response.done 事件""" + if self._response_done_future: + await self._response_done_future + + + async def handle_messages(self) -> None: + """处理来自服务器的消息。""" + try: + async for message in self.ws: + event = json.loads(message) + event_type = event.get("type") + + if event_type != "response.audio.delta": + print(f"收到事件: {event_type}") + + if event_type == "error": + print("错误: ", event.get('error', {})) + continue + elif event_type == "session.created": + print("会话创建,ID: ", event.get('session', {}).get('id')) + elif event_type == "session.updated": + print("会话更新,ID: ", event.get('session', {}).get('id')) + elif event_type == "input_text_buffer.committed": + print("文本缓冲区已提交,项目ID: ", event.get('item_id')) + elif event_type == "input_text_buffer.cleared": + print("文本缓冲区已清除") + elif event_type == "response.created": + self._current_response_id = event.get("response", {}).get("id") + self._is_responding = True + # 创建新的 future 来等待 response.done + self._response_done_future = asyncio.Future() + print("响应已创建,ID: ", self._current_response_id) + elif event_type == "response.output_item.added": + self._current_item_id = event.get("item", {}).get("id") + print("输出项已添加,ID: ", self._current_item_id) + # 处理音频增量 + elif event_type == "response.audio.delta" and self.audio_callback: + audio_bytes = base64.b64decode(event.get("delta", "")) + self.audio_callback(audio_bytes) + elif event_type == "response.audio.done": + print("音频生成完成") + elif event_type == "response.done": + self._is_responding = False + self._current_response_id = None + self._current_item_id = None + # 标记 future 完成 + if self._response_done_future and not self._response_done_future.done(): + self._response_done_future.set_result(True) + print("响应完成") + elif event_type == "session.finished": + print("会话已结束") + + except websockets.exceptions.ConnectionClosed: + print("连接已关闭") + except Exception as e: + print("消息处理出错: ", str(e)) + + + async def close(self) -> None: + """关闭 WebSocket 连接。""" + if self.ws: + await self.ws.close() \ No newline at end of file diff --git a/main_logic/tts_client.py b/main_logic/tts_client.py index a9645b2501..a0d0a3e16b 100644 --- a/main_logic/tts_client.py +++ b/main_logic/tts_client.py @@ -12,6 +12,7 @@ import wave import aiohttp import asyncio +import os from functools import partial from utils.config_manager import get_config_manager from utils.logger_config import get_module_logger @@ -19,32 +20,32 @@ logger = get_module_logger(__name__, "Main") -def _resample_audio(audio_int16: np.ndarray, src_rate: int, dst_rate: int, +def _resample_audio(audio_int16: np.ndarray, src_rate: int, dst_rate: int, resampler: 'soxr.ResampleStream | None' = None) -> bytes: """使用 soxr 进行高质量音频重采样 - + Args: audio_int16: int16 格式的音频 numpy 数组 src_rate: 源采样率 dst_rate: 目标采样率 resampler: 可选的流式重采样器,用于维护 chunk 间状态 - + Returns: 重采样后的 bytes """ if src_rate == dst_rate: return audio_int16.tobytes() - + # 转换为 float32 进行高质量重采样 audio_float = audio_int16.astype(np.float32) / 32768.0 - + if resampler is not None: # 使用流式重采样器(维护 chunk 边界状态) resampled_float = resampler.resample_chunk(audio_float) else: # 无状态重采样(不推荐用于流式音频) resampled_float = soxr.resample(audio_float, src_rate, dst_rate, quality='HQ') - + # 转回 int16 resampled_int16 = (resampled_float * 32768.0).clip(-32768, 32767).astype(np.int16) return resampled_int16.tobytes() @@ -54,7 +55,7 @@ def step_realtime_tts_worker(request_queue, response_queue, audio_api_key, voice """ StepFun实时TTS worker(用于默认音色) 使用阶跃星辰的实时TTS API(step-tts-mini) - + Args: request_queue: 多进程请求队列,接收(speech_id, text)元组 response_queue: 多进程响应队列,发送音频数据(也用于发送就绪信号) @@ -62,11 +63,11 @@ def step_realtime_tts_worker(request_queue, response_queue, audio_api_key, voice voice_id: 音色ID,默认使用"qingchunshaonv" """ import asyncio - + # 使用默认音色 "qingchunshaonv" if not voice_id: voice_id = "qingchunshaonv" - + async def async_worker(): """异步TTS worker主循环""" if free_mode: @@ -81,13 +82,13 @@ async def async_worker(): response_done = asyncio.Event() # 用于标记当前响应是否完成 # 流式重采样器(24kHz→48kHz)- 维护 chunk 边界状态 resampler = soxr.ResampleStream(24000, 48000, 1, dtype='float32') - + try: # 连接WebSocket headers = {"Authorization": f"Bearer {audio_api_key}"} - + ws = await websockets.connect(tts_url, additional_headers=headers) - + # 等待连接成功事件 async def wait_for_connection(): """等待连接成功""" @@ -96,7 +97,7 @@ async def wait_for_connection(): async for message in ws: event = json.loads(message) event_type = event.get("type") - + if event_type == "tts.connection.done": session_id = event.get("data", {}).get("session_id") session_ready.set() @@ -106,7 +107,7 @@ async def wait_for_connection(): break except Exception as e: logger.error(f"等待连接时出错: {e}") - + # 等待连接成功 try: await asyncio.wait_for(wait_for_connection(), timeout=5.0) @@ -115,13 +116,13 @@ async def wait_for_connection(): # 发送失败信号 response_queue.put(("__ready__", False)) return - + if not session_ready.is_set() or not session_id: logger.error("连接未能正确建立") # 发送失败信号 response_queue.put(("__ready__", False)) return - + # 发送创建会话事件 create_event = { "type": "tts.create", @@ -133,14 +134,14 @@ async def wait_for_connection(): } } await ws.send(json.dumps(create_event)) - + # 等待会话创建成功 async def wait_for_session_ready(): try: async for message in ws: event = json.loads(message) event_type = event.get("type") - + if event_type == "tts.response.created": break elif event_type == "tts.response.error": @@ -148,16 +149,16 @@ async def wait_for_session_ready(): break except Exception as e: logger.error(f"等待会话创建时出错: {e}") - + try: await asyncio.wait_for(wait_for_session_ready(), timeout=1.0) except asyncio.TimeoutError: logger.warning("会话创建超时") - + # 发送就绪信号,通知主进程 TTS 已经可以使用 logger.info("StepFun TTS 已就绪,发送就绪信号") response_queue.put(("__ready__", True)) - + # 初始接收任务 async def receive_messages_initial(): """初始接收任务""" @@ -165,7 +166,7 @@ async def receive_messages_initial(): async for message in ws: event = json.loads(message) event_type = event.get("type") - + if event_type == "tts.response.error": logger.error(f"TTS错误: {event}") elif event_type == "tts.response.audio.delta": @@ -179,7 +180,7 @@ async def receive_messages_initial(): with wave.open(wav_io, 'rb') as wav_file: # 读取音频数据 pcm_data = wav_file.readframes(wav_file.getnframes()) - + # 转换为 numpy 数组 audio_array = np.frombuffer(pcm_data, dtype=np.int16) # 使用流式重采样器 24000Hz -> 48000Hz @@ -194,9 +195,9 @@ async def receive_messages_initial(): pass except Exception as e: logger.error(f"消息接收出错: {e}") - + receive_task = asyncio.create_task(receive_messages_initial()) - + # 主循环:处理请求队列 loop = asyncio.get_running_loop() while True: @@ -204,7 +205,7 @@ async def receive_messages_initial(): sid, tts_text = await loop.run_in_executor(None, request_queue.get) except Exception: break - + if sid is None: # 提交缓冲区完成当前合成 if ws and session_id and current_speech_id is not None: @@ -220,8 +221,8 @@ async def receive_messages_initial(): await asyncio.wait_for(response_done.wait(), timeout=20.0) logger.debug("音频生成完成,主动关闭连接") except asyncio.TimeoutError: - logger.warning("等待响应完成超时(20秒),强制关闭连接") - + logger.warning("等待响应完成超时(30秒),强制关闭连接") + # 主动关闭连接,避免连接一直保持到超时 if ws: try: @@ -242,7 +243,7 @@ async def receive_messages_initial(): except Exception as e: logger.error(f"完成生成失败: {e}") continue - + # 新的语音ID,重新建立连接 if current_speech_id != sid: current_speech_id = sid @@ -259,15 +260,15 @@ async def receive_messages_initial(): await receive_task except asyncio.CancelledError: pass - + # 建立新连接 try: ws = await websockets.connect(tts_url, additional_headers=headers) - + # 等待连接成功 session_id = None session_ready.clear() - + async def wait_conn(): nonlocal session_id try: @@ -279,16 +280,16 @@ async def wait_conn(): break except Exception: pass - + try: await asyncio.wait_for(wait_conn(), timeout=1.0) except asyncio.TimeoutError: logger.warning("新连接超时") continue - + if not session_id: continue - + # 创建会话 await ws.send(json.dumps({ "type": "tts.create", @@ -299,14 +300,14 @@ async def wait_conn(): "sample_rate": 24000 } })) - + # 启动新的接收任务 async def receive_messages(): try: async for message in ws: event = json.loads(message) event_type = event.get("type") - + if event_type == "tts.response.error": logger.error(f"TTS错误: {event}") elif event_type == "tts.response.audio.delta": @@ -319,11 +320,12 @@ async def receive_messages(): with wave.open(wav_io, 'rb') as wav_file: # 读取音频数据 pcm_data = wav_file.readframes(wav_file.getnframes()) - + # 转换为 numpy 数组 audio_array = np.frombuffer(pcm_data, dtype=np.int16) # 使用流式重采样器 24000Hz -> 48000Hz - response_queue.put(_resample_audio(audio_array, 24000, 48000, resampler)) + response_queue.put( + _resample_audio(audio_array, 24000, 48000, resampler)) except Exception as e: logger.error(f"处理音频数据时出错: {e}") elif event_type in ["tts.response.done", "tts.response.audio.done"]: @@ -334,20 +336,20 @@ async def receive_messages(): pass except Exception as e: logger.error(f"消息接收出错: {e}") - + receive_task = asyncio.create_task(receive_messages()) - + except Exception as e: logger.error(f"重新建立连接失败: {e}") continue - + # 检查文本有效性 if not tts_text or not tts_text.strip(): continue - + if not ws or not session_id: continue - + # 发送文本 try: text_event = { @@ -366,7 +368,7 @@ async def receive_messages(): current_speech_id = None # 清空ID以强制下次重连 if receive_task and not receive_task.done(): receive_task.cancel() - + except Exception as e: logger.error(f"StepFun实时TTS Worker错误: {e}") finally: @@ -377,13 +379,13 @@ async def receive_messages(): await receive_task except asyncio.CancelledError: pass - + if ws: try: await ws.close() except Exception: pass - + # 运行异步worker try: asyncio.run(async_worker()) @@ -395,7 +397,7 @@ def qwen_realtime_tts_worker(request_queue, response_queue, audio_api_key, voice """ Qwen实时TTS worker(用于默认音色) 使用阿里云的实时TTS API(qwen3-tts-flash-2025-09-18) - + Args: request_queue: 多进程请求队列,接收(speech_id, text)元组 response_queue: 多进程响应队列,发送音频数据(也用于发送就绪信号) @@ -406,7 +408,7 @@ def qwen_realtime_tts_worker(request_queue, response_queue, audio_api_key, voice if not voice_id: voice_id = "Momo" - + async def async_worker(): """异步TTS worker主循环""" tts_url = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime?model=qwen3-tts-flash-realtime-2025-11-27" @@ -417,11 +419,11 @@ async def async_worker(): response_done = asyncio.Event() # 用于标记当前响应是否完成 # 流式重采样器(24kHz→48kHz)- 维护 chunk 边界状态 resampler = soxr.ResampleStream(24000, 48000, 1, dtype='float32') - + try: # 连接WebSocket headers = {"Authorization": f"Bearer {audio_api_key}"} - + # 配置会话消息模板(在重连时复用) # 使用 SERVER_COMMIT 模式:多次 append 文本,最后手动 commit 触发合成 # 这样可以累积文本,避免"一个字一个字往外蹦"的问题 @@ -437,9 +439,9 @@ async def async_worker(): "bit_depth": 16 } } - + ws = await websockets.connect(tts_url, additional_headers=headers) - + # 等待并处理初始消息 async def wait_for_session_ready(): """等待会话创建确认""" @@ -447,7 +449,7 @@ async def wait_for_session_ready(): async for message in ws: event = json.loads(message) event_type = event.get("type") - + # Qwen TTS API 返回 session.updated 而不是 session.created if event_type in ["session.created", "session.updated"]: session_ready.set() @@ -457,10 +459,10 @@ async def wait_for_session_ready(): break except Exception as e: logger.error(f"等待会话就绪时出错: {e}") - + # 发送配置 await ws.send(json.dumps(config_message)) - + # 等待会话就绪(超时5秒) try: await asyncio.wait_for(wait_for_session_ready(), timeout=5.0) @@ -468,16 +470,16 @@ async def wait_for_session_ready(): logger.error("❌ 等待会话就绪超时") response_queue.put(("__ready__", False)) return - + if not session_ready.is_set(): logger.error("❌ 会话未能正确初始化") response_queue.put(("__ready__", False)) return - + # 发送就绪信号 logger.info("Qwen TTS 已就绪,发送就绪信号") response_queue.put(("__ready__", True)) - + # 初始接收任务(会在每次新 speech_id 时重新创建) async def receive_messages_initial(): """初始接收任务""" @@ -485,7 +487,7 @@ async def receive_messages_initial(): async for message in ws: event = json.loads(message) event_type = event.get("type") - + if event_type == "error": logger.error(f"TTS错误: {event}") elif event_type == "response.audio.delta": @@ -504,9 +506,9 @@ async def receive_messages_initial(): pass except Exception as e: logger.error(f"消息接收出错: {e}") - + receive_task = asyncio.create_task(receive_messages_initial()) - + # 主循环:处理请求队列 loop = asyncio.get_running_loop() while True: @@ -515,7 +517,7 @@ async def receive_messages_initial(): sid, tts_text = await loop.run_in_executor(None, request_queue.get) except Exception: break - + if sid is None: # 提交缓冲区完成当前合成(仅当之前有文本时) if ws and session_ready.is_set() and current_speech_id is not None: @@ -531,7 +533,7 @@ async def receive_messages_initial(): logger.debug("音频生成完成,主动关闭连接") except asyncio.TimeoutError: logger.warning("等待响应完成超时(20秒),强制关闭连接") - + # 主动关闭连接,避免连接一直保持到超时 if ws: try: @@ -551,7 +553,7 @@ async def receive_messages_initial(): except Exception as e: logger.error(f"提交缓冲区失败: {e}") continue - + # 新的语音ID,重新建立连接(类似 speech_synthesis_worker 的逻辑) # 直接关闭旧连接,打断旧语音 if current_speech_id != sid: @@ -569,15 +571,15 @@ async def receive_messages_initial(): await receive_task except asyncio.CancelledError: pass - + # 建立新连接 try: ws = await websockets.connect(tts_url, additional_headers=headers) await ws.send(json.dumps(config_message)) - + # 等待 session.created session_ready.clear() - + async def wait_ready(): try: async for message in ws: @@ -592,19 +594,19 @@ async def wait_ready(): break except Exception as e: logger.error(f"wait_ready 异常: {e}") - + try: await asyncio.wait_for(wait_ready(), timeout=2.0) except asyncio.TimeoutError: logger.warning("新会话创建超时") - + # 启动新的接收任务 async def receive_messages(): try: async for message in ws: event = json.loads(message) event_type = event.get("type") - + if event_type == "error": logger.error(f"TTS错误: {event}") elif event_type == "response.audio.delta": @@ -623,20 +625,20 @@ async def receive_messages(): pass except Exception as e: logger.error(f"消息接收出错: {e}") - + receive_task = asyncio.create_task(receive_messages()) - + except Exception as e: logger.error(f"重新建立连接失败: {e}") continue - + # 检查文本有效性 if not tts_text or not tts_text.strip(): continue - + if not ws or not session_ready.is_set(): continue - + # 追加文本到缓冲区(不立即提交,等待响应完成时的终止信号再 commit) try: await ws.send(json.dumps({ @@ -652,7 +654,7 @@ async def receive_messages(): session_ready.clear() if receive_task and not receive_task.done(): receive_task.cancel() - + except Exception as e: logger.error(f"Qwen实时TTS Worker错误: {e}") finally: @@ -663,13 +665,13 @@ async def receive_messages(): await receive_task except asyncio.CancelledError: pass - + if ws: try: await ws.close() except Exception: pass - + # 运行异步worker try: asyncio.run(async_worker()) @@ -680,7 +682,7 @@ async def receive_messages(): def cosyvoice_vc_tts_worker(request_queue, response_queue, audio_api_key, voice_id): """ TTS多进程worker函数,用于阿里云CosyVoice TTS - + Args: request_queue: 多进程请求队列,接收(speech_id, text)元组 response_queue: 多进程响应队列,发送音频数据(也用于发送就绪信号) @@ -690,7 +692,7 @@ def cosyvoice_vc_tts_worker(request_queue, response_queue, audio_api_key, voice_ import re import dashscope from dashscope.audio.tts_v2 import ResultCallback, SpeechSynthesizer, AudioFormat - + dashscope.api_key = audio_api_key _RE_KANA = re.compile(r'[\u3040-\u309F\u30A0-\u30FF]') @@ -699,33 +701,33 @@ def cosyvoice_vc_tts_worker(request_queue, response_queue, audio_api_key, voice_ # CosyVoice 不需要预连接,直接发送就绪信号 logger.info("CosyVoice TTS 已就绪,发送就绪信号") response_queue.put(("__ready__", True)) - + class Callback(ResultCallback): def __init__(self, response_queue): self.response_queue = response_queue - - def on_open(self): + + def on_open(self): # 连接已建立,发送就绪信号 elapsed = time.time() - self.construct_start_time if hasattr(self, 'construct_start_time') else -1 logger.debug(f"TTS 连接已建立 (构造到open耗时: {elapsed:.2f}s)") - - def on_complete(self): + + def on_complete(self): pass - - def on_error(self, message: str): + + def on_error(self, message: str): if "request timeout after 23 seconds" not in message: logger.error(f"TTS Error: {message}") - - def on_close(self): + + def on_close(self): pass - - def on_event(self, message): + + def on_event(self, message): pass - + def on_data(self, data: bytes) -> None: # 直接转发 OGG OPUS 数据到前端解码 self.response_queue.put(data) - + callback = Callback(response_queue) current_speech_id = None synthesizer = None @@ -786,7 +788,7 @@ def _flush_buffer(): synthesizer.close() except Exception: pass - synthesizer = None + synthesizer = None current_speech_id = None char_buffer = "" detected_lang = None @@ -856,7 +858,7 @@ def cogtts_tts_worker(request_queue, response_queue, audio_api_key, voice_id): 使用智谱AI的CogTTS API(cogtts) 注意:CogTTS不支持流式输入,只支持流式输出 因此需要累积文本后一次性发送,但可以流式接收音频 - + Args: request_queue: 多进程请求队列,接收(speech_id, text)元组 response_queue: 多进程响应队列,发送音频数据(也用于发送就绪信号) @@ -864,35 +866,35 @@ def cogtts_tts_worker(request_queue, response_queue, audio_api_key, voice_id): voice_id: 音色ID,默认使用"tongtong"(支持:tongtong, chuichui, xiaochen, jam, kazi, douji, luodo) """ import asyncio - + # 使用默认音色 "tongtong" if not voice_id: voice_id = "tongtong" - + async def async_worker(): """异步TTS worker主循环""" tts_url = "https://open.bigmodel.cn/api/paas/v4/audio/speech" current_speech_id = None text_buffer = [] # 累积文本缓冲区 - + # CogTTS 是基于 HTTP 的,无需建立持久连接,直接发送就绪信号 logger.info("CogTTS TTS 已就绪,发送就绪信号") response_queue.put(("__ready__", True)) - + try: loop = asyncio.get_running_loop() - + while True: try: sid, tts_text = await loop.run_in_executor(None, request_queue.get) except Exception: break - + # 新的语音ID,清空缓冲区并重新开始 if current_speech_id != sid and sid is not None: current_speech_id = sid text_buffer = [] - + if sid is None: # 收到终止信号,合成累积的文本 if text_buffer and current_speech_id is not None: @@ -904,7 +906,7 @@ async def async_worker(): "Authorization": f"Bearer {audio_api_key}", "Content-Type": "application/json" } - + payload = { "model": "cogtts", "input": full_text[:1024], # CogTTS最大支持1024字符 @@ -915,7 +917,7 @@ async def async_worker(): "volume": 1.0, "stream": True, } - + # 使用异步HTTP客户端流式接收SSE响应 async with aiohttp.ClientSession() as session: async with session.post(tts_url, headers=headers, json=payload) as resp: @@ -927,45 +929,45 @@ async def async_worker(): async for chunk in resp.content.iter_any(): # 解码并添加到缓冲区 buffer += chunk.decode('utf-8') - + # 按行分割处理 while '\n' in buffer: line, buffer = buffer.split('\n', 1) line = line.strip() - + # 跳过空行 if not line: continue - + # 解析SSE格式: data: {...} if line.startswith('data: '): json_str = line[6:] # 去掉 "data: " 前缀 try: event_data = json.loads(json_str) - + # 提取音频数据: choices[0].delta.content choices = event_data.get('choices', []) if choices and 'delta' in choices[0]: delta = choices[0]['delta'] audio_b64 = delta.get('content', '') - + if audio_b64: # Base64解码得到PCM数据 audio_bytes = base64.b64decode(audio_b64) - + # 跳过过小的音频块(可能是初始化数据) # 至少需要 100 个采样点(约 4ms@24kHz)才处理 if len(audio_bytes) < 200: # 100 samples * 2 bytes logger.debug(f"跳过过小的音频块: {len(audio_bytes)} bytes") continue - + # CogTTS返回PCM格式(24000Hz, mono, 16bit) # 从返回的 return_sample_rate 获取采样率 sample_rate = delta.get('return_sample_rate', 24000) - + # 转换为 float32 进行高质量重采样 audio_array = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0 - + # 对第一个音频块,裁剪掉开头的噪音部分(CogTTS有初始化噪音) if not first_audio_received: first_audio_received = True @@ -979,7 +981,7 @@ async def async_worker(): if fade_samples > 0: fade_curve = np.linspace(0.0, 1.0, fade_samples) audio_array[:fade_samples] *= fade_curve - + # 使用 soxr 进行高质量重采样 resampled = soxr.resample(audio_array, sample_rate, 48000, quality='HQ') # 转回 int16 格式 @@ -994,18 +996,18 @@ async def async_worker(): logger.error(f"CogTTS API错误 ({resp.status}): {error_text}") except Exception as e: logger.error(f"CogTTS合成失败: {e}") - + # 清空缓冲区 text_buffer = [] continue - + # 累积文本到缓冲区(不立即发送) if tts_text and tts_text.strip(): text_buffer.append(tts_text) - + except Exception as e: logger.error(f"CogTTS Worker错误: {e}") - + # 运行异步worker try: asyncio.run(async_worker()) @@ -1171,7 +1173,7 @@ def openai_tts_worker(request_queue, response_queue, audio_api_key, voice_id): 使用 OpenAI 的 TTS API(gpt-4o-mini-tts) 注意:OpenAI TTS 不支持流式输入,只支持流式输出 因此需要累积文本后一次性发送,但可以流式接收音频 - + Args: request_queue: 多进程请求队列,接收(speech_id, text)元组 response_queue: 多进程响应队列,发送音频数据(也用于发送就绪信号) @@ -1179,7 +1181,7 @@ def openai_tts_worker(request_queue, response_queue, audio_api_key, voice_id): voice_id: 音色ID,默认使用"marin"(支持:marin, alloy, ash, ballad, coral, echo, fable, onyx, nova, sage, shimmer) """ import asyncio - + try: from openai import AsyncOpenAI except ImportError: @@ -1193,37 +1195,37 @@ def openai_tts_worker(request_queue, response_queue, audio_api_key, voice_id): except Exception: break return - + # 使用默认音色 "marin" if not voice_id: voice_id = "marin" - + async def async_worker(): """异步TTS worker主循环""" current_speech_id = None text_buffer = [] # 累积文本缓冲区 - + # 初始化 OpenAI 客户端 client = AsyncOpenAI(api_key=audio_api_key) - + # OpenAI TTS 是基于 HTTP 的,无需建立持久连接,直接发送就绪信号 logger.info("OpenAI TTS 已就绪,发送就绪信号") response_queue.put(("__ready__", True)) - + try: loop = asyncio.get_running_loop() - + while True: try: sid, tts_text = await loop.run_in_executor(None, request_queue.get) except Exception: break - + # 新的语音ID,清空缓冲区并重新开始 if current_speech_id != sid and sid is not None: current_speech_id = sid text_buffer = [] - + if sid is None: # 收到终止信号,合成累积的文本 if text_buffer and current_speech_id is not None: @@ -1246,22 +1248,22 @@ async def async_worker(): # 重采样到 48000Hz resampled_bytes = _resample_audio(audio_array, 24000, 48000) response_queue.put(resampled_bytes) - + except Exception as e: logger.error(f"OpenAI TTS 合成失败: {e}") - + # 清空缓冲区 text_buffer = [] current_speech_id = None continue - + # 累积文本到缓冲区(不立即发送) if tts_text and tts_text.strip(): text_buffer.append(tts_text) - + except Exception as e: logger.error(f"OpenAI TTS Worker错误: {e}") - + # 运行异步worker try: asyncio.run(async_worker()) @@ -1271,14 +1273,14 @@ async def async_worker(): def gptsovits_tts_worker(request_queue, response_queue, audio_api_key, voice_id): """GPT-SoVITS TTS Worker - 使用 v3 WebSocket stream-input 双工模式 - + Args: request_queue: 多进程请求队列,接收 (speech_id, text) 元组 response_queue: 多进程响应队列,发送音频数据(也用于发送就绪信号) audio_api_key: API密钥(未使用,保持接口一致) voice_id: v3 声音配置ID,格式为 "voice_id" 或 "voice_id|高级参数JSON" 例如: "my_voice" 或 "my_voice|{\"speed\":1.2,\"text_lang\":\"all_zh\"}" - + 配置项(通过 TTS_MODEL_URL 设置): base_url: GPT-SoVITS API 地址,如 "http://127.0.0.1:9881" 会自动转换为 ws:// 协议用于 WebSocket 连接 @@ -1527,7 +1529,7 @@ def dummy_tts_worker(request_queue, response_queue, audio_api_key, voice_id): """ 空的TTS worker(用于不支持TTS的core_api) 持续清空请求队列但不生成任何音频,使程序正常运行但无语音输出 - + Args: request_queue: 多进程请求队列,接收(speech_id, text)元组 response_queue: 多进程响应队列(也用于发送就绪信号) @@ -1535,10 +1537,10 @@ def dummy_tts_worker(request_queue, response_queue, audio_api_key, voice_id): voice_id: 音色ID(不使用) """ logger.warning("TTS Worker 未启用,不会生成语音") - + # 立即发送就绪信号 response_queue.put(("__ready__", True)) - + while True: try: # 持续清空队列以避免阻塞,但不做任何处理 @@ -1554,11 +1556,11 @@ def dummy_tts_worker(request_queue, response_queue, audio_api_key, voice_id): def get_tts_worker(core_api_type='qwen', has_custom_voice=False): """ 根据 core_api 类型和是否有自定义音色,返回对应的 TTS worker 函数 - + Args: core_api_type: core API 类型 ('qwen', 'step', 'glm' 等) has_custom_voice: 是否有自定义音色 (voice_id) - + Returns: 对应的 TTS worker 函数 """ @@ -1573,7 +1575,9 @@ def get_tts_worker(core_api_type='qwen', has_custom_voice=False): # local_cosyvoice:配置 ws:// URL,直接使用 WebSocket if base_url.startswith('http://') or base_url.startswith('https://'): return gptsovits_tts_worker - return local_cosyvoice_worker + if base_url.startswith('ws://') or base_url.startswith('wss://'): + return local_qwen3_tts_worker + # return local_cosyvoice_worker # 旧cosyvoice 的启用 正常情况下似乎不能使用 except Exception as e: logger.warning(f'TTS调度器检查报告:{e}') @@ -1603,18 +1607,18 @@ def local_cosyvoice_worker(request_queue, response_queue, audio_api_key, voice_i """ 本地 CosyVoice WebSocket Worker(OpenAI 兼容 bistream 版本) 适配 openai_server.py 定义的 /v1/audio/speech/stream 接口 - + 协议流程: 1. 连接后发送 config: {"voice": ..., "speed": ...} 2. 发送文本: {"text": ...} 3. 发送结束信号: {"event": "end"} 4. 接收 bytes 音频数据(16-bit PCM, 22050Hz) - + 特性: - 双工流:发送和接收独立运行,互不阻塞 - 打断支持:speech_id 变化时关闭旧连接,打断旧语音 - 非阻塞:异步架构,不会卡住主循环 - + 注意:audio_api_key 参数未使用(本地模式不需要 API Key),保留是为了与其他 worker 保持统一签名 """ _ = audio_api_key # 本地模式不需要 API Key @@ -1638,10 +1642,10 @@ def local_cosyvoice_worker(request_queue, response_queue, audio_api_key, voice_i except Exception: break return - + # OpenAI 兼容端点 WS_URL = f'{ws_base}/v1/audio/speech/stream' - + # 从 voice_id 解析 voice 和 speed(格式:voice 或 voice:speed) voice_name = voice_id or "中文女" speech_speed = 1.0 @@ -1652,7 +1656,7 @@ def local_cosyvoice_worker(request_queue, response_queue, audio_api_key, voice_i speech_speed = float(parts[1]) except ValueError: pass - + # 服务器返回的采样率(22050Hz) SRC_RATE = 22050 @@ -1660,7 +1664,8 @@ async def async_worker(): ws = None receive_task = None current_speech_id = None - + response_done_event = asyncio.Event() + resampler = soxr.ResampleStream(SRC_RATE, 48000, 1, dtype='float32') async def receive_loop(ws_conn): @@ -1690,7 +1695,7 @@ async def send_end_signal(ws_conn): async def create_connection(): """创建新连接并发送配置""" nonlocal ws, receive_task, resampler - + # 清理旧连接 if receive_task and not receive_task.done(): receive_task.cancel() @@ -1703,14 +1708,14 @@ async def create_connection(): await ws.close() except Exception: pass - + # 重置 resampler resampler = soxr.ResampleStream(SRC_RATE, 48000, 1, dtype='float32') - + logger.info(f"🔄 [LocalTTS] 正在连接: {WS_URL}") ws = await websockets.connect(WS_URL, ping_interval=None) logger.info("✅ [LocalTTS] 连接成功") - + # 发送配置 config = { "voice": voice_name, @@ -1718,7 +1723,7 @@ async def create_connection(): } await ws.send(json.dumps(config)) logger.debug(f"发送配置: {config}") - + # 启动接收任务 receive_task = asyncio.create_task(receive_loop(ws)) return ws @@ -1747,7 +1752,7 @@ async def create_connection(): # 发送结束信号(文本已在实时流中发送过了) if ws: await send_end_signal(ws) - + current_speech_id = sid try: await create_connection() @@ -1765,7 +1770,7 @@ async def create_connection(): if not tts_text or not tts_text.strip(): continue - + # 同时发送(bistream 模式允许边发边收) if ws: try: @@ -1792,4 +1797,222 @@ async def create_connection(): try: asyncio.run(async_worker()) except Exception as e: - logger.error(f"Local CosyVoice Worker 崩溃: {e}") \ No newline at end of file + logger.error(f"Local CosyVoice Worker 崩溃: {e}") + +def local_qwen3_tts_worker(request_queue, response_queue, audio_api_key, voice_id): + _ = audio_api_key # 本地不需要 key + + cm = get_config_manager() + tts_config = cm.get_model_api_config('tts_custom') + + ws_base = tts_config.get('base_url', '') + if not ws_base: + logger.error("local_qwen3_tts 未配置 tts_custom.base_url") + response_queue.put(("__ready__", False)) + return + + # 协议标准化处理 + if ws_base.startswith("http://"): + ws_base = "ws://" + ws_base[len("http://"):] + elif ws_base.startswith("https://"): + ws_base = "wss://" + ws_base[len("https://"):] + + WS_URL = ws_base.rstrip("/") + + # voice_id 可扩展成选择不同 pt/不同 ref + # 这里先做最小可用:固定使用你 advanced_demo 的缓存文件 + # CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) + voice_pt_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "local_server", "qwen3_tts_server", "nyaning_voice.pt")) + # 目标采样率:前端 PCM 默认按 48k 播 + DST_RATE = 48000 + + # 🟢 [新增] 客户端极速发送策略开关 + # 配合服务端的 ENABLE_TRUE_STREAMING 使用 + TRUE_STREAM_MODE = True + + # 提交策略 + if TRUE_STREAM_MODE: + COMMIT_CHARS = 8 # 极速模式:只要拿到 8 个字就立刻发给服务端合成 + else: + COMMIT_CHARS = 60 # 兜底模式:等满 60 个字再发 + + COMMIT_PUNCS = ("。", "!", "?", ".", "!", "?", "\n") + + async def async_worker(): + ws = None + receive_task = None + current_speech_id = None + + text_buf = "" + src_rate = 24000 # default, will be overridden by audio.meta + resampler = soxr.ResampleStream(src_rate, DST_RATE, 1, dtype="float32") + response_done_event = asyncio.Event() + + async def receive_loop(ws_conn): + nonlocal src_rate, resampler + try: + async for message in ws_conn: + if isinstance(message, str): + try: + evt = json.loads(message) + except Exception as e: + logger.debug(f"local_qwen3 parse message error:{e}") + continue + if evt.get("type") == "response.done": + response_done_event.set() + continue + if evt.get("type") == "audio.meta": + sr = int(evt.get("sample_rate", src_rate)) + if sr != src_rate: + src_rate = sr + resampler = soxr.ResampleStream(src_rate, DST_RATE, 1, dtype="float32") + continue + + + if isinstance(message, bytes): + audio_i16 = np.frombuffer(message, dtype=np.int16) + resampled = _resample_audio(audio_i16, src_rate, DST_RATE, resampler) + response_queue.put(resampled) + except Exception as e: + logger.error(f"local_qwen3 receive_loop error: {e}") + + async def connect(): + """建立 WebSocket 连接""" + nonlocal ws, receive_task + try: + logger.info(f"正在连接本地 TTS 服务: {WS_URL}") + ws = await websockets.connect(WS_URL, ping_interval=None, max_size=None) + + # 启动接收任务 + receive_task = asyncio.create_task(receive_loop(ws)) + + # 发送初始握手/配置 (可选,视服务端实现而定) + # 即使服务端目前忽略 session.update,发送它也是个好习惯,方便未来扩展 + try: + await ws.send(json.dumps({ + "type": "session.update", + "voice_pt_path": voice_pt_path, + "language": "Chinese", + }, ensure_ascii=False)) + except Exception as e: + logger.debug(f"local_qwen3 session.update failed: {e}") + + return ws + except Exception as e: + logger.error(f"连接 TTS 服务失败: {e}") + raise + + async def _cleanup_ws(): + """关闭 ws 并停止接收任务(用于重连/退出清理)""" + nonlocal ws, receive_task + if receive_task and not receive_task.done(): + receive_task.cancel() + try: + await receive_task + except asyncio.CancelledError: + pass + except Exception: + pass + receive_task = None + + if ws: + try: + await ws.close() + except Exception: + pass + ws = None + + async def reconnect(): + """cancel/切换 speech 后重连,避免服务端/接收 loop 状态脏掉""" + await _cleanup_ws() + await connect() + + async def send_cancel(): + if ws: + try: + await ws.send(json.dumps({"type": "cancel"})) + except Exception as e: + logger.debug(f"local_qwen3 send_cancel failed: {e}") + + async def send_append(txt: str): + if ws and txt: + try: + await ws.send(json.dumps({"type": "input_text_buffer.append", "text": txt}, ensure_ascii=False)) + except Exception as e: + logger.debug(f"local_qwen3 send_append failed: {e}") + + async def send_commit(): + if ws: + try: + await ws.send(json.dumps({"type": "input_text_buffer.commit"})) + except Exception as e: + logger.debug(f"local_qwen3 send_commit failed: {e}") + + # init connect + try: + await connect() + # 连接成功,通知主进程 + response_queue.put(("__ready__", True)) + except Exception as e: + logger.error(f"local_qwen3 connect failed: {e}") + response_queue.put(("__ready__", False)) + return + + loop = asyncio.get_running_loop() + + try: + while True: + try: + sid, tts_text = await loop.run_in_executor(None, request_queue.get) + except Exception: + break + + # end signal from N.E.K.O + if sid is None: + if text_buf.strip(): + response_done_event.clear() + await send_append(text_buf) + await send_commit() + text_buf = "" + try: + await asyncio.wait_for(response_done_event.wait(), timeout=10.0) + except asyncio.TimeoutError: + logger.warning("local_qwen3 wait response.done timeout") + current_speech_id = None + continue + + # speech_id changed -> cancel old + if current_speech_id is not None and sid != current_speech_id: + await send_cancel() + try: + await reconnect() + except Exception as e: + logger.error(f"local_qwen3 reconnect failed: {e}") + text_buf = "" + resampler = soxr.ResampleStream(src_rate, DST_RATE, 1, dtype="float32") + + if current_speech_id != sid: + current_speech_id = sid + + if not tts_text or not tts_text.strip(): + continue + + text_buf += tts_text + + should_commit = False + if len(text_buf) >= COMMIT_CHARS: + should_commit = True + elif any(p in tts_text for p in COMMIT_PUNCS): + should_commit = True + + if should_commit: + await send_append(text_buf) + await send_commit() + text_buf = "" + finally: + await _cleanup_ws() + + try: + asyncio.run(async_worker()) + except Exception as e: + logger.error(f"local_qwen3_tts_worker crashed: {e}")