|
19 | 19 | from unittest import IsolatedAsyncioTestCase |
20 | 20 | from unittest.mock import MagicMock, patch |
21 | 21 |
|
| 22 | +from django.test import override_settings |
| 23 | + |
22 | 24 | from ansible_ai_connect.ai.api.model_pipelines.http.configuration import ( |
23 | 25 | HttpConfiguration, |
24 | 26 | ) |
@@ -176,6 +178,172 @@ async def test_async_invoke_with_no_error(self, mock_post): |
176 | 178 | pass |
177 | 179 | self.assertEqual(self.call_counter, 2) |
178 | 180 |
|
| 181 | + @patch("aiohttp.ClientSession.post") |
| 182 | + @override_settings(CHATBOT_RETURN_TOOL_CALL=False) |
| 183 | + async def test_async_invoke_tool_call_hidden_with_no_error(self, mock_post): |
| 184 | + tool_call_event_id = 5 |
| 185 | + tool_result_event_id = 6 |
| 186 | + tool_call_event = { |
| 187 | + "event": "tool_call", |
| 188 | + "data": { |
| 189 | + "id": tool_call_event_id, |
| 190 | + "token": { |
| 191 | + "tool_name": "knowledge_search", |
| 192 | + "arguments": {"query": "Exploratory Data Analysis"}, |
| 193 | + }, |
| 194 | + }, |
| 195 | + } |
| 196 | + tool_result_event = { |
| 197 | + "event": "tool_result", |
| 198 | + "data": { |
| 199 | + "id": tool_result_event_id, |
| 200 | + "token": { |
| 201 | + "tool_name": "knowledge_search", |
| 202 | + "summary": "knowledge_search tool found 5 chunks:", |
| 203 | + }, |
| 204 | + }, |
| 205 | + } |
| 206 | + stream_data = [ |
| 207 | + {"event": "start", "data": {"conversation_id": "92766ddd-dfc8-4830-b269-7a4b3dbc7c3f"}}, |
| 208 | + {"event": "token", "data": {"id": 0, "token": ""}}, |
| 209 | + tool_call_event, |
| 210 | + tool_result_event, |
| 211 | + {"event": "token", "data": {"id": 24, "token": "some data"}}, |
| 212 | + {"event": "token", "data": {"id": 25, "token": ""}}, |
| 213 | + { |
| 214 | + "event": "end", |
| 215 | + "data": { |
| 216 | + "referenced_documents": [ |
| 217 | + { |
| 218 | + "doc_title": "Document 1", |
| 219 | + "doc_url": "https://example.com/document1", |
| 220 | + }, |
| 221 | + { |
| 222 | + "title": "Document 2", |
| 223 | + "docs_url": "https://example.com/document2", |
| 224 | + }, |
| 225 | + ], |
| 226 | + "truncated": False, |
| 227 | + "input_tokens": 241, |
| 228 | + "output_tokens": 25, |
| 229 | + }, |
| 230 | + }, |
| 231 | + ] |
| 232 | + |
| 233 | + mock_post.return_value = self.get_return_value(stream_data) |
| 234 | + with patch( |
| 235 | + "ansible_ai_connect.ai.api.model_pipelines.http.pipelines" |
| 236 | + ".HttpStreamingChatBotPipeline.send_schema1_event", |
| 237 | + wraps=self.send_event, |
| 238 | + ): |
| 239 | + tool_calls_data_counter = 0 |
| 240 | + events_counter = 0 |
| 241 | + async for chunk in self.pipeline.async_invoke(self.get_params()): |
| 242 | + chunk_string = chunk.decode("utf-8") |
| 243 | + if chunk_string.startswith("data: "): |
| 244 | + chuck_data = json.loads(chunk_string.lstrip("data: ")) |
| 245 | + if events_counter == 2: |
| 246 | + # ensure the event type has been changed to simple token |
| 247 | + self.assertEqual(chuck_data["event"], "token") |
| 248 | + # ensure the data token is empty |
| 249 | + self.assertEqual(chuck_data["data"]["token"], "") |
| 250 | + # ensure the event id is preserved |
| 251 | + self.assertEqual(chuck_data["data"]["id"], tool_call_event_id) |
| 252 | + # ensure the original event is in the chunk data |
| 253 | + self.assertEqual(chuck_data["original"], tool_call_event) |
| 254 | + tool_calls_data_counter += 1 |
| 255 | + if events_counter == 3: |
| 256 | + # ensure the event type has been changed to simple token |
| 257 | + self.assertEqual(chuck_data["event"], "token") |
| 258 | + # ensure the data token is empty |
| 259 | + self.assertEqual(chuck_data["data"]["token"], "") |
| 260 | + # ensure the event id is preserved |
| 261 | + self.assertEqual(chuck_data["data"]["id"], tool_result_event_id) |
| 262 | + # ensure the original event is in the chunk data |
| 263 | + self.assertEqual(chuck_data["original"], tool_result_event) |
| 264 | + tool_calls_data_counter += 1 |
| 265 | + events_counter += 1 |
| 266 | + self.assertEqual(tool_calls_data_counter, 2) |
| 267 | + self.assertEqual(events_counter, len(stream_data)) |
| 268 | + self.assertEqual(self.call_counter, 2) |
| 269 | + |
| 270 | + @patch("aiohttp.ClientSession.post") |
| 271 | + @override_settings(CHATBOT_RETURN_TOOL_CALL=True) |
| 272 | + async def test_async_invoke_tool_call_preserved_with_no_error(self, mock_post): |
| 273 | + tool_call_event_id = 5 |
| 274 | + tool_result_event_id = 6 |
| 275 | + tool_call_event = { |
| 276 | + "event": "tool_call", |
| 277 | + "data": { |
| 278 | + "id": tool_call_event_id, |
| 279 | + "token": { |
| 280 | + "tool_name": "knowledge_search", |
| 281 | + "arguments": {"query": "Exploratory Data Analysis"}, |
| 282 | + }, |
| 283 | + }, |
| 284 | + } |
| 285 | + tool_result_event = { |
| 286 | + "event": "tool_result", |
| 287 | + "data": { |
| 288 | + "id": tool_result_event_id, |
| 289 | + "token": { |
| 290 | + "tool_name": "knowledge_search", |
| 291 | + "summary": "knowledge_search tool found 5 chunks:", |
| 292 | + }, |
| 293 | + }, |
| 294 | + } |
| 295 | + stream_data = [ |
| 296 | + {"event": "start", "data": {"conversation_id": "92766ddd-dfc8-4830-b269-7a4b3dbc7c3f"}}, |
| 297 | + {"event": "token", "data": {"id": 0, "token": ""}}, |
| 298 | + tool_call_event, |
| 299 | + tool_result_event, |
| 300 | + {"event": "token", "data": {"id": 24, "token": "some data"}}, |
| 301 | + {"event": "token", "data": {"id": 25, "token": ""}}, |
| 302 | + { |
| 303 | + "event": "end", |
| 304 | + "data": { |
| 305 | + "referenced_documents": [ |
| 306 | + { |
| 307 | + "doc_title": "Document 1", |
| 308 | + "doc_url": "https://example.com/document1", |
| 309 | + }, |
| 310 | + { |
| 311 | + "title": "Document 2", |
| 312 | + "docs_url": "https://example.com/document2", |
| 313 | + }, |
| 314 | + ], |
| 315 | + "truncated": False, |
| 316 | + "input_tokens": 241, |
| 317 | + "output_tokens": 25, |
| 318 | + }, |
| 319 | + }, |
| 320 | + ] |
| 321 | + |
| 322 | + mock_post.return_value = self.get_return_value(stream_data) |
| 323 | + with patch( |
| 324 | + "ansible_ai_connect.ai.api.model_pipelines.http.pipelines" |
| 325 | + ".HttpStreamingChatBotPipeline.send_schema1_event", |
| 326 | + wraps=self.send_event, |
| 327 | + ): |
| 328 | + tool_calls_data_counter = 0 |
| 329 | + events_counter = 0 |
| 330 | + async for chunk in self.pipeline.async_invoke(self.get_params()): |
| 331 | + chunk_string = chunk.decode("utf-8") |
| 332 | + if chunk_string.startswith("data: "): |
| 333 | + chuck_data = json.loads(chunk_string.lstrip("data: ")) |
| 334 | + if events_counter == 2: |
| 335 | + # ensure the tool_call has not changed |
| 336 | + self.assertEqual(chuck_data, tool_call_event) |
| 337 | + tool_calls_data_counter += 1 |
| 338 | + if events_counter == 3: |
| 339 | + # ensure the tool_result has not changed |
| 340 | + self.assertEqual(chuck_data, tool_result_event) |
| 341 | + tool_calls_data_counter += 1 |
| 342 | + events_counter += 1 |
| 343 | + self.assertEqual(tool_calls_data_counter, 2) |
| 344 | + self.assertEqual(events_counter, len(stream_data)) |
| 345 | + self.assertEqual(self.call_counter, 2) |
| 346 | + |
179 | 347 | @patch("aiohttp.ClientSession.post") |
180 | 348 | async def test_async_invoke_prompt_too_long(self, mock_post): |
181 | 349 | mock_post.return_value = self.get_return_value(self.STREAM_DATA_PROMPT_TOO_LONG) |
|
0 commit comments