|
81 | 81 | from mcpgateway.common.models import ListResourceTemplatesResult, LogLevel, Root |
82 | 82 | from mcpgateway.common.validators import SecurityValidator |
83 | 83 | from mcpgateway.config import settings |
84 | | -from mcpgateway.db import refresh_slugs_on_startup, SessionLocal |
| 84 | +from mcpgateway.db import A2AAgent as DbA2AAgent, A2APushNotificationConfig, A2ATask as DbA2ATask, refresh_slugs_on_startup, SessionLocal |
85 | 85 | from mcpgateway.db import Tool as DbTool |
86 | 86 | from mcpgateway.handlers.sampling import SamplingHandler |
87 | 87 | from mcpgateway.middleware.compression import SSEAwareCompressMiddleware |
|
111 | 111 | from mcpgateway.schemas import ( |
112 | 112 | A2AAgentCreate, |
113 | 113 | A2AAgentRead, |
| 114 | + A2APushNotificationConfigCreate, |
114 | 115 | A2AAgentUpdate, |
115 | 116 | CursorPaginatedA2AAgentsResponse, |
116 | 117 | CursorPaginatedGatewaysResponse, |
|
143 | 144 | ToolUpdate, |
144 | 145 | ) |
145 | 146 | from mcpgateway.services.a2a_service import A2AAgentError, A2AAgentNameConflictError, A2AAgentNotFoundError, A2AAgentService |
| 147 | +from mcpgateway.services.a2a_server_service import A2AServerService |
146 | 148 | from mcpgateway.services.cancellation_service import cancellation_service |
147 | 149 | from mcpgateway.services.completion_service import CompletionService |
148 | 150 | from mcpgateway.services.content_security import ContentSizeError, ContentTypeError |
@@ -8962,19 +8964,12 @@ async def handle_internal_a2a_agent_resolve(request: Request, agent_name: str): |
8962 | 8964 |
|
8963 | 8965 | db = SessionLocal() |
8964 | 8966 | try: |
8965 | | - # First-Party |
8966 | | - from mcpgateway.db import A2AAgent as DbA2AAgent # pylint: disable=import-outside-toplevel |
8967 | | - from mcpgateway.services.a2a_service import A2AAgentService # pylint: disable=import-outside-toplevel |
8968 | | - |
8969 | 8967 | user_email, token_teams = _get_internal_a2a_scope_context(request) |
8970 | 8968 | service = A2AAgentService() |
8971 | | - agent = db.query(DbA2AAgent).filter(DbA2AAgent.name == agent_name, DbA2AAgent.enabled == True).first() # noqa: E712 |
| 8969 | + agent = db.query(DbA2AAgent).filter(DbA2AAgent.name == agent_name, DbA2AAgent.enabled.is_(True)).first() |
8972 | 8970 | if not agent: |
8973 | | - # First-Party |
8974 | | - from mcpgateway.services.a2a_server_service import A2AServerService # pylint: disable=import-outside-toplevel |
8975 | | - |
8976 | | - server_service = A2AServerService() |
8977 | | - server_agent = server_service.resolve_server_agent(db, agent_name, user_email=user_email, token_teams=token_teams) |
| 8971 | + a2a_server_service = A2AServerService() |
| 8972 | + server_agent = a2a_server_service.resolve_server_agent(db, agent_name, user_email=user_email, token_teams=token_teams) |
8978 | 8973 | if server_agent: |
8979 | 8974 | return ORJSONResponse(status_code=200, content=server_agent) |
8980 | 8975 | return ORJSONResponse(status_code=404, content={"error": f"agent '{agent_name}' not found"}) |
@@ -9023,27 +9018,18 @@ async def handle_internal_a2a_agent_card(request: Request, agent_name: str): |
9023 | 9018 |
|
9024 | 9019 | db = SessionLocal() |
9025 | 9020 | try: |
9026 | | - # First-Party |
9027 | | - from mcpgateway.services.a2a_service import A2AAgentService # pylint: disable=import-outside-toplevel |
9028 | | - |
9029 | 9021 | user_email, token_teams = _get_internal_a2a_scope_context(request) |
9030 | 9022 | service = A2AAgentService() |
9031 | 9023 |
|
9032 | 9024 | # Check agent visibility before building the card to avoid loading |
9033 | 9025 | # sensitive relationship data for agents the caller cannot see. |
9034 | | - # First-Party |
9035 | | - from mcpgateway.db import A2AAgent as DbA2AAgent # pylint: disable=import-outside-toplevel |
9036 | | - |
9037 | | - agent = db.query(DbA2AAgent).filter(DbA2AAgent.name == agent_name, DbA2AAgent.enabled == True).first() # noqa: E712 |
| 9026 | + agent = db.query(DbA2AAgent).filter(DbA2AAgent.name == agent_name, DbA2AAgent.enabled.is_(True)).first() |
9038 | 9027 | card = None |
9039 | 9028 | if agent is not None and service._check_agent_access(agent, user_email, token_teams): # pylint: disable=protected-access |
9040 | 9029 | card = service.get_agent_card(db, agent_name) |
9041 | 9030 | if card is None: |
9042 | | - # First-Party |
9043 | | - from mcpgateway.services.a2a_server_service import A2AServerService # pylint: disable=import-outside-toplevel |
9044 | | - |
9045 | | - server_service = A2AServerService() |
9046 | | - card = server_service.get_server_agent_card(db, agent_name, user_email=user_email, token_teams=token_teams) |
| 9031 | + a2a_server_service = A2AServerService() |
| 9032 | + card = a2a_server_service.get_server_agent_card(db, agent_name, user_email=user_email, token_teams=token_teams) |
9047 | 9033 | if card is None: |
9048 | 9034 | return ORJSONResponse(status_code=404, content={"error": f"agent '{agent_name}' not found"}) |
9049 | 9035 | return ORJSONResponse(status_code=200, content=card) |
@@ -9078,9 +9064,6 @@ async def handle_internal_a2a_tasks_get(request: Request): |
9078 | 9064 | if agent_id is not None and not isinstance(agent_id, str): |
9079 | 9065 | return ORJSONResponse(status_code=400, content={"error": "agent_id must be a string"}) |
9080 | 9066 |
|
9081 | | - # First-Party |
9082 | | - from mcpgateway.services.a2a_service import A2AAgentService # pylint: disable=import-outside-toplevel |
9083 | | - |
9084 | 9067 | user_email, token_teams = _get_internal_a2a_scope_context(request) |
9085 | 9068 | service = A2AAgentService() |
9086 | 9069 | task = service.get_task(db, task_id, agent_id=agent_id, user_email=user_email, token_teams=token_teams) |
@@ -9120,9 +9103,6 @@ async def handle_internal_a2a_tasks_list(request: Request): |
9120 | 9103 | limit = min(int(body.get("limit", 100)), 1000) |
9121 | 9104 | offset = max(int(body.get("offset", 0)), 0) |
9122 | 9105 |
|
9123 | | - # First-Party |
9124 | | - from mcpgateway.services.a2a_service import A2AAgentService # pylint: disable=import-outside-toplevel |
9125 | | - |
9126 | 9106 | user_email, token_teams = _get_internal_a2a_scope_context(request) |
9127 | 9107 | service = A2AAgentService() |
9128 | 9108 | tasks = service.list_tasks(db, agent_id=agent_id, state=state, limit=limit, offset=offset, user_email=user_email, token_teams=token_teams) |
@@ -9158,9 +9138,6 @@ async def handle_internal_a2a_tasks_cancel(request: Request): |
9158 | 9138 | if agent_id is not None and not isinstance(agent_id, str): |
9159 | 9139 | return ORJSONResponse(status_code=400, content={"error": "agent_id must be a string"}) |
9160 | 9140 |
|
9161 | | - # First-Party |
9162 | | - from mcpgateway.services.a2a_service import A2AAgentService # pylint: disable=import-outside-toplevel |
9163 | | - |
9164 | 9141 | user_email, token_teams = _get_internal_a2a_scope_context(request) |
9165 | 9142 | service = A2AAgentService() |
9166 | 9143 | task = service.cancel_task(db, task_id, agent_id=agent_id, user_email=user_email, token_teams=token_teams) |
@@ -9195,10 +9172,6 @@ async def handle_internal_a2a_push_create(request: Request): |
9195 | 9172 | return ORJSONResponse(status_code=400, content={"error": "a2a_agent_id, task_id, and webhook_url are required"}) |
9196 | 9173 |
|
9197 | 9174 | # Validate webhook URL through the schema to enforce SSRF protection. |
9198 | | - # First-Party |
9199 | | - from mcpgateway.schemas import A2APushNotificationConfigCreate # pylint: disable=import-outside-toplevel |
9200 | | - from mcpgateway.services.a2a_service import A2AAgentService # pylint: disable=import-outside-toplevel |
9201 | | - |
9202 | 9175 | try: |
9203 | 9176 | validated = A2APushNotificationConfigCreate(**body) |
9204 | 9177 | except Exception as validation_err: |
@@ -9239,9 +9212,6 @@ async def handle_internal_a2a_push_get(request: Request): |
9239 | 9212 | if not task_id: |
9240 | 9213 | return ORJSONResponse(status_code=400, content={"error": "task_id is required"}) |
9241 | 9214 |
|
9242 | | - # First-Party |
9243 | | - from mcpgateway.services.a2a_service import A2AAgentService # pylint: disable=import-outside-toplevel |
9244 | | - |
9245 | 9215 | user_email, token_teams = _get_internal_a2a_scope_context(request) |
9246 | 9216 | service = A2AAgentService() |
9247 | 9217 | cfg = service.get_push_config(db, task_id, agent_id=agent_id) |
@@ -9277,9 +9247,6 @@ async def handle_internal_a2a_push_list(request: Request): |
9277 | 9247 | agent_id = body.get("agent_id") |
9278 | 9248 | task_id = body.get("task_id") |
9279 | 9249 |
|
9280 | | - # First-Party |
9281 | | - from mcpgateway.services.a2a_service import A2AAgentService # pylint: disable=import-outside-toplevel |
9282 | | - |
9283 | 9250 | user_email, token_teams = _get_internal_a2a_scope_context(request) |
9284 | 9251 | service = A2AAgentService() |
9285 | 9252 | configs = service.list_push_configs(db, agent_id=agent_id, task_id=task_id) |
@@ -9314,10 +9281,6 @@ async def handle_internal_a2a_push_delete(request: Request): |
9314 | 9281 | if not config_id: |
9315 | 9282 | return ORJSONResponse(status_code=400, content={"error": "config_id is required"}) |
9316 | 9283 |
|
9317 | | - # First-Party |
9318 | | - from mcpgateway.db import A2APushNotificationConfig # pylint: disable=import-outside-toplevel |
9319 | | - from mcpgateway.services.a2a_service import A2AAgentService # pylint: disable=import-outside-toplevel |
9320 | | - |
9321 | 9284 | user_email, token_teams = _get_internal_a2a_scope_context(request) |
9322 | 9285 | service = A2AAgentService() |
9323 | 9286 | cfg = db.query(A2APushNotificationConfig).filter(A2APushNotificationConfig.id == config_id).first() |
@@ -9355,10 +9318,6 @@ async def handle_internal_a2a_events_flush(request: Request): |
9355 | 9318 | if not events: |
9356 | 9319 | return ORJSONResponse(status_code=200, content={"count": 0}) |
9357 | 9320 |
|
9358 | | - # First-Party |
9359 | | - from mcpgateway.db import A2ATask as DbA2ATask # pylint: disable=import-outside-toplevel |
9360 | | - from mcpgateway.services.a2a_service import A2AAgentService # pylint: disable=import-outside-toplevel |
9361 | | - |
9362 | 9321 | user_email, token_teams = _get_internal_a2a_scope_context(request) |
9363 | 9322 | service = A2AAgentService() |
9364 | 9323 |
|
@@ -9403,10 +9362,6 @@ async def handle_internal_a2a_events_replay(request: Request): |
9403 | 9362 | if not task_id: |
9404 | 9363 | return ORJSONResponse(status_code=400, content={"error": "task_id required"}) |
9405 | 9364 |
|
9406 | | - # First-Party |
9407 | | - from mcpgateway.db import A2ATask as DbA2ATask # pylint: disable=import-outside-toplevel |
9408 | | - from mcpgateway.services.a2a_service import A2AAgentService # pylint: disable=import-outside-toplevel |
9409 | | - |
9410 | 9365 | user_email, token_teams = _get_internal_a2a_scope_context(request) |
9411 | 9366 | service = A2AAgentService() |
9412 | 9367 | task_row = db.query(DbA2ATask).filter(DbA2ATask.task_id == task_id).first() |
|
0 commit comments