Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tests/entrypoints/openai/test_cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,29 @@ def test_default_chat_template_kwargs_default_none(serve_parser):
assert args.default_chat_template_kwargs is None


def test_enable_moe_topk_indices_nemo_rl_block_store(serve_parser):
args = serve_parser.parse_args(
args=["--enable-moe-topk-indices-nemo-rl-block-store"]
)
assert args.enable_moe_topk_indices_nemo_rl_block_store is True


def test_enable_moe_topk_indices_json(serve_parser):
args = serve_parser.parse_args(args=["--enable-moe-topk-indices-json"])
assert args.enable_moe_topk_indices_json is True


def test_moe_topk_indices_output_modes_are_mutually_exclusive(serve_parser):
args = serve_parser.parse_args(
args=[
"--enable-moe-topk-indices-nemo-rl-block-store",
"--enable-moe-topk-indices-json",
]
)
with pytest.raises(TypeError):
validate_parsed_serve_args(args)


def test_default_chat_template_kwargs_invalid_json(serve_parser):
"""Ensure invalid JSON raises an error"""
with pytest.raises(SystemExit):
Expand Down
288 changes: 288 additions & 0 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
from typing import Any
from unittest.mock import AsyncMock, MagicMock

import numpy as np
import pytest
import pytest_asyncio
from openai import OpenAI

from vllm._aiter_ops import is_aiter_found_and_supported
from vllm.config import MultiModalConfig
from vllm.entrypoints.openai.chat_completion import serving as chat_serving_module
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand Down Expand Up @@ -527,6 +529,8 @@ class MockModelConfig:
encoder_config = None
generation_config: str = "auto"
override_generation_config: dict[str, Any] = field(default_factory=dict)
enable_moe_topk_indices_nemo_rl_block_store: bool = False
enable_moe_topk_indices_json: bool = False
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
Expand Down Expand Up @@ -1042,6 +1046,290 @@ async def mock_generate(*args, **kwargs):
assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] is None


@pytest.mark.asyncio
async def test_serving_chat_returns_moe_topk_indices():
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig(enable_moe_topk_indices_json=True)
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)

serving_chat = _build_serving_chat(mock_engine)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "what is 1+1?"}],
)

async def result_generator():
yield RequestOutput(
request_id=req.request_id,
prompt="what is 1+1?",
prompt_token_ids=[1, 2],
prompt_logprobs=None,
outputs=[
CompletionOutput(
index=0,
text="2",
token_ids=[3],
cumulative_logprob=0.0,
logprobs=None,
routed_experts=np.array([[[7, 8]]], dtype=np.int16),
finish_reason="stop",
)
],
finished=True,
prompt_routed_experts=np.array([[[1, 2]], [[3, 4]]], dtype=np.int16),
)

response = await serving_chat.chat_completion_full_generator(
request=req,
result_generator=result_generator(),
request_id="chatcmpl-test",
model_name=MODEL_NAME,
conversation=[],
tokenizer=MagicMock(),
request_metadata=RequestResponseMetadata(request_id=req.request_id),
)

assert isinstance(response, ChatCompletionResponse)
assert response.prompt_moe_topk_indices == [[[1, 2]], [[3, 4]]]
assert response.choices[0].moe_topk_indices == [[[7, 8]]]


@pytest.mark.asyncio
async def test_serving_chat_omits_moe_topk_indices_without_output_flags():
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)

serving_chat = _build_serving_chat(mock_engine)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "what is 1+1?"}],
)

async def result_generator():
yield RequestOutput(
request_id=req.request_id,
prompt="what is 1+1?",
prompt_token_ids=[1, 2],
prompt_logprobs=None,
outputs=[
CompletionOutput(
index=0,
text="2",
token_ids=[3],
cumulative_logprob=0.0,
logprobs=None,
routed_experts=np.array([[[7, 8]]], dtype=np.int16),
finish_reason="stop",
)
],
finished=True,
prompt_routed_experts=np.array([[[1, 2]], [[3, 4]]], dtype=np.int16),
)

response = await serving_chat.chat_completion_full_generator(
request=req,
result_generator=result_generator(),
request_id="chatcmpl-test",
model_name=MODEL_NAME,
conversation=[],
tokenizer=MagicMock(),
request_metadata=RequestResponseMetadata(request_id=req.request_id),
)

assert isinstance(response, ChatCompletionResponse)
assert response.prompt_moe_topk_indices is None
assert response.choices[0].moe_topk_indices is None
response_dict = response.model_dump()
assert "prompt_moe_topk_indices" not in response_dict
assert "moe_topk_indices" not in response_dict["choices"][0]


@pytest.mark.asyncio
async def test_serving_chat_stores_moe_topk_indices_in_block_store(monkeypatch):
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig(
enable_moe_topk_indices_nemo_rl_block_store=True
)
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)

serving_chat = _build_serving_chat(mock_engine)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "what is 1+1?"}],
)

class FakePutNumpy:
def __init__(self):
self.args = None

def remote(self, *args):
self.args = args
return "put-ref"

