diff --git a/.github/actions/run-test/action.yml b/.github/actions/run-test/action.yml index 6cdef7357..cd9348a80 100644 --- a/.github/actions/run-test/action.yml +++ b/.github/actions/run-test/action.yml @@ -55,7 +55,13 @@ runs: run: | uv pip install --system -r examples/applications/requirements_applications.txt uv pip install --system -r examples/ray_compat/requirements.txt + readarray -t skip_examples < examples/skip_examples.txt for example in "./examples"/*.py; do + filename=$(basename "$example") + if [[ " ${skip_examples[*]} " =~ [[:space:]]${filename}[[:space:]] ]]; then + echo "Skipping $example" + continue + fi echo "Running $example" python $example done diff --git a/.gitignore b/.gitignore index 1d0bc0bd1..f06bb07cd 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,6 @@ CMakeFiles/ .vs/ src/scaler/protocol/capnp/*.c++ src/scaler/protocol/capnp/*.h + +orb/logs/ +orb/metrics/ diff --git a/README.md b/README.md index cd64f35fd..f9707396b 100644 --- a/README.md +++ b/README.md @@ -250,6 +250,8 @@ The following table maps each Scaler command to its corresponding section name i | `scaler_worker_adapter_native` | `[native_worker_adapter]` | | `scaler_worker_adapter_fixed_native` | `[fixed_native_worker_adapter]` | | `scaler_worker_adapter_symphony` | `[symphony_worker_adapter]` | +| `scaler_worker_manager_orb` | `[orb_worker_adapter]` | +| `scaler_worker_adapter_ecs` | `[ecs_worker_adapter]` | ### Practical Scenarios & Examples @@ -466,6 +468,26 @@ where `deepest_nesting_level` is the deepest nesting level a task has in your wo workload that has a base task that calls a nested task that calls another nested task, then the deepest nesting level is 2. +## ORB (AWS EC2) integration + +A Scaler scheduler can interface with ORB (Open Resource Broker) to dynamically provision and manage workers on AWS EC2 instances. + +```bash +$ scaler_worker_manager_orb tcp://127.0.0.1:2345 --image-id ami-0528819f94f4f5fa5 +``` + +This will start an ORB worker adapter that connects to the Scaler scheduler at `tcp://127.0.0.1:2345`. The scheduler can then request new workers from this adapter, which will be launched as EC2 instances. + +### Configuration + +The ORB adapter requires `orb-py` and `boto3` to be installed. You can install them with: + +```bash +$ pip install "opengris-scaler[orb]" +``` + +For more details on configuring ORB, including AWS credentials and instance templates, please refer to the [ORB Worker Adapter documentation](https://finos.github.io/opengris-scaler/tutorials/worker_adapters/orb.html). + ## Worker Adapter usage > **Note**: This feature is experimental and may change in future releases. diff --git a/docs/source/index.rst b/docs/source/index.rst index 997a488d8..562420bb5 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -30,6 +30,7 @@ Content tutorials/worker_adapters/index tutorials/worker_adapters/native tutorials/worker_adapters/fixed_native + tutorials/worker_adapters/orb tutorials/worker_adapters/aws_hpc/index tutorials/worker_adapters/common_parameters tutorials/compatibility/ray diff --git a/docs/source/tutorials/compatibility/ray.rst b/docs/source/tutorials/compatibility/ray.rst index 8a2ac3d5d..0f59cce34 100644 --- a/docs/source/tutorials/compatibility/ray.rst +++ b/docs/source/tutorials/compatibility/ray.rst @@ -6,7 +6,7 @@ Ray Compatibility Layer Scaler is a lightweight distributed computation engine similar to Ray. Scaler supports many of the same concepts as Ray including remote functions (known as tasks in Scaler), futures, cluster object storage, labels (known as capabilities in Scaler), and it comes with comparable monitoring tools. -Unlike Ray, Scaler supports both local clusters and also easily integrates with multiple cloud providers out of the box, including AWS EC2 and IBM Symphony, +Unlike Ray, Scaler supports both local clusters and also easily integrates with multiple cloud providers out of the box, including ORB (AWS EC2) and IBM Symphony, with more providers planned for the future. You can view our `roadmap on GitHub `_ for details on upcoming cloud integrations. diff --git a/docs/source/tutorials/configuration.rst b/docs/source/tutorials/configuration.rst index fbdd2b06d..166611334 100644 --- a/docs/source/tutorials/configuration.rst +++ b/docs/source/tutorials/configuration.rst @@ -199,6 +199,8 @@ The following table maps each Scaler command to its corresponding section name i - ``[fixed_native_worker_adapter]`` * - ``scaler_worker_adapter_symphony`` - ``[symphony_worker_adapter]`` + * - ``scaler_worker_manager_orb`` + - ``[orb_worker_adapter]`` * - ``scaler_worker_adapter_ecs`` - ``[ecs_worker_adapter]`` * - ``python -m scaler.entry_points.worker_adapter_aws_hpc`` diff --git a/docs/source/tutorials/examples.rst b/docs/source/tutorials/examples.rst index c01f15284..b211baa47 100644 --- a/docs/source/tutorials/examples.rst +++ b/docs/source/tutorials/examples.rst @@ -15,6 +15,14 @@ Shows how to send a basic task to scheduler .. literalinclude:: ../../../examples/simple_client.py :language: python +Submit Tasks +~~~~~~~~~~~~ + +Shows various ways to submit tasks (submit, map, starmap) + +.. literalinclude:: ../../../examples/submit_tasks.py + :language: python + Client Mapping Tasks ~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/tutorials/worker_adapters/common_parameters.rst b/docs/source/tutorials/worker_adapters/common_parameters.rst index b5a53e415..8f6d4dc08 100644 --- a/docs/source/tutorials/worker_adapters/common_parameters.rst +++ b/docs/source/tutorials/worker_adapters/common_parameters.rst @@ -1,7 +1,7 @@ Common Worker Adapter Parameters ================================ -All worker adapters in Scaler share a set of common configuration parameters for connecting to the cluster, configuring the internal web server, and managing worker behavior. +All worker adapters in Scaler share a set of common configuration parameters for connecting to the cluster and managing worker behavior. .. note:: For more details on how to configure Scaler, see the :doc:`../configuration` section. diff --git a/docs/source/tutorials/worker_adapters/index.rst b/docs/source/tutorials/worker_adapters/index.rst index de1d8d1af..a505894fa 100644 --- a/docs/source/tutorials/worker_adapters/index.rst +++ b/docs/source/tutorials/worker_adapters/index.rst @@ -46,6 +46,11 @@ AWS HPC The :doc:`AWS HPC ` worker adapter allows Scaler to offload task execution to cloud environments, currently supporting AWS Batch. It is ideal for bursting workloads to the cloud or utilizing specific hardware not available locally. +ORB (AWS EC2) +~~~~~~~~~~~~~ + +The :doc:`ORB ` worker adapter allows Scaler to dynamically provision workers on AWS EC2 instances. This is ideal for scaling workloads that require significant cloud compute resources or specialized hardware like GPUs. + Common Parameters ~~~~~~~~~~~~~~~~~ @@ -56,5 +61,6 @@ All worker adapters share a set of :doc:`common configuration parameters :8516 \ + --object-storage-address tcp://:8517 \ + --image-id ami-0528819f94f4f5fa5 \ + --instance-type t3.medium \ + --aws-region us-east-1 \ + --logging-level INFO \ + --task-timeout-seconds 60 + +Equivalent configuration using a TOML file: + +.. code-block:: bash + + scaler_worker_manager_orb tcp://:8516 --config config.toml + +.. code-block:: toml + + # config.toml + + [orb_worker_adapter] + object_storage_address = "tcp://:8517" + image_id = "ami-0528819f94f4f5fa5" + instance_type = "t3.medium" + aws_region = "us-east-1" + logging_level = "INFO" + task_timeout_seconds = 60 + +* ``tcp://:8516`` is the address workers will use to connect to the scheduler. +* ``tcp://:8517`` is the address workers will use to connect to the object storage server. +* New workers will be launched using the specified AMI and instance type. + +Networking Configuration +------------------------ + +Workers launched by the ORB adapter are EC2 instances and require an externally-reachable IP address for the scheduler. + +* **Internal Communication**: If the machine running the scheduler is another EC2 instance in the same VPC, you can use EC2 private IP addresses. +* **Public Internet**: If communicating over the public internet, it is highly recommended to set up robust security rules and/or a VPN to protect the cluster. + +Publicly Available AMIs +----------------------- + +We regularly publish publicly available Amazon Machine Images (AMIs) with Python and ``opengris-scaler`` pre-installed. + +.. list-table:: Available Public AMIs + :widths: 15 15 20 20 30 + :header-rows: 1 + + * - Scaler Version + - Python Version + - Amazon Linux 2023 Version + - Date (MM/DD/YYYY) + - AMI ID (us-east-1) + * - 1.14.2 + - 3.13 + - 2023.10.20260120 + - 01/30/2026 + - ``ami-0528819f94f4f5fa5`` + +New AMIs will be added to this list as they become available. + +Supported Parameters +-------------------- + +.. note:: + For more details on how to configure Scaler, see the :doc:`../configuration` section. + +The ORB worker adapter supports ORB-specific configuration parameters as well as common worker adapter parameters. + +Orb Template Configuration +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +* ``--image-id`` (Required): AMI ID for the worker instances. +* ``--instance-type``: EC2 instance type (default: ``t2.micro``). +* ``--aws-region``: AWS region (default: ``us-east-1``). +* ``--key-name``: AWS key pair name for the instances. If not provided, a temporary key pair will be created and deleted on cleanup. +* ``--subnet-id``: AWS subnet ID where the instances will be launched. If not provided, it attempts to discover the default subnet in the default VPC. +* ``--security-group-ids``: Comma-separated list of AWS security group IDs. +* ``--allowed-ip``: IP address to allow in the security group (if created automatically). Defaults to the adapter's external IP. +* ``--orb-config-path``: Path to the ORB root directory (default: ``src/scaler/drivers/orb``). + +Common Parameters +~~~~~~~~~~~~~~~~~ + +For a full list of common parameters including networking, worker configuration, and logging, see :doc:`common_parameters`. + +Cleanup +------- + +The ORB worker adapter is designed to be self-cleaning, but it is important to be aware of the resources it manages: + +* **Key Pairs**: If a ``--key-name`` is not provided, the adapter creates a temporary AWS key pair. +* **Security Groups**: If ``--security-group-ids`` are not provided, the adapter creates a temporary security group to allow communication. +* **Launch Templates**: ORB may additionally create EC2 Launch Templates as part of the machine provisioning process. + +The adapter attempts to delete these temporary resources and terminate all launched EC2 instances when it shuts down gracefully. However, in the event of an ungraceful crash or network failure, some resources may persist in your AWS account. + +.. tip:: + It is recommended to periodically check your AWS console for any orphaned resources (instances, security groups, key pairs, or launch templates) and clean them up manually if necessary to avoid unexpected costs. + +.. warning:: + **Subnet and Security Groups**: Currently, specifying ``--subnet-id`` or ``--security-group-ids`` via configuration might not have the intended effect as the adapter is designed to auto-discover or create these resources. Specifically, the adapter may still attempt to use default subnets or create its own temporary security groups regardless of these parameters. diff --git a/examples/readme.md b/examples/readme.md index be996f185..c0865c381 100644 --- a/examples/readme.md +++ b/examples/readme.md @@ -17,6 +17,8 @@ Ensure that the scheduler and cluster are set up before running clients. Shows how to send a nested task to scheduler - `simple_client.py` Shows how to send a basic task to scheduler +- `submit_tasks.py` + Shows various ways to submit tasks (submit, map, starmap) - `task_capabilities.py` Shows how to use capabilities to route task to various workers - `ray_compat/` diff --git a/examples/skip_examples.txt b/examples/skip_examples.txt new file mode 100644 index 000000000..e69de29bb diff --git a/examples/submit_tasks.py b/examples/submit_tasks.py new file mode 100644 index 000000000..013166339 --- /dev/null +++ b/examples/submit_tasks.py @@ -0,0 +1,67 @@ +""" +This example demonstrates various ways to submit tasks to a Scaler scheduler. +It shows how to use the Client to: +1. Submit a single task using .submit() +2. Submit multiple tasks using .map() +3. Submit tasks with multiple arguments using .map() and .starmap() +""" + +import argparse +import math + +from scaler import Client, SchedulerClusterCombo + + +def square(value: int): + return value * value + + +def add(x: int, y: int): + return x + y + + +def main(): + parser = argparse.ArgumentParser(description="Submit tasks to a Scaler scheduler.") + parser.add_argument("url", nargs="?", help="The URL of the Scaler scheduler (e.g., tcp://127.0.0.1:2345)") + args = parser.parse_args() + + cluster = None + if args.url is None: + + print("No scheduler URL provided. Spinning up a local cluster...") + cluster = SchedulerClusterCombo(n_workers=4) + address = cluster.get_address() + else: + address = args.url + + try: + print(f"Connecting to scheduler at {address}...") + + # Use the Client as a context manager to ensure proper cleanup + with Client(address=address) as client: + print("Submitting a single task using .submit()...") + future = client.submit(square, 4) + print(f"Result of square(4): {future.result()}") + + print("\nSubmitting multiple tasks using .map()...") + # client.map() works like Python's built-in map() + results = client.map(math.sqrt, range(1, 6)) + print(f"Results of sqrt(1..5): {list(results)}") + + print("\nSubmitting tasks with multiple arguments using .map()...") + # You can pass multiple iterables to map() for functions with multiple arguments + results_add = client.map(add, [1, 2, 3], [10, 20, 30]) + print(f"Results of add([1,2,3], [10,20,30]): {list(results_add)}") + + print("\nSubmitting tasks with multiple arguments using .starmap()...") + # starmap() takes an iterable of argument tuples + results_starmap = client.starmap(add, [(5, 5), (10, 10)]) + print(f"Results of starmap(add, [(5,5), (10,10)]): {list(results_starmap)}") + finally: + if cluster: + cluster.shutdown() + print("\nAll tasks completed successfully.") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 62f94f5d5..275f9cac8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,10 @@ graphblas = [ aws = [ "boto3", ] +orb = [ + "orb-py~=1.1; python_version >= '3.10'", + "boto3; python_version >= '3.10'", +] all = [ "nicegui[plotly]==2.24.2; python_version == '3.8'", "nicegui[plotly]==3.6.1; python_version >= '3.9'", @@ -63,6 +67,7 @@ all = [ "numpy==2.2.6; python_version >= '3.10'", "uvloop; platform_system != 'Windows'", "boto3", + "orb-py~=1.1; python_version >= '3.10'" ] [dependency-groups] @@ -99,6 +104,7 @@ scaler_worker_manager_baremetal_fixed_native = "scaler.entry_points.worker_manag scaler_worker_manager_symphony = "scaler.entry_points.worker_manager_symphony:main" scaler_worker_manager_aws_raw_ecs = "scaler.entry_points.worker_manager_aws_raw_ecs:main" scaler_worker_manager_aws_hpc_batch = "scaler.entry_points.worker_manager_aws_hpc_batch:main" +scaler_worker_manager_orb = "scaler.entry_points.worker_manager_orb:main" [tool.scikit-build] cmake.source-dir = "." diff --git a/src/run_worker_manager_orb.py b/src/run_worker_manager_orb.py new file mode 100644 index 000000000..fd8b02bd9 --- /dev/null +++ b/src/run_worker_manager_orb.py @@ -0,0 +1,5 @@ +from scaler.entry_points.worker_manager_orb import main +from scaler.utility.debug import pdb_wrapped + +if __name__ == "__main__": + pdb_wrapped(main)() diff --git a/src/scaler/cluster/cluster.py b/src/scaler/cluster/cluster.py index e45fbaa9c..84fbea271 100644 --- a/src/scaler/cluster/cluster.py +++ b/src/scaler/cluster/cluster.py @@ -34,6 +34,7 @@ def __init__(self, config: ClusterConfig): self._logging_paths = config.logging_config.paths self._logging_config_file = config.logging_config.config_file self._logging_level = config.logging_config.level + self._deterministic_worker_ids = config.deterministic_worker_ids # we create the config here, but create the actual adapter in the run method # to ensure that it's created in the correct process diff --git a/src/scaler/config/section/cluster.py b/src/scaler/config/section/cluster.py index 1ec0994dd..9fa3988fd 100644 --- a/src/scaler/config/section/cluster.py +++ b/src/scaler/config/section/cluster.py @@ -53,6 +53,10 @@ class ClusterConfig(ConfigClass): ) worker_config: WorkerConfig = dataclasses.field(default_factory=WorkerConfig) logging_config: LoggingConfig = dataclasses.field(default_factory=LoggingConfig) + deterministic_worker_ids: bool = dataclasses.field( + default=False, + metadata=dict(short="-dwi", action="store_true", help="enable deterministic worker id generation"), + ) def __post_init__(self): if self.worker_names.names and len(self.worker_names.names) != self.num_of_workers: diff --git a/src/scaler/config/section/orb_worker_adapter.py b/src/scaler/config/section/orb_worker_adapter.py new file mode 100644 index 000000000..98c9f7b28 --- /dev/null +++ b/src/scaler/config/section/orb_worker_adapter.py @@ -0,0 +1,57 @@ +import dataclasses +import pathlib +from typing import List, Optional + +from scaler.config import defaults +from scaler.config.common.logging import LoggingConfig +from scaler.config.common.worker import WorkerConfig +from scaler.config.common.worker_adapter import WorkerAdapterConfig +from scaler.config.config_class import ConfigClass +from scaler.utility.event_loop import EventLoopType + +_DEFAULT_ORB_CONFIG_PATH = str(pathlib.Path(__file__).parent.parent.parent / "worker_manager_adapter" / "orb") + + +@dataclasses.dataclass +class ORBWorkerAdapterConfig(ConfigClass): + """Configuration for the ORB worker adapter.""" + + worker_adapter_config: WorkerAdapterConfig + + # ORB Template configuration + image_id: str = dataclasses.field(metadata=dict(help="AMI ID for the worker instances", required=True)) + key_name: Optional[str] = dataclasses.field( + default=None, metadata=dict(help="AWS key pair name for the instances (optional)") + ) + subnet_id: Optional[str] = dataclasses.field( + default=None, metadata=dict(help="AWS subnet ID where the instances will be launched (optional)") + ) + + worker_config: WorkerConfig = dataclasses.field(default_factory=WorkerConfig) + logging_config: LoggingConfig = dataclasses.field(default_factory=LoggingConfig) + event_loop: str = dataclasses.field( + default="builtin", + metadata=dict(short="-el", choices=EventLoopType.allowed_types(), help="select the event loop type"), + ) + + worker_io_threads: int = dataclasses.field( + default=defaults.DEFAULT_IO_THREADS, + metadata=dict(short="-wit", help="set the number of io threads for io backend per worker"), + ) + + orb_config_path: str = dataclasses.field( + default=_DEFAULT_ORB_CONFIG_PATH, metadata=dict(help="Path to the ORB root directory") + ) + + instance_type: str = dataclasses.field(default="t2.micro", metadata=dict(help="EC2 instance type")) + aws_region: Optional[str] = dataclasses.field(default="us-east-1", metadata=dict(help="AWS region")) + security_group_ids: List[str] = dataclasses.field( + default_factory=list, + metadata=dict( + type=lambda s: [x for x in s.split(",") if x], help="Comma-separated list of AWS security group IDs" + ), + ) + + def __post_init__(self) -> None: + if self.worker_io_threads <= 0: + raise ValueError("worker_io_threads must be a positive integer.") diff --git a/src/scaler/drivers/ami/build.sh b/src/scaler/drivers/ami/build.sh new file mode 100755 index 000000000..c073caa06 --- /dev/null +++ b/src/scaler/drivers/ami/build.sh @@ -0,0 +1,22 @@ +#!/bin/bash +set -e +set -x + +# This script builds the AMI for opengris-scaler using Packer +# It reads the version from the version.txt file and passes it as a variable + +# Get the directory where the script is located +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +VERSION_FILE="$SCRIPT_DIR/../../version.txt" + +if [ ! -f "$VERSION_FILE" ]; then + echo "Error: Version file not found at $VERSION_FILE" + exit 1 +fi + +VERSION=$(cat "$VERSION_FILE" | tr -d '[:space:]') + +echo "Building AMI for version: $VERSION" + +cd "$SCRIPT_DIR" +packer build -var "version=$VERSION" opengris-scaler.pkr.hcl diff --git a/src/scaler/drivers/ami/opengris-scaler.pkr.hcl b/src/scaler/drivers/ami/opengris-scaler.pkr.hcl new file mode 100644 index 000000000..8611a939f --- /dev/null +++ b/src/scaler/drivers/ami/opengris-scaler.pkr.hcl @@ -0,0 +1,69 @@ +packer { + required_plugins { + amazon = { + version = "~> 1" + source = "github.com/hashicorp/amazon" + } + } +} + +variable "aws_region" { + type = string + default = "us-east-1" +} + +variable "version" { + type = string +} + +variable "ami_regions" { + type = list(string) + default = [] + description = "A list of regions to copy the AMI to." +} + +variable "ami_groups" { + type = list(string) + default = ["all"] + description = "A list of groups to share the AMI with. Set to ['all'] to make public." +} + +variable "python_version" { + type = string + default = "3.13" +} + +source "amazon-ebs" "opengris-scaler" { + ami_name = "opengris-scaler-${var.version}-py${var.python_version}" + instance_type = "t2.small" + region = var.aws_region + ami_regions = var.ami_regions + ami_groups = var.ami_groups + source_ami_filter { + filters = { + name = "al2023-ami-2023.*-kernel-*-x86_64" + root-device-type = "ebs" + virtualization-type = "hvm" + } + most_recent = true + owners = ["amazon"] + } + ssh_username = "ec2-user" +} + +build { + name = "opengris-scaler-build" + sources = ["source.amazon-ebs.opengris-scaler"] + + provisioner "shell" { + inline = [ + "sudo dnf update -y", + "sudo dnf install -y python${var.python_version} python${var.python_version}-pip", + "sudo python${var.python_version} -m venv /opt/opengris-scaler", + "sudo /opt/opengris-scaler/bin/python -m pip install --upgrade pip", + "sudo /opt/opengris-scaler/bin/pip install opengris-scaler==${var.version}", + "sudo ln -sf /opt/opengris-scaler/bin/scaler_* /usr/local/bin/", + "sudo ln -sf /opt/opengris-scaler/bin/python /usr/local/bin/opengris-python" + ] + } +} diff --git a/src/scaler/entry_points/worker_manager_orb.py b/src/scaler/entry_points/worker_manager_orb.py new file mode 100644 index 000000000..0c60b5706 --- /dev/null +++ b/src/scaler/entry_points/worker_manager_orb.py @@ -0,0 +1,20 @@ +from scaler.config.section.orb_worker_adapter import ORBWorkerAdapterConfig +from scaler.utility.logging.utility import setup_logger +from scaler.worker_manager_adapter.orb.worker_manager import ORBWorkerAdapter + + +def main(): + orb_adapter_config = ORBWorkerAdapterConfig.parse("Scaler ORB Worker Adapter", "orb_worker_adapter") + + setup_logger( + orb_adapter_config.logging_config.paths, + orb_adapter_config.logging_config.config_file, + orb_adapter_config.logging_config.level, + ) + + orb_worker_adapter = ORBWorkerAdapter(orb_adapter_config) + orb_worker_adapter.run() + + +if __name__ == "__main__": + main() diff --git a/src/scaler/io/uv_ymq/__init__.py b/src/scaler/io/uv_ymq/__init__.py index 72a56723d..3263fd244 100644 --- a/src/scaler/io/uv_ymq/__init__.py +++ b/src/scaler/io/uv_ymq/__init__.py @@ -10,13 +10,5 @@ "UVYMQException", ] -from scaler.io.uv_ymq._uv_ymq import ( - Address, - AddressType, - Bytes, - ErrorCode, - IOContext, - Message, - UVYMQException, -) +from scaler.io.uv_ymq._uv_ymq import Address, AddressType, Bytes, ErrorCode, IOContext, Message, UVYMQException from scaler.io.uv_ymq.sockets import BinderSocket, ConnectorSocket diff --git a/src/scaler/io/uv_ymq/_uv_ymq.pyi b/src/scaler/io/uv_ymq/_uv_ymq.pyi index 61941072e..73d3fe6a1 100644 --- a/src/scaler/io/uv_ymq/_uv_ymq.pyi +++ b/src/scaler/io/uv_ymq/_uv_ymq.pyi @@ -81,7 +81,6 @@ class BinderSocket: """Create a BinderSocket with the specified identity.""" def __repr__(self) -> str: ... - def bind_to(self, callback: Callable[[Union[Address, Exception]], None], address: str) -> None: """Bind the socket to an address and listen for incoming connections.""" @@ -114,7 +113,6 @@ class ConnectorSocket: """Create a ConnectorSocket and initiate connection to the remote address.""" def __repr__(self) -> str: ... - def send_message(self, callback: Callable[[Optional[Exception]], None], message_payload: Bytes) -> None: """Send a message to the connected remote peer.""" diff --git a/src/scaler/scheduler/controllers/worker_adapter_controller.py b/src/scaler/scheduler/controllers/worker_adapter_controller.py index 98b370862..83d0205c9 100644 --- a/src/scaler/scheduler/controllers/worker_adapter_controller.py +++ b/src/scaler/scheduler/controllers/worker_adapter_controller.py @@ -64,6 +64,12 @@ async def on_heartbeat(self, source: bytes, heartbeat: WorkerAdapterHeartbeat): worker_groups = {gid: info.worker_ids for gid, info in adapter_groups.items()} worker_group_capabilities = {gid: info.capabilities for gid, info in adapter_groups.items()} + # Wait for the previous command to complete before sending another. + # Adapters can take a long time to fulfill commands (e.g. ORB polls for instance IDs), + # so sending a new command before the response arrives causes duplicate work and errors. + if source in self._pending_commands: + return + commands = self._scaler_policy.get_scaling_commands( information_snapshot, heartbeat, worker_groups, worker_group_capabilities ) diff --git a/src/scaler/utility/dict_utils.py b/src/scaler/utility/dict_utils.py new file mode 100644 index 000000000..97a0b8429 --- /dev/null +++ b/src/scaler/utility/dict_utils.py @@ -0,0 +1,38 @@ +import re +from typing import Any + + +def to_camel_case(snake_str: str) -> str: + components = snake_str.split("_") + return components[0] + "".join(x.title() for x in components[1:]) + + +def to_snake_case(camel_str: str) -> str: + pattern = re.compile(r"(? Any: + if isinstance(d, dict): + new_d = {} + for k, v in d.items(): + new_key = to_camel_case(k) if isinstance(k, str) else k + new_d[new_key] = camelcase_dict(v) + return new_d + elif isinstance(d, list): + return [camelcase_dict(i) for i in d] + else: + return d + + +def snakecase_dict(d: Any) -> Any: + if isinstance(d, dict): + new_d = {} + for k, v in d.items(): + new_key = to_snake_case(k) if isinstance(k, str) else k + new_d[new_key] = snakecase_dict(v) + return new_d + elif isinstance(d, list): + return [snakecase_dict(i) for i in d] + else: + return d diff --git a/src/scaler/worker/worker.py b/src/scaler/worker/worker.py index c96e8f087..acc47a9b7 100644 --- a/src/scaler/worker/worker.py +++ b/src/scaler/worker/worker.py @@ -64,6 +64,7 @@ def __init__( hard_processor_suspend: bool, logging_paths: Tuple[str, ...], logging_level: str, + deterministic_worker_ids: bool = False, ): multiprocessing.Process.__init__(self, name="Agent") @@ -76,7 +77,10 @@ def __init__( self._io_threads = io_threads self._task_queue_size = task_queue_size - self._ident = WorkerID.generate_worker_id(name) # _identity is internal to multiprocessing.Process + if deterministic_worker_ids: + self._ident = WorkerID(name.encode()) + else: + self._ident = WorkerID.generate_worker_id(name) self._address_path_internal = os.path.join(tempfile.gettempdir(), f"scaler_worker_{uuid.uuid4().hex}") self._address_internal = ZMQConfig(ZMQType.ipc, host=self._address_path_internal) diff --git a/src/scaler/worker_manager_adapter/orb/__init__.py b/src/scaler/worker_manager_adapter/orb/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/scaler/worker_manager_adapter/orb/config/config.json b/src/scaler/worker_manager_adapter/orb/config/config.json new file mode 100644 index 000000000..e309deddd --- /dev/null +++ b/src/scaler/worker_manager_adapter/orb/config/config.json @@ -0,0 +1,35 @@ +{ + "version": "2.0.0", + "provider": { + "active_provider": "aws-default", + "selection_policy": "FIRST_AVAILABLE", + "providers": [ + { + "name": "aws-default", + "type": "aws", + "enabled": true, + "priority": 1, + "config": { + "region": "us-east-1", + "profile": "default" + } + } + ] + }, + "storage": { + "strategy": "json", + "json_strategy": { + "storage_type": "single_file", + "base_path": "data", + "filenames": { + "single_file": "request_database.json" + } + } + }, + "logging": { + "level": "INFO", + "file_path": "logs/app.log", + "console_enabled": true + } +} + diff --git a/src/scaler/worker_manager_adapter/orb/config/default_config.json b/src/scaler/worker_manager_adapter/orb/config/default_config.json new file mode 100644 index 000000000..e309deddd --- /dev/null +++ b/src/scaler/worker_manager_adapter/orb/config/default_config.json @@ -0,0 +1,35 @@ +{ + "version": "2.0.0", + "provider": { + "active_provider": "aws-default", + "selection_policy": "FIRST_AVAILABLE", + "providers": [ + { + "name": "aws-default", + "type": "aws", + "enabled": true, + "priority": 1, + "config": { + "region": "us-east-1", + "profile": "default" + } + } + ] + }, + "storage": { + "strategy": "json", + "json_strategy": { + "storage_type": "single_file", + "base_path": "data", + "filenames": { + "single_file": "request_database.json" + } + } + }, + "logging": { + "level": "INFO", + "file_path": "logs/app.log", + "console_enabled": true + } +} + diff --git a/src/scaler/worker_manager_adapter/orb/exception.py b/src/scaler/worker_manager_adapter/orb/exception.py new file mode 100644 index 000000000..9ae10cbe9 --- /dev/null +++ b/src/scaler/worker_manager_adapter/orb/exception.py @@ -0,0 +1,9 @@ +from typing import Any + + +class ORBException(Exception): + """Exception raised for errors in ORB operations.""" + + def __init__(self, data: Any): + self.data = data + super().__init__(f"ORB Exception: {data}") diff --git a/src/scaler/worker_manager_adapter/orb/helper.py b/src/scaler/worker_manager_adapter/orb/helper.py new file mode 100644 index 000000000..5a54c2e54 --- /dev/null +++ b/src/scaler/worker_manager_adapter/orb/helper.py @@ -0,0 +1,240 @@ +import json +import os +import shutil +import subprocess +import tempfile +from os import path +from typing import Any, Dict, List, Optional + +from scaler.utility.dict_utils import snakecase_dict +from scaler.worker_manager_adapter.orb.exception import ORBException +from scaler.worker_manager_adapter.orb.types import ORBMachine, ORBRequest, ORBTemplate + + +class ORBHelper: + """Helper class to interact with the ORB CLI.""" + + @staticmethod + def _filter_data(cls: Any, data: Dict[str, Any]) -> Dict[str, Any]: + """Filter data to match dataclass fields.""" + if not hasattr(cls, "__annotations__"): + return data + valid_keys = cls.__annotations__.keys() + return {k: v for k, v in data.items() if k in valid_keys} + + class Templates: + """API for managing compute templates.""" + + def __init__(self, helper: "ORBHelper"): + self.helper = helper + + def _command(self, command: List[str]) -> Any: + return self.helper._command(["templates", *command]) + + def list(self) -> List[ORBTemplate]: + """List all templates.""" + data = self._command(["list"]) + # Handle case where list command returns a dict with items or a list + if isinstance(data, list): + items = data + else: + items = data.get("templates") or data.get("items") or [] + + # Convert items to snake_case and filter + result = [] + for item in items: + snake_item = snakecase_dict(item) + filtered_item = ORBHelper._filter_data(ORBTemplate, snake_item) + result.append(ORBTemplate(**filtered_item)) + return result + + def create(self, config_file_path: Optional[str] = None, template_id: Optional[str] = None) -> ORBTemplate: + """Create a new template from a JSON configuration file.""" + cmd = ["create"] + if config_file_path: + cmd.extend(["--file", config_file_path]) + if template_id: + cmd.extend(["--template-id", template_id]) + + data = self._command(cmd) + # Handle possible nesting in creation response + if isinstance(data, dict) and "templates" in data and data["templates"]: + data = data["templates"][0] + + snake_data = snakecase_dict(data) + return ORBTemplate(**ORBHelper._filter_data(ORBTemplate, snake_data)) + + def delete(self, template_id: str) -> Dict[str, Any]: + """Delete a template by ID.""" + return self._command(["delete", "--force", template_id]) + + class Machines: + """API for managing compute instances.""" + + def __init__(self, helper: "ORBHelper"): + self.helper = helper + + def _command(self, command: List[str]) -> Any: + return self.helper._command(["machines", *command]) + + def list(self) -> List[ORBMachine]: + """List all machines.""" + data = self._command(["list"]) + if isinstance(data, list): + items = data + else: + items = data.get("machines") or data.get("items") or [] + + result = [] + for item in items: + snake_item = snakecase_dict(item) + filtered_item = ORBHelper._filter_data(ORBMachine, snake_item) + result.append(ORBMachine(**filtered_item)) + return result + + def show(self, machine_id: str) -> ORBMachine: + """Show details for a specific machine.""" + data = self._command(["show", machine_id]) + if isinstance(data, dict) and "machines" in data and data["machines"]: + data = data["machines"][0] + + snake_data = snakecase_dict(data) + return ORBMachine(**ORBHelper._filter_data(ORBMachine, snake_data)) + + def request( + self, template_id: str, count: int, wait: bool = False, timeout: Optional[int] = None + ) -> ORBRequest: + """Request new machines using a template.""" + cmd = ["request", template_id, str(count)] + if wait: + cmd.append("--wait") + if timeout: + cmd.extend(["--timeout", str(timeout)]) + data = self._command(cmd) + # requestMachines usually returns requestId directly at top level or in a message + # But handle list just in case + if isinstance(data, dict) and "requests" in data and data["requests"]: + data = data["requests"][0] + + snake_data = snakecase_dict(data) + return ORBRequest(**ORBHelper._filter_data(ORBRequest, snake_data)) + + def return_machines(self, machine_ids: List[str]) -> ORBRequest: + """Return (terminate) one or more machines.""" + data = self._command(["return", *machine_ids]) + if isinstance(data, dict) and "requests" in data and data["requests"]: + data = data["requests"][0] + + snake_data = snakecase_dict(data) + return ORBRequest(**ORBHelper._filter_data(ORBRequest, snake_data)) + + class Requests: + """API for managing provisioning requests.""" + + def __init__(self, helper: "ORBHelper"): + self.helper = helper + + def _command(self, command: List[str]) -> Any: + return self.helper._command(["requests", *command]) + + def list(self) -> List[ORBRequest]: + """List all requests.""" + data = self._command(["list"]) + if isinstance(data, list): + items = data + else: + items = data.get("requests") or data.get("items") or [] + + result = [] + for item in items: + snake_item = snakecase_dict(item) + filtered_item = ORBHelper._filter_data(ORBRequest, snake_item) + result.append(ORBRequest(**filtered_item)) + return result + + def show(self, request_id: str) -> ORBRequest: + """Show details for a specific request.""" + data = self._command(["show", request_id]) + if isinstance(data, dict) and "requests" in data and data["requests"]: + data = data["requests"][0] + + snake_data = snakecase_dict(data) + return ORBRequest(**ORBHelper._filter_data(ORBRequest, snake_data)) + + def cancel(self, request_id: str) -> ORBRequest: + """Cancel a provisioning request.""" + data = self._command(["cancel", request_id]) + if isinstance(data, dict) and "requests" in data and data["requests"]: + data = data["requests"][0] + + snake_data = snakecase_dict(data) + return ORBRequest(**ORBHelper._filter_data(ORBRequest, snake_data)) + + def __init__(self, config_root_path: str, region: str = "us-east-1"): + """Initialize the helper. + + :param config_root_path: The root directory containing an ORB 'config/' subdirectory. + :param region: AWS region to inject into ORB config. + """ + self._temp_dir = tempfile.TemporaryDirectory() + self._cwd = self._temp_dir.name + + source_config_dir = path.join(config_root_path, "config") + dest_config_dir = path.join(self._cwd, "config") + shutil.copytree(source_config_dir, dest_config_dir) + os.makedirs(path.join(self._cwd, "logs"), exist_ok=True) + + config_json_path = path.join(self._cwd, "config", "config.json") + if path.isfile(config_json_path): + with open(config_json_path) as f: + config_data = json.load(f) + for provider in config_data.get("provider", {}).get("providers", []): + if "config" in provider: + provider["config"]["region"] = region + with open(config_json_path, "w") as f: + json.dump(config_data, f, indent=4) + + self.templates = self.Templates(self) + self.machines = self.Machines(self) + self.requests = self.Requests(self) + + @property + def cwd(self) -> str: + """Return the working directory (temp dir).""" + return self._cwd + + def _command(self, command: List[str]) -> Any: + """Run an ORB CLI command and return the parsed JSON output.""" + # Set environment variables to point to the temp config directory + env = os.environ.copy() + config_dir = path.join(self._cwd, "config") + logs_dir = path.join(self._cwd, "logs") + + env["HF_PROVIDER_CONFDIR"] = config_dir + env["ORB_CONFIG_DIR"] = config_dir + env["ORB_LOG_DIR"] = logs_dir + env["HF_PROVIDER_LOGDIR"] = logs_dir + + cmd = ["orb", *command] + + try: + result = subprocess.run(cmd, check=True, capture_output=True, text=True, cwd=self._cwd, env=env) + stdout = result.stdout + + if not stdout.strip(): + return {} + + data = json.loads(stdout) + + if isinstance(data, dict) and "error" in data: + raise ORBException(data) + return data + except subprocess.CalledProcessError as e: + error_msg = f"Failed to run ORB command: {e}" + if e.stdout: + error_msg += f"\nSTDOUT:\n{e.stdout}" + if e.stderr: + error_msg += f"\nSTDERR:\n{e.stderr}" + raise RuntimeError(error_msg) from e + except json.JSONDecodeError as e: + raise RuntimeError(f"Failed to parse ORB command output as JSON: {e}\nOutput: {stdout}") from e diff --git a/src/scaler/worker_manager_adapter/orb/types.py b/src/scaler/worker_manager_adapter/orb/types.py new file mode 100644 index 000000000..fe6f7e8e2 --- /dev/null +++ b/src/scaler/worker_manager_adapter/orb/types.py @@ -0,0 +1,120 @@ +from dataclasses import dataclass, field, fields +from datetime import datetime +from typing import Any, Dict, List, Optional + + +@dataclass +class ORBTemplate: + template_id: str + name: Optional[str] = None + description: Optional[str] = None + vm_type: Optional[str] = None + image_id: Optional[str] = None + max_number: int = 1 + subnet_id: Optional[str] = None + subnet_ids: List[str] = field(default_factory=list) + security_group_ids: List[str] = field(default_factory=list) + price_type: str = "ondemand" + allocation_strategy: str = "lowest_price" + max_price: Optional[float] = None + instance_types: Dict[str, int] = field(default_factory=dict) + primary_instance_type: Optional[str] = None + network_zones: List[str] = field(default_factory=list) + public_ip_assignment: Optional[bool] = None + root_volume_size: Optional[int] = None + root_volume_type: Optional[str] = None + root_volume_iops: Optional[int] = None + root_volume_throughput: Optional[int] = None + storage_encryption: Optional[bool] = None + encryption_key: Optional[str] = None + key_pair_name: Optional[str] = None + user_data_script: Optional[str] = None + instance_profile: Optional[str] = None + monitoring_enabled: Optional[bool] = None + tags: Dict[str, Any] = field(default_factory=dict) + metadata: Dict[str, Any] = field(default_factory=dict) + launch_template_spec: Optional[Dict[str, Any]] = None + provider_api_spec: Optional[Dict[str, Any]] = None + provider_type: Optional[str] = None + provider_name: Optional[str] = None + provider_api: Optional[str] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + is_active: bool = True + vm_types: Dict[str, Any] = field(default_factory=dict) + key_name: Optional[str] = None + + +@dataclass +class ORBMachine: + machine_id: str = "" + instance_id: str = "" + template_id: str = "" + request_id: Optional[str] = None + provider_type: str = "" + instance_type: str = "" + image_id: str = "" + private_ip: Optional[str] = None + public_ip: Optional[str] = None + subnet_id: Optional[str] = None + security_group_ids: List[str] = field(default_factory=list) + status: str = "" + status_reason: Optional[str] = None + launch_time: Optional[datetime] = None + termination_time: Optional[datetime] = None + tags: Dict[str, Any] = field(default_factory=dict) + metadata: Dict[str, Any] = field(default_factory=dict) + provider_data: Dict[str, Any] = field(default_factory=dict) + version: int = 0 + created_at: Optional[datetime] = None + + +@dataclass +class ORBRequest: + request_id: str = "" + request_type: str = "" + provider_type: str = "" + template_id: str = "" + provider_instance: Optional[str] = None + requested_count: int = 1 + desired_capacity: int = 1 + provider_name: Optional[str] = None + provider_api: Optional[str] = None + resource_ids: List[str] = field(default_factory=list) + status: str = "" + status_message: Optional[str] = None + message: Optional[str] = None + instance_ids: List[str] = field(default_factory=list) + machines: List[ORBMachine] = field(default_factory=list) + successful_count: int = 0 + failed_count: int = 0 + created_at: Optional[datetime] = None + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + metadata: Dict[str, Any] = field(default_factory=dict) + error_details: Dict[str, Any] = field(default_factory=dict) + provider_data: Dict[str, Any] = field(default_factory=dict) + version: int = 0 + + def __post_init__(self): + _machine_fields = {f.name for f in fields(ORBMachine)} + self.machines = [ + ORBMachine(**{k: v for k, v in m.items() if k in _machine_fields}) if isinstance(m, dict) else m + for m in self.machines + ] + + def get_instance_ids(self) -> List[str]: + """Extract instance IDs from any available field in the request.""" + # 1. Try explicit instance_ids + if self.instance_ids: + return self.instance_ids + + # 2. Try resource_ids (often used for reservation IDs, but can contain instance IDs) + if self.resource_ids: + return self.resource_ids + + # 3. Try nested machines list + if self.machines: + return [m.instance_id or m.machine_id for m in self.machines if m.instance_id or m.machine_id] + + return [] diff --git a/src/scaler/worker_manager_adapter/orb/worker_manager.py b/src/scaler/worker_manager_adapter/orb/worker_manager.py new file mode 100644 index 000000000..87133727c --- /dev/null +++ b/src/scaler/worker_manager_adapter/orb/worker_manager.py @@ -0,0 +1,429 @@ +import asyncio +import json +import logging +import os +import signal +import time +import uuid +from dataclasses import asdict +from typing import Any, Dict, List, Optional, Tuple + +import boto3 +import zmq + +from scaler.config.section.orb_worker_adapter import ORBWorkerAdapterConfig +from scaler.io.mixins import AsyncConnector +from scaler.io.utility import create_async_connector, create_async_simple_context +from scaler.io.ymq import ymq +from scaler.protocol.python.message import ( + Message, + WorkerAdapterCommand, + WorkerAdapterCommandResponse, + WorkerAdapterCommandType, + WorkerAdapterHeartbeat, + WorkerAdapterHeartbeatEcho, +) +from scaler.utility.event_loop import create_async_loop_routine, register_event_loop, run_task_forever +from scaler.utility.identifiers import WorkerID +from scaler.utility.logging.utility import setup_logger +from scaler.worker_manager_adapter.common import WorkerGroupID, format_capabilities +from scaler.worker_manager_adapter.orb.helper import ORBHelper +from scaler.worker_manager_adapter.orb.types import ORBTemplate + +Status = WorkerAdapterCommandResponse.Status +logger = logging.getLogger(__name__) + + +# Polling configuration for ORB machine requests +ORB_POLLING_INTERVAL_SECONDS = 5 +ORB_MAX_POLLING_ATTEMPTS = 60 + + +def get_orb_worker_name(instance_id: str) -> str: + """ + Returns the deterministic worker name for an ORB instance. + If instance_id is the bash variable '${INSTANCE_ID}', it returns a bash-compatible string. + """ + if instance_id == "${INSTANCE_ID}": + return "Worker|ORB|${INSTANCE_ID}|${INSTANCE_ID//i-/}" + tag = instance_id.replace("i-", "") + return f"Worker|ORB|{instance_id}|{tag}" + + +class ORBWorkerAdapter: + _config: ORBWorkerAdapterConfig + _orb: Optional[ORBHelper] + _worker_groups: Dict[WorkerGroupID, WorkerID] + _template_id: str + _created_security_group_id: Optional[str] + _created_key_name: Optional[str] + _ec2: Optional[Any] + + def __init__(self, config: ORBWorkerAdapterConfig): + self._config = config + self._address = config.worker_adapter_config.scheduler_address + self._heartbeat_interval_seconds = config.worker_config.heartbeat_interval_seconds + self._capabilities = config.worker_config.per_worker_capabilities.capabilities + self._max_workers = config.worker_adapter_config.max_workers + self._workers_per_group = 1 + + self._event_loop = config.event_loop + self._logging_paths = config.logging_config.paths + self._logging_level = config.logging_config.level + self._logging_config_file = config.logging_config.config_file + + source_orb_root = os.path.abspath(config.orb_config_path) + if not os.path.isdir(source_orb_root): + raise NotADirectoryError(f"orb_config_path must be a directory: {source_orb_root}") + + self._orb: Optional[ORBHelper] = None + self._ec2: Optional[Any] = None + self._context = None + self._connector_external: Optional[AsyncConnector] = None + self._created_security_group_id: Optional[str] = None + self._created_key_name: Optional[str] = None + self._cleaned_up = False + self._worker_groups: Dict[WorkerGroupID, WorkerID] = {} + self._ident: bytes = b"worker_adapter_orb|uninitialized" + self._subnet_id: Optional[str] = None + + def __initialize(self): + source_orb_root = os.path.abspath(self._config.orb_config_path) + region = self._config.aws_region or "us-east-1" + self._orb = ORBHelper(config_root_path=source_orb_root, region=region) + + self._ec2 = boto3.client("ec2", region_name=region) + + self._subnet_id = self._config.subnet_id or self._discover_default_subnet() + + self._template_id = os.urandom(8).hex() + + security_group_ids = self._config.security_group_ids + if not security_group_ids: + self._create_security_group() + security_group_ids = [self._created_security_group_id] + + key_name = self._config.key_name + if not key_name: + self._create_key_pair() + key_name = self._created_key_name + + user_data = self._create_user_data() + user_data_file_path = os.path.join(self._orb.cwd, "config", "user_data.sh") + with open(user_data_file_path, "w") as f: + f.write(user_data) + + template = ORBTemplate( + template_id=self._template_id, + max_number=self._config.worker_adapter_config.max_workers, + provider_api="RunInstances", + provider_name="aws-default", + image_id=self._config.image_id, + vm_type=self._config.instance_type, + instance_types={self._config.instance_type: 1}, + subnet_ids=[self._subnet_id], + security_group_ids=security_group_ids, + key_name=key_name, + user_data_script=user_data_file_path, + metadata={ + "attributes": { + "type": ["String", "X86_64"], + "ncpus": ["Numeric", "1"], + "nram": ["Numeric", "1024"], + "ncores": ["Numeric", "1"], + } + }, + ) + + # Create template in ORB + # Use the cwd from ORBHelper to place the templates file + templates_file_path = os.path.join(self._orb.cwd, "config", "templates.json") + with open(templates_file_path, "w") as f: + template_dict = asdict(template) + json.dump({"templates": [template_dict]}, f, indent=4) + + self._context = create_async_simple_context() + self._name = "worker_adapter_orb" + self._ident = f"{self._name}|{uuid.uuid4().bytes.hex()}".encode() + + self._connector_external = create_async_connector( + self._context, + name=self._name, + socket_type=zmq.DEALER, + address=self._address, + bind_or_connect="connect", + callback=self.__on_receive_external, + identity=self._ident, + ) + + async def __on_receive_external(self, message: Message): + if isinstance(message, WorkerAdapterCommand): + await self._handle_command(message) + elif isinstance(message, WorkerAdapterHeartbeatEcho): + pass + else: + logging.warning(f"Received unknown message type: {type(message)}") + + async def _handle_command(self, command: WorkerAdapterCommand): + cmd_type = command.command + worker_group_id = command.worker_group_id + response_status = Status.Success + worker_ids: List[bytes] = [] + capabilities: Dict[str, int] = {} + + cmd_res = WorkerAdapterCommandType.StartWorkerGroup + if cmd_type == WorkerAdapterCommandType.StartWorkerGroup: + cmd_res = WorkerAdapterCommandType.StartWorkerGroup + worker_group_id, response_status = await self.start_worker_group() + if response_status == Status.Success: + worker_ids = [bytes(self._worker_groups[worker_group_id])] + capabilities = self._capabilities + elif cmd_type == WorkerAdapterCommandType.ShutdownWorkerGroup: + cmd_res = WorkerAdapterCommandType.ShutdownWorkerGroup + response_status = await self.shutdown_worker_group(worker_group_id) + else: + raise ValueError("Unknown Command") + + assert self._connector_external is not None + await self._connector_external.send( + WorkerAdapterCommandResponse.new_msg( + worker_group_id=bytes(worker_group_id), + command=cmd_res, + status=response_status, + worker_ids=worker_ids, + capabilities=capabilities, + ) + ) + + async def __send_heartbeat(self): + assert self._connector_external is not None + await self._connector_external.send( + WorkerAdapterHeartbeat.new_msg( + max_worker_groups=self._max_workers, + workers_per_group=self._workers_per_group, + capabilities=self._capabilities, + ) + ) + + def run(self) -> None: + self._loop = asyncio.new_event_loop() + run_task_forever(self._loop, self._run(), cleanup_callback=self._cleanup) + + def __destroy(self): + print(f"Worker adapter {self._ident!r} received signal, shutting down") + self._task.cancel() + + def __register_signal(self): + self._loop.add_signal_handler(signal.SIGINT, self.__destroy) + self._loop.add_signal_handler(signal.SIGTERM, self.__destroy) + + async def _run(self) -> None: + register_event_loop(self._event_loop) + setup_logger(self._logging_paths, self._logging_config_file, self._logging_level) + self.__initialize() + self._task = self._loop.create_task(self.__get_loops()) + self.__register_signal() + await self._task + + async def __get_loops(self): + assert self._connector_external is not None + loops = [ + create_async_loop_routine(self._connector_external.routine, 0), + create_async_loop_routine(self.__send_heartbeat, self._heartbeat_interval_seconds), + ] + + try: + await asyncio.gather(*loops) + except asyncio.CancelledError: + pass + except ymq.YMQException as e: + if e.code == ymq.ErrorCode.ConnectorSocketClosedByRemoteEnd: + pass + else: + logging.exception(f"{self._ident!r}: failed with unhandled exception:\n{e}") + + def _create_user_data(self) -> str: + worker_config = self._config.worker_config + adapter_config = self._config.worker_adapter_config + + # We assume 1 worker per machine for ORB + # TODO: Add support for multiple workers per machine if needed + num_workers = 1 + + # Build the command + # We construct the full WorkerID here so it's deterministic and matches what the adapter calculates + # We fetch instance_id once and use it to construct the ID + script = f"""#!/bin/bash +INSTANCE_ID=$(ec2-metadata --instance-id --quiet) +WORKER_NAME="{get_orb_worker_name('${INSTANCE_ID}')}" + +nohup /usr/local/bin/scaler_cluster {adapter_config.scheduler_address.to_address()} \ + --num-of-workers {num_workers} \ + --worker-names "${{WORKER_NAME}}" \ + --per-worker-task-queue-size {worker_config.per_worker_task_queue_size} \ + --heartbeat-interval-seconds {worker_config.heartbeat_interval_seconds} \ + --task-timeout-seconds {worker_config.task_timeout_seconds} \ + --garbage-collect-interval-seconds {worker_config.garbage_collect_interval_seconds} \ + --death-timeout-seconds {worker_config.death_timeout_seconds} \ + --trim-memory-threshold-bytes {worker_config.trim_memory_threshold_bytes} \ + --event-loop {self._config.event_loop} \ + --worker-io-threads {self._config.worker_io_threads} \ + --deterministic-worker-ids""" + + if worker_config.hard_processor_suspend: + script += " \ + --hard-processor-suspend" + + if adapter_config.object_storage_address: + script += f" \ + --object-storage-address {adapter_config.object_storage_address.to_string()}" + + capabilities = worker_config.per_worker_capabilities.capabilities + if capabilities: + cap_str = format_capabilities(capabilities) + if cap_str.strip(): + script += f" \ + --per-worker-capabilities {cap_str}" + + script += " > /var/log/opengris-scaler.log 2>&1 &\n" + + return script + + def _discover_default_subnet(self) -> str: + vpcs = self._ec2.describe_vpcs(Filters=[{"Name": "isDefault", "Values": ["true"]}]) + if not vpcs["Vpcs"]: + raise RuntimeError("No default VPC found, and no subnet_id provided.") + default_vpc_id = vpcs["Vpcs"][0]["VpcId"] + + subnets = self._ec2.describe_subnets(Filters=[{"Name": "vpc-id", "Values": [default_vpc_id]}]) + if not subnets["Subnets"]: + raise RuntimeError(f"No subnets found in default VPC {default_vpc_id}.") + + subnet_id = subnets["Subnets"][0]["SubnetId"] + logger.info(f"Auto-discovered subnet_id: {subnet_id}") + return subnet_id + + def _create_security_group(self): + # Get VPC ID from Subnet + subnet_response = self._ec2.describe_subnets(SubnetIds=[self._subnet_id]) + vpc_id = subnet_response["Subnets"][0]["VpcId"] + + # Create Security Group (outbound-only — workers connect out to scheduler via ZMQ) + group_name = f"opengris-orb-sg-{self._template_id}" + sg_response = self._ec2.create_security_group( + Description="Temporary security group created for OpenGRIS ORB worker adapter", + GroupName=group_name, + VpcId=vpc_id, + ) + self._created_security_group_id = sg_response["GroupId"] + logger.info(f"Created security group with ID: {self._created_security_group_id}") + + def _create_key_pair(self): + key_name = f"opengris-orb-key-{self._template_id}" + self._ec2.create_key_pair(KeyName=key_name) + self._created_key_name = key_name + logger.info(f"Created key pair: {key_name}") + + def _cleanup(self): + if self._cleaned_up: + return + self._cleaned_up = True + + if self._connector_external is not None: + self._connector_external.destroy() + + logger.info("Starting cleanup of ORB and AWS resources...") + + # 1. Shutdown all active worker groups (terminate instances) + if self._worker_groups and self._orb is not None: + logger.info(f"Terminating {len(self._worker_groups)} worker groups...") + instance_ids = [wg_id.decode() for wg_id in self._worker_groups.keys()] + try: + # Use ORB to return (terminate) the machines + self._orb.machines.return_machines(instance_ids) + logger.info(f"Successfully requested termination of instances: {instance_ids}") + except Exception as e: + logger.warning(f"Failed to terminate instances during cleanup: {e}") + self._worker_groups.clear() + + if self._created_security_group_id is not None: + try: + logger.info(f"Deleting AWS security group: {self._created_security_group_id}") + self._ec2.delete_security_group(GroupId=self._created_security_group_id) + except Exception as e: + logger.warning(f"Failed to delete security group {self._created_security_group_id}: {e}") + + if self._created_key_name is not None: + try: + logger.info(f"Deleting AWS key pair: {self._created_key_name}") + self._ec2.delete_key_pair(KeyName=self._created_key_name) + except Exception as e: + logger.warning(f"Failed to delete key pair {self._created_key_name}: {e}") + + logger.info("Cleanup completed.") + + def __del__(self): + self._cleanup() + + def _poll_for_instance_id(self, request_id: str) -> Optional[str]: + for _ in range(ORB_MAX_POLLING_ATTEMPTS): + status_response = self._orb.requests.show(request_id) + logger.debug(f"ORB polling response for {request_id}: {status_response}") + + instance_ids = status_response.get_instance_ids() + if instance_ids: + logger.info(f"ORB request {request_id} fulfilled with instance IDs: {instance_ids}") + return instance_ids[0] + + if status_response.status in ["failed", "cancelled", "timeout"]: + error_msg = status_response.status_message or "Unknown failure" + logger.error( + f"ORB machine request {request_id} failed with status '{status_response.status}': {error_msg}" + ) + return None + + time.sleep(ORB_POLLING_INTERVAL_SECONDS) + + timeout_seconds = ORB_MAX_POLLING_ATTEMPTS * ORB_POLLING_INTERVAL_SECONDS + logger.error(f"ORB machine request {request_id} timed out waiting for instance IDs after {timeout_seconds}s.") + return None + + async def start_worker_group(self) -> Tuple[WorkerGroupID, Status]: + if len(self._worker_groups) >= self._max_workers != -1: + return b"", Status.WorkerGroupTooMuch + + # Request a machine. Note: wait and timeout flags in ORB CLI are currently ignored by the handler, + # so we must handle the polling ourselves. + response = self._orb.machines.request(template_id=self._template_id, count=1) + + if not response.request_id: + logger.error(f"ORB machine request failed to return a request ID. Response: {response}") + return b"", Status.UnknownAction + + logger.info(f"ORB machine request {response.request_id} submitted, polling for instance IDs...") + + instance_id = await self._loop.run_in_executor(None, lambda: self._poll_for_instance_id(response.request_id)) + if not instance_id: + return b"", Status.UnknownAction + + worker_group_id = instance_id.encode() + + # Deterministic WorkerID calculation to match the user_data script + worker_id = WorkerID(get_orb_worker_name(instance_id).encode()) + + self._worker_groups[worker_group_id] = worker_id + return worker_group_id, Status.Success + + async def shutdown_worker_group(self, worker_group_id: WorkerGroupID) -> Status: + if not worker_group_id: + return Status.WorkerGroupIDNotSpecified + + if worker_group_id not in self._worker_groups: + logger.warning(f"Worker group with ID {bytes(worker_group_id).decode()} does not exist.") + return Status.WorkerGroupIDNotFound + + instance_id = worker_group_id.decode() + self._orb.machines.return_machines([instance_id]) + + del self._worker_groups[worker_group_id] + return Status.Success diff --git a/tests/io/uv_ymq/test_sockets.py b/tests/io/uv_ymq/test_sockets.py index 40eea5d4e..bad94aeb3 100644 --- a/tests/io/uv_ymq/test_sockets.py +++ b/tests/io/uv_ymq/test_sockets.py @@ -1,14 +1,7 @@ import asyncio import unittest -from scaler.io.uv_ymq import ( - BinderSocket, - Bytes, - ConnectorSocket, - ErrorCode, - IOContext, -) -from scaler.io.uv_ymq import _uv_ymq +from scaler.io.uv_ymq import BinderSocket, Bytes, ConnectorSocket, ErrorCode, IOContext, _uv_ymq class TestSockets(unittest.IsolatedAsyncioTestCase):