Skip to content

feat(metrics): add EnsembleMSE and EnsembleRMSE streaming metrics#1701

Open
harshaa765 wants to merge 3 commits into
NVIDIA:mainfrom
harshaa765:add-ensemble-mse-rmse-metrics
Open

feat(metrics): add EnsembleMSE and EnsembleRMSE streaming metrics#1701
harshaa765 wants to merge 3 commits into
NVIDIA:mainfrom
harshaa765:add-ensemble-mse-rmse-metrics

Conversation

@harshaa765

Copy link
Copy Markdown

Summary

  • Adds EnsembleMSE and EnsembleRMSE as stateful streaming metric classes to physicsnemo/metrics/general/ensemble_metrics.py, extending the existing EnsembleMetrics base class
  • Both classes support single-pass (__call__) and incremental (update / finalize) accumulation across batches, matching the lifecycle pattern of the existing Mean and Variance classes
  • Distributed all-reduce is applied automatically when DistributedManager and torch.distributed are both initialized
  • Removes the # TODO(Dallas) Introduce Ensemble RMSE and MSE routines. comment in mse.py that requested this feature
  • Adds __init__.py exports for physicsnemo/metrics/general/ (previously empty after the license header), exporting EnsembleMSE, EnsembleRMSE, EnsembleMetrics, Mean, Variance, mse, rmse
  • Adds test_ensemble_mse_rmse to test/metrics/test_metrics_general.py covering: single-pass correctness, incremental accumulation, finalize(), RMSE = sqrt(MSE) identity, device mismatch error, and wrong-shape error

Test plan

  • pytest test/metrics/test_metrics_general.py -v passes on CPU
  • pytest test/metrics/test_metrics_general.py -v passes on CUDA (if available)
  • EnsembleRMSE.finalize() ** 2 == EnsembleMSE.finalize() identity holds
  • Device mismatch raises AssertionError
  • Wrong input shape raises ValueError

Implements EnsembleMSE and EnsembleRMSE as stateful streaming classes
that support both single-pass (__call__) and incremental (update/finalize)
accumulation across batches, with optional distributed all-reduce support.
Removes the TODO comment in mse.py that requested these routines and adds
__init__.py exports for the general metrics subpackage.
@copy-pr-bot

copy-pr-bot Bot commented Jun 6, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps

greptile-apps Bot commented Jun 6, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR introduces EnsembleMSE and EnsembleRMSE as stateful streaming metric classes that extend the existing EnsembleMetrics base, fulfilling the long-standing TODO in mse.py. The math, distributed all-reduce wiring, and __init__.py exports are all correct.

  • EnsembleMSE / EnsembleRMSE implementation — both follow the __call__updatefinalize lifecycle of the existing Mean class, accumulating sum_sq_err and n across batches; EnsembleRMSE wraps every return with torch.sqrt.
  • Three gaps in defensive validation__call__ skips _check_shape (unlike update), finalize has no guard for n == 0 (unlike Variance.finalize), and EnsembleRMSE.finalize stores self.mse rather than self.rmse, leaving the instance attribute misleading for callers who inspect it after the fact.

Important Files Changed

Filename Overview
physicsnemo/metrics/general/ensemble_metrics.py Adds EnsembleMSE and EnsembleRMSE streaming metric classes; core math and distributed all-reduce logic are correct, but call skips _check_shape (risking silent state corruption), finalize() has no guard against n=0, and EnsembleRMSE.finalize() stores self.mse instead of self.rmse.
physicsnemo/metrics/general/init.py Adds public exports for EnsembleMSE, EnsembleRMSE, EnsembleMetrics, Mean, Variance, mse, and rmse; straightforward and correct.
physicsnemo/metrics/general/mse.py Removes the TODO comment requesting ensemble MSE/RMSE routines; no functional change.
test/metrics/test_metrics_general.py Adds test_ensemble_mse_rmse covering single-pass, incremental accumulation, finalize, RMSE identity, device mismatch, and wrong-shape error; does not test call with wrong shape or finalize on an uninitialized instance.

Reviews (1): Last reviewed commit: "feat(metrics): add EnsembleMSE and Ensem..." | Re-trigger Greptile

Comment on lines +453 to +484
def __call__(self, preds: Tensor, target: Tensor, dim: int = 0) -> Tensor:
"""Initialise with the first batch of ensemble predictions.

Parameters
----------
preds : Tensor
Ensemble predictions of shape ``[n_ensemble, *input_shape]``.
target : Tensor
Ground-truth tensor of shape ``[*input_shape]``.
dim : int, optional
Ensemble dimension of ``preds``, by default 0.

Returns
-------
Tensor
Running ensemble MSE using samples seen so far.
"""
if preds.device != self.device:
raise AssertionError(
f"Input device, {preds.device}, and Module device, {self.device}, must be the same."
)
sq_err = (preds - target) ** 2
self.sum_sq_err = torch.sum(sq_err, dim=dim)
self.n = torch.as_tensor([preds.shape[dim]], device=self.device)

if (
DistributedManager.is_initialized() and dist.is_initialized()
): # pragma: no cover
dist.all_reduce(self.sum_sq_err, op=dist.ReduceOp.SUM)
dist.all_reduce(self.n, op=dist.ReduceOp.SUM)

return self.sum_sq_err / self.n

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 __call__ skips _check_shape, silently corrupting accumulated state

EnsembleMSE.__call__ resets self.sum_sq_err to whatever shape torch.sum(sq_err, dim=dim) produces, bypassing _check_shape. If called with incorrectly-shaped preds, self.sum_sq_err is left with a shape inconsistent with self.input_shape. Any subsequent update() call with correctly-shaped inputs will then crash inside self.sum_sq_err += sums with an unhelpful shape-mismatch RuntimeError, even though _check_shape in update() reports no problem — because _check_shape validates against self.input_shape, not against the current shape of self.sum_sq_err. Adding self._check_shape(preds) as the first statement of __call__ closes this gap and is consistent with how update() behaves.

Comment on lines +526 to +535
def finalize(self) -> Tensor:
"""Compute and store the final ensemble MSE.

Returns
-------
Tensor
Final ensemble mean squared error.
"""
self.mse = self.sum_sq_err / self.n
return self.mse

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 finalize() silently returns NaN when called on an uninitialized instance

self.n is initialized to torch.zeros([1], dtype=torch.int32), so calling finalize() on a fresh EnsembleMSE (before any __call__ or update) computes 0.0 / 0 = NaN without any error. Variance.finalize() explicitly guards against this with if not (self.n > 1.0): raise ValueError(...). Adding a similar guard (if self.n == 0) in EnsembleMSE.finalize() would keep the API consistent with the stricter existing classes and prevent silent NaN propagation downstream.

Comment on lines +589 to +597
def finalize(self) -> Tensor:
"""Compute and store the final ensemble RMSE.

Returns
-------
Tensor
Final ensemble root mean squared error.
"""
return torch.sqrt(super().finalize())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 EnsembleRMSE.finalize() leaves self.mse set to the MSE value, not RMSE

EnsembleRMSE.finalize() delegates to EnsembleMSE.finalize() via super(), which sets self.mse = sum_sq_err / n (i.e. the MSE) and returns it. The RMSE returned by this method is correct, but no self.rmse attribute is ever stored. Callers who cache the instance and later inspect instance.mse receive the MSE, not RMSE, and instance.rmse raises AttributeError. Storing the final value as self.rmse (parallel to how EnsembleMSE.finalize() stores self.mse) would make the attribute naming symmetrical and predictable.

Harshdeep Sharma and others added 2 commits June 8, 2026 07:56
- Add _check_shape call in EnsembleMSE.__call__ to prevent silent state
  corruption when called with wrong-shaped inputs (matches update() behavior)
- Guard EnsembleMSE.finalize() against n==0 to raise ValueError instead of
  silently returning NaN (matches Variance.finalize() pattern)
- Store self.rmse in EnsembleRMSE.finalize() for symmetrical attribute naming
  with EnsembleMSE.mse
- Add test coverage for all three edge cases
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant