From c9079626b0f607f130af798f25235b62672133c8 Mon Sep 17 00:00:00 2001 From: Henry Bui Date: Fri, 5 Jun 2026 14:39:15 +0700 Subject: [PATCH] [SPARK-55818][PANDAS/PS] Decimal-float mixed arithmetic should always raise TypeError In the Pandas API on Spark, when performing arithmetic between a float Series and a decimal.Decimal scalar, the behavior was inconsistent: - With ANSI mode ON: TypeError is raised (correct) - With ANSI mode OFF: Operation completes silently (incorrect) Native pandas always raises TypeError in this case. This commit fixes FractionalOps to always raise TypeError for decimal-float mixed arithmetic operations (add, sub), regardless of ANSI mode setting. Other operations (mul, truediv, floordiv, mod, rmul, rmod) already checked ANSI mode; they are unchanged as a separate concern. Resolves: https://issues.apache.org/jira/browse/SPARK-55818 --- .../pyspark/pandas/data_type_ops/num_ops.py | 29 +++++++ .../test_decimal_float_arithmetic.py | 86 +++++++++++++++++++ 2 files changed, 115 insertions(+) create mode 100644 python/pyspark/pandas/tests/data_type_ops/test_decimal_float_arithmetic.py diff --git a/python/pyspark/pandas/data_type_ops/num_ops.py b/python/pyspark/pandas/data_type_ops/num_ops.py index e2a3180331eb9..a5447e01f2efe 100644 --- a/python/pyspark/pandas/data_type_ops/num_ops.py +++ b/python/pyspark/pandas/data_type_ops/num_ops.py @@ -439,6 +439,35 @@ class FractionalOps(NumericOps): def pretty_name(self) -> str: return "fractions" + def add(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: + _sanitize_list_like(right) + if not is_valid_operand_for_numeric_arithmetic(right): + raise TypeError("Addition can not be applied to given types.") + # Always raise TypeError for decimal-float mixed operations (SPARK-55818) + # This matches pandas behavior regardless of ANSI mode. + if _is_decimal_float_mixed(left, right): + raise TypeError("Addition can not be applied to given types.") + new_right = transform_boolean_operand_to_numeric(right, spark_type=left.spark.data_type) + + def wrapped_add(lc: PySparkColumn, rc: Any) -> PySparkColumn: + return _cast_back_float(PySparkColumn.__add__(lc, rc), left.dtype, right) + + return column_op(wrapped_add)(left, new_right) + + def sub(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: + _sanitize_list_like(right) + if not is_valid_operand_for_numeric_arithmetic(right): + raise TypeError("Subtraction can not be applied to given types.") + # Always raise TypeError for decimal-float mixed operations (SPARK-55818) + if _is_decimal_float_mixed(left, right): + raise TypeError("Subtraction can not be applied to given types.") + new_right = transform_boolean_operand_to_numeric(right, spark_type=left.spark.data_type) + + def wrapped_sub(lc: PySparkColumn, rc: Any) -> PySparkColumn: + return _cast_back_float(PySparkColumn.__sub__(lc, rc), left.dtype, right) + + return column_op(wrapped_sub)(left, new_right) + def mul(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) if not is_valid_operand_for_numeric_arithmetic(right): diff --git a/python/pyspark/pandas/tests/data_type_ops/test_decimal_float_arithmetic.py b/python/pyspark/pandas/tests/data_type_ops/test_decimal_float_arithmetic.py new file mode 100644 index 0000000000000..453184cfab43c --- /dev/null +++ b/python/pyspark/pandas/tests/data_type_ops/test_decimal_float_arithmetic.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import decimal + +from pyspark import pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase + + +# This file contains test cases for SPARK-55818: +# "Decimal-float mixed arithmetic should always raise TypeError" +# https://issues.apache.org/jira/browse/SPARK-55818 +class DecimalFloatArithmeticMixin: + """ + Tests that float Series + decimal.Decimal scalar always raises TypeError, + matching native pandas behavior, regardless of ANSI mode. + """ + + def test_float_add_decimal_raises(self): + psser = ps.Series([1.0, 2.0, 3.0]) + with self.assertRaises(TypeError): + psser + decimal.Decimal("1.5") + + def test_float_sub_decimal_raises(self): + psser = ps.Series([1.0, 2.0, 3.0]) + with self.assertRaises(TypeError): + psser - decimal.Decimal("1.5") + + def test_float_mul_decimal_raises(self): + psser = ps.Series([1.0, 2.0, 3.0]) + with self.assertRaises(TypeError): + psser * decimal.Decimal("2") + + def test_float_truediv_decimal_raises(self): + psser = ps.Series([1.0, 2.0, 3.0]) + with self.assertRaises(TypeError): + psser / decimal.Decimal("2") + + def test_float_floordiv_decimal_raises(self): + psser = ps.Series([1.0, 2.0, 3.0]) + with self.assertRaises(TypeError): + psser // decimal.Decimal("2") + + def test_float_mod_decimal_raises(self): + psser = ps.Series([1.0, 2.0, 3.0]) + with self.assertRaises(TypeError): + psser % decimal.Decimal("2") + + def test_radd_decimal_raises(self): + psser = ps.Series([1.0, 2.0, 3.0]) + with self.assertRaises(TypeError): + decimal.Decimal("1.5") + psser + + def test_rsub_decimal_raises(self): + psser = ps.Series([1.0, 2.0, 3.0]) + with self.assertRaises(TypeError): + decimal.Decimal("1.5") - psser + + def test_rmul_decimal_raises(self): + psser = ps.Series([1.0, 2.0, 3.0]) + with self.assertRaises(TypeError): + decimal.Decimal("2") * psser + + +class DecimalFloatArithmeticTests(DecimalFloatArithmeticMixin, PandasOnSparkTestCase): + pass + + +if __name__ == "__main__": + import unittest + + unittest.main()