Skip to content

Commit 35f71c8

Browse files
jonpspriclaude
andcommitted
fix(security): harden A2A runtime trust boundaries, scoping, and validation
Security fixes: - Add authentication to the /invoke endpoint (HMAC trust header gate) - Add authentication to the A2A proxy catch-all before forwarding - Require non-empty A2A_RUST_AUTH_SECRET at startup to prevent predictable trust headers - Use server-derived owner_email/team_id in agent registration instead of trusting client-provided values - Add visibility scoping to the flush_events internal endpoint - Close visibility bypass on events/replay when task row is absent - Close visibility bypass on push config get/list when agent_id omitted - Validate webhook URL through Pydantic schema (SSRF protection) in internal push config creation endpoint - Add webhook URL validation via validate_core_url on A2APushNotificationConfigCreate schema - Exclude auth_token from A2APushNotificationConfigRead (Field exclude) - Change _check_agent_access_by_id to fail-closed for deleted agents - Make encode_auth_context a hard error instead of silent empty fallback - Sanitize all error responses to external callers — strip internal backend URLs, Python response bodies, and reqwest error details - Redact Redis URL credentials in logs - Use constant-time comparison for session fingerprint validation - Warn at startup if HTTP listener is non-loopback Correctness fixes: - Pass actual JSON-RPC method name through full_authenticate - Add explicit authz match arms for cancel, delete, create, and stream - Remove x-forwarded-for from default session fingerprint headers - Fix cargo fmt conflicts with pragma allowlist comments - Cap retry backoff at 60 seconds to prevent runaway delays Error handling improvements: - Replace expect() in queue worker with graceful error handling for semaphore closure, panicked JoinSet tasks, and missing results - Add logger.exception + db.invalidate fallback to all 11 internal A2A endpoint exception handlers - Upgrade query param decryption failure log from debug to warning - Log warning on auth decoding failure during agent update - Return None from event_store.store_event on serialization failure - Replace .ok() with logged match on proxy response JSON parsing - Replace mutex expect("poisoned") with unwrap_or_else recovery - Upgrade shadow mode payload mismatch log to warning with traceback - Upgrade Redis cache invalidation failure log to warning - Bound resolve_inflight DashMap to prevent unbounded memory growth Comment and code cleanup: - Fix handle_a2a_proxy doc (does not inject trust headers) - Fix coalesce_jobs doc (does not match on timeout) - Fix proxy_task_method doc (also handles cancel, not just reads) - Clarify trust header as keyed SHA-256, not formal HMAC - Renumber handle_a2a_invoke steps sequentially (1-2-3-3a-4-5) - Document _visible_agent_ids admin-bypass divergence - Remove leftover logger.info debug statements from register_agent - Remove commented-out dead code from register_agent - Remove duplicate TOOLS_MANAGE_PLUGINS constant - Remove constant-comparison test (test_version_endpoint_redis_conditions) Test coverage: - Add visibility tests for cancel_task (wrong team, admin, public-only) - Add deny-path tests for events/replay (inaccessible agent, missing task) - Add 32 tests for _check_agent_access_by_id, _visible_agent_ids, get_task, list_tasks, and _check_server_access - Add tests for auth_secret startup rejection (unit + binary) - Update integration tests for trust-gated /invoke and authenticated proxy endpoints Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Jonathan Springer <jps@s390x.com>
1 parent 4655195 commit 35f71c8

33 files changed

+1704
-590
lines changed

.secrets.baseline

Lines changed: 82 additions & 56 deletions
Large diffs are not rendered by default.

mcpgateway/alembic/versions/a2a_v1_domain_models_3f7e9d1a2b4c.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,19 @@ def upgrade() -> None:
135135

136136
# --- Backfill a2a_agent_auth from existing a2a_agents auth columns -------
137137
if "a2a_agent_auth" in inspector.get_table_names():
138+
# Standard
139+
import json
140+
import uuid
141+
138142
conn = op.get_bind()
139143
agents_with_auth = conn.execute(
140144
sa.text("SELECT id, auth_type, auth_value, auth_query_params " "FROM a2a_agents " "WHERE auth_type IS NOT NULL " "AND id NOT IN (SELECT a2a_agent_id FROM a2a_agent_auth)")
141145
).fetchall()
142146
for agent in agents_with_auth:
143-
# Standard
144-
import uuid
147+
# auth_query_params may be a Python dict (from a JSON column);
148+
# serialize it so the untyped text bind works on all drivers.
149+
raw_params = agent[3]
150+
params_str = json.dumps(raw_params) if isinstance(raw_params, (dict, list)) else raw_params
145151

