Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- **security:** treat legacy `httplib` pickle globals the same as `http.client`, including import-only and `REDUCE` findings in standalone and archived payloads
- **security:** fail closed on malformed `STACK_GLOBAL` operands when memo lookups are missing or operand types are non-string, while keeping simple truncation-only context informational
- **security:** harden TensorFlow weight extraction limits to bound actual tensor payload materialization, including malformed `tensor_content` and string-backed tensors, and continue scanning past oversized `Const` nodes
- **security:** stream TAR members to temp files under size limits instead of buffering whole entries in memory during scan
- **security:** inspect TensorFlow SavedModel function definitions when scanning for dangerous ops and protobuf string abuse, with function-aware finding locations
Expand Down
144 changes: 123 additions & 21 deletions modelaudit/scanners/pickle_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import pickletools
import reprlib
import struct
import time
from typing import IO, Any, BinaryIO, ClassVar, TypeGuard
Expand Down Expand Up @@ -44,6 +45,38 @@
_RESYNC_BUDGET = 8192 # Max bytes to scan forward when resyncing after an unknown opcode
COPYREG_EXTENSION_MODULE = "__copyreg_extension__"
COPYREG_EXTENSION_PREFIX = "code_"
_STACK_GLOBAL_OPERAND_PREVIEW_MAX = 128
_STACK_GLOBAL_BINARY_PREVIEW_BYTES = 8
_STACK_GLOBAL_OPERAND_PREVIEWER = reprlib.Repr()
_STACK_GLOBAL_OPERAND_PREVIEWER.maxstring = _STACK_GLOBAL_OPERAND_PREVIEW_MAX
_STACK_GLOBAL_OPERAND_PREVIEWER.maxother = _STACK_GLOBAL_OPERAND_PREVIEW_MAX
_STACK_GLOBAL_OPERAND_PREVIEWER.maxlist = 4
_STACK_GLOBAL_OPERAND_PREVIEWER.maxtuple = 4
_STACK_GLOBAL_OPERAND_PREVIEWER.maxset = 4
_STACK_GLOBAL_OPERAND_PREVIEWER.maxfrozenset = 4
_STACK_GLOBAL_OPERAND_PREVIEWER.maxdict = 4


def _format_stack_global_operand_preview(value: Any) -> str:
"""Return a bounded diagnostic preview for malformed STACK_GLOBAL operands."""
if isinstance(value, (bytes, bytearray, memoryview)):
value_len = value.nbytes if isinstance(value, memoryview) else len(value)
prefix_bytes = bytes(value[:_STACK_GLOBAL_BINARY_PREVIEW_BYTES])
suffix = "..." if value_len > _STACK_GLOBAL_BINARY_PREVIEW_BYTES else ""
return f"{type(value).__name__}(len={value_len}, hex=0x{prefix_bytes.hex()}{suffix})"

preview = _STACK_GLOBAL_OPERAND_PREVIEWER.repr(value)
if len(preview) > _STACK_GLOBAL_OPERAND_PREVIEW_MAX:
preview = preview[:_STACK_GLOBAL_OPERAND_PREVIEW_MAX] + "...<truncated>"

preview_value_len: int | None
try:
preview_value_len = len(value)
except Exception:
preview_value_len = None

length_suffix = f" (len={preview_value_len})" if preview_value_len is not None else ""
return f"{type(value).__name__}:{preview}{length_suffix}"


def _genops_with_fallback(file_obj: BinaryIO, *, multi_stream: bool = False) -> Any:
Expand Down Expand Up @@ -1903,20 +1936,27 @@ def _parse_module_function(arg: str) -> tuple[str, str] | None:

def _build_symbolic_reference_maps(
opcodes: list[tuple],
) -> tuple[dict[int, tuple[str, str]], dict[int, tuple[str, str]]]:
) -> tuple[
dict[int, tuple[str, str]],
dict[int, tuple[str, str]],
dict[int, dict[str, str]],
]:
"""
Build symbolic maps of callable references in an opcode stream.

Returns:
Tuple of:
- stack_global_refs: opcode index -> (module, function) for STACK_GLOBAL
- callable_refs: opcode index -> (module, function) for REDUCE/NEWOBJ/OBJ/INST call targets
- malformed_stack_globals: opcode index -> malformed STACK_GLOBAL operand details
"""
stack_global_refs: dict[int, tuple[str, str]] = {}
callable_refs: dict[int, tuple[str, str]] = {}
malformed_stack_globals: dict[int, dict[str, str]] = {}

marker = object()
unknown = object()
missing_memo = object()
stack: list[Any] = []
memo: dict[int | str, Any] = {}
next_memo_index = 0
Expand All @@ -1939,6 +1979,15 @@ def _pop_to_mark() -> list[Any]:
def _is_ref(value: Any) -> TypeGuard[tuple[str, str]]:
return isinstance(value, tuple) and len(value) == 2 and isinstance(value[0], str) and isinstance(value[1], str)

def _classify_stack_global_operand(value: Any) -> tuple[str, str]:
if isinstance(value, str):
return "string", value
if value is missing_memo:
return "missing_memo", "unknown"
if value is unknown:
return "unknown", "unknown"
return "non_string", _format_stack_global_operand_preview(value)

for i, (opcode, arg, _pos) in enumerate(opcodes):
name = opcode.name

Expand Down Expand Up @@ -1972,6 +2021,20 @@ def _is_ref(value: Any) -> TypeGuard[tuple[str, str]]:
stack_global_refs[i] = ref
stack.append(ref)
else:
module_kind, module_value = _classify_stack_global_operand(mod_name)
function_kind, function_value = _classify_stack_global_operand(func_name)
reason = "insufficient_context"
if "missing_memo" in {module_kind, function_kind}:
reason = "missing_memo"
elif "non_string" in {module_kind, function_kind}:
reason = "mixed_or_non_string"
malformed_stack_globals[i] = {
"module_kind": module_kind,
"module": module_value,
"function_kind": function_kind,
"function": function_value,
"reason": reason,
}
stack.append(unknown)
continue

Expand All @@ -1994,7 +2057,7 @@ def _is_ref(value: Any) -> TypeGuard[tuple[str, str]]:
continue

if name in {"GET", "BINGET", "LONG_BINGET"}:
stack.append(memo.get(arg, unknown))
stack.append(memo.get(arg, missing_memo))
continue

if name == "MARK":
Expand Down Expand Up @@ -2095,20 +2158,12 @@ def _is_ref(value: Any) -> TypeGuard[tuple[str, str]]:
stack.append(unknown)
continue

if name == "STOP":
# Clear memo at stream boundaries so that a safe memo entry from
# stream 1 cannot be inherited by a dangerous callable in stream 2
# (cross-stream memo contamination).
memo.clear()
next_memo_index = 0
stack.clear()
if name in {"NONE", "NEWTRUE", "NEWFALSE"}:
stack.append(None if name == "NONE" else name == "NEWTRUE")
continue

if name in {
"PERSID",
"NONE",
"NEWTRUE",
"NEWFALSE",
"INT",
"BININT",
"BININT1",
Expand All @@ -2127,7 +2182,7 @@ def _is_ref(value: Any) -> TypeGuard[tuple[str, str]]:
"BINUNICODE",
"BINUNICODE8",
}:
stack.append(unknown)
stack.append(arg)

if name == "STOP":
# Reset memo and stack at pickle stream boundaries so that
Expand All @@ -2137,7 +2192,7 @@ def _is_ref(value: Any) -> TypeGuard[tuple[str, str]]:
stack.clear()
next_memo_index = 0

return stack_global_refs, callable_refs
return stack_global_refs, callable_refs, malformed_stack_globals


