Skip to content

fix: detect risky import-only pickle ML surfaces#696

Open
mldangelo wants to merge 2 commits intomainfrom
feat/pickle-risky-ml-imports
Open

fix: detect risky import-only pickle ML surfaces#696
mldangelo wants to merge 2 commits intomainfrom
feat/pickle-risky-ml-imports

Conversation

@mldangelo
Copy link
Member

@mldangelo mldangelo commented Mar 13, 2026

Summary

  • flag risky import-only pickle references for torch.jit, torch._dynamo, torch._inductor, torch.compile, torch.storage._load_from_bytes, numpy.f2py, and numpy.distutils
  • reuse the risky-ML policy across GLOBAL, STACK_GLOBAL, and REDUCE-associated callable handling
  • resolve dotted import explanations for exact refs like torch.compile and add regressions for safe reconstruction and state-dict-style payloads

Testing

  • uv run ruff format modelaudit/ tests/
  • uv run ruff check --fix modelaudit/ tests/
  • uv run ruff check modelaudit/ tests/
  • uv run ruff format --check modelaudit/ tests/
  • uv run mypy modelaudit/
  • uv run pytest tests/scanners/test_pickle_scanner.py -q -k "risky_ml or state_dict or why"
  • uv run pytest tests/test_why_explanations.py -q
  • uv run pytest -n auto -m "not slow and not integration" --maxfail=1

QA

  • manual text-output scan of a torch.compile import-only pickle produced critical findings with the exact risky-ML explanation
  • manual JSON-output scan of a safe state-dict-style payload remained clean; the only pickle-specific checks were passing collections.OrderedDict global/REDUCE validations

Summary by CodeRabbit

  • New Features

    • Enhanced security scanning to flag risky ML framework import patterns in pickled artifacts while preserving safe model state reconstruction paths.
  • Tests

    • Added extensive tests validating detection across various pickle opcodes and reuse/memoization scenarios.
  • Documentation

    • Updated changelog with the security improvement note.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 13, 2026

Walkthrough

Adds security detection for risky Torch and NumPy import-only references (e.g., torch.jit, torch._dynamo, torch._inductor, torch.compile, torch.storage._load_from_bytes, numpy.f2py, numpy.distutils), updates import-explanation lookup to consider parent modules, and expands pickle scanner checks and tests to flag those risky ML refs while preserving safe reconstruction patterns.

Changes

Cohort / File(s) Summary
Changelog & Config
CHANGELOG.md, modelaudit/config/explanations.py
Adds a security note and new dangerous-import explanations for specific Torch/NumPy modules; expands get_import_explanation to check parent module paths for matching explanations.
Pickle scanner logic
modelaudit/scanners/pickle_scanner.py
Adds risky-ML policy constants (RISKY_ML_MODULE_PREFIXES, RISKY_ML_EXACT_REFS, RISKY_ML_PARENT_CHILD_REFS), helper _split_parent_child_ref, _is_risky_ml_import() detection, integrates risky-ML checks into global/REDUCE/STACK_GLOBAL detection, and switches to fully-qualified import explanation formatting.
Tests
tests/scanners/test_pickle_scanner.py, tests/test_why_explanations.py
Introduces extensive tests covering GLOBAL, REDUCE, STACK_GLOBAL, memoization, safe reconstruction patterns, and explicit explanation checks for new risky ML import references.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

