Skip to content

Commit 96585f5

Browse files
authored
feat: sktime auto adapter (#291)
* feat: auto adapt sktime models in tc agent init * docs: update sktime example for auto-adapting --------- Co-authored-by: spolisar <22416070+spolisar@users.noreply.github.com>
1 parent 8ae0d67 commit 96585f5

File tree

2 files changed

+110
-40
lines changed

2 files changed

+110
-40
lines changed

docs/examples/sktime.ipynb

Lines changed: 79 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,48 @@
3333
"import pandas as pd"
3434
]
3535
},
36+
{
37+
"cell_type": "markdown",
38+
"id": "d7bc0155",
39+
"metadata": {},
40+
"source": [
41+
"## Setup the sktime model\n",
42+
"\n",
43+
"Sktime models can be passed in the `forecasters` argument when initializing the TimeCopilot agent where they will be wrapped in an adapter with an alias based on the type name. \n",
44+
"\n",
45+
"If multiple sktime forecasters of the same type are passed, each model after the first will have be wrapped in an adapter with an alias that has `'_n'` appended to it with `n` being incremented by 1 for each additional occurrence of the same model type. \n",
46+
"\n",
47+
"For example, if you pass two `TrendForecaster` sktime models, the first one will have an alias of `'sktime.TrendForecaster'` and the second one will have an alias of `'sktime.TrendForecaster_2'`.\n",
48+
"\n",
49+
"If you would rather specify the alias yourself, you will need to adapt the model manually with `SKTimeAdapter`."
50+
]
51+
},
52+
{
53+
"cell_type": "code",
54+
"execution_count": 2,
55+
"id": "f4b60a78",
56+
"metadata": {},
57+
"outputs": [],
58+
"source": [
59+
"from sktime.forecasting.trend import TrendForecaster\n",
60+
"\n",
61+
"trend_forecaster = TrendForecaster()"
62+
]
63+
},
3664
{
3765
"cell_type": "markdown",
3866
"id": "694b7854",
3967
"metadata": {},
4068
"source": [
41-
"## Setup the sktime model and adapt it to TimeCopilot\n",
69+
"### Manually adapt sktime model\n",
4270
"\n",
43-
"sktime models need to be adapted to work properly with TimeCopilot. This is done by creating your model with sktime and passing it through SKTimeAdapter. Some sktime models may require more configuration to function properly with the data you intend to use it on. For example, when using sktime's NaiveForecaster with yearly data you might want to initialize it with an `sp` argument of `12` like this `NaiveForecaster(sp=12)`.\n",
71+
"If you would rather decide on the alias yourself, you will need to manually adapt the model with `SKTimeAdapter`.\n",
4472
"\n",
45-
"The `Alias` argument should also be provided, especially if you plan on adding multiple sktime forecasters. If you add multiple sktime models without specifying aliases, TimeCopilot will not be able to properly call all of them."
73+
"The `model` argument should be an sktime `Forecaster` model. The `alias` argument should be a string that uniquely identifies the model.\n",
74+
"\n",
75+
"After adapting the model you would pass it in the `forecasters` argument when initializing the TimeCopilot agent.\n",
76+
"\n",
77+
"If you add multiple manually adapted sktime models of the same type without specifying aliases, TimeCopilot may not be able to properly call all of them."
4678
]
4779
},
4880
{
@@ -57,9 +89,16 @@
5789
"\n",
5890
"trend_forecaster = TrendForecaster()\n",
5991
"\n",
60-
"adapted_model = SKTimeAdapter(\n",
92+
"manually_adapted_model = SKTimeAdapter(\n",
6193
" model=trend_forecaster,\n",
6294
" alias=\"TrendForecaster\",\n",
95+
")\n",
96+
"\n",
97+
"tc = timecopilot.TimeCopilot(\n",
98+
" llm=\"openai:gpt-4o\",\n",
99+
" forecasters=[\n",
100+
" manually_adapted_model\n",
101+
" ]\n",
63102
")"
64103
]
65104
},
@@ -83,7 +122,7 @@
83122
"tc = timecopilot.TimeCopilot(\n",
84123
" llm=\"openai:gpt-4o\",\n",
85124
" forecasters=[\n",
86-
" adapted_model,\n",
125+
" trend_forecaster,\n",
87126
" ],\n",
88127
")"
89128
]
@@ -93,7 +132,7 @@
93132
"id": "401d5b6f",
94133
"metadata": {},
95134
"source": [
96-
"### Extending default model list with an sktime adapted model\n",
135+
"### Extending default model list with an sktime model\n",
97136
"\n",
98137
"if you want to use the default list with the addition of your sktime model you could make a copy of the default list and append your model to it:"
99138
]
@@ -106,7 +145,7 @@
106145
"outputs": [],
107146
"source": [
108147
"model_list = timecopilot.agent.DEFAULT_MODELS.copy()\n",
109-
"model_list.append(adapted_model)\n",
148+
"model_list.append(trend_forecaster)\n",
110149
"\n",
111150
"tc = timecopilot.TimeCopilot(\n",
112151
" llm=\"openai:gpt-4o\",\n",
@@ -143,9 +182,9 @@
143182
"name": "stderr",
144183
"output_type": "stream",
145184
"text": [
146-
"1it [00:00, 4.70it/s]\n",
147-
"1it [00:00, 223.32it/s]\n",
148-
"11it [00:00, 77.11it/s]\n"
185+
"1it [00:00, 4.52it/s]\n",
186+
"1it [00:00, 220.94it/s]\n",
187+
"11it [00:00, 83.97it/s]\n"
149188
]
150189
}
151190
],
@@ -157,15 +196,15 @@
157196
},
158197
{
159198
"cell_type": "code",
160-
"execution_count": 10,
199+
"execution_count": 6,
161200
"id": "7355c143",
162201
"metadata": {},
163202
"outputs": [
164203
{
165204
"name": "stdout",
166205
"output_type": "stream",
167206
"text": [
168-
"The 'AirPassengers' time series has a series length of 144 with a clear seasonal pattern identified using key features. The high 'seasonal_strength' of 0.981 suggests strong seasonality, evident from the 12-month seasonal period. The time series also exhibits trends, shown by a 'trend' score of 0.997, and moderate curvature at 1.069. The high autocorrelation 'x_acf1' at 0.948 indicates the persistence of patterns over time. The Holt-Winters parameters suggest a stable level (alpha ~1) with no trend component (beta ~0) and significant seasonal smoothing (gamma ~0.75). These features suggest that both trend and seasonality are prominent and need to be captured by the model.\n"
207+
"The 'AirPassengers' data exhibits a strong trend with a stability of 0.933, indicating a consistent pattern over time. It also has a very strong seasonal component (seasonal_strength 0.982) with a seasonal period of 12 months, typical for annual data. The high x_acf1 (0.948) indicates strong autocorrelation, implying that past values are highly predictive of future values, suggesting that a time series model with a trend and seasonal structure would be appropriate. The entropy of the series is relatively low (0.429), indicating a high level of predictability, which substantiates the use of deterministic trend models. These features warrant the use of models that incorporate seasonality and trends, such as Holt's linear trend method or even seasonal naive models.\n"
169208
]
170209
}
171210
],
@@ -175,7 +214,7 @@
175214
},
176215
{
177216
"cell_type": "code",
178-
"execution_count": 8,
217+
"execution_count": 7,
179218
"id": "86acfa60",
180219
"metadata": {},
181220
"outputs": [
@@ -202,7 +241,7 @@
202241
" <th></th>\n",
203242
" <th>unique_id</th>\n",
204243
" <th>ds</th>\n",
205-
" <th>TrendForecaster</th>\n",
244+
" <th>sktime.TrendForecaster</th>\n",
206245
" </tr>\n",
207246
" </thead>\n",
208247
" <tbody>\n",
@@ -355,34 +394,34 @@
355394
"</div>"
356395
],
357396
"text/plain": [
358-
" unique_id ds TrendForecaster\n",
359-
"0 AirPassengers 1961-01-01 473.023018\n",
360-
"1 AirPassengers 1961-02-01 475.729097\n",
361-
"2 AirPassengers 1961-03-01 478.173296\n",
362-
"3 AirPassengers 1961-04-01 480.879374\n",
363-
"4 AirPassengers 1961-05-01 483.498159\n",
364-
"5 AirPassengers 1961-06-01 486.204237\n",
365-
"6 AirPassengers 1961-07-01 488.823023\n",
366-
"7 AirPassengers 1961-08-01 491.529101\n",
367-
"8 AirPassengers 1961-09-01 494.235179\n",
368-
"9 AirPassengers 1961-10-01 496.853964\n",
369-
"10 AirPassengers 1961-11-01 499.560042\n",
370-
"11 AirPassengers 1961-12-01 502.178827\n",
371-
"12 AirPassengers 1962-01-01 504.884906\n",
372-
"13 AirPassengers 1962-02-01 507.590984\n",
373-
"14 AirPassengers 1962-03-01 510.035183\n",
374-
"15 AirPassengers 1962-04-01 512.741261\n",
375-
"16 AirPassengers 1962-05-01 515.360046\n",
376-
"17 AirPassengers 1962-06-01 518.066125\n",
377-
"18 AirPassengers 1962-07-01 520.684910\n",
378-
"19 AirPassengers 1962-08-01 523.390988\n",
379-
"20 AirPassengers 1962-09-01 526.097066\n",
380-
"21 AirPassengers 1962-10-01 528.715851\n",
381-
"22 AirPassengers 1962-11-01 531.421929\n",
382-
"23 AirPassengers 1962-12-01 534.040714"
397+
" unique_id ds sktime.TrendForecaster\n",
398+
"0 AirPassengers 1961-01-01 473.023018\n",
399+
"1 AirPassengers 1961-02-01 475.729097\n",
400+
"2 AirPassengers 1961-03-01 478.173296\n",
401+
"3 AirPassengers 1961-04-01 480.879374\n",
402+
"4 AirPassengers 1961-05-01 483.498159\n",
403+
"5 AirPassengers 1961-06-01 486.204237\n",
404+
"6 AirPassengers 1961-07-01 488.823023\n",
405+
"7 AirPassengers 1961-08-01 491.529101\n",
406+
"8 AirPassengers 1961-09-01 494.235179\n",
407+
"9 AirPassengers 1961-10-01 496.853964\n",
408+
"10 AirPassengers 1961-11-01 499.560042\n",
409+
"11 AirPassengers 1961-12-01 502.178827\n",
410+
"12 AirPassengers 1962-01-01 504.884906\n",
411+
"13 AirPassengers 1962-02-01 507.590984\n",
412+
"14 AirPassengers 1962-03-01 510.035183\n",
413+
"15 AirPassengers 1962-04-01 512.741261\n",
414+
"16 AirPassengers 1962-05-01 515.360046\n",
415+
"17 AirPassengers 1962-06-01 518.066125\n",
416+
"18 AirPassengers 1962-07-01 520.684910\n",
417+
"19 AirPassengers 1962-08-01 523.390988\n",
418+
"20 AirPassengers 1962-09-01 526.097066\n",
419+
"21 AirPassengers 1962-10-01 528.715851\n",
420+
"22 AirPassengers 1962-11-01 531.421929\n",
421+
"23 AirPassengers 1962-12-01 534.040714"
383422
]
384423
},
385-
"execution_count": 8,
424+
"execution_count": 7,
386425
"metadata": {},
387426
"output_type": "execute_result"
388427
}

timecopilot/agent.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from tsfeatures.tsfeatures import _get_feats
3434

3535
from .forecaster import Forecaster, TimeCopilotForecaster
36+
from .models.adapters.sktime import SKTimeAdapter
3637
from .models.prophet import Prophet
3738
from .models.stats import (
3839
ADIDA,
@@ -388,6 +389,18 @@ def _transform_anomalies_to_text(anomalies_df: pd.DataFrame) -> str:
388389
return output
389390

390391

392+
def _is_sktime_forecaster(obj: object) -> bool:
393+
"""
394+
Helper function for checking if an object is an sktime model by checking if
395+
sktime's BaseForecaster class is in its inheritance tree.
396+
"""
397+
mro_types = type(obj).__mro__
398+
for t in mro_types:
399+
if t.__name__ == "BaseForecaster" and "sktime" in t.__module__:
400+
return True
401+
return False
402+
403+
391404
class TimeCopilot:
392405
"""
393406
TimeCopilot: An AI agent for comprehensive time series analysis.
@@ -421,6 +434,24 @@ def __init__(
421434

422435
if forecasters is None:
423436
forecasters = DEFAULT_MODELS
437+
combined_forecasters = []
438+
sktime_forecasters = []
439+
for f in forecasters:
440+
if _is_sktime_forecaster(f):
441+
sktime_forecasters.append(f)
442+
else:
443+
combined_forecasters.append(f)
444+
type_counts: dict[str, int] = {}
445+
for f in sktime_forecasters:
446+
alias = "sktime." + type(f).__name__
447+
if type(f).__name__ in type_counts:
448+
type_counts[type(f).__name__] += 1
449+
alias += f"_{type_counts[type(f).__name__]}"
450+
else:
451+
type_counts[type(f).__name__] = 1
452+
adapted = SKTimeAdapter(f, alias=alias)
453+
combined_forecasters.append(adapted)
454+
forecasters = combined_forecasters
424455
self.forecasters = {forecaster.alias: forecaster for forecaster in forecasters}
425456
if "SeasonalNaive" not in self.forecasters:
426457
self.forecasters["SeasonalNaive"] = SeasonalNaive()

0 commit comments

Comments
 (0)