Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/spark_expressions_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
- [ ] sequence
- [ ] shuffle
- [ ] slice
- [ ] sort_array
- [x] sort_array

### bitwise_funcs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
classOf[ArrayMin] -> CometArrayMin,
classOf[ArrayRemove] -> CometArrayRemove,
classOf[ArrayRepeat] -> CometArrayRepeat,
classOf[SortArray] -> CometSortArray,
classOf[ArraysOverlap] -> CometArraysOverlap,
classOf[ArrayUnion] -> CometArrayUnion,
classOf[CreateArray] -> CometCreateArray,
Expand Down
82 changes: 81 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ package org.apache.comet.serde

import scala.annotation.tailrec

import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size}
import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, SortArray}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.QueryPlanSerde._
import org.apache.comet.shims.CometExprShim
Expand Down Expand Up @@ -200,6 +201,85 @@ object CometArrayDistinct extends CometExpressionSerde[ArrayDistinct] {
}
}

object CometSortArray extends CometExpressionSerde[SortArray] {
private def containsFloatingPoint(dt: DataType): Boolean = {
dt match {
case FloatType | DoubleType => true
case ArrayType(elementType, _) => containsFloatingPoint(elementType)
case StructType(fields) => fields.exists(f => containsFloatingPoint(f.dataType))
case MapType(keyType, valueType, _) =>
containsFloatingPoint(keyType) || containsFloatingPoint(valueType)
case _ => false
}
}

private def canRank(dt: DataType, nestedInArray: Boolean = false): Boolean = {
dt match {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: DecimalType =>
true
case _: DateType | _: TimestampType | _: TimestampNTZType =>
true
case _: NullType =>
true
case _: BooleanType | _: BinaryType | _: StringType =>
true
case ArrayType(elementType, _) =>
canRank(elementType, nestedInArray = true)
case StructType(fields) if !nestedInArray =>
fields.forall(f => canRank(f.dataType))
case _ =>
false
}
}

override def getSupportLevel(expr: SortArray): SupportLevel = {
val elementType = expr.base.dataType.asInstanceOf[ArrayType].elementType

if (!canRank(elementType)) {
Unsupported(Some(s"Sort on array element type $elementType is not supported"))
} else if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() &&
containsFloatingPoint(elementType)) {
Incompatible(
Some(
"Sorting on floating-point is not 100% compatible with Spark, and Comet is running " +
s"with ${CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " +
s"${CometConf.COMPAT_GUIDE}"))
} else {
Compatible()
}
}

override def convert(
expr: SortArray,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val arrayExprProto = exprToProtoInternal(expr.base, inputs, binding)
val sortDirectionExprProto = expr.ascendingOrder match {
case Literal(value: Boolean, BooleanType) =>
val direction = if (value) "ASC" else "DESC"
exprToProtoInternal(Literal(direction), inputs, binding)
case other =>
withInfo(expr, s"ascendingOrder must be a boolean literal: $other")
None
}
val nullOrderingExprProto = expr.ascendingOrder match {
case Literal(value: Boolean, BooleanType) =>
val nullOrdering = if (value) "NULLS FIRST" else "NULLS LAST"
exprToProtoInternal(Literal(nullOrdering), inputs, binding)
case _ => None
}

val sortArrayScalarExpr =
scalarFunctionExprToProto(
"array_sort",
arrayExprProto,
sortDirectionExprProto,
nullOrderingExprProto)
optExprWithInfo(sortArrayScalarExpr, expr, expr.children: _*)
}
}

object CometArrayIntersect extends CometExpressionSerde[ArrayIntersect] {

override def getSupportLevel(expr: ArrayIntersect): SupportLevel = Incompatible(None)
Expand Down
199 changes: 199 additions & 0 deletions spark/src/test/resources/sql-tests/expressions/array/sort_array.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing,
-- software distributed under the License is distributed on an
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-- KIND, either express or implied. See the License for the
-- specific language governing permissions and limitations
-- under the License.

-- ConfigMatrix: parquet.enable.dictionary=false,true

statement
CREATE TABLE test_sort_array_int(arr array<int>) USING parquet

statement
INSERT INTO test_sort_array_int VALUES
(array(3, 1, 4, 1, 5)),
(array(3, NULL, 1, NULL, 2)),
(array()),
(NULL)

query
SELECT sort_array(arr) FROM test_sort_array_int

query
SELECT sort_array(arr, true) FROM test_sort_array_int

query
SELECT sort_array(arr, false) FROM test_sort_array_int

statement
CREATE TABLE test_sort_array_string(arr array<string>) USING parquet

statement
INSERT INTO test_sort_array_string VALUES
(array('d', 'c', 'b', 'a')),
(array('b', NULL, 'a')),
(array()),
(NULL)

query
SELECT sort_array(arr) FROM test_sort_array_string

query
SELECT sort_array(arr, true) FROM test_sort_array_string

query
SELECT sort_array(arr, false) FROM test_sort_array_string

statement
CREATE TABLE test_sort_array_float(arr array<double>) USING parquet

statement
INSERT INTO test_sort_array_float VALUES
(array(CAST('NaN' AS DOUBLE), 3.0, 1.0, NULL, -0.0, 0.0)),
(array(CAST('NaN' AS DOUBLE), CAST('NaN' AS DOUBLE), -5.0, 2.0)),
(array()),
(NULL)

query
SELECT sort_array(arr) FROM test_sort_array_float

query
SELECT sort_array(arr, true) FROM test_sort_array_float

query
SELECT sort_array(arr, false) FROM test_sort_array_float

statement
CREATE TABLE test_sort_array_decimal(arr array<decimal(10, 0)>) USING parquet

statement
INSERT INTO test_sort_array_decimal VALUES
(array(CAST(100 AS DECIMAL(10, 0)), CAST(10 AS DECIMAL(10, 0)))),
(array()),
(NULL)

query
SELECT sort_array(arr) FROM test_sort_array_decimal

query
SELECT sort_array(arr, true) FROM test_sort_array_decimal

query
SELECT sort_array(arr, false) FROM test_sort_array_decimal

statement
CREATE TABLE test_sort_array_boolean(arr array<boolean>) USING parquet

statement
INSERT INTO test_sort_array_boolean VALUES
(array(true, false, true, false)),
(array(true, false, true, NULL, false)),
(array()),
(NULL)

query
SELECT sort_array(arr) FROM test_sort_array_boolean

query
SELECT sort_array(arr, true) FROM test_sort_array_boolean

query
SELECT sort_array(arr, false) FROM test_sort_array_boolean

statement
CREATE TABLE test_sort_array_struct(arr array<struct<a:int,b:string>>) USING parquet

statement
INSERT INTO test_sort_array_struct VALUES
(array(
named_struct('a', 2, 'b', 'b'),
named_struct('a', 1, 'b', 'c'),
named_struct('a', 1, 'b', 'a'))),
(array(
named_struct('a', 2, 'b', NULL),
named_struct('a', 1, 'b', 'z'),
named_struct('a', 1, 'b', NULL))),
(array()),
(NULL)

query
SELECT sort_array(arr) FROM test_sort_array_struct

query
SELECT sort_array(arr, false) FROM test_sort_array_struct

statement
CREATE TABLE test_sort_array_nested(arr array<array<int>>) USING parquet

statement
INSERT INTO test_sort_array_nested VALUES
(array(array(2, 3), array(1), array(2, 1))),
(array(array(1, NULL), array(1), NULL)),
(array()),
(NULL)

query
SELECT sort_array(arr) FROM test_sort_array_nested

query
SELECT sort_array(arr, false) FROM test_sort_array_nested

statement
CREATE TABLE test_sort_array_nested_struct(arr array<array<struct<a:int>>>) USING parquet

statement
INSERT INTO test_sort_array_nested_struct VALUES
(array(
array(named_struct('a', 2)),
array(named_struct('a', 1)))),
(array()),
(NULL)

query expect_fallback(Sort on array element type ArrayType(StructType(StructField(a,IntegerType)
SELECT sort_array(arr) FROM test_sort_array_nested_struct

query expect_fallback(Sort on array element type ArrayType(StructType(StructField(a,IntegerType)
SELECT sort_array(arr, false) FROM test_sort_array_nested_struct

-- literal arguments
query
SELECT
sort_array(array(3, 1, 4, 1, 5)),
sort_array(array(3, 1, 4, 1, 5), true),
sort_array(array(3, NULL, 1, NULL, 2)),
sort_array(array(3, NULL, 1, NULL, 2), false),
sort_array(array(CAST('NaN' AS DOUBLE), 1.0, NULL, -0.0, 0.0)),
sort_array(array(CAST('NaN' AS DOUBLE), 1.0, NULL, -0.0, 0.0), false),
sort_array(array(CAST(100 AS DECIMAL(10, 0)), CAST(10 AS DECIMAL(10, 0)))),
sort_array(
array(CAST(100 AS DECIMAL(10, 0)), CAST(10 AS DECIMAL(10, 0))),
false),
sort_array(array(true, false, true, false)),
sort_array(array(true, false, true, NULL, false)),
sort_array(array(true, false, true, NULL, false), false),
sort_array(
array(
named_struct('a', 2, 'b', 'b'),
named_struct('a', 1, 'b', 'c'),
named_struct('a', 1, 'b', 'a'))),
sort_array(array(array(2, 3), array(1), array(2, 1))),
sort_array(array(array(1, NULL), array(1), NULL)),
sort_array(array(NULL, NULL)),
sort_array(cast(NULL as array<int>))

query expect_fallback(Sort on array element type ArrayType(StructType(StructField(a,IntegerType)
SELECT sort_array(
array(
array(named_struct('a', 2)),
array(named_struct('a', 1))))
Loading