diff --git a/native/Cargo.lock b/native/Cargo.lock index 328e8b3727..e8957e0065 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1938,6 +1938,7 @@ dependencies = [ "regex", "serde", "serde_json", + "sha2", "thiserror 2.0.18", "tokio", "twox-hash", diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 361deae182..8962f66494 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -390,7 +390,6 @@ fn prepare_datafusion_session_context( // register UDFs from datafusion-spark crate fn register_datafusion_spark_function(session_ctx: &SessionContext) { session_ctx.register_udf(ScalarUDF::new_from_impl(SparkExpm1::default())); - session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha2::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(CharFunc::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitGet::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateAdd::default())); @@ -398,6 +397,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) { session_ctx.register_udf(ScalarUDF::new_from_impl(SparkLastDay::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkNextDay::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha1::default())); + session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha2::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkConcat::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseNot::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkHex::default())); diff --git a/spark/src/main/scala/org/apache/comet/serde/hash.scala b/spark/src/main/scala/org/apache/comet/serde/hash.scala index b059199735..a138ea023d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/hash.scala +++ b/spark/src/main/scala/org/apache/comet/serde/hash.scala @@ -19,7 +19,7 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Murmur3Hash, Sha1, Sha2, XxHash64} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal, Murmur3Hash, Sha1, Sha2, XxHash64} import org.apache.spark.sql.types.{ArrayType, DataType, DecimalType, IntegerType, LongType, MapType, StringType, StructType} import org.apache.comet.CometSparkSessionExtensions.withInfo @@ -76,6 +76,12 @@ object CometSha2 extends CometExpressionSerde[Sha2] { return None } + // Fall back to Spark for literal input to avoid native engine crash (#3340) + if (expr.left.isInstanceOf[Literal]) { + withInfo(expr, "Sha2 with literal input falls back to Spark") + return None + } + // It's possible for spark to dynamically compute the number of bits from input // expression, however DataFusion does not support that yet. if (!expr.right.foldable) { diff --git a/spark/src/test/resources/sql-tests/expressions/hash/hash.sql b/spark/src/test/resources/sql-tests/expressions/hash/hash.sql index 35031ea7e4..550f34b13a 100644 --- a/spark/src/test/resources/sql-tests/expressions/hash/hash.sql +++ b/spark/src/test/resources/sql-tests/expressions/hash/hash.sql @@ -25,8 +25,10 @@ statement INSERT INTO test VALUES ('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), ('苹果手机', NULL, 3.999999), ('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), ('苹果手机', NULL, 3.999999) query -SELECT md5(col), md5(cast(a as string)), md5(cast(b as string)), hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), xxhash64(col), xxhash64(col, 1), xxhash64(col, 0), xxhash64(col, a, b), xxhash64(b, a, col), sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1), sha1(col), sha1(cast(a as string)), sha1(cast(b as string)) FROM test +SELECT md5(col), md5(cast(a as string)), md5(cast(b as string)), hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), xxhash64(col), xxhash64(col, 1), xxhash64(col, 0), xxhash64(col, a, b), xxhash64(b, a, col), sha1(col), sha1(cast(a as string)), sha1(cast(b as string)), sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1) FROM test -- literal arguments -query ignore(https://github.com/apache/datafusion-comet/issues/3340) +-- sha2 with literal input falls back to Spark to avoid native engine crash (#3340) +query expect_fallback(Sha2 with literal input falls back to Spark) SELECT md5('Spark SQL'), sha1('test'), sha2('test', 256), hash('test'), xxhash64('test') + diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index eeaf1ed911..cbde9ed6b6 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1891,8 +1891,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), |xxhash64(col), xxhash64(col, 1), xxhash64(col, 0), xxhash64(col, a, b), xxhash64(b, a, col), |crc32(col), crc32(cast(a as string)), crc32(cast(b as string)), - |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1), - |sha1(col), sha1(cast(a as string)), sha1(cast(b as string)) + |sha1(col), sha1(cast(a as string)), sha1(cast(b as string)), + |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1) |from test |""".stripMargin) } @@ -2002,8 +2002,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), |xxhash64(col), xxhash64(col, 1), xxhash64(col, 0), xxhash64(col, a, b), xxhash64(b, a, col), |crc32(col), crc32(cast(a as string)), - |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1), - |sha1(col), sha1(cast(a as string)) + |sha1(col), sha1(cast(a as string)), + |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1) |from test |""".stripMargin) }