Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 46 additions & 15 deletions lib/iris/src/iris/cluster/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,13 @@ def __init__(

self._host_metrics = HostMetricsCollector(disk_path=str(self._cache_dir))

# LogPusher and RemoteLogHandler are created after registration, once
# the worker can resolve /system/log-server via ListEndpoints.
# LogPusher and RemoteLogHandler are created before registration so
# pre-register failures (container bring-up, disk/health probes,
# registration rejection) leave remote logs. Attachment relies on
# ``self._worker_id`` having been resolved locally (IRIS_WORKER_ID,
# slice_id + TPU index, or GCE instance name); the rare case where
# the controller assigns the id is handled by re-attaching post-
# register.
self._log_pusher: LogPusher | None = None
self._log_handler: RemoteLogHandler | None = None

Expand Down Expand Up @@ -379,9 +384,14 @@ def _run_lifecycle(self, stop_event: threading.Event) -> None:

This loop runs continuously until shutdown. On each iteration:
1. Reset worker state (kill all containers)
2. Register with controller (retry until accepted)
3. Serve (wait for heartbeats from controller)
4. If heartbeat timeout expires, return to step 1
2. Attach the remote log handler so pre-registration log lines ship
to the central log server (and a refreshed /system/log-server
endpoint is picked up after any log-server failover)
3. Register with controller (retry until accepted)
4. If the controller assigned a worker_id we didn't know locally,
re-attach the handler under the canonical key
5. Serve (wait for heartbeats from controller)
6. If heartbeat timeout expires, return to step 1

On the first iteration after a restart with adopted containers,
step 1 is skipped to preserve the running tasks.
Expand All @@ -396,12 +406,14 @@ def _run_lifecycle(self, stop_event: threading.Event) -> None:
else:
self._reset_worker_state()
first_iteration = False
self._attach_log_handler()
worker_id = self._register(stop_event)
if worker_id is None:
# Shutdown requested during registration
break
self._worker_id = worker_id
self._attach_log_handler()
if worker_id != self._worker_id:
self._worker_id = worker_id
self._attach_log_handler()
self._serve(stop_event)
except Exception:
logger.exception("Worker lifecycle crashed")
Expand Down Expand Up @@ -449,15 +461,25 @@ def _register(self, stop_event: threading.Event) -> str | None:
return None

def _resolve_log_service(self) -> str | None:
"""Resolve the LogService address via the /system/log-server endpoint."""
"""Resolve the LogService address via the /system/log-server endpoint.

Called before registration, so the controller may not yet be reachable.
Treats RPC errors and missing endpoints the same: log a warning and
return None so the caller can skip remote log attachment without
crashing the lifecycle thread.
"""
if not self._controller_client:
return None
resp = self._controller_client.list_endpoints(
controller_pb2.Controller.ListEndpointsRequest(
prefix="/system/log-server",
exact=True,
),
)
try:
resp = self._controller_client.list_endpoints(
controller_pb2.Controller.ListEndpointsRequest(
prefix="/system/log-server",
exact=True,
),
)
except Exception as e:
logger.warning("Failed to resolve /system/log-server: %s", e)
return None
if not resp.endpoints:
logger.warning("No /system/log-server endpoint registered on controller")
return None
Expand All @@ -466,7 +488,16 @@ def _resolve_log_service(self) -> str | None:
return addr

def _attach_log_handler(self) -> None:
"""Create LogPusher and attach RemoteLogHandler after registration."""
"""Create LogPusher and attach RemoteLogHandler under ``worker_log_key``.

Always tears down any existing handler first so each lifecycle cycle
re-resolves /system/log-server (picking up log-server failover) and
rebuilds the LogPusher against the fresh address.

Skipped when ``self._worker_id`` is not yet known locally — in that
(rare) case the controller will assign an id during ``_register`` and
the lifecycle loop re-calls this method with the canonical id.
"""
self._detach_log_handler()
if not self._worker_id:
return
Expand Down
118 changes: 118 additions & 0 deletions lib/iris/tests/cluster/worker/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import socket
import time
import zipfile
from typing import ClassVar
from unittest.mock import Mock

import pytest
Expand Down Expand Up @@ -583,6 +584,123 @@ def test_port_binding_failure(mock_bundle_store, tmp_path):
assert "address already in use" in final_task.error


# ============================================================================
# Remote log handler attach tests (regression for #4794)
# ============================================================================


def _log_server_endpoints(address: str):
from iris.rpc import controller_pb2

return controller_pb2.Controller.ListEndpointsResponse(
endpoints=[
controller_pb2.Controller.Endpoint(
endpoint_id="/system/log-server",
name="/system/log-server",
address=address,
)
]
)


class _RecordingPusher:
"""Records server_url so tests can observe LogPusher re-creation."""

instances: ClassVar[list["_RecordingPusher"]] = []

def __init__(self, server_url, **_kwargs):
self.server_url = server_url
_RecordingPusher.instances.append(self)

def push(self, key, entries):
pass

def flush(self):
pass

def close(self):
pass


@pytest.fixture
def recording_log_pusher(monkeypatch):
"""Swap iris.log_server.client.LogPusher for a recorder that tracks constructions."""
from iris.cluster.worker import worker as worker_module

_RecordingPusher.instances = []
monkeypatch.setattr(worker_module, "LogPusher", _RecordingPusher)
yield _RecordingPusher


def test_attach_log_handler_uses_worker_log_key_before_register(
mock_bundle_store, mock_runtime, tmp_path, recording_log_pusher
):
"""Worker known locally (e.g. via slice_id) attaches under worker_log_key
*before* register so pre-register failures ship remote logs."""
from iris.cluster.log_store import worker_log_key

config = WorkerConfig(
port=0,
port_range=(50000, 50100),
cache_dir=tmp_path / "cache",
default_task_image="mock-image",
worker_id="w-1",
)
worker = Worker(config, bundle_store=mock_bundle_store, container_runtime=mock_runtime)
worker._controller_client = Mock(list_endpoints=Mock(return_value=_log_server_endpoints("http://log:9000")))

try:
worker._attach_log_handler()
assert worker._log_handler is not None
assert worker._log_handler.key == worker_log_key("w-1")
finally:
worker._detach_log_handler()


def test_attach_log_handler_tolerates_resolve_failure(mock_bundle_store, mock_runtime, tmp_path, recording_log_pusher):
"""A ListEndpoints RPC failure must not crash the lifecycle thread."""
config = WorkerConfig(
port=0,
port_range=(50000, 50100),
cache_dir=tmp_path / "cache",
default_task_image="mock-image",
worker_id="w-1",
)
worker = Worker(config, bundle_store=mock_bundle_store, container_runtime=mock_runtime)
worker._controller_client = Mock(list_endpoints=Mock(side_effect=ConnectionError("controller down")))

worker._attach_log_handler()
assert worker._log_handler is None
assert worker._log_pusher is None
assert recording_log_pusher.instances == []


def test_attach_log_handler_rebuilds_pusher_on_reattach(mock_bundle_store, mock_runtime, tmp_path, recording_log_pusher):
"""Repeated attach must tear down the old LogPusher so log-server failover
is picked up — protects against the regression Codex flagged."""
config = WorkerConfig(
port=0,
port_range=(50000, 50100),
cache_dir=tmp_path / "cache",
default_task_image="mock-image",
worker_id="w-1",
)
worker = Worker(config, bundle_store=mock_bundle_store, container_runtime=mock_runtime)
addrs = iter(["http://log-a:9000", "http://log-b:9000"])
worker._controller_client = Mock(list_endpoints=Mock(side_effect=lambda _req: _log_server_endpoints(next(addrs))))

try:
worker._attach_log_handler()
worker._attach_log_handler()
assert [p.server_url for p in recording_log_pusher.instances] == [
"http://log-a:9000",
"http://log-b:9000",
]
assert worker._log_pusher is recording_log_pusher.instances[-1]
finally:
worker._detach_log_handler()


# ============================================================================
# Integration Tests (with real Docker)
# ============================================================================
Expand Down
Loading