-
Notifications
You must be signed in to change notification settings - Fork 7
Add run_moderation to the remote provider #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -6,9 +6,10 @@ | |||||||||||
| import random | ||||||||||||
| from abc import ABC, abstractmethod | ||||||||||||
| from dataclasses import dataclass | ||||||||||||
| from enum import Enum, auto | ||||||||||||
| from enum import Enum, StrEnum, auto | ||||||||||||
| from typing import Any, ClassVar, Dict, List, Optional, Tuple, cast | ||||||||||||
| from urllib.parse import urlparse | ||||||||||||
| import uuid | ||||||||||||
| import httpx | ||||||||||||
|
|
||||||||||||
| from llama_stack.apis.inference import ( | ||||||||||||
|
|
@@ -26,6 +27,11 @@ | |||||||||||
| ShieldStore, | ||||||||||||
| ViolationLevel, | ||||||||||||
| ) | ||||||||||||
| try: | ||||||||||||
| from llama_stack.apis.safety import ModerationObject, ModerationObjectResults | ||||||||||||
| _HAS_MODERATION = True | ||||||||||||
| except ImportError: | ||||||||||||
| _HAS_MODERATION = False | ||||||||||||
| from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields | ||||||||||||
| from llama_stack.providers.datatypes import ShieldsProtocolPrivate | ||||||||||||
| from ..config import ( | ||||||||||||
|
|
@@ -39,6 +45,12 @@ | |||||||||||
| # Configure logging | ||||||||||||
| logger = logging.getLogger(__name__) | ||||||||||||
|
|
||||||||||||
| if not _HAS_MODERATION: | ||||||||||||
| logger.warning( | ||||||||||||
| "llama-stack version does not support ModerationObject/ModerationObjectResults. " | ||||||||||||
| "The /v1/openai/v1/moderations endpoint will not be available. " | ||||||||||||
| "Upgrade to llama-stack >= 0.2.18 for moderation support." | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| # Custom exceptions | ||||||||||||
| class DetectorError(Exception): | ||||||||||||
|
|
@@ -128,7 +140,6 @@ def to_dict(self) -> Dict[str, Any]: | |||||||||||
| **({"metadata": self.metadata} if self.metadata else {}), | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class BaseDetector(Safety, ShieldsProtocolPrivate, ABC): | ||||||||||||
| """Base class for all safety detectors""" | ||||||||||||
|
|
||||||||||||
|
|
@@ -1792,7 +1803,95 @@ async def process_with_semaphore(orig_idx, message): | |||||||||||
| }, | ||||||||||||
| ) | ||||||||||||
| ) | ||||||||||||
| if _HAS_MODERATION: | ||||||||||||
| async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: | ||||||||||||
| """ | ||||||||||||
| Runs moderation for each input message | ||||||||||||
| Returns a ModerationObject with one ModerationObjectResults per input | ||||||||||||
| """ | ||||||||||||
| try: | ||||||||||||
| shield_id = await self._get_shield_id_from_model(model) | ||||||||||||
| messages = self._convert_input_to_messages(input) | ||||||||||||
| shield_response = await self.run_shield(shield_id, messages) | ||||||||||||
| metadata = shield_response.violation.metadata if shield_response.violation and shield_response.violation.metadata else {} | ||||||||||||
| results_metadata = metadata.get("results", []) | ||||||||||||
| moderation_results = [] | ||||||||||||
| for idx, msg in enumerate(messages): | ||||||||||||
| # Find the result for this message index | ||||||||||||
| result = next((r for r in results_metadata if r.get("message_index") == idx), None) | ||||||||||||
| categories = {} | ||||||||||||
| category_scores = {} | ||||||||||||
| category_applied_input_types = {} | ||||||||||||
| flagged = False | ||||||||||||
| if result: | ||||||||||||
| cat = result.get("detection_type") | ||||||||||||
| score = result.get("score") | ||||||||||||
| if isinstance(cat, str) and score is not None: | ||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion: Only one category per message is supported. Currently, only one detection_type and score are processed per message. If the API can return multiple categories, update the logic to handle all relevant categories. Suggested implementation: if result:
# Support multiple categories per message
detected_categories = result.get("categories")
detected_scores = result.get("scores")
detected_statuses = result.get("statuses")
# Fallback for single category format
if detected_categories and isinstance(detected_categories, dict):
for cat, status in detected_statuses.items():
score = detected_scores.get(cat)
if isinstance(cat, str) and score is not None:
is_violation = status == "violation"
categories[cat] = is_violation
category_scores[cat] = float(score)
category_applied_input_types[cat] = ["text"]
if is_violation:
flagged = True
else:
cat = result.get("detection_type")
score = result.get("score")
if isinstance(cat, str) and score is not None:
is_violation = result.get("status") == "violation"
categories[cat] = is_violation
category_scores[cat] = float(score)
category_applied_input_types[cat] = ["text"]
flagged = is_violation
meta = result
|
||||||||||||
| is_violation = result.get("status") == "violation" | ||||||||||||
| categories[cat] = is_violation | ||||||||||||
| category_scores[cat] = float(score) | ||||||||||||
| category_applied_input_types[cat] = ["text"] | ||||||||||||
| flagged = is_violation | ||||||||||||
| meta = result | ||||||||||||
| else: | ||||||||||||
| meta = {} | ||||||||||||
| moderation_results.append( | ||||||||||||
| ModerationObjectResults( | ||||||||||||
| flagged=flagged, | ||||||||||||
| categories=categories, | ||||||||||||
| category_applied_input_types=category_applied_input_types, | ||||||||||||
| category_scores=category_scores, | ||||||||||||
| user_message=msg.content, | ||||||||||||
| metadata=meta, | ||||||||||||
| ) | ||||||||||||
| ) | ||||||||||||
| return ModerationObject( | ||||||||||||
| id=str(uuid.uuid4()), | ||||||||||||
| model=model, | ||||||||||||
| results=moderation_results, | ||||||||||||
| ) | ||||||||||||
| except Exception as e: | ||||||||||||
| # On error, return a safe fallback for each input | ||||||||||||
| input_list = [input] if isinstance(input, str) else input | ||||||||||||
| return ModerationObject( | ||||||||||||
| id=str(uuid.uuid4()), | ||||||||||||
| model=model, | ||||||||||||
| results=[ | ||||||||||||
| ModerationObjectResults( | ||||||||||||
| flagged=False, | ||||||||||||
| categories={}, | ||||||||||||
| category_applied_input_types={}, | ||||||||||||
| category_scores={}, | ||||||||||||
| user_message=msg if isinstance(msg, str) else getattr(msg, "content", str(msg)), | ||||||||||||
| metadata={"error": str(e), "status": "error"}, | ||||||||||||
| ) | ||||||||||||
| for msg in input_list | ||||||||||||
| ], | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| async def _get_shield_id_from_model(self, model: str) -> str: | ||||||||||||
| """Map model name to shield_id using provider_resource_id.""" | ||||||||||||
| shields_response = await self.list_shields() | ||||||||||||
| matching_shields = [ | ||||||||||||
| shield.identifier | ||||||||||||
| for shield in shields_response.data | ||||||||||||
| if shield.provider_resource_id == model | ||||||||||||
| ] | ||||||||||||
| if not matching_shields: | ||||||||||||
| raise ValueError(f"No shield found for model '{model}'. Available shields: {[s.identifier for s in shields_response.data]}") | ||||||||||||
| if len(matching_shields) > 1: | ||||||||||||
| raise ValueError(f"Multiple shields found for model '{model}': {matching_shields}") | ||||||||||||
| return matching_shields[0] | ||||||||||||
|
|
||||||||||||
| def _convert_input_to_messages(self, input: str | list[str]) -> List[Message]: | ||||||||||||
| """Convert string input(s) to UserMessage objects.""" | ||||||||||||
| if isinstance(input, str): | ||||||||||||
| inputs = [input] | ||||||||||||
| else: | ||||||||||||
| inputs = input | ||||||||||||
|
||||||||||||
| if isinstance(input, str): | |
| inputs = [input] | |
| else: | |
| inputs = input | |
| inputs = [input] if isinstance(input, str) else input |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sourcery-ai review
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! I'm generating a new review now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @m-misiura, I've posted a new review for you!
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| import pytest | ||
| from unittest.mock import AsyncMock, MagicMock | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_run_moderation_flagged(): | ||
|
Comment on lines
+4
to
+5
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Missing test for empty input and single string input edge cases. Please add tests for empty list input and single string input to ensure run_moderation handles these cases correctly. |
||
| from llama_stack_provider_trustyai_fms.detectors.base import DetectorProvider | ||
|
|
||
| provider = DetectorProvider(detectors={}) | ||
| provider._get_shield_id_from_model = AsyncMock(return_value="test_shield") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Consider adding a test for multiple shields found for a model. Please add a test that triggers the multiple shields exception and verifies the error is correctly reflected in the moderation results metadata. Suggested implementation: import pytest
from unittest.mock import AsyncMock, MagicMock
class MultipleShieldsFoundError(Exception):
pass
@pytest.mark.asyncio
async def test_run_moderation_multiple_shields_error():
from llama_stack_provider_trustyai_fms.detectors.base import DetectorProvider
provider = DetectorProvider(detectors={})
# Simulate multiple shields found by raising the error
provider._get_shield_id_from_model = AsyncMock(side_effect=MultipleShieldsFoundError("Multiple shields found for model"))
provider._convert_input_to_messages = MagicMock(return_value=[
MagicMock(content="test message")
])
# Run moderation and check error in metadata
result = await provider.run_moderation("test_model", "test input")
assert result["metadata"]["error"] == "Multiple shields found for model"
@pytest.mark.asyncio
async def test_run_moderation_flagged():
|
||
| provider._convert_input_to_messages = MagicMock(return_value=[ | ||
| MagicMock(content="bad message"), MagicMock(content="good message") | ||
| ]) | ||
| # Simulate shield_response with one flagged and one not | ||
| class FakeViolation: | ||
| violation_level = "error" | ||
| user_message = "violation" | ||
| metadata = { | ||
| "results": [ | ||
| {"message_index": 0, "detection_type": "LABEL_1", "score": 0.99, "status": "violation"}, | ||
| {"message_index": 1, "detection_type": None, "score": None, "status": "pass"}, | ||
| ] | ||
| } | ||
| class FakeShieldResponse: | ||
| violation = FakeViolation() | ||
| provider.run_shield = AsyncMock(return_value=FakeShieldResponse()) | ||
|
|
||
| result = await provider.run_moderation(["bad message", "good message"], "test_model") | ||
| assert len(result.results) == 2 | ||
Check noticeCode scanning / Bandit Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
|
||
| assert result.results[0].flagged is True | ||
Check noticeCode scanning / Bandit Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
|
||
| assert result.results[1].flagged is False | ||
Check noticeCode scanning / Bandit Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
|
||
| assert result.results[0].user_message == "bad message" | ||
Check noticeCode scanning / Bandit Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
|
||
| assert result.results[1].user_message == "good message" | ||
Check noticeCode scanning / Bandit Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_run_moderation_error(): | ||
| from llama_stack_provider_trustyai_fms.detectors.base import DetectorProvider | ||
|
|
||
| provider = DetectorProvider(detectors={}) | ||
| provider._get_shield_id_from_model = AsyncMock(side_effect=Exception("fail")) | ||
| provider._convert_input_to_messages = MagicMock(return_value=[MagicMock(content="msg")]) | ||
|
|
||
| result = await provider.run_moderation(["msg"], "test_model") | ||
| assert len(result.results) == 1 | ||
Check noticeCode scanning / Bandit Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
|
||
| assert result.results[0].flagged is False | ||
|
||
| assert "fail" in result.results[0].metadata["error"] | ||
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (code-quality): We've found these issues:
move-assign-in-block)use-named-expression)low-code-quality)Explanation
The quality score for this function is below the quality threshold of 25%.
This score is a combination of the method length, cognitive complexity and working memory.
How can you solve this?
It might be worth refactoring this function to make it shorter and more readable.
their own functions. This is the most important thing you can do - ideally a
function should be less than 10 lines.
sits together within the function rather than being scattered.