diff --git a/src/ast/query.rs b/src/ast/query.rs index ff617a38e..b4d3fdb2b 100644 --- a/src/ast/query.rs +++ b/src/ast/query.rs @@ -156,12 +156,12 @@ pub enum SetExpr { /// UNION/EXCEPT/INTERSECT of two queries /// A set operation combining two query expressions. SetOperation { + /// Left operand of the set operation. + left: Box, /// The set operator used (e.g. `UNION`, `EXCEPT`). op: SetOperator, /// Optional quantifier (`ALL`, `DISTINCT`, etc.). set_quantifier: SetQuantifier, - /// Left operand of the set operation. - left: Box, /// Right operand of the set operation. right: Box, }, @@ -442,6 +442,7 @@ impl SelectModifiers { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +#[cfg_attr(feature = "visitor", visit(with = "visit_select"))] pub struct Select { /// Token for the `SELECT` keyword pub select_token: AttachedToken, diff --git a/src/ast/visitor.rs b/src/ast/visitor.rs index 5d841655b..5f9b37489 100644 --- a/src/ast/visitor.rs +++ b/src/ast/visitor.rs @@ -17,7 +17,7 @@ //! Recursive visitors for ast Nodes. See [`Visitor`] for more details. -use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor, Value}; +use crate::ast::{Expr, ObjectName, Query, Select, Statement, TableFactor, Value}; use core::ops::ControlFlow; /// A type that can be visited by a [`Visitor`]. See [`Visitor`] for @@ -207,6 +207,16 @@ pub trait Visitor { ControlFlow::Continue(()) } + /// Invoked for any [Select] that appear in the AST before visiting children + fn pre_visit_select(&mut self, _select: &Select) -> ControlFlow { + ControlFlow::Continue(()) + } + + /// Invoked for any [Select] that appear in the AST after visiting children + fn post_visit_select(&mut self, _select: &Select) -> ControlFlow { + ControlFlow::Continue(()) + } + /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow { ControlFlow::Continue(()) @@ -319,6 +329,16 @@ pub trait VisitorMut { ControlFlow::Continue(()) } + /// Invoked for any [Select] that appear in the AST before visiting children + fn pre_visit_select(&mut self, _select: &mut Select) -> ControlFlow { + ControlFlow::Continue(()) + } + + /// Invoked for any [Select] that appear in the AST after visiting children + fn post_visit_select(&mut self, _select: &mut Select) -> ControlFlow { + ControlFlow::Continue(()) + } + /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow { ControlFlow::Continue(()) @@ -709,6 +729,16 @@ mod tests { ControlFlow::Continue(()) } + fn pre_visit_select(&mut self, select: &Select) -> ControlFlow { + self.visited.push(format!("PRE: SELECT: {select}")); + ControlFlow::Continue(()) + } + + fn post_visit_select(&mut self, select: &Select) -> ControlFlow { + self.visited.push(format!("POST: SELECT: {select}")); + ControlFlow::Continue(()) + } + fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow { self.visited.push(format!("PRE: RELATION: {relation}")); ControlFlow::Continue(()) @@ -779,10 +809,12 @@ mod tests { vec![ "PRE: STATEMENT: SELECT * FROM table_name AS my_table", "PRE: QUERY: SELECT * FROM table_name AS my_table", + "PRE: SELECT: SELECT * FROM table_name AS my_table", "PRE: TABLE FACTOR: table_name AS my_table", "PRE: RELATION: table_name", "POST: RELATION: table_name", "POST: TABLE FACTOR: table_name AS my_table", + "POST: SELECT: SELECT * FROM table_name AS my_table", "POST: QUERY: SELECT * FROM table_name AS my_table", "POST: STATEMENT: SELECT * FROM table_name AS my_table", ], @@ -792,6 +824,7 @@ mod tests { vec![ "PRE: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id", "PRE: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id", + "PRE: SELECT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id", "PRE: TABLE FACTOR: t1", "PRE: RELATION: t1", "POST: RELATION: t1", @@ -806,6 +839,7 @@ mod tests { "PRE: EXPR: t2.t1_id", "POST: EXPR: t2.t1_id", "POST: EXPR: t1.id = t2.t1_id", + "POST: SELECT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id", "POST: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id", "POST: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id", ], @@ -815,20 +849,24 @@ mod tests { vec![ "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", + "PRE: SELECT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", "PRE: TABLE FACTOR: t1", "PRE: RELATION: t1", "POST: RELATION: t1", "POST: TABLE FACTOR: t1", "PRE: EXPR: EXISTS (SELECT column FROM t2)", "PRE: QUERY: SELECT column FROM t2", + "PRE: SELECT: SELECT column FROM t2", "PRE: EXPR: column", "POST: EXPR: column", "PRE: TABLE FACTOR: t2", "PRE: RELATION: t2", "POST: RELATION: t2", "POST: TABLE FACTOR: t2", + "POST: SELECT: SELECT column FROM t2", "POST: QUERY: SELECT column FROM t2", "POST: EXPR: EXISTS (SELECT column FROM t2)", + "POST: SELECT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", ], @@ -838,20 +876,24 @@ mod tests { vec![ "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", + "PRE: SELECT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", "PRE: TABLE FACTOR: t1", "PRE: RELATION: t1", "POST: RELATION: t1", "POST: TABLE FACTOR: t1", "PRE: EXPR: EXISTS (SELECT column FROM t2)", "PRE: QUERY: SELECT column FROM t2", + "PRE: SELECT: SELECT column FROM t2", "PRE: EXPR: column", "POST: EXPR: column", "PRE: TABLE FACTOR: t2", "PRE: RELATION: t2", "POST: RELATION: t2", "POST: TABLE FACTOR: t2", + "POST: SELECT: SELECT column FROM t2", "POST: QUERY: SELECT column FROM t2", "POST: EXPR: EXISTS (SELECT column FROM t2)", + "POST: SELECT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", ], @@ -861,24 +903,30 @@ mod tests { vec![ "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3", "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3", + "PRE: SELECT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", "PRE: TABLE FACTOR: t1", "PRE: RELATION: t1", "POST: RELATION: t1", "POST: TABLE FACTOR: t1", "PRE: EXPR: EXISTS (SELECT column FROM t2)", "PRE: QUERY: SELECT column FROM t2", + "PRE: SELECT: SELECT column FROM t2", "PRE: EXPR: column", "POST: EXPR: column", "PRE: TABLE FACTOR: t2", "PRE: RELATION: t2", "POST: RELATION: t2", "POST: TABLE FACTOR: t2", + "POST: SELECT: SELECT column FROM t2", "POST: QUERY: SELECT column FROM t2", "POST: EXPR: EXISTS (SELECT column FROM t2)", + "POST: SELECT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", + "PRE: SELECT: SELECT * FROM t3", "PRE: TABLE FACTOR: t3", "PRE: RELATION: t3", "POST: RELATION: t3", "POST: TABLE FACTOR: t3", + "POST: SELECT: SELECT * FROM t3", "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3", "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3", ], @@ -892,6 +940,7 @@ mod tests { vec![ "PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID", "PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID", + "PRE: SELECT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)", "PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)", "PRE: TABLE FACTOR: monthly_sales", "PRE: RELATION: monthly_sales", @@ -912,6 +961,7 @@ mod tests { "PRE: EXPR: 'APR'", "POST: EXPR: 'APR'", "POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)", + "POST: SELECT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)", "PRE: EXPR: EMPID", "POST: EXPR: EMPID", "POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",