forked from OpenHands/OpenHands
-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtest_agent_delegation.py
More file actions
304 lines (253 loc) · 10.6 KB
/
Copy pathtest_agent_delegation.py
File metadata and controls
304 lines (253 loc) · 10.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
import asyncio
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from uuid import uuid4
import pytest
from openhands.controller.agent import Agent
from openhands.controller.agent_controller import AgentController
from openhands.controller.state.state import State
from openhands.core.config import LLMConfig
from openhands.core.config.agent_config import AgentConfig
from openhands.core.schema import AgentState
from openhands.events import EventSource, EventStream
from openhands.events.action import (
AgentDelegateAction,
AgentFinishAction,
MessageAction,
)
from openhands.events.action.agent import RecallAction
from openhands.events.event import Event, RecallType
from openhands.events.observation.agent import RecallObservation
from openhands.events.stream import EventStreamSubscriber
from openhands.llm.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.memory.memory import Memory
from openhands.storage.memory import InMemoryFileStore
@pytest.fixture
def mock_event_stream():
"""Creates an event stream in memory."""
sid = f'test-{uuid4()}'
file_store = InMemoryFileStore({})
return EventStream(sid=sid, file_store=file_store, max_delay_time=0)
@pytest.fixture
def mock_parent_agent():
"""Creates a mock parent agent for testing delegation."""
agent = MagicMock(spec=Agent)
agent.name = 'ParentAgent'
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = LLMConfig()
agent.config = AgentConfig()
agent.workspace_mount_path_in_sandbox_store_in_session = True
agent.streaming_llm = None
return agent
@pytest.fixture
def mock_child_agent():
"""Creates a mock child agent for testing delegation."""
agent = MagicMock(spec=Agent)
agent.name = 'ChildAgent'
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = LLMConfig()
agent.config = AgentConfig()
agent.streaming_llm = None
return agent
# Create separate mock functions so we can track calls
async def mock_process_event(*args, **kwargs):
print(f'Mock process_single_event_for_mem0 called with: {args[0]}')
return []
async def mock_webhook_rag(*args, **kwargs):
print(f'Mock webhook_rag_conversation called with: {args[0]}')
return True
@pytest.mark.asyncio
async def test_delegation_flow(
mock_parent_agent, mock_child_agent, mock_event_stream, monkeypatch
):
"""
Test that when the parent agent delegates to a child, the parent's delegate
is set, and once the child finishes, the parent is cleaned up properly.
"""
# Mock the httpx client to prevent any actual HTTP requests
class MockResponse:
def __init__(self):
self.status_code = 200
self.text = 'OK'
def json(self):
return {'status': 'success'}
def raise_for_status(self):
pass
async def mock_post(*args, **kwargs):
return MockResponse()
# Apply mock to httpx AsyncClient post method
monkeypatch.setattr('httpx.AsyncClient.post', mock_post)
# We also need to disable any internal client creation logic
monkeypatch.setattr(
'openhands.server.thesis_auth.os.getenv',
lambda x, default=None: (
'http://fake-url' if x == 'THESIS_AUTH_SERVER_URL' else default
),
)
# Mock Mem0Client to avoid actual initialization attempts during tests
mock_mem0_client = MagicMock()
mock_mem0_client.is_available = False
monkeypatch.setattr(
'openhands.server.mem0.Mem0Client', MagicMock(return_value=mock_mem0_client)
)
# Mock the agent class resolution so that AgentController can instantiate mock_child_agent
Agent.get_cls = Mock(
return_value=lambda llm,
config,
workspace_mount_path_in_sandbox_store_in_session=None: mock_child_agent
)
process_patch = patch(
'openhands.controller.agent_controller.process_single_event_for_mem0',
new=AsyncMock(side_effect=mock_process_event),
)
webhook_patch = patch(
'openhands.controller.agent_controller.webhook_rag_conversation',
new=AsyncMock(side_effect=mock_webhook_rag),
)
# Apply both patches
with process_patch, webhook_patch:
# Create parent controller
parent_state = State(max_iterations=10)
parent_controller = AgentController(
agent=mock_parent_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='parent',
confirmation_mode=False,
headless_mode=True,
initial_state=parent_state,
)
# Setup Memory to catch RecallActions
mock_memory = MagicMock(spec=Memory)
mock_memory.event_stream = mock_event_stream
def on_event(event: Event):
if isinstance(event, RecallAction):
# create a RecallObservation
microagent_observation = RecallObservation(
recall_type=RecallType.KNOWLEDGE,
content='Found info',
)
microagent_observation._cause = event.id # ignore attr-defined warning
mock_event_stream.add_event(
microagent_observation, EventSource.ENVIRONMENT
)
mock_memory.on_event = on_event
mock_event_stream.subscribe(
EventStreamSubscriber.MEMORY, mock_memory.on_event, mock_memory
)
# Setup a delegate action from the parent
delegate_action = AgentDelegateAction(agent='ChildAgent', inputs={'test': True})
mock_parent_agent.step.return_value = delegate_action
# Simulate a user message event to cause parent.step() to run
message_action = MessageAction(content='please delegate now')
message_action._source = EventSource.USER
await parent_controller._on_event(message_action)
# Give time for the async step() to execute
await asyncio.sleep(1)
# Verify that a RecallObservation was added to the event stream
events = list(mock_event_stream.get_events())
assert (
mock_event_stream.get_latest_event_id() == 3
) # Microagents and AgentChangeState
# a RecallObservation and an AgentDelegateAction should be in the list
assert any(isinstance(event, RecallObservation) for event in events)
assert any(isinstance(event, AgentDelegateAction) for event in events)
# Verify that a delegate agent controller is created
assert (
parent_controller.delegate is not None
), "Parent's delegate controller was not set."
# The parent's iteration should have incremented
assert (
parent_controller.state.iteration == 1
), 'Parent iteration should be incremented after step.'
# Now simulate that the child increments local iteration and finishes its subtask
delegate_controller = parent_controller.delegate
delegate_controller.state.iteration = 5 # child had some steps
delegate_controller.state.outputs = {'delegate_result': 'done'}
# Mock _react_to_exception to prevent errors
async def mock_react_to_exception(*args, **kwargs):
pass
# Apply the mock to both controllers
monkeypatch.setattr(
delegate_controller, '_react_to_exception', mock_react_to_exception
)
monkeypatch.setattr(
parent_controller, '_react_to_exception', mock_react_to_exception
)
# Mock the update_agent_knowledge_base function in Agent to prevent problems
mock_child_agent.update_agent_knowledge_base = Mock()
# The child is done, so we simulate it finishing:
child_finish_action = AgentFinishAction()
await delegate_controller._on_event(child_finish_action)
# Send a dummy event to parent controller to trigger delegate cleanup check
dummy_message = MessageAction(content='Dummy event to check delegate status')
dummy_message._source = EventSource.USER
await parent_controller._on_event(dummy_message)
# Verify parent is cleaned up
assert (
parent_controller.delegate is None
), "Parent's delegate should be cleaned up after finishing."
# Instead of checking for exact iteration, check that it has been updated from the child
# using "greater than or equal" to handle possible additional increments
assert (
parent_controller.state.iteration >= 5
), "Parent should have adopted at least child's iteration count."
@pytest.mark.asyncio
@pytest.mark.parametrize(
'delegate_state',
[
AgentState.RUNNING,
AgentState.FINISHED,
AgentState.ERROR,
AgentState.REJECTED,
],
)
async def test_delegate_step_different_states(
mock_parent_agent, mock_event_stream, delegate_state
):
"""Ensure that delegate is closed or remains open based on the delegate's state."""
controller = AgentController(
agent=mock_parent_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
mock_delegate = AsyncMock()
controller.delegate = mock_delegate
mock_delegate.state.iteration = 5
mock_delegate.state.outputs = {'result': 'test'}
mock_delegate.agent.name = 'TestDelegate'
mock_delegate.get_agent_state = Mock(return_value=delegate_state)
mock_delegate._step = AsyncMock()
mock_delegate.close = AsyncMock()
def call_on_event_with_new_loop():
"""
In this thread, create and set a fresh event loop, so that the run_until_complete()
calls inside controller.on_event(...) find a valid loop.
"""
loop_in_thread = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop_in_thread)
msg_action = MessageAction(content='Test message')
msg_action._source = EventSource.USER
controller.on_event(msg_action)
finally:
loop_in_thread.close()
loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as executor:
future = loop.run_in_executor(executor, call_on_event_with_new_loop)
await future
if delegate_state == AgentState.RUNNING:
assert controller.delegate is not None
assert controller.state.iteration == 0
mock_delegate.close.assert_not_called()
else:
assert controller.delegate is None
assert controller.state.iteration == 5
mock_delegate.close.assert_called_once()
await controller.close()