[Feature] Add Gemma4 model support and refactor RoPE#334
Conversation
New: - Gemma4ForCausalLM: MoE (custom softmax→top-k→renorm routing with per-expert scale folded into weights), YOCO KV-sharing prefill, Per-Layer Embedding, tied K/V (attention_k_eq_v), 3D-packed MoE checkpoint unpacking, PP support. Note: MoE path verified only on E2B dense; 26B A4B not yet validated. - Gemma4RotaryEmbedding: proportional scaling (inv_freq uses head_dim denominator, zero-pads non-rotated dims). - rope_ext.get_rope(): _ROPE_TYPE_REGISTRY for custom rope_type dispatch (registers "proportional"). - fused_moe/router/gate_linear.py: 3-tier GEMM (DSV3 → cuBLAS bf16→fp32 → F.linear). Currently unused on XPU (falls through to Tier 3). - EagleModelMixin / SupportsEagle3 for EAGLE-3 speculative decoding. Refactor: - ops/rotary_embedding.py → ops/rotary_embedding/ package (kunlun_rope, kunlun_mrope, kunlun_deepseek_rope, gemma4_rope, rope_ext, utils). - Plugin __init__: split into _load_native_extension / _patch_schema_utils / _install_import_hook with _completed_steps for idempotent registration. Modified: - v1 kunlun_attn: KV-sharing via update_block_table, prefill alpha scaling (scale*sqrt(head_dim)) to offset kernel's fixed 1/sqrt, SWA params (swa_left/right, sink), skip cache write on KV-share layers, fix decode scale 0.0 → self.scale. - v0 kunlun_attn + paged_attn: simplified split_kv_cache(kv_cache=) API; paged_attn adds head_size 512. - Register Gemma4ForCausalLM (models/__init__); register qwen3 reasoning parser; entry-point kunlun_reasoning_parser in pyproject.toml / setup.py. Signed-off-by: GrootLiu <1219671600@qq.com>
There was a problem hiding this comment.
Pull request overview
Adds Gemma4 model support (including YOCO KV-sharing fast-prefill, PLE, MoE routing, PP/EAGLE-3 hooks) and refactors Kunlun RoPE into a package with extensible rope_type dispatch, while updating Kunlun attention backends and plugin initialization/entrypoints.
Changes:
- Introduce
Gemma4ForCausalLM(+ supporting interfaces) and register it in the model registry. - Refactor rotary embedding implementation into
ops/rotary_embedding/and addrope_ext.get_rope()registry-based dispatch (e.g.,rope_type="proportional"). - Update Kunlun attention backends for KV-sharing behavior, new scaling behavior, and simplify KV-cache helpers.
Reviewed changes
Copilot reviewed 20 out of 21 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| vllm_kunlun/v1/attention/backends/kunlun_attn.py | KV-sharing support in v1 backend (metadata update helper, cache-write skip, scaling/params tweaks). |
| vllm_kunlun/reasoning/init.py | Adds reasoning parser registration entrypoint for Qwen3. |
| vllm_kunlun/ops/rotary_embedding/utils.py | Moves Split_Norm_Rope utility into new RoPE package. |
| vllm_kunlun/ops/rotary_embedding/rope_ext.py | Adds extensible get_rope() dispatch via _ROPE_TYPE_REGISTRY. |
| vllm_kunlun/ops/rotary_embedding/kunlun_rope.py | Splits out Kunlun OOT RotaryEmbedding implementation. |
| vllm_kunlun/ops/rotary_embedding/kunlun_mrope.py | Splits out Kunlun OOT MRotaryEmbedding implementation. |
| vllm_kunlun/ops/rotary_embedding/kunlun_deepseek_rope.py | Splits out Kunlun OOT DeepseekScalingRotaryEmbedding implementation. |
| vllm_kunlun/ops/rotary_embedding/gemma4_rope.py | Adds Gemma4 proportional RoPE implementation. |
| vllm_kunlun/ops/rotary_embedding/init.py | New package initializer importing/triggering OOT registrations and exporting utilities. |
| vllm_kunlun/ops/rotary_embedding.py | Removes the previous monolithic RoPE module (replaced by package). |
| vllm_kunlun/ops/paged_attn.py | Simplifies KV-cache split API and adds head_size=512 support. |
| vllm_kunlun/ops/fused_moe/router/gate_linear.py | Introduces a 3-tier router GEMM layer (currently intended mainly for CUDA SM90+). |
| vllm_kunlun/ops/attention/backends/kunlun_attn.py | Updates split_kv_cache callsite to new API. |
| vllm_kunlun/models/interfaces.py | Adds EagleModelMixin and SupportsEagle3 protocol for speculative decoding. |
| vllm_kunlun/models/gemma4.py | Adds full Gemma4 model implementation (MoE/YOCO/PLE/PP/EAGLE-3 integration). |
| vllm_kunlun/models/init.py | Registers Gemma4ForCausalLM. |
| vllm_kunlun/init.py | Refactors plugin init into idempotent steps + adds reasoning-parser registration entrypoint. |
| setup.py | Adds kunlun_reasoning_parser entry point. |
| pyproject.toml | Adds kunlun_reasoning_parser entry point. |
| docs/source/installation.md | Formatting/whitespace cleanup in installation docs. |
| ci/scripts/env/install_env.sh | Whitespace cleanup in CI env script. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| import importlib | ||
|
|
||
| from vllm.reasoning import ReasoningParserManager | ||
|
|
||
| """ | ||
| Reasoning parser registration module for vLLM Kunlun. | ||
| """ | ||
|
|
||
|
|
There was a problem hiding this comment.
The module docstring is placed after imports. For standard Python module semantics (and for tools like pydoc/Sphinx), the docstring should be the first statement in the file (before any imports).
| import importlib | |
| from vllm.reasoning import ReasoningParserManager | |
| """ | |
| Reasoning parser registration module for vLLM Kunlun. | |
| """ | |
| """ | |
| Reasoning parser registration module for vLLM Kunlun. | |
| """ | |
| import importlib | |
| from vllm.reasoning import ReasoningParserManager |
| def _maybe_add_hidden_state( | ||
| self, | ||
| aux_hidden_states: list[torch.Tensor], | ||
| layer_idx: int, | ||
| hidden_states: torch.Tensor, | ||
| residual: torch.Tensor, | ||
| ) -> list[torch.Tensor]: | ||
| if layer_idx in self.aux_hidden_state_layers: | ||
| value = hidden_states + residual if residual is not None else hidden_states | ||
| aux_hidden_states.append(value) |
There was a problem hiding this comment.
EagleModelMixin._maybe_add_hidden_state annotates residual as torch.Tensor, but the implementation explicitly handles residual is None (and callers pass None). Update the type hint to torch.Tensor | None (and, if desired, reflect that in the docstring) to avoid misleading typing and static-analysis errors.
| """ | ||
|
|
||
| import logging | ||
|
|
||
| from vllm_kunlun.ops.rotary_embedding.gemma4_rope import ( # noqa: F401 | ||
| Gemma4RotaryEmbedding, | ||
| ) | ||
| from vllm_kunlun.ops.rotary_embedding.kunlun_deepseek_rope import ( # noqa: F401 | ||
| KunlunDeepseekScalingRotaryEmbedding, | ||
| ) | ||
| from vllm_kunlun.ops.rotary_embedding.kunlun_mrope import ( # noqa: F401 | ||
| KunlunMRotaryEmbedding, | ||
| ) | ||
| from vllm_kunlun.ops.rotary_embedding.kunlun_rope import ( # noqa: F401 | ||
| KunlunRotaryEmbedding, | ||
| ) | ||
| from vllm_kunlun.ops.rotary_embedding.utils import Split_Norm_Rope # noqa: F401 | ||
|
|
||
| logger = logging.getLogger("vllm_kunlun.ops.rotary_embedding") | ||
|
|
||
| # Log that OOT registration is complete | ||
| logger.info( | ||
| "[KunlunOOT] Registered KunlunRotaryEmbedding, KunlunMRotaryEmbedding, " | ||
| "KunlunDeepseekScalingRotaryEmbedding via CustomOp.register_oot" | ||
| ) |
There was a problem hiding this comment.
This package __init__.py eagerly imports all concrete RoPE implementations (including modules that register CustomOps). Any import of vllm_kunlun.ops.rotary_embedding.rope_ext will execute this __init__ first, which defeats the “lazy import to avoid circular imports” goal described in rope_ext.py and can reintroduce circular-import / side-effect issues. Consider keeping __init__.py import-light (no registration side effects) and triggering OOT registrations from an explicit register_*() function invoked during plugin setup instead.
| """ | |
| import logging | |
| from vllm_kunlun.ops.rotary_embedding.gemma4_rope import ( # noqa: F401 | |
| Gemma4RotaryEmbedding, | |
| ) | |
| from vllm_kunlun.ops.rotary_embedding.kunlun_deepseek_rope import ( # noqa: F401 | |
| KunlunDeepseekScalingRotaryEmbedding, | |
| ) | |
| from vllm_kunlun.ops.rotary_embedding.kunlun_mrope import ( # noqa: F401 | |
| KunlunMRotaryEmbedding, | |
| ) | |
| from vllm_kunlun.ops.rotary_embedding.kunlun_rope import ( # noqa: F401 | |
| KunlunRotaryEmbedding, | |
| ) | |
| from vllm_kunlun.ops.rotary_embedding.utils import Split_Norm_Rope # noqa: F401 | |
| logger = logging.getLogger("vllm_kunlun.ops.rotary_embedding") | |
| # Log that OOT registration is complete | |
| logger.info( | |
| "[KunlunOOT] Registered KunlunRotaryEmbedding, KunlunMRotaryEmbedding, " | |
| "KunlunDeepseekScalingRotaryEmbedding via CustomOp.register_oot" | |
| ) | |
| Import behavior: | |
| - Keep this package initializer import-light to avoid defeating lazy-import paths | |
| (for example in rope_ext.py) and to avoid import-time registration side effects. | |
| - Call register_oot() explicitly during plugin setup to trigger OOT registrations. | |
| """ | |
| import importlib | |
| import logging | |
| logger = logging.getLogger("vllm_kunlun.ops.rotary_embedding") | |
| _REGISTERED = False | |
| _SYMBOL_TO_MODULE = { | |
| "Gemma4RotaryEmbedding": ( | |
| "vllm_kunlun.ops.rotary_embedding.gemma4_rope", | |
| "Gemma4RotaryEmbedding", | |
| ), | |
| "KunlunDeepseekScalingRotaryEmbedding": ( | |
| "vllm_kunlun.ops.rotary_embedding.kunlun_deepseek_rope", | |
| "KunlunDeepseekScalingRotaryEmbedding", | |
| ), | |
| "KunlunMRotaryEmbedding": ( | |
| "vllm_kunlun.ops.rotary_embedding.kunlun_mrope", | |
| "KunlunMRotaryEmbedding", | |
| ), | |
| "KunlunRotaryEmbedding": ( | |
| "vllm_kunlun.ops.rotary_embedding.kunlun_rope", | |
| "KunlunRotaryEmbedding", | |
| ), | |
| "Split_Norm_Rope": ( | |
| "vllm_kunlun.ops.rotary_embedding.utils", | |
| "Split_Norm_Rope", | |
| ), | |
| } | |
| __all__ = list(_SYMBOL_TO_MODULE) + ["register_oot"] | |
| def register_oot(): | |
| """Import OOT modules explicitly so their registration side effects run.""" | |
| global _REGISTERED | |
| if _REGISTERED: | |
| return | |
| importlib.import_module("vllm_kunlun.ops.rotary_embedding.gemma4_rope") | |
| importlib.import_module( | |
| "vllm_kunlun.ops.rotary_embedding.kunlun_deepseek_rope") | |
| importlib.import_module("vllm_kunlun.ops.rotary_embedding.kunlun_mrope") | |
| importlib.import_module("vllm_kunlun.ops.rotary_embedding.kunlun_rope") | |
| _REGISTERED = True | |
| logger.info( | |
| "[KunlunOOT] Registered KunlunRotaryEmbedding, " | |
| "KunlunMRotaryEmbedding, KunlunDeepseekScalingRotaryEmbedding " | |
| "via CustomOp.register_oot" | |
| ) | |
| def __getattr__(name): | |
| if name not in _SYMBOL_TO_MODULE: | |
| raise AttributeError( | |
| f"module {__name__!r} has no attribute {name!r}") | |
| module_name, attr_name = _SYMBOL_TO_MODULE[name] | |
| module = importlib.import_module(module_name) | |
| value = getattr(module, attr_name) | |
| globals()[name] = value | |
| return value |
| def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: | ||
| return self.embed_tokens(input_ids) * self.normalizer | ||
|
|
There was a problem hiding this comment.
embed_input_ids multiplies embeddings by self.normalizer, which is registered as a default torch.tensor(...) (float32). Multiplying (b)float16 embeddings by a float32 tensor will promote the whole embedding output to float32, increasing memory/compute and diverging from the comment about “downcast to model dtype”. Consider storing these scale factors as Python floats, or casting the buffer to embed_tokens.weight.dtype (or inputs_embeds.dtype) before multiplication.
| per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens) | ||
| per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer | ||
| return per_layer_embeds.reshape( |
There was a problem hiding this comment.
per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer will also upcast per_layer_embeds to float32 because embed_scale_per_layer is a float32 tensor buffer. Same applies to per_layer_projection_scale / per_layer_input_scale later. To keep activations in the model dtype, cast these scale buffers to the activation dtype before use (or store them as Python floats).
| router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1) | ||
| indicator = torch.nn.functional.one_hot( | ||
| topk_ids, num_classes=gating_output.size(-1) | ||
| ).sum(dim=-2) | ||
| gate_weights = indicator * router_probabilities | ||
| renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True) | ||
| renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0) | ||
| dispatch_weights = gate_weights / renorm_factor | ||
|
|
||
| topk_weights = dispatch_weights.gather(1, topk_ids) |
There was a problem hiding this comment.
The routing function materializes a dense [T, num_experts] one-hot indicator tensor (one_hot(...).sum(...)), which is O(T·E) memory and can be a significant overhead (e.g., large token batches with 256/384 experts). This can be computed without a full one-hot by gathering the softmax probabilities at topk_ids, renormalizing those topk values, and returning them directly.
| router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1) | |
| indicator = torch.nn.functional.one_hot( | |
| topk_ids, num_classes=gating_output.size(-1) | |
| ).sum(dim=-2) | |
| gate_weights = indicator * router_probabilities | |
| renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True) | |
| renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0) | |
| dispatch_weights = gate_weights / renorm_factor | |
| topk_weights = dispatch_weights.gather(1, topk_ids) | |
| router_probabilities = torch.nn.functional.softmax( | |
| gating_output, dim=-1 | |
| ) | |
| topk_weights = router_probabilities.gather(1, topk_ids) | |
| if renormalize: | |
| renorm_factor = torch.sum(topk_weights, dim=-1, keepdim=True) | |
| renorm_factor = torch.where( | |
| renorm_factor > 0.0, renorm_factor, | |
| torch.ones_like(renorm_factor) | |
| ) | |
| topk_weights = topk_weights / renorm_factor |
| """Update block table and slot mapping for a different KV cache group.""" | ||
| new_metadata = copy.copy(metadata) | ||
| new_metadata.block_tables = blk_table | ||
| new_metadata.slot_mapping = slot_mapping |
There was a problem hiding this comment.
update_block_table uses a shallow copy.copy(metadata) and then overwrites block_tables/slot_mapping, but it leaves _cached_prefill_metadata / _cached_decode_metadata intact. If those caches were already populated, new_metadata.prefill_metadata / decode_metadata will return stale cached objects that still reference the old block tables / slot mapping. Consider clearing these cached fields (or rebuilding via dataclasses.replace while setting the cached fields to None) when creating the updated metadata.
| new_metadata.slot_mapping = slot_mapping | |
| new_metadata.slot_mapping = slot_mapping | |
| # Invalidate cached derived metadata so it is rebuilt against the | |
| # updated block tables and slot mapping instead of reusing stale | |
| # references copied from the original metadata object. | |
| new_metadata._cached_prefill_metadata = None | |
| new_metadata._cached_decode_metadata = None |
| def _load_native_extension(logger: logging.Logger) -> None: | ||
| """Load _kunlun C extension to register torch.ops._C.weak_ref_tensor.""" | ||
| if "native_ext" in _completed_steps: | ||
| return | ||
| _completed_steps.add("native_ext") # only attempt once | ||
| try: | ||
| from . import _kunlun # noqa: F401 | ||
|
|
||
| logger.info("[KunlunPlugin] _kunlun native extension loaded") | ||
| except ImportError as e: | ||
| logger.warning("[KunlunPlugin] Failed to load _kunlun: %s", e) |
There was a problem hiding this comment.
_load_native_extension adds "native_ext" to _completed_steps before attempting the import. This means a transient failure (e.g., extension not yet available during an early plugin-discovery phase) will never be retried, which contradicts the register() docstring about retrying previously-failed steps. Consider only marking the step as completed after a successful import (or track attempted vs completed separately).
- Implement Gemma4MultiModalProcessor with image, audio and video handling - Replace self.info.ctx.get_merged_mm_kwargs() with model_config.get_multimodal_config().merge_mm_processor_kwargs() to fix AttributeError in container vLLM environment - Add fallback for bare numpy array video items since container vLLM strips metadata before _call_hf_processor; synthesise default metadata to fix "too many values to unpack" error Signed-off-by: GrootLiu <1219671600@qq.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 21 out of 22 changed files in this pull request and generated 6 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| import importlib | ||
|
|
||
| from vllm.reasoning import ReasoningParserManager | ||
|
|
||
| """ | ||
| Reasoning parser registration module for vLLM Kunlun. | ||
| """ | ||
|
|
||
|
|
There was a problem hiding this comment.
The triple-quoted string is placed after imports, so it is not treated as the module docstring (and may be flagged by linters as a pointless string literal). Move this docstring to the very top of the file (before imports) or convert it to a regular comment if you don't want it as doc.
| import importlib | |
| from vllm.reasoning import ReasoningParserManager | |
| """ | |
| Reasoning parser registration module for vLLM Kunlun. | |
| """ | |
| """ | |
| Reasoning parser registration module for vLLM Kunlun. | |
| """ | |
| import importlib | |
| from vllm.reasoning import ReasoningParserManager |
| aux_hidden_states: list[torch.Tensor], | ||
| layer_idx: int, | ||
| hidden_states: torch.Tensor, | ||
| residual: torch.Tensor, |
There was a problem hiding this comment.
residual is typed as torch.Tensor, but this method explicitly handles residual is not None and callers pass None (e.g. during the first layer). Update the type annotation to torch.Tensor | None to match actual usage and avoid type-checker errors.
| residual: torch.Tensor, | |
| residual: torch.Tensor | None, |
| # NOTE(kunlun): prefill_attention kernel internally applies | ||
| # 1/sqrt(head_dim) and multiplies by alpha. Compute alpha to | ||
| # achieve the desired effective scaling: | ||
| # score = Q @ K^T * (1/sqrt(d)) * alpha | ||
| # We want: score = Q @ K^T * self.scale | ||
| # So: alpha = self.scale * sqrt(d) = self.scale / (1/sqrt(d)) | ||
| import math | ||
|
|
||
| _prefill_alpha = self.scale * math.sqrt(self.head_size) |
There was a problem hiding this comment.
import math is inside the attention forward path, so it executes on every prefill call (hot path). Move the import to module scope (and consider precomputing sqrt(head_size) once) to avoid repeated work in a latency-sensitive function.
| merged_kwargs = self.info.ctx.get_merged_mm_kwargs( | ||
| hf_processor_mm_kwargs, |
There was a problem hiding this comment.
This still uses self.info.ctx.get_merged_mm_kwargs(...), but earlier in this PR _call_hf_processor was updated because some vLLM environments don't have InputProcessingContext.get_merged_mm_kwargs. For consistency (and to avoid the same AttributeError during prompt update building), use the model_config.get_multimodal_config().merge_mm_processor_kwargs(...) path here as well.
| merged_kwargs = self.info.ctx.get_merged_mm_kwargs( | |
| hf_processor_mm_kwargs, | |
| model_config = self.info.ctx.model_config | |
| merged_kwargs = ( | |
| model_config.get_multimodal_config() | |
| .merge_mm_processor_kwargs(hf_processor_mm_kwargs) |
| if multimodal_embeddings is None or is_multimodal is None: | ||
| return super().embed_input_ids(input_ids) | ||
|
|
||
| return super().embed_input_ids( | ||
| input_ids, | ||
| multimodal_embeddings=multimodal_embeddings, | ||
| is_multimodal=is_multimodal, | ||
| ) |
There was a problem hiding this comment.
super().embed_input_ids(...) is called here, but Gemma4ForConditionalGeneration only inherits from nn.Module plus Protocols, so there is no guaranteed implementation of embed_input_ids in the MRO. This is likely to raise at runtime and it also bypasses the standard multimodal merge behavior used elsewhere in this repo (embed text first, then merge multimodal_embeddings at is_multimodal positions). Implement embed_input_ids explicitly by delegating to self.language_model.embed_input_ids and then merging multimodal embeddings (similar to other multimodal models in vllm_kunlun/models/*).
| if multimodal_embeddings is None or is_multimodal is None: | |
| return super().embed_input_ids(input_ids) | |
| return super().embed_input_ids( | |
| input_ids, | |
| multimodal_embeddings=multimodal_embeddings, | |
| is_multimodal=is_multimodal, | |
| ) | |
| inputs_embeds = self.language_model.embed_input_ids(input_ids) | |
| if multimodal_embeddings is None or is_multimodal is None: | |
| return inputs_embeds | |
| merged_embeddings = inputs_embeds.clone() | |
| merged_embeddings[ | |
| is_multimodal.to(inputs_embeds.device, non_blocking=True) | |
| ] = multimodal_embeddings.to( | |
| device=inputs_embeds.device, | |
| dtype=inputs_embeds.dtype, | |
| non_blocking=True, | |
| ) | |
| return merged_embeddings |
| def forward( | ||
| self, x: torch.Tensor | ||
| ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: | ||
| import vllm._custom_ops as ops |
There was a problem hiding this comment.
The return type annotation for forward() doesn't match the implementation: this method always returns a 2-tuple (output, bias_or_none), never a bare torch.Tensor. Tighten the return type to the tuple form (and align the bias element type with what ReplicatedLinear.forward actually returns) so callers and type-checkers can rely on it.
- Remove dead code caused by static signature check - Use is_speculative as the main branch condition - Remove paged_attention branch (speculative_attention covers its capability) - Fix max_context_len from hardcoded 131072 to decode_meta.max_model_len Signed-off-by: GrootLiu <1219671600@qq.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 21 out of 22 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| vLLM may invoke this multiple times during different discovery phases; | ||
| each step tracks its own completion state via ``_completed_steps`` so | ||
| already-succeeded work is skipped while previously-failed work (e.g. | ||
| _patch_rotary_embedding blocked by circular import) is retried. | ||
| """ |
There was a problem hiding this comment.
The register() docstring says previously-failed steps are retried, but _load_native_extension() marks the native_ext step as completed before attempting the import (so an ImportError will not be retried on subsequent register() calls). Please either adjust the docstring to match the actual behavior or change _load_native_extension() to only mark completion after a successful load if retries are intended.
[Feature] Add Gemma4 support (MoE, multimodal) and refactor RoPE
This commit introduces Gemma4 model support including MoE routing,
YOCO KV-sharing, and multimodal processing, along with a refactor of
RoPE implementations and related infrastructure.
Feature:
Add Gemma4ForCausalLM:
Add Gemma4RotaryEmbedding:
Add RoPE extensibility:
Add multimodal support:
Add inference enhancements:
Refactor:
Restructure rotary embedding:
(kunlun_rope, mrope, deepseek_rope, gemma4_rope, rope_ext, utils)
Refactor plugin initialization:
Refactor attention ops:
Enhancement / Fix:
v1 kunlun_attn:
v0 kunlun_attn:
Register:
Signed-off-by: GrootLiu 1219671600@qq.com
PR Description
1. 新增:Gemma4 模型主体
vllm_kunlun/models/gemma4.py(+1615 行,纯新增)
vllm_kunlun/models/interfaces.py(+83 行)
vllm_kunlun/models/init.py(+4 行)
2. 重构:RoPE 从单文件拆成 package
删除
vllm_kunlun/ops/rotary_embedding.py(-256 行)新建
vllm_kunlun/ops/rotary_embedding/目录(+542 行):kunlun_rope.py(99)— OOT RotaryEmbeddingkunlun_mrope.py(71)— OOT MRotaryEmbeddingkunlun_deepseek_rope.py(64)— OOT Deepseek Scaling RoPEgemma4_rope.py(83)— proportional scaling,head_dim 为分母,非旋转维 zero-padrope_ext.py(108)— 可扩展 get_rope() + _ROPE_TYPE_REGISTRY(注册 "proportional")utils.py(63)— Split_Norm_Rope 工具__init__.py(54)3. 新增:MoE Gate 算子
vllm_kunlun/ops/fused_moe/router/gate_linear.py(+116 行)
GateLinear(继承 ReplicatedLinear)三级 GEMM dispatch:
XPU 上一定会走 Tier 3,等价于 ReplicatedLinear;gemma4.py 目前直接用 ReplicatedLinear,GateLinear 尚未投入使用
4. 修改:Attention backend
v1/attention/backends/kunlun_attn.py
update_block_table()支持 KV sharingscale * sqrt(head_dim),抵消 XPU kernel 里写死的1/sqrt(head_dim),变为可自定义缩放比例,如果未自定义,默认为sqrt(head_dim)swa_left、swa_right、sink0.0 → self.scalespeculative_attention,移除paged_attention分支,修复静态签名检查导致的死代码问题max_context_len:从硬编码131072改为decode_meta.max_model_lenops/attention/backends/kunlun_attn.py(小改) + ops/paged_attn.py
split_kv_cache(kv_cache=...)接口(去掉多余参数)5. 修改:插件初始化
vllm_kunlun/init.py(+162/-?)
_load_native_extension/_patch_schema_utils/_install_import_hook_completed_steps做幂等多阶段注册vllm_kunlun/reasoning/init.py(+22 行)
pyproject.toml(+1)+ setup.py(+1)
kunlun_reasoning_parser→vllm.general_plugins6. 修改:Gemma4 多模态 API 兼容性
vllm_kunlun/models/gemma4_mm.py(+13/-2)
修复容器中 vLLM 环境与代码期望的 API 差异:
_call_hf_processor(493-494):将self.info.ctx.get_merged_mm_kwargs(mm_kwargs)替换为self.info.ctx.model_config.get_multimodal_config().merge_mm_processor_kwargs(mm_kwargs),解决AttributeError: 'InputProcessingContext' object has no attribute 'get_merged_mm_kwargs'for item in videos循环中增加类型判断分支。容器中 vLLM 解析管道会在传入_call_hf_processor前剥离 metadata,导致video_array, metadata = item解包时出现ValueError: too many values to unpack。现在:(array, metadata)二元组:正常解包Checklist (Required)
Before submitting this PR, please ensure that all the following items are completed:
pre-commitchecks.git commit -s.PR Type
Please prefix the PR title with one or more of the following labels to help reviewers quickly understand the nature of the change:
[Feature]– New features or enhancements (e.g. Attention, Communicator, Kernel, Worker, etc.)[Bugfix]– Bug fixes[CI/Build]– CI, build system, or infrastructure improvements[Doc]– Documentation updates or fixes[Misc]– Other changes that do not fit the above categories (use sparingly)Detailed Checklist (Click to Expand)
Thank you for contributing to vLLM Kunlun! To help us maintain high code quality and streamline the review process, please ensure your PR meets the following requirements.
1. Code Quality
pre-commit).2. Testing
3. DCO Compliance
This project follows the Developer Certificate of Origin (DCO).
Signed-off-by:line.git commit -sto automatically add the sign-off.4. Review Expectations
During the review process, maintainers may:
We appreciate your patience and collaboration throughout the review process!