Skip to content

Conversation

@ananthsub
Copy link
Contributor

@ananthsub ananthsub commented Jan 26, 2026

What does this PR do ?

Fixes #2064

Adds a callback system for third-party extension for users to register functions/classes to run during training/evaluation. This PR adds the following hooks:

  • on_train_start
  • on_train_step_start
  • on_train_step_end
  • on_train_end
  • on_eval_start
  • on_eval_step_start
  • on_eval_step_end
  • on_eval_end
  • on_test_start
  • on_test_step_start
  • on_test_step_end
  • on_test_end

Changelog

  • Support callbacks for third-party extensions

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

Summary by CodeRabbit

Release Notes

  • New Features

    • Added a callback framework enabling custom hooks into training and evaluation lifecycle events (start, step start/end, end).
    • Callbacks support both class-based and functional registration patterns with persistent state management.
  • Documentation

    • Added comprehensive documentation covering callback system usage, patterns, and API reference.

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 26, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ananthsub
Copy link
Contributor Author

/ok to test 065acea

Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
@ananthsub
Copy link
Contributor Author

/ok to test f9e411c

@ananthsub ananthsub changed the title Support callbacks for third-party extensions feat: Support callbacks for third-party extensions Jan 26, 2026
@ananthsub ananthsub marked this pull request as ready for review January 26, 2026 12:54
@coderabbitai
Copy link

coderabbitai bot commented Jan 26, 2026

📝 Walkthrough

Walkthrough

This PR introduces a lightweight event-based callback system for Megatron-Bridge training workflows. It adds CallbackManager to orchestrate callbacks, CallbackContext to provide framework state, and Callback as a base class for users to define custom logic. The system integrates eight hooks (on_train_start, on_train_step_start, on_train_step_end, on_train_end, on_eval_start, on_eval_step_start, on_eval_step_end, on_eval_end) into pretrain, train, and eval functions. Documentation covers design, usage patterns, and API reference.

Changes

Cohort / File(s) Change Summary
Documentation Updates
docs/index.md, docs/training/README.md, docs/training/callbacks.md
Added callbacks reference to Training toctree and README navigation; new comprehensive 215-line documentation file covering callback system design, quick-start patterns, event lifecycle, CallbackContext fields, distributed training considerations, and API reference.
New Callback Framework
src/megatron/bridge/training/callbacks.py
New 337-line module introducing public exports: VALID_EVENTS (frozenset), CallbackContext (dataclass), Callback (base class), CallbackManager (event orchestrator), normalize_callbacks and should_fire (utility functions). Supports class-based and functional callback registration patterns with per-event function firing in registration order.
Training Loop Integration
src/megatron/bridge/training/pretrain.py, src/megatron/bridge/training/train.py
Added callbacks parameter to pretrain; propagates normalized CallbackManager through _pretrain and train. train() fires on_train_start/end and on_train_step_start/end with context containing loss_dict, grad_norm, skipped_iter; user_state persisted across all invocations.
Evaluation Loop Integration
src/megatron/bridge/training/eval.py
Added callback_manager and callback_user_state parameters to evaluate and evaluate_and_print_results; fires on_eval_start/end and on_eval_step_start/end around evaluation stages with appropriate context fields.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant pretrain
    participant _pretrain
    participant CallbackManager
    participant train
    participant evaluate_and_print_results

    User->>pretrain: pretrain(config, forward_step_func, callbacks=[...])
    pretrain->>CallbackManager: normalize_callbacks(callbacks)
    CallbackManager-->>pretrain: CallbackManager instance
    pretrain->>_pretrain: _pretrain(..., callback_manager)
    _pretrain->>train: train(..., callback_manager, callback_user_state)
    
    train->>CallbackManager: fire('on_train_start', context)
    CallbackManager-->>train: execute all registered callbacks
    
    loop for each training step
        train->>CallbackManager: fire('on_train_step_start', context)
        CallbackManager-->>train: execute callbacks
        train->>train: forward_backward_func()
        train->>CallbackManager: fire('on_train_step_end', context with loss_dict, grad_norm)
        CallbackManager-->>train: execute callbacks
    end
    
    train->>evaluate_and_print_results: evaluate_and_print_results(..., callback_manager, callback_user_state)
    evaluate_and_print_results->>CallbackManager: fire('on_eval_start', context)
    CallbackManager-->>evaluate_and_print_results: execute callbacks
    
    loop for each eval step
        evaluate_and_print_results->>CallbackManager: fire('on_eval_step_start', context)
        CallbackManager-->>evaluate_and_print_results: execute callbacks
        evaluate_and_print_results->>evaluate_and_print_results: forward_backward_func()
        evaluate_and_print_results->>CallbackManager: fire('on_eval_step_end', context)
        CallbackManager-->>evaluate_and_print_results: execute callbacks
    end
    
    evaluate_and_print_results->>CallbackManager: fire('on_eval_end', context with total_loss_dict)
    CallbackManager-->>evaluate_and_print_results: execute callbacks
    
    train->>CallbackManager: fire('on_train_end', context)
    CallbackManager-->>train: execute callbacks
    
    _pretrain-->>pretrain: completed
    pretrain-->>User: training finished
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~35 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR implements major callback system feature but lacks documented test results, test file paths, coverage metrics, or performance validation of zero-cost overhead claim. Provide test file paths, execution results, performance benchmarks validating zero-cost overhead, and confirmation that existing tests pass without regressions.
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: Support callbacks for third-party extensions' directly and clearly describes the main change: adding callback support for third-party extensions.
Linked Issues check ✅ Passed The PR fully implements all core requirements from issue #2064: CallbackManager for registration, Callback base class, CallbackContext with framework state access, all eight hook events, functional and class-based patterns, zero-cost when unused, and proper first-party isolation.
Out of Scope Changes check ✅ Passed All changes are directly related to implementing the callback system. Documentation updates and integration into training/evaluation loops are within scope; no unrelated modifications detected.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Docstrings were successfully generated.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Callback support for third-party extensions

2 participants