iris: add multi-VM CoreWeave support with JAX coordinator bootstrap#3638
iris: add multi-VM CoreWeave support with JAX coordinator bootstrap#3638
Conversation
Covers the jax_init module that uses Iris's existing endpoint registry to distribute the coordinator address from task 0 to sibling tasks, plus CoreWeave config changes for multi-VM GPU scale groups. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Replace ctx.endpoint_registry with ctx.registry (actual IrisContext property) - Change default port from 0 to 8476; document that explicit port is required since JAX's gRPC coordinator doesn't expose the bound port - Add implementation step for lifting CoreweavePlatform's num_vms > 1 restriction, the hardest part of this work - Fix feasibility check description: replicas must be an exact multiple of num_vms, not an exact match Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
Claude finished @rjpower's task in 6s —— View job PR Review: Multi-VM CoreWeave JAX Coordinator Bootstrap
OverallWell-written design doc. The core approach — reusing the existing Iris endpoint registry for coordinator address discovery — is clean and avoids unnecessary proto/RPC changes. I verified the key code references against the codebase and they are accurate. The endpoint lifecycle analysis (attempt validation, cascade deletes, coscheduled sibling cascade) is correct. A few observations: Line reference accuracyThe doc's line references are accurate (within ±1 line for some). Verified:
These are all accurate. Note that line references in design docs tend to drift as the codebase evolves — the doc should be understood as a snapshot. Design feedback
Minor nits
VerdictThe design is sound. The approach of reusing existing endpoint infrastructure is pragmatic, and the implementation outline is well-scoped. The hardest part (step 3: lifting |
B1: Add 'Multi-VM Pod lifecycle' subsection to Step 3 specifying naming
convention (iris-worker-{slice_id}-vm{i}), shared ConfigMap, updated
_worker_pod_name signature, terminate() iteration, and partial-failure
cleanup pseudocode.
B2: Replace ambiguous 'task_index offset' with accurate description —
each Pod gets a distinct name as its IRIS_WORKER_ID; task assignment is
handled by the scheduler, not by vm_index.
B3: Clarify Step 4 as documentation-only (no framework code changes);
users pass ports=["jax"] in LaunchJobRequest.
N1: Clarify poll_interval is the initial delay for ExponentialBackoff.
N2: Document hostNetwork: true implicit anti-affinity.
N3: Note atexit handler is best-effort; cascade delete is the backstop.
N4: Point integration test to tests/e2e/conftest.py IrisTestCluster
fixture and specify mock pattern for jax.distributed.initialize.
N5: Mark Levanter callsite step as out of scope for initial PR.
- New iris.runtime.jax_init module: task 0 registers coordinator address via endpoint registry, tasks 1..N-1 poll and discover it before calling jax.distributed.initialize() - CoreweavePlatform: lift num_vms>1 restriction, create N Pods per slice with shared ConfigMap, proper multi-Pod terminate and partial-failure cleanup - Add h100-16x multi-VM scale group to coreweave.yaml example - Update coreweave.md: remove single-VM limitation, add Multi-VM Jobs section with usage docs - 15 new unit/integration tests for jax_init, 3 new coreweave platform tests (all 1293 Iris tests pass)
Add pytest.importorskip("jax") at module level in both test_jax_init.py
and test_jax_init_integration.py so CI environments without jax skip
these tests instead of failing.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
Will start testing after this. |
yonromai
left a comment
There was a problem hiding this comment.
Approve. Clean, well-layered implementation that reuses existing infrastructure (endpoint registry, ExponentialBackoff) without adding new RPCs or proto changes. The jax_init module is tight and the CoreWeave multi-VM support follows the established pattern naturally.
One actionable finding: the partial-failure cleanup in _monitor_slice wraps all deletes in a single try/except, so one failed delete prevents remaining cleanup (see inline comment). Straightforward fix — fine to address in a follow-up.
Generated with Claude Code
| self._kubectl.delete("configmap", cm_name) | ||
| for pod_name in handle._pod_names: | ||
| self._kubectl.delete("pod", pod_name, force=True) | ||
| self._kubectl.delete("configmap", _worker_config_cm_name(handle.slice_id)) |
There was a problem hiding this comment.
Medium — cleanup can skip pods and ConfigMap
The cleanup loop wraps all deletes in a single try/except. If kubectl delete pod raises for any pod (e.g. a pod whose creation failed and doesn't exist on the API server), subsequent pods and the ConfigMap are never cleaned up.
Suggested fix — try each delete independently:
for pod_name in handle._pod_names:
try:
self._kubectl.delete("pod", pod_name, force=True)
except Exception:
logger.warning("Failed to clean up pod %s for slice %s", pod_name, handle.slice_id)
try:
self._kubectl.delete("configmap", _worker_config_cm_name(handle.slice_id))
except Exception:
logger.warning("Failed to clean up ConfigMap for slice %s", handle.slice_id)Generated with Claude Code
| if task_index == 0: | ||
| bound_port = job_info.ports.get("jax", port) | ||
| address = f"{job_info.advertise_host}:{bound_port}" | ||
| endpoint_id = ctx.registry.register(endpoint_name, address) |
There was a problem hiding this comment.
Low (non-blocking) — endpoint registered before coordinator is listening
Task 0 registers the endpoint here, then calls jax.distributed.initialize() on line 107 which starts the gRPC coordinator. During this window, tasks 1..N-1 can resolve the endpoint and try connecting to a coordinator that isn't listening yet.
JAX's internal gRPC client retries handle this gracefully in practice. But if you wanted to eliminate the race entirely, you could swap the order: call jax.distributed.initialize() first, then register. The non-coordinator tasks' ExponentialBackoff already accounts for startup delay.
Not blocking — just flagging the ordering.
Generated with Claude Code
|
@claude let's fix the issues from yonromai review |
|
Claude finished @rjpower's task in 2m 3s —— View job Fixing review feedback from @yonromai
Both issues from @yonromai's review are fixed:
Testing: |
…er init 1. Partial-failure cleanup in _monitor_slice now wraps each pod delete and ConfigMap delete independently so one failure doesn't skip remaining cleanup. 2. Task 0 now calls jax.distributed.initialize() before registering the coordinator endpoint, eliminating the race window where other tasks could discover the address before the gRPC coordinator is listening. Co-authored-by: Russell Power <rjpower@users.noreply.github.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
@claude doesn't jax.distributed.initialize block? we need to register the endpoint before calling it IIUC, just rely on jax not being dumb. |
|
Claude finished @rjpower's task in 1m 58s —— View job Fixing endpoint registration ordering
You're right — Reverted to register-before-initialize. The brief window where the endpoint is registered but the gRPC coordinator isn't yet listening is handled by JAX's internal retry logic. Updated the comment to explain the ordering rationale. Testing: |
…dlock jax.distributed.initialize() blocks until all processes connect, so registering the endpoint after calling it would deadlock — tasks 1..N-1 would never discover the coordinator address. JAX's internal gRPC retry handles the brief window between registration and the coordinator starting to listen. Co-authored-by: Russell Power <rjpower@users.noreply.github.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Closes #3634
iris.runtime.jax_initmodule: Task 0 registers its coordinator address via the existing endpoint registry; tasks 1..N-1 poll for it usingExponentialBackoff, then all calljax.distributed.initialize()with the discovered coordinator. Single-task jobs skip coordination entirely. JAX is imported at call time — no new dependency on JAX.CoreweavePlatform: Lifts thenum_vms > 1restriction. Creates N worker Pods per slice with a shared ConfigMap. Proper multi-Pod terminate, partial-failure cleanup, and_list_slices_by_labelsgrouping.h100-16xscale group withnum_vms: 2toexamples/coreweave.yaml. Documents multi-VM job submission with coscheduling indocs/coreweave.md.lib/iris/DESIGN.md.