Skip to content

Commit 3bc6f0b

Browse files
marksverdheiclaude
andcommitted
style: fix ruff formatting in speaker_embedding_interpolation example
Signed-off-by: marksverdhei <marksverdhei@hotmail.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: marksverdhei <marksverdhei@hotmail.com>
1 parent c0d3fcd commit 3bc6f0b

File tree

1 file changed

+19
-22
lines changed

1 file changed

+19
-22
lines changed

examples/online_serving/qwen3_tts/speaker_embedding_interpolation.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def load_speaker_encoder(model_path: str, device: str = "cpu") -> torch.nn.Modul
7373
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
7474
# Dynamically import from the downloaded model files
7575
import importlib
76-
import tempfile
7776

7877
from huggingface_hub import snapshot_download
7978

@@ -119,31 +118,24 @@ def _load_speaker_encoder_weights(encoder: torch.nn.Module, model_path: str) ->
119118
state_dict = {}
120119

121120
# Try safetensors first, then pytorch bin
122-
safetensor_files = sorted(
123-
f for f in os.listdir(model_dir) if f.endswith(".safetensors")
124-
)
121+
safetensor_files = sorted(f for f in os.listdir(model_dir) if f.endswith(".safetensors"))
125122
if safetensor_files:
126123
for fname in safetensor_files:
127124
shard = load_file(os.path.join(model_dir, fname))
128125
for k, v in shard.items():
129126
if k.startswith(prefix):
130127
state_dict[k[len(prefix) :]] = v
131128
else:
132-
bin_files = sorted(
133-
f for f in os.listdir(model_dir) if f.endswith(".bin")
134-
)
129+
bin_files = sorted(f for f in os.listdir(model_dir) if f.endswith(".bin"))
135130
for fname in bin_files:
136-
shard = torch.load(
137-
os.path.join(model_dir, fname), map_location="cpu", weights_only=True
138-
)
131+
shard = torch.load(os.path.join(model_dir, fname), map_location="cpu", weights_only=True)
139132
for k, v in shard.items():
140133
if k.startswith(prefix):
141134
state_dict[k[len(prefix) :]] = v
142135

143136
if not state_dict:
144137
raise RuntimeError(
145-
f"No speaker_encoder weights found in {model_path}. "
146-
"Make sure this is a Qwen3-TTS-*-Base checkpoint."
138+
f"No speaker_encoder weights found in {model_path}. Make sure this is a Qwen3-TTS-*-Base checkpoint."
147139
)
148140

149141
encoder.load_state_dict(state_dict)
@@ -161,9 +153,7 @@ def compute_mel_spectrogram(audio: np.ndarray, sr: int = 24000) -> torch.Tensor:
161153

162154
from librosa.filters import mel as librosa_mel_fn
163155

164-
mel_basis = torch.from_numpy(
165-
librosa_mel_fn(sr=24000, n_fft=1024, n_mels=128, fmin=0, fmax=12000)
166-
).float()
156+
mel_basis = torch.from_numpy(librosa_mel_fn(sr=24000, n_fft=1024, n_mels=128, fmin=0, fmax=12000)).float()
167157

168158
n_fft = 1024
169159
hop_size = 256
@@ -173,8 +163,13 @@ def compute_mel_spectrogram(audio: np.ndarray, sr: int = 24000) -> torch.Tensor:
173163

174164
hann_window = torch.hann_window(win_size)
175165
spec = torch.stft(
176-
y, n_fft, hop_length=hop_size, win_length=win_size,
177-
window=hann_window, center=False, return_complex=True,
166+
y,
167+
n_fft,
168+
hop_length=hop_size,
169+
win_length=win_size,
170+
window=hann_window,
171+
center=False,
172+
return_complex=True,
178173
)
179174
spec = torch.abs(spec)
180175
mel = torch.matmul(mel_basis, spec)
@@ -183,9 +178,7 @@ def compute_mel_spectrogram(audio: np.ndarray, sr: int = 24000) -> torch.Tensor:
183178

184179

185180
@torch.inference_mode()
186-
def extract_embedding(
187-
encoder: torch.nn.Module, audio_path: str, device: str = "cpu"
188-
) -> np.ndarray:
181+
def extract_embedding(encoder: torch.nn.Module, audio_path: str, device: str = "cpu") -> np.ndarray:
189182
"""Extract a 1024-dim speaker embedding from an audio file."""
190183
import librosa
191184

@@ -342,7 +335,8 @@ def main():
342335
parser.add_argument("--api-base", default=DEFAULT_API_BASE, help="TTS API base URL")
343336
parser.add_argument("--api-key", default=DEFAULT_API_KEY, help="API key")
344337
parser.add_argument(
345-
"--model", default="Qwen/Qwen3-TTS-12Hz-1.7B-Base",
338+
"--model",
339+
default="Qwen/Qwen3-TTS-12Hz-1.7B-Base",
346340
help="Model name (used for both weight loading and API requests)",
347341
)
348342
parser.add_argument("--device", default="cpu", help="Device for embedding extraction (cpu/cuda)")
@@ -367,7 +361,10 @@ def main():
367361
p_pipe.add_argument("--audio-a", required=True, help="Audio file for voice A")
368362
p_pipe.add_argument("--audio-b", required=True, help="Audio file for voice B")
369363
p_pipe.add_argument(
370-
"--ratios", nargs="+", type=float, default=[0.0, 0.25, 0.5, 0.75, 1.0],
364+
"--ratios",
365+
nargs="+",
366+
type=float,
367+
default=[0.0, 0.25, 0.5, 0.75, 1.0],
371368
help="SLERP ratios to generate (default: 0.0 0.25 0.5 0.75 1.0)",
372369
)
373370
p_pipe.add_argument("--text", required=True, help="Text to synthesize")

0 commit comments

Comments
 (0)