Skip to content

Commit b1d8b77

Browse files
authored
Merge pull request #52 from RobGeada/PassthroughHeaders
Feat: Add optional passthrough headers to custom functions
2 parents 0b4d33a + 8e6647c commit b1d8b77

File tree

7 files changed

+45
-17
lines changed

7 files changed

+45
-17
lines changed

detectors/built_in/app.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from fastapi import HTTPException
3+
from fastapi import HTTPException, Request
44
from contextlib import asynccontextmanager
55
from base_detector_registry import BaseDetectorRegistry
66
from regex_detectors import RegexDetectorRegistry
@@ -29,9 +29,11 @@ async def lifespan(app: FastAPI):
2929

3030

3131
@app.post("/api/v1/text/contents", response_model=ContentsAnalysisResponse)
32-
def detect_content(request: ContentAnalysisHttpRequest):
32+
def detect_content(request: ContentAnalysisHttpRequest, raw_request: Request):
3333
logger.info(f"Request for {request.detector_params}")
3434

35+
headers = dict(raw_request.headers)
36+
3537
detections = []
3638
for content in request.contents:
3739
message_detections = []
@@ -43,7 +45,7 @@ def detect_content(request: ContentAnalysisHttpRequest):
4345
raise TypeError(f"Detector {detector_kind} is not a valid BaseDetectorRegistry")
4446
else:
4547
try:
46-
message_detections += detector_registry.handle_request(content, request.detector_params)
48+
message_detections += detector_registry.handle_request(content, request.detector_params, headers)
4749
except HTTPException as e:
4850
raise e
4951
except Exception as e:

detectors/built_in/base_detector_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def __init__(self):
88
self.registry = None
99

1010
@abstractmethod
11-
def handle_request(self, content: str, detector_params: dict) -> List[ContentAnalysisResponse]:
11+
def handle_request(self, content: str, detector_params: dict, headers: dict) -> List[ContentAnalysisResponse]:
1212
pass
1313

1414
def get_registry(self):

detectors/built_in/custom_detectors/custom_detectors.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@ def over_100_characters(text: str) -> bool:
44

55
def contains_word(text: str) -> bool:
66
return "apple" in text.lower()
7+
8+
def function_that_needs_headers(text: str, headers: dict) -> bool:
9+
return headers['magic-key'] != "123"

detectors/built_in/custom_detectors_wrapper.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,15 @@
1212

1313
logger = logging.getLogger(__name__)
1414

15-
def custom_func_wrapper(func: Callable, func_name: str, s: str) -> Optional[ContentAnalysisResponse]:
15+
def custom_func_wrapper(func: Callable, func_name: str, s: str, headers: dict) -> Optional[ContentAnalysisResponse]:
1616
"""Convert a some f(text)->bool into a Detector response"""
17+
sig = inspect.signature(func)
1718
try:
18-
result = func(s)
19+
if headers is not None:
20+
result = func(s, headers)
21+
else:
22+
result = func(s)
23+
1924
except Exception as e:
2025
logging.error(f"Error when computing custom detector function {func_name}: {e}")
2126
raise e
@@ -96,17 +101,19 @@ def __init__(self):
96101
self.registry = {name: obj for name, obj
97102
in inspect.getmembers(custom_detectors, inspect.isfunction)
98103
if not name.startswith("_")}
104+
self.function_needs_headers = {name: "headers" in inspect.signature(obj).parameters for name, obj in self.registry.items() }
99105
logger.info(f"Registered the following custom detectors: {self.registry.keys()}")
100106

101-
def handle_request(self, content: str, detector_params: dict) -> List[ContentAnalysisResponse]:
107+
def handle_request(self, content: str, detector_params: dict, headers: dict) -> List[ContentAnalysisResponse]:
102108
detections = []
103109
if "custom" in detector_params and isinstance(detector_params["custom"], (list, str)):
104110
custom_functions = detector_params["custom"]
105111
custom_functions = [custom_functions] if isinstance(custom_functions, str) else custom_functions
106112
for custom_function in custom_functions:
107113
if self.registry.get(custom_function):
108114
try:
109-
result = custom_func_wrapper(self.registry[custom_function], custom_function, content)
115+
func_headers = headers if self.function_needs_headers.get(custom_function) else None
116+
result = custom_func_wrapper(self.registry[custom_function], custom_function, content, func_headers)
110117
if result is not None:
111118
detections.append(result)
112119
except Exception as e:

detectors/built_in/file_type_detectors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def __init__(self):
176176
"yaml-with-schema:$SCHEMA": is_valid_yaml_schema,
177177
}
178178

179-
def handle_request(self, content: str, detector_params: dict) -> List[ContentAnalysisResponse]:
179+
def handle_request(self, content: str, detector_params: dict, headers: dict) -> List[ContentAnalysisResponse]:
180180
detections = []
181181
if "file_type" in detector_params and isinstance(detector_params["file_type"], (list, str)):
182182
file_types = detector_params["file_type"]

detectors/built_in/regex_detectors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __init__(self):
131131
"$CUSTOM_REGEX": custom_regex_documenter,
132132
}
133133

134-
def handle_request(self, content: str, detector_params: dict) -> List[ContentAnalysisResponse]:
134+
def handle_request(self, content: str, detector_params: dict, headers: dict) -> List[ContentAnalysisResponse]:
135135
detections = []
136136
if "regex" in detector_params and isinstance(detector_params["regex"], (list, str)):
137137
regexes = detector_params["regex"]

tests/detectors/builtIn/test_custom.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,8 @@
1212
"../../../detectors/built_in/custom_detectors/custom_detectors.py"
1313
)
1414

15-
SAFE_CODE = """
16-
def over_100_characters(text: str) -> bool:
17-
return len(text)>100
18-
19-
def contains_word(text: str) -> bool:
20-
return "apple" in text.lower()
21-
"""
15+
with open(CUSTOM_DETECTORS_PATH) as f:
16+
SAFE_CODE = f.read()
2217

2318
UNSAFE_CODE = '''
2419
import os
@@ -80,6 +75,27 @@ def test_custom_detectors_not_match(self, client):
8075
texts = [d["text"] for d in resp.json()[0]]
8176
assert msg not in texts
8277

78+
def test_custom_detectors_need_header(self, client):
79+
msg = "What is an banana?"
80+
payload = {
81+
"contents": [msg],
82+
"detector_params": {"custom": ["function_that_needs_headers"]}
83+
}
84+
85+
# shouldn't flag
86+
headers = {"magic-key": "123"}
87+
resp = client.post("/api/v1/text/contents", json=payload, headers=headers)
88+
assert resp.status_code == 200
89+
texts = [d["text"] for d in resp.json()[0]]
90+
assert msg not in texts
91+
92+
# should flag
93+
headers = {"magic-key": "wrong"}
94+
resp = client.post("/api/v1/text/contents", json=payload, headers=headers)
95+
assert resp.status_code == 200
96+
texts = [d["text"] for d in resp.json()[0]]
97+
assert msg in texts
98+
8399
def test_unsafe_code(self, client):
84100
write_code_to_custom_detectors(UNSAFE_CODE)
85101
from detectors.built_in.custom_detectors_wrapper import CustomDetectorRegistry

0 commit comments

Comments
 (0)