Skip to content

Confusion about loss calculation #134

Open
@distillation-dcf

Description

@distillation-dcf

Hi!

In forward() function of model.py, text loss and image loss is computed by

labels = torch.cat((text[:, 1:], image_input_ids), dim=1).contiguous().long()  # shape: (bs, 127+1024=1151)
loss_text = F.cross_entropy(
    text_logits,
    labels[:, :self.text_seq_length])  # shape: (bs, 128)
loss_img = F.cross_entropy(
    image_logits,
    labels[:, self.text_seq_length:])  # shape: (bs, 1023)

Here text[:, 1:] should be removal of the first [BOS] text token label, then there are only 128-1=127 text tokens left in labels. But in CE loss, text logits with seq_len=128 and labels[:, :self.text_seq_length]) # shape: (bs, 128) come to calculate the text loss. I guess that the very first image token after all text tokens are taken into text loss computation by mistake.

Am I understanding the code correctly? Will the text token length in CE loss calculation affect the training process?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions