|
| 1 | +"""Minimal Kokoro v1.1-zh CLI wrapper for local_tts_server. |
| 2 | +
|
| 3 | +Usage: |
| 4 | + python kokoro_cli.py <text_file> <out_file> <voice> <speed> |
| 5 | +
|
| 6 | +Reads text from <text_file>, synthesizes with kokoro, writes WAV to <out_file>. |
| 7 | +""" |
| 8 | + |
| 9 | +from __future__ import annotations |
| 10 | + |
| 11 | +import argparse |
| 12 | +import os |
| 13 | +import sys |
| 14 | +import wave |
| 15 | +from pathlib import Path |
| 16 | + |
| 17 | +import numpy as np |
| 18 | + |
| 19 | + |
| 20 | +DEFAULT_REPO_ID = "hexgrad/Kokoro-82M-v1.1-zh" |
| 21 | +DEFAULT_VOICE = "zf_001" |
| 22 | +SAMPLE_RATE = 24000 |
| 23 | +SCRIPT_DIR = Path(__file__).resolve().parent |
| 24 | +DEFAULT_LOCAL_REPO = SCRIPT_DIR / "kokoro_models" / "Kokoro-82M-v1.1-zh" |
| 25 | + |
| 26 | + |
| 27 | +def _audio_from_result(result): |
| 28 | + if hasattr(result, "audio"): |
| 29 | + return result.audio |
| 30 | + if isinstance(result, tuple) and result: |
| 31 | + return result[-1] |
| 32 | + return None |
| 33 | + |
| 34 | + |
| 35 | +def _infer_lang_code(voice: str) -> str: |
| 36 | + # Kokoro uses single-letter lang codes: z=zh, a=en-us, b=en-gb. |
| 37 | + if voice.startswith(("a", "af", "am")): |
| 38 | + return "a" |
| 39 | + if voice.startswith(("b", "bf", "bm")): |
| 40 | + return "b" |
| 41 | + return "z" |
| 42 | + |
| 43 | + |
| 44 | +def _speed_callable(base_speed: float): |
| 45 | + """Mitigate rushed long Chinese phoneme sequences in v1.1-zh.""" |
| 46 | + |
| 47 | + base = base_speed if base_speed > 0 else 1.0 |
| 48 | + |
| 49 | + def speed_by_len(len_ps: int) -> float: |
| 50 | + speed = 1.0 |
| 51 | + if len_ps > 83 and len_ps < 183: |
| 52 | + speed = 1.0 - (len_ps - 83) / 500.0 |
| 53 | + elif len_ps >= 183: |
| 54 | + speed = 0.8 |
| 55 | + return max(0.5, speed * base) |
| 56 | + |
| 57 | + return speed_by_len |
| 58 | + |
| 59 | + |
| 60 | +def _resolve_local_model_dir() -> Path | None: |
| 61 | + raw = os.getenv("LOCAL_TTS_KOKORO_MODEL_DIR", "").strip() |
| 62 | + if raw: |
| 63 | + path = Path(raw) |
| 64 | + return path if path.is_dir() else None |
| 65 | + return DEFAULT_LOCAL_REPO if DEFAULT_LOCAL_REPO.is_dir() else None |
| 66 | + |
| 67 | + |
| 68 | +def _find_model_file(model_dir: Path) -> Path | None: |
| 69 | + preferred = model_dir / "kokoro-v1_1-zh.pth" |
| 70 | + if preferred.is_file(): |
| 71 | + return preferred |
| 72 | + candidates = sorted(model_dir.glob("*.pth")) |
| 73 | + return candidates[0] if candidates else None |
| 74 | + |
| 75 | + |
| 76 | +def _resolve_voice(voice: str, model_dir: Path | None) -> str: |
| 77 | + if not model_dir: |
| 78 | + return voice |
| 79 | + if voice.endswith(".pt"): |
| 80 | + return voice |
| 81 | + local_voice = model_dir / "voices" / f"{voice}.pt" |
| 82 | + return str(local_voice) if local_voice.is_file() else voice |
| 83 | + |
| 84 | + |
| 85 | +def _available_local_voices(model_dir: Path | None) -> set[str]: |
| 86 | + if not model_dir: |
| 87 | + return set() |
| 88 | + voices_dir = model_dir / "voices" |
| 89 | + if not voices_dir.is_dir(): |
| 90 | + return set() |
| 91 | + return {path.stem for path in voices_dir.glob("*.pt") if path.is_file()} |
| 92 | + |
| 93 | + |
| 94 | +def synthesize(text_path: str, out_path: str, voice: str, speed: float) -> int: |
| 95 | + try: |
| 96 | + import torch |
| 97 | + from kokoro import KModel, KPipeline |
| 98 | + except ImportError: |
| 99 | + print( |
| 100 | + 'kokoro v1.1-zh deps missing. Run: uv pip install "kokoro>=0.8.2" "misaki[zh]>=0.8.2"', |
| 101 | + file=sys.stderr, |
| 102 | + ) |
| 103 | + return 1 |
| 104 | + |
| 105 | + text = Path(text_path).read_text(encoding="utf-8").strip() |
| 106 | + if not text: |
| 107 | + print("Empty text file", file=sys.stderr) |
| 108 | + return 1 |
| 109 | + |
| 110 | + model_dir = _resolve_local_model_dir() |
| 111 | + repo_id = os.getenv("LOCAL_TTS_KOKORO_REPO_ID", DEFAULT_REPO_ID).strip() or DEFAULT_REPO_ID |
| 112 | + voice = (voice or "").strip() or os.getenv("LOCAL_TTS_KOKORO_DEFAULT_VOICE", DEFAULT_VOICE) |
| 113 | + available_voices = _available_local_voices(model_dir) |
| 114 | + if available_voices and voice not in available_voices: |
| 115 | + fallback_voice = os.getenv("LOCAL_TTS_KOKORO_DEFAULT_VOICE", DEFAULT_VOICE).strip() or DEFAULT_VOICE |
| 116 | + if fallback_voice not in available_voices: |
| 117 | + fallback_voice = sorted(available_voices)[0] |
| 118 | + print( |
| 119 | + f"Kokoro voice '{voice}' not found in local model dir; falling back to '{fallback_voice}'.", |
| 120 | + file=sys.stderr, |
| 121 | + ) |
| 122 | + voice = fallback_voice |
| 123 | + pipeline_voice = _resolve_voice(voice, model_dir) |
| 124 | + lang = _infer_lang_code(voice) |
| 125 | + device = os.getenv("LOCAL_TTS_KOKORO_DEVICE", "").strip() |
| 126 | + if not device: |
| 127 | + device = "cuda" if torch.cuda.is_available() else "cpu" |
| 128 | + |
| 129 | + if model_dir: |
| 130 | + config_path = model_dir / "config.json" |
| 131 | + model_path = _find_model_file(model_dir) |
| 132 | + if not config_path.is_file() or model_path is None: |
| 133 | + print( |
| 134 | + f"Invalid LOCAL_TTS_KOKORO_MODEL_DIR: {model_dir} " |
| 135 | + "(expected config.json and a .pth model file)", |
| 136 | + file=sys.stderr, |
| 137 | + ) |
| 138 | + return 1 |
| 139 | + model = KModel(repo_id=repo_id, config=str(config_path), model=str(model_path)).to(device).eval() |
| 140 | + else: |
| 141 | + model = KModel(repo_id=repo_id).to(device).eval() |
| 142 | + |
| 143 | + en_pipeline = None |
| 144 | + en_callable = None |
| 145 | + if lang == "z": |
| 146 | + en_pipeline = KPipeline(lang_code="a", repo_id=repo_id, model=False) |
| 147 | + |
| 148 | + def en_callable(text_part: str): |
| 149 | + if text_part == "Kokoro": |
| 150 | + return "kˈOkəɹO" |
| 151 | + if text_part == "Sol": |
| 152 | + return "sˈOl" |
| 153 | + return next(en_pipeline(text_part)).phonemes |
| 154 | + |
| 155 | + pipeline = KPipeline( |
| 156 | + lang_code=lang, |
| 157 | + repo_id=repo_id, |
| 158 | + model=model, |
| 159 | + en_callable=en_callable, |
| 160 | + ) |
| 161 | + effective_speed = _speed_callable(speed) if lang == "z" else speed |
| 162 | + generator = pipeline(text, voice=pipeline_voice, speed=effective_speed) |
| 163 | + |
| 164 | + chunks: list[np.ndarray] = [] |
| 165 | + for result in generator: |
| 166 | + audio = _audio_from_result(result) |
| 167 | + if audio is not None: |
| 168 | + chunks.append(np.asarray(audio, dtype=np.float32)) |
| 169 | + |
| 170 | + if not chunks: |
| 171 | + print("No audio generated", file=sys.stderr) |
| 172 | + return 1 |
| 173 | + |
| 174 | + pcm = np.concatenate(chunks) |
| 175 | + pcm = np.clip(pcm, -1.0, 1.0) |
| 176 | + pcm_int16 = (pcm * 32767.0).astype(np.int16) |
| 177 | + |
| 178 | + with wave.open(out_path, "wb") as wf: |
| 179 | + wf.setnchannels(1) |
| 180 | + wf.setsampwidth(2) |
| 181 | + wf.setframerate(SAMPLE_RATE) |
| 182 | + wf.writeframes(pcm_int16.tobytes()) |
| 183 | + |
| 184 | + print( |
| 185 | + f"Wrote {out_path}: {len(pcm_int16)} samples @ {SAMPLE_RATE} Hz " |
| 186 | + f"repo={repo_id} model_dir={model_dir or '<hf-cache>'} voice={voice} device={device}" |
| 187 | + ) |
| 188 | + return 0 |
| 189 | + |
| 190 | + |
| 191 | +def main() -> int: |
| 192 | + parser = argparse.ArgumentParser(description="Kokoro CLI wrapper for local_tts") |
| 193 | + parser.add_argument("text_file") |
| 194 | + parser.add_argument("out_file") |
| 195 | + parser.add_argument("voice") |
| 196 | + parser.add_argument("speed", type=float) |
| 197 | + args = parser.parse_args() |
| 198 | + return synthesize(args.text_file, args.out_file, args.voice, args.speed) |
| 199 | + |
| 200 | + |
| 201 | +if __name__ == "__main__": |
| 202 | + raise SystemExit(main()) |
0 commit comments