From d8bc265f323d6c667e1246f389896f8d26344f8c Mon Sep 17 00:00:00 2001 From: Elijah Ben Izzy Date: Sun, 14 Jun 2026 13:08:41 -0700 Subject: [PATCH] Add optional module allowlist to PickleLoader deserialization PickleLoader has historically called pickle.load on caller-supplied bytes without restriction, mirroring pandas.read_pickle / joblib.load. That puts the safety burden entirely on the caller, which is brittle once pickle files are shared across teams or produced by upstream systems. Introduce an optional allowlist of (module, qualname) pairs, either per-instance via PickleLoader(allowlist=...) or process-wide via set_pickle_loader_allowlist(...). When configured, a RestrictedUnpickler overrides find_class to reject anything off the allowlist, raising UnpicklingError before any global is resolved. When no allowlist is set, behavior is unchanged for backward compatibility but a warning is now emitted on every load to signal the unrestricted mode. Tests cover: the legacy unrestricted-with-warning path, allowlist permitting a legitimate roundtrip, per-instance allowlist blocking a malicious __reduce__ payload, module-level allowlist applying to fresh instances, and per-instance overriding the module-level default. Co-Authored-By: Claude Opus 4.7 (1M context) --- hamilton/io/default_data_loaders.py | 110 +++++++++++++++++++++++++++- tests/io/test_default_adapters.py | 106 ++++++++++++++++++++++++++- 2 files changed, 213 insertions(+), 3 deletions(-) diff --git a/hamilton/io/default_data_loaders.py b/hamilton/io/default_data_loaders.py index 1810e1fc8..00b110cb3 100644 --- a/hamilton/io/default_data_loaders.py +++ b/hamilton/io/default_data_loaders.py @@ -21,12 +21,72 @@ import os import pathlib import pickle +import warnings from collections.abc import Collection -from typing import Any +from typing import Any, Optional from hamilton.io.data_adapters import DataLoader, DataSaver from hamilton.io.utils import get_file_metadata +# Module-level allowlist for `PickleLoader`. When set (non-None), only the +# listed `(module, qualname)` pairs may be reconstructed during unpickling. +# Set via `set_pickle_loader_allowlist([...])` or by passing `allowlist=...` +# to a `PickleLoader` instance directly. See `PickleLoader` for details. +_PICKLE_LOADER_ALLOWLIST: Optional[frozenset[tuple[str, str]]] = None + + +def set_pickle_loader_allowlist( + allowlist: Optional[Collection[tuple[str, str]]], +) -> None: + """Configure the module-level allowlist applied to every `PickleLoader`. + + Each entry is a ``(module, qualname)`` pair, e.g. ``("pandas.core.frame", "DataFrame")``. + When the allowlist is configured, the loader uses a restricted unpickler + that only reconstructs objects whose ``(module, qualname)`` pair is in + the allowlist. Anything else raises ``pickle.UnpicklingError``. + + Pass ``None`` to clear the allowlist and restore the unrestricted (and + noisy: a warning is emitted on each load) default behavior. + + Per-instance allowlists passed to ``PickleLoader(allowlist=...)`` take + precedence over the module-level value. + """ + global _PICKLE_LOADER_ALLOWLIST + if allowlist is None: + _PICKLE_LOADER_ALLOWLIST = None + else: + _PICKLE_LOADER_ALLOWLIST = frozenset((str(m), str(n)) for m, n in allowlist) + + +class _RestrictedUnpickler(pickle.Unpickler): + """`pickle.Unpickler` subclass that rejects classes outside an allowlist. + + The standard library's ``pickle`` module permits arbitrary callables to + be invoked at load time via the ``REDUCE`` opcode (the ``__reduce__`` + protocol). That makes a raw ``pickle.load`` on caller-supplied bytes a + code-execution primitive. This subclass overrides ``find_class`` to + consult an allowlist before resolving a global reference, so payloads + that try to import e.g. ``os.system`` are rejected before any side + effects occur. + """ + + def __init__( + self, + file, + allowlist: frozenset[tuple[str, str]], + ) -> None: + super().__init__(file) + self._allowlist = allowlist + + def find_class(self, module: str, name: str) -> Any: # type: ignore[override] + if (module, name) in self._allowlist: + return super().find_class(module, name) + raise pickle.UnpicklingError( + f"Refusing to load disallowed global '{module}.{name}'. " + "Add it to the allowlist passed to `PickleLoader(allowlist=...)` " + "or via `set_pickle_loader_allowlist(...)` if it is trusted." + ) + @dataclasses.dataclass class JSONDataLoader(DataLoader): @@ -126,7 +186,35 @@ def save_data(self, data: bytes | io.BytesIO) -> dict[str, Any]: @dataclasses.dataclass class PickleLoader(DataLoader): + """Loads a Python object from a pickle file. + + Python's ``pickle`` module is not safe to use on data from untrusted + sources -- a maliciously crafted pickle can execute arbitrary code at + load time via the ``__reduce__`` protocol. Hamilton's default has + historically been to call ``pickle.load`` unrestricted, which mirrors + ``pandas.read_pickle`` and ``joblib.load`` and assumes the caller wrote + (or otherwise trusts) the file. That assumption is brittle once pickle + files are shared across teams, downloaded as artifacts, or produced + upstream by a system the loader does not control. + + To opt into a safer mode, pass ``allowlist`` to this loader or call + :func:`set_pickle_loader_allowlist` to set a process-wide allowlist. + Each entry is a ``(module, qualname)`` pair; the loader will only + reconstruct objects whose qualified name appears in the allowlist and + will raise ``pickle.UnpicklingError`` otherwise. When no allowlist is + configured, the loader still loads the file (for backward compatibility) + but emits a warning to signal that the deserialization is unrestricted. + + :param path: Filesystem path of the pickle file to load. + :param allowlist: Optional collection of ``(module, qualname)`` pairs + permitted during unpickling. If ``None`` (the default), falls back + to the module-level allowlist configured via + :func:`set_pickle_loader_allowlist`; if that is also ``None``, the + load is unrestricted and emits a warning. + """ + path: str + allowlist: Optional[Collection[tuple[str, str]]] = None @classmethod def applicable_types(cls) -> Collection[type]: @@ -136,9 +224,27 @@ def applicable_types(cls) -> Collection[type]: def name(cls) -> str: return "pickle" + def _resolve_allowlist(self) -> Optional[frozenset[tuple[str, str]]]: + if self.allowlist is not None: + return frozenset((str(m), str(n)) for m, n in self.allowlist) + return _PICKLE_LOADER_ALLOWLIST + def load_data(self, type_: type[object]) -> tuple[object, dict[str, Any]]: + allowlist = self._resolve_allowlist() with open(self.path, "rb") as f: - return pickle.load(f), get_file_metadata(self.path) + if allowlist is None: + warnings.warn( + "PickleLoader is loading without an allowlist; pickle " + "deserialization can execute arbitrary code if the file " + "is not fully trusted. Configure an allowlist via " + "`PickleLoader(allowlist=...)` or " + "`hamilton.io.default_data_loaders.set_pickle_loader_allowlist(...)`.", + stacklevel=2, + ) + obj = pickle.load(f) + else: + obj = _RestrictedUnpickler(f, allowlist).load() + return obj, get_file_metadata(self.path) @dataclasses.dataclass diff --git a/tests/io/test_default_adapters.py b/tests/io/test_default_adapters.py index 8a1c30cb8..39294ed8e 100644 --- a/tests/io/test_default_adapters.py +++ b/tests/io/test_default_adapters.py @@ -18,10 +18,18 @@ import io import json import pathlib +import pickle import pytest -from hamilton.io.default_data_loaders import JSONDataLoader, JSONDataSaver, RawFileDataSaverBytes +from hamilton.io.default_data_loaders import ( + JSONDataLoader, + JSONDataSaver, + PickleLoader, + PickleSaver, + RawFileDataSaverBytes, + set_pickle_loader_allowlist, +) from hamilton.io.utils import FILE_METADATA @@ -91,3 +99,99 @@ def test_json_load_object_and_array(data, tmp_path: pathlib.Path): assert JSONDataLoader.applicable_types() == [dict, list] assert data == loaded_data + + +# Tracks side effects from a malicious pickle payload so we can assert the +# restricted unpickler refused to execute it. +_MALICIOUS_PAYLOAD_RAN = [] + + +def _record_side_effect(marker): + _MALICIOUS_PAYLOAD_RAN.append(marker) + return marker + + +class _MaliciousPayload: + """Object whose `__reduce__` triggers an arbitrary call at load time.""" + + def __reduce__(self): + return (_record_side_effect, ("pwned",)) + + +@pytest.fixture(autouse=True) +def _clear_pickle_allowlist(): + # Make sure the module-level allowlist does not leak across tests. + set_pickle_loader_allowlist(None) + _MALICIOUS_PAYLOAD_RAN.clear() + yield + set_pickle_loader_allowlist(None) + _MALICIOUS_PAYLOAD_RAN.clear() + + +def test_pickle_roundtrip_without_allowlist_warns(tmp_path: pathlib.Path): + """Backward-compat path: no allowlist -> load still works but warns.""" + data_path = tmp_path / "obj.pkl" + payload = {"a": 1, "b": [1, 2, 3]} + + PickleSaver(path=str(data_path)).save_data(payload) + loader = PickleLoader(path=str(data_path)) + + with pytest.warns(UserWarning, match="without an allowlist"): + loaded, _ = loader.load_data(dict) + + assert loaded == payload + + +def test_pickle_allowlist_permits_listed_types(tmp_path: pathlib.Path): + """Legitimate object whose module is on the allowlist roundtrips.""" + data_path = tmp_path / "obj.pkl" + payload = {"a": 1, "b": [1, 2, 3]} + + PickleSaver(path=str(data_path)).save_data(payload) + loader = PickleLoader(path=str(data_path), allowlist=[("builtins", "dict")]) + + loaded, _ = loader.load_data(dict) + assert loaded == payload + + +def test_pickle_allowlist_blocks_malicious_reduce(tmp_path: pathlib.Path): + """Malicious __reduce__ payload is rejected before side effects occur.""" + data_path = tmp_path / "evil.pkl" + with open(data_path, "wb") as f: + pickle.dump(_MaliciousPayload(), f) + + # Allowlist that does NOT include the helper used by the payload. + loader = PickleLoader(path=str(data_path), allowlist=[("builtins", "dict")]) + + with pytest.raises(pickle.UnpicklingError, match="Refusing to load"): + loader.load_data(object) + + assert _MALICIOUS_PAYLOAD_RAN == [] # side effect did not run + + +def test_pickle_module_level_allowlist_blocks_malicious(tmp_path: pathlib.Path): + """Module-level `set_pickle_loader_allowlist` applies to fresh instances.""" + data_path = tmp_path / "evil.pkl" + with open(data_path, "wb") as f: + pickle.dump(_MaliciousPayload(), f) + + set_pickle_loader_allowlist([("builtins", "dict")]) + loader = PickleLoader(path=str(data_path)) + + with pytest.raises(pickle.UnpicklingError): + loader.load_data(object) + + assert _MALICIOUS_PAYLOAD_RAN == [] + + +def test_pickle_instance_allowlist_overrides_module_default(tmp_path: pathlib.Path): + """Per-instance allowlist takes precedence over the module-level one.""" + data_path = tmp_path / "obj.pkl" + PickleSaver(path=str(data_path)).save_data({"a": 1}) + + # Module-level allowlist permits nothing useful, but the instance overrides it. + set_pickle_loader_allowlist([("nonexistent", "Thing")]) + loader = PickleLoader(path=str(data_path), allowlist=[("builtins", "dict")]) + + loaded, _ = loader.load_data(dict) + assert loaded == {"a": 1}