[dpo] Revive LoRA-DPO on canonical train_dpo#4634
Conversation
- Update train_dpo.py imports to use LmDataConfig instead of SingleDatasetLMConfig - Migrate to components-based data config structure - Replace text.py with text/ package structure (from simpo) - Add preference.py with DPO-specific classes - Update DPO YAML configs to use components: structure - Merge validation split functions into single _build_validation_split - Copy updated Levanter main scripts from simpo (train_lm.py, eval_lm.py, etc.) - Copy updated marin tokenize files from simpo Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Copy all config files from simpo (updated to components: structure) - Copy updated source files from simpo (trainer_state.py, optim/, etc.) - Add EpochDataset class to dataset.py for DPO training - Update text/__init__.py exports for preference functions - Add SimPO config files from simpo branch Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Keep preference format handling in datasets.py and formats.py - Keep DPO exports in text/__init__.py - Accept main's partitioning.py changes (use axis_names) - Restore EpochDataset class in dataset.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
… branch - tokenizer: marin-community/marin-tokenizer - train_batch_size: 128, num_train_steps: 2150 - learning_rate: 5e-7, lr_schedule: cosine, warmup: 0.1 - beta: 0.01 - Add both train and validation components with proper cache dirs - Use GCS model paths for reference_model_path and initialize_from_hf - validation_split_fraction: null (uses separate validation component) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This lets the simple experiment and executor wrappers drive adapter-based DPO without bypassing the canonical train_dpo entrypoint. It also defaults reference eval caching on for SimpleDPOConfig so repeated eval no longer pays the uncached reference-model cost.
Add 5 executor-native LoRA DPO tuning scripts under experiments/tune_lora/ that sweep LR (5e-6 to 1e-5) and seed around the best known config from the endlboq3 baseline. Update the babysit logbook with job launch details, OOM recovery incidents, cross-region GCS fix, and ongoing monitoring. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The LoRA DPO sweep halved batch size (128→64) but kept num_train_steps=850, training on only half the dataset. Instead of requiring manual step counts, SimpleDPOConfig now defaults to num_epochs=1.0 and auto-resolves the correct num_train_steps from the tokenized cache stats at launch time. Validation scheduling is also auto-computed (5 runs: initial, 3 interior, final). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Reduce bloom speceval v2 validation set from 23,454 pairs (9 per prompt) to ~2,606 (1 per prompt) via scripts/dedupe_val_prefs.py. This cuts DPO eval time from ~40 min to ~4 min per round. Update dpo_bloom_speceval_v2.py to use the deduped val path and add codex logbook notes. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Attach the default Paloma and Uncheatable validation sets to DPO as a separate LM-eval path instead of mixing them into preference validation. This keeps DPO loss metrics distinct while giving standard and LoRA runs the same tagged LM eval coverage as pretraining.
Add 13 new sweep wrapper scripts covering 9 LRs (1e-6 to 1e-5) x 2 seeds (0, 2). Move heavy imports in uncheatable evals to avoid circular deps. Update codex logbook with sweep plan and launch commands. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Preserve the DPO and LoRA investigation state with reproducible wrappers, local plotting scripts, and append-only logbooks so the February and April runs can be compared offline. Also make future LoRA runs expose their configured W&B tags reliably so sweep metadata stays searchable.
Backport the LoRA merge axis-order fix and embed chat templates in saved tokenizer metadata so merged HF exports load cleanly in downstream HF and vLLM inference paths.
Add a one-off 5-step LoRA DPO wrapper that forces checkpoint and HF export on v5p-8. Record the initial Iris launch attempts and the controller-side submission blocker so the smoke path can be resumed cleanly.
Retire disposable DPO/LoRA probes and plotting scripts now that the investigation has converged. Keep the successful LoRA smoke-run and downstream export verification recorded in the codex logbook.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 982ef92eab
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| os.path.join(full_save_path, f"step-{step.step}"), | ||
| upload_to_hf=upload_to_hf, | ||
| dtype=save_dtype, | ||
| **hf_upload_kwargs, |
There was a problem hiding this comment.
Pass generation_config when exporting separate-reference checkpoints
In the separate-reference/no-adapter path, the HF export hook calls converter.save_pretrained(...) without forwarding generation_config, so hf_generation_eos_token_ids is silently dropped for periodic exports. This regresses behavior from the previous implementation and can produce checkpoints with incorrect EOS/chat generation settings in downstream inference.
Useful? React with 👍 / 👎.
| WANDB_API_KEY="${WANDB_API_KEY:-3d91078de9092186db48b81253a2e8902563454b}" | ||
| HF_TOKEN="${HF_TOKEN:-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}" |
There was a problem hiding this comment.
Remove committed API token defaults from launch script
This script embeds concrete default values for WANDB_API_KEY and HF_TOKEN, so running it without environment overrides will use committed credentials. That is a credential-leak/security risk and can also cause accidental writes under the wrong external accounts.
Useful? React with 👍 / 👎.
|
|
||
| # Step 1: Create a tarball of the codebase (excluding heavy dirs) | ||
| echo "[1/4] Creating codebase tarball..." | ||
| cd /Users/ahmed/code/marin |
There was a problem hiding this comment.
Replace machine-specific source path in launch script
The launch workflow hardcodes cd /Users/ahmed/code/marin, which makes the script fail immediately on any machine that does not have that exact local path. Since the script is checked into the repo as a reusable launcher, it should resolve the repo path dynamically instead of depending on one developer workstation layout.
Useful? React with 👍 / 👎.
| interval = max(1, math.ceil(num_train_steps / (total_validation_runs - 1))) | ||
| return [step for step in range(interval, num_train_steps - 1, interval)] |
There was a problem hiding this comment.
Honor auto_validation_runs when deriving interior eval steps
The interval-based calculation can under-schedule validation passes relative to auto_validation_runs; for example, num_train_steps=10 and total_validation_runs=5 yields interior steps [3, 6], i.e. only 4 total runs including initial/final. This violates the documented contract that auto_validation_runs controls the total number of validation passes.
Useful? React with 👍 / 👎.
Adds per_device_parallelism to SimpleDPOConfig so DPO training can use gradient accumulation on smaller-HBM TPUs like v6e-8 (31.25 GB/chip). Changes: - SimpleDPOConfig: expose per_device_parallelism (default -1 = auto) - defaults.py: pass per_device_parallelism through to TrainerConfig - New v6e-8 probe scripts using mirrored() for cross-region data access - Logbook: full v6e feasibility analysis, memory accounting, throughput comparison (v6e-8 14.2s/step vs v5p-8 25.6s/step at batch=64) Validated: v6e-8 with per_device=4 and 2x grad accum fits Llama 8B DPO LoRA (r=64, seq=4096, batch=64) in 31.25 GB HBM. Runs 1.8x faster than v5p-8 in wall-clock tokens/s. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…aint The executor parent runs on a CPU node in us-central1. Child jobs inherit region=us-central1 via Iris client auto-inheritance. But v6e-16 hardware only exists in europe-west4, us-east5, us-east1 — no groups match us-central1. Fix: set explicit regions on ResourceConfig to override the inherited constraint. Not an Iris bug, just a region mismatch. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
v6e-16 multi-host scheduling failed because child jobs inherit the parent executor's region (us-central1) where no v6e-16 exists. Fix: set explicit regions on ResourceConfig. Also generalized make_v6e_probe() to support both v6e-8 and v6e-16. Launched v6e-16 probes in europe-west4 and us-east5. Both accepted by controller — waiting for TPU capacity. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
jax_init.py previously skipped jax.distributed.initialize() for ALL TPU jobs, assuming the TPU runtime handles multi-host discovery. But on Iris- managed multi-host TPU, the container doesn't inherit TPU_WORKER_HOSTNAMES, so JAX can't auto-discover other hosts. Fix: only skip for single-task TPU jobs. Multi-task TPU jobs (num_tasks>1) now fall through to the Iris coordinator dance, same as GPU multi-host. Bug found when v6e-16 (4 VMs) DPO training failed with: RuntimeError: multihost_broadcast_sync requires jax distributed client to be initialized Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
xprof trace (steps 5-15) confirms v6e-8 DPO LoRA bottleneck: - 47% matmul, 22% splash attention, 12% elementwise, 9% buffer alloc - Matmuls are memory-bound (472 FLOP/byte < 522 crossover) - AllocateBuffer stalls (9%) from HBM pressure at 99% utilization - Communication only 2.7% (not a bottleneck) - Carry offloading to enable pd=8 would eliminate grad accum AND buffer allocation overhead — estimated ~1.5x additional speedup Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Uses gradient_checkpointing="offload" to move checkpoint saves from HBM to pinned host memory, freeing ~17 GB HBM. Should enable per_device=8 without gradient accumulation, eliminating 2x weight reads and the 9% buffer allocation stalls observed in xprof. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
JAX init fix worked but v6e-16 needs 38.97 GB/chip at pd=4 — 8 GB more than v6e-8 at the same pd=4, likely from DCN communication buffers. Deprioritizing v6e-16 in favor of v6e-8 with carry offloading. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
gradient_checkpointing="offload" triggers SIGABRT in XLA TPU compiler at async_dynamic_index_emitter.cc:584 — sublane alignment bug with save_and_offload DMA transfers on v6e with JAX 0.8.0. Not fixable in our code. Next: try TP=4+FSDP=2 or file JAX bug. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Carry offloading crashes on v6e (XLA codegen bug). Pivoting to TP=4+FSDP=2. Changes: - SimpleDPOConfig: add mesh field - defaults.py: pass mesh through to TrainerConfig - make_v6e_probe: accept mesh parameter - TP=4 probe scripts for all 3 v6e regions Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
MeshConfig(axes={model: 4}) compiles but doesn't shard attention heads.
Full bf16[32,32,4096,4096] attention scores allocated on one chip (34 GB).
Need to map model axis to head dimension explicitly.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
GQA models use "kv_head" axis (not "heads") for attention scores. The default shared_mapping only maps "heads" → "model", so GQA attention scores weren't being sharded by TP. Adding "kv_head" → "model" fixes it. With 8 KV heads and model=4, each chip gets 2 KV heads. Attention scores go from bf16[8,4,4096,4096] total to bf16[2,4,4096,4096] per chip. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…on attempts Working config: v6e-8, pd=4, 2x grad accum, 14.2s/step, 18474 tokens/s. Both optimization paths (carry offload, TP=4) blocked by framework issues. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Land canonical LoRA-DPO on top of current main by unifying the train_dpo and LoRA config paths, wiring Marin executor/simple-config launches, and fixing merged HF export behavior. Add tests and docs for adapter-base references and validate the upstreamed path with Iris LoRA runs and downstream vLLM loading.
Fixes #4556