|
| 1 | +#ifndef SQL_ENGINE_OPERATORS_AGGREGATE_OP_H |
| 2 | +#define SQL_ENGINE_OPERATORS_AGGREGATE_OP_H |
| 3 | + |
| 4 | +#include "sql_engine/operator.h" |
| 5 | +#include "sql_engine/expression_eval.h" |
| 6 | +#include "sql_engine/catalog.h" |
| 7 | +#include "sql_parser/arena.h" |
| 8 | +#include <functional> |
| 9 | +#include <vector> |
| 10 | +#include <unordered_map> |
| 11 | +#include <string> |
| 12 | +#include <cstring> |
| 13 | +#include <cmath> |
| 14 | + |
| 15 | +namespace sql_engine { |
| 16 | + |
| 17 | +template <sql_parser::Dialect D> |
| 18 | +class AggregateOperator : public Operator { |
| 19 | +public: |
| 20 | + AggregateOperator(Operator* child, |
| 21 | + const sql_parser::AstNode** group_by_exprs, |
| 22 | + uint16_t group_count, |
| 23 | + const sql_parser::AstNode** agg_exprs, |
| 24 | + uint16_t agg_count, |
| 25 | + const Catalog& catalog, |
| 26 | + const std::vector<const TableInfo*>& tables, |
| 27 | + FunctionRegistry<D>& functions, |
| 28 | + sql_parser::Arena& arena) |
| 29 | + : child_(child), group_by_exprs_(group_by_exprs), group_count_(group_count), |
| 30 | + agg_exprs_(agg_exprs), agg_count_(agg_count), |
| 31 | + catalog_(catalog), tables_(tables), functions_(functions), arena_(arena) {} |
| 32 | + |
| 33 | + void open() override { |
| 34 | + child_->open(); |
| 35 | + groups_.clear(); |
| 36 | + group_order_.clear(); |
| 37 | + result_idx_ = 0; |
| 38 | + |
| 39 | + // Consume all child rows |
| 40 | + Row row{}; |
| 41 | + while (child_->next(row)) { |
| 42 | + auto resolver = make_resolver(row); |
| 43 | + std::string key = compute_group_key(resolver); |
| 44 | + |
| 45 | + auto it = groups_.find(key); |
| 46 | + if (it == groups_.end()) { |
| 47 | + GroupState state; |
| 48 | + state.group_values.reserve(group_count_); |
| 49 | + for (uint16_t i = 0; i < group_count_; ++i) { |
| 50 | + state.group_values.push_back(evaluate_expression<D>( |
| 51 | + group_by_exprs_[i], resolver, functions_, arena_)); |
| 52 | + } |
| 53 | + state.agg_states.reserve(agg_count_); |
| 54 | + for (uint16_t i = 0; i < agg_count_; ++i) { |
| 55 | + AggState s{}; |
| 56 | + detect_agg_type(agg_exprs_[i], s); |
| 57 | + state.agg_states.push_back(s); |
| 58 | + } |
| 59 | + groups_[key] = std::move(state); |
| 60 | + group_order_.push_back(key); |
| 61 | + it = groups_.find(key); |
| 62 | + } |
| 63 | + |
| 64 | + // Update aggregate states |
| 65 | + auto resolver2 = make_resolver(row); |
| 66 | + for (uint16_t i = 0; i < agg_count_; ++i) { |
| 67 | + update_agg(it->second.agg_states[i], agg_exprs_[i], resolver2); |
| 68 | + } |
| 69 | + } |
| 70 | + child_->close(); |
| 71 | + } |
| 72 | + |
| 73 | + bool next(Row& out) override { |
| 74 | + if (result_idx_ >= group_order_.size()) { |
| 75 | + // If no groups at all and no group-by (whole-table aggregate), emit one row |
| 76 | + if (result_idx_ == 0 && group_count_ == 0 && groups_.empty()) { |
| 77 | + result_idx_ = 1; // only once |
| 78 | + uint16_t cols = group_count_ + agg_count_; |
| 79 | + if (cols == 0) return false; |
| 80 | + out = make_row(arena_, cols); |
| 81 | + for (uint16_t i = 0; i < agg_count_; ++i) { |
| 82 | + AggState s{}; |
| 83 | + detect_agg_type(agg_exprs_[i], s); |
| 84 | + out.set(group_count_ + i, finalize_agg(s)); |
| 85 | + } |
| 86 | + return true; |
| 87 | + } |
| 88 | + return false; |
| 89 | + } |
| 90 | + |
| 91 | + const auto& key = group_order_[result_idx_++]; |
| 92 | + const auto& state = groups_[key]; |
| 93 | + |
| 94 | + uint16_t cols = group_count_ + agg_count_; |
| 95 | + out = make_row(arena_, cols); |
| 96 | + for (uint16_t i = 0; i < group_count_; ++i) { |
| 97 | + out.set(i, state.group_values[i]); |
| 98 | + } |
| 99 | + for (uint16_t i = 0; i < agg_count_; ++i) { |
| 100 | + out.set(group_count_ + i, finalize_agg(state.agg_states[i])); |
| 101 | + } |
| 102 | + return true; |
| 103 | + } |
| 104 | + |
| 105 | + void close() override { |
| 106 | + groups_.clear(); |
| 107 | + group_order_.clear(); |
| 108 | + } |
| 109 | + |
| 110 | +private: |
| 111 | + Operator* child_; |
| 112 | + const sql_parser::AstNode** group_by_exprs_; |
| 113 | + uint16_t group_count_; |
| 114 | + const sql_parser::AstNode** agg_exprs_; |
| 115 | + uint16_t agg_count_; |
| 116 | + const Catalog& catalog_; |
| 117 | + std::vector<const TableInfo*> tables_; |
| 118 | + FunctionRegistry<D>& functions_; |
| 119 | + sql_parser::Arena& arena_; |
| 120 | + |
| 121 | + enum class AggType { COUNT, SUM, AVG, MIN, MAX, EXPR }; |
| 122 | + |
| 123 | + struct AggState { |
| 124 | + AggType type = AggType::EXPR; |
| 125 | + int64_t count = 0; |
| 126 | + double sum = 0.0; |
| 127 | + Value min_val{}; |
| 128 | + Value max_val{}; |
| 129 | + bool has_value = false; |
| 130 | + bool count_star = false; // COUNT(*) |
| 131 | + }; |
| 132 | + |
| 133 | + struct GroupState { |
| 134 | + std::vector<Value> group_values; |
| 135 | + std::vector<AggState> agg_states; |
| 136 | + }; |
| 137 | + |
| 138 | + std::unordered_map<std::string, GroupState> groups_; |
| 139 | + std::vector<std::string> group_order_; |
| 140 | + size_t result_idx_ = 0; |
| 141 | + |
| 142 | + std::function<Value(sql_parser::StringRef)> make_resolver(const Row& row) { |
| 143 | + return [this, &row](sql_parser::StringRef col_name) -> Value { |
| 144 | + uint16_t offset = 0; |
| 145 | + for (const auto* table : tables_) { |
| 146 | + if (!table) continue; |
| 147 | + const ColumnInfo* col = catalog_.get_column(table, col_name); |
| 148 | + if (col) { |
| 149 | + uint16_t idx = offset + col->ordinal; |
| 150 | + if (idx < row.column_count) return row.get(idx); |
| 151 | + } |
| 152 | + offset += table->column_count; |
| 153 | + } |
| 154 | + return value_null(); |
| 155 | + }; |
| 156 | + } |
| 157 | + |
| 158 | + std::string compute_group_key(const std::function<Value(sql_parser::StringRef)>& resolver) { |
| 159 | + if (group_count_ == 0) return ""; |
| 160 | + std::string key; |
| 161 | + for (uint16_t i = 0; i < group_count_; ++i) { |
| 162 | + Value v = evaluate_expression<D>(group_by_exprs_[i], resolver, functions_, arena_); |
| 163 | + key += value_to_string(v); |
| 164 | + key += '\x01'; // separator |
| 165 | + } |
| 166 | + return key; |
| 167 | + } |
| 168 | + |
| 169 | + static std::string value_to_string(const Value& v) { |
| 170 | + if (v.is_null()) return "NULL"; |
| 171 | + switch (v.tag) { |
| 172 | + case Value::TAG_BOOL: return v.bool_val ? "1" : "0"; |
| 173 | + case Value::TAG_INT64: return std::to_string(v.int_val); |
| 174 | + case Value::TAG_UINT64: return std::to_string(v.uint_val); |
| 175 | + case Value::TAG_DOUBLE: return std::to_string(v.double_val); |
| 176 | + case Value::TAG_STRING: return std::string(v.str_val.ptr, v.str_val.len); |
| 177 | + default: return "?"; |
| 178 | + } |
| 179 | + } |
| 180 | + |
| 181 | + void detect_agg_type(const sql_parser::AstNode* expr, AggState& state) { |
| 182 | + if (!expr) { state.type = AggType::EXPR; return; } |
| 183 | + |
| 184 | + if (expr->type == sql_parser::NodeType::NODE_FUNCTION_CALL) { |
| 185 | + sql_parser::StringRef name = expr->value(); |
| 186 | + if (name.equals_ci("COUNT", 5)) { |
| 187 | + state.type = AggType::COUNT; |
| 188 | + // Check for COUNT(*) |
| 189 | + const sql_parser::AstNode* arg = expr->first_child; |
| 190 | + if (arg && arg->type == sql_parser::NodeType::NODE_ASTERISK) { |
| 191 | + state.count_star = true; |
| 192 | + } |
| 193 | + return; |
| 194 | + } |
| 195 | + if (name.equals_ci("SUM", 3)) { state.type = AggType::SUM; return; } |
| 196 | + if (name.equals_ci("AVG", 3)) { state.type = AggType::AVG; return; } |
| 197 | + if (name.equals_ci("MIN", 3)) { state.type = AggType::MIN; return; } |
| 198 | + if (name.equals_ci("MAX", 3)) { state.type = AggType::MAX; return; } |
| 199 | + } |
| 200 | + state.type = AggType::EXPR; |
| 201 | + } |
| 202 | + |
| 203 | + void update_agg(AggState& state, const sql_parser::AstNode* expr, |
| 204 | + const std::function<Value(sql_parser::StringRef)>& resolver) { |
| 205 | + switch (state.type) { |
| 206 | + case AggType::COUNT: { |
| 207 | + if (state.count_star) { |
| 208 | + state.count++; |
| 209 | + } else { |
| 210 | + // COUNT(expr) - count non-null values |
| 211 | + const sql_parser::AstNode* arg = expr->first_child; |
| 212 | + Value v = evaluate_expression<D>(arg, resolver, functions_, arena_); |
| 213 | + if (!v.is_null()) state.count++; |
| 214 | + } |
| 215 | + break; |
| 216 | + } |
| 217 | + case AggType::SUM: |
| 218 | + case AggType::AVG: { |
| 219 | + const sql_parser::AstNode* arg = expr->first_child; |
| 220 | + Value v = evaluate_expression<D>(arg, resolver, functions_, arena_); |
| 221 | + if (!v.is_null()) { |
| 222 | + state.sum += v.to_double(); |
| 223 | + state.count++; |
| 224 | + state.has_value = true; |
| 225 | + } |
| 226 | + break; |
| 227 | + } |
| 228 | + case AggType::MIN: { |
| 229 | + const sql_parser::AstNode* arg = expr->first_child; |
| 230 | + Value v = evaluate_expression<D>(arg, resolver, functions_, arena_); |
| 231 | + if (!v.is_null()) { |
| 232 | + if (!state.has_value || compare_values(v, state.min_val) < 0) { |
| 233 | + state.min_val = v; |
| 234 | + state.has_value = true; |
| 235 | + } |
| 236 | + } |
| 237 | + break; |
| 238 | + } |
| 239 | + case AggType::MAX: { |
| 240 | + const sql_parser::AstNode* arg = expr->first_child; |
| 241 | + Value v = evaluate_expression<D>(arg, resolver, functions_, arena_); |
| 242 | + if (!v.is_null()) { |
| 243 | + if (!state.has_value || compare_values(v, state.max_val) > 0) { |
| 244 | + state.max_val = v; |
| 245 | + state.has_value = true; |
| 246 | + } |
| 247 | + } |
| 248 | + break; |
| 249 | + } |
| 250 | + case AggType::EXPR: |
| 251 | + break; |
| 252 | + } |
| 253 | + } |
| 254 | + |
| 255 | + Value finalize_agg(const AggState& state) const { |
| 256 | + switch (state.type) { |
| 257 | + case AggType::COUNT: |
| 258 | + return value_int(state.count); |
| 259 | + case AggType::SUM: |
| 260 | + if (!state.has_value) return value_null(); |
| 261 | + return value_double(state.sum); |
| 262 | + case AggType::AVG: |
| 263 | + if (state.count == 0) return value_null(); |
| 264 | + return value_double(state.sum / static_cast<double>(state.count)); |
| 265 | + case AggType::MIN: |
| 266 | + if (!state.has_value) return value_null(); |
| 267 | + return state.min_val; |
| 268 | + case AggType::MAX: |
| 269 | + if (!state.has_value) return value_null(); |
| 270 | + return state.max_val; |
| 271 | + case AggType::EXPR: |
| 272 | + return value_null(); |
| 273 | + } |
| 274 | + return value_null(); |
| 275 | + } |
| 276 | + |
| 277 | + static int compare_values(const Value& a, const Value& b) { |
| 278 | + if (a.is_null() && b.is_null()) return 0; |
| 279 | + if (a.is_null()) return -1; |
| 280 | + if (b.is_null()) return 1; |
| 281 | + |
| 282 | + // Try numeric comparison |
| 283 | + if (a.is_numeric() && b.is_numeric()) { |
| 284 | + double da = a.to_double(); |
| 285 | + double db = b.to_double(); |
| 286 | + if (da < db) return -1; |
| 287 | + if (da > db) return 1; |
| 288 | + return 0; |
| 289 | + } |
| 290 | + |
| 291 | + // String comparison |
| 292 | + if (a.tag == Value::TAG_STRING && b.tag == Value::TAG_STRING) { |
| 293 | + uint32_t minlen = a.str_val.len < b.str_val.len ? a.str_val.len : b.str_val.len; |
| 294 | + int cmp = std::memcmp(a.str_val.ptr, b.str_val.ptr, minlen); |
| 295 | + if (cmp != 0) return cmp; |
| 296 | + if (a.str_val.len < b.str_val.len) return -1; |
| 297 | + if (a.str_val.len > b.str_val.len) return 1; |
| 298 | + return 0; |
| 299 | + } |
| 300 | + |
| 301 | + return 0; |
| 302 | + } |
| 303 | +}; |
| 304 | + |
| 305 | +} // namespace sql_engine |
| 306 | + |
| 307 | +#endif // SQL_ENGINE_OPERATORS_AGGREGATE_OP_H |
0 commit comments