Skip to content

Inefficient Fused Kernel #2

@t0278611

Description

@t0278611

Your function seem a bit inefficient, since it calculates Q @ K three times. One time to for max_score one time sum_exp and one last time for attn_prob. Even when Q and K are sparse, this could be quite inefficient.

Maybe something like the following (still without flash-attention like tiling) could be more efficient. Note that you need to add your padding logic (and maybe more) which I did not add. In this version Q @ K is only calculated once, but the score tensor is saved. You could also recalculate the score and calculate Q @ K twice, still one time less than your original version.

import torch
import triton
import triton.language as tl

@triton.jit
def _efficient_fused_attention_kernel(
    Q_ptr, K_ptr, V_ptr, Out_ptr,
    q_indices_ptr, k_indices_ptr,
    stride_qz, stride_qh, stride_qs, stride_qd,
    stride_kz, stride_kh, stride_ks, stride_kd,
    stride_vz, stride_vh, stride_vs, stride_vd,
    stride_oz, stride_oh, stride_os, stride_od,
    num_queries: tl.constexpr,
    num_keys: tl.constexpr,
    head_dim: tl.constexpr,
    BLOCK_SIZE_D: tl.constexpr
):
    pid = tl.program_id(0)
    if pid >= num_queries:
        return

    # Load Q vector for the current query from global memory
    q_idx = tl.load(q_indices_ptr + pid)
    Q = Q_ptr + q_idx * stride_qs
    q_vals = tl.load(Q + tl.arange(0, BLOCK_SIZE_D) * stride_qd, mask=tl.arange(0, BLOCK_SIZE_D) < head_dim)

    # Initialize running max and sum_exp for this query
    max_score = float('-inf')
    sum_exp = 0.0

    # Store scores and context for a single pass
    scores = tl.zeros([num_keys], dtype=tl.float32)

    # First pass: Compute scores and find running max and sum_exp
    head_dim_sqrt = tl.sqrt(float(head_dim))
    for k in range(num_keys):
        k_idx = tl.load(k_indices_ptr + k)
        K = K_ptr + k_idx * stride_ks
        k_vals = tl.load(K + tl.arange(0, BLOCK_SIZE_D) * stride_kd, mask=tl.arange(0, BLOCK_SIZE_D) < head_dim)
        
        # In-place dot product and scale
        score = tl.dot(q_vals, k_vals) / head_dim_sqrt

        # Update running max and sum_exp
        old_max_score = max_score
        max_score = tl.maximum(old_max_score, score)
        sum_exp = sum_exp * tl.exp(old_max_score - max_score) + tl.exp(score - max_score)
        
        # Store the raw score
        tl.store(scores + k, score)

    # Final pass: Compute context vector
    context = tl.zeros([BLOCK_SIZE_D], dtype=tl.float32)
    for k in range(num_keys):
        # Load the raw score from the first pass or you could recalculate them here if you really want to avoid storing them
        score = tl.load(scores + k)
        
        # Calculate attention probability
        attn_prob = tl.exp(score - max_score) / sum_exp

        # Load V vector and accumulate context
        k_idx = tl.load(k_indices_ptr + k)
        V = V_ptr + k_idx * stride_vs
        v_vals = tl.load(V + tl.arange(0, BLOCK_SIZE_D) * stride_vd, mask=tl.arange(0, BLOCK_SIZE_D) < head_dim)
        context += attn_prob * v_vals

    # Store the final context vector to output
    Out = Out_ptr + q_idx * stride_os
    tl.store(Out + tl.arange(0, BLOCK_SIZE_D) * stride_od, context, mask=tl.arange(0, BLOCK_SIZE_D) < head_dim)

# Python wrapper function to launch the kernel
# TODO: Insert padding
def efficient_sparse_attention(Q, K, V, q_indices, k_indices):
    num_queries = q_indices.shape[0]
    num_keys = k_indices.shape[0]
    batch_size, num_heads, _, head_dim = Q.shape
    assert head_dim == K.shape[-1] == V.shape[-1]
    
    Out = torch.empty_like(Q)
    
    BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
    
    # Grid is based on the number of queries to process
    grid = (num_queries, )
    
    _efficient_fused_attention_kernel[grid](
        Q, K, V, Out,
        q_indices, k_indices,
        Q.stride(0), Q.stride(1), Q.stride(2), Q.stride(3),
        K.stride(0), K.stride(1), K.stride(2), K.stride(3),
        V.stride(0), V.stride(1), V.stride(2), V.stride(3),
        Out.stride(0), Out.stride(1), Out.stride(2), Out.stride(3),
        num_queries=num_queries,
        num_keys=num_keys,
        head_dim=head_dim,
        BLOCK_SIZE_D=BLOCK_SIZE_D
    )
    return Out

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions