|
33 | 33 | from ejkernel.ops import BwdParams, FwdParams |
34 | 34 |
|
35 | 35 | from ..._registry import Backend, Platform, kernel_registry |
| 36 | +from ..attention import attention as dense_attention |
36 | 37 |
|
37 | 38 | if tp.TYPE_CHECKING: |
38 | 39 | from ejkernel.kernels._pallas.tpu.blocksparse_attention._masks import Mask |
@@ -225,33 +226,55 @@ def blocksparse_attention( |
225 | 226 |
|
226 | 227 | row_has_any = jnp.any(mask, axis=-1) |
227 | 228 |
|
228 | | - reps = num_heads // num_kv_heads |
229 | | - q = query.reshape(batch, num_kv_heads, reps, q_len, head_dim) |
230 | | - k = key |
231 | | - v = value |
| 229 | + if softmax_aux is None: |
| 230 | + q_bthd = jnp.transpose(query, (0, 2, 1, 3)) |
| 231 | + k_bthd = jnp.transpose(key, (0, 2, 1, 3)) |
| 232 | + v_bthd = jnp.transpose(value, (0, 2, 1, 3)) |
| 233 | + mask_4d = mask[:, None, :, :] |
| 234 | + |
| 235 | + out_bthd, _ = dense_attention( |
| 236 | + query=q_bthd, |
| 237 | + key=k_bthd, |
| 238 | + value=v_bthd, |
| 239 | + attention_mask=mask_4d, |
| 240 | + softmax_aux=None, |
| 241 | + softmax_scale=softmax_scale, |
| 242 | + logits_soft_cap=logits_soft_cap, |
| 243 | + dtype=q_bthd.dtype, |
| 244 | + softmax_dtype=None, |
| 245 | + dropout_prob=0.0, |
| 246 | + deterministic=True, |
| 247 | + dropout_rng=None, |
| 248 | + causal=causal, |
| 249 | + sliding_window=None, |
| 250 | + bias=None, |
| 251 | + init_bias=None, |
| 252 | + ) |
232 | 253 |
|
233 | | - scale = jnp.asarray(softmax_scale, dtype=q.dtype) |
234 | | - logits = jnp.einsum("bhrqd,bhkd->bhrqk", q * scale, k, optimize=True) |
| 254 | + out_bthd = out_bthd * (row_has_any & q_valid).astype(out_bthd.dtype)[:, :, None, None] |
| 255 | + return jnp.transpose(out_bthd, (0, 2, 1, 3)) |
235 | 256 |
|
| 257 | + reps = num_heads // num_kv_heads |
| 258 | + if reps != 1: |
| 259 | + key_h = jnp.repeat(key, repeats=reps, axis=1) |
| 260 | + value_h = jnp.repeat(value, repeats=reps, axis=1) |
| 261 | + else: |
| 262 | + key_h = key |
| 263 | + value_h = value |
| 264 | + |
| 265 | + logits = jnp.einsum("bhtd,bhkd->bhtk", query * softmax_scale, key_h, optimize=True) |
236 | 266 | if logits_soft_cap is not None: |
237 | | - cap = jnp.asarray(logits_soft_cap, dtype=logits.dtype) |
238 | | - logits = cap * jnp.tanh(logits / cap) |
| 267 | + logits = logits_soft_cap * jnp.tanh(logits / logits_soft_cap) |
239 | 268 |
|
240 | | - neg = jnp.finfo(logits.dtype).min |
241 | | - logits = jnp.where(mask[:, None, None, :, :], logits, neg) |
| 269 | + logits = jnp.where(mask[:, None, :, :], logits, jnp.finfo(logits.dtype).min) |
242 | 270 |
|
243 | 271 | aux = _normalize_softmax_aux(softmax_aux, num_heads=num_heads, num_kv_heads=num_kv_heads, dtype=logits.dtype) |
244 | | - if aux is not None: |
245 | | - aux = aux.reshape(num_kv_heads, reps, aux.shape[-1]) |
246 | | - sinks = jnp.broadcast_to(aux[None, :, :, None, :], (batch, num_kv_heads, reps, q_len, aux.shape[-1])) |
247 | | - combined = jnp.concatenate([logits, sinks], axis=-1) |
248 | | - probs = jax.nn.softmax(combined.astype(jnp.float32), axis=-1).astype(logits.dtype) |
249 | | - weights = probs[..., :kv_len] |
250 | | - else: |
251 | | - weights = jax.nn.softmax(logits.astype(jnp.float32), axis=-1).astype(logits.dtype) |
252 | | - |
253 | | - weights = weights * row_has_any[:, None, None, :, None].astype(weights.dtype) |
254 | | - |
255 | | - out = jnp.einsum("bhrqk,bhkd->bhrqd", weights, v, optimize=True).reshape(batch, num_heads, q_len, value.shape[-1]) |
256 | | - out = out * q_valid[:, None, :, None].astype(out.dtype) |
| 272 | + assert aux is not None |
| 273 | + sinks = jnp.broadcast_to(aux[None, :, None, :], (batch, num_heads, q_len, aux.shape[-1])) |
| 274 | + combined = jnp.concatenate([logits, sinks], axis=-1) |
| 275 | + probs = jax.nn.softmax(combined.astype(jnp.float32), axis=-1).astype(logits.dtype) |
| 276 | + weights = probs[..., :kv_len] |
| 277 | + |
| 278 | + out = jnp.einsum("bhtk,bhkd->bhtd", weights, value_h, optimize=True) |
| 279 | + out = out * (row_has_any & q_valid).astype(out.dtype)[:, None, :, None] |
257 | 280 | return out |
0 commit comments