Skip to content

feat: implement get chat completions APIs #2200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions llama_stack/distribution/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,16 @@ def get_provider_dependencies(
# Extract providers based on config type
if isinstance(config, DistributionTemplate):
providers = config.providers
run_configs = config.run_configs
additional_pip_packages: list[str] = []
if run_configs:
for run_config in run_configs.values():
run_config_ = run_config.run_config(name="", providers={}, container_image=None)
if run_config_.inference_store:
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
elif isinstance(config, BuildConfig):
providers = config.distribution_spec.providers
additional_pip_packages = config.additional_pip_packages
deps = []
registry = get_provider_registry(config)
for api_str, provider_or_providers in providers.items():
Expand Down Expand Up @@ -72,6 +80,9 @@ def get_provider_dependencies(
else:
normal_deps.append(package)

if additional_pip_packages:
normal_deps.extend(additional_pip_packages)

return list(set(normal_deps)), list(set(special_deps))


Expand Down
12 changes: 12 additions & 0 deletions llama_stack/distribution/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig

LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
Expand Down Expand Up @@ -297,6 +298,13 @@ class StackRunConfig(BaseModel):
a default SQLite store will be used.""",
)

inference_store: SqlStoreConfig | None = Field(
default=None,
description="""
Configuration for the persistence store used by the inference API. If not specified,
a default SQLite store will be used.""",
)

# registry of "resources" in the distribution
models: list[ModelInput] = Field(default_factory=list)
shields: list[ShieldInput] = Field(default_factory=list)
Expand Down Expand Up @@ -345,6 +353,10 @@ class BuildConfig(BaseModel):
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
"pip_packages MUST contain the provider package name.",
)
additional_pip_packages: list[str] = Field(
default_factory=list,
description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.",
)

@field_validator("external_providers_dir")
@classmethod
Expand Down
12 changes: 8 additions & 4 deletions llama_stack/distribution/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ async def resolve_impls(

sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)

return await instantiate_providers(sorted_providers, router_apis, dist_registry)
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config)


def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
Expand Down Expand Up @@ -243,7 +243,10 @@ def sort_providers_by_deps(


async def instantiate_providers(
sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry
sorted_providers: list[tuple[str, ProviderWithSpec]],
router_apis: set[Api],
dist_registry: DistributionRegistry,
run_config: StackRunConfig,
) -> dict:
"""Instantiates providers asynchronously while managing dependencies."""
impls: dict[Api, Any] = {}
Expand All @@ -258,7 +261,7 @@ async def instantiate_providers(
if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]

impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config)

if api_str.startswith("inner-"):
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
Expand Down Expand Up @@ -308,6 +311,7 @@ async def instantiate_provider(
deps: dict[Api, Any],
inner_impls: dict[str, Any],
dist_registry: DistributionRegistry,
run_config: StackRunConfig,
):
provider_spec = provider.spec
if not hasattr(provider_spec, "module"):
Expand All @@ -327,7 +331,7 @@ async def instantiate_provider(
method = "get_auto_router_impl"

config = None
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config]
elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl"

Expand Down
11 changes: 10 additions & 1 deletion llama_stack/distribution/routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from typing import Any

from llama_stack.distribution.datatypes import RoutedProtocol
from llama_stack.distribution.stack import StackRunConfig
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore

from .routing_tables import (
BenchmarksRoutingTable,
Expand Down Expand Up @@ -45,7 +47,9 @@ async def get_routing_table_impl(
return impl


async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict[str, Any]) -> Any:
async def get_auto_router_impl(
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig
) -> Any:
from .routers import (
DatasetIORouter,
EvalRouter,
Expand Down Expand Up @@ -76,6 +80,11 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict
if dep_api in deps:
api_to_dep_impl[dep_name] = deps[dep_api]

if api == Api.inference and run_config.inference_store:
inference_store = InferenceStore(run_config.inference_store)
await inference_store.initialize()
api_to_dep_impl["store"] = inference_store

impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
await impl.initialize()
return impl
36 changes: 34 additions & 2 deletions llama_stack/distribution/routers/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
ListOpenAIChatCompletionResponse,
LogProbConfig,
Message,
OpenAICompletionWithInputMessages,
Order,
ResponseFormat,
SamplingParams,
StopReason,
Expand Down Expand Up @@ -73,6 +76,8 @@
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
from llama_stack.providers.utils.telemetry.tracing import get_current_span

logger = get_logger(name=__name__, category="core")
Expand Down Expand Up @@ -141,10 +146,12 @@ def __init__(
self,
routing_table: RoutingTable,
telemetry: Telemetry | None = None,
store: InferenceStore | None = None,
) -> None:
logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table
self.telemetry = telemetry
self.store = store
if self.telemetry:
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
Expand Down Expand Up @@ -607,9 +614,34 @@ async def openai_chat_completion(

provider = self.routing_table.get_provider_impl(model_obj.identifier)
if stream:
return await provider.openai_chat_completion(**params)
response_stream = await provider.openai_chat_completion(**params)
if self.store:
return stream_and_store_openai_completion(response_stream, model, self.store, messages)
else:
return response_stream
else:
return await self._nonstream_openai_chat_completion(provider, params)
response = await self._nonstream_openai_chat_completion(provider, params)
if self.store:
await self.store.store_chat_completion(response, messages)
return response

async def list_chat_completions(
self,
after: str | None = None,
limit: int | None = 20,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIChatCompletionResponse:
if self.store:
return await self.store.list_chat_completions(after, limit, model, order)
else:
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")

async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
if self.store:
return await self.store.get_chat_completion(completion_id)
else:
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")

async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(**params)
Expand Down
112 changes: 112 additions & 0 deletions llama_stack/providers/utils/inference/inference_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import (
ListOpenAIChatCompletionResponse,
OpenAIChatCompletion,
OpenAICompletionWithInputMessages,
OpenAIMessageParam,
Order,
)
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR

from ..sqlstore.api import ColumnDefinition, ColumnType
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl


class InferenceStore:
def __init__(self, sql_store_config: SqlStoreConfig):
if not sql_store_config:
sql_store_config = SqliteSqlStoreConfig(
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
)
self.sql_store = sqlstore_impl(sql_store_config)

async def initialize(self):
"""Create the necessary tables if they don't exist."""
await self.sql_store.create_table(
"chat_completions",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"created": ColumnType.INTEGER,
"model": ColumnType.STRING,
"choices": ColumnType.JSON,
"input_messages": ColumnType.JSON,
},
)

