Skip to content

Commit 9824198

Browse files
committed
feat: add reasoning decorator to browser toolkit
1 parent b8cc4cc commit 9824198

File tree

6 files changed

+568
-20
lines changed

6 files changed

+568
-20
lines changed

camel/toolkits/_utils.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14+
# =========
15+
import functools
16+
import inspect
17+
18+
from camel.logger import get_logger
19+
20+
logger = get_logger(__name__)
21+
22+
23+
def add_reason_field(func):
24+
"""
25+
Decorator to enable reasoning for tool functions.
26+
1.It modifies the function's signature to add a 'reason' parameter.
27+
2.The 'reason' argument is a string describing why the tool is being called
28+
and it is added to the function docstring.
29+
3.It wraps the original function to ensure
30+
its return value includes a 'reason' key.
31+
32+
Note: This decorator can only be applied to
33+
functions with a return type of Dict.
34+
35+
"""
36+
sig = inspect.signature(func)
37+
# Check return annotation
38+
if sig.return_annotation is inspect.Signature.empty or (
39+
getattr(sig.return_annotation, '__origin__', None) is not dict
40+
and sig.return_annotation is not dict
41+
):
42+
logger.info(
43+
f"add_reason_field: Function '{func.__name__}' "
44+
"does not have return type Dict. "
45+
"Reasoning will not be applied."
46+
)
47+
return func
48+
49+
# Patch signature
50+
params = list(sig.parameters.values())
51+
# Only add reason if not already present
52+
if "reason" not in sig.parameters:
53+
params.append(
54+
inspect.Parameter(
55+
"reason",
56+
inspect.Parameter.KEYWORD_ONLY,
57+
default="",
58+
annotation=str,
59+
)
60+
)
61+
new_sig = sig.replace(parameters=params)
62+
63+
# Patch docstring
64+
doc = func.__doc__ or ""
65+
lines = doc.splitlines()
66+
67+
# Add reason to Args section
68+
if "Args:" in doc:
69+
for i, line in enumerate(lines):
70+
if "Args:" in line:
71+
indent = line[: line.index("Args:")]
72+
reason_doc = (
73+
f"{indent} reason (str): The reason why this "
74+
+ "tool is called."
75+
)
76+
lines.insert(i + 1, reason_doc)
77+
break
78+
else:
79+
lines.extend(
80+
[
81+
"",
82+
"Args:",
83+
" reason (str): The reason why this tool is called.",
84+
]
85+
)
86+
87+
# Add reason to Returns section
88+
returns_idx = None
89+
for i, line in enumerate(lines):
90+
if "Returns:" in line:
91+
returns_idx = i
92+
indent = line[: line.index("Returns:")]
93+
break
94+
95+
if returns_idx is not None:
96+
# Find where Returns section ends by checking indentation
97+
end_idx = len(lines)
98+
for j in range(returns_idx + 1, len(lines)):
99+
line = lines[j]
100+
if line.strip() == "":
101+
continue
102+
if not line.startswith(indent + " "):
103+
end_idx = j
104+
break
105+
106+
# Remove existing reason lines
107+
filtered = []
108+
for i, line in enumerate(lines):
109+
if '"reason"' not in line and "'reason'" not in line:
110+
filtered.append(line)
111+
elif i < returns_idx or i >= end_idx:
112+
filtered.append(line)
113+
114+
lines = filtered
115+
116+
# Recalculate end_idx after filtering
117+
end_idx = len(lines)
118+
for j in range(returns_idx + 1, len(lines)):
119+
line = lines[j]
120+
if line.strip() == "":
121+
continue
122+
if not line.startswith(indent + " "):
123+
end_idx = j
124+
break
125+
126+
while end_idx > returns_idx + 1 and lines[end_idx - 1].strip() == "":
127+
end_idx -= 1
128+
129+
# Append reason line at end of Returns section
130+
reason_line = f'{indent} - "reason" (str): tool call reason.'
131+
lines.insert(end_idx, reason_line)
132+
elif "Returns:" not in doc:
133+
lines.extend(
134+
[
135+
"",
136+
"Returns:",
137+
" dict: The result dictionary.",
138+
' - "reason" (str): tool call reason.',
139+
]
140+
)
141+
142+
doc = "\n".join(lines)
143+
144+
@functools.wraps(func)
145+
async def wrapper(*args, reason: str = "", **kwargs):
146+
result = await func(*args, **kwargs)
147+
# No need to check type here, as enforced by decorator
148+
result["reason"] = reason
149+
return result
150+
151+
wrapper.__signature__ = new_sig
152+
wrapper.__doc__ = doc
153+
return wrapper

