Skip to content

Commit

Permalink
Follow Model Pipeline pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
TamiTakamiya committed Feb 14, 2025
1 parent cc39ec0 commit f10ebc5
Show file tree
Hide file tree
Showing 9 changed files with 236 additions and 74 deletions.
114 changes: 84 additions & 30 deletions ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import logging
from typing import Optional

import aiohttp
import requests
from django.http import StreamingHttpResponse
from health_check.exceptions import ServiceUnavailable

from ansible_ai_connect.ai.api.exceptions import (
Expand All @@ -39,6 +41,8 @@
MetaData,
ModelPipelineChatBot,
ModelPipelineCompletions,
ModelPipelineStreamingChatBot,
StreamingChatBotParameters,
)
from ansible_ai_connect.ai.api.model_pipelines.registry import Register
from ansible_ai_connect.healthcheck.backends import (
Expand Down Expand Up @@ -120,13 +124,12 @@ def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_i
raise NotImplementedError


@Register(api_type="http")
class HttpChatBotPipeline(HttpMetaData, ModelPipelineChatBot[HttpConfiguration]):
class HttpChatBotMetaData(HttpMetaData):

def __init__(self, config: HttpConfiguration):
super().__init__(config=config)

def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
def prepare_data(self, params: ChatBotParameters):
query = params.query
conversation_id = params.conversation_id
provider = params.provider
Expand All @@ -142,11 +145,49 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
data["conversation_id"] = str(conversation_id)
if system_prompt:
data["system_prompt"] = str(system_prompt)
return data

def self_test(self) -> Optional[HealthCheckSummary]:
summary: HealthCheckSummary = HealthCheckSummary(
{
MODEL_MESH_HEALTH_CHECK_PROVIDER: "http",
MODEL_MESH_HEALTH_CHECK_MODELS: "ok",
}
)
try:
headers = {"Content-Type": "application/json"}
r = requests.get(self.config.inference_url + "/readiness", headers=headers)
r.raise_for_status()

data = r.json()
ready = data.get("ready")
if not ready:
reason = data.get("reason")
summary.add_exception(
MODEL_MESH_HEALTH_CHECK_MODELS,
HealthCheckSummaryException(ServiceUnavailable(reason)),
)

except Exception as e:
logger.exception(str(e))
summary.add_exception(
MODEL_MESH_HEALTH_CHECK_MODELS,
HealthCheckSummaryException(ServiceUnavailable(ERROR_MESSAGE), e),
)
return summary


@Register(api_type="http")
class HttpChatBotPipeline(HttpChatBotMetaData, ModelPipelineChatBot[HttpConfiguration]):

def __init__(self, config: HttpConfiguration):
super().__init__(config=config)

def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
response = requests.post(
self.config.inference_url + "/v1/query",
headers=self.headers,
json=data,
json=self.prepare_data(params),
timeout=self.timeout(1),
verify=self.config.verify_ssl,
)
Expand All @@ -171,31 +212,44 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
detail = json.loads(response.text).get("detail", "")
raise ChatbotInternalServerException(detail=detail)

def self_test(self) -> Optional[HealthCheckSummary]:
summary: HealthCheckSummary = HealthCheckSummary(
{
MODEL_MESH_HEALTH_CHECK_PROVIDER: "http",
MODEL_MESH_HEALTH_CHECK_MODELS: "ok",
}
)
try:
headers = {"Content-Type": "application/json"}
r = requests.get(self.config.inference_url + "/readiness", headers=headers)
r.raise_for_status()

data = r.json()
ready = data.get("ready")
if not ready:
reason = data.get("reason")
summary.add_exception(
MODEL_MESH_HEALTH_CHECK_MODELS,
HealthCheckSummaryException(ServiceUnavailable(reason)),
)
class HttpStreamingChatBotMetaData(HttpChatBotMetaData):

except Exception as e:
logger.exception(str(e))
summary.add_exception(
MODEL_MESH_HEALTH_CHECK_MODELS,
HealthCheckSummaryException(ServiceUnavailable(ERROR_MESSAGE), e),
)
return summary
def __init__(self, config: HttpConfiguration):
super().__init__(config=config)

def prepare_data(self, params: StreamingChatBotParameters):
data = super().prepare_data(params)

media_type = params.media_type
if media_type:
data["media_type"] = str(media_type)

return data


@Register(api_type="http")
class HttpStreamingChatBotPipeline(
HttpStreamingChatBotMetaData, ModelPipelineStreamingChatBot[HttpConfiguration]
):

def __init__(self, config: HttpConfiguration):
super().__init__(config=config)

def invoke(self, params: StreamingChatBotParameters) -> StreamingHttpResponse:
raise NotImplementedError

async def async_invoke(self, params: StreamingChatBotParameters) -> StreamingHttpResponse:
async with aiohttp.ClientSession(raise_for_status=True) as session:
headers = {
"Content-Type": "application/json",
"Accept": "application/json,text/event-stream",
}
async with session.post(
self.config.inference_url + "/v1/streaming_query",
json=self.prepare_data(params),
headers=headers,
) as r:
async for chunk in r.content:
logger.debug(chunk)
yield chunk
44 changes: 44 additions & 0 deletions ansible_ai_connect/ai/api/model_pipelines/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,33 @@ def init(
ChatBotResponse = Any


@define
class StreamingChatBotParameters(ChatBotParameters):
media_type: str

@classmethod
def init(
cls,
query: str,
provider: Optional[str] = None,
model_id: Optional[str] = None,
conversation_id: Optional[str] = None,
system_prompt: Optional[str] = None,
media_type: Optional[str] = None,
):
return cls(
query=query,
provider=provider,
model_id=model_id,
conversation_id=conversation_id,
system_prompt=system_prompt,
media_type=media_type,
)


StreamingChatBotResponse = Any


class MetaData(Generic[PIPELINE_CONFIGURATION], metaclass=ABCMeta):

def __init__(self, config: PIPELINE_CONFIGURATION):
Expand Down Expand Up @@ -274,6 +301,9 @@ def alias() -> str:
def invoke(self, params: PIPELINE_PARAMETERS) -> PIPELINE_RETURN:
raise NotImplementedError

async def async_invoke(self, params: PIPELINE_PARAMETERS) -> PIPELINE_RETURN:
raise NotImplementedError

@abstractmethod
def self_test(self) -> Optional[HealthCheckSummary]:
raise NotImplementedError
Expand Down Expand Up @@ -381,3 +411,17 @@ def __init__(self, config: PIPELINE_CONFIGURATION):
@staticmethod
def alias():
return "chatbot-service"


class ModelPipelineStreamingChatBot(
ModelPipeline[PIPELINE_CONFIGURATION, ChatBotParameters, StreamingChatBotResponse],
Generic[PIPELINE_CONFIGURATION],
metaclass=ABCMeta,
):

def __init__(self, config: PIPELINE_CONFIGURATION):
super().__init__(config=config)

@staticmethod
def alias():
return "streaming-chatbot-service"
2 changes: 2 additions & 0 deletions ansible_ai_connect/ai/api/model_pipelines/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ModelPipelinePlaybookGeneration,
ModelPipelineRoleExplanation,
ModelPipelineRoleGeneration,
ModelPipelineStreamingChatBot,
)
from ansible_ai_connect.main.settings.types import t_model_mesh_api_type

Expand All @@ -45,6 +46,7 @@
ModelPipelinePlaybookExplanation,
ModelPipelineRoleExplanation,
ModelPipelineChatBot,
ModelPipelineStreamingChatBot,
PipelineConfiguration,
Serializer,
]
Expand Down
8 changes: 8 additions & 0 deletions ansible_ai_connect/ai/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,14 @@ class ChatRequestSerializer(serializers.Serializer):
)


class StreamingChatRequestSerializer(ChatRequestSerializer):
media_type = serializers.CharField(
required=False,
label="Media type",
help_text=("A media type to be used in the output from LLM."),
)


class ReferencedDocumentsSerializer(serializers.Serializer):
docs_url = serializers.CharField()
title = serializers.CharField()
Expand Down
40 changes: 0 additions & 40 deletions ansible_ai_connect/ai/api/streaming_chat.py

This file was deleted.

2 changes: 1 addition & 1 deletion ansible_ai_connect/ai/api/versions/v1/ai/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ansible_ai_connect.ai.api.streaming_chat import StreamingChat
from ansible_ai_connect.ai.api.views import (
Chat,
Completions,
Expand All @@ -22,6 +21,7 @@
Feedback,
GenerationPlaybook,
GenerationRole,
StreamingChat,
)

__all__ = [
Expand Down
Loading

0 comments on commit f10ebc5

Please sign in to comment.