async def store_chat_completion(
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
) -> None:
data = chat_completion.model_dump()

await self.sql_store.insert(
"chat_completions",
{
"id": data["id"],
"created": data["created"],
"model": data["model"],
"choices": data["choices"],
"input_messages": [message.model_dump() for message in input_messages],
},
)

async def list_chat_completions(
self,
after: str | None = None,
limit: int | None = 50,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIChatCompletionResponse:
"""
List chat completions from the database.

:param after: The ID of the last chat completion to return.
:param limit: The maximum number of chat completions to return.
:param model: The model to filter by.
:param order: The order to sort the chat completions by.
"""
# TODO: support after
if after:
raise NotImplementedError("After is not supported for SQLite")
if not order:
order = Order.desc

rows = await self.sql_store.fetch_all(
"chat_completions",
where={"model": model} if model else None,
order_by=[("created", order.value)],
limit=limit,
)

data = [
OpenAICompletionWithInputMessages(
id=row["id"],
created=row["created"],
model=row["model"],
choices=row["choices"],
input_messages=row["input_messages"],
)
for row in rows
]
return ListOpenAIChatCompletionResponse(
data=data,
# TODO: implement has_more
has_more=False,
first_id=data[0].id if data else "",
last_id=data[-1].id if data else "",
)

async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
row = await self.sql_store.fetch_one("chat_completions", where={"id": completion_id})
if not row:
raise ValueError(f"Chat completion with id {completion_id} not found") from None
return OpenAICompletionWithInputMessages(
id=row["id"],
created=row["created"],
model=row["model"],
choices=row["choices"],
input_messages=row["input_messages"],
)
Loading
Loading