diff --git a/src/exo/shared/types/chunks.py b/src/exo/shared/types/chunks.py index 204556e88c..2e51c076f4 100644 --- a/src/exo/shared/types/chunks.py +++ b/src/exo/shared/types/chunks.py @@ -85,6 +85,6 @@ class PrefillProgressChunk(BaseChunk): total_tokens: int -GenerationChunk = ( - TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk | PrefillProgressChunk -) +StatusChunk = PrefillProgressChunk +GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk +Chunk = StatusChunk | GenerationChunk diff --git a/src/exo/shared/types/events.py b/src/exo/shared/types/events.py index b750a2ae00..9fd35071d5 100644 --- a/src/exo/shared/types/events.py +++ b/src/exo/shared/types/events.py @@ -5,7 +5,7 @@ from exo.shared.models.model_cards import ModelCard from exo.shared.topology import Connection -from exo.shared.types.chunks import GenerationChunk, InputImageChunk +from exo.shared.types.chunks import Chunk, InputImageChunk from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId, SystemId from exo.shared.types.tasks import Task, TaskId, TaskStatus from exo.shared.types.worker.downloads import DownloadProgress @@ -91,7 +91,7 @@ class NodeDownloadProgress(BaseEvent): class ChunkGenerated(BaseEvent): command_id: CommandId - chunk: GenerationChunk + chunk: Chunk class InputChunkReceived(BaseEvent): diff --git a/src/exo/shared/types/tasks.py b/src/exo/shared/types/tasks.py index e764ec3370..a5fec8cc33 100644 --- a/src/exo/shared/types/tasks.py +++ b/src/exo/shared/types/tasks.py @@ -101,3 +101,6 @@ class Shutdown(BaseTask): # emitted by Worker | ImageEdits | Shutdown ) +TextTask = TextGeneration +ImageTask = ImageGeneration | ImageEdits +GenerationTask = TextTask | ImageTask diff --git a/src/exo/shared/types/worker/__init__.py b/src/exo/shared/types/worker/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/exo/shared/types/worker/runner_response.py b/src/exo/shared/types/worker/runner_response.py index e415bdca3a..9fb3301904 100644 --- a/src/exo/shared/types/worker/runner_response.py +++ b/src/exo/shared/types/worker/runner_response.py @@ -16,10 +16,6 @@ class BaseRunnerResponse(TaggedModel): pass -class TokenizedResponse(BaseRunnerResponse): - prompt_tokens: int - - class GenerationResponse(BaseRunnerResponse): text: str token: int @@ -75,6 +71,10 @@ class ModelLoadingResponse(BaseRunnerResponse): total: int +class CancelledResponse(BaseRunnerResponse): + pass + + class PrefillProgressResponse(BaseRunnerResponse): processed_tokens: int total_tokens: int diff --git a/src/exo/worker/engines/base.py b/src/exo/worker/engines/base.py new file mode 100644 index 0000000000..9efa13aea1 --- /dev/null +++ b/src/exo/worker/engines/base.py @@ -0,0 +1,55 @@ +from abc import ABC, abstractmethod +from collections.abc import Generator, Iterable + +from exo.shared.types.chunks import Chunk +from exo.shared.types.tasks import CANCEL_ALL_TASKS, GenerationTask, TaskId +from exo.shared.types.worker.instances import BoundInstance +from exo.shared.types.worker.runner_response import ( + CancelledResponse, + FinishedResponse, + ModelLoadingResponse, +) + + +class Engine(ABC): + _cancelled_tasks: set[TaskId] + + def should_cancel(self, task_id: TaskId) -> bool: + return ( + task_id in self._cancelled_tasks + or CANCEL_ALL_TASKS in self._cancelled_tasks + ) + + @abstractmethod + def warmup(self) -> None: ... + + @abstractmethod + def submit( + self, + task: GenerationTask, + ) -> None: ... + + @abstractmethod + def step( + self, + ) -> Iterable[tuple[TaskId, Chunk | CancelledResponse | FinishedResponse]]: ... + + @abstractmethod + def close(self) -> None: ... + + +class Builder(ABC): + @abstractmethod + def connect(self, bound_instance: BoundInstance) -> None: ... + + @abstractmethod + def load( + self, + bound_instance: BoundInstance, + ) -> Generator[ModelLoadingResponse]: ... + + @abstractmethod + def build(self) -> Engine: ... + + @abstractmethod + def close(self) -> None: ... diff --git a/src/exo/worker/engines/image/__init__.py b/src/exo/worker/engines/image/__init__.py index c83b4702f4..d0ccbac89a 100644 --- a/src/exo/worker/engines/image/__init__.py +++ b/src/exo/worker/engines/image/__init__.py @@ -1,12 +1,16 @@ +from exo.worker.engines.image.builder import ( + ImageEngine, + MfluxBuilder, +) from exo.worker.engines.image.distributed_model import ( DistributedImageModel, - initialize_image_model, ) from exo.worker.engines.image.generate import generate_image, warmup_image_generator __all__ = [ + "MfluxBuilder", + "ImageEngine", "DistributedImageModel", "generate_image", - "initialize_image_model", "warmup_image_generator", ] diff --git a/src/exo/worker/engines/image/builder.py b/src/exo/worker/engines/image/builder.py new file mode 100644 index 0000000000..81f4df9b43 --- /dev/null +++ b/src/exo/worker/engines/image/builder.py @@ -0,0 +1,212 @@ +import contextlib +from collections import deque +from collections.abc import Generator, Iterable +from dataclasses import dataclass, field + +import mlx.core as mx +from loguru import logger + +from exo.api.types import ImageEditsTaskParams, ImageGenerationTaskParams +from exo.shared.constants import EXO_TRACING_ENABLED +from exo.shared.tracing import clear_trace_buffer, get_trace_buffer +from exo.shared.types.chunks import Chunk, ErrorChunk +from exo.shared.types.events import ( + Event, + TraceEventData, + TracesCollected, +) +from exo.shared.types.tasks import ( + GenerationTask, + ImageEdits, + ImageGeneration, + ImageTask, + TaskId, +) +from exo.shared.types.worker.instances import BoundInstance +from exo.shared.types.worker.runner_response import ( + CancelledResponse, + FinishedResponse, + ModelLoadingResponse, +) +from exo.shared.types.worker.shards import ( + CfgShardMetadata, + PipelineShardMetadata, + ShardMetadata, +) +from exo.utils.channels import MpReceiver, MpSender +from exo.worker.engines.base import Builder, Engine +from exo.worker.engines.image.distributed_model import ( + DistributedImageModel, +) +from exo.worker.engines.image.generate import ( + generate_image, + warmup_image_generator, +) +from exo.worker.engines.mlx.utils_mlx import ( + initialize_mlx, +) + + +def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool: + """Check if this node is the primary output node for image generation. + + For CFG models: the last pipeline stage in CFG group 0 (positive prompt). + For non-CFG models: the last pipeline stage. + """ + if isinstance(shard_metadata, CfgShardMetadata): + is_pipeline_last = ( + shard_metadata.pipeline_rank == shard_metadata.pipeline_world_size - 1 + ) + return is_pipeline_last and shard_metadata.cfg_rank == 0 + elif isinstance(shard_metadata, PipelineShardMetadata): + return shard_metadata.device_rank == shard_metadata.world_size - 1 + return False + + +def _send_traces_if_enabled( + event_sender: MpSender[Event], + task_id: TaskId, + rank: int, +) -> None: + if not EXO_TRACING_ENABLED: + return + + traces = get_trace_buffer() + if traces: + trace_data = [ + TraceEventData( + name=t.name, + start_us=t.start_us, + duration_us=t.duration_us, + rank=t.rank, + category=t.category, + ) + for t in traces + ] + event_sender.send( + TracesCollected( + task_id=task_id, + rank=rank, + traces=trace_data, + ) + ) + clear_trace_buffer() + + +@dataclass +class MfluxBuilder(Builder): + event_sender: MpSender[Event] + cancel_receiver: MpReceiver[TaskId] + shard_metadata: ShardMetadata | None = None + image_model: DistributedImageModel | None = None + group: mx.distributed.Group | None = None + + def connect(self, bound_instance: BoundInstance) -> None: + self.group = initialize_mlx(bound_instance) + + def load(self, bound_instance: BoundInstance) -> Generator[ModelLoadingResponse]: + self.shard_metadata = bound_instance.bound_shard + self.image_model = DistributedImageModel.from_shard_metadata( + bound_instance.bound_shard, self.group + ) + return + # very important! + yield + + def close(self) -> None: + with contextlib.suppress(NameError, AttributeError): + del self.image_model, self.group + + def build( + self, + ) -> Engine: + assert self.image_model + assert self.shard_metadata + + return ImageEngine( + self.image_model, + self.shard_metadata, + self.event_sender, + self.cancel_receiver, + ) + + +@dataclass +class ImageEngine(Engine): + image_model: DistributedImageModel + shard_metadata: ShardMetadata + event_sender: MpSender[Event] + cancel_receiver: MpReceiver[TaskId] + current_gen: Generator[tuple[TaskId, Chunk]] | None = field( + init=False, default=None + ) + queue: deque[ImageTask] = field(init=False, default_factory=deque) + + def warmup(self) -> None: + image = warmup_image_generator(model=self.image_model) + if image is not None: + logger.info(f"warmed up by generating {image.size} image") + else: + logger.info("warmup completed (non-primary node)") + + def submit( + self, + task: GenerationTask, + ) -> None: + assert isinstance(task, (ImageGeneration, ImageEdits)) + self.queue.append(task) + + def step( + self, + ) -> Iterable[tuple[TaskId, Chunk | CancelledResponse | FinishedResponse]]: + resp = None + if self.current_gen is not None: + resp = next(self.current_gen, None) + if resp is None and len(self.queue) > 0: + task = self.queue.popleft() + self.current_gen = self._run_image_task(task.task_id, task.task_params) + resp = next(self.current_gen, None) + return (resp,) if resp is not None else () + + def close(self) -> None: + with contextlib.suppress(NameError, AttributeError): + del self.image_model + + def _run_image_task( + self, + task_id: TaskId, + task_params: ImageGenerationTaskParams | ImageEditsTaskParams, + ) -> Generator[tuple[TaskId, Chunk]]: + assert self.image_model + logger.info(f"received image task: {str(task_params)[:500]}") + + def cancel_checker() -> bool: + for cancel_id in self.cancel_receiver.collect(): + self._cancelled_tasks.add(cancel_id) + return self.should_cancel(task_id) + + try: + for response in generate_image( + model=self.image_model, + task=task_params, + cancel_checker=cancel_checker, + ): + if _is_primary_output_node(self.shard_metadata): + yield (task_id, response) + except Exception as e: + if _is_primary_output_node(self.shard_metadata): + yield ( + task_id, + ErrorChunk( + model=self.shard_metadata.model_card.model_id, + finish_reason="error", + error_message=str(e), + ), + ) + raise + finally: + _send_traces_if_enabled( + self.event_sender, task_id, self.shard_metadata.device_rank + ) + + return diff --git a/src/exo/worker/engines/image/distributed_model.py b/src/exo/worker/engines/image/distributed_model.py index 4c9e7406ec..4fc375cac9 100644 --- a/src/exo/worker/engines/image/distributed_model.py +++ b/src/exo/worker/engines/image/distributed_model.py @@ -1,6 +1,6 @@ from collections.abc import Callable, Generator from pathlib import Path -from typing import Any, Literal, Optional +from typing import Any, Literal import mlx.core as mx from mflux.models.common.config.config import Config @@ -9,8 +9,11 @@ from exo.api.types import AdvancedImageParams from exo.download.download_utils import build_model_path from exo.shared.types.common import ModelId -from exo.shared.types.worker.instances import BoundInstance -from exo.shared.types.worker.shards import CfgShardMetadata, PipelineShardMetadata +from exo.shared.types.worker.shards import ( + CfgShardMetadata, + PipelineShardMetadata, + ShardMetadata, +) from exo.worker.engines.image.config import ImageModelConfig from exo.worker.engines.image.models import ( create_adapter_for_model, @@ -18,7 +21,7 @@ ) from exo.worker.engines.image.models.base import ModelAdapter from exo.worker.engines.image.pipeline import DiffusionRunner -from exo.worker.engines.mlx.utils_mlx import mlx_distributed_init, mx_barrier +from exo.worker.engines.mlx.utils_mlx import mx_barrier from exo.worker.runner.bootstrap import logger @@ -33,7 +36,7 @@ def __init__( model_id: ModelId, local_path: Path, shard_metadata: PipelineShardMetadata | CfgShardMetadata, - group: Optional[mx.distributed.Group] = None, + group: mx.distributed.Group | None, quantize: int | None = None, ): config = get_config_for_model(model_id) @@ -76,32 +79,21 @@ def __init__( self._runner = runner @classmethod - def from_bound_instance( - cls, bound_instance: BoundInstance + def from_shard_metadata( + cls, shard: ShardMetadata, group: mx.distributed.Group | None ) -> "DistributedImageModel": - model_id = bound_instance.bound_shard.model_card.model_id + model_id = shard.model_card.model_id model_path = build_model_path(model_id) - shard_metadata = bound_instance.bound_shard - if not isinstance(shard_metadata, (PipelineShardMetadata, CfgShardMetadata)): + if not isinstance(shard, (PipelineShardMetadata, CfgShardMetadata)): raise ValueError( "Expected PipelineShardMetadata or CfgShardMetadata for image generation" ) - is_distributed = ( - len(bound_instance.instance.shard_assignments.node_to_runner) > 1 - ) - - if is_distributed: - logger.info("Starting distributed init for image model") - group = mlx_distributed_init(bound_instance) - else: - group = None - return cls( model_id=model_id, local_path=model_path, - shard_metadata=shard_metadata, + shard_metadata=shard, group=group, ) @@ -176,7 +168,3 @@ def generate( else: logger.info("generated image") yield result - - -def initialize_image_model(bound_instance: BoundInstance) -> DistributedImageModel: - return DistributedImageModel.from_bound_instance(bound_instance) diff --git a/src/exo/worker/engines/mlx/builder.py b/src/exo/worker/engines/mlx/builder.py new file mode 100644 index 0000000000..ca095ad1ec --- /dev/null +++ b/src/exo/worker/engines/mlx/builder.py @@ -0,0 +1,108 @@ +import contextlib +import os +from collections.abc import Generator +from dataclasses import dataclass + +import mlx.core as mx +from mlx_lm.tokenizer_utils import TokenizerWrapper + +from exo.shared.types.common import ModelId +from exo.shared.types.events import Event +from exo.shared.types.mlx import Model +from exo.shared.types.tasks import TaskId +from exo.shared.types.worker.instances import BoundInstance +from exo.shared.types.worker.runner_response import ModelLoadingResponse +from exo.utils.channels import MpReceiver, MpSender +from exo.worker.engines.base import Builder, Engine +from exo.worker.engines.mlx.cache import KVPrefixCache +from exo.worker.engines.mlx.utils_mlx import ( + initialize_mlx, + load_mlx_items, +) +from exo.worker.engines.mlx.vision import VisionProcessor +from exo.worker.runner.bootstrap import logger +from exo.worker.runner.llm_inference.batch_generator import ( + BatchGenerator, + SequentialGenerator, +) +from exo.worker.runner.llm_inference.tool_parsers import make_mlx_parser + + +@dataclass +class MlxBuilder(Builder): + model_id: ModelId + event_sender: MpSender[Event] + cancel_receiver: MpReceiver[TaskId] + inference_model: Model | None = None + tokenizer: TokenizerWrapper | None = None + group: mx.distributed.Group | None = None + vision_processor: VisionProcessor | None = None + + def connect(self, bound_instance: BoundInstance) -> None: + self.group = initialize_mlx(bound_instance) + + def load(self, bound_instance: BoundInstance) -> Generator[ModelLoadingResponse]: + ( + self.inference_model, + self.tokenizer, + self.vision_processor, + ) = yield from load_mlx_items(bound_instance, self.group) + + def close(self) -> None: + with contextlib.suppress(NameError, AttributeError): + del self.inference_model, self.tokenizer, self.group + + def build( + self, + ) -> Engine: + assert self.inference_model + assert self.tokenizer + + vision_processor = self.vision_processor + + tool_parser = None + logger.info( + f"model has_tool_calling={self.tokenizer.has_tool_calling} using tokens {self.tokenizer.tool_call_start}, {self.tokenizer.tool_call_end}" + ) + if ( + self.tokenizer.tool_call_start + and self.tokenizer.tool_call_end + and self.tokenizer.tool_parser # type: ignore + ): + tool_parser = make_mlx_parser( + self.tokenizer.tool_call_start, + self.tokenizer.tool_call_end, + self.tokenizer.tool_parser, # type: ignore + ) + + kv_prefix_cache = KVPrefixCache(self.group) + + device_rank = 0 if self.group is None else self.group.rank() + if os.environ.get("EXO_NO_BATCH"): + logger.info("using SequentialGenerator (batching disabled)") + return SequentialGenerator( + model=self.inference_model, + tokenizer=self.tokenizer, + group=self.group, + tool_parser=tool_parser, + kv_prefix_cache=kv_prefix_cache, + model_id=self.model_id, + device_rank=device_rank, + cancel_receiver=self.cancel_receiver, + event_sender=self.event_sender, + vision_processor=vision_processor, + ) + else: + logger.info("using BatchGenerator") + return BatchGenerator( + model=self.inference_model, + tokenizer=self.tokenizer, + group=self.group, + tool_parser=tool_parser, + kv_prefix_cache=kv_prefix_cache, + model_id=self.model_id, + device_rank=device_rank, + cancel_receiver=self.cancel_receiver, + event_sender=self.event_sender, + vision_processor=vision_processor, + ) diff --git a/src/exo/worker/engines/mlx/generator/batch_generate.py b/src/exo/worker/engines/mlx/generator/batch_generate.py index b0b2a53793..08e6399a10 100644 --- a/src/exo/worker/engines/mlx/generator/batch_generate.py +++ b/src/exo/worker/engines/mlx/generator/batch_generate.py @@ -457,6 +457,7 @@ def cancel(self, uids: list[int]) -> None: def close(self) -> None: self._mlx_gen.close() + mx.clear_cache() def _save_prefix_cache( self, diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index c676c85cf6..b35f946aac 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -57,7 +57,7 @@ from exo.utils.keyed_backoff import KeyedBackoff from exo.utils.task_group import TaskGroup from exo.worker.plan import plan -from exo.worker.runner.runner_supervisor import RunnerSupervisor +from exo.worker.runner.supervisor import RunnerSupervisor class Worker: diff --git a/src/exo/worker/plan.py b/src/exo/worker/plan.py index 1ca2cfaf5d..3824e4bb7a 100644 --- a/src/exo/worker/plan.py +++ b/src/exo/worker/plan.py @@ -41,7 +41,7 @@ RunnerWarmingUp, ) from exo.utils.keyed_backoff import KeyedBackoff -from exo.worker.runner.runner_supervisor import RunnerSupervisor +from exo.worker.runner.supervisor import RunnerSupervisor def plan( diff --git a/src/exo/worker/runner/bootstrap.py b/src/exo/worker/runner/bootstrap.py index cc382d9202..e904118022 100644 --- a/src/exo/worker/runner/bootstrap.py +++ b/src/exo/worker/runner/bootstrap.py @@ -8,6 +8,7 @@ from exo.shared.types.worker.instances import BoundInstance from exo.shared.types.worker.runners import RunnerFailed from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender +from exo.worker.engines.base import Builder logger: "loguru.Logger" = loguru.logger @@ -35,23 +36,32 @@ def entrypoint( # Import main after setting global logger - this lets us just import logger from this module try: + from exo.worker.runner.runner import Runner + + builder: Builder + if bound_instance.is_image_model: - from exo.worker.runner.image_models.runner import Runner as ImageRunner + from exo.worker.engines.image.builder import MfluxBuilder - runner = ImageRunner( - bound_instance, event_sender, task_receiver, cancel_receiver + builder = MfluxBuilder( + event_sender, cancel_receiver, bound_instance.bound_shard ) - runner.main() else: from exo.worker.engines.mlx.patches import apply_mlx_patches - from exo.worker.runner.llm_inference.runner import Runner apply_mlx_patches() - runner = Runner( - bound_instance, event_sender, task_receiver, cancel_receiver + from exo.worker.engines.mlx.builder import MlxBuilder + + # evil sharing of the event sender + builder = MlxBuilder( + model_id=bound_instance.bound_shard.model_card.model_id, + event_sender=event_sender, + cancel_receiver=cancel_receiver, ) - runner.main() + + runner = Runner(bound_instance, builder, event_sender, task_receiver) + runner.main() except ClosedResourceError: logger.warning("Runner communication closed unexpectedly") diff --git a/src/exo/worker/runner/image_models/runner.py b/src/exo/worker/runner/image_models/runner.py deleted file mode 100644 index e743c4ce4b..0000000000 --- a/src/exo/worker/runner/image_models/runner.py +++ /dev/null @@ -1,315 +0,0 @@ -import time -from typing import TYPE_CHECKING - -import mlx.core as mx - -from exo.api.types import ( - ImageEditsTaskParams, - ImageGenerationTaskParams, -) -from exo.shared.constants import EXO_TRACING_ENABLED -from exo.shared.models.model_cards import ModelTask -from exo.shared.tracing import clear_trace_buffer, get_trace_buffer -from exo.shared.types.chunks import ErrorChunk -from exo.shared.types.common import CommandId -from exo.shared.types.events import ( - ChunkGenerated, - Event, - RunnerStatusUpdated, - TaskAcknowledged, - TaskStatusUpdated, - TraceEventData, - TracesCollected, -) -from exo.shared.types.tasks import ( - CANCEL_ALL_TASKS, - ConnectToGroup, - ImageEdits, - ImageGeneration, - LoadModel, - Shutdown, - StartWarmup, - Task, - TaskId, - TaskStatus, -) -from exo.shared.types.worker.instances import BoundInstance -from exo.shared.types.worker.runners import ( - RunnerConnected, - RunnerConnecting, - RunnerIdle, - RunnerLoaded, - RunnerLoading, - RunnerReady, - RunnerRunning, - RunnerShutdown, - RunnerShuttingDown, - RunnerStatus, - RunnerWarmingUp, -) -from exo.shared.types.worker.shards import ( - CfgShardMetadata, - PipelineShardMetadata, - ShardMetadata, -) -from exo.utils.channels import MpReceiver, MpSender -from exo.worker.engines.image import ( - DistributedImageModel, - generate_image, - initialize_image_model, - warmup_image_generator, -) -from exo.worker.engines.mlx.utils_mlx import ( - initialize_mlx, -) -from exo.worker.runner.bootstrap import logger - - -def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool: - """Check if this node is the primary output node for image generation. - - For CFG models: the last pipeline stage in CFG group 0 (positive prompt). - For non-CFG models: the last pipeline stage. - """ - if isinstance(shard_metadata, CfgShardMetadata): - is_pipeline_last = ( - shard_metadata.pipeline_rank == shard_metadata.pipeline_world_size - 1 - ) - return is_pipeline_last and shard_metadata.cfg_rank == 0 - elif isinstance(shard_metadata, PipelineShardMetadata): - return shard_metadata.device_rank == shard_metadata.world_size - 1 - return False - - -def _send_traces_if_enabled( - event_sender: MpSender[Event], - task_id: TaskId, - rank: int, -) -> None: - if not EXO_TRACING_ENABLED: - return - - traces = get_trace_buffer() - if traces: - trace_data = [ - TraceEventData( - name=t.name, - start_us=t.start_us, - duration_us=t.duration_us, - rank=t.rank, - category=t.category, - ) - for t in traces - ] - event_sender.send( - TracesCollected( - task_id=task_id, - rank=rank, - traces=trace_data, - ) - ) - clear_trace_buffer() - - -class Runner: - def __init__( - self, - bound_instance: BoundInstance, - event_sender: MpSender[Event], - task_receiver: MpReceiver[Task], - cancel_receiver: MpReceiver[TaskId], - ): - self.event_sender = event_sender - self.task_receiver = task_receiver - self.cancel_receiver = cancel_receiver - self.bound_instance = bound_instance - - self.instance, self.runner_id, self.shard_metadata = ( - bound_instance.instance, - bound_instance.bound_runner_id, - bound_instance.bound_shard, - ) - self.device_rank = self.shard_metadata.device_rank - - logger.info("hello from the runner") - if getattr(self.shard_metadata, "immediate_exception", False): - raise Exception("Fake exception - runner failed to spin up.") - if timeout := getattr(self.shard_metadata, "should_timeout", 0): - time.sleep(timeout) - - self.setup_start_time = time.time() - self.cancelled_tasks = set[TaskId]() - - self.image_model: DistributedImageModel | None = None - self.group = None - - self.current_status: RunnerStatus = RunnerIdle() - logger.info("runner created") - self.update_status(RunnerIdle()) - self.seen = set[TaskId]() - - def update_status(self, status: RunnerStatus): - self.current_status = status - self.event_sender.send( - RunnerStatusUpdated( - runner_id=self.runner_id, runner_status=self.current_status - ) - ) - - def send_task_status(self, task: Task, status: TaskStatus): - self.event_sender.send( - TaskStatusUpdated(task_id=task.task_id, task_status=status) - ) - - def acknowledge_task(self, task: Task): - self.event_sender.send(TaskAcknowledged(task_id=task.task_id)) - - def _check_cancelled(self, task_id: TaskId) -> bool: - for cancel_id in self.cancel_receiver.collect(): - self.cancelled_tasks.add(cancel_id) - return ( - task_id in self.cancelled_tasks or CANCEL_ALL_TASKS in self.cancelled_tasks - ) - - def _run_image_task( - self, - task: Task, - task_params: ImageGenerationTaskParams | ImageEditsTaskParams, - command_id: CommandId, - ) -> None: - assert self.image_model - logger.info(f"received image task: {str(task)[:500]}") - logger.info("runner running") - self.update_status(RunnerRunning()) - self.acknowledge_task(task) - - def cancel_checker() -> bool: - return self._check_cancelled(task.task_id) - - try: - for chunk in generate_image( - model=self.image_model, - task=task_params, - cancel_checker=cancel_checker, - ): - if _is_primary_output_node(self.shard_metadata): - if chunk.is_partial: - logger.info( - f"sending partial ImageChunk {chunk.partial_index}/{chunk.total_partials}" - ) - else: - logger.info("sending final ImageChunk") - self.event_sender.send( - ChunkGenerated(command_id=command_id, chunk=chunk) - ) - except Exception as e: - if _is_primary_output_node(self.shard_metadata): - self.event_sender.send( - ChunkGenerated( - command_id=command_id, - chunk=ErrorChunk( - model=self.shard_metadata.model_card.model_id, - finish_reason="error", - error_message=str(e), - ), - ) - ) - raise - finally: - _send_traces_if_enabled(self.event_sender, task.task_id, self.device_rank) - - self.current_status = RunnerReady() - logger.info("runner ready") - - def main(self): - with self.task_receiver as tasks: - for task in tasks: - if task.task_id in self.seen: - logger.warning("repeat task - potential error") - self.seen.add(task.task_id) - self.cancelled_tasks.discard(CANCEL_ALL_TASKS) - self.send_task_status(task, TaskStatus.Running) - self.handle_task(task) - was_cancelled = (task.task_id in self.cancelled_tasks) or ( - CANCEL_ALL_TASKS in self.cancelled_tasks - ) - if not was_cancelled: - self.send_task_status(task, TaskStatus.Complete) - self.update_status(self.current_status) - - if isinstance(self.current_status, RunnerShutdown): - break - - def handle_task(self, task: Task): - match task: - case ConnectToGroup() if isinstance(self.current_status, RunnerIdle): - logger.info("runner connecting") - self.update_status(RunnerConnecting()) - self.acknowledge_task(task) - self.group = initialize_mlx(self.bound_instance) - - logger.info("runner connected") - self.current_status = RunnerConnected() - - # we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to - case LoadModel() if ( - isinstance(self.current_status, RunnerConnected) - and self.group is not None - ) or (isinstance(self.current_status, RunnerIdle) and self.group is None): - logger.info("runner loading") - self.update_status(RunnerLoading()) - self.acknowledge_task(task) - - assert ( - ModelTask.TextToImage in self.shard_metadata.model_card.tasks - or ModelTask.ImageToImage in self.shard_metadata.model_card.tasks - ), f"Incorrect model task(s): {self.shard_metadata.model_card.tasks}" - - self.image_model = initialize_image_model(self.bound_instance) - self.current_status = RunnerLoaded() - logger.info("runner loaded") - - case StartWarmup() if isinstance(self.current_status, RunnerLoaded): - logger.info("runner warming up") - self.update_status(RunnerWarmingUp()) - self.acknowledge_task(task) - - logger.info(f"warming up inference for instance: {self.instance}") - - assert self.image_model - image = warmup_image_generator(model=self.image_model) - if image is not None: - logger.info(f"warmed up by generating {image.size} image") - else: - logger.info("warmup completed (non-primary node)") - - logger.info( - f"runner initialized in {time.time() - self.setup_start_time} seconds" - ) - - self.current_status = RunnerReady() - logger.info("runner ready") - - case ( - ImageGeneration(task_params=task_params, command_id=command_id) - | ImageEdits(task_params=task_params, command_id=command_id) - ) if isinstance(self.current_status, RunnerReady): - self._run_image_task(task, task_params, command_id) - - case Shutdown(): - logger.info("runner shutting down") - if not TYPE_CHECKING: - del self.image_model, self.group - mx.clear_cache() - import gc - - gc.collect() - - self.update_status(RunnerShuttingDown()) - self.acknowledge_task(task) - - self.current_status = RunnerShutdown() - case _: - raise ValueError( - f"Received {task.__class__.__name__} outside of state machine in {self.current_status=}" - ) diff --git a/src/exo/worker/runner/llm_inference/batch_generator.py b/src/exo/worker/runner/llm_inference/batch_generator.py index ee536a9939..795ed68dcb 100644 --- a/src/exo/worker/runner/llm_inference/batch_generator.py +++ b/src/exo/worker/runner/llm_inference/batch_generator.py @@ -1,6 +1,5 @@ import itertools import time -from abc import ABC, abstractmethod from collections import deque from collections.abc import Generator, Iterator from dataclasses import dataclass, field @@ -13,10 +12,20 @@ from exo.shared.types.common import ModelId from exo.shared.types.events import ChunkGenerated, Event from exo.shared.types.mlx import Model -from exo.shared.types.tasks import CANCEL_ALL_TASKS, TaskId, TextGeneration +from exo.shared.types.tasks import ( + CANCEL_ALL_TASKS, + GenerationTask, + TaskId, + TextGeneration, +) from exo.shared.types.text_generation import TextGenerationTaskParams -from exo.shared.types.worker.runner_response import GenerationResponse +from exo.shared.types.worker.runner_response import ( + CancelledResponse, + FinishedResponse, + GenerationResponse, +) from exo.utils.channels import MpReceiver, MpSender +from exo.worker.engines.base import Engine from exo.worker.engines.mlx.cache import KVPrefixCache from exo.worker.engines.mlx.generator.batch_generate import ExoBatchGenerator from exo.worker.engines.mlx.generator.generate import ( @@ -36,14 +45,6 @@ from .tool_parsers import ToolParser -class Cancelled: - pass - - -class Finished: - pass - - class GeneratorQueue[T]: def __init__(self): self._q = deque[T]() @@ -59,33 +60,6 @@ def gen(self) -> Generator[T | None]: yield self._q.popleft() -class InferenceGenerator(ABC): - _cancelled_tasks: set[TaskId] - - def should_cancel(self, task_id: TaskId) -> bool: - return ( - task_id in self._cancelled_tasks - or CANCEL_ALL_TASKS in self._cancelled_tasks - ) - - @abstractmethod - def warmup(self) -> None: ... - - @abstractmethod - def submit( - self, - task: TextGeneration, - ) -> None: ... - - @abstractmethod - def step( - self, - ) -> Iterator[tuple[TaskId, GenerationChunk | Cancelled | Finished]]: ... - - @abstractmethod - def close(self) -> None: ... - - EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL" EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM" EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT" @@ -109,7 +83,7 @@ def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None: @dataclass(eq=False) -class SequentialGenerator(InferenceGenerator): +class SequentialGenerator(Engine): model: Model tokenizer: TokenizerWrapper group: mx.distributed.Group | None @@ -150,8 +124,9 @@ def warmup(self): def submit( self, - task: TextGeneration, + task: GenerationTask, ) -> None: + assert isinstance(task, TextGeneration) self._cancelled_tasks.discard(CANCEL_ALL_TASKS) self._all_tasks[task.task_id] = task self._maybe_queue.append(task) @@ -181,28 +156,34 @@ def agree_on_cancellations(self) -> None: def step( self, - ) -> Iterator[tuple[TaskId, GenerationChunk | Cancelled | Finished]]: + ) -> Iterator[ + tuple[TaskId, GenerationChunk | FinishedResponse | CancelledResponse] + ]: if self._active is None: self.agree_on_tasks() if self._queue: self._start_next() else: - return map(lambda task: (task, Cancelled()), self._cancelled_tasks) + return map( + lambda task: (task, CancelledResponse()), self._cancelled_tasks + ) assert self._active is not None - task, mlx_gen, queue, output_generator = self._active - output: list[tuple[TaskId, GenerationChunk | Cancelled | Finished]] = [] + task, gen, queue, output_generator = self._active + output: list[ + tuple[TaskId, GenerationChunk | CancelledResponse | FinishedResponse] + ] = [] try: - response = next(mlx_gen) + response = next(gen) queue.push(response) # drain potentially many responses every time while (parsed := next(output_generator, None)) is not None: output.append((task.task_id, parsed)) except (StopIteration, PrefillCancelled): - output.append((task.task_id, Finished())) + output.append((task.task_id, FinishedResponse())) self._active = None if self._queue: self._start_next() @@ -214,13 +195,13 @@ def step( return itertools.chain( output, - map(lambda task: (task, Cancelled()), self._cancelled_tasks), + map(lambda task: (task, CancelledResponse()), self._cancelled_tasks), ) def _start_next(self) -> None: task = self._queue.popleft() try: - mlx_gen = self._build_generator(task) + gen = self._build_generator(task) except Exception as e: self._send_error(task, e) raise @@ -240,7 +221,7 @@ def _start_next(self) -> None: self.model_id, task.task_params.tools, ) - self._active = (task, mlx_gen, queue, output_generator) + self._active = (task, gen, queue, output_generator) def _send_error(self, task: TextGeneration, e: Exception) -> None: if self.device_rank == 0: @@ -310,7 +291,7 @@ def close(self) -> None: @dataclass(eq=False) -class BatchGenerator(InferenceGenerator): +class BatchGenerator(Engine): model: Model tokenizer: TokenizerWrapper group: mx.distributed.Group | None @@ -328,7 +309,7 @@ class BatchGenerator(InferenceGenerator): _maybe_cancel: list[TextGeneration] = field(default_factory=list, init=False) _all_tasks: dict[TaskId, TextGeneration] = field(default_factory=dict, init=False) _queue: deque[TextGeneration] = field(default_factory=deque, init=False) - _mlx_gen: ExoBatchGenerator = field(init=False) + _gen: ExoBatchGenerator = field(init=False) _active_tasks: dict[ int, tuple[ @@ -339,7 +320,7 @@ class BatchGenerator(InferenceGenerator): ] = field(default_factory=dict, init=False) def __post_init__(self) -> None: - self._mlx_gen = ExoBatchGenerator( + self._gen = ExoBatchGenerator( model=self.model, tokenizer=self.tokenizer, group=self.group, @@ -357,8 +338,9 @@ def warmup(self): def submit( self, - task: TextGeneration, + task: GenerationTask, ) -> None: + assert isinstance(task, TextGeneration) self._cancelled_tasks.discard(CANCEL_ALL_TASKS) self._all_tasks[task.task_id] = task self._maybe_queue.append(task) @@ -388,7 +370,9 @@ def agree_on_cancellations(self) -> None: def step( self, - ) -> Iterator[tuple[TaskId, GenerationChunk | Cancelled | Finished]]: + ) -> Iterator[ + tuple[TaskId, GenerationChunk | CancelledResponse | FinishedResponse] + ]: if not self._queue: self.agree_on_tasks() @@ -420,12 +404,14 @@ def step( ) self._active_tasks[uid] = (task, queue, output_generator) - if not self._mlx_gen.has_work: + if not self._gen.has_work: return self._apply_cancellations() - results = self._mlx_gen.step() + results = self._gen.step() - output: list[tuple[TaskId, GenerationChunk | Cancelled | Finished]] = [] + output: list[ + tuple[TaskId, GenerationChunk | CancelledResponse | FinishedResponse] + ] = [] for uid, response in results: if uid not in self._active_tasks: # should we error here? @@ -440,35 +426,35 @@ def step( # check if original response was terminal and append a Finished() if response.finish_reason is not None: - output.append((task.task_id, Finished())) + output.append((task.task_id, FinishedResponse())) del self._active_tasks[uid] return itertools.chain(output, self._apply_cancellations()) def _apply_cancellations( self, - ) -> Iterator[tuple[TaskId, Cancelled]]: + ) -> Iterator[tuple[TaskId, CancelledResponse]]: if not self._cancelled_tasks: return iter([]) cancel_all = CANCEL_ALL_TASKS in self._cancelled_tasks uids_to_cancel: list[int] = [] - results: list[tuple[TaskId, Cancelled]] = [] + results: list[tuple[TaskId, CancelledResponse]] = [] for uid, (task, _, _) in list(self._active_tasks.items()): if task.task_id in self._cancelled_tasks or cancel_all: uids_to_cancel.append(uid) - results.append((task.task_id, Cancelled())) + results.append((task.task_id, CancelledResponse())) del self._active_tasks[uid] if uids_to_cancel: - self._mlx_gen.cancel(uids_to_cancel) + self._gen.cancel(uids_to_cancel) already_cancelled = {tid for tid, _ in results} for tid in self._cancelled_tasks: if tid != CANCEL_ALL_TASKS and tid not in already_cancelled: - results.append((tid, Cancelled())) + results.append((tid, CancelledResponse())) self._cancelled_tasks.clear() return iter(results) @@ -523,7 +509,7 @@ def on_generation_token() -> None: self.agree_on_tasks() - return self._mlx_gen.submit( + return self._gen.submit( task_params=task.task_params, prompt=prompt, on_prefill_progress=on_prefill_progress, @@ -532,5 +518,5 @@ def on_generation_token() -> None: ) def close(self) -> None: - self._mlx_gen.close() + self._gen.close() del self.model, self.tokenizer, self.group diff --git a/src/exo/worker/runner/llm_inference/runner.py b/src/exo/worker/runner/runner.py similarity index 58% rename from src/exo/worker/runner/llm_inference/runner.py rename to src/exo/worker/runner/runner.py index 9cc2506a22..cee1b655bc 100644 --- a/src/exo/worker/runner/llm_inference/runner.py +++ b/src/exo/worker/runner/runner.py @@ -1,16 +1,10 @@ -import os import time -from collections.abc import Generator -from dataclasses import dataclass from enum import Enum -import mlx.core as mx from anyio import WouldBlock -from mlx_lm.tokenizer_utils import TokenizerWrapper -from exo.shared.models.model_cards import ModelTask -from exo.shared.types.chunks import GenerationChunk -from exo.shared.types.common import CommandId, ModelId +from exo.shared.types.chunks import Chunk +from exo.shared.types.common import CommandId from exo.shared.types.events import ( ChunkGenerated, Event, @@ -18,9 +12,11 @@ TaskAcknowledged, TaskStatusUpdated, ) -from exo.shared.types.mlx import Model from exo.shared.types.tasks import ( ConnectToGroup, + GenerationTask, + ImageEdits, + ImageGeneration, LoadModel, Shutdown, StartWarmup, @@ -31,7 +27,8 @@ ) from exo.shared.types.worker.instances import BoundInstance from exo.shared.types.worker.runner_response import ( - ModelLoadingResponse, + CancelledResponse, + FinishedResponse, ) from exo.shared.types.worker.runners import ( RunnerConnected, @@ -47,21 +44,9 @@ RunnerWarmingUp, ) from exo.utils.channels import MpReceiver, MpSender -from exo.worker.engines.mlx.cache import KVPrefixCache -from exo.worker.engines.mlx.utils_mlx import ( - initialize_mlx, - load_mlx_items, -) -from exo.worker.engines.mlx.vision import VisionProcessor -from exo.worker.runner.bootstrap import logger -from exo.worker.runner.llm_inference.batch_generator import ( - BatchGenerator, - InferenceGenerator, - SequentialGenerator, -) +from exo.worker.engines.base import Builder, Engine -from .batch_generator import Cancelled, Finished -from .tool_parsers import make_mlx_parser +from .bootstrap import logger class ExitCode(str, Enum): @@ -73,13 +58,12 @@ class Runner: def __init__( self, bound_instance: BoundInstance, + builder: Builder, event_sender: MpSender[Event], task_receiver: MpReceiver[Task], - cancel_receiver: MpReceiver[TaskId], ): self.event_sender = event_sender self.task_receiver = task_receiver - self.cancel_receiver = cancel_receiver self.bound_instance = bound_instance self.instance, self.runner_id, self.shard_metadata = ( @@ -98,16 +82,12 @@ def __init__( self.setup_start_time = time.time() - self.generator: Builder | InferenceGenerator = Builder( - self.model_id, - self.event_sender, - self.cancel_receiver, - ) + self.generator: Builder | Engine = builder self.seen: set[TaskId] = set() self.active_tasks: dict[ TaskId, - TextGeneration, + GenerationTask, ] = {} logger.info("runner created") @@ -150,7 +130,7 @@ def handle_first_task(self, task: Task): self.update_status(RunnerConnecting()) self.acknowledge_task(task) - self.generator.group = initialize_mlx(self.bound_instance) + self.generator.connect(self.bound_instance) self.send_task_status(task.task_id, TaskStatus.Complete) self.update_status(RunnerConnected()) @@ -158,14 +138,7 @@ def handle_first_task(self, task: Task): # we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to case LoadModel() if isinstance(self.generator, Builder) and ( - ( - isinstance(self.current_status, RunnerConnected) - and self.generator.group is not None - ) - or ( - isinstance(self.current_status, RunnerIdle) - and self.generator.group is None - ) + isinstance(self.current_status, (RunnerConnected, RunnerIdle)) ): total_layers = ( self.shard_metadata.end_layer - self.shard_metadata.start_layer @@ -177,26 +150,11 @@ def handle_first_task(self, task: Task): ) self.acknowledge_task(task) - assert ( - ModelTask.TextGeneration in self.shard_metadata.model_card.tasks - ), f"Incorrect model task(s): {self.shard_metadata.model_card.tasks}" - - def load_model() -> Generator[ModelLoadingResponse]: - assert isinstance(self.generator, Builder) - ( - self.generator.inference_model, - self.generator.tokenizer, - self.generator.vision_processor, - ) = yield from load_mlx_items( - self.bound_instance, - self.generator.group, - ) - - for load_resp in load_model(): + for load_progress in self.generator.load(self.bound_instance): self.update_status( RunnerLoading( - layers_loaded=load_resp.layers_loaded, - total_layers=load_resp.total, + layers_loaded=load_progress.layers_loaded, + total_layers=load_progress.total, ) ) @@ -207,7 +165,7 @@ def load_model() -> Generator[ModelLoadingResponse]: logger.info("runner loaded") case StartWarmup() if isinstance(self.current_status, RunnerLoaded): - assert isinstance(self.generator, InferenceGenerator) + assert isinstance(self.generator, Engine) logger.info("runner warming up") self.update_status(RunnerWarmingUp()) @@ -223,7 +181,9 @@ def load_model() -> Generator[ModelLoadingResponse]: self.update_status(RunnerReady()) logger.info("runner ready") - case TextGeneration() if isinstance(self.current_status, RunnerReady): + case TextGeneration() | ImageEdits() | ImageGeneration() if isinstance( + self.current_status, RunnerReady + ): return_code = self.handle_generation_tasks(starting_task=task) if return_code == ExitCode.Shutdown: return @@ -241,23 +201,21 @@ def shutdown(self, task: Task): logger.info("runner shutting down") self.update_status(RunnerShuttingDown()) self.acknowledge_task(task) - if isinstance(self.generator, InferenceGenerator): - self.generator.close() - mx.clear_cache() + self.generator.close() import gc gc.collect() self.send_task_status(task.task_id, TaskStatus.Complete) self.update_status(RunnerShutdown()) - def submit_text_generation(self, task: TextGeneration): - assert isinstance(self.generator, InferenceGenerator) + def submit_generation(self, task: GenerationTask): + assert isinstance(self.generator, Engine) self.active_tasks[task.task_id] = task self.generator.submit(task) - def handle_generation_tasks(self, starting_task: TextGeneration): + def handle_generation_tasks(self, starting_task: GenerationTask): assert isinstance(self.current_status, RunnerReady) - assert isinstance(self.generator, InferenceGenerator) + assert isinstance(self.generator, Engine) logger.info(f"received chat request: {starting_task}") self.update_status(RunnerRunning()) @@ -265,7 +223,7 @@ def handle_generation_tasks(self, starting_task: TextGeneration): self.acknowledge_task(starting_task) self.seen.add(starting_task.task_id) - self.submit_text_generation(starting_task) + self.submit_generation(starting_task) while self.active_tasks: results = self.generator.step() @@ -273,13 +231,13 @@ def handle_generation_tasks(self, starting_task: TextGeneration): finished: list[TaskId] = [] for task_id, result in results: match result: - case Cancelled(): + case CancelledResponse(): finished.append(task_id) - case Finished(): + case FinishedResponse(): self.send_task_status(task_id, TaskStatus.Complete) finished.append(task_id) - case _: - self.send_chunk(result, self.active_tasks[task_id].command_id) + case other: + self.send_chunk(other, self.active_tasks[task_id].command_id) for task_id in finished: self.active_tasks.pop(task_id, None) @@ -293,9 +251,9 @@ def handle_generation_tasks(self, starting_task: TextGeneration): self.seen.add(task.task_id) match task: - case TextGeneration(): + case TextGeneration() | ImageEdits() | ImageGeneration(): self.acknowledge_task(task) - self.submit_text_generation(task) + self.submit_generation(task) case Shutdown(): self.shutdown(task) return ExitCode.Shutdown @@ -314,74 +272,8 @@ def handle_generation_tasks(self, starting_task: TextGeneration): def send_chunk( self, - chunk: GenerationChunk, + chunk: Chunk, command_id: CommandId, ): if self.device_rank == 0: self.event_sender.send(ChunkGenerated(command_id=command_id, chunk=chunk)) - - -@dataclass -class Builder: - model_id: ModelId - event_sender: MpSender[Event] - cancel_receiver: MpReceiver[TaskId] - inference_model: Model | None = None - tokenizer: TokenizerWrapper | None = None - group: mx.distributed.Group | None = None - vision_processor: VisionProcessor | None = None - - def build( - self, - ) -> InferenceGenerator: - assert self.model_id - assert self.inference_model - assert self.tokenizer - - vision_processor = self.vision_processor - - tool_parser = None - logger.info( - f"model has_tool_calling={self.tokenizer.has_tool_calling} using tokens {self.tokenizer.tool_call_start}, {self.tokenizer.tool_call_end}" - ) - if ( - self.tokenizer.tool_call_start - and self.tokenizer.tool_call_end - and self.tokenizer.tool_parser # type: ignore - ): - tool_parser = make_mlx_parser( - self.tokenizer.tool_call_start, - self.tokenizer.tool_call_end, - self.tokenizer.tool_parser, # type: ignore - ) - - kv_prefix_cache = KVPrefixCache(self.group) - - device_rank = 0 if self.group is None else self.group.rank() - if os.environ.get("EXO_NO_BATCH"): - logger.info("using SequentialGenerator (batching disabled)") - return SequentialGenerator( - model=self.inference_model, - tokenizer=self.tokenizer, - group=self.group, - tool_parser=tool_parser, - kv_prefix_cache=kv_prefix_cache, - model_id=self.model_id, - device_rank=device_rank, - cancel_receiver=self.cancel_receiver, - event_sender=self.event_sender, - vision_processor=vision_processor, - ) - logger.info("using BatchGenerator") - return BatchGenerator( - model=self.inference_model, - tokenizer=self.tokenizer, - group=self.group, - tool_parser=tool_parser, - kv_prefix_cache=kv_prefix_cache, - model_id=self.model_id, - device_rank=device_rank, - cancel_receiver=self.cancel_receiver, - event_sender=self.event_sender, - vision_processor=vision_processor, - ) diff --git a/src/exo/worker/runner/runner_supervisor.py b/src/exo/worker/runner/supervisor.py similarity index 100% rename from src/exo/worker/runner/runner_supervisor.py rename to src/exo/worker/runner/supervisor.py diff --git a/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py b/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py index a4133b1ca7..2fc4c752c9 100644 --- a/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py +++ b/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py @@ -1,15 +1,13 @@ # Check tasks are complete before runner is ever ready. -import unittest.mock from collections.abc import Iterable from dataclasses import dataclass from typing import Callable -import mlx.core as mx import pytest +import exo.worker.engines.mlx.builder as mlx_builder import exo.worker.runner.llm_inference.batch_generator as mlx_batch_generator import exo.worker.runner.llm_inference.model_output_parsers as mlx_model_output_parsers -import exo.worker.runner.llm_inference.runner as mlx_runner from exo.shared.types.chunks import TokenChunk from exo.shared.types.events import ( ChunkGenerated, @@ -47,6 +45,8 @@ RunnerWarmingUp, ) from exo.utils.channels import mp_channel +from exo.worker.engines.mlx.builder import MlxBuilder +from exo.worker.runner.runner import Runner from ...constants import ( CHAT_COMPLETION_TASK_ID, @@ -125,13 +125,13 @@ class MockLoadOutput: @pytest.fixture def patch_out_mlx(monkeypatch: pytest.MonkeyPatch): # initialize_mlx returns a mock group - monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup())) + monkeypatch.setattr(mlx_builder, "initialize_mlx", make_nothin(MockGroup())) def lmi_gen(): yield MockLoadOutput(1, 1) return (1, MockTokenizer, None) - monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin(lmi_gen())) + monkeypatch.setattr(mlx_builder, "load_mlx_items", make_nothin(lmi_gen())) monkeypatch.setattr(mlx_batch_generator, "warmup_inference", make_nothin(1)) monkeypatch.setattr(mlx_batch_generator, "_check_for_debug_prompts", nothin) monkeypatch.setattr(mlx_batch_generator, "mx_any", make_nothin(False)) @@ -274,17 +274,18 @@ def _on_event(event: Event) -> None: # this is some c++ nonsense task_receiver.close = nothin task_receiver.join = nothin - with unittest.mock.patch( - "exo.worker.runner.llm_inference.runner.mx.distributed.all_gather", - make_nothin(mx.array([1])), - ): - runner = mlx_runner.Runner( - bound_instance, - event_sender, # pyright: ignore[reportArgumentType] - task_receiver, - cancel_receiver, - ) - runner.main() + builder = MlxBuilder( + bound_instance.bound_shard.model_card.model_id, + event_sender, # pyright: ignore[reportArgumentType] + cancel_receiver, + ) + runner = Runner( + bound_instance, + builder, + event_sender, # pyright: ignore[reportArgumentType] + task_receiver, + ) + runner.main() return event_sender.events diff --git a/src/exo/worker/tests/unittests/test_runner/test_runner_supervisor.py b/src/exo/worker/tests/unittests/test_runner/test_runner_supervisor.py index 82612d6d67..39c991a193 100644 --- a/src/exo/worker/tests/unittests/test_runner/test_runner_supervisor.py +++ b/src/exo/worker/tests/unittests/test_runner/test_runner_supervisor.py @@ -17,7 +17,7 @@ from exo.shared.types.worker.instances import BoundInstance, InstanceId from exo.shared.types.worker.runners import RunnerFailed, RunnerId from exo.utils.channels import channel, mp_channel -from exo.worker.runner.runner_supervisor import RunnerSupervisor +from exo.worker.runner.supervisor import RunnerSupervisor from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance