Skip to content

Commit c2d90bc

Browse files
committed
impl
# What does this PR do? ## Test Plan
1 parent 047303e commit c2d90bc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+979
-13
lines changed

llama_stack/distribution/datatypes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from llama_stack.apis.vector_io import VectorIO
2727
from llama_stack.providers.datatypes import Api, ProviderSpec
2828
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
29+
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
2930

3031
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
3132
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
@@ -297,6 +298,13 @@ class StackRunConfig(BaseModel):
297298
a default SQLite store will be used.""",
298299
)
299300

301+
inference_store: SqlStoreConfig | None = Field(
302+
default=None,
303+
description="""
304+
Configuration for the persistence store used by the inference API. If not specified,
305+
a default SQLite store will be used.""",
306+
)
307+
300308
# registry of "resources" in the distribution
301309
models: list[ModelInput] = Field(default_factory=list)
302310
shields: list[ShieldInput] = Field(default_factory=list)

llama_stack/distribution/resolver.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ async def resolve_impls(
140140

141141
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
142142

143-
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
143+
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config)
144144

145145

146146
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
@@ -243,7 +243,10 @@ def sort_providers_by_deps(
243243

244244

245245
async def instantiate_providers(
246-
sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry
246+
sorted_providers: list[tuple[str, ProviderWithSpec]],
247+
router_apis: set[Api],
248+
dist_registry: DistributionRegistry,
249+
run_config: StackRunConfig,
247250
) -> dict:
248251
"""Instantiates providers asynchronously while managing dependencies."""
249252
impls: dict[Api, Any] = {}
@@ -258,7 +261,7 @@ async def instantiate_providers(
258261
if isinstance(provider.spec, RoutingTableProviderSpec):
259262
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
260263

261-
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
264+
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config)
262265

263266
if api_str.startswith("inner-"):
264267
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
@@ -308,6 +311,7 @@ async def instantiate_provider(
308311
deps: dict[Api, Any],
309312
inner_impls: dict[str, Any],
310313
dist_registry: DistributionRegistry,
314+
run_config: StackRunConfig,
311315
):
312316
provider_spec = provider.spec
313317
if not hasattr(provider_spec, "module"):
@@ -327,7 +331,7 @@ async def instantiate_provider(
327331
method = "get_auto_router_impl"
328332

329333
config = None
330-
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
334+
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config]
331335
elif isinstance(provider_spec, RoutingTableProviderSpec):
332336
method = "get_routing_table_impl"
333337

llama_stack/distribution/routers/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
from typing import Any
88

99
from llama_stack.distribution.datatypes import RoutedProtocol
10+
from llama_stack.distribution.stack import StackRunConfig
1011
from llama_stack.distribution.store import DistributionRegistry
1112
from llama_stack.providers.datatypes import Api, RoutingTable
13+
from llama_stack.providers.utils.inference.inference_store import InferenceStore
1214

1315
from .routing_tables import (
1416
BenchmarksRoutingTable,
@@ -45,7 +47,9 @@ async def get_routing_table_impl(
4547
return impl
4648

4749

48-
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict[str, Any]) -> Any:
50+
async def get_auto_router_impl(
51+
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig
52+
) -> Any:
4953
from .routers import (
5054
DatasetIORouter,
5155
EvalRouter,
@@ -76,6 +80,11 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict
7680
if dep_api in deps:
7781
api_to_dep_impl[dep_name] = deps[dep_api]
7882

83+
if api == Api.inference and run_config.inference_store:
84+
inference_store = InferenceStore(run_config.inference_store)
85+
await inference_store.initialize()
86+
api_to_dep_impl["store"] = inference_store
87+
7988
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
8089
await impl.initialize()
8190
return impl

llama_stack/distribution/routers/routers.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@
3232
EmbeddingsResponse,
3333
EmbeddingTaskType,
3434
Inference,
35+
ListOpenAIChatCompletionResponse,
3536
LogProbConfig,
3637
Message,
38+
OpenAICompletionWithInputMessages,
39+
Order,
3740
ResponseFormat,
3841
SamplingParams,
3942
StopReason,
@@ -73,6 +76,8 @@
7376
from llama_stack.models.llama.llama3.chat_format import ChatFormat
7477
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
7578
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
79+
from llama_stack.providers.utils.inference.inference_store import InferenceStore
80+
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
7681
from llama_stack.providers.utils.telemetry.tracing import get_current_span
7782

7883
logger = get_logger(name=__name__, category="core")
@@ -141,10 +146,12 @@ def __init__(
141146
self,
142147
routing_table: RoutingTable,
143148
telemetry: Telemetry | None = None,
149+
store: InferenceStore | None = None,
144150
) -> None:
145151
logger.debug("Initializing InferenceRouter")
146152
self.routing_table = routing_table
147153
self.telemetry = telemetry
154+
self.store = store
148155
if self.telemetry:
149156
self.tokenizer = Tokenizer.get_instance()
150157
self.formatter = ChatFormat(self.tokenizer)
@@ -607,9 +614,34 @@ async def openai_chat_completion(
607614

608615
provider = self.routing_table.get_provider_impl(model_obj.identifier)
609616
if stream:
610-
return await provider.openai_chat_completion(**params)
617+
response_stream = await provider.openai_chat_completion(**params)
618+
if self.store:
619+
return stream_and_store_openai_completion(response_stream, model, self.store, messages)
620+
else:
621+
return response_stream
611622
else:
612-
return await self._nonstream_openai_chat_completion(provider, params)
623+
response = await self._nonstream_openai_chat_completion(provider, params)
624+
if self.store:
625+
await self.store.store_chat_completion(response, messages)
626+
return response
627+
628+
async def list_chat_completions(
629+
self,
630+
after: str | None = None,
631+
limit: int | None = 20,
632+
model: str | None = None,
633+
order: Order | None = Order.desc,
634+
) -> ListOpenAIChatCompletionResponse:
635+
if self.store:
636+
return await self.store.list_chat_completions(after, limit, model, order)
637+
else:
638+
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
639+
640+
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
641+
if self.store:
642+
return await self.store.get_chat_completion(completion_id)
643+
else:
644+
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
613645

614646
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
615647
response = await provider.openai_chat_completion(**params)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
from llama_stack.apis.inference import (
7+
ListOpenAIChatCompletionResponse,
8+
OpenAIChatCompletion,
9+
OpenAICompletionWithInputMessages,
10+
OpenAIMessageParam,
11+
Order,
12+
)
13+
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
14+
15+
from ..sqlstore.api import ColumnDefinition, ColumnType
16+
from ..sqlstore.sqlstore import SqlalchemySqlStoreConfig, SqlStoreConfig, sqlstore_impl
17+
18+
19+
class InferenceStore:
20+
def __init__(self, sql_store_config: SqlStoreConfig):
21+
if not sql_store_config:
22+
sql_store_config = SqlalchemySqlStoreConfig(
23+
engine_str="sqlite:///" + (RUNTIME_BASE_DIR / "sqlstore.db").as_posix()
24+
)
25+
self.sql_store = sqlstore_impl(sql_store_config)
26+
27+
async def initialize(self):
28+
"""Create the necessary tables if they don't exist."""
29+
await self.sql_store.create_table(
30+
"chat_completions",
31+
{
32+
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
33+
"created": ColumnType.INTEGER,
34+
"model": ColumnType.STRING,
35+
"choices": ColumnType.JSON,
36+
"input_messages": ColumnType.JSON,
37+
},
38+
)
39+
40+
async def store_chat_completion(
41+
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
42+
) -> None:
43+
data = chat_completion.model_dump()
44+
45+
await self.sql_store.insert(
46+
"chat_completions",
47+
{
48+
"id": data["id"],
49+
"created": data["created"],
50+
"model": data["model"],
51+
"choices": data["choices"],
52+
"input_messages": [message.model_dump() for message in input_messages],
53+
},
54+
)
55+
56+
async def list_chat_completions(
57+
self,
58+
after: str | None = None,
59+
limit: int | None = 50,
60+
model: str | None = None,
61+
order: Order | None = Order.desc,
62+
) -> ListOpenAIChatCompletionResponse:
63+
"""
64+
List chat completions from the database.
65+
66+
:param after: The ID of the last chat completion to return.
67+
:param limit: The maximum number of chat completions to return.
68+
:param model: The model to filter by.
69+
:param order: The order to sort the chat completions by.
70+
"""
71+
# TODO: support after
72+
if after:
73+
raise NotImplementedError("After is not supported for SQLite")
74+
if not order:
75+
order = Order.desc
76+
77+
rows = await self.sql_store.fetch_all(
78+
"chat_completions",
79+
where={"model": model} if model else None,
80+
order_by=[("created", order.value)],
81+
limit=limit,
82+
)
83+
84+
data = [
85+
OpenAICompletionWithInputMessages(
86+
id=row["id"],
87+
created=row["created"],
88+
model=row["model"],
89+
choices=row["choices"],
90+
input_messages=row["input_messages"],
91+
)
92+
for row in rows
93+
]
94+
return ListOpenAIChatCompletionResponse(
95+
data=data,
96+
# TODO: implement has_more
97+
has_more=False,
98+
first_id=data[0].id if data else "",
99+
last_id=data[-1].id if data else "",
100+
)
101+
102+
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
103+
row = await self.sql_store.fetch_one("chat_completions", where={"id": completion_id})
104+
if not row:
105+
raise ValueError(f"Chat completion with id {completion_id} not found") from None
106+
return OpenAICompletionWithInputMessages(
107+
id=row["id"],
108+
created=row["created"],
109+
model=row["model"],
110+
choices=row["choices"],
111+
input_messages=row["input_messages"],
112+
)

0 commit comments

Comments
 (0)