Skip to content

[iris] Fix TPU JAX bootstrap retries#4652

Open
Calvin-Xu wants to merge 3 commits intomainfrom
codex/fix-tpu-jax-bootstrap
Open

[iris] Fix TPU JAX bootstrap retries#4652
Calvin-Xu wants to merge 3 commits intomainfrom
codex/fix-tpu-jax-bootstrap

Conversation

@Calvin-Xu
Copy link
Copy Markdown
Member

Delete stale endpoint registrations before retried task attempts so workers cannot resolve an old coordinator. Add regression coverage for the retry path and Iris TPU distributed initialization so failed attempts do not split bootstrap state across retries.

Coscheduled retries could briefly leave old coordinator endpoints visible, letting tasks bootstrap against different JAX coordinators. Prefer the newest endpoint and clear a task's stale endpoints before assigning a retry so distributed init converges on a single coordinator.
@Calvin-Xu Calvin-Xu added the agent-generated Created by automation/agent label Apr 11, 2026
@claude
Copy link
Copy Markdown
Contributor

claude bot commented Apr 11, 2026

Claude finished @Calvin-Xu's task in 4m 17s —— View job


Code review

No high-signal issues found. Checked for bugs and CLAUDE.md / AGENTS.md compliance.

Summary of what the PR does

Belt-and-suspenders fix for stale JAX coordinator resolution across retries:

  1. transitions.py — When a retry is assigned in both queue_assignments (worker-direct path) and drain_for_direct_provider (K8s direct-provider path), delete any lingering endpoints tied to the task_id before bumping attempt_id. This closes an overlap window where a stale endpoint could be re-registered between the initial _move_task_out_of_active (which already clears endpoints on failure) and the retry dispatch — e.g., a late update from the dying attempt. The current_attempt_id >= 0 guard correctly skips first-dispatch tasks.
  2. db.py — Move ORDER BY e.registered_at_ms DESC, e.endpoint_id ASC into endpoint_query_sql so all callers get a deterministic order. This matters for _resolve_endpoint, which returns endpoints[0] — without the ORDER BY, SQLite could return an older row if multiple matched. The existing test-file workaround that added this ORDER BY manually is correctly removed since it's now redundant.
  3. test_transitions.py — New regression test (test_retry_assignment_deletes_stale_endpoints_before_next_attempt) correctly exercises the race by manually inserting a stale endpoint after fail_worker (which clears endpoints via _move_task_out_of_active) and then verifying it's cleared when queue_assignments dispatches the retry. The assertions verify both state (TASK_STATE_ASSIGNED) and endpoint cleanup.
  4. test_distributed.py — Regression test asserting that when an Iris job context is detected (get_job_info() is not None), initialize_iris_jax() is used even when _is_distributed=True and PJRT_DEVICE=TPU, so manual jax.distributed.initialize is not called. This matches the existing control flow at distributed.py:348-361.

Minor observation (non-blocking): The two new cur.execute("DELETE FROM endpoints WHERE task_id = ?", ...) calls at transitions.py:1419 and transitions.py:3008 duplicate the body of the existing delete_task_endpoints(cur, task_id) helper, which is already used in three other places in this file. Using the helper here would be more consistent, but it's purely a style nit — behavior is identical.


  • Gather PR context (diff, files)
  • Read relevant CLAUDE.md / AGENTS.md files
  • Review transitions.py changes (stale endpoint cleanup)
  • Review db.py change
  • Review test additions
  • Post review findings
    Branch: codex/fix-tpu-jax-bootstrap

@rjpower
Copy link
Copy Markdown
Collaborator

rjpower commented Apr 12, 2026

This seems reasonable, but @claude, don't we already have a post-condition from the TASK_PREEMPTED/TASK_FAILED path which deletes registry entries from tasks? We should be deleting the entry long before we try to reschedule.

@claude
Copy link
Copy Markdown
Contributor

claude bot commented Apr 12, 2026

Claude finished @rjpower's task in 3m 36s —— View job


  • Trace the existing failure → retry flow for endpoint cleanup
  • Examine the new cleanup code in this PR
  • Determine if there's a real race window or if the new code is redundant
  • Post findings

Answer

Good question. Yes — both existing failure paths already delete endpoints before the task enters PENDING for retry. The cleanup happens well before rescheduling:

  1. Worker failure path (fail_workers_batch_remove_failed_worker_terminate_task): _terminate_task unconditionally calls delete_task_endpoints(cur, task_id) at line 476, regardless of whether the resolved task state is terminal or PENDING. So endpoints are cleared within the same transaction that sets the task to PENDING for retry.

  2. Heartbeat update path (apply_task_updates_apply_task_transitions): At line 1672–1673, endpoints are deleted when update.new_state in TERMINAL_TASK_STATES. Since TASK_FAILED and TASK_WORKER_FAILED are both terminal, endpoints are deleted even when the task's DB state is subsequently set to PENDING for retry (lines 1591–1605 use update.new_state to gate the deletion, not the resolved task_state).

So Russell is right — under normal flow, endpoints are gone before the task is even eligible for rescheduling.

What the PR guards against is re-registration in the window between cleanup and dispatch: since add_endpoint (line 2901–2926) only rejects inserts when the task is in a TERMINAL_TASK_STATE, a dying worker from the old attempt that calls add_endpoint after the failure transaction commits (task is now PENDING, not terminal) would succeed. The stale endpoint would then persist until the next failure cycle. The new DELETE in queue_assignments and drain_for_direct_provider catches any such late arrivals right before the new attempt is dispatched.

Whether this race is worth defending against in practice is a judgment call — it requires the old worker to register an endpoint after the controller has already processed the failure and set the task to PENDING, which is a narrow window. But for JAX coordinator resolution specifically (where _resolve_endpoint returns endpoints[0]), a stale coordinator address would cause all peers in the new attempt to bootstrap against the wrong coordinator, so the blast radius of the race is high even if the probability is low.

Branch: codex/fix-tpu-jax-bootstrap

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

agent-generated Created by automation/agent

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants