Skip to content

Commit 9d0dcf7

Browse files
committed
feat: add SET deep parser with full AST for all SET variants
1 parent 2d0cfda commit 9d0dcf7

3 files changed

Lines changed: 434 additions & 2 deletions

File tree

include/sql_parser/set_parser.h

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
#ifndef SQL_PARSER_SET_PARSER_H
2+
#define SQL_PARSER_SET_PARSER_H
3+
4+
#include "sql_parser/common.h"
5+
#include "sql_parser/token.h"
6+
#include "sql_parser/tokenizer.h"
7+
#include "sql_parser/ast.h"
8+
#include "sql_parser/arena.h"
9+
#include "sql_parser/expression_parser.h"
10+
11+
namespace sql_parser {
12+
13+
template <Dialect D>
14+
class SetParser {
15+
public:
16+
SetParser(Tokenizer<D>& tokenizer, Arena& arena)
17+
: tok_(tokenizer), arena_(arena), expr_parser_(tokenizer, arena) {}
18+
19+
// Parse a SET statement (SET keyword already consumed by classifier).
20+
// Returns the root NODE_SET_STMT node, or nullptr on failure.
21+
AstNode* parse() {
22+
AstNode* root = make_node(arena_, NodeType::NODE_SET_STMT);
23+
if (!root) return nullptr;
24+
25+
Token next = tok_.peek();
26+
27+
// SET NAMES ...
28+
if (next.type == TokenType::TK_NAMES) {
29+
tok_.skip();
30+
AstNode* names_node = parse_set_names();
31+
if (names_node) root->add_child(names_node);
32+
return root;
33+
}
34+
35+
// SET CHARACTER SET ... or SET CHARSET ...
36+
if (next.type == TokenType::TK_CHARACTER) {
37+
tok_.skip();
38+
// Expect SET keyword
39+
if (tok_.peek().type == TokenType::TK_SET) {
40+
tok_.skip();
41+
}
42+
AstNode* charset_node = parse_set_charset();
43+
if (charset_node) root->add_child(charset_node);
44+
return root;
45+
}
46+
if (next.type == TokenType::TK_CHARSET) {
47+
tok_.skip();
48+
AstNode* charset_node = parse_set_charset();
49+
if (charset_node) root->add_child(charset_node);
50+
return root;
51+
}
52+
53+
// SET [GLOBAL|SESSION] TRANSACTION ...
54+
// Need to check for scope + TRANSACTION or just TRANSACTION
55+
if (next.type == TokenType::TK_TRANSACTION) {
56+
tok_.skip();
57+
AstNode* txn_node = parse_set_transaction(StringRef{});
58+
if (txn_node) root->add_child(txn_node);
59+
return root;
60+
}
61+
62+
if (next.type == TokenType::TK_GLOBAL || next.type == TokenType::TK_SESSION) {
63+
Token scope_tok = tok_.next_token();
64+
if (tok_.peek().type == TokenType::TK_TRANSACTION) {
65+
tok_.skip();
66+
AstNode* txn_node = parse_set_transaction(scope_tok.text);
67+
if (txn_node) root->add_child(txn_node);
68+
return root;
69+
}
70+
// Not TRANSACTION — it's SET GLOBAL var = expr
71+
// Fall through to variable assignment with scope
72+
AstNode* assignment = parse_variable_assignment(&scope_tok);
73+
if (assignment) root->add_child(assignment);
74+
// Parse remaining comma-separated assignments
75+
while (tok_.peek().type == TokenType::TK_COMMA) {
76+
tok_.skip();
77+
AstNode* next_assign = parse_variable_assignment(nullptr);
78+
if (next_assign) root->add_child(next_assign);
79+
}
80+
return root;
81+
}
82+
83+
// PostgreSQL: SET LOCAL var = expr
84+
if constexpr (D == Dialect::PostgreSQL) {
85+
if (next.type == TokenType::TK_LOCAL) {
86+
Token scope_tok = tok_.next_token();
87+
AstNode* assignment = parse_variable_assignment(&scope_tok);
88+
if (assignment) root->add_child(assignment);
89+
return root;
90+
}
91+
}
92+
93+
// SET var = expr [, var = expr, ...]
94+
AstNode* assignment = parse_variable_assignment(nullptr);
95+
if (assignment) root->add_child(assignment);
96+
while (tok_.peek().type == TokenType::TK_COMMA) {
97+
tok_.skip();
98+
AstNode* next_assign = parse_variable_assignment(nullptr);
99+
if (next_assign) root->add_child(next_assign);
100+
}
101+
102+
return root;
103+
}
104+
105+
private:
106+
Tokenizer<D>& tok_;
107+
Arena& arena_;
108+
ExpressionParser<D> expr_parser_;
109+
110+
// SET NAMES charset [COLLATE collation]
111+
AstNode* parse_set_names() {
112+
AstNode* node = make_node(arena_, NodeType::NODE_SET_NAMES);
113+
if (!node) return nullptr;
114+
115+
// charset name or DEFAULT
116+
Token charset = tok_.next_token();
117+
node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, charset.text));
118+
119+
// Optional COLLATE
120+
if (tok_.peek().type == TokenType::TK_COLLATE) {
121+
tok_.skip();
122+
Token collation = tok_.next_token();
123+
node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, collation.text));
124+
}
125+
return node;
126+
}
127+
128+
// SET CHARACTER SET charset / SET CHARSET charset
129+
AstNode* parse_set_charset() {
130+
AstNode* node = make_node(arena_, NodeType::NODE_SET_CHARSET);
131+
if (!node) return nullptr;
132+
133+
Token charset = tok_.next_token();
134+
node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, charset.text));
135+
return node;
136+
}
137+
138+
// SET [GLOBAL|SESSION] TRANSACTION ...
139+
AstNode* parse_set_transaction(StringRef scope) {
140+
AstNode* node = make_node(arena_, NodeType::NODE_SET_TRANSACTION);
141+
if (!node) return nullptr;
142+
143+
if (!scope.empty()) {
144+
node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, scope));
145+
}
146+
147+
// ISOLATION LEVEL ... or READ ONLY/WRITE
148+
Token next = tok_.peek();
149+
if (next.type == TokenType::TK_ISOLATION) {
150+
tok_.skip();
151+
if (tok_.peek().type == TokenType::TK_LEVEL) tok_.skip();
152+
153+
// READ UNCOMMITTED | READ COMMITTED | REPEATABLE READ | SERIALIZABLE
154+
Token level = tok_.next_token();
155+
if (level.type == TokenType::TK_READ) {
156+
Token sublevel = tok_.next_token();
157+
// Combine "READ COMMITTED" or "READ UNCOMMITTED"
158+
StringRef combined{level.text.ptr,
159+
static_cast<uint32_t>((sublevel.text.ptr + sublevel.text.len) - level.text.ptr)};
160+
node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, combined));
161+
} else if (level.type == TokenType::TK_REPEATABLE) {
162+
Token read_tok = tok_.next_token(); // READ
163+
StringRef combined{level.text.ptr,
164+
static_cast<uint32_t>((read_tok.text.ptr + read_tok.text.len) - level.text.ptr)};
165+
node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, combined));
166+
} else {
167+
// SERIALIZABLE
168+
node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, level.text));
169+
}
170+
} else if (next.type == TokenType::TK_READ) {
171+
tok_.skip();
172+
Token rw = tok_.next_token(); // ONLY or WRITE
173+
StringRef combined{next.text.ptr,
174+
static_cast<uint32_t>((rw.text.ptr + rw.text.len) - next.text.ptr)};
175+
node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, combined));
176+
}
177+
178+
return node;
179+
}
180+
181+
// Parse a single variable assignment: [scope] target = expr
182+
// scope_token is non-null if GLOBAL/SESSION/LOCAL was already consumed
183+
AstNode* parse_variable_assignment(const Token* scope_token) {
184+
AstNode* assignment = make_node(arena_, NodeType::NODE_VAR_ASSIGNMENT);
185+
if (!assignment) return nullptr;
186+
187+
// Build the variable target
188+
AstNode* target = make_node(arena_, NodeType::NODE_VAR_TARGET);
189+
if (!target) return nullptr;
190+
191+
if (scope_token) {
192+
target->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, scope_token->text));
193+
}
194+
195+
Token var = tok_.peek();
196+
if (var.type == TokenType::TK_AT) {
197+
// User variable @name
198+
tok_.skip();
199+
Token name = tok_.next_token();
200+
StringRef full{var.text.ptr,
201+
static_cast<uint32_t>((name.text.ptr + name.text.len) - var.text.ptr)};
202+
target->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, full));
203+
} else if (var.type == TokenType::TK_DOUBLE_AT) {
204+
// System variable @@[scope.]name
205+
tok_.skip();
206+
Token name = tok_.next_token();
207+
StringRef full{var.text.ptr,
208+
static_cast<uint32_t>((name.text.ptr + name.text.len) - var.text.ptr)};
209+
// Check for @@scope.name
210+
if (tok_.peek().type == TokenType::TK_DOT) {
211+
tok_.skip();
212+
Token actual_name = tok_.next_token();
213+
full = StringRef{var.text.ptr,
214+
static_cast<uint32_t>((actual_name.text.ptr + actual_name.text.len) - var.text.ptr)};
215+
}
216+
target->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, full));
217+
} else {
218+
// Plain variable name
219+
Token name = tok_.next_token();
220+
target->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, name.text));
221+
}
222+
223+
assignment->add_child(target);
224+
225+
// Expect = or := (MySQL) or TO (PostgreSQL)
226+
Token eq = tok_.peek();
227+
if (eq.type == TokenType::TK_EQUAL || eq.type == TokenType::TK_COLON_EQUAL) {
228+
tok_.skip();
229+
} else if constexpr (D == Dialect::PostgreSQL) {
230+
if (eq.type == TokenType::TK_TO) {
231+
tok_.skip();
232+
}
233+
}
234+
235+
// Parse RHS expression
236+
AstNode* rhs = expr_parser_.parse();
237+
if (rhs) assignment->add_child(rhs);
238+
239+
return assignment;
240+
}
241+
};
242+
243+
} // namespace sql_parser
244+
245+
#endif // SQL_PARSER_SET_PARSER_H

src/sql_parser/parser.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#include "sql_parser/parser.h"
2+
#include "sql_parser/expression_parser.h"
3+
#include "sql_parser/set_parser.h"
24

35
namespace sql_parser {
46

@@ -74,8 +76,18 @@ ParseResult Parser<D>::parse_select() {
7476
template <Dialect D>
7577
ParseResult Parser<D>::parse_set() {
7678
ParseResult r;
77-
r.status = ParseResult::PARTIAL;
7879
r.stmt_type = StmtType::SET;
80+
81+
SetParser<D> set_parser(tokenizer_, arena_);
82+
AstNode* ast = set_parser.parse();
83+
84+
if (ast) {
85+
r.status = ParseResult::OK;
86+
r.ast = ast;
87+
} else {
88+
r.status = ParseResult::PARTIAL;
89+
}
90+
7991
scan_to_end(r);
8092
return r;
8193
}

0 commit comments

Comments
 (0)