Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
b06589a
Add elastic TPU training transport and benchmarks
dlwh Mar 12, 2026
6158b4e
Add DiLoCo-style elastic sync mode
dlwh Mar 12, 2026
1148c15
Fix Iris nested job context detection
dlwh Mar 12, 2026
00cba33
Harden elastic DiLoCo rerun instrumentation
dlwh Mar 12, 2026
142d72d
Use standard validation mixes in elastic benchmarks
dlwh Mar 12, 2026
86d53b8
Document elastic validation rerun
dlwh Mar 12, 2026
798c7bb
Launch elastic budget compare via executor
dlwh Mar 12, 2026
e47d75c
Log executor-dispatched elastic rerun
dlwh Mar 12, 2026
1679185
Treat deleted transfer buffers as transient publish misses
dlwh Mar 12, 2026
db8c017
Log deleted-array elastic publish fix
dlwh Mar 12, 2026
e21fc76
Stage elastic transfer payloads before publish
dlwh Mar 12, 2026
46d6063
Log staged-publish rollout to 0312g
dlwh Mar 12, 2026
7f644cd
Decouple DiLoCo anchor from donated model buffers
dlwh Mar 13, 2026
fe55611
Log DiLoCo anchor fix and 0312h relaunch
dlwh Mar 13, 2026
0c378ba
Stop sharing DiLoCo outer state across peers
dlwh Mar 13, 2026
c33afb0
Stabilize DiLoCo sync with staleness gating and update clipping
dlwh Mar 13, 2026
d63014a
Add MaxText-style Nesterov DiLoCo outer optimizer
dlwh Mar 13, 2026
58dc350
Log elastic Adam vs Nesterov A/B launch
dlwh Mar 13, 2026
ea0d0b4
Fix elastic compare launch path for default validation
dlwh Mar 13, 2026
8211557
Merge remote-tracking branch 'origin/main' into codex/research/resili…
dlwh Mar 19, 2026
4195b6b
Merge remote-tracking branch 'origin/main' into codex/research/resili…
dlwh Mar 19, 2026
1729aeb
Merge remote-tracking branch 'origin/main' into codex/research/resili…
dlwh Mar 20, 2026
39d1c86
[iris] Ignore unknown inherited constraint fields
dlwh Mar 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
926 changes: 926 additions & 0 deletions .agents/logbooks/resilient-tpu-training.md

Large diffs are not rendered by default.

205 changes: 205 additions & 0 deletions .agents/projects/resilient-tpu-training.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Resilient TPU Training Design Sketch

## Problem

Current multislice training in Fray is still gang-scheduled. One slice loss cancels the entire cohort, and the existing "flex" path only accepts exact sizes from a fixed list. That is better than a single hard size, but it is not truly elastic and it does not satisfy the desired property that the run should keep making progress with whatever slice count is currently available.

## Recommendation

Build elasticity as a controller-driven restart system, not as in-place world-size mutation inside a running JAX job.

The core observation is:

- JAX process groups appear static once formed.
- Levanter checkpoints already support restore on a different host count.
- Iris already has a persistent controller model and slice reconciliation logic.
- JAX's official fault-tolerance docs are still experimental and explicitly not fully ready for TPU.

That combination suggests a clean architecture:

1. A persistent controller owns run identity, membership policy, and last committed state.
2. TPU workers are launched as short-lived fixed-mesh cohorts.
3. Membership changes trigger a controlled stop at the next safe point, followed by relaunch on the currently available slice count.

## Proposed Components

### 1. Elastic policy

Replace `num_slices: int | Sequence[int]` as the high-level user API with a range and policy object.

```python
@dataclass(frozen=True)
class ElasticSlicesConfig:
min_slices: int = 1
max_slices: int | None = None
target_slices: int | None = None
scale_up_cooldown_seconds: int = 1800
resize_mode: Literal["checkpoint_boundary", "manual"] = "checkpoint_boundary"
```

This makes the contract explicit:

- any slice count in `[min_slices, max_slices]` is admissible;
- progress is required at `min_slices`;
- scale-up is opportunistic, not mandatory.

### 2. Persistent elastic controller

Add an `ElasticTrainingController` that runs on non-preemptible CPU infrastructure. It should not be a TPU worker or a best-effort Ray actor.

Responsibilities:

- track run epoch, current slice set, last committed checkpoint, and restart reason;
- decide when to admit scale-up or trigger scale-down;
- coordinate safe-point barriers;
- relaunch cohorts with the new world size;
- publish status for observability.

Suggested control loop:

```python
while not terminal:
available = slice_manager.available_slices()
wanted = policy.choose_slice_count(available)

if cohort is None:
cohort = launch_cohort(checkpoint=state.latest_checkpoint, num_slices=wanted)
continue

event = controller.wait_for_event()

if event.kind in {"slice_lost", "slice_gained"}:
if policy.should_reconfigure(event, available):
cohort.request_stop_at_safe_point(reason=event.kind)
checkpoint = cohort.await_checkpoint()
cohort = launch_cohort(checkpoint=checkpoint, num_slices=policy.choose_slice_count(available))
```

### 3. Fixed-mesh training cohorts

Each cohort is a normal Levanter run:

- fixed JAX process group,
- fixed device mesh,
- fixed sharding for the duration of the cohort epoch.

This keeps the training code close to current semantics. Elasticity only happens at safe points between cohorts.

### 4. Safe-point protocol

Add a lightweight safe-point callback in Levanter:

- every `N` steps, or on controller request, all hosts enter a barrier;
- if no resize is requested, continue immediately;
- if resize is requested, flush async checkpoint writes, write a small manifest, and exit cleanly.

The manifest needs only a few fields:

```python
@dataclass(frozen=True)
class ElasticCheckpointManifest:
run_id: str
epoch: int
step: int
checkpoint_path: str
slice_count: int
mesh_shape: dict[str, int]
batch_tokens: int
```

### 5. Batch and optimizer semantics

Keep per-device microbatch fixed. Adjust global effective batch by changing gradient accumulation after each restart.

Rules:

- if slice count drops, increase accumulation so optimizer semantics stay close to the original target batch;
- if slice count rises, reduce accumulation before increasing batch directly;
- record the effective batch in the checkpoint manifest and logs.

This avoids coupling correctness to a specific world size.

### 6. State transport abstraction

Use durable checkpoint restore as the correctness path.

Add an optional fast path:

```python
class ElasticStateRelay(Protocol):
def publish(self, state_ref: str, state: PyTree) -> None: ...
def fetch(self, state_ref: str, exemplar: PyTree) -> PyTree: ...
```

Implementations:

- `TensorStoreRelay`: always available, durable, slower.
- `ArrowFlightRelay`: existing, practical fallback for fast host-to-host state movement.
- `JaxTransferRelay`: only enable when the runtime proves the API is available and stable.

This is where TransferServer fits. It should accelerate warm restarts, not define the whole system's correctness story.

## Why not just extend Fray's list-of-sizes model?

Because the list-of-sizes workaround solves only admission. It does not solve runtime membership change.

Current Fray behavior on slice failure is still:

1. cancel the cohort,
2. count it as preemption,
3. relaunch the whole job.

That is fundamentally gang behavior. A more generic `num_slices` parser does not change the underlying lifecycle model.

## Why not lead with DiLoCo?

DiLoCo-like training is worth a separate track, especially if the goal is to use many unreliable slices with low communication overhead. There is now enough evidence to treat it seriously: the original DiLoCo paper reports robustness to workers appearing and disappearing, OpenDiLoCo reports multi-continent replication with `90-95%` utilization, and INTELLECT-1 reports a `10B` / `1T` token run built on a hybrid DiLoCo system.

But it still changes the optimization algorithm and only helps if the full model fits on the minimum local slice footprint.

That makes it a good research branch, not the baseline fix for Levanter's current synchronous training stack.

Recommended sequencing:

1. ship elastic restart for synchronous training;
2. benchmark optional state relays;
3. evaluate DiLoCo or similar as an alternative training mode for models that fit on a single slice.

## Main Risks

- Restart latency may be dominated by compile time rather than checkpoint read time.
- Some models cannot actually make progress on one slice because model-parallel state does not fit.
- Input pipeline determinism must survive world-size changes.
- Optimizer-state restore across arbitrary slice counts must be validated, not assumed.
- JAX TransferServer may remain too experimental for production use.

## Phased Rollout

### Phase 0: prove prerequisites

- verify checkpoint restore across arbitrary slice counts;
- measure compile/cache reuse after resize;
- measure checkpoint barrier latency.

### Phase 1: elastic scale-down only

- trigger restart when slices are lost;
- continue on any admissible lower slice count;
- no opportunistic scale-up yet.

### Phase 2: opportunistic scale-up

- detect additional healthy slices;
- restart at a checkpoint boundary onto a larger cohort when worth it.

### Phase 3: fast-state relay

- benchmark Arrow Flight and JAX TransferServer as resume accelerators;
- keep TensorStore as fallback and source of truth.

## Immediate Next Steps

