|
| 1 | +import argparse |
| 2 | +import gc |
1 | 3 | import socket |
2 | 4 | import struct |
3 | 5 | import torch |
4 | 6 | import torchaudio |
5 | | -from threading import Thread |
6 | | - |
7 | | - |
8 | | -import gc |
9 | 7 | import traceback |
| 8 | +from importlib.resources import files |
| 9 | +from threading import Thread |
10 | 10 |
|
| 11 | +from cached_path import cached_path |
11 | 12 |
|
12 | 13 | from infer.utils_infer import infer_batch_process, preprocess_ref_audio_text, load_vocoder, load_model |
13 | 14 | from model.backbones.dit import DiT |
14 | 15 |
|
15 | 16 |
|
16 | 17 | class TTSStreamingProcessor: |
17 | 18 | def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32): |
18 | | - self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| 19 | + self.device = device or ( |
| 20 | + "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
| 21 | + ) |
19 | 22 |
|
20 | 23 | # Load the model using the provided checkpoint and vocab files |
21 | 24 | self.model = load_model( |
@@ -137,23 +140,51 @@ def start_server(host, port, processor): |
137 | 140 |
|
138 | 141 |
|
139 | 142 | if __name__ == "__main__": |
140 | | - try: |
141 | | - # Load the model and vocoder using the provided files |
142 | | - ckpt_file = "" # pointing your checkpoint "ckpts/model/model_1096.pt" |
143 | | - vocab_file = "" # Add vocab file path if needed |
144 | | - ref_audio = "" # add ref audio"./tests/ref_audio/reference.wav" |
145 | | - ref_text = "" |
| 143 | + parser = argparse.ArgumentParser() |
| 144 | + |
| 145 | + parser.add_argument("--host", default="0.0.0.0") |
| 146 | + parser.add_argument("--port", default=9998) |
| 147 | + |
| 148 | + parser.add_argument( |
| 149 | + "--ckpt_file", |
| 150 | + default=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")), |
| 151 | + help="Path to the model checkpoint file", |
| 152 | + ) |
| 153 | + parser.add_argument( |
| 154 | + "--vocab_file", |
| 155 | + default="", |
| 156 | + help="Path to the vocab file if customized", |
| 157 | + ) |
| 158 | + |
| 159 | + parser.add_argument( |
| 160 | + "--ref_audio", |
| 161 | + default=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")), |
| 162 | + help="Reference audio to provide model with speaker characteristics", |
| 163 | + ) |
| 164 | + parser.add_argument( |
| 165 | + "--ref_text", |
| 166 | + default="", |
| 167 | + help="Reference audio subtitle, leave empty to auto-transcribe", |
| 168 | + ) |
| 169 | + |
| 170 | + parser.add_argument("--device", default=None, help="Device to run the model on") |
| 171 | + parser.add_argument("--dtype", default=torch.float32, help="Data type to use for model inference") |
| 172 | + |
| 173 | + args = parser.parse_args() |
146 | 174 |
|
| 175 | + try: |
147 | 176 | # Initialize the processor with the model and vocoder |
148 | 177 | processor = TTSStreamingProcessor( |
149 | | - ckpt_file=ckpt_file, |
150 | | - vocab_file=vocab_file, |
151 | | - ref_audio=ref_audio, |
152 | | - ref_text=ref_text, |
153 | | - dtype=torch.float32, |
| 178 | + ckpt_file=args.ckpt_file, |
| 179 | + vocab_file=args.vocab_file, |
| 180 | + ref_audio=args.ref_audio, |
| 181 | + ref_text=args.ref_text, |
| 182 | + device=args.device, |
| 183 | + dtype=args.dtype, |
154 | 184 | ) |
155 | 185 |
|
156 | 186 | # Start the server |
157 | | - start_server("0.0.0.0", 9998, processor) |
| 187 | + start_server(args.host, args.port, processor) |
| 188 | + |
158 | 189 | except KeyboardInterrupt: |
159 | 190 | gc.collect() |
0 commit comments