Skip to content

Commit 0f7ec53

Browse files
[plugin][UT] add test cases(not model level) for ATOM OOT
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
1 parent 43ca1de commit 0f7ec53

File tree

6 files changed

+339
-0
lines changed

6 files changed

+339
-0
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import pytest
2+
3+
import atom.plugin.config as plugin_config
4+
5+
6+
class _Obj:
7+
def __init__(self, **kwargs):
8+
for k, v in kwargs.items():
9+
setattr(self, k, v)
10+
11+
12+
class _FakeConfig:
13+
def __init__(self, **kwargs):
14+
for k, v in kwargs.items():
15+
setattr(self, k, v)
16+
17+
18+
class _FakeCompilationConfig:
19+
def __init__(self, level, use_cudagraph, cudagraph_mode):
20+
self.level = level
21+
self.use_cudagraph = use_cudagraph
22+
self.cudagraph_mode = cudagraph_mode
23+
24+
25+
def _patch_atom_config_module(monkeypatch):
26+
import atom.config as atom_config_module
27+
28+
monkeypatch.setattr(atom_config_module, "Config", _FakeConfig, raising=False)
29+
monkeypatch.setattr(
30+
atom_config_module, "CompilationConfig", _FakeCompilationConfig, raising=False
31+
)
32+
33+
34+
def test_generate_from_vllm_translates_core_fields(monkeypatch):
35+
_patch_atom_config_module(monkeypatch)
36+
monkeypatch.setenv("ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "0")
37+
38+
vllm_cfg = _Obj(
39+
model_config=_Obj(model="m1", max_model_len=4096),
40+
scheduler_config=_Obj(max_num_batched_tokens=2048, max_num_seqs=8),
41+
cache_config=_Obj(
42+
gpu_memory_utilization=0.5,
43+
block_size=16,
44+
num_gpu_blocks=1024,
45+
cache_dtype="auto",
46+
enable_prefix_caching=True,
47+
),
48+
parallel_config=_Obj(
49+
rank=1, tensor_parallel_size=2, enable_expert_parallel=False
50+
),
51+
compilation_config=_Obj(mode=3),
52+
quant_config=_Obj(name="q"),
53+
)
54+
55+
cfg = plugin_config._generate_atom_config_from_vllm_config(vllm_cfg)
56+
57+
assert cfg.model == "m1"
58+
assert cfg.max_num_batched_tokens == 2048
59+
assert cfg.max_num_seqs == 8
60+
assert cfg.max_model_len == 4096
61+
assert cfg.tensor_parallel_size == 2
62+
assert cfg.enforce_eager is True
63+
assert cfg.compilation_config.level == 3
64+
assert cfg.plugin_config.is_plugin_mode is True
65+
assert cfg.plugin_config.is_vllm is True
66+
assert cfg.plugin_config.is_sglang is False
67+
assert cfg.plugin_config.vllm_use_atom_attention is True
68+
69+
70+
def test_generate_atom_config_requires_plugin_mode(monkeypatch):
71+
import atom.plugin.config as config_module
72+
import atom.plugin as plugin_module
73+
import atom.config as atom_config_module
74+
75+
monkeypatch.setattr(plugin_module, "is_vllm", lambda: False, raising=False)
76+
monkeypatch.setattr(plugin_module, "is_sglang", lambda: False, raising=False)
77+
monkeypatch.setattr(
78+
atom_config_module, "set_current_atom_config", lambda _cfg: None, raising=False
79+
)
80+
81+
with pytest.raises(ValueError, match="running in plugin mode"):
82+
config_module.generate_atom_config_for_plugin_mode(config=None)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import importlib
2+
import importlib.util
3+
import sys
4+
import types
5+
6+
import pytest
7+
8+
9+
def test_disable_vllm_plugin_flag_disables_platform(monkeypatch):
10+
# ATOM_DISABLE_VLLM_PLUGIN takes precedence:
11+
# when it is 1, vLLM should not get ATOM platform/attention at all.
12+
for disable_attention in ("0", "1"):
13+
monkeypatch.setenv("ATOM_DISABLE_VLLM_PLUGIN", "1")
14+
monkeypatch.setenv("ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", disable_attention)
15+
16+
import atom.plugin.vllm.platform as platform_module
17+
import atom.plugin.vllm.register as register_module
18+
19+
importlib.reload(platform_module)
20+
importlib.reload(register_module)
21+
22+
assert platform_module.ATOMPlatform is None
23+
assert register_module.register_platform() is None
24+
25+
26+
@pytest.mark.skipif(
27+
importlib.util.find_spec("vllm") is None,
28+
reason="vllm is not installed in current test environment",
29+
)
30+
def test_disable_vllm_plugin_attention_fallbacks_to_non_atom_backend(monkeypatch):
31+
rocm_module = types.ModuleType("vllm.platforms.rocm")
32+
33+
class _RocmPlatform:
34+
@classmethod
35+
def get_attn_backend_cls(cls, selected_backend, attn_selector_config):
36+
return "vllm.default.backend"
37+
38+
rocm_module.RocmPlatform = _RocmPlatform
39+
40+
monkeypatch.setitem(sys.modules, "vllm", types.ModuleType("vllm"))
41+
monkeypatch.setitem(
42+
sys.modules, "vllm.platforms", types.ModuleType("vllm.platforms")
43+
)
44+
monkeypatch.setitem(sys.modules, "vllm.platforms.rocm", rocm_module)
45+
monkeypatch.setenv("ATOM_DISABLE_VLLM_PLUGIN", "0")
46+
monkeypatch.setenv("ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "1")
47+
48+
import atom.plugin.vllm.platform as platform_module
49+
50+
importlib.reload(platform_module)
51+
52+
result = platform_module.ATOMPlatform.get_attn_backend_cls(
53+
selected_backend="x",
54+
attn_selector_config=types.SimpleNamespace(use_mla=True),
55+
)
56+
assert result == "vllm.default.backend"
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import pytest
2+
3+
from atom.plugin import prepare as plugin_prepare
4+
5+
6+
@pytest.fixture(autouse=True)
7+
def _reset_framework_state():
8+
# Autouse fixture: pytest runs this before/after every test.
9+
plugin_prepare._set_framework_backbone("atom")
10+
yield
11+
plugin_prepare._set_framework_backbone("atom")
12+
13+
14+
def test_default_mode_is_server_mode():
15+
assert plugin_prepare.is_plugin_mode() is False
16+
assert plugin_prepare.is_vllm() is False
17+
assert plugin_prepare.is_sglang() is False
18+
19+
20+
def test_set_framework_to_vllm():
21+
plugin_prepare._set_framework_backbone("vllm")
22+
assert plugin_prepare.is_plugin_mode() is True
23+
assert plugin_prepare.is_vllm() is True
24+
assert plugin_prepare.is_sglang() is False
25+
26+
27+
def test_set_framework_to_sgl_alias():
28+
plugin_prepare._set_framework_backbone("sgl")
29+
assert plugin_prepare.is_plugin_mode() is True
30+
assert plugin_prepare.is_vllm() is False
31+
assert plugin_prepare.is_sglang() is True
32+
33+
34+
def test_set_framework_unsupported_raises():
35+
with pytest.raises(ValueError, match="Unsupported framework"):
36+
plugin_prepare._set_framework_backbone("tensorflow")
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import sys
2+
import types
3+
import importlib
4+
import importlib.util
5+
6+
import pytest
7+
8+
from atom.plugin import prepare as plugin_prepare
9+
import atom.plugin.vllm.register as vllm_register
10+
11+
12+
@pytest.fixture(autouse=True)
13+
def _reset_framework_state():
14+
plugin_prepare._set_framework_backbone("atom")
15+
yield
16+
plugin_prepare._set_framework_backbone("atom")
17+
18+
19+
@pytest.mark.skipif(
20+
importlib.util.find_spec("vllm") is None,
21+
reason="vllm is not installed in current test environment",
22+
)
23+
def test_register_platform_returns_oot_platform(monkeypatch):
24+
rocm_module = types.ModuleType("vllm.platforms.rocm")
25+
26+
class _RocmPlatform:
27+
pass
28+
29+
rocm_module.RocmPlatform = _RocmPlatform
30+
vllm_platforms = types.ModuleType("vllm.platforms")
31+
vllm_platforms.current_platform = None
32+
33+
monkeypatch.setitem(sys.modules, "vllm", types.ModuleType("vllm"))
34+
monkeypatch.setitem(sys.modules, "vllm.platforms", vllm_platforms)
35+
monkeypatch.setitem(sys.modules, "vllm.platforms.rocm", rocm_module)
36+
37+
monkeypatch.setenv("ATOM_DISABLE_VLLM_PLUGIN", "0")
38+
monkeypatch.setenv("ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "0")
39+
40+
import atom.plugin.vllm.platform as platform_module
41+
42+
importlib.reload(platform_module)
43+
importlib.reload(vllm_register)
44+
45+
platform_path = vllm_register.register_platform()
46+
module_name, class_name = platform_path.rsplit(".", 1)
47+
vllm_platforms.current_platform = getattr(
48+
importlib.import_module(module_name), class_name
49+
)
50+
51+
# get current platform from vllm side and validate it is ATOM platform.
52+
assert vllm_platforms.current_platform is platform_module.ATOMPlatform
53+
54+
55+
def test_register_platform_can_be_disabled(monkeypatch):
56+
monkeypatch.setattr(vllm_register, "disable_vllm_plugin", True, raising=False)
57+
assert vllm_register.register_platform() is None
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import importlib.util
2+
import importlib
3+
import sys
4+
import types
5+
6+
import pytest
7+
8+
9+
# FIXME: remove it later when enabling fallback for unsupported models
10+
@pytest.mark.skipif(
11+
importlib.util.find_spec("vllm") is None,
12+
reason="vllm is not installed in current test environment",
13+
)
14+
def test_vllm_wrapper_rejects_unsupported_model_arch(monkeypatch):
15+
# Avoid importing deep model-loader dependencies during test collection/import.
16+
fake_loader = types.ModuleType("atom.model_loader.loader")
17+
fake_loader.load_model_in_plugin_mode = lambda **kwargs: set()
18+
monkeypatch.setitem(sys.modules, "atom.model_loader.loader", fake_loader)
19+
20+
model_wrapper = importlib.import_module("atom.plugin.vllm.model_wrapper")
21+
22+
with pytest.raises(ValueError, match="not supported by ATOM OOT backend"):
23+
model_wrapper._get_atom_model_cls("UnknownModelForCausalLM")
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import importlib.util
2+
3+
import pytest
4+
5+
6+
@pytest.mark.skipif(
7+
importlib.util.find_spec("vllm") is None,
8+
reason="vllm is not installed in current test environment",
9+
)
10+
def test_vllm_import_paths_guardrail():
11+
"""Guardrail for OOT vLLM import paths used by ATOM plugin mode."""
12+
# attention.py / paged_attention.py (new path with legacy fallback)
13+
try:
14+
from vllm.attention.layer import Attention, MLAAttention, AttentionType
15+
except ImportError:
16+
from vllm.model_executor.layers.attention import Attention, MLAAttention
17+
from vllm.v1.attention.backend import AttentionType
18+
19+
# attention.py
20+
from vllm.config import (
21+
VllmConfig,
22+
get_current_vllm_config,
23+
get_layers_from_vllm_config,
24+
)
25+
from vllm.model_executor.layers.attention.mla_attention import (
26+
MLACommonMetadataBuilder,
27+
QueryLenSupport,
28+
)
29+
from vllm.utils.math_utils import cdiv, round_down
30+
from vllm.v1.attention.backend import AttentionCGSupport, AttentionMetadataBuilder
31+
from vllm.v1.attention.backends.utils import (
32+
get_dcp_local_seq_lens,
33+
split_decodes_and_prefills,
34+
split_decodes_prefills_and_extends,
35+
)
36+
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
37+
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
38+
39+
# model_wrapper.py (core vLLM model interfaces)
40+
from vllm.model_executor.models.interfaces import SupportsPP, SupportsQuant
41+
from vllm.model_executor.models.interfaces_base import (
42+
VllmModel,
43+
VllmModelForTextGeneration,
44+
)
45+
from vllm.model_executor.models.registry import ModelRegistry
46+
from vllm.sequence import IntermediateTensors
47+
48+
# attention_mla.py / platform.py / register.py
49+
from vllm import _custom_ops
50+
from vllm.distributed.parallel_state import get_dcp_group
51+
from vllm.platforms import current_platform
52+
from vllm.platforms.rocm import RocmPlatform
53+
54+
assert all(
55+
obj is not None
56+
for obj in [
57+
Attention,
58+
MLAAttention,
59+
AttentionType,
60+
QueryLenSupport,
61+
MLACommonMetadataBuilder,
62+
cdiv,
63+
round_down,
64+
AttentionCGSupport,
65+
AttentionMetadataBuilder,
66+
get_dcp_local_seq_lens,
67+
split_decodes_and_prefills,
68+
split_decodes_prefills_and_extends,
69+
cp_lse_ag_out_rs,
70+
merge_attn_states,
71+
VllmConfig,
72+
get_current_vllm_config,
73+
get_layers_from_vllm_config,
74+
SupportsPP,
75+
SupportsQuant,
76+
VllmModel,
77+
VllmModelForTextGeneration,
78+
ModelRegistry,
79+
IntermediateTensors,
80+
_custom_ops,
81+
get_dcp_group,
82+
current_platform,
83+
RocmPlatform,
84+
]
85+
)

0 commit comments

Comments
 (0)