Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
3775267
dpo init
ahmeda14960 Jan 22, 2026
5218a2c
wip dpo working somewhat
ahmeda14960 Jan 23, 2026
fd91d3c
config
ahmeda14960 Jan 23, 2026
7a9f9fa
Merge remote-tracking branch 'origin/main' into dpo
ahmeda14960 Jan 23, 2026
4860174
claude suggestions
ahmeda14960 Jan 23, 2026
1e0550c
fix scan
ahmeda14960 Jan 23, 2026
63da7f7
update
ahmeda14960 Jan 23, 2026
9c7db7c
update dpo
ahmeda14960 Jan 23, 2026
dedec64
Merge remote-tracking branch 'origin/main' into dpo_claude_opus
ahmeda14960 Jan 24, 2026
f5df1b7
update claude dpo branch
ahmeda14960 Jan 24, 2026
f4aca7c
wip
ahmeda14960 Jan 24, 2026
7278f64
wip
ahmeda14960 Jan 24, 2026
bfb9171
Fix trainer_state.py to match main (remove _fill_missing_namedarrays)
ahmeda14960 Jan 31, 2026
f034997
Refactor train_dpo.py to use new LmDataConfig API
ahmeda14960 Jan 31, 2026
2f6620b
Fix missing Union import in train_dpo.py
ahmeda14960 Jan 31, 2026
0de46ad
Update test files from simpo branch
ahmeda14960 Jan 31, 2026
a8f59eb
Sync dpo_claude_opus with simpo branch
ahmeda14960 Jan 31, 2026
4534dbe
Merge main into dpo_claude_opus - resolve conflicts
ahmeda14960 Jan 31, 2026
54af964
Update dpo_ultrafeedback_llama3_8b.yaml with hyperparameters from dpo…
ahmeda14960 Jan 31, 2026
f6b602a
wip
ahmeda14960 Jan 31, 2026
acac2a1
Merge remote-tracking branch 'origin/main' into dpo_claude_opus
ahmeda14960 Jan 31, 2026
8889a8a
track actualy single epoch
ahmeda14960 Jan 31, 2026
7b9ab19
Merge remote-tracking branch 'origin/main' into dpo_claude_opus
ahmeda14960 Feb 2, 2026
8a572a9
Refactor: move is_path_like to marin.utils for reuse
ahmeda14960 Feb 2, 2026
2c02787
add agent docs
ahmeda14960 Feb 2, 2026
1ddd631
Merge remote-tracking branch 'origin/main' into dpo_claude_opus
ahmeda14960 Feb 2, 2026
0c8f9de
Merge remote-tracking branch 'origin/main' into dpo_claude_opus
ahmeda14960 Feb 2, 2026
0d473dd
final touches
ahmeda14960 Feb 5, 2026
5e078ae
Merge remote-tracking branch 'origin/main' into dpo_claude_opus
ahmeda14960 Feb 5, 2026
5ea87db
Merge remote-tracking branch 'origin/main' into dpo_claude_opus
ahmeda14960 Feb 6, 2026
0bc6365
delete simpo tracker
ahmeda14960 Feb 6, 2026
7fb3990
do david's fix
ahmeda14960 Feb 6, 2026
e1533ad
david's changes
ahmeda14960 Feb 6, 2026
d46be03
remove redudant memory loading logic
ahmeda14960 Feb 6, 2026
ba1fc71
should be clean now
ahmeda14960 Feb 6, 2026
05cee85
Merge remote-tracking branch 'origin/main' into dpo_claude_opus
ahmeda14960 Feb 8, 2026
4cdb3d7
more fixes
ahmeda14960 Feb 8, 2026
2817a32
training not broken, now refactoring dataset
ahmeda14960 Feb 8, 2026
79736f1
more dpo updates
ahmeda14960 Feb 8, 2026
9c03a3c
Add technical report on LoRA best practices for DPO
ahmeda14960 Feb 8, 2026
d0058a9
go back to mixture
ahmeda14960 Feb 8, 2026
1d70cf3
go back to mixture
ahmeda14960 Feb 8, 2026
2a5865f
Add LoRA-DPO training: single-copy DPO via implicit reference model
ahmeda14960 Feb 9, 2026
c373905
Add temporary test config for 70B LoRA-DPO with tensor parallelism
ahmeda14960 Feb 9, 2026
fc8a76f
Merge remote-tracking branch 'origin/main' into dpo_lora
ahmeda14960 Feb 9, 2026
33bd433
Merge branch 'dpo_claude_opus' into dpo_lora
ahmeda14960 Feb 13, 2026
a2f8db1
wip
ahmeda14960 Feb 13, 2026
045f7ab
dpo lora update
ahmeda14960 Mar 12, 2026
2fdff31
Merge origin/main into dpo_lora, resolving conflicts with main's DPO
ahmeda14960 Mar 12, 2026
cbb848f
Merge remote-tracking branch 'origin/main' into dpo_lora
ahmeda14960 Mar 25, 2026
b89dbd4
Unify DPO and LoRA DPO training
ahmeda14960 Mar 29, 2026
7ccc531
Improve DPO eval tooling and configs
ahmeda14960 Mar 30, 2026
e6dcb2e
Merge remote-tracking branch 'origin/main' into dpo-lora
ahmeda14960 Mar 31, 2026
35d9c44
[dpo] Cache reference eval logprobs
ahmeda14960 Mar 31, 2026
a356e08
[dpo] Fix LoRA resume and capture follow-ups
ahmeda14960 Mar 31, 2026
304a3bf
Merge remote-tracking branch 'origin/main' into dpo-lora
ahmeda14960 Mar 31, 2026
eef838d
[dpo] Add LoRA support to simple DPO configs
ahmeda14960 Mar 31, 2026
9989cbc
[dpo] Add LoRA DPO tuning sweep experiments and babysit logbook
ahmeda14960 Mar 31, 2026
1218ce6
[dpo] Auto-compute training steps from dataset size and num_epochs
ahmeda14960 Apr 1, 2026
fe91deb
[dpo] Deduplicate val preferences and update logbook
ahmeda14960 Apr 1, 2026
1e0e5fe
[dpo] Add LM eval suites to DPO runs
ahmeda14960 Apr 1, 2026
9c0ff0f
[dpo] Add expanded LoRA DPO sweep scripts and LM eval integration
ahmeda14960 Apr 1, 2026
21f52d7
[dpo] Add comparison runs and W&B tag repair
ahmeda14960 Apr 6, 2026
11ddc0c
Merge origin/main into dpo-lora
ahmeda14960 Apr 9, 2026
08f4ae8
[levanter] Fix LoRA HF export path
ahmeda14960 Apr 9, 2026
8ceb13c
Merge remote-tracking branch 'origin/main' into dpo-lora
ahmeda14960 Apr 9, 2026
450e7f2
[dpo] Add LoRA smoke export probe
ahmeda14960 Apr 9, 2026
1ac7d27
[dpo] Remove one-off experiment scripts
ahmeda14960 Apr 10, 2026
982ef92
Merge remote-tracking branch 'origin/main' into dpo-lora
ahmeda14960 Apr 10, 2026
722f615
Merge remote-tracking branch 'origin/main' into dpo-lora
ahmeda14960 Apr 12, 2026
f89aa35
[dpo] Add v6e-8 LoRA DPO support with gradient accumulation
ahmeda14960 Apr 13, 2026
3fadb5d
[dpo] Logbook: v6e-16 scheduling root cause — inherited region constr…
ahmeda14960 Apr 13, 2026
40eea63
[dpo] Add v6e-16 probe scripts, fix scheduling via explicit region
ahmeda14960 Apr 13, 2026
2fe470a
[iris] Fix multi-host TPU JAX distributed init on Iris
ahmeda14960 Apr 13, 2026
1f35abb
[dpo] Logbook: document multi-host TPU JAX init bug and fix
ahmeda14960 Apr 13, 2026
3a66b78
[dpo] xprof analysis: v6e-8 is bandwidth-bound, 9% buffer alloc stalls
ahmeda14960 Apr 13, 2026
2da5b06
[dpo] Add carry offloading probe for v6e-8 (pd=8, no grad accum)
ahmeda14960 Apr 13, 2026
8424b7d
[dpo] Logbook: v6e-16 OOMs at 38.97 GB (multi-host DCN overhead)
ahmeda14960 Apr 13, 2026
65be317
[dpo] Logbook: carry offloading crashes on v6e (XLA codegen bug)
ahmeda14960 Apr 13, 2026
986c660
[dpo] Add TP=4+FSDP=2 probes for v6e-8
ahmeda14960 Apr 13, 2026
5c1de57
[dpo] Logbook: TP=4 OOM — attention heads not sharded by model axis
ahmeda14960 Apr 13, 2026
0b228b3
[dpo] Fix TP=4 attention sharding: map kv_head axis to model
ahmeda14960 Apr 13, 2026
00bc13e
[dpo] Logbook: document working v6e-8 config at top, close optimizati…
ahmeda14960 Apr 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3,789 changes: 3,789 additions & 0 deletions .agents/logbooks/dpo-lora-claude.md

Large diffs are not rendered by default.

3,150 changes: 3,150 additions & 0 deletions .agents/logbooks/dpo-lora-codex.md

Large diffs are not rendered by default.

180 changes: 180 additions & 0 deletions .agents/logbooks/levanter_mesh_explained.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# How Levanter Sets Up the TPU Mesh

## Overview

Levanter doesn't hardcode any knowledge of TPU types. It asks JAX what devices exist, JAX asks the hardware, and Levanter builds the mesh from two numbers: **total device count** and **number of hosts (slices)**.

## Step-by-Step: How the Mesh Gets Built

### Step 1: JAX discovers devices

When a TPU program starts, JAX calls into the TPU runtime and gets back a list of device objects. Each device has:
- `id`: unique device ID
- `platform`: "tpu"
- `device_kind`: "TPU v5p", "TPU v6e", etc.
- `slice_index`: which host this device belongs to (only present on multi-host setups)

### Step 2: Levanter counts hosts

```python
# lib/levanter/src/levanter/trainer.py, line 960
num_slices = max(getattr(device, "slice_index", 0) for device in jax.devices()) + 1
```

If `slice_index` doesn't exist on the device (single host), `num_slices = 1`.

### Step 3: Levanter computes chips per host

```python
# lib/levanter/src/levanter/trainer.py, line 966
per_slice = jax.device_count() // num_slices
```

