Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions sky/server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import enum
import functools
import os
import sys
from typing import Literal, Optional

from sky import sky_logging
Expand Down Expand Up @@ -74,6 +75,40 @@ class WorkerConfig:
garanteed_parallelism: int
burstable_parallelism: int
num_db_connections_per_worker: int
# Recycle a guaranteed worker after it has handled this many requests, to
# bound its high-water-mark RSS. None keeps workers for the pool lifetime.
max_tasks_per_child: Optional[int] = None


def _get_worker_max_tasks_per_child() -> Optional[int]:
"""Reads the per-worker task limit from the environment.

Returns None (no recycling) when unset/invalid, or when the interpreter is
older than 3.11 (``ProcessPoolExecutor.max_tasks_per_child`` was added in
3.11), logging a warning in the latter case so the misconfiguration is
visible.
"""
raw = os.environ.get(server_constants.WORKER_MAX_TASKS_PER_CHILD_ENV_VAR)
if not raw:
return None
try:
value = int(raw)
except ValueError:
logger.warning('Ignoring %s=%r: expected a positive integer.',
server_constants.WORKER_MAX_TASKS_PER_CHILD_ENV_VAR, raw)
return None
if value < 1:
logger.warning('Ignoring %s=%r: expected a positive integer.',
server_constants.WORKER_MAX_TASKS_PER_CHILD_ENV_VAR, raw)
return None
if sys.version_info < (3, 11):
logger.warning(
'%s is set but worker recycling requires Python 3.11+; '
'ignoring it on Python %d.%d.',
server_constants.WORKER_MAX_TASKS_PER_CHILD_ENV_VAR,
sys.version_info[0], sys.version_info[1])
return None
return value


@dataclasses.dataclass
Expand Down Expand Up @@ -134,6 +169,7 @@ def compute_server_config(
max_parallel_for_short = _max_short_worker_parallism(
mem_size_gb, max_parallel_for_long)
queue_backend = QueueBackend.MULTIPROCESSING
max_tasks_per_child = _get_worker_max_tasks_per_child()
burstable_parallel_for_long = 0
burstable_parallel_for_short = 0
# if num_db_connections_per_worker is 0, server will use NullPool
Expand Down Expand Up @@ -202,11 +238,13 @@ def compute_server_config(
long_worker_config=WorkerConfig(
garanteed_parallelism=max_parallel_for_long,
burstable_parallelism=burstable_parallel_for_long,
num_db_connections_per_worker=num_db_connections_per_worker),
num_db_connections_per_worker=num_db_connections_per_worker,
max_tasks_per_child=max_tasks_per_child),
short_worker_config=WorkerConfig(
garanteed_parallelism=max_parallel_for_short,
burstable_parallelism=burstable_parallel_for_short,
num_db_connections_per_worker=num_db_connections_per_worker),
num_db_connections_per_worker=num_db_connections_per_worker,
max_tasks_per_child=max_tasks_per_child),
num_db_connections_per_worker=num_db_connections_per_worker,
)

Expand Down
8 changes: 8 additions & 0 deletions sky/server/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@
# Keep in sync with websocket_proxy.py
API_COOKIE_FILE_DEFAULT_LOCATION = '~/.sky/cookies.txt'

# Maximum number of requests a guaranteed executor worker process handles
# before it is recycled (replaced by a fresh process), to bound the
# worker's high-water-mark RSS. Maps to ProcessPoolExecutor's
# `max_tasks_per_child`, which requires Python 3.11+; ignored on older
# interpreters. Unset (default) keeps workers for the lifetime of the pool.
WORKER_MAX_TASKS_PER_CHILD_ENV_VAR = (
f'{constants.SKYPILOT_ENV_VAR_PREFIX}API_SERVER_WORKER_MAX_TASKS_PER_CHILD')

# The path to the dashboard build output
DASHBOARD_DIR = os.path.join(os.path.dirname(__file__), '..', 'dashboard',
'out')
Expand Down
2 changes: 2 additions & 0 deletions sky/server/requests/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def __init__(self, schedule_type: api_requests.ScheduleType,
self.burstable_parallelism = config.burstable_parallelism
self.num_db_connections_per_worker = (
config.num_db_connections_per_worker)
self.max_tasks_per_child = config.max_tasks_per_child
self._thread: Optional[threading.Thread] = None
self._cancel_event = threading.Event()

Expand Down Expand Up @@ -392,6 +393,7 @@ def run(self) -> None:
executor = process.BurstableExecutor(
garanteed_workers=self.garanteed_parallelism,
burst_workers=self.burstable_parallelism,
max_tasks_per_child=self.max_tasks_per_child,
initializer=executor_initializer,
initargs=(proc_group, clean_env_module.get_clean_server_env()))
# Initialize the appropriate gauge for the number of free executors
Expand Down
24 changes: 21 additions & 3 deletions sky/server/requests/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import concurrent.futures
import logging
import multiprocessing
import sys
import threading
import time
from typing import Callable, Dict, Optional, Tuple
Expand All @@ -24,7 +25,16 @@ class PoolExecutor(concurrent.futures.ProcessPoolExecutor):
shutting down instead of indefinitely waiting.
"""

def __init__(self, max_workers: int, **kwargs):
def __init__(self,
max_workers: int,
max_tasks_per_child: Optional[int] = None,
**kwargs):
if max_tasks_per_child is not None and sys.version_info >= (3, 11):
# Recycle a worker process after it has handled this many tasks so
# its high-water-mark RSS is reclaimed. `max_tasks_per_child` was
# added to ProcessPoolExecutor in Python 3.11; on older versions it
# is silently ignored (the caller logs a warning instead).
kwargs['max_tasks_per_child'] = max_tasks_per_child
super().__init__(max_workers=max_workers, **kwargs)
self.max_workers: int = max_workers
# The number of workers that are handling tasks, atomicity across
Expand Down Expand Up @@ -203,13 +213,21 @@ class BurstableExecutor:
def __init__(self,
garanteed_workers: int,
burst_workers: int = 0,
max_tasks_per_child: Optional[int] = None,
**kwargs):
if garanteed_workers > 0:
self._guaranteed_workers = garanteed_workers
self._guaranteed_pool_kwargs = kwargs
# Worker recycling applies to the guaranteed pool only; burst
# workers are already disposed after each task. Keep
# max_tasks_per_child in the stored kwargs so the pool is rebuilt
# with the same setting after a BrokenProcessPool.
self._guaranteed_pool_kwargs = {
**kwargs,
'max_tasks_per_child': max_tasks_per_child,
}
self._guaranteed_pool_lock = threading.Lock()
self._executor = PoolExecutor(max_workers=garanteed_workers,
**kwargs)
**self._guaranteed_pool_kwargs)
if burst_workers > 0:
self._burst_executor = DisposableExecutor(max_workers=burst_workers,
**kwargs)
Expand Down
54 changes: 54 additions & 0 deletions tests/unit_tests/test_sky/server/requests/test_process.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Unit tests for sky/server/requests/process.py."""
from concurrent.futures import Future
import concurrent.futures.process
import os
import sys
import time
import unittest.mock

Expand All @@ -17,6 +19,11 @@ def dummy_task(sleep_time=0.1):
return True


def pid_task():
"""Return the worker process PID, to observe worker recycling."""
return os.getpid()


def failing_task():
"""A task that raises an exception."""
raise ValueError('Task failed')
Expand Down Expand Up @@ -86,6 +93,53 @@ def test_pool_executor():
executor.shutdown()


@pytest.mark.skipif(sys.version_info < (3, 11),
reason='max_tasks_per_child requires Python 3.11+')
def test_pool_executor_recycles_after_max_tasks():
"""A guaranteed worker is replaced after max_tasks_per_child tasks."""
executor = PoolExecutor(max_workers=1, max_tasks_per_child=2)
try:
# Submit one at a time so all tasks land on the single worker in
# order, making recycling deterministic.
pids = []
for _ in range(4):
pids.append(executor.submit(pid_task).result(timeout=20))
# With max_tasks_per_child=2 the worker is recycled after every 2
# tasks: [A, A, B, B] with A != B.
assert pids[0] == pids[1], pids
assert pids[2] == pids[3], pids
assert pids[0] != pids[2], pids
finally:
executor.shutdown()


def test_pool_executor_no_recycle_by_default():
"""Without max_tasks_per_child the worker is reused for every task."""
executor = PoolExecutor(max_workers=1)
try:
pids = [executor.submit(pid_task).result(timeout=20) for _ in range(4)]
assert len(set(pids)) == 1, pids
finally:
executor.shutdown()


def test_burstable_executor_max_tasks_per_child_routing():
"""max_tasks_per_child is applied to the guaranteed pool only.

Burst workers (DisposableExecutor) already dispose after each task, so the
setting must not be forwarded to them.
"""
executor = BurstableExecutor(garanteed_workers=1,
burst_workers=1,
max_tasks_per_child=5)
try:
assert (executor._guaranteed_pool_kwargs['max_tasks_per_child'] == 5)
# DisposableExecutor does not accept/track the setting.
assert not hasattr(executor._burst_executor, 'max_tasks_per_child')
finally:
executor.shutdown()


def test_disposable_executor():
"""Test DisposableExecutor functionality."""
executor = DisposableExecutor(max_workers=2)
Expand Down
38 changes: 38 additions & 0 deletions tests/unit_tests/test_sky/server/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,44 @@ def test_compute_server_config_pool(cpu_count, mem_size_gb, buildkite_mock):
assert controller_utils._get_request_parallelism(pool=True) == 40


_MAX_TASKS_ENV = config.server_constants.WORKER_MAX_TASKS_PER_CHILD_ENV_VAR


def test_worker_max_tasks_per_child_unset(monkeypatch):
monkeypatch.delenv(_MAX_TASKS_ENV, raising=False)
assert config._get_worker_max_tasks_per_child() is None


def test_worker_max_tasks_per_child_valid(monkeypatch):
monkeypatch.setenv(_MAX_TASKS_ENV, '100')
monkeypatch.setattr(config.sys, 'version_info', (3, 11, 0))
assert config._get_worker_max_tasks_per_child() == 100


@pytest.mark.parametrize('value', ['0', '-3', 'abc', ''])
def test_worker_max_tasks_per_child_invalid(monkeypatch, value):
monkeypatch.setenv(_MAX_TASKS_ENV, value)
monkeypatch.setattr(config.sys, 'version_info', (3, 11, 0))
assert config._get_worker_max_tasks_per_child() is None


def test_worker_max_tasks_per_child_ignored_pre_311(monkeypatch):
monkeypatch.setenv(_MAX_TASKS_ENV, '100')
monkeypatch.setattr(config.sys, 'version_info', (3, 10, 0))
assert config._get_worker_max_tasks_per_child() is None


@mock.patch('sky.utils.common_utils.get_mem_size_gb', return_value=8)
@mock.patch('sky.utils.common_utils.get_cpu_count', return_value=4)
def test_compute_server_config_propagates_max_tasks(cpu_count, mem_size_gb,
monkeypatch):
monkeypatch.setenv(_MAX_TASKS_ENV, '50')
monkeypatch.setattr(config.sys, 'version_info', (3, 11, 0))
c = config.compute_server_config(deploy=True)
assert c.long_worker_config.max_tasks_per_child == 50
assert c.short_worker_config.max_tasks_per_child == 50


def test_parallel_size_long():
# Test with insufficient memory
cpu_count = 4
Expand Down
Loading