Skip to content

Commit 1d7d854

Browse files
fix: time index intersection for coefficient of variation (#2202)
* fix: properly take the intersected time indexes for the coefficient of variation * fix: computing rmse on ndarray directly * fix: forgot sqrt for rmse in coef of variation * fix: update type of return in docstring, taking into consideration the multi_ts and multivariate decorator, which convert arrays into list * update changelog * update changelog --------- Co-authored-by: dennisbader <[email protected]>
1 parent 8cb04f6 commit 1d7d854

File tree

2 files changed

+21
-20
lines changed

2 files changed

+21
-20
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1818
- Added option to exclude some `group_cols` from being added as static covariates when using `TimeSeries.from_group_dataframe()` with parameter `drop_group_cols`.
1919

2020
**Fixed**
21+
- Fixed a bug in `coefficient_of_variaton()` with `intersect=True`, where the coefficient was not computed on the intersection. [#2202](https://github.com/unit8co/darts/pull/2202) by [Antoine Madrona](https://github.com/madtoinou).
2122

2223
### For developers of the library:
2324

darts/metrics/metrics.py

+20-20
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from functools import wraps
99
from inspect import signature
10-
from typing import Callable, Optional, Sequence, Tuple, Union
10+
from typing import Callable, List, Optional, Sequence, Tuple, Union
1111
from warnings import warn
1212

1313
import numpy as np
@@ -27,7 +27,7 @@
2727
# care of dealing with Sequence[TimeSeries] and multivariate TimeSeries on its own (See mase() implementation).
2828

2929

30-
def multi_ts_support(func):
30+
def multi_ts_support(func) -> Union[float, List[float]]:
3131
"""
3232
This decorator further adapts the metrics that took as input two univariate/multivariate ``TimeSeries`` instances,
3333
adding support for equally-sized sequences of ``TimeSeries`` instances. The decorator computes the pairwise metric
@@ -107,7 +107,7 @@ def wrapper_multi_ts_support(*args, **kwargs):
107107
return wrapper_multi_ts_support
108108

109109

110-
def multivariate_support(func):
110+
def multivariate_support(func) -> Union[float, List[float]]:
111111
"""
112112
This decorator transforms a metric function that takes as input two univariate TimeSeries instances
113113
into a function that takes two equally-sized multivariate TimeSeries instances, computes the pairwise univariate
@@ -279,7 +279,7 @@ def mae(
279279
280280
Returns
281281
-------
282-
float
282+
Union[float, List[float]]
283283
The Mean Absolute Error (MAE)
284284
"""
285285

@@ -336,7 +336,7 @@ def mse(
336336
337337
Returns
338338
-------
339-
float
339+
Union[float, List[float]]
340340
The Mean Squared Error (MSE)
341341
"""
342342

@@ -393,7 +393,7 @@ def rmse(
393393
394394
Returns
395395
-------
396-
float
396+
Union[float, List[float]]
397397
The Root Mean Squared Error (RMSE)
398398
"""
399399
return np.sqrt(mse(actual_series, pred_series, intersect))
@@ -448,7 +448,7 @@ def rmsle(
448448
449449
Returns
450450
-------
451-
float
451+
Union[float, List[float]]
452452
The Root Mean Squared Log Error (RMSLE)
453453
"""
454454

@@ -510,15 +510,15 @@ def coefficient_of_variation(
510510
511511
Returns
512512
-------
513-
float
513+
Union[float, List[float]]
514514
The Coefficient of Variation
515515
"""
516516

517-
return (
518-
100
519-
* rmse(actual_series, pred_series, intersect)
520-
/ actual_series.pd_dataframe(copy=False).mean().mean()
517+
y_true, y_pred = _get_values_or_raise(
518+
actual_series, pred_series, intersect, remove_nan_union=True
521519
)
520+
# not calling rmse as y_true and y_pred are np.ndarray
521+
return 100 * np.sqrt(np.mean((y_true - y_pred) ** 2)) / y_true.mean()
522522

523523

524524
@multi_ts_support
@@ -577,7 +577,7 @@ def mape(
577577
578578
Returns
579579
-------
580-
float
580+
Union[float, List[float]]
581581
The Mean Absolute Percentage Error (MAPE)
582582
"""
583583

@@ -650,7 +650,7 @@ def smape(
650650
651651
Returns
652652
-------
653-
float
653+
Union[float, List[float]]
654654
The symmetric Mean Absolute Percentage Error (sMAPE)
655655
"""
656656

@@ -725,7 +725,7 @@ def mase(
725725
726726
Returns
727727
-------
728-
float
728+
Union[float, List[float]]
729729
The Mean Absolute Scaled Error (MASE)
730730
"""
731731

@@ -907,7 +907,7 @@ def ope(
907907
908908
Returns
909909
-------
910-
float
910+
Union[float, List[float]]
911911
The Overall Percentage Error (OPE)
912912
"""
913913

@@ -977,7 +977,7 @@ def marre(
977977
978978
Returns
979979
-------
980-
float
980+
Union[float, List[float]]
981981
The Mean Absolute Ranged Relative Error (MARRE)
982982
"""
983983

@@ -1042,7 +1042,7 @@ def r2_score(
10421042
10431043
Returns
10441044
-------
1045-
float
1045+
Union[float, List[float]]
10461046
The Coefficient of Determination :math:`R^2`
10471047
"""
10481048
y1, y2 = _get_values_or_raise(
@@ -1185,7 +1185,7 @@ def rho_risk(
11851185
11861186
Returns
11871187
-------
1188-
float
1188+
Union[float, List[float]]
11891189
The rho-risk metric
11901190
"""
11911191

@@ -1263,7 +1263,7 @@ def quantile_loss(
12631263
12641264
Returns
12651265
-------
1266-
float
1266+
Union[float, List[float]]
12671267
The quantile loss metric
12681268
"""
12691269

0 commit comments

Comments
 (0)