### Step 4: The mesh axes are computed

The defaults (from `lib/levanter/src/levanter/utils/mesh.py`):
```python
DEFAULT_ICI_AXIS_SPEC = {"data": -1, "replica": 1, "model": 1} # within-host
DEFAULT_DCN_AXIS_SPEC = {"replica_dcn": -1} # across-host
```

The `-1` means "absorb whatever's left." So:
- `data = per_slice / (replica × model) = per_slice` (all chips within a host)
- `replica_dcn = num_slices` (one slot per host)

### Concrete examples

| TPU config | `jax.device_count()` | `num_slices` | `per_slice` | Mesh: `data` | Mesh: `replica_dcn` |
|------------|---------------------|-------------|-------------|--------------|---------------------|
| v5p-8 (1 host, 4 chips) | 4 | 1 | 4 | 4 | 1 |
| v5p-32 (4 hosts, 4 chips each) | 16 | 4 | 4 | 4 | 4 |
| v6e-4 (1 host, 4 chips) | 4 | 1 | 4 | 4 | 1 |
| v6e-128 (32 hosts, 4 chips each) | 128 | 32 | 4 | 4 | 32 |

**v5p-8 and v6e-4 look identical to Levanter** — both are 4 chips, 1 host. The only difference is HBM size (95 GB vs 31 GB), which Levanter doesn't check.

**v5p-32 and v6e-128** both have `per_slice=4` — same within-host FSDP depth. The difference is 4 vs 32 hosts, and 95 vs 31 GB per chip.

## How FSDP Sharding Depth Is Determined

The critical config (line 40 in mesh.py):
```python
param_mapping: {"embed": "data"}
```

This means: **shard the model's embed dimension across the `data` axis only.**

Since `data` is an ICI axis (within-host), FSDP only shards within a host. The `replica_dcn` axis is used for data parallelism (batch distribution + gradient averaging) but NOT for parameter sharding.

### What this looks like on v6e-128

```
Host 0: [chip0: 1/4 model] [chip1: 1/4 model] [chip2: 1/4 model] [chip3: 1/4 model]
Host 1: [chip0: 1/4 model] [chip1: 1/4 model] [chip2: 1/4 model] [chip3: 1/4 model]
...
Host 31: [chip0: 1/4 model] [chip1: 1/4 model] [chip2: 1/4 model] [chip3: 1/4 model]
```

Each chip holds 1/4 of the model (sharded within-host). All 32 hosts hold identical copies. Adding more hosts doesn't reduce per-chip model memory.

Per-chip model storage: `8B params × 4 bytes (f32) / 4 = 8 GB`

