Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 76 additions & 12 deletions modelaudit/scanners/tf_savedmodel_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
# Python operations that require special handling
PYTHON_OPS = ("PyFunc", "PyCall", "PyFuncStateless", "EagerPyFunc")

# Common checkpoint/restore ops that appear in benign function libraries.
_FUNCTION_LIBRARY_BENIGN_IO_OPS = {"SaveV2", "RestoreV2", "MergeV2Checkpoints", "ShardedFilename"}

Comment on lines +30 to +32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check how MergeV2Checkpoints is handled across all scanners
rg -n "MergeV2Checkpoints" --type=py -C2

Repository: promptfoo/modelaudit

Length of output: 2969


🏁 Script executed:

# Check if there's any documentation or comments explaining why these ops are considered benign in function libraries
rg -n "FUNCTION_LIBRARY_BENIGN_IO_OPS|function.librar" --type=py -B5 -A5

Repository: promptfoo/modelaudit

Length of output: 5902


🏁 Script executed:

# Also check how SaveV2 and RestoreV2 are handled in metagraph scanner for comparison
rg -n "_EXCLUDE_GENERIC_DANGEROUS|_BENIGN_CHECKPOINT_IO_OPS" --type=py -B3 -A3

Repository: promptfoo/modelaudit

Length of output: 2160


MergeV2Checkpoints exclusion from function-library scanning weakens detection coverage.

MergeV2Checkpoints is classified as HIGH-risk and documented as capable of manipulating and overwriting checkpoint files. However, it's excluded from function-library detection via _FUNCTION_LIBRARY_BENIGN_IO_OPS, while tf_metagraph_scanner.py does not exclude it. This inconsistency means the same operation receives different treatment across scanners—potentially missing malicious file manipulation in function libraries. Unlike SaveV2 and RestoreV2 (standard checkpoint I/O), MergeV2Checkpoints is a manipulation primitive that warrants consistent detection. Either align both scanners or remove this operation from the exclusion set to preserve detection coverage.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelaudit/scanners/tf_savedmodel_scanner.py` around lines 30 - 32, The
constant _FUNCTION_LIBRARY_BENIGN_IO_OPS in tf_savedmodel_scanner.py currently
includes "MergeV2Checkpoints", which suppresses function-library scanning for a
HIGH-risk op; remove "MergeV2Checkpoints" from that set (or otherwise stop
treating it as benign) so its usage in function libraries is detected
consistently with tf_metagraph_scanner.py; update the set
_FUNCTION_LIBRARY_BENIGN_IO_OPS to only contain true benign ops (e.g., keep
"SaveV2", "RestoreV2", "ShardedFilename") and ensure any tests or comments
referencing MergeV2Checkpoints are adjusted accordingly.

# Defer protobuf availability check to avoid module-level imports
HAS_PROTOS: bool | None = None

Expand Down Expand Up @@ -160,15 +163,17 @@ def _scan_saved_model_file(self, path: str) -> ScanResult:
saved_model = SavedModel()
saved_model.ParseFromString(content)
for op_info in self._scan_tf_operations(saved_model):
location_label = op_info.get("location_label") or f"node: {op_info['node_name']}"
result.add_check(
name="TensorFlow Operation Security Check",
passed=False,
message=f"Dangerous TensorFlow operation: {op_info['operation']}",
severity=op_info["severity"],
location=f"{self.current_file_path} (node: {op_info['node_name']})",
location=f"{self.current_file_path} ({location_label})",
details={
"op_type": op_info["operation"],
"node_name": op_info["node_name"],
"location_label": op_info.get("location_label"),
"meta_graph": op_info.get("meta_graph", "unknown"),
},
why=get_tf_op_explanation(op_info["operation"]),
Expand Down Expand Up @@ -290,6 +295,7 @@ def _scan_tf_operations(self, saved_model: Any) -> list[dict[str, Any]]:
dangerous_ops: list[dict[str, Any]] = []
try:
for meta_graph in saved_model.meta_graphs:
meta_graph_tag = meta_graph.meta_info_def.tags[0] if meta_graph.meta_info_def.tags else "unknown"
graph_def = meta_graph.graph_def
for node in graph_def.node:
# Skip Python ops here; they are handled by _check_python_op
Expand All @@ -300,12 +306,27 @@ def _scan_tf_operations(self, saved_model: Any) -> list[dict[str, Any]]:
{
"operation": node.op,
"node_name": node.name,
"location_label": f"node: {node.name}",
"severity": DANGEROUS_TF_OPERATIONS[node.op],
"meta_graph": (
meta_graph.meta_info_def.tags[0] if meta_graph.meta_info_def.tags else "unknown"
),
"meta_graph": meta_graph_tag,
}
)

for function in graph_def.library.function:
function_name = function.signature.name or "unknown_function"
for node in function.node_def:
if node.op in PYTHON_OPS or node.op in _FUNCTION_LIBRARY_BENIGN_IO_OPS:
continue
if node.op in DANGEROUS_TF_OPERATIONS:
dangerous_ops.append(
{
"operation": node.op,
"node_name": node.name,
"location_label": f"function: {function_name}, node: {node.name}",
"severity": DANGEROUS_TF_OPERATIONS[node.op],
"meta_graph": meta_graph_tag,
}
)
except Exception as e: # pragma: no cover
logger.warning(f"Failed to iterate TensorFlow graph: {e}")
return dangerous_ops
Expand All @@ -331,6 +352,7 @@ def _analyze_saved_model(self, saved_model: Any, result: ScanResult) -> None:

for meta_graph in saved_model.meta_graphs:
graph_def = meta_graph.graph_def
meta_graph_tag = meta_graph.meta_info_def.tags[0] if meta_graph.meta_info_def.tags else "unknown"

# Scan all nodes in the graph for suspicious operations
for node in graph_def.node:
Expand All @@ -354,14 +376,49 @@ def _analyze_saved_model(self, saved_model: Any, result: ScanResult) -> None:
details={
"op_type": node.op,
"node_name": node.name,
"meta_graph": (
meta_graph.meta_info_def.tags[0] if meta_graph.meta_info_def.tags else "unknown"
),
"meta_graph": meta_graph_tag,
},
why=get_tf_op_explanation(node.op),
)
# else: already reported by generic dangerous-op pass

for function in graph_def.library.function:
function_name = function.signature.name or "unknown_function"

for node in function.node_def:
# Count all operation types
op_counts[node.op] = op_counts.get(node.op, 0) + 1

if node.op in _FUNCTION_LIBRARY_BENIGN_IO_OPS:
continue

if node.op in self.suspicious_ops:
suspicious_op_found = True

if node.op in PYTHON_OPS:
self._check_python_op(
node,
result,
meta_graph,
location_label=f"function: {function_name}, node: {node.name}",
)
elif node.op not in DANGEROUS_TF_OPERATIONS:
result.add_check(
name="TensorFlow Operation Security Check",
passed=False,
message=f"Suspicious TensorFlow operation: {node.op}",
severity=IssueSeverity.CRITICAL,
location=(f"{self.current_file_path} (function: {function_name}, node: {node.name})"),
rule_code="S703",
details={
"op_type": node.op,
"node_name": node.name,
"function_name": function_name,
"meta_graph": meta_graph_tag,
},
why=get_tf_op_explanation(node.op),
)

Comment on lines +385 to +421
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify the node variable usage and scoping in _analyze_saved_model
ast-grep --pattern $'for node in graph_def.node:
  $$$
  for node in $_.node_def:
    $$$
  $$$
'

Repository: promptfoo/modelaudit

Length of output: 46


🏁 Script executed:

git ls-files | grep -E 'tf_savedmodel_scanner\.py$'

Repository: promptfoo/modelaudit

Length of output: 153


🏁 Script executed:

sed -n '350,470p' modelaudit/scanners/tf_savedmodel_scanner.py | cat -n

Repository: promptfoo/modelaudit

Length of output: 7299


🏁 Script executed:

sed -n '470,520p' modelaudit/scanners/tf_savedmodel_scanner.py | cat -n

Repository: promptfoo/modelaudit

Length of output: 2892


Critical: node variable shadowing breaks StatefulPartitionedCall and Lambda checks.

The new function-library loop (lines 385-420) declares for node in function.node_def: at line 388, which shadows the outer node from for node in graph_def.node: (line 358).

After the function-library inner loop completes, the node variable at line 422 refers to the last node from function.node_def, not the graph nodes. This breaks:

  1. StatefulPartitionedCall check (lines 422-469): Now executes only on the last node of each function's nodes, instead of all graph nodes.
  2. Lambda layer detection (line 470+): Similarly checks only the last node of each function instead of all graph nodes.

Both checks are now inside the function loop but after the inner node loop, with shadowed variable bindings.

Fix: Rename the inner loop variable to avoid shadowing (e.g., fn_node or func_node_def), or restructure to run these checks on the outer graph_def.node loop where they belong.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelaudit/scanners/tf_savedmodel_scanner.py` around lines 385 - 421, The
inner loop over function.node_def is shadowing the outer graph_def.node variable
(both use "node"), causing subsequent StatefulPartitionedCall and Lambda-layer
checks to run against the last function node instead of graph nodes; fix by
renaming the inner loop variable (e.g., change for node in function.node_def to
for fn_node in function.node_def) and update all uses inside that inner loop
(the op_counts increment, _FUNCTION_LIBRARY_BENIGN_IO_OPS check, PYTHON_OPS
branch that calls self._check_python_op, the DANGEROUS_TF_OPERATIONS branch and
result.add_check/get_tf_op_explanation) so that the outer "node" remains intact
for the later StatefulPartitionedCall and Lambda detection logic which must
iterate over graph_def.node.

# Check for StatefulPartitionedCall which can contain custom functions
if node.op == "StatefulPartitionedCall" and hasattr(node, "attr") and "f" in node.attr:
# These operations can contain arbitrary functions
Expand Down Expand Up @@ -449,7 +506,13 @@ def _analyze_saved_model(self, saved_model: Any, result: ScanResult) -> None:
# Enhanced protobuf vulnerability scanning
self._scan_protobuf_vulnerabilities(saved_model, result)

def _check_python_op(self, node: Any, result: ScanResult, meta_graph: Any) -> None:
def _check_python_op(
self,
node: Any,
result: ScanResult,
meta_graph: Any,
location_label: str | None = None,
) -> None:
"""Check PyFunc/PyCall operations for embedded Python code"""
# PyFunc and PyCall can embed Python code in various ways:
# 1. As a string attribute containing Python code
Expand All @@ -458,6 +521,7 @@ def _check_python_op(self, node: Any, result: ScanResult, meta_graph: Any) -> No

code_found = False
python_code = None
node_location = f"{self.current_file_path} ({location_label or f'node: {node.name}'})"

# Try to extract Python code from node attributes
if hasattr(node, "attr"):
Expand Down Expand Up @@ -491,7 +555,7 @@ def _check_python_op(self, node: Any, result: ScanResult, meta_graph: Any) -> No
passed=False,
message=f"{node.op} operation references dangerous function: {func_name}",
severity=IssueSeverity.CRITICAL,
location=f"{self.current_file_path} (node: {node.name})",
location=node_location,
rule_code="S902",
details={
"op_type": node.op,
Expand Down Expand Up @@ -523,7 +587,7 @@ def _check_python_op(self, node: Any, result: ScanResult, meta_graph: Any) -> No
passed=False,
message=issue_msg,
severity=severity,
location=f"{self.current_file_path} (node: {node.name})",
location=node_location,
rule_code="S902",
details={
"op_type": node.op,
Expand All @@ -545,7 +609,7 @@ def _check_python_op(self, node: Any, result: ScanResult, meta_graph: Any) -> No
message=f"{node.op} operation contains suspicious data (possibly obfuscated code)",
rule_code="S902",
severity=IssueSeverity.CRITICAL,
location=f"{self.current_file_path} (node: {node.name})",
location=node_location,
details={
"op_type": node.op,
"node_name": node.name,
Expand All @@ -565,7 +629,7 @@ def _check_python_op(self, node: Any, result: ScanResult, meta_graph: Any) -> No
message=f"{node.op} operation detected (unable to extract Python code)",
rule_code="S902",
severity=IssueSeverity.CRITICAL,
location=f"{self.current_file_path} (node: {node.name})",
location=node_location,
details={
"op_type": node.op,
"node_name": node.name,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
¾°½ûù¦ê¨§ʌüÅúë΋k«¹ä¥®ä¢›¹ —ðՑ¿ÁŸÌ(ò𸨭ÓÄÿ†2:&33340531854789989996454843123276936653
Binary file not shown.
Binary file not shown.
Binary file not shown.
45 changes: 45 additions & 0 deletions tests/scanners/test_tf_savedmodel_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,21 @@ def test_detect_writefile_operation(tmp_path):
assert any(i.why for i in writefile_issues), "Missing explanation for WriteFile detection"


@pytest.mark.skipif(not has_tf_protos(), reason="TensorFlow protobuf stubs unavailable")
@pytest.mark.parametrize("op_name", ["EagerPyFunc", "ReadFile", "WriteFile", "ParseTensor"])
def test_detect_function_library_dangerous_ops(tmp_path, op_name):
"""Dangerous ops in function library node_defs must be detected."""
model_path = _create_test_savedmodel_with_function_op(tmp_path, op_name, f"function_lib_{op_name.lower()}")
scanner = TensorFlowSavedModelScanner()
result = scanner.scan(model_path)

op_issues = [i for i in result.issues if i.message and op_name in i.message]
assert op_issues, f"Expected detection for function library op {op_name}"
assert any(i.severity == IssueSeverity.CRITICAL for i in op_issues)
assert any(i.why for i in op_issues), f"Missing explanation for {op_name} detection"
assert any(i.location and "function:" in i.location for i in op_issues)

Comment on lines +219 to +232
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add missing -> None type hint.

The test function is missing the required return type annotation.

As per coding guidelines: "Use type hints -> None on all test methods."

♻️ Proposed fix
 `@pytest.mark.skipif`(not has_tf_protos(), reason="TensorFlow protobuf stubs unavailable")
 `@pytest.mark.parametrize`("op_name", ["EagerPyFunc", "ReadFile", "WriteFile", "ParseTensor"])
-def test_detect_function_library_dangerous_ops(tmp_path, op_name):
+def test_detect_function_library_dangerous_ops(tmp_path, op_name: str) -> None:
     """Dangerous ops in function library node_defs must be detected."""
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@pytest.mark.skipif(not has_tf_protos(), reason="TensorFlow protobuf stubs unavailable")
@pytest.mark.parametrize("op_name", ["EagerPyFunc", "ReadFile", "WriteFile", "ParseTensor"])
def test_detect_function_library_dangerous_ops(tmp_path, op_name):
"""Dangerous ops in function library node_defs must be detected."""
model_path = _create_test_savedmodel_with_function_op(tmp_path, op_name, f"function_lib_{op_name.lower()}")
scanner = TensorFlowSavedModelScanner()
result = scanner.scan(model_path)
op_issues = [i for i in result.issues if i.message and op_name in i.message]
assert op_issues, f"Expected detection for function library op {op_name}"
assert any(i.severity == IssueSeverity.CRITICAL for i in op_issues)
assert any(i.why for i in op_issues), f"Missing explanation for {op_name} detection"
assert any(i.location and "function:" in i.location for i in op_issues)
def test_detect_function_library_dangerous_ops(tmp_path: Path, op_name: str) -> None:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/scanners/test_tf_savedmodel_scanner.py` around lines 219 - 232, The
test function test_detect_function_library_dangerous_ops is missing the required
return type annotation; update its definition to include the type hint "-> None"
(i.e., def test_detect_function_library_dangerous_ops(tmp_path, op_name) ->
None:) so the test signature follows the project's guideline for test methods.


