diff --git a/python/pyspark/pandas/data_type_ops/num_ops.py b/python/pyspark/pandas/data_type_ops/num_ops.py index e2a3180331eb9..a5447e01f2efe 100644 --- a/python/pyspark/pandas/data_type_ops/num_ops.py +++ b/python/pyspark/pandas/data_type_ops/num_ops.py @@ -439,6 +439,35 @@ class FractionalOps(NumericOps): def pretty_name(self) -> str: return "fractions" + def add(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: + _sanitize_list_like(right) + if not is_valid_operand_for_numeric_arithmetic(right): + raise TypeError("Addition can not be applied to given types.") + # Always raise TypeError for decimal-float mixed operations (SPARK-55818) + # This matches pandas behavior regardless of ANSI mode. + if _is_decimal_float_mixed(left, right): + raise TypeError("Addition can not be applied to given types.") + new_right = transform_boolean_operand_to_numeric(right, spark_type=left.spark.data_type) + + def wrapped_add(lc: PySparkColumn, rc: Any) -> PySparkColumn: + return _cast_back_float(PySparkColumn.__add__(lc, rc), left.dtype, right) + + return column_op(wrapped_add)(left, new_right) + + def sub(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: + _sanitize_list_like(right) + if not is_valid_operand_for_numeric_arithmetic(right): + raise TypeError("Subtraction can not be applied to given types.") + # Always raise TypeError for decimal-float mixed operations (SPARK-55818) + if _is_decimal_float_mixed(left, right): + raise TypeError("Subtraction can not be applied to given types.") + new_right = transform_boolean_operand_to_numeric(right, spark_type=left.spark.data_type) + + def wrapped_sub(lc: PySparkColumn, rc: Any) -> PySparkColumn: + return _cast_back_float(PySparkColumn.__sub__(lc, rc), left.dtype, right) + + return column_op(wrapped_sub)(left, new_right) + def mul(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) if not is_valid_operand_for_numeric_arithmetic(right): diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 812f66b882956..00ea4e13ad87f 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -4360,6 +4360,75 @@ def set_index( else: return DataFrame(internal) + def set_axis( + self, + labels: Any, + axis: Axis = 0, + ) -> "DataFrame": + """ + Assign desired index to given axis. + + Indexes for column or row labels can be changed by assigning a list-like or + Index. + + Parameters + ---------- + labels : list-like, Index + The values for the new index. + axis : {0 or 'index', 1 or 'columns'}, default 0 + The axis to update. The value 0 or 'index' identifies the row axis. + For Series there is only one axis to set. + + Returns + ------- + DataFrame + An object of same type as caller. + + See Also + -------- + DataFrame.rename_axis : Alter the name of the index or columns. + + Examples + -------- + >>> df = ps.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + + Change the row labels. + + >>> df.set_axis(['a', 'b', 'c']) + A B + a 1 4 + b 2 5 + c 3 6 + + Change the column labels. + + >>> df.set_axis(['I', 'II'], axis=1) + I II + 0 1 4 + 1 2 5 + 2 3 6 + + >>> df.set_axis(['I', 'II'], axis='columns') + I II + 0 1 4 + 1 2 5 + 2 3 6 + """ + axis = validate_axis(axis) + if axis == 0: + return self.set_index(pd.Index(labels)) # type: ignore[arg-type] + elif axis == 1: + if not isinstance(labels, pd.Index): + labels = pd.Index(labels) + if len(labels) != len(self.columns): + raise ValueError( + "Length mismatch: Expected axis has %d elements, " + "new values have %d elements" % (len(self.columns), len(labels)) + ) + return self.rename(columns=dict(zip(self.columns, labels))) + else: + raise ValueError("No axis named %s for object type DataFrame" % axis) + def reset_index( self, level: Optional[Union[int, Name, Sequence[Union[int, Name]]]] = None, diff --git a/python/pyspark/pandas/missing/frame.py b/python/pyspark/pandas/missing/frame.py index bdfa7574dc3d3..1299f2dbde618 100644 --- a/python/pyspark/pandas/missing/frame.py +++ b/python/pyspark/pandas/missing/frame.py @@ -41,7 +41,6 @@ class MissingPandasLikeDataFrame: convert_dtypes = _unsupported_function("convert_dtypes") infer_objects = _unsupported_function("infer_objects") reorder_levels = _unsupported_function("reorder_levels") - set_axis = _unsupported_function("set_axis") to_period = _unsupported_function("to_period") to_sql = _unsupported_function("to_sql") to_timestamp = _unsupported_function("to_timestamp") diff --git a/python/pyspark/pandas/missing/series.py b/python/pyspark/pandas/missing/series.py index 08f21f46b2cc1..cb3400e1e8668 100644 --- a/python/pyspark/pandas/missing/series.py +++ b/python/pyspark/pandas/missing/series.py @@ -39,7 +39,6 @@ class MissingPandasLikeSeries: convert_dtypes = _unsupported_function("convert_dtypes") infer_objects = _unsupported_function("infer_objects") reorder_levels = _unsupported_function("reorder_levels") - set_axis = _unsupported_function("set_axis") to_period = _unsupported_function("to_period") to_sql = _unsupported_function("to_sql") to_timestamp = _unsupported_function("to_timestamp") diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index 6496137824de2..59d1b0ab35448 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -1410,6 +1410,58 @@ class name else: return first_series(psdf) + def set_axis( + self, + labels: Any, + axis: Axis = 0, + ) -> "Series": + """ + Assign desired index to given axis. + + Indexes for row labels can be changed by assigning a list-like or + Index. + + Parameters + ---------- + labels : list-like, Index + The values for the new index. + axis : {0 or 'index'}, default 0 + The axis to update. The value 0 or 'index' identifies the row axis. + + Returns + ------- + Series + An object of same type as caller. + + See Also + -------- + Series.rename_axis : Alter the name of the index or columns. + + Examples + -------- + >>> s = ps.Series([1, 2, 3]) + >>> s.set_axis(['a', 'b', 'c']) + a 1 + b 2 + c 3 + dtype: int64 + """ + from pyspark.pandas.utils import validate_axis + + axis = validate_axis(axis) + if axis != 0: + raise ValueError("No axis named %s for object type Series" % axis) + if not isinstance(labels, pd.Index): + labels = pd.Index(labels) + if len(labels) != len(self): + raise ValueError( + "Length mismatch: Expected axis has %d elements, " + "new values have %d elements" % (len(self), len(labels)) + ) + psdf = self.to_frame() + psdf.index = labels + return first_series(psdf) + @property def index(self) -> "ps.Index": """The index (axis labels) Column of the Series. diff --git a/python/pyspark/pandas/tests/data_type_ops/test_decimal_float_arithmetic.py b/python/pyspark/pandas/tests/data_type_ops/test_decimal_float_arithmetic.py new file mode 100644 index 0000000000000..453184cfab43c --- /dev/null +++ b/python/pyspark/pandas/tests/data_type_ops/test_decimal_float_arithmetic.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import decimal + +from pyspark import pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase + + +# This file contains test cases for SPARK-55818: +# "Decimal-float mixed arithmetic should always raise TypeError" +# https://issues.apache.org/jira/browse/SPARK-55818 +class DecimalFloatArithmeticMixin: + """ + Tests that float Series + decimal.Decimal scalar always raises TypeError, + matching native pandas behavior, regardless of ANSI mode. + """ + + def test_float_add_decimal_raises(self): + psser = ps.Series([1.0, 2.0, 3.0]) + with self.assertRaises(TypeError): + psser + decimal.Decimal("1.5") + + def test_float_sub_decimal_raises(self): + psser = ps.Series([1.0, 2.0, 3.0]) + with self.assertRaises(TypeError): + psser - decimal.Decimal("1.5") + + def test_float_mul_decimal_raises(self): + psser = ps.Series([1.0, 2.0, 3.0]) + with self.assertRaises(TypeError): + psser * decimal.Decimal("2") + + def test_float_truediv_decimal_raises(self): + psser = ps.Series([1.0, 2.0, 3.0]) + with self.assertRaises(TypeError): + psser / decimal.Decimal("2") + + def test_float_floordiv_decimal_raises(self): + psser = ps.Series([1.0, 2.0, 3.0]) + with self.assertRaises(TypeError): + psser // decimal.Decimal("2") + + def test_float_mod_decimal_raises(self): + psser = ps.Series([1.0, 2.0, 3.0]) + with self.assertRaises(TypeError): + psser % decimal.Decimal("2") + + def test_radd_decimal_raises(self): + psser = ps.Series([1.0, 2.0, 3.0]) + with self.assertRaises(TypeError): + decimal.Decimal("1.5") + psser + + def test_rsub_decimal_raises(self): + psser = ps.Series([1.0, 2.0, 3.0]) + with self.assertRaises(TypeError): + decimal.Decimal("1.5") - psser + + def test_rmul_decimal_raises(self): + psser = ps.Series([1.0, 2.0, 3.0]) + with self.assertRaises(TypeError): + decimal.Decimal("2") * psser + + +class DecimalFloatArithmeticTests(DecimalFloatArithmeticMixin, PandasOnSparkTestCase): + pass + + +if __name__ == "__main__": + import unittest + + unittest.main() diff --git a/python/pyspark/pandas/tests/frame/test_set_axis.py b/python/pyspark/pandas/tests/frame/test_set_axis.py new file mode 100644 index 0000000000000..0b1b9597802f0 --- /dev/null +++ b/python/pyspark/pandas/tests/frame/test_set_axis.py @@ -0,0 +1,106 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pandas as pd + +from pyspark import pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase + + +# This file contains test cases for 'DataFrame.set_axis' +# https://issues.apache.org/jira/browse/SPARK-56375 +class FrameSetAxisMixin: + def test_set_axis_index(self): + pdf = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + psdf = ps.from_pandas(pdf) + + # axis=0 (index), list labels + self.assert_eq( + pdf.set_axis(["x", "y", "z"]), + psdf.set_axis(["x", "y", "z"]), + ) + + # axis=0 with "index" string + self.assert_eq( + pdf.set_axis(["a", "b", "c"], axis="index"), + psdf.set_axis(["a", "b", "c"], axis="index"), + ) + + # axis=0 with pd.Index + self.assert_eq( + pdf.set_axis(pd.Index([10, 20, 30])), + psdf.set_axis(pd.Index([10, 20, 30])), + ) + + def test_set_axis_columns(self): + pdf = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + psdf = ps.from_pandas(pdf) + + # axis=1 (columns), list labels + self.assert_eq( + pdf.set_axis(["I", "II"], axis=1), + psdf.set_axis(["I", "II"], axis=1), + ) + + # axis=1 with "columns" string + self.assert_eq( + pdf.set_axis(["col1", "col2"], axis="columns"), + psdf.set_axis(["col1", "col2"], axis="columns"), + ) + + # axis=1 with pd.Index + self.assert_eq( + pdf.set_axis(pd.Index(["X", "Y"]), axis=1), + psdf.set_axis(pd.Index(["X", "Y"]), axis=1), + ) + + def test_set_axis_errors(self): + psdf = ps.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + + # length mismatch on index (axis=0) + with self.assertRaises(ValueError): + psdf.set_axis(["x", "y"]) # only 2 labels for 3 rows + + # length mismatch on columns (axis=1) + with self.assertRaises(ValueError): + psdf.set_axis(["X"], axis=1) # only 1 label for 2 columns + + # invalid axis + with self.assertRaises(ValueError): + psdf.set_axis(["x", "y", "z"], axis=2) + + def test_set_axis_numeric_index(self): + pdf = pd.DataFrame({"A": [10, 20, 30], "B": [40, 50, 60]}, index=[0, 1, 2]) + psdf = ps.from_pandas(pdf) + + self.assert_eq( + pdf.set_axis([100, 200, 300]), + psdf.set_axis([100, 200, 300]), + ) + + +class FrameSetAxisTests(FrameSetAxisMixin, PandasOnSparkTestCase): + pass + + +if __name__ == "__main__": + import unittest + + from pyspark.testing.utils import PySparkTestCase # noqa + + unittest.main() diff --git a/python/pyspark/pandas/tests/series/test_set_axis.py b/python/pyspark/pandas/tests/series/test_set_axis.py new file mode 100644 index 0000000000000..52eebdc339cba --- /dev/null +++ b/python/pyspark/pandas/tests/series/test_set_axis.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pandas as pd + +from pyspark import pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase + + +# This file contains test cases for 'Series.set_axis' +# https://issues.apache.org/jira/browse/SPARK-56375 +class SeriesSetAxisMixin: + def test_set_axis_index(self): + pser = pd.Series([1, 2, 3]) + psser = ps.from_pandas(pser) + + # List labels + self.assert_eq( + pser.set_axis(["a", "b", "c"]), + psser.set_axis(["a", "b", "c"]), + ) + + # axis="index" string + self.assert_eq( + pser.set_axis(["x", "y", "z"], axis="index"), + psser.set_axis(["x", "y", "z"], axis="index"), + ) + + # pd.Index + self.assert_eq( + pser.set_axis(pd.Index([10, 20, 30])), + psser.set_axis(pd.Index([10, 20, 30])), + ) + + def test_set_axis_errors(self): + psser = ps.Series([1, 2, 3]) + + # length mismatch + with self.assertRaises(ValueError): + psser.set_axis(["a", "b"]) # only 2 labels for 3 elements + + # invalid axis (Series only supports 0/'index') + with self.assertRaises(ValueError): + psser.set_axis(["a", "b", "c"], axis=1) + + def test_set_axis_named(self): + pser = pd.Series([10, 20, 30], name="myval") + psser = ps.from_pandas(pser) + + self.assert_eq( + pser.set_axis(["x", "y", "z"]), + psser.set_axis(["x", "y", "z"]), + ) + + +class SeriesSetAxisTests(SeriesSetAxisMixin, PandasOnSparkTestCase): + pass + + +if __name__ == "__main__": + import unittest + + unittest.main()