Skip to content

Inconsistency for CLS Token in TextTransformer #1113

@grashei

Description

@grashei

When the TextTransformer is instantiated in this function

def _build_text_tower(

the keyword argument correct_cls_mask is not set so it is left to the default of False

Hence, when the additive attention mask is created in

if self.cls_emb is not None:
cls_valid = valid.new_ones(valid.size(0), 1) # [B, 1]
# cls mask pos at end if correct or front for incorrect legacy mode in existing CoCa weights
valid = torch.cat([valid, cls_valid] if self.correct_cls_mask else [cls_valid, valid], 1)

the mask for the CLS Token is always positioned at the beginning of the sequence.
However, the CLS Token itself is always placed at the end of the sequence:

# Optional class token (always appended ala CoCa)
if self.cls_emb is not None:
x = torch.cat([x, _expand_token(self.cls_emb, x.size(0))], 1)
seq_len += 1

Am I correct, and could we make use of correct_cls_mask to solve this?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions