Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 155 additions & 2 deletions burr/integrations/serde/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,94 @@
# Pickle serde registration
# This is not automatically registered because we want to register
# it based on class type.
import io
import pickle
import warnings
from typing import Iterable, List, Optional, Tuple

from burr.core import serde

# Type alias: an allowlist entry is ``(module, qualname)``.
# ``qualname`` may be ``"*"`` to permit any class in the given module.
AllowlistEntry = Tuple[str, str]

def register_type_to_pickle(cls):

class SecurityWarning(Warning):
"""Warning issued when pickle deserialization proceeds without an allowlist."""


# Global allowlist for pickle deserialization. When set, only
# ``(module, qualname)`` pairs in this list are permitted to be
# reconstructed during ``pickle.loads``.
_global_allowlist: Optional[List[AllowlistEntry]] = None


def set_pickle_serde_allowlist(allowlist: Optional[Iterable[AllowlistEntry]]) -> None:
"""Set a process-wide allowlist of ``(module, qualname)`` pairs that may be
reconstructed by ``deserialize_pickle``.

Each entry is a tuple ``(module_name, class_name)``. Use ``"*"`` as the class
name to permit any class in that module:

.. code-block:: python

from burr.integrations.serde import pickle as burr_pickle

burr_pickle.set_pickle_serde_allowlist([
("myapp.models", "User"),
("myapp.models", "Address"),
("myapp.types", "*"), # any class in myapp.types
])

When the allowlist is set, any attempt to deserialize a class outside the
allowlist raises ``pickle.UnpicklingError``. Pass ``None`` to clear it.

:param allowlist: Iterable of ``(module, qualname)`` tuples, or ``None``.
"""
global _global_allowlist
_global_allowlist = list(allowlist) if allowlist is not None else None


def _is_allowed(module: str, name: str, allowlist: Iterable[AllowlistEntry]) -> bool:
for allowed_module, allowed_name in allowlist:
if module == allowed_module and (allowed_name == "*" or allowed_name == name):
return True
return False


class _RestrictedUnpickler(pickle.Unpickler):
"""Pickle unpickler that only resolves classes present in an allowlist.

See https://docs.python.org/3/library/pickle.html#restricting-globals
for the standard library guidance this is modeled on.
"""

def __init__(self, file, allowlist: Iterable[AllowlistEntry], **kwargs):
super().__init__(file, **kwargs)
self._allowlist = list(allowlist)

def find_class(self, module: str, name: str):
if not _is_allowed(module, name, self._allowlist):
raise pickle.UnpicklingError(
f"Refusing to load pickled class '{module}.{name}': not in the "
f"pickle serde allowlist. Add it via the ``allowlist`` argument "
f"to register_type_to_pickle()/deserialize, or call "
f"burr.integrations.serde.pickle.set_pickle_serde_allowlist([...])."
)
return super().find_class(module, name)


def _restricted_loads(data: bytes, allowlist: Iterable[AllowlistEntry], **kwargs):
return _RestrictedUnpickler(io.BytesIO(data), allowlist, **kwargs).load()


# Tracks call sites that have already received a "no allowlist" warning, so the
# warning is emitted at most once per registration site rather than once per
# deserialization (which could be noisy on hot paths).
_warned_call_sites: set = set()


def register_type_to_pickle(cls, allowlist: Optional[Iterable[AllowlistEntry]] = None):
"""Register a class to be serialized/deserialized using pickle.

Note: `pickle_kwargs` are passed to the pickle.dumps and pickle.loads functions.
Expand All @@ -40,9 +122,40 @@ def __init__(self, name, email):
from burr.integrations.serde import pickle
pickle.register_type_to_pickle(User) # this will register the User class to be serialized/deserialized using pickle.

Trust model
-----------
Pickle is, by design, capable of executing arbitrary code during
deserialization. If the persistence backend (SQLite file, Redis, S3, the
local filesystem, etc.) can be written to by an untrusted party, a tampered
payload can trigger remote code execution when burr restores application
state.

To mitigate this, pass an ``allowlist`` of permitted ``(module, qualname)``
pairs. When set, deserialization will refuse to import any class outside
that list and raise ``pickle.UnpicklingError`` instead. You can also set a
process-wide default via
:func:`set_pickle_serde_allowlist`. If no allowlist is configured the legacy
behavior is preserved but a :class:`SecurityWarning` is emitted once per
call site.

.. code-block:: python

pickle.register_type_to_pickle(
User,
allowlist=[("myapp.models", "User")],
)

:param cls: The class to register
:param allowlist: Optional iterable of ``(module, qualname)`` pairs that
deserialization is permitted to import. ``"*"`` may be used as the
qualname to allow any class in a given module. If ``None``, falls back
to the global allowlist set via
:func:`set_pickle_serde_allowlist`, and if that is also ``None`` the
legacy unrestricted behavior is used with a :class:`SecurityWarning`.
"""
local_allowlist: Optional[List[AllowlistEntry]] = (
list(allowlist) if allowlist is not None else None
)

@serde.serialize.register(cls)
def serialize_pickle(value: cls, pickle_kwargs: dict = None, **kwargs) -> dict:
Expand All @@ -61,14 +174,54 @@ def serialize_pickle(value: cls, pickle_kwargs: dict = None, **kwargs) -> dict:
}

@serde.deserializer.register("pickle")
def deserialize_pickle(value: dict, pickle_kwargs: dict = None, **kwargs) -> cls:
def deserialize_pickle(
value: dict,
pickle_kwargs: dict = None,
allowlist: Optional[Iterable[AllowlistEntry]] = None,
**kwargs,
) -> cls:
"""Deserializes the value using pickle.

Resolution order for the effective allowlist:

1. The ``allowlist`` kwarg passed into deserialize (e.g. via
``State.deserialize(..., allowlist=...)``).
2. The ``allowlist`` argument passed to ``register_type_to_pickle``.
3. The process-wide default from
:func:`set_pickle_serde_allowlist`.
4. None — falls back to the legacy unrestricted ``pickle.loads`` path
and emits a :class:`SecurityWarning`.

:param value: the value to deserialize from.
:param pickle_kwargs: note required. Optional.
:param allowlist: per-call allowlist override, see above.
:param kwargs:
:return: object of type cls
"""
if pickle_kwargs is None:
pickle_kwargs = {}
effective_allowlist: Optional[List[AllowlistEntry]] = None
if allowlist is not None:
effective_allowlist = list(allowlist)
elif local_allowlist is not None:
effective_allowlist = local_allowlist
elif _global_allowlist is not None:
effective_allowlist = _global_allowlist

if effective_allowlist is not None:
return _restricted_loads(value["value"], effective_allowlist, **pickle_kwargs)

# Legacy path — warn once per registration site.
site_key = (cls.__module__, cls.__qualname__)
if site_key not in _warned_call_sites:
_warned_call_sites.add(site_key)
warnings.warn(
f"Deserializing pickled class '{cls.__module__}.{cls.__qualname__}' "
"without an allowlist. This is a remote-code-execution risk if the "
"persistence backend is writable by untrusted parties. Pass "
"allowlist=[(module, qualname), ...] to register_type_to_pickle() "
"or call burr.integrations.serde.pickle.set_pickle_serde_allowlist([...]).",
SecurityWarning,
stacklevel=2,
)
return pickle.loads(value["value"], **pickle_kwargs)
180 changes: 170 additions & 10 deletions tests/integrations/serde/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@
# specific language governing permissions and limitations
# under the License.

import pickle as _pickle
import warnings

import pytest

from burr.core import serde, state
from burr.integrations.serde import pickle
from burr.integrations.serde.pickle import SecurityWarning, _is_allowed, set_pickle_serde_allowlist


class User:
Expand All @@ -25,22 +31,176 @@ def __init__(self, name, email):
self.email = email


class Address:
def __init__(self, city):
self.city = city


class _ReduceCanary:
"""When unpickled, would call ``int("1")`` instead of doing the normal
object reconstruction. We use this to exercise the ``find_class`` hook of
the restricted unpickler without invoking anything actually dangerous.
"""

def __reduce__(self):
return (int, ("1",))


def test_serde_of_pickle_object():
pickle.register_type_to_pickle(User)
user = User(name="John Doe", email="john.doe@example.com")
og = state.State({"user": user, "test": "test"})
serialized = og.serialize()
assert serialized == {
with warnings.catch_warnings():
warnings.simplefilter("ignore", SecurityWarning)
serialized = og.serialize()
ng = state.State.deserialize(serialized)
assert serialized["user"][serde.KEY] == "pickle"
assert isinstance(ng["user"], User)
assert ng["user"].name == "John Doe"
assert ng["user"].email == "john.doe@example.com"


def test_is_allowed_logic():
assert _is_allowed("m", "C", [("m", "C")]) is True
assert _is_allowed("m", "C", [("m", "*")]) is True
assert _is_allowed("m", "C", [("other", "*")]) is False
assert _is_allowed("m", "C", [("m", "D")]) is False
assert _is_allowed("m", "C", []) is False


def test_malicious_reduce_payload_blocked_with_allowlist():
"""A pickle payload referencing a class outside the allowlist must be
refused at find_class() time, before any object is reconstructed."""
pickle.register_type_to_pickle(User, allowlist=[(__name__, "User")])

malicious_bytes = _pickle.dumps(_ReduceCanary())
tampered = {
"user": {
serde.KEY: "pickle",
"value": b"\x80\x04\x95Q\x00\x00\x00\x00\x00\x00\x00\x8c\x0btest_pi"
b"ckle\x94\x8c\x04User\x94\x93\x94)\x81\x94}\x94(\x8c\x04na"
b"me\x94\x8c\x08John Doe\x94\x8c\x05email\x94\x8c\x14john"
b".doe@example.com\x94ub.",
},
"test": "test",
"value": malicious_bytes,
}
}
with pytest.raises(ValueError, match="not in the pickle serde allowlist"):
state.State.deserialize(tampered)


def test_legitimate_object_roundtrips_when_module_on_allowlist():
pickle.register_type_to_pickle(User, allowlist=[(__name__, "User")])
user = User(name="Alice", email="a@example.com")
og = state.State({"user": user})
serialized = og.serialize()
ng = state.State.deserialize(serialized)
assert isinstance(ng["user"], User)
assert ng["user"].name == "John Doe"
assert ng["user"].email == "john.doe@example.com"
assert ng["user"].name == "Alice"


def test_module_level_allowlist_applies_to_fresh_registration():
"""A class registered after set_pickle_serde_allowlist() should pick up the
process-wide default."""
set_pickle_serde_allowlist([(__name__, "*")])
try:
pickle.register_type_to_pickle(Address)
addr = Address(city="Berlin")
og = state.State({"addr": addr})
serialized = og.serialize()
ng = state.State.deserialize(serialized)
assert isinstance(ng["addr"], Address)
assert ng["addr"].city == "Berlin"

# And a tampered payload referencing a class outside the module is
# still refused under the global allowlist.
tampered = {
"addr": {
serde.KEY: "pickle",
"value": _pickle.dumps(_ReduceCanary()),
}
}
with pytest.raises(ValueError, match="not in the pickle serde allowlist"):
state.State.deserialize(tampered)
finally:
set_pickle_serde_allowlist(None)


def test_instance_allowlist_overrides_module_default():
"""A per-call allowlist passed to register_type_to_pickle() should take
precedence over the process-wide default."""
# Global default permits everything in this module.
set_pickle_serde_allowlist([(__name__, "*")])
try:
# But the per-registration allowlist only permits User. So a payload
# referencing _ReduceCanary (also in this module) should be blocked.
pickle.register_type_to_pickle(User, allowlist=[(__name__, "User")])

tampered = {
"user": {
serde.KEY: "pickle",
"value": _pickle.dumps(_ReduceCanary()),
}
}
with pytest.raises(ValueError, match="not in the pickle serde allowlist"):
state.State.deserialize(tampered)

# And legitimate User payloads still work under the stricter local
# allowlist.
u = User(name="Bob", email="b@example.com")
og = state.State({"user": u})
ng = state.State.deserialize(og.serialize())
assert isinstance(ng["user"], User)
finally:
set_pickle_serde_allowlist(None)


def test_per_call_allowlist_overrides_registration_and_global():
"""The allowlist kwarg passed at deserialize time should take precedence
over both the registration-time allowlist and the global default."""
set_pickle_serde_allowlist([(__name__, "User")])
try:
pickle.register_type_to_pickle(User, allowlist=[(__name__, "User")])
# Legitimate User payload
u = User(name="Carol", email="c@example.com")
serialized = state.State({"user": u}).serialize()

# Pass an allowlist that does NOT include User; deserialization should
# fail even though both the registration and global allowlists permit
# it.
with pytest.raises(ValueError, match="not in the pickle serde allowlist"):
state.State.deserialize(serialized, allowlist=[("nothing", "Nothing")])

# And allowlist that includes User succeeds.
ng = state.State.deserialize(serialized, allowlist=[(__name__, "User")])
assert isinstance(ng["user"], User)
finally:
set_pickle_serde_allowlist(None)


def test_legacy_path_emits_security_warning_once_per_site():
"""When no allowlist is configured anywhere, deserialization should still
work (backward compatibility) but emit a SecurityWarning. The warning
should fire at most once per registration site."""
# Reset the dedup set so the warning will fire for this test, then put it
# back so we don't disturb other tests.
from burr.integrations.serde.pickle import _warned_call_sites

saved = set(_warned_call_sites)
_warned_call_sites.clear()
try:
set_pickle_serde_allowlist(None)
pickle.register_type_to_pickle(User)
u = User(name="Dana", email="d@example.com")
serialized = state.State({"user": u}).serialize()

with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
ng1 = state.State.deserialize(serialized)
ng2 = state.State.deserialize(serialized)
sec_warnings = [w for w in caught if issubclass(w.category, SecurityWarning)]
# Exactly one SecurityWarning across the two deserialize calls.
assert len(sec_warnings) == 1, (
f"expected 1 SecurityWarning, got {len(sec_warnings)}: "
f"{[str(w.message) for w in sec_warnings]}"
)
assert isinstance(ng1["user"], User)
assert isinstance(ng2["user"], User)
finally:
_warned_call_sites.clear()
_warned_call_sites.update(saved)
Loading