Skip to content

[Feature] Add Gemma4 model support and refactor RoPE#334

Merged
xyDong0223 merged 4 commits into
baidu:mainfrom
GrootLiu:gemma4
May 7, 2026
Merged

[Feature] Add Gemma4 model support and refactor RoPE#334
xyDong0223 merged 4 commits into
baidu:mainfrom
GrootLiu:gemma4

Conversation

@GrootLiu
Copy link
Copy Markdown
Contributor

@GrootLiu GrootLiu commented Apr 21, 2026

[Feature] Add Gemma4 support (MoE, multimodal) and refactor RoPE

This commit introduces Gemma4 model support including MoE routing,
YOCO KV-sharing, and multimodal processing, along with a refactor of
RoPE implementations and related infrastructure.

Feature:

  • Add Gemma4ForCausalLM:

    • MoE routing (custom softmax → top-k → renorm, per-expert scale folded)
    • YOCO KV-sharing prefill
    • Per-layer embedding
    • tied K/V (attention_k_eq_v)
    • 3D-packed MoE checkpoint unpacking
    • Pipeline Parallel (PP) support
    • Note: MoE path validated on E2B dense only (26B A4B not verified)
  • Add Gemma4RotaryEmbedding:

    • proportional scaling (inv_freq uses head_dim denominator)
    • zero-padding for non-rotary dimensions
  • Add RoPE extensibility:

    • rope_ext.get_rope() with _ROPE_TYPE_REGISTRY
    • register new rope type: "proportional"
  • Add multimodal support:

    • Gemma4MultiModalProcessor (image/audio/video)
    • Fix container vLLM incompatibility:
      • replace ctx.get_merged_mm_kwargs() with model_config API
    • Add fallback for numpy video inputs (missing metadata)
      • fix "too many values to unpack" error
  • Add inference enhancements:

    • EagleModelMixin / SupportsEagle3 (EAGLE-3 speculative decoding)