def _find_stack_global_strings(
Expand Down Expand Up @@ -2585,7 +2640,7 @@ def _is_dangerous_ref(mod: str, func: str) -> bool:
return is_suspicious_global(mod, func)

if stack_global_refs is None or callable_refs is None:
computed_stack_refs, computed_callable_refs = _build_symbolic_reference_maps(opcodes)
computed_stack_refs, computed_callable_refs, _ = _build_symbolic_reference_maps(opcodes)
else:
computed_stack_refs, computed_callable_refs = stack_global_refs, callable_refs

Expand Down Expand Up @@ -2732,7 +2787,7 @@ def check_opcode_sequence(
return suspicious_patterns # Return empty list for legitimate ML content

if stack_global_refs is None or callable_refs is None:
computed_stack_refs, computed_callable_refs = _build_symbolic_reference_maps(opcodes)
computed_stack_refs, computed_callable_refs, _ = _build_symbolic_reference_maps(opcodes)
else:
computed_stack_refs, computed_callable_refs = stack_global_refs, callable_refs

Expand Down Expand Up @@ -3809,7 +3864,7 @@ def _extract_globals_advanced(self, data: IO[bytes], multiple_pickles: bool = Tr
logger.debug(f"Pickle parsing failed with no globals found: {e}")
return set()

stack_global_refs, _callable_refs = _build_symbolic_reference_maps(ops)
stack_global_refs, _callable_refs, _ = _build_symbolic_reference_maps(ops)

last_byte = data.read(1)
if last_byte:
Expand Down Expand Up @@ -4108,7 +4163,7 @@ def _scan_pickle_bytes(self, file_obj: BinaryIO, file_size: int) -> ScanResult:

# ML CONTEXT FILTERING: Analyze ML context once for the entire pickle
ml_context = _detect_ml_context(opcodes)
stack_global_refs, callable_refs = _build_symbolic_reference_maps(opcodes)
stack_global_refs, callable_refs, malformed_stack_globals = _build_symbolic_reference_maps(opcodes)

# CVE-2025-32434 specific opcode sequence analysis - REMOVED
# Now only show CVE info in REDUCE opcode detection messages
Expand Down Expand Up @@ -4831,8 +4886,55 @@ def get_depth(x):
rule_code=None, # Passing check
)
else:
# Only warn about insufficient context if not ML content
if not ml_context.get("is_ml_content", False):
malformed = malformed_stack_globals.get(i)
if malformed and malformed.get("reason") != "insufficient_context":
suspicious_count += 1
module_hint = malformed.get("module", "unknown")
function_hint = malformed.get("function", "unknown")
module_kind = malformed.get("module_kind", "unknown")
function_kind = malformed.get("function_kind", "unknown")
reason = malformed.get("reason", "mixed_or_non_string")
module_looks_dangerous = (
module_kind == "string"
and module_hint not in {"", "unknown"}
and _is_dangerous_module(module_hint)
)
severity = IssueSeverity.CRITICAL if module_looks_dangerous else IssueSeverity.WARNING
if reason == "missing_memo":
message = (
"STACK_GLOBAL references missing or invalid memoized operand(s): "
f"module={module_hint} ({module_kind}), function={function_hint} ({function_kind})"
)
else:
message = (
"Malformed STACK_GLOBAL operand types can hide dangerous imports: "
f"module={module_hint} ({module_kind}), function={function_hint} ({function_kind})"
)

result.add_check(
name="STACK_GLOBAL Context Check",
passed=False,
message=message,
severity=severity,
location=f"{self.current_file_path} (pos {pos})",
rule_code="S205",
details={
"position": pos,
"opcode": opcode.name,
"module": module_hint,
"function": function_hint,
"module_kind": module_kind,
"function_kind": function_kind,
"reason": reason,
"ml_context_confidence": ml_context.get("overall_confidence", 0),
},
why=(
"STACK_GLOBAL should be formed from two string operands. Non-string operands "
"or missing memoized values indicate a malformed-by-design payload and are "
"treated as a security finding under fail-closed handling."
),
)
elif not ml_context.get("is_ml_content", False):
result.add_check(
name="STACK_GLOBAL Context Check",
passed=False,
Expand Down Expand Up @@ -5908,7 +6010,7 @@ def _detect_cve_2026_24747_sequences(self, opcodes: list[tuple], file_size: int)
# Pre-compute symbolic references for STACK_GLOBAL resolution.
# This handles BINUNICODE8, memoized strings (BINGET/LONG_BINGET),
# and indirect stack flows that a narrow lookback would miss.
stack_global_refs, _ = _build_symbolic_reference_maps(opcodes)
stack_global_refs, _, _ = _build_symbolic_reference_maps(opcodes)

for i, (opcode, _arg, pos) in enumerate(opcodes):
if opcode.name not in ("SETITEM", "SETITEMS"):
Expand Down
Loading
Loading