Skip to content

Commit e6ee448

Browse files
authored
iris: add multi-VM CoreWeave support with JAX coordinator bootstrap (#3638)
Closes #3634 - **New `iris.runtime.jax_init` module**: Task 0 registers its coordinator address via the existing endpoint registry; tasks 1..N-1 poll for it using `ExponentialBackoff`, then all call `jax.distributed.initialize()` with the discovered coordinator. Single-task jobs skip coordination entirely. JAX is imported at call time — no new dependency on JAX. - **Multi-VM `CoreweavePlatform`**: Lifts the `num_vms > 1` restriction. Creates N worker Pods per slice with a shared ConfigMap. Proper multi-Pod terminate, partial-failure cleanup, and `_list_slices_by_labels` grouping. - **Config & docs**: Adds `h100-16x` scale group with `num_vms: 2` to `examples/coreweave.yaml`. Documents multi-VM job submission with coscheduling in `docs/coreweave.md`.
1 parent 0e4a8c0 commit e6ee448

File tree

8 files changed

+970
-103
lines changed

8 files changed

+970
-103
lines changed

lib/iris/docs/coreweave.md

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,81 @@ Standard Iris flow. Controller assigns task via heartbeat RPC. Worker calls
409409
2. `handle.terminate()` force-deletes the worker Pod
410410
3. CoreWeave autoscaler deprovisions the bare-metal node when no Pods remain
411411

412-
## 13. Credentials Summary
412+
## 13. Multi-VM Jobs
413+
414+
Multi-VM scale groups allow training across multiple nodes. Each slice in a
415+
multi-VM group provisions N worker Pods (one per VM) that share a single
416+
ConfigMap. All Pods in a slice must reach Ready before the slice is usable.
417+
418+
### Configuration
419+
420+
Define a scale group with `num_vms > 1` in the cluster config. The
421+
`slice_template.num_vms` must match the top-level `num_vms`:
422+
423+
```yaml
424+
scale_groups:
425+
h100-16x:
426+
num_vms: 2
427+
resources:
428+
cpu: 128
429+
ram: 2048GB
430+
disk: 1TB
431+
device_type: gpu
432+
device_variant: H100
433+
device_count: 8
434+
worker:
435+
attributes:
436+
region: US-WEST-04A
437+
pool: h100-16x
438+
min_slices: 0
439+
max_slices: 1
440+
priority: 50
441+
slice_template:
442+
num_vms: 2
443+
coreweave:
444+
region: US-WEST-04A
445+
instance_type: gd-8xh100ib-i128
446+
```
447+
448+
### Submitting multi-replica jobs
449+
450+
Jobs targeting a multi-VM group must use coscheduling so all replicas land on
451+
workers in the same pool. Include `ports=["jax"]` so Iris allocates a named
452+
port for JAX coordinator discovery:
453+
454+
```python
455+
from iris.sdk import IrisClient, CoschedulingConfig
456+
457+
client = IrisClient()
458+
client.submit(
459+
name="multi-node-training",
460+
image="ghcr.io/marin-community/iris-task:latest",
461+
command=["python", "train.py"],
462+
replicas=2,
463+
ports=["jax"],
464+
coscheduling=CoschedulingConfig(group_by="pool"),
465+
resources={"gpu": 8},
466+
)
467+
```
468+
469+
Each replica receives `IRIS_TASK_ID` (0 or 1), `IRIS_NUM_TASKS` (2), and
470+
`IRIS_PORT_JAX` (the allocated coordinator port). Task code calls
471+
`iris.runtime.jax_init.initialize_jax()` to bootstrap JAX distributed — task 0
472+
registers its coordinator address via the endpoint API, and task 1 discovers it
473+
by polling.
474+
475+
### Requirements
476+
477+
- **Coscheduling is mandatory**: Without `CoschedulingConfig(group_by="pool")`,
478+
replicas may land on workers from different scale groups, which lack
479+
InfiniBand connectivity.
480+
- **hostNetwork anti-affinity**: Because worker Pods use `hostNetwork: true`,
481+
two Pods binding the same port cannot schedule on the same node. This
482+
provides implicit anti-affinity — no explicit `podAntiAffinity` rule needed.
483+
- **Gang semantics**: If any task in a coscheduled group fails terminally, all
484+
siblings are killed and the entire group retries together.
485+
486+
## 14. Credentials Summary
413487

414488
### Platform-managed (all created by `iris cluster start`)
415489

@@ -431,19 +505,16 @@ The `kubeconfig_path` config field is only needed when running the CLI
431505
**outside** the cluster (e.g., `iris cluster start` from a laptop). Inside the
432506
cluster, Pods use in-cluster auth automatically.
433507

434-
## 14. Open Questions / Known Limitations
435-
436-
1. **Multi-node slices**: `num_vms > 1` is not supported and raises `ValueError`.
437-
InfiniBand co-scheduling for multi-node training needs investigation.
508+
## 15. Open Questions / Known Limitations
438509

439-
2. **NodePool rate limits**: Creating many NodePools at scale has not been
510+
1. **NodePool rate limits**: Creating many NodePools at scale has not been
440511
validated with CoreWeave.
441512

442-
3. **Task Pod GC**: `ownerReferences` on task Pods only trigger GC when the
513+
2. **Task Pod GC**: `ownerReferences` on task Pods only trigger GC when the
443514
worker Pod object is deleted. If the worker crash-loops in place, stale task
444515
Pods can accumulate. See TODO in `kubernetes.py`.
445516

446-
## 15. Troubleshooting
517+
## 16. Troubleshooting
447518

448519
### NodePool not scaling up
449520

@@ -487,7 +558,7 @@ kubectl logs <pod> -n iris --previous # Logs from the last crash
487558
If `cache_dir` is not set to `/mnt/local/...`, the 15 GB root RAM disk fills
488559
instantly. Fix in config and redeploy.
489560

490-
## 16. References
561+
## 17. References
491562

492563
- [CoreWeave CKS Introduction](https://docs.coreweave.com/docs/products/cks)
493564
- [CKS Cluster Creation](https://docs.coreweave.com/docs/products/cks/clusters/create)

lib/iris/examples/coreweave.yaml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,26 @@ scale_groups:
108108
coreweave:
109109
region: US-WEST-04A
110110
instance_type: gd-8xh100ib-i128
111+
112+
# 16x H100 (2-VM) with InfiniBand — multi-node training
113+
h100-16x:
114+
num_vms: 2
115+
resources:
116+
cpu: 128
117+
ram: 2048GB
118+
disk: 1TB
119+
device_type: gpu
120+
device_variant: H100
121+
device_count: 8
122+
worker:
123+
attributes:
124+
region: US-WEST-04A
125+
pool: h100-16x
126+
min_slices: 0
127+
max_slices: 1
128+
priority: 50
129+
slice_template:
130+
num_vms: 2
131+
coreweave:
132+
region: US-WEST-04A
133+
instance_type: gd-8xh100ib-i128

0 commit comments

Comments
 (0)