fix: detect dangerous ops in SavedModel function libraries#676
fix: detect dangerous ops in SavedModel function libraries#676Daketey wants to merge 1 commit intopromptfoo:mainfrom
Conversation
WalkthroughEnhanced TensorFlow SavedModel scanner with function library scanning, improved operation location tracking via location_label propagation, benign IO operation classification, and new tests validating detection of dangerous operations within SavedModel function definitions. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes Possibly related issues
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelaudit/scanners/tf_savedmodel_scanner.py`:
- Around line 30-32: The constant _FUNCTION_LIBRARY_BENIGN_IO_OPS in
tf_savedmodel_scanner.py currently includes "MergeV2Checkpoints", which
suppresses function-library scanning for a HIGH-risk op; remove
"MergeV2Checkpoints" from that set (or otherwise stop treating it as benign) so
its usage in function libraries is detected consistently with
tf_metagraph_scanner.py; update the set _FUNCTION_LIBRARY_BENIGN_IO_OPS to only
contain true benign ops (e.g., keep "SaveV2", "RestoreV2", "ShardedFilename")
and ensure any tests or comments referencing MergeV2Checkpoints are adjusted
accordingly.
- Around line 385-421: The inner loop over function.node_def is shadowing the
outer graph_def.node variable (both use "node"), causing subsequent
StatefulPartitionedCall and Lambda-layer checks to run against the last function
node instead of graph nodes; fix by renaming the inner loop variable (e.g.,
change for node in function.node_def to for fn_node in function.node_def) and
update all uses inside that inner loop (the op_counts increment,
_FUNCTION_LIBRARY_BENIGN_IO_OPS check, PYTHON_OPS branch that calls
self._check_python_op, the DANGEROUS_TF_OPERATIONS branch and
result.add_check/get_tf_op_explanation) so that the outer "node" remains intact
for the later StatefulPartitionedCall and Lambda detection logic which must
iterate over graph_def.node.
In `@tests/scanners/test_tf_savedmodel_scanner.py`:
- Around line 219-232: The test function
test_detect_function_library_dangerous_ops is missing the required return type
annotation; update its definition to include the type hint "-> None" (i.e., def
test_detect_function_library_dangerous_ops(tmp_path, op_name) -> None:) so the
test signature follows the project's guideline for test methods.
- Around line 330-358: The helper function
_create_test_savedmodel_with_function_op returns a string but lacks a return
type annotation; update its signature to include "-> str" (i.e., def
_create_test_savedmodel_with_function_op(tmp_path, op_name, model_name=None) ->
str:) so the function is explicitly annotated as returning a string; no other
behavior changes needed.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: ASSERTIVE
Plan: Pro
Run ID: b7b701e1-f98b-43a5-914c-2c44f025a8ac
📒 Files selected for processing (6)
modelaudit/scanners/tf_savedmodel_scanner.pytests/assets/samples/tensorflow/py_func_rce_savedmodel/fingerprint.pbtests/assets/samples/tensorflow/py_func_rce_savedmodel/saved_model.pbtests/assets/samples/tensorflow/py_func_rce_savedmodel/variables/variables.data-00000-of-00001tests/assets/samples/tensorflow/py_func_rce_savedmodel/variables/variables.indextests/scanners/test_tf_savedmodel_scanner.py
| # Common checkpoint/restore ops that appear in benign function libraries. | ||
| _FUNCTION_LIBRARY_BENIGN_IO_OPS = {"SaveV2", "RestoreV2", "MergeV2Checkpoints", "ShardedFilename"} | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check how MergeV2Checkpoints is handled across all scanners
rg -n "MergeV2Checkpoints" --type=py -C2Repository: promptfoo/modelaudit
Length of output: 2969
🏁 Script executed:
# Check if there's any documentation or comments explaining why these ops are considered benign in function libraries
rg -n "FUNCTION_LIBRARY_BENIGN_IO_OPS|function.librar" --type=py -B5 -A5Repository: promptfoo/modelaudit
Length of output: 5902
🏁 Script executed:
# Also check how SaveV2 and RestoreV2 are handled in metagraph scanner for comparison
rg -n "_EXCLUDE_GENERIC_DANGEROUS|_BENIGN_CHECKPOINT_IO_OPS" --type=py -B3 -A3Repository: promptfoo/modelaudit
Length of output: 2160
MergeV2Checkpoints exclusion from function-library scanning weakens detection coverage.
MergeV2Checkpoints is classified as HIGH-risk and documented as capable of manipulating and overwriting checkpoint files. However, it's excluded from function-library detection via _FUNCTION_LIBRARY_BENIGN_IO_OPS, while tf_metagraph_scanner.py does not exclude it. This inconsistency means the same operation receives different treatment across scanners—potentially missing malicious file manipulation in function libraries. Unlike SaveV2 and RestoreV2 (standard checkpoint I/O), MergeV2Checkpoints is a manipulation primitive that warrants consistent detection. Either align both scanners or remove this operation from the exclusion set to preserve detection coverage.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelaudit/scanners/tf_savedmodel_scanner.py` around lines 30 - 32, The
constant _FUNCTION_LIBRARY_BENIGN_IO_OPS in tf_savedmodel_scanner.py currently
includes "MergeV2Checkpoints", which suppresses function-library scanning for a
HIGH-risk op; remove "MergeV2Checkpoints" from that set (or otherwise stop
treating it as benign) so its usage in function libraries is detected
consistently with tf_metagraph_scanner.py; update the set
_FUNCTION_LIBRARY_BENIGN_IO_OPS to only contain true benign ops (e.g., keep
"SaveV2", "RestoreV2", "ShardedFilename") and ensure any tests or comments
referencing MergeV2Checkpoints are adjusted accordingly.
| for function in graph_def.library.function: | ||
| function_name = function.signature.name or "unknown_function" | ||
|
|
||
| for node in function.node_def: | ||
| # Count all operation types | ||
| op_counts[node.op] = op_counts.get(node.op, 0) + 1 | ||
|
|
||
| if node.op in _FUNCTION_LIBRARY_BENIGN_IO_OPS: | ||
| continue | ||
|
|
||
| if node.op in self.suspicious_ops: | ||
| suspicious_op_found = True | ||
|
|
||
| if node.op in PYTHON_OPS: | ||
| self._check_python_op( | ||
| node, | ||
| result, | ||
| meta_graph, | ||
| location_label=f"function: {function_name}, node: {node.name}", | ||
| ) | ||
| elif node.op not in DANGEROUS_TF_OPERATIONS: | ||
| result.add_check( | ||
| name="TensorFlow Operation Security Check", | ||
| passed=False, | ||
| message=f"Suspicious TensorFlow operation: {node.op}", | ||
| severity=IssueSeverity.CRITICAL, | ||
| location=(f"{self.current_file_path} (function: {function_name}, node: {node.name})"), | ||
| rule_code="S703", | ||
| details={ | ||
| "op_type": node.op, | ||
| "node_name": node.name, | ||
| "function_name": function_name, | ||
| "meta_graph": meta_graph_tag, | ||
| }, | ||
| why=get_tf_op_explanation(node.op), | ||
| ) | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify the node variable usage and scoping in _analyze_saved_model
ast-grep --pattern $'for node in graph_def.node:
$$$
for node in $_.node_def:
$$$
$$$
'Repository: promptfoo/modelaudit
Length of output: 46
🏁 Script executed:
git ls-files | grep -E 'tf_savedmodel_scanner\.py$'Repository: promptfoo/modelaudit
Length of output: 153
🏁 Script executed:
sed -n '350,470p' modelaudit/scanners/tf_savedmodel_scanner.py | cat -nRepository: promptfoo/modelaudit
Length of output: 7299
🏁 Script executed:
sed -n '470,520p' modelaudit/scanners/tf_savedmodel_scanner.py | cat -nRepository: promptfoo/modelaudit
Length of output: 2892
Critical: node variable shadowing breaks StatefulPartitionedCall and Lambda checks.
The new function-library loop (lines 385-420) declares for node in function.node_def: at line 388, which shadows the outer node from for node in graph_def.node: (line 358).
After the function-library inner loop completes, the node variable at line 422 refers to the last node from function.node_def, not the graph nodes. This breaks:
- StatefulPartitionedCall check (lines 422-469): Now executes only on the last node of each function's nodes, instead of all graph nodes.
- Lambda layer detection (line 470+): Similarly checks only the last node of each function instead of all graph nodes.
Both checks are now inside the function loop but after the inner node loop, with shadowed variable bindings.
Fix: Rename the inner loop variable to avoid shadowing (e.g., fn_node or func_node_def), or restructure to run these checks on the outer graph_def.node loop where they belong.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelaudit/scanners/tf_savedmodel_scanner.py` around lines 385 - 421, The
inner loop over function.node_def is shadowing the outer graph_def.node variable
(both use "node"), causing subsequent StatefulPartitionedCall and Lambda-layer
checks to run against the last function node instead of graph nodes; fix by
renaming the inner loop variable (e.g., change for node in function.node_def to
for fn_node in function.node_def) and update all uses inside that inner loop
(the op_counts increment, _FUNCTION_LIBRARY_BENIGN_IO_OPS check, PYTHON_OPS
branch that calls self._check_python_op, the DANGEROUS_TF_OPERATIONS branch and
result.add_check/get_tf_op_explanation) so that the outer "node" remains intact
for the later StatefulPartitionedCall and Lambda detection logic which must
iterate over graph_def.node.
| @pytest.mark.skipif(not has_tf_protos(), reason="TensorFlow protobuf stubs unavailable") | ||
| @pytest.mark.parametrize("op_name", ["EagerPyFunc", "ReadFile", "WriteFile", "ParseTensor"]) | ||
| def test_detect_function_library_dangerous_ops(tmp_path, op_name): | ||
| """Dangerous ops in function library node_defs must be detected.""" | ||
| model_path = _create_test_savedmodel_with_function_op(tmp_path, op_name, f"function_lib_{op_name.lower()}") | ||
| scanner = TensorFlowSavedModelScanner() | ||
| result = scanner.scan(model_path) | ||
|
|
||
| op_issues = [i for i in result.issues if i.message and op_name in i.message] | ||
| assert op_issues, f"Expected detection for function library op {op_name}" | ||
| assert any(i.severity == IssueSeverity.CRITICAL for i in op_issues) | ||
| assert any(i.why for i in op_issues), f"Missing explanation for {op_name} detection" | ||
| assert any(i.location and "function:" in i.location for i in op_issues) | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Add missing -> None type hint.
The test function is missing the required return type annotation.
As per coding guidelines: "Use type hints -> None on all test methods."
♻️ Proposed fix
`@pytest.mark.skipif`(not has_tf_protos(), reason="TensorFlow protobuf stubs unavailable")
`@pytest.mark.parametrize`("op_name", ["EagerPyFunc", "ReadFile", "WriteFile", "ParseTensor"])
-def test_detect_function_library_dangerous_ops(tmp_path, op_name):
+def test_detect_function_library_dangerous_ops(tmp_path, op_name: str) -> None:
"""Dangerous ops in function library node_defs must be detected."""📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @pytest.mark.skipif(not has_tf_protos(), reason="TensorFlow protobuf stubs unavailable") | |
| @pytest.mark.parametrize("op_name", ["EagerPyFunc", "ReadFile", "WriteFile", "ParseTensor"]) | |
| def test_detect_function_library_dangerous_ops(tmp_path, op_name): | |
| """Dangerous ops in function library node_defs must be detected.""" | |
| model_path = _create_test_savedmodel_with_function_op(tmp_path, op_name, f"function_lib_{op_name.lower()}") | |
| scanner = TensorFlowSavedModelScanner() | |
| result = scanner.scan(model_path) | |
| op_issues = [i for i in result.issues if i.message and op_name in i.message] | |
| assert op_issues, f"Expected detection for function library op {op_name}" | |
| assert any(i.severity == IssueSeverity.CRITICAL for i in op_issues) | |
| assert any(i.why for i in op_issues), f"Missing explanation for {op_name} detection" | |
| assert any(i.location and "function:" in i.location for i in op_issues) | |
| def test_detect_function_library_dangerous_ops(tmp_path: Path, op_name: str) -> None: |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/scanners/test_tf_savedmodel_scanner.py` around lines 219 - 232, The
test function test_detect_function_library_dangerous_ops is missing the required
return type annotation; update its definition to include the type hint "-> None"
(i.e., def test_detect_function_library_dangerous_ops(tmp_path, op_name) ->
None:) so the test signature follows the project's guideline for test methods.
| def _create_test_savedmodel_with_function_op(tmp_path, op_name, model_name=None): | ||
| """Create a SavedModel with operation only inside function library.""" | ||
| from tensorflow.core.framework.function_pb2 import FunctionDef | ||
| from tensorflow.core.framework.node_def_pb2 import NodeDef | ||
| from tensorflow.core.protobuf.saved_model_pb2 import SavedModel | ||
|
|
||
| if model_name is None: | ||
| model_name = f"test_model_function_{op_name.lower()}" | ||
|
|
||
| model_dir = tmp_path / model_name | ||
| model_dir.mkdir() | ||
|
|
||
| saved_model = SavedModel() | ||
| meta_graph = saved_model.meta_graphs.add() | ||
| meta_graph.meta_info_def.tags.append("serve") | ||
|
|
||
| function = FunctionDef() | ||
| function.signature.name = f"dangerous_fn_{op_name.lower()}" | ||
| function.node_def.extend([NodeDef(name=f"fn_node_{op_name.lower()}", op=op_name)]) | ||
| meta_graph.graph_def.library.function.extend([function]) | ||
|
|
||
| saved_model_path = model_dir / "saved_model.pb" | ||
| saved_model_path.write_bytes(saved_model.SerializeToString()) | ||
|
|
||
| variables_dir = model_dir / "variables" | ||
| variables_dir.mkdir() | ||
|
|
||
| return str(model_dir) | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Add missing -> str return type hint.
The helper function returns a string but lacks the type annotation.
As per coding guidelines: "Always include type hints in Python code."
♻️ Proposed fix
-def _create_test_savedmodel_with_function_op(tmp_path, op_name, model_name=None):
+def _create_test_savedmodel_with_function_op(tmp_path, op_name: str, model_name: str | None = None) -> str:
"""Create a SavedModel with operation only inside function library."""🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/scanners/test_tf_savedmodel_scanner.py` around lines 330 - 358, The
helper function _create_test_savedmodel_with_function_op returns a string but
lacks a return type annotation; update its signature to include "-> str" (i.e.,
def _create_test_savedmodel_with_function_op(tmp_path, op_name, model_name=None)
-> str:) so the function is explicitly annotated as returning a string; no other
behavior changes needed.
|
@mldangelo closing this PR as #677 has been raised for the fix |
|
@Daketey I apologize for raising a separate PR, if you would like to get this in please feel free to re-open. Would gladly accept the contribution. Separately, please feel free to reach out for some swag. |
|
@mldangelo No worries! I’d be happy to contribute in the future. I’ve also reached out to you via email, please let me know if it didn’t come through for any reason. |
Summary
What
This PR fixes a detection gap in TensorFlow SavedModel scanning by adding analysis of function-library node definitions and including a real attack-vector sample asset. This issue was highlighted in issue #675
Why
Real malicious SavedModel payloads frequently place dangerous operations in function definitions (graph_def.library.function[].node_def) instead of only in top-level graph nodes. Without scanning function definitions, ModelAudit can miss critical threats.
Changes
Security Impact
Improves detection coverage for critical TensorFlow attack vectors that use function-library placement for stealth.
Validation
uv run ruff format --check modelaudit/ tests/uv run ruff check modelaudit/ tests/uv run mypy modelaudit/uv run pytest -n auto -m "not slow and not integration" --maxfail=1Checklist
AGENTS.md.Summary by CodeRabbit
Release Notes
Bug Fixes & Improvements
Tests