Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion lib/iris/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ dev = [
"pytest-playwright>=0.6.2",
"pytest-timeout",
"pytest-xdist",
"pytest>=8.3.2",
"pytest>=8.4",
]

[tool.hatch.build.hooks.custom]
Expand All @@ -89,6 +89,10 @@ markers = [
"e2e: end-to-end cluster tests (chaos, dashboard, scheduling)",
]
filterwarnings = ["ignore::DeprecationWarning"]
# Cap assertion-diff output even under -v, so a single bad assertion
# can't dump tens of KB of proto repr into CI logs.
truncation_limit_lines = 100
truncation_limit_chars = 10000
log_level = "INFO"
log_format = "%(asctime)s %(levelname)s %(message)s"
log_date_format = "%Y-%m-%d %H:%M:%S"
6 changes: 5 additions & 1 deletion lib/iris/src/iris/cluster/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,7 @@ def __init__(
)

self._config = config
self._stopped = False
self._provider: TaskProvider | K8sTaskProvider = provider
self._provider_scheduling_events: list[SchedulingEvent] = []
self._provider_capacity: ClusterCapacity | None = None
Expand Down Expand Up @@ -1279,14 +1280,17 @@ def start(self) -> None:
logger.info("Registered system endpoint /system/log-server -> %s", self._log_service_address)

def stop(self) -> None:
"""Stop all background components gracefully.
"""Stop all background components gracefully. Idempotent.

Shutdown ordering:
1. Unregister atexit hook so it doesn't fire against a closed DB.
2. Stop scheduling/heartbeat/autoscaler loops so no new work is triggered.
3. Shut down the autoscaler (stops monitors, terminates VMs, stops platform).
4. Stop remaining threads (server) and executors.
"""
if self._stopped:
return
self._stopped = True
# Unregister atexit hook before closing DB connections.
if self._atexit_registered:
atexit.unregister(self._atexit_checkpoint)
Expand Down
65 changes: 63 additions & 2 deletions lib/iris/tests/cluster/controller/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@
)
from iris.cluster.controller.autoscaler import Autoscaler
from iris.cluster.controller.autoscaler.models import DemandEntry
from iris.cluster.controller.autoscaler.scaling_group import ScalingGroup
from iris.cluster.controller.controller import Controller, ControllerConfig
from iris.cluster.controller.db import (
ACTIVE_TASK_STATES,
ControllerDB,
_decode_attribute_rows,
task_row_can_be_scheduled,
task_row_is_finished,
)
from iris.cluster.controller.provider import ProviderUnsupportedError
from iris.cluster.controller.schema import (
ATTEMPT_PROJECTION,
JOB_CONFIG_JOIN,
Expand All @@ -47,8 +50,6 @@
WorkerRow,
tasks_with_attempts,
)
from iris.cluster.controller.provider import ProviderUnsupportedError
from iris.cluster.controller.autoscaler.scaling_group import ScalingGroup
from iris.cluster.controller.service import ControllerServiceImpl
from iris.log_server.server import LogServiceImpl
from iris.cluster.controller.transitions import (
Expand Down Expand Up @@ -201,6 +202,66 @@ def make_controller_state(**kwargs):
shutil.rmtree(tmp, ignore_errors=True)


@pytest.fixture
def make_controller(tmp_path):
"""Factory for building ``Controller`` instances with automatic teardown.

``Controller.__init__`` attaches a ``RemoteLogHandler`` to the ``iris``
logger and spawns a ``LogPusher`` drain thread. Without ``stop()``, those
leak across the test session and pull every ``iris.*`` log record into
their internal queue — which can then be flushed into another test's
monkeypatched ``LogServiceClientSync``. The factory tracks every
constructed controller and ``stop()``s them at fixture teardown.

Pass ``db=`` to inject a pre-built ``ControllerDB`` (otherwise the
``Controller`` opens one under ``config.local_state_dir``). Pass
``provider=`` to override the default ``FakeProvider``. Any remaining
keyword arguments are forwarded to ``ControllerConfig``.

