Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions modelaudit/scanners/pickle_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3790,15 +3790,41 @@ def _create_opcode_sequence_check(self, sequence_result: Any, result: ScanResult
),
)

def _extract_globals_advanced(self, data: IO[bytes], multiple_pickles: bool = True) -> set[tuple[str, str]]:
def _extract_globals_advanced(
self,
data: IO[bytes],
multiple_pickles: bool = True,
scan_start_time: float | None = None,
) -> set[tuple[str, str]]:
"""Advanced pickle global extraction with STACK_GLOBAL and memo support."""
globals_found: set[tuple[str, str]] = set()
memo: dict[int | str, str] = {}
extracted_opcodes = 0
effective_scan_start_time = scan_start_time if scan_start_time is not None else self.scan_start_time

last_byte = b"dummy"
while last_byte != b"":
extraction_truncated = False
try:
ops: list[tuple[Any, Any, int | None]] = list(pickletools.genops(data))
ops: list[tuple[Any, Any, int | None]] = []
for op in pickletools.genops(data):
ops.append(op)
extracted_opcodes += 1

if extracted_opcodes > self.max_opcodes:
logger.warning(
f"Advanced global extraction stopped after exceeding max_opcodes ({self.max_opcodes})"
)
extraction_truncated = True
break

if (
effective_scan_start_time is not None
and (time.time() - effective_scan_start_time) > self.timeout
):
logger.warning(f"Advanced global extraction stopped after exceeding timeout ({self.timeout}s)")
extraction_truncated = True
break
except Exception as e:
if globals_found:
logger.warning(f"Pickle parsing failed, but found {len(globals_found)} globals: {e}")
Expand Down Expand Up @@ -3836,6 +3862,8 @@ def _extract_globals_advanced(self, data: IO[bytes], multiple_pickles: bool = Tr

if not multiple_pickles:
break
if extraction_truncated:
break
return globals_found

def _extract_stack_global_values(
Expand Down
36 changes: 36 additions & 0 deletions tests/scanners/test_pickle_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,5 +1210,41 @@ def _raise_memory_error(*args: object, **kwargs: object) -> object:
assert format_validation_checks[0].details["exception_type"] == "MemoryError"


def test_extract_globals_advanced_respects_max_opcodes(monkeypatch: pytest.MonkeyPatch) -> None:
"""Advanced extraction should stop after max_opcodes instead of parsing everything."""
scanner = PickleScanner({"max_opcodes": 3})

def _fake_genops(_data: object):
for i in range(10):
yield (type("Op", (), {"name": "NOP"})(), i, i)

monkeypatch.setattr("modelaudit.scanners.pickle_scanner.pickletools.genops", _fake_genops)

from io import BytesIO

globals_found = scanner._extract_globals_advanced(data=BytesIO(b"x"), multiple_pickles=False)

assert globals_found == set()


def test_extract_globals_advanced_respects_scan_timeout(monkeypatch: pytest.MonkeyPatch) -> None:
"""Advanced extraction should honor active scan timeout budget."""
scanner = PickleScanner({"timeout": 1})
scanner.scan_start_time = 100.0

def _fake_genops(_data: object):
for i in range(10):
yield (type("Op", (), {"name": "NOP"})(), i, i)

monkeypatch.setattr("modelaudit.scanners.pickle_scanner.pickletools.genops", _fake_genops)
monkeypatch.setattr("modelaudit.scanners.pickle_scanner.time.time", lambda: 102.0)

from io import BytesIO

globals_found = scanner._extract_globals_advanced(data=BytesIO(b"x"), multiple_pickles=False)

assert globals_found == set()


if __name__ == "__main__":
unittest.main()
Loading