Skip to content

Commit 67f1a9d

Browse files
committed
feat: add window functions (ROW_NUMBER, RANK, SUM/COUNT/AVG OVER) and CTEs (WITH clause)
Window functions: ROW_NUMBER, RANK, DENSE_RANK, SUM, COUNT, AVG, MIN, MAX, LAG, LEAD, FIRST_VALUE, LAST_VALUE with PARTITION BY and ORDER BY support. CTEs: WITH name AS (SELECT ...) with materialized execution, multiple CTEs, and proper alias propagation for aggregated CTE queries.
1 parent 4d6e1d8 commit 67f1a9d

18 files changed

Lines changed: 1447 additions & 3 deletions

Makefile

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ TEST_SRCS = $(TEST_DIR)/test_main.cpp \
7777
$(TEST_DIR)/test_local_txn.cpp \
7878
$(TEST_DIR)/test_session.cpp \
7979
$(TEST_DIR)/test_single_backend_txn.cpp \
80-
$(TEST_DIR)/test_distributed_txn.cpp
80+
$(TEST_DIR)/test_distributed_txn.cpp \
81+
$(TEST_DIR)/test_window.cpp \
82+
$(TEST_DIR)/test_cte.cpp
8183
TEST_OBJS = $(TEST_SRCS:.cpp=.o)
8284
TEST_TARGET = $(PROJECT_ROOT)/run_tests
8385

include/sql_engine/catalog.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ class Catalog {
5959
// Find a column in a table by name. Returns nullptr if not found.
6060
virtual const ColumnInfo* get_column(const TableInfo* table,
6161
sql_parser::StringRef column_name) const = 0;
62+
63+
// Register a pre-built TableInfo (e.g., for CTEs). Default: no-op.
64+
virtual void register_table(const TableInfo* /*table*/) {}
6265
};
6366

6467
} // namespace sql_engine

include/sql_engine/in_memory_catalog.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ class InMemoryCatalog : public Catalog {
2020
// Remove a table.
2121
void drop_table(const char* schema, const char* table);
2222

23+
// Register a pre-built TableInfo (e.g., from CTE materialization)
24+
void register_table(const TableInfo* table) override;
25+
2326
// Catalog interface
2427
const TableInfo* get_table(sql_parser::StringRef name) const override;
2528
const TableInfo* get_table(sql_parser::StringRef schema,

include/sql_engine/operators/window_op.h

Lines changed: 543 additions & 0 deletions
Large diffs are not rendered by default.

include/sql_engine/plan_builder.h

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "sql_parser/common.h"
2626
#include "sql_parser/arena.h"
2727
#include <cstring>
28+
#include <vector>
2829

2930
namespace sql_engine {
3031

@@ -45,6 +46,9 @@ class PlanBuilder {
4546
if (stmt_ast->type == sql_parser::NodeType::NODE_COMPOUND_QUERY) {
4647
return build_compound(stmt_ast);
4748
}
49+
if (stmt_ast->type == sql_parser::NodeType::NODE_CTE) {
50+
return build_cte(stmt_ast);
51+
}
4852
return nullptr;
4953
}
5054

@@ -96,13 +100,34 @@ class PlanBuilder {
96100
return JOIN_INNER;
97101
}
98102

103+
// Check if an expression is a window function (NODE_WINDOW_FUNCTION)
104+
static bool is_window_function(const sql_parser::AstNode* expr) {
105+
if (!expr) return false;
106+
return expr->type == sql_parser::NodeType::NODE_WINDOW_FUNCTION;
107+
}
108+
109+
// Check if SELECT list contains any window functions
110+
static bool has_window_functions(const sql_parser::AstNode* select_ast) {
111+
const sql_parser::AstNode* items = find_child(select_ast, sql_parser::NodeType::NODE_SELECT_ITEM_LIST);
112+
if (!items) return false;
113+
for (const sql_parser::AstNode* item = items->first_child; item; item = item->next_sibling) {
114+
if (item->first_child && is_window_function(item->first_child)) return true;
115+
}
116+
return false;
117+
}
118+
99119
// Check if an expression (or any descendant) contains an aggregate function call.
100120
// Does NOT recurse into subqueries -- aggregates inside subqueries belong
101121
// to the subquery's own aggregation, not the outer query.
122+
// Does NOT recurse into window functions -- aggregates inside OVER() are
123+
// handled by the window operator, not the aggregate operator.
102124
static bool has_aggregate(const sql_parser::AstNode* expr) {
103125
if (!expr) return false;
104126
// Do not recurse into subqueries
105127
if (expr->type == sql_parser::NodeType::NODE_SUBQUERY) return false;
128+
// Do not recurse into window functions - aggregates inside OVER()
129+
// belong to the window operator
130+
if (expr->type == sql_parser::NodeType::NODE_WINDOW_FUNCTION) return false;
106131
if (expr->type == sql_parser::NodeType::NODE_FUNCTION_CALL) {
107132
sql_parser::StringRef name = expr->value();
108133
if (name.equals_ci("COUNT", 5) || name.equals_ci("SUM", 3) ||
@@ -206,6 +231,43 @@ class PlanBuilder {
206231
current = filter;
207232
}
208233

234+
// 4b. WINDOW -> Window node (if SELECT list has window functions)
235+
if (has_window_functions(select_ast)) {
236+
const sql_parser::AstNode* wnd_items = find_child(select_ast, sql_parser::NodeType::NODE_SELECT_ITEM_LIST);
237+
if (wnd_items) {
238+
// Collect window function expressions and full select list
239+
std::vector<const sql_parser::AstNode*> win_exprs;
240+
uint16_t total_count = count_children(wnd_items);
241+
242+
auto** sel_exprs = static_cast<const sql_parser::AstNode**>(
243+
arena_.allocate(sizeof(sql_parser::AstNode*) * total_count));
244+
auto** sel_aliases = static_cast<const sql_parser::AstNode**>(
245+
arena_.allocate(sizeof(sql_parser::AstNode*) * total_count));
246+
uint16_t idx = 0;
247+
for (const sql_parser::AstNode* item = wnd_items->first_child; item; item = item->next_sibling) {
248+
sel_exprs[idx] = item->first_child;
249+
sel_aliases[idx] = find_child(item, sql_parser::NodeType::NODE_ALIAS);
250+
if (item->first_child && is_window_function(item->first_child)) {
251+
win_exprs.push_back(item->first_child);
252+
}
253+
++idx;
254+
}
255+
256+
PlanNode* wnd = make_plan_node(arena_, PlanNodeType::WINDOW);
257+
uint16_t wc = static_cast<uint16_t>(win_exprs.size());
258+
auto** warr = static_cast<const sql_parser::AstNode**>(
259+
arena_.allocate(sizeof(sql_parser::AstNode*) * wc));
260+
for (uint16_t i = 0; i < wc; ++i) warr[i] = win_exprs[i];
261+
wnd->window.window_exprs = warr;
262+
wnd->window.window_count = wc;
263+
wnd->window.select_exprs = sel_exprs;
264+
wnd->window.select_aliases = sel_aliases;
265+
wnd->window.select_count = total_count;
266+
wnd->left = current;
267+
current = wnd;
268+
}
269+
}
270+
209271
// 5. ORDER BY -> Sort (before Project so sort keys resolve against full row)
210272
const sql_parser::AstNode* order_by = find_child(select_ast, sql_parser::NodeType::NODE_ORDER_BY_CLAUSE);
211273
if (order_by) {
@@ -236,9 +298,10 @@ class PlanBuilder {
236298
current = sort;
237299
}
238300

239-
// 6. SELECT list -> Project
301+
// 6. SELECT list -> Project (skip if WINDOW node handles it)
302+
bool has_window = has_window_functions(select_ast);
240303
const sql_parser::AstNode* item_list = find_child(select_ast, sql_parser::NodeType::NODE_SELECT_ITEM_LIST);
241-
if (item_list) {
304+
if (item_list && !has_window) {
242305
// Check if this is "SELECT *" with a single asterisk and no aliases -- skip Project for bare scan
243306
bool is_star_only = false;
244307
const sql_parser::AstNode* first_item = item_list->first_child;
@@ -329,6 +392,23 @@ class PlanBuilder {
329392
return result;
330393
}
331394

395+
// Build plan for CTE (WITH clause)
396+
// The CTE node has CTE_DEFINITION children, with the last child being the main SELECT.
397+
// We build the plan for the main SELECT, and store CTE definitions separately
398+
// for the executor to materialize.
399+
PlanNode* build_cte(const sql_parser::AstNode* cte_ast) {
400+
// The last child is the main query (SELECT_STMT or COMPOUND_QUERY)
401+
const sql_parser::AstNode* main_query = nullptr;
402+
for (const sql_parser::AstNode* c = cte_ast->first_child; c; c = c->next_sibling) {
403+
if (c->type == sql_parser::NodeType::NODE_SELECT_STMT ||
404+
c->type == sql_parser::NodeType::NODE_COMPOUND_QUERY) {
405+
main_query = c;
406+
}
407+
}
408+
if (!main_query) return nullptr;
409+
return build(main_query);
410+
}
411+
332412
// Build plan from FROM clause
333413
PlanNode* build_from(const sql_parser::AstNode* from_clause) {
334414
PlanNode* current = nullptr;

0 commit comments

Comments
 (0)