diff --git a/doc/whats-new.rst b/doc/whats-new.rst index effb199f18e..aada9ac8643 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -26,6 +26,12 @@ Deprecations Bug Fixes ~~~~~~~~~ +- The zarr backend now writes boolean arrays with native ``bool`` dtype instead + of converting them to ``int8``. Zarr supports ``bool`` natively, so the + ``BooleanCoder`` (which was designed for NetCDF compatibility) is now skipped + for zarr writes. Existing zarr stores written with the old ``int8`` encoding + are still read correctly. (:issue:`2937`, :pull:`11318`) + By `Evan Lyall `_. - Fix a major performance regression in :py:meth:`Coordinates.to_index` (and consequently :py:meth:`Dataset.to_dataframe`) caused by converting the cached code ndarrays into Python lists (:issue:`11305`). diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index d9279dc2de9..4b037747786 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -562,7 +562,12 @@ def encode_zarr_variable(var, needs_copy=True, name=None): A variable which has been encoded as described above. """ - var = conventions.encode_cf_variable(var, name=name) + coders = [ + c + for c in conventions._default_encode_cf_coders() + if not isinstance(c, coding.variables.BooleanCoder) + ] + var = conventions.encode_cf_variable(var, name=name, coders=coders) var = ensure_dtype_not_object(var, name=name) # zarr allows unicode, but not variable-length strings, so it's both diff --git a/xarray/conventions.py b/xarray/conventions.py index d3ee05e5da1..860d011bccb 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -65,8 +65,22 @@ def ensure_not_multiindex(var: Variable, name: T_Name = None) -> None: ) +def _default_encode_cf_coders(): + """Return the default list of coders used by encode_cf_variable.""" + return [ + CFDatetimeCoder(), + CFTimedeltaCoder(), + variables.CFScaleOffsetCoder(), + variables.CFMaskCoder(), + variables.NativeEnumCoder(), + variables.NonStringCoder(), + variables.DefaultFillvalueCoder(), + variables.BooleanCoder(), + ] + + def encode_cf_variable( - var: Variable, needs_copy: bool = True, name: T_Name = None + var: Variable, needs_copy: bool = True, name: T_Name = None, coders=None ) -> Variable: """ Converts a Variable into a Variable which follows some @@ -81,6 +95,8 @@ def encode_cf_variable( ---------- var : Variable A variable holding un-encoded data. + coders : list of VariableCoder, optional + List of coders to apply. If None, uses the default CF coder chain. Returns ------- @@ -89,16 +105,10 @@ def encode_cf_variable( """ ensure_not_multiindex(var, name=name) - for coder in [ - CFDatetimeCoder(), - CFTimedeltaCoder(), - variables.CFScaleOffsetCoder(), - variables.CFMaskCoder(), - variables.NativeEnumCoder(), - variables.NonStringCoder(), - variables.DefaultFillvalueCoder(), - variables.BooleanCoder(), - ]: + if coders is None: + coders = _default_encode_cf_coders() + + for coder in coders: var = coder.encode(var, name=name) for attr_name in CF_RELATED_DATA: diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index e42bfc2cd9f..8ccdf735e0b 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2709,6 +2709,60 @@ def roundtrip( async def test_load_async(self) -> None: await super().test_load_async() + def test_roundtrip_boolean_dtype(self) -> None: + original = create_boolean_data() + assert original["x"].dtype == "bool" + with self.create_zarr_target() as store_target: + self.save(original, store_target, consolidated=False) + # Verify on-disk zarr array uses native bool dtype (not int8) + zg = zarr.open_group(store_target, mode="r") + zarr_arr = zg["x"] + assert isinstance(zarr_arr, zarr.Array) + assert zarr_arr.dtype == np.dtype("bool") + assert "dtype" not in zarr_arr.attrs + with self.open( + store_target, backend_kwargs={"consolidated": False} + ) as actual: + assert_identical(original, actual) + assert actual["x"].dtype == "bool" + # Verify second roundtrip also preserves bool + with self.roundtrip(actual) as actual2: + assert_identical(original, actual2) + assert actual2["x"].dtype == "bool" + + def test_roundtrip_boolean_dtype_legacy_int8(self) -> None: + """Verify backward compat: old-style int8 + attrs['dtype']='bool' decodes to bool.""" + original = create_boolean_data() + with self.create_zarr_target() as store_target: + zg = zarr.open_group(store_target, mode="w") + data_int8 = original["x"].values.astype("i1") + is_v3_format = has_zarr_v3 and zg.metadata.zarr_format == 3 + if is_v3_format: + arr = zg.create_array( + "x", + shape=data_int8.shape, + dtype=data_int8.dtype, + fill_value=-1, + dimension_names=("t", "x"), + ) + else: + arr = zg.create_array( + "x", + shape=data_int8.shape, + dtype=data_int8.dtype, + fill_value=-1, + ) + arr[:] = data_int8 + arr.attrs["dtype"] = "bool" + arr.attrs["units"] = "-" + if not is_v3_format: + arr.attrs["_ARRAY_DIMENSIONS"] = ["t", "x"] + with self.open( + store_target, backend_kwargs={"consolidated": False} + ) as actual: + assert actual["x"].dtype == "bool" + np.testing.assert_array_equal(actual["x"].values, original["x"].values) + def test_roundtrip_bytes_with_fill_value(self): pytest.xfail("Broken by Zarr 3.0.7") diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index 38b835fd3d5..3c7cfe292c9 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -42,6 +42,31 @@ def test_booltype_array(self) -> None: assert_array_equal(bx.transpose((1, 0)), x.transpose((1, 0))) +class TestEncodeCfVariableCoders: + def test_empty_coders_is_identity(self) -> None: + var = Variable(["x"], np.array([True, False, True]), {"units": "test"}) + result = conventions.encode_cf_variable(var, coders=[]) + assert result.dtype == bool + assert_array_equal(result.values, var.values) + + def test_custom_coders_excludes_boolean_coder(self) -> None: + var = Variable(["x"], np.array([True, False, True])) + coders = [ + c + for c in conventions._default_encode_cf_coders() + if not isinstance(c, coding.variables.BooleanCoder) + ] + result = conventions.encode_cf_variable(var, coders=coders) + assert result.dtype == bool + assert "dtype" not in result.attrs + + def test_default_coders_encodes_bool_to_int8(self) -> None: + var = Variable(["x"], np.array([True, False, True])) + result = conventions.encode_cf_variable(var) + assert result.dtype == np.int8 + assert result.attrs.get("dtype") == "bool" + + class TestNativeEndiannessArray: def test(self) -> None: x = np.arange(5, dtype=">i8")