Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 108 additions & 2 deletions hamilton/io/default_data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,72 @@
import os
import pathlib
import pickle
import warnings
from collections.abc import Collection
from typing import Any
from typing import Any, Optional

from hamilton.io.data_adapters import DataLoader, DataSaver
from hamilton.io.utils import get_file_metadata

# Module-level allowlist for `PickleLoader`. When set (non-None), only the
# listed `(module, qualname)` pairs may be reconstructed during unpickling.
# Set via `set_pickle_loader_allowlist([...])` or by passing `allowlist=...`
# to a `PickleLoader` instance directly. See `PickleLoader` for details.
_PICKLE_LOADER_ALLOWLIST: Optional[frozenset[tuple[str, str]]] = None


def set_pickle_loader_allowlist(
allowlist: Optional[Collection[tuple[str, str]]],
) -> None:
"""Configure the module-level allowlist applied to every `PickleLoader`.

Each entry is a ``(module, qualname)`` pair, e.g. ``("pandas.core.frame", "DataFrame")``.
When the allowlist is configured, the loader uses a restricted unpickler
that only reconstructs objects whose ``(module, qualname)`` pair is in
the allowlist. Anything else raises ``pickle.UnpicklingError``.

Pass ``None`` to clear the allowlist and restore the unrestricted (and
noisy: a warning is emitted on each load) default behavior.

Per-instance allowlists passed to ``PickleLoader(allowlist=...)`` take
precedence over the module-level value.
"""
global _PICKLE_LOADER_ALLOWLIST
if allowlist is None:
_PICKLE_LOADER_ALLOWLIST = None
else:
_PICKLE_LOADER_ALLOWLIST = frozenset((str(m), str(n)) for m, n in allowlist)


class _RestrictedUnpickler(pickle.Unpickler):
"""`pickle.Unpickler` subclass that rejects classes outside an allowlist.

The standard library's ``pickle`` module permits arbitrary callables to
be invoked at load time via the ``REDUCE`` opcode (the ``__reduce__``
protocol). That makes a raw ``pickle.load`` on caller-supplied bytes a
code-execution primitive. This subclass overrides ``find_class`` to
consult an allowlist before resolving a global reference, so payloads
that try to import e.g. ``os.system`` are rejected before any side
effects occur.
"""

def __init__(
self,
file,
allowlist: frozenset[tuple[str, str]],
) -> None:
super().__init__(file)
self._allowlist = allowlist

def find_class(self, module: str, name: str) -> Any: # type: ignore[override]
if (module, name) in self._allowlist:
return super().find_class(module, name)
raise pickle.UnpicklingError(
f"Refusing to load disallowed global '{module}.{name}'. "
"Add it to the allowlist passed to `PickleLoader(allowlist=...)` "
"or via `set_pickle_loader_allowlist(...)` if it is trusted."
)


@dataclasses.dataclass
class JSONDataLoader(DataLoader):
Expand Down Expand Up @@ -126,7 +186,35 @@ def save_data(self, data: bytes | io.BytesIO) -> dict[str, Any]:

@dataclasses.dataclass
class PickleLoader(DataLoader):
"""Loads a Python object from a pickle file.

Python's ``pickle`` module is not safe to use on data from untrusted
sources -- a maliciously crafted pickle can execute arbitrary code at
load time via the ``__reduce__`` protocol. Hamilton's default has
historically been to call ``pickle.load`` unrestricted, which mirrors
``pandas.read_pickle`` and ``joblib.load`` and assumes the caller wrote
(or otherwise trusts) the file. That assumption is brittle once pickle
files are shared across teams, downloaded as artifacts, or produced
upstream by a system the loader does not control.

To opt into a safer mode, pass ``allowlist`` to this loader or call
:func:`set_pickle_loader_allowlist` to set a process-wide allowlist.
Each entry is a ``(module, qualname)`` pair; the loader will only
reconstruct objects whose qualified name appears in the allowlist and
will raise ``pickle.UnpicklingError`` otherwise. When no allowlist is
configured, the loader still loads the file (for backward compatibility)
but emits a warning to signal that the deserialization is unrestricted.

:param path: Filesystem path of the pickle file to load.
:param allowlist: Optional collection of ``(module, qualname)`` pairs
permitted during unpickling. If ``None`` (the default), falls back
to the module-level allowlist configured via
:func:`set_pickle_loader_allowlist`; if that is also ``None``, the
load is unrestricted and emits a warning.
"""

path: str
allowlist: Optional[Collection[tuple[str, str]]] = None

@classmethod
def applicable_types(cls) -> Collection[type]:
Expand All @@ -136,9 +224,27 @@ def applicable_types(cls) -> Collection[type]:
def name(cls) -> str:
return "pickle"

def _resolve_allowlist(self) -> Optional[frozenset[tuple[str, str]]]:
if self.allowlist is not None:
return frozenset((str(m), str(n)) for m, n in self.allowlist)
return _PICKLE_LOADER_ALLOWLIST

def load_data(self, type_: type[object]) -> tuple[object, dict[str, Any]]:
allowlist = self._resolve_allowlist()
with open(self.path, "rb") as f:
return pickle.load(f), get_file_metadata(self.path)
if allowlist is None:
warnings.warn(
"PickleLoader is loading without an allowlist; pickle "
"deserialization can execute arbitrary code if the file "
"is not fully trusted. Configure an allowlist via "
"`PickleLoader(allowlist=...)` or "
"`hamilton.io.default_data_loaders.set_pickle_loader_allowlist(...)`.",
stacklevel=2,
)
obj = pickle.load(f)
else:
obj = _RestrictedUnpickler(f, allowlist).load()
return obj, get_file_metadata(self.path)


@dataclasses.dataclass
Expand Down
106 changes: 105 additions & 1 deletion tests/io/test_default_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,18 @@
import io
import json
import pathlib
import pickle

import pytest

from hamilton.io.default_data_loaders import JSONDataLoader, JSONDataSaver, RawFileDataSaverBytes
from hamilton.io.default_data_loaders import (
JSONDataLoader,
JSONDataSaver,
PickleLoader,
PickleSaver,
RawFileDataSaverBytes,
set_pickle_loader_allowlist,
)
from hamilton.io.utils import FILE_METADATA


Expand Down Expand Up @@ -91,3 +99,99 @@ def test_json_load_object_and_array(data, tmp_path: pathlib.Path):

assert JSONDataLoader.applicable_types() == [dict, list]
assert data == loaded_data


# Tracks side effects from a malicious pickle payload so we can assert the
# restricted unpickler refused to execute it.
_MALICIOUS_PAYLOAD_RAN = []


def _record_side_effect(marker):
_MALICIOUS_PAYLOAD_RAN.append(marker)
return marker


class _MaliciousPayload:
"""Object whose `__reduce__` triggers an arbitrary call at load time."""

def __reduce__(self):
return (_record_side_effect, ("pwned",))


@pytest.fixture(autouse=True)
def _clear_pickle_allowlist():
# Make sure the module-level allowlist does not leak across tests.
set_pickle_loader_allowlist(None)
_MALICIOUS_PAYLOAD_RAN.clear()
yield
set_pickle_loader_allowlist(None)
_MALICIOUS_PAYLOAD_RAN.clear()


def test_pickle_roundtrip_without_allowlist_warns(tmp_path: pathlib.Path):
"""Backward-compat path: no allowlist -> load still works but warns."""
data_path = tmp_path / "obj.pkl"
payload = {"a": 1, "b": [1, 2, 3]}

PickleSaver(path=str(data_path)).save_data(payload)
loader = PickleLoader(path=str(data_path))

with pytest.warns(UserWarning, match="without an allowlist"):
loaded, _ = loader.load_data(dict)

assert loaded == payload


def test_pickle_allowlist_permits_listed_types(tmp_path: pathlib.Path):
"""Legitimate object whose module is on the allowlist roundtrips."""
data_path = tmp_path / "obj.pkl"
payload = {"a": 1, "b": [1, 2, 3]}

PickleSaver(path=str(data_path)).save_data(payload)
loader = PickleLoader(path=str(data_path), allowlist=[("builtins", "dict")])

loaded, _ = loader.load_data(dict)
assert loaded == payload


def test_pickle_allowlist_blocks_malicious_reduce(tmp_path: pathlib.Path):
"""Malicious __reduce__ payload is rejected before side effects occur."""
data_path = tmp_path / "evil.pkl"
with open(data_path, "wb") as f:
pickle.dump(_MaliciousPayload(), f)

# Allowlist that does NOT include the helper used by the payload.
loader = PickleLoader(path=str(data_path), allowlist=[("builtins", "dict")])

with pytest.raises(pickle.UnpicklingError, match="Refusing to load"):
loader.load_data(object)

assert _MALICIOUS_PAYLOAD_RAN == [] # side effect did not run


def test_pickle_module_level_allowlist_blocks_malicious(tmp_path: pathlib.Path):
"""Module-level `set_pickle_loader_allowlist` applies to fresh instances."""
data_path = tmp_path / "evil.pkl"
with open(data_path, "wb") as f:
pickle.dump(_MaliciousPayload(), f)

set_pickle_loader_allowlist([("builtins", "dict")])
loader = PickleLoader(path=str(data_path))

with pytest.raises(pickle.UnpicklingError):
loader.load_data(object)

assert _MALICIOUS_PAYLOAD_RAN == []


def test_pickle_instance_allowlist_overrides_module_default(tmp_path: pathlib.Path):
"""Per-instance allowlist takes precedence over the module-level one."""
data_path = tmp_path / "obj.pkl"
PickleSaver(path=str(data_path)).save_data({"a": 1})

# Module-level allowlist permits nothing useful, but the instance overrides it.
set_pickle_loader_allowlist([("nonexistent", "Thing")])
loader = PickleLoader(path=str(data_path), allowlist=[("builtins", "dict")])

loaded, _ = loader.load_data(dict)
assert loaded == {"a": 1}
Loading