Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ Bug Fixes
``"dayofweek"``, and ``"week"``, respectively; now they return objects named
``"days_in_month"``, ``"weekday"``, and ``"weekofyear"`` (:pull:`11270`). By
`Spencer Clark <https://github.com/spencerkclark>`_.
- :py:meth:`DataArray.shift`, :py:meth:`Dataset.shift`, :py:meth:`DataArray.pad`
and :py:meth:`Dataset.pad` no longer raise ``TypeError`` on variables backed by
pandas nullable extension arrays (e.g. ``Int64``, ``Float64``, ``boolean``).
The padding/fill values are now inserted using the extension array's own
missing value (``pd.NA``), preserving the extension dtype instead of letting
numpy coerce it to an ``object`` or numeric ndarray (:issue:`10301`).

Documentation
~~~~~~~~~~~~~
Expand Down
9 changes: 9 additions & 0 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,15 @@ def moveaxis(array, source, destination):

def pad(array, pad_width, **kwargs):
xp = get_array_namespace(array)
if (
isinstance(array, pd.api.extensions.ExtensionArray)
and kwargs.get("mode", "constant") == "constant"
):
# Wrap so that NEP-18 dispatch routes to __extension_duck_array__pad,
# which preserves the extension dtype (and its NA) instead of letting
# numpy coerce to a plain ndarray. See GH #10301. Non-constant modes
# don't introduce fill values, so the historical numpy path is fine.
array = PandasExtensionArray(array)
return xp.pad(array, pad_width, **kwargs)


Expand Down
64 changes: 64 additions & 0 deletions xarray/core/extension_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,70 @@ def __extension_duck_array__where(
return cast(T_ExtensionArray, pd.Series(x).where(condition, y).array) # type: ignore[call-overload]


def _pad_pair(value: Any, default: Any = 0) -> tuple[Any, Any]:
"""Normalize a numpy-pad-style per-axis argument to a ``(before, after)`` pair.

``np.pad`` accepts ``pad_width`` / ``constant_values`` as a scalar (applied to
both sides), a ``(before, after)`` tuple, or a length-1 sequence wrapping such a
tuple (xarray passes one entry per axis, and extension arrays are 1d). This
collapses all of those forms to a single ``(before, after)`` pair.
"""
if value is None:
value = default
# unwrap a length-1 per-axis sequence, e.g. [(before, after)]
if (
isinstance(value, list | tuple)
and len(value) == 1
and isinstance(value[0], list | tuple)
):
value = value[0]
if isinstance(value, list | tuple):
before, after = value
return before, after
return value, value # scalar -> same on both sides


@implements(np.pad)
def __extension_duck_array__pad(
array: T_ExtensionArray,
pad_width: Any,
mode: str = "constant",
**kwargs: Any,
) -> T_ExtensionArray:
"""Constant-mode padding for a 1d pandas extension array, preserving its dtype.

Building the result from the extension array itself (rather than letting numpy
coerce it to an ndarray) keeps the extension dtype and lets fill values such as
``pd.NA`` round-trip with full type fidelity. Only constant mode is handled here;
the other modes reuse existing values and are routed through numpy by the caller
(:func:`xarray.core.duck_array_ops.pad`).
"""
if mode != "constant":
raise NotImplementedError(
f"Only mode='constant' padding is implemented for pandas extension "
f"arrays, got mode={mode!r}."
)

before, after = _pad_pair(pad_width)
fill_before, fill_after = _pad_pair(kwargs.get("constant_values", 0))

# Build the padding as same-dtype extension arrays and concatenate, so the
# result keeps ``array.dtype`` and fill values (incl. ``pd.NA``) round-trip
# natively. Concatenating per-side honors ``constant_values=(before, after)``.
constructor = array.dtype.construct_array_type()
parts: list[ExtensionArray] = []
if before:
parts.append(
constructor._from_sequence([fill_before] * before, dtype=array.dtype)
)
parts.append(array)
if after:
parts.append(
constructor._from_sequence([fill_after] * after, dtype=array.dtype)
)
return type(array)._concat_same_type(parts)


def _replace_duck(args, replacer: Callable[[PandasExtensionArray], Any]) -> list:
args_as_list = list(args)
for index, value in enumerate(args_as_list):
Expand Down
45 changes: 45 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2009,6 +2009,51 @@ def test_dropna_extension_array(self, extension_array) -> None:
assert filled.dtype == srs.dtype
assert (filled.values == srs.values[1:]).all()

@pytest.mark.parametrize(
"extension_array",
[
pytest.param(pd.array([1, 2, 3], dtype="Int64"), id="Int64"),
pytest.param(pd.array([1.5, 2.5, 3.5], dtype="Float64"), id="Float64"),
pytest.param(pd.array([True, False, True], dtype="boolean"), id="boolean"),
],
)
def test_shift_extension_array(self, extension_array) -> None:
# GH #10301: shift on a nullable/extension dtype used to raise because the
# NA fill value could not be inserted into a numpy-coerced array.
srs: pd.Series = pd.Series(index=np.array([1, 2, 3]), data=extension_array)
da = srs.to_xarray()
shifted = da.shift(index=1)
assert shifted.dtype == da.dtype
pd.testing.assert_extension_array_equal(
shifted.data,
extension_array._from_sequence(
[pd.NA, extension_array[0], extension_array[1]],
dtype=extension_array.dtype,
),
)

@pytest.mark.parametrize(
"extension_array",
[
pytest.param(pd.array([1, 2, 3], dtype="Int64"), id="Int64"),
pytest.param(pd.array([1.5, 2.5, 3.5], dtype="Float64"), id="Float64"),
pytest.param(pd.array([True, False, True], dtype="boolean"), id="boolean"),
],
)
def test_pad_extension_array(self, extension_array) -> None:
# GH #10301: pad on a nullable/extension dtype keeps the dtype + native NA.
srs: pd.Series = pd.Series(index=np.array([1, 2, 3]), data=extension_array)
da = srs.to_xarray()
padded = da.pad(index=(1, 1))
assert padded.dtype == da.dtype
pd.testing.assert_extension_array_equal(
padded.data,
extension_array._from_sequence(
[pd.NA, *list(extension_array), pd.NA],
dtype=extension_array.dtype,
),
)

def test_rename(self) -> None:
da = xr.DataArray(
[1, 2, 3], dims="dim", name="name", coords={"coord": ("dim", [5, 6, 7])}
Expand Down
28 changes: 28 additions & 0 deletions xarray/tests/test_duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
least_squares,
mean,
np_timedelta64_to_float,
pad,
pd_timedelta_to_float,
push,
py_timedelta_to_float,
Expand Down Expand Up @@ -210,6 +211,33 @@ def test_concatenate_extension_duck_array(self, categorical1, categorical2):
== type(categorical1)._concat_same_type((categorical1, categorical2))
).all()

@pytest.mark.parametrize(
"values,dtype",
[
([1, 2, 3], "Int64"),
([1.5, 2.5, 3.5], "Float64"),
([True, False, True], "boolean"),
],
)
def test_pad_extension_duck_array(self, values, dtype):
# GH #10301: padding (used by shift/pad) must keep the extension dtype
# and use its native NA instead of letting numpy coerce to an ndarray.
array = pd.array(values, dtype=dtype)
padded = pad(array, [(1, 0)], mode="constant", constant_values=pd.NA)
assert isinstance(padded, PandasExtensionArray)
assert padded.dtype == array.dtype
pd.testing.assert_extension_array_equal(
padded.array,
array._from_sequence([pd.NA, *list(array)], dtype=array.dtype),
)

def test_pad_extension_duck_array_per_side_fill(self):
# constant_values=(before, after) is honored independently per side
array = pd.array([1, 2, 3], dtype="Int64")
padded = pad(array, [(1, 1)], mode="constant", constant_values=(0, 9))
assert isinstance(padded, PandasExtensionArray)
assert list(padded.array) == [0, 1, 2, 3, 9]

@requires_pyarrow
def test_extension_array_pyarrow_concatenate(self, arrow1, arrow2):
concatenated = concatenate(
Expand Down
Loading