Skip to content

Commit ad53d20

Browse files
committed
feat: add reasoning decorator to browser toolkit
1 parent b8cc4cc commit ad53d20

File tree

6 files changed

+538
-14
lines changed

6 files changed

+538
-14
lines changed

camel/toolkits/_utils.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14+
# =========
15+
import functools
16+
import inspect
17+
18+
from camel.logger import get_logger
19+
logger = get_logger(__name__)
20+
21+
22+
def add_reason_field(func):
23+
"""
24+
Decorator to enable reasoning for tool functions.
25+
1.It modifies the function's signature to add a 'reason' parameter.
26+
2.The 'reason' argument is a string describing why the tool is being called and it is
27+
added to the function docstring.
28+
3.It wraps the original function to ensure its return value includes a 'reason' key.
29+
30+
Note: This decorator can only be applied to functions with a return type of Dict.
31+
32+
"""
33+
sig = inspect.signature(func)
34+
# Check return annotation
35+
if sig.return_annotation is inspect.Signature.empty or (
36+
getattr(sig.return_annotation, '__origin__', None) is not dict and
37+
sig.return_annotation is not dict
38+
):
39+
logger.info(
40+
f"add_reason_field: Function '{func.__name__}' does not have return type Dict. "
41+
"Reasoning will not be applied."
42+
)
43+
return func
44+
45+
# Patch signature
46+
params = list(sig.parameters.values())
47+
# Only add reason if not already present
48+
if "reason" not in sig.parameters:
49+
params.append(
50+
inspect.Parameter(
51+
"reason",
52+
inspect.Parameter.KEYWORD_ONLY,
53+
default="",
54+
annotation=str,
55+
)
56+
)
57+
new_sig = sig.replace(parameters=params)
58+
59+
# Patch docstring
60+
doc = func.__doc__ or ""
61+
lines = doc.splitlines()
62+
63+
# Add reason to Args section
64+
if "Args:" in doc:
65+
for i, line in enumerate(lines):
66+
if "Args:" in line:
67+
indent = line[:line.index("Args:")]
68+
reason_doc = f"{indent} reason (str): The reason why this tool is called."
69+
lines.insert(i + 1, reason_doc)
70+
break
71+
else:
72+
lines.extend(["", "Args:", " reason (str): The reason why this tool is called."])
73+
74+
# Add reason to Returns section
75+
returns_idx = None
76+
for i, line in enumerate(lines):
77+
if "Returns:" in line:
78+
returns_idx = i
79+
indent = line[:line.index("Returns:")]
80+
break
81+
82+
if returns_idx is not None:
83+
# Find where Returns section ends by checking indentation
84+
end_idx = len(lines)
85+
for j in range(returns_idx + 1, len(lines)):
86+
line = lines[j]
87+
# Empty line or line with same/less indentation as "Returns:" = end of section
88+
if line.strip() == "":
89+
continue
90+
if not line.startswith(indent + " "):
91+
end_idx = j
92+
break
93+
94+
# Remove existing reason lines
95+
filtered = []
96+
for i, line in enumerate(lines):
97+
if '"reason"' not in line and "'reason'" not in line:
98+
filtered.append(line)
99+
elif i < returns_idx or i >= end_idx:
100+
filtered.append(line)
101+
102+
lines = filtered
103+
104+
# Recalculate end_idx after filtering
105+
end_idx = len(lines)
106+
for j in range(returns_idx + 1, len(lines)):
107+
line = lines[j]
108+
if line.strip() == "":
109+
continue
110+
if not line.startswith(indent + " "):
111+
end_idx = j
112+
break
113+
114+
# Remove trailing empty lines in Returns section before inserting reason
115+
while end_idx > returns_idx + 1 and lines[end_idx - 1].strip() == "":
116+
end_idx -= 1
117+
118+
# Append reason line at end of Returns section
119+
reason_line = f'{indent} - "reason" (str): tool call reason.'
120+
lines.insert(end_idx, reason_line)
121+
elif "Returns:" not in doc:
122+
lines.extend(["", "Returns:", " dict: The result dictionary.", ' - "reason" (str): tool call reason.'])
123+
124+
doc = "\n".join(lines)
125+
126+
@functools.wraps(func)
127+
async def wrapper(*args, reason: str = "", **kwargs):
128+
result = await func(*args, **kwargs)
129+
# No need to check type here, as enforced by decorator
130+
result["reason"] = reason
131+
return result
132+
wrapper.__signature__ = new_sig
133+
wrapper.__doc__ = doc
134+
return wrapper

