Skip to content

Commit d6224d4

Browse files
committed
fix(auth): use module-local async get_db() in _set_proxy_user_context
Resolves pylint W0621 (redefined-outer-name): the function shadowed the module-local async get_db() at :721 with a local import of mcpgateway.db.get_db. Switched to the existing async context manager, which also provides proper cancellation handling for MCP long-lived connections and removes the explicit try/finally/db.close() boilerplate (addresses review note M4). Signed-off-by: Jonathan Springer <jps@s390x.com>
1 parent 43a217b commit d6224d4

3 files changed

Lines changed: 146 additions & 24 deletions

File tree

.secrets.baseline

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"files": "(?x)( package-lock\\.json$ |Cargo\\.lock$ |uv\\.lock$ |go\\.sum$ |mcpgateway/sri_hashes\\.json$ )|^.secrets.baseline$",
44
"lines": null
55
},
6-
"generated_at": "2026-04-24T21:03:53Z",
6+
"generated_at": "2026-04-25T08:30:54Z",
77
"plugins_used": [
88
{
99
"name": "AWSKeyDetector"
@@ -5456,7 +5456,7 @@
54565456
"hashed_secret": "a4b48a81cdab1e1a5dd37907d6c85ca1c61ddc7c",
54575457
"is_secret": false,
54585458
"is_verified": false,
5459-
"line_number": 572,
5459+
"line_number": 681,
54605460
"type": "Secret Keyword",
54615461
"verified_result": null
54625462
}
@@ -9530,15 +9530,15 @@
95309530
"hashed_secret": "b4c9248600a42f8c38c01b632f392dbcb4c7b19a",
95319531
"is_secret": false,
95329532
"is_verified": false,
9533-
"line_number": 12937,
9533+
"line_number": 13105,
95349534
"type": "Hex High Entropy String",
95359535
"verified_result": null
95369536
},
95379537
{
95389538
"hashed_secret": "90bd1b48e958257948487b90bee080ba5ed00caa",
95399539
"is_secret": false,
95409540
"is_verified": false,
9541-
"line_number": 14112,
9541+
"line_number": 14280,
95429542
"type": "Hex High Entropy String",
95439543
"verified_result": null
95449544
}

mcpgateway/transports/streamablehttp_transport.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4550,11 +4550,12 @@ async def _set_proxy_user_context(proxy_user: str) -> dict[str, Any] | None:
45504550
"""
45514551
# First-Party
45524552
from mcpgateway.auth import _resolve_teams_from_db # pylint: disable=import-outside-toplevel
4553-
from mcpgateway.db import get_db # pylint: disable=import-outside-toplevel
45544553
from mcpgateway.services.email_auth_service import EmailAuthService # pylint: disable=import-outside-toplevel
45554554

4556-
db = next(get_db())
4557-
try:
4555+
# Use the module-local async get_db() context manager (line 721) rather than
4556+
# mcpgateway.db.get_db: it provides proper cancellation handling for MCP
4557+
# handlers cancelled mid-auth (client disconnect, timeout).
4558+
async with get_db() as db:
45584559
auth_service = EmailAuthService(db)
45594560
user_info = await auth_service.get_user_by_email(proxy_user)
45604561

@@ -4590,8 +4591,6 @@ async def _set_proxy_user_context(proxy_user: str) -> dict[str, Any] | None:
45904591
# pre-authentication contract of set_trace_context_from_teams.
45914592
set_trace_context_from_teams(token_teams or [], user_email=proxy_user, is_admin=is_admin, auth_method="proxy")
45924593
return None
4593-
finally:
4594-
db.close()
45954594

45964595

45974596
def get_streamable_http_auth_context() -> dict[str, Any]:

tests/unit/mcpgateway/transports/test_streamablehttp_transport.py

Lines changed: 138 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4893,11 +4893,11 @@ async def send(msg):
48934893
mock_user.is_active = True
48944894
mock_user.email = "proxy_user@example.com"
48954895

4896-
with patch("mcpgateway.db.get_db") as mock_get_db, patch(
4897-
"mcpgateway.services.email_auth_service.EmailAuthService"
4898-
) as mock_auth_service, patch(
4899-
"mcpgateway.auth._resolve_teams_from_db", new_callable=AsyncMock
4900-
) as mock_resolve_teams:
4896+
with (
4897+
patch("mcpgateway.db.get_db") as mock_get_db,
4898+
patch("mcpgateway.services.email_auth_service.EmailAuthService") as mock_auth_service,
4899+
patch("mcpgateway.auth._resolve_teams_from_db", new_callable=AsyncMock) as mock_resolve_teams,
4900+
):
49014901
mock_get_db.return_value = iter([Mock()])
49024902
mock_auth_service.return_value.get_user_by_email = AsyncMock(return_value=mock_user)
49034903
mock_resolve_teams.return_value = []
@@ -4944,11 +4944,11 @@ async def send(msg):
49444944
mock_user.is_active = True
49454945
mock_user.email = "proxy_fallback@example.com"
49464946

4947-
with patch("mcpgateway.db.get_db") as mock_get_db, patch(
4948-
"mcpgateway.services.email_auth_service.EmailAuthService"
4949-
) as mock_auth_service, patch(
4950-
"mcpgateway.auth._resolve_teams_from_db", new_callable=AsyncMock
4951-
) as mock_resolve_teams:
4947+
with (
4948+
patch("mcpgateway.db.get_db") as mock_get_db,
4949+
patch("mcpgateway.services.email_auth_service.EmailAuthService") as mock_auth_service,
4950+
patch("mcpgateway.auth._resolve_teams_from_db", new_callable=AsyncMock) as mock_resolve_teams,
4951+
):
49524952
mock_get_db.return_value = iter([Mock()])
49534953
mock_auth_service.return_value.get_user_by_email = AsyncMock(return_value=mock_user)
49544954
mock_resolve_teams.return_value = []
@@ -4964,6 +4964,129 @@ async def send(msg):
49644964
assert user_ctx["is_admin"] is False
49654965

49664966

4967+
# ---------------------------------------------------------------------------
4968+
# Proxy auth: disabled user rejected via _set_proxy_user_context (line 4567, 4727)
4969+
# ---------------------------------------------------------------------------
4970+
4971+
4972+
@pytest.mark.asyncio
4973+
async def test_streamable_http_auth_proxy_rejects_disabled_user(monkeypatch):
4974+
"""Disabled user gets 401 'Account disabled' through proxy auth (lines 4567, 4727)."""
4975+
monkeypatch.setattr("mcpgateway.transports.streamablehttp_transport.settings.mcp_client_auth_enabled", False)
4976+
monkeypatch.setattr("mcpgateway.transports.streamablehttp_transport.settings.trust_proxy_auth", True)
4977+
monkeypatch.setattr("mcpgateway.transports.streamablehttp_transport.settings.trust_proxy_auth_dangerously", True)
4978+
monkeypatch.setattr("mcpgateway.transports.streamablehttp_transport.settings.proxy_user_header", "x-forwarded-user")
4979+
4980+
scope = _make_scope(
4981+
"/servers/1/mcp",
4982+
headers=[
4983+
(b"x-forwarded-user", b"disabled@example.com"),
4984+
],
4985+
)
4986+
sent = []
4987+
4988+
async def send(msg):
4989+
sent.append(msg)
4990+
4991+
mock_user = Mock()
4992+
mock_user.is_admin = True
4993+
mock_user.is_active = False
4994+
mock_user.email = "disabled@example.com"
4995+
4996+
with patch("mcpgateway.db.get_db") as mock_get_db, patch("mcpgateway.services.email_auth_service.EmailAuthService") as mock_auth_service:
4997+
mock_get_db.return_value = iter([Mock()])
4998+
mock_auth_service.return_value.get_user_by_email = AsyncMock(return_value=mock_user)
4999+
5000+
result = await streamable_http_auth(scope, None, send)
5001+
5002+
assert result is False
5003+
starts = [m for m in sent if m.get("type") == "http.response.start"]
5004+
assert starts and starts[0]["status"] == 401
5005+
bodies = [m for m in sent if m.get("type") == "http.response.body"]
5006+
assert bodies and b"Account disabled" in bodies[0]["body"]
5007+
5008+
5009+
# ---------------------------------------------------------------------------
5010+
# Proxy auth: admin bypass when user not in DB (lines 4572-4575)
5011+
# ---------------------------------------------------------------------------
5012+
5013+
5014+
@pytest.mark.asyncio
5015+
async def test_streamable_http_auth_proxy_admin_bypass_no_db_record(monkeypatch):
5016+
"""Platform admin email gets admin bypass when not in DB and require_user_in_db=False (lines 4572-4575)."""
5017+
monkeypatch.setattr("mcpgateway.transports.streamablehttp_transport.settings.mcp_client_auth_enabled", False)
5018+
monkeypatch.setattr("mcpgateway.transports.streamablehttp_transport.settings.trust_proxy_auth", True)
5019+
monkeypatch.setattr("mcpgateway.transports.streamablehttp_transport.settings.trust_proxy_auth_dangerously", True)
5020+
monkeypatch.setattr("mcpgateway.transports.streamablehttp_transport.settings.proxy_user_header", "x-forwarded-user")
5021+
monkeypatch.setattr("mcpgateway.transports.streamablehttp_transport.settings.require_user_in_db", False)
5022+
monkeypatch.setattr("mcpgateway.transports.streamablehttp_transport.settings.platform_admin_email", "admin@example.com")
5023+
5024+
scope = _make_scope(
5025+
"/servers/1/mcp",
5026+
headers=[
5027+
(b"x-forwarded-user", b"admin@example.com"),
5028+
],
5029+
)
5030+
sent = []
5031+
5032+
async def send(msg):
5033+
sent.append(msg)
5034+
5035+
with patch("mcpgateway.db.get_db") as mock_get_db, patch("mcpgateway.services.email_auth_service.EmailAuthService") as mock_auth_service:
5036+
mock_get_db.return_value = iter([Mock()])
5037+
mock_auth_service.return_value.get_user_by_email = AsyncMock(return_value=None)
5038+
5039+
result = await streamable_http_auth(scope, None, send)
5040+
5041+
assert result is True
5042+
assert sent == []
5043+
5044+
user_ctx = tr.user_context_var.get()
5045+
assert user_ctx["email"] == "admin@example.com"
5046+
assert user_ctx["teams"] is None
5047+
assert user_ctx["is_admin"] is True
5048+
assert user_ctx["auth_method"] == "proxy"
5049+
5050+
5051+
# ---------------------------------------------------------------------------
5052+
# Proxy auth: unknown user rejected (line 4577, 4727)
5053+
# ---------------------------------------------------------------------------
5054+
5055+
5056+
@pytest.mark.asyncio
5057+
async def test_streamable_http_auth_proxy_rejects_unknown_user(monkeypatch):
5058+
"""Unknown user (not in DB, not platform admin) gets 401 'User not found' (lines 4577, 4727)."""
5059+
monkeypatch.setattr("mcpgateway.transports.streamablehttp_transport.settings.mcp_client_auth_enabled", False)
5060+
monkeypatch.setattr("mcpgateway.transports.streamablehttp_transport.settings.trust_proxy_auth", True)
5061+
monkeypatch.setattr("mcpgateway.transports.streamablehttp_transport.settings.trust_proxy_auth_dangerously", True)
5062+
monkeypatch.setattr("mcpgateway.transports.streamablehttp_transport.settings.proxy_user_header", "x-forwarded-user")
5063+
monkeypatch.setattr("mcpgateway.transports.streamablehttp_transport.settings.require_user_in_db", True)
5064+
monkeypatch.setattr("mcpgateway.transports.streamablehttp_transport.settings.platform_admin_email", "admin@example.com")
5065+
5066+
scope = _make_scope(
5067+
"/servers/1/mcp",
5068+
headers=[
5069+
(b"x-forwarded-user", b"unknown@example.com"),
5070+
],
5071+
)
5072+
sent = []
5073+
5074+
async def send(msg):
5075+
sent.append(msg)
5076+
5077+
with patch("mcpgateway.db.get_db") as mock_get_db, patch("mcpgateway.services.email_auth_service.EmailAuthService") as mock_auth_service:
5078+
mock_get_db.return_value = iter([Mock()])
5079+
mock_auth_service.return_value.get_user_by_email = AsyncMock(return_value=None)
5080+
5081+
result = await streamable_http_auth(scope, None, send)
5082+
5083+
assert result is False
5084+
starts = [m for m in sent if m.get("type") == "http.response.start"]
5085+
assert starts and starts[0]["status"] == 401
5086+
bodies = [m for m in sent if m.get("type") == "http.response.body"]
5087+
assert bodies and b"User not found in database" in bodies[0]["body"]
5088+
5089+
49675090
@pytest.mark.asyncio
49685091
async def test_streamable_http_auth_proxy_user_context_on_valid_jwt(monkeypatch):
49695092
"""Proxy auth takes precedence even when a valid JWT Bearer header is present."""
@@ -4995,11 +5118,11 @@ async def send(msg):
49955118
mock_user.is_active = True
49965119
mock_user.email = "proxy_user@example.com"
49975120

4998-
with patch("mcpgateway.db.get_db") as mock_get_db, patch(
4999-
"mcpgateway.services.email_auth_service.EmailAuthService"
5000-
) as mock_auth_service, patch(
5001-
"mcpgateway.auth._resolve_teams_from_db", new_callable=AsyncMock
5002-
) as mock_resolve_teams:
5121+
with (
5122+
patch("mcpgateway.db.get_db") as mock_get_db,
5123+
patch("mcpgateway.services.email_auth_service.EmailAuthService") as mock_auth_service,
5124+
patch("mcpgateway.auth._resolve_teams_from_db", new_callable=AsyncMock) as mock_resolve_teams,
5125+
):
50035126
mock_get_db.return_value = iter([Mock()])
50045127
mock_auth_service.return_value.get_user_by_email = AsyncMock(return_value=mock_user)
50055128
mock_resolve_teams.return_value = []

0 commit comments

Comments
 (0)