diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 6dd55f1d7e4be..9439d8c590acc 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -382,30 +382,39 @@ fn get_exprs_except_skipped( } } -/// For each column specified in the USING JOIN condition, the JOIN plan outputs it twice -/// (once for each join side), but an unqualified wildcard should include it only once. -/// This function returns the columns that should be excluded. +/// When a JOIN has a USING clause, the join columns appear in the output +/// schema once per side (for inner/outer joins) or once total (for semi/anti +/// joins). An unqualified wildcard should include each USING column only once. +/// This function returns the duplicate columns that should be excluded. fn exclude_using_columns(plan: &LogicalPlan) -> Result> { - let using_columns = plan.using_columns()?; - let excluded = using_columns - .into_iter() - // For each USING JOIN condition, only expand to one of each join column in projection - .flat_map(|cols| { - let mut cols = cols.into_iter().collect::>(); - // sort join columns to make sure we consistently keep the same - // qualified column - cols.sort(); - let mut out_column_names: HashSet = HashSet::new(); - cols.into_iter().filter_map(move |c| { - if out_column_names.contains(&c.name) { - Some(c) - } else { - out_column_names.insert(c.name); - None - } - }) - }) - .collect::>(); + let output_columns: HashSet<_> = plan.schema().columns().iter().cloned().collect(); + let mut excluded = HashSet::new(); + for cols in plan.using_columns()? { + // `using_columns()` returns join columns from both sides regardless of + // the join type. For semi/anti joins, only one side's columns appear in + // the output schema. Filter to output columns so that columns from the + // non-output side don't participate in the deduplication process below + // and displace real output columns. + let mut cols: Vec<_> = cols + .into_iter() + .filter(|c| output_columns.contains(c)) + .collect(); + + // Sort so we keep the same qualified column, regardless of HashSet + // iteration order. + cols.sort(); + + // Keep only one column per name from the columns set, adding any + // duplicates to the excluded set. + let mut seen_names = HashSet::new(); + for col in cols { + if seen_names.contains(col.name.as_str()) { + excluded.insert(col); // exclude columns with already seen name + } else { + seen_names.insert(col.name.clone()); // mark column name as seen + } + } + } Ok(excluded) } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 950a79ddb0b5e..fd606af3a6af0 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -5004,6 +5004,71 @@ fn test_using_join_wildcard_schema() { ); } +#[test] +fn test_using_join_wildcard_schema_semi_anti() { + let s_columns = &["s.x1", "s.x2", "s.x3"]; + let t_columns = &["t.x1", "t.x2", "t.x3"]; + + let sql = "WITH + s AS (SELECT 1 AS x1, 2 AS x2, 3 AS x3), + t AS (SELECT 1 AS x1, 4 AS x2, 5 AS x3) + SELECT * FROM s LEFT SEMI JOIN t USING (x1)"; + let plan = logical_plan(sql).unwrap(); + assert_eq!(plan.schema().field_names(), s_columns); + + let sql = "WITH + s AS (SELECT 1 AS x1, 2 AS x2, 3 AS x3), + t AS (SELECT 1 AS x1, 4 AS x2, 5 AS x3) + SELECT * FROM t RIGHT SEMI JOIN s USING (x1)"; + let plan = logical_plan(sql).unwrap(); + assert_eq!(plan.schema().field_names(), s_columns); + + let sql = "WITH + s AS (SELECT 1 AS x1, 2 AS x2, 3 AS x3), + t AS (SELECT 1 AS x1, 4 AS x2, 5 AS x3) + SELECT * FROM s LEFT ANTI JOIN t USING (x1)"; + let plan = logical_plan(sql).unwrap(); + assert_eq!(plan.schema().field_names(), s_columns); + + let sql = "WITH + s AS (SELECT 1 AS x1, 2 AS x2, 3 AS x3), + t AS (SELECT 1 AS x1, 4 AS x2, 5 AS x3) + SELECT * FROM t RIGHT ANTI JOIN s USING (x1)"; + let plan = logical_plan(sql).unwrap(); + assert_eq!(plan.schema().field_names(), s_columns); + + // Same as above, but with swapped s and t sides. + // Tests the issue fixed with #20990. + + let sql = "WITH + s AS (SELECT 1 AS x1, 2 AS x2, 3 AS x3), + t AS (SELECT 1 AS x1, 4 AS x2, 5 AS x3) + SELECT * FROM t LEFT SEMI JOIN s USING (x1)"; + let plan = logical_plan(sql).unwrap(); + assert_eq!(plan.schema().field_names(), t_columns); + + let sql = "WITH + s AS (SELECT 1 AS x1, 2 AS x2, 3 AS x3), + t AS (SELECT 1 AS x1, 4 AS x2, 5 AS x3) + SELECT * FROM s RIGHT SEMI JOIN t USING (x1)"; + let plan = logical_plan(sql).unwrap(); + assert_eq!(plan.schema().field_names(), t_columns); + + let sql = "WITH + s AS (SELECT 1 AS x1, 2 AS x2, 3 AS x3), + t AS (SELECT 1 AS x1, 4 AS x2, 5 AS x3) + SELECT * FROM t LEFT ANTI JOIN s USING (x1)"; + let plan = logical_plan(sql).unwrap(); + assert_eq!(plan.schema().field_names(), t_columns); + + let sql = "WITH + s AS (SELECT 1 AS x1, 2 AS x2, 3 AS x3), + t AS (SELECT 1 AS x1, 4 AS x2, 5 AS x3) + SELECT * FROM s RIGHT ANTI JOIN t USING (x1)"; + let plan = logical_plan(sql).unwrap(); + assert_eq!(plan.schema().field_names(), t_columns); +} + #[test] fn test_2_nested_lateral_join_with_the_deepest_join_referencing_the_outer_most_relation() {