Skip to content

Commit 73beee0

Browse files
authored
feat: Add shutdown functionality to LlamaStackAsLibraryClient and AsyncLlamaStackAsLibraryClient (ogx-ai#4642)
# What does this PR do? Adds shutdown functionality to LlamaStackAsLibraryClient and AsyncLlamaStackAsLibraryClient. See the change in docs/docs/distributions/importing_as_library.mdx for details. Closes ogx-ai#4641 ## Test Plan The test script from ogx-ai#4641, modified to shut down AsyncLlamaStackAsLibraryClient, exits. ``` """ Minimal test to reproduce the race condition in _load_sentence_transformer_model via the Llama Stack API. """ import asyncio import tempfile from io import BytesIO from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient CONFIG = """ version: 2 image_name: test apis: [files, vector_io, inference] providers: inference: - provider_id: st provider_type: inline::sentence-transformers files: - provider_id: fs provider_type: inline::localfs config: storage_dir: {tmpdir}/files metadata_store: {{table_name: files, backend: sql_default}} vector_io: - provider_id: faiss provider_type: inline::faiss config: persistence: {{namespace: faiss, backend: kv_default}} storage: backends: kv_default: {{type: kv_sqlite, db_path: {tmpdir}/kv.db}} sql_default: {{type: sql_sqlite, db_path: {tmpdir}/sql.db}} stores: metadata: {{namespace: registry, backend: kv_default}} inference: {{table_name: inference, backend: sql_default}} conversations: {{table_name: conversations, backend: sql_default}} prompts: {{namespace: prompts, backend: kv_default}} metadata_store: {{type: sqlite, db_path: {tmpdir}/meta.db}} vector_stores: default_provider_id: faiss default_embedding_model: {{provider_id: st, model_id: all-MiniLM-L6-v2}} file_batch_params: {{max_concurrent_files_per_batch: 5}} registered_resources: models: - model_id: all-MiniLM-L6-v2 provider_id: st provider_model_id: all-MiniLM-L6-v2 model_type: embedding metadata: {{embedding_dimension: 384}} """ async def main(): with tempfile.TemporaryDirectory() as tmpdir: cfg_path = f"{tmpdir}/config.yaml" with open(cfg_path, "w") as f: f.write(CONFIG.format(tmpdir=tmpdir)) async with AsyncLlamaStackAsLibraryClient(cfg_path) as client: # Upload 5 files file_ids = [] for i in range(5): with BytesIO(f"Test content {i}".encode() * 50) as buf: buf.name = f"test_{i}.txt" file_ids.append((await client.files.create(file=buf, purpose="assistants")).id) # Create vector store and batch - triggers concurrent embedding loads vs = await client.vector_stores.create(name="test") batch = await client.vector_stores.file_batches.create( vector_store_id=vs.id, file_ids=file_ids ) # Wait for completion while batch.status == "in_progress": await asyncio.sleep(0.5) batch = await client.vector_stores.file_batches.retrieve( vector_store_id=vs.id, batch_id=batch.id ) print(f"Result: {batch.file_counts.completed} succeeded, {batch.file_counts.failed} failed") if batch.file_counts.failed: print("*** RACE CONDITION DETECTED ***") if __name__ == "__main__": import faulthandler import signal faulthandler.register(signal.SIGUSR1) asyncio.run(main()) ```
1 parent 0fae351 commit 73beee0

13 files changed

Lines changed: 816 additions & 4 deletions

File tree

docs/docs/distributions/importing_as_library.mdx

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,67 @@ If you've created a [custom distribution](./building_distro), you can also use t
3838
```python
3939
client = LlamaStackAsLibraryClient(config_path)
4040
```
41+
42+
## Resource Management
43+
44+
When you're done using the client, you should properly release resources such as database connections. There are two ways to do this:
45+
46+
### Using Context Managers (Recommended)
47+
48+
The easiest and most Pythonic way is to use the client as a context manager, which automatically handles cleanup:
49+
50+
```python
51+
# Synchronous client
52+
from llama_stack.core.library_client import LlamaStackAsLibraryClient
53+
54+
with LlamaStackAsLibraryClient("starter") as client:
55+
response = client.models.list()
56+
# Client is automatically shut down here
57+
```
58+
59+
For the async client:
60+
61+
```python
62+
from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient
63+
64+
async with AsyncLlamaStackAsLibraryClient("starter") as client:
65+
response = await client.models.list()
66+
# Client is automatically shut down here
67+
```
68+
69+
### Using Explicit shutdown()
70+
71+
Alternatively, you can manually call `shutdown()` when you're done:
72+
73+
```python
74+
# Synchronous client
75+
client = LlamaStackAsLibraryClient("starter")
76+
try:
77+
# ... use the client ...
78+
response = client.models.list()
79+
finally:
80+
client.shutdown()
81+
```
82+
83+
For the async client:
84+
85+
```python
86+
from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient
87+
88+
client = AsyncLlamaStackAsLibraryClient("starter")
89+
await client.initialize()
90+
try:
91+
# ... use the client ...
92+
response = await client.models.list()
93+
finally:
94+
await client.shutdown()
95+
```
96+
97+
The `shutdown()` method:
98+
- Closes all database connections (SQLite, PostgreSQL, etc.)
99+
- Releases any held resources
100+
- Can be called multiple times safely (idempotent)
101+
102+
:::tip
103+
If you don't call `shutdown()` or use a context manager, your program may hang on exit while waiting for background threads to complete, especially when using SQLite-based storage backends.
104+
:::

src/llama_stack/core/library_client.py

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,45 @@ def initialize(self):
161161
"""
162162
pass
163163

164+
def shutdown(self) -> None:
165+
"""Shutdown the client and release all resources.
166+
167+
This method should be called when you're done using the client to properly
168+
close database connections and release other resources. Failure to call this
169+
method may result in the program hanging on exit while waiting for background
170+
threads to complete.
171+
172+
This method is idempotent and can be called multiple times safely.
173+
174+
Example:
175+
client = LlamaStackAsLibraryClient("starter")
176+
# ... use the client ...
177+
client.shutdown()
178+
"""
179+
loop = self.loop
180+
asyncio.set_event_loop(loop)
181+
try:
182+
loop.run_until_complete(self.async_client.shutdown())
183+
finally:
184+
loop.close()
185+
asyncio.set_event_loop(None)
186+
187+
def __enter__(self) -> "LlamaStackAsLibraryClient":
188+
"""Enter the context manager.
189+
190+
The client is already initialized in __init__, so this just returns self.
191+
192+
Example:
193+
with LlamaStackAsLibraryClient("starter") as client:
194+
response = client.models.list()
195+
# Client is automatically shut down here
196+
"""
197+
return self
198+
199+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
200+
"""Exit the context manager and shut down the client."""
201+
self.shutdown()
202+
164203
def request(self, *args, **kwargs):
165204
loop = self.loop
166205
asyncio.set_event_loop(loop)
@@ -224,6 +263,7 @@ def __init__(
224263
self.custom_provider_registry = custom_provider_registry
225264
self.provider_data = provider_data
226265
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
266+
self.stack: Stack | None = None
227267

228268
def _remove_root_logger_handlers(self):
229269
"""
@@ -246,9 +286,9 @@ async def initialize(self) -> bool:
246286
try:
247287
self.route_impls = None
248288

249-
stack = Stack(self.config, self.custom_provider_registry)
250-
await stack.initialize()
251-
self.impls = stack.impls
289+
self.stack = Stack(self.config, self.custom_provider_registry)
290+
await self.stack.initialize()
291+
self.impls = self.stack.impls
252292
except ModuleNotFoundError as _e:
253293
cprint(_e.msg, color="red", file=sys.stderr)
254294
cprint(
@@ -283,6 +323,43 @@ async def initialize(self) -> bool:
283323
self.route_impls = initialize_route_impls(self.impls)
284324
return True
285325

326+
async def shutdown(self) -> None:
327+
"""Shutdown the client and release all resources.
328+
329+
This method should be called when you're done using the client to properly
330+
close database connections and release other resources. Failure to call this
331+
method may result in the program hanging on exit while waiting for background
332+
threads to complete.
333+
334+
This method is idempotent and can be called multiple times safely.
335+
336+
Example:
337+
client = AsyncLlamaStackAsLibraryClient("starter")
338+
await client.initialize()
339+
# ... use the client ...
340+
await client.shutdown()
341+
"""
342+
if self.stack:
343+
await self.stack.shutdown()
344+
self.stack = None
345+
346+
async def __aenter__(self) -> "AsyncLlamaStackAsLibraryClient":
347+
"""Enter the async context manager.
348+
349+
Initializes the client and returns it.
350+
351+
Example:
352+
async with AsyncLlamaStackAsLibraryClient("starter") as client:
353+
response = await client.models.list()
354+
# Client is automatically shut down here
355+
"""
356+
await self.initialize()
357+
return self
358+
359+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
360+
"""Exit the async context manager and shut down the client."""
361+
await self.shutdown()
362+
286363
async def request(
287364
self,
288365
cast_to: Any,

src/llama_stack/core/stack.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ def cb(task):
654654
async def shutdown(self):
655655
for impl in self.impls.values():
656656
impl_name = impl.__class__.__name__
657-
logger.info(f"Shutting down {impl_name}")
657+
logger.debug(f"Shutting down {impl_name}")
658658
try:
659659
if hasattr(impl, "shutdown"):
660660
await asyncio.wait_for(impl.shutdown(), timeout=5)
@@ -676,6 +676,20 @@ async def shutdown(self):
676676
if REGISTRY_REFRESH_TASK:
677677
REGISTRY_REFRESH_TASK.cancel()
678678

679+
# Shutdown storage backends
680+
from llama_stack.core.storage.kvstore.kvstore import shutdown_kvstore_backends
681+
from llama_stack.core.storage.sqlstore.sqlstore import shutdown_sqlstore_backends
682+
683+
try:
684+
await shutdown_kvstore_backends()
685+
except Exception as e:
686+
logger.exception(f"Failed to shutdown KV store backends: {e}")
687+
688+
try:
689+
await shutdown_sqlstore_backends()
690+
except Exception as e:
691+
logger.exception(f"Failed to shutdown SQL store backends: {e}")
692+
679693

680694
async def refresh_registry_once(impls: dict[Api, Any]):
681695
logger.debug("refreshing registry")

src/llama_stack/core/storage/kvstore/kvstore.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
6262
async def delete(self, key: str) -> None:
6363
del self._store[key]
6464

65+
async def shutdown(self) -> None:
66+
self._store.clear()
67+
6568

6669
_KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {}
6770
_KVSTORE_INSTANCES: dict[tuple[str, str], KVStore] = {}
@@ -126,3 +129,11 @@ async def kvstore_impl(reference: KVStoreReference) -> KVStore:
126129
await impl.initialize()
127130
_KVSTORE_INSTANCES[cache_key] = impl
128131
return impl
132+
133+
134+
async def shutdown_kvstore_backends() -> None:
135+
"""Shutdown all cached KV store instances."""
136+
global _KVSTORE_INSTANCES
137+
for instance in _KVSTORE_INSTANCES.values():
138+
await instance.shutdown()
139+
_KVSTORE_INSTANCES.clear()

src/llama_stack/core/storage/kvstore/mongodb/mongodb.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,8 @@ async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
8383
async for doc in cursor:
8484
result.append(doc["key"])
8585
return result
86+
87+
async def shutdown(self) -> None:
88+
if self.conn:
89+
await self.conn.close()
90+
self.conn = None

src/llama_stack/core/storage/kvstore/postgres/postgres.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,11 @@ async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
123123
(start_key, end_key),
124124
)
125125
return [row[0] for row in cursor.fetchall()]
126+
127+
async def shutdown(self) -> None:
128+
if self._cursor:
129+
self._cursor.close()
130+
self._cursor = None
131+
if self._conn:
132+
self._conn.close()
133+
self._conn = None

src/llama_stack/core/storage/kvstore/redis/redis.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,8 @@ async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
9999
if cursor == 0:
100100
break
101101
return result
102+
103+
async def shutdown(self) -> None:
104+
if self._redis:
105+
await self._redis.close()
106+
self._redis = None

src/llama_stack/core/storage/sqlstore/sqlalchemy_sqlstore.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,14 @@ def set_sqlite_pragma(dbapi_conn, connection_record):
107107

108108
return engine
109109

110+
async def shutdown(self) -> None:
111+
"""Dispose the session maker's engine and close all connections."""
112+
# The async_session holds a reference to the engine created in __init__
113+
if self.async_session:
114+
engine = self.async_session.kw.get("bind")
115+
if engine:
116+
await engine.dispose()
117+
110118
async def create_table(
111119
self,
112120
table: str,

src/llama_stack/core/storage/sqlstore/sqlstore.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,11 @@ def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> Non
8585
_SQLSTORE_LOCKS.clear()
8686
for name, cfg in backends.items():
8787
_SQLSTORE_BACKENDS[name] = cfg
88+
89+
90+
async def shutdown_sqlstore_backends() -> None:
91+
"""Shutdown all cached SQL store instances."""
92+
global _SQLSTORE_INSTANCES
93+
for instance in _SQLSTORE_INSTANCES.values():
94+
await instance.shutdown()
95+
_SQLSTORE_INSTANCES.clear()

src/llama_stack_api/internal/kvstore.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,7 @@ async def values_in_range(self, start_key: str, end_key: str) -> list[str]: ...
2222

2323
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: ...
2424

25+
async def shutdown(self) -> None: ...
26+
2527

2628
__all__ = ["KVStore"]

0 commit comments

Comments
 (0)