Usage::

def test_foo(make_controller, tmp_path):
ctrl = make_controller(remote_state_dir="file:///tmp/iris-state")
# Or inject an existing DB / provider:
ctrl = make_controller(
remote_state_dir="file:///tmp/iris-state",
local_state_dir=tmp_path,
db=my_db,
)
"""
created: list[Controller] = []

def _factory(
config: ControllerConfig | None = None,
*,
provider=None,
db: ControllerDB | None = None,
**config_kwargs,
) -> Controller:
if config is None:
config_kwargs.setdefault("remote_state_dir", f"file://{tmp_path}/remote")
config = ControllerConfig(**config_kwargs)
elif config_kwargs:
raise TypeError("make_controller: pass either a config or config kwargs, not both")
controller = Controller(
config=config,
provider=provider if provider is not None else FakeProvider(),
db=db,
)
created.append(controller)
return controller

yield _factory
errors: list[BaseException] = []
for controller in created:
try:
controller.stop()
except BaseException as exc:
errors.append(exc)
if errors:
raise errors[0]


def make_test_entrypoint() -> job_pb2.RuntimeEntrypoint:
entrypoint = job_pb2.RuntimeEntrypoint()
entrypoint.run_command.argv[:] = ["python", "-c", "pass"]
Expand Down
57 changes: 12 additions & 45 deletions lib/iris/tests/cluster/controller/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,19 @@

"""Tests for controller checkpoint: remote-only write and download-before-create restore."""

from pathlib import Path

from iris.cluster.controller.checkpoint import (
download_checkpoint_to_local,
prune_old_checkpoints,
write_checkpoint,
)
from iris.cluster.controller.controller import (
Controller,
ControllerConfig,
)
from iris.cluster.controller.db import ControllerDB
from rigging.timing import Duration
from tests.cluster.controller.conftest import FakeProvider


def _local_state_dir(tmp_path: Path, name: str = "state") -> Path:
d = tmp_path / name
d.mkdir(parents=True, exist_ok=True)
return d


def _make_controller(tmp_path: Path, remote_state_dir: str | None = None, **kwargs) -> Controller:
if remote_state_dir is None:
remote_state_dir = f"file://{tmp_path}/remote"
state_dir = _local_state_dir(tmp_path)
config = ControllerConfig(remote_state_dir=remote_state_dir, local_state_dir=state_dir, **kwargs)
return Controller(config=config, provider=FakeProvider())


def test_write_checkpoint_uploads_compressed(tmp_path):
def test_write_checkpoint_uploads_compressed(tmp_path, make_controller):
"""write_checkpoint creates a timestamped directory with .zst files."""
remote_dir = f"file://{tmp_path}/remote"
controller = _make_controller(tmp_path, remote_state_dir=remote_dir)
controller = make_controller(remote_state_dir=remote_dir)

path, result = write_checkpoint(controller._db, remote_dir)

Expand All @@ -52,26 +31,22 @@ def test_write_checkpoint_uploads_compressed(tmp_path):
assert result.task_count == 0
assert result.worker_count == 0

controller._db.close()


def test_begin_checkpoint_returns_remote_path(tmp_path):
def test_begin_checkpoint_returns_remote_path(tmp_path, make_controller):
"""begin_checkpoint returns a remote path string."""
remote_dir = f"file://{tmp_path}/remote"
controller = _make_controller(tmp_path, remote_state_dir=remote_dir)
controller = make_controller(remote_state_dir=remote_dir)

path, result = controller.begin_checkpoint()

assert path.startswith(f"file://{tmp_path}/remote/controller-state/")
assert result.job_count == 0

controller._db.close()


def test_atexit_checkpoint_writes_to_remote(tmp_path):
def test_atexit_checkpoint_writes_to_remote(tmp_path, make_controller):
"""_atexit_checkpoint writes directly to remote storage."""
remote_dir = f"file://{tmp_path}/remote"
controller = _make_controller(tmp_path, remote_state_dir=remote_dir)
controller = make_controller(remote_state_dir=remote_dir)

controller._atexit_checkpoint()

Expand All @@ -80,8 +55,6 @@ def test_atexit_checkpoint_writes_to_remote(tmp_path):
assert len(timestamped_dirs) >= 1
assert (timestamped_dirs[0] / "controller.sqlite3.zst").exists()

controller._db.close()


def test_download_checkpoint_to_local(tmp_path):
"""download_checkpoint_to_local copies remote DB to local path."""
Expand Down Expand Up @@ -117,23 +90,22 @@ def test_download_from_explicit_path(tmp_path):
assert (local_db_dir / "controller.sqlite3").exists()


def test_write_checkpoint_roundtrip(tmp_path):
def test_write_checkpoint_roundtrip(tmp_path, make_controller):
"""Write then download produces a valid DB."""
remote_dir = f"file://{tmp_path}/remote"
controller = _make_controller(tmp_path, remote_state_dir=remote_dir)
controller = make_controller(remote_state_dir=remote_dir)
write_checkpoint(controller._db, remote_dir)
controller._db.close()

local_db_dir = tmp_path / "restored"
download_checkpoint_to_local(remote_dir, local_db_dir)
restored_db = ControllerDB(db_dir=local_db_dir)
restored_db.close()


def test_write_checkpoint_cleans_up_temp_file(tmp_path):
def test_write_checkpoint_cleans_up_temp_file(tmp_path, make_controller):
"""write_checkpoint does not leave temp files in the DB directory."""
remote_dir = f"file://{tmp_path}/remote"
controller = _make_controller(tmp_path, remote_state_dir=remote_dir)
controller = make_controller(remote_state_dir=remote_dir)
db_dir = controller._db.db_path.parent

files_before = set(db_dir.iterdir())
Expand All @@ -144,8 +116,6 @@ def test_write_checkpoint_cleans_up_temp_file(tmp_path):
sqlite_temps = [f for f in new_files if ".sqlite3" in f.name and f.name != ControllerDB.DB_FILENAME]
assert len(sqlite_temps) == 0

controller._db.close()


def test_local_db_exists_skips_remote_download(tmp_path):
"""When a local DB already exists, download_checkpoint_to_local should not be called.
Expand Down Expand Up @@ -217,11 +187,10 @@ def test_download_from_explicit_path_pairs_profiles_db(tmp_path):
assert (local_db_dir / "profiles.sqlite3").exists(), "profiles DB should be downloaded into local_db_dir"


