Skip to content

Feature Request : Gradient Statistics Monitoring Callback #21589

@Sanchay117

Description

@Sanchay117

Description & Motivation

While PyTorch Lightning provides useful callbacks for monitoring learning rates, device statistics, and throughput, there is currently no built-in utility for tracking gradient behavior.

Having access to gradient statistics (e.g., norm, mean, variance, sparsity) can significantly help in debugging training instability, tune hyperparameters, and better understand model behavior.

Pitch

I propose adding a GradientStatsMonitor callback that logs gradient-related statistics during training.

The callback could:

  • Compute global gradient norm across all parameters
  • Optionally compute per-layer gradient norms
  • Track basic statistics such as mean, standard deviation, and fraction of near-zero gradients
  • Log metrics through the existing Lightning logger interface
  • Optionally provide warnings for potential issues (e.g., exploding or vanishing gradients)

Alternatives

No response

Additional context

No response

cc @lantiga

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementneeds triageWaiting to be triaged by maintainers

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions