Skip to content

iris: add multi-VM CoreWeave support with JAX coordinator bootstrap#3638

Merged
rjpower merged 8 commits intomainfrom
multi/7410d2cb
Mar 14, 2026
Merged

iris: add multi-VM CoreWeave support with JAX coordinator bootstrap#3638
rjpower merged 8 commits intomainfrom
multi/7410d2cb

Conversation

@rjpower
Copy link
Copy Markdown
Collaborator

@rjpower rjpower commented Mar 13, 2026

Closes #3634

  • New iris.runtime.jax_init module: Task 0 registers its coordinator address via the existing endpoint registry; tasks 1..N-1 poll for it using ExponentialBackoff, then all call jax.distributed.initialize() with the discovered coordinator. Single-task jobs skip coordination entirely. JAX is imported at call time — no new dependency on JAX.
  • Multi-VM CoreweavePlatform: Lifts the num_vms > 1 restriction. Creates N worker Pods per slice with a shared ConfigMap. Proper multi-Pod terminate, partial-failure cleanup, and _list_slices_by_labels grouping.
  • Config & docs: Adds h100-16x scale group with num_vms: 2 to examples/coreweave.yaml. Documents multi-VM job submission with coscheduling in docs/coreweave.md.
  • Design doc at lib/iris/DESIGN.md.

rjpower and others added 2 commits March 13, 2026 15:59
Covers the jax_init module that uses Iris's existing endpoint registry
to distribute the coordinator address from task 0 to sibling tasks,
plus CoreWeave config changes for multi-VM GPU scale groups.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Replace ctx.endpoint_registry with ctx.registry (actual IrisContext property)
- Change default port from 0 to 8476; document that explicit port is required
  since JAX's gRPC coordinator doesn't expose the bound port
- Add implementation step for lifting CoreweavePlatform's num_vms > 1
  restriction, the hardest part of this work
- Fix feasibility check description: replicas must be an exact multiple of
  num_vms, not an exact match

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@rjpower rjpower added the agent-generated Created by automation/agent label Mar 13, 2026
@claude
Copy link
Copy Markdown
Contributor

claude bot commented Mar 13, 2026

Claude finished @rjpower's task in 6s —— View job


PR Review: Multi-VM CoreWeave JAX Coordinator Bootstrap

  • Read design document and gather context
  • Verify claims against existing codebase
  • Post review feedback

Overall

Well-written design doc. The core approach — reusing the existing Iris endpoint registry for coordinator address discovery — is clean and avoids unnecessary proto/RPC changes. I verified the key code references against the codebase and they are accurate. The endpoint lifecycle analysis (attempt validation, cascade deletes, coscheduled sibling cascade) is correct.

A few observations:

Line reference accuracy

The doc's line references are accurate (within ±1 line for some). Verified:

  • job_info.py:25-74JobInfo is at lines 26-74 (off by one on start)
  • protocol.py:80-90 — endpoint methods at lines 80-90
  • service.py:1050-1055 — attempt validation at lines 1050-1055
  • coreweave.py:627-631num_vms > 1 guard at lines 627-631
  • client.py:1025iris_ctx() at line 1025
  • 0001_init.sql:116 — endpoint task_id FK with ON DELETE CASCADE at line 116
  • distributed.py:348-387DistributedConfig.initialize() at lines 348-387

These are all accurate. Note that line references in design docs tend to drift as the codebase evolves — the doc should be understood as a snapshot.

Design feedback

  1. atexit for endpoint cleanup (line 84 of the design): The atexit handler to unregister the endpoint is a nice touch for graceful shutdown, but the doc correctly identifies that cascade deletes handle the abnormal-exit case. One subtlety: atexit handlers don't run on SIGKILL (which is how Kubernetes terminates pods after the grace period). This is fine because cascade deletes cover it, but worth noting in comments when implementing.

  2. jax.distributed.initialize() with defaults for single-task (line 74): When num_tasks == 1, the code calls jax.distributed.initialize() with no arguments. This works for single-process JAX but is a no-op on most setups. Might be worth a comment explaining why this is called at all (e.g., "ensures JAX backend is initialized consistently regardless of task count").

  3. Port collision risk: The default port 8476 is fine, but the doc should mention that if multiple JAX jobs land on the same node (with hostNetwork: true), they'll collide. The IRIS_PORT_JAX allocation path avoids this — consider making port allocation via IRIS_PORT_JAX the recommended/default path rather than the fallback, especially since hostNetwork: true means all pods on a node share the same port namespace.

  4. _poll_for_coordinator backoff: The doc mentions ExponentialBackoff from iris.time_utils — good choice. The 300s timeout with 2s initial interval is reasonable. One edge case: if task 0 registers the endpoint but then its jax.distributed.initialize() call hangs (before JAX is actually listening), tasks 1..N-1 will resolve the address and try to connect to a coordinator that isn't ready yet. JAX has its own internal retry logic for this, so it should be fine in practice, but worth mentioning.

  5. runtime directory: The design proposes lib/iris/src/iris/runtime/ as a new directory. This doesn't exist yet — makes sense as a namespace for these optional runtime utilities that tasks opt into.

  6. Dependency direction (line 31): The non-goal about Levanter integration is correct — iris should not import levanter. The proposed Levanter callsite pattern (step 6: user code checks IRIS_CONTROLLER_ADDRESS) is the right approach. The future work item about IrisCluster in JAX's ClusterEnv registry is the cleaner long-term solution.

Minor nits

  • Line 9, job_info.py:25-74: Start line should be 26 (class definition is line 26, not 25). Very minor.
  • Line 92, ExponentialBackoff from iris.time_utils: This exists but is actually used across many files. Worth confirming the exact import path when implementing.
  • Line 122, autoscaler feasibility: "validates that replicas is an exact multiple of num_vms" — I didn't verify this specific claim (autoscaler.py:1406). Worth double-checking during implementation.

Verdict

The design is sound. The approach of reusing existing endpoint infrastructure is pragmatic, and the implementation outline is well-scoped. The hardest part (step 3: lifting num_vms > 1 in CoreweavePlatform.create_slice()) is correctly identified and has a reasonable breakdown of sub-steps. Ready for implementation.

rjpower added 2 commits March 13, 2026 16:21
B1: Add 'Multi-VM Pod lifecycle' subsection to Step 3 specifying naming
convention (iris-worker-{slice_id}-vm{i}), shared ConfigMap, updated
_worker_pod_name signature, terminate() iteration, and partial-failure
cleanup pseudocode.

B2: Replace ambiguous 'task_index offset' with accurate description —
each Pod gets a distinct name as its IRIS_WORKER_ID; task assignment is
handled by the scheduler, not by vm_index.

B3: Clarify Step 4 as documentation-only (no framework code changes);
users pass ports=["jax"] in LaunchJobRequest.

N1: Clarify poll_interval is the initial delay for ExponentialBackoff.
N2: Document hostNetwork: true implicit anti-affinity.
N3: Note atexit handler is best-effort; cascade delete is the backstop.
N4: Point integration test to tests/e2e/conftest.py IrisTestCluster
fixture and specify mock pattern for jax.distributed.initialize.
N5: Mark Levanter callsite step as out of scope for initial PR.
- New iris.runtime.jax_init module: task 0 registers coordinator address
  via endpoint registry, tasks 1..N-1 poll and discover it before calling
  jax.distributed.initialize()
- CoreweavePlatform: lift num_vms>1 restriction, create N Pods per slice
  with shared ConfigMap, proper multi-Pod terminate and partial-failure
  cleanup
- Add h100-16x multi-VM scale group to coreweave.yaml example
- Update coreweave.md: remove single-VM limitation, add Multi-VM Jobs
  section with usage docs
- 15 new unit/integration tests for jax_init, 3 new coreweave platform
  tests (all 1293 Iris tests pass)
@rjpower rjpower changed the title design: multi-VM CoreWeave JAX coordinator bootstrap for Iris iris: add multi-VM CoreWeave support with JAX coordinator bootstrap Mar 13, 2026
Add pytest.importorskip("jax") at module level in both test_jax_init.py
and test_jax_init_integration.py so CI environments without jax skip
these tests instead of failing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@rjpower rjpower requested a review from yonromai March 14, 2026 00:52
@rjpower
Copy link
Copy Markdown
Collaborator Author

rjpower commented Mar 14, 2026

Will start testing after this.

Copy link
Copy Markdown
Contributor

@yonromai yonromai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approve. Clean, well-layered implementation that reuses existing infrastructure (endpoint registry, ExponentialBackoff) without adding new RPCs or proto changes. The jax_init module is tight and the CoreWeave multi-VM support follows the established pattern naturally.

One actionable finding: the partial-failure cleanup in _monitor_slice wraps all deletes in a single try/except, so one failed delete prevents remaining cleanup (see inline comment). Straightforward fix — fine to address in a follow-up.

Generated with Claude Code

self._kubectl.delete("configmap", cm_name)
for pod_name in handle._pod_names:
self._kubectl.delete("pod", pod_name, force=True)
self._kubectl.delete("configmap", _worker_config_cm_name(handle.slice_id))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Medium — cleanup can skip pods and ConfigMap

