1010from tqdm .auto import tqdm
1111from copy import deepcopy
1212import concurrent .futures
13+ import multiprocessing
14+ import sys
1315
1416import numpy as np
1517import pandas as pd
1618
19+
1720logger = logging .getLogger ('prophet' )
1821
1922
@@ -107,7 +110,7 @@ def map(self, func, *iterables):
107110 for args in zip(*iterables)
108111 ]
109112 return results
110-
113+
111114 disable_tqdm: if True it disables the progress bar that would otherwise show up when parallel=None
112115 extra_output_columns: A String or List of Strings e.g. 'trend' or ['trend'].
113116 Additional columns to 'yhat' and 'ds' to be returned in output.
@@ -116,27 +119,27 @@ def map(self, func, *iterables):
116119 -------
117120 A pd.DataFrame with the forecast, actual value and cutoff.
118121 """
119-
122+
120123 if model .history is None :
121124 raise Exception ('Model has not been fit. Fitting the model provides contextual parameters for cross validation.' )
122-
125+
123126 df = model .history .copy ().reset_index (drop = True )
124127 horizon = pd .Timedelta (horizon )
125128 predict_columns = ['ds' , 'yhat' ]
126-
129+
127130 if model .uncertainty_samples :
128131 predict_columns .extend (['yhat_lower' , 'yhat_upper' ])
129132
130133 if extra_output_columns is not None :
131134 if isinstance (extra_output_columns , str ):
132135 extra_output_columns = [extra_output_columns ]
133136 predict_columns .extend ([c for c in extra_output_columns if c not in predict_columns ])
134-
137+
135138 # Identify the largest seasonality period
136139 period_max = 0.
137140 for s in model .seasonalities .values ():
138141 period_max = max (period_max , s ['period' ])
139- seasonality_dt = pd .Timedelta (str (period_max ) + ' days' )
142+ seasonality_dt = pd .Timedelta (str (period_max ) + ' days' )
140143
141144 if cutoffs is None :
142145 # Set period
@@ -152,15 +155,15 @@ def map(self, func, *iterables):
152155 cutoffs = generate_cutoffs (df , horizon , initial , period )
153156 else :
154157 # add validation of the cutoff to make sure that the min cutoff is strictly greater than the min date in the history
155- if min (cutoffs ) <= df ['ds' ].min ():
158+ if min (cutoffs ) <= df ['ds' ].min ():
156159 raise ValueError ("Minimum cutoff value is not strictly greater than min date in history" )
157160 # max value of cutoffs is <= (end date minus horizon)
158- end_date_minus_horizon = df ['ds' ].max () - horizon
159- if max (cutoffs ) > end_date_minus_horizon :
161+ end_date_minus_horizon = df ['ds' ].max () - horizon
162+ if max (cutoffs ) > end_date_minus_horizon :
160163 raise ValueError ("Maximum cutoff value is greater than end date minus horizon, no value for cross-validation remaining" )
161164 initial = cutoffs [0 ] - df ['ds' ].min ()
162-
163- # Check if the initial window
165+
166+ # Check if the initial window
164167 # (that is, the amount of time between the start of the history and the first cutoff)
165168 # is less than the maximum seasonality period
166169 if initial < seasonality_dt :
@@ -175,7 +178,11 @@ def map(self, func, *iterables):
175178 if parallel == "threads" :
176179 pool = concurrent .futures .ThreadPoolExecutor ()
177180 elif parallel == "processes" :
178- pool = concurrent .futures .ProcessPoolExecutor ()
181+ if sys .platform .startswith ("win" ) or sys .platform == "darwin" :
182+ ctx = multiprocessing .get_context ("spawn" )
183+ else :
184+ ctx = multiprocessing .get_context ("forkserver" )
185+ pool = concurrent .futures .ProcessPoolExecutor (mp_context = ctx )
179186 elif parallel == "dask" :
180187 try :
181188 from dask .distributed import get_client
@@ -204,7 +211,7 @@ def map(self, func, *iterables):
204211
205212 else :
206213 predicts = [
207- single_cutoff_forecast (df , model , cutoff , horizon , predict_columns )
214+ single_cutoff_forecast (df , model , cutoff , horizon , predict_columns )
208215 for cutoff in (tqdm (cutoffs ) if not disable_tqdm else cutoffs )
209216 ]
210217
@@ -334,7 +341,7 @@ def register_performance_metric(func):
334341 df: Cross-validation results dataframe.
335342 w: Aggregation window size.
336343
337- Registered metric should return following
344+ Registered metric should return following
338345 -------
339346 Dataframe with columns horizon and metric.
340347 """
@@ -382,7 +389,7 @@ def performance_metrics(df, metrics=None, rolling_window=0.1, monthly=False):
382389 use ['mse', 'rmse', 'mae', 'mape', 'mdape', 'smape', 'coverage'].
383390 rolling_window: Proportion of data to use in each rolling window for
384391 computing the metrics. Should be in [0, 1] to average.
385- monthly: monthly=True will compute horizons as numbers of calendar months
392+ monthly: monthly=True will compute horizons as numbers of calendar months
386393 from the cutoff date, starting from 0 for the cutoff month.
387394
388395 Returns
@@ -477,7 +484,7 @@ def rolling_mean_by_h(x, h, w, name):
477484 res_x = res_x [(trailing_i + 1 ):]
478485
479486 return pd .DataFrame ({'horizon' : res_h , name : res_x })
480-
487+
481488
482489
483490def rolling_median_by_h (x , h , w , name ):
0 commit comments