Skip to content

Commit 923324f

Browse files
authored
Merge pull request trustyai-explainability#21 from m-misiura/lls-0.2.19-no-openai-in-moderations
Add run_moderation to the remote provider
2 parents 5b842ab + 1237545 commit 923324f

File tree

3 files changed

+182
-3
lines changed

3 files changed

+182
-3
lines changed

llama_stack_provider_trustyai_fms/detectors/base.py

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import random
77
from abc import ABC, abstractmethod
88
from dataclasses import dataclass
9-
from enum import Enum, auto
9+
from enum import Enum, StrEnum, auto
1010
from typing import Any, ClassVar, Dict, List, Optional, Tuple, cast
1111
from urllib.parse import urlparse
12+
import uuid
1213
import httpx
1314

1415
from llama_stack.apis.inference import (
@@ -26,6 +27,11 @@
2627
ShieldStore,
2728
ViolationLevel,
2829
)
30+
try:
31+
from llama_stack.apis.safety import ModerationObject, ModerationObjectResults
32+
_HAS_MODERATION = True
33+
except ImportError:
34+
_HAS_MODERATION = False
2935
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
3036
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
3137
from ..config import (
@@ -39,6 +45,12 @@
3945
# Configure logging
4046
logger = logging.getLogger(__name__)
4147

48+
if not _HAS_MODERATION:
49+
logger.warning(
50+
"llama-stack version does not support ModerationObject/ModerationObjectResults. "
51+
"The /v1/openai/v1/moderations endpoint will not be available. "
52+
"Upgrade to llama-stack >= 0.2.18 for moderation support."
53+
)
4254

4355
# Custom exceptions
4456
class DetectorError(Exception):
@@ -128,7 +140,6 @@ def to_dict(self) -> Dict[str, Any]:
128140
**({"metadata": self.metadata} if self.metadata else {}),
129141
}
130142

131-
132143
class BaseDetector(Safety, ShieldsProtocolPrivate, ABC):
133144
"""Base class for all safety detectors"""
134145

@@ -1792,7 +1803,105 @@ async def process_with_semaphore(orig_idx, message):
17921803
},
17931804
)
17941805
)
1806+
if _HAS_MODERATION:
1807+
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
1808+
"""
1809+
Runs moderation for each input message.
1810+
Returns a ModerationObject with one ModerationObjectResults per input.
1811+
"""
1812+
texts = input # Avoid shadowing the built-in 'input'
1813+
try:
1814+
# Shield ID caching for performance
1815+
if not hasattr(self, "_model_to_shield_id"):
1816+
self._model_to_shield_id = {}
1817+
if model in self._model_to_shield_id:
1818+
shield_id = self._model_to_shield_id[model]
1819+
else:
1820+
shield_id = await self._get_shield_id_from_model(model)
1821+
self._model_to_shield_id[model] = shield_id
1822+
1823+
messages = self._convert_input_to_messages(texts)
1824+
shield_response = await self.run_shield(shield_id, messages)
1825+
metadata = shield_response.violation.metadata if shield_response.violation and shield_response.violation.metadata else {}
1826+
results_metadata = metadata.get("results", [])
1827+
# Index results by message_index for O(1) lookup
1828+
results_by_index = {r.get("message_index"): r for r in results_metadata}
1829+
moderation_results = []
1830+
for idx, msg in enumerate(messages):
1831+
result = results_by_index.get(idx)
1832+
categories = {}
1833+
category_scores = {}
1834+
category_applied_input_types = {}
1835+
flagged = False
1836+
if result:
1837+
cat = result.get("detection_type")
1838+
score = result.get("score")
1839+
if isinstance(cat, str) and score is not None:
1840+
is_violation = result.get("status") == "violation"
1841+
categories[cat] = is_violation
1842+
category_scores[cat] = float(score)
1843+
category_applied_input_types[cat] = ["text"]
1844+
flagged = is_violation
1845+
meta = result
1846+
else:
1847+
meta = {}
1848+
moderation_results.append(
1849+
ModerationObjectResults(
1850+
flagged=flagged,
1851+
categories=categories,
1852+
category_applied_input_types=category_applied_input_types,
1853+
category_scores=category_scores,
1854+
user_message=msg.content,
1855+
metadata=meta,
1856+
)
1857+
)
1858+
return ModerationObject(
1859+
id=str(uuid.uuid4()),
1860+
model=model,
1861+
results=moderation_results,
1862+
)
1863+
except Exception as e:
1864+
input_list = [texts] if isinstance(texts, str) else texts
1865+
return ModerationObject(
1866+
id=str(uuid.uuid4()),
1867+
model=model,
1868+
results=[
1869+
ModerationObjectResults(
1870+
flagged=False,
1871+
categories={},
1872+
category_applied_input_types={},
1873+
category_scores={},
1874+
user_message=msg if isinstance(msg, str) else getattr(msg, "content", str(msg)),
1875+
metadata={"error": str(e), "status": "error"},
1876+
)
1877+
for msg in input_list
1878+
],
1879+
)
1880+
1881+
async def _get_shield_id_from_model(self, model: str) -> str:
1882+
"""Map model name to shield_id using provider_resource_id."""
1883+
shields_response = await self.list_shields()
1884+
matching_shields = [
1885+
shield.identifier
1886+
for shield in shields_response.data
1887+
if shield.provider_resource_id == model
1888+
]
1889+
if not matching_shields:
1890+
raise ValueError(f"No shield found for model '{model}'. Available shields: {[s.identifier for s in shields_response.data]}")
1891+
if len(matching_shields) > 1:
1892+
raise ValueError(f"Multiple shields found for model '{model}': {matching_shields}")
1893+
return matching_shields[0]
1894+
1895+
1896+
def _convert_input_to_messages(self, texts: str | list[str]) -> List[Message]:
1897+
"""Convert string input(s) to UserMessage objects."""
1898+
if isinstance(texts, str):
1899+
inputs = [texts]
1900+
else:
1901+
inputs = texts
1902+
return [UserMessage(content=text) for text in inputs]
17951903

1904+
17961905
async def shutdown(self) -> None:
17971906
"""Cleanup resources"""
17981907
logger.info(f"Provider {self._provider_id} shutting down")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "llama-stack-provider-trustyai-fms"
7-
version = "0.2.0"
7+
version = "0.2.1"
88
description = "Remote safety provider for Llama Stack integrating FMS Guardrails Orchestrator and community detectors"
99
authors = [
1010
{name = "GitHub: m-misiura"}

tests/test_moderation.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pytest
2+
from unittest.mock import AsyncMock, MagicMock
3+
4+
@pytest.mark.asyncio
5+
async def test_run_moderation_flagged():
6+
from llama_stack_provider_trustyai_fms.detectors.base import DetectorProvider
7+
8+
provider = DetectorProvider(detectors={})
9+
provider._get_shield_id_from_model = AsyncMock(return_value="test_shield")
10+
provider._convert_input_to_messages = MagicMock(return_value=[
11+
MagicMock(content="bad message"), MagicMock(content="good message")
12+
])
13+
# Simulate shield_response with one flagged and one not
14+
class FakeViolation:
15+
violation_level = "error"
16+
user_message = "violation"
17+
metadata = {
18+
"results": [
19+
{"message_index": 0, "detection_type": "LABEL_1", "score": 0.99, "status": "violation"},
20+
{"message_index": 1, "detection_type": None, "score": None, "status": "pass"},
21+
]
22+
}
23+
class FakeShieldResponse:
24+
violation = FakeViolation()
25+
provider.run_shield = AsyncMock(return_value=FakeShieldResponse())
26+
27+
result = await provider.run_moderation(["bad message", "good message"], "test_model")
28+
assert len(result.results) == 2
29+
assert result.results[0].flagged is True
30+
assert result.results[1].flagged is False
31+
assert result.results[0].user_message == "bad message"
32+
assert result.results[1].user_message == "good message"
33+
34+
@pytest.mark.asyncio
35+
async def test_run_moderation_error():
36+
from llama_stack_provider_trustyai_fms.detectors.base import DetectorProvider
37+
38+
provider = DetectorProvider(detectors={})
39+
provider._get_shield_id_from_model = AsyncMock(side_effect=Exception("fail"))
40+
provider._convert_input_to_messages = MagicMock(return_value=[MagicMock(content="msg")])
41+
42+
result = await provider.run_moderation(["msg"], "test_model")
43+
assert len(result.results) == 1
44+
assert result.results[0].flagged is False
45+
assert "fail" in result.results[0].metadata["error"]
46+
47+
@pytest.mark.asyncio
48+
async def test_run_moderation_empty_input():
49+
from llama_stack_provider_trustyai_fms.detectors.base import DetectorProvider
50+
provider = DetectorProvider(detectors={})
51+
provider._get_shield_id_from_model = AsyncMock(return_value="test_shield")
52+
provider._convert_input_to_messages = MagicMock(return_value=[])
53+
provider.run_shield = AsyncMock()
54+
result = await provider.run_moderation([], "test_model")
55+
assert len(result.results) == 0
56+
57+
@pytest.mark.asyncio
58+
async def test_run_moderation_single_string_input():
59+
from llama_stack_provider_trustyai_fms.detectors.base import DetectorProvider
60+
provider = DetectorProvider(detectors={})
61+
provider._get_shield_id_from_model = AsyncMock(return_value="test_shield")
62+
provider._convert_input_to_messages = MagicMock(return_value=[MagicMock(content="one message")])
63+
provider.run_shield = AsyncMock(return_value=MagicMock(
64+
violation=MagicMock(metadata={"results": [
65+
{"message_index": 0, "detection_type": None, "score": None, "status": "pass"}
66+
]})
67+
))
68+
result = await provider.run_moderation("one message", "test_model")
69+
assert len(result.results) == 1
70+
assert result.results[0].user_message == "one message"

0 commit comments

Comments
 (0)