Skip to content

Commit 00b9af1

Browse files
committed
Fix up sampling.
1 parent 3b30ee0 commit 00b9af1

3 files changed

Lines changed: 32 additions & 19 deletions

File tree

README.md

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@ pip install mlx-e2-tts
1313
## Usage
1414

1515
```python
16-
1716
import mlx.core as mx
1817

1918
from e2_tts_mlx.model import E2TTS
2019
from e2_tts_mlx.trainer import E2Trainer
2120
from e2_tts_mlx.data import load_libritts_r
2221

2322
e2tts = E2TTS(
24-
tokenizer="char-utf8", # or "phoneme_en" for phoneme-based tokenization
23+
tokenizer="char-utf8", # or "phoneme_en"
2524
cond_drop_prob = 0.25,
2625
frac_lengths_mask = (0.7, 0.9),
2726
transformer = dict(
@@ -40,11 +39,31 @@ mx.eval(e2tts.parameters())
4039
batch_size = 128
4140
max_duration = 30
4241

43-
dataset = load_libritts_r(split="dev-clean", max_duration = max_duration) # or any other audio/caption data set
42+
dataset = load_libritts_r(split="dev-clean") # or any audio/caption dataset
4443

4544
trainer = E2Trainer(model = e2tts, num_warmup_steps = 1000)
46-
trainer.train(train_dataset = dataset, learning_rate = 7.5e-5, batch_size = batch_size)
4745

46+
trainer.train(
47+
train_dataset = ...,
48+
learning_rate = 7.5e-5,
49+
batch_size = batch_size
50+
)
51+
```
52+
53+
... after much training ...
54+
55+
```python
56+
cond = ...
57+
text = ...
58+
duration = ... # from a trained DurationPredictor or otherwise
59+
60+
generated_mel_spec = e2tts.sample(
61+
cond = cond,
62+
text = text,
63+
duration = duration,
64+
steps = 32,
65+
cfg_strength = 1.0, # if trained for cfg
66+
)
4867
```
4968

5069
Note the model size specified above (from the paper) is very large. See `train_example.py` for a more practical-sized model you can train on your local device.

e2_tts_mlx/model.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def lens_to_mask(
3535
length: int | None = None,
3636
) -> mx.array: # Bool['b n']
3737
if not exists(length):
38-
length = t.amax()
38+
length = t.max()
3939

4040
seq = mx.arange(length)
4141
return einx.less("n, b -> b n", seq, t)
@@ -844,7 +844,6 @@ def __init__(
844844
self,
845845
transformer: dict | Transformer = None,
846846
duration_predictor: dict | DurationPredictor | None = None,
847-
odeint_kwargs: dict = dict(atol=1e-5, rtol=1e-5, method="midpoint"),
848847
cond_drop_prob=0.25,
849848
num_channels=None,
850849
mel_spec_module: nn.Module | None = None,
@@ -877,10 +876,6 @@ def __init__(
877876

878877
self.duration_predictor = duration_predictor
879878

880-
# sampling
881-
882-
self.odeint_kwargs = odeint_kwargs
883-
884879
# mel spec
885880

886881
self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
@@ -1016,7 +1011,6 @@ def odeint(self, func, y0, t):
10161011

10171012
return mx.stack(ys)
10181013

1019-
@mx.compile
10201014
def sample(
10211015
self,
10221016
cond: mx.array,
@@ -1049,7 +1043,7 @@ def sample(
10491043
assert text.shape[0] == batch
10501044

10511045
if exists(text):
1052-
text_lens = (text != -1).sum(dim=-1)
1046+
text_lens = (text != -1).sum(axis=-1)
10531047
lens = mx.maximum(
10541048
text_lens, lens
10551049
) # make sure lengths are at least those of the text characters
@@ -1070,15 +1064,15 @@ def sample(
10701064
duration = mx.maximum(
10711065
lens + 1, duration
10721066
) # just add one token so something is generated
1073-
duration = duration.clamp(max=max_duration)
1067+
duration = mx.minimum(duration, max_duration)
10741068

10751069
assert duration.shape[0] == batch
10761070

1077-
max_duration = duration.amax()
1071+
max_duration = duration.max().item()
10781072

1079-
cond = mx.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
1073+
cond = mx.pad(cond, [(0, 0), (0, max_duration - cond_seq_len), (0, 0)], constant_values=0)
10801074
cond_mask = mx.pad(
1081-
cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False
1075+
cond_mask, [(0, 0), (0, max_duration - cond_mask.shape[-1])], constant_values=False
10821076
)
10831077
cond_mask = rearrange(cond_mask, "... -> ... 1")
10841078

@@ -1088,7 +1082,7 @@ def sample(
10881082

10891083
def fn(t, x):
10901084
# at each step, conditioning is fixed
1091-
1085+
10921086
step_cond = mx.where(cond_mask, cond, mx.zeros_like(cond))
10931087

10941088
# predict flow
@@ -1100,7 +1094,7 @@ def fn(t, x):
11001094
y0 = mx.random.normal(cond.shape)
11011095
t = mx.linspace(0, 1, steps)
11021096

1103-
trajectory = self.odeint(fn, y0, t, **self.odeint_kwargs)
1097+
trajectory = self.odeint(fn, y0, t)
11041098
sampled = trajectory[-1]
11051099

11061100
out = sampled

e2_tts_mlx/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def train_step(mel_spec, text_inputs, mel_lens):
159159
log_start_date = datetime.datetime.now()
160160

161161
print(
162-
f"step {global_step}: loss = {loss.item():.4f}, sec per step = {elapsed_time.seconds / log_every}"
162+
f"step {global_step}: loss = {loss.item():.4f}, sec per step = {(elapsed_time.seconds / log_every):.2f}"
163163
)
164164

165165
if exists(self.duration_predictor):

0 commit comments

Comments
 (0)