Skip to content

MLA KV Cache Fusion TPU Inference #1856

Open
mourado wants to merge 11 commits intomainfrom
bouache_mla/tpu_inference
Open

MLA KV Cache Fusion TPU Inference #1856
mourado wants to merge 11 commits intomainfrom
bouache_mla/tpu_inference

Conversation

@mourado
Copy link
Collaborator

@mourado mourado commented Mar 4, 2026

This PR fuses the KV Cache update into the MLA attention kernel using Pallas.

Baseline — no fusion
Two separate passes over HBM:

1. Update KV cache — scatter new tokens into cache_kv in HBM (one read + write per token)
2. Run attention — DMA-fetch KV blocks from HBM, run Flash Attention on VPU

The cache update and attention are sequential. The VPU sits idle while the cache writes finish.

Fused — KV Cache update inside Pallas attention
A single Pallas kernel does both in one pass. Pallas on TPU exposes two independent execution lanes:

1. Scalar unit — runs DMA commands (prefetch next KV block + write new tokens into cache)
2. Vector unit (VPU) — runs Flash Attention on the current KV block

These run concurrently. While the VPU computes attention on block N, the scalar unit simultaneously writes new tokens into the cache and prefetches block N+1. The KV write latency is fully hidden behind VPU compute.

Checklist
I have performed a self-review of my code.
I have necessary comments in my code, particularly in hard-to-understand areas.
I have made or will make corresponding changes to any relevant documentation.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR aims to improve TPU inference performance for DeepSeek MLA attention by fusing KV-cache updates into the MLA ragged paged attention kernel (Pallas), enabling overlap between KV writes/prefetch (scalar lane) and attention compute (VPU lane). It also updates sharding configuration to introduce an attn_dp_expert mesh axis and adjusts quantization/random-weight-loading utilities accordingly.

Changes:

  • Introduce an additional mesh axis (attn_dp_expert) and propagate it through sharding strategy + TPU runner mesh construction.
  • Add MLA attention-side KV quantization plumbing (scales + key quantization) and adjust output sharding constraints.
  • Add a new MLA v1 “baseline” kernel module and update Qwix random-weight-loading scale key construction.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
tpu_inference/runner/tpu_runner.py Imports multihost utils, updates mesh shape to include attn_dp_expert, and tweaks compilation padding buckets for KV packing/alignment.
tpu_inference/models/jax/utils/qwix/qwix_utils.py Changes scale-key derivation for Qwix random weight loading to handle deeper module paths.
tpu_inference/models/jax/deepseek_v3.py Adjusts attention output sharding constraint placement; adds KV quantization for MLA inputs and new sharding knobs.
tpu_inference/layers/common/sharding.py Adds attn_dp_expert axis, extends sharding axis-name groupings, and updates DP size computation/validation.
tpu_inference/layers/common/quantization/__init__.py Makes quantize_kv accept value=None for key-only quantization.
tpu_inference/kernels/mla/v1/baseline.py Adds a new MLA v1 kernel/baseline implementation, including KV update logic and Pallas call scaffolding.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 313 to 321
def _create_single_slice_mesh(self) -> jax.Array:
sharding_strategy: ShardingConfigManager = self.vllm_config.sharding_config
mesh_shape = (
sharding_strategy.model_dp_size,
sharding_strategy.attn_dp_size,
sharding_strategy.attn_dp_expert_size,
sharding_strategy.expert_size,
sharding_strategy.tp_size,
)
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding attn_dp_expert_size introduces a 5D mesh shape for single-slice, but _create_multi_slice_mesh() still builds a 4D ici_mesh_shape while the mesh axis names (MESH_AXIS_NAMES) are now 5D. This will likely cause a shape/axis mismatch (or silently incorrect sharding) when NUM_SLICES > 1; update the multi-slice mesh construction to include the new attn_dp_expert axis (and adjust dcn_mesh_shape accordingly).

