-
Notifications
You must be signed in to change notification settings - Fork 15
Description
I am currently reproducing the Cosy3 model and have a question regarding the DiffRO process described in your paper.
In the paper, the predicted tokens are sampled using Gumbel-Softmax (Eq. 3), and then the token-to-text model computes the ASR loss to update the LLM (Eq. 4).
I have a question about this process:
During the training of the text model (ASR model), the paper only mentions that the predicted tokens are used to compute the ASR loss, but since the argmax operation is non-differentiable, I assume the ASR model is trained either with a one-hot vector representation of the speech tokens or with the corresponding speech token embeddings (via nn.Embedding). If the input is the speech token embedding, does this mean the Gumbel-Softmax output (one-hot vector) is multiplied by the speech token embedding table (embedding.weight * one_hot_vector)?