Skip to content

Commit df21f0f

Browse files
committed
Add SQL engine type system: NULL semantics, coercion, functions, registry
Tasks 2-8 of the type system plan: - NULL semantics: three-valued AND/OR/NOT, propagate_null - Type coercion: MySQL (permissive) and PostgreSQL (strict) via template specialization - Function registry: dialect-templated FunctionRegistry with case-insensitive lookup - Arithmetic functions: ABS, CEIL, FLOOR, ROUND, TRUNCATE, MOD, POWER, SQRT, SIGN - Comparison functions: COALESCE, NULLIF, IFNULL, IF, LEAST, GREATEST - String functions: CONCAT, CONCAT_WS, LENGTH, UPPER, LOWER, SUBSTRING, TRIM, REPLACE, etc. - Cast functions: cast_value<D>() with MySQL lenient and PostgreSQL strict parsing - register_builtins() wiring for both dialects (38 MySQL, 36 PostgreSQL functions) 650 tests pass (454 parser + 196 type system), zero warnings.
1 parent 2085b17 commit df21f0f

16 files changed

Lines changed: 2262 additions & 3 deletions

Makefile.new

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ CPPFLAGS = -I./include -I./third_party/googletest/googletest/include
44

55
PROJECT_ROOT = .
66
SRC_DIR = $(PROJECT_ROOT)/src/sql_parser
7+
ENGINE_SRC_DIR = $(PROJECT_ROOT)/src/sql_engine
78
INCLUDE_DIR = $(PROJECT_ROOT)/include/sql_parser
89
TEST_DIR = $(PROJECT_ROOT)/tests
910

@@ -12,6 +13,10 @@ LIB_SRCS = $(SRC_DIR)/arena.cpp $(SRC_DIR)/parser.cpp
1213
LIB_OBJS = $(LIB_SRCS:.cpp=.o)
1314
LIB_TARGET = $(PROJECT_ROOT)/libsqlparser.a
1415

16+
# SQL Engine sources
17+
ENGINE_SRCS = $(ENGINE_SRC_DIR)/function_registry.cpp
18+
ENGINE_OBJS = $(ENGINE_SRCS:.cpp=.o)
19+
1520
# Google Test library
1621
GTEST_DIR = $(PROJECT_ROOT)/third_party/googletest/googletest
1722
GTEST_SRC = $(GTEST_DIR)/src/gtest-all.cc
@@ -34,7 +39,14 @@ TEST_SRCS = $(TEST_DIR)/test_main.cpp \
3439
$(TEST_DIR)/test_compound.cpp \
3540
$(TEST_DIR)/test_digest.cpp \
3641
$(TEST_DIR)/test_misc_stmts.cpp \
37-
$(TEST_DIR)/test_value.cpp
42+
$(TEST_DIR)/test_value.cpp \
43+
$(TEST_DIR)/test_null_semantics.cpp \
44+
$(TEST_DIR)/test_coercion.cpp \
45+
$(TEST_DIR)/test_arithmetic.cpp \
46+
$(TEST_DIR)/test_comparison.cpp \
47+
$(TEST_DIR)/test_string_funcs.cpp \
48+
$(TEST_DIR)/test_cast.cpp \
49+
$(TEST_DIR)/test_registry.cpp
3850
TEST_OBJS = $(TEST_SRCS:.cpp=.o)
3951
TEST_TARGET = $(PROJECT_ROOT)/run_tests
4052

@@ -68,6 +80,10 @@ $(LIB_TARGET): $(LIB_OBJS)
6880
$(SRC_DIR)/%.o: $(SRC_DIR)/%.cpp
6981
$(CXX) $(CXXFLAGS) $(CPPFLAGS) -c $< -o $@
7082

83+
# SQL Engine objects
84+
$(ENGINE_SRC_DIR)/%.o: $(ENGINE_SRC_DIR)/%.cpp
85+
$(CXX) $(CXXFLAGS) $(CPPFLAGS) -c $< -o $@
86+
7187
# Google Test object
7288
$(GTEST_OBJ): $(GTEST_SRC)
7389
$(CXX) $(CXXFLAGS) $(GTEST_CPPFLAGS) -c $< -o $@
@@ -79,8 +95,8 @@ $(TEST_DIR)/%.o: $(TEST_DIR)/%.cpp
7995
test: $(TEST_TARGET)
8096
./$(TEST_TARGET)
8197

82-
$(TEST_TARGET): $(TEST_OBJS) $(GTEST_OBJ) $(LIB_TARGET)
83-
$(CXX) $(CXXFLAGS) -o $@ $(TEST_OBJS) $(GTEST_OBJ) -L$(PROJECT_ROOT) -lsqlparser -lpthread
98+
$(TEST_TARGET): $(TEST_OBJS) $(GTEST_OBJ) $(LIB_TARGET) $(ENGINE_OBJS)
99+
$(CXX) $(CXXFLAGS) -o $@ $(TEST_OBJS) $(GTEST_OBJ) $(ENGINE_OBJS) -L$(PROJECT_ROOT) -lsqlparser -lpthread
84100

