|
12 | 12 |
|
13 | 13 | logger = logging.getLogger(__name__) |
14 | 14 |
|
15 | | -def custom_func_wrapper(func: Callable, func_name: str, s: str) -> Optional[ContentAnalysisResponse]: |
| 15 | +def custom_func_wrapper(func: Callable, func_name: str, s: str, headers: dict) -> Optional[ContentAnalysisResponse]: |
16 | 16 | """Convert a some f(text)->bool into a Detector response""" |
| 17 | + sig = inspect.signature(func) |
17 | 18 | try: |
18 | | - result = func(s) |
| 19 | + if headers is not None: |
| 20 | + result = func(s, headers) |
| 21 | + else: |
| 22 | + result = func(s) |
| 23 | + |
19 | 24 | except Exception as e: |
20 | 25 | logging.error(f"Error when computing custom detector function {func_name}: {e}") |
21 | 26 | raise e |
@@ -96,17 +101,19 @@ def __init__(self): |
96 | 101 | self.registry = {name: obj for name, obj |
97 | 102 | in inspect.getmembers(custom_detectors, inspect.isfunction) |
98 | 103 | if not name.startswith("_")} |
| 104 | + self.function_needs_headers = {name: "headers" in inspect.signature(obj).parameters for name, obj in self.registry.items() } |
99 | 105 | logger.info(f"Registered the following custom detectors: {self.registry.keys()}") |
100 | 106 |
|
101 | | - def handle_request(self, content: str, detector_params: dict) -> List[ContentAnalysisResponse]: |
| 107 | + def handle_request(self, content: str, detector_params: dict, headers: dict) -> List[ContentAnalysisResponse]: |
102 | 108 | detections = [] |
103 | 109 | if "custom" in detector_params and isinstance(detector_params["custom"], (list, str)): |
104 | 110 | custom_functions = detector_params["custom"] |
105 | 111 | custom_functions = [custom_functions] if isinstance(custom_functions, str) else custom_functions |
106 | 112 | for custom_function in custom_functions: |
107 | 113 | if self.registry.get(custom_function): |
108 | 114 | try: |
109 | | - result = custom_func_wrapper(self.registry[custom_function], custom_function, content) |
| 115 | + func_headers = headers if self.function_needs_headers.get(custom_function) else None |
| 116 | + result = custom_func_wrapper(self.registry[custom_function], custom_function, content, func_headers) |
110 | 117 | if result is not None: |
111 | 118 | detections.append(result) |
112 | 119 | except Exception as e: |
|
0 commit comments