Skip to content

Commit d20deca

Browse files
committed
fixing docs for metrics
1 parent 2a258fd commit d20deca

4 files changed

Lines changed: 58 additions & 44 deletions

File tree

src/auto_cast/metrics/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,4 @@
22

33
__all__ = ["MAE", "MSE", "NMAE", "NMSE", "NRMSE", "RMSE", "VMSE", "VRMSE", "LInfinity"]
44

5-
ALL_METRICS = (
6-
MSE, MAE, NMAE, NMSE, RMSE, NRMSE, VMSE, VRMSE, LInfinity
7-
)
5+
ALL_METRICS = (MSE, MAE, NMAE, NMSE, RMSE, NRMSE, VMSE, VRMSE, LInfinity)

src/auto_cast/metrics/base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,15 @@
22
import torch
33
from torch import nn
44

5+
from auto_cast.types import TensorBCTSPlus
6+
57

68
class Metric(nn.Module):
79
"""
8-
Decorator for metrics that standardizes the input arguments and checks the dimensions of the input tensors.
10+
Base class for metrics.
11+
12+
This class standardizes the input arguments and
13+
checks the dimensions of the input tensors.
914
1015
Args:
1116
f: function
@@ -46,5 +51,7 @@ def forward(self, *args, **kwargs):
4651
return self.score(y_pred, y_true, n_spatial_dims, **kwargs)
4752

4853
@staticmethod
49-
def score(y_pred, y_true, n_spatial_dims, **kwargs):
54+
def score(
55+
y_pred: TensorBCTSPlus, y_true: TensorBCTSPlus, n_spatial_dims: int, **kwargs
56+
):
5057
raise NotImplementedError

src/auto_cast/metrics/spatiotemporal.py

Lines changed: 47 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
1-
import numpy as np
21
import torch
32

43
from auto_cast.metrics.base import Metric
5-
from auto_cast.types import TensorBTC, TensorBTSPlusC
4+
from auto_cast.types import TensorBCTSPlus, TensorBTC
65

76

87
class MSE(Metric):
8+
"""Mean Squared Error."""
9+
910
@staticmethod
1011
def score(
11-
y_pred: TensorBTSPlusC,
12-
y_true: TensorBTSPlusC,
13-
n_spatial_dims: int,
12+
y_pred: TensorBCTSPlus, y_true: TensorBCTSPlus, n_spatial_dims: int, **kwargs
1413
) -> TensorBTC:
1514
"""
16-
Mean Squared Error
15+
Compute Mean Squared Error.
1716
1817
Args:
1918
y_pred: Predicted values tensor.
@@ -30,14 +29,14 @@ def score(
3029

3130

3231
class MAE(Metric):
32+
"""Mean Absolute Error."""
33+
3334
@staticmethod
3435
def score(
35-
y_pred: TensorBTSPlusC,
36-
y_true: TensorBTSPlusC,
37-
n_spatial_dims: int,
36+
y_pred: TensorBCTSPlus, y_true: TensorBCTSPlus, n_spatial_dims: int, **kwargs
3837
) -> TensorBTC:
3938
"""
40-
Mean Absolute Error
39+
Compute Mean Absolute Error.
4140
4241
Args:
4342
y_pred: Predicted values tensor.
@@ -54,15 +53,18 @@ def score(
5453

5554

5655
class NMAE(Metric):
56+
"""Normalized Mean Absolute Error."""
57+
5758
@staticmethod
5859
def score(
59-
y_pred: TensorBTSPlusC,
60-
y_true: TensorBTSPlusC,
60+
y_pred: TensorBCTSPlus,
61+
y_true: TensorBCTSPlus,
6162
n_spatial_dims: int,
6263
eps: float = 1e-7,
64+
**kwargs,
6365
) -> TensorBTC:
6466
"""
65-
Normalized Mean Absolute Error
67+
Compute Normalized Mean Absolute Error.
6668
6769
Args:
6870
y_pred: Predicted values tensor.
@@ -80,16 +82,19 @@ def score(
8082

8183

8284
class NMSE(Metric):
85+
"""Normalized Mean Squared Error."""
86+
8387
@staticmethod
8488
def score(
85-
y_pred: TensorBTSPlusC,
86-
y_true: TensorBTSPlusC,
89+
y_pred: TensorBCTSPlus,
90+
y_true: TensorBCTSPlus,
8791
n_spatial_dims: int,
8892
eps: float = 1e-7,
8993
norm_mode: str = "norm",
94+
**kwargs,
9095
) -> TensorBTC:
9196
"""
92-
Normalized Mean Squared Error
97+
Compute Normalized Mean Squared Error.
9398
9499
Args:
95100
y_pred: Predicted values tensor.
@@ -98,7 +103,8 @@ def score(
98103
Number of spatial dimensions.
99104
eps: Small value to avoid division by zero. Default is 1e-7.
100105
norm_mode:
101-
Mode for computing the normalization factor. Can be 'norm' or 'std'. Default is 'norm'.
106+
Mode for computing the normalization factor. Can be 'norm' or 'std'.
107+
Default is 'norm'.
102108
103109
Returns
104110
-------
@@ -115,14 +121,14 @@ def score(
115121

116122

117123
class RMSE(Metric):
124+
"""Root Mean Squared Error."""
125+
118126
@staticmethod
119127
def score(
120-
y_pred: TensorBTSPlusC,
121-
y_true: TensorBTSPlusC,
122-
n_spatial_dims: int,
128+
y_pred: TensorBCTSPlus, y_true: TensorBCTSPlus, n_spatial_dims: int, **kwargs
123129
) -> TensorBTC:
124130
"""
125-
Root Mean Squared Error
131+
Compute Root Mean Squared Error.
126132
127133
Args:
128134
y_pred: Predicted values tensor.
@@ -138,24 +144,28 @@ def score(
138144

139145

140146
class NRMSE(Metric):
147+
"""Normalized Root Mean Squared Error."""
148+
141149
@staticmethod
142150
def score(
143-
y_pred: TensorBTSPlusC,
144-
y_true: TensorBTSPlusC,
151+
y_pred: TensorBCTSPlus,
152+
y_true: TensorBCTSPlus,
145153
n_spatial_dims: int,
146154
eps: float = 1e-7,
147155
norm_mode: str = "norm",
156+
**kwargs,
148157
) -> TensorBTC:
149158
"""
150-
Normalized Root Mean Squared Error
159+
Compute Normalized Root Mean Squared Error.
151160
152161
Args:
153162
y_pred: Predicted values tensor.
154163
y_true: Target values tensor.
155164
n_spatial_dims: int
156165
Number of spatial dimensions.
157166
eps: Small value to avoid division by zero. Default is 1e-7.
158-
norm_mode : Mode for computing the normalization factor. Can be 'norm' or 'std'. Default is 'norm'.
167+
norm_mode : Mode for computing the normalization factor.
168+
Can be 'norm' or 'std'. Default is 'norm'.
159169
160170
Returns
161171
-------
@@ -168,14 +178,14 @@ def score(
168178

169179

170180
class VMSE(Metric):
181+
"""Variance Scaled Mean Squared Error."""
182+
171183
@staticmethod
172184
def score(
173-
y_pred: TensorBTSPlusC,
174-
y_true: TensorBTSPlusC,
175-
n_spatial_dims: int,
185+
y_pred: TensorBCTSPlus, y_true: TensorBCTSPlus, n_spatial_dims: int, **kwargs
176186
) -> TensorBTC:
177187
"""
178-
Variance Scaled Mean Squared Error
188+
Compute Variance Scaled Mean Squared Error.
179189
180190
Args:
181191
y_pred: Predicted values tensor.
@@ -191,14 +201,14 @@ def score(
191201

192202

193203
class VRMSE(Metric):
204+
"""Variance Scaled Root Mean Squared Error."""
205+
194206
@staticmethod
195207
def score(
196-
y_pred: TensorBTSPlusC,
197-
y_true: TensorBTSPlusC,
198-
n_spatial_dims: int,
208+
y_pred: TensorBCTSPlus, y_true: TensorBCTSPlus, n_spatial_dims: int, **kwargs
199209
) -> TensorBTC:
200210
"""
201-
Root Variance Scaled Mean Squared Error
211+
Compute Root Variance Scaled Mean Squared Error.
202212
203213
Args:
204214
y_pred: Predicted values tensor.
@@ -214,14 +224,14 @@ def score(
214224

215225

216226
class LInfinity(Metric):
227+
"""L-Infinity Norm."""
228+
217229
@staticmethod
218230
def score(
219-
y_pred: TensorBTSPlusC,
220-
y_true: TensorBTSPlusC,
221-
n_spatial_dims: int,
231+
y_pred: TensorBCTSPlus, y_true: TensorBCTSPlus, n_spatial_dims: int, **kwargs
222232
) -> TensorBTC:
223233
"""
224-
L-Infinity Norm
234+
Compute L-Infinity Norm.
225235
226236
Args:
227237
x: Input tensor.

tests/metrics/test_metrics.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66

77

88
@pytest.mark.parametrize("metric", ALL_METRICS)
9-
109
def test_spatiotemporal_metrics(metric):
1110
y_pred: TensorBTSPlusC = torch.ones((2, 3, 4, 5))
1211
y_true: TensorBTSPlusC = torch.ones((2, 3, 4, 5))
1312
n_spatial_dims = 1
1413

15-
error = metric()(y_pred, y_true, n_spatial_dims)
14+
error = metric()(y_pred, y_true, n_spatial_dims)
1615
assert torch.allclose(error.nansum(), torch.tensor(0.0))

0 commit comments

Comments
 (0)