diff --git a/python/ray/llm/_internal/serve/configs/prompt_formats.py b/python/ray/llm/_internal/serve/configs/prompt_formats.py index a3c0f14b6f474..86530aa5205aa 100644 --- a/python/ray/llm/_internal/serve/configs/prompt_formats.py +++ b/python/ray/llm/_internal/serve/configs/prompt_formats.py @@ -24,7 +24,6 @@ class Text(BaseModel): - field: str = "text" type: str = "text" text: str @@ -35,18 +34,23 @@ class Text(BaseModel): # This is to support the "content" content type in the prompt format, as opposite of # the "text" content from the above which most other model uses. class Content(BaseModel): - field: str = "text" type: str = "text" content: str class Image(BaseModel): - field: str = "image_url" + type: str = "image_url" image_url: Dict @field_validator("image_url") @classmethod def check_image_url(cls, value): + """Checks if the image_url is a dict with a 'url' key. + Example: + image_url = { + "url": "https://example.com/image.png" + } + """ if "url" not in value or not value["url"] or not isinstance(value["url"], str): raise ValueError( # TODO(xwjiang): Link to doc. @@ -120,6 +124,7 @@ class EngineInput(BaseModel): image: Optional[List[ImageInput]] = None +# TODO (Kourosh): We can delete this abstraction. class AbstractPromptFormat(BaseModel): model_config = ConfigDict(extra="forbid") diff --git a/python/ray/llm/_internal/serve/configs/server_models.py b/python/ray/llm/_internal/serve/configs/server_models.py index 1a1b5520a4e78..17265b76bcb5c 100644 --- a/python/ray/llm/_internal/serve/configs/server_models.py +++ b/python/ray/llm/_internal/serve/configs/server_models.py @@ -945,3 +945,4 @@ class GenerationRequest(BaseModelExtended): prompt: Union[str, List[int], List[str]] request_id: Union[str, List[str]] sampling_params: Optional[Union[SamplingParams, List[SamplingParams]]] = None + stream: bool = False diff --git a/python/ray/llm/_internal/serve/deployments/llm/llm_engine.py b/python/ray/llm/_internal/serve/deployments/llm/llm_engine.py new file mode 100644 index 0000000000000..556ae23342ea9 --- /dev/null +++ b/python/ray/llm/_internal/serve/deployments/llm/llm_engine.py @@ -0,0 +1,67 @@ +from typing import AsyncGenerator, Optional + +from ray.llm._internal.serve.configs.server_models import ( + Prompt, + LLMRawResponse, + LLMConfig, + GenerationRequest, + DiskMultiplexConfig, +) + + +import abc + + +class LLMEngine(abc.ABC): + """Base class for all LLM engines""" + + def __init__(self, llm_config: LLMConfig): + self._llm_config = llm_config + + @abc.abstractmethod + async def start(self): + """Start the engine""" + pass + + @abc.abstractmethod + async def prepare_request( + self, + request_id: str, + prompt: Prompt, + stream: bool, + disk_lora_model: Optional[DiskMultiplexConfig] = None, + **kwargs, + ) -> GenerationRequest: + """Prepare a GenerationRequest for the engine""" + pass + + @abc.abstractmethod + async def generate( + self, request: GenerationRequest + ) -> AsyncGenerator[LLMRawResponse, None]: + """Generate an LLMRawResponse stream based on the GenerationRequest""" + pass + + async def check_health(self) -> bool: + """Check the health of the engine""" + return True + + ############################################################## + # Optional methods + # These methods will be implemented in the future to allow + # more granular life-cycle management of the engine. + # e.g. in usecases like RL training, we need to put the engine + # to sleep during training and wake up during rollouts. + ############################################################## + + async def sleep(self): + """Puts the engine to sleep""" + pass + + async def wakeup(self): + """Wakes up the engine""" + pass + + def shutdown(self): + """Shuts down the engine""" + pass diff --git a/python/ray/llm/_internal/serve/deployments/llm/llm_server.py b/python/ray/llm/_internal/serve/deployments/llm/llm_server.py index bbe89f25f875a..47aaf77047ec5 100644 --- a/python/ray/llm/_internal/serve/deployments/llm/llm_server.py +++ b/python/ray/llm/_internal/serve/deployments/llm/llm_server.py @@ -4,7 +4,6 @@ from typing import AsyncGenerator, Dict, Any, Optional, Type, Union # Third-party imports -from pydantic import ValidationError as PydanticValidationError from ray import serve from ray._common.utils import import_attr @@ -15,9 +14,6 @@ ENGINE_START_TIMEOUT_S, RAYLLM_VLLM_ENGINE_CLS_ENV, ) -from ray.llm._internal.serve.configs.error_handling import ( - ValidationErrorWithPydantic, -) from ray.llm._internal.serve.configs.openai_api_models import ( ChatCompletionLogProb, ChatCompletionLogProbs, @@ -47,15 +43,12 @@ LLMConfig, LLMRawResponse, ) +from ray.llm._internal.serve.deployments.llm.llm_engine import LLMEngine from ray.llm._internal.serve.deployments.llm.image_retriever import ImageRetriever from ray.llm._internal.serve.deployments.llm.multiplex.lora_model_loader import ( LoraModelLoader, ) from ray.llm._internal.serve.deployments.llm.vllm.vllm_engine import VLLMEngine -from ray.llm._internal.serve.deployments.llm.vllm.vllm_models import ( - VLLMGenerationRequest, - VLLMSamplingParams, -) from ray.llm._internal.serve.deployments.utils.error_handling_utils import ( StreamingErrorHandler, ) @@ -472,9 +465,11 @@ async def __init__( self.response_postprocessor = ResponsePostprocessor() @property - def _get_engine_class(self) -> VLLMEngine: - """Helper to load the engine class from the environment variable if existed - else it will fallback to the default engine class. + def _get_engine_class(self) -> Type[LLMEngine]: + """Helper to load the engine class from the environment variable. + + This is used for testing or escape-hatch for patching purposes. + If env variable is not set, it will fallback to the default engine class. """ engine_cls_path = os.environ.get(RAYLLM_VLLM_ENGINE_CLS_ENV) if engine_cls_path: @@ -485,7 +480,6 @@ def _get_engine_class(self) -> VLLMEngine: f"Failed to import engine class {engine_cls_path}. " f"Using the default engine class {self._engine_cls}." ) - return self._engine_cls async def _start_engine(self): @@ -511,50 +505,24 @@ async def _predict( """ logger.info(f"Received streaming request {request_id}") - try: - multiplexed_model_id = serve.get_multiplexed_model_id() - - if multiplexed_model_id: - assert ( - self._llm_config.lora_config is not None - ), "Must setup lora config for multiplexed requests." - disk_lora_model = await self._disk_lora_model(multiplexed_model_id) - else: - disk_lora_model = None - - prompt_output = self._llm_config.prompt_format.generate_prompt(prompt) - - sampling_params = VLLMSamplingParams.from_prompt(prompt) - prompt_text = prompt_output.text - image_input = prompt_output.image - image = [] - if not self._llm_config.supports_vision and image_input: - raise RuntimeError( - "You provided image input while the engine is not set up to handle images. " - "Did you forget to set `input_modality` to image in yaml file?" - ) + multiplexed_model_id = serve.get_multiplexed_model_id() + + if multiplexed_model_id: + assert ( + self._llm_config.lora_config is not None + ), "Must setup lora config for multiplexed requests." + disk_lora_model = await self._disk_lora_model(multiplexed_model_id) + else: + disk_lora_model = None + + llm_request = await self.engine.prepare_request( + request_id=request_id, + prompt=prompt, + stream=stream, + disk_lora_model=disk_lora_model, + ) - if self._llm_config.supports_vision and image_input: - for _image in image_input: - image_url = _image.image_url - image.append(await self.image_retriever.get(image_url)) - - request_params = { - "prompt": prompt_text, - "request_id": request_id, - "sampling_params": sampling_params, - "disk_multiplex_config": disk_lora_model, - "serve_request_context": serve.context._serve_request_context.get(), - } - if image: - request_params["multi_modal_data"] = {"image": image} - vllm_request = VLLMGenerationRequest(**request_params) - except PydanticValidationError as e: - # Wrap the PydanticValidationError in a ValidationErrorWithPydantic - # so that it can be used in a RayActorError - # See https://github.com/ray-project/ray/issues/43401 - raise ValidationErrorWithPydantic(e) from None - async for llm_response in self.engine.generate(vllm_request, stream): + async for llm_response in self.engine.generate(llm_request): yield llm_response async def chat(self, request: ChatCompletionRequest) -> LLMChatResponse: @@ -598,8 +566,8 @@ async def completions(self, request: CompletionRequest) -> LLMCompletionsRespons model=self._llm_config.model_id, gen=gen, stream=stream ) - async def check_health(self): - """Check the health of the vllm engine.""" + async def check_health(self) -> bool: + """Check the health of the llm engine.""" return await self.engine.check_health() async def _load_model(self, lora_model_id: str) -> DiskMultiplexConfig: diff --git a/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py b/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py index 6565ff8e1c14f..7dde59ba09c32 100644 --- a/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py +++ b/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py @@ -20,7 +20,6 @@ ) from ray.llm._internal.serve.deployments.llm.vllm.vllm_engine_stats import ( ArgUsage, - VLLMEngineStats, VLLMEngineStatTracker, usage_counters, ) @@ -36,7 +35,9 @@ initialize_node as initialize_node_util, ) from ray.llm._internal.serve.configs.server_models import ( - BatchedLLMRawResponse, + Prompt, + GenerationRequest, + DiskMultiplexConfig, LLMConfig, LLMRawResponse, LogProb, @@ -52,6 +53,9 @@ MAX_NUM_TOPLOGPROBS_ALLOWED, ) from ray.llm._internal.utils import try_import +from ray.llm._internal.serve.deployments.utils.batcher import LLMRawResponsesBatcher + +from ray.llm._internal.serve.deployments.llm.llm_engine import LLMEngine if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig @@ -63,7 +67,6 @@ vllm = try_import("vllm") logger = get_logger(__name__) - time_in_queue_histogram = metrics.Histogram( "vllm_engine_stats_time_in_queue_ms", "Time a request spends in the queue first forward pass not included (ms).", @@ -126,93 +129,6 @@ def _clear_current_platform_cache(): current_platform.get_device_capability.cache_clear() -class BatchLLMRawResponses: - """This class batches multiple LLMRawResponses from a generator into a - single response, at some time interval. - - Args: - generator: the async generator that this class pulls LLMRawResponses - from. - interval_ms: the interval at which this class yields the current batch. - If None, this class will batch all responses from the generator - together and yield the entire batch once. - """ - - def __init__( - self, - generator: AsyncGenerator[LLMRawResponse, None], - interval_ms: Optional[float] = MODEL_RESPONSE_BATCH_TIMEOUT_MS, - ): - self.generator = generator - self.queue: asyncio.Queue = asyncio.Queue() - - if interval_ms is None: - self.interval_s = None - else: - self.interval_s = interval_ms / 1000 - - self.done_event: asyncio.Event = asyncio.Event() - - # We are okay with this task getting cancelled (to propagate cancellations) - self.read_task = asyncio.create_task(self.read()) - - async def stream(self) -> AsyncGenerator[BatchedLLMRawResponse, None]: - """Drain from the queue every interval_ms and yield the merged results""" - try: - while True: - # Wait for the interval or until we finish, whichever is faster. - # We use an event to avoid asyncio.wait_for cancelling the real task on timeout. - try: - if self.interval_s is None: - await self.done_event.wait() - else: - await asyncio.wait_for( - self.done_event.wait(), timeout=self.interval_s - ) - except asyncio.TimeoutError: - pass - - # Get all elements from the queue - results, is_done = self.check_done_and_drain() - - # If there are results, merge and yield them - if results: - output: BatchedLLMRawResponse = BatchedLLMRawResponse.merge_stream(*results) # type: ignore - yield output - - # If the read task is done, exit the stream task - if is_done: - # Raise exception, if any - self.read_task.result() - break - finally: - # If the stream task is done, make sure to exit the read task - if not self.read_task.done(): - self.read_task.cancel() - - def check_done_and_drain(self): - results = self.drain_queue() - return results, self.read_task.done() - - async def read(self): - """Read from the generator and put into the queue in a tight loop""" - try: - async for x in self.generator: - self.queue.put_nowait(x) - finally: - self.done_event.set() - - def drain_queue(self): - """Drain all results currently in the queue""" - results = [] - try: - while True: - results.append(self.queue.get_nowait()) - except asyncio.QueueEmpty: - pass - return results - - class _EngineBackgroundProcess: def __init__(self, ipc_path, engine_args, engine_config): from vllm.engine.multiprocessing.engine import MQLLMEngine @@ -251,7 +167,7 @@ def get_error(self): return self._error -class VLLMEngine: +class VLLMEngine(LLMEngine): def __init__( self, llm_config: LLMConfig, @@ -261,6 +177,8 @@ def __init__( Args: llm_config: The llm configuration for this engine """ + super().__init__(llm_config) + if vllm is None: raise ImportError( "vLLM is not installed. Please install it with `pip install ray[llm]`." @@ -278,6 +196,11 @@ def __init__( self.engine = None self.vllm_config: "VllmConfig" = None + # Chat template content format (openai or string) + self._resolved_content_format = None + # Also need local instance of the tokenizer to manage prompt formatting. + self._tokenizer = None + @staticmethod async def initialize_node(llm_config: LLMConfig) -> InitializeNodeOutput: """Run the node initializer. @@ -293,16 +216,30 @@ async def start(self): If the engine is already running, do nothing. """ + from vllm.entrypoints.chat_utils import resolve_chat_template_content_format + if self.running: # The engine is already running! logger.info("Skipping engine restart because the engine is already running") return - # Get the scaling options self.engine = await self._start_engine() self.running = True self.model_config = await self.engine.get_model_config() + self._tokenizer = await self.engine.get_tokenizer() + self._resolved_content_format = resolve_chat_template_content_format( + # Use HF to get the chat template so set it to None here. + chat_template=None, + # Default to None, change when it's needed. + # vLLM does not have a high level API to support all of this. + tools=None, + # Let vLLM decide the content format. + given_format="auto", + tokenizer=self._tokenizer, + trust_remote_code=self.model_config.trust_remote_code, + ) + logger.info("Started vLLM engine.") async def _start_engine(self) -> "EngineClient": @@ -524,25 +461,73 @@ def _start_async_llm_engine( log_stats=not engine_args.disable_log_stats, ) - async def generate( + async def prepare_request( self, - vllm_engine_request: VLLMGenerationRequest, + request_id: str, + prompt: Prompt, stream: bool, - ) -> AsyncGenerator[LLMRawResponse, None]: - batch_interval_ms = MODEL_RESPONSE_BATCH_TIMEOUT_MS if stream else None - if vllm_engine_request.serve_request_context: - ray.serve.context._serve_request_context.set( - vllm_engine_request.serve_request_context + disk_lora_model: Optional[DiskMultiplexConfig] = None, + ) -> GenerationRequest: + from vllm.entrypoints.chat_utils import ( + parse_chat_messages_futures, + apply_hf_chat_template, + ) + + model_config = self.model_config + mm_data = None + + if isinstance(prompt.prompt, list): + messages = [m.model_dump() for m in prompt.prompt] + conversation, mm_futures = parse_chat_messages_futures( + messages=messages, + model_config=model_config, + tokenizer=self._tokenizer, + content_format=self._resolved_content_format, ) - response_stream = BatchLLMRawResponses( - self._generate(vllm_engine_request), + mm_data = await mm_futures + + prompt_text = apply_hf_chat_template( + tokenizer=self._tokenizer, + conversation=conversation, + chat_template=None, + tools=None, + trust_remote_code=model_config.trust_remote_code, + tokenize=False, + # **kwargs for tokenizer.apply_chat_template + add_generation_prompt=True, + continue_final_message=False, + ) + else: + prompt_text = prompt.prompt + + request_params = { + "prompt": prompt_text, + "request_id": request_id, + "sampling_params": VLLMSamplingParams.from_prompt(prompt), + "disk_multiplex_config": disk_lora_model, + "stream": stream, + } + if mm_data: + request_params["multi_modal_data"] = mm_data + + vllm_request = VLLMGenerationRequest(**request_params) + return vllm_request + + async def generate( + self, + request: GenerationRequest, + ) -> AsyncGenerator[LLMRawResponse, None]: + batch_interval_ms = MODEL_RESPONSE_BATCH_TIMEOUT_MS if request.stream else None + + response_stream = LLMRawResponsesBatcher( + self._generate(request), interval_ms=batch_interval_ms, ) async for response in response_stream.stream(): yield response async def _generate( - self, vllm_generation_request: VLLMGenerationRequest + self, request: GenerationRequest ) -> AsyncGenerator[LLMRawResponse, None]: """Generate an LLMRawResponse stream @@ -560,20 +545,17 @@ async def _generate( """ if RAYLLM_ENABLE_REQUEST_PROMPT_LOGS: logger.info( - f"Request {vllm_generation_request.request_id} started. " - f"Prompt: {vllm_generation_request.prompt}" + f"Request {request.request_id} started. " f"Prompt: {request.prompt}" ) # Construct a results generator from vLLM results_generator: AsyncGenerator["RequestOutput", None] = self.engine.generate( prompt=vllm.inputs.TextPrompt( - prompt=vllm_generation_request.prompt, - multi_modal_data=vllm_generation_request.multi_modal_data, - ), - sampling_params=self._parse_sampling_params( - vllm_generation_request.sampling_params + prompt=request.prompt, + multi_modal_data=request.multi_modal_data, ), - request_id=vllm_generation_request.request_id, - lora_request=vllm_generation_request.lora_request, # type: ignore + sampling_params=self._parse_sampling_params(request.sampling_params), + request_id=request.request_id, + lora_request=request.lora_request, # type: ignore ) # Loop over the results @@ -607,7 +589,7 @@ async def _generate( log_probs, log_probs_idx = self._extract_logprobs( output, log_probs_idx, - vllm_generation_request.sampling_params.top_logprobs, + request.sampling_params.top_logprobs, ) yield LLMRawResponse( generated_text=text_output, @@ -644,7 +626,7 @@ async def _generate( generated_tokens_s = all_tokens_collected / generation_time logger.info( - f"Request {vllm_generation_request.request_id} finished ({finish_reason}). " + f"Request {request.request_id} finished ({finish_reason}). " f"Total time: {total_request_time}s, " f"Queue time: {queue_time}, " f"Generation+async time: {generation_time_str}, " @@ -655,7 +637,7 @@ async def _generate( ) else: logger.warning( - f"Request {vllm_generation_request.request_id} " + f"Request {request.request_id} " "finished without any output. " f"Input tokens: {num_input_tokens}." ) @@ -669,7 +651,7 @@ async def _generate( finally: # Ensure that we cancel on the engine once we have exited the streaming # phase - await self.engine.abort(vllm_generation_request.request_id) + await self.engine.abort(request.request_id) def _get_prompt_limit(self) -> int: """Helper to get the prompt limit from scheduler config @@ -695,6 +677,7 @@ def _handle_input_too_long( if ( finish_reason and finish_reason == FinishReason.LENGTH + and hasattr(request_output.metrics, "first_token_time") and request_output.metrics.first_token_time is None ): # This means that the prompt was too long and we did not generate anything. @@ -702,9 +685,9 @@ def _handle_input_too_long( len(request_output.prompt_token_ids), self._get_prompt_limit() ).exception - async def check_health(self): + async def check_health(self) -> bool: if not hasattr(self.engine, "check_health"): - return + return False try: return await asyncio.wait_for(self.engine.check_health(), timeout=15) @@ -712,12 +695,6 @@ async def check_health(self): logger.exception("Healthcheck failed. The replica will be restarted") raise e from None - def stats(self) -> VLLMEngineStats: - return self._stats.to_stats() - - def shutdown(self, shutdown_pg: bool = True): - raise NotImplementedError() - @staticmethod def _collect_usage_metrics(sampling_params: VLLMSamplingParams) -> None: if sampling_params.best_of is not None: diff --git a/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_models.py b/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_models.py index 4916615c87289..87a05e5e3dae5 100644 --- a/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_models.py +++ b/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_models.py @@ -1,8 +1,7 @@ import os -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union from pydantic import ConfigDict, Field -from ray import serve from ray.util.placement_group import ( PlacementGroup, get_current_placement_group, @@ -222,9 +221,10 @@ class VLLMSamplingParams(SamplingParams): class VLLMGenerationRequest(GenerationRequest): model_config = ConfigDict(arbitrary_types_allowed=True) - sampling_params: VLLMSamplingParams + sampling_params: Optional[ + Union[VLLMSamplingParams, List[VLLMSamplingParams]] + ] = None multi_modal_data: Optional[Dict[str, Any]] = None - serve_request_context: Optional[serve.context._RequestContext] = None disk_multiplex_config: Optional[DiskMultiplexConfig] = None @property diff --git a/python/ray/llm/_internal/serve/deployments/utils/batcher.py b/python/ray/llm/_internal/serve/deployments/utils/batcher.py new file mode 100644 index 0000000000000..9b3b8e66a63a8 --- /dev/null +++ b/python/ray/llm/_internal/serve/deployments/utils/batcher.py @@ -0,0 +1,103 @@ +import asyncio +from typing import AsyncGenerator, Optional + + +from ray.llm._internal.serve.observability.logging import get_logger +from ray.llm._internal.serve.configs.server_models import ( + BatchedLLMRawResponse, + LLMRawResponse, +) + +from ray.llm._internal.serve.configs.constants import ( + MODEL_RESPONSE_BATCH_TIMEOUT_MS, +) + + +logger = get_logger(__name__) + + +class LLMRawResponsesBatcher: + """This class batches multiple LLMRawResponses from a generator into a + single response, at some time interval. + + Args: + generator: the async generator that this class pulls LLMRawResponses + from. + interval_ms: the interval at which this class yields the current batch. + If None, this class will batch all responses from the generator + together and yield the entire batch once. + """ + + def __init__( + self, + generator: AsyncGenerator[LLMRawResponse, None], + interval_ms: Optional[float] = MODEL_RESPONSE_BATCH_TIMEOUT_MS, + ): + self.generator = generator + self.queue: asyncio.Queue = asyncio.Queue() + + if interval_ms is None: + self.interval_s = None + else: + self.interval_s = interval_ms / 1000 + + self.done_event: asyncio.Event = asyncio.Event() + + # We are okay with this task getting cancelled (to propagate cancellations) + self.read_task = asyncio.create_task(self.read()) + + async def stream(self) -> AsyncGenerator[BatchedLLMRawResponse, None]: + """Drain from the queue every interval_ms and yield the merged results""" + try: + while True: + # Wait for the interval or until we finish, whichever is faster. + # We use an event to avoid asyncio.wait_for cancelling the real task on timeout. + try: + if self.interval_s is None: + await self.done_event.wait() + else: + await asyncio.wait_for( + self.done_event.wait(), timeout=self.interval_s + ) + except asyncio.TimeoutError: + pass + + # Get all elements from the queue + results, is_done = self.check_done_and_drain() + + # If there are results, merge and yield them + if results: + output: BatchedLLMRawResponse = BatchedLLMRawResponse.merge_stream(*results) # type: ignore + yield output + + # If the read task is done, exit the stream task + if is_done: + # Raise exception, if any + self.read_task.result() + break + finally: + # If the stream task is done, make sure to exit the read task + if not self.read_task.done(): + self.read_task.cancel() + + def check_done_and_drain(self): + results = self.drain_queue() + return results, self.read_task.done() + + async def read(self): + """Read from the generator and put into the queue in a tight loop""" + try: + async for x in self.generator: + self.queue.put_nowait(x) + finally: + self.done_event.set() + + def drain_queue(self): + """Drain all results currently in the queue""" + results = [] + try: + while True: + results.append(self.queue.get_nowait()) + except asyncio.QueueEmpty: + pass + return results diff --git a/python/ray/llm/tests/serve/conftest.py b/python/ray/llm/tests/serve/conftest.py index b7fb72da162a2..5b57e52696615 100644 --- a/python/ray/llm/tests/serve/conftest.py +++ b/python/ray/llm/tests/serve/conftest.py @@ -1,7 +1,6 @@ import ray from ray import serve import pytest -from ray.llm._internal.serve.configs.constants import RAYLLM_VLLM_ENGINE_CLS_ENV from ray.llm._internal.serve.configs.server_models import ( LLMConfig, ModelLoadingConfig, @@ -30,15 +29,6 @@ def shutdown_ray_and_serve(): ray.shutdown() -@pytest.fixture -def use_mock_vllm_engine(monkeypatch): - monkeypatch.setenv( - RAYLLM_VLLM_ENGINE_CLS_ENV, - "ray.llm.tests.serve.mocks.mock_vllm_engine.MockVLLMEngine", - ) - yield - - @pytest.fixture def llm_config(model_pixtral_12b): yield LLMConfig( @@ -101,30 +91,16 @@ def get_rayllm_testing_model( @pytest.fixture -def testing_model(shutdown_ray_and_serve, use_mock_vllm_engine, model_pixtral_12b): +def testing_model(shutdown_ray_and_serve): test_model_path = get_test_model_path("mock_vllm_model.yaml") - with open(test_model_path, "r") as f: - loaded_llm_config = yaml.safe_load(f) - - loaded_llm_config["model_loading_config"]["model_source"] = model_pixtral_12b - test_model_path = write_yaml_file(loaded_llm_config) - with get_rayllm_testing_model(test_model_path) as (client, model_id): yield client, model_id @pytest.fixture -def testing_model_no_accelerator( - shutdown_ray_and_serve, use_mock_vllm_engine, model_pixtral_12b -): +def testing_model_no_accelerator(shutdown_ray_and_serve): test_model_path = get_test_model_path("mock_vllm_model_no_accelerator.yaml") - with open(test_model_path, "r") as f: - loaded_llm_config = yaml.safe_load(f) - - loaded_llm_config["model_loading_config"]["model_source"] = model_pixtral_12b - test_model_path = write_yaml_file(loaded_llm_config) - with get_rayllm_testing_model(test_model_path) as (client, model_id): yield client, model_id diff --git a/python/ray/llm/tests/serve/cpu/builders/test_application_builders.py b/python/ray/llm/tests/serve/cpu/builders/test_application_builders.py index 1f5354480fe3d..0c9b997ba2b95 100644 --- a/python/ray/llm/tests/serve/cpu/builders/test_application_builders.py +++ b/python/ray/llm/tests/serve/cpu/builders/test_application_builders.py @@ -20,54 +20,70 @@ import tempfile import signal import sys +import re from ray._private.test_utils import wait_for_condition @pytest.fixture -def get_llm_serve_args(llm_config): - yield LLMServingArgs(llm_configs=[llm_config]) +def llm_config_with_mock_engine(llm_config): + # Make sure engine is mocked. + if llm_config.runtime_env is None: + llm_config.runtime_env = {} + llm_config.runtime_env.setdefault("env_vars", {})[ + "RAYLLM_VLLM_ENGINE_CLS" + ] = "ray.llm.tests.serve.mocks.mock_vllm_engine.MockVLLMEngine" + yield llm_config + + +@pytest.fixture +def get_llm_serve_args(llm_config_with_mock_engine): + yield LLMServingArgs(llm_configs=[llm_config_with_mock_engine]) @pytest.fixture() -def serve_config_separate_model_config_files(model_pixtral_12b): - with tempfile.TemporaryDirectory() as config_dir: - serve_config_filename = "llm_app_separate_model_config_files.yaml" - config_root = os.path.join(os.path.dirname(__file__), "test_config_files") - serve_config_src = os.path.join(config_root, serve_config_filename) - serve_config_dst = os.path.join(config_dir, serve_config_filename) +def serve_config_separate_model_config_files(): + config_dir = tempfile.mkdtemp() + serve_config_filename = "llm_app_separate_model_config_files.yaml" + config_root = os.path.join(os.path.dirname(__file__), "test_config_files") + serve_config_src = os.path.join(config_root, serve_config_filename) + serve_config_dst = os.path.join(config_dir, serve_config_filename) - with open(serve_config_src, "r") as f: - serve_config_yaml = yaml.safe_load(f) + with open(serve_config_src, "r") as f: + serve_config_yaml = yaml.safe_load(f) - for application in serve_config_yaml["applications"]: - llm_configs = application["args"]["llm_configs"] - tmp_llm_config_files = [] - for llm_config in llm_configs: - llm_config_src = llm_config.replace(".", config_root, 1) - llm_config_dst = llm_config.replace(".", config_dir, 1) - tmp_llm_config_files.append(llm_config_dst) + for application in serve_config_yaml["applications"]: + llm_configs = application["args"]["llm_configs"] + tmp_llm_config_files = [] + for llm_config in llm_configs: + llm_config_src = llm_config.replace(".", config_root, 1) + llm_config_dst = llm_config.replace(".", config_dir, 1) + tmp_llm_config_files.append(llm_config_dst) - with open(llm_config_src, "r") as f: - llm_config_yaml = yaml.safe_load(f) - llm_config_yaml["model_loading_config"]["model_id"] = model_pixtral_12b + with open(llm_config_src, "r") as f: + llm_config_yaml = yaml.safe_load(f) - os.makedirs(os.path.dirname(llm_config_dst), exist_ok=True) - with open(llm_config_dst, "w") as f: - yaml.dump(llm_config_yaml, f) + # Make sure engine is mocked. + if llm_config_yaml.get("runtime_env", None) is None: + llm_config_yaml["runtime_env"] = {} + llm_config_yaml["runtime_env"]["env_vars"] = { + "RAYLLM_VLLM_ENGINE_CLS": "ray.llm.tests.serve.mocks.mock_vllm_engine.MockVLLMEngine" + } - application["args"]["llm_configs"] = tmp_llm_config_files + os.makedirs(os.path.dirname(llm_config_dst), exist_ok=True) + with open(llm_config_dst, "w") as f: + yaml.dump(llm_config_yaml, f) - with open(serve_config_dst, "w") as f: - yaml.dump(serve_config_yaml, f) + application["args"]["llm_configs"] = tmp_llm_config_files - yield serve_config_dst + with open(serve_config_dst, "w") as f: + yaml.dump(serve_config_yaml, f) + + yield serve_config_dst class TestBuildOpenaiApp: - def test_build_openai_app( - self, get_llm_serve_args, shutdown_ray_and_serve, use_mock_vllm_engine - ): + def test_build_openai_app(self, get_llm_serve_args, shutdown_ray_and_serve): """Test `build_openai_app` can build app and run it with Serve.""" app = build_openai_app( @@ -77,25 +93,36 @@ def test_build_openai_app( serve.run(app) def test_build_openai_app_with_config( - self, - serve_config_separate_model_config_files, - shutdown_ray_and_serve, - use_mock_vllm_engine, + self, serve_config_separate_model_config_files, shutdown_ray_and_serve ): """Test `build_openai_app` can be used in serve config.""" def deployments_healthy(): status_response = subprocess.check_output(["serve", "status"]) - serve_status = yaml.safe_load(status_response)["applications"][ - "llm-endpoint" - ] - assert len(serve_status["deployments"]) == 2 - deployment_status = serve_status["deployments"].values() - assert all([status["status"] == "HEALTHY" for status in deployment_status]) + print("[TEST] Status response: ", status_response) + applications = extract_applications_from_output(status_response) + + if "llm-endpoint" not in applications: + print("[TEST] Application 'llm-endpoint' not found.") + return False + + llm_endpoint_status = applications["llm-endpoint"] + if len(llm_endpoint_status["deployments"]) != 2: + print( + f"[TEST] Expected 2 deployments, found {len(llm_endpoint_status['deployments'])}" + ) + return False + + deployment_status = llm_endpoint_status["deployments"].values() + if not all([status["status"] == "HEALTHY" for status in deployment_status]): + print(f"[TEST] Not all deployments healthy: {deployment_status}") + return False + + print("[TEST] All deployments healthy.") return True p = subprocess.Popen(["serve", "run", serve_config_separate_model_config_files]) - wait_for_condition(deployments_healthy, timeout=30) + wait_for_condition(deployments_healthy, timeout=60, retry_interval_ms=1000) p.send_signal(signal.SIGINT) # Equivalent to ctrl-C p.wait() @@ -149,16 +176,35 @@ def test_router_built_with_autoscaling_configs(self): class TestBuildVllmDeployment: def test_build_llm_deployment( self, - llm_config, + llm_config_with_mock_engine, shutdown_ray_and_serve, - use_mock_vllm_engine, ): """Test `build_llm_deployment` can build a vLLM deployment.""" - app = build_llm_deployment(llm_config) + app = build_llm_deployment(llm_config_with_mock_engine) assert isinstance(app, serve.Application) serve.run(app) +def extract_applications_from_output(output: bytes) -> dict: + """ + Extracts the 'applications' block from mixed output and returns it as a dict. + """ + # 1. Decode bytes to string + text = output.decode("utf-8", errors="ignore") + + # 2. Regex to find the 'applications:' block and its indented content + # This matches 'applications:' and all following lines that are indented (YAML block) + match = re.search(r"(^applications:\n(?:^(?: {2,}|\t).*\n?)+)", text, re.MULTILINE) + if not match: + raise ValueError("Could not find 'applications:' block in output.") + + applications_block = match.group(1) + + # 3. Parse the YAML block + applications_dict = yaml.safe_load(applications_block) + return applications_dict["applications"] + + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/llm/tests/serve/cpu/builders/test_config_files/model_config/llm_config.yaml b/python/ray/llm/tests/serve/cpu/builders/test_config_files/model_config/llm_config.yaml index 91421496d5b38..567b9457f296e 100644 --- a/python/ray/llm/tests/serve/cpu/builders/test_config_files/model_config/llm_config.yaml +++ b/python/ray/llm/tests/serve/cpu/builders/test_config_files/model_config/llm_config.yaml @@ -1,8 +1,6 @@ model_loading_config: model_id: model1 -accelerator_type: "L4" - deployment_config: ray_actor_options: resources: diff --git a/python/ray/llm/tests/serve/cpu/deployments/llm/multiplex/test_multiplex_deployment.py b/python/ray/llm/tests/serve/cpu/deployments/llm/multiplex/test_multiplex_deployment.py index 20eebf335caf2..a3c60a85e34f3 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/llm/multiplex/test_multiplex_deployment.py +++ b/python/ray/llm/tests/serve/cpu/deployments/llm/multiplex/test_multiplex_deployment.py @@ -143,13 +143,11 @@ async def test_multiplex_deployment( assert arg is not None - expected_lora_out_with_serve_request_context = dict(expected_lora_out) - expected_lora_out_with_serve_request_context[ - "serve_request_context" - ] = arg.model_dump().get("serve_request_context") + expected_lora_out_modified = dict(expected_lora_out) + expected_lora_out_modified["stream"] = stream_tokens print("***arg***", arg.model_dump()) - print("***exp***", expected_lora_out_with_serve_request_context) - assert arg == arg.__class__(**expected_lora_out_with_serve_request_context) + print("***exp***", expected_lora_out_modified) + assert arg == arg.__class__(**expected_lora_out_modified) responses = [ x @@ -188,8 +186,8 @@ async def test_multiplex_deployment( "best_of": 1, }, "multi_modal_data": None, - "serve_request_context": arg.model_dump().get("serve_request_context"), "disk_multiplex_config": None, + "stream": stream_tokens, } assert arg.model_dump() == expected_model_dump, ( "Arg model dump didn't match expected value." diff --git a/python/ray/llm/tests/serve/cpu/deployments/llm/vllm/test_vllm_engine.py b/python/ray/llm/tests/serve/cpu/deployments/llm/vllm/test_vllm_engine.py index 1e92f11d201fa..a8ed4508b11e9 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/llm/vllm/test_vllm_engine.py +++ b/python/ray/llm/tests/serve/cpu/deployments/llm/vllm/test_vllm_engine.py @@ -7,8 +7,8 @@ import pytest from ray.llm._internal.serve.configs.server_models import FinishReason +from ray.llm._internal.serve.deployments.utils.batcher import LLMRawResponsesBatcher from ray.llm._internal.serve.deployments.llm.vllm.vllm_engine import ( - BatchLLMRawResponses, VLLMEngine, ) from ray.llm._internal.serve.deployments.llm.vllm.vllm_models import ( @@ -87,6 +87,7 @@ def get_fake_engine_and_request(llm_config: LLMConfig, expected_out: List[str]): request_id="req_id", sampling_params=VLLMSamplingParams(), disk_multiplex_config=None, + stream=True, ) return vllm_engine, req, engine_mock @@ -129,7 +130,7 @@ async def test_vllm_engine_error_in_caller(self, llm_config): ) with pytest.raises(RuntimeError): - async for _x in vllm_engine.generate(req, stream=True): + async for _x in vllm_engine.generate(req): raise RuntimeError() await asyncio.sleep(0.02) # wait for asyncio task scheduling @@ -144,7 +145,7 @@ async def test_vllm_engine_caller_cancellation(self, llm_config): ) async def run(): - async for x in vllm_engine.generate(req, stream=True): + async for x in vllm_engine.generate(req): print(x) task = asyncio.create_task(run()) @@ -229,7 +230,7 @@ class TestBatching: @pytest.mark.asyncio async def test_batch(self): count = 0 - batcher = BatchLLMRawResponses(fake_generator()) + batcher = LLMRawResponsesBatcher(fake_generator()) async for x in batcher.stream(): count += 1 assert x.num_generated_tokens == 100 @@ -242,7 +243,7 @@ async def test_batch(self): @pytest.mark.asyncio async def test_batch_timing(self): count = 0 - batcher = BatchLLMRawResponses(fake_generator_slow(num_batches=10)) + batcher = LLMRawResponsesBatcher(fake_generator_slow(num_batches=10)) async for _x in batcher.stream(): count += 1 @@ -258,7 +259,7 @@ async def test_batch_last_return_is_immediate(self): the last response if it returns quickly.""" count = 0 token_count = 0 - batcher = BatchLLMRawResponses(fake_generator_slow_last_return_immediate()) + batcher = LLMRawResponsesBatcher(fake_generator_slow_last_return_immediate()) last_response = None async for _x in batcher.stream(): count += 1 @@ -278,7 +279,7 @@ async def test_batch_last_return_is_immediate(self): async def test_batch_no_interval(self): """Check that the class creates only one batch if there's no interval.""" - batcher = BatchLLMRawResponses( + batcher = LLMRawResponsesBatcher( fake_generator_slow(num_batches=10), interval_ms=None ) @@ -301,7 +302,7 @@ async def generator_should_raise(): raise ValueError() count = 0 - batched = BatchLLMRawResponses( + batched = LLMRawResponsesBatcher( generator_should_raise(), interval_ms=interval_ms ) @@ -340,7 +341,7 @@ async def generator_should_raise(): if to_cancel == "inner": raise asyncio.CancelledError() - batched = BatchLLMRawResponses( + batched = LLMRawResponsesBatcher( generator_should_raise(), interval_ms=interval_ms ) diff --git a/python/ray/llm/tests/serve/mock_vllm_model.yaml b/python/ray/llm/tests/serve/mock_vllm_model.yaml index 879fc9ec8f3d1..1e89e2fa7bdcc 100644 --- a/python/ray/llm/tests/serve/mock_vllm_model.yaml +++ b/python/ray/llm/tests/serve/mock_vllm_model.yaml @@ -1,5 +1,10 @@ model_loading_config: - model_id: VLLMFakeModel + model_id: FAKE_MODEL_UNDER_TEST + +# Overriding the engine class to only focus on testing the components around the engine +runtime_env: + env_vars: + RAYLLM_VLLM_ENGINE_CLS: "ray.llm.tests.serve.mocks.mock_vllm_engine.MockVLLMEngine" llm_engine: vLLM diff --git a/python/ray/llm/tests/serve/mock_vllm_model_no_accelerator.yaml b/python/ray/llm/tests/serve/mock_vllm_model_no_accelerator.yaml index 701fb4171a396..a54b2d597840c 100644 --- a/python/ray/llm/tests/serve/mock_vllm_model_no_accelerator.yaml +++ b/python/ray/llm/tests/serve/mock_vllm_model_no_accelerator.yaml @@ -1,5 +1,10 @@ model_loading_config: - model_id: VLLMFakeModel + model_id: FAKE_MODEL_UNDER_TEST + +# Overriding the engine class to only focus on testing the components around the engine +runtime_env: + env_vars: + RAYLLM_VLLM_ENGINE_CLS: "ray.llm.tests.serve.mocks.mock_vllm_engine.MockVLLMEngine" llm_engine: vLLM diff --git a/python/ray/llm/tests/serve/mocks/mock_vllm_engine.py b/python/ray/llm/tests/serve/mocks/mock_vllm_engine.py index 9880fa4156d22..26587a9d74e21 100644 --- a/python/ray/llm/tests/serve/mocks/mock_vllm_engine.py +++ b/python/ray/llm/tests/serve/mocks/mock_vllm_engine.py @@ -2,7 +2,7 @@ import json import random from random import randint -from typing import Dict +from typing import Dict, Optional from PIL import Image from vllm.sampling_params import SamplingParams as VLLMInternalSamplingParams @@ -18,6 +18,7 @@ DiskMultiplexConfig, LLMConfig, LLMRawResponse, + Prompt, ) from ray.llm._internal.serve.deployments.llm.vllm.vllm_engine_stats import ( VLLMEngineStats, @@ -30,9 +31,10 @@ from ray.llm._internal.serve.deployments.utils.node_initialization_utils import ( InitializeNodeOutput, ) +from ray.llm._internal.serve.deployments.llm.llm_engine import LLMEngine -class MockVLLMEngine: +class MockVLLMEngine(LLMEngine): def __init__(self, llm_config: LLMConfig): """Create a vLLM Engine class @@ -72,7 +74,26 @@ async def async_range(count): yield i await asyncio.sleep(0.0) - async def generate(self, vllm_engine_request: VLLMGenerationRequest, stream: bool): + async def prepare_request( + self, request_id: str, prompt: Prompt, stream: bool, **kwargs + ) -> VLLMGenerationRequest: + + if isinstance(prompt.prompt, list): + # Simplification: Assume prompt is a list of messages with one user message + assert len(prompt.prompt) == 1 + assert hasattr(prompt.prompt[0], "content") + prompt_text = prompt.prompt[0].content + else: + prompt_text = prompt.prompt + + return VLLMGenerationRequest( + request_id=request_id, + prompt=prompt_text, + stream=stream, + sampling_params=VLLMSamplingParams.from_prompt(prompt), + ) + + async def generate(self, vllm_engine_request: VLLMGenerationRequest): sampling_params = self._parse_sampling_params( vllm_engine_request.sampling_params ) @@ -228,7 +249,7 @@ def _convert_to_json(self, vllm_engine_request: VLLMGenerationRequest) -> Dict: res.update({"has_image": has_image}) return json.dumps(res) - async def generate(self, vllm_engine_request: VLLMGenerationRequest, stream: bool): + async def generate(self, vllm_engine_request: VLLMGenerationRequest): yield LLMRawResponse( generated_text=self._convert_to_json(vllm_engine_request), num_input_tokens=0, @@ -241,7 +262,7 @@ async def generate(self, vllm_engine_request: VLLMGenerationRequest, stream: boo ) -class MockMultiplexEngine: +class MockMultiplexEngine(LLMEngine): def __init__(self, *args, **kwargs): self.started = False @@ -253,10 +274,35 @@ async def initialize_node(llm_config: LLMConfig) -> InitializeNodeOutput: extra_init_kwargs={}, ) + async def prepare_request( + self, + request_id: str, + prompt: Prompt, + stream: bool, + disk_lora_model: Optional[DiskMultiplexConfig] = None, + ) -> VLLMGenerationRequest: + + if isinstance(prompt.prompt, list): + # Simplification: Assume prompt is a list of messages with one user message + assert len(prompt.prompt) == 1 + assert hasattr(prompt.prompt[0], "content") + prompt_text = prompt.prompt[0].content + else: + prompt_text = prompt.prompt + + output = VLLMGenerationRequest( + request_id=request_id, + prompt=prompt_text, + stream=stream, + sampling_params=VLLMSamplingParams.from_prompt(prompt), + disk_multiplex_config=disk_lora_model, + ) + return output + async def start(self): self.started = True - async def generate(self, arg, stream): + async def generate(self, arg): assert self.started, "Engine was not started" # First yield the arg yield arg @@ -328,7 +374,7 @@ async def generate_json(self, json_schema, max_tokens, prompt_len): yield llm_response await asyncio.sleep(generation_time) - async def generate(self, vllm_engine_request: VLLMGenerationRequest, stream: bool): + async def generate(self, vllm_engine_request: VLLMGenerationRequest): sampling_params = self._parse_sampling_params( vllm_engine_request.sampling_params )