-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
Your function seem a bit inefficient, since it calculates Q @ K three times. One time to for max_score one time sum_exp and one last time for attn_prob. Even when Q and K are sparse, this could be quite inefficient.
Maybe something like the following (still without flash-attention like tiling) could be more efficient. Note that you need to add your padding logic (and maybe more) which I did not add. In this version Q @ K is only calculated once, but the score tensor is saved. You could also recalculate the score and calculate Q @ K twice, still one time less than your original version.
import torch
import triton
import triton.language as tl
@triton.jit
def _efficient_fused_attention_kernel(
Q_ptr, K_ptr, V_ptr, Out_ptr,
q_indices_ptr, k_indices_ptr,
stride_qz, stride_qh, stride_qs, stride_qd,
stride_kz, stride_kh, stride_ks, stride_kd,
stride_vz, stride_vh, stride_vs, stride_vd,
stride_oz, stride_oh, stride_os, stride_od,
num_queries: tl.constexpr,
num_keys: tl.constexpr,
head_dim: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr
):
pid = tl.program_id(0)
if pid >= num_queries:
return
# Load Q vector for the current query from global memory
q_idx = tl.load(q_indices_ptr + pid)
Q = Q_ptr + q_idx * stride_qs
q_vals = tl.load(Q + tl.arange(0, BLOCK_SIZE_D) * stride_qd, mask=tl.arange(0, BLOCK_SIZE_D) < head_dim)
# Initialize running max and sum_exp for this query
max_score = float('-inf')
sum_exp = 0.0
# Store scores and context for a single pass
scores = tl.zeros([num_keys], dtype=tl.float32)
# First pass: Compute scores and find running max and sum_exp
head_dim_sqrt = tl.sqrt(float(head_dim))
for k in range(num_keys):
k_idx = tl.load(k_indices_ptr + k)
K = K_ptr + k_idx * stride_ks
k_vals = tl.load(K + tl.arange(0, BLOCK_SIZE_D) * stride_kd, mask=tl.arange(0, BLOCK_SIZE_D) < head_dim)
# In-place dot product and scale
score = tl.dot(q_vals, k_vals) / head_dim_sqrt
# Update running max and sum_exp
old_max_score = max_score
max_score = tl.maximum(old_max_score, score)
sum_exp = sum_exp * tl.exp(old_max_score - max_score) + tl.exp(score - max_score)
# Store the raw score
tl.store(scores + k, score)
# Final pass: Compute context vector
context = tl.zeros([BLOCK_SIZE_D], dtype=tl.float32)
for k in range(num_keys):
# Load the raw score from the first pass or you could recalculate them here if you really want to avoid storing them
score = tl.load(scores + k)
# Calculate attention probability
attn_prob = tl.exp(score - max_score) / sum_exp
# Load V vector and accumulate context
k_idx = tl.load(k_indices_ptr + k)
V = V_ptr + k_idx * stride_vs
v_vals = tl.load(V + tl.arange(0, BLOCK_SIZE_D) * stride_vd, mask=tl.arange(0, BLOCK_SIZE_D) < head_dim)
context += attn_prob * v_vals
# Store the final context vector to output
Out = Out_ptr + q_idx * stride_os
tl.store(Out + tl.arange(0, BLOCK_SIZE_D) * stride_od, context, mask=tl.arange(0, BLOCK_SIZE_D) < head_dim)
# Python wrapper function to launch the kernel
# TODO: Insert padding
def efficient_sparse_attention(Q, K, V, q_indices, k_indices):
num_queries = q_indices.shape[0]
num_keys = k_indices.shape[0]
batch_size, num_heads, _, head_dim = Q.shape
assert head_dim == K.shape[-1] == V.shape[-1]
Out = torch.empty_like(Q)
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
# Grid is based on the number of queries to process
grid = (num_queries, )
_efficient_fused_attention_kernel[grid](
Q, K, V, Out,
q_indices, k_indices,
Q.stride(0), Q.stride(1), Q.stride(2), Q.stride(3),
K.stride(0), K.stride(1), K.stride(2), K.stride(3),
V.stride(0), V.stride(1), V.stride(2), V.stride(3),
Out.stride(0), Out.stride(1), Out.stride(2), Out.stride(3),
num_queries=num_queries,
num_keys=num_keys,
head_dim=head_dim,
BLOCK_SIZE_D=BLOCK_SIZE_D
)
return Out
Metadata
Metadata
Assignees
Labels
No labels