1. Prototype a controller-owned safe-point request path.
2. Add an elastic slice policy config at the Levanter launcher layer.
3. Run checkpoint restore tests across changing slice counts.
4. Decide whether the first controller implementation lives in Levanter, Fray, or Iris-backed launcher code.
22 changes: 16 additions & 6 deletions lib/iris/src/iris/cluster/client/job_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,17 @@ def get_job_info() -> JobInfo | None:
return info

# Fall back to environment variables.
raw_task_id = os.environ.get("IRIS_TASK_ID")
if raw_task_id:
raw_task_attempt = os.environ.get("IRIS_TASK_ID")
raw_legacy_task_id = os.environ.get("IRIS_JOB_ID")
if raw_task_attempt or raw_legacy_task_id:
try:
parsed = TaskAttempt.from_wire(raw_task_id)
task_id = parsed.task_id
attempt_id = parsed.attempt_id if parsed.attempt_id is not None else 0
if raw_task_attempt:
parsed = TaskAttempt.from_wire(raw_task_attempt)
task_id = parsed.task_id
attempt_id = parsed.attempt_id if parsed.attempt_id is not None else 0
else:
task_id = JobName.from_wire(raw_legacy_task_id)
attempt_id = int(os.environ.get("IRIS_ATTEMPT_ID", "0"))
task_id.require_task()
except ValueError:
return None
Expand All @@ -104,7 +109,12 @@ def get_job_info() -> JobInfo | None:
constraints: list[Constraint] = []
if constraints_json:
for item in json.loads(constraints_json):
constraints.append(Constraint.from_proto(json_format.ParseDict(item, cluster_pb2.Constraint())))
constraint_proto = json_format.ParseDict(
item,
cluster_pb2.Constraint(),
ignore_unknown_fields=True,
)
constraints.append(Constraint.from_proto(constraint_proto))

info = JobInfo(
task_id=task_id,
Expand Down
59 changes: 59 additions & 0 deletions lib/iris/tests/cluster/client/test_job_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,65 @@ def _raise():
assert resolve_job_user() == "root"


def test_get_job_info_accepts_iris_task_id_env(monkeypatch):
set_job_info(None)
monkeypatch.delenv("IRIS_JOB_ID", raising=False)
monkeypatch.setenv("IRIS_TASK_ID", "/alice/train/0:0")
monkeypatch.setenv("IRIS_NUM_TASKS", "2")
monkeypatch.setenv("IRIS_ATTEMPT_ID", "3")
monkeypatch.setenv("IRIS_WORKER_ID", "worker-7")
monkeypatch.setenv("IRIS_CONTROLLER_ADDRESS", "http://10.0.0.1:10000")

info = get_job_info()

assert info is not None
assert info.task_id == JobName.from_wire("/alice/train/0")
assert info.job_id == JobName.from_wire("/alice/train")
assert info.num_tasks == 2
assert info.attempt_id == 3
assert info.worker_id == "worker-7"
assert info.controller_address == "http://10.0.0.1:10000"
set_job_info(None)


def test_get_job_info_accepts_legacy_iris_job_id_env(monkeypatch):
set_job_info(None)
monkeypatch.delenv("IRIS_TASK_ID", raising=False)
monkeypatch.setenv("IRIS_JOB_ID", "/alice/train/0")
monkeypatch.setenv("IRIS_ATTEMPT_ID", "3")
monkeypatch.setenv("IRIS_NUM_TASKS", "2")
monkeypatch.setenv("IRIS_WORKER_ID", "worker-7")
monkeypatch.setenv("IRIS_CONTROLLER_ADDRESS", "http://10.0.0.1:10000")

info = get_job_info()

assert info is not None
assert info.task_id == JobName.from_wire("/alice/train/0")
assert info.job_id == JobName.from_wire("/alice/train")
assert info.num_tasks == 2
assert info.attempt_id == 3
assert info.worker_id == "worker-7"
assert info.controller_address == "http://10.0.0.1:10000"
set_job_info(None)


def test_get_job_info_ignores_unknown_constraint_fields(monkeypatch):
set_job_info(None)
monkeypatch.delenv("IRIS_JOB_ID", raising=False)
monkeypatch.setenv("IRIS_TASK_ID", "/alice/train/0:0")
monkeypatch.setenv(
"IRIS_JOB_CONSTRAINTS",
'[{"key":"region","op":0,"mode":"cohort"}]',
)

info = get_job_info()

assert info is not None
assert len(info.constraints) == 1
assert info.constraints[0].key == "region"
set_job_info(None)


def test_worker_region_from_env(monkeypatch):
"""IRIS_WORKER_REGION is read into JobInfo.worker_region."""
set_job_info(None)
Expand Down
Loading
Loading