Skip to content

Commit 4d6e1d8

Browse files
committed
fix: resolve distributed aggregation, sort merge, join columns, and subquery across shards
Bug 1 - Aggregation missing COUNT column: constant folding was replacing aggregate functions (COUNT, SUM, AVG, etc.) with NULL when they had no column references (e.g., COUNT(*)). Fixed by treating aggregate functions as non-foldable in the constant folding rule. Also added output_exprs to MERGE_AGGREGATE for proper column naming. Bug 2 - Sort+Limit returning wrong results: MySQL DECIMAL values were stored as TAG_STRING instead of TAG_DECIMAL, causing string comparison ("90000" > "110000") instead of numeric. Fixed by returning TAG_DECIMAL from the MySQL executor and adding TAG_DECIMAL handling to Value::to_double() and Value::to_int64(). Bug 3 - Cross-shard JOIN columns NULL: collect_tables() recursed into both sides of SET_OP (UNION ALL) nodes, duplicating table entries and inflating column offsets so the PROJECT resolver addressed beyond the actual row width. Fixed by only collecting tables from the left side of SET_OP nodes. Bug 4 - Distributed subquery empty: SubqueryExecutor built and executed subquery plans without going through the distributed planner or setting up the remote executor. Fixed by adding a distribute callback to PlanExecutor that Session wires up, and propagating remote_executor to inner subquery executors. Also prevented filter-with-subquery pushdown to remote shards.
1 parent de8fcbb commit 4d6e1d8

8 files changed

Lines changed: 114 additions & 12 deletions

File tree

include/sql_engine/distributed_planner.h

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,10 @@ class DistributedPlanner {
116116

117117
case PlanNodeType::FILTER: {
118118
// Check if child is a SCAN -- push filter to remote
119-
if (node->left && node->left->type == PlanNodeType::SCAN) {
119+
// (but not if the filter contains a subquery, which must be
120+
// evaluated locally after distributed subquery execution)
121+
if (node->left && node->left->type == PlanNodeType::SCAN &&
122+
!has_subquery(node->filter.expr)) {
120123
return distribute_scan(node->left, node->filter.expr,
121124
nullptr, nullptr, nullptr, false);
122125
}
@@ -140,8 +143,6 @@ class DistributedPlanner {
140143
// Extract aggregate info from the PROJECT select list
141144
push_agg_exprs_from_project(node, agg_child);
142145
PlanNode* dist_agg = distribute_aggregate(agg_child);
143-
// If it was distributed (MERGE_AGGREGATE), no need for PROJECT wrapper
144-
// because the merge already produces the right columns
145146
if (dist_agg && dist_agg->type == PlanNodeType::MERGE_AGGREGATE) {
146147
// Re-add FILTER (HAVING) if present
147148
if (node->left && node->left->type == PlanNodeType::FILTER) {
@@ -583,6 +584,18 @@ class DistributedPlanner {
583584
arena_.allocate(merge_ops.size()));
584585
std::memcpy(merge->merge_aggregate.merge_ops, merge_ops.data(), merge_ops.size());
585586

587+
// Store original output expressions for column naming:
588+
// group_by expressions + aggregate expressions
589+
uint16_t out_count = agg_node->aggregate.group_count + agg_node->aggregate.agg_count;
590+
auto** out_exprs = static_cast<const sql_parser::AstNode**>(
591+
arena_.allocate(sizeof(sql_parser::AstNode*) * out_count));
592+
for (uint16_t i = 0; i < agg_node->aggregate.group_count; ++i)
593+
out_exprs[i] = agg_node->aggregate.group_by[i];
594+
for (uint16_t i = 0; i < agg_node->aggregate.agg_count; ++i)
595+
out_exprs[agg_node->aggregate.group_count + i] = agg_node->aggregate.agg_exprs[i];
596+
merge->merge_aggregate.output_exprs = out_exprs;
597+
merge->merge_aggregate.output_expr_count = out_count;
598+
586599
// Set left to first child for compatibility with tree walkers
587600
if (!children.empty()) merge->left = children[0];
588601

include/sql_engine/plan_executor.h

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
#include "sql_engine/dml_result.h"
4949
#include "sql_engine/mutable_data_source.h"
5050
#include "sql_engine/catalog_resolver.h"
51+
#include "sql_parser/emitter.h"
5152

5253
#include <unordered_map>
5354
#include <string>
@@ -81,6 +82,11 @@ class PlanExecutor {
8182
// Access the subquery executor (for operators that need it)
8283
SubqueryExecutor<D>* subquery_executor() { return &subquery_exec_; }
8384

85+
// Set a distribute callback for subquery plans (used by Session to inject
86+
// distributed planning into subquery execution).
87+
using DistributeFn = std::function<PlanNode*(PlanNode*)>;
88+
void set_distribute_fn(DistributeFn fn) { distribute_fn_ = std::move(fn); }
89+
8490
ResultSet execute(PlanNode* plan) {
8591
if (!plan) return {};
8692

@@ -141,15 +147,22 @@ class PlanExecutor {
141147
std::unordered_map<std::string, MutableDataSource*> mutable_sources_;
142148
std::vector<std::unique_ptr<Operator>> operators_;
143149
RemoteExecutor* remote_executor_ = nullptr;
150+
DistributeFn distribute_fn_;
144151
SubqueryExecutor<D> subquery_exec_;
145152
sql_parser::Arena subquery_plan_arena_{65536, 1048576};
146153
std::function<Value(sql_parser::StringRef)> outer_resolver_;
147154

148155
void setup_subquery_executor() {
149-
// Build plan callback: uses PlanBuilder with our catalog and arena
156+
// Build plan callback: uses PlanBuilder with our catalog and arena.
157+
// If a distribute function is set (for distributed/sharded execution),
158+
// apply it to the plan so subqueries also go through the distributed planner.
150159
subquery_exec_.set_build_plan([this](const sql_parser::AstNode* ast) -> PlanNode* {
151160
PlanBuilder<D> builder(catalog_, arena_);
152-
return builder.build(ast);
161+
PlanNode* plan = builder.build(ast);
162+
if (plan && distribute_fn_) {
163+
plan = distribute_fn_(plan);
164+
}
165+
return plan;
153166
});
154167
// Execute plan callback: create a fresh executor for the subquery
155168
// to avoid interfering with the outer operator tree.
@@ -169,6 +182,8 @@ class PlanExecutor {
169182
inner_exec.add_data_source(kv.first.c_str(), kv.second);
170183
}
171184
}
185+
if (remote_executor_)
186+
inner_exec.set_remote_executor(remote_executor_);
172187
return inner_exec.execute(plan);
173188
});
174189
// Correlated execution: pass outer resolver as fallback.
@@ -192,6 +207,8 @@ class PlanExecutor {
192207
inner_exec.add_data_source(kv.first.c_str(), kv.second);
193208
}
194209
}
210+
if (remote_executor_)
211+
inner_exec.set_remote_executor(remote_executor_);
195212
inner_exec.set_outer_resolver(outer_resolve);
196213
return inner_exec.execute(plan);
197214
});
@@ -503,6 +520,13 @@ class PlanExecutor {
503520
collect_tables(node->merge_sort.children[i], tables);
504521
return;
505522
}
523+
// For SET_OP (UNION ALL), only collect from left side to avoid
524+
// duplicating table entries when the same table appears on both
525+
// sides (e.g., sharded UNION ALL of RemoteScans for the same table).
526+
if (node->type == PlanNodeType::SET_OP) {
527+
collect_tables(node->left, tables);
528+
return;
529+
}
506530
collect_tables(node->left, tables);
507531
collect_tables(node->right, tables);
508532
}
@@ -536,8 +560,10 @@ class PlanExecutor {
536560
if (node->remote_scan.table) return node->remote_scan.table->column_count;
537561
return 0;
538562
case PlanNodeType::MERGE_AGGREGATE:
539-
// group keys + output agg columns
540-
return count_columns(node->left);
563+
// Use stored output expression count if available
564+
if (node->merge_aggregate.output_expr_count > 0)
565+
return node->merge_aggregate.output_expr_count;
566+
return node->merge_aggregate.group_key_count + node->merge_aggregate.merge_op_count;
541567
case PlanNodeType::MERGE_SORT:
542568
return count_columns(node->left);
543569
case PlanNodeType::INSERT_PLAN:
@@ -970,6 +996,31 @@ class PlanExecutor {
970996
build_column_names(plan->derived_scan.inner_plan, rs);
971997
}
972998
break;
999+
case PlanNodeType::MERGE_AGGREGATE: {
1000+
// Use the stored output expressions for column naming
1001+
if (plan->merge_aggregate.output_exprs &&
1002+
plan->merge_aggregate.output_expr_count > 0) {
1003+
for (uint16_t i = 0; i < plan->merge_aggregate.output_expr_count; ++i) {
1004+
const sql_parser::AstNode* expr = plan->merge_aggregate.output_exprs[i];
1005+
if (expr) {
1006+
// Use the emitter to produce a readable name
1007+
sql_parser::Emitter<D> emitter(arena_);
1008+
emitter.emit(expr);
1009+
sql_parser::StringRef name = emitter.result();
1010+
if (name.ptr && name.len > 0) {
1011+
rs.column_names.emplace_back(name.ptr, name.len);
1012+
} else {
1013+
rs.column_names.push_back("?column?");
1014+
}
1015+
} else {
1016+
rs.column_names.push_back("?column?");
1017+
}
1018+
}
1019+
} else {
1020+
build_column_names(plan->left, rs);
1021+
}
1022+
break;
1023+
}
9731024
default:
9741025
// For wrapping operators, recurse to child
9751026
build_column_names(plan->left, rs);

include/sql_engine/plan_node.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ struct PlanNode {
113113
uint8_t* merge_ops; // parallel to agg columns
114114
uint16_t merge_op_count;
115115
uint16_t group_key_count; // number of leading group-by columns
116+
// Original output column expressions (for column naming)
117+
const sql_parser::AstNode** output_exprs;
118+
uint16_t output_expr_count;
116119
} merge_aggregate;
117120

118121
struct {

include/sql_engine/rules/constant_folding.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,17 @@ namespace rules {
2323

2424
namespace detail_cf {
2525

26-
// Check if an expression has any column references
26+
// Check if a function name is an aggregate function
27+
inline bool is_aggregate_function(const sql_parser::AstNode* expr) {
28+
if (!expr || expr->type != sql_parser::NodeType::NODE_FUNCTION_CALL) return false;
29+
sql_parser::StringRef name = expr->value();
30+
return name.equals_ci("COUNT", 5) || name.equals_ci("SUM", 3) ||
31+
name.equals_ci("AVG", 3) || name.equals_ci("MIN", 3) ||
32+
name.equals_ci("MAX", 3);
33+
}
34+
35+
// Check if an expression has any column references or aggregate functions
36+
// (aggregate functions depend on row data and must not be folded)
2737
inline bool has_column_ref(const sql_parser::AstNode* expr) {
2838
if (!expr) return false;
2939
switch (expr->type) {
@@ -33,6 +43,10 @@ inline bool has_column_ref(const sql_parser::AstNode* expr) {
3343
case sql_parser::NodeType::NODE_IDENTIFIER:
3444
// Identifiers in expression context are column references
3545
return true;
46+
case sql_parser::NodeType::NODE_FUNCTION_CALL:
47+
// Aggregate functions depend on row data, not foldable
48+
if (is_aggregate_function(expr)) return true;
49+
break;
3650
default:
3751
break;
3852
}

include/sql_engine/session.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,16 @@ class Session {
203203
executor.add_mutable_data_source(kv.first.c_str(), kv.second);
204204
if (remote_executor_)
205205
executor.set_remote_executor(remote_executor_);
206+
// If sharding is configured, provide a distribute callback so that
207+
// subqueries also go through the distributed planner.
208+
if (shard_map_ && remote_executor_) {
209+
executor.set_distribute_fn(
210+
[this](PlanNode* plan) -> PlanNode* {
211+
DistributedPlanner<D> dp(*shard_map_, catalog_, parser_.arena(),
212+
remote_executor_, &functions_);
213+
return dp.distribute(plan);
214+
});
215+
}
206216
}
207217
};
208218

include/sql_engine/value.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "sql_engine/types.h"
55
#include "sql_parser/common.h" // StringRef
66
#include <cstdint>
7+
#include <cstdlib>
78
#include <type_traits>
89

910
namespace sql_engine {
@@ -55,6 +56,15 @@ struct Value {
5556
case TAG_INT64: return static_cast<double>(int_val);
5657
case TAG_UINT64: return static_cast<double>(uint_val);
5758
case TAG_DOUBLE: return double_val;
59+
case TAG_DECIMAL:
60+
if (str_val.ptr && str_val.len > 0) {
61+
char buf[128];
62+
uint32_t n = str_val.len < 127 ? str_val.len : 127;
63+
for (uint32_t i = 0; i < n; ++i) buf[i] = str_val.ptr[i];
64+
buf[n] = '\0';
65+
return std::strtod(buf, nullptr);
66+
}
67+
return 0.0;
5868
default: return 0.0;
5969
}
6070
}
@@ -66,6 +76,7 @@ struct Value {
6676
case TAG_INT64: return int_val;
6777
case TAG_UINT64: return static_cast<int64_t>(uint_val);
6878
case TAG_DOUBLE: return static_cast<int64_t>(double_val);
79+
case TAG_DECIMAL: return static_cast<int64_t>(to_double());
6980
default: return 0;
7081
}
7182
}

src/sql_engine/mysql_remote_executor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,9 @@ Value MySQLRemoteExecutor::mysql_field_to_value(
165165

166166
case MYSQL_TYPE_DECIMAL:
167167
case MYSQL_TYPE_NEWDECIMAL: {
168-
// Store as string in arena
168+
// Store as decimal (numeric string) in arena
169169
sql_parser::StringRef s = arena_.allocate_string(data, static_cast<uint32_t>(length));
170-
return value_string(s);
170+
return value_decimal(s);
171171
}
172172

173173
case MYSQL_TYPE_DATE: {

tests/test_mysql_executor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ TEST_F(MySQLExecutorTest, DecimalType) {
112112
sql_parser::StringRef sql{q, static_cast<uint32_t>(strlen(q))};
113113
auto rs = exec_->execute("test_mysql", sql);
114114
ASSERT_EQ(rs.row_count(), 1u);
115-
// DECIMAL comes as string
116-
EXPECT_EQ(rs.rows[0].get(0).tag, sql_engine::Value::TAG_STRING);
115+
// DECIMAL comes as numeric decimal type
116+
EXPECT_EQ(rs.rows[0].get(0).tag, sql_engine::Value::TAG_DECIMAL);
117117
}
118118

119119
TEST_F(MySQLExecutorTest, NullHandling) {

0 commit comments

Comments
 (0)