Skip to content

Commit 7cf436c

Browse files
author
Owen Kaplan
committed
feat: add threading.lock on async stream; expand interrupts to other AgentBase instances
1 parent 4d79d32 commit 7cf436c

2 files changed

Lines changed: 94 additions & 18 deletions

File tree

src/strands/agent/_agent_as_tool.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import copy
88
import logging
9+
import threading
910
from typing import Any
1011

1112
from typing_extensions import override
@@ -81,6 +82,10 @@ def __init__(
8182
# messages/state attributes.
8283
self._initial_messages: Messages = []
8384
self._initial_state: AgentState = AgentState()
85+
# Serialize access so _reset_agent_state + stream_async are atomic.
86+
# threading.Lock (not asyncio.Lock) because run_async() may create
87+
# separate event loops in different threads.
88+
self._lock = threading.Lock()
8489

8590
if not preserve_context:
8691
from .agent import Agent
@@ -157,20 +162,38 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
157162

158163
tool_use_id = tool_use["toolUseId"]
159164

160-
# Determine if we are resuming the sub-agent from an interrupt.
161-
if self._is_sub_agent_interrupted():
162-
prompt = self._build_interrupt_responses()
163-
logger.debug(
164-
"tool_name=<%s>, tool_use_id=<%s> | resuming sub-agent from interrupt",
165+
# Serialize access to the underlying agent. _reset_agent_state() mutates
166+
# the agent before stream_async acquires its own lock, so a concurrent
167+
# call would corrupt an in-flight invocation.
168+
if not self._lock.acquire(blocking=False):
169+
logger.warning(
170+
"tool_name=<%s>, tool_use_id=<%s> | agent is already processing a request",
165171
self._tool_name,
166172
tool_use_id,
167173
)
168-
elif not self._preserve_context:
169-
self._reset_agent_state(tool_use_id)
170-
171-
logger.debug("tool_name=<%s>, tool_use_id=<%s> | invoking agent", self._tool_name, tool_use_id)
174+
yield ToolResultEvent(
175+
{
176+
"toolUseId": tool_use_id,
177+
"status": "error",
178+
"content": [{"text": f"Agent '{self._tool_name}' is already processing a request"}],
179+
}
180+
)
181+
return
172182

173183
try:
184+
# Determine if we are resuming the sub-agent from an interrupt.
185+
if self._is_sub_agent_interrupted():
186+
prompt = self._build_interrupt_responses()
187+
logger.debug(
188+
"tool_name=<%s>, tool_use_id=<%s> | resuming sub-agent from interrupt",
189+
self._tool_name,
190+
tool_use_id,
191+
)
192+
elif not self._preserve_context:
193+
self._reset_agent_state(tool_use_id)
194+
195+
logger.debug("tool_name=<%s>, tool_use_id=<%s> | invoking agent", self._tool_name, tool_use_id)
196+
174197
result = None
175198
async for event in self._agent.stream_async(prompt):
176199
if "result" in event:
@@ -224,6 +247,8 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
224247
"content": [{"text": f"Agent error: {e}"}],
225248
}
226249
)
250+
finally:
251+
self._lock.release()
227252

228253
def _reset_agent_state(self, tool_use_id: str) -> None:
229254
"""Reset the wrapped agent to its initial state.
@@ -250,11 +275,8 @@ def _reset_agent_state(self, tool_use_id: str) -> None:
250275

251276
def _is_sub_agent_interrupted(self) -> bool:
252277
"""Check whether the wrapped agent is in an activated interrupt state."""
253-
from .agent import Agent
254-
255-
if not isinstance(self._agent, Agent):
256-
return False
257-
return self._agent._interrupt_state.activated
278+
interrupt_state = getattr(self._agent, "_interrupt_state", None)
279+
return interrupt_state is not None and interrupt_state.activated
258280

259281
def _build_interrupt_responses(self) -> list[InterruptResponseContent]:
260282
"""Build interrupt response payloads from the sub-agent's interrupt state.
@@ -266,14 +288,13 @@ def _build_interrupt_responses(self) -> list[InterruptResponseContent]:
266288
Returns:
267289
List of interrupt response content blocks for resuming the sub-agent.
268290
"""
269-
from .agent import Agent
270-
271-
if not isinstance(self._agent, Agent):
291+
interrupt_state = getattr(self._agent, "_interrupt_state", None)
292+
if interrupt_state is None:
272293
return []
273294

274295
return [
275296
{"interruptResponse": {"interruptId": interrupt.id, "response": interrupt.response}}
276-
for interrupt in self._agent._interrupt_state.interrupts.values()
297+
for interrupt in interrupt_state.interrupts.values()
277298
if interrupt.response is not None
278299
]
279300

tests/strands/agent/test_agent_as_tool.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ def mock_agent():
2323
agent = MagicMock()
2424
agent.name = "test_agent"
2525
agent.description = "A test agent"
26+
# Prevent MagicMock from auto-creating _interrupt_state on access,
27+
# so getattr checks in AgentAsTool correctly detect its absence.
28+
agent._interrupt_state = None
2629
return agent
2730

2831

@@ -615,3 +618,55 @@ async def test_build_interrupt_responses(fake_agent):
615618
# Only interrupt_a has a response
616619
assert len(responses) == 1
617620
assert responses[0] == {"interruptResponse": {"interruptId": "id-a", "response": "yes"}}
621+
622+
623+
# --- concurrency ---
624+
625+
626+
@pytest.mark.asyncio
627+
async def test_stream_rejects_concurrent_call(tool, mock_agent, tool_use, agent_result):
628+
"""A second concurrent call should get an error ToolResultEvent."""
629+
mock_agent.stream_async.return_value = _mock_stream_async(agent_result)
630+
631+
# Simulate the lock already being held by another invocation
632+
tool._lock.acquire()
633+
try:
634+
events = [event async for event in tool.stream(tool_use, {})]
635+
636+
assert len(events) == 1
637+
assert isinstance(events[0], ToolResultEvent)
638+
assert events[0]["tool_result"]["status"] == "error"
639+
assert "already processing" in events[0]["tool_result"]["content"][0]["text"]
640+
mock_agent.stream_async.assert_not_called()
641+
finally:
642+
tool._lock.release()
643+
644+
645+
@pytest.mark.asyncio
646+
async def test_stream_releases_lock_after_completion(tool, mock_agent, tool_use, agent_result):
647+
"""Lock should be released after stream completes, allowing subsequent calls."""
648+
mock_agent.stream_async.return_value = _mock_stream_async(agent_result)
649+
650+
async for _ in tool.stream(tool_use, {}):
651+
pass
652+
653+
assert not tool._lock.locked()
654+
655+
# A second call should succeed
656+
mock_agent.stream_async.return_value = _mock_stream_async(agent_result)
657+
events = [event async for event in tool.stream(tool_use, {})]
658+
659+
result_events = [e for e in events if isinstance(e, ToolResultEvent)]
660+
assert len(result_events) == 1
661+
assert result_events[0]["tool_result"]["status"] == "success"
662+
663+
664+
@pytest.mark.asyncio
665+
async def test_stream_releases_lock_after_error(tool, mock_agent, tool_use):
666+
"""Lock should be released even when the agent raises an exception."""
667+
mock_agent.stream_async.side_effect = RuntimeError("boom")
668+
669+
async for _ in tool.stream(tool_use, {}):
670+
pass
671+
672+
assert not tool._lock.locked()

0 commit comments

Comments
 (0)