A lightweight, hook-based training loop for PyTorch. trainer-tools abstracts away the boilerplate of training loops while remaining fully customizable via a powerful flexible hook system.
- Hook System: Customize every step of the training lifecycle (before/after batch, step, epoch, fit).
- Built-in Integrations: Comes with hooks for wandb or trackio, Progress Bar, and Checkpointing.
- Optimization: Easy Automatic Mixed Precision (AMP), Gradient Accumulation, and Gradient Clipping.
- Metrics: robust metric tracking and logging to JSONL or external trackers.
- Memory Profiling: Built-in tools to debug CUDA memory leaks.
pip install trainer-tools
# With optional integrations
pip install trainer-tools[wandb] # Weights & Biases logging
pip install trainer-tools[trackio] # Trackio logging
pip install trainer-tools[hydra] # Hydra config management
pip install trainer-tools[all] # All optional dependenciesHere is a minimal example of training a simple model:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from trainer_tools.trainer import Trainer
from trainer_tools.hooks import MetricsHook, Accuracy, Loss, ProgressBarHook
# 1. Prepare Data
x = torch.randn(100, 10)
y = torch.randint(0, 2, (100,))
ds = TensorDataset(x, y)
dl = DataLoader(ds, batch_size=32)
# 2. Define Model
model = nn.Sequential(nn.Linear(10, 2))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 3. Setup Hooks
metrics = MetricsHook(metrics=[Accuracy(), Loss()])
pbar = ProgressBarHook()
# 4. Train
trainer = Trainer(
model=model,
train_dl=dl,
valid_dl=dl,
optim=optimizer,
loss_func=nn.CrossEntropyLoss(),
epochs=5,
hooks=[metrics, pbar],
device="cuda" if torch.cuda.is_available() else "cpu"
)
trainer.fit()Trainer.fit() runs a standard PyTorch training loop but exposes hook points at every meaningful stage so you can inject logic without touching the core loop:
fit()
├── before_fit
└── for each epoch:
├── before_epoch
├── for each training batch:
│ ├── before_step
│ ├── predict() → trainer.preds
│ ├── after_pred
│ ├── get_loss() → trainer.loss / trainer.loss_t # float loss / tensor loss
│ ├── after_loss
│ ├── loss_t.backward()
│ ├── after_backward
│ ├── opt.step() / opt.zero_grad()
│ └── after_step
├── before_valid
├── for each validation batch:
│ └── (same as above, no backward/opt steps)
└── after_epoch
└── after_fit (or after_cancel on KeyboardInterrupt)
At any hook point you have access to the live trainer state:
| Attribute | Description |
|---|---|
trainer.model |
The nn.Module being trained |
trainer.opt |
The optimizer |
trainer.epoch / trainer.step |
Current epoch / global step count |
trainer.batch |
The raw batch from the dataloader |
trainer.preds |
Model predictions (set after predict()) |
trainer.loss / trainer.loss_t |
Scalar loss value / loss tensor |
trainer.training |
True during the train phase, False during validation |
trainer.dl |
DataLoader currently in use |
trainer.config |
Optional config object (e.g. Hydra DictConfig) |
Three boolean flags let a hook short-circuit the default behavior for a single step:
| Flag | Effect when set to True |
|---|---|
trainer.skip_backward |
Skips loss.backward() |
trainer.skip_opt_step |
Skips opt.step() |
trainer.skip_zero_grad |
Skips opt.zero_grad() |
Hooks are sorted by their ord attribute before each call. Lower values run first. The default is 0. This guarantees correct ordering when hooks depend on one another (e.g., CheckpointHook runs before LRSchedulerHook so a restored scheduler state is used from the beginning).
trainer-tools relies on BaseHook. You can create custom behavior by subclassing it:
from trainer_tools.hooks import BaseHook
class MyCustomHook(BaseHook):
def after_step(self, trainer):
if trainer.step % 100 == 0:
print(f"Current Loss: {trainer.loss}")Displays tqdm progress bars for epochs and batches. Shows a running training loss that updates every freq steps.
from trainer_tools.hooks import ProgressBarHook
pbar = ProgressBarHook(freq=10)Central hub for computing, aggregating, and logging metrics. Supports logging to the console, a JSONL history file, Weights & Biases, or Trackio.
from trainer_tools.hooks import MetricsHook, Loss
from trainer_tools.hooks.metrics import Accuracy
metrics = MetricsHook(
metrics=[Loss(), Accuracy()],
tracker_type="wandb", # or "trackio" / None
project="my-project",
name="run_1"
)Metrics are split into phases ("step" or "epoch") and automatically prefixed with train_ / valid_.
Wraps any PyTorch LRScheduler and calls sched.step() after every optimizer update.
from trainer_tools.hooks import LRSchedulerHook
hook = LRSchedulerHook(
sched_fn=lambda opt: torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=1000)
)Enables Automatic Mixed Precision training by wrapping the forward pass in torch.autocast and managing a GradScaler. Supports both float16 and bfloat16.
from trainer_tools.hooks import AMPHook
amp = AMPHook(dtype=torch.bfloat16, device_type="cuda")Clips gradient norms before each optimizer step to stabilize training.
from trainer_tools.hooks import GradClipHook
clip = GradClipHook(max_norm=1.0)Accumulates gradients over multiple micro-batches before calling opt.step(), effectively increasing the batch size without extra GPU memory.
from trainer_tools.hooks import GradientAccumulationHook
accum = GradientAccumulationHook(accumulation_steps=4)Saves and restores model, optimizer, scheduler, scaler, and RNG states. Supports both a "best" (by a tracked metric) and "latest" save strategy, and keeps a configurable number of recent checkpoints.
from trainer_tools.hooks import CheckpointHook
ckpt = CheckpointHook(
save_dir="checkpoints/",
save_every_steps=500,
keep_last=3,
save_strategy="best", # or "latest"
metric_name="valid_loss",
resume_path="checkpoints/step_1000", # optional
)Maintains an exponential moving average (EMA) of model weights. Validation is automatically run against the EMA model, which often gives better generalization. The EMA state is saved and restored with CheckpointHook.
from trainer_tools.hooks import EMAHook
ema = EMAHook(decay=0.9999)Applies on-GPU data augmentations or pre-processing transforms to inputs and/or targets at the start of each batch. Separate transforms can be provided for training and validation.
from trainer_tools.hooks import BatchTransformHook
import torchvision.transforms.v2 as T
aug = BatchTransformHook(
x_tfm=T.RandomHorizontalFlip(),
x_tfms_valid=None, # no aug during validation
)Integrates HuggingFace Accelerate for distributed training (DDP/FSDP), mixed precision, and gradient accumulation in a single hook. When used, do not add AMPHook, GradClipHook, or GradientAccumulationHook — Accelerate handles all of that.
from trainer_tools.hooks.accelerate import AccelerateHook
accel = AccelerateHook(
gradient_accumulation_steps=4,
max_grad_norm=1.0,
mixed_precision="bf16",
)