Skip to content

Commit 8c4d7de

Browse files
committed
feat: add prepared statement cache with parse_and_cache, execute, and bindings-aware emitter
1 parent 2a9afd0 commit 8c4d7de

7 files changed

Lines changed: 555 additions & 4 deletions

File tree

Makefile.new

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ TEST_SRCS = $(TEST_DIR)/test_main.cpp \
2626
$(TEST_DIR)/test_expression.cpp \
2727
$(TEST_DIR)/test_set.cpp \
2828
$(TEST_DIR)/test_select.cpp \
29-
$(TEST_DIR)/test_emitter.cpp
29+
$(TEST_DIR)/test_emitter.cpp \
30+
$(TEST_DIR)/test_stmt_cache.cpp
3031
TEST_OBJS = $(TEST_SRCS:.cpp=.o)
3132
TEST_TARGET = $(PROJECT_ROOT)/run_tests
3233

include/sql_parser/emitter.h

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
#include "sql_parser/ast.h"
66
#include "sql_parser/arena.h"
77
#include "sql_parser/string_builder.h"
8+
#include "sql_parser/parse_result.h"
9+
#include <cstdio>
810

911
namespace sql_parser {
1012

1113
template <Dialect D>
1214
class Emitter {
1315
public:
14-
explicit Emitter(Arena& arena) : sb_(arena) {}
16+
explicit Emitter(Arena& arena, const ParamBindings* bindings = nullptr)
17+
: sb_(arena), bindings_(bindings), placeholder_index_(0) {}
1518

1619
void emit(const AstNode* node) {
1720
if (!node) return;
@@ -22,6 +25,8 @@ class Emitter {
2225

2326
private:
2427
StringBuilder sb_;
28+
const ParamBindings* bindings_;
29+
uint16_t placeholder_index_;
2530

2631
void emit_node(const AstNode* node) {
2732
switch (node->type) {
@@ -65,13 +70,16 @@ class Emitter {
6570
case NodeType::NODE_CASE_WHEN: emit_case_when(node); break;
6671
case NodeType::NODE_SUBQUERY: emit_value(node); break;
6772

73+
// ---- Leaf nodes (emit value directly) ----
74+
case NodeType::NODE_PLACEHOLDER:
75+
emit_placeholder(node); break;
76+
6877
// ---- Leaf nodes (emit value directly) ----
6978
case NodeType::NODE_LITERAL_INT:
7079
case NodeType::NODE_LITERAL_FLOAT:
7180
case NodeType::NODE_LITERAL_NULL:
7281
case NodeType::NODE_COLUMN_REF:
7382
case NodeType::NODE_ASTERISK:
74-
case NodeType::NODE_PLACEHOLDER:
7583
case NodeType::NODE_IDENTIFIER:
7684
emit_value(node); break;
7785

@@ -93,6 +101,43 @@ class Emitter {
93101
sb_.append_char('\'');
94102
}
95103

104+
void emit_placeholder(const AstNode* node) {
105+
if (bindings_ && placeholder_index_ < bindings_->count) {
106+
const BoundValue& bv = bindings_->values[placeholder_index_];
107+
++placeholder_index_;
108+
switch (bv.type) {
109+
case BoundValue::INT:
110+
{ char buf[32]; int n = snprintf(buf, sizeof(buf), "%lld", (long long)bv.int_val);
111+
sb_.append(buf, n); }
112+
break;
113+
case BoundValue::FLOAT:
114+
{ char buf[64]; int n = snprintf(buf, sizeof(buf), "%g", (double)bv.float32_val);
115+
sb_.append(buf, n); }
116+
break;
117+
case BoundValue::DOUBLE:
118+
{ char buf[64]; int n = snprintf(buf, sizeof(buf), "%g", bv.float64_val);
119+
sb_.append(buf, n); }
120+
break;
121+
case BoundValue::STRING:
122+
case BoundValue::DATETIME:
123+
case BoundValue::DECIMAL:
124+
sb_.append_char('\'');
125+
sb_.append(bv.str_val);
126+
sb_.append_char('\'');
127+
break;
128+
case BoundValue::BLOB:
129+
sb_.append(bv.str_val);
130+
break;
131+
case BoundValue::NULL_VAL:
132+
sb_.append("NULL", 4);
133+
break;
134+
}
135+
} else {
136+
// No binding available -- emit placeholder as-is
137+
emit_value(node);
138+
}
139+
}
140+
96141
// ---- SET ----
97142

98143
void emit_set_stmt(const AstNode* node) {

include/sql_parser/parse_result.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,26 @@ struct ErrorInfo {
1111
StringRef message;
1212
};
1313

14+
struct BoundValue {
15+
enum Type : uint8_t { INT, FLOAT, DOUBLE, STRING, BLOB, NULL_VAL, DATETIME, DECIMAL };
16+
Type type = NULL_VAL;
17+
union {
18+
int64_t int_val;
19+
float float32_val;
20+
double float64_val;
21+
StringRef str_val;
22+
};
23+
24+
BoundValue() : type(NULL_VAL), int_val(0) {}
25+
BoundValue(const BoundValue&) = default;
26+
BoundValue& operator=(const BoundValue&) = default;
27+
};
28+
29+
struct ParamBindings {
30+
BoundValue* values = nullptr;
31+
uint16_t count = 0;
32+
};
33+
1434
struct ParseResult {
1535
enum Status : uint8_t { OK = 0, PARTIAL, ERROR };
1636

@@ -24,6 +44,8 @@ struct ParseResult {
2444
StringRef schema_name;
2545
StringRef database_name;
2646

47+
ParamBindings bindings; // populated by execute()
48+
2749
bool ok() const { return status == OK; }
2850
bool has_remaining() const { return !remaining.empty(); }
2951
};

include/sql_parser/parser.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
#include "sql_parser/tokenizer.h"
77
#include "sql_parser/ast.h"
88
#include "sql_parser/parse_result.h"
9+
#include "sql_parser/stmt_cache.h"
910

1011
namespace sql_parser {
1112

1213
struct ParserConfig {
1314
size_t arena_block_size = 65536; // 64KB
1415
size_t arena_max_size = 1048576; // 1MB
16+
size_t stmt_cache_capacity = 128;
1517
};
1618

1719
template <Dialect D>
@@ -35,9 +37,15 @@ class Parser {
3537
// Access the arena (for emitter use)
3638
Arena& arena() { return arena_; }
3739

40+
// Prepared statement support
41+
ParseResult parse_and_cache(const char* sql, size_t len, uint32_t stmt_id);
42+
ParseResult execute(uint32_t stmt_id, const ParamBindings& params);
43+
void prepare_cache_evict(uint32_t stmt_id);
44+
3845
private:
3946
Arena arena_;
4047
Tokenizer<D> tokenizer_;
48+
StmtCache stmt_cache_;
4149

4250
// Classifier: dispatches to the right extractor/parser
4351
ParseResult classify_and_dispatch();

include/sql_parser/stmt_cache.h

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
#ifndef SQL_PARSER_STMT_CACHE_H
2+
#define SQL_PARSER_STMT_CACHE_H
3+
4+
#include "sql_parser/ast.h"
5+
#include "sql_parser/common.h"
6+
#include "sql_parser/parse_result.h"
7+
#include <cstdlib>
8+
#include <cstring>
9+
#include <unordered_map>
10+
#include <list>
11+
12+
namespace sql_parser {
13+
14+
// Deep-copy an AST tree from arena to heap memory.
15+
// The returned tree must be freed with free_ast().
16+
inline AstNode* deep_copy_ast(const AstNode* src) {
17+
if (!src) return nullptr;
18+
19+
AstNode* dst = static_cast<AstNode*>(std::malloc(sizeof(AstNode)));
20+
if (!dst) return nullptr;
21+
22+
dst->type = src->type;
23+
dst->flags = src->flags;
24+
dst->first_child = nullptr;
25+
dst->next_sibling = nullptr;
26+
27+
// Deep-copy value string to heap
28+
if (src->value_ptr && src->value_len > 0) {
29+
char* val_copy = static_cast<char*>(std::malloc(src->value_len));
30+
if (val_copy) {
31+
std::memcpy(val_copy, src->value_ptr, src->value_len);
32+
}
33+
dst->value_ptr = val_copy;
34+
dst->value_len = src->value_len;
35+
} else {
36+
dst->value_ptr = nullptr;
37+
dst->value_len = 0;
38+
}
39+
40+
// Recursively copy children
41+
const AstNode* src_child = src->first_child;
42+
AstNode* prev_dst_child = nullptr;
43+
while (src_child) {
44+
AstNode* dst_child = deep_copy_ast(src_child);
45+
if (dst_child) {
46+
if (!dst->first_child) {
47+
dst->first_child = dst_child;
48+
} else if (prev_dst_child) {
49+
prev_dst_child->next_sibling = dst_child;
50+
}
51+
prev_dst_child = dst_child;
52+
}
53+
src_child = src_child->next_sibling;
54+
}
55+
56+
return dst;
57+
}
58+
59+
// Free a heap-allocated AST tree (produced by deep_copy_ast).
60+
inline void free_ast(AstNode* node) {
61+
if (!node) return;
62+
// Free children first
63+
AstNode* child = node->first_child;
64+
while (child) {
65+
AstNode* next = child->next_sibling;
66+
free_ast(child);
67+
child = next;
68+
}
69+
// Free value string
70+
if (node->value_ptr) {
71+
std::free(const_cast<char*>(node->value_ptr));
72+
}
73+
std::free(node);
74+
}
75+
76+
// Cached entry for a prepared statement.
77+
struct CachedStmt {
78+
uint32_t stmt_id;
79+
StmtType stmt_type;
80+
AstNode* ast; // heap-allocated deep copy
81+
82+
~CachedStmt() {
83+
free_ast(ast);
84+
}
85+
86+
// Non-copyable
87+
CachedStmt(const CachedStmt&) = delete;
88+
CachedStmt& operator=(const CachedStmt&) = delete;
89+
CachedStmt(CachedStmt&& o) noexcept
90+
: stmt_id(o.stmt_id), stmt_type(o.stmt_type), ast(o.ast) {
91+
o.ast = nullptr;
92+
}
93+
CachedStmt& operator=(CachedStmt&& o) noexcept {
94+
if (this != &o) {
95+
free_ast(ast);
96+
stmt_id = o.stmt_id;
97+
stmt_type = o.stmt_type;
98+
ast = o.ast;
99+
o.ast = nullptr;
100+
}
101+
return *this;
102+
}
103+
104+
CachedStmt() : stmt_id(0), stmt_type(StmtType::UNKNOWN), ast(nullptr) {}
105+
CachedStmt(uint32_t id, StmtType type, AstNode* a)
106+
: stmt_id(id), stmt_type(type), ast(a) {}
107+
};
108+
109+
// Fixed-capacity LRU cache for prepared statements.
110+
class StmtCache {
111+
public:
112+
explicit StmtCache(size_t capacity = 128) : capacity_(capacity) {}
113+
114+
~StmtCache() { clear(); }
115+
116+
// Non-copyable
117+
StmtCache(const StmtCache&) = delete;
118+
StmtCache& operator=(const StmtCache&) = delete;
119+
120+
// Store a prepared statement. Deep-copies the AST from the arena.
121+
// Evicts LRU entry if at capacity.
122+
bool store(uint32_t stmt_id, StmtType stmt_type, const AstNode* ast) {
123+
// If already exists, remove old entry
124+
evict(stmt_id);
125+
126+
AstNode* copy = deep_copy_ast(ast);
127+
if (!copy && ast) return false;
128+
129+
// Evict LRU if at capacity
130+
if (lru_.size() >= capacity_) {
131+
auto& oldest = lru_.back();
132+
map_.erase(oldest.stmt_id);
133+
lru_.pop_back();
134+
}
135+
136+
lru_.emplace_front(stmt_id, stmt_type, copy);
137+
map_[stmt_id] = lru_.begin();
138+
return true;
139+
}
140+
141+
// Look up a cached statement. Returns nullptr if not found.
142+
// Moves the entry to front of LRU.
143+
const CachedStmt* lookup(uint32_t stmt_id) {
144+
auto it = map_.find(stmt_id);
145+
if (it == map_.end()) return nullptr;
146+
// Move to front (most recently used)
147+
lru_.splice(lru_.begin(), lru_, it->second);
148+
return &(*it->second);
149+
}
150+
151+
// Evict a specific statement.
152+
void evict(uint32_t stmt_id) {
153+
auto it = map_.find(stmt_id);
154+
if (it != map_.end()) {
155+
lru_.erase(it->second);
156+
map_.erase(it);
157+
}
158+
}
159+
160+
// Clear all entries.
161+
void clear() {
162+
lru_.clear();
163+
map_.clear();
164+
}
165+
166+
size_t size() const { return map_.size(); }
167+
size_t capacity() const { return capacity_; }
168+
169+
private:
170+
size_t capacity_;
171+
std::list<CachedStmt> lru_;
172+
std::unordered_map<uint32_t, std::list<CachedStmt>::iterator> map_;
173+
};
174+
175+
} // namespace sql_parser
176+
177+
#endif // SQL_PARSER_STMT_CACHE_H

src/sql_parser/parser.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ namespace sql_parser {
77

88
template <Dialect D>
99
Parser<D>::Parser(const ParserConfig& config)
10-
: arena_(config.arena_block_size, config.arena_max_size) {}
10+
: arena_(config.arena_block_size, config.arena_max_size),
11+
stmt_cache_(config.stmt_cache_capacity) {}
1112

1213
template <Dialect D>
1314
void Parser<D>::reset() {
@@ -402,6 +403,38 @@ ParseResult Parser<D>::extract_unknown(const Token& /* first */) {
402403
return r;
403404
}
404405

406+
// ---- Prepared statement support ----
407+
408+
template <Dialect D>
409+
ParseResult Parser<D>::parse_and_cache(const char* sql, size_t len, uint32_t stmt_id) {
410+
ParseResult r = parse(sql, len);
411+
if (r.ast) {
412+
stmt_cache_.store(stmt_id, r.stmt_type, r.ast);
413+
}
414+
return r;
415+
}
416+
417+
template <Dialect D>
418+
ParseResult Parser<D>::execute(uint32_t stmt_id, const ParamBindings& params) {
419+
ParseResult r;
420+
const CachedStmt* cached = stmt_cache_.lookup(stmt_id);
421+
if (!cached) {
422+
r.status = ParseResult::ERROR;
423+
r.stmt_type = StmtType::UNKNOWN;
424+
return r;
425+
}
426+
r.status = ParseResult::OK;
427+
r.stmt_type = cached->stmt_type;
428+
r.ast = cached->ast;
429+
r.bindings = params;
430+
return r;
431+
}
432+
433+
template <Dialect D>
434+
void Parser<D>::prepare_cache_evict(uint32_t stmt_id) {
435+
stmt_cache_.evict(stmt_id);
436+
}
437+
405438
// ---- Explicit template instantiations ----
406439

407440
template class Parser<Dialect::MySQL>;

0 commit comments

Comments
 (0)