diff --git a/ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py b/ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py index 3ce05be77..dc478f433 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py +++ b/ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py @@ -14,9 +14,11 @@ import json import logging -from typing import Optional +from typing import AsyncGenerator, Optional +import aiohttp import requests +from django.http import StreamingHttpResponse from health_check.exceptions import ServiceUnavailable from ansible_ai_connect.ai.api.exceptions import ( @@ -39,6 +41,8 @@ MetaData, ModelPipelineChatBot, ModelPipelineCompletions, + ModelPipelineStreamingChatBot, + StreamingChatBotParameters, ) from ansible_ai_connect.ai.api.model_pipelines.registry import Register from ansible_ai_connect.healthcheck.backends import ( @@ -120,8 +124,43 @@ def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_i raise NotImplementedError +class HttpChatBotMetaData(HttpMetaData): + + def __init__(self, config: HttpConfiguration): + super().__init__(config=config) + + def self_test(self) -> Optional[HealthCheckSummary]: + summary: HealthCheckSummary = HealthCheckSummary( + { + MODEL_MESH_HEALTH_CHECK_PROVIDER: "http", + MODEL_MESH_HEALTH_CHECK_MODELS: "ok", + } + ) + try: + headers = {"Content-Type": "application/json"} + r = requests.get(self.config.inference_url + "/readiness", headers=headers) + r.raise_for_status() + + data = r.json() + ready = data.get("ready") + if not ready: + reason = data.get("reason") + summary.add_exception( + MODEL_MESH_HEALTH_CHECK_MODELS, + HealthCheckSummaryException(ServiceUnavailable(reason)), + ) + + except Exception as e: + logger.exception(str(e)) + summary.add_exception( + MODEL_MESH_HEALTH_CHECK_MODELS, + HealthCheckSummaryException(ServiceUnavailable(ERROR_MESSAGE), e), + ) + return summary + + @Register(api_type="http") -class HttpChatBotPipeline(HttpMetaData, ModelPipelineChatBot[HttpConfiguration]): +class HttpChatBotPipeline(HttpChatBotMetaData, ModelPipelineChatBot[HttpConfiguration]): def __init__(self, config: HttpConfiguration): super().__init__(config=config) @@ -171,31 +210,49 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse: detail = json.loads(response.text).get("detail", "") raise ChatbotInternalServerException(detail=detail) - def self_test(self) -> Optional[HealthCheckSummary]: - summary: HealthCheckSummary = HealthCheckSummary( - { - MODEL_MESH_HEALTH_CHECK_PROVIDER: "http", - MODEL_MESH_HEALTH_CHECK_MODELS: "ok", + +@Register(api_type="http") +class HttpStreamingChatBotPipeline( + HttpChatBotMetaData, ModelPipelineStreamingChatBot[HttpConfiguration] +): + + def __init__(self, config: HttpConfiguration): + super().__init__(config=config) + + def invoke(self, params: StreamingChatBotParameters) -> StreamingHttpResponse: + raise NotImplementedError + + async def async_invoke(self, params: StreamingChatBotParameters) -> AsyncGenerator: + async with aiohttp.ClientSession(raise_for_status=True) as session: + headers = { + "Content-Type": "application/json", + "Accept": "application/json,text/event-stream", } - ) - try: - headers = {"Content-Type": "application/json"} - r = requests.get(self.config.inference_url + "/readiness", headers=headers) - r.raise_for_status() - data = r.json() - ready = data.get("ready") - if not ready: - reason = data.get("reason") - summary.add_exception( - MODEL_MESH_HEALTH_CHECK_MODELS, - HealthCheckSummaryException(ServiceUnavailable(reason)), - ) + query = params.query + conversation_id = params.conversation_id + provider = params.provider + model_id = params.model_id + system_prompt = params.system_prompt + media_type = params.media_type - except Exception as e: - logger.exception(str(e)) - summary.add_exception( - MODEL_MESH_HEALTH_CHECK_MODELS, - HealthCheckSummaryException(ServiceUnavailable(ERROR_MESSAGE), e), - ) - return summary + data = { + "query": query, + "model": model_id, + "provider": provider, + } + if conversation_id: + data["conversation_id"] = str(conversation_id) + if system_prompt: + data["system_prompt"] = str(system_prompt) + if media_type: + data["media_type"] = str(media_type) + + async with session.post( + self.config.inference_url + "/v1/streaming_query", + json=data, + headers=headers, + ) as r: + async for chunk in r.content: + logger.debug(chunk) + yield chunk diff --git a/ansible_ai_connect/ai/api/model_pipelines/nop/pipelines.py b/ansible_ai_connect/ai/api/model_pipelines/nop/pipelines.py index c3620ef5c..42a41f848 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/nop/pipelines.py +++ b/ansible_ai_connect/ai/api/model_pipelines/nop/pipelines.py @@ -29,6 +29,7 @@ ModelPipelinePlaybookGeneration, ModelPipelineRoleExplanation, ModelPipelineRoleGeneration, + ModelPipelineStreamingChatBot, PlaybookExplanationParameters, PlaybookExplanationResponse, PlaybookGenerationParameters, @@ -37,6 +38,8 @@ RoleExplanationResponse, RoleGenerationParameters, RoleGenerationResponse, + StreamingChatBotParameters, + StreamingChatBotResponse, ) from ansible_ai_connect.ai.api.model_pipelines.registry import Register from ansible_ai_connect.healthcheck.backends import HealthCheckSummary @@ -143,3 +146,16 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse: def self_test(self) -> Optional[HealthCheckSummary]: raise NotImplementedError + + +@Register(api_type="nop") +class NopStreamingChatBotPipeline(NopMetaData, ModelPipelineStreamingChatBot[NopConfiguration]): + + def __init__(self, config: NopConfiguration): + super().__init__(config=config) + + def invoke(self, params: StreamingChatBotParameters) -> StreamingChatBotResponse: + raise NotImplementedError + + def self_test(self) -> Optional[HealthCheckSummary]: + raise NotImplementedError diff --git a/ansible_ai_connect/ai/api/model_pipelines/pipelines.py b/ansible_ai_connect/ai/api/model_pipelines/pipelines.py index f35f8f79c..6495bd682 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/pipelines.py +++ b/ansible_ai_connect/ai/api/model_pipelines/pipelines.py @@ -14,7 +14,7 @@ import logging from abc import ABCMeta, abstractmethod -from typing import Any, Dict, Generic, Optional, TypeVar +from typing import Any, AsyncGenerator, Dict, Generic, Optional, TypeVar from attrs import define from django.conf import settings @@ -242,6 +242,33 @@ def init( ChatBotResponse = Any +@define +class StreamingChatBotParameters(ChatBotParameters): + media_type: str + + @classmethod + def init( + cls, + query: str, + provider: Optional[str] = None, + model_id: Optional[str] = None, + conversation_id: Optional[str] = None, + system_prompt: Optional[str] = None, + media_type: Optional[str] = None, + ): + return cls( + query=query, + provider=provider, + model_id=model_id, + conversation_id=conversation_id, + system_prompt=system_prompt, + media_type=media_type, + ) + + +StreamingChatBotResponse = Any + + class MetaData(Generic[PIPELINE_CONFIGURATION], metaclass=ABCMeta): def __init__(self, config: PIPELINE_CONFIGURATION): @@ -274,6 +301,9 @@ def alias() -> str: def invoke(self, params: PIPELINE_PARAMETERS) -> PIPELINE_RETURN: raise NotImplementedError + def async_invoke(self, params: PIPELINE_PARAMETERS) -> AsyncGenerator: + raise NotImplementedError + @abstractmethod def self_test(self) -> Optional[HealthCheckSummary]: raise NotImplementedError @@ -381,3 +411,17 @@ def __init__(self, config: PIPELINE_CONFIGURATION): @staticmethod def alias(): return "chatbot-service" + + +class ModelPipelineStreamingChatBot( + ModelPipeline[PIPELINE_CONFIGURATION, StreamingChatBotParameters, StreamingChatBotResponse], + Generic[PIPELINE_CONFIGURATION], + metaclass=ABCMeta, +): + + def __init__(self, config: PIPELINE_CONFIGURATION): + super().__init__(config=config) + + @staticmethod + def alias(): + return "streaming-chatbot-service" diff --git a/ansible_ai_connect/ai/api/model_pipelines/registry.py b/ansible_ai_connect/ai/api/model_pipelines/registry.py index f9ee4c88b..1426099ec 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/registry.py +++ b/ansible_ai_connect/ai/api/model_pipelines/registry.py @@ -30,6 +30,7 @@ ModelPipelinePlaybookGeneration, ModelPipelineRoleExplanation, ModelPipelineRoleGeneration, + ModelPipelineStreamingChatBot, ) from ansible_ai_connect.main.settings.types import t_model_mesh_api_type @@ -45,6 +46,7 @@ ModelPipelinePlaybookExplanation, ModelPipelineRoleExplanation, ModelPipelineChatBot, + ModelPipelineStreamingChatBot, PipelineConfiguration, Serializer, ] diff --git a/ansible_ai_connect/ai/api/serializers.py b/ansible_ai_connect/ai/api/serializers.py index 329e4386b..f3bc7d5a0 100644 --- a/ansible_ai_connect/ai/api/serializers.py +++ b/ansible_ai_connect/ai/api/serializers.py @@ -352,6 +352,14 @@ class ChatRequestSerializer(serializers.Serializer): ) +class StreamingChatRequestSerializer(ChatRequestSerializer): + media_type = serializers.CharField( + required=False, + label="Media type", + help_text=("A media type to be used in the output from LLM."), + ) + + class ReferencedDocumentsSerializer(serializers.Serializer): docs_url = serializers.CharField() title = serializers.CharField() diff --git a/ansible_ai_connect/ai/api/versions/v1/ai/urls.py b/ansible_ai_connect/ai/api/versions/v1/ai/urls.py index 2921b3032..aad4bcdb8 100644 --- a/ansible_ai_connect/ai/api/versions/v1/ai/urls.py +++ b/ansible_ai_connect/ai/api/versions/v1/ai/urls.py @@ -25,4 +25,5 @@ path("generations/role/", views.GenerationRole.as_view(), name="generations/role"), path("feedback/", views.Feedback.as_view(), name="feedback"), path("chat/", views.Chat.as_view(), name="chat"), + path("streaming_chat/", views.StreamingChat.as_view(), name="streaming_chat"), ] diff --git a/ansible_ai_connect/ai/api/versions/v1/ai/views.py b/ansible_ai_connect/ai/api/versions/v1/ai/views.py index 1667003af..f47679e96 100644 --- a/ansible_ai_connect/ai/api/versions/v1/ai/views.py +++ b/ansible_ai_connect/ai/api/versions/v1/ai/views.py @@ -21,6 +21,7 @@ Feedback, GenerationPlaybook, GenerationRole, + StreamingChat, ) __all__ = [ @@ -32,4 +33,5 @@ "ExplanationRole", "Feedback", "Chat", + "StreamingChat", ] diff --git a/ansible_ai_connect/ai/api/views.py b/ansible_ai_connect/ai/api/views.py index 8283ea379..3e32690cb 100644 --- a/ansible_ai_connect/ai/api/views.py +++ b/ansible_ai_connect/ai/api/views.py @@ -21,6 +21,7 @@ from attr import asdict from django.apps import apps from django.conf import settings +from django.http import StreamingHttpResponse from django_prometheus.conf import NAMESPACE from drf_spectacular.utils import OpenApiResponse, extend_schema from oauth2_provider.contrib.rest_framework import IsAuthenticatedOrTokenHasScope @@ -77,10 +78,12 @@ ModelPipelinePlaybookGeneration, ModelPipelineRoleExplanation, ModelPipelineRoleGeneration, + ModelPipelineStreamingChatBot, PlaybookExplanationParameters, PlaybookGenerationParameters, RoleExplanationParameters, RoleGenerationParameters, + StreamingChatBotParameters, ) from ansible_ai_connect.ai.api.pipelines.completions import CompletionsPipeline from ansible_ai_connect.ai.api.telemetry import schema1 @@ -134,6 +137,7 @@ PlaybookGenerationAction, RoleGenerationAction, SentimentFeedback, + StreamingChatRequestSerializer, SuggestionQualityFeedback, ) from .telemetry.schema1 import ( @@ -1126,3 +1130,82 @@ def post(self, request) -> Response: status=rest_framework_status.HTTP_200_OK, headers=headers, ) + + +class StreamingChat(AACSAPIView): + """ + Send a message to the backend chatbot service and get a streaming reply. + """ + + class StreamingChatEndpointThrottle(EndpointRateThrottle): + scope = "chat" + + permission_classes = [ + permissions.IsAuthenticated, + IsAuthenticatedOrTokenHasScope, + IsRHInternalUser | IsTestUser, + ] + required_scopes = ["read", "write"] + schema1_event = schema1.ChatBotOperationalEvent # TODO + request_serializer_class = StreamingChatRequestSerializer + throttle_classes = [StreamingChatEndpointThrottle] + + llm: ModelPipelineStreamingChatBot + + def __init__(self): + super().__init__() + self.llm = apps.get_app_config("ai").get_model_pipeline(ModelPipelineStreamingChatBot) + + self.chatbot_enabled = ( + self.llm.config.inference_url + and self.llm.config.model_id + and settings.CHATBOT_DEFAULT_PROVIDER + ) + if self.chatbot_enabled: + logger.debug("Chatbot is enabled.") + else: + logger.debug("Chatbot is not enabled.") + + @extend_schema( + request=StreamingChatRequestSerializer, + responses={ + 200: ChatResponseSerializer, # TODO + 400: OpenApiResponse(description="Bad request"), + 403: OpenApiResponse(description="Forbidden"), + 413: OpenApiResponse(description="Prompt too long"), + 422: OpenApiResponse(description="Validation failed"), + 500: OpenApiResponse(description="Internal server error"), + 503: OpenApiResponse(description="Service unavailable"), + }, + summary="Streaming chat request", + ) + def post(self, request) -> Response: + if not self.chatbot_enabled: + raise ChatbotNotEnabledException() + + req_query = self.validated_data["query"] + req_system_prompt = self.validated_data.get("system_prompt") + req_provider = self.validated_data.get("provider", settings.CHATBOT_DEFAULT_PROVIDER) + conversation_id = self.validated_data.get("conversation_id") + media_type = self.validated_data.get("media_type") + + # Initialise Segment Event early, in case of exceptions + self.event.chat_prompt = anonymize_struct(req_query) + self.event.chat_system_prompt = req_system_prompt + self.event.provider_id = req_provider + self.event.conversation_id = conversation_id + self.event.modelName = self.req_model_id or self.llm.config.model_id + + return StreamingHttpResponse( + self.llm.async_invoke( + StreamingChatBotParameters.init( + query=req_query, + system_prompt=req_system_prompt, + model_id=self.req_model_id or self.llm.config.model_id, + provider=req_provider, + conversation_id=conversation_id, + media_type=media_type, + ) + ), + content_type="text/event-stream", + ) diff --git a/ansible_ai_connect/healthcheck/tests/test_healthcheck.py b/ansible_ai_connect/healthcheck/tests/test_healthcheck.py index 9ac53cb66..8f654b564 100644 --- a/ansible_ai_connect/healthcheck/tests/test_healthcheck.py +++ b/ansible_ai_connect/healthcheck/tests/test_healthcheck.py @@ -159,7 +159,7 @@ def assert_basic_data( self.assert_common_data(data, expected_status, deployed_region) timestamp = data["timestamp"] dependencies = data.get("dependencies", []) - self.assertEqual(10, len(dependencies)) + self.assertEqual(11, len(dependencies)) for dependency in dependencies: self.assertIn( dependency["name"], diff --git a/ansible_ai_connect/main/settings/base.py b/ansible_ai_connect/main/settings/base.py index a91f834d0..a726c8e45 100644 --- a/ansible_ai_connect/main/settings/base.py +++ b/ansible_ai_connect/main/settings/base.py @@ -55,6 +55,7 @@ # Application definition INSTALLED_APPS = [ + "daphne", "django.contrib.admin", "django.contrib.auth", "django.contrib.contenttypes", @@ -93,6 +94,11 @@ "csp.middleware.CSPMiddleware", ] +if os.environ.get("CSRF_TRUSTED_ORIGINS"): + CSRF_TRUSTED_ORIGINS = os.environ.get("CSRF_TRUSTED_ORIGINS").split(",") +else: + CSRF_TRUSTED_ORIGINS = ["http://localhost:8000"] + # Allow Prometheus to scrape metrics ALLOWED_CIDR_NETS = [os.environ.get("ALLOWED_CIDR_NETS", "10.0.0.0/8")] @@ -294,7 +300,7 @@ def is_ssl_enabled(value: str) -> bool: }, }, "handlers": { - "console": {"class": "logging.StreamHandler", "formatter": "simple", "level": "INFO"}, + "console": {"class": "logging.StreamHandler", "formatter": "simple", "level": "DEBUG"}, }, "loggers": { "django": { @@ -334,6 +340,11 @@ def is_ssl_enabled(value: str) -> bool: "level": "INFO", "propagate": False, }, + "ansible_ai_connect.ai.api.streaming_chat": { + "handlers": ["console"], + "level": "DEBUG", + "propagate": False, + }, }, "root": { "handlers": ["console"], @@ -358,6 +369,7 @@ def is_ssl_enabled(value: str) -> bool: ] WSGI_APPLICATION = "ansible_ai_connect.main.wsgi.application" +ASGI_APPLICATION = "ansible_ai_connect.main.asgi.application" # Database # https://docs.djangoproject.com/en/4.1/ref/settings/#databases @@ -543,6 +555,7 @@ def is_ssl_enabled(value: str) -> bool: # ------------------------------------------ CHATBOT_DEFAULT_PROVIDER = os.getenv("CHATBOT_DEFAULT_PROVIDER") CHATBOT_DEBUG_UI = os.getenv("CHATBOT_DEBUG_UI", "False").lower() == "true" +CHATBOT_STREAMING = os.getenv("CHATBOT_STREAMING", "False").lower() == "true" # ========================================== # ========================================== diff --git a/ansible_ai_connect/main/settings/legacy.py b/ansible_ai_connect/main/settings/legacy.py index 6543f43c8..cc7aef87c 100644 --- a/ansible_ai_connect/main/settings/legacy.py +++ b/ansible_ai_connect/main/settings/legacy.py @@ -192,6 +192,15 @@ def load_from_env_vars(): "stream": False, }, } + model_pipelines_config["ModelPipelineStreamingChatBot"] = { + "provider": "http", + "config": { + "inference_url": chatbot_service_url or "http://localhost:8000", + "model_id": chatbot_service_model_id or "granite3-8b", + "verify_ssl": model_service_verify_ssl, + "stream": False, + }, + } # Enable Health Checks where we have them implemented model_pipelines_config["ModelPipelineCompletions"]["config"][ diff --git a/ansible_ai_connect/main/tests/test_views.py b/ansible_ai_connect/main/tests/test_views.py index 059f7fc28..4da9d433b 100644 --- a/ansible_ai_connect/main/tests/test_views.py +++ b/ansible_ai_connect/main/tests/test_views.py @@ -350,6 +350,7 @@ def test_chatbot_view_with_debug_ui(self): self.assertEqual(r.status_code, HTTPStatus.OK) self.assertContains(r, '