Skip to content

Commit 6be0b04

Browse files
committed
perf: add parallel remote execution, shard-key routing, hash join operator (#26, #27, #28)
1 parent 5e036c7 commit 6be0b04

8 files changed

Lines changed: 765 additions & 37 deletions

File tree

include/sql_engine/distributed_planner.h

Lines changed: 160 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,167 @@ class DistributedPlanner {
276276
}
277277

278278
// Case 2: Sharded -- N RemoteScans + UNION ALL
279-
const auto& shard_list = shards_.get_shards(table->table_name);
279+
// Optimization (#27): if WHERE contains shard_key = <literal> or
280+
// shard_key IN (<literals>), route to only the relevant shard(s).
281+
const auto& full_shard_list = shards_.get_shards(table->table_name);
282+
std::vector<ShardInfo> pruned = prune_shards(table, where_expr, full_shard_list);
280283
return make_sharded_union(table, where_expr, nullptr, 0, nullptr, 0,
281-
nullptr, nullptr, 0, -1, false, shard_list);
284+
nullptr, nullptr, 0, -1, false, pruned);
285+
}
286+
287+
// Shard pruning (#27): analyze WHERE for shard_key = <literal> or
288+
// shard_key IN (<literal_list>). Returns a subset of shards when pruning
289+
// is possible, otherwise returns the full shard list.
290+
std::vector<ShardInfo> prune_shards(const TableInfo* table,
291+
const sql_parser::AstNode* where_expr,
292+
const std::vector<ShardInfo>& all_shards) {
293+
if (!where_expr || all_shards.empty()) return all_shards;
294+
295+
sql_parser::StringRef shard_key = shards_.get_shard_key(table->table_name);
296+
if (!shard_key.ptr || shard_key.len == 0) return all_shards;
297+
298+
// Try to extract shard key literal values from WHERE expression
299+
std::vector<size_t> target_indices;
300+
extract_shard_targets(where_expr, shard_key, table->table_name,
301+
all_shards.size(), target_indices);
302+
303+
if (target_indices.empty()) return all_shards;
304+
305+
// Deduplicate and collect matching shards
306+
std::vector<bool> included(all_shards.size(), false);
307+
for (size_t idx : target_indices) {
308+
if (idx < all_shards.size()) included[idx] = true;
309+
}
310+
std::vector<ShardInfo> result;
311+
for (size_t i = 0; i < all_shards.size(); ++i) {
312+
if (included[i]) result.push_back(all_shards[i]);
313+
}
314+
return result.empty() ? all_shards : result;
315+
}
316+
317+
// Walk a WHERE expression looking for shard_key = <literal> or
318+
// shard_key IN (<literal>, ...). Populates target_indices with
319+
// the shard index for each matched literal.
320+
void extract_shard_targets(const sql_parser::AstNode* expr,
321+
sql_parser::StringRef shard_key,
322+
sql_parser::StringRef table_name,
323+
size_t num_shards,
324+
std::vector<size_t>& target_indices) {
325+
if (!expr) return;
326+
327+
// Check for shard_key = <literal>
328+
if (expr->type == sql_parser::NodeType::NODE_BINARY_OP) {
329+
sql_parser::StringRef op = expr->value();
330+
if (op.len == 1 && op.ptr[0] == '=') {
331+
const sql_parser::AstNode* left_node = expr->first_child;
332+
const sql_parser::AstNode* right_node = left_node ? left_node->next_sibling : nullptr;
333+
if (left_node && right_node) {
334+
// Check if one side is the shard key column and the other is a literal
335+
const sql_parser::AstNode* col_node = nullptr;
336+
const sql_parser::AstNode* lit_node = nullptr;
337+
if (is_shard_key_ref(left_node, shard_key) && is_literal(right_node)) {
338+
col_node = left_node; lit_node = right_node;
339+
} else if (is_shard_key_ref(right_node, shard_key) && is_literal(left_node)) {
340+
col_node = right_node; lit_node = left_node;
341+
}
342+
if (col_node && lit_node) {
343+
size_t idx = literal_to_shard_index(lit_node, table_name, num_shards);
344+
target_indices.push_back(idx);
345+
return;
346+
}
347+
}
348+
}
349+
// Recurse into AND branches
350+
if (op.len == 3 &&
351+
(op.ptr[0] == 'A' || op.ptr[0] == 'a') &&
352+
(op.ptr[1] == 'N' || op.ptr[1] == 'n') &&
353+
(op.ptr[2] == 'D' || op.ptr[2] == 'd')) {
354+
const sql_parser::AstNode* left_node = expr->first_child;
355+
const sql_parser::AstNode* right_node = left_node ? left_node->next_sibling : nullptr;
356+
// For AND, either branch matching is sufficient (both must be true,
357+
// so if one constrains the shard key, we can prune).
358+
std::vector<size_t> left_targets, right_targets;
359+
extract_shard_targets(left_node, shard_key, table_name, num_shards, left_targets);
360+
extract_shard_targets(right_node, shard_key, table_name, num_shards, right_targets);
361+
// Use whichever branch found shard targets (prefer the more selective one)
362+
if (!left_targets.empty() && !right_targets.empty()) {
363+
// Intersect: both constraints must hold
364+
std::vector<bool> lset(num_shards, false), rset(num_shards, false);
365+
for (auto i : left_targets) if (i < num_shards) lset[i] = true;
366+
for (auto i : right_targets) if (i < num_shards) rset[i] = true;
367+
for (size_t i = 0; i < num_shards; ++i) {
368+
if (lset[i] && rset[i]) target_indices.push_back(i);
369+
}
370+
} else if (!left_targets.empty()) {
371+
target_indices.insert(target_indices.end(), left_targets.begin(), left_targets.end());
372+
} else if (!right_targets.empty()) {
373+
target_indices.insert(target_indices.end(), right_targets.begin(), right_targets.end());
374+
}
375+
return;
376+
}
377+
}
378+
379+
// Check for shard_key IN (literal_list)
380+
if (expr->type == sql_parser::NodeType::NODE_IN_LIST) {
381+
const sql_parser::AstNode* col_expr = expr->first_child;
382+
if (col_expr && is_shard_key_ref(col_expr, shard_key)) {
383+
for (const sql_parser::AstNode* item = col_expr->next_sibling; item; item = item->next_sibling) {
384+
if (is_literal(item)) {
385+
target_indices.push_back(literal_to_shard_index(item, table_name, num_shards));
386+
} else {
387+
// Non-literal in IN list -- can't prune
388+
target_indices.clear();
389+
return;
390+
}
391+
}
392+
}
393+
}
394+
}
395+
396+
bool is_shard_key_ref(const sql_parser::AstNode* node, sql_parser::StringRef shard_key) const {
397+
if (!node) return false;
398+
if (node->type == sql_parser::NodeType::NODE_COLUMN_REF ||
399+
node->type == sql_parser::NodeType::NODE_IDENTIFIER) {
400+
return node->value().equals_ci(shard_key.ptr, shard_key.len);
401+
}
402+
if (node->type == sql_parser::NodeType::NODE_QUALIFIED_NAME) {
403+
// table.column -- check the column part
404+
const sql_parser::AstNode* c = node->first_child;
405+
if (c && c->next_sibling) {
406+
return c->next_sibling->value().equals_ci(shard_key.ptr, shard_key.len);
407+
}
408+
}
409+
return false;
410+
}
411+
412+
static bool is_literal(const sql_parser::AstNode* node) {
413+
if (!node) return false;
414+
return node->type == sql_parser::NodeType::NODE_LITERAL_INT ||
415+
node->type == sql_parser::NodeType::NODE_LITERAL_FLOAT ||
416+
node->type == sql_parser::NodeType::NODE_LITERAL_STRING;
417+
}
418+
419+
size_t literal_to_shard_index(const sql_parser::AstNode* lit,
420+
sql_parser::StringRef table_name,
421+
size_t num_shards) const {
422+
if (!lit || num_shards == 0) return 0;
423+
if (lit->type == sql_parser::NodeType::NODE_LITERAL_INT) {
424+
sql_parser::StringRef sv = lit->value();
425+
int64_t val = 0;
426+
if (sv.ptr && sv.len > 0) val = std::strtoll(sv.ptr, nullptr, 10);
427+
return shards_.shard_index_for_int(table_name, val);
428+
}
429+
if (lit->type == sql_parser::NodeType::NODE_LITERAL_STRING) {
430+
sql_parser::StringRef sv = lit->value();
431+
return shards_.shard_index_for_string(table_name, sv.ptr, sv.len);
432+
}
433+
if (lit->type == sql_parser::NodeType::NODE_LITERAL_FLOAT) {
434+
sql_parser::StringRef sv = lit->value();
435+
double dv = sv.ptr ? std::strtod(sv.ptr, nullptr) : 0.0;
436+
int64_t iv = static_cast<int64_t>(dv);
437+
return shards_.shard_index_for_int(table_name, iv);
438+
}
439+
return 0;
282440
}
283441

284442
// Build N RemoteScans with UNION ALL
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
#ifndef SQL_ENGINE_OPERATORS_HASH_JOIN_OP_H
2+
#define SQL_ENGINE_OPERATORS_HASH_JOIN_OP_H
3+
4+
#include "sql_engine/operator.h"
5+
#include "sql_engine/expression_eval.h"
6+
#include "sql_engine/catalog.h"
7+
#include "sql_engine/plan_node.h"
8+
#include "sql_parser/arena.h"
9+
#include <functional>
10+
#include <vector>
11+
#include <unordered_map>
12+
13+
namespace sql_engine {
14+
15+
// HashJoinOperator — builds a hash table on the right (build) side keyed by
16+
// the equi-join column, then probes it with each row from the left (probe)
17+
// side. O(n + m) for equi-joins, vs O(n * m) for nested-loop.
18+
//
19+
// Supports INNER and LEFT equi-joins. Falls back to NestedLoopJoinOperator
20+
// for non-equi or CROSS joins (selection is done in PlanExecutor::build_join).
21+
template <sql_parser::Dialect D>
22+
class HashJoinOperator : public Operator {
23+
public:
24+
HashJoinOperator(Operator* left, Operator* right,
25+
uint8_t join_type,
26+
uint16_t left_join_col, // ordinal of join key in left row
27+
uint16_t right_join_col, // ordinal of join key in right row
28+
uint16_t left_cols, uint16_t right_cols,
29+
sql_parser::Arena& arena)
30+
: left_(left), right_(right), join_type_(join_type),
31+
left_join_col_(left_join_col), right_join_col_(right_join_col),
32+
left_cols_(left_cols), right_cols_(right_cols),
33+
arena_(arena) {}
34+
35+
void open() override {
36+
// Build phase: consume right side into hash table keyed by join column
37+
hash_table_.clear();
38+
right_->open();
39+
Row r{};
40+
while (right_->next(r)) {
41+
uint64_t h = hash_value(r.get(right_join_col_));
42+
hash_table_[h].push_back(copy_row(r));
43+
}
44+
right_->close();
45+
46+
// Probe side
47+
left_->open();
48+
probe_exhausted_ = false;
49+
has_probe_row_ = false;
50+
match_idx_ = 0;
51+
current_matches_ = nullptr;
52+
probe_matched_ = false;
53+
}
54+
55+
bool next(Row& out) override {
56+
uint16_t total_cols = left_cols_ + right_cols_;
57+
58+
while (true) {
59+
// Try to emit from current matches
60+
if (has_probe_row_ && current_matches_) {
61+
while (match_idx_ < current_matches_->size()) {
62+
const Row& rr = (*current_matches_)[match_idx_];
63+
match_idx_++;
64+
// Verify actual equality (not just hash match)
65+
if (values_equal(probe_row_.get(left_join_col_), rr.get(right_join_col_))) {
66+
probe_matched_ = true;
67+
out = combine_rows(probe_row_, rr, total_cols);
68+
return true;
69+
}
70+
}
71+
}
72+
73+
// Done with current probe row's matches
74+
if (has_probe_row_ && join_type_ == JOIN_LEFT && !probe_matched_) {
75+
// LEFT JOIN: emit probe row with NULLs for right side
76+
out = make_row(arena_, total_cols);
77+
for (uint16_t i = 0; i < left_cols_ && i < probe_row_.column_count; ++i)
78+
out.set(i, probe_row_.get(i));
79+
for (uint16_t i = 0; i < right_cols_; ++i)
80+
out.set(left_cols_ + i, value_null());
81+
has_probe_row_ = false;
82+
return true;
83+
}
84+
85+
// Advance to next probe row
86+
if (probe_exhausted_) return false;
87+
if (!left_->next(probe_row_)) {
88+
probe_exhausted_ = true;
89+
return false;
90+
}
91+
92+
has_probe_row_ = true;
93+
probe_matched_ = false;
94+
match_idx_ = 0;
95+
96+
uint64_t h = hash_value(probe_row_.get(left_join_col_));
97+
auto it = hash_table_.find(h);
98+
current_matches_ = (it != hash_table_.end()) ? &it->second : nullptr;
99+
}
100+
}
101+
102+
void close() override {
103+
left_->close();
104+
hash_table_.clear();
105+
}
106+
107+
private:
108+
Operator* left_;
109+
Operator* right_;
110+
uint8_t join_type_;
111+
uint16_t left_join_col_;
112+
uint16_t right_join_col_;
113+
uint16_t left_cols_;
114+
uint16_t right_cols_;
115+
sql_parser::Arena& arena_;
116+
117+
std::unordered_map<uint64_t, std::vector<Row>> hash_table_;
118+
Row probe_row_{};
119+
bool has_probe_row_ = false;
120+
bool probe_exhausted_ = false;
121+
bool probe_matched_ = false;
122+
size_t match_idx_ = 0;
123+
const std::vector<Row>* current_matches_ = nullptr;
124+
125+
Row copy_row(const Row& src) {
126+
Row dst = make_row(arena_, src.column_count);
127+
for (uint16_t i = 0; i < src.column_count; ++i)
128+
dst.set(i, src.get(i));
129+
return dst;
130+
}
131+
132+
Row combine_rows(const Row& left, const Row& right, uint16_t total_cols) {
133+
Row out = make_row(arena_, total_cols);
134+
for (uint16_t i = 0; i < left_cols_ && i < left.column_count; ++i)
135+
out.set(i, left.get(i));
136+
for (uint16_t i = 0; i < right_cols_ && i < right.column_count; ++i)
137+
out.set(left_cols_ + i, right.get(i));
138+
return out;
139+
}
140+
141+
static uint64_t hash_value(const Value& v) {
142+
if (v.is_null()) return 0;
143+
switch (v.tag) {
144+
case Value::TAG_BOOL: return std::hash<bool>{}(v.bool_val);
145+
case Value::TAG_INT64: return std::hash<int64_t>{}(v.int_val);
146+
case Value::TAG_UINT64: return std::hash<uint64_t>{}(v.uint_val);
147+
case Value::TAG_DOUBLE: return std::hash<double>{}(v.double_val);
148+
case Value::TAG_STRING: {
149+
// FNV-1a hash
150+
uint64_t h = 14695981039346656037ULL;
151+
for (uint32_t i = 0; i < v.str_val.len; ++i) {
152+
h ^= static_cast<uint64_t>(static_cast<unsigned char>(v.str_val.ptr[i]));
153+
h *= 1099511628211ULL;
154+
}
155+
return h;
156+
}
157+
default: return 0;
158+
}
159+
}
160+
161+
static bool values_equal(const Value& a, const Value& b) {
162+
if (a.is_null() || b.is_null()) return false;
163+
if (a.tag == b.tag) {
164+
switch (a.tag) {
165+
case Value::TAG_BOOL: return a.bool_val == b.bool_val;
166+
case Value::TAG_INT64: return a.int_val == b.int_val;
167+
case Value::TAG_UINT64: return a.uint_val == b.uint_val;
168+
case Value::TAG_DOUBLE: return a.double_val == b.double_val;
169+
case Value::TAG_STRING:
170+
return a.str_val.len == b.str_val.len &&
171+
(a.str_val.len == 0 ||
172+
std::memcmp(a.str_val.ptr, b.str_val.ptr, a.str_val.len) == 0);
173+
default: return false;
174+
}
175+
}
176+
// Cross-type numeric comparison
177+
if (a.is_numeric() && b.is_numeric()) {
178+
return a.to_double() == b.to_double();
179+
}
180+
return false;
181+
}
182+
};
183+
184+
} // namespace sql_engine
185+
186+
#endif // SQL_ENGINE_OPERATORS_HASH_JOIN_OP_H

0 commit comments

Comments
 (0)