Skip to content

Commit 471541a

Browse files
[FIX][API]: store authheadersauth_value as dict to prevent JSON null on persist (#3510)
* store authheaders auth_value as dict instead of encoded string Signed-off-by: Shoumi <shoumimukherjee@gmail.com> * additional fixes for tool Signed-off-by: Shoumi <shoumimukherjee@gmail.com> * encode DbGateway.auth_value dict for DbTool/export/tool-invoke paths Signed-off-by: Shoumi <shoumimukherjee@gmail.com> * fix coverage Signed-off-by: Shoumi <shoumimukherjee@gmail.com> * fix: isort import order and add authheaders dict export tests Fix import ordering in export_service.py (isort), apply black formatting to test assertions, and add two regression tests covering the dict→encode_auth branch in both export gateway paths. Signed-off-by: Mihai Criveti <crivetimihai@gmail.com> * fix: compare plaintext auth values to avoid spurious tool updates encode_auth() uses os.urandom(12) for AES-GCM nonce, so each call produces different ciphertext even for identical plaintext. Comparing encoded values in _update_or_create_tools() caused every health-check refresh to detect a false auth change and rewrite all tools. Fix: decode existing tool auth_value and compare against plaintext gateway dict. Only encode on the actual write path when auth content truly changed. Also strengthen the test to assert byte-for-byte identity of existing auth_value (no spurious re-encryption). Signed-off-by: Mihai Criveti <crivetimihai@gmail.com> * test: assert encode_auth is called in hydration path The test previously relied on "Invalid tool type" exit without verifying the dict→string encoding actually happened. Add spy on encode_auth to assert it is called with the expected dict during cache-miss hydration. Signed-off-by: Mihai Criveti <crivetimihai@gmail.com> --------- Signed-off-by: Shoumi <shoumimukherjee@gmail.com> Signed-off-by: Mihai Criveti <crivetimihai@gmail.com> Co-authored-by: Mihai Criveti <crivetimihai@gmail.com>
1 parent 89cc891 commit 471541a

File tree

6 files changed

+348
-12
lines changed

6 files changed

+348
-12
lines changed

mcpgateway/services/export_service.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from mcpgateway.db import Resource as DbResource
3333
from mcpgateway.db import Server as DbServer
3434
from mcpgateway.db import Tool as DbTool
35+
from mcpgateway.utils.services_auth import encode_auth
3536

3637
# Service singletons are imported lazily in __init__ to avoid circular imports
3738

@@ -518,7 +519,8 @@ async def _export_gateways(self, db: Session, tags: Optional[List[str]], include
518519
auth_type, auth_value = auth_data_map[gateway.id]
519520
if auth_value:
520521
gateway_data["auth_type"] = auth_type
521-
gateway_data["auth_value"] = auth_value
522+
# DbGateway.auth_value is JSON (dict); export format expects encoded string.
523+
gateway_data["auth_value"] = encode_auth(auth_value) if isinstance(auth_value, dict) else auth_value
522524
else:
523525
# Auth value is not masked, use as-is
524526
gateway_data["auth_type"] = gateway.auth_type
@@ -939,7 +941,9 @@ async def _export_selected_gateways(self, db: Session, gateway_ids: List[str], u
939941
if db_gateway.auth_type:
940942
gateway_data["auth_type"] = db_gateway.auth_type
941943
if db_gateway.auth_value:
942-
gateway_data["auth_value"] = db_gateway.auth_value
944+
# DbGateway.auth_value is JSON (dict); export format expects an encoded string.
945+
raw = db_gateway.auth_value
946+
gateway_data["auth_value"] = encode_auth(raw) if isinstance(raw, dict) else raw
943947
# Include query param auth if present
944948
if db_gateway.auth_type == "query_param" and getattr(db_gateway, "auth_query_params", None):
945949
gateway_data["auth_query_params"] = db_gateway.auth_query_params

mcpgateway/services/gateway_service.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -831,8 +831,7 @@ async def register_gateway(
831831
elif hasattr(gateway, "auth_headers") and gateway.auth_headers:
832832
# Convert list of {key, value} to dict
833833
header_dict = {h["key"]: h["value"] for h in gateway.auth_headers if h.get("key")}
834-
# Keep encoded form for persistence, but pass raw headers for initialization
835-
auth_value = encode_auth(header_dict) # Encode the dict for consistency
834+
auth_value = header_dict # store plain dict, consistent with update path and DB column type
836835
authentication_headers = {str(k): str(v) for k, v in header_dict.items()}
837836

838837
elif isinstance(auth_value, str) and auth_value:
@@ -885,6 +884,11 @@ async def register_gateway(
885884
auth_value = None
886885
oauth_config = None
887886

887+
# DbTool.auth_value is Mapped[Optional[str]] (Text), so encode the dict before
888+
# storing it there. DbGateway.auth_value is Mapped[Optional[Dict]] (JSON) and
889+
# receives the plain dict directly (see assignment above).
890+
tool_auth_value = encode_auth(auth_value) if isinstance(auth_value, dict) else auth_value
891+
888892
tools = [
889893
DbTool(
890894
original_name=tool.name,
@@ -902,7 +906,7 @@ async def register_gateway(
902906
annotations=tool.annotations,
903907
jsonpath_filter=tool.jsonpath_filter,
904908
auth_type=auth_type,
905-
auth_value=auth_value,
909+
auth_value=tool_auth_value,
906910
# Federation metadata
907911
created_by=created_by or "system",
908912
created_from_ip=created_from_ip,
@@ -4132,8 +4136,20 @@ def _update_or_create_tools(self, db: Session, tools: List[Any], gateway: DbGate
41324136
or existing_tool.jsonpath_filter != tool.jsonpath_filter
41334137
)
41344138

4135-
# Check authentication and visibility changes
4136-
auth_fields_changed = existing_tool.auth_type != gateway.auth_type or existing_tool.auth_value != gateway.auth_value or existing_tool.visibility != gateway.visibility
4139+
# Check authentication and visibility changes.
4140+
# DbTool.auth_value is Text (encoded str); DbGateway.auth_value is JSON (dict).
4141+
# encode_auth() uses a random nonce, so comparing ciphertext would always
4142+
# differ even when the plaintext hasn't changed. Compare on decoded
4143+
# (plaintext) values instead, and only encode on the write path.
4144+
# If decoding fails (legacy/corrupt data), fall back to direct comparison.
4145+
try:
4146+
gateway_auth_plain = gateway.auth_value if isinstance(gateway.auth_value, dict) else (decode_auth(gateway.auth_value) if gateway.auth_value else {})
4147+
existing_tool_auth_plain = decode_auth(existing_tool.auth_value) if existing_tool.auth_value else {}
4148+
auth_value_changed = existing_tool_auth_plain != gateway_auth_plain
4149+
except Exception:
4150+
gateway_tool_auth_value = encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value
4151+
auth_value_changed = existing_tool.auth_value != gateway_tool_auth_value
4152+
auth_fields_changed = existing_tool.auth_type != gateway.auth_type or auth_value_changed or existing_tool.visibility != gateway.visibility
41374153

41384154
if basic_fields_changed or schema_fields_changed or auth_fields_changed:
41394155
fields_to_update = True
@@ -4151,7 +4167,7 @@ def _update_or_create_tools(self, db: Session, tools: List[Any], gateway: DbGate
41514167
existing_tool.output_schema = tool.output_schema
41524168
existing_tool.jsonpath_filter = tool.jsonpath_filter
41534169
existing_tool.auth_type = gateway.auth_type
4154-
existing_tool.auth_value = gateway.auth_value
4170+
existing_tool.auth_value = encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value
41554171
existing_tool.visibility = gateway.visibility
41564172
logger.debug(f"Updated existing tool: {tool.name}")
41574173
else:

mcpgateway/services/tool_service.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2835,7 +2835,8 @@ async def invoke_tool(
28352835
"name": gateway.name,
28362836
"url": gateway.url,
28372837
"auth_type": gateway.auth_type,
2838-
"auth_value": gateway.auth_value,
2838+
# DbGateway.auth_value is JSON (dict); downstream code expects an encoded str.
2839+
"auth_value": encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value,
28392840
"auth_query_params": gateway.auth_query_params,
28402841
"oauth_config": gateway.oauth_config,
28412842
"ca_certificate": gateway.ca_certificate,
@@ -3022,7 +3023,9 @@ async def invoke_tool(
30223023
gateway_oauth_config = gateway_payload.get("oauth_config") if has_gateway and isinstance(gateway_payload.get("oauth_config"), dict) else None
30233024
if has_gateway and gateway is not None:
30243025
runtime_gateway_auth_value = getattr(gateway, "auth_value", None)
3025-
if isinstance(runtime_gateway_auth_value, str):
3026+
if isinstance(runtime_gateway_auth_value, dict):
3027+
gateway_auth_value = encode_auth(runtime_gateway_auth_value)
3028+
elif isinstance(runtime_gateway_auth_value, str):
30263029
gateway_auth_value = runtime_gateway_auth_value
30273030
runtime_gateway_query_params = getattr(gateway, "auth_query_params", None)
30283031
if isinstance(runtime_gateway_query_params, dict):
@@ -3053,7 +3056,9 @@ async def invoke_tool(
30533056
tool_oauth_config = hydrated_tool_oauth_config
30543057
if has_gateway and tool_auth_row.gateway:
30553058
hydrated_gateway_auth_value = getattr(tool_auth_row.gateway, "auth_value", None)
3056-
if isinstance(hydrated_gateway_auth_value, str):
3059+
if isinstance(hydrated_gateway_auth_value, dict):
3060+
gateway_auth_value = encode_auth(hydrated_gateway_auth_value)
3061+
elif isinstance(hydrated_gateway_auth_value, str):
30573062
gateway_auth_value = hydrated_gateway_auth_value
30583063
hydrated_gateway_query_params = getattr(tool_auth_row.gateway, "auth_query_params", None)
30593064
if isinstance(hydrated_gateway_query_params, dict):

tests/unit/mcpgateway/services/test_export_service.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,3 +1696,81 @@ async def test_export_selected_resources_scoped_no_visible_match_returns_empty(e
16961696
)
16971697
assert exported == []
16981698
mock_db.execute.assert_not_called()
1699+
1700+
1701+
@pytest.mark.asyncio
1702+
async def test_export_gateways_masked_auth_encodes_dict_auth_value(export_service, mock_db):
1703+
"""Batch-fetched auth_value that is a dict (authheaders) must be encoded before export."""
1704+
# First-Party
1705+
from mcpgateway.config import settings
1706+
from mcpgateway.utils.services_auth import decode_auth
1707+
1708+
auth_dict = {"X-Custom-Auth": "my-token", "X-Org-ID": "org-42"}
1709+
1710+
gateway = GatewayRead(
1711+
id="gw1",
1712+
name="gw-authheaders",
1713+
url="https://gw.example.com",
1714+
description="gateway with authheaders",
1715+
transport="SSE",
1716+
capabilities={},
1717+
created_at=datetime.now(timezone.utc),
1718+
updated_at=datetime.now(timezone.utc),
1719+
enabled=True,
1720+
reachable=True,
1721+
last_seen=datetime.now(timezone.utc),
1722+
auth_type="authheaders",
1723+
auth_value=settings.masked_auth_value,
1724+
auth_username=None,
1725+
auth_password=None,
1726+
auth_token=None,
1727+
auth_header_key=None,
1728+
auth_header_value=None,
1729+
tags=[],
1730+
slug="gw-authheaders",
1731+
passthrough_headers=None,
1732+
)
1733+
1734+
export_service.gateway_service.list_gateways.return_value = ([gateway], None)
1735+
# DB batch query returns the raw dict (JSON column value for authheaders)
1736+
mock_db.execute.return_value.all.return_value = [("gw1", "authheaders", auth_dict)]
1737+
1738+
exported = await export_service._export_gateways(mock_db, None, False)
1739+
1740+
assert len(exported) == 1
1741+
assert exported[0]["auth_type"] == "authheaders"
1742+
# auth_value must be an encoded string, not the raw dict
1743+
assert isinstance(exported[0]["auth_value"], str)
1744+
assert decode_auth(exported[0]["auth_value"]) == auth_dict
1745+
1746+
1747+
@pytest.mark.asyncio
1748+
async def test_export_selected_gateways_encodes_dict_auth_value(export_service, mock_db):
1749+
"""_export_selected_gateways must encode dict auth_value (authheaders) before export."""
1750+
# First-Party
1751+
from mcpgateway.utils.services_auth import decode_auth
1752+
1753+
auth_dict = {"X-Custom-Auth": "my-token"}
1754+
1755+
db_gateway = MagicMock()
1756+
db_gateway.id = "gw1"
1757+
db_gateway.name = "gw-authheaders"
1758+
db_gateway.url = "https://gw.example.com"
1759+
db_gateway.description = "desc"
1760+
db_gateway.transport = "SSE"
1761+
db_gateway.capabilities = {}
1762+
db_gateway.is_active = True
1763+
db_gateway.tags = []
1764+
db_gateway.passthrough_headers = []
1765+
db_gateway.auth_type = "authheaders"
1766+
db_gateway.auth_value = auth_dict # JSON column stores plain dict
1767+
db_gateway.auth_query_params = None
1768+
1769+
mock_db.execute.return_value.scalars.return_value.all.return_value = [db_gateway]
1770+
1771+
exported = await export_service._export_selected_gateways(mock_db, ["gw1"])
1772+
1773+
assert len(exported) == 1
1774+
assert exported[0]["auth_type"] == "authheaders"
1775+
assert isinstance(exported[0]["auth_value"], str)
1776+
assert decode_auth(exported[0]["auth_value"]) == auth_dict

tests/unit/mcpgateway/services/test_gateway_service_helpers.py

Lines changed: 167 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Standard
55
import tempfile
66
from types import SimpleNamespace
7+
from unittest.mock import AsyncMock, MagicMock, Mock
78

89
# Third-Party
910
import pytest
@@ -12,7 +13,7 @@
1213
from mcpgateway.config import settings
1314
from mcpgateway.schemas import GatewayRead
1415
from mcpgateway.services.gateway_service import GatewayConnectionError, GatewayNameConflictError, GatewayService, OAuthToolValidationError
15-
from mcpgateway.utils.services_auth import decode_auth
16+
from mcpgateway.utils.services_auth import decode_auth, encode_auth
1617
from mcpgateway.validation.tags import validate_tags_field
1718

1819

@@ -133,3 +134,168 @@ def test_gateway_service_validate_tools_valueerror(monkeypatch):
133134
with pytest.raises(GatewayConnectionError) as excinfo:
134135
service._validate_tools([{"name": "tool-other"}])
135136
assert "ValueError" in str(excinfo.value)
137+
138+
139+
@pytest.mark.asyncio
140+
async def test_authheaders_auth_value_stored_as_dict(monkeypatch):
141+
"""Verify that registering a gateway with authheaders stores auth_value as a plain dict.
142+
143+
auth_value DB column is Mapped[Optional[Dict[str, str]]] (JSON). Storing a string
144+
in that column causes the driver to write JSON null, which breaks health checks
145+
and the auto-refresh loop. The creation path must store the plain dict, consistent
146+
with the update path and the column type annotation.
147+
"""
148+
# Verify the type contract: encode_auth() returns str, NOT a dict.
149+
# This is why storing its result in a Dict-typed JSON column produces null.
150+
encoded = encode_auth({"X-Key": "value"})
151+
assert isinstance(encoded, str), "encode_auth must return str — storing it in a dict JSON column yields null"
152+
153+
# Build a minimal gateway with authheaders
154+
# Standard
155+
from types import SimpleNamespace as NS
156+
157+
gateway = NS(
158+
name="test-gw",
159+
url="http://localhost:8000/mcp",
160+
description=None,
161+
transport="sse",
162+
tags=[],
163+
passthrough_headers=None,
164+
auth_type="authheaders",
165+
auth_value=None,
166+
auth_headers=[
167+
{"key": "X-Custom-Auth-Header", "value": "my-token"},
168+
{"key": "X-Custom-User-ID", "value": "user-123"},
169+
],
170+
auth_query_param_key=None,
171+
auth_query_param_value=None,
172+
auth_query_params=None,
173+
oauth_config=None,
174+
one_time_auth=False,
175+
ca_certificate=None,
176+
ca_certificate_sig=None,
177+
signing_algorithm=None,
178+
visibility="public",
179+
enabled=True,
180+
team_id=None,
181+
owner_email=None,
182+
gateway_mode="cache",
183+
)
184+
185+
# First-Party
186+
from mcpgateway.schemas import ToolCreate
187+
188+
fake_tool = ToolCreate(name="echo", integration_type="REST", request_type="POST", url="http://localhost:8000/mcp")
189+
190+
service = GatewayService()
191+
service._check_gateway_uniqueness = MagicMock(return_value=None)
192+
service._initialize_gateway = AsyncMock(return_value=({"tools": {}}, [fake_tool], [], []))
193+
service._notify_gateway_added = AsyncMock()
194+
195+
monkeypatch.setattr("mcpgateway.services.gateway_service.get_for_update", lambda *_a, **_kw: None)
196+
monkeypatch.setattr(
197+
"mcpgateway.services.gateway_service.GatewayRead.model_validate",
198+
lambda x: MagicMock(masked=lambda: x),
199+
)
200+
201+
db = MagicMock()
202+
db.flush = Mock()
203+
db.refresh = Mock()
204+
205+
# Snapshot at db.add() time — tools flow through the gateway relationship (gateway.tools=tools),
206+
# not separate db.add() calls. _prepare_gateway_for_read() later mutates db_gateway.auth_value
207+
# to an encoded string for the GatewayRead response; we capture before that mutation.
208+
# First-Party
209+
from mcpgateway.db import Gateway as DbGateway
210+
211+
captured_gw: dict = {}
212+
captured_tool_auth_values: list = []
213+
214+
def _capture_add(obj):
215+
if isinstance(obj, DbGateway):
216+
captured_gw["auth_value"] = obj.auth_value # snapshot before any mutation
217+
for t in obj.tools or []:
218+
captured_tool_auth_values.append(t.auth_value)
219+
220+
db.add = Mock(side_effect=_capture_add)
221+
222+
await service.register_gateway(db, gateway)
223+
224+
# --- DbGateway assertion ---
225+
# auth_value must be a plain dict — NOT a string.
226+
# A string stored in a Mapped[Optional[Dict[str, str]]] JSON column is written as JSON null.
227+
assert "auth_value" in captured_gw, "db.add was never called with a DbGateway object"
228+
assert isinstance(captured_gw["auth_value"], dict), f"DbGateway.auth_value must be dict for authheaders auth type, got {type(captured_gw['auth_value'])}: {captured_gw['auth_value']!r}"
229+
assert captured_gw["auth_value"] == {"X-Custom-Auth-Header": "my-token", "X-Custom-User-ID": "user-123"}
230+
231+
# --- DbTool assertion ---
232+
# DbTool.auth_value is Mapped[Optional[str]] (Text), so it must be an encoded string,
233+
# not a raw dict. tool_service.py calls decode_auth() on it at read-time.
234+
assert len(captured_tool_auth_values) == 1, "expected exactly one DbTool to be added"
235+
assert isinstance(captured_tool_auth_values[0], str), f"DbTool.auth_value must be an encoded string for Text column, got {type(captured_tool_auth_values[0])}: {captured_tool_auth_values[0]!r}"
236+
# Decoding must recover the original headers dict
237+
assert decode_auth(captured_tool_auth_values[0]) == {"X-Custom-Auth-Header": "my-token", "X-Custom-User-ID": "user-123"}
238+
239+
240+
def test_update_or_create_tools_authheaders_no_spurious_update():
241+
"""Verify _update_or_create_tools does NOT trigger a spurious update when the
242+
gateway's auth_value dict matches the existing tool's encoded auth_value.
243+
244+
encode_auth() uses os.urandom(12) for the AES-GCM nonce, so comparing
245+
ciphertext would always differ even when the plaintext is identical. The
246+
comparison must use decoded/plaintext values to avoid write amplification
247+
on every health-check refresh cycle.
248+
"""
249+
# Standard
250+
from types import SimpleNamespace as NS
251+
252+
service = GatewayService()
253+
254+
auth_dict = {"X-My-Header": "secret-val"}
255+
encoded = encode_auth(auth_dict)
256+
original_encoded = encoded # save for byte-for-byte comparison
257+
258+
# Existing tool already has the correctly encoded auth_value stored
259+
existing = MagicMock()
260+
existing.original_name = "my-tool"
261+
existing.url = "http://gw.example.com/mcp"
262+
existing.description = "desc"
263+
existing.original_description = "desc"
264+
existing.integration_type = "MCP"
265+
existing.request_type = "POST"
266+
existing.headers = {}
267+
existing.input_schema = {}
268+
existing.output_schema = None
269+
existing.jsonpath_filter = None
270+
existing.auth_type = "authheaders"
271+
existing.auth_value = encoded # Text column — already encoded
272+
existing.visibility = "public"
273+
274+
db = MagicMock()
275+
db.execute.return_value.scalars.return_value.all.return_value = [existing]
276+
277+
tool = NS(
278+
name="my-tool",
279+
description="desc",
280+
input_schema={},
281+
output_schema=None,
282+
request_type="POST",
283+
headers={},
284+
annotations=None,
285+
jsonpath_filter=None,
286+
)
287+
288+
gateway = MagicMock()
289+
gateway.id = "gw-1"
290+
gateway.url = "http://gw.example.com/mcp"
291+
gateway.auth_type = "authheaders"
292+
gateway.auth_value = auth_dict # JSON column — plain dict
293+
gateway.visibility = "public"
294+
295+
result = service._update_or_create_tools(db, [tool], gateway, "update")
296+
297+
# No new tools returned
298+
assert result == []
299+
# auth_value must be the EXACT same string — no spurious re-encryption
300+
assert existing.auth_value is original_encoded, f"auth_value was spuriously rewritten: {existing.auth_value!r} != {original_encoded!r}"
301+
assert decode_auth(existing.auth_value) == auth_dict

0 commit comments

Comments
 (0)