Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- **security:** add exact dangerous-global coverage for `numpy.load`, `site.main`, `_io.FileIO`, `test.support.script_helper.assert_python_ok`, `_osx_support._read_output`, `_aix_support._read_cmd_output`, `_pyrepl.pager.pipe_pager`, `torch.serialization.load`, and `torch._inductor.codecache.compile_file` (9 PickleScan-only loader and execution primitives)
- **security:** treat legacy `httplib` pickle globals the same as `http.client`, including import-only and `REDUCE` findings in standalone and archived payloads
- **security:** harden TensorFlow weight extraction limits to bound actual tensor payload materialization, including malformed `tensor_content` and string-backed tensors, and continue scanning past oversized `Const` nodes
- **security:** stream TAR members to temp files under size limits instead of buffering whole entries in memory during scan
Expand Down
38 changes: 38 additions & 0 deletions modelaudit/config/explanations.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,42 @@
"The 'dill' module extends pickle's capabilities to serialize almost any Python object, including lambda "
"functions and code objects. This significantly increases the attack surface for code execution."
),
"numpy.load": (
"The 'numpy.load' function can recursively deserialize object arrays via pickle support, enabling "
"second-stage payload loading from attacker-controlled files."
),
"site.main": (
"The 'site.main' function executes Python startup path initialization and can trigger module-level "
"execution side effects in attacker-influenced environments."
),
"_io.FileIO": (
"The '_io.FileIO' constructor performs direct file reads and writes, enabling arbitrary local file access "
"without using higher-level safety wrappers."
),
"test.support.script_helper.assert_python_ok": (
"The 'assert_python_ok' helper launches a Python subprocess. In untrusted pickle payloads this is command "
"execution behavior, not benign test plumbing."
),
"_osx_support._read_output": (
"The '_osx_support._read_output' helper executes shell commands to capture output, enabling command "
"execution from deserialization payloads."
),
"_aix_support._read_cmd_output": (
"The '_aix_support._read_cmd_output' helper executes commands and captures process output, creating direct "
"command-execution risk."
),
"_pyrepl.pager.pipe_pager": (
"The '_pyrepl.pager.pipe_pager' helper invokes pager subprocess flows and can be abused for process "
"execution during model loading."
),
"torch.serialization.load": (
"The 'torch.serialization.load' loader performs nested PyTorch and pickle deserialization, which can invoke "
"attacker-controlled reconstruction callables."
),
"torch._inductor.codecache.compile_file": (
"The 'torch._inductor.codecache.compile_file' path compiles and loads generated code artifacts, enabling "
"arbitrary code execution when attacker-controlled."
),
}

# Explanations for dangerous pickle opcodes
Expand Down Expand Up @@ -390,6 +426,8 @@ def get_explanation(category: str, specific_item: str | None = None) -> str | No
# Convenience functions for common use cases
def get_import_explanation(module_name: str) -> str | None:
"""Get explanation for a dangerous import/module."""
if module_name in DANGEROUS_IMPORTS:
return get_explanation("import", module_name)
# Handle module.function format (e.g., "os.system")
base_module = module_name.split(".")[0]
return get_explanation("import", base_module)
Expand Down
173 changes: 93 additions & 80 deletions modelaudit/scanners/pickle_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pickletools
import struct
import time
from typing import IO, Any, BinaryIO, ClassVar, TypeGuard
from typing import Any, BinaryIO, ClassVar, TypeGuard

from modelaudit.analysis.enhanced_pattern_detector import EnhancedPatternDetector, PatternMatch
from modelaudit.analysis.entropy_analyzer import EntropyAnalyzer
Expand Down Expand Up @@ -379,6 +379,7 @@ def _compute_pickle_length(path: str) -> int:
"torch.distributed.rpc.RemoteModule",
# NumPy dangerous functions (Fickling)
"numpy.testing._private.utils.runstring",
"numpy.load",
# pip as callable (CVE-2025-1716: picklescan bypass via pip.main)
"pip.main",
"pip._internal.main",
Expand Down Expand Up @@ -418,6 +419,15 @@ def _compute_pickle_length(path: str) -> int:
"ctypes.cast",
"ctypes.CFUNCTYPE",
"ctypes.WINFUNCTYPE",
# Expanded exact dangerous primitives validated against PickleScan
"site.main",
"_io.FileIO",
"test.support.script_helper.assert_python_ok",
"_osx_support._read_output",
"_aix_support._read_cmd_output",
"_pyrepl.pager.pipe_pager",
"torch.serialization.load",
"torch._inductor.codecache.compile_file",
}

# Module prefixes that are always dangerous (Fickling-based + additional)
Expand Down Expand Up @@ -3791,52 +3801,32 @@ def _create_opcode_sequence_check(self, sequence_result: Any, result: ScanResult
),
)

def _extract_globals_advanced(self, data: IO[bytes], multiple_pickles: bool = True) -> set[tuple[str, str]]:
def _extract_globals_advanced(self, data: BinaryIO, multiple_pickles: bool = True) -> set[tuple[str, str, str]]:
"""Advanced pickle global extraction with STACK_GLOBAL and memo support."""
globals_found: set[tuple[str, str]] = set()
memo: dict[int | str, str] = {}
globals_found: set[tuple[str, str, str]] = set()

last_byte = b"dummy"
while last_byte != b"":
try:
ops: list[tuple[Any, Any, int | None]] = list(pickletools.genops(data))
except Exception as e:
if globals_found:
logger.warning(f"Pickle parsing failed, but found {len(globals_found)} globals: {e}")
return globals_found
# For internal scanner calls (like joblib), don't fail the entire scan
# Just log the issue and return empty set
logger.debug(f"Pickle parsing failed with no globals found: {e}")
return set()

stack_global_refs, _callable_refs = _build_symbolic_reference_maps(ops)

last_byte = data.read(1)
if last_byte:
data.seek(-1, 1)

for n, (opcode, arg, _pos) in enumerate(ops):
op_name = opcode.name
if op_name == "MEMOIZE" and n > 0:
memo[len(memo)] = ops[n - 1][1]
elif op_name in {"PUT", "BINPUT", "LONG_BINPUT"} and n > 0:
memo[arg] = ops[n - 1][1]
elif op_name in {"GLOBAL", "INST"}:
parts = str(arg).split(" ", 1)
if len(parts) == 2:
globals_found.add((parts[0], parts[1]))
elif parts:
globals_found.add((parts[0], ""))
elif op_name == "STACK_GLOBAL":
resolved = stack_global_refs.get(n)
if resolved:
globals_found.add(resolved)
else:
logger.debug(f"STACK_GLOBAL parsing failed at position {n}")
globals_found.add(("unknown", "unknown"))
try:
ops: list[tuple[Any, Any, int | None]] = list(_genops_with_fallback(data, multi_stream=multiple_pickles))
except Exception as e:
logger.debug(f"Pickle parsing failed during advanced global extraction: {e}")
return set()

stack_global_refs, _callable_refs = _build_symbolic_reference_maps(ops)

for n, (opcode, arg, _pos) in enumerate(ops):
op_name = opcode.name
if op_name in {"GLOBAL", "INST"} and isinstance(arg, str):
parsed = _parse_module_function(arg)
if parsed is not None:
globals_found.add((*parsed, op_name))
elif op_name == "STACK_GLOBAL":
resolved = stack_global_refs.get(n)
if resolved:
globals_found.add((*resolved, op_name))
else:
logger.debug(f"STACK_GLOBAL parsing failed at position {n}")
globals_found.add(("unknown", "unknown", op_name))

if not multiple_pickles:
break
return globals_found

def _extract_stack_global_values(
Expand Down Expand Up @@ -4235,7 +4225,7 @@ def get_depth(x):
result.metadata["first_pickle_end_pos"] = first_pickle_end_pos

# Analyze globals extracted from all pickle streams
for mod, func in advanced_globals:
for mod, func, opcode_name in advanced_globals:
if _is_actually_dangerous_global(mod, func, ml_context):
suspicious_count += 1
base_sev = IssueSeverity.WARNING if mod in WARNING_SEVERITY_MODULES else IssueSeverity.CRITICAL
Expand All @@ -4246,7 +4236,7 @@ def get_depth(x):
)
rule_code = get_import_rule_code(mod, func)
if not rule_code:
rule_code = "S205" # STACK_GLOBAL/GLOBAL fallback
rule_code = get_pickle_opcode_rule_code(opcode_name) or "S206"
result.add_check(
name="Advanced Global Reference Check",
passed=False,
Expand All @@ -4257,13 +4247,13 @@ def get_depth(x):
details={
"module": mod,
"function": func,
"opcode": "STACK_GLOBAL",
"opcode": opcode_name,
"ml_context_confidence": ml_context.get(
"overall_confidence",
0,
),
},
why=get_import_explanation(mod),
why=get_import_explanation(f"{mod}.{func}"),
)

# Record successful ML context validation if content appears safe
Expand Down Expand Up @@ -4301,6 +4291,8 @@ def get_depth(x):
)
# Get rule code for this import/module
rule_code = get_import_rule_code(mod, func)
if not rule_code:
rule_code = "S206" # GLOBAL fallback
result.add_check(
name="Global Module Reference Check",
passed=False,
Expand All @@ -4319,7 +4311,7 @@ def get_depth(x):
0,
),
},
why=get_import_explanation(mod),
why=get_import_explanation(f"{mod}.{func}"),
)
else:
# Record successful validation of safe global
Expand Down Expand Up @@ -4433,40 +4425,61 @@ def get_depth(x):
ml_context,
)

