Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions docs/guides/callbacks.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Custom Callbacks

## Introduction

Callbacks provide a flexible way to inject custom logic into the training loop without modifying recipe code. They enable integration with external systems, custom logging, metrics collection, and monitoring.

Callbacks work across **all NeMo AutoModel recipes**: LLM fine-tuning, VLM fine-tuning, sequence classification, and knowledge distillation.

## Key Features

- **PyTorch Lightning-style API**: Familiar hooks like `on_train_start`, `on_train_batch_end`, `on_validation_end`
- **Programmatic integration**: Pass callbacks directly to recipe constructors
- **Full training context**: Access to recipe state, metrics, and checkpoint information
- **Distributed training support**: Includes `@rank_zero_only` decorator for multi-GPU environments

## Available Hooks

| Hook | When Called | Key Arguments |
|------|------------|---------------|
| `on_train_start` | After setup, before training begins | `recipe` |
| `on_train_batch_end` | After each training step | `recipe`, `train_log_data` |
| `on_validation_end` | After validation completes | `recipe`, `val_results` |
| `on_save_checkpoint` | When checkpoint is saved | `recipe`, `checkpoint_info` |
| `on_exception` | When training fails | `recipe`, `exception` |
| `on_train_end` | When training completes successfully | `recipe` |

## Quick Example

### 1. Define a Custom Callback

```python
from nemo_automodel.components.callbacks import Callback

class MetricsReporterCallback(Callback):
"""Report metrics to external API."""

def on_train_batch_end(self, recipe, **kwargs):
train_log_data = kwargs['train_log_data']
step = train_log_data.step
loss = train_log_data.metrics['loss']

# Send to your API, database, or monitoring system
print(f"Step {step}: Loss = {loss:.4f}")

def on_validation_end(self, recipe, **kwargs):
val_results = kwargs['val_results']

# val_results is a dict: {"validation": MetricsSample, ...}
for name, log_data in val_results.items():
val_loss = log_data.metrics['val_loss']
print(f"Validation '{name}': Loss = {val_loss:.4f}")

def on_save_checkpoint(self, recipe, **kwargs):
checkpoint_info = kwargs['checkpoint_info']
print(f"Checkpoint saved: {checkpoint_info['checkpoint_path']}")
```

### 2. Pass to Recipe

```python
from nemo_automodel.recipes.llm.train_ft import TrainFinetuneRecipeForNextTokenPrediction

# Instantiate your callbacks
metrics_callback = MetricsReporterCallback()

# Pass to recipe constructor
recipe = TrainFinetuneRecipeForNextTokenPrediction(
cfg,
callbacks=[metrics_callback]
)

recipe.setup()
recipe.run_train_validation_loop()
```

## Distributed Training

In multi-GPU training, callbacks run on **all ranks**. Python `logging` (e.g., `logger.info()`) is automatically filtered to rank 0, but other operations need explicit handling.

### Use `@rank_zero_only` Decorator

```python
from nemo_automodel.components.callbacks import Callback, rank_zero_only

class CustomizerCallback(Callback):
@rank_zero_only
def on_validation_end(self, recipe, **kwargs):
# This only runs on rank 0
val_results = kwargs['val_results']

# Safe to do file I/O, API calls, etc.
requests.post('https://api.example.com/metrics', json=val_results)
```

### Use Manual Rank Checking

```python
def on_train_batch_end(self, recipe, **kwargs):
# Check if main rank before doing expensive operations
if recipe.dist_env.is_main:
# Do rank-0-only work (file I/O, API calls, etc.)
save_metrics_to_file()
```

## Hook Details

For complete API reference, see the [API documentation](../apidocs/index.rst).

### `on_train_batch_end`

```python
def on_train_batch_end(self, recipe, **kwargs):
train_log_data = kwargs['train_log_data'] # MetricsSample
# Fields: train_log_data.step, .epoch, .metrics (dict), .timestamp
```

### `on_validation_end`

```python
def on_validation_end(self, recipe, **kwargs):
val_results = kwargs['val_results'] # dict[str, MetricsSample]
# For single validation set: {"validation": MetricsSample}
# For multiple validation sets: {"squad": MetricsSample, "hellaswag": MetricsSample, ...}
```

### `on_save_checkpoint`