🐇 I nibbled bytes and sniffed the trail,
Found torch and f2py in a hidden mail,
I flagged the risky hops with care,
Let safe state-dicts bounce right there,
Now pickles dance without a scare. 🥕

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly and clearly describes the main change: detecting risky import-only pickle references for ML-related modules.
Docstring Coverage ✅ Passed Docstring coverage is 88.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feat/pickle-risky-ml-imports
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelaudit/scanners/pickle_scanner.py`:
- Around line 547-560: The current _is_risky_ml_import only checks
RISKY_ML_MODULE_PREFIXES against the full module string and misses parent-module
+ attribute forms like (mod="torch", func="jit") or stack globals like
["numpy","distutils"]; update _is_risky_ml_import to also detect risky
parent/child combos by checking if (mod, func) matches a risky pair (derive by
splitting each prefix in RISKY_ML_MODULE_PREFIXES and normalizing) or if (mod,
func) equals any tuple in RISKY_ML_EXACT_REFS, and ensure ML_SAFE_GLOBALS
exceptions are applied only after this parent-attribute check; add unit tests
covering GLOBAL torch jit, GLOBAL torch _dynamo, STACK_GLOBAL
["numpy","distutils"] and corresponding benign cases so regressions are caught.

In `@tests/scanners/test_pickle_scanner.py`:
- Around line 1309-1330: The helper _short_binunicode is duplicated in
test_risky_ml_memoized_stack_global_reuse_is_detected and
test_risky_ml_stack_global_detection; extract it to a single module-level helper
(e.g., define _short_binunicode once near the top of the test file after
imports) and remove the inner definitions from both test functions, then update
the tests to call the shared _short_binunicode; ensure the helper signature and
behavior remain identical so Pickle payload construction in both tests is
unchanged.

In `@tests/test_why_explanations.py`:
- Around line 73-77: The test test_explanations_for_specific_risky_ml_imports
currently only checks get_import_explanation(...) is not None and will pass if a
generic "torch" fallback is returned; update the assertions to verify
ref-specific content by calling get_import_explanation("torch.compile"),
get_import_explanation("torch._dynamo.optimize"), and
get_import_explanation("torch.storage._load_from_bytes") and asserting the
returned string contains expected specific tokens (e.g., "compile", "dynamo",
"storage._load_from_bytes" or the precise check/issue names used by your
explanations) or explicitly asserting it is not equal to the generic torch
fallback string, so the test fails if resolution regresses to the base-module
fallback.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: ASSERTIVE

Plan: Pro

Run ID: f6e5dbf5-a8ff-4c2d-a380-c9d6e5e273ae

📥 Commits

Reviewing files that changed from the base of the PR and between ea7c6d9 and 938a3dd.

📒 Files selected for processing (5)
  • CHANGELOG.md
  • modelaudit/config/explanations.py
  • modelaudit/scanners/pickle_scanner.py
  • tests/scanners/test_pickle_scanner.py
  • tests/test_why_explanations.py

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
modelaudit/scanners/pickle_scanner.py (1)

557-560: ⚠️ Potential issue | 🔴 Critical

Dotted qualnames still evade the risky-ML matcher.

Direct parent/child refs are covered now, but protocol-4+ pickle can also resolve dotted name values under the imported module. Shapes like GLOBAL torch jit.script, GLOBAL torch _dynamo.optimize, or GLOBAL torch storage._load_from_bytes still miss _is_risky_ml_import() because only mod is prefix-matched, so the Step 0 GLOBAL/STACK_GLOBAL/REDUCE checks can still be bypassed.

🔧 Suggested fix
-RISKY_ML_EXACT_REFS: set[tuple[str, str]] = {
-    ("torch", "compile"),
-    ("torch.storage", "_load_from_bytes"),
-}
+RISKY_ML_EXACT_REFS: set[str] = {
+    "torch.compile",
+    "torch.storage._load_from_bytes",
+}
@@
 def _is_risky_ml_import(mod: str, func: str) -> bool:
     """Return True when module/function matches risky ML import policy."""
-    if (mod, func) in RISKY_ML_EXACT_REFS:
+    full_ref = f"{mod}.{func}" if func else mod
+    if full_ref in RISKY_ML_EXACT_REFS:
         return True
     if (mod, func) in RISKY_ML_PARENT_CHILD_REFS:
         return True
-    return any(mod == prefix or mod.startswith(f"{prefix}.") for prefix in RISKY_ML_MODULE_PREFIXES)
+    return any(
+        candidate == prefix or candidate.startswith(f"{prefix}.")
+        for candidate in (mod, full_ref)
+        for prefix in RISKY_ML_MODULE_PREFIXES
+    )

Run this to verify the gap and the underlying pickle behavior:

#!/bin/bash
set -euo pipefail

echo "== Current matcher =="
rg -n -A8 -B4 'RISKY_ML_MODULE_PREFIXES|RISKY_ML_EXACT_REFS|def _is_risky_ml_import' modelaudit/scanners/pickle_scanner.py

echo
echo "== CPython pickle dotted-name resolution =="
python - <<'PY'
import inspect
import pathlib
import pickle

source_file = inspect.getsourcefile(pickle._Unpickler)
print(source_file)
if source_file:
    text = pathlib.Path(source_file).read_text()
    marker = "def find_class"
    start = text.find(marker)
    if start >= 0:
        print(text[start:start + 900])
PY

echo
echo "== Current logic on dotted risky refs =="
python - <<'PY'
RISKY_ML_MODULE_PREFIXES = (
    "torch.jit",
    "torch._dynamo",
    "torch._inductor",
    "numpy.f2py",
    "numpy.distutils",
)
RISKY_ML_EXACT_REFS = {
    ("torch", "compile"),
    ("torch.storage", "_load_from_bytes"),
}

def _split_parent_child_ref(prefix: str) -> tuple[str, str]:
    parent, _separator, child = prefix.rpartition(".")
    return parent, child

RISKY_ML_PARENT_CHILD_REFS = frozenset(
    _split_parent_child_ref(prefix) for prefix in RISKY_ML_MODULE_PREFIXES
)

def current(mod: str, func: str) -> bool:
    if (mod, func) in RISKY_ML_EXACT_REFS:
        return True
    if (mod, func) in RISKY_ML_PARENT_CHILD_REFS:
        return True
    return any(mod == prefix or mod.startswith(f"{prefix}.") for prefix in RISKY_ML_MODULE_PREFIXES)

for case in [
    ("torch", "jit.script"),
    ("torch", "_dynamo.optimize"),
    ("torch", "storage._load_from_bytes"),
]:
    print(f"{case}: {current(*case)}")
PY

Expected: the stdlib source shows dotted-name handling in find_class, and all three cases print False. As per coding guidelines, "Preserve or strengthen security detections; test both benign and malicious samples when adding scanner/feature changes".

Also applies to: 1822-1828

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelaudit/scanners/pickle_scanner.py` around lines 557 - 560, The matcher
misses dotted qualnames (e.g., GLOBAL "torch" "jit.script" or "torch"
"storage._load_from_bytes") because _is_risky_ml_import only checks (mod, func)
and module-prefixes; update _is_risky_ml_import (and related checks that use
RISKY_ML_EXACT_REFS / RISKY_ML_PARENT_CHILD_REFS / RISKY_ML_MODULE_PREFIXES) to
also handle dotted func values by splitting func on the first '.' and: 1)
treating the leftmost token as a possible submodule (new_mod = f"{mod}.{first}")
and checking (new_mod, rest) against RISKY_ML_EXACT_REFS and parent/child refs,
and 2) treating the leftmost token as a child name for parent-child checks
(i.e., check (mod, first) against RISKY_ML_PARENT_CHILD_REFS and also check
module-prefix matching for f"{mod}.{first}" with RISKY_ML_MODULE_PREFIXES); this
ensures cases like ("torch","jit.script"), ("torch","_dynamo.optimize"), and
("torch","storage._load_from_bytes") are detected.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/scanners/test_pickle_scanner.py`:
- Around line 1259-1265: The parametrized test only includes deeper refs
(torch.jit, torch._dynamo, numpy.distutils); add explicit parent-import cases
for the bare modules by adding ("torch", "_inductor",
b"\x80\x02ctorch\n_inductor\n.") and ("numpy", "f2py",
b"\x80\x02cnumpy\nf2py\n.") to the pytest.mark.parametrize tuple so GLOBAL-style
imports are exercised; ensure these new entries follow the same benign/malicious
payload patterns used by the existing cases so both detection and non-detection
behavior for functions like the test function that consumes
module_name/func_name/payload is validated.
- Around line 1345-1369: The test
test_risky_ml_parent_attribute_stack_global_detection currently only asserts
that the STACK_GLOBAL parent/attribute check failed (CheckStatus.FAILED) which
would miss regressions that downgrade the detection to WARNING; update the
assertions to require the check status be CRITICAL instead of FAILED by locating
the failing_checks logic that filters result.checks for "STACK_GLOBAL Module
Check" and change the expectation to CheckStatus.CRITICAL, and also update the
subsequent issue assertion to ensure the issue severity or message reflects a
CRITICAL-level detection for numpy.distutils so the test fails if the check is
downgraded.
- Around line 1372-1390: The test currently flags the original STACK_GLOBAL
instead of the BINGET recall, so change the constructed payload in
test_risky_ml_memoized_stack_global_reuse_is_detected to ensure the detection
must come from the memo recall: create a harmless STACK_GLOBAL first (so it
won't be flagged), then emit the risky STACK_GLOBAL and MEMOIZE it into index 0,
POP the stack, then BINGET 0 and STOP so the scanner must detect the risky
import when the memoized entry is recalled (update the payload byte sequence in
the payload variable accordingly).

---

Duplicate comments:
In `@modelaudit/scanners/pickle_scanner.py`:
- Around line 557-560: The matcher misses dotted qualnames (e.g., GLOBAL "torch"
"jit.script" or "torch" "storage._load_from_bytes") because _is_risky_ml_import
only checks (mod, func) and module-prefixes; update _is_risky_ml_import (and
related checks that use RISKY_ML_EXACT_REFS / RISKY_ML_PARENT_CHILD_REFS /
RISKY_ML_MODULE_PREFIXES) to also handle dotted func values by splitting func on
the first '.' and: 1) treating the leftmost token as a possible submodule
(new_mod = f"{mod}.{first}") and checking (new_mod, rest) against
RISKY_ML_EXACT_REFS and parent/child refs, and 2) treating the leftmost token as
a child name for parent-child checks (i.e., check (mod, first) against
RISKY_ML_PARENT_CHILD_REFS and also check module-prefix matching for
f"{mod}.{first}" with RISKY_ML_MODULE_PREFIXES); this ensures cases like
("torch","jit.script"), ("torch","_dynamo.optimize"), and
("torch","storage._load_from_bytes") are detected.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: ASSERTIVE

Plan: Pro

Run ID: edf07cff-9afb-4cc2-9165-c50775b203f3

📥 Commits

Reviewing files that changed from the base of the PR and between 938a3dd and c4b0bc8.

📒 Files selected for processing (3)
  • modelaudit/scanners/pickle_scanner.py
  • tests/scanners/test_pickle_scanner.py
  • tests/test_why_explanations.py

Comment on lines +1259 to +1265
@pytest.mark.parametrize(
("module_name", "func_name", "payload"),
[
("torch", "jit", b"\x80\x02ctorch\njit\n."),
("torch", "_dynamo", b"\x80\x02ctorch\n_dynamo\n."),
("numpy", "distutils", b"\x80\x02cnumpy\ndistutils\n."),
],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Add bare torch._inductor and numpy.f2py parent-import cases.

The PR adds those import-only refs, but this parametrization only exercises torch.jit, torch._dynamo, and numpy.distutils. A regression in GLOBAL torch\n_inductor\n. or GLOBAL numpy\nf2py\n. would still pass because the suite only covers deeper refs for those families elsewhere.

➕ Suggested additions
     [
         ("torch", "jit", b"\x80\x02ctorch\njit\n."),
         ("torch", "_dynamo", b"\x80\x02ctorch\n_dynamo\n."),
+        ("torch", "_inductor", b"\x80\x02ctorch\n_inductor\n."),
+        ("numpy", "f2py", b"\x80\x02cnumpy\nf2py\n."),
         ("numpy", "distutils", b"\x80\x02cnumpy\ndistutils\n."),
     ],
Based on learnings, "Preserve or strengthen security detections; test both benign and malicious samples when adding scanner/feature changes".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/scanners/test_pickle_scanner.py` around lines 1259 - 1265, The
parametrized test only includes deeper refs (torch.jit, torch._dynamo,
numpy.distutils); add explicit parent-import cases for the bare modules by
adding ("torch", "_inductor", b"\x80\x02ctorch\n_inductor\n.") and ("numpy",
"f2py", b"\x80\x02cnumpy\nf2py\n.") to the pytest.mark.parametrize tuple so
GLOBAL-style imports are exercised; ensure these new entries follow the same
benign/malicious payload patterns used by the existing cases so both detection
and non-detection behavior for functions like the test function that consumes
module_name/func_name/payload is validated.

