Skip to content

Commit 9303f05

Browse files
authored
Merge pull request trustyai-explainability#5 from m-misiura/dynamic-shields-via-api
feat (RHOAIENG-27371) -- add dynamic shield registration
2 parents 66da86f + e651ded commit 9303f05

File tree

3 files changed

+151
-64
lines changed

3 files changed

+151
-64
lines changed

llama_stack_provider_trustyai_fms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ async def get_adapter_impl(
8888
if isinstance(detector, BaseDetector):
8989
detectors_for_provider[shield_id] = detector
9090

91-
return DetectorProvider(detectors_for_provider)
91+
return DetectorProvider(detectors_for_provider, config)
9292

9393
except Exception as e:
9494
raise DetectorConfigError(

llama_stack_provider_trustyai_fms/detectors/base.py

Lines changed: 149 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -680,8 +680,6 @@ async def update_shield_params(
680680
f"Shield store {self._store_id} updated params for shield {shield_id}: {params}"
681681
)
682682

683-
# In the SimpleShieldStore class, replace the initialize method:
684-
685683
async def initialize(self) -> None:
686684
"""Initialize store and process pending configurations"""
687685
if self._initialized:
@@ -830,6 +828,77 @@ async def get_shield(self, identifier: str) -> Shield:
830828
"Shields must have a valid detector configuration to ensure proper safety checks."
831829
)
832830

831+
async def create_dynamic_shield(
832+
self, shield_id: str, params: Dict[str, Any]
833+
) -> None:
834+
"""Create a dynamic shield configuration from API parameters"""
835+
from ..config import ContentDetectorConfig, ChatDetectorConfig, DetectorParams
836+
837+
logger.info(f"Creating dynamic shield configuration for: {shield_id}")
838+
839+
# Extract shield configuration from API params
840+
shield_type = params.get("type", "content")
841+
confidence_threshold = params.get("confidence_threshold", 0.5)
842+
message_types = params.get("message_types", ["system"])
843+
844+
# Create detector params
845+
detector_params = DetectorParams()
846+
847+
# Handle detectors configuration (like your email_hap example)
848+
if "detectors" in params:
849+
detector_params.detectors = params["detectors"]
850+
851+
# Set orchestrator URL from provider config if available
852+
orchestrator_url = None
853+
if hasattr(self, "_provider_config") and self._provider_config.orchestrator_url:
854+
orchestrator_url = self._provider_config.orchestrator_url
855+
856+
# Create appropriate config based on type
857+
if shield_type == "content":
858+
config = ContentDetectorConfig(
859+
detector_id=shield_id,
860+
confidence_threshold=confidence_threshold,
861+
message_types=set(message_types), # Convert to set as expected
862+
detector_params=detector_params,
863+
# Runtime parameters (all valid from BaseDetectorConfig)
864+
request_timeout=30.0,
865+
max_retries=3,
866+
backoff_factor=1.5,
867+
max_keepalive_connections=5,
868+
max_connections=10,
869+
max_concurrency=10,
870+
# URL configuration
871+
orchestrator_url=orchestrator_url,
872+
detector_url=None, # Will be set if needed
873+
auth_token=None,
874+
)
875+
elif shield_type == "chat":
876+
config = ChatDetectorConfig(
877+
detector_id=shield_id,
878+
confidence_threshold=confidence_threshold,
879+
message_types=set(message_types), # Convert to set as expected
880+
detector_params=detector_params,
881+
# Runtime parameters (all valid from BaseDetectorConfig)
882+
request_timeout=30.0,
883+
max_retries=3,
884+
backoff_factor=1.5,
885+
max_keepalive_connections=5,
886+
max_connections=10,
887+
max_concurrency=10,
888+
# URL configuration
889+
orchestrator_url=orchestrator_url,
890+
detector_url=None,
891+
auth_token=None,
892+
)
893+
else:
894+
raise DetectorValidationError(f"Unknown shield type: {shield_type}")
895+
896+
# Register the configuration
897+
await self.register_detector_config(shield_id, config)
898+
logger.info(
899+
f"Successfully created dynamic configuration for shield: {shield_id}"
900+
)
901+
833902
async def list_shields(self) -> ListShieldsResponse:
834903
"""List all registered shields with their parameters"""
835904
if not self._initialized:
@@ -882,9 +951,14 @@ async def list_shields(self) -> ListShieldsResponse:
882951
class DetectorProvider(Safety, Shields):
883952
"""Provider for managing safety detectors and shields"""
884953

885-
def __init__(self, detectors: Dict[str, BaseDetector]) -> None:
954+
def __init__(
955+
self, detectors: Dict[str, BaseDetector], config: Optional[Any] = None
956+
) -> None:
886957
self.detectors = detectors
887958
self._shield_store: ShieldStore = SimpleShieldStore()
959+
if config:
960+
self._shield_store._provider_config = config
961+
888962
self._shields: Dict[str, Shield] = {}
889963
self._initialized = False
890964
self._provider_id = id(self)
@@ -935,7 +1009,6 @@ def shield_store(self, value: ShieldStore) -> None:
9351009
register_method(detector.config.detector_id, detector.config)
9361010
)
9371011

938-
# Add this helper method to the DetectorProvider class
9391012
def _prepare_shield_for_storage(self, shield: Shield) -> dict:
9401013
"""Prepare shield for storage by ensuring all attributes are properly serialized"""
9411014
# Create a complete dictionary manually to ensure all fields are included
@@ -1014,8 +1087,6 @@ async def initialize(self) -> None:
10141087
logger.error(f"Provider {self._provider_id} initialization failed: {e}")
10151088
raise
10161089

1017-
# In DetectorProvider class
1018-
10191090
async def list_shields(self) -> ListShieldsResponse:
10201091
"""List all registered shields with their parameters"""
10211092
if not self._initialized:
@@ -1133,76 +1204,92 @@ async def get_shield(self, identifier: str) -> Shield:
11331204

11341205
raise DetectorValidationError(f"Failed to get shield: {identifier}")
11351206

1136-
# In DetectorProvider class
1137-
1138-
async def register_shield(
1139-
self,
1140-
shield_id: str,
1141-
provider_shield_id: Optional[str] = None,
1142-
provider_id: Optional[str] = None,
1143-
params: Optional[Dict[str, Any]] = None,
1144-
) -> Shield:
1145-
"""Register a new shield"""
1207+
async def register_shield(self, shield: Shield) -> Shield:
1208+
"""Register a shield dynamically from API request"""
11461209
if not self._initialized:
11471210
await self.initialize()
11481211

1149-
# Get the string identifier regardless of input type
1150-
if hasattr(shield_id, "identifier"):
1151-
shield_identifier = shield_id.identifier
1152-
else:
1153-
shield_identifier = str(shield_id)
1154-
1155-
logger.debug(f"Registering shield with ID: {shield_identifier}")
1212+
shield_identifier = shield.identifier
1213+
logger.info(
1214+
f"Provider {self._provider_id} registering shield: {shield_identifier}"
1215+
)
11561216

1157-
# Check if shield already exists by string identifier
1217+
# Check if shield already exists
11581218
if shield_identifier in self._shields:
1159-
shield = self._shields[shield_identifier]
1160-
logger.debug(
1161-
f"Shield {shield_identifier} already registered, returning existing instance with params: {shield.params}"
1219+
existing_shield = self._shields[shield_identifier]
1220+
logger.info(
1221+
f"Shield {shield_identifier} already exists, returning existing"
11621222
)
1163-
return shield
1223+
return existing_shield
11641224

1165-
# Get or create shield from store
1166-
shield = await self._shield_store.get_shield(shield_identifier)
1167-
if not shield:
1168-
raise DetectorValidationError(
1169-
f"Failed to create shield: {shield_identifier}"
1225+
try:
1226+
existing_shield = await self._shield_store.get_shield(shield_identifier)
1227+
if existing_shield:
1228+
self._shields[shield_identifier] = existing_shield
1229+
logger.info(f"Created shield {shield_identifier} from existing config")
1230+
return existing_shield
1231+
except DetectorValidationError:
1232+
# Config doesn't exist - this is expected for dynamic shields
1233+
logger.info(
1234+
f"No existing config for {shield_identifier}, creating from API params"
11701235
)
11711236

1172-
# Update fields if provided
1173-
if provider_id:
1174-
shield.provider_id = provider_id
1175-
if provider_shield_id:
1176-
shield.provider_resource_id = provider_shield_id
1177-
if params is not None:
1178-
shield.params = params
1237+
# dynamic shield creation
1238+
if shield.params:
1239+
logger.info(
1240+
f"Creating dynamic shield {shield_identifier} with params: {shield.params}"
1241+
)
11791242

1180-
# Ensure shield parameters exist even if not provided
1181-
if not shield.params:
1182-
detector = next(
1183-
(
1184-
d
1185-
for d in self.detectors.values()
1186-
if d.config.detector_id == shield_identifier
1187-
),
1188-
None,
1243+
# Create a detector config from the API parameters
1244+
await self._shield_store.create_dynamic_shield(
1245+
shield_identifier, shield.params
11891246
)
1190-
if detector:
1191-
shield.params = self._generate_shield_params(detector)
1192-
logger.debug(
1193-
f"Generated missing parameters for shield {shield_identifier}: {shield.params}"
1247+
1248+
try:
1249+
dynamic_shield = await self._shield_store.get_shield(shield_identifier)
1250+
self._shields[shield_identifier] = dynamic_shield
1251+
1252+
# CREATE AND REGISTER DETECTOR INSTANCE FOR DYNAMIC SHIELD
1253+
config = self._shield_store._detector_configs.get(shield_identifier)
1254+
if config:
1255+
# Import detector classes
1256+
from ..detectors.content import ContentDetector
1257+
from ..detectors.chat import ChatDetector
1258+
1259+
# Create detector instance based on type
1260+
if config.is_chat:
1261+
detector_instance = ChatDetector(config)
1262+
else:
1263+
detector_instance = ContentDetector(config)
1264+
1265+
# Initialize the detector
1266+
await detector_instance.initialize()
1267+
detector_instance.shield_store = self._shield_store
1268+
1269+
# Add to detectors dictionary
1270+
self.detectors[shield_identifier] = detector_instance
1271+
1272+
# Register shield with the detector
1273+
await detector_instance.register_shield(dynamic_shield)
1274+
1275+
logger.info(
1276+
f"Created and registered detector instance for dynamic shield: {shield_identifier}"
1277+
)
1278+
1279+
logger.info(f"Successfully created dynamic shield: {shield_identifier}")
1280+
return dynamic_shield
1281+
1282+
except Exception as e:
1283+
logger.error(
1284+
f"Failed to create dynamic shield {shield_identifier}: {e}"
1285+
)
1286+
raise DetectorValidationError(
1287+
f"Failed to create dynamic shield {shield_identifier}: {e}"
11941288
)
11951289

1196-
# Store shield by string identifier
1197-
self._shields[shield_identifier] = shield
1198-
logger.debug(
1199-
f"Shield {shield_identifier} registered with params: {shield.params}"
1290+
raise DetectorValidationError(
1291+
f"Cannot create shield '{shield_identifier}': no detector configuration found and no API parameters provided"
12001292
)
1201-
# Register with detectors
1202-
for detector in self.detectors.values():
1203-
await detector.register_shield(shield)
1204-
1205-
return shield
12061293

12071294
def _generate_shield_params(self, detector) -> Dict[str, Any]:
12081295
"""Generate shield parameters from detector config"""

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "llama-stack-provider-trustyai-fms"
7-
version = "0.1.1"
7+
version = "0.1.2"
88
description = "Remote safety provider for Llama Stack integrating FMS Guardrails Orchestrator and community detectors"
99
authors = [
1010
{name = "GitHub: m-misiura"}

0 commit comments

Comments
 (0)