Skip to content

Commit

Permalink
Chat Streaming endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
TamiTakamiya committed Feb 14, 2025
1 parent c1de7c2 commit dae29f3
Show file tree
Hide file tree
Showing 16 changed files with 373 additions and 39 deletions.
114 changes: 84 additions & 30 deletions ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import logging
from typing import Optional

import aiohttp
import requests
from django.http import StreamingHttpResponse
from health_check.exceptions import ServiceUnavailable

from ansible_ai_connect.ai.api.exceptions import (
Expand All @@ -39,6 +41,8 @@
MetaData,
ModelPipelineChatBot,
ModelPipelineCompletions,
ModelPipelineStreamingChatBot,
StreamingChatBotParameters,
)
from ansible_ai_connect.ai.api.model_pipelines.registry import Register
from ansible_ai_connect.healthcheck.backends import (
Expand Down Expand Up @@ -120,13 +124,12 @@ def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_i
raise NotImplementedError


@Register(api_type="http")
class HttpChatBotPipeline(HttpMetaData, ModelPipelineChatBot[HttpConfiguration]):
class HttpChatBotMetaData(HttpMetaData):

def __init__(self, config: HttpConfiguration):
super().__init__(config=config)

def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
def prepare_data(self, params: ChatBotParameters):
query = params.query
conversation_id = params.conversation_id
provider = params.provider
Expand All @@ -142,11 +145,49 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
data["conversation_id"] = str(conversation_id)
if system_prompt:
data["system_prompt"] = str(system_prompt)
return data

def self_test(self) -> Optional[HealthCheckSummary]:
summary: HealthCheckSummary = HealthCheckSummary(
{
MODEL_MESH_HEALTH_CHECK_PROVIDER: "http",
MODEL_MESH_HEALTH_CHECK_MODELS: "ok",
}
)
try:
headers = {"Content-Type": "application/json"}
r = requests.get(self.config.inference_url + "/readiness", headers=headers)
r.raise_for_status()

data = r.json()
ready = data.get("ready")
if not ready:
reason = data.get("reason")
summary.add_exception(
MODEL_MESH_HEALTH_CHECK_MODELS,
HealthCheckSummaryException(ServiceUnavailable(reason)),
)

except Exception as e:
logger.exception(str(e))
summary.add_exception(
MODEL_MESH_HEALTH_CHECK_MODELS,
HealthCheckSummaryException(ServiceUnavailable(ERROR_MESSAGE), e),
)
return summary


@Register(api_type="http")
class HttpChatBotPipeline(HttpChatBotMetaData, ModelPipelineChatBot[HttpConfiguration]):

def __init__(self, config: HttpConfiguration):
super().__init__(config=config)

def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
response = requests.post(
self.config.inference_url + "/v1/query",
headers=self.headers,
json=data,
json=self.prepare_data(params),
timeout=self.timeout(1),
verify=self.config.verify_ssl,
)
Expand All @@ -171,31 +212,44 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
detail = json.loads(response.text).get("detail", "")
raise ChatbotInternalServerException(detail=detail)

def self_test(self) -> Optional[HealthCheckSummary]:
summary: HealthCheckSummary = HealthCheckSummary(
{
MODEL_MESH_HEALTH_CHECK_PROVIDER: "http",
MODEL_MESH_HEALTH_CHECK_MODELS: "ok",
}
)
try:
headers = {"Content-Type": "application/json"}
r = requests.get(self.config.inference_url + "/readiness", headers=headers)
r.raise_for_status()

data = r.json()
ready = data.get("ready")
if not ready:
reason = data.get("reason")
summary.add_exception(
MODEL_MESH_HEALTH_CHECK_MODELS,
HealthCheckSummaryException(ServiceUnavailable(reason)),
)
class HttpStreamingChatBotMetaData(HttpChatBotMetaData):

except Exception as e:
logger.exception(str(e))
summary.add_exception(
MODEL_MESH_HEALTH_CHECK_MODELS,
HealthCheckSummaryException(ServiceUnavailable(ERROR_MESSAGE), e),
)
return summary
def __init__(self, config: HttpConfiguration):
super().__init__(config=config)

def prepare_data(self, params: StreamingChatBotParameters):
data = super().prepare_data(params)

media_type = params.media_type
if media_type:
data["media_type"] = str(media_type)

return data


@Register(api_type="http")
class HttpStreamingChatBotPipeline(
HttpStreamingChatBotMetaData, ModelPipelineStreamingChatBot[HttpConfiguration]
):

def __init__(self, config: HttpConfiguration):
super().__init__(config=config)

def invoke(self, params: StreamingChatBotParameters) -> StreamingHttpResponse:
raise NotImplementedError

async def async_invoke(self, params: StreamingChatBotParameters) -> StreamingHttpResponse:
async with aiohttp.ClientSession(raise_for_status=True) as session:
headers = {
"Content-Type": "application/json",
"Accept": "application/json,text/event-stream",
}
async with session.post(
self.config.inference_url + "/v1/streaming_query",
json=self.prepare_data(params),
headers=headers,
) as r:
async for chunk in r.content:
logger.debug(chunk)
yield chunk
44 changes: 44 additions & 0 deletions ansible_ai_connect/ai/api/model_pipelines/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,33 @@ def init(
ChatBotResponse = Any


@define
class StreamingChatBotParameters(ChatBotParameters):
media_type: str

@classmethod
def init(
cls,
query: str,
provider: Optional[str] = None,
model_id: Optional[str] = None,
conversation_id: Optional[str] = None,
system_prompt: Optional[str] = None,
media_type: Optional[str] = None,
):
return cls(
query=query,
provider=provider,
model_id=model_id,
conversation_id=conversation_id,
system_prompt=system_prompt,
media_type=media_type,
)


StreamingChatBotResponse = Any


class MetaData(Generic[PIPELINE_CONFIGURATION], metaclass=ABCMeta):

def __init__(self, config: PIPELINE_CONFIGURATION):
Expand Down Expand Up @@ -274,6 +301,9 @@ def alias() -> str:
def invoke(self, params: PIPELINE_PARAMETERS) -> PIPELINE_RETURN:
raise NotImplementedError

async def async_invoke(self, params: PIPELINE_PARAMETERS) -> PIPELINE_RETURN:
raise NotImplementedError

@abstractmethod
def self_test(self) -> Optional[HealthCheckSummary]:
raise NotImplementedError
Expand Down Expand Up @@ -381,3 +411,17 @@ def __init__(self, config: PIPELINE_CONFIGURATION):
@staticmethod
def alias():
return "chatbot-service"


class ModelPipelineStreamingChatBot(
ModelPipeline[PIPELINE_CONFIGURATION, ChatBotParameters, StreamingChatBotResponse],
Generic[PIPELINE_CONFIGURATION],
metaclass=ABCMeta,
):

def __init__(self, config: PIPELINE_CONFIGURATION):
super().__init__(config=config)

@staticmethod
def alias():
return "streaming-chatbot-service"
2 changes: 2 additions & 0 deletions ansible_ai_connect/ai/api/model_pipelines/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ModelPipelinePlaybookGeneration,
ModelPipelineRoleExplanation,
ModelPipelineRoleGeneration,
ModelPipelineStreamingChatBot,
)
from ansible_ai_connect.main.settings.types import t_model_mesh_api_type

Expand All @@ -45,6 +46,7 @@
ModelPipelinePlaybookExplanation,
ModelPipelineRoleExplanation,
ModelPipelineChatBot,
ModelPipelineStreamingChatBot,
PipelineConfiguration,
Serializer,
]
Expand Down
8 changes: 8 additions & 0 deletions ansible_ai_connect/ai/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,14 @@ class ChatRequestSerializer(serializers.Serializer):
)


class StreamingChatRequestSerializer(ChatRequestSerializer):
media_type = serializers.CharField(
required=False,
label="Media type",
help_text=("A media type to be used in the output from LLM."),
)


class ReferencedDocumentsSerializer(serializers.Serializer):
docs_url = serializers.CharField()
title = serializers.CharField()
Expand Down
1 change: 1 addition & 0 deletions ansible_ai_connect/ai/api/versions/v1/ai/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@
path("generations/role/", views.GenerationRole.as_view(), name="generations/role"),
path("feedback/", views.Feedback.as_view(), name="feedback"),
path("chat/", views.Chat.as_view(), name="chat"),
path("streaming_chat/", views.StreamingChat.as_view(), name="streaming_chat"),
]
2 changes: 2 additions & 0 deletions ansible_ai_connect/ai/api/versions/v1/ai/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Feedback,
GenerationPlaybook,
GenerationRole,
StreamingChat,
)

__all__ = [
Expand All @@ -32,4 +33,5 @@
"ExplanationRole",
"Feedback",
"Chat",
"StreamingChat",
]
83 changes: 83 additions & 0 deletions ansible_ai_connect/ai/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from attr import asdict
from django.apps import apps
from django.conf import settings
from django.http import StreamingHttpResponse
from django_prometheus.conf import NAMESPACE
from drf_spectacular.utils import OpenApiResponse, extend_schema
from oauth2_provider.contrib.rest_framework import IsAuthenticatedOrTokenHasScope
Expand Down Expand Up @@ -77,10 +78,12 @@
ModelPipelinePlaybookGeneration,
ModelPipelineRoleExplanation,
ModelPipelineRoleGeneration,
ModelPipelineStreamingChatBot,
PlaybookExplanationParameters,
PlaybookGenerationParameters,
RoleExplanationParameters,
RoleGenerationParameters,
StreamingChatBotParameters,
)
from ansible_ai_connect.ai.api.pipelines.completions import CompletionsPipeline
from ansible_ai_connect.ai.api.telemetry import schema1
Expand Down Expand Up @@ -134,6 +137,7 @@
PlaybookGenerationAction,
RoleGenerationAction,
SentimentFeedback,
StreamingChatRequestSerializer,
SuggestionQualityFeedback,
)
from .telemetry.schema1 import (
Expand Down Expand Up @@ -1126,3 +1130,82 @@ def post(self, request) -> Response:
status=rest_framework_status.HTTP_200_OK,
headers=headers,
)


class StreamingChat(AACSAPIView):
"""
Send a message to the backend chatbot service and get a streaming reply.
"""

class StreamingChatEndpointThrottle(EndpointRateThrottle):
scope = "chat"

permission_classes = [
permissions.IsAuthenticated,
IsAuthenticatedOrTokenHasScope,
IsRHInternalUser | IsTestUser,
]
required_scopes = ["read", "write"]
schema1_event = schema1.ChatBotOperationalEvent # TODO
request_serializer_class = StreamingChatRequestSerializer
throttle_classes = [StreamingChatEndpointThrottle]

llm: ModelPipelineStreamingChatBot

def __init__(self):
super().__init__()
self.llm = apps.get_app_config("ai").get_model_pipeline(ModelPipelineStreamingChatBot)

self.chatbot_enabled = (
self.llm.config.inference_url
and self.llm.config.model_id
and settings.CHATBOT_DEFAULT_PROVIDER
)
if self.chatbot_enabled:
logger.debug("Chatbot is enabled.")
else:
logger.debug("Chatbot is not enabled.")

@extend_schema(
request=StreamingChatRequestSerializer,
responses={
200: ChatResponseSerializer, # TODO
400: OpenApiResponse(description="Bad request"),
403: OpenApiResponse(description="Forbidden"),
413: OpenApiResponse(description="Prompt too long"),
422: OpenApiResponse(description="Validation failed"),
500: OpenApiResponse(description="Internal server error"),
503: OpenApiResponse(description="Service unavailable"),
},
summary="Streaming chat request",
)
def post(self, request) -> Response:
if not self.chatbot_enabled:
raise ChatbotNotEnabledException()

req_query = self.validated_data["query"]
req_system_prompt = self.validated_data.get("system_prompt")
req_provider = self.validated_data.get("provider", settings.CHATBOT_DEFAULT_PROVIDER)
conversation_id = self.validated_data.get("conversation_id")
media_type = self.validated_data.get("media_type")

# Initialise Segment Event early, in case of exceptions
self.event.chat_prompt = anonymize_struct(req_query)
self.event.chat_system_prompt = req_system_prompt
self.event.provider_id = req_provider
self.event.conversation_id = conversation_id
self.event.modelName = self.req_model_id or self.llm.config.model_id

return StreamingHttpResponse(
self.llm.async_invoke(
StreamingChatBotParameters.init(
query=req_query,
system_prompt=req_system_prompt,
model_id=self.req_model_id or self.llm.config.model_id,
provider=req_provider,
conversation_id=conversation_id,
media_type=media_type,
)
),
content_type="text/event-stream",
)
Loading

0 comments on commit dae29f3

Please sign in to comment.