diff --git a/eventsourcingdb/client.py b/eventsourcingdb/client.py index ad16525..c197a21 100644 --- a/eventsourcingdb/client.py +++ b/eventsourcingdb/client.py @@ -45,12 +45,6 @@ async def __aexit__( ) -> None: await self.__http_client.__aexit__(exc_type, exc_val, exc_tb) - async def initialize(self) -> None: - await self.__http_client.initialize() - - async def close(self) -> None: - await self.__http_client.close() - @property def http_client(self) -> HttpClient: return self.__http_client diff --git a/eventsourcingdb/http_client/http_client.py b/eventsourcingdb/http_client/http_client.py index 14a890a..548563a 100644 --- a/eventsourcingdb/http_client/http_client.py +++ b/eventsourcingdb/http_client/http_client.py @@ -3,8 +3,6 @@ import aiohttp from aiohttp import ClientSession -from ..errors.custom_error import CustomError - from .get_get_headers import get_get_headers from .get_post_headers import get_post_headers from .response import Response @@ -21,7 +19,7 @@ def __init__( self.__session: ClientSession | None = None async def __aenter__(self) -> 'HttpClient': - await self.initialize() + await self.__initialize() return self async def __aexit__( @@ -30,12 +28,16 @@ async def __aexit__( exc_val: BaseException | None = None, exc_tb: TracebackType | None = None, ) -> None: - await self.close() + await self.__close() + + async def __initialize(self) -> None: + # If a session already exists, close it first to prevent leaks + if self.__session is not None: + await self.__session.close() - async def initialize(self) -> None: - self.__session = aiohttp.ClientSession() + self.__session = aiohttp.ClientSession(connector_owner=True) - async def close(self) -> None: + async def __close(self) -> None: if self.__session is not None: await self.__session.close() self.__session = None @@ -49,7 +51,7 @@ def join_segments(first: str, *rest: str) -> str: async def post(self, path: str, request_body: str) -> Response: if self.__session is None: - await self.initialize() + await self.__initialize() url_path = HttpClient.join_segments(self.__base_url, path) headers = get_post_headers(self.__api_token) @@ -70,8 +72,7 @@ async def get( with_authorization: bool = True, ) -> Response: if self.__session is None: - raise CustomError( - "HTTP client session not initialized. Call initialize() before making requests.") + await self.__initialize() async def __request_executor() -> Response: url_path = HttpClient.join_segments(self.__base_url, path) diff --git a/tests/shared/database.py b/tests/shared/database.py index 552748b..7576caa 100644 --- a/tests/shared/database.py +++ b/tests/shared/database.py @@ -35,14 +35,10 @@ def _create_container(cls, api_token, image_tag) -> Container: @staticmethod async def _initialize_clients(container, api_token) -> tuple[Client, Client]: with_authorization_client = container.get_client() - await with_authorization_client.initialize() - with_invalid_url_client = Client( base_url='http://localhost.invalid', api_token=api_token ) - await with_invalid_url_client.initialize() - return with_authorization_client, with_invalid_url_client @classmethod @@ -111,6 +107,17 @@ def get_client(self, client_type: str = CLIENT_TYPE_WITH_AUTH) -> Client: raise ValueError(f'Unknown client type: {client_type}') + async def __aenter__(self) -> 'Database': + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_val: BaseException | None = None, + exc_tb: object | None = None + ) -> None: + await self.stop() + def get_base_url(self) -> str: return self.__container.get_base_url() @@ -118,5 +125,18 @@ def get_api_token(self) -> str: return self.__container.get_api_token() async def stop(self) -> None: + if self.__with_authorization_client: + try: + await self.__with_authorization_client.__aexit__(None, None, None) + except (ConnectionError) as e: + logging.warning("Error closing authorization client: %s", e) + + if self.__with_invalid_url_client: + try: + await self.__with_invalid_url_client.__aexit__(None, None, None) + except (ConnectionError) as e: + logging.warning("Error closing invalid URL client: %s", e) + + # Then stop the container if (container := getattr(self.__class__, '_Database__container', None)): container.stop() diff --git a/tests/shared/start_local_http_server.py b/tests/shared/start_local_http_server.py index d2f3611..db8ca19 100644 --- a/tests/shared/start_local_http_server.py +++ b/tests/shared/start_local_http_server.py @@ -73,6 +73,4 @@ def stop_server() -> None: f'http://localhost:{local_http_server.port}', 'access-token', ) - await client.initialize() - return client, stop_server