The cleanup loop wraps all deletes in a single try/except. If kubectl delete pod raises for any pod (e.g. a pod whose creation failed and doesn't exist on the API server), subsequent pods and the ConfigMap are never cleaned up.

Suggested fix — try each delete independently:

for pod_name in handle._pod_names:
    try:
        self._kubectl.delete("pod", pod_name, force=True)
    except Exception:
        logger.warning("Failed to clean up pod %s for slice %s", pod_name, handle.slice_id)
try:
    self._kubectl.delete("configmap", _worker_config_cm_name(handle.slice_id))
except Exception:
    logger.warning("Failed to clean up ConfigMap for slice %s", handle.slice_id)

Generated with Claude Code

Comment thread lib/iris/src/iris/runtime/jax_init.py Outdated
if task_index == 0:
bound_port = job_info.ports.get("jax", port)
address = f"{job_info.advertise_host}:{bound_port}"
endpoint_id = ctx.registry.register(endpoint_name, address)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Low (non-blocking) — endpoint registered before coordinator is listening

Task 0 registers the endpoint here, then calls jax.distributed.initialize() on line 107 which starts the gRPC coordinator. During this window, tasks 1..N-1 can resolve the endpoint and try connecting to a coordinator that isn't listening yet.

JAX's internal gRPC client retries handle this gracefully in practice. But if you wanted to eliminate the race entirely, you could swap the order: call jax.distributed.initialize() first, then register. The non-coordinator tasks' ExponentialBackoff already accounts for startup delay.

Not blocking — just flagging the ordering.

Generated with Claude Code

@rjpower
Copy link
Copy Markdown
Collaborator Author

rjpower commented Mar 14, 2026

@claude let's fix the issues from yonromai review

@claude
Copy link
Copy Markdown
Contributor

claude bot commented Mar 14, 2026

Claude finished @rjpower's task in 2m 3s —— View job


Fixing review feedback from @yonromai

  • Fix partial-failure cleanup in _monitor_slice — wrap each delete independently
  • Fix endpoint registration ordering in jax_init.py — register after jax.distributed.initialize()
  • Run pre-commit and tests
  • Commit and push

Both issues from @yonromai's review are fixed:

  1. coreweave.py cleanup — Each pod delete and the ConfigMap delete are now wrapped in their own try/except, so one failed cleanup doesn't prevent the rest from running.

  2. jax_init.py registration ordering — Task 0 now calls jax.distributed.initialize() before registering the endpoint. This eliminates the race window where other tasks could discover the coordinator address before the gRPC server is actually listening.

Testing: pre-commit passes, all 54 tests in test_jax_init.py and test_coreweave_platform.py pass (including test_multi_node_partial_failure_cleans_up).

…er init

1. Partial-failure cleanup in _monitor_slice now wraps each pod delete and
   ConfigMap delete independently so one failure doesn't skip remaining cleanup.

2. Task 0 now calls jax.distributed.initialize() before registering the
   coordinator endpoint, eliminating the race window where other tasks could
   discover the address before the gRPC coordinator is listening.

Co-authored-by: Russell Power <rjpower@users.noreply.github.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@rjpower
Copy link
Copy Markdown
Collaborator Author

rjpower commented Mar 14, 2026

@claude doesn't jax.distributed.initialize block? we need to register the endpoint before calling it IIUC, just rely on jax not being dumb.

@claude
Copy link
Copy Markdown
Contributor

claude bot commented Mar 14, 2026

Claude finished @rjpower's task in 1m 58s —— View job


Fixing endpoint registration ordering

  • Read jax_init.py to understand current state
  • Revert registration ordering back to register-before-initialize
  • Run pre-commit and tests
  • Commit and push

You're right — jax.distributed.initialize() blocks until all processes connect. Calling it before registering the endpoint would deadlock, since tasks 1..N-1 would never discover the coordinator address.

Reverted to register-before-initialize. The brief window where the endpoint is registered but the gRPC coordinator isn't yet listening is handled by JAX's internal retry logic. Updated the comment to explain the ordering rationale.

Testing: pre-commit passes. All 8 tests in test_jax_init.py pass.

View job

…dlock

jax.distributed.initialize() blocks until all processes connect, so
registering the endpoint after calling it would deadlock — tasks 1..N-1
would never discover the coordinator address. JAX's internal gRPC retry
handles the brief window between registration and the coordinator
starting to listen.

Co-authored-by: Russell Power <rjpower@users.noreply.github.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@rjpower rjpower enabled auto-merge (squash) March 14, 2026 02:52
@rjpower rjpower merged commit e6ee448 into main Mar 14, 2026
16 checks passed
@rjpower rjpower deleted the multi/7410d2cb branch March 14, 2026 03:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

agent-generated Created by automation/agent

Projects

None yet

Development

Successfully merging this pull request may close these issues.

iris: add support for coreweave multi-host

2 participants