Skip to content

Commit 45bde42

Browse files
committed
feat: enhance block-sparse attention with dense math delegation and improved logits handling
1 parent 870e2fb commit 45bde42

1 file changed

Lines changed: 46 additions & 23 deletions

File tree

ejkernel/kernels/_xla/blocksparse_attention/_interface.py

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from ejkernel.ops import BwdParams, FwdParams
3434

3535
from ..._registry import Backend, Platform, kernel_registry
36+
from ..attention import attention as dense_attention
3637

3738
if tp.TYPE_CHECKING:
3839
from ejkernel.kernels._pallas.tpu.blocksparse_attention._masks import Mask
@@ -225,33 +226,55 @@ def blocksparse_attention(
225226

226227
row_has_any = jnp.any(mask, axis=-1)
227228

228-
reps = num_heads // num_kv_heads
229-
q = query.reshape(batch, num_kv_heads, reps, q_len, head_dim)
230-
k = key
231-
v = value
229+
if softmax_aux is None:
230+
q_bthd = jnp.transpose(query, (0, 2, 1, 3))
231+
k_bthd = jnp.transpose(key, (0, 2, 1, 3))
232+
v_bthd = jnp.transpose(value, (0, 2, 1, 3))
233+
mask_4d = mask[:, None, :, :]
234+
235+
out_bthd, _ = dense_attention(
236+
query=q_bthd,
237+
key=k_bthd,
238+
value=v_bthd,
239+
attention_mask=mask_4d,
240+
softmax_aux=None,
241+
softmax_scale=softmax_scale,
242+
logits_soft_cap=logits_soft_cap,
243+
dtype=q_bthd.dtype,
244+
softmax_dtype=None,
245+
dropout_prob=0.0,
246+
deterministic=True,
247+
dropout_rng=None,
248+
causal=causal,
249+
sliding_window=None,
250+
bias=None,
251+
init_bias=None,
252+
)
232253

233-
scale = jnp.asarray(softmax_scale, dtype=q.dtype)
234-
logits = jnp.einsum("bhrqd,bhkd->bhrqk", q * scale, k, optimize=True)
254+
out_bthd = out_bthd * (row_has_any & q_valid).astype(out_bthd.dtype)[:, :, None, None]
255+
return jnp.transpose(out_bthd, (0, 2, 1, 3))
235256

257+
reps = num_heads // num_kv_heads
258+
if reps != 1:
259+
key_h = jnp.repeat(key, repeats=reps, axis=1)
260+
value_h = jnp.repeat(value, repeats=reps, axis=1)
261+
else:
262+
key_h = key
263+
value_h = value
264+
265+
logits = jnp.einsum("bhtd,bhkd->bhtk", query * softmax_scale, key_h, optimize=True)
236266
if logits_soft_cap is not None:
237-
cap = jnp.asarray(logits_soft_cap, dtype=logits.dtype)
238-
logits = cap * jnp.tanh(logits / cap)
267+
logits = logits_soft_cap * jnp.tanh(logits / logits_soft_cap)
239268

240-
neg = jnp.finfo(logits.dtype).min
241-
logits = jnp.where(mask[:, None, None, :, :], logits, neg)
269+
logits = jnp.where(mask[:, None, :, :], logits, jnp.finfo(logits.dtype).min)
242270

243271
aux = _normalize_softmax_aux(softmax_aux, num_heads=num_heads, num_kv_heads=num_kv_heads, dtype=logits.dtype)
244-
if aux is not None:
245-
aux = aux.reshape(num_kv_heads, reps, aux.shape[-1])
246-
sinks = jnp.broadcast_to(aux[None, :, :, None, :], (batch, num_kv_heads, reps, q_len, aux.shape[-1]))
247-
combined = jnp.concatenate([logits, sinks], axis=-1)
248-
probs = jax.nn.softmax(combined.astype(jnp.float32), axis=-1).astype(logits.dtype)
249-
weights = probs[..., :kv_len]
250-
else:
251-
weights = jax.nn.softmax(logits.astype(jnp.float32), axis=-1).astype(logits.dtype)
252-
253-
weights = weights * row_has_any[:, None, None, :, None].astype(weights.dtype)
254-
255-
out = jnp.einsum("bhrqk,bhkd->bhrqd", weights, v, optimize=True).reshape(batch, num_heads, q_len, value.shape[-1])
256-
out = out * q_valid[:, None, :, None].astype(out.dtype)
272+
assert aux is not None
273+
sinks = jnp.broadcast_to(aux[None, :, None, :], (batch, num_heads, q_len, aux.shape[-1]))
274+
combined = jnp.concatenate([logits, sinks], axis=-1)
275+
probs = jax.nn.softmax(combined.astype(jnp.float32), axis=-1).astype(logits.dtype)
276+
weights = probs[..., :kv_len]
277+
278+
out = jnp.einsum("bhtk,bhkd->bhtd", weights, value_h, optimize=True)
279+
out = out * (row_has_any & q_valid).astype(out.dtype)[:, None, :, None]
257280
return out

0 commit comments

Comments
 (0)