Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/f5_tts/model/backbones/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
text_num_embeds=256,
text_mask_padding=True,
qk_norm=None,
checkpoint_activations=False,
):
super().__init__()

Expand Down Expand Up @@ -126,6 +127,8 @@ def __init__(
self.norm_out = AdaLayerNorm_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)

self.checkpoint_activations = checkpoint_activations

self.initialize_weights()

def initialize_weights(self):
Expand All @@ -142,6 +145,13 @@ def initialize_weights(self):
nn.init.constant_(self.proj_out.weight, 0)
nn.init.constant_(self.proj_out.bias, 0)

def ckpt_wrapper(self, module):
def ckpt_forward(*inputs):
outputs = module(*inputs)
return outputs

return ckpt_forward

def get_input_embed(
self,
x, # b n d
Expand Down Expand Up @@ -205,7 +215,12 @@ def forward(
rope_text = self.rotary_embed.forward_from_seq_len(text_len)

for block in self.transformer_blocks:
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
if self.checkpoint_activations:
c, x = torch.utils.checkpoint.checkpoint(
self.ckpt_wrapper(block), x, c, t, mask, rope_audio, rope_text, use_reentrant=False
)
else:
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)

x = self.norm_out(x, t)
output = self.proj_out(x)
Expand Down