1- import numpy as np
21import torch
32
43from auto_cast .metrics .base import Metric
5- from auto_cast .types import TensorBTC , TensorBTSPlusC
4+ from auto_cast .types import TensorBCTSPlus , TensorBTC
65
76
87class MSE (Metric ):
8+ """Mean Squared Error."""
9+
910 @staticmethod
1011 def score (
11- y_pred : TensorBTSPlusC ,
12- y_true : TensorBTSPlusC ,
13- n_spatial_dims : int ,
12+ y_pred : TensorBCTSPlus , y_true : TensorBCTSPlus , n_spatial_dims : int , ** kwargs
1413 ) -> TensorBTC :
1514 """
16- Mean Squared Error
15+ Compute Mean Squared Error.
1716
1817 Args:
1918 y_pred: Predicted values tensor.
@@ -30,14 +29,14 @@ def score(
3029
3130
3231class MAE (Metric ):
32+ """Mean Absolute Error."""
33+
3334 @staticmethod
3435 def score (
35- y_pred : TensorBTSPlusC ,
36- y_true : TensorBTSPlusC ,
37- n_spatial_dims : int ,
36+ y_pred : TensorBCTSPlus , y_true : TensorBCTSPlus , n_spatial_dims : int , ** kwargs
3837 ) -> TensorBTC :
3938 """
40- Mean Absolute Error
39+ Compute Mean Absolute Error.
4140
4241 Args:
4342 y_pred: Predicted values tensor.
@@ -54,15 +53,18 @@ def score(
5453
5554
5655class NMAE (Metric ):
56+ """Normalized Mean Absolute Error."""
57+
5758 @staticmethod
5859 def score (
59- y_pred : TensorBTSPlusC ,
60- y_true : TensorBTSPlusC ,
60+ y_pred : TensorBCTSPlus ,
61+ y_true : TensorBCTSPlus ,
6162 n_spatial_dims : int ,
6263 eps : float = 1e-7 ,
64+ ** kwargs ,
6365 ) -> TensorBTC :
6466 """
65- Normalized Mean Absolute Error
67+ Compute Normalized Mean Absolute Error.
6668
6769 Args:
6870 y_pred: Predicted values tensor.
@@ -80,16 +82,19 @@ def score(
8082
8183
8284class NMSE (Metric ):
85+ """Normalized Mean Squared Error."""
86+
8387 @staticmethod
8488 def score (
85- y_pred : TensorBTSPlusC ,
86- y_true : TensorBTSPlusC ,
89+ y_pred : TensorBCTSPlus ,
90+ y_true : TensorBCTSPlus ,
8791 n_spatial_dims : int ,
8892 eps : float = 1e-7 ,
8993 norm_mode : str = "norm" ,
94+ ** kwargs ,
9095 ) -> TensorBTC :
9196 """
92- Normalized Mean Squared Error
97+ Compute Normalized Mean Squared Error.
9398
9499 Args:
95100 y_pred: Predicted values tensor.
@@ -98,7 +103,8 @@ def score(
98103 Number of spatial dimensions.
99104 eps: Small value to avoid division by zero. Default is 1e-7.
100105 norm_mode:
101- Mode for computing the normalization factor. Can be 'norm' or 'std'. Default is 'norm'.
106+ Mode for computing the normalization factor. Can be 'norm' or 'std'.
107+ Default is 'norm'.
102108
103109 Returns
104110 -------
@@ -115,14 +121,14 @@ def score(
115121
116122
117123class RMSE (Metric ):
124+ """Root Mean Squared Error."""
125+
118126 @staticmethod
119127 def score (
120- y_pred : TensorBTSPlusC ,
121- y_true : TensorBTSPlusC ,
122- n_spatial_dims : int ,
128+ y_pred : TensorBCTSPlus , y_true : TensorBCTSPlus , n_spatial_dims : int , ** kwargs
123129 ) -> TensorBTC :
124130 """
125- Root Mean Squared Error
131+ Compute Root Mean Squared Error.
126132
127133 Args:
128134 y_pred: Predicted values tensor.
@@ -138,24 +144,28 @@ def score(
138144
139145
140146class NRMSE (Metric ):
147+ """Normalized Root Mean Squared Error."""
148+
141149 @staticmethod
142150 def score (
143- y_pred : TensorBTSPlusC ,
144- y_true : TensorBTSPlusC ,
151+ y_pred : TensorBCTSPlus ,
152+ y_true : TensorBCTSPlus ,
145153 n_spatial_dims : int ,
146154 eps : float = 1e-7 ,
147155 norm_mode : str = "norm" ,
156+ ** kwargs ,
148157 ) -> TensorBTC :
149158 """
150- Normalized Root Mean Squared Error
159+ Compute Normalized Root Mean Squared Error.
151160
152161 Args:
153162 y_pred: Predicted values tensor.
154163 y_true: Target values tensor.
155164 n_spatial_dims: int
156165 Number of spatial dimensions.
157166 eps: Small value to avoid division by zero. Default is 1e-7.
158- norm_mode : Mode for computing the normalization factor. Can be 'norm' or 'std'. Default is 'norm'.
167+ norm_mode : Mode for computing the normalization factor.
168+ Can be 'norm' or 'std'. Default is 'norm'.
159169
160170 Returns
161171 -------
@@ -168,14 +178,14 @@ def score(
168178
169179
170180class VMSE (Metric ):
181+ """Variance Scaled Mean Squared Error."""
182+
171183 @staticmethod
172184 def score (
173- y_pred : TensorBTSPlusC ,
174- y_true : TensorBTSPlusC ,
175- n_spatial_dims : int ,
185+ y_pred : TensorBCTSPlus , y_true : TensorBCTSPlus , n_spatial_dims : int , ** kwargs
176186 ) -> TensorBTC :
177187 """
178- Variance Scaled Mean Squared Error
188+ Compute Variance Scaled Mean Squared Error.
179189
180190 Args:
181191 y_pred: Predicted values tensor.
@@ -191,14 +201,14 @@ def score(
191201
192202
193203class VRMSE (Metric ):
204+ """Variance Scaled Root Mean Squared Error."""
205+
194206 @staticmethod
195207 def score (
196- y_pred : TensorBTSPlusC ,
197- y_true : TensorBTSPlusC ,
198- n_spatial_dims : int ,
208+ y_pred : TensorBCTSPlus , y_true : TensorBCTSPlus , n_spatial_dims : int , ** kwargs
199209 ) -> TensorBTC :
200210 """
201- Root Variance Scaled Mean Squared Error
211+ Compute Root Variance Scaled Mean Squared Error.
202212
203213 Args:
204214 y_pred: Predicted values tensor.
@@ -214,14 +224,14 @@ def score(
214224
215225
216226class LInfinity (Metric ):
227+ """L-Infinity Norm."""
228+
217229 @staticmethod
218230 def score (
219- y_pred : TensorBTSPlusC ,
220- y_true : TensorBTSPlusC ,
221- n_spatial_dims : int ,
231+ y_pred : TensorBCTSPlus , y_true : TensorBCTSPlus , n_spatial_dims : int , ** kwargs
222232 ) -> TensorBTC :
223233 """
224- L-Infinity Norm
234+ Compute L-Infinity Norm.
225235
226236 Args:
227237 x: Input tensor.
0 commit comments