diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 5af901a6bfb6f..73ea5ff95e295 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -186,6 +186,11 @@ harness = false name = "signum" required-features = ["math_expressions"] +[[bench]] +harness = false +name = "atan2" +required-features = ["math_expressions"] + [[bench]] harness = false name = "substr_index" diff --git a/datafusion/functions/benches/atan2.rs b/datafusion/functions/benches/atan2.rs new file mode 100644 index 0000000000000..f1c9756a0cc08 --- /dev/null +++ b/datafusion/functions/benches/atan2.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::datatypes::{DataType, Field, Float32Type, Float64Type}; +use arrow::util::bench_util::create_primitive_array; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::math::atan2; +use std::hint::black_box; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let atan2_fn = atan2(); + let config_options = Arc::new(ConfigOptions::default()); + + for size in [1024, 4096, 8192] { + let y_f32 = Arc::new(create_primitive_array::(size, 0.2)); + let x_f32 = Arc::new(create_primitive_array::(size, 0.2)); + let f32_args = vec![ColumnarValue::Array(y_f32), ColumnarValue::Array(x_f32)]; + let f32_arg_fields = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field_f32 = Field::new("f", DataType::Float32, true).into(); + + c.bench_function(&format!("atan2 f32 array: {size}"), |b| { + b.iter(|| { + black_box( + atan2_fn + .invoke_with_args(ScalarFunctionArgs { + args: f32_args.clone(), + arg_fields: f32_arg_fields.clone(), + number_rows: size, + return_field: Arc::clone(&return_field_f32), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let y_f64 = Arc::new(create_primitive_array::(size, 0.2)); + let x_f64 = Arc::new(create_primitive_array::(size, 0.2)); + let f64_args = vec![ColumnarValue::Array(y_f64), ColumnarValue::Array(x_f64)]; + let f64_arg_fields = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field_f64 = Field::new("f", DataType::Float64, true).into(); + + c.bench_function(&format!("atan2 f64 array: {size}"), |b| { + b.iter(|| { + black_box( + atan2_fn + .invoke_with_args(ScalarFunctionArgs { + args: f64_args.clone(), + arg_fields: f64_arg_fields.clone(), + number_rows: size, + return_field: Arc::clone(&return_field_f64), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + } + + let scalar_f32_args = vec![ + ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0))), + ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), + ]; + let scalar_f32_arg_fields = vec![ + Field::new("a", DataType::Float32, false).into(), + Field::new("b", DataType::Float32, false).into(), + ]; + let return_field_f32 = Field::new("f", DataType::Float32, false).into(); + + c.bench_function("atan2 f32 scalar", |b| { + b.iter(|| { + black_box( + atan2_fn + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f32_args.clone(), + arg_fields: scalar_f32_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f32), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f64_args = vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0))), + ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), + ]; + let scalar_f64_arg_fields = vec![ + Field::new("a", DataType::Float64, false).into(), + Field::new("b", DataType::Float64, false).into(), + ]; + let return_field_f64 = Field::new("f", DataType::Float64, false).into(); + + c.bench_function("atan2 f64 scalar", |b| { + b.iter(|| { + black_box( + atan2_fn + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f64_args.clone(), + arg_fields: scalar_f64_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f64), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 4adc331fef669..380877b593643 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -332,7 +332,8 @@ macro_rules! make_math_binary_udf { use arrow::array::{ArrayRef, AsArray}; use arrow::datatypes::{DataType, Float32Type, Float64Type}; - use datafusion_common::{Result, exec_err}; + use datafusion_common::utils::take_function_args; + use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::TypeSignature; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ @@ -393,37 +394,76 @@ macro_rules! make_math_binary_udf { &self, args: ScalarFunctionArgs, ) -> Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => { - let y = args[0].as_primitive::(); - let x = args[1].as_primitive::(); - let result = arrow::compute::binary::<_, _, _, Float64Type>( - y, - x, - |y, x| f64::$BINARY_FUNC(y, x), - )?; - Arc::new(result) as _ - } - DataType::Float32 => { - let y = args[0].as_primitive::(); - let x = args[1].as_primitive::(); - let result = arrow::compute::binary::<_, _, _, Float32Type>( - y, - x, - |y, x| f32::$BINARY_FUNC(y, x), - )?; - Arc::new(result) as _ - } - other => { - return exec_err!( - "Unsupported data type {other:?} for function {}", - self.name() - ); + let ScalarFunctionArgs { + args, return_field, .. + } = args; + let return_type = return_field.data_type(); + let [y, x] = take_function_args(self.name(), args)?; + + match (y, x) { + ( + ColumnarValue::Scalar(y_scalar), + ColumnarValue::Scalar(x_scalar), + ) => match (&y_scalar, &x_scalar) { + (y, x) if y.is_null() || x.is_null() => { + ColumnarValue::Scalar(ScalarValue::Null) + .cast_to(return_type, None) + } + ( + ScalarValue::Float64(Some(yv)), + ScalarValue::Float64(Some(xv)), + ) => Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some( + f64::$BINARY_FUNC(*yv, *xv), + )))), + ( + ScalarValue::Float32(Some(yv)), + ScalarValue::Float32(Some(xv)), + ) => Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some( + f32::$BINARY_FUNC(*yv, *xv), + )))), + _ => internal_err!( + "Unexpected scalar types for function {}: {:?}, {:?}", + self.name(), + y_scalar.data_type(), + x_scalar.data_type() + ), + }, + (y, x) => { + let args = ColumnarValue::values_to_arrays(&[y, x])?; + let arr: ArrayRef = match args[0].data_type() { + DataType::Float64 => { + let y = args[0].as_primitive::(); + let x = args[1].as_primitive::(); + let result = + arrow::compute::binary::<_, _, _, Float64Type>( + y, + x, + |y, x| f64::$BINARY_FUNC(y, x), + )?; + Arc::new(result) as _ + } + DataType::Float32 => { + let y = args[0].as_primitive::(); + let x = args[1].as_primitive::(); + let result = + arrow::compute::binary::<_, _, _, Float32Type>( + y, + x, + |y, x| f32::$BINARY_FUNC(y, x), + )?; + Arc::new(result) as _ + } + other => { + return internal_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + + Ok(ColumnarValue::Array(arr)) } - }; - - Ok(ColumnarValue::Array(arr)) + } } fn documentation(&self) -> Option<&Documentation> {