-
Notifications
You must be signed in to change notification settings - Fork 6
fix: detect dangerous ops in SavedModel function libraries #676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,9 @@ | |
| # Python operations that require special handling | ||
| PYTHON_OPS = ("PyFunc", "PyCall", "PyFuncStateless", "EagerPyFunc") | ||
|
|
||
| # Common checkpoint/restore ops that appear in benign function libraries. | ||
| _FUNCTION_LIBRARY_BENIGN_IO_OPS = {"SaveV2", "RestoreV2", "MergeV2Checkpoints", "ShardedFilename"} | ||
|
|
||
| # Defer protobuf availability check to avoid module-level imports | ||
| HAS_PROTOS: bool | None = None | ||
|
|
||
|
|
@@ -160,15 +163,17 @@ def _scan_saved_model_file(self, path: str) -> ScanResult: | |
| saved_model = SavedModel() | ||
| saved_model.ParseFromString(content) | ||
| for op_info in self._scan_tf_operations(saved_model): | ||
| location_label = op_info.get("location_label") or f"node: {op_info['node_name']}" | ||
| result.add_check( | ||
| name="TensorFlow Operation Security Check", | ||
| passed=False, | ||
| message=f"Dangerous TensorFlow operation: {op_info['operation']}", | ||
| severity=op_info["severity"], | ||
| location=f"{self.current_file_path} (node: {op_info['node_name']})", | ||
| location=f"{self.current_file_path} ({location_label})", | ||
| details={ | ||
| "op_type": op_info["operation"], | ||
| "node_name": op_info["node_name"], | ||
| "location_label": op_info.get("location_label"), | ||
| "meta_graph": op_info.get("meta_graph", "unknown"), | ||
| }, | ||
| why=get_tf_op_explanation(op_info["operation"]), | ||
|
|
@@ -290,6 +295,7 @@ def _scan_tf_operations(self, saved_model: Any) -> list[dict[str, Any]]: | |
| dangerous_ops: list[dict[str, Any]] = [] | ||
| try: | ||
| for meta_graph in saved_model.meta_graphs: | ||
| meta_graph_tag = meta_graph.meta_info_def.tags[0] if meta_graph.meta_info_def.tags else "unknown" | ||
| graph_def = meta_graph.graph_def | ||
| for node in graph_def.node: | ||
| # Skip Python ops here; they are handled by _check_python_op | ||
|
|
@@ -300,12 +306,27 @@ def _scan_tf_operations(self, saved_model: Any) -> list[dict[str, Any]]: | |
| { | ||
| "operation": node.op, | ||
| "node_name": node.name, | ||
| "location_label": f"node: {node.name}", | ||
| "severity": DANGEROUS_TF_OPERATIONS[node.op], | ||
| "meta_graph": ( | ||
| meta_graph.meta_info_def.tags[0] if meta_graph.meta_info_def.tags else "unknown" | ||
| ), | ||
| "meta_graph": meta_graph_tag, | ||
| } | ||
| ) | ||
|
|
||
| for function in graph_def.library.function: | ||
| function_name = function.signature.name or "unknown_function" | ||
| for node in function.node_def: | ||
| if node.op in PYTHON_OPS or node.op in _FUNCTION_LIBRARY_BENIGN_IO_OPS: | ||
| continue | ||
| if node.op in DANGEROUS_TF_OPERATIONS: | ||
| dangerous_ops.append( | ||
| { | ||
| "operation": node.op, | ||
| "node_name": node.name, | ||
| "location_label": f"function: {function_name}, node: {node.name}", | ||
| "severity": DANGEROUS_TF_OPERATIONS[node.op], | ||
| "meta_graph": meta_graph_tag, | ||
| } | ||
| ) | ||
| except Exception as e: # pragma: no cover | ||
| logger.warning(f"Failed to iterate TensorFlow graph: {e}") | ||
| return dangerous_ops | ||
|
|
@@ -331,6 +352,7 @@ def _analyze_saved_model(self, saved_model: Any, result: ScanResult) -> None: | |
|
|
||
| for meta_graph in saved_model.meta_graphs: | ||
| graph_def = meta_graph.graph_def | ||
| meta_graph_tag = meta_graph.meta_info_def.tags[0] if meta_graph.meta_info_def.tags else "unknown" | ||
|
|
||
| # Scan all nodes in the graph for suspicious operations | ||
| for node in graph_def.node: | ||
|
|
@@ -354,14 +376,49 @@ def _analyze_saved_model(self, saved_model: Any, result: ScanResult) -> None: | |
| details={ | ||
| "op_type": node.op, | ||
| "node_name": node.name, | ||
| "meta_graph": ( | ||
| meta_graph.meta_info_def.tags[0] if meta_graph.meta_info_def.tags else "unknown" | ||
| ), | ||
| "meta_graph": meta_graph_tag, | ||
| }, | ||
| why=get_tf_op_explanation(node.op), | ||
| ) | ||
| # else: already reported by generic dangerous-op pass | ||
|
|
||
| for function in graph_def.library.function: | ||
| function_name = function.signature.name or "unknown_function" | ||
|
|
||
| for node in function.node_def: | ||
| # Count all operation types | ||
| op_counts[node.op] = op_counts.get(node.op, 0) + 1 | ||
|
|
||
| if node.op in _FUNCTION_LIBRARY_BENIGN_IO_OPS: | ||
| continue | ||
|
|
||
| if node.op in self.suspicious_ops: | ||
| suspicious_op_found = True | ||
|
|
||
| if node.op in PYTHON_OPS: | ||
| self._check_python_op( | ||
| node, | ||
| result, | ||
| meta_graph, | ||
| location_label=f"function: {function_name}, node: {node.name}", | ||
| ) | ||
| elif node.op not in DANGEROUS_TF_OPERATIONS: | ||
| result.add_check( | ||
| name="TensorFlow Operation Security Check", | ||
| passed=False, | ||
| message=f"Suspicious TensorFlow operation: {node.op}", | ||
| severity=IssueSeverity.CRITICAL, | ||
| location=(f"{self.current_file_path} (function: {function_name}, node: {node.name})"), | ||
| rule_code="S703", | ||
| details={ | ||
| "op_type": node.op, | ||
| "node_name": node.name, | ||
| "function_name": function_name, | ||
| "meta_graph": meta_graph_tag, | ||
| }, | ||
| why=get_tf_op_explanation(node.op), | ||
| ) | ||
|
|
||
|
Comment on lines
+385
to
+421
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Verify the node variable usage and scoping in _analyze_saved_model
ast-grep --pattern $'for node in graph_def.node:
$$$
for node in $_.node_def:
$$$
$$$
'Repository: promptfoo/modelaudit Length of output: 46 🏁 Script executed: git ls-files | grep -E 'tf_savedmodel_scanner\.py$'Repository: promptfoo/modelaudit Length of output: 153 🏁 Script executed: sed -n '350,470p' modelaudit/scanners/tf_savedmodel_scanner.py | cat -nRepository: promptfoo/modelaudit Length of output: 7299 🏁 Script executed: sed -n '470,520p' modelaudit/scanners/tf_savedmodel_scanner.py | cat -nRepository: promptfoo/modelaudit Length of output: 2892 Critical: The new function-library loop (lines 385-420) declares After the function-library inner loop completes, the
Both checks are now inside the function loop but after the inner node loop, with shadowed variable bindings. Fix: Rename the inner loop variable to avoid shadowing (e.g., 🤖 Prompt for AI Agents |
||
| # Check for StatefulPartitionedCall which can contain custom functions | ||
| if node.op == "StatefulPartitionedCall" and hasattr(node, "attr") and "f" in node.attr: | ||
| # These operations can contain arbitrary functions | ||
|
|
@@ -449,7 +506,13 @@ def _analyze_saved_model(self, saved_model: Any, result: ScanResult) -> None: | |
| # Enhanced protobuf vulnerability scanning | ||
| self._scan_protobuf_vulnerabilities(saved_model, result) | ||
|
|
||
| def _check_python_op(self, node: Any, result: ScanResult, meta_graph: Any) -> None: | ||
| def _check_python_op( | ||
| self, | ||
| node: Any, | ||
| result: ScanResult, | ||
| meta_graph: Any, | ||
| location_label: str | None = None, | ||
| ) -> None: | ||
| """Check PyFunc/PyCall operations for embedded Python code""" | ||
| # PyFunc and PyCall can embed Python code in various ways: | ||
| # 1. As a string attribute containing Python code | ||
|
|
@@ -458,6 +521,7 @@ def _check_python_op(self, node: Any, result: ScanResult, meta_graph: Any) -> No | |
|
|
||
| code_found = False | ||
| python_code = None | ||
| node_location = f"{self.current_file_path} ({location_label or f'node: {node.name}'})" | ||
|
|
||
| # Try to extract Python code from node attributes | ||
| if hasattr(node, "attr"): | ||
|
|
@@ -491,7 +555,7 @@ def _check_python_op(self, node: Any, result: ScanResult, meta_graph: Any) -> No | |
| passed=False, | ||
| message=f"{node.op} operation references dangerous function: {func_name}", | ||
| severity=IssueSeverity.CRITICAL, | ||
| location=f"{self.current_file_path} (node: {node.name})", | ||
| location=node_location, | ||
| rule_code="S902", | ||
| details={ | ||
| "op_type": node.op, | ||
|
|
@@ -523,7 +587,7 @@ def _check_python_op(self, node: Any, result: ScanResult, meta_graph: Any) -> No | |
| passed=False, | ||
| message=issue_msg, | ||
| severity=severity, | ||
| location=f"{self.current_file_path} (node: {node.name})", | ||
| location=node_location, | ||
| rule_code="S902", | ||
| details={ | ||
| "op_type": node.op, | ||
|
|
@@ -545,7 +609,7 @@ def _check_python_op(self, node: Any, result: ScanResult, meta_graph: Any) -> No | |
| message=f"{node.op} operation contains suspicious data (possibly obfuscated code)", | ||
| rule_code="S902", | ||
| severity=IssueSeverity.CRITICAL, | ||
| location=f"{self.current_file_path} (node: {node.name})", | ||
| location=node_location, | ||
| details={ | ||
| "op_type": node.op, | ||
| "node_name": node.name, | ||
|
|
@@ -565,7 +629,7 @@ def _check_python_op(self, node: Any, result: ScanResult, meta_graph: Any) -> No | |
| message=f"{node.op} operation detected (unable to extract Python code)", | ||
| rule_code="S902", | ||
| severity=IssueSeverity.CRITICAL, | ||
| location=f"{self.current_file_path} (node: {node.name})", | ||
| location=node_location, | ||
| details={ | ||
| "op_type": node.op, | ||
| "node_name": node.name, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| ¾°½ûù¦ê¨§ÊüÅúëÎk«¹ä¥®ä¢¹ ðÕ¿ÁÌ(òð¸¨ÓÄÿ2:&33340531854789989996454843123276936653 |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -216,6 +216,21 @@ def test_detect_writefile_operation(tmp_path): | |||||||||||||||||||||||||||||
| assert any(i.why for i in writefile_issues), "Missing explanation for WriteFile detection" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @pytest.mark.skipif(not has_tf_protos(), reason="TensorFlow protobuf stubs unavailable") | ||||||||||||||||||||||||||||||
| @pytest.mark.parametrize("op_name", ["EagerPyFunc", "ReadFile", "WriteFile", "ParseTensor"]) | ||||||||||||||||||||||||||||||
| def test_detect_function_library_dangerous_ops(tmp_path, op_name): | ||||||||||||||||||||||||||||||
| """Dangerous ops in function library node_defs must be detected.""" | ||||||||||||||||||||||||||||||
| model_path = _create_test_savedmodel_with_function_op(tmp_path, op_name, f"function_lib_{op_name.lower()}") | ||||||||||||||||||||||||||||||
| scanner = TensorFlowSavedModelScanner() | ||||||||||||||||||||||||||||||
| result = scanner.scan(model_path) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| op_issues = [i for i in result.issues if i.message and op_name in i.message] | ||||||||||||||||||||||||||||||
| assert op_issues, f"Expected detection for function library op {op_name}" | ||||||||||||||||||||||||||||||
| assert any(i.severity == IssueSeverity.CRITICAL for i in op_issues) | ||||||||||||||||||||||||||||||
| assert any(i.why for i in op_issues), f"Missing explanation for {op_name} detection" | ||||||||||||||||||||||||||||||
| assert any(i.location and "function:" in i.location for i in op_issues) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
Comment on lines
+219
to
+232
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Add missing The test function is missing the required return type annotation. As per coding guidelines: "Use type hints ♻️ Proposed fix `@pytest.mark.skipif`(not has_tf_protos(), reason="TensorFlow protobuf stubs unavailable")
`@pytest.mark.parametrize`("op_name", ["EagerPyFunc", "ReadFile", "WriteFile", "ParseTensor"])
-def test_detect_function_library_dangerous_ops(tmp_path, op_name):
+def test_detect_function_library_dangerous_ops(tmp_path, op_name: str) -> None:
"""Dangerous ops in function library node_defs must be detected."""📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @pytest.mark.skipif(not has_tensorflow(), reason="TensorFlow not installed") | ||||||||||||||||||||||||||||||
| def test_tf_savedmodel_scanner_with_blacklist(tmp_path): | ||||||||||||||||||||||||||||||
| """Test TensorFlow SavedModel scanner with custom blacklist patterns.""" | ||||||||||||||||||||||||||||||
|
|
@@ -312,6 +327,36 @@ def _create_test_savedmodel_with_ops(tmp_path, op_names, model_name=None): | |||||||||||||||||||||||||||||
| return str(model_dir) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _create_test_savedmodel_with_function_op(tmp_path, op_name, model_name=None): | ||||||||||||||||||||||||||||||
| """Create a SavedModel with operation only inside function library.""" | ||||||||||||||||||||||||||||||
| from tensorflow.core.framework.function_pb2 import FunctionDef | ||||||||||||||||||||||||||||||
| from tensorflow.core.framework.node_def_pb2 import NodeDef | ||||||||||||||||||||||||||||||
| from tensorflow.core.protobuf.saved_model_pb2 import SavedModel | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if model_name is None: | ||||||||||||||||||||||||||||||
| model_name = f"test_model_function_{op_name.lower()}" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| model_dir = tmp_path / model_name | ||||||||||||||||||||||||||||||
| model_dir.mkdir() | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| saved_model = SavedModel() | ||||||||||||||||||||||||||||||
| meta_graph = saved_model.meta_graphs.add() | ||||||||||||||||||||||||||||||
| meta_graph.meta_info_def.tags.append("serve") | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| function = FunctionDef() | ||||||||||||||||||||||||||||||
| function.signature.name = f"dangerous_fn_{op_name.lower()}" | ||||||||||||||||||||||||||||||
| function.node_def.extend([NodeDef(name=f"fn_node_{op_name.lower()}", op=op_name)]) | ||||||||||||||||||||||||||||||
| meta_graph.graph_def.library.function.extend([function]) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| saved_model_path = model_dir / "saved_model.pb" | ||||||||||||||||||||||||||||||
| saved_model_path.write_bytes(saved_model.SerializeToString()) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| variables_dir = model_dir / "variables" | ||||||||||||||||||||||||||||||
| variables_dir.mkdir() | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| return str(model_dir) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
Comment on lines
+330
to
+358
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Add missing The helper function returns a string but lacks the type annotation. As per coding guidelines: "Always include type hints in Python code." ♻️ Proposed fix-def _create_test_savedmodel_with_function_op(tmp_path, op_name, model_name=None):
+def _create_test_savedmodel_with_function_op(tmp_path, op_name: str, model_name: str | None = None) -> str:
"""Create a SavedModel with operation only inside function library."""🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @pytest.mark.skipif(not has_tensorflow(), reason="TensorFlow not installed") | ||||||||||||||||||||||||||||||
| def test_tf_scanner_explanations_for_all_suspicious_ops(tmp_path): | ||||||||||||||||||||||||||||||
| """Test that all suspicious TensorFlow operations generate explanations.""" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: promptfoo/modelaudit
Length of output: 2969
🏁 Script executed:
Repository: promptfoo/modelaudit
Length of output: 5902
🏁 Script executed:
Repository: promptfoo/modelaudit
Length of output: 2160
MergeV2Checkpointsexclusion from function-library scanning weakens detection coverage.MergeV2Checkpointsis classified as HIGH-risk and documented as capable of manipulating and overwriting checkpoint files. However, it's excluded from function-library detection via_FUNCTION_LIBRARY_BENIGN_IO_OPS, whiletf_metagraph_scanner.pydoes not exclude it. This inconsistency means the same operation receives different treatment across scanners—potentially missing malicious file manipulation in function libraries. UnlikeSaveV2andRestoreV2(standard checkpoint I/O),MergeV2Checkpointsis a manipulation primitive that warrants consistent detection. Either align both scanners or remove this operation from the exclusion set to preserve detection coverage.🤖 Prompt for AI Agents