1616 DataCol ,
1717 FeeInfo ,
1818 FeeMode ,
19+ PositionMode ,
1920 PriceType ,
2021 StopType ,
2122 to_decimal ,
@@ -342,6 +343,7 @@ class Portfolio:
342343 subtract_fees: Whether to subtract fees from the cash balance after an
343344 order is filled.
344345 enable_fractional_shares: Whether to enable trading fractional shares.
346+ position_mode: Position mode for :class:`.Portfolio`.
345347 max_long_positions: Maximum number of long :class:`.Position`\ s that
346348 can be held at a time. If ``None``, then unlimited.
347349 max_short_positions: Maximum number of short :class:`.Position`\ s that
@@ -384,6 +386,7 @@ def __init__(
384386 fee_amount : Optional [float ] = None ,
385387 subtract_fees : bool = False ,
386388 enable_fractional_shares : bool = False ,
389+ position_mode : PositionMode = PositionMode .DEFAULT ,
387390 max_long_positions : Optional [int ] = None ,
388391 max_short_positions : Optional [int ] = None ,
389392 record_stops : Optional [bool ] = False ,
@@ -396,6 +399,7 @@ def __init__(
396399 )
397400 self ._subtract_fees = subtract_fees
398401 self ._enable_fractional_shares = enable_fractional_shares
402+ self ._position_mode = position_mode
399403 self .equity : Decimal = self .cash
400404 self .market_value : Decimal = self .cash
401405 self .fees = Decimal ()
@@ -639,7 +643,7 @@ def buy(
639643 if shares == 0 :
640644 return None
641645 covered = self ._cover (date , symbol , shares , fill_price )
642- bought_shares = self ._buy (
646+ bought_shares = self ._long (
643647 date , symbol , covered .rem_shares , fill_price , limit_price , stops
644648 )
645649 if not covered .filled_shares and not bought_shares :
@@ -724,7 +728,7 @@ def _exit_short(
724728 mfe = mfe ,
725729 )
726730
727- def _buy (
731+ def _long (
728732 self ,
729733 date : np .datetime64 ,
730734 symbol : str ,
@@ -733,6 +737,8 @@ def _buy(
733737 limit_price : Optional [Decimal ],
734738 stops : Optional [Iterable [Stop ]],
735739 ) -> Decimal :
740+ if self ._position_mode == PositionMode .SHORT_ONLY :
741+ return Decimal ()
736742 clamped_shares = self ._clamp_shares (fill_price , shares )
737743 if clamped_shares < shares :
738744 self ._logger .debug_buy_shares_exceed_cash (
@@ -925,6 +931,8 @@ def _short(
925931 and len (self .short_positions ) == self ._max_short_positions
926932 ):
927933 return Decimal ()
934+ if self ._position_mode == PositionMode .LONG_ONLY :
935+ return Decimal ()
928936 if symbol not in self .short_positions :
929937 self .symbols .add (symbol )
930938 pos = Position (symbol = symbol , shares = shares , type = "short" )
0 commit comments