diff --git a/docs/spark_expressions_support.md b/docs/spark_expressions_support.md index 5474894108..10da3b4dc7 100644 --- a/docs/spark_expressions_support.md +++ b/docs/spark_expressions_support.md @@ -105,7 +105,7 @@ - [ ] sequence - [ ] shuffle - [ ] slice -- [ ] sort_array +- [x] sort_array ### bitwise_funcs diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 8c39ba779d..e8f0372409 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -60,6 +60,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[ArrayMin] -> CometArrayMin, classOf[ArrayRemove] -> CometArrayRemove, classOf[ArrayRepeat] -> CometArrayRepeat, + classOf[SortArray] -> CometSortArray, classOf[ArraysOverlap] -> CometArraysOverlap, classOf[ArrayUnion] -> CometArrayUnion, classOf[CreateArray] -> CometCreateArray, diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index c82018fe6d..700169d5f2 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -21,11 +21,12 @@ package org.apache.comet.serde import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size} +import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, SortArray} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.QueryPlanSerde._ import org.apache.comet.shims.CometExprShim @@ -200,6 +201,88 @@ object CometArrayDistinct extends CometExpressionSerde[ArrayDistinct] { } } +object CometSortArray extends CometExpressionSerde[SortArray] { + private def containsFloatingPoint(dt: DataType): Boolean = { + dt match { + case FloatType | DoubleType => true + case ArrayType(elementType, _) => containsFloatingPoint(elementType) + case StructType(fields) => fields.exists(f => containsFloatingPoint(f.dataType)) + case MapType(keyType, valueType, _) => + containsFloatingPoint(keyType) || containsFloatingPoint(valueType) + case _ => false + } + } + + private def canRank(dt: DataType, nestedInArray: Boolean = false): Boolean = { + dt match { + case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | + _: DoubleType | _: DecimalType => + true + case _: DateType | _: TimestampType | _: TimestampNTZType => + true + // DataFusion's array_sort compares nested arrays through Arrow's rank kernel. + // That kernel does not support Struct or Null child values, + // so array>> and array> would fail at runtime. + case _: NullType if !nestedInArray => + true + case _: BooleanType | _: BinaryType | _: StringType => + true + case ArrayType(elementType, _) => + canRank(elementType, nestedInArray = true) + case StructType(fields) if !nestedInArray => + fields.forall(f => canRank(f.dataType)) + case _ => + false + } + } + + override def getSupportLevel(expr: SortArray): SupportLevel = { + val elementType = expr.base.dataType.asInstanceOf[ArrayType].elementType + + if (!canRank(elementType)) { + Unsupported(Some(s"Sort on array element type $elementType is not supported")) + } else if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() && + containsFloatingPoint(elementType)) { + Incompatible( + Some( + "Sorting on floating-point is not 100% compatible with Spark, and Comet is running " + + s"with ${CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " + + s"${CometConf.COMPAT_GUIDE}")) + } else { + Compatible() + } + } + + override def convert( + expr: SortArray, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val arrayExprProto = exprToProtoInternal(expr.base, inputs, binding) + val sortDirectionExprProto = expr.ascendingOrder match { + case Literal(value: Boolean, BooleanType) => + val direction = if (value) "ASC" else "DESC" + exprToProtoInternal(Literal(direction), inputs, binding) + case other => + withInfo(expr, s"ascendingOrder must be a boolean literal: $other") + None + } + val nullOrderingExprProto = expr.ascendingOrder match { + case Literal(value: Boolean, BooleanType) => + val nullOrdering = if (value) "NULLS FIRST" else "NULLS LAST" + exprToProtoInternal(Literal(nullOrdering), inputs, binding) + case _ => None + } + + val sortArrayScalarExpr = + scalarFunctionExprToProto( + "array_sort", + arrayExprProto, + sortDirectionExprProto, + nullOrderingExprProto) + optExprWithInfo(sortArrayScalarExpr, expr, expr.children: _*) + } +} + object CometArrayIntersect extends CometExpressionSerde[ArrayIntersect] { override def getSupportLevel(expr: ArrayIntersect): SupportLevel = Incompatible(None) diff --git a/spark/src/test/resources/sql-tests/expressions/array/sort_array.sql b/spark/src/test/resources/sql-tests/expressions/array/sort_array.sql new file mode 100644 index 0000000000..f118d0830f --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/sort_array.sql @@ -0,0 +1,199 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ConfigMatrix: parquet.enable.dictionary=false,true + +statement +CREATE TABLE test_sort_array_int(arr array) USING parquet + +statement +INSERT INTO test_sort_array_int VALUES + (array(3, 1, 4, 1, 5)), + (array(3, NULL, 1, NULL, 2)), + (array()), + (NULL) + +query +SELECT sort_array(arr) FROM test_sort_array_int + +query +SELECT sort_array(arr, true) FROM test_sort_array_int + +query +SELECT sort_array(arr, false) FROM test_sort_array_int + +statement +CREATE TABLE test_sort_array_string(arr array) USING parquet + +statement +INSERT INTO test_sort_array_string VALUES + (array('d', 'c', 'b', 'a')), + (array('b', NULL, 'a')), + (array()), + (NULL) + +query +SELECT sort_array(arr) FROM test_sort_array_string + +query +SELECT sort_array(arr, true) FROM test_sort_array_string + +query +SELECT sort_array(arr, false) FROM test_sort_array_string + +statement +CREATE TABLE test_sort_array_float(arr array) USING parquet + +statement +INSERT INTO test_sort_array_float VALUES + (array(CAST('NaN' AS DOUBLE), 3.0, 1.0, NULL, -0.0, 0.0)), + (array(CAST('NaN' AS DOUBLE), CAST('NaN' AS DOUBLE), -5.0, 2.0)), + (array()), + (NULL) + +query +SELECT sort_array(arr) FROM test_sort_array_float + +query +SELECT sort_array(arr, true) FROM test_sort_array_float + +query +SELECT sort_array(arr, false) FROM test_sort_array_float + +statement +CREATE TABLE test_sort_array_decimal(arr array) USING parquet + +statement +INSERT INTO test_sort_array_decimal VALUES + (array(CAST(100 AS DECIMAL(10, 0)), CAST(10 AS DECIMAL(10, 0)))), + (array()), + (NULL) + +query +SELECT sort_array(arr) FROM test_sort_array_decimal + +query +SELECT sort_array(arr, true) FROM test_sort_array_decimal + +query +SELECT sort_array(arr, false) FROM test_sort_array_decimal + +statement +CREATE TABLE test_sort_array_boolean(arr array) USING parquet + +statement +INSERT INTO test_sort_array_boolean VALUES + (array(true, false, true, false)), + (array(true, false, true, NULL, false)), + (array()), + (NULL) + +query +SELECT sort_array(arr) FROM test_sort_array_boolean + +query +SELECT sort_array(arr, true) FROM test_sort_array_boolean + +query +SELECT sort_array(arr, false) FROM test_sort_array_boolean + +statement +CREATE TABLE test_sort_array_struct(arr array>) USING parquet + +statement +INSERT INTO test_sort_array_struct VALUES + (array( + named_struct('a', 2, 'b', 'b'), + named_struct('a', 1, 'b', 'c'), + named_struct('a', 1, 'b', 'a'))), + (array( + named_struct('a', 2, 'b', NULL), + named_struct('a', 1, 'b', 'z'), + named_struct('a', 1, 'b', NULL))), + (array()), + (NULL) + +query +SELECT sort_array(arr) FROM test_sort_array_struct + +query +SELECT sort_array(arr, false) FROM test_sort_array_struct + +statement +CREATE TABLE test_sort_array_nested(arr array>) USING parquet + +statement +INSERT INTO test_sort_array_nested VALUES + (array(array(2, 3), array(1), array(2, 1))), + (array(array(1, NULL), array(1), NULL)), + (array()), + (NULL) + +query +SELECT sort_array(arr) FROM test_sort_array_nested + +query +SELECT sort_array(arr, false) FROM test_sort_array_nested + +statement +CREATE TABLE test_sort_array_nested_struct(arr array>>) USING parquet + +statement +INSERT INTO test_sort_array_nested_struct VALUES + (array( + array(named_struct('a', 2)), + array(named_struct('a', 1)))), + (array()), + (NULL) + +query expect_fallback(Sort on array element type ArrayType(StructType(StructField(a,IntegerType) +SELECT sort_array(arr) FROM test_sort_array_nested_struct + +query expect_fallback(Sort on array element type ArrayType(StructType(StructField(a,IntegerType) +SELECT sort_array(arr, false) FROM test_sort_array_nested_struct + +-- literal arguments +query +SELECT + sort_array(array(3, 1, 4, 1, 5)), + sort_array(array(3, 1, 4, 1, 5), true), + sort_array(array(3, NULL, 1, NULL, 2)), + sort_array(array(3, NULL, 1, NULL, 2), false), + sort_array(array(CAST('NaN' AS DOUBLE), 1.0, NULL, -0.0, 0.0)), + sort_array(array(CAST('NaN' AS DOUBLE), 1.0, NULL, -0.0, 0.0), false), + sort_array(array(CAST(100 AS DECIMAL(10, 0)), CAST(10 AS DECIMAL(10, 0)))), + sort_array( + array(CAST(100 AS DECIMAL(10, 0)), CAST(10 AS DECIMAL(10, 0))), + false), + sort_array(array(true, false, true, false)), + sort_array(array(true, false, true, NULL, false)), + sort_array(array(true, false, true, NULL, false), false), + sort_array( + array( + named_struct('a', 2, 'b', 'b'), + named_struct('a', 1, 'b', 'c'), + named_struct('a', 1, 'b', 'a'))), + sort_array(array(array(2, 3), array(1), array(2, 1))), + sort_array(array(array(1, NULL), array(1), NULL)), + sort_array(array(NULL, NULL)), + sort_array(cast(NULL as array)) + +query expect_fallback(Sort on array element type ArrayType(StructType(StructField(a,IntegerType) +SELECT sort_array( + array( + array(named_struct('a', 2)), + array(named_struct('a', 1))))