fake_put_numpy = FakePutNumpy()
serving_chat.block_store_instance_rank = 7
serving_chat.block_store_instance_id = "nemo_rl.block_store.node.127.0.0.1"
serving_chat.block_store_instance = MagicMock()
serving_chat.block_store_instance.put_numpy = fake_put_numpy

ray = pytest.importorskip("ray")

monkeypatch.setattr(ray, "get", lambda ref: ref)

async def result_generator():
yield RequestOutput(
request_id=req.request_id,
prompt="what is 1+1?",
prompt_token_ids=[1, 2],
prompt_logprobs=None,
outputs=[
CompletionOutput(
index=0,
text="2",
token_ids=[3],
cumulative_logprob=0.0,
logprobs=None,
routed_experts=np.array([[[7, 8]]], dtype=np.int16),
finish_reason="stop",
)
],
finished=True,
prompt_routed_experts=np.array([[[1, 2]], [[3, 4]]], dtype=np.int16),
)

response = await serving_chat.chat_completion_full_generator(
request=req,
result_generator=result_generator(),
request_id="chatcmpl-test",
model_name=MODEL_NAME,
conversation=[],
tokenizer=MagicMock(),
request_metadata=RequestResponseMetadata(request_id=req.request_id),
)

assert isinstance(response, ChatCompletionResponse)
block_cache_key = {
"instance_rank": 7,
"instance_id": "nemo_rl.block_store.node.127.0.0.1",
"req_id": "test",
"key": "moe_topk_indices",
}
assert response.prompt_moe_topk_indices == {"block_cache_key": block_cache_key}
assert response.choices[0].moe_topk_indices == {
"block_cache_key": block_cache_key
}
assert fake_put_numpy.args is not None
assert fake_put_numpy.args[0] == "test"
assert fake_put_numpy.args[1] == "moe_topk_indices"
np.testing.assert_array_equal(
fake_put_numpy.args[2],
np.array([[[1, 2]], [[3, 4]], [[7, 8]]], dtype=np.int16),
)


def test_block_store_cache_key_loads_instance_rank(monkeypatch):
ray = pytest.importorskip("ray")

mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig(
enable_moe_topk_indices_nemo_rl_block_store=True
)
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)

serving_chat = _build_serving_chat(mock_engine)

class FakeGetRuntimeMetadata:
def remote(self):
return "metadata-ref"

fake_actor = MagicMock()
fake_actor.get_runtime_metadata = FakeGetRuntimeMetadata()

monkeypatch.setattr(
ray._private.services,
"get_node_ip_address",
lambda: "127.0.0.1",
)
monkeypatch.setattr(ray, "get_actor", lambda instance_id: fake_actor)
monkeypatch.setattr(ray, "get", lambda ref: {"instance_rank": 5})

block_cache_key = serving_chat._get_moe_topk_indices_block_cache_key(
"chatcmpl-test"
)

assert block_cache_key == {
"instance_rank": 5,
"instance_id": "nemo_rl.block_store.node.127.0.0.1",
"req_id": "test",
"key": "moe_topk_indices",
}
assert serving_chat.block_store_instance is fake_actor
assert serving_chat.block_store_instance_rank == 5


def test_normalize_moe_topk_indices_warns_on_nested_lists(monkeypatch):
warnings = []
monkeypatch.setattr(
chat_serving_module.logger,
"warning_once",
lambda message, *args: warnings.append(message % args),
)

routed_experts = [[[1, 2]], [[3, 4]]]
normalized = OpenAIServingChat._normalize_moe_topk_indices_array(
routed_experts,
expected_len=2,
field_name="prompt_routed_experts",
)

np.testing.assert_array_equal(
normalized,
np.array([[[1, 2]], [[3, 4]]], dtype=np.int16),
)
assert warnings == [
"prompt_routed_experts uses nested Python lists for MoE top-k indices; "
"list[np.ndarray] is preferred."
]


def test_normalize_moe_topk_indices_warns_on_other_types(monkeypatch):
warnings = []
monkeypatch.setattr(
chat_serving_module.logger,
"warning_once",
lambda message, *args: warnings.append(message % args),
)

routed_experts = (((1, 2),),)
normalized = OpenAIServingChat._normalize_moe_topk_indices_array(
routed_experts,
expected_len=1,
field_name="routed_experts",
)

np.testing.assert_array_equal(
normalized,
np.array([[[1, 2]]], dtype=np.int16),
)
assert warnings == [
"routed_experts has MoE top-k indices type tuple; expected np.ndarray, "
"list[np.ndarray], or list[list]. Attempting numpy conversion."
]


class TestServingChatWithHarmony:
"""
These tests ensure Chat Completion requests are being properly converted into
Expand Down
6 changes: 6 additions & 0 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,12 @@ class ModelConfig:
Processed means the values after applying all processors, including
temperature and top_k/top_p.
"""
enable_moe_topk_indices_nemo_rl_block_store: bool = False
"""If True, return MoE top-k indices via a block cache key payload and
upload the data to the node-local NeMo RL block store."""
enable_moe_topk_indices_json: bool = False
"""If True, return MoE top-k indices inline in chat completion responses
as nested JSON lists."""
disable_sliding_window: bool = False
"""Whether to disable sliding window. If True, we will disable the sliding
window functionality of the model, capping to sliding window size. If the
Expand Down
Loading