Skip to content

Commit b7ecc1b

Browse files
committed
Add CompoundQueryParser for UNION/INTERSECT/EXCEPT with precedence
Implements Pratt-style precedence parsing where INTERSECT binds tighter than UNION/EXCEPT. Handles parenthesized nesting, optional ALL modifier, and trailing ORDER BY/LIMIT on compound results. Returns bare NODE_SELECT_STMT when no set operator is present.
1 parent 13a1e50 commit b7ecc1b

1 file changed

Lines changed: 267 additions & 0 deletions

File tree

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
#ifndef SQL_PARSER_COMPOUND_QUERY_PARSER_H
2+
#define SQL_PARSER_COMPOUND_QUERY_PARSER_H
3+
4+
#include "sql_parser/common.h"
5+
#include "sql_parser/token.h"
6+
#include "sql_parser/tokenizer.h"
7+
#include "sql_parser/ast.h"
8+
#include "sql_parser/arena.h"
9+
#include "sql_parser/select_parser.h"
10+
#include "sql_parser/expression_parser.h"
11+
12+
namespace sql_parser {
13+
14+
// Flag on NODE_SET_OPERATION to indicate ALL
15+
static constexpr uint16_t FLAG_SET_OP_ALL = 0x01;
16+
17+
template <Dialect D>
18+
class CompoundQueryParser {
19+
public:
20+
CompoundQueryParser(Tokenizer<D>& tokenizer, Arena& arena)
21+
: tok_(tokenizer), arena_(arena), expr_parser_(tokenizer, arena) {}
22+
23+
// Parse a compound query (or a plain SELECT if no set operator follows).
24+
// Returns NODE_SELECT_STMT for plain selects, NODE_COMPOUND_QUERY for compounds.
25+
AstNode* parse() {
26+
AstNode* result = parse_compound_expr(0);
27+
if (!result) return nullptr;
28+
29+
// If the result is a set operation, wrap in COMPOUND_QUERY and parse trailing clauses
30+
if (result->type == NodeType::NODE_SET_OPERATION) {
31+
AstNode* compound = make_node(arena_, NodeType::NODE_COMPOUND_QUERY);
32+
if (!compound) return nullptr;
33+
compound->add_child(result);
34+
35+
// Parse trailing ORDER BY (applies to whole compound)
36+
if (tok_.peek().type == TokenType::TK_ORDER) {
37+
tok_.skip();
38+
if (tok_.peek().type == TokenType::TK_BY) tok_.skip();
39+
AstNode* order_by = parse_order_by();
40+
if (order_by) compound->add_child(order_by);
41+
}
42+
43+
// Parse trailing LIMIT (applies to whole compound)
44+
if (tok_.peek().type == TokenType::TK_LIMIT) {
45+
tok_.skip();
46+
AstNode* limit = parse_limit();
47+
if (limit) compound->add_child(limit);
48+
}
49+
50+
return compound;
51+
}
52+
53+
// No set operator found -- return the bare SELECT as-is
54+
return result;
55+
}
56+
57+
private:
58+
Tokenizer<D>& tok_;
59+
Arena& arena_;
60+
ExpressionParser<D> expr_parser_;
61+
62+
// Precedence levels
63+
static constexpr int PREC_UNION_EXCEPT = 1;
64+
static constexpr int PREC_INTERSECT = 2;
65+
66+
// Get the precedence of a set operator token, or 0 if not a set operator
67+
static int get_set_op_precedence(TokenType type) {
68+
switch (type) {
69+
case TokenType::TK_UNION: return PREC_UNION_EXCEPT;
70+
case TokenType::TK_EXCEPT: return PREC_UNION_EXCEPT;
71+
case TokenType::TK_INTERSECT: return PREC_INTERSECT;
72+
default: return 0;
73+
}
74+
}
75+
76+
// Check if a token is a set operator
77+
static bool is_set_operator(TokenType type) {
78+
return type == TokenType::TK_UNION ||
79+
type == TokenType::TK_INTERSECT ||
80+
type == TokenType::TK_EXCEPT;
81+
}
82+
83+
// Parse a compound expression with minimum precedence (Pratt-style)
84+
AstNode* parse_compound_expr(int min_prec) {
85+
AstNode* left = parse_operand();
86+
if (!left) return nullptr;
87+
88+
while (true) {
89+
Token t = tok_.peek();
90+
int prec = get_set_op_precedence(t.type);
91+
if (prec == 0 || prec <= min_prec) break;
92+
93+
// Consume the set operator
94+
tok_.skip();
95+
StringRef op_text = t.text;
96+
97+
// Check for optional ALL
98+
uint16_t flags = 0;
99+
if (tok_.peek().type == TokenType::TK_ALL) {
100+
tok_.skip();
101+
flags = FLAG_SET_OP_ALL;
102+
}
103+
104+
// Parse right operand with current precedence as min (left-associative)
105+
AstNode* right = parse_compound_expr(prec);
106+
if (!right) return nullptr;
107+
108+
// Build NODE_SET_OPERATION with left and right as children
109+
AstNode* setop = make_node(arena_, NodeType::NODE_SET_OPERATION, op_text);
110+
if (!setop) return nullptr;
111+
setop->flags = flags;
112+
setop->add_child(left);
113+
setop->add_child(right);
114+
115+
left = setop;
116+
}
117+
118+
return left;
119+
}
120+
121+
// Parse a single operand: parenthesized compound or plain SELECT
122+
AstNode* parse_operand() {
123+
if (tok_.peek().type == TokenType::TK_LPAREN) {
124+
tok_.skip(); // consume '('
125+
126+
// Could be a parenthesized compound query or a parenthesized SELECT
127+
AstNode* inner = nullptr;
128+
if (tok_.peek().type == TokenType::TK_SELECT ||
129+
tok_.peek().type == TokenType::TK_LPAREN) {
130+
// Parse the inner compound expression recursively
131+
// Need to consume SELECT keyword first if present
132+
if (tok_.peek().type == TokenType::TK_SELECT) {
133+
tok_.skip(); // consume SELECT
134+
// Create a SelectParser that will parse from after SELECT
135+
SelectParser<D> sp(tok_, arena_);
136+
AstNode* select = sp.parse();
137+
138+
// Check if a set operator follows inside the parens
139+
if (is_set_operator(tok_.peek().type)) {
140+
// There's a compound inside the parens
141+
// We need to continue parsing with select as the left operand
142+
inner = continue_compound_from(select, 0);
143+
} else {
144+
inner = select;
145+
}
146+
} else {
147+
// Nested parenthesized: ((SELECT ...))
148+
inner = parse_compound_expr(0);
149+
}
150+
}
151+
152+
// Expect closing ')'
153+
if (tok_.peek().type == TokenType::TK_RPAREN) {
154+
tok_.skip();
155+
}
156+
157+
return inner;
158+
}
159+
160+
// Not parenthesized -- must be a plain SELECT
161+
// Note: SELECT keyword was already consumed by the classifier
162+
// The tokenizer is positioned right after SELECT
163+
SelectParser<D> sp(tok_, arena_);
164+
return sp.parse();
165+
}
166+
167+
// Continue parsing compound from an already-parsed left operand
168+
AstNode* continue_compound_from(AstNode* left, int min_prec) {
169+
if (!left) return nullptr;
170+
171+
while (true) {
172+
Token t = tok_.peek();
173+
int prec = get_set_op_precedence(t.type);
174+
if (prec == 0 || prec <= min_prec) break;
175+
176+
tok_.skip();
177+
StringRef op_text = t.text;
178+
179+
uint16_t flags = 0;
180+
if (tok_.peek().type == TokenType::TK_ALL) {
181+
tok_.skip();
182+
flags = FLAG_SET_OP_ALL;
183+
}
184+
185+
// Inside parens, operand must start with SELECT or (
186+
AstNode* right = nullptr;
187+
if (tok_.peek().type == TokenType::TK_SELECT) {
188+
tok_.skip();
189+
SelectParser<D> sp(tok_, arena_);
190+
AstNode* rsel = sp.parse();
191+
// Check for more operators at higher precedence
192+
right = continue_compound_from(rsel, prec);
193+
} else if (tok_.peek().type == TokenType::TK_LPAREN) {
194+
right = parse_operand(); // handles nested parens
195+
right = continue_compound_from(right, prec);
196+
}
197+
198+
if (!right) return nullptr;
199+
200+
AstNode* setop = make_node(arena_, NodeType::NODE_SET_OPERATION, op_text);
201+
if (!setop) return nullptr;
202+
setop->flags = flags;
203+
setop->add_child(left);
204+
setop->add_child(right);
205+
206+
left = setop;
207+
}
208+
209+
return left;
210+
}
211+
212+
// Parse trailing ORDER BY for compound result
213+
AstNode* parse_order_by() {
214+
AstNode* order_by = make_node(arena_, NodeType::NODE_ORDER_BY_CLAUSE);
215+
if (!order_by) return nullptr;
216+
217+
while (true) {
218+
AstNode* expr = expr_parser_.parse();
219+
if (!expr) break;
220+
221+
AstNode* item = make_node(arena_, NodeType::NODE_ORDER_BY_ITEM);
222+
item->add_child(expr);
223+
224+
// Optional ASC/DESC
225+
Token dir = tok_.peek();
226+
if (dir.type == TokenType::TK_ASC || dir.type == TokenType::TK_DESC) {
227+
tok_.skip();
228+
item->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, dir.text));
229+
}
230+
231+
order_by->add_child(item);
232+
233+
if (tok_.peek().type == TokenType::TK_COMMA) {
234+
tok_.skip();
235+
} else {
236+
break;
237+
}
238+
}
239+
return order_by;
240+
}
241+
242+
// Parse trailing LIMIT for compound result
243+
AstNode* parse_limit() {
244+
AstNode* limit = make_node(arena_, NodeType::NODE_LIMIT_CLAUSE);
245+
if (!limit) return nullptr;
246+
247+
AstNode* first = expr_parser_.parse();
248+
if (first) limit->add_child(first);
249+
250+
if (tok_.peek().type == TokenType::TK_OFFSET) {
251+
tok_.skip();
252+
AstNode* offset = expr_parser_.parse();
253+
if (offset) limit->add_child(offset);
254+
} else if (tok_.peek().type == TokenType::TK_COMMA) {
255+
// MySQL: LIMIT offset, count
256+
tok_.skip();
257+
AstNode* count = expr_parser_.parse();
258+
if (count) limit->add_child(count);
259+
}
260+
261+
return limit;
262+
}
263+
};
264+
265+
} // namespace sql_parser
266+
267+
#endif // SQL_PARSER_COMPOUND_QUERY_PARSER_H

0 commit comments

Comments
 (0)