Skip to content

Commit 7af80a4

Browse files
committed
feat(tts): migrate speech providers to backend direct routing
1 parent 3808c0a commit 7af80a4

20 files changed

+5267
-568
lines changed

backend/app/api/tts.py

Lines changed: 747 additions & 1 deletion
Large diffs are not rendered by default.

backend/app/services/providers/registry.py

Lines changed: 206 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
from typing import List, Optional
1+
import os
2+
import json
3+
from functools import lru_cache
4+
from pathlib import Path
5+
from typing import Any, Dict, List, Optional, Set
26

37
import httpx
4-
import os
58

9+
from app.services.engines import runtime_store
610
from app.services.providers.types import ProviderConfig, ProviderValidation
711

812

@@ -43,6 +47,17 @@
4347
"aliyun-nls-asr",
4448
}
4549

50+
LOCAL_TTS_PROVIDER_IDS = {
51+
"volcengine-speech",
52+
"alibaba-cloud-model-studio-speech",
53+
}
54+
LOCAL_TTS_VOICES_DIR = Path(__file__).resolve().parent / "voices"
55+
LOCAL_TTS_VOICE_FILES = {
56+
"volcengine-speech": LOCAL_TTS_VOICES_DIR / "volcengine.json",
57+
"alibaba-cloud-model-studio-speech": LOCAL_TTS_VOICES_DIR / "alibaba.json",
58+
}
59+
60+
4661
class ProviderRegistry:
4762
async def validate(self, config: ProviderConfig) -> ProviderValidation:
4863
provider_id = config.provider_id
@@ -56,8 +71,10 @@ async def validate(self, config: ProviderConfig) -> ProviderValidation:
5671
if not str(api_key or "").strip():
5772
return ProviderValidation(valid=False, reason="Missing apiKey for Alibaba Bailian ASR")
5873
return ProviderValidation(valid=True)
74+
5975
if provider_id in {"dify", "fastgpt"}:
6076
return self._validate_basic(config, require_base_url=True, require_api_key=True)
77+
6178
if provider_id == "coze":
6279
result = self._validate_basic(config, require_base_url=True, require_api_key=True)
6380
if not result.valid:
@@ -85,27 +102,67 @@ async def list_models(self, config: ProviderConfig) -> List[dict]:
85102
{"id": "qwen3-asr-flash-realtime", "label": "qwen3-asr-flash-realtime"},
86103
{"id": "qwen3-asr-flash", "label": "qwen3-asr-flash"},
87104
]
105+
88106
if provider_id in OPENAI_COMPAT_IDS or "openai" in provider_id:
89107
return await self._fetch_openai_models(config)
108+
90109
return []
91110

92111
async def list_voices(self, config: ProviderConfig) -> List[dict]:
93-
return []
112+
if config.provider_id not in LOCAL_TTS_PROVIDER_IDS:
113+
return []
114+
115+
voices = await _load_local_tts_voices(config.provider_id)
116+
tts_runtime = runtime_store.get("tts", config.provider_id)
117+
118+
if config.provider_id == "alibaba-cloud-model-studio-speech":
119+
runtime_model = str(tts_runtime.model or "").strip() if tts_runtime else ""
120+
filter_model = config.model or runtime_model
121+
model_candidates = _resolve_model_candidates(filter_model)
122+
if model_candidates:
123+
voices = [
124+
voice
125+
for voice in voices
126+
if _voice_matches_model_candidates(voice, model_candidates)
127+
]
128+
129+
options: List[dict] = []
130+
for voice in voices:
131+
if not isinstance(voice, dict):
132+
continue
133+
134+
voice_id = voice.get("id")
135+
voice_name = voice.get("name")
136+
if not isinstance(voice_id, str) or not voice_id.strip():
137+
continue
138+
if not isinstance(voice_name, str) or not voice_name.strip():
139+
continue
140+
141+
options.append(
142+
{
143+
"id": voice_id.strip(),
144+
"label": voice_name.strip(),
145+
"description": _build_voice_description(voice),
146+
}
147+
)
148+
149+
return options
94150

95151
@staticmethod
96152
def _validate_basic(
97153
config: ProviderConfig, require_base_url: bool, require_api_key: bool
98154
) -> ProviderValidation:
99-
if require_api_key and not config.api_key:
155+
if require_api_key and not str(config.api_key or "").strip():
100156
return ProviderValidation(valid=False, reason="Missing API key")
101-
if require_base_url and not config.base_url:
157+
if require_base_url and not str(config.base_url or "").strip():
102158
return ProviderValidation(valid=False, reason="Missing base URL")
103159
return ProviderValidation(valid=True)
104160

