Skip to content

Commit dae29f3

Browse files
committed
Chat Streaming endpoint
1 parent c1de7c2 commit dae29f3

File tree

16 files changed

+373
-39
lines changed

16 files changed

+373
-39
lines changed

ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py

Lines changed: 84 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
import logging
1717
from typing import Optional
1818

19+
import aiohttp
1920
import requests
21+
from django.http import StreamingHttpResponse
2022
from health_check.exceptions import ServiceUnavailable
2123

2224
from ansible_ai_connect.ai.api.exceptions import (
@@ -39,6 +41,8 @@
3941
MetaData,
4042
ModelPipelineChatBot,
4143
ModelPipelineCompletions,
44+
ModelPipelineStreamingChatBot,
45+
StreamingChatBotParameters,
4246
)
4347
from ansible_ai_connect.ai.api.model_pipelines.registry import Register
4448
from ansible_ai_connect.healthcheck.backends import (
@@ -120,13 +124,12 @@ def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_i
120124
raise NotImplementedError
121125

122126

123-
@Register(api_type="http")
124-
class HttpChatBotPipeline(HttpMetaData, ModelPipelineChatBot[HttpConfiguration]):
127+
class HttpChatBotMetaData(HttpMetaData):
125128

126129
def __init__(self, config: HttpConfiguration):
127130
super().__init__(config=config)
128131

129-
def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
132+
def prepare_data(self, params: ChatBotParameters):
130133
query = params.query
131134
conversation_id = params.conversation_id
132135
provider = params.provider
@@ -142,11 +145,49 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
142145
data["conversation_id"] = str(conversation_id)
143146
if system_prompt:
144147
data["system_prompt"] = str(system_prompt)
148+
return data
149+
150+
def self_test(self) -> Optional[HealthCheckSummary]:
151+
summary: HealthCheckSummary = HealthCheckSummary(
152+
{
153+
MODEL_MESH_HEALTH_CHECK_PROVIDER: "http",
154+
MODEL_MESH_HEALTH_CHECK_MODELS: "ok",
155+
}
156+
)
157+
try:
158+
headers = {"Content-Type": "application/json"}
159+
r = requests.get(self.config.inference_url + "/readiness", headers=headers)
160+
r.raise_for_status()
161+
162+
data = r.json()
163+
ready = data.get("ready")
164+
if not ready:
165+
reason = data.get("reason")
166+
summary.add_exception(
167+
MODEL_MESH_HEALTH_CHECK_MODELS,
168+
HealthCheckSummaryException(ServiceUnavailable(reason)),
169+
)
170+
171+
except Exception as e:
172+
logger.exception(str(e))
173+
summary.add_exception(
174+
MODEL_MESH_HEALTH_CHECK_MODELS,
175+
HealthCheckSummaryException(ServiceUnavailable(ERROR_MESSAGE), e),
176+
)
177+
return summary
178+
179+
180+
@Register(api_type="http")
181+
class HttpChatBotPipeline(HttpChatBotMetaData, ModelPipelineChatBot[HttpConfiguration]):
182+
183+
def __init__(self, config: HttpConfiguration):
184+
super().__init__(config=config)
145185

186+
def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
146187
response = requests.post(
147188
self.config.inference_url + "/v1/query",
148189
headers=self.headers,
149-
json=data,
190+
json=self.prepare_data(params),
150191
timeout=self.timeout(1),
151192
verify=self.config.verify_ssl,
152193
)
@@ -171,31 +212,44 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
171212
detail = json.loads(response.text).get("detail", "")
172213
raise ChatbotInternalServerException(detail=detail)
173214

174-
def self_test(self) -> Optional[HealthCheckSummary]:
175-
summary: HealthCheckSummary = HealthCheckSummary(
176-
{
177-
MODEL_MESH_HEALTH_CHECK_PROVIDER: "http",
178-
MODEL_MESH_HEALTH_CHECK_MODELS: "ok",
179-
}
180-
)
181-
try:
182-
headers = {"Content-Type": "application/json"}
183-
r = requests.get(self.config.inference_url + "/readiness", headers=headers)
184-
r.raise_for_status()
185215

186-
data = r.json()
187-
ready = data.get("ready")
188-
if not ready:
189-
reason = data.get("reason")
190-
summary.add_exception(
191-
MODEL_MESH_HEALTH_CHECK_MODELS,
192-
HealthCheckSummaryException(ServiceUnavailable(reason)),
193-
)
216+
class HttpStreamingChatBotMetaData(HttpChatBotMetaData):
194217

195-
except Exception as e:
196-
logger.exception(str(e))
197-
summary.add_exception(
198-
MODEL_MESH_HEALTH_CHECK_MODELS,
199-
HealthCheckSummaryException(ServiceUnavailable(ERROR_MESSAGE), e),
200-
)
201-
return summary
218+
def __init__(self, config: HttpConfiguration):
219+
super().__init__(config=config)
220+
221+
def prepare_data(self, params: StreamingChatBotParameters):
222+
data = super().prepare_data(params)
223+
224+
media_type = params.media_type
225+
if media_type:
226+
data["media_type"] = str(media_type)
227+
228+
return data
229+
230+
231+
@Register(api_type="http")
232+
class HttpStreamingChatBotPipeline(
233+
HttpStreamingChatBotMetaData, ModelPipelineStreamingChatBot[HttpConfiguration]
234+
):
235+
236+
def __init__(self, config: HttpConfiguration):
237+
super().__init__(config=config)
238+
239+
def invoke(self, params: StreamingChatBotParameters) -> StreamingHttpResponse:
240+
raise NotImplementedError
241+
242+
async def async_invoke(self, params: StreamingChatBotParameters) -> StreamingHttpResponse:
243+
async with aiohttp.ClientSession(raise_for_status=True) as session:
244+
headers = {
245+
"Content-Type": "application/json",
246+
"Accept": "application/json,text/event-stream",
247+
}
248+
async with session.post(
249+
self.config.inference_url + "/v1/streaming_query",
250+
json=self.prepare_data(params),
251+
headers=headers,
252+
) as r:
253+
async for chunk in r.content:
254+
logger.debug(chunk)
255+
yield chunk

ansible_ai_connect/ai/api/model_pipelines/pipelines.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,33 @@ def init(
242242
ChatBotResponse = Any
243243

244244

245+
@define
246+
class StreamingChatBotParameters(ChatBotParameters):
247+
media_type: str
248+
249+
@classmethod
250+
def init(
251+
cls,
252+
query: str,
253+
provider: Optional[str] = None,
254+
model_id: Optional[str] = None,
255+
conversation_id: Optional[str] = None,
256+
system_prompt: Optional[str] = None,
257+
media_type: Optional[str] = None,
258+
):
259+
return cls(
260+
query=query,
261+
provider=provider,
262+
model_id=model_id,
263+
conversation_id=conversation_id,
264+
system_prompt=system_prompt,
265+
media_type=media_type,
266+
)
267+
268+
269+
StreamingChatBotResponse = Any
270+
271+
245272
class MetaData(Generic[PIPELINE_CONFIGURATION], metaclass=ABCMeta):
246273

247274
def __init__(self, config: PIPELINE_CONFIGURATION):
@@ -274,6 +301,9 @@ def alias() -> str:
274301
def invoke(self, params: PIPELINE_PARAMETERS) -> PIPELINE_RETURN:
275302
raise NotImplementedError
276303

304+
async def async_invoke(self, params: PIPELINE_PARAMETERS) -> PIPELINE_RETURN:
305+
raise NotImplementedError
306+
277307
@abstractmethod
278308
def self_test(self) -> Optional[HealthCheckSummary]:
279309
raise NotImplementedError
@@ -381,3 +411,17 @@ def __init__(self, config: PIPELINE_CONFIGURATION):
381411
@staticmethod
382412
def alias():
383413
return "chatbot-service"
414+
415+
416+
class ModelPipelineStreamingChatBot(
417+
ModelPipeline[PIPELINE_CONFIGURATION, ChatBotParameters, StreamingChatBotResponse],
418+
Generic[PIPELINE_CONFIGURATION],
419+
metaclass=ABCMeta,
420+
):
421+
422+
def __init__(self, config: PIPELINE_CONFIGURATION):
423+
super().__init__(config=config)
424+
425+
@staticmethod
426+
def alias():
427+
return "streaming-chatbot-service"

ansible_ai_connect/ai/api/model_pipelines/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
ModelPipelinePlaybookGeneration,
3131
ModelPipelineRoleExplanation,
3232
ModelPipelineRoleGeneration,
33+
ModelPipelineStreamingChatBot,
3334
)
3435
from ansible_ai_connect.main.settings.types import t_model_mesh_api_type
3536

@@ -45,6 +46,7 @@
4546
ModelPipelinePlaybookExplanation,
4647
ModelPipelineRoleExplanation,
4748
ModelPipelineChatBot,
49+
ModelPipelineStreamingChatBot,
4850
PipelineConfiguration,
4951
Serializer,
5052
]

ansible_ai_connect/ai/api/serializers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,14 @@ class ChatRequestSerializer(serializers.Serializer):
352352
)
353353

354354

355+
class StreamingChatRequestSerializer(ChatRequestSerializer):
356+
media_type = serializers.CharField(
357+
required=False,
358+
label="Media type",
359+
help_text=("A media type to be used in the output from LLM."),
360+
)
361+
362+
355363
class ReferencedDocumentsSerializer(serializers.Serializer):
356364
docs_url = serializers.CharField()
357365
title = serializers.CharField()

ansible_ai_connect/ai/api/versions/v1/ai/urls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,5 @@
2525
path("generations/role/", views.GenerationRole.as_view(), name="generations/role"),
2626
path("feedback/", views.Feedback.as_view(), name="feedback"),
2727
path("chat/", views.Chat.as_view(), name="chat"),
28+
path("streaming_chat/", views.StreamingChat.as_view(), name="streaming_chat"),
2829
]

ansible_ai_connect/ai/api/versions/v1/ai/views.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Feedback,
2222
GenerationPlaybook,
2323
GenerationRole,
24+
StreamingChat,
2425
)
2526

2627
__all__ = [
@@ -32,4 +33,5 @@
3233
"ExplanationRole",
3334
"Feedback",
3435
"Chat",
36+
"StreamingChat",
3537
]

ansible_ai_connect/ai/api/views.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from attr import asdict
2222
from django.apps import apps
2323
from django.conf import settings
24+
from django.http import StreamingHttpResponse
2425
from django_prometheus.conf import NAMESPACE
2526
from drf_spectacular.utils import OpenApiResponse, extend_schema
2627
from oauth2_provider.contrib.rest_framework import IsAuthenticatedOrTokenHasScope
@@ -77,10 +78,12 @@
7778
ModelPipelinePlaybookGeneration,
7879
ModelPipelineRoleExplanation,
7980
ModelPipelineRoleGeneration,
81+
ModelPipelineStreamingChatBot,
8082
PlaybookExplanationParameters,
8183
PlaybookGenerationParameters,
8284
RoleExplanationParameters,
8385
RoleGenerationParameters,
86+
StreamingChatBotParameters,
8487
)
8588
from ansible_ai_connect.ai.api.pipelines.completions import CompletionsPipeline
8689
from ansible_ai_connect.ai.api.telemetry import schema1
@@ -134,6 +137,7 @@
134137
PlaybookGenerationAction,
135138
RoleGenerationAction,
136139
SentimentFeedback,
140+
StreamingChatRequestSerializer,
137141
SuggestionQualityFeedback,
138142
)
139143
from .telemetry.schema1 import (
@@ -1126,3 +1130,82 @@ def post(self, request) -> Response:
11261130
status=rest_framework_status.HTTP_200_OK,
11271131
headers=headers,
11281132
)
1133+
1134+
1135+
class StreamingChat(AACSAPIView):
1136+
"""
1137+
Send a message to the backend chatbot service and get a streaming reply.
1138+
"""
1139+
1140+
class StreamingChatEndpointThrottle(EndpointRateThrottle):
1141+
scope = "chat"
1142+
1143+
permission_classes = [
1144+
permissions.IsAuthenticated,
1145+
IsAuthenticatedOrTokenHasScope,
1146+
IsRHInternalUser | IsTestUser,
1147+
]
1148+
required_scopes = ["read", "write"]
1149+
schema1_event = schema1.ChatBotOperationalEvent # TODO
1150+
request_serializer_class = StreamingChatRequestSerializer
1151+
throttle_classes = [StreamingChatEndpointThrottle]
1152+
1153+
llm: ModelPipelineStreamingChatBot
1154+
1155+
def __init__(self):
1156+
super().__init__()
1157+
self.llm = apps.get_app_config("ai").get_model_pipeline(ModelPipelineStreamingChatBot)
1158+
1159+
self.chatbot_enabled = (
1160+
self.llm.config.inference_url
1161+
and self.llm.config.model_id
1162+
and settings.CHATBOT_DEFAULT_PROVIDER
1163+
)
1164+
if self.chatbot_enabled:
1165+
logger.debug("Chatbot is enabled.")
1166+
else:
1167+
logger.debug("Chatbot is not enabled.")
1168+
1169+
@extend_schema(
1170+
request=StreamingChatRequestSerializer,
1171+
responses={
1172+
200: ChatResponseSerializer, # TODO
1173+
400: OpenApiResponse(description="Bad request"),
1174+
403: OpenApiResponse(description="Forbidden"),
1175+
413: OpenApiResponse(description="Prompt too long"),
1176+
422: OpenApiResponse(description="Validation failed"),
1177+
500: OpenApiResponse(description="Internal server error"),
1178+
503: OpenApiResponse(description="Service unavailable"),
1179+
},
1180+
summary="Streaming chat request",
1181+
)
1182+
def post(self, request) -> Response:
1183+
if not self.chatbot_enabled:
1184+
raise ChatbotNotEnabledException()
1185+
1186+
req_query = self.validated_data["query"]
1187+
req_system_prompt = self.validated_data.get("system_prompt")
1188+
req_provider = self.validated_data.get("provider", settings.CHATBOT_DEFAULT_PROVIDER)
1189+
conversation_id = self.validated_data.get("conversation_id")
1190+
media_type = self.validated_data.get("media_type")
1191+
1192+
# Initialise Segment Event early, in case of exceptions
1193+
self.event.chat_prompt = anonymize_struct(req_query)
1194+
self.event.chat_system_prompt = req_system_prompt
1195+
self.event.provider_id = req_provider
1196+
self.event.conversation_id = conversation_id
1197+
self.event.modelName = self.req_model_id or self.llm.config.model_id
1198+
1199+
return StreamingHttpResponse(
1200+
self.llm.async_invoke(
1201+
StreamingChatBotParameters.init(
1202+
query=req_query,
1203+
system_prompt=req_system_prompt,
1204+
model_id=self.req_model_id or self.llm.config.model_id,
1205+
provider=req_provider,
1206+
conversation_id=conversation_id,
1207+
media_type=media_type,
1208+
)
1209+
),
1210+
content_type="text/event-stream",
1211+
)

0 commit comments

Comments
 (0)