Skip to content

Commit 96194a1

Browse files
committed
apis, alt
# What does this PR do? ## Test Plan # What does this PR do? ## Test Plan
1 parent ff247e3 commit 96194a1

File tree

14 files changed

+1241
-895
lines changed

14 files changed

+1241
-895
lines changed

docs/_static/llama-stack-spec.html

Lines changed: 679 additions & 512 deletions
Large diffs are not rendered by default.

docs/_static/llama-stack-spec.yaml

Lines changed: 473 additions & 354 deletions
Large diffs are not rendered by default.

docs/openapi_generator/pyopenapi/generator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,7 @@ def _build_operation(self, op: EndpointOperation) -> Operation:
759759
)
760760

761761
return Operation(
762-
tags=[op.defining_class.__name__],
762+
tags=[op.defining_class.__name__ if op.defining_class.__name__ != "InferenceProvider" else "Inference"],
763763
summary=None,
764764
# summary=doc_string.short_description,
765765
description=description,
@@ -805,6 +805,8 @@ def generate(self) -> Document:
805805
operation_tags: List[Tag] = []
806806
for cls in endpoint_classes:
807807
doc_string = parse_type(cls)
808+
if cls.__name__ == "InferenceProvider":
809+
continue
808810
operation_tags.append(
809811
Tag(
810812
name=cls.__name__,

llama_stack/apis/inference/inference.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -820,14 +820,30 @@ class BatchChatCompletionResponse(BaseModel):
820820
batch: list[ChatCompletionResponse]
821821

822822

823+
@json_schema_type
824+
class ChatCompletion(BaseModel):
825+
id: str
826+
created: int
827+
model: str
828+
messages: list[OpenAIMessageParam]
829+
830+
831+
@json_schema_type
832+
class ListChatCompletionsResponse(BaseModel):
833+
data: list[ChatCompletion]
834+
has_more: bool
835+
836+
837+
class Order(Enum):
838+
asc = "asc"
839+
desc = "desc"
840+
841+
823842
@runtime_checkable
824843
@trace_protocol
825-
class Inference(Protocol):
826-
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
827-
828-
This API provides the raw interface to the underlying models. Two kinds of models are supported:
829-
- LLM models: these models generate "raw" and "chat" (conversational) completions.
830-
- Embedding models: these models generate embeddings to be used for semantic search.
844+
class InferenceProvider(Protocol):
845+
"""
846+
This protocol defines the interface that should be implemented by all inference providers.
831847
"""
832848

833849
model_store: ModelStore | None = None
@@ -1040,3 +1056,39 @@ async def openai_chat_completion(
10401056
:param user: (Optional) The user to use
10411057
"""
10421058
...
1059+
1060+
1061+
class Inference(InferenceProvider):
1062+
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
1063+
1064+
This API provides the raw interface to the underlying models. Two kinds of models are supported:
1065+
- LLM models: these models generate "raw" and "chat" (conversational) completions.
1066+
- Embedding models: these models generate embeddings to be used for semantic search.
1067+
"""
1068+
1069+
@webmethod(route="/inference/chat-completion", method="GET")
1070+
async def list_chat_completions(
1071+
self,
1072+
after: str | None = None,
1073+
limit: int | None = 20,
1074+
model: str | None = None,
1075+
order: Order | None = Order.desc,
1076+
) -> ListChatCompletionsResponse:
1077+
"""List all chat completions.
1078+
1079+
:param after: The ID of the last chat completion to return.
1080+
:param limit: The maximum number of chat completions to return.
1081+
:param model: The model to filter by.
1082+
:param order: The order to sort the chat completions by: "asc" or "desc". Defaults to "desc".
1083+
:returns: A ListChatCompletionsResponse.
1084+
"""
1085+
raise NotImplementedError("List chat completions is not implemented")
1086+
1087+
@webmethod(route="/inference/chat-completion/{completion_id}", method="GET")
1088+
async def get_chat_completion(self, completion_id: str) -> ChatCompletion:
1089+
"""Describe a chat completion by its ID.
1090+
1091+
:param completion_id: ID of the chat completion.
1092+
:returns: A ChatCompletion.
1093+
"""
1094+
raise NotImplementedError("Get chat completion is not implemented")

llama_stack/distribution/resolver.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from llama_stack.apis.datasets import Datasets
1414
from llama_stack.apis.eval import Eval
1515
from llama_stack.apis.files import Files
16-
from llama_stack.apis.inference import Inference
16+
from llama_stack.apis.inference import Inference, InferenceProvider
1717
from llama_stack.apis.inspect import Inspect
1818
from llama_stack.apis.models import Models
1919
from llama_stack.apis.post_training import PostTraining
@@ -83,6 +83,13 @@ def api_protocol_map() -> dict[Api, Any]:
8383
}
8484

8585

86+
def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
87+
return {
88+
**api_protocol_map(),
89+
Api.inference: InferenceProvider,
90+
}
91+
92+
8693
def additional_protocols_map() -> dict[Api, Any]:
8794
return {
8895
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
@@ -302,9 +309,6 @@ async def instantiate_provider(
302309
inner_impls: dict[str, Any],
303310
dist_registry: DistributionRegistry,
304311
):
305-
protocols = api_protocol_map()
306-
additional_protocols = additional_protocols_map()
307-
308312
provider_spec = provider.spec
309313
if not hasattr(provider_spec, "module"):
310314
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
@@ -342,6 +346,8 @@ async def instantiate_provider(
342346
impl.__provider_spec__ = provider_spec
343347
impl.__provider_config__ = config
344348

349+
protocols = api_protocol_map_for_compliance_check()
350+
additional_protocols = additional_protocols_map()
345351
# TODO: check compliance for special tool groups
346352
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
347353
check_protocol_compliance(impl, protocols[provider_spec.api])

llama_stack/providers/inline/inference/meta_reference/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
CompletionRequest,
2929
CompletionResponse,
3030
CompletionResponseStreamChunk,
31-
Inference,
31+
InferenceProvider,
3232
InterleavedContent,
3333
LogProbConfig,
3434
Message,
@@ -86,7 +86,7 @@ class MetaReferenceInferenceImpl(
8686
OpenAICompletionToLlamaStackMixin,
8787
OpenAIChatCompletionToLlamaStackMixin,
8888
SentenceTransformerEmbeddingMixin,
89-
Inference,
89+
InferenceProvider,
9090
ModelsProtocolPrivate,
9191
):
9292
def __init__(self, config: MetaReferenceInferenceConfig) -> None:

llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from llama_stack.apis.inference import (
1111
CompletionResponse,
12-
Inference,
12+
InferenceProvider,
1313
InterleavedContent,
1414
LogProbConfig,
1515
Message,
@@ -38,7 +38,7 @@ class SentenceTransformersInferenceImpl(
3838
OpenAIChatCompletionToLlamaStackMixin,
3939
OpenAICompletionToLlamaStackMixin,
4040
SentenceTransformerEmbeddingMixin,
41-
Inference,
41+
InferenceProvider,
4242
ModelsProtocolPrivate,
4343
):
4444
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:

llama_stack/providers/remote/inference/cerebras_openai_compat/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from llama_stack.apis.inference import Inference
7+
from llama_stack.apis.inference import InferenceProvider
88

99
from .config import CerebrasCompatConfig
1010

1111

12-
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> Inference:
12+
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> InferenceProvider:
1313
# import dynamically so the import is used only when it is needed
1414
from .cerebras import CerebrasCompatInferenceAdapter
1515

llama_stack/providers/remote/inference/fireworks_openai_compat/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from llama_stack.apis.inference import Inference
7+
from llama_stack.apis.inference import InferenceProvider
88

99
from .config import FireworksCompatConfig
1010

1111

12-
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> Inference:
12+
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> InferenceProvider:
1313
# import dynamically so the import is used only when it is needed
1414
from .fireworks import FireworksCompatInferenceAdapter
1515

llama_stack/providers/remote/inference/groq_openai_compat/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from llama_stack.apis.inference import Inference
7+
from llama_stack.apis.inference import InferenceProvider
88

99
from .config import GroqCompatConfig
1010

1111

12-
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> Inference:
12+
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> InferenceProvider:
1313
# import dynamically so the import is used only when it is needed
1414
from .groq import GroqCompatInferenceAdapter
1515

llama_stack/providers/remote/inference/ollama/ollama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
EmbeddingsResponse,
2929
EmbeddingTaskType,
3030
GrammarResponseFormat,
31-
Inference,
31+
InferenceProvider,
3232
JsonSchemaResponseFormat,
3333
LogProbConfig,
3434
Message,
@@ -82,7 +82,7 @@
8282

8383

8484
class OllamaInferenceAdapter(
85-
Inference,
85+
InferenceProvider,
8686
ModelsProtocolPrivate,
8787
):
8888
def __init__(self, url: str) -> None:

llama_stack/providers/remote/inference/sambanova_openai_compat/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from llama_stack.apis.inference import Inference
7+
from llama_stack.apis.inference import InferenceProvider
88

99
from .config import SambaNovaCompatConfig
1010

1111

12-
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> Inference:
12+
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> InferenceProvider:
1313
# import dynamically so the import is used only when it is needed
1414
from .sambanova import SambaNovaCompatInferenceAdapter
1515

llama_stack/providers/remote/inference/together_openai_compat/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from llama_stack.apis.inference import Inference
7+
from llama_stack.apis.inference import InferenceProvider
88

99
from .config import TogetherCompatConfig
1010

1111

12-
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> Inference:
12+
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> InferenceProvider:
1313
# import dynamically so the import is used only when it is needed
1414
from .together import TogetherCompatInferenceAdapter
1515

llama_stack/providers/utils/inference/litellm_openai_mixin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
ChatCompletionResponseStreamChunk,
2020
EmbeddingsResponse,
2121
EmbeddingTaskType,
22-
Inference,
22+
InferenceProvider,
2323
JsonSchemaResponseFormat,
2424
LogProbConfig,
2525
Message,
@@ -59,7 +59,7 @@
5959

6060
class LiteLLMOpenAIMixin(
6161
ModelRegistryHelper,
62-
Inference,
62+
InferenceProvider,
6363
NeedsRequestProviderData,
6464
):
6565
# TODO: avoid exposing the litellm specific model names to the user.

0 commit comments

Comments
 (0)