diff --git a/tests/entrypoints/openai/test_cli_args.py b/tests/entrypoints/openai/test_cli_args.py index ccf145a0c65e..c434717133bc 100644 --- a/tests/entrypoints/openai/test_cli_args.py +++ b/tests/entrypoints/openai/test_cli_args.py @@ -247,6 +247,29 @@ def test_default_chat_template_kwargs_default_none(serve_parser): assert args.default_chat_template_kwargs is None +def test_enable_moe_topk_indices_nemo_rl_block_store(serve_parser): + args = serve_parser.parse_args( + args=["--enable-moe-topk-indices-nemo-rl-block-store"] + ) + assert args.enable_moe_topk_indices_nemo_rl_block_store is True + + +def test_enable_moe_topk_indices_json(serve_parser): + args = serve_parser.parse_args(args=["--enable-moe-topk-indices-json"]) + assert args.enable_moe_topk_indices_json is True + + +def test_moe_topk_indices_output_modes_are_mutually_exclusive(serve_parser): + args = serve_parser.parse_args( + args=[ + "--enable-moe-topk-indices-nemo-rl-block-store", + "--enable-moe-topk-indices-json", + ] + ) + with pytest.raises(TypeError): + validate_parsed_serve_args(args) + + def test_default_chat_template_kwargs_invalid_json(serve_parser): """Ensure invalid JSON raises an error""" with pytest.raises(SystemExit): diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 33c69578ce93..c42eae3798aa 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -6,12 +6,14 @@ from typing import Any from unittest.mock import AsyncMock, MagicMock +import numpy as np import pytest import pytest_asyncio from openai import OpenAI from vllm._aiter_ops import is_aiter_found_and_supported from vllm.config import MultiModalConfig +from vllm.entrypoints.openai.chat_completion import serving as chat_serving_module from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -527,6 +529,8 @@ class MockModelConfig: encoder_config = None generation_config: str = "auto" override_generation_config: dict[str, Any] = field(default_factory=dict) + enable_moe_topk_indices_nemo_rl_block_store: bool = False + enable_moe_topk_indices_json: bool = False media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) skip_tokenizer_init: bool = False is_encoder_decoder: bool = False @@ -1042,6 +1046,290 @@ async def mock_generate(*args, **kwargs): assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] is None +@pytest.mark.asyncio +async def test_serving_chat_returns_moe_topk_indices(): + mock_engine = MagicMock(spec=AsyncLLM) + mock_engine.errored = False + mock_engine.model_config = MockModelConfig(enable_moe_topk_indices_json=True) + mock_engine.input_processor = MagicMock() + mock_engine.io_processor = MagicMock() + mock_engine.renderer = _build_renderer(mock_engine.model_config) + + serving_chat = _build_serving_chat(mock_engine) + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{"role": "user", "content": "what is 1+1?"}], + ) + + async def result_generator(): + yield RequestOutput( + request_id=req.request_id, + prompt="what is 1+1?", + prompt_token_ids=[1, 2], + prompt_logprobs=None, + outputs=[ + CompletionOutput( + index=0, + text="2", + token_ids=[3], + cumulative_logprob=0.0, + logprobs=None, + routed_experts=np.array([[[7, 8]]], dtype=np.int16), + finish_reason="stop", + ) + ], + finished=True, + prompt_routed_experts=np.array([[[1, 2]], [[3, 4]]], dtype=np.int16), + ) + + response = await serving_chat.chat_completion_full_generator( + request=req, + result_generator=result_generator(), + request_id="chatcmpl-test", + model_name=MODEL_NAME, + conversation=[], + tokenizer=MagicMock(), + request_metadata=RequestResponseMetadata(request_id=req.request_id), + ) + + assert isinstance(response, ChatCompletionResponse) + assert response.prompt_moe_topk_indices == [[[1, 2]], [[3, 4]]] + assert response.choices[0].moe_topk_indices == [[[7, 8]]] + + +@pytest.mark.asyncio +async def test_serving_chat_omits_moe_topk_indices_without_output_flags(): + mock_engine = MagicMock(spec=AsyncLLM) + mock_engine.errored = False + mock_engine.model_config = MockModelConfig() + mock_engine.input_processor = MagicMock() + mock_engine.io_processor = MagicMock() + mock_engine.renderer = _build_renderer(mock_engine.model_config) + + serving_chat = _build_serving_chat(mock_engine) + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{"role": "user", "content": "what is 1+1?"}], + ) + + async def result_generator(): + yield RequestOutput( + request_id=req.request_id, + prompt="what is 1+1?", + prompt_token_ids=[1, 2], + prompt_logprobs=None, + outputs=[ + CompletionOutput( + index=0, + text="2", + token_ids=[3], + cumulative_logprob=0.0, + logprobs=None, + routed_experts=np.array([[[7, 8]]], dtype=np.int16), + finish_reason="stop", + ) + ], + finished=True, + prompt_routed_experts=np.array([[[1, 2]], [[3, 4]]], dtype=np.int16), + ) + + response = await serving_chat.chat_completion_full_generator( + request=req, + result_generator=result_generator(), + request_id="chatcmpl-test", + model_name=MODEL_NAME, + conversation=[], + tokenizer=MagicMock(), + request_metadata=RequestResponseMetadata(request_id=req.request_id), + ) + + assert isinstance(response, ChatCompletionResponse) + assert response.prompt_moe_topk_indices is None + assert response.choices[0].moe_topk_indices is None + response_dict = response.model_dump() + assert "prompt_moe_topk_indices" not in response_dict + assert "moe_topk_indices" not in response_dict["choices"][0] + + +@pytest.mark.asyncio +async def test_serving_chat_stores_moe_topk_indices_in_block_store(monkeypatch): + mock_engine = MagicMock(spec=AsyncLLM) + mock_engine.errored = False + mock_engine.model_config = MockModelConfig( + enable_moe_topk_indices_nemo_rl_block_store=True + ) + mock_engine.input_processor = MagicMock() + mock_engine.io_processor = MagicMock() + mock_engine.renderer = _build_renderer(mock_engine.model_config) + + serving_chat = _build_serving_chat(mock_engine) + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{"role": "user", "content": "what is 1+1?"}], + ) + + class FakePutNumpy: + def __init__(self): + self.args = None + + def remote(self, *args): + self.args = args + return "put-ref" + + fake_put_numpy = FakePutNumpy() + serving_chat.block_store_instance_rank = 7 + serving_chat.block_store_instance_id = "nemo_rl.block_store.node.127.0.0.1" + serving_chat.block_store_instance = MagicMock() + serving_chat.block_store_instance.put_numpy = fake_put_numpy + + ray = pytest.importorskip("ray") + + monkeypatch.setattr(ray, "get", lambda ref: ref) + + async def result_generator(): + yield RequestOutput( + request_id=req.request_id, + prompt="what is 1+1?", + prompt_token_ids=[1, 2], + prompt_logprobs=None, + outputs=[ + CompletionOutput( + index=0, + text="2", + token_ids=[3], + cumulative_logprob=0.0, + logprobs=None, + routed_experts=np.array([[[7, 8]]], dtype=np.int16), + finish_reason="stop", + ) + ], + finished=True, + prompt_routed_experts=np.array([[[1, 2]], [[3, 4]]], dtype=np.int16), + ) + + response = await serving_chat.chat_completion_full_generator( + request=req, + result_generator=result_generator(), + request_id="chatcmpl-test", + model_name=MODEL_NAME, + conversation=[], + tokenizer=MagicMock(), + request_metadata=RequestResponseMetadata(request_id=req.request_id), + ) + + assert isinstance(response, ChatCompletionResponse) + block_cache_key = { + "instance_rank": 7, + "instance_id": "nemo_rl.block_store.node.127.0.0.1", + "req_id": "test", + "key": "moe_topk_indices", + } + assert response.prompt_moe_topk_indices == {"block_cache_key": block_cache_key} + assert response.choices[0].moe_topk_indices == { + "block_cache_key": block_cache_key + } + assert fake_put_numpy.args is not None + assert fake_put_numpy.args[0] == "test" + assert fake_put_numpy.args[1] == "moe_topk_indices" + np.testing.assert_array_equal( + fake_put_numpy.args[2], + np.array([[[1, 2]], [[3, 4]], [[7, 8]]], dtype=np.int16), + ) + + +def test_block_store_cache_key_loads_instance_rank(monkeypatch): + ray = pytest.importorskip("ray") + + mock_engine = MagicMock(spec=AsyncLLM) + mock_engine.errored = False + mock_engine.model_config = MockModelConfig( + enable_moe_topk_indices_nemo_rl_block_store=True + ) + mock_engine.input_processor = MagicMock() + mock_engine.io_processor = MagicMock() + mock_engine.renderer = _build_renderer(mock_engine.model_config) + + serving_chat = _build_serving_chat(mock_engine) + + class FakeGetRuntimeMetadata: + def remote(self): + return "metadata-ref" + + fake_actor = MagicMock() + fake_actor.get_runtime_metadata = FakeGetRuntimeMetadata() + + monkeypatch.setattr( + ray._private.services, + "get_node_ip_address", + lambda: "127.0.0.1", + ) + monkeypatch.setattr(ray, "get_actor", lambda instance_id: fake_actor) + monkeypatch.setattr(ray, "get", lambda ref: {"instance_rank": 5}) + + block_cache_key = serving_chat._get_moe_topk_indices_block_cache_key( + "chatcmpl-test" + ) + + assert block_cache_key == { + "instance_rank": 5, + "instance_id": "nemo_rl.block_store.node.127.0.0.1", + "req_id": "test", + "key": "moe_topk_indices", + } + assert serving_chat.block_store_instance is fake_actor + assert serving_chat.block_store_instance_rank == 5 + + +def test_normalize_moe_topk_indices_warns_on_nested_lists(monkeypatch): + warnings = [] + monkeypatch.setattr( + chat_serving_module.logger, + "warning_once", + lambda message, *args: warnings.append(message % args), + ) + + routed_experts = [[[1, 2]], [[3, 4]]] + normalized = OpenAIServingChat._normalize_moe_topk_indices_array( + routed_experts, + expected_len=2, + field_name="prompt_routed_experts", + ) + + np.testing.assert_array_equal( + normalized, + np.array([[[1, 2]], [[3, 4]]], dtype=np.int16), + ) + assert warnings == [ + "prompt_routed_experts uses nested Python lists for MoE top-k indices; " + "list[np.ndarray] is preferred." + ] + + +def test_normalize_moe_topk_indices_warns_on_other_types(monkeypatch): + warnings = [] + monkeypatch.setattr( + chat_serving_module.logger, + "warning_once", + lambda message, *args: warnings.append(message % args), + ) + + routed_experts = (((1, 2),),) + normalized = OpenAIServingChat._normalize_moe_topk_indices_array( + routed_experts, + expected_len=1, + field_name="routed_experts", + ) + + np.testing.assert_array_equal( + normalized, + np.array([[[1, 2]]], dtype=np.int16), + ) + assert warnings == [ + "routed_experts has MoE top-k indices type tuple; expected np.ndarray, " + "list[np.ndarray], or list[list]. Attempting numpy conversion." + ] + + class TestServingChatWithHarmony: """ These tests ensure Chat Completion requests are being properly converted into diff --git a/vllm/config/model.py b/vllm/config/model.py index 7a669118ebc6..e97b084a6f9c 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -210,6 +210,12 @@ class ModelConfig: Processed means the values after applying all processors, including temperature and top_k/top_p. """ + enable_moe_topk_indices_nemo_rl_block_store: bool = False + """If True, return MoE top-k indices via a block cache key payload and + upload the data to the node-local NeMo RL block store.""" + enable_moe_topk_indices_json: bool = False + """If True, return MoE top-k indices inline in chat completion responses + as nested JSON lists.""" disable_sliding_window: bool = False """Whether to disable sliding window. If True, we will disable the sliding window functionality of the model, capping to sliding window size. If the diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8b88f9fdbf94..03a446e0f4d8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -460,6 +460,10 @@ class EngineArgs: max_num_seqs: int | None = None max_logprobs: int = ModelConfig.max_logprobs logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode + enable_moe_topk_indices_nemo_rl_block_store: bool = ( + ModelConfig.enable_moe_topk_indices_nemo_rl_block_store + ) + enable_moe_topk_indices_json: bool = ModelConfig.enable_moe_topk_indices_json disable_log_stats: bool = False aggregate_engine_logging: bool = False revision: str | None = ModelConfig.revision @@ -696,6 +700,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"]) model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"]) model_group.add_argument("--logprobs-mode", **model_kwargs["logprobs_mode"]) + model_group.add_argument( + "--enable-moe-topk-indices-nemo-rl-block-store", + **model_kwargs["enable_moe_topk_indices_nemo_rl_block_store"], + ) + model_group.add_argument( + "--enable-moe-topk-indices-json", + **model_kwargs["enable_moe_topk_indices_json"], + ) model_group.add_argument( "--disable-sliding-window", **model_kwargs["disable_sliding_window"] ) @@ -1317,6 +1329,15 @@ def create_model_config(self) -> ModelConfig: if is_gguf(self.model): self.quantization = self.load_format = "gguf" + if ( + self.enable_moe_topk_indices_nemo_rl_block_store + and self.enable_moe_topk_indices_json + ): + raise ValueError( + "enable_moe_topk_indices_nemo_rl_block_store and " + "enable_moe_topk_indices_json are mutually exclusive" + ) + if not envs.VLLM_ENABLE_V1_MULTIPROCESSING: logger.warning( "The global random seed is set to %d. Since " @@ -1350,6 +1371,10 @@ def create_model_config(self) -> ModelConfig: enforce_eager=self.enforce_eager, max_logprobs=self.max_logprobs, logprobs_mode=self.logprobs_mode, + enable_moe_topk_indices_nemo_rl_block_store=( + self.enable_moe_topk_indices_nemo_rl_block_store + ), + enable_moe_topk_indices_json=self.enable_moe_topk_indices_json, disable_sliding_window=self.disable_sliding_window, disable_cascade_attn=self.disable_cascade_attn, skip_tokenizer_init=self.skip_tokenizer_init, diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py index 0abe85ae8558..788d76ed9b85 100644 --- a/vllm/entrypoints/openai/chat_completion/protocol.py +++ b/vllm/entrypoints/openai/chat_completion/protocol.py @@ -12,6 +12,7 @@ ChatCompletionAudio as OpenAIChatCompletionAudio, ) from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation +from openai.types.shared import Metadata from pydantic import Field, model_validator from vllm.config import ModelConfig @@ -81,6 +82,10 @@ class ChatCompletionLogProbs(OpenAIBaseModel): content: list[ChatCompletionLogProbsContent] | None = None +MoETopKIndices = list[Any] | list[list[list[int]]] | None +MoETopKIndicesPayload = dict[str, Any] | MoETopKIndices + + class ChatCompletionResponseChoice(OpenAIBaseModel): index: int message: ChatMessage @@ -92,6 +97,9 @@ class ChatCompletionResponseChoice(OpenAIBaseModel): # not part of the OpenAI spec but is useful for tracing the tokens # in agent scenarios token_ids: list[int] | None = None + moe_topk_indices: MoETopKIndicesPayload = Field( + default=None, exclude_if=lambda value: value is None + ) class ChatCompletionResponse(OpenAIBaseModel): @@ -110,6 +118,9 @@ class ChatCompletionResponse(OpenAIBaseModel): kv_transfer_params: dict[str, Any] | None = Field( default=None, description="KVTransfer parameters." ) + prompt_moe_topk_indices: MoETopKIndicesPayload = Field( + default=None, exclude_if=lambda value: value is None + ) class ChatCompletionResponseStreamChoice(OpenAIBaseModel): @@ -162,6 +173,7 @@ class ChatCompletionRequest(OpenAIBaseModel): "the max_completion_tokens field", ) max_completion_tokens: int | None = None + metadata: Metadata | None = None n: int | None = 1 presence_penalty: float | None = 0.0 response_format: AnyResponseFormat | None = None diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 06b16cde6748..165b81ac400f 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -9,7 +9,9 @@ from typing import Any, Final import jinja2 +import numpy as np import partial_json_parser +import ray import regex as re from fastapi import Request from openai_harmony import Message as OpenAIMessage @@ -34,6 +36,8 @@ ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, + MoETopKIndices, + MoETopKIndicesPayload, ) from vllm.entrypoints.openai.chat_completion.stream_harmony import ( TokenState, @@ -174,6 +178,10 @@ def __init__( self.supports_code_interpreter = False self.python_tool = None + self.block_store_instance_rank: int | None = None + self.block_store_instance_id: str | None = None + self.block_store_instance = None + async def warmup(self) -> None: """ Warm up the chat template processing to avoid first-request latency. @@ -1419,6 +1427,41 @@ async def chat_completion_full_generator( assert final_res is not None + rl_metadata = None + if isinstance(request.metadata, dict): + rl_metadata = request.metadata.get("rl_metadata") + if isinstance(rl_metadata, str): + rl_metadata = json.loads(rl_metadata) + if rl_metadata is not None: + logger.debug(f"chat_completion_full_generator: rl metadata = {rl_metadata}") + else: + logger.debug("chat_completion_full_generator: no rl metadata") + + block_store_enabled = ( + self.model_config.enable_moe_topk_indices_nemo_rl_block_store + ) + json_enabled = self.model_config.enable_moe_topk_indices_json + return_moe_topk_indices = block_store_enabled or json_enabled + has_routed_experts = ( + final_res.prompt_routed_experts is not None + or any(output.routed_experts is not None for output in final_res.outputs) + ) + if block_store_enabled and has_routed_experts and len(final_res.outputs) > 1: + return self.create_error_response( + "multiple response outputs not supported with " + "--enable-moe-topk-indices-nemo-rl-block-store" + ) + + prompt_moe_topk_indices = ( + self._get_moe_topk_indices_payload( + request_id, + final_res.prompt_routed_experts, + rl_metadata=rl_metadata, + ) + if return_moe_topk_indices + else None + ) + choices: list[ChatCompletionResponseChoice] = [] if self.tool_call_id_type == "kimi_k2": history_tool_call_cnt = get_history_tool_calls_cnt(conversation) @@ -1433,6 +1476,15 @@ async def chat_completion_full_generator( token_ids = output.token_ids out_logprobs = output.logprobs tool_call_info = None + completion_moe_topk_indices = ( + self._get_moe_topk_indices_payload( + request_id, + output.routed_experts, + rl_metadata=rl_metadata, + ) + if return_moe_topk_indices + else None + ) if request.logprobs and request.top_logprobs is not None: assert out_logprobs is not None, "Did not output logprobs" @@ -1493,6 +1545,7 @@ async def chat_completion_full_generator( token_ids=( as_list(output.token_ids) if request.return_token_ids else None ), + moe_topk_indices=completion_moe_topk_indices, ) choices.append(choice_data) continue @@ -1702,6 +1755,7 @@ async def chat_completion_full_generator( token_ids=( as_list(output.token_ids) if request.return_token_ids else None ), + moe_topk_indices=completion_moe_topk_indices, ) choice_data = maybe_filter_parallel_tool_calls(choice_data, request) @@ -1723,6 +1777,16 @@ async def chat_completion_full_generator( choice.message.content = full_message assert final_res.prompt_token_ids is not None + if block_store_enabled and final_res.outputs: + self._maybe_store_moe_topk_indices( + request_id=request_id, + prompt_routed_experts=final_res.prompt_routed_experts, + completion_routed_experts=final_res.outputs[0].routed_experts, + num_prompt_tokens=len(final_res.prompt_token_ids), + num_completion_tokens=len(final_res.outputs[0].token_ids), + rl_metadata=rl_metadata, + ) + num_prompt_tokens = len(final_res.prompt_token_ids) if final_res.encoder_prompt_token_ids is not None: num_prompt_tokens += len(final_res.encoder_prompt_token_ids) @@ -1752,6 +1816,7 @@ async def chat_completion_full_generator( final_res.prompt_token_ids if request.return_token_ids else None ), kv_transfer_params=final_res.kv_transfer_params, + prompt_moe_topk_indices=prompt_moe_topk_indices, ) # Log complete response if output logging is enabled @@ -1788,6 +1853,166 @@ async def chat_completion_full_generator( return response + @staticmethod + def _base_chat_request_id(request_id: str) -> str: + if request_id.startswith("chatcmpl-") or request_id.startswith("chatcmpl_"): + return request_id[9:] + return request_id + + def _get_moe_topk_indices_payload( + self, + request_id: str, + routed_experts: Any, + rl_metadata: dict[str, Any] | None = None, + ) -> MoETopKIndicesPayload: + if self.model_config.enable_moe_topk_indices_nemo_rl_block_store: + return { + "block_cache_key": self._get_moe_topk_indices_block_cache_key( + request_id, + rl_metadata=rl_metadata, + ), + } + if self.model_config.enable_moe_topk_indices_json: + return self._format_moe_topk_indices(routed_experts) + return None + + def _get_moe_topk_indices_block_cache_key( + self, + request_id: str, + rl_metadata: dict[str, Any] | None = None, + ) -> dict[str, Any] | None: + if not self.model_config.enable_moe_topk_indices_nemo_rl_block_store: + return None + + if self.block_store_instance is None: + node_ip = ray._private.services.get_node_ip_address() + instance_id = f"nemo_rl.block_store.node.{node_ip}" + try: + self.block_store_instance = ray.get_actor(instance_id) + runtime_metadata = ray.get( + self.block_store_instance.get_runtime_metadata.remote() + ) + self.block_store_instance_rank = runtime_metadata["instance_rank"] + self.block_store_instance_id = instance_id + except ValueError: + logger.warning_once( + "MoE top-k block store actor %s is not available.", instance_id + ) + self.block_store_instance_rank = None + self.block_store_instance = None + self.block_store_instance_id = None + return None + + if self.block_store_instance_id is None: + return None + + req_id = self._base_chat_request_id(request_id) + return { + "instance_rank": self.block_store_instance_rank, + "instance_id": self.block_store_instance_id, + "req_id": req_id, + "key": "moe_topk_indices", + } + + @staticmethod + def _normalize_moe_topk_indices_array( + routed_experts: Any, + expected_len: int, + field_name: str, + ) -> np.ndarray | None: + if routed_experts is None: + return None + + if isinstance(routed_experts, list): + if routed_experts and isinstance(routed_experts[0], list): + logger.warning_once( + "%s uses nested Python lists for MoE top-k indices; " + "list[np.ndarray] is preferred.", + field_name, + ) + elif routed_experts and not isinstance(routed_experts[0], np.ndarray): + elem_types = sorted( + {type(item).__name__ for item in routed_experts} + ) + logger.warning_once( + "%s has MoE top-k indices list element types %s; expected " + "list[np.ndarray]. Attempting numpy conversion.", + field_name, + elem_types, + ) + elif not isinstance(routed_experts, np.ndarray): + logger.warning_once( + "%s has MoE top-k indices type %s; expected np.ndarray, " + "list[np.ndarray], or list[list]. Attempting numpy conversion.", + field_name, + type(routed_experts).__name__, + ) + + routed_experts = np.asarray(routed_experts, dtype=np.int16) + if routed_experts.ndim == 0: + return None + + seq_len = int(routed_experts.shape[0]) + if seq_len > expected_len: + routed_experts = routed_experts[:expected_len] + elif seq_len < expected_len: + if seq_len == 0: + return None + pad_shape = (expected_len - seq_len, *routed_experts.shape[1:]) + routed_experts = np.concatenate( + (routed_experts, np.zeros(pad_shape, dtype=routed_experts.dtype)), + axis=0, + ) + + return routed_experts + + def _maybe_store_moe_topk_indices( + self, + request_id: str, + prompt_routed_experts: Any, + completion_routed_experts: Any, + num_prompt_tokens: int, + num_completion_tokens: int, + rl_metadata: dict[str, Any] | None = None, + ) -> None: + if self.block_store_instance is None: + return + + prompt_array = self._normalize_moe_topk_indices_array( + prompt_routed_experts, num_prompt_tokens, "prompt_routed_experts" + ) + completion_array = self._normalize_moe_topk_indices_array( + completion_routed_experts, num_completion_tokens, "routed_experts" + ) + arrays = [ + array for array in (prompt_array, completion_array) if array is not None + ] + if not arrays: + return + + moe_topk_indices = ( + arrays[0] if len(arrays) == 1 else np.concatenate(arrays, axis=0) + ) + req_id = self._base_chat_request_id(request_id) + ray.get( + self.block_store_instance.put_numpy.remote( + req_id, + "moe_topk_indices", + moe_topk_indices, + rl_metadata=rl_metadata, + ) + ) + + def _format_moe_topk_indices(self, routed_experts: Any) -> list | None: + if routed_experts is None: + return None + if isinstance(routed_experts, list): + return routed_experts + tolist = getattr(routed_experts, "tolist", None) + if callable(tolist): + return tolist() + return list(routed_experts) + def _get_top_logprobs( self, logprobs: dict[int, Logprob], diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index d3a66c183346..47c4ee12ac7e 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -368,6 +368,14 @@ def validate_parsed_serve_args(args: argparse.Namespace): raise TypeError("Error: --enable-auto-tool-choice requires --tool-call-parser") if args.enable_log_outputs and not args.enable_log_requests: raise TypeError("Error: --enable-log-outputs requires --enable-log-requests") + if ( + args.enable_moe_topk_indices_nemo_rl_block_store + and args.enable_moe_topk_indices_json + ): + raise TypeError( + "Error: --enable-moe-topk-indices-nemo-rl-block-store and " + "--enable-moe-topk-indices-json are mutually exclusive" + ) def create_parser_for_docs() -> FlexibleArgumentParser: