Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/docs/source/reference/pyspark.pandas/frame.rst
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ Reindexing / Selection / Label manipulation
DataFrame.rename
DataFrame.rename_axis
DataFrame.reset_index
DataFrame.set_axis
DataFrame.set_index
DataFrame.swapaxes
DataFrame.swaplevel
Expand Down
1 change: 1 addition & 0 deletions python/docs/source/reference/pyspark.pandas/series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ Reindexing / Selection / Label manipulation
Series.reindex
Series.reindex_like
Series.reset_index
Series.set_axis
Series.sample
Series.searchsorted
Series.swaplevel
Expand Down
98 changes: 94 additions & 4 deletions python/pyspark/pandas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4244,6 +4244,100 @@ def style(self) -> "Styler":
else:
return self._to_internal_pandas().style

def set_axis(
self,
labels: Union[pd.Index, List[Any], "Index"],
*,
axis: Axis = 0,
) -> "DataFrame":
"""
Assign desired index to given axis.

Parameters
----------
labels : list-like or Index
The values for the new index.
axis : {{0 or 'index', 1 or 'columns'}}, default 0
The axis to update.

Returns
-------
DataFrame
A new DataFrame with the updated axis labels.

Raises
------
ValueError
If the length of `labels` does not match the length of the axis being updated.

See Also
--------
DataFrame.rename : Alter the axis labels of :class:`DataFrame`.
DataFrame.set_index : Set the DataFrame index using existing columns.

Examples
--------
>>> df = ps.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})

Replace column labels (axis=1):

>>> df.set_axis(["a", "b"], axis=1)
a b
0 1 4
1 2 5
2 3 6

Replace index labels (axis=0):

>>> df.set_axis(["x", "y", "z"], axis=0) # doctest: +SKIP
A B
x 1 4
y 2 5
z 3 6
"""
axis = validate_axis(axis)
if axis == 1:
psdf = self.copy()
psdf.columns = (
labels.to_pandas()
if isinstance(labels, ps.Index)
else labels
if isinstance(labels, pd.Index)
else pd.Index(labels)
)
return psdf
else:
pdf_labels = labels.to_pandas() if isinstance(labels, ps.Index) else pd.Index(labels)

psdf = self.reset_index(drop=True)
sdf = psdf._internal.spark_frame

seq_col = verify_temp_column_name(sdf, "__set_axis_seq__")
sdf = InternalFrame.attach_distributed_sequence_column(sdf, seq_col)

pdf_index = pdf_labels.to_frame(index=False)
# Use temp names to avoid collisions with existing columns.
temp_index_columns = [
verify_temp_column_name(sdf, "__set_axis_index_{}__".format(i))
for i in range(len(pdf_index.columns))
]
pdf_index.columns = temp_index_columns
pdf_index[seq_col] = range(len(pdf_index))
sdf_labels = default_session().createDataFrame(pdf_index)

joined = sdf.join(sdf_labels, on=seq_col, how="inner").drop(seq_col)

internal = psdf._internal.copy(
spark_frame=joined,
index_spark_columns=[joined[n] for n in temp_index_columns],
index_names=[(n,) if not isinstance(n, tuple) else n for n in pdf_labels.names],
index_fields=[
InternalField.from_struct_field(joined.schema[n]) for n in temp_index_columns
],
data_spark_columns=[joined[n] for n in psdf._internal.data_spark_column_names],
)
return DataFrame(internal)

