|
2 | 2 | import operator |
3 | 3 | from typing import Any |
4 | 4 |
|
| 5 | +# Limits to prevent resource exhaustion via crafted expressions |
| 6 | +MAX_EXPONENT = 100 # cap for ** operator |
| 7 | +MAX_REPEAT = 10_000 # cap for string/list repetition with * |
| 8 | + |
| 9 | + |
| 10 | +def _safe_pow(base: Any, exp: Any) -> Any: |
| 11 | + """Bounded power operator to prevent DoS via huge exponents.""" |
| 12 | + if isinstance(exp, (int, float)) and abs(exp) > MAX_EXPONENT: |
| 13 | + raise ValueError( |
| 14 | + f"Exponent {exp} exceeds maximum allowed ({MAX_EXPONENT})" |
| 15 | + ) |
| 16 | + return operator.pow(base, exp) |
| 17 | + |
| 18 | + |
| 19 | +def _safe_mult(a: Any, b: Any) -> Any: |
| 20 | + """Bounded multiplication that prevents huge string/list repetitions.""" |
| 21 | + if isinstance(a, (str, list, tuple, bytes)) and isinstance(b, int): |
| 22 | + if b > MAX_REPEAT: |
| 23 | + raise ValueError( |
| 24 | + f"Repeat count {b} exceeds maximum allowed ({MAX_REPEAT})" |
| 25 | + ) |
| 26 | + elif isinstance(b, (str, list, tuple, bytes)) and isinstance(a, int): |
| 27 | + if a > MAX_REPEAT: |
| 28 | + raise ValueError( |
| 29 | + f"Repeat count {a} exceeds maximum allowed ({MAX_REPEAT})" |
| 30 | + ) |
| 31 | + return operator.mul(a, b) |
| 32 | + |
| 33 | + |
5 | 34 | # Safe operators whitelist |
6 | 35 | SAFE_OPERATORS = { |
7 | 36 | ast.Add: operator.add, |
8 | 37 | ast.Sub: operator.sub, |
9 | | - ast.Mult: operator.mul, |
| 38 | + ast.Mult: _safe_mult, |
10 | 39 | ast.Div: operator.truediv, |
11 | 40 | ast.FloorDiv: operator.floordiv, |
12 | 41 | ast.Mod: operator.mod, |
13 | | - ast.Pow: operator.pow, |
| 42 | + ast.Pow: _safe_pow, |
14 | 43 | ast.LShift: operator.lshift, |
15 | 44 | ast.RShift: operator.rshift, |
16 | 45 | ast.BitOr: operator.or_, |
|
52 | 81 | "any": any, |
53 | 82 | } |
54 | 83 |
|
| 84 | +# Method whitelist with allowed receiver types. |
| 85 | +# Only these (method, type) combinations are auto-approved. |
| 86 | +SAFE_METHODS: dict[str, tuple[type, ...]] = { |
| 87 | + "get": (dict,), |
| 88 | + "keys": (dict,), |
| 89 | + "values": (dict,), |
| 90 | + "items": (dict,), |
| 91 | + "lower": (str,), |
| 92 | + "upper": (str,), |
| 93 | + "strip": (str,), |
| 94 | + "split": (str,), |
| 95 | +} |
| 96 | + |
55 | 97 |
|
56 | 98 | class SafeEvalVisitor(ast.NodeVisitor): |
| 99 | + MAX_DEPTH = 50 # prevent stack overflow from deeply nested expressions |
| 100 | + |
57 | 101 | def __init__(self, context: dict[str, Any]): |
58 | 102 | self.context = context |
| 103 | + self._depth = 0 |
59 | 104 |
|
60 | 105 | def visit(self, node: ast.AST) -> Any: |
61 | | - # Override visit to prevent default behavior and ensure only explicitly allowed nodes work |
62 | | - method = "visit_" + node.__class__.__name__ |
63 | | - visitor = getattr(self, method, self.generic_visit) |
64 | | - return visitor(node) |
| 106 | + self._depth += 1 |
| 107 | + if self._depth > self.MAX_DEPTH: |
| 108 | + raise ValueError( |
| 109 | + f"Expression nesting depth exceeds limit ({self.MAX_DEPTH})" |
| 110 | + ) |
| 111 | + try: |
| 112 | + method = "visit_" + node.__class__.__name__ |
| 113 | + visitor = getattr(self, method, self.generic_visit) |
| 114 | + return visitor(node) |
| 115 | + finally: |
| 116 | + self._depth -= 1 |
65 | 117 |
|
66 | 118 | def generic_visit(self, node: ast.AST): |
67 | 119 | raise ValueError(f"Use of {node.__class__.__name__} is not allowed") |
@@ -115,11 +167,23 @@ def visit_Compare(self, node: ast.Compare) -> Any: |
115 | 167 | return True |
116 | 168 |
|
117 | 169 | def visit_BoolOp(self, node: ast.BoolOp) -> Any: |
118 | | - values = [self.visit(v) for v in node.values] |
| 170 | + # Lazy evaluation matching Python short-circuit semantics. |
| 171 | + # `x and y` returns x if x is falsy, otherwise y. |
| 172 | + # `x or y` returns x if x is truthy, otherwise y. |
119 | 173 | if isinstance(node.op, ast.And): |
120 | | - return all(values) |
| 174 | + result: Any = True |
| 175 | + for v in node.values: |
| 176 | + result = self.visit(v) |
| 177 | + if not result: |
| 178 | + return result |
| 179 | + return result |
121 | 180 | elif isinstance(node.op, ast.Or): |
122 | | - return any(values) |
| 181 | + result = False |
| 182 | + for v in node.values: |
| 183 | + result = self.visit(v) |
| 184 | + if result: |
| 185 | + return result |
| 186 | + return result |
123 | 187 | raise ValueError(f"Boolean operator {type(node.op).__name__} is not allowed") |
124 | 188 |
|
125 | 189 | def visit_IfExp(self, node: ast.IfExp) -> Any: |
@@ -171,42 +235,29 @@ def visit_Attribute(self, node: ast.Attribute) -> Any: |
171 | 235 | raise AttributeError(f"Object has no attribute '{node.attr}'") |
172 | 236 |
|
173 | 237 | def visit_Call(self, node: ast.Call) -> Any: |
174 | | - # Only allow calling whitelisted functions |
175 | | - func = self.visit(node.func) |
176 | | - |
177 | | - # Check if the function object itself is in our whitelist values |
178 | | - # This is tricky because `func` is the actual function object, |
179 | | - # but we also want to verify it came from a safe place. |
180 | | - # Easier: Check if node.func is a Name and that name is in SAFE_FUNCTIONS. |
| 238 | + # Only allow calling whitelisted functions or type-checked methods. |
181 | 239 |
|
182 | 240 | is_safe = False |
| 241 | + |
183 | 242 | if isinstance(node.func, ast.Name): |
184 | 243 | if node.func.id in SAFE_FUNCTIONS: |
185 | 244 | is_safe = True |
186 | 245 |
|
187 | | - # Also allow methods on objects if they are safe? |
188 | | - # E.g. "somestring".lower() or list.append() (if we allowed mutation, but we don't for now) |
189 | | - # For now, restrict to SAFE_FUNCTIONS whitelist for global calls and deny method calls |
190 | | - # unless we explicitly add safe methods. |
191 | | - # Allowing method calls on strings/lists (split, join, get) is commonly needed. |
192 | | - |
193 | 246 | if isinstance(node.func, ast.Attribute): |
194 | | - # Method call. |
195 | | - # Allow basic safe methods? |
196 | | - # For security, start strict. Only helper functions. |
197 | | - # Re-visiting: User might want 'output.get("key")'. |
198 | 247 | method_name = node.func.attr |
199 | | - if method_name in [ |
200 | | - "get", |
201 | | - "keys", |
202 | | - "values", |
203 | | - "items", |
204 | | - "lower", |
205 | | - "upper", |
206 | | - "strip", |
207 | | - "split", |
208 | | - ]: |
209 | | - is_safe = True |
| 248 | + allowed_types = SAFE_METHODS.get(method_name) |
| 249 | + if allowed_types: |
| 250 | + # Evaluate the receiver and check its type before approving |
| 251 | + receiver = self.visit(node.func.value) |
| 252 | + if isinstance(receiver, allowed_types): |
| 253 | + is_safe = True |
| 254 | + # Build the bound method directly so we don't re-visit |
| 255 | + func = getattr(receiver, method_name) |
| 256 | + args = [self.visit(arg) for arg in node.args] |
| 257 | + keywords = {kw.arg: self.visit(kw.value) for kw in node.keywords} |
| 258 | + return func(*args, **keywords) |
| 259 | + |
| 260 | + func = self.visit(node.func) |
210 | 261 |
|
211 | 262 | if not is_safe and func not in SAFE_FUNCTIONS.values(): |
212 | 263 | raise ValueError("Call to function/method is not allowed") |
|
0 commit comments