Skip to content

Streaming chat endpoint #1527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 137 additions & 27 deletions ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@

import json
import logging
from typing import Optional
from json import JSONDecodeError
from typing import AsyncGenerator, Optional

import aiohttp
import requests
from django.http import StreamingHttpResponse
from health_check.exceptions import ServiceUnavailable

from ansible_ai_connect.ai.api.exceptions import (
Expand All @@ -39,6 +42,9 @@
MetaData,
ModelPipelineChatBot,
ModelPipelineCompletions,
ModelPipelineStreamingChatBot,
StreamingChatBotParameters,
StreamingChatBotResponse,
)
from ansible_ai_connect.ai.api.model_pipelines.registry import Register
from ansible_ai_connect.healthcheck.backends import (
Expand Down Expand Up @@ -120,8 +126,43 @@ def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_i
raise NotImplementedError


class HttpChatBotMetaData(HttpMetaData):

def __init__(self, config: HttpConfiguration):
super().__init__(config=config)

def self_test(self) -> Optional[HealthCheckSummary]:
summary: HealthCheckSummary = HealthCheckSummary(
{
MODEL_MESH_HEALTH_CHECK_PROVIDER: "http",
MODEL_MESH_HEALTH_CHECK_MODELS: "ok",
}
)
try:
headers = {"Content-Type": "application/json"}
r = requests.get(self.config.inference_url + "/readiness", headers=headers)
r.raise_for_status()

data = r.json()
ready = data.get("ready")
if not ready:
reason = data.get("reason")
summary.add_exception(
MODEL_MESH_HEALTH_CHECK_MODELS,
HealthCheckSummaryException(ServiceUnavailable(reason)),
)

except Exception as e:
logger.exception(str(e))
summary.add_exception(
MODEL_MESH_HEALTH_CHECK_MODELS,
HealthCheckSummaryException(ServiceUnavailable(ERROR_MESSAGE), e),
)
return summary


@Register(api_type="http")
class HttpChatBotPipeline(HttpMetaData, ModelPipelineChatBot[HttpConfiguration]):
class HttpChatBotPipeline(HttpChatBotMetaData, ModelPipelineChatBot[HttpConfiguration]):

def __init__(self, config: HttpConfiguration):
super().__init__(config=config)
Expand Down Expand Up @@ -171,31 +212,100 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
detail = json.loads(response.text).get("detail", "")
raise ChatbotInternalServerException(detail=detail)

def self_test(self) -> Optional[HealthCheckSummary]:
summary: HealthCheckSummary = HealthCheckSummary(
{
MODEL_MESH_HEALTH_CHECK_PROVIDER: "http",
MODEL_MESH_HEALTH_CHECK_MODELS: "ok",
}

@Register(api_type="http")
class HttpStreamingChatBotPipeline(
HttpChatBotMetaData, ModelPipelineStreamingChatBot[HttpConfiguration]
):

def __init__(self, config: HttpConfiguration):
super().__init__(config=config)

def invoke(self, params: StreamingChatBotParameters) -> StreamingChatBotResponse:
response = self.get_streaming_http_response(params)

if response.status_code == 200:
return response
else:
raise ChatbotInternalServerException(detail="Internal server error")

def get_streaming_http_response(
self, params: StreamingChatBotParameters
) -> StreamingHttpResponse:
return StreamingHttpResponse(
self.async_invoke(params),
content_type="text/event-stream",
)
try:
headers = {"Content-Type": "application/json"}
r = requests.get(self.config.inference_url + "/readiness", headers=headers)
r.raise_for_status()

data = r.json()
ready = data.get("ready")
if not ready:
reason = data.get("reason")
summary.add_exception(
MODEL_MESH_HEALTH_CHECK_MODELS,
HealthCheckSummaryException(ServiceUnavailable(reason)),
)
async def async_invoke(self, params: StreamingChatBotParameters) -> AsyncGenerator:
async with aiohttp.ClientSession(raise_for_status=True) as session:
headers = {
"Content-Type": "application/json",
"Accept": "application/json,text/event-stream",
}

except Exception as e:
logger.exception(str(e))
summary.add_exception(
MODEL_MESH_HEALTH_CHECK_MODELS,
HealthCheckSummaryException(ServiceUnavailable(ERROR_MESSAGE), e),
)
return summary
query = params.query
conversation_id = params.conversation_id
provider = params.provider
model_id = params.model_id
system_prompt = params.system_prompt
media_type = params.media_type

data = {
"query": query,
"model": model_id,
"provider": provider,
}
if conversation_id:
data["conversation_id"] = str(conversation_id)
if system_prompt:
data["system_prompt"] = str(system_prompt)
if media_type:
data["media_type"] = str(media_type)

async with session.post(
self.config.inference_url + "/v1/streaming_query",
json=data,
headers=headers,
raise_for_status=False,
) as response:
if response.status == 200:
async for chunk in response.content:
try:
if chunk:
s = chunk.decode("utf-8").strip()
if s and s.startswith("data: "):
o = json.loads(s[len("data: ") :])
if o["event"] == "error":
default_data = {
"response": "(not provided)",
"cause": "(not provided)",
}
data = o.get("data", default_data)
logger.error(
"An error received in chat streaming content:"
+ " response="
+ data.get("response")
+ ", cause="
+ data.get("cause")
)
except JSONDecodeError:
pass
logger.debug(chunk)
yield chunk
else:
logging.error(
"Streaming query API returned status code="
+ str(response.status)
+ ", reason="
+ str(response.reason)
)
error = {
"event": "error",
"data": {
"response": f"Non-200 status code ({response.status}) was received.",
"cause": response.reason,
},
}
yield json.dumps(error).encode("utf-8")
return
173 changes: 173 additions & 0 deletions ansible_ai_connect/ai/api/model_pipelines/http/tests/test_pipelines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
#
# Copyright Red Hat
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from unittest import IsolatedAsyncioTestCase
from unittest.mock import patch

from ansible_ai_connect.test_utils import WisdomLogAwareMixin

from ...pipelines import StreamingChatBotParameters
from ...tests import mock_pipeline_config
from ..pipelines import HttpStreamingChatBotPipeline

logger = logging.getLogger(__name__)


class TestHttpStreamingChatBotPipeline(IsolatedAsyncioTestCase, WisdomLogAwareMixin):
pipeline: HttpStreamingChatBotPipeline

STREAM_DATA = [
{"event": "start", "data": {"conversation_id": "92766ddd-dfc8-4830-b269-7a4b3dbc7c3f"}},
{"event": "token", "data": {"id": 0, "token": ""}},
{"event": "token", "data": {"id": 1, "token": "Hello"}},
{"event": "token", "data": {"id": 2, "token": "!"}},
{"event": "token", "data": {"id": 3, "token": " I"}},
{"event": "token", "data": {"id": 4, "token": "'m"}},
{"event": "token", "data": {"id": 5, "token": " Ansible"}},
{"event": "token", "data": {"id": 6, "token": " L"}},
{"event": "token", "data": {"id": 7, "token": "ights"}},
{"event": "token", "data": {"id": 8, "token": "peed"}},
{"event": "token", "data": {"id": 9, "token": ","}},
{"event": "token", "data": {"id": 10, "token": " your"}},
{"event": "token", "data": {"id": 11, "token": " virtual"}},
{"event": "token", "data": {"id": 12, "token": " assistant"}},
{"event": "token", "data": {"id": 13, "token": " for"}},
{"event": "token", "data": {"id": 14, "token": " all"}},
{"event": "token", "data": {"id": 15, "token": " things"}},
{"event": "token", "data": {"id": 16, "token": " Ansible"}},
{"event": "token", "data": {"id": 17, "token": "."}},
{"event": "token", "data": {"id": 18, "token": " How"}},
{"event": "token", "data": {"id": 19, "token": " can"}},
{"event": "token", "data": {"id": 20, "token": " I"}},
{"event": "token", "data": {"id": 21, "token": " assist"}},
{"event": "token", "data": {"id": 22, "token": " you"}},
{"event": "token", "data": {"id": 23, "token": " today"}},
{"event": "token", "data": {"id": 24, "token": "?"}},
{"event": "token", "data": {"id": 25, "token": ""}},
{
"event": "end",
"data": {
"referenced_documents": [],
"truncated": False,
"input_tokens": 241,
"output_tokens": 25,
},
},
]

STREAM_DATA_PROMPT_TOO_LONG = [
{"event": "start", "data": {"conversation_id": "92766ddd-dfc8-4830-b269-7a4b3dbc7c3e"}},
{
"event": "error",
"data": {"response": "Prompt is too long", "cause": "Prompt length 10000 exceeds LLM"},
},
]

STREAM_DATA_PROMPT_GENERIC_LLM_ERROR = [
{"event": "start", "data": {"conversation_id": "92766ddd-dfc8-4830-b269-7a4b3dbc7c3d"}},
{
"event": "error",
"data": {
"response": "Oops, something went wrong during LLM invocation",
"cause": "A generic LLM error",
},
},
]

STREAM_DATA_PROMPT_ERROR_WITH_NO_DATA = [
{"event": "start", "data": {"conversation_id": "92766ddd-dfc8-4830-b269-7a4b3dbc7c3c"}},
{"event": "error"},
]

def setUp(self):
self.pipeline = HttpStreamingChatBotPipeline(mock_pipeline_config("http"))

def assertInLog(self, s, logs, number_of_matches_expected=None):
self.assertTrue(self.searchInLogOutput(s, logs, number_of_matches_expected), logs)

def get_return_value(self, stream_data, status=200):
class MyAsyncContextManager:
def __init__(self, stream_data, status=200):
self.stream_data = stream_data
self.status = status
self.reason = ""

async def my_async_generator(self):
for data in self.stream_data:
s = json.dumps(data)
yield (f"data: {s}\n\n".encode())

async def __aenter__(self):
self.content = self.my_async_generator()
self.status = self.status
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
pass

return MyAsyncContextManager(stream_data, status)

def get_params(self) -> StreamingChatBotParameters:
return StreamingChatBotParameters(
query="Hello",
provider="",
model_id="",
conversation_id=None,
system_prompt=None,
media_type="application/json",
)

@patch("aiohttp.ClientSession.post")
async def test_async_invoke_with_no_error(self, mock_post):
mock_post.return_value = self.get_return_value(self.STREAM_DATA)
async for _ in self.pipeline.async_invoke(self.get_params()):
pass

@patch("aiohttp.ClientSession.post")
async def test_async_invoke_prompt_too_long(self, mock_post):
mock_post.return_value = self.get_return_value(self.STREAM_DATA_PROMPT_TOO_LONG)
with self.assertLogs(logger="root", level="ERROR") as log:
async for _ in self.pipeline.async_invoke(self.get_params()):
pass
self.assertInLog("Prompt is too long", log)

@patch("aiohttp.ClientSession.post")
async def test_async_invoke_prompt_generic_llm_error(self, mock_post):
mock_post.return_value = self.get_return_value(self.STREAM_DATA_PROMPT_GENERIC_LLM_ERROR)
with self.assertLogs(logger="root", level="ERROR") as log:
async for _ in self.pipeline.async_invoke(self.get_params()):
pass
self.assertInLog("Oops, something went wrong during LLM invocation", log)

@patch("aiohttp.ClientSession.post")
async def test_async_invoke_internal_server_error(self, mock_post):
mock_post.return_value = self.get_return_value(
self.STREAM_DATA_PROMPT_GENERIC_LLM_ERROR, 500
)
with self.assertLogs(logger="root", level="ERROR") as log:
async for _ in self.pipeline.async_invoke(self.get_params()):
pass
self.assertInLog("Streaming query API returned status code=500", log)

@patch("aiohttp.ClientSession.post")
async def test_async_invoke_error_with_no_data(self, mock_post):
mock_post.return_value = self.get_return_value(
self.STREAM_DATA_PROMPT_ERROR_WITH_NO_DATA,
)
with self.assertLogs(logger="root", level="ERROR") as log:
async for _ in self.pipeline.async_invoke(self.get_params()):
pass
self.assertInLog("(not provided)", log)
Loading