Skip to content

Commit f10ebc5

Browse files
committed
Follow Model Pipeline pattern
1 parent cc39ec0 commit f10ebc5

File tree

9 files changed

+236
-74
lines changed

9 files changed

+236
-74
lines changed

ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py

Lines changed: 84 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
import logging
1717
from typing import Optional
1818

19+
import aiohttp
1920
import requests
21+
from django.http import StreamingHttpResponse
2022
from health_check.exceptions import ServiceUnavailable
2123

2224
from ansible_ai_connect.ai.api.exceptions import (
@@ -39,6 +41,8 @@
3941
MetaData,
4042
ModelPipelineChatBot,
4143
ModelPipelineCompletions,
44+
ModelPipelineStreamingChatBot,
45+
StreamingChatBotParameters,
4246
)
4347
from ansible_ai_connect.ai.api.model_pipelines.registry import Register
4448
from ansible_ai_connect.healthcheck.backends import (
@@ -120,13 +124,12 @@ def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_i
120124
raise NotImplementedError
121125

122126

123-
@Register(api_type="http")
124-
class HttpChatBotPipeline(HttpMetaData, ModelPipelineChatBot[HttpConfiguration]):
127+
class HttpChatBotMetaData(HttpMetaData):
125128

126129
def __init__(self, config: HttpConfiguration):
127130
super().__init__(config=config)
128131

129-
def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
132+
def prepare_data(self, params: ChatBotParameters):
130133
query = params.query
131134
conversation_id = params.conversation_id
132135
provider = params.provider
@@ -142,11 +145,49 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
142145
data["conversation_id"] = str(conversation_id)
143146
if system_prompt:
144147
data["system_prompt"] = str(system_prompt)
148+
return data
149+
150+
def self_test(self) -> Optional[HealthCheckSummary]:
151+
summary: HealthCheckSummary = HealthCheckSummary(
152+
{
153+
MODEL_MESH_HEALTH_CHECK_PROVIDER: "http",
154+
MODEL_MESH_HEALTH_CHECK_MODELS: "ok",
155+
}
156+
)
157+
try:
158+
headers = {"Content-Type": "application/json"}
159+
r = requests.get(self.config.inference_url + "/readiness", headers=headers)
160+
r.raise_for_status()
161+
162+
data = r.json()
163+
ready = data.get("ready")
164+
if not ready:
165+
reason = data.get("reason")
166+
summary.add_exception(
167+
MODEL_MESH_HEALTH_CHECK_MODELS,
168+
HealthCheckSummaryException(ServiceUnavailable(reason)),
169+
)
170+
171+
except Exception as e:
172+
logger.exception(str(e))
173+
summary.add_exception(
174+
MODEL_MESH_HEALTH_CHECK_MODELS,
175+
HealthCheckSummaryException(ServiceUnavailable(ERROR_MESSAGE), e),
176+
)
177+
return summary
178+
179+
180+
@Register(api_type="http")
181+
class HttpChatBotPipeline(HttpChatBotMetaData, ModelPipelineChatBot[HttpConfiguration]):
182+
183+
def __init__(self, config: HttpConfiguration):
184+
super().__init__(config=config)
145185

186+
def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
146187
response = requests.post(
147188
self.config.inference_url + "/v1/query",
148189
headers=self.headers,
149-
json=data,
190+
json=self.prepare_data(params),
150191
timeout=self.timeout(1),
151192
verify=self.config.verify_ssl,
152193
)
@@ -171,31 +212,44 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
171212
detail = json.loads(response.text).get("detail", "")
172213
raise ChatbotInternalServerException(detail=detail)
173214

174-
def self_test(self) -> Optional[HealthCheckSummary]:
175-
summary: HealthCheckSummary = HealthCheckSummary(
176-
{
177-
MODEL_MESH_HEALTH_CHECK_PROVIDER: "http",
178-
MODEL_MESH_HEALTH_CHECK_MODELS: "ok",
179-
}
180-
)
181-
try:
182-
headers = {"Content-Type": "application/json"}
183-
r = requests.get(self.config.inference_url + "/readiness", headers=headers)
184-
r.raise_for_status()
185215

186-
data = r.json()
187-
ready = data.get("ready")
188-
if not ready:
189-
reason = data.get("reason")
190-
summary.add_exception(
191-
MODEL_MESH_HEALTH_CHECK_MODELS,
192-
HealthCheckSummaryException(ServiceUnavailable(reason)),
193-
)
216+
class HttpStreamingChatBotMetaData(HttpChatBotMetaData):
194217

195-
except Exception as e:
196-
logger.exception(str(e))
197-
summary.add_exception(
198-
MODEL_MESH_HEALTH_CHECK_MODELS,
199-
HealthCheckSummaryException(ServiceUnavailable(ERROR_MESSAGE), e),
200-
)
201-
return summary
218+
def __init__(self, config: HttpConfiguration):
219+
super().__init__(config=config)
220+
221+
def prepare_data(self, params: StreamingChatBotParameters):
222+
data = super().prepare_data(params)
223+
224+
media_type = params.media_type
225+
if media_type:
226+
data["media_type"] = str(media_type)
227+
228+
return data
229+
230+
231+
@Register(api_type="http")
232+
class HttpStreamingChatBotPipeline(
233+
HttpStreamingChatBotMetaData, ModelPipelineStreamingChatBot[HttpConfiguration]
234+
):
235+
236+
def __init__(self, config: HttpConfiguration):
237+
super().__init__(config=config)
238+
239+
def invoke(self, params: StreamingChatBotParameters) -> StreamingHttpResponse:
240+
raise NotImplementedError
241+
242+
async def async_invoke(self, params: StreamingChatBotParameters) -> StreamingHttpResponse:
243+
async with aiohttp.ClientSession(raise_for_status=True) as session:
244+
headers = {
245+
"Content-Type": "application/json",
246+
"Accept": "application/json,text/event-stream",
247+
}
248+
async with session.post(
249+
self.config.inference_url + "/v1/streaming_query",
250+
json=self.prepare_data(params),
251+
headers=headers,
252+
) as r:
253+
async for chunk in r.content:
254+
logger.debug(chunk)
255+
yield chunk

ansible_ai_connect/ai/api/model_pipelines/pipelines.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,33 @@ def init(
242242
ChatBotResponse = Any
243243

244244

245+
@define
246+
class StreamingChatBotParameters(ChatBotParameters):
247+
media_type: str
248+
249+
@classmethod
250+
def init(
251+
cls,
252+
query: str,
253+
provider: Optional[str] = None,
254+
model_id: Optional[str] = None,
255+
conversation_id: Optional[str] = None,
256+
system_prompt: Optional[str] = None,
257+
media_type: Optional[str] = None,
258+
):
259+
return cls(
260+
query=query,
261+
provider=provider,
262+
model_id=model_id,
263+
conversation_id=conversation_id,
264+
system_prompt=system_prompt,
265+
media_type=media_type,
266+
)
267+
268+
269+
StreamingChatBotResponse = Any
270+
271+
245272
class MetaData(Generic[PIPELINE_CONFIGURATION], metaclass=ABCMeta):
246273

247274
def __init__(self, config: PIPELINE_CONFIGURATION):
@@ -274,6 +301,9 @@ def alias() -> str:
274301
def invoke(self, params: PIPELINE_PARAMETERS) -> PIPELINE_RETURN:
275302
raise NotImplementedError
276303

304+
async def async_invoke(self, params: PIPELINE_PARAMETERS) -> PIPELINE_RETURN:
305+
raise NotImplementedError
306+
277307
@abstractmethod
278308
def self_test(self) -> Optional[HealthCheckSummary]:
279309
raise NotImplementedError
@@ -381,3 +411,17 @@ def __init__(self, config: PIPELINE_CONFIGURATION):
381411
@staticmethod
382412
def alias():
383413
return "chatbot-service"
414+
415+
416+
class ModelPipelineStreamingChatBot(
417+
ModelPipeline[PIPELINE_CONFIGURATION, ChatBotParameters, StreamingChatBotResponse],
418+
Generic[PIPELINE_CONFIGURATION],
419+
metaclass=ABCMeta,
420+
):
421+
422+
def __init__(self, config: PIPELINE_CONFIGURATION):
423+
super().__init__(config=config)
424+
425+
@staticmethod
426+
def alias():
427+
return "streaming-chatbot-service"

ansible_ai_connect/ai/api/model_pipelines/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
ModelPipelinePlaybookGeneration,
3131
ModelPipelineRoleExplanation,
3232
ModelPipelineRoleGeneration,
33+
ModelPipelineStreamingChatBot,
3334
)
3435
from ansible_ai_connect.main.settings.types import t_model_mesh_api_type
3536

@@ -45,6 +46,7 @@
4546
ModelPipelinePlaybookExplanation,
4647
ModelPipelineRoleExplanation,
4748
ModelPipelineChatBot,
49+
ModelPipelineStreamingChatBot,
4850
PipelineConfiguration,
4951
Serializer,
5052
]

ansible_ai_connect/ai/api/serializers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,14 @@ class ChatRequestSerializer(serializers.Serializer):
352352
)
353353

354354

355+
class StreamingChatRequestSerializer(ChatRequestSerializer):
356+
media_type = serializers.CharField(
357+
required=False,
358+
label="Media type",
359+
help_text=("A media type to be used in the output from LLM."),
360+
)
361+
362+
355363
class ReferencedDocumentsSerializer(serializers.Serializer):
356364
docs_url = serializers.CharField()
357365
title = serializers.CharField()

ansible_ai_connect/ai/api/streaming_chat.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

ansible_ai_connect/ai/api/versions/v1/ai/views.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from ansible_ai_connect.ai.api.streaming_chat import StreamingChat
1615
from ansible_ai_connect.ai.api.views import (
1716
Chat,
1817
Completions,
@@ -22,6 +21,7 @@
2221
Feedback,
2322
GenerationPlaybook,
2423
GenerationRole,
24+
StreamingChat,
2525
)
2626

2727
__all__ = [

0 commit comments

Comments
 (0)