Skip to content

Enable creation of custom performance metrics #2599

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 171 additions & 71 deletions docs/_docs/diagnostics.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,45 +86,45 @@ df_cv.head()
<tr>
<th>0</th>
<td>2010-02-16</td>
<td>8.959678</td>
<td>8.470035</td>
<td>9.451618</td>
<td>8.954582</td>
<td>8.462876</td>
<td>9.452305</td>
<td>8.242493</td>
<td>2010-02-15</td>
</tr>
<tr>
<th>1</th>
<td>2010-02-17</td>
<td>8.726195</td>
<td>8.236734</td>
<td>9.219616</td>
<td>8.720932</td>
<td>8.222682</td>
<td>9.242788</td>
<td>8.008033</td>
<td>2010-02-15</td>
</tr>
<tr>
<th>2</th>
<td>2010-02-18</td>
<td>8.610011</td>
<td>8.104834</td>
<td>9.125484</td>
<td>8.604608</td>
<td>8.066920</td>
<td>9.144968</td>
<td>8.045268</td>
<td>2010-02-15</td>
</tr>
<tr>
<th>3</th>
<td>2010-02-19</td>
<td>8.532004</td>
<td>7.985031</td>
<td>9.041575</td>
<td>8.526379</td>
<td>8.029189</td>
<td>9.043045</td>
<td>7.928766</td>
<td>2010-02-15</td>
</tr>
<tr>
<th>4</th>
<td>2010-02-20</td>
<td>8.274090</td>
<td>7.779034</td>
<td>8.745627</td>
<td>8.268247</td>
<td>7.749520</td>
<td>8.741847</td>
<td>7.745003</td>
<td>2010-02-15</td>
</tr>
Expand Down Expand Up @@ -154,6 +154,20 @@ df_cv2 = cross_validation(m, cutoffs=cutoffs, horizon='365 days')
The `performance_metrics` utility can be used to compute some useful statistics of the prediction performance (`yhat`, `yhat_lower`, and `yhat_upper` compared to `y`), as a function of the distance from the cutoff (how far into the future the prediction was). The statistics computed are mean squared error (MSE), root mean squared error (RMSE), mean absolute error (MAE), mean absolute percent error (MAPE), median absolute percent error (MDAPE) and coverage of the `yhat_lower` and `yhat_upper` estimates. These are computed on a rolling window of the predictions in `df_cv` after sorting by horizon (`ds` minus `cutoff`). By default 10% of the predictions will be included in each window, but this can be changed with the `rolling_window` argument.



In Python, you can also create custom performance metric using the `register_performance_metric` decorator. Created metric should contain following arguments:

- df: Cross-validation results dataframe.

- w: Aggregation window size.



and return:

- Dataframe with columns horizon and metric.


```R
# R
df.p <- performance_metrics(df.cv)
Expand Down Expand Up @@ -200,57 +214,143 @@ df_p.head()
<tr>
<th>0</th>
<td>37 days</td>
<td>0.493764</td>
<td>0.702683</td>
<td>0.504754</td>
<td>0.058485</td>
<td>0.049922</td>
<td>0.058774</td>
<td>0.674052</td>
<td>0.493358</td>
<td>0.702395</td>
<td>0.503977</td>
<td>0.058376</td>
<td>0.049365</td>
<td>0.058677</td>
<td>0.676565</td>
</tr>
<tr>
<th>1</th>
<td>38 days</td>
<td>0.499522</td>
<td>0.706769</td>
<td>0.509723</td>
<td>0.059060</td>
<td>0.049389</td>
<td>0.059409</td>
<td>0.672910</td>
<td>0.499112</td>
<td>0.706478</td>
<td>0.508946</td>
<td>0.058951</td>
<td>0.049135</td>
<td>0.059312</td>
<td>0.675423</td>
</tr>
<tr>
<th>2</th>
<td>39 days</td>
<td>0.521614</td>
<td>0.722229</td>
<td>0.515793</td>
<td>0.059657</td>
<td>0.049540</td>
<td>0.060131</td>
<td>0.670169</td>
<td>0.521344</td>
<td>0.722042</td>
<td>0.515016</td>
<td>0.059547</td>
<td>0.049225</td>
<td>0.060034</td>
<td>0.672682</td>
</tr>
<tr>
<th>3</th>
<td>40 days</td>
<td>0.528760</td>
<td>0.727159</td>
<td>0.518634</td>
<td>0.059961</td>
<td>0.049232</td>
<td>0.060504</td>
<td>0.671311</td>
<td>0.528651</td>
<td>0.727084</td>
<td>0.517873</td>
<td>0.059852</td>
<td>0.049072</td>
<td>0.060409</td>
<td>0.676336</td>
</tr>
<tr>
<th>4</th>
<td>41 days</td>
<td>0.536078</td>
<td>0.732174</td>
<td>0.519585</td>
<td>0.060036</td>
<td>0.049389</td>
<td>0.060641</td>
<td>0.678849</td>
<td>0.536149</td>
<td>0.732222</td>
<td>0.518843</td>
<td>0.059927</td>
<td>0.049135</td>
<td>0.060548</td>
<td>0.681361</td>
</tr>
</tbody>
</table>
</div>



```python
# Python
from prophet.diagnostics import register_performance_metric, rolling_mean_by_h
import numpy as np
@register_performance_metric
def mase(df, w):
"""Mean absolute scale error

Parameters
----------
df: Cross-validation results dataframe.
w: Aggregation window size.

Returns
-------
Dataframe with columns horizon and mase.
"""
e = (df['y'] - df['yhat'])
d = np.abs(np.diff(df['y'])).sum()/(df['y'].shape[0]-1)
se = np.abs(e/d)
if w < 0:
return pd.DataFrame({'horizon': df['horizon'], 'mase': se})
return rolling_mean_by_h(
x=se.values, h=df['horizon'].values, w=w, name='mase'
)

df_mase = performance_metrics(df_cv, metrics=['mase'])
df_mase.head()
```



<div>
<style scoped>
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}

.dataframe tbody tr th {
vertical-align: top;
}

.dataframe thead th {
text-align: right;
}
</style>
<table border="1" class="dataframe">
<thead>
<tr style="text-align: right;">
<th></th>
<th>horizon</th>
<th>mase</th>
</tr>
</thead>
<tbody>
<tr>
<th>0</th>
<td>37 days</td>
<td>0.522946</td>
</tr>
<tr>
<th>1</th>
<td>38 days</td>
<td>0.528102</td>
</tr>
<tr>
<th>2</th>
<td>39 days</td>
<td>0.534401</td>
</tr>
<tr>
<th>3</th>
<td>40 days</td>
<td>0.537365</td>
</tr>
<tr>
<th>4</th>
<td>41 days</td>
<td>0.538372</td>
</tr>
</tbody>
</table>
Expand All @@ -271,7 +371,7 @@ from prophet.plot import plot_cross_validation_metric
fig = plot_cross_validation_metric(df_cv, metric='mape')
```

![png](/prophet/static/diagnostics_files/diagnostics_17_0.png)
![png](/prophet/static/diagnostics_files/diagnostics_18_0.png)


The size of the rolling window in the figure can be changed with the optional argument `rolling_window`, which specifies the proportion of forecasts to use in each rolling window. The default is 0.1, corresponding to 10% of rows from `df_cv` included in each window; increasing this will lead to a smoother average curve in the figure. The `initial` period should be long enough to capture all of the components of the model, in particular seasonalities and extra regressors: at least a year for yearly seasonality, at least a week for weekly seasonality, etc.
Expand Down Expand Up @@ -355,33 +455,33 @@ for params in all_params:
tuning_results = pd.DataFrame(all_params)
tuning_results['rmse'] = rmses
print(tuning_results)
```
changepoint_prior_scale seasonality_prior_scale rmse
0 0.001 0.01 0.757694
1 0.001 0.10 0.743399
2 0.001 1.00 0.753387
3 0.001 10.00 0.762890
4 0.010 0.01 0.542315
5 0.010 0.10 0.535546
6 0.010 1.00 0.527008
7 0.010 10.00 0.541544
8 0.100 0.01 0.524835
9 0.100 0.10 0.516061
10 0.100 1.00 0.521406
11 0.100 10.00 0.518580
12 0.500 0.01 0.532140
13 0.500 0.10 0.524668
14 0.500 1.00 0.521130
15 0.500 10.00 0.522980

```

changepoint_prior_scale seasonality_prior_scale rmse
0 0.001 0.01 0.757694
1 0.001 0.10 0.743399
2 0.001 1.00 0.753387
3 0.001 10.00 0.762890
4 0.010 0.01 0.542315
5 0.010 0.10 0.535546
6 0.010 1.00 0.527008
7 0.010 10.00 0.541544
8 0.100 0.01 0.524835
9 0.100 0.10 0.516061
10 0.100 1.00 0.521406
11 0.100 10.00 0.518580
12 0.500 0.01 0.532140
13 0.500 0.10 0.524668
14 0.500 1.00 0.521130
15 0.500 10.00 0.522980

```python
# Python
best_params = all_params[np.argmin(rmses)]
print(best_params)
```
{'changepoint_prior_scale': 0.1, 'seasonality_prior_scale': 0.1}

{'changepoint_prior_scale': 0.1, 'seasonality_prior_scale': 0.1}

Alternatively, parallelization could be done across parameter combinations by parallelizing the loop above.

Expand Down
Binary file removed docs/static/diagnostics_files/diagnostics_16_0.png
Binary file not shown.
Binary file modified docs/static/diagnostics_files/diagnostics_17_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/static/diagnostics_files/diagnostics_4_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading