Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 46 additions & 4 deletions lib/iris/docs/task-states.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ independent job state machine.
v
(terminal)

+-----------+ |
| PREEMPTED | |
+-----------+ |
^ |
| exhausted |
| |
+-----------+ |
| preempt |------------+
| (ctrl) |
+-----------+

Other terminal states: KILLED, UNSCHEDULABLE (never retried)
```

Expand All @@ -72,6 +83,7 @@ independent job state machine.
| `KILLED` | 6 | Yes | No | Controller: job cancellation (`_on_job_cancelled`), job failure cascade (`_mark_remaining_tasks_killed`), per-task timeout | `killed` (grey) |
| `WORKER_FAILED` | 7 | Yes | Yes | Controller: worker death cascade (`_on_worker_failed`), coscheduled sibling kill | `worker_failed` (purple) |
| `UNSCHEDULABLE` | 8 | Yes | No | Controller: scheduling timeout expired (`_mark_task_unschedulable`) | `unschedulable` (red) |
| `PREEMPTED` | 10 | Yes | Yes | Controller: priority preemption with budget exhausted (`preempt_task`) | `preempted` (orange) |


## State Transitions in Detail
Expand Down Expand Up @@ -171,6 +183,31 @@ terminally, `_cascade_coscheduled_failure` exhausts the preemption budget of
all running siblings and transitions them to `WORKER_FAILED` (terminal). This
prevents other hosts from hanging on collective operations.

### PREEMPTED

Set by the controller when a higher-priority task evicts a lower-priority
running task via `preempt_task`. The preemption loop (`_run_preemption_pass`)
selects victims from lower priority bands and calls `preempt_task` for each.

Retry evaluation uses `_resolve_task_failure_state` with the preemption budget:

1. **ASSIGNED tasks**: always retry to PENDING regardless of budget (the task
never started executing, so preemption is free).
2. **BUILDING or RUNNING tasks**: `preemption_count` is incremented and compared
against `max_retries_preemption`.
- If `preemption_count <= max_retries_preemption`: task is requeued to
`PENDING` for retry. The current attempt is marked `PREEMPTED`.
- If `preemption_count > max_retries_preemption`: task state is set to
`PREEMPTED` (terminal). Both the attempt and the task are `PREEMPTED`.

`PREEMPTED` is in both `TERMINAL_TASK_STATES` and `FAILURE_TASK_STATES`.
When a coscheduled task becomes terminally `PREEMPTED`, the job state is
recomputed. If all tasks in the job are terminal, `_finalize_terminal_job`
kills any remaining non-terminal tasks and cascades to child jobs. Note that
unlike `WORKER_FAILED` reported via heartbeat, `preempt_task` does not
directly cascade coscheduled siblings — the cascade only occurs through job
finalization.

### UNSCHEDULABLE

Set by the controller's scheduling loop when a task's scheduling deadline
Expand All @@ -185,10 +222,10 @@ are killed.

Iris maintains two independent retry budgets per task:

| Budget | Counter | Limit Field | Default | Trigger State |
| Budget | Counter | Limit Field | Default | Trigger States |
|---|---|---|---|---|
| Failure | `failure_count` | `max_retries_failure` | 0 (no retries) | `FAILED` |
| Preemption | `preemption_count` | `max_retries_preemption` | 100 | `WORKER_FAILED` |
| Preemption | `preemption_count` | `max_retries_preemption` | 100 | `WORKER_FAILED`, `PREEMPTED` |

### Retry flow

Expand All @@ -209,15 +246,18 @@ Iris maintains two independent retry budgets per task:
### What counts toward job failure

Only `TASK_STATE_FAILED` counts toward the job's `max_task_failures` threshold.
Worker failures (preemptions) do not count. This means a job can survive
Worker failures and preemptions do not count. This means a job can survive
unlimited preemptions as long as the per-task preemption budget is not
exhausted.
exhausted. `TASK_STATE_PREEMPTED` and `TASK_STATE_WORKER_FAILED` are grouped
together for job state derivation: if all tasks are terminal and any are in
one of these states, the job becomes `JOB_STATE_WORKER_FAILED`.

### States that are never retried

- `SUCCEEDED`: task completed successfully
- `KILLED`: explicit termination by user or cascade
- `UNSCHEDULABLE`: scheduling timeout expired
- `PREEMPTED`: only when preemption budget is exhausted (otherwise retried as `PENDING`)

## Terminal State Summary

Expand All @@ -230,6 +270,7 @@ A task is considered finished (`is_finished() == True`) when:
| `UNSCHEDULABLE` | Always finished |
| `FAILED` | Finished when `failure_count > max_retries_failure` |
| `WORKER_FAILED` | Finished when `preemption_count > max_retries_preemption` |
| `PREEMPTED` | Finished when `preemption_count > max_retries_preemption` |

The distinction matters: a task in `FAILED` state with retry budget remaining
is in a terminal state at the attempt level but is not finished at the task
Expand All @@ -252,6 +293,7 @@ strings (e.g., `TASK_STATE_RUNNING`) to lowercase display names by stripping the
| `killed` | `.status-killed` | Grey (#57606a) |
| `worker_failed` | `.status-worker_failed` | Purple (#8250df) |
| `unschedulable` | `.status-unschedulable` | Red (#cf222e) |
| `preempted` | `.status-preempted` | Orange (#bc4c00) |

The job detail page shows per-task attempt history. Each attempt has its own
state badge, and worker failures are annotated with "(worker failure)" in the
Expand Down
174 changes: 86 additions & 88 deletions lib/iris/src/iris/cli/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
find_marin_root,
get_git_sha,
)
from iris.cli.main import require_controller_url
from iris.cli.main import require_controller_url, rpc_client
from iris.cluster.config import IrisConfig, clear_remote_state, make_local_config
from iris.rpc import cluster_connect, cluster_pb2, vm_pb2
from iris.rpc import cluster_pb2, vm_pb2
from iris.rpc.proto_utils import format_accelerator_display, vm_state_name
from iris.time_proto import timestamp_from_proto
from rigging.timing import Duration, ExponentialBackoff, Timestamp
Expand Down Expand Up @@ -56,15 +56,15 @@ def _format_status_table(status: vm_pb2.AutoscalerStatus) -> str:


def _get_autoscaler_status(controller_url: str) -> vm_pb2.AutoscalerStatus:
client = cluster_connect.ControllerServiceClientSync(controller_url)
request = cluster_pb2.Controller.GetAutoscalerStatusRequest()
return client.get_autoscaler_status(request).status
with rpc_client(controller_url) as client:
request = cluster_pb2.Controller.GetAutoscalerStatusRequest()
return client.get_autoscaler_status(request).status


def _get_worker_status(controller_url: str, worker_id: str) -> cluster_pb2.Controller.GetWorkerStatusResponse:
client = cluster_connect.ControllerServiceClientSync(controller_url)
request = cluster_pb2.Controller.GetWorkerStatusRequest(id=worker_id)
return client.get_worker_status(request)
with rpc_client(controller_url) as client:
request = cluster_pb2.Controller.GetWorkerStatusRequest(id=worker_id)
return client.get_worker_status(request)


def _parse_ghcr_tag(image_tag: str) -> tuple[str, str, str] | None:
Expand Down Expand Up @@ -322,20 +322,20 @@ def cluster_start_smoke(ctx, label_prefix, url_file, min_workers, worker_timeout
with bundle.controller.tunnel(address) as url:
click.echo(f"Tunnel ready: {url}")

client = cluster_connect.ControllerServiceClientSync(url, timeout_ms=30000)
deadline = time.monotonic() + worker_timeout
healthy_count = 0
while time.monotonic() < deadline:
workers = client.list_workers(cluster_pb2.Controller.ListWorkersRequest()).workers
healthy = [w for w in workers if w.healthy]
healthy_count = len(healthy)
if healthy_count >= min_workers:
break
time.sleep(2)
else:
raise click.ClickException(
f"Only {healthy_count} of {min_workers} workers healthy after {worker_timeout}s"
)
with rpc_client(url) as client:
deadline = time.monotonic() + worker_timeout
healthy_count = 0
while time.monotonic() < deadline:
workers = client.list_workers(cluster_pb2.Controller.ListWorkersRequest()).workers
healthy = [w for w in workers if w.healthy]
healthy_count = len(healthy)
if healthy_count >= min_workers:
break
time.sleep(2)
else:
raise click.ClickException(
f"Only {healthy_count} of {min_workers} workers healthy after {worker_timeout}s"
)

click.echo(f"{healthy_count} workers ready, writing URL to {url_file}")
Path(url_file).write_text(url)
Expand Down Expand Up @@ -398,10 +398,10 @@ def cluster_status_cmd(ctx):
controller_url = require_controller_url(ctx)
click.echo("Checking controller status...")
try:
client = cluster_connect.ControllerServiceClientSync(controller_url)
proc = client.get_process_status(cluster_pb2.GetProcessStatusRequest()).process_info
workers = client.list_workers(cluster_pb2.Controller.ListWorkersRequest()).workers
as_status = client.get_autoscaler_status(cluster_pb2.Controller.GetAutoscalerStatusRequest()).status
with rpc_client(controller_url) as client:
proc = client.get_process_status(cluster_pb2.GetProcessStatusRequest()).process_info
workers = client.list_workers(cluster_pb2.Controller.ListWorkersRequest()).workers
as_status = client.get_autoscaler_status(cluster_pb2.Controller.GetAutoscalerStatusRequest()).status
healthy = sum(1 for w in workers if w.healthy)
click.echo("Controller Status:")
click.echo(" Running: True")
Expand Down Expand Up @@ -629,12 +629,12 @@ def controller_checkpoint(ctx, stop: bool):
briefly and writes a consistent checkpoint DB copy.
"""
controller_url = require_controller_url(ctx)
client = cluster_connect.ControllerServiceClientSync(controller_url)
try:
resp = client.begin_checkpoint(cluster_pb2.Controller.BeginCheckpointRequest(), timeout_ms=60_000)
except Exception as e:
click.echo(f"Checkpoint failed: {e}", err=True)
raise SystemExit(1) from e
with rpc_client(controller_url) as client:
try:
resp = client.begin_checkpoint(cluster_pb2.Controller.BeginCheckpointRequest(), timeout_ms=60_000)
except Exception as e:
click.echo(f"Checkpoint failed: {e}", err=True)
raise SystemExit(1) from e

click.echo(f"Checkpoint DB written: {resp.checkpoint_path}")
click.echo(f" Jobs: {resp.job_count}")
Expand Down Expand Up @@ -717,17 +717,15 @@ def controller_restart(ctx, skip_checkpoint: bool, checkpoint_timeout: int):
click.echo("Skipping pre-restart checkpoint.")
else:
click.echo(f"Taking checkpoint (timeout {checkpoint_timeout}s)...")
client = cluster_connect.ControllerServiceClientSync(controller_url)
try:
resp = client.begin_checkpoint(
cluster_pb2.Controller.BeginCheckpointRequest(),
timeout_ms=checkpoint_timeout * 1000,
)
except Exception as e:
click.echo(f"Checkpoint failed: {e}", err=True)
raise SystemExit(1) from e
finally:
client.close()
with rpc_client(controller_url) as client:
try:
resp = client.begin_checkpoint(
cluster_pb2.Controller.BeginCheckpointRequest(),
timeout_ms=checkpoint_timeout * 1000,
)
except Exception as e:
click.echo(f"Checkpoint failed: {e}", err=True)
raise SystemExit(1) from e
click.echo(f"Checkpoint: {resp.checkpoint_path} ({resp.job_count} jobs, {resp.worker_count} workers)")

# Build fresh images so the new controller VM gets the latest code
Expand Down Expand Up @@ -759,59 +757,59 @@ def worker_restart(ctx, worker_id: str | None, timeout: int):
new worker process.
"""
controller_url = require_controller_url(ctx)
client = cluster_connect.ControllerServiceClientSync(controller_url)

# Get current workers
workers_resp = client.list_workers(cluster_pb2.Controller.ListWorkersRequest())
workers = workers_resp.workers

if worker_id:
workers = [w for w in workers if w.worker_id == worker_id]
if not workers:
click.echo(f"Worker {worker_id} not found", err=True)
raise SystemExit(1)
with rpc_client(controller_url) as client:
# Get current workers
workers_resp = client.list_workers(cluster_pb2.Controller.ListWorkersRequest())
workers = workers_resp.workers

if not workers:
click.echo("No workers to restart")
return
if worker_id:
workers = [w for w in workers if w.worker_id == worker_id]
if not workers:
click.echo(f"Worker {worker_id} not found", err=True)
raise SystemExit(1)

click.echo(f"Restarting {len(workers)} worker(s) (timeout={timeout}s per worker)")
if not workers:
click.echo("No workers to restart")
return

succeeded = 0
failed = 0
click.echo(f"Restarting {len(workers)} worker(s) (timeout={timeout}s per worker)")

for worker in workers:
wid = worker.worker_id
click.echo(f"\nRestarting worker {wid}...")
succeeded = 0
failed = 0

resp = client.restart_worker(
cluster_pb2.Controller.RestartWorkerRequest(worker_id=wid),
timeout_ms=timeout * 1000,
)
for worker in workers:
wid = worker.worker_id
click.echo(f"\nRestarting worker {wid}...")

if not resp.accepted:
click.echo(f" Failed: {resp.error}", err=True)
failed += 1
continue
resp = client.restart_worker(
cluster_pb2.Controller.RestartWorkerRequest(worker_id=wid),
timeout_ms=timeout * 1000,
)

# Poll until the worker re-registers as healthy
def _worker_healthy(target_id: str = wid) -> bool:
try:
resp = client.list_workers(cluster_pb2.Controller.ListWorkersRequest())
return any(w.worker_id == target_id and w.healthy for w in resp.workers)
except Exception:
return False

reregistered = ExponentialBackoff(initial=5.0, maximum=5.0, jitter=0.0).wait_until(
_worker_healthy,
timeout=Duration.from_seconds(timeout),
)
if not resp.accepted:
click.echo(f" Failed: {resp.error}", err=True)
failed += 1
continue

# Poll until the worker re-registers as healthy
def _worker_healthy(target_id: str = wid) -> bool:
try:
resp = client.list_workers(cluster_pb2.Controller.ListWorkersRequest())
return any(w.worker_id == target_id and w.healthy for w in resp.workers)
except Exception:
return False

reregistered = ExponentialBackoff(initial=5.0, maximum=5.0, jitter=0.0).wait_until(
_worker_healthy,
timeout=Duration.from_seconds(timeout),
)

if reregistered:
click.echo(f" Worker {wid} restarted successfully")
succeeded += 1
else:
click.echo(f" Worker {wid} did not re-register within {timeout}s", err=True)
failed += 1
if reregistered:
click.echo(f" Worker {wid} restarted successfully")
succeeded += 1
else:
click.echo(f" Worker {wid} did not re-register within {timeout}s", err=True)
failed += 1

click.echo(f"\nDone: {succeeded} succeeded, {failed} failed")
Loading
Loading