diff --git a/python/docs/source/reference/pyspark.pandas/frame.rst b/python/docs/source/reference/pyspark.pandas/frame.rst index ccecb360a5293..aabcf671ff0f9 100644 --- a/python/docs/source/reference/pyspark.pandas/frame.rst +++ b/python/docs/source/reference/pyspark.pandas/frame.rst @@ -203,6 +203,7 @@ Reindexing / Selection / Label manipulation DataFrame.rename DataFrame.rename_axis DataFrame.reset_index + DataFrame.set_axis DataFrame.set_index DataFrame.swapaxes DataFrame.swaplevel diff --git a/python/docs/source/reference/pyspark.pandas/series.rst b/python/docs/source/reference/pyspark.pandas/series.rst index 67dc3582c27dc..b5aa672d531a4 100644 --- a/python/docs/source/reference/pyspark.pandas/series.rst +++ b/python/docs/source/reference/pyspark.pandas/series.rst @@ -197,6 +197,7 @@ Reindexing / Selection / Label manipulation Series.reindex Series.reindex_like Series.reset_index + Series.set_axis Series.sample Series.searchsorted Series.swaplevel diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index fb86e37999eb4..f0303ceb77f3d 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -4244,6 +4244,100 @@ def style(self) -> "Styler": else: return self._to_internal_pandas().style + def set_axis( + self, + labels: Union[pd.Index, List[Any], "Index"], + *, + axis: Axis = 0, + ) -> "DataFrame": + """ + Assign desired index to given axis. + + Parameters + ---------- + labels : list-like or Index + The values for the new index. + axis : {{0 or 'index', 1 or 'columns'}}, default 0 + The axis to update. + + Returns + ------- + DataFrame + A new DataFrame with the updated axis labels. + + Raises + ------ + ValueError + If the length of `labels` does not match the length of the axis being updated. + + See Also + -------- + DataFrame.rename : Alter the axis labels of :class:`DataFrame`. + DataFrame.set_index : Set the DataFrame index using existing columns. + + Examples + -------- + >>> df = ps.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + + Replace column labels (axis=1): + + >>> df.set_axis(["a", "b"], axis=1) + a b + 0 1 4 + 1 2 5 + 2 3 6 + + Replace index labels (axis=0): + + >>> df.set_axis(["x", "y", "z"], axis=0) # doctest: +SKIP + A B + x 1 4 + y 2 5 + z 3 6 + """ + axis = validate_axis(axis) + if axis == 1: + psdf = self.copy() + psdf.columns = ( + labels.to_pandas() + if isinstance(labels, ps.Index) + else labels + if isinstance(labels, pd.Index) + else pd.Index(labels) + ) + return psdf + else: + pdf_labels = labels.to_pandas() if isinstance(labels, ps.Index) else pd.Index(labels) + + psdf = self.reset_index(drop=True) + sdf = psdf._internal.spark_frame + + seq_col = verify_temp_column_name(sdf, "__set_axis_seq__") + sdf = InternalFrame.attach_distributed_sequence_column(sdf, seq_col) + + pdf_index = pdf_labels.to_frame(index=False) + # Use temp names to avoid collisions with existing columns. + temp_index_columns = [ + verify_temp_column_name(sdf, "__set_axis_index_{}__".format(i)) + for i in range(len(pdf_index.columns)) + ] + pdf_index.columns = temp_index_columns + pdf_index[seq_col] = range(len(pdf_index)) + sdf_labels = default_session().createDataFrame(pdf_index) + + joined = sdf.join(sdf_labels, on=seq_col, how="inner").drop(seq_col) + + internal = psdf._internal.copy( + spark_frame=joined, + index_spark_columns=[joined[n] for n in temp_index_columns], + index_names=[(n,) if not isinstance(n, tuple) else n for n in pdf_labels.names], + index_fields=[ + InternalField.from_struct_field(joined.schema[n]) for n in temp_index_columns + ], + data_spark_columns=[joined[n] for n in psdf._internal.data_spark_column_names], + ) + return DataFrame(internal) + def set_index( self, keys: Union[Name, List[Name]], @@ -14203,10 +14297,6 @@ def _infer_objects_fallback(self, *args: Any, **kwargs: Any) -> "DataFrame": _f = self._build_fallback_method("infer_objects") return _f(*args, **kwargs) - def _set_axis_fallback(self, *args: Any, **kwargs: Any) -> "DataFrame": - _f = self._build_fallback_method("set_axis") - return _f(*args, **kwargs) - def __getattr__(self, key: str) -> Any: if key.startswith("__"): raise AttributeError(key) diff --git a/python/pyspark/pandas/missing/frame.py b/python/pyspark/pandas/missing/frame.py index bdfa7574dc3d3..1299f2dbde618 100644 --- a/python/pyspark/pandas/missing/frame.py +++ b/python/pyspark/pandas/missing/frame.py @@ -41,7 +41,6 @@ class MissingPandasLikeDataFrame: convert_dtypes = _unsupported_function("convert_dtypes") infer_objects = _unsupported_function("infer_objects") reorder_levels = _unsupported_function("reorder_levels") - set_axis = _unsupported_function("set_axis") to_period = _unsupported_function("to_period") to_sql = _unsupported_function("to_sql") to_timestamp = _unsupported_function("to_timestamp") diff --git a/python/pyspark/pandas/missing/series.py b/python/pyspark/pandas/missing/series.py index 08f21f46b2cc1..cb3400e1e8668 100644 --- a/python/pyspark/pandas/missing/series.py +++ b/python/pyspark/pandas/missing/series.py @@ -39,7 +39,6 @@ class MissingPandasLikeSeries: convert_dtypes = _unsupported_function("convert_dtypes") infer_objects = _unsupported_function("infer_objects") reorder_levels = _unsupported_function("reorder_levels") - set_axis = _unsupported_function("set_axis") to_period = _unsupported_function("to_period") to_sql = _unsupported_function("to_sql") to_timestamp = _unsupported_function("to_timestamp") diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index 6496137824de2..7a413d4f84ac0 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -1259,6 +1259,53 @@ def name(self) -> Name: def name(self, name: Name) -> None: self.rename(name, inplace=True) + # TODO: Currently, changing index labels taking dictionary/Series is not supported. + def set_axis( + self, + labels: Union[pd.Index, List[Any], "ps.Index"], + *, + axis: Axis = 0, + ) -> "Series": + """ + Assign desired index to given axis. + + Parameters + ---------- + labels : list-like or Index + The values for the new index. + axis : {{0 or 'index'}}, default 0 + The axis to update. For Series, only 0/'index' is valid. + + Returns + ------- + Series + A new Series with the updated axis labels. + + Raises + ------ + ValueError + If the length of `labels` does not match the length of the axis being updated, + or if axis is not 0/'index'. + + See Also + -------- + Series.rename : Alter Series index labels or name. + + Examples + -------- + >>> s = ps.Series([1, 2, 3]) + + >>> s.set_axis(["a", "b", "c"]) # doctest: +SKIP + a 1 + b 2 + c 3 + dtype: int64 + """ + axis = validate_axis(axis) + if axis != 0: + raise ValueError("No axis named {0} for object type Series".format(axis)) + return first_series(self.to_frame().set_axis(labels, axis=0)).rename(self.name) + # TODO: Currently, changing index labels taking dictionary/Series is not supported. def rename( self, index: Optional[Union[Name, Callable[[Any], Any]]] = None, **kwargs: Any diff --git a/python/pyspark/pandas/tests/frame/test_reindexing.py b/python/pyspark/pandas/tests/frame/test_reindexing.py index 3e170ad2f4495..879d29aa3e979 100644 --- a/python/pyspark/pandas/tests/frame/test_reindexing.py +++ b/python/pyspark/pandas/tests/frame/test_reindexing.py @@ -956,6 +956,51 @@ def test_sample(self): with self.assertRaises(NotImplementedError): psdf.sample(n=1) + def test_set_axis(self): + pdf = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + psdf = ps.from_pandas(pdf) + + # axis=1: replace column labels + self.assert_eq( + pdf.set_axis(["a", "b"], axis=1), + psdf.set_axis(["a", "b"], axis=1), + ) + + # axis=1: replace with pd.Index (preserves name) + self.assert_eq( + pdf.set_axis(pd.Index(["x", "y"], name="cols"), axis=1), + psdf.set_axis(pd.Index(["x", "y"], name="cols"), axis=1), + ) + + # axis=1: MultiIndex columns + midx = pd.MultiIndex.from_tuples([("X", "a"), ("Y", "b")]) + self.assert_eq( + pdf.set_axis(midx, axis=1), + psdf.set_axis(midx, axis=1), + ) + + # axis=0: replace index labels + self.assert_eq( + pdf.set_axis(["x", "y", "z"], axis=0), + psdf.set_axis(["x", "y", "z"], axis=0).sort_index(), + ) + + # axis=0: default axis is 0 + self.assert_eq( + pdf.set_axis(["x", "y", "z"]), + psdf.set_axis(["x", "y", "z"]).sort_index(), + ) + + # axis=0: replace with pd.Index (preserves name) + self.assert_eq( + pdf.set_axis(pd.Index(["a", "b", "c"], name="idx"), axis=0), + psdf.set_axis(pd.Index(["a", "b", "c"], name="idx"), axis=0).sort_index(), + ) + + # axis=1: length mismatch error + with self.assertRaises(ValueError): + psdf.set_axis(["a"], axis=1) + class FrameReidexingTests( FrameReindexingMixin, diff --git a/python/pyspark/pandas/tests/series/test_series.py b/python/pyspark/pandas/tests/series/test_series.py index d4578bc01a935..6169f9b95784c 100644 --- a/python/pyspark/pandas/tests/series/test_series.py +++ b/python/pyspark/pandas/tests/series/test_series.py @@ -874,6 +874,26 @@ def test_transform(self): ): psser.transform(lambda x: x + 1, axis=1) + def test_set_axis(self): + pser = pd.Series([1, 2, 3], name="x") + psser = ps.from_pandas(pser) + + # axis=0: replace index labels + self.assert_eq( + pser.set_axis(["a", "b", "c"]), + psser.set_axis(["a", "b", "c"]).sort_index(), + ) + + # axis=0: replace with pd.Index (preserves name) + self.assert_eq( + pser.set_axis(pd.Index(["a", "b", "c"], name="idx")), + psser.set_axis(pd.Index(["a", "b", "c"], name="idx")).sort_index(), + ) + + # preserves Series name + result = psser.set_axis(["a", "b", "c"]) + self.assertEqual(result.name, "x") + class SeriesTests( SeriesTestsMixin,