@@ -16,7 +16,7 @@ def get_rectangular_mask(shape, q_seq_len, k_seq_len, device, dtype):
16
16
)[:, :, :, :k_seq_len ]
17
17
18
18
19
- def xformers_attn (queries , keys , values , is_causal , document_seqlens = None ):
19
+ def xformers_attn (queries , keys , values , is_causal , document_seqlens = None ):
20
20
# xformers assumes q, k, v are [batch, seq_len, heads, embed_dim]
21
21
# We assume that queries match the last part of the key / value sequences
22
22
# see (https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.LowerTriangularFromBottomRightMask)
@@ -47,18 +47,24 @@ def xformers_attn(queries, keys, values, is_causal, document_seqlens = None):
47
47
device = queries .device
48
48
for ds in document_seqlens :
49
49
if is_causal and queries .shape [1 ] == keys .shape [1 ]:
50
- masks .append (xops .fmha .attn_bias .BlockDiagonalCausalMask .from_seqlens (ds ).materialize (shape = (1 , heads , q_seq_len , k_seq_len ), device = device , dtype = dtype ))
50
+ masks .append (
51
+ xops .fmha .attn_bias .BlockDiagonalCausalMask .from_seqlens (ds ).materialize (
52
+ shape = (1 , heads , q_seq_len , k_seq_len ), device = device , dtype = dtype
53
+ )
54
+ )
51
55
elif is_causal and queries .shape [1 ] > 1 :
52
- masks .append (xops .fmha .attn_bias .BlockDiagonalCausalFromBottomRightMask .from_seqlens (ds ).materialize (shape = (1 , heads , q_seq_len , k_seq_len ), device = device , dtype = dtype ))
56
+ masks .append (
57
+ xops .fmha .attn_bias .BlockDiagonalCausalFromBottomRightMask .from_seqlens (ds ).materialize (
58
+ shape = (1 , heads , q_seq_len , k_seq_len ), device = device , dtype = dtype
59
+ )
60
+ )
53
61
mask = torch .cat (masks , dim = 0 )
54
62
55
63
return xops .memory_efficient_attention (queries , keys , values , attn_bias = mask )
56
64
57
65
58
- def torch_attn (queries , keys , values , is_causal , document_seqlens = None ):
59
-
66
+ def torch_attn (queries , keys , values , is_causal , document_seqlens = None ):
60
67
if document_seqlens is None or len (document_seqlens ) == 1 :
61
-
62
68
# Need to call contiguous in torch >=2.1, otherwise later calls to .view() fail.
63
69
# Possibly related: https://github.com/pytorch/pytorch/issues/110213 - behavior of scaled_dot_product_attention
64
70
# changed between 2.0 and 2.1
@@ -89,7 +95,7 @@ def torch_attn(queries, keys, values, is_causal, document_seqlens = None):
89
95
.transpose (1 , 2 )
90
96
.contiguous ()
91
97
)
92
-
98
+
93
99
else :
94
100
raise NotImplementedError ("Currently supporting --mask-across-documents only with xformers attention." )
95
101
0 commit comments