Copilot uses AI. Check for mistakes.
Comment on lines 409 to 423
additional_sizes = self.vllm_config.additional_config.get("compilation_sizes", [])
# [16, 32, 64, 128, 256, 512, 1024, 2048]
cache_dtype = self.cache_config.cache_dtype
if cache_dtype == "auto":
cache_dtype = self.dtype
kv_cache_dtype = to_jax_dtype(cache_dtype)
kv_packing = common_utils.get_dtype_packing(kv_cache_dtype)
self.num_tokens_paddings = runner_utils.get_token_paddings(
min_token_size=max(16, self.dp_size),
min_token_size=max(16, self.dp_size * kv_packing),
max_token_size=scheduler_config.max_num_batched_tokens *
self.dp_size,
padding_gap=vllm_envs.VLLM_TPU_BUCKET_PADDING_GAP)
self.num_tokens_paddings = sorted(self.num_tokens_paddings + additional_sizes)
self.num_tokens_paddings_per_dp = [
padding // self.dp_size for padding in self.num_tokens_paddings
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

additional_sizes are appended directly into num_tokens_paddings, but later num_tokens_paddings_per_dp is computed via padding // self.dp_size. If any additional_sizes entries are not multiples of dp_size (and/or the kv_packing alignment you just introduced), the per-DP padding will be truncated and can create inconsistent shapes between global/per-DP token counts. Consider validating/rounding additional_sizes to the required alignment before merging them into the paddings list.

Copilot uses AI. Check for mistakes.
num_queries_per_block=num_queries_per_block,
q_scale=q_scale,
k_scale=k_scale,
v_scale=k_scale)
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v_scale is computed (and set to self._v_scale for quantized KV) but the MLA kernel call passes v_scale=k_scale. This applies the wrong dequant scale to the attention output; it should pass the value scale (v_scale).

Suggested change
v_scale=k_scale)
v_scale=v_scale)

Copilot uses AI. Check for mistakes.
Comment on lines 265 to +269
) -> Tuple[jax.Array, jax.Array]:
"""Static quantize key and value tensors."""
key = static_per_tensor_quantize_tensor(dtype, key, k_scale)
if value is None:
return key, None
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quantize_kv now allows value=None and returns (key, None), but the return type annotation is still Tuple[jax.Array, jax.Array]. Also, k_scale/v_scale are defaulted to None but are passed directly into static_per_tensor_quantize_tensor, which expects a real scale value (will error if a caller relies on the defaults). Consider updating the return type to include None (e.g. jax.Array | None) and either make scales required again or add an explicit check/ValueError when they are None.

Suggested change
) -> Tuple[jax.Array, jax.Array]:
"""Static quantize key and value tensors."""
key = static_per_tensor_quantize_tensor(dtype, key, k_scale)
if value is None:
return key, None
) -> tuple[jax.Array, jax.Array | None]:
"""Static quantize key and value tensors."""
if k_scale is None:
raise ValueError("k_scale must be provided for quantize_kv.")
key = static_per_tensor_quantize_tensor(dtype, key, k_scale)
if value is None:
return key, None
if v_scale is None:
raise ValueError(
"v_scale must be provided for quantize_kv when value is not None."
)

Copilot uses AI. Check for mistakes.
Comment on lines +51 to +104
def update_kv_cache(
new_kv_c: jax.Array, # [num_tokens, actual_lkv_dim]
new_k_pe: jax.Array, # [num_tokens, actual_r_dim]
cache_kv: jax.
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim+r_dim]
kv_lens: jax.Array, # i32[max_num_seqs]
page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
distribution: jax.Array, # i32[3]
) -> tuple[jax.Array, jax.Array]:
"""Update KV cache with new tokens."""
actual_r_dim = new_k_pe.shape[-1]
r_dim = align_to(actual_r_dim, 128)
if actual_r_dim != r_dim:
new_k_pe = jnp.pad(new_k_pe, ((0, 0), (0, r_dim - actual_r_dim)),
constant_values=0)
actual_lkv_dim = new_kv_c.shape[-1]
lkv_dim = align_to(actual_lkv_dim, 128)
if actual_lkv_dim != lkv_dim:
new_kv_c = jnp.pad(new_kv_c, ((0, 0), (0, lkv_dim - actual_lkv_dim)),
constant_values=0)
kv_dim = r_dim + lkv_dim
_, page_size_per_kv_packing, kv_packing, cache_kv_dim = cache_kv.shape
assert kv_dim == cache_kv_dim
page_size = page_size_per_kv_packing * kv_packing

