From 02a5b6c3cdf64e89b53d3bd6eba67989927a6c59 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Tue, 7 Apr 2026 15:04:54 -0700 Subject: [PATCH 1/2] [SPARK-56383][SQL] Add tests for partition filter extraction from mixed predicates Add tests to DataSourceV2EnhancedPartitionFilterSuite covering the case where getPartitionFiltersAndDataFilters extracts additional partition filters from predicates that reference both partition and data columns (via extractPredicatesWithinOutputSet). --- ...SourceV2EnhancedPartitionFilterSuite.scala | 145 ++++++++++++++++++ 1 file changed, 145 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedPartitionFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedPartitionFilterSuite.scala index d49ff779e7376..135edac6c9a77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedPartitionFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedPartitionFilterSuite.scala @@ -409,6 +409,151 @@ class DataSourceV2EnhancedPartitionFilterSuite } } + test("extract partition filter from translated OR with mixed partition and data references") { + withTable(partFilterTableName) { + sql(s"CREATE TABLE $partFilterTableName (part_col string, data string) USING $v2Source " + + "PARTITIONED BY (part_col)") + sql(s"INSERT INTO $partFilterTableName VALUES " + + "('a', 'x'), ('a', 'other'), ('b', 'y'), ('c', 'z')") + + val df = sql(s"SELECT * FROM $partFilterTableName WHERE " + + "(part_col = 'a' AND data = 'x') OR (part_col = 'b' AND data = 'y')") + checkAnswer(df, Seq(Row("a", "x"), Row("b", "y"))) + assertPushedPartitionPredicates(df, 1) + assertScanReturnsPartitionKeys(df, Set("a", "b")) + assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part_col")) + } + } + + test("extract partition filter from untranslatable OR with mixed partition and data references") { + withTable(partFilterTableName) { + sql(s"CREATE TABLE $partFilterTableName (part_col string, data string) USING $v2Source " + + "PARTITIONED BY (part_col)") + sql(s"INSERT INTO $partFilterTableName VALUES " + + "('a', 'x'), ('b', 'y'), ('c', 'z')") + + spark.udf.register("my_upper_extract", (s: String) => + if (s == null) null else s.toUpperCase(Locale.ROOT)) + + val df = sql(s"SELECT * FROM $partFilterTableName WHERE " + + "(my_upper_extract(part_col) = 'A' AND data = 'x') OR " + + "(my_upper_extract(part_col) = 'B' AND data = 'y')") + checkAnswer(df, Seq(Row("a", "x"), Row("b", "y"))) + assertPushedPartitionPredicates(df, 1) + assertScanReturnsPartitionKeys(df, Set("a", "b")) + assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part_col")) + } + } + + test("extract partition filter from OR with one partition-only and one mixed filter") { + withTable(partFilterTableName) { + sql(s"CREATE TABLE $partFilterTableName (part_col string, data string) USING $v2Source " + + "PARTITIONED BY (part_col)") + sql(s"INSERT INTO $partFilterTableName VALUES " + + "('a', 'x'), ('a', 'other'), ('b', 'y'), ('b', 'other'), ('c', 'z')") + + val df = sql(s"SELECT * FROM $partFilterTableName WHERE " + + "part_col = 'a' OR (part_col = 'b' AND data = 'y')") + checkAnswer(df, Seq(Row("a", "x"), Row("a", "other"), Row("b", "y"))) + assertPushedPartitionPredicates(df, 1) + assertScanReturnsPartitionKeys(df, Set("a", "b")) + assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part_col")) + } + } + + test("extract multi-column partition filter from OR with mixed partition and data references") { + withTable(partFilterTableName) { + sql(s"CREATE TABLE $partFilterTableName (p1 string, p2 string, data string) " + + s"USING $v2Source PARTITIONED BY (p1, p2)") + sql(s"INSERT INTO $partFilterTableName VALUES " + + "('a', 'x', 'd1'), ('a', 'y', 'd2'), ('b', 'x', 'd3'), ('b', 'y', 'd4')") + + val df = sql(s"SELECT * FROM $partFilterTableName WHERE " + + "(p1 = 'a' AND p2 = 'x' AND data = 'd1') OR (p1 = 'b' AND p2 = 'y' AND data = 'd4')") + checkAnswer(df, Seq(Row("a", "x", "d1"), Row("b", "y", "d4"))) + assertPushedPartitionPredicates(df, 1) + assertScanReturnsPartitionKeys(df, Set("a/x", "b/y")) + assertReferencedPartitionFieldOrdinals(df, Array(0, 1), Array("p1", "p2")) + } + } + + test("two partition predicates pushed: UDF on p1 and " + + "extracted filter on p2 from mixed OR") { + withTable(partFilterTableName) { + sql(s"CREATE TABLE $partFilterTableName (p1 string, p2 string, data string) " + + s"USING $v2Source PARTITIONED BY (p1, p2)") + sql(s"INSERT INTO $partFilterTableName VALUES " + + "('a', 'x', 'd1'), " + + "('a', 'y', 'd4'), " + + "('b', 'x', 'd3'), " + + "('b', 'y', 'd4'), " + + "('c', 'z', 'd5')") + + spark.udf.register("my_upper_multi", (s: String) => + if (s == null) null else s.toUpperCase(Locale.ROOT)) + + // my_upper_multi(p1) = 'A' is untranslatable and partition-only, so it is + // a partition filter. The OR mixes p2 and data; extractPredicatesWithinOutputSet + // infers (p2 = 'x' OR p2 = 'y') as an additional partition filter. + // Both are pushed as separate PartitionPredicates. + val df = sql(s"SELECT * FROM $partFilterTableName WHERE " + + "my_upper_multi(p1) = 'A' AND " + + "((p2 = 'x' AND data = 'd1') OR (p2 = 'y' AND data = 'd4'))") + checkAnswer(df, Seq(Row("a", "x", "d1"), Row("a", "y", "d4"))) + assertPushedPartitionPredicates(df, 2) + assertScanReturnsPartitionKeys(df, Set("a/x", "a/y")) + } + } + + test("nested partition: extract partition filter from " + + "OR with mixed references") { + withTable(partFilterTableName) { + sql(s"CREATE TABLE $partFilterTableName " + + s"(s struct, data string) USING $v2Source " + + "PARTITIONED BY (s.tz)") + sql(s"INSERT INTO $partFilterTableName VALUES " + + "(named_struct('tz', 'LA', 'x', 1), 'a'), " + + "(named_struct('tz', 'NY', 'x', 2), 'b'), " + + "(named_struct('tz', 'SF', 'x', 3), 'c')") + + val df = sql(s"SELECT * FROM $partFilterTableName WHERE " + + "(s.tz = 'LA' AND data = 'a') OR (s.tz = 'NY' AND data = 'b')") + checkAnswer(df, Seq(Row(Row("LA", 1), "a"), Row(Row("NY", 2), "b"))) + assertPushedPartitionPredicates(df, 1) + assertScanReturnsPartitionKeys(df, Set("LA", "NY")) + assertReferencedPartitionFieldOrdinals(df, Array(0), Array("s.tz")) + } + } + + test("nested partition: two partition predicates from " + + "UDF and extracted mixed OR") { + withTable(partFilterTableName) { + sql(s"CREATE TABLE $partFilterTableName " + + s"(s struct, data string) USING $v2Source " + + "PARTITIONED BY (s.tz)") + sql(s"INSERT INTO $partFilterTableName VALUES " + + "(named_struct('tz', 'LA', 'x', 1), 'a'), " + + "(named_struct('tz', 'la', 'x', 2), 'b'), " + + "(named_struct('tz', 'NY', 'x', 3), 'c'), " + + "(named_struct('tz', 'SF', 'x', 4), 'd')") + + spark.udf.register("my_upper_nested2", (s: String) => + if (s == null) null else s.toUpperCase(Locale.ROOT)) + + // my_upper_nested2(s.tz) = 'LA' is untranslatable and partition-only. + // The OR mixes s.tz and data; extractPredicatesWithinOutputSet + // infers (s.tz = 'LA' OR s.tz = 'la') as an additional partition filter. + // Both are pushed as separate PartitionPredicates. + val df = sql(s"SELECT * FROM $partFilterTableName WHERE " + + "my_upper_nested2(s.tz) = 'LA' AND " + + "((s.tz = 'LA' AND data = 'a') OR (s.tz = 'la' AND data = 'b'))") + checkAnswer(df, Seq(Row(Row("LA", 1), "a"), Row(Row("la", 2), "b"))) + assertPushedPartitionPredicates(df, 2) + assertScanReturnsPartitionKeys(df, Set("LA", "la")) + assertReferencedPartitionFieldOrdinals(df, Array(0), Array("s.tz")) + } + } + private def assertTranslatableBeforeUntranslatableInPostScan(df: DataFrame): Unit = { val postScanFilterExec = df.queryExecution.executedPlan.collect { case f @ FilterExec(_, _) if f.exists(_.isInstanceOf[BatchScanExec]) => f From d56c2d30715215b41f1c8949c1ea054b1b2a4005 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Tue, 7 Apr 2026 22:19:07 -0700 Subject: [PATCH 2/2] Update test titles --- ...taSourceV2EnhancedPartitionFilterSuite.scala | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedPartitionFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedPartitionFilterSuite.scala index 135edac6c9a77..a3c2fbdc83d9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedPartitionFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedPartitionFilterSuite.scala @@ -478,7 +478,7 @@ class DataSourceV2EnhancedPartitionFilterSuite } test("two partition predicates pushed: UDF on p1 and " + - "extracted filter on p2 from mixed OR") { + "extracted filter on p2 from mixed data and partition references") { withTable(partFilterTableName) { sql(s"CREATE TABLE $partFilterTableName (p1 string, p2 string, data string) " + s"USING $v2Source PARTITIONED BY (p1, p2)") @@ -492,9 +492,8 @@ class DataSourceV2EnhancedPartitionFilterSuite spark.udf.register("my_upper_multi", (s: String) => if (s == null) null else s.toUpperCase(Locale.ROOT)) - // my_upper_multi(p1) = 'A' is untranslatable and partition-only, so it is - // a partition filter. The OR mixes p2 and data; extractPredicatesWithinOutputSet - // infers (p2 = 'x' OR p2 = 'y') as an additional partition filter. + // my_upper_multi(p1) = 'A' is untranslatable and partition-only, so it is a partition filter. + // The OR mixes p2 and data; we infer (p2 = 'x' OR p2 = 'y') as a partition filter. // Both are pushed as separate PartitionPredicates. val df = sql(s"SELECT * FROM $partFilterTableName WHERE " + "my_upper_multi(p1) = 'A' AND " + @@ -506,7 +505,7 @@ class DataSourceV2EnhancedPartitionFilterSuite } test("nested partition: extract partition filter from " + - "OR with mixed references") { + "OR with mixed data and partition references") { withTable(partFilterTableName) { sql(s"CREATE TABLE $partFilterTableName " + s"(s struct, data string) USING $v2Source " + @@ -526,7 +525,7 @@ class DataSourceV2EnhancedPartitionFilterSuite } test("nested partition: two partition predicates from " + - "UDF and extracted mixed OR") { + "UDF and extracted mixed data and partition references") { withTable(partFilterTableName) { sql(s"CREATE TABLE $partFilterTableName " + s"(s struct, data string) USING $v2Source " + @@ -540,9 +539,9 @@ class DataSourceV2EnhancedPartitionFilterSuite spark.udf.register("my_upper_nested2", (s: String) => if (s == null) null else s.toUpperCase(Locale.ROOT)) - // my_upper_nested2(s.tz) = 'LA' is untranslatable and partition-only. - // The OR mixes s.tz and data; extractPredicatesWithinOutputSet - // infers (s.tz = 'LA' OR s.tz = 'la') as an additional partition filter. + // my_upper_nested2(s.tz) = 'LA' is untranslatable and partition-only, + // it is a partition filter. + // The OR mixes s.tz and data; we infer (s.tz = 'LA' OR s.tz = 'la') as an partition filter. // Both are pushed as separate PartitionPredicates. val df = sql(s"SELECT * FROM $partFilterTableName WHERE " + "my_upper_nested2(s.tz) = 'LA' AND " +