fix: detect risky import-only pickle ML surfaces#696
Conversation
WalkthroughAdds security detection for risky Torch and NumPy import-only references (e.g., Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
📝 Coding Plan
Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelaudit/scanners/pickle_scanner.py`:
- Around line 547-560: The current _is_risky_ml_import only checks
RISKY_ML_MODULE_PREFIXES against the full module string and misses parent-module
+ attribute forms like (mod="torch", func="jit") or stack globals like
["numpy","distutils"]; update _is_risky_ml_import to also detect risky
parent/child combos by checking if (mod, func) matches a risky pair (derive by
splitting each prefix in RISKY_ML_MODULE_PREFIXES and normalizing) or if (mod,
func) equals any tuple in RISKY_ML_EXACT_REFS, and ensure ML_SAFE_GLOBALS
exceptions are applied only after this parent-attribute check; add unit tests
covering GLOBAL torch jit, GLOBAL torch _dynamo, STACK_GLOBAL
["numpy","distutils"] and corresponding benign cases so regressions are caught.
In `@tests/scanners/test_pickle_scanner.py`:
- Around line 1309-1330: The helper _short_binunicode is duplicated in
test_risky_ml_memoized_stack_global_reuse_is_detected and
test_risky_ml_stack_global_detection; extract it to a single module-level helper
(e.g., define _short_binunicode once near the top of the test file after
imports) and remove the inner definitions from both test functions, then update
the tests to call the shared _short_binunicode; ensure the helper signature and
behavior remain identical so Pickle payload construction in both tests is
unchanged.
In `@tests/test_why_explanations.py`:
- Around line 73-77: The test test_explanations_for_specific_risky_ml_imports
currently only checks get_import_explanation(...) is not None and will pass if a
generic "torch" fallback is returned; update the assertions to verify
ref-specific content by calling get_import_explanation("torch.compile"),
get_import_explanation("torch._dynamo.optimize"), and
get_import_explanation("torch.storage._load_from_bytes") and asserting the
returned string contains expected specific tokens (e.g., "compile", "dynamo",
"storage._load_from_bytes" or the precise check/issue names used by your
explanations) or explicitly asserting it is not equal to the generic torch
fallback string, so the test fails if resolution regresses to the base-module
fallback.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: ASSERTIVE
Plan: Pro
Run ID: f6e5dbf5-a8ff-4c2d-a380-c9d6e5e273ae
📒 Files selected for processing (5)
CHANGELOG.mdmodelaudit/config/explanations.pymodelaudit/scanners/pickle_scanner.pytests/scanners/test_pickle_scanner.pytests/test_why_explanations.py
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (1)
modelaudit/scanners/pickle_scanner.py (1)
557-560:⚠️ Potential issue | 🔴 CriticalDotted qualnames still evade the risky-ML matcher.
Direct parent/child refs are covered now, but protocol-4+ pickle can also resolve dotted
namevalues under the imported module. Shapes likeGLOBAL torch jit.script,GLOBAL torch _dynamo.optimize, orGLOBAL torch storage._load_from_bytesstill miss_is_risky_ml_import()because onlymodis prefix-matched, so the Step 0 GLOBAL/STACK_GLOBAL/REDUCE checks can still be bypassed.🔧 Suggested fix
-RISKY_ML_EXACT_REFS: set[tuple[str, str]] = { - ("torch", "compile"), - ("torch.storage", "_load_from_bytes"), -} +RISKY_ML_EXACT_REFS: set[str] = { + "torch.compile", + "torch.storage._load_from_bytes", +} @@ def _is_risky_ml_import(mod: str, func: str) -> bool: """Return True when module/function matches risky ML import policy.""" - if (mod, func) in RISKY_ML_EXACT_REFS: + full_ref = f"{mod}.{func}" if func else mod + if full_ref in RISKY_ML_EXACT_REFS: return True if (mod, func) in RISKY_ML_PARENT_CHILD_REFS: return True - return any(mod == prefix or mod.startswith(f"{prefix}.") for prefix in RISKY_ML_MODULE_PREFIXES) + return any( + candidate == prefix or candidate.startswith(f"{prefix}.") + for candidate in (mod, full_ref) + for prefix in RISKY_ML_MODULE_PREFIXES + )Run this to verify the gap and the underlying pickle behavior:
#!/bin/bash set -euo pipefail echo "== Current matcher ==" rg -n -A8 -B4 'RISKY_ML_MODULE_PREFIXES|RISKY_ML_EXACT_REFS|def _is_risky_ml_import' modelaudit/scanners/pickle_scanner.py echo echo "== CPython pickle dotted-name resolution ==" python - <<'PY' import inspect import pathlib import pickle source_file = inspect.getsourcefile(pickle._Unpickler) print(source_file) if source_file: text = pathlib.Path(source_file).read_text() marker = "def find_class" start = text.find(marker) if start >= 0: print(text[start:start + 900]) PY echo echo "== Current logic on dotted risky refs ==" python - <<'PY' RISKY_ML_MODULE_PREFIXES = ( "torch.jit", "torch._dynamo", "torch._inductor", "numpy.f2py", "numpy.distutils", ) RISKY_ML_EXACT_REFS = { ("torch", "compile"), ("torch.storage", "_load_from_bytes"), } def _split_parent_child_ref(prefix: str) -> tuple[str, str]: parent, _separator, child = prefix.rpartition(".") return parent, child RISKY_ML_PARENT_CHILD_REFS = frozenset( _split_parent_child_ref(prefix) for prefix in RISKY_ML_MODULE_PREFIXES ) def current(mod: str, func: str) -> bool: if (mod, func) in RISKY_ML_EXACT_REFS: return True if (mod, func) in RISKY_ML_PARENT_CHILD_REFS: return True return any(mod == prefix or mod.startswith(f"{prefix}.") for prefix in RISKY_ML_MODULE_PREFIXES) for case in [ ("torch", "jit.script"), ("torch", "_dynamo.optimize"), ("torch", "storage._load_from_bytes"), ]: print(f"{case}: {current(*case)}") PYExpected: the stdlib source shows dotted-name handling in
find_class, and all three cases printFalse. As per coding guidelines, "Preserve or strengthen security detections; test both benign and malicious samples when adding scanner/feature changes".Also applies to: 1822-1828
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelaudit/scanners/pickle_scanner.py` around lines 557 - 560, The matcher misses dotted qualnames (e.g., GLOBAL "torch" "jit.script" or "torch" "storage._load_from_bytes") because _is_risky_ml_import only checks (mod, func) and module-prefixes; update _is_risky_ml_import (and related checks that use RISKY_ML_EXACT_REFS / RISKY_ML_PARENT_CHILD_REFS / RISKY_ML_MODULE_PREFIXES) to also handle dotted func values by splitting func on the first '.' and: 1) treating the leftmost token as a possible submodule (new_mod = f"{mod}.{first}") and checking (new_mod, rest) against RISKY_ML_EXACT_REFS and parent/child refs, and 2) treating the leftmost token as a child name for parent-child checks (i.e., check (mod, first) against RISKY_ML_PARENT_CHILD_REFS and also check module-prefix matching for f"{mod}.{first}" with RISKY_ML_MODULE_PREFIXES); this ensures cases like ("torch","jit.script"), ("torch","_dynamo.optimize"), and ("torch","storage._load_from_bytes") are detected.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/scanners/test_pickle_scanner.py`:
- Around line 1259-1265: The parametrized test only includes deeper refs
(torch.jit, torch._dynamo, numpy.distutils); add explicit parent-import cases
for the bare modules by adding ("torch", "_inductor",
b"\x80\x02ctorch\n_inductor\n.") and ("numpy", "f2py",
b"\x80\x02cnumpy\nf2py\n.") to the pytest.mark.parametrize tuple so GLOBAL-style
imports are exercised; ensure these new entries follow the same benign/malicious
payload patterns used by the existing cases so both detection and non-detection
behavior for functions like the test function that consumes
module_name/func_name/payload is validated.
- Around line 1345-1369: The test
test_risky_ml_parent_attribute_stack_global_detection currently only asserts
that the STACK_GLOBAL parent/attribute check failed (CheckStatus.FAILED) which
would miss regressions that downgrade the detection to WARNING; update the
assertions to require the check status be CRITICAL instead of FAILED by locating
the failing_checks logic that filters result.checks for "STACK_GLOBAL Module
Check" and change the expectation to CheckStatus.CRITICAL, and also update the
subsequent issue assertion to ensure the issue severity or message reflects a
CRITICAL-level detection for numpy.distutils so the test fails if the check is
downgraded.
- Around line 1372-1390: The test currently flags the original STACK_GLOBAL
instead of the BINGET recall, so change the constructed payload in
test_risky_ml_memoized_stack_global_reuse_is_detected to ensure the detection
must come from the memo recall: create a harmless STACK_GLOBAL first (so it
won't be flagged), then emit the risky STACK_GLOBAL and MEMOIZE it into index 0,
POP the stack, then BINGET 0 and STOP so the scanner must detect the risky
import when the memoized entry is recalled (update the payload byte sequence in
the payload variable accordingly).
---
Duplicate comments:
In `@modelaudit/scanners/pickle_scanner.py`:
- Around line 557-560: The matcher misses dotted qualnames (e.g., GLOBAL "torch"
"jit.script" or "torch" "storage._load_from_bytes") because _is_risky_ml_import
only checks (mod, func) and module-prefixes; update _is_risky_ml_import (and
related checks that use RISKY_ML_EXACT_REFS / RISKY_ML_PARENT_CHILD_REFS /
RISKY_ML_MODULE_PREFIXES) to also handle dotted func values by splitting func on
the first '.' and: 1) treating the leftmost token as a possible submodule
(new_mod = f"{mod}.{first}") and checking (new_mod, rest) against
RISKY_ML_EXACT_REFS and parent/child refs, and 2) treating the leftmost token as
a child name for parent-child checks (i.e., check (mod, first) against
RISKY_ML_PARENT_CHILD_REFS and also check module-prefix matching for
f"{mod}.{first}" with RISKY_ML_MODULE_PREFIXES); this ensures cases like
("torch","jit.script"), ("torch","_dynamo.optimize"), and
("torch","storage._load_from_bytes") are detected.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: ASSERTIVE
Plan: Pro
Run ID: edf07cff-9afb-4cc2-9165-c50775b203f3
📒 Files selected for processing (3)
modelaudit/scanners/pickle_scanner.pytests/scanners/test_pickle_scanner.pytests/test_why_explanations.py
| @pytest.mark.parametrize( | ||
| ("module_name", "func_name", "payload"), | ||
| [ | ||
| ("torch", "jit", b"\x80\x02ctorch\njit\n."), | ||
| ("torch", "_dynamo", b"\x80\x02ctorch\n_dynamo\n."), | ||
| ("numpy", "distutils", b"\x80\x02cnumpy\ndistutils\n."), | ||
| ], |
There was a problem hiding this comment.
🧹 Nitpick | 🔵 Trivial
Add bare torch._inductor and numpy.f2py parent-import cases.
The PR adds those import-only refs, but this parametrization only exercises torch.jit, torch._dynamo, and numpy.distutils. A regression in GLOBAL torch\n_inductor\n. or GLOBAL numpy\nf2py\n. would still pass because the suite only covers deeper refs for those families elsewhere.
➕ Suggested additions
[
("torch", "jit", b"\x80\x02ctorch\njit\n."),
("torch", "_dynamo", b"\x80\x02ctorch\n_dynamo\n."),
+ ("torch", "_inductor", b"\x80\x02ctorch\n_inductor\n."),
+ ("numpy", "f2py", b"\x80\x02cnumpy\nf2py\n."),
("numpy", "distutils", b"\x80\x02cnumpy\ndistutils\n."),
],🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/scanners/test_pickle_scanner.py` around lines 1259 - 1265, The
parametrized test only includes deeper refs (torch.jit, torch._dynamo,
numpy.distutils); add explicit parent-import cases for the bare modules by
adding ("torch", "_inductor", b"\x80\x02ctorch\n_inductor\n.") and ("numpy",
"f2py", b"\x80\x02cnumpy\nf2py\n.") to the pytest.mark.parametrize tuple so
GLOBAL-style imports are exercised; ensure these new entries follow the same
benign/malicious payload patterns used by the existing cases so both detection
and non-detection behavior for functions like the test function that consumes
module_name/func_name/payload is validated.
| def test_risky_ml_parent_attribute_stack_global_detection(tmp_path: Path) -> None: | ||
| """STACK_GLOBAL parent/attribute refs should trigger the risky ML policy.""" | ||
| payload = bytearray(b"\x80\x04") | ||
| payload += _short_binunicode(b"numpy") | ||
| payload += _short_binunicode(b"distutils") | ||
| payload += b"\x93" # STACK_GLOBAL | ||
| payload += b"." # STOP | ||
|
|
||
| path = tmp_path / "numpy_distutils_stack_global.pkl" | ||
| path.write_bytes(payload) | ||
|
|
||
| result = PickleScanner().scan(str(path)) | ||
| failing_checks = [ | ||
| check | ||
| for check in result.checks | ||
| if check.name == "STACK_GLOBAL Module Check" | ||
| and check.status == CheckStatus.FAILED | ||
| and check.details.get("module") == "numpy" | ||
| and check.details.get("function") == "distutils" | ||
| ] | ||
|
|
||
| assert failing_checks, f"Expected STACK_GLOBAL numpy.distutils detection. Checks: {result.checks}" | ||
| assert any("numpy.distutils" in issue.message for issue in result.issues), ( | ||
| f"Expected numpy.distutils issue. Issues: {[issue.message for issue in result.issues]}" | ||
| ) |
There was a problem hiding this comment.
🧹 Nitpick | 🔵 Trivial
Assert CRITICAL on the parent/attribute STACK_GLOBAL path too.
This branch is exercising the new risky-parent STACK_GLOBAL policy, but it only checks for FAILED. A regression that downgrades numpy.distutils to WARNING would still keep this test green.
➕ Suggested assertion
assert failing_checks, f"Expected STACK_GLOBAL numpy.distutils detection. Checks: {result.checks}"
+ assert all(check.severity == IssueSeverity.CRITICAL for check in failing_checks), (
+ f"Expected CRITICAL STACK_GLOBAL finding for numpy.distutils, got: "
+ f"{[(check.severity, check.message) for check in failing_checks]}"
+ )
assert any("numpy.distutils" in issue.message for issue in result.issues), (🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/scanners/test_pickle_scanner.py` around lines 1345 - 1369, The test
test_risky_ml_parent_attribute_stack_global_detection currently only asserts
that the STACK_GLOBAL parent/attribute check failed (CheckStatus.FAILED) which
would miss regressions that downgrade the detection to WARNING; update the
assertions to require the check status be CRITICAL instead of FAILED by locating
the failing_checks logic that filters result.checks for "STACK_GLOBAL Module
Check" and change the expectation to CheckStatus.CRITICAL, and also update the
subsequent issue assertion to ensure the issue severity or message reflects a
CRITICAL-level detection for numpy.distutils so the test fails if the check is
downgraded.
| def test_risky_ml_memoized_stack_global_reuse_is_detected(tmp_path: Path) -> None: | ||
| """Memoized risky STACK_GLOBAL references should remain detectable on recall.""" | ||
|
|
||
| payload = bytearray(b"\x80\x04") | ||
| payload += _short_binunicode(b"torch._dynamo") | ||
| payload += _short_binunicode(b"optimize") | ||
| payload += b"\x93" # STACK_GLOBAL | ||
| payload += b"\x94" # MEMOIZE index 0 | ||
| payload += b"0" # POP | ||
| payload += b"h\x00" # BINGET 0 | ||
| payload += b"." # STOP | ||
|
|
||
| path = tmp_path / "torch_dynamo_memo.pkl" | ||
| path.write_bytes(payload) | ||
|
|
||
| result = PickleScanner().scan(str(path)) | ||
| assert any("torch._dynamo.optimize" in issue.message for issue in result.issues), ( | ||
| f"Expected memoized risky import detection. Issues: {[i.message for i in result.issues]}" | ||
| ) |
There was a problem hiding this comment.
🧹 Nitpick | 🔵 Trivial
Make the recalled memo entry drive the detection.
Right now the dangerous path is still the original STACK_GLOBAL; the test stops immediately after BINGET 0. If memoized reuse resolution breaks, this can still pass because the first STACK_GLOBAL already produced the issue.
➕ Stronger regression shape
payload += b"\x93" # STACK_GLOBAL
payload += b"\x94" # MEMOIZE index 0
payload += b"0" # POP
payload += b"h\x00" # BINGET 0
- payload += b"." # STOP
+ payload += b")R." # EMPTY_TUPLE + REDUCE + STOP
@@
- assert any("torch._dynamo.optimize" in issue.message for issue in result.issues), (
- f"Expected memoized risky import detection. Issues: {[i.message for i in result.issues]}"
- )
+ reduce_checks = [
+ check
+ for check in result.checks
+ if check.name == "REDUCE Opcode Safety Check"
+ and check.status == CheckStatus.FAILED
+ and check.details.get("associated_global") == "torch._dynamo.optimize"
+ ]
+ assert reduce_checks, (
+ f"Expected BINGET-resolved REDUCE detection for torch._dynamo.optimize. Checks: {result.checks}"
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/scanners/test_pickle_scanner.py` around lines 1372 - 1390, The test
currently flags the original STACK_GLOBAL instead of the BINGET recall, so
change the constructed payload in
test_risky_ml_memoized_stack_global_reuse_is_detected to ensure the detection
must come from the memo recall: create a harmless STACK_GLOBAL first (so it
won't be flagged), then emit the risky STACK_GLOBAL and MEMOIZE it into index 0,
POP the stack, then BINGET 0 and STOP so the scanner must detect the risky
import when the memoized entry is recalled (update the payload byte sequence in
the payload variable accordingly).
Summary
torch.jit,torch._dynamo,torch._inductor,torch.compile,torch.storage._load_from_bytes,numpy.f2py, andnumpy.distutilstorch.compileand add regressions for safe reconstruction and state-dict-style payloadsTesting
uv run ruff format modelaudit/ tests/uv run ruff check --fix modelaudit/ tests/uv run ruff check modelaudit/ tests/uv run ruff format --check modelaudit/ tests/uv run mypy modelaudit/uv run pytest tests/scanners/test_pickle_scanner.py -q -k "risky_ml or state_dict or why"uv run pytest tests/test_why_explanations.py -quv run pytest -n auto -m "not slow and not integration" --maxfail=1QA
torch.compileimport-only pickle produced critical findings with the exact risky-ML explanationcollections.OrderedDictglobal/REDUCE validationsSummary by CodeRabbit
New Features
Tests
Documentation