Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions mcpgateway/services/export_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from mcpgateway.db import Resource as DbResource
from mcpgateway.db import Server as DbServer
from mcpgateway.db import Tool as DbTool
from mcpgateway.utils.services_auth import encode_auth

# Service singletons are imported lazily in __init__ to avoid circular imports

Expand Down Expand Up @@ -518,7 +519,8 @@ async def _export_gateways(self, db: Session, tags: Optional[List[str]], include
auth_type, auth_value = auth_data_map[gateway.id]
if auth_value:
gateway_data["auth_type"] = auth_type
gateway_data["auth_value"] = auth_value
# DbGateway.auth_value is JSON (dict); export format expects encoded string.
gateway_data["auth_value"] = encode_auth(auth_value) if isinstance(auth_value, dict) else auth_value
else:
# Auth value is not masked, use as-is
gateway_data["auth_type"] = gateway.auth_type
Expand Down Expand Up @@ -939,7 +941,9 @@ async def _export_selected_gateways(self, db: Session, gateway_ids: List[str], u
if db_gateway.auth_type:
gateway_data["auth_type"] = db_gateway.auth_type
if db_gateway.auth_value:
gateway_data["auth_value"] = db_gateway.auth_value
# DbGateway.auth_value is JSON (dict); export format expects an encoded string.
raw = db_gateway.auth_value
gateway_data["auth_value"] = encode_auth(raw) if isinstance(raw, dict) else raw
# Include query param auth if present
if db_gateway.auth_type == "query_param" and getattr(db_gateway, "auth_query_params", None):
gateway_data["auth_query_params"] = db_gateway.auth_query_params
Expand Down
28 changes: 22 additions & 6 deletions mcpgateway/services/gateway_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,8 +831,7 @@ async def register_gateway(
elif hasattr(gateway, "auth_headers") and gateway.auth_headers:
# Convert list of {key, value} to dict
header_dict = {h["key"]: h["value"] for h in gateway.auth_headers if h.get("key")}
# Keep encoded form for persistence, but pass raw headers for initialization
auth_value = encode_auth(header_dict) # Encode the dict for consistency
auth_value = header_dict # store plain dict, consistent with update path and DB column type
authentication_headers = {str(k): str(v) for k, v in header_dict.items()}

elif isinstance(auth_value, str) and auth_value:
Expand Down Expand Up @@ -885,6 +884,11 @@ async def register_gateway(
auth_value = None
oauth_config = None

# DbTool.auth_value is Mapped[Optional[str]] (Text), so encode the dict before
# storing it there. DbGateway.auth_value is Mapped[Optional[Dict]] (JSON) and
# receives the plain dict directly (see assignment above).
tool_auth_value = encode_auth(auth_value) if isinstance(auth_value, dict) else auth_value

tools = [
DbTool(
original_name=tool.name,
Expand All @@ -902,7 +906,7 @@ async def register_gateway(
annotations=tool.annotations,
jsonpath_filter=tool.jsonpath_filter,
auth_type=auth_type,
auth_value=auth_value,
auth_value=tool_auth_value,
# Federation metadata
created_by=created_by or "system",
created_from_ip=created_from_ip,
Expand Down Expand Up @@ -4132,8 +4136,20 @@ def _update_or_create_tools(self, db: Session, tools: List[Any], gateway: DbGate
or existing_tool.jsonpath_filter != tool.jsonpath_filter
)

# Check authentication and visibility changes
auth_fields_changed = existing_tool.auth_type != gateway.auth_type or existing_tool.auth_value != gateway.auth_value or existing_tool.visibility != gateway.visibility
# Check authentication and visibility changes.
# DbTool.auth_value is Text (encoded str); DbGateway.auth_value is JSON (dict).
# encode_auth() uses a random nonce, so comparing ciphertext would always
# differ even when the plaintext hasn't changed. Compare on decoded
# (plaintext) values instead, and only encode on the write path.
# If decoding fails (legacy/corrupt data), fall back to direct comparison.
try:
gateway_auth_plain = gateway.auth_value if isinstance(gateway.auth_value, dict) else (decode_auth(gateway.auth_value) if gateway.auth_value else {})
existing_tool_auth_plain = decode_auth(existing_tool.auth_value) if existing_tool.auth_value else {}
auth_value_changed = existing_tool_auth_plain != gateway_auth_plain
except Exception:
gateway_tool_auth_value = encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value
auth_value_changed = existing_tool.auth_value != gateway_tool_auth_value
auth_fields_changed = existing_tool.auth_type != gateway.auth_type or auth_value_changed or existing_tool.visibility != gateway.visibility

if basic_fields_changed or schema_fields_changed or auth_fields_changed:
fields_to_update = True
Expand All @@ -4151,7 +4167,7 @@ def _update_or_create_tools(self, db: Session, tools: List[Any], gateway: DbGate
existing_tool.output_schema = tool.output_schema
existing_tool.jsonpath_filter = tool.jsonpath_filter
existing_tool.auth_type = gateway.auth_type
existing_tool.auth_value = gateway.auth_value
existing_tool.auth_value = encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value
existing_tool.visibility = gateway.visibility
logger.debug(f"Updated existing tool: {tool.name}")
else:
Expand Down
11 changes: 8 additions & 3 deletions mcpgateway/services/tool_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2835,7 +2835,8 @@ async def invoke_tool(
"name": gateway.name,
"url": gateway.url,
"auth_type": gateway.auth_type,
"auth_value": gateway.auth_value,
# DbGateway.auth_value is JSON (dict); downstream code expects an encoded str.
"auth_value": encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value,
"auth_query_params": gateway.auth_query_params,
"oauth_config": gateway.oauth_config,
"ca_certificate": gateway.ca_certificate,
Expand Down Expand Up @@ -3022,7 +3023,9 @@ async def invoke_tool(
gateway_oauth_config = gateway_payload.get("oauth_config") if has_gateway and isinstance(gateway_payload.get("oauth_config"), dict) else None
if has_gateway and gateway is not None:
runtime_gateway_auth_value = getattr(gateway, "auth_value", None)
if isinstance(runtime_gateway_auth_value, str):
if isinstance(runtime_gateway_auth_value, dict):
gateway_auth_value = encode_auth(runtime_gateway_auth_value)
elif isinstance(runtime_gateway_auth_value, str):
gateway_auth_value = runtime_gateway_auth_value
runtime_gateway_query_params = getattr(gateway, "auth_query_params", None)
if isinstance(runtime_gateway_query_params, dict):
Expand Down Expand Up @@ -3053,7 +3056,9 @@ async def invoke_tool(
tool_oauth_config = hydrated_tool_oauth_config
if has_gateway and tool_auth_row.gateway:
hydrated_gateway_auth_value = getattr(tool_auth_row.gateway, "auth_value", None)
if isinstance(hydrated_gateway_auth_value, str):
if isinstance(hydrated_gateway_auth_value, dict):
gateway_auth_value = encode_auth(hydrated_gateway_auth_value)
elif isinstance(hydrated_gateway_auth_value, str):
gateway_auth_value = hydrated_gateway_auth_value
hydrated_gateway_query_params = getattr(tool_auth_row.gateway, "auth_query_params", None)
if isinstance(hydrated_gateway_query_params, dict):
Expand Down
78 changes: 78 additions & 0 deletions tests/unit/mcpgateway/services/test_export_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,3 +1696,81 @@ async def test_export_selected_resources_scoped_no_visible_match_returns_empty(e
)
assert exported == []
mock_db.execute.assert_not_called()


@pytest.mark.asyncio
async def test_export_gateways_masked_auth_encodes_dict_auth_value(export_service, mock_db):
"""Batch-fetched auth_value that is a dict (authheaders) must be encoded before export."""
# First-Party
from mcpgateway.config import settings
from mcpgateway.utils.services_auth import decode_auth

auth_dict = {"X-Custom-Auth": "my-token", "X-Org-ID": "org-42"}

gateway = GatewayRead(
id="gw1",
name="gw-authheaders",
url="https://gw.example.com",
description="gateway with authheaders",
transport="SSE",
capabilities={},
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
enabled=True,
reachable=True,
last_seen=datetime.now(timezone.utc),
auth_type="authheaders",
auth_value=settings.masked_auth_value,
auth_username=None,
auth_password=None,
auth_token=None,
auth_header_key=None,
auth_header_value=None,
tags=[],
slug="gw-authheaders",
passthrough_headers=None,
)

export_service.gateway_service.list_gateways.return_value = ([gateway], None)
# DB batch query returns the raw dict (JSON column value for authheaders)
mock_db.execute.return_value.all.return_value = [("gw1", "authheaders", auth_dict)]

exported = await export_service._export_gateways(mock_db, None, False)

assert len(exported) == 1
assert exported[0]["auth_type"] == "authheaders"
# auth_value must be an encoded string, not the raw dict
assert isinstance(exported[0]["auth_value"], str)
assert decode_auth(exported[0]["auth_value"]) == auth_dict


@pytest.mark.asyncio
async def test_export_selected_gateways_encodes_dict_auth_value(export_service, mock_db):
"""_export_selected_gateways must encode dict auth_value (authheaders) before export."""
# First-Party
from mcpgateway.utils.services_auth import decode_auth

auth_dict = {"X-Custom-Auth": "my-token"}

db_gateway = MagicMock()
db_gateway.id = "gw1"
db_gateway.name = "gw-authheaders"
db_gateway.url = "https://gw.example.com"
db_gateway.description = "desc"
db_gateway.transport = "SSE"
db_gateway.capabilities = {}
db_gateway.is_active = True
db_gateway.tags = []
db_gateway.passthrough_headers = []
db_gateway.auth_type = "authheaders"
db_gateway.auth_value = auth_dict # JSON column stores plain dict
db_gateway.auth_query_params = None

mock_db.execute.return_value.scalars.return_value.all.return_value = [db_gateway]

exported = await export_service._export_selected_gateways(mock_db, ["gw1"])

assert len(exported) == 1
assert exported[0]["auth_type"] == "authheaders"
assert isinstance(exported[0]["auth_value"], str)
assert decode_auth(exported[0]["auth_value"]) == auth_dict
168 changes: 167 additions & 1 deletion tests/unit/mcpgateway/services/test_gateway_service_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Standard
import tempfile
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, Mock

# Third-Party
import pytest
Expand All @@ -12,7 +13,7 @@
from mcpgateway.config import settings
from mcpgateway.schemas import GatewayRead
from mcpgateway.services.gateway_service import GatewayConnectionError, GatewayNameConflictError, GatewayService, OAuthToolValidationError
from mcpgateway.utils.services_auth import decode_auth
from mcpgateway.utils.services_auth import decode_auth, encode_auth
from mcpgateway.validation.tags import validate_tags_field


Expand Down Expand Up @@ -133,3 +134,168 @@ def test_gateway_service_validate_tools_valueerror(monkeypatch):
with pytest.raises(GatewayConnectionError) as excinfo:
service._validate_tools([{"name": "tool-other"}])
assert "ValueError" in str(excinfo.value)


@pytest.mark.asyncio
async def test_authheaders_auth_value_stored_as_dict(monkeypatch):
"""Verify that registering a gateway with authheaders stores auth_value as a plain dict.

auth_value DB column is Mapped[Optional[Dict[str, str]]] (JSON). Storing a string
in that column causes the driver to write JSON null, which breaks health checks
and the auto-refresh loop. The creation path must store the plain dict, consistent
with the update path and the column type annotation.
"""
# Verify the type contract: encode_auth() returns str, NOT a dict.
# This is why storing its result in a Dict-typed JSON column produces null.
encoded = encode_auth({"X-Key": "value"})
assert isinstance(encoded, str), "encode_auth must return str β€” storing it in a dict JSON column yields null"

# Build a minimal gateway with authheaders
# Standard
from types import SimpleNamespace as NS

gateway = NS(
name="test-gw",
url="http://localhost:8000/mcp",
description=None,
transport="sse",
tags=[],
passthrough_headers=None,
auth_type="authheaders",
auth_value=None,
auth_headers=[
{"key": "X-Custom-Auth-Header", "value": "my-token"},
{"key": "X-Custom-User-ID", "value": "user-123"},
],
auth_query_param_key=None,
auth_query_param_value=None,
auth_query_params=None,
oauth_config=None,
one_time_auth=False,
ca_certificate=None,
ca_certificate_sig=None,
signing_algorithm=None,
visibility="public",
enabled=True,
team_id=None,
owner_email=None,
gateway_mode="cache",
)

# First-Party
from mcpgateway.schemas import ToolCreate

fake_tool = ToolCreate(name="echo", integration_type="REST", request_type="POST", url="http://localhost:8000/mcp")

service = GatewayService()
service._check_gateway_uniqueness = MagicMock(return_value=None)
service._initialize_gateway = AsyncMock(return_value=({"tools": {}}, [fake_tool], [], []))
service._notify_gateway_added = AsyncMock()

monkeypatch.setattr("mcpgateway.services.gateway_service.get_for_update", lambda *_a, **_kw: None)
monkeypatch.setattr(
"mcpgateway.services.gateway_service.GatewayRead.model_validate",
lambda x: MagicMock(masked=lambda: x),
)

db = MagicMock()
db.flush = Mock()
db.refresh = Mock()

# Snapshot at db.add() time β€” tools flow through the gateway relationship (gateway.tools=tools),
# not separate db.add() calls. _prepare_gateway_for_read() later mutates db_gateway.auth_value
# to an encoded string for the GatewayRead response; we capture before that mutation.
# First-Party
from mcpgateway.db import Gateway as DbGateway

captured_gw: dict = {}
captured_tool_auth_values: list = []

def _capture_add(obj):
if isinstance(obj, DbGateway):
captured_gw["auth_value"] = obj.auth_value # snapshot before any mutation
for t in obj.tools or []:
captured_tool_auth_values.append(t.auth_value)

db.add = Mock(side_effect=_capture_add)

await service.register_gateway(db, gateway)

# --- DbGateway assertion ---
# auth_value must be a plain dict β€” NOT a string.
# A string stored in a Mapped[Optional[Dict[str, str]]] JSON column is written as JSON null.
assert "auth_value" in captured_gw, "db.add was never called with a DbGateway object"
assert isinstance(captured_gw["auth_value"], dict), f"DbGateway.auth_value must be dict for authheaders auth type, got {type(captured_gw['auth_value'])}: {captured_gw['auth_value']!r}"
assert captured_gw["auth_value"] == {"X-Custom-Auth-Header": "my-token", "X-Custom-User-ID": "user-123"}

# --- DbTool assertion ---
# DbTool.auth_value is Mapped[Optional[str]] (Text), so it must be an encoded string,
# not a raw dict. tool_service.py calls decode_auth() on it at read-time.
assert len(captured_tool_auth_values) == 1, "expected exactly one DbTool to be added"
assert isinstance(captured_tool_auth_values[0], str), f"DbTool.auth_value must be an encoded string for Text column, got {type(captured_tool_auth_values[0])}: {captured_tool_auth_values[0]!r}"
# Decoding must recover the original headers dict
assert decode_auth(captured_tool_auth_values[0]) == {"X-Custom-Auth-Header": "my-token", "X-Custom-User-ID": "user-123"}


def test_update_or_create_tools_authheaders_no_spurious_update():
"""Verify _update_or_create_tools does NOT trigger a spurious update when the
gateway's auth_value dict matches the existing tool's encoded auth_value.

encode_auth() uses os.urandom(12) for the AES-GCM nonce, so comparing
ciphertext would always differ even when the plaintext is identical. The
comparison must use decoded/plaintext values to avoid write amplification
on every health-check refresh cycle.
"""
# Standard
from types import SimpleNamespace as NS

service = GatewayService()

auth_dict = {"X-My-Header": "secret-val"}
encoded = encode_auth(auth_dict)
original_encoded = encoded # save for byte-for-byte comparison

# Existing tool already has the correctly encoded auth_value stored
existing = MagicMock()
existing.original_name = "my-tool"
existing.url = "http://gw.example.com/mcp"
existing.description = "desc"
existing.original_description = "desc"
existing.integration_type = "MCP"
existing.request_type = "POST"
existing.headers = {}
existing.input_schema = {}
existing.output_schema = None
existing.jsonpath_filter = None
existing.auth_type = "authheaders"
existing.auth_value = encoded # Text column β€” already encoded
existing.visibility = "public"

db = MagicMock()
db.execute.return_value.scalars.return_value.all.return_value = [existing]

tool = NS(
name="my-tool",
description="desc",
input_schema={},
output_schema=None,
request_type="POST",
headers={},
annotations=None,
jsonpath_filter=None,
)

gateway = MagicMock()
gateway.id = "gw-1"
gateway.url = "http://gw.example.com/mcp"
gateway.auth_type = "authheaders"
gateway.auth_value = auth_dict # JSON column β€” plain dict
gateway.visibility = "public"

result = service._update_or_create_tools(db, [tool], gateway, "update")

# No new tools returned
assert result == []
# auth_value must be the EXACT same string β€” no spurious re-encryption
assert existing.auth_value is original_encoded, f"auth_value was spuriously rewritten: {existing.auth_value!r} != {original_encoded!r}"
assert decode_auth(existing.auth_value) == auth_dict
Loading
Loading