From 9e85ce309aaf32bfd2faced0b6904d637330d7d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81d=C3=A1m=20Lippai?= Date: Wed, 22 Apr 2026 19:47:39 -0400 Subject: [PATCH 1/3] Simplify extension array set_dims fallback --- xarray/core/duck_array_ops.py | 7 +++++ xarray/core/indexes.py | 8 ++++- xarray/core/variable.py | 7 ++++- xarray/tests/test_dataarray.py | 54 ++++++++++++++++++++++++++++++++++ xarray/tests/test_indexes.py | 18 ++++++++++++ 5 files changed, 92 insertions(+), 2 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index e10ab5a3558..7c8660f751e 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -409,6 +409,13 @@ def where(condition, x, y): xp = get_array_namespace(condition, x, y) dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool + if isinstance(condition, pd.api.extensions.ExtensionArray) and pd.api.types.is_bool_dtype( + condition.dtype + ): + # pandas nullable booleans can contain , which cannot be cast + # directly to bool. For masking semantics, treat missing condition + # values as False. + condition = condition.fillna(False) if not is_duck_array(condition): condition = asarray(condition, dtype=dtype, xp=xp) else: diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 2242e57e482..3488f65ce1b 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -914,7 +914,13 @@ def join( index = self.index.intersection(other.index) if is_allowed_extension_array_dtype(index.dtype): return type(self)(index, self.dim) - coord_dtype = np.result_type(self.coord_dtype, other.coord_dtype) + try: + coord_dtype = np.result_type(self.coord_dtype, other.coord_dtype) + except TypeError: + # pandas extension dtypes (e.g., CategoricalDtype) are not always + # compatible with numpy's type promotion even when the resulting + # index dtype is a regular NumPy dtype. + coord_dtype = get_valid_numpy_dtype(index) return type(self)(index, self.dim, coord_dtype=coord_dtype) def reindex_like( diff --git a/xarray/core/variable.py b/xarray/core/variable.py index b4cdf5cf6ca..415f662a15f 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1489,7 +1489,12 @@ def set_dims(self, dim, shape=None): # than the full "broadcast_to" semantics indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,) # TODO: switch this to ._data once we teach ExplicitlyIndexed arrays to handle indexers with None. - expanded_data = self.data[indexer] + try: + expanded_data = self.data[indexer] + except IndexError: + # Some pandas ExtensionArray backends (notably ArrowExtensionArray) + # don't support tuple indexers containing `None`. + expanded_data = np.asarray(self.data, dtype=object)[indexer] else: # elif shape is not None: dims_map = dict(zip(dim, shape, strict=True)) tmp_shape = tuple(dims_map[d] for d in expanded_dims) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 8eb52046a31..1c6d88ede02 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2006,6 +2006,60 @@ def test_dropna_extension_array(self, extension_array) -> None: assert filled.dtype == srs.dtype assert (filled.values == srs.values[1:]).all() + @pytest.mark.parametrize( + "dtype", + [ + pytest.param("Int64", id="Int64"), + pytest.param("int64[pyarrow]", id="int64[pyarrow]", marks=requires_pyarrow), + ], + ) + def test_where_nullable_int_no_float_promotion(self, dtype) -> None: + srs = pd.Series([1, 2, None], dtype=dtype, name="v") + da = srs.to_xarray() + actual = da.where(da > 1) + + assert actual.dtype != np.dtype(np.float64) + assert_array_equal(actual.isnull(), [True, False, True]) + assert actual.fillna(0).astype("int64").data.tolist() == [0, 2, 0] + + @pytest.mark.parametrize( + "dtype", + [ + pytest.param("Int64", id="Int64"), + pytest.param("int64[pyarrow]", id="int64[pyarrow]", marks=requires_pyarrow), + ], + ) + def test_concat_nullable_int_no_float_promotion(self, dtype) -> None: + srs = pd.Series([1, 2, None], dtype=dtype, name="v") + da = srs.to_xarray() + actual = xr.concat([da, da], dim="rep") + + assert actual.dtype != np.dtype(np.float64) + assert_array_equal(actual.isnull(), [[False, False, True], [False, False, True]]) + actual_values = actual.data.tolist() + assert actual_values[0][:2] == [1, 2] + assert actual_values[1][:2] == [1, 2] + assert pd.isna(actual_values[0][2]) + assert pd.isna(actual_values[1][2]) + + @requires_pyarrow + def test_date32_pyarrow_coord_align(self) -> None: + dates = pd.date_range("2024-01-01", periods=3, freq="D") + date32_index = pd.Index( + pd.Series(dates, name="date").astype("date32[pyarrow]"), name="time" + ) + + da1 = xr.DataArray([10.0, 20.0, 30.0], dims="time", coords={"time": date32_index}) + da2 = xr.DataArray( + [1.0, 2.0, 3.0], + dims="time", + coords={"time": pd.date_range("2024-01-01", periods=3, freq="D")}, + ) + + actual1, actual2 = xr.align(da1, da2, join="outer") + assert actual1.sizes["time"] == 3 + assert actual2.sizes["time"] == 3 + def test_rename(self) -> None: da = xr.DataArray( [1, 2, 3], dims="dim", name="name", coords={"coord": ("dim", [5, 6, 7])} diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 94adcc3b935..e35be123598 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -280,6 +280,24 @@ def test_join(self) -> None: assert actual.equals(expected) assert actual.coord_dtype == "=U4" + def test_join_categorical_and_object(self) -> None: + index1 = PandasIndex(pd.Index(["A", "B"], dtype=object), "x") + index2 = PandasIndex( + pd.CategoricalIndex(["B", "C"], categories=["A", "B", "C"], name="x"), + "x", + ) + + expected = PandasIndex(pd.Index(["B"], dtype=object), "x") + actual = index1.join(index2) + + assert actual.equals(expected) + assert actual.coord_dtype == np.dtype(object) + + expected = PandasIndex(pd.Index(["A", "B", "C"], dtype=object), "x") + actual = index1.join(index2, how="outer") + assert actual.equals(expected) + assert actual.coord_dtype == np.dtype(object) + def test_reindex_like(self) -> None: index1 = PandasIndex([0, 1, 2], "x") index2 = PandasIndex([1, 2, 3, 4], "x") From 90ca1adfa59de9cca7cd54f26963f7189bc4bce4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81d=C3=A1m=20Lippai?= Date: Wed, 22 Apr 2026 19:48:10 -0400 Subject: [PATCH 2/3] Drop pandas-indexing workaround and keep focused EA tests --- xarray/core/variable.py | 7 +------ xarray/tests/test_dataarray.py | 20 -------------------- 2 files changed, 1 insertion(+), 26 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 415f662a15f..b4cdf5cf6ca 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1489,12 +1489,7 @@ def set_dims(self, dim, shape=None): # than the full "broadcast_to" semantics indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,) # TODO: switch this to ._data once we teach ExplicitlyIndexed arrays to handle indexers with None. - try: - expanded_data = self.data[indexer] - except IndexError: - # Some pandas ExtensionArray backends (notably ArrowExtensionArray) - # don't support tuple indexers containing `None`. - expanded_data = np.asarray(self.data, dtype=object)[indexer] + expanded_data = self.data[indexer] else: # elif shape is not None: dims_map = dict(zip(dim, shape, strict=True)) tmp_shape = tuple(dims_map[d] for d in expanded_dims) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 1c6d88ede02..be468ffe4e6 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2022,26 +2022,6 @@ def test_where_nullable_int_no_float_promotion(self, dtype) -> None: assert_array_equal(actual.isnull(), [True, False, True]) assert actual.fillna(0).astype("int64").data.tolist() == [0, 2, 0] - @pytest.mark.parametrize( - "dtype", - [ - pytest.param("Int64", id="Int64"), - pytest.param("int64[pyarrow]", id="int64[pyarrow]", marks=requires_pyarrow), - ], - ) - def test_concat_nullable_int_no_float_promotion(self, dtype) -> None: - srs = pd.Series([1, 2, None], dtype=dtype, name="v") - da = srs.to_xarray() - actual = xr.concat([da, da], dim="rep") - - assert actual.dtype != np.dtype(np.float64) - assert_array_equal(actual.isnull(), [[False, False, True], [False, False, True]]) - actual_values = actual.data.tolist() - assert actual_values[0][:2] == [1, 2] - assert actual_values[1][:2] == [1, 2] - assert pd.isna(actual_values[0][2]) - assert pd.isna(actual_values[1][2]) - @requires_pyarrow def test_date32_pyarrow_coord_align(self) -> None: dates = pd.date_range("2024-01-01", periods=3, freq="D") From 449c821aa98f91f8f12902c585cd6131f9ed1263 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 Apr 2026 23:53:49 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/duck_array_ops.py | 6 +++--- xarray/tests/test_dataarray.py | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 7c8660f751e..2d30fb41eaa 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -409,9 +409,9 @@ def where(condition, x, y): xp = get_array_namespace(condition, x, y) dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool - if isinstance(condition, pd.api.extensions.ExtensionArray) and pd.api.types.is_bool_dtype( - condition.dtype - ): + if isinstance( + condition, pd.api.extensions.ExtensionArray + ) and pd.api.types.is_bool_dtype(condition.dtype): # pandas nullable booleans can contain , which cannot be cast # directly to bool. For masking semantics, treat missing condition # values as False. diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index be468ffe4e6..e5d791212f8 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2029,7 +2029,9 @@ def test_date32_pyarrow_coord_align(self) -> None: pd.Series(dates, name="date").astype("date32[pyarrow]"), name="time" ) - da1 = xr.DataArray([10.0, 20.0, 30.0], dims="time", coords={"time": date32_index}) + da1 = xr.DataArray( + [10.0, 20.0, 30.0], dims="time", coords={"time": date32_index} + ) da2 = xr.DataArray( [1.0, 2.0, 3.0], dims="time",