Skip to content

Commit 70b0783

Browse files
committed
Adds long-only and short-only config options.
1 parent e00644d commit 70b0783

File tree

5 files changed

+56
-4
lines changed

5 files changed

+56
-4
lines changed

src/pybroker/common.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,20 @@ class FeeInfo(NamedTuple):
173173
order_type: Literal["buy", "sell"]
174174

175175

176+
class PositionMode(Enum):
177+
"""Position mode for backtesting.
178+
179+
Attributes:
180+
DEFAULT: Long and short positions.
181+
LONG_ONLY: Long-only positions.
182+
SHORT_ONLY: Short-only positions.
183+
"""
184+
185+
DEFAULT = "default"
186+
LONG_ONLY = "long_only"
187+
SHORT_ONLY = "short_only"
188+
189+
176190
class BarData:
177191
r"""Contains data for a series of bars. Each field is a
178192
:class:`numpy.ndarray` that contains bar values in the series. The values

src/pybroker/config.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
(see LICENSE for details).
77
"""
88

9-
from pybroker.common import BarData, FeeInfo, FeeMode, PriceType
9+
from pybroker.common import BarData, FeeInfo, FeeMode, PositionMode, PriceType
1010
from dataclasses import dataclass, field
1111
from decimal import Decimal
1212
from typing import Callable, Optional, Union
@@ -35,6 +35,12 @@ class StrategyConfig:
3535
Set to ``True`` for crypto trading. Defaults to ``False``.
3636
round_fill_price: Whether to round fill prices to the nearest cent.
3737
Defaults to ``True``.
38+
position_mode: Position mode for :class:`pybroker.strategy.Strategy`.
39+
Supports one of:
40+
41+
- ``DEFAULT``: Long and short positions.
42+
- ``LONG_ONLY``: Long-only positions.
43+
- ``SHORT_ONLY``: Short-only positions.
3844
max_long_positions: Maximum number of long positions that can be held
3945
at any time in :class:`pybroker.portfolio.Portfolio`. Unlimited
4046
when ``None``. Defaults to ``None``.
@@ -81,6 +87,7 @@ class StrategyConfig:
8187
subtract_fees: bool = field(default=False)
8288
enable_fractional_shares: bool = field(default=False)
8389
round_fill_price: bool = field(default=True)
90+
position_mode: PositionMode = field(default=PositionMode.DEFAULT)
8491
max_long_positions: Optional[int] = field(default=None)
8592
max_short_positions: Optional[int] = field(default=None)
8693
buy_delay: int = field(default=1)

src/pybroker/portfolio.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
DataCol,
1717
FeeInfo,
1818
FeeMode,
19+
PositionMode,
1920
PriceType,
2021
StopType,
2122
to_decimal,
@@ -342,6 +343,7 @@ class Portfolio:
342343
subtract_fees: Whether to subtract fees from the cash balance after an
343344
order is filled.
344345
enable_fractional_shares: Whether to enable trading fractional shares.
346+
position_mode: Position mode for :class:`.Portfolio`.
345347
max_long_positions: Maximum number of long :class:`.Position`\ s that
346348
can be held at a time. If ``None``, then unlimited.
347349
max_short_positions: Maximum number of short :class:`.Position`\ s that
@@ -384,6 +386,7 @@ def __init__(
384386
fee_amount: Optional[float] = None,
385387
subtract_fees: bool = False,
386388
enable_fractional_shares: bool = False,
389+
position_mode: PositionMode = PositionMode.DEFAULT,
387390
max_long_positions: Optional[int] = None,
388391
max_short_positions: Optional[int] = None,
389392
record_stops: Optional[bool] = False,
@@ -396,6 +399,7 @@ def __init__(
396399
)
397400
self._subtract_fees = subtract_fees
398401
self._enable_fractional_shares = enable_fractional_shares
402+
self._position_mode = position_mode
399403
self.equity: Decimal = self.cash
400404
self.market_value: Decimal = self.cash
401405
self.fees = Decimal()
@@ -639,7 +643,7 @@ def buy(
639643
if shares == 0:
640644
return None
641645
covered = self._cover(date, symbol, shares, fill_price)
642-
bought_shares = self._buy(
646+
bought_shares = self._long(
643647
date, symbol, covered.rem_shares, fill_price, limit_price, stops
644648
)
645649
if not covered.filled_shares and not bought_shares:
@@ -724,7 +728,7 @@ def _exit_short(
724728
mfe=mfe,
725729
)
726730

727-
def _buy(
731+
def _long(
728732
self,
729733
date: np.datetime64,
730734
symbol: str,
@@ -733,6 +737,8 @@ def _buy(
733737
limit_price: Optional[Decimal],
734738
stops: Optional[Iterable[Stop]],
735739
) -> Decimal:
740+
if self._position_mode == PositionMode.SHORT_ONLY:
741+
return Decimal()
736742
clamped_shares = self._clamp_shares(fill_price, shares)
737743
if clamped_shares < shares:
738744
self._logger.debug_buy_shares_exceed_cash(
@@ -925,6 +931,8 @@ def _short(
925931
and len(self.short_positions) == self._max_short_positions
926932
):
927933
return Decimal()
934+
if self._position_mode == PositionMode.LONG_ONLY:
935+
return Decimal()
928936
if symbol not in self.short_positions:
929937
self.symbols.add(symbol)
930938
pos = Position(symbol=symbol, shares=shares, type="short")

src/pybroker/strategy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,6 +1251,7 @@ def walkforward(
12511251
self._config.fee_amount,
12521252
self._config.subtract_fees,
12531253
self._fractional_shares_enabled(),
1254+
self._config.position_mode,
12541255
self._config.max_long_positions,
12551256
self._config.max_short_positions,
12561257
self._config.return_stops,

tests/test_portfolio.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import pytest
1212
from collections import deque
1313
from decimal import Decimal
14-
from pybroker.common import FeeMode, PriceType, StopType
14+
from pybroker.common import FeeMode, PositionMode, PriceType, StopType
1515
from pybroker.portfolio import Portfolio, Stop
1616
from pybroker.scope import ColumnScope, PriceScope
1717

@@ -3536,3 +3536,25 @@ def test_mae_mfe_when_long_position():
35363536
assert len(portfolio.trades) == 1
35373537
assert portfolio.trades[0].mae == low_price - fill_price
35383538
assert portfolio.trades[0].mfe == high_price - fill_price
3539+
3540+
3541+
def test_long_only_mode():
3542+
cash = 100_000
3543+
portfolio = Portfolio(cash, position_mode=PositionMode.LONG_ONLY)
3544+
portfolio.buy(DATE_1, SYMBOL_1, 100, FILL_PRICE_1)
3545+
portfolio.sell(DATE_2, SYMBOL_1, 200, FILL_PRICE_1)
3546+
assert not portfolio.long_positions
3547+
assert not portfolio.short_positions
3548+
assert len(portfolio.trades) == 1
3549+
assert portfolio.trades[0].shares == 100
3550+
3551+
3552+
def test_short_only_mode():
3553+
cash = 100_000
3554+
portfolio = Portfolio(cash, position_mode=PositionMode.SHORT_ONLY)
3555+
portfolio.sell(DATE_1, SYMBOL_1, 100, FILL_PRICE_1)
3556+
portfolio.buy(DATE_2, SYMBOL_1, 200, FILL_PRICE_1)
3557+
assert not portfolio.long_positions
3558+
assert not portfolio.short_positions
3559+
assert len(portfolio.trades) == 1
3560+
assert portfolio.trades[0].shares == 100

0 commit comments

Comments
 (0)