Skip to content

[RFC] Multi-queries paged attention Pallas kernel #8597

Open
@vanbasten23

Description

🚀 Feature

We need a paged attention capable of handling multiple query tokens in a sequence.

Motivation

The [existing paged attention(https://github.com/jax-ml/jax/blob/3aa55992fe374987ff3701b69d6814c007c37bb3/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py#L374) has limited support in that it requires the input to have a single token per sequence. This limitation prevents vLLM from having other powerful features such as speculative decoding, prefix caching, and chunked prefill. In speculative decoding, for example, we need to decode multiple input query tokens in parallel at the same time for a given sequence. Hence, this new kernel is a hard blocker for vLLM to shine on TPU. So a new paged attention is needed.

Pitch

We need a new Pallas kernel:

def paged_attention(
    q: jax.Array,			# [batch_size, query_len, num_heads, head_dim] 
    k_pages: jax.Array,		# [num_kv_heads, total_num_pages, page_size, head_dim]
    v_pages: jax.Array,		# [num_kv_heads, total_num_pages, page_size, head_dim]
    lengths: jax.Array,		# i32[batch_size]
    page_indices: jax.Array,	# i32[batch_size, pages_per_sequence]
    effective_q_lens: jax.Array, # i32[batch_size]
) -> jax.Array:			# [batch_size, query_len, num_heads, head_dim]

The rough logic maps to

q=jnp.permute_dim(q, (0,2,1,3))  # in order to put the num_head dim before length dim
for b_idx in range(batch_size):
  for kv_head_idx in range(num_kv_heads):
    for q_blk_idx in range(num_queries_len_blocks):
      for kv_blk_idx in range(num_kv_len_blocks):
        # Within the kernel
        # q.shape=[num_q_heads_per_kv_head, query_len_per_q_len_block, head_size]
        # Load the kv pages corresponding to the current batch from HBM to VMEM
        for q_head_idx in range(num_q_heads_per_kv_head):
          # Within the flash attention kernel
          # q.shape=[query_len_per_q_len_block, head_size]
          # k.shape=[kv_len_per_kv_len_block, head_size]
          # attn=[query_len_per_q_len_block, kv_len_per_kv_len_block]
          # v.shape=[kv_len_per_kv_len_block, head_size]\
          # out.shape=[query_len_per_q_len_block, head_size]
          # save out to q_head_idx of final_out.
        # final_out.shape=[num_q_heads_per_kv_head, query_len_per_q_len_block, head_size]

Alternatives

Additional context

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions