diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 93f335e625b..f6ad4ec5856 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -162,6 +162,9 @@ Deprecations Bug Fixes ~~~~~~~~~ +- Fix :py:meth:`DataArray.sortby` and :py:meth:`Dataset.sortby` raising ``KeyError`` + when passing a tuple of dimension names, e.g. ``da.sortby(da.dims)`` (:issue:`4821`). + By `Timothy Hodson `_. - Fix multi-coordinate indexes being dropped in :py:meth:`DataArray._replace_maybe_drop_dims` (e.g. after reducing over an unrelated dimension) and in :py:meth:`Dataset._copy_listed` (e.g. when subsetting a Dataset by variable names). Both paths now consult diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index d0df9bc061b..0dea8832066 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -5230,10 +5230,10 @@ def dot( def sortby( self, variables: ( - Hashable + str | DataArray - | Sequence[Hashable | DataArray] - | Callable[[Self], Hashable | DataArray | Sequence[Hashable | DataArray]] + | Iterable[Hashable | DataArray] + | Callable[[Self], str | DataArray | Iterable[Hashable | DataArray]] ), ascending: bool = True, ) -> Self: @@ -5255,7 +5255,7 @@ def sortby( Parameters ---------- - variables : Hashable, DataArray, sequence of Hashable or DataArray, or Callable + variables : str, DataArray, iterable of Hashable or DataArray, or Callable 1D DataArray objects or name(s) of 1D variable(s) in coords whose values are used to sort this array. If a callable, the callable is passed this object, and the result is used as the value for cond. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a046c5d0f9e..29d25c9f2e5 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8071,10 +8071,10 @@ def roll( def sortby( self, variables: ( - Hashable + str | DataArray - | Sequence[Hashable | DataArray] - | Callable[[Self], Hashable | DataArray | list[Hashable | DataArray]] + | Iterable[Hashable | DataArray] + | Callable[[Self], str | DataArray | Iterable[Hashable | DataArray]] ), ascending: bool = True, ) -> Self: @@ -8097,7 +8097,7 @@ def sortby( Parameters ---------- - variables : Hashable, DataArray, sequence of Hashable or DataArray, or Callable + variables : str, DataArray, iterable of Hashable or DataArray, or Callable 1D DataArray objects or name(s) of 1D variable(s) in coords whose values are used to sort this array. If a callable, the callable is passed this object, and the result is used as the value for cond. @@ -8149,8 +8149,12 @@ def sortby( if callable(variables): variables = variables(self) - if not isinstance(variables, list): + if isinstance(variables, (str, DataArray)) or not isinstance( + variables, Iterable + ): variables = [variables] + else: + variables = list(variables) arrays = [v if isinstance(v, DataArray) else self[v] for v in variables] aligned_vars = align(self, *arrays, join="left") aligned_self = cast("Self", aligned_vars[0]) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 8eb52046a31..d99be397085 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -4558,6 +4558,11 @@ def test_sortby(self) -> None: actual = da.sortby(["x", "y"]) assert_equal(actual, expected) + # test tuple of dimension names (GH4821) + expected = sorted2d + actual = da.sortby(("x", "y")) + assert_equal(actual, expected) + @requires_bottleneck def test_rank(self) -> None: # floats diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index cd4f2c10eb2..54507890571 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -7323,6 +7323,11 @@ def test_sortby(self) -> None: actual = ds.sortby(["x", "y"], ascending=False) assert_equal(actual, ds) + # test tuple of dimension names (GH4821) + expected = sorted2d + actual = ds.sortby(("x", "y")) + assert_equal(actual, expected) + def test_sortby_descending_nans(self) -> None: # Regression test for https://github.com/pydata/xarray/issues/7358 # NaN values should remain at the end when sorting in descending order