```python
def on_save_checkpoint(self, recipe, **kwargs):
checkpoint_info = kwargs['checkpoint_info'] # dict
# Fields: 'epoch', 'step', 'train_loss', 'val_losses', 'checkpoint_path', 'best_metric_key'
```

## Complete Example

See [`examples/llm_finetune/finetune_with_callback.py`](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/llm_finetune/finetune_with_callback.py) for a full working example demonstrating:
- Multiple callbacks
- Distributed training with `@rank_zero_only`
- Metrics collection for external reporting
- Custom logging with prefixes

### Running the Example

```bash
# Single GPU
uv run python examples/llm_finetune/finetune_with_callback.py

# Multi-GPU
uv run torchrun --nproc-per-node=8 examples/llm_finetune/finetune_with_callback.py
```

## Use Cases

- **External integrations**: Report metrics to W&B, MLflow, Customizer, or custom APIs
- **Progress monitoring**: Send Slack/email notifications on training milestones or failures
- **Custom metrics collection**: Track and store domain-specific metrics beyond standard training logs
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ guides/pipelining.md
guides/llm/knowledge-distillation.md
guides/fp8-training.md
guides/mlflow-logging.md
guides/callbacks.md

apidocs/index.rst
```
Expand Down
198 changes: 198 additions & 0 deletions examples/llm_finetune/finetune_with_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Example: Using Custom Callbacks with Automodel Training

This example demonstrates how to use callbacks to hook into the training loop
for custom logging, monitoring, or integration with external systems.

Usage:
# Using default small model (270M - fastest for testing)
python examples/llm_finetune/finetune_with_callback.py

# Or specify a different model
python examples/llm_finetune/finetune_with_callback.py \\
-c examples/llm_finetune/llama3_2/llama3_2_1b_hellaswag.yaml
"""

from __future__ import annotations

import logging

from nemo_automodel.components.callbacks import Callback, rank_zero_only
from nemo_automodel.components.config._arg_parser import parse_args_and_load_config
from nemo_automodel.recipes.llm.train_ft import TrainFinetuneRecipeForNextTokenPrediction

logger = logging.getLogger(__name__)


class SimpleLoggingCallback(Callback):
"""
A basic callback that logs training progress at key milestones.

Note: logger.info() is automatically filtered to rank 0 by RankFilter,
so no explicit rank checking is needed for simple logging.
"""

def on_train_start(self, recipe, **kwargs):
# Basic logging works without rank checks (automatically filtered to rank 0 with RankFilter)
logger.info("[SimpleLoggingCallback] πŸ”₯ Training is starting!")
logger.info(f"[SimpleLoggingCallback] World size: {recipe.dist_env.world_size} GPUs")
logger.info(f"[SimpleLoggingCallback] Total steps: {recipe.step_scheduler.max_steps}")

def on_train_batch_end(self, recipe, **kwargs):
step = recipe.step_scheduler.step
if step % 10 == 0:
metrics = kwargs["train_log_data"].metrics
logger.info(
f"[SimpleLoggingCallback] πŸš€ Step {step}/{recipe.step_scheduler.max_steps}: "
f"Loss = {metrics['loss']:.4f}, LR = {metrics['lr']:.2e}"
)

def on_validation_end(self, recipe, **kwargs):
val_results = kwargs["val_results"]
# val_results is a dict: {"validation": MetricsSample, ...}
for name, log_data in val_results.items():
logger.info(f"[SimpleLoggingCallback] βœ… Validation '{name}': Loss = {log_data.metrics['val_loss']:.4f}")

def on_save_checkpoint(self, recipe, **kwargs):
checkpoint_info = kwargs["checkpoint_info"]
logger.info(
f"[SimpleLoggingCallback] πŸ’Ύ Checkpoint saved at step {checkpoint_info['step']}, "
f"epoch {checkpoint_info['epoch']}, path: {checkpoint_info['checkpoint_path']}"
)

def on_train_end(self, recipe, **kwargs):
logger.info(
f"[SimpleLoggingCallback] πŸŽ‰ Training completed successfully! Final step: {recipe.step_scheduler.step}"
)

def on_exception(self, recipe, **kwargs):
exception = kwargs["exception"]
logger.error(f"[SimpleLoggingCallback] ❌ Training failed: {exception}")


class MetricsCollectorCallback(Callback):
"""
Example callback that collects metrics for external reporting.

In a real scenario, this could report to an API, database, or monitoring system.

Note: In distributed training (multiple GPUs), callbacks run on ALL ranks.
This example shows TWO ways to handle rank filtering:

1. Manual checking: if recipe.dist_env.is_main
2. Using @rank_zero_only decorator (recommended for cleaner code)
"""

def __init__(self):
self.training_metrics = []
self.validation_metrics = []
self.checkpoints = []

def on_train_batch_end(self, recipe, **kwargs):
step = recipe.step_scheduler.step
metrics = kwargs["train_log_data"].metrics

# Collect metrics (happens on all ranks, but that's fine for local state)
self.training_metrics.append(
{
"step": step,
"loss": metrics["loss"],
"lr": metrics["lr"],
}
)

# Method 1: Manual rank checking
# In a real use case, only rank 0 should send to external APIs:
# if recipe.dist_env.is_main and step % 100 == 0:
# requests.post('https://api.example.com/metrics', json=metrics)

@rank_zero_only # Method 2: Using decorator (cleaner!)
def on_validation_end(self, recipe, **kwargs):
val_results = kwargs["val_results"]

# This only runs on rank 0 thanks to @rank_zero_only
# val_results is a dict: {"validation": MetricsSample, "squad": MetricsSample, ...}
for name, log_data in val_results.items():
self.validation_metrics.append(
{
"step": log_data.step,
"epoch": log_data.epoch,
"validation_name": name,
"metrics": log_data.metrics, # Full metrics dict (val_loss, accuracy, etc.)
}
)

logger.info(f"[MetricsCollectorCallback] πŸ“Š Collected {len(self.validation_metrics)} validation checkpoints")

@rank_zero_only
def on_save_checkpoint(self, recipe, **kwargs):
checkpoint_info = kwargs["checkpoint_info"]

# Track checkpoint information for external reporting
self.checkpoints.append(
{
"step": checkpoint_info["step"],
"epoch": checkpoint_info["epoch"],
"train_loss": checkpoint_info["train_loss"],
"val_losses": checkpoint_info["val_losses"],
"path": checkpoint_info["checkpoint_path"],
}
)

logger.info(
f"[MetricsCollectorCallback] πŸ’Ύ Tracked checkpoint {len(self.checkpoints)}: "
f"step={checkpoint_info['step']}, train_loss={checkpoint_info['train_loss']:.4f}"
)


def main(default_config_path="examples/llm_finetune/gemma/gemma_3_270m_squad_peft.yaml"):
"""
Main entry point for fine-tuning with custom callbacks.

This example shows how to use multiple callbacks simultaneously.

For faster testing, this uses Gemma 3 270M (the smallest model available).
You can change the config to use larger models like:
- examples/llm_finetune/llama3_2/llama3_2_1b_hellaswag.yaml
- examples/llm_finetune/granite/granite_3_3_2b_instruct_squad_peft.yaml
"""
cfg = parse_args_and_load_config(default_config_path)

# Instantiate multiple callbacks
logging_callback = SimpleLoggingCallback()
metrics_callback = MetricsCollectorCallback()

# Pass them to the recipe (they'll be called in order)
recipe = TrainFinetuneRecipeForNextTokenPrediction(cfg, callbacks=[logging_callback, metrics_callback])

recipe.setup()

try:
recipe.run_train_validation_loop()
except Exception as e:
logger.error(f"[Main] Training failed: {e}")
raise

# After training, you can access collected metrics
logger.info(
f"[MetricsCollectorCallback] πŸŽ‰ Training complete! "
f"Collected {len(metrics_callback.training_metrics)} training steps"
)


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions nemo_automodel/_cli/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def get_recipe_script_path(command: str, domain: str, repo_root: str | Path) ->
str: Full path to the recipe script
"""
recipe_name = COMMAND_ALIASES.get(command, command)
# VLM uses finetune.py instead of train_ft.py
if domain == "vlm" and command == "finetune":
recipe_name = "finetune"
return f"{repo_root}/nemo_automodel/recipes/{domain}/{recipe_name}.py"


Expand Down
Loading