Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 53 additions & 41 deletions fickling/analysis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import ast
import json
from abc import ABC, abstractmethod
from ast import unparse
Expand Down Expand Up @@ -325,45 +326,55 @@ class UnsafeImportsML(Analysis):
def analyze(self, context: AnalysisContext) -> Iterator[AnalysisResult]:
for node in context.pickled.properties.imports:
shortened, _ = context.shorten_code(node)
all_modules = [
node.module.rsplit(".", i)[0] for i in range(0, node.module.count(".") + 1)
]
for module_name in all_modules:
if module_name in self.UNSAFE_MODULES:
# Special handling for builtins - check specific function names
if module_name in BUILTIN_MODULE_NAMES:
for n in node.names:
if n.name not in SAFE_BUILTINS:
risk_info = self.UNSAFE_MODULES[module_name]
yield AnalysisResult(
Severity.LIKELY_OVERTLY_MALICIOUS,
f"`{shortened}` imports `{n.name}` from `{module_name}` "
f"which can execute arbitrary code. {risk_info}",
"UnsafeImportsML",
trigger=shortened,
)
else:
# All other unsafe modules are fully blocked
risk_info = self.UNSAFE_MODULES[module_name]
yield AnalysisResult(
Severity.LIKELY_OVERTLY_MALICIOUS,
f"`{shortened}` uses `{module_name}` that is indicative of a malicious pickle file. {risk_info}",
"UnsafeImportsML",
trigger=shortened,
)
if node.module in self.UNSAFE_IMPORTS:
for n in node.names:
if n.name in self.UNSAFE_IMPORTS[node.module]:
risk_info = self.UNSAFE_IMPORTS[node.module][n.name]
yield AnalysisResult(
Severity.LIKELY_OVERTLY_MALICIOUS,
f"`{shortened}` imports `{n.name}` that is indicative of a malicious pickle file. {risk_info}",
"UnsafeImportsML",
trigger=shortened,
)

match node:
case ast.ImportFrom(module=module, names=names) if module:
modules_to_check = [module]
imported_names = names
case ast.Import(names=names):
modules_to_check = [alias.name for alias in names]
imported_names = []
case _:
continue

for module in modules_to_check:
all_modules = [module.rsplit(".", i)[0] for i in range(0, module.count(".") + 1)]
for module_name in all_modules:
if module_name in self.UNSAFE_MODULES:
# Special handling for builtins - check specific function names
if module_name in BUILTIN_MODULE_NAMES:
for n in imported_names:
if n.name not in SAFE_BUILTINS:
risk_info = self.UNSAFE_MODULES[module_name]
yield AnalysisResult(
Severity.LIKELY_OVERTLY_MALICIOUS,
f"`{shortened}` imports `{n.name}` from `{module_name}` "
f"which can execute arbitrary code. {risk_info}",
"UnsafeImportsML",
trigger=shortened,
)
else:
# All other unsafe modules are fully blocked
risk_info = self.UNSAFE_MODULES[module_name]
yield AnalysisResult(
Severity.LIKELY_OVERTLY_MALICIOUS,
f"`{shortened}` uses `{module_name}` that is indicative of a malicious pickle file. {risk_info}",
"UnsafeImportsML",
trigger=shortened,
)
if module in self.UNSAFE_IMPORTS:
for n in imported_names:
if n.name in self.UNSAFE_IMPORTS[module]:
risk_info = self.UNSAFE_IMPORTS[module][n.name]
yield AnalysisResult(
Severity.LIKELY_OVERTLY_MALICIOUS,
f"`{shortened}` imports `{n.name}` that is indicative of a malicious pickle file. {risk_info}",
"UnsafeImportsML",
trigger=shortened,
)
# NOTE(boyan): Special case with eval?
# Copy pasted from pickled.unsafe_imports() original implementation
elif "eval" in (n.name for n in node.names):
if "eval" in (n.name for n in imported_names):
yield AnalysisResult(
Severity.LIKELY_OVERTLY_MALICIOUS,
f"`{shortened}` imports `eval` which can execute arbitrary code",
Expand Down Expand Up @@ -425,10 +436,11 @@ def analyze(self, context: AnalysisContext) -> Iterator[AnalysisResult]:
class UnsafeImports(Analysis):
def analyze(self, context: AnalysisContext) -> Iterator[AnalysisResult]:
for node in context.pickled.unsafe_imports():
if node.module in BUILTIN_MODULE_NAMES and all(
n.name in SAFE_BUILTINS for n in node.names
):
continue
if isinstance(node, ast.ImportFrom):
if node.module in BUILTIN_MODULE_NAMES and all(
n.name in SAFE_BUILTINS for n in node.names
):
continue
shortened, _ = context.shorten_code(node)
yield AnalysisResult(
Severity.LIKELY_OVERTLY_MALICIOUS,
Expand Down
2 changes: 1 addition & 1 deletion fickling/fickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def __post_init__(self) -> None:


def is_std_module(module_name: str) -> bool:
return module_name in BUILTIN_STDLIB_MODULE_NAMES
return module_name.split(".")[0] in BUILTIN_STDLIB_MODULE_NAMES


def extract_identifier_from_ast_node(
Expand Down
76 changes: 76 additions & 0 deletions test/test_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from unittest import TestCase

import fickling.fickle as op
from fickling.analysis import (
Analyzer,
UnsafeImports,
UnsafeImportsML,
check_safety,
)
from fickling.fickle import Pickled


class TestImportMatchingGaps(TestCase):
"""Regression tests for gaps in how analysis passes handle imports."""

def test_stdlib_submodule_not_flagged_as_nonstandard(self):
"""Stdlib submodules like collections.abc should not trigger NonStandardImports."""
pickled = Pickled(
[
op.Proto.create(4),
op.Global.create("collections.abc", "Mapping"),
op.EmptyTuple(),
op.Reduce(),
op.Stop(),
]
)
result = check_safety(pickled)
non_std_results = [r for r in result.results if r.analysis_name == "NonStandardImports"]
self.assertEqual(
len(non_std_results),
0,
"collections.abc should not be flagged as non-standard",
)

def test_eval_import_from_unsafe_imports_ml_module(self):
"""Eval check must not be skipped when module is in UNSAFE_IMPORTS.

Not a real payload (_io.eval doesn't exist), just a regression trigger.
"""
pickled = Pickled(
[
op.Proto.create(4),
op.ShortBinUnicode("_io"),
op.ShortBinUnicode("eval"),
op.StackGlobal(),
op.EmptyTuple(),
op.Reduce(),
op.Stop(),
]
)
result = check_safety(pickled)
eval_results = [
r
for r in result.results
if r.analysis_name == "UnsafeImportsML" and "eval" in (r.message or "")
]
self.assertGreater(
len(eval_results),
0,
"UnsafeImportsML should flag 'from _io import eval'",
)

def test_ext1_ast_import_does_not_crash_analysis(self):
"""Ext1 generates ast.Import nodes; both analysis passes must handle them."""
pickled = Pickled(
[
op.Proto.create(2),
op.Ext1(1),
op.Stop(),
]
)
# Must not raise AttributeError: 'Import' object has no attribute 'module'
for analysis in [UnsafeImportsML(), UnsafeImports()]:
with self.subTest(analysis=type(analysis).__name__):
result = Analyzer([analysis]).analyze(pickled)
self.assertIsNotNone(result)
Loading