Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/iris/src/iris/cluster/controller/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def endpoint_query_sql(query: EndpointQuery) -> tuple[str, list[object]]:
sql = from_clause
if conditions:
sql += " WHERE " + " AND ".join(conditions)
sql += " ORDER BY e.registered_at_ms DESC, e.endpoint_id ASC"
if query.limit is not None:
sql += " LIMIT ?"
params.append(query.limit)
Expand Down
13 changes: 11 additions & 2 deletions lib/iris/src/iris/cluster/controller/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,7 +1410,13 @@ def queue_assignments(self, assignments: list[Assignment]) -> AssignmentResult:
continue
job_cache[job_id_wire] = decoded_job
job = job_cache[job_id_wire]
attempt_id = int(task_row["current_attempt_id"]) + 1
current_attempt_id = int(task_row["current_attempt_id"])
attempt_id = current_attempt_id + 1
if current_attempt_id >= 0:
# Clear endpoints from the previous attempt before
# launching a retry so new peers cannot resolve a stale
# coordinator.
cur.execute("DELETE FROM endpoints WHERE task_id = ?", (assignment.task_id.to_wire(),))
_assign_task(
cur,
assignment.task_id.to_wire(),
Expand Down Expand Up @@ -2992,11 +2998,14 @@ def drain_for_direct_provider(

for row in pending_rows:
task_id = str(row["task_id"])
attempt_id = int(row["current_attempt_id"]) + 1
current_attempt_id = int(row["current_attempt_id"])
attempt_id = current_attempt_id + 1
job_req = controller_pb2.Controller.LaunchJobRequest()
job_req.ParseFromString(row["request_proto"])
resources = job_req.resources

if current_attempt_id >= 0:
cur.execute("DELETE FROM endpoints WHERE task_id = ?", (task_id,))
_assign_task(cur, task_id, None, None, attempt_id, now_ms)

run_req = job_pb2.RunTaskRequest(
Expand Down
54 changes: 52 additions & 2 deletions lib/iris/tests/cluster/controller/test_transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ def _queued_dispatch(

def _endpoints(state: ControllerTransitions, query: EndpointQuery = EndpointQuery()) -> list[EndpointRow]:
sql, params = endpoint_query_sql(query)
# Add ORDER BY to match original behavior
sql += " ORDER BY registered_at_ms DESC, endpoint_id ASC"
with state._db.snapshot() as q:
return ENDPOINT_PROJECTION.decode(q.fetchall(sql, tuple(params)))

Expand Down Expand Up @@ -706,6 +704,58 @@ def test_endpoint_deleted_on_worker_failure(state):
assert len(_endpoints(state, EndpointQuery(exact_name="ns-1/actor"))) == 0


def test_retry_assignment_deletes_stale_endpoints_before_next_attempt(state):
"""Retry assignment removes stale endpoints before the next attempt is assigned."""

worker_1 = register_worker(state, "w1", "host-1:8080", make_worker_metadata())
worker_2 = register_worker(state, "w2", "host-2:8080", make_worker_metadata())

req = make_job_request("test")
req.max_retries_preemption = 2
tasks = submit_job(state, "ns-1", req)
task = tasks[0]

dispatch_task(state, task, worker_1)
state.add_endpoint(
EndpointRow(
endpoint_id="ep-1",
name="ns-1/jax_coordinator",
address="10.0.0.1:8476",
job_id=JobName.root("test-user", "ns-1"),
metadata={},
registered_at=Timestamp.now(),
),
task_id=task.task_id,
)

fail_worker(state, worker_1, "Connection lost")
assert _query_task(state, task.task_id).state == job_pb2.TASK_STATE_PENDING

# Simulate the overlap window where a stale coordinator endpoint is still
# visible just before the retry is launched.
state.add_endpoint(
EndpointRow(
endpoint_id="ep-stale",
name="ns-1/jax_coordinator",
address="10.0.0.9:8476",
job_id=JobName.root("test-user", "ns-1"),
metadata={},
registered_at=Timestamp.now(),
),
task_id=task.task_id,
)
assert [endpoint.address for endpoint in _endpoints(state, EndpointQuery(exact_name="ns-1/jax_coordinator"))] == [
"10.0.0.9:8476"
]

state.queue_assignments([Assignment(task_id=task.task_id, worker_id=worker_2)])

retried_task = _query_task(state, task.task_id)
assert retried_task is not None
assert retried_task.state == job_pb2.TASK_STATE_ASSIGNED
assert _endpoints(state, EndpointQuery(exact_name="ns-1/jax_coordinator")) == []


def test_endpoint_survives_building_state(state):
"""Endpoints registered during BUILDING are not deleted by subsequent BUILDING updates."""

Expand Down
21 changes: 21 additions & 0 deletions lib/levanter/tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,24 @@ def test_distributed_config_initializes_via_iris_when_iris_job_present(

mock_initialize_iris_jax.assert_called_once_with()
mock_jax_initialize.assert_not_called()


@patch("jax.distributed.initialize")
@patch("iris.runtime.jax_init.initialize_jax")
@patch("iris.cluster.client.job_info.get_job_info")
@patch("levanter.distributed.DistributedConfig._is_distributed", return_value=True)
def test_distributed_config_skips_manual_init_for_iris_tpu_jobs(
mock_is_distributed,
mock_get_job_info,
mock_initialize_iris_jax,
mock_jax_initialize,
monkeypatch,
):
"""Iris TPU jobs should defer distributed init to the TPU runtime."""
monkeypatch.setenv("PJRT_DEVICE", "TPU")
mock_get_job_info.return_value = object()

DistributedConfig().initialize()

mock_initialize_iris_jax.assert_called_once_with()
mock_jax_initialize.assert_not_called()
Loading