diff --git a/chromadb/api/async_client.py b/chromadb/api/async_client.py index a167c83e290..3b9424404fe 100644 --- a/chromadb/api/async_client.py +++ b/chromadb/api/async_client.py @@ -1,5 +1,5 @@ import httpx -from typing import Optional, Sequence +from typing import Optional, Sequence, cast from uuid import UUID from overrides import override @@ -103,7 +103,44 @@ async def from_system_async( database: str = DEFAULT_DATABASE, ) -> "AsyncClient": """Create a client from an existing system. This is useful for testing and debugging.""" - return await AsyncClient.create(tenant, database, system.settings) + self = cls.__new__(cls) + self._identifier = SharedSystemClient._populate_data_from_system(system) + SharedSystemClient._increment_refcount(self._identifier) + + try: + self.tenant = tenant + self.database = database + + self._server = self._system.instance(AsyncServerAPI) + + user_identity = await self.get_user_identity() + + maybe_tenant, maybe_database = maybe_set_tenant_and_database( + user_identity, + overwrite_singleton_tenant_database_access_from_auth=system.settings.chroma_overwrite_singleton_tenant_database_access_from_auth, + user_provided_tenant=tenant, + user_provided_database=database, + ) + if maybe_tenant: + self.tenant = maybe_tenant + if maybe_database: + self.database = maybe_database + + self._admin_client = AsyncAdminClient.from_system(self._system) + await self._validate_tenant_database( + tenant=self.tenant, database=self.database + ) + + self._submit_client_start_event() + + return self + except Exception: + if hasattr(self, "_admin_client"): + SharedSystemClient._release_system( + cast("AsyncAdminClient", self._admin_client)._identifier + ) + SharedSystemClient._release_system(self._identifier) + raise @classmethod @override diff --git a/chromadb/test/test_async_client_from_system.py b/chromadb/test/test_async_client_from_system.py new file mode 100644 index 00000000000..eb237adb3b7 --- /dev/null +++ b/chromadb/test/test_async_client_from_system.py @@ -0,0 +1,39 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +from chromadb.api.async_client import AsyncClient +from chromadb.auth import UserIdentity +from chromadb.config import Settings, System + + +def test_async_client_from_system_async_reuses_provided_system() -> None: + settings = Settings( + chroma_api_impl="chromadb.api.async_fastapi.AsyncFastAPI", + chroma_server_host="localhost", + chroma_server_http_port=9000, + ) + system = MagicMock(spec=System) + system.settings = settings + system.instance.return_value = MagicMock() + + with patch.object( + AsyncClient, + "get_user_identity", + new=AsyncMock( + return_value=UserIdentity( + user_id="test-user", + tenant="default_tenant", + databases=["default_database"], + ) + ), + ): + with patch.object(AsyncClient, "_validate_tenant_database", new=AsyncMock()): + with patch.object(AsyncClient, "_submit_client_start_event"): + with patch( + "chromadb.api.async_client.AsyncAdminClient.from_system", + return_value=MagicMock(), + ): + client = asyncio.run(AsyncClient.from_system_async(system)) + + assert client._system is system + client.clear_system_cache()