diff --git a/octobot_trading/exchange_data/prices/prices_manager.pxd b/octobot_trading/exchange_data/prices/prices_manager.pxd index 454ced78a..67930e518 100644 --- a/octobot_trading/exchange_data/prices/prices_manager.pxd +++ b/octobot_trading/exchange_data/prices/prices_manager.pxd @@ -33,6 +33,7 @@ cdef class PricesManager(util.Initializable): cdef void _reset_prices(self) cdef void _ensure_price_validity(self) cdef bint _are_other_sources_valid(self, str mark_price_source) + cdef bint _is_exchange_mark_price_valid(self) cdef int _compute_mark_price_validity_timeout(self) cpdef bint set_mark_price(self, object mark_price, str mark_price_source) diff --git a/octobot_trading/exchange_data/prices/prices_manager.py b/octobot_trading/exchange_data/prices/prices_manager.py index 63d3ed05c..5b4e87e2d 100644 --- a/octobot_trading/exchange_data/prices/prices_manager.py +++ b/octobot_trading/exchange_data/prices/prices_manager.py @@ -68,14 +68,15 @@ def set_mark_price(self, mark_price, mark_price_source) -> bool: self._set_mark_price_value(mark_price) is_mark_price_updated = True - # set mark price value if MarkPriceSources.RECENT_TRADE_AVERAGE.value has already been updated + # set mark price value if RECENT_TRADE_AVERAGE has already been updated and no EXCHANGE_MARK_PRICE elif mark_price_source == enums.MarkPriceSources.RECENT_TRADE_AVERAGE.value: - if self.mark_price_from_sources.get(enums.MarkPriceSources.RECENT_TRADE_AVERAGE.value, None) is not None: - self._set_mark_price_value(mark_price) - is_mark_price_updated = True - else: - # set time at 0 to ensure invalid price but keep track of initialization - self.mark_price_from_sources[mark_price_source] = (mark_price, 0) + if not self._is_exchange_mark_price_valid(): + if self.mark_price_from_sources.get(enums.MarkPriceSources.RECENT_TRADE_AVERAGE.value, None) is not None: + self._set_mark_price_value(mark_price) + is_mark_price_updated = True + else: + # set time at 0 to ensure invalid price but keep track of initialization + self.mark_price_from_sources[mark_price_source] = (mark_price, 0) # set mark price value if other sources have expired elif mark_price_source in (enums.MarkPriceSources.TICKER_CLOSE_PRICE.value, @@ -130,7 +131,13 @@ def _are_other_sources_valid(self, mark_price_source): self._is_mark_price_valid(source_mark_price[1]): return True return False - + + def _is_exchange_mark_price_valid(self): + if enums.MarkPriceSources.EXCHANGE_MARK_PRICE.value in self.mark_price_from_sources: + return self._is_mark_price_valid( + self.mark_price_from_sources[ + enums.MarkPriceSources.EXCHANGE_MARK_PRICE.value][1]) + def _ensure_price_validity(self): """ Clear the event price validity event if the mark price has expired diff --git a/tests/exchange_data/prices/test_prices_manager.py b/tests/exchange_data/prices/test_prices_manager.py index b36ca180c..997783d0b 100644 --- a/tests/exchange_data/prices/test_prices_manager.py +++ b/tests/exchange_data/prices/test_prices_manager.py @@ -55,19 +55,40 @@ async def test_set_mark_price(prices_manager): assert prices_manager.mark_price == decimal.Decimal(10) check_event_is_set(prices_manager) - + async def test_set_mark_price_for_exchange_source(prices_manager): + prices_manager.set_mark_price(decimal.Decimal(15), MarkPriceSources.TICKER_CLOSE_PRICE.value) + assert prices_manager.mark_price == decimal.Decimal(15) + check_event_is_set(prices_manager) + prices_manager.set_mark_price(decimal.Decimal(30), MarkPriceSources.RECENT_TRADE_AVERAGE.value) + assert prices_manager.mark_price == decimal.Decimal(15) # ignore first call + check_event_is_set(prices_manager) + prices_manager.set_mark_price(decimal.Decimal(30), MarkPriceSources.RECENT_TRADE_AVERAGE.value) + assert prices_manager.mark_price == decimal.Decimal(30) + check_event_is_set(prices_manager) prices_manager.set_mark_price(decimal.Decimal(10), MarkPriceSources.EXCHANGE_MARK_PRICE.value) assert prices_manager.mark_price == decimal.Decimal(10) check_event_is_set(prices_manager) prices_manager.set_mark_price(decimal.Decimal(25), MarkPriceSources.RECENT_TRADE_AVERAGE.value) - assert prices_manager.mark_price == decimal.Decimal(10) # Drop first RT update - prices_manager.set_mark_price(decimal.Decimal(30), MarkPriceSources.RECENT_TRADE_AVERAGE.value) - assert prices_manager.mark_price == decimal.Decimal(30) + assert prices_manager.mark_price == decimal.Decimal(10) # dont override valid exchange mark price + check_event_is_set(prices_manager) prices_manager.set_mark_price(decimal.Decimal(20), MarkPriceSources.TICKER_CLOSE_PRICE.value) - assert prices_manager.mark_price == decimal.Decimal(30) + assert prices_manager.mark_price == decimal.Decimal(10) # dont override valid exchange mark price + check_event_is_set(prices_manager) prices_manager.set_mark_price(decimal.Decimal(15), MarkPriceSources.EXCHANGE_MARK_PRICE.value) assert prices_manager.mark_price == decimal.Decimal(15) + check_event_is_set(prices_manager) + if not os.getenv('CYTHON_IGNORE'): + prices_manager.mark_price_from_sources = {} + prices_manager.set_mark_price(decimal.Decimal(10), MarkPriceSources.CANDLE_CLOSE_PRICE.value) + assert prices_manager.mark_price == decimal.Decimal(10) + check_event_is_set(prices_manager) + prices_manager.set_mark_price(decimal.Decimal(22), MarkPriceSources.EXCHANGE_MARK_PRICE.value) + assert prices_manager.mark_price == decimal.Decimal(22) + check_event_is_set(prices_manager) + prices_manager.set_mark_price(decimal.Decimal(15), MarkPriceSources.CANDLE_CLOSE_PRICE.value) + assert prices_manager.mark_price == decimal.Decimal(22) # dont override valid exchange mark price + check_event_is_set(prices_manager) async def test_set_mark_price_for_ticker_source_only(prices_manager):