diff --git a/mcpgateway/schemas.py b/mcpgateway/schemas.py index 8ddfa60dd8..445546c02a 100644 --- a/mcpgateway/schemas.py +++ b/mcpgateway/schemas.py @@ -877,7 +877,7 @@ class ToolUpdate(BaseModelWithConfigDict): auth: Optional[AuthenticationValues] = Field(None, description="Authentication credentials (Basic or Bearer Token or custom headers) if required") gateway_id: Optional[str] = Field(None, description="id of gateway for the tool") tags: Optional[List[str]] = Field(None, description="Tags for categorizing the tool") - visibility: Optional[str] = Field(default="public", description="Visibility level: private, team, or public") + visibility: Optional[Literal["private", "team", "public"]] = Field(None, description="Visibility level: private, team, or public") # Passthrough REST fields base_url: Optional[str] = Field(None, description="Base URL for REST passthrough") diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index ac95ebb875..b8d78ddda7 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -487,6 +487,8 @@ def __init__(self, name: str, enabled: bool = True, tool_id: Optional[int] = Non self.tool_id = tool_id if visibility == "team": vis_label = "Team-level" + elif visibility == "private": + vis_label = "Private" else: vis_label = "Public" message = f"{vis_label} Tool already exists with name: {name}" @@ -4105,6 +4107,59 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head with perf_tracker.track_operation("tool_invocation", name): pass # Duration already captured above + @staticmethod + def _check_tool_name_conflict(db: Session, custom_name: str, visibility: str, tool_id: str, team_id: Optional[str] = None, owner_email: Optional[str] = None) -> None: + """Raise ToolNameConflictError if another tool with the same name exists in the target visibility scope. + + Args: + db: The SQLAlchemy database session. + custom_name: The custom name to check for conflicts. + visibility: The target visibility scope (``public``, ``team``, or ``private``). + tool_id: The ID of the tool being updated (excluded from the conflict search). + team_id: Required when *visibility* is ``team``; scopes the uniqueness check to this team. + owner_email: Required when *visibility* is ``private``; scopes the uniqueness check to this owner. + + Raises: + ToolNameConflictError: If a conflicting tool already exists in the target scope. + """ + if visibility == "public": + existing_tool = get_for_update( + db, + DbTool, + where=and_( + DbTool.custom_name == custom_name, + DbTool.visibility == "public", + DbTool.id != tool_id, + ), + ) + elif visibility == "team" and team_id: + existing_tool = get_for_update( + db, + DbTool, + where=and_( + DbTool.custom_name == custom_name, + DbTool.visibility == "team", + DbTool.team_id == team_id, + DbTool.id != tool_id, + ), + ) + elif visibility == "private" and owner_email: + existing_tool = get_for_update( + db, + DbTool, + where=and_( + DbTool.custom_name == custom_name, + DbTool.visibility == "private", + DbTool.owner_email == owner_email, + DbTool.id != tool_id, + ), + ) + else: + logger.warning("Skipping conflict check for tool %s: visibility=%r requires %s but none provided", tool_id, visibility, "team_id" if visibility == "team" else "owner_email") + return + if existing_tool: + raise ToolNameConflictError(existing_tool.custom_name, enabled=existing_tool.enabled, tool_id=existing_tool.id, visibility=existing_tool.visibility) + async def update_tool( self, db: Session, @@ -4175,40 +4230,29 @@ async def update_tool( if not await permission_service.check_resource_ownership(user_email, tool): raise PermissionError("Only the owner can update this tool") + # Track whether a name change occurred (before tool.name is mutated) + name_is_changing = bool(tool_update.name and tool_update.name != tool.name) + # Check for name change and ensure uniqueness - if tool_update.name and tool_update.name != tool.name: - # Check for existing tool with the same name and visibility - if tool_update.visibility.lower() == "public": - # Check for existing public tool with the same name (row-locked) - existing_tool = get_for_update( - db, - DbTool, - where=and_( - DbTool.custom_name == tool_update.custom_name, - DbTool.visibility == "public", - DbTool.id != tool.id, - ), - ) - if existing_tool: - raise ToolNameConflictError(existing_tool.custom_name, enabled=existing_tool.enabled, tool_id=existing_tool.id, visibility=existing_tool.visibility) - elif tool_update.visibility.lower() == "team" and tool_update.team_id: - # Check for existing team tool with the same name - existing_tool = get_for_update( - db, - DbTool, - where=and_( - DbTool.custom_name == tool_update.custom_name, - DbTool.visibility == "team", - DbTool.team_id == tool_update.team_id, - DbTool.id != tool.id, - ), - ) - if existing_tool: - raise ToolNameConflictError(existing_tool.custom_name, enabled=existing_tool.enabled, tool_id=existing_tool.id, visibility=existing_tool.visibility) + if name_is_changing: + # Always derive ownership fields from the DB record — never trust client-provided team_id/owner_email + tool_visibility_ref = tool.visibility if tool_update.visibility is None else tool_update.visibility.lower() + if tool_update.custom_name is not None: + custom_name_ref = tool_update.custom_name + elif tool.name == tool.custom_name: + custom_name_ref = tool_update.name # custom_name will track the rename + else: + custom_name_ref = tool.custom_name # custom_name stays unchanged + self._check_tool_name_conflict(db, custom_name_ref, tool_visibility_ref, tool.id, team_id=tool.team_id, owner_email=tool.owner_email) if tool_update.custom_name is None and tool.name == tool.custom_name: tool.custom_name = tool_update.name tool.name = tool_update.name + # Check for conflicts when visibility changes without a name change + if tool_update.visibility is not None and tool_update.visibility.lower() != tool.visibility and not name_is_changing: + new_visibility = tool_update.visibility.lower() + self._check_tool_name_conflict(db, tool.custom_name, new_visibility, tool.id, team_id=tool.team_id, owner_email=tool.owner_email) + if tool_update.custom_name is not None: tool.custom_name = tool_update.custom_name if tool_update.displayName is not None: @@ -4239,8 +4283,6 @@ async def update_tool( tool.auth_type = tool_update.auth.auth_type if tool_update.auth.auth_value is not None: tool.auth_value = tool_update.auth.auth_value - else: - tool.auth_type = None # Update tags if provided if tool_update.tags is not None: diff --git a/tests/unit/mcpgateway/services/test_tool_service_coverage.py b/tests/unit/mcpgateway/services/test_tool_service_coverage.py index bac078112c..cf27076070 100644 --- a/tests/unit/mcpgateway/services/test_tool_service_coverage.py +++ b/tests/unit/mcpgateway/services/test_tool_service_coverage.py @@ -182,6 +182,32 @@ def mock_tool(mock_gateway): return tool +def _make_tool_update(**overrides) -> MagicMock: + """Build a ToolUpdate MagicMock with all fields defaulting to None.""" + update = MagicMock(spec=ToolUpdate) + defaults = dict( + name=None, + custom_name=None, + displayName=None, + url=None, + description=None, + integration_type=None, + request_type=None, + headers=None, + input_schema=None, + output_schema=None, + annotations=None, + jsonpath_filter=None, + visibility=None, + auth=None, + tags=None, + ) + defaults.update(overrides) + for attr, value in defaults.items(): + setattr(update, attr, value) + return update + + # ═════════════════════════════════════════════════════════════════════════════ # 1. Module-level lazy singletons: _get_registry_cache, _get_tool_lookup_cache # ═════════════════════════════════════════════════════════════════════════════ @@ -4152,7 +4178,7 @@ async def test_update_multiple_fields(self, tool_service): tool.description = "old desc" tool.original_description = "old desc" tool.version = 3 - tool.team_id = None + tool.team_id = "team-1" tool.visibility = "public" tool_update = MagicMock(spec=ToolUpdate) @@ -4169,7 +4195,6 @@ async def test_update_multiple_fields(self, tool_service): tool_update.annotations = None tool_update.jsonpath_filter = None tool_update.visibility = "team" - tool_update.team_id = "team-1" tool_update.auth = None tool_update.tags = ["api", "v2"] @@ -4266,7 +4291,6 @@ async def test_team_name_conflict_on_update(self, tool_service): tool_update.annotations = None tool_update.jsonpath_filter = None tool_update.visibility = "team" - tool_update.team_id = "team-1" tool_update.auth = None tool_update.tags = None @@ -4281,6 +4305,344 @@ async def test_team_name_conflict_on_update(self, tool_service): with pytest.raises(ToolNameConflictError): await tool_service.update_tool(db, "t1", tool_update) + @pytest.mark.asyncio + async def test_update_tool_visibility_none_team_conflict_check(self, tool_service, mock_tool): + """When ToolUpdate.visibility is None and tool.visibility is team, should check team conflicts and allow successful update.""" + tool = mock_tool + tool.id = "t1" + tool.name = "old_name" + tool.custom_name = "old_name" + tool.team_id = "team-1" + tool.visibility = "team" + tool.version = 1 + + tool_update = _make_tool_update(name="conflict_name", custom_name="conflict_name", description="Updated description") + + # Test 1: Conflict detection - should raise error when name conflicts + existing = MagicMock() + existing.custom_name = "conflict_name" + existing.enabled = True + existing.id = "t2" + existing.visibility = "team" + + db = MagicMock() + with patch("mcpgateway.services.tool_service.get_for_update", side_effect=[tool, existing]): + with pytest.raises(ToolNameConflictError): + await tool_service.update_tool(db, "t1", tool_update) + + # Test 2: Successful update - should work when no conflict and preserve visibility + tool_update.name = "new_name" + tool_update.custom_name = "new_name" + + with ( + patch("mcpgateway.services.tool_service.get_for_update", side_effect=[tool, None]), + patch.object(tool_service, "_notify_tool_updated", AsyncMock()), + patch.object(tool_service, "convert_tool_to_read", return_value={"id": "t1", "name": "new_name", "visibility": "team"}), + ): + result = await tool_service.update_tool(db, "t1", tool_update) + + assert result is not None + assert tool.name == "new_name" + assert tool.custom_name == "new_name" + assert tool.description == "Updated description" + assert tool.visibility == "team" + assert tool.team_id == "team-1" + assert tool.version == 2 + + @pytest.mark.asyncio + async def test_update_tool_visibility_none_public_conflict_check(self, tool_service, mock_tool): + """When ToolUpdate.visibility is None and tool.visibility is public, should check public conflicts and allow successful update.""" + tool = mock_tool + tool.id = "t1" + tool.name = "old_name" + tool.custom_name = "old_name" + tool.team_id = None + tool.visibility = "public" + tool.version = 5 + + tool_update = _make_tool_update(name="conflict_name", custom_name="conflict_name", displayName="Updated Tool") + + # Test 1: Conflict detection + existing = MagicMock() + existing.custom_name = "conflict_name" + existing.enabled = True + existing.id = "t2" + existing.visibility = "public" + + db = MagicMock() + with patch("mcpgateway.services.tool_service.get_for_update", side_effect=[tool, existing]): + with pytest.raises(ToolNameConflictError): + await tool_service.update_tool(db, "t1", tool_update) + + # Test 2: Successful update + tool_update.name = "new_public_name" + tool_update.custom_name = "new_public_name" + + with ( + patch("mcpgateway.services.tool_service.get_for_update", side_effect=[tool, None]), + patch.object(tool_service, "_notify_tool_updated", AsyncMock()), + patch.object(tool_service, "convert_tool_to_read", return_value={"id": "t1", "name": "new_public_name", "visibility": "public"}), + ): + result = await tool_service.update_tool(db, "t1", tool_update) + + assert result is not None + assert tool.name == "new_public_name" + assert tool.custom_name == "new_public_name" + assert tool.display_name == "Updated Tool" + assert tool.visibility == "public" + assert tool.team_id is None + assert tool.version == 6 + + @pytest.mark.asyncio + async def test_update_tool_team_id_fallback_from_db(self, tool_service, mock_tool): + """When ToolUpdate.team_id is None, should fall back to tool.team_id for conflict check.""" + tool = mock_tool + tool.id = "t1" + tool.name = "old_name" + tool.custom_name = "old_name" + tool.team_id = "team-1" + tool.visibility = "team" + tool.version = 1 + + tool_update = _make_tool_update(name="conflict_name", custom_name="conflict_name") + + existing = MagicMock() + existing.custom_name = "conflict_name" + existing.enabled = True + existing.id = "t2" + existing.visibility = "team" + + db = MagicMock() + with patch("mcpgateway.services.tool_service.get_for_update", side_effect=[tool, existing]): + with pytest.raises(ToolNameConflictError): + await tool_service.update_tool(db, "t1", tool_update) + + @pytest.mark.asyncio + async def test_update_tool_custom_name_fallback_in_conflict_check(self, tool_service, mock_tool): + """When ToolUpdate.custom_name is None, conflict check should use tool_update.name.""" + tool = mock_tool + tool.id = "t1" + tool.name = "old_name" + tool.custom_name = "old_name" + tool.team_id = None + tool.visibility = "public" + tool.version = 1 + + tool_update = _make_tool_update(name="conflict_name", visibility="public") + + existing = MagicMock() + existing.custom_name = "conflict_name" + existing.enabled = True + existing.id = "t2" + existing.visibility = "public" + + db = MagicMock() + with patch("mcpgateway.services.tool_service.get_for_update", side_effect=[tool, existing]): + with pytest.raises(ToolNameConflictError): + await tool_service.update_tool(db, "t1", tool_update) + + @pytest.mark.asyncio + async def test_update_tool_visibility_change_conflict_check(self, tool_service, mock_tool): + """Changing visibility without changing name should still check for conflicts in the target scope.""" + tool = mock_tool + tool.id = "t1" + tool.name = "my_tool" + tool.custom_name = "my_tool" + tool.team_id = "team-1" + tool.visibility = "team" + tool.version = 1 + + tool_update = _make_tool_update(visibility="public") + + existing = MagicMock() + existing.custom_name = "my_tool" + existing.enabled = True + existing.id = "t2" + existing.visibility = "public" + + db = MagicMock() + with patch("mcpgateway.services.tool_service.get_for_update", side_effect=[tool, existing]): + with pytest.raises(ToolNameConflictError): + await tool_service.update_tool(db, "t1", tool_update) + + @pytest.mark.asyncio + async def test_update_tool_visibility_change_to_team_conflict(self, tool_service, mock_tool): + """Changing visibility from public to team should check for team-scoped conflicts.""" + tool = mock_tool + tool.id = "t1" + tool.name = "my_tool" + tool.custom_name = "my_tool" + tool.team_id = "team-1" + tool.visibility = "public" + tool.version = 1 + + tool_update = _make_tool_update(visibility="team") + + existing = MagicMock() + existing.custom_name = "my_tool" + existing.enabled = True + existing.id = "t2" + existing.visibility = "team" + + db = MagicMock() + with patch("mcpgateway.services.tool_service.get_for_update", side_effect=[tool, existing]): + with pytest.raises(ToolNameConflictError): + await tool_service.update_tool(db, "t1", tool_update) + + @pytest.mark.asyncio + async def test_update_tool_auth_not_reset_when_omitted(self, tool_service, mock_tool): + """When ToolUpdate.auth is None (not provided), existing auth_type should be preserved.""" + tool = mock_tool + tool.id = "t1" + tool.name = "my_tool" + tool.custom_name = "my_tool" + tool.team_id = None + tool.visibility = "public" + tool.auth_type = "bearer" + tool.auth_value = "secret-token" + tool.version = 1 + + tool_update = _make_tool_update(description="Updated description") + + db = MagicMock() + with ( + patch("mcpgateway.services.tool_service.get_for_update", return_value=tool), + patch.object(tool_service, "_notify_tool_updated", AsyncMock()), + patch.object(tool_service, "convert_tool_to_read", return_value={"id": "t1"}), + ): + result = await tool_service.update_tool(db, "t1", tool_update) + + assert result is not None + assert tool.auth_type == "bearer" # Should NOT have been wiped + assert tool.auth_value == "secret-token" + assert tool.description == "Updated description" + + @pytest.mark.asyncio + async def test_update_tool_private_visibility_name_conflict(self, tool_service, mock_tool): + """Renaming a private tool should check for conflicts scoped by owner_email.""" + tool = mock_tool + tool.id = "t1" + tool.name = "old_name" + tool.custom_name = "old_name" + tool.team_id = None + tool.visibility = "private" + tool.owner_email = "owner@example.com" + tool.version = 1 + + tool_update = _make_tool_update(name="conflict_name", custom_name="conflict_name") + + existing = MagicMock() + existing.custom_name = "conflict_name" + existing.enabled = True + existing.id = "t2" + existing.visibility = "private" + + db = MagicMock() + with patch("mcpgateway.services.tool_service.get_for_update", side_effect=[tool, existing]): + with pytest.raises(ToolNameConflictError): + await tool_service.update_tool(db, "t1", tool_update) + + @pytest.mark.asyncio + async def test_update_tool_visibility_change_to_private_conflict(self, tool_service, mock_tool): + """Changing visibility to private should check for owner-scoped conflicts.""" + tool = mock_tool + tool.id = "t1" + tool.name = "my_tool" + tool.custom_name = "my_tool" + tool.team_id = None + tool.visibility = "public" + tool.owner_email = "owner@example.com" + tool.version = 1 + + tool_update = _make_tool_update(visibility="private") + + existing = MagicMock() + existing.custom_name = "my_tool" + existing.enabled = True + existing.id = "t2" + existing.visibility = "private" + + db = MagicMock() + with patch("mcpgateway.services.tool_service.get_for_update", side_effect=[tool, existing]): + with pytest.raises(ToolNameConflictError): + await tool_service.update_tool(db, "t1", tool_update) + + @pytest.mark.asyncio + async def test_update_tool_rename_with_divergent_custom_name(self, tool_service, mock_tool): + """When custom_name differs from name, renaming should check conflict against the existing custom_name (which won't change).""" + tool = mock_tool + tool.id = "t1" + tool.name = "internal_name" + tool.custom_name = "display_name" # Diverges from name + tool.team_id = None + tool.visibility = "public" + tool.version = 1 + + # Rename the internal name — custom_name should stay "display_name" + tool_update = _make_tool_update(name="new_internal_name") + + # A tool with custom_name "display_name" already exists in public scope + existing = MagicMock() + existing.custom_name = "display_name" + existing.enabled = True + existing.id = "t2" + existing.visibility = "public" + + db = MagicMock() + # The conflict check should query for "display_name" (the unchanged custom_name), + # NOT "new_internal_name" + with patch("mcpgateway.services.tool_service.get_for_update", side_effect=[tool, existing]): + with pytest.raises(ToolNameConflictError): + await tool_service.update_tool(db, "t1", tool_update) + + @pytest.mark.asyncio + async def test_update_tool_rename_with_divergent_custom_name_no_conflict(self, tool_service, mock_tool): + """When custom_name differs from name, renaming should succeed if custom_name has no conflict.""" + tool = mock_tool + tool.id = "t1" + tool.name = "internal_name" + tool.custom_name = "display_name" # Diverges from name + tool.team_id = None + tool.visibility = "public" + tool.version = 1 + + tool_update = _make_tool_update(name="new_internal_name") + + db = MagicMock() + with ( + patch("mcpgateway.services.tool_service.get_for_update", side_effect=[tool, None]), + patch.object(tool_service, "_notify_tool_updated", AsyncMock()), + patch.object(tool_service, "convert_tool_to_read", return_value={"id": "t1", "name": "new_internal_name"}), + ): + result = await tool_service.update_tool(db, "t1", tool_update) + + assert result is not None + assert tool.name == "new_internal_name" + assert tool.custom_name == "display_name" # Should NOT have changed + assert tool.version == 2 + + def test_check_tool_name_conflict_skips_team_without_team_id(self, tool_service): + """When visibility is 'team' but team_id is None, should log warning and skip conflict check.""" + db = MagicMock() + with patch("mcpgateway.services.tool_service.logger") as mock_logger: + tool_service._check_tool_name_conflict(db, "my_tool", "team", "t1", team_id=None, owner_email=None) + mock_logger.warning.assert_called_once() + assert mock_logger.warning.call_args[0][3] == "team_id" + + def test_check_tool_name_conflict_skips_private_without_owner_email(self, tool_service): + """When visibility is 'private' but owner_email is None, should log warning and skip conflict check.""" + db = MagicMock() + with patch("mcpgateway.services.tool_service.logger") as mock_logger: + tool_service._check_tool_name_conflict(db, "my_tool", "private", "t1", team_id=None, owner_email=None) + mock_logger.warning.assert_called_once() + assert mock_logger.warning.call_args[0][3] == "owner_email" + + def test_tool_name_conflict_error_private_label(self): + """ToolNameConflictError should label private conflicts as 'Private', not 'Public'.""" + err = ToolNameConflictError("my_tool", enabled=True, tool_id="t1", visibility="private") + assert "Private" in str(err) + assert "Public" not in str(err) + @pytest.mark.asyncio async def test_permission_check_on_update(self, tool_service): """Non-owner user gets PermissionError on update."""