Refactor:

  • Restructure rotary embedding:

    • ops/rotary_embedding.py → ops/rotary_embedding/*
      (kunlun_rope, mrope, deepseek_rope, gemma4_rope, rope_ext, utils)
  • Refactor plugin initialization:

    • split into:
      • _load_native_extension
      • _patch_schema_utils
      • _install_import_hook
    • add _completed_steps for idempotency
  • Refactor attention ops:

    • simplify split_kv_cache API (v0/v1)
    • add head_size=512 support in paged_attn

Enhancement / Fix:

  • v1 kunlun_attn:

    • KV-sharing via update_block_table
    • prefill alpha scaling fix (scale * sqrt(head_dim))
    • fix decode scale (0.0 → self.scale)
    • skip cache write for KV-sharing layers
    • add SWA params (swa_left/right, sink)
    • unify decode attention with speculative_attention (remove paged_attention branch)
  • v0 kunlun_attn:

    • align with new kv cache API
  • Register:

    • Gemma4ForCausalLM
    • qwen3 reasoning parser
    • kunlun_reasoning_parser entrypoint

Signed-off-by: GrootLiu 1219671600@qq.com

PR Description

1. 新增:Gemma4 模型主体

vllm_kunlun/models/gemma4.py(+1615 行,纯新增)

  • Gemma4ForCausalLM 全套实现
  • MoE:Gemma4Router + Gemma4MoE(自定义 softmax→top-k→renorm 路由,per-expert scale 折叠入 routing weights;26B A4B 未验证,准备下一个 PR 验证 Gemma MoE 模型)
  • YOCO KV-sharing prefill
  • Per-Layer Embedding (PLE)
  • attention_k_eq_v 共享 K/V 权重
  • 3D-packed MoE checkpoint 权重拆包逻辑
  • Pipeline parallelism

vllm_kunlun/models/interfaces.py(+83 行)

  • EagleModelMixin + SupportsEagle3 接口(EAGLE-3 投机解码)

vllm_kunlun/models/init.py(+4 行)

  • 注册 Gemma4ForCausalLM

Gemma4 多模态 Gemma4ForConditionalGeneration 暂时未完成适配,与 Gemma4 MoE 一起在下一个 PR 中提交

2. 重构:RoPE 从单文件拆成 package

删除 vllm_kunlun/ops/rotary_embedding.py(-256 行)

新建 vllm_kunlun/ops/rotary_embedding/ 目录(+542 行):

  • kunlun_rope.py(99)— OOT RotaryEmbedding
  • kunlun_mrope.py(71)— OOT MRotaryEmbedding
  • kunlun_deepseek_rope.py(64)— OOT Deepseek Scaling RoPE
  • gemma4_rope.py(83)— proportional scaling,head_dim 为分母,非旋转维 zero-pad
  • rope_ext.py(108)— 可扩展 get_rope() + _ROPE_TYPE_REGISTRY(注册 "proportional")
  • utils.py(63)— Split_Norm_Rope 工具
  • __init__.py(54)

3. 新增:MoE Gate 算子

写好但 XPU 未启用,等 Gemma4 多模态和 Gemm4 MoE 模型适配完后,将 Gate 算子从目前的 ReplicatedLinear 切到 GateLinear

vllm_kunlun/ops/fused_moe/router/gate_linear.py(+116 行)

GateLinear(继承 ReplicatedLinear)三级 GEMM dispatch:

  • Tier 1: DSV3 kernel(SM90+,batch≤16,固定维度)
  • Tier 2: cuBLAS bf16→fp32(SM90+ + bf16)
  • Tier 3: F.linear fallback

XPU 上一定会走 Tier 3,等价于 ReplicatedLinear;gemma4.py 目前直接用 ReplicatedLinear,GateLinear 尚未投入使用

4. 修改:Attention backend

v1/attention/backends/kunlun_attn.py

  • 通过 update_block_table() 支持 KV sharing
  • Prefill alpha scaling:scale * sqrt(head_dim),抵消 XPU kernel 里写死的 1/sqrt(head_dim),变为可自定义缩放比例,如果未自定义,默认为 sqrt(head_dim)
  • Sliding window 参数:swa_leftswa_rightsink
  • KV-sharing 层跳过 cache 写入,直接使用指定共享层的 KV Cache
  • 修正 decode scale:0.0 → self.scale
  • 重构 decode attention:统一使用 speculative_attention,移除 paged_attention 分支,修复静态签名检查导致的死代码问题
  • 修正 max_context_len:从硬编码 131072 改为 decode_meta.max_model_len

ops/attention/backends/kunlun_attn.py(小改) + ops/paged_attn.py

  • 简化 split_kv_cache(kv_cache=...) 接口(去掉多余参数)
  • PagedAttention 新增 head_size 512 支持

5. 修改:插件初始化

vllm_kunlun/init.py(+162/-?)

  • 拆出 _load_native_extension / _patch_schema_utils / _install_import_hook
  • _completed_steps 做幂等多阶段注册

vllm_kunlun/reasoning/init.py(+22 行)

  • 注册 qwen3 reasoning parser

pyproject.toml(+1)+ setup.py(+1)

  • entry point: kunlun_reasoning_parservllm.general_plugins

6. 修改:Gemma4 多模态 API 兼容性

vllm_kunlun/models/gemma4_mm.py(+13/-2)

修复容器中 vLLM 环境与代码期望的 API 差异:

  • _call_hf_processor(493-494):将 self.info.ctx.get_merged_mm_kwargs(mm_kwargs) 替换为 self.info.ctx.model_config.get_multimodal_config().merge_mm_processor_kwargs(mm_kwargs),解决 AttributeError: 'InputProcessingContext' object has no attribute 'get_merged_mm_kwargs'
  • video 处理(522-537):在 for item in videos 循环中增加类型判断分支。容器中 vLLM 解析管道会在传入 _call_hf_processor 前剥离 metadata,导致 video_array, metadata = item 解包时出现 ValueError: too many values to unpack。现在:
    • 如果是 (array, metadata) 二元组:正常解包
    • 如果是裸 numpy array:自动合成默认 metadata dict(fps=2.0, duration, total_num_frames, frames_indices, video_backend=opencv, do_sample_frames=False),确保与新版 vLLM 的行为一致
image

Checklist (Required)

Before submitting this PR, please ensure that all the following items are completed:

  • All code changes pass the pre-commit checks.
  • Commits are signed off using git commit -s.
  • The PR title is properly classified (see below).

PR Type

Please prefix the PR title with one or more of the following labels to help reviewers quickly understand the nature of the change:

  • [Feature] – New features or enhancements (e.g. Attention, Communicator, Kernel, Worker, etc.)
  • [Bugfix] – Bug fixes
  • [CI/Build] – CI, build system, or infrastructure improvements
  • [Doc] – Documentation updates or fixes
  • [Misc] – Other changes that do not fit the above categories (use sparingly)

Note: If the PR spans multiple categories, include all relevant prefixes.


Detailed Checklist (Click to Expand)

Thank you for contributing to vLLM Kunlun! To help us maintain high code quality and streamline the review process, please ensure your PR meets the following requirements.

1. Code Quality

  • All linting and formatting checks pass (pre-commit).
  • The code is well-structured and sufficiently documented.
  • The change is designed with maintainability and readability in mind.

2. Testing

  • Relevant unit tests are added or updated.
  • Integration tests are included when applicable.
  • Existing tests continue to pass.

3. DCO Compliance

This project follows the Developer Certificate of Origin (DCO).

  • All commits include a Signed-off-by: line.
  • Use git commit -s to automatically add the sign-off.

4. Review Expectations

During the review process, maintainers may:

  • Request code refactoring or additional tests.
  • Ask for clarifications on design decisions.
  • Suggest performance, stability, or maintainability improvements.

We appreciate your patience and collaboration throughout the review process!

New:
- Gemma4ForCausalLM: MoE (custom softmax→top-k→renorm routing with
  per-expert scale folded into weights), YOCO KV-sharing prefill,
  Per-Layer Embedding, tied K/V (attention_k_eq_v), 3D-packed MoE
  checkpoint unpacking, PP support.
  Note: MoE path verified only on E2B dense; 26B A4B not yet validated.
- Gemma4RotaryEmbedding: proportional scaling (inv_freq uses head_dim
  denominator, zero-pads non-rotated dims).
- rope_ext.get_rope(): _ROPE_TYPE_REGISTRY for custom rope_type dispatch
  (registers "proportional").
- fused_moe/router/gate_linear.py: 3-tier GEMM (DSV3 → cuBLAS bf16→fp32
  → F.linear). Currently unused on XPU (falls through to Tier 3).
- EagleModelMixin / SupportsEagle3 for EAGLE-3 speculative decoding.

Refactor:
- ops/rotary_embedding.py → ops/rotary_embedding/ package
  (kunlun_rope, kunlun_mrope, kunlun_deepseek_rope, gemma4_rope,
  rope_ext, utils).
- Plugin __init__: split into _load_native_extension /
  _patch_schema_utils / _install_import_hook with _completed_steps
  for idempotent registration.

Modified:
- v1 kunlun_attn: KV-sharing via update_block_table, prefill alpha
  scaling (scale*sqrt(head_dim)) to offset kernel's fixed 1/sqrt,
  SWA params (swa_left/right, sink), skip cache write on KV-share
  layers, fix decode scale 0.0 → self.scale.
- v0 kunlun_attn + paged_attn: simplified split_kv_cache(kv_cache=)
  API; paged_attn adds head_size 512.
- Register Gemma4ForCausalLM (models/__init__); register qwen3
  reasoning parser; entry-point kunlun_reasoning_parser in
  pyproject.toml / setup.py.

Signed-off-by: GrootLiu <1219671600@qq.com>
@xyDong0223 xyDong0223 requested review from Copilot and liwei109 and removed request for Copilot April 22, 2026 08:24
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds Gemma4 model support (including YOCO KV-sharing fast-prefill, PLE, MoE routing, PP/EAGLE-3 hooks) and refactors Kunlun RoPE into a package with extensible rope_type dispatch, while updating Kunlun attention backends and plugin initialization/entrypoints.

Changes:

  • Introduce Gemma4ForCausalLM (+ supporting interfaces) and register it in the model registry.
  • Refactor rotary embedding implementation into ops/rotary_embedding/ and add rope_ext.get_rope() registry-based dispatch (e.g., rope_type="proportional").
  • Update Kunlun attention backends for KV-sharing behavior, new scaling behavior, and simplify KV-cache helpers.

Reviewed changes

Copilot reviewed 20 out of 21 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
vllm_kunlun/v1/attention/backends/kunlun_attn.py KV-sharing support in v1 backend (metadata update helper, cache-write skip, scaling/params tweaks).
vllm_kunlun/reasoning/init.py Adds reasoning parser registration entrypoint for Qwen3.
vllm_kunlun/ops/rotary_embedding/utils.py Moves Split_Norm_Rope utility into new RoPE package.
vllm_kunlun/ops/rotary_embedding/rope_ext.py Adds extensible get_rope() dispatch via _ROPE_TYPE_REGISTRY.
vllm_kunlun/ops/rotary_embedding/kunlun_rope.py Splits out Kunlun OOT RotaryEmbedding implementation.
vllm_kunlun/ops/rotary_embedding/kunlun_mrope.py Splits out Kunlun OOT MRotaryEmbedding implementation.
vllm_kunlun/ops/rotary_embedding/kunlun_deepseek_rope.py Splits out Kunlun OOT DeepseekScalingRotaryEmbedding implementation.
vllm_kunlun/ops/rotary_embedding/gemma4_rope.py Adds Gemma4 proportional RoPE implementation.
vllm_kunlun/ops/rotary_embedding/init.py New package initializer importing/triggering OOT registrations and exporting utilities.
vllm_kunlun/ops/rotary_embedding.py Removes the previous monolithic RoPE module (replaced by package).
vllm_kunlun/ops/paged_attn.py Simplifies KV-cache split API and adds head_size=512 support.
vllm_kunlun/ops/fused_moe/router/gate_linear.py Introduces a 3-tier router GEMM layer (currently intended mainly for CUDA SM90+).
vllm_kunlun/ops/attention/backends/kunlun_attn.py Updates split_kv_cache callsite to new API.
vllm_kunlun/models/interfaces.py Adds EagleModelMixin and SupportsEagle3 protocol for speculative decoding.
vllm_kunlun/models/gemma4.py Adds full Gemma4 model implementation (MoE/YOCO/PLE/PP/EAGLE-3 integration).
vllm_kunlun/models/init.py Registers Gemma4ForCausalLM.
vllm_kunlun/init.py Refactors plugin init into idempotent steps + adds reasoning-parser registration entrypoint.
setup.py Adds kunlun_reasoning_parser entry point.
pyproject.toml Adds kunlun_reasoning_parser entry point.
docs/source/installation.md Formatting/whitespace cleanup in installation docs.
ci/scripts/env/install_env.sh Whitespace cleanup in CI env script.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1 to +9
import importlib

from vllm.reasoning import ReasoningParserManager

"""
Reasoning parser registration module for vLLM Kunlun.
"""


Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module docstring is placed after imports. For standard Python module semantics (and for tools like pydoc/Sphinx), the docstring should be the first statement in the file (before any imports).

Suggested change
import importlib
from vllm.reasoning import ReasoningParserManager
"""
Reasoning parser registration module for vLLM Kunlun.
"""
"""
Reasoning parser registration module for vLLM Kunlun.
"""
import importlib
from vllm.reasoning import ReasoningParserManager

Copilot uses AI. Check for mistakes.
Comment on lines +13 to +22
def _maybe_add_hidden_state(
self,
aux_hidden_states: list[torch.Tensor],
layer_idx: int,
hidden_states: torch.Tensor,
residual: torch.Tensor,
) -> list[torch.Tensor]:
if layer_idx in self.aux_hidden_state_layers:
value = hidden_states + residual if residual is not None else hidden_states
aux_hidden_states.append(value)
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EagleModelMixin._maybe_add_hidden_state annotates residual as torch.Tensor, but the implementation explicitly handles residual is None (and callers pass None). Update the type hint to torch.Tensor | None (and, if desired, reflect that in the docstring) to avoid misleading typing and static-analysis errors.

Copilot uses AI. Check for mistakes.
Comment on lines +30 to +54
"""

import logging

from vllm_kunlun.ops.rotary_embedding.gemma4_rope import ( # noqa: F401
Gemma4RotaryEmbedding,
)
from vllm_kunlun.ops.rotary_embedding.kunlun_deepseek_rope import ( # noqa: F401
KunlunDeepseekScalingRotaryEmbedding,
)
from vllm_kunlun.ops.rotary_embedding.kunlun_mrope import ( # noqa: F401
KunlunMRotaryEmbedding,
)
from vllm_kunlun.ops.rotary_embedding.kunlun_rope import ( # noqa: F401
KunlunRotaryEmbedding,
)
from vllm_kunlun.ops.rotary_embedding.utils import Split_Norm_Rope # noqa: F401

logger = logging.getLogger("vllm_kunlun.ops.rotary_embedding")

# Log that OOT registration is complete
logger.info(
"[KunlunOOT] Registered KunlunRotaryEmbedding, KunlunMRotaryEmbedding, "
"KunlunDeepseekScalingRotaryEmbedding via CustomOp.register_oot"
)
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This package __init__.py eagerly imports all concrete RoPE implementations (including modules that register CustomOps). Any import of vllm_kunlun.ops.rotary_embedding.rope_ext will execute this __init__ first, which defeats the “lazy import to avoid circular imports” goal described in rope_ext.py and can reintroduce circular-import / side-effect issues. Consider keeping __init__.py import-light (no registration side effects) and triggering OOT registrations from an explicit register_*() function invoked during plugin setup instead.

Suggested change
"""
import logging
from vllm_kunlun.ops.rotary_embedding.gemma4_rope import ( # noqa: F401
Gemma4RotaryEmbedding,
)
from vllm_kunlun.ops.rotary_embedding.kunlun_deepseek_rope import ( # noqa: F401
KunlunDeepseekScalingRotaryEmbedding,
)
from vllm_kunlun.ops.rotary_embedding.kunlun_mrope import ( # noqa: F401
KunlunMRotaryEmbedding,
)
from vllm_kunlun.ops.rotary_embedding.kunlun_rope import ( # noqa: F401
KunlunRotaryEmbedding,
)
from vllm_kunlun.ops.rotary_embedding.utils import Split_Norm_Rope # noqa: F401
logger = logging.getLogger("vllm_kunlun.ops.rotary_embedding")
# Log that OOT registration is complete
logger.info(
"[KunlunOOT] Registered KunlunRotaryEmbedding, KunlunMRotaryEmbedding, "
"KunlunDeepseekScalingRotaryEmbedding via CustomOp.register_oot"
)
Import behavior:
- Keep this package initializer import-light to avoid defeating lazy-import paths
(for example in rope_ext.py) and to avoid import-time registration side effects.
- Call register_oot() explicitly during plugin setup to trigger OOT registrations.
"""
import importlib
import logging
logger = logging.getLogger("vllm_kunlun.ops.rotary_embedding")
_REGISTERED = False
_SYMBOL_TO_MODULE = {
"Gemma4RotaryEmbedding": (
"vllm_kunlun.ops.rotary_embedding.gemma4_rope",
"Gemma4RotaryEmbedding",
),
"KunlunDeepseekScalingRotaryEmbedding": (
"vllm_kunlun.ops.rotary_embedding.kunlun_deepseek_rope",
"KunlunDeepseekScalingRotaryEmbedding",
),
"KunlunMRotaryEmbedding": (
"vllm_kunlun.ops.rotary_embedding.kunlun_mrope",
"KunlunMRotaryEmbedding",
),
"KunlunRotaryEmbedding": (
"vllm_kunlun.ops.rotary_embedding.kunlun_rope",
"KunlunRotaryEmbedding",
),
"Split_Norm_Rope": (
"vllm_kunlun.ops.rotary_embedding.utils",
"Split_Norm_Rope",
),
}
__all__ = list(_SYMBOL_TO_MODULE) + ["register_oot"]
def register_oot():
"""Import OOT modules explicitly so their registration side effects run."""
global _REGISTERED
if _REGISTERED:
return
importlib.import_module("vllm_kunlun.ops.rotary_embedding.gemma4_rope")
importlib.import_module(
"vllm_kunlun.ops.rotary_embedding.kunlun_deepseek_rope")
importlib.import_module("vllm_kunlun.ops.rotary_embedding.kunlun_mrope")
importlib.import_module("vllm_kunlun.ops.rotary_embedding.kunlun_rope")
_REGISTERED = True
logger.info(
"[KunlunOOT] Registered KunlunRotaryEmbedding, "
"KunlunMRotaryEmbedding, KunlunDeepseekScalingRotaryEmbedding "
"via CustomOp.register_oot"
)
def __getattr__(name):
if name not in _SYMBOL_TO_MODULE:
raise AttributeError(
f"module {__name__!r} has no attribute {name!r}")
module_name, attr_name = _SYMBOL_TO_MODULE[name]
module = importlib.import_module(module_name)
value = getattr(module, attr_name)
globals()[name] = value
return value

Copilot uses AI. Check for mistakes.
Comment on lines +724 to +726
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) * self.normalizer

Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

embed_input_ids multiplies embeddings by self.normalizer, which is registered as a default torch.tensor(...) (float32). Multiplying (b)float16 embeddings by a float32 tensor will promote the whole embedding output to float32, increasing memory/compute and diverging from the comment about “downcast to model dtype”. Consider storing these scale factors as Python floats, or casting the buffer to embed_tokens.weight.dtype (or inputs_embeds.dtype) before multiplication.

Copilot uses AI. Check for mistakes.
Comment on lines +743 to +745
per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens)
per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer
return per_layer_embeds.reshape(
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer will also upcast per_layer_embeds to float32 because embed_scale_per_layer is a float32 tensor buffer. Same applies to per_layer_projection_scale / per_layer_input_scale later. To keep activations in the model dtype, cast these scale buffers to the activation dtype before use (or store them as Python floats).

Copilot uses AI. Check for mistakes.
Comment on lines +217 to +226
router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1)
indicator = torch.nn.functional.one_hot(
topk_ids, num_classes=gating_output.size(-1)
).sum(dim=-2)
gate_weights = indicator * router_probabilities
renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True)
renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0)
dispatch_weights = gate_weights / renorm_factor

topk_weights = dispatch_weights.gather(1, topk_ids)
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The routing function materializes a dense [T, num_experts] one-hot indicator tensor (one_hot(...).sum(...)), which is O(T·E) memory and can be a significant overhead (e.g., large token batches with 256/384 experts). This can be computed without a full one-hot by gathering the softmax probabilities at topk_ids, renormalizing those topk values, and returning them directly.

Suggested change
router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1)
indicator = torch.nn.functional.one_hot(
topk_ids, num_classes=gating_output.size(-1)
).sum(dim=-2)
gate_weights = indicator * router_probabilities
renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True)
renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0)
dispatch_weights = gate_weights / renorm_factor
topk_weights = dispatch_weights.gather(1, topk_ids)
router_probabilities = torch.nn.functional.softmax(
gating_output, dim=-1
)
topk_weights = router_probabilities.gather(1, topk_ids)
if renormalize:
renorm_factor = torch.sum(topk_weights, dim=-1, keepdim=True)
renorm_factor = torch.where(
renorm_factor > 0.0, renorm_factor,
torch.ones_like(renorm_factor)
)
topk_weights = topk_weights / renorm_factor

Copilot uses AI. Check for mistakes.
"""Update block table and slot mapping for a different KV cache group."""
new_metadata = copy.copy(metadata)
new_metadata.block_tables = blk_table
new_metadata.slot_mapping = slot_mapping
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update_block_table uses a shallow copy.copy(metadata) and then overwrites block_tables/slot_mapping, but it leaves _cached_prefill_metadata / _cached_decode_metadata intact. If those caches were already populated, new_metadata.prefill_metadata / decode_metadata will return stale cached objects that still reference the old block tables / slot mapping. Consider clearing these cached fields (or rebuilding via dataclasses.replace while setting the cached fields to None) when creating the updated metadata.

Suggested change
new_metadata.slot_mapping = slot_mapping
new_metadata.slot_mapping = slot_mapping
# Invalidate cached derived metadata so it is rebuilt against the
# updated block tables and slot mapping instead of reusing stale
# references copied from the original metadata object.
new_metadata._cached_prefill_metadata = None
new_metadata._cached_decode_metadata = None

Copilot uses AI. Check for mistakes.
Comment thread vllm_kunlun/__init__.py
Comment on lines +76 to 86
def _load_native_extension(logger: logging.Logger) -> None:
"""Load _kunlun C extension to register torch.ops._C.weak_ref_tensor."""
if "native_ext" in _completed_steps:
return
_completed_steps.add("native_ext") # only attempt once
try:
from . import _kunlun # noqa: F401

logger.info("[KunlunPlugin] _kunlun native extension loaded")
except ImportError as e:
logger.warning("[KunlunPlugin] Failed to load _kunlun: %s", e)
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_load_native_extension adds "native_ext" to _completed_steps before attempting the import. This means a transient failure (e.g., extension not yet available during an early plugin-discovery phase) will never be retried, which contradicts the register() docstring about retrying previously-failed steps. Consider only marking the step as completed after a successful import (or track attempted vs completed separately).

Copilot uses AI. Check for mistakes.
- Implement Gemma4MultiModalProcessor with image, audio and video handling
- Replace self.info.ctx.get_merged_mm_kwargs() with
  model_config.get_multimodal_config().merge_mm_processor_kwargs()
  to fix AttributeError in container vLLM environment
- Add fallback for bare numpy array video items since container vLLM
  strips metadata before _call_hf_processor; synthesise default metadata
  to fix "too many values to unpack" error

Signed-off-by: GrootLiu <1219671600@qq.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 21 out of 22 changed files in this pull request and generated 6 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1 to +9
import importlib

from vllm.reasoning import ReasoningParserManager

"""
Reasoning parser registration module for vLLM Kunlun.
"""


Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The triple-quoted string is placed after imports, so it is not treated as the module docstring (and may be flagged by linters as a pointless string literal). Move this docstring to the very top of the file (before imports) or convert it to a regular comment if you don't want it as doc.

Suggested change
import importlib
from vllm.reasoning import ReasoningParserManager
"""
Reasoning parser registration module for vLLM Kunlun.
"""
"""
Reasoning parser registration module for vLLM Kunlun.
"""
import importlib
from vllm.reasoning import ReasoningParserManager

Copilot uses AI. Check for mistakes.
aux_hidden_states: list[torch.Tensor],
layer_idx: int,
hidden_states: torch.Tensor,
residual: torch.Tensor,
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

residual is typed as torch.Tensor, but this method explicitly handles residual is not None and callers pass None (e.g. during the first layer). Update the type annotation to torch.Tensor | None to match actual usage and avoid type-checker errors.

Suggested change
residual: torch.Tensor,
residual: torch.Tensor | None,

Copilot uses AI. Check for mistakes.
Comment on lines +827 to +835
# NOTE(kunlun): prefill_attention kernel internally applies
# 1/sqrt(head_dim) and multiplies by alpha. Compute alpha to
# achieve the desired effective scaling:
# score = Q @ K^T * (1/sqrt(d)) * alpha
# We want: score = Q @ K^T * self.scale
# So: alpha = self.scale * sqrt(d) = self.scale / (1/sqrt(d))
import math

_prefill_alpha = self.scale * math.sqrt(self.head_size)
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import math is inside the attention forward path, so it executes on every prefill call (hot path). Move the import to module scope (and consider precomputing sqrt(head_size) once) to avoid repeated work in a latency-sensitive function.

Copilot uses AI. Check for mistakes.
Comment on lines +757 to +758
merged_kwargs = self.info.ctx.get_merged_mm_kwargs(
hf_processor_mm_kwargs,
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still uses self.info.ctx.get_merged_mm_kwargs(...), but earlier in this PR _call_hf_processor was updated because some vLLM environments don't have InputProcessingContext.get_merged_mm_kwargs. For consistency (and to avoid the same AttributeError during prompt update building), use the model_config.get_multimodal_config().merge_mm_processor_kwargs(...) path here as well.

Suggested change
merged_kwargs = self.info.ctx.get_merged_mm_kwargs(
hf_processor_mm_kwargs,
model_config = self.info.ctx.model_config
merged_kwargs = (
model_config.get_multimodal_config()
.merge_mm_processor_kwargs(hf_processor_mm_kwargs)

Copilot uses AI. Check for mistakes.
Comment on lines +1288 to +1295
if multimodal_embeddings is None or is_multimodal is None:
return super().embed_input_ids(input_ids)

return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super().embed_input_ids(...) is called here, but Gemma4ForConditionalGeneration only inherits from nn.Module plus Protocols, so there is no guaranteed implementation of embed_input_ids in the MRO. This is likely to raise at runtime and it also bypasses the standard multimodal merge behavior used elsewhere in this repo (embed text first, then merge multimodal_embeddings at is_multimodal positions). Implement embed_input_ids explicitly by delegating to self.language_model.embed_input_ids and then merging multimodal embeddings (similar to other multimodal models in vllm_kunlun/models/*).

Suggested change
if multimodal_embeddings is None or is_multimodal is None:
return super().embed_input_ids(input_ids)
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
inputs_embeds = self.language_model.embed_input_ids(input_ids)
if multimodal_embeddings is None or is_multimodal is None:
return inputs_embeds
merged_embeddings = inputs_embeds.clone()
merged_embeddings[
is_multimodal.to(inputs_embeds.device, non_blocking=True)
] = multimodal_embeddings.to(
device=inputs_embeds.device,
dtype=inputs_embeds.dtype,
non_blocking=True,
)
return merged_embeddings

Copilot uses AI. Check for mistakes.
Comment on lines +91 to +94
def forward(
self, x: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
import vllm._custom_ops as ops
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type annotation for forward() doesn't match the implementation: this method always returns a 2-tuple (output, bias_or_none), never a bare torch.Tensor. Tighten the return type to the tuple form (and align the bias element type with what ReplicatedLinear.forward actually returns) so callers and type-checkers can rely on it.

Copilot uses AI. Check for mistakes.
GrootLiu and others added 2 commits April 29, 2026 15:56
- Remove dead code caused by static signature check
- Use is_speculative as the main branch condition
- Remove paged_attention branch (speculative_attention covers its capability)
- Fix max_context_len from hardcoded 131072 to decode_meta.max_model_len

Signed-off-by: GrootLiu <1219671600@qq.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 21 out of 22 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread vllm_kunlun/__init__.py
Comment on lines +118 to +122
vLLM may invoke this multiple times during different discovery phases;
each step tracks its own completion state via ``_completed_steps`` so
already-succeeded work is skipped while previously-failed work (e.g.
_patch_rotary_embedding blocked by circular import) is retried.
"""
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The register() docstring says previously-failed steps are retried, but _load_native_extension() marks the native_ext step as completed before attempting the import (so an ImportError will not be retried on subsequent register() calls). Please either adjust the docstring to match the actual behavior or change _load_native_extension() to only mark completion after a successful load if retries are intended.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

@liwei109 liwei109 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@xyDong0223 xyDong0223 merged commit 189a443 into baidu:main May 7, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants