Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] /streaming_chat endpoint PoC #1527

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 84 additions & 30 deletions ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import logging
from typing import Optional

import aiohttp
import requests
from django.http import StreamingHttpResponse
from health_check.exceptions import ServiceUnavailable

from ansible_ai_connect.ai.api.exceptions import (
Expand All @@ -39,6 +41,8 @@
MetaData,
ModelPipelineChatBot,
ModelPipelineCompletions,
ModelPipelineStreamingChatBot,
StreamingChatBotParameters,
)
from ansible_ai_connect.ai.api.model_pipelines.registry import Register
from ansible_ai_connect.healthcheck.backends import (
Expand Down Expand Up @@ -120,13 +124,12 @@ def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_i
raise NotImplementedError


@Register(api_type="http")
class HttpChatBotPipeline(HttpMetaData, ModelPipelineChatBot[HttpConfiguration]):
class HttpChatBotMetaData(HttpMetaData):

def __init__(self, config: HttpConfiguration):
super().__init__(config=config)

def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
def prepare_data(self, params: ChatBotParameters):
query = params.query
conversation_id = params.conversation_id
provider = params.provider
Expand All @@ -142,11 +145,49 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
data["conversation_id"] = str(conversation_id)
if system_prompt:
data["system_prompt"] = str(system_prompt)
return data

def self_test(self) -> Optional[HealthCheckSummary]:
summary: HealthCheckSummary = HealthCheckSummary(
{
MODEL_MESH_HEALTH_CHECK_PROVIDER: "http",
MODEL_MESH_HEALTH_CHECK_MODELS: "ok",
}
)
try:
headers = {"Content-Type": "application/json"}
r = requests.get(self.config.inference_url + "/readiness", headers=headers)
r.raise_for_status()

data = r.json()
ready = data.get("ready")
if not ready:
reason = data.get("reason")
summary.add_exception(
MODEL_MESH_HEALTH_CHECK_MODELS,
HealthCheckSummaryException(ServiceUnavailable(reason)),
)

except Exception as e:
logger.exception(str(e))
summary.add_exception(
MODEL_MESH_HEALTH_CHECK_MODELS,
HealthCheckSummaryException(ServiceUnavailable(ERROR_MESSAGE), e),
)
return summary


@Register(api_type="http")
class HttpChatBotPipeline(HttpChatBotMetaData, ModelPipelineChatBot[HttpConfiguration]):

def __init__(self, config: HttpConfiguration):
super().__init__(config=config)

def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
response = requests.post(
self.config.inference_url + "/v1/query",
headers=self.headers,
json=data,
json=self.prepare_data(params),
timeout=self.timeout(1),
verify=self.config.verify_ssl,
)
Expand All @@ -171,31 +212,44 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
detail = json.loads(response.text).get("detail", "")
raise ChatbotInternalServerException(detail=detail)

def self_test(self) -> Optional[HealthCheckSummary]:
summary: HealthCheckSummary = HealthCheckSummary(
{
MODEL_MESH_HEALTH_CHECK_PROVIDER: "http",
MODEL_MESH_HEALTH_CHECK_MODELS: "ok",
}
)
try:
headers = {"Content-Type": "application/json"}
r = requests.get(self.config.inference_url + "/readiness", headers=headers)
r.raise_for_status()

data = r.json()
ready = data.get("ready")
if not ready:
reason = data.get("reason")
summary.add_exception(
MODEL_MESH_HEALTH_CHECK_MODELS,
HealthCheckSummaryException(ServiceUnavailable(reason)),
)
class HttpStreamingChatBotMetaData(HttpChatBotMetaData):

except Exception as e:
logger.exception(str(e))
summary.add_exception(
MODEL_MESH_HEALTH_CHECK_MODELS,
HealthCheckSummaryException(ServiceUnavailable(ERROR_MESSAGE), e),
)
return summary
def __init__(self, config: HttpConfiguration):
super().__init__(config=config)

def prepare_data(self, params: StreamingChatBotParameters):
data = super().prepare_data(params)

media_type = params.media_type
if media_type:
data["media_type"] = str(media_type)

return data


@Register(api_type="http")
class HttpStreamingChatBotPipeline(
HttpStreamingChatBotMetaData, ModelPipelineStreamingChatBot[HttpConfiguration]
):

def __init__(self, config: HttpConfiguration):
super().__init__(config=config)

def invoke(self, params: StreamingChatBotParameters) -> StreamingHttpResponse:
raise NotImplementedError

async def async_invoke(self, params: StreamingChatBotParameters) -> StreamingHttpResponse:
async with aiohttp.ClientSession(raise_for_status=True) as session:
headers = {
"Content-Type": "application/json",
"Accept": "application/json,text/event-stream",
}
async with session.post(
self.config.inference_url + "/v1/streaming_query",
json=self.prepare_data(params),
headers=headers,
) as r:
async for chunk in r.content:
logger.debug(chunk)
yield chunk
44 changes: 44 additions & 0 deletions ansible_ai_connect/ai/api/model_pipelines/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,33 @@ def init(
ChatBotResponse = Any


@define
class StreamingChatBotParameters(ChatBotParameters):
media_type: str

@classmethod
def init(
cls,
query: str,
provider: Optional[str] = None,
model_id: Optional[str] = None,
conversation_id: Optional[str] = None,
system_prompt: Optional[str] = None,
media_type: Optional[str] = None,
):
return cls(
query=query,
provider=provider,
model_id=model_id,
conversation_id=conversation_id,
system_prompt=system_prompt,
media_type=media_type,
)


StreamingChatBotResponse = Any


class MetaData(Generic[PIPELINE_CONFIGURATION], metaclass=ABCMeta):

def __init__(self, config: PIPELINE_CONFIGURATION):
Expand Down Expand Up @@ -274,6 +301,9 @@ def alias() -> str:
def invoke(self, params: PIPELINE_PARAMETERS) -> PIPELINE_RETURN:
raise NotImplementedError

async def async_invoke(self, params: PIPELINE_PARAMETERS) -> PIPELINE_RETURN:
raise NotImplementedError

@abstractmethod
def self_test(self) -> Optional[HealthCheckSummary]:
raise NotImplementedError
Expand Down Expand Up @@ -381,3 +411,17 @@ def __init__(self, config: PIPELINE_CONFIGURATION):
@staticmethod
def alias():
return "chatbot-service"


class ModelPipelineStreamingChatBot(
ModelPipeline[PIPELINE_CONFIGURATION, ChatBotParameters, StreamingChatBotResponse],
Generic[PIPELINE_CONFIGURATION],
metaclass=ABCMeta,
):

def __init__(self, config: PIPELINE_CONFIGURATION):
super().__init__(config=config)

@staticmethod
def alias():
return "streaming-chatbot-service"
2 changes: 2 additions & 0 deletions ansible_ai_connect/ai/api/model_pipelines/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ModelPipelinePlaybookGeneration,
ModelPipelineRoleExplanation,
ModelPipelineRoleGeneration,
ModelPipelineStreamingChatBot,
)
from ansible_ai_connect.main.settings.types import t_model_mesh_api_type

Expand All @@ -45,6 +46,7 @@
ModelPipelinePlaybookExplanation,
ModelPipelineRoleExplanation,
ModelPipelineChatBot,
ModelPipelineStreamingChatBot,
PipelineConfiguration,
Serializer,
]
Expand Down
8 changes: 8 additions & 0 deletions ansible_ai_connect/ai/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,14 @@ class ChatRequestSerializer(serializers.Serializer):
)


class StreamingChatRequestSerializer(ChatRequestSerializer):
media_type = serializers.CharField(
required=False,
label="Media type",
help_text=("A media type to be used in the output from LLM."),
)


class ReferencedDocumentsSerializer(serializers.Serializer):
docs_url = serializers.CharField()
title = serializers.CharField()
Expand Down
1 change: 1 addition & 0 deletions ansible_ai_connect/ai/api/versions/v1/ai/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@
path("generations/role/", views.GenerationRole.as_view(), name="generations/role"),
path("feedback/", views.Feedback.as_view(), name="feedback"),
path("chat/", views.Chat.as_view(), name="chat"),
path("streaming_chat/", views.StreamingChat.as_view(), name="streaming_chat"),
]
2 changes: 2 additions & 0 deletions ansible_ai_connect/ai/api/versions/v1/ai/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Feedback,
GenerationPlaybook,
GenerationRole,
StreamingChat,
)

__all__ = [
Expand All @@ -32,4 +33,5 @@
"ExplanationRole",
"Feedback",
"Chat",
"StreamingChat",
]
83 changes: 83 additions & 0 deletions ansible_ai_connect/ai/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from attr import asdict
from django.apps import apps
from django.conf import settings
from django.http import StreamingHttpResponse
from django_prometheus.conf import NAMESPACE
from drf_spectacular.utils import OpenApiResponse, extend_schema
from oauth2_provider.contrib.rest_framework import IsAuthenticatedOrTokenHasScope
Expand Down Expand Up @@ -77,10 +78,12 @@
ModelPipelinePlaybookGeneration,
ModelPipelineRoleExplanation,
ModelPipelineRoleGeneration,
ModelPipelineStreamingChatBot,
PlaybookExplanationParameters,
PlaybookGenerationParameters,
RoleExplanationParameters,
RoleGenerationParameters,
StreamingChatBotParameters,
)
from ansible_ai_connect.ai.api.pipelines.completions import CompletionsPipeline
from ansible_ai_connect.ai.api.telemetry import schema1
Expand Down Expand Up @@ -134,6 +137,7 @@
PlaybookGenerationAction,
RoleGenerationAction,
SentimentFeedback,
StreamingChatRequestSerializer,
SuggestionQualityFeedback,
)
from .telemetry.schema1 import (
Expand Down Expand Up @@ -1126,3 +1130,82 @@ def post(self, request) -> Response:
status=rest_framework_status.HTTP_200_OK,
headers=headers,
)


class StreamingChat(AACSAPIView):
"""
Send a message to the backend chatbot service and get a streaming reply.
"""

class StreamingChatEndpointThrottle(EndpointRateThrottle):
scope = "chat"

permission_classes = [
permissions.IsAuthenticated,
IsAuthenticatedOrTokenHasScope,
IsRHInternalUser | IsTestUser,
]
required_scopes = ["read", "write"]
schema1_event = schema1.ChatBotOperationalEvent # TODO
request_serializer_class = StreamingChatRequestSerializer
throttle_classes = [StreamingChatEndpointThrottle]

llm: ModelPipelineStreamingChatBot

def __init__(self):
super().__init__()
self.llm = apps.get_app_config("ai").get_model_pipeline(ModelPipelineStreamingChatBot)

self.chatbot_enabled = (
self.llm.config.inference_url
and self.llm.config.model_id
and settings.CHATBOT_DEFAULT_PROVIDER
)
if self.chatbot_enabled:
logger.debug("Chatbot is enabled.")
else:
logger.debug("Chatbot is not enabled.")

@extend_schema(
request=StreamingChatRequestSerializer,
responses={
200: ChatResponseSerializer, # TODO
400: OpenApiResponse(description="Bad request"),
403: OpenApiResponse(description="Forbidden"),
413: OpenApiResponse(description="Prompt too long"),
422: OpenApiResponse(description="Validation failed"),
500: OpenApiResponse(description="Internal server error"),
503: OpenApiResponse(description="Service unavailable"),
},
summary="Streaming chat request",
)
def post(self, request) -> Response:
if not self.chatbot_enabled:
raise ChatbotNotEnabledException()

req_query = self.validated_data["query"]
req_system_prompt = self.validated_data.get("system_prompt")
req_provider = self.validated_data.get("provider", settings.CHATBOT_DEFAULT_PROVIDER)
conversation_id = self.validated_data.get("conversation_id")
media_type = self.validated_data.get("media_type")

# Initialise Segment Event early, in case of exceptions
self.event.chat_prompt = anonymize_struct(req_query)
self.event.chat_system_prompt = req_system_prompt
self.event.provider_id = req_provider
self.event.conversation_id = conversation_id
self.event.modelName = self.req_model_id or self.llm.config.model_id

return StreamingHttpResponse(
self.llm.async_invoke(
StreamingChatBotParameters.init(
query=req_query,
system_prompt=req_system_prompt,
model_id=self.req_model_id or self.llm.config.model_id,
provider=req_provider,
conversation_id=conversation_id,
media_type=media_type,
)
),
content_type="text/event-stream",
)
Loading
Loading