|
7 | 7 |
|
8 | 8 | from functools import wraps
|
9 | 9 | from inspect import signature
|
10 |
| -from typing import Callable, Optional, Sequence, Tuple, Union |
| 10 | +from typing import Callable, List, Optional, Sequence, Tuple, Union |
11 | 11 | from warnings import warn
|
12 | 12 |
|
13 | 13 | import numpy as np
|
|
27 | 27 | # care of dealing with Sequence[TimeSeries] and multivariate TimeSeries on its own (See mase() implementation).
|
28 | 28 |
|
29 | 29 |
|
30 |
| -def multi_ts_support(func): |
| 30 | +def multi_ts_support(func) -> Union[float, List[float]]: |
31 | 31 | """
|
32 | 32 | This decorator further adapts the metrics that took as input two univariate/multivariate ``TimeSeries`` instances,
|
33 | 33 | adding support for equally-sized sequences of ``TimeSeries`` instances. The decorator computes the pairwise metric
|
@@ -107,7 +107,7 @@ def wrapper_multi_ts_support(*args, **kwargs):
|
107 | 107 | return wrapper_multi_ts_support
|
108 | 108 |
|
109 | 109 |
|
110 |
| -def multivariate_support(func): |
| 110 | +def multivariate_support(func) -> Union[float, List[float]]: |
111 | 111 | """
|
112 | 112 | This decorator transforms a metric function that takes as input two univariate TimeSeries instances
|
113 | 113 | into a function that takes two equally-sized multivariate TimeSeries instances, computes the pairwise univariate
|
@@ -279,7 +279,7 @@ def mae(
|
279 | 279 |
|
280 | 280 | Returns
|
281 | 281 | -------
|
282 |
| - float |
| 282 | + Union[float, List[float]] |
283 | 283 | The Mean Absolute Error (MAE)
|
284 | 284 | """
|
285 | 285 |
|
@@ -336,7 +336,7 @@ def mse(
|
336 | 336 |
|
337 | 337 | Returns
|
338 | 338 | -------
|
339 |
| - float |
| 339 | + Union[float, List[float]] |
340 | 340 | The Mean Squared Error (MSE)
|
341 | 341 | """
|
342 | 342 |
|
@@ -393,7 +393,7 @@ def rmse(
|
393 | 393 |
|
394 | 394 | Returns
|
395 | 395 | -------
|
396 |
| - float |
| 396 | + Union[float, List[float]] |
397 | 397 | The Root Mean Squared Error (RMSE)
|
398 | 398 | """
|
399 | 399 | return np.sqrt(mse(actual_series, pred_series, intersect))
|
@@ -448,7 +448,7 @@ def rmsle(
|
448 | 448 |
|
449 | 449 | Returns
|
450 | 450 | -------
|
451 |
| - float |
| 451 | + Union[float, List[float]] |
452 | 452 | The Root Mean Squared Log Error (RMSLE)
|
453 | 453 | """
|
454 | 454 |
|
@@ -510,15 +510,15 @@ def coefficient_of_variation(
|
510 | 510 |
|
511 | 511 | Returns
|
512 | 512 | -------
|
513 |
| - float |
| 513 | + Union[float, List[float]] |
514 | 514 | The Coefficient of Variation
|
515 | 515 | """
|
516 | 516 |
|
517 |
| - return ( |
518 |
| - 100 |
519 |
| - * rmse(actual_series, pred_series, intersect) |
520 |
| - / actual_series.pd_dataframe(copy=False).mean().mean() |
| 517 | + y_true, y_pred = _get_values_or_raise( |
| 518 | + actual_series, pred_series, intersect, remove_nan_union=True |
521 | 519 | )
|
| 520 | + # not calling rmse as y_true and y_pred are np.ndarray |
| 521 | + return 100 * np.sqrt(np.mean((y_true - y_pred) ** 2)) / y_true.mean() |
522 | 522 |
|
523 | 523 |
|
524 | 524 | @multi_ts_support
|
@@ -577,7 +577,7 @@ def mape(
|
577 | 577 |
|
578 | 578 | Returns
|
579 | 579 | -------
|
580 |
| - float |
| 580 | + Union[float, List[float]] |
581 | 581 | The Mean Absolute Percentage Error (MAPE)
|
582 | 582 | """
|
583 | 583 |
|
@@ -650,7 +650,7 @@ def smape(
|
650 | 650 |
|
651 | 651 | Returns
|
652 | 652 | -------
|
653 |
| - float |
| 653 | + Union[float, List[float]] |
654 | 654 | The symmetric Mean Absolute Percentage Error (sMAPE)
|
655 | 655 | """
|
656 | 656 |
|
@@ -725,7 +725,7 @@ def mase(
|
725 | 725 |
|
726 | 726 | Returns
|
727 | 727 | -------
|
728 |
| - float |
| 728 | + Union[float, List[float]] |
729 | 729 | The Mean Absolute Scaled Error (MASE)
|
730 | 730 | """
|
731 | 731 |
|
@@ -907,7 +907,7 @@ def ope(
|
907 | 907 |
|
908 | 908 | Returns
|
909 | 909 | -------
|
910 |
| - float |
| 910 | + Union[float, List[float]] |
911 | 911 | The Overall Percentage Error (OPE)
|
912 | 912 | """
|
913 | 913 |
|
@@ -977,7 +977,7 @@ def marre(
|
977 | 977 |
|
978 | 978 | Returns
|
979 | 979 | -------
|
980 |
| - float |
| 980 | + Union[float, List[float]] |
981 | 981 | The Mean Absolute Ranged Relative Error (MARRE)
|
982 | 982 | """
|
983 | 983 |
|
@@ -1042,7 +1042,7 @@ def r2_score(
|
1042 | 1042 |
|
1043 | 1043 | Returns
|
1044 | 1044 | -------
|
1045 |
| - float |
| 1045 | + Union[float, List[float]] |
1046 | 1046 | The Coefficient of Determination :math:`R^2`
|
1047 | 1047 | """
|
1048 | 1048 | y1, y2 = _get_values_or_raise(
|
@@ -1185,7 +1185,7 @@ def rho_risk(
|
1185 | 1185 |
|
1186 | 1186 | Returns
|
1187 | 1187 | -------
|
1188 |
| - float |
| 1188 | + Union[float, List[float]] |
1189 | 1189 | The rho-risk metric
|
1190 | 1190 | """
|
1191 | 1191 |
|
@@ -1263,7 +1263,7 @@ def quantile_loss(
|
1263 | 1263 |
|
1264 | 1264 | Returns
|
1265 | 1265 | -------
|
1266 |
| - float |
| 1266 | + Union[float, List[float]] |
1267 | 1267 | The quantile loss metric
|
1268 | 1268 | """
|
1269 | 1269 |
|
|
0 commit comments