Skip to content

Commit 6a40e1c

Browse files
authored
Fix looping limit in causal attention (#1999)
1 parent 9307b2a commit 6a40e1c

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ template <
237237
if (do_causal) {
238238
int q_max = (tid.x + 1) * BQ + params->qL_off;
239239
kb_lim = (q_max + BK - 1) / BK;
240+
kb_lim = min(params->NK, kb_lim);
240241
}
241242

242243
// Loop over KV seq length
@@ -290,7 +291,7 @@ template <
290291
}
291292

292293
// Mask out if causal
293-
if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) {
294+
if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
294295
using stile_t = decltype(Stile);
295296
using selem_t = typename stile_t::elem_type;
296297
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();

0 commit comments

Comments
 (0)