Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
489 changes: 270 additions & 219 deletions lib/iris/src/iris/cluster/controller/controller.py

Large diffs are not rendered by default.

23 changes: 3 additions & 20 deletions lib/iris/src/iris/cluster/controller/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from typing import Protocol

from iris.cluster.controller.transitions import DispatchBatch, HeartbeatApplyRequest
from iris.cluster.types import WorkerId
from iris.rpc import job_pb2

Expand All @@ -21,30 +20,14 @@ class ProviderUnsupportedError(ProviderError):
class TaskProvider(Protocol):
"""Abstraction over a task execution backend.

The controller calls sync() in a loop. The provider is responsible for
submitting/cancelling tasks and collecting their state. It returns
HeartbeatApplyRequest batches which the controller applies via
ControllerTransitions.apply_heartbeat().
The WorkerProvider communicates with workers via focused RPCs (Ping,
StartTasks, StopTasks, PollTasks). The controller orchestrates these
calls through dedicated loops.

Logs are pushed directly to the LogService by workers/tasks, not carried
via heartbeats or fetched from the provider.
"""

def sync(
self,
batches: list[DispatchBatch],
) -> list[tuple[DispatchBatch, HeartbeatApplyRequest | None, str | None]]:
"""Sync task state with the execution backend.

Args:
batches: One DispatchBatch per active execution unit, drained from the DB.

Returns:
For each batch: (batch, apply_request | None, error_str | None).
apply_request is None on communication failure (caller uses fail_heartbeat).
"""
...

def get_process_status(
self,
worker_id: WorkerId,
Expand Down
51 changes: 50 additions & 1 deletion lib/iris/src/iris/cluster/controller/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
from iris.cluster.controller.query import execute_raw_query
from iris.rpc import query_pb2
from iris.cluster.controller.scheduler import SchedulingContext
from iris.cluster.controller.transitions import ControllerTransitions
from iris.cluster.controller.transitions import ControllerTransitions, HeartbeatApplyRequest, TaskUpdate
from iris.cluster.controller.provider import ProviderError
from iris.cluster.log_store import build_log_source, worker_log_key
from iris.cluster.process_status import get_process_status
Expand Down Expand Up @@ -2319,3 +2319,52 @@ def get_scheduler_state(
total_pending=total_pending,
total_running=len(running_protos),
)

# --- Worker Push ---

def update_task_status(
self,
request: controller_pb2.Controller.UpdateTaskStatusRequest,
_ctx: Any,
) -> controller_pb2.Controller.UpdateTaskStatusResponse:
"""Worker pushes task state transitions to controller.

Converts the proto updates into TaskUpdate dataclasses and applies
them through the same ControllerTransitions.apply_heartbeat() path
used by the poll-based heartbeat. Returns any tasks the controller
wants the worker to stop.
"""
updates: list[TaskUpdate] = []
for entry in request.updates:
if entry.state in (job_pb2.TASK_STATE_UNSPECIFIED, job_pb2.TASK_STATE_PENDING):
continue
updates.append(
TaskUpdate(
task_id=JobName.from_wire(entry.task_id),
attempt_id=entry.attempt_id,
new_state=entry.state,
error=entry.error or None,
exit_code=entry.exit_code if entry.exit_code != 0 else None,
resource_usage=entry.resource_usage if entry.resource_usage.ByteSize() > 0 else None,
container_id=entry.container_id or None,
)
)

tasks_to_stop: list[str] = []
if updates:
worker_id = WorkerId(request.worker_id)
result = self._transitions.apply_heartbeat(
HeartbeatApplyRequest(
worker_id=worker_id,
worker_resource_snapshot=None,
updates=updates,
)
)
if result.tasks_to_kill:
tasks_to_stop = [tid.to_wire() for tid in result.tasks_to_kill]
# Wake the controller so it can act on any state changes promptly
self._controller.wake()

return controller_pb2.Controller.UpdateTaskStatusResponse(
tasks_to_stop=tasks_to_stop,
)
Loading