Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions lib/iris/src/iris/cluster/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@
from iris.cluster.controller.dashboard import ControllerDashboard
from iris.cluster.providers.k8s.tasks import K8sTaskProvider
from iris.cluster.controller.provider import TaskProvider
from iris.cluster.controller.worker_provider import WorkerProvider
from iris.cluster.controller.scheduler import (
JobRequirements,
Scheduler,
Expand Down Expand Up @@ -1058,10 +1057,10 @@ def __init__(
log_client_interceptors = _log_client_interceptors(config)
self._remote_log_service = LogServiceProxy(self._log_service_address, interceptors=log_client_interceptors)

# Providers push directly to the log server via RPC.
provider_log_pusher = LogPusher(self._log_service_address, interceptors=log_client_interceptors)
if isinstance(self._provider, (K8sTaskProvider, WorkerProvider)):
self._provider.log_pusher = provider_log_pusher
# Providers that collect logs outside the worker process push directly
# to the log server via RPC.
if isinstance(self._provider, K8sTaskProvider):
self._provider.log_pusher = LogPusher(self._log_service_address, interceptors=log_client_interceptors)

# Controller process logs ship to the log server via RemoteLogHandler.
self._log_pusher = LogPusher(self._log_service_address, interceptors=log_client_interceptors)
Expand Down Expand Up @@ -2131,12 +2130,20 @@ def _reap_stale_workers(self) -> None:
stale = [w for w in workers if w.last_heartbeat.age_ms() > threshold_ms]
if not stale:
return
stale_samples = [
{
"worker_id": str(w.worker_id),
"age_s": int(w.last_heartbeat.age_ms() / 1000),
"address": w.address,
}
for w in stale[:10]
]

logger.warning(
"Failing %d workers with stale heartbeats (threshold=%ds): %s",
len(stale),
HEARTBEAT_STALENESS_THRESHOLD.to_seconds(),
[str(w.worker_id) for w in stale[:10]],
stale_samples,
)
failure_result = self._transitions.fail_workers_batch(
[str(w.worker_id) for w in stale],
Expand Down Expand Up @@ -2231,7 +2238,25 @@ def _handle_failed_heartbeats(

primary_failed_workers: list[str] = []
for (batch, error), result in zip(failure_entries, failure_result.results, strict=False):
logger.debug("Sync error for %s: %s", batch.worker_id, error)
last_success_age_s = (
"unknown" if result.last_heartbeat_age_ms is None else f"{result.last_heartbeat_age_ms / 1000.0:.1f}"
)
log_level = logging.ERROR if result.action == HeartbeatAction.WORKER_FAILED else logging.WARNING
logger.log(
log_level,
"Heartbeat RPC failure: worker=%s address=%s action=%s failures=%d/%d last_success_age_s=%s "
"expected=%d run=%d kill=%d error=%s",
batch.worker_id,
batch.worker_address or "<missing>",
result.action.value,
result.consecutive_failures,
result.failure_threshold,
last_success_age_s,
len(batch.running_tasks),
len(batch.tasks_to_run),
len(batch.tasks_to_kill),
error,
)
if result.action == HeartbeatAction.WORKER_FAILED:
acc.fail_count += 1
acc.failed_workers.append(batch.worker_id)
Expand Down
18 changes: 15 additions & 3 deletions lib/iris/src/iris/cluster/controller/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ class HeartbeatApplyResult(TxResult):
class HeartbeatFailureResult(TxResult):
worker_removed: bool = False
action: HeartbeatAction = HeartbeatAction.TRANSIENT_FAILURE
consecutive_failures: int = 0
failure_threshold: int = HEARTBEAT_FAILURE_THRESHOLD
last_heartbeat_age_ms: int | None = None


@dataclass(frozen=True)
Expand Down Expand Up @@ -2121,20 +2124,26 @@ def _record_heartbeat_failure(
tasks_to_kill: set[JobName] = set()
task_kill_workers: dict[JobName, WorkerId] = {}
row = cur.execute(
"SELECT consecutive_failures FROM workers WHERE worker_id = ? AND active = 1",
"SELECT consecutive_failures, last_heartbeat_ms FROM workers WHERE worker_id = ? AND active = 1",
(str(worker_id),),
).fetchone()
if row is None:
return HeartbeatFailureResult(worker_removed=True, action=HeartbeatAction.WORKER_FAILED)
return HeartbeatFailureResult(
worker_removed=True,
action=HeartbeatAction.WORKER_FAILED,
failure_threshold=self._heartbeat_failure_threshold,
)

now_ms = now_ms or Timestamp.now().epoch_ms()
last_heartbeat_ms = row["last_heartbeat_ms"]
last_heartbeat_age_ms = None if last_heartbeat_ms is None else max(0, now_ms - int(last_heartbeat_ms))
failures = int(row["consecutive_failures"]) + 1
cur.execute(
"UPDATE workers SET consecutive_failures = ?, healthy = CASE WHEN ? >= ? THEN 0 ELSE healthy END "
"WHERE worker_id = ?",
(failures, failures, self._heartbeat_failure_threshold, str(worker_id)),
)
should_remove = force_remove or failures >= self._heartbeat_failure_threshold
now_ms = now_ms or Timestamp.now().epoch_ms()
if should_remove:
removal = self._remove_failed_worker(cur, worker_id, error, now_ms=now_ms)
tasks_to_kill.update(removal.tasks_to_kill)
Expand All @@ -2150,6 +2159,9 @@ def _record_heartbeat_failure(
task_kill_workers=task_kill_workers,
worker_removed=should_remove,
action=action,
consecutive_failures=failures,
failure_threshold=self._heartbeat_failure_threshold,
last_heartbeat_age_ms=last_heartbeat_age_ms,
)

def record_heartbeat_failure(
Expand Down
67 changes: 46 additions & 21 deletions lib/iris/src/iris/cluster/controller/worker_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import threading
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from time import sleep
from time import monotonic, sleep
from typing import Protocol

from iris.chaos import chaos
Expand All @@ -17,7 +17,6 @@
HeartbeatApplyRequest,
TaskUpdate,
)
from iris.cluster.log_store._types import LogPusherProtocol, TaskAttempt, task_log_key
from iris.cluster.types import JobName, WorkerId
from iris.rpc import job_pb2
from iris.rpc import worker_pb2
Expand All @@ -26,6 +25,23 @@

logger = logging.getLogger(__name__)

DEFAULT_WORKER_RPC_TIMEOUT = Duration.from_seconds(30.0)
_SLOW_HEARTBEAT_RPC_LOG_THRESHOLD_MS = 10_000


def _heartbeat_rpc_context(
batch: DispatchBatch,
*,
elapsed_ms: int,
timeout_ms: int | None,
) -> str:
timeout_fragment = f" timeout_ms={timeout_ms}" if timeout_ms is not None else ""
return (
f"worker={batch.worker_id} address={batch.worker_address or '<missing>'}"
f" elapsed_ms={elapsed_ms}{timeout_fragment}"
f" expected={len(batch.running_tasks)} run={len(batch.tasks_to_run)} kill={len(batch.tasks_to_kill)}"
)


class WorkerStubFactory(Protocol):
"""Factory for getting cached worker RPC stubs."""
Expand All @@ -39,11 +55,15 @@ class RpcWorkerStubFactory:
"""Caches WorkerServiceClientSync stubs by address so each worker gets
one persistent httpx.Client instead of a new one per RPC."""

def __init__(self, timeout: Duration = Duration.from_seconds(5.0)) -> None:
def __init__(self, timeout: Duration = DEFAULT_WORKER_RPC_TIMEOUT) -> None:
self._timeout = timeout
self._stubs: dict[str, WorkerServiceClientSync] = {}
self._lock = threading.Lock()

@property
def timeout_ms(self) -> int:
return self._timeout.to_ms()

def get_stub(self, address: str) -> WorkerServiceClientSync:
with self._lock:
stub = self._stubs.get(address)
Expand Down Expand Up @@ -105,7 +125,6 @@ class WorkerProvider:
"""

stub_factory: WorkerStubFactory
log_pusher: LogPusherProtocol | None = None
parallelism: int = 32
_pool: ThreadPoolExecutor = field(init=False)

Expand All @@ -131,6 +150,9 @@ def sync(

def _heartbeat_one(self, batch: DispatchBatch) -> HeartbeatApplyRequest:
"""Send heartbeat RPC to one worker and return the apply request."""
started = monotonic()
timeout_ms = getattr(self.stub_factory, "timeout_ms", None)

if rule := chaos("controller.heartbeat"):
sleep(rule.delay_seconds)
raise ProviderError("chaos: heartbeat unavailable")
Expand All @@ -155,23 +177,26 @@ def _heartbeat_one(self, batch: DispatchBatch) -> HeartbeatApplyRequest:
tasks_to_kill=batch.tasks_to_kill,
expected_tasks=expected_tasks,
)
response = stub.heartbeat(request)

if not response.worker_healthy:
health_error = response.health_error or "worker reported unhealthy"
raise ProviderError(f"worker {batch.worker_id} reported unhealthy: {health_error}")

# Forward log entries from old workers that still piggyback logs on
# heartbeat responses. New workers push logs directly via LogPusher.
if self.log_pusher:
for entry in response.tasks:
if entry.log_entries:
key = task_log_key(
TaskAttempt(task_id=JobName.from_wire(entry.task_id), attempt_id=entry.attempt_id)
)
self.log_pusher.push(key, list(entry.log_entries))

return _apply_request_from_response(batch.worker_id, response)
try:
response = stub.heartbeat(request)

if not response.worker_healthy:
health_error = response.health_error or "worker reported unhealthy"
raise ProviderError(f"worker {batch.worker_id} reported unhealthy: {health_error}")

elapsed_ms = int((monotonic() - started) * 1000)
if elapsed_ms >= _SLOW_HEARTBEAT_RPC_LOG_THRESHOLD_MS:
logger.warning(
"Slow heartbeat RPC succeeded: %s",
_heartbeat_rpc_context(batch, elapsed_ms=elapsed_ms, timeout_ms=timeout_ms),
)
return _apply_request_from_response(batch.worker_id, response)
except Exception as e:
elapsed_ms = int((monotonic() - started) * 1000)
context = _heartbeat_rpc_context(batch, elapsed_ms=elapsed_ms, timeout_ms=timeout_ms)
if isinstance(e, ProviderError):
raise ProviderError(f"{e}; {context}") from e
raise ProviderError(f"heartbeat RPC failed: {context}; error={e}") from e

def get_process_status(
self,
Expand Down
Loading
Loading