def set_index(
self,
keys: Union[Name, List[Name]],
Expand Down Expand Up @@ -14203,10 +14297,6 @@ def _infer_objects_fallback(self, *args: Any, **kwargs: Any) -> "DataFrame":
_f = self._build_fallback_method("infer_objects")
return _f(*args, **kwargs)

def _set_axis_fallback(self, *args: Any, **kwargs: Any) -> "DataFrame":
_f = self._build_fallback_method("set_axis")
return _f(*args, **kwargs)

def __getattr__(self, key: str) -> Any:
if key.startswith("__"):
raise AttributeError(key)
Expand Down
1 change: 0 additions & 1 deletion python/pyspark/pandas/missing/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class MissingPandasLikeDataFrame:
convert_dtypes = _unsupported_function("convert_dtypes")
infer_objects = _unsupported_function("infer_objects")
reorder_levels = _unsupported_function("reorder_levels")
set_axis = _unsupported_function("set_axis")
to_period = _unsupported_function("to_period")
to_sql = _unsupported_function("to_sql")
to_timestamp = _unsupported_function("to_timestamp")
Expand Down
1 change: 0 additions & 1 deletion python/pyspark/pandas/missing/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class MissingPandasLikeSeries:
convert_dtypes = _unsupported_function("convert_dtypes")
infer_objects = _unsupported_function("infer_objects")
reorder_levels = _unsupported_function("reorder_levels")
set_axis = _unsupported_function("set_axis")
to_period = _unsupported_function("to_period")
to_sql = _unsupported_function("to_sql")
to_timestamp = _unsupported_function("to_timestamp")
Expand Down
47 changes: 47 additions & 0 deletions python/pyspark/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,6 +1259,53 @@ def name(self) -> Name:
def name(self, name: Name) -> None:
self.rename(name, inplace=True)

# TODO: Currently, changing index labels taking dictionary/Series is not supported.
def set_axis(
self,
labels: Union[pd.Index, List[Any], "ps.Index"],
*,
axis: Axis = 0,
) -> "Series":
"""
Assign desired index to given axis.

Parameters
----------
labels : list-like or Index
The values for the new index.
axis : {{0 or 'index'}}, default 0
The axis to update. For Series, only 0/'index' is valid.

Returns
-------
Series
A new Series with the updated axis labels.

Raises
------
ValueError
If the length of `labels` does not match the length of the axis being updated,
or if axis is not 0/'index'.

See Also
--------
Series.rename : Alter Series index labels or name.

Examples
--------
>>> s = ps.Series([1, 2, 3])

>>> s.set_axis(["a", "b", "c"]) # doctest: +SKIP
a 1
b 2
c 3
dtype: int64
"""
axis = validate_axis(axis)
if axis != 0:
raise ValueError("No axis named {0} for object type Series".format(axis))
return first_series(self.to_frame().set_axis(labels, axis=0)).rename(self.name)

# TODO: Currently, changing index labels taking dictionary/Series is not supported.
def rename(
self, index: Optional[Union[Name, Callable[[Any], Any]]] = None, **kwargs: Any
Expand Down
45 changes: 45 additions & 0 deletions python/pyspark/pandas/tests/frame/test_reindexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,51 @@ def test_sample(self):
with self.assertRaises(NotImplementedError):
psdf.sample(n=1)

def test_set_axis(self):
pdf = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
psdf = ps.from_pandas(pdf)

# axis=1: replace column labels
self.assert_eq(
pdf.set_axis(["a", "b"], axis=1),
psdf.set_axis(["a", "b"], axis=1),
)

# axis=1: replace with pd.Index (preserves name)
self.assert_eq(
pdf.set_axis(pd.Index(["x", "y"], name="cols"), axis=1),
psdf.set_axis(pd.Index(["x", "y"], name="cols"), axis=1),
)

# axis=1: MultiIndex columns
midx = pd.MultiIndex.from_tuples([("X", "a"), ("Y", "b")])
self.assert_eq(
pdf.set_axis(midx, axis=1),
psdf.set_axis(midx, axis=1),
)

# axis=0: replace index labels
self.assert_eq(
pdf.set_axis(["x", "y", "z"], axis=0),
psdf.set_axis(["x", "y", "z"], axis=0).sort_index(),
)

# axis=0: default axis is 0
self.assert_eq(
pdf.set_axis(["x", "y", "z"]),
psdf.set_axis(["x", "y", "z"]).sort_index(),
)

# axis=0: replace with pd.Index (preserves name)
self.assert_eq(
pdf.set_axis(pd.Index(["a", "b", "c"], name="idx"), axis=0),
psdf.set_axis(pd.Index(["a", "b", "c"], name="idx"), axis=0).sort_index(),
)

# axis=1: length mismatch error
with self.assertRaises(ValueError):
psdf.set_axis(["a"], axis=1)


class FrameReidexingTests(
FrameReindexingMixin,
Expand Down
20 changes: 20 additions & 0 deletions python/pyspark/pandas/tests/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,26 @@ def test_transform(self):
):
psser.transform(lambda x: x + 1, axis=1)

def test_set_axis(self):
pser = pd.Series([1, 2, 3], name="x")
psser = ps.from_pandas(pser)

# axis=0: replace index labels
self.assert_eq(
pser.set_axis(["a", "b", "c"]),
psser.set_axis(["a", "b", "c"]).sort_index(),
)

# axis=0: replace with pd.Index (preserves name)
self.assert_eq(
pser.set_axis(pd.Index(["a", "b", "c"], name="idx")),
psser.set_axis(pd.Index(["a", "b", "c"], name="idx")).sort_index(),
)

# preserves Series name
result = psser.set_axis(["a", "b", "c"])
self.assertEqual(result.name, "x")


class SeriesTests(
SeriesTestsMixin,
Expand Down