Skip to content

Commit 6093af5

Browse files
committed
apis, alt
# What does this PR do? ## Test Plan # What does this PR do? ## Test Plan
1 parent ff247e3 commit 6093af5

File tree

15 files changed

+1509
-1022
lines changed

15 files changed

+1509
-1022
lines changed

docs/_static/llama-stack-spec.html

Lines changed: 753 additions & 511 deletions
Large diffs are not rendered by default.

docs/_static/llama-stack-spec.yaml

Lines changed: 664 additions & 480 deletions
Large diffs are not rendered by default.

docs/openapi_generator/pyopenapi/generator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,7 @@ def _build_operation(self, op: EndpointOperation) -> Operation:
759759
)
760760

761761
return Operation(
762-
tags=[op.defining_class.__name__],
762+
tags=[getattr(op.defining_class, "API_NAMESPACE", op.defining_class.__name__)],
763763
summary=None,
764764
# summary=doc_string.short_description,
765765
description=description,
@@ -805,6 +805,8 @@ def generate(self) -> Document:
805805
operation_tags: List[Tag] = []
806806
for cls in endpoint_classes:
807807
doc_string = parse_type(cls)
808+
if hasattr(cls, "API_NAMESPACE") and cls.API_NAMESPACE != cls.__name__:
809+
continue
808810
operation_tags.append(
809811
Tag(
810812
name=cls.__name__,

llama_stack/apis/inference/inference.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -820,16 +820,33 @@ class BatchChatCompletionResponse(BaseModel):
820820
batch: list[ChatCompletionResponse]
821821

822822

823+
class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
824+
input_messages: list[OpenAIMessageParam]
825+
826+
827+
@json_schema_type
828+
class ListOpenAIChatCompletionResponse(BaseModel):
829+
data: list[OpenAICompletionWithInputMessages]
830+
has_more: bool
831+
first_id: str
832+
last_id: str
833+
object: Literal["list"] = "list"
834+
835+
836+
class Order(Enum):
837+
asc = "asc"
838+
desc = "desc"
839+
840+
823841
@runtime_checkable
824842
@trace_protocol
825-
class Inference(Protocol):
826-
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
827-
828-
This API provides the raw interface to the underlying models. Two kinds of models are supported:
829-
- LLM models: these models generate "raw" and "chat" (conversational) completions.
830-
- Embedding models: these models generate embeddings to be used for semantic search.
843+
class InferenceProvider(Protocol):
844+
"""
845+
This protocol defines the interface that should be implemented by all inference providers.
831846
"""
832847

848+
API_NAMESPACE: str = "Inference"
849+
833850
model_store: ModelStore | None = None
834851

835852
@webmethod(route="/inference/completion", method="POST")
@@ -1040,3 +1057,39 @@ async def openai_chat_completion(
10401057
:param user: (Optional) The user to use
10411058
"""
10421059
...
1060+
1061+
1062+
class Inference(InferenceProvider):
1063+
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
1064+
1065+
This API provides the raw interface to the underlying models. Two kinds of models are supported:
1066+
- LLM models: these models generate "raw" and "chat" (conversational) completions.
1067+
- Embedding models: these models generate embeddings to be used for semantic search.
1068+
"""
1069+
1070+
@webmethod(route="/openai/v1/chat/completions", method="GET")
1071+
async def list_chat_completions(
1072+
self,
1073+
after: str | None = None,
1074+
limit: int | None = 20,
1075+
model: str | None = None,
1076+
order: Order | None = Order.desc,
1077+
) -> ListOpenAIChatCompletionResponse:
1078+
"""List all chat completions.
1079+
1080+
:param after: The ID of the last chat completion to return.
1081+
:param limit: The maximum number of chat completions to return.
1082+
:param model: The model to filter by.
1083+
:param order: The order to sort the chat completions by: "asc" or "desc". Defaults to "desc".
1084+
:returns: A ListOpenAIChatCompletionResponse.
1085+
"""
1086+
raise NotImplementedError("List chat completions is not implemented")
1087+
1088+
@webmethod(route="/openai/v1/chat/completions/{completion_id}", method="GET")
1089+
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
1090+
"""Describe a chat completion by its ID.
1091+
1092+
:param completion_id: ID of the chat completion.
1093+
:returns: A OpenAICompletionWithInputMessages.
1094+
"""
1095+
raise NotImplementedError("Get chat completion is not implemented")

llama_stack/distribution/resolver.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from llama_stack.apis.datasets import Datasets
1414
from llama_stack.apis.eval import Eval
1515
from llama_stack.apis.files import Files
16-
from llama_stack.apis.inference import Inference
16+
from llama_stack.apis.inference import Inference, InferenceProvider
1717
from llama_stack.apis.inspect import Inspect
1818
from llama_stack.apis.models import Models
1919
from llama_stack.apis.post_training import PostTraining
@@ -83,6 +83,13 @@ def api_protocol_map() -> dict[Api, Any]:
8383
}
8484

8585

86+
def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
87+
return {
88+
**api_protocol_map(),
89+
Api.inference: InferenceProvider,
90+
}
91+
92+
8693
def additional_protocols_map() -> dict[Api, Any]:
8794
return {
8895
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
@@ -302,9 +309,6 @@ async def instantiate_provider(
302309
inner_impls: dict[str, Any],
303310
dist_registry: DistributionRegistry,
304311
):
305-
protocols = api_protocol_map()
306-
additional_protocols = additional_protocols_map()
307-
308312
provider_spec = provider.spec
309313
if not hasattr(provider_spec, "module"):
310314
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
@@ -342,6 +346,8 @@ async def instantiate_provider(
342346
impl.__provider_spec__ = provider_spec
343347
impl.__provider_config__ = config
344348

349+
protocols = api_protocol_map_for_compliance_check()
350+
additional_protocols = additional_protocols_map()
345351
# TODO: check compliance for special tool groups
346352
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
347353
check_protocol_compliance(impl, protocols[provider_spec.api])

llama_stack/providers/inline/inference/meta_reference/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
CompletionRequest,
2929
CompletionResponse,
3030
CompletionResponseStreamChunk,
31-
Inference,
31+
InferenceProvider,
3232
InterleavedContent,
3333
LogProbConfig,
3434
Message,
@@ -86,7 +86,7 @@ class MetaReferenceInferenceImpl(
8686
OpenAICompletionToLlamaStackMixin,
8787
OpenAIChatCompletionToLlamaStackMixin,
8888
SentenceTransformerEmbeddingMixin,
89-
Inference,
89+
InferenceProvider,
9090
ModelsProtocolPrivate,
9191
):
9292
def __init__(self, config: MetaReferenceInferenceConfig) -> None:

llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from llama_stack.apis.inference import (
1111
CompletionResponse,
12-
Inference,
12+
InferenceProvider,
1313
InterleavedContent,
1414
LogProbConfig,
1515
Message,
@@ -38,7 +38,7 @@ class SentenceTransformersInferenceImpl(
3838
OpenAIChatCompletionToLlamaStackMixin,
3939
OpenAICompletionToLlamaStackMixin,
4040
SentenceTransformerEmbeddingMixin,
41-
Inference,
41+
InferenceProvider,
4242
ModelsProtocolPrivate,
4343
):
4444
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:

llama_stack/providers/remote/inference/cerebras_openai_compat/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from llama_stack.apis.inference import Inference
7+
from llama_stack.apis.inference import InferenceProvider
88

99
from .config import CerebrasCompatConfig
1010

1111

12-
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> Inference:
12+
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> InferenceProvider:
1313
# import dynamically so the import is used only when it is needed
1414
from .cerebras import CerebrasCompatInferenceAdapter
1515

llama_stack/providers/remote/inference/fireworks_openai_compat/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from llama_stack.apis.inference import Inference
7+
from llama_stack.apis.inference import InferenceProvider
88

99
from .config import FireworksCompatConfig
1010

1111

12-
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> Inference:
12+
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> InferenceProvider:
1313
# import dynamically so the import is used only when it is needed
1414
from .fireworks import FireworksCompatInferenceAdapter
1515

llama_stack/providers/remote/inference/groq_openai_compat/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from llama_stack.apis.inference import Inference
7+
from llama_stack.apis.inference import InferenceProvider
88

99
from .config import GroqCompatConfig
1010

1111

12-
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> Inference:
12+
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> InferenceProvider:
1313
# import dynamically so the import is used only when it is needed
1414
from .groq import GroqCompatInferenceAdapter
1515

llama_stack/providers/remote/inference/llama_openai_compat/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from llama_stack.apis.inference import Inference
7+
from llama_stack.apis.inference import InferenceProvider
88

99
from .config import LlamaCompatConfig
1010

1111

12-
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> Inference:
12+
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> InferenceProvider:
1313
# import dynamically so the import is used only when it is needed
1414
from .llama import LlamaCompatInferenceAdapter
1515

llama_stack/providers/remote/inference/ollama/ollama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
EmbeddingsResponse,
2929
EmbeddingTaskType,
3030
GrammarResponseFormat,
31-
Inference,
31+
InferenceProvider,
3232
JsonSchemaResponseFormat,
3333
LogProbConfig,
3434
Message,
@@ -82,7 +82,7 @@
8282

8383

8484
class OllamaInferenceAdapter(
85-
Inference,
85+
InferenceProvider,
8686
ModelsProtocolPrivate,
8787
):
8888
def __init__(self, url: str) -> None:

llama_stack/providers/remote/inference/sambanova_openai_compat/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from llama_stack.apis.inference import Inference
7+
from llama_stack.apis.inference import InferenceProvider
88

99
from .config import SambaNovaCompatConfig
1010

1111

12-
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> Inference:
12+
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> InferenceProvider:
1313
# import dynamically so the import is used only when it is needed
1414
from .sambanova import SambaNovaCompatInferenceAdapter
1515

llama_stack/providers/remote/inference/together_openai_compat/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from llama_stack.apis.inference import Inference
7+
from llama_stack.apis.inference import InferenceProvider
88

99
from .config import TogetherCompatConfig
1010

1111

12-
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> Inference:
12+
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> InferenceProvider:
1313
# import dynamically so the import is used only when it is needed
1414
from .together import TogetherCompatInferenceAdapter
1515

llama_stack/providers/utils/inference/litellm_openai_mixin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
ChatCompletionResponseStreamChunk,
2020
EmbeddingsResponse,
2121
EmbeddingTaskType,
22-
Inference,
22+
InferenceProvider,
2323
JsonSchemaResponseFormat,
2424
LogProbConfig,
2525
Message,
@@ -59,7 +59,7 @@
5959

6060
class LiteLLMOpenAIMixin(
6161
ModelRegistryHelper,
62-
Inference,
62+
InferenceProvider,
6363
NeedsRequestProviderData,
6464
):
6565
# TODO: avoid exposing the litellm specific model names to the user.

0 commit comments

Comments
 (0)