max_num_seqs = kv_lens.shape[0]
num_page_indices = page_indices.shape[0]
pages_per_seq = num_page_indices // max_num_seqs

def seq_loop_body(i, cache_kv):
q_start, q_end = cu_q_lens[i], cu_q_lens[i + 1]
q_len = q_end - q_start
kv_len = kv_lens[i]

def token_loop_body(j, cache_kv_):
token_idx_in_seq = kv_len - q_len + j
page_num_in_seq = token_idx_in_seq // page_size
page_indices_start = i * pages_per_seq
page_idx = page_indices[page_indices_start + page_num_in_seq]
row = (token_idx_in_seq % page_size) // kv_packing
col = (token_idx_in_seq % page_size) % kv_packing

cache_kv_ = cache_kv_.at[page_idx, row, col,
..., :lkv_dim].set(new_kv_c[q_start + j])
cache_kv_ = cache_kv_.at[page_idx, row, col, ...,
lkv_dim:].set(new_k_pe[q_start + j])
return cache_kv_

return lax.fori_loop(0, q_len, token_loop_body, cache_kv)

cache_kv = lax.fori_loop(0, distribution[-1], seq_loop_body, cache_kv)

return cache_kv
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update_kv_cache is annotated as returning tuple[jax.Array, jax.Array], but it actually returns a single cache_kv array. This is inconsistent with its implementation and with how callers use it (as a single array), and will confuse type-checkers/readers; update the return annotation (and docstring if needed) to match the actual return value.

Copilot uses AI. Check for mistakes.
Comment on lines +1196 to +1234
# Fused KV-Cache update: handled inside the Pallas kernel.
# Use JAX-compatible update_kv_cache to update the cache in a jit-friendly way
# Parallelize KV cache update across sequences using vmap
# Pad new_kv_c and new_k_pe to aligned dims before update
actual_lkv_dim = new_kv_c.shape[-1]
actual_r_dim = new_k_pe.shape[-1]
lkv_dim = align_to(actual_lkv_dim, 128)
r_dim = align_to(actual_r_dim, 128)
if actual_lkv_dim != lkv_dim:
new_kv_c = jnp.pad(new_kv_c, ((0, 0), (0, lkv_dim - actual_lkv_dim)), constant_values=0)
if actual_r_dim != r_dim:
new_k_pe = jnp.pad(new_k_pe, ((0, 0), (0, r_dim - actual_r_dim)), constant_values=0)

def update_kv_cache_per_seq(seq_idx, cache_kv):
q_start = cu_q_lens[seq_idx]
q_end = cu_q_lens[seq_idx + 1]
q_len = q_end - q_start
kv_len = kv_lens[seq_idx]
_, page_size_per_kv_packing, kv_packing, _ = cache_kv.shape
page_size = page_size_per_kv_packing * kv_packing
num_page_indices = page_indices.shape[0]
max_num_seqs = kv_lens.shape[0]
pages_per_seq = num_page_indices // max_num_seqs
def update_token(j, cache_kv_):
token_idx_in_seq = kv_len - q_len + j
page_num_in_seq = token_idx_in_seq // page_size
page_indices_start = seq_idx * pages_per_seq
page_idx = page_indices[page_indices_start + page_num_in_seq]
row = (token_idx_in_seq % page_size) // kv_packing
col = (token_idx_in_seq % page_size) % kv_packing
cache_kv_ = cache_kv_.at[page_idx, row, col, ..., :lkv_dim].set(new_kv_c[q_start + j])
cache_kv_ = cache_kv_.at[page_idx, row, col, ..., lkv_dim:].set(new_k_pe[q_start + j])
return cache_kv_
cache_kv = jax.lax.fori_loop(0, q_len, update_token, cache_kv)
return cache_kv

seq_indices = jnp.arange(kv_lens.shape[0])
cache_kv = jax.lax.fori_loop(0, seq_indices.shape[0], update_kv_cache_per_seq, cache_kv)

Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block claims the KV-cache update is "handled inside the Pallas kernel", but the code performs a full KV update pass in Python/JAX (fori_loop over sequences and tokens) before launching the Pallas attention kernel. That contradicts the intended fusion/concurrency described in the PR and will reintroduce the extra HBM pass (and likely compile a large scatter loop). Either move the KV update into _mla_ragged_paged_attention_kernel (scalar lane) or update the comments/PR description and keep this as an explicit unfused baseline path.

