Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 101 additions & 2 deletions llama_stack_provider_trustyai_fms/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import random
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum, auto
from enum import Enum, StrEnum, auto
from typing import Any, ClassVar, Dict, List, Optional, Tuple, cast
from urllib.parse import urlparse
import uuid
import httpx

from llama_stack.apis.inference import (
Expand All @@ -26,6 +27,11 @@
ShieldStore,
ViolationLevel,
)
try:
from llama_stack.apis.safety import ModerationObject, ModerationObjectResults
_HAS_MODERATION = True
except ImportError:
_HAS_MODERATION = False
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from ..config import (
Expand All @@ -39,6 +45,12 @@
# Configure logging
logger = logging.getLogger(__name__)

if not _HAS_MODERATION:
logger.warning(
"llama-stack version does not support ModerationObject/ModerationObjectResults. "
"The /v1/openai/v1/moderations endpoint will not be available. "
"Upgrade to llama-stack >= 0.2.18 for moderation support."
)

# Custom exceptions
class DetectorError(Exception):
Expand Down Expand Up @@ -128,7 +140,6 @@ def to_dict(self) -> Dict[str, Any]:
**({"metadata": self.metadata} if self.metadata else {}),
}


class BaseDetector(Safety, ShieldsProtocolPrivate, ABC):
"""Base class for all safety detectors"""

Expand Down Expand Up @@ -1792,7 +1803,95 @@ async def process_with_semaphore(orig_idx, message):
},
)
)
if _HAS_MODERATION:
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (code-quality): We've found these issues:


Explanation

The quality score for this function is below the quality threshold of 25%.
This score is a combination of the method length, cognitive complexity and working memory.

How can you solve this?

It might be worth refactoring this function to make it shorter and more readable.

  • Reduce the function length by extracting pieces of functionality out into
    their own functions. This is the most important thing you can do - ideally a
    function should be less than 10 lines.
  • Reduce nesting, perhaps by introducing guard clauses to return early.
  • Ensure that variables are tightly scoped, so that code using related concepts
    sits together within the function rather than being scattered.

"""
Runs moderation for each input message
Returns a ModerationObject with one ModerationObjectResults per input
"""
try:
shield_id = await self._get_shield_id_from_model(model)
messages = self._convert_input_to_messages(input)
shield_response = await self.run_shield(shield_id, messages)
metadata = shield_response.violation.metadata if shield_response.violation and shield_response.violation.metadata else {}
results_metadata = metadata.get("results", [])
moderation_results = []
for idx, msg in enumerate(messages):
# Find the result for this message index
result = next((r for r in results_metadata if r.get("message_index") == idx), None)
categories = {}
category_scores = {}
category_applied_input_types = {}
flagged = False
if result:
cat = result.get("detection_type")
score = result.get("score")
if isinstance(cat, str) and score is not None:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Only one category per message is supported.

Currently, only one detection_type and score are processed per message. If the API can return multiple categories, update the logic to handle all relevant categories.

Suggested implementation:

                    if result:
                        # Support multiple categories per message
                        detected_categories = result.get("categories")
                        detected_scores = result.get("scores")
                        detected_statuses = result.get("statuses")
                        # Fallback for single category format
                        if detected_categories and isinstance(detected_categories, dict):
                            for cat, status in detected_statuses.items():
                                score = detected_scores.get(cat)
                                if isinstance(cat, str) and score is not None:
                                    is_violation = status == "violation"
                                    categories[cat] = is_violation
                                    category_scores[cat] = float(score)
                                    category_applied_input_types[cat] = ["text"]
                                    if is_violation:
                                        flagged = True
                        else:
                            cat = result.get("detection_type")
                            score = result.get("score")
                            if isinstance(cat, str) and score is not None:
                                is_violation = result.get("status") == "violation"
                                categories[cat] = is_violation
                                category_scores[cat] = float(score)
                                category_applied_input_types[cat] = ["text"]
                                flagged = is_violation
                        meta = result
  • You may need to adjust the keys (categories, scores, statuses) to match the actual API response format if they differ.
  • If the API returns a list of category objects instead of dicts, iterate accordingly.
  • Ensure that the rest of the code (e.g., how ModerationObjectResults uses these dicts) supports multiple categories.

is_violation = result.get("status") == "violation"
categories[cat] = is_violation
category_scores[cat] = float(score)
category_applied_input_types[cat] = ["text"]
flagged = is_violation
meta = result
else:
meta = {}
moderation_results.append(
ModerationObjectResults(
flagged=flagged,
categories=categories,
category_applied_input_types=category_applied_input_types,
category_scores=category_scores,
user_message=msg.content,
metadata=meta,
)
)
return ModerationObject(
id=str(uuid.uuid4()),
model=model,
results=moderation_results,
)
except Exception as e:
# On error, return a safe fallback for each input
input_list = [input] if isinstance(input, str) else input
return ModerationObject(
id=str(uuid.uuid4()),
model=model,
results=[
ModerationObjectResults(
flagged=False,
categories={},
category_applied_input_types={},
category_scores={},
user_message=msg if isinstance(msg, str) else getattr(msg, "content", str(msg)),
metadata={"error": str(e), "status": "error"},
)
for msg in input_list
],
)

async def _get_shield_id_from_model(self, model: str) -> str:
"""Map model name to shield_id using provider_resource_id."""
shields_response = await self.list_shields()
matching_shields = [
shield.identifier
for shield in shields_response.data
if shield.provider_resource_id == model
]
if not matching_shields:
raise ValueError(f"No shield found for model '{model}'. Available shields: {[s.identifier for s in shields_response.data]}")
if len(matching_shields) > 1:
raise ValueError(f"Multiple shields found for model '{model}': {matching_shields}")
return matching_shields[0]

def _convert_input_to_messages(self, input: str | list[str]) -> List[Message]:
"""Convert string input(s) to UserMessage objects."""
if isinstance(input, str):
inputs = [input]
else:
inputs = input
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code-quality): Replace if statement with if expression (assign-if-exp)

Suggested change
if isinstance(input, str):
inputs = [input]
else:
inputs = input
inputs = [input] if isinstance(input, str) else input

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sourcery-ai review

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I'm generating a new review now.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @m-misiura, I've posted a new review for you!

return [UserMessage(content=text) for text in inputs]


async def shutdown(self) -> None:
"""Cleanup resources"""
logger.info(f"Provider {self._provider_id} shutting down")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "llama-stack-provider-trustyai-fms"
version = "0.2.0"
version = "0.2.1"
description = "Remote safety provider for Llama Stack integrating FMS Guardrails Orchestrator and community detectors"
authors = [
{name = "GitHub: m-misiura"}
Expand Down
45 changes: 45 additions & 0 deletions tests/test_moderation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest
from unittest.mock import AsyncMock, MagicMock

@pytest.mark.asyncio
async def test_run_moderation_flagged():
Comment on lines +4 to +5
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Missing test for empty input and single string input edge cases.

Please add tests for empty list input and single string input to ensure run_moderation handles these cases correctly.

from llama_stack_provider_trustyai_fms.detectors.base import DetectorProvider

provider = DetectorProvider(detectors={})
provider._get_shield_id_from_model = AsyncMock(return_value="test_shield")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Consider adding a test for multiple shields found for a model.

Please add a test that triggers the multiple shields exception and verifies the error is correctly reflected in the moderation results metadata.

Suggested implementation:

import pytest
from unittest.mock import AsyncMock, MagicMock

class MultipleShieldsFoundError(Exception):
    pass

@pytest.mark.asyncio
async def test_run_moderation_multiple_shields_error():
    from llama_stack_provider_trustyai_fms.detectors.base import DetectorProvider

    provider = DetectorProvider(detectors={})
    # Simulate multiple shields found by raising the error
    provider._get_shield_id_from_model = AsyncMock(side_effect=MultipleShieldsFoundError("Multiple shields found for model"))
    provider._convert_input_to_messages = MagicMock(return_value=[
        MagicMock(content="test message")
    ])

    # Run moderation and check error in metadata
    result = await provider.run_moderation("test_model", "test input")
    assert result["metadata"]["error"] == "Multiple shields found for model"

@pytest.mark.asyncio
async def test_run_moderation_flagged():
  • If MultipleShieldsFoundError is defined elsewhere in your codebase, import it instead of defining it in the test file.
  • Ensure that provider.run_moderation correctly catches the exception and sets the error in result["metadata"]["error"]. If not, you may need to update the implementation to handle this case.

provider._convert_input_to_messages = MagicMock(return_value=[
MagicMock(content="bad message"), MagicMock(content="good message")
])
# Simulate shield_response with one flagged and one not
class FakeViolation:
violation_level = "error"
user_message = "violation"
metadata = {
"results": [
{"message_index": 0, "detection_type": "LABEL_1", "score": 0.99, "status": "violation"},
{"message_index": 1, "detection_type": None, "score": None, "status": "pass"},
]
}
class FakeShieldResponse:
violation = FakeViolation()
provider.run_shield = AsyncMock(return_value=FakeShieldResponse())

result = await provider.run_moderation(["bad message", "good message"], "test_model")
assert len(result.results) == 2

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert result.results[0].flagged is True

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert result.results[1].flagged is False

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert result.results[0].user_message == "bad message"

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert result.results[1].user_message == "good message"

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.

@pytest.mark.asyncio
async def test_run_moderation_error():
from llama_stack_provider_trustyai_fms.detectors.base import DetectorProvider

provider = DetectorProvider(detectors={})
provider._get_shield_id_from_model = AsyncMock(side_effect=Exception("fail"))
provider._convert_input_to_messages = MagicMock(return_value=[MagicMock(content="msg")])

result = await provider.run_moderation(["msg"], "test_model")
assert len(result.results) == 1

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert result.results[0].flagged is False

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert "fail" in result.results[0].metadata["error"]
Loading