105161
@staticmethod
106162
async def _fetch_openai_models(config: ProviderConfig) -> List[dict]:
107163
if not config.base_url:
108164
raise ValueError("Base URL is required")
165+
109166
headers = {}
110167
if config.api_key:
111168
headers["Authorization"] = f"Bearer {config.api_key}"
@@ -129,4 +186,148 @@ async def _fetch_openai_models(config: ProviderConfig) -> List[dict]:
129186
]
130187

131188

189+
async def _load_local_tts_voices(provider_id: str) -> List[dict]:
190+
path = LOCAL_TTS_VOICE_FILES.get(provider_id)
191+
if not path:
192+
return []
193+
return _load_local_tts_voices_cached(provider_id, str(path))
194+
195+
196+
@lru_cache(maxsize=8)
197+
def _load_local_tts_voices_cached(provider_id: str, path: str) -> List[dict]:
198+
source = Path(path)
199+
if not source.exists():
200+
return []
201+
202+
try:
203+
raw = json.loads(source.read_text(encoding="utf-8"))
204+
except Exception:
205+
return []
206+
207+
if provider_id == "alibaba-cloud-model-studio-speech":
208+
return _parse_alibaba_voices(raw)
209+
if provider_id == "volcengine-speech":
210+
return _parse_volcengine_voices(raw)
211+
return []
212+
213+
214+
def _parse_alibaba_voices(raw: Any) -> List[dict]:
215+
if not isinstance(raw, list):
216+
return []
217+
218+
voices: List[dict] = []
219+
for item in raw:
220+
if not isinstance(item, dict):
221+
continue
222+
223+
voice_id = str(item.get("voice") or "").strip()
224+
name = str(item.get("name") or "").strip()
225+
model = str(item.get("model") or "").strip()
226+
language = str(item.get("language") or "").strip()
227+
228+
if not voice_id:
229+
continue
230+
231+
voice: Dict[str, Any] = {
232+
"id": voice_id,
233+
"name": name or voice_id,
234+
"compatible_models": [model] if model else [],
235+
}
236+
if language:
237+
voice["languages"] = [{"title": language, "code": language}]
238+
voices.append(voice)
239+
240+
return voices
241+
242+
243+
def _parse_volcengine_voices(raw: Any) -> List[dict]:
244+
if not isinstance(raw, dict):
245+
return []
246+
247+
data = raw.get("data")
248+
if not isinstance(data, dict):
249+
return []
250+
251+
resource_packs = data.get("resource_packs")
252+
if not isinstance(resource_packs, list):
253+
return []
254+
255+
voices: List[dict] = []
256+
for item in resource_packs:
257+
if not isinstance(item, dict):
258+
continue
259+
260+
details = item.get("details")
261+
details = details if isinstance(details, dict) else {}
262+
voice_id = str(item.get("code") or "").strip()
263+
name = str(item.get("resource_display") or "").strip()
264+
language = str(details.get("language") or "").strip()
265+
266+
if not voice_id:
267+
continue
268+
269+
voice: Dict[str, Any] = {
270+
"id": voice_id,
271+
"name": name or voice_id,
272+
"compatible_models": ["v1"],
273+
}
274+
if language:
275+
voice["languages"] = [{"title": language, "code": language}]
276+
voices.append(voice)
277+
278+
return voices
279+
280+
281+
def _resolve_model_candidates(model: Optional[str]) -> Set[str]:
282+
if not model:
283+
return set()
284+
285+
candidate = model.strip()
286+
if not candidate:
287+
return set()
288+
289+
result: Set[str] = {candidate}
290+
if "/" in candidate:
291+
short_model = candidate.split("/")[-1].strip()
292+
if short_model:
293+
result.add(short_model)
294+
else:
295+
result.add(f"alibaba/{candidate}")
296+
return result
297+
298+
299+
def _voice_matches_model_candidates(voice: dict, model_candidates: Set[str]) -> bool:
300+
compatible_models = voice.get("compatible_models")
301+
if not isinstance(compatible_models, list) or len(compatible_models) == 0:
302+
return True
303+
304+
normalized = {
305+
str(model).strip()
306+
for model in compatible_models
307+
if isinstance(model, str) and str(model).strip()
308+
}
309+
return len(normalized.intersection(model_candidates)) > 0
310+
311+
312+
def _build_voice_description(voice: dict) -> str:
313+
descriptions: List[str] = []
314+
315+
languages = voice.get("languages")
316+
if isinstance(languages, list) and len(languages) > 0:
317+
titles: List[str] = []
318+
for language in languages:
319+
if not isinstance(language, dict):
320+
continue
321+
title = language.get("title")
322+
code = language.get("code")
323+
if isinstance(title, str) and title.strip():
324+
titles.append(title.strip())
325+
elif isinstance(code, str) and code.strip():
326+
titles.append(code.strip())
327+
if titles:
328+
descriptions.append(", ".join(titles))
329+
330+
return " | ".join(descriptions)
331+
332+
132333
registry = ProviderRegistry()

0 commit comments

Comments
 (0)