diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 9d5f33e8947..13dff16aee3 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -122,16 +122,51 @@ def __setitem__(self, key, value): raise -# This is a dirty workaround to allow pickling of the flush_only_netcdf_file class. -# https://stackoverflow.com/questions/72766345/attributeerror-cant-pickle-local-object-in-multiprocessing -# TODO: Remove this after upstreaming the fixes to scipy. -class _PickleWorkaround: - flush_only_netcdf_file: type[scipy.io.netcdf_file] - - @classmethod - def add_cls(cls, new_class: type[Any]) -> None: - setattr(cls, new_class.__name__, new_class) - new_class.__qualname__ = cls.__qualname__ + "." + new_class.__name__ +# Cached class created once so its identity is stable for pickle. +# The class must not be re-created on each call to _open_scipy_netcdf; +# otherwise pickle sees a different class object when looking up the +# qualname and raises PicklingError (GH#11323). +# +# We set __qualname__ to a module-level name so pickle can always +# resolve the class via ``xarray.backends.scipy_.flush_only_netcdf_file``. +_flush_only_class: type[Any] | None = None + + +def _get_flush_only_class() -> type[Any]: + global _flush_only_class + if _flush_only_class is None: + import scipy.io + + # TODO: Remove this after upstreaming these fixes. + class flush_only_netcdf_file(scipy.io.netcdf_file): + # scipy.io.netcdf_file.close() incorrectly closes file objects that + # were passed in as constructor arguments: + # https://github.com/scipy/scipy/issues/13905 + + # Instead of closing such files, only call flush(), which is + # equivalent as long as the netcdf_file object is not mmapped. + # This suffices to keep BytesIO objects open long enough to read + # their contents from to_netcdf(), but underlying files still get + # closed when the netcdf_file is garbage collected (via __del__), + # and will need to be fixed upstream in scipy. + def close(self): + if hasattr(self, "fp") and not self.fp.closed: + self.flush() + self.fp.seek(0) # allow file to be read again + + def __del__(self): + # Remove the __del__ method, which in scipy is aliased to close(). + # These files need to be closed explicitly by xarray. + pass + + flush_only_netcdf_file.__qualname__ = "flush_only_netcdf_file" + _flush_only_class = flush_only_netcdf_file + # Make the class accessible as a module attribute so pickle can + # resolve it by qualname ``xarray.backends.scipy_.flush_only_netcdf_file``. + import sys + + sys.modules[__name__].flush_only_netcdf_file = _flush_only_class # type: ignore[attr-defined] + return _flush_only_class def _open_scipy_netcdf( @@ -143,33 +178,7 @@ def _open_scipy_netcdf( ) -> scipy.io.netcdf_file: import scipy.io - # TODO: Remove this after upstreaming these fixes. - class flush_only_netcdf_file(scipy.io.netcdf_file): - # scipy.io.netcdf_file.close() incorrectly closes file objects that - # were passed in as constructor arguments: - # https://github.com/scipy/scipy/issues/13905 - - # Instead of closing such files, only call flush(), which is - # equivalent as long as the netcdf_file object is not mmapped. - # This suffices to keep BytesIO objects open long enough to read - # their contents from to_netcdf(), but underlying files still get - # closed when the netcdf_file is garbage collected (via __del__), - # and will need to be fixed upstream in scipy. - def close(self): - if hasattr(self, "fp") and not self.fp.closed: - self.flush() - self.fp.seek(0) # allow file to be read again - - def __del__(self): - # Remove the __del__ method, which in scipy is aliased to close(). - # These files need to be closed explicitly by xarray. - pass - - _PickleWorkaround.add_cls(flush_only_netcdf_file) - - netcdf_file = ( - _PickleWorkaround.flush_only_netcdf_file if flush_only else scipy.io.netcdf_file - ) + netcdf_file = _get_flush_only_class() if flush_only else scipy.io.netcdf_file # if the string ends with .gz, then gunzip and open as netcdf file if isinstance(filename, str) and filename.endswith(".gz"): diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 4e08b71260b..33a9e3c9deb 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -4579,6 +4579,21 @@ def roundtrip( with self.open(saved, **open_kwargs) as ds: yield ds + def test_pickle_after_multiple_opens_from_bytes(self) -> None: + # Regression test for GH#11323: opening two scipy-backed datasets + # from BytesIO objects would overwrite the cached flush_only class, + # making the first dataset unpicklable. + original = Dataset({"foo": ("x", [1, 2, 3])}) + netcdf_bytes = bytes(original.to_netcdf(engine=self.engine)) + ds1 = open_dataset(BytesIO(netcdf_bytes), engine=self.engine) + ds2 = open_dataset(BytesIO(netcdf_bytes), engine=self.engine) + try: + with pickle.loads(pickle.dumps(ds1)) as unpickled: + assert_identical(unpickled, original) + finally: + ds1.close() + ds2.close() + @pytest.mark.asyncio @pytest.mark.skip(reason="NetCDF backends don't support async loading") async def test_load_async(self) -> None: