Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions tests/ut/test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,79 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, patch

import pytest
import torch


def _make_platform_config(hf_config):
from vllm.config import CUDAGraphMode

return SimpleNamespace(
parallel_config=SimpleNamespace(worker_cls="manual", data_parallel_size=1),
model_config=SimpleNamespace(
use_mla=False,
hf_config=hf_config,
enforce_eager=False,
),
cache_config=SimpleNamespace(block_size=16),
speculative_config=None,
compilation_config=SimpleNamespace(
cudagraph_mode=CUDAGraphMode.NONE,
pass_config=SimpleNamespace(enable_fusion=True),
backend=None,
custom_ops=[],
),
)


def _check_platform_config(vllm_config):
import vllm.envs as envs

from vllm_kunlun.platforms.kunlun import KunlunPlatform

with patch.object(envs, "VLLM_ALL2ALL_BACKEND", None, create=True):
KunlunPlatform.check_and_update_config(vllm_config)


def test_qwen3_vl_text_config_inherits_top_level_tie_word_embeddings():
text_config = SimpleNamespace()
hf_config = SimpleNamespace(
architectures=["Qwen3VLForConditionalGeneration"],
text_config=text_config,
tie_word_embeddings=False,
)

_check_platform_config(_make_platform_config(hf_config))

assert text_config.tie_word_embeddings is False


def test_qwen3_vl_text_config_existing_tie_word_embeddings_is_preserved():
text_config = SimpleNamespace(tie_word_embeddings=True)
hf_config = SimpleNamespace(
architectures=["Qwen3VLForConditionalGeneration"],
text_config=text_config,
tie_word_embeddings=False,
)

_check_platform_config(_make_platform_config(hf_config))

assert text_config.tie_word_embeddings is True


def test_non_qwen3_vl_text_config_is_not_modified():
text_config = SimpleNamespace()
hf_config = SimpleNamespace(
architectures=["Qwen3ForCausalLM"],
text_config=text_config,
tie_word_embeddings=False,
)

_check_platform_config(_make_platform_config(hf_config))

assert not hasattr(text_config, "tie_word_embeddings")


def test_import():
"""Test that the module can be imported successfully."""
from vllm_kunlun.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
Expand Down
26 changes: 26 additions & 0 deletions vllm_kunlun/platforms/kunlun.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,30 @@

logger = init_logger(__name__)

_QWEN3_VL_ARCHITECTURES = {"Qwen3VLForConditionalGeneration"}


def _is_qwen3_vl_config(hf_config) -> bool:
config_type = type(hf_config).__name__
architectures = getattr(hf_config, "architectures", None) or ()
if isinstance(architectures, str):
architectures = (architectures,)

return config_type == "Qwen3VLConfig" or any(
architecture in _QWEN3_VL_ARCHITECTURES for architecture in architectures
)
Comment on lines +24 to +32
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 8a46ecb. I added targeted regression coverage for the string architectures path, the Qwen3VLConfig type-name path, and the missing top-level field case.



def _patch_qwen3_vl_text_config(hf_config) -> None:
if hf_config is None or not _is_qwen3_vl_config(hf_config):
return

text_config = getattr(hf_config, "text_config", None)
if text_config is None or hasattr(text_config, "tie_word_embeddings"):
return

text_config.tie_word_embeddings = getattr(hf_config, "tie_word_embeddings", False)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, fixed in 8a46ecb. The patch now only copies tie_word_embeddings when the top-level config explicitly defines it.



class KunlunPlatform(Platform):
"""KunlunPlatform"""
Expand Down Expand Up @@ -179,6 +203,8 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
parallel_config = vllm_config.parallel_config # Not use scheduler_config
# scheduler_config = vllm_config.scheduler_config
model_config = vllm_config.model_config
if model_config is not None:
_patch_qwen3_vl_text_config(getattr(model_config, "hf_config", None))

if parallel_config.worker_cls == "auto":
# v0.15.1 do not support v0.15.1, remove the if condition
Expand Down
Loading