33import datetime as dt
44import logging
55import warnings
6- from typing import TYPE_CHECKING
6+ from typing import TYPE_CHECKING , assert_never , overload
77
88from backtest_lib .backtest ._helpers import _to_pydt
99from backtest_lib .backtest .results import BacktestResults
1515 PerfectWorldPlanGenerator ,
1616)
1717from backtest_lib .market import MarketView , get_pastview_type_from_backend
18- from backtest_lib .portfolio import CashPortfolio , Portfolio , WeightedPortfolio
18+ from backtest_lib .portfolio import (
19+ Cash ,
20+ FractionalQuantityPortfolio ,
21+ Portfolio ,
22+ QuantityPortfolio ,
23+ WeightedPortfolio ,
24+ )
1925from backtest_lib .strategy import Strategy
2026from backtest_lib .strategy .context import StrategyContext
2127
@@ -105,14 +111,14 @@ class Backtest:
105111 >>> bt = btl.Backtest(hold_strategy, market, uniform_portfolio(universe))
106112 >>> results = bt.run()
107113 >>> results.annualized_return
108- -0.00553564 ...
114+ -0.00035649 ...
109115 """
110116
111117 strategy : Strategy
112118 universe : Universe
113119 market_view : MarketView
114- initial_portfolio : WeightedPortfolio
115- _current_portfolio : WeightedPortfolio
120+ initial_portfolio : Portfolio
121+ _current_portfolio : Portfolio
116122 settings : BacktestSettings
117123 _schedule : DecisionSchedule
118124 _backend : type [PastView ]
@@ -122,7 +128,7 @@ def __init__(
122128 self ,
123129 strategy : Strategy ,
124130 market_view : MarketView ,
125- initial_portfolio : WeightedPortfolio ,
131+ initial_portfolio : Portfolio | Cash ,
126132 universe : Universe | None = None ,
127133 settings : BacktestSettings = _DEFAULT_BACKTEST_SETTINGS ,
128134 * ,
@@ -133,7 +139,7 @@ def __init__(
133139 self .strategy = strategy
134140 self .universe = universe or market_view .securities
135141 self .market_view = market_view
136- if isinstance (initial_portfolio , CashPortfolio ):
142+ if isinstance (initial_portfolio , Cash ):
137143 initial_portfolio = initial_portfolio .materialize (
138144 universe = self .universe , backend = backend
139145 )
@@ -160,6 +166,7 @@ def __init__(
160166 def run (self , ctx : StrategyContext | None = None ) -> BacktestResults :
161167 schedule_it = iter (self ._schedule )
162168 next_decision_period = next (schedule_it )
169+ advance_schedule = False
163170 if ctx is None :
164171 ctx = StrategyContext ()
165172 output_holdings : list [VectorMapping [str , float ]] = []
@@ -176,15 +183,7 @@ def run(self, ctx: StrategyContext | None = None) -> BacktestResults:
176183 f" { self ._current_portfolio .total_value } " ,
177184 )
178185 if ctx .now >= _to_pydt (next_decision_period ):
179- try :
180- next_decision_period = next (schedule_it )
181- except StopIteration :
182- logger .debug (
183- "Reached end of decision schedule, breaking from backtest loop"
184- f" at { ctx .now } (period { i } )." ,
185- )
186- break
187-
186+ advance_schedule = True
188187 # NOTE: we are using close prices here. this is an implicit assumption.
189188 # the user may want to use (low+high)/2, mid price, VWAP/TWAP.
190189 result = self ._engine .execute_strategy (
@@ -195,15 +194,7 @@ def run(self, ctx: StrategyContext | None = None) -> BacktestResults:
195194 prices = yesterday_prices ,
196195 )
197196
198- # problem: we convert into quantities a lot even when we may not have
199- # to. i.e when the user is using a PerfectWorldGenerator/Executor and
200- # only wants to pass target weights.
201- # idea: allow a decorator or something on the user's strategy to say:
202- # @weight_based or @quantity_based or something else so we can optimize
203- # our backtesting behaviour due to not neeeding constant conversions.
204- portfolio_after_decision = result .after .into_quantities (
205- yesterday_prices
206- )
197+ portfolio_after_decision = result .after
207198 logger .debug (
208199 f"engine output for { ctx .now } : { result .after .holdings } , "
209200 f"cash: { result .after .cash } "
@@ -216,9 +207,7 @@ def run(self, ctx: StrategyContext | None = None) -> BacktestResults:
216207 ctx .now ,
217208 )
218209 else :
219- portfolio_after_decision = self ._current_portfolio .into_quantities (
220- yesterday_prices
221- )
210+ portfolio_after_decision = self ._current_portfolio
222211 # TODO: the weights can be calculated as part of the results calculation,
223212 pf_as_weights = portfolio_after_decision .into_weighted (
224213 prices = yesterday_prices
@@ -234,43 +223,93 @@ def run(self, ctx: StrategyContext | None = None) -> BacktestResults:
234223
235224 self ._current_portfolio = inter_day_adjusted_portfolio
236225 yesterday_prices = today_prices
226+ if advance_schedule :
227+ try :
228+ next_decision_period = next (schedule_it )
229+ advance_schedule = False
230+ except StopIteration :
231+ logger .debug (
232+ "Reached end of decision schedule, breaking from backtest loop"
233+ f" at { ctx .now } (period { i } )." ,
234+ )
235+ break
237236
238237 allocation_history : PastView = self ._backend .from_security_mappings (
239238 output_holdings ,
240- self .market_view .periods [: i - 1 ],
239+ self .market_view .periods [:i ],
241240 )
242241 results = BacktestResults .from_weights_market_initial_capital (
243242 weights = allocation_history ,
244- market = self .market_view .truncated_to (i - 1 ),
243+ market = self .market_view .truncated_to (i ),
245244 backend = self ._backend ,
246245 )
247246 return results
248247
249248
249+ @overload
250250def _apply_inter_period_price_changes (
251251 portfolio : WeightedPortfolio ,
252252 pct_change : UniverseMapping [float ],
253- ) -> WeightedPortfolio :
254- prev_cash = portfolio .cash
255- prev_hold = portfolio .holdings
256- # logger.debug(
257- # f"Holdings length: {len(prev_hold)}, pct_change length: {len(pct_change)}, "
258- # f"hold: {prev_hold}, pct_change: {pct_change}"
259- # )
260-
261- new_total_holdings_weight = prev_hold * pct_change
262- new_total_weight = prev_cash + new_total_holdings_weight .sum ()
263-
264- new_cash = prev_cash / new_total_weight
265- new_holdings = new_total_holdings_weight / new_total_weight
266-
267- return WeightedPortfolio (
268- cash = new_cash ,
269- holdings = new_holdings ,
270- universe = new_holdings .keys (), # brittle, review this
271- total_value = portfolio .total_value * new_total_weight ,
272- constructor_backend = portfolio ._backend ,
273- )
253+ ) -> WeightedPortfolio : ...
254+
255+
256+ @overload
257+ def _apply_inter_period_price_changes (
258+ portfolio : QuantityPortfolio ,
259+ pct_change : UniverseMapping [float ],
260+ ) -> QuantityPortfolio : ...
261+
262+
263+ @overload
264+ def _apply_inter_period_price_changes (
265+ portfolio : FractionalQuantityPortfolio ,
266+ pct_change : UniverseMapping [float ],
267+ ) -> FractionalQuantityPortfolio : ...
268+
269+
270+ def _apply_inter_period_price_changes (
271+ portfolio : Portfolio ,
272+ pct_change : UniverseMapping [float ],
273+ ) -> Portfolio :
274+ if isinstance (portfolio , WeightedPortfolio ):
275+ prev_cash = portfolio .cash
276+ prev_hold = portfolio .holdings
277+
278+ new_total_holdings_weight = prev_hold * pct_change
279+ new_total_weight = prev_cash + new_total_holdings_weight .sum ()
280+
281+ new_cash = prev_cash / new_total_weight
282+ new_holdings = new_total_holdings_weight / new_total_weight
283+
284+ return WeightedPortfolio (
285+ cash = new_cash ,
286+ holdings = new_holdings ,
287+ universe = new_holdings .keys (), # brittle, review this
288+ total_value = portfolio .total_value * new_total_weight ,
289+ constructor_backend = portfolio ._backend ,
290+ )
291+ elif isinstance (portfolio , QuantityPortfolio ):
292+ value_changes = portfolio .holdings * pct_change
293+ new_total_value = value_changes .sum ()
294+ return QuantityPortfolio (
295+ universe = portfolio .universe ,
296+ holdings = portfolio .holdings ,
297+ cash = portfolio .cash ,
298+ total_value = new_total_value ,
299+ constructor_backend = portfolio ._backend ,
300+ )
301+ elif isinstance (portfolio , FractionalQuantityPortfolio ):
302+ value_changes = portfolio .holdings * pct_change
303+ new_total_value = value_changes .sum ()
304+ return FractionalQuantityPortfolio (
305+ universe = portfolio .universe ,
306+ holdings = portfolio .holdings ,
307+ cash = portfolio .cash ,
308+ total_value = new_total_value ,
309+ constructor_backend = portfolio ._backend ,
310+ )
311+ else :
312+ assert_never (portfolio )
274313
275314
276315def _check_tradable (
0 commit comments