@pytest.mark.skipif(not has_tensorflow(), reason="TensorFlow not installed")
def test_tf_savedmodel_scanner_with_blacklist(tmp_path):
"""Test TensorFlow SavedModel scanner with custom blacklist patterns."""
Expand Down Expand Up @@ -312,6 +327,36 @@ def _create_test_savedmodel_with_ops(tmp_path, op_names, model_name=None):
return str(model_dir)


def _create_test_savedmodel_with_function_op(tmp_path, op_name, model_name=None):
"""Create a SavedModel with operation only inside function library."""
from tensorflow.core.framework.function_pb2 import FunctionDef
from tensorflow.core.framework.node_def_pb2 import NodeDef
from tensorflow.core.protobuf.saved_model_pb2 import SavedModel

if model_name is None:
model_name = f"test_model_function_{op_name.lower()}"

model_dir = tmp_path / model_name
model_dir.mkdir()

saved_model = SavedModel()
meta_graph = saved_model.meta_graphs.add()
meta_graph.meta_info_def.tags.append("serve")

function = FunctionDef()
function.signature.name = f"dangerous_fn_{op_name.lower()}"
function.node_def.extend([NodeDef(name=f"fn_node_{op_name.lower()}", op=op_name)])
meta_graph.graph_def.library.function.extend([function])

saved_model_path = model_dir / "saved_model.pb"
saved_model_path.write_bytes(saved_model.SerializeToString())

variables_dir = model_dir / "variables"
variables_dir.mkdir()

return str(model_dir)

Comment on lines +330 to +358
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add missing -> str return type hint.

The helper function returns a string but lacks the type annotation.

As per coding guidelines: "Always include type hints in Python code."

♻️ Proposed fix
-def _create_test_savedmodel_with_function_op(tmp_path, op_name, model_name=None):
+def _create_test_savedmodel_with_function_op(tmp_path, op_name: str, model_name: str | None = None) -> str:
     """Create a SavedModel with operation only inside function library."""
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/scanners/test_tf_savedmodel_scanner.py` around lines 330 - 358, The
helper function _create_test_savedmodel_with_function_op returns a string but
lacks a return type annotation; update its signature to include "-> str" (i.e.,
def _create_test_savedmodel_with_function_op(tmp_path, op_name, model_name=None)
-> str:) so the function is explicitly annotated as returning a string; no other
behavior changes needed.


@pytest.mark.skipif(not has_tensorflow(), reason="TensorFlow not installed")
def test_tf_scanner_explanations_for_all_suspicious_ops(tmp_path):
"""Test that all suspicious TensorFlow operations generate explanations."""
Expand Down
Loading