Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions tests/ut/test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,115 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, patch

import pytest
import torch


def _make_platform_config(hf_config):
from vllm.config import CUDAGraphMode

return SimpleNamespace(
parallel_config=SimpleNamespace(worker_cls="manual", data_parallel_size=1),
model_config=SimpleNamespace(
use_mla=False,
hf_config=hf_config,
enforce_eager=False,
),
cache_config=SimpleNamespace(block_size=16),
speculative_config=None,
compilation_config=SimpleNamespace(
cudagraph_mode=CUDAGraphMode.NONE,
pass_config=SimpleNamespace(enable_fusion=True),
backend=None,
custom_ops=[],
),
)


def _check_platform_config(vllm_config):
import vllm.envs as envs

from vllm_kunlun.platforms.kunlun import KunlunPlatform

with patch.object(envs, "VLLM_ALL2ALL_BACKEND", None, create=True):
KunlunPlatform.check_and_update_config(vllm_config)


def test_qwen3_vl_text_config_inherits_top_level_tie_word_embeddings():
text_config = SimpleNamespace()
hf_config = SimpleNamespace(
architectures=["Qwen3VLForConditionalGeneration"],
text_config=text_config,
tie_word_embeddings=False,
)

_check_platform_config(_make_platform_config(hf_config))

assert text_config.tie_word_embeddings is False


def test_qwen3_vl_text_config_inherits_from_string_architecture():
text_config = SimpleNamespace()
hf_config = SimpleNamespace(
architectures="Qwen3VLForConditionalGeneration",
text_config=text_config,
tie_word_embeddings=True,
)

_check_platform_config(_make_platform_config(hf_config))

assert text_config.tie_word_embeddings is True


def test_qwen3_vl_config_type_is_detected_without_architectures():
text_config = SimpleNamespace()
hf_config = type("Qwen3VLConfig", (), {})()
hf_config.text_config = text_config
hf_config.tie_word_embeddings = False

_check_platform_config(_make_platform_config(hf_config))

assert text_config.tie_word_embeddings is False


def test_qwen3_vl_text_config_existing_tie_word_embeddings_is_preserved():
text_config = SimpleNamespace(tie_word_embeddings=True)
hf_config = SimpleNamespace(
architectures=["Qwen3VLForConditionalGeneration"],
text_config=text_config,
tie_word_embeddings=False,
)

_check_platform_config(_make_platform_config(hf_config))

assert text_config.tie_word_embeddings is True


def test_non_qwen3_vl_text_config_is_not_modified():
text_config = SimpleNamespace()
hf_config = SimpleNamespace(
architectures=["Qwen3ForCausalLM"],
text_config=text_config,
tie_word_embeddings=False,
)

_check_platform_config(_make_platform_config(hf_config))

assert not hasattr(text_config, "tie_word_embeddings")


def test_qwen3_vl_text_config_without_top_level_tie_word_embeddings_is_not_modified():
text_config = SimpleNamespace()
hf_config = SimpleNamespace(
architectures=["Qwen3VLForConditionalGeneration"],
text_config=text_config,
)

_check_platform_config(_make_platform_config(hf_config))

assert not hasattr(text_config, "tie_word_embeddings")


def test_import():
"""Test that the module can be imported successfully."""
from vllm_kunlun.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
Expand Down
29 changes: 29 additions & 0 deletions vllm_kunlun/platforms/kunlun.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,33 @@

logger = init_logger(__name__)

_QWEN3_VL_ARCHITECTURES = {"Qwen3VLForConditionalGeneration"}


def _is_qwen3_vl_config(hf_config) -> bool:
config_type = type(hf_config).__name__
architectures = getattr(hf_config, "architectures", None) or ()
if isinstance(architectures, str):
architectures = (architectures,)

return config_type == "Qwen3VLConfig" or any(
architecture in _QWEN3_VL_ARCHITECTURES for architecture in architectures
)
Comment on lines +24 to +32
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 8a46ecb. I added targeted regression coverage for the string architectures path, the Qwen3VLConfig type-name path, and the missing top-level field case.



def _patch_qwen3_vl_text_config(hf_config) -> None:
if hf_config is None or not _is_qwen3_vl_config(hf_config):
return

text_config = getattr(hf_config, "text_config", None)
if text_config is None or hasattr(text_config, "tie_word_embeddings"):
return

if not hasattr(hf_config, "tie_word_embeddings"):
return

text_config.tie_word_embeddings = hf_config.tie_word_embeddings


class KunlunPlatform(Platform):
"""KunlunPlatform"""
Expand Down Expand Up @@ -179,6 +206,8 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
parallel_config = vllm_config.parallel_config # Not use scheduler_config
# scheduler_config = vllm_config.scheduler_config
model_config = vllm_config.model_config
if model_config is not None:
_patch_qwen3_vl_text_config(getattr(model_config, "hf_config", None))

if parallel_config.worker_cls == "auto":
# v0.15.1 do not support v0.15.1, remove the if condition
Expand Down
Loading