Skip to content

Commit deaca8d

Browse files
committed
v0.3.2 add flags and default values to socket_server.py
1 parent 0e1f2fc commit deaca8d

File tree

2 files changed

+49
-18
lines changed

2 files changed

+49
-18
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "f5-tts"
7-
version = "0.3.1"
7+
version = "0.3.2"
88
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
99
readme = "README.md"
1010
license = {text = "MIT License"}

src/f5_tts/socket_server.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
1+
import argparse
2+
import gc
13
import socket
24
import struct
35
import torch
46
import torchaudio
5-
from threading import Thread
6-
7-
8-
import gc
97
import traceback
8+
from importlib.resources import files
9+
from threading import Thread
1010

11+
from cached_path import cached_path
1112

1213
from infer.utils_infer import infer_batch_process, preprocess_ref_audio_text, load_vocoder, load_model
1314
from model.backbones.dit import DiT
1415

1516

1617
class TTSStreamingProcessor:
1718
def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
18-
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
19+
self.device = device or (
20+
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
21+
)
1922

2023
# Load the model using the provided checkpoint and vocab files
2124
self.model = load_model(
@@ -137,23 +140,51 @@ def start_server(host, port, processor):
137140

138141

139142
if __name__ == "__main__":
140-
try:
141-
# Load the model and vocoder using the provided files
142-
ckpt_file = "" # pointing your checkpoint "ckpts/model/model_1096.pt"
143-
vocab_file = "" # Add vocab file path if needed
144-
ref_audio = "" # add ref audio"./tests/ref_audio/reference.wav"
145-
ref_text = ""
143+
parser = argparse.ArgumentParser()
144+
145+
parser.add_argument("--host", default="0.0.0.0")
146+
parser.add_argument("--port", default=9998)
147+
148+
parser.add_argument(
149+
"--ckpt_file",
150+
default=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")),
151+
help="Path to the model checkpoint file",
152+
)
153+
parser.add_argument(
154+
"--vocab_file",
155+
default="",
156+
help="Path to the vocab file if customized",
157+
)
158+
159+
parser.add_argument(
160+
"--ref_audio",
161+
default=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
162+
help="Reference audio to provide model with speaker characteristics",
163+
)
164+
parser.add_argument(
165+
"--ref_text",
166+
default="",
167+
help="Reference audio subtitle, leave empty to auto-transcribe",
168+
)
169+
170+
parser.add_argument("--device", default=None, help="Device to run the model on")
171+
parser.add_argument("--dtype", default=torch.float32, help="Data type to use for model inference")
172+
173+
args = parser.parse_args()
146174

175+
try:
147176
# Initialize the processor with the model and vocoder
148177
processor = TTSStreamingProcessor(
149-
ckpt_file=ckpt_file,
150-
vocab_file=vocab_file,
151-
ref_audio=ref_audio,
152-
ref_text=ref_text,
153-
dtype=torch.float32,
178+
ckpt_file=args.ckpt_file,
179+
vocab_file=args.vocab_file,
180+
ref_audio=args.ref_audio,
181+
ref_text=args.ref_text,
182+
device=args.device,
183+
dtype=args.dtype,
154184
)
155185

156186
# Start the server
157-
start_server("0.0.0.0", 9998, processor)
187+
start_server(args.host, args.port, processor)
188+
158189
except KeyboardInterrupt:
159190
gc.collect()

0 commit comments

Comments
 (0)