diff --git a/mcpgateway/services/export_service.py b/mcpgateway/services/export_service.py index a295b88c45..e955ce45b9 100644 --- a/mcpgateway/services/export_service.py +++ b/mcpgateway/services/export_service.py @@ -32,6 +32,7 @@ from mcpgateway.db import Resource as DbResource from mcpgateway.db import Server as DbServer from mcpgateway.db import Tool as DbTool +from mcpgateway.utils.services_auth import encode_auth # Service singletons are imported lazily in __init__ to avoid circular imports @@ -518,7 +519,8 @@ async def _export_gateways(self, db: Session, tags: Optional[List[str]], include auth_type, auth_value = auth_data_map[gateway.id] if auth_value: gateway_data["auth_type"] = auth_type - gateway_data["auth_value"] = auth_value + # DbGateway.auth_value is JSON (dict); export format expects encoded string. + gateway_data["auth_value"] = encode_auth(auth_value) if isinstance(auth_value, dict) else auth_value else: # Auth value is not masked, use as-is gateway_data["auth_type"] = gateway.auth_type @@ -939,7 +941,9 @@ async def _export_selected_gateways(self, db: Session, gateway_ids: List[str], u if db_gateway.auth_type: gateway_data["auth_type"] = db_gateway.auth_type if db_gateway.auth_value: - gateway_data["auth_value"] = db_gateway.auth_value + # DbGateway.auth_value is JSON (dict); export format expects an encoded string. + raw = db_gateway.auth_value + gateway_data["auth_value"] = encode_auth(raw) if isinstance(raw, dict) else raw # Include query param auth if present if db_gateway.auth_type == "query_param" and getattr(db_gateway, "auth_query_params", None): gateway_data["auth_query_params"] = db_gateway.auth_query_params diff --git a/mcpgateway/services/gateway_service.py b/mcpgateway/services/gateway_service.py index 1a76b945b5..a26bc233a4 100644 --- a/mcpgateway/services/gateway_service.py +++ b/mcpgateway/services/gateway_service.py @@ -831,8 +831,7 @@ async def register_gateway( elif hasattr(gateway, "auth_headers") and gateway.auth_headers: # Convert list of {key, value} to dict header_dict = {h["key"]: h["value"] for h in gateway.auth_headers if h.get("key")} - # Keep encoded form for persistence, but pass raw headers for initialization - auth_value = encode_auth(header_dict) # Encode the dict for consistency + auth_value = header_dict # store plain dict, consistent with update path and DB column type authentication_headers = {str(k): str(v) for k, v in header_dict.items()} elif isinstance(auth_value, str) and auth_value: @@ -885,6 +884,11 @@ async def register_gateway( auth_value = None oauth_config = None + # DbTool.auth_value is Mapped[Optional[str]] (Text), so encode the dict before + # storing it there. DbGateway.auth_value is Mapped[Optional[Dict]] (JSON) and + # receives the plain dict directly (see assignment above). + tool_auth_value = encode_auth(auth_value) if isinstance(auth_value, dict) else auth_value + tools = [ DbTool( original_name=tool.name, @@ -902,7 +906,7 @@ async def register_gateway( annotations=tool.annotations, jsonpath_filter=tool.jsonpath_filter, auth_type=auth_type, - auth_value=auth_value, + auth_value=tool_auth_value, # Federation metadata created_by=created_by or "system", created_from_ip=created_from_ip, @@ -4132,8 +4136,20 @@ def _update_or_create_tools(self, db: Session, tools: List[Any], gateway: DbGate or existing_tool.jsonpath_filter != tool.jsonpath_filter ) - # Check authentication and visibility changes - auth_fields_changed = existing_tool.auth_type != gateway.auth_type or existing_tool.auth_value != gateway.auth_value or existing_tool.visibility != gateway.visibility + # Check authentication and visibility changes. + # DbTool.auth_value is Text (encoded str); DbGateway.auth_value is JSON (dict). + # encode_auth() uses a random nonce, so comparing ciphertext would always + # differ even when the plaintext hasn't changed. Compare on decoded + # (plaintext) values instead, and only encode on the write path. + # If decoding fails (legacy/corrupt data), fall back to direct comparison. + try: + gateway_auth_plain = gateway.auth_value if isinstance(gateway.auth_value, dict) else (decode_auth(gateway.auth_value) if gateway.auth_value else {}) + existing_tool_auth_plain = decode_auth(existing_tool.auth_value) if existing_tool.auth_value else {} + auth_value_changed = existing_tool_auth_plain != gateway_auth_plain + except Exception: + gateway_tool_auth_value = encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value + auth_value_changed = existing_tool.auth_value != gateway_tool_auth_value + auth_fields_changed = existing_tool.auth_type != gateway.auth_type or auth_value_changed or existing_tool.visibility != gateway.visibility if basic_fields_changed or schema_fields_changed or auth_fields_changed: fields_to_update = True @@ -4151,7 +4167,7 @@ def _update_or_create_tools(self, db: Session, tools: List[Any], gateway: DbGate existing_tool.output_schema = tool.output_schema existing_tool.jsonpath_filter = tool.jsonpath_filter existing_tool.auth_type = gateway.auth_type - existing_tool.auth_value = gateway.auth_value + existing_tool.auth_value = encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value existing_tool.visibility = gateway.visibility logger.debug(f"Updated existing tool: {tool.name}") else: diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 973d1f2757..ceed7afd2e 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -2835,7 +2835,8 @@ async def invoke_tool( "name": gateway.name, "url": gateway.url, "auth_type": gateway.auth_type, - "auth_value": gateway.auth_value, + # DbGateway.auth_value is JSON (dict); downstream code expects an encoded str. + "auth_value": encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value, "auth_query_params": gateway.auth_query_params, "oauth_config": gateway.oauth_config, "ca_certificate": gateway.ca_certificate, @@ -3022,7 +3023,9 @@ async def invoke_tool( gateway_oauth_config = gateway_payload.get("oauth_config") if has_gateway and isinstance(gateway_payload.get("oauth_config"), dict) else None if has_gateway and gateway is not None: runtime_gateway_auth_value = getattr(gateway, "auth_value", None) - if isinstance(runtime_gateway_auth_value, str): + if isinstance(runtime_gateway_auth_value, dict): + gateway_auth_value = encode_auth(runtime_gateway_auth_value) + elif isinstance(runtime_gateway_auth_value, str): gateway_auth_value = runtime_gateway_auth_value runtime_gateway_query_params = getattr(gateway, "auth_query_params", None) if isinstance(runtime_gateway_query_params, dict): @@ -3053,7 +3056,9 @@ async def invoke_tool( tool_oauth_config = hydrated_tool_oauth_config if has_gateway and tool_auth_row.gateway: hydrated_gateway_auth_value = getattr(tool_auth_row.gateway, "auth_value", None) - if isinstance(hydrated_gateway_auth_value, str): + if isinstance(hydrated_gateway_auth_value, dict): + gateway_auth_value = encode_auth(hydrated_gateway_auth_value) + elif isinstance(hydrated_gateway_auth_value, str): gateway_auth_value = hydrated_gateway_auth_value hydrated_gateway_query_params = getattr(tool_auth_row.gateway, "auth_query_params", None) if isinstance(hydrated_gateway_query_params, dict): diff --git a/tests/unit/mcpgateway/services/test_export_service.py b/tests/unit/mcpgateway/services/test_export_service.py index 37b08d46e1..672fc4b530 100644 --- a/tests/unit/mcpgateway/services/test_export_service.py +++ b/tests/unit/mcpgateway/services/test_export_service.py @@ -1696,3 +1696,81 @@ async def test_export_selected_resources_scoped_no_visible_match_returns_empty(e ) assert exported == [] mock_db.execute.assert_not_called() + + +@pytest.mark.asyncio +async def test_export_gateways_masked_auth_encodes_dict_auth_value(export_service, mock_db): + """Batch-fetched auth_value that is a dict (authheaders) must be encoded before export.""" + # First-Party + from mcpgateway.config import settings + from mcpgateway.utils.services_auth import decode_auth + + auth_dict = {"X-Custom-Auth": "my-token", "X-Org-ID": "org-42"} + + gateway = GatewayRead( + id="gw1", + name="gw-authheaders", + url="https://gw.example.com", + description="gateway with authheaders", + transport="SSE", + capabilities={}, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + enabled=True, + reachable=True, + last_seen=datetime.now(timezone.utc), + auth_type="authheaders", + auth_value=settings.masked_auth_value, + auth_username=None, + auth_password=None, + auth_token=None, + auth_header_key=None, + auth_header_value=None, + tags=[], + slug="gw-authheaders", + passthrough_headers=None, + ) + + export_service.gateway_service.list_gateways.return_value = ([gateway], None) + # DB batch query returns the raw dict (JSON column value for authheaders) + mock_db.execute.return_value.all.return_value = [("gw1", "authheaders", auth_dict)] + + exported = await export_service._export_gateways(mock_db, None, False) + + assert len(exported) == 1 + assert exported[0]["auth_type"] == "authheaders" + # auth_value must be an encoded string, not the raw dict + assert isinstance(exported[0]["auth_value"], str) + assert decode_auth(exported[0]["auth_value"]) == auth_dict + + +@pytest.mark.asyncio +async def test_export_selected_gateways_encodes_dict_auth_value(export_service, mock_db): + """_export_selected_gateways must encode dict auth_value (authheaders) before export.""" + # First-Party + from mcpgateway.utils.services_auth import decode_auth + + auth_dict = {"X-Custom-Auth": "my-token"} + + db_gateway = MagicMock() + db_gateway.id = "gw1" + db_gateway.name = "gw-authheaders" + db_gateway.url = "https://gw.example.com" + db_gateway.description = "desc" + db_gateway.transport = "SSE" + db_gateway.capabilities = {} + db_gateway.is_active = True + db_gateway.tags = [] + db_gateway.passthrough_headers = [] + db_gateway.auth_type = "authheaders" + db_gateway.auth_value = auth_dict # JSON column stores plain dict + db_gateway.auth_query_params = None + + mock_db.execute.return_value.scalars.return_value.all.return_value = [db_gateway] + + exported = await export_service._export_selected_gateways(mock_db, ["gw1"]) + + assert len(exported) == 1 + assert exported[0]["auth_type"] == "authheaders" + assert isinstance(exported[0]["auth_value"], str) + assert decode_auth(exported[0]["auth_value"]) == auth_dict diff --git a/tests/unit/mcpgateway/services/test_gateway_service_helpers.py b/tests/unit/mcpgateway/services/test_gateway_service_helpers.py index 698a96a24f..ccfc265b64 100644 --- a/tests/unit/mcpgateway/services/test_gateway_service_helpers.py +++ b/tests/unit/mcpgateway/services/test_gateway_service_helpers.py @@ -4,6 +4,7 @@ # Standard import tempfile from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, Mock # Third-Party import pytest @@ -12,7 +13,7 @@ from mcpgateway.config import settings from mcpgateway.schemas import GatewayRead from mcpgateway.services.gateway_service import GatewayConnectionError, GatewayNameConflictError, GatewayService, OAuthToolValidationError -from mcpgateway.utils.services_auth import decode_auth +from mcpgateway.utils.services_auth import decode_auth, encode_auth from mcpgateway.validation.tags import validate_tags_field @@ -133,3 +134,168 @@ def test_gateway_service_validate_tools_valueerror(monkeypatch): with pytest.raises(GatewayConnectionError) as excinfo: service._validate_tools([{"name": "tool-other"}]) assert "ValueError" in str(excinfo.value) + + +@pytest.mark.asyncio +async def test_authheaders_auth_value_stored_as_dict(monkeypatch): + """Verify that registering a gateway with authheaders stores auth_value as a plain dict. + + auth_value DB column is Mapped[Optional[Dict[str, str]]] (JSON). Storing a string + in that column causes the driver to write JSON null, which breaks health checks + and the auto-refresh loop. The creation path must store the plain dict, consistent + with the update path and the column type annotation. + """ + # Verify the type contract: encode_auth() returns str, NOT a dict. + # This is why storing its result in a Dict-typed JSON column produces null. + encoded = encode_auth({"X-Key": "value"}) + assert isinstance(encoded, str), "encode_auth must return str — storing it in a dict JSON column yields null" + + # Build a minimal gateway with authheaders + # Standard + from types import SimpleNamespace as NS + + gateway = NS( + name="test-gw", + url="http://localhost:8000/mcp", + description=None, + transport="sse", + tags=[], + passthrough_headers=None, + auth_type="authheaders", + auth_value=None, + auth_headers=[ + {"key": "X-Custom-Auth-Header", "value": "my-token"}, + {"key": "X-Custom-User-ID", "value": "user-123"}, + ], + auth_query_param_key=None, + auth_query_param_value=None, + auth_query_params=None, + oauth_config=None, + one_time_auth=False, + ca_certificate=None, + ca_certificate_sig=None, + signing_algorithm=None, + visibility="public", + enabled=True, + team_id=None, + owner_email=None, + gateway_mode="cache", + ) + + # First-Party + from mcpgateway.schemas import ToolCreate + + fake_tool = ToolCreate(name="echo", integration_type="REST", request_type="POST", url="http://localhost:8000/mcp") + + service = GatewayService() + service._check_gateway_uniqueness = MagicMock(return_value=None) + service._initialize_gateway = AsyncMock(return_value=({"tools": {}}, [fake_tool], [], [])) + service._notify_gateway_added = AsyncMock() + + monkeypatch.setattr("mcpgateway.services.gateway_service.get_for_update", lambda *_a, **_kw: None) + monkeypatch.setattr( + "mcpgateway.services.gateway_service.GatewayRead.model_validate", + lambda x: MagicMock(masked=lambda: x), + ) + + db = MagicMock() + db.flush = Mock() + db.refresh = Mock() + + # Snapshot at db.add() time — tools flow through the gateway relationship (gateway.tools=tools), + # not separate db.add() calls. _prepare_gateway_for_read() later mutates db_gateway.auth_value + # to an encoded string for the GatewayRead response; we capture before that mutation. + # First-Party + from mcpgateway.db import Gateway as DbGateway + + captured_gw: dict = {} + captured_tool_auth_values: list = [] + + def _capture_add(obj): + if isinstance(obj, DbGateway): + captured_gw["auth_value"] = obj.auth_value # snapshot before any mutation + for t in obj.tools or []: + captured_tool_auth_values.append(t.auth_value) + + db.add = Mock(side_effect=_capture_add) + + await service.register_gateway(db, gateway) + + # --- DbGateway assertion --- + # auth_value must be a plain dict — NOT a string. + # A string stored in a Mapped[Optional[Dict[str, str]]] JSON column is written as JSON null. + assert "auth_value" in captured_gw, "db.add was never called with a DbGateway object" + assert isinstance(captured_gw["auth_value"], dict), f"DbGateway.auth_value must be dict for authheaders auth type, got {type(captured_gw['auth_value'])}: {captured_gw['auth_value']!r}" + assert captured_gw["auth_value"] == {"X-Custom-Auth-Header": "my-token", "X-Custom-User-ID": "user-123"} + + # --- DbTool assertion --- + # DbTool.auth_value is Mapped[Optional[str]] (Text), so it must be an encoded string, + # not a raw dict. tool_service.py calls decode_auth() on it at read-time. + assert len(captured_tool_auth_values) == 1, "expected exactly one DbTool to be added" + assert isinstance(captured_tool_auth_values[0], str), f"DbTool.auth_value must be an encoded string for Text column, got {type(captured_tool_auth_values[0])}: {captured_tool_auth_values[0]!r}" + # Decoding must recover the original headers dict + assert decode_auth(captured_tool_auth_values[0]) == {"X-Custom-Auth-Header": "my-token", "X-Custom-User-ID": "user-123"} + + +def test_update_or_create_tools_authheaders_no_spurious_update(): + """Verify _update_or_create_tools does NOT trigger a spurious update when the + gateway's auth_value dict matches the existing tool's encoded auth_value. + + encode_auth() uses os.urandom(12) for the AES-GCM nonce, so comparing + ciphertext would always differ even when the plaintext is identical. The + comparison must use decoded/plaintext values to avoid write amplification + on every health-check refresh cycle. + """ + # Standard + from types import SimpleNamespace as NS + + service = GatewayService() + + auth_dict = {"X-My-Header": "secret-val"} + encoded = encode_auth(auth_dict) + original_encoded = encoded # save for byte-for-byte comparison + + # Existing tool already has the correctly encoded auth_value stored + existing = MagicMock() + existing.original_name = "my-tool" + existing.url = "http://gw.example.com/mcp" + existing.description = "desc" + existing.original_description = "desc" + existing.integration_type = "MCP" + existing.request_type = "POST" + existing.headers = {} + existing.input_schema = {} + existing.output_schema = None + existing.jsonpath_filter = None + existing.auth_type = "authheaders" + existing.auth_value = encoded # Text column — already encoded + existing.visibility = "public" + + db = MagicMock() + db.execute.return_value.scalars.return_value.all.return_value = [existing] + + tool = NS( + name="my-tool", + description="desc", + input_schema={}, + output_schema=None, + request_type="POST", + headers={}, + annotations=None, + jsonpath_filter=None, + ) + + gateway = MagicMock() + gateway.id = "gw-1" + gateway.url = "http://gw.example.com/mcp" + gateway.auth_type = "authheaders" + gateway.auth_value = auth_dict # JSON column — plain dict + gateway.visibility = "public" + + result = service._update_or_create_tools(db, [tool], gateway, "update") + + # No new tools returned + assert result == [] + # auth_value must be the EXACT same string — no spurious re-encryption + assert existing.auth_value is original_encoded, f"auth_value was spuriously rewritten: {existing.auth_value!r} != {original_encoded!r}" + assert decode_auth(existing.auth_value) == auth_dict diff --git a/tests/unit/mcpgateway/services/test_tool_service.py b/tests/unit/mcpgateway/services/test_tool_service.py index e59c05b76b..534457c28b 100644 --- a/tests/unit/mcpgateway/services/test_tool_service.py +++ b/tests/unit/mcpgateway/services/test_tool_service.py @@ -2588,6 +2588,73 @@ async def test_invoke_tool_cache_hit_hydrates_auth_material_from_db(self, tool_s assert response.content[0].text == "Invalid tool type" assert test_db.execute.called + @pytest.mark.asyncio + async def test_invoke_tool_cache_hit_hydrates_authheaders_dict_from_db(self, tool_service, test_db): + """Cache-miss hydration encodes DbGateway.auth_value dict (JSON col) to str for downstream use.""" + cached_payload = { + "status": "active", + "tool": { + "id": "tool-cache-2", + "name": "cached_tool_ah", + "original_name": "cached_tool_ah", + "url": "http://example.com/tool", + "integration_type": "ABC", + "request_type": "POST", + "auth_type": "authheaders", + "headers": {}, + "annotations": {}, + "jsonpath_filter": "", + "output_schema": {}, + "enabled": True, + "reachable": True, + "visibility": "public", + "owner_email": None, + "team_id": None, + "gateway_id": "gw-cache-2", + }, + "gateway": { + "id": "gw-cache-2", + "name": "cached-gw-ah", + "url": "http://example.com/gateway", + "auth_type": "authheaders", + "passthrough_headers": [], + }, + } + + lookup_cache = SimpleNamespace( + enabled=True, + get=AsyncMock(return_value=cached_payload), + set=AsyncMock(), + set_negative=AsyncMock(), + ) + + # DbGateway.auth_value is a JSON dict — the path under test encodes it + hydrated_gateway = SimpleNamespace( + auth_value={"X-Custom-Auth-Header": "my-token"}, + auth_query_params=None, + oauth_config=None, + ) + hydrated_tool = SimpleNamespace( + auth_value=None, + oauth_config=None, + gateway=hydrated_gateway, + ) + hydration_result = Mock() + hydration_result.scalar_one_or_none.return_value = hydrated_tool + test_db.execute = Mock(return_value=hydration_result) + + with ( + patch("mcpgateway.services.tool_service._get_tool_lookup_cache", return_value=lookup_cache), + patch("mcpgateway.services.tool_service.global_config_cache.get_passthrough_headers", return_value=[]), + patch("mcpgateway.services.tool_service.encode_auth", wraps=encode_auth) as spy_encode, + ): + response = await tool_service.invoke_tool(test_db, "cached_tool_ah", {"param": "value"}, request_headers=None) + + assert response.content[0].text == "Invalid tool type" + assert test_db.execute.called + # Verify the hydration path actually called encode_auth on the dict + spy_encode.assert_called_once_with({"X-Custom-Auth-Header": "my-token"}) + @pytest.mark.asyncio async def test_invoke_tool_mcp_tool_basic_auth(self, tool_service, mock_tool, mock_gateway, test_db): """Test invoking an invalid tool type."""