-
Notifications
You must be signed in to change notification settings - Fork 149
feat: Support callbacks for third-party extensions #2063
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
|
/ok to test 065acea |
|
/ok to test f9e411c |
📝 WalkthroughWalkthroughThis PR introduces a lightweight event-based callback system for Megatron-Bridge training workflows. It adds CallbackManager to orchestrate callbacks, CallbackContext to provide framework state, and Callback as a base class for users to define custom logic. The system integrates eight hooks (on_train_start, on_train_step_start, on_train_step_end, on_train_end, on_eval_start, on_eval_step_start, on_eval_step_end, on_eval_end) into pretrain, train, and eval functions. Documentation covers design, usage patterns, and API reference. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant pretrain
participant _pretrain
participant CallbackManager
participant train
participant evaluate_and_print_results
User->>pretrain: pretrain(config, forward_step_func, callbacks=[...])
pretrain->>CallbackManager: normalize_callbacks(callbacks)
CallbackManager-->>pretrain: CallbackManager instance
pretrain->>_pretrain: _pretrain(..., callback_manager)
_pretrain->>train: train(..., callback_manager, callback_user_state)
train->>CallbackManager: fire('on_train_start', context)
CallbackManager-->>train: execute all registered callbacks
loop for each training step
train->>CallbackManager: fire('on_train_step_start', context)
CallbackManager-->>train: execute callbacks
train->>train: forward_backward_func()
train->>CallbackManager: fire('on_train_step_end', context with loss_dict, grad_norm)
CallbackManager-->>train: execute callbacks
end
train->>evaluate_and_print_results: evaluate_and_print_results(..., callback_manager, callback_user_state)
evaluate_and_print_results->>CallbackManager: fire('on_eval_start', context)
CallbackManager-->>evaluate_and_print_results: execute callbacks
loop for each eval step
evaluate_and_print_results->>CallbackManager: fire('on_eval_step_start', context)
CallbackManager-->>evaluate_and_print_results: execute callbacks
evaluate_and_print_results->>evaluate_and_print_results: forward_backward_func()
evaluate_and_print_results->>CallbackManager: fire('on_eval_step_end', context)
CallbackManager-->>evaluate_and_print_results: execute callbacks
end
evaluate_and_print_results->>CallbackManager: fire('on_eval_end', context with total_loss_dict)
CallbackManager-->>evaluate_and_print_results: execute callbacks
train->>CallbackManager: fire('on_train_end', context)
CallbackManager-->>train: execute callbacks
_pretrain-->>pretrain: completed
pretrain-->>User: training finished
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~35 minutes 🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
What does this PR do ?
Fixes #2064
Adds a callback system for third-party extension for users to register functions/classes to run during training/evaluation. This PR adds the following hooks:
on_train_starton_train_step_starton_train_step_endon_train_endon_eval_starton_eval_step_starton_eval_step_endon_eval_endon_test_starton_test_step_starton_test_step_endon_test_endChangelog
GitHub Actions CI
See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Documentation
✏️ Tip: You can customize this high-level summary in your review settings.