Skip to content

Commit 461d5f4

Browse files
committed
impl
# What does this PR do? ## Test Plan
1 parent 047303e commit 461d5f4

File tree

42 files changed

+600
-7
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+600
-7
lines changed

llama_stack/distribution/datatypes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
2626
from llama_stack.apis.vector_io import VectorIO
2727
from llama_stack.providers.datatypes import Api, ProviderSpec
28+
from llama_stack.providers.utils.inference.inference_store import InferenceStoreConfig
2829
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
2930

3031
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
@@ -297,6 +298,13 @@ class StackRunConfig(BaseModel):
297298
a default SQLite store will be used.""",
298299
)
299300

301+
inference_store: InferenceStoreConfig | None = Field(
302+
default=None,
303+
description="""
304+
Configuration for the persistence store used by the inference API. If not specified,
305+
a default SQLite store will be used.""",
306+
)
307+
300308
# registry of "resources" in the distribution
301309
models: list[ModelInput] = Field(default_factory=list)
302310
shields: list[ShieldInput] = Field(default_factory=list)

llama_stack/distribution/resolver.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ async def resolve_impls(
140140

141141
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
142142

143-
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
143+
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config)
144144

145145

146146
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
@@ -243,7 +243,10 @@ def sort_providers_by_deps(
243243

244244

245245
async def instantiate_providers(
246-
sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry
246+
sorted_providers: list[tuple[str, ProviderWithSpec]],
247+
router_apis: set[Api],
248+
dist_registry: DistributionRegistry,
249+
run_config: StackRunConfig,
247250
) -> dict:
248251
"""Instantiates providers asynchronously while managing dependencies."""
249252
impls: dict[Api, Any] = {}
@@ -258,7 +261,7 @@ async def instantiate_providers(
258261
if isinstance(provider.spec, RoutingTableProviderSpec):
259262
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
260263

261-
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
264+
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config)
262265

263266
if api_str.startswith("inner-"):
264267
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
@@ -308,6 +311,7 @@ async def instantiate_provider(
308311
deps: dict[Api, Any],
309312
inner_impls: dict[str, Any],
310313
dist_registry: DistributionRegistry,
314+
run_config: StackRunConfig,
311315
):
312316
provider_spec = provider.spec
313317
if not hasattr(provider_spec, "module"):
@@ -327,7 +331,7 @@ async def instantiate_provider(
327331
method = "get_auto_router_impl"
328332

329333
config = None
330-
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
334+
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config]
331335
elif isinstance(provider_spec, RoutingTableProviderSpec):
332336
method = "get_routing_table_impl"
333337

llama_stack/distribution/routers/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
from typing import Any
88

99
from llama_stack.distribution.datatypes import RoutedProtocol
10+
from llama_stack.distribution.stack import StackRunConfig
1011
from llama_stack.distribution.store import DistributionRegistry
1112
from llama_stack.providers.datatypes import Api, RoutingTable
13+
from llama_stack.providers.utils.inference.inference_store import inference_store_impl
1214

1315
from .routing_tables import (
1416
BenchmarksRoutingTable,
@@ -45,7 +47,9 @@ async def get_routing_table_impl(
4547
return impl
4648

4749

48-
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict[str, Any]) -> Any:
50+
async def get_auto_router_impl(
51+
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig
52+
) -> Any:
4953
from .routers import (
5054
DatasetIORouter,
5155
EvalRouter,
@@ -76,6 +80,9 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict
7680
if dep_api in deps:
7781
api_to_dep_impl[dep_name] = deps[dep_api]
7882

83+
if api == Api.inference and run_config.inference_store:
84+
api_to_dep_impl["store"] = await inference_store_impl(run_config.inference_store)
85+
7986
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
8087
await impl.initialize()
8188
return impl

llama_stack/distribution/routers/routers.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@
3232
EmbeddingsResponse,
3333
EmbeddingTaskType,
3434
Inference,
35+
ListOpenAIChatCompletionResponse,
3536
LogProbConfig,
3637
Message,
38+
OpenAICompletionWithInputMessages,
39+
Order,
3740
ResponseFormat,
3841
SamplingParams,
3942
StopReason,
@@ -73,6 +76,8 @@
7376
from llama_stack.models.llama.llama3.chat_format import ChatFormat
7477
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
7578
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
79+
from llama_stack.providers.utils.inference.inference_store import InferenceStore
80+
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
7681
from llama_stack.providers.utils.telemetry.tracing import get_current_span
7782

7883
logger = get_logger(name=__name__, category="core")
@@ -141,10 +146,12 @@ def __init__(
141146
self,
142147
routing_table: RoutingTable,
143148
telemetry: Telemetry | None = None,
149+
store: InferenceStore | None = None,
144150
) -> None:
145151
logger.debug("Initializing InferenceRouter")
146152
self.routing_table = routing_table
147153
self.telemetry = telemetry
154+
self.store = store
148155
if self.telemetry:
149156
self.tokenizer = Tokenizer.get_instance()
150157
self.formatter = ChatFormat(self.tokenizer)
@@ -607,9 +614,34 @@ async def openai_chat_completion(
607614

608615
provider = self.routing_table.get_provider_impl(model_obj.identifier)
609616
if stream:
610-
return await provider.openai_chat_completion(**params)
617+
response_stream = await provider.openai_chat_completion(**params)
618+
if self.store:
619+
return stream_and_store_openai_completion(response_stream, model, self.store, messages)
620+
else:
621+
return response_stream
611622
else:
612-
return await self._nonstream_openai_chat_completion(provider, params)
623+
response = await self._nonstream_openai_chat_completion(provider, params)
624+
if self.store:
625+
await self.store.store_chat_completion(response, messages)
626+
return response
627+
628+
async def list_chat_completions(
629+
self,
630+
after: str | None = None,
631+
limit: int | None = 20,
632+
model: str | None = None,
633+
order: Order | None = Order.desc,
634+
) -> ListOpenAIChatCompletionResponse:
635+
if self.store:
636+
return await self.store.list_chat_completions(after, limit, model, order)
637+
else:
638+
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
639+
640+
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
641+
if self.store:
642+
return await self.store.get_chat_completion(completion_id)
643+
else:
644+
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
613645

614646
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
615647
response = await provider.openai_chat_completion(**params)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
from enum import Enum
8+
from typing import Annotated, Literal, Protocol
9+
10+
from pydantic import BaseModel, Field
11+
12+
from llama_stack.apis.inference import (
13+
ListOpenAIChatCompletionResponse,
14+
OpenAIChatCompletion,
15+
OpenAICompletionWithInputMessages,
16+
OpenAIMessageParam,
17+
Order,
18+
)
19+
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
20+
21+
22+
class InferenceStoreType(Enum):
23+
sqlite = "sqlite"
24+
25+
26+
class SqliteInferenceStoreConfig(BaseModel):
27+
type: Literal["sqlite"] = InferenceStoreType.sqlite.value
28+
db_path: str = Field(
29+
default=(RUNTIME_BASE_DIR / "inference_store.db").as_posix(),
30+
description="File path for the sqlite database",
31+
)
32+
33+
@classmethod
34+
def sample_run_config(cls, __distro_dir__: str, db_name: str = "inference_store.db"):
35+
return {
36+
"type": "sqlite",
37+
"db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
38+
}
39+
40+
41+
InferenceStoreConfig = Annotated[
42+
SqliteInferenceStoreConfig,
43+
Field(discriminator="type", default=InferenceStoreType.sqlite.value),
44+
]
45+
46+
47+
class InferenceStore(Protocol):
48+
async def initialize(self) -> None: ...
49+
50+
async def store_chat_completion(
51+
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
52+
) -> None: ...
53+
54+
async def list_chat_completions(
55+
self,
56+
after: str | None = None,
57+
limit: int | None = 20,
58+
model: str | None = None,
59+
order: Order | None = Order.desc,
60+
) -> ListOpenAIChatCompletionResponse: ...
61+
62+
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages: ...
63+
64+
65+
async def inference_store_impl(config: InferenceStoreConfig) -> InferenceStore:
66+
if config.type == InferenceStoreType.sqlite.value:
67+
from .stores.sqlite import SqliteInferenceStore
68+
69+
impl = SqliteInferenceStore(config.db_path)
70+
else:
71+
raise ValueError(
72+
f"Unknown inference store type {config.type}: available types are {InferenceStoreType.values()}"
73+
)
74+
75+
await impl.initialize()
76+
return impl
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
import json
8+
import os
9+
10+
import aiosqlite
11+
12+
from llama_stack.apis.inference import (
13+
ListOpenAIChatCompletionResponse,
14+
OpenAIChatCompletion,
15+
OpenAICompletionWithInputMessages,
16+
OpenAIMessageParam,
17+
Order,
18+
)
19+
20+
from ..inference_store import InferenceStore
21+
22+
23+
class SqliteInferenceStore(InferenceStore):
24+
def __init__(self, conn_string: str):
25+
self.conn_string = conn_string
26+
27+
async def initialize(self):
28+
"""Create the necessary tables if they don't exist."""
29+
# Create directory if it doesn't exist
30+
os.makedirs(os.path.dirname(self.conn_string), exist_ok=True)
31+
32+
async with aiosqlite.connect(self.conn_string) as conn:
33+
await conn.execute(
34+
"""
35+
CREATE TABLE IF NOT EXISTS chat_completions (
36+
id TEXT PRIMARY KEY,
37+
created INTEGER,
38+
model TEXT,
39+
choices TEXT,
40+
input_messages TEXT
41+
)
42+
"""
43+
)
44+
await conn.commit()
45+
46+
async def store_chat_completion(
47+
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
48+
) -> None:
49+
data = chat_completion.model_dump()
50+
51+
async with aiosqlite.connect(self.conn_string) as conn:
52+
await conn.execute(
53+
"""
54+
INSERT INTO chat_completions (id, created, model, choices, input_messages)
55+
VALUES (?, ?, ?, ?, ?)
56+
""",
57+
(
58+
data["id"],
59+
data["created"],
60+
data["model"],
61+
json.dumps(data["choices"]),
62+
json.dumps([message.model_dump() for message in input_messages]),
63+
),
64+
)
65+
await conn.commit()
66+
67+
async def list_chat_completions(
68+
self,
69+
after: str | None = None,
70+
limit: int | None = 20,
71+
model: str | None = None,
72+
order: Order | None = Order.desc,
73+
) -> ListOpenAIChatCompletionResponse:
74+
"""
75+
List chat completions from the database.
76+
77+
:param after: The ID of the last chat completion to return.
78+
:param limit: The maximum number of chat completions to return.
79+
:param model: The model to filter by.
80+
:param order: The order to sort the chat completions by.
81+
"""
82+
# TODO: support after
83+
if after:
84+
raise NotImplementedError("After is not supported for SQLite")
85+
if not order:
86+
order = Order.desc
87+
88+
async with aiosqlite.connect(self.conn_string) as conn:
89+
conn.row_factory = aiosqlite.Row
90+
where_clause = f"WHERE model = {model}" if model else ""
91+
cursor = await conn.execute(
92+
f"""
93+
SELECT * FROM chat_completions
94+
{where_clause}
95+
ORDER BY created {order.value}
96+
LIMIT {limit}
97+
"""
98+
)
99+
rows = await cursor.fetchall()
100+
101+
data = [
102+
OpenAICompletionWithInputMessages(
103+
id=row["id"],
104+
created=row["created"],
105+
model=row["model"],
106+
choices=json.loads(row["choices"]),
107+
input_messages=json.loads(row["input_messages"]),
108+
)
109+
for row in rows
110+
]
111+
return ListOpenAIChatCompletionResponse(
112+
data=data,
113+
# TODO: implement has_more
114+
has_more=False,
115+
first_id=data[0].id,
116+
last_id=data[-1].id,
117+
)
118+
119+
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
120+
async with aiosqlite.connect(self.conn_string) as conn:
121+
conn.row_factory = aiosqlite.Row
122+
cursor = await conn.execute("SELECT * FROM chat_completions WHERE id = ?", (completion_id,))
123+
row = await cursor.fetchone()
124+
if row is None:
125+
raise ValueError(f"Chat completion with id {completion_id} not found")
126+
return OpenAICompletionWithInputMessages(
127+
id=row["id"],
128+
created=row["created"],
129+
model=row["model"],
130+
choices=json.loads(row["choices"]),
131+
input_messages=json.loads(row["input_messages"]),
132+
)

0 commit comments

Comments
 (0)