Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 46 additions & 37 deletions xarray/backends/scipy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,51 @@ def __setitem__(self, key, value):
raise


# This is a dirty workaround to allow pickling of the flush_only_netcdf_file class.
# https://stackoverflow.com/questions/72766345/attributeerror-cant-pickle-local-object-in-multiprocessing
# TODO: Remove this after upstreaming the fixes to scipy.
class _PickleWorkaround:
flush_only_netcdf_file: type[scipy.io.netcdf_file]

@classmethod
def add_cls(cls, new_class: type[Any]) -> None:
setattr(cls, new_class.__name__, new_class)
new_class.__qualname__ = cls.__qualname__ + "." + new_class.__name__
# Cached class created once so its identity is stable for pickle.
# The class must not be re-created on each call to _open_scipy_netcdf;
# otherwise pickle sees a different class object when looking up the
# qualname and raises PicklingError (GH#11323).
#
# We set __qualname__ to a module-level name so pickle can always
# resolve the class via ``xarray.backends.scipy_.flush_only_netcdf_file``.
_flush_only_class: type[Any] | None = None


def _get_flush_only_class() -> type[Any]:
global _flush_only_class
if _flush_only_class is None:
import scipy.io

# TODO: Remove this after upstreaming these fixes.
class flush_only_netcdf_file(scipy.io.netcdf_file):
# scipy.io.netcdf_file.close() incorrectly closes file objects that
# were passed in as constructor arguments:
# https://github.com/scipy/scipy/issues/13905

# Instead of closing such files, only call flush(), which is
# equivalent as long as the netcdf_file object is not mmapped.
# This suffices to keep BytesIO objects open long enough to read
# their contents from to_netcdf(), but underlying files still get
# closed when the netcdf_file is garbage collected (via __del__),
# and will need to be fixed upstream in scipy.
def close(self):
if hasattr(self, "fp") and not self.fp.closed:
self.flush()
self.fp.seek(0) # allow file to be read again

def __del__(self):
# Remove the __del__ method, which in scipy is aliased to close().
# These files need to be closed explicitly by xarray.
pass

flush_only_netcdf_file.__qualname__ = "flush_only_netcdf_file"
_flush_only_class = flush_only_netcdf_file
# Make the class accessible as a module attribute so pickle can
# resolve it by qualname ``xarray.backends.scipy_.flush_only_netcdf_file``.
import sys

sys.modules[__name__].flush_only_netcdf_file = _flush_only_class # type: ignore[attr-defined]
return _flush_only_class


def _open_scipy_netcdf(
Expand All @@ -143,33 +178,7 @@ def _open_scipy_netcdf(
) -> scipy.io.netcdf_file:
import scipy.io

# TODO: Remove this after upstreaming these fixes.
class flush_only_netcdf_file(scipy.io.netcdf_file):
# scipy.io.netcdf_file.close() incorrectly closes file objects that
# were passed in as constructor arguments:
# https://github.com/scipy/scipy/issues/13905

# Instead of closing such files, only call flush(), which is
# equivalent as long as the netcdf_file object is not mmapped.
# This suffices to keep BytesIO objects open long enough to read
# their contents from to_netcdf(), but underlying files still get
# closed when the netcdf_file is garbage collected (via __del__),
# and will need to be fixed upstream in scipy.
def close(self):
if hasattr(self, "fp") and not self.fp.closed:
self.flush()
self.fp.seek(0) # allow file to be read again

def __del__(self):
# Remove the __del__ method, which in scipy is aliased to close().
# These files need to be closed explicitly by xarray.
pass

_PickleWorkaround.add_cls(flush_only_netcdf_file)

netcdf_file = (
_PickleWorkaround.flush_only_netcdf_file if flush_only else scipy.io.netcdf_file
)
netcdf_file = _get_flush_only_class() if flush_only else scipy.io.netcdf_file

# if the string ends with .gz, then gunzip and open as netcdf file
if isinstance(filename, str) and filename.endswith(".gz"):
Expand Down
15 changes: 15 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -4579,6 +4579,21 @@ def roundtrip(
with self.open(saved, **open_kwargs) as ds:
yield ds

def test_pickle_after_multiple_opens_from_bytes(self) -> None:
# Regression test for GH#11323: opening two scipy-backed datasets
# from BytesIO objects would overwrite the cached flush_only class,
# making the first dataset unpicklable.
original = Dataset({"foo": ("x", [1, 2, 3])})
netcdf_bytes = bytes(original.to_netcdf(engine=self.engine))
ds1 = open_dataset(BytesIO(netcdf_bytes), engine=self.engine)
ds2 = open_dataset(BytesIO(netcdf_bytes), engine=self.engine)
try:
with pickle.loads(pickle.dumps(ds1)) as unpickled:
assert_identical(unpickled, original)
finally:
ds1.close()
ds2.close()

@pytest.mark.asyncio
@pytest.mark.skip(reason="NetCDF backends don't support async loading")
async def test_load_async(self) -> None:
Expand Down
Loading