Skip to content

Commit 2d36d58

Browse files
committed
feat: add Volcano executor engine — SQL goes in, data comes out
1 parent 0ecbb84 commit 2d36d58

16 files changed

Lines changed: 2579 additions & 1 deletion

Makefile.new

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ TEST_SRCS = $(TEST_DIR)/test_main.cpp \
5353
$(TEST_DIR)/test_eval_integration.cpp \
5454
$(TEST_DIR)/test_catalog.cpp \
5555
$(TEST_DIR)/test_row.cpp \
56-
$(TEST_DIR)/test_plan_builder.cpp
56+
$(TEST_DIR)/test_plan_builder.cpp \
57+
$(TEST_DIR)/test_operators.cpp \
58+
$(TEST_DIR)/test_plan_executor.cpp
5759
TEST_OBJS = $(TEST_SRCS:.cpp=.o)
5860
TEST_TARGET = $(PROJECT_ROOT)/run_tests
5961

include/sql_engine/data_source.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#ifndef SQL_ENGINE_DATA_SOURCE_H
2+
#define SQL_ENGINE_DATA_SOURCE_H
3+
4+
#include "sql_engine/catalog.h"
5+
#include "sql_engine/row.h"
6+
#include <vector>
7+
8+
namespace sql_engine {
9+
10+
class DataSource {
11+
public:
12+
virtual ~DataSource() = default;
13+
virtual const TableInfo* table_info() const = 0;
14+
virtual void open() = 0;
15+
virtual bool next(Row& out) = 0;
16+
virtual void close() = 0;
17+
};
18+
19+
class InMemoryDataSource : public DataSource {
20+
public:
21+
InMemoryDataSource(const TableInfo* table, std::vector<Row> rows)
22+
: table_(table), rows_(std::move(rows)) {}
23+
24+
const TableInfo* table_info() const override { return table_; }
25+
26+
void open() override { cursor_ = 0; }
27+
28+
bool next(Row& out) override {
29+
if (cursor_ >= rows_.size()) return false;
30+
out = rows_[cursor_++];
31+
return true;
32+
}
33+
34+
void close() override { cursor_ = 0; }
35+
36+
private:
37+
const TableInfo* table_;
38+
std::vector<Row> rows_;
39+
size_t cursor_ = 0;
40+
};
41+
42+
} // namespace sql_engine
43+
44+
#endif // SQL_ENGINE_DATA_SOURCE_H

include/sql_engine/operator.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#ifndef SQL_ENGINE_OPERATOR_H
2+
#define SQL_ENGINE_OPERATOR_H
3+
4+
#include "sql_engine/row.h"
5+
6+
namespace sql_engine {
7+
8+
class Operator {
9+
public:
10+
virtual ~Operator() = default;
11+
virtual void open() = 0;
12+
virtual bool next(Row& out) = 0;
13+
virtual void close() = 0;
14+
};
15+
16+
} // namespace sql_engine
17+
18+
#endif // SQL_ENGINE_OPERATOR_H
Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
#ifndef SQL_ENGINE_OPERATORS_AGGREGATE_OP_H
2+
#define SQL_ENGINE_OPERATORS_AGGREGATE_OP_H
3+
4+
#include "sql_engine/operator.h"
5+
#include "sql_engine/expression_eval.h"
6+
#include "sql_engine/catalog.h"
7+
#include "sql_parser/arena.h"
8+
#include <functional>
9+
#include <vector>
10+
#include <unordered_map>
11+
#include <string>
12+
#include <cstring>
13+
#include <cmath>
14+
15+
namespace sql_engine {
16+
17+
template <sql_parser::Dialect D>
18+
class AggregateOperator : public Operator {
19+
public:
20+
AggregateOperator(Operator* child,
21+
const sql_parser::AstNode** group_by_exprs,
22+
uint16_t group_count,
23+
const sql_parser::AstNode** agg_exprs,
24+
uint16_t agg_count,
25+
const Catalog& catalog,
26+
const std::vector<const TableInfo*>& tables,
27+
FunctionRegistry<D>& functions,
28+
sql_parser::Arena& arena)
29+
: child_(child), group_by_exprs_(group_by_exprs), group_count_(group_count),
30+
agg_exprs_(agg_exprs), agg_count_(agg_count),
31+
catalog_(catalog), tables_(tables), functions_(functions), arena_(arena) {}
32+
33+
void open() override {
34+
child_->open();
35+
groups_.clear();
36+
group_order_.clear();
37+
result_idx_ = 0;
38+
39+
// Consume all child rows
40+
Row row{};
41+
while (child_->next(row)) {
42+
auto resolver = make_resolver(row);
43+
std::string key = compute_group_key(resolver);
44+
45+
auto it = groups_.find(key);
46+
if (it == groups_.end()) {
47+
GroupState state;
48+
state.group_values.reserve(group_count_);
49+
for (uint16_t i = 0; i < group_count_; ++i) {
50+
state.group_values.push_back(evaluate_expression<D>(
51+
group_by_exprs_[i], resolver, functions_, arena_));
52+
}
53+
state.agg_states.reserve(agg_count_);
54+
for (uint16_t i = 0; i < agg_count_; ++i) {
55+
AggState s{};
56+
detect_agg_type(agg_exprs_[i], s);
57+
state.agg_states.push_back(s);
58+
}
59+
groups_[key] = std::move(state);
60+
group_order_.push_back(key);
61+
it = groups_.find(key);
62+
}
63+
64+
// Update aggregate states
65+
auto resolver2 = make_resolver(row);
66+
for (uint16_t i = 0; i < agg_count_; ++i) {
67+
update_agg(it->second.agg_states[i], agg_exprs_[i], resolver2);
68+
}
69+
}
70+
child_->close();
71+
}
72+
73+
bool next(Row& out) override {
74+
if (result_idx_ >= group_order_.size()) {
75+
// If no groups at all and no group-by (whole-table aggregate), emit one row
76+
if (result_idx_ == 0 && group_count_ == 0 && groups_.empty()) {
77+
result_idx_ = 1; // only once
78+
uint16_t cols = group_count_ + agg_count_;
79+
if (cols == 0) return false;
80+
out = make_row(arena_, cols);
81+
for (uint16_t i = 0; i < agg_count_; ++i) {
82+
AggState s{};
83+
detect_agg_type(agg_exprs_[i], s);
84+
out.set(group_count_ + i, finalize_agg(s));
85+
}
86+
return true;
87+
}
88+
return false;
89+
}
90+
91+
const auto& key = group_order_[result_idx_++];
92+
const auto& state = groups_[key];
93+
94+
uint16_t cols = group_count_ + agg_count_;
95+
out = make_row(arena_, cols);
96+
for (uint16_t i = 0; i < group_count_; ++i) {
97+
out.set(i, state.group_values[i]);
98+
}
99+
for (uint16_t i = 0; i < agg_count_; ++i) {
100+
out.set(group_count_ + i, finalize_agg(state.agg_states[i]));
101+
}
102+
return true;
103+
}
104+
105+
void close() override {
106+
groups_.clear();
107+
group_order_.clear();
108+
}
109+
110+
private:
111+
Operator* child_;
112+
const sql_parser::AstNode** group_by_exprs_;
113+
uint16_t group_count_;
114+
const sql_parser::AstNode** agg_exprs_;
115+
uint16_t agg_count_;
116+
const Catalog& catalog_;
117+
std::vector<const TableInfo*> tables_;
118+
FunctionRegistry<D>& functions_;
119+
sql_parser::Arena& arena_;
120+
121+
enum class AggType { COUNT, SUM, AVG, MIN, MAX, EXPR };
122+
123+
struct AggState {
124+
AggType type = AggType::EXPR;
125+
int64_t count = 0;
126+
double sum = 0.0;
127+
Value min_val{};
128+
Value max_val{};
129+
bool has_value = false;
130+
bool count_star = false; // COUNT(*)
131+
};
132+
133+
struct GroupState {
134+
std::vector<Value> group_values;
135+
std::vector<AggState> agg_states;
136+
};
137+
138+
std::unordered_map<std::string, GroupState> groups_;
139+
std::vector<std::string> group_order_;
140+
size_t result_idx_ = 0;
141+
142+
std::function<Value(sql_parser::StringRef)> make_resolver(const Row& row) {
143+
return [this, &row](sql_parser::StringRef col_name) -> Value {
144+
uint16_t offset = 0;
145+
for (const auto* table : tables_) {
146+
if (!table) continue;
147+
const ColumnInfo* col = catalog_.get_column(table, col_name);
148+
if (col) {
149+
uint16_t idx = offset + col->ordinal;
150+
if (idx < row.column_count) return row.get(idx);
151+
}
152+
offset += table->column_count;
153+
}
154+
return value_null();
155+
};
156+
}
157+
158+
std::string compute_group_key(const std::function<Value(sql_parser::StringRef)>& resolver) {
159+
if (group_count_ == 0) return "";
160+
std::string key;
161+
for (uint16_t i = 0; i < group_count_; ++i) {
162+
Value v = evaluate_expression<D>(group_by_exprs_[i], resolver, functions_, arena_);
163+
key += value_to_string(v);
164+
key += '\x01'; // separator
165+
}
166+
return key;
167+
}
168+
169+
static std::string value_to_string(const Value& v) {
170+
if (v.is_null()) return "NULL";
171+
switch (v.tag) {
172+
case Value::TAG_BOOL: return v.bool_val ? "1" : "0";
173+
case Value::TAG_INT64: return std::to_string(v.int_val);
174+
case Value::TAG_UINT64: return std::to_string(v.uint_val);
175+
case Value::TAG_DOUBLE: return std::to_string(v.double_val);
176+
case Value::TAG_STRING: return std::string(v.str_val.ptr, v.str_val.len);
177+
default: return "?";
178+
}
179+
}
180+
181+
void detect_agg_type(const sql_parser::AstNode* expr, AggState& state) {
182+
if (!expr) { state.type = AggType::EXPR; return; }
183+
184+
if (expr->type == sql_parser::NodeType::NODE_FUNCTION_CALL) {
185+
sql_parser::StringRef name = expr->value();
186+
if (name.equals_ci("COUNT", 5)) {
187+
state.type = AggType::COUNT;
188+
// Check for COUNT(*)
189+
const sql_parser::AstNode* arg = expr->first_child;
190+
if (arg && arg->type == sql_parser::NodeType::NODE_ASTERISK) {
191+
state.count_star = true;
192+
}
193+
return;
194+
}
195+
if (name.equals_ci("SUM", 3)) { state.type = AggType::SUM; return; }
196+
if (name.equals_ci("AVG", 3)) { state.type = AggType::AVG; return; }
197+
if (name.equals_ci("MIN", 3)) { state.type = AggType::MIN; return; }
198+
if (name.equals_ci("MAX", 3)) { state.type = AggType::MAX; return; }
199+
}
200+
state.type = AggType::EXPR;
201+
}
202+
203+
void update_agg(AggState& state, const sql_parser::AstNode* expr,
204+
const std::function<Value(sql_parser::StringRef)>& resolver) {
205+
switch (state.type) {
206+
case AggType::COUNT: {
207+
if (state.count_star) {
208+
state.count++;
209+
} else {
210+
// COUNT(expr) - count non-null values
211+
const sql_parser::AstNode* arg = expr->first_child;
212+
Value v = evaluate_expression<D>(arg, resolver, functions_, arena_);
213+
if (!v.is_null()) state.count++;
214+
}
215+
break;
216+
}
217+
case AggType::SUM:
218+
case AggType::AVG: {
219+
const sql_parser::AstNode* arg = expr->first_child;
220+
Value v = evaluate_expression<D>(arg, resolver, functions_, arena_);
221+
if (!v.is_null()) {
222+
state.sum += v.to_double();
223+
state.count++;
224+
state.has_value = true;
225+
}
226+
break;
227+
}
228+
case AggType::MIN: {
229+
const sql_parser::AstNode* arg = expr->first_child;
230+
Value v = evaluate_expression<D>(arg, resolver, functions_, arena_);
231+
if (!v.is_null()) {
232+
if (!state.has_value || compare_values(v, state.min_val) < 0) {
233+
state.min_val = v;
234+
state.has_value = true;
235+
}
236+
}
237+
break;
238+
}
239+
case AggType::MAX: {
240+
const sql_parser::AstNode* arg = expr->first_child;
241+
Value v = evaluate_expression<D>(arg, resolver, functions_, arena_);
242+
if (!v.is_null()) {
243+
if (!state.has_value || compare_values(v, state.max_val) > 0) {
244+
state.max_val = v;
245+
state.has_value = true;
246+
}
247+
}
248+
break;
249+
}
250+
case AggType::EXPR:
251+
break;
252+
}
253+
}
254+
255+
Value finalize_agg(const AggState& state) const {
256+
switch (state.type) {
257+
case AggType::COUNT:
258+
return value_int(state.count);
259+
case AggType::SUM:
260+
if (!state.has_value) return value_null();
261+
return value_double(state.sum);
262+
case AggType::AVG:
263+
if (state.count == 0) return value_null();
264+
return value_double(state.sum / static_cast<double>(state.count));
265+
case AggType::MIN:
266+
if (!state.has_value) return value_null();
267+
return state.min_val;
268+
case AggType::MAX:
269+
if (!state.has_value) return value_null();
270+
return state.max_val;
271+
case AggType::EXPR:
272+
return value_null();
273+
}
274+
return value_null();
275+
}
276+
277+
static int compare_values(const Value& a, const Value& b) {
278+
if (a.is_null() && b.is_null()) return 0;
279+
if (a.is_null()) return -1;
280+
if (b.is_null()) return 1;
281+
282+
// Try numeric comparison
283+
if (a.is_numeric() && b.is_numeric()) {
284+
double da = a.to_double();
285+
double db = b.to_double();
286+
if (da < db) return -1;
287+
if (da > db) return 1;
288+
return 0;
289+
}
290+
291+
// String comparison
292+
if (a.tag == Value::TAG_STRING && b.tag == Value::TAG_STRING) {
293+
uint32_t minlen = a.str_val.len < b.str_val.len ? a.str_val.len : b.str_val.len;
294+
int cmp = std::memcmp(a.str_val.ptr, b.str_val.ptr, minlen);
295+
if (cmp != 0) return cmp;
296+
if (a.str_val.len < b.str_val.len) return -1;
297+
if (a.str_val.len > b.str_val.len) return 1;
298+
return 0;
299+
}
300+
301+
return 0;
302+
}
303+
};
304+
305+
} // namespace sql_engine
306+
307+
#endif // SQL_ENGINE_OPERATORS_AGGREGATE_OP_H

0 commit comments

Comments
 (0)