Skip to content

Commit ff4b9d3

Browse files
committed
Add to_series to official Timeseries contract, change as_series to to_series
Represent results as actual Timeseries object rather than lists Make Backtest preserve the backend string so that it can get different structures for that backend rather than only the PastView
1 parent b81aace commit ff4b9d3

File tree

13 files changed

+154
-72
lines changed

13 files changed

+154
-72
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def aapl_momentum_with_liquidity(
109109

110110
# liquidity filter: average recent volume
111111
if aapl_volume is not None and len(aapl_volume) >= vol_window:
112-
avg_vol = aapl_volume[-vol_window:].as_series().mean()
112+
avg_vol = aapl_volume[-vol_window:].mean()
113113
vol_ok = avg_vol is not None and avg_vol > 0
114114
else:
115115
avg_vol = None

src/backtest_lib/backtest/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class Backtest:
121121
_current_portfolio: Portfolio
122122
settings: BacktestSettings
123123
_schedule: DecisionSchedule
124-
_backend: type[PastView]
124+
_backend: str
125125
_engine: Engine
126126

127127
def __init__(
@@ -155,7 +155,7 @@ def __init__(
155155
self._schedule = make_decision_schedule(market_view.periods)
156156
else:
157157
self._schedule = decision_schedule
158-
self._backend = get_pastview_type_from_backend(backend)
158+
self._backend = backend
159159

160160
self._engine = engine or make_engine(
161161
PerfectWorldPlanGenerator(),
@@ -234,7 +234,9 @@ def run(self, ctx: StrategyContext | None = None) -> BacktestResults:
234234
)
235235
break
236236

237-
allocation_history: PastView = self._backend.from_security_mappings(
237+
allocation_history: PastView = get_pastview_type_from_backend(
238+
self._backend
239+
).from_security_mappings(
238240
output_holdings,
239241
self.market_view.periods[:i],
240242
)

src/backtest_lib/backtest/results.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66
from functools import cached_property
77
from typing import TYPE_CHECKING, Any
88

9+
from backtest_lib.market import (
10+
get_pastview_type_from_backend,
11+
get_timeseries_type_from_backend,
12+
)
13+
from backtest_lib.market.timeseries import Timeseries
14+
915
if TYPE_CHECKING:
1016
from backtest_lib.market import MarketView, PastView
1117
from backtest_lib.market.timeseries import Comparable
@@ -28,12 +34,12 @@ class BacktestResults[IndexT: Comparable]:
2834
asset_returns: PastView[float, IndexT] = field(repr=False)
2935
initial_capital: float
3036

31-
portfolio_returns: list[float]
32-
nav: list[float]
33-
drawdowns: list[float]
34-
gross_exposure: list[float]
35-
net_exposure: list[float]
36-
turnover: list[float]
37+
portfolio_returns: Timeseries[float, IndexT]
38+
nav: Timeseries[float, IndexT]
39+
drawdowns: Timeseries[float, IndexT]
40+
gross_exposure: Timeseries[float, IndexT]
41+
net_exposure: Timeseries[float, IndexT]
42+
turnover: Timeseries[float, IndexT]
3743

3844
total_return: float
3945
annualized_return: float
@@ -43,7 +49,9 @@ class BacktestResults[IndexT: Comparable]:
4349
avg_turnover: float
4450

4551
market: MarketView[IndexT]
46-
_backend: type[PastView]
52+
_backend: str
53+
_backend_pastview_type: type[PastView]
54+
_backend_timeseries_type: type[Timeseries]
4755

4856
@staticmethod
4957
def from_weights_and_returns(
@@ -54,7 +62,7 @@ def from_weights_and_returns(
5462
initial_capital: float = 1.0,
5563
periods_per_year: float = 252.0,
5664
risk_free_annual: float | None = None,
57-
backend: type[PastView],
65+
backend: str = "polars",
5866
) -> BacktestResults[Any]:
5967
"""
6068
Build results from pre-computed per-security simple returns.
@@ -158,18 +166,20 @@ def _std(xs: list[float]) -> float:
158166
)
159167
sharpe = (annualized_return - risk_free_annual) / annualized_volatility
160168

169+
timeseries_type = get_timeseries_type_from_backend(backend)
170+
161171
return BacktestResults(
162172
periods=periods,
163173
securities=securities,
164174
weights=weights,
165175
asset_returns=returns,
166176
initial_capital=initial_capital,
167-
portfolio_returns=portfolio_returns,
168-
nav=nav,
169-
drawdowns=drawdowns,
170-
gross_exposure=gross_exposure,
171-
net_exposure=net_exposure,
172-
turnover=turnover,
177+
portfolio_returns=timeseries_type.from_vectors(portfolio_returns, periods),
178+
nav=timeseries_type.from_vectors(nav, periods),
179+
drawdowns=timeseries_type.from_vectors(drawdowns, periods),
180+
gross_exposure=timeseries_type.from_vectors(gross_exposure, periods),
181+
net_exposure=timeseries_type.from_vectors(net_exposure, periods),
182+
turnover=timeseries_type.from_vectors(turnover, periods),
173183
total_return=total_return,
174184
annualized_return=annualized_return,
175185
annualized_volatility=annualized_volatility,
@@ -178,6 +188,8 @@ def _std(xs: list[float]) -> float:
178188
avg_turnover=avg_turnover,
179189
market=market,
180190
_backend=backend,
191+
_backend_pastview_type=get_pastview_type_from_backend(backend),
192+
_backend_timeseries_type=timeseries_type,
181193
)
182194

183195
@staticmethod
@@ -188,7 +200,7 @@ def from_weights_market_initial_capital(
188200
*,
189201
periods_per_year: float = 252.0,
190202
risk_free_annual: float | None = 0.02,
191-
backend: type[PastView],
203+
backend: str,
192204
) -> BacktestResults[Any]:
193205
"""
194206
Convenience constructor that derives per-security returns from
@@ -244,7 +256,9 @@ def from_weights_market_initial_capital(
244256
.collect()
245257
)
246258

247-
asset_returns: PastView = backend.from_dataframe(asset_returns_df)
259+
backend_pastview_type = get_pastview_type_from_backend(backend)
260+
261+
asset_returns: PastView = backend_pastview_type.from_dataframe(asset_returns_df)
248262

249263
results = BacktestResults.from_weights_and_returns(
250264
weights=weights,
@@ -273,17 +287,16 @@ def quantities_held(self) -> PastView[float, IndexT]:
273287
)
274288
if dtype.is_numeric()
275289
]
276-
nav_series = pl.Series(self.nav)
277290

278291
qtys = joined.select(
279292
"date",
280293
*[
281-
(pl.col(c) * nav_series / pl.col(f"{c}_p")).alias(c)
294+
(pl.col(c) * self.nav.to_series() / pl.col(f"{c}_p")).alias(c)
282295
for c in numeric_cols
283296
],
284297
)
285298

286-
return self._backend.from_dataframe(qtys.collect())
299+
return self._backend_pastview_type.from_dataframe(qtys.collect())
287300

288301
@cached_property
289302
def values_held(self) -> PastView[float, IndexT]:
@@ -297,11 +310,10 @@ def values_held(self) -> PastView[float, IndexT]:
297310
)
298311
if dtype.is_numeric()
299312
]
300-
nav_series = pl.Series(self.nav)
301313

302314
values = weights.select(
303315
"date",
304-
*[(pl.col(c) * nav_series).alias(c) for c in numeric_cols],
316+
*[(pl.col(c) * self.nav.to_series()).alias(c) for c in numeric_cols],
305317
)
306318

307-
return self._backend.from_dataframe(values.collect())
319+
return self._backend_pastview_type.from_dataframe(values.collect())

src/backtest_lib/examples/ewma.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@
150150
"metadata": {},
151151
"outputs": [],
152152
"source": [
153-
"aapl_prices = past_cost_prices.by_security[\"AAPL\"].as_series()\n",
153+
"aapl_prices = past_cost_prices.by_security[\"AAPL\"].to_series()\n",
154154
"\n",
155155
"print(aapl_prices.pct_change() + 1)\n",
156156
"print((aapl_prices.pct_change() + 1).cum_prod())"

src/backtest_lib/examples/full_sp500_test.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@
199199
"final_allocations = pl.DataFrame(\n",
200200
" data={\n",
201201
" \"security\": final_allocations_series.names,\n",
202-
" \"value\": final_allocations_series.as_series(),\n",
202+
" \"value\": final_allocations_series.to_series(),\n",
203203
" },\n",
204204
").filter(pl.col(\"value\") > 0)\n",
205205
"\n",
@@ -225,7 +225,7 @@
225225
"outputs": [],
226226
"source": [
227227
"nvda_pnl = list(\n",
228-
" (results.asset_returns.by_security[\"NVDA\"].as_series() + 1)\n",
228+
" (results.asset_returns.by_security[\"NVDA\"].to_series() + 1)\n",
229229
" .cum_prod()\n",
230230
" .rolling_mean(10, min_samples=10),\n",
231231
")\n",

src/backtest_lib/examples/sp500.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@
298298
"metadata": {},
299299
"outputs": [],
300300
"source": [
301-
"neg_series = bt._current_portfolio.holdings.as_series()\n",
301+
"neg_series = bt._current_portfolio.holdings.to_series()\n",
302302
"neg_mass = neg_series.clip(upper_bound=0).sum()\n",
303303
"pos_mass = neg_series.clip(lower_bound=0).sum()\n",
304304
"pos_only_series = neg_series.clip(lower_bound=0)\n",
@@ -381,8 +381,8 @@
381381
"metadata": {},
382382
"outputs": [],
383383
"source": [
384-
"results.holdings.by_security[\"AAPL\"].as_series().min()\n",
385-
"results.weights.by_security[\"AAPL\"].as_series().min()"
384+
"results.holdings.by_security[\"AAPL\"].to_series().min()\n",
385+
"results.weights.by_security[\"AAPL\"].to_series().min()"
386386
]
387387
},
388388
{

src/backtest_lib/market/polars_impl/_axis.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,9 @@ def from_series(date_s: pl.Series, fmt: str = "%Y-%m-%d") -> PeriodAxis:
3939
date_s = date_s.cast(pl.Datetime("us"))
4040
labels = tuple(date_s.dt.strftime(fmt).to_list())
4141
dt64 = date_s.to_numpy().astype("datetime64[us]", copy=False)
42-
# assert np.all(dt64[1:] >= dt64[:-1])
4342
return PeriodAxis(dt64, labels, {lbl: i for i, lbl in enumerate(labels)})
4443

45-
def take(self, idxs: Sequence[int] | NDArray[np.int64]) -> PeriodAxis:
44+
def take(self, idxs: Sequence[int] | NDArray[np.integer]) -> PeriodAxis:
4645
"""
4746
Creates a new PeriodAxis from a sequence of
4847
integer indices contained in the period axis.

src/backtest_lib/market/polars_impl/_plotting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def line(
6060
class PolarsUniverseMappingPlotAccessor(UniverseMappingPlotAccessor):
6161
def __init__(self, obj: PolarsUniverseMapping):
6262
self._obj = obj
63-
self._series = obj.as_series()
63+
self._series = obj.to_series()
6464

6565
def __call__(
6666
self,
@@ -165,7 +165,7 @@ def stacked_bar(
165165
class PolarsTimeseriesPlotAccessor(TimeseriesPlotAccessor):
166166
def __init__(self, obj: PolarsTimeseries):
167167
self._obj = obj
168-
self._series = obj.as_series()
168+
self._series = obj.to_series()
169169

170170
def __call__(self, **kwargs) -> alt.Chart:
171171
return self.line(**kwargs)

0 commit comments

Comments
 (0)