66#include " sql_engine/catalog.h"
77#include " sql_engine/remote_query_builder.h"
88#include " sql_engine/operators/merge_aggregate_op.h"
9+ #include " sql_engine/expression_eval.h"
10+ #include " sql_engine/function_registry.h"
911#include " sql_parser/arena.h"
1012#include " sql_parser/ast.h"
1113#include " sql_parser/common.h"
1214#include < cstring>
1315#include < cstdio>
1416#include < vector>
17+ #include < unordered_map>
18+ #include < functional>
1519
1620namespace sql_engine {
1721
@@ -28,6 +32,24 @@ class DistributedPlanner {
2832 return distribute_node (plan);
2933 }
3034
35+ // Distribute a DML plan node for remote execution.
36+ // Returns a new plan tree with REMOTE_SCAN nodes (for DML, the remote
37+ // scan carries the DML SQL; the executor calls execute_dml on it).
38+ PlanNode* distribute_dml (PlanNode* plan) {
39+ if (!plan) return nullptr ;
40+
41+ switch (plan->type ) {
42+ case PlanNodeType::INSERT_PLAN:
43+ return distribute_insert (plan);
44+ case PlanNodeType::UPDATE_PLAN:
45+ return distribute_update (plan);
46+ case PlanNodeType::DELETE_PLAN:
47+ return distribute_delete (plan);
48+ default :
49+ return plan;
50+ }
51+ }
52+
3153private:
3254 const ShardMap& shards_;
3355 const Catalog& catalog_;
@@ -670,6 +692,298 @@ class DistributedPlanner {
670692 return result;
671693 }
672694
695+ // ---- DML distribution ----
696+
697+ PlanNode* distribute_insert (PlanNode* plan) {
698+ const auto & ip = plan->insert_plan ;
699+ const TableInfo* table = ip.table ;
700+ if (!table || !shards_.has_table (table->table_name )) return plan;
701+
702+ if (!shards_.is_sharded (table->table_name )) {
703+ // Unsharded: single remote INSERT
704+ sql_parser::StringRef sql = qb_.build_insert (
705+ table, ip.columns , ip.column_count , ip.value_rows , ip.row_count );
706+ return make_remote_scan (shards_.get_backend (table->table_name ), sql, table);
707+ }
708+
709+ // Sharded: group rows by shard key value
710+ sql_parser::StringRef shard_key = shards_.get_shard_key (table->table_name );
711+ if (!shard_key.ptr ) return plan;
712+
713+ // Find shard key column ordinal in the column list
714+ int shard_col_idx = -1 ;
715+ if (ip.columns && ip.column_count > 0 ) {
716+ for (uint16_t i = 0 ; i < ip.column_count ; ++i) {
717+ if (ip.columns [i] && ip.columns [i]->value ().equals_ci (shard_key.ptr , shard_key.len )) {
718+ shard_col_idx = static_cast <int >(i);
719+ break ;
720+ }
721+ }
722+ } else if (table) {
723+ // No explicit column list -- match by table column order
724+ for (uint16_t i = 0 ; i < table->column_count ; ++i) {
725+ if (table->columns [i].name .equals_ci (shard_key.ptr , shard_key.len )) {
726+ shard_col_idx = static_cast <int >(i);
727+ break ;
728+ }
729+ }
730+ }
731+
732+ if (shard_col_idx < 0 ) {
733+ // Can't determine shard -- send to all (scatter)
734+ // For INSERT, this is an error in practice. Fall back to first shard.
735+ sql_parser::StringRef sql = qb_.build_insert (
736+ table, ip.columns , ip.column_count , ip.value_rows , ip.row_count );
737+ return make_remote_scan (shards_.get_backend (table->table_name ), sql, table);
738+ }
739+
740+ const auto & shard_list = shards_.get_shards (table->table_name );
741+
742+ // Group rows by shard: evaluate the shard key value in each row,
743+ // hash to determine target shard
744+ // Map: shard_index -> list of row indices
745+ std::unordered_map<size_t , std::vector<uint16_t >> shard_rows;
746+ auto null_resolve = [](sql_parser::StringRef) -> Value { return value_null (); };
747+
748+ for (uint16_t ri = 0 ; ri < ip.row_count ; ++ri) {
749+ const sql_parser::AstNode* row_ast = ip.value_rows [ri];
750+ if (!row_ast) continue ;
751+
752+ // Get the shard key value expression (nth child of the row)
753+ const sql_parser::AstNode* expr = row_ast->first_child ;
754+ for (int j = 0 ; j < shard_col_idx && expr; ++j) {
755+ expr = expr->next_sibling ;
756+ }
757+
758+ // Evaluate to get the value, then hash to determine shard
759+ size_t shard_idx = 0 ;
760+ if (expr) {
761+ // Simple hashing: convert to int64 and mod by shard count
762+ Value v = evaluate_shard_key_value (expr);
763+ if (v.tag == Value::TAG_INT64) {
764+ shard_idx = static_cast <size_t >(
765+ std::abs (v.int_val ) % static_cast <int64_t >(shard_list.size ()));
766+ } else if (v.tag == Value::TAG_STRING && v.str_val .ptr ) {
767+ // Simple string hash
768+ uint64_t h = 0 ;
769+ for (uint32_t k = 0 ; k < v.str_val .len ; ++k) {
770+ h = h * 31 + static_cast <uint8_t >(v.str_val .ptr [k]);
771+ }
772+ shard_idx = static_cast <size_t >(h % shard_list.size ());
773+ }
774+ }
775+ shard_rows[shard_idx].push_back (ri);
776+ }
777+
778+ // Generate per-shard INSERT SQL
779+ if (shard_rows.size () == 1 ) {
780+ auto it = shard_rows.begin ();
781+ // If all rows go to one shard, send the original INSERT
782+ if (it->second .size () == ip.row_count ) {
783+ sql_parser::StringRef sql = qb_.build_insert (
784+ table, ip.columns , ip.column_count , ip.value_rows , ip.row_count );
785+ return make_remote_scan (shard_list[it->first ].backend_name .c_str (), sql, table);
786+ }
787+ }
788+
789+ // Build per-shard INSERT nodes, combine with UNION ALL (for plan structure)
790+ PlanNode* current = nullptr ;
791+ for (auto & [shard_idx, row_indices] : shard_rows) {
792+ // Build a subset value_rows array
793+ uint16_t sub_count = static_cast <uint16_t >(row_indices.size ());
794+ auto ** sub_rows = static_cast <const sql_parser::AstNode**>(
795+ arena_.allocate (sizeof (sql_parser::AstNode*) * sub_count));
796+ for (uint16_t i = 0 ; i < sub_count; ++i) {
797+ sub_rows[i] = ip.value_rows [row_indices[i]];
798+ }
799+
800+ sql_parser::StringRef sql = qb_.build_insert (
801+ table, ip.columns , ip.column_count , sub_rows, sub_count);
802+ PlanNode* rs = make_remote_scan (shard_list[shard_idx].backend_name .c_str (), sql, table);
803+
804+ if (!current) {
805+ current = rs;
806+ } else {
807+ PlanNode* union_node = make_plan_node (arena_, PlanNodeType::SET_OP);
808+ union_node->set_op .op = SET_OP_UNION;
809+ union_node->set_op .all = true ;
810+ union_node->left = current;
811+ union_node->right = rs;
812+ current = union_node;
813+ }
814+ }
815+
816+ return current ? current : plan;
817+ }
818+
819+ PlanNode* distribute_update (PlanNode* plan) {
820+ const auto & up = plan->update_plan ;
821+ const TableInfo* table = up.table ;
822+ if (!table || !shards_.has_table (table->table_name )) return plan;
823+
824+ if (!shards_.is_sharded (table->table_name )) {
825+ // Unsharded: single remote UPDATE
826+ sql_parser::StringRef sql = qb_.build_update (
827+ table, up.set_columns , up.set_exprs , up.set_count , up.where_expr );
828+ return make_remote_scan (shards_.get_backend (table->table_name ), sql, table);
829+ }
830+
831+ // Sharded: check if WHERE references the shard key
832+ sql_parser::StringRef shard_key = shards_.get_shard_key (table->table_name );
833+ const auto & shard_list = shards_.get_shards (table->table_name );
834+
835+ int target_shard = find_shard_from_where (up.where_expr , shard_key, shard_list.size ());
836+
837+ if (target_shard >= 0 ) {
838+ // Route to specific shard
839+ sql_parser::StringRef sql = qb_.build_update (
840+ table, up.set_columns , up.set_exprs , up.set_count , up.where_expr );
841+ return make_remote_scan (shard_list[target_shard].backend_name .c_str (), sql, table);
842+ }
843+
844+ // Scatter to all shards
845+ return scatter_dml_to_shards (table, shard_list, [&]() {
846+ return qb_.build_update (
847+ table, up.set_columns , up.set_exprs , up.set_count , up.where_expr );
848+ });
849+ }
850+
851+ PlanNode* distribute_delete (PlanNode* plan) {
852+ const auto & dp = plan->delete_plan ;
853+ const TableInfo* table = dp.table ;
854+ if (!table || !shards_.has_table (table->table_name )) return plan;
855+
856+ if (!shards_.is_sharded (table->table_name )) {
857+ // Unsharded: single remote DELETE
858+ sql_parser::StringRef sql = qb_.build_delete (table, dp.where_expr );
859+ return make_remote_scan (shards_.get_backend (table->table_name ), sql, table);
860+ }
861+
862+ // Sharded: check if WHERE references the shard key
863+ sql_parser::StringRef shard_key = shards_.get_shard_key (table->table_name );
864+ const auto & shard_list = shards_.get_shards (table->table_name );
865+
866+ int target_shard = find_shard_from_where (dp.where_expr , shard_key, shard_list.size ());
867+
868+ if (target_shard >= 0 ) {
869+ // Route to specific shard
870+ sql_parser::StringRef sql = qb_.build_delete (table, dp.where_expr );
871+ return make_remote_scan (shard_list[target_shard].backend_name .c_str (), sql, table);
872+ }
873+
874+ // Scatter to all shards
875+ return scatter_dml_to_shards (table, shard_list, [&]() {
876+ return qb_.build_delete (table, dp.where_expr );
877+ });
878+ }
879+
880+ // Evaluate a shard key expression from a VALUES row (simple: literal values only)
881+ Value evaluate_shard_key_value (const sql_parser::AstNode* expr) {
882+ if (!expr) return value_null ();
883+ if (expr->type == sql_parser::NodeType::NODE_LITERAL_INT) {
884+ sql_parser::StringRef val = expr->value ();
885+ int64_t n = 0 ;
886+ for (uint32_t i = 0 ; i < val.len ; ++i) {
887+ char c = val.ptr [i];
888+ if (c >= ' 0' && c <= ' 9' ) n = n * 10 + (c - ' 0' );
889+ }
890+ return value_int (n);
891+ }
892+ if (expr->type == sql_parser::NodeType::NODE_LITERAL_STRING) {
893+ return value_string (expr->value ());
894+ }
895+ return value_null ();
896+ }
897+
898+ // Check if a WHERE expression contains shard_key = <literal>.
899+ // Returns the target shard index, or -1 if not determinable.
900+ int find_shard_from_where (const sql_parser::AstNode* where_expr,
901+ sql_parser::StringRef shard_key,
902+ size_t shard_count) {
903+ if (!where_expr || !shard_key.ptr || shard_count == 0 ) return -1 ;
904+
905+ // Look for binary_op '=' with one side being the shard key column
906+ if (where_expr->type == sql_parser::NodeType::NODE_BINARY_OP) {
907+ sql_parser::StringRef op = where_expr->value ();
908+ if (op.len == 1 && op.ptr [0 ] == ' =' ) {
909+ const sql_parser::AstNode* left = where_expr->first_child ;
910+ const sql_parser::AstNode* right = left ? left->next_sibling : nullptr ;
911+ if (!left || !right) return -1 ;
912+
913+ // Check if left is the shard key column and right is a literal (or vice versa)
914+ const sql_parser::AstNode* col_node = nullptr ;
915+ const sql_parser::AstNode* val_node = nullptr ;
916+
917+ if (is_column_ref (left, shard_key)) {
918+ col_node = left;
919+ val_node = right;
920+ } else if (is_column_ref (right, shard_key)) {
921+ col_node = right;
922+ val_node = left;
923+ }
924+
925+ if (col_node && val_node) {
926+ Value v = evaluate_shard_key_value (val_node);
927+ if (v.tag == Value::TAG_INT64) {
928+ return static_cast <int >(
929+ std::abs (v.int_val ) % static_cast <int64_t >(shard_count));
930+ }
931+ if (v.tag == Value::TAG_STRING && v.str_val .ptr ) {
932+ uint64_t h = 0 ;
933+ for (uint32_t k = 0 ; k < v.str_val .len ; ++k) {
934+ h = h * 31 + static_cast <uint8_t >(v.str_val .ptr [k]);
935+ }
936+ return static_cast <int >(h % shard_count);
937+ }
938+ }
939+ }
940+
941+ // Check AND: both sides might contain the shard key
942+ if (op.equals_ci (" AND" , 3 )) {
943+ const sql_parser::AstNode* left = where_expr->first_child ;
944+ const sql_parser::AstNode* right = left ? left->next_sibling : nullptr ;
945+ int r = find_shard_from_where (left, shard_key, shard_count);
946+ if (r >= 0 ) return r;
947+ return find_shard_from_where (right, shard_key, shard_count);
948+ }
949+ }
950+
951+ return -1 ;
952+ }
953+
954+ bool is_column_ref (const sql_parser::AstNode* node, sql_parser::StringRef col_name) {
955+ if (!node) return false ;
956+ if (node->type == sql_parser::NodeType::NODE_COLUMN_REF ||
957+ node->type == sql_parser::NodeType::NODE_IDENTIFIER) {
958+ return node->value ().equals_ci (col_name.ptr , col_name.len );
959+ }
960+ return false ;
961+ }
962+
963+ // Scatter DML SQL to all shards, combining results via UNION ALL
964+ PlanNode* scatter_dml_to_shards (const TableInfo* table,
965+ const std::vector<ShardInfo>& shard_list,
966+ std::function<sql_parser::StringRef()> build_sql) {
967+ if (shard_list.empty ()) return nullptr ;
968+
969+ PlanNode* current = nullptr ;
970+ for (const auto & shard : shard_list) {
971+ sql_parser::StringRef sql = build_sql ();
972+ PlanNode* rs = make_remote_scan (shard.backend_name .c_str (), sql, table);
973+ if (!current) {
974+ current = rs;
975+ } else {
976+ PlanNode* union_node = make_plan_node (arena_, PlanNodeType::SET_OP);
977+ union_node->set_op .op = SET_OP_UNION;
978+ union_node->set_op .all = true ;
979+ union_node->left = current;
980+ union_node->right = rs;
981+ current = union_node;
982+ }
983+ }
984+ return current;
985+ }
986+
673987 // Case 6: Distributed DISTINCT
674988 PlanNode* distribute_distinct (PlanNode* distinct_node) {
675989 // Check child for sharded scan
0 commit comments