Skip to content

[TTS] Add code for training semantic codec#15524

Open
rlangman wants to merge 3 commits intomainfrom
codec_semantic
Open

[TTS] Add code for training semantic codec#15524
rlangman wants to merge 3 commits intomainfrom
codec_semantic

Conversation

@rlangman
Copy link
Collaborator

What does this PR do ?

Add code needed to train a single codebook semantic token and embed it inside a multi-codebook codec.

Collection: [TTS]

Changelog

  • Add semantic distillation using w2v-bert
  • Add inference logic to embed semantic token as first codebook in multi-codebook codec
  • Add option in data loader to resample audio
  • Remove dead commented out code related to ASR loss. The PhonemeASR module it references is not defined in NeMo.

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

Signed-off-by: Ryan <rlangman@nvidia.com>
@rlangman rlangman requested review from Edresson, blisc and rfejgin March 19, 2026 22:58
@rlangman rlangman self-assigned this Mar 19, 2026
@github-actions github-actions bot removed the CI label Mar 19, 2026
Signed-off-by: rlangman <rlangman@users.noreply.github.com>
hidden_layer: Index of hidden layer to extract embeddings from.
Defaults to 16, which for research suggests is effective for w2v-bert and TTS.
padding: Number of audio samples to pad before encoding to ensure output has a frame rate compatible with the audio codec.
scaling_factor: Constant factor to scale output embedding by.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like in practice we divide by this factor, not multiply. Maybe either change the contract to provide 1/scaling and then multiply, or change the comment to say we divide by it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, why is this scaling needed? Okay to keep it, just curious.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All loss functions in the codec training are are implemented to produce approximately the same scale by default (around 0.1 - 0.3). The w2v embedding scale is random based on watever scale the layer norm in the model learned. It ends up producing an embedding where the max value is around 5, so I scale the embedding down to have a max value of about 1, which also reduces the scale of the SLM loss to about 0.2 to be comparable to the other losses.

The alternative would be to have the SLM loss scale in the AudioCodecModel class default to something like 0.1 or 0.2.

return slm_emb


class SLMDecoder(NeuralModule):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if SLMDecoder is the best name for what this class does; on the face of it, it would look like it's decoding the outputs of the SLM encoder, but that's not what really does. Maybe SLM Predictor? Or any name that you think makes sense.

That said, I hope changing this won't cause too much trouble in invalidating existing checkpoints etc...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the terms "predict" and "decode" are usually interchangeable. It is however a bit confusing in this instance because the target being generated by the decoder happens to be the latent space of a different encoder's output.

This can be safely changed, as I deleted the SLM encoder and decoders from the checkpoints I am using for inference anyways.

self,
dataset_meta: Dict,
sample_rate: int,
resample_rate: Optional[int] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you update the docstring to add this argument?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added docstring. The functionality might be a bit confusing, because the feature that is actually being added is the option to resample using batched NeMo code instead of librosa.

return state_dict

def load_state_dict(self, state_dict, strict=True):
# Override to load all the keys except .speaker_encoder. and WavLM model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you update the comment to say why we are skipping some keys?

semantic_codec_cfg = cfg.get("semantic_codec")
semantic_codec = AudioCodecModel(cfg=semantic_codec_cfg)
elif cfg.get("semantic_codec_path"):
semantic_codec_path = cfg.get("semantic_codec_path")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are semantic_codec and semantic_codec_path mutually excludive? If so, maybe we can add an error check.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, a question: how does training of the semantic codec itself happen?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to describe it succinctly in the comment. But what happens is that the first time you train this model there is no "semantic_codec" config, and it reads it from the "semantic_codec_path" you provide in your yaml file. It then loads the checkpoint, and stores it in a new config called "semantic_codec". On all future training runs, or during inference, both config values are present, but it prioritizes using the submodule instead of reading the checkpoint again. The "semantic_codec" is only auto-generated in this way, never defined by a user.

This was the only way I could find to get register_nemo_submodule to work in this way, and it feels like an awkward interface. If anyone knows a cleaner way to implement this, I would be happy to hear it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, a question: how does training of the semantic codec itself happen?

You run the recipe using a config file that has 1 codebook, no discriminator, and the SLM loss enabled.

if self.discriminator is None:
schedulers.step()
else:
schedulers[0].step()
Copy link
Collaborator

@rfejgin rfejgin Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it would be cleaner to iterate on all schedulers in the list

Copy link
Collaborator

@rfejgin rfejgin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good and clean overall. See some generally minor comments.

encoded, encoded_len = self.audio_encoder(audio=audio_preprocessed, audio_len=audio_preprocessed_len)

if self.semantic_codec is not None:
semantic, _ = self.semantic_codec.encode_audio(audio=audio, audio_len=audio_len, sample_rate=sample_rate)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add with torch.no_grad()?

if schedulers is None or self.lr_schedule_interval != interval:
return

if self.discriminator is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe instead just check if it's a list or single item, then this function doesn't need to know about discriminators

semantic_codec_cfg = cfg.get("semantic_codec")
semantic_codec = AudioCodecModel(cfg=semantic_codec_cfg)
elif cfg.get("semantic_codec_path"):
semantic_codec_path = cfg.get("semantic_codec_path")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, a question: how does training of the semantic codec itself happen?

Signed-off-by: Ryan <rlangman@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants