Skip to content

Commit 798f948

Browse files
eustlbArthurZucker
andauthored
Add CSM model (#36719)
* draft structure * depth decoder with forward pre hook * full model forward draft * draft update * depth decoder update * ConversationalSpeechModelForCausalLM udpates * add generate * max length criteria small fix * udpate * updates * generation update * update in loss compute * conversion script * update for correct input embeddings * handle interleaved rope * update * update * update * support compile * update training * add doc * update doc * correct inits * ConversationalSpeechModel -> Csm * conf update * name update * tests CsmForCausalLMTest * convert use cached_file * conf + modeling updates * generate utils handle third dim shape * integration test * modeling + conf updates * common test handle more than 2 dims * add nested audio list utils * processing handle nested audio list * csm processing draft * mimi util * init updates * modular update * convert modular * processing update * csm tests update * generate tests handle third dim * generate utils handle third dim * propagate _get_initial_cache_position update * tied_weight_keys update + convert correctly * fix inputs_embeds * revert audio nested list * batch inference update + return audio * audio_utils update * processor update * some more integration tests * remove old test * porcessing output labels * improve * fix * update rope values with equivalent ones * conversion update * udpate tests * handle depth decoder generation config * remove default eos_token_id * make style * revert modeling_mimi * add default generation_config * remove sdpa since handled by default * make * fix conflict * fix conflicts * correct naming * correct imports * make * causal -> conditional naming * causal -> conditional naming * auto update * make * make * add doc * test update * fix weight init * audio tokens offsets as buffer * 4d mask in conditional class * make * doc update * fix causal mask * fix causal mask * doc update * doc update * add processor doc * update doc * fix 4d causal mask * update make_list_of_audio * do not default to mutable * remove duplicates * remove useless reset_parameters * use GradientCheckpointingLayer * use can_return_tuple * formatting * prepend placeholder in _sample * torch compile fix * some more fixies * convert modular * fix * default max_length in convert * handle depth decoder generation config correctly * clearer formulation * handle output_loading_info * handle softmax warning * add doc * propagate _get_initial_cache_position changes * generation in its own module * add processor tests * fix compile witu cuda graphs * fix compile with cuda graphs * add csm.md * include CSM loss * doc nit * doc nit * doc nit * Update docs/source/en/model_doc/csm.md Co-authored-by: Arthur <[email protected]> * add save_audio to processor * Update src/transformers/models/csm/modular_csm.py Co-authored-by: Arthur <[email protected]> * doc update * simplify audio_codes_mask computation * doc update * simplify loss computation * fix static cache test * fix * remove comment * simplify encoded length computation * use hf-internal-testing * doc update * cast to float before numpy * nit * mem efficient codebook head * nit * cat input values with cutoffs --------- Co-authored-by: Arthur <[email protected]>
1 parent c8607a1 commit 798f948

29 files changed

+5827
-86
lines changed

docs/source/en/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,8 @@
825825
title: Bark
826826
- local: model_doc/clap
827827
title: CLAP
828+
- local: model_doc/csm
829+
title: CSM
828830
- local: model_doc/dac
829831
title: dac
830832
- local: model_doc/encodec

docs/source/en/model_doc/csm.md

+377
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,377 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# Csm
18+
19+
## Overview
20+
21+
The Conversational Speech Model (CSM) is the first open-source contextual text-to-speech model [released by Sesame](https://www.sesame.com/research/crossing_the_uncanny_valley_of_voice). It is designed to generate natural-sounding speech with or without conversational context. This context typically consists of multi-turn dialogue between speakers, represented as sequences of text and corresponding spoken audio.
22+
23+
**Model Architecture:**
24+
CSM is composed of two LLaMA-style auto-regressive transformer decoders: a backbone decoder that predicts the first codebook token and a depth decoder that generates the remaining tokens. It uses the pretrained codec model [Mimi](./mimi.md), introduced by Kyutai, to encode speech into discrete codebook tokens and decode them back into audio.
25+
26+
The original csm-1b checkpoint is available under the [Sesame](https://huggingface.co/sesame/csm-1b) organization on Hugging Face.
27+
28+
<div class="flex justify-center">
29+
<img src="https://huggingface.co/datasets/eustlb/documentation-images/resolve/main/csm_architecture.png"/>
30+
</div>
31+
32+
## Usage Tips
33+
34+
### Without Conversational Context
35+
36+
CSM can be used to simply generate speech from a text prompt:
37+
38+
```python
39+
import torch
40+
from transformers import CsmForConditionalGeneration, AutoProcessor
41+
42+
model_id = "eustlb/csm-1b"
43+
device = "cuda" if torch.cuda.is_available() else "cpu"
44+
45+
# load the model and the processor
46+
processor = AutoProcessor.from_pretrained(model_id)
47+
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
48+
49+
# prepare the inputs
50+
text = "[0]The past is just a story we tell ourselves." # `[0]` for speaker id 0
51+
inputs = processor(text, add_special_tokens=True).to(device)
52+
53+
# another equivalent way to prepare the inputs
54+
conversation = [
55+
{"role": "0", "content": [{"type": "text", "text": "The past is just a story we tell ourselves."}]},
56+
]
57+
inputs = processor.apply_chat_template(
58+
conversation,
59+
tokenize=True,
60+
return_dict=True,
61+
).to(device)
62+
63+
# infer the model
64+
audio = model.generate(**inputs, output_audio=True)
65+
processor.save_audio(audio, "example_without_context.wav")
66+
```
67+
68+
### With Conversational Context
69+
70+
CSM can be used to generate speech given a conversation, allowing consistency in the voices and content-aware generation:
71+
72+
```python
73+
import torch
74+
from transformers import CsmForConditionalGeneration, AutoProcessor
75+
from datasets import load_dataset, Audio
76+
77+
model_id = "eustlb/csm-1b"
78+
device = "cuda" if torch.cuda.is_available() else "cpu"
79+
80+
# load the model and the processor
81+
processor = AutoProcessor.from_pretrained(model_id)
82+
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
83+
84+
# prepare the inputs
85+
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
86+
# ensure the audio is 24kHz
87+
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
88+
conversation = []
89+
90+
# 1. context
91+
for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
92+
conversation.append(
93+
{
94+
"role": f"{speaker_id}",
95+
"content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
96+
}
97+
)
98+
99+
# 2. text prompt
100+
conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]})
101+
102+
inputs = processor.apply_chat_template(
103+
conversation,
104+
tokenize=True,
105+
return_dict=True,
106+
).to(device)
107+
108+
# infer the model
109+
audio = model.generate(**inputs, output_audio=True)
110+
processor.save_audio(audio, "example_with_context.wav")
111+
```
112+
113+
### Batched Inference
114+
115+
CSM supports batched inference!
116+
117+
```python
118+
import torch
119+
from transformers import CsmForConditionalGeneration, AutoProcessor
120+
from datasets import load_dataset, Audio
121+
122+
model_id = "eustlb/csm-1b"
123+
device = "cuda" if torch.cuda.is_available() else "cpu"
124+
125+
# load the model and the processor
126+
processor = AutoProcessor.from_pretrained(model_id)
127+
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
128+
129+
# prepare the inputs
130+
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
131+
# ensure the audio is 24kHz
132+
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
133+
# here a batch with two prompts
134+
conversation = [
135+
[
136+
{
137+
"role": f"{ds[0]['speaker_id']}",
138+
"content": [
139+
{"type": "text", "text": ds[0]["text"]},
140+
{"type": "audio", "path": ds[0]["audio"]["array"]},
141+
],
142+
},
143+
{
144+
"role": f"{ds[1]['speaker_id']}",
145+
"content": [
146+
{"type": "text", "text": ds[1]["text"]},
147+
],
148+
},
149+
],
150+
[
151+
{
152+
"role": f"{ds[0]['speaker_id']}",
153+
"content": [
154+
{"type": "text", "text": ds[0]["text"]},
155+
],
156+
}
157+
],
158+
]
159+
inputs = processor.apply_chat_template(
160+
conversation,
161+
tokenize=True,
162+
return_dict=True,
163+
).to(device)
164+
165+
audio = model.generate(**inputs, output_audio=True)
166+
processor.save_audio(audio, [f"speech_batch_idx_{i}.wav" for i in range(len(audio))])
167+
```
168+
169+
### Making The Model Go Brrr
170+
171+
CSM supports full-graph compilation with CUDA graphs!
172+
173+
```python
174+
import torch
175+
import copy
176+
from transformers import CsmForConditionalGeneration, AutoProcessor
177+
from datasets import load_dataset
178+
179+
model_id = "eustlb/csm-1b"
180+
device = "cuda"
181+
182+
# set logs to ensure no recompilation and graph breaks
183+
torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
184+
185+
# load the model and the processor
186+
processor = AutoProcessor.from_pretrained(model_id)
187+
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
188+
189+
# use static cache, enabling automatically torch compile with fullgraph and reduce-overhead
190+
model.generation_config.max_length = 250 # big enough to avoid recompilation
191+
model.generation_config.max_new_tokens = None # would take precedence over max_length
192+
model.generation_config.cache_implementation = "static"
193+
model.depth_decoder.generation_config.cache_implementation = "static"
194+
195+
# generation kwargs
196+
gen_kwargs = {
197+
"do_sample": False,
198+
"depth_decoder_do_sample": False,
199+
"temperature": 1.0,
200+
"depth_decoder_temperature": 1.0,
201+
}
202+
203+
# Define a timing decorator
204+
class TimerContext:
205+
def __init__(self, name="Execution"):
206+
self.name = name
207+
self.start_event = None
208+
self.end_event = None
209+
210+
def __enter__(self):
211+
# Use CUDA events for more accurate GPU timing
212+
self.start_event = torch.cuda.Event(enable_timing=True)
213+
self.end_event = torch.cuda.Event(enable_timing=True)
214+
self.start_event.record()
215+
return self
216+
217+
def __exit__(self, *args):
218+
self.end_event.record()
219+
torch.cuda.synchronize()
220+
elapsed_time = self.start_event.elapsed_time(self.end_event) / 1000.0
221+
print(f"{self.name} time: {elapsed_time:.4f} seconds")
222+
223+
# prepare the inputs
224+
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
225+
226+
conversation = [
227+
{
228+
"role": f"{ds[0]['speaker_id']}",
229+
"content": [
230+
{"type": "text", "text": ds[0]["text"]},
231+
{"type": "audio", "path": ds[0]["audio"]["array"]},
232+
],
233+
},
234+
{
235+
"role": f"{ds[1]['speaker_id']}",
236+
"content": [
237+
{"type": "text", "text": ds[1]["text"]},
238+
{"type": "audio", "path": ds[1]["audio"]["array"]},
239+
],
240+
},
241+
{
242+
"role": f"{ds[2]['speaker_id']}",
243+
"content": [
244+
{"type": "text", "text": ds[2]["text"]},
245+
],
246+
},
247+
]
248+
249+
padded_inputs_1 = processor.apply_chat_template(
250+
conversation,
251+
tokenize=True,
252+
return_dict=True,
253+
).to(device)
254+
255+
print("\n" + "="*50)
256+
print("First generation - compiling and recording CUDA graphs...")
257+
with TimerContext("First generation"):
258+
_ = model.generate(**padded_inputs_1, **gen_kwargs)
259+
print("="*50)
260+
261+
print("\n" + "="*50)
262+
print("Second generation - fast !!!")
263+
with TimerContext("Second generation"):
264+
_ = model.generate(**padded_inputs_1, **gen_kwargs)
265+
print("="*50)
266+
267+
# now with different inputs
268+
conversation = [
269+
{
270+
"role": f"{ds[0]['speaker_id']}",
271+
"content": [
272+
{"type": "text", "text": ds[2]["text"]},
273+
{"type": "audio", "path": ds[2]["audio"]["array"]},
274+
],
275+
},
276+
{
277+
"role": f"{ds[1]['speaker_id']}",
278+
"content": [
279+
{"type": "text", "text": ds[3]["text"]},
280+
{"type": "audio", "path": ds[3]["audio"]["array"]},
281+
],
282+
},
283+
{
284+
"role": f"{ds[2]['speaker_id']}",
285+
"content": [
286+
{"type": "text", "text": ds[4]["text"]},
287+
],
288+
},
289+
]
290+
padded_inputs_2 = processor.apply_chat_template(
291+
conversation,
292+
tokenize=True,
293+
return_dict=True,
294+
).to(device)
295+
296+
print("\n" + "="*50)
297+
print("Generation with other inputs!")
298+
with TimerContext("Generation with different inputs"):
299+
_ = model.generate(**padded_inputs_2, **gen_kwargs)
300+
print("="*50)
301+
```
302+
303+
### Training
304+
305+
CSM Transformers integration supports training!
306+
307+
```python
308+
from transformers import CsmForConditionalGeneration, AutoProcessor
309+
from datasets import load_dataset, Audio
310+
311+
model_id = "eustlb/csm-1b"
312+
device = "cuda"
313+
314+
# load the model and the processor
315+
processor = AutoProcessor.from_pretrained(model_id)
316+
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
317+
model.train()
318+
319+
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
320+
# ensure the audio is 24kHz
321+
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
322+
conversation = []
323+
324+
# context
325+
for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
326+
conversation.append(
327+
{
328+
"role": f"{speaker_id}",
329+
"content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
330+
}
331+
)
332+
333+
inputs = processor.apply_chat_template(
334+
conversation,
335+
tokenize=True,
336+
return_dict=True,
337+
output_labels=True,
338+
).to(device)
339+
340+
out = model(**inputs)
341+
out.loss.backward()
342+
```
343+
344+
This model was contributed by [Eustache Le Bihan](https://huggingface.co/eustlb).
345+
The original code can be found [here](https://github.com/SesameAILabs/csm).
346+
347+
348+
## CsmConfig
349+
350+
[[autodoc]] CsmConfig
351+
352+
## CsmDepthDecoderConfig
353+
354+
[[autodoc]] CsmDepthDecoderConfig
355+
356+
## CsmProcessor
357+
358+
[[autodoc]] CsmProcessor
359+
- __call__
360+
361+
## CsmForConditionalGeneration
362+
363+
[[autodoc]] CsmForConditionalGeneration
364+
- forward
365+
- generate
366+
367+
## CsmDepthDecoderForCausalLM
368+
369+
[[autodoc]] CsmDepthDecoderForCausalLM
370+
371+
## CsmDepthDecoderModel
372+
373+
[[autodoc]] CsmDepthDecoderModel
374+
375+
## CsmBackboneModel
376+
377+
[[autodoc]] CsmBackboneModel

0 commit comments

Comments
 (0)