Skip to content

Commit 21b3ff4

Browse files
committed
Hide the chatbot tool_call/ tool_response JSON data
Hide the chatbot tool_call/ tool_response JSON data issue: https://issues.redhat.com/browse/AAP-57513 Signed-off-by: Djebran Lezzoum <ldjebran@gmail.com>
1 parent 2f5f706 commit 21b3ff4

File tree

4 files changed

+197
-3
lines changed

4 files changed

+197
-3
lines changed

ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,9 +380,9 @@ async def async_invoke(self, params: StreamingChatBotParameters) -> AsyncGenerat
380380
async for chunk in response.content:
381381
try:
382382
if chunk:
383-
s = chunk.decode("utf-8").strip()
384-
if s and s.startswith("data: "):
385-
o = json.loads(s[len("data: ") :])
383+
chunk_string = chunk.decode("utf-8").strip()
384+
if chunk_string and chunk_string.startswith("data: "):
385+
o = json.loads(chunk_string[len("data: ") :])
386386
event = o.get("event")
387387
if event == "error":
388388
default_data = {
@@ -406,6 +406,29 @@ async def async_invoke(self, params: StreamingChatBotParameters) -> AsyncGenerat
406406
conversation_id = data.get("conversation_id")
407407
ev.conversation_id = conversation_id
408408
self.send_schema1_event(ev)
409+
elif event in ("tool_call", "tool_result"):
410+
if not settings.CHATBOT_RETURN_TOOL_CALL:
411+
# do not return tool_call event to final user response
412+
# and send an empty token data instead
413+
# include also the original tool_call/tool_response
414+
data = o.get("data", {"id": 0})
415+
chunk_id = data.get("id")
416+
logger.debug(
417+
"hide tool_call/tool_result from final result, "
418+
"original chunk: %s",
419+
chunk_string,
420+
)
421+
new_chunk_data = {
422+
"event": "token",
423+
"data": {"id": chunk_id, "token": ""},
424+
"original": o,
425+
}
426+
new_chunk_data_json = json.dumps(new_chunk_data)
427+
chunk = (
428+
b"data: "
429+
+ new_chunk_data_json.encode("utf-8")
430+
+ b"\n"
431+
)
409432
elif event == "end":
410433
ev.phase = event
411434
default_data = {

ansible_ai_connect/ai/api/model_pipelines/http/tests/test_pipelines.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from unittest import IsolatedAsyncioTestCase
2020
from unittest.mock import MagicMock, patch
2121

22+
from django.test import override_settings
23+
2224
from ansible_ai_connect.ai.api.model_pipelines.http.configuration import (
2325
HttpConfiguration,
2426
)
@@ -176,6 +178,172 @@ async def test_async_invoke_with_no_error(self, mock_post):
176178
pass
177179
self.assertEqual(self.call_counter, 2)
178180

181+
@patch("aiohttp.ClientSession.post")
182+
@override_settings(CHATBOT_RETURN_TOOL_CALL=False)
183+
async def test_async_invoke_tool_call_hidden_with_no_error(self, mock_post):
184+
tool_call_event_id = 5
185+
tool_result_event_id = 6
186+
tool_call_event = {
187+
"event": "tool_call",
188+
"data": {
189+
"id": tool_call_event_id,
190+
"token": {
191+
"tool_name": "knowledge_search",
192+
"arguments": {"query": "Exploratory Data Analysis"},
193+
},
194+
},
195+
}
196+
tool_result_event = {
197+
"event": "tool_result",
198+
"data": {
199+
"id": tool_result_event_id,
200+
"token": {
201+
"tool_name": "knowledge_search",
202+
"summary": "knowledge_search tool found 5 chunks:",
203+
},
204+
},
205+
}
206+
stream_data = [
207+
{"event": "start", "data": {"conversation_id": "92766ddd-dfc8-4830-b269-7a4b3dbc7c3f"}},
208+
{"event": "token", "data": {"id": 0, "token": ""}},
209+
tool_call_event,
210+
tool_result_event,
211+
{"event": "token", "data": {"id": 24, "token": "some data"}},
212+
{"event": "token", "data": {"id": 25, "token": ""}},
213+
{
214+
"event": "end",
215+
"data": {
216+
"referenced_documents": [
217+
{
218+
"doc_title": "Document 1",
219+
"doc_url": "https://example.com/document1",
220+
},
221+
{
222+
"title": "Document 2",
223+
"docs_url": "https://example.com/document2",
224+
},
225+
],
226+
"truncated": False,
227+
"input_tokens": 241,
228+
"output_tokens": 25,
229+
},
230+
},
231+
]
232+
233+
mock_post.return_value = self.get_return_value(stream_data)
234+
with patch(
235+
"ansible_ai_connect.ai.api.model_pipelines.http.pipelines"
236+
".HttpStreamingChatBotPipeline.send_schema1_event",
237+
wraps=self.send_event,
238+
):
239+
tool_calls_data_counter = 0
240+
events_counter = 0
241+
async for chunk in self.pipeline.async_invoke(self.get_params()):
242+
chunk_string = chunk.decode("utf-8")
243+
if chunk_string.startswith("data: "):
244+
chuck_data = json.loads(chunk_string.lstrip("data: "))
245+
if events_counter == 2:
246+
# ensure the event type has been changed to simple token
247+
self.assertEqual(chuck_data["event"], "token")
248+
# ensure the data token is empty
249+
self.assertEqual(chuck_data["data"]["token"], "")
250+
# ensure the event id is preserved
251+
self.assertEqual(chuck_data["data"]["id"], tool_call_event_id)
252+
# ensure the original event is in the chunk data
253+
self.assertEqual(chuck_data["original"], tool_call_event)
254+
tool_calls_data_counter += 1
255+
if events_counter == 3:
256+
# ensure the event type has been changed to simple token
257+
self.assertEqual(chuck_data["event"], "token")
258+
# ensure the data token is empty
259+
self.assertEqual(chuck_data["data"]["token"], "")
260+
# ensure the event id is preserved
261+
self.assertEqual(chuck_data["data"]["id"], tool_result_event_id)
262+
# ensure the original event is in the chunk data
263+
self.assertEqual(chuck_data["original"], tool_result_event)
264+
tool_calls_data_counter += 1
265+
events_counter += 1
266+
self.assertEqual(tool_calls_data_counter, 2)
267+
self.assertEqual(events_counter, len(stream_data))
268+
self.assertEqual(self.call_counter, 2)
269+
270+
@patch("aiohttp.ClientSession.post")
271+
@override_settings(CHATBOT_RETURN_TOOL_CALL=True)
272+
async def test_async_invoke_tool_call_preserved_with_no_error(self, mock_post):
273+
tool_call_event_id = 5
274+
tool_result_event_id = 6
275+
tool_call_event = {
276+
"event": "tool_call",
277+
"data": {
278+
"id": tool_call_event_id,
279+
"token": {
280+
"tool_name": "knowledge_search",
281+
"arguments": {"query": "Exploratory Data Analysis"},
282+
},
283+
},
284+
}
285+
tool_result_event = {
286+
"event": "tool_result",
287+
"data": {
288+
"id": tool_result_event_id,
289+
"token": {
290+
"tool_name": "knowledge_search",
291+
"summary": "knowledge_search tool found 5 chunks:",
292+
},
293+
},
294+
}
295+
stream_data = [
296+
{"event": "start", "data": {"conversation_id": "92766ddd-dfc8-4830-b269-7a4b3dbc7c3f"}},
297+
{"event": "token", "data": {"id": 0, "token": ""}},
298+
tool_call_event,
299+
tool_result_event,
300+
{"event": "token", "data": {"id": 24, "token": "some data"}},
301+
{"event": "token", "data": {"id": 25, "token": ""}},
302+
{
303+
"event": "end",
304+
"data": {
305+
"referenced_documents": [
306+
{
307+
"doc_title": "Document 1",
308+
"doc_url": "https://example.com/document1",
309+
},
310+
{
311+
"title": "Document 2",
312+
"docs_url": "https://example.com/document2",
313+
},
314+
],
315+
"truncated": False,
316+
"input_tokens": 241,
317+
"output_tokens": 25,
318+
},
319+
},
320+
]
321+
322+
mock_post.return_value = self.get_return_value(stream_data)
323+
with patch(
324+
"ansible_ai_connect.ai.api.model_pipelines.http.pipelines"
325+
".HttpStreamingChatBotPipeline.send_schema1_event",
326+
wraps=self.send_event,
327+
):
328+
tool_calls_data_counter = 0
329+
events_counter = 0
330+
async for chunk in self.pipeline.async_invoke(self.get_params()):
331+
chunk_string = chunk.decode("utf-8")
332+
if chunk_string.startswith("data: "):
333+
chuck_data = json.loads(chunk_string.lstrip("data: "))
334+
if events_counter == 2:
335+
# ensure the tool_call has not changed
336+
self.assertEqual(chuck_data, tool_call_event)
337+
tool_calls_data_counter += 1
338+
if events_counter == 3:
339+
# ensure the tool_result has not changed
340+
self.assertEqual(chuck_data, tool_result_event)
341+
tool_calls_data_counter += 1
342+
events_counter += 1
343+
self.assertEqual(tool_calls_data_counter, 2)
344+
self.assertEqual(events_counter, len(stream_data))
345+
self.assertEqual(self.call_counter, 2)
346+
179347
@patch("aiohttp.ClientSession.post")
180348
async def test_async_invoke_prompt_too_long(self, mock_post):
181349
mock_post.return_value = self.get_return_value(self.STREAM_DATA_PROMPT_TOO_LONG)

ansible_ai_connect/main/settings/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,8 @@ def is_ssl_enabled(value: str) -> bool:
596596
CHATBOT_DEBUG_UI = os.getenv("CHATBOT_DEBUG_UI", "False").lower() == "true"
597597
CHATBOT_DEFAULT_SYSTEM_PROMPT = os.getenv("CHATBOT_DEFAULT_SYSTEM_PROMPT")
598598
CHATBOT_API_KEY = os.getenv("CHATBOT_API_KEY")
599+
# by default do not return chatbot tool_call event
600+
CHATBOT_RETURN_TOOL_CALL = os.environ.get("CHATBOT_RETURN_TOOL_CALL", "False").lower() == "true"
599601
# ==========================================
600602

601603
# ==========================================

tools/docker-compose/compose.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ services:
7777
- CHATBOT_DEFAULT_PROVIDER=${CHATBOT_DEFAULT_PROVIDER}
7878
- CHATBOT_DEFAULT_MODEL=${CHATBOT_DEFAULT_MODEL}
7979
- CHATBOT_DEFAULT_SYSTEM_PROMPT=${CHATBOT_DEFAULT_SYSTEM_PROMPT}
80+
- CHATBOT_RETURN_TOOL_CALL=${CHATBOT_RETURN_TOOL_CALL}
8081
- ANSIBLE_AI_MODEL_MESH_CONFIG=${ANSIBLE_AI_MODEL_MESH_CONFIG}
8182
- ANSIBLE_AI_ENABLE_ROLE_GEN_ENDPOINT=${ANSIBLE_AI_ENABLE_ROLE_GEN_ENDPOINT}
8283
- AAP_API_URL=${AAP_API_URL}

0 commit comments

Comments
 (0)