Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mcpgateway/services/tool_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4734,6 +4734,7 @@ async def invoke_tool(
a2a_agent_auth_type = a2a_agent.auth_type
a2a_agent_auth_value = a2a_agent.auth_value
a2a_agent_auth_query_params = a2a_agent.auth_query_params
a2a_agent_passthrough_headers = a2a_agent.passthrough_headers or []

# ═══════════════════════════════════════════════════════════════════════════
# CRITICAL: Release DB connection back to pool BEFORE making HTTP calls
Expand Down Expand Up @@ -5744,6 +5745,13 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head
# A2A tool invocation using pre-extracted agent data (extracted in Phase 2 before db.close())
headers = {"Content-Type": "application/json"}

if request_headers and a2a_agent_passthrough_headers:
inbound_headers = {k.lower(): v for k, v in request_headers.items()}
for header_name in a2a_agent_passthrough_headers:
header_value = inbound_headers.get(header_name.lower())
if header_value:
headers[header_name] = header_value

# Plugin hook: tool pre-invoke for A2A
plugin_manager = await self._get_plugin_manager(plugin_context_id)
if plugin_manager and plugin_manager.has_hooks_for(ToolHookType.TOOL_PRE_INVOKE) and not skip_pre_invoke:
Expand Down
55 changes: 54 additions & 1 deletion tests/unit/mcpgateway/services/test_tool_service_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7716,7 +7716,7 @@ async def fake_request(method, url, json=None, headers=None):
# ---------------------------------------------------------------------------


def _make_a2a_agent(*, enabled=True, agent_type="jsonrpc", auth_type=None, auth_value=None, auth_query_params=None):
def _make_a2a_agent(*, enabled=True, agent_type="jsonrpc", auth_type=None, auth_value=None, auth_query_params=None, passthrough_headers=None):
"""Create a mock A2A agent."""
agent = MagicMock()
agent.id = "agent-uuid-1"
Expand All @@ -7728,6 +7728,7 @@ def _make_a2a_agent(*, enabled=True, agent_type="jsonrpc", auth_type=None, auth_
agent.auth_type = auth_type
agent.auth_value = auth_value
agent.auth_query_params = auth_query_params
agent.passthrough_headers = passthrough_headers
return agent


Expand Down Expand Up @@ -7812,6 +7813,58 @@ async def fake_post(url, json=None, headers=None):

return await tool_service.invoke_tool(db, "test_tool", {"query": "test"})

@pytest.mark.asyncio
async def test_a2a_invoke_tool_forwards_passthrough_headers(self, tool_service):
"""invoke_tool A2A path forwards whitelisted passthrough headers and drops the rest."""
tp = _make_tool_payload(
integration_type="A2A",
request_type="POST",
annotations={"a2a_agent_id": "agent-uuid-1"},
)
db = MagicMock()
a2a_agent = _make_a2a_agent(passthrough_headers=["x-user-id"])
db.execute = MagicMock(return_value=MagicMock(scalar_one_or_none=MagicMock(return_value=a2a_agent)))

mock_http_response = MagicMock()
mock_http_response.status_code = 200
mock_http_response.json = MagicMock(return_value={"response": "ok"})

captured = {}

async def fake_post(url, json=None, headers=None):
captured["headers"] = headers or {}
return mock_http_response

with (
_setup_cache_for_invoke(tp),
patch.object(tool_service, "_check_tool_access", AsyncMock(return_value=True)),
patch("mcpgateway.services.tool_service.global_config_cache") as mock_gcc,
patch("mcpgateway.services.tool_service.current_trace_id") as mock_trace,
patch("mcpgateway.services.tool_service.create_span") as mock_span_ctx,
patch("mcpgateway.services.metrics_buffer_service.get_metrics_buffer_service") as mock_mbuf,
patch("mcpgateway.services.tool_service.compute_passthrough_headers_cached", return_value={}),
):
mock_gcc.get_passthrough_headers = MagicMock(return_value=[])
mock_trace.get = MagicMock(return_value=None)
mock_span_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock())
mock_span_ctx.return_value.__exit__ = MagicMock(return_value=False)
mock_mbuf.return_value = MagicMock()

tool_service._http_client = AsyncMock()
tool_service._http_client.post = fake_post

result = await tool_service.invoke_tool(
db,
"test_tool",
{"query": "hi"},
request_headers={"X-User-Id": "u1", "X-Secret": "leak"},
)

assert result.is_error is False
forwarded = {k.lower(): v for k, v in captured["headers"].items()}
assert forwarded.get("x-user-id") == "u1"
assert "x-secret" not in forwarded

@pytest.mark.asyncio
async def test_a2a_invoke_tool_response_value_none(self, tool_service):
"""invoke_tool A2A path: response['response'] is None -> serialized as 'null'."""
Expand Down