- Unify regular DPO and LoRA-DPO under the canonical
levanter.main.train_dpoentrypoint. - Keep
TrainerState.modelpolicy-only for both regular DPO and LoRA-DPO. - Factor adapter behavior into a reusable layer so future LoRA-enabled training flows can share the same machinery.
- Represent DPO variants via config flips:
adapter.type: none+reference.type: separatefor regular DPOadapter.type: lora+reference.type: adapter_basefor LoRA-DPO
- Keep
lora_dpo.pyonly as a compatibility shim for legacy configs. - Preserve regular DPO behavior closely enough to sanity-check against earlier runs, including the legacy
model_keylineage. - Update configs, tests, and docs to match the new runtime shape.
- Refactored
lib/levanter/src/levanter/main/train_dpo.pyto be the single canonical DPO runtime. - Removed the old
DpoModel(policy, reference)trainer-state shape. - Added
DpoReferenceConfigwith:SeparateReferenceConfigAdapterBaseReferenceConfig
- Added validation rules for invalid combinations:
reference.type=adapter_baserequires a non-noneadapteradapter.type=lora+reference.type=adapter_baserequireszero_init_b=true
- Changed the loss function to accept only the policy model.
- Implemented separate frozen-reference loading outside
TrainerState, captured by the loss closure. - Implemented adapter-base reference lookup through the adapter runtime.
- Applied
jax.lax.stop_gradientto reference log-probs so the reference path is explicitly non-differentiated in both modes.
- Added
lib/levanter/src/levanter/adaptation.py. - Introduced:
AdaptationConfigAdaptationExportConfigNoAdaptationConfigLoraAdaptationConfig
- Centralized:
- adapter application
- trainable-filter derivation
- adapter-base model view lookup
- export-hook installation
- Reused low-level LoRA operations from
lib/levanter/src/levanter/lora.pyrather than using legacylora_lm.pyas the architectural template.
- Added
lib/levanter/src/levanter/main/model_init.py. - Factored out shared model loading/bootstrap logic for:
- HF converter setup
- tokenizer replacement
- optional HF-config adoption
- HF checkpoint load
- Levanter checkpoint load
- parameter casting/sharding
- Rewrote
lib/levanter/src/levanter/main/lora_dpo.pyinto a translation wrapper. - Kept the legacy
LoraDpoConfigsurface for old config files. - Translates legacy LoRA-DPO configs into canonical
TrainDpoConfigwith:adapter=LoraAdaptationConfig(...)reference=AdapterBaseReferenceConfig()
- Forwards execution into canonical
train_dpo.main.
- Updated
experiments/defaults.pysodefault_dpo(...)constructs canonicalTrainDpoConfig(reference=SeparateReferenceConfig(...)). - Updated canonical DPO YAML configs under
lib/levanter/config/to use nestedadapter/referenceblocks instead of top-levelreference_model_path/reference_is_hf. - Left legacy
lib/levanter/config/dpo/lora_dpo_*YAMLs on the compatibility path intentionally, so old LoRA-DPO configs still route through the shim.
- Updated
lib/levanter/tests/test_dpo.pyto match the new architecture. - Removed old tests that assumed
DpoModellived in trainer state. - Added tests for:
- policy-only
TrainerState - invalid adapter/reference combinations
- canonical config parsing for
adapter.type: none - canonical config parsing for
adapter.type: lora - legacy
LoraDpoConfigtranslation
- policy-only
- Kept existing
lib/levanter/tests/test_lora_dpo.pypassing against the refactor. - Replaced brittle parse-from-repo-config tests with minimal temp YAML fixtures after an initial failure exposed unrelated data-config parsing fields.
- Updated
lib/levanter/docs/guides/DPO-Training.mdto describe:- canonical
train_dpo.py - nested
adapter/referenceconfig - policy-only trainer state
- separate-reference vs adapter-base reference behavior
- canonical
- Updated
lib/levanter/docs/guides/LoRA-DPO-Training.mdto describe:- canonical
train_dpo.pyusage adapter.type: lorareference.type: adapter_base- explicit
zero_init_b: truerequirement - legacy
lora_dpo.pystatus as compatibility-only
- canonical
After the initial refactor, a follow-up review requested two concrete changes.
- Changed the new
logger.info(...)calls inlib/levanter/src/levanter/main/train_dpo.pyback to f-strings.
- Restored the legacy full-DPO top-level split shape in canonical
train_dpo.py:data_key, adapter_key, model_key, training_key = split(PRNGKey(seed), 4)- this intentionally repurposes the old unused loader-key slot as
adapter_key
- Added
_derive_training_keys(seed)to preserve the old regular-DPO policy key lineage:policy_key = split(model_key)[0]
- Used
model_keyas the separate-reference checkpoint shape key so non-HF separate references follow the old regular-DPO path more closely. - Added a regression test in
lib/levanter/tests/test_dpo.pyto verify that the deriveddata_key,model_key,policy_key, andtraining_keymatch the legacy full-DPO derivation.
inference_mode(...)does not stop gradients; it flips modules with aninferenceflag into eval behavior, typically relevant for dropout-like modules.jax.lax.stop_gradient(...)is still needed for the reference path when the reference can be derived from the policy model itself.- There is no
haliax.stop_gradienthelper in the current Haliax version used here, sojax.lax.stop_gradient(...)is the direct primitive. SeparateReferenceModelProvider.model_for(policy_model)takespolicy_modelonly to share one interface with the adapter-base provider. In the separate-reference case it is intentionally ignored.
- Ran targeted syntax verification on changed Python files during implementation.
- Ran targeted DPO tests multiple times during the refactor.
- Final targeted test result:
uv run --project lib/levanter --group test python -m pytest lib/levanter/tests/test_dpo.py lib/levanter/tests/test_lora_dpo.py- Result:
27 passed, 1 skipped
- Ran repo-required targeted lint/type/style checks with:
./infra/pre-commit.py --fix ...- Final result: passed on the touched files
- I did not migrate
lora_lm.pyto the new adaptation layer. - I did not remove legacy LoRA-DPO YAMLs; they still intentionally target the compatibility shim.
- I did not run a full training job end-to-end.
- I did not update every
SimpleDPOConfigcall site inexperiments/; the executor defaults path now emits canonicalTrainDpoConfig, but the higher-level simple config surface still usesreference_model_path/reference_is_hf.
- The original refactor goal "
TrainerState.modelis policy-only for both regular DPO and LoRA-DPO" did not survive a real multihost TPU run. - For
reference.type=separate, I changed regular DPO to load the frozen reference once and close over it in the loss function instead of storing it in trainer state. - On multihost TPU, that closed over a sharded
jax.Arrayspanning non-addressable devices. - JAX rejected that during lowering with:
RuntimeError: Closing over jax.Array that spans non-addressable (non process local) devices is not allowed.
- This was not an Iris scheduling problem and not a W&B naming problem. It was a real training-runtime bug in the new canonical DPO code path.
- The old regular DPO script worked because the frozen reference lived inside
state.model, so it was passed into the compiled train step as an argument instead of being captured as a constant. - Because I launched a batch of relaunches before catching this, all of the
new_dporegular-DPO reruns failed at the same compile/lowering boundary.
- I relaunched the central1
lr7.5e-7sweep members before validating a real multihost TPU training run of the refactor. - I also had east5
new_dporelaunches still active while debugging the central1 failure path. - After the failure was confirmed, I killed the still-running east5 sibling reruns so they would stop consuming TPU capacity:
new_dpo_bloom_speceval_v2_marin_instruct_beta0.1_seed2new_dpo_bloom_speceval_v2_marin_instruct_beta0.01_seed1new_dpo_bloom_speceval_v2_marin_instruct_beta0.01_seed0new_dpo_bloom_speceval_v2_marin_instruct_beta0.1_seed1
- I then used
/ahmed/new_dpo_bloom_speceval_v2_marin_instruct_beta0.1_lr7.5e-7_seed0as the single debug specimen. - When relaunching the fixed job in two regions in parallel, I had to stop reusing the same run id. Parallel east5/central1 launches need distinct W&B ids and distinct checkpoint roots or they collide.
- Restored
DpoModel(policy, reference)inlib/levanter/src/levanter/main/train_dpo.pyforreference.type=separate. - Kept policy-only state for
reference.type=adapter_baseso LoRA-DPO still uses the unified adapter-base path. - Updated the loss function to accept:
DpoModelfor separate-reference regular DPOLmHeadModelfor adapter-base LoRA-DPO
- Added
_load_separate_reference_model(...)for the explicit separate-reference load path. - Added
_install_separate_reference_export_hooks(...)so the separate-reference path still exports only the policy model. - Kept the adapter-based export path for LoRA unchanged.
- Ran:
./infra/pre-commit.py --fix lib/levanter/src/levanter/main/train_dpo.py lib/levanter/tests/test_dpo.pyuv run --project lib/levanter --group test python -m pytest lib/levanter/tests/test_dpo.py lib/levanter/tests/test_lora_dpo.py
- Final result:
29 passed, 1 skipped
- Added a regression test in
lib/levanter/tests/test_dpo.pyto assert that the separate reference is marked non-saveable / non-trainable in the mask.
- Relaunched only the single target experiment in parallel across two regions with distinct ids:
- central1:
/ahmed/new_dpo_bloom_speceval_v2_marin_instruct_beta0.1_lr7.5e-7_seed0-c1 - east5:
/ahmed/new_dpo_bloom_speceval_v2_marin_instruct_beta0.1_lr7.5e-7_seed0-e5
- central1:
- Both use the same training hyperparameters as the original
beta0.1_seed0_lr7.5e-7sweep member. - Both add region / TPU tags so the hardware is visible later in W&B:
new_dpov5p-32- region tag (
us-central1-aorus-east5-a)
- Current state at time of writing:
- central1 job is
PENDINGwaiting fortpu_v5p_32-us-central1-acapacity - east5 job is
RUNNING
- central1 job is
- The abstraction split is still good at the config level:
- regular DPO:
adapter.type=none,reference.type=separate - LoRA-DPO:
adapter.type=lora,reference.type=adapter_base
- regular DPO:
- The runtime shape is not fully symmetric:
- regular DPO currently needs
DpoModel(policy, reference)in trainer state for multihost safety - LoRA-DPO can stay policy-only because the reference is derived from the policy model inside the step
- regular DPO currently needs
- So the original "policy-only for both modes" objective should be considered disproven by experiment.
- Watch the east5 relaunch until it gets past the old failure point:
- JAX lowering / first train-step compile
- first actual optimization step
- If the east5 run clears that boundary, keep it as the validation run and decide whether the queued central1 duplicate is still useful.
- If the east5 run fails again, pull logs immediately and compare against the old closure error to see whether there is a second bug behind the first one.
- If the central1 run later gets capacity and starts cleanly, compare behavior across regions before relaunching any broader DPO sweep.
- Only after one regular DPO run is confirmed stable on TPU should the rest of the failed
new_dporegular-DPO reruns be resubmitted.
- Both duplicate validation jobs are now genuinely running, not just scheduler-level
RUNNING:/ahmed/new_dpo_bloom_speceval_v2_marin_instruct_beta0.1_lr7.5e-7_seed0-c1/ahmed/new_dpo_bloom_speceval_v2_marin_instruct_beta0.1_lr7.5e-7_seed0-e5
- Iris shows both jobs with:
JOB_STATE_RUNNINGtask_state_counts.running = 4- empty
pending_reason
- Both jobs loaded the model weights successfully.
- Both jobs reached:
First batch loaded ..., starting first train step (includes JIT compilation)...Tracing train_step for jaxpr...Lowering train_step to HLO...Lowered train_step ...
- Both jobs completed the first optimization step and entered the step-0 eval:
- log line shape:
Progress on:train 1.00it/850it
- log line shape:
- This is the exact boundary where the earlier regular-DPO refactor was failing, so the separate-reference fix is now validated in live TPU execution.
-c1is currently ahead:- reached about
106/184eval batches - eval loss is still
0.693, which is expected at step 0 before learning
- reached about
-e5is behind but healthy:- reached about
68/184eval batches - eval loss is also
0.693
- reached about
- W&B can still appear to show only system metrics even while the jobs are healthy.
- The reason is the trainer ordering:
- step-0 eval runs before the per-step tracker log flush
- this eval is fairly large (
184batches)
- So the jobs can be actively training/evaluating while W&B still has little or no
train/*/eval/*history visible. - This is a logging-order artifact, not evidence of a hang.
- The duplicate runs have now served their original debugging purpose:
- they proved that the multihost compile/lowering regression is fixed
- The remaining reason to keep both is redundancy across regions.
- Practical recommendation:
- keep watching until
-c1finishes the initial eval and emits real W&B metrics - then keep
-c1and kill-e5unless cross-region duplication is still desired overnight
- keep watching until
- Get
experiments/eval_dpo.pyinto a runnable state so DPO eval can be profiled directly without burning 400+ training steps just to trigger an eval hook.
- The checked-in
experiments/eval_dpo.pydid not match current Levanter APIs. - It had broken model-init calls, broken cache loading, and an incomplete profiler / tracker lifecycle.
- It also diverged from the standard standalone eval pattern enough that even a superficially fixed script would not necessarily profile the real Levanter eval path.
- Updated
experiments/eval_dpo.pyto use current model-init helpers:prepare_model_init_context(...)load_model_from_source(...)
- Switched validation-cache loading to the current
TreeCache.load(... exemplar=..., options=...)pattern. - Reused existing Bloom SpecEval v2 experiment config, tokenizer, and model definitions instead of keeping parallel hardcoded copies.
- Switched profiling lifecycle to
profile_ctx(...). - Explicitly finish the tracker on exit for remote worker safety.
Observed remote failure during warmup:
- fused next-token loss crashed with
Axis embed present in both specs with different sizes- local shard looked like
embed(256)while the model contract axis wasembed(4096)
- This was not a dataset-shape bug and not an HF-config mismatch.
- The standalone script had left
parameter_axis_mappingactive across the eval path. - In this mesh config,
parameter_axis_mappingincludesembed -> data. - That is correct for parameter layout, but wrong for executing the eval kernel.
- The fused CE path uses the current thread-local mapping when it enters
shard_map(...), so eval wound up executing under parameter sharding instead of compute sharding.
- Removed the outer
hax.axis_mapping(parameter_axis_mapping)context fromexperiments/eval_dpo.py. - Changed the eval kernel from plain
eqx.filter_jitto:@hax.named_jit(axis_resources=compute_axis_mapping)
- Kept model loading under
parameter_axis_mapping.
Resulting model:
- parameters are still loaded/sharded the same way as training expects
- eval batches now execute under compute sharding, which should avoid the
embed(256)vsembed(4096)fused-loss conflict
uv run python -m py_compile experiments/eval_dpo.pypasseduv run python - <<'PY' import experiments.eval_dpo as m; print(...) PYpassed
experiments/eval_dpo.pyshould now be much closer to the real standalone eval shape used elsewhere in Levanter.- The next material check is to relaunch the Iris job and confirm the warmup batches clear the previous fused-loss crash.
What looks high-confidence:
- Caching frozen-reference log-probs is the cleanest eval optimization.
- It removes 2 of the 4 forward passes in DPO eval without requiring sharding redesign.
What looks plausible but not yet proven:
- An eval-specific replicated/data-parallel model layout may reduce communication overhead.
- That is not a simple
axis_mappingflip. It would require a second eval layout for the model itself, not just a different thread-local mapping. - The right order is still:
- cache reference log-probs
- re-profile
- only then decide whether an eval-specific layout is worth building
- Remove the two frozen-reference forward passes from every DPO eval by precomputing:
logp_ref_chosenlogp_ref_rejected
- Keep the cache reusable across runs.
- Avoid introducing meaningful runtime I/O overhead during eval.
- Persist the cache to GCS once.
- Load the full cache into host RAM at job startup.
- Join cached values to eval examples by dataset index.
This is the right split because the cached payload is tiny but expensive to compute.
For the Bloom SpecEval v2 validation set:
- the cache only needs two scalar floats per example
- DPO already reduces over position via
_logp_sum(...) PreferenceChatLmDatasetFormatdefaults topack=False, so the validation set is naturally aligned 1:1 with cache rows
That means the runtime working set is on the order of a few hundred KB, not GB.
Per validation example, store:
{
"logp_ref_chosen": np.float32(...),
"logp_ref_rejected": np.float32(...),
}Do not cache per-token reference log-probs for this optimization. They are unnecessary for DPO eval and would only complicate storage and shape handling.
Use a small sidecar TreeCache, not an ad hoc JSON/NPY blob.
Reasons:
- it reuses the repo’s existing cache machinery
- it naturally lives on GCS
- it gives us a ledger and metadata validation
- it is easy to load back as an
AsyncDataset
Implementation sketch:
exemplar = {
"logp_ref_chosen": np.zeros((), dtype=np.float32),
"logp_ref_rejected": np.zeros((), dtype=np.float32),
}
with SerialCacheWriter(cache_dir, exemplar, metadata=CacheMetadata(...)) as writer:
writer.write_batch(batch_dict)The cache must be invalidated whenever the frozen reference semantics could change.
Minimum key inputs:
- validation cache path
- reference model path / HF ref
- immutable reference revision if available
- tokenizer id
- sequence length
- preference-format fields that affect tokenization/loss semantics:
chosen_fieldrejected_fieldmask_user_turnsslice_strategy
Suggested metadata shape:
CacheMetadata(
preprocessor_metadata={
"kind": "dpo_reference_logprobs_v1",
"val_cache": "...",
"reference_model": "...",
"reference_revision": "...",
"tokenizer": "...",
"seq_len": 4096,
"mask_user_turns": True,
"slice_strategy": "raise",
}
)Use a sibling GCS path near the existing validation token cache, not a W&B artifact-only flow.
Preferred pattern:
<validation-cache-root>/reference_logprobs/<stable-cache-key>/
For the current experiment, that means something under the same marin-us-central1 storage hierarchy as:
gs://marin-us-central1/tokenized/bloom_speceval_v2_val_prefs_marin_tokenizer-a06ae8/validation
This keeps the data close to the source cache and avoids inventing a second distribution mechanism for something that is fundamentally dataset-derived.
Do not read the sidecar cache remotely batch-by-batch during eval.
Instead:
- load the sidecar cache once at startup
- materialize the two float arrays in CPU memory
- map validation examples to include cached reference values by index
- during eval, compute only:
logp_pi_chosenlogp_pi_rejected
- combine with cached reference scalars to produce the normal DPO metrics
The index-based join is straightforward because MappedAsyncDataset receives the item index.
- Add a small cache-builder helper, probably in a DPO-focused module rather than inside the one-off experiment script.
- Add a lightweight cached-example type for eval-only use.
- Add a helper that:
- tries to load the reference-logprob sidecar cache
- builds it if missing
- returns host arrays or a mapped dataset carrying cached scalars
- Teach DPO eval to use cached reference scalars when present.
- Leave training loss unchanged for now. Only the eval hook path should use the cache in phase 1.
- Implement the cache only for evaluation.
- Keep the existing policy path unchanged.
- Use the one-off
experiments/eval_dpo.pyscript to validate correctness and speedup.
- Wire cached reference values into the main DPO eval hook path in
train_dpo.py. - Re-profile a training run with eval enabled.
- Decide whether more aggressive eval-layout changes are still worth doing after the 2x reference-pass reduction.
Correctness checks:
- compare uncached vs cached eval loss/metrics on a bounded number of batches
- assert numerical agreement within a tight tolerance
- verify cache length exactly matches validation dataset length
Performance checks:
- re-run the standalone eval profile
- verify that:
reference_chosenreference_rejectedno longer dominate the trace
- compare total eval wall time before/after
- Do not cache per-token log-probs.
- Do not try to make the cache work for packed preference datasets yet.
- Do not combine this with eval-specific replicated model layout in the same patch.
- Do not redesign checkpointing in the same change.
- DPO eval should drop from 4 forward passes per example to 2.
- The cache should be reusable across runs and cheap to load.
- After this lands, we will have a much cleaner answer to whether remaining eval time is still dominated by communication or by local model compute / HBM bandwidth.
- Claude finished the reference log-prob caching path for standalone DPO eval and validated it on real full-validation runs.
- The cached path appears correct and gives a reproducible ~2x eval speedup by removing the two frozen-reference forward passes.
- Claude's follow-up idea of replicated non-FSDP eval sharding is plausible, but the proposed implementation seam was too optimistic.
- The better next experiment is: keep model loading unchanged, run cached eval only, explicitly reshard the loaded policy model once into an eval-only replicated layout, then measure steady-state batches and communication.
Claude's most important completed work was on standalone DPO evaluation, not the training-path eval hooks.
Implemented in experiments/eval_dpo.py:
- three eval modes:
--mode uncached--mode build--mode cached
- a deterministic reference-logprob sidecar cache path
- cache build logic for:
logp_ref_chosenlogp_ref_rejected
- a cached eval path that avoids loading or running the reference model in cached mode
- a cached-example wrapper compatible with Levanter/JAX pytree expectations
The core design was sound:
- build reference log-probs once for the full validation set
- persist them as a tiny sidecar cache
- join by dataset index at eval time
- reduce DPO eval from 4 forward passes to 2
- Cached loss compute dropped from
26.9sto13.4s. - This established the expected 2x compute reduction.
- Uncached total eval:
552.1s - Cached total eval:
279.1s - Measured speedup:
1.98x
Policy:
- trained step-849 checkpoint from the new DPO pipeline
Reference:
marin-community/marin-8b-instruct
Validation set:
- full
bloom_speceval_v2_val 23,552examples46batches at eval parallelism32
Reported metric agreement:
loss: exact matchdpo_accuracy: exact matchdpo_chosen_reward: exact matchdpo_rejected_reward: exact matchdpo_margin_policy: exact matchdpo_margin_ref: exact match
Reported performance at full scale:
- uncached total eval: about
20.7 min - cached total eval: about
10.4 min - measured speedup:
1.98x
Projected per-run savings from Claude's numbers:
- roughly
41 minto48 minsaved per full training run depending on which baseline is used
These are worth preserving because they explain the multihost failure modes we already hit:
- use
process_allgather(..., tiled=True)before converting sharded outputs to numpy - only one host should write the sidecar cache
- gate cache writes with
jax.process_index() == 0 - synchronize after cache writes so non-writers do not race ahead
- use an
eqx.Modulefor cached examples instead of a plain dataclass so DataLoader/JIT can batch them - treat
TreeCache.get_batch_sync(...)as returning a list of dicts, not a dict - remove trailing
(1,)dimensions from cached scalar values before wrapping them withhax.named(...) - run eval/loss under
compute_axis_mapping, not underparameter_axis_mapping - avoid losing results behind multihost barrier/profiler cleanup ordering
The standalone eval path is much further along than the main DPO training path.
At the current workspace state:
experiments/eval_dpo.pyexists and compiles locallyexperiments/dpo_bloom_speceval_v2_profile.pyexists and compiles locallytrain_dpo.pystill computes reference log-probs online inside the loss- cached-reference eval does not yet appear wired into the main
train_dpo.pyeval hook path - the worktree is still dirty and includes untracked standalone eval/profiling files
Relevant local verification:
uv run python -m py_compile experiments/eval_dpo.py experiments/dpo_bloom_speceval_v2_profile.pypassed- pytest was not available in this environment, so I did not complete a local test run
The high-level idea is plausible:
- load or place the eval-time model so weights are replicated
- keep batches sharded across the data axis
- remove the FSDP parameter-layout transition during eval
However, the logged implementation proposal was too optimistic in three ways.
Claude proposed making the experiment by changing:
load_model_from_source(..., parameter_axis_mapping=parameter_axis_mapping)to:
load_model_from_source(..., parameter_axis_mapping=compute_axis_mapping)This is not just an eval-layout change. It changes the HF/GCS deserialization path too, because load_model_from_source(...) passes that mapping into the loader itself.
That means the experiment would conflate:
- load-time sharding behavior
- deserialization behavior
- runtime eval layout
If the goal is to test whether the FSDP parameter-to-compute layout transition is the bottleneck, that is the wrong first seam to touch.
The current HF/GCS path builds a full state dict on each host before sharding it into the model. So the difference is not simply:
- FSDP load = each chip loads only its shard
- replicated load = each chip loads the full model
For the current Levanter HF path, both cases already involve heavy per-host state-dict loading before device placement. Replicated eval still may cost more HBM, but the loading story is not as favorable to the current FSDP path as the plan suggested.
Claude had already recorded the key caveat earlier:
- we had collective op counts
- we did not yet have per-op durations
- the collectives might be overlapped with compute
So a forecast like another 1.5x to 3x on top of caching was speculative. The right framing is:
- communication-free eval might help
- we do not yet know how much of the remaining 13.4s/batch is actually recoverable
Answer the narrow question first:
- after reference-logprob caching, how much steady-state eval time is still caused by the FSDP parameter-layout transition?
Do not combine this with training-hook integration, checkpoint serialization changes, or other cleanup work.
Run this experiment only in cached eval mode.
Reasons:
- cached eval only needs the policy model
- that cuts replicated-model HBM requirements roughly in half versus uncached/build mode
- it isolates the layout question from the already-solved reference-forward-pass problem
Keep model loading unchanged at first:
policy_model = load_model_from_source(
...,
parameter_axis_mapping=parameter_axis_mapping,
...
)Then explicitly reshard once, after load, into an eval-only layout:
policy_model = hax.shard_with_axis_mapping(policy_model, compute_axis_mapping)or the equivalent hax.shard(...) form.
Why this is the better first test:
- it isolates runtime/layout effects from loader effects
- it preserves the currently-working HF/GCS load path
- it matches the actual question we care about: whether eval gets faster when the model already lives in the compute layout
For both baselines:
- cached eval with normal FSDP-loaded model
- cached eval with one-time explicit reshard to replicated compute layout
collect these separately:
- policy model load time
- one-time reshard time
- warmup / first-compile time
- steady-state loss time over:
- 20 batches first
- then full 46-batch validation if promising
- total eval wall time
- peak HBM / OOM behavior
Do not treat "single end-to-end wall time" as the only metric. The compile and reshard one-time costs need to be split out.
Call the experiment successful only if all of the following hold:
- cached metrics still match the FSDP-cached baseline
- the trace shows parameter-layout collectives disappear or materially shrink
- steady-state loss time improves enough to matter after excluding one-time reshard/compile cost
- HBM remains stable at the chosen eval parallelism
If collectives disappear but steady-state batches barely improve, stop. That means the remaining bottleneck is local compute, memory bandwidth, or something else.
- freeze a clean baseline for cached mode at:
per_device_eval_parallelism=3220batches- then full validation if needed
- record:
- load time
- warmup time
- steady-state loss time
- total eval time
- keep the same model load path
- reshard the loaded policy model once into compute mapping
- repeat the exact same measurement protocol
- try
per_device_eval_parallelism=64 - treat this as a separate capacity/throughput experiment, not part of the original sharding question
If 64 fails or becomes unstable:
- keep
32 - preserve the simpler replicated-eval result if it helps
The profile should distinguish:
- policy forward
- reference forward
- eval loop load time
- eval loop loss time
- one-time warmup/compile
This is already partly in place from the current local changes:
- named scopes in
train_dpo.py - timing metrics in
levanter.callbacks.eval_loss_loop - standalone profiling script for steady-state DPO eval
The experiment should look roughly like:
policy_model = load_model_from_source(
context=model_context,
Vocab=Vocab,
model_key=policy_key,
parameter_axis_mapping=parameter_axis_mapping,
compute_dtype=trainer_config.mp.compute_dtype,
cast_to_param=trainer_config.mp.cast_to_param,
hf_ref=model_name,
)
if args.eval_param_layout == "replicated":
t_reshard = time.time()
policy_model = hax.shard_with_axis_mapping(policy_model, compute_axis_mapping)
logger.info("Resharded policy model for eval in %.1fs", time.time() - t_reshard)with an explicit CLI switch, for example:
--eval_param_layout fsdp
--eval_param_layout replicated
That keeps the experiment reversible and makes the comparison honest.
- One-time reshard may still be non-trivial in cost.
- JIT recompilation is expected because input/output shardings change.
- HBM headroom at
per_device=64is still speculative. - If the loader or model carries arrays whose metadata assumes parameter layout in a subtle way, explicit post-load reshard may surface another shape/sharding bug.
These are acceptable risks because they are all directly tied to the hypothesis being tested.
The next best step is not "switch load_model_from_source(...) to compute_axis_mapping and hope."
The next best step is:
- keep the working loader path
- run cached eval only
- explicitly reshard the already-loaded policy model once
- measure steady-state batches and communication before trying higher batch sizes
If that experiment shows a real gain, then we can decide whether it is worth teaching the loader or training eval hooks about an explicit eval-only replicated layout.
We ran the standalone replicated cached-eval experiment to completion on Iris.
Identifiers:
- W&B run:
lr15l9n2 - W&B display name:
jumping-resonance-174 - Iris job:
/ahmed/dpo-eval-cached-repl-v5p32-20260329-230854
Run shape:
- mode:
cached - eval parameter layout:
replicated - hardware:
v5p-32 - devices:
16 - hosts:
4 - checkpoint:
step-849fromnew_dpo_v2_bloom_speceval_v2_marin_instruct_beta0.1_-7_seed2-947c5d
Controller outcome:
- Iris job finished
JOB_STATE_SUCCEEDED 4/4tasks succeeded- no preemptions
- no worker failures
Measured result:
- total eval time: about
592.2s(9.9 min) - warmup time: about
24.1s - batches run:
46 - data load time: about
8.2sto8.8s - loss compute time: about
583.3sto583.9s - average time per batch: about
12.87s
Metric summary from W&B:
eval_cached/loss:0.00551eval_cached/dpo_loss:0.00551eval_cached/dpo_accuracy:0.9958eval_cached/dpo_chosen_reward:2.5412eval_cached/dpo_rejected_reward:-11.9235eval_cached/dpo_margin_policy:-45.0787eval_cached/dpo_margin_ref:-189.7258
Interpretation:
- the standalone replicated-layout path is real and stable on multihost TPU
- cached eval remains overwhelmingly compute-bound after caching
- data loading was only about
1.4%to1.5%of total eval time - the remaining time is dominated by loss compute, not by input loading
Most important practical conclusion:
- the end-to-end improvement versus the earlier cached full-validation result (
~10.4 min) is only modest (~9.9 minhere) - that is directionally positive, but not large enough to justify more complexity right now
Decision:
- do not prioritize integrating eval-only replicated sharding into
train_dpo.py - do not spend more time teaching the loader or training eval hooks about replicated eval layout right now
- leave the standalone
experiments/eval_dpo.pysupport in place as a profiling/benchmark tool - defer this thread until eval speed becomes a top blocker again
In short: this experiment de-risked the idea, but it did not earn promotion into the mainline DPO path yet.
- Not for the Bloom SpecEval v2 Marin Instruct run family in this thread.
- I do not see any recorded Bloom SpecEval v2 LoRA-DPO training launch in:
- this Codex logbook
- Claude's parallel logbook
- the current Iris job list
- What we have today is code-path readiness, not experiment evidence:
- canonical
levanter.main.train_dposupportsadapter.type=lorawithreference.type=adapter_base - the legacy
levanter.main.lora_dpowrapper still works for old configs - LoRA-DPO tests and docs are in place
- existing LoRA YAMLs target Ultrafeedback / legacy sanity checks, not Bloom SpecEval v2
- canonical
Target run:
- W&B:
https://wandb.ai/marin-community/dpo/runs/new_dpo_v2_bloom_speceval_v2_marin_instruct_beta0.1_-7_seed2-947c5d - Executor source:
experiments/sweep_dpo/beta0.1_seed2_lr7.5e-7.py
Baseline knobs that should stay fixed for the first LoRA reproduction attempt:
- dataset: Bloom SpecEval v2 preference data (
bloom_openai_model_spec_v2_gpt41_vs_mixtral_opposite) - tokenizer:
marin-community/marin-tokenizer - base model:
marin-community/marin-8b-instruct - seed:
2 - beta:
0.1 - learning rate:
7.5e-7 - train batch size:
128 - num train steps:
850 - train / max seq len:
4096 - steps per eval:
200 - hf export cadence:
200 - hardware target:
v5p-32 - memory target:
256GB - eval parallelism target:
32
The only intended semantic change for the first reproduction run is:
- full-FT policy + separate frozen reference
- becomes
- LoRA policy + implicit adapter-base reference
experiments/defaults.default_dpo(...)andSimpleDPOConfigcurrently only construct regular DPO withreference.type=separate.- There is no Bloom SpecEval v2 experiment wrapper in
experiments/that can already emit:adapter.type: lorareference.type: adapter_base- LoRA export settings like
merged_hf_save_path
So the fastest and lowest-risk first LoRA experiment is:
- write a standalone canonical
TrainDpoConfigYAML for Bloom SpecEval v2 - launch
python -m levanter.main.train_dpodirectly on Iris - only after one successful run, decide whether to teach
SimpleDPOConfig/default_dpoabout LoRA for sweep parity
That avoids mixing "does LoRA-DPO train correctly here?" with "did we correctly refactor the executor config surface?"
Goal: reproduce the ...947c5d run as faithfully as possible while changing only the DPO parameterization.
Planned config delta relative to the seed-2 baseline:
adapter:
type: lora
r: 64
alpha: 64.0
dropout: 0.0
zero_init_b: true
target_modules: null
reference:
type: adapter_baseEverything else should match the baseline run unless a TPU-specific blocker appears.
Rationale for this first run shape:
zero_init_b: trueis mandatory for DPO because the step-0 policy must equal the implicit referencer=64, alpha=64is the current house default and matches the LoRA-DPO guidetarget_modules: nullmeans all linear modules, which is the current recommended setting- keeping
lr=7.5e-7makes this a true reproduction attempt rather than an immediate retuning study
First run should be a dedicated canonical config, something like:
- config path:
lib/levanter/config/dpo/lora_dpo_bloom_speceval_v2_seed2_lr7.5e-7_central1.yaml - entrypoint:
uv run python -m levanter.main.train_dpo --config_path ...
Suggested job command shape:
uv run iris --config lib/iris/examples/marin.yaml job run --no-wait \
--extra marin:tpu \
--tpu v5p-32 \
--memory 256GB \
--disk 50GB \
--zone us-central1-a \
--job-name lora-new-dpo-v2-bloom-s2-lr7p5e7 \
-e WANDB_API_KEY "$WANDB_API_KEY" \
-- python -m levanter.main.train_dpo \
--config_path lib/levanter/config/dpo/lora_dpo_bloom_speceval_v2_seed2_lr7.5e-7_central1.yaml- Use the same Bloom SpecEval v2 tokenized train/val caches as the full-DPO baseline.
- Keep
validation_split_fraction: nullso eval uses the explicit validation set rather than creating a new split. - Keep
trainer.per_device_eval_parallelism: 32for comparability with the current full-DPO runs. - Keep
train_batch_size: 128andnum_train_steps: 850unchanged. - Use
merged_hf_save_pathfor the first run, notpeft_save_path.
Why merged_hf_save_path first:
- older LoRA TPU configs in this repo explicitly disabled
peft_save_pathbecause of multihost serialization issues - even if the new adaptation path may be better, the first Bloom SpecEval v2 LoRA run should minimize new failure modes
Call the first LoRA reproduction successful only if all of the following hold:
- it compiles and gets past the first train step on multihost TPU
- step-0 DPO loss is near
ln(2)/0.693, not a large blown-up value - no container OOM occurs at
v5p-32,256GB - W&B logs normal DPO metrics:
lossdpo_accuracydpo_margin_policydpo_margin_refdpo_chosen_rewarddpo_rejected_reward
- throughput and eval cadence look sane relative to the regular-DPO baseline
If step-0 loss is badly wrong, that is a strong sign that the LoRA identity/reference assumption is broken, most likely:
zero_init_bnot applied- wrong reference mode
- adapter modules not wired as expected
There is a real chance that a literal lr=7.5e-7 LoRA run will learn too slowly. The LoRA-DPO guide in this branch recommends starting closer to 5e-6.
So the plan should be:
- run the strict reproduction first at
7.5e-7 - if it is stable but under-trains, launch a LoRA-tuned follow-up at
5e-6 - keep the same:
- dataset
- seed
- beta
- batch size
- steps
- hardware
- adapter rank
That gives two distinct answers:
- strict reproduction: "what happens if we only swap in LoRA?"
- LoRA-tuned comparison: "what is the fairer LoRA baseline once the optimizer is adjusted?"
Do not start by extending the executor/sweep layer.
After the first LoRA run succeeds, the next cleanup step should be:
- add adapter/reference/export fields to
SimpleDPOConfig - teach
default_dpo(...)to emit either regular DPO or LoRA-DPO - add a proper
experiments/sweep_dpo/lora_beta0.1_seed2_lr7.5e-7.py
That sequencing keeps the first experiment focused on model behavior instead of config refactoring.
Source read:
- This article is strong evidence for supervised fine-tuning and policy-gradient RL.
- It is not a direct DPO paper.
- For our Bloom SpecEval v2 plan, the safest interpretation is:
- treat DPO as much closer to the article's supervised setting than to its RL setting
- carry over the supervised LoRA practices first
- mark any DPO-specific conclusions as inference, not established fact
The practices below are directly supported by the article:
- Apply LoRA to all layers, especially MLP/MoE layers.
- Attention-only LoRA underperformed materially.
- MLP-only was much better, and all-layer LoRA was the safest default.
- Do not judge LoRA from a single learning rate.
- The article explicitly swept LR for each condition before comparing LoRA to FullFT.
- For supervised-style training, LoRA's best LR is about 10x the FullFT LR.
- The article reports an empirical multiplier of about
9.8x.
- The article reports an empirical multiplier of about
- Large batch size can hurt LoRA more than it hurts FullFT.
- This effect appeared to be mostly independent of rank.
- Keep the standard LoRA parametrization unless there is evidence otherwise.
alpha / rscaling- zero-init on
B - standard random init for
A - same LR for
AandB - the authors report they could not improve on this basic setup
- Rank is mainly a capacity knob, not a cure for bad optimization settings.
- If LoRA is capacity-constrained, training falls off the FullFT curve.
- But larger rank does not remove the large-batch penalty.
These are my inferences from the article, not direct claims made there:
- Our Bloom SpecEval v2 DPO run should be treated as a supervised-style LoRA problem, not as an RL-style low-capacity case.
- Therefore the article's RL result ("very low rank can match FullFT") should not be used to justify tiny-rank DPO first runs.
- The current repo default of:
target_modules: nullzero_init_b: truealpha = ris directionally correct for the first Bloom SpecEval v2 LoRA experiment.
- Because Levanter currently excludes
lm_headfrom LoRA by default, our practical "all-layer" run is really "all supported linear layers except lm_head". That is still much closer to the article's recommendation than attention-only LoRA.
This section updates the previous plan.
The earlier plan proposed a strict reproduction at:
- FullFT baseline LR:
7.5e-7 - LoRA reproduction LR: also
7.5e-7
After reading the article, that should be demoted to a sanity / lineage run only.
Reason:
- the article's strongest operational finding is that LoRA wants about 10x the FullFT LR in supervised settings
- so a same-LR comparison is likely unfair and likely to make LoRA look artificially weak
Updated interpretation:
- strict same-LR run (
7.5e-7): useful only to answer "what happens if I swap in LoRA and change nothing else?" - fair LoRA comparison: should center around
7.5e-6
Minimal first sweep for the Bloom SpecEval v2 seed-2 setup:
r = 64target_modules = nullzero_init_b = true- LR grid centered around the article's 10x rule:
5e-67.5e-61e-5
For our 850-step training run, 7.5e-6 is the natural anchor because it is exactly 10x the validated FullFT baseline LR.
The article also suggests a somewhat higher multiplier in very short runs. If we do only a brief screening run, e.g. <=100 steps, then a fourth point near 1.1e-5 is reasonable. For the full 850-step run, the main comparison should still center near 7.5e-6.
This is one of the clearest actionable points from the article.
If the first LoRA runs are stable but learn more poorly than FullFT:
- do not immediately conclude that rank 64 is too small
- do not expect higher rank to fix a large-batch optimization penalty
Instead, test smaller train batch sizes first, for example:
128(baseline)64- possibly
32if needed
Keep the hardware fixed if possible so the comparison stays interpretable.
Recommended order:
- get a stable run with all-layer LoRA and a LoRA-appropriate LR
- if that still underfits, test batch reduction
- only then test higher rank
Minimal rank ladder:
64first128second if there are signs of capacity limits
I do not think we should start with attention-only LoRA or with tiny ranks.
The article is a strong argument against piling on extra LoRA tricks in the first Bloom SpecEval v2 experiment.
So the first serious run should keep:
- standard
alpha / rscaling - same LR for
AandB - no LoRA+
- no rank-dependent alpha hacks
- no attention-only targeting
For this codebase, that means:
adapter:
type: lora
r: 64
alpha: 64.0
dropout: 0.0
zero_init_b: true
target_modules: nullThe article deliberately used loss-based comparisons rather than only sample-based evals.
For our DPO runs, the corresponding best practice is:
- compare validation
loss - compare
dpo_accuracy - compare
dpo_margin_policy - compare
dpo_margin_ref - compare chosen/rejected rewards
- compare throughput / wall-clock / memory
Do not treat a handful of qualitative generations as the main evidence.
This is the current recommended order for Bloom SpecEval v2 LoRA:
- Optional sanity run:
- same config as baseline, but with LoRA and
lr=7.5e-7 - purpose: verify the pipeline and observe how much performance is lost if LR is not retuned
- same config as baseline, but with LoRA and
- Main fair comparison run:
- same config,
r=64, all-layer LoRA,lr=7.5e-6
- same config,
- Small LR sweep around the fair run:
5e-6,7.5e-6,1e-5
- Batch-size follow-up only if needed:
- reduce
train_batch_sizefrom128to64
- reduce
- Rank follow-up only if needed:
r=128
The most important correction from the article is simple:
- a Bloom SpecEval v2 LoRA run at
7.5e-7should not be considered the serious LoRA baseline - the serious baseline should be around
7.5e-6, with all supported linear layers adapted, and with batch size treated as a separate optimization variable
- User requested that the next LoRA experiment actually be launched on
v5p-8. - The current best next experiment from this logbook is the fair LoRA baseline, not the same-LR sanity run.
- I am keeping the run in
us-central1-ato stay in-region with the Bloom SpecEval v2 preference data and tokenized caches.
- Config path:
lib/levanter/config/dpo/lora_dpo_bloom_speceval_v2_marin_instruct_beta0.1_lr7.5e-6_seed2_v5p8_central1.yaml
- Main settings:
- dataset: Bloom SpecEval v2 GPT-4.1 vs Mixtral opposite-mode preferences
- train cache:
gs://marin-us-central1/tokenized/bloom_speceval_v2_train_prefs_marin_tokenizer-12920b - val cache:
gs://marin-us-central1/tokenized/bloom_speceval_v2_val_prefs_marin_tokenizer-a06ae8 - base model:
marin-community/marin-8b-instruct - adapter: LoRA
r=64,alpha=64,dropout=0,zero_init_b=true,target_modules=null - reference:
adapter_base - LR:
7.5e-6 - beta:
0.1 - seed:
2 - train batch size:
128 - train steps:
850 - hardware:
v5p-8 - eval parallelism:
16
Planned Iris submission:
uv run iris --config lib/iris/examples/marin.yaml job run --no-wait \
--extra marin:tpu \
--tpu v5p-8 \
--memory 256GB \
--disk 50GB \
--zone us-central1-a \
--job-name lora-bsv2-mi-b0p1-s2-lr7p5e6-v5p8 \
-e WANDB_API_KEY "$WANDB_API_KEY" \
-- python -m levanter.main.train_dpo \
--config_path lib/levanter/config/dpo/lora_dpo_bloom_speceval_v2_marin_instruct_beta0.1_lr7.5e-6_seed2_v5p8_central1.yaml- Babysit this as an Iris TPU job, not a fire-and-forget launch.
- Watch specifically for:
- scheduler capacity wait vs real runtime failure
- first-step compile/lowering failures
- HBM / OOM signals
- bad step-0 DPO loss that would indicate broken LoRA identity/reference behavior
Initial launch:
uv run iris --config lib/iris/examples/marin.yaml job run --no-wait \
--extra marin:tpu \
--tpu v5p-8 \
--memory 256GB \
--disk 50GB \
--zone us-central1-a \
--job-name lora-bsv2-mi-b0p1-s2-lr7p5e6-v5p8-20260330-000321 \
-e WANDB_API_KEY "$WANDB_API_KEY" \
-- python -m levanter.main.train_dpo \
--config_path lib/levanter/config/dpo/lora_dpo_bloom_speceval_v2_marin_instruct_beta0.1_lr7.5e-6_seed2_v5p8_central1.yamlJob id:
/ahmed/lora-bsv2-mi-b0p1-s2-lr7p5e6-v5p8-20260330-000321
Result:
- never allocated
- stayed
JOB_STATE_PENDING - pending reason was:
- insufficient memory on ready
tpu_v5p_8-us-central1-aworkers (need 256GB, only about11.8GBavailable) - autoscaler scale-up for
tpu_v5p_8-us-central1-awas quota-blocked
- insufficient memory on ready
Action taken:
- stopped the pending central1 job
tpu_v5p_8-us-east5-awas not quota-blocked- multiple east5
v5p-8slices were fully idle (committed_mem_bytes=0,committed_tpu=0) - east5 already has the Bloom SpecEval v2 tokenized caches:
gs://marin-us-east5/tokenized/bloom_speceval_v2_train_prefs_marin_tokenizer-12920bgs://marin-us-east5/tokenized/bloom_speceval_v2_val_prefs_marin_tokenizer-a06ae8
So east5 was the first region with a realistic chance of actually running tonight.
Relaunch:
uv run iris --config lib/iris/examples/marin.yaml job run --no-wait \
--extra marin:tpu \
--tpu v5p-8 \
--memory 256GB \
--disk 50GB \
--zone us-east5-a \
--job-name lora-bsv2-mi-b0p1-s2-lr7p5e6-v5p8-east5-20260330-000700 \
-e WANDB_API_KEY "$WANDB_API_KEY" \
-- python -m levanter.main.train_dpo \
--config_path lib/levanter/config/dpo/lora_dpo_bloom_speceval_v2_marin_instruct_beta0.1_lr7.5e-6_seed2_v5p8_east5.yamlActive job id:
/ahmed/lora-bsv2-mi-b0p1-s2-lr7p5e6-v5p8-east5-20260330-000700
Current state at last check:
JOB_STATE_RUNNINGtask_state_counts.running = 1- no pending reason
This run is past scheduler allocation and into real startup.
Observed so far:
- W&B initialized successfully
- run id:
053ujx8y - W&B URL:
https://wandb.ai/marin-community/dpo/runs/053ujx8y - train and validation token caches loaded from east5
- no OOM /
RESOURCE_EXHAUSTED/ JAX lowering failure seen yet - worker is actively reading HF model shards for
marin-community/marin-8b-instruct
Notable warning:
- cache metadata mismatch warning on
preprocessor_metadata - this did not immediately kill the run
- at last check the worker was still making forward progress through model shard reads
What is already ruled out:
- scheduler-capacity failure on east5
- immediate container death on startup
- immediate host-RAM OOM during initial process boot
What is still not ruled out yet:
- failure later in model load
- first-step compile/lowering failure
- TPU HBM OOM once actual training starts
- finish loading all
marin-8b-instructsafetensor shards - reach first batch load
- reach train-step tracing / lowering
- survive step-0 eval without OOM
The east5 run did not survive first-step compile.
Final job state:
- job id:
/ahmed/lora-bsv2-mi-b0p1-s2-lr7p5e6-v5p8-east5-20260330-000700 - state:
JOB_STATE_FAILED - W&B run:
053ujx8y
What happened:
- the job loaded the east5 train/validation caches successfully
- it loaded all
marin-community/marin-8b-instructsafetensor shards successfully - it reached first batch load, tracing, and HLO lowering
- it then failed in JAX/XLA with TPU HBM exhaustion during compile
Key failure signal from logs:
RESOURCE_EXHAUSTED: XLA:TPU compile permanent error- HBM used:
111.15Gof95.74G - over capacity by:
15.41G - dominant temporary allocation included a
bf16[32,32,4096,4096]broadcast of about32.00G
Interpretation:
- this was not a scheduler-capacity failure
- this was not a host-RAM failure
- this was not a tokenizer/cache problem
- this was a true TPU-device-memory failure at the first training step
Most important conclusion:
- LoRA reduced trainable-parameter / optimizer-state cost, but it did not make the activation and temporary-memory footprint small enough for
train_batch_size: 128at4096tokens onv5p-8
Therefore the next experiment should follow the earlier LoRA plan exactly:
- keep
r=64 - keep all-layer LoRA (
target_modules: null) - keep
lr=7.5e-6 - keep
beta=0.1 - keep the same dataset and seed
- reduce train batch size before touching rank
Recommended next ladder on the same hardware:
- rerun with
train_batch_size: 64 - if that still OOMs, rerun with
train_batch_size: 32 - only after memory is sane, compare learning behavior and consider rank changes
This failure is actually consistent with the experimental guidance already recorded above: batch is the first knob to lower before increasing rank.
- DPO eval during training still does four forward passes because
loss_functioninlib/levanter/src/levanter/main/train_dpo.py:402computes both policy and reference log-probs inline. - The eval hook added in
lib/levanter/src/levanter/main/train_dpo.py:593goes through the genericTrainer.add_eval_hook(...)path inlib/levanter/src/levanter/trainer.py:606, so eval currently reuses the uncached training loss. experiments/eval_dpo.py:79proved that caching reference log-probs gives about a 2x eval speedup, but that file is a standalone profiling script and the wrong long-term home for production DPO logic.- Pre-emption is routine on these jobs, so an in-memory-only cache is not acceptable. The cache must survive restarts and be safe to rebuild after partial writes.
- Add an opt-in DPO config flag that builds or loads a reference-eval cache before training starts.
- Persist the cache to GCS, then load it into host RAM for repeated eval use within the run.
- Keep the cache scoped to validation/eval only. Do not change the training step, optimizer state, or checkpoint format.
- Reuse Levanter's existing finished-ledger cache semantics from
lib/levanter/src/levanter/store/cache.py:137so incomplete caches are treated as missing and rebuilt. - Keep the generic
TrainerAPI unchanged. This is DPO-specific behavior and should live in DPO code, not as a trainer-wide abstraction. - Support both DPO reference modes:
reference.type=separatereference.type=adapter_base
- Do not turn
experiments/eval_dpo.pyinto a production library module. - Do not cache training-time reference log-probs.
- Do not put the cache in device RAM or TPU HBM.
- Do not serialize cached reference log-probs inside training checkpoints.
- Do not introduce generic eval-cache machinery for all trainers.
- Add a new library module:
lib/levanter/src/levanter/dpo.py - Keep
lib/levanter/src/levanter/main/train_dpo.pyas the orchestration layer:- build validation dataset specs
- decide whether caching is enabled
- call the cache build/load helper before
trainer.add_eval_hook(...)
- Leave
experiments/eval_dpo.pyas disposable profiling / validation code. After the library path lands, either:- reduce it to a thin profiling script that imports the library helpers, or
- archive/delete it if nobody needs the standalone path anymore
Reason for using levanter/dpo.py instead of adding more code to train_dpo.py:
- the cache logic is real runtime behavior, not experiment glue
- it is too DPO-specific for
trainer.py - it is not purely data formatting, so it does not belong in
data/text/preference.py - it gives us one canonical home for reusable DPO runtime code instead of leaving it trapped in the CLI entrypoint
- it is lower churn than introducing a full
levanter/dpo/package right now, while still setting up a natural migration path if DPO code keeps growing
@dataclass(frozen=True)
class ReferenceEvalCacheConfig:
mode: Literal["disabled", "build_or_load"] = "disabled"
cache_dir: str | None = None
@dataclass(frozen=True)
class ValidationDatasetSpec:
name: str
dataset: AsyncDataset[DpoExample]
source_cache_path: str
source_split: str
slice_start: int | None = None
slice_end: int | None = Noneif config.reference_eval_cache.mode == "build_or_load":
validation_specs = prepare_reference_eval_caches(
validation_specs,
config=config,
trainer=trainer,
tokenizer=tokenizer,
model_context=model_context,
Pos=Pos,
)
for spec in validation_specs:
trainer.add_eval_hook(spec.dataset, name=spec.name or None)- Source of truth: GCS sidecar cache
- Fast serving path during the run: host RAM numpy arrays
- Not worth streaming from GCS every eval batch because the cache is tiny
- Not worth putting in checkpoint state because it is derived, immutable, and easier to rebuild than to checkpoint safely
Approximate size:
- two
float32values per example 8 bytes/example23,552examples is about188 KB- even
1,000,000examples is only about8 MB
Conclusion: persist to GCS for durability, then load into RAM per host once.
Trainer.add_eval_hook(...)inlib/levanter/src/levanter/trainer.py:606is intentionally generic.- We should not add DPO-specific cache controls there.
- The DPO loss in
train_dpo.pyshould instead accept either:- a normal
DpoExample - a cached
CachedDpoExample
- a normal
- That keeps the generic trainer unchanged while letting eval compile a separate cached path when needed.
MappedAsyncDatasetinlib/levanter/src/levanter/data/dataset.py:227does not pass the example index into the user callback; it only callsfn(item, ...).- The cache attach step needs the dataset index so it can read
ref_chosen[index]andref_rejected[index]. - So the right production shape is the same basic idea validated in
experiments/eval_dpo.py:477: a smallAsyncDatasetwrapper that returnsCachedDpoExampleby index.
Cache key must include:
- validation cache identity
- validation slice identity
- full validation cache
- or train-cache slice bounds when
validation_split_fractionis used
- reference identity
- sequence length
- cache schema version
For reference.type=separate, reference identity should include:
reference.model_pathreference.is_hf- model config class / shape if needed for safety
For reference.type=adapter_base, reference identity should include the frozen base initialization source, not the current policy weights:
initialize_from_hfinitialize_from_checkpoint_path- relevant model/tokenizer identity
- adapter type and settings only if they change the base-model view semantics
TreeCache.load(...)inlib/levanter/src/levanter/store/cache.py:137already refuses unfinished caches.- We should lean into that:
- if the cache is missing: build
- if the ledger is unfinished: build
- if metadata/version mismatches: rebuild
- Important change from current generic cache loading behavior:
- metadata mismatch should be treated as a cache miss for this path, not merely a warning
-
Add config surface
- Extend
TrainDpoConfiginlib/levanter/src/levanter/main/train_dpo.py - Extend
SimpleDPOConfiginexperiments/simple_dpo_config.py - Thread the config through
experiments/defaults.py
- Extend
-
Refactor validation-set construction
- Replace the plain
dict[str, AsyncDataset[DpoExample]]withValidationDatasetSpec - Preserve provenance needed for deterministic cache paths:
- source cache path
- split name
- slice bounds when using
validation_split_fraction
- Replace the plain
-
Add
levanter/dpo.py- Move reusable DPO runtime helpers out of
train_dpo.py - Move only the validated cache ideas from
experiments/eval_dpo.py - Add:
DpoModeldpo_loss_from_logpsReferenceEvalCacheConfigCachedDpoExampleCachedReferenceDatasetValidationDatasetSpec- deterministic cache-path builder
- strict cache metadata builder
- build-or-load helper that uses
SerialCacheWriterandTreeCache
- Keep the implementation process-safe:
- all hosts compute
- process 0 writes
- all hosts barrier before load/continue
- Move reusable DPO runtime helpers out of
-
Wire the cache into
train_dpo.py- After validation datasets are built and after the reference identity is known, call the cache helper before training starts
- Replace uncached eval datasets with cached wrappers when available
- Extend the DPO loss to branch on
CachedDpoExampleand skip the reference forward passes in eval - Leave the training path unchanged
-
Add tests
- Extend
lib/levanter/tests/test_dpo.pyand add a focused new test file if needed - Cover:
- deterministic cache-key changes when source cache / slice / reference / seq len changes
- unfinished cache rebuild behavior
- cached vs uncached DPO-loss parity on a small real model/dataset path
validation_split_fractionprovenance produces distinct cache keys from explicit validation caches
- Extend
-
Update docs
- Add a short note to
lib/levanter/docs/guides/DPO-Training.md - Document:
- what the flag does
- that it writes a durable sidecar cache
- that the first run pays a one-time build cost
- that later runs reuse the completed cache
- Add a short note to
- This should be opt-in at first. Default behavior should remain uncached until the path is validated in real training jobs.
- The cache builder should be invoked before training starts, not lazily inside the first eval hook. That makes failure/rebuild behavior obvious and keeps step-0 eval timing less surprising.
- The first implementation should stay conservative and only cache validation reference log-probs. Training-time caching is a separate problem with different correctness and storage tradeoffs.
- The reference cache should remain a DPO-only feature until there is a second real user of the same pattern.
- If DPO-specific runtime code keeps growing, promote
lib/levanter/src/levanter/dpo.pyinto alib/levanter/src/levanter/dpo/package later. - After the library path lands, archive or shrink
experiments/eval_dpo.pyso the production implementation has one canonical home. - If eval is still slow after reference caching, revisit eval-specific parameter layout / replicated-weight experiments.
- I reviewed the failed Claude merge attempt and confirmed the main mistake: it tried to resolve the DPO conflicts by provenance ("keep our refactor everywhere") instead of by behavior.
- That would have dropped the recent
origin/mainchange that writesgeneration_config.jsonwith chat stop tokens for HF exports. For chat models this is important because inference stacks like vLLM need the chat stop token, not just the tokenizer EOS token, to stop generation correctly after DPO. - The conflict itself was centered in
lib/levanter/src/levanter/main/train_dpo.py, but the full feature was not limited to conflict markers. Main also added required support in the shared HF export path and corresponding tests/docs.
- I kept the refactored DPO runtime shape from this branch:
train_dpo.pyremains the canonical entrypoint- the adapter/reference abstraction stays in place
model_init.pyremains the model/bootstrap helper
- I ported main's generation-config behavior into the refactored export layer instead of resurrecting main's older inline save path.
- Concretely, I threaded
generation_configthrough:lib/levanter/src/levanter/compat/hf_checkpoints.pylib/levanter/src/levanter/adaptation.pylib/levanter/src/levanter/lora.pylib/levanter/src/levanter/main/train_dpo.py
- I also brought over the DPO-facing config/docs/tests pieces:
hf_generation_eos_token_idsconfig plumbing- the DPO guide section on generation stop tokens
- HF export tests and DPO smoke coverage
- This is a better refactor than replaying main literally because merged LoRA exports now get the same
generation_config.jsonbehavior as regular DPO.
- I merged current
origin/mainintodpo-loraand created merge commite6dcb2e96. - After the merge,
origin/dpo-lorawas confirmed to be53commits ahead and0commits behindorigin/main, so the branch contains current main as of2026-03-30. - Verification for the merge-resolution work:
./infra/pre-commit.py --fix ...on the touched DPO/HF export files: passeduv run --project lib/levanter --group test python -m pytest lib/levanter/tests/test_hf_export.py lib/levanter/tests/test_hf_checkpoints.py lib/levanter/tests/test_dpo.py lib/levanter/tests/test_lora_dpo.py:49 passed, 2 skipped
- I implemented the pre-emption-safe cached reference-eval path in the real DPO runtime.
- The production implementation lives in
lib/levanter/src/levanter/dpo.py.experiments/eval_dpo.pyremains a prototype/profiling artifact and is not the canonical home for this feature. - The cache is durable on GCS via Levanter's cache ledger semantics, and then loaded into host RAM for repeated eval use during the run.
- The feature is opt-in and scoped to validation/eval only. Training-time reference computation and checkpoint state are unchanged.
- New shared DPO runtime module:
lib/levanter/src/levanter/dpo.py
- That module now owns:
DpoModel- shared DPO loss helpers
ReferenceEvalCacheConfigCachedDpoExampleCachedReferenceDatasetValidationDatasetSpec- deterministic cache metadata / cache path helpers
- build/load/build-or-load helpers for reference-eval caches
lib/levanter/src/levanter/main/train_dpo.pystays as orchestration:- builds validation specs
- derives reference identity
- optionally materializes caches before training
- swaps eval datasets onto cached wrappers
- leaves training behavior unchanged
TrainDpoConfignow hasreference_eval_cache, and the config is threaded through:lib/levanter/src/levanter/main/train_dpo.pylib/levanter/src/levanter/main/lora_dpo.pyexperiments/simple_dpo_config.pyexperiments/defaults.py
- Validation datasets now preserve provenance needed for deterministic cache identity:
- source cache path
- split name
- slice bounds when using
validation_split_fraction
- Cache identity includes:
- validation cache provenance
- slice bounds
- reference identity
- sequence length
- schema/version kind
- Cache loading is intentionally strict:
- missing cache -> build
- unfinished ledger -> build
- metadata mismatch -> rebuild
- The eval loss path now accepts either a normal
DpoExampleor aCachedDpoExample, so cached eval skips the reference forward passes while the training path remains the normal uncached computation.
- I updated
lib/levanter/tests/test_dpo.pywith focused coverage for:- config parsing
- cache-path identity changes
- metadata mismatch rejection
- cache build/load behavior
- cached-vs-uncached DPO-loss parity
lib/levanter/tests/test_lora_dpo.pynow imports the shared DPO helpers fromlevanter.dpo.- I added a short "Reference Eval Cache" section to
lib/levanter/docs/guides/DPO-Training.md.
./infra/pre-commit.py --fix .agents/logbooks/dpo-lora-codex.md experiments/defaults.py experiments/simple_dpo_config.py lib/levanter/docs/guides/DPO-Training.md lib/levanter/src/levanter/dpo.py lib/levanter/src/levanter/main/lora_dpo.py lib/levanter/src/levanter/main/train_dpo.py lib/levanter/tests/test_dpo.py lib/levanter/tests/test_lora_dpo.py- passed
uv run --project lib/levanter --group test python -m pytest lib/levanter/tests/test_dpo.py lib/levanter/tests/test_lora_dpo.py- passed with
34 passed, 1 skipped
- passed with
- I committed the cache work as
35d9c444bwith subject[dpo] Cache reference eval logprobs. - I pushed that commit to
origin/dpo-lora. - I did not use the repo's literal
make fixtarget for this step because in this dirty worktree it would have swept unrelated local modifications and still would not have included the new untrackedlib/levanter/src/levanter/dpo.py. - Instead I ran the required repo fixer directly on the scoped file set, then committed and pushed only the DPO cache work.
- Keep babysitting the resumed LoRA DPO run for W&B run
endlboq3until it either:- finishes final eval and writes the final merged HF export, or
- fails / gets preempted and needs another recovery decision
- The specific thing we were trying to validate was that the LoRA resume path is now correct after the adapter-checkpoint fix:
- trainer state should load from the east5 training checkpoint
- base model should be reconstructed from the configured central1 HF source
- resumed LoRA weights should be overlaid on top of that base model
- Several earlier central1 relaunches were wrong:
- one used unsupported deep
draccusCLI overrides for dataset cache dirs - one used
initialize_from_checkpoint_pathagainst a LoRA trainer checkpoint and failed with missing base arrays - one used
trainer.initialize_fromtogether withinitialize_from_hfand hit the config validation error
- one used unsupported deep
- The clean relaunch path is:
- keep
initialize_from_hfin the config - do not use
trainer.initialize_from - do not use
initialize_from_checkpoint_path - instead pass
--trainer.load_checkpoint_path <east5 trainer checkpoint>
- keep
- Correct running job:
- Iris job id:
/ahmed/lora-bsv2-mi-b0p1-s2-lr7p5e6-b64-v5p8-central1-resume-20260330-2236 - zone / hardware:
us-central1-a,v5p-8 - W&B run:
https://wandb.ai/marin-community/dpo/runs/endlboq3 - config:
lib/levanter/config/dpo/lora_dpo_bloom_speceval_v2_marin_instruct_beta0.1_lr7.5e-6_seed2_v5p8_central1_b64.yaml - checkpoint source for resume:
gs://marin-us-east5/checkpoints/dpo/lora_bloom_speceval_v2_marin_instruct_beta0.1_lr7.5e-6_seed2_v5p8_b64/endlboq3/step-835
- Iris job id:
- W&B reused
endlboq3 - The process loaded the east5 trainer checkpoint via
--trainer.load_checkpoint_path - The fixed LoRA resume path ran:
- log line:
Resuming from step 836, using checkpoint policy weights. - log line:
Adapter checkpoints only store trainable weights. Reconstructing the base policy model from the configured source before overlaying resumed adapter parameters.
- log line:
- The run rebuilt the base model from the central1 HF shards and resumed training from step
836 - Training completed successfully on the resumed job
- A final central1 training checkpoint was written to:
gs://marin-us-central1/checkpoints/dpo/lora_bloom_speceval_v2_marin_instruct_beta0.1_lr7.5e-6_seed2_v5p8_b64/endlboq3/step-849
- Timestamp of last status refresh:
2026-03-30 23:26:48 PDT - Iris state:
JOB_STATE_RUNNING preemption_count: 0- The job is in the final eval loop, not dead or hung
- Latest observed eval progress:
147 / 367validation batches- about
13.2s/it - about
48 minutesremaining from that sample
- There is still no merged HF export under:
gs://marin-us-central1/checkpoints/dpo/lora_bloom_speceval_v2_marin_instruct_beta0.1_lr7.5e-6_seed2_v5p8_b64/merged_hf/endlboq3/
- That is expected so far because the export hooks have not fired yet; the job is still inside eval
- This run is not using the new reference-eval cache feature
- the YAML has no
reference_eval_cacheblock - so final eval is on the old uncached DPO path
- the YAML has no
- That is why eval is slow:
- uncached DPO eval still computes policy chosen
- policy rejected
- reference chosen
- reference rejected
- The long final eval time is therefore expected and is not evidence of another resume failure
- There are occasional log lines
Failed to compress input file- these appeared while the job was otherwise healthy
- they did not correlate with failure, preemption, or checkpoint corruption
- Keep the long cadence; this is now mostly a wait-for-terminal-state problem
- Watch for three things only:
- preemption / worker death
- clean end of eval
- creation of the final merged HF export
- If the job completes normally, verify:
- terminal success in Iris
- final eval finished
- merged HF export appears under
.../merged_hf/endlboq3/ - ideally
step-849or the final-step-equivalent directory created by the export hook
- If the job fails or gets preempted, do not go back to
trainer.initialize_from- the known-good restart pattern is the same
central1_b64YAML plus--trainer.load_checkpoint_path <latest training checkpoint>
- the known-good restart pattern is the same
- User wanted the
SimpleDPOConfigsurface to support a stable "one pass over the train set" mode for DPO experiments instead of hard-coding step budgets. - User also wanted a default validation cadence that runs:
- once before training
- three evenly spaced interior validations
- once at the end
- This follow-up was code-only. I did not launch a new TPU training run in this pass.
- Updated
experiments/simple_dpo_config.pyso:num_train_steps: int | None = Nonenum_epochs: float = 1.0steps_per_eval: int | None = None- added validation in
__post_init__for non-positive values
- Updated
experiments/defaults.pysodefault_dpo(...)now:- keeps explicit
num_train_steps/steps_per_evalif they are set - otherwise enables runtime auto-resolution through
TrainDpoOnPodConfig(auto_num_epochs=..., auto_validation_runs=5)
- keeps explicit
- Updated
lib/marin/src/marin/training/training.pyso DPO runtime launch now:- reads tokenizer stats from the concrete cache path
- uses
total_elementsas the DPO train-set size - applies the same validation-split rule as DPO runtime
- converts
num_epochsintonum_train_stepswithBatchSchedule.find_step_containing_offset(...) + 1 - resolves exact interior eval steps for the 5-validation schedule
- Added
lib/marin/src/marin/processing/tokenize/cache_stats.py. - This is intentionally a read-side only helper for this PR:
TokenizedCacheStatstokenized_cache_stats_path(...)read_tokenized_cache_stats(...)
- It uses
rigging.filesystem.url_to_fs(...)plusopen_url(...)instead ofos.path.join(...). - Exported the helper from
lib/marin/src/marin/processing/tokenize/__init__.py. - I explicitly did not refactor the tokenizer write path in this pass; it still writes the same
.stats.jsonschema in place.
- Updated
lib/levanter/src/levanter/main/train_dpo.pyto add:scheduled_eval_stepsrun_initial_eval
- Implemented:
- explicit initial validation before training starts
- exact interior eval steps via a small hook wrapper
- final validation through the trainer's existing forced end-of-training hook path
- Removed the old extra DPO-side
trainer.run_hooks(last_info, force=True)call becauseTrainer.train()already force-runs hooks at the end.
- I reviewed a competing critique carefully. The main conclusions after code inspection and tests:
- removing the extra DPO-side
run_hooks(..., force=True)is not a regression becauseTrainer.train()already force-runs hooks at the end - the synthetic initial eval logs at step 0, not step 1, because
StepInfo.step = state.step - 1 - using
.stats.jsonis the right source of DPO dataset size for this feature; for preference data the meaningful field istotal_elements - private imports from
train_dpo.pywere avoided in the final version by moving the stats read into a shared helper and keeping the tiny DPO-specific component/split logic local intraining.py
- removing the extra DPO-side
- Updated
experiments/tune_lora/common.py:LoraTuneSpec.num_epochs = 1.0- removed the explicit
num_train_steps=850 - removed the explicit
steps_per_eval=200
- Updated
experiments/dpo_bloom_speceval_v2.py:- switched the main script from
num_train_steps=850tonum_epochs=1.0 - removed the explicit
steps_per_eval=200
- switched the main script from
- I intentionally left
experiments/dpo_bloom_speceval_v2_profile.pyalone because that script is explicitly built around fixed eval timing for profiling.
- The train tokenizer stats we have been using report
108,765train preference pairs. - Therefore:
- LoRA sweep config in
experiments/tune_lora/common.pywith batch64andnum_epochs=1.0resolves to1700train steps - auto eval schedule there is 5 total validations:
- initial
- steps
425,850,1275 - final
- LoRA sweep config in
- Main non-LoRA Bloom SpecEval v2 script with batch
128andnum_epochs=1.0resolves to850train steps- auto eval schedule there is 5 total validations:
- initial
- steps
213,426,639 - final
- auto eval schedule there is 5 total validations:
- Ran:
./infra/pre-commit.py --fix lib/marin/src/marin/processing/tokenize/cache_stats.py lib/marin/src/marin/processing/tokenize/__init__.py lib/marin/src/marin/training/training.py tests/test_training.pyuv run python -m pytest -o addopts='' tests/test_training.pycd lib/levanter && uv run --group test python -m pytest tests/test_dpo.py tests/test_lora_dpo.py./infra/pre-commit.py --fix experiments/dpo_bloom_speceval_v2.py
- Results:
- targeted Marin checks passed
tests/test_training.py:8 passed- Levanter DPO tests:
35 passed, 1 skipped
- No new runtime experiment launched in this follow-up.
- Code now supports:
- explicit step-based DPO configs when a script sets
num_train_steps/steps_per_eval - one-epoch auto budgeting plus 5-point eval cadence when those fields are left unset
- explicit step-based DPO configs when a script sets
- The main Bloom SpecEval v2 scripts now opt into the new auto path directly.
- Add the standard pretraining LM eval suites to DPO runs without mixing them into preference validation.
- Apply this automatically to both:
- the standard Bloom SpecEval v2 DPO run
- the LoRA DPO sweep, via
default_dpo(...)
- Updated
experiments/defaults.pysodefault_dpo(...)now always builds a separatelm_validation_datafromdefault_validation_sets(tokenizer=...). - This reuses the same validation bundle used by pretraining:
- Paloma tokenized eval sets
- Uncheatable Eval tokenized eval sets
- Updated
lib/levanter/src/levanter/main/train_dpo.pyto add:lm_validation_data: LmDataConfig | Nonelm_validation_prefix: str = "lm_eval"
- DPO now runs LM evals through
levanter.eval.cb_tagged_lm_evaluate(...)on the same cadence as DPO preference eval:- initial eval before training
- scheduled interior evals
- final forced eval at end of training
- Preference validation remains under
eval/...and still reports DPO loss metrics for preference datasets. - LM eval metrics are logged separately under
lm_eval/...so they do not collide with preference loss:lm_eval/paloma/...lm_eval/uncheatable_eval/...
- This keeps DPO-specific validation and LM generalization validation distinct in W&B.
- I did not try to cram
default_validation_sets()intoPreferenceLmDataConfig; those datasets are LM eval data, not preference pairs. - Instead, DPO now carries two parallel validation paths:
- preference validation for DPO loss
- LM validation for next-token loss / bpb on Paloma and Uncheatable
- Added small callback wrappers in
train_dpo.pyto prevent duplicate eval runs when:- initial eval happens at step 0
- trainer force-runs hooks at the final step
- Ran:
make fixuv run python -m pytest -o addopts='' tests/test_training.pyuv run --directory lib/levanter --group test python -m pytest tests/test_dpo.py tests/test_lora_dpo.py
- Results:
tests/test_training.py:9 passed- Levanter DPO tests:
37 passed, 1 skipped
- Any DPO experiment that goes through
default_dpo(...)now picks up Paloma + Uncheatable LM evals automatically. - That includes:
- If a Bloom SpecEval v2 LoRA run is launched now from
experiments/tune_lora/common.py:- training uses the existing train preference cache with
108,765train pairs - validation uses the deduped validation cache with
2,606validation pairs num_epochs=1.0and batch size64resolve to1,700train steps- DPO preference eval runs 5 times total:
- initial
- steps
425,850,1275 - final
- Paloma + Uncheatable LM evals run on that exact same schedule under
lm_eval/...
- training uses the existing train preference cache with
- The branch containing this work was pushed as commit
1e0e5fe9f:[dpo] Add LM eval suites to DPO runs
- Expand the LoRA DPO learning-rate sweep around the current Bloom SpecEval v2 setup.
- Use two seeds per hyperparameter combination.
- Keep the current executor-native training behavior:
num_epochs=1.0- preference eval on the 5-point schedule
- Paloma + Uncheatable LM evals via
default_validation_sets()
- Keep:
beta=0.1- batch size
64 - rank
64 alpha=64dropout=0.0reference.type=adapter_basereference_eval_cache.mode=build_or_loadtrain_seq_len=max_seq_len=4096v5p-8
- Do not reuse the exact same slug string for new combinations.
- Reuse the existing slug pattern:
bloom_speceval_v2_marin_instruct_lora_beta0p1_lr{lr}_seed{seed}_b64_v5p8
- The current scripted sweep is concentrated in the
5e-6to1e-5range. - The LoRA DPO best-practices note suggests the most promising regime is often lower, roughly
5e-7to5e-6. - So the next useful move is not to push higher than
1e-5, but to add denser lower-LR points while also filling in missing seed coverage.
- Keep existing scripted LRs:
5e-66.25e-67.5e-68.75e-61e-5
- Add lower-LR points:
1e-62.5e-63.75e-64.5e-6
- Run every LR with:
seed=0seed=2
- If we want to stay conservative on job count, the minimum good expansion is:
- add
2.5e-6 - add
3.75e-6 - add missing
seed=0companions for the currently seed-2-only LRs
- add
- This gives broader LR coverage without exploding the sweep immediately.
- Existing or already represented:
lr=5e-6, seeds0, 2lr=6.25e-6, seeds0, 2lr=7.5e-6, seeds0, 2lr=8.75e-6, seeds0, 2lr=1e-5, seeds0, 2
- New lower-LR additions:
lr=1e-6, seeds0, 2lr=2.5e-6, seeds0, 2lr=3.75e-6, seeds0, 2lr=4.5e-6, seeds0, 2
bloom_speceval_v2_marin_instruct_lora_beta0p1_lr1e6_seed0_b64_v5p8bloom_speceval_v2_marin_instruct_lora_beta0p1_lr1e6_seed2_b64_v5p8bloom_speceval_v2_marin_instruct_lora_beta0p1_lr2p5e6_seed0_b64_v5p8bloom_speceval_v2_marin_instruct_lora_beta0p1_lr2p5e6_seed2_b64_v5p8bloom_speceval_v2_marin_instruct_lora_beta0p1_lr3p75e6_seed0_b64_v5p8bloom_speceval_v2_marin_instruct_lora_beta0p1_lr3p75e6_seed2_b64_v5p8bloom_speceval_v2_marin_instruct_lora_beta0p1_lr4p5e6_seed0_b64_v5p8bloom_speceval_v2_marin_instruct_lora_beta0p1_lr4p5e6_seed2_b64_v5p8
- Each run should:
- resolve to
1,700train steps fromnum_epochs=1.0 - run DPO preference eval at:
- initial
- steps
425,850,1275 - final
- run Paloma + Uncheatable LM evals on the same schedule under
lm_eval/...
- resolve to
- Added full
9 x 2wrapper coverage inexperiments/tune_lora/:- learning rates:
1e-62.5e-63.75e-64.5e-65e-66.25e-67.5e-68.75e-61e-5
- seeds:
02
- learning rates:
- Launch each LoRA sweep job on Iris with
v5p-8. - Allow Iris to place the job in either:
us-central1us-east5
- We do not care which region wins as long as the job lands on
v5p-8.
- Use repeated region flags:
--region us-central1 --region us-east5
- Do not use:
us-east5a
us-east5is the region.us-east5-ais the zone.
- Train preference cache is present in both regions.
- Deduped validation preference cache is present in both regions.
default_validation_sets()is present23/23in both regions.- That means the current LoRA sweep is safe to launch with dual-region placement.
- Each LoRA job should:
- request
v5p-8 - resolve
num_epochs=1.0to1,700train steps - run DPO preference eval at:
- initial
- steps
425,850,1275 - final
- run Paloma + Uncheatable LM evals on the same schedule under
lm_eval/...
- request
Run these from the repo root:
- If launching from this worktree, first make sure local
wandb_claude_data/andwandb_data/are removed, ignored, or moved out of tree; Iris previously rejected a73MBsource bundle from those directories alone.
uv run iris --config lib/iris/examples/marin.yaml job run --job-name bloom-speceval-v2-lora-beta0p1-lr1e6-seed0-b64 --tpu v5p-8 --region us-central1 --region us-east5 --no-wait -- uv run python experiments/tune_lora/beta0p1_lr1e6_seed0_b64.py
uv run iris --config lib/iris/examples/marin.yaml job run --job-name bloom-speceval-v2-lora-beta0p1-lr1e6-seed2-b64 --tpu v5p-8 --region us-central1 --region us-east5 --no-wait -- uv run python experiments/tune_lora/beta0p1_lr1e6_seed2_b64.py
uv run iris --config lib/iris/examples/marin.yaml job run --job-name bloom-speceval-v2-lora-beta0p1-lr2p5e6-seed0-b64 --tpu v5p-8 --region us-central1 --region us-east5 --no-wait -- uv run python experiments/tune_lora/beta0p1_lr2p5e6_seed0_b64.py
uv run iris --config lib/iris/examples/marin.yaml job run --job-name bloom-speceval-v2-lora-beta0p1-lr2p5e6-seed2-b64 --tpu v5p-8 --region us-central1 --region us-east5 --no-wait -- uv run python experiments/tune_lora/beta0p1_lr2p5e6_seed2_b64.py
uv run iris --config lib/iris/examples/marin.yaml job run --job-name bloom-speceval-v2-lora-beta0p1-lr3p75e6-seed0-b64 --tpu v5p-8 --region us-central1 --region us-east5 --no-wait -- uv run python experiments/tune_lora/beta0p1_lr3p75e6_seed0_b64.py
uv run iris --config lib/iris/examples/marin.yaml job run --job-name bloom-speceval-v2-lora-beta0p1-lr3p75e6-seed2-b64 --tpu v5p-8 --region us-central1 --region us-east5 --no-wait -- uv run python experiments/tune_lora/beta0p1_lr3p75e6_seed2_b64.py
uv run iris --config lib/iris/examples/marin.yaml job run --job-name bloom-speceval-v2-lora-beta0p1-lr4p5e6-seed0-b64 --tpu v5p-8 --region us-central1 --region us-east5 --no-wait -- uv run python experiments/tune_lora/beta0p1_lr4p5e6_seed0_b64.py
uv run iris --config lib/iris/examples/marin.yaml job run --job-name bloom-speceval-v2-lora-beta0p1-lr4p5e6-seed2-b64 --tpu v5p-8 --region us-central1 --region us-east5 --no-wait -- uv run python experiments/tune_lora/beta0p1_lr4p5e6_seed2_b64.py
uv run iris --config lib/iris/examples/marin.yaml job run --job-name bloom-speceval-v2-lora-beta0p1-lr5e6-seed0-b64 --tpu v5p-8 --region us-central1 --region us-east5 --no-wait -- uv run python experiments/tune_lora/beta0p1_lr5e6_seed0_b64.py
uv run iris --config lib/iris/examples/marin.yaml job run --job-name bloom-speceval-v2-lora-beta0p1-lr5e6-seed2-b64 --tpu v5p-8 --region us-central1 --region us-east5 --no-wait -- uv run python experiments/tune_lora/beta0p1_lr5e6_seed2_b64.py
uv run iris --config lib/iris/examples/marin.yaml job run --job-name bloom-speceval-v2-lora-beta0p1-lr6p25e6-seed0-b64 --tpu v5p-8 --region us-central1 --region us-east5 --no-wait -- uv run python experiments/tune_lora/beta0p1_lr6p25e6_seed0_b64.py
uv run iris --config lib/iris/examples/marin.yaml job run --job-name bloom-speceval-v2-lora-beta0p1-lr6p25e6-seed2-b64 --tpu v5p-8 --region us-central1 --region us-east5 --no-wait -- uv run python experiments/tune_lora/beta0p1_lr6p25e6_seed2_b64.py
uv run iris --config lib/iris/examples/marin.yaml job run --job-name bloom-speceval-v2-lora-beta0p1-lr7p5e6-seed0-b64 --tpu v5p-8 --region us-central1 --region us-east5 --no-wait -- uv run python experiments/tune_lora/beta0p1_lr7p5e6_seed0_b64.py
uv run iris --config lib/iris/examples/marin.yaml job run --job-name bloom-speceval-v2-lora-beta0p1-lr7p5e6-seed2-b64 --tpu v5p-8 --region us-central1 --region us-east5 --no-wait -- uv run python experiments/tune_lora/beta0p1_lr7p5e6_seed2_b64.py
uv run iris --config lib/iris/examples/marin.yaml job run --job-name bloom-speceval-v2-lora-beta0p1-lr8p75e6-seed0-b64 --tpu v5p-8 --region us-central1 --region us-east5 --no-wait -- uv run python experiments/tune_lora/beta0p1_lr8p75e6_seed0_b64.py
uv run iris --config lib/iris/examples/marin.yaml job run --job-name bloom-speceval-v2-lora-beta0p1-lr8p75e6-seed2-b64 --tpu v5p-8 --region us-central1 --region us-east5 --no-wait -- uv run python experiments/tune_lora/beta0p1_lr8p75e6_seed2_b64.py
uv run iris --config lib/iris/examples/marin.yaml job run --job-name bloom-speceval-v2-lora-beta0p1-lr1e5-seed0-b64 --tpu v5p-8 --region us-central1 --region us-east5 --no-wait -- uv run python experiments/tune_lora/beta0p1_lr1e5_seed0_b64.py
uv run iris --config lib/iris/examples/marin.yaml job run --job-name bloom-speceval-v2-lora-beta0p1-lr1e5-seed2-b64 --tpu v5p-8 --region us-central1 --region us-east5 --no-wait -- uv run python experiments/tune_lora/beta0p1_lr1e5_seed2_b64.py- Created a local export root at:
- Exported the canonical 18 tune-LoRA runs only:
- the
1700-step variants - not the older superseded
850-step duplicates
- the
- Wrote a top-level manifest:
- For each run, saved:
run_metadata.jsonconfig.jsonsummary.jsonfiles_manifest.jsonlogged_artifacts.jsonhistory.jsonl.gzsystem_history_sampled.csvdownloaded_files.jsonfiles/mirror of the run-attached W&B files
- This export is W&B-local metadata/history/files only.
- I did not download large checkpoint artifacts from W&B artifact references.
- Manifest count:
18runs
- All exported runs are marked:
state = finished
- Export summary from the manifest:
history_rows = 2000for each runnum_file_errors = 0for every run
- Total local size after export:
- about
305M
- about
- The run-attached file count varies across runs (
16to40files), but all attached files that W&B exposed for these runs downloaded successfully. - The canonical run list used for this export is the same
18-run9 x 2matrix described above under the updated sweep plan.
uv run iris --config=lib/iris/examples/marin.yaml cluster statusAt the time of this check, Iris reported:
tpu_v5p_8-us-central1-a:Ready 21tpu_v5p_8-us-east5-a:Ready 21tpu_v5p_32-us-central1-a:Ready 2tpu_v5p_32-us-east5-a:Ready 0
So the answer is: it is not only v5p-8. There is currently live v5p-32 capacity in us-central1-a, although the pool is much smaller than v5p-8.
The current code paths are meaningfully different:
- Full DPO in
experiments/sweep_dpo/uses a true separate reference model viaSeparateReferenceConfig. - In
train_dpo.py, that path constructsDpoModel(policy, reference)and loads the reference weights separately. - The tune-LoRA sweep in
experiments/tune_lora/common.pyusesAdapterBaseReferenceConfig, so the reference path is taken as the base-model view of the policy instead of loading a second full model.
Important nuance: both paths still compute policy chosen/rejected and reference chosen/rejected log-probs during training. So the batch/sequence-driven activation and temporary-allocation pressure should be directionally similar. The main extra burden in full DPO is resident frozen-reference model memory, not a totally different loss shape.
Known hard evidence from the LoRA v5p-8 failure:
- observed usage at first-step compile:
111.15G v5p-8HBM capacity seen by XLA:95.74G- over by:
15.41G
If we make the intentionally crude assumption that the dominant training-step memory is roughly linear in batch size, then:
- batch
128:111.15G - batch
64: about55.58G - implied headroom at batch
64: about40.16G
That is a large enough gap that simply halving batch looks plausibly sufficient from the activation/temp-allocation side.
The catch is full DPO:
- LoRA at
reference=adapter_baseavoided storing a second separately loaded reference model. - Full DPO at
reference=separatemust keep extra frozen reference parameters resident. - Therefore full DPO batch
64onv5p-8is not guaranteed by the LoRA math.
My best napkin read is:
- full DPO on
v5p-8at batch64looks plausible enough to try once - but it is still materially riskier than the LoRA
b64sweep because the extra reference-model residency eats into that~40Gcrude headroom - if it fails,
train_batch_size: 32is the obvious next rung
- If the goal is the lowest-risk path to reproduce the February full-DPO runs with more evals, prefer
v5p-32. - If the goal is to test whether we can get away with cheaper hardware, a single
v5p-8, batch-64full-DPO trial is justified by the current napkin math. - I would treat that
v5p-8batch-64run as a fit/probe, not as guaranteed-capacity ground truth.
User requested a single launch on v5p-16 named dummy.
I am using a one-off wrapper at:
Config choice for this probe:
- full DPO, not LoRA
- Bloom SpecEval v2 tokenized train/val
beta=0.1- seed
0 v5p-16train_batch_size=64num_train_steps=10steps_per_eval=5
Reason for the short run:
- this is a hardware-fit / launch probe, not a canonical long baseline
- the main uncertainty is whether full DPO with a separate reference model fits and starts cleanly on
v5p-16
Planned launch command:
uv run iris --config=lib/iris/examples/marin.yaml job run --job-name dummy --tpu v5p-16 --region us-central1 -e WANDB_API_KEY "$WANDB_API_KEY" --no-wait -- uv run python experiments/sweep_dpo/dummy.py- submitted successfully via Iris
- job id:
/ahmed/dummy - immediate scheduler state from
iris job list:pending - scheduler note:
Coscheduling: need 2 workers in 'tpu-name' group ...
Interpretation:
- this is expected for a multinode
v5p-16request - the job is queued cleanly; it has not failed at submit time
Subsequent monitoring showed:
- child training job:
/ahmed/dummy/train_dpo - child scheduler state:
running - W&B run from process 0:
https://wandb.ai/marin-community/dpo/runs/dummy-3ca308
Signals seen before first-step compile:
- both TPU tasks started and joined JAX distributed cleanly
- W&B initialized successfully on process 0
- the training cache and validation cache started loading
- no
RESOURCE_EXHAUSTED, HBM, OOM, orFAILED_PRECONDITIONsignal observed yet
This means the probe has cleared scheduler placement and entered the actual train-DPO runtime path on v5p-16.
Continued babysitting showed that dummy advanced beyond simple process startup:
- it finished one full pass reading the
marin-community/marin-8b-instructsafetensor shard set - it then began a second HF load pass, consistent with full DPO loading the separate reference model
- the job remained in scheduler state
runningthroughout this phase
Important negative signal:
- still no
RESOURCE_EXHAUSTED, TPU HBM,FAILED_PRECONDITION, traceback, or process-death signal during this model-loading phase
This is not yet proof that first-step compile will fit, but it is stronger evidence than mere scheduler placement: the full-DPO v5p-16 probe is making real progress through the heavyweight model-load path without early failure.
Once the dummy probe reached the real train_dpo child runtime on v5p-16, the next step was to launch a comparable full-DPO sweep against the LoRA b64, 1 epoch runs.
New wrappers:
experiments/sweep_dpo/compare_lora_beta0p1_b64_v5p16_common.pyexperiments/sweep_dpo/compare_lora_beta0p1_b64_v5p16_seed0.pyexperiments/sweep_dpo/compare_lora_beta0p1_b64_v5p16_seed1.pyexperiments/sweep_dpo/compare_lora_beta0p1_b64_v5p16_seed2.py
Sweep shape:
- full DPO, not LoRA
beta=0.1- seeds
0,1,2 train_batch_size=64num_epochs=1.0v5p-16- same Bloom SpecEval v2 train/val preference data
- same base LR as the original full-DPO baseline:
5e-7 - auto-scheduled 5-point validation because
steps_per_evalis intentionally left unset
Planned launch commands:
uv run iris --config=lib/iris/examples/marin.yaml job run --job-name bloom-speceval-v2-full-dpo-compare-lora-beta0p1-seed0-b64-v5p16 --tpu v5p-16 --region us-central1 -e WANDB_API_KEY "$WANDB_API_KEY" --no-wait -- uv run python experiments/sweep_dpo/compare_lora_beta0p1_b64_v5p16_seed0.py
uv run iris --config=lib/iris/examples/marin.yaml job run --job-name bloom-speceval-v2-full-dpo-compare-lora-beta0p1-seed1-b64-v5p16 --tpu v5p-16 --region us-central1 -e WANDB_API_KEY "$WANDB_API_KEY" --no-wait -- uv run python experiments/sweep_dpo/compare_lora_beta0p1_b64_v5p16_seed1.py
uv run iris --config=lib/iris/examples/marin.yaml job run --job-name bloom-speceval-v2-full-dpo-compare-lora-beta0p1-seed2-b64-v5p16 --tpu v5p-16 --region us-central1 -e WANDB_API_KEY "$WANDB_API_KEY" --no-wait -- uv run python experiments/sweep_dpo/compare_lora_beta0p1_b64_v5p16_seed2.py- seed
0job id:/ahmed/bloom-speceval-v2-full-dpo-compare-lora-beta0p1-seed0-b64-v5p16
- seed
1job id:/ahmed/bloom-speceval-v2-full-dpo-compare-lora-beta0p1-seed1-b64-v5p16
- seed
2job id:/ahmed/bloom-speceval-v2-full-dpo-compare-lora-beta0p1-seed2-b64-v5p16
Immediate scheduler states after submit:
- seed
0:running - seed
1:pending - seed
2:pending
Pending reason for seeds 1 and 2:
Coscheduling: need 2 workers in 'tpu-name' group ...
Interpretation:
- the cluster had enough free
v5p-16capacity to start one sweep job immediately whiledummywas still occupying another slice - the remaining two sweep jobs are queued cleanly behind the same multinode placement constraint
Created a new local read-only W&B export root for the newly finished full-DPO runs:
Top-level manifest:
Archived runs:
dummy-3ca308bloom_speceval_v2_beta0.1_seed0_b64_v5p16-68f963bloom_speceval_v2_beta0.1_seed1_b64_v5p16-c50842bloom_speceval_v2_beta0.1_seed2_b64_v5p16-2272c8
Saved per-run files:
run_metadata.jsonconfig.jsonsummary.jsonfiles_manifest.jsondownloaded_files.jsonlogged_artifacts.jsonhistory.jsonl.gzsystem_history_sampled.csvfiles/with downloaded W&B run files
Verification summary:
- manifest length:
4 - all run states:
finished - compare runs exported
1700history rows each dummyexported10history rows- file download errors:
0for all four runs - archive size on disk: about
33M
Implementation note:
- the first export attempt hit a benign W&B API attribute mismatch on
run.updated_at; rerunning withgetattr(..., None)completed the archive without modifying any runs
Built a new local comparison report for the OG February full-DPO baseline versus the new April full-DPO rerun:
scripts/dpo/plot_beta0p1_og_dpo_vs_new_dpo.pybeta0p1_og_dpo_vs_new_dpo.htmlbeta0p1_og_dpo_vs_new_dpo_selection.json
Selection rule:
- OG side:
beta=0.1February full-DPO baseline runs fromscratch/wandb_dpo_data/og_no_lora, averaged across seeds0/1/2 - new side:
beta=0.1April full-DPO reruns fromscratch/wandb_dpo_data/new_dpo, averaged across seeds0/1/2 - excluded: the
10-stepdummy-3ca308probe
Plotting choices:
- same four Bloom SpecEval v2 DPO metrics as the earlier LoRA comparison
- x-axis is normalized to percent of the run so OG
850-stepb128and new1700-stepb64runs line up as one-epoch curves - both loss panels use log-scale y-axes
- train loss is lightly smoothed; eval metrics remain checkpoint markers
Added a one-off executor wrapper for the cleanest current-head comparison against the old new_dpo_v2 run family:
Intent:
- keep the linked
new_dpo_v2_bloom_speceval_v2_marin_instruct_beta0.1_-7_seed2-947c5dtraining shape - run on
v5p-32 - keep
train_batch_size=128 - keep
num_train_steps=850 - keep
learning_rate=7.5e-7 - keep
steps_per_eval=200 - keep the same seed semantics as the old executor path (
data_seed=2, trainer seed remains defaulted) - but swap the primary DPO validation set to the current deduped val path
- and let current
default_dpo(...)add the modern validation stack (LM eval suites plus reference-eval cache)
This is intentionally a one-off debugging/repro script, not a new canonical experiment entry point.
Submitted the one-off regression run on Iris in us-central1-a.
Launch command:
uv run iris --config=lib/iris/examples/marin.yaml job run --job-name regression-test-dpo-beta0p1-lr7p5e-7-seed2-deduped-val --tpu v5p-32 --zone us-central1-a -e WANDB_API_KEY "$WANDB_API_KEY" --no-wait -- uv run python experiments/sweep_dpo/regression_test_dpo.pyIdentifiers:
- Iris job:
/ahmed/regression-test-dpo-beta0p1-lr7p5e-7-seed2-deduped-val - submit time:
2026-04-03 12:29:52 PDT
Initial scheduler state:
pending- reason:
Scheduler: Coscheduling: need 4 workers in 'tpu-name' group ...
By 2026-04-03 13:08 PDT, the child TPU training job had moved past the original coscheduling wait:
- parent executor job
/ahmed/regression-test-dpo-beta0p1-lr7p5e-7-seed2-deduped-val:running - child training job
/ahmed/regression-test-dpo-beta0p1-lr7p5e-7-seed2-deduped-val/train_dpo:running - child task counts at that moment:
building=4
Interpretation:
- the outer executor wrapper is alive and orchestrating steps
- the inner
train_dpochild has now received thev5p-32worker gang - startup is still in the worker-build / container-init phase, so this is not yet proof that the model has begun compiling or training
Next monitoring gate:
- keep short startup checks until child logs show actual training initialization or an early failure
By 2026-04-03 13:09 PDT, the child job had cleared the build phase and was fully running on all four TPU tasks:
- child training job
/ahmed/regression-test-dpo-beta0p1-lr7p5e-7-seed2-deduped-val/train_dpo:running - child task counts:
running=4
Earliest substantive startup signal observed:
- Levanter began loading the deduped validation cache ledger from
gs://marin-us-central1/tokenized/bloom_speceval_v2_val_deduped_prefs_marin_tokenizer-589b86/validation/shard_ledger.json
No early RESOURCE_EXHAUSTED, HBM OOM, FAILED_PRECONDITION, or node-death signal had appeared by that point.
Checked again at 2026-04-03 12:42 PDT and the one-off regression run had completed successfully:
- parent executor job
/ahmed/regression-test-dpo-beta0p1-lr7p5e-7-seed2-deduped-val:succeeded - child training job
/ahmed/regression-test-dpo-beta0p1-lr7p5e-7-seed2-deduped-val/train_dpo:succeeded - child task counts:
succeeded=4 - failure count:
0 - preemption count:
0
Runtime from Iris child timestamps:
- child start:
2026-04-03 12:01:13 PDT - child finish:
2026-04-03 17:15:42 PDT - child wall time:
5h 14m 28s
Parent executor wall time:
- parent start:
2026-04-03 12:00:16 PDT - parent finish:
2026-04-03 17:15:52 PDT - parent wall time:
5h 15m 36s
Pulled the finished regression run from W&B and archived it into the existing new_dpo scratch bundle.
Canonical W&B run:
- display name:
regression_test_dpo_bloom_lr7.5e-7_seed2_deduped_val-1e4e93 - URL:
https://wandb.ai/marin-community/dpo/runs/regression_test_dpo_bloom_lr7.5e-7_seed2_deduped_val-1e4e93
Local archive path:
Archived contents:
run_metadata.jsonconfig.jsonsummary.jsonhistory.jsonl.gzsystem_history_sampled.csvfiles_manifest.jsondownloaded_files.jsonlogged_artifacts.json- downloaded run files under
files/
Verification:
- manifest
finished_new_dpo_runs_manifest.jsonnow has5finished runs - regression run export has
history_rows = 850 system_history_rows = 500num_files = 11num_file_errors = 0
Pulled the older full-val reference run into a separate scratch archive so it would not contaminate the seed-averaged new_dpo directory.
Canonical W&B run:
- display name:
new_dpo_v2_bloom_speceval_v2_marin_instruct_beta0.1_-7_seed2-947c5d - URL:
https://wandb.ai/marin-community/dpo/runs/new_dpo_v2_bloom_speceval_v2_marin_instruct_beta0.1_-7_seed2-947c5d
Local archive path:
Verification:
- manifest
finished_reference_dpo_runs_manifest.jsonhas1finished run - reference export has
history_rows = 1000 system_history_rows = 500num_files = 13num_file_errors = 0
Built a dedicated exact-step comparison plot for the two single runs:
- baseline full-val run
947c5d - deduped-val regression run
1e4e93
Artifacts:
- plot script:
scripts/dpo/plot_regression_test_vs_new_dpo_v2.py - HTML:
scratch/wandb_dpo_data/plots/regression_test_vs_new_dpo_v2_seed2.html - selection summary:
scratch/wandb_dpo_data/plots/regression_test_vs_new_dpo_v2_seed2_selection.json
Key readout from the archived W&B summaries:
- train loss is identical at step
849:0.003482460742816329in both runs - eval loss diverges:
0.005509627517312765in947c5dvs0.10757964104413986in1e4e93 - eval accuracy diverges:
99.5669%in947c5dvs84.8307%in1e4e93
Interpretation:
- the exact same training shape reproduced cleanly
- the validation-side numbers changed sharply once the run used the deduped validation set plus current validation callbacks
- this supports the hypothesis that the training path is not what changed in the regression test
After reviewing the exact-step overlay in regression_test_vs_new_dpo_v2_seed2.html, confidence increased further:
- the train-loss curve is not just close at the endpoint; it visually overlays almost exactly across the full run
- the final train metrics also match exactly at step
849
Current working conclusion:
- the regression run
1e4e93is a moral reproduction of the old947c5dtraining run - current-head DPO training appears sound for this configuration
- the large metric gap is on the validation side, not in the optimization path
Practical implication for future comparisons:
- treat
1e4e93as the bridge run that establishes continuity between old full-val reporting and the new deduped-val reporting - compare future runs primarily against the deduped-val metric family rather than against the historical full-val numbers
- if needed, use
947c5donly as the archival reference showing that training itself reproduced
Built a dedicated local-archive comparison between the batch-64 full-DPO reruns and two selected LoRA groups.
Artifacts:
- plot script:
scripts/dpo/plot_full_dpo_b64_vs_tune_lora.py - HTML:
scratch/wandb_dpo_data/plots/beta0p1_full_dpo_b64_vs_tune_lora.html - selection summary:
scratch/wandb_dpo_data/plots/beta0p1_full_dpo_b64_vs_tune_lora_selection.json
Selection:
Full DPO: batch64, beta0.1, lr5e-7, steps1700, seeds0/1/2LoRA 10x LR: batch64, beta0.1, lr5e-6, steps1700, archived seeds0/2LoRA Best Eval: batch64, beta0.1, lr1e-5, steps1700, archived seeds0/2
Important caveat:
- the archived LoRA sweep contains only seeds
0and2for each learning rate - so the full-DPO group is a true
3-seed average, while both LoRA groups are true2-seed averages
Best-Eval selection rule:
- highest mean final eval accuracy
- tie-break lowest mean final eval loss
Resulting LoRA LR ranking head:
1e-5: mean final eval accuracy0.992759108543396, mean final eval loss0.0063946198206394918.75e-6: mean final eval accuracy0.9925685524940491, mean final eval loss0.0065465716179460295e-6(10x LR): mean final eval accuracy0.9919969439506532, mean final eval loss0.007281250087544322
Final mean metrics from the selected groups:
Full DPO: eval loss0.023859122768044472, eval accuracy0.9692460497220358, eval policy margin-40.860984802246094LoRA 10x LR: eval loss0.007281250087544322, eval accuracy0.9919969439506532, eval policy margin14.115777015686035LoRA Best Eval: eval loss0.006394619820639491, eval accuracy0.992759108543396, eval policy margin67.51775360107422
Interpretation:
- on the current deduped-val metric family, the selected LoRA runs clearly outperform the batch-64 full-DPO reruns
- the strongest LoRA setting in the archived sweep is
lr=1e-5 - the
10xheuristic choice at5e-6is close, but still weaker than the1e-5group by the chosen eval ranking
Updated beta0p1_full_dpo_b64_vs_tune_lora.html to include the full shared metric set from the archived histories, not just the DPO subset.
The current shared set across all three selected groups is 17 metrics:
- training:
train/loss,train/dpo_loss,train/dpo_accuracy,train/dpo_chosen_reward,train/dpo_rejected_reward,train/dpo_margin_policy,train/dpo_margin_ref - evaluation:
eval/bloom_speceval_v2_val/loss,eval/bloom_speceval_v2_val/dpo_loss,eval/bloom_speceval_v2_val/dpo_accuracy,eval/bloom_speceval_v2_val/dpo_chosen_reward,eval/bloom_speceval_v2_val/dpo_rejected_reward,eval/bloom_speceval_v2_val/dpo_margin_policy,eval/bloom_speceval_v2_val/dpo_margin_ref - eval timing/default bookkeeping:
eval/bloom_speceval_v2_val/timing/load_time,eval/bloom_speceval_v2_val/timing/loss_time,eval/bloom_speceval_v2_val/timing/num_batches
Note:
- the earlier statement that there were no extra LM-suite keys in this comparison was wrong
- the shared archive actually includes a large
lm_eval/...block on all three groups
Patched scripts/dpo/plot_full_dpo_b64_vs_tune_lora.py so the comparison now includes shared lm_eval/... metrics in addition to train/... and eval/....
Verification from the rebuilt selection summary:
beta0p1_full_dpo_b64_vs_tune_lora_selection.jsoncommon_metric_count = 77contains_lm_eval = true
Examples of the newly included shared LM-eval metrics:
- aggregate:
lm_eval/loss,lm_eval/bpb,lm_eval/macro_loss,lm_eval/macro_bpb,lm_eval/loading_time,lm_eval/total_time - Paloma slices:
lm_eval/paloma/... - Uncheatable eval slices:
lm_eval/uncheatable_eval/...
Net:
- the HTML now includes the shared train metrics, shared DPO validation metrics, and the shared LM-eval suite metrics for the selected full-DPO and LoRA groups
The all-in-one 77-metric page was too cluttered, so beta0p1_full_dpo_b64_vs_tune_lora.html is now rendered as a tabbed HTML instead of a single giant Plotly grid.
Current panes:
Core DPOLM SummaryPaloma SummaryPaloma BPBPaloma LossUncheatable SummaryUncheatable BPBUncheatable Loss
Implementation detail:
- the script now writes multiple Plotly figures into a single custom HTML shell with clickable tabs
- tab metadata is recorded in
beta0p1_full_dpo_b64_vs_tune_lora_selection.json
Copied the Bloom spec-adherence judging data needed for the Std/Opp/Natural/Adversarial figure into scratch/bloom_data.
Archive contents:
- selected judging subtree:
scratch/bloom_data/judging - selection manifest:
scratch/bloom_data/selected_runs_manifest.json - copied plotting script from Bloom:
scratch/bloom_data/adherence.py - copied plot outputs:
scratch/bloom_data/plot_output
Selection details:
13unique judging runs copied- only plotting-relevant artifacts copied per run:
summary.json,run_metadata.json, andper_statement/*.json - total local archive size after copy: about
854M
Plot provenance:
- source script is Bloom's
plot/adherence.py - source figure copied locally as
scratch/bloom_data/plot_output/overall_adherence.png
Local rerender note:
- attempted to rerun the plot against the copied scratch judging root
- this environment stalled during
matplotlibstartup/import before the plotting logic ran - as a practical fallback, copied the exact Bloom-generated plot outputs into the scratch bundle next to the selected data and plotting script
Looked across the full February no-LoRA family (12 runs: 2 betas x 2 learning rates x 3 seeds) to see whether any old W&B eval metric already preferred the same family as the Bloom opposite-mode adherence plot.
Scope:
- source runs: the full CS229 cache under
data/wandb_logs - available old eval metrics are only the
7DPO-validation metrics:eval/bloom_speceval_v2_val/dpo_accuracyeval/bloom_speceval_v2_val/dpo_chosen_rewardeval/bloom_speceval_v2_val/dpo_losseval/bloom_speceval_v2_val/dpo_margin_policyeval/bloom_speceval_v2_val/dpo_margin_refeval/bloom_speceval_v2_val/dpo_rejected_rewardeval/bloom_speceval_v2_val/loss
Main result:
- the cleanest old proxy for the Bloom opposite-mode preference toward the
beta=0.1family iseval/bloom_speceval_v2_val/dpo_accuracy
Family means:
beta=0.1, lr=5e-7:0.995811beta=0.1, lr=7.5e-7:0.995754beta=0.01, lr=5e-7:0.994650beta=0.01, lr=7.5e-7:0.994141
Interpretation:
- this metric puts both
beta=0.1configs above bothbeta=0.01configs - it also slightly prefers
beta=0.1, lr=5e-7overbeta=0.1, lr=7.5e-7, matching the opposite-mode read thatbeta=0.1, lr=5e-7is the healthiest controllability point - seed variance is negligible, so this is not a noise artifact
Supporting proxy:
eval/bloom_speceval_v2_val/dpo_chosen_rewardalso strongly separates the families:beta=0.1family is positive (2.86,2.57)beta=0.01family is negative (-0.95,-2.32)
Misleading old proxy:
eval/bloom_speceval_v2_val/dpo_margin_policypoints in the wrong direction for controllability:beta=0.01family has huge positive policy margins (430,1644)beta=0.1family is much smaller / negative (-79,-46)
- this now looks like a likely over-optimization signal rather than a health signal, since the same family is the one that fails to flip under opposite-mode system prompts
Ranked the canonical 18 tune-LoRA sweep runs in scratch/wandb_dpo_data/tune_lora by eval/bloom_speceval_v2_val/dpo_accuracy.
Top individual runs:
lr=1e-5bloom_speceval_v2_marin_lr1e5_seed0_b64_v5p8-41586d:dpo_accuracy=0.992759,dpo_loss=0.006383bloom_speceval_v2_marin_lr1e5_seed2_b64_v5p8-a73d6f:dpo_accuracy=0.992759,dpo_loss=0.006406
lr=8.75e-6bloom_speceval_v2_marin_lr8p75e6_seed2_b64_v5p8-f0636c:dpo_accuracy=0.992759,dpo_loss=0.006553bloom_speceval_v2_marin_lr8p75e6_seed0_b64_v5p8-ee2e69:dpo_accuracy=0.992378,dpo_loss=0.006541
Per-LR means ranked by mean eval dpo_accuracy:
lr=1e-5: meandpo_accuracy=0.992759, meandpo_loss=0.006395lr=8.75e-6: meandpo_accuracy=0.992569, meandpo_loss=0.006547lr=2.5e-6: meandpo_accuracy=0.992378, meandpo_loss=0.009144lr=7.5e-6: meandpo_accuracy=0.992187, meandpo_loss=0.006739lr=6.25e-6: meandpo_accuracy=0.991997, meandpo_loss=0.006965lr=5e-6: meandpo_accuracy=0.991997, meandpo_loss=0.007281lr=4.5e-6: meandpo_accuracy=0.991997, meandpo_loss=0.007452lr=3.75e-6: meandpo_accuracy=0.991997, meandpo_loss=0.007804lr=1e-6: meandpo_accuracy=0.991806, meandpo_loss=0.052525
Interpretation:
- if we treat eval
dpo_accuracyas the primary selector, the best LoRA family is clearlylr=1e-5 - the next best family is
lr=8.75e-6 - the previously-highlighted
10xfamily (lr=5e-6) is not best on this metric; it sits in the middle pack and mainly won before on the "tinker recommended" heuristic rather than pure eval-accuracy ranking - caveat: the archived LoRA sweep only has seeds
0and2, so these family means are 2-seed means rather than 3-seed means
Read the separate vLLM/export investigation logbook at:
Key finding from that thread:
- historical merged LoRA HF exports could write LoRA-targeted weights with the wrong axis order during
LoraLinear.merge() - this broke plain HF/vLLM loading for merged exports even though training itself was not implicated
- Bloom/vLLM inference also depended on
tokenizer_config.jsonembeddingchat_template, not only shippingchat_template.jinja
Backported fixes on this branch:
- fixed
LoraLinear.merge()to rearrange the LoRA delta onto the wrapped linear's weight axes before addition - taught
HFCheckpointConverter.save_pretrained()to save tokenizers through a shared helper that embedschat_templateintotokenizer_config.json - wired
save_peft_pretrained()through the same tokenizer-save helper so adapter exports do not rely onchat_template.jinjaalone - updated
HfMarinTokenizer.as_hf_tokenizer()andKitokenMarinTokenizer.as_hf_tokenizer()to carry over the in-memory chat template onto the HF tokenizer object
Added regression coverage:
lib/levanter/tests/test_lora.py- assert merged LoRA non-square weights preserve the wrapped axis order
- add a Llama-style merged-HF load regression that would catch the old transpose bug
lib/levanter/tests/test_hf_export.py- assert saved tokenizer metadata includes an embedded
chat_template
- assert saved tokenizer metadata includes an embedded
Validation:
uv run --directory lib/levanter --group test python -m pytest tests/test_hf_export.py tests/test_lora.py -q20 passed, 3 skipped
./infra/pre-commit.py --fix lib/levanter/src/levanter/lora.py lib/levanter/src/levanter/compat/hf_checkpoints.py lib/levanter/src/levanter/tokenizers.py lib/levanter/tests/test_lora.py lib/levanter/tests/test_hf_export.py- passed
Important caveat:
- this patch fixes future merged LoRA HF exports on this branch
- existing merged LoRA
hf/step-*artifacts produced before the axis-order fix should still be treated as potentially tainted until they are re-exported or repaired
User asked for a real LoRA DPO smoke run to make sure the recent export fixes did not break training or HF export:
5training stepsv5p-8us-central1- starting from
marin-community/marin-8b-instruct
I added a dedicated one-off wrapper:
What the wrapper does:
- reuses the canonical
tune_loraBloom SpecEval v2 LoRA-DPO path - pins
num_train_steps=5 - pins
steps_per_eval=5 - pins
steps_per_checkpoint=5 - pins
steps_per_hf_export=5 - keeps
model_name_or_path="marin-community/marin-8b-instruct" - keeps the LoRA adapter/reference setup from the canonical executor path
Launch attempts:
-
Initial Iris job with the inherited LoRA resource request (
ram="400g"):- job:
/ahmed/lora-dpo-smoke-export-v5p8-5step - parent executor job started
- child
/train_dporemained pending - pending reason was scheduler admission on the host RAM request:
Insufficient memory (need 400.0GB, available <~2GB>)- later also
Unsatisfied autoscaler demand ... quota-pool tier monotonicity
- job:
-
Reduced the smoke wrapper only to
ram="256g"and relaunched:- attempted fixed-name relaunches, then a fresh-name relaunch, then an auto-generated-name relaunch
- none of those follow-on launches actually created a new job object in Iris before the client RPC timed out
Infra signal observed while diagnosing:
- other Iris jobs on the cluster are currently failing with bundle-fetch timeouts from the controller, e.g.:
RuntimeError: Failed to fetch <bundle_id>: timed out
- that strongly suggests the current blocker is Iris/controller bundle distribution health, not a LoRA-DPO code regression in this branch
Status at end of this session:
- the smoke wrapper is ready and lint-clean
- the first
400gattempt was intentionally terminated after confirming the scheduler bottleneck - the
256grelaunches have not yet produced a runnable job object because of controller-side submission/bundle issues - as a result, I do not yet have a completed LoRA-DPO smoke run proving end-to-end HF export on-cluster
Retried submission after pulling current main Iris CLI behavior forward:
- top-level coordinator launched CPU-only with
--extra marin:tpu - child
train_dpostep requested the actualv5p-8worker - successful job:
/ahmed/lora-dpo-smoke-export-v5p8-5step-r2
- W&B run:
lora_smoke_export_marin_8b_instruct_v5p8_5step-e97bb1- https://wandb.ai/marin-community/dpo/runs/lora_smoke_export_marin_8b_instruct_v5p8_5step-e97bb1
Final Iris result:
- parent
/ahmed/lora-dpo-smoke-export-v5p8-5step-r2:succeeded - child
/ahmed/lora-dpo-smoke-export-v5p8-5step-r2/train_dpo:succeeded - child runtime:
57m 39s - task runtime:
56m 21s - failures:
0 - preemptions:
0
What this run proved:
- canonical Marin executor
tune_lorapath can still launch LoRA-DPO on Iris after the recent export fixes v5p-8training survived startup, reference-cache rebuild, compile, train steps, repeated eval suites, checkpoint save, and merged HF export- no
RESOURCE_EXHAUSTED, HBM,FAILED_PRECONDITION, or node-failure signal appeared during the successful run
Observed training/export milestones:
- loaded base model from
marin-community/marin-8b-instruct - rebuilt the LoRA reference-eval cache after metadata mismatch on the old cache identity schema
- completed all
5train steps - saved checkpoints to:
gs://marin-us-central1/checkpoints/dpo/tune_lora/lora_smoke_export_marin_8b_instruct_v5p8_5step-e97bb1/checkpoints/step-1gs://marin-us-central1/checkpoints/dpo/tune_lora/lora_smoke_export_marin_8b_instruct_v5p8_5step-e97bb1/checkpoints/step-4
- pruned the temporary
step-1checkpoint after the final save - completed merged HF export to:
gs://marin-us-central1/checkpoints/dpo/tune_lora/lora_smoke_export_marin_8b_instruct_v5p8_5step-e97bb1/hf/step-4
- HF export completed all
7safetensor shards and logged:Finished saving HF-compatible checkpoint
Conclusion:
- this branch now has a successful on-cluster LoRA-DPO smoke run from the upstreamed path, and the merged HF export path completed end-to-end
Follow-up validation from Codex on the exported merged HF checkpoint:
- tested checkpoint:
gs://marin-us-central1/checkpoints/dpo/tune_lora/lora_smoke_export_marin_8b_instruct_v5p8_5step-e97bb1/hf/step-4
- downstream inference output:
gs://marin-us-central1/eval/marin_dpo_lora_smoke_export_step4_bloom_speceval_r2/inference-6f1fa3
What the downstream test established:
- vLLM loaded the merged HF export successfully
- inference completed and produced coherent English outputs
- the old weight-load failure did not reproduce:
- no
assert param_data.shape == loaded_weight.shape
- no
- this confirms the recent merged-LoRA HF export fix is working for this checkpoint in an actual serving path, not just during Levanter export
Observed caveats from the downstream run:
- early launch attempts had infra issues unrelated to model export correctness:
- one TPU type had no
us-central1overlap - later child-task worker loss forced retries
- one TPU type had no
- model quality itself is weak because this is only a
step-4smoke checkpoint- export/serving correctness is now the important signal here, not benchmark quality
Net conclusion:
- LoRA-DPO smoke run succeeded on Iris
- merged HF export completed
- downstream vLLM inference loaded that export and generated real English text
- this closes the loop on the export fix for the smoke-run checkpoint