On v6e (31.25 GB HBM), 8 GB is 25% of the chip — before activations, optimizer state, or XLA temp buffers.

## How to Shard Across All Chips

Change the param_mapping in the YAML config:

```yaml
trainer:
mesh:
param_mapping:
embed: [replica_dcn, data]
```

This tells Levanter: "shard the embed dimension across both `replica_dcn` AND `data`." On v6e-128 that's 32 × 4 = 128-way sharding.

Per-chip model storage: `8B params × 4 bytes / 128 = 250 MB`

### The tradeoff

**Default (`embed: data`):**
- All-gathers stay within-host (fast ICI, microseconds)
- Each host has a full model copy (high memory per chip)
- Good for v5p (95 GB HBM — plenty of room)

**Cross-host (`embed: [replica_dcn, data]`):**
- All-gathers go across hosts (slow DCN, milliseconds)
- Each chip holds 1/128 of the model (low memory per chip)
- Necessary for v6e (31 GB HBM — too small for 1/4 of 8B model in f32)

### Why this is especially OK for LoRA

| Communication type | Full fine-tuning | LoRA |
|-------------------|-----------------|------|
| Forward all-gather (base weights, per layer) | ~864 MB × 64 layers | Same (unavoidable) |
| Gradient all-reduce (across hosts) | **32 GB** (all params) | **620 MB** (LoRA only, 50× less) |
| Optimizer step | All params | LoRA only (tiny) |

The forward all-gather is the same cost regardless. But gradient all-reduce — which is the OTHER expensive cross-host communication — is 50× cheaper with LoRA. So the performance penalty of cross-host FSDP is much smaller for LoRA than for full fine-tuning.

## ICI vs DCN: The Two Networks

TPU pods have two interconnects:

- **ICI (Inter-Chip Interconnect)**: Connects chips within a host. Very fast, very high bandwidth. This is the `data`, `replica`, and `model` axes.
- **DCN (Data Center Network)**: Connects hosts to each other. Much slower, lower bandwidth. This is the `replica_dcn` axis.

The default design philosophy: keep parameter communication on ICI, use DCN only for gradient averaging. This is optimal when chips have enough HBM. When they don't (v6e), you trade DCN latency for memory savings.

## MeshConfig: The Full Configuration Surface

From `lib/levanter/src/levanter/utils/mesh.py`:

```python
@dataclass(frozen=True)
class MeshConfig:
axes: {"data": -1, "replica": 1, "model": 1} # ICI axis sizes (-1 = absorb remaining)
dcn_axes: {"replica_dcn": -1} # DCN axis sizes (-1 = absorb remaining)
batch_axis_name: "batch" # logical name for the batch axis
shared_mapping: {} # logical → physical (shared by compute + params)
compute_mapping: {} # logical → physical (compute only)
param_mapping: {"embed": "data"} # logical → physical (params + optimizer)
```

Resolved mappings:
- **Parameters**: `{"mlp": "model", "heads": "model", "embed": "data"}` — FSDP on `data`, TP on `model`
- **Compute**: same + `{"batch": ("replica_dcn", "replica", "data")}` — batch spans ALL axes

Everything is configurable via the trainer YAML:
```yaml
trainer:
mesh:
axes:
data: -1
model: 1
dcn_axes:
replica_dcn: -1
param_mapping:
embed: [replica_dcn, data] # cross-host FSDP
compute_mapping:
batch: [replica_dcn, replica, data]
```

## TPU Generation Differences

Levanter's mesh code is **generation-agnostic**. It doesn't detect v5p vs v6e. The differences that matter:

| Property | v5p | v6e |
|----------|-----|-----|
| HBM per chip | 95.74 GB | 31.25 GB |
| Chips per host | 4 | 4 |
| ICI bandwidth | High | High |
| DCN bandwidth | Medium | Medium |
| Levanter mesh | Identical | Identical |

The only code that checks TPU generation is the Pallas kernel tuning (`tuned_block_sizes.py`), which detects `"v5p" in device_kind` for fused cross-entropy block sizes.

## Git History

The mesh code was refactored in commit `896a390de` (Dec 18, 2025) by David Hall and William Held: "Mesh refactor in support of context parallelism." This introduced the clean ICI/DCN separation and the configurable `MeshConfig`. The default of `embed: data` (FSDP within-host only) was an intentional performance choice, not a limitation — it's configurable.
Loading
Loading