Skip to content

Commit 5e036c7

Browse files
committed
fix: resolve distributed subquery materialization, ORDER BY with PROJECT, derived table aggregate (#22, #23, #24)
Bug #22: DELETE/UPDATE parsers didn't set the subquery parse callback, so subqueries in DML WHERE clauses (e.g., DELETE ... WHERE id IN (SELECT ...)) were skipped instead of parsed. Enable callback on DeleteParser and UpdateParser. Re-enable 2 disabled tests. Bug #23: PlanBuilder placed Sort above Project, causing ORDER BY keys to reference wrong column ordinals after projection. Move Sort before Project in the plan tree so sorting operates on full rows. Update preprocess_aggregates and distributed planner to look through SORT when finding AGGREGATE nodes below PROJECT. Bug #24: DerivedScanOperator now deep-copies inner query result rows into the outer arena, preventing dangling pointers when inner arena memory could be invalidated before the outer executor reads results.
1 parent af4ac9b commit 5e036c7

9 files changed

Lines changed: 130 additions & 45 deletions

File tree

include/sql_engine/distributed_planner.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,11 @@ class DistributedPlanner {
129129
}
130130

131131
case PlanNodeType::PROJECT: {
132-
// Check for PROJECT -> AGGREGATE pattern (or PROJECT -> FILTER -> AGGREGATE)
132+
// Check for PROJECT -> [SORT ->] [FILTER ->] AGGREGATE pattern
133133
// For aggregate queries, we want to handle the whole thing in distribute_aggregate
134134
PlanNode* agg_child = node->left;
135+
if (agg_child && agg_child->type == PlanNodeType::SORT)
136+
agg_child = agg_child->left;
135137
if (agg_child && agg_child->type == PlanNodeType::FILTER)
136138
agg_child = agg_child->left;
137139
if (agg_child && agg_child->type == PlanNodeType::AGGREGATE) {

include/sql_engine/operators/derived_scan_op.h

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,26 @@
22
//
33
// Materializes the inner plan's result set on open(), then yields
44
// rows one at a time from the materialized buffer on next().
5+
// Deep-copies all row data into the outer arena to avoid dangling
6+
// pointers into arenas that may be reset after inner execution.
57

68
#ifndef SQL_ENGINE_OPERATORS_DERIVED_SCAN_OP_H
79
#define SQL_ENGINE_OPERATORS_DERIVED_SCAN_OP_H
810

911
#include "sql_engine/operator.h"
1012
#include "sql_engine/row.h"
13+
#include "sql_parser/arena.h"
1114
#include <vector>
15+
#include <cstring>
1216

1317
namespace sql_engine {
1418

1519
class DerivedScanOperator : public Operator {
1620
public:
17-
// Takes ownership of the inner operator. On open(), pulls all rows
18-
// from the inner operator and stores them. On next(), yields them.
19-
explicit DerivedScanOperator(Operator* inner)
20-
: inner_(inner) {}
21+
// Takes the inner operator and an arena for deep-copying row data.
22+
// The arena must outlive any result set produced by the outer query.
23+
DerivedScanOperator(Operator* inner, sql_parser::Arena& arena)
24+
: inner_(inner), arena_(arena) {}
2125

2226
void open() override {
2327
rows_.clear();
@@ -26,7 +30,7 @@ class DerivedScanOperator : public Operator {
2630
inner_->open();
2731
Row row{};
2832
while (inner_->next(row)) {
29-
rows_.push_back(row);
33+
rows_.push_back(deep_copy_row(row));
3034
}
3135
inner_->close();
3236
}
@@ -45,8 +49,35 @@ class DerivedScanOperator : public Operator {
4549

4650
private:
4751
Operator* inner_;
52+
sql_parser::Arena& arena_;
4853
std::vector<Row> rows_;
4954
size_t cursor_ = 0;
55+
56+
Row deep_copy_row(const Row& src) {
57+
if (!src.values || src.column_count == 0) {
58+
Row result;
59+
result.values = nullptr;
60+
result.column_count = 0;
61+
return result;
62+
}
63+
uint16_t cc = src.column_count;
64+
Row result = make_row(arena_, cc);
65+
for (uint16_t i = 0; i < cc; ++i) {
66+
Value v = src.values[i];
67+
// Deep-copy string data into the arena
68+
if ((v.tag == Value::TAG_STRING ||
69+
v.tag == Value::TAG_DECIMAL ||
70+
v.tag == Value::TAG_BYTES ||
71+
v.tag == Value::TAG_JSON) &&
72+
v.str_val.ptr && v.str_val.len > 0) {
73+
char* buf = static_cast<char*>(arena_.allocate(v.str_val.len));
74+
std::memcpy(buf, v.str_val.ptr, v.str_val.len);
75+
v.str_val.ptr = buf;
76+
}
77+
result.set(i, v);
78+
}
79+
return result;
80+
}
5081
};
5182

5283
} // namespace sql_engine

include/sql_engine/plan_builder.h

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,37 @@ class PlanBuilder {
206206
current = filter;
207207
}
208208

209-
// 5. SELECT list -> Project
209+
// 5. ORDER BY -> Sort (before Project so sort keys resolve against full row)
210+
const sql_parser::AstNode* order_by = find_child(select_ast, sql_parser::NodeType::NODE_ORDER_BY_CLAUSE);
211+
if (order_by) {
212+
PlanNode* sort = make_plan_node(arena_, PlanNodeType::SORT);
213+
uint16_t cnt = count_children(order_by);
214+
sort->sort.count = cnt;
215+
216+
auto** keys = static_cast<const sql_parser::AstNode**>(
217+
arena_.allocate(sizeof(sql_parser::AstNode*) * cnt));
218+
auto* dirs = static_cast<uint8_t*>(arena_.allocate(cnt));
219+
220+
uint16_t idx = 0;
221+
for (const sql_parser::AstNode* item = order_by->first_child; item; item = item->next_sibling) {
222+
// First child is the key expression
223+
keys[idx] = item->first_child;
224+
// Check for DESC direction (second child with "DESC" value)
225+
dirs[idx] = 0; // ASC by default
226+
const sql_parser::AstNode* dir_node = find_child(item, sql_parser::NodeType::NODE_IDENTIFIER);
227+
if (dir_node) {
228+
sql_parser::StringRef dir_val = dir_node->value();
229+
if (dir_val.equals_ci("DESC", 4)) dirs[idx] = 1;
230+
}
231+
++idx;
232+
}
233+
sort->sort.keys = keys;
234+
sort->sort.directions = dirs;
235+
sort->left = current;
236+
current = sort;
237+
}
238+
239+
// 6. SELECT list -> Project
210240
const sql_parser::AstNode* item_list = find_child(select_ast, sql_parser::NodeType::NODE_SELECT_ITEM_LIST);
211241
if (item_list) {
212242
// Check if this is "SELECT *" with a single asterisk and no aliases -- skip Project for bare scan
@@ -246,43 +276,13 @@ class PlanBuilder {
246276
}
247277
}
248278

249-
// 6. DISTINCT -> Distinct
279+
// 7. DISTINCT -> Distinct
250280
if (has_distinct(select_ast)) {
251281
PlanNode* dist = make_plan_node(arena_, PlanNodeType::DISTINCT);
252282
dist->left = current;
253283
current = dist;
254284
}
255285

256-
// 7. ORDER BY -> Sort
257-
const sql_parser::AstNode* order_by = find_child(select_ast, sql_parser::NodeType::NODE_ORDER_BY_CLAUSE);
258-
if (order_by) {
259-
PlanNode* sort = make_plan_node(arena_, PlanNodeType::SORT);
260-
uint16_t cnt = count_children(order_by);
261-
sort->sort.count = cnt;
262-
263-
auto** keys = static_cast<const sql_parser::AstNode**>(
264-
arena_.allocate(sizeof(sql_parser::AstNode*) * cnt));
265-
auto* dirs = static_cast<uint8_t*>(arena_.allocate(cnt));
266-
267-
uint16_t idx = 0;
268-
for (const sql_parser::AstNode* item = order_by->first_child; item; item = item->next_sibling) {
269-
// First child is the key expression
270-
keys[idx] = item->first_child;
271-
// Check for DESC direction (second child with "DESC" value)
272-
dirs[idx] = 0; // ASC by default
273-
const sql_parser::AstNode* dir_node = find_child(item, sql_parser::NodeType::NODE_IDENTIFIER);
274-
if (dir_node) {
275-
sql_parser::StringRef dir_val = dir_node->value();
276-
if (dir_val.equals_ci("DESC", 4)) dirs[idx] = 1;
277-
}
278-
++idx;
279-
}
280-
sort->sort.keys = keys;
281-
sort->sort.directions = dirs;
282-
sort->left = current;
283-
current = sort;
284-
}
285-
286286
// 8. LIMIT -> Limit
287287
const sql_parser::AstNode* limit_clause = find_child(select_ast, sql_parser::NodeType::NODE_LIMIT_CLAUSE);
288288
if (limit_clause) {

include/sql_engine/plan_executor.h

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -390,8 +390,10 @@ class PlanExecutor {
390390

391391
if (node->type != PlanNodeType::PROJECT) return;
392392

393-
// Find AGGREGATE child (possibly through a FILTER for HAVING)
393+
// Find AGGREGATE child (possibly through SORT and/or FILTER for HAVING)
394394
PlanNode* agg_node = node->left;
395+
if (agg_node && agg_node->type == PlanNodeType::SORT)
396+
agg_node = agg_node->left;
395397
if (agg_node && agg_node->type == PlanNodeType::FILTER)
396398
agg_node = agg_node->left;
397399
if (!agg_node || agg_node->type != PlanNodeType::AGGREGATE) return;
@@ -435,9 +437,17 @@ class PlanExecutor {
435437
// output matches what we need.
436438

437439
// Overwrite this PROJECT node to become a pass-through:
438-
// Change it into the node below it.
440+
// Change it into the node below it, preserving any intermediate
441+
// SORT and/or FILTER (HAVING) nodes.
439442
PlanNode* child = node->left;
440-
if (child && child->type == PlanNodeType::FILTER) {
443+
if (child && child->type == PlanNodeType::SORT) {
444+
// PROJECT -> SORT -> [FILTER ->] AGGREGATE
445+
// Replace PROJECT with SORT (preserving the sort)
446+
node->type = PlanNodeType::SORT;
447+
node->sort = child->sort;
448+
node->left = child->left;
449+
node->right = child->right;
450+
} else if (child && child->type == PlanNodeType::FILTER) {
441451
// PROJECT -> FILTER -> AGGREGATE: keep FILTER, remove PROJECT
442452
node->type = PlanNodeType::FILTER;
443453
node->filter.expr = child->filter.expr;
@@ -576,10 +586,11 @@ class PlanExecutor {
576586
}
577587

578588
Operator* build_derived_scan_op(PlanNode* node) {
579-
// Build the inner plan's operator tree and wrap in DerivedScanOperator
589+
// Build the inner plan's operator tree and wrap in DerivedScanOperator.
590+
// Pass the arena so deep-copied row data persists in the outer arena.
580591
Operator* inner = build_operator(node->derived_scan.inner_plan);
581592
if (!inner) return nullptr;
582-
auto op = std::make_unique<DerivedScanOperator>(inner);
593+
auto op = std::make_unique<DerivedScanOperator>(inner, arena_);
583594
Operator* ptr = op.get();
584595
operators_.push_back(std::move(op));
585596
return ptr;

include/sql_parser/delete_parser.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ class DeleteParser {
2222
: tok_(tokenizer), arena_(arena), expr_parser_(tokenizer, arena),
2323
table_ref_parser_(tokenizer, arena, expr_parser_) {}
2424

25+
void set_subquery_callback(SubqueryParseCallback<D> cb) {
26+
expr_parser_.set_subquery_callback(cb);
27+
}
28+
2529
// Parse DELETE statement (DELETE keyword already consumed).
2630
AstNode* parse() {
2731
AstNode* root = make_node(arena_, NodeType::NODE_DELETE_STMT);

include/sql_parser/update_parser.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ class UpdateParser {
1818
: tok_(tokenizer), arena_(arena), expr_parser_(tokenizer, arena),
1919
table_ref_parser_(tokenizer, arena, expr_parser_) {}
2020

21+
void set_subquery_callback(SubqueryParseCallback<D> cb) {
22+
expr_parser_.set_subquery_callback(cb);
23+
}
24+
2125
// Parse UPDATE statement (UPDATE keyword already consumed).
2226
AstNode* parse() {
2327
AstNode* root = make_node(arena_, NodeType::NODE_UPDATE_STMT);

src/sql_parser/parser.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ ParseResult Parser<D>::parse_update() {
347347
r.stmt_type = StmtType::UPDATE;
348348

349349
UpdateParser<D> update_parser(tokenizer_, arena_);
350+
update_parser.set_subquery_callback(&parse_subquery_select<D>);
350351
AstNode* ast = update_parser.parse();
351352

352353
if (ast) {
@@ -382,6 +383,7 @@ ParseResult Parser<D>::parse_delete() {
382383
r.stmt_type = StmtType::DELETE_STMT;
383384

384385
DeleteParser<D> delete_parser(tokenizer_, arena_);
386+
delete_parser.set_subquery_callback(&parse_subquery_select<D>);
385387
AstNode* ast = delete_parser.parse();
386388

387389
if (ast) {

tests/test_distributed_dml.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ TEST_F(DistributedDmlTest, InsertThenSelectVerify) {
524524

525525
// DELETE FROM users WHERE id IN (SELECT user_id FROM orders)
526526
// users on shards, orders on orders_backend
527-
TEST_F(DistributedDmlTest, DISABLED_DmlWithCrossShardSubquery) {
527+
TEST_F(DistributedDmlTest, DmlWithCrossShardSubquery) {
528528
// Insert users: id=0 -> shard0, id=1 -> shard1, id=2 -> shard2
529529
execute_distributed_dml("INSERT INTO users (id, name, age) VALUES (0, 'Alice', 25)");
530530
execute_distributed_dml("INSERT INTO users (id, name, age) VALUES (1, 'Bob', 30)");
@@ -568,7 +568,7 @@ TEST_F(DistributedDmlTest, DISABLED_DmlWithCrossShardSubquery) {
568568
}
569569

570570
// UPDATE users SET age = 99 WHERE id IN (SELECT user_id FROM orders)
571-
TEST_F(DistributedDmlTest, DISABLED_UpdateWithCrossShardSubquery) {
571+
TEST_F(DistributedDmlTest, UpdateWithCrossShardSubquery) {
572572
// Insert users
573573
execute_distributed_dml("INSERT INTO users (id, name, age) VALUES (0, 'Alice', 25)");
574574
execute_distributed_dml("INSERT INTO users (id, name, age) VALUES (1, 'Bob', 30)");

tests/test_plan_executor.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,34 @@ TEST_F(PlanExecutorTest, SelectWithLike) {
155155
std::string name(rs.rows[0].get(0).str_val.ptr, rs.rows[0].get(0).str_val.len);
156156
EXPECT_EQ(name, "Alice");
157157
}
158+
159+
// Bug #23: SELECT name FROM users ORDER BY age DESC -- sort by age, return only name
160+
TEST_F(PlanExecutorTest, OrderByColumnNotInSelectList) {
161+
auto rs = run_query("SELECT name FROM users ORDER BY age DESC");
162+
EXPECT_EQ(rs.row_count(), 5u);
163+
EXPECT_EQ(rs.column_count, 1);
164+
// Sorted by age DESC: Eve(35), Bob(30), Alice(25), Dave(22), Carol(17)
165+
std::string first(rs.rows[0].get(0).str_val.ptr, rs.rows[0].get(0).str_val.len);
166+
std::string last(rs.rows[4].get(0).str_val.ptr, rs.rows[4].get(0).str_val.len);
167+
EXPECT_EQ(first, "Eve");
168+
EXPECT_EQ(last, "Carol");
169+
}
170+
171+
// Bug #24: SELECT * FROM (SELECT COUNT(*) AS cnt FROM users) AS t
172+
TEST_F(PlanExecutorTest, DerivedTableWithAggregate) {
173+
auto rs = run_query("SELECT * FROM (SELECT COUNT(*) AS cnt FROM users) AS t");
174+
EXPECT_EQ(rs.row_count(), 1u);
175+
EXPECT_FALSE(rs.rows[0].get(0).is_null());
176+
EXPECT_EQ(rs.rows[0].get(0).int_val, 5);
177+
}
178+
179+
// Bug #24: SELECT * FROM (SELECT dept, COUNT(*) AS cnt FROM users GROUP BY dept) AS t
180+
TEST_F(PlanExecutorTest, DerivedTableWithGroupByAggregate) {
181+
auto rs = run_query("SELECT * FROM (SELECT dept, COUNT(*) AS cnt FROM users GROUP BY dept) AS t");
182+
EXPECT_EQ(rs.row_count(), 2u);
183+
int64_t total = 0;
184+
for (const auto& row : rs.rows) {
185+
total += row.get(1).int_val;
186+
}
187+
EXPECT_EQ(total, 5); // 3 Engineering + 2 Sales
188+
}

0 commit comments

Comments
 (0)