diff --git a/ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py b/ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py index 3ce05be77..cca099948 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py +++ b/ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py @@ -14,9 +14,12 @@ import json import logging -from typing import Optional +from json import JSONDecodeError +from typing import AsyncGenerator, Optional +import aiohttp import requests +from django.http import StreamingHttpResponse from health_check.exceptions import ServiceUnavailable from ansible_ai_connect.ai.api.exceptions import ( @@ -39,6 +42,9 @@ MetaData, ModelPipelineChatBot, ModelPipelineCompletions, + ModelPipelineStreamingChatBot, + StreamingChatBotParameters, + StreamingChatBotResponse, ) from ansible_ai_connect.ai.api.model_pipelines.registry import Register from ansible_ai_connect.healthcheck.backends import ( @@ -120,8 +126,43 @@ def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_i raise NotImplementedError +class HttpChatBotMetaData(HttpMetaData): + + def __init__(self, config: HttpConfiguration): + super().__init__(config=config) + + def self_test(self) -> Optional[HealthCheckSummary]: + summary: HealthCheckSummary = HealthCheckSummary( + { + MODEL_MESH_HEALTH_CHECK_PROVIDER: "http", + MODEL_MESH_HEALTH_CHECK_MODELS: "ok", + } + ) + try: + headers = {"Content-Type": "application/json"} + r = requests.get(self.config.inference_url + "/readiness", headers=headers) + r.raise_for_status() + + data = r.json() + ready = data.get("ready") + if not ready: + reason = data.get("reason") + summary.add_exception( + MODEL_MESH_HEALTH_CHECK_MODELS, + HealthCheckSummaryException(ServiceUnavailable(reason)), + ) + + except Exception as e: + logger.exception(str(e)) + summary.add_exception( + MODEL_MESH_HEALTH_CHECK_MODELS, + HealthCheckSummaryException(ServiceUnavailable(ERROR_MESSAGE), e), + ) + return summary + + @Register(api_type="http") -class HttpChatBotPipeline(HttpMetaData, ModelPipelineChatBot[HttpConfiguration]): +class HttpChatBotPipeline(HttpChatBotMetaData, ModelPipelineChatBot[HttpConfiguration]): def __init__(self, config: HttpConfiguration): super().__init__(config=config) @@ -171,31 +212,100 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse: detail = json.loads(response.text).get("detail", "") raise ChatbotInternalServerException(detail=detail) - def self_test(self) -> Optional[HealthCheckSummary]: - summary: HealthCheckSummary = HealthCheckSummary( - { - MODEL_MESH_HEALTH_CHECK_PROVIDER: "http", - MODEL_MESH_HEALTH_CHECK_MODELS: "ok", - } + +@Register(api_type="http") +class HttpStreamingChatBotPipeline( + HttpChatBotMetaData, ModelPipelineStreamingChatBot[HttpConfiguration] +): + + def __init__(self, config: HttpConfiguration): + super().__init__(config=config) + + def invoke(self, params: StreamingChatBotParameters) -> StreamingChatBotResponse: + response = self.get_streaming_http_response(params) + + if response.status_code == 200: + return response + else: + raise ChatbotInternalServerException(detail="Internal server error") + + def get_streaming_http_response( + self, params: StreamingChatBotParameters + ) -> StreamingHttpResponse: + return StreamingHttpResponse( + self.async_invoke(params), + content_type="text/event-stream", ) - try: - headers = {"Content-Type": "application/json"} - r = requests.get(self.config.inference_url + "/readiness", headers=headers) - r.raise_for_status() - data = r.json() - ready = data.get("ready") - if not ready: - reason = data.get("reason") - summary.add_exception( - MODEL_MESH_HEALTH_CHECK_MODELS, - HealthCheckSummaryException(ServiceUnavailable(reason)), - ) + async def async_invoke(self, params: StreamingChatBotParameters) -> AsyncGenerator: + async with aiohttp.ClientSession(raise_for_status=True) as session: + headers = { + "Content-Type": "application/json", + "Accept": "application/json,text/event-stream", + } - except Exception as e: - logger.exception(str(e)) - summary.add_exception( - MODEL_MESH_HEALTH_CHECK_MODELS, - HealthCheckSummaryException(ServiceUnavailable(ERROR_MESSAGE), e), - ) - return summary + query = params.query + conversation_id = params.conversation_id + provider = params.provider + model_id = params.model_id + system_prompt = params.system_prompt + media_type = params.media_type + + data = { + "query": query, + "model": model_id, + "provider": provider, + } + if conversation_id: + data["conversation_id"] = str(conversation_id) + if system_prompt: + data["system_prompt"] = str(system_prompt) + if media_type: + data["media_type"] = str(media_type) + + async with session.post( + self.config.inference_url + "/v1/streaming_query", + json=data, + headers=headers, + raise_for_status=False, + ) as response: + if response.status == 200: + async for chunk in response.content: + try: + if chunk: + s = chunk.decode("utf-8").strip() + if s and s.startswith("data: "): + o = json.loads(s[len("data: ") :]) + if o["event"] == "error": + default_data = { + "response": "(not provided)", + "cause": "(not provided)", + } + data = o.get("data", default_data) + logger.error( + "An error received in chat streaming content:" + + " response=" + + data.get("response") + + ", cause=" + + data.get("cause") + ) + except JSONDecodeError: + pass + logger.debug(chunk) + yield chunk + else: + logging.error( + "Streaming query API returned status code=" + + str(response.status) + + ", reason=" + + str(response.reason) + ) + error = { + "event": "error", + "data": { + "response": f"Non-200 status code ({response.status}) was received.", + "cause": response.reason, + }, + } + yield json.dumps(error).encode("utf-8") + return diff --git a/ansible_ai_connect/ai/api/model_pipelines/http/tests/test_pipelines.py b/ansible_ai_connect/ai/api/model_pipelines/http/tests/test_pipelines.py new file mode 100644 index 000000000..226114f21 --- /dev/null +++ b/ansible_ai_connect/ai/api/model_pipelines/http/tests/test_pipelines.py @@ -0,0 +1,173 @@ +# +# Copyright Red Hat +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import logging +from unittest import IsolatedAsyncioTestCase +from unittest.mock import patch + +from ansible_ai_connect.test_utils import WisdomLogAwareMixin + +from ...pipelines import StreamingChatBotParameters +from ...tests import mock_pipeline_config +from ..pipelines import HttpStreamingChatBotPipeline + +logger = logging.getLogger(__name__) + + +class TestHttpStreamingChatBotPipeline(IsolatedAsyncioTestCase, WisdomLogAwareMixin): + pipeline: HttpStreamingChatBotPipeline + + STREAM_DATA = [ + {"event": "start", "data": {"conversation_id": "92766ddd-dfc8-4830-b269-7a4b3dbc7c3f"}}, + {"event": "token", "data": {"id": 0, "token": ""}}, + {"event": "token", "data": {"id": 1, "token": "Hello"}}, + {"event": "token", "data": {"id": 2, "token": "!"}}, + {"event": "token", "data": {"id": 3, "token": " I"}}, + {"event": "token", "data": {"id": 4, "token": "'m"}}, + {"event": "token", "data": {"id": 5, "token": " Ansible"}}, + {"event": "token", "data": {"id": 6, "token": " L"}}, + {"event": "token", "data": {"id": 7, "token": "ights"}}, + {"event": "token", "data": {"id": 8, "token": "peed"}}, + {"event": "token", "data": {"id": 9, "token": ","}}, + {"event": "token", "data": {"id": 10, "token": " your"}}, + {"event": "token", "data": {"id": 11, "token": " virtual"}}, + {"event": "token", "data": {"id": 12, "token": " assistant"}}, + {"event": "token", "data": {"id": 13, "token": " for"}}, + {"event": "token", "data": {"id": 14, "token": " all"}}, + {"event": "token", "data": {"id": 15, "token": " things"}}, + {"event": "token", "data": {"id": 16, "token": " Ansible"}}, + {"event": "token", "data": {"id": 17, "token": "."}}, + {"event": "token", "data": {"id": 18, "token": " How"}}, + {"event": "token", "data": {"id": 19, "token": " can"}}, + {"event": "token", "data": {"id": 20, "token": " I"}}, + {"event": "token", "data": {"id": 21, "token": " assist"}}, + {"event": "token", "data": {"id": 22, "token": " you"}}, + {"event": "token", "data": {"id": 23, "token": " today"}}, + {"event": "token", "data": {"id": 24, "token": "?"}}, + {"event": "token", "data": {"id": 25, "token": ""}}, + { + "event": "end", + "data": { + "referenced_documents": [], + "truncated": False, + "input_tokens": 241, + "output_tokens": 25, + }, + }, + ] + + STREAM_DATA_PROMPT_TOO_LONG = [ + {"event": "start", "data": {"conversation_id": "92766ddd-dfc8-4830-b269-7a4b3dbc7c3e"}}, + { + "event": "error", + "data": {"response": "Prompt is too long", "cause": "Prompt length 10000 exceeds LLM"}, + }, + ] + + STREAM_DATA_PROMPT_GENERIC_LLM_ERROR = [ + {"event": "start", "data": {"conversation_id": "92766ddd-dfc8-4830-b269-7a4b3dbc7c3d"}}, + { + "event": "error", + "data": { + "response": "Oops, something went wrong during LLM invocation", + "cause": "A generic LLM error", + }, + }, + ] + + STREAM_DATA_PROMPT_ERROR_WITH_NO_DATA = [ + {"event": "start", "data": {"conversation_id": "92766ddd-dfc8-4830-b269-7a4b3dbc7c3c"}}, + {"event": "error"}, + ] + + def setUp(self): + self.pipeline = HttpStreamingChatBotPipeline(mock_pipeline_config("http")) + + def assertInLog(self, s, logs, number_of_matches_expected=None): + self.assertTrue(self.searchInLogOutput(s, logs, number_of_matches_expected), logs) + + def get_return_value(self, stream_data, status=200): + class MyAsyncContextManager: + def __init__(self, stream_data, status=200): + self.stream_data = stream_data + self.status = status + self.reason = "" + + async def my_async_generator(self): + for data in self.stream_data: + s = json.dumps(data) + yield (f"data: {s}\n\n".encode()) + + async def __aenter__(self): + self.content = self.my_async_generator() + self.status = self.status + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + return MyAsyncContextManager(stream_data, status) + + def get_params(self) -> StreamingChatBotParameters: + return StreamingChatBotParameters( + query="Hello", + provider="", + model_id="", + conversation_id=None, + system_prompt=None, + media_type="application/json", + ) + + @patch("aiohttp.ClientSession.post") + async def test_async_invoke_with_no_error(self, mock_post): + mock_post.return_value = self.get_return_value(self.STREAM_DATA) + async for _ in self.pipeline.async_invoke(self.get_params()): + pass + + @patch("aiohttp.ClientSession.post") + async def test_async_invoke_prompt_too_long(self, mock_post): + mock_post.return_value = self.get_return_value(self.STREAM_DATA_PROMPT_TOO_LONG) + with self.assertLogs(logger="root", level="ERROR") as log: + async for _ in self.pipeline.async_invoke(self.get_params()): + pass + self.assertInLog("Prompt is too long", log) + + @patch("aiohttp.ClientSession.post") + async def test_async_invoke_prompt_generic_llm_error(self, mock_post): + mock_post.return_value = self.get_return_value(self.STREAM_DATA_PROMPT_GENERIC_LLM_ERROR) + with self.assertLogs(logger="root", level="ERROR") as log: + async for _ in self.pipeline.async_invoke(self.get_params()): + pass + self.assertInLog("Oops, something went wrong during LLM invocation", log) + + @patch("aiohttp.ClientSession.post") + async def test_async_invoke_internal_server_error(self, mock_post): + mock_post.return_value = self.get_return_value( + self.STREAM_DATA_PROMPT_GENERIC_LLM_ERROR, 500 + ) + with self.assertLogs(logger="root", level="ERROR") as log: + async for _ in self.pipeline.async_invoke(self.get_params()): + pass + self.assertInLog("Streaming query API returned status code=500", log) + + @patch("aiohttp.ClientSession.post") + async def test_async_invoke_error_with_no_data(self, mock_post): + mock_post.return_value = self.get_return_value( + self.STREAM_DATA_PROMPT_ERROR_WITH_NO_DATA, + ) + with self.assertLogs(logger="root", level="ERROR") as log: + async for _ in self.pipeline.async_invoke(self.get_params()): + pass + self.assertInLog("(not provided)", log) diff --git a/ansible_ai_connect/ai/api/model_pipelines/nop/pipelines.py b/ansible_ai_connect/ai/api/model_pipelines/nop/pipelines.py index c3620ef5c..42a41f848 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/nop/pipelines.py +++ b/ansible_ai_connect/ai/api/model_pipelines/nop/pipelines.py @@ -29,6 +29,7 @@ ModelPipelinePlaybookGeneration, ModelPipelineRoleExplanation, ModelPipelineRoleGeneration, + ModelPipelineStreamingChatBot, PlaybookExplanationParameters, PlaybookExplanationResponse, PlaybookGenerationParameters, @@ -37,6 +38,8 @@ RoleExplanationResponse, RoleGenerationParameters, RoleGenerationResponse, + StreamingChatBotParameters, + StreamingChatBotResponse, ) from ansible_ai_connect.ai.api.model_pipelines.registry import Register from ansible_ai_connect.healthcheck.backends import HealthCheckSummary @@ -143,3 +146,16 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse: def self_test(self) -> Optional[HealthCheckSummary]: raise NotImplementedError + + +@Register(api_type="nop") +class NopStreamingChatBotPipeline(NopMetaData, ModelPipelineStreamingChatBot[NopConfiguration]): + + def __init__(self, config: NopConfiguration): + super().__init__(config=config) + + def invoke(self, params: StreamingChatBotParameters) -> StreamingChatBotResponse: + raise NotImplementedError + + def self_test(self) -> Optional[HealthCheckSummary]: + raise NotImplementedError diff --git a/ansible_ai_connect/ai/api/model_pipelines/pipelines.py b/ansible_ai_connect/ai/api/model_pipelines/pipelines.py index f35f8f79c..c2facc626 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/pipelines.py +++ b/ansible_ai_connect/ai/api/model_pipelines/pipelines.py @@ -18,6 +18,7 @@ from attrs import define from django.conf import settings +from django.http import StreamingHttpResponse from rest_framework.request import Request from rest_framework.response import Response @@ -242,6 +243,33 @@ def init( ChatBotResponse = Any +@define +class StreamingChatBotParameters(ChatBotParameters): + media_type: str + + @classmethod + def init( + cls, + query: str, + provider: Optional[str] = None, + model_id: Optional[str] = None, + conversation_id: Optional[str] = None, + system_prompt: Optional[str] = None, + media_type: Optional[str] = None, + ): + return cls( + query=query, + provider=provider, + model_id=model_id, + conversation_id=conversation_id, + system_prompt=system_prompt, + media_type=media_type, + ) + + +StreamingChatBotResponse = StreamingHttpResponse + + class MetaData(Generic[PIPELINE_CONFIGURATION], metaclass=ABCMeta): def __init__(self, config: PIPELINE_CONFIGURATION): @@ -381,3 +409,17 @@ def __init__(self, config: PIPELINE_CONFIGURATION): @staticmethod def alias(): return "chatbot-service" + + +class ModelPipelineStreamingChatBot( + ModelPipeline[PIPELINE_CONFIGURATION, StreamingChatBotParameters, StreamingChatBotResponse], + Generic[PIPELINE_CONFIGURATION], + metaclass=ABCMeta, +): + + def __init__(self, config: PIPELINE_CONFIGURATION): + super().__init__(config=config) + + @staticmethod + def alias(): + return "streaming-chatbot-service" diff --git a/ansible_ai_connect/ai/api/model_pipelines/registry.py b/ansible_ai_connect/ai/api/model_pipelines/registry.py index f9ee4c88b..1426099ec 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/registry.py +++ b/ansible_ai_connect/ai/api/model_pipelines/registry.py @@ -30,6 +30,7 @@ ModelPipelinePlaybookGeneration, ModelPipelineRoleExplanation, ModelPipelineRoleGeneration, + ModelPipelineStreamingChatBot, ) from ansible_ai_connect.main.settings.types import t_model_mesh_api_type @@ -45,6 +46,7 @@ ModelPipelinePlaybookExplanation, ModelPipelineRoleExplanation, ModelPipelineChatBot, + ModelPipelineStreamingChatBot, PipelineConfiguration, Serializer, ] diff --git a/ansible_ai_connect/ai/api/model_pipelines/tests/__init__.py b/ansible_ai_connect/ai/api/model_pipelines/tests/__init__.py index 4ee8e825d..db4878a0a 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/tests/__init__.py +++ b/ansible_ai_connect/ai/api/model_pipelines/tests/__init__.py @@ -87,6 +87,7 @@ def mock_pipeline_config(pipeline_provider: t_model_mesh_api_type, **kwargs): timeout=extract("timeout", 1000, **kwargs), enable_health_check=extract("enable_health_check", False, **kwargs), verify_ssl=extract("verify_ssl", False, **kwargs), + stream=extract("stream", False, **kwargs), ) case "llamacpp": return LlamaCppConfiguration( diff --git a/ansible_ai_connect/ai/api/serializers.py b/ansible_ai_connect/ai/api/serializers.py index 329e4386b..f3bc7d5a0 100644 --- a/ansible_ai_connect/ai/api/serializers.py +++ b/ansible_ai_connect/ai/api/serializers.py @@ -352,6 +352,14 @@ class ChatRequestSerializer(serializers.Serializer): ) +class StreamingChatRequestSerializer(ChatRequestSerializer): + media_type = serializers.CharField( + required=False, + label="Media type", + help_text=("A media type to be used in the output from LLM."), + ) + + class ReferencedDocumentsSerializer(serializers.Serializer): docs_url = serializers.CharField() title = serializers.CharField() diff --git a/ansible_ai_connect/ai/api/telemetry/schema1.py b/ansible_ai_connect/ai/api/telemetry/schema1.py index d75a74fc7..e9bbea929 100644 --- a/ansible_ai_connect/ai/api/telemetry/schema1.py +++ b/ansible_ai_connect/ai/api/telemetry/schema1.py @@ -229,3 +229,8 @@ class ChatBotFeedbackEvent(ChatBotBaseEvent): @define class ChatBotOperationalEvent(ChatBotBaseEvent): event_name: str = "chatOperationalEvent" + + +@define +class StreamingChatBotOperationalEvent(ChatBotBaseEvent): + event_name: str = "streamingChatOperationalEvent" diff --git a/ansible_ai_connect/ai/api/tests/test_chat_view.py b/ansible_ai_connect/ai/api/tests/test_chat_view.py index 42c843ef1..31adc47f2 100644 --- a/ansible_ai_connect/ai/api/tests/test_chat_view.py +++ b/ansible_ai_connect/ai/api/tests/test_chat_view.py @@ -23,6 +23,7 @@ from django.apps import apps from django.contrib.auth import get_user_model +from django.http import StreamingHttpResponse from django.test import override_settings from ansible_ai_connect.ai.api.exceptions import ( @@ -34,7 +35,10 @@ ChatbotUnauthorizedException, ChatbotValidationException, ) -from ansible_ai_connect.ai.api.model_pipelines.http.pipelines import HttpChatBotPipeline +from ansible_ai_connect.ai.api.model_pipelines.http.pipelines import ( + HttpChatBotPipeline, + HttpStreamingChatBotPipeline, +) from ansible_ai_connect.ai.api.model_pipelines.tests import mock_pipeline_config from ansible_ai_connect.organizations.models import Organization from ansible_ai_connect.test_utils import ( @@ -178,6 +182,7 @@ def json(self): def query_with_no_error(self, payload, mock_post): return self.client.post(self.api_version_reverse("chat"), payload, format="json") + @override_settings(CHATBOT_DEFAULT_PROVIDER="") @mock.patch( "requests.post", side_effect=mocked_requests_post, @@ -475,3 +480,275 @@ def test_not_rh_internal_user(self): finally: if self.user2: self.user2.delete() + + +class TestStreamingChatView(APIVersionTestCaseBase, WisdomServiceAPITestCaseBase): + api_version = "v1" + + def setUp(self): + super().setUp() + (org, _) = Organization.objects.get_or_create(id=123, telemetry_opt_out=False) + self.user.organization = org + self.user.rh_internal = True + + @staticmethod + def mocked_response(*args, **kwargs): + + # Make sure that the given json data is serializable + input = json.dumps(kwargs["json"]) + assert input is not None + + json_response = { + "response": "AAP 2.5 introduces an updated, unified UI.", + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "truncated": False, + "referenced_documents": [], + } + status_code = 200 + + if kwargs["json"]["query"] == TestChatView.PAYLOAD_INTERNAL_SERVER_ERROR["query"]: + status_code = 500 + json_response = { + "detail": "Internal server error", + } + + response = StreamingHttpResponse() + response.status_code = status_code + response.text = json.dumps(json_response) + return response + + @override_settings(CHATBOT_DEFAULT_PROVIDER="wisdom") + @mock.patch( + "ansible_ai_connect.ai.api.model_pipelines.http.pipelines." + "HttpStreamingChatBotPipeline.get_streaming_http_response", + ) + def query_with_status_code_override(self, payload, mock): + mock.return_value = TestStreamingChatView.mocked_response(json=payload) + return self.client.post(self.api_version_reverse("streaming_chat"), payload, format="json") + + @override_settings(CHATBOT_DEFAULT_PROVIDER="") + def query_without_chat_config(self, payload): + return self.client.post(self.api_version_reverse("streaming_chat"), payload, format="json") + + @override_settings(CHATBOT_DEFAULT_PROVIDER="wisdom") + @mock.patch( + "ansible_ai_connect.ai.api.model_pipelines.http.pipelines." + "HttpStreamingChatBotPipeline.get_streaming_http_response", + ) + def query_with_no_error(self, payload, mock): + mock.return_value = TestStreamingChatView.mocked_response(json=payload) + return self.client.post(self.api_version_reverse("streaming_chat"), payload, format="json") + + def assert_test( + self, + payload, + expected_status_code=200, + expected_exception=None, + expected_log_message=None, + user=None, + ): + if user is None: + user = self.user + with ( + patch.object( + apps.get_app_config("ai"), + "get_model_pipeline", + Mock(return_value=HttpStreamingChatBotPipeline(mock_pipeline_config("http"))), + ), + self.assertLogs(logger="root", level="DEBUG") as log, + ): + self.client.force_authenticate(user=user) + + if expected_status_code >= 400: + if expected_exception == ChatbotNotEnabledException: + r = self.query_without_chat_config(payload) + else: + r = self.query_with_status_code_override(payload) + else: + r = self.query_with_no_error(payload) + + self.assertEqual(r.status_code, expected_status_code) + if expected_exception is not None: + self.assert_error_detail( + r, expected_exception().default_code, expected_exception().default_detail + ) + self.assertInLog(expected_log_message, log) + return r + + def test_chat(self): + self.assert_test(TestChatView.VALID_PAYLOAD) + + def test_chat_with_conversation_id(self): + self.assert_test(TestChatView.VALID_PAYLOAD_WITH_CONVERSATION_ID) + + def test_chat_not_enabled_exception(self): + self.assert_test( + TestChatView.VALID_PAYLOAD, 503, ChatbotNotEnabledException, "Chatbot is not enabled" + ) + + def test_chat_internal_server_exception(self): + self.assert_test( + TestChatView.PAYLOAD_INTERNAL_SERVER_ERROR, + 500, + ChatbotInternalServerException, + "ChatbotInternalServerException", + ) + + @override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE") + def test_operational_telemetry(self): + self.user.rh_user_has_seat = True + self.user.organization = Organization.objects.get_or_create(id=1)[0] + self.client.force_authenticate(user=self.user) + with ( + patch.object( + apps.get_app_config("ai"), + "get_model_pipeline", + Mock( + return_value=HttpStreamingChatBotPipeline( + mock_pipeline_config("http", model_id="granite-8b") + ) + ), + ), + self.assertLogs(logger="root", level="DEBUG") as log, + ): + r = self.query_with_no_error(TestChatView.VALID_PAYLOAD_WITH_CONVERSATION_ID) + self.assertEqual(r.status_code, HTTPStatus.OK) + segment_events = self.extractSegmentEventsFromLog(log) + self.assertEqual( + segment_events[0]["properties"]["chat_prompt"], + TestChatView.VALID_PAYLOAD_WITH_CONVERSATION_ID["query"], + ) + self.assertEqual( + segment_events[0]["properties"]["conversation_id"], + TestChatView.VALID_PAYLOAD_WITH_CONVERSATION_ID["conversation_id"], + ) + self.assertEqual(segment_events[0]["properties"]["modelName"], "granite-8b") + self.assertEqual( + segment_events[0]["properties"]["chat_truncated"], + TestChatView.JSON_RESPONSE["truncated"], + ) + self.assertEqual(len(segment_events[0]["properties"]["chat_referenced_documents"]), 0) + + @override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE") + def test_operational_telemetry_limit_exceeded(self): + q = "".join("hello " for i in range(6500)) + payload = { + "query": q, + } + self.client.force_authenticate(user=self.user) + with ( + patch.object( + apps.get_app_config("ai"), + "get_model_pipeline", + Mock(return_value=HttpStreamingChatBotPipeline(mock_pipeline_config("http"))), + ), + self.assertLogs(logger="root", level="DEBUG") as log, + ): + r = self.query_with_no_error(payload) + self.assertEqual(r.status_code, 200) + segment_events = self.extractSegmentEventsFromLog(log) + self.assertEqual( + segment_events[0]["properties"]["rh_user_org_id"], + 123, + ) + self.assertEqual( + segment_events[0]["properties"]["chat_response"], + "", + ) + + @override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE") + def test_operational_telemetry_anonymizer(self): + self.client.force_authenticate(user=self.user) + with ( + patch.object( + apps.get_app_config("ai"), + "get_model_pipeline", + Mock(return_value=HttpStreamingChatBotPipeline(mock_pipeline_config("http"))), + ), + self.assertLogs(logger="root", level="DEBUG") as log, + ): + r = self.query_with_no_error( + { + "query": "Hello ansible@ansible.com", + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + } + ) + self.assertEqual(r.status_code, HTTPStatus.OK) + segment_events = self.extractSegmentEventsFromLog(log) + self.assertNotEqual( + segment_events[0]["properties"]["chat_prompt"], + "Hello ansible@ansible.com", + ) + + @override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE") + def test_operational_telemetry_with_system_prompt_override(self): + self.client.force_authenticate(user=self.user) + with ( + patch.object( + apps.get_app_config("ai"), + "get_model_pipeline", + Mock( + return_value=HttpStreamingChatBotPipeline( + mock_pipeline_config("http", model_id="granite-8b") + ) + ), + ), + self.assertLogs(logger="root", level="DEBUG") as log, + ): + r = self.query_with_no_error(TestChatView.PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE) + self.assertEqual(r.status_code, HTTPStatus.OK) + segment_events = self.extractSegmentEventsFromLog(log) + self.assertEqual( + segment_events[0]["properties"]["chat_prompt"], + TestChatView.PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE["query"], + ) + self.assertEqual(segment_events[0]["properties"]["modelName"], "granite-8b") + self.assertEqual( + segment_events[0]["properties"]["chat_truncated"], + TestChatView.JSON_RESPONSE["truncated"], + ) + self.assertEqual(len(segment_events[0]["properties"]["chat_referenced_documents"]), 0) + self.assertEqual( + segment_events[0]["properties"]["chat_system_prompt"], + TestChatView.PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE["system_prompt"], + ) + + def test_chat_rate_limit(self): + # Call chat API five times using self.user + for i in range(5): + self.assert_test(TestChatView.VALID_PAYLOAD) + try: + username = "u" + "".join(random.choices(string.digits, k=5)) + password = "secret" + email = "user2@example.com" + self.user2 = get_user_model().objects.create_user( + username=username, + email=email, + password=password, + ) + (org, _) = Organization.objects.get_or_create(id=123, telemetry_opt_out=False) + self.user2.organization = org + self.user2.rh_internal = True + # Call chart API five times using self.user2 + for i in range(5): + self.assert_test(TestChatView.VALID_PAYLOAD, user=self.user2) + # The next chat API call should be the 11th from two users and should receive a 429. + self.assert_test(TestChatView.VALID_PAYLOAD, expected_status_code=429, user=self.user2) + finally: + if self.user2: + self.user2.delete() + + def test_not_rh_internal_user(self): + try: + username = "u" + "".join(random.choices(string.digits, k=5)) + self.user2 = get_user_model().objects.create_user( + username=username, + ) + self.user2.organization = Organization.objects.get_or_create( + id=123, telemetry_opt_out=False + )[0] + self.user2.rh_internal = False + self.assert_test(TestChatView.VALID_PAYLOAD, expected_status_code=403, user=self.user2) + finally: + if self.user2: + self.user2.delete() diff --git a/ansible_ai_connect/ai/api/utils/seated_users_allow_list.py b/ansible_ai_connect/ai/api/utils/seated_users_allow_list.py index 0856a464f..1e2338c22 100644 --- a/ansible_ai_connect/ai/api/utils/seated_users_allow_list.py +++ b/ansible_ai_connect/ai/api/utils/seated_users_allow_list.py @@ -390,4 +390,7 @@ "chatOperationalEvent": { "*": None, }, + "streamingChatOperationalEvent": { + "*": None, + }, } diff --git a/ansible_ai_connect/ai/api/versions/v1/ai/urls.py b/ansible_ai_connect/ai/api/versions/v1/ai/urls.py index 2921b3032..aad4bcdb8 100644 --- a/ansible_ai_connect/ai/api/versions/v1/ai/urls.py +++ b/ansible_ai_connect/ai/api/versions/v1/ai/urls.py @@ -25,4 +25,5 @@ path("generations/role/", views.GenerationRole.as_view(), name="generations/role"), path("feedback/", views.Feedback.as_view(), name="feedback"), path("chat/", views.Chat.as_view(), name="chat"), + path("streaming_chat/", views.StreamingChat.as_view(), name="streaming_chat"), ] diff --git a/ansible_ai_connect/ai/api/versions/v1/ai/views.py b/ansible_ai_connect/ai/api/versions/v1/ai/views.py index 1667003af..f47679e96 100644 --- a/ansible_ai_connect/ai/api/versions/v1/ai/views.py +++ b/ansible_ai_connect/ai/api/versions/v1/ai/views.py @@ -21,6 +21,7 @@ Feedback, GenerationPlaybook, GenerationRole, + StreamingChat, ) __all__ = [ @@ -32,4 +33,5 @@ "ExplanationRole", "Feedback", "Chat", + "StreamingChat", ] diff --git a/ansible_ai_connect/ai/api/views.py b/ansible_ai_connect/ai/api/views.py index 01dec16d4..768be0da9 100644 --- a/ansible_ai_connect/ai/api/views.py +++ b/ansible_ai_connect/ai/api/views.py @@ -21,6 +21,7 @@ from attr import asdict from django.apps import apps from django.conf import settings +from django.http import StreamingHttpResponse from django_prometheus.conf import NAMESPACE from drf_spectacular.utils import OpenApiResponse, extend_schema from oauth2_provider.contrib.rest_framework import IsAuthenticatedOrTokenHasScope @@ -77,10 +78,12 @@ ModelPipelinePlaybookGeneration, ModelPipelineRoleExplanation, ModelPipelineRoleGeneration, + ModelPipelineStreamingChatBot, PlaybookExplanationParameters, PlaybookGenerationParameters, RoleExplanationParameters, RoleGenerationParameters, + StreamingChatBotParameters, ) from ansible_ai_connect.ai.api.pipelines.completions import CompletionsPipeline from ansible_ai_connect.ai.api.telemetry import schema1 @@ -134,6 +137,7 @@ PlaybookGenerationAction, RoleGenerationAction, SentimentFeedback, + StreamingChatRequestSerializer, SuggestionQualityFeedback, ) from .telemetry.schema1 import ( @@ -1126,3 +1130,79 @@ def post(self, request) -> Response: status=rest_framework_status.HTTP_200_OK, headers=headers, ) + + +class StreamingChat(AACSAPIView): + """ + Send a message to the backend chatbot service and get a streaming reply. + """ + + class StreamingChatEndpointThrottle(EndpointRateThrottle): + scope = "chat" + + permission_classes = [ + permissions.IsAuthenticated, + IsAuthenticatedOrTokenHasScope, + IsRHInternalUser | IsTestUser | IsAAPUser, + ] + required_scopes = ["read", "write"] + schema1_event = schema1.StreamingChatBotOperationalEvent + request_serializer_class = StreamingChatRequestSerializer + throttle_classes = [StreamingChatEndpointThrottle] + + llm: ModelPipelineStreamingChatBot + + def __init__(self): + super().__init__() + self.llm = apps.get_app_config("ai").get_model_pipeline(ModelPipelineStreamingChatBot) + + self.chatbot_enabled = ( + self.llm.config.inference_url + and self.llm.config.model_id + and settings.CHATBOT_DEFAULT_PROVIDER + ) + if self.chatbot_enabled: + logger.debug("Chatbot is enabled.") + else: + logger.debug("Chatbot is not enabled.") + + @extend_schema( + request=StreamingChatRequestSerializer, + responses={ + 200: ChatResponseSerializer, # TODO + 400: OpenApiResponse(description="Bad request"), + 403: OpenApiResponse(description="Forbidden"), + 413: OpenApiResponse(description="Prompt too long"), + 422: OpenApiResponse(description="Validation failed"), + 500: OpenApiResponse(description="Internal server error"), + 503: OpenApiResponse(description="Service unavailable"), + }, + summary="Streaming chat request", + ) + def post(self, request) -> StreamingHttpResponse: + if not self.chatbot_enabled: + raise ChatbotNotEnabledException() + + req_query = self.validated_data["query"] + req_system_prompt = self.validated_data.get("system_prompt") + req_provider = self.validated_data.get("provider", settings.CHATBOT_DEFAULT_PROVIDER) + conversation_id = self.validated_data.get("conversation_id") + media_type = self.validated_data.get("media_type") + + # Initialise Segment Event early, in case of exceptions + self.event.chat_prompt = anonymize_struct(req_query) + self.event.chat_system_prompt = req_system_prompt + self.event.provider_id = req_provider + self.event.conversation_id = conversation_id + self.event.modelName = self.req_model_id or self.llm.config.model_id + + return self.llm.invoke( + StreamingChatBotParameters.init( + query=req_query, + system_prompt=req_system_prompt, + model_id=self.req_model_id or self.llm.config.model_id, + provider=req_provider, + conversation_id=conversation_id, + media_type=media_type, + ) + ) diff --git a/ansible_ai_connect/healthcheck/tests/test_healthcheck.py b/ansible_ai_connect/healthcheck/tests/test_healthcheck.py index 9ac53cb66..8f654b564 100644 --- a/ansible_ai_connect/healthcheck/tests/test_healthcheck.py +++ b/ansible_ai_connect/healthcheck/tests/test_healthcheck.py @@ -159,7 +159,7 @@ def assert_basic_data( self.assert_common_data(data, expected_status, deployed_region) timestamp = data["timestamp"] dependencies = data.get("dependencies", []) - self.assertEqual(10, len(dependencies)) + self.assertEqual(11, len(dependencies)) for dependency in dependencies: self.assertIn( dependency["name"], diff --git a/ansible_ai_connect/main/settings/base.py b/ansible_ai_connect/main/settings/base.py index 9e9cd659f..43adf7cd2 100644 --- a/ansible_ai_connect/main/settings/base.py +++ b/ansible_ai_connect/main/settings/base.py @@ -55,6 +55,7 @@ # Application definition INSTALLED_APPS = [ + "daphne", "django.contrib.admin", "django.contrib.auth", "django.contrib.contenttypes", @@ -95,6 +96,11 @@ "csp.middleware.CSPMiddleware", ] +if os.environ.get("CSRF_TRUSTED_ORIGINS"): + CSRF_TRUSTED_ORIGINS = os.environ.get("CSRF_TRUSTED_ORIGINS").split(",") +else: + CSRF_TRUSTED_ORIGINS = ["http://localhost:8000"] + # Allow Prometheus to scrape metrics ALLOWED_CIDR_NETS = [os.environ.get("ALLOWED_CIDR_NETS", "10.0.0.0/8")] @@ -340,6 +346,11 @@ def is_ssl_enabled(value: str) -> bool: "level": "INFO", "propagate": False, }, + "ansible_ai_connect.ai.api.streaming_chat": { + "handlers": ["console"], + "level": "INFO", + "propagate": False, + }, }, "root": { "handlers": ["console"], @@ -364,6 +375,7 @@ def is_ssl_enabled(value: str) -> bool: ] WSGI_APPLICATION = "ansible_ai_connect.main.wsgi.application" +ASGI_APPLICATION = "ansible_ai_connect.main.asgi.application" # Database # https://docs.djangoproject.com/en/4.1/ref/settings/#databases diff --git a/ansible_ai_connect/main/settings/legacy.py b/ansible_ai_connect/main/settings/legacy.py index 6543f43c8..8bd686263 100644 --- a/ansible_ai_connect/main/settings/legacy.py +++ b/ansible_ai_connect/main/settings/legacy.py @@ -192,6 +192,15 @@ def load_from_env_vars(): "stream": False, }, } + model_pipelines_config["ModelPipelineStreamingChatBot"] = { + "provider": "http", + "config": { + "inference_url": chatbot_service_url or "http://localhost:8000", + "model_id": chatbot_service_model_id or "granite3-8b", + "verify_ssl": model_service_verify_ssl, + "stream": True, + }, + } # Enable Health Checks where we have them implemented model_pipelines_config["ModelPipelineCompletions"]["config"][ diff --git a/ansible_ai_connect/main/tests/test_views.py b/ansible_ai_connect/main/tests/test_views.py index 059f7fc28..0eb1b95ec 100644 --- a/ansible_ai_connect/main/tests/test_views.py +++ b/ansible_ai_connect/main/tests/test_views.py @@ -26,7 +26,9 @@ from django.urls import reverse from rest_framework.test import APITransactionTestCase -from ansible_ai_connect.ai.api.model_pipelines.pipelines import ModelPipelineChatBot +from ansible_ai_connect.ai.api.model_pipelines.pipelines import ( + ModelPipelineStreamingChatBot, +) from ansible_ai_connect.main.settings.base import SOCIAL_AUTH_OIDC_KEY from ansible_ai_connect.main.views import LoginView from ansible_ai_connect.test_utils import ( @@ -341,7 +343,7 @@ def test_chatbot_view_with_rh_user(self): self.assertContains(r, TestChatbotView.CHATBOT_PAGE_TITLE) self.assertContains(r, self.rh_user.username) self.assertContains(r, '') - self.assertContains(r, '') + self.assertContains(r, '') @override_settings(CHATBOT_DEBUG_UI=True) def test_chatbot_view_with_debug_ui(self): @@ -350,12 +352,12 @@ def test_chatbot_view_with_debug_ui(self): self.assertEqual(r.status_code, HTTPStatus.OK) self.assertContains(r, '') - def test_chatbot_view_with_streaming_enabled(self): - llm: ModelPipelineChatBot = apps.get_app_config("ai").get_model_pipeline( - ModelPipelineChatBot + def test_chatbot_view_with_streaming_disabled(self): + llm: ModelPipelineStreamingChatBot = apps.get_app_config("ai").get_model_pipeline( + ModelPipelineStreamingChatBot ) - llm.config.stream = True + llm.config.inference_url = "" self.client.force_login(user=self.rh_user) r = self.client.get(reverse("chatbot"), {"stream": "true"}) self.assertEqual(r.status_code, HTTPStatus.OK) - self.assertContains(r, '') + self.assertContains(r, '') diff --git a/ansible_ai_connect/main/views.py b/ansible_ai_connect/main/views.py index bb6e470c0..f9db30613 100644 --- a/ansible_ai_connect/main/views.py +++ b/ansible_ai_connect/main/views.py @@ -28,7 +28,10 @@ from rest_framework.renderers import BaseRenderer from rest_framework.views import APIView -from ansible_ai_connect.ai.api.model_pipelines.pipelines import ModelPipelineChatBot +from ansible_ai_connect.ai.api.model_pipelines.pipelines import ( + ModelPipelineChatBot, + ModelPipelineStreamingChatBot, +) from ansible_ai_connect.ai.api.permissions import ( IsOrganisationAdministrator, IsOrganisationLightspeedSubscriber, @@ -121,17 +124,26 @@ class ChatbotView(ProtectedTemplateView): IsRHInternalUser | IsTestUser | IsAAPUser, ] - llm: ModelPipelineChatBot chatbot_enabled: bool + streaming_chatbot_enabled: bool def __init__(self): super().__init__() - self.llm = apps.get_app_config("ai").get_model_pipeline(ModelPipelineChatBot) + chat_llm = apps.get_app_config("ai").get_model_pipeline(ModelPipelineChatBot) self.chatbot_enabled = ( - self.llm.config.inference_url - and self.llm.config.model_id + chat_llm.config.inference_url + and chat_llm.config.model_id + and settings.CHATBOT_DEFAULT_PROVIDER + ) + streaming_chat_llm = apps.get_app_config("ai").get_model_pipeline( + ModelPipelineStreamingChatBot + ) + self.streaming_chatbot_enabled = ( + streaming_chat_llm.config.inference_url + and streaming_chat_llm.config.model_id and settings.CHATBOT_DEFAULT_PROVIDER ) + self.chatbot_enabled = self.chatbot_enabled or self.streaming_chatbot_enabled def get(self, request): # Open the chatbot page when the chatbot service is configured. @@ -148,7 +160,7 @@ def get_context_data(self, **kwargs): if user and user.is_authenticated: context["user_name"] = user.username context["debug"] = "true" if settings.CHATBOT_DEBUG_UI else "false" - context["stream"] = "true" if self.llm.config.stream else "false" + context["stream"] = "true" if self.streaming_chatbot_enabled else "false" return context diff --git a/pyproject.toml b/pyproject.toml index d92395fc6..ac2009b64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,11 +7,13 @@ name = "ansible-ai-connect" description = "Ansible Lightspeed with IBM watsonx Code Assistant." version = "0.1.0" dependencies = [ + 'aiohttp~=3.10.11', 'ansible-core~=2.15.9', 'ansible-anonymizer~=1.5.0', 'ansible-risk-insight~=0.2.7', 'ansible-lint~=24.2.2', 'boto3~=1.26.84', + 'daphne~=4.1.2', 'Django~=4.2.18', 'django-deprecate-fields~=0.1.1', 'django-extensions~=3.2.1', diff --git a/requirements-aarch64.txt b/requirements-aarch64.txt index c3460f09e..89149c82c 100644 --- a/requirements-aarch64.txt +++ b/requirements-aarch64.txt @@ -7,7 +7,9 @@ aiohappyeyeballs==2.3.5 # via aiohttp aiohttp==3.10.11 - # via langchain + # via + # -r requirements.in + # langchain aiosignal==1.3.1 # via aiohttp annotated-types==0.6.0 @@ -30,6 +32,7 @@ argparse==1.4.0 # via uwsgi-readiness-check asgiref==3.8.1 # via + # daphne # django # django-ansible-base asttokens==2.4.1 @@ -39,6 +42,12 @@ attrs==23.2.0 # aiohttp # jsonschema # referencing + # service-identity + # twisted +autobahn==24.4.2 + # via daphne +automat==24.8.1 + # via twisted backcall==0.2.0 # via ipython backoff==2.2.1 @@ -70,14 +79,20 @@ charset-normalizer==3.3.2 # via requests click==8.1.7 # via black +constantly==23.10.4 + # via twisted cryptography==43.0.1 # via # -r requirements.in # ansible-core + # autobahn # django-ansible-base # jwcrypto # pyopenssl + # service-identity # social-auth-core +daphne==4.1.2 + # via -r requirements.in decorator==5.1.1 # via ipython defusedxml==0.8.0rc2 @@ -179,13 +194,21 @@ httpx==0.27.2 # via # langsmith # ollama +hyperlink==21.0.0 + # via + # autobahn + # twisted idna==3.7 # via # -r requirements.in # anyio # httpx + # hyperlink # requests + # twisted # yarl +incremental==24.7.2 + # via twisted inflection==0.5.1 # via # django-ansible-base @@ -321,10 +344,12 @@ pyasn1==0.6.0 # oauth2client # pyasn1-modules # rsa + # service-identity pyasn1-modules==0.4.0 # via # google-auth # oauth2client + # service-identity pycparser==2.21 # via cffi pydantic==2.9.2 @@ -351,6 +376,7 @@ pyopenssl==24.2.1 # via # -r requirements.in # pydrive2 + # twisted pyparsing==3.1.2 # via httplib2 pyrfc3339==1.1 @@ -426,6 +452,8 @@ segment-analytics-python==2.2.2 # via -r requirements.in semver==3.0.2 # via launchdarkly-server-sdk +service-identity==24.2.0 + # via twisted six==1.16.0 # via # asttokens @@ -472,6 +500,10 @@ traitlets==5.14.3 # via # ipython # matplotlib-inline +twisted[tls]==24.11.0 + # via daphne +txaio==23.1.1 + # via autobahn typing-extensions==4.11.0 # via # django-test-migrations @@ -481,6 +513,7 @@ typing-extensions==4.11.0 # pydantic # pydantic-core # sqlalchemy + # twisted uritemplate==4.1.1 # via # drf-spectacular @@ -508,3 +541,8 @@ yamllint==1.35.1 # via ansible-lint yarl==1.17.2 # via aiohttp +zope-interface==7.2 + # via twisted + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/requirements-x86_64.txt b/requirements-x86_64.txt index d3285b583..d5f3bb9c3 100644 --- a/requirements-x86_64.txt +++ b/requirements-x86_64.txt @@ -7,7 +7,9 @@ aiohappyeyeballs==2.3.5 # via aiohttp aiohttp==3.10.11 - # via langchain + # via + # -r requirements.in + # langchain aiosignal==1.3.1 # via aiohttp annotated-types==0.6.0 @@ -30,6 +32,7 @@ argparse==1.4.0 # via uwsgi-readiness-check asgiref==3.8.1 # via + # daphne # django # django-ansible-base asttokens==2.4.1 @@ -39,6 +42,12 @@ attrs==23.2.0 # aiohttp # jsonschema # referencing + # service-identity + # twisted +autobahn==24.4.2 + # via daphne +automat==24.8.1 + # via twisted backcall==0.2.0 # via ipython backoff==2.2.1 @@ -70,14 +79,20 @@ charset-normalizer==3.3.2 # via requests click==8.1.7 # via black +constantly==23.10.4 + # via twisted cryptography==43.0.1 # via # -r requirements.in # ansible-core + # autobahn # django-ansible-base # jwcrypto # pyopenssl + # service-identity # social-auth-core +daphne==4.1.2 + # via -r requirements.in decorator==5.1.1 # via ipython defusedxml==0.8.0rc2 @@ -179,13 +194,21 @@ httpx==0.27.2 # via # langsmith # ollama +hyperlink==21.0.0 + # via + # autobahn + # twisted idna==3.7 # via # -r requirements.in # anyio # httpx + # hyperlink # requests + # twisted # yarl +incremental==24.7.2 + # via twisted inflection==0.5.1 # via # django-ansible-base @@ -321,10 +344,12 @@ pyasn1==0.6.0 # oauth2client # pyasn1-modules # rsa + # service-identity pyasn1-modules==0.4.0 # via # google-auth # oauth2client + # service-identity pycparser==2.21 # via cffi pydantic==2.9.2 @@ -351,6 +376,7 @@ pyopenssl==24.2.1 # via # -r requirements.in # pydrive2 + # twisted pyparsing==3.1.2 # via httplib2 pyrfc3339==1.1 @@ -426,6 +452,8 @@ segment-analytics-python==2.2.2 # via -r requirements.in semver==3.0.2 # via launchdarkly-server-sdk +service-identity==24.2.0 + # via twisted six==1.16.0 # via # asttokens @@ -472,6 +500,10 @@ traitlets==5.14.3 # via # ipython # matplotlib-inline +twisted[tls]==24.11.0 + # via daphne +txaio==23.1.1 + # via autobahn typing-extensions==4.11.0 # via # django-test-migrations @@ -481,6 +513,7 @@ typing-extensions==4.11.0 # pydantic # pydantic-core # sqlalchemy + # twisted uritemplate==4.1.1 # via # drf-spectacular @@ -508,3 +541,8 @@ yamllint==1.35.1 # via ansible-lint yarl==1.17.2 # via aiohttp +zope-interface==7.2 + # via twisted + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/requirements.in b/requirements.in index eb114d896..5b356033b 100644 --- a/requirements.in +++ b/requirements.in @@ -9,6 +9,7 @@ # - https://peps.python.org/pep-0631 # - https://peps.python.org/pep-0508 # ====================================================================== +aiohttp==3.10.11 ansible-anonymizer==1.5.0 ansible-risk-insight==0.2.7 ansible-lint==24.2.2 @@ -17,6 +18,7 @@ boto3==1.26.84 black==24.3.0 certifi@git+https://github.com/ansible/system-certifi@5aa52ab91f9d579bfe52b5acf30ca799f1a563d9 cryptography==43.0.1 +daphne==4.1.2 Django==4.2.18 django-deprecate-fields==0.1.1 django-extensions==3.2.1 diff --git a/tools/configs/nginx-wisdom.conf b/tools/configs/nginx-wisdom.conf index 3f8327679..70c2bb3dd 100644 --- a/tools/configs/nginx-wisdom.conf +++ b/tools/configs/nginx-wisdom.conf @@ -2,6 +2,11 @@ upstream uwsgi { server unix:///var/run/uwsgi/ansible_wisdom.sock; } +upstream daphne { + server unix:///var/run/daphne/ansible_wisdom.sock; +} + + server { listen 8000 default_server; server_name _; @@ -14,4 +19,12 @@ server { uwsgi_pass uwsgi; include /etc/nginx/uwsgi_params; } + + location /api/v1/ai/streaming_chat/ { + proxy_pass http://daphne; + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; + proxy_redirect off; + } } diff --git a/tools/configs/supervisord.conf b/tools/configs/supervisord.conf index 05a897e5d..44e1826ea 100644 --- a/tools/configs/supervisord.conf +++ b/tools/configs/supervisord.conf @@ -28,6 +28,29 @@ stdout_logfile_maxbytes = 0 stderr_logfile = /dev/stderr stderr_logfile_maxbytes = 0 +[fcgi-program:daphne] +# TCP socket used by Nginx backend upstream +socket=tcp://localhost:9000 + +# When daphne is running in multiple processes, each needs to have a different socket. +# In such a case, it is recommended to include process # in the name of socket, but +# then those generated socket names cannot be specified in nginx config file... +# So use this with numprocs=1 for now. See https://github.com/django/daphne/issues/287 +# for more details. +numprocs=1 +command = /var/www/venv/bin/daphne -u /var/run/daphne/ansible_wisdom.sock --fd 0 --access-log - --proxy-headers ansible_ai_connect.main.asgi:application + +autostart = true +autorestart = true +stopwaitsecs = 1 +stopsignal = KILL +stopasgroup = true +killasgroup = true +stdout_logfile = /dev/stdout +stdout_logfile_maxbytes = 0 +stderr_logfile = /dev/stderr +stderr_logfile_maxbytes = 0 + ; [program:test] ; command = sleep infinity diff --git a/tools/openapi-schema/ansible-ai-connect-service.json b/tools/openapi-schema/ansible-ai-connect-service.json index 7f583d958..084b32ed0 100644 --- a/tools/openapi-schema/ansible-ai-connect-service.json +++ b/tools/openapi-schema/ansible-ai-connect-service.json @@ -554,6 +554,77 @@ } } }, + "/api/v1/ai/streaming_chat/": { + "post": { + "operationId": "ai_streaming_chat_create", + "description": "Send a message to the backend chatbot service and get a streaming reply.", + "summary": "Streaming chat request", + "tags": [ + "ai" + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/StreamingChatRequest" + } + }, + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/StreamingChatRequest" + } + }, + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/StreamingChatRequest" + } + } + }, + "required": true + }, + "security": [ + { + "oauth2": [ + "read", + "write" + ] + }, + { + "cookieAuth": [] + } + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ChatResponse" + } + } + }, + "description": "" + }, + "400": { + "description": "Bad request" + }, + "403": { + "description": "Forbidden" + }, + "413": { + "description": "Prompt too long" + }, + "422": { + "description": "Validation failed" + }, + "500": { + "description": "Internal server error" + }, + "503": { + "description": "Service unavailable" + } + } + } + }, "/api/v1/health/": { "get": { "operationId": "health_retrieve", @@ -2111,6 +2182,42 @@ "value" ] }, + "StreamingChatRequest": { + "type": "object", + "properties": { + "conversation_id": { + "type": "string", + "format": "uuid", + "description": "A UUID that identifies the particular conversation is being requested for." + }, + "query": { + "type": "string", + "title": "Query string", + "description": "A query string to be sent to LLM." + }, + "model": { + "type": "string", + "title": "Model name", + "description": "A model to be used on LLM." + }, + "provider": { + "type": "string", + "title": "Provider name", + "description": "A name that identifies a LLM provider." + }, + "system_prompt": { + "type": "string", + "description": "An optional non-default system prompt to be used on LLM (debug mode only)." + }, + "media_type": { + "type": "string", + "description": "A media type to be used in the output from LLM." + } + }, + "required": [ + "query" + ] + }, "SuggestionQualityFeedback": { "type": "object", "properties": { diff --git a/tools/openapi-schema/ansible-ai-connect-service.yaml b/tools/openapi-schema/ansible-ai-connect-service.yaml index 1c4fce437..3ff800164 100644 --- a/tools/openapi-schema/ansible-ai-connect-service.yaml +++ b/tools/openapi-schema/ansible-ai-connect-service.yaml @@ -349,6 +349,50 @@ paths: description: '' '401': description: Unauthorized + /api/v1/ai/streaming_chat/: + post: + operationId: ai_streaming_chat_create + description: Send a message to the backend chatbot service and get a streaming + reply. + summary: Streaming chat request + tags: + - ai + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/StreamingChatRequest' + application/x-www-form-urlencoded: + schema: + $ref: '#/components/schemas/StreamingChatRequest' + multipart/form-data: + schema: + $ref: '#/components/schemas/StreamingChatRequest' + required: true + security: + - oauth2: + - read + - write + - cookieAuth: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/ChatResponse' + description: '' + '400': + description: Bad request + '403': + description: Forbidden + '413': + description: Prompt too long + '422': + description: Validation failed + '500': + description: Internal server error + '503': + description: Service unavailable /api/v1/health/: get: operationId: health_retrieve @@ -1434,6 +1478,35 @@ components: required: - feedback - value + StreamingChatRequest: + type: object + properties: + conversation_id: + type: string + format: uuid + description: A UUID that identifies the particular conversation is being + requested for. + query: + type: string + title: Query string + description: A query string to be sent to LLM. + model: + type: string + title: Model name + description: A model to be used on LLM. + provider: + type: string + title: Provider name + description: A name that identifies a LLM provider. + system_prompt: + type: string + description: An optional non-default system prompt to be used on LLM (debug + mode only). + media_type: + type: string + description: A media type to be used in the output from LLM. + required: + - query SuggestionQualityFeedback: type: object properties: diff --git a/wisdom-service.Containerfile b/wisdom-service.Containerfile index 9d20acbf9..88cd80c44 100644 --- a/wisdom-service.Containerfile +++ b/wisdom-service.Containerfile @@ -50,7 +50,7 @@ RUN /var/www/venv/bin/python3.11 -m pip --no-cache-dir install --no-binary=all c RUN /var/www/venv/bin/python3.11 -m pip --no-cache-dir install -r/var/www/ansible-ai-connect-service/requirements.txt RUN /var/www/venv/bin/python3.11 -m pip --no-cache-dir install -e/var/www/ansible-ai-connect-service/ -RUN mkdir /var/run/uwsgi +RUN mkdir /var/run/uwsgi /var/run/daphne RUN echo -e "\ {\n\ @@ -99,6 +99,7 @@ RUN for dir in \ /var/log/supervisor \ /var/run/supervisor \ /var/run/uwsgi \ + /var/run/daphne \ /var/www/wisdom \ /var/log/nginx \ /etc/ari \