-
Notifications
You must be signed in to change notification settings - Fork 149
Description
Reference doc: https://docs.google.com/document/d/1LyH7KURd6ShYeTKQF7ntxPB0ggZVOUFTI6TKfdVvwes/
Users need to inject custom logic (e.g., custom logging) into the training loop. The current functional hooks like forward_step_func are too coarse-grained for these needs, forcing users to maintain complex forks.
This proposal introduces a lightweight, third-party-only callback system. This system allows users to register functions to specific events (e.g., on_train_step_end). Callbacks are orchestrated by a CallbackManager class, which is passed into the main pretrain function. User functions are given a CallbackContext object that provides access to the framework state.
Background & Problem
The Megatron design for training is already designed around a functional paradigm that uses dependency injection for high-level components. For example, training accepts a forward_step_func to define the model's core computation, a model provider to specify model initialization, and a dataset provider to manage data loading.
These functional hooks are too coarse-grained for common user customization requests. For example, a user cannot log custom metrics after their training step or use a proprietary metrics logger without either relying on:
-
Piecemeal Use: Users import only the low-level components they need (e.g., models, optimizers) and write their own training loop from scratch. This is the recommended approach for projects like NeMo-RL that have a fundamentally different training paradigm.
-
Fork and Patch: Users fork the entire repository and directly modify the train.py or train_step.py files. This offers full control but creates a maintenance burden, as the user is responsible for rebasing and adapting their changes every time the upstream framework evolves.
Ideally for these cases, we'd like users to be able to consume the full framework as a package with the capability to add small, private pieces of logic for custom third-party extensions. We should create a stable, safe, and maintainable "escape hatch" that makes using the full framework viable without the high cost of a fork.
NeMo Migration
NeMo2 users are accustomed to callbacks through using PyTorch Lightning. Lightning uses a class-based callback system that is deeply integrated into the framework. Many of the framework’s key features (like checkpointing) are implemented as callbacks. This system is very powerful but complex as it grants callbacks significant control, including the ability to modify the trainer's internal state, and thus for users to override key functionality via callbacks.
The key difference here is that core functionality remains as-is (e.g. checkpointing), and callbacks serve exclusively for third party integrations that the project cannot directly host.
Reference Issues in Megatron Bridge:
Custom logging / optimizer analysis
- Conduct an analysis of the parameter's gradients right before optimizer_step runs.
- Calculate and record the source distribution of the training data once on_load_data_end is executed.
Diffusion Model support: #688
Requirements
Any solution will be evaluated against these core principles:
- First-Party Isolation: The core framework code must never use the callback system to orchestrate its own logic. It is an external-facing API only. The framework must be fully functional with callbacks disabled.
For common extensions, if we can support them directly in the project, they'll be explicitly imported and called in the code to preserve readability.
-
Framework state access: The API must provide access to the trainer/evaluation's internal state.
-
Performance: Training should have zero overhead when no callbacks are registered
-
API Consistency: callback event functions should have a consistent signature
Proposal
We will implement a lightweight, event-based callback system by introducing these components:
a CallbackManager to register and orchestrate callback functions, a Callback base class for users to optionally group together hooks and manage state, and a CallbackContext to store data for callback functions to act upon.
CallbackManager
A class that users will instantiate, register their callbacks onto, and pass into the main training entry point.
from collections import defaultdict
from typing import Callable, Optional
class CallbackManager:
"""Manages the registration and firing of callback functions."""
def __init__(self):
self.callbacks = dict[str, list[Callable]]
def register(self, event_name: str, callback_fn: Callable) -> None:
"""Register a function to be called on a specific event."""
self.callbacks[event_name].append(callback_fn)
def add(self, callback: Callback) -> None:
# Register Callback classes for convenience
def fire(self, event_name: str,context: CallbackContext) -> None:
"""Execute all functions registered for a given event. Passes kwargs to the callback functions."""
if event_name in self.callbacks:
for fn in self.callbacks[event_name]:
fn(context)
def list_callbacks(self, event_name: str) -> list[Callable[[CallbackContext], None]):
# Support introspection for what has been registered
@property
def events(self) -> frozenset[str]:
# Return set of valid event names to use for function registration
# In pretrain.py
def pretrain(config: ConfigContainer, forward_step_func: Callable, callbacks: CallbackHandler | list[Callback]| None = None) -> None:
_pretrain(..., callbacks=callbacks)
All fire calls in the framework will be wrapped in an if callbacks: check to ensure zero cost when unused.
CallbackContext
from dataclasses import dataclass
@dataclass
class CallbackContext:
"""Context passed to callbacks.
Contains framework state and a persistent user_state dict.
Modifying framework objects is at the user's own risk.
Field Availability by Event:
All events: state, model, user_state
Training events: optimizer, scheduler
on_train_step_end: loss_dict, grad_norm, skipped_iter
on_eval_end: total_loss_dict
"""
# Always available
state: GlobalState
model: list[MegatronModule]
user_state: dict
# Training events only
optimizer: MegatronOptimizer | None = None
scheduler: OptimizerParamScheduler | None = None
# on_train_step_end
loss_dict: dict[str, torch.Tensor] | None = None
grad_norm: float | None = None
skipped_iter: bool | None = None
# on_eval_end
total_loss_dict: dict[str, torch.Tensor] | None = None
User code would then look like:
Functional
from megatron.bridge.training.callback import CallbackManager
from megatron.bridge.training import pretrain
# Create handler
callback_manager = CallbackManager()
# Define callback function
def my_callback(context):
# Access framework state (read-only)
iteration = context.state.train_state.step
print(iteration)
# Register to event
callback_manager.register('on_train_step_end', my_callback)
# Pass to training
pretrain(config, forward_step_func, callbacks=callback_manager)
Class-based
class TimingCallback(Callback):
def on_train_start(self, context: CallbackContext) -> None:
context.user_state['train_start'] = time.time()
def on_train_end(self, context: CallbackContext) -> None:
elapsed = time.time() - context.user_state['train_start']
print(f"Training completed in {elapsed:.2f}s")
def on_eval_start(self, context: CallbackContext) -> None:
context.user_state['eval_start'] = time.time()
def on_eval_end(self, context: CallbackContext) -> None:
elapsed = time.time() - context.user_state['eval_start']
print(f"Eval completed in {elapsed:.2f}s")
callback_manager = CallbackManager()
callback_manager.add(TimingCallback())
callback_manager.add([MyCallback1(), MyCallback2()]
At known points in the training/evaluation loop (e.g. on_eval_end) we will run all of the callback functions registered for that event.
Supported Hooks
We will support the following hooks:
on_train_starton_train_step_starton_train_step_endon_train_endon_eval_starton_eval_step_starton_eval_step_endon_eval_end
For simplicity:
- No control flow from callbacks back to the training loop. For instance, these callbacks cannot signal to stop training.
- No exception handling for callbacks is provided from the training framework side: users are expected to handle exceptions within their callback functions since these functions are strictly optional from the framework POV.
- Distributed control is left to the callback implementor: Callback functions run on all ranks. It is the user’s responsibility in a distributed environment to gate which logic to specific ranks, if needed.