16
16
import logging
17
17
from typing import Optional
18
18
19
+ import aiohttp
19
20
import requests
21
+ from django .http import StreamingHttpResponse
20
22
from health_check .exceptions import ServiceUnavailable
21
23
22
24
from ansible_ai_connect .ai .api .exceptions import (
39
41
MetaData ,
40
42
ModelPipelineChatBot ,
41
43
ModelPipelineCompletions ,
44
+ ModelPipelineStreamingChatBot ,
45
+ StreamingChatBotParameters ,
42
46
)
43
47
from ansible_ai_connect .ai .api .model_pipelines .registry import Register
44
48
from ansible_ai_connect .healthcheck .backends import (
@@ -120,13 +124,12 @@ def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_i
120
124
raise NotImplementedError
121
125
122
126
123
- @Register (api_type = "http" )
124
- class HttpChatBotPipeline (HttpMetaData , ModelPipelineChatBot [HttpConfiguration ]):
127
+ class HttpChatBotMetaData (HttpMetaData ):
125
128
126
129
def __init__ (self , config : HttpConfiguration ):
127
130
super ().__init__ (config = config )
128
131
129
- def invoke (self , params : ChatBotParameters ) -> ChatBotResponse :
132
+ def prepare_data (self , params : ChatBotParameters ):
130
133
query = params .query
131
134
conversation_id = params .conversation_id
132
135
provider = params .provider
@@ -142,11 +145,49 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
142
145
data ["conversation_id" ] = str (conversation_id )
143
146
if system_prompt :
144
147
data ["system_prompt" ] = str (system_prompt )
148
+ return data
149
+
150
+ def self_test (self ) -> Optional [HealthCheckSummary ]:
151
+ summary : HealthCheckSummary = HealthCheckSummary (
152
+ {
153
+ MODEL_MESH_HEALTH_CHECK_PROVIDER : "http" ,
154
+ MODEL_MESH_HEALTH_CHECK_MODELS : "ok" ,
155
+ }
156
+ )
157
+ try :
158
+ headers = {"Content-Type" : "application/json" }
159
+ r = requests .get (self .config .inference_url + "/readiness" , headers = headers )
160
+ r .raise_for_status ()
161
+
162
+ data = r .json ()
163
+ ready = data .get ("ready" )
164
+ if not ready :
165
+ reason = data .get ("reason" )
166
+ summary .add_exception (
167
+ MODEL_MESH_HEALTH_CHECK_MODELS ,
168
+ HealthCheckSummaryException (ServiceUnavailable (reason )),
169
+ )
170
+
171
+ except Exception as e :
172
+ logger .exception (str (e ))
173
+ summary .add_exception (
174
+ MODEL_MESH_HEALTH_CHECK_MODELS ,
175
+ HealthCheckSummaryException (ServiceUnavailable (ERROR_MESSAGE ), e ),
176
+ )
177
+ return summary
178
+
179
+
180
+ @Register (api_type = "http" )
181
+ class HttpChatBotPipeline (HttpChatBotMetaData , ModelPipelineChatBot [HttpConfiguration ]):
182
+
183
+ def __init__ (self , config : HttpConfiguration ):
184
+ super ().__init__ (config = config )
145
185
186
+ def invoke (self , params : ChatBotParameters ) -> ChatBotResponse :
146
187
response = requests .post (
147
188
self .config .inference_url + "/v1/query" ,
148
189
headers = self .headers ,
149
- json = data ,
190
+ json = self . prepare_data ( params ) ,
150
191
timeout = self .timeout (1 ),
151
192
verify = self .config .verify_ssl ,
152
193
)
@@ -171,31 +212,44 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
171
212
detail = json .loads (response .text ).get ("detail" , "" )
172
213
raise ChatbotInternalServerException (detail = detail )
173
214
174
- def self_test (self ) -> Optional [HealthCheckSummary ]:
175
- summary : HealthCheckSummary = HealthCheckSummary (
176
- {
177
- MODEL_MESH_HEALTH_CHECK_PROVIDER : "http" ,
178
- MODEL_MESH_HEALTH_CHECK_MODELS : "ok" ,
179
- }
180
- )
181
- try :
182
- headers = {"Content-Type" : "application/json" }
183
- r = requests .get (self .config .inference_url + "/readiness" , headers = headers )
184
- r .raise_for_status ()
185
215
186
- data = r .json ()
187
- ready = data .get ("ready" )
188
- if not ready :
189
- reason = data .get ("reason" )
190
- summary .add_exception (
191
- MODEL_MESH_HEALTH_CHECK_MODELS ,
192
- HealthCheckSummaryException (ServiceUnavailable (reason )),
193
- )
216
+ class HttpStreamingChatBotMetaData (HttpChatBotMetaData ):
194
217
195
- except Exception as e :
196
- logger .exception (str (e ))
197
- summary .add_exception (
198
- MODEL_MESH_HEALTH_CHECK_MODELS ,
199
- HealthCheckSummaryException (ServiceUnavailable (ERROR_MESSAGE ), e ),
200
- )
201
- return summary
218
+ def __init__ (self , config : HttpConfiguration ):
219
+ super ().__init__ (config = config )
220
+
221
+ def prepare_data (self , params : StreamingChatBotParameters ):
222
+ data = super ().prepare_data (params )
223
+
224
+ media_type = params .media_type
225
+ if media_type :
226
+ data ["media_type" ] = str (media_type )
227
+
228
+ return data
229
+
230
+
231
+ @Register (api_type = "http" )
232
+ class HttpStreamingChatBotPipeline (
233
+ HttpStreamingChatBotMetaData , ModelPipelineStreamingChatBot [HttpConfiguration ]
234
+ ):
235
+
236
+ def __init__ (self , config : HttpConfiguration ):
237
+ super ().__init__ (config = config )
238
+
239
+ def invoke (self , params : StreamingChatBotParameters ) -> StreamingHttpResponse :
240
+ raise NotImplementedError
241
+
242
+ async def async_invoke (self , params : StreamingChatBotParameters ) -> StreamingHttpResponse :
243
+ async with aiohttp .ClientSession (raise_for_status = True ) as session :
244
+ headers = {
245
+ "Content-Type" : "application/json" ,
246
+ "Accept" : "application/json,text/event-stream" ,
247
+ }
248
+ async with session .post (
249
+ self .config .inference_url + "/v1/streaming_query" ,
250
+ json = self .prepare_data (params ),
251
+ headers = headers ,
252
+ ) as r :
253
+ async for chunk in r .content :
254
+ logger .debug (chunk )
255
+ yield chunk
0 commit comments