camel/toolkits/hybrid_browser_toolkit/config_loader.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class ToolkitConfig:
5858
log_dir: Optional[str] = None
5959
session_id: Optional[str] = None
6060
enabled_tools: Optional[list] = None
61+
enable_reasoning: bool = False
6162

6263

6364
class ConfigLoader:
@@ -123,6 +124,8 @@ def from_kwargs(cls, **kwargs) -> 'ConfigLoader':
123124
toolkit_kwargs["session_id"] = value
124125
elif key == "enabledTools":
125126
toolkit_kwargs["enabled_tools"] = value
127+
elif key == "enableReasoning":
128+
toolkit_kwargs["enable_reasoning"] = value
126129
elif key == "fullVisualMode":
127130
browser_kwargs["full_visual_mode"] = value
128131

camel/toolkits/hybrid_browser_toolkit/hybrid_browser_toolkit.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ class HybridBrowserToolkit(BaseToolkit):
3939
cache_dir (str): Directory for caching. Defaults to "tmp/".
4040
enabled_tools (Optional[List[str]]): List of enabled tools.
4141
Defaults to None.
42+
enable_reasoning (bool): Whether to enable reasoning when agent
43+
is using the browser toolkit. Defaults to False. agent will
44+
provide explanations for its actions when this is enabled.
4245
browser_log_to_file (bool): Whether to log browser actions to
4346
file. Defaults to False.
4447
log_dir (Optional[str]): Custom directory path for log files.
@@ -93,6 +96,7 @@ def __new__(
9396
stealth: bool = False,
9497
cache_dir: Optional[str] = None,
9598
enabled_tools: Optional[List[str]] = None,
99+
enable_reasoning: bool = False,
96100
browser_log_to_file: bool = False,
97101
log_dir: Optional[str] = None,
98102
session_id: Optional[str] = None,
@@ -127,6 +131,9 @@ def __new__(
127131
cache_dir (str): Directory for caching. Defaults to "tmp/".
128132
enabled_tools (Optional[List[str]]): List of enabled tools.
129133
Defaults to None.
134+
enable_reasoning (bool): Whether to enable reasoning when agent
135+
is using the browser toolkit. Defaults to False. agent will
136+
provide explanations for its actions when this is enabled.
130137
browser_log_to_file (bool): Whether to log browser actions to
131138
file. Defaults to False.
132139
log_dir (Optional[str]): Custom directory path for log files.
@@ -182,6 +189,7 @@ def __new__(
182189
stealth=stealth,
183190
cache_dir=cache_dir,
184191
enabled_tools=enabled_tools,
192+
enable_reasoning=enable_reasoning,
185193
browser_log_to_file=browser_log_to_file,
186194
log_dir=log_dir,
187195
session_id=session_id,

camel/toolkits/hybrid_browser_toolkit/hybrid_browser_toolkit_ts.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from camel.toolkits.function_tool import FunctionTool
3333
from camel.utils.commons import dependencies_required
3434

35+
from camel.toolkits._utils import add_reason_field
36+
3537
from .config_loader import ConfigLoader
3638
from .ws_wrapper import WebSocketBrowserWrapper, high_level_action
3739

@@ -99,6 +101,7 @@ def __init__(
99101
stealth: bool = False,
100102
cache_dir: Optional[str] = None,
101103
enabled_tools: Optional[List[str]] = None,
104+
enable_reasoning: bool = False,
102105
browser_log_to_file: bool = False,
103106
log_dir: Optional[str] = None,
104107
session_id: Optional[str] = None,
@@ -128,6 +131,9 @@ def __init__(
128131
cache_dir (str): Directory for caching. Defaults to "tmp/".
129132
enabled_tools (Optional[List[str]]): List of enabled tools.
130133
Defaults to None.
134+
enable_reasoning (bool): Whether to enable reasoning when agent
135+
is using the browser toolkit. Defaults to False. agent will
136+
provide explanations for its actions when this is enabled.
131137
browser_log_to_file (bool): Whether to log browser actions to
132138
file. Defaults to False.
133139
log_dir (Optional[str]): Custom directory path for log files.
@@ -187,6 +193,7 @@ def __init__(
187193
log_dir=log_dir,
188194
session_id=session_id,
189195
enabled_tools=enabled_tools,
196+
enable_reasoning=enable_reasoning,
190197
connect_over_cdp=connect_over_cdp,
191198
cdp_url=cdp_url,
192199
cdp_keep_current_page=cdp_keep_current_page,
@@ -206,17 +213,22 @@ def __init__(
206213
"is True, the browser will keep the current page and not "
207214
"navigate to any URL."
208215
)
216+
# toolkit settings
217+
self._cache_dir = toolkit_config.cache_dir
218+
self._browser_log_to_file = toolkit_config.browser_log_to_file
219+
self._enabled_tools = toolkit_config.enabled_tools
220+
self._enable_reasoning = toolkit_config.enable_reasoning
209221

222+
# browser settings
210223
self._headless = browser_config.headless
211224
self._user_data_dir = browser_config.user_data_dir
212225
self._stealth = browser_config.stealth
213-
self._cache_dir = toolkit_config.cache_dir
214-
self._browser_log_to_file = toolkit_config.browser_log_to_file
215226
self._default_start_url = browser_config.default_start_url
216227
self._session_id = toolkit_config.session_id or "default"
217228
self._viewport_limit = browser_config.viewport_limit
218229
self._full_visual_mode = browser_config.full_visual_mode
219230

231+
# timeout settings
220232
self._default_timeout = browser_config.default_timeout
221233
self._short_timeout = browser_config.short_timeout
222234
self._navigation_timeout = browser_config.navigation_timeout
@@ -227,24 +239,30 @@ def __init__(
227239
browser_config.dom_content_loaded_timeout
228240
)
229241

230-
if enabled_tools is None:
231-
self.enabled_tools = self.DEFAULT_TOOLS.copy()
242+
if self._enabled_tools is None:
243+
self._enabled_tools = self.DEFAULT_TOOLS.copy()
232244
else:
233245
invalid_tools = [
234-
tool for tool in enabled_tools if tool not in self.ALL_TOOLS
246+
tool for tool in self._enabled_tools if tool not in self.ALL_TOOLS
235247
]
236248
if invalid_tools:
237249
raise ValueError(
238250
f"Invalid tools specified: {invalid_tools}. "
239251
f"Available tools: {self.ALL_TOOLS}"
240252
)
241-
self.enabled_tools = enabled_tools.copy()
242253

243-
logger.info(f"Enabled tools: {self.enabled_tools}")
254+
logger.info(f"Enabled tools: {self._enabled_tools}")
244255

245256
self._ws_wrapper: Optional[WebSocketBrowserWrapper] = None
246257
self._ws_config = self.config_loader.to_ws_config()
247258

259+
# Dynamically wrap tool methods if reasoning is enabled
260+
if self._enable_reasoning:
261+
for tool_name in self._enabled_tools:
262+
method = getattr(self, tool_name, None)
263+
if method and callable(method):
264+
setattr(self, tool_name, add_reason_field(method))
265+
248266
async def _ensure_ws_wrapper(self):
249267
"""Ensure WebSocket wrapper is initialized."""
250268
if self._ws_wrapper is None:
@@ -1913,7 +1931,7 @@ def clone_for_new_session(
19131931
stealth=self._stealth,
19141932
cache_dir=f"{self._cache_dir.rstrip('/')}_clone_"
19151933
f"{new_session_id}/",
1916-
enabled_tools=self.enabled_tools.copy(),
1934+
enabled_tools=self._enabled_tools.copy(),
19171935
browser_log_to_file=self._browser_log_to_file,
19181936
session_id=new_session_id,
19191937
default_start_url=self._default_start_url,
@@ -1958,16 +1976,15 @@ def get_tools(self) -> List[FunctionTool]:
19581976
"browser_sheet_read": self.browser_sheet_read,
19591977
}
19601978

1961-
enabled_tools = []
1962-
1963-
for tool_name in self.enabled_tools:
1979+
tools = []
1980+
for tool_name in self._enabled_tools:
19641981
if tool_name in tool_map:
19651982
tool = FunctionTool(
19661983
cast(Callable[..., Any], tool_map[tool_name])
19671984
)
1968-
enabled_tools.append(tool)
1985+
tools.append(tool)
19691986
else:
19701987
logger.warning(f"Unknown tool name: {tool_name}")
19711988

1972-
logger.info(f"Returning {len(enabled_tools)} enabled tools")
1973-
return enabled_tools
1989+
logger.info(f"Returning {len(tools)} enabled tools")
1990+
return tools

0 commit comments

Comments
 (0)