diff --git a/llama_stack_provider_trustyai_fms/detectors/base.py b/llama_stack_provider_trustyai_fms/detectors/base.py index 6c4c4b9..75d3bd5 100644 --- a/llama_stack_provider_trustyai_fms/detectors/base.py +++ b/llama_stack_provider_trustyai_fms/detectors/base.py @@ -6,9 +6,10 @@ import random from abc import ABC, abstractmethod from dataclasses import dataclass -from enum import Enum, auto +from enum import Enum, StrEnum, auto from typing import Any, ClassVar, Dict, List, Optional, Tuple, cast from urllib.parse import urlparse +import uuid import httpx from llama_stack.apis.inference import ( @@ -26,6 +27,11 @@ ShieldStore, ViolationLevel, ) +try: + from llama_stack.apis.safety import ModerationObject, ModerationObjectResults + _HAS_MODERATION = True +except ImportError: + _HAS_MODERATION = False from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields from llama_stack.providers.datatypes import ShieldsProtocolPrivate from ..config import ( @@ -39,6 +45,12 @@ # Configure logging logger = logging.getLogger(__name__) +if not _HAS_MODERATION: + logger.warning( + "llama-stack version does not support ModerationObject/ModerationObjectResults. " + "The /v1/openai/v1/moderations endpoint will not be available. " + "Upgrade to llama-stack >= 0.2.18 for moderation support." + ) # Custom exceptions class DetectorError(Exception): @@ -128,7 +140,6 @@ def to_dict(self) -> Dict[str, Any]: **({"metadata": self.metadata} if self.metadata else {}), } - class BaseDetector(Safety, ShieldsProtocolPrivate, ABC): """Base class for all safety detectors""" @@ -1792,7 +1803,105 @@ async def process_with_semaphore(orig_idx, message): }, ) ) + if _HAS_MODERATION: + async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: + """ + Runs moderation for each input message. + Returns a ModerationObject with one ModerationObjectResults per input. + """ + texts = input # Avoid shadowing the built-in 'input' + try: + # Shield ID caching for performance + if not hasattr(self, "_model_to_shield_id"): + self._model_to_shield_id = {} + if model in self._model_to_shield_id: + shield_id = self._model_to_shield_id[model] + else: + shield_id = await self._get_shield_id_from_model(model) + self._model_to_shield_id[model] = shield_id + + messages = self._convert_input_to_messages(texts) + shield_response = await self.run_shield(shield_id, messages) + metadata = shield_response.violation.metadata if shield_response.violation and shield_response.violation.metadata else {} + results_metadata = metadata.get("results", []) + # Index results by message_index for O(1) lookup + results_by_index = {r.get("message_index"): r for r in results_metadata} + moderation_results = [] + for idx, msg in enumerate(messages): + result = results_by_index.get(idx) + categories = {} + category_scores = {} + category_applied_input_types = {} + flagged = False + if result: + cat = result.get("detection_type") + score = result.get("score") + if isinstance(cat, str) and score is not None: + is_violation = result.get("status") == "violation" + categories[cat] = is_violation + category_scores[cat] = float(score) + category_applied_input_types[cat] = ["text"] + flagged = is_violation + meta = result + else: + meta = {} + moderation_results.append( + ModerationObjectResults( + flagged=flagged, + categories=categories, + category_applied_input_types=category_applied_input_types, + category_scores=category_scores, + user_message=msg.content, + metadata=meta, + ) + ) + return ModerationObject( + id=str(uuid.uuid4()), + model=model, + results=moderation_results, + ) + except Exception as e: + input_list = [texts] if isinstance(texts, str) else texts + return ModerationObject( + id=str(uuid.uuid4()), + model=model, + results=[ + ModerationObjectResults( + flagged=False, + categories={}, + category_applied_input_types={}, + category_scores={}, + user_message=msg if isinstance(msg, str) else getattr(msg, "content", str(msg)), + metadata={"error": str(e), "status": "error"}, + ) + for msg in input_list + ], + ) + + async def _get_shield_id_from_model(self, model: str) -> str: + """Map model name to shield_id using provider_resource_id.""" + shields_response = await self.list_shields() + matching_shields = [ + shield.identifier + for shield in shields_response.data + if shield.provider_resource_id == model + ] + if not matching_shields: + raise ValueError(f"No shield found for model '{model}'. Available shields: {[s.identifier for s in shields_response.data]}") + if len(matching_shields) > 1: + raise ValueError(f"Multiple shields found for model '{model}': {matching_shields}") + return matching_shields[0] + + + def _convert_input_to_messages(self, texts: str | list[str]) -> List[Message]: + """Convert string input(s) to UserMessage objects.""" + if isinstance(texts, str): + inputs = [texts] + else: + inputs = texts + return [UserMessage(content=text) for text in inputs] + async def shutdown(self) -> None: """Cleanup resources""" logger.info(f"Provider {self._provider_id} shutting down") diff --git a/pyproject.toml b/pyproject.toml index a4b75a2..9428c9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "llama-stack-provider-trustyai-fms" -version = "0.2.0" +version = "0.2.1" description = "Remote safety provider for Llama Stack integrating FMS Guardrails Orchestrator and community detectors" authors = [ {name = "GitHub: m-misiura"} diff --git a/tests/test_moderation.py b/tests/test_moderation.py new file mode 100644 index 0000000..3d64d13 --- /dev/null +++ b/tests/test_moderation.py @@ -0,0 +1,70 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock + +@pytest.mark.asyncio +async def test_run_moderation_flagged(): + from llama_stack_provider_trustyai_fms.detectors.base import DetectorProvider + + provider = DetectorProvider(detectors={}) + provider._get_shield_id_from_model = AsyncMock(return_value="test_shield") + provider._convert_input_to_messages = MagicMock(return_value=[ + MagicMock(content="bad message"), MagicMock(content="good message") + ]) + # Simulate shield_response with one flagged and one not + class FakeViolation: + violation_level = "error" + user_message = "violation" + metadata = { + "results": [ + {"message_index": 0, "detection_type": "LABEL_1", "score": 0.99, "status": "violation"}, + {"message_index": 1, "detection_type": None, "score": None, "status": "pass"}, + ] + } + class FakeShieldResponse: + violation = FakeViolation() + provider.run_shield = AsyncMock(return_value=FakeShieldResponse()) + + result = await provider.run_moderation(["bad message", "good message"], "test_model") + assert len(result.results) == 2 + assert result.results[0].flagged is True + assert result.results[1].flagged is False + assert result.results[0].user_message == "bad message" + assert result.results[1].user_message == "good message" + +@pytest.mark.asyncio +async def test_run_moderation_error(): + from llama_stack_provider_trustyai_fms.detectors.base import DetectorProvider + + provider = DetectorProvider(detectors={}) + provider._get_shield_id_from_model = AsyncMock(side_effect=Exception("fail")) + provider._convert_input_to_messages = MagicMock(return_value=[MagicMock(content="msg")]) + + result = await provider.run_moderation(["msg"], "test_model") + assert len(result.results) == 1 + assert result.results[0].flagged is False + assert "fail" in result.results[0].metadata["error"] + +@pytest.mark.asyncio +async def test_run_moderation_empty_input(): + from llama_stack_provider_trustyai_fms.detectors.base import DetectorProvider + provider = DetectorProvider(detectors={}) + provider._get_shield_id_from_model = AsyncMock(return_value="test_shield") + provider._convert_input_to_messages = MagicMock(return_value=[]) + provider.run_shield = AsyncMock() + result = await provider.run_moderation([], "test_model") + assert len(result.results) == 0 + +@pytest.mark.asyncio +async def test_run_moderation_single_string_input(): + from llama_stack_provider_trustyai_fms.detectors.base import DetectorProvider + provider = DetectorProvider(detectors={}) + provider._get_shield_id_from_model = AsyncMock(return_value="test_shield") + provider._convert_input_to_messages = MagicMock(return_value=[MagicMock(content="one message")]) + provider.run_shield = AsyncMock(return_value=MagicMock( + violation=MagicMock(metadata={"results": [ + {"message_index": 0, "detection_type": None, "score": None, "status": "pass"} + ]}) + )) + result = await provider.run_moderation("one message", "test_model") + assert len(result.results) == 1 + assert result.results[0].user_message == "one message" \ No newline at end of file