Skip to content
This repository was archived by the owner on Jul 15, 2024. It is now read-only.

Commit ff8f8f9

Browse files
committed
23.7.4.1
1 parent 4748dc8 commit ff8f8f9

File tree

4 files changed

+54
-27
lines changed

4 files changed

+54
-27
lines changed

example/train.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pandas as pd
22

3-
from lightwood.data.splitter import stratify
3+
from dataprep_ml.splitters import stratify
44
from lightwood.api.high_level import ProblemDefinition, predictor_from_code, json_ai_from_problem, code_from_json_ai
55

66

@@ -11,6 +11,7 @@
1111
train_df, _, _ = stratify(df, pct_train=0.8, pct_dev=0, pct_test=0.2, stratify_on=gby, seed=1, reshuffle=False)
1212

1313
pdef = ProblemDefinition.from_dict({'target': 'Traffic', # column to forecast
14+
'fit_on_all': False,
1415
'timeseries_settings': {
1516
'window': 10, # qty of previous data to use when predicting
1617
'horizon': 5, # forecast horizon length
@@ -22,11 +23,13 @@
2223
p_name = 'arrival_forecast_example'
2324
json_ai = json_ai_from_problem(train_df, problem_definition=pdef)
2425

25-
# specify a quick mixer for this example
26+
# specify a quick mixer configuration for this example
2627
json_ai.model['args']['submodels'] = [
2728
{
28-
"module": "ETSMixer",
29-
"args": {}
29+
"module": "SkTime",
30+
"args": {
31+
'model_path': '"theta.ThetaForecaster"',
32+
}
3033
}
3134
]
3235

@@ -39,4 +42,3 @@
3942
predictor.save(f'./{p_name}.pkl')
4043
with open(f'./{p_name}.py', 'wb') as fp:
4144
fp.write(predictor_class_code.encode('utf-8'))
42-

example/visualize.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
To visualize from a Jupyter notebook, refer to example/visualize.ipynb
55
"""
66
import pandas as pd
7-
from lightwood.data.splitter import stratify
7+
from dataprep_ml.splitters import stratify
88
from lightwood.api.high_level import predictor_from_state
99
from mindsdb_forecast_visualizer.core.dispatcher import forecast
1010

@@ -28,8 +28,8 @@
2828
seed=1,
2929
reshuffle=False)
3030

31-
# Specify series and plot
32-
subset = None # [{'Country': 'UK'}, {'Country': 'US'}] # None will plot all available series
31+
# Specify series and plot. `None` will plot all available series.
32+
subset = [{'Country': 'UK'}, {'Country': 'US'}, {'Country': 'Japan'}, {'Country': 'NZ'}]
3333

3434
forecast(
3535
predictor,

mindsdb_forecast_visualizer/__about__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__title__ = 'mindsdb_forecast_visualizer'
22
__package_name__ = 'mindsdb_forecast_visualizer'
3-
__version__ = '22.8.4.0'
3+
__version__ = '23.7.4.1'
44
__description__ = "Companion package to visualizer forecasts made with MindsDB predictors."
55
__email__ = "[email protected]"
66
__author__ = 'MindsDB Inc'

mindsdb_forecast_visualizer/core/forecaster.py

+43-18
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
from copy import deepcopy
44
from itertools import product
55
from collections import OrderedDict
6+
import datetime
67

8+
import numpy as np
79
import pandas as pd
810
from mindsdb_forecast_visualizer.core.plotter import plot
911

10-
from lightwood.data.cleaner import _standardize_datetime
12+
from dataprep_ml.cleaners import _standardize_datetime
1113

1214

1315
def forecast(model,
@@ -33,7 +35,6 @@ def forecast(model,
3335

3436
if show_insample and len(backfill) == 0:
3537
raise Exception("You must pass a dataframe with the predictor's training data to show in-sample forecasts.")
36-
predargs['time_format'] = 'infer'
3738

3839
# instantiate series according to groups
3940
group_values = OrderedDict()
@@ -58,14 +59,20 @@ def forecast(model,
5859
if g == ():
5960
g = '__default'
6061
try:
61-
filtered_backfill, test_data = get_group(g, subset, data, backfill, group_keys, order)
62+
filtered_backfill, filtered_data = get_group(g, subset, data, backfill, group_keys, order)
6263

63-
if test_data.shape[0] > 0:
64+
if filtered_data.shape[0] > 0:
6465
print(f'Plotting for group {g}...')
65-
original_test_data = test_data
66-
test_data = test_data.iloc[[0]] # library only supports plotting first horizon inside test dataset
6766

68-
filtered_data = pd.concat([filtered_backfill.iloc[-warm_start_offset:], test_data])
67+
# check offset for warm start
68+
special_mixers = ['GluonTSMixer', 'NHitsMixer']
69+
if hasattr(model.ensemble, 'indexes_by_accuracy') and \
70+
(model.mixers[model.ensemble.indexes_by_accuracy[0]].__class__.__name__ in special_mixers):
71+
filtered_data = pd.concat([filtered_backfill.iloc[-warm_start_offset:], filtered_data.iloc[[0]]])
72+
else:
73+
filtered_data = pd.concat([filtered_backfill.iloc[-warm_start_offset:], filtered_data])
74+
75+
6976
if not tss.allow_incomplete_history:
7077
assert filtered_data.shape[0] > tss.window
7178

@@ -83,24 +90,26 @@ def forecast(model,
8390

8491
# forecast & divide into in-sample and out-sample predictions, if required
8592
if show_insample:
93+
offset = predargs.get('forecast_offset', 0)
8694
predargs['forecast_offset'] = -len(filtered_backfill)
8795
model_fit = model.predict(filtered_backfill, args=predargs)
96+
predargs['forecast_offset'] = offset
8897
else:
8998
model_fit = None
9099
if len(filtered_backfill) > 0:
91-
time_target += [t for t in filtered_backfill[tss.order_by]]
92100
pred_target += [None for _ in range(len(filtered_backfill))]
93101
conf_lower += [None for _ in range(len(filtered_backfill))]
94102
conf_upper += [None for _ in range(len(filtered_backfill))]
95103
anomalies += [None for _ in range(len(filtered_backfill))]
96104

97105
predargs['forecast_offset'] = -warm_start_offset
98106
model_forecast = model.predict(filtered_data, args=predargs).iloc[warm_start_offset:]
99-
real_target += [r for r in original_test_data[target]][:tss.horizon]
107+
filtered_data = filtered_data.iloc[warm_start_offset:]
108+
real_target += [float(r) for r in filtered_data[target]][:tss.horizon]
100109

101-
# edge case: convert one-step-ahead predictions to unitary lists
110+
# convert one-step-ahead predictions to unitary lists
102111
if not isinstance(model_forecast['prediction'].iloc[0], list):
103-
for k in ['prediction', 'lower', 'upper'] + [f'order_{i}' for i in tss.order_by]:
112+
for k in ['prediction', 'lower', 'upper'] + [f'order_{tss.order_by}']:
104113
model_forecast[k] = model_forecast[k].apply(lambda x: [x])
105114
if show_insample:
106115
model_fit[k] = model_fit[k].apply(lambda x: [x])
@@ -109,10 +118,11 @@ def forecast(model,
109118
pred_target += [p[0] for p in model_fit['prediction']]
110119
conf_lower += [p[0] for p in model_fit['lower']]
111120
conf_upper += [p[0] for p in model_fit['upper']]
121+
time_target += [p[0] for p in model_fit[f'order_{order}']]
112122
if 'anomaly' in model_fit.columns:
113123
anomalies += [p for p in model_fit['anomaly']]
114124

115-
# forecast always corresponds to predicted arrays for the first out-of-sample query data point
125+
# forecast corresponds to predicted arrays for the first out-of-sample query data point
116126
fcst = {
117127
'prediction': model_forecast['prediction'].iloc[0],
118128
'lower': model_forecast['lower'].iloc[0],
@@ -134,10 +144,23 @@ def forecast(model,
134144
pred_target += [p for p in fcst['prediction']]
135145
conf_lower += [p for p in fcst['lower']]
136146
conf_upper += [p for p in fcst['upper']]
137-
time_target += [r for r in original_test_data[tss.order_by]][:tss.horizon]
147+
148+
# fix timestamps
149+
time_target = [pd.to_datetime(p).timestamp() for p in filtered_data[order]]
150+
try:
151+
delta = model.ts_analysis['deltas'][g]
152+
except:
153+
delta = model.ts_analysis['deltas'].get(tuple([str(gg) for gg in g]),
154+
model.ts_analysis['deltas']['__default'])
155+
156+
for i in range(len(pred_target) - len(time_target)):
157+
time_target.insert(0, time_target[0] - delta)
158+
159+
160+
time_target = [datetime.datetime.utcfromtimestamp(ts).strftime('%Y-%m-%d %H:%M:%S') for ts in time_target]
138161

139162
# round confidences
140-
conf = model_forecast['confidence'].values.mean()
163+
conf = np.array([np.array(l) for l in model_forecast['confidence'].values]).mean()
141164

142165
# set titles and legends
143166
if g != ():
@@ -161,6 +184,8 @@ def forecast(model,
161184
anomalies=anomalies if show_anomaly else None,
162185
separate=separate)
163186
fig.show()
187+
else:
188+
print(f"No data for group {g}. Skipping...")
164189

165190
except Exception:
166191
print(f"Error in group {g}:")
@@ -173,11 +198,11 @@ def get_group(g, subset, data, backfill, group_keys, order):
173198
group_dict = {k: v for k, v in zip(group_keys, g)}
174199

175200
if subset is None or group_dict in subset:
176-
filtered_data = deepcopy(data)
177-
filtered_backfill = deepcopy(backfill)
201+
filtered_data = data
202+
filtered_backfill = backfill
178203
for k, v in group_dict.items():
179-
filtered_data = filtered_data[filtered_data[k] == v]
180-
filtered_backfill = filtered_backfill[filtered_backfill[k] == v]
204+
filtered_data = deepcopy(filtered_data[filtered_data[k] == v])
205+
filtered_backfill = deepcopy(filtered_backfill[filtered_backfill[k] == v])
181206

182207
filtered_data = filtered_data.drop_duplicates(subset=order)
183208
filtered_backfill = filtered_backfill.drop_duplicates(subset=order)

0 commit comments

Comments
 (0)