Skip to content

Commit 3b7c1eb

Browse files
authored
Merge pull request trustyai-explainability#56 from m-misiura/fix-detector-url-dynamic-shield
Add unregister shield method and attempt to fix bugs in dynamic shield registration
2 parents 6db7a75 + dc7b9c7 commit 3b7c1eb

File tree

1 file changed

+39
-4
lines changed
  • llama_stack_provider_trustyai_fms/detectors

1 file changed

+39
-4
lines changed

llama_stack_provider_trustyai_fms/detectors/base.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -837,8 +837,11 @@ async def create_dynamic_shield(
837837
confidence_threshold = params.get("confidence_threshold", 0.5)
838838
message_types = params.get("message_types", ["system"])
839839

840-
# Create detector params
841-
detector_params = DetectorParams()
840+
# Create detector params from API params
841+
detector_params_dict = params.get("detector_params", {})
842+
if not isinstance(detector_params_dict, dict):
843+
detector_params_dict = {}
844+
detector_params = DetectorParams(**detector_params_dict)
842845

843846
# Handle detectors configuration (like your email_hap example)
844847
if "detectors" in params:
@@ -865,7 +868,7 @@ async def create_dynamic_shield(
865868
max_concurrency=10,
866869
# URL configuration
867870
orchestrator_url=orchestrator_url,
868-
detector_url=None, # Will be set if needed
871+
detector_url=params.get("detector_url"),
869872
auth_token=params.get("auth_token"),
870873
verify_ssl=params.get("verify_ssl", True),
871874
ssl_cert_path=params.get("ssl_cert_path"),
@@ -887,7 +890,7 @@ async def create_dynamic_shield(
887890
max_concurrency=10,
888891
# URL configuration
889892
orchestrator_url=orchestrator_url,
890-
detector_url=None,
893+
detector_url=params.get("detector_url"),
891894
auth_token=params.get("auth_token"),
892895
verify_ssl=params.get("verify_ssl", True),
893896
ssl_cert_path=params.get("ssl_cert_path"),
@@ -1299,6 +1302,38 @@ async def register_shield(self, shield: Shield) -> Shield:
12991302
f"Cannot create shield '{shield_identifier}': no detector configuration found and no API parameters provided"
13001303
)
13011304

1305+
async def unregister_shield(self, identifier: str) -> None:
1306+
"""Unregister a shield and remove it from the provider"""
1307+
logger.info(f"Provider {self._provider_id} unregistering shield: {identifier}")
1308+
1309+
# Remove from provider's shields dictionary
1310+
if identifier in self._shields:
1311+
del self._shields[identifier]
1312+
logger.info(f"Removed shield {identifier} from provider's _shields")
1313+
1314+
# Remove from shield store
1315+
if (
1316+
hasattr(self._shield_store, "_shields")
1317+
and identifier in self._shield_store._shields
1318+
):
1319+
del self._shield_store._shields[identifier]
1320+
logger.info(f"Removed shield {identifier} from shield store's _shields")
1321+
1322+
# Remove detector config from shield store
1323+
if (
1324+
hasattr(self._shield_store, "_detector_configs")
1325+
and identifier in self._shield_store._detector_configs
1326+
):
1327+
del self._shield_store._detector_configs[identifier]
1328+
logger.info(f"Removed detector config for {identifier} from shield store")
1329+
1330+
# Remove detector instance if it's a dynamic shield
1331+
if identifier in self.detectors:
1332+
del self.detectors[identifier]
1333+
logger.info(f"Removed detector instance for {identifier}")
1334+
1335+
logger.info(f"Successfully unregistered shield: {identifier}")
1336+
13021337
def _generate_shield_params(self, detector) -> dict[str, Any]:
13031338
"""Generate shield parameters from detector config"""
13041339
shield_id = detector.config.detector_id

0 commit comments

Comments
 (0)