diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index e10ab5a3558..2d30fb41eaa 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -409,6 +409,13 @@ def where(condition, x, y): xp = get_array_namespace(condition, x, y) dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool + if isinstance( + condition, pd.api.extensions.ExtensionArray + ) and pd.api.types.is_bool_dtype(condition.dtype): + # pandas nullable booleans can contain , which cannot be cast + # directly to bool. For masking semantics, treat missing condition + # values as False. + condition = condition.fillna(False) if not is_duck_array(condition): condition = asarray(condition, dtype=dtype, xp=xp) else: diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 2242e57e482..3488f65ce1b 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -914,7 +914,13 @@ def join( index = self.index.intersection(other.index) if is_allowed_extension_array_dtype(index.dtype): return type(self)(index, self.dim) - coord_dtype = np.result_type(self.coord_dtype, other.coord_dtype) + try: + coord_dtype = np.result_type(self.coord_dtype, other.coord_dtype) + except TypeError: + # pandas extension dtypes (e.g., CategoricalDtype) are not always + # compatible with numpy's type promotion even when the resulting + # index dtype is a regular NumPy dtype. + coord_dtype = get_valid_numpy_dtype(index) return type(self)(index, self.dim, coord_dtype=coord_dtype) def reindex_like( diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 8eb52046a31..e5d791212f8 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2006,6 +2006,42 @@ def test_dropna_extension_array(self, extension_array) -> None: assert filled.dtype == srs.dtype assert (filled.values == srs.values[1:]).all() + @pytest.mark.parametrize( + "dtype", + [ + pytest.param("Int64", id="Int64"), + pytest.param("int64[pyarrow]", id="int64[pyarrow]", marks=requires_pyarrow), + ], + ) + def test_where_nullable_int_no_float_promotion(self, dtype) -> None: + srs = pd.Series([1, 2, None], dtype=dtype, name="v") + da = srs.to_xarray() + actual = da.where(da > 1) + + assert actual.dtype != np.dtype(np.float64) + assert_array_equal(actual.isnull(), [True, False, True]) + assert actual.fillna(0).astype("int64").data.tolist() == [0, 2, 0] + + @requires_pyarrow + def test_date32_pyarrow_coord_align(self) -> None: + dates = pd.date_range("2024-01-01", periods=3, freq="D") + date32_index = pd.Index( + pd.Series(dates, name="date").astype("date32[pyarrow]"), name="time" + ) + + da1 = xr.DataArray( + [10.0, 20.0, 30.0], dims="time", coords={"time": date32_index} + ) + da2 = xr.DataArray( + [1.0, 2.0, 3.0], + dims="time", + coords={"time": pd.date_range("2024-01-01", periods=3, freq="D")}, + ) + + actual1, actual2 = xr.align(da1, da2, join="outer") + assert actual1.sizes["time"] == 3 + assert actual2.sizes["time"] == 3 + def test_rename(self) -> None: da = xr.DataArray( [1, 2, 3], dims="dim", name="name", coords={"coord": ("dim", [5, 6, 7])} diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 94adcc3b935..e35be123598 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -280,6 +280,24 @@ def test_join(self) -> None: assert actual.equals(expected) assert actual.coord_dtype == "=U4" + def test_join_categorical_and_object(self) -> None: + index1 = PandasIndex(pd.Index(["A", "B"], dtype=object), "x") + index2 = PandasIndex( + pd.CategoricalIndex(["B", "C"], categories=["A", "B", "C"], name="x"), + "x", + ) + + expected = PandasIndex(pd.Index(["B"], dtype=object), "x") + actual = index1.join(index2) + + assert actual.equals(expected) + assert actual.coord_dtype == np.dtype(object) + + expected = PandasIndex(pd.Index(["A", "B", "C"], dtype=object), "x") + actual = index1.join(index2, how="outer") + assert actual.equals(expected) + assert actual.coord_dtype == np.dtype(object) + def test_reindex_like(self) -> None: index1 = PandasIndex([0, 1, 2], "x") index2 = PandasIndex([1, 2, 3, 4], "x")