Skip to content

Commit 810e53e

Browse files
gcgoncalvescrivetimihai
authored andcommitted
3446 - Enable public MCP objects on team Servers
Signed-off-by: Gabriel Costa <gabrielcg@proton.me>
1 parent 348b0c8 commit 810e53e

File tree

2 files changed

+113
-80
lines changed

2 files changed

+113
-80
lines changed

mcpgateway/admin.py

Lines changed: 73 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,42 @@ def _owner_access_condition(owner_column, team_column, *, user_email: str, team_
13831383
return owner_column == user_email
13841384

13851385

1386+
def _merge_select_all_ids(form: Any, flag_key: str, all_ids_key: str, checked_list: list[str]) -> list[str]:
1387+
"""Merge server-fetched IDs with UI-checked IDs when "Select All" is active.
1388+
1389+
When the user clicks "Select All" in a paginated list, the browser populates
1390+
*all_ids_key* with IDs fetched from the corresponding /ids endpoint. Because
1391+
that endpoint may be team-scoped, it can miss platform-public items that are
1392+
still visible (and checked) in the UI. Taking the union of both sources
1393+
ensures every explicitly selected item is preserved.
1394+
1395+
Note: both sources are client-supplied form values. Downstream persistence
1396+
code is responsible for enforcing final access control on the merged IDs.
1397+
1398+
Args:
1399+
form: Starlette form object.
1400+
flag_key (str): Form field that signals "Select All" mode (e.g. ``"selectAllTools"``).
1401+
all_ids_key (str): Form field holding the JSON-encoded server-fetched IDs.
1402+
checked_list (list[str]): IDs collected from checked checkboxes in the form.
1403+
1404+
Returns:
1405+
list[str]: Merged, deduplicated list of string IDs; or *checked_list* unchanged
1406+
when Select All is not active or the JSON payload cannot be parsed.
1407+
"""
1408+
if form.get(flag_key) != "true":
1409+
return checked_list
1410+
raw = form.get(all_ids_key) or "[]"
1411+
try:
1412+
server_ids = orjson.loads(raw)
1413+
# Normalise to str to avoid silent int/str duplicates from different sources.
1414+
merged = list({str(i) for i in server_ids} | set(checked_list))
1415+
LOGGER.info("Select All (%s): %d items after merge", all_ids_key, len(merged))
1416+
return merged
1417+
except orjson.JSONDecodeError:
1418+
LOGGER.warning("Failed to parse %s JSON, falling back to checked items", all_ids_key)
1419+
return checked_list
1420+
1421+
13861422
async def _has_permission(
13871423
*,
13881424
db: Session,
@@ -2531,39 +2567,13 @@ async def admin_add_server(request: Request, db: Session = Depends(get_db), user
25312567
try:
25322568
LOGGER.debug(f"User {get_user_email(user)} is adding a new server with name: {form['name']}")
25332569

2534-
# Handle "Select All" for tools
2535-
associated_tools_list = form.getlist("associatedTools")
2536-
if form.get("selectAllTools") == "true":
2537-
# User clicked "Select All" - get all tool IDs from hidden field
2538-
all_tool_ids_json = str(form.get("allToolIds", "[]"))
2539-
try:
2540-
all_tool_ids = orjson.loads(all_tool_ids_json)
2541-
associated_tools_list = all_tool_ids
2542-
LOGGER.info(f"Select All tools enabled: {len(all_tool_ids)} tools selected")
2543-
except orjson.JSONDecodeError:
2544-
LOGGER.warning("Failed to parse allToolIds JSON, falling back to checked tools")
2545-
2546-
# Handle "Select All" for resources
2547-
associated_resources_list = form.getlist("associatedResources")
2548-
if form.get("selectAllResources") == "true":
2549-
all_resource_ids_json = str(form.get("allResourceIds", "[]"))
2550-
try:
2551-
all_resource_ids = orjson.loads(all_resource_ids_json)
2552-
associated_resources_list = all_resource_ids
2553-
LOGGER.info(f"Select All resources enabled: {len(all_resource_ids)} resources selected")
2554-
except orjson.JSONDecodeError:
2555-
LOGGER.warning("Failed to parse allResourceIds JSON, falling back to checked resources")
2556-
2557-
# Handle "Select All" for prompts
2558-
associated_prompts_list = form.getlist("associatedPrompts")
2559-
if form.get("selectAllPrompts") == "true":
2560-
all_prompt_ids_json = str(form.get("allPromptIds", "[]"))
2561-
try:
2562-
all_prompt_ids = orjson.loads(all_prompt_ids_json)
2563-
associated_prompts_list = all_prompt_ids
2564-
LOGGER.info(f"Select All prompts enabled: {len(all_prompt_ids)} prompts selected")
2565-
except orjson.JSONDecodeError:
2566-
LOGGER.warning("Failed to parse allPromptIds JSON, falling back to checked prompts")
2570+
# Handle "Select All" for tools, resources, and prompts.
2571+
# _merge_select_all_ids takes the union of the server-fetched paginated IDs
2572+
# (allToolIds etc.) with the explicitly checked form values so that
2573+
# platform-public items visible in the UI are never silently dropped.
2574+
associated_tools_list = _merge_select_all_ids(form, "selectAllTools", "allToolIds", form.getlist("associatedTools"))
2575+
associated_resources_list = _merge_select_all_ids(form, "selectAllResources", "allResourceIds", form.getlist("associatedResources"))
2576+
associated_prompts_list = _merge_select_all_ids(form, "selectAllPrompts", "allPromptIds", form.getlist("associatedPrompts"))
25672577

25682578
# Handle OAuth 2.0 configuration (RFC 9728)
25692579
oauth_enabled = form.get("oauth_enabled") == "on"
@@ -2709,39 +2719,13 @@ async def admin_edit_server(
27092719

27102720
mod_metadata = MetadataCapture.extract_modification_metadata(request, user, 0)
27112721

2712-
# Handle "Select All" for tools
2713-
associated_tools_list = form.getlist("associatedTools")
2714-
if form.get("selectAllTools") == "true":
2715-
# User clicked "Select All" - get all tool IDs from hidden field
2716-
all_tool_ids_json = str(form.get("allToolIds", "[]"))
2717-
try:
2718-
all_tool_ids = orjson.loads(all_tool_ids_json)
2719-
associated_tools_list = all_tool_ids
2720-
LOGGER.info(f"Select All tools enabled for edit: {len(all_tool_ids)} tools selected")
2721-
except orjson.JSONDecodeError:
2722-
LOGGER.warning("Failed to parse allToolIds JSON, falling back to checked tools")
2723-
2724-
# Handle "Select All" for resources
2725-
associated_resources_list = form.getlist("associatedResources")
2726-
if form.get("selectAllResources") == "true":
2727-
all_resource_ids_json = str(form.get("allResourceIds", "[]"))
2728-
try:
2729-
all_resource_ids = orjson.loads(all_resource_ids_json)
2730-
associated_resources_list = all_resource_ids
2731-
LOGGER.info(f"Select All resources enabled for edit: {len(all_resource_ids)} resources selected")
2732-
except orjson.JSONDecodeError:
2733-
LOGGER.warning("Failed to parse allResourceIds JSON, falling back to checked resources")
2734-
2735-
# Handle "Select All" for prompts
2736-
associated_prompts_list = form.getlist("associatedPrompts")
2737-
if form.get("selectAllPrompts") == "true":
2738-
all_prompt_ids_json = str(form.get("allPromptIds", "[]"))
2739-
try:
2740-
all_prompt_ids = orjson.loads(all_prompt_ids_json)
2741-
associated_prompts_list = all_prompt_ids
2742-
LOGGER.info(f"Select All prompts enabled for edit: {len(all_prompt_ids)} prompts selected")
2743-
except orjson.JSONDecodeError:
2744-
LOGGER.warning("Failed to parse allPromptIds JSON, falling back to checked prompts")
2722+
# Handle "Select All" for tools, resources, and prompts.
2723+
# _merge_select_all_ids takes the union of the server-fetched paginated IDs
2724+
# (allToolIds etc.) with the explicitly checked form values so that
2725+
# platform-public items visible in the UI are never silently dropped.
2726+
associated_tools_list = _merge_select_all_ids(form, "selectAllTools", "allToolIds", form.getlist("associatedTools"))
2727+
associated_resources_list = _merge_select_all_ids(form, "selectAllResources", "allResourceIds", form.getlist("associatedResources"))
2728+
associated_prompts_list = _merge_select_all_ids(form, "selectAllPrompts", "allPromptIds", form.getlist("associatedPrompts"))
27452729

27462730
# Handle OAuth 2.0 configuration (RFC 9728)
27472731
oauth_enabled = form.get("oauth_enabled") == "on"
@@ -8288,14 +8272,19 @@ async def admin_get_all_tool_ids(
82888272
LOGGER.debug(f"Filtering tools by gateway IDs: {non_null_ids}")
82898273

82908274
# Build access conditions
8291-
# When team_id is specified, show ONLY items from that team (team-scoped view)
8292-
# Otherwise, show all accessible items (All Teams view)
8275+
# When team_id is specified, show items from that team plus all platform-public tools
8276+
# (visibility="public") so the "Select All" count and payload match what is actually
8277+
# visible in the edit UI. Public visibility is platform-wide regardless of team ownership.
8278+
# Otherwise, show all accessible items (All Teams view).
82938279
if team_id:
82948280
if team_id in team_ids:
82958281
# Apply visibility check: team/public resources + user's own resources (including private)
8282+
# Also include all platform-public tools so they can be associated with team-owned
8283+
# virtual servers.
82968284
team_access = [
82978285
and_(DbTool.team_id == team_id, DbTool.visibility.in_(["team", "public"])),
82988286
and_(DbTool.team_id == team_id, DbTool.owner_email == user_email),
8287+
DbTool.visibility == "public",
82998288
]
83008289
query = query.where(or_(*team_access))
83018290
LOGGER.debug(f"Filtering tool IDs by team_id: {team_id}")
@@ -9526,15 +9515,20 @@ async def admin_get_all_prompt_ids(
95269515
query = query.where(DbPrompt.enabled.is_(True))
95279516

95289517
# Build access conditions
9529-
# When team_id is specified, show ONLY items from that team (team-scoped view)
9530-
# Otherwise, show all accessible items (All Teams view)
9518+
# When team_id is specified, show items from that team plus all platform-public prompts
9519+
# (visibility="public") so the "Select All" count and payload match what is actually
9520+
# visible in the edit UI. Public visibility is platform-wide regardless of team ownership.
9521+
# Otherwise, show all accessible items (All Teams view).
95319522
if team_id:
9532-
# Team-specific view: only show prompts from the specified team
9523+
# Team-specific view: show prompts from the specified team plus platform-public prompts
95339524
if team_id in team_ids:
95349525
# Apply visibility check: team/public resources + user's own resources (including private)
9526+
# Also include all platform-public prompts so they can be associated with team-owned
9527+
# virtual servers.
95359528
team_access = [
95369529
and_(DbPrompt.team_id == team_id, DbPrompt.visibility.in_(["team", "public"])),
95379530
and_(DbPrompt.team_id == team_id, DbPrompt.owner_email == user_email),
9531+
DbPrompt.visibility == "public",
95389532
]
95399533
query = query.where(or_(*team_access))
95409534
LOGGER.debug(f"Filtering prompt IDs by team_id: {team_id}")
@@ -9606,15 +9600,20 @@ async def admin_get_all_resource_ids(
96069600
query = query.where(DbResource.enabled.is_(True))
96079601

96089602
# Build access conditions
9609-
# When team_id is specified, show ONLY items from that team (team-scoped view)
9610-
# Otherwise, show all accessible items (All Teams view)
9603+
# When team_id is specified, show items from that team plus all platform-public resources
9604+
# (visibility="public") so the "Select All" count and payload match what is actually
9605+
# visible in the edit UI. Public visibility is platform-wide regardless of team ownership.
9606+
# Otherwise, show all accessible items (All Teams view).
96119607
if team_id:
9612-
# Team-specific view: only show resources from the specified team
9608+
# Team-specific view: show resources from the specified team plus platform-public resources
96139609
if team_id in team_ids:
96149610
# Apply visibility check: team/public resources + user's own resources (including private)
9611+
# Also include all platform-public resources so they can be associated with team-owned
9612+
# virtual servers.
96159613
team_access = [
96169614
and_(DbResource.team_id == team_id, DbResource.visibility.in_(["team", "public"])),
96179615
and_(DbResource.team_id == team_id, DbResource.owner_email == user_email),
9616+
DbResource.visibility == "public",
96189617
]
96199618
query = query.where(or_(*team_access))
96209619
LOGGER.debug(f"Filtering resource IDs by team_id: {team_id}")

tests/unit/mcpgateway/test_admin.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -606,9 +606,10 @@ async def test_admin_add_server_select_all_parses_json(self, mock_register_serve
606606
assert result.status_code == 200
607607

608608
server_create = mock_register_server.call_args.args[1]
609-
assert server_create.associated_tools == ["tool-1", "tool-2"]
610-
assert server_create.associated_resources == ["res-1"]
611-
assert server_create.associated_prompts == ["prompt-1", "prompt-2"]
609+
# Merge of allToolIds + associatedTools (public items from UI are preserved)
610+
assert set(server_create.associated_tools) == {"tool-1", "tool-2", "tool-x"}
611+
assert set(server_create.associated_resources) == {"res-1", "res-x"}
612+
assert set(server_create.associated_prompts) == {"prompt-1", "prompt-2", "prompt-x"}
612613
assert server_create.oauth_enabled is True
613614
assert server_create.oauth_config["authorization_servers"] == ["https://idp.example.com"]
614615
assert server_create.oauth_config["scopes_supported"] == ["openid", "profile"]
@@ -880,9 +881,10 @@ async def test_admin_edit_server_select_all_parses_json(self, mock_update_server
880881
assert result.status_code == 200
881882

882883
server_update = mock_update_server.call_args[0][2]
883-
assert server_update.associated_tools == ["tool-1", "tool-2"]
884-
assert server_update.associated_resources == ["res-1"]
885-
assert server_update.associated_prompts == ["prompt-1", "prompt-2"]
884+
# Merge of allToolIds + associatedTools (public items from UI are preserved)
885+
assert set(server_update.associated_tools) == {"tool-1", "tool-2", "tool-x"}
886+
assert set(server_update.associated_resources) == {"res-1", "res-x"}
887+
assert set(server_update.associated_prompts) == {"prompt-1", "prompt-2", "prompt-x"}
886888

887889
@patch.object(ServerService, "update_server")
888890
async def test_admin_edit_server_select_all_json_decode_error(self, mock_update_server, mock_request, mock_db, monkeypatch):
@@ -10323,6 +10325,38 @@ async def test_admin_get_all_tool_ids_gateway_and_team_filters(monkeypatch, mock
1032310325
assert result["count"] == 0
1032410326

1032510327

10328+
@pytest.mark.asyncio
10329+
async def test_admin_get_all_tool_ids_team_scoped_includes_public(monkeypatch, mock_db):
10330+
"""When team_id is set, the SQL query must include a standalone visibility='public'
10331+
condition (not gated by team_id) so that platform-public tools from public MCP
10332+
servers appear in team-scoped Select All fetches and can be associated with
10333+
team-owned virtual servers. Regression test for issue #3446."""
10334+
import re
10335+
from sqlalchemy.dialects import sqlite as sqlite_dialect
10336+
10337+
setup_team_service(monkeypatch, ["team-1"])
10338+
mock_db.execute.return_value.all.return_value = []
10339+
10340+
await admin_get_all_tool_ids(
10341+
include_inactive=False,
10342+
gateway_id=None,
10343+
team_id="team-1",
10344+
db=mock_db,
10345+
user={"email": "user@example.com", "db": mock_db},
10346+
)
10347+
10348+
executed_query = mock_db.execute.call_args[0][0]
10349+
sql = str(executed_query.compile(dialect=sqlite_dialect.dialect(), compile_kwargs={"literal_binds": True}))
10350+
10351+
# A standalone `visibility = 'public'` condition must be present as a top-level
10352+
# OR alternative — not wrapped inside `team_id = '...' AND visibility IN (...)`.
10353+
# This is what makes platform-public tools visible to team-scoped queries.
10354+
assert re.search(r"tools\.visibility\s*=\s*'public'", sql), (
10355+
"Expected a standalone visibility='public' condition in team-scoped tool IDs query. "
10356+
"Platform-public tools must be accessible when associating with team-owned virtual servers."
10357+
)
10358+
10359+
1032610360
@pytest.mark.asyncio
1032710361
async def test_admin_get_all_prompt_ids_gateway_and_team_filters(monkeypatch, mock_db):
1032810362
"""Cover non-null gateway filters and team membership checks for prompt IDs helper."""

0 commit comments

Comments
 (0)