Skip to content

Conversation

@baileykuehl
Copy link
Contributor

@baileykuehl baileykuehl commented Jan 21, 2026

Added a model merging callback that averages model weights over the last N steps of training and saves the result as a stepXXXX-merged directory

Key features:

  • Disabled by default (enabled=False) because it seems to somewhat affect performance at merge steps on larger models (measuring this impact was out of scope for now, will follow-up with experiments if this gets adopted)
  • For WSD-S schedulers, automatically merges and evaluates right before each decay phase begins. This overlaps with the pre-decay eval in the old ladder.
  • Saves accumulator state to checkpoints, so interrupted runs can resume mid-merge-window without losing progress
  • Integration test added that trains a small model for 5 steps (averages the last 3 steps)

A few examples of recent runs with this callback:

@baileykuehl baileykuehl changed the title Bk/model merging Add model merging callback and integration test Jan 21, 2026
@baileykuehl baileykuehl changed the title Add model merging callback and integration test [DRAFT] Add model merging callback and integration test Jan 21, 2026
baileykuehl and others added 22 commits January 21, 2026 11:05
Adds a trainer callback that maintains a running average of model weights
during training. At merge_step, it saves the averaged model as a new checkpoint.

This provides an alternative to learning rate annealing by averaging weights
from multiple training steps.
- merge_step now defaults to trainer.max_steps if not explicitly set
- Fixed off-by-one: merge_last_n_steps=100 now gives exactly 100 steps
- Updated docstring with clearer examples
… support

- Support list of merge steps (merge_step can be int or List[int])
- Auto-detect decay phase from scheduler and set merge_step accordingly
- Warn if merge steps fall within the decay phase
- Add evaluation support for merged models with "merged-" metric prefix
- Auto-discover evaluators from EvaluatorCallback instances
- Add eval_merged flag to enable/disable merged model evaluation
- Add `prefix` parameter to EvaluatorCallback._perform_eval() to support
  custom metric prefixes (defaults to "eval")
- Simplify ModelMergeCallback._evaluate_merged() to call existing
  _perform_eval(prefix="eval/merged") instead of duplicating eval logic
- Add `validate` flag to ModelMergeCallback for testing weight averaging
  correctness by comparing running sum vs stack-and-mean methods
- Remove unused `evaluators` and `eval_duration` fields from ModelMergeCallback
- Add integration test that trains a small 30M model for 5 steps with
  weight averaging validation enabled
- Add GitHub Actions job to run integration tests on GPU
- Move import pytest to top of file (lint fix)
- Use torchrun instead of pytest in CI for the integration test
  since FSDP requires distributed environment initialization

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Remove FSDP and bfloat16 (not needed for testing callback logic)
- Move from GPU checks to regular CPU checks in CI
- Simplify test file (remove torchrun, main function)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Keep CPU test for basic callback logic
- Add GPU test with FSDP to verify sharded checkpoint handling
- GPU test runs with torchrun and 2 GPUs

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
ModelMergeCallback now auto-detects WSDS schedulers and computes
merge steps before each decay phase. For a ladder run with N periods,
it will automatically set up N merge steps (one per anneal).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Match existing test patterns - use run_distributed_test which handles
spawning processes internally, instead of torchrun wrapper.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Move accumulator and captured weights to CPU during accumulation
- Explicitly move weights back to correct device when evaluating merged model
- Avoids keeping extra copy of model weights on GPU during training

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Document tradeoffs for 70B+ models:
- Per-step state_dict() overhead during merge window
- Checkpoint size increase when saving accumulator

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Use public scheduler attributes instead of private _resolve_decay()
- Rename batch_size to tokens_per_step to clarify units
- Add documentation about assumption that global_batch_size is in tokens
- Log tokens/step in info message so users can verify
- Add debug logging for each period's computed values

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Ensures all ranks finish writing before any rank proceeds to
evaluation or next merge step.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
baileykuehl and others added 5 commits January 21, 2026 16:44
Add -m "not gpu" to skip GPU-marked tests in the CPU CI job.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Use get_model_state_dict/set_model_state_dict for proper FSDP support
- Use full_state_dict=True in accumulation to ensure correct saving
- Remove broken LMEvaluatorCallbackConfig from integration test (missing label metadata)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
baileykuehl and others added 15 commits January 23, 2026 16:45
Adds a ladder script that includes ModelMergeCallback for weight averaging experiments.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add __post_init__ validation for merge_last_n_steps and merge_step values
- Fix resume handling: save accumulated weights when past merge step
- Use olmo_core.io utilities for remote path support (S3/GCS)
- Check for .metadata file to detect complete checkpoints (vs partial)
- Add overlap detection for merge windows with helpful error message
- Add warning for truncated accumulation windows
- Fix decay phase warning to apply when user specifies merge_step on WSDS
- Switch to sharded checkpoint approach (full_state_dict=False)
- Add unit tests for validation logic

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add continue after handling past merge steps to check remaining steps
- Add barriers around merged model evaluation to synchronize all ranks

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Change `enabled` default from True to False so callback must be
  explicitly enabled
- Remove olmo3_model_merging_ladder.py (test script, not needed in PR)
@baileykuehl baileykuehl changed the title [DRAFT] Add model merging callback and integration test Add model merging callback with integration test Jan 30, 2026
@baileykuehl baileykuehl requested a review from dirkgr February 2, 2026 18:09
- Making accumulator checkpointing configurable
"""
state = {
"n_accumulated": self._n_accumulated,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way the trainer saves this state is not ideal for model state, especially when it's distributed/sharded. So we'll probably need to rethink this.

Here's one alternative:

Say we're merging the last 100 steps. We enforce saving a full training checkpoint right when start to accumulate, i.e. 100 steps before we want to evaluate the merged model, and we don't checkpoint again until after the merged model is evaluated. That way we wouldn't need to save the state of this call back because it would always be recomputed.

Copy link
Contributor Author

@baileykuehl baileykuehl Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes a lot of sense! So the trade-off of this would just be that if we get interrupted during the merge window, we would lose up to merge_last_n_steps - 1 steps of work, at most? Which seems reasonable while we're using fairly low values

Comment on lines 404 to 407
sd_options = dist_cp_sd.StateDictOptions(full_state_dict=False, cpu_offload=True)
model_state = dist_cp_sd.get_model_state_dict(
self.trainer.train_module.model, options=sd_options
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using dist_cp* functions comes with additional overhead that I think we can safely avoid by just collecting model parameters directly, e.g. something like model_state = {k: get_local_tensor(p.data.detach()) for k, p in model.named_parameters()}.

There's still the overhead of copying those params to CPU... One option is to keep them on GPU if we can afford it. But for larger models we're typically memory-starved, so that's doesn't scale. Alternatively we could still copy to CPU but try to hide the overhead of the copy by doing it asynchronously.

For example,

  1. Trigger an async copy as early as possible, like in pre_load_batch(). You can do an async copy with olmo_core.utils.move_to_device(tensor, torch.device("cpu"), non_blocking=True). Now you can't do anything else with the CPU-tensor until you've triggered a host-device sync, so for now you wait.
  2. Right before the optimizer step you force a host-device sync (call torch.cuda.synchronize()) to ensure the copy to CPU has finished before params are modified.
  3. Then accumulate those weights.

This means the accumulated weights are a step behind (since they're updated pre-optim step), but I think we can adjust the bookkeeping accordingly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants