diff --git a/hamilton/io/default_data_loaders.py b/hamilton/io/default_data_loaders.py index 1810e1fc8..00b110cb3 100644 --- a/hamilton/io/default_data_loaders.py +++ b/hamilton/io/default_data_loaders.py @@ -21,12 +21,72 @@ import os import pathlib import pickle +import warnings from collections.abc import Collection -from typing import Any +from typing import Any, Optional from hamilton.io.data_adapters import DataLoader, DataSaver from hamilton.io.utils import get_file_metadata +# Module-level allowlist for `PickleLoader`. When set (non-None), only the +# listed `(module, qualname)` pairs may be reconstructed during unpickling. +# Set via `set_pickle_loader_allowlist([...])` or by passing `allowlist=...` +# to a `PickleLoader` instance directly. See `PickleLoader` for details. +_PICKLE_LOADER_ALLOWLIST: Optional[frozenset[tuple[str, str]]] = None + + +def set_pickle_loader_allowlist( + allowlist: Optional[Collection[tuple[str, str]]], +) -> None: + """Configure the module-level allowlist applied to every `PickleLoader`. + + Each entry is a ``(module, qualname)`` pair, e.g. ``("pandas.core.frame", "DataFrame")``. + When the allowlist is configured, the loader uses a restricted unpickler + that only reconstructs objects whose ``(module, qualname)`` pair is in + the allowlist. Anything else raises ``pickle.UnpicklingError``. + + Pass ``None`` to clear the allowlist and restore the unrestricted (and + noisy: a warning is emitted on each load) default behavior. + + Per-instance allowlists passed to ``PickleLoader(allowlist=...)`` take + precedence over the module-level value. + """ + global _PICKLE_LOADER_ALLOWLIST + if allowlist is None: + _PICKLE_LOADER_ALLOWLIST = None + else: + _PICKLE_LOADER_ALLOWLIST = frozenset((str(m), str(n)) for m, n in allowlist) + + +class _RestrictedUnpickler(pickle.Unpickler): + """`pickle.Unpickler` subclass that rejects classes outside an allowlist. + + The standard library's ``pickle`` module permits arbitrary callables to + be invoked at load time via the ``REDUCE`` opcode (the ``__reduce__`` + protocol). That makes a raw ``pickle.load`` on caller-supplied bytes a + code-execution primitive. This subclass overrides ``find_class`` to + consult an allowlist before resolving a global reference, so payloads + that try to import e.g. ``os.system`` are rejected before any side + effects occur. + """ + + def __init__( + self, + file, + allowlist: frozenset[tuple[str, str]], + ) -> None: + super().__init__(file) + self._allowlist = allowlist + + def find_class(self, module: str, name: str) -> Any: # type: ignore[override] + if (module, name) in self._allowlist: + return super().find_class(module, name) + raise pickle.UnpicklingError( + f"Refusing to load disallowed global '{module}.{name}'. " + "Add it to the allowlist passed to `PickleLoader(allowlist=...)` " + "or via `set_pickle_loader_allowlist(...)` if it is trusted." + ) + @dataclasses.dataclass class JSONDataLoader(DataLoader): @@ -126,7 +186,35 @@ def save_data(self, data: bytes | io.BytesIO) -> dict[str, Any]: @dataclasses.dataclass class PickleLoader(DataLoader): + """Loads a Python object from a pickle file. + + Python's ``pickle`` module is not safe to use on data from untrusted + sources -- a maliciously crafted pickle can execute arbitrary code at + load time via the ``__reduce__`` protocol. Hamilton's default has + historically been to call ``pickle.load`` unrestricted, which mirrors + ``pandas.read_pickle`` and ``joblib.load`` and assumes the caller wrote + (or otherwise trusts) the file. That assumption is brittle once pickle + files are shared across teams, downloaded as artifacts, or produced + upstream by a system the loader does not control. + + To opt into a safer mode, pass ``allowlist`` to this loader or call + :func:`set_pickle_loader_allowlist` to set a process-wide allowlist. + Each entry is a ``(module, qualname)`` pair; the loader will only + reconstruct objects whose qualified name appears in the allowlist and + will raise ``pickle.UnpicklingError`` otherwise. When no allowlist is + configured, the loader still loads the file (for backward compatibility) + but emits a warning to signal that the deserialization is unrestricted. + + :param path: Filesystem path of the pickle file to load. + :param allowlist: Optional collection of ``(module, qualname)`` pairs + permitted during unpickling. If ``None`` (the default), falls back + to the module-level allowlist configured via + :func:`set_pickle_loader_allowlist`; if that is also ``None``, the + load is unrestricted and emits a warning. + """ + path: str + allowlist: Optional[Collection[tuple[str, str]]] = None @classmethod def applicable_types(cls) -> Collection[type]: @@ -136,9 +224,27 @@ def applicable_types(cls) -> Collection[type]: def name(cls) -> str: return "pickle" + def _resolve_allowlist(self) -> Optional[frozenset[tuple[str, str]]]: + if self.allowlist is not None: + return frozenset((str(m), str(n)) for m, n in self.allowlist) + return _PICKLE_LOADER_ALLOWLIST + def load_data(self, type_: type[object]) -> tuple[object, dict[str, Any]]: + allowlist = self._resolve_allowlist() with open(self.path, "rb") as f: - return pickle.load(f), get_file_metadata(self.path) + if allowlist is None: + warnings.warn( + "PickleLoader is loading without an allowlist; pickle " + "deserialization can execute arbitrary code if the file " + "is not fully trusted. Configure an allowlist via " + "`PickleLoader(allowlist=...)` or " + "`hamilton.io.default_data_loaders.set_pickle_loader_allowlist(...)`.", + stacklevel=2, + ) + obj = pickle.load(f) + else: + obj = _RestrictedUnpickler(f, allowlist).load() + return obj, get_file_metadata(self.path) @dataclasses.dataclass diff --git a/tests/io/test_default_adapters.py b/tests/io/test_default_adapters.py index 8a1c30cb8..39294ed8e 100644 --- a/tests/io/test_default_adapters.py +++ b/tests/io/test_default_adapters.py @@ -18,10 +18,18 @@ import io import json import pathlib +import pickle import pytest -from hamilton.io.default_data_loaders import JSONDataLoader, JSONDataSaver, RawFileDataSaverBytes +from hamilton.io.default_data_loaders import ( + JSONDataLoader, + JSONDataSaver, + PickleLoader, + PickleSaver, + RawFileDataSaverBytes, + set_pickle_loader_allowlist, +) from hamilton.io.utils import FILE_METADATA @@ -91,3 +99,99 @@ def test_json_load_object_and_array(data, tmp_path: pathlib.Path): assert JSONDataLoader.applicable_types() == [dict, list] assert data == loaded_data + + +# Tracks side effects from a malicious pickle payload so we can assert the +# restricted unpickler refused to execute it. +_MALICIOUS_PAYLOAD_RAN = [] + + +def _record_side_effect(marker): + _MALICIOUS_PAYLOAD_RAN.append(marker) + return marker + + +class _MaliciousPayload: + """Object whose `__reduce__` triggers an arbitrary call at load time.""" + + def __reduce__(self): + return (_record_side_effect, ("pwned",)) + + +@pytest.fixture(autouse=True) +def _clear_pickle_allowlist(): + # Make sure the module-level allowlist does not leak across tests. + set_pickle_loader_allowlist(None) + _MALICIOUS_PAYLOAD_RAN.clear() + yield + set_pickle_loader_allowlist(None) + _MALICIOUS_PAYLOAD_RAN.clear() + + +def test_pickle_roundtrip_without_allowlist_warns(tmp_path: pathlib.Path): + """Backward-compat path: no allowlist -> load still works but warns.""" + data_path = tmp_path / "obj.pkl" + payload = {"a": 1, "b": [1, 2, 3]} + + PickleSaver(path=str(data_path)).save_data(payload) + loader = PickleLoader(path=str(data_path)) + + with pytest.warns(UserWarning, match="without an allowlist"): + loaded, _ = loader.load_data(dict) + + assert loaded == payload + + +def test_pickle_allowlist_permits_listed_types(tmp_path: pathlib.Path): + """Legitimate object whose module is on the allowlist roundtrips.""" + data_path = tmp_path / "obj.pkl" + payload = {"a": 1, "b": [1, 2, 3]} + + PickleSaver(path=str(data_path)).save_data(payload) + loader = PickleLoader(path=str(data_path), allowlist=[("builtins", "dict")]) + + loaded, _ = loader.load_data(dict) + assert loaded == payload + + +def test_pickle_allowlist_blocks_malicious_reduce(tmp_path: pathlib.Path): + """Malicious __reduce__ payload is rejected before side effects occur.""" + data_path = tmp_path / "evil.pkl" + with open(data_path, "wb") as f: + pickle.dump(_MaliciousPayload(), f) + + # Allowlist that does NOT include the helper used by the payload. + loader = PickleLoader(path=str(data_path), allowlist=[("builtins", "dict")]) + + with pytest.raises(pickle.UnpicklingError, match="Refusing to load"): + loader.load_data(object) + + assert _MALICIOUS_PAYLOAD_RAN == [] # side effect did not run + + +def test_pickle_module_level_allowlist_blocks_malicious(tmp_path: pathlib.Path): + """Module-level `set_pickle_loader_allowlist` applies to fresh instances.""" + data_path = tmp_path / "evil.pkl" + with open(data_path, "wb") as f: + pickle.dump(_MaliciousPayload(), f) + + set_pickle_loader_allowlist([("builtins", "dict")]) + loader = PickleLoader(path=str(data_path)) + + with pytest.raises(pickle.UnpicklingError): + loader.load_data(object) + + assert _MALICIOUS_PAYLOAD_RAN == [] + + +def test_pickle_instance_allowlist_overrides_module_default(tmp_path: pathlib.Path): + """Per-instance allowlist takes precedence over the module-level one.""" + data_path = tmp_path / "obj.pkl" + PickleSaver(path=str(data_path)).save_data({"a": 1}) + + # Module-level allowlist permits nothing useful, but the instance overrides it. + set_pickle_loader_allowlist([("nonexistent", "Thing")]) + loader = PickleLoader(path=str(data_path), allowlist=[("builtins", "dict")]) + + loaded, _ = loader.load_data(dict) + assert loaded == {"a": 1}