146152
conn.execute(
147153
sa.text(
@@ -153,7 +159,7 @@ def upgrade() -> None:
153159
"agent_id": agent[0],
154160
"auth_type": agent[1],
155161
"auth_value": agent[2],
156-
"auth_query_params": agent[3],
162+
"auth_query_params": params_str,
157163
},
158164
)
159165

mcpgateway/alembic/versions/ffe4494639d3_add_a2a_task_events_table.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def upgrade() -> None:
2929
op.create_table(
3030
"a2a_task_events",
3131
sa.Column("id", sa.String(36), primary_key=True, nullable=False),
32+
sa.Column("a2a_agent_id", sa.String(36), sa.ForeignKey("a2a_agents.id", ondelete="CASCADE"), nullable=True),
3233
sa.Column("task_id", sa.String(255), nullable=False),
3334
sa.Column("event_id", sa.String(36), nullable=False),
3435
sa.Column("sequence", sa.BigInteger(), nullable=False),
@@ -38,6 +39,7 @@ def upgrade() -> None:
3839
)
3940
op.create_index("ix_a2a_task_events_task_id", "a2a_task_events", ["task_id"])
4041
op.create_index("ix_a2a_task_events_task_seq", "a2a_task_events", ["task_id", "sequence"])
42+
op.create_index("ix_a2a_task_events_agent_id", "a2a_task_events", ["a2a_agent_id"])
4143

4244

4345
def downgrade() -> None:

mcpgateway/db.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,8 +1258,6 @@ class Permissions:
12581258
TOOLS_EXECUTE = "tools.execute"
12591259
TOOLS_MANAGE_PLUGINS = "tools.manage_plugins"
12601260

1261-
TOOLS_MANAGE_PLUGINS = "tools.manage_plugins"
1262-
12631261
# Resource permissions
12641262
RESOURCES_CREATE = "resources.create"
12651263
RESOURCES_READ = "resources.read"
@@ -5066,6 +5064,7 @@ class A2ATaskEvent(Base):
50665064
__tablename__ = "a2a_task_events"
50675065

50685066
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
5067+
a2a_agent_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey("a2a_agents.id", ondelete="CASCADE"), nullable=True, index=True)
50695068
task_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
50705069
event_id: Mapped[str] = mapped_column(String(36), nullable=False)
50715070
sequence: Mapped[int] = mapped_column(BigInteger, nullable=False)

mcpgateway/main.py

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8997,10 +8997,14 @@ async def handle_internal_a2a_agent_resolve(request: Request, agent_name: str):
89978997

89988998
return ORJSONResponse(status_code=200, content=result)
89998999
except Exception:
9000+
logger.exception("Internal A2A endpoint error")
90009001
try:
90019002
db.rollback()
90029003
except Exception:
9003-
pass
9004+
try:
9005+
db.invalidate()
9006+
except Exception:
9007+
pass # nosec B110
90049008
raise
90059009
finally:
90069010
db.close()
@@ -9044,10 +9048,14 @@ async def handle_internal_a2a_agent_card(request: Request, agent_name: str):
90449048
return ORJSONResponse(status_code=404, content={"error": f"agent '{agent_name}' not found"})
90459049
return ORJSONResponse(status_code=200, content=card)
90469050
except Exception:
9051+
logger.exception("Internal A2A endpoint error")
90479052
try:
90489053
db.rollback()
90499054
except Exception:
9050-
pass
9055+
try:
9056+
db.invalidate()
9057+
except Exception:
9058+
pass # nosec B110
90519059
raise
90529060
finally:
90539061
db.close()
@@ -9080,10 +9088,14 @@ async def handle_internal_a2a_tasks_get(request: Request):
90809088
return ORJSONResponse(status_code=404, content={"error": f"task '{task_id}' not found"})
90819089
return ORJSONResponse(status_code=200, content=task)
90829090
except Exception:
9091+
logger.exception("Internal A2A endpoint error")
90839092
try:
90849093
db.rollback()
90859094
except Exception:
9086-
pass
9095+
try:
9096+
db.invalidate()
9097+
except Exception:
9098+
pass # nosec B110
90879099
raise
90889100
finally:
90899101
db.close()
@@ -9116,10 +9128,14 @@ async def handle_internal_a2a_tasks_list(request: Request):
91169128
tasks = service.list_tasks(db, agent_id=agent_id, state=state, limit=limit, offset=offset, user_email=user_email, token_teams=token_teams)
91179129
return ORJSONResponse(status_code=200, content={"tasks": tasks})
91189130
except Exception:
9131+
logger.exception("Internal A2A endpoint error")
91199132
try:
91209133
db.rollback()
91219134
except Exception:
9122-
pass
9135+
try:
9136+
db.invalidate()
9137+
except Exception:
9138+
pass # nosec B110
91239139
raise
91249140
finally:
91259141
db.close()
@@ -9152,10 +9168,14 @@ async def handle_internal_a2a_tasks_cancel(request: Request):
91529168
return ORJSONResponse(status_code=404, content={"error": f"task '{task_id}' not found"})
91539169
return ORJSONResponse(status_code=200, content=task)
91549170
except Exception:
9171+
logger.exception("Internal A2A endpoint error")
91559172
try:
91569173
db.rollback()
91579174
except Exception:
9158-
pass
9175+
try:
9176+
db.invalidate()
9177+
except Exception:
9178+
pass # nosec B110
91599179
raise
91609180
finally:
91619181
db.close()
@@ -9174,20 +9194,31 @@ async def handle_internal_a2a_push_create(request: Request):
91749194
if not body.get("a2a_agent_id") or not body.get("task_id") or not body.get("webhook_url"):
91759195
return ORJSONResponse(status_code=400, content={"error": "a2a_agent_id, task_id, and webhook_url are required"})
91769196

9197+
# Validate webhook URL through the schema to enforce SSRF protection.
91779198
# First-Party
9199+
from mcpgateway.schemas import A2APushNotificationConfigCreate # pylint: disable=import-outside-toplevel
91789200
from mcpgateway.services.a2a_service import A2AAgentService # pylint: disable=import-outside-toplevel
91799201

9202+
try:
9203+
validated = A2APushNotificationConfigCreate(**body)
9204+
except Exception as validation_err:
9205+
return ORJSONResponse(status_code=400, content={"error": f"invalid push config: {validation_err}"})
9206+
91809207
user_email, token_teams = _get_internal_a2a_scope_context(request)
91819208
service = A2AAgentService()
91829209
if not service._check_agent_access_by_id(db, body["a2a_agent_id"], user_email, token_teams): # pylint: disable=protected-access
91839210
return ORJSONResponse(status_code=404, content={"error": "agent not found"})
9184-
cfg = service.create_push_config(db, body)
9211+
cfg = service.create_push_config(db, validated.model_dump())
91859212
return ORJSONResponse(status_code=200, content=cfg)
91869213
except Exception:
9214+
logger.exception("Internal A2A endpoint error")
91879215
try:
91889216
db.rollback()
91899217
except Exception:
9190-
pass
9218+
try:
9219+
db.invalidate()
9220+
except Exception:
9221+
pass # nosec B110
91919222
raise
91929223
finally:
91939224
db.close()
@@ -9213,17 +9244,21 @@ async def handle_internal_a2a_push_get(request: Request):
92139244

92149245
user_email, token_teams = _get_internal_a2a_scope_context(request)
92159246
service = A2AAgentService()
9216-
if agent_id and not service._check_agent_access_by_id(db, agent_id, user_email, token_teams): # pylint: disable=protected-access
9217-
return ORJSONResponse(status_code=404, content={"error": f"push config for task '{task_id}' not found"})
92189247
cfg = service.get_push_config(db, task_id, agent_id=agent_id)
92199248
if cfg is None:
92209249
return ORJSONResponse(status_code=404, content={"error": f"push config for task '{task_id}' not found"})
9250+
if not service._check_agent_access_by_id(db, cfg["a2a_agent_id"], user_email, token_teams): # pylint: disable=protected-access
9251+
return ORJSONResponse(status_code=404, content={"error": f"push config for task '{task_id}' not found"})
92219252
return ORJSONResponse(status_code=200, content=cfg)
92229253
except Exception:
9254+
logger.exception("Internal A2A endpoint error")
92239255
try:
92249256
db.rollback()
92259257
except Exception:
9226-
pass
9258+
try:
9259+
db.invalidate()
9260+
except Exception:
9261+
pass # nosec B110
92279262
raise
92289263
finally:
92299264
db.close()
@@ -9247,15 +9282,19 @@ async def handle_internal_a2a_push_list(request: Request):
92479282

92489283
user_email, token_teams = _get_internal_a2a_scope_context(request)
92499284
service = A2AAgentService()
9250-
if agent_id and not service._check_agent_access_by_id(db, agent_id, user_email, token_teams): # pylint: disable=protected-access
9251-
return ORJSONResponse(status_code=200, content={"configs": []})
92529285
configs = service.list_push_configs(db, agent_id=agent_id, task_id=task_id)
9253-
return ORJSONResponse(status_code=200, content={"configs": configs})
9286+
# Filter configs to only those whose owning agent is visible to the caller.
9287+
visible = [c for c in configs if service._check_agent_access_by_id(db, c["a2a_agent_id"], user_email, token_teams)] # pylint: disable=protected-access
9288+
return ORJSONResponse(status_code=200, content={"configs": visible})
92549289
except Exception:
9290+
logger.exception("Internal A2A endpoint error")
92559291
try:
92569292
db.rollback()
92579293
except Exception:
9258-
pass
9294+
try:
9295+
db.invalidate()
9296+
except Exception:
9297+
pass # nosec B110
92599298
raise
92609299
finally:
92619300
db.close()
@@ -9289,10 +9328,14 @@ async def handle_internal_a2a_push_delete(request: Request):
92899328
return ORJSONResponse(status_code=404, content={"error": f"push config '{config_id}' not found"})
92909329
return ORJSONResponse(status_code=200, content={"deleted": True})
92919330
except Exception:
9331+
logger.exception("Internal A2A endpoint error")
92929332
try:
92939333
db.rollback()
92949334
except Exception:
9295-
pass
9335+
try:
9336+
db.invalidate()
9337+
except Exception:
9338+
pass # nosec B110
92969339
raise
92979340
finally:
92989341
db.close()
@@ -9313,16 +9356,32 @@ async def handle_internal_a2a_events_flush(request: Request):
93139356
return ORJSONResponse(status_code=200, content={"count": 0})
93149357

93159358
# First-Party
9359+
from mcpgateway.db import A2ATask as DbA2ATask # pylint: disable=import-outside-toplevel
93169360
from mcpgateway.services.a2a_service import A2AAgentService # pylint: disable=import-outside-toplevel
93179361

9362+
user_email, token_teams = _get_internal_a2a_scope_context(request)
93189363
service = A2AAgentService()
9364+
9365+
# Verify the caller has access to the agents that own the referenced tasks.
9366+
task_ids = {e["task_id"] for e in events if "task_id" in e}
9367+
if task_ids:
9368+
tasks = db.query(DbA2ATask).filter(DbA2ATask.task_id.in_(task_ids)).all()
9369+
agent_ids = {t.a2a_agent_id for t in tasks}
9370+
for agent_id in agent_ids:
9371+
if not service._check_agent_access_by_id(db, agent_id, user_email, token_teams): # pylint: disable=protected-access
9372+
return ORJSONResponse(status_code=403, content={"error": "access denied for one or more referenced tasks"})
9373+
93199374
count = service.flush_events(db, events)
93209375
return ORJSONResponse(status_code=200, content={"count": count})
93219376
except Exception:
9377+
logger.exception("Internal A2A endpoint error")
93229378
try:
93239379
db.rollback()
93249380
except Exception:
9325-
pass
9381+
try:
9382+
db.invalidate()
9383+
except Exception:
9384+
pass # nosec B110
93269385
raise
93279386
finally:
93289387
db.close()
@@ -9351,15 +9410,21 @@ async def handle_internal_a2a_events_replay(request: Request):
93519410
user_email, token_teams = _get_internal_a2a_scope_context(request)
93529411
service = A2AAgentService()
93539412
task_row = db.query(DbA2ATask).filter(DbA2ATask.task_id == task_id).first()
9354-
if task_row and not service._check_agent_access_by_id(db, task_row.a2a_agent_id, user_email, token_teams): # pylint: disable=protected-access
9413+
if task_row is None:
9414+
return ORJSONResponse(status_code=404, content={"error": "task not found"})
9415+
if not service._check_agent_access_by_id(db, task_row.a2a_agent_id, user_email, token_teams): # pylint: disable=protected-access
93559416
return ORJSONResponse(status_code=404, content={"error": "task not found"})
93569417
events = service.replay_events(db, task_id, after_sequence, limit=limit)
93579418
return ORJSONResponse(status_code=200, content={"events": events})
93589419
except Exception:
9420+
logger.exception("Internal A2A endpoint error")
93599421
try:
93609422
db.rollback()
93619423
except Exception:
9362-
pass
9424+
try:
9425+
db.invalidate()
9426+
except Exception:
9427+
pass # nosec B110
93639428
raise
93649429
finally:
93659430
db.close()

mcpgateway/schemas.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5575,6 +5575,12 @@ class A2APushNotificationConfigCreate(BaseModel):
55755575
events: Optional[List[str]] = None
55765576
enabled: bool = True
55775577

5578+
@field_validator("webhook_url")
5579+
@classmethod
5580+
def validate_webhook_url(cls, v: str) -> str:
5581+
"""Validate webhook URL for scheme, SSRF, and dangerous patterns."""
5582+
return validate_core_url(v, "Webhook URL")
5583+
55785584

55795585
class A2APushNotificationConfigRead(BaseModel):
55805586
"""Schema for reading a push notification webhook configuration."""
@@ -5585,7 +5591,7 @@ class A2APushNotificationConfigRead(BaseModel):
55855591
a2a_agent_id: str
55865592
task_id: str
55875593
webhook_url: str
5588-
auth_token: Optional[str] = None
5594+
auth_token: Optional[str] = Field(default=None, exclude=True)
55895595
events: Optional[List[str]] = None
55905596
enabled: bool
55915597
created_at: datetime
@@ -5595,6 +5601,7 @@ class A2APushNotificationConfigRead(BaseModel):
55955601
class A2ATaskEventCreate(BaseModel):
55965602
"""Schema for creating a task event log entry."""
55975603

5604+
a2a_agent_id: Optional[str] = None
55985605
task_id: str
55995606
event_id: str
56005607
sequence: int
@@ -5608,6 +5615,7 @@ class A2ATaskEventRead(BaseModel):
56085615
model_config = ConfigDict(from_attributes=True)
56095616

56105617
id: str
5618+
a2a_agent_id: Optional[str] = None
56115619
task_id: str
56125620
event_id: str
56135621
sequence: int

mcpgateway/services/a2a_protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def prepare_a2a_invocation(
393393
decrypted = decode_auth(encrypted_value)
394394
auth_query_params_decrypted[str(param_key)] = str(decrypted.get(param_key, ""))
395395
except Exception: # nosec B112
396-
logger.debug("Failed to decrypt query param %r for A2A agent invocation", param_key, exc_info=True)
396+
logger.warning("Failed to decrypt query param %r for A2A agent invocation — invocation proceeds without this credential", param_key, exc_info=True)
397397
continue
398398
if auth_query_params_decrypted:
399399
target_endpoint_url = apply_query_param_auth(target_endpoint_url, auth_query_params_decrypted)

mcpgateway/services/a2a_server_service.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def _check_server_access(server: DbServer, user_email: Optional[str], token_team
3939
return True
4040

4141
if token_teams is None and user_email is None:
42+
logger.debug("admin bypass: granting access to %s server %s", server.visibility, getattr(server, "name", "?"))
4243
return True
4344

4445
if not user_email:
@@ -313,8 +314,11 @@ def resolve_task_mapping(self, db: Session, server_id: str, server_task_id: str)
313314
def _find_a2a_interface(self, db: Session, server_id: str) -> Optional[DbServerInterface]:
314315
"""Return the first enabled ServerInterface with an A2A protocol for *server_id*.
315316
316-
The match is case-insensitive on the protocol field (e.g. "a2a", "A2A",
317-
"a2a/v1" all qualify).
317+
Matches any protocol whose lower-cased value starts with ``a2a``
318+
(e.g. ``a2a``, ``a2a/v1``, ``a2a-jsonrpc``). When a server exposes
319+
multiple A2A interfaces, the first one (by DB insertion order) is
320+
returned — this avoids ``MultipleResultsFound`` when both v0.3 and
321+
v1 bindings exist.
318322
319323
Args:
320324
db: Database session.
@@ -323,9 +327,14 @@ def _find_a2a_interface(self, db: Session, server_id: str) -> Optional[DbServerI
323327
Returns:
324328
Matching ServerInterface ORM instance, or None.
325329
"""
326-
query = select(DbServerInterface).where(
327-
DbServerInterface.server_id == server_id,
328-
DbServerInterface.enabled == True, # noqa: E712
329-
func.lower(DbServerInterface.protocol).in_(["a2a", "a2a/v1", "a2a/v0.3"]),
330+
query = (
331+
select(DbServerInterface)
332+
.where(
333+
DbServerInterface.server_id == server_id,
334+
DbServerInterface.enabled == True, # noqa: E712
335+
func.lower(DbServerInterface.protocol).like("a2a%"),
336+
)
337+
.order_by(DbServerInterface.created_at.desc())
338+
.limit(1)
330339
)
331340
return db.execute(query).scalar_one_or_none()

0 commit comments

Comments
 (0)