Skip to content
Merged
8 changes: 8 additions & 0 deletions datafusion/spark/src/function/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub mod floor;
pub mod hex;
pub mod modulus;
pub mod negative;
pub mod pow;
pub mod rint;
pub mod round;
pub mod trigonometry;
Expand All @@ -42,6 +43,7 @@ make_udf_function!(floor::SparkFloor, floor);
make_udf_function!(hex::SparkHex, hex);
make_udf_function!(modulus::SparkMod, modulus);
make_udf_function!(modulus::SparkPmod, pmod);
make_udf_function!(pow::SparkPow, pow);
make_udf_function!(rint::SparkRint, rint);
make_udf_function!(round::SparkRound, round);
make_udf_function!(unhex::SparkUnhex, unhex);
Expand All @@ -66,6 +68,11 @@ pub mod expr_fn {
export_functions!((hex, "Computes hex value of the given column.", arg1));
export_functions!((modulus, "Returns the remainder of division of the first argument by the second argument.", arg1 arg2));
export_functions!((pmod, "Returns the positive remainder of division of the first argument by the second argument.", arg1 arg2));
export_functions!((
pow,
"Returns base raised to the power of exponent. Returns Infinity for pow(0, negative).",
arg1 arg2
));
export_functions!((
rint,
"Returns the double value that is closest in value to the argument and is equal to a mathematical integer.",
Expand Down Expand Up @@ -102,6 +109,7 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
hex(),
modulus(),
pmod(),
pow(),
rint(),
round(),
unhex(),
Expand Down
152 changes: 152 additions & 0 deletions datafusion/spark/src/function/math/pow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Spark-compatible `pow` / `power` function.
//!
//! Unlike the default DataFusion (PostgreSQL) implementation, Spark returns
//! `Infinity` for `pow(0, <negative>)` rather than raising an error.

use std::sync::Arc;

use arrow::array::{Array, ArrayRef, Float64Array};
use arrow::datatypes::DataType;

use datafusion_common::utils::take_function_args;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
};
use datafusion_functions::math::power::PowerFunc;

/// Spark-compatible implementation of `pow` / `power`.
///
/// Behavioural difference from the DataFusion default:
/// - `pow(0, <negative>)` → `Infinity` (IEEE 754 / Spark semantics)
/// The default raises `"zero raised to a negative power is undefined"` to
/// match PostgreSQL.
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkPow {
inner: PowerFunc,
aliases: Vec<String>,
}

impl Default for SparkPow {
fn default() -> Self {
Self::new()
}
}

impl SparkPow {
pub fn new() -> Self {
Self {
inner: PowerFunc::new(),
// SparkPow is named "pow"; expose "power" as an alias so that
// both names resolve to Spark semantics when this crate is active.
aliases: vec!["power".to_string()],
}
}
}

impl ScalarUDFImpl for SparkPow {
fn name(&self) -> &str {
"pow"
}

fn aliases(&self) -> &[String] {
&self.aliases
}

fn signature(&self) -> &Signature {
self.inner.signature()
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
self.inner.return_type(arg_types)
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
// Only Float64 × Float64 needs the Spark override.
// Decimal / integer / mixed-type paths are delegated to the standard
// PowerFunc which already handles them correctly (decimal can't
// represent Infinity anyway).
match args.args.as_slice() {
[base, exponent]
if matches!(base.data_type(), DataType::Float64)
&& matches!(exponent.data_type(), DataType::Float64) => {}
_ => return self.inner.invoke_with_args(args),
}

let num_rows = args.number_rows;

// ── Scalar × Scalar fast path ────────────────────────────────────────
// Pattern-match on the slice to avoid any ownership issues.
if let [
ColumnarValue::Scalar(ScalarValue::Float64(base)),
ColumnarValue::Scalar(ScalarValue::Float64(exp)),
] = args.args.as_slice()
{
// base and exp are &Option<f64>; Option<f64> is Copy.
let result = (*base).zip(*exp).map(|(base, exp)| {
if base == 0.0 && exp < 0.0 {
f64::INFINITY
} else {
base.powf(exp)
}
});
return Ok(ColumnarValue::Scalar(ScalarValue::Float64(result)));
}

// ── Array path ───────────────────────────────────────────────────────
let [base, exponent] = take_function_args(self.name(), &args.args)?;

let base_arr: ArrayRef = base.to_array(num_rows)?;
let exp_arr: ArrayRef = exponent.to_array(num_rows)?;
Comment on lines +113 to +117

let base_f64 = base_arr
.as_any()
.downcast_ref::<Float64Array>()
.expect("base must be Float64Array");
let exp_f64 = exp_arr
.as_any()
.downcast_ref::<Float64Array>()
.expect("exponent must be Float64Array");

// Spark: 0^negative = +Infinity (covers both 0.0 and -0.0)
// IEEE 754: 0.0^-1.0 = +Infinity, -0.0^-1.0 = -Infinity
// Thus we need an explicit guard for base == 0.0 to ensure +Infinity.
let result: Float64Array = base_f64
.iter()
.zip(exp_f64.iter())
.map(|(base, exp)| match (base, exp) {
(Some(base), Some(exp)) => {
if base == 0.0 && exp < 0.0 {
Some(f64::INFINITY)
} else {
Some(base.powf(exp))
}
}
_ => None,
})
.collect();

Ok(ColumnarValue::Array(Arc::new(result)))
}

fn documentation(&self) -> Option<&Documentation> {
self.inner.documentation()
}
}
154 changes: 151 additions & 3 deletions datafusion/sqllogictest/test_files/spark/math/pow.slt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,154 @@
# https://github.com/apache/datafusion/issues/15914

## Original Query: SELECT pow(2, 3);
## PySpark 3.5.5 Result: {'pow(2, 3)': 8.0, 'typeof(pow(2, 3))': 'double', 'typeof(2)': 'int', 'typeof(3)': 'int'}
#query
#SELECT pow(2::int, 3::int);
## PySpark 3.5.5 Result: {'pow(2, 3)': 8.0, 'typeof(pow(2, 3))': 'double'}
## DataFusion: pow(int, int) returns int. Sqllogictest prints 8.
query R
SELECT pow(2::int, 3::int);
----
8

## Spark returns Infinity for pow(0, negative) — see https://github.com/apache/datafusion/issues/22598
## PostgreSQL / DataFusion default raises an error instead.
## PySpark 3.5.5: spark.sql("select pow(0, -1)").show() => Infinity

query R
SELECT pow(0::double, -1::double);
----
Infinity

query R
SELECT power(0::double, -1::double);
----
Infinity

query R
SELECT pow(0.0, -1.0);
----
Infinity

# nulls
query R
SELECT pow(CAST(NULL AS DOUBLE), 1.0);
----
NULL

query R
SELECT pow(1.0, CAST(NULL AS DOUBLE));
----
NULL

query R
SELECT pow(CAST(NULL AS DOUBLE), CAST(NULL AS DOUBLE));
----
NULL

# nans
query R
SELECT pow(CAST('NaN' AS DOUBLE), 1.0);
----
NaN

query R
SELECT pow(1.0, CAST('NaN' AS DOUBLE));
----
1

query R
SELECT pow(CAST('NaN' AS DOUBLE), 0.0);
----
1

# -0, +0
query R
SELECT pow(0.0, 1.0);
----
0

query R
SELECT pow(CAST('-0.0' AS DOUBLE), 1.0);
----
0

query R
SELECT pow(0.0, -1.0);
----
Infinity

query R
SELECT pow(CAST('-0.0' AS DOUBLE), -1.0);
----
Infinity

# -inf, +inf
query R
SELECT pow(CAST('Infinity' AS DOUBLE), 1.0);
----
Infinity

query R
SELECT pow(CAST('Infinity' AS DOUBLE), -1.0);
----
0

query R
SELECT pow(CAST('-Infinity' AS DOUBLE), 1.0);
----
-Infinity

query R
SELECT pow(CAST('-Infinity' AS DOUBLE), 2.0);
----
Infinity

query R
SELECT pow(2.0, CAST('Infinity' AS DOUBLE));
----
Infinity

query R
SELECT pow(0.5, CAST('Infinity' AS DOUBLE));
----
0

query R
SELECT pow(2.0, CAST('-Infinity' AS DOUBLE));
----
0

query R
SELECT pow(0.5, CAST('-Infinity' AS DOUBLE));
----
Infinity

# Test Array x Array
statement ok
CREATE TABLE t1(a DOUBLE, b DOUBLE) AS VALUES
(0.0, -1.0),
(2.0, 3.0),
(CAST(NULL AS DOUBLE), 1.0);

query R
SELECT pow(a, b) FROM t1;
----
Infinity
8
NULL

statement ok
DROP TABLE t1;

# Test Scalar x Array
statement ok
CREATE TABLE t2(b DOUBLE) AS VALUES
(-1.0),
(2.0);

query R
SELECT pow(0.0, b) FROM t2;
----
Infinity
0

statement ok
DROP TABLE t2;
Loading