diff --git a/datafusion/core/benches/sql_planner_extended.rs b/datafusion/core/benches/sql_planner_extended.rs index d4955313c79c3..767134bb5bafd 100644 --- a/datafusion/core/benches/sql_planner_extended.rs +++ b/datafusion/core/benches/sql_planner_extended.rs @@ -18,21 +18,31 @@ use arrow::array::{ArrayRef, RecordBatch}; use arrow_schema::DataType; use arrow_schema::TimeUnit::Nanosecond; -use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use criterion::{ + BenchmarkGroup, BenchmarkId, Criterion, criterion_group, criterion_main, + measurement::WallTime, +}; use datafusion::prelude::{DataFrame, SessionContext}; use datafusion_catalog::MemTable; -use datafusion_common::ScalarValue; +use datafusion_common::{Column, ScalarValue}; use datafusion_expr::Expr::Literal; +use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::utils::split_conjunction_owned; use datafusion_expr::{cast, col, lit, not, try_cast, when}; use datafusion_functions::expr_fn::{ btrim, length, regexp_like, regexp_replace, to_timestamp, upper, }; +use std::env; use std::fmt::Write; use std::hint::black_box; use std::ops::Rem; use std::sync::Arc; use tokio::runtime::Runtime; +const FULL_PREDICATE_SWEEP: [usize; 5] = [10, 20, 30, 40, 60]; +const FULL_DEPTH_SWEEP: [usize; 3] = [1, 2, 3]; +const DEFAULT_SWEEP_POINTS: [(usize, usize); 3] = [(10, 1), (30, 2), (60, 3)]; + // This benchmark suite is designed to test the performance of // logical planning with a large plan containing unions, many columns // with a variety of operations in it. @@ -218,7 +228,9 @@ fn build_test_data_frame(ctx: &SessionContext, rt: &Runtime) -> DataFrame { fn build_case_heavy_left_join_df(ctx: &SessionContext, rt: &Runtime) -> DataFrame { register_string_table(ctx, 100, 1000); let query = build_case_heavy_left_join_query(30, 1); - rt.block_on(async { ctx.sql(&query).await.unwrap() }) + let df = rt.block_on(async { ctx.sql(&query).await.unwrap() }); + assert_case_heavy_left_join_inference_candidates(&df, 30); + df } fn build_case_heavy_left_join_query(predicate_count: usize, case_depth: usize) -> String { @@ -237,12 +249,17 @@ fn build_case_heavy_left_join_query(predicate_count: usize, case_depth: usize) - query.push_str(" AND "); } - let mut expr = format!("length(l.c{})", i % 20); + let left_payload_col = (i % 19) + 1; + let right_payload_col = ((i + 7) % 19) + 1; + let mut expr = format!( + "CASE WHEN l.c0 IS NOT NULL THEN length(l.c{left_payload_col}) ELSE length(r.c{right_payload_col}) END" + ); for depth in 0..case_depth { - let left_col = (i + depth + 1) % 20; - let right_col = (i + depth + 2) % 20; + let left_col = ((i + depth + 3) % 19) + 1; + let right_col = ((i + depth + 11) % 19) + 1; + let join_key_ref = if (i + depth) % 2 == 0 { "l.c0" } else { "r.c0" }; expr = format!( - "CASE WHEN l.c{left_col} IS NOT NULL THEN {expr} ELSE length(r.c{right_col}) END" + "CASE WHEN {join_key_ref} IS NOT NULL THEN {expr} ELSE CASE WHEN l.c{left_col} IS NOT NULL THEN length(l.c{left_col}) ELSE length(r.c{right_col}) END END" ); } @@ -252,26 +269,6 @@ fn build_case_heavy_left_join_query(predicate_count: usize, case_depth: usize) - query } -fn build_case_heavy_left_join_df_with_push_down_filter( - rt: &Runtime, - predicate_count: usize, - case_depth: usize, - push_down_filter_enabled: bool, -) -> DataFrame { - let ctx = SessionContext::new(); - register_string_table(&ctx, 100, 1000); - if !push_down_filter_enabled { - let removed = ctx.remove_optimizer_rule("push_down_filter"); - assert!( - removed, - "push_down_filter rule should be present in the default optimizer" - ); - } - - let query = build_case_heavy_left_join_query(predicate_count, case_depth); - rt.block_on(async { ctx.sql(&query).await.unwrap() }) -} - fn build_non_case_left_join_query( predicate_count: usize, nesting_depth: usize, @@ -304,10 +301,11 @@ fn build_non_case_left_join_query( query } -fn build_non_case_left_join_df_with_push_down_filter( +fn build_left_join_df_with_push_down_filter( rt: &Runtime, + query_builder: impl Fn(usize, usize) -> String, predicate_count: usize, - nesting_depth: usize, + depth: usize, push_down_filter_enabled: bool, ) -> DataFrame { let ctx = SessionContext::new(); @@ -320,10 +318,138 @@ fn build_non_case_left_join_df_with_push_down_filter( ); } - let query = build_non_case_left_join_query(predicate_count, nesting_depth); + let query = query_builder(predicate_count, depth); rt.block_on(async { ctx.sql(&query).await.unwrap() }) } +fn build_case_heavy_left_join_df_with_push_down_filter( + rt: &Runtime, + predicate_count: usize, + case_depth: usize, + push_down_filter_enabled: bool, +) -> DataFrame { + let df = build_left_join_df_with_push_down_filter( + rt, + build_case_heavy_left_join_query, + predicate_count, + case_depth, + push_down_filter_enabled, + ); + assert_case_heavy_left_join_inference_candidates(&df, predicate_count); + df +} + +fn build_non_case_left_join_df_with_push_down_filter( + rt: &Runtime, + predicate_count: usize, + nesting_depth: usize, + push_down_filter_enabled: bool, +) -> DataFrame { + build_left_join_df_with_push_down_filter( + rt, + build_non_case_left_join_query, + predicate_count, + nesting_depth, + push_down_filter_enabled, + ) +} + +fn find_filter_predicates(plan: &LogicalPlan) -> Vec { + match plan { + LogicalPlan::Filter(filter) => split_conjunction_owned(filter.predicate.clone()), + LogicalPlan::Projection(projection) => find_filter_predicates(projection.input.as_ref()), + other => panic!("expected benchmark query plan to contain a Filter, found {other:?}"), + } +} + +fn assert_case_heavy_left_join_inference_candidates( + df: &DataFrame, + expected_predicate_count: usize, +) { + let predicates = find_filter_predicates(df.logical_plan()); + assert_eq!(predicates.len(), expected_predicate_count); + + let left_join_key = Column::from_qualified_name("l.c0"); + let right_join_key = Column::from_qualified_name("r.c0"); + + for predicate in predicates { + let column_refs = predicate.column_refs(); + assert!( + column_refs.contains(&&left_join_key) || column_refs.contains(&&right_join_key), + "benchmark predicate should reference a join key: {predicate}" + ); + assert!( + column_refs + .iter() + .any(|col| **col != left_join_key && **col != right_join_key), + "benchmark predicate should reference a non-join column: {predicate}" + ); + } +} + +fn include_full_push_down_filter_sweep() -> bool { + env::var("DATAFUSION_PUSH_DOWN_FILTER_FULL_SWEEP") + .map(|value| value == "1" || value.eq_ignore_ascii_case("true")) + .unwrap_or(false) +} + +fn push_down_filter_sweep_points() -> Vec<(usize, usize)> { + if include_full_push_down_filter_sweep() { + FULL_DEPTH_SWEEP + .into_iter() + .flat_map(|depth| { + FULL_PREDICATE_SWEEP + .into_iter() + .map(move |predicate_count| (predicate_count, depth)) + }) + .collect() + } else { + DEFAULT_SWEEP_POINTS.to_vec() + } +} + +fn bench_push_down_filter_ab( + group: &mut BenchmarkGroup<'_, WallTime>, + rt: &Runtime, + sweep_points: &[(usize, usize)], + build_df: BuildFn, +) where + BuildFn: Fn(&Runtime, usize, usize, bool) -> DataFrame, +{ + for &(predicate_count, depth) in sweep_points { + let with_push_down_filter = build_df(rt, predicate_count, depth, true); + let without_push_down_filter = build_df(rt, predicate_count, depth, false); + + let input_label = format!("predicates={predicate_count},nesting_depth={depth}"); + + group.bench_with_input( + BenchmarkId::new("with_push_down_filter", &input_label), + &with_push_down_filter, + |b, df| { + b.iter(|| { + let df_clone = df.clone(); + black_box( + rt.block_on(async { df_clone.into_optimized_plan().unwrap() }), + ); + }) + }, + ); + + group.bench_with_input( + BenchmarkId::new("without_push_down_filter", &input_label), + &without_push_down_filter, + |b, df| { + b.iter(|| { + let df_clone = df.clone(); + black_box( + rt.block_on(async { df_clone.into_optimized_plan().unwrap() }), + ); + }) + }, + ); + } +} + fn criterion_benchmark(c: &mut Criterion) { let baseline_ctx = SessionContext::new(); let case_heavy_ctx = SessionContext::new(); @@ -349,116 +475,40 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - let predicate_sweep = [10, 20, 30, 40, 60]; - let case_depth_sweep = [1, 2, 3]; + let sweep_points = push_down_filter_sweep_points(); let mut hotspot_group = c.benchmark_group("push_down_filter_hotspot_case_heavy_left_join_ab"); - for case_depth in case_depth_sweep { - for predicate_count in predicate_sweep { - let with_push_down_filter = - build_case_heavy_left_join_df_with_push_down_filter( - &rt, - predicate_count, - case_depth, - true, - ); - let without_push_down_filter = - build_case_heavy_left_join_df_with_push_down_filter( - &rt, - predicate_count, - case_depth, - false, - ); - - let input_label = - format!("predicates={predicate_count},case_depth={case_depth}"); - // A/B interpretation: - // - with_push_down_filter: default optimizer path (rule enabled) - // - without_push_down_filter: control path with the rule removed - // Compare both IDs at the same sweep point to isolate rule impact. - hotspot_group.bench_with_input( - BenchmarkId::new("with_push_down_filter", &input_label), - &with_push_down_filter, - |b, df| { - b.iter(|| { - let df_clone = df.clone(); - black_box( - rt.block_on(async { - df_clone.into_optimized_plan().unwrap() - }), - ); - }) - }, - ); - hotspot_group.bench_with_input( - BenchmarkId::new("without_push_down_filter", &input_label), - &without_push_down_filter, - |b, df| { - b.iter(|| { - let df_clone = df.clone(); - black_box( - rt.block_on(async { - df_clone.into_optimized_plan().unwrap() - }), - ); - }) - }, - ); - } - } + bench_push_down_filter_ab( + &mut hotspot_group, + &rt, + &sweep_points, + |rt, predicate_count, depth, enable| { + build_case_heavy_left_join_df_with_push_down_filter( + rt, + predicate_count, + depth, + enable, + ) + }, + ); hotspot_group.finish(); let mut control_group = c.benchmark_group("push_down_filter_control_non_case_left_join_ab"); - for nesting_depth in case_depth_sweep { - for predicate_count in predicate_sweep { - let with_push_down_filter = build_non_case_left_join_df_with_push_down_filter( - &rt, + bench_push_down_filter_ab( + &mut control_group, + &rt, + &sweep_points, + |rt, predicate_count, depth, enable| { + build_non_case_left_join_df_with_push_down_filter( + rt, predicate_count, - nesting_depth, - true, - ); - let without_push_down_filter = - build_non_case_left_join_df_with_push_down_filter( - &rt, - predicate_count, - nesting_depth, - false, - ); - - let input_label = - format!("predicates={predicate_count},nesting_depth={nesting_depth}"); - control_group.bench_with_input( - BenchmarkId::new("with_push_down_filter", &input_label), - &with_push_down_filter, - |b, df| { - b.iter(|| { - let df_clone = df.clone(); - black_box( - rt.block_on(async { - df_clone.into_optimized_plan().unwrap() - }), - ); - }) - }, - ); - control_group.bench_with_input( - BenchmarkId::new("without_push_down_filter", &input_label), - &without_push_down_filter, - |b, df| { - b.iter(|| { - let df_clone = df.clone(); - black_box( - rt.block_on(async { - df_clone.into_optimized_plan().unwrap() - }), - ); - }) - }, - ); - } - } + depth, + enable, + ) + }, + ); control_group.finish(); } diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 9a1dc5502ee60..a245227382d83 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -67,6 +67,7 @@ pub mod create_drop; pub mod explain_analyze; pub mod joins; mod path_partition; +mod push_down_filter_regressions; mod runtime_config; pub mod select; mod sql_api; diff --git a/datafusion/core/tests/sql/push_down_filter_regressions.rs b/datafusion/core/tests/sql/push_down_filter_regressions.rs new file mode 100644 index 0000000000000..5ad53f33c8b98 --- /dev/null +++ b/datafusion/core/tests/sql/push_down_filter_regressions.rs @@ -0,0 +1,222 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use super::*; +use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; + +const WINDOW_SCALAR_SUBQUERY_SQL: &str = r#" + WITH suppliers AS ( + SELECT * + FROM (VALUES (1, 10.0), (1, 20.0)) AS t(nation, acctbal) + ) + SELECT + ROW_NUMBER() OVER (PARTITION BY nation ORDER BY acctbal DESC) AS rn + FROM suppliers AS s + WHERE acctbal > ( + SELECT AVG(acctbal) FROM suppliers + ) +"#; + +const WINDOW_SCALAR_SUBQUERY_EXPECTED: &[&str] = + &["+----+", "| rn |", "+----+", "| 1 |", "+----+"]; + +fn sqllogictest_style_ctx(push_down_filter_enabled: bool) -> SessionContext { + let ctx = + SessionContext::new_with_config(SessionConfig::new().with_target_partitions(4)); + if !push_down_filter_enabled { + assert!(ctx.remove_optimizer_rule("push_down_filter")); + } + ctx +} + +async fn capture_window_scalar_subquery_plans( + push_down_filter_enabled: bool, +) -> Result<(String, String)> { + let ctx = sqllogictest_style_ctx(push_down_filter_enabled); + let df = ctx.sql(WINDOW_SCALAR_SUBQUERY_SQL).await?; + let optimized_plan = df.clone().into_optimized_plan()?; + let physical_plan = df.create_physical_plan().await?; + + Ok(( + optimized_plan.display_indent_schema().to_string(), + displayable(physical_plan.as_ref()).indent(true).to_string(), + )) +} + +async fn assert_window_scalar_subquery(ctx: SessionContext) -> Result<()> { + let results = ctx.sql(WINDOW_SCALAR_SUBQUERY_SQL).await?.collect().await?; + assert_batches_eq!(WINDOW_SCALAR_SUBQUERY_EXPECTED, &results); + Ok(()) +} + +#[tokio::test] +async fn window_scalar_subquery_regression() -> Result<()> { + assert_window_scalar_subquery(SessionContext::new()).await +} + +#[tokio::test] +async fn window_scalar_subquery_sqllogictest_style_regression() -> Result<()> { + assert_window_scalar_subquery(sqllogictest_style_ctx(true)).await +} + +#[tokio::test] +async fn aggregate_regr_functions_regression() -> Result<()> { + let ctx = SessionContext::new(); + let batch = RecordBatch::try_from_iter(vec![ + ( + "c11", + Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0])) as ArrayRef, + ), + ( + "c12", + Arc::new(Float64Array::from(vec![2.0, 4.0, 6.0])) as ArrayRef, + ), + ])?; + ctx.register_batch("aggregate_test_100", batch)?; + + let sql = r#" + select + regr_slope(c12, c11), + regr_intercept(c12, c11), + regr_count(c12, c11), + regr_r2(c12, c11), + regr_avgx(c12, c11), + regr_avgy(c12, c11), + regr_sxx(c12, c11), + regr_syy(c12, c11), + regr_sxy(c12, c11) + from aggregate_test_100 + "#; + + let rows = execute(&ctx, sql).await; + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].len(), 9); + assert!(rows[0].iter().all(|value| value != "NULL")); + + Ok(()) +} + +#[tokio::test] +async fn correlated_in_subquery_regression() -> Result<()> { + let ctx = SessionContext::new(); + let t1 = RecordBatch::try_from_iter(vec![ + ("t1_id", Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef), + ( + "t1_name", + Arc::new(StringArray::from(vec!["alpha", "beta"])) as ArrayRef, + ), + ("t1_int", Arc::new(Int32Array::from(vec![1, 0])) as ArrayRef), + ])?; + let t2 = RecordBatch::try_from_iter(vec![( + "t2_id", + Arc::new(Int32Array::from(vec![12, 99])) as ArrayRef, + )])?; + ctx.register_batch("t1", t1)?; + ctx.register_batch("t2", t2)?; + + let sql = r#" + select t1.t1_id, + t1.t1_name, + t1.t1_int + from t1 + where t1.t1_id + 12 in ( + select t2.t2_id + 1 from t2 where t1.t1_int > 0 + ) + "#; + + let results = ctx.sql(sql).await?.collect().await?; + + assert_batches_sorted_eq!( + &[ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 1 | alpha | 1 |", + "+-------+---------+--------+", + ], + &results + ); + + Ok(()) +} + +#[tokio::test] +async fn natural_join_union_regression() -> Result<()> { + let ctx = SessionContext::new(); + let t1 = RecordBatch::try_from_iter(vec![ + ("v0", Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef), + ( + "v2", + Arc::new(Int32Array::from(vec![None, Some(5)])) as ArrayRef, + ), + ])?; + // Keep `v2` only on the left side so the natural join key remains `v0`. + let t2 = RecordBatch::try_from_iter(vec![( + "v0", + Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef, + )])?; + ctx.register_batch("t1", t1)?; + ctx.register_batch("t2", t2)?; + + let sql = r#" + SELECT t1.v2, t1.v0 FROM t2 NATURAL JOIN t1 + UNION ALL + SELECT t1.v2, t1.v0 FROM t2 NATURAL JOIN t1 WHERE (t1.v2 IS NULL) + "#; + + let results = ctx.sql(sql).await?.collect().await?; + + assert_batches_sorted_eq!( + &[ + "+----+----+", + "| v2 | v0 |", + "+----+----+", + "| | 1 |", + "| | 1 |", + "| 5 | 2 |", + "+----+----+", + ], + &results + ); + + Ok(()) +} + +#[tokio::test(flavor = "current_thread")] +async fn window_scalar_subquery_optimizer_delta() -> Result<()> { + let (enabled_optimized, enabled_physical) = + capture_window_scalar_subquery_plans(true).await?; + let (disabled_optimized, disabled_physical) = + capture_window_scalar_subquery_plans(false).await?; + + assert_eq!(enabled_optimized, disabled_optimized); + assert_eq!(enabled_physical, disabled_physical); + + assert!( + enabled_optimized + .contains("Filter: s.acctbal > __scalar_sq_1.avg(suppliers.acctbal)") + ); + assert!(enabled_optimized.contains("Cross Join:")); + assert!( + enabled_physical.contains("FilterExec: acctbal@1 > avg(suppliers.acctbal)@2") + ); + assert!(enabled_physical.contains("CrossJoinExec")); + + Ok(()) +} diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 36deb0f67d77e..622780cad5085 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -285,6 +285,55 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Ok(is_evaluate) } +fn strip_plan_wrappers(plan: &LogicalPlan) -> (&LogicalPlan, bool) { + match plan { + LogicalPlan::SubqueryAlias(subquery_alias) => { + let (plan, _) = strip_plan_wrappers(subquery_alias.input.as_ref()); + (plan, true) + } + LogicalPlan::Projection(projection) => { + let (plan, is_derived_relation) = + strip_plan_wrappers(projection.input.as_ref()); + (plan, is_derived_relation) + } + _ => (plan, false), + } +} + +fn is_scalar_aggregate_subquery(plan: &LogicalPlan) -> bool { + matches!( + strip_plan_wrappers(plan).0, + LogicalPlan::Aggregate(aggregate) if aggregate.group_expr.is_empty() + ) +} + +fn is_derived_relation(plan: &LogicalPlan) -> bool { + strip_plan_wrappers(plan).1 +} + +fn is_scalar_subquery_cross_join(join: &Join) -> bool { + join.on.is_empty() + && join.filter.is_none() + && ((is_scalar_aggregate_subquery(join.left.as_ref()) + && is_derived_relation(join.right.as_ref())) + || (is_scalar_aggregate_subquery(join.right.as_ref()) + && is_derived_relation(join.left.as_ref()))) +} + +// Keep post-join filters above certain scalar-subquery cross joins to preserve +// behavior for the window-over-scalar-subquery regression shape. +fn should_keep_filter_above_scalar_subquery_cross_join( + join: &Join, + predicate: &Expr, +) -> bool { + if !is_scalar_subquery_cross_join(join) { + return false; + } + + let mut checker = ColumnChecker::new(join.left.schema(), join.right.schema()); + !checker.is_left_only(predicate) && !checker.is_right_only(predicate) +} + /// examine OR clause to see if any useful clauses can be extracted and push down. /// extract at least one qual from each sub clauses of OR clause, then form the quals /// to new OR clause as predicate. @@ -431,7 +480,10 @@ fn push_down_all_join( left_push.push(predicate); } else if right_preserved && checker.is_right_only(&predicate) { right_push.push(predicate); - } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? { + } else if is_inner_join + && !should_keep_filter_above_scalar_subquery_cross_join(&join, &predicate) + && can_evaluate_as_join_condition(&predicate)? + { // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate // and convert to the join on condition join_conditions.push(predicate); @@ -723,24 +775,36 @@ fn infer_join_predicates_impl< inferred_predicates: &mut InferredPredicates, ) -> Result<()> { for predicate in input_predicates { - let mut join_cols_to_replace = HashMap::new(); + let column_refs = predicate.column_refs(); + let join_col_replacements: Vec<_> = column_refs + .iter() + .filter_map(|&col| { + join_col_keys.iter().find_map(|(l, r)| { + if ENABLE_LEFT_TO_RIGHT && col == *l { + Some((col, *r)) + } else if ENABLE_RIGHT_TO_LEFT && col == *r { + Some((col, *l)) + } else { + None + } + }) + }) + .collect(); - for &col in &predicate.column_refs() { - for (l, r) in join_col_keys.iter() { - if ENABLE_LEFT_TO_RIGHT && col == *l { - join_cols_to_replace.insert(col, *r); - break; - } - if ENABLE_RIGHT_TO_LEFT && col == *r { - join_cols_to_replace.insert(col, *l); - break; - } - } + if join_col_replacements.is_empty() { + continue; } - if join_cols_to_replace.is_empty() { + + // For non-inner joins, predicates that reference any non-replaceable + // columns cannot be inferred on the other side. Skip the null-restriction + // helper entirely in that common mixed-reference case. + if !inferred_predicates.is_inner_join + && join_col_replacements.len() != column_refs.len() + { continue; } + let join_cols_to_replace = join_col_replacements.into_iter().collect(); inferred_predicates .try_build_predicate(predicate.clone(), &join_cols_to_replace)?; } @@ -1484,7 +1548,7 @@ mod tests { use crate::simplify_expressions::SimplifyExpressions; use crate::test::udfs::leaf_udf_expr; use crate::test::*; - use datafusion_expr::test::function_stub::sum; + use datafusion_expr::test::function_stub::{avg, sum}; use insta::assert_snapshot; use super::*; @@ -2408,6 +2472,138 @@ mod tests { ) } + #[test] + fn window_over_scalar_subquery_cross_join_keeps_filter_above_join() -> Result<()> { + let left = LogicalPlanBuilder::from(test_table_scan()?) + .project(vec![col("a").alias("nation"), col("b").alias("acctbal")])? + .alias("s")? + .build()?; + let right = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?) + .project(vec![col("a").alias("acctbal")])? + .aggregate( + Vec::::new(), + vec![avg(col("acctbal")).alias("avg_acctbal")], + )? + .alias("__scalar_sq_1")? + .build()?; + + let window = Expr::from(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::row_number::row_number_udwf(), + ), + vec![], + )) + .partition_by(vec![col("s.nation")]) + .order_by(vec![col("s.acctbal").sort(false, true)]) + .build() + .unwrap(); + + let plan = LogicalPlanBuilder::from(left) + .cross_join(right)? + .filter(col("s.acctbal").gt(col("__scalar_sq_1.avg_acctbal")))? + .project(vec![col("s.nation"), col("s.acctbal")])? + .window(vec![window])? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + WindowAggr: windowExpr=[[row_number() PARTITION BY [s.nation] ORDER BY [s.acctbal DESC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + Projection: s.nation, s.acctbal + Filter: s.acctbal > __scalar_sq_1.avg_acctbal + Cross Join: + SubqueryAlias: s + Projection: test.a AS nation, test.b AS acctbal + TableScan: test + SubqueryAlias: __scalar_sq_1 + Aggregate: groupBy=[[]], aggr=[[avg(acctbal) AS avg_acctbal]] + Projection: test1.a AS acctbal + TableScan: test1 + " + ) + } + + #[test] + fn window_over_scalar_subquery_cross_join_with_project_wrapper_keeps_filter_above_join() + -> Result<()> { + let left = LogicalPlanBuilder::from(test_table_scan()?) + .project(vec![col("a").alias("nation"), col("b").alias("acctbal")])? + .alias("s")? + .project(vec![col("s.nation"), col("s.acctbal")])? + .build()?; + let right = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?) + .project(vec![col("a").alias("acctbal")])? + .aggregate( + Vec::::new(), + vec![avg(col("acctbal")).alias("avg_acctbal")], + )? + .alias("__scalar_sq_1")? + .build()?; + + let window = Expr::from(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::row_number::row_number_udwf(), + ), + vec![], + )) + .partition_by(vec![col("s.nation")]) + .order_by(vec![col("s.acctbal").sort(false, true)]) + .build() + .unwrap(); + + let plan = LogicalPlanBuilder::from(left) + .cross_join(right)? + .filter(col("s.acctbal").gt(col("__scalar_sq_1.avg_acctbal")))? + .project(vec![col("s.nation"), col("s.acctbal")])? + .window(vec![window])? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + WindowAggr: windowExpr=[[row_number() PARTITION BY [s.nation] ORDER BY [s.acctbal DESC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + Projection: s.nation, s.acctbal + Filter: s.acctbal > __scalar_sq_1.avg_acctbal + Cross Join: + Projection: s.nation, s.acctbal + SubqueryAlias: s + Projection: test.a AS nation, test.b AS acctbal + TableScan: test + SubqueryAlias: __scalar_sq_1 + Aggregate: groupBy=[[]], aggr=[[avg(acctbal) AS avg_acctbal]] + Projection: test1.a AS acctbal + TableScan: test1 + " + ) + } + + #[test] + fn cross_join_builder_uses_inner_join_with_no_join_keys() -> Result<()> { + let plan = LogicalPlanBuilder::from(test_table_scan()?) + .cross_join(test_table_scan_with_name("test1")?)? + .build()?; + + let LogicalPlan::Join(join) = plan else { + panic!("expected join plan"); + }; + + assert_eq!(join.join_type, JoinType::Inner); + assert!(join.on.is_empty()); + assert!(join.filter.is_none()); + + Ok(()) + } + + #[test] + fn scalar_subquery_cross_join_filter_is_treated_as_join_condition_candidate() + -> Result<()> { + let predicate = col("s.acctbal").gt(col("__scalar_sq_1.avg_acctbal")); + + assert!(can_evaluate_as_join_condition(&predicate)?); + + Ok(()) + } + /// verifies that filters with the same columns are correctly placed #[test] fn filter_2_breaks_limits() -> Result<()> { @@ -2798,6 +2994,41 @@ mod tests { ) } + /// mixed post-left-join predicates that reference a join key plus a + /// non-join column should not be inferred to the preserved side + #[test] + fn filter_using_left_join_with_mixed_join_key_and_non_join_refs() -> Result<()> { + let table_scan = test_table_scan()?; + let left = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a"), col("c")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + right, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(add(col("test2.a"), col("test.b")).gt(lit(1i64)))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Filter: test2.a + test.b > Int64(1) + Left Join: test.a = test2.a + Projection: test.a, test.b + TableScan: test + Projection: test2.a, test2.c + TableScan: test2 + " + ) + } + /// post-right-join predicate on a column common to both sides is pushed to both sides #[test] fn filter_using_right_join_on_common() -> Result<()> { diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 7e038d2392022..329271a067ee8 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -17,9 +17,58 @@ //! Utility functions leveraged by the query optimizer rules +mod null_restriction; + use std::collections::{BTreeSet, HashMap, HashSet}; +use std::sync::Arc; + +#[cfg(test)] +use std::cell::Cell; use crate::analyzer::type_coercion::TypeCoercionRewriter; + +/// Null restriction evaluation mode for optimizer tests. +#[cfg(test)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub(crate) enum NullRestrictionEvalMode { + Auto, + AuthoritativeOnly, +} + +#[cfg(test)] +thread_local! { + static NULL_RESTRICTION_EVAL_MODE: Cell = + const { Cell::new(NullRestrictionEvalMode::Auto) }; +} + +#[cfg(test)] +pub(crate) fn set_null_restriction_eval_mode_for_test(mode: NullRestrictionEvalMode) { + NULL_RESTRICTION_EVAL_MODE.with(|eval_mode| eval_mode.set(mode)); +} + +#[cfg(test)] +fn null_restriction_eval_mode() -> NullRestrictionEvalMode { + NULL_RESTRICTION_EVAL_MODE.with(Cell::get) +} + +#[cfg(test)] +pub(crate) fn with_null_restriction_eval_mode_for_test( + mode: NullRestrictionEvalMode, + f: impl FnOnce() -> T, +) -> T { + struct NullRestrictionEvalModeReset(NullRestrictionEvalMode); + + impl Drop for NullRestrictionEvalModeReset { + fn drop(&mut self) { + set_null_restriction_eval_mode_for_test(self.0); + } + } + + let previous_mode = null_restriction_eval_mode(); + set_null_restriction_eval_mode_for_test(mode); + let _reset = NullRestrictionEvalModeReset(previous_mode); + f() +} use arrow::array::{Array, RecordBatch, new_null_array}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::cast::as_boolean_array; @@ -30,7 +79,6 @@ use datafusion_expr::expr_rewriter::replace_col; use datafusion_expr::{ColumnarValue, Expr, logical_plan::LogicalPlan}; use datafusion_physical_expr::create_physical_expr; use log::{debug, trace}; -use std::sync::Arc; /// Re-export of `NamesPreserver` for backwards compatibility, /// as it was initially placed here and then moved elsewhere. @@ -79,24 +127,46 @@ pub fn is_restrict_null_predicate<'a>( return Ok(true); } - // If result is single `true`, return false; - // If result is single `NULL` or `false`, return true; - Ok( - match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? { - ColumnarValue::Array(array) => { - if array.len() == 1 { - let boolean_array = as_boolean_array(&array)?; - boolean_array.is_null(0) || !boolean_array.value(0) - } else { - false - } - } - ColumnarValue::Scalar(scalar) => matches!( - scalar, - ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false)) - ), - }, - ) + // Collect join columns so they can be used in both the fast-path check and the + // fallback evaluation path below. + let join_cols: HashSet<&Column> = join_cols_of_predicate.into_iter().collect(); + let column_refs = predicate.column_refs(); + + // Fast path: if the predicate references columns outside the join key set, + // `evaluate_expr_with_null_column` would fail because the null schema only + // contains a placeholder for the join key columns. Callers treat such errors as + // non-restricting (false) via `matches!(_, Ok(true))`, so we return false early + // and avoid the expensive physical-expression compilation pipeline entirely. + if !null_restriction::all_columns_allowed(&column_refs, &join_cols) { + return Ok(false); + } + + #[cfg(test)] + if matches!( + null_restriction_eval_mode(), + NullRestrictionEvalMode::AuthoritativeOnly + ) { + return authoritative_restrict_null_predicate(predicate, join_cols); + } + + if let Some(is_restricting) = + null_restriction::syntactic_restrict_null_predicate(&predicate, &join_cols) + { + #[cfg(debug_assertions)] + { + let authoritative = authoritative_restrict_null_predicate( + predicate.clone(), + join_cols.iter().copied(), + )?; + debug_assert_eq!( + is_restricting, authoritative, + "syntactic fast path disagrees with authoritative null-restriction evaluation for predicate: {predicate}" + ); + } + return Ok(is_restricting); + } + + authoritative_restrict_null_predicate(predicate, join_cols) } /// Determines if an expression will always evaluate to null. @@ -146,6 +216,30 @@ fn evaluate_expr_with_null_column<'a>( .evaluate(&input_batch) } +fn authoritative_restrict_null_predicate<'a>( + predicate: Expr, + join_cols_of_predicate: impl IntoIterator, +) -> Result { + Ok( + match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? { + ColumnarValue::Array(array) => { + if array.len() == 1 { + let boolean_array = as_boolean_array(&array)?; + boolean_array.is_null(0) || !boolean_array.value(0) + } else { + false + } + } + ColumnarValue::Scalar(scalar) => matches!( + scalar, + ScalarValue::Boolean(None) + | ScalarValue::Boolean(Some(false)) + | ScalarValue::Null + ), + }, + ) +} + fn coerce(expr: Expr, schema: &DFSchema) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; expr.rewrite(&mut expr_rewrite).data() @@ -154,7 +248,11 @@ fn coerce(expr: Expr, schema: &DFSchema) -> Result { #[cfg(test)] mod tests { use super::*; - use datafusion_expr::{Operator, binary_expr, case, col, in_list, is_null, lit}; + use std::panic::{AssertUnwindSafe, catch_unwind}; + + use datafusion_expr::{ + Operator, binary_expr, case, col, in_list, is_null, lit, when, + }; #[test] fn expr_is_restrict_null_predicate() -> Result<()> { @@ -193,6 +291,27 @@ mod tests { .otherwise(lit(false))?, true, ), + // CASE 1 WHEN 1 THEN true ELSE false END + ( + case(lit(1i64)) + .when(lit(1i64), lit(true)) + .otherwise(lit(false))?, + false, + ), + // CASE 1 WHEN 1 THEN NULL ELSE false END + ( + case(lit(1i64)) + .when(lit(1i64), lit(ScalarValue::Null)) + .otherwise(lit(false))?, + true, + ), + // CASE true WHEN true THEN false ELSE true END + ( + case(lit(true)) + .when(lit(true), lit(false)) + .otherwise(lit(true))?, + true, + ), // CASE a WHEN 0 THEN false ELSE true END ( case(col("a")) @@ -246,16 +365,171 @@ mod tests { in_list(col("a"), vec![Expr::Literal(ScalarValue::Null, None)], true), true, ), + // CASE WHEN a IS NOT NULL THEN a ELSE b END > 2 + ( + binary_expr( + when(Expr::IsNotNull(Box::new(col("a"))), col("a")) + .otherwise(col("b"))?, + Operator::Gt, + lit(2i64), + ), + true, + ), ]; - let column_a = Column::from_name("a"); for (predicate, expected) in test_cases { - let join_cols_of_predicate = std::iter::once(&column_a); - let actual = - is_restrict_null_predicate(predicate.clone(), join_cols_of_predicate)?; + let join_cols_of_predicate = predicate.column_refs(); + let actual = is_restrict_null_predicate( + predicate.clone(), + join_cols_of_predicate.iter().copied(), + )?; assert_eq!(actual, expected, "{predicate}"); } + // Keep coverage for the fast path that rejects predicates referencing + // columns outside the provided join key set. + let predicate = binary_expr(col("a"), Operator::Gt, col("b")); + let column_a = Column::from_name("a"); + let actual = + is_restrict_null_predicate(predicate.clone(), std::iter::once(&column_a))?; + assert!(!actual, "{predicate}"); + Ok(()) } + + #[test] + fn syntactic_fast_path_matches_authoritative_evaluator() -> Result<()> { + let test_cases = vec![ + is_null(col("a")), + Expr::IsNotNull(Box::new(col("a"))), + binary_expr(col("a"), Operator::Gt, lit(8i64)), + binary_expr(col("a"), Operator::Eq, lit(ScalarValue::Null)), + binary_expr(col("a"), Operator::And, lit(true)), + binary_expr(col("a"), Operator::Or, lit(false)), + Expr::Not(Box::new(col("a").is_true())), + col("a").is_true(), + col("a").is_false(), + col("a").is_unknown(), + col("a").is_not_true(), + col("a").is_not_false(), + col("a").is_not_unknown(), + col("a").between(lit(1i64), lit(10i64)), + binary_expr( + when(Expr::IsNotNull(Box::new(col("a"))), col("a")) + .otherwise(col("b"))?, + Operator::Gt, + lit(2i64), + ), + case(col("a")) + .when(lit(1i64), lit(true)) + .otherwise(lit(false))?, + case(col("a")) + .when(lit(0i64), lit(false)) + .otherwise(lit(true))?, + binary_expr( + case(col("a")) + .when(lit(0i64), lit(true)) + .otherwise(lit(false))?, + Operator::Or, + lit(false), + ), + binary_expr( + case(lit(1i64)) + .when(lit(1i64), lit(ScalarValue::Null)) + .otherwise(lit(false))?, + Operator::IsNotDistinctFrom, + lit(true), + ), + ]; + + for predicate in test_cases { + let join_cols = predicate.column_refs(); + if let Some(syntactic) = null_restriction::syntactic_restrict_null_predicate( + &predicate, &join_cols, + ) { + let authoritative = authoritative_restrict_null_predicate( + predicate.clone(), + join_cols.iter().copied(), + ) + .unwrap_or_else(|error| { + panic!( + "authoritative evaluator failed for predicate `{predicate}`: {error}" + ) + }); + assert_eq!( + syntactic, authoritative, + "syntactic fast path disagrees with authoritative evaluator for predicate: {predicate}", + ); + } + } + + Ok(()) + } + + #[test] + fn null_restriction_eval_mode_auto_vs_authoritative_only() -> Result<()> { + let predicate = binary_expr(col("a"), Operator::Gt, lit(8i64)); + let join_cols_of_predicate = predicate.column_refs(); + + let auto_result = with_null_restriction_eval_mode_for_test( + NullRestrictionEvalMode::Auto, + || { + is_restrict_null_predicate( + predicate.clone(), + join_cols_of_predicate.iter().copied(), + ) + }, + )?; + + let authoritative_result = with_null_restriction_eval_mode_for_test( + NullRestrictionEvalMode::AuthoritativeOnly, + || { + is_restrict_null_predicate( + predicate.clone(), + join_cols_of_predicate.iter().copied(), + ) + }, + )?; + + assert_eq!(auto_result, authoritative_result); + + Ok(()) + } + + #[test] + fn mixed_reference_predicate_remains_fast_pathed_in_authoritative_mode() -> Result<()> + { + let predicate = binary_expr(col("a"), Operator::Gt, col("b")); + let column_a = Column::from_name("a"); + + let auto_result = with_null_restriction_eval_mode_for_test( + NullRestrictionEvalMode::Auto, + || is_restrict_null_predicate(predicate.clone(), std::iter::once(&column_a)), + )?; + + let authoritative_only_result = with_null_restriction_eval_mode_for_test( + NullRestrictionEvalMode::AuthoritativeOnly, + || is_restrict_null_predicate(predicate.clone(), std::iter::once(&column_a)), + )?; + + assert!(!auto_result, "{predicate}"); + assert!(!authoritative_only_result, "{predicate}"); + + Ok(()) + } + + #[test] + fn null_restriction_eval_mode_guard_restores_on_panic() { + set_null_restriction_eval_mode_for_test(NullRestrictionEvalMode::Auto); + + let result = catch_unwind(AssertUnwindSafe(|| { + with_null_restriction_eval_mode_for_test( + NullRestrictionEvalMode::AuthoritativeOnly, + || panic!("intentional panic to verify test mode reset"), + ) + })); + + assert!(result.is_err()); + assert_eq!(null_restriction_eval_mode(), NullRestrictionEvalMode::Auto); + } } diff --git a/datafusion/optimizer/src/utils/null_restriction.rs b/datafusion/optimizer/src/utils/null_restriction.rs new file mode 100644 index 0000000000000..6e9920af80acc --- /dev/null +++ b/datafusion/optimizer/src/utils/null_restriction.rs @@ -0,0 +1,261 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Syntactic null-restriction evaluator used by optimizer fast paths. + +use std::collections::HashSet; + +use datafusion_common::{Column, ScalarValue}; +use datafusion_expr::{BinaryExpr, Expr, Operator}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum NullSubstitutionValue { + /// SQL NULL after substituting join columns with NULL. + Null, + /// Known to be non-null, but value is otherwise unknown. + NonNull, + /// A known boolean outcome from SQL three-valued logic. + Boolean(bool), +} + +pub(super) fn all_columns_allowed( + column_refs: &HashSet<&Column>, + allowed_columns: &HashSet<&Column>, +) -> bool { + column_refs + .iter() + .all(|column| allowed_columns.contains(*column)) +} + +pub(super) fn syntactic_restrict_null_predicate( + predicate: &Expr, + join_cols: &HashSet<&Column>, +) -> Option { + match syntactic_null_substitution_value(predicate, join_cols) { + Some(NullSubstitutionValue::Boolean(value)) => Some(!value), + Some(NullSubstitutionValue::Null) => Some(true), + Some(NullSubstitutionValue::NonNull) | None => None, + } +} + +fn not(value: Option) -> Option { + match value { + Some(NullSubstitutionValue::Boolean(value)) => { + Some(NullSubstitutionValue::Boolean(!value)) + } + Some(NullSubstitutionValue::Null) => Some(NullSubstitutionValue::Null), + Some(NullSubstitutionValue::NonNull) | None => None, + } +} + +fn binary_boolean_value( + left: Option, + right: Option, + when_short_circuit: bool, +) -> Option { + let short_circuit = Some(NullSubstitutionValue::Boolean(when_short_circuit)); + let identity = Some(NullSubstitutionValue::Boolean(!when_short_circuit)); + + if left == short_circuit || right == short_circuit { + return short_circuit; + } + + match (left, right) { + (value, other) if value == identity => other, + (other, value) if value == identity => other, + (Some(NullSubstitutionValue::Null), Some(NullSubstitutionValue::Null)) => { + Some(NullSubstitutionValue::Null) + } + (Some(NullSubstitutionValue::NonNull), _) + | (_, Some(NullSubstitutionValue::NonNull)) + | (None, _) + | (_, None) => None, + (left, right) => { + debug_assert_eq!(left, right); + left + } + } +} + +fn null_check_value( + value: Option, + is_not_null: bool, +) -> Option { + match value { + Some(NullSubstitutionValue::Null) => { + Some(NullSubstitutionValue::Boolean(!is_not_null)) + } + Some(NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_)) => { + Some(NullSubstitutionValue::Boolean(is_not_null)) + } + None => None, + } +} + +fn null_if_contains_null( + values: impl IntoIterator>, +) -> Option { + values + .into_iter() + .any(|value| matches!(value, Some(NullSubstitutionValue::Null))) + .then_some(NullSubstitutionValue::Null) +} + +fn strict_null_only( + value: Option, +) -> Option { + value.filter(|value| matches!(value, NullSubstitutionValue::Null)) +} + +fn syntactic_null_substitution_value( + expr: &Expr, + join_cols: &HashSet<&Column>, +) -> Option { + match expr { + Expr::Alias(alias) => { + syntactic_null_substitution_value(alias.expr.as_ref(), join_cols) + } + Expr::Column(column) => join_cols + .contains(column) + .then_some(NullSubstitutionValue::Null), + Expr::Literal(value, _) => Some(scalar_to_null_substitution_value(value)), + Expr::BinaryExpr(binary_expr) => syntactic_binary_value(binary_expr, join_cols), + Expr::Not(expr) => { + not(syntactic_null_substitution_value(expr.as_ref(), join_cols)) + } + Expr::IsNull(expr) => null_check_value( + syntactic_null_substitution_value(expr.as_ref(), join_cols), + false, + ), + Expr::IsNotNull(expr) => null_check_value( + syntactic_null_substitution_value(expr.as_ref(), join_cols), + true, + ), + Expr::Between(between) => null_if_contains_null([ + syntactic_null_substitution_value(between.expr.as_ref(), join_cols), + syntactic_null_substitution_value(between.low.as_ref(), join_cols), + syntactic_null_substitution_value(between.high.as_ref(), join_cols), + ]), + Expr::Cast(cast) => strict_null_only(syntactic_null_substitution_value( + cast.expr.as_ref(), + join_cols, + )), + Expr::TryCast(try_cast) => strict_null_only(syntactic_null_substitution_value( + try_cast.expr.as_ref(), + join_cols, + )), + Expr::Negative(expr) => { + strict_null_only(syntactic_null_substitution_value(expr.as_ref(), join_cols)) + } + Expr::Like(like) | Expr::SimilarTo(like) => null_if_contains_null([ + syntactic_null_substitution_value(like.expr.as_ref(), join_cols), + syntactic_null_substitution_value(like.pattern.as_ref(), join_cols), + ]), + Expr::Exists { .. } + | Expr::InList(_) + | Expr::InSubquery(_) + | Expr::SetComparison(_) + | Expr::ScalarSubquery(_) + | Expr::OuterReferenceColumn(_, _) + | Expr::Placeholder(_) + | Expr::ScalarVariable(_, _) + | Expr::Unnest(_) + | Expr::GroupingSet(_) + | Expr::WindowFunction(_) + | Expr::ScalarFunction(_) + | Expr::Case(_) + | Expr::IsTrue(_) + | Expr::IsFalse(_) + | Expr::IsUnknown(_) + | Expr::IsNotTrue(_) + | Expr::IsNotFalse(_) + | Expr::IsNotUnknown(_) => None, + Expr::AggregateFunction(_) => None, + #[expect(deprecated)] + Expr::Wildcard { .. } => None, + } +} + +fn scalar_to_null_substitution_value(value: &ScalarValue) -> NullSubstitutionValue { + match value { + _ if value.is_null() => NullSubstitutionValue::Null, + ScalarValue::Boolean(Some(value)) => NullSubstitutionValue::Boolean(*value), + _ => NullSubstitutionValue::NonNull, + } +} + +fn is_strict_null_binary_op(op: Operator) -> bool { + matches!( + op, + Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::LtEq + | Operator::Gt + | Operator::GtEq + | Operator::Plus + | Operator::Minus + | Operator::Multiply + | Operator::Divide + | Operator::Modulo + | Operator::RegexMatch + | Operator::RegexIMatch + | Operator::RegexNotMatch + | Operator::RegexNotIMatch + | Operator::LikeMatch + | Operator::ILikeMatch + | Operator::NotLikeMatch + | Operator::NotILikeMatch + | Operator::BitwiseAnd + | Operator::BitwiseOr + | Operator::BitwiseXor + | Operator::BitwiseShiftRight + | Operator::BitwiseShiftLeft + | Operator::StringConcat + | Operator::AtArrow + | Operator::ArrowAt + | Operator::Arrow + | Operator::LongArrow + | Operator::HashArrow + | Operator::HashLongArrow + | Operator::AtAt + | Operator::IntegerDivide + | Operator::HashMinus + | Operator::AtQuestion + | Operator::Question + | Operator::QuestionAnd + | Operator::QuestionPipe + | Operator::Colon + ) +} + +fn syntactic_binary_value( + binary_expr: &BinaryExpr, + join_cols: &HashSet<&Column>, +) -> Option { + let left = syntactic_null_substitution_value(binary_expr.left.as_ref(), join_cols); + let right = syntactic_null_substitution_value(binary_expr.right.as_ref(), join_cols); + + match binary_expr.op { + Operator::And => binary_boolean_value(left, right, false), + Operator::Or => binary_boolean_value(left, right, true), + Operator::IsDistinctFrom | Operator::IsNotDistinctFrom => None, + op => is_strict_null_binary_op(op) + .then(|| null_if_contains_null([left, right])) + .flatten(), + } +}