Skip to content

Commit 1682a0c

Browse files
mananpatel320Manan PatelUnshure
authored
fix: forward _meta to MCP tool calls and fix model_dump alias seriali… (#1918)
Co-authored-by: Manan Patel <mananptl@amazon.com> Co-authored-by: Nicholas Clegg <ncclegg@amazon.com>
1 parent 94fc8dd commit 1682a0c

6 files changed

Lines changed: 170 additions & 16 deletions

File tree

src/strands/tools/mcp/mcp_client.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,7 @@ def _create_call_tool_coroutine(
568568
name: str,
569569
arguments: dict[str, Any] | None,
570570
read_timeout_seconds: timedelta | None,
571+
meta: dict[str, Any] | None = None,
571572
) -> Coroutine[Any, Any, MCPCallToolResult]:
572573
"""Create the appropriate coroutine for calling a tool.
573574
@@ -578,6 +579,7 @@ def _create_call_tool_coroutine(
578579
name: Name of the tool to call.
579580
arguments: Optional arguments to pass to the tool.
580581
read_timeout_seconds: Optional timeout for the tool call.
582+
meta: Optional metadata to pass to the tool call per MCP spec (_meta).
581583
582584
Returns:
583585
A coroutine that will execute the tool call.
@@ -598,7 +600,7 @@ async def _call_as_task() -> MCPCallToolResult:
598600

599601
async def _call_tool_direct() -> MCPCallToolResult:
600602
return await cast(ClientSession, self._background_thread_session).call_tool(
601-
name, arguments, read_timeout_seconds
603+
name, arguments, read_timeout_seconds, meta=meta
602604
)
603605

604606
return _call_tool_direct()
@@ -609,6 +611,7 @@ def call_tool_sync(
609611
name: str,
610612
arguments: dict[str, Any] | None = None,
611613
read_timeout_seconds: timedelta | None = None,
614+
meta: dict[str, Any] | None = None,
612615
) -> MCPToolResult:
613616
"""Synchronously calls a tool on the MCP server.
614617
@@ -620,6 +623,7 @@ def call_tool_sync(
620623
name: Name of the tool to call
621624
arguments: Optional arguments to pass to the tool
622625
read_timeout_seconds: Optional timeout for the tool call
626+
meta: Optional metadata to pass to the tool call per MCP spec (_meta)
623627
624628
Returns:
625629
MCPToolResult: The result of the tool call
@@ -629,7 +633,7 @@ def call_tool_sync(
629633
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
630634

631635
try:
632-
coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds)
636+
coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta)
633637
call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(coro).result()
634638
return self._handle_tool_result(tool_use_id, call_tool_result)
635639
except Exception as e:
@@ -642,6 +646,7 @@ async def call_tool_async(
642646
name: str,
643647
arguments: dict[str, Any] | None = None,
644648
read_timeout_seconds: timedelta | None = None,
649+
meta: dict[str, Any] | None = None,
645650
) -> MCPToolResult:
646651
"""Asynchronously calls a tool on the MCP server.
647652
@@ -653,6 +658,7 @@ async def call_tool_async(
653658
name: Name of the tool to call
654659
arguments: Optional arguments to pass to the tool
655660
read_timeout_seconds: Optional timeout for the tool call
661+
meta: Optional metadata to pass to the tool call per MCP spec (_meta)
656662
657663
Returns:
658664
MCPToolResult: The result of the tool call
@@ -662,7 +668,7 @@ async def call_tool_async(
662668
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
663669

664670
try:
665-
coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds)
671+
coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta)
666672
future = self._invoke_on_background_thread(coro)
667673
call_tool_result: MCPCallToolResult = await asyncio.wrap_future(future)
668674
return self._handle_tool_result(tool_use_id, call_tool_result)

src/strands/tools/mcp/mcp_instrumentation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,10 @@ def patch_mcp_client(wrapped: Callable[..., Any], instance: Any, args: Any, kwar
9090
if hasattr(request.root, "params") and request.root.params:
9191
# Handle Pydantic models
9292
if hasattr(request.root.params, "model_dump") and hasattr(request.root.params, "model_validate"):
93-
params_dict = request.root.params.model_dump()
93+
params_dict = request.root.params.model_dump(by_alias=True)
9494
# Add _meta with tracing context
95-
meta = params_dict.setdefault("_meta", {})
95+
meta = params_dict.get("_meta") if params_dict.get("_meta") is not None else {}
96+
params_dict["_meta"] = meta
9697
propagate.get_global_textmap().inject(meta)
9798

9899
# Recreate the Pydantic model with the updated data

tests/strands/tools/mcp/test_mcp_client.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def test_call_tool_sync_status(mock_transport, mock_session, is_error, expected_
124124
with MCPClient(mock_transport["transport_callable"]) as client:
125125
result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})
126126

127-
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None)
127+
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None)
128128

129129
assert result["status"] == expected_status
130130
assert result["toolUseId"] == "test-123"
@@ -153,7 +153,7 @@ def test_call_tool_sync_with_structured_content(mock_transport, mock_session):
153153
with MCPClient(mock_transport["transport_callable"]) as client:
154154
result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})
155155

156-
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None)
156+
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None)
157157

158158
assert result["status"] == "success"
159159
assert result["toolUseId"] == "test-123"
@@ -180,6 +180,51 @@ def test_call_tool_sync_exception(mock_transport, mock_session):
180180
assert "Test exception" in result["content"][0]["text"]
181181

182182

183+
def test_call_tool_sync_forwards_meta(mock_transport, mock_session):
184+
"""Test that call_tool_sync forwards meta to ClientSession.call_tool."""
185+
mock_content = MCPTextContent(type="text", text="Test message")
186+
mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[mock_content])
187+
meta = {"com.example/request_id": "abc-123"}
188+
189+
with MCPClient(mock_transport["transport_callable"]) as client:
190+
result = client.call_tool_sync(
191+
tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, meta=meta
192+
)
193+
194+
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=meta)
195+
assert result["status"] == "success"
196+
197+
198+
@pytest.mark.asyncio
199+
async def test_call_tool_async_forwards_meta(mock_transport, mock_session):
200+
"""Test that call_tool_async forwards meta to ClientSession.call_tool."""
201+
mock_content = MCPTextContent(type="text", text="Test message")
202+
mock_result = MCPCallToolResult(isError=False, content=[mock_content])
203+
mock_session.call_tool.return_value = mock_result
204+
meta = {"com.example/request_id": "abc-123"}
205+
206+
with MCPClient(mock_transport["transport_callable"]) as client:
207+
with (
208+
patch("asyncio.run_coroutine_threadsafe") as mock_run_coroutine_threadsafe,
209+
patch("asyncio.wrap_future") as mock_wrap_future,
210+
):
211+
mock_future = MagicMock()
212+
mock_run_coroutine_threadsafe.return_value = mock_future
213+
214+
async def mock_awaitable():
215+
return mock_result
216+
217+
mock_wrap_future.return_value = mock_awaitable()
218+
219+
result = await client.call_tool_async(
220+
tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, meta=meta
221+
)
222+
223+
mock_run_coroutine_threadsafe.assert_called_once()
224+
225+
assert result["status"] == "success"
226+
227+
183228
@pytest.mark.asyncio
184229
@pytest.mark.parametrize("is_error,expected_status", [(False, "success"), (True, "error")])
185230
async def test_call_tool_async_status(mock_transport, mock_session, is_error, expected_status):
@@ -584,7 +629,7 @@ def test_call_tool_sync_embedded_nested_text(mock_transport, mock_session):
584629
with MCPClient(mock_transport["transport_callable"]) as client:
585630
result = client.call_tool_sync(tool_use_id="er-text", name="get_file_contents", arguments={})
586631

587-
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None)
632+
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
588633
assert result["status"] == "success"
589634
assert len(result["content"]) == 1
590635
assert result["content"][0]["text"] == "inner text"
@@ -609,7 +654,7 @@ def test_call_tool_sync_embedded_nested_base64_textual_mime(mock_transport, mock
609654
with MCPClient(mock_transport["transport_callable"]) as client:
610655
result = client.call_tool_sync(tool_use_id="er-blob", name="get_file_contents", arguments={})
611656

612-
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None)
657+
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
613658
assert result["status"] == "success"
614659
assert len(result["content"]) == 1
615660
assert result["content"][0]["text"] == '{"k":"v"}'
@@ -635,7 +680,7 @@ def test_call_tool_sync_embedded_image_blob(mock_transport, mock_session):
635680
with MCPClient(mock_transport["transport_callable"]) as client:
636681
result = client.call_tool_sync(tool_use_id="er-image", name="get_file_contents", arguments={})
637682

638-
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None)
683+
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
639684
assert result["status"] == "success"
640685
assert len(result["content"]) == 1
641686
assert "image" in result["content"][0]
@@ -660,7 +705,7 @@ def test_call_tool_sync_embedded_non_textual_blob_dropped(mock_transport, mock_s
660705
with MCPClient(mock_transport["transport_callable"]) as client:
661706
result = client.call_tool_sync(tool_use_id="er-binary", name="get_file_contents", arguments={})
662707

663-
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None)
708+
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
664709
assert result["status"] == "success"
665710
assert len(result["content"]) == 0 # Content should be dropped
666711

@@ -683,7 +728,7 @@ def test_call_tool_sync_embedded_multiple_textual_mimes(mock_transport, mock_ses
683728
with MCPClient(mock_transport["transport_callable"]) as client:
684729
result = client.call_tool_sync(tool_use_id="er-yaml", name="get_file_contents", arguments={})
685730

686-
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None)
731+
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
687732
assert result["status"] == "success"
688733
assert len(result["content"]) == 1
689734
assert "key: value" in result["content"][0]["text"]
@@ -710,7 +755,7 @@ def __init__(self):
710755
with MCPClient(mock_transport["transport_callable"]) as client:
711756
result = client.call_tool_sync(tool_use_id="er-unknown", name="get_file_contents", arguments={})
712757

713-
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None)
758+
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
714759
assert result["status"] == "success"
715760
assert len(result["content"]) == 0 # Unknown resource type should be dropped
716761

@@ -762,7 +807,7 @@ def test_call_tool_sync_with_meta_and_structured_content(mock_transport, mock_se
762807
with MCPClient(mock_transport["transport_callable"]) as client:
763808
result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})
764809

765-
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None)
810+
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None)
766811

767812
assert result["status"] == "success"
768813
assert result["toolUseId"] == "test-123"

tests/strands/tools/mcp/test_mcp_instrumentation.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ class MockPydanticParams:
328328
def __init__(self, **data):
329329
self._data = data
330330

331-
def model_dump(self):
331+
def model_dump(self, by_alias=False):
332332
return self._data.copy()
333333

334334
@classmethod
@@ -431,6 +431,32 @@ def test_patch_mcp_client_injects_context_pydantic_model(self):
431431
# Verify the params object is still a MockPydanticParams (or dict if fallback occurred)
432432
assert hasattr(mock_request.root.params, "model_dump") or isinstance(mock_request.root.params, dict)
433433

434+
def test_patch_mcp_client_preserves_existing_meta_pydantic(self):
435+
"""Test that instrumentation preserves existing _meta values in Pydantic models."""
436+
mock_request = MagicMock()
437+
mock_request.root.method = "tools/call"
438+
439+
# Pydantic model with existing _meta (returned via by_alias=True)
440+
mock_params = MockPydanticParams(_meta={"com.example/request_id": "abc-123"}, name="echo")
441+
mock_request.root.params = mock_params
442+
443+
with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap:
444+
mcp_instrumentation()
445+
patch_function = mock_wrap.call_args_list[0][0][2]
446+
447+
mock_wrapped = MagicMock()
448+
449+
with patch.object(propagate, "get_global_textmap") as mock_textmap:
450+
mock_textmap_instance = MagicMock()
451+
mock_textmap.return_value = mock_textmap_instance
452+
453+
patch_function(mock_wrapped, None, [mock_request], {})
454+
455+
# Verify the reconstructed params use the key "_meta" (alias) not "meta" (Python name)
456+
validated_params = mock_request.root.params.model_dump(by_alias=True)
457+
assert "_meta" in validated_params
458+
assert validated_params["_meta"]["com.example/request_id"] == "abc-123"
459+
434460
def test_patch_mcp_client_injects_context_dict_params(self):
435461
"""Test that the client patch injects OpenTelemetry context into dict params."""
436462
# Create a mock request with tools/call method and dict params
@@ -507,7 +533,7 @@ class FailingMockPydanticParams:
507533
def __init__(self, **data):
508534
self._data = data
509535

510-
def model_dump(self):
536+
def model_dump(self, by_alias=False):
511537
return self._data.copy()
512538

513539
def model_validate(self, data):

tests_integ/mcp/echo_server.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import Literal
2121

2222
from mcp.server import FastMCP
23+
from mcp.server.fastmcp import Context
2324
from mcp.types import BlobResourceContents, CallToolResult, EmbeddedResource, TextContent, TextResourceContents
2425
from pydantic import BaseModel
2526

@@ -48,6 +49,13 @@ def start_echo_server():
4849
def echo(to_echo: str) -> str:
4950
return to_echo
5051

52+
@mcp.tool(description="Echos back the _meta received in the request", structured_output=False)
53+
def echo_meta(ctx: Context) -> str:
54+
meta = ctx.request_context.meta
55+
if meta is None:
56+
return json.dumps(None)
57+
return json.dumps(meta.model_dump(exclude_none=True))
58+
5159
# FastMCP automatically constructs structured output schema from method signature
5260
@mcp.tool(description="Echos response back with structured content", structured_output=True)
5361
def echo_with_structured_content(to_echo: str) -> EchoResponse:

tests_integ/mcp/test_mcp_client.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,74 @@ def test_mcp_client_without_structured_content():
238238
assert result["content"] == [{"text": "SIMPLE_ECHO_TEST"}]
239239

240240

241+
def test_call_tool_sync_with_meta():
242+
"""Test that call_tool_sync forwards meta to the MCP server."""
243+
stdio_mcp_client = MCPClient(
244+
lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"]))
245+
)
246+
247+
with stdio_mcp_client:
248+
result = stdio_mcp_client.call_tool_sync(
249+
tool_use_id="test-meta-sync",
250+
name="echo_meta",
251+
arguments={},
252+
meta={"com.example/request_id": "abc-123"},
253+
)
254+
255+
assert result["status"] == "success"
256+
received_meta = json.loads(result["content"][0]["text"])
257+
assert received_meta["com.example/request_id"] == "abc-123"
258+
259+
260+
@pytest.mark.asyncio
261+
async def test_call_tool_async_with_meta():
262+
"""Test that call_tool_async forwards meta to the MCP server."""
263+
stdio_mcp_client = MCPClient(
264+
lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"]))
265+
)
266+
267+
with stdio_mcp_client:
268+
result = await stdio_mcp_client.call_tool_async(
269+
tool_use_id="test-meta-async",
270+
name="echo_meta",
271+
arguments={},
272+
meta={"com.example/request_id": "def-456"},
273+
)
274+
275+
assert result["status"] == "success"
276+
received_meta = json.loads(result["content"][0]["text"])
277+
assert received_meta["com.example/request_id"] == "def-456"
278+
279+
280+
def test_instrumentation_preserves_meta_on_tool_call():
281+
"""Test that OTel instrumentation sets _meta that reaches the MCP server."""
282+
from unittest.mock import MagicMock, patch
283+
284+
# Mock the propagator to always inject a known value, bypassing the need for
285+
# an active span on the background thread where send_request runs
286+
mock_textmap = MagicMock()
287+
mock_textmap.inject = lambda carrier, **kwargs: carrier.update({"traceparent": "00-abc-def-01"})
288+
289+
with patch("opentelemetry.propagate.get_global_textmap", return_value=mock_textmap):
290+
stdio_mcp_client = MCPClient(
291+
lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"]))
292+
)
293+
294+
with stdio_mcp_client:
295+
result = stdio_mcp_client.call_tool_sync(
296+
tool_use_id="test-instrumentation",
297+
name="echo_meta",
298+
arguments={},
299+
)
300+
301+
assert result["status"] == "success"
302+
received_meta = json.loads(result["content"][0]["text"])
303+
# OTel instrumentation should have injected _meta with tracing context
304+
assert received_meta is not None
305+
assert isinstance(received_meta, dict)
306+
assert received_meta["traceparent"] == "00-abc-def-01"
307+
308+
241309
@pytest.mark.skipif(
242310
condition=os.environ.get("GITHUB_ACTIONS") == "true",
243311
reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue",

0 commit comments

Comments
 (0)