# CVE-2025-32434 is specific to torch.load() and
# should only be referenced for PyTorch file formats
_ext = os.path.splitext(self.current_file_path)[1].lower()
_is_pytorch_file = _ext in {".pt", ".pth"} or (
_ext == ".bin" and "pytorch" in ml_context.get("frameworks", {})
)
if _is_pytorch_file:
_reduce_msg = (
f"Found REDUCE opcode with non-allowlisted global: {associated_global}. "
f"This may indicate CVE-2025-32434 exploitation (RCE via torch.load)"
)
if is_actually_dangerous:
_reduce_msg = f"Found REDUCE opcode invoking dangerous global: {associated_global}"
_reduce_details: dict[str, Any] = {
"position": pos,
"opcode": opcode.name,
"associated_global": associated_global,
"cve_id": "CVE-2025-32434",
"ml_context_confidence": ml_context.get(
"overall_confidence",
0,
),
}
else:
_reduce_msg = (
f"Found REDUCE opcode with non-allowlisted global: {associated_global}"
# CVE-2025-32434 is specific to torch.load() and
# should only be referenced for PyTorch file formats
_ext = os.path.splitext(self.current_file_path)[1].lower()
_is_pytorch_file = _ext in {".pt", ".pth"} or (
_ext == ".bin" and "pytorch" in ml_context.get("frameworks", {})
)
_reduce_details = {
"position": pos,
"opcode": opcode.name,
"associated_global": associated_global,
"ml_context_confidence": ml_context.get(
"overall_confidence",
0,
),
}
if _is_pytorch_file:
_reduce_msg = (
f"Found REDUCE opcode with non-allowlisted global: {associated_global}. "
f"This may indicate CVE-2025-32434 exploitation (RCE via torch.load)"
)
_reduce_details = {
"position": pos,
"opcode": opcode.name,
"associated_global": associated_global,
"cve_id": "CVE-2025-32434",
"cvss": 9.8,
"cwe": "CWE-502",
"description": (
"RCE when loading models with torch.load(weights_only=True)"
),
"remediation": (
"Upgrade to PyTorch 2.6.0 or later, and avoid "
"torch.load(weights_only=True) with untrusted models."
),
"ml_context_confidence": ml_context.get(
"overall_confidence",
0,
),
}
else:
_reduce_msg = (
f"Found REDUCE opcode with non-allowlisted global: {associated_global}"
)
_reduce_details = {
"position": pos,
"opcode": opcode.name,
"associated_global": associated_global,
"ml_context_confidence": ml_context.get(
"overall_confidence",
0,
),
}

result.add_check(
name="REDUCE Opcode Safety Check",
Expand Down Expand Up @@ -4809,7 +4822,7 @@ def get_depth(x):
0,
),
},
why=get_import_explanation(mod),
why=get_import_explanation(f"{mod}.{func}"),
)
else:
# Record successful validation of safe STACK_GLOBAL
Expand Down Expand Up @@ -4974,7 +4987,7 @@ def get_depth(x):
# (e.g. HuggingFace cache stores files as hash blobs without extensions)
has_joblib_globals = any(
mod in {"joblib", "sklearn", "numpy"} or mod.startswith(("joblib.", "sklearn.", "numpy."))
for mod, _func in advanced_globals
for mod, _func, _opcode in advanced_globals
)
is_joblib_content = is_serialization_ext or (not file_ext and has_joblib_globals)

Expand Down Expand Up @@ -5002,15 +5015,15 @@ def get_depth(x):
# legitimate PyTorch structures and no dangerous global references appear.
global_validation_context = {"is_ml_content": False, "overall_confidence": 0.0, "frameworks": {}}
has_dangerous_advanced_global = any(
_is_actually_dangerous_global(mod, func, global_validation_context) for mod, func in advanced_globals
_is_actually_dangerous_global(mod, func, global_validation_context)
for mod, func, _opcode in advanced_globals
)
has_pytorch_advanced_global = any(
mod == "torch" or mod.startswith("torch.") for mod, _func in advanced_globals
mod == "torch" or mod.startswith("torch.") for mod, _func, _opcode in advanced_globals
)
has_ordereddict_global = ("collections", "OrderedDict") in advanced_globals or (
"torch",
"OrderedDict",
) in advanced_globals
has_ordereddict_global = any(
mod == "collections" and func == "OrderedDict" for mod, func, _opcode in advanced_globals
) or any(mod == "torch" and func == "OrderedDict" for mod, func, _opcode in advanced_globals)
has_legitimate_pytorch_globals = (
bool(advanced_globals)
and (has_pytorch_advanced_global or has_ordereddict_global)
Expand Down
Loading
Loading