feat(metrics): add EnsembleMSE and EnsembleRMSE streaming metrics#1701
feat(metrics): add EnsembleMSE and EnsembleRMSE streaming metrics#1701harshaa765 wants to merge 3 commits into
Conversation
Implements EnsembleMSE and EnsembleRMSE as stateful streaming classes that support both single-pass (__call__) and incremental (update/finalize) accumulation across batches, with optional distributed all-reduce support. Removes the TODO comment in mse.py that requested these routines and adds __init__.py exports for the general metrics subpackage.
Greptile SummaryThis PR introduces
Important Files Changed
Reviews (1): Last reviewed commit: "feat(metrics): add EnsembleMSE and Ensem..." | Re-trigger Greptile |
| def __call__(self, preds: Tensor, target: Tensor, dim: int = 0) -> Tensor: | ||
| """Initialise with the first batch of ensemble predictions. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| preds : Tensor | ||
| Ensemble predictions of shape ``[n_ensemble, *input_shape]``. | ||
| target : Tensor | ||
| Ground-truth tensor of shape ``[*input_shape]``. | ||
| dim : int, optional | ||
| Ensemble dimension of ``preds``, by default 0. | ||
|
|
||
| Returns | ||
| ------- | ||
| Tensor | ||
| Running ensemble MSE using samples seen so far. | ||
| """ | ||
| if preds.device != self.device: | ||
| raise AssertionError( | ||
| f"Input device, {preds.device}, and Module device, {self.device}, must be the same." | ||
| ) | ||
| sq_err = (preds - target) ** 2 | ||
| self.sum_sq_err = torch.sum(sq_err, dim=dim) | ||
| self.n = torch.as_tensor([preds.shape[dim]], device=self.device) | ||
|
|
||
| if ( | ||
| DistributedManager.is_initialized() and dist.is_initialized() | ||
| ): # pragma: no cover | ||
| dist.all_reduce(self.sum_sq_err, op=dist.ReduceOp.SUM) | ||
| dist.all_reduce(self.n, op=dist.ReduceOp.SUM) | ||
|
|
||
| return self.sum_sq_err / self.n |
There was a problem hiding this comment.
__call__ skips _check_shape, silently corrupting accumulated state
EnsembleMSE.__call__ resets self.sum_sq_err to whatever shape torch.sum(sq_err, dim=dim) produces, bypassing _check_shape. If called with incorrectly-shaped preds, self.sum_sq_err is left with a shape inconsistent with self.input_shape. Any subsequent update() call with correctly-shaped inputs will then crash inside self.sum_sq_err += sums with an unhelpful shape-mismatch RuntimeError, even though _check_shape in update() reports no problem — because _check_shape validates against self.input_shape, not against the current shape of self.sum_sq_err. Adding self._check_shape(preds) as the first statement of __call__ closes this gap and is consistent with how update() behaves.
| def finalize(self) -> Tensor: | ||
| """Compute and store the final ensemble MSE. | ||
|
|
||
| Returns | ||
| ------- | ||
| Tensor | ||
| Final ensemble mean squared error. | ||
| """ | ||
| self.mse = self.sum_sq_err / self.n | ||
| return self.mse |
There was a problem hiding this comment.
finalize() silently returns NaN when called on an uninitialized instance
self.n is initialized to torch.zeros([1], dtype=torch.int32), so calling finalize() on a fresh EnsembleMSE (before any __call__ or update) computes 0.0 / 0 = NaN without any error. Variance.finalize() explicitly guards against this with if not (self.n > 1.0): raise ValueError(...). Adding a similar guard (if self.n == 0) in EnsembleMSE.finalize() would keep the API consistent with the stricter existing classes and prevent silent NaN propagation downstream.
| def finalize(self) -> Tensor: | ||
| """Compute and store the final ensemble RMSE. | ||
|
|
||
| Returns | ||
| ------- | ||
| Tensor | ||
| Final ensemble root mean squared error. | ||
| """ | ||
| return torch.sqrt(super().finalize()) |
There was a problem hiding this comment.
EnsembleRMSE.finalize() leaves self.mse set to the MSE value, not RMSE
EnsembleRMSE.finalize() delegates to EnsembleMSE.finalize() via super(), which sets self.mse = sum_sq_err / n (i.e. the MSE) and returns it. The RMSE returned by this method is correct, but no self.rmse attribute is ever stored. Callers who cache the instance and later inspect instance.mse receive the MSE, not RMSE, and instance.rmse raises AttributeError. Storing the final value as self.rmse (parallel to how EnsembleMSE.finalize() stores self.mse) would make the attribute naming symmetrical and predictable.
- Add _check_shape call in EnsembleMSE.__call__ to prevent silent state corruption when called with wrong-shaped inputs (matches update() behavior) - Guard EnsembleMSE.finalize() against n==0 to raise ValueError instead of silently returning NaN (matches Variance.finalize() pattern) - Store self.rmse in EnsembleRMSE.finalize() for symmetrical attribute naming with EnsembleMSE.mse - Add test coverage for all three edge cases
Summary
EnsembleMSEandEnsembleRMSEas stateful streaming metric classes tophysicsnemo/metrics/general/ensemble_metrics.py, extending the existingEnsembleMetricsbase class__call__) and incremental (update/finalize) accumulation across batches, matching the lifecycle pattern of the existingMeanandVarianceclassesDistributedManagerandtorch.distributedare both initialized# TODO(Dallas) Introduce Ensemble RMSE and MSE routines.comment inmse.pythat requested this feature__init__.pyexports forphysicsnemo/metrics/general/(previously empty after the license header), exportingEnsembleMSE,EnsembleRMSE,EnsembleMetrics,Mean,Variance,mse,rmsetest_ensemble_mse_rmsetotest/metrics/test_metrics_general.pycovering: single-pass correctness, incremental accumulation,finalize(), RMSE = sqrt(MSE) identity, device mismatch error, and wrong-shape errorTest plan
pytest test/metrics/test_metrics_general.py -vpasses on CPUpytest test/metrics/test_metrics_general.py -vpasses on CUDA (if available)EnsembleRMSE.finalize() ** 2 == EnsembleMSE.finalize()identity holdsAssertionErrorValueError