Skip to content

Prune text tokens that are masked #8

@dxqb

Description

@dxqb

The max text sequence length for this model is 512. The image sequence length is usually between 1024 and 4096 depending on resolution. This means that a significant fraction of the processed tokens is text tokens - but due to Chroma's design, most of them are masked unless you have very long prompts.

Masked tokens can be removed:

        seq_lengths = bool_attention_mask.sum(dim=1)
        max_seq_length = seq_lengths.max().item()
        text_encoder_output = text_encoder_output[:, :max_seq_length, :]
        bool_attention_mask = bool_attention_mask[:, :max_seq_length]

(applied after the mask has been expanded, otherwise max_seq_length is off by 1)

Training and inference takes about 25% less time at 512 px. Less at time saving at higher resolutions

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions