Skip to content

Commit e299a5a

Browse files
Merge pull request MervinPraison#495 from MervinPraison/claude/issue-410-20250524_055025
Fix agent tool calls expecting strings for integer parameters
2 parents 03e2bb9 + 54cb014 commit e299a5a

4 files changed

Lines changed: 222 additions & 3 deletions

File tree

src/praisonai-agents/praisonaiagents/agent/agent.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,44 @@ def generate_task(self) -> 'Task':
526526
tools=self.tools
527527
)
528528

529+
def _cast_arguments(self, func, arguments):
530+
"""Cast arguments to their expected types based on function signature."""
531+
if not callable(func) or not arguments:
532+
return arguments
533+
534+
try:
535+
sig = inspect.signature(func)
536+
casted_args = {}
537+
538+
for param_name, arg_value in arguments.items():
539+
if param_name in sig.parameters:
540+
param = sig.parameters[param_name]
541+
if param.annotation != inspect.Parameter.empty:
542+
# Handle common type conversions
543+
if param.annotation == int and isinstance(arg_value, (str, float)):
544+
try:
545+
casted_args[param_name] = int(float(arg_value))
546+
except (ValueError, TypeError):
547+
casted_args[param_name] = arg_value
548+
elif param.annotation == float and isinstance(arg_value, (str, int)):
549+
try:
550+
casted_args[param_name] = float(arg_value)
551+
except (ValueError, TypeError):
552+
casted_args[param_name] = arg_value
553+
elif param.annotation == bool and isinstance(arg_value, str):
554+
casted_args[param_name] = arg_value.lower() in ('true', '1', 'yes', 'on')
555+
else:
556+
casted_args[param_name] = arg_value
557+
else:
558+
casted_args[param_name] = arg_value
559+
else:
560+
casted_args[param_name] = arg_value
561+
562+
return casted_args
563+
except Exception as e:
564+
logging.debug(f"Type casting failed for {getattr(func, '__name__', 'unknown function')}: {e}")
565+
return arguments
566+
529567
def execute_tool(self, function_name, arguments):
530568
"""
531569
Execute a tool dynamically based on the function name and arguments.
@@ -576,19 +614,22 @@ def execute_tool(self, function_name, arguments):
576614
run_params = {k: v for k, v in arguments.items()
577615
if k in inspect.signature(instance.run).parameters
578616
and k != 'self'}
579-
return instance.run(**run_params)
617+
casted_params = self._cast_arguments(instance.run, run_params)
618+
return instance.run(**casted_params)
580619

581620
# CrewAI: If it's a class with an _run method, instantiate and call _run
582621
elif inspect.isclass(func) and hasattr(func, '_run'):
583622
instance = func()
584623
run_params = {k: v for k, v in arguments.items()
585624
if k in inspect.signature(instance._run).parameters
586625
and k != 'self'}
587-
return instance._run(**run_params)
626+
casted_params = self._cast_arguments(instance._run, run_params)
627+
return instance._run(**casted_params)
588628

589629
# Otherwise treat as regular function
590630
elif callable(func):
591-
return func(**arguments)
631+
casted_arguments = self._cast_arguments(func, arguments)
632+
return func(**casted_arguments)
592633
except Exception as e:
593634
error_msg = str(e)
594635
logging.error(f"Error executing tool {function_name}: {error_msg}")

tests/unit/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Unit tests for PraisonAI Agents

tests/unit/agent/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Agent unit tests
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
"""
2+
Unit tests for Agent type casting functionality.
3+
4+
Tests the _cast_arguments() method that converts string arguments
5+
to their expected types based on function signatures.
6+
7+
Issue: #410 - Agent calling tool calls always expects strings even if its integer
8+
"""
9+
10+
import unittest
11+
from unittest.mock import Mock, patch
12+
import sys
13+
import os
14+
15+
# Add the source directory to the path for imports
16+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../src/praisonai-agents'))
17+
18+
from praisonaiagents.agent.agent import Agent
19+
20+
21+
class TestAgentTypeCasting(unittest.TestCase):
22+
"""Test cases for Agent type casting functionality."""
23+
24+
def setUp(self):
25+
"""Set up test fixtures."""
26+
self.agent = Agent(
27+
name="TestAgent",
28+
role="Type Casting Tester",
29+
goal="Test type casting functionality"
30+
)
31+
32+
def test_cast_arguments_integer_conversion(self):
33+
"""Test casting string arguments to integers."""
34+
# Define a test function with integer parameter
35+
def test_function(count: int) -> str:
36+
return f"Count: {count}"
37+
38+
# Mock arguments as they would come from JSON (all strings)
39+
arguments = {"count": "42"}
40+
41+
# Test the casting (when _cast_arguments method exists)
42+
if hasattr(self.agent, '_cast_arguments'):
43+
casted_args = self.agent._cast_arguments(test_function, arguments)
44+
self.assertEqual(casted_args["count"], 42)
45+
self.assertIsInstance(casted_args["count"], int)
46+
else:
47+
self.skipTest("_cast_arguments method not implemented yet")
48+
49+
def test_cast_arguments_float_conversion(self):
50+
"""Test casting string arguments to floats."""
51+
def test_function(price: float) -> str:
52+
return f"Price: ${price}"
53+
54+
arguments = {"price": "3.14"}
55+
56+
if hasattr(self.agent, '_cast_arguments'):
57+
casted_args = self.agent._cast_arguments(test_function, arguments)
58+
self.assertEqual(casted_args["price"], 3.14)
59+
self.assertIsInstance(casted_args["price"], float)
60+
else:
61+
self.skipTest("_cast_arguments method not implemented yet")
62+
63+
def test_cast_arguments_boolean_conversion(self):
64+
"""Test casting string arguments to booleans."""
65+
def test_function(enabled: bool) -> str:
66+
return f"Enabled: {enabled}"
67+
68+
# Test various boolean representations
69+
test_cases = [
70+
({"enabled": "true"}, True),
71+
({"enabled": "True"}, True),
72+
({"enabled": "false"}, False),
73+
({"enabled": "False"}, False),
74+
]
75+
76+
if hasattr(self.agent, '_cast_arguments'):
77+
for arguments, expected in test_cases:
78+
with self.subTest(arguments=arguments):
79+
casted_args = self.agent._cast_arguments(test_function, arguments)
80+
self.assertEqual(casted_args["enabled"], expected)
81+
self.assertIsInstance(casted_args["enabled"], bool)
82+
else:
83+
self.skipTest("_cast_arguments method not implemented yet")
84+
85+
def test_cast_arguments_mixed_types(self):
86+
"""Test casting mixed argument types."""
87+
def test_function(count: int, price: float, enabled: bool, name: str) -> str:
88+
return f"Mixed: {count}, {price}, {enabled}, {name}"
89+
90+
arguments = {
91+
"count": "10",
92+
"price": "99.99",
93+
"enabled": "true",
94+
"name": "test_item"
95+
}
96+
97+
if hasattr(self.agent, '_cast_arguments'):
98+
casted_args = self.agent._cast_arguments(test_function, arguments)
99+
100+
self.assertEqual(casted_args["count"], 10)
101+
self.assertIsInstance(casted_args["count"], int)
102+
103+
self.assertEqual(casted_args["price"], 99.99)
104+
self.assertIsInstance(casted_args["price"], float)
105+
106+
self.assertEqual(casted_args["enabled"], True)
107+
self.assertIsInstance(casted_args["enabled"], bool)
108+
109+
# String should remain unchanged
110+
self.assertEqual(casted_args["name"], "test_item")
111+
self.assertIsInstance(casted_args["name"], str)
112+
else:
113+
self.skipTest("_cast_arguments method not implemented yet")
114+
115+
def test_cast_arguments_no_annotations(self):
116+
"""Test that functions without type annotations remain unchanged."""
117+
def test_function(value):
118+
return f"Value: {value}"
119+
120+
arguments = {"value": "42"}
121+
122+
if hasattr(self.agent, '_cast_arguments'):
123+
casted_args = self.agent._cast_arguments(test_function, arguments)
124+
# Without annotations, should remain as string
125+
self.assertEqual(casted_args["value"], "42")
126+
self.assertIsInstance(casted_args["value"], str)
127+
else:
128+
self.skipTest("_cast_arguments method not implemented yet")
129+
130+
def test_cast_arguments_conversion_failure_graceful(self):
131+
"""Test graceful fallback when type conversion fails."""
132+
def test_function(count: int) -> str:
133+
return f"Count: {count}"
134+
135+
# Invalid integer string should fallback gracefully
136+
arguments = {"count": "not_a_number"}
137+
138+
if hasattr(self.agent, '_cast_arguments'):
139+
casted_args = self.agent._cast_arguments(test_function, arguments)
140+
# Should fallback to original string value
141+
self.assertEqual(casted_args["count"], "not_a_number")
142+
self.assertIsInstance(casted_args["count"], str)
143+
else:
144+
self.skipTest("_cast_arguments method not implemented yet")
145+
146+
def test_cast_arguments_already_correct_type(self):
147+
"""Test that arguments already of correct type are not modified."""
148+
def test_function(count: int) -> str:
149+
return f"Count: {count}"
150+
151+
# Already an integer
152+
arguments = {"count": 42}
153+
154+
if hasattr(self.agent, '_cast_arguments'):
155+
casted_args = self.agent._cast_arguments(test_function, arguments)
156+
self.assertEqual(casted_args["count"], 42)
157+
self.assertIsInstance(casted_args["count"], int)
158+
else:
159+
self.skipTest("_cast_arguments method not implemented yet")
160+
161+
def test_cast_arguments_with_none_values(self):
162+
"""Test handling of None values."""
163+
def test_function(optional_count: int = None) -> str:
164+
return f"Count: {optional_count}"
165+
166+
arguments = {"optional_count": None}
167+
168+
if hasattr(self.agent, '_cast_arguments'):
169+
casted_args = self.agent._cast_arguments(test_function, arguments)
170+
self.assertIsNone(casted_args["optional_count"])
171+
else:
172+
self.skipTest("_cast_arguments method not implemented yet")
173+
174+
175+
if __name__ == '__main__':
176+
unittest.main()

0 commit comments

Comments
 (0)