Skip to content

Mask construction & attention dispatch issues and possible ideas to allow for more models #1870

Open
@mirceamironenco

Description

  • Since torch 2.5.0, training with packed=True and attention dropout > 0.0 is not possible because padded_collate_packed automatically chooses to build BlockMasks if flex is available (which will generally be the case on cuda device):
    if _SUPPORTS_FLEX_ATTENTION:
  • If a layer needs a custom attention mask (e.g. sliding window attention), the collate mechanism returning a BlockMask object makes it hard/not possible for this mask to be materialized, since the mask is constructed with causal & document mod:
    def mask_mod(b, h, q_idx, kv_idx):
    and the seq_lens are discarded:
    "mask": block_mask,

    For example in gemma-2 you have interleaved global and sliding window layers, so you would need 2 BlockMask/dense mask objects.
  • The cartesian product of options for dropout, sliding window (when implemented), possible score_mod (e.g. softcapping) seem to be problematic for the current dispatch mechanism, e.g. the same model with dropout > 0 (and packed=False due to the above bug) would be using F.sdpa at train time and and flex during inference?

I think a lot of these issues can be solved by a more expressive AttentionMask abstraction. The objective would be to make minimal changes to recipe code. One way this could be done is to replace MaskType:

mask: Optional[_MaskType] = None,

with an AttentionMask object that can be materialized once per forward pass. Currently BlockMask is constructed by first building a dense mask, which could be used for F.sdpa or naive-attention (needed e.g. if we have dropout > 0 and soft capping). Example (for flex_attention on pt 2.5.0):

import torch
from torch.nn.attention.flex_attention import (
    _DEFAULT_SPARSE_BLOCK_SIZE,
    BlockMask,
    _convert_mask_to_block_mask,
    _create_sparse_block_from_block_mask,
    _mask_mod_signature,
    _score_mod_signature,
    create_block_mask,
    create_mask,
    _round_up_to_multiple,
)

class AttentionMask:
    _mask_tensor: torch.Tensor
    _block_mask: Optional[BlockMask]

    def __init__(
        self,
        mask_mod: _mask_mod_signature,
        B: Optional[int],
        H: Optional[int],
        Q_LEN: int,
        KV_LEN: int,
        device: str = "cuda",
        BLOCK_SIZE: int | tuple[int, int] = _DEFAULT_SPARSE_BLOCK_SIZE,
    ) -> None:
        if B is None:
            B = 1
        if H is None:
            H = 1
        if isinstance(BLOCK_SIZE, int):
            Q_BLOCK_SIZE = BLOCK_SIZE
            KV_BLOCK_SIZE = BLOCK_SIZE
        else:
            Q_BLOCK_SIZE, KV_BLOCK_SIZE = BLOCK_SIZE

        if Q_LEN < 128:
            Q_BLOCK_SIZE = Q_LEN
        else:
            Q_LEN = _round_up_to_multiple(Q_LEN, Q_BLOCK_SIZE)
        KV_LEN = _round_up_to_multiple(KV_LEN, KV_BLOCK_SIZE)

        self.mask_mod = mask_mod
        self.B, self.H = B, H
        self.Q_LEN, self.KV_LEN = Q_LEN, KV_LEN
        self.Q_BLOCK_SIZE, self.KV_BLOCK_SIZE = Q_BLOCK_SIZE, KV_BLOCK_SIZE
        self.device = device

        # Only build dense mask
        # mask_mod can be e.g. causal & document
        self._mask_tensor = create_mask(
            mask_mod, self.B, self.H, self.Q_LEN, self.KV_LEN, self.device
        )
        self._block_mask = None

    @property
    def dense_mask(self) -> torch.Tensor:
        return self._mask_tensor

    @property
    def block_mask(self) -> BlockMask:
        # Materialize BlockMask if needed.
        if self._block_mask is None:
            partial_block_mask, full_block_mask = _convert_mask_to_block_mask(
                self.dense_mask,
                Q_BLOCK_SIZE=self.Q_BLOCK_SIZE,
                KV_BLOCK_SIZE=self.KV_BLOCK_SIZE,
                separate_full_blocks=True,
            )
            self._block_mask = _create_sparse_block_from_block_mask(
                (partial_block_mask, full_block_mask),
                self.mask_mod,
                self.Q_BLOCK_SIZE,
                self.KV_BLOCK_SIZE,
            )
        return self._block_mask

    def windowed_dense(self, window_size: int) -> torch.Tensor:
        # Materialize windowed if needed.
        return torch.tril(self.dense_mask.squeeze()).triu_(diagonal=1 - window_size)

The AttentionMask object can still be constructed in the collate function for packed=True if that is desired, and the BlockMask is potentially built in the forward pass If the flex_attention API is compatible with the current options, otherwise the dense mask can be used for F.sdpa or naive-attention (i.e. no flash, etc).

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions