diff --git a/locast/candle_storage/candle_storage.py b/locast/candle_storage/candle_storage.py index b1bbbcf..bce5dff 100644 --- a/locast/candle_storage/candle_storage.py +++ b/locast/candle_storage/candle_storage.py @@ -16,6 +16,14 @@ async def retrieve_cluster( resolution: ResolutionDetail, ) -> List[Candle]: ... + async def retrieve_newest_candles( + self, + exchange: Exchange, + market: str, + resolution: ResolutionDetail, + amount: int, + ) -> List[Candle]: ... + async def delete_cluster( self, exchange: Exchange, diff --git a/locast/candle_storage/sql/sqlite_candle_storage.py b/locast/candle_storage/sql/sqlite_candle_storage.py index c8847f3..a2fc9db 100644 --- a/locast/candle_storage/sql/sqlite_candle_storage.py +++ b/locast/candle_storage/sql/sqlite_candle_storage.py @@ -62,6 +62,38 @@ async def retrieve_cluster( else: return [] + async def retrieve_newest_candles( + self, + exchange: Exchange, + market: str, + resolution: ResolutionDetail, + amount: int, + ) -> List[Candle]: + with Session(self._engine) as session: + if foreign_keys := self._look_up_foreign_keys( + exchange, + market, + resolution, + session, + ): + sqlite_exchange, sqlite_market, sqlite_resolution = foreign_keys + + stmnt = ( + select(SqliteCandle) + .where( + (SqliteCandle.exchange_id == sqlite_exchange.id) + & (SqliteCandle.market_id == sqlite_market.id) + & (SqliteCandle.resolution_id == sqlite_resolution.id) + ) + .order_by(desc(SqliteCandle.started_at)) + .limit(amount) + ) + + results = session.exec(stmnt) + return self._to_candles(list(results.all())) + else: + return [] + async def delete_cluster( self, exchange: Exchange, diff --git a/locast/store_manager/store_manager.py b/locast/store_manager/store_manager.py index 2a89b7f..91568c3 100644 --- a/locast/store_manager/store_manager.py +++ b/locast/store_manager/store_manager.py @@ -71,13 +71,36 @@ async def retrieve_cluster( ) -> List[Candle]: cluster_info = await self.get_cluster_info(exchange, market, resolution) - if not cluster_info.newest_candle: + if cluster_info.size == 0: raise MissingClusterException( f"Cluster does not exist for market {market} and resolution {resolution.notation}." ) return await self._candle_storage.retrieve_cluster(exchange, market, resolution) + async def retrieve_newest_candles( + self, + exchange: Exchange, + market: str, + resolution: ResolutionDetail, + amount: int, + ) -> List[Candle]: + cluster_info = await self.get_cluster_info(exchange, market, resolution) + + if cluster_info.size == 0: + raise MissingClusterException( + f"Cluster does not exist for market {market} and resolution {resolution.notation}." + ) + + amount_to_retrieve = min(amount, cluster_info.size) + + return await self._candle_storage.retrieve_newest_candles( + exchange, + market, + resolution, + amount_to_retrieve, + ) + async def update_cluster( self, exchange: Exchange, @@ -138,7 +161,7 @@ async def _check_horizon( market: str, resolution: ResolutionDetail, start_date: datetime, - ): + ) -> datetime: if not (horizon := self._horizon_cache.get(f"{market}_{resolution.notation}")): horizon = await self._candle_fetcher.find_horizon(market, resolution) self._horizon_cache[f"{market}_{resolution.notation}"] = horizon diff --git a/tests/candle_storage/sql/test_sqlite_candle_storage.py b/tests/candle_storage/sql/test_sqlite_candle_storage.py index e20ec4c..1835c8d 100644 --- a/tests/candle_storage/sql/test_sqlite_candle_storage.py +++ b/tests/candle_storage/sql/test_sqlite_candle_storage.py @@ -77,7 +77,7 @@ async def test_store_candles_results_in_correct_storage_state( @pytest.mark.parametrize("amount", few_amounts) @pytest.mark.asyncio -async def test_retrieve_candles_results_in_correct_cluster( +async def test_retrieve_cluster_results_in_correct_cluster( sqlite_candle_storage_memory: SqliteCandleStorage, amount: int, ) -> None: @@ -103,7 +103,7 @@ async def test_retrieve_candles_results_in_correct_cluster( @pytest.mark.asyncio -async def test_retrieve_candles_results_in_empty_list( +async def test_retrieve_cluster_results_in_empty_list( sqlite_candle_storage_memory: SqliteCandleStorage, ) -> None: # given @@ -120,6 +120,90 @@ async def test_retrieve_candles_results_in_empty_list( assert len(retrieved_candles) == 0 +@pytest.mark.asyncio +async def test_retrieve_newest_candles_results_in_correct_list( + sqlite_candle_storage_memory: SqliteCandleStorage, +) -> None: + # given + storage = sqlite_candle_storage_memory + + exchange = Exchange.DYDX_V4 + res = ResolutionDetail(Seconds.ONE_MINUTE, "1MIN") + start_date = string_to_datetime("2022-01-01T00:00:00.000Z") + market = "ETH-USD" + amount_mocked = 100 + amount_retreived = 10 + + candles = mock_dydx_v4_candles(market, res, amount_mocked, start_date) + await storage.store_candles(candles) + + # when + retrieved_candles = await storage.retrieve_newest_candles( + exchange, + market, + res, + amount_retreived, + ) + + # then + assert len(retrieved_candles) == amount_retreived + assert candles[0].started_at == retrieved_candles[0].started_at + + +@pytest.mark.asyncio +async def test_retrieve_newest_candles_corrects_amount_to_cluster_size( + sqlite_candle_storage_memory: SqliteCandleStorage, +) -> None: + # given + storage = sqlite_candle_storage_memory + + exchange = Exchange.DYDX_V4 + res = ResolutionDetail(Seconds.ONE_MINUTE, "1MIN") + start_date = string_to_datetime("2022-01-01T00:00:00.000Z") + market = "ETH-USD" + amount_mocked = 50 + amount_retreived = 60 + + candles = mock_dydx_v4_candles(market, res, amount_mocked, start_date) + await storage.store_candles(candles) + + # when cluster size is less than requested to retrieve + retrieved_candles = await storage.retrieve_newest_candles( + exchange, + market, + res, + amount_retreived, + ) + + # then retrieved size equals cluster size + assert len(retrieved_candles) == amount_mocked + assert candles[0].started_at == retrieved_candles[0].started_at + + +@pytest.mark.asyncio +async def test_retrieve_newest_candles_results_in_empty_list( + sqlite_candle_storage_memory: SqliteCandleStorage, +) -> None: + # given + storage = sqlite_candle_storage_memory + + exchange = Exchange.DYDX_V4 + res = ResolutionDetail(Seconds.ONE_MINUTE, "1MIN") + market = "ETH-USD" + amount = 10 + + # when no cluster in storage + retrieved_candles = await storage.retrieve_newest_candles( + exchange, + market, + res, + amount, + ) + + # then + assert len(retrieved_candles) == 0 + + @pytest.mark.parametrize("amount", few_amounts) @pytest.mark.asyncio async def test_delete_cluster_results_in_correct_state(