Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- **security:** recurse into object-dtype `.npy` payloads and `.npz` object members with the pickle scanner while preserving CVE-2019-6446 warnings and archive-member context
- **security:** treat legacy `httplib` pickle globals the same as `http.client`, including import-only and `REDUCE` findings in standalone and archived payloads
- **security:** harden TensorFlow weight extraction limits to bound actual tensor payload materialization, including malformed `tensor_content` and string-backed tensors, and continue scanning past oversized `Const` nodes
- **security:** stream TAR members to temp files under size limits instead of buffering whole entries in memory during scan
Expand Down
6 changes: 5 additions & 1 deletion modelaudit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,12 @@ def _group_checks_by_asset(checks_list: list[Any]) -> dict[tuple[str, str], list
check_name = check.get("name", "Unknown Check")
location = check.get("location", "")
primary_asset = _extract_primary_asset_from_location(location)
details = check.get("details")
zip_entry = details.get("zip_entry") if isinstance(details, dict) else None

group_key = (check_name, primary_asset)
asset_group = f"{primary_asset}:{zip_entry}" if isinstance(zip_entry, str) and zip_entry else primary_asset

group_key = (check_name, asset_group)
check_groups[group_key].append(check)

return check_groups
Expand Down
23 changes: 22 additions & 1 deletion modelaudit/scanners/numpy_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import sys
import warnings
from typing import TYPE_CHECKING, Any, ClassVar
from typing import TYPE_CHECKING, Any, BinaryIO, ClassVar

from .base import BaseScanner, IssueSeverity, ScanResult
from .pickle_scanner import PickleScanner

# Import NumPy with compatibility handling
try:
Expand Down Expand Up @@ -88,6 +89,17 @@ def _validate_array_dimensions(self, shape: tuple[int, ...]) -> None:
CVE_2019_6446_CVSS = 9.8
CVE_2019_6446_CWE = "CWE-502"

def _scan_embedded_pickle_payload(
self,
file_obj: BinaryIO,
payload_size: int,
context_path: str,
) -> ScanResult:
"""Reuse PickleScanner analysis for object-dtype NumPy payloads."""
pickle_scanner = PickleScanner(config=self.config)
pickle_scanner.current_file_path = context_path
return pickle_scanner._scan_pickle_bytes(file_obj, payload_size)

def _validate_dtype(self, dtype: Any) -> None:
"""Validate numpy dtype for security"""
# Check for problematic data types
Expand Down Expand Up @@ -299,6 +311,15 @@ def scan(self, path: str) -> ScanResult:
),
)

f.seek(data_offset)
embedded_result = self._scan_embedded_pickle_payload(
f,
file_size - data_offset,
path,
)
result.issues.extend(embedded_result.issues)
result.checks.extend(embedded_result.checks)

