diff --git a/datafusion/optimizer/src/extract_leaf_expressions.rs b/datafusion/optimizer/src/extract_leaf_expressions.rs index 151bca8278836..d04261456d600 100644 --- a/datafusion/optimizer/src/extract_leaf_expressions.rs +++ b/datafusion/optimizer/src/extract_leaf_expressions.rs @@ -167,72 +167,22 @@ mod tests { use super::*; use crate::optimize_projections::OptimizeProjections; + use crate::test::udfs::PlacementTestUDF; use crate::test::*; use crate::{Optimizer, OptimizerContext}; - use arrow::datatypes::DataType; use datafusion_common::Result; use datafusion_expr::expr::ScalarFunction; + use datafusion_expr::{Expr, ExpressionPlacement}; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, - TypeSignature, col, lit, logical_plan::builder::LogicalPlanBuilder, + ScalarUDF, col, lit, logical_plan::builder::LogicalPlanBuilder, }; - use datafusion_expr::{Expr, ExpressionPlacement}; - - /// A mock UDF that simulates a leaf-pushable function like `get_field`. - /// It returns `MoveTowardsLeafNodes` when its first argument is Column or MoveTowardsLeafNodes. - #[derive(Debug, PartialEq, Eq, Hash)] - struct MockLeafFunc { - signature: Signature, - } - - impl MockLeafFunc { - fn new() -> Self { - Self { - signature: Signature::new( - TypeSignature::Any(2), - datafusion_expr::Volatility::Immutable, - ), - } - } - } - - impl ScalarUDFImpl for MockLeafFunc { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn name(&self) -> &str { - "mock_leaf" - } - - fn signature(&self) -> &Signature { - &self.signature - } - fn return_type(&self, _args: &[DataType]) -> Result { - Ok(DataType::Utf8) - } - - fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { - unimplemented!("This is only used for testing optimization") - } - - fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement { - // Return MoveTowardsLeafNodes if first arg is Column or MoveTowardsLeafNodes - // (like get_field does) - match args.first() { - Some(ExpressionPlacement::Column) - | Some(ExpressionPlacement::MoveTowardsLeafNodes) => { - ExpressionPlacement::MoveTowardsLeafNodes - } - _ => ExpressionPlacement::KeepInPlace, - } - } - } - - fn mock_leaf(expr: Expr, name: &str) -> Expr { + fn leaf_udf(expr: Expr, name: &str) -> Expr { Expr::ScalarFunction(ScalarFunction::new_udf( - Arc::new(ScalarUDF::new_from_impl(MockLeafFunc::new())), + Arc::new(ScalarUDF::new_from_impl( + PlacementTestUDF::new() + .with_placement(ExpressionPlacement::MoveTowardsLeafNodes), + )), vec![expr, lit(name)], )) } @@ -251,9 +201,8 @@ mod tests { /// 3. **After Pushdown** - + PushDownLeafProjections /// 4. **Optimized** - + final OptimizeProjections fn format_optimization_stages(plan: &LogicalPlan) -> Result { - let ctx = OptimizerContext::new().with_max_passes(1); - let run = |rules: Vec>| -> Result { + let ctx = OptimizerContext::new().with_max_passes(1); let optimizer = Optimizer::with_rules(rules); let optimized = optimizer.optimize(plan.clone(), &ctx, |_, _| {})?; Ok(format!("{optimized}")) @@ -318,7 +267,7 @@ mod tests { fn test_extract_from_filter() -> Result<()> { let table_scan = test_table_scan_with_struct()?; let plan = LogicalPlanBuilder::from(table_scan.clone()) - .filter(mock_leaf(col("user"), "status").eq(lit("active")))? + .filter(leaf_udf(col("user"), "status").eq(lit("active")))? .select(vec![ table_scan .schema() @@ -330,7 +279,7 @@ mod tests { assert_stages!(plan, @r#" ## Original Plan Projection: test.id - Filter: mock_leaf(test.user, Utf8("status")) = Utf8("active") + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") TableScan: test projection=[id, user] ## After Extraction @@ -371,12 +320,12 @@ mod tests { fn test_extract_from_projection() -> Result<()> { let table_scan = test_table_scan_with_struct()?; let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![mock_leaf(col("user"), "name")])? + .project(vec![leaf_udf(col("user"), "name")])? .build()?; assert_stages!(plan, @r#" ## Original Plan - Projection: mock_leaf(test.user, Utf8("name")) + Projection: leaf_udf(test.user, Utf8("name")) TableScan: test projection=[user] ## After Extraction @@ -395,7 +344,7 @@ mod tests { let table_scan = test_table_scan_with_struct()?; let plan = LogicalPlanBuilder::from(table_scan) .project(vec![ - mock_leaf(col("user"), "name") + leaf_udf(col("user"), "name") .is_not_null() .alias("has_name"), ])? @@ -403,7 +352,7 @@ mod tests { assert_stages!(plan, @r#" ## Original Plan - Projection: mock_leaf(test.user, Utf8("name")) IS NOT NULL AS has_name + Projection: leaf_udf(test.user, Utf8("name")) IS NOT NULL AS has_name TableScan: test projection=[user] ## After Extraction @@ -442,7 +391,7 @@ mod tests { #[test] fn test_filter_with_deduplication() -> Result<()> { let table_scan = test_table_scan_with_struct()?; - let field_access = mock_leaf(col("user"), "name"); + let field_access = leaf_udf(col("user"), "name"); // Filter with the same expression used twice let plan = LogicalPlanBuilder::from(table_scan) .filter( @@ -455,7 +404,7 @@ mod tests { assert_stages!(plan, @r#" ## Original Plan - Filter: mock_leaf(test.user, Utf8("name")) IS NOT NULL AND mock_leaf(test.user, Utf8("name")) IS NULL + Filter: leaf_udf(test.user, Utf8("name")) IS NOT NULL AND leaf_udf(test.user, Utf8("name")) IS NULL TableScan: test projection=[id, user] ## After Extraction @@ -473,12 +422,12 @@ mod tests { fn test_already_leaf_expression_in_filter() -> Result<()> { let table_scan = test_table_scan_with_struct()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(mock_leaf(col("user"), "name").eq(lit("test")))? + .filter(leaf_udf(col("user"), "name").eq(lit("test")))? .build()?; assert_stages!(plan, @r#" ## Original Plan - Filter: mock_leaf(test.user, Utf8("name")) = Utf8("test") + Filter: leaf_udf(test.user, Utf8("name")) = Utf8("test") TableScan: test projection=[id, user] ## After Extraction @@ -498,12 +447,12 @@ mod tests { let table_scan = test_table_scan_with_struct()?; let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(vec![mock_leaf(col("user"), "status")], vec![count(lit(1))])? + .aggregate(vec![leaf_udf(col("user"), "status")], vec![count(lit(1))])? .build()?; assert_stages!(plan, @r#" ## Original Plan - Aggregate: groupBy=[[mock_leaf(test.user, Utf8("status"))]], aggr=[[COUNT(Int32(1))]] + Aggregate: groupBy=[[leaf_udf(test.user, Utf8("status"))]], aggr=[[COUNT(Int32(1))]] TableScan: test projection=[user] ## After Extraction @@ -525,13 +474,13 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate( vec![col("user")], - vec![count(mock_leaf(col("user"), "value"))], + vec![count(leaf_udf(col("user"), "value"))], )? .build()?; assert_stages!(plan, @r#" ## Original Plan - Aggregate: groupBy=[[test.user]], aggr=[[COUNT(mock_leaf(test.user, Utf8("value")))]] + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(leaf_udf(test.user, Utf8("value")))]] TableScan: test projection=[user] ## After Extraction @@ -549,14 +498,14 @@ mod tests { fn test_projection_with_filter_combined() -> Result<()> { let table_scan = test_table_scan_with_struct()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(mock_leaf(col("user"), "status").eq(lit("active")))? - .project(vec![mock_leaf(col("user"), "name")])? + .filter(leaf_udf(col("user"), "status").eq(lit("active")))? + .project(vec![leaf_udf(col("user"), "name")])? .build()?; assert_stages!(plan, @r#" ## Original Plan - Projection: mock_leaf(test.user, Utf8("name")) - Filter: mock_leaf(test.user, Utf8("status")) = Utf8("active") + Projection: leaf_udf(test.user, Utf8("name")) + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") TableScan: test projection=[user] ## After Extraction @@ -574,12 +523,12 @@ mod tests { fn test_projection_preserves_alias() -> Result<()> { let table_scan = test_table_scan_with_struct()?; let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![mock_leaf(col("user"), "name").alias("username")])? + .project(vec![leaf_udf(col("user"), "name").alias("username")])? .build()?; assert_stages!(plan, @r#" ## Original Plan - Projection: mock_leaf(test.user, Utf8("name")) AS username + Projection: leaf_udf(test.user, Utf8("name")) AS username TableScan: test projection=[user] ## After Extraction @@ -600,14 +549,14 @@ mod tests { fn test_projection_different_field_from_filter() -> Result<()> { let table_scan = test_table_scan_with_struct()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(mock_leaf(col("user"), "value").gt(lit(150)))? - .project(vec![col("user"), mock_leaf(col("user"), "label")])? + .filter(leaf_udf(col("user"), "value").gt(lit(150)))? + .project(vec![col("user"), leaf_udf(col("user"), "label")])? .build()?; assert_stages!(plan, @r#" ## Original Plan - Projection: test.user, mock_leaf(test.user, Utf8("label")) - Filter: mock_leaf(test.user, Utf8("value")) > Int32(150) + Projection: test.user, leaf_udf(test.user, Utf8("label")) + Filter: leaf_udf(test.user, Utf8("value")) > Int32(150) TableScan: test projection=[user] ## After Extraction @@ -624,14 +573,14 @@ mod tests { #[test] fn test_projection_deduplication() -> Result<()> { let table_scan = test_table_scan_with_struct()?; - let field = mock_leaf(col("user"), "name"); + let field = leaf_udf(col("user"), "name"); let plan = LogicalPlanBuilder::from(table_scan) .project(vec![field.clone(), field.clone().alias("name2")])? .build()?; assert_stages!(plan, @r#" ## Original Plan - Projection: mock_leaf(test.user, Utf8("name")), mock_leaf(test.user, Utf8("name")) AS name2 + Projection: leaf_udf(test.user, Utf8("name")), leaf_udf(test.user, Utf8("name")) AS name2 TableScan: test projection=[user] ## After Extraction @@ -655,12 +604,12 @@ mod tests { let table_scan = test_table_scan_with_struct()?; let plan = LogicalPlanBuilder::from(table_scan) .sort(vec![col("user").sort(true, true)])? - .project(vec![mock_leaf(col("user"), "name")])? + .project(vec![leaf_udf(col("user"), "name")])? .build()?; assert_stages!(plan, @r#" ## Original Plan - Projection: mock_leaf(test.user, Utf8("name")) + Projection: leaf_udf(test.user, Utf8("name")) Sort: test.user ASC NULLS FIRST TableScan: test projection=[user] @@ -681,12 +630,12 @@ mod tests { let table_scan = test_table_scan_with_struct()?; let plan = LogicalPlanBuilder::from(table_scan) .limit(0, Some(10))? - .project(vec![mock_leaf(col("user"), "name")])? + .project(vec![leaf_udf(col("user"), "name")])? .build()?; assert_stages!(plan, @r#" ## Original Plan - Projection: mock_leaf(test.user, Utf8("name")) + Projection: leaf_udf(test.user, Utf8("name")) Limit: skip=0, fetch=10 TableScan: test projection=[user] @@ -710,13 +659,13 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate( vec![col("user")], - vec![count(mock_leaf(col("user"), "value")).alias("cnt")], + vec![count(leaf_udf(col("user"), "value")).alias("cnt")], )? .build()?; assert_stages!(plan, @r#" ## Original Plan - Aggregate: groupBy=[[test.user]], aggr=[[COUNT(mock_leaf(test.user, Utf8("value"))) AS cnt]] + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(leaf_udf(test.user, Utf8("value"))) AS cnt]] TableScan: test projection=[user] ## After Extraction @@ -762,14 +711,14 @@ mod tests { let table_scan = test_table_scan_with_struct()?; let plan = LogicalPlanBuilder::from(table_scan) .project(vec![ - mock_leaf(col("user"), "name").alias("__datafusion_extracted_manual"), + leaf_udf(col("user"), "name").alias("__datafusion_extracted_manual"), col("user"), ])? .build()?; assert_stages!(plan, @r#" ## Original Plan - Projection: mock_leaf(test.user, Utf8("name")) AS __datafusion_extracted_manual, test.user + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_manual, test.user TableScan: test projection=[user] ## After Extraction @@ -788,14 +737,14 @@ mod tests { fn test_merge_into_existing_extracted_projection() -> Result<()> { let table_scan = test_table_scan_with_struct()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(mock_leaf(col("user"), "status").eq(lit("active")))? - .filter(mock_leaf(col("user"), "name").is_not_null())? + .filter(leaf_udf(col("user"), "status").eq(lit("active")))? + .filter(leaf_udf(col("user"), "name").is_not_null())? .build()?; assert_stages!(plan, @r#" ## Original Plan - Filter: mock_leaf(test.user, Utf8("name")) IS NOT NULL - Filter: mock_leaf(test.user, Utf8("status")) = Utf8("active") + Filter: leaf_udf(test.user, Utf8("name")) IS NOT NULL + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") TableScan: test projection=[id, user] ## After Extraction @@ -815,12 +764,12 @@ mod tests { let table_scan = test_table_scan_with_struct()?; let plan = LogicalPlanBuilder::from(table_scan) .project(vec![col("user")])? - .project(vec![mock_leaf(col("user"), "name")])? + .project(vec![leaf_udf(col("user"), "name")])? .build()?; assert_stages!(plan, @r#" ## Original Plan - Projection: mock_leaf(test.user, Utf8("name")) + Projection: leaf_udf(test.user, Utf8("name")) TableScan: test projection=[user] ## After Extraction @@ -889,14 +838,14 @@ mod tests { let table_scan = test_table_scan_with_struct()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(mock_leaf(col("user"), "status").eq(lit("active")))? - .aggregate(vec![mock_leaf(col("user"), "name")], vec![count(lit(1))])? + .filter(leaf_udf(col("user"), "status").eq(lit("active")))? + .aggregate(vec![leaf_udf(col("user"), "name")], vec![count(lit(1))])? .build()?; assert_stages!(plan, @r#" ## Original Plan - Aggregate: groupBy=[[mock_leaf(test.user, Utf8("name"))]], aggr=[[COUNT(Int32(1))]] - Filter: mock_leaf(test.user, Utf8("status")) = Utf8("active") + Aggregate: groupBy=[[leaf_udf(test.user, Utf8("name"))]], aggr=[[COUNT(Int32(1))]] + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") TableScan: test projection=[user] ## After Extraction @@ -915,14 +864,14 @@ mod tests { fn test_merge_with_new_columns() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(mock_leaf(col("a"), "x").eq(lit(1)))? - .filter(mock_leaf(col("b"), "y").eq(lit(2)))? + .filter(leaf_udf(col("a"), "x").eq(lit(1)))? + .filter(leaf_udf(col("b"), "y").eq(lit(2)))? .build()?; assert_stages!(plan, @r#" ## Original Plan - Filter: mock_leaf(test.b, Utf8("y")) = Int32(2) - Filter: mock_leaf(test.a, Utf8("x")) = Int32(1) + Filter: leaf_udf(test.b, Utf8("y")) = Int32(2) + Filter: leaf_udf(test.a, Utf8("x")) = Int32(1) TableScan: test projection=[a, b, c] ## After Extraction @@ -960,8 +909,8 @@ mod tests { right, JoinType::Inner, ( - vec![mock_leaf(col("user"), "id")], - vec![mock_leaf(col("user"), "id")], + vec![leaf_udf(col("user"), "id")], + vec![leaf_udf(col("user"), "id")], ), None, )? @@ -969,7 +918,7 @@ mod tests { assert_stages!(plan, @r#" ## Original Plan - Inner Join: mock_leaf(test.user, Utf8("id")) = mock_leaf(right.user, Utf8("id")) + Inner Join: leaf_udf(test.user, Utf8("id")) = leaf_udf(right.user, Utf8("id")) TableScan: test projection=[id, user] TableScan: right projection=[id, user] @@ -998,14 +947,14 @@ mod tests { JoinType::Inner, vec![ col("test.user").eq(col("right.user")), - mock_leaf(col("test.user"), "status").eq(lit("active")), + leaf_udf(col("test.user"), "status").eq(lit("active")), ], )? .build()?; assert_stages!(plan, @r#" ## Original Plan - Inner Join: Filter: test.user = right.user AND mock_leaf(test.user, Utf8("status")) = Utf8("active") + Inner Join: Filter: test.user = right.user AND leaf_udf(test.user, Utf8("status")) = Utf8("active") TableScan: test projection=[id, user] TableScan: right projection=[id, user] @@ -1034,15 +983,15 @@ mod tests { JoinType::Inner, vec![ col("test.user").eq(col("right.user")), - mock_leaf(col("test.user"), "status").eq(lit("active")), - mock_leaf(col("right.user"), "role").eq(lit("admin")), + leaf_udf(col("test.user"), "status").eq(lit("active")), + leaf_udf(col("right.user"), "role").eq(lit("admin")), ], )? .build()?; assert_stages!(plan, @r#" ## Original Plan - Inner Join: Filter: test.user = right.user AND mock_leaf(test.user, Utf8("status")) = Utf8("active") AND mock_leaf(right.user, Utf8("role")) = Utf8("admin") + Inner Join: Filter: test.user = right.user AND leaf_udf(test.user, Utf8("status")) = Utf8("active") AND leaf_udf(right.user, Utf8("role")) = Utf8("admin") TableScan: test projection=[id, user] TableScan: right projection=[id, user] @@ -1099,18 +1048,18 @@ mod tests { right, JoinType::Inner, ( - vec![mock_leaf(col("user"), "id")], - vec![mock_leaf(col("user"), "id")], + vec![leaf_udf(col("user"), "id")], + vec![leaf_udf(col("user"), "id")], ), None, )? - .filter(mock_leaf(col("test.user"), "status").eq(lit("active")))? + .filter(leaf_udf(col("test.user"), "status").eq(lit("active")))? .build()?; assert_stages!(plan, @r#" ## Original Plan - Filter: mock_leaf(test.user, Utf8("status")) = Utf8("active") - Inner Join: mock_leaf(test.user, Utf8("id")) = mock_leaf(right.user, Utf8("id")) + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") + Inner Join: leaf_udf(test.user, Utf8("id")) = leaf_udf(right.user, Utf8("id")) TableScan: test projection=[id, user] TableScan: right projection=[id, user] @@ -1137,14 +1086,14 @@ mod tests { let plan = LogicalPlanBuilder::from(left) .join(right, JoinType::Inner, (vec!["id"], vec!["id"]), None)? .project(vec![ - mock_leaf(col("test.user"), "status"), - mock_leaf(col("right.user"), "role"), + leaf_udf(col("test.user"), "status"), + leaf_udf(col("right.user"), "role"), ])? .build()?; assert_stages!(plan, @r#" ## Original Plan - Projection: mock_leaf(test.user, Utf8("status")), mock_leaf(right.user, Utf8("role")) + Projection: leaf_udf(test.user, Utf8("status")), leaf_udf(right.user, Utf8("role")) Inner Join: test.id = right.id TableScan: test projection=[id, user] TableScan: right projection=[id, user] @@ -1171,12 +1120,12 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![col("user").alias("x")])? .filter(col("x").is_not_null())? - .project(vec![mock_leaf(col("x"), "a")])? + .project(vec![leaf_udf(col("x"), "a")])? .build()?; assert_stages!(plan, @r#" ## Original Plan - Projection: mock_leaf(x, Utf8("a")) + Projection: leaf_udf(x, Utf8("a")) Filter: x IS NOT NULL Projection: test.user AS x TableScan: test projection=[user] @@ -1199,12 +1148,12 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![col("user").alias("x")])? .filter(col("x").is_not_null())? - .project(vec![mock_leaf(col("x"), "a").is_not_null()])? + .project(vec![leaf_udf(col("x"), "a").is_not_null()])? .build()?; assert_stages!(plan, @r#" ## Original Plan - Projection: mock_leaf(x, Utf8("a")) IS NOT NULL + Projection: leaf_udf(x, Utf8("a")) IS NOT NULL Filter: x IS NOT NULL Projection: test.user AS x TableScan: test projection=[user] @@ -1226,12 +1175,12 @@ mod tests { let table_scan = test_table_scan_with_struct()?; let plan = LogicalPlanBuilder::from(table_scan) .project(vec![col("user").alias("x")])? - .filter(mock_leaf(col("x"), "a").eq(lit("active")))? + .filter(leaf_udf(col("x"), "a").eq(lit("active")))? .build()?; assert_stages!(plan, @r#" ## Original Plan - Filter: mock_leaf(x, Utf8("a")) = Utf8("active") + Filter: leaf_udf(x, Utf8("a")) = Utf8("active") Projection: test.user AS x TableScan: test projection=[user] @@ -1256,12 +1205,12 @@ mod tests { let table_scan = test_table_scan_with_struct()?; let plan = LogicalPlanBuilder::from(table_scan) .alias("sub")? - .project(vec![mock_leaf(col("sub.user"), "name")])? + .project(vec![leaf_udf(col("sub.user"), "name")])? .build()?; assert_stages!(plan, @r#" ## Original Plan - Projection: mock_leaf(sub.user, Utf8("name")) + Projection: leaf_udf(sub.user, Utf8("name")) SubqueryAlias: sub TableScan: test projection=[user] @@ -1282,14 +1231,14 @@ mod tests { let table_scan = test_table_scan_with_struct()?; let plan = LogicalPlanBuilder::from(table_scan) .alias("sub")? - .filter(mock_leaf(col("sub.user"), "status").eq(lit("active")))? - .project(vec![mock_leaf(col("sub.user"), "name")])? + .filter(leaf_udf(col("sub.user"), "status").eq(lit("active")))? + .project(vec![leaf_udf(col("sub.user"), "name")])? .build()?; assert_stages!(plan, @r#" ## Original Plan - Projection: mock_leaf(sub.user, Utf8("name")) - Filter: mock_leaf(sub.user, Utf8("status")) = Utf8("active") + Projection: leaf_udf(sub.user, Utf8("name")) + Filter: leaf_udf(sub.user, Utf8("status")) = Utf8("active") SubqueryAlias: sub TableScan: test projection=[user] @@ -1311,12 +1260,12 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .alias("inner_sub")? .alias("outer_sub")? - .project(vec![mock_leaf(col("outer_sub.user"), "name")])? + .project(vec![leaf_udf(col("outer_sub.user"), "name")])? .build()?; assert_stages!(plan, @r#" ## Original Plan - Projection: mock_leaf(outer_sub.user, Utf8("name")) + Projection: leaf_udf(outer_sub.user, Utf8("name")) SubqueryAlias: outer_sub SubqueryAlias: inner_sub TableScan: test projection=[user] @@ -1357,63 +1306,21 @@ mod tests { ") } - /// A variant of MockLeafFunc with the same `name()` but a different concrete type. - /// Used to verify that deduplication uses `Expr` equality, not `schema_name`. - #[derive(Debug, PartialEq, Eq, Hash)] - struct MockLeafFuncVariant { - signature: Signature, - } - - impl MockLeafFuncVariant { - fn new() -> Self { - Self { - signature: Signature::new( - TypeSignature::Any(2), - datafusion_expr::Volatility::Immutable, - ), - } - } - } - - impl ScalarUDFImpl for MockLeafFuncVariant { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn name(&self) -> &str { - "mock_leaf" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _args: &[DataType]) -> Result { - Ok(DataType::Utf8) - } - - fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { - unimplemented!("This is only used for testing optimization") - } - - fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement { - match args.first() { - Some(ExpressionPlacement::Column) - | Some(ExpressionPlacement::MoveTowardsLeafNodes) => { - ExpressionPlacement::MoveTowardsLeafNodes - } - _ => ExpressionPlacement::KeepInPlace, - } - } - } - /// Two UDFs with the same `name()` but different concrete types should NOT be /// deduplicated -- they are semantically different expressions that happen to /// collide on `schema_name()`. #[test] fn test_different_udfs_same_schema_name_not_deduplicated() -> Result<()> { - let udf_a = Arc::new(ScalarUDF::new_from_impl(MockLeafFunc::new())); - let udf_b = Arc::new(ScalarUDF::new_from_impl(MockLeafFuncVariant::new())); + let udf_a = Arc::new(ScalarUDF::new_from_impl( + PlacementTestUDF::new() + .with_placement(ExpressionPlacement::MoveTowardsLeafNodes) + .with_id(1), + )); + let udf_b = Arc::new(ScalarUDF::new_from_impl( + PlacementTestUDF::new() + .with_placement(ExpressionPlacement::MoveTowardsLeafNodes) + .with_id(2), + )); let expr_a = Expr::ScalarFunction(ScalarFunction::new_udf( udf_a, @@ -1449,7 +1356,7 @@ mod tests { assert_stages!(plan, @r#" ## Original Plan Projection: test.id - Filter: mock_leaf(test.user, Utf8("field")) = Utf8("a") AND mock_leaf(test.user, Utf8("field")) = Utf8("b") + Filter: leaf_udf(test.user, Utf8("field")) = Utf8("a") AND leaf_udf(test.user, Utf8("field")) = Utf8("b") TableScan: test projection=[id, user] ## After Extraction @@ -1473,14 +1380,14 @@ mod tests { fn test_extraction_pushdown_through_filter_with_extracted_predicate() -> Result<()> { let table_scan = test_table_scan_with_struct()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(mock_leaf(col("user"), "status").eq(lit("active")))? - .project(vec![col("id"), mock_leaf(col("user"), "name")])? + .filter(leaf_udf(col("user"), "status").eq(lit("active")))? + .project(vec![col("id"), leaf_udf(col("user"), "name")])? .build()?; assert_stages!(plan, @r#" ## Original Plan - Projection: test.id, mock_leaf(test.user, Utf8("name")) - Filter: mock_leaf(test.user, Utf8("status")) = Utf8("active") + Projection: test.id, leaf_udf(test.user, Utf8("name")) + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") TableScan: test projection=[id, user] ## After Extraction @@ -1498,7 +1405,7 @@ mod tests { #[test] fn test_extraction_pushdown_same_expr_in_filter_and_projection() -> Result<()> { let table_scan = test_table_scan_with_struct()?; - let field_expr = mock_leaf(col("user"), "status"); + let field_expr = leaf_udf(col("user"), "status"); let plan = LogicalPlanBuilder::from(table_scan) .filter(field_expr.clone().gt(lit(5)))? .project(vec![col("id"), field_expr])? @@ -1506,8 +1413,8 @@ mod tests { assert_stages!(plan, @r#" ## Original Plan - Projection: test.id, mock_leaf(test.user, Utf8("status")) - Filter: mock_leaf(test.user, Utf8("status")) > Int32(5) + Projection: test.id, leaf_udf(test.user, Utf8("status")) + Filter: leaf_udf(test.user, Utf8("status")) > Int32(5) TableScan: test projection=[id, user] ## After Extraction @@ -1536,20 +1443,20 @@ mod tests { JoinType::Left, vec![ col("test.id").eq(col("right.id")), - mock_leaf(col("right.user"), "status").gt(lit(5)), + leaf_udf(col("right.user"), "status").gt(lit(5)), ], )? .project(vec![ col("test.id"), - mock_leaf(col("test.user"), "name"), - mock_leaf(col("right.user"), "status"), + leaf_udf(col("test.user"), "name"), + leaf_udf(col("right.user"), "status"), ])? .build()?; assert_stages!(plan, @r#" ## Original Plan - Projection: test.id, mock_leaf(test.user, Utf8("name")), mock_leaf(right.user, Utf8("status")) - Left Join: Filter: test.id = right.id AND mock_leaf(right.user, Utf8("status")) > Int32(5) + Projection: test.id, leaf_udf(test.user, Utf8("name")), leaf_udf(right.user, Utf8("status")) + Left Join: Filter: test.id = right.id AND leaf_udf(right.user, Utf8("status")) > Int32(5) TableScan: test projection=[id, user] TableScan: right projection=[id, user] @@ -1570,18 +1477,18 @@ mod tests { fn test_pure_extraction_proj_push_through_filter() -> Result<()> { let table_scan = test_table_scan_with_struct()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(mock_leaf(col("user"), "status").gt(lit(5)))? + .filter(leaf_udf(col("user"), "status").gt(lit(5)))? .project(vec![ col("id"), - mock_leaf(col("user"), "name"), - mock_leaf(col("user"), "status"), + leaf_udf(col("user"), "name"), + leaf_udf(col("user"), "status"), ])? .build()?; assert_stages!(plan, @r#" ## Original Plan - Projection: test.id, mock_leaf(test.user, Utf8("name")), mock_leaf(test.user, Utf8("status")) - Filter: mock_leaf(test.user, Utf8("status")) > Int32(5) + Projection: test.id, leaf_udf(test.user, Utf8("name")), leaf_udf(test.user, Utf8("status")) + Filter: leaf_udf(test.user, Utf8("status")) > Int32(5) TableScan: test projection=[id, user] ## After Extraction diff --git a/datafusion/optimizer/src/test/udfs.rs b/datafusion/optimizer/src/test/udfs.rs index 0e68568decf85..9164603dba3d5 100644 --- a/datafusion/optimizer/src/test/udfs.rs +++ b/datafusion/optimizer/src/test/udfs.rs @@ -21,7 +21,7 @@ use arrow::datatypes::DataType; use datafusion_common::Result; use datafusion_expr::{ ColumnarValue, Expr, ExpressionPlacement, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, Signature, Volatility, + ScalarUDFImpl, Signature, TypeSignature, }; /// A configurable test UDF for optimizer tests. @@ -30,6 +30,7 @@ use datafusion_expr::{ pub struct PlacementTestUDF { signature: Signature, placement: ExpressionPlacement, + id: usize, } impl Default for PlacementTestUDF { @@ -41,15 +42,30 @@ impl Default for PlacementTestUDF { impl PlacementTestUDF { pub fn new() -> Self { Self { - signature: Signature::exact(vec![DataType::UInt32], Volatility::Immutable), + // Accept any one or two arguments and return UInt32 for testing purposes. + // The actual types don't matter since this UDF is not intended for execution. + signature: Signature::new( + TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]), + datafusion_expr::Volatility::Immutable, + ), placement: ExpressionPlacement::MoveTowardsLeafNodes, + id: 0, } } + /// Set the expression placement for this UDF, which is used by optimizer rules to determine where in the plan the expression should be placed. + /// This also resets the name of the UDF to a default based on the placement. pub fn with_placement(mut self, placement: ExpressionPlacement) -> Self { self.placement = placement; self } + + /// Set the id of the UDF. + /// This is an arbitrary made up field to allow creating multiple distinct UDFs with the same placement. + pub fn with_id(mut self, id: usize) -> Self { + self.id = id; + self + } } impl ScalarUDFImpl for PlacementTestUDF {