Conversation
There was a problem hiding this comment.
Pull request overview
This PR aims to improve TPU inference performance for DeepSeek MLA attention by fusing KV-cache updates into the MLA ragged paged attention kernel (Pallas), enabling overlap between KV writes/prefetch (scalar lane) and attention compute (VPU lane). It also updates sharding configuration to introduce an attn_dp_expert mesh axis and adjusts quantization/random-weight-loading utilities accordingly.
Changes:
- Introduce an additional mesh axis (
attn_dp_expert) and propagate it through sharding strategy + TPU runner mesh construction. - Add MLA attention-side KV quantization plumbing (scales + key quantization) and adjust output sharding constraints.
- Add a new MLA v1 “baseline” kernel module and update Qwix random-weight-loading scale key construction.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
tpu_inference/runner/tpu_runner.py |
Imports multihost utils, updates mesh shape to include attn_dp_expert, and tweaks compilation padding buckets for KV packing/alignment. |
tpu_inference/models/jax/utils/qwix/qwix_utils.py |
Changes scale-key derivation for Qwix random weight loading to handle deeper module paths. |
tpu_inference/models/jax/deepseek_v3.py |
Adjusts attention output sharding constraint placement; adds KV quantization for MLA inputs and new sharding knobs. |
tpu_inference/layers/common/sharding.py |
Adds attn_dp_expert axis, extends sharding axis-name groupings, and updates DP size computation/validation. |
tpu_inference/layers/common/quantization/__init__.py |
Makes quantize_kv accept value=None for key-only quantization. |
tpu_inference/kernels/mla/v1/baseline.py |
Adds a new MLA v1 kernel/baseline implementation, including KV update logic and Pallas call scaffolding. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _create_single_slice_mesh(self) -> jax.Array: | ||
| sharding_strategy: ShardingConfigManager = self.vllm_config.sharding_config | ||
| mesh_shape = ( | ||
| sharding_strategy.model_dp_size, | ||
| sharding_strategy.attn_dp_size, | ||
| sharding_strategy.attn_dp_expert_size, | ||
| sharding_strategy.expert_size, | ||
| sharding_strategy.tp_size, | ||
| ) |
There was a problem hiding this comment.
Adding attn_dp_expert_size introduces a 5D mesh shape for single-slice, but _create_multi_slice_mesh() still builds a 4D ici_mesh_shape while the mesh axis names (MESH_AXIS_NAMES) are now 5D. This will likely cause a shape/axis mismatch (or silently incorrect sharding) when NUM_SLICES > 1; update the multi-slice mesh construction to include the new attn_dp_expert axis (and adjust dcn_mesh_shape accordingly).
tpu_inference/runner/tpu_runner.py
Outdated
| additional_sizes = self.vllm_config.additional_config.get("compilation_sizes", []) | ||
| # [16, 32, 64, 128, 256, 512, 1024, 2048] | ||
| cache_dtype = self.cache_config.cache_dtype | ||
| if cache_dtype == "auto": | ||
| cache_dtype = self.dtype | ||
| kv_cache_dtype = to_jax_dtype(cache_dtype) | ||
| kv_packing = common_utils.get_dtype_packing(kv_cache_dtype) | ||
| self.num_tokens_paddings = runner_utils.get_token_paddings( | ||
| min_token_size=max(16, self.dp_size), | ||
| min_token_size=max(16, self.dp_size * kv_packing), | ||
| max_token_size=scheduler_config.max_num_batched_tokens * | ||
| self.dp_size, | ||
| padding_gap=vllm_envs.VLLM_TPU_BUCKET_PADDING_GAP) | ||
| self.num_tokens_paddings = sorted(self.num_tokens_paddings + additional_sizes) | ||
| self.num_tokens_paddings_per_dp = [ | ||
| padding // self.dp_size for padding in self.num_tokens_paddings |
There was a problem hiding this comment.
additional_sizes are appended directly into num_tokens_paddings, but later num_tokens_paddings_per_dp is computed via padding // self.dp_size. If any additional_sizes entries are not multiples of dp_size (and/or the kv_packing alignment you just introduced), the per-DP padding will be truncated and can create inconsistent shapes between global/per-DP token counts. Consider validating/rounding additional_sizes to the required alignment before merging them into the paddings list.
| num_queries_per_block=num_queries_per_block, | ||
| q_scale=q_scale, | ||
| k_scale=k_scale, | ||
| v_scale=k_scale) |
There was a problem hiding this comment.
v_scale is computed (and set to self._v_scale for quantized KV) but the MLA kernel call passes v_scale=k_scale. This applies the wrong dequant scale to the attention output; it should pass the value scale (v_scale).
| v_scale=k_scale) | |
| v_scale=v_scale) |
| ) -> Tuple[jax.Array, jax.Array]: | ||
| """Static quantize key and value tensors.""" | ||
| key = static_per_tensor_quantize_tensor(dtype, key, k_scale) | ||
| if value is None: | ||
| return key, None |
There was a problem hiding this comment.
quantize_kv now allows value=None and returns (key, None), but the return type annotation is still Tuple[jax.Array, jax.Array]. Also, k_scale/v_scale are defaulted to None but are passed directly into static_per_tensor_quantize_tensor, which expects a real scale value (will error if a caller relies on the defaults). Consider updating the return type to include None (e.g. jax.Array | None) and either make scales required again or add an explicit check/ValueError when they are None.
| ) -> Tuple[jax.Array, jax.Array]: | |
| """Static quantize key and value tensors.""" | |
| key = static_per_tensor_quantize_tensor(dtype, key, k_scale) | |
| if value is None: | |
| return key, None | |
| ) -> tuple[jax.Array, jax.Array | None]: | |
| """Static quantize key and value tensors.""" | |
| if k_scale is None: | |
| raise ValueError("k_scale must be provided for quantize_kv.") | |
| key = static_per_tensor_quantize_tensor(dtype, key, k_scale) | |
| if value is None: | |
| return key, None | |
| if v_scale is None: | |
| raise ValueError( | |
| "v_scale must be provided for quantize_kv when value is not None." | |
| ) |
| def update_kv_cache( | ||
| new_kv_c: jax.Array, # [num_tokens, actual_lkv_dim] | ||
| new_k_pe: jax.Array, # [num_tokens, actual_r_dim] | ||
| cache_kv: jax. | ||
| Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim+r_dim] | ||
| kv_lens: jax.Array, # i32[max_num_seqs] | ||
| page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq] | ||
| cu_q_lens: jax.Array, # i32[max_num_seqs + 1] | ||
| distribution: jax.Array, # i32[3] | ||
| ) -> tuple[jax.Array, jax.Array]: | ||
| """Update KV cache with new tokens.""" | ||
| actual_r_dim = new_k_pe.shape[-1] | ||
| r_dim = align_to(actual_r_dim, 128) | ||
| if actual_r_dim != r_dim: | ||
| new_k_pe = jnp.pad(new_k_pe, ((0, 0), (0, r_dim - actual_r_dim)), | ||
| constant_values=0) | ||
| actual_lkv_dim = new_kv_c.shape[-1] | ||
| lkv_dim = align_to(actual_lkv_dim, 128) | ||
| if actual_lkv_dim != lkv_dim: | ||
| new_kv_c = jnp.pad(new_kv_c, ((0, 0), (0, lkv_dim - actual_lkv_dim)), | ||
| constant_values=0) | ||
| kv_dim = r_dim + lkv_dim | ||
| _, page_size_per_kv_packing, kv_packing, cache_kv_dim = cache_kv.shape | ||
| assert kv_dim == cache_kv_dim | ||
| page_size = page_size_per_kv_packing * kv_packing | ||
|
|
||
| max_num_seqs = kv_lens.shape[0] | ||
| num_page_indices = page_indices.shape[0] | ||
| pages_per_seq = num_page_indices // max_num_seqs | ||
|
|
||
| def seq_loop_body(i, cache_kv): | ||
| q_start, q_end = cu_q_lens[i], cu_q_lens[i + 1] | ||
| q_len = q_end - q_start | ||
| kv_len = kv_lens[i] | ||
|
|
||
| def token_loop_body(j, cache_kv_): | ||
| token_idx_in_seq = kv_len - q_len + j | ||
| page_num_in_seq = token_idx_in_seq // page_size | ||
| page_indices_start = i * pages_per_seq | ||
| page_idx = page_indices[page_indices_start + page_num_in_seq] | ||
| row = (token_idx_in_seq % page_size) // kv_packing | ||
| col = (token_idx_in_seq % page_size) % kv_packing | ||
|
|
||
| cache_kv_ = cache_kv_.at[page_idx, row, col, | ||
| ..., :lkv_dim].set(new_kv_c[q_start + j]) | ||
| cache_kv_ = cache_kv_.at[page_idx, row, col, ..., | ||
| lkv_dim:].set(new_k_pe[q_start + j]) | ||
| return cache_kv_ | ||
|
|
||
| return lax.fori_loop(0, q_len, token_loop_body, cache_kv) | ||
|
|
||
| cache_kv = lax.fori_loop(0, distribution[-1], seq_loop_body, cache_kv) | ||
|
|
||
| return cache_kv |
There was a problem hiding this comment.
update_kv_cache is annotated as returning tuple[jax.Array, jax.Array], but it actually returns a single cache_kv array. This is inconsistent with its implementation and with how callers use it (as a single array), and will confuse type-checkers/readers; update the return annotation (and docstring if needed) to match the actual return value.
| # Fused KV-Cache update: handled inside the Pallas kernel. | ||
| # Use JAX-compatible update_kv_cache to update the cache in a jit-friendly way | ||
| # Parallelize KV cache update across sequences using vmap | ||
| # Pad new_kv_c and new_k_pe to aligned dims before update | ||
| actual_lkv_dim = new_kv_c.shape[-1] | ||
| actual_r_dim = new_k_pe.shape[-1] | ||
| lkv_dim = align_to(actual_lkv_dim, 128) | ||
| r_dim = align_to(actual_r_dim, 128) | ||
| if actual_lkv_dim != lkv_dim: | ||
| new_kv_c = jnp.pad(new_kv_c, ((0, 0), (0, lkv_dim - actual_lkv_dim)), constant_values=0) | ||
| if actual_r_dim != r_dim: | ||
| new_k_pe = jnp.pad(new_k_pe, ((0, 0), (0, r_dim - actual_r_dim)), constant_values=0) | ||
|
|
||
| def update_kv_cache_per_seq(seq_idx, cache_kv): | ||
| q_start = cu_q_lens[seq_idx] | ||
| q_end = cu_q_lens[seq_idx + 1] | ||
| q_len = q_end - q_start | ||
| kv_len = kv_lens[seq_idx] | ||
| _, page_size_per_kv_packing, kv_packing, _ = cache_kv.shape | ||
| page_size = page_size_per_kv_packing * kv_packing | ||
| num_page_indices = page_indices.shape[0] | ||
| max_num_seqs = kv_lens.shape[0] | ||
| pages_per_seq = num_page_indices // max_num_seqs | ||
| def update_token(j, cache_kv_): | ||
| token_idx_in_seq = kv_len - q_len + j | ||
| page_num_in_seq = token_idx_in_seq // page_size | ||
| page_indices_start = seq_idx * pages_per_seq | ||
| page_idx = page_indices[page_indices_start + page_num_in_seq] | ||
| row = (token_idx_in_seq % page_size) // kv_packing | ||
| col = (token_idx_in_seq % page_size) % kv_packing | ||
| cache_kv_ = cache_kv_.at[page_idx, row, col, ..., :lkv_dim].set(new_kv_c[q_start + j]) | ||
| cache_kv_ = cache_kv_.at[page_idx, row, col, ..., lkv_dim:].set(new_k_pe[q_start + j]) | ||
| return cache_kv_ | ||
| cache_kv = jax.lax.fori_loop(0, q_len, update_token, cache_kv) | ||
| return cache_kv | ||
|
|
||
| seq_indices = jnp.arange(kv_lens.shape[0]) | ||
| cache_kv = jax.lax.fori_loop(0, seq_indices.shape[0], update_kv_cache_per_seq, cache_kv) | ||
|
|
There was a problem hiding this comment.
This block claims the KV-cache update is "handled inside the Pallas kernel", but the code performs a full KV update pass in Python/JAX (fori_loop over sequences and tokens) before launching the Pallas attention kernel. That contradicts the intended fusion/concurrency described in the PR and will reintroduce the extra HBM pass (and likely compile a large scatter loop). Either move the KV update into _mla_ragged_paged_attention_kernel (scalar lane) or update the comments/PR description and keep this as an explicit unfused baseline path.
| # Currently tensor_parallelism is also used for other things like determining number of Ray workers. | ||
| pc_tensor_parallelism = parallel_config.tensor_parallel_size | ||
| ss_tensor_parallelsim = sharding_strategy.get("tensor_parallelism", 1) | ||
| data_parallelism = parallel_config.data_parallel_size | ||
| expert_parallelism = sharding_strategy.get("expert_parallelism", 1) | ||
| sequence_parallelism = sharding_strategy.get("sequence_parallelism", 1) | ||
| device_indexes = sharding_strategy.get("device_indexes", None) | ||
|
|
||
| enable_dp_attention = sharding_strategy.get("enable_dp_attention", | ||
| False) | ||
| if pc_tensor_parallelism != ss_tensor_parallelsim and ss_tensor_parallelsim > 1: | ||
| # The user has explicitly set the tensor parallelism in the sharding config. | ||
| tensor_parallelism = ss_tensor_parallelsim | ||
| else: | ||
| tensor_parallelism = pc_tensor_parallelism |
There was a problem hiding this comment.
ss_tensor_parallelsim is misspelled (should be ss_tensor_parallelism). Beyond readability, this makes it easier to accidentally introduce a second similarly-named variable later; consider renaming for clarity/consistency.
| sharding_strategy = vllm_config.additional_config.get( | ||
| "sharding", {}).get("sharding_strategy", {}) | ||
| parallel_config = vllm_config.parallel_config | ||
| tensor_parallelism = parallel_config.tensor_parallel_size | ||
| # Currently tensor_parallelism is also used for other things like determining number of Ray workers. | ||
| pc_tensor_parallelism = parallel_config.tensor_parallel_size | ||
| ss_tensor_parallelsim = sharding_strategy.get("tensor_parallelism", 1) | ||
| data_parallelism = parallel_config.data_parallel_size | ||
| expert_parallelism = sharding_strategy.get("expert_parallelism", 1) | ||
| sequence_parallelism = sharding_strategy.get("sequence_parallelism", 1) | ||
| device_indexes = sharding_strategy.get("device_indexes", None) | ||
|
|
||
| enable_dp_attention = sharding_strategy.get("enable_dp_attention", | ||
| False) | ||
| if pc_tensor_parallelism != ss_tensor_parallelsim and ss_tensor_parallelsim > 1: | ||
| # The user has explicitly set the tensor parallelism in the sharding config. | ||
| tensor_parallelism = ss_tensor_parallelsim | ||
| else: | ||
| tensor_parallelism = pc_tensor_parallelism | ||
|
|
There was a problem hiding this comment.
The logic for overriding tensor parallelism only applies when ss_tensor_parallelsim > 1. If a user explicitly sets tensor_parallelism to 1 in the sharding strategy (to override a larger parallel_config.tensor_parallel_size), that configuration will be ignored despite being explicitly provided. Consider checking for presence of the key (e.g. 'tensor_parallelism' in sharding_strategy) rather than > 1 so explicit overrides to 1 are honored.
There was a problem hiding this comment.
There's already a kernel.py in main. Can you diff base based on that file instead of submitting a whole new file?
There was a problem hiding this comment.
Yes, these changes are specifically for MLA optimization. I initially added new files to avoid breaking the base version during testing, but I'll move the optimized fused logic into the existing kernel.py now and remove the redundant files to keep the diff clean.
There was a problem hiding this comment.
we need to add test coverage before submission imo (see: tests/kernels/mla_v1_test.py)
Please work with Jaehong to expand test coverage and get those committed.
There was a problem hiding this comment.
Agreed on the test coverage. add cases to mla_v1_test.py that specifically exercise the new fused version and vectorization paths. I'll include those in the next push.
|
there so many conflicts in this pr. please update the branch to main's head first. |
…pdated some shardings in DSV3.
474def3 to
0f41147
Compare
|
Is this kernel connected to anywhere? I don't see a way for e2e to trigger this code path? |
…, license header)
This PR fuses the KV Cache update into the MLA attention kernel using Pallas.
Baseline — no fusion
Two separate passes over HBM:
1. Update KV cache — scatter new tokens into cache_kv in HBM (one read + write per token)
2. Run attention — DMA-fetch KV blocks from HBM, run Flash Attention on VPU
The cache update and attention are sequential. The VPU sits idle while the cache writes finish.
Fused — KV Cache update inside Pallas attention
A single Pallas kernel does both in one pass. Pallas on TPU exposes two independent execution lanes:
1. Scalar unit — runs DMA commands (prefetch next KV block + write new tokens into cache)
2. Vector unit (VPU) — runs Flash Attention on the current KV block
These run concurrently. While the VPU computes attention on block N, the scalar unit simultaneously writes new tokens into the cache and prefetches block N+1. The KV write latency is fully hidden behind VPU compute.
Checklist
I have performed a self-review of my code.
I have necessary comments in my code, particularly in hard-to-understand areas.
I have made or will make corresponding changes to any relevant documentation.