Skip to content

Commit 3be77cd

Browse files
authored
Merge pull request #26 from trustyai-explainability/main
[pull] main from trustyai-explainability:main
2 parents f31e689 + b1d8b77 commit 3be77cd

File tree

10 files changed

+274
-15
lines changed

10 files changed

+274
-15
lines changed

detectors/Dockerfile.builtIn

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ WORKDIR /app
2020
ARG CACHEBUST=1
2121
RUN echo "$CACHEBUST"
2222
COPY ./common /app/detectors/common
23-
COPY ./built_in/* /app
23+
COPY ./built_in/ /app
2424

2525
EXPOSE 8080
2626

detectors/built_in/app.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,51 @@
1-
from fastapi import HTTPException
1+
import logging
2+
3+
from fastapi import HTTPException, Request
24
from contextlib import asynccontextmanager
35
from base_detector_registry import BaseDetectorRegistry
46
from regex_detectors import RegexDetectorRegistry
7+
from custom_detectors_wrapper import CustomDetectorRegistry
58
from file_type_detectors import FileTypeDetectorRegistry
69

710
from prometheus_fastapi_instrumentator import Instrumentator
811
from detectors.common.scheme import ContentAnalysisHttpRequest, ContentsAnalysisResponse
912
from detectors.common.app import DetectorBaseAPI as FastAPI
1013

14+
1115
@asynccontextmanager
1216
async def lifespan(app: FastAPI):
1317
app.set_detector(RegexDetectorRegistry(), "regex")
1418
app.set_detector(FileTypeDetectorRegistry(), "file_type")
19+
app.set_detector(CustomDetectorRegistry(), "custom")
1520
yield
1621

1722
app.cleanup_detector()
1823

1924

2025
app = FastAPI(lifespan=lifespan)
2126
Instrumentator().instrument(app).expose(app)
27+
logging.basicConfig(level=logging.INFO)
28+
logger = logging.getLogger(__name__)
2229

2330

24-
# registry : dict[str, BaseDetectorRegistry] = {
25-
# "regex": RegexDetectorRegistry(),
26-
# "file_type": FileTypeDetectorRegistry(),
27-
# }
28-
2931
@app.post("/api/v1/text/contents", response_model=ContentsAnalysisResponse)
30-
def detect_content(request: ContentAnalysisHttpRequest):
32+
def detect_content(request: ContentAnalysisHttpRequest, raw_request: Request):
33+
logger.info(f"Request for {request.detector_params}")
34+
35+
headers = dict(raw_request.headers)
36+
3137
detections = []
3238
for content in request.contents:
3339
message_detections = []
34-
for detector_kind, detector_registry in app.get_all_detectors().items():
40+
for detector_kind in request.detector_params:
41+
detector_registry = app.get_all_detectors().get(detector_kind)
42+
if detector_registry is None:
43+
raise HTTPException(status_code=400, detail=f"Detector {detector_kind} not found")
3544
if not isinstance(detector_registry, BaseDetectorRegistry):
3645
raise TypeError(f"Detector {detector_kind} is not a valid BaseDetectorRegistry")
37-
if detector_kind in request.detector_params:
46+
else:
3847
try:
39-
message_detections += detector_registry.handle_request(content, request.detector_params)
48+
message_detections += detector_registry.handle_request(content, request.detector_params, headers)
4049
except HTTPException as e:
4150
raise e
4251
except Exception as e:

detectors/built_in/base_detector_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def __init__(self):
88
self.registry = None
99

1010
@abstractmethod
11-
def handle_request(self, content: str, detector_params: dict) -> List[ContentAnalysisResponse]:
11+
def handle_request(self, content: str, detector_params: dict, headers: dict) -> List[ContentAnalysisResponse]:
1212
pass
1313

1414
def get_registry(self):

detectors/built_in/custom_detectors/__init__.py

Whitespace-only changes.
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
2+
def over_100_characters(text: str) -> bool:
3+
return len(text)>100
4+
5+
def contains_word(text: str) -> bool:
6+
return "apple" in text.lower()
7+
8+
def function_that_needs_headers(text: str, headers: dict) -> bool:
9+
return headers['magic-key'] != "123"
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import ast
2+
import os
3+
4+
from fastapi import HTTPException
5+
import inspect
6+
import logging
7+
from typing import List, Optional, Callable
8+
9+
10+
from base_detector_registry import BaseDetectorRegistry
11+
from detectors.common.scheme import ContentAnalysisResponse
12+
13+
logger = logging.getLogger(__name__)
14+
15+
def custom_func_wrapper(func: Callable, func_name: str, s: str, headers: dict) -> Optional[ContentAnalysisResponse]:
16+
"""Convert a some f(text)->bool into a Detector response"""
17+
sig = inspect.signature(func)
18+
try:
19+
if headers is not None:
20+
result = func(s, headers)
21+
else:
22+
result = func(s)
23+
24+
except Exception as e:
25+
logging.error(f"Error when computing custom detector function {func_name}: {e}")
26+
raise e
27+
if result:
28+
if isinstance(result, bool):
29+
return ContentAnalysisResponse(
30+
start=0,
31+
end=len(s),
32+
text=s,
33+
detection_type=func_name,
34+
detection=func_name,
35+
score=1.0)
36+
elif isinstance(result, dict):
37+
try:
38+
return ContentAnalysisResponse(**result)
39+
except Exception as e:
40+
logging.error(f"Error when trying to build ContentAnalysisResponse from {func_name} response: {e}")
41+
raise e
42+
else:
43+
msg = f"Unsupported result type for custom detector function {func_name}, must be bool or ContentAnalysisResponse, got: {type(result)}"
44+
logging.error(msg)
45+
raise TypeError(msg)
46+
else:
47+
return None
48+
49+
50+
def static_code_analysis(module_path, forbidden_imports=None, forbidden_calls=None):
51+
"""
52+
Perform static code analysis on a Python module to check for forbidden imports and function calls.
53+
Returns a list of issues found.
54+
"""
55+
if forbidden_imports is None:
56+
forbidden_imports = {"os", "subprocess", "sys", "shutil"}
57+
if forbidden_calls is None:
58+
forbidden_calls = {"eval", "exec", "open", "compile", "input"}
59+
60+
issues = []
61+
with open(module_path, "r") as f:
62+
source = f.read()
63+
try:
64+
tree = ast.parse(source, filename=module_path)
65+
except Exception as e:
66+
issues.append(f"Failed to parse {module_path}: {e}")
67+
return issues
68+
69+
for node in ast.walk(tree):
70+
# Check for forbidden imports
71+
if isinstance(node, ast.Import):
72+
for alias in node.names:
73+
if alias.name.split(".")[0] in forbidden_imports:
74+
issues.append(f"- Forbidden import: {alias.name} (line {node.lineno})")
75+
if isinstance(node, ast.ImportFrom):
76+
if node.module and node.module.split(".")[0] in forbidden_imports:
77+
issues.append(f"- Forbidden import: {node.module} (line {node.lineno})")
78+
# Check for forbidden function calls
79+
if isinstance(node, ast.Call):
80+
func_name = ""
81+
if isinstance(node.func, ast.Name):
82+
func_name = node.func.id
83+
elif isinstance(node.func, ast.Attribute):
84+
func_name = f"{getattr(node.func.value, 'id', '')}.{node.func.attr}"
85+
if func_name in forbidden_calls:
86+
issues.append(f"- Forbidden function call: {func_name} (line {node.lineno})")
87+
return issues
88+
89+
90+
class CustomDetectorRegistry(BaseDetectorRegistry):
91+
def __init__(self):
92+
super().__init__()
93+
94+
issues = static_code_analysis(module_path = os.path.join(os.path.dirname(__file__), "custom_detectors", "custom_detectors.py"))
95+
if issues:
96+
logging.error(f"Detected {len(issues)} potential security issues inside the custom_detectors file: {issues}")
97+
raise ImportError(f"Unsafe code detected in custom_detectors:\n" + "\n".join(issues))
98+
99+
import custom_detectors.custom_detectors as custom_detectors
100+
101+
self.registry = {name: obj for name, obj
102+
in inspect.getmembers(custom_detectors, inspect.isfunction)
103+
if not name.startswith("_")}
104+
self.function_needs_headers = {name: "headers" in inspect.signature(obj).parameters for name, obj in self.registry.items() }
105+
logger.info(f"Registered the following custom detectors: {self.registry.keys()}")
106+
107+
def handle_request(self, content: str, detector_params: dict, headers: dict) -> List[ContentAnalysisResponse]:
108+
detections = []
109+
if "custom" in detector_params and isinstance(detector_params["custom"], (list, str)):
110+
custom_functions = detector_params["custom"]
111+
custom_functions = [custom_functions] if isinstance(custom_functions, str) else custom_functions
112+
for custom_function in custom_functions:
113+
if self.registry.get(custom_function):
114+
try:
115+
func_headers = headers if self.function_needs_headers.get(custom_function) else None
116+
result = custom_func_wrapper(self.registry[custom_function], custom_function, content, func_headers)
117+
if result is not None:
118+
detections.append(result)
119+
except Exception as e:
120+
logger.error(e)
121+
raise HTTPException(status_code=400, detail="Detection error, check detector logs")
122+
else:
123+
raise HTTPException(status_code=400, detail=f"Unrecognized custom function: {custom_function}")
124+
return detections
125+

detectors/built_in/file_type_detectors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def __init__(self):
176176
"yaml-with-schema:$SCHEMA": is_valid_yaml_schema,
177177
}
178178

179-
def handle_request(self, content: str, detector_params: dict) -> List[ContentAnalysisResponse]:
179+
def handle_request(self, content: str, detector_params: dict, headers: dict) -> List[ContentAnalysisResponse]:
180180
detections = []
181181
if "file_type" in detector_params and isinstance(detector_params["file_type"], (list, str)):
182182
file_types = detector_params["file_type"]

detectors/built_in/regex_detectors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __init__(self):
131131
"$CUSTOM_REGEX": custom_regex_documenter,
132132
}
133133

134-
def handle_request(self, content: str, detector_params: dict) -> List[ContentAnalysisResponse]:
134+
def handle_request(self, content: str, detector_params: dict, headers: dict) -> List[ContentAnalysisResponse]:
135135
detections = []
136136
if "regex" in detector_params and isinstance(detector_params["regex"], (list, str)):
137137
regexes = detector_params["regex"]
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
markdown==3.8.2
22
jsonschema==4.24.0
3-
xmlschema==4.1.0
3+
xmlschema==4.1.0
4+
requests==2.32.5
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import importlib
2+
import sys
3+
from http.client import HTTPException
4+
5+
import pytest
6+
import os
7+
from fastapi.testclient import TestClient
8+
9+
10+
CUSTOM_DETECTORS_PATH = os.path.join(
11+
os.path.dirname(__file__),
12+
"../../../detectors/built_in/custom_detectors/custom_detectors.py"
13+
)
14+
15+
with open(CUSTOM_DETECTORS_PATH) as f:
16+
SAFE_CODE = f.read()
17+
18+
UNSAFE_CODE = '''
19+
import os
20+
def evil(text: str) -> bool:
21+
os.system("echo haha gottem")
22+
return True
23+
'''
24+
25+
26+
def write_code_to_custom_detectors(code: str):
27+
with open(CUSTOM_DETECTORS_PATH, "w") as f:
28+
f.write(code)
29+
30+
def restore_safe_code():
31+
write_code_to_custom_detectors(SAFE_CODE)
32+
33+
34+
class TestCustomDetectors:
35+
@pytest.fixture
36+
def client(self):
37+
from detectors.built_in.app import app
38+
from detectors.built_in.custom_detectors_wrapper import CustomDetectorRegistry
39+
app.set_detector(CustomDetectorRegistry(), "custom")
40+
return TestClient(app)
41+
42+
@pytest.fixture(autouse=True)
43+
def cleanup_custom_detectors(self):
44+
# Always restore safe code after test
45+
yield
46+
restore_safe_code()
47+
48+
def test_missing_detector_type(self, client):
49+
payload = {
50+
"contents": ["What is an apple?"],
51+
"detector_params": {"custom1": ["contains_word"]}
52+
}
53+
resp = client.post("/api/v1/text/contents", json=payload)
54+
assert resp.status_code == 400 and "Detector custom1 not found" in resp.text
55+
56+
57+
def test_custom_detectors(self, client):
58+
payload = {
59+
"contents": ["What is an apple?"],
60+
"detector_params": {"custom": ["contains_word"]}
61+
}
62+
resp = client.post("/api/v1/text/contents", json=payload)
63+
assert resp.status_code == 200
64+
texts = [d["text"] for d in resp.json()[0]]
65+
assert "What is an apple?" in texts
66+
67+
def test_custom_detectors_not_match(self, client):
68+
msg = "What is an banana?"
69+
payload = {
70+
"contents": [msg],
71+
"detector_params": {"custom": ["contains_word"]}
72+
}
73+
resp = client.post("/api/v1/text/contents", json=payload)
74+
assert resp.status_code == 200
75+
texts = [d["text"] for d in resp.json()[0]]
76+
assert msg not in texts
77+
78+
def test_custom_detectors_need_header(self, client):
79+
msg = "What is an banana?"
80+
payload = {
81+
"contents": [msg],
82+
"detector_params": {"custom": ["function_that_needs_headers"]}
83+
}
84+
85+
# shouldn't flag
86+
headers = {"magic-key": "123"}
87+
resp = client.post("/api/v1/text/contents", json=payload, headers=headers)
88+
assert resp.status_code == 200
89+
texts = [d["text"] for d in resp.json()[0]]
90+
assert msg not in texts
91+
92+
# should flag
93+
headers = {"magic-key": "wrong"}
94+
resp = client.post("/api/v1/text/contents", json=payload, headers=headers)
95+
assert resp.status_code == 200
96+
texts = [d["text"] for d in resp.json()[0]]
97+
assert msg in texts
98+
99+
def test_unsafe_code(self, client):
100+
write_code_to_custom_detectors(UNSAFE_CODE)
101+
from detectors.built_in.custom_detectors_wrapper import CustomDetectorRegistry
102+
with pytest.raises(ImportError) as excinfo:
103+
CustomDetectorRegistry()
104+
assert "Unsafe code detected" in str(excinfo.value)
105+
assert "Forbidden import: os" in str(excinfo.value) or "os.system" in str(excinfo.value)
106+
107+
108+
def test_custom_detectors_func_doesnt_exist(self, client):
109+
payload = {
110+
"contents": ["What is an apple?"],
111+
"detector_params": {"custom": ["abc"]}
112+
}
113+
resp = client.post("/api/v1/text/contents", json=payload)
114+
assert resp.status_code == 400 and "Unrecognized custom function: abc" in resp.text
115+

0 commit comments

Comments
 (0)