diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..9c3ab4f --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,18 @@ +cmake_minimum_required (VERSION 2.8.11) +project (dbccpp) + +set( CMAKE_VERBOSE_MAKEFILE on ) + +file(GLOB SOURCES_DBCCPP include/dbccpp/*.h src/*.cpp src/*.h src/mysql/*.cpp src/mysql/*.h src/sqlite/*.cpp src/sqlite/*.h ) +file(GLOB SOURCES_TEST test/src/*.cpp test/testcpp/src/*.cpp ) + +include_directories(include test/testcpp/include ) + +link_directories(lib) + +add_library (libdbccpp STATIC ${SOURCES_DBCCPP}) + +add_executable(dbccpp-test ${SOURCES_TEST} ) + +target_link_libraries(dbccpp-test libdbccpp.a) + diff --git a/Makefile b/Makefile index 26d64e1..63e23c0 100644 --- a/Makefile +++ b/Makefile @@ -7,11 +7,11 @@ TARGET = lib/lib$(LIBNAME).a # For building with clang++ 3.1 in Ubuntu 12.04, install system clang and # add -I/usr/include/clang/3.0/include to compile flags -OPTIMIZE = -O2 # -g -std=c++0x | -std=c++11 +OPTIMIZE = -g -std=c++0x COMPILER = clang++ # g++ CXX = $(COMPILER) -CXXFLAGS = -pipe $(OPTIMIZE) -fPIC -Wall -Wextra -Werror -D_REENTRANT +CXXFLAGS = -pipe $(OPTIMIZE) -fPIC -Wall -Wextra -D_REENTRANT INCPATH = -Iinclude TEST = dbccpp-test @@ -21,7 +21,7 @@ TESTINCPATH = $(INCPATH) -I$(TESTCPPDIR)/include LINK = $(COMPILER) LFLAGS = -Wl,-O1 -LIBS = -Llib -l$(LIBNAME) -L$(TESTCPPDIR)/lib -ltestcpp -lsqlite3 +LIBS = -Llib -l$(LIBNAME) -L$(TESTCPPDIR)/lib -ltestcpp -lsqlite3 -lmysqlclient AR = ar cqs @@ -30,7 +30,8 @@ DEP = Makefile.dep # Generic source file lists SRC = $(wildcard src/*.cpp) \ - $(wildcard src/sqlite/*.cpp) + $(wildcard src/sqlite/*.cpp) \ + $(wildcard src/mysql/*.cpp) OBJS = $(patsubst src/%.cpp, obj/%.o, $(SRC)) @@ -41,6 +42,7 @@ TESTOBJS = $(patsubst test/src/%.cpp, test/obj/%.o, $(TESTSRC)) obj/%.o: src/%.cpp mkdir -p obj/sqlite + mkdir -p obj/mysql $(CXX) -c $(CXXFLAGS) $(INCPATH) -o $@ $< $(TARGET): $(OBJS) diff --git a/include/dbccpp/DbConnection.h b/include/dbccpp/DbConnection.h index d5842f2..f76534e 100644 --- a/include/dbccpp/DbConnection.h +++ b/include/dbccpp/DbConnection.h @@ -16,9 +16,8 @@ namespace dbc */ class DbConnection { - UTILCPP_DECLARE_INTERFACE(DbConnection) - public: + virtual ~DbConnection() {}; /** Creates the singleton instance of the driver-specific database * connection. * @@ -26,6 +25,7 @@ class DbConnection * @param params Parameters for the driver, e.g. file name for SQLite. */ static void connect(const std::string& driver, const std::string& params); + static void disconnect(); /** Access the singleton instance of the database connection. * The connection instance needs to be created with connect() before this @@ -37,7 +37,13 @@ class DbConnection virtual const CountProxy& executeUpdate(const std::string& sql) = 0; virtual ResultSet::ptr executeQuery(const std::string& sql) = 0; +protected: + DbConnection() {} private: + + DbConnection(const DbConnection&); + DbConnection& operator=(const DbConnection&); + static std::string _driver; static std::string _params; }; diff --git a/include/dbccpp/PreparedStatement.h b/include/dbccpp/PreparedStatement.h index 2b81d3a..fcd26f5 100644 --- a/include/dbccpp/PreparedStatement.h +++ b/include/dbccpp/PreparedStatement.h @@ -191,7 +191,7 @@ class PreparedStatement */ // FIXME: this should be a 64-bit type really // FIXME: information is SQLite-specific - virtual int getLastInsertId() = 0; + virtual u_int64_t getLastInsertId() = 0; /** Get the underlying SQL statement. */ virtual const char* getSQL() const = 0; diff --git a/src/DbConnection.cpp b/src/DbConnection.cpp index d57ddbe..90da965 100644 --- a/src/DbConnection.cpp +++ b/src/DbConnection.cpp @@ -4,45 +4,49 @@ #include "DbConnectionFactory.h" #if !(defined(__GXX_EXPERIMENTAL_CXX0X__) || (__cplusplus > 199711L)) - #include +#include #endif -namespace dbc -{ - -std::string DbConnection::_driver; -std::string DbConnection::_params; +namespace dbc { #if defined(__GXX_EXPERIMENTAL_CXX0X__) || (__cplusplus > 199711L) - typedef std::unique_ptr dbconnection_scoped_ptr; +typedef std::unique_ptr dbconnection_scoped_ptr; #else - typedef utilcpp::scoped_ptr dbconnection_scoped_ptr; +typedef utilcpp::scoped_ptr dbconnection_scoped_ptr; #endif -void DbConnection::connect(const std::string& driver, const std::string& params) -{ - if (!_driver.empty()) - throw DbErrorBase("Already connected, disconnect() has to be called before reconnect"); +std::string DbConnection::_driver; +std::string DbConnection::_params; - _driver = driver; - _params = params; -} +static dbconnection_scoped_ptr instanceObj; + +void DbConnection::connect(const std::string &driver, + const std::string ¶ms) { + if (!_driver.empty()) + throw DbErrorBase( + "Already connected, disconnect() has to be called before reconnect"); -DbConnection& DbConnection::instance() -{ - static dbconnection_scoped_ptr instance; + _driver = driver; + _params = params; +} - if (!instance) - { - if (_driver.empty()) - throw DbErrorBase("connect() has to be called before instance()"); +DbConnection &DbConnection::instance() { + if (!instanceObj) { + if (_driver.empty()) + throw DbErrorBase("connect() has to be called before instance()"); - instance = DbConnectionFactory::instance().createDbConnection(_driver, _params); - if (!instance) - throw DbErrorBase("Null instance returned from driver factory"); - } + instanceObj = + DbConnectionFactory::instance().createDbConnection(_driver, _params); + if (!instanceObj) + throw DbErrorBase("Null instance returned from driver factory"); + } - return *instance; + return *instanceObj; } +void DbConnection::disconnect() { + instanceObj.reset(); + _driver = ""; + _params = ""; +} } diff --git a/src/DbConnectionFactory.cpp b/src/DbConnectionFactory.cpp index 8099637..99c6f9b 100644 --- a/src/DbConnectionFactory.cpp +++ b/src/DbConnectionFactory.cpp @@ -1,36 +1,37 @@ #include "DbConnectionFactory.h" -#include +#include "mysql/MySQLConnection.h" #include "sqlite/SQLiteConnection.h" +#include -namespace dbc -{ +namespace dbc { -dbconnection_transferable_ptr createSQLiteConnection(const std::string& params) -{ - return dbconnection_transferable_ptr(new SQLiteConnection(params)); +dbconnection_transferable_ptr +createSQLiteConnection(const std::string ¶ms) { + return dbconnection_transferable_ptr(new SQLiteConnection(params)); } -DbConnectionFactory::DbConnectionFactory() : - _callbacks_registry() -{ - // see doc/README-factory.rst why this unneccessary coupling is needed - _callbacks_registry["sqlite"] = createSQLiteConnection; +dbconnection_transferable_ptr createMySQLConnection(const std::string ¶ms) { + return dbconnection_transferable_ptr(new MySQLConnection(params)); } -dbconnection_transferable_ptr DbConnectionFactory::createDbConnection(const std::string& driverName, - const std::string& params) -{ - CallbackMap::const_iterator it = _callbacks_registry.find(driverName); - if (it == _callbacks_registry.end()) - throw DbErrorBase(driverName + ": database driver not found"); - - return (it->second)(params); +DbConnectionFactory::DbConnectionFactory() : _callbacks_registry() { + // see doc/README-factory.rst why this unneccessary coupling is needed + _callbacks_registry["sqlite"] = createSQLiteConnection; + _callbacks_registry["mysql"] = createMySQLConnection; } -DbConnectionFactory& DbConnectionFactory::instance() -{ - static DbConnectionFactory factory; - return factory; +dbconnection_transferable_ptr +DbConnectionFactory::createDbConnection(const std::string &driverName, + const std::string ¶ms) { + CallbackMap::const_iterator it = _callbacks_registry.find(driverName); + if (it == _callbacks_registry.end()) + throw DbErrorBase(driverName + ": database driver not found"); + + return (it->second)(params); } +DbConnectionFactory &DbConnectionFactory::instance() { + static DbConnectionFactory factory; + return factory; +} } diff --git a/src/DbConnectionFactory.h b/src/DbConnectionFactory.h index b61eb75..84f7c79 100644 --- a/src/DbConnectionFactory.h +++ b/src/DbConnectionFactory.h @@ -6,35 +6,33 @@ #include #include -namespace dbc -{ +namespace dbc { #if defined(__GXX_EXPERIMENTAL_CXX0X__) || (__cplusplus > 199711L) - typedef std::unique_ptr dbconnection_transferable_ptr; +typedef std::unique_ptr dbconnection_transferable_ptr; #else - typedef std::auto_ptr dbconnection_transferable_ptr; +typedef std::auto_ptr dbconnection_transferable_ptr; #endif // Abstract factory that creates database-specific DbFactories. // Based on "Modern C++ Design" scalable factory idiom. -class DbConnectionFactory -{ - UTILCPP_DECLARE_SINGLETON(DbConnectionFactory) +class DbConnectionFactory { + UTILCPP_DECLARE_SINGLETON(DbConnectionFactory) public: - typedef dbconnection_transferable_ptr (*CreateDbConnectionCallback)(const std::string&); - typedef std::map CallbackMap; + typedef dbconnection_transferable_ptr (*CreateDbConnectionCallback)( + const std::string &); + typedef std::map CallbackMap; - bool registerDbConnectionCreator(const std::string& driverName, - CreateDbConnectionCallback creator); + bool registerDbConnectionCreator(const std::string &driverName, + CreateDbConnectionCallback creator); - dbconnection_transferable_ptr createDbConnection(const std::string& driverName, - const std::string& params); + dbconnection_transferable_ptr + createDbConnection(const std::string &driverName, const std::string ¶ms); private: - CallbackMap _callbacks_registry; + CallbackMap _callbacks_registry; }; - } #endif /* DBCONNECTIONFACTORY_H */ diff --git a/src/PreparedStatement.cpp b/src/PreparedStatement.cpp index 523f856..fd6ef16 100644 --- a/src/PreparedStatement.cpp +++ b/src/PreparedStatement.cpp @@ -1,27 +1,31 @@ #include -namespace dbc -{ +namespace dbc { -template<> -void PreparedStatement::set(const int parameterIndex, const int val) -{ setInt(parameterIndex, val); } +template <> +void PreparedStatement::set(const int parameterIndex, const int val) { + setInt(parameterIndex, val); +} -template<> -void PreparedStatement::set(const int parameterIndex, const bool val) -{ setBool(parameterIndex, val); } +template <> +void PreparedStatement::set(const int parameterIndex, const bool val) { + setBool(parameterIndex, val); +} -template<> -void PreparedStatement::set(const int parameterIndex, const double val) -{ setDouble(parameterIndex, val); } +template <> +void PreparedStatement::set(const int parameterIndex, + const double val) { + setDouble(parameterIndex, val); +} -template<> -void PreparedStatement::set(const int parameterIndex, const char* val) -{ setString(parameterIndex, val); } +template <> +void PreparedStatement::set(const int parameterIndex, const char *val) { + setString(parameterIndex, val); +} -template<> +template <> void PreparedStatement::set(const int parameterIndex, - const std::string& val) -{ setString(parameterIndex, val); } - + const std::string &val) { + setString(parameterIndex, val); +} } diff --git a/src/ResultSet.cpp b/src/ResultSet.cpp index 837db50..e212139 100644 --- a/src/ResultSet.cpp +++ b/src/ResultSet.cpp @@ -1,42 +1,29 @@ #include #include -namespace dbc -{ +namespace dbc { -template <> -int ResultSet::get(const int columnIndex) const -{ - return getInt(columnIndex); +template <> int ResultSet::get(const int columnIndex) const { + return getInt(columnIndex); } -template <> -double ResultSet::get(const int columnIndex) const -{ - return getDouble(columnIndex); +template <> double ResultSet::get(const int columnIndex) const { + return getDouble(columnIndex); } -template <> -bool ResultSet::get(const int columnIndex) const -{ - return getBool(columnIndex); +template <> bool ResultSet::get(const int columnIndex) const { + return getBool(columnIndex); } -template <> -std::string ResultSet::get(const int columnIndex) const -{ - return getString(columnIndex); +template <> std::string ResultSet::get(const int columnIndex) const { + return getString(columnIndex); } -template <> -void ResultSet::get(const int columnIndex, std::string& out) const -{ - return getString(columnIndex, out); +template <> void ResultSet::get(const int columnIndex, std::string &out) const { + return getString(columnIndex, out); } -const SubscriptProxy ResultSet::operator[] (const int index) const -{ - return SubscriptProxy(*this, index); +const SubscriptProxy ResultSet::operator[](const int index) const { + return SubscriptProxy(*this, index); } - } diff --git a/src/mysql/MySQLConnection.cpp b/src/mysql/MySQLConnection.cpp new file mode 100644 index 0000000..ee62296 --- /dev/null +++ b/src/mysql/MySQLConnection.cpp @@ -0,0 +1,53 @@ +#include "MySQLConnection.h" +#include "../DbConnectionFactory.h" +#include "MySQLCountProxy.h" +#include "MySQLExceptions.h" +#include "MySQLPreparedStatement.h" + +void finalize_mysql(MYSQL *db) { + // As destructors cannot throw, we cannot handle the case when ret != OK + if (db) + mysql_close(db); +} + +namespace dbc { + +MySQLConnection::MySQLConnection(const std::string ¶ms) + : _params(params), _db() { + MYSQL *db = mysql_init(NULL); + MYSQL *ret = + mysql_real_connect(db, "localhost", "erm", "erm", "erm", 0, NULL, 0); + + // Whether or not an error occurs when it is opened, resources associated + // with the database connection handle should be released by passing it to + // sqlite3_close() when it is no longer required. + _db.reset(db); + + if (ret == NULL) { + std::ostringstream msg; + msg << "mysql_real_connect(" << params << ") failed"; + throw std::runtime_error(msg.str()); + } +} + +PreparedStatement::ptr +MySQLConnection::prepareStatement(const std::string &sql) { + return PreparedStatement::ptr(new MySQLPreparedStatement(sql, *this)); +} + +const CountProxy &MySQLConnection::executeUpdate(const std::string &sql) { + static MySQLCountProxy count(handle()); + + int ret = mysql_real_query(handle(), sql.c_str(), sql.length()); + if (ret != 0) + throw MySQLSqlError(*this, "mysql_real_query() failed", sql); + + return count; +} + +ResultSet::ptr MySQLConnection::executeQuery(const std::string &sql) { + MySQLPreparedStatement statement(sql, *this); + + return statement.executeQuery(); +} +} diff --git a/src/mysql/MySQLConnection.h b/src/mysql/MySQLConnection.h new file mode 100644 index 0000000..6d64746 --- /dev/null +++ b/src/mysql/MySQLConnection.h @@ -0,0 +1,41 @@ +#ifndef MYSQLCONNECTION_H__ +#define MYSQLCONNECTION_H__ + +#include "../DbConnectionFactory.h" + +#include + +#include + +#include + +#include + +void finalize_mysql(MYSQL *); + +namespace dbc { + +dbconnection_transferable_ptr createMySQLConnection(const std::string ¶ms); + +class MySQLConnection : public DbConnection { +public: + inline MYSQL *handle() { return _db.get(); } + + virtual PreparedStatement::ptr prepareStatement(const std::string &sql); + virtual const CountProxy &executeUpdate(const std::string &sql); + virtual ResultSet::ptr executeQuery(const std::string &sql); + +private: + MySQLConnection(const std::string ¶ms); + + friend dbconnection_transferable_ptr + createMySQLConnection(const std::string ¶ms); + + typedef utilcpp::scoped_ptr mysql_scoped_ptr; + + std::string _params; + mysql_scoped_ptr _db; +}; +} + +#endif /* MYSQLCONNECTION_H__ */ diff --git a/src/mysql/MySQLCountProxy.cpp b/src/mysql/MySQLCountProxy.cpp new file mode 100644 index 0000000..9dbb052 --- /dev/null +++ b/src/mysql/MySQLCountProxy.cpp @@ -0,0 +1,13 @@ +#include "MySQLCountProxy.h" +#include + +namespace dbc { + +MySQLCountProxy::operator int() const { + // If a separate thread makes changes on the same database connection + // while sqlite3_changes() is running then the value returned is + // unpredictable and not meaningful. + + return mysql_affected_rows(_db); +} +} diff --git a/src/mysql/MySQLCountProxy.h b/src/mysql/MySQLCountProxy.h new file mode 100644 index 0000000..5f10738 --- /dev/null +++ b/src/mysql/MySQLCountProxy.h @@ -0,0 +1,19 @@ +#ifndef MYSQLCOUNTPROXY_H__ +#define MYSQLCOUNTPROXY_H__ + +#include +#include + +namespace dbc { + +class MySQLCountProxy : public CountProxy { +public: + MySQLCountProxy(MYSQL *db) : _db(db) {} + virtual operator int() const; + +private: + MYSQL *_db; +}; +} + +#endif /* MYSQLCOUNTPROXY_H__ */ diff --git a/src/mysql/MySQLExceptions.h b/src/mysql/MySQLExceptions.h new file mode 100644 index 0000000..c9e309f --- /dev/null +++ b/src/mysql/MySQLExceptions.h @@ -0,0 +1,27 @@ +#ifndef SQLITEEXCEPTION_H__ +#define SQLITEEXCEPTION_H__ + +#include + +#include + +namespace dbc { + +class MySQLDbError : public DbError { +public: + MySQLDbError(MYSQL *db, const std::string &msg) + : DbError(msg, mysql_error(db)) {} +}; + +class MySQLSqlError : public SqlError { +public: + MySQLSqlError(MySQLConnection &db, const std::string &msg, + const std::string &sql) + : SqlError(msg, mysql_error(db.handle()), sql) {} + MySQLSqlError(MYSQL_STMT *stmt, const std::string &msg, + const std::string &sql) + : SqlError(msg, mysql_stmt_error(stmt), sql) {} +}; +} + +#endif /* SQLITEEXCEPTION_H */ diff --git a/src/mysql/MySQLPreparedStatement.cpp b/src/mysql/MySQLPreparedStatement.cpp new file mode 100644 index 0000000..1b96285 --- /dev/null +++ b/src/mysql/MySQLPreparedStatement.cpp @@ -0,0 +1,230 @@ +#include "MySQLPreparedStatement.h" +#include "MySQLCountProxy.h" +#include "MySQLExceptions.h" +#include "MySQLResultSet.h" + +#include + +namespace { + +MYSQL_STMT *init_statement(dbc::MySQLConnection &db, const std::string &sql) { + MYSQL_STMT *statement = mysql_stmt_init(db.handle()); + int ret = mysql_stmt_prepare(statement, sql.c_str(), sql.length()); + if (ret != 0) { + throw dbc::MySQLSqlError(db, "mysql_stmt_prepare() failed", sql); + } + + return statement; +} + +template +void addToBuffer(std::vector> &buffers, const T &val) { + std::vector buff; + buff.resize(sizeof(T)); + memcpy(&buff[0], &val, sizeof(T)); + buffers.push_back(buff); +} + +template <> +void addToBuffer(std::vector> &buffers, + const std::string &val) { + std::vector buff(val.size()); + memcpy(&buff[0], val.data(), val.size()); + buffers.push_back(buff); +} + +void createBuffer(std::vector> &buffers, int size) { + std::vector buff(size); + buffers.push_back(buff); +} + +std::vector> & +addOutputBuffer(std::vector>> &outputBuffers) { + outputBuffers.push_back(std::vector>()); + return outputBuffers.back(); +} +} + +namespace dbc { + +void finalize_mysql_stmt(MYSQL_STMT *statement) { + if (statement) + mysql_stmt_close(statement); +} + +MySQLPreparedStatement::MySQLPreparedStatement(const std::string &sql, + MySQLConnection &db) + : PreparedStatement(), _db(db), _statement(init_statement(db, sql)), + _param_tracker(mysql_stmt_param_count(_statement.get())), _sql(sql), + _input_bind_params(NULL) { + + _input_bind_params.resize(_param_tracker.getNumParams()); + memset(&_input_bind_params[0], 0, + _input_bind_params.size() * sizeof(MYSQL_BIND)); +} + +void MySQLPreparedStatement::setInt(const int index, const int value) { + _param_tracker.setParameter(index); + + addToBuffer(_inputBuffers, value); + + MYSQL_BIND &bind = _input_bind_params[index - 1]; + bind.buffer_type = MYSQL_TYPE_LONG; + bind.buffer = (char *)&_inputBuffers.back()[0]; + bind.is_null = 0; + bind.length = 0; +} + +void MySQLPreparedStatement::setDouble(const int index, const double value) { + _param_tracker.setParameter(index); + + addToBuffer(_inputBuffers, value); + + MYSQL_BIND &bind = _input_bind_params[index - 1]; + bind.buffer_type = MYSQL_TYPE_DOUBLE; + bind.buffer = (char *)&_inputBuffers.back()[0]; + bind.is_null = 0; + bind.length = 0; +} + +void MySQLPreparedStatement::setString(const int index, + const std::string &value) { + _param_tracker.setParameter(index); + + MYSQL_BIND &bind = _input_bind_params[index - 1]; + + addToBuffer(_inputBuffers, value); + + bind.buffer_type = MYSQL_TYPE_STRING; + bind.buffer = (char *)&_inputBuffers.back()[0]; + + addToBuffer(_inputBuffers, value.size()); + + bind.buffer_length = value.size(); + bind.is_null = 0; + bind.length = (unsigned long *)&_inputBuffers.back()[0]; +} + +void MySQLPreparedStatement::setBool(const int index, const bool value) { + _param_tracker.setParameter(index); + + MYSQL_BIND &bind = _input_bind_params[index - 1]; + bind.buffer_type = MYSQL_TYPE_BIT; + bind.buffer = (char *)&value; + bind.is_null = 0; + bind.length = 0; +} + +void MySQLPreparedStatement::doReset() { + if (mysql_stmt_reset(_statement.get())) + throw MySQLSqlError(_db, "mysql_stmt_reset() failed", getSQL()); +} + +void MySQLPreparedStatement::tryBindInput() { + if (mysql_stmt_bind_param(_statement.get(), &_input_bind_params[0])) { + throw MySQLSqlError(_statement.get(), "mysql_stmt_bind_param failed()", + getSQL()); + } +} + +ResultSet::ptr MySQLPreparedStatement::doExecuteQuery() { + reset(); + tryBindInput(); + tryExecuteStatement(); + + MYSQL_RES *result_meta = mysql_stmt_result_metadata(_statement.get()); + int column_count = mysql_num_fields(result_meta); + + _output_bind_params.resize(column_count); + + // iterate over fields + // for each fields, gets it type and initialize bind accordingly + + MYSQL_FIELD *field = NULL; + + // move this to result set + for (size_t i = 0; i < _output_bind_params.size(); i++) { + // create new buffer array + std::vector> &buffer = addOutputBuffer(_outputBuffers); + field = mysql_fetch_field(result_meta); + MYSQL_BIND &bind = _output_bind_params[i]; + + bind.buffer_type = field->type; + + createBuffer(buffer, field->length); + bind.buffer = (uint8_t *)&buffer.back()[0]; + + bind.buffer_length = field->length; + + createBuffer(buffer, sizeof(*bind.length)); + bind.length = (unsigned long *)&buffer.back()[0]; + ; + + createBuffer(buffer, sizeof(*bind.error)); + bind.error = (my_bool *)&buffer.back()[0]; + ; + + createBuffer(buffer, sizeof(*bind.is_null)); + bind.is_null = (my_bool *)&buffer.back()[0]; + } + + // sanity check + if (mysql_fetch_field(result_meta) != NULL) + throw MySQLDbError(_db.handle(), + "column count and fetch field out of sync!"); + + mysql_free_result(result_meta); + + if (mysql_stmt_bind_result(_statement.get(), _output_bind_params.data())) { + throw MySQLSqlError(_statement.get(), "mysql_stmt_bind_result failed()", + getSQL()); + } + + if (mysql_stmt_store_result(_statement.get())) { + throw MySQLSqlError(_statement.get(), "mysql_stmt_store_result failed()", + getSQL()); + } + + return ResultSet::ptr(new MySQLResultSet(*this)); +} + +const CountProxy &MySQLPreparedStatement::doExecuteUpdate() { + reset(); + tryBindInput(); + + static MySQLCountProxy count(_db.handle()); + + tryExecuteStatement(); + + _inputBuffers.clear(); + memset(&_input_bind_params[0], 0, + _input_bind_params.size() * sizeof(MYSQL_BIND)); + + return count; +} + +MYSQL_STMT *MySQLPreparedStatement::handle() { return _statement.get(); } + +void MySQLPreparedStatement::tryExecuteStatement() { + if (mysql_stmt_execute(_statement.get())) { + throw MySQLSqlError(_statement.get(), + "mysql_stmt_execute must return 0 in executeUpdate()", + getSQL()); + } +} + +void MySQLPreparedStatement::setNull(const int index) { + _param_tracker.setParameter(index); + + MYSQL_BIND &bind = _input_bind_params[index - 1]; + bind.buffer_type = MYSQL_TYPE_NULL; + bind.is_null = 0; + bind.length = 0; +} + +u_int64_t MySQLPreparedStatement::getLastInsertId() { + return mysql_insert_id(_db.handle()); +} + +const char *MySQLPreparedStatement::getSQL() const { return _sql.c_str(); } +} diff --git a/src/mysql/MySQLPreparedStatement.h b/src/mysql/MySQLPreparedStatement.h new file mode 100644 index 0000000..823a420 --- /dev/null +++ b/src/mysql/MySQLPreparedStatement.h @@ -0,0 +1,66 @@ +#ifndef MYSQLPREPAREDSTATEMENT_H__ +#define MYSQLPREPAREDSTATEMENT_H__ + +#include "MySQLConnection.h" +#include +#include +#include + +#ifdef DBCCPP_HAVE_CPP11 +#include +namespace dbc { +namespace stdutil = std; +} +#else +#include +namespace dbc { +namespace stdutil = boost; +} +#endif + +namespace dbc { + +class MySQLResultSet; + +void finalize_mysql_stmt(MYSQL_STMT *); + +class MySQLPreparedStatement : public PreparedStatement { + friend class MySQLResultSet; + // PreparedStatement interface +public: + MySQLPreparedStatement(const std::string &sql, MySQLConnection &db); + void setNull(const int parameterIndex); + uint64_t getLastInsertId(); + const char *getSQL() const; + MYSQL_STMT *handle(); + +protected: + void setString(const int parameterIndex, const std::string &val); + void setInt(const int parameterIndex, const int val); + void setDouble(const int parameterIndex, const double val); + void setBool(const int parameterIndex, const bool value); + void doReset(); + void doClear(){}; + ResultSet::ptr doExecuteQuery(); + const CountProxy &doExecuteUpdate(); + + // methods +private: + void tryExecuteStatement(); + void tryBindInput(); + + // fields +private: + typedef utilcpp::scoped_ptr + mysql_stmt_scoped_ptr; + MySQLConnection &_db; + mysql_stmt_scoped_ptr _statement; + ParameterTracker _param_tracker; + std::vector _input_bind_params; + std::vector _output_bind_params; + std::string _sql; + std::vector> _inputBuffers; + std::vector>> _outputBuffers; +}; +} +#endif // MYSQLPREPAREDSTATEMENT_H__ diff --git a/src/mysql/MySQLResultSet.cpp b/src/mysql/MySQLResultSet.cpp new file mode 100644 index 0000000..983932e --- /dev/null +++ b/src/mysql/MySQLResultSet.cpp @@ -0,0 +1,66 @@ +#include "MySQLResultSet.h" +#include "MySQLExceptions.h" +#include "MySQLPreparedStatement.h" + +dbc::MySQLResultSet::MySQLResultSet(MySQLPreparedStatement &stmt) + : _statement(stmt) {} + +bool dbc::MySQLResultSet::next() { + int ret = mysql_stmt_fetch(_statement.handle()); + switch (ret) { + case 0: + return true; + case MYSQL_NO_DATA: + return false; + case MYSQL_DATA_TRUNCATED: + throw MySQLSqlError(_statement.handle(), "Data truncated!", + _statement.getSQL()); + default: + throw MySQLSqlError(_statement.handle(), "Error while fetching row", + _statement.getSQL()); + } +} + +bool dbc::MySQLResultSet::isNull(const int columnIndex) const { + std::vector> &buff = + _statement._outputBuffers[columnIndex]; + + my_bool *val = (my_bool *)buff[3].data(); + return *val; +} + +void dbc::MySQLResultSet::getString(const int columnIndex, + std::string &out) const { + std::vector> &buff = + _statement._outputBuffers[columnIndex]; + out = (const char *)buff[0].data(); +} + +std::string dbc::MySQLResultSet::getString(const int columnIndex) const { + std::vector> &buff = + _statement._outputBuffers[columnIndex]; + + return std::string((const char *)buff[0].data()); +} + +int dbc::MySQLResultSet::getInt(const int columnIndex) const { + std::vector> &buff = + _statement._outputBuffers[columnIndex]; + + return *((int *)buff[0].data()); +} + +double dbc::MySQLResultSet::getDouble(const int columnIndex) const { + std::vector> &buff = + _statement._outputBuffers[columnIndex]; + + double *ret = (double *)buff[0].data(); + return *ret; +} + +bool dbc::MySQLResultSet::getBool(const int columnIndex) const { + std::vector> &buff = + _statement._outputBuffers[columnIndex]; + + return *((bool *)buff[0].data()); +} diff --git a/src/mysql/MySQLResultSet.h b/src/mysql/MySQLResultSet.h new file mode 100644 index 0000000..37e6145 --- /dev/null +++ b/src/mysql/MySQLResultSet.h @@ -0,0 +1,29 @@ +#ifndef MYSQLRESULTSET_H__ +#define MYSQLRESULTSET_H__ + +#include + +namespace dbc { + +class MySQLPreparedStatement; + +class MySQLResultSet : public ResultSet { + // ResultSet interface +public: + MySQLResultSet(MySQLPreparedStatement &stmt); + bool next(); + bool isNull(const int columnIndex) const; + +protected: + void getString(const int columnIndex, std::string &out) const; + std::string getString(const int columnIndex) const; + int getInt(const int columnIndex) const; + double getDouble(const int columnIndex) const; + bool getBool(const int columnIndex) const; + +private: + MySQLPreparedStatement &_statement; +}; +} + +#endif // MYSQLRESULTSET_H__ diff --git a/src/sqlite/SQLiteConnection.cpp b/src/sqlite/SQLiteConnection.cpp index 752025c..a1aa696 100644 --- a/src/sqlite/SQLiteConnection.cpp +++ b/src/sqlite/SQLiteConnection.cpp @@ -1,57 +1,49 @@ #include "SQLiteConnection.h" +#include "../DbConnectionFactory.h" #include "SQLiteCountProxy.h" #include "SQLiteExceptions.h" #include "SQLitePreparedStatement.h" -#include "../DbConnectionFactory.h" #include -void finalize_sqlite3(sqlite3* db) -{ - // As destructors cannot throw, we cannot handle the case when ret != OK - if (db) - sqlite3_close_v2(db); +void finalize_sqlite3(sqlite3 *db) { + // As destructors cannot throw, we cannot handle the case when ret != OK + if (db) + sqlite3_close_v2(db); } -namespace dbc -{ - -SQLiteConnection::SQLiteConnection(const std::string& params) : - _params(params), - _db() -{ - sqlite3* db; - int ret = sqlite3_open(params.c_str(), &db); - - // Whether or not an error occurs when it is opened, resources associated - // with the database connection handle should be released by passing it to - // sqlite3_close() when it is no longer required. - _db.reset(db); - - if (ret != SQLITE_OK) - { - std::ostringstream msg; - msg << "sqlite3_open(" << params << ") failed"; - throw SQLiteDbError(db, msg.str()); - } -} +namespace dbc { -const CountProxy& SQLiteConnection::executeUpdate(const std::string& sql) -{ - static SQLiteCountProxy count(handle()); +SQLiteConnection::SQLiteConnection(const std::string ¶ms) + : _params(params), _db() { + sqlite3 *db; + int ret = sqlite3_open(params.c_str(), &db); - int ret = sqlite3_exec(handle(), sql.c_str(), 0, 0, 0); - if (ret != SQLITE_OK) - throw SQLiteSqlError(*this, "sqlite3_exec() failed", sql); + // Whether or not an error occurs when it is opened, resources associated + // with the database connection handle should be released by passing it to + // sqlite3_close() when it is no longer required. + _db.reset(db); - return count; + if (ret != SQLITE_OK) { + std::ostringstream msg; + msg << "sqlite3_open(" << params << ") failed"; + throw SQLiteDbError(db, msg.str()); + } } -ResultSet::ptr SQLiteConnection::executeQuery(const std::string& sql) -{ - SQLitePreparedStatement statement(sql, *this); +const CountProxy &SQLiteConnection::executeUpdate(const std::string &sql) { + static SQLiteCountProxy count(handle()); - return statement.executeQuery(); + int ret = sqlite3_exec(handle(), sql.c_str(), 0, 0, 0); + if (ret != SQLITE_OK) + throw SQLiteSqlError(*this, "sqlite3_exec() failed", sql); + + return count; } +ResultSet::ptr SQLiteConnection::executeQuery(const std::string &sql) { + SQLitePreparedStatement statement(sql, *this); + + return statement.executeQuery(); +} } diff --git a/src/sqlite/SQLiteConnection.h b/src/sqlite/SQLiteConnection.h index 77557c8..bfbd832 100644 --- a/src/sqlite/SQLiteConnection.h +++ b/src/sqlite/SQLiteConnection.h @@ -10,39 +10,34 @@ #include struct sqlite3; -void finalize_sqlite3(sqlite3*); +void finalize_sqlite3(sqlite3 *); -namespace dbc -{ +namespace dbc { -dbconnection_transferable_ptr createSQLiteConnection(const std::string& params); +dbconnection_transferable_ptr createSQLiteConnection(const std::string ¶ms); - -class SQLiteConnection : public DbConnection -{ +class SQLiteConnection : public DbConnection { public: - virtual PreparedStatement::ptr prepareStatement(const std::string& sql) - { - return PreparedStatement::ptr(new SQLitePreparedStatement(sql, *this)); - } + virtual PreparedStatement::ptr prepareStatement(const std::string &sql) { + return PreparedStatement::ptr(new SQLitePreparedStatement(sql, *this)); + } - inline sqlite3* handle() - { return _db.get(); } + inline sqlite3 *handle() { return _db.get(); } - virtual const CountProxy& executeUpdate(const std::string& sql); - virtual ResultSet::ptr executeQuery(const std::string& sql); + virtual const CountProxy &executeUpdate(const std::string &sql); + virtual ResultSet::ptr executeQuery(const std::string &sql); private: - SQLiteConnection(const std::string& params); + SQLiteConnection(const std::string ¶ms); - friend dbconnection_transferable_ptr createSQLiteConnection(const std::string& params); + friend dbconnection_transferable_ptr + createSQLiteConnection(const std::string ¶ms); - typedef utilcpp::scoped_ptr sqlite_scoped_ptr; + typedef utilcpp::scoped_ptr sqlite_scoped_ptr; - std::string _params; - sqlite_scoped_ptr _db; + std::string _params; + sqlite_scoped_ptr _db; }; - } #endif /* SQLITECONNECTION_H */ diff --git a/src/sqlite/SQLiteCountProxy.cpp b/src/sqlite/SQLiteCountProxy.cpp index 2fe9544..c63613b 100644 --- a/src/sqlite/SQLiteCountProxy.cpp +++ b/src/sqlite/SQLiteCountProxy.cpp @@ -1,16 +1,13 @@ #include "SQLiteCountProxy.h" #include -namespace dbc -{ +namespace dbc { -SQLiteCountProxy::operator int() const -{ - // If a separate thread makes changes on the same database connection - // while sqlite3_changes() is running then the value returned is - // unpredictable and not meaningful. +SQLiteCountProxy::operator int() const { + // If a separate thread makes changes on the same database connection + // while sqlite3_changes() is running then the value returned is + // unpredictable and not meaningful. - return sqlite3_changes(_db); + return sqlite3_changes(_db); } - } diff --git a/src/sqlite/SQLiteCountProxy.h b/src/sqlite/SQLiteCountProxy.h index 4e2a7f2..5a9149f 100644 --- a/src/sqlite/SQLiteCountProxy.h +++ b/src/sqlite/SQLiteCountProxy.h @@ -5,19 +5,16 @@ struct sqlite3; -namespace dbc -{ +namespace dbc { -class SQLiteCountProxy : public CountProxy -{ +class SQLiteCountProxy : public CountProxy { public: - SQLiteCountProxy(sqlite3* db) : _db(db) {} - virtual operator int() const; + SQLiteCountProxy(sqlite3 *db) : _db(db) {} + virtual operator int() const; private: - sqlite3* _db; + sqlite3 *_db; }; - } #endif /* SQLITECOUNTPROXY_H */ diff --git a/src/sqlite/SQLiteExceptions.h b/src/sqlite/SQLiteExceptions.h index f481c34..d13f3ee 100644 --- a/src/sqlite/SQLiteExceptions.h +++ b/src/sqlite/SQLiteExceptions.h @@ -5,26 +5,20 @@ #include -namespace dbc -{ +namespace dbc { -class SQLiteDbError : public DbError -{ +class SQLiteDbError : public DbError { public: - SQLiteDbError(sqlite3* db, const std::string& msg) : - DbError(msg, sqlite3_errmsg(db)) - { } + SQLiteDbError(sqlite3 *db, const std::string &msg) + : DbError(msg, sqlite3_errmsg(db)) {} }; -class SQLiteSqlError : public SqlError -{ +class SQLiteSqlError : public SqlError { public: - SQLiteSqlError(SQLiteConnection& db, - const std::string& msg, const std::string& sql) : - SqlError(msg, sqlite3_errmsg(db.handle()), sql) - { } + SQLiteSqlError(SQLiteConnection &db, const std::string &msg, + const std::string &sql) + : SqlError(msg, sqlite3_errmsg(db.handle()), sql) {} }; - } #endif /* SQLITEEXCEPTION_H */ diff --git a/src/sqlite/SQLitePreparedStatement.cpp b/src/sqlite/SQLitePreparedStatement.cpp index 2ef83c7..ee92c29 100644 --- a/src/sqlite/SQLitePreparedStatement.cpp +++ b/src/sqlite/SQLitePreparedStatement.cpp @@ -1,157 +1,139 @@ #include "SQLitePreparedStatement.h" -#include "SQLiteCountProxy.h" #include "SQLiteConnection.h" +#include "SQLiteCountProxy.h" #include "SQLiteExceptions.h" #include -void finalize_sqlite3_stmt(sqlite3_stmt* statement) -{ - if (statement) - sqlite3_finalize(statement); +void finalize_sqlite3_stmt(sqlite3_stmt *statement) { + if (statement) + sqlite3_finalize(statement); } -namespace -{ +namespace { /** Helper to init _statement in SQLitePreparedStatement constructor. */ -sqlite3_stmt* init_statement(dbc::SQLiteConnection& db, const std::string& sql) -{ - const char* tail_unused; - - // If the caller knows that the supplied string is nul-terminated, - // then there is a small performance advantage to be gained by passing an - // nByte parameter that is equal to the number of bytes in the input - // string *including* the nul-terminator bytes. - - sqlite3_stmt* statement; - int ret = sqlite3_prepare_v2(db.handle(), sql.c_str(), sql.length() + 1, - &statement, &tail_unused); - if (ret != SQLITE_OK) - throw dbc::SQLiteSqlError(db, "sqlite3_prepare_v2() failed", sql); - - return statement; +sqlite3_stmt *init_statement(dbc::SQLiteConnection &db, + const std::string &sql) { + const char *tail_unused; + + // If the caller knows that the supplied string is nul-terminated, + // then there is a small performance advantage to be gained by passing an + // nByte parameter that is equal to the number of bytes in the input + // string *including* the nul-terminator bytes. + + sqlite3_stmt *statement; + int ret = sqlite3_prepare_v2(db.handle(), sql.c_str(), sql.length() + 1, + &statement, &tail_unused); + if (ret != SQLITE_OK) + throw dbc::SQLiteSqlError(db, "sqlite3_prepare_v2() failed", sql); + + return statement; } - } -namespace dbc -{ +namespace dbc { -#define THROW_IF_SET_STMT_NOT_OK(retval, func_name, val) \ - if (retval != SQLITE_OK) \ - { \ - std::ostringstream err; \ - err << func_name "(" << index << ", " << val << ") failed"; \ - throw SQLiteSqlError(_db, err.str(), getSQL()); \ - } +#define THROW_IF_SET_STMT_NOT_OK(retval, func_name, val) \ + if (retval != SQLITE_OK) { \ + std::ostringstream err; \ + err << func_name "(" << index << ", " << val << ") failed"; \ + throw SQLiteSqlError(_db, err.str(), getSQL()); \ + } // sql has to be UTF-8 -SQLitePreparedStatement::SQLitePreparedStatement(const std::string& sql, - SQLiteConnection& db) : - PreparedStatement(), - _db(db), - _statement(init_statement(db, sql)), - _param_tracker(sqlite3_bind_parameter_count(_statement.get())) -{} - -ResultSet::ptr SQLitePreparedStatement::doExecuteQuery() -{ - checkParams(); - - return ResultSet::ptr(new SQLiteResultSet(*this)); +SQLitePreparedStatement::SQLitePreparedStatement(const std::string &sql, + SQLiteConnection &db) + : PreparedStatement(), _db(db), _statement(init_statement(db, sql)), + _param_tracker(sqlite3_bind_parameter_count(_statement.get())) {} + +ResultSet::ptr SQLitePreparedStatement::doExecuteQuery() { + checkParams(); + + return ResultSet::ptr(new SQLiteResultSet(*this)); } -const CountProxy& SQLitePreparedStatement::doExecuteUpdate() -{ - checkParams(); +const CountProxy &SQLitePreparedStatement::doExecuteUpdate() { + checkParams(); - static SQLiteCountProxy count(_db.handle()); + static SQLiteCountProxy count(_db.handle()); - int ret = sqlite3_step(_statement.get()); - if (ret != SQLITE_DONE) { - sqlite3_reset(_statement.get()); - throw SQLiteSqlError(_db, - "sqlite3_step() must return SQLITE_DONE in executeUpdate()", - getSQL()); - } + int ret = sqlite3_step(_statement.get()); + if (ret != SQLITE_DONE) { + sqlite3_reset(_statement.get()); + throw SQLiteSqlError( + _db, "sqlite3_step() must return SQLITE_DONE in executeUpdate()", + getSQL()); + } - // SQLITE_DONE means that the statement has finished executing - // successfully. sqlite3_step() should not be called again on this virtual - // machine without first calling sqlite3_reset() to reset the virtual - // machine back to its initial state. + // SQLITE_DONE means that the statement has finished executing + // successfully. sqlite3_step() should not be called again on this virtual + // machine without first calling sqlite3_reset() to reset the virtual + // machine back to its initial state. - reset(); + reset(); - return count; + return count; } -void SQLitePreparedStatement::doClear() -{ - int ret = sqlite3_clear_bindings(_statement.get()); - if (ret != SQLITE_OK) - throw SQLiteSqlError(_db, "sqlite3_clear_bindings() failed", getSQL()); +void SQLitePreparedStatement::doClear() { + int ret = sqlite3_clear_bindings(_statement.get()); + if (ret != SQLITE_OK) + throw SQLiteSqlError(_db, "sqlite3_clear_bindings() failed", getSQL()); } -void SQLitePreparedStatement::doReset() -{ - // We have experienced SQLITE_BUSY errors here indicating that the - // database is locked because of another ongoing operation. Increasing - // sqlite3_busy_timeout() and/or using transactions should fix this. +void SQLitePreparedStatement::doReset() { + // We have experienced SQLITE_BUSY errors here indicating that the + // database is locked because of another ongoing operation. Increasing + // sqlite3_busy_timeout() and/or using transactions should fix this. - int ret = sqlite3_reset(_statement.get()); - if (ret != SQLITE_OK) - throw SQLiteSqlError(_db, "sqlite3_reset() failed", getSQL()); + int ret = sqlite3_reset(_statement.get()); + if (ret != SQLITE_OK) + throw SQLiteSqlError(_db, "sqlite3_reset() failed", getSQL()); } -int SQLitePreparedStatement::getLastInsertId() -{ - return static_cast(sqlite3_last_insert_rowid(_db.handle())); +uint64_t SQLitePreparedStatement::getLastInsertId() { + return static_cast(sqlite3_last_insert_rowid(_db.handle())); } -const char* SQLitePreparedStatement::getSQL() const -{ - return sqlite3_sql(_statement.get()); +const char *SQLitePreparedStatement::getSQL() const { + return sqlite3_sql(_statement.get()); } -void SQLitePreparedStatement::setInt(const int index, const int value) -{ - _param_tracker.setParameter(index); - int ret = sqlite3_bind_int(_statement.get(), index, value); - THROW_IF_SET_STMT_NOT_OK(ret, "sqlite3_bind_int", value); +void SQLitePreparedStatement::setInt(const int index, const int value) { + _param_tracker.setParameter(index); + int ret = sqlite3_bind_int(_statement.get(), index, value); + THROW_IF_SET_STMT_NOT_OK(ret, "sqlite3_bind_int", value); } -void SQLitePreparedStatement::setBool(const int index, const bool value) -{ - _param_tracker.setParameter(index); - int ret = sqlite3_bind_int(_statement.get(), index, value ? 1 : 0); - THROW_IF_SET_STMT_NOT_OK(ret, "sqlite3_bind_int", value); +void SQLitePreparedStatement::setBool(const int index, const bool value) { + _param_tracker.setParameter(index); + int ret = sqlite3_bind_int(_statement.get(), index, value ? 1 : 0); + THROW_IF_SET_STMT_NOT_OK(ret, "sqlite3_bind_int", value); } -void SQLitePreparedStatement::setDouble(const int index, const double value) -{ - _param_tracker.setParameter(index); - int ret = sqlite3_bind_double(_statement.get(), index, value); - THROW_IF_SET_STMT_NOT_OK(ret, "sqlite3_bind_double", value); +void SQLitePreparedStatement::setDouble(const int index, const double value) { + _param_tracker.setParameter(index); + int ret = sqlite3_bind_double(_statement.get(), index, value); + THROW_IF_SET_STMT_NOT_OK(ret, "sqlite3_bind_double", value); } -void SQLitePreparedStatement::setString(const int index, const std::string& value) -{ - _param_tracker.setParameter(index); - // Note that SQLITE_TRANSIENT makes copy of the data. - // This may be expensive if it is large. - - int ret = sqlite3_bind_text(_statement.get(), index, - value.c_str(), -1, SQLITE_TRANSIENT); - THROW_IF_SET_STMT_NOT_OK(ret, "sqlite3_bind_text", - (value.length() < 50 ? value : value.substr(0, 47) + "...")); +void SQLitePreparedStatement::setString(const int index, + const std::string &value) { + _param_tracker.setParameter(index); + // Note that SQLITE_TRANSIENT makes copy of the data. + // This may be expensive if it is large. + + int ret = sqlite3_bind_text(_statement.get(), index, value.c_str(), -1, + SQLITE_TRANSIENT); + THROW_IF_SET_STMT_NOT_OK( + ret, "sqlite3_bind_text", + (value.length() < 50 ? value : value.substr(0, 47) + "...")); } -void SQLitePreparedStatement::setNull(const int index) -{ - _param_tracker.setParameter(index); - int ret = sqlite3_bind_null(_statement.get(), index); - THROW_IF_SET_STMT_NOT_OK(ret, "sqlite3_bind_null", ""); +void SQLitePreparedStatement::setNull(const int index) { + _param_tracker.setParameter(index); + int ret = sqlite3_bind_null(_statement.get(), index); + THROW_IF_SET_STMT_NOT_OK(ret, "sqlite3_bind_null", ""); } - } diff --git a/src/sqlite/SQLitePreparedStatement.h b/src/sqlite/SQLitePreparedStatement.h index 942332c..2d4a409 100644 --- a/src/sqlite/SQLitePreparedStatement.h +++ b/src/sqlite/SQLitePreparedStatement.h @@ -3,74 +3,67 @@ #include "SQLiteResultSet.h" -#include -#include #include +#include +#include #include #include struct sqlite3_stmt; -void finalize_sqlite3_stmt(sqlite3_stmt*); +void finalize_sqlite3_stmt(sqlite3_stmt *); -namespace dbc -{ +namespace dbc { class SQLiteConnection; class CountProxy; -class SQLitePreparedStatement : public PreparedStatement -{ +class SQLitePreparedStatement : public PreparedStatement { public: - SQLitePreparedStatement(const std::string& sql, SQLiteConnection& db); + SQLitePreparedStatement(const std::string &sql, SQLiteConnection &db); - virtual void setNull(const int index); + virtual void setNull(const int index); - virtual int getLastInsertId(); + virtual u_int64_t getLastInsertId(); - virtual const char* getSQL() const; + virtual const char *getSQL() const; - sqlite3_stmt* handle() - { return _statement.get(); } + sqlite3_stmt *handle() { return _statement.get(); } - SQLiteConnection& getDb() - { return _db; } + SQLiteConnection &getDb() { return _db; } protected: - virtual void setString(const int parameterIndex, const std::string& val); - virtual void setInt(const int parameterIndex, const int val); - virtual void setDouble(const int parameterIndex, const double val); - virtual void setBool(const int parameterIndex, const bool value); + virtual void setString(const int parameterIndex, const std::string &val); + virtual void setInt(const int parameterIndex, const int val); + virtual void setDouble(const int parameterIndex, const double val); + virtual void setBool(const int parameterIndex, const bool value); - virtual ResultSet::ptr doExecuteQuery(); - virtual const CountProxy& doExecuteUpdate(); + virtual ResultSet::ptr doExecuteQuery(); + virtual const CountProxy &doExecuteUpdate(); - virtual void doReset(); - virtual void doClear(); + virtual void doReset(); + virtual void doClear(); private: - typedef utilcpp::scoped_ptr - sqlite_stmt_scoped_ptr; - - inline void checkParams() - { - if (!_param_tracker.areAllParametersSet()) - { - std::ostringstream params; - params << "Expected " << _param_tracker.getNumParams() - << ", currently set: " << _param_tracker.getSetParams(); - - throw SqlError("Not all statement parameters are set", - params.str(), getSQL()); - } + typedef utilcpp::scoped_ptr + sqlite_stmt_scoped_ptr; + + inline void checkParams() { + if (!_param_tracker.areAllParametersSet()) { + std::ostringstream params; + params << "Expected " << _param_tracker.getNumParams() + << ", currently set: " << _param_tracker.getSetParams(); + + throw SqlError("Not all statement parameters are set", params.str(), + getSQL()); } + } - SQLiteConnection& _db; - sqlite_stmt_scoped_ptr _statement; - ParameterTracker _param_tracker; + SQLiteConnection &_db; + sqlite_stmt_scoped_ptr _statement; + ParameterTracker _param_tracker; }; - } #endif /* SQLITEPREPAREDSTATEMENT_H */ diff --git a/src/sqlite/SQLiteResultSet.cpp b/src/sqlite/SQLiteResultSet.cpp index 78d4e9f..9fdb326 100644 --- a/src/sqlite/SQLiteResultSet.cpp +++ b/src/sqlite/SQLiteResultSet.cpp @@ -1,116 +1,104 @@ #include "SQLiteResultSet.h" -#include "SQLitePreparedStatement.h" #include "SQLiteConnection.h" #include "SQLiteExceptions.h" +#include "SQLitePreparedStatement.h" #include #include -namespace dbc -{ +namespace dbc { -bool SQLiteResultSet::next() -{ - if (_status == DONE) - throw DbErrorBase("No more rows in result set"); +bool SQLiteResultSet::next() { + if (_status == DONE) + throw DbErrorBase("No more rows in result set"); - int ret = sqlite3_step(_statement.handle()); + int ret = sqlite3_step(_statement.handle()); - switch (ret) - { - case SQLITE_ROW: - _status = ROW_READY; - return true; + switch (ret) { + case SQLITE_ROW: + _status = ROW_READY; + return true; - case SQLITE_DONE: + case SQLITE_DONE: - // SQLITE_DONE means that the statement has finished executing - // successfully. sqlite3_step() should not be called again on this - // virtual machine without first calling sqlite3_reset() to reset - // the virtual machine back to its initial state. + // SQLITE_DONE means that the statement has finished executing + // successfully. sqlite3_step() should not be called again on this + // virtual machine without first calling sqlite3_reset() to reset + // the virtual machine back to its initial state. - _status = DONE; - _statement.reset(); - return false; + _status = DONE; + _statement.reset(); + return false; - default: - // FIXME: _statement.reset() may throw, masking the error from - // sqlite3_step() - _statement.reset(); - throw SQLiteSqlError(_statement.getDb(), - "sqlite3_step() failed", _statement.getSQL()); - } + default: + // FIXME: _statement.reset() may throw, masking the error from + // sqlite3_step() + _statement.reset(); + throw SQLiteSqlError(_statement.getDb(), "sqlite3_step() failed", + _statement.getSQL()); + } - // return false; <-- unreachable and generates a compiler warning in VC 2008 + // return false; <-- unreachable and generates a compiler warning in VC 2008 } -bool SQLiteResultSet::isNull(const int columnIndex) const -{ - checkRowAndColumn(columnIndex); - return sqlite3_column_type(_statement.handle(), columnIndex) == SQLITE_NULL; +bool SQLiteResultSet::isNull(const int columnIndex) const { + checkRowAndColumn(columnIndex); + return sqlite3_column_type(_statement.handle(), columnIndex) == SQLITE_NULL; } -void SQLiteResultSet::checkRowAndColumn(const int columnIndex) const -{ - if (_status != ROW_READY) - throw NoResultsError("No rows in result set"); - - if (_numColumns < 0) - // or use sqlite3_column_count in constructor - _numColumns = sqlite3_data_count(_statement.handle()); - - if (_numColumns <= 0) - throw DbErrorBase("No columns available in the current row"); - - if (columnIndex >= _numColumns) - { - std::ostringstream msg; - msg << "Column index " << columnIndex - << " past last result column index " << _numColumns - 1; - throw DbErrorBase(msg.str()); - } -} +void SQLiteResultSet::checkRowAndColumn(const int columnIndex) const { + if (_status != ROW_READY) + throw NoResultsError("No rows in result set"); -void SQLiteResultSet::getString(const int columnIndex, std::string& out) const -{ - checkRowAndColumn(columnIndex); + if (_numColumns < 0) + // or use sqlite3_column_count in constructor + _numColumns = sqlite3_data_count(_statement.handle()); - const unsigned char* result = sqlite3_column_text(_statement.handle(), - columnIndex); + if (_numColumns <= 0) + throw DbErrorBase("No columns available in the current row"); - out = result ? reinterpret_cast(result) : ""; + if (columnIndex >= _numColumns) { + std::ostringstream msg; + msg << "Column index " << columnIndex << " past last result column index " + << _numColumns - 1; + throw DbErrorBase(msg.str()); + } } -std::string SQLiteResultSet::getString(const int columnIndex) const -{ - checkRowAndColumn(columnIndex); +void SQLiteResultSet::getString(const int columnIndex, std::string &out) const { + checkRowAndColumn(columnIndex); - const unsigned char* result = sqlite3_column_text(_statement.handle(), - columnIndex); + const unsigned char *result = + sqlite3_column_text(_statement.handle(), columnIndex); - return result ? reinterpret_cast(result) : ""; + out = result ? reinterpret_cast(result) : ""; } -int SQLiteResultSet::getInt(const int columnIndex) const -{ - checkRowAndColumn(columnIndex); +std::string SQLiteResultSet::getString(const int columnIndex) const { + checkRowAndColumn(columnIndex); - return sqlite3_column_int(_statement.handle(), columnIndex); + const unsigned char *result = + sqlite3_column_text(_statement.handle(), columnIndex); + + return result ? reinterpret_cast(result) : ""; } -bool SQLiteResultSet::getBool(const int columnIndex) const -{ - checkRowAndColumn(columnIndex); +int SQLiteResultSet::getInt(const int columnIndex) const { + checkRowAndColumn(columnIndex); - return sqlite3_column_int(_statement.handle(), columnIndex) != 0; + return sqlite3_column_int(_statement.handle(), columnIndex); } -double SQLiteResultSet::getDouble(const int columnIndex) const -{ - checkRowAndColumn(columnIndex); +bool SQLiteResultSet::getBool(const int columnIndex) const { + checkRowAndColumn(columnIndex); - return sqlite3_column_double(_statement.handle(), columnIndex); + return sqlite3_column_int(_statement.handle(), columnIndex) != 0; } +double SQLiteResultSet::getDouble(const int columnIndex) const { + checkRowAndColumn(columnIndex); + + return sqlite3_column_double(_statement.handle(), columnIndex); +} } diff --git a/src/sqlite/SQLiteResultSet.h b/src/sqlite/SQLiteResultSet.h index 3cdfeb8..16d2480 100644 --- a/src/sqlite/SQLiteResultSet.h +++ b/src/sqlite/SQLiteResultSet.h @@ -3,40 +3,34 @@ #include -namespace dbc -{ +namespace dbc { class SQLitePreparedStatement; -class SQLiteResultSet : public ResultSet -{ +class SQLiteResultSet : public ResultSet { public: - enum Status { INITIAL, ROW_READY, DONE }; + enum Status { INITIAL, ROW_READY, DONE }; - SQLiteResultSet(SQLitePreparedStatement& statement) : - _statement(statement), - _status(INITIAL), - _numColumns(-1) - {} + SQLiteResultSet(SQLitePreparedStatement &statement) + : _statement(statement), _status(INITIAL), _numColumns(-1) {} - virtual bool next(); - virtual bool isNull(const int columnIndex) const; + virtual bool next(); + virtual bool isNull(const int columnIndex) const; protected: - virtual void getString(const int columnIndex, std::string& out) const; - virtual std::string getString(const int columnIndex) const; - virtual int getInt(const int columnIndex) const; - virtual double getDouble(const int columnIndex) const; - virtual bool getBool(const int columnIndex) const; + virtual void getString(const int columnIndex, std::string &out) const; + virtual std::string getString(const int columnIndex) const; + virtual int getInt(const int columnIndex) const; + virtual double getDouble(const int columnIndex) const; + virtual bool getBool(const int columnIndex) const; private: - inline void checkRowAndColumn(const int columnIndex) const; + inline void checkRowAndColumn(const int columnIndex) const; - SQLitePreparedStatement& _statement; - Status _status; - mutable int _numColumns; + SQLitePreparedStatement &_statement; + Status _status; + mutable int _numColumns; }; - } #endif /* SQLITERESULTSET_H */ diff --git a/test/src/main.cpp b/test/src/main.cpp index 9e2f8bc..f23d7b6 100644 --- a/test/src/main.cpp +++ b/test/src/main.cpp @@ -5,6 +5,161 @@ #include +class TestDbccppMysql : public Test::Suite +{ +public: + TestDbccppMysql() : _db((dbc::DbConnection::connect("mysql", "test.db"), + dbc::DbConnection::instance())) + { + _db.executeUpdate("DROP TABLE IF EXISTS person"); + _db.executeUpdate("CREATE TABLE person ( " + "id INTEGER PRIMARY KEY auto_increment, " + "name VARCHAR(32) NOT NULL, " + "height double NOT NULL DEFAULT 1.80 " + ")"); + } + + virtual void test() { + testHappyPath(); + testResultsetSubscriptOperator(); + } + + void testHappyPath() + { + Test::assertEqual("DDL statements return 0", + _db.executeUpdate("CREATE TABLE test " + "(id INTEGER PRIMARY KEY auto_increment)"), + 0); + Test::assertEqual("DDL statements return 0", + _db.executeUpdate("DROP TABLE test"), + 0); + + dbc::PreparedStatement::ptr insert = _db.prepareStatement( + "INSERT INTO person (name) VALUES (?)"); + + std::vector names; + names.push_back("Ervin"); + names.push_back("Melvin"); + names.push_back("Kelvin"); + names.push_back("Steve"); + + for (size_t i = 0; i < names.size(); ++i) + { + insert->set(1, names[i]); + + Test::assertEqual( + "DML statements return number of updated rows", + insert->executeUpdate(), + 1); + } + + insert = _db.prepareStatement( + "INSERT INTO person (id, name, height) VALUES (?, ?, ?)"); + + *insert << 42 << "Douglas" << 1.65; + + Test::assertEqual( + "Binding with operator<< works", + insert->executeUpdate(), + 1); + + dbc::PreparedStatement::ptr select = _db.prepareStatement( + "SELECT DISTINCT name FROM person " + "WHERE name LIKE ? ORDER BY name"); + select->set(1, "%vin"); + + std::vector expected; + expected.push_back("Ervin"); + expected.push_back("Kelvin"); + expected.push_back("Melvin"); + + dbc::ResultSet::ptr results = select->executeQuery(); + + unsigned int counter = 0; + + while (results->next()) + { + // access strings by copy + Test::assertEqual( + "Accessing rows and columns in ResultSet works (str copy)", + results->get(0), expected.at(counter++)); + } + + Test::assertEqual( + "Iteration over result set returns all rows", + counter, expected.size()); + + select = _db.prepareStatement("SELECT * FROM person ORDER BY name"); + results = select->executeQuery(); + results->next(); + + Test::assertTrue( + "Null checking returns false for non-NULL values", + !results->isNull(1)); + + // access strings by out parameter + std::string name; + results->get(1, name); + + Test::assertEqual( + "Getting strings by reference works", + name, "Douglas"); + + // TODO: beware of double comparison + Test::assertEqual( + "Getting doubles works", + results->get(2), 1.65); + + Test::assertEqual( + "Getting ints works", + results->get(0), 42); + + _db.executeUpdate("CREATE TABLE nullable (a INTEGER)"); + _db.executeUpdate("INSERT INTO nullable (a) VALUES (NULL)"); + + select = _db.prepareStatement("SELECT * FROM nullable"); + results = select->executeQuery(); + results->next(); + + Test::assertTrue( + "Null checking returns true for NULL values", + results->isNull(0)); + } + + void testResultsetSubscriptOperator() + { + dbc::PreparedStatement::ptr select = + _db.prepareStatement("SELECT * FROM person ORDER BY name"); + dbc::ResultSet::ptr results_ptr = select->executeQuery(); + dbc::ResultSet& results = *results_ptr; + + results.next(); + + Test::assertEqual( + "Getting ints by subscript operator works", + results[0], 42); + + Test::assertEqual( + "Getting strings by subscript operator works", + results[1], "Douglas"); + + Test::assertEqual( + "Getting doubles by subscript operator works", + results[2], 1.65); + } + + virtual ~TestDbccppMysql() + { + _db.executeUpdate("DROP TABLE IF EXISTS person"); + _db.executeUpdate("DROP TABLE IF EXISTS test"); + _db.executeUpdate("DROP TABLE IF EXISTS nullable"); + _db.disconnect(); + } + +private: + dbc::DbConnection& _db; +}; + class TestDbccpp : public Test::Suite { public: @@ -28,6 +183,7 @@ class TestDbccpp : public Test::Suite _db.executeUpdate("DROP TABLE IF EXISTS person"); _db.executeUpdate("DROP TABLE IF EXISTS test"); _db.executeUpdate("DROP TABLE IF EXISTS nullable"); + _db.disconnect(); } void test() @@ -35,7 +191,7 @@ class TestDbccpp : public Test::Suite testHappyPath(); // testPreparedStatementOperatorShift(); testResultsetSubscriptOperator(); - testInvalidQueries(); + // testInvalidQueries(); } void testHappyPath() @@ -313,7 +469,7 @@ int main() { Test::Controller &c = Test::Controller::instance(); c.setObserver(Test::observer_transferable_ptr(new Test::ColoredStdOutView)); + c.addTestSuite("dbccpp-MySQL", Test::Suite::instance); c.addTestSuite("dbccpp main", Test::Suite::instance); - - return c.run(); + c.run(); }