Skip to content

Commit 0fae9d5

Browse files
committed
feat: add subquery execution — scalar, EXISTS, IN, derived tables
Implement subquery support across the parser and execution engine: Parser changes: - Add SubqueryParseCallback to ExpressionParser and TableRefParser so (SELECT ...) inside expressions and FROM clauses are fully parsed instead of skipped - Fix CompoundQueryParser to not greedily consume parenthesized subqueries that are actually scalar expressions in SELECT lists - Propagate subquery callback through SelectParser, CompoundQueryParser Execution engine: - SubqueryExecutor<D> bridges expression evaluator to full executor pipeline for scalar, EXISTS, and IN subqueries - DerivedScanOperator materializes inner plan results for FROM subqueries - PlanBuilder creates implicit AGGREGATE nodes for queries with aggregate functions but no GROUP BY (e.g., SELECT MAX(age) FROM users) - PlanBuilder generates synthetic TableInfo for derived tables to enable column name resolution in outer queries - FilterOperator and ProjectOperator accept optional SubqueryExecutor 20 new tests covering: scalar subqueries (MAX/MIN/COUNT), scalar returning 0 rows (NULL), IN subquery, NOT IN, EXISTS, NOT EXISTS, derived tables, subquery with LIMIT, and regression tests.
1 parent ba8e3e7 commit 0fae9d5

16 files changed

Lines changed: 909 additions & 76 deletions

Makefile.new

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ TEST_SRCS = $(TEST_DIR)/test_main.cpp \
7272
$(TEST_DIR)/test_distributed_dml.cpp \
7373
$(TEST_DIR)/test_mysql_executor.cpp \
7474
$(TEST_DIR)/test_pgsql_executor.cpp \
75-
$(TEST_DIR)/test_distributed_real.cpp
75+
$(TEST_DIR)/test_distributed_real.cpp \
76+
$(TEST_DIR)/test_subquery.cpp
7677
TEST_OBJS = $(TEST_SRCS:.cpp=.o)
7778
TEST_TARGET = $(PROJECT_ROOT)/run_tests
7879

include/sql_engine/expression_eval.h

Lines changed: 61 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// expression_eval.h — Recursive AST expression evaluator
22
//
3-
// Core function: evaluate_expression<D>(expr, resolve, functions, arena)
3+
// Core function: evaluate_expression<D>(expr, resolve, functions, arena, subquery_exec)
44
//
55
// Takes a parsed AST expression node and evaluates it recursively against
66
// a row of data. The `resolve` callback maps column names (StringRef) to
@@ -27,13 +27,15 @@
2727
#include "sql_engine/tag_kind_map.h"
2828
#include "sql_engine/like.h"
2929
#include "sql_engine/function_registry.h"
30+
#include "sql_engine/subquery_executor.h"
3031
#include "sql_parser/common.h"
3132
#include "sql_parser/ast.h"
3233
#include "sql_parser/arena.h"
3334
#include <functional>
3435
#include <cstdlib>
3536
#include <cstring>
3637
#include <cmath>
38+
#include <vector>
3739

3840
namespace sql_engine {
3941

@@ -97,7 +99,8 @@ template <Dialect D>
9799
Value evaluate_expression(const AstNode* expr,
98100
const std::function<Value(StringRef)>& resolve,
99101
FunctionRegistry<D>& functions,
100-
Arena& arena) {
102+
Arena& arena,
103+
SubqueryExecutor<D>* subquery_exec = nullptr) {
101104
if (!expr) return value_null();
102105

103106
switch (expr->type) {
@@ -163,7 +166,7 @@ Value evaluate_expression(const AstNode* expr,
163166
// ---- Wrapper: unwrap and evaluate first child ----
164167

165168
case NodeType::NODE_EXPRESSION: {
166-
return evaluate_expression<D>(expr->first_child, resolve, functions, arena);
169+
return evaluate_expression<D>(expr->first_child, resolve, functions, arena, subquery_exec);
167170
}
168171

169172
// ---- Unary operators ----
@@ -172,7 +175,7 @@ Value evaluate_expression(const AstNode* expr,
172175
StringRef op = expr->value();
173176
const AstNode* operand_node = expr->first_child;
174177
if (!operand_node) return value_null();
175-
Value operand = evaluate_expression<D>(operand_node, resolve, functions, arena);
178+
Value operand = evaluate_expression<D>(operand_node, resolve, functions, arena, subquery_exec);
176179
if (op.len == 1 && op.ptr[0] == '-') {
177180
// Unary minus
178181
if (operand.is_null()) return value_null();
@@ -209,28 +212,28 @@ Value evaluate_expression(const AstNode* expr,
209212

210213
// --- Short-circuit: AND ---
211214
if (detail::ref_equals_ci(op, "AND", 3)) {
212-
Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena);
215+
Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena, subquery_exec);
213216
// If left is FALSE -> FALSE immediately
214217
if (!left_val.is_null() && left_val.tag == Value::TAG_BOOL && !left_val.bool_val)
215218
return value_bool(false);
216-
Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena);
219+
Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena, subquery_exec);
217220
return null_semantics::eval_and(left_val, right_val);
218221
}
219222

220223
// --- Short-circuit: OR ---
221224
if (detail::ref_equals_ci(op, "OR", 2)) {
222-
Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena);
225+
Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena, subquery_exec);
223226
// If left is TRUE -> TRUE immediately
224227
if (!left_val.is_null() && left_val.tag == Value::TAG_BOOL && left_val.bool_val)
225228
return value_bool(true);
226-
Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena);
229+
Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena, subquery_exec);
227230
return null_semantics::eval_or(left_val, right_val);
228231
}
229232

230233
// --- IS / IS NOT (never return NULL) ---
231234
if (detail::ref_equals_ci(op, "IS", 2)) {
232-
Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena);
233-
Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena);
235+
Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena, subquery_exec);
236+
Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena, subquery_exec);
234237
// IS TRUE: left is truthy and not null
235238
if (!right_val.is_null() && right_val.tag == Value::TAG_BOOL && right_val.bool_val) {
236239
// IS TRUE
@@ -251,8 +254,8 @@ Value evaluate_expression(const AstNode* expr,
251254
return value_null();
252255
}
253256
if (detail::ref_equals_ci(op, "IS NOT", 6)) {
254-
Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena);
255-
Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena);
257+
Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena, subquery_exec);
258+
Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena, subquery_exec);
256259
// IS NOT TRUE = NOT (IS TRUE)
257260
if (!right_val.is_null() && right_val.tag == Value::TAG_BOOL && right_val.bool_val) {
258261
if (left_val.is_null()) return value_bool(true);
@@ -274,8 +277,8 @@ Value evaluate_expression(const AstNode* expr,
274277

275278
// --- LIKE ---
276279
if (detail::ref_equals_ci(op, "LIKE", 4)) {
277-
Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena);
278-
Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena);
280+
Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena, subquery_exec);
281+
Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena, subquery_exec);
279282
if (left_val.is_null() || right_val.is_null()) return value_null();
280283
// Coerce both to strings if not already
281284
if (left_val.tag != Value::TAG_STRING)
@@ -290,8 +293,8 @@ Value evaluate_expression(const AstNode* expr,
290293
if (op.len == 2 && op.ptr[0] == '|' && op.ptr[1] == '|') {
291294
if constexpr (D == Dialect::PostgreSQL) {
292295
// String concatenation
293-
Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena);
294-
Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena);
296+
Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena, subquery_exec);
297+
Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena, subquery_exec);
295298
if (left_val.is_null() || right_val.is_null()) return value_null();
296299
// Coerce to string
297300
if (left_val.tag != Value::TAG_STRING)
@@ -308,17 +311,17 @@ Value evaluate_expression(const AstNode* expr,
308311
return value_string(StringRef{buf, total});
309312
} else {
310313
// MySQL: || is OR
311-
Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena);
314+
Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena, subquery_exec);
312315
if (!left_val.is_null() && left_val.tag == Value::TAG_BOOL && left_val.bool_val)
313316
return value_bool(true);
314-
Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena);
317+
Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena, subquery_exec);
315318
return null_semantics::eval_or(left_val, right_val);
316319
}
317320
}
318321

319322
// --- Standard binary: evaluate both sides, null-propagate, coerce, apply ---
320-
Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena);
321-
Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena);
323+
Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena, subquery_exec);
324+
Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena, subquery_exec);
322325

323326
// NULL propagation
324327
if (left_val.is_null() || right_val.is_null()) return value_null();
@@ -448,12 +451,12 @@ Value evaluate_expression(const AstNode* expr,
448451
// ---- IS NULL / IS NOT NULL (never return NULL) ----
449452

450453
case NodeType::NODE_IS_NULL: {
451-
Value child = evaluate_expression<D>(expr->first_child, resolve, functions, arena);
454+
Value child = evaluate_expression<D>(expr->first_child, resolve, functions, arena, subquery_exec);
452455
return value_bool(child.is_null());
453456
}
454457

455458
case NodeType::NODE_IS_NOT_NULL: {
456-
Value child = evaluate_expression<D>(expr->first_child, resolve, functions, arena);
459+
Value child = evaluate_expression<D>(expr->first_child, resolve, functions, arena, subquery_exec);
457460
return value_bool(!child.is_null());
458461
}
459462

@@ -465,9 +468,9 @@ Value evaluate_expression(const AstNode* expr,
465468
const AstNode* high_node = detail::nth_child(expr, 2);
466469
if (!expr_node || !low_node || !high_node) return value_null();
467470

468-
Value val = evaluate_expression<D>(expr_node, resolve, functions, arena);
471+
Value val = evaluate_expression<D>(expr_node, resolve, functions, arena, subquery_exec);
469472
Value low = evaluate_expression<D>(low_node, resolve, functions, arena);
470-
Value high = evaluate_expression<D>(high_node, resolve, functions, arena);
473+
Value high = evaluate_expression<D>(high_node, resolve, functions, arena, subquery_exec);
471474

472475
// NULL propagation
473476
if (val.is_null() || low.is_null() || high.is_null()) return value_null();
@@ -515,14 +518,25 @@ Value evaluate_expression(const AstNode* expr,
515518
const AstNode* expr_node = expr->first_child;
516519
if (!expr_node) return value_null();
517520

518-
Value val = evaluate_expression<D>(expr_node, resolve, functions, arena);
521+
Value val = evaluate_expression<D>(expr_node, resolve, functions, arena, subquery_exec);
519522
if (val.is_null()) return value_null();
520523

521524
bool found = false;
522525
bool has_null = false;
523526

527+
// Collect all IN-list values. If a child is a subquery, expand it.
528+
std::vector<Value> in_values;
524529
for (const AstNode* item = expr_node->next_sibling; item; item = item->next_sibling) {
525-
Value item_val = evaluate_expression<D>(item, resolve, functions, arena);
530+
if (item->type == NodeType::NODE_SUBQUERY && subquery_exec && item->first_child) {
531+
// IN (subquery) -- execute and expand
532+
std::vector<Value> set_vals = subquery_exec->execute_set(item, resolve);
533+
for (auto& sv : set_vals) in_values.push_back(sv);
534+
} else {
535+
in_values.push_back(evaluate_expression<D>(item, resolve, functions, arena, subquery_exec));
536+
}
537+
}
538+
539+
for (const auto& item_val : in_values) {
526540
if (item_val.is_null()) {
527541
has_null = true;
528542
continue;
@@ -561,13 +575,13 @@ Value evaluate_expression(const AstNode* expr,
561575
// Simple CASE: children = [case_expr, when1, then1, when2, then2, ..., else?]
562576
const AstNode* case_node = expr->first_child;
563577
if (!case_node) return value_null();
564-
Value case_val = evaluate_expression<D>(case_node, resolve, functions, arena);
578+
Value case_val = evaluate_expression<D>(case_node, resolve, functions, arena, subquery_exec);
565579

566580
const AstNode* child = case_node->next_sibling;
567581
uint32_t remaining = count - 1; // excluding case_expr
568582

569583
while (child && child->next_sibling) {
570-
Value when_val = evaluate_expression<D>(child, resolve, functions, arena);
584+
Value when_val = evaluate_expression<D>(child, resolve, functions, arena, subquery_exec);
571585
const AstNode* then_node = child->next_sibling;
572586

573587
// Compare case_val = when_val
@@ -590,7 +604,7 @@ Value evaluate_expression(const AstNode* expr,
590604
}
591605

592606
if (match) {
593-
return evaluate_expression<D>(then_node, resolve, functions, arena);
607+
return evaluate_expression<D>(then_node, resolve, functions, arena, subquery_exec);
594608
}
595609

596610
child = then_node->next_sibling;
@@ -599,7 +613,7 @@ Value evaluate_expression(const AstNode* expr,
599613

600614
// Check for ELSE (one remaining child)
601615
if (child && remaining == 1) {
602-
return evaluate_expression<D>(child, resolve, functions, arena);
616+
return evaluate_expression<D>(child, resolve, functions, arena, subquery_exec);
603617
}
604618
return value_null();
605619
} else {
@@ -608,7 +622,7 @@ Value evaluate_expression(const AstNode* expr,
608622
uint32_t remaining = count;
609623

610624
while (child && child->next_sibling) {
611-
Value when_val = evaluate_expression<D>(child, resolve, functions, arena);
625+
Value when_val = evaluate_expression<D>(child, resolve, functions, arena, subquery_exec);
612626
const AstNode* then_node = child->next_sibling;
613627

614628
// Evaluate WHEN condition as boolean
@@ -621,7 +635,7 @@ Value evaluate_expression(const AstNode* expr,
621635
}
622636

623637
if (is_true) {
624-
return evaluate_expression<D>(then_node, resolve, functions, arena);
638+
return evaluate_expression<D>(then_node, resolve, functions, arena, subquery_exec);
625639
}
626640

627641
child = then_node->next_sibling;
@@ -630,7 +644,7 @@ Value evaluate_expression(const AstNode* expr,
630644

631645
// Check for ELSE (one remaining child)
632646
if (child && remaining % 2 == 1) {
633-
return evaluate_expression<D>(child, resolve, functions, arena);
647+
return evaluate_expression<D>(child, resolve, functions, arena, subquery_exec);
634648
}
635649
return value_null();
636650
}
@@ -651,14 +665,26 @@ Value evaluate_expression(const AstNode* expr,
651665
uint32_t i = 0;
652666
for (const AstNode* arg = expr->first_child; arg && i < MAX_ARGS;
653667
arg = arg->next_sibling, ++i) {
654-
new (&args[i]) Value(evaluate_expression<D>(arg, resolve, functions, arena));
668+
new (&args[i]) Value(evaluate_expression<D>(arg, resolve, functions, arena, subquery_exec));
655669
}
656670
return entry->impl(args, static_cast<uint16_t>(i), arena);
657671
}
658672

659673
// ---- Deferred node types (return value_null) ----
660674

661-
case NodeType::NODE_SUBQUERY: return value_null(); // requires full executor
675+
case NodeType::NODE_SUBQUERY: {
676+
// If the subquery has a parsed SELECT child and we have an executor, run it
677+
if (subquery_exec && expr->first_child) {
678+
// Check if this is an EXISTS subquery (flags == 1)
679+
if (expr->flags == 1) {
680+
bool exists = subquery_exec->execute_exists(expr, resolve);
681+
return value_bool(exists);
682+
}
683+
// Otherwise treat as scalar subquery
684+
return subquery_exec->execute_scalar(expr, resolve);
685+
}
686+
return value_null(); // no executor or no parsed child
687+
}
662688
case NodeType::NODE_TUPLE: return value_null(); // requires row/tuple value type
663689
case NodeType::NODE_ARRAY_CONSTRUCTOR: return value_null(); // requires array value type
664690
case NodeType::NODE_ARRAY_SUBSCRIPT: return value_null(); // requires array support
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// derived_scan_op.h -- DerivedScanOperator for subqueries in FROM clause
2+
//
3+
// Materializes the inner plan's result set on open(), then yields
4+
// rows one at a time from the materialized buffer on next().
5+
6+
#ifndef SQL_ENGINE_OPERATORS_DERIVED_SCAN_OP_H
7+
#define SQL_ENGINE_OPERATORS_DERIVED_SCAN_OP_H
8+
9+
#include "sql_engine/operator.h"
10+
#include "sql_engine/row.h"
11+
#include <vector>
12+
13+
namespace sql_engine {
14+
15+
class DerivedScanOperator : public Operator {
16+
public:
17+
// Takes ownership of the inner operator. On open(), pulls all rows
18+
// from the inner operator and stores them. On next(), yields them.
19+
explicit DerivedScanOperator(Operator* inner)
20+
: inner_(inner) {}
21+
22+
void open() override {
23+
rows_.clear();
24+
cursor_ = 0;
25+
if (inner_) {
26+
inner_->open();
27+
Row row{};
28+
while (inner_->next(row)) {
29+
rows_.push_back(row);
30+
}
31+
inner_->close();
32+
}
33+
}
34+
35+
bool next(Row& out) override {
36+
if (cursor_ >= rows_.size()) return false;
37+
out = rows_[cursor_++];
38+
return true;
39+
}
40+
41+
void close() override {
42+
rows_.clear();
43+
cursor_ = 0;
44+
}
45+
46+
private:
47+
Operator* inner_;
48+
std::vector<Row> rows_;
49+
size_t cursor_ = 0;
50+
};
51+
52+
} // namespace sql_engine
53+
54+
#endif // SQL_ENGINE_OPERATORS_DERIVED_SCAN_OP_H

include/sql_engine/operators/filter_op.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "sql_engine/operator.h"
55
#include "sql_engine/expression_eval.h"
66
#include "sql_engine/catalog.h"
7+
#include "sql_engine/subquery_executor.h"
78
#include "sql_parser/arena.h"
89
#include <functional>
910
#include <vector>
@@ -18,9 +19,11 @@ class FilterOperator : public Operator {
1819
const Catalog& catalog,
1920
const std::vector<const TableInfo*>& tables,
2021
FunctionRegistry<D>& functions,
21-
sql_parser::Arena& arena)
22+
sql_parser::Arena& arena,
23+
SubqueryExecutor<D>* subquery_exec = nullptr)
2224
: child_(child), expr_(expr), catalog_(catalog),
23-
tables_(tables), functions_(functions), arena_(arena) {}
25+
tables_(tables), functions_(functions), arena_(arena),
26+
subquery_exec_(subquery_exec) {}
2427

2528
void open() override {
2629
child_->open();
@@ -30,7 +33,7 @@ class FilterOperator : public Operator {
3033
while (child_->next(out)) {
3134
// Build a resolver that can look up column names in any of the known tables
3235
auto resolver = make_multi_table_resolver(out);
33-
Value result = evaluate_expression<D>(expr_, resolver, functions_, arena_);
36+
Value result = evaluate_expression<D>(expr_, resolver, functions_, arena_, subquery_exec_);
3437
if (is_truthy(result)) return true;
3538
}
3639
return false;
@@ -47,6 +50,7 @@ class FilterOperator : public Operator {
4750
std::vector<const TableInfo*> tables_;
4851
FunctionRegistry<D>& functions_;
4952
sql_parser::Arena& arena_;
53+
SubqueryExecutor<D>* subquery_exec_ = nullptr;
5054

5155
std::function<Value(sql_parser::StringRef)> make_multi_table_resolver(const Row& row) {
5256
return [this, &row](sql_parser::StringRef col_name) -> Value {

0 commit comments

Comments
 (0)