Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions docs/source/tutorials/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,39 @@ You can override any value from the TOML file by providing it as a command-line
scaler_cluster tcp://127.0.0.1:6378 --config example_config.toml --num-of-workers 12
The cluster will start with **12 workers**, but all other settings (like ``task_timeout_seconds``) will still be loaded from the ``[cluster]`` section of ``example_config.toml``.


**Scenario 3: Waterfall Scaling Configuration**

To use the ``waterfall_v1`` policy engine for priority-based scaling across multiple worker adapters, set ``policy_engine_type = "waterfall_v1"`` and provide rules in ``policy_content`` (one rule per line, ``#`` comments supported):

**waterfall_config.toml**

.. code-block:: toml
[scheduler]
object_storage_address = "tcp://127.0.0.1:6379"
monitor_address = "tcp://127.0.0.1:6380"
logging_level = "INFO"
policy_engine_type = "waterfall_v1"
policy_content = """
# priority, adapter_id_prefix, max_workers
1, NAT, 8
2, ECS, 50
"""
[native_worker_adapter]
max_workers = 8
[ecs_worker_adapter]
max_workers = 50
Then start the scheduler and worker adapters:

.. code-block:: bash
scaler_scheduler tcp://127.0.0.1:8516 --config waterfall_config.toml &
scaler_worker_adapter_native tcp://127.0.0.1:8516 --config waterfall_config.toml &
scaler_worker_adapter_ecs tcp://127.0.0.1:8516 --config waterfall_config.toml &
Local ``NAT`` workers will scale up first. When they reach capacity, ``ECS`` workers will begin scaling. On scale-down, ECS workers drain before local workers.
28 changes: 28 additions & 0 deletions docs/source/tutorials/scaling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ Scaler provides several built-in scaling policies:
- Capability-aware scaling. Scales worker groups based on task-required capabilities (e.g., GPU, memory).
* - ``fixed_elastic``
- Hybrid scaling using primary and secondary worker managers with configurable limits.
* - ``waterfall_v1``
- Priority-based cascading across multiple worker managers. Higher-priority managers fill first; overflow goes to lower-priority.


No Scaling (``no``)
Expand Down Expand Up @@ -161,6 +163,32 @@ This is useful for scenarios where you have a fixed pool of dedicated resources
* When scaling down, only secondary manager groups are shut down


Waterfall Scaling (``waterfall_v1``)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The waterfall scaling policy cascades worker scaling across prioritized worker managers. Higher-priority managers fill first; when they reach capacity, overflow goes to the next priority tier. When scaling down, the lowest-priority managers drain first.

This is useful for hybrid deployments where you want to prefer cheaper or lower-latency resources (e.g., local bare-metal) and only burst to more expensive resources (e.g., cloud) when needed.

**Configuration:**

The waterfall policy uses ``policy_engine_type = "waterfall_v1"`` and a newline-separated rule format for ``policy_content``. Each rule is a comma-separated line with three fields: ``priority``, ``manager_id_prefix``, ``max_workers``. Lines starting with ``#`` are comments.

.. code:: toml
[scheduler]
policy_engine_type = "waterfall_v1"
policy_content = """
# priority, manager_id_prefix, max_workers
# Use local workers first (cheap, low latency)
1, NAT, 8
# Overflow to ECS when local capacity is exhausted
2, ECS, 50
"""
Rules reference worker manager ID prefixes. At runtime, each worker manager generates a full ID like ``NAT|<pid>``; the prefix ``NAT`` matches any manager whose ID starts with ``NAT``. Multiple managers can share the same prefix and are governed by the same rule.


Worker Manager Protocol
-----------------------

Expand Down
1 change: 1 addition & 0 deletions src/scaler/protocol/capnp/message.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ struct WorkerManagerHeartbeat {
maxWorkerGroups @0 :UInt32;
workersPerGroup @1 :UInt32;
capabilities @2 :List(CommonType.TaskCapability);
workerManagerID @3 :Data;
}

struct WorkerManagerHeartbeatEcho {
Expand Down
7 changes: 6 additions & 1 deletion src/scaler/protocol/python/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,13 @@ def workers_per_group(self) -> int:
def capabilities(self) -> Dict[str, int]:
return {capability.name: capability.value for capability in self._msg.capabilities}

@property
def worker_manager_id(self) -> bytes:
return self._msg.workerManagerID

@staticmethod
def new_msg(
max_worker_groups: int, workers_per_group: int, capabilities: Dict[str, int]
max_worker_groups: int, workers_per_group: int, capabilities: Dict[str, int], worker_manager_id: bytes
) -> "WorkerManagerHeartbeat":
return WorkerManagerHeartbeat(
_message.WorkerManagerHeartbeat(
Expand All @@ -376,6 +380,7 @@ def new_msg(
capabilities=[
TaskCapability.new_msg(name, value).get_message() for name, value in capabilities.items()
],
workerManagerID=worker_manager_id,
)
)

Expand Down
7 changes: 6 additions & 1 deletion src/scaler/scheduler/controllers/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
WorkerManagerHeartbeat,
)
from scaler.protocol.python.status import ScalingManagerStatus
from scaler.scheduler.controllers.policies.simple_policy.scaling.types import WorkerGroupCapabilities, WorkerGroupState
from scaler.scheduler.controllers.policies.simple_policy.scaling.types import (
WorkerGroupCapabilities,
WorkerGroupState,
WorkerManagerSnapshot,
)
from scaler.utility.identifiers import ClientID, ObjectID, TaskID, WorkerID
from scaler.utility.mixins import Reporter

Expand Down Expand Up @@ -255,6 +259,7 @@ def get_scaling_commands(
worker_manager_heartbeat: WorkerManagerHeartbeat,
worker_groups: WorkerGroupState,
worker_group_capabilities: WorkerGroupCapabilities,
worker_manager_snapshots: Dict[bytes, WorkerManagerSnapshot],
) -> List[WorkerManagerCommand]:
"""Pure function: state in, commands out. Commands are either all start or all shutdown, never mixed."""
raise NotImplementedError()
Expand Down
5 changes: 5 additions & 0 deletions src/scaler/scheduler/controllers/policies/library/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,9 @@ def create_policy(policy_engine_type: str, policy_content: str) -> ScalerPolicy:

return SimplePolicy(policy_content)

if engine_type == PolicyEngineType.WATERFALL_V1:
from scaler.scheduler.controllers.policies.waterfall_v1.waterfall_v1_policy import WaterfallV1Policy

return WaterfallV1Policy(policy_content)

raise ValueError(f"Unknown policy_engine_type: {policy_engine_type}")
7 changes: 6 additions & 1 deletion src/scaler/scheduler/controllers/policies/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

from scaler.protocol.python.message import InformationSnapshot, Task, WorkerManagerCommand, WorkerManagerHeartbeat
from scaler.protocol.python.status import ScalingManagerStatus
from scaler.scheduler.controllers.policies.simple_policy.scaling.types import WorkerGroupCapabilities, WorkerGroupState
from scaler.scheduler.controllers.policies.simple_policy.scaling.types import (
WorkerGroupCapabilities,
WorkerGroupState,
WorkerManagerSnapshot,
)
from scaler.utility.identifiers import TaskID, WorkerID


Expand Down Expand Up @@ -51,6 +55,7 @@ def get_scaling_commands(
worker_manager_heartbeat: WorkerManagerHeartbeat,
worker_groups: WorkerGroupState,
worker_group_capabilities: WorkerGroupCapabilities,
worker_manager_snapshots: Dict[bytes, WorkerManagerSnapshot],
) -> List[WorkerManagerCommand]:
raise NotImplementedError()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
WorkerGroupCapabilities,
WorkerGroupID,
WorkerGroupState,
WorkerManagerSnapshot,
)
from scaler.utility.identifiers import WorkerID

Expand All @@ -40,6 +41,7 @@ def get_scaling_commands(
worker_manager_heartbeat: WorkerManagerHeartbeat,
worker_groups: WorkerGroupState,
worker_group_capabilities: WorkerGroupCapabilities,
worker_manager_snapshots: Dict[bytes, WorkerManagerSnapshot],
) -> List[WorkerManagerCommand]:
# Derive worker_groups_by_capability from worker_groups + worker_group_capabilities
worker_groups_by_capability = self._derive_worker_groups_by_capability(worker_groups, worker_group_capabilities)
Expand Down Expand Up @@ -176,7 +178,6 @@ def _get_shutdown_commands(
if not worker_group_dict:
continue

# Find tasks that these workers can handle
task_count = 0
for task_capability_keys, tasks in tasks_by_capability.items():
if task_capability_keys <= capability_keys:
Expand All @@ -190,7 +191,6 @@ def _get_shutdown_commands(

task_ratio = task_count / worker_count
if task_ratio < self._lower_task_ratio:
# Find the worker group with the least queued tasks
worker_group_task_counts: Dict[WorkerGroupID, int] = {}
for worker_group_id, worker_ids in worker_group_dict.items():
total_queued = sum(
Expand All @@ -203,17 +203,13 @@ def _get_shutdown_commands(
if not worker_group_task_counts:
continue

# Select the worker group with the fewest queued tasks to shut down
least_busy_group_id = min(worker_group_task_counts, key=lambda gid: worker_group_task_counts[gid])

# Don't scale down if there are pending tasks and this would leave no capable workers
workers_in_group = len(worker_group_dict.get(least_busy_group_id, []))
remaining_worker_count = worker_count - workers_in_group
if task_count > 0 and remaining_worker_count == 0:
# This is the last worker group that can handle these tasks - don't shut it down
continue
if remaining_worker_count > 0 and (task_count / remaining_worker_count) > self._upper_task_ratio:
# Shutting down this group would cause task ratio to exceed upper threshold and scale-up again
continue

commands.append(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
WorkerGroupCapabilities,
WorkerGroupID,
WorkerGroupState,
WorkerManagerSnapshot,
)


Expand Down Expand Up @@ -42,6 +43,7 @@ def get_scaling_commands(
worker_manager_heartbeat: WorkerManagerHeartbeat,
worker_groups: WorkerGroupState,
worker_group_capabilities: WorkerGroupCapabilities,
worker_manager_snapshots: Dict[bytes, WorkerManagerSnapshot],
) -> List[WorkerManagerCommand]:
if not information_snapshot.workers:
if information_snapshot.tasks:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import abc
from typing import List
from typing import Dict, List

from scaler.protocol.python.message import InformationSnapshot, WorkerManagerCommand, WorkerManagerHeartbeat
from scaler.protocol.python.status import ScalingManagerStatus
from scaler.scheduler.controllers.policies.simple_policy.scaling.types import WorkerGroupCapabilities, WorkerGroupState
from scaler.scheduler.controllers.policies.simple_policy.scaling.types import (
WorkerGroupCapabilities,
WorkerGroupState,
WorkerManagerSnapshot,
)


class ScalingPolicy:
Expand All @@ -21,6 +25,7 @@ def get_scaling_commands(
worker_manager_heartbeat: WorkerManagerHeartbeat,
worker_groups: WorkerGroupState,
worker_group_capabilities: WorkerGroupCapabilities,
worker_manager_snapshots: Dict[bytes, WorkerManagerSnapshot],
) -> List[WorkerManagerCommand]:
"""
Pure function: state in, commands out.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import List
from typing import Dict, List

from scaler.protocol.python.message import InformationSnapshot, WorkerManagerCommand, WorkerManagerHeartbeat
from scaler.protocol.python.status import ScalingManagerStatus
from scaler.scheduler.controllers.policies.simple_policy.scaling.mixins import ScalingPolicy
from scaler.scheduler.controllers.policies.simple_policy.scaling.types import WorkerGroupCapabilities, WorkerGroupState
from scaler.scheduler.controllers.policies.simple_policy.scaling.types import (
WorkerGroupCapabilities,
WorkerGroupState,
WorkerManagerSnapshot,
)


class NoScalingPolicy(ScalingPolicy):
Expand All @@ -16,6 +20,7 @@ def get_scaling_commands(
worker_manager_heartbeat: WorkerManagerHeartbeat,
worker_groups: WorkerGroupState,
worker_group_capabilities: WorkerGroupCapabilities,
worker_manager_snapshots: Dict[bytes, WorkerManagerSnapshot],
) -> List[WorkerManagerCommand]:
return []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ class WorkerGroupInfo:
WorkerGroupCapabilities = Dict[WorkerGroupID, Dict[str, int]]


@dataclasses.dataclass(frozen=True)
class WorkerManagerSnapshot:
"""Immutable snapshot of a worker manager's state, passed to stateless scaling policies."""

worker_manager_id: bytes
max_worker_groups: int
worker_group_count: int
last_seen_s: float # time.time() epoch seconds of last heartbeat


class ScalingPolicyStrategy(enum.Enum):
NO = "no"
VANILLA = "vanilla"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
WorkerGroupCapabilities,
WorkerGroupID,
WorkerGroupState,
WorkerManagerSnapshot,
)


Expand All @@ -31,6 +32,7 @@ def get_scaling_commands(
worker_manager_heartbeat: WorkerManagerHeartbeat,
worker_groups: WorkerGroupState,
worker_group_capabilities: WorkerGroupCapabilities,
worker_manager_snapshots: Dict[bytes, WorkerManagerSnapshot],
) -> List[WorkerManagerCommand]:
if not information_snapshot.workers:
if information_snapshot.tasks:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ScalingPolicyStrategy,
WorkerGroupCapabilities,
WorkerGroupState,
WorkerManagerSnapshot,
)
from scaler.scheduler.controllers.policies.simple_policy.scaling.utility import create_scaling_policy
from scaler.utility.identifiers import TaskID, WorkerID
Expand Down Expand Up @@ -64,9 +65,14 @@ def get_scaling_commands(
worker_manager_heartbeat: WorkerManagerHeartbeat,
worker_groups: WorkerGroupState,
worker_group_capabilities: WorkerGroupCapabilities,
worker_manager_snapshots: Dict[bytes, WorkerManagerSnapshot],
) -> List[WorkerManagerCommand]:
return self._scaling_policy.get_scaling_commands(
information_snapshot, worker_manager_heartbeat, worker_groups, worker_group_capabilities
information_snapshot,
worker_manager_heartbeat,
worker_groups,
worker_group_capabilities,
worker_manager_snapshots,
)

def get_scaling_status(self, worker_groups: WorkerGroupState) -> ScalingManagerStatus:
Expand Down
1 change: 1 addition & 0 deletions src/scaler/scheduler/controllers/policies/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

class PolicyEngineType(enum.Enum):
SIMPLE = "simple"
WATERFALL_V1 = "waterfall_v1"

def __str__(self):
return self.name
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import dataclasses


@dataclasses.dataclass(frozen=True)
class WaterfallRule:
"""A single rule in the waterfall config, parsed from 'priority,worker_type,max_task_concurrency'."""

priority: int
worker_type: bytes
max_task_concurrency: int
Loading
Loading