camel/toolkits/hybrid_browser_toolkit/config_loader.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class ToolkitConfig:
5858
log_dir: Optional[str] = None
5959
session_id: Optional[str] = None
6060
enabled_tools: Optional[list] = None
61+
enable_reasoning: bool = False
6162

6263

6364
class ConfigLoader:
@@ -123,6 +124,8 @@ def from_kwargs(cls, **kwargs) -> 'ConfigLoader':
123124
toolkit_kwargs["session_id"] = value
124125
elif key == "enabledTools":
125126
toolkit_kwargs["enabled_tools"] = value
127+
elif key == "enableReasoning":
128+
toolkit_kwargs["enable_reasoning"] = value
126129
elif key == "fullVisualMode":
127130
browser_kwargs["full_visual_mode"] = value
128131

camel/toolkits/hybrid_browser_toolkit/hybrid_browser_toolkit.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ class HybridBrowserToolkit(BaseToolkit):
3939
cache_dir (str): Directory for caching. Defaults to "tmp/".
4040
enabled_tools (Optional[List[str]]): List of enabled tools.
4141
Defaults to None.
42+
enable_reasoning (bool): Whether to enable reasoning when agent
43+
is using the browser toolkit. Defaults to False. agent will
44+
provide explanations for its actions when this is enabled.
4245
browser_log_to_file (bool): Whether to log browser actions to
4346
file. Defaults to False.
4447
log_dir (Optional[str]): Custom directory path for log files.
@@ -93,6 +96,7 @@ def __new__(
9396
stealth: bool = False,
9497
cache_dir: Optional[str] = None,
9598
enabled_tools: Optional[List[str]] = None,
99+
enable_reasoning: bool = False,
96100
browser_log_to_file: bool = False,
97101
log_dir: Optional[str] = None,
98102
session_id: Optional[str] = None,
@@ -127,6 +131,9 @@ def __new__(
127131
cache_dir (str): Directory for caching. Defaults to "tmp/".
128132
enabled_tools (Optional[List[str]]): List of enabled tools.
129133
Defaults to None.
134+
enable_reasoning (bool): Whether to enable reasoning when agent
135+
is using the browser toolkit. Defaults to False. agent will
136+
provide explanations for its actions when this is enabled.
130137
browser_log_to_file (bool): Whether to log browser actions to
131138
file. Defaults to False.
132139
log_dir (Optional[str]): Custom directory path for log files.
@@ -182,6 +189,7 @@ def __new__(
182189
stealth=stealth,
183190
cache_dir=cache_dir,
184191
enabled_tools=enabled_tools,
192+
enable_reasoning=enable_reasoning,
185193
browser_log_to_file=browser_log_to_file,
186194
log_dir=log_dir,
187195
session_id=session_id,

camel/toolkits/hybrid_browser_toolkit/hybrid_browser_toolkit_ts.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from camel.logger import get_logger
3030
from camel.messages import BaseMessage
31+
from camel.toolkits._utils import add_reason_field
3132
from camel.toolkits.base import BaseToolkit, RegisteredAgentToolkit
3233
from camel.toolkits.function_tool import FunctionTool
3334
from camel.utils.commons import dependencies_required
@@ -99,6 +100,7 @@ def __init__(
99100
stealth: bool = False,
100101
cache_dir: Optional[str] = None,
101102
enabled_tools: Optional[List[str]] = None,
103+
enable_reasoning: bool = False,
102104
browser_log_to_file: bool = False,
103105
log_dir: Optional[str] = None,
104106
session_id: Optional[str] = None,
@@ -128,6 +130,9 @@ def __init__(
128130
cache_dir (str): Directory for caching. Defaults to "tmp/".
129131
enabled_tools (Optional[List[str]]): List of enabled tools.
130132
Defaults to None.
133+
enable_reasoning (bool): Whether to enable reasoning when agent
134+
is using the browser toolkit. Defaults to False. agent will
135+
provide explanations for its actions when this is enabled.
131136
browser_log_to_file (bool): Whether to log browser actions to
132137
file. Defaults to False.
133138
log_dir (Optional[str]): Custom directory path for log files.
@@ -187,6 +192,7 @@ def __init__(
187192
log_dir=log_dir,
188193
session_id=session_id,
189194
enabled_tools=enabled_tools,
195+
enable_reasoning=enable_reasoning,
190196
connect_over_cdp=connect_over_cdp,
191197
cdp_url=cdp_url,
192198
cdp_keep_current_page=cdp_keep_current_page,
@@ -206,17 +212,22 @@ def __init__(
206212
"is True, the browser will keep the current page and not "
207213
"navigate to any URL."
208214
)
215+
# toolkit settings
216+
self._cache_dir = toolkit_config.cache_dir
217+
self._browser_log_to_file = toolkit_config.browser_log_to_file
218+
self._enabled_tools = toolkit_config.enabled_tools
219+
self._enable_reasoning = toolkit_config.enable_reasoning
209220

221+
# browser settings
210222
self._headless = browser_config.headless
211223
self._user_data_dir = browser_config.user_data_dir
212224
self._stealth = browser_config.stealth
213-
self._cache_dir = toolkit_config.cache_dir
214-
self._browser_log_to_file = toolkit_config.browser_log_to_file
215225
self._default_start_url = browser_config.default_start_url
216226
self._session_id = toolkit_config.session_id or "default"
217227
self._viewport_limit = browser_config.viewport_limit
218228
self._full_visual_mode = browser_config.full_visual_mode
219229

230+
# timeout settings
220231
self._default_timeout = browser_config.default_timeout
221232
self._short_timeout = browser_config.short_timeout
222233
self._navigation_timeout = browser_config.navigation_timeout
@@ -227,24 +238,32 @@ def __init__(
227238
browser_config.dom_content_loaded_timeout
228239
)
229240

230-
if enabled_tools is None:
231-
self.enabled_tools = self.DEFAULT_TOOLS.copy()
241+
if self._enabled_tools is None:
242+
self._enabled_tools = self.DEFAULT_TOOLS.copy()
232243
else:
233244
invalid_tools = [
234-
tool for tool in enabled_tools if tool not in self.ALL_TOOLS
245+
tool
246+
for tool in self._enabled_tools
247+
if tool not in self.ALL_TOOLS
235248
]
236249
if invalid_tools:
237250
raise ValueError(
238251
f"Invalid tools specified: {invalid_tools}. "
239252
f"Available tools: {self.ALL_TOOLS}"
240253
)
241-
self.enabled_tools = enabled_tools.copy()
242254

243-
logger.info(f"Enabled tools: {self.enabled_tools}")
255+
logger.info(f"Enabled tools: {self._enabled_tools}")
244256

245257
self._ws_wrapper: Optional[WebSocketBrowserWrapper] = None
246258
self._ws_config = self.config_loader.to_ws_config()
247259

260+
# Dynamically wrap tool methods if reasoning is enabled
261+
if self._enable_reasoning:
262+
for tool_name in self._enabled_tools:
263+
method = getattr(self, tool_name, None)
264+
if method and callable(method):
265+
setattr(self, tool_name, add_reason_field(method))
266+
248267
async def _ensure_ws_wrapper(self):
249268
"""Ensure WebSocket wrapper is initialized."""
250269
if self._ws_wrapper is None:
@@ -1913,7 +1932,9 @@ def clone_for_new_session(
19131932
stealth=self._stealth,
19141933
cache_dir=f"{self._cache_dir.rstrip('/')}_clone_"
19151934
f"{new_session_id}/",
1916-
enabled_tools=self.enabled_tools.copy(),
1935+
enabled_tools=(self._enabled_tools.copy()
1936+
if self._enabled_tools
1937+
else None),
19171938
browser_log_to_file=self._browser_log_to_file,
19181939
session_id=new_session_id,
19191940
default_start_url=self._default_start_url,
@@ -1958,16 +1979,16 @@ def get_tools(self) -> List[FunctionTool]:
19581979
"browser_sheet_read": self.browser_sheet_read,
19591980
}
19601981

1961-
enabled_tools = []
1962-
1963-
for tool_name in self.enabled_tools:
1964-
if tool_name in tool_map:
1965-
tool = FunctionTool(
1966-
cast(Callable[..., Any], tool_map[tool_name])
1967-
)
1968-
enabled_tools.append(tool)
1969-
else:
1970-
logger.warning(f"Unknown tool name: {tool_name}")
1982+
tools = []
1983+
if self._enabled_tools is not None:
1984+
for tool_name in self._enabled_tools:
1985+
if tool_name in tool_map:
1986+
tool = FunctionTool(
1987+
cast(Callable[..., Any], tool_map[tool_name])
1988+
)
1989+
tools.append(tool)
1990+
else:
1991+
logger.warning(f"Unknown tool name: {tool_name}")
19711992

1972-
logger.info(f"Returning {len(enabled_tools)} enabled tools")
1973-
return enabled_tools
1993+
logger.info(f"Returning {len(tools)} enabled tools")
1994+
return tools

0 commit comments

Comments
 (0)