Comment on lines +1345 to +1369
def test_risky_ml_parent_attribute_stack_global_detection(tmp_path: Path) -> None:
"""STACK_GLOBAL parent/attribute refs should trigger the risky ML policy."""
payload = bytearray(b"\x80\x04")
payload += _short_binunicode(b"numpy")
payload += _short_binunicode(b"distutils")
payload += b"\x93" # STACK_GLOBAL
payload += b"." # STOP

path = tmp_path / "numpy_distutils_stack_global.pkl"
path.write_bytes(payload)

result = PickleScanner().scan(str(path))
failing_checks = [
check
for check in result.checks
if check.name == "STACK_GLOBAL Module Check"
and check.status == CheckStatus.FAILED
and check.details.get("module") == "numpy"
and check.details.get("function") == "distutils"
]

assert failing_checks, f"Expected STACK_GLOBAL numpy.distutils detection. Checks: {result.checks}"
assert any("numpy.distutils" in issue.message for issue in result.issues), (
f"Expected numpy.distutils issue. Issues: {[issue.message for issue in result.issues]}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Assert CRITICAL on the parent/attribute STACK_GLOBAL path too.

This branch is exercising the new risky-parent STACK_GLOBAL policy, but it only checks for FAILED. A regression that downgrades numpy.distutils to WARNING would still keep this test green.

➕ Suggested assertion
     assert failing_checks, f"Expected STACK_GLOBAL numpy.distutils detection. Checks: {result.checks}"
+    assert all(check.severity == IssueSeverity.CRITICAL for check in failing_checks), (
+        f"Expected CRITICAL STACK_GLOBAL finding for numpy.distutils, got: "
+        f"{[(check.severity, check.message) for check in failing_checks]}"
+    )
     assert any("numpy.distutils" in issue.message for issue in result.issues), (
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/scanners/test_pickle_scanner.py` around lines 1345 - 1369, The test
test_risky_ml_parent_attribute_stack_global_detection currently only asserts
that the STACK_GLOBAL parent/attribute check failed (CheckStatus.FAILED) which
would miss regressions that downgrade the detection to WARNING; update the
assertions to require the check status be CRITICAL instead of FAILED by locating
the failing_checks logic that filters result.checks for "STACK_GLOBAL Module
Check" and change the expectation to CheckStatus.CRITICAL, and also update the
subsequent issue assertion to ensure the issue severity or message reflects a
CRITICAL-level detection for numpy.distutils so the test fails if the check is
downgraded.

Comment on lines +1372 to +1390
def test_risky_ml_memoized_stack_global_reuse_is_detected(tmp_path: Path) -> None:
"""Memoized risky STACK_GLOBAL references should remain detectable on recall."""

payload = bytearray(b"\x80\x04")
payload += _short_binunicode(b"torch._dynamo")
payload += _short_binunicode(b"optimize")
payload += b"\x93" # STACK_GLOBAL
payload += b"\x94" # MEMOIZE index 0
payload += b"0" # POP
payload += b"h\x00" # BINGET 0
payload += b"." # STOP

path = tmp_path / "torch_dynamo_memo.pkl"
path.write_bytes(payload)

result = PickleScanner().scan(str(path))
assert any("torch._dynamo.optimize" in issue.message for issue in result.issues), (
f"Expected memoized risky import detection. Issues: {[i.message for i in result.issues]}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Make the recalled memo entry drive the detection.

Right now the dangerous path is still the original STACK_GLOBAL; the test stops immediately after BINGET 0. If memoized reuse resolution breaks, this can still pass because the first STACK_GLOBAL already produced the issue.

➕ Stronger regression shape
     payload += b"\x93"  # STACK_GLOBAL
     payload += b"\x94"  # MEMOIZE index 0
     payload += b"0"  # POP
     payload += b"h\x00"  # BINGET 0
-    payload += b"."  # STOP
+    payload += b")R."  # EMPTY_TUPLE + REDUCE + STOP
@@
-    assert any("torch._dynamo.optimize" in issue.message for issue in result.issues), (
-        f"Expected memoized risky import detection. Issues: {[i.message for i in result.issues]}"
-    )
+    reduce_checks = [
+        check
+        for check in result.checks
+        if check.name == "REDUCE Opcode Safety Check"
+        and check.status == CheckStatus.FAILED
+        and check.details.get("associated_global") == "torch._dynamo.optimize"
+    ]
+    assert reduce_checks, (
+        f"Expected BINGET-resolved REDUCE detection for torch._dynamo.optimize. Checks: {result.checks}"
+    )
Based on learnings, "Preserve or strengthen security detections; test both benign and malicious samples when adding scanner/feature changes".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/scanners/test_pickle_scanner.py` around lines 1372 - 1390, The test
currently flags the original STACK_GLOBAL instead of the BINGET recall, so
change the constructed payload in
test_risky_ml_memoized_stack_global_reuse_is_detected to ensure the detection
must come from the memo recall: create a harmless STACK_GLOBAL first (so it
won't be flagged), then emit the risky STACK_GLOBAL and MEMOIZE it into index 0,
POP the stack, then BINGET 0 and STOP so the scanner must detect the risky
import when the memoized entry is recalled (update the payload byte sequence in
the payload variable accordingly).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant