|
6 | 6 | import random |
7 | 7 | from abc import ABC, abstractmethod |
8 | 8 | from dataclasses import dataclass |
9 | | -from enum import Enum, auto |
| 9 | +from enum import Enum, StrEnum, auto |
10 | 10 | from typing import Any, ClassVar, Dict, List, Optional, Tuple, cast |
11 | 11 | from urllib.parse import urlparse |
| 12 | +import uuid |
12 | 13 | import httpx |
13 | 14 |
|
14 | 15 | from llama_stack.apis.inference import ( |
|
26 | 27 | ShieldStore, |
27 | 28 | ViolationLevel, |
28 | 29 | ) |
| 30 | +try: |
| 31 | + from llama_stack.apis.safety import ModerationObject, ModerationObjectResults |
| 32 | + _HAS_MODERATION = True |
| 33 | +except ImportError: |
| 34 | + _HAS_MODERATION = False |
29 | 35 | from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields |
30 | 36 | from llama_stack.providers.datatypes import ShieldsProtocolPrivate |
31 | 37 | from ..config import ( |
|
39 | 45 | # Configure logging |
40 | 46 | logger = logging.getLogger(__name__) |
41 | 47 |
|
| 48 | +if not _HAS_MODERATION: |
| 49 | + logger.warning( |
| 50 | + "llama-stack version does not support ModerationObject/ModerationObjectResults. " |
| 51 | + "The /v1/openai/v1/moderations endpoint will not be available. " |
| 52 | + "Upgrade to llama-stack >= 0.2.18 for moderation support." |
| 53 | + ) |
42 | 54 |
|
43 | 55 | # Custom exceptions |
44 | 56 | class DetectorError(Exception): |
@@ -128,7 +140,6 @@ def to_dict(self) -> Dict[str, Any]: |
128 | 140 | **({"metadata": self.metadata} if self.metadata else {}), |
129 | 141 | } |
130 | 142 |
|
131 | | - |
132 | 143 | class BaseDetector(Safety, ShieldsProtocolPrivate, ABC): |
133 | 144 | """Base class for all safety detectors""" |
134 | 145 |
|
@@ -1792,7 +1803,105 @@ async def process_with_semaphore(orig_idx, message): |
1792 | 1803 | }, |
1793 | 1804 | ) |
1794 | 1805 | ) |
| 1806 | + if _HAS_MODERATION: |
| 1807 | + async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: |
| 1808 | + """ |
| 1809 | + Runs moderation for each input message. |
| 1810 | + Returns a ModerationObject with one ModerationObjectResults per input. |
| 1811 | + """ |
| 1812 | + texts = input # Avoid shadowing the built-in 'input' |
| 1813 | + try: |
| 1814 | + # Shield ID caching for performance |
| 1815 | + if not hasattr(self, "_model_to_shield_id"): |
| 1816 | + self._model_to_shield_id = {} |
| 1817 | + if model in self._model_to_shield_id: |
| 1818 | + shield_id = self._model_to_shield_id[model] |
| 1819 | + else: |
| 1820 | + shield_id = await self._get_shield_id_from_model(model) |
| 1821 | + self._model_to_shield_id[model] = shield_id |
| 1822 | + |
| 1823 | + messages = self._convert_input_to_messages(texts) |
| 1824 | + shield_response = await self.run_shield(shield_id, messages) |
| 1825 | + metadata = shield_response.violation.metadata if shield_response.violation and shield_response.violation.metadata else {} |
| 1826 | + results_metadata = metadata.get("results", []) |
| 1827 | + # Index results by message_index for O(1) lookup |
| 1828 | + results_by_index = {r.get("message_index"): r for r in results_metadata} |
| 1829 | + moderation_results = [] |
| 1830 | + for idx, msg in enumerate(messages): |
| 1831 | + result = results_by_index.get(idx) |
| 1832 | + categories = {} |
| 1833 | + category_scores = {} |
| 1834 | + category_applied_input_types = {} |
| 1835 | + flagged = False |
| 1836 | + if result: |
| 1837 | + cat = result.get("detection_type") |
| 1838 | + score = result.get("score") |
| 1839 | + if isinstance(cat, str) and score is not None: |
| 1840 | + is_violation = result.get("status") == "violation" |
| 1841 | + categories[cat] = is_violation |
| 1842 | + category_scores[cat] = float(score) |
| 1843 | + category_applied_input_types[cat] = ["text"] |
| 1844 | + flagged = is_violation |
| 1845 | + meta = result |
| 1846 | + else: |
| 1847 | + meta = {} |
| 1848 | + moderation_results.append( |
| 1849 | + ModerationObjectResults( |
| 1850 | + flagged=flagged, |
| 1851 | + categories=categories, |
| 1852 | + category_applied_input_types=category_applied_input_types, |
| 1853 | + category_scores=category_scores, |
| 1854 | + user_message=msg.content, |
| 1855 | + metadata=meta, |
| 1856 | + ) |
| 1857 | + ) |
| 1858 | + return ModerationObject( |
| 1859 | + id=str(uuid.uuid4()), |
| 1860 | + model=model, |
| 1861 | + results=moderation_results, |
| 1862 | + ) |
| 1863 | + except Exception as e: |
| 1864 | + input_list = [texts] if isinstance(texts, str) else texts |
| 1865 | + return ModerationObject( |
| 1866 | + id=str(uuid.uuid4()), |
| 1867 | + model=model, |
| 1868 | + results=[ |
| 1869 | + ModerationObjectResults( |
| 1870 | + flagged=False, |
| 1871 | + categories={}, |
| 1872 | + category_applied_input_types={}, |
| 1873 | + category_scores={}, |
| 1874 | + user_message=msg if isinstance(msg, str) else getattr(msg, "content", str(msg)), |
| 1875 | + metadata={"error": str(e), "status": "error"}, |
| 1876 | + ) |
| 1877 | + for msg in input_list |
| 1878 | + ], |
| 1879 | + ) |
| 1880 | + |
| 1881 | + async def _get_shield_id_from_model(self, model: str) -> str: |
| 1882 | + """Map model name to shield_id using provider_resource_id.""" |
| 1883 | + shields_response = await self.list_shields() |
| 1884 | + matching_shields = [ |
| 1885 | + shield.identifier |
| 1886 | + for shield in shields_response.data |
| 1887 | + if shield.provider_resource_id == model |
| 1888 | + ] |
| 1889 | + if not matching_shields: |
| 1890 | + raise ValueError(f"No shield found for model '{model}'. Available shields: {[s.identifier for s in shields_response.data]}") |
| 1891 | + if len(matching_shields) > 1: |
| 1892 | + raise ValueError(f"Multiple shields found for model '{model}': {matching_shields}") |
| 1893 | + return matching_shields[0] |
| 1894 | + |
| 1895 | + |
| 1896 | + def _convert_input_to_messages(self, texts: str | list[str]) -> List[Message]: |
| 1897 | + """Convert string input(s) to UserMessage objects.""" |
| 1898 | + if isinstance(texts, str): |
| 1899 | + inputs = [texts] |
| 1900 | + else: |
| 1901 | + inputs = texts |
| 1902 | + return [UserMessage(content=text) for text in inputs] |
1795 | 1903 |
|
| 1904 | + |
1796 | 1905 | async def shutdown(self) -> None: |
1797 | 1906 | """Cleanup resources""" |
1798 | 1907 | logger.info(f"Provider {self._provider_id} shutting down") |
|
0 commit comments