Description & Motivation
While PyTorch Lightning provides useful callbacks for monitoring learning rates, device statistics, and throughput, there is currently no built-in utility for tracking gradient behavior.
Having access to gradient statistics (e.g., norm, mean, variance, sparsity) can significantly help in debugging training instability, tune hyperparameters, and better understand model behavior.
Pitch
I propose adding a GradientStatsMonitor callback that logs gradient-related statistics during training.
The callback could:
- Compute global gradient norm across all parameters
- Optionally compute per-layer gradient norms
- Track basic statistics such as mean, standard deviation, and fraction of near-zero gradients
- Log metrics through the existing Lightning logger interface
- Optionally provide warnings for potential issues (e.g., exploding or vanishing gradients)
Alternatives
No response
Additional context
No response
cc @lantiga
Description & Motivation
While PyTorch Lightning provides useful callbacks for monitoring learning rates, device statistics, and throughput, there is currently no built-in utility for tracking gradient behavior.
Having access to gradient statistics (e.g., norm, mean, variance, sparsity) can significantly help in debugging training instability, tune hyperparameters, and better understand model behavior.
Pitch
I propose adding a
GradientStatsMonitorcallback that logs gradient-related statistics during training.The callback could:
Alternatives
No response
Additional context
No response
cc @lantiga