Skip to content

Commit 68df903

Browse files
author
Valentina Bojan
committed
fix tests
1 parent 8df2e88 commit 68df903

2 files changed

Lines changed: 24 additions & 18 deletions

File tree

src/uipath_langchain/agent/guardrails/actions/escalate_action.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,9 +331,13 @@ def _process_llm_escalation_response(
331331
return {}
332332

333333
reviewed_tool_calls_obj = json.loads(reviewed_outputs_json)
334-
reviewed_tool_calls_list = reviewed_tool_calls_obj.get("tool_calls")
334+
reviewed_tool_calls_list = (
335+
reviewed_tool_calls_obj.get("tool_calls")
336+
if "tool_calls" in reviewed_tool_calls_obj
337+
else None
338+
)
335339

336-
if not reviewed_tool_calls_list:
340+
if not reviewed_tool_calls_obj:
337341
return {}
338342

339343
# Track if tool calls were successfully processed

tests/agent/guardrails/actions/test_escalate_action.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,11 @@ async def test_node_interrupts_with_correct_message_data(
159159
assert call_args.data["GuardrailResult"] == "Validation failed"
160160

161161
if stage == ExecutionStage.PRE_EXECUTION:
162-
assert call_args.data["Inputs"] == "Test message"
162+
assert call_args.data["Inputs"] == "\"Test message\""
163163
assert "Outputs" not in call_args.data
164164
else:
165-
assert call_args.data["Inputs"] == "Test message"
166-
assert call_args.data["Outputs"] == "Output message"
165+
assert call_args.data["Inputs"] == "\"Test message\""
166+
assert call_args.data["Outputs"] == "\"Output message\""
167167

168168
@pytest.mark.asyncio
169169
@patch("uipath_langchain.agent.guardrails.actions.escalate_action.interrupt")
@@ -222,7 +222,7 @@ async def test_node_post_agent_interrupts_with_correct_agent_result_data(
222222
assert call_args.data["ExecutionStage"] == "PostExecution"
223223
assert call_args.data["GuardrailResult"] == "Validation failed"
224224

225-
assert call_args.data["Inputs"] == "User prompt message"
225+
assert call_args.data["Inputs"] == "\"User prompt message\""
226226
assert call_args.data["Outputs"] == '{"ok": true}'
227227

228228
@pytest.mark.asyncio
@@ -489,12 +489,13 @@ async def test_post_execution_ai_message_with_tool_calls_extraction(
489489

490490
# Verify interrupt was called with tool calls (name and args) in Outputs and Inputs
491491
call_args = mock_interrupt.call_args[0][0]
492-
assert call_args.data["Inputs"] == "Input message"
492+
assert call_args.data["Inputs"] == "\"Input message\""
493493
tool_outputs = call_args.data["Outputs"]
494-
parsed = json.loads(tool_outputs)
495-
assert len(parsed) == 1 # Tool call data with name and args
496-
assert parsed[0]["name"] == "test_tool"
497-
assert parsed[0]["args"] == {"content": {"input": "test"}}
494+
parsed_obj = json.loads(tool_outputs)
495+
parsed_list = parsed_obj["tool_calls"]
496+
assert len(parsed_list) == 1 # Tool call data with name and args
497+
assert parsed_list[0]["name"] == "test_tool"
498+
assert parsed_list[0]["args"] == {"content": {"input": "test"}}
498499

499500
@pytest.mark.asyncio
500501
@pytest.mark.parametrize(
@@ -614,7 +615,7 @@ async def test_post_execution_ai_message_with_reviewed_outputs_and_tool_calls(
614615
guardrail.description = "Test description"
615616

616617
reviewed_tool_args = {"updated": "tool_content"}
617-
reviewed_outputs = [{"name": "test_tool", "args": reviewed_tool_args}]
618+
reviewed_outputs = {"tool_calls": [{"name": "test_tool", "args": reviewed_tool_args}]}
618619
mock_escalation_result = MagicMock()
619620
mock_escalation_result.action = "Approve"
620621
mock_escalation_result.data = {"ReviewedOutputs": json.dumps(reviewed_outputs)}
@@ -822,7 +823,7 @@ async def test_node_interrupts_with_correct_data_pre_tool(self, mock_interrupt):
822823
call_args = mock_interrupt.call_args[0][0]
823824

824825
assert call_args.data["GuardrailName"] == "Test Guardrail"
825-
assert call_args.data["Component"] == "tool"
826+
assert call_args.data["Component"] == "test_tool"
826827
assert call_args.data["ExecutionStage"] == "PreExecution"
827828
assert call_args.data["Inputs"] == '{"input": "test"}'
828829

@@ -1422,7 +1423,7 @@ async def test_extract_llm_content_pre_execution_empty_content(self):
14221423
ai_message, ExecutionStage.PRE_EXECUTION
14231424
)
14241425

1425-
assert result == ""
1426+
assert result == "\"\""
14261427

14271428
@pytest.mark.asyncio
14281429
async def test_extract_llm_content_post_execution_tool_calls_no_content_field(self):
@@ -1447,11 +1448,12 @@ async def test_extract_llm_content_post_execution_tool_calls_no_content_field(se
14471448
)
14481449

14491450
assert isinstance(result, str)
1450-
parsed = json.loads(result)
1451+
parsed_obj = json.loads(result)
1452+
parsed_list = parsed_obj["tool_calls"]
14511453
# Should extract tool call data with name and args
1452-
assert len(parsed) == 1
1453-
assert parsed[0]["name"] == "tool_without_content"
1454-
assert parsed[0]["args"] == {"param": "value"}
1454+
assert len(parsed_list) == 1
1455+
assert parsed_list[0]["name"] == "tool_without_content"
1456+
assert parsed_list[0]["args"] == {"param": "value"}
14551457

14561458
@pytest.mark.asyncio
14571459
async def test_validate_message_count_empty_messages_raises_exception(self):

0 commit comments

Comments
 (0)