Copilot uses AI. Check for mistakes.
Comment on lines +132 to +146
# Currently tensor_parallelism is also used for other things like determining number of Ray workers.
pc_tensor_parallelism = parallel_config.tensor_parallel_size
ss_tensor_parallelsim = sharding_strategy.get("tensor_parallelism", 1)
data_parallelism = parallel_config.data_parallel_size
expert_parallelism = sharding_strategy.get("expert_parallelism", 1)
sequence_parallelism = sharding_strategy.get("sequence_parallelism", 1)
device_indexes = sharding_strategy.get("device_indexes", None)

enable_dp_attention = sharding_strategy.get("enable_dp_attention",
False)
if pc_tensor_parallelism != ss_tensor_parallelsim and ss_tensor_parallelsim > 1:
# The user has explicitly set the tensor parallelism in the sharding config.
tensor_parallelism = ss_tensor_parallelsim
else:
tensor_parallelism = pc_tensor_parallelism
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ss_tensor_parallelsim is misspelled (should be ss_tensor_parallelism). Beyond readability, this makes it easier to accidentally introduce a second similarly-named variable later; consider renaming for clarity/consistency.

Copilot uses AI. Check for mistakes.
Comment on lines 129 to +147
sharding_strategy = vllm_config.additional_config.get(
"sharding", {}).get("sharding_strategy", {})
parallel_config = vllm_config.parallel_config
tensor_parallelism = parallel_config.tensor_parallel_size
# Currently tensor_parallelism is also used for other things like determining number of Ray workers.
pc_tensor_parallelism = parallel_config.tensor_parallel_size
ss_tensor_parallelsim = sharding_strategy.get("tensor_parallelism", 1)
data_parallelism = parallel_config.data_parallel_size
expert_parallelism = sharding_strategy.get("expert_parallelism", 1)
sequence_parallelism = sharding_strategy.get("sequence_parallelism", 1)
device_indexes = sharding_strategy.get("device_indexes", None)

enable_dp_attention = sharding_strategy.get("enable_dp_attention",
False)
if pc_tensor_parallelism != ss_tensor_parallelsim and ss_tensor_parallelsim > 1:
# The user has explicitly set the tensor parallelism in the sharding config.
tensor_parallelism = ss_tensor_parallelsim
else:
tensor_parallelism = pc_tensor_parallelism

Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for overriding tensor parallelism only applies when ss_tensor_parallelsim > 1. If a user explicitly sets tensor_parallelism to 1 in the sharding strategy (to override a larger parallel_config.tensor_parallel_size), that configuration will be ignored despite being explicitly provided. Consider checking for presence of the key (e.g. 'tensor_parallelism' in sharding_strategy) rather than > 1 so explicit overrides to 1 are honored.

Copilot uses AI. Check for mistakes.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's already a kernel.py in main. Can you diff base based on that file instead of submitting a whole new file?

Copy link
Collaborator Author

@mourado mourado Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these changes are specifically for MLA optimization. I initially added new files to avoid breaking the base version during testing, but I'll move the optimized fused logic into the existing kernel.py now and remove the redundant files to keep the diff clean.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these changes related to MLA?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to add test coverage before submission imo (see: tests/kernels/mla_v1_test.py)
Please work with Jaehong to expand test coverage and get those committed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed on the test coverage. add cases to mla_v1_test.py that specifically exercise the new fused version and vectorization paths. I'll include those in the next push.

@kyuyeunk
Copy link
Collaborator

kyuyeunk commented Mar 4, 2026

there so many conflicts in this pr. please update the branch to main's head first.

@mourado mourado force-pushed the bouache_mla/tpu_inference branch from 474def3 to 0f41147 Compare March 9, 2026 16:50
@kyuyeunk
Copy link
Collaborator

kyuyeunk commented Mar 9, 2026

Is this kernel connected to anywhere? I don't see a way for e2e to trigger this code path?

@mourado mourado added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants