Skip to content

Commit 2224676

Browse files
committed
fix(graph): harden safe_eval against DoS, fix BoolOp short-circuit, add depth limit
- cap exponentiation at MAX_EXPONENT (100) to prevent hangs from 9**9**9 - bound string/list repetition to MAX_REPEAT (10k) to block memory bombs - rewrite visit_BoolOp to evaluate lazily, matching Python semantics (e.g. 'x and len(x)' no longer crashes when x is None) - add MAX_DEPTH (50) recursion limit to prevent stack overflow - validate method receiver types (e.g. .get() only on dict, .lower() only on str) - add 42 tests covering all four vulnerability categories Closes #5109
1 parent f36add8 commit 2224676

File tree

2 files changed

+339
-37
lines changed

2 files changed

+339
-37
lines changed

core/framework/graph/safe_eval.py

Lines changed: 88 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,44 @@
22
import operator
33
from typing import Any
44

5+
# Limits to prevent resource exhaustion via crafted expressions
6+
MAX_EXPONENT = 100 # cap for ** operator
7+
MAX_REPEAT = 10_000 # cap for string/list repetition with *
8+
9+
10+
def _safe_pow(base: Any, exp: Any) -> Any:
11+
"""Bounded power operator to prevent DoS via huge exponents."""
12+
if isinstance(exp, (int, float)) and abs(exp) > MAX_EXPONENT:
13+
raise ValueError(
14+
f"Exponent {exp} exceeds maximum allowed ({MAX_EXPONENT})"
15+
)
16+
return operator.pow(base, exp)
17+
18+
19+
def _safe_mult(a: Any, b: Any) -> Any:
20+
"""Bounded multiplication that prevents huge string/list repetitions."""
21+
if isinstance(a, (str, list, tuple, bytes)) and isinstance(b, int):
22+
if b > MAX_REPEAT:
23+
raise ValueError(
24+
f"Repeat count {b} exceeds maximum allowed ({MAX_REPEAT})"
25+
)
26+
elif isinstance(b, (str, list, tuple, bytes)) and isinstance(a, int):
27+
if a > MAX_REPEAT:
28+
raise ValueError(
29+
f"Repeat count {a} exceeds maximum allowed ({MAX_REPEAT})"
30+
)
31+
return operator.mul(a, b)
32+
33+
534
# Safe operators whitelist
635
SAFE_OPERATORS = {
736
ast.Add: operator.add,
837
ast.Sub: operator.sub,
9-
ast.Mult: operator.mul,
38+
ast.Mult: _safe_mult,
1039
ast.Div: operator.truediv,
1140
ast.FloorDiv: operator.floordiv,
1241
ast.Mod: operator.mod,
13-
ast.Pow: operator.pow,
42+
ast.Pow: _safe_pow,
1443
ast.LShift: operator.lshift,
1544
ast.RShift: operator.rshift,
1645
ast.BitOr: operator.or_,
@@ -52,16 +81,39 @@
5281
"any": any,
5382
}
5483

84+
# Method whitelist with allowed receiver types.
85+
# Only these (method, type) combinations are auto-approved.
86+
SAFE_METHODS: dict[str, tuple[type, ...]] = {
87+
"get": (dict,),
88+
"keys": (dict,),
89+
"values": (dict,),
90+
"items": (dict,),
91+
"lower": (str,),
92+
"upper": (str,),
93+
"strip": (str,),
94+
"split": (str,),
95+
}
96+
5597

5698
class SafeEvalVisitor(ast.NodeVisitor):
99+
MAX_DEPTH = 50 # prevent stack overflow from deeply nested expressions
100+
57101
def __init__(self, context: dict[str, Any]):
58102
self.context = context
103+
self._depth = 0
59104

60105
def visit(self, node: ast.AST) -> Any:
61-
# Override visit to prevent default behavior and ensure only explicitly allowed nodes work
62-
method = "visit_" + node.__class__.__name__
63-
visitor = getattr(self, method, self.generic_visit)
64-
return visitor(node)
106+
self._depth += 1
107+
if self._depth > self.MAX_DEPTH:
108+
raise ValueError(
109+
f"Expression nesting depth exceeds limit ({self.MAX_DEPTH})"
110+
)
111+
try:
112+
method = "visit_" + node.__class__.__name__
113+
visitor = getattr(self, method, self.generic_visit)
114+
return visitor(node)
115+
finally:
116+
self._depth -= 1
65117

66118
def generic_visit(self, node: ast.AST):
67119
raise ValueError(f"Use of {node.__class__.__name__} is not allowed")
@@ -115,11 +167,23 @@ def visit_Compare(self, node: ast.Compare) -> Any:
115167
return True
116168

117169
def visit_BoolOp(self, node: ast.BoolOp) -> Any:
118-
values = [self.visit(v) for v in node.values]
170+
# Lazy evaluation matching Python short-circuit semantics.
171+
# `x and y` returns x if x is falsy, otherwise y.
172+
# `x or y` returns x if x is truthy, otherwise y.
119173
if isinstance(node.op, ast.And):
120-
return all(values)
174+
result: Any = True
175+
for v in node.values:
176+
result = self.visit(v)
177+
if not result:
178+
return result
179+
return result
121180
elif isinstance(node.op, ast.Or):
122-
return any(values)
181+
result = False
182+
for v in node.values:
183+
result = self.visit(v)
184+
if result:
185+
return result
186+
return result
123187
raise ValueError(f"Boolean operator {type(node.op).__name__} is not allowed")
124188

125189
def visit_IfExp(self, node: ast.IfExp) -> Any:
@@ -171,42 +235,29 @@ def visit_Attribute(self, node: ast.Attribute) -> Any:
171235
raise AttributeError(f"Object has no attribute '{node.attr}'")
172236

173237
def visit_Call(self, node: ast.Call) -> Any:
174-
# Only allow calling whitelisted functions
175-
func = self.visit(node.func)
176-
177-
# Check if the function object itself is in our whitelist values
178-
# This is tricky because `func` is the actual function object,
179-
# but we also want to verify it came from a safe place.
180-
# Easier: Check if node.func is a Name and that name is in SAFE_FUNCTIONS.
238+
# Only allow calling whitelisted functions or type-checked methods.
181239

182240
is_safe = False
241+
183242
if isinstance(node.func, ast.Name):
184243
if node.func.id in SAFE_FUNCTIONS:
185244
is_safe = True
186245

187-
# Also allow methods on objects if they are safe?
188-
# E.g. "somestring".lower() or list.append() (if we allowed mutation, but we don't for now)
189-
# For now, restrict to SAFE_FUNCTIONS whitelist for global calls and deny method calls
190-
# unless we explicitly add safe methods.
191-
# Allowing method calls on strings/lists (split, join, get) is commonly needed.
192-
193246
if isinstance(node.func, ast.Attribute):
194-
# Method call.
195-
# Allow basic safe methods?
196-
# For security, start strict. Only helper functions.
197-
# Re-visiting: User might want 'output.get("key")'.
198247
method_name = node.func.attr
199-
if method_name in [
200-
"get",
201-
"keys",
202-
"values",
203-
"items",
204-
"lower",
205-
"upper",
206-
"strip",
207-
"split",
208-
]:
209-
is_safe = True
248+
allowed_types = SAFE_METHODS.get(method_name)
249+
if allowed_types:
250+
# Evaluate the receiver and check its type before approving
251+
receiver = self.visit(node.func.value)
252+
if isinstance(receiver, allowed_types):
253+
is_safe = True
254+
# Build the bound method directly so we don't re-visit
255+
func = getattr(receiver, method_name)
256+
args = [self.visit(arg) for arg in node.args]
257+
keywords = {kw.arg: self.visit(kw.value) for kw in node.keywords}
258+
return func(*args, **keywords)
259+
260+
func = self.visit(node.func)
210261

211262
if not is_safe and func not in SAFE_FUNCTIONS.values():
212263
raise ValueError("Call to function/method is not allowed")

0 commit comments

Comments
 (0)