-
Notifications
You must be signed in to change notification settings - Fork 506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ragged-paged-attn] Use hidden states in kv cache and support any num_kv_head #8851
Conversation
@@ -1014,8 +955,7 @@ def ragged_paged_attention( | |||
): | |||
if mask_value is None: | |||
mask_value = DEFAULT_MASK_VALUE | |||
validate_ragged_paged_attention_inputs(q, k_pages, v_pages, kv_lens, | |||
page_indices, cu_q_lens, num_seqs) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why stopped checking validate_ragged_paged_attention_inputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we have these static shape check in JAX already
torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py
Outdated
Show resolved
Hide resolved
|
||
q_packing = get_dtype_packing(q_dtype) | ||
max_q_tiling = 8 * q_packing | ||
min_q_heads = lcm(max_q_tiling, num_q_heads_per_kv_head) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if I follow. If dtype is bf16, then max_q_tiling is 16. If it's qwen where num_q_heads=12, kum_kv_head=2, num_q_heads_per_kv_head=6, then min_q_heads (=lcm(max_q_tiling, num_q_heads_per_kv_head)) will be 48. What does min_q_heads
mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It tries to find a min number that is fully divisible by both max_q_tiling
and num_q_heads_per_kv_head
, if this number can divide total num_q_heads evenly, we just use this number as num_q_heads_per_blk. If we can not find one, we use the total num_q_heads .
Checking if it is divisible by max_q_tiling
is to make sure it can be fully tiled by XLA.
Checking if it is divisible by num_q_heads_per_kv_head
is to make sure we do not need to have inner split in num_q_heads_per_kv_head
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Could you add what you said as a comment in the code?
max_kv_len = jnp.max(kv_lens) | ||
min_pages_per_seq = ceil_div(max_kv_len, page_size) | ||
min_pages_per_seq = cdiv(max_kv_len, page_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is it min? Shouldn't it be max_pages_per_seq since you used cdiv(jnp.max(kv_lens), page_size)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is lower bound for pages_per_seq
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gotcha
_, page_size, kv_model_dim = k_pages.shape | ||
kv_packing = get_dtype_packing(k_pages.dtype) | ||
if page_size % kv_packing != 0: | ||
raise ValueError(f"Expected {page_size=} is divisible by {kv_packing=}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
page_size % kv_packing != 0 indicating there will be padding so we may waste some memory. Can we give a warning instead of raising an exception?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The page size is chosen by the serving config, the error indicates we should choose better one. Otherwise when people using bf16 or quantized types (fp8, int8, int4) there will be no bandwidth saving. We should prevent this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I guess it's the same reason why before this PR when num_kv_head==1 and dtype=bf16, we would raise an exception
xla/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py
Lines 534 to 536 in c7d0b1e
if not can_be_xla_fully_tiled(num_kv_heads, kv_packing): | |
raise ValueError( | |
f"Not implemented: {num_kv_heads=} can not be XLA fully tiled.") |
Previous if num_kv_head == 1 and dtype=bfloat16, we will have implicit padding in TPU.
). The point is the code may still run fine but there will be no bandwidth savings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the point of quantization is to save more memory and bandwidth
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Jevin. LGTM pending on CI.
@bythew3i I assume you have run the tests tests/pallas/tpu_ragged_paged_attention_test.py and they all pass? |
Yes I tested the kernel. |
This PR uses hidden states (num_kv_head * head_dim) in kv cache. This change can unblock us with any num_kv_head. Previous if num_kv_head == 1 and dtype=bfloat16, we will have implicit padding in TPU. But now, after just using hidden states directly from projection, we no-longer need to use strided load, but just load by slice directly.
This PR should help us support multi-chip sharding which shard num_kv_head to 1 for llama-3-70B.
Tested: