Open
Description
- Since torch 2.5.0, training with
packed=True
and attentiondropout > 0.0
is not possible becausepadded_collate_packed
automatically chooses to build BlockMasks if flex is available (which will generally be the case on cuda device):
- If a layer needs a custom attention mask (e.g. sliding window attention), the collate mechanism returning a
BlockMask
object makes it hard/not possible for this mask to be materialized, since the mask is constructed with causal & document mod:
and the seq_lens are discarded:
torchtune/torchtune/data/_collate.py
Line 504 in 3ca0d30
For example in gemma-2 you have interleaved global and sliding window layers, so you would need 2 BlockMask/dense mask objects. - The cartesian product of options for dropout, sliding window (when implemented), possible score_mod (e.g. softcapping) seem to be problematic for the current dispatch mechanism, e.g. the same model with dropout > 0 (and packed=False due to the above bug) would be using F.sdpa at train time and and flex during inference?
I think a lot of these issues can be solved by a more expressive AttentionMask
abstraction. The objective would be to make minimal changes to recipe code. One way this could be done is to replace MaskType
:
with an
AttentionMask
object that can be materialized once per forward pass. Currently BlockMask
is constructed by first building a dense mask, which could be used for F.sdpa or naive-attention (needed e.g. if we have dropout > 0 and soft capping). Example (for flex_attention on pt 2.5.0):
import torch
from torch.nn.attention.flex_attention import (
_DEFAULT_SPARSE_BLOCK_SIZE,
BlockMask,
_convert_mask_to_block_mask,
_create_sparse_block_from_block_mask,
_mask_mod_signature,
_score_mod_signature,
create_block_mask,
create_mask,
_round_up_to_multiple,
)
class AttentionMask:
_mask_tensor: torch.Tensor
_block_mask: Optional[BlockMask]
def __init__(
self,
mask_mod: _mask_mod_signature,
B: Optional[int],
H: Optional[int],
Q_LEN: int,
KV_LEN: int,
device: str = "cuda",
BLOCK_SIZE: int | tuple[int, int] = _DEFAULT_SPARSE_BLOCK_SIZE,
) -> None:
if B is None:
B = 1
if H is None:
H = 1
if isinstance(BLOCK_SIZE, int):
Q_BLOCK_SIZE = BLOCK_SIZE
KV_BLOCK_SIZE = BLOCK_SIZE
else:
Q_BLOCK_SIZE, KV_BLOCK_SIZE = BLOCK_SIZE
if Q_LEN < 128:
Q_BLOCK_SIZE = Q_LEN
else:
Q_LEN = _round_up_to_multiple(Q_LEN, Q_BLOCK_SIZE)
KV_LEN = _round_up_to_multiple(KV_LEN, KV_BLOCK_SIZE)
self.mask_mod = mask_mod
self.B, self.H = B, H
self.Q_LEN, self.KV_LEN = Q_LEN, KV_LEN
self.Q_BLOCK_SIZE, self.KV_BLOCK_SIZE = Q_BLOCK_SIZE, KV_BLOCK_SIZE
self.device = device
# Only build dense mask
# mask_mod can be e.g. causal & document
self._mask_tensor = create_mask(
mask_mod, self.B, self.H, self.Q_LEN, self.KV_LEN, self.device
)
self._block_mask = None
@property
def dense_mask(self) -> torch.Tensor:
return self._mask_tensor
@property
def block_mask(self) -> BlockMask:
# Materialize BlockMask if needed.
if self._block_mask is None:
partial_block_mask, full_block_mask = _convert_mask_to_block_mask(
self.dense_mask,
Q_BLOCK_SIZE=self.Q_BLOCK_SIZE,
KV_BLOCK_SIZE=self.KV_BLOCK_SIZE,
separate_full_blocks=True,
)
self._block_mask = _create_sparse_block_from_block_mask(
(partial_block_mask, full_block_mask),
self.mask_mod,
self.Q_BLOCK_SIZE,
self.KV_BLOCK_SIZE,
)
return self._block_mask
def windowed_dense(self, window_size: int) -> torch.Tensor:
# Materialize windowed if needed.
return torch.tril(self.dense_mask.squeeze()).triu_(diagonal=1 - window_size)
The AttentionMask
object can still be constructed in the collate function for packed=True
if that is desired, and the BlockMask
is potentially built in the forward pass If the flex_attention API is compatible with the current options, otherwise the dense mask can be used for F.sdpa or naive-attention (i.e. no flash, etc).
Activity