def test_periodic_checkpoint_inline(tmp_path):
def test_periodic_checkpoint_inline(tmp_path, make_controller):
"""Controller writes periodic checkpoints when limiter fires."""
remote_dir = f"file://{tmp_path}/remote"
controller = _make_controller(
tmp_path,
controller = make_controller(
remote_state_dir=remote_dir,
checkpoint_interval=Duration.from_seconds(0),
)
Expand All @@ -235,8 +204,6 @@ def test_periodic_checkpoint_inline(tmp_path):
assert len(timestamped_dirs) >= 1
assert (timestamped_dirs[0] / "controller.sqlite3.zst").exists()

controller._db.close()


def test_download_uncompressed_fallback(tmp_path):
"""download_checkpoint_to_local falls back to uncompressed files from old checkpoints."""
Expand Down
15 changes: 2 additions & 13 deletions lib/iris/tests/cluster/controller/test_dry_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@

import pytest

from iris.cluster.controller.controller import Controller, ControllerConfig
from iris.cluster.controller.db import ControllerDB
from iris.cluster.controller.schema import TASK_DETAIL_PROJECTION
from iris.cluster.types import JobName
from iris.rpc import job_pb2
from tests.cluster.controller.conftest import (
FakeProvider,
make_job_request,
make_worker_metadata,
register_worker,
Expand All @@ -24,16 +21,8 @@


@pytest.fixture
def dry_run_controller(tmp_path):
db = ControllerDB(db_dir=tmp_path / "db")
config = ControllerConfig(
dry_run=True,
remote_state_dir=f"file://{tmp_path}/remote",
local_state_dir=tmp_path,
)
controller = Controller(config=config, provider=FakeProvider(), db=db)
yield controller
controller.stop()
def dry_run_controller(make_controller):
return make_controller(dry_run=True)


def test_dry_run_controller_starts_and_stops(dry_run_controller):
Expand Down
26 changes: 13 additions & 13 deletions lib/iris/tests/cluster/controller/test_heartbeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@

import iris.cluster.controller.worker_provider as worker_provider_module
import pytest
from iris.cluster.controller.controller import Controller, ControllerConfig, _SyncFailureAccumulator
from iris.cluster.controller.controller import _SyncFailureAccumulator
from iris.cluster.controller.db import ControllerDB
from iris.cluster.controller.schema import (
TASK_DETAIL_PROJECTION,
WORKER_DETAIL_PROJECTION,
)
from tests.cluster.controller.conftest import FakeProvider
from iris.cluster.controller.transitions import (
Assignment,
ControllerTransitions,
Expand Down Expand Up @@ -122,11 +121,14 @@ def test_fail_heartbeat_returns_transient_and_worker_stays_alive(state, worker_m
assert q.fetchone("SELECT 1 FROM workers WHERE worker_id = ?", ("worker1",)) is not None


def test_ping_failures_accumulate_and_terminate_inline(tmp_path, worker_metadata):
def test_ping_failures_accumulate_and_terminate_inline(tmp_path, worker_metadata, make_controller):
"""Ten consecutive ping failures via _handle_failed_heartbeats terminates the worker inline."""
db = ControllerDB(db_dir=tmp_path)
config = ControllerConfig(remote_state_dir="file:///tmp/iris-test-state", local_state_dir=tmp_path)
controller = Controller(config=config, provider=FakeProvider(), db=db)
controller = make_controller(
remote_state_dir="file:///tmp/iris-test-state",
local_state_dir=tmp_path,
db=db,
)
state = controller.state
_register_worker(state, "worker1", worker_metadata, address="10.0.0.1:10001")

Expand All @@ -150,9 +152,6 @@ def test_ping_failures_accumulate_and_terminate_inline(tmp_path, worker_metadata
with db.snapshot() as q:
assert q.fetchone("SELECT 1 FROM workers WHERE worker_id = ?", ("worker1",)) is None

controller.stop()
db.close()


def test_complete_heartbeat_unhealthy_worker_increments_failures(state, worker_metadata):
"""Worker reporting unhealthy increments failure count (not immediate removal)."""
Expand Down Expand Up @@ -263,10 +262,13 @@ def close(self) -> None:
pass


def test_handle_failed_heartbeats_logs_diagnostics(tmp_path, worker_metadata, caplog):
def test_handle_failed_heartbeats_logs_diagnostics(tmp_path, worker_metadata, caplog, make_controller):
db = ControllerDB(db_dir=tmp_path)
config = ControllerConfig(remote_state_dir="file:///tmp/iris-test-state", local_state_dir=tmp_path)
controller = Controller(config=config, provider=FakeProvider(), db=db)
controller = make_controller(
remote_state_dir="file:///tmp/iris-test-state",
local_state_dir=tmp_path,
db=db,
)
state = controller.state
_register_worker(state, "worker1", worker_metadata, address="10.0.0.1:10001")

Expand All @@ -292,8 +294,6 @@ def test_handle_failed_heartbeats_logs_diagnostics(tmp_path, worker_metadata, ca
assert "last_success_age_s=" in caplog.text
assert "deadline exceeded after 12000ms" in caplog.text

controller.stop()


def test_rpc_worker_stub_factory_default_timeout(monkeypatch):
captured: dict[str, object] = {}
Expand Down
Loading
Loading