self._validate_dtype(dtype)
result.add_check(
name="Data Type Safety Check",
Expand Down
70 changes: 40 additions & 30 deletions modelaudit/scanners/zip_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,44 @@ def scan(self, path: str) -> ScanResult:
result.metadata["file_size"] = os.path.getsize(path)
return result

def _rewrite_nested_result_context(
self, scan_result: ScanResult, tmp_path: str, archive_path: str, entry_name: str
) -> None:
"""Rewrite nested result locations so archive members, not temp files, are reported."""
archive_location = f"{archive_path}:{entry_name}"

for issue in scan_result.issues:
if issue.location:
if issue.location.startswith(tmp_path):
issue.location = issue.location.replace(tmp_path, archive_location, 1)
else:
issue.location = f"{archive_location} {issue.location}"
else:
issue.location = archive_location

existing_issue_entry = issue.details.get("zip_entry")
issue.details["zip_entry"] = (
f"{entry_name}:{existing_issue_entry}"
if isinstance(existing_issue_entry, str) and existing_issue_entry
else entry_name
)

for check in scan_result.checks:
if check.location:
if check.location.startswith(tmp_path):
check.location = check.location.replace(tmp_path, archive_location, 1)
else:
check.location = f"{archive_location} {check.location}"
else:
check.location = archive_location

existing_check_entry = check.details.get("zip_entry")
check.details["zip_entry"] = (
f"{entry_name}:{existing_check_entry}"
if isinstance(existing_check_entry, str) and existing_check_entry
else entry_name
)

def _scan_zip_file(self, path: str, depth: int = 0) -> ScanResult:
"""Recursively scan a ZIP file and its contents"""
result = ScanResult(scanner_name=self.name)
Expand Down Expand Up @@ -317,16 +355,7 @@ def _scan_zip_file(self, path: str, depth: int = 0) -> ScanResult:
if name.lower().endswith(".zip"):
try:
nested_result = self._scan_zip_file(tmp_path, depth + 1)
# Update locations in nested results
for issue in nested_result.issues:
if issue.location and issue.location.startswith(
tmp_path,
):
issue.location = issue.location.replace(
tmp_path,
f"{path}:{name}",
1,
)
self._rewrite_nested_result_context(nested_result, tmp_path, path, name)
result.merge(nested_result)

asset_entry = asset_from_scan_result(
Expand All @@ -348,26 +377,7 @@ def _scan_zip_file(self, path: str, depth: int = 0) -> ScanResult:

# Use core.scan_file to scan with appropriate scanner
file_result = core.scan_file(tmp_path, self.config)

# Update locations in file results
for issue in file_result.issues:
if issue.location:
if issue.location.startswith(tmp_path):
issue.location = issue.location.replace(
tmp_path,
f"{path}:{name}",
1,
)
else:
issue.location = f"{path}:{name} {issue.location}"
else:
issue.location = f"{path}:{name}"

# Add zip entry name to details
if issue.details:
issue.details["zip_entry"] = name
else:
issue.details = {"zip_entry": name}
self._rewrite_nested_result_context(file_result, tmp_path, path, name)

result.merge(file_result)

Expand Down
169 changes: 168 additions & 1 deletion tests/scanners/test_numpy_scanner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from collections.abc import Callable
from pathlib import Path
from typing import Any

import numpy as np

from modelaudit.scanners.base import IssueSeverity
from modelaudit.scanners.base import Check, IssueSeverity, ScanResult
from modelaudit.scanners.numpy_scanner import NumPyScanner


Expand Down Expand Up @@ -100,3 +104,166 @@ def test_structured_with_object_field_triggers_cve(self, tmp_path):

cve_checks = [c for c in result.checks if "CVE-2019-6446" in (c.name + c.message)]
assert len(cve_checks) > 0, "Structured dtype with object field should trigger CVE"


class _ExecPayload:
def __reduce__(self) -> tuple[Callable[..., Any], tuple[Any, ...]]:
return (exec, ("print('owned')",))


class _SSLPayload:
def __reduce__(self) -> tuple[Callable[..., Any], tuple[Any, ...]]:
import ssl

return (ssl.get_server_certificate, (("example.com", 443),))


def _failed_checks(result: ScanResult) -> list[Check]:
return [c for c in result.checks if c.status.value == "failed"]
Comment on lines +109 to +122
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add type hints to newly introduced helpers.

The new helper methods/functions are untyped, which violates the repository Python typing rule.

Proposed fix
+from typing import Any, Callable
+
 class _ExecPayload:
-    def __reduce__(self):
+    def __reduce__(self) -> tuple[Callable[..., object], tuple[str]]:
         return (exec, ("print('owned')",))
@@
 class _SSLPayload:
-    def __reduce__(self):
+    def __reduce__(self) -> tuple[Callable[..., object], tuple[tuple[str, int]]]:
         import ssl
 
         return (ssl.get_server_certificate, (("example.com", 443),))
@@
-def _failed_checks(result):
+def _failed_checks(result: Any) -> list[Any]:
     return [c for c in result.checks if c.status.value == "failed"]

As per coding guidelines: "Always include type hints in Python code".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/scanners/test_numpy_scanner.py` around lines 107 - 120, Add Python type
hints to the new helpers: annotate _ExecPayload.__reduce__ and
_SSLPayload.__reduce__ to return a Tuple[Callable[..., Any], Tuple[Any, ...]]
and annotate their self parameter as usual; import the needed typing names
(e.g., Callable, Tuple, Any). Also annotate _failed_checks to accept result: Any
(or the specific result type if available) and return List[Any] (or
List[CheckType] if you have a Check type); import List/Any as needed. Ensure all
new function/method signatures use these type annotations to satisfy the
repository typing rule.



def _inject_comment_token_into_npy_payload(path: Path) -> None:
with path.open("rb") as handle:
major, minor = np.lib.format.read_magic(handle)
if (major, minor) == (1, 0):
np.lib.format.read_array_header_1_0(handle)
elif (major, minor) == (2, 0):
np.lib.format.read_array_header_2_0(handle)
else:
read_array_header = getattr(np.lib.format, "_read_array_header", None)
if read_array_header is None:
raise AssertionError(f"Unsupported NumPy header version: {(major, minor)}")
read_array_header(handle, version=(major, minor))
data_offset = handle.tell()
payload = handle.read()

if len(payload) < 2 or payload[0] != 0x80:
raise AssertionError(f"Unexpected pickle payload header: {payload[:4]!r}")

protocol = payload[1]
comment = b"# harmless note"
if protocol >= 4:
comment_op = b"\x8c" + bytes([len(comment)]) + comment
else:
comment_op = b"X" + len(comment).to_bytes(4, "little") + comment

patched = payload[:2] + comment_op + b"0" + payload[2:]
original = path.read_bytes()
path.write_bytes(original[:data_offset] + patched)


def test_object_dtype_numpy_recurses_into_pickle_exec(tmp_path: Path) -> None:
arr = np.array([_ExecPayload()], dtype=object)
path = tmp_path / "malicious_object.npy"
np.save(path, arr, allow_pickle=True)

scanner = NumPyScanner()
result = scanner.scan(str(path))

failed = _failed_checks(result)
assert any("CVE-2019-6446" in (c.name + c.message) for c in failed)
assert any("exec" in (c.message.lower()) for c in failed)


def test_object_dtype_numpy_recurses_into_pickle_ssl(tmp_path: Path) -> None:
arr = np.array([_SSLPayload()], dtype=object)
path = tmp_path / "malicious_ssl_object.npy"
np.save(path, arr, allow_pickle=True)

scanner = NumPyScanner()
result = scanner.scan(str(path))

failed = _failed_checks(result)
assert any("CVE-2019-6446" in (c.name + c.message) for c in failed)
assert any("ssl.get_server_certificate" in c.message for c in failed)


def test_numeric_npz_has_no_pickle_recursion_findings(tmp_path: Path) -> None:
npz_path = tmp_path / "numeric_only.npz"
np.savez(npz_path, a=np.arange(4), b=np.ones((2, 2), dtype=np.float32))

from modelaudit.scanners.zip_scanner import ZipScanner

result = ZipScanner().scan(str(npz_path))

assert not any("CVE-2019-6446" in (c.name + c.message) for c in result.checks)
assert not any("exec" in c.message.lower() for c in result.checks)
assert not any(i.details.get("cve_id") == "CVE-2019-6446" for i in result.issues)
assert not any("exec" in i.message.lower() for i in result.issues)


def test_object_npz_member_recurses_into_pickle_exec_with_member_context(tmp_path: Path) -> None:
safe = np.array([1, 2, 3], dtype=np.int64)
malicious = np.array([_ExecPayload()], dtype=object)
npz_path = tmp_path / "mixed_object.npz"
np.savez(npz_path, safe=safe, payload=malicious)

from modelaudit.scanners.zip_scanner import ZipScanner

result = ZipScanner().scan(str(npz_path))

failed = _failed_checks(result)
assert any("CVE-2019-6446" in (c.name + c.message) and "payload.npy" in str(c.location) for c in failed)
assert any("exec" in i.message.lower() and i.details.get("zip_entry") == "payload.npy" for i in result.issues)


def test_object_dtype_numpy_comment_token_bypass_still_detected(tmp_path: Path) -> None:
arr = np.array([_ExecPayload()], dtype=object)
path = tmp_path / "comment_token.npy"
np.save(path, arr, allow_pickle=True)
_inject_comment_token_into_npy_payload(path)

scanner = NumPyScanner()
result = scanner.scan(str(path))

failed = _failed_checks(result)
assert any("CVE-2019-6446" in (c.name + c.message) for c in failed)
assert any("exec" in c.message.lower() for c in failed)


def test_benign_object_dtype_numpy_no_nested_critical(tmp_path: Path) -> None:
arr = np.array([{"k": "v"}, [1, 2, 3]], dtype=object)
path = tmp_path / "benign_object.npy"
np.save(path, arr, allow_pickle=True)

scanner = NumPyScanner()
result = scanner.scan(str(path))

assert any("CVE-2019-6446" in (c.name + c.message) for c in result.checks)
assert not any(i.severity == IssueSeverity.CRITICAL for i in result.issues if "CVE-2019-6446" not in i.message)


def test_benign_object_dtype_npz_no_nested_critical(tmp_path: Path) -> None:
npz_path = tmp_path / "benign_object.npz"
np.savez(npz_path, safe=np.array([{"x": 1}], dtype=object))

from modelaudit.scanners.zip_scanner import ZipScanner

result = ZipScanner().scan(str(npz_path))

assert any("CVE-2019-6446" in (c.name + c.message) for c in result.checks)
assert not any(i.severity == IssueSeverity.CRITICAL for i in result.issues)


def test_truncated_npy_fails_safely(tmp_path: Path) -> None:
arr = np.array([_ExecPayload()], dtype=object)
path = tmp_path / "truncated.npy"
np.save(path, arr, allow_pickle=True)
path.write_bytes(path.read_bytes()[:-8])

scanner = NumPyScanner()
result = scanner.scan(str(path))

assert any(i.severity == IssueSeverity.INFO for i in result.issues)


def test_corrupted_npz_fails_safely(tmp_path: Path) -> None:
npz_path = tmp_path / "corrupt.npz"
npz_path.write_bytes(b"not-a-zip")

from modelaudit.scanners.zip_scanner import ZipScanner

result = ZipScanner().scan(str(npz_path))

assert result.success is False
assert any(i.severity == IssueSeverity.INFO for i in result.issues)
Loading
Loading