Skip to content

Commit 648f01e

Browse files
committed
feat: add DML execution — INSERT/UPDATE/DELETE local + distributed with shard routing
1 parent 2466c3f commit 648f01e

11 files changed

Lines changed: 1767 additions & 1 deletion

Makefile.new

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ TEST_SRCS = $(TEST_DIR)/test_main.cpp \
5757
$(TEST_DIR)/test_operators.cpp \
5858
$(TEST_DIR)/test_plan_executor.cpp \
5959
$(TEST_DIR)/test_optimizer.cpp \
60-
$(TEST_DIR)/test_distributed_planner.cpp
60+
$(TEST_DIR)/test_distributed_planner.cpp \
61+
$(TEST_DIR)/test_dml.cpp \
62+
$(TEST_DIR)/test_distributed_dml.cpp
6163
TEST_OBJS = $(TEST_SRCS:.cpp=.o)
6264
TEST_TARGET = $(PROJECT_ROOT)/run_tests
6365

include/sql_engine/distributed_planner.h

Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66
#include "sql_engine/catalog.h"
77
#include "sql_engine/remote_query_builder.h"
88
#include "sql_engine/operators/merge_aggregate_op.h"
9+
#include "sql_engine/expression_eval.h"
10+
#include "sql_engine/function_registry.h"
911
#include "sql_parser/arena.h"
1012
#include "sql_parser/ast.h"
1113
#include "sql_parser/common.h"
1214
#include <cstring>
1315
#include <cstdio>
1416
#include <vector>
17+
#include <unordered_map>
18+
#include <functional>
1519

1620
namespace sql_engine {
1721

@@ -28,6 +32,24 @@ class DistributedPlanner {
2832
return distribute_node(plan);
2933
}
3034

35+
// Distribute a DML plan node for remote execution.
36+
// Returns a new plan tree with REMOTE_SCAN nodes (for DML, the remote
37+
// scan carries the DML SQL; the executor calls execute_dml on it).
38+
PlanNode* distribute_dml(PlanNode* plan) {
39+
if (!plan) return nullptr;
40+
41+
switch (plan->type) {
42+
case PlanNodeType::INSERT_PLAN:
43+
return distribute_insert(plan);
44+
case PlanNodeType::UPDATE_PLAN:
45+
return distribute_update(plan);
46+
case PlanNodeType::DELETE_PLAN:
47+
return distribute_delete(plan);
48+
default:
49+
return plan;
50+
}
51+
}
52+
3153
private:
3254
const ShardMap& shards_;
3355
const Catalog& catalog_;
@@ -670,6 +692,298 @@ class DistributedPlanner {
670692
return result;
671693
}
672694

695+
// ---- DML distribution ----
696+
697+
PlanNode* distribute_insert(PlanNode* plan) {
698+
const auto& ip = plan->insert_plan;
699+
const TableInfo* table = ip.table;
700+
if (!table || !shards_.has_table(table->table_name)) return plan;
701+
702+
if (!shards_.is_sharded(table->table_name)) {
703+
// Unsharded: single remote INSERT
704+
sql_parser::StringRef sql = qb_.build_insert(
705+
table, ip.columns, ip.column_count, ip.value_rows, ip.row_count);
706+
return make_remote_scan(shards_.get_backend(table->table_name), sql, table);
707+
}
708+
709+
// Sharded: group rows by shard key value
710+
sql_parser::StringRef shard_key = shards_.get_shard_key(table->table_name);
711+
if (!shard_key.ptr) return plan;
712+
713+
// Find shard key column ordinal in the column list
714+
int shard_col_idx = -1;
715+
if (ip.columns && ip.column_count > 0) {
716+
for (uint16_t i = 0; i < ip.column_count; ++i) {
717+
if (ip.columns[i] && ip.columns[i]->value().equals_ci(shard_key.ptr, shard_key.len)) {
718+
shard_col_idx = static_cast<int>(i);
719+
break;
720+
}
721+
}
722+
} else if (table) {
723+
// No explicit column list -- match by table column order
724+
for (uint16_t i = 0; i < table->column_count; ++i) {
725+
if (table->columns[i].name.equals_ci(shard_key.ptr, shard_key.len)) {
726+
shard_col_idx = static_cast<int>(i);
727+
break;
728+
}
729+
}
730+
}
731+
732+
if (shard_col_idx < 0) {
733+
// Can't determine shard -- send to all (scatter)
734+
// For INSERT, this is an error in practice. Fall back to first shard.
735+
sql_parser::StringRef sql = qb_.build_insert(
736+
table, ip.columns, ip.column_count, ip.value_rows, ip.row_count);
737+
return make_remote_scan(shards_.get_backend(table->table_name), sql, table);
738+
}
739+
740+
const auto& shard_list = shards_.get_shards(table->table_name);
741+
742+
// Group rows by shard: evaluate the shard key value in each row,
743+
// hash to determine target shard
744+
// Map: shard_index -> list of row indices
745+
std::unordered_map<size_t, std::vector<uint16_t>> shard_rows;
746+
auto null_resolve = [](sql_parser::StringRef) -> Value { return value_null(); };
747+
748+
for (uint16_t ri = 0; ri < ip.row_count; ++ri) {
749+
const sql_parser::AstNode* row_ast = ip.value_rows[ri];
750+
if (!row_ast) continue;
751+
752+
// Get the shard key value expression (nth child of the row)
753+
const sql_parser::AstNode* expr = row_ast->first_child;
754+
for (int j = 0; j < shard_col_idx && expr; ++j) {
755+
expr = expr->next_sibling;
756+
}
757+
758+
// Evaluate to get the value, then hash to determine shard
759+
size_t shard_idx = 0;
760+
if (expr) {
761+
// Simple hashing: convert to int64 and mod by shard count
762+
Value v = evaluate_shard_key_value(expr);
763+
if (v.tag == Value::TAG_INT64) {
764+
shard_idx = static_cast<size_t>(
765+
std::abs(v.int_val) % static_cast<int64_t>(shard_list.size()));
766+
} else if (v.tag == Value::TAG_STRING && v.str_val.ptr) {
767+
// Simple string hash
768+
uint64_t h = 0;
769+
for (uint32_t k = 0; k < v.str_val.len; ++k) {
770+
h = h * 31 + static_cast<uint8_t>(v.str_val.ptr[k]);
771+
}
772+
shard_idx = static_cast<size_t>(h % shard_list.size());
773+
}
774+
}
775+
shard_rows[shard_idx].push_back(ri);
776+
}
777+
778+
// Generate per-shard INSERT SQL
779+
if (shard_rows.size() == 1) {
780+
auto it = shard_rows.begin();
781+
// If all rows go to one shard, send the original INSERT
782+
if (it->second.size() == ip.row_count) {
783+
sql_parser::StringRef sql = qb_.build_insert(
784+
table, ip.columns, ip.column_count, ip.value_rows, ip.row_count);
785+
return make_remote_scan(shard_list[it->first].backend_name.c_str(), sql, table);
786+
}
787+
}
788+
789+
// Build per-shard INSERT nodes, combine with UNION ALL (for plan structure)
790+
PlanNode* current = nullptr;
791+
for (auto& [shard_idx, row_indices] : shard_rows) {
792+
// Build a subset value_rows array
793+
uint16_t sub_count = static_cast<uint16_t>(row_indices.size());
794+
auto** sub_rows = static_cast<const sql_parser::AstNode**>(
795+
arena_.allocate(sizeof(sql_parser::AstNode*) * sub_count));
796+
for (uint16_t i = 0; i < sub_count; ++i) {
797+
sub_rows[i] = ip.value_rows[row_indices[i]];
798+
}
799+
800+
sql_parser::StringRef sql = qb_.build_insert(
801+
table, ip.columns, ip.column_count, sub_rows, sub_count);
802+
PlanNode* rs = make_remote_scan(shard_list[shard_idx].backend_name.c_str(), sql, table);
803+
804+
if (!current) {
805+
current = rs;
806+
} else {
807+
PlanNode* union_node = make_plan_node(arena_, PlanNodeType::SET_OP);
808+
union_node->set_op.op = SET_OP_UNION;
809+
union_node->set_op.all = true;
810+
union_node->left = current;
811+
union_node->right = rs;
812+
current = union_node;
813+
}
814+
}
815+
816+
return current ? current : plan;
817+
}
818+
819+
PlanNode* distribute_update(PlanNode* plan) {
820+
const auto& up = plan->update_plan;
821+
const TableInfo* table = up.table;
822+
if (!table || !shards_.has_table(table->table_name)) return plan;
823+
824+
if (!shards_.is_sharded(table->table_name)) {
825+
// Unsharded: single remote UPDATE
826+
sql_parser::StringRef sql = qb_.build_update(
827+
table, up.set_columns, up.set_exprs, up.set_count, up.where_expr);
828+
return make_remote_scan(shards_.get_backend(table->table_name), sql, table);
829+
}
830+
831+
// Sharded: check if WHERE references the shard key
832+
sql_parser::StringRef shard_key = shards_.get_shard_key(table->table_name);
833+
const auto& shard_list = shards_.get_shards(table->table_name);
834+
835+
int target_shard = find_shard_from_where(up.where_expr, shard_key, shard_list.size());
836+
837+
if (target_shard >= 0) {
838+
// Route to specific shard
839+
sql_parser::StringRef sql = qb_.build_update(
840+
table, up.set_columns, up.set_exprs, up.set_count, up.where_expr);
841+
return make_remote_scan(shard_list[target_shard].backend_name.c_str(), sql, table);
842+
}
843+
844+
// Scatter to all shards
845+
return scatter_dml_to_shards(table, shard_list, [&]() {
846+
return qb_.build_update(
847+
table, up.set_columns, up.set_exprs, up.set_count, up.where_expr);
848+
});
849+
}
850+
851+
PlanNode* distribute_delete(PlanNode* plan) {
852+
const auto& dp = plan->delete_plan;
853+
const TableInfo* table = dp.table;
854+
if (!table || !shards_.has_table(table->table_name)) return plan;
855+
856+
if (!shards_.is_sharded(table->table_name)) {
857+
// Unsharded: single remote DELETE
858+
sql_parser::StringRef sql = qb_.build_delete(table, dp.where_expr);
859+
return make_remote_scan(shards_.get_backend(table->table_name), sql, table);
860+
}
861+
862+
// Sharded: check if WHERE references the shard key
863+
sql_parser::StringRef shard_key = shards_.get_shard_key(table->table_name);
864+
const auto& shard_list = shards_.get_shards(table->table_name);
865+
866+
int target_shard = find_shard_from_where(dp.where_expr, shard_key, shard_list.size());
867+
868+
if (target_shard >= 0) {
869+
// Route to specific shard
870+
sql_parser::StringRef sql = qb_.build_delete(table, dp.where_expr);
871+
return make_remote_scan(shard_list[target_shard].backend_name.c_str(), sql, table);
872+
}
873+
874+
// Scatter to all shards
875+
return scatter_dml_to_shards(table, shard_list, [&]() {
876+
return qb_.build_delete(table, dp.where_expr);
877+
});
878+
}
879+
880+
// Evaluate a shard key expression from a VALUES row (simple: literal values only)
881+
Value evaluate_shard_key_value(const sql_parser::AstNode* expr) {
882+
if (!expr) return value_null();
883+
if (expr->type == sql_parser::NodeType::NODE_LITERAL_INT) {
884+
sql_parser::StringRef val = expr->value();
885+
int64_t n = 0;
886+
for (uint32_t i = 0; i < val.len; ++i) {
887+
char c = val.ptr[i];
888+
if (c >= '0' && c <= '9') n = n * 10 + (c - '0');
889+
}
890+
return value_int(n);
891+
}
892+
if (expr->type == sql_parser::NodeType::NODE_LITERAL_STRING) {
893+
return value_string(expr->value());
894+
}
895+
return value_null();
896+
}
897+
898+
// Check if a WHERE expression contains shard_key = <literal>.
899+
// Returns the target shard index, or -1 if not determinable.
900+
int find_shard_from_where(const sql_parser::AstNode* where_expr,
901+
sql_parser::StringRef shard_key,
902+
size_t shard_count) {
903+
if (!where_expr || !shard_key.ptr || shard_count == 0) return -1;
904+
905+
// Look for binary_op '=' with one side being the shard key column
906+
if (where_expr->type == sql_parser::NodeType::NODE_BINARY_OP) {
907+
sql_parser::StringRef op = where_expr->value();
908+
if (op.len == 1 && op.ptr[0] == '=') {
909+
const sql_parser::AstNode* left = where_expr->first_child;
910+
const sql_parser::AstNode* right = left ? left->next_sibling : nullptr;
911+
if (!left || !right) return -1;
912+
913+
// Check if left is the shard key column and right is a literal (or vice versa)
914+
const sql_parser::AstNode* col_node = nullptr;
915+
const sql_parser::AstNode* val_node = nullptr;
916+
917+
if (is_column_ref(left, shard_key)) {
918+
col_node = left;
919+
val_node = right;
920+
} else if (is_column_ref(right, shard_key)) {
921+
col_node = right;
922+
val_node = left;
923+
}
924+
925+
if (col_node && val_node) {
926+
Value v = evaluate_shard_key_value(val_node);
927+
if (v.tag == Value::TAG_INT64) {
928+
return static_cast<int>(
929+
std::abs(v.int_val) % static_cast<int64_t>(shard_count));
930+
}
931+
if (v.tag == Value::TAG_STRING && v.str_val.ptr) {
932+
uint64_t h = 0;
933+
for (uint32_t k = 0; k < v.str_val.len; ++k) {
934+
h = h * 31 + static_cast<uint8_t>(v.str_val.ptr[k]);
935+
}
936+
return static_cast<int>(h % shard_count);
937+
}
938+
}
939+
}
940+
941+
// Check AND: both sides might contain the shard key
942+
if (op.equals_ci("AND", 3)) {
943+
const sql_parser::AstNode* left = where_expr->first_child;
944+
const sql_parser::AstNode* right = left ? left->next_sibling : nullptr;
945+
int r = find_shard_from_where(left, shard_key, shard_count);
946+
if (r >= 0) return r;
947+
return find_shard_from_where(right, shard_key, shard_count);
948+
}
949+
}
950+
951+
return -1;
952+
}
953+
954+
bool is_column_ref(const sql_parser::AstNode* node, sql_parser::StringRef col_name) {
955+
if (!node) return false;
956+
if (node->type == sql_parser::NodeType::NODE_COLUMN_REF ||
957+
node->type == sql_parser::NodeType::NODE_IDENTIFIER) {
958+
return node->value().equals_ci(col_name.ptr, col_name.len);
959+
}
960+
return false;
961+
}
962+
963+
// Scatter DML SQL to all shards, combining results via UNION ALL
964+
PlanNode* scatter_dml_to_shards(const TableInfo* table,
965+
const std::vector<ShardInfo>& shard_list,
966+
std::function<sql_parser::StringRef()> build_sql) {
967+
if (shard_list.empty()) return nullptr;
968+
969+
PlanNode* current = nullptr;
970+
for (const auto& shard : shard_list) {
971+
sql_parser::StringRef sql = build_sql();
972+
PlanNode* rs = make_remote_scan(shard.backend_name.c_str(), sql, table);
973+
if (!current) {
974+
current = rs;
975+
} else {
976+
PlanNode* union_node = make_plan_node(arena_, PlanNodeType::SET_OP);
977+
union_node->set_op.op = SET_OP_UNION;
978+
union_node->set_op.all = true;
979+
union_node->left = current;
980+
union_node->right = rs;
981+
current = union_node;
982+
}
983+
}
984+
return current;
985+
}
986+
673987
// Case 6: Distributed DISTINCT
674988
PlanNode* distribute_distinct(PlanNode* distinct_node) {
675989
// Check child for sharded scan

0 commit comments

Comments
 (0)