From 92632dc67768ff5458a79a26e9584a616319fa30 Mon Sep 17 00:00:00 2001 From: Markus Krimmel Date: Thu, 29 Feb 2024 16:50:35 +0100 Subject: [PATCH 1/4] feat: started implementing ALiBi for non-flash attention --- flash_attn/modules/mha.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 89c7680d5..4886d851b 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -234,11 +234,25 @@ class SelfAttention(nn.Module): (default: 0.0) """ - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None): super().__init__() self.causal = causal self.softmax_scale = softmax_scale self.drop = nn.Dropout(attention_dropout) + self.alibi_slopes = self.register_buffer('alibi_slopes', alibi_slopes, persistent=False) + if alibi_slopes is not None: + self.alibi_tensor = self.register_buffer('alibi_tensor', self._build_alibi_tensor(16), persistent=False) + else: + self.alibi_tensor = None + + def _build_alibi_tensor(self, seqlen): + context_position = torch.arange(seqlen, device=self.alibi_slopes.device)[:, None] + memory_position = torch.arange(seqlen, device=self.alibi_slopes.device)[None, :] + # distance tensor is of shape (seqlen, seqlen) + distance = torch.abs(memory_position - context_position) + # alibi tensor is of shape (1, H, seqlen, seqlen) + alibi_tensor = (distance[None, ...] * self.alibi_tensor[:, None, None])[None, ...] + return alibi_tensor def forward(self, qkv, causal=None, key_padding_mask=None): """Implements the multihead softmax attention. @@ -261,6 +275,11 @@ def forward(self, qkv, causal=None, key_padding_mask=None): padding_mask.masked_fill_(key_padding_mask, 0.0) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") + if self.alibi_slopes is not None: + if seqlen > self.alibi_tensor.shape[-1]: + self.alibi_tensor = self._build_alibi_tensor(seqlen).to(scores.device) + cropped_alibi = self.alibi_slopes[..., :seqlen, :seqlen].to(scores.device) + scores = scores - cropped_alibi if causal: # "triu_tril_cuda_template" not implemented for 'BFloat16' # So we have to construct the mask in float @@ -420,7 +439,7 @@ def __init__( self.return_residual = return_residual self.checkpointing = checkpointing if use_alibi: - assert use_flash_attn, "ALiBi code path requires flash_attn" + assert not cross_attn or use_flash_attn, "ALiBi code path requires self-attention or cross-attention with flash_attn" alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device) else: alibi_slopes = None @@ -458,7 +477,7 @@ def __init__( inner_attn_cls = ( partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size) if use_flash_attn - else SelfAttention + else partial(SelfAttention, alibi_slopes=alibi_slopes) ) inner_cross_attn_cls = ( partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size) From 732beaa02ac98e30009f4e241f944a114aca4b98 Mon Sep 17 00:00:00 2001 From: Markus Krimmel Date: Thu, 29 Feb 2024 17:05:25 +0100 Subject: [PATCH 2/4] fixed buffer registration, refactoring of variable names --- flash_attn/modules/mha.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 4886d851b..2f3c6df3c 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -239,19 +239,19 @@ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alib self.causal = causal self.softmax_scale = softmax_scale self.drop = nn.Dropout(attention_dropout) - self.alibi_slopes = self.register_buffer('alibi_slopes', alibi_slopes, persistent=False) + self.register_buffer('alibi_slopes', alibi_slopes, persistent=False) if alibi_slopes is not None: - self.alibi_tensor = self.register_buffer('alibi_tensor', self._build_alibi_tensor(16), persistent=False) + self.register_buffer('linear_biases', self._build_linear_biases(16), persistent=False) else: self.alibi_tensor = None - def _build_alibi_tensor(self, seqlen): + def _build_linear_biases(self, seqlen): context_position = torch.arange(seqlen, device=self.alibi_slopes.device)[:, None] memory_position = torch.arange(seqlen, device=self.alibi_slopes.device)[None, :] # distance tensor is of shape (seqlen, seqlen) distance = torch.abs(memory_position - context_position) # alibi tensor is of shape (1, H, seqlen, seqlen) - alibi_tensor = (distance[None, ...] * self.alibi_tensor[:, None, None])[None, ...] + alibi_tensor = (distance[None, ...] * self.alibi_slopes[:, None, None])[None, ...] return alibi_tensor def forward(self, qkv, causal=None, key_padding_mask=None): @@ -276,10 +276,10 @@ def forward(self, qkv, causal=None, key_padding_mask=None): # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") if self.alibi_slopes is not None: - if seqlen > self.alibi_tensor.shape[-1]: - self.alibi_tensor = self._build_alibi_tensor(seqlen).to(scores.device) - cropped_alibi = self.alibi_slopes[..., :seqlen, :seqlen].to(scores.device) - scores = scores - cropped_alibi + if seqlen > self.linear_biases.shape[-1]: + self.linear_biases = self._build_linear_biases(seqlen).to(scores.device) + cropped_biases = self.linear_biases[..., :seqlen, :seqlen].to(scores.device) + scores = scores - cropped_biases if causal: # "triu_tril_cuda_template" not implemented for 'BFloat16' # So we have to construct the mask in float From e156a15d4d4dd82a4752acd64f1dd80289adb0ab Mon Sep 17 00:00:00 2001 From: Markus Krimmel Date: Thu, 29 Feb 2024 17:09:00 +0100 Subject: [PATCH 3/4] feat: some further refactoring --- flash_attn/modules/mha.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 2f3c6df3c..83a05b32a 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -243,7 +243,7 @@ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alib if alibi_slopes is not None: self.register_buffer('linear_biases', self._build_linear_biases(16), persistent=False) else: - self.alibi_tensor = None + self.linear_biases = None def _build_linear_biases(self, seqlen): context_position = torch.arange(seqlen, device=self.alibi_slopes.device)[:, None] @@ -251,8 +251,8 @@ def _build_linear_biases(self, seqlen): # distance tensor is of shape (seqlen, seqlen) distance = torch.abs(memory_position - context_position) # alibi tensor is of shape (1, H, seqlen, seqlen) - alibi_tensor = (distance[None, ...] * self.alibi_slopes[:, None, None])[None, ...] - return alibi_tensor + linear_biases = (distance[None, ...] * self.alibi_slopes[:, None, None])[None, ...] + return linear_biases def forward(self, qkv, causal=None, key_padding_mask=None): """Implements the multihead softmax attention. From 4eb887a559c97ee747e20e3796bf1e9f0cb4e92e Mon Sep 17 00:00:00 2001 From: Markus Krimmel Date: Thu, 29 Feb 2024 17:11:25 +0100 Subject: [PATCH 4/4] fix: don't move linear biases to device --- flash_attn/modules/mha.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 83a05b32a..e264f3a7e 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -277,8 +277,8 @@ def forward(self, qkv, causal=None, key_padding_mask=None): scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") if self.alibi_slopes is not None: if seqlen > self.linear_biases.shape[-1]: - self.linear_biases = self._build_linear_biases(seqlen).to(scores.device) - cropped_biases = self.linear_biases[..., :seqlen, :seqlen].to(scores.device) + self.linear_biases = self._build_linear_biases(seqlen) + cropped_biases = self.linear_biases[..., :seqlen, :seqlen] scores = scores - cropped_biases if causal: # "triu_tril_cuda_template" not implemented for 'BFloat16'