Skip to content

Commit a4d6e74

Browse files
DouweMclaude
andauthored
Fix Temporal and DBOS MCP to use cached tools instead of fetching each time (#4331)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a89292b commit a4d6e74

File tree

9 files changed

+3008
-45
lines changed

9 files changed

+3008
-45
lines changed

pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_mcp_server.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,18 @@
22

33
from pydantic_ai import ToolsetTool
44
from pydantic_ai.mcp import MCPServer
5-
from pydantic_ai.tools import AgentDepsT, ToolDefinition
5+
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
66

77
from ._mcp import DBOSMCPToolset
88
from ._utils import StepConfig
99

1010

1111
class DBOSMCPServer(DBOSMCPToolset[AgentDepsT]):
12-
"""A wrapper for MCPServer that integrates with DBOS, turning call_tool and get_tools to DBOS steps."""
12+
"""A wrapper for MCPServer that integrates with DBOS, turning call_tool and get_tools into DBOS steps.
13+
14+
Tool definitions are cached across steps to avoid redundant MCP server round-trips,
15+
respecting the wrapped server's `cache_tools` setting.
16+
"""
1317

1418
def __init__(
1519
self,
@@ -23,7 +27,24 @@ def __init__(
2327
step_name_prefix=step_name_prefix,
2428
step_config=step_config,
2529
)
30+
# Cached across steps to avoid redundant MCP connections per step.
31+
# Not invalidated by `tools/list_changed` notifications — users who need
32+
# dynamic tools during a workflow should set `cache_tools=False`.
33+
self._cached_tool_defs: dict[str, ToolDefinition] | None = None
2634

27-
def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
35+
@property
36+
def _server(self) -> MCPServer:
2837
assert isinstance(self.wrapped, MCPServer)
29-
return self.wrapped.tool_for_tool_def(tool_def)
38+
return self.wrapped
39+
40+
def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
41+
return self._server.tool_for_tool_def(tool_def)
42+
43+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
44+
if self._server.cache_tools and self._cached_tool_defs is not None:
45+
return {name: self.tool_for_tool_def(td) for name, td in self._cached_tool_defs.items()}
46+
47+
result = await super().get_tools(ctx)
48+
if self._server.cache_tools:
49+
self._cached_tool_defs = {name: tool.tool_def for name, tool in result.items()}
50+
return result

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_logfire.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def _default_setup_logfire() -> Logfire:
1515
import logfire
1616

1717
instance = logfire.configure()
18-
logfire.instrument_pydantic_ai()
18+
instance.instrument_pydantic_ai()
1919
return instance
2020

2121

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp_server.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,19 @@
66

77
from pydantic_ai import ToolsetTool
88
from pydantic_ai.mcp import MCPServer
9-
from pydantic_ai.tools import AgentDepsT, ToolDefinition
9+
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
1010

1111
from ._mcp import TemporalMCPToolset
1212
from ._run_context import TemporalRunContext
1313

1414

1515
class TemporalMCPServer(TemporalMCPToolset[AgentDepsT]):
16+
"""A wrapper for MCPServer that integrates with Temporal, turning get_tools and call_tool into activities.
17+
18+
Tool definitions are cached across activities to avoid redundant MCP server round-trips,
19+
respecting the wrapped server's `cache_tools` setting.
20+
"""
21+
1622
def __init__(
1723
self,
1824
server: MCPServer,
@@ -31,7 +37,24 @@ def __init__(
3137
deps_type=deps_type,
3238
run_context_type=run_context_type,
3339
)
40+
# Cached across activities to avoid redundant MCP connections per activity.
41+
# Not invalidated by `tools/list_changed` notifications — users who need
42+
# dynamic tools during a workflow should set `cache_tools=False`.
43+
self._cached_tool_defs: dict[str, ToolDefinition] | None = None
3444

35-
def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
45+
@property
46+
def _server(self) -> MCPServer:
3647
assert isinstance(self.wrapped, MCPServer)
37-
return self.wrapped.tool_for_tool_def(tool_def)
48+
return self.wrapped
49+
50+
def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
51+
return self._server.tool_for_tool_def(tool_def)
52+
53+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
54+
if self._server.cache_tools and self._cached_tool_defs is not None:
55+
return {name: self.tool_for_tool_def(td) for name, td in self._cached_tool_defs.items()}
56+
57+
result = await super().get_tools(ctx)
58+
if self._server.cache_tools: # pragma: no branch
59+
self._cached_tool_defs = {name: tool.tool_def for name, tool in result.items()}
60+
return result

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -339,17 +339,22 @@ class MCPServer(AbstractToolset[Any], ABC):
339339
340340
When enabled (default), tools are fetched once and cached until either:
341341
- The server sends a `notifications/tools/list_changed` notification
342-
- The connection is closed
342+
- [`MCPServer.__aexit__`][pydantic_ai.mcp.MCPServer.__aexit__] is called (when the last context exits)
343343
344344
Set to `False` for servers that change tools dynamically without sending notifications.
345+
346+
Note: When using durable execution (Temporal, DBOS), tool definitions are additionally cached
347+
at the wrapper level across activities/steps, to avoid redundant MCP connections. This
348+
wrapper-level cache is not invalidated by `tools/list_changed` notifications.
349+
Set to `False` to disable all caching if tools may change during a workflow.
345350
"""
346351

347352
cache_resources: bool
348353
"""Whether to cache the list of resources.
349354
350355
When enabled (default), resources are fetched once and cached until either:
351356
- The server sends a `notifications/resources/list_changed` notification
352-
- The connection is closed
357+
- [`MCPServer.__aexit__`][pydantic_ai.mcp.MCPServer.__aexit__] is called (when the last context exits)
353358
354359
Set to `False` for servers that change resources dynamically without sending notifications.
355360
"""
@@ -479,20 +484,18 @@ async def list_tools(self) -> list[mcp_types.Tool]:
479484
480485
Tools are cached by default, with cache invalidation on:
481486
- `notifications/tools/list_changed` notifications from the server
482-
- Connection close (cache is cleared in `__aexit__`)
487+
- `__aexit__` when the last context exits
483488
484489
Set `cache_tools=False` for servers that change tools without sending notifications.
485490
"""
491+
if self.cache_tools and self._cached_tools is not None:
492+
return self._cached_tools
493+
486494
async with self:
495+
result = await self._client.list_tools()
487496
if self.cache_tools:
488-
if self._cached_tools is not None:
489-
return self._cached_tools
490-
result = await self._client.list_tools()
491497
self._cached_tools = result.tools
492-
return result.tools
493-
else:
494-
result = await self._client.list_tools()
495-
return result.tools
498+
return result.tools
496499

497500
async def direct_call_tool(
498501
self,
@@ -600,27 +603,25 @@ async def list_resources(self) -> list[Resource]:
600603
601604
Resources are cached by default, with cache invalidation on:
602605
- `notifications/resources/list_changed` notifications from the server
603-
- Connection close (cache is cleared in `__aexit__`)
606+
- `__aexit__` when the last context exits
604607
605608
Set `cache_resources=False` for servers that change resources without sending notifications.
606609
607610
Raises:
608611
MCPError: If the server returns an error.
609612
"""
613+
if self.cache_resources and self._cached_resources is not None:
614+
return self._cached_resources
615+
610616
async with self:
611617
if not self.capabilities.resources:
612618
return []
613619
try:
620+
result = await self._client.list_resources()
621+
resources = [Resource.from_mcp_sdk(r) for r in result.resources]
614622
if self.cache_resources:
615-
if self._cached_resources is not None:
616-
return self._cached_resources
617-
result = await self._client.list_resources()
618-
resources = [Resource.from_mcp_sdk(r) for r in result.resources]
619623
self._cached_resources = resources
620-
return resources
621-
else:
622-
result = await self._client.list_resources()
623-
return [Resource.from_mcp_sdk(r) for r in result.resources]
624+
return resources
624625
except mcp_exceptions.McpError as e:
625626
raise MCPError.from_mcp_sdk(e) from e
626627

0 commit comments

Comments
 (0)