11// expression_eval.h — Recursive AST expression evaluator
22//
3- // Core function: evaluate_expression<D>(expr, resolve, functions, arena)
3+ // Core function: evaluate_expression<D>(expr, resolve, functions, arena, subquery_exec )
44//
55// Takes a parsed AST expression node and evaluates it recursively against
66// a row of data. The `resolve` callback maps column names (StringRef) to
2727#include " sql_engine/tag_kind_map.h"
2828#include " sql_engine/like.h"
2929#include " sql_engine/function_registry.h"
30+ #include " sql_engine/subquery_executor.h"
3031#include " sql_parser/common.h"
3132#include " sql_parser/ast.h"
3233#include " sql_parser/arena.h"
3334#include < functional>
3435#include < cstdlib>
3536#include < cstring>
3637#include < cmath>
38+ #include < vector>
3739
3840namespace sql_engine {
3941
@@ -97,7 +99,8 @@ template <Dialect D>
9799Value evaluate_expression (const AstNode* expr,
98100 const std::function<Value(StringRef)>& resolve,
99101 FunctionRegistry<D>& functions,
100- Arena& arena) {
102+ Arena& arena,
103+ SubqueryExecutor<D>* subquery_exec = nullptr) {
101104 if (!expr) return value_null ();
102105
103106 switch (expr->type ) {
@@ -163,7 +166,7 @@ Value evaluate_expression(const AstNode* expr,
163166 // ---- Wrapper: unwrap and evaluate first child ----
164167
165168 case NodeType::NODE_EXPRESSION: {
166- return evaluate_expression<D>(expr->first_child , resolve, functions, arena);
169+ return evaluate_expression<D>(expr->first_child , resolve, functions, arena, subquery_exec );
167170 }
168171
169172 // ---- Unary operators ----
@@ -172,7 +175,7 @@ Value evaluate_expression(const AstNode* expr,
172175 StringRef op = expr->value ();
173176 const AstNode* operand_node = expr->first_child ;
174177 if (!operand_node) return value_null ();
175- Value operand = evaluate_expression<D>(operand_node, resolve, functions, arena);
178+ Value operand = evaluate_expression<D>(operand_node, resolve, functions, arena, subquery_exec );
176179 if (op.len == 1 && op.ptr [0 ] == ' -' ) {
177180 // Unary minus
178181 if (operand.is_null ()) return value_null ();
@@ -209,28 +212,28 @@ Value evaluate_expression(const AstNode* expr,
209212
210213 // --- Short-circuit: AND ---
211214 if (detail::ref_equals_ci (op, " AND" , 3 )) {
212- Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena);
215+ Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena, subquery_exec );
213216 // If left is FALSE -> FALSE immediately
214217 if (!left_val.is_null () && left_val.tag == Value::TAG_BOOL && !left_val.bool_val )
215218 return value_bool (false );
216- Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena);
219+ Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena, subquery_exec );
217220 return null_semantics::eval_and (left_val, right_val);
218221 }
219222
220223 // --- Short-circuit: OR ---
221224 if (detail::ref_equals_ci (op, " OR" , 2 )) {
222- Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena);
225+ Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena, subquery_exec );
223226 // If left is TRUE -> TRUE immediately
224227 if (!left_val.is_null () && left_val.tag == Value::TAG_BOOL && left_val.bool_val )
225228 return value_bool (true );
226- Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena);
229+ Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena, subquery_exec );
227230 return null_semantics::eval_or (left_val, right_val);
228231 }
229232
230233 // --- IS / IS NOT (never return NULL) ---
231234 if (detail::ref_equals_ci (op, " IS" , 2 )) {
232- Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena);
233- Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena);
235+ Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena, subquery_exec );
236+ Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena, subquery_exec );
234237 // IS TRUE: left is truthy and not null
235238 if (!right_val.is_null () && right_val.tag == Value::TAG_BOOL && right_val.bool_val ) {
236239 // IS TRUE
@@ -251,8 +254,8 @@ Value evaluate_expression(const AstNode* expr,
251254 return value_null ();
252255 }
253256 if (detail::ref_equals_ci (op, " IS NOT" , 6 )) {
254- Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena);
255- Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena);
257+ Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena, subquery_exec );
258+ Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena, subquery_exec );
256259 // IS NOT TRUE = NOT (IS TRUE)
257260 if (!right_val.is_null () && right_val.tag == Value::TAG_BOOL && right_val.bool_val ) {
258261 if (left_val.is_null ()) return value_bool (true );
@@ -274,8 +277,8 @@ Value evaluate_expression(const AstNode* expr,
274277
275278 // --- LIKE ---
276279 if (detail::ref_equals_ci (op, " LIKE" , 4 )) {
277- Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena);
278- Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena);
280+ Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena, subquery_exec );
281+ Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena, subquery_exec );
279282 if (left_val.is_null () || right_val.is_null ()) return value_null ();
280283 // Coerce both to strings if not already
281284 if (left_val.tag != Value::TAG_STRING)
@@ -290,8 +293,8 @@ Value evaluate_expression(const AstNode* expr,
290293 if (op.len == 2 && op.ptr [0 ] == ' |' && op.ptr [1 ] == ' |' ) {
291294 if constexpr (D == Dialect::PostgreSQL) {
292295 // String concatenation
293- Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena);
294- Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena);
296+ Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena, subquery_exec );
297+ Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena, subquery_exec );
295298 if (left_val.is_null () || right_val.is_null ()) return value_null ();
296299 // Coerce to string
297300 if (left_val.tag != Value::TAG_STRING)
@@ -308,17 +311,17 @@ Value evaluate_expression(const AstNode* expr,
308311 return value_string (StringRef{buf, total});
309312 } else {
310313 // MySQL: || is OR
311- Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena);
314+ Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena, subquery_exec );
312315 if (!left_val.is_null () && left_val.tag == Value::TAG_BOOL && left_val.bool_val )
313316 return value_bool (true );
314- Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena);
317+ Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena, subquery_exec );
315318 return null_semantics::eval_or (left_val, right_val);
316319 }
317320 }
318321
319322 // --- Standard binary: evaluate both sides, null-propagate, coerce, apply ---
320- Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena);
321- Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena);
323+ Value left_val = evaluate_expression<D>(left_node, resolve, functions, arena, subquery_exec );
324+ Value right_val = evaluate_expression<D>(right_node, resolve, functions, arena, subquery_exec );
322325
323326 // NULL propagation
324327 if (left_val.is_null () || right_val.is_null ()) return value_null ();
@@ -448,12 +451,12 @@ Value evaluate_expression(const AstNode* expr,
448451 // ---- IS NULL / IS NOT NULL (never return NULL) ----
449452
450453 case NodeType::NODE_IS_NULL: {
451- Value child = evaluate_expression<D>(expr->first_child , resolve, functions, arena);
454+ Value child = evaluate_expression<D>(expr->first_child , resolve, functions, arena, subquery_exec );
452455 return value_bool (child.is_null ());
453456 }
454457
455458 case NodeType::NODE_IS_NOT_NULL: {
456- Value child = evaluate_expression<D>(expr->first_child , resolve, functions, arena);
459+ Value child = evaluate_expression<D>(expr->first_child , resolve, functions, arena, subquery_exec );
457460 return value_bool (!child.is_null ());
458461 }
459462
@@ -465,9 +468,9 @@ Value evaluate_expression(const AstNode* expr,
465468 const AstNode* high_node = detail::nth_child (expr, 2 );
466469 if (!expr_node || !low_node || !high_node) return value_null ();
467470
468- Value val = evaluate_expression<D>(expr_node, resolve, functions, arena);
471+ Value val = evaluate_expression<D>(expr_node, resolve, functions, arena, subquery_exec );
469472 Value low = evaluate_expression<D>(low_node, resolve, functions, arena);
470- Value high = evaluate_expression<D>(high_node, resolve, functions, arena);
473+ Value high = evaluate_expression<D>(high_node, resolve, functions, arena, subquery_exec );
471474
472475 // NULL propagation
473476 if (val.is_null () || low.is_null () || high.is_null ()) return value_null ();
@@ -515,14 +518,25 @@ Value evaluate_expression(const AstNode* expr,
515518 const AstNode* expr_node = expr->first_child ;
516519 if (!expr_node) return value_null ();
517520
518- Value val = evaluate_expression<D>(expr_node, resolve, functions, arena);
521+ Value val = evaluate_expression<D>(expr_node, resolve, functions, arena, subquery_exec );
519522 if (val.is_null ()) return value_null ();
520523
521524 bool found = false ;
522525 bool has_null = false ;
523526
527+ // Collect all IN-list values. If a child is a subquery, expand it.
528+ std::vector<Value> in_values;
524529 for (const AstNode* item = expr_node->next_sibling ; item; item = item->next_sibling ) {
525- Value item_val = evaluate_expression<D>(item, resolve, functions, arena);
530+ if (item->type == NodeType::NODE_SUBQUERY && subquery_exec && item->first_child ) {
531+ // IN (subquery) -- execute and expand
532+ std::vector<Value> set_vals = subquery_exec->execute_set (item, resolve);
533+ for (auto & sv : set_vals) in_values.push_back (sv);
534+ } else {
535+ in_values.push_back (evaluate_expression<D>(item, resolve, functions, arena, subquery_exec));
536+ }
537+ }
538+
539+ for (const auto & item_val : in_values) {
526540 if (item_val.is_null ()) {
527541 has_null = true ;
528542 continue ;
@@ -561,13 +575,13 @@ Value evaluate_expression(const AstNode* expr,
561575 // Simple CASE: children = [case_expr, when1, then1, when2, then2, ..., else?]
562576 const AstNode* case_node = expr->first_child ;
563577 if (!case_node) return value_null ();
564- Value case_val = evaluate_expression<D>(case_node, resolve, functions, arena);
578+ Value case_val = evaluate_expression<D>(case_node, resolve, functions, arena, subquery_exec );
565579
566580 const AstNode* child = case_node->next_sibling ;
567581 uint32_t remaining = count - 1 ; // excluding case_expr
568582
569583 while (child && child->next_sibling ) {
570- Value when_val = evaluate_expression<D>(child, resolve, functions, arena);
584+ Value when_val = evaluate_expression<D>(child, resolve, functions, arena, subquery_exec );
571585 const AstNode* then_node = child->next_sibling ;
572586
573587 // Compare case_val = when_val
@@ -590,7 +604,7 @@ Value evaluate_expression(const AstNode* expr,
590604 }
591605
592606 if (match) {
593- return evaluate_expression<D>(then_node, resolve, functions, arena);
607+ return evaluate_expression<D>(then_node, resolve, functions, arena, subquery_exec );
594608 }
595609
596610 child = then_node->next_sibling ;
@@ -599,7 +613,7 @@ Value evaluate_expression(const AstNode* expr,
599613
600614 // Check for ELSE (one remaining child)
601615 if (child && remaining == 1 ) {
602- return evaluate_expression<D>(child, resolve, functions, arena);
616+ return evaluate_expression<D>(child, resolve, functions, arena, subquery_exec );
603617 }
604618 return value_null ();
605619 } else {
@@ -608,7 +622,7 @@ Value evaluate_expression(const AstNode* expr,
608622 uint32_t remaining = count;
609623
610624 while (child && child->next_sibling ) {
611- Value when_val = evaluate_expression<D>(child, resolve, functions, arena);
625+ Value when_val = evaluate_expression<D>(child, resolve, functions, arena, subquery_exec );
612626 const AstNode* then_node = child->next_sibling ;
613627
614628 // Evaluate WHEN condition as boolean
@@ -621,7 +635,7 @@ Value evaluate_expression(const AstNode* expr,
621635 }
622636
623637 if (is_true) {
624- return evaluate_expression<D>(then_node, resolve, functions, arena);
638+ return evaluate_expression<D>(then_node, resolve, functions, arena, subquery_exec );
625639 }
626640
627641 child = then_node->next_sibling ;
@@ -630,7 +644,7 @@ Value evaluate_expression(const AstNode* expr,
630644
631645 // Check for ELSE (one remaining child)
632646 if (child && remaining % 2 == 1 ) {
633- return evaluate_expression<D>(child, resolve, functions, arena);
647+ return evaluate_expression<D>(child, resolve, functions, arena, subquery_exec );
634648 }
635649 return value_null ();
636650 }
@@ -651,14 +665,26 @@ Value evaluate_expression(const AstNode* expr,
651665 uint32_t i = 0 ;
652666 for (const AstNode* arg = expr->first_child ; arg && i < MAX_ARGS;
653667 arg = arg->next_sibling , ++i) {
654- new (&args[i]) Value (evaluate_expression<D>(arg, resolve, functions, arena));
668+ new (&args[i]) Value (evaluate_expression<D>(arg, resolve, functions, arena, subquery_exec ));
655669 }
656670 return entry->impl (args, static_cast <uint16_t >(i), arena);
657671 }
658672
659673 // ---- Deferred node types (return value_null) ----
660674
661- case NodeType::NODE_SUBQUERY: return value_null (); // requires full executor
675+ case NodeType::NODE_SUBQUERY: {
676+ // If the subquery has a parsed SELECT child and we have an executor, run it
677+ if (subquery_exec && expr->first_child ) {
678+ // Check if this is an EXISTS subquery (flags == 1)
679+ if (expr->flags == 1 ) {
680+ bool exists = subquery_exec->execute_exists (expr, resolve);
681+ return value_bool (exists);
682+ }
683+ // Otherwise treat as scalar subquery
684+ return subquery_exec->execute_scalar (expr, resolve);
685+ }
686+ return value_null (); // no executor or no parsed child
687+ }
662688 case NodeType::NODE_TUPLE: return value_null (); // requires row/tuple value type
663689 case NodeType::NODE_ARRAY_CONSTRUCTOR: return value_null (); // requires array value type
664690 case NodeType::NODE_ARRAY_SUBSCRIPT: return value_null (); // requires array support
0 commit comments