Skip to content

Commit a645edd

Browse files
authored
diagonal noncausal try accounting for max_seqlens too
1 parent fdc50f3 commit a645edd

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

csrc/flash_attn/src/mask.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,13 @@ struct Mask {
177177
for (int j = 0; j < size<1, 0>(tensor); ++j) {
178178
const int col_idx = col_idx_base + j;
179179
if constexpr (Has_alibi) {
180-
tensor(make_coord(i, mi), make_coord(j, nj)) += ((row_idx + max_seqlen_k == max_seqlen_q + col_idx) ? 0 : alibi_slope);
180+
if constexpr (Is_causal) {
181+
tensor(make_coord(i, mi), make_coord(j, nj)) += ((col_idx == (col_idx_limit_right - 1)) ? 0 : alibi_slope);
182+
183+
} else {
184+
tensor(make_coord(i, mi), make_coord(j, nj)) += ((row_idx + max_seqlen_k == max_seqlen_q + col_idx) ? 0 : alibi_slope);
185+
186+
}
181187
}
182188
if constexpr (Causal_mask) {
183189
if (col_idx >= col_idx_limit_right) {

0 commit comments

Comments
 (0)