|
| 1 | +#ifndef SQL_ENGINE_COERCION_H |
| 2 | +#define SQL_ENGINE_COERCION_H |
| 3 | + |
| 4 | +#include "sql_engine/types.h" |
| 5 | +#include "sql_engine/value.h" |
| 6 | +#include "sql_parser/common.h" // Dialect, StringRef |
| 7 | +#include "sql_parser/arena.h" // Arena |
| 8 | +#include <cstdlib> |
| 9 | +#include <cstdio> |
| 10 | +#include <cstring> |
| 11 | + |
| 12 | +namespace sql_engine { |
| 13 | + |
| 14 | +using sql_parser::Dialect; |
| 15 | +using sql_parser::Arena; |
| 16 | + |
| 17 | +template <Dialect D> |
| 18 | +struct CoercionRules { |
| 19 | + // Can `from` be implicitly coerced to `to`? |
| 20 | + static bool can_coerce(SqlType::Kind from, SqlType::Kind to); |
| 21 | + |
| 22 | + // What common type should two operands be promoted to? |
| 23 | + static SqlType::Kind common_type(SqlType::Kind left, SqlType::Kind right); |
| 24 | + |
| 25 | + // Perform the coercion. Returns new Value with target tag, or NULL on failure. |
| 26 | + static Value coerce_value(const Value& val, Value::Tag target, Arena& arena); |
| 27 | +}; |
| 28 | + |
| 29 | +// ----- Helper: parse int from string (MySQL: lenient, stops at first non-digit) ----- |
| 30 | +namespace detail { |
| 31 | + |
| 32 | +inline bool parse_int_lenient(const char* s, uint32_t len, int64_t& out) { |
| 33 | + if (len == 0) { out = 0; return true; } |
| 34 | + char* end = nullptr; |
| 35 | + char buf[64]; |
| 36 | + uint32_t n = len < 63 ? len : 63; |
| 37 | + std::memcpy(buf, s, n); |
| 38 | + buf[n] = '\0'; |
| 39 | + out = std::strtoll(buf, &end, 10); |
| 40 | + return end != buf; // at least one char consumed |
| 41 | +} |
| 42 | + |
| 43 | +inline bool parse_int_strict(const char* s, uint32_t len, int64_t& out) { |
| 44 | + if (len == 0) return false; |
| 45 | + char* end = nullptr; |
| 46 | + char buf[64]; |
| 47 | + uint32_t n = len < 63 ? len : 63; |
| 48 | + std::memcpy(buf, s, n); |
| 49 | + buf[n] = '\0'; |
| 50 | + out = std::strtoll(buf, &end, 10); |
| 51 | + // Strict: entire string must be consumed (ignoring trailing whitespace) |
| 52 | + while (end && *end == ' ') ++end; |
| 53 | + return end == buf + n; |
| 54 | +} |
| 55 | + |
| 56 | +inline bool parse_double_lenient(const char* s, uint32_t len, double& out) { |
| 57 | + if (len == 0) { out = 0.0; return true; } |
| 58 | + char buf[128]; |
| 59 | + uint32_t n = len < 127 ? len : 127; |
| 60 | + std::memcpy(buf, s, n); |
| 61 | + buf[n] = '\0'; |
| 62 | + char* end = nullptr; |
| 63 | + out = std::strtod(buf, &end); |
| 64 | + return end != buf; |
| 65 | +} |
| 66 | + |
| 67 | +inline StringRef int_to_string(int64_t v, Arena& arena) { |
| 68 | + char buf[32]; |
| 69 | + int n = std::snprintf(buf, sizeof(buf), "%lld", (long long)v); |
| 70 | + return arena.allocate_string(buf, static_cast<uint32_t>(n)); |
| 71 | +} |
| 72 | + |
| 73 | +inline StringRef double_to_string(double v, Arena& arena) { |
| 74 | + char buf[64]; |
| 75 | + int n = std::snprintf(buf, sizeof(buf), "%g", v); |
| 76 | + return arena.allocate_string(buf, static_cast<uint32_t>(n)); |
| 77 | +} |
| 78 | + |
| 79 | +} // namespace detail |
| 80 | + |
| 81 | +// ----- MySQL specialization (permissive) ----- |
| 82 | + |
| 83 | +template <> |
| 84 | +inline bool CoercionRules<Dialect::MySQL>::can_coerce(SqlType::Kind from, SqlType::Kind to) { |
| 85 | + if (from == to) return true; |
| 86 | + if (from == SqlType::NULL_TYPE) return true; // NULL coerces to anything |
| 87 | + |
| 88 | + // Numeric within-category: always |
| 89 | + SqlType probe_from{from}; SqlType probe_to{to}; |
| 90 | + if (probe_from.is_numeric() && probe_to.is_numeric()) return true; |
| 91 | + |
| 92 | + // String <-> Numeric: MySQL allows |
| 93 | + if (probe_from.is_string() && probe_to.is_numeric()) return true; |
| 94 | + if (probe_from.is_numeric() && probe_to.is_string()) return true; |
| 95 | + |
| 96 | + // String <-> Temporal: MySQL allows (parse attempt) |
| 97 | + if (probe_from.is_string() && probe_to.is_temporal()) return true; |
| 98 | + if (probe_from.is_temporal() && probe_to.is_string()) return true; |
| 99 | + |
| 100 | + // Temporal within-category: allowed |
| 101 | + if (probe_from.is_temporal() && probe_to.is_temporal()) return true; |
| 102 | + |
| 103 | + return false; |
| 104 | +} |
| 105 | + |
| 106 | +template <> |
| 107 | +inline SqlType::Kind CoercionRules<Dialect::MySQL>::common_type(SqlType::Kind left, SqlType::Kind right) { |
| 108 | + if (left == right) return left; |
| 109 | + |
| 110 | + SqlType pl{left}; SqlType pr{right}; |
| 111 | + |
| 112 | + // Both numeric: promote up the hierarchy |
| 113 | + if (pl.is_numeric() && pr.is_numeric()) { |
| 114 | + // BOOL < TINYINT < ... < BIGINT < FLOAT < DOUBLE < DECIMAL |
| 115 | + // Use the wider kind |
| 116 | + return left > right ? left : right; |
| 117 | + } |
| 118 | + |
| 119 | + // One string, one numeric: MySQL promotes to numeric (DOUBLE) |
| 120 | + if ((pl.is_numeric() && pr.is_string()) || (pl.is_string() && pr.is_numeric())) { |
| 121 | + return SqlType::DOUBLE; |
| 122 | + } |
| 123 | + |
| 124 | + // Both temporal: promote to wider |
| 125 | + if (pl.is_temporal() && pr.is_temporal()) { |
| 126 | + return left > right ? left : right; |
| 127 | + } |
| 128 | + |
| 129 | + // Fallback: STRING |
| 130 | + return SqlType::VARCHAR; |
| 131 | +} |
| 132 | + |
| 133 | +template <> |
| 134 | +inline Value CoercionRules<Dialect::MySQL>::coerce_value(const Value& val, Value::Tag target, Arena& arena) { |
| 135 | + if (val.tag == target) return val; |
| 136 | + if (val.is_null()) return value_null(); |
| 137 | + |
| 138 | + switch (target) { |
| 139 | + case Value::TAG_INT64: { |
| 140 | + if (val.tag == Value::TAG_BOOL) return value_int(val.bool_val ? 1 : 0); |
| 141 | + if (val.tag == Value::TAG_UINT64) return value_int(static_cast<int64_t>(val.uint_val)); |
| 142 | + if (val.tag == Value::TAG_DOUBLE) return value_int(static_cast<int64_t>(val.double_val)); |
| 143 | + if (val.tag == Value::TAG_STRING || val.tag == Value::TAG_DECIMAL) { |
| 144 | + int64_t out; |
| 145 | + if (detail::parse_int_lenient(val.str_val.ptr, val.str_val.len, out)) |
| 146 | + return value_int(out); |
| 147 | + } |
| 148 | + return value_null(); |
| 149 | + } |
| 150 | + case Value::TAG_DOUBLE: { |
| 151 | + if (val.tag == Value::TAG_BOOL) return value_double(val.bool_val ? 1.0 : 0.0); |
| 152 | + if (val.tag == Value::TAG_INT64) return value_double(static_cast<double>(val.int_val)); |
| 153 | + if (val.tag == Value::TAG_UINT64) return value_double(static_cast<double>(val.uint_val)); |
| 154 | + if (val.tag == Value::TAG_STRING || val.tag == Value::TAG_DECIMAL) { |
| 155 | + double out; |
| 156 | + if (detail::parse_double_lenient(val.str_val.ptr, val.str_val.len, out)) |
| 157 | + return value_double(out); |
| 158 | + } |
| 159 | + return value_null(); |
| 160 | + } |
| 161 | + case Value::TAG_STRING: { |
| 162 | + if (val.tag == Value::TAG_INT64) |
| 163 | + return value_string(detail::int_to_string(val.int_val, arena)); |
| 164 | + if (val.tag == Value::TAG_UINT64) { |
| 165 | + char buf[32]; |
| 166 | + int n = std::snprintf(buf, sizeof(buf), "%llu", (unsigned long long)val.uint_val); |
| 167 | + return value_string(arena.allocate_string(buf, static_cast<uint32_t>(n))); |
| 168 | + } |
| 169 | + if (val.tag == Value::TAG_DOUBLE) |
| 170 | + return value_string(detail::double_to_string(val.double_val, arena)); |
| 171 | + if (val.tag == Value::TAG_BOOL) |
| 172 | + return value_string(val.bool_val |
| 173 | + ? arena.allocate_string("1", 1) |
| 174 | + : arena.allocate_string("0", 1)); |
| 175 | + return value_null(); |
| 176 | + } |
| 177 | + case Value::TAG_BOOL: { |
| 178 | + if (val.tag == Value::TAG_INT64) return value_bool(val.int_val != 0); |
| 179 | + if (val.tag == Value::TAG_DOUBLE) return value_bool(val.double_val != 0.0); |
| 180 | + return value_null(); |
| 181 | + } |
| 182 | + default: |
| 183 | + return value_null(); |
| 184 | + } |
| 185 | +} |
| 186 | + |
| 187 | +// ----- PostgreSQL specialization (strict) ----- |
| 188 | + |
| 189 | +template <> |
| 190 | +inline bool CoercionRules<Dialect::PostgreSQL>::can_coerce(SqlType::Kind from, SqlType::Kind to) { |
| 191 | + if (from == to) return true; |
| 192 | + if (from == SqlType::NULL_TYPE) return true; |
| 193 | + |
| 194 | + SqlType probe_from{from}; SqlType probe_to{to}; |
| 195 | + |
| 196 | + // Within-category numeric promotions only |
| 197 | + if (probe_from.is_numeric() && probe_to.is_numeric()) { |
| 198 | + // Only allow widening: kind value must increase (narrower to wider) |
| 199 | + return to > from; |
| 200 | + } |
| 201 | + |
| 202 | + // Temporal within-category promotions |
| 203 | + if (probe_from.is_temporal() && probe_to.is_temporal()) { |
| 204 | + return to > from; |
| 205 | + } |
| 206 | + |
| 207 | + // No cross-category implicit coercion in PostgreSQL |
| 208 | + return false; |
| 209 | +} |
| 210 | + |
| 211 | +template <> |
| 212 | +inline SqlType::Kind CoercionRules<Dialect::PostgreSQL>::common_type(SqlType::Kind left, SqlType::Kind right) { |
| 213 | + if (left == right) return left; |
| 214 | + |
| 215 | + SqlType pl{left}; SqlType pr{right}; |
| 216 | + |
| 217 | + // Both numeric: promote to wider |
| 218 | + if (pl.is_numeric() && pr.is_numeric()) { |
| 219 | + return left > right ? left : right; |
| 220 | + } |
| 221 | + |
| 222 | + // Both temporal: promote to wider |
| 223 | + if (pl.is_temporal() && pr.is_temporal()) { |
| 224 | + return left > right ? left : right; |
| 225 | + } |
| 226 | + |
| 227 | + // Cross-category: return UNKNOWN (error) |
| 228 | + return SqlType::UNKNOWN; |
| 229 | +} |
| 230 | + |
| 231 | +template <> |
| 232 | +inline Value CoercionRules<Dialect::PostgreSQL>::coerce_value(const Value& val, Value::Tag target, Arena& /*arena*/) { |
| 233 | + if (val.tag == target) return val; |
| 234 | + if (val.is_null()) return value_null(); |
| 235 | + |
| 236 | + // PostgreSQL: only within-category promotions |
| 237 | + switch (target) { |
| 238 | + case Value::TAG_INT64: { |
| 239 | + if (val.tag == Value::TAG_BOOL) return value_int(val.bool_val ? 1 : 0); |
| 240 | + if (val.tag == Value::TAG_UINT64) return value_int(static_cast<int64_t>(val.uint_val)); |
| 241 | + // String -> int: NOT allowed implicitly in PostgreSQL |
| 242 | + return value_null(); |
| 243 | + } |
| 244 | + case Value::TAG_DOUBLE: { |
| 245 | + if (val.tag == Value::TAG_BOOL) return value_double(val.bool_val ? 1.0 : 0.0); |
| 246 | + if (val.tag == Value::TAG_INT64) return value_double(static_cast<double>(val.int_val)); |
| 247 | + if (val.tag == Value::TAG_UINT64) return value_double(static_cast<double>(val.uint_val)); |
| 248 | + // String -> double: NOT allowed implicitly |
| 249 | + return value_null(); |
| 250 | + } |
| 251 | + case Value::TAG_BOOL: { |
| 252 | + // PostgreSQL does not implicitly convert int to bool |
| 253 | + return value_null(); |
| 254 | + } |
| 255 | + default: |
| 256 | + return value_null(); |
| 257 | + } |
| 258 | +} |
| 259 | + |
| 260 | +} // namespace sql_engine |
| 261 | + |
| 262 | +#endif // SQL_ENGINE_COERCION_H |
0 commit comments