diff --git a/lib/iris/docs/task-states.md b/lib/iris/docs/task-states.md index e95511ca6b..544940587b 100644 --- a/lib/iris/docs/task-states.md +++ b/lib/iris/docs/task-states.md @@ -55,6 +55,17 @@ independent job state machine. v (terminal) + +-----------+ | + | PREEMPTED | | + +-----------+ | + ^ | + | exhausted | + | | + +-----------+ | + | preempt |------------+ + | (ctrl) | + +-----------+ + Other terminal states: KILLED, UNSCHEDULABLE (never retried) ``` @@ -72,6 +83,7 @@ independent job state machine. | `KILLED` | 6 | Yes | No | Controller: job cancellation (`_on_job_cancelled`), job failure cascade (`_mark_remaining_tasks_killed`), per-task timeout | `killed` (grey) | | `WORKER_FAILED` | 7 | Yes | Yes | Controller: worker death cascade (`_on_worker_failed`), coscheduled sibling kill | `worker_failed` (purple) | | `UNSCHEDULABLE` | 8 | Yes | No | Controller: scheduling timeout expired (`_mark_task_unschedulable`) | `unschedulable` (red) | +| `PREEMPTED` | 10 | Yes | Yes | Controller: priority preemption with budget exhausted (`preempt_task`) | `preempted` (orange) | ## State Transitions in Detail @@ -171,6 +183,31 @@ terminally, `_cascade_coscheduled_failure` exhausts the preemption budget of all running siblings and transitions them to `WORKER_FAILED` (terminal). This prevents other hosts from hanging on collective operations. +### PREEMPTED + +Set by the controller when a higher-priority task evicts a lower-priority +running task via `preempt_task`. The preemption loop (`_run_preemption_pass`) +selects victims from lower priority bands and calls `preempt_task` for each. + +Retry evaluation uses `_resolve_task_failure_state` with the preemption budget: + +1. **ASSIGNED tasks**: always retry to PENDING regardless of budget (the task + never started executing, so preemption is free). +2. **BUILDING or RUNNING tasks**: `preemption_count` is incremented and compared + against `max_retries_preemption`. + - If `preemption_count <= max_retries_preemption`: task is requeued to + `PENDING` for retry. The current attempt is marked `PREEMPTED`. + - If `preemption_count > max_retries_preemption`: task state is set to + `PREEMPTED` (terminal). Both the attempt and the task are `PREEMPTED`. + +`PREEMPTED` is in both `TERMINAL_TASK_STATES` and `FAILURE_TASK_STATES`. +When a coscheduled task becomes terminally `PREEMPTED`, the job state is +recomputed. If all tasks in the job are terminal, `_finalize_terminal_job` +kills any remaining non-terminal tasks and cascades to child jobs. Note that +unlike `WORKER_FAILED` reported via heartbeat, `preempt_task` does not +directly cascade coscheduled siblings — the cascade only occurs through job +finalization. + ### UNSCHEDULABLE Set by the controller's scheduling loop when a task's scheduling deadline @@ -185,10 +222,10 @@ are killed. Iris maintains two independent retry budgets per task: -| Budget | Counter | Limit Field | Default | Trigger State | +| Budget | Counter | Limit Field | Default | Trigger States | |---|---|---|---|---| | Failure | `failure_count` | `max_retries_failure` | 0 (no retries) | `FAILED` | -| Preemption | `preemption_count` | `max_retries_preemption` | 100 | `WORKER_FAILED` | +| Preemption | `preemption_count` | `max_retries_preemption` | 100 | `WORKER_FAILED`, `PREEMPTED` | ### Retry flow @@ -209,15 +246,18 @@ Iris maintains two independent retry budgets per task: ### What counts toward job failure Only `TASK_STATE_FAILED` counts toward the job's `max_task_failures` threshold. -Worker failures (preemptions) do not count. This means a job can survive +Worker failures and preemptions do not count. This means a job can survive unlimited preemptions as long as the per-task preemption budget is not -exhausted. +exhausted. `TASK_STATE_PREEMPTED` and `TASK_STATE_WORKER_FAILED` are grouped +together for job state derivation: if all tasks are terminal and any are in +one of these states, the job becomes `JOB_STATE_WORKER_FAILED`. ### States that are never retried - `SUCCEEDED`: task completed successfully - `KILLED`: explicit termination by user or cascade - `UNSCHEDULABLE`: scheduling timeout expired +- `PREEMPTED`: only when preemption budget is exhausted (otherwise retried as `PENDING`) ## Terminal State Summary @@ -230,6 +270,7 @@ A task is considered finished (`is_finished() == True`) when: | `UNSCHEDULABLE` | Always finished | | `FAILED` | Finished when `failure_count > max_retries_failure` | | `WORKER_FAILED` | Finished when `preemption_count > max_retries_preemption` | +| `PREEMPTED` | Finished when `preemption_count > max_retries_preemption` | The distinction matters: a task in `FAILED` state with retry budget remaining is in a terminal state at the attempt level but is not finished at the task @@ -252,6 +293,7 @@ strings (e.g., `TASK_STATE_RUNNING`) to lowercase display names by stripping the | `killed` | `.status-killed` | Grey (#57606a) | | `worker_failed` | `.status-worker_failed` | Purple (#8250df) | | `unschedulable` | `.status-unschedulable` | Red (#cf222e) | +| `preempted` | `.status-preempted` | Orange (#bc4c00) | The job detail page shows per-task attempt history. Each attempt has its own state badge, and worker failures are annotated with "(worker failure)" in the diff --git a/lib/iris/src/iris/cli/cluster.py b/lib/iris/src/iris/cli/cluster.py index 3a923ca527..5e18dc4dd3 100644 --- a/lib/iris/src/iris/cli/cluster.py +++ b/lib/iris/src/iris/cli/cluster.py @@ -20,9 +20,9 @@ find_marin_root, get_git_sha, ) -from iris.cli.main import require_controller_url +from iris.cli.main import require_controller_url, rpc_client from iris.cluster.config import IrisConfig, clear_remote_state, make_local_config -from iris.rpc import cluster_connect, cluster_pb2, vm_pb2 +from iris.rpc import cluster_pb2, vm_pb2 from iris.rpc.proto_utils import format_accelerator_display, vm_state_name from iris.time_proto import timestamp_from_proto from rigging.timing import Duration, ExponentialBackoff, Timestamp @@ -56,15 +56,15 @@ def _format_status_table(status: vm_pb2.AutoscalerStatus) -> str: def _get_autoscaler_status(controller_url: str) -> vm_pb2.AutoscalerStatus: - client = cluster_connect.ControllerServiceClientSync(controller_url) - request = cluster_pb2.Controller.GetAutoscalerStatusRequest() - return client.get_autoscaler_status(request).status + with rpc_client(controller_url) as client: + request = cluster_pb2.Controller.GetAutoscalerStatusRequest() + return client.get_autoscaler_status(request).status def _get_worker_status(controller_url: str, worker_id: str) -> cluster_pb2.Controller.GetWorkerStatusResponse: - client = cluster_connect.ControllerServiceClientSync(controller_url) - request = cluster_pb2.Controller.GetWorkerStatusRequest(id=worker_id) - return client.get_worker_status(request) + with rpc_client(controller_url) as client: + request = cluster_pb2.Controller.GetWorkerStatusRequest(id=worker_id) + return client.get_worker_status(request) def _parse_ghcr_tag(image_tag: str) -> tuple[str, str, str] | None: @@ -322,20 +322,20 @@ def cluster_start_smoke(ctx, label_prefix, url_file, min_workers, worker_timeout with bundle.controller.tunnel(address) as url: click.echo(f"Tunnel ready: {url}") - client = cluster_connect.ControllerServiceClientSync(url, timeout_ms=30000) - deadline = time.monotonic() + worker_timeout - healthy_count = 0 - while time.monotonic() < deadline: - workers = client.list_workers(cluster_pb2.Controller.ListWorkersRequest()).workers - healthy = [w for w in workers if w.healthy] - healthy_count = len(healthy) - if healthy_count >= min_workers: - break - time.sleep(2) - else: - raise click.ClickException( - f"Only {healthy_count} of {min_workers} workers healthy after {worker_timeout}s" - ) + with rpc_client(url) as client: + deadline = time.monotonic() + worker_timeout + healthy_count = 0 + while time.monotonic() < deadline: + workers = client.list_workers(cluster_pb2.Controller.ListWorkersRequest()).workers + healthy = [w for w in workers if w.healthy] + healthy_count = len(healthy) + if healthy_count >= min_workers: + break + time.sleep(2) + else: + raise click.ClickException( + f"Only {healthy_count} of {min_workers} workers healthy after {worker_timeout}s" + ) click.echo(f"{healthy_count} workers ready, writing URL to {url_file}") Path(url_file).write_text(url) @@ -398,10 +398,10 @@ def cluster_status_cmd(ctx): controller_url = require_controller_url(ctx) click.echo("Checking controller status...") try: - client = cluster_connect.ControllerServiceClientSync(controller_url) - proc = client.get_process_status(cluster_pb2.GetProcessStatusRequest()).process_info - workers = client.list_workers(cluster_pb2.Controller.ListWorkersRequest()).workers - as_status = client.get_autoscaler_status(cluster_pb2.Controller.GetAutoscalerStatusRequest()).status + with rpc_client(controller_url) as client: + proc = client.get_process_status(cluster_pb2.GetProcessStatusRequest()).process_info + workers = client.list_workers(cluster_pb2.Controller.ListWorkersRequest()).workers + as_status = client.get_autoscaler_status(cluster_pb2.Controller.GetAutoscalerStatusRequest()).status healthy = sum(1 for w in workers if w.healthy) click.echo("Controller Status:") click.echo(" Running: True") @@ -629,12 +629,12 @@ def controller_checkpoint(ctx, stop: bool): briefly and writes a consistent checkpoint DB copy. """ controller_url = require_controller_url(ctx) - client = cluster_connect.ControllerServiceClientSync(controller_url) - try: - resp = client.begin_checkpoint(cluster_pb2.Controller.BeginCheckpointRequest(), timeout_ms=60_000) - except Exception as e: - click.echo(f"Checkpoint failed: {e}", err=True) - raise SystemExit(1) from e + with rpc_client(controller_url) as client: + try: + resp = client.begin_checkpoint(cluster_pb2.Controller.BeginCheckpointRequest(), timeout_ms=60_000) + except Exception as e: + click.echo(f"Checkpoint failed: {e}", err=True) + raise SystemExit(1) from e click.echo(f"Checkpoint DB written: {resp.checkpoint_path}") click.echo(f" Jobs: {resp.job_count}") @@ -717,17 +717,15 @@ def controller_restart(ctx, skip_checkpoint: bool, checkpoint_timeout: int): click.echo("Skipping pre-restart checkpoint.") else: click.echo(f"Taking checkpoint (timeout {checkpoint_timeout}s)...") - client = cluster_connect.ControllerServiceClientSync(controller_url) - try: - resp = client.begin_checkpoint( - cluster_pb2.Controller.BeginCheckpointRequest(), - timeout_ms=checkpoint_timeout * 1000, - ) - except Exception as e: - click.echo(f"Checkpoint failed: {e}", err=True) - raise SystemExit(1) from e - finally: - client.close() + with rpc_client(controller_url) as client: + try: + resp = client.begin_checkpoint( + cluster_pb2.Controller.BeginCheckpointRequest(), + timeout_ms=checkpoint_timeout * 1000, + ) + except Exception as e: + click.echo(f"Checkpoint failed: {e}", err=True) + raise SystemExit(1) from e click.echo(f"Checkpoint: {resp.checkpoint_path} ({resp.job_count} jobs, {resp.worker_count} workers)") # Build fresh images so the new controller VM gets the latest code @@ -759,59 +757,59 @@ def worker_restart(ctx, worker_id: str | None, timeout: int): new worker process. """ controller_url = require_controller_url(ctx) - client = cluster_connect.ControllerServiceClientSync(controller_url) - # Get current workers - workers_resp = client.list_workers(cluster_pb2.Controller.ListWorkersRequest()) - workers = workers_resp.workers - - if worker_id: - workers = [w for w in workers if w.worker_id == worker_id] - if not workers: - click.echo(f"Worker {worker_id} not found", err=True) - raise SystemExit(1) + with rpc_client(controller_url) as client: + # Get current workers + workers_resp = client.list_workers(cluster_pb2.Controller.ListWorkersRequest()) + workers = workers_resp.workers - if not workers: - click.echo("No workers to restart") - return + if worker_id: + workers = [w for w in workers if w.worker_id == worker_id] + if not workers: + click.echo(f"Worker {worker_id} not found", err=True) + raise SystemExit(1) - click.echo(f"Restarting {len(workers)} worker(s) (timeout={timeout}s per worker)") + if not workers: + click.echo("No workers to restart") + return - succeeded = 0 - failed = 0 + click.echo(f"Restarting {len(workers)} worker(s) (timeout={timeout}s per worker)") - for worker in workers: - wid = worker.worker_id - click.echo(f"\nRestarting worker {wid}...") + succeeded = 0 + failed = 0 - resp = client.restart_worker( - cluster_pb2.Controller.RestartWorkerRequest(worker_id=wid), - timeout_ms=timeout * 1000, - ) + for worker in workers: + wid = worker.worker_id + click.echo(f"\nRestarting worker {wid}...") - if not resp.accepted: - click.echo(f" Failed: {resp.error}", err=True) - failed += 1 - continue + resp = client.restart_worker( + cluster_pb2.Controller.RestartWorkerRequest(worker_id=wid), + timeout_ms=timeout * 1000, + ) - # Poll until the worker re-registers as healthy - def _worker_healthy(target_id: str = wid) -> bool: - try: - resp = client.list_workers(cluster_pb2.Controller.ListWorkersRequest()) - return any(w.worker_id == target_id and w.healthy for w in resp.workers) - except Exception: - return False - - reregistered = ExponentialBackoff(initial=5.0, maximum=5.0, jitter=0.0).wait_until( - _worker_healthy, - timeout=Duration.from_seconds(timeout), - ) + if not resp.accepted: + click.echo(f" Failed: {resp.error}", err=True) + failed += 1 + continue + + # Poll until the worker re-registers as healthy + def _worker_healthy(target_id: str = wid) -> bool: + try: + resp = client.list_workers(cluster_pb2.Controller.ListWorkersRequest()) + return any(w.worker_id == target_id and w.healthy for w in resp.workers) + except Exception: + return False + + reregistered = ExponentialBackoff(initial=5.0, maximum=5.0, jitter=0.0).wait_until( + _worker_healthy, + timeout=Duration.from_seconds(timeout), + ) - if reregistered: - click.echo(f" Worker {wid} restarted successfully") - succeeded += 1 - else: - click.echo(f" Worker {wid} did not re-register within {timeout}s", err=True) - failed += 1 + if reregistered: + click.echo(f" Worker {wid} restarted successfully") + succeeded += 1 + else: + click.echo(f" Worker {wid} did not re-register within {timeout}s", err=True) + failed += 1 click.echo(f"\nDone: {succeeded} succeeded, {failed} failed") diff --git a/lib/iris/src/iris/cli/main.py b/lib/iris/src/iris/cli/main.py index e57bd02612..ff20af310d 100644 --- a/lib/iris/src/iris/cli/main.py +++ b/lib/iris/src/iris/cli/main.py @@ -14,7 +14,8 @@ from iris.cli.token_store import cluster_name_from_url, load_any_token, load_token, store_token from rigging.log_setup import configure_logging from iris.rpc import cluster_pb2 as _cluster_pb2, config_pb2 -from iris.rpc.auth import GcpAccessTokenProvider, StaticTokenProvider, TokenProvider +from iris.rpc.auth import AuthTokenInjector, GcpAccessTokenProvider, StaticTokenProvider, TokenProvider +from iris.rpc.cluster_connect import ControllerServiceClientSync from iris.rpc.proto_utils import PRIORITY_BAND_NAMES, priority_band_name, priority_band_value logger = _logging_module.getLogger(__name__) @@ -71,6 +72,16 @@ def _configure_client_s3(config) -> None: configure_client_s3(config) +def rpc_client( + address: str, + token_provider: TokenProvider | None = None, + timeout_ms: int = 30_000, +) -> ControllerServiceClientSync: + """Create an RPC client with optional auth. Use as a context manager: ``with rpc_client(url) as c:``.""" + interceptors = [AuthTokenInjector(token_provider)] if token_provider else [] + return ControllerServiceClientSync(address, timeout_ms=timeout_ms, interceptors=interceptors) + + def require_controller_url(ctx: click.Context) -> str: """Get controller_url from context, establishing a tunnel lazily if needed. @@ -199,18 +210,15 @@ def login(ctx): config = ctx.obj.get("config") from iris.rpc import cluster_pb2 - from iris.rpc.cluster_connect import ControllerServiceClientSync if config and config.HasField("auth"): provider = config.auth.WhichOneof("provider") else: - client = ControllerServiceClientSync(address=controller_url, timeout_ms=30000) - try: - auth_info = client.get_auth_info(cluster_pb2.GetAuthInfoRequest()) - except Exception as e: - raise click.ClickException(f"Failed to discover auth method: {e}") from e - finally: - client.close() + with rpc_client(controller_url) as client: + try: + auth_info = client.get_auth_info(cluster_pb2.GetAuthInfoRequest()) + except Exception as e: + raise click.ClickException(f"Failed to discover auth method: {e}") from e provider = auth_info.provider or None if not provider: raise click.ClickException("Controller has no authentication configured") @@ -232,13 +240,11 @@ def login(ctx): raise click.ClickException(f"Unsupported auth provider: {provider}") # All providers converge: exchange identity_token for JWT via Login RPC - client = ControllerServiceClientSync(address=controller_url, timeout_ms=30000) - try: - response = client.login(cluster_pb2.LoginRequest(identity_token=identity_token)) - except Exception as e: - raise click.ClickException(f"Login failed: {e}") from e - finally: - client.close() + with rpc_client(controller_url) as client: + try: + response = client.login(cluster_pb2.LoginRequest(identity_token=identity_token)) + except Exception as e: + raise click.ClickException(f"Login failed: {e}") from e cluster_name = ctx.obj.get("cluster_name", "default") store_token(cluster_name, controller_url, response.token) @@ -249,15 +255,6 @@ def login(ctx): click.echo(f"Token stored for cluster '{cluster_name}'") -def _make_authenticated_client(controller_url: str, token_provider: TokenProvider | None): - """Create a ControllerServiceClientSync with auth interceptor if available.""" - from iris.rpc.auth import AuthTokenInjector - from iris.rpc.cluster_connect import ControllerServiceClientSync - - interceptors = [AuthTokenInjector(token_provider)] if token_provider else [] - return ControllerServiceClientSync(address=controller_url, timeout_ms=30000, interceptors=interceptors) - - @iris.group() @click.pass_context def key(ctx): @@ -277,11 +274,8 @@ def key_create(ctx, name: str, user_id: str, ttl_ms: int): from iris.rpc import cluster_pb2 - client = _make_authenticated_client(controller_url, token_provider) - try: + with rpc_client(controller_url, token_provider) as client: response = client.create_api_key(cluster_pb2.CreateApiKeyRequest(user_id=user_id, name=name, ttl_ms=ttl_ms)) - finally: - client.close() click.echo(f"Key ID: {response.key_id}") click.echo(f"Token: {response.token}") @@ -299,11 +293,8 @@ def key_list(ctx, user_id: str): from iris.rpc import cluster_pb2 - client = _make_authenticated_client(controller_url, token_provider) - try: + with rpc_client(controller_url, token_provider) as client: response = client.list_api_keys(cluster_pb2.ListApiKeysRequest(user_id=user_id)) - finally: - client.close() if not response.keys: click.echo("No API keys found.") @@ -324,11 +315,8 @@ def key_revoke(ctx, key_id: str): from iris.rpc import cluster_pb2 - client = _make_authenticated_client(controller_url, token_provider) - try: + with rpc_client(controller_url, token_provider) as client: client.revoke_api_key(cluster_pb2.RevokeApiKeyRequest(key_id=key_id)) - finally: - client.close() click.echo(f"Revoked key: {key_id}") @@ -367,8 +355,7 @@ def budget_set(ctx, user_id: str, budget_limit: int, max_band: str): controller_url = require_controller_url(ctx) token_provider = ctx.obj.get("token_provider") - client = _make_authenticated_client(controller_url, token_provider) - try: + with rpc_client(controller_url, token_provider) as client: client.set_user_budget( _cluster_pb2.Controller.SetUserBudgetRequest( user_id=user_id, @@ -376,8 +363,6 @@ def budget_set(ctx, user_id: str, budget_limit: int, max_band: str): max_band=priority_band_value(max_band), ) ) - finally: - client.close() click.echo(f"Budget set for {user_id}: limit={budget_limit}, max_band={max_band}") @@ -390,11 +375,8 @@ def budget_get(ctx, user_id: str): controller_url = require_controller_url(ctx) token_provider = ctx.obj.get("token_provider") - client = _make_authenticated_client(controller_url, token_provider) - try: + with rpc_client(controller_url, token_provider) as client: resp = client.get_user_budget(_cluster_pb2.Controller.GetUserBudgetRequest(user_id=user_id)) - finally: - client.close() click.echo(f"User: {resp.user_id}") click.echo(f"Limit: {resp.budget_limit}") @@ -409,11 +391,8 @@ def budget_list(ctx): controller_url = require_controller_url(ctx) token_provider = ctx.obj.get("token_provider") - client = _make_authenticated_client(controller_url, token_provider) - try: + with rpc_client(controller_url, token_provider) as client: resp = client.list_user_budgets(_cluster_pb2.Controller.ListUserBudgetsRequest()) - finally: - client.close() if not resp.users: click.echo("No user budgets found.") diff --git a/lib/iris/src/iris/cli/process_status.py b/lib/iris/src/iris/cli/process_status.py index 8a7190ebcd..a7f0ea73da 100644 --- a/lib/iris/src/iris/cli/process_status.py +++ b/lib/iris/src/iris/cli/process_status.py @@ -9,13 +9,11 @@ controller itself. """ - import click import humanfriendly -from iris.cli.main import require_controller_url +from iris.cli.main import require_controller_url, rpc_client from iris.rpc import cluster_pb2 -from iris.rpc.cluster_connect import ControllerServiceClientSync _CONTROLLER_TARGET = "/system/process" @@ -55,10 +53,10 @@ def status(ctx, target: str | None, as_json: bool): from google.protobuf import json_format url = require_controller_url(ctx) - client = ControllerServiceClientSync(url) label = target or "Controller" - # GetProcessStatus uses empty string for controller - resp = client.get_process_status(cluster_pb2.GetProcessStatusRequest(max_log_lines=0, target=target or "")) + with rpc_client(url) as client: + # GetProcessStatus uses empty string for controller + resp = client.get_process_status(cluster_pb2.GetProcessStatusRequest(max_log_lines=0, target=target or "")) if as_json: click.echo(json_format.MessageToJson(resp.process_info, preserving_proto_field_name=True, indent=2)) else: @@ -83,36 +81,36 @@ def logs(ctx, target: str | None, level: str, follow: bool, max_lines: int, subs from datetime import datetime, timezone url = require_controller_url(ctx) - client = ControllerServiceClientSync(url) source = target or _CONTROLLER_TARGET - cursor = 0 - first = True - while True: - req = cluster_pb2.FetchLogsRequest( - source=source, - max_lines=max_lines if first else 100, - tail=first, - min_level=level, - cursor=cursor if not first else 0, - ) - if substring: - req.substring = substring - - resp = client.fetch_logs(req) - for entry in resp.entries: - ts = "" - if entry.timestamp and entry.timestamp.epoch_ms: - dt = datetime.fromtimestamp(entry.timestamp.epoch_ms / 1000, tz=timezone.utc) - ts = dt.strftime("%H:%M:%S") - click.echo(f"[{ts}] {entry.data}") - - cursor = resp.cursor - first = False - - if not follow: - break - time.sleep(2) + with rpc_client(url) as client: + cursor = 0 + first = True + while True: + req = cluster_pb2.FetchLogsRequest( + source=source, + max_lines=max_lines if first else 100, + tail=first, + min_level=level, + cursor=cursor if not first else 0, + ) + if substring: + req.substring = substring + + resp = client.fetch_logs(req) + for entry in resp.entries: + ts = "" + if entry.timestamp and entry.timestamp.epoch_ms: + dt = datetime.fromtimestamp(entry.timestamp.epoch_ms / 1000, tz=timezone.utc) + ts = dt.strftime("%H:%M:%S") + click.echo(f"[{ts}] {entry.data}") + + cursor = resp.cursor + first = False + + if not follow: + break + time.sleep(2) @process_group.command() @@ -141,7 +139,6 @@ def profile( /system/worker/ for a worker, /alice/job/0 for a task container. """ url = require_controller_url(ctx) - client = ControllerServiceClientSync(url) rpc_target = target or _CONTROLLER_TARGET label = target or "Controller" @@ -157,13 +154,14 @@ def profile( raise click.ClickException(f"Unknown profiler type: {profiler}") click.echo(f"Profiling {label} ({profiler}, {duration}s)...") - resp = client.profile_task( - cluster_pb2.ProfileTaskRequest( - target=rpc_target, - duration_seconds=duration, - profile_type=profile_type, + with rpc_client(url) as client: + resp = client.profile_task( + cluster_pb2.ProfileTaskRequest( + target=rpc_target, + duration_seconds=duration, + profile_type=profile_type, + ) ) - ) if resp.error: raise click.ClickException(f"Profiling failed: {resp.error}") diff --git a/lib/iris/src/iris/cli/query.py b/lib/iris/src/iris/cli/query.py index 7a854bf925..36c6a8070e 100644 --- a/lib/iris/src/iris/cli/query.py +++ b/lib/iris/src/iris/cli/query.py @@ -10,7 +10,7 @@ import click from tabulate import tabulate -from iris.cli.main import _make_authenticated_client, require_controller_url +from iris.cli.main import require_controller_url, rpc_client from iris.rpc import query_pb2 @@ -68,18 +68,15 @@ def query_cmd(ctx: click.Context, sql: str, fmt: str) -> None: """ controller_url = require_controller_url(ctx) token_provider = ctx.obj.get("token_provider") if ctx.obj else None - client = _make_authenticated_client(controller_url, token_provider) - try: + with rpc_client(controller_url, token_provider) as client: request = query_pb2.RawQueryRequest(sql=sql) response = client.execute_raw_query(request) - columns = list(response.columns) - rows = _parse_rows(list(response.rows)) - formatter = _FORMATTERS[fmt] - output = formatter(columns, rows) + columns = list(response.columns) + rows = _parse_rows(list(response.rows)) + formatter = _FORMATTERS[fmt] + output = formatter(columns, rows) - if output: - click.echo(output) - finally: - client.close() + if output: + click.echo(output) diff --git a/lib/iris/src/iris/cli/rpc.py b/lib/iris/src/iris/cli/rpc.py index 656908f2d6..2019bec1e1 100644 --- a/lib/iris/src/iris/cli/rpc.py +++ b/lib/iris/src/iris/cli/rpc.py @@ -187,8 +187,11 @@ def call_rpc( interceptors = [AuthTokenInjector(token_provider)] if token_provider else [] client = service.client_class(url, interceptors=interceptors) - method_fn = getattr(client, method.method_fn_name) - return method_fn(request) + try: + method_fn = getattr(client, method.method_fn_name) + return method_fn(request) + finally: + client.close() def format_response(response: Message) -> str: diff --git a/lib/iris/src/iris/cli/task.py b/lib/iris/src/iris/cli/task.py index 111df1add9..25808e97dc 100644 --- a/lib/iris/src/iris/cli/task.py +++ b/lib/iris/src/iris/cli/task.py @@ -12,7 +12,7 @@ import click -from iris.cli.main import require_controller_url +from iris.cli.main import require_controller_url, rpc_client from iris.rpc import cluster_pb2 from iris.rpc.auth import TokenProvider @@ -50,19 +50,7 @@ def task_exec(ctx, task_id: str, command: tuple[str, ...], timeout_seconds: int) controller_url = require_controller_url(ctx) token_provider: TokenProvider | None = ctx.obj.get("token_provider") - from iris.rpc.cluster_connect import ControllerServiceClientSync - from iris.rpc.auth import AuthTokenInjector - - interceptors = [] - if token_provider: - interceptors.append(AuthTokenInjector(token_provider)) - - client = ControllerServiceClientSync( - address=controller_url, - interceptors=interceptors, - ) - - try: + with rpc_client(controller_url, token_provider) as client: request = cluster_pb2.Controller.ExecInContainerRequest( task_id=task_id, command=list(command), @@ -70,15 +58,13 @@ def task_exec(ctx, task_id: str, command: tuple[str, ...], timeout_seconds: int) ) response = client.exec_in_container(request) - if response.error: - click.echo(f"Error: {response.error}", err=True) - sys.exit(1) + if response.error: + click.echo(f"Error: {response.error}", err=True) + sys.exit(1) - if response.stdout: - click.echo(response.stdout, nl=False) - if response.stderr: - click.echo(response.stderr, nl=False, err=True) + if response.stdout: + click.echo(response.stdout, nl=False) + if response.stderr: + click.echo(response.stderr, nl=False, err=True) - sys.exit(response.exit_code) - finally: - client.close() + sys.exit(response.exit_code) diff --git a/lib/iris/src/iris/cluster/controller/controller.py b/lib/iris/src/iris/cluster/controller/controller.py index 812e7f26c6..7f966f979d 100644 --- a/lib/iris/src/iris/cluster/controller/controller.py +++ b/lib/iris/src/iris/cluster/controller/controller.py @@ -162,6 +162,45 @@ class PreemptionCandidate: band: int # proto PriorityBand value +@dataclass +class _SyncFailureAccumulator: + """Mutable accumulator for tracking failures during provider sync.""" + + fail_count: int = 0 + failed_workers: list[str] = field(default_factory=list) + all_tasks_to_kill: set[JobName] = field(default_factory=set) + all_task_kill_workers: dict[JobName, WorkerId] = field(default_factory=dict) + + +@dataclass(frozen=True) +class _SchedulingStateRead: + """Snapshot of pending tasks and workers read at the start of a scheduling cycle.""" + + pending_tasks: list[TaskRow] + workers: list[WorkerRow] + state_read_ms: int + + +@dataclass(frozen=True) +class _GatedCandidates: + """Tasks that passed deadline, reservation, and per-job-cap gates.""" + + schedulable_task_ids: list[JobName] + jobs: dict[JobName, JobRequirements] + has_reservation: set[JobName] + has_direct_reservation: set[JobName] + + +@dataclass(frozen=True) +class _SchedulingOrder: + """Priority-ordered task list with budget context for preemption.""" + + ordered_task_ids: list[JobName] + task_band_map: dict[JobName, int] + user_spend: dict[str, int] + user_budget_limits: dict[str, int] + + def job_requirements_from_job(job: JobSchedulingRow) -> JobRequirements: """Convert a job row to scheduler-compatible JobRequirements.""" return JobRequirements( @@ -1028,9 +1067,10 @@ def __init__( # are only valid before the controller loops begin (e.g. LoadCheckpoint). self._started = False - # Checkpoint coordination flag. When set, scheduling and autoscaler - # loops skip their work so the snapshot captures a quiescent state. - self._checkpoint_in_progress = False + # Checkpoint coordination: when set, scheduling and autoscaler loops + # skip their work so the snapshot captures a quiescent state. + # threading.Event (not a bare bool) for cross-thread memory ordering. + self._checkpoint_paused = threading.Event() self._atexit_registered = False # Serializes heartbeat rounds against checkpoint snapshots so that @@ -1178,7 +1218,7 @@ def _run_scheduling_loop(self, stop_event: threading.Event) -> None: if stop_event.is_set(): break - if self._checkpoint_in_progress: + if self._checkpoint_paused.is_set(): continue if woken: @@ -1227,7 +1267,7 @@ def _run_autoscaler_loop(self, stop_event: threading.Event) -> None: while not stop_event.is_set(): if not limiter.wait(cancel=stop_event): break - if self._checkpoint_in_progress: + if self._checkpoint_paused.is_set(): continue try: self._run_autoscaler_once() @@ -1250,7 +1290,7 @@ def _run_provider_loop(self, stop_event: threading.Event) -> None: limiter.mark_run() if stop_event.is_set(): break - if self._checkpoint_in_progress: + if self._checkpoint_paused.is_set(): continue try: with self._heartbeat_lock: @@ -1267,7 +1307,7 @@ def _run_direct_provider_loop(self, stop_event: threading.Event) -> None: limiter.mark_run() if stop_event.is_set(): break - if self._checkpoint_in_progress: + if self._checkpoint_paused.is_set(): continue try: self._sync_direct_provider() @@ -1310,7 +1350,7 @@ def _run_profile_loop(self, stop_event: threading.Event) -> None: if stop_event.is_set(): break limiter.mark_run() - if self._checkpoint_in_progress: + if self._checkpoint_paused.is_set(): continue try: self._profile_all_running_tasks() @@ -1533,10 +1573,43 @@ def _run_scheduling(self) -> SchedulingOutcome: is serialized by ControllerDB._lock with multi-statement mutations wrapped in BEGIN IMMEDIATE transactions. """ - # Reservation claims are read and updated outside the scheduling transaction. - # This creates a narrow race window where a worker could be removed between - # claim reads and scheduling, but it's benign: queue_assignments() re-validates - # all assignments transactionally, and stale claims are cleaned up next cycle. + claims = self._refresh_reservation_claims() + + timer = Timer() + state = self._read_scheduling_state() + + if not state.pending_tasks: + self._scheduling_diagnostics = {} + return SchedulingOutcome.NO_PENDING_TASKS + + gated = self._apply_scheduling_gates(state.pending_tasks, claims) + + if not gated.schedulable_task_ids: + self._scheduling_diagnostics = {} + return SchedulingOutcome.NO_PENDING_TASKS + + order = self._compute_scheduling_order( + gated.schedulable_task_ids, + state.pending_tasks, + gated.jobs, + ) + + all_assignments, context, tainted_jobs = self._run_scheduler_pass(order, gated, state, claims, timer) + + preemptions = self._apply_preemptions(order, tainted_jobs, all_assignments, claims, context) + + self._cache_scheduling_diagnostics(context, tainted_jobs, all_assignments, order.ordered_task_ids) + + if all_assignments or preemptions: + return SchedulingOutcome.ASSIGNMENTS_MADE + return SchedulingOutcome.NO_ASSIGNMENTS + + def _refresh_reservation_claims(self) -> dict[WorkerId, ReservationClaim]: + """Read, clean up, and refresh reservation claims. Returns updated claims.""" + # Claims are read outside the scheduling transaction. This creates a + # narrow race window where a worker could be removed between claim reads + # and scheduling, but it's benign: queue_assignments() re-validates all + # assignments transactionally, and stale claims are cleaned up next cycle. claims = _read_reservation_claims(self._db) claims_changed = self._cleanup_stale_claims(claims) claims_changed = self._claim_workers_for_reservations(claims) or claims_changed @@ -1545,20 +1618,26 @@ def _run_scheduling(self) -> SchedulingOutcome: logger.info("[DRY-RUN] Would update %d reservation claims", len(claims)) else: self._transitions.replace_reservation_claims(claims) + return claims + def _read_scheduling_state(self) -> _SchedulingStateRead: + """Fetch pending tasks and healthy workers from the DB.""" timer = Timer() with slow_log(logger, "scheduling state reads", threshold_ms=50): pending_tasks = _schedulable_tasks(self._db) workers = healthy_active_workers_with_attributes(self._db) - state_read_ms = timer.elapsed_ms() - - if not pending_tasks: - self._scheduling_diagnostics = {} - return SchedulingOutcome.NO_PENDING_TASKS + return _SchedulingStateRead( + pending_tasks=pending_tasks, + workers=workers, + state_read_ms=timer.elapsed_ms(), + ) - # Handle timeouts and reservation gates before scheduling. - # Holder tasks participate in scheduling like normal tasks. - # Cap non-coscheduled tasks per job to bound scheduling CPU time. + def _apply_scheduling_gates( + self, + pending_tasks: list[TaskRow], + claims: dict[WorkerId, ReservationClaim], + ) -> _GatedCandidates: + """Filter tasks by deadline, reservation satisfaction, and per-job cap.""" schedulable_task_ids: list[JobName] = [] jobs: dict[JobName, JobRequirements] = {} has_reservation: set[JobName] = set() @@ -1591,15 +1670,24 @@ def _run_scheduling(self) -> SchedulingOutcome: has_direct_reservation.add(task.job_id) elif _find_reservation_ancestor(self._db, task.job_id) is not None: has_reservation.add(task.job_id) + return _GatedCandidates( + schedulable_task_ids=schedulable_task_ids, + jobs=jobs, + has_reservation=has_reservation, + has_direct_reservation=has_direct_reservation, + ) - if not schedulable_task_ids: - self._scheduling_diagnostics = {} - return SchedulingOutcome.NO_PENDING_TASKS + def _compute_scheduling_order( + self, + schedulable_task_ids: list[JobName], + pending_tasks: list[TaskRow], + jobs: dict[JobName, JobRequirements], + ) -> _SchedulingOrder: + """Compute priority-band interleaving and per-user cap. - # Per-band interleaving: group tasks by priority band, round-robin - # users within each band ordered by ascending budget spend. - # Budget down-weighting: users over budget have non-PRODUCTION tasks - # treated as BATCH for both scheduling order and preemption eligibility. + Maps tasks to effective bands (down-weighting over-budget users), + round-robins users within each band, and applies the per-user cap. + """ with self._db.read_snapshot() as budget_snapshot: user_spend = compute_user_spend(budget_snapshot) user_budget_limits = self._db.get_all_user_budget_limits() @@ -1617,7 +1705,6 @@ def _run_scheduling(self) -> SchedulingOutcome: band_tasks = tasks_by_band[band_key] user_tasks = [UserTask(user_id=tid.user, task=tid) for tid in band_tasks] interleaved.extend(interleave_by_user(user_tasks, user_spend)) - schedulable_task_ids = interleaved # Per-user cap: limit how many tasks a single user can have considered # per scheduling cycle, ensuring fairness. @@ -1625,31 +1712,44 @@ def _run_scheduling(self) -> SchedulingOutcome: if user_cap > 0: tasks_per_user: dict[str, int] = defaultdict(int) capped: list[JobName] = [] - for task_id in schedulable_task_ids: + for task_id in interleaved: if tasks_per_user[task_id.user] < user_cap: capped.append(task_id) tasks_per_user[task_id.user] += 1 - schedulable_task_ids = capped + interleaved = capped + + return _SchedulingOrder( + ordered_task_ids=interleaved, + task_band_map=task_band_map, + user_spend=user_spend, + user_budget_limits=user_budget_limits, + ) - # Inject reservation taints: claimed workers get a taint attribute, - # non-reservation jobs get a NOT_EXISTS constraint for it. - modified_workers = _inject_reservation_taints(workers, claims) - jobs = _inject_taint_constraints(jobs, has_reservation, has_direct_reservation) + def _run_scheduler_pass( + self, + order: _SchedulingOrder, + gated: _GatedCandidates, + state: _SchedulingStateRead, + claims: dict[WorkerId, ReservationClaim], + timer: Timer, + ) -> tuple[list[tuple[JobName, WorkerId]], SchedulingContext, dict[JobName, JobRequirements]]: + """Run preference + normal assignment passes. Returns (assignments, context, taint-injected jobs).""" + modified_workers = _inject_reservation_taints(state.workers, claims) + modified_jobs = _inject_taint_constraints(gated.jobs, gated.has_reservation, gated.has_direct_reservation) with slow_log(logger, "building_counts", threshold_ms=50): - building_counts = _building_counts(self._db, workers=workers) + building_counts = _building_counts(self._db, workers=state.workers) context = self._scheduler.create_scheduling_context( modified_workers, building_counts=building_counts, - pending_tasks=schedulable_task_ids, - jobs=jobs, + pending_tasks=order.ordered_task_ids, + jobs=modified_jobs, ) - # Phase 1: soft preference — steer reservation tasks toward claimed workers. + # Soft preference — steer reservation tasks toward claimed workers. # Skips coscheduled jobs (they need atomic all-or-nothing via find_assignments). - preference_assignments = _preference_pass(context, has_reservation, claims) + preference_assignments = _preference_pass(context, gated.has_reservation, claims) - # Phase 2: normal scheduler for all remaining tasks. result = self._scheduler.find_assignments(context) all_assignments = preference_assignments + result.assignments @@ -1662,26 +1762,34 @@ def _run_scheduling(self) -> SchedulingOutcome: len(preference_assignments), len(result.assignments), timer.elapsed_ms(), - state_read_ms, + state.state_read_ms, ) + return all_assignments, context, modified_jobs - # Phase 3: preemption — evict lower-priority running tasks for - # higher-priority unscheduled work. + def _apply_preemptions( + self, + order: _SchedulingOrder, + jobs: dict[JobName, JobRequirements], + all_assignments: list[tuple[JobName, WorkerId]], + claims: dict[WorkerId, ReservationClaim], + context: SchedulingContext, + ) -> list[tuple[JobName, JobName]]: + """Evict lower-priority running tasks for higher-priority unscheduled work.""" assigned_ids = {task_id for task_id, _ in all_assignments} unscheduled = [ PreemptionCandidate( job_name=tid, requirements=jobs[tid.parent], - band=task_band_map.get(tid, cluster_pb2.PRIORITY_BAND_INTERACTIVE), + band=order.task_band_map.get(tid, cluster_pb2.PRIORITY_BAND_INTERACTIVE), ) - for tid in schedulable_task_ids + for tid in order.ordered_task_ids if tid not in assigned_ids and tid.parent is not None and tid.parent in jobs ] preemptions: list[tuple[JobName, JobName]] = [] if unscheduled: claimed_workers = set(claims.keys()) running_info = _get_running_tasks_with_band_and_value( - self._db, claimed_workers, user_spend=user_spend, user_budget_limits=user_budget_limits + self._db, claimed_workers, user_spend=order.user_spend, user_budget_limits=order.user_budget_limits ) preemptions = _run_preemption_pass(unscheduled, running_info, context) for preemptor_name, victim_id in preemptions: @@ -1689,14 +1797,7 @@ def _run_scheduling(self) -> SchedulingOutcome: self.kill_tasks_on_workers(preempt_result.tasks_to_kill) if preemptions: logger.info("Preemption pass: %d tasks preempted", len(preemptions)) - - # Cache diagnostics for jobs that still have unassigned tasks. - # RPCs read from this cache instead of recomputing per request. - self._cache_scheduling_diagnostics(context, jobs, all_assignments, schedulable_task_ids) - - if all_assignments or preemptions: - return SchedulingOutcome.ASSIGNMENTS_MADE - return SchedulingOutcome.NO_ASSIGNMENTS + return preemptions def _cache_scheduling_diagnostics( self, @@ -1897,88 +1998,124 @@ def _sync_all_execution_units(self) -> None: return round_timer = Timer() - # Phase 0: fail workers whose last heartbeat exceeds the staleness - # threshold. This catches workers restored from a checkpoint whose - # backing VMs no longer exist. self._reap_stale_workers() - # Phase 1: drain dispatch for all healthy workers. with slow_log(logger, "provider sync phase 1 (snapshot)", threshold_ms=100): batches = self._transitions.drain_dispatch_all() if not batches: return - # Phase 2: sync with the execution backend. + # Sync with the execution backend (ThreadPoolExecutor inside provider). results = self._provider.sync(batches) - # Phase 3: apply results. - fail_count = 0 - failed_workers: list[str] = [] - with slow_log(logger, "provider sync phase 3 (apply results)", threshold_ms=500): - # Separate successes from failures so we can batch the common case. - success_reqs = [] - failure_entries = [] - for batch, apply_req, error in results: - if apply_req is not None: - success_reqs.append(apply_req) - else: - failure_entries.append((batch, error or "unknown error")) - - # Batch all successful heartbeats in one transaction. - all_tasks_to_kill: set[JobName] = set() - all_task_kill_workers: dict[JobName, WorkerId] = {} - if success_reqs: - batch_results = self._transitions.apply_heartbeats_batch(success_reqs) - for result in batch_results: - all_tasks_to_kill.update(result.tasks_to_kill) - all_task_kill_workers.update(result.task_kill_workers) - - failure_result = self._transitions.fail_heartbeats_batch(failure_entries) - all_tasks_to_kill.update(failure_result.tasks_to_kill) - all_task_kill_workers.update(failure_result.task_kill_workers) - - primary_failed_workers: list[str] = [] - for (batch, error), result in zip(failure_entries, failure_result.results, strict=False): - logger.debug("Sync error for %s: %s", batch.worker_id, error) - if result.action == HeartbeatAction.WORKER_FAILED: - fail_count += 1 - failed_workers.append(batch.worker_id) - self._provider.on_worker_failed(batch.worker_id, batch.worker_address) - primary_failed_workers.append(str(batch.worker_id)) - elif result.action == HeartbeatAction.TRANSIENT_FAILURE: - fail_count += 1 - failed_workers.append(batch.worker_id) - - if self._autoscaler and primary_failed_workers: - sibling_worker_ids = self._autoscaler.terminate_slices_for_workers(primary_failed_workers) - # TODO(#3425): This prunes sibling workers before their in-flight - # results are processed, causing apply_heartbeat() to - # silently drop any logs/states those workers reported this round. - sibling_failures = self._transitions.fail_workers_batch( - sibling_worker_ids, - reason="sibling worker failed, slice terminated", - ) - all_tasks_to_kill.update(sibling_failures.tasks_to_kill) - all_task_kill_workers.update(sibling_failures.task_kill_workers) - for wid, addr in sibling_failures.removed_workers: - self._provider.on_worker_failed(wid, addr) - if sibling_failures.removed_workers: - fail_count += len(sibling_failures.removed_workers) - failed_workers.extend(wid for wid, _ in sibling_failures.removed_workers) - logger.info( - "Failed %d sibling workers from slices: %s", - len(sibling_failures.removed_workers), - [wid for wid, _ in sibling_failures.removed_workers], - ) + acc = _SyncFailureAccumulator() + with slow_log(logger, "provider sync (apply results)", threshold_ms=500): + success_reqs, failure_entries = self._separate_sync_results(results) + self._apply_successful_heartbeats(success_reqs, acc) + primary_failed_workers = self._handle_failed_heartbeats(failure_entries, acc) + self._handle_sibling_worker_failures(primary_failed_workers, acc) + + if acc.all_tasks_to_kill: + self.kill_tasks_on_workers(acc.all_tasks_to_kill, acc.all_task_kill_workers) + + self._log_sync_health_summary( + batch_count=len(batches), + fail_count=acc.fail_count, + failed_workers=acc.failed_workers, + elapsed_ms=round_timer.elapsed_ms(), + ) + + def _separate_sync_results( + self, + results: list, + ) -> tuple[list, list[tuple]]: + """Partition provider sync results into successes and failures.""" + success_reqs = [] + failure_entries = [] + for batch, apply_req, error in results: + if apply_req is not None: + success_reqs.append(apply_req) + else: + failure_entries.append((batch, error or "unknown error")) + return success_reqs, failure_entries + + def _apply_successful_heartbeats( + self, + success_reqs: list, + acc: _SyncFailureAccumulator, + ) -> None: + """Batch-apply successful heartbeat results, accumulating kill targets.""" + if not success_reqs: + return + batch_results = self._transitions.apply_heartbeats_batch(success_reqs) + for result in batch_results: + acc.all_tasks_to_kill.update(result.tasks_to_kill) + acc.all_task_kill_workers.update(result.task_kill_workers) - if all_tasks_to_kill: - self.kill_tasks_on_workers(all_tasks_to_kill, all_task_kill_workers) + def _handle_failed_heartbeats( + self, + failure_entries: list[tuple], + acc: _SyncFailureAccumulator, + ) -> list[str]: + """Process failed heartbeats: update accumulator and return primary failed worker IDs.""" + failure_result = self._transitions.fail_heartbeats_batch(failure_entries) + acc.all_tasks_to_kill.update(failure_result.tasks_to_kill) + acc.all_task_kill_workers.update(failure_result.task_kill_workers) + + primary_failed_workers: list[str] = [] + for (batch, error), result in zip(failure_entries, failure_result.results, strict=False): + logger.debug("Sync error for %s: %s", batch.worker_id, error) + if result.action == HeartbeatAction.WORKER_FAILED: + acc.fail_count += 1 + acc.failed_workers.append(batch.worker_id) + self._provider.on_worker_failed(batch.worker_id, batch.worker_address) + primary_failed_workers.append(str(batch.worker_id)) + elif result.action == HeartbeatAction.TRANSIENT_FAILURE: + acc.fail_count += 1 + acc.failed_workers.append(batch.worker_id) + return primary_failed_workers + + def _handle_sibling_worker_failures( + self, + primary_failed_workers: list[str], + acc: _SyncFailureAccumulator, + ) -> None: + """Terminate slices containing failed workers and fail their siblings.""" + if not self._autoscaler or not primary_failed_workers: + return + sibling_worker_ids = self._autoscaler.terminate_slices_for_workers(primary_failed_workers) + # TODO(#3425): This prunes sibling workers before their in-flight + # results are processed, causing apply_heartbeat() to + # silently drop any logs/states those workers reported this round. + sibling_failures = self._transitions.fail_workers_batch( + sibling_worker_ids, + reason="sibling worker failed, slice terminated", + ) + acc.all_tasks_to_kill.update(sibling_failures.tasks_to_kill) + acc.all_task_kill_workers.update(sibling_failures.task_kill_workers) + for wid, addr in sibling_failures.removed_workers: + self._provider.on_worker_failed(wid, addr) + if sibling_failures.removed_workers: + acc.fail_count += len(sibling_failures.removed_workers) + acc.failed_workers.extend(wid for wid, _ in sibling_failures.removed_workers) + logger.info( + "Failed %d sibling workers from slices: %s", + len(sibling_failures.removed_workers), + [wid for wid, _ in sibling_failures.removed_workers], + ) - elapsed = round_timer.elapsed_ms() - level = logging.WARNING if elapsed > _SLOW_HEARTBEAT_MS else logging.DEBUG + def _log_sync_health_summary( + self, + batch_count: int, + fail_count: int, + failed_workers: list[str], + elapsed_ms: int, + ) -> None: + """Log provider sync timing and periodic cluster health summary.""" + level = logging.WARNING if elapsed_ms > _SLOW_HEARTBEAT_MS else logging.DEBUG fmt = "Provider sync: %d workers, %d failed, %dms" - args: list[object] = [len(batches), fail_count, elapsed] + args: list[object] = [batch_count, fail_count, elapsed_ms] if failed_workers: fmt += " failed=[%s]" args.append(", ".join(failed_workers)) @@ -2047,7 +2184,7 @@ def begin_checkpoint(self) -> tuple[str, CheckpointResult]: if self._config.dry_run: logger.info("[DRY-RUN] Skipping checkpoint write") return ("dry-run", CheckpointResult(created_at=Timestamp.now(), job_count=0, task_count=0, worker_count=0)) - self._checkpoint_in_progress = True + self._checkpoint_paused.set() try: # Hold the heartbeat lock only for the SQLite backup (consistent # snapshot). Compression and GCS upload run outside the lock so @@ -2067,7 +2204,7 @@ def begin_checkpoint(self) -> tuple[str, CheckpointResult]: ) return path, result finally: - self._checkpoint_in_progress = False + self._checkpoint_paused.clear() def launch_job( self, diff --git a/lib/iris/src/iris/cluster/controller/transitions.py b/lib/iris/src/iris/cluster/controller/transitions.py index bf2cf6ef26..6c1d988034 100644 --- a/lib/iris/src/iris/cluster/controller/transitions.py +++ b/lib/iris/src/iris/cluster/controller/transitions.py @@ -294,6 +294,54 @@ def _has_reservation_flag(request: cluster_pb2.Controller.LaunchJobRequest) -> i return 1 if request.HasField("reservation") and request.reservation.entries else 0 +def delete_task_endpoints(cur: TransactionCursor, task_id: str) -> None: + """Remove all registered endpoints for a task.""" + cur.execute("DELETE FROM endpoints WHERE task_id = ?", (task_id,)) + + +def enqueue_run_dispatch( + cur: TransactionCursor, + worker_id: str, + payload_proto: bytes, + now_ms: int, +) -> None: + """Queue a 'run' dispatch entry for delivery on the next heartbeat.""" + cur.execute( + "INSERT INTO dispatch_queue(worker_id, kind, payload_proto, task_id, created_at_ms) " + "VALUES (?, 'run', ?, NULL, ?)", + (worker_id, payload_proto, now_ms), + ) + + +def enqueue_kill_dispatch( + cur: TransactionCursor, + worker_id: str | None, + task_id: str, + now_ms: int, +) -> None: + """Queue a 'kill' dispatch entry for delivery on the next heartbeat.""" + cur.execute( + "INSERT INTO dispatch_queue(worker_id, kind, payload_proto, task_id, created_at_ms) " + "VALUES (?, 'kill', NULL, ?, ?)", + (worker_id, task_id, now_ms), + ) + + +def insert_task_attempt( + cur: TransactionCursor, + task_id: str, + attempt_id: int, + worker_id: str | None, + state: int, + now_ms: int, +) -> None: + """Record a new task attempt row.""" + cur.execute( + "INSERT INTO task_attempts(task_id, attempt_id, worker_id, state, created_at_ms) " "VALUES (?, ?, ?, ?, ?)", + (task_id, attempt_id, worker_id, state, now_ms), + ) + + def _decommit_worker_resources( cur: TransactionCursor, worker_id: str, @@ -315,6 +363,120 @@ def _decommit_worker_resources( ) +def _remove_worker(cur: TransactionCursor, worker_id: str) -> None: + """Remove a worker and sever all its foreign-key references. + + Must be called inside an existing transaction. The four statements + enforce the multi-table invariant: no dangling worker_id references + remain in task_attempts, tasks, or dispatch_queue after the worker + row is deleted. + """ + cur.execute("UPDATE task_attempts SET worker_id = NULL WHERE worker_id = ?", (worker_id,)) + cur.execute("UPDATE tasks SET current_worker_id = NULL WHERE current_worker_id = ?", (worker_id,)) + cur.execute("DELETE FROM dispatch_queue WHERE worker_id = ?", (worker_id,)) + cur.execute("DELETE FROM workers WHERE worker_id = ?", (worker_id,)) + + +def _assign_task( + cur: TransactionCursor, + task_id: str, + worker_id: str | None, + worker_address: str | None, + attempt_id: int, + now_ms: int, +) -> None: + """Create an attempt and mark a task as ASSIGNED in one consistent step. + + worker_id may be None for direct-provider tasks that have no backing + worker daemon. + """ + insert_task_attempt(cur, task_id, attempt_id, worker_id, cluster_pb2.TASK_STATE_ASSIGNED, now_ms) + if worker_id is not None: + cur.execute( + "UPDATE tasks SET state = ?, current_attempt_id = ?, " + "current_worker_id = ?, current_worker_address = ?, " + "started_at_ms = COALESCE(started_at_ms, ?) WHERE task_id = ?", + (cluster_pb2.TASK_STATE_ASSIGNED, attempt_id, worker_id, worker_address, now_ms, task_id), + ) + else: + cur.execute( + "UPDATE tasks SET state = ?, current_attempt_id = ?, " + "started_at_ms = COALESCE(started_at_ms, ?) WHERE task_id = ?", + (cluster_pb2.TASK_STATE_ASSIGNED, attempt_id, now_ms, task_id), + ) + + +def _terminate_task( + cur: TransactionCursor, + task_id: str, + attempt_id: int | None, + state: int, + error: str | None, + now_ms: int, + *, + attempt_state: int | None = None, + worker_id: str | None = None, + resources: "cluster_pb2.ResourceSpecProto | None" = None, + failure_count: int | None = None, + preemption_count: int | None = None, +) -> None: + """Move a task (and its current attempt) out of active state consistently. + + Enforces the multi-table invariant: attempt is marked terminal, + task state/error/finished_at are updated, endpoints are deleted, + and worker resources are released. + + ``attempt_state`` overrides the state written to the attempt row when it + differs from the task state (e.g. attempt=WORKER_FAILED while task retries + to PENDING). Defaults to ``state`` when not provided. + + attempt_id < 0 means no attempt exists; the attempt UPDATE is skipped. + """ + finished_at_ms = None if state in ACTIVE_TASK_STATES or state == cluster_pb2.TASK_STATE_PENDING else now_ms + effective_attempt_state = attempt_state if attempt_state is not None else state + + if attempt_id is not None and attempt_id >= 0: + cur.execute( + "UPDATE task_attempts SET state = ?, " + "finished_at_ms = COALESCE(finished_at_ms, ?), error = ? " + "WHERE task_id = ? AND attempt_id = ?", + (effective_attempt_state, now_ms, error, task_id, attempt_id), + ) + + # Build the UPDATE tasks statement dynamically based on optional counters. + # Use COALESCE for finished_at_ms when non-NULL to preserve any existing + # timestamp (defensive against double-termination). When NULL (retrying to + # PENDING), assign directly so the column is cleared. + if finished_at_ms is not None: + set_clauses = ["state = ?", "error = ?", "finished_at_ms = COALESCE(finished_at_ms, ?)"] + else: + set_clauses = ["state = ?", "error = ?", "finished_at_ms = ?"] + params: list[object] = [state, error, finished_at_ms] + + if failure_count is not None: + set_clauses.append("failure_count = ?") + params.append(failure_count) + if preemption_count is not None: + set_clauses.append("preemption_count = ?") + params.append(preemption_count) + + # Always clear worker columns when leaving active state. + if state not in ACTIVE_TASK_STATES: + set_clauses.append("current_worker_id = NULL") + set_clauses.append("current_worker_address = NULL") + + params.append(task_id) + cur.execute( + f"UPDATE tasks SET {', '.join(set_clauses)} WHERE task_id = ?", + tuple(params), + ) + + delete_task_endpoints(cur, task_id) + + if worker_id is not None and resources is not None: + _decommit_worker_resources(cur, worker_id, resources) + + _LAUNCH_JOB_DECODER = proto_decoder(cluster_pb2.Controller.LaunchJobRequest) @@ -340,24 +502,22 @@ def _kill_non_terminal_tasks( task_id = str(row["task_id"]) worker_id = row["current_worker_id"] task_name = JobName.from_wire(task_id) - cur.execute( - "UPDATE tasks SET state = ?, finished_at_ms = COALESCE(finished_at_ms, ?), error = ?, " - "current_worker_id = NULL, current_worker_address = NULL WHERE task_id = ?", - (cluster_pb2.TASK_STATE_KILLED, now_ms, reason, task_id), - ) - if int(row["current_attempt_id"]) >= 0: - cur.execute( - "UPDATE task_attempts SET state = ?, " - "finished_at_ms = COALESCE(finished_at_ms, ?), error = ? " - "WHERE task_id = ? AND attempt_id = ?", - (cluster_pb2.TASK_STATE_KILLED, now_ms, reason, task_id, int(row["current_attempt_id"])), - ) + resources = None if worker_id is not None: req = proto_cache.get_or_decode(row["request_proto"], _LAUNCH_JOB_DECODER) - _decommit_worker_resources(cur, str(worker_id), req.resources) + resources = req.resources task_kill_workers[task_name] = WorkerId(str(worker_id)) + _terminate_task( + cur, + task_id, + int(row["current_attempt_id"]), + cluster_pb2.TASK_STATE_KILLED, + reason, + now_ms, + worker_id=str(worker_id) if worker_id is not None else None, + resources=resources, + ) tasks_to_kill.add(task_name) - cur.execute("DELETE FROM endpoints WHERE task_id = ?", (task_id,)) return tasks_to_kill, task_kill_workers @@ -471,28 +631,19 @@ def _terminate_coscheduled_siblings( error = f"Coscheduled sibling {failed_task_id.to_wire()} failed" for sib in siblings: - cur.execute( - "UPDATE task_attempts SET state = ?, " - "finished_at_ms = COALESCE(finished_at_ms, ?), error = ? " - "WHERE task_id = ? AND attempt_id = ?", - (cluster_pb2.TASK_STATE_WORKER_FAILED, now_ms, error, sib.task_id, sib.attempt_id), - ) - cur.execute( - "UPDATE tasks SET state = ?, finished_at_ms = ?, preemption_count = ?, error = ?, " - "current_worker_id = NULL, current_worker_address = NULL " - "WHERE task_id = ?", - ( - cluster_pb2.TASK_STATE_WORKER_FAILED, - now_ms, - sib.max_retries_preemption + 1, - error, - sib.task_id, - ), + _terminate_task( + cur, + sib.task_id, + sib.attempt_id, + cluster_pb2.TASK_STATE_WORKER_FAILED, + error, + now_ms, + worker_id=sib.worker_id, + resources=job_req.resources if sib.worker_id is not None else None, + preemption_count=sib.max_retries_preemption + 1, ) if sib.worker_id is not None: - _decommit_worker_resources(cur, sib.worker_id, job_req.resources) task_kill_workers[JobName.from_wire(sib.task_id)] = WorkerId(sib.worker_id) - cur.execute("DELETE FROM endpoints WHERE task_id = ?", (sib.task_id,)) tasks_to_kill.add(JobName.from_wire(sib.task_id)) return tasks_to_kill, task_kill_workers @@ -1252,29 +1403,13 @@ def queue_assignments(self, assignments: list[Assignment]) -> AssignmentResult: job_cache[job_id_wire] = decoded_job job = job_cache[job_id_wire] attempt_id = int(task_row["current_attempt_id"]) + 1 - cur.execute( - "INSERT INTO task_attempts(task_id, attempt_id, worker_id, state, created_at_ms) " - "VALUES (?, ?, ?, ?, ?)", - ( - assignment.task_id.to_wire(), - attempt_id, - str(assignment.worker_id), - cluster_pb2.TASK_STATE_ASSIGNED, - now_ms, - ), - ) - cur.execute( - "UPDATE tasks SET state = ?, current_attempt_id = ?, " - "current_worker_id = ?, current_worker_address = ?, " - "started_at_ms = COALESCE(started_at_ms, ?) WHERE task_id = ?", - ( - cluster_pb2.TASK_STATE_ASSIGNED, - attempt_id, - str(assignment.worker_id), - str(worker_row["address"]), - now_ms, - assignment.task_id.to_wire(), - ), + _assign_task( + cur, + assignment.task_id.to_wire(), + str(assignment.worker_id), + str(worker_row["address"]), + attempt_id, + now_ms, ) if not job.is_reservation_holder: resources = job.request.resources @@ -1301,11 +1436,7 @@ def queue_assignments(self, assignments: list[Assignment]) -> AssignmentResult: attempt_id=attempt_id, constraints=list(job.request.constraints), ) - cur.execute( - "INSERT INTO dispatch_queue(worker_id, kind, payload_proto, task_id, created_at_ms) " - "VALUES (?, 'run', ?, NULL, ?)", - (str(assignment.worker_id), run_request.SerializeToString(), now_ms), - ) + enqueue_run_dispatch(cur, str(assignment.worker_id), run_request.SerializeToString(), now_ms) has_real_dispatch = True cur.execute( "INSERT INTO worker_task_history(worker_id, task_id, assigned_at_ms) VALUES (?, ?, ?)", @@ -1527,7 +1658,7 @@ def _apply_task_transitions( _decommit_worker_resources(cur, str(worker_id), job_req.resources) if update.new_state in TERMINAL_TASK_STATES: - cur.execute("DELETE FROM endpoints WHERE task_id = ?", (update.task_id.to_wire(),)) + delete_task_endpoints(cur, update.task_id.to_wire()) # Coscheduled jobs: a terminal host failure should cascade to siblings. if job_req is not None and task_state in FAILURE_TASK_STATES: @@ -1711,7 +1842,6 @@ def _remove_failed_worker( is_reservation_holder = bool(int(task_row["is_reservation_holder"])) if is_reservation_holder: new_task_state = cluster_pb2.TASK_STATE_PENDING - finished_ms: int | None = None preemption_count = int(task_row["preemption_count"]) else: new_task_state, preemption_count = _resolve_task_failure_state( @@ -1720,7 +1850,6 @@ def _remove_failed_worker( int(task_row["max_retries_preemption"]), cluster_pb2.TASK_STATE_WORKER_FAILED, ) - finished_ms = None if new_task_state == cluster_pb2.TASK_STATE_PENDING else now_ms if is_reservation_holder: cur.execute( "DELETE FROM task_attempts WHERE task_id = ? AND attempt_id = ?", @@ -1733,31 +1862,16 @@ def _remove_failed_worker( (new_task_state, tid), ) else: - cur.execute( - "UPDATE task_attempts SET state = ?, " - "finished_at_ms = COALESCE(finished_at_ms, ?), error = ? " - "WHERE task_id = ? AND attempt_id = ?", - ( - cluster_pb2.TASK_STATE_WORKER_FAILED, - now_ms, - f"Worker {worker_id} failed: {error}", - tid, - int(task_row["current_attempt_id"]), - ), - ) - cur.execute( - "UPDATE tasks SET state = ?, finished_at_ms = ?, error = ?, " - "preemption_count = ?, " - "current_worker_id = NULL, current_worker_address = NULL WHERE task_id = ?", - ( - new_task_state, - finished_ms, - f"Worker {worker_id} failed: {error}", - preemption_count, - tid, - ), + _terminate_task( + cur, + tid, + int(task_row["current_attempt_id"]), + new_task_state, + f"Worker {worker_id} failed: {error}", + now_ms, + attempt_state=cluster_pb2.TASK_STATE_WORKER_FAILED, + preemption_count=preemption_count, ) - cur.execute("DELETE FROM endpoints WHERE task_id = ?", (tid,)) task_id = JobName.from_wire(tid) parent_job_id, _ = task_id.require_task() new_job_state = self._recompute_job_state(cur, parent_job_id) @@ -1777,16 +1891,7 @@ def _remove_failed_worker( task_kill_workers.update(child_task_kill_workers) if new_task_state == cluster_pb2.TASK_STATE_WORKER_FAILED: tasks_to_kill.add(task_id) - cur.execute( - "UPDATE task_attempts SET worker_id = NULL WHERE worker_id = ?", - (str(worker_id),), - ) - cur.execute( - "UPDATE tasks SET current_worker_id = NULL WHERE current_worker_id = ?", - (str(worker_id),), - ) - cur.execute("DELETE FROM dispatch_queue WHERE worker_id = ?", (str(worker_id),)) - cur.execute("DELETE FROM workers WHERE worker_id = ?", (str(worker_id),)) + _remove_worker(cur, str(worker_id)) return TxResult(tasks_to_kill=tasks_to_kill, task_kill_workers=task_kill_workers) def _record_heartbeat_failure( @@ -1823,17 +1928,9 @@ def _record_heartbeat_failure( task_kill_workers.update(removal.task_kill_workers) else: for req in drained_dispatch.tasks_to_run: - cur.execute( - "INSERT INTO dispatch_queue(worker_id, kind, payload_proto, task_id, created_at_ms) " - "VALUES (?, 'run', ?, NULL, ?)", - (str(worker_id), req.SerializeToString(), now_ms), - ) + enqueue_run_dispatch(cur, str(worker_id), req.SerializeToString(), now_ms) for task_id in drained_dispatch.tasks_to_kill: - cur.execute( - "INSERT INTO dispatch_queue(worker_id, kind, payload_proto, task_id, created_at_ms) " - "VALUES (?, 'kill', NULL, ?, ?)", - (str(worker_id), task_id, now_ms), - ) + enqueue_kill_dispatch(cur, str(worker_id), task_id, now_ms) action = HeartbeatAction.WORKER_FAILED if should_remove else HeartbeatAction.TRANSIENT_FAILURE return HeartbeatFailureResult( tasks_to_kill=tasks_to_kill, @@ -1944,12 +2041,14 @@ def mark_task_unschedulable(self, task_id: JobName, reason: str) -> TxResult: if row is None: return TxResult() now_ms = Timestamp.now().epoch_ms() - cur.execute( - "UPDATE tasks SET state = ?, error = ?, finished_at_ms = ?, " - "current_worker_id = NULL, current_worker_address = NULL WHERE task_id = ?", - (cluster_pb2.TASK_STATE_UNSCHEDULABLE, reason, now_ms, task_id.to_wire()), + _terminate_task( + cur, + task_id.to_wire(), + None, + cluster_pb2.TASK_STATE_UNSCHEDULABLE, + reason, + now_ms, ) - cur.execute("DELETE FROM endpoints WHERE task_id = ?", (task_id.to_wire(),)) self._recompute_job_state(cur, JobName.from_wire(str(row["job_id"]))) self._record_transaction( cur, "mark_task_unschedulable", [("task_unschedulable", task_id.to_wire(), {"reason": reason})] @@ -1986,39 +2085,30 @@ def preempt_task(self, task_id: JobName, reason: str) -> TxResult: int(row["max_retries_preemption"]), cluster_pb2.TASK_STATE_PREEMPTED, ) - finished_ms = None if new_state == cluster_pb2.TASK_STATE_PENDING else now_ms - - # Update attempt - cur.execute( - "UPDATE task_attempts SET state = ?, finished_at_ms = COALESCE(finished_at_ms, ?), error = ? " - "WHERE task_id = ? AND attempt_id = ?", - ( - cluster_pb2.TASK_STATE_PREEMPTED, - now_ms, - reason, - task_id.to_wire(), - int(row["current_attempt_id"]), - ), - ) - - # Update task - cur.execute( - "UPDATE tasks SET state = ?, error = ?, finished_at_ms = ?, preemption_count = ?, " - "current_worker_id = NULL, current_worker_address = NULL WHERE task_id = ?", - (new_state, reason, finished_ms, preemption_count, task_id.to_wire()), - ) - - # Decommit worker resources + # Fetch worker_id from the attempt for resource decommit. attempt_row = cur.execute( "SELECT worker_id FROM task_attempts WHERE task_id = ? AND attempt_id = ?", (task_id.to_wire(), int(row["current_attempt_id"])), ).fetchone() - if attempt_row and attempt_row["worker_id"]: + attempt_worker_id = str(attempt_row["worker_id"]) if attempt_row and attempt_row["worker_id"] else None + attempt_resources = None + if attempt_worker_id is not None: job_req = cluster_pb2.Controller.LaunchJobRequest() job_req.ParseFromString(row["request_proto"]) - _decommit_worker_resources(cur, str(attempt_row["worker_id"]), job_req.resources) + attempt_resources = job_req.resources - cur.execute("DELETE FROM endpoints WHERE task_id = ?", (task_id.to_wire(),)) + _terminate_task( + cur, + task_id.to_wire(), + int(row["current_attempt_id"]), + new_state, + reason, + now_ms, + attempt_state=cluster_pb2.TASK_STATE_PREEMPTED, + worker_id=attempt_worker_id, + resources=attempt_resources, + preemption_count=preemption_count, + ) # Recompute job state and cascade if terminal job_id = JobName.from_wire(str(row["job_id"])) @@ -2110,29 +2200,25 @@ def cancel_tasks_for_timeout(self, task_ids: set[JobName], reason: str) -> TxRes worker_id_str = row["worker_id"] job_req = job_req_cache[job_id_wire] tasks_to_kill.add(tid) + decommit_worker = None + decommit_resources = None if worker_id_str is not None: task_kill_workers[tid] = WorkerId(str(worker_id_str)) if not int(row["is_reservation_holder"]): - _decommit_worker_resources(cur, str(worker_id_str), job_req.resources) - cur.execute( - "UPDATE tasks SET state = ?, error = ?, finished_at_ms = COALESCE(finished_at_ms, ?), " - "failure_count = ?, current_worker_id = NULL, current_worker_address = NULL WHERE task_id = ?", - ( - cluster_pb2.TASK_STATE_FAILED, - reason, - now_ms, - int(row["failure_count"]) + 1, - task_id_wire, - ), - ) + decommit_worker = str(worker_id_str) + decommit_resources = job_req.resources attempt_id = row["current_attempt_id"] - if attempt_id is not None and int(attempt_id) >= 0: - cur.execute( - "UPDATE task_attempts SET state = ?, error = ?, finished_at_ms = COALESCE(finished_at_ms, ?) " - "WHERE task_id = ? AND attempt_id = ?", - (cluster_pb2.TASK_STATE_FAILED, reason, now_ms, task_id_wire, int(attempt_id)), - ) - cur.execute("DELETE FROM endpoints WHERE task_id = ?", (task_id_wire,)) + _terminate_task( + cur, + task_id_wire, + int(attempt_id) if attempt_id is not None else None, + cluster_pb2.TASK_STATE_FAILED, + reason, + now_ms, + worker_id=decommit_worker, + resources=decommit_resources, + failure_count=int(row["failure_count"]) + 1, + ) jobs_to_update.add(job_id_wire) # Terminate coscheduled siblings (deduplicated, all reads already done). @@ -2323,17 +2409,9 @@ def requeue_dispatch(self, batch: DispatchBatch) -> None: with self._db.transaction() as cur: now_ms = Timestamp.now().epoch_ms() for req in batch.tasks_to_run: - cur.execute( - "INSERT INTO dispatch_queue(worker_id, kind, payload_proto, task_id, created_at_ms) " - "VALUES (?, 'run', ?, NULL, ?)", - (str(batch.worker_id), req.SerializeToString(), now_ms), - ) + enqueue_run_dispatch(cur, str(batch.worker_id), req.SerializeToString(), now_ms) for task_id in batch.tasks_to_kill: - cur.execute( - "INSERT INTO dispatch_queue(worker_id, kind, payload_proto, task_id, created_at_ms) " - "VALUES (?, 'kill', NULL, ?, ?)", - (str(batch.worker_id), task_id, now_ms), - ) + enqueue_kill_dispatch(cur, str(batch.worker_id), task_id, now_ms) def remove_finished_job(self, job_id: JobName) -> bool: """Remove a finished job and its tasks from state. @@ -2368,10 +2446,7 @@ def remove_worker(self, worker_id: WorkerId) -> WorkerDetailRow | None: row = cur.execute("SELECT * FROM workers WHERE worker_id = ?", (str(worker_id),)).fetchone() if row is None: return None - cur.execute("UPDATE task_attempts SET worker_id = NULL WHERE worker_id = ?", (str(worker_id),)) - cur.execute("UPDATE tasks SET current_worker_id = NULL WHERE current_worker_id = ?", (str(worker_id),)) - cur.execute("DELETE FROM dispatch_queue WHERE worker_id = ?", (str(worker_id),)) - cur.execute("DELETE FROM workers WHERE worker_id = ?", (str(worker_id),)) + _remove_worker(cur, str(worker_id)) self._record_transaction(cur, "remove_worker", [("worker_removed", str(worker_id), {})]) self._db.remove_worker_from_attr_cache(worker_id) return WORKER_DETAIL_PROJECTION.decode_one([row]) @@ -2474,10 +2549,7 @@ def _stopped() -> bool: break worker_id = row["worker_id"] with self._db.transaction() as cur: - cur.execute("UPDATE task_attempts SET worker_id = NULL WHERE worker_id = ?", (worker_id,)) - cur.execute("UPDATE tasks SET current_worker_id = NULL WHERE current_worker_id = ?", (worker_id,)) - cur.execute("DELETE FROM dispatch_queue WHERE worker_id = ?", (worker_id,)) - cur.execute("DELETE FROM workers WHERE worker_id = ?", (worker_id,)) + _remove_worker(cur, str(worker_id)) self._record_transaction(cur, "prune_old_data", [("worker_pruned", str(worker_id), {})]) workers_deleted += 1 time.sleep(pause_between_s) @@ -2572,11 +2644,8 @@ def buffer_dispatch(self, worker_id: WorkerId, task_request: cluster_pb2.Worker. Called by the scheduling thread after committing resources via TaskAssignedEvent. The dispatch will be delivered when begin_heartbeat() drains the buffer. """ - self._db.execute( - "INSERT INTO dispatch_queue(worker_id, kind, payload_proto, task_id, created_at_ms) " - "VALUES (?, 'run', ?, NULL, ?)", - (str(worker_id), task_request.SerializeToString(), Timestamp.now().epoch_ms()), - ) + with self._db.transaction() as cur: + enqueue_run_dispatch(cur, str(worker_id), task_request.SerializeToString(), Timestamp.now().epoch_ms()) def buffer_kill(self, worker_id: WorkerId, task_id: str) -> None: """Buffer a task kill for the next heartbeat. @@ -2584,11 +2653,8 @@ def buffer_kill(self, worker_id: WorkerId, task_id: str) -> None: Called when a task needs to be terminated on a worker. The kill will be delivered when begin_heartbeat() drains the buffer. """ - self._db.execute( - "INSERT INTO dispatch_queue(worker_id, kind, payload_proto, task_id, created_at_ms) " - "VALUES (?, 'kill', NULL, ?, ?)", - (str(worker_id), task_id, Timestamp.now().epoch_ms()), - ) + with self._db.transaction() as cur: + enqueue_kill_dispatch(cur, str(worker_id), task_id, Timestamp.now().epoch_ms()) def begin_heartbeat(self, worker_id: WorkerId) -> DispatchBatch | None: """Drain dispatch for a worker and snapshot expected running attempts.""" @@ -2850,16 +2916,7 @@ def drain_for_direct_provider( job_req.ParseFromString(row["request_proto"]) resources = job_req.resources - cur.execute( - "INSERT INTO task_attempts(task_id, attempt_id, worker_id, state, created_at_ms) " - "VALUES (?, ?, NULL, ?, ?)", - (task_id, attempt_id, cluster_pb2.TASK_STATE_ASSIGNED, now_ms), - ) - cur.execute( - "UPDATE tasks SET state = ?, current_attempt_id = ?, " - "started_at_ms = COALESCE(started_at_ms, ?) WHERE task_id = ?", - (cluster_pb2.TASK_STATE_ASSIGNED, attempt_id, now_ms, task_id), - ) + _assign_task(cur, task_id, None, None, attempt_id, now_ms) run_req = cluster_pb2.Worker.RunTaskRequest( task_id=task_id, @@ -3086,7 +3143,7 @@ def apply_direct_provider_updates(self, updates: list[TaskUpdate]) -> TxResult: job_req.ParseFromString(job_row["request_proto"]) if update.new_state in TERMINAL_TASK_STATES: - cur.execute("DELETE FROM endpoints WHERE task_id = ?", (update.task_id.to_wire(),)) + delete_task_endpoints(cur, update.task_id.to_wire()) # Coscheduled sibling cascade. if job_req is not None and task_state in FAILURE_TASK_STATES: @@ -3124,11 +3181,8 @@ def buffer_direct_kill(self, task_id: str) -> None: Inserts a kill entry into dispatch_queue with worker_id=NULL. Drained by drain_for_direct_provider(). """ - self._db.execute( - "INSERT INTO dispatch_queue(worker_id, kind, payload_proto, task_id, created_at_ms) " - "VALUES (NULL, 'kill', NULL, ?, ?)", - (task_id, Timestamp.now().epoch_ms()), - ) + with self._db.transaction() as cur: + enqueue_kill_dispatch(cur, None, task_id, Timestamp.now().epoch_ms()) # ========================================================================= # Test helpers @@ -3171,19 +3225,6 @@ def create_attempt_for_test(self, task_id: JobName, worker_id: WorkerId) -> int: worker_address = str(worker_row["address"]) if worker_row is not None else str(worker_id) next_attempt_id = int(task["current_attempt_id"]) + 1 now_ms = Timestamp.now().epoch_ms() - self._db.execute( - "INSERT INTO task_attempts(task_id, attempt_id, worker_id, state, created_at_ms) VALUES (?, ?, ?, ?, ?)", - ( - task_id.to_wire(), - next_attempt_id, - str(worker_id), - cluster_pb2.TASK_STATE_ASSIGNED, - now_ms, - ), - ) - self._db.execute( - "UPDATE tasks SET current_attempt_id = ?, state = ?, " - "current_worker_id = ?, current_worker_address = ? WHERE task_id = ?", - (next_attempt_id, cluster_pb2.TASK_STATE_ASSIGNED, str(worker_id), worker_address, task_id.to_wire()), - ) + with self._db.transaction() as cur: + _assign_task(cur, task_id.to_wire(), str(worker_id), worker_address, next_attempt_id, now_ms) return next_attempt_id diff --git a/lib/iris/tests/cluster/controller/test_preemption.py b/lib/iris/tests/cluster/controller/test_preemption.py index 0ce608f5ce..8d12148ae5 100644 --- a/lib/iris/tests/cluster/controller/test_preemption.py +++ b/lib/iris/tests/cluster/controller/test_preemption.py @@ -3,7 +3,6 @@ """Tests for the preemption loop — higher-priority tasks evict lower-priority running tasks.""" - from iris.cluster.controller.budget import compute_effective_band from iris.cluster.controller.transitions import _resolve_task_failure_state from iris.cluster.controller.controller import ( @@ -20,8 +19,14 @@ from .conftest import ( ControllerTestHarness, + dispatch_task, make_controller_state, + make_test_entrypoint, + make_worker_metadata, + query_attempt, query_task, + register_worker, + submit_job, ) @@ -672,3 +677,121 @@ def test_resolve_failure_building_terminal_when_exhausted(): ) assert new_state == cluster_pb2.TASK_STATE_WORKER_FAILED assert count == 2 + + +# --------------------------------------------------------------------------- +# Integration tests: preempt_task attempt state and coscheduled cascade +# --------------------------------------------------------------------------- + + +def test_preempt_task_retries_when_budget_remains(): + """Preempted running task retries to PENDING with attempt marked PREEMPTED.""" + with make_controller_state() as state: + harness = ControllerTestHarness(state) + w1 = harness.add_worker("w1", cpu=4) + + tasks = harness.submit( + "/alice/batch-job", + cpu=1, + replicas=1, + max_retries_preemption=3, + ) + task = tasks[0] + harness.dispatch(task, w1) + assert query_task(state, task.task_id).state == cluster_pb2.TASK_STATE_RUNNING + + attempt_id_before = query_task(state, task.task_id).current_attempt_id + state.preempt_task(task.task_id, reason="Evicted by /bob/prod:0") + + # Task retries to PENDING + updated = query_task(state, task.task_id) + assert updated.state == cluster_pb2.TASK_STATE_PENDING + assert updated.preemption_count == 1 + + # The attempt is marked PREEMPTED even though the task retries + attempt = query_attempt(state, task.task_id, attempt_id_before) + assert attempt is not None + assert attempt.state == cluster_pb2.TASK_STATE_PREEMPTED + + +def test_preempt_task_terminal_when_budget_exhausted(): + """Preempted running task becomes terminal PREEMPTED when budget is spent.""" + with make_controller_state() as state: + harness = ControllerTestHarness(state) + w1 = harness.add_worker("w1", cpu=4) + + tasks = harness.submit( + "/alice/batch-job", + cpu=1, + replicas=1, + max_retries_preemption=0, + ) + task = tasks[0] + harness.dispatch(task, w1) + + result = state.preempt_task(task.task_id, reason="budget gone") + + updated = query_task(state, task.task_id) + assert updated.state == cluster_pb2.TASK_STATE_PREEMPTED + assert updated.preemption_count == 1 + assert updated.finished_at is not None + + # The preempted task is included in tasks_to_kill so the controller + # can send a kill RPC to the worker. + assert task.task_id in result.tasks_to_kill + + # Attempt is also PREEMPTED + attempt = query_attempt(state, task.task_id, updated.current_attempt_id) + assert attempt is not None + assert attempt.state == cluster_pb2.TASK_STATE_PREEMPTED + + +def test_preempt_task_cascades_coscheduled_siblings(): + """When all coscheduled tasks are preempted to terminal, the job finalizes and kills survivors. + + preempt_task does not directly cascade coscheduled siblings (unlike + WORKER_FAILED via heartbeat). Instead, siblings are killed when the job + reaches a terminal state through _finalize_terminal_job. + """ + from iris.cluster.constraints import WellKnownAttribute + + with make_controller_state() as state: + # Register 2 workers with TPU attributes for coscheduling + for i in range(2): + meta = make_worker_metadata() + meta.attributes[WellKnownAttribute.TPU_NAME].string_value = "tpu-a" + meta.attributes[WellKnownAttribute.TPU_WORKER_ID].int_value = i + register_worker(state, f"w{i}", f"addr{i}:8080", meta) + + # Submit a coscheduled job with 2 replicas, no preemption retries + req = cluster_pb2.Controller.LaunchJobRequest( + name="cosched-preempt", + entrypoint=make_test_entrypoint(), + resources=cluster_pb2.ResourceSpecProto(cpu_millicores=1000, memory_bytes=1024**3), + replicas=2, + environment=cluster_pb2.EnvironmentConfig(), + max_retries_preemption=0, + ) + req.coscheduling.group_by = WellKnownAttribute.TPU_NAME + tasks = submit_job(state, "cosched-preempt", req) + assert len(tasks) == 2 + + # Dispatch both tasks + for i, task in enumerate(tasks): + dispatch_task(state, task, WorkerId(f"w{i}")) + + # Preempt the first task — it goes terminal PREEMPTED + result0 = state.preempt_task(tasks[0].task_id, reason="preempted by prod") + assert query_task(state, tasks[0].task_id).state == cluster_pb2.TASK_STATE_PREEMPTED + + # Second task is still running (preempt_task doesn't directly cascade siblings) + assert query_task(state, tasks[1].task_id).state == cluster_pb2.TASK_STATE_RUNNING + + # Preempt the second task — now ALL tasks are terminal, job finalizes + result1 = state.preempt_task(tasks[1].task_id, reason="preempted by prod") + assert query_task(state, tasks[1].task_id).state == cluster_pb2.TASK_STATE_PREEMPTED + + # Both tasks should be in the combined kill set + all_kills = result0.tasks_to_kill | result1.tasks_to_kill + assert tasks[0].task_id in all_kills + assert tasks[1].task_id in all_kills diff --git a/tests/integration/iris/test_kind_gpu_canary.py b/tests/integration/iris/test_iris_kind.py similarity index 100% rename from tests/integration/iris/test_kind_gpu_canary.py rename to tests/integration/iris/test_iris_kind.py