-
Notifications
You must be signed in to change notification settings - Fork 76
Expand file tree
/
Copy pathtest.py
More file actions
228 lines (168 loc) · 8.51 KB
/
test.py
File metadata and controls
228 lines (168 loc) · 8.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, patch
import pytest
import torch
def _make_platform_config(hf_config):
from vllm.config import CUDAGraphMode
return SimpleNamespace(
parallel_config=SimpleNamespace(worker_cls="manual", data_parallel_size=1),
model_config=SimpleNamespace(
use_mla=False,
hf_config=hf_config,
enforce_eager=False,
),
cache_config=SimpleNamespace(block_size=16),
speculative_config=None,
compilation_config=SimpleNamespace(
cudagraph_mode=CUDAGraphMode.NONE,
pass_config=SimpleNamespace(enable_fusion=True),
backend=None,
custom_ops=[],
),
)
def _check_platform_config(vllm_config):
import vllm.envs as envs
from vllm_kunlun.platforms.kunlun import KunlunPlatform
with patch.object(envs, "VLLM_ALL2ALL_BACKEND", None, create=True):
KunlunPlatform.check_and_update_config(vllm_config)
def test_qwen3_vl_text_config_inherits_top_level_tie_word_embeddings():
text_config = SimpleNamespace()
hf_config = SimpleNamespace(
architectures=["Qwen3VLForConditionalGeneration"],
text_config=text_config,
tie_word_embeddings=False,
)
_check_platform_config(_make_platform_config(hf_config))
assert text_config.tie_word_embeddings is False
def test_qwen3_vl_text_config_existing_tie_word_embeddings_is_preserved():
text_config = SimpleNamespace(tie_word_embeddings=True)
hf_config = SimpleNamespace(
architectures=["Qwen3VLForConditionalGeneration"],
text_config=text_config,
tie_word_embeddings=False,
)
_check_platform_config(_make_platform_config(hf_config))
assert text_config.tie_word_embeddings is True
def test_non_qwen3_vl_text_config_is_not_modified():
text_config = SimpleNamespace()
hf_config = SimpleNamespace(
architectures=["Qwen3ForCausalLM"],
text_config=text_config,
tie_word_embeddings=False,
)
_check_platform_config(_make_platform_config(hf_config))
assert not hasattr(text_config, "tie_word_embeddings")
def test_import():
"""Test that the module can be imported successfully."""
from vllm_kunlun.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
assert TorchCompileWrapperWithCustomDispatcher is not None
def test_basic_instantiation():
"""Test basic wrapper instantiation with mocked dependencies."""
from vllm_kunlun.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
# Create a concrete implementation
class TestWrapper(TorchCompileWrapperWithCustomDispatcher):
def forward(self, x):
return x * 2
# Mock all the dependencies
mock_config = MagicMock()
mock_config.compilation_config.init_backend.return_value = "eager"
mock_config.compilation_config.inductor_compile_config = None
with patch("vllm.config.get_current_vllm_config", return_value=mock_config):
with patch("vllm.config.CompilationLevel") as mock_level:
mock_level.DYNAMO_ONCE = 1
with patch("torch.compile", side_effect=lambda func, **kwargs: func):
with patch("torch._dynamo.convert_frame.register_bytecode_hook"):
wrapper = TestWrapper(compilation_level=0)
# Verify basic attributes exist
assert hasattr(wrapper, "vllm_config")
assert hasattr(wrapper, "compiled_callable")
assert hasattr(wrapper, "original_code_object")
assert hasattr(wrapper, "compiled_codes")
assert isinstance(wrapper.compiled_codes, list)
def test_forward_call():
"""Test that the forward method can be called."""
from vllm_kunlun.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
class TestWrapper(TorchCompileWrapperWithCustomDispatcher):
def forward(self, x):
return x * 2
mock_config = MagicMock()
mock_config.compilation_config.init_backend.return_value = "eager"
mock_config.compilation_config.inductor_compile_config = None
with patch("vllm.config.get_current_vllm_config", return_value=mock_config):
with patch("vllm.config.CompilationLevel") as mock_level:
mock_level.DYNAMO_ONCE = 1
with patch("torch.compile", side_effect=lambda func, **kwargs: func):
with patch("torch._dynamo.convert_frame.register_bytecode_hook"):
wrapper = TestWrapper(compilation_level=0)
# Test calling the wrapper
input_tensor = torch.tensor([1.0, 2.0, 3.0])
result = wrapper(input_tensor)
expected = input_tensor * 2
assert torch.allclose(result, expected)
def test_custom_callable():
"""Test wrapper with custom compiled callable."""
from vllm_kunlun.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
class TestWrapper(TorchCompileWrapperWithCustomDispatcher):
def forward(self, x):
return x * 2
custom_func = Mock(return_value=torch.tensor([5.0]))
mock_config = MagicMock()
mock_config.compilation_config.init_backend.return_value = "eager"
with patch("vllm.config.get_current_vllm_config", return_value=mock_config):
with patch("vllm.config.CompilationLevel") as mock_level:
mock_level.DYNAMO_ONCE = 1
with patch("torch._dynamo.convert_frame.register_bytecode_hook"):
wrapper = TestWrapper(
compiled_callable=custom_func, compilation_level=0
)
# Verify custom callable is used
assert wrapper.compiled_callable is custom_func
# Call should use custom callable
result = wrapper(torch.tensor([1.0])) # noqa
assert custom_func.called
def test_bytecode_hook_basic():
"""Test that bytecode hook can be called without errors."""
from types import CodeType
from vllm_kunlun.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
class TestWrapper(TorchCompileWrapperWithCustomDispatcher):
def forward(self, x):
return x * 2
mock_config = MagicMock()
mock_config.compilation_config.init_backend.return_value = "eager"
mock_config.compilation_config.inductor_compile_config = None
mock_config.compilation_config.local_cache_dir = None
with patch("vllm.config.get_current_vllm_config", return_value=mock_config):
with patch("vllm.config.CompilationLevel") as mock_level:
mock_level.DYNAMO_ONCE = 1
with patch("torch.compile", side_effect=lambda func, **kwargs: func):
with patch("torch._dynamo.convert_frame.register_bytecode_hook"):
wrapper = TestWrapper(compilation_level=0)
# Test with wrong code object (should be ignored)
wrong_code = MagicMock(spec=CodeType)
new_code = MagicMock(spec=CodeType)
initial_count = len(wrapper.compiled_codes)
wrapper.bytecode_hook(wrong_code, new_code)
# Should not add anything
assert len(wrapper.compiled_codes) == initial_count
def test_use_custom_dispatcher_flag():
"""Test that use_custom_dispatcher flag is set based on compilation_level."""
from vllm_kunlun.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
class TestWrapper(TorchCompileWrapperWithCustomDispatcher):
def forward(self, x):
return x * 2
mock_config = MagicMock()
mock_config.compilation_config.init_backend.return_value = "eager"
mock_config.compilation_config.inductor_compile_config = None
with patch("vllm.config.get_current_vllm_config", return_value=mock_config):
with patch("vllm.config.CompilationLevel") as mock_level:
mock_level.DYNAMO_ONCE = 1
with patch("torch.compile", side_effect=lambda func, **kwargs: func):
with patch("torch._dynamo.convert_frame.register_bytecode_hook"):
# Test with low level
wrapper_low = TestWrapper(compilation_level=0)
assert wrapper_low.use_custom_dispatcher is False
# Test with high level
wrapper_high = TestWrapper(compilation_level=2)
assert wrapper_high.use_custom_dispatcher is True
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])