85101
# Benchmark objects
86102
$(GBENCH_DIR)/src/%.o: $(GBENCH_DIR)/src/%.cc
@@ -118,6 +134,7 @@ $(BENCH_COMPARE_TARGET): $(BENCH_DIR)/bench_main.o $(BENCH_COMPARE_OBJ) $(GBENCH
118134

119135
clean:
120136
rm -f $(LIB_OBJS) $(LIB_TARGET) $(TEST_OBJS) $(GTEST_OBJ) $(TEST_TARGET)
137+
rm -f $(ENGINE_OBJS)
121138
rm -f $(BENCH_OBJS) $(GBENCH_OBJS) $(BENCH_TARGET) $(CORPUS_TEST_TARGET)
122139
rm -f $(BENCH_COMPARE_OBJ) $(BENCH_COMPARE_TARGET)
123140
@echo "Cleaned."

include/sql_engine/coercion.h

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
#ifndef SQL_ENGINE_COERCION_H
2+
#define SQL_ENGINE_COERCION_H
3+
4+
#include "sql_engine/types.h"
5+
#include "sql_engine/value.h"
6+
#include "sql_parser/common.h" // Dialect, StringRef
7+
#include "sql_parser/arena.h" // Arena
8+
#include <cstdlib>
9+
#include <cstdio>
10+
#include <cstring>
11+
12+
namespace sql_engine {
13+
14+
using sql_parser::Dialect;
15+
using sql_parser::Arena;
16+
17+
template <Dialect D>
18+
struct CoercionRules {
19+
// Can `from` be implicitly coerced to `to`?
20+
static bool can_coerce(SqlType::Kind from, SqlType::Kind to);
21+
22+
// What common type should two operands be promoted to?
23+
static SqlType::Kind common_type(SqlType::Kind left, SqlType::Kind right);
24+
25+
// Perform the coercion. Returns new Value with target tag, or NULL on failure.
26+
static Value coerce_value(const Value& val, Value::Tag target, Arena& arena);
27+
};
28+
29+
// ----- Helper: parse int from string (MySQL: lenient, stops at first non-digit) -----
30+
namespace detail {
31+
32+
inline bool parse_int_lenient(const char* s, uint32_t len, int64_t& out) {
33+
if (len == 0) { out = 0; return true; }
34+
char* end = nullptr;
35+
char buf[64];
36+
uint32_t n = len < 63 ? len : 63;
37+
std::memcpy(buf, s, n);
38+
buf[n] = '\0';
39+
out = std::strtoll(buf, &end, 10);
40+
return end != buf; // at least one char consumed
41+
}
42+
43+
inline bool parse_int_strict(const char* s, uint32_t len, int64_t& out) {
44+
if (len == 0) return false;
45+
char* end = nullptr;
46+
char buf[64];
47+
uint32_t n = len < 63 ? len : 63;
48+
std::memcpy(buf, s, n);
49+
buf[n] = '\0';
50+
out = std::strtoll(buf, &end, 10);
51+
// Strict: entire string must be consumed (ignoring trailing whitespace)
52+
while (end && *end == ' ') ++end;
53+
return end == buf + n;
54+
}
55+
56+
inline bool parse_double_lenient(const char* s, uint32_t len, double& out) {
57+
if (len == 0) { out = 0.0; return true; }
58+
char buf[128];
59+
uint32_t n = len < 127 ? len : 127;
60+
std::memcpy(buf, s, n);
61+
buf[n] = '\0';
62+
char* end = nullptr;
63+
out = std::strtod(buf, &end);
64+
return end != buf;
65+
}
66+
67+
inline StringRef int_to_string(int64_t v, Arena& arena) {
68+
char buf[32];
69+
int n = std::snprintf(buf, sizeof(buf), "%lld", (long long)v);
70+
return arena.allocate_string(buf, static_cast<uint32_t>(n));
71+
}
72+
73+
inline StringRef double_to_string(double v, Arena& arena) {
74+
char buf[64];
75+
int n = std::snprintf(buf, sizeof(buf), "%g", v);
76+
return arena.allocate_string(buf, static_cast<uint32_t>(n));
77+
}
78+
79+
} // namespace detail
80+
81+
// ----- MySQL specialization (permissive) -----
82+
83+
template <>
84+
inline bool CoercionRules<Dialect::MySQL>::can_coerce(SqlType::Kind from, SqlType::Kind to) {
85+
if (from == to) return true;
86+
if (from == SqlType::NULL_TYPE) return true; // NULL coerces to anything
87+
88+
// Numeric within-category: always
89+
SqlType probe_from{from}; SqlType probe_to{to};
90+
if (probe_from.is_numeric() && probe_to.is_numeric()) return true;
91+
92+
// String <-> Numeric: MySQL allows
93+
if (probe_from.is_string() && probe_to.is_numeric()) return true;
94+
if (probe_from.is_numeric() && probe_to.is_string()) return true;
95+
96+
// String <-> Temporal: MySQL allows (parse attempt)
97+
if (probe_from.is_string() && probe_to.is_temporal()) return true;
98+
if (probe_from.is_temporal() && probe_to.is_string()) return true;
99+
100+
// Temporal within-category: allowed
101+
if (probe_from.is_temporal() && probe_to.is_temporal()) return true;
102+
103+
return false;
104+
}
105+
106+
template <>
107+
inline SqlType::Kind CoercionRules<Dialect::MySQL>::common_type(SqlType::Kind left, SqlType::Kind right) {
108+
if (left == right) return left;
109+
110+
SqlType pl{left}; SqlType pr{right};
111+
112+
// Both numeric: promote up the hierarchy
113+
if (pl.is_numeric() && pr.is_numeric()) {
114+
// BOOL < TINYINT < ... < BIGINT < FLOAT < DOUBLE < DECIMAL
115+
// Use the wider kind
116+
return left > right ? left : right;
117+
}
118+
119+
// One string, one numeric: MySQL promotes to numeric (DOUBLE)
120+
if ((pl.is_numeric() && pr.is_string()) || (pl.is_string() && pr.is_numeric())) {
121+
return SqlType::DOUBLE;
122+
}
123+
124+
// Both temporal: promote to wider
125+
if (pl.is_temporal() && pr.is_temporal()) {
126+
return left > right ? left : right;
127+
}
128+
129+
// Fallback: STRING
130+
return SqlType::VARCHAR;
131+
}
132+
133+
template <>
134+
inline Value CoercionRules<Dialect::MySQL>::coerce_value(const Value& val, Value::Tag target, Arena& arena) {
135+
if (val.tag == target) return val;
136+
if (val.is_null()) return value_null();
137+
138+
switch (target) {
139+
case Value::TAG_INT64: {
140+
if (val.tag == Value::TAG_BOOL) return value_int(val.bool_val ? 1 : 0);
141+
if (val.tag == Value::TAG_UINT64) return value_int(static_cast<int64_t>(val.uint_val));
142+
if (val.tag == Value::TAG_DOUBLE) return value_int(static_cast<int64_t>(val.double_val));
143+
if (val.tag == Value::TAG_STRING || val.tag == Value::TAG_DECIMAL) {
144+
int64_t out;
145+
if (detail::parse_int_lenient(val.str_val.ptr, val.str_val.len, out))
146+
return value_int(out);
147+
}
148+
return value_null();
149+
}
150+
case Value::TAG_DOUBLE: {
151+
if (val.tag == Value::TAG_BOOL) return value_double(val.bool_val ? 1.0 : 0.0);
152+
if (val.tag == Value::TAG_INT64) return value_double(static_cast<double>(val.int_val));
153+
if (val.tag == Value::TAG_UINT64) return value_double(static_cast<double>(val.uint_val));
154+
if (val.tag == Value::TAG_STRING || val.tag == Value::TAG_DECIMAL) {
155+
double out;
156+
if (detail::parse_double_lenient(val.str_val.ptr, val.str_val.len, out))
157+
return value_double(out);
158+
}
159+
return value_null();
160+
}
161+
case Value::TAG_STRING: {
162+
if (val.tag == Value::TAG_INT64)
163+
return value_string(detail::int_to_string(val.int_val, arena));
164+
if (val.tag == Value::TAG_UINT64) {
165+
char buf[32];
166+
int n = std::snprintf(buf, sizeof(buf), "%llu", (unsigned long long)val.uint_val);
167+
return value_string(arena.allocate_string(buf, static_cast<uint32_t>(n)));
168+
}
169+
if (val.tag == Value::TAG_DOUBLE)
170+
return value_string(detail::double_to_string(val.double_val, arena));
171+
if (val.tag == Value::TAG_BOOL)
172+
return value_string(val.bool_val
173+
? arena.allocate_string("1", 1)
174+
: arena.allocate_string("0", 1));
175+
return value_null();
176+
}
177+
case Value::TAG_BOOL: {
178+
if (val.tag == Value::TAG_INT64) return value_bool(val.int_val != 0);
179+
if (val.tag == Value::TAG_DOUBLE) return value_bool(val.double_val != 0.0);
180+
return value_null();
181+
}
182+
default:
183+
return value_null();
184+
}
185+
}
186+
187+
// ----- PostgreSQL specialization (strict) -----
188+
189+
template <>
190+
inline bool CoercionRules<Dialect::PostgreSQL>::can_coerce(SqlType::Kind from, SqlType::Kind to) {
191+
if (from == to) return true;
192+
if (from == SqlType::NULL_TYPE) return true;
193+
194+
SqlType probe_from{from}; SqlType probe_to{to};
195+
196+
// Within-category numeric promotions only
197+
if (probe_from.is_numeric() && probe_to.is_numeric()) {
198+
// Only allow widening: kind value must increase (narrower to wider)
199+
return to > from;
200+
}
201+
202+
// Temporal within-category promotions
203+
if (probe_from.is_temporal() && probe_to.is_temporal()) {
204+
return to > from;
205+
}
206+
207+
// No cross-category implicit coercion in PostgreSQL
208+
return false;
209+
}
210+
211+
template <>
212+
inline SqlType::Kind CoercionRules<Dialect::PostgreSQL>::common_type(SqlType::Kind left, SqlType::Kind right) {
213+
if (left == right) return left;
214+
215+
SqlType pl{left}; SqlType pr{right};
216+
217+
// Both numeric: promote to wider
218+
if (pl.is_numeric() && pr.is_numeric()) {
219+
return left > right ? left : right;
220+
}
221+
222+
// Both temporal: promote to wider
223+
if (pl.is_temporal() && pr.is_temporal()) {
224+
return left > right ? left : right;
225+
}
226+
227+
// Cross-category: return UNKNOWN (error)
228+
return SqlType::UNKNOWN;
229+
}
230+
231+
template <>
232+
inline Value CoercionRules<Dialect::PostgreSQL>::coerce_value(const Value& val, Value::Tag target, Arena& /*arena*/) {
233+
if (val.tag == target) return val;
234+
if (val.is_null()) return value_null();
235+
236+
// PostgreSQL: only within-category promotions
237+
switch (target) {
238+
case Value::TAG_INT64: {
239+
if (val.tag == Value::TAG_BOOL) return value_int(val.bool_val ? 1 : 0);
240+
if (val.tag == Value::TAG_UINT64) return value_int(static_cast<int64_t>(val.uint_val));
241+
// String -> int: NOT allowed implicitly in PostgreSQL
242+
return value_null();
243+
}
244+
case Value::TAG_DOUBLE: {
245+
if (val.tag == Value::TAG_BOOL) return value_double(val.bool_val ? 1.0 : 0.0);
246+
if (val.tag == Value::TAG_INT64) return value_double(static_cast<double>(val.int_val));
247+
if (val.tag == Value::TAG_UINT64) return value_double(static_cast<double>(val.uint_val));
248+
// String -> double: NOT allowed implicitly
249+
return value_null();
250+
}
251+
case Value::TAG_BOOL: {
252+
// PostgreSQL does not implicitly convert int to bool
253+
return value_null();
254+
}
255+
default:
256+
return value_null();
257+
}
258+
}
259+
260+
} // namespace sql_engine
261+
262+
#endif // SQL_ENGINE_COERCION_H
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#ifndef SQL_ENGINE_FUNCTION_REGISTRY_H
2+
#define SQL_ENGINE_FUNCTION_REGISTRY_H
3+
4+
#include "sql_engine/value.h"
5+
#include "sql_parser/common.h"
6+
#include "sql_parser/arena.h"
7+
#include <cstdint>
8+
#include <cstring>
9+
10+
namespace sql_engine {
11+
12+
using sql_parser::Dialect;
13+
using sql_parser::Arena;
14+
15+
// Function signature: takes array of args, count, arena for allocations.
16+
using SqlFunction = Value(*)(const Value* args, uint16_t arg_count, Arena& arena);
17+
18+
struct FunctionEntry {
19+
const char* name; // uppercased canonical name
20+
uint32_t name_len;
21+
SqlFunction impl;
22+
uint8_t min_args;
23+
uint8_t max_args; // 255 = variadic
24+
};
25+
26+
template <Dialect D>
27+
class FunctionRegistry {
28+
public:
29+
static constexpr uint32_t MAX_FUNCTIONS = 256;
30+
31+
void register_function(const FunctionEntry& entry) {
32+
if (count_ < MAX_FUNCTIONS) {
33+
entries_[count_++] = entry;
34+
}
35+
}
36+
37+
const FunctionEntry* lookup(const char* name, uint32_t name_len) const {
38+
for (uint32_t i = 0; i < count_; ++i) {
39+
if (entries_[i].name_len == name_len &&
40+
ci_compare(entries_[i].name, name, name_len) == 0) {
41+
return &entries_[i];
42+
}
43+
}
44+
return nullptr;
45+
}
46+
47+
// Register all built-in functions for this dialect.
48+
// Implemented in function_registry.cpp.
49+
void register_builtins();
50+
51+
uint32_t size() const { return count_; }
52+
53+
private:
54+
FunctionEntry entries_[MAX_FUNCTIONS] = {};
55+
uint32_t count_ = 0;
56+
57+
static int ci_compare(const char* a, const char* b, uint32_t len) {
58+
for (uint32_t i = 0; i < len; ++i) {
59+
char ca = a[i]; if (ca >= 'a' && ca <= 'z') ca -= 32;
60+
char cb = b[i]; if (cb >= 'a' && cb <= 'z') cb -= 32;
61+
if (ca != cb) return ca < cb ? -1 : 1;
62+
}
63+
return 0;
64+
}
65+
};
66+
67+
} // namespace sql_engine
68+
69+
#endif // SQL_ENGINE_FUNCTION_REGISTRY_H

0 commit comments

Comments
 (0)