Skip to content

Commit 3f35dfc

Browse files
authored
Add opt-in v3 multilingual checkpoint, skip analyzer for v3 (#516)
* Add opt-in v3 multilingual checkpoint, skip analyzer for v3 * Remove alignment analyzer; lower rep_penalty default to 1.2; trim final speech token artifact
1 parent 59bc590 commit 3f35dfc

7 files changed

Lines changed: 54 additions & 218 deletions

File tree

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ ta.save("test-english.wav", wav, model.sr)
8484

8585
# Multilingual examples
8686
multilingual_model = ChatterboxMultilingualTTS.from_pretrained(device=device)
87+
# v2 remains the default. To use the v3 multilingual checkpoint:
88+
# multilingual_model = ChatterboxMultilingualTTS.from_pretrained(device=device, t3_model="v3")
8789

8890
french_text = "Bonjour, comment ça va? Ceci est le modèle de synthèse vocale multilingue Chatterbox, il prend en charge 23 langues."
8991
wav_french = multilingual_model.generate(french_text, language_id="fr")

example_tts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
ta.save("test-1.wav", wav, model.sr)
2222

2323
multilingual_model = ChatterboxMultilingualTTS.from_pretrained(device=device)
24+
# v2 is the default. Pass t3_model="v3" to use the v3 multilingual checkpoint.
2425
text = "Bonjour, comment ça va? Ceci est le modèle de synthèse vocale multilingue Chatterbox, il prend en charge 23 langues."
2526
wav = multilingual_model.generate(text, language_id="fr")
2627
ta.save("test-2.wav", wav, multilingual_model.sr)

multilingual_app.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import random
2+
import os
23
import numpy as np
34
import torch
45
from chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
56
import gradio as gr
67

78
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9+
T3_MODEL = os.getenv("CHATTERBOX_MULTILINGUAL_T3_MODEL", "v2")
810
print(f"🚀 Running on device: {DEVICE}")
11+
print(f"Using multilingual T3 model: {T3_MODEL}")
912

1013
# --- Global Model Initialization ---
1114
MODEL = None
@@ -140,7 +143,7 @@ def get_or_load_model():
140143
if MODEL is None:
141144
print("Model not loaded, initializing...")
142145
try:
143-
MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE)
146+
MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE, t3_model=T3_MODEL)
144147
if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
145148
MODEL.to(DEVICE)
146149
print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")

src/chatterbox/models/t3/inference/alignment_stream_analyzer.py

Lines changed: 0 additions & 181 deletions
This file was deleted.

src/chatterbox/models/t3/inference/t3_hf_backend.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,12 @@ def __init__(
2323
speech_head,
2424
latents_queue=None,
2525
logits_queue=None,
26-
alignment_stream_analyzer: 'AlignmentStreamAnalyzer'=None,
2726
):
2827
super().__init__(config)
2928
self.model = llama
3029
self.speech_enc = speech_enc
3130
self.speech_head = speech_head
3231
self._added_cond = False
33-
self.alignment_stream_analyzer = alignment_stream_analyzer
3432

3533
@torch.inference_mode()
3634
def prepare_inputs_for_generation(
@@ -105,9 +103,6 @@ def forward(
105103
logits = self.speech_head(hidden_states)
106104
# assert inputs_embeds.size(0) == 1 # (disabled for CFG)
107105

108-
# NOTE: hallucination handler may modify logits to force emit an EOS token
109-
# logits = self.alignment_stream_analyzer.step(logits)
110-
111106
return CausalLMOutputWithCrossAttentions(
112107
logits=logits,
113108
past_key_values=tfmr_out.past_key_values,

src/chatterbox/models/t3/t3.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from .modules.t3_config import T3Config
2525
from .llama_configs import LLAMA_CONFIGS
2626
from .inference.t3_hf_backend import T3HuggingfaceBackend
27-
from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer
2827
from ..utils import AttrDict
2928

3029

@@ -275,24 +274,11 @@ def inference(
275274
# TODO? synchronize the expensive compile function
276275
# with self.compile_lock:
277276
if not self.compiled:
278-
# Default to None for English models, only create for multilingual
279-
alignment_stream_analyzer = None
280-
if self.hp.is_multilingual:
281-
alignment_stream_analyzer = AlignmentStreamAnalyzer(
282-
self.tfmr,
283-
None,
284-
text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
285-
alignment_layer_idx=9, # TODO: hparam or something?
286-
eos_idx=self.hp.stop_speech_token,
287-
)
288-
assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token
289-
290277
patched_model = T3HuggingfaceBackend(
291278
config=self.cfg,
292279
llama=self.tfmr,
293280
speech_enc=self.speech_emb,
294281
speech_head=self.speech_head,
295-
alignment_stream_analyzer=alignment_stream_analyzer,
296282
)
297283
self.patched_model = patched_model
298284
self.compiled = True
@@ -341,7 +327,7 @@ def inference(
341327
inputs_embeds=inputs_embeds,
342328
past_key_values=None,
343329
use_cache=True,
344-
output_attentions=True,
330+
output_attentions=False,
345331
output_hidden_states=True,
346332
return_dict=True,
347333
)
@@ -357,14 +343,6 @@ def inference(
357343
cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype)
358344
logits = cond + cfg * (cond - uncond)
359345

360-
# Apply alignment stream analyzer integrity checks
361-
if self.patched_model.alignment_stream_analyzer is not None:
362-
if logits.dim() == 1: # guard in case something upstream squeezed
363-
logits = logits.unsqueeze(0) # (1, V)
364-
# Pass the last generated token for repetition tracking
365-
last_token = generated_ids[0, -1].item() if len(generated_ids[0]) > 0 else None
366-
logits = self.patched_model.alignment_stream_analyzer.step(logits, next_token=last_token) # (1, V)
367-
368346
# Apply repetition penalty
369347
ids_for_proc = generated_ids[:1, ...] # batch = 1
370348
logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B,V)
@@ -400,7 +378,7 @@ def inference(
400378
output = self.patched_model(
401379
inputs_embeds=next_token_embed,
402380
past_key_values=past,
403-
output_attentions=True,
381+
output_attentions=False,
404382
output_hidden_states=True,
405383
return_dict=True,
406384
)

0 commit comments

Comments
 (0)