Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 53 additions & 13 deletions mcpgateway/services/prompt_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,22 +440,30 @@ async def register_prompt(
PromptError: For other prompt registration errors

Examples:
>>> import logging
>>> logging.disable(logging.CRITICAL)
>>> from mcpgateway.services.prompt_service import PromptService
>>> from unittest.mock import MagicMock
>>> from unittest.mock import AsyncMock, MagicMock
>>> service = PromptService()
>>> db = MagicMock()
>>> prompt = MagicMock()
>>> prompt.template = "Hello {{ name }}"
>>> prompt.name = "test-prompt"
>>> prompt.custom_name = None
>>> prompt.display_name = None
>>> prompt.arguments = []
>>> db.execute.return_value.scalar_one_or_none.return_value = None
>>> db.add = MagicMock()
>>> db.commit = MagicMock()
>>> db.refresh = MagicMock()
>>> service._notify_prompt_added = MagicMock()
>>> service._notify_prompt_added = AsyncMock()
>>> service.convert_prompt_to_read = MagicMock(return_value={})
>>> import asyncio
>>> try:
... asyncio.run(service.register_prompt(db, prompt))
... except Exception:
... pass
>>> logging.disable(logging.NOTSET)
"""
try:
# Validate template syntax
Expand Down Expand Up @@ -690,16 +698,31 @@ async def register_prompts_bulk(
PromptError: If bulk registration fails critically

Examples:
>>> import logging
>>> logging.disable(logging.CRITICAL)
>>> from mcpgateway.services.prompt_service import PromptService
>>> from unittest.mock import MagicMock
>>> service = PromptService()
>>> db = MagicMock()
>>> prompts = [MagicMock(), MagicMock()]
>>> p1 = MagicMock()
>>> p1.name = "prompt-1"
>>> p1.template = "Hello"
>>> p1.custom_name = None
>>> p1.display_name = None
>>> p1.arguments = []
>>> p2 = MagicMock()
>>> p2.name = "prompt-2"
>>> p2.template = "World"
>>> p2.custom_name = None
>>> p2.display_name = None
>>> p2.arguments = []
>>> prompts = [p1, p2]
>>> import asyncio
>>> try:
... result = asyncio.run(service.register_prompts_bulk(db, prompts))
... except Exception:
... pass
>>> logging.disable(logging.NOTSET)
"""
if not prompts:
return {"created": 0, "updated": 0, "skipped": 0, "failed": 0, "errors": []}
Expand Down Expand Up @@ -1721,20 +1744,31 @@ async def update_prompt(
PromptError: For other update errors

Examples:
>>> import logging
>>> logging.disable(logging.CRITICAL)
>>> from mcpgateway.services.prompt_service import PromptService
>>> from unittest.mock import MagicMock
>>> from unittest.mock import AsyncMock, MagicMock
>>> service = PromptService()
>>> db = MagicMock()
>>> db.execute.return_value.scalar_one_or_none.return_value = MagicMock()
>>> existing = MagicMock()
>>> existing.custom_name = "test-prompt"
>>> existing.name = "test-prompt"
>>> existing.gateway = None
>>> db.execute.return_value.scalar_one_or_none.return_value = existing
>>> db.commit = MagicMock()
>>> db.refresh = MagicMock()
>>> service._notify_prompt_updated = MagicMock()
>>> service._notify_prompt_updated = AsyncMock()
>>> service.convert_prompt_to_read = MagicMock(return_value={})
>>> update = MagicMock()
>>> update.name = None
>>> update.visibility = None
>>> update.team_id = None
>>> import asyncio
>>> try:
... asyncio.run(service.update_prompt(db, 'prompt_name', MagicMock()))
... asyncio.run(service.update_prompt(db, 'prompt_name', update))
... except Exception:
... pass
>>> logging.disable(logging.NOTSET)
"""
try:
# Acquire a row-level lock for the prompt being updated to make
Expand Down Expand Up @@ -1979,22 +2013,25 @@ async def set_prompt_state(self, db: Session, prompt_id: int, activate: bool, us
PermissionError: If user doesn't own the prompt.

Examples:
>>> import logging
>>> logging.disable(logging.CRITICAL)
>>> from mcpgateway.services.prompt_service import PromptService
>>> from unittest.mock import MagicMock
>>> from unittest.mock import AsyncMock, MagicMock
>>> service = PromptService()
>>> db = MagicMock()
>>> prompt = MagicMock()
>>> db.get.return_value = prompt
>>> db.commit = MagicMock()
>>> db.refresh = MagicMock()
>>> service._notify_prompt_activated = MagicMock()
>>> service._notify_prompt_deactivated = MagicMock()
>>> service._notify_prompt_activated = AsyncMock()
>>> service._notify_prompt_deactivated = AsyncMock()
>>> service.convert_prompt_to_read = MagicMock(return_value={})
>>> import asyncio
>>> try:
... asyncio.run(service.set_prompt_state(db, 1, True))
... result = asyncio.run(service.set_prompt_state(db, 1, True))
... except Exception:
... pass
>>> logging.disable(logging.NOTSET)
"""
try:
# Use nowait=True to fail fast if row is locked, preventing lock contention under high load
Expand Down Expand Up @@ -2174,20 +2211,23 @@ async def delete_prompt(self, db: Session, prompt_id: Union[int, str], user_emai
Exception: For unexpected errors.

Examples:
>>> import logging
>>> logging.disable(logging.CRITICAL)
>>> from mcpgateway.services.prompt_service import PromptService
>>> from unittest.mock import MagicMock
>>> from unittest.mock import AsyncMock, MagicMock
>>> service = PromptService()
>>> db = MagicMock()
>>> prompt = MagicMock()
>>> db.get.return_value = prompt
>>> db.delete = MagicMock()
>>> db.commit = MagicMock()
>>> service._notify_prompt_deleted = MagicMock()
>>> service._notify_prompt_deleted = AsyncMock()
>>> import asyncio
>>> try:
... asyncio.run(service.delete_prompt(db, '123'))
... except Exception:
... pass
>>> logging.disable(logging.NOTSET)
"""
try:
prompt = db.get(DbPrompt, prompt_id)
Expand Down
3 changes: 3 additions & 0 deletions mcpgateway/services/root_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,12 +457,15 @@ async def _notify_subscribers(self, event: Dict) -> None:
True

Test error handling with closed queue:
>>> import logging
>>> logging.disable(logging.CRITICAL)
>>> from unittest.mock import AsyncMock
>>> service = RootService()
>>> bad_queue = AsyncMock()
>>> bad_queue.put.side_effect = Exception("Queue error")
>>> service._subscribers.append(bad_queue)
>>> asyncio.run(service._notify_subscribers({"type": "test"}))
>>> logging.disable(logging.NOTSET)
"""
for queue in self._subscribers:
try:
Expand Down
3 changes: 3 additions & 0 deletions mcpgateway/toolops/utils/db_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,16 @@ def query_tool_auth(tool_id, db: Session):
decoded-encoded-val

>>> # Case 2: Exception Handling
>>> import logging
>>> logging.disable(logging.CRITICAL)
>>> mock_db_fail = MagicMock()
>>> mock_db_fail.query.side_effect = Exception("DB Connection Error")

>>> with patch(f"{mod_path}.Tool"):
... auth = query_tool_auth("tool-2", mock_db_fail)
... print(auth)
None
>>> logging.disable(logging.NOTSET)
"""
tool_auth = None
try:
Expand Down
Loading