Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,17 +1231,26 @@ class NumpyReader(ImageReader):
npz_keys: if loading npz file, only load the specified keys, if None, load all the items.
stack the loaded items together to construct a new first dimension.
channel_dim: if not None, explicitly specify the channel dim, otherwise, treat the array as no channel.
allow_pickle: if True, allows loading pickled contents from NPY/NPZ files. Note that the default value of False
prevents the risk of remote code execution, set this to True only for loading known trusted data. If this
argument is False and pickled data is loaded, a ValueError will be raised.
kwargs: additional args for `numpy.load` API except `allow_pickle`. more details about available args:
https://numpy.org/doc/stable/reference/generated/numpy.load.html

"""

def __init__(self, npz_keys: KeysCollection | None = None, channel_dim: str | int | None = None, **kwargs):
def __init__(
self,
npz_keys: KeysCollection | None = None,
channel_dim: str | int | None = None,
allow_pickle: bool = False,
**kwargs,
):
super().__init__()
if npz_keys is not None:
npz_keys = ensure_tuple(npz_keys)
self.npz_keys = npz_keys
self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim
self.allow_pickle = allow_pickle
self.kwargs = kwargs

def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
Expand All @@ -1267,14 +1276,16 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
More details about available args:
https://numpy.org/doc/stable/reference/generated/numpy.load.html

Raises:
ValueError: when `self.allow_pickle` is False but loaded data contains pickled objects.
"""
img_: list[Nifti1Image] = []

filenames: Sequence[PathLike] = ensure_tuple(data)
kwargs_ = self.kwargs.copy()
kwargs_.update(kwargs)
for name in filenames:
img = np.load(name, allow_pickle=True, **kwargs_)
img = np.load(name, allow_pickle=self.allow_pickle, **kwargs_)
if Path(name).name.endswith(".npz"):
# load expected items from NPZ file
npz_keys = list(img.keys()) if self.npz_keys is None else self.npz_keys
Expand Down
11 changes: 11 additions & 0 deletions tests/data/test_numpy_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,13 @@ def test_npy_pickle(self):
np.save(filepath, test_data, allow_pickle=True)

reader = NumpyReader()

with self.assertRaises(ValueError):
reader.get_data(reader.read(filepath))

reader = NumpyReader(allow_pickle=True)
result = reader.get_data(reader.read(filepath))[0].item()

np.testing.assert_allclose(result["test"].shape, test_data["test"].shape)
np.testing.assert_allclose(result["test"], test_data["test"])

Expand All @@ -92,6 +98,11 @@ def test_kwargs(self):
np.save(filepath, test_data, allow_pickle=True)

reader = NumpyReader(mmap_mode="r")

with self.assertRaises(ValueError):
reader.get_data(reader.read(filepath, mmap_mode=None))

reader = NumpyReader(mmap_mode="r", allow_pickle=True)
result = reader.get_data(reader.read(filepath, mmap_mode=None))[0].item()
np.testing.assert_allclose(result["test"].shape, test_data["test"].shape)

Expand Down
Loading