Skip to content

Commit 2a803d7

Browse files
authored
Merge branch 'trustyai-explainability:main' into incubation
2 parents 76ff899 + cbe5b63 commit 2a803d7

21 files changed

+528
-202
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ At the moment, the following detectors are supported:
1313
- `builtIn` -- Small, lightweight detection functions that are deployed out-of-the-box alongside the [Guardrails Orchestrator](https://github.com/foundation-model-stack/fms-guardrails-orchestrator). The built-in detectors provide a number of heuristic or algorithmic detection functions, such as:
1414
- Regex-based detections, with pre-written regexes for flagging various Personally Identifiable Information items like emails or phone numbers, as well as the ability to provide custom regexes
1515
- File-type validations, for verifying if model input/output is valid JSON, XML, or YAML
16-
16+
- "Custom" detectors which can be defined with raw Python code, which allows for easy declaration of custom detection functions that can perform arbitrarily complex logic.
1717

1818
## Building
1919

detectors/Dockerfile.builtIn

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@ ARG CACHEBUST=1
2121
RUN echo "$CACHEBUST"
2222
COPY ./common /app/detectors/common
2323
COPY ./built_in/ /app
24+
ENV PROMETHEUS_MULTIPROC_DIR="/tmp/prometheus_multiproc_dir"
25+
RUN mkdir -p $PROMETHEUS_MULTIPROC_DIR && chmod 777 $PROMETHEUS_MULTIPROC_DIR
2426

2527
EXPOSE 8080
2628

2729
# for backwards compatibility with existing k8s deployment configs
2830
RUN mkdir /app/bin &&\
29-
echo '#!/bin/bash' > /app/bin/regex-detector &&\
31+
echo '#!/bin/bash' > /app/bin/regex-detector &&\
3032
echo "uvicorn app:app --workers 4 --host 0.0.0.0 --port 8080 --log-config /app/detectors/common/log_conf.yaml" >> /app/bin/regex-detector &&\
3133
chmod +x /app/bin/regex-detector
3234
CMD ["/app/bin/regex-detector"]

detectors/Dockerfile.hf

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,12 @@ FROM builder
9696
WORKDIR /app
9797
ARG CACHEBUST=1
9898
RUN echo "$CACHEBUST"
99-
COPY ./common /common
99+
COPY ./common /app/detectors/common
100+
COPY ./huggingface/detector.py /app/detectors/huggingface/
101+
RUN mkdir /common; cp /app/detectors/common/log_conf.yaml /common/
100102
COPY ./huggingface/app.py /app
101-
COPY ./huggingface/detector.py /app
102103

103104
EXPOSE 8000
104-
CMD ["uvicorn", "app:app", "--workers", "4", "--host", "0.0.0.0", "--port", "8000", "--log-config", "/common/log_conf.yaml"]
105+
CMD ["uvicorn", "app:app", "--workers", "1", "--host", "0.0.0.0", "--port", "8000", "--log-config", "/common/log_conf.yaml"]
105106

106107
# gunicorn main:app --workers 4 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:8000

detectors/built_in/app.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from custom_detectors_wrapper import CustomDetectorRegistry
88
from file_type_detectors import FileTypeDetectorRegistry
99

10-
from prometheus_fastapi_instrumentator import Instrumentator
11-
from prometheus_client import Gauge
10+
from prometheus_client import generate_latest, CONTENT_TYPE_LATEST, CollectorRegistry, multiprocess
11+
from starlette.responses import Response
1212
from detectors.common.scheme import ContentAnalysisHttpRequest, ContentsAnalysisResponse
1313
from detectors.common.app import DetectorBaseAPI as FastAPI
1414

@@ -21,17 +21,23 @@ async def lifespan(app: FastAPI):
2121
CustomDetectorRegistry()
2222
]:
2323
app.set_detector(detector_registry, detector_registry.registry_name)
24-
detector_registry.add_instruments(app.state.instruments)
24+
detector_registry.set_instruments(app.state.instruments)
2525
yield
2626
app.cleanup_detector()
2727

2828

2929
app = FastAPI(lifespan=lifespan)
30-
Instrumentator().instrument(app).expose(app)
3130
logging.basicConfig(level=logging.INFO)
3231
logger = logging.getLogger(__name__)
3332

3433

34+
@app.get("/metrics")
35+
def metrics():
36+
registry = CollectorRegistry()
37+
multiprocess.MultiProcessCollector(registry)
38+
data = generate_latest(registry)
39+
return Response(data, media_type=CONTENT_TYPE_LATEST)
40+
3541
@app.post("/api/v1/text/contents", response_model=ContentsAnalysisResponse)
3642
def detect_content(request: ContentAnalysisHttpRequest, raw_request: Request):
3743
logger.info(f"Request for {request.detector_params}")

detectors/built_in/base_detector_registry.py

Lines changed: 4 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
1-
import contextlib
21
import logging
32
from abc import ABC, abstractmethod
4-
import time
5-
from http.client import HTTPException
3+
from fastapi import HTTPException
64
from typing import List
75

6+
from detectors.common.instrumented_detector import InstrumentedDetector
87
from detectors.common.scheme import ContentAnalysisResponse
98

10-
class BaseDetectorRegistry(ABC):
9+
class BaseDetectorRegistry(InstrumentedDetector, ABC):
1110
def __init__(self, registry_name):
11+
super().__init__(registry_name)
1212
self.registry = None
13-
self.registry_name = registry_name
1413

15-
# prometheus
16-
self.instruments = {}
1714

1815
@abstractmethod
1916
def handle_request(self, content: str, detector_params: dict, headers: dict, **kwargs) -> List[ContentAnalysisResponse]:
@@ -22,35 +19,6 @@ def handle_request(self, content: str, detector_params: dict, headers: dict, **k
2219
def get_registry(self):
2320
return self.registry
2421

25-
def add_instruments(self, gauges):
26-
self.instruments = gauges
27-
28-
def increment_detector_instruments(self, function_name: str, is_detection: bool):
29-
"""Increment the detection and request counters, automatically update rates"""
30-
if self.instruments.get("requests"):
31-
self.instruments["requests"].labels(self.registry_name, function_name).inc()
32-
33-
# The labels() function will initialize the counters if not already created.
34-
# This prevents the counters not existing until they are first incremented
35-
# If the counters have already been created, this is just a cheap dict.get() call
36-
if self.instruments.get("errors"):
37-
_ = self.instruments["errors"].labels(self.registry_name, function_name)
38-
if self.instruments.get("runtime"):
39-
_ = self.instruments["runtime"].labels(self.registry_name, function_name)
40-
41-
# create and/or increment the detection counter
42-
if self.instruments.get("detections"):
43-
detection_counter = self.instruments["detections"].labels(self.registry_name, function_name)
44-
if is_detection:
45-
detection_counter.inc()
46-
47-
48-
def increment_error_instruments(self, function_name: str):
49-
"""Increment the error counter, update the rate gauges"""
50-
if self.instruments.get("errors"):
51-
self.instruments["errors"].labels(self.registry_name, function_name).inc()
52-
53-
5422
def throw_internal_detector_error(self, function_name: str, logger: logging.Logger, exception: Exception, increment_requests: bool):
5523
"""consistent handling of internal errors within a detection function"""
5624
if increment_requests and self.instruments.get("requests"):
@@ -60,16 +28,6 @@ def throw_internal_detector_error(self, function_name: str, logger: logging.Logg
6028
raise HTTPException(status_code=500, detail="Detection error, check detector logs")
6129

6230

63-
@contextlib.contextmanager
64-
def instrument_runtime(self, function_name: str):
65-
try:
66-
start_time = time.time()
67-
yield
68-
if self.instruments.get("runtime"):
69-
self.instruments["runtime"].labels(self.registry_name, function_name).inc(time.time() - start_time)
70-
finally:
71-
pass
72-
7331
def get_detection_functions_from_params(self, params: dict):
7432
"""Parse the request parameters to extract and normalize detection functions as iterable list"""
7533
if self.registry_name in params and isinstance(params[self.registry_name], (list, str)):
Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,36 @@
11
"""
2-
This is an example custom_detectors.py file. Here, you can define any arbitrary Python code as a
3-
Guardrail detector.
2+
This is an example custom_detectors.py file. Overwrite this file to define custom guardrailing
3+
logic!
44
5-
The following rules apply:
6-
1) Each function defined in this file (except for those starting with "_") will be registered as a detector
7-
2) Functions that accept a parameter "headers" will receive the inbound request headers as a parameter
8-
3) Functions may either return a boolean or a dict:
9-
3a) Return values that evaluate to false (e.g., {}, "", None, etc) are treated as non-detections
10-
3b) Boolean responses of "true" are considered a detection
11-
3c) Dict response must be parseable as a ContentAnalysisResponse object (see example below)
12-
4) This code may not import "os", "subprocess", "sys", or "shutil" for security reasons
13-
5) This code may not call "eval", "exec", "open", "compile", or "input" for security reasons
5+
See [docs/custom_detectors.md](../../docs/custom_detectors.md) for more details.
146
"""
157

16-
# example boolean-returning function
17-
def over_100_characters(text: str) -> bool:
18-
return len(text)>100
19-
20-
# example dict-returning function
21-
def contains_word(text: str) -> dict:
22-
detection = "apple" in text.lower()
23-
if detection:
24-
detection_position = text.find("apple")
25-
return {
26-
"start":detection_position, # start position of detection in text
27-
"end": detection_position+5, # end position of detection in text
28-
"text": text, # "the flagged text, or some arbitrary message to return to the user"
29-
"detection_type": "content_check", #detection_type -> use these fields to define your detector taxonomy as you see fit
30-
"detection": "forbidden_word: apple", ##detection -> use these fields to define your detector taxonomy as you see fit
31-
"score": 1.0 # the score/severity/probability of the detection
32-
}
33-
else:
34-
return {}
35-
36-
def _this_function_will_not_be_exposed():
37-
pass
38-
39-
def function_that_needs_headers(text: str, headers: dict) -> bool:
40-
return headers['magic-key'] != "123"
8+
import time
9+
def slow_func(text: str) -> bool:
10+
time.sleep(.25)
11+
return False
12+
13+
from prometheus_client import Counter
4114

15+
prompt_rejection_counter = Counter(
16+
"trustyai_guardrails_system_prompt_rejections",
17+
"Number of rejections by the system prompt",
18+
)
19+
20+
@use_instruments(instruments=[prompt_rejection_counter])
21+
def has_metrics(text: str) -> bool:
22+
if "sorry" in text:
23+
prompt_rejection_counter.inc()
24+
return False
25+
26+
background_metric = Counter(
27+
"trustyai_guardrails_background_metric",
28+
"Runs some logic in the background without blocking the /detections call"
29+
)
30+
@use_instruments(instruments=[background_metric])
31+
@non_blocking(return_value=False)
32+
def background_function(text: str) -> bool:
33+
time.sleep(.25)
34+
if "sorry" in text:
35+
background_metric.inc()
36+
return False

detectors/built_in/custom_detectors_wrapper.py

Lines changed: 84 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,71 @@
11
import ast
2+
import logging
3+
import importlib.util
4+
import inspect
5+
import functools
26
import os
3-
import traceback
7+
import sys
48

9+
from concurrent.futures import ThreadPoolExecutor
510
from fastapi import HTTPException
6-
import inspect
7-
import logging
811
from typing import List, Optional, Callable
912

10-
1113
from base_detector_registry import BaseDetectorRegistry
14+
from detectors.common.app import METRIC_PREFIX
1215
from detectors.common.scheme import ContentAnalysisResponse
1316

1417
logger = logging.getLogger(__name__)
1518

19+
def use_instruments(instruments: List):
20+
"""Use this decorator to register the provided Prometheus instruments with the main /metrics endpoint"""
21+
def inner_layer_1(func):
22+
@functools.wraps(func)
23+
def inner_layer_2(*args, **kwargs):
24+
return func(*args, **kwargs)
25+
26+
# check to see if "func" is already decorated, and only add the prometheus instruments field into the original function
27+
target = get_underlying_function(func)
28+
setattr(target, "prometheus_instruments", instruments)
29+
return inner_layer_2
30+
return inner_layer_1
31+
32+
def non_blocking(return_value):
33+
"""
34+
Use this decorator to run the guardrail as a non-blocking background thread.
35+
36+
The `return_value` is returned instantly to the caller of the /api/v1/text/contents, while
37+
the logic inside the function will run asynchronously in the background.
38+
"""
39+
def inner_layer_1(func):
40+
@functools.wraps(func)
41+
def inner_layer_2(*args, **kwargs):
42+
executor = getattr(non_blocking, "_executor", None)
43+
if executor is None:
44+
executor = ThreadPoolExecutor()
45+
non_blocking._executor = executor
46+
def runner():
47+
try:
48+
func(*args, **kwargs)
49+
except Exception as e:
50+
logging.error(f"Exception in non-blocking guardrail {func.__name__}: {e}")
51+
executor.submit(runner)
52+
53+
# check to see if "func" is already decorated by `use_instruments`, and grab the instruments if so
54+
target = get_underlying_function(func)
55+
if hasattr(target, "prometheus_instruments"):
56+
setattr(target, "prometheus_instruments", target.prometheus_instruments)
57+
return return_value
58+
return inner_layer_2
59+
return inner_layer_1
60+
61+
forbidden_names = [use_instruments.__name__, non_blocking.__name__]
62+
63+
def get_underlying_function(func):
64+
if hasattr(func, "__wrapped__"):
65+
return get_underlying_function(func.__wrapped__)
66+
return func
67+
68+
1669
def custom_func_wrapper(func: Callable, func_name: str, s: str, headers: dict) -> Optional[ContentAnalysisResponse]:
1770
"""Convert a some f(text)->bool into a Detector response"""
1871
sig = inspect.signature(func)
@@ -92,17 +145,42 @@ class CustomDetectorRegistry(BaseDetectorRegistry):
92145
def __init__(self):
93146
super().__init__("custom")
94147

148+
# check the imported code for potential security issues
95149
issues = static_code_analysis(module_path = os.path.join(os.path.dirname(__file__), "custom_detectors", "custom_detectors.py"))
96150
if issues:
97151
logging.error(f"Detected {len(issues)} potential security issues inside the custom_detectors file: {issues}")
98152
raise ImportError(f"Unsafe code detected in custom_detectors:\n" + "\n".join(issues))
99153

100-
import custom_detectors.custom_detectors as custom_detectors
154+
# grab custom detectors module
155+
module_path = os.path.join(os.path.dirname(__file__), "custom_detectors", "custom_detectors.py")
156+
spec = importlib.util.spec_from_file_location("custom_detectors.custom_detectors", module_path)
157+
custom_detectors = importlib.util.module_from_spec(spec)
158+
159+
# inject any user utility functions into the code automatically
160+
inject_imports = {
161+
"use_instruments": use_instruments,
162+
"non_blocking": non_blocking,
163+
}
164+
for name, mod in inject_imports.items():
165+
setattr(custom_detectors, name, mod)
166+
167+
# load the module
168+
sys.modules["custom_detectors.custom_detectors"] = custom_detectors
169+
spec.loader.exec_module(custom_detectors)
101170

102171
self.registry = {name: obj for name, obj
103172
in inspect.getmembers(custom_detectors, inspect.isfunction)
104-
if not name.startswith("_")}
173+
if not name.startswith("_") and name not in forbidden_names}
105174
self.function_needs_headers = {name: "headers" in inspect.signature(obj).parameters for name, obj in self.registry.items() }
175+
176+
# check if functions have requested user prometheus metrics
177+
for name, func in self.registry.items():
178+
target = get_underlying_function(func)
179+
if getattr(target, "prometheus_instruments", False):
180+
instruments = target.prometheus_instruments
181+
for instrument in instruments:
182+
super().add_instrument(instrument)
183+
106184
logger.info(f"Registered the following custom detectors: {self.registry.keys()}")
107185

108186

0 commit comments

Comments
 (0)