Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ragged-paged-attn] Use hidden states in kv cache and support any num_kv_head #8851

Merged
merged 4 commits into from
Mar 19, 2025

Conversation

bythew3i
Copy link
Contributor

This PR uses hidden states (num_kv_head * head_dim) in kv cache. This change can unblock us with any num_kv_head. Previous if num_kv_head == 1 and dtype=bfloat16, we will have implicit padding in TPU. But now, after just using hidden states directly from projection, we no-longer need to use strided load, but just load by slice directly.

This PR should help us support multi-chip sharding which shard num_kv_head to 1 for llama-3-70B.

Tested:

python test/test_pallas.py -v -k PallasTest.test_ragged_paged_attention_wrapper

@@ -1014,8 +955,7 @@ def ragged_paged_attention(
):
if mask_value is None:
mask_value = DEFAULT_MASK_VALUE
validate_ragged_paged_attention_inputs(q, k_pages, v_pages, kv_lens,
page_indices, cu_q_lens, num_seqs)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why stopped checking validate_ragged_paged_attention_inputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we have these static shape check in JAX already


q_packing = get_dtype_packing(q_dtype)
max_q_tiling = 8 * q_packing
min_q_heads = lcm(max_q_tiling, num_q_heads_per_kv_head)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if I follow. If dtype is bf16, then max_q_tiling is 16. If it's qwen where num_q_heads=12, kum_kv_head=2, num_q_heads_per_kv_head=6, then min_q_heads (=lcm(max_q_tiling, num_q_heads_per_kv_head)) will be 48. What does min_q_heads mean?

Copy link
Contributor Author

@bythew3i bythew3i Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It tries to find a min number that is fully divisible by both max_q_tiling and num_q_heads_per_kv_head, if this number can divide total num_q_heads evenly, we just use this number as num_q_heads_per_blk. If we can not find one, we use the total num_q_heads .

Checking if it is divisible by max_q_tiling is to make sure it can be fully tiled by XLA.
Checking if it is divisible by num_q_heads_per_kv_head is to make sure we do not need to have inner split in num_q_heads_per_kv_head.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Could you add what you said as a comment in the code?

max_kv_len = jnp.max(kv_lens)
min_pages_per_seq = ceil_div(max_kv_len, page_size)
min_pages_per_seq = cdiv(max_kv_len, page_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it min? Shouldn't it be max_pages_per_seq since you used cdiv(jnp.max(kv_lens), page_size)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is lower bound for pages_per_seq

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gotcha

_, page_size, kv_model_dim = k_pages.shape
kv_packing = get_dtype_packing(k_pages.dtype)
if page_size % kv_packing != 0:
raise ValueError(f"Expected {page_size=} is divisible by {kv_packing=}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

page_size % kv_packing != 0 indicating there will be padding so we may waste some memory. Can we give a warning instead of raising an exception?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The page size is chosen by the serving config, the error indicates we should choose better one. Otherwise when people using bf16 or quantized types (fp8, int8, int4) there will be no bandwidth saving. We should prevent this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I guess it's the same reason why before this PR when num_kv_head==1 and dtype=bf16, we would raise an exception

if not can_be_xla_fully_tiled(num_kv_heads, kv_packing):
raise ValueError(
f"Not implemented: {num_kv_heads=} can not be XLA fully tiled.")
(Previous if num_kv_head == 1 and dtype=bfloat16, we will have implicit padding in TPU.). The point is the code may still run fine but there will be no bandwidth savings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the point of quantization is to save more memory and bandwidth

Copy link
Collaborator

@vanbasten23 vanbasten23 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Jevin. LGTM pending on CI.

@vanbasten23
Copy link
Collaborator

@bythew3i I assume you have run the tests tests/pallas/tpu_ragged_paged_attention_test.py and they all pass?

@bythew3i
Copy link
Contributor Author

@bythew3i I assume you have run the tests tests/pallas/tpu_ragged_paged_attention_test.py and they all pass?

Yes I tested the kernel.

@vanbasten23 vanbasten23 merged commit 4190fc0 into pytorch:master Mar 19, 2025
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants