-
Notifications
You must be signed in to change notification settings - Fork 138
Add model merging callback with integration test #558
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Adds a trainer callback that maintains a running average of model weights during training. At merge_step, it saves the averaged model as a new checkpoint. This provides an alternative to learning rate annealing by averaging weights from multiple training steps.
- merge_step now defaults to trainer.max_steps if not explicitly set - Fixed off-by-one: merge_last_n_steps=100 now gives exactly 100 steps - Updated docstring with clearer examples
… support - Support list of merge steps (merge_step can be int or List[int]) - Auto-detect decay phase from scheduler and set merge_step accordingly - Warn if merge steps fall within the decay phase - Add evaluation support for merged models with "merged-" metric prefix - Auto-discover evaluators from EvaluatorCallback instances - Add eval_merged flag to enable/disable merged model evaluation
- Add `prefix` parameter to EvaluatorCallback._perform_eval() to support custom metric prefixes (defaults to "eval") - Simplify ModelMergeCallback._evaluate_merged() to call existing _perform_eval(prefix="eval/merged") instead of duplicating eval logic - Add `validate` flag to ModelMergeCallback for testing weight averaging correctness by comparing running sum vs stack-and-mean methods - Remove unused `evaluators` and `eval_duration` fields from ModelMergeCallback - Add integration test that trains a small 30M model for 5 steps with weight averaging validation enabled - Add GitHub Actions job to run integration tests on GPU
- Move import pytest to top of file (lint fix) - Use torchrun instead of pytest in CI for the integration test since FSDP requires distributed environment initialization Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Remove FSDP and bfloat16 (not needed for testing callback logic) - Move from GPU checks to regular CPU checks in CI - Simplify test file (remove torchrun, main function) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Keep CPU test for basic callback logic - Add GPU test with FSDP to verify sharded checkpoint handling - GPU test runs with torchrun and 2 GPUs Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
ModelMergeCallback now auto-detects WSDS schedulers and computes merge steps before each decay phase. For a ladder run with N periods, it will automatically set up N merge steps (one per anneal). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Match existing test patterns - use run_distributed_test which handles spawning processes internally, instead of torchrun wrapper. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Move accumulator and captured weights to CPU during accumulation - Explicitly move weights back to correct device when evaluating merged model - Avoids keeping extra copy of model weights on GPU during training Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Document tradeoffs for 70B+ models: - Per-step state_dict() overhead during merge window - Checkpoint size increase when saving accumulator Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Use public scheduler attributes instead of private _resolve_decay() - Rename batch_size to tokens_per_step to clarify units - Add documentation about assumption that global_batch_size is in tokens - Log tokens/step in info message so users can verify - Add debug logging for each period's computed values Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Ensures all ranks finish writing before any rank proceeds to evaluation or next merge step. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
ef69d79 to
8568747
Compare
Add -m "not gpu" to skip GPU-marked tests in the CPU CI job. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Use get_model_state_dict/set_model_state_dict for proper FSDP support - Use full_state_dict=True in accumulation to ensure correct saving - Remove broken LMEvaluatorCallbackConfig from integration test (missing label metadata) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Adds a ladder script that includes ModelMergeCallback for weight averaging experiments. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add __post_init__ validation for merge_last_n_steps and merge_step values - Fix resume handling: save accumulated weights when past merge step - Use olmo_core.io utilities for remote path support (S3/GCS) - Check for .metadata file to detect complete checkpoints (vs partial) - Add overlap detection for merge windows with helpful error message - Add warning for truncated accumulation windows - Fix decay phase warning to apply when user specifies merge_step on WSDS - Switch to sharded checkpoint approach (full_state_dict=False) - Add unit tests for validation logic Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add continue after handling past merge steps to check remaining steps - Add barriers around merged model evaluation to synchronize all ranks Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Change `enabled` default from True to False so callback must be explicitly enabled - Remove olmo3_model_merging_ladder.py (test script, not needed in PR)
| - Making accumulator checkpointing configurable | ||
| """ | ||
| state = { | ||
| "n_accumulated": self._n_accumulated, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way the trainer saves this state is not ideal for model state, especially when it's distributed/sharded. So we'll probably need to rethink this.
Here's one alternative:
Say we're merging the last 100 steps. We enforce saving a full training checkpoint right when start to accumulate, i.e. 100 steps before we want to evaluate the merged model, and we don't checkpoint again until after the merged model is evaluated. That way we wouldn't need to save the state of this call back because it would always be recomputed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes a lot of sense! So the trade-off of this would just be that if we get interrupted during the merge window, we would lose up to merge_last_n_steps - 1 steps of work, at most? Which seems reasonable while we're using fairly low values
| sd_options = dist_cp_sd.StateDictOptions(full_state_dict=False, cpu_offload=True) | ||
| model_state = dist_cp_sd.get_model_state_dict( | ||
| self.trainer.train_module.model, options=sd_options | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using dist_cp* functions comes with additional overhead that I think we can safely avoid by just collecting model parameters directly, e.g. something like model_state = {k: get_local_tensor(p.data.detach()) for k, p in model.named_parameters()}.
There's still the overhead of copying those params to CPU... One option is to keep them on GPU if we can afford it. But for larger models we're typically memory-starved, so that's doesn't scale. Alternatively we could still copy to CPU but try to hide the overhead of the copy by doing it asynchronously.
For example,
- Trigger an async copy as early as possible, like in
pre_load_batch(). You can do an async copy witholmo_core.utils.move_to_device(tensor, torch.device("cpu"), non_blocking=True). Now you can't do anything else with the CPU-tensor until you've triggered a host-device sync, so for now you wait. - Right before the optimizer step you force a host-device sync (call
torch.cuda.synchronize()) to ensure the copy to CPU has finished before params are modified. - Then accumulate those weights.
This means the accumulated weights are a step behind (since they're updated pre-optim step), but I think we can adjust the bookkeeping accordingly.
Co-authored-by: Pete Walsh <petew@allenai.org>
Co-authored-by: Pete Walsh <petew@allenai.org>
Co-authored-by: Pete Walsh <petew@allenai.org>
Co-authored-by: Pete Walsh <petew@allenai.org>
Added a model merging callback that averages model weights over the last N steps of training and saves the result as a
stepXXXX-mergeddirectoryKey features:
A few examples of recent runs with this callback: