From f10ebc5e858f1ede2bfaa83a44bbd6a2a4ce0f56 Mon Sep 17 00:00:00 2001 From: Tami Takamiya Date: Sat, 8 Feb 2025 15:44:19 -0500 Subject: [PATCH] Follow Model Pipeline pattern --- .../ai/api/model_pipelines/http/pipelines.py | 114 +++++++++++++----- .../ai/api/model_pipelines/pipelines.py | 44 +++++++ .../ai/api/model_pipelines/registry.py | 2 + ansible_ai_connect/ai/api/serializers.py | 8 ++ ansible_ai_connect/ai/api/streaming_chat.py | 40 ------ .../ai/api/versions/v1/ai/views.py | 2 +- ansible_ai_connect/ai/api/views.py | 83 +++++++++++++ ansible_ai_connect/main/settings/legacy.py | 9 ++ ansible_ai_connect/main/views.py | 8 +- 9 files changed, 236 insertions(+), 74 deletions(-) delete mode 100644 ansible_ai_connect/ai/api/streaming_chat.py diff --git a/ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py b/ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py index 3ce05be77..41b021eb7 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py +++ b/ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py @@ -16,7 +16,9 @@ import logging from typing import Optional +import aiohttp import requests +from django.http import StreamingHttpResponse from health_check.exceptions import ServiceUnavailable from ansible_ai_connect.ai.api.exceptions import ( @@ -39,6 +41,8 @@ MetaData, ModelPipelineChatBot, ModelPipelineCompletions, + ModelPipelineStreamingChatBot, + StreamingChatBotParameters, ) from ansible_ai_connect.ai.api.model_pipelines.registry import Register from ansible_ai_connect.healthcheck.backends import ( @@ -120,13 +124,12 @@ def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_i raise NotImplementedError -@Register(api_type="http") -class HttpChatBotPipeline(HttpMetaData, ModelPipelineChatBot[HttpConfiguration]): +class HttpChatBotMetaData(HttpMetaData): def __init__(self, config: HttpConfiguration): super().__init__(config=config) - def invoke(self, params: ChatBotParameters) -> ChatBotResponse: + def prepare_data(self, params: ChatBotParameters): query = params.query conversation_id = params.conversation_id provider = params.provider @@ -142,11 +145,49 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse: data["conversation_id"] = str(conversation_id) if system_prompt: data["system_prompt"] = str(system_prompt) + return data + + def self_test(self) -> Optional[HealthCheckSummary]: + summary: HealthCheckSummary = HealthCheckSummary( + { + MODEL_MESH_HEALTH_CHECK_PROVIDER: "http", + MODEL_MESH_HEALTH_CHECK_MODELS: "ok", + } + ) + try: + headers = {"Content-Type": "application/json"} + r = requests.get(self.config.inference_url + "/readiness", headers=headers) + r.raise_for_status() + + data = r.json() + ready = data.get("ready") + if not ready: + reason = data.get("reason") + summary.add_exception( + MODEL_MESH_HEALTH_CHECK_MODELS, + HealthCheckSummaryException(ServiceUnavailable(reason)), + ) + + except Exception as e: + logger.exception(str(e)) + summary.add_exception( + MODEL_MESH_HEALTH_CHECK_MODELS, + HealthCheckSummaryException(ServiceUnavailable(ERROR_MESSAGE), e), + ) + return summary + + +@Register(api_type="http") +class HttpChatBotPipeline(HttpChatBotMetaData, ModelPipelineChatBot[HttpConfiguration]): + + def __init__(self, config: HttpConfiguration): + super().__init__(config=config) + def invoke(self, params: ChatBotParameters) -> ChatBotResponse: response = requests.post( self.config.inference_url + "/v1/query", headers=self.headers, - json=data, + json=self.prepare_data(params), timeout=self.timeout(1), verify=self.config.verify_ssl, ) @@ -171,31 +212,44 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse: detail = json.loads(response.text).get("detail", "") raise ChatbotInternalServerException(detail=detail) - def self_test(self) -> Optional[HealthCheckSummary]: - summary: HealthCheckSummary = HealthCheckSummary( - { - MODEL_MESH_HEALTH_CHECK_PROVIDER: "http", - MODEL_MESH_HEALTH_CHECK_MODELS: "ok", - } - ) - try: - headers = {"Content-Type": "application/json"} - r = requests.get(self.config.inference_url + "/readiness", headers=headers) - r.raise_for_status() - data = r.json() - ready = data.get("ready") - if not ready: - reason = data.get("reason") - summary.add_exception( - MODEL_MESH_HEALTH_CHECK_MODELS, - HealthCheckSummaryException(ServiceUnavailable(reason)), - ) +class HttpStreamingChatBotMetaData(HttpChatBotMetaData): - except Exception as e: - logger.exception(str(e)) - summary.add_exception( - MODEL_MESH_HEALTH_CHECK_MODELS, - HealthCheckSummaryException(ServiceUnavailable(ERROR_MESSAGE), e), - ) - return summary + def __init__(self, config: HttpConfiguration): + super().__init__(config=config) + + def prepare_data(self, params: StreamingChatBotParameters): + data = super().prepare_data(params) + + media_type = params.media_type + if media_type: + data["media_type"] = str(media_type) + + return data + + +@Register(api_type="http") +class HttpStreamingChatBotPipeline( + HttpStreamingChatBotMetaData, ModelPipelineStreamingChatBot[HttpConfiguration] +): + + def __init__(self, config: HttpConfiguration): + super().__init__(config=config) + + def invoke(self, params: StreamingChatBotParameters) -> StreamingHttpResponse: + raise NotImplementedError + + async def async_invoke(self, params: StreamingChatBotParameters) -> StreamingHttpResponse: + async with aiohttp.ClientSession(raise_for_status=True) as session: + headers = { + "Content-Type": "application/json", + "Accept": "application/json,text/event-stream", + } + async with session.post( + self.config.inference_url + "/v1/streaming_query", + json=self.prepare_data(params), + headers=headers, + ) as r: + async for chunk in r.content: + logger.debug(chunk) + yield chunk diff --git a/ansible_ai_connect/ai/api/model_pipelines/pipelines.py b/ansible_ai_connect/ai/api/model_pipelines/pipelines.py index f35f8f79c..8c66f1d53 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/pipelines.py +++ b/ansible_ai_connect/ai/api/model_pipelines/pipelines.py @@ -242,6 +242,33 @@ def init( ChatBotResponse = Any +@define +class StreamingChatBotParameters(ChatBotParameters): + media_type: str + + @classmethod + def init( + cls, + query: str, + provider: Optional[str] = None, + model_id: Optional[str] = None, + conversation_id: Optional[str] = None, + system_prompt: Optional[str] = None, + media_type: Optional[str] = None, + ): + return cls( + query=query, + provider=provider, + model_id=model_id, + conversation_id=conversation_id, + system_prompt=system_prompt, + media_type=media_type, + ) + + +StreamingChatBotResponse = Any + + class MetaData(Generic[PIPELINE_CONFIGURATION], metaclass=ABCMeta): def __init__(self, config: PIPELINE_CONFIGURATION): @@ -274,6 +301,9 @@ def alias() -> str: def invoke(self, params: PIPELINE_PARAMETERS) -> PIPELINE_RETURN: raise NotImplementedError + async def async_invoke(self, params: PIPELINE_PARAMETERS) -> PIPELINE_RETURN: + raise NotImplementedError + @abstractmethod def self_test(self) -> Optional[HealthCheckSummary]: raise NotImplementedError @@ -381,3 +411,17 @@ def __init__(self, config: PIPELINE_CONFIGURATION): @staticmethod def alias(): return "chatbot-service" + + +class ModelPipelineStreamingChatBot( + ModelPipeline[PIPELINE_CONFIGURATION, ChatBotParameters, StreamingChatBotResponse], + Generic[PIPELINE_CONFIGURATION], + metaclass=ABCMeta, +): + + def __init__(self, config: PIPELINE_CONFIGURATION): + super().__init__(config=config) + + @staticmethod + def alias(): + return "streaming-chatbot-service" diff --git a/ansible_ai_connect/ai/api/model_pipelines/registry.py b/ansible_ai_connect/ai/api/model_pipelines/registry.py index f9ee4c88b..1426099ec 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/registry.py +++ b/ansible_ai_connect/ai/api/model_pipelines/registry.py @@ -30,6 +30,7 @@ ModelPipelinePlaybookGeneration, ModelPipelineRoleExplanation, ModelPipelineRoleGeneration, + ModelPipelineStreamingChatBot, ) from ansible_ai_connect.main.settings.types import t_model_mesh_api_type @@ -45,6 +46,7 @@ ModelPipelinePlaybookExplanation, ModelPipelineRoleExplanation, ModelPipelineChatBot, + ModelPipelineStreamingChatBot, PipelineConfiguration, Serializer, ] diff --git a/ansible_ai_connect/ai/api/serializers.py b/ansible_ai_connect/ai/api/serializers.py index 329e4386b..f3bc7d5a0 100644 --- a/ansible_ai_connect/ai/api/serializers.py +++ b/ansible_ai_connect/ai/api/serializers.py @@ -352,6 +352,14 @@ class ChatRequestSerializer(serializers.Serializer): ) +class StreamingChatRequestSerializer(ChatRequestSerializer): + media_type = serializers.CharField( + required=False, + label="Media type", + help_text=("A media type to be used in the output from LLM."), + ) + + class ReferencedDocumentsSerializer(serializers.Serializer): docs_url = serializers.CharField() title = serializers.CharField() diff --git a/ansible_ai_connect/ai/api/streaming_chat.py b/ansible_ai_connect/ai/api/streaming_chat.py deleted file mode 100644 index b5f286cfd..000000000 --- a/ansible_ai_connect/ai/api/streaming_chat.py +++ /dev/null @@ -1,40 +0,0 @@ -import logging - -import aiohttp -from django.apps import apps -from django.http import StreamingHttpResponse -from rest_framework.decorators import parser_classes -from rest_framework.parsers import JSONParser -from rest_framework.views import APIView - -from ansible_ai_connect.ai.api.model_pipelines.pipelines import ModelPipelineChatBot - -logger = logging.getLogger(__name__) - - -class StreamingChat(APIView): - - def __init__(self): - self.llm = apps.get_app_config("ai").get_model_pipeline(ModelPipelineChatBot) - - async def call_chatservice(self, request): - async with aiohttp.ClientSession(raise_for_status=True) as session: - headers = { - "Content-Type": "application/json", - "Accept": "application/json,text/event-stream", - } - async with session.post( - self.llm.config.inference_url + "/v1/streaming_query", - json=request.data, - headers=headers, - ) as r: - async for chunk in r.content: - logger.debug(chunk) - yield chunk - - @parser_classes([JSONParser]) - def post(self, request): - return StreamingHttpResponse( - self.call_chatservice(request), - content_type="text/event-stream", - ) diff --git a/ansible_ai_connect/ai/api/versions/v1/ai/views.py b/ansible_ai_connect/ai/api/versions/v1/ai/views.py index 5582458a0..f47679e96 100644 --- a/ansible_ai_connect/ai/api/versions/v1/ai/views.py +++ b/ansible_ai_connect/ai/api/versions/v1/ai/views.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ansible_ai_connect.ai.api.streaming_chat import StreamingChat from ansible_ai_connect.ai.api.views import ( Chat, Completions, @@ -22,6 +21,7 @@ Feedback, GenerationPlaybook, GenerationRole, + StreamingChat, ) __all__ = [ diff --git a/ansible_ai_connect/ai/api/views.py b/ansible_ai_connect/ai/api/views.py index 02bcc18e2..98716e006 100644 --- a/ansible_ai_connect/ai/api/views.py +++ b/ansible_ai_connect/ai/api/views.py @@ -21,6 +21,7 @@ from attr import asdict from django.apps import apps from django.conf import settings +from django.http import StreamingHttpResponse from django_prometheus.conf import NAMESPACE from drf_spectacular.utils import OpenApiResponse, extend_schema from oauth2_provider.contrib.rest_framework import IsAuthenticatedOrTokenHasScope @@ -77,10 +78,12 @@ ModelPipelinePlaybookGeneration, ModelPipelineRoleExplanation, ModelPipelineRoleGeneration, + ModelPipelineStreamingChatBot, PlaybookExplanationParameters, PlaybookGenerationParameters, RoleExplanationParameters, RoleGenerationParameters, + StreamingChatBotParameters, ) from ansible_ai_connect.ai.api.pipelines.completions import CompletionsPipeline from ansible_ai_connect.ai.api.telemetry import schema1 @@ -134,6 +137,7 @@ PlaybookGenerationAction, RoleGenerationAction, SentimentFeedback, + StreamingChatRequestSerializer, SuggestionQualityFeedback, ) from .telemetry.schema1 import ( @@ -1126,3 +1130,82 @@ def post(self, request) -> Response: status=rest_framework_status.HTTP_200_OK, headers=headers, ) + + +class StreamingChat(AACSAPIView): + """ + Send a message to the backend chatbot service and get a streaming reply. + """ + + class StreamingChatEndpointThrottle(EndpointRateThrottle): + scope = "chat" + + permission_classes = [ + permissions.IsAuthenticated, + IsAuthenticatedOrTokenHasScope, + IsRHInternalUser | IsTestUser, + ] + required_scopes = ["read", "write"] + schema1_event = schema1.ChatBotOperationalEvent # TODO + request_serializer_class = StreamingChatRequestSerializer + throttle_classes = [StreamingChatEndpointThrottle] + + llm: ModelPipelineStreamingChatBot + + def __init__(self): + super().__init__() + self.llm = apps.get_app_config("ai").get_model_pipeline(ModelPipelineStreamingChatBot) + + self.chatbot_enabled = ( + self.llm.config.inference_url + and self.llm.config.model_id + and settings.CHATBOT_DEFAULT_PROVIDER + ) + if self.chatbot_enabled: + logger.debug("Chatbot is enabled.") + else: + logger.debug("Chatbot is not enabled.") + + @extend_schema( + request=StreamingChatRequestSerializer, + responses={ + 200: ChatResponseSerializer, # TODO + 400: OpenApiResponse(description="Bad request"), + 403: OpenApiResponse(description="Forbidden"), + 413: OpenApiResponse(description="Prompt too long"), + 422: OpenApiResponse(description="Validation failed"), + 500: OpenApiResponse(description="Internal server error"), + 503: OpenApiResponse(description="Service unavailable"), + }, + summary="Streaming chat request", + ) + def post(self, request) -> Response: + if not self.chatbot_enabled: + raise ChatbotNotEnabledException() + + req_query = self.validated_data["query"] + req_system_prompt = self.validated_data.get("system_prompt") + req_provider = self.validated_data.get("provider", settings.CHATBOT_DEFAULT_PROVIDER) + conversation_id = self.validated_data.get("conversation_id") + media_type = self.validated_data.get("media_type") + + # Initialise Segment Event early, in case of exceptions + self.event.chat_prompt = anonymize_struct(req_query) + self.event.chat_system_prompt = req_system_prompt + self.event.provider_id = req_provider + self.event.conversation_id = conversation_id + self.event.modelName = self.req_model_id or self.llm.config.model_id + + return StreamingHttpResponse( + self.llm.async_invoke( + StreamingChatBotParameters.init( + query=req_query, + system_prompt=req_system_prompt, + model_id=self.req_model_id or self.llm.config.model_id, + provider=req_provider, + conversation_id=conversation_id, + media_type=media_type, + ) + ), + content_type="text/event-stream", + ) diff --git a/ansible_ai_connect/main/settings/legacy.py b/ansible_ai_connect/main/settings/legacy.py index 6543f43c8..cc7aef87c 100644 --- a/ansible_ai_connect/main/settings/legacy.py +++ b/ansible_ai_connect/main/settings/legacy.py @@ -192,6 +192,15 @@ def load_from_env_vars(): "stream": False, }, } + model_pipelines_config["ModelPipelineStreamingChatBot"] = { + "provider": "http", + "config": { + "inference_url": chatbot_service_url or "http://localhost:8000", + "model_id": chatbot_service_model_id or "granite3-8b", + "verify_ssl": model_service_verify_ssl, + "stream": False, + }, + } # Enable Health Checks where we have them implemented model_pipelines_config["ModelPipelineCompletions"]["config"][ diff --git a/ansible_ai_connect/main/views.py b/ansible_ai_connect/main/views.py index 60ca899bb..d81b6607b 100644 --- a/ansible_ai_connect/main/views.py +++ b/ansible_ai_connect/main/views.py @@ -28,7 +28,9 @@ from rest_framework.renderers import BaseRenderer from rest_framework.views import APIView -from ansible_ai_connect.ai.api.model_pipelines.pipelines import ModelPipelineChatBot +from ansible_ai_connect.ai.api.model_pipelines.pipelines import ( + ModelPipelineStreamingChatBot, +) from ansible_ai_connect.ai.api.permissions import ( IsOrganisationAdministrator, IsOrganisationLightspeedSubscriber, @@ -121,12 +123,12 @@ class ChatbotView(ProtectedTemplateView): IsRHInternalUser | IsTestUser, ] - llm: ModelPipelineChatBot + llm: ModelPipelineStreamingChatBot chatbot_enabled: bool def __init__(self): super().__init__() - self.llm = apps.get_app_config("ai").get_model_pipeline(ModelPipelineChatBot) + self.llm = apps.get_app_config("ai").get_model_pipeline(ModelPipelineStreamingChatBot) self.chatbot_enabled = ( self.llm.config.inference_url and self.llm.config.model_id