2525 model_forecast ,
2626 validation_aggregation ,
2727 back_forecast ,
28- remove_leading_zeros
28+ remove_leading_zeros ,
2929)
3030from autots .models .ensemble import (
3131 EnsembleTemplateGenerator ,
@@ -1099,7 +1099,9 @@ def fit(
10991099 # give a more convenient dict option
11001100 self .best_model_name = self .best_model ['Model' ].iloc [0 ]
11011101 self .best_model_params = json .loads (self .best_model ['ModelParameters' ].iloc [0 ])
1102- self .best_model_transformation_params = json .loads (self .best_model ['TransformationParameters' ].iloc [0 ])
1102+ self .best_model_transformation_params = json .loads (
1103+ self .best_model ['TransformationParameters' ].iloc [0 ]
1104+ )
11031105
11041106 # set flags to check if regressors or ensemble used in final model.
11051107 param_dict = json .loads (self .best_model .iloc [0 ]['ModelParameters' ])
@@ -1334,7 +1336,9 @@ def export_template(
13341336 export_template = export_template .nsmallest (n , columns = ['Score' ])
13351337 if not include_results :
13361338 export_template = export_template [self .template_cols ]
1337- export_template = pd .concat ([self .best_model , export_template ]).drop_duplicates ()
1339+ export_template = pd .concat (
1340+ [self .best_model , export_template ]
1341+ ).drop_duplicates ()
13381342 else :
13391343 raise ValueError ("`models` must be 'all' or 'best'" )
13401344 try :
@@ -1453,7 +1457,9 @@ def import_results(self, filename):
14531457 self .initial_results = self .initial_results .concat (new_obj )
14541458 return self
14551459
1456- def back_forecast (self , column = None , n_splits : int = 3 , verbose : int = 0 ):
1460+ def back_forecast (
1461+ self , column = None , n_splits : int = 3 , tail : int = None , verbose : int = 0
1462+ ):
14571463 """Create forecasts for the historical training data, ie. backcast or back forecast.
14581464
14591465 This actually forecasts on historical data, these are not fit model values as are often returned by other packages.
@@ -1463,6 +1469,7 @@ def back_forecast(self, column=None, n_splits: int = 3, verbose: int = 0):
14631469 Args are same as for model_forecast except...
14641470 n_splits(int): how many pieces to split data into. Pass 2 for fastest, or "auto" for best accuracy
14651471 column (str): if to run on only one column, pass column name. Faster than full.
1472+ tail (int): df.tail() of the dataset, back_forecast is only run on n most recent observations.
14661473
14671474 Returns a standard prediction object (access .forecast, .lower_forecast, .upper_forecast)
14681475 """
@@ -1472,18 +1479,24 @@ def back_forecast(self, column=None, n_splits: int = 3, verbose: int = 0):
14721479 input_df = pd .DataFrame (self .df_wide_numeric [column ])
14731480 else :
14741481 input_df = self .df_wide_numeric
1482+ if tail is not None :
1483+ input_df = input_df .tail (tail )
14751484 result = back_forecast (
14761485 df = input_df ,
14771486 model_name = self .best_model_name ,
14781487 model_param_dict = self .best_model_params ,
14791488 model_transform_dict = self .best_model_transformation_params ,
14801489 future_regressor_train = self .future_regressor_train ,
1481- n_splits = n_splits , forecast_length = self .forecast_length ,
1482- frequency = self .frequency , prediction_interval = self .prediction_interval ,
1490+ n_splits = n_splits ,
1491+ forecast_length = self .forecast_length ,
1492+ frequency = self .frequency ,
1493+ prediction_interval = self .prediction_interval ,
14831494 no_negatives = self .no_negatives ,
1484- constraint = self .constraint , holiday_country = self .holiday_country ,
1495+ constraint = self .constraint ,
1496+ holiday_country = self .holiday_country ,
14851497 random_seed = self .random_seed ,
1486- n_jobs = self .n_jobs , verbose = verbose ,
1498+ n_jobs = self .n_jobs ,
1499+ verbose = verbose ,
14871500 )
14881501 return result
14891502
@@ -1604,7 +1617,9 @@ def plot_generation_loss(self, **kwargs):
16041617 ylabel = "Lowest Score" , ** kwargs
16051618 )
16061619
1607- def plot_backforecast (self , series = None , n_splits : int = 3 , start_date = None , ** kwargs ):
1620+ def plot_backforecast (
1621+ self , series = None , n_splits : int = 3 , start_date = None , ** kwargs
1622+ ):
16081623 """Plot the historical data and fit forecast on historic.
16091624
16101625 Args:
@@ -1616,10 +1631,13 @@ def plot_backforecast(self, series=None, n_splits: int = 3, start_date=None, **k
16161631 series = random .choice (self .df_wide_numeric .columns )
16171632 b_df = self .back_forecast (column = series , n_splits = n_splits , verbose = 0 ).forecast
16181633 b_df = b_df .rename (columns = lambda x : str (x ) + "_forecast" )
1619- plot_df = pd .concat ([
1620- pd .DataFrame (self .df_wide_numeric [series ]),
1621- b_df ,
1622- ], axis = 1 )
1634+ plot_df = pd .concat (
1635+ [
1636+ pd .DataFrame (self .df_wide_numeric [series ]),
1637+ b_df ,
1638+ ],
1639+ axis = 1 ,
1640+ )
16231641 if start_date is not None :
16241642 plot_df = plot_df [plot_df .index >= start_date ]
16251643 plot_df = remove_leading_zeros (plot_df )
@@ -1667,6 +1685,8 @@ def plot_backforecast(self, series=None, n_splits: int = 3, start_date=None, **k
16671685 '#EE82EE' ,
16681686 '#00008B' ,
16691687 '#4B0082' ,
1688+ '#0403A7' ,
1689+ "#000000" ,
16701690]
16711691
16721692
0 commit comments