diff --git a/ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py b/ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py index 3ce05be77..cca099948 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py +++ b/ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py @@ -14,9 +14,12 @@ import json import logging -from typing import Optional +from json import JSONDecodeError +from typing import AsyncGenerator, Optional +import aiohttp import requests +from django.http import StreamingHttpResponse from health_check.exceptions import ServiceUnavailable from ansible_ai_connect.ai.api.exceptions import ( @@ -39,6 +42,9 @@ MetaData, ModelPipelineChatBot, ModelPipelineCompletions, + ModelPipelineStreamingChatBot, + StreamingChatBotParameters, + StreamingChatBotResponse, ) from ansible_ai_connect.ai.api.model_pipelines.registry import Register from ansible_ai_connect.healthcheck.backends import ( @@ -120,8 +126,43 @@ def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_i raise NotImplementedError +class HttpChatBotMetaData(HttpMetaData): + + def __init__(self, config: HttpConfiguration): + super().__init__(config=config) + + def self_test(self) -> Optional[HealthCheckSummary]: + summary: HealthCheckSummary = HealthCheckSummary( + { + MODEL_MESH_HEALTH_CHECK_PROVIDER: "http", + MODEL_MESH_HEALTH_CHECK_MODELS: "ok", + } + ) + try: + headers = {"Content-Type": "application/json"} + r = requests.get(self.config.inference_url + "/readiness", headers=headers) + r.raise_for_status() + + data = r.json() + ready = data.get("ready") + if not ready: + reason = data.get("reason") + summary.add_exception( + MODEL_MESH_HEALTH_CHECK_MODELS, + HealthCheckSummaryException(ServiceUnavailable(reason)), + ) + + except Exception as e: + logger.exception(str(e)) + summary.add_exception( + MODEL_MESH_HEALTH_CHECK_MODELS, + HealthCheckSummaryException(ServiceUnavailable(ERROR_MESSAGE), e), + ) + return summary + + @Register(api_type="http") -class HttpChatBotPipeline(HttpMetaData, ModelPipelineChatBot[HttpConfiguration]): +class HttpChatBotPipeline(HttpChatBotMetaData, ModelPipelineChatBot[HttpConfiguration]): def __init__(self, config: HttpConfiguration): super().__init__(config=config) @@ -171,31 +212,100 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse: detail = json.loads(response.text).get("detail", "") raise ChatbotInternalServerException(detail=detail) - def self_test(self) -> Optional[HealthCheckSummary]: - summary: HealthCheckSummary = HealthCheckSummary( - { - MODEL_MESH_HEALTH_CHECK_PROVIDER: "http", - MODEL_MESH_HEALTH_CHECK_MODELS: "ok", - } + +@Register(api_type="http") +class HttpStreamingChatBotPipeline( + HttpChatBotMetaData, ModelPipelineStreamingChatBot[HttpConfiguration] +): + + def __init__(self, config: HttpConfiguration): + super().__init__(config=config) + + def invoke(self, params: StreamingChatBotParameters) -> StreamingChatBotResponse: + response = self.get_streaming_http_response(params) + + if response.status_code == 200: + return response + else: + raise ChatbotInternalServerException(detail="Internal server error") + + def get_streaming_http_response( + self, params: StreamingChatBotParameters + ) -> StreamingHttpResponse: + return StreamingHttpResponse( + self.async_invoke(params), + content_type="text/event-stream", ) - try: - headers = {"Content-Type": "application/json"} - r = requests.get(self.config.inference_url + "/readiness", headers=headers) - r.raise_for_status() - data = r.json() - ready = data.get("ready") - if not ready: - reason = data.get("reason") - summary.add_exception( - MODEL_MESH_HEALTH_CHECK_MODELS, - HealthCheckSummaryException(ServiceUnavailable(reason)), - ) + async def async_invoke(self, params: StreamingChatBotParameters) -> AsyncGenerator: + async with aiohttp.ClientSession(raise_for_status=True) as session: + headers = { + "Content-Type": "application/json", + "Accept": "application/json,text/event-stream", + } - except Exception as e: - logger.exception(str(e)) - summary.add_exception( - MODEL_MESH_HEALTH_CHECK_MODELS, - HealthCheckSummaryException(ServiceUnavailable(ERROR_MESSAGE), e), - ) - return summary + query = params.query + conversation_id = params.conversation_id + provider = params.provider + model_id = params.model_id + system_prompt = params.system_prompt + media_type = params.media_type + + data = { + "query": query, + "model": model_id, + "provider": provider, + } + if conversation_id: + data["conversation_id"] = str(conversation_id) + if system_prompt: + data["system_prompt"] = str(system_prompt) + if media_type: + data["media_type"] = str(media_type) + + async with session.post( + self.config.inference_url + "/v1/streaming_query", + json=data, + headers=headers, + raise_for_status=False, + ) as response: + if response.status == 200: + async for chunk in response.content: + try: + if chunk: + s = chunk.decode("utf-8").strip() + if s and s.startswith("data: "): + o = json.loads(s[len("data: ") :]) + if o["event"] == "error": + default_data = { + "response": "(not provided)", + "cause": "(not provided)", + } + data = o.get("data", default_data) + logger.error( + "An error received in chat streaming content:" + + " response=" + + data.get("response") + + ", cause=" + + data.get("cause") + ) + except JSONDecodeError: + pass + logger.debug(chunk) + yield chunk + else: + logging.error( + "Streaming query API returned status code=" + + str(response.status) + + ", reason=" + + str(response.reason) + ) + error = { + "event": "error", + "data": { + "response": f"Non-200 status code ({response.status}) was received.", + "cause": response.reason, + }, + } + yield json.dumps(error).encode("utf-8") + return diff --git a/ansible_ai_connect/ai/api/model_pipelines/http/tests/test_pipelines.py b/ansible_ai_connect/ai/api/model_pipelines/http/tests/test_pipelines.py new file mode 100644 index 000000000..226114f21 --- /dev/null +++ b/ansible_ai_connect/ai/api/model_pipelines/http/tests/test_pipelines.py @@ -0,0 +1,173 @@ +# +# Copyright Red Hat +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import logging +from unittest import IsolatedAsyncioTestCase +from unittest.mock import patch + +from ansible_ai_connect.test_utils import WisdomLogAwareMixin + +from ...pipelines import StreamingChatBotParameters +from ...tests import mock_pipeline_config +from ..pipelines import HttpStreamingChatBotPipeline + +logger = logging.getLogger(__name__) + + +class TestHttpStreamingChatBotPipeline(IsolatedAsyncioTestCase, WisdomLogAwareMixin): + pipeline: HttpStreamingChatBotPipeline + + STREAM_DATA = [ + {"event": "start", "data": {"conversation_id": "92766ddd-dfc8-4830-b269-7a4b3dbc7c3f"}}, + {"event": "token", "data": {"id": 0, "token": ""}}, + {"event": "token", "data": {"id": 1, "token": "Hello"}}, + {"event": "token", "data": {"id": 2, "token": "!"}}, + {"event": "token", "data": {"id": 3, "token": " I"}}, + {"event": "token", "data": {"id": 4, "token": "'m"}}, + {"event": "token", "data": {"id": 5, "token": " Ansible"}}, + {"event": "token", "data": {"id": 6, "token": " L"}}, + {"event": "token", "data": {"id": 7, "token": "ights"}}, + {"event": "token", "data": {"id": 8, "token": "peed"}}, + {"event": "token", "data": {"id": 9, "token": ","}}, + {"event": "token", "data": {"id": 10, "token": " your"}}, + {"event": "token", "data": {"id": 11, "token": " virtual"}}, + {"event": "token", "data": {"id": 12, "token": " assistant"}}, + {"event": "token", "data": {"id": 13, "token": " for"}}, + {"event": "token", "data": {"id": 14, "token": " all"}}, + {"event": "token", "data": {"id": 15, "token": " things"}}, + {"event": "token", "data": {"id": 16, "token": " Ansible"}}, + {"event": "token", "data": {"id": 17, "token": "."}}, + {"event": "token", "data": {"id": 18, "token": " How"}}, + {"event": "token", "data": {"id": 19, "token": " can"}}, + {"event": "token", "data": {"id": 20, "token": " I"}}, + {"event": "token", "data": {"id": 21, "token": " assist"}}, + {"event": "token", "data": {"id": 22, "token": " you"}}, + {"event": "token", "data": {"id": 23, "token": " today"}}, + {"event": "token", "data": {"id": 24, "token": "?"}}, + {"event": "token", "data": {"id": 25, "token": ""}}, + { + "event": "end", + "data": { + "referenced_documents": [], + "truncated": False, + "input_tokens": 241, + "output_tokens": 25, + }, + }, + ] + + STREAM_DATA_PROMPT_TOO_LONG = [ + {"event": "start", "data": {"conversation_id": "92766ddd-dfc8-4830-b269-7a4b3dbc7c3e"}}, + { + "event": "error", + "data": {"response": "Prompt is too long", "cause": "Prompt length 10000 exceeds LLM"}, + }, + ] + + STREAM_DATA_PROMPT_GENERIC_LLM_ERROR = [ + {"event": "start", "data": {"conversation_id": "92766ddd-dfc8-4830-b269-7a4b3dbc7c3d"}}, + { + "event": "error", + "data": { + "response": "Oops, something went wrong during LLM invocation", + "cause": "A generic LLM error", + }, + }, + ] + + STREAM_DATA_PROMPT_ERROR_WITH_NO_DATA = [ + {"event": "start", "data": {"conversation_id": "92766ddd-dfc8-4830-b269-7a4b3dbc7c3c"}}, + {"event": "error"}, + ] + + def setUp(self): + self.pipeline = HttpStreamingChatBotPipeline(mock_pipeline_config("http")) + + def assertInLog(self, s, logs, number_of_matches_expected=None): + self.assertTrue(self.searchInLogOutput(s, logs, number_of_matches_expected), logs) + + def get_return_value(self, stream_data, status=200): + class MyAsyncContextManager: + def __init__(self, stream_data, status=200): + self.stream_data = stream_data + self.status = status + self.reason = "" + + async def my_async_generator(self): + for data in self.stream_data: + s = json.dumps(data) + yield (f"data: {s}\n\n".encode()) + + async def __aenter__(self): + self.content = self.my_async_generator() + self.status = self.status + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + return MyAsyncContextManager(stream_data, status) + + def get_params(self) -> StreamingChatBotParameters: + return StreamingChatBotParameters( + query="Hello", + provider="", + model_id="", + conversation_id=None, + system_prompt=None, + media_type="application/json", + ) + + @patch("aiohttp.ClientSession.post") + async def test_async_invoke_with_no_error(self, mock_post): + mock_post.return_value = self.get_return_value(self.STREAM_DATA) + async for _ in self.pipeline.async_invoke(self.get_params()): + pass + + @patch("aiohttp.ClientSession.post") + async def test_async_invoke_prompt_too_long(self, mock_post): + mock_post.return_value = self.get_return_value(self.STREAM_DATA_PROMPT_TOO_LONG) + with self.assertLogs(logger="root", level="ERROR") as log: + async for _ in self.pipeline.async_invoke(self.get_params()): + pass + self.assertInLog("Prompt is too long", log) + + @patch("aiohttp.ClientSession.post") + async def test_async_invoke_prompt_generic_llm_error(self, mock_post): + mock_post.return_value = self.get_return_value(self.STREAM_DATA_PROMPT_GENERIC_LLM_ERROR) + with self.assertLogs(logger="root", level="ERROR") as log: + async for _ in self.pipeline.async_invoke(self.get_params()): + pass + self.assertInLog("Oops, something went wrong during LLM invocation", log) + + @patch("aiohttp.ClientSession.post") + async def test_async_invoke_internal_server_error(self, mock_post): + mock_post.return_value = self.get_return_value( + self.STREAM_DATA_PROMPT_GENERIC_LLM_ERROR, 500 + ) + with self.assertLogs(logger="root", level="ERROR") as log: + async for _ in self.pipeline.async_invoke(self.get_params()): + pass + self.assertInLog("Streaming query API returned status code=500", log) + + @patch("aiohttp.ClientSession.post") + async def test_async_invoke_error_with_no_data(self, mock_post): + mock_post.return_value = self.get_return_value( + self.STREAM_DATA_PROMPT_ERROR_WITH_NO_DATA, + ) + with self.assertLogs(logger="root", level="ERROR") as log: + async for _ in self.pipeline.async_invoke(self.get_params()): + pass + self.assertInLog("(not provided)", log) diff --git a/ansible_ai_connect/ai/api/model_pipelines/nop/pipelines.py b/ansible_ai_connect/ai/api/model_pipelines/nop/pipelines.py index c3620ef5c..42a41f848 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/nop/pipelines.py +++ b/ansible_ai_connect/ai/api/model_pipelines/nop/pipelines.py @@ -29,6 +29,7 @@ ModelPipelinePlaybookGeneration, ModelPipelineRoleExplanation, ModelPipelineRoleGeneration, + ModelPipelineStreamingChatBot, PlaybookExplanationParameters, PlaybookExplanationResponse, PlaybookGenerationParameters, @@ -37,6 +38,8 @@ RoleExplanationResponse, RoleGenerationParameters, RoleGenerationResponse, + StreamingChatBotParameters, + StreamingChatBotResponse, ) from ansible_ai_connect.ai.api.model_pipelines.registry import Register from ansible_ai_connect.healthcheck.backends import HealthCheckSummary @@ -143,3 +146,16 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse: def self_test(self) -> Optional[HealthCheckSummary]: raise NotImplementedError + + +@Register(api_type="nop") +class NopStreamingChatBotPipeline(NopMetaData, ModelPipelineStreamingChatBot[NopConfiguration]): + + def __init__(self, config: NopConfiguration): + super().__init__(config=config) + + def invoke(self, params: StreamingChatBotParameters) -> StreamingChatBotResponse: + raise NotImplementedError + + def self_test(self) -> Optional[HealthCheckSummary]: + raise NotImplementedError diff --git a/ansible_ai_connect/ai/api/model_pipelines/pipelines.py b/ansible_ai_connect/ai/api/model_pipelines/pipelines.py index f35f8f79c..c2facc626 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/pipelines.py +++ b/ansible_ai_connect/ai/api/model_pipelines/pipelines.py @@ -18,6 +18,7 @@ from attrs import define from django.conf import settings +from django.http import StreamingHttpResponse from rest_framework.request import Request from rest_framework.response import Response @@ -242,6 +243,33 @@ def init( ChatBotResponse = Any +@define +class StreamingChatBotParameters(ChatBotParameters): + media_type: str + + @classmethod + def init( + cls, + query: str, + provider: Optional[str] = None, + model_id: Optional[str] = None, + conversation_id: Optional[str] = None, + system_prompt: Optional[str] = None, + media_type: Optional[str] = None, + ): + return cls( + query=query, + provider=provider, + model_id=model_id, + conversation_id=conversation_id, + system_prompt=system_prompt, + media_type=media_type, + ) + + +StreamingChatBotResponse = StreamingHttpResponse + + class MetaData(Generic[PIPELINE_CONFIGURATION], metaclass=ABCMeta): def __init__(self, config: PIPELINE_CONFIGURATION): @@ -381,3 +409,17 @@ def __init__(self, config: PIPELINE_CONFIGURATION): @staticmethod def alias(): return "chatbot-service" + + +class ModelPipelineStreamingChatBot( + ModelPipeline[PIPELINE_CONFIGURATION, StreamingChatBotParameters, StreamingChatBotResponse], + Generic[PIPELINE_CONFIGURATION], + metaclass=ABCMeta, +): + + def __init__(self, config: PIPELINE_CONFIGURATION): + super().__init__(config=config) + + @staticmethod + def alias(): + return "streaming-chatbot-service" diff --git a/ansible_ai_connect/ai/api/model_pipelines/registry.py b/ansible_ai_connect/ai/api/model_pipelines/registry.py index f9ee4c88b..1426099ec 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/registry.py +++ b/ansible_ai_connect/ai/api/model_pipelines/registry.py @@ -30,6 +30,7 @@ ModelPipelinePlaybookGeneration, ModelPipelineRoleExplanation, ModelPipelineRoleGeneration, + ModelPipelineStreamingChatBot, ) from ansible_ai_connect.main.settings.types import t_model_mesh_api_type @@ -45,6 +46,7 @@ ModelPipelinePlaybookExplanation, ModelPipelineRoleExplanation, ModelPipelineChatBot, + ModelPipelineStreamingChatBot, PipelineConfiguration, Serializer, ] diff --git a/ansible_ai_connect/ai/api/model_pipelines/tests/__init__.py b/ansible_ai_connect/ai/api/model_pipelines/tests/__init__.py index 4ee8e825d..db4878a0a 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/tests/__init__.py +++ b/ansible_ai_connect/ai/api/model_pipelines/tests/__init__.py @@ -87,6 +87,7 @@ def mock_pipeline_config(pipeline_provider: t_model_mesh_api_type, **kwargs): timeout=extract("timeout", 1000, **kwargs), enable_health_check=extract("enable_health_check", False, **kwargs), verify_ssl=extract("verify_ssl", False, **kwargs), + stream=extract("stream", False, **kwargs), ) case "llamacpp": return LlamaCppConfiguration( diff --git a/ansible_ai_connect/ai/api/serializers.py b/ansible_ai_connect/ai/api/serializers.py index 329e4386b..f3bc7d5a0 100644 --- a/ansible_ai_connect/ai/api/serializers.py +++ b/ansible_ai_connect/ai/api/serializers.py @@ -352,6 +352,14 @@ class ChatRequestSerializer(serializers.Serializer): ) +class StreamingChatRequestSerializer(ChatRequestSerializer): + media_type = serializers.CharField( + required=False, + label="Media type", + help_text=("A media type to be used in the output from LLM."), + ) + + class ReferencedDocumentsSerializer(serializers.Serializer): docs_url = serializers.CharField() title = serializers.CharField() diff --git a/ansible_ai_connect/ai/api/telemetry/schema1.py b/ansible_ai_connect/ai/api/telemetry/schema1.py index d75a74fc7..e9bbea929 100644 --- a/ansible_ai_connect/ai/api/telemetry/schema1.py +++ b/ansible_ai_connect/ai/api/telemetry/schema1.py @@ -229,3 +229,8 @@ class ChatBotFeedbackEvent(ChatBotBaseEvent): @define class ChatBotOperationalEvent(ChatBotBaseEvent): event_name: str = "chatOperationalEvent" + + +@define +class StreamingChatBotOperationalEvent(ChatBotBaseEvent): + event_name: str = "streamingChatOperationalEvent" diff --git a/ansible_ai_connect/ai/api/tests/test_chat_view.py b/ansible_ai_connect/ai/api/tests/test_chat_view.py index 42c843ef1..31adc47f2 100644 --- a/ansible_ai_connect/ai/api/tests/test_chat_view.py +++ b/ansible_ai_connect/ai/api/tests/test_chat_view.py @@ -23,6 +23,7 @@ from django.apps import apps from django.contrib.auth import get_user_model +from django.http import StreamingHttpResponse from django.test import override_settings from ansible_ai_connect.ai.api.exceptions import ( @@ -34,7 +35,10 @@ ChatbotUnauthorizedException, ChatbotValidationException, ) -from ansible_ai_connect.ai.api.model_pipelines.http.pipelines import HttpChatBotPipeline +from ansible_ai_connect.ai.api.model_pipelines.http.pipelines import ( + HttpChatBotPipeline, + HttpStreamingChatBotPipeline, +) from ansible_ai_connect.ai.api.model_pipelines.tests import mock_pipeline_config from ansible_ai_connect.organizations.models import Organization from ansible_ai_connect.test_utils import ( @@ -178,6 +182,7 @@ def json(self): def query_with_no_error(self, payload, mock_post): return self.client.post(self.api_version_reverse("chat"), payload, format="json") + @override_settings(CHATBOT_DEFAULT_PROVIDER="") @mock.patch( "requests.post", side_effect=mocked_requests_post, @@ -475,3 +480,275 @@ def test_not_rh_internal_user(self): finally: if self.user2: self.user2.delete() + + +class TestStreamingChatView(APIVersionTestCaseBase, WisdomServiceAPITestCaseBase): + api_version = "v1" + + def setUp(self): + super().setUp() + (org, _) = Organization.objects.get_or_create(id=123, telemetry_opt_out=False) + self.user.organization = org + self.user.rh_internal = True + + @staticmethod + def mocked_response(*args, **kwargs): + + # Make sure that the given json data is serializable + input = json.dumps(kwargs["json"]) + assert input is not None + + json_response = { + "response": "AAP 2.5 introduces an updated, unified UI.", + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "truncated": False, + "referenced_documents": [], + } + status_code = 200 + + if kwargs["json"]["query"] == TestChatView.PAYLOAD_INTERNAL_SERVER_ERROR["query"]: + status_code = 500 + json_response = { + "detail": "Internal server error", + } + + response = StreamingHttpResponse() + response.status_code = status_code + response.text = json.dumps(json_response) + return response + + @override_settings(CHATBOT_DEFAULT_PROVIDER="wisdom") + @mock.patch( + "ansible_ai_connect.ai.api.model_pipelines.http.pipelines." + "HttpStreamingChatBotPipeline.get_streaming_http_response", + ) + def query_with_status_code_override(self, payload, mock): + mock.return_value = TestStreamingChatView.mocked_response(json=payload) + return self.client.post(self.api_version_reverse("streaming_chat"), payload, format="json") + + @override_settings(CHATBOT_DEFAULT_PROVIDER="") + def query_without_chat_config(self, payload): + return self.client.post(self.api_version_reverse("streaming_chat"), payload, format="json") + + @override_settings(CHATBOT_DEFAULT_PROVIDER="wisdom") + @mock.patch( + "ansible_ai_connect.ai.api.model_pipelines.http.pipelines." + "HttpStreamingChatBotPipeline.get_streaming_http_response", + ) + def query_with_no_error(self, payload, mock): + mock.return_value = TestStreamingChatView.mocked_response(json=payload) + return self.client.post(self.api_version_reverse("streaming_chat"), payload, format="json") + + def assert_test( + self, + payload, + expected_status_code=200, + expected_exception=None, + expected_log_message=None, + user=None, + ): + if user is None: + user = self.user + with ( + patch.object( + apps.get_app_config("ai"), + "get_model_pipeline", + Mock(return_value=HttpStreamingChatBotPipeline(mock_pipeline_config("http"))), + ), + self.assertLogs(logger="root", level="DEBUG") as log, + ): + self.client.force_authenticate(user=user) + + if expected_status_code >= 400: + if expected_exception == ChatbotNotEnabledException: + r = self.query_without_chat_config(payload) + else: + r = self.query_with_status_code_override(payload) + else: + r = self.query_with_no_error(payload) + + self.assertEqual(r.status_code, expected_status_code) + if expected_exception is not None: + self.assert_error_detail( + r, expected_exception().default_code, expected_exception().default_detail + ) + self.assertInLog(expected_log_message, log) + return r + + def test_chat(self): + self.assert_test(TestChatView.VALID_PAYLOAD) + + def test_chat_with_conversation_id(self): + self.assert_test(TestChatView.VALID_PAYLOAD_WITH_CONVERSATION_ID) + + def test_chat_not_enabled_exception(self): + self.assert_test( + TestChatView.VALID_PAYLOAD, 503, ChatbotNotEnabledException, "Chatbot is not enabled" + ) + + def test_chat_internal_server_exception(self): + self.assert_test( + TestChatView.PAYLOAD_INTERNAL_SERVER_ERROR, + 500, + ChatbotInternalServerException, + "ChatbotInternalServerException", + ) + + @override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE") + def test_operational_telemetry(self): + self.user.rh_user_has_seat = True + self.user.organization = Organization.objects.get_or_create(id=1)[0] + self.client.force_authenticate(user=self.user) + with ( + patch.object( + apps.get_app_config("ai"), + "get_model_pipeline", + Mock( + return_value=HttpStreamingChatBotPipeline( + mock_pipeline_config("http", model_id="granite-8b") + ) + ), + ), + self.assertLogs(logger="root", level="DEBUG") as log, + ): + r = self.query_with_no_error(TestChatView.VALID_PAYLOAD_WITH_CONVERSATION_ID) + self.assertEqual(r.status_code, HTTPStatus.OK) + segment_events = self.extractSegmentEventsFromLog(log) + self.assertEqual( + segment_events[0]["properties"]["chat_prompt"], + TestChatView.VALID_PAYLOAD_WITH_CONVERSATION_ID["query"], + ) + self.assertEqual( + segment_events[0]["properties"]["conversation_id"], + TestChatView.VALID_PAYLOAD_WITH_CONVERSATION_ID["conversation_id"], + ) + self.assertEqual(segment_events[0]["properties"]["modelName"], "granite-8b") + self.assertEqual( + segment_events[0]["properties"]["chat_truncated"], + TestChatView.JSON_RESPONSE["truncated"], + ) + self.assertEqual(len(segment_events[0]["properties"]["chat_referenced_documents"]), 0) + + @override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE") + def test_operational_telemetry_limit_exceeded(self): + q = "".join("hello " for i in range(6500)) + payload = { + "query": q, + } + self.client.force_authenticate(user=self.user) + with ( + patch.object( + apps.get_app_config("ai"), + "get_model_pipeline", + Mock(return_value=HttpStreamingChatBotPipeline(mock_pipeline_config("http"))), + ), + self.assertLogs(logger="root", level="DEBUG") as log, + ): + r = self.query_with_no_error(payload) + self.assertEqual(r.status_code, 200) + segment_events = self.extractSegmentEventsFromLog(log) + self.assertEqual( + segment_events[0]["properties"]["rh_user_org_id"], + 123, + ) + self.assertEqual( + segment_events[0]["properties"]["chat_response"], + "", + ) + + @override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE") + def test_operational_telemetry_anonymizer(self): + self.client.force_authenticate(user=self.user) + with ( + patch.object( + apps.get_app_config("ai"), + "get_model_pipeline", + Mock(return_value=HttpStreamingChatBotPipeline(mock_pipeline_config("http"))), + ), + self.assertLogs(logger="root", level="DEBUG") as log, + ): + r = self.query_with_no_error( + { + "query": "Hello ansible@ansible.com", + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + } + ) + self.assertEqual(r.status_code, HTTPStatus.OK) + segment_events = self.extractSegmentEventsFromLog(log) + self.assertNotEqual( + segment_events[0]["properties"]["chat_prompt"], + "Hello ansible@ansible.com", + ) + + @override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE") + def test_operational_telemetry_with_system_prompt_override(self): + self.client.force_authenticate(user=self.user) + with ( + patch.object( + apps.get_app_config("ai"), + "get_model_pipeline", + Mock( + return_value=HttpStreamingChatBotPipeline( + mock_pipeline_config("http", model_id="granite-8b") + ) + ), + ), + self.assertLogs(logger="root", level="DEBUG") as log, + ): + r = self.query_with_no_error(TestChatView.PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE) + self.assertEqual(r.status_code, HTTPStatus.OK) + segment_events = self.extractSegmentEventsFromLog(log) + self.assertEqual( + segment_events[0]["properties"]["chat_prompt"], + TestChatView.PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE["query"], + ) + self.assertEqual(segment_events[0]["properties"]["modelName"], "granite-8b") + self.assertEqual( + segment_events[0]["properties"]["chat_truncated"], + TestChatView.JSON_RESPONSE["truncated"], + ) + self.assertEqual(len(segment_events[0]["properties"]["chat_referenced_documents"]), 0) + self.assertEqual( + segment_events[0]["properties"]["chat_system_prompt"], + TestChatView.PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE["system_prompt"], + ) + + def test_chat_rate_limit(self): + # Call chat API five times using self.user + for i in range(5): + self.assert_test(TestChatView.VALID_PAYLOAD) + try: + username = "u" + "".join(random.choices(string.digits, k=5)) + password = "secret" + email = "user2@example.com" + self.user2 = get_user_model().objects.create_user( + username=username, + email=email, + password=password, + ) + (org, _) = Organization.objects.get_or_create(id=123, telemetry_opt_out=False) + self.user2.organization = org + self.user2.rh_internal = True + # Call chart API five times using self.user2 + for i in range(5): + self.assert_test(TestChatView.VALID_PAYLOAD, user=self.user2) + # The next chat API call should be the 11th from two users and should receive a 429. + self.assert_test(TestChatView.VALID_PAYLOAD, expected_status_code=429, user=self.user2) + finally: + if self.user2: + self.user2.delete() + + def test_not_rh_internal_user(self): + try: + username = "u" + "".join(random.choices(string.digits, k=5)) + self.user2 = get_user_model().objects.create_user( + username=username, + ) + self.user2.organization = Organization.objects.get_or_create( + id=123, telemetry_opt_out=False + )[0] + self.user2.rh_internal = False + self.assert_test(TestChatView.VALID_PAYLOAD, expected_status_code=403, user=self.user2) + finally: + if self.user2: + self.user2.delete() diff --git a/ansible_ai_connect/ai/api/utils/seated_users_allow_list.py b/ansible_ai_connect/ai/api/utils/seated_users_allow_list.py index 0856a464f..1e2338c22 100644 --- a/ansible_ai_connect/ai/api/utils/seated_users_allow_list.py +++ b/ansible_ai_connect/ai/api/utils/seated_users_allow_list.py @@ -390,4 +390,7 @@ "chatOperationalEvent": { "*": None, }, + "streamingChatOperationalEvent": { + "*": None, + }, } diff --git a/ansible_ai_connect/ai/api/versions/v1/ai/urls.py b/ansible_ai_connect/ai/api/versions/v1/ai/urls.py index 2921b3032..aad4bcdb8 100644 --- a/ansible_ai_connect/ai/api/versions/v1/ai/urls.py +++ b/ansible_ai_connect/ai/api/versions/v1/ai/urls.py @@ -25,4 +25,5 @@ path("generations/role/", views.GenerationRole.as_view(), name="generations/role"), path("feedback/", views.Feedback.as_view(), name="feedback"), path("chat/", views.Chat.as_view(), name="chat"), + path("streaming_chat/", views.StreamingChat.as_view(), name="streaming_chat"), ] diff --git a/ansible_ai_connect/ai/api/versions/v1/ai/views.py b/ansible_ai_connect/ai/api/versions/v1/ai/views.py index 1667003af..f47679e96 100644 --- a/ansible_ai_connect/ai/api/versions/v1/ai/views.py +++ b/ansible_ai_connect/ai/api/versions/v1/ai/views.py @@ -21,6 +21,7 @@ Feedback, GenerationPlaybook, GenerationRole, + StreamingChat, ) __all__ = [ @@ -32,4 +33,5 @@ "ExplanationRole", "Feedback", "Chat", + "StreamingChat", ] diff --git a/ansible_ai_connect/ai/api/views.py b/ansible_ai_connect/ai/api/views.py index 01dec16d4..768be0da9 100644 --- a/ansible_ai_connect/ai/api/views.py +++ b/ansible_ai_connect/ai/api/views.py @@ -21,6 +21,7 @@ from attr import asdict from django.apps import apps from django.conf import settings +from django.http import StreamingHttpResponse from django_prometheus.conf import NAMESPACE from drf_spectacular.utils import OpenApiResponse, extend_schema from oauth2_provider.contrib.rest_framework import IsAuthenticatedOrTokenHasScope @@ -77,10 +78,12 @@ ModelPipelinePlaybookGeneration, ModelPipelineRoleExplanation, ModelPipelineRoleGeneration, + ModelPipelineStreamingChatBot, PlaybookExplanationParameters, PlaybookGenerationParameters, RoleExplanationParameters, RoleGenerationParameters, + StreamingChatBotParameters, ) from ansible_ai_connect.ai.api.pipelines.completions import CompletionsPipeline from ansible_ai_connect.ai.api.telemetry import schema1 @@ -134,6 +137,7 @@ PlaybookGenerationAction, RoleGenerationAction, SentimentFeedback, + StreamingChatRequestSerializer, SuggestionQualityFeedback, ) from .telemetry.schema1 import ( @@ -1126,3 +1130,79 @@ def post(self, request) -> Response: status=rest_framework_status.HTTP_200_OK, headers=headers, ) + + +class StreamingChat(AACSAPIView): + """ + Send a message to the backend chatbot service and get a streaming reply. + """ + + class StreamingChatEndpointThrottle(EndpointRateThrottle): + scope = "chat" + + permission_classes = [ + permissions.IsAuthenticated, + IsAuthenticatedOrTokenHasScope, + IsRHInternalUser | IsTestUser | IsAAPUser, + ] + required_scopes = ["read", "write"] + schema1_event = schema1.StreamingChatBotOperationalEvent + request_serializer_class = StreamingChatRequestSerializer + throttle_classes = [StreamingChatEndpointThrottle] + + llm: ModelPipelineStreamingChatBot + + def __init__(self): + super().__init__() + self.llm = apps.get_app_config("ai").get_model_pipeline(ModelPipelineStreamingChatBot) + + self.chatbot_enabled = ( + self.llm.config.inference_url + and self.llm.config.model_id + and settings.CHATBOT_DEFAULT_PROVIDER + ) + if self.chatbot_enabled: + logger.debug("Chatbot is enabled.") + else: + logger.debug("Chatbot is not enabled.") + + @extend_schema( + request=StreamingChatRequestSerializer, + responses={ + 200: ChatResponseSerializer, # TODO + 400: OpenApiResponse(description="Bad request"), + 403: OpenApiResponse(description="Forbidden"), + 413: OpenApiResponse(description="Prompt too long"), + 422: OpenApiResponse(description="Validation failed"), + 500: OpenApiResponse(description="Internal server error"), + 503: OpenApiResponse(description="Service unavailable"), + }, + summary="Streaming chat request", + ) + def post(self, request) -> StreamingHttpResponse: + if not self.chatbot_enabled: + raise ChatbotNotEnabledException() + + req_query = self.validated_data["query"] + req_system_prompt = self.validated_data.get("system_prompt") + req_provider = self.validated_data.get("provider", settings.CHATBOT_DEFAULT_PROVIDER) + conversation_id = self.validated_data.get("conversation_id") + media_type = self.validated_data.get("media_type") + + # Initialise Segment Event early, in case of exceptions + self.event.chat_prompt = anonymize_struct(req_query) + self.event.chat_system_prompt = req_system_prompt + self.event.provider_id = req_provider + self.event.conversation_id = conversation_id + self.event.modelName = self.req_model_id or self.llm.config.model_id + + return self.llm.invoke( + StreamingChatBotParameters.init( + query=req_query, + system_prompt=req_system_prompt, + model_id=self.req_model_id or self.llm.config.model_id, + provider=req_provider, + conversation_id=conversation_id, + media_type=media_type, + ) + ) diff --git a/ansible_ai_connect/healthcheck/tests/test_healthcheck.py b/ansible_ai_connect/healthcheck/tests/test_healthcheck.py index 9ac53cb66..8f654b564 100644 --- a/ansible_ai_connect/healthcheck/tests/test_healthcheck.py +++ b/ansible_ai_connect/healthcheck/tests/test_healthcheck.py @@ -159,7 +159,7 @@ def assert_basic_data( self.assert_common_data(data, expected_status, deployed_region) timestamp = data["timestamp"] dependencies = data.get("dependencies", []) - self.assertEqual(10, len(dependencies)) + self.assertEqual(11, len(dependencies)) for dependency in dependencies: self.assertIn( dependency["name"], diff --git a/ansible_ai_connect/main/settings/base.py b/ansible_ai_connect/main/settings/base.py index 9e9cd659f..43adf7cd2 100644 --- a/ansible_ai_connect/main/settings/base.py +++ b/ansible_ai_connect/main/settings/base.py @@ -55,6 +55,7 @@ # Application definition INSTALLED_APPS = [ + "daphne", "django.contrib.admin", "django.contrib.auth", "django.contrib.contenttypes", @@ -95,6 +96,11 @@ "csp.middleware.CSPMiddleware", ] +if os.environ.get("CSRF_TRUSTED_ORIGINS"): + CSRF_TRUSTED_ORIGINS = os.environ.get("CSRF_TRUSTED_ORIGINS").split(",") +else: + CSRF_TRUSTED_ORIGINS = ["http://localhost:8000"] + # Allow Prometheus to scrape metrics ALLOWED_CIDR_NETS = [os.environ.get("ALLOWED_CIDR_NETS", "10.0.0.0/8")] @@ -340,6 +346,11 @@ def is_ssl_enabled(value: str) -> bool: "level": "INFO", "propagate": False, }, + "ansible_ai_connect.ai.api.streaming_chat": { + "handlers": ["console"], + "level": "INFO", + "propagate": False, + }, }, "root": { "handlers": ["console"], @@ -364,6 +375,7 @@ def is_ssl_enabled(value: str) -> bool: ] WSGI_APPLICATION = "ansible_ai_connect.main.wsgi.application" +ASGI_APPLICATION = "ansible_ai_connect.main.asgi.application" # Database # https://docs.djangoproject.com/en/4.1/ref/settings/#databases diff --git a/ansible_ai_connect/main/settings/legacy.py b/ansible_ai_connect/main/settings/legacy.py index 6543f43c8..8bd686263 100644 --- a/ansible_ai_connect/main/settings/legacy.py +++ b/ansible_ai_connect/main/settings/legacy.py @@ -192,6 +192,15 @@ def load_from_env_vars(): "stream": False, }, } + model_pipelines_config["ModelPipelineStreamingChatBot"] = { + "provider": "http", + "config": { + "inference_url": chatbot_service_url or "http://localhost:8000", + "model_id": chatbot_service_model_id or "granite3-8b", + "verify_ssl": model_service_verify_ssl, + "stream": True, + }, + } # Enable Health Checks where we have them implemented model_pipelines_config["ModelPipelineCompletions"]["config"][ diff --git a/ansible_ai_connect/main/tests/test_views.py b/ansible_ai_connect/main/tests/test_views.py index 059f7fc28..0eb1b95ec 100644 --- a/ansible_ai_connect/main/tests/test_views.py +++ b/ansible_ai_connect/main/tests/test_views.py @@ -26,7 +26,9 @@ from django.urls import reverse from rest_framework.test import APITransactionTestCase -from ansible_ai_connect.ai.api.model_pipelines.pipelines import ModelPipelineChatBot +from ansible_ai_connect.ai.api.model_pipelines.pipelines import ( + ModelPipelineStreamingChatBot, +) from ansible_ai_connect.main.settings.base import SOCIAL_AUTH_OIDC_KEY from ansible_ai_connect.main.views import LoginView from ansible_ai_connect.test_utils import ( @@ -341,7 +343,7 @@ def test_chatbot_view_with_rh_user(self): self.assertContains(r, TestChatbotView.CHATBOT_PAGE_TITLE) self.assertContains(r, self.rh_user.username) self.assertContains(r, '