diff --git a/CMakeLists.txt b/CMakeLists.txt index 41f4783ca2d..30e9ba2f7e5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,6 +10,7 @@ project(YDB-CPP-SDK VERSION ${YDB_SDK_VERSION} LANGUAGES C CXX ASM) option(YDB_SDK_INSTALL "Install YDB C++ SDK" Off) option(YDB_SDK_TESTS "Build YDB C++ SDK tests" Off) option(YDB_SDK_EXAMPLES "Build YDB C++ SDK examples" On) +option(YDB_SDK_ODBC "Build YDB ODBC driver" On) set(YDB_SDK_GOOGLE_COMMON_PROTOS_TARGET "" CACHE STRING "Name of cmake target preparing google common proto library") option(YDB_SDK_USE_RAPID_JSON "Search for rapid json library in system" ON) @@ -61,6 +62,10 @@ add_subdirectory(util) #_ydb_sdk_validate_public_headers() +if (YDB_SDK_ODBC) + add_subdirectory(odbc) +endif() + if (YDB_SDK_EXAMPLES) add_subdirectory(examples) endif() diff --git a/CMakePresets.json b/CMakePresets.json index ad610dd6dd4..92416e766e8 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -55,6 +55,7 @@ "cacheVariables": { "YDB_SDK_TESTS": "TRUE", "YDB_SDK_EXAMPLES": "TRUE", + "YDB_SDK_ODBC": "TRUE", "ARCADIA_ROOT": "..", "ARCADIA_BUILD_ROOT": "." } diff --git a/cmake/common.cmake b/cmake/common.cmake index 546ce4e81f0..89ebb5eaca1 100644 --- a/cmake/common.cmake +++ b/cmake/common.cmake @@ -115,7 +115,7 @@ function(generate_enum_serilization Tgt Input) endfunction() function(add_global_library_for TgtName MainName) - add_library(${TgtName} STATIC ${ARGN}) + _ydb_sdk_add_library(${TgtName} STATIC ${ARGN}) if(APPLE) target_link_options(${MainName} INTERFACE "SHELL:-Wl,-force_load,$${TgtName}>") else() @@ -182,7 +182,7 @@ endfunction() function(_ydb_sdk_add_library Tgt) cmake_parse_arguments(ARG - "INTERFACE" "" "" + "INTERFACE;OBJECT;SHARED" "" "" ${ARGN} ) @@ -192,6 +192,12 @@ function(_ydb_sdk_add_library Tgt) set(libraryMode "INTERFACE") set(includeMode "INTERFACE") endif() + if (ARG_OBJECT) + set(libraryMode "OBJECT") + endif() + if (ARG_SHARED) + set(libraryMode "SHARED") + endif() add_library(${Tgt} ${libraryMode}) target_include_directories(${Tgt} ${includeMode} $ @@ -201,6 +207,7 @@ function(_ydb_sdk_add_library Tgt) target_compile_definitions(${Tgt} ${includeMode} YDB_SDK_USE_STD_STRING ) + set_property(TARGET ${Tgt} PROPERTY POSITION_INDEPENDENT_CODE ON) endfunction() function(_ydb_sdk_validate_public_headers) @@ -255,4 +262,3 @@ function(_ydb_sdk_validate_public_headers) ) target_include_directories(validate_public_interface PUBLIC ${YDB_SDK_BINARY_DIR}/__validate_headers_dir/include) endfunction() - diff --git a/cmake/external_libs.cmake b/cmake/external_libs.cmake index 22d0603e77c..a252c588ae8 100644 --- a/cmake/external_libs.cmake +++ b/cmake/external_libs.cmake @@ -14,6 +14,10 @@ find_package(Brotli 1.1.0 REQUIRED) find_package(jwt-cpp REQUIRED) find_package(double-conversion REQUIRED) +if (YDB_SDK_ODBC) + find_package(ODBC REQUIRED) +endif() + # RapidJSON if (YDB_SDK_USE_RAPID_JSON) find_package(RapidJSON REQUIRED) diff --git a/cmake/testing.cmake b/cmake/testing.cmake index 1319cb16896..7b4a9763f96 100644 --- a/cmake/testing.cmake +++ b/cmake/testing.cmake @@ -103,3 +103,35 @@ function(add_ydb_test) vcs_info(${YDB_TEST_NAME}) endfunction() + +if (YDB_SDK_ODBC) + function(add_odbc_test) + set(opts "") + set(oneval_args NAME WORKING_DIRECTORY OUTPUT_DIRECTORY) + set(multival_args SOURCES LINK_LIBRARIES LABELS) + cmake_parse_arguments(ODBC_TEST + "${opts}" + "${oneval_args}" + "${multival_args}" + ${ARGN} + ) + + add_ydb_test(GTEST + NAME ${ODBC_TEST_NAME} + SOURCES ${ODBC_TEST_SOURCES} + LINK_LIBRARIES + ${ODBC_TEST_LINK_LIBRARIES} + ODBC::ODBC + LABELS + integration + ${ODBC_TEST_LABELS} + ) + + target_compile_definitions(${ODBC_TEST_NAME} + PRIVATE + ODBC_DRIVER_PATH="$" + ) + + add_dependencies(${ODBC_TEST_NAME} ydb-odbc) + endfunction() +endif() diff --git a/odbc/CMakeLists.txt b/odbc/CMakeLists.txt new file mode 100644 index 00000000000..06386fd31dd --- /dev/null +++ b/odbc/CMakeLists.txt @@ -0,0 +1,58 @@ +add_library(ydb-odbc SHARED + src/utils/attr.cpp + src/utils/escape.cpp + src/utils/cursor.cpp + src/utils/types.cpp + src/utils/util.cpp + src/utils/convert.cpp + src/utils/error_manager.cpp + src/odbc_driver.cpp + src/connection_attr.cpp + src/connection.cpp + src/statement_attr.cpp + src/statement.cpp + src/environment.cpp +) + +target_include_directories(ydb-odbc + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR}/src + ${ODBC_INCLUDE_DIRS} +) + +target_link_libraries(ydb-odbc + PRIVATE + YDB-CPP-SDK::Query + YDB-CPP-SDK::Table + YDB-CPP-SDK::Scheme + YDB-CPP-SDK::Driver +) + +set_target_properties(ydb-odbc PROPERTIES + POSITION_INDEPENDENT_CODE ON +) + +install(TARGETS ydb-odbc + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} +) + +install(DIRECTORY include/ + DESTINATION include/ydb-odbc +) + +add_subdirectory(examples) +add_subdirectory(tests) + +include(GNUInstallDirs) + +install(FILES + odbcinst.ini + DESTINATION ${CMAKE_INSTALL_SYSCONFDIR}/odbcinst.d + RENAME ydb-odbc.ini +) + +install(FILES + odbc.ini + DESTINATION ${CMAKE_INSTALL_SYSCONFDIR} +) diff --git a/odbc/README.md b/odbc/README.md new file mode 100644 index 00000000000..c73f9b8704a --- /dev/null +++ b/odbc/README.md @@ -0,0 +1,80 @@ +# YDB ODBC Driver + +ODBC driver for YDB. + +## Requirements + +- CMake 3.10 or higher +- C/C++ compiler with C11 and C++20 support +- YDB C++ SDK +- unixODBC (for Linux/macOS) + +## Build + +```bash +cmake -DYDB_SDK_ODBC=1 --preset release-clang +cmake --build --preset default +``` + +## Configuration + +1. Make sure the driver is registered: +```bash +odbcinst -q -d +``` + +2. Check available data sources: +```bash +odbcinst -q -s +``` + +3. Edit `/etc/odbc.ini` to configure the connection: +```ini +[YDB] +Driver=YDB +Description=YDB Database Connection +Server=your-server:port +Database=/path/to/database +``` + +## Usage + +Example of connecting via isql: +```bash +isql -v YDB +``` + +Example usage in C: +```c +SQLHENV env; +SQLHDBC dbc; +SQLHSTMT stmt; + +// Initialize environment +SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env); +SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + +// Connect +SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc); +SQLConnect(dbc, (SQLCHAR*)"YDB", SQL_NTS, + (SQLCHAR*)"", SQL_NTS, + (SQLCHAR*)"", SQL_NTS); + +// Execute query +SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt); +SQLExecDirect(stmt, (SQLCHAR*)"SELECT * FROM mytable", SQL_NTS); + +// Cleanup +SQLFreeHandle(SQL_HANDLE_STMT, stmt); +SQLDisconnect(dbc); +SQLFreeHandle(SQL_HANDLE_DBC, dbc); +SQLFreeHandle(SQL_HANDLE_ENV, env); +``` + +## Parameters + +Use names $p1, $p2, ... for parameter names + +## License + +Apache License 2.0 diff --git a/odbc/examples/CMakeLists.txt b/odbc/examples/CMakeLists.txt new file mode 100644 index 00000000000..88b1f27cc60 --- /dev/null +++ b/odbc/examples/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(basic) +add_subdirectory(scheme) diff --git a/odbc/examples/basic/CMakeLists.txt b/odbc/examples/basic/CMakeLists.txt new file mode 100644 index 00000000000..b99d1175f43 --- /dev/null +++ b/odbc/examples/basic/CMakeLists.txt @@ -0,0 +1,14 @@ +add_executable(odbc_basic + main.cpp +) + +target_link_libraries(odbc_basic + PRIVATE + ODBC::ODBC +) +target_compile_definitions(odbc_basic + PRIVATE + ODBC_DRIVER_PATH="$" +) + +add_dependencies(odbc_basic ydb-odbc) diff --git a/odbc/examples/basic/main.cpp b/odbc/examples/basic/main.cpp new file mode 100644 index 00000000000..8084e32f3d1 --- /dev/null +++ b/odbc/examples/basic/main.cpp @@ -0,0 +1,132 @@ +#include +#include + +#include + +void PrintOdbcError(SQLSMALLINT handleType, SQLHANDLE handle) { + SQLCHAR sqlState[6] = {0}; + SQLINTEGER nativeError = 0; + SQLCHAR message[256] = {0}; + SQLSMALLINT textLength = 0; + SQLGetDiagRec(handleType, handle, 1, sqlState, &nativeError, message, sizeof(message), &textLength); + std::cerr << "ODBC error: [" << sqlState << "] " << message << std::endl; +} + +int main() { + SQLHENV henv = nullptr; + SQLHDBC hdbc = nullptr; + SQLHSTMT hstmt = nullptr; + SQLRETURN ret; + + std::cout << "1. Allocating environment handle" << std::endl; + ret = SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &henv); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error allocating environment handle" << std::endl; + return 1; + } + SQLSetEnvAttr(henv, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + + std::cout << "2. Allocating connection handle" << std::endl; + ret = SQLAllocHandle(SQL_HANDLE_DBC, henv, &hdbc); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error allocating connection handle" << std::endl; + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "3. Building connection string" << std::endl; + std::string connStr = "Driver=" ODBC_DRIVER_PATH ";Endpoint=localhost:2136;Database=/local;"; + SQLCHAR outConnStr[1024] = {0}; + SQLSMALLINT outConnStrLen = 0; + + std::cout << "4. Connecting with SQLDriverConnect" << std::endl; + ret = SQLDriverConnect(hdbc, NULL, (SQLCHAR*)connStr.c_str(), SQL_NTS, + outConnStr, sizeof(outConnStr), &outConnStrLen, SQL_DRIVER_COMPLETE); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error connecting with SQLDriverConnect" << std::endl; + PrintOdbcError(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "5. Allocating statement handle" << std::endl; + ret = SQLAllocHandle(SQL_HANDLE_STMT, hdbc, &hstmt); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error allocating statement handle" << std::endl; + SQLDisconnect(hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "6. Executing query" << std::endl; + SQLCHAR query[] = R"( + DECLARE $p1 AS Int64?; + SELECT id, data from test_table WHERE id == $p1; + )"; + + int64_t paramValue = 1; + SQLLEN paramInd = 0; + ret = SQLBindParameter(hstmt, 1, SQL_PARAM_INPUT, SQL_C_SBIGINT, SQL_BIGINT, 0, 0, ¶mValue, 0, ¶mInd); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error binding parameter" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + SQLFreeHandle(SQL_HANDLE_STMT, hstmt); + SQLDisconnect(hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + ret = SQLExecDirect(hstmt, query, SQL_NTS); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error executing query" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + SQLFreeHandle(SQL_HANDLE_STMT, hstmt); + SQLDisconnect(hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "7. Fetching result" << std::endl; + + SQLLEN ind = 0; + int value1 = 0; + if (SQLBindCol(hstmt, 1, SQL_C_SLONG, &value1, 0, &ind) != SQL_SUCCESS) { + std::cerr << "Error binding column 1" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + return 1; + } + + SQLCHAR value2[1024] = {0}; + if (SQLBindCol(hstmt, 2, SQL_C_CHAR, &value2, 1024, &ind) != SQL_SUCCESS) { + std::cerr << "Error binding column 2" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + return 1; + } + + while ((ret = SQLFetch(hstmt)) == SQL_SUCCESS || ret == SQL_SUCCESS_WITH_INFO) { + if (ret != SQL_SUCCESS) { + std::cerr << "Error fetching result" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + return 1; + } + + std::cout << "Result column 1: " << value1 << std::endl; + std::cout << "Result column 2: " << value2 << std::endl; + + std::cout << "--------------------------------" << std::endl; + } + + std::cout << "8. Cleaning up" << std::endl; + + SQLCloseCursor(hstmt); + SQLFreeHandle(SQL_HANDLE_STMT, hstmt); + SQLDisconnect(hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + + return 0; +} diff --git a/odbc/examples/scheme/CMakeLists.txt b/odbc/examples/scheme/CMakeLists.txt new file mode 100644 index 00000000000..ffab881aed5 --- /dev/null +++ b/odbc/examples/scheme/CMakeLists.txt @@ -0,0 +1,14 @@ +add_executable(odbc_scheme + main.cpp +) + +target_link_libraries(odbc_scheme + PRIVATE + ODBC::ODBC +) +target_compile_definitions(odbc_scheme + PRIVATE + ODBC_DRIVER_PATH="$" +) + +add_dependencies(odbc_scheme ydb-odbc) diff --git a/odbc/examples/scheme/main.cpp b/odbc/examples/scheme/main.cpp new file mode 100644 index 00000000000..3ae2cd6fe40 --- /dev/null +++ b/odbc/examples/scheme/main.cpp @@ -0,0 +1,116 @@ +#include +#include + +#include + +void PrintOdbcError(SQLSMALLINT handleType, SQLHANDLE handle) { + SQLCHAR sqlState[6] = {0}; + SQLINTEGER nativeError = 0; + SQLCHAR message[256] = {0}; + SQLSMALLINT textLength = 0; + SQLGetDiagRec(handleType, handle, 1, sqlState, &nativeError, message, sizeof(message), &textLength); + std::cerr << "ODBC error: [" << sqlState << "] " << message << std::endl; +} + +int main() { + SQLHENV henv = nullptr; + SQLHDBC hdbc = nullptr; + SQLHSTMT hstmt = nullptr; + SQLRETURN ret; + + std::cout << "1. Allocating environment handle" << std::endl; + ret = SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &henv); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error allocating environment handle" << std::endl; + return 1; + } + SQLSetEnvAttr(henv, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + + std::cout << "2. Allocating connection handle" << std::endl; + ret = SQLAllocHandle(SQL_HANDLE_DBC, henv, &hdbc); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error allocating connection handle" << std::endl; + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "3. Building connection string" << std::endl; + std::string connStr = "Driver=" ODBC_DRIVER_PATH ";Endpoint=localhost:2136;Database=/local;"; + SQLCHAR outConnStr[1024] = {0}; + SQLSMALLINT outConnStrLen = 0; + + std::cout << "4. Connecting with SQLDriverConnect" << std::endl; + ret = SQLDriverConnect(hdbc, NULL, (SQLCHAR*)connStr.c_str(), SQL_NTS, + outConnStr, sizeof(outConnStr), &outConnStrLen, SQL_DRIVER_COMPLETE); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error connecting with SQLDriverConnect" << std::endl; + PrintOdbcError(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "5. Allocating statement handle" << std::endl; + ret = SQLAllocHandle(SQL_HANDLE_STMT, hdbc, &hstmt); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error allocating statement handle" << std::endl; + SQLDisconnect(hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "6. Getting tables" << std::endl; + + SQLCHAR pattern[] = "/local"; + SQLCHAR tableType[] = "TABLE"; + + ret = SQLTables(hstmt, NULL, 0, NULL, 0, pattern, SQL_NTS, tableType, SQL_NTS); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error executing query" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + SQLFreeHandle(SQL_HANDLE_STMT, hstmt); + SQLDisconnect(hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "7. Fetching result" << std::endl; + + SQLLEN ind = 0; + SQLCHAR value1[1024] = {0}; + if (SQLBindCol(hstmt, 3, SQL_C_CHAR, &value1, 1024, &ind) != SQL_SUCCESS) { + std::cerr << "Error binding column 1" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + return 1; + } + + SQLCHAR value2[1024] = {0}; + if (SQLBindCol(hstmt, 4, SQL_C_CHAR, &value2, 1024, &ind) != SQL_SUCCESS) { + std::cerr << "Error binding column 2" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + return 1; + } + + while ((ret = SQLFetch(hstmt)) == SQL_SUCCESS || ret == SQL_SUCCESS_WITH_INFO) { + if (ret != SQL_SUCCESS) { + std::cerr << "Error fetching result" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + return 1; + } + + std::cout << "Table name: " << value1 << std::endl; + std::cout << "Table type: " << value2 << std::endl; + + std::cout << "--------------------------------" << std::endl; + } + + std::cout << "8. Cleaning up" << std::endl; + SQLFreeHandle(SQL_HANDLE_STMT, hstmt); + SQLDisconnect(hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + + return 0; +} diff --git a/odbc/odbc.ini b/odbc/odbc.ini new file mode 100644 index 00000000000..6335b3ee389 --- /dev/null +++ b/odbc/odbc.ini @@ -0,0 +1,9 @@ +[ODBC Data Sources] +YDB=YDB ODBC Driver + +[YDB] +Driver=YDB +Description=YDB Database Connection +Server=grpc://localhost:2136 +Database=local +AuthMode=none \ No newline at end of file diff --git a/odbc/odbcinst.ini b/odbc/odbcinst.ini new file mode 100644 index 00000000000..fd0b3f27650 --- /dev/null +++ b/odbc/odbcinst.ini @@ -0,0 +1,4 @@ +[YDB] +Description=YDB ODBC Driver +Driver=/home/brgayazov/ydbwork/ydb-cpp-sdk/build/odbc/libydb-odbc.so +Setup=/home/brgayazov/ydbwork/ydb-cpp-sdk/build/odbc/libydb-odbc.so \ No newline at end of file diff --git a/odbc/src/connection.cpp b/odbc/src/connection.cpp new file mode 100644 index 00000000000..85724670f9d --- /dev/null +++ b/odbc/src/connection.cpp @@ -0,0 +1,273 @@ +#include "connection.h" +#include "statement.h" +#include "utils/error_manager.h" + +#include +#include +#include + +#include +#include + +#include + +namespace NYdb { +namespace NOdbc { + +namespace { + +struct TDriverKey { + std::string Endpoint; + std::string Database; + + bool operator==(const TDriverKey& other) const noexcept { + return Endpoint == other.Endpoint && Database == other.Database; + } +}; + +struct TDriverKeyHash { + size_t operator()(const TDriverKey& key) const noexcept { + return std::hash{}(key.Endpoint) ^ (std::hash{}(key.Database) << 1U); + } +}; + +struct TDriverPool { + std::unordered_map, TDriverKeyHash> DriversByKey; + size_t InsertionsSinceCleanup = 0; +}; + +void CleanupExpiredDrivers(TDriverPool& pool) { + for (auto mapIt = pool.DriversByKey.begin(); mapIt != pool.DriversByKey.end();) { + if (mapIt->second.expired()) { + mapIt = pool.DriversByKey.erase(mapIt); + } else { + ++mapIt; + } + } +} + +std::shared_ptr AcquireSharedDriver(const std::string& endpoint, const std::string& database) { + static TDriverPool pool; + TDriverKey key{endpoint, database}; + auto it = pool.DriversByKey.find(key); + if (it != pool.DriversByKey.end()) { + if (std::shared_ptr existing = it->second.lock()) { + return existing; + } + } + auto driver = std::make_shared( + NYdb::TDriverConfig().SetEndpoint(endpoint).SetDatabase(database)); + pool.DriversByKey[std::move(key)] = driver; + ++pool.InsertionsSinceCleanup; + if (pool.InsertionsSinceCleanup >= 32) { + CleanupExpiredDrivers(pool); + pool.InsertionsSinceCleanup = 0; + } + return driver; +} + +} // namespace + +SQLRETURN TConnection::DriverConnect(const std::string& connectionString) { + std::map params; + size_t pos = 0; + while (pos < connectionString.size()) { + size_t eq = connectionString.find('=', pos); + if (eq == std::string::npos) { + break; + } + + size_t sc = connectionString.find(';', eq); + std::string key = connectionString.substr(pos, eq-pos); + std::string val = connectionString.substr(eq+1, (sc == std::string::npos ? std::string::npos : sc-eq-1)); + params[key] = val; + if (sc == std::string::npos) { + break; + } + pos = sc+1; + } + Endpoint_ = params["Endpoint"]; + Database_ = params["Database"]; + + if (Endpoint_.empty() || Database_.empty()) { + throw TOdbcException("08001", 0, "Missing Endpoint or Database in connection string"); + } + + TConnectionAttributes::NormalizeCatalogPath(Database_); + RecreateYdbClients(); + Attributes_.SetCurrentCatalog(Database_); + + return SQL_SUCCESS; +} + +SQLRETURN TConnection::Connect(const std::string& serverName, + const std::string& userName, + const std::string& auth) { + + char endpoint[256] = {0}; + char database[256] = {0}; + + //SQLGetPrivateProfileString(serverName.c_str(), "Endpoint", "", endpoint, sizeof(endpoint), nullptr); + //SQLGetPrivateProfileString(serverName.c_str(), "Database", "", database, sizeof(database), nullptr); + + Endpoint_ = endpoint; + Database_ = database; + + if (Endpoint_.empty() || Database_.empty()) { + throw TOdbcException("08001", 0, "Missing Endpoint or Database in DSN"); + } + + TConnectionAttributes::NormalizeCatalogPath(Database_); + RecreateYdbClients(); + Attributes_.SetCurrentCatalog(Database_); + + return SQL_SUCCESS; +} + +SQLRETURN TConnection::Disconnect() { + QuerySession_.reset(); + Tx_.reset(); + YdbSchemeClient_.reset(); + YdbTableClient_.reset(); + YdbClient_.reset(); + YdbDriver_.reset(); + return SQL_SUCCESS; +} + +NQuery::TSession& TConnection::GetOrCreateQuerySession() { + if (!QuerySession_) { + auto sessionResult = YdbClient_->GetSession().ExtractValueSync(); + NStatusHelpers::ThrowOnError(sessionResult); + QuerySession_.emplace(std::move(sessionResult.GetSession())); + } + return *QuerySession_; +} + +std::unique_ptr TConnection::CreateStatement() { + return std::make_unique(this); +} + +void TConnection::RemoveStatement(TStatement* stmt) { + Statements_.erase(std::remove_if(Statements_.begin(), Statements_.end(), + [stmt](const std::unique_ptr& s) { return s.get() == stmt; }), Statements_.end()); +} + +SQLRETURN TConnection::SetAutocommit(bool value) { + Attributes_.SetAutocommit(value); + if (Attributes_.GetAutocommit() && Tx_) { + auto status = Tx_->Commit().ExtractValueSync(); + NStatusHelpers::ThrowOnError(status); + Tx_.reset(); + } + return SQL_SUCCESS; +} + +bool TConnection::GetAutocommit() const { + return Attributes_.GetAutocommit(); +} + +SQLRETURN TConnection::SetConnectAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER stringLength) { + if (attr == SQL_ATTR_CURRENT_CATALOG) { + std::optional rebindDatabase; + SQLRETURN rc = Attributes_.ApplyCatalogChange(value, stringLength, Database_, rebindDatabase, *this); + if (rc != SQL_SUCCESS) { + return rc; + } + if (rebindDatabase) { + RebindToDatabase(*rebindDatabase); + } + return SQL_SUCCESS; + } + return Attributes_.SetConnectAttr(attr, value, stringLength, [this](bool autocommit) { + return SetAutocommit(autocommit); + }, *this); +} + +SQLRETURN TConnection::GetConnectAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr) { + return Attributes_.GetConnectAttr(attr, value, bufferLength, stringLengthPtr, *this); +} + +NQuery::TTxSettings TConnection::MakeTxSettings() const { + return Attributes_.MakeTxSettings(); +} + +const std::optional& TConnection::GetTx() { + return Tx_; +} + +void TConnection::SetTx(const NQuery::TTransaction& tx) { + Tx_ = tx; +} + +void TConnection::ResetTx() { + Tx_.reset(); +} + +void TConnection::ResetQuerySession() { + QuerySession_.reset(); +} + +SQLRETURN TConnection::CommitTx() { + auto status = Tx_->Commit().ExtractValueSync(); + NStatusHelpers::ThrowOnError(status); + Tx_.reset(); + return SQL_SUCCESS; +} + +SQLRETURN TConnection::RollbackTx() { + auto status = Tx_->Rollback().ExtractValueSync(); + NStatusHelpers::ThrowOnError(status); + Tx_.reset(); + return SQL_SUCCESS; +} + +void TConnection::SetEnvironment(TEnvironment* env){ + if (ParentEnv_){ + throw std::logic_error("Connection already bound to environment"); + } + ParentEnv_ = env; +} + +TEnvironment* TConnection::GetEnvironment(){ + return ParentEnv_; +} + +void TConnection::RecreateYdbClients() { + QuerySession_.reset(); + Tx_.reset(); + YdbSchemeClient_.reset(); + YdbTableClient_.reset(); + YdbClient_.reset(); + YdbDriver_ = AcquireSharedDriver(Endpoint_, Database_); + YdbClient_ = std::make_unique(*YdbDriver_); + YdbSchemeClient_ = std::make_unique(*YdbDriver_); + YdbTableClient_ = std::make_unique(*YdbDriver_); +} + +void TConnection::RebindToDatabase(const std::string& newDatabase) { + std::string db = newDatabase; + TConnectionAttributes::NormalizeCatalogPath(db); + Database_ = std::move(db); + Attributes_.SetCurrentCatalog(Database_); + RecreateYdbClients(); +} + + +std::string TConnection::WrapQueryForCurrentCatalog(const std::string& sql) const { + std::optional rel = Attributes_.ResolveCatalogRoute(Database_).TablePathPrefix; + if (!rel) { + return sql; + } + std::string escapedPrefix; + escapedPrefix.reserve(rel->size() + 8); + for (const char ch : *rel) { + if (ch == '\\' || ch == '"') { + escapedPrefix.push_back('\\'); + } + escapedPrefix.push_back(ch); + } + return "PRAGMA TablePathPrefix = \"" + escapedPrefix + "\";\n" + sql; +} +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/connection.h b/odbc/src/connection.h new file mode 100644 index 00000000000..dac7721c000 --- /dev/null +++ b/odbc/src/connection.h @@ -0,0 +1,82 @@ +#pragma once + +#include "environment.h" +#include "connection_attr.h" +#include "utils/error_manager.h" + +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace NYdb { +namespace NOdbc { + +class TStatement; + +class TConnection : public TErrorManager { +private: + std::shared_ptr YdbDriver_; + std::unique_ptr YdbClient_; + std::unique_ptr YdbTableClient_; + std::unique_ptr YdbSchemeClient_; + std::optional Tx_; + std::optional QuerySession_; + + std::vector> Statements_; + std::string Endpoint_; + std::string Database_; + std::string AuthToken_; + TEnvironment* ParentEnv_; + + TConnectionAttributes Attributes_; + + void RecreateYdbClients(); + void RebindToDatabase(const std::string& newDatabase); +public: + SQLRETURN Connect(const std::string& serverName, + const std::string& userName, + const std::string& auth); + + SQLRETURN DriverConnect(const std::string& connectionString); + SQLRETURN Disconnect(); + + std::unique_ptr CreateStatement(); + void RemoveStatement(TStatement* stmt); + + NYdb::NQuery::TQueryClient* GetClient() { return YdbClient_.get(); } + NQuery::TSession& GetOrCreateQuerySession(); + NYdb::NTable::TTableClient* GetTableClient() { return YdbTableClient_.get(); } + NScheme::TSchemeClient* GetSchemeClient() { return YdbSchemeClient_.get(); } + + SQLRETURN SetAutocommit(bool value); + bool GetAutocommit() const; + + SQLRETURN SetConnectAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER stringLength); + SQLRETURN GetConnectAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER bufferLength, SQLINTEGER* stringLengthPtr); + NQuery::TTxSettings MakeTxSettings() const; + + std::string WrapQueryForCurrentCatalog(const std::string& sql) const; + + const std::optional& GetTx(); + void SetTx(const NQuery::TTransaction& tx); + void ResetTx(); + void ResetQuerySession(); + + SQLRETURN CommitTx(); + SQLRETURN RollbackTx(); + + void SetEnvironment(TEnvironment* env); + TEnvironment* GetEnvironment(); +}; + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/connection_attr.cpp b/odbc/src/connection_attr.cpp new file mode 100644 index 00000000000..6197a4ad4a9 --- /dev/null +++ b/odbc/src/connection_attr.cpp @@ -0,0 +1,308 @@ + +#include "connection_attr.h" +#include "utils/attr.h" +#include "utils/diag.h" + +#include + +namespace NYdb { +namespace NOdbc { + +namespace { + +namespace Catalog { + +void NormalizePath(std::string& path) { + if (path.empty() || path == "/") { + return; + } + const size_t trailingSlashStart = path.find_last_not_of('/'); + if (trailingSlashStart == std::string::npos) { + path.assign("/"); + return; + } + path.erase(trailingSlashStart + 1); +} + +TConnectionAttributes::TCatalogBinding BuildBinding(const std::string& currentCatalog, const std::string& database) { + TConnectionAttributes::TCatalogBinding binding; + binding.Catalog = currentCatalog; + binding.Database = database; + NormalizePath(binding.Catalog); + NormalizePath(binding.Database); + if (binding.Catalog == binding.Database) { + return binding; + } + + const std::string databasePrefix = binding.Database + "/"; + if (binding.Catalog.size() <= databasePrefix.size() || + binding.Catalog.compare(0, databasePrefix.size(), databasePrefix) != 0) { + return binding; + } + + std::string relativeCatalog = binding.Catalog.substr(databasePrefix.size()); + if (!relativeCatalog.empty()) { + binding.RelativeCatalog = std::move(relativeCatalog); + } + return binding; +} + +} // namespace Catalog + +namespace Tx { + +bool IsKnownTxnIsolation(SQLUINTEGER txnIsolation) { + switch (txnIsolation) { + case SQL_TXN_READ_UNCOMMITTED: + case SQL_TXN_READ_COMMITTED: + case SQL_TXN_REPEATABLE_READ: + case SQL_TXN_SERIALIZABLE: + return true; + default: + return false; + } +} + +std::optional ResolveTxMode(SQLUINTEGER accessMode, SQLUINTEGER txnIsolation) { + if (accessMode == SQL_MODE_READ_ONLY) { + switch (txnIsolation) { + case SQL_TXN_READ_UNCOMMITTED: + return NQuery::TTxSettings::TS_STALE_RO; + case SQL_TXN_READ_COMMITTED: + return NQuery::TTxSettings::TS_ONLINE_RO; + case SQL_TXN_REPEATABLE_READ: + case SQL_TXN_SERIALIZABLE: + return NQuery::TTxSettings::TS_SNAPSHOT_RO; + default: + return std::nullopt; + } + } + + switch (txnIsolation) { + case SQL_TXN_REPEATABLE_READ: + case SQL_TXN_SERIALIZABLE: + return NQuery::TTxSettings::TS_SERIALIZABLE_RW; + default: + return std::nullopt; + } +} + +} // namespace Tx + +namespace Autocommit { + +SQLRETURN Get(bool autocommitEnabled, SQLPOINTER value) { + auto* out = reinterpret_cast(value); + *out = autocommitEnabled ? SQL_AUTOCOMMIT_ON : SQL_AUTOCOMMIT_OFF; + return SQL_SUCCESS; +} + +} // namespace Autocommit + +} + +void TConnectionAttributes::NormalizeCatalogPath(std::string& path) { + Catalog::NormalizePath(path); +} + +SQLRETURN TConnectionAttributes::SetAutocommit(bool value) { + Autocommit_ = value; + return SQL_SUCCESS; +} + +bool TConnectionAttributes::GetAutocommit() const { + return Autocommit_; +} + +SQLRETURN TConnectionAttributes::SetConnectAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER stringLength, + const std::function& applyAutocommit, + TErrorManager& errors) { + switch (attr) { + case SQL_ATTR_AUTOCOMMIT: + return SetAutocommit(value, applyAutocommit, errors); + case SQL_ATTR_ACCESS_MODE: + return SetAccessMode(value, errors); + case SQL_ATTR_TXN_ISOLATION: + return SetTxnIsolation(value, errors); + case SQL_ATTR_CURRENT_CATALOG: + return SetCurrentCatalog(value, stringLength, errors); + default: + return Diag::AddNotImplemented(errors); + } +} + +SQLRETURN TConnectionAttributes::GetConnectAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) const { + if (!value) { + return Diag::AddNullPointer(errors); + } + if (stringLengthPtr) { + *stringLengthPtr = 0; + } + switch (attr) { + case SQL_ATTR_AUTOCOMMIT: + return GetAutocommit(value); + case SQL_ATTR_ACCESS_MODE: + return GetAccessMode(value); + case SQL_ATTR_TXN_ISOLATION: + return GetTxnIsolation(value); + case SQL_ATTR_CURRENT_CATALOG: + return GetCurrentCatalog(value, bufferLength, stringLengthPtr, errors); + default: + return Diag::AddNotImplemented(errors); + } +} + +SQLRETURN TConnectionAttributes::SetAutocommit( + SQLPOINTER value, + const std::function& applyAutocommit, + TErrorManager& errors) { + const auto token = ReadIntegerAttrIfIn( + value, + {static_cast(SQL_AUTOCOMMIT_ON), static_cast(SQL_AUTOCOMMIT_OFF)}); + if (!token) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_AUTOCOMMIT"); + } + if (*token == static_cast(SQL_AUTOCOMMIT_ON)) { + return applyAutocommit(true); + } + return applyAutocommit(false); +} + +SQLRETURN TConnectionAttributes::SetAccessMode(SQLPOINTER value, TErrorManager& errors) { + const auto mode = ReadIntegerAttrIfIn(value, {SQL_MODE_READ_WRITE, SQL_MODE_READ_ONLY}); + if (!mode) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_ACCESS_MODE"); + } + AccessMode_ = *mode; + auto txMode = Tx::ResolveTxMode(AccessMode_, TxnIsolation_); + if (!txMode) { + return errors.AddError( + "HYC00", + 0, + AccessMode_ == SQL_MODE_READ_WRITE + ? "Transaction isolation is not supported for read-write mode" + : "Transaction isolation is not supported for read-only mode"); + } + TxMode_ = *txMode; + return SQL_SUCCESS; +} + +SQLRETURN TConnectionAttributes::SetTxnIsolation(SQLPOINTER value, TErrorManager& errors) { + const SQLUINTEGER isolation = ReadIntegerAttr(value); + if (!Tx::IsKnownTxnIsolation(isolation)) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_TXN_ISOLATION"); + } + auto txMode = Tx::ResolveTxMode(AccessMode_, isolation); + if (!txMode) { + return errors.AddError("HYC00", 0, "SQL_ATTR_TXN_ISOLATION value is not supported"); + } + TxnIsolation_ = isolation; + TxMode_ = *txMode; + return SQL_SUCCESS; +} + +SQLRETURN TConnectionAttributes::SetCurrentCatalog(SQLPOINTER value, SQLINTEGER stringLength, TErrorManager& errors) { + if (!value) { + return Diag::AddNullPointer(errors); + } + CurrentCatalog_ = ReadAttributeString(value, stringLength); + Catalog::NormalizePath(CurrentCatalog_); + if (CurrentCatalog_.empty()) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_CURRENT_CATALOG"); + } + return SQL_SUCCESS; +} + +SQLRETURN TConnectionAttributes::GetAutocommit(SQLPOINTER value) const { + return Autocommit::Get(Autocommit_, value); +} + +SQLRETURN TConnectionAttributes::GetAccessMode(SQLPOINTER value) const { + auto* out = reinterpret_cast(value); + *out = AccessMode_; + return SQL_SUCCESS; +} + +SQLRETURN TConnectionAttributes::GetTxnIsolation(SQLPOINTER value) const { + auto* out = reinterpret_cast(value); + *out = TxnIsolation_; + return SQL_SUCCESS; +} + +SQLRETURN TConnectionAttributes::GetCurrentCatalog( + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) const { + return WriteAttributeString(CurrentCatalog_, value, bufferLength, stringLengthPtr, errors); +} + +NQuery::TTxSettings TConnectionAttributes::MakeTxSettings() const { + switch (TxMode_) { + case NQuery::TTxSettings::TS_ONLINE_RO: + return NQuery::TTxSettings::OnlineRO(); + case NQuery::TTxSettings::TS_STALE_RO: + return NQuery::TTxSettings::StaleRO(); + case NQuery::TTxSettings::TS_SNAPSHOT_RO: + return NQuery::TTxSettings::SnapshotRO(); + case NQuery::TTxSettings::TS_SNAPSHOT_RW: + return NQuery::TTxSettings::SnapshotRW(); + case NQuery::TTxSettings::TS_SERIALIZABLE_RW: + default: + return NQuery::TTxSettings::SerializableRW(); + } +} + +void TConnectionAttributes::SetCurrentCatalog(const std::string& value) { + CurrentCatalog_ = value; + Catalog::NormalizePath(CurrentCatalog_); +} + +const std::string& TConnectionAttributes::GetCurrentCatalog() const { + return CurrentCatalog_; +} + +TConnectionAttributes::TCatalogBinding TConnectionAttributes::BuildCatalogBinding(const std::string& database) const { + return Catalog::BuildBinding(CurrentCatalog_, database); +} + +TConnectionAttributes::TCatalogRoute TConnectionAttributes::ResolveCatalogRoute(const std::string& currentDatabase) const { + const TCatalogBinding binding = BuildCatalogBinding(currentDatabase); + if (binding.Catalog == binding.Database) { + return {binding.Database, std::nullopt}; + } + if (binding.RelativeCatalog) { + return {binding.Database, binding.Catalog}; + } + return {binding.Catalog, std::nullopt}; +} + +SQLRETURN TConnectionAttributes::ApplyCatalogChange( + SQLPOINTER value, + SQLINTEGER stringLength, + const std::string& currentDatabase, + std::optional& rebindDatabase, + TErrorManager& errors) { + SQLRETURN rc = SetCurrentCatalog(value, stringLength, errors); + if (rc != SQL_SUCCESS) { + return rc; + } + const TCatalogRoute route = ResolveCatalogRoute(currentDatabase); + if (route.EffectiveDatabase != currentDatabase) { + rebindDatabase = route.EffectiveDatabase; + } else { + rebindDatabase.reset(); + } + return SQL_SUCCESS; +} + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/connection_attr.h b/odbc/src/connection_attr.h new file mode 100644 index 00000000000..607f7bf929f --- /dev/null +++ b/odbc/src/connection_attr.h @@ -0,0 +1,86 @@ +#pragma once + +#include "utils/error_manager.h" + +#include + +#include +#include +#include + +#include +#include + +namespace NYdb { +namespace NOdbc { + +class TConnectionAttributes { +public: + struct TCatalogBinding { + std::string Catalog; + std::string Database; + std::optional RelativeCatalog; + }; + + struct TCatalogRoute { + std::string EffectiveDatabase; + std::optional TablePathPrefix; + }; + + SQLRETURN SetAutocommit(bool value); + bool GetAutocommit() const; + + SQLRETURN SetConnectAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER stringLength, + const std::function& applyAutocommit, + TErrorManager& errors); + + SQLRETURN GetConnectAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) const; + + NQuery::TTxSettings MakeTxSettings() const; + void SetCurrentCatalog(const std::string& value); + const std::string& GetCurrentCatalog() const; + TCatalogBinding BuildCatalogBinding(const std::string& database) const; + TCatalogRoute ResolveCatalogRoute(const std::string& currentDatabase) const; + SQLRETURN ApplyCatalogChange( + SQLPOINTER value, + SQLINTEGER stringLength, + const std::string& currentDatabase, + std::optional& rebindDatabase, + TErrorManager& errors); + static void NormalizeCatalogPath(std::string& path); + +private: + SQLRETURN SetAutocommit( + SQLPOINTER value, + const std::function& applyAutocommit, + TErrorManager& errors); + SQLRETURN SetAccessMode(SQLPOINTER value, TErrorManager& errors); + SQLRETURN SetTxnIsolation(SQLPOINTER value, TErrorManager& errors); + SQLRETURN SetCurrentCatalog(SQLPOINTER value, SQLINTEGER stringLength, TErrorManager& errors); + + SQLRETURN GetAutocommit(SQLPOINTER value) const; + SQLRETURN GetAccessMode(SQLPOINTER value) const; + SQLRETURN GetTxnIsolation(SQLPOINTER value) const; + SQLRETURN GetCurrentCatalog( + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) const; + + bool Autocommit_ = true; + std::string CurrentCatalog_; + SQLUINTEGER AccessMode_ = SQL_MODE_READ_WRITE; + SQLUINTEGER TxnIsolation_ = SQL_TXN_SERIALIZABLE; + NQuery::TTxSettings::ETransactionMode TxMode_ = NQuery::TTxSettings::TS_SERIALIZABLE_RW; +}; + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/environment.cpp b/odbc/src/environment.cpp new file mode 100644 index 00000000000..44e3473d023 --- /dev/null +++ b/odbc/src/environment.cpp @@ -0,0 +1,82 @@ +#include "environment.h" +#include "connection.h" + +namespace NYdb { +namespace NOdbc { + +TEnvironment::TEnvironment() : OdbcVersion_(SQL_OV_ODBC3) {} +TEnvironment::~TEnvironment() {} + +SQLRETURN TEnvironment::SetAttribute(SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER stringLength) { + switch (attribute) { + case SQL_ATTR_ODBC_VERSION: { + if (!value) { + return AddError("HY009", 0, "Invalid use of null pointer"); + } + OdbcVersion_ = static_cast(reinterpret_cast(value)); + return SQL_SUCCESS; + } + case SQL_ATTR_OUTPUT_NTS: { + if (value && static_cast(reinterpret_cast(value)) != SQL_TRUE) { + return AddError("HY024", 0, "SQL_ATTR_OUTPUT_NTS must be SQL_TRUE"); + } + return SQL_SUCCESS; + } + default: + return AddError("HYC00", 0, "Optional feature not implemented"); + } +} + +void TEnvironment::RegisterConnection(TConnection* conn){ + if (conn == nullptr){ + throw std::invalid_argument("null connection"); + } + connections_.insert(conn); +} + +void TEnvironment::UnregisterConnection(TConnection* conn){ + if (conn == nullptr){ + throw std::invalid_argument("null connection"); + } + connections_.erase(conn); +} + +std::vector TEnvironment::GetConnectionsSnapshot() const { + return std::vector(connections_.begin(), connections_.end()); +} + +SQLRETURN TEnvironment::EndTran(SQLSMALLINT completionType){ + if (completionType != SQL_COMMIT && completionType != SQL_ROLLBACK){ + return AddError("HY012", 0, "Invalid transaction operation code"); + } + bool hasFailures = false; + int failedCount = 0; + + for (auto* conn : connections_) { + if (!conn || !conn->GetTx()) { + continue; + } + try { + if (completionType == SQL_COMMIT) { + conn->CommitTx(); + } else { + conn->RollbackTx(); + } + } catch (const std::exception& ex) { + hasFailures = true; + ++failedCount; + AddError("HY000", 0, ex.what(), SQL_SUCCESS_WITH_INFO); + } catch (...) { + hasFailures = true; + ++failedCount; + AddError("HY000", 0, "Unknown error during ENV-level transaction completion", SQL_SUCCESS_WITH_INFO); + } + } + if (hasFailures) { + AddError("01000", 0, "SQLEndTran(SQL_HANDLE_ENV): some connections failed", SQL_SUCCESS_WITH_INFO); + return SQL_SUCCESS_WITH_INFO; + } + return SQL_SUCCESS; +} +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/environment.h b/odbc/src/environment.h new file mode 100644 index 00000000000..70a785f45d7 --- /dev/null +++ b/odbc/src/environment.h @@ -0,0 +1,34 @@ +#pragma once + +#include "utils/error_manager.h" + +#include +#include +#include +#include + +namespace NYdb { +namespace NOdbc { + +class TConnection; + +class TEnvironment : public TErrorManager { +private: + SQLINTEGER OdbcVersion_; + std::unordered_set connections_; + +public: + TEnvironment(); + ~TEnvironment(); + + SQLRETURN SetAttribute(SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER stringLength); + + void RegisterConnection(TConnection*); + void UnregisterConnection(TConnection*); + std::vector GetConnectionsSnapshot() const; + + SQLRETURN EndTran(SQLSMALLINT completionType); +}; + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/odbc_driver.cpp b/odbc/src/odbc_driver.cpp new file mode 100644 index 00000000000..97d9ffd8a5a --- /dev/null +++ b/odbc/src/odbc_driver.cpp @@ -0,0 +1,412 @@ +#include "environment.h" +#include "connection.h" +#include "statement.h" + +#include "utils/util.h" +#include "utils/error_manager.h" + +#include +#include + +namespace { + template + Handle* GetHandle(SQLHANDLE handle) { + if (!handle) { + throw NYdb::NOdbc::TOdbcException("HY000", 0, "Invalid handle", SQL_INVALID_HANDLE); + } + return static_cast(handle); + } +} + +extern "C" { + +SQLRETURN SQL_API SQLAllocHandle(SQLSMALLINT handleType, + SQLHANDLE inputHandle, + SQLHANDLE* outputHandle) { + if (!outputHandle) { + return SQL_INVALID_HANDLE; + } + + switch (handleType) { + case SQL_HANDLE_ENV: { + return NYdb::NOdbc::HandleOdbcExceptions( + inputHandle, + [&]() { + auto* const env = new NYdb::NOdbc::TEnvironment(); + *outputHandle = env; + env->SetLastReturnCode(SQL_SUCCESS); + return SQL_SUCCESS; + }, + NYdb::NOdbc::ENullInputHandlePolicy::Allow); + } + + case SQL_HANDLE_DBC: { + return NYdb::NOdbc::HandleOdbcExceptions(inputHandle, [&](auto* env) { + auto conn = std::make_unique(); + conn->SetEnvironment(env); + env->RegisterConnection(conn.get()); + auto* const raw = conn.release(); + *outputHandle = raw; + raw->SetLastReturnCode(SQL_SUCCESS); + return SQL_SUCCESS; + }); + } + case SQL_HANDLE_STMT: { + return NYdb::NOdbc::HandleOdbcExceptions(inputHandle, [&](auto* conn) { + auto stmt = conn->CreateStatement(); + auto* const raw = stmt.release(); + *outputHandle = raw; + raw->SetLastReturnCode(SQL_SUCCESS); + return SQL_SUCCESS; + }); + } + default: + return SQL_ERROR; + } +} + +SQLRETURN SQL_API SQLFreeHandle(SQLSMALLINT handleType, SQLHANDLE handle) { + switch (handleType) { + case SQL_HANDLE_ENV: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [](auto* env) { + delete env; + return SQL_SUCCESS; + }); + } + case SQL_HANDLE_DBC: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [](auto* conn) { + auto* env = conn->GetEnvironment(); + if (env != nullptr){ + env->UnregisterConnection(conn); + } + delete conn; + return SQL_SUCCESS; + }); + } + case SQL_HANDLE_STMT: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [](auto* stmt) { + if (stmt->GetConnection()) { + stmt->GetConnection()->RemoveStatement(stmt); + } + delete stmt; + return SQL_SUCCESS; + }); + } + default: + return SQL_ERROR; + } +} + +SQLRETURN SQL_API SQLSetEnvAttr(SQLHENV environmentHandle, + SQLINTEGER attribute, + SQLPOINTER value, + SQLINTEGER stringLength) { + auto env = static_cast(environmentHandle); + if (!env) { + return SQL_INVALID_HANDLE; + } + + return NYdb::NOdbc::HandleOdbcExceptions(env, [&]() { + return env->SetAttribute(attribute, value, stringLength); + }); +} + +SQLRETURN SQL_API SQLDriverConnect(SQLHDBC connectionHandle, + SQLHWND /*WindowHandle*/, + SQLCHAR* inConnectionString, + SQLSMALLINT stringLength1, + SQLCHAR* /*outConnectionString*/, + SQLSMALLINT /*bufferLength*/, + SQLSMALLINT* /*stringLength2Ptr*/, + SQLUSMALLINT /*driverCompletion*/) { + return NYdb::NOdbc::HandleOdbcExceptions(connectionHandle, [&](auto* conn) { + return conn->DriverConnect(NYdb::NOdbc::GetString(inConnectionString, stringLength1)); + }); +} + +SQLRETURN SQL_API SQLConnect(SQLHDBC connectionHandle, + SQLCHAR* serverName, SQLSMALLINT nameLength1, + SQLCHAR* userName, SQLSMALLINT nameLength2, + SQLCHAR* authentication, SQLSMALLINT nameLength3) { + return NYdb::NOdbc::HandleOdbcExceptions(connectionHandle, [&](auto* conn) { + return conn->Connect(NYdb::NOdbc::GetString(serverName, nameLength1), + NYdb::NOdbc::GetString(userName, nameLength2), + NYdb::NOdbc::GetString(authentication, nameLength3)); + }); +} + +SQLRETURN SQL_API SQLDisconnect(SQLHDBC connectionHandle) { + return NYdb::NOdbc::HandleOdbcExceptions(connectionHandle, [&](auto* conn) { + return conn->Disconnect(); + }); +} + +SQLRETURN SQL_API SQLExecDirect(SQLHSTMT statementHandle, + SQLCHAR* statementText, + SQLINTEGER textLength) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + auto ret = stmt->Prepare(NYdb::NOdbc::GetString(statementText, textLength)); + if (ret != SQL_SUCCESS) { + return ret; + } + return stmt->Execute(); + }); +} + +SQLRETURN SQL_API SQLPrepare(SQLHSTMT statementHandle, + SQLCHAR* statementText, + SQLINTEGER textLength) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->Prepare(NYdb::NOdbc::GetString(statementText, textLength)); + }); +} + +SQLRETURN SQL_API SQLExecute(SQLHSTMT statementHandle) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->Execute(); + }); +} + +SQLRETURN SQL_API SQLFetch(SQLHSTMT statementHandle) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->Fetch(); + }); +} + +SQLRETURN SQL_API SQLGetData(SQLHSTMT statementHandle, + SQLUSMALLINT columnNumber, + SQLSMALLINT targetType, + SQLPOINTER targetValue, + SQLLEN bufferLength, + SQLLEN* strLenOrInd) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->GetData(columnNumber, targetType, targetValue, bufferLength, strLenOrInd); + }); +} + +SQLRETURN SQL_API SQLBindCol(SQLHSTMT statementHandle, + SQLUSMALLINT columnNumber, + SQLSMALLINT targetType, + SQLPOINTER targetValue, + SQLLEN bufferLength, + SQLLEN* strLenOrInd) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->BindCol(columnNumber, targetType, targetValue, bufferLength, strLenOrInd); + }); +} + +SQLRETURN SQL_API SQLGetDiagRec(SQLSMALLINT handleType, + SQLHANDLE handle, + SQLSMALLINT recNumber, + SQLCHAR* sqlState, + SQLINTEGER* nativeError, + SQLCHAR* messageText, + SQLSMALLINT bufferLength, + SQLSMALLINT* textLength) { + switch (handleType) { + case SQL_HANDLE_ENV: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* env) { + return env->GetDiagRec(recNumber, sqlState, nativeError, messageText, bufferLength, textLength); + }); + } + case SQL_HANDLE_DBC: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* conn) { + return conn->GetDiagRec(recNumber, sqlState, nativeError, messageText, bufferLength, textLength); + }); + } + case SQL_HANDLE_STMT: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* stmt) { + return stmt->GetDiagRec(recNumber, sqlState, nativeError, messageText, bufferLength, textLength); + }); + } + default: + return SQL_ERROR; + } +} + +SQLRETURN SQL_API SQLGetDiagField(SQLSMALLINT handleType, + SQLHANDLE handle, + SQLSMALLINT recNumber, + SQLSMALLINT diagIdentifier, + SQLPOINTER diagInfoPtr, + SQLSMALLINT bufferLength, + SQLSMALLINT* stringLengthPtr) { + switch (handleType) { + case SQL_HANDLE_ENV: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* env) { + return env->GetDiagField(recNumber, diagIdentifier, diagInfoPtr, bufferLength, stringLengthPtr); + }); + } + case SQL_HANDLE_DBC: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* conn) { + return conn->GetDiagField(recNumber, diagIdentifier, diagInfoPtr, bufferLength, stringLengthPtr); + }); + } + case SQL_HANDLE_STMT: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* stmt) { + return stmt->GetDiagField(recNumber, diagIdentifier, diagInfoPtr, bufferLength, stringLengthPtr); + }); + } + default: + return SQL_ERROR; + } +} + +SQLRETURN SQL_API SQLBindParameter(SQLHSTMT statementHandle, + SQLUSMALLINT paramNumber, + SQLSMALLINT inputOutputType, + SQLSMALLINT valueType, + SQLSMALLINT parameterType, + SQLULEN columnSize, + SQLSMALLINT decimalDigits, + SQLPOINTER parameterValuePtr, + SQLLEN bufferLength, + SQLLEN* strLenOrIndPtr) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->BindParameter(paramNumber, inputOutputType, valueType, parameterType, columnSize, decimalDigits, parameterValuePtr, bufferLength, strLenOrIndPtr); + }); +} + +SQLRETURN SQL_API SQLEndTran(SQLSMALLINT handleType, SQLHANDLE handle, SQLSMALLINT completionType) { + switch (handleType) { + case SQL_HANDLE_DBC: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* conn) { + if (completionType == SQL_COMMIT) { + return conn->CommitTx(); + } else if (completionType == SQL_ROLLBACK) { + return conn->RollbackTx(); + } else { + throw NYdb::NOdbc::TOdbcException("HY000", 0, "Invalid completion type"); + } + }); + } + case SQL_HANDLE_STMT: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* stmt) -> SQLRETURN { + auto conn = stmt->GetConnection(); + if (!conn) return SQL_INVALID_HANDLE; + if (completionType == SQL_COMMIT) { + return conn->CommitTx(); + } else if (completionType == SQL_ROLLBACK) { + return conn->RollbackTx(); + } else { + throw NYdb::NOdbc::TOdbcException("HY000", 0, "Invalid completion type"); + } + }); + } + case SQL_HANDLE_ENV: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* env) -> SQLRETURN { + return env->EndTran(completionType); + }); + } + default: + return SQL_ERROR; + } +} + +SQLRETURN SQL_API SQLSetConnectAttr(SQLHDBC connectionHandle, SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER stringLength) { + return NYdb::NOdbc::HandleOdbcExceptions(connectionHandle, [&](auto* conn) { + return conn->SetConnectAttr(attribute, value, stringLength); + }); +} + +SQLRETURN SQL_API SQLGetConnectAttr(SQLHDBC connectionHandle, SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr) { + return NYdb::NOdbc::HandleOdbcExceptions(connectionHandle, [&](auto* conn) { + return conn->GetConnectAttr(attribute, value, bufferLength, stringLengthPtr); + }); +} + +SQLRETURN SQL_API SQLColumns(SQLHSTMT statementHandle, + SQLCHAR* catalogName, SQLSMALLINT nameLength1, + SQLCHAR* schemaName, SQLSMALLINT nameLength2, + SQLCHAR* tableName, SQLSMALLINT nameLength3, + SQLCHAR* columnName, SQLSMALLINT nameLength4) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->Columns( + NYdb::NOdbc::GetString(catalogName, nameLength1), + NYdb::NOdbc::GetString(schemaName, nameLength2), + NYdb::NOdbc::GetString(tableName, nameLength3), + NYdb::NOdbc::GetString(columnName, nameLength4)); + }); +} + +SQLRETURN SQL_API SQLTables(SQLHSTMT statementHandle, + SQLCHAR* catalogName, SQLSMALLINT nameLength1, + SQLCHAR* schemaName, SQLSMALLINT nameLength2, + SQLCHAR* tableName, SQLSMALLINT nameLength3, + SQLCHAR* tableType, SQLSMALLINT nameLength4) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->Tables( + NYdb::NOdbc::GetString(catalogName, nameLength1), + NYdb::NOdbc::GetString(schemaName, nameLength2), + NYdb::NOdbc::GetString(tableName, nameLength3), + NYdb::NOdbc::GetString(tableType, nameLength4)); + }); +} + +SQLRETURN SQL_API SQLCloseCursor(SQLHSTMT statementHandle) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->Close(false); + }); +} + +SQLRETURN SQL_API SQLFreeStmt(SQLHSTMT statementHandle, SQLUSMALLINT option) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) -> SQLRETURN { + switch (option) { + case SQL_CLOSE: + return stmt->Close(true); + case SQL_DROP: + return SQLFreeHandle(SQL_HANDLE_STMT, statementHandle); + case SQL_UNBIND: + stmt->UnbindColumns(); + return SQL_SUCCESS; + case SQL_RESET_PARAMS: + stmt->ResetParams(); + return SQL_SUCCESS; + default: + throw NYdb::NOdbc::TOdbcException("HY000", 0, "Invalid option"); + } + }); +} + +SQLRETURN SQL_API SQLFetchScroll(SQLHSTMT statementHandle, SQLSMALLINT fetchOrientation, SQLLEN fetchOffset) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + if (fetchOrientation == SQL_FETCH_NEXT) { + return stmt->Fetch(); + } else { + throw NYdb::NOdbc::TOdbcException("HYC00", 0, "Only SQL_FETCH_NEXT is supported"); + } + //TODO other fetch-orientation + }); +} + +SQLRETURN SQL_API SQLRowCount(SQLHSTMT statementHandle, SQLLEN* rowCount) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->RowCount(rowCount); + }); +} + +SQLRETURN SQL_API SQLNumResultCols(SQLHSTMT statementHandle, SQLSMALLINT* colCount) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->NumResultCols(colCount); + }); +} + +SQLRETURN SQL_API SQLSetStmtAttr(SQLHSTMT statementHandle, SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER stringLength) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->SetStmtAttr(attribute, value, stringLength); + }); +} + +SQLRETURN SQL_API SQLGetStmtAttr( + SQLHSTMT statementHandle, + SQLINTEGER attribute, + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->GetStmtAttr(attribute, value, bufferLength, stringLengthPtr); + }); +} + +} diff --git a/odbc/src/statement.cpp b/odbc/src/statement.cpp new file mode 100644 index 00000000000..5d6bb38f152 --- /dev/null +++ b/odbc/src/statement.cpp @@ -0,0 +1,538 @@ +#include "statement.h" + +#include "utils/convert.h" +#include "utils/types.h" +#include "utils/error_manager.h" +#include "utils/escape.h" +#include "utils/sql_like.h" + +#include +#include +#include +#include + +#include + +#include + +namespace NYdb { +namespace NOdbc { + +namespace { + NYdb::TStatus StatusFrom(const NYdb::TStatus& ydb_status) { + return NYdb::TStatus(ydb_status.GetStatus(), NYdb::NIssue::TIssues(ydb_status.GetIssues())); + } + + + NYdb::TStatus PrefetchFirstPartStatus(NQuery::TExecuteQueryIterator& iterator, std::optional* prefetchedResultPart){ + prefetchedResultPart->reset(); + while (true) { + auto part = iterator.ReadNext().ExtractValueSync(); + if (part.EOS()) { + break; + } + if (!part.IsSuccess()) { + return StatusFrom(part); + + } + if (part.HasResultSet()) { + prefetchedResultPart->emplace(std::move(part)); + return NYdb::TStatus(EStatus::SUCCESS, NYdb::NIssue::TIssues()); + } + } + return NYdb::TStatus(EStatus::SUCCESS, NYdb::NIssue::TIssues()); + } +} + +TStatement::TStatement(TConnection* conn) + : Conn_(conn) {} + +SQLRETURN TStatement::Prepare(const std::string& statementText) { + StreamFetchError_ = false; + RowsFetched_ = 0; + Cursor_.reset(); + PreparedQuery_ = statementText; + IsPrepared_ = true; + return SQL_SUCCESS; +} + +SQLRETURN TStatement::Execute() { + if (!IsPrepared_ || PreparedQuery_.empty()) { + throw TOdbcException("HY007", 0, "No prepared statement"); + } + StreamFetchError_ = false; + RowsFetched_ = 0; + Cursor_.reset(); + auto* client = Conn_->GetClient(); + if (!client) { + throw TOdbcException("HY000", 0, "No client connection"); + } + NYdb::TParams params = BuildParams(); + + std::optional iterator; + std::optional prefetchedResultPart; + + if (Conn_->GetAutocommit()){ + Conn_->ResetTx(); + Conn_->ResetQuerySession(); + const NYdb::NRetry::TRetryOperationSettings retrySettings = + MakeAutocommitRetrySettings(); + + NYdb::TStatus execStatus = client->RetryQuerySync( + [this, ¶ms, &iterator, &prefetchedResultPart](NQuery::TSession session) -> NYdb::TStatus{ + auto retry_iterator = CreateExecuteIterator(session, params); + if (!retry_iterator.IsSuccess()) { + return StatusFrom(retry_iterator); + } + std::optional retry_prefetched; + const NYdb::TStatus prefetchStatus = PrefetchFirstPartStatus(retry_iterator, &retry_prefetched); + if (!prefetchStatus.IsSuccess()) { + return prefetchStatus; + } + iterator.emplace(std::move(retry_iterator)); + prefetchedResultPart = std::move(retry_prefetched); + return NYdb::TStatus(EStatus::SUCCESS, NYdb::NIssue::TIssues()); + }, retrySettings); + + NStatusHelpers::ThrowOnError(execStatus); + } else { + NQuery::TSession& session = Conn_->GetOrCreateQuerySession(); + iterator.emplace(CreateExecuteIterator(session, params)); + NStatusHelpers::ThrowOnError(*iterator); + NStatusHelpers::ThrowOnError(PrefetchFirstPartStatus(*iterator, &prefetchedResultPart)); + } + + if (prefetchedResultPart) { + Cursor_ = CreateExecCursor(this, std::move(*iterator), std::move(prefetchedResultPart)); + } else { + Cursor_.reset(); + } + IsPrepared_ = false; + PreparedQuery_.clear(); + return SQL_SUCCESS; +} + +NYdb::NRetry::TRetryOperationSettings TStatement::MakeAutocommitRetrySettings() { + NYdb::NRetry::TRetryOperationSettings settings; + SQLUINTEGER queryTimeoutSec = Attributes_.GetQueryTimeoutSec(); + if (queryTimeoutSec > 0) { + const TDuration deadline = TDuration::Seconds(queryTimeoutSec); + settings.MaxTimeout(deadline).GetSessionClientTimeout(deadline); + } + return settings; +} + +NQuery::TExecuteQueryIterator TStatement::CreateExecuteIterator(NQuery::TSession& session, const NYdb::TParams& params){ + const std::string sqlText = Attributes_.GetNoScanMode() == SQL_NOSCAN_ON + ? PreparedQuery_ + : RewriteOdbcEscapes(PreparedQuery_); + const std::string queryText = Conn_->WrapQueryForCurrentCatalog(sqlText); + NQuery::TExecuteQuerySettings execSettings; + const SQLUINTEGER queryTimeoutSec = Attributes_.GetQueryTimeoutSec(); + execSettings.ClientTimeout(TDuration::Seconds(queryTimeoutSec)); + if (Conn_->GetAutocommit()) { + const auto txSettings = Conn_->MakeTxSettings(); + if (txSettings.GetMode() == NQuery::TTxSettings::TS_SERIALIZABLE_RW) { + return session.StreamExecuteQuery( + queryText, + NQuery::TTxControl::NoTx(), + params, + execSettings).ExtractValueSync(); + } + return session.StreamExecuteQuery( + queryText, + NQuery::TTxControl::BeginTx(txSettings).CommitTx(), + params, + execSettings).ExtractValueSync(); + } + if (!Conn_->GetTx()) { + auto beginTxResult = session.BeginTransaction(Conn_->MakeTxSettings()).ExtractValueSync(); + NStatusHelpers::ThrowOnError(beginTxResult); + Conn_->SetTx(beginTxResult.GetTransaction()); + } + return session.StreamExecuteQuery( + queryText, + NQuery::TTxControl::Tx(*Conn_->GetTx()).CommitTx(false), + params, + execSettings).ExtractValueSync(); +} + + + +SQLRETURN TStatement::Fetch() { + if (!Cursor_) { + Cursor_.reset(); + return SQL_NO_DATA; + } + const SQLULEN maxRows = Attributes_.GetMaxRows(); + if (maxRows > 0 && RowsFetched_ >= maxRows) { + return SQL_NO_DATA; + } + StreamFetchError_ = false; + if (!Cursor_->Fetch()) { + return StreamFetchError_ ? SQL_ERROR : SQL_NO_DATA; + } + ++RowsFetched_; + return SQL_SUCCESS; +} + +void TStatement::OnStreamPartError(const TStatus& status) { + ClearErrors(); + AddError(status); + StreamFetchError_ = true; +} + +SQLRETURN TStatement::GetData(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, + SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd) { + if (!Cursor_) { + return SQL_NO_DATA; + } + return Cursor_->GetData(columnNumber, targetType, targetValue, bufferLength, strLenOrInd); +} + +void TStatement::FillBoundColumns() { + if (!Cursor_) { + return; + } + for (const auto& col : BoundColumns_) { + Cursor_->GetData(col.ColumnNumber, col.TargetType, col.TargetValue, col.BufferLength, col.StrLenOrInd); + } +} + +SQLRETURN TStatement::BindCol(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd) { + if (!Cursor_) { + return SQL_NO_DATA; + } + + BoundColumns_.erase(std::remove_if(BoundColumns_.begin(), BoundColumns_.end(), + [columnNumber](const TBoundColumn& col) { return col.ColumnNumber == columnNumber; }), BoundColumns_.end()); + + if (!targetValue) { + return SQL_SUCCESS; + } + BoundColumns_.push_back({columnNumber, targetType, targetValue, bufferLength, strLenOrInd}); + return SQL_SUCCESS; +} + +SQLRETURN TStatement::BindParameter(SQLUSMALLINT paramNumber, + SQLSMALLINT inputOutputType, + SQLSMALLINT valueType, + SQLSMALLINT parameterType, + SQLULEN columnSize, + SQLSMALLINT decimalDigits, + SQLPOINTER parameterValuePtr, + SQLLEN bufferLength, + SQLLEN* strLenOrIndPtr) { + + if (inputOutputType != SQL_PARAM_INPUT) { + throw TOdbcException("HYC00", 0, "Only input parameters are supported"); + } + + BoundParams_.erase(std::remove_if(BoundParams_.begin(), BoundParams_.end(), + [paramNumber](const TBoundParam& p) { return p.ParamNumber == paramNumber; }), BoundParams_.end()); + + if (!parameterValuePtr) { + return SQL_SUCCESS; + } + BoundParams_.push_back({paramNumber, inputOutputType, valueType, parameterType, columnSize, decimalDigits, parameterValuePtr, bufferLength, strLenOrIndPtr}); + return SQL_SUCCESS; +} + +NYdb::TParams TStatement::BuildParams() { + ClearErrors(); + NYdb::TParamsBuilder paramsBuilder; + for (const auto& param : BoundParams_) { + std::string paramName = "$p" + std::to_string(param.ParamNumber); + ConvertParam(param, paramsBuilder.AddParam(paramName)); + } + + return paramsBuilder.Build(); +} + +SQLRETURN TStatement::Columns(const std::string& catalogName, + const std::string& schemaName, + const std::string& tableName, + const std::string& columnName) { + ClearErrors(); + RowsFetched_ = 0; + Cursor_.reset(); + + std::vector columns = { + {"TABLE_CAT", SQL_VARCHAR, 128, SQL_NULLABLE}, + {"TABLE_SCHEM", SQL_VARCHAR, 128, SQL_NULLABLE}, + {"TABLE_NAME", SQL_VARCHAR, 128, SQL_NO_NULLS}, + {"COLUMN_NAME", SQL_VARCHAR, 128, SQL_NO_NULLS}, + {"DATA_TYPE", SQL_INTEGER, 0, SQL_NO_NULLS}, + {"TYPE_NAME", SQL_VARCHAR, 128, SQL_NO_NULLS}, + {"COLUMN_SIZE", SQL_INTEGER, 0, SQL_NULLABLE}, + {"BUFFER_LENGTH", SQL_INTEGER, 0, SQL_NULLABLE}, + {"DECIMAL_DIGITS", SQL_INTEGER, 0, SQL_NULLABLE}, + {"NUM_PREC_RADIX", SQL_INTEGER, 0, SQL_NULLABLE}, + {"NULLABLE", SQL_INTEGER, 0, SQL_NO_NULLS}, + {"REMARKS", SQL_VARCHAR, 762, SQL_NULLABLE}, + {"COLUMN_DEF", SQL_VARCHAR, 254, SQL_NULLABLE}, + {"SQL_DATA_TYPE", SQL_INTEGER, 0, SQL_NO_NULLS}, + {"SQL_DATETIME_SUB", SQL_INTEGER, 0, SQL_NULLABLE}, + {"CHAR_OCTET_LENGTH", SQL_INTEGER, 0, SQL_NULLABLE}, + {"ORDINAL_POSITION", SQL_INTEGER, 0, SQL_NO_NULLS}, + {"IS_NULLABLE", SQL_VARCHAR, 254, SQL_NO_NULLS} + }; + + auto entries = GetPatternEntries(tableName); + if (entries.empty()) { + throw TOdbcException("HYC00", 0, "No tables found"); + } + + TTable table; + table.reserve(entries.size()); + + for (const auto& entry : entries) { + if (entry.Type != NScheme::ESchemeEntryType::Table && + entry.Type != NScheme::ESchemeEntryType::ColumnTable) { + continue; + } + + auto status = Conn_->GetTableClient()->RetryOperationSync([this, path = entry.Name, &table, &columnName](NTable::TSession session) -> TStatus { + auto result = session.DescribeTable(path).ExtractValueSync(); + NStatusHelpers::ThrowOnError(result); + + auto columns = result.GetTableDescription().GetTableColumns(); + + auto columnIt = std::find_if(columns.begin(), columns.end(), [&](const NTable::TTableColumn& column) { + if (Attributes_.GetMetadataId() == SQL_TRUE) { + return column.Name == columnName; + } + if (columnName.empty()) { + return column.Name.empty(); + } + return SqlLikeMatch(column.Name, columnName); + }); + + if (columnIt == columns.end()) { + throw TOdbcException("42S22", 0, "Column not found", SQL_ERROR); + } + + auto column = *columnIt; + + TTypeParser typeParser(column.Type); + + table.push_back({ + TValueBuilder().OptionalUtf8(std::nullopt).Build(), + TValueBuilder().OptionalUtf8(std::nullopt).Build(), + TValueBuilder().Utf8(path).Build(), + TValueBuilder().Utf8(column.Name).Build(), + TValueBuilder().Int16(GetTypeId(column.Type)).Build(), + TValueBuilder().Utf8(column.Type.ToString()).Build(), + TValueBuilder().OptionalInt32(std::nullopt).Build(), + TValueBuilder().OptionalInt32(std::nullopt).Build(), + TValueBuilder().OptionalInt16(GetDecimalDigits(column.Type)).Build(), + TValueBuilder().OptionalInt16(GetRadix(column.Type)).Build(), + TValueBuilder().Int16(column.NotNull && *column.NotNull ? SQL_NO_NULLS : SQL_NULLABLE).Build(), + TValueBuilder().OptionalUtf8(std::nullopt).Build(), + TValueBuilder().OptionalUtf8(std::nullopt).Build(), + TValueBuilder().Int16(GetTypeId(column.Type)).Build(), + TValueBuilder().OptionalInt16(std::nullopt).Build(), + TValueBuilder().OptionalInt32(8).Build(), + TValueBuilder().OptionalInt32(columnIt - columns.begin() + 1).Build(), + TValueBuilder().Utf8(column.NotNull && *column.NotNull ? "NO" : "YES").Build(), + }); + return TStatus(EStatus::SUCCESS, {}); + }); + + NStatusHelpers::ThrowOnError(status); + } + + Cursor_ = CreateVirtualCursor(this, columns, table); + return SQL_SUCCESS; +} + +SQLRETURN TStatement::Tables(const std::string& catalogName, + const std::string& schemaName, + const std::string& tableName, + const std::string& tableType) { + ClearErrors(); + RowsFetched_ = 0; + Cursor_.reset(); + + std::vector columns = { + {"TABLE_CAT", SQL_VARCHAR, 128, SQL_NULLABLE}, + {"TABLE_SCHEM", SQL_VARCHAR, 128, SQL_NULLABLE}, + {"TABLE_NAME", SQL_VARCHAR, 128, SQL_NO_NULLS}, + {"TABLE_TYPE", SQL_VARCHAR, 128, SQL_NO_NULLS}, + {"REMARKS", SQL_VARCHAR, 254, SQL_NULLABLE} + }; + + auto entries = GetPatternEntries(tableName); + if (entries.empty()) { + throw TOdbcException("HYC00", 0, "No tables found"); + } + + TTable table; + table.reserve(entries.size()); + + for (const auto& entry : entries) { + auto tableType = GetTableType(entry.Type); + if (!tableType) { + continue; + } + + table.push_back({ + TValueBuilder().OptionalUtf8(std::nullopt).Build(), + TValueBuilder().OptionalUtf8(std::nullopt).Build(), + TValueBuilder().Utf8(entry.Name).Build(), + TValueBuilder().Utf8(*tableType).Build(), + TValueBuilder().OptionalUtf8(std::nullopt).Build(), + }); + } + + Cursor_ = CreateVirtualCursor(this, columns, table); + return SQL_SUCCESS; +} + +std::vector TStatement::GetPatternEntries(const std::string& pattern) { + std::vector entries; + VisitEntry("", pattern, entries); + return entries; +} + +SQLRETURN TStatement::VisitEntry(const std::string& path, const std::string& pattern, std::vector& resultEntries) { + auto schemeClient = Conn_->GetSchemeClient(); + auto listDirectoryResult = schemeClient->ListDirectory(path + "/").ExtractValueSync(); + NStatusHelpers::ThrowOnError(listDirectoryResult); + + for (const auto& entry : listDirectoryResult.GetChildren()) { + std::string fullPath = path + "/" + entry.Name; + if (entry.Type == NScheme::ESchemeEntryType::Directory || + entry.Type == NScheme::ESchemeEntryType::SubDomain) { + VisitEntry(fullPath, pattern, resultEntries); + } else if (IsPatternMatch(fullPath, pattern)) { + NScheme::TSchemeEntry entryCopy = entry; + entryCopy.Name = fullPath; + resultEntries.push_back(entryCopy); + } + } + return SQL_SUCCESS; +} + +bool TStatement::IsPatternMatch(const std::string& path, const std::string& pattern) { + if (pattern.empty()) { + return true; + } + if (Attributes_.GetMetadataId() == SQL_TRUE) { + return path == pattern; + } + return SqlLikeMatch(path, pattern); +} + +std::optional TStatement::GetTableType(NScheme::ESchemeEntryType type) { + switch (type) { + case NScheme::ESchemeEntryType::Table: + return "TABLE"; + case NScheme::ESchemeEntryType::View: + return "VIEW"; + case NScheme::ESchemeEntryType::ColumnStore: + return "COLUMN_STORE"; + case NScheme::ESchemeEntryType::ColumnTable: + return "COLUMN_TABLE"; + case NScheme::ESchemeEntryType::Sequence: + return "SEQUENCE"; + case NScheme::ESchemeEntryType::Replication: + return "REPLICATION"; + case NScheme::ESchemeEntryType::Topic: + return "TOPIC"; + case NScheme::ESchemeEntryType::ExternalTable: + return "EXTERNAL_TABLE"; + case NScheme::ESchemeEntryType::ExternalDataSource: + return "EXTERNAL_DATA_SOURCE"; + case NScheme::ESchemeEntryType::ResourcePool: + return "RESOURCE_POOL"; + case NScheme::ESchemeEntryType::PqGroup: + return "PQ_GROUP"; + case NScheme::ESchemeEntryType::RtmrVolume: + return "RTMR_VOLUME"; + case NScheme::ESchemeEntryType::BlockStoreVolume: + return "BLOCK_STORE_VOLUME"; + case NScheme::ESchemeEntryType::CoordinationNode: + return "COORDINATION_NODE"; + case NScheme::ESchemeEntryType::Unknown: + return "UNKNOWN"; + case NScheme::ESchemeEntryType::SysView: + return "SYSTEM VIEW"; + case NScheme::ESchemeEntryType::Transfer: + return "TRANSFER"; + case NScheme::ESchemeEntryType::Directory: + case NScheme::ESchemeEntryType::SubDomain: + return std::nullopt; + default: + return std::nullopt; + } +} + +SQLRETURN TStatement::Close(bool force) { + if (!force && !Cursor_) { + throw TOdbcException("24000", 0, "Invalid handle"); + } + + Cursor_.reset(); + RowsFetched_ = 0; + PreparedQuery_.clear(); + IsPrepared_ = false; + ClearErrors(); + return SQL_SUCCESS; +} + +void TStatement::UnbindColumns() { + BoundColumns_.clear(); +} + +void TStatement::ResetParams() { + BoundParams_.clear(); +} + +SQLRETURN TStatement::RowCount(SQLLEN* rowCount) { + if (!rowCount) { + throw TOdbcException("HY000", 0, "Invalid parameter"); + } + + *rowCount = -1; + return SQL_SUCCESS; +} + +SQLRETURN TStatement::NumResultCols(SQLSMALLINT* colCount) { + if (!colCount) { + throw TOdbcException("HY000", 0, "Invalid parameter"); + } + if (!Cursor_) { + *colCount = 0; + return SQL_SUCCESS; + } + *colCount = static_cast(Cursor_->GetColumnMeta().size()); + return SQL_SUCCESS; +} + +SQLRETURN TStatement::SetStmtAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER stringLength) { + return Attributes_.SetStmtAttr(attr, value, stringLength, *this); +} + +SQLRETURN TStatement::GetStmtAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER bufferLength, SQLINTEGER* stringLengthPtr) { + return Attributes_.GetStmtAttr(attr, value, bufferLength, stringLengthPtr, *this); +} + +SQLRETURN TStatement::GetDiagField( + SQLSMALLINT recNumber, + SQLSMALLINT diagIdentifier, + SQLPOINTER diagInfoPtr, + SQLSMALLINT bufferLength, + SQLSMALLINT* stringLengthPtr) { + if (recNumber == 0 && diagIdentifier == SQL_DIAG_ROW_COUNT) { + if (!diagInfoPtr) { + return SQL_ERROR; + } + *reinterpret_cast(diagInfoPtr) = -1; + return SQL_SUCCESS; + } + return TErrorManager::GetDiagField(recNumber, diagIdentifier, diagInfoPtr, bufferLength, stringLengthPtr); +} + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/statement.h b/odbc/src/statement.h new file mode 100644 index 00000000000..9f2eb8ade64 --- /dev/null +++ b/odbc/src/statement.h @@ -0,0 +1,89 @@ +#pragma once + +#include "connection.h" +#include "statement_attr.h" +#include "utils/error_manager.h" +#include "utils/bindings.h" +#include "utils/cursor.h" + +#include + +#include +#include + +#include +#include +#include + + +namespace NYdb { +namespace NOdbc { + +class TStatement : public TErrorManager, public IBindingFiller { +public: + TStatement(TConnection* conn); + + SQLRETURN Prepare(const std::string& statementText); + SQLRETURN Execute(); + + SQLRETURN Fetch(); + SQLRETURN GetData(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, + SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd); + + void FillBoundColumns() override; + void OnStreamPartError(const TStatus& status) override; + + SQLRETURN Close(bool force = false); + void UnbindColumns(); + void ResetParams(); + + SQLRETURN BindCol(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd); + SQLRETURN BindParameter(SQLUSMALLINT paramNumber, SQLSMALLINT inputOutputType, SQLSMALLINT valueType, SQLSMALLINT parameterType, SQLULEN columnSize, SQLSMALLINT decimalDigits, SQLPOINTER parameterValuePtr, SQLLEN bufferLength, SQLLEN* strLenOrIndPtr); + + SQLRETURN Columns(const std::string& catalogName, + const std::string& schemaName, + const std::string& tableName, + const std::string& columnName); + + SQLRETURN Tables(const std::string& catalogName, + const std::string& schemaName, + const std::string& tableName, + const std::string& tableType); + + SQLRETURN RowCount(SQLLEN* rowCount); + SQLRETURN NumResultCols(SQLSMALLINT* colCount); + SQLRETURN SetStmtAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER stringLength); + SQLRETURN GetStmtAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER bufferLength, SQLINTEGER* stringLengthPtr); + + SQLRETURN GetDiagField(SQLSMALLINT recNumber, SQLSMALLINT diagIdentifier, SQLPOINTER diagInfoPtr, SQLSMALLINT bufferLength, + SQLSMALLINT* stringLengthPtr) override; + + TConnection* GetConnection() { + return Conn_; + } + +private: + TConnection* Conn_; + std::unique_ptr Cursor_; + std::string PreparedQuery_; + bool IsPrepared_ = false; + + std::vector BoundColumns_; + std::vector BoundParams_; + bool StreamFetchError_ = false; + SQLULEN RowsFetched_ = 0; + TStatementAttributes Attributes_; + + NYdb::TParams BuildParams(); + + NQuery::TExecuteQueryIterator CreateExecuteIterator(NQuery::TSession& session, const NYdb::TParams& params); + + NYdb::NRetry::TRetryOperationSettings MakeAutocommitRetrySettings(); + std::vector GetPatternEntries(const std::string& pattern); + SQLRETURN VisitEntry(const std::string& path, const std::string& pattern, std::vector& resultEntries); + bool IsPatternMatch(const std::string& path, const std::string& pattern); + std::optional GetTableType(NScheme::ESchemeEntryType type); +}; + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/statement_attr.cpp b/odbc/src/statement_attr.cpp new file mode 100644 index 00000000000..f0baad0016a --- /dev/null +++ b/odbc/src/statement_attr.cpp @@ -0,0 +1,101 @@ +#include "statement_attr.h" + +#include "utils/attr.h" +#include "utils/diag.h" + +#include + +namespace NYdb { +namespace NOdbc { + +SQLRETURN TStatementAttributes::SetStmtAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER /*stringLength*/, + TErrorManager& errors) { + switch (attr) { + case SQL_ATTR_QUERY_TIMEOUT: { + const SQLINTEGER timeout = ReadIntegerAttr(value); + if (timeout < 0) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_QUERY_TIMEOUT"); + } + QueryTimeoutSec_ = static_cast(timeout); + return SQL_SUCCESS; + } + case SQL_ATTR_MAX_ROWS: { + const SQLLEN maxRows = ReadIntegerAttr(value); + if (maxRows < 0) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_MAX_ROWS"); + } + MaxRows_ = static_cast(maxRows); + return SQL_SUCCESS; + } + case SQL_ATTR_NOSCAN: { + const auto mode = ReadIntegerAttrIfIn(value, {SQL_NOSCAN_OFF, SQL_NOSCAN_ON}); + if (!mode) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_NOSCAN"); + } + NoScan_ = *mode; + return SQL_SUCCESS; + } + case SQL_ATTR_METADATA_ID: { + const auto mode = ReadIntegerAttrIfIn(value, {SQL_FALSE, SQL_TRUE}); + if (!mode) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_METADATA_ID"); + } + MetadataId_ = *mode; + return SQL_SUCCESS; + } + default: + return Diag::AddNotImplemented(errors); + } +} + +SQLRETURN TStatementAttributes::GetStmtAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER /*bufferLength*/, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) const { + if (!value) { + return Diag::AddNullPointer(errors); + } + if (stringLengthPtr) { + *stringLengthPtr = 0; + } + switch (attr) { + case SQL_ATTR_QUERY_TIMEOUT: + *reinterpret_cast(value) = QueryTimeoutSec_; + return SQL_SUCCESS; + case SQL_ATTR_MAX_ROWS: + *reinterpret_cast(value) = MaxRows_; + return SQL_SUCCESS; + case SQL_ATTR_NOSCAN: + *reinterpret_cast(value) = NoScan_; + return SQL_SUCCESS; + case SQL_ATTR_METADATA_ID: + *reinterpret_cast(value) = MetadataId_; + return SQL_SUCCESS; + default: + return Diag::AddNotImplemented(errors); + } +} + +SQLUINTEGER TStatementAttributes::GetQueryTimeoutSec() const noexcept{ + return QueryTimeoutSec_; +} + +SQLULEN TStatementAttributes::GetMaxRows() const noexcept { + return MaxRows_; +} + +SQLULEN TStatementAttributes::GetNoScanMode() const noexcept { + return NoScan_; +} + +SQLULEN TStatementAttributes::GetMetadataId() const noexcept { + return MetadataId_; +} + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/statement_attr.h b/odbc/src/statement_attr.h new file mode 100644 index 00000000000..b0d6e9bd97f --- /dev/null +++ b/odbc/src/statement_attr.h @@ -0,0 +1,39 @@ +#pragma once + +#include "utils/error_manager.h" + +#include +#include + +namespace NYdb { +namespace NOdbc { + +class TStatementAttributes { +public: + SQLRETURN SetStmtAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER stringLength, + TErrorManager& errors); + + SQLRETURN GetStmtAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) const; + + SQLUINTEGER GetQueryTimeoutSec() const noexcept; + SQLULEN GetMaxRows() const noexcept; + SQLULEN GetNoScanMode() const noexcept; + SQLULEN GetMetadataId() const noexcept; + +private: + SQLUINTEGER QueryTimeoutSec_ = 0; + SQLULEN MaxRows_ = 0; + SQLULEN NoScan_ = SQL_NOSCAN_OFF; + SQLULEN MetadataId_ = SQL_FALSE; +}; + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/attr.cpp b/odbc/src/utils/attr.cpp new file mode 100644 index 00000000000..1fb2a83324a --- /dev/null +++ b/odbc/src/utils/attr.cpp @@ -0,0 +1,51 @@ +#include "attr.h" +#include "diag.h" + +#include +#include + +namespace NYdb::NOdbc { + +std::string ReadAttributeString(SQLPOINTER value, SQLINTEGER stringLength) { + const char* const str = static_cast(value); + if (stringLength == SQL_NTS) { + return std::string(str); + } + if (stringLength < 0) { + return {}; + } + return std::string(str, static_cast(stringLength)); +} + +SQLRETURN WriteAttributeString( + const std::string& source, + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) { + const SQLINTEGER length = static_cast(source.size()); + if (stringLengthPtr != nullptr) { + *stringLengthPtr = length; + } + if (value == nullptr) { + return SQL_SUCCESS; + } + if (bufferLength <= 0) { + return Diag::AddInvalidBufferLength(errors); + } + + auto* dest = static_cast(value); + const size_t maxData = static_cast(bufferLength - 1); + const size_t nCopy = std::min(source.size(), maxData); + if (nCopy > 0) { + std::memcpy(dest, source.data(), nCopy); + } + dest[nCopy] = 0; + + if (length >= bufferLength) { + return Diag::AddRightTruncated(errors); + } + return SQL_SUCCESS; +} + +} // namespace NYdb::NOdbc diff --git a/odbc/src/utils/attr.h b/odbc/src/utils/attr.h new file mode 100644 index 00000000000..34abeed42da --- /dev/null +++ b/odbc/src/utils/attr.h @@ -0,0 +1,46 @@ +#pragma once + +#include "error_manager.h" + +#include +#include +#include + +#include +#include + +namespace NYdb::NOdbc { + +std::string ReadAttributeString(SQLPOINTER value, SQLINTEGER stringLength); + +SQLRETURN WriteAttributeString( + const std::string& source, + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors); + +template +T ReadIntegerAttr(SQLPOINTER value) noexcept; + +template +std::optional ReadIntegerAttrIfIn(SQLPOINTER value, std::initializer_list allowed) noexcept; + +template +T ReadIntegerAttr(SQLPOINTER value) noexcept { + return static_cast(reinterpret_cast(value)); +} + +template +std::optional ReadIntegerAttrIfIn(SQLPOINTER value, std::initializer_list allowed) noexcept { + const T token = ReadIntegerAttr(value); + for (const T allowedValue : allowed) { + if (token == allowedValue) { + return token; + } + } + return std::nullopt; +} + + +} // namespace NYdb::NOdbc diff --git a/odbc/src/utils/bindings.h b/odbc/src/utils/bindings.h new file mode 100644 index 00000000000..2480f5367af --- /dev/null +++ b/odbc/src/utils/bindings.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include + +#include + +namespace NYdb { +namespace NOdbc { + +struct TBoundParam { + SQLUSMALLINT ParamNumber; + SQLSMALLINT InputOutputType; + SQLSMALLINT ValueType; + SQLSMALLINT ParameterType; + SQLULEN ColumnSize; + SQLSMALLINT DecimalDigits; + SQLPOINTER ParameterValuePtr; + SQLLEN BufferLength; + SQLLEN* StrLenOrIndPtr; +}; + +struct TBoundColumn { + SQLUSMALLINT ColumnNumber; + SQLSMALLINT TargetType; + SQLPOINTER TargetValue; + SQLLEN BufferLength; + SQLLEN* StrLenOrInd; +}; + +class IBindingFiller { +public: + virtual void FillBoundColumns() = 0; + virtual void OnStreamPartError([[maybe_unused]] const TStatus& status) { + } + + virtual ~IBindingFiller() = default; +}; + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/convert.cpp b/odbc/src/utils/convert.cpp new file mode 100644 index 00000000000..4e415c65521 --- /dev/null +++ b/odbc/src/utils/convert.cpp @@ -0,0 +1,484 @@ +#include "convert.h" + +#include +#include + +#include + +namespace NYdb { +namespace NOdbc { + +template +struct TSqlTypeTraits; + +template<> struct TSqlTypeTraits { using Type = std::string; }; +template<> struct TSqlTypeTraits { using Type = std::string; }; +template<> struct TSqlTypeTraits { using Type = SQLBIGINT; }; +template<> struct TSqlTypeTraits { using Type = SQLUBIGINT; }; +template<> struct TSqlTypeTraits { using Type = SQLINTEGER; }; +template<> struct TSqlTypeTraits { using Type = SQLUINTEGER; }; +template<> struct TSqlTypeTraits { using Type = SQLSMALLINT; }; +template<> struct TSqlTypeTraits { using Type = SQLSMALLINT; }; +template<> struct TSqlTypeTraits { using Type = SQLUSMALLINT; }; +template<> struct TSqlTypeTraits { using Type = SQLSCHAR; }; +template<> struct TSqlTypeTraits { using Type = SQLCHAR; }; +template<> struct TSqlTypeTraits { using Type = SQLDOUBLE; }; +template<> struct TSqlTypeTraits { using Type = SQLFLOAT; }; +template<> struct TSqlTypeTraits { using Type = SQLCHAR; }; + +template +struct TTypedValue { + using TSrcType = typename TSqlTypeTraits::Type; + + TSrcType Data; + + TTypedValue(const TBoundParam& param) { + Data = *static_cast(param.ParameterValuePtr); + } +}; + +template<> +TTypedValue::TTypedValue(const TBoundParam& param) { + Data = std::string(static_cast(param.ParameterValuePtr), param.BufferLength); +} + +template<> +TTypedValue::TTypedValue(const TBoundParam& param) { + Data = std::string(static_cast(param.ParameterValuePtr), param.BufferLength); +} + +class IConverter { +public: + virtual void AddToBuilder(const TBoundParam& param, TParamValueBuilder& builder) = 0; + + virtual ~IConverter() = default; +}; + +template +class TConverter : public IConverter { +public: + virtual void AddToBuilder(const TBoundParam& param, TParamValueBuilder& builder) override { + TTypedValue value(param); + Convert(param, std::move(value.Data), builder); + if (param.StrLenOrIndPtr && *param.StrLenOrIndPtr == SQL_NULL_DATA) { + builder.EmptyOptional(GetType()); + } + builder.Build(); + } + +private: + void Convert(const TBoundParam& param, TTypedValue::TSrcType&& data, TParamValueBuilder& builder); + TType GetType(); +}; + +class TConverterRegistry { +public: + static TConverterRegistry& GetInstance() { + static TConverterRegistry instance; + return instance; + } + + void RegisterConverter(SQLSMALLINT cType, SQLSMALLINT sqlType, std::unique_ptr converter) { + Converters_.emplace(std::make_pair(cType, sqlType), std::move(converter)); + } + + IConverter* GetConverter(SQLSMALLINT cType, SQLSMALLINT sqlType) { + auto it = Converters_.find(std::make_pair(cType, sqlType)); + if (it != Converters_.end()) { + return it->second.get(); + } + return nullptr; + } + +private: + std::map, std::unique_ptr> Converters_; +}; + +#define REGISTER_CONVERTER(CType, SqlType, YdbType) \ + struct TConverterRegistration##CType##SqlType { \ + TConverterRegistration##CType##SqlType() { \ + TConverterRegistry::GetInstance().RegisterConverter(CType, SqlType, std::make_unique>()); \ + } \ + }; \ + static const TConverterRegistration##CType##SqlType converterRegistration##CType##SqlType; \ + template<> \ + TType TConverter::GetType() { \ + return TTypeBuilder().Primitive(YdbType).Build(); \ + } \ + template<> \ + void TConverter::Convert(const TBoundParam& param, TTypedValue::TSrcType&& data, TParamValueBuilder& builder) + +// Integer types + +REGISTER_CONVERTER(SQL_C_SBIGINT, SQL_BIGINT, EPrimitiveType::Int64) { + builder.OptionalInt64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_LONG, SQL_BIGINT, EPrimitiveType::Int64) { + builder.OptionalInt64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SHORT, SQL_BIGINT, EPrimitiveType::Int64) { + builder.OptionalInt64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_TINYINT, SQL_BIGINT, EPrimitiveType::Int64) { + builder.OptionalInt64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UBIGINT, SQL_BIGINT, EPrimitiveType::Uint64) { + builder.OptionalUint64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_ULONG, SQL_BIGINT, EPrimitiveType::Uint64) { + builder.OptionalUint64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_USHORT, SQL_BIGINT, EPrimitiveType::Uint64) { + builder.OptionalUint64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UTINYINT, SQL_BIGINT, EPrimitiveType::Uint64) { + builder.OptionalUint64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SBIGINT, SQL_INTEGER, EPrimitiveType::Int32) { + builder.OptionalInt32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_LONG, SQL_INTEGER, EPrimitiveType::Int32) { + builder.OptionalInt32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SHORT, SQL_INTEGER, EPrimitiveType::Int32) { + builder.OptionalInt32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_TINYINT, SQL_INTEGER, EPrimitiveType::Int32) { + builder.OptionalInt32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UBIGINT, SQL_INTEGER, EPrimitiveType::Uint32) { + builder.OptionalUint32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_ULONG, SQL_INTEGER, EPrimitiveType::Uint32) { + builder.OptionalUint32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_USHORT, SQL_INTEGER, EPrimitiveType::Uint32) { + builder.OptionalUint32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UTINYINT, SQL_INTEGER, EPrimitiveType::Uint32) { + builder.OptionalUint32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SBIGINT, SQL_SMALLINT, EPrimitiveType::Int16) { + builder.OptionalInt16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_LONG, SQL_SMALLINT, EPrimitiveType::Int16) { + builder.OptionalInt16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SHORT, SQL_SMALLINT, EPrimitiveType::Int16) { + builder.OptionalInt16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_TINYINT, SQL_SMALLINT, EPrimitiveType::Int16) { + builder.OptionalInt16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UBIGINT, SQL_SMALLINT, EPrimitiveType::Uint16) { + builder.OptionalUint16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_ULONG, SQL_SMALLINT, EPrimitiveType::Uint16) { + builder.OptionalUint16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_USHORT, SQL_SMALLINT, EPrimitiveType::Uint16) { + builder.OptionalUint16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UTINYINT, SQL_SMALLINT, EPrimitiveType::Uint16) { + builder.OptionalUint16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SBIGINT, SQL_TINYINT, EPrimitiveType::Int8) { + builder.OptionalInt8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_LONG, SQL_TINYINT, EPrimitiveType::Int8) { + builder.OptionalInt8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SHORT, SQL_TINYINT, EPrimitiveType::Int8) { + builder.OptionalInt8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_TINYINT, SQL_TINYINT, EPrimitiveType::Int8) { + builder.OptionalInt8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UBIGINT, SQL_TINYINT, EPrimitiveType::Uint8) { + builder.OptionalUint8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_ULONG, SQL_TINYINT, EPrimitiveType::Uint8) { + builder.OptionalUint8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_USHORT, SQL_TINYINT, EPrimitiveType::Uint8) { + builder.OptionalUint8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UTINYINT, SQL_TINYINT, EPrimitiveType::Uint8) { + builder.OptionalUint8(static_cast(data)); +} + +// Floating point types + +REGISTER_CONVERTER(SQL_C_FLOAT, SQL_REAL, EPrimitiveType::Float) { + builder.OptionalFloat(data); +} + +REGISTER_CONVERTER(SQL_C_DOUBLE, SQL_FLOAT, EPrimitiveType::Double) { + builder.OptionalDouble(data); +} + +REGISTER_CONVERTER(SQL_C_DOUBLE, SQL_DOUBLE, EPrimitiveType::Double) { + builder.OptionalDouble(data); +} + +// String types + +REGISTER_CONVERTER(SQL_C_CHAR, SQL_CHAR, EPrimitiveType::Utf8) { + builder.OptionalUtf8(std::move(data)); +} + +REGISTER_CONVERTER(SQL_C_CHAR, SQL_VARCHAR, EPrimitiveType::Utf8) { + builder.OptionalUtf8(std::move(data)); +} + +REGISTER_CONVERTER(SQL_C_CHAR, SQL_LONGVARCHAR, EPrimitiveType::Utf8) { + builder.OptionalUtf8(std::move(data)); +} + +// Binary types + +REGISTER_CONVERTER(SQL_C_BINARY, SQL_BINARY, EPrimitiveType::String) { + builder.OptionalString(std::move(data)); +} + +REGISTER_CONVERTER(SQL_C_BINARY, SQL_VARBINARY, EPrimitiveType::String) { + builder.OptionalString(std::move(data)); +} + +REGISTER_CONVERTER(SQL_C_BINARY, SQL_LONGVARBINARY, EPrimitiveType::String) { + builder.OptionalString(std::move(data)); +} + +#undef REGISTER_CONVERTER + +SQLRETURN ConvertParam(const TBoundParam& param, TParamValueBuilder& builder) { + auto converter = TConverterRegistry::GetInstance().GetConverter(param.ValueType, param.ParameterType); + if (!converter) { + return SQL_ERROR; + } + + converter->AddToBuilder(param, builder); + return SQL_SUCCESS; +} + +SQLRETURN ConvertColumn(TValueParser& parser, SQLSMALLINT targetType, SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd) { + if (parser.IsNull()) { + if (strLenOrInd) { + *strLenOrInd = SQL_NULL_DATA; + } + return SQL_SUCCESS; + } + + if (parser.GetKind() == TTypeParser::ETypeKind::Optional) { + parser.OpenOptional(); + SQLRETURN ret = ConvertColumn(parser, targetType, targetValue, bufferLength, strLenOrInd); + parser.CloseOptional(); + return ret; + } + + if (parser.GetKind() != TTypeParser::ETypeKind::Primitive) { + return SQL_ERROR; + } + + EPrimitiveType ydbType = parser.GetPrimitiveType(); + + switch (targetType) { + case SQL_C_SHORT: + case SQL_C_SSHORT: + { + SQLSMALLINT v = 0; + switch (ydbType) { + case EPrimitiveType::Int16: v = parser.GetInt16(); break; + case EPrimitiveType::Uint16: v = static_cast(parser.GetUint16()); break; + case EPrimitiveType::Int8: v = static_cast(parser.GetInt8()); break; + case EPrimitiveType::Uint8: v = static_cast(parser.GetUint8()); break; + case EPrimitiveType::Int32: v = static_cast(parser.GetInt32()); break; + case EPrimitiveType::Uint32: v = static_cast(parser.GetUint32()); break; + case EPrimitiveType::Bool: v = parser.GetBool() ? 1 : 0; break; + default: return SQL_ERROR; + } + if (targetValue) { + *reinterpret_cast(targetValue) = v; + } + if (strLenOrInd) { + *strLenOrInd = sizeof(SQLSMALLINT); + } + return SQL_SUCCESS; + } + case SQL_C_SLONG: + case SQL_C_LONG: + { + int32_t v = 0; + switch (ydbType) { + case EPrimitiveType::Int16: v = static_cast(parser.GetInt16()); break; + case EPrimitiveType::Uint16: v = static_cast(parser.GetUint16()); break; + case EPrimitiveType::Int8: v = static_cast(parser.GetInt8()); break; + case EPrimitiveType::Uint8: v = static_cast(parser.GetUint8()); break; + case EPrimitiveType::Int32: v = static_cast(parser.GetInt32()); break; + case EPrimitiveType::Uint32: v = static_cast(parser.GetUint32()); break; + case EPrimitiveType::Int64: v = static_cast(parser.GetInt64()); break; + case EPrimitiveType::Uint64: v = static_cast(parser.GetUint64()); break; + case EPrimitiveType::Bool: v = parser.GetBool() ? 1 : 0; break; + default: return SQL_ERROR; + } + if (targetValue) { + *reinterpret_cast(targetValue) = v; + } + if (strLenOrInd) { + *strLenOrInd = sizeof(int32_t); + } + return SQL_SUCCESS; + } + case SQL_C_SBIGINT: + { + SQLBIGINT v = 0; + switch (ydbType) { + case EPrimitiveType::Int64: v = parser.GetInt64(); break; + case EPrimitiveType::Uint64: v = static_cast(parser.GetUint64()); break; + case EPrimitiveType::Int32: v = static_cast(parser.GetInt32()); break; + case EPrimitiveType::Uint32: v = static_cast(parser.GetUint32()); break; + default: return SQL_ERROR; + } + if (targetValue) { + *reinterpret_cast(targetValue) = v; + } + if (strLenOrInd) { + *strLenOrInd = sizeof(SQLBIGINT); + } + return SQL_SUCCESS; + } + case SQL_C_DOUBLE: + { + double v = 0.0; + switch (ydbType) { + case EPrimitiveType::Double: v = parser.GetDouble(); break; + case EPrimitiveType::Float: v = parser.GetFloat(); break; + default: return SQL_ERROR; + } + if (targetValue) { + *reinterpret_cast(targetValue) = v; + } + if (strLenOrInd) { + *strLenOrInd = sizeof(double); + } + return SQL_SUCCESS; + } + case SQL_C_CHAR: + { + std::string str; + switch (ydbType) { + case EPrimitiveType::Utf8: str = parser.GetUtf8(); break; + case EPrimitiveType::String: str = parser.GetString(); break; + case EPrimitiveType::Json: str = parser.GetJson(); break; + case EPrimitiveType::JsonDocument: str = parser.GetJsonDocument(); break; + case EPrimitiveType::Date: { + const TString t = parser.GetDate().FormatGmTime("%Y-%m-%d"); + str.assign(t.data(), t.size()); + break; + } + case EPrimitiveType::Date32: { + const i32 days = parser.GetDate32(); + if (days < 0) { + return SQL_ERROR; + } + const TString t = TInstant::Days(static_cast(days)).FormatGmTime("%Y-%m-%d"); + str.assign(t.data(), t.size()); + break; + } + case EPrimitiveType::Datetime: { + const TString t = parser.GetDatetime().FormatGmTime("%Y-%m-%d %H:%M:%S"); + str.assign(t.data(), t.size()); + break; + } + case EPrimitiveType::Datetime64: { + const std::int64_t secs = parser.GetDatetime64(); + if (secs < 0) { + return SQL_ERROR; + } + const TString t = + TInstant::Seconds(static_cast(static_cast(secs))) + .FormatGmTime("%Y-%m-%d %H:%M:%S"); + str.assign(t.data(), t.size()); + break; + } + case EPrimitiveType::Timestamp: { + const TString t = parser.GetTimestamp().FormatGmTime("%Y-%m-%d %H:%M:%S"); + str.assign(t.data(), t.size()); + break; + } + case EPrimitiveType::Timestamp64: { + const std::int64_t micros = parser.GetTimestamp64(); + if (micros < 0) { + return SQL_ERROR; + } + const TString t = + TInstant::MicroSeconds(static_cast(static_cast(micros))) + .FormatGmTime("%Y-%m-%d %H:%M:%S"); + str.assign(t.data(), t.size()); + break; + } + case EPrimitiveType::TzDate: str = parser.GetTzDate(); break; + case EPrimitiveType::TzDatetime: str = parser.GetTzDatetime(); break; + case EPrimitiveType::TzTimestamp: str = parser.GetTzTimestamp(); break; + default: return SQL_ERROR; + } + SQLLEN len = str.size(); + if (targetValue && bufferLength > 0) { + SQLLEN copyLen = std::min(len, bufferLength - 1); + memcpy(targetValue, str.data(), copyLen); + reinterpret_cast(targetValue)[copyLen] = 0; + } + if (strLenOrInd) { + *strLenOrInd = len; + } + return SQL_SUCCESS; + } + case SQL_C_BIT: + { + char v = parser.GetBool() ? 1 : 0; + if (targetValue) { + *reinterpret_cast(targetValue) = v; + } + if (strLenOrInd) { + *strLenOrInd = sizeof(char); + } + return SQL_SUCCESS; + } + default: + return SQL_ERROR; + } +} + +} // namespace NYdb +} // namespace NOdbc diff --git a/odbc/src/utils/convert.h b/odbc/src/utils/convert.h new file mode 100644 index 00000000000..9b8140665e8 --- /dev/null +++ b/odbc/src/utils/convert.h @@ -0,0 +1,17 @@ +#pragma once + +#include "bindings.h" + +#include + +#include +#include + +namespace NYdb { +namespace NOdbc { + +SQLRETURN ConvertParam(const TBoundParam& param, TParamValueBuilder& builder); +SQLRETURN ConvertColumn(TValueParser& parser, SQLSMALLINT targetType, SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd); + +} // namespace NYdb +} // namespace NOdbc diff --git a/odbc/src/utils/cursor.cpp b/odbc/src/utils/cursor.cpp new file mode 100644 index 00000000000..26ad393b03a --- /dev/null +++ b/odbc/src/utils/cursor.cpp @@ -0,0 +1,143 @@ +#include "cursor.h" + +#include "convert.h" +#include "types.h" + +#include + +namespace NYdb { +namespace NOdbc { + +class TExecCursor : public ICursor { +public: + TExecCursor(IBindingFiller* bindingFiller, NQuery::TExecuteQueryIterator iterator, + std::optional prefetchedPart) + : BindingFiller_(bindingFiller) + , Iterator_(std::move(iterator)) + , PrefetchedPart_(std::move(prefetchedPart)) + {} + + bool Fetch() override { + while (true) { + if (ResultSetParser_) { + if (ResultSetParser_->TryNextRow()) { + BindingFiller_->FillBoundColumns(); + return true; + } + ResultSetParser_.reset(); + } + NQuery::TExecuteQueryPart part = [&]() { + if (PrefetchedPart_) { + auto p = std::move(*PrefetchedPart_); + PrefetchedPart_.reset(); + return p; + } + return Iterator_.ReadNext().ExtractValueSync(); + }(); + if (part.EOS()) { + return false; + } + if (!part.IsSuccess()) { + BindingFiller_->OnStreamPartError(part); + return false; + } + if (part.HasResultSet()) { + TResultSet rs = part.ExtractResultSet(); + Columns_.clear(); + Columns_.reserve(rs.ColumnsCount()); + for (const auto& col : rs.GetColumnsMeta()) { + Columns_.push_back(TColumnMeta{ + col.Name, + GetTypeId(col.Type), + 0, + IsNullable(col.Type)}); + } + ResultSetParser_ = std::make_unique(rs); + } + } + return false; + } + + SQLRETURN GetData(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, + SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd) override { + if (!ResultSetParser_) { + return SQL_NO_DATA; + } + if (columnNumber < 1 || columnNumber > ResultSetParser_->ColumnsCount()) { + return SQL_ERROR; + } + return ConvertColumn(ResultSetParser_->ColumnParser(columnNumber - 1), targetType, targetValue, bufferLength, strLenOrInd); + } + + const std::vector& GetColumnMeta() const override { + return Columns_; + } + +private: + // void GetNextPart() { + // auto part = Iterator_.ReadNext().ExtractValueSync(); + // while (!part.EOS() && part.IsSuccess() && !part.HasResultSet()) { + // part = Iterator_.ReadNext().ExtractValueSync(); + // } + // Part_ = std::move(part); + // } + + IBindingFiller* BindingFiller_; + NQuery::TExecuteQueryIterator Iterator_; + std::optional PrefetchedPart_; + std::unique_ptr ResultSetParser_; + std::vector Columns_; +}; + +class TVirtualCursor : public ICursor { +public: + TVirtualCursor(IBindingFiller* bindingFiller, const std::vector& columns, const TTable& table) + : BindingFiller_(bindingFiller) + , Columns_(columns) + , Table_(table) + {} + + bool Fetch() override { + Cursor_++; + if (Cursor_ >= static_cast(Table_.size())) { + return false; + } + BindingFiller_->FillBoundColumns(); + return true; + } + + SQLRETURN GetData(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, + SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd) override { + if (Cursor_ >= static_cast(Table_.size())) { + return SQL_NO_DATA; + } + if (Cursor_ < 0 || columnNumber < 1 || columnNumber > Columns_.size()) { + return SQL_ERROR; + } + TValueParser parser{Table_[Cursor_][columnNumber - 1]}; + return ConvertColumn(parser, targetType, targetValue, bufferLength, strLenOrInd); + } + + const std::vector& GetColumnMeta() const override { + return Columns_; + } + +private: + IBindingFiller* BindingFiller_; + std::vector Columns_; + TTable Table_; + int64_t Cursor_ = -1; +}; + +std::unique_ptr CreateExecCursor(IBindingFiller* bindingFiller, + NQuery::TExecuteQueryIterator iterator, + std::optional prefetchedPart) { + return std::make_unique(bindingFiller, std::move(iterator), std::move(prefetchedPart)); +} + +std::unique_ptr CreateVirtualCursor(IBindingFiller* bindingFiller, const std::vector& columns, const TTable& table) { + return std::make_unique(bindingFiller, columns, table); +} + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/cursor.h b/odbc/src/utils/cursor.h new file mode 100644 index 00000000000..22828f66144 --- /dev/null +++ b/odbc/src/utils/cursor.h @@ -0,0 +1,40 @@ +#pragma once + +#include "bindings.h" + +#include + +#include + +#include +#include +#include + +namespace NYdb { +namespace NOdbc { + +struct TColumnMeta { + std::string Name; + SQLSMALLINT SqlType; + SQLULEN Size; + SQLSMALLINT Nullable; +}; + +using TTable = std::vector>; + +class ICursor { +public: + virtual ~ICursor() = default; + virtual bool Fetch() = 0; + virtual SQLRETURN GetData(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, + SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd) = 0; + virtual const std::vector& GetColumnMeta() const = 0; +}; + +std::unique_ptr CreateExecCursor(IBindingFiller* bindingFiller, + NYdb::NQuery::TExecuteQueryIterator iterator, + std::optional prefetchedPart = std::nullopt); +std::unique_ptr CreateVirtualCursor(IBindingFiller* bindingFiller, const std::vector& columns, const TTable& table); + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/diag.h b/odbc/src/utils/diag.h new file mode 100644 index 00000000000..5e2db740a07 --- /dev/null +++ b/odbc/src/utils/diag.h @@ -0,0 +1,33 @@ +#pragma once + +#include "error_manager.h" + +#include +#include + +namespace NYdb::NOdbc { +namespace Diag { + + inline SQLRETURN AddNullPointer(TErrorManager& errors) { + return errors.AddError("HY009", 0, "Invalid use of null pointer"); + } + + inline SQLRETURN AddNotImplemented(TErrorManager& errors) { + return errors.AddError("HYC00", 0, "Optional feature not implemented"); + } + + inline SQLRETURN AddInvalidAttrValue(TErrorManager& errors, std::string_view attrName) { + return errors.AddError("HY024", 0, "Invalid " + std::string(attrName) + " value"); + } + + inline SQLRETURN AddInvalidBufferLength(TErrorManager& errors) { + return errors.AddError("HY090", 0, "Invalid string or buffer length"); + } + + inline SQLRETURN AddRightTruncated(TErrorManager& errors) { + return errors.AddError("01004", 0, "String data, right truncated", SQL_SUCCESS_WITH_INFO); + } + +} + +} // namespace NYdb::NOdbc::Diag diff --git a/odbc/src/utils/error_manager.cpp b/odbc/src/utils/error_manager.cpp new file mode 100644 index 00000000000..92c8ec1750f --- /dev/null +++ b/odbc/src/utils/error_manager.cpp @@ -0,0 +1,231 @@ +#include "error_manager.h" + +#include +#include +#include +#include + +namespace NYdb { +namespace NOdbc { +namespace { + struct OdbcErrorMapping { + const char* sqlState; + const char* description; + SQLRETURN returnCode; + }; + + const std::unordered_map ERROR_MAPPINGS = { + {EStatus::SUCCESS, {"00000", "Success", SQL_SUCCESS}}, + {EStatus::BAD_REQUEST, {"42000", "Syntax error or access rule violation", SQL_ERROR}}, + {EStatus::UNAUTHORIZED, {"28000", "Invalid authorization specification", SQL_ERROR}}, + {EStatus::INTERNAL_ERROR, {"HY000", "General error", SQL_ERROR}}, + {EStatus::ABORTED, {"25000", "Invalid transaction state", SQL_ERROR}}, + {EStatus::UNAVAILABLE, {"08001", "Client unable to establish connection", SQL_ERROR}}, + {EStatus::OVERLOADED, {"HY000", "General error - server overloaded", SQL_ERROR}}, + {EStatus::SCHEME_ERROR, {"42S02", "Base table or view not found", SQL_ERROR}}, + {EStatus::GENERIC_ERROR, {"HY000", "General error", SQL_ERROR}}, + {EStatus::TIMEOUT, {"HYT00", "Timeout expired", SQL_ERROR}}, + {EStatus::BAD_SESSION, {"08003", "Connection does not exist", SQL_ERROR}}, + {EStatus::PRECONDITION_FAILED, {"23000", "Integrity constraint violation", SQL_ERROR}}, + {EStatus::ALREADY_EXISTS, {"23000", "Integrity constraint violation", SQL_ERROR}}, + {EStatus::NOT_FOUND, {"02000", "No data found", SQL_NO_DATA}}, + {EStatus::SESSION_EXPIRED, {"08003", "Connection does not exist", SQL_ERROR}}, + {EStatus::CANCELLED, {"HY008", "Operation canceled", SQL_ERROR}}, + {EStatus::UNDETERMINED, {"HY000", "General error", SQL_ERROR}}, + {EStatus::UNSUPPORTED, {"HYC00", "Optional feature not implemented", SQL_ERROR}}, + {EStatus::SESSION_BUSY, {"HY000", "General error - session busy", SQL_ERROR}}, + // Transport errors + {EStatus::TRANSPORT_UNAVAILABLE, {"08001", "Client unable to establish connection", SQL_ERROR}}, + {EStatus::CLIENT_RESOURCE_EXHAUSTED, {"HY000", "General error - resource exhausted", SQL_ERROR}}, + {EStatus::CLIENT_DEADLINE_EXCEEDED, {"HYT00", "Timeout expired", SQL_ERROR}}, + {EStatus::CLIENT_INTERNAL_ERROR, {"HY000", "General error", SQL_ERROR}}, + {EStatus::CLIENT_CANCELLED, {"HY008", "Operation canceled", SQL_ERROR}}, + {EStatus::CLIENT_UNAUTHENTICATED, {"28000", "Invalid authorization specification", SQL_ERROR}}, + {EStatus::CLIENT_LIMITS_REACHED, {"HY000", "General error - limits reached", SQL_ERROR}}, + {EStatus::CLIENT_DISCOVERY_FAILED, {"08001", "Client unable to establish connection", SQL_ERROR}}, + {EStatus::CLIENT_CALL_UNIMPLEMENTED, {"HYC00", "Optional feature not implemented", SQL_ERROR}}, + {EStatus::CLIENT_OUT_OF_RANGE, {"22003", "Numeric value out of range", SQL_ERROR}}, + }; + + const OdbcErrorMapping DEFAULT_ERROR_MAPPING = {"HY000", "Unknown YDB error", SQL_ERROR}; + + OdbcErrorMapping GetErrorMappingForStatus(EStatus status) { + auto it = ERROR_MAPPINGS.find(status); + if (it != ERROR_MAPPINGS.end()) { + return it->second; + } + return DEFAULT_ERROR_MAPPING; + } +} // namespace + +namespace { + +SQLRETURN WriteDiagCStr( + const std::string& str, + SQLPOINTER diagInfoPtr, + SQLSMALLINT bufferLength, + SQLSMALLINT* stringLengthPtr, + bool sqlStateField = false) { + std::string storage; + const std::string* src = &str; + if (sqlStateField) { + storage = str; + if (storage.size() < 5) { + storage.append(5U - storage.size(), ' '); + } else { + storage.resize(5U); + } + src = &storage; + } + if (!diagInfoPtr) { + return SQL_ERROR; + } + if (bufferLength < 0) { + return SQL_ERROR; + } + const size_t fullLen = src->size(); + if (stringLengthPtr) { + *stringLengthPtr = static_cast(std::min(fullLen, 0x7FFFU)); + } + if (bufferLength == 0) { + return fullLen == 0 ? SQL_SUCCESS : SQL_SUCCESS_WITH_INFO; + } + auto* out = static_cast(diagInfoPtr); + const size_t maxData = static_cast(bufferLength - 1U); + const size_t copyLen = std::min(fullLen, maxData); + std::memcpy(out, src->data(), copyLen); + out[copyLen] = 0; + return (fullLen > maxData) ? SQL_SUCCESS_WITH_INFO : SQL_SUCCESS; +} + +} // namespace + +SQLRETURN TErrorManager::AddError(const std::string& sqlState, SQLINTEGER nativeError, const std::string& message, SQLRETURN returnCode) { + Errors_.push_back({sqlState, nativeError, message, returnCode}); + LastReturnCode_ = returnCode; + return returnCode; +} + +SQLRETURN TErrorManager::AddError(const TOdbcException& ex) { + Errors_.push_back({ex.GetSqlState(), ex.GetNativeError(), ex.GetMessage(), ex.GetReturnCode()}); + LastReturnCode_ = ex.GetReturnCode(); + return ex.GetReturnCode(); +} + +SQLRETURN TErrorManager::AddError(const TStatus& status) { + auto mapping = GetErrorMappingForStatus(status.GetStatus()); + std::string message = mapping.description; + if (!status.GetIssues().Empty()) { + message += ": " + status.GetIssues().ToString(); + } + Errors_.push_back({mapping.sqlState, static_cast(status.GetStatus()), message, mapping.returnCode}); + LastReturnCode_ = mapping.returnCode; + return mapping.returnCode; +} + +void TErrorManager::ClearErrors() { + Errors_.clear(); +} + +SQLRETURN TErrorManager::GetDiagRec(SQLSMALLINT recNumber, SQLCHAR* sqlState, SQLINTEGER* nativeError, + SQLCHAR* messageText, SQLSMALLINT bufferLength, SQLSMALLINT* textLength) { + if (recNumber < 1 || recNumber > (SQLSMALLINT)Errors_.size()) { + return SQL_NO_DATA; + } + + const auto& err = Errors_[recNumber-1]; + if (sqlState) { + strncpy((char*)sqlState, err.SqlState.c_str(), 6); + } + + if (nativeError) { + *nativeError = err.NativeError; + } + + if (messageText && bufferLength > 0) { + strncpy((char*)messageText, err.Message.c_str(), bufferLength); + if (textLength) { + *textLength = (SQLSMALLINT)std::min((int)err.Message.size(), (int)bufferLength); + } + } + return SQL_SUCCESS; +} + +SQLRETURN TErrorManager::GetDiagField(SQLSMALLINT recNumber, SQLSMALLINT diagIdentifier, SQLPOINTER diagInfoPtr, + SQLSMALLINT bufferLength, SQLSMALLINT* stringLengthPtr) { + const SQLSMALLINT count = static_cast(Errors_.size()); + if (diagInfoPtr == nullptr) { + return SQL_ERROR; + } + if (recNumber == 0) { + switch (diagIdentifier) { + case SQL_DIAG_RETURNCODE: + *static_cast(diagInfoPtr) = LastReturnCode_; + return SQL_SUCCESS; + case SQL_DIAG_NUMBER: { + *static_cast(diagInfoPtr) = static_cast(count); + return SQL_SUCCESS; + } + case SQL_DIAG_ROW_COUNT: + return SQL_ERROR; + default: + return SQL_ERROR; + } + } + + if (recNumber < 1 || recNumber > count) { + return SQL_NO_DATA; + } + + const auto& err = Errors_[recNumber - 1]; + switch (diagIdentifier) { + case SQL_DIAG_SQLSTATE: + return WriteDiagCStr(err.SqlState, diagInfoPtr, bufferLength, stringLengthPtr, true); + case SQL_DIAG_NATIVE: { + *static_cast(diagInfoPtr) = err.NativeError; + return SQL_SUCCESS; + } + case SQL_DIAG_MESSAGE_TEXT: + return WriteDiagCStr(err.Message, diagInfoPtr, bufferLength, stringLengthPtr); + case SQL_DIAG_CLASS_ORIGIN: + return WriteDiagCStr("ODBC 3.0", diagInfoPtr, bufferLength, stringLengthPtr); + case SQL_DIAG_SUBCLASS_ORIGIN: + return WriteDiagCStr("ODBC 3.0", diagInfoPtr, bufferLength, stringLengthPtr); + case SQL_DIAG_CONNECTION_NAME: + case SQL_DIAG_SERVER_NAME: + return WriteDiagCStr("", diagInfoPtr, bufferLength, stringLengthPtr); + case SQL_DIAG_COLUMN_NUMBER: + *static_cast(diagInfoPtr) = SQL_COLUMN_NUMBER_UNKNOWN; + return SQL_SUCCESS; + case SQL_DIAG_ROW_NUMBER: + *static_cast(diagInfoPtr) = SQL_ROW_NUMBER_UNKNOWN; + return SQL_SUCCESS; + default: + return SQL_ERROR; + } +} + +SQLRETURN HandleOdbcExceptions( + SQLHANDLE handlePtr, + std::function&& func, + ENullInputHandlePolicy nullInputPolicy) { + if (!handlePtr && nullInputPolicy != ENullInputHandlePolicy::Allow) { + return SQL_INVALID_HANDLE; + } + + try { + const SQLRETURN r = func(); + if (handlePtr) { + static_cast(handlePtr)->SetLastReturnCode(r); + } + return r; + } catch (...) { + if (handlePtr) { + static_cast(handlePtr)->SetLastReturnCode(SQL_ERROR); + } + return SQL_ERROR; + } +} + +} // namespace NOdbc +} // namespace NYdb \ No newline at end of file diff --git a/odbc/src/utils/error_manager.h b/odbc/src/utils/error_manager.h new file mode 100644 index 00000000000..9f91fab8a1d --- /dev/null +++ b/odbc/src/utils/error_manager.h @@ -0,0 +1,117 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace NYdb { +namespace NOdbc { + +struct TErrorInfo { + std::string SqlState; + SQLINTEGER NativeError; + std::string Message; + SQLRETURN ReturnCode; +}; + +using TErrorList = std::vector; + +class TOdbcException : public std::exception { +public: + TOdbcException(const std::string& sqlState, SQLINTEGER nativeError, + const std::string& message, SQLRETURN returnCode = SQL_ERROR) + : SqlState_(sqlState) + , NativeError_(nativeError) + , Message_(message) + , ReturnCode_(returnCode) + {} + + const std::string& GetSqlState() const { + return SqlState_; + } + + SQLINTEGER GetNativeError() const { + return NativeError_; + } + + const std::string& GetMessage() const { + return Message_; + } + + SQLRETURN GetReturnCode() const { + return ReturnCode_; + } + + const char* what() const noexcept override { + return Message_.c_str(); + } + +private: + std::string SqlState_; + SQLINTEGER NativeError_; + std::string Message_; + SQLRETURN ReturnCode_; +}; + +class TErrorManager { +public: + SQLRETURN AddError(const std::string& sqlState, SQLINTEGER nativeError, const std::string& message, SQLRETURN returnCode = SQL_ERROR); + SQLRETURN AddError(const TOdbcException& ex); + SQLRETURN AddError(const TStatus& status); + + void ClearErrors(); + + void SetLastReturnCode(SQLRETURN code) { + LastReturnCode_ = code; + } + [[nodiscard]] SQLRETURN GetLastReturnCode() const { + return LastReturnCode_; + } + + SQLRETURN GetDiagRec(SQLSMALLINT recNumber, SQLCHAR* sqlState, SQLINTEGER* nativeError, + SQLCHAR* messageText, SQLSMALLINT bufferLength, SQLSMALLINT* textLength); + virtual SQLRETURN GetDiagField(SQLSMALLINT recNumber, SQLSMALLINT diagIdentifier, + SQLPOINTER diagInfoPtr, SQLSMALLINT bufferLength, SQLSMALLINT* stringLengthPtr); + +private: + TErrorList Errors_; + SQLRETURN LastReturnCode_ = SQL_SUCCESS; +}; + +enum class ENullInputHandlePolicy : unsigned char { + Reject, + Allow, +}; + +template +SQLRETURN HandleOdbcExceptions(SQLHANDLE handlePtr, std::function&& func) { + if (!handlePtr) { + return SQL_INVALID_HANDLE; + } + auto handle = static_cast(handlePtr); + + try { + const SQLRETURN ret = func(handle); + handle->SetLastReturnCode(ret); + return ret; + } catch (const NStatusHelpers::TYdbErrorException& ex) { + return handle->AddError(ex.GetStatus()); + } catch (const TOdbcException& ex) { + return handle->AddError(ex); + } catch (const std::exception& ex) { + return handle->AddError("HY000", 0, ex.what()); + } catch (...) { + return handle->AddError("HY000", 0, "Unknown error"); + } +} + +SQLRETURN HandleOdbcExceptions( + SQLHANDLE handlePtr, + std::function&& func, + ENullInputHandlePolicy nullInputPolicy = ENullInputHandlePolicy::Reject); + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/escape.cpp b/odbc/src/utils/escape.cpp new file mode 100644 index 00000000000..5a9c643eb7a --- /dev/null +++ b/odbc/src/utils/escape.cpp @@ -0,0 +1,444 @@ +#include "escape.h" + +#include +#include +#include +#include +#include + +namespace NYdb::NOdbc { +namespace { + +bool EqualNoCase(std::string_view lhs, std::string_view rhs) { + return lhs.size() == rhs.size() && + std::equal(lhs.begin(), lhs.end(), rhs.begin(), [](char leftCh, char rightCh) { + return std::tolower(static_cast(leftCh)) == + std::tolower(static_cast(rightCh)); + }); +} + +void SkipLeadingWhitespace(std::string_view sql, size_t& cursor) { + const auto strEnd = sql.end(); + const auto firstNonSpace = std::find_if_not( + sql.begin() + static_cast(cursor), + strEnd, + [](unsigned char byte) { + return std::isspace(byte) != 0; + }); + cursor = static_cast(firstNonSpace - sql.begin()); +} + +bool ReadIdent(std::string_view sql, size_t& cursor, std::string_view* outIdent) { + SkipLeadingWhitespace(sql, cursor); + const size_t identStart = cursor; + const auto afterIdent = std::find_if_not( + sql.begin() + static_cast(cursor), + sql.end(), + [](unsigned char byte) { + return std::isalpha(byte) != 0 || byte == '_'; + }); + cursor = static_cast(afterIdent - sql.begin()); + if (cursor == identStart) { + return false; + } + *outIdent = std::string_view(sql.data() + identStart, cursor - identStart); + return true; +} + +bool ParseSingleQuoted(std::string_view sql, size_t& cursor, std::string* outValue) { + SkipLeadingWhitespace(sql, cursor); + if (cursor >= sql.size() || sql[cursor] != '\'') { + return false; + } + ++cursor; + outValue->clear(); + while (cursor < sql.size()) { + if (sql[cursor] == '\'') { + if (cursor + 1 < sql.size() && sql[cursor + 1] == '\'') { + outValue->push_back('\''); + cursor += 2; + continue; + } + ++cursor; + return true; + } + outValue->push_back(sql[cursor++]); + } + return false; +} + +size_t FindMatchingCloseBrace(std::string_view sql, size_t openBrace) { + if (openBrace >= sql.size() || sql[openBrace] != '{') { + return std::string_view::npos; + } + int braceDepth = 1; + for (size_t idx = openBrace + 1; idx < sql.size(); ++idx) { + if (sql[idx] == '{') { + ++braceDepth; + } else if (sql[idx] == '}') { + --braceDepth; + if (braceDepth == 0) { + return idx; + } + } + } + return std::string_view::npos; +} + +std::string NormalizeOdbcTimestampLiteral(const std::string& raw) { + std::string normalized = raw; + const auto firstSpace = std::find(normalized.begin(), normalized.end(), ' '); + if (firstSpace != normalized.end()) { + *firstSpace = 'T'; + } + if (std::find(normalized.begin(), normalized.end(), 'Z') == normalized.end()) { + normalized.push_back('Z'); + } + return normalized; +} + +std::string ToUpperAscii(std::string_view sv) { + std::string upper; + upper.resize(sv.size()); + std::transform(sv.begin(), sv.end(), upper.begin(), [](unsigned char byte) { + return static_cast(std::toupper(byte)); + }); + return upper; +} + +std::string MapSqlTypeToken(std::string_view sqlType) { + static const std::unordered_map kMap = { + {"CHAR", "Utf8"}, + {"VARCHAR", "Utf8"}, + {"LONGVARCHAR", "Utf8"}, + {"WCHAR", "Utf8"}, + {"WVARCHAR", "Utf8"}, + {"WLONGVARCHAR", "Utf8"}, + {"BIT", "Bool"}, + {"TINYINT", "Int8"}, + {"SMALLINT", "Int16"}, + {"INTEGER", "Int32"}, + {"BIGINT", "Int64"}, + {"REAL", "Float"}, + {"FLOAT", "Double"}, + {"DOUBLE", "Double"}, + {"DECIMAL", "Decimal(22, 9)"}, + {"NUMERIC", "Decimal(22, 9)"}, + {"BINARY", "String"}, + {"VARBINARY", "String"}, + {"LONGVARBINARY", "String"}, + {"DATE", "Date"}, + {"TIME", "Time"}, + {"TIMESTAMP", "Datetime"}, + {"TYPE_DATE", "Date"}, + {"TYPE_TIME", "Time"}, + {"TYPE_TIMESTAMP", "Datetime"}, + }; + std::string key = ToUpperAscii(sqlType); + const std::string kSql = "SQL_"; + if (key.size() > kSql.size() && key.compare(0, kSql.size(), kSql) == 0) { + key.erase(0, kSql.size()); + } + const auto mapped = kMap.find(key); + if (mapped != kMap.end()) { + return mapped->second; + } + return key; +} + +std::string RewriteOdbcEscapesImpl(std::string_view sql); + + +enum class OdbcBraceKind { + OutputProcedureCall, // {?= call ... } + FnBody, // {fn ...} + OjBody, // {oj ...} + DateLiteral, // {d '...'} + TimeLiteral, // {t '...'} + TimestampLiteral, // {ts '...'} + ProcedureCall, // {call ...} + LikeEscape, // {escape '...'} +}; + +struct OdbcBraceParsed { + OdbcBraceKind Kind; + std::string_view RecurseTail; + std::string QuotedValue; +}; + +std::optional TryParseOutputCallBrace(std::string_view sql, size_t parsePos, size_t closeBrace) { + if (parsePos + 1 >= sql.size() || sql[parsePos] != '?' || sql[parsePos + 1] != '=') { + return std::nullopt; + } + size_t inner = parsePos + 2; + SkipLeadingWhitespace(sql, inner); + std::string_view keyword; + if (!ReadIdent(sql, inner, &keyword) || !EqualNoCase(keyword, "call")) { + return std::nullopt; + } + SkipLeadingWhitespace(sql, inner); + if (inner > closeBrace) { + return std::nullopt; + } + OdbcBraceParsed parsed; + parsed.Kind = OdbcBraceKind::OutputProcedureCall; + parsed.RecurseTail = std::string_view(sql.data() + inner, closeBrace - inner); + return parsed; +} + +std::optional MakeRecurseTailBrace(OdbcBraceKind kind, std::string_view sql, size_t& parsePos, size_t closeBrace) { + SkipLeadingWhitespace(sql, parsePos); + if (parsePos > closeBrace) { + return std::nullopt; + } + OdbcBraceParsed parsed; + parsed.Kind = kind; + parsed.RecurseTail = std::string_view(sql.data() + parsePos, closeBrace - parsePos); + return parsed; +} + +std::optional MakeQuotedBrace(OdbcBraceKind kind, std::string_view sql, size_t& parsePos, size_t closeBrace) { + std::string quotedLit; + if (!ParseSingleQuoted(sql, parsePos, "edLit) || parsePos > closeBrace) { + return std::nullopt; + } + SkipLeadingWhitespace(sql, parsePos); + if (parsePos != closeBrace) { + return std::nullopt; + } + OdbcBraceParsed parsed; + parsed.Kind = kind; + parsed.QuotedValue = std::move(quotedLit); + return parsed; +} + +struct BraceKeywordSpec { + std::string_view Keyword; + OdbcBraceKind Kind; + bool IsQuotedLiteral; +}; + +static constexpr BraceKeywordSpec kBraceKeywordSpecs[] = { + {"fn", OdbcBraceKind::FnBody, false}, + {"oj", OdbcBraceKind::OjBody, false}, + {"d", OdbcBraceKind::DateLiteral, true}, + {"t", OdbcBraceKind::TimeLiteral, true}, + {"ts", OdbcBraceKind::TimestampLiteral, true}, + {"call", OdbcBraceKind::ProcedureCall, false}, + {"escape", OdbcBraceKind::LikeEscape, true}, +}; + +std::optional TryParseOdbcBrace(std::string_view sql, size_t openBrace, size_t closeBrace) { + size_t parsePos = openBrace + 1; + SkipLeadingWhitespace(sql, parsePos); + + if (std::optional outputCall = TryParseOutputCallBrace(sql, parsePos, closeBrace)) { + return outputCall; + } + if (parsePos + 1 < sql.size() && sql[parsePos] == '?' && sql[parsePos + 1] == '=') { + return std::nullopt; + } + + std::string_view token; + if (!ReadIdent(sql, parsePos, &token)) { + return std::nullopt; + } + + for (const BraceKeywordSpec& spec : kBraceKeywordSpecs) { + if (!EqualNoCase(token, spec.Keyword)) { + continue; + } + if (spec.IsQuotedLiteral) { + return MakeQuotedBrace(spec.Kind, sql, parsePos, closeBrace); + } + return MakeRecurseTailBrace(spec.Kind, sql, parsePos, closeBrace); + } + + return std::nullopt; +} + +void AppendRewrittenBrace(std::string& rewritten, const OdbcBraceParsed& parsed) { + switch (parsed.Kind) { + case OdbcBraceKind::OutputProcedureCall: + case OdbcBraceKind::ProcedureCall: + rewritten += "CALL "; + rewritten.append(RewriteOdbcEscapesImpl(parsed.RecurseTail)); + return; + case OdbcBraceKind::FnBody: + case OdbcBraceKind::OjBody: + rewritten.append(RewriteOdbcEscapesImpl(parsed.RecurseTail)); + return; + case OdbcBraceKind::DateLiteral: + rewritten += "CAST('"; + rewritten += parsed.QuotedValue; + rewritten += "' AS Date)"; + return; + case OdbcBraceKind::TimeLiteral: + rewritten += "CAST('"; + rewritten += parsed.QuotedValue; + rewritten += "' AS Time)"; + return; + case OdbcBraceKind::TimestampLiteral: { + const std::string normalizedTs = NormalizeOdbcTimestampLiteral(parsed.QuotedValue); + rewritten += "CAST('"; + rewritten += normalizedTs; + rewritten += "' AS Datetime)"; + return; + } + case OdbcBraceKind::LikeEscape: + rewritten += " ESCAPE '"; + rewritten += parsed.QuotedValue; + rewritten += '\''; + return; + } +} + +std::string RewriteOdbcEscapesImpl(std::string_view sql) { + std::string rewritten; + rewritten.reserve(sql.size()); + + for (size_t readPos = 0; readPos < sql.size();) { + if (sql[readPos] != '{') { + rewritten.push_back(sql[readPos++]); + continue; + } + + const size_t closeBrace = FindMatchingCloseBrace(sql, readPos); + if (closeBrace == std::string_view::npos) { + rewritten.push_back(sql[readPos++]); + continue; + } + + if (std::optional parsedBrace = TryParseOdbcBrace(sql, readPos, closeBrace)) { + AppendRewrittenBrace(rewritten, *parsedBrace); + readPos = closeBrace + 1; + continue; + } + + rewritten.push_back(sql[readPos++]); + } + + return rewritten; +} + +std::string RewriteOdbcConvertCalls(std::string_view sql); + +class TOdbcConvertCallRewriter { +public: + explicit TOdbcConvertCallRewriter(std::string_view sql) + : Sql_(sql) { + Rewritten_.reserve(sql.size()); + } + + std::string TakeResult() && { + return std::move(Rewritten_); + } + + void Run() { + while (SegmentStart_ < Sql_.size()) { + const std::optional convertKeywordPos = FindNextConvertKeyword(SegmentStart_); + if (!convertKeywordPos) { + Rewritten_.append(Sql_.substr(SegmentStart_)); + break; + } + Rewritten_.append(Sql_.substr(SegmentStart_, *convertKeywordPos - SegmentStart_)); + if (!TryRewriteConvertAt(*convertKeywordPos)) { + break; + } + } + } + +private: + static constexpr size_t kConvertTokenLen = 7; + + std::optional FindNextConvertKeyword(size_t from) const { + for (size_t probePos = from; probePos + kConvertTokenLen <= Sql_.size(); ++probePos) { + if (!EqualNoCase(Sql_.substr(probePos, kConvertTokenLen), "CONVERT")) { + continue; + } + size_t afterKeyword = probePos + kConvertTokenLen; + SkipLeadingWhitespace(Sql_, afterKeyword); + if (afterKeyword < Sql_.size() && Sql_[afterKeyword] == '(') { + return probePos; + } + } + return std::nullopt; + } + + bool TryRewriteConvertAt(size_t convertKeywordPos) { + size_t parsePos = convertKeywordPos + kConvertTokenLen; + SkipLeadingWhitespace(Sql_, parsePos); + if (parsePos >= Sql_.size() || Sql_[parsePos] != '(') { + Rewritten_.append(Sql_.substr(convertKeywordPos, kConvertTokenLen)); + SegmentStart_ = convertKeywordPos + kConvertTokenLen; + return true; + } + ++parsePos; + + int parenDepth = 1; + const size_t firstArgStart = parsePos; + std::optional typeCommaPos; + for (; parsePos < Sql_.size(); ++parsePos) { + if (Sql_[parsePos] == '(') { + ++parenDepth; + } else if (Sql_[parsePos] == ')') { + --parenDepth; + } else if (Sql_[parsePos] == ',' && parenDepth == 1) { + typeCommaPos = parsePos; + break; + } + } + if (!typeCommaPos) { + Rewritten_.append(Sql_.substr(convertKeywordPos)); + return false; + } + + const std::string_view firstArg(Sql_.data() + firstArgStart, *typeCommaPos - firstArgStart); + parsePos = *typeCommaPos + 1; + SkipLeadingWhitespace(Sql_, parsePos); + const size_t sqlTypeStart = parsePos; + const auto sqlTypeEnd = std::find_if_not( + Sql_.begin() + static_cast(parsePos), + Sql_.end(), + [](unsigned char byte) { + return std::isalpha(byte) != 0 || byte == '_'; + }); + parsePos = static_cast(sqlTypeEnd - Sql_.begin()); + const std::string_view sqlTypeToken(Sql_.data() + sqlTypeStart, parsePos - sqlTypeStart); + SkipLeadingWhitespace(Sql_, parsePos); + if (parsePos >= Sql_.size() || Sql_[parsePos] != ')') { + Rewritten_.append(Sql_.substr(convertKeywordPos)); + return false; + } + + const std::string yqlType = MapSqlTypeToken(sqlTypeToken); + Rewritten_ += "CAST("; + Rewritten_ += RewriteOdbcConvertCalls(RewriteOdbcEscapesImpl(firstArg)); + Rewritten_ += " AS "; + Rewritten_ += yqlType; + Rewritten_ += ')'; + SegmentStart_ = parsePos + 1; + return true; + } + + std::string_view Sql_; + std::string Rewritten_; + size_t SegmentStart_ = 0; +}; + +std::string RewriteOdbcConvertCalls(std::string_view sql) { + TOdbcConvertCallRewriter rewriter(sql); + rewriter.Run(); + return std::move(rewriter).TakeResult(); +} + +} // namespace + + + +std::string RewriteOdbcEscapes(const std::string& sql) { + std::string afterBraceRewrite = RewriteOdbcEscapesImpl(sql); + return RewriteOdbcConvertCalls(afterBraceRewrite); +} + +} // namespace NYdb::NOdbc diff --git a/odbc/src/utils/escape.h b/odbc/src/utils/escape.h new file mode 100644 index 00000000000..7397a128450 --- /dev/null +++ b/odbc/src/utils/escape.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace NYdb::NOdbc { + +std::string RewriteOdbcEscapes(const std::string& sql); + +} // namespace NYdb::NOdbc diff --git a/odbc/src/utils/sql_like.h b/odbc/src/utils/sql_like.h new file mode 100644 index 00000000000..f51c10ca28c --- /dev/null +++ b/odbc/src/utils/sql_like.h @@ -0,0 +1,49 @@ +#pragma once + +#include + +namespace NYdb::NOdbc { + +// SQL LIKE — '%' is any substring, '_' is any single character. +inline bool SqlLikeMatch(std::string_view text, std::string_view pattern) { + size_t textPos = 0; + size_t patPos = 0; + size_t lastPercentPat = std::string_view::npos; + size_t textStartAfterPercent = 0; + + const size_t textLen = text.size(); + const size_t patLen = pattern.size(); + + while (textPos < textLen) { + const bool morePat = patPos < patLen; + const char patCh = morePat ? pattern[patPos] : '\0'; + + if (morePat && patCh != '%' && (patCh == '_' || patCh == text[textPos])) { + ++textPos; + ++patPos; + continue; + } + + if (morePat && patCh == '%') { + lastPercentPat = patPos++; + textStartAfterPercent = textPos; + continue; + } + + if (lastPercentPat != std::string_view::npos) { + patPos = lastPercentPat + 1; + ++textStartAfterPercent; + textPos = textStartAfterPercent; + continue; + } + + return false; + } + + while (patPos < patLen && pattern[patPos] == '%') { + ++patPos; + } + return patPos == patLen; +} + +} // namespace NYdb::NOdbc diff --git a/odbc/src/utils/types.cpp b/odbc/src/utils/types.cpp new file mode 100644 index 00000000000..ce5ead462cc --- /dev/null +++ b/odbc/src/utils/types.cpp @@ -0,0 +1,70 @@ +#include "types.h" + +namespace NYdb { +namespace NOdbc { + +SQLSMALLINT GetTypeId(const TType& type) { + // TODO: implement + return 0; +} + +SQLSMALLINT IsNullable(const TType& type) { + TTypeParser typeParser(type); + if (typeParser.GetKind() == TTypeParser::ETypeKind::Optional || typeParser.GetKind() == TTypeParser::ETypeKind::Null) { + return SQL_NULLABLE; + } + + return SQL_NO_NULLS; +} + +std::optional GetDecimalDigits(const TType& type) { + TTypeParser typeParser(type); + if (typeParser.GetKind() != TTypeParser::ETypeKind::Primitive) { + return std::nullopt; + } + + switch (typeParser.GetPrimitive()) { + case EPrimitiveType::Int64: + return 64; + case EPrimitiveType::Uint64: + return 64; + case EPrimitiveType::Int32: + return 32; + case EPrimitiveType::Uint32: + return 32; + case EPrimitiveType::Int16: + return 16; + case EPrimitiveType::Uint16: + return 16; + case EPrimitiveType::Int8: + return 8; + case EPrimitiveType::Uint8: + return 8; + default: + return std::nullopt; + } +} + +std::optional GetRadix(const TType& type) { + TTypeParser typeParser(type); + if (typeParser.GetKind() != TTypeParser::ETypeKind::Primitive) { + return std::nullopt; + } + + switch (typeParser.GetPrimitive()) { + case EPrimitiveType::Int64: + case EPrimitiveType::Uint64: + case EPrimitiveType::Int32: + case EPrimitiveType::Uint32: + case EPrimitiveType::Int16: + case EPrimitiveType::Uint16: + case EPrimitiveType::Int8: + case EPrimitiveType::Uint8: + return 10; + default: + return std::nullopt; + } +} + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/types.h b/odbc/src/utils/types.h new file mode 100644 index 00000000000..3f481702902 --- /dev/null +++ b/odbc/src/utils/types.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +#include + +namespace NYdb { +namespace NOdbc { + +SQLSMALLINT GetTypeId(const TType& type); +SQLSMALLINT IsNullable(const TType& type); + +std::optional GetDecimalDigits(const TType& type); +std::optional GetRadix(const TType& type); + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/util.cpp b/odbc/src/utils/util.cpp new file mode 100644 index 00000000000..9097ce80dbf --- /dev/null +++ b/odbc/src/utils/util.cpp @@ -0,0 +1,12 @@ +#include "util.h" + +namespace NYdb::NOdbc { + +std::string GetString(SQLCHAR* str, SQLSMALLINT length) { + if (length == SQL_NTS) { + return std::string(reinterpret_cast(str)); + } + return std::string(reinterpret_cast(str), length); +} + +} // namespace NYdb::NOdbc diff --git a/odbc/src/utils/util.h b/odbc/src/utils/util.h new file mode 100644 index 00000000000..b17fe2c235f --- /dev/null +++ b/odbc/src/utils/util.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include + +#include + +namespace NYdb::NOdbc { + +std::string GetString(SQLCHAR* str, SQLSMALLINT length); + +} // namespace NYdb::NOdbc diff --git a/odbc/tests/CMakeLists.txt b/odbc/tests/CMakeLists.txt new file mode 100644 index 00000000000..729c6ee0778 --- /dev/null +++ b/odbc/tests/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(integration) +add_subdirectory(unit) diff --git a/odbc/tests/integration/CMakeLists.txt b/odbc/tests/integration/CMakeLists.txt new file mode 100644 index 00000000000..43925350b02 --- /dev/null +++ b/odbc/tests/integration/CMakeLists.txt @@ -0,0 +1,19 @@ +add_odbc_test(NAME odbc-basic_it + SOURCES + basic_it.cpp +) + +add_odbc_test(NAME odbc-env_it + SOURCES + env_it.cpp +) + +add_odbc_test(NAME odbc-attr_it + SOURCES + attr_it.cpp +) + +add_odbc_test(NAME odbc-stmt_attr_it + SOURCES + stmt_attr_it.cpp +) \ No newline at end of file diff --git a/odbc/tests/integration/attr_it.cpp b/odbc/tests/integration/attr_it.cpp new file mode 100644 index 00000000000..2dc30446498 --- /dev/null +++ b/odbc/tests/integration/attr_it.cpp @@ -0,0 +1,223 @@ +#include "test_utils.h" + +#include +#include + + +TEST(OdbcAttrEnv, OdbcVersionAttr) { + SQLHENV env; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env), SQL_SUCCESS); + ASSERT_EQ(SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); + ASSERT_NE(SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, nullptr, 0), SQL_SUCCESS); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcAttrEnv, OutputNtsAttr) { + SQLHENV env; + AllocEnv(&env); + ASSERT_EQ(SQLSetEnvAttr(env, SQL_ATTR_OUTPUT_NTS, (void*)SQL_TRUE, 0), SQL_SUCCESS); + ASSERT_NE(SQLSetEnvAttr(env, SQL_ATTR_OUTPUT_NTS, (void*)SQL_FALSE, 0), SQL_SUCCESS); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcAttrConn, AutocommitAttr) { + SQLHENV env; + SQLHDBC dbc; + AllocEnvAndConnect(&env, &dbc); + SQLHSTMT stmt; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + SQLCHAR dropQuery[] = "DROP TABLE IF EXISTS test_attr_autocommit"; + SQLCHAR createQuery[] = + "CREATE TABLE test_attr_autocommit (id Int32, value Int32, PRIMARY KEY (id))"; + SQLCHAR upsertRollbackQuery[] = "UPSERT INTO test_attr_autocommit (id, value) VALUES (1, 100)"; + SQLCHAR upsertCommitQuery[] = "UPSERT INTO test_attr_autocommit (id, value) VALUES (1, 200)"; + SQLCHAR selectQuery[] = "SELECT value FROM test_attr_autocommit WHERE id = 1"; + + CHECK_ODBC_OK(SQLExecDirect(stmt, dropQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, createQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_OFF, 0), dbc, SQL_HANDLE_DBC); + CHECK_ODBC_OK(SQLExecDirect(stmt, upsertRollbackQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLEndTran(SQL_HANDLE_DBC, dbc, SQL_ROLLBACK), SQL_SUCCESS); + CHECK_ODBC_OK(SQLExecDirect(stmt, selectQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_NO_DATA); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + + CHECK_ODBC_OK(SQLExecDirect(stmt, upsertCommitQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLEndTran(SQL_HANDLE_DBC, dbc, SQL_COMMIT), SQL_SUCCESS); + CHECK_ODBC_OK(SQLExecDirect(stmt, selectQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + SQLINTEGER valueInt = 0; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &valueInt, 0, &valueInd), SQL_SUCCESS); + ASSERT_EQ(valueInt, 200); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_ON, 0), dbc, SQL_HANDLE_DBC); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcAttrConn, AccessModeAttr) { + SQLHENV env; + SQLHDBC dbc; + AllocEnvAndConnect(&env, &dbc); + + constexpr SQLUINTEGER readWriteMode = SQL_MODE_READ_WRITE; + constexpr SQLUINTEGER readOnlyMode = SQL_MODE_READ_ONLY; + SQLUINTEGER currentMode = 0; + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, ¤tMode, sizeof(currentMode), nullptr), SQL_SUCCESS); + ASSERT_EQ(readWriteMode, currentMode); + + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, (SQLPOINTER)readOnlyMode, 0), dbc, SQL_HANDLE_DBC); + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, ¤tMode, sizeof(currentMode), nullptr), SQL_SUCCESS); + ASSERT_EQ(readOnlyMode, currentMode); + + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, (SQLPOINTER)readWriteMode, 0), dbc, SQL_HANDLE_DBC); + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, ¤tMode, sizeof(currentMode), nullptr), SQL_SUCCESS); + ASSERT_EQ(readWriteMode, currentMode); + + ASSERT_EQ(SQLSetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, (SQLPOINTER)9999, 0), SQL_ERROR); + EXPECT_TRUE(SqlStatePrefix(GetOdbcError(dbc, SQL_HANDLE_DBC), "HY024")); + + SQLHSTMT stmt; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + SQLCHAR dropQuery[] = "DROP TABLE IF EXISTS test_attr_read_only"; + SQLCHAR createQuery[] = "CREATE TABLE test_attr_read_only (id Int32, PRIMARY KEY (id))"; + SQLCHAR selectOneQuery[] = "SELECT 1 AS value"; + SQLCHAR upsertQuery[] = "UPSERT INTO test_attr_read_only (id) VALUES (1)"; + CHECK_ODBC_OK(SQLExecDirect(stmt, dropQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, createQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, (SQLPOINTER)readOnlyMode, 0), dbc, SQL_HANDLE_DBC); + CHECK_ODBC_OK(SQLExecDirect(stmt, selectOneQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLExecDirect(stmt, upsertQuery, SQL_NTS), SQL_ERROR); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcAttrConn, TxnIsolationAttr) { + SQLHENV env; + SQLHDBC dbc; + AllocEnvAndConnect(&env, &dbc); + SQLHSTMT stmt; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + SQLCHAR selectOneQuery[] = "SELECT 1 AS value"; + + SQLUINTEGER currentIsolation = 0; + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, ¤tIsolation, sizeof(currentIsolation), nullptr), SQL_SUCCESS); + ASSERT_EQ(static_cast(SQL_TXN_SERIALIZABLE), currentIsolation); + + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)SQL_TXN_REPEATABLE_READ, 0), dbc, SQL_HANDLE_DBC); + CHECK_ODBC_OK(SQLExecDirect(stmt, selectOneQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + + ASSERT_EQ(SQLSetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)SQL_TXN_READ_COMMITTED, 0), SQL_ERROR); + EXPECT_TRUE(SqlStatePrefix(GetOdbcError(dbc, SQL_HANDLE_DBC), "HYC00")); + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, ¤tIsolation, sizeof(currentIsolation), nullptr), SQL_SUCCESS); + ASSERT_EQ(static_cast(SQL_TXN_REPEATABLE_READ), currentIsolation); + + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, (SQLPOINTER)SQL_MODE_READ_ONLY, 0), dbc, SQL_HANDLE_DBC); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)SQL_TXN_READ_UNCOMMITTED, 0), dbc, SQL_HANDLE_DBC); + CHECK_ODBC_OK(SQLExecDirect(stmt, selectOneQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)SQL_TXN_READ_COMMITTED, 0), dbc, SQL_HANDLE_DBC); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)SQL_TXN_REPEATABLE_READ, 0), dbc, SQL_HANDLE_DBC); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)SQL_TXN_SERIALIZABLE, 0), dbc, SQL_HANDLE_DBC); + + ASSERT_EQ(SQLSetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)9999, 0), SQL_ERROR); + EXPECT_TRUE(SqlStatePrefix(GetOdbcError(dbc, SQL_HANDLE_DBC), "HY024")); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcAttrConn, CurrentCatalogAttr) { + SQLHENV env; + SQLHDBC dbc; + AllocEnvAndConnect(&env, &dbc); + SQLHSTMT stmt; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + constexpr const char* dbRoot = "/local"; + const std::string catA = std::string(dbRoot) + "/odbc_cat_a"; + const std::string catB = std::string(dbRoot) + "/odbc_cat_b"; + SQLCHAR dropAQuery[] = "DROP TABLE IF EXISTS `odbc_cat_a/probe`"; + SQLCHAR dropBQuery[] = "DROP TABLE IF EXISTS `odbc_cat_b/probe`"; + SQLCHAR createAQuery[] = + "CREATE TABLE `odbc_cat_a/probe` (id Int32, value Int32, PRIMARY KEY (id))"; + SQLCHAR createBQuery[] = + "CREATE TABLE `odbc_cat_b/probe` (id Int32, value Int32, PRIMARY KEY (id))"; + SQLCHAR upsertAQuery[] = "UPSERT INTO `odbc_cat_a/probe` (id, value) VALUES (1, 100)"; + SQLCHAR upsertBQuery[] = "UPSERT INTO `odbc_cat_b/probe` (id, value) VALUES (1, 200)"; + SQLCHAR selectAQuery[] = "SELECT value FROM `odbc_cat_a/probe` WHERE id = 1"; + SQLCHAR selectQuery[] = "SELECT value FROM probe WHERE id = 1"; + + char catalog[256] = {0}; + SQLINTEGER textLen = 0; + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_CURRENT_CATALOG, catalog, sizeof(catalog), &textLen), SQL_SUCCESS); + ASSERT_STREQ(catalog, dbRoot); + + + CHECK_ODBC_OK(SQLExecDirect(stmt, dropAQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, dropBQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, createAQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, createBQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, upsertAQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, upsertBQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, selectAQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + SQLINTEGER valueInt = 0; + SQLLEN valueInd = 0; + + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_CURRENT_CATALOG, (SQLPOINTER)catA.c_str(), SQL_NTS), dbc, + SQL_HANDLE_DBC); + std::memset(catalog, 0, sizeof(catalog)); + textLen = 0; + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_CURRENT_CATALOG, catalog, sizeof(catalog), &textLen), SQL_SUCCESS); + ASSERT_STREQ(catalog, catA.c_str()); + CHECK_ODBC_OK(SQLExecDirect(stmt, selectQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &valueInt, 0, &valueInd), SQL_SUCCESS); + ASSERT_EQ(valueInt, 100); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + + valueInt = 0; + valueInd = 0; + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_CURRENT_CATALOG, (SQLPOINTER)catB.c_str(), SQL_NTS), dbc, + SQL_HANDLE_DBC); + std::memset(catalog, 0, sizeof(catalog)); + textLen = 0; + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_CURRENT_CATALOG, catalog, sizeof(catalog), &textLen), SQL_SUCCESS); + ASSERT_STREQ(catalog, catB.c_str()); + CHECK_ODBC_OK(SQLExecDirect(stmt, selectQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &valueInt, 0, &valueInd), SQL_SUCCESS); + ASSERT_EQ(valueInt, 200); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + + const std::string catWithSlashes = catB + "///"; + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_CURRENT_CATALOG, (SQLPOINTER)catWithSlashes.c_str(), SQL_NTS), dbc, + SQL_HANDLE_DBC); + std::memset(catalog, 0, sizeof(catalog)); + textLen = 0; + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_CURRENT_CATALOG, catalog, sizeof(catalog), &textLen), SQL_SUCCESS); + ASSERT_STREQ(catalog, catB.c_str()); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + diff --git a/odbc/tests/integration/basic_it.cpp b/odbc/tests/integration/basic_it.cpp new file mode 100644 index 00000000000..37973667147 --- /dev/null +++ b/odbc/tests/integration/basic_it.cpp @@ -0,0 +1,101 @@ +#include "test_utils.h" + +TEST(OdbcBasic, SimpleQuery) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env), SQL_SUCCESS); + ASSERT_EQ(SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc), SQL_SUCCESS); + CHECK_ODBC_OK(SQLDriverConnect(dbc, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE), dbc, SQL_HANDLE_DBC); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + // Simple query + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"SELECT 1 AS one, 'abc' AS str", SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + + SQLINTEGER ival = 0; + char sval[16] = {0}; + SQLLEN ival_ind = 0, sval_ind = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &ival, 0, &ival_ind), SQL_SUCCESS); + ASSERT_EQ(SQLGetData(stmt, 2, SQL_C_CHAR, sval, sizeof(sval), &sval_ind), SQL_SUCCESS); + ASSERT_EQ(ival, 1); + ASSERT_STREQ(sval, "abc"); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcBasic, ParameterizedQuery) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env), SQL_SUCCESS); + ASSERT_EQ(SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc), SQL_SUCCESS); + CHECK_ODBC_OK(SQLDriverConnect(dbc, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE), dbc, SQL_HANDLE_DBC); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + SQLCHAR query[] = R"( + DECLARE $p1 AS Int32?; + SELECT $p1 + 10 AS res; + )"; + + // Parameterized query + CHECK_ODBC_OK(SQLPrepare(stmt, query, SQL_NTS), stmt, SQL_HANDLE_STMT); + SQLINTEGER param = 5; + CHECK_ODBC_OK(SQLBindParameter(stmt, 1, SQL_PARAM_INPUT, SQL_C_LONG, SQL_INTEGER, 0, 0, ¶m, 0, nullptr), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecute(stmt), stmt, SQL_HANDLE_STMT); + + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + + SQLINTEGER res = 0; + SQLLEN res_ind = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &res, 0, &res_ind), SQL_SUCCESS); + ASSERT_EQ(res, 15); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcBasic, ColumnBinding) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env), SQL_SUCCESS); + ASSERT_EQ(SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc), SQL_SUCCESS); + CHECK_ODBC_OK(SQLDriverConnect(dbc, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE), dbc, SQL_HANDLE_DBC); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + SQLCHAR query_ddl[] = R"( + DROP TABLE IF EXISTS test_bind; + CREATE TABLE test_bind (id Int32, name Text, PRIMARY KEY (id)); + )"; + + SQLCHAR query[] = R"( + UPSERT INTO test_bind (id, name) VALUES (1, 'foo'), (2, 'bar'); + SELECT id, name FROM test_bind ORDER BY id; + )"; + + CHECK_ODBC_OK(SQLExecDirect(stmt, query_ddl, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, query, SQL_NTS), stmt, SQL_HANDLE_STMT); + + SQLINTEGER id = 0; + char name[16] = {0}; + SQLLEN id_ind = 0, name_ind = 0; + ASSERT_EQ(SQLBindCol(stmt, 1, SQL_C_LONG, &id, 0, &id_ind), SQL_SUCCESS); + ASSERT_EQ(SQLBindCol(stmt, 2, SQL_C_CHAR, name, sizeof(name), &name_ind), SQL_SUCCESS); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(id, 1); + ASSERT_STREQ(name, "foo"); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(id, 2); + ASSERT_STREQ(name, "bar"); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} diff --git a/odbc/tests/integration/env_it.cpp b/odbc/tests/integration/env_it.cpp new file mode 100644 index 00000000000..952c1459ad6 --- /dev/null +++ b/odbc/tests/integration/env_it.cpp @@ -0,0 +1,90 @@ +#include "test_utils.h" + +namespace { + +void StartManualTx(SQLHDBC dbc, SQLHSTMT* stmt) { + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_OFF, 0), dbc, SQL_HANDLE_DBC); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, stmt), SQL_SUCCESS); + CHECK_ODBC_OK(SQLExecDirect(*stmt, (SQLCHAR*)"SELECT 1", SQL_NTS), *stmt, SQL_HANDLE_STMT); +} + +} // namespace + +TEST(OdbcEnv, EndTranCommitOnEnv) { + SQLHENV env; + SQLHDBC dbc1, dbc2; + SQLHSTMT stmt1, stmt2; + + AllocEnvAndConnect(&env, &dbc1); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc2), SQL_SUCCESS); + SQLRETURN rc = SQLDriverConnect( + dbc2, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE); + CHECK_ODBC_OK(rc, dbc2, SQL_HANDLE_DBC); + + StartManualTx(dbc1, &stmt1); + StartManualTx(dbc2, &stmt2); + + CHECK_ODBC_OK(SQLEndTran(SQL_HANDLE_ENV, env, SQL_COMMIT), env, SQL_HANDLE_ENV); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt1); + SQLFreeHandle(SQL_HANDLE_STMT, stmt2); + SQLDisconnect(dbc1); + SQLDisconnect(dbc2); + SQLFreeHandle(SQL_HANDLE_DBC, dbc1); + SQLFreeHandle(SQL_HANDLE_DBC, dbc2); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcEnv, EndTranRollbackOnEnv) { + SQLHENV env; + SQLHDBC dbc1, dbc2; + SQLHSTMT stmt1, stmt2; + + AllocEnvAndConnect(&env, &dbc1); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc2), SQL_SUCCESS); + SQLRETURN rc = SQLDriverConnect( + dbc2, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE); + CHECK_ODBC_OK(rc, dbc2, SQL_HANDLE_DBC); + + StartManualTx(dbc1, &stmt1); + StartManualTx(dbc2, &stmt2); + + CHECK_ODBC_OK(SQLEndTran(SQL_HANDLE_ENV, env, SQL_ROLLBACK), env, SQL_HANDLE_ENV); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt1); + SQLFreeHandle(SQL_HANDLE_STMT, stmt2); + SQLDisconnect(dbc1); + SQLDisconnect(dbc2); + SQLFreeHandle(SQL_HANDLE_DBC, dbc1); + SQLFreeHandle(SQL_HANDLE_DBC, dbc2); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcEnv, EndTranPartialFailureReturnsInfo) { + SQLHENV env; + SQLHDBC dbc1, dbc2; + SQLHSTMT stmt1, stmt2; + + AllocEnvAndConnect(&env, &dbc1); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc2), SQL_SUCCESS); + SQLRETURN rc = SQLDriverConnect( + dbc2, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE); + CHECK_ODBC_OK(rc, dbc2, SQL_HANDLE_DBC); + + StartManualTx(dbc1, &stmt1); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc2, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_OFF, 0), dbc2, SQL_HANDLE_DBC); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc2, &stmt2), SQL_SUCCESS); + (void)SQLExecDirect(stmt2, (SQLCHAR*)"SELECT FROM", SQL_NTS); + + rc = SQLEndTran(SQL_HANDLE_ENV, env, SQL_COMMIT); + ASSERT_TRUE(rc == SQL_SUCCESS || rc == SQL_SUCCESS_WITH_INFO || rc == SQL_ERROR) + << GetOdbcError(env, SQL_HANDLE_ENV); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt1); + SQLFreeHandle(SQL_HANDLE_STMT, stmt2); + SQLDisconnect(dbc1); + SQLDisconnect(dbc2); + SQLFreeHandle(SQL_HANDLE_DBC, dbc1); + SQLFreeHandle(SQL_HANDLE_DBC, dbc2); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} diff --git a/odbc/tests/integration/stmt_attr_it.cpp b/odbc/tests/integration/stmt_attr_it.cpp new file mode 100644 index 00000000000..89faf9abed0 --- /dev/null +++ b/odbc/tests/integration/stmt_attr_it.cpp @@ -0,0 +1,334 @@ +#include "test_utils.h" + +#include +#include +#include +#include + +#ifndef SQL_ATTR_METADATA_ID +#define SQL_ATTR_METADATA_ID 10029 +#endif + + +TEST(OdbcStmtAttr, QueryTimeoutAttr) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + SQLUINTEGER timeoutSec = 1; + CHECK_ODBC_OK( + SQLSetStmtAttr(stmt, SQL_ATTR_QUERY_TIMEOUT, (SQLPOINTER)(uintptr_t)timeoutSec, 0), + stmt, + SQL_HANDLE_STMT); + + SQLCHAR longQuery[] = + "SELECT COUNT(*) FROM AS_TABLE(ListMap(ListFromRange(1u, 100000000u), ($x)->(AsStruct($x AS v))))"; + ASSERT_EQ(SQLExecDirect(stmt, longQuery, SQL_NTS), SQL_ERROR); + EXPECT_TRUE(SqlStatePrefix(GetOdbcError(stmt, SQL_HANDLE_STMT), "HYT00")); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcStmtAttr, MaxRowsAttr) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + SQLCHAR dropQuery[] = "DROP TABLE IF EXISTS test_attr_max_rows"; + SQLCHAR createQuery[] = + "CREATE TABLE test_attr_max_rows (id Int32, value Int32, PRIMARY KEY (id))"; + SQLCHAR upsert1Query[] = "UPSERT INTO test_attr_max_rows (id, value) VALUES (1, 10)"; + SQLCHAR upsert2Query[] = "UPSERT INTO test_attr_max_rows (id, value) VALUES (2, 20)"; + SQLCHAR selectQuery[] = "SELECT value FROM test_attr_max_rows ORDER BY id"; + + CHECK_ODBC_OK(SQLExecDirect(stmt, dropQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, createQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, upsert1Query, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, upsert2Query, SQL_NTS), stmt, SQL_HANDLE_STMT); + + const SQLULEN maxRows = 1; + CHECK_ODBC_OK( + SQLSetStmtAttr(stmt, SQL_ATTR_MAX_ROWS, (SQLPOINTER)(uintptr_t)maxRows, 0), + stmt, + SQL_HANDLE_STMT); + + SQLULEN maxRowsOut = 0; + ASSERT_EQ(SQLGetStmtAttr(stmt, SQL_ATTR_MAX_ROWS, &maxRowsOut, 0, nullptr), SQL_SUCCESS); + ASSERT_EQ(maxRowsOut, maxRows); + + CHECK_ODBC_OK(SQLExecDirect(stmt, selectQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLFetch(stmt), SQL_NO_DATA); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcStmtAttr, NoScanAttr) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + SQLCHAR selectEscapeFnQuery[] = "SELECT {fn ABS(-12)} AS value"; + + CHECK_ODBC_OK(SQLSetStmtAttr(stmt, SQL_ATTR_NOSCAN, (SQLPOINTER)SQL_NOSCAN_OFF, 0), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, selectEscapeFnQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + SQLINTEGER valueInt = 0; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &valueInt, 0, &valueInd), SQL_SUCCESS); + ASSERT_EQ(valueInt, 12); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + + CHECK_ODBC_OK(SQLSetStmtAttr(stmt, SQL_ATTR_NOSCAN, (SQLPOINTER)SQL_NOSCAN_ON, 0), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLExecDirect(stmt, selectEscapeFnQuery, SQL_NTS), SQL_ERROR); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcStmtAttr, OdbcEscapeSequences) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + CHECK_ODBC_OK(SQLSetStmtAttr(stmt, SQL_ATTR_NOSCAN, (SQLPOINTER)SQL_NOSCAN_OFF, 0), stmt, SQL_HANDLE_STMT); + + { + SQLCHAR convertQuery[] = "SELECT {fn CONVERT(42, SQL_SMALLINT)} AS value"; + CHECK_ODBC_OK(SQLExecDirect(stmt, convertQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + SQLSMALLINT valueSmall = 0; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_SSHORT, &valueSmall, 0, &valueInd), SQL_SUCCESS); + ASSERT_EQ(valueSmall, 42); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + } + + { + SQLCHAR convertDoubleQuery[] = "SELECT {fn CONVERT(2.5, SQL_DOUBLE)} AS value"; + CHECK_ODBC_OK(SQLExecDirect(stmt, convertDoubleQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + double valueDouble = 0; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_DOUBLE, &valueDouble, 0, &valueInd), SQL_SUCCESS); + ASSERT_LT(std::fabs(valueDouble - 2.5), 1e-9); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + } + + { + SQLCHAR nestedFnQuery[] = "SELECT {fn {fn ABS(-10)}} AS value"; + CHECK_ODBC_OK(SQLExecDirect(stmt, nestedFnQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + SQLINTEGER valueInt = 0; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &valueInt, 0, &valueInd), SQL_SUCCESS); + ASSERT_EQ(valueInt, 10); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + } + + { + SQLCHAR asciiLowerQuery[] = "SELECT {fn String::AsciiToLower('AbC')} AS value"; + CHECK_ODBC_OK(SQLExecDirect(stmt, asciiLowerQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + char buf[32] = {}; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_CHAR, buf, sizeof(buf), &valueInd), SQL_SUCCESS); + ASSERT_STREQ(buf, "abc"); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + } + + { + SQLCHAR dateQuery[] = "SELECT {d '2024-06-15'} AS value"; + CHECK_ODBC_OK(SQLExecDirect(stmt, dateQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + char buf[32] = {}; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_CHAR, buf, sizeof(buf), &valueInd), SQL_SUCCESS); + ASSERT_STREQ(buf, "2024-06-15"); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + } + + { + SQLCHAR tsQuery[] = "SELECT {ts '2024-06-15 14:30:00'} AS value"; + CHECK_ODBC_OK(SQLExecDirect(stmt, tsQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + char buf[64] = {}; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_CHAR, buf, sizeof(buf), &valueInd), SQL_SUCCESS); + ASSERT_STREQ(buf, "2024-06-15 14:30:00"); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + } + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcStmtAttr, MetadataIdSqlLikeForTableNames) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + SQLCHAR ddl[] = R"( + DROP TABLE IF EXISTS test_odbc_meta_like_a; + DROP TABLE IF EXISTS test_odbc_meta_like_b; + CREATE TABLE test_odbc_meta_like_a (id Int32, PRIMARY KEY (id)); + CREATE TABLE test_odbc_meta_like_b (id Int32, PRIMARY KEY (id)); + )"; + CHECK_ODBC_OK(SQLExecDirect(stmt, ddl, SQL_NTS), stmt, SQL_HANDLE_STMT); + + SQLULEN metadataId = SQL_TRUE; + ASSERT_EQ(SQLGetStmtAttr(stmt, SQL_ATTR_METADATA_ID, &metadataId, 0, nullptr), SQL_SUCCESS); + ASSERT_EQ(metadataId, SQL_FALSE); + + const char* likePattern = "%/test_odbc_meta_like_%"; + CHECK_ODBC_OK( + SQLTables(stmt, nullptr, 0, nullptr, 0, (SQLCHAR*)likePattern, SQL_NTS, (SQLCHAR*)"TABLE", SQL_NTS), + stmt, + SQL_HANDLE_STMT); + int tableRows = 0; + while (SQLFetch(stmt) == SQL_SUCCESS) { + ++tableRows; + } + ASSERT_EQ(tableRows, 2); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + + CHECK_ODBC_OK( + SQLSetStmtAttr(stmt, SQL_ATTR_METADATA_ID, (SQLPOINTER)(uintptr_t)SQL_TRUE, 0), + stmt, + SQL_HANDLE_STMT); + ASSERT_EQ(SQLGetStmtAttr(stmt, SQL_ATTR_METADATA_ID, &metadataId, 0, nullptr), SQL_SUCCESS); + ASSERT_EQ(metadataId, SQL_TRUE); + + ASSERT_EQ( + SQLTables(stmt, nullptr, 0, nullptr, 0, (SQLCHAR*)likePattern, SQL_NTS, (SQLCHAR*)"TABLE", SQL_NTS), + SQL_ERROR); + EXPECT_TRUE(SqlStatePrefix(GetOdbcError(stmt, SQL_HANDLE_STMT), "HYC00")); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + + const std::string exactPath = "/local/test_odbc_meta_like_a"; + CHECK_ODBC_OK( + SQLTables(stmt, nullptr, 0, nullptr, 0, (SQLCHAR*)exactPath.c_str(), SQL_NTS, (SQLCHAR*)"TABLE", SQL_NTS), + stmt, + SQL_HANDLE_STMT); + tableRows = 0; + while (SQLFetch(stmt) == SQL_SUCCESS) { + ++tableRows; + } + ASSERT_EQ(tableRows, 1); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + + CHECK_ODBC_OK( + SQLSetStmtAttr(stmt, SQL_ATTR_METADATA_ID, (SQLPOINTER)(uintptr_t)SQL_FALSE, 0), + stmt, + SQL_HANDLE_STMT); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcStmtAttr, MetadataIdSqlLikeForColumnNames) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + SQLCHAR ddl[] = R"( + DROP TABLE IF EXISTS test_odbc_meta_col; + CREATE TABLE test_odbc_meta_col (id Int32, value_x Int32, PRIMARY KEY (id)); + )"; + CHECK_ODBC_OK(SQLExecDirect(stmt, ddl, SQL_NTS), stmt, SQL_HANDLE_STMT); + + constexpr SQLUSMALLINT kColumnNameCol = 4; + char colName[256] = {}; + SQLLEN colInd = 0; + const std::string exactTable = "/local/test_odbc_meta_col"; + + { + CHECK_ODBC_OK( + SQLColumns( + stmt, + nullptr, + 0, + nullptr, + 0, + (SQLCHAR*)"%/test_odbc_meta_col", + SQL_NTS, + (SQLCHAR*)"val%", + SQL_NTS), + stmt, + SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLGetData(stmt, kColumnNameCol, SQL_C_CHAR, colName, sizeof(colName), &colInd), SQL_SUCCESS); + ASSERT_STREQ(colName, "value_x"); + ASSERT_EQ(SQLFetch(stmt), SQL_NO_DATA); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + } + + { + CHECK_ODBC_OK( + SQLSetStmtAttr(stmt, SQL_ATTR_METADATA_ID, (SQLPOINTER)(uintptr_t)SQL_TRUE, 0), + stmt, + SQL_HANDLE_STMT); + + ASSERT_EQ( + SQLColumns( + stmt, + nullptr, + 0, + nullptr, + 0, + (SQLCHAR*)"%/test_odbc_meta_col", + SQL_NTS, + (SQLCHAR*)"value_x", + SQL_NTS), + SQL_ERROR); + EXPECT_TRUE(SqlStatePrefix(GetOdbcError(stmt, SQL_HANDLE_STMT), "HYC00")); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + } + + { + ASSERT_EQ( + SQLColumns( + stmt, + nullptr, + 0, + nullptr, + 0, + (SQLCHAR*)exactTable.c_str(), + SQL_NTS, + (SQLCHAR*)"val%", + SQL_NTS), + SQL_ERROR); + EXPECT_TRUE(SqlStatePrefix(GetOdbcError(stmt, SQL_HANDLE_STMT), "42S22")); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + } + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + diff --git a/odbc/tests/integration/test_utils.h b/odbc/tests/integration/test_utils.h new file mode 100644 index 00000000000..950ffef9508 --- /dev/null +++ b/odbc/tests/integration/test_utils.h @@ -0,0 +1,43 @@ +#pragma once + +#include + +#include +#include + +#include +#include + +inline std::string GetOdbcError(SQLHANDLE handle, SQLSMALLINT type) { + SQLCHAR sqlState[6] = {0}; + SQLCHAR message[256] = {0}; + SQLINTEGER nativeError = 0; + SQLSMALLINT textLength = 0; + SQLRETURN rc = SQLGetDiagRec(type, handle, 1, sqlState, &nativeError, message, sizeof(message), &textLength); + if (rc == SQL_SUCCESS || rc == SQL_SUCCESS_WITH_INFO) { + return std::string((char*)sqlState) + ": " + (char*)message; + } + return "Unknown ODBC error"; +} + +#define CHECK_ODBC_OK(rc, handle, type) \ + ASSERT_TRUE((rc) == SQL_SUCCESS || (rc) == SQL_SUCCESS_WITH_INFO) << GetOdbcError(handle, type) + +inline const char* kConnStr = "Driver=" ODBC_DRIVER_PATH ";Endpoint=localhost:2136;Database=/local;"; + +inline bool SqlStatePrefix(const std::string& diag, const char* state5) { + return diag.size() >= 5 && std::strncmp(diag.c_str(), state5, 5) == 0; +} + +inline void AllocEnv(SQLHENV* env) { + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, env), SQL_SUCCESS); + ASSERT_EQ(SQLSetEnvAttr(*env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); +} + +inline void AllocEnvAndConnect(SQLHENV* env, SQLHDBC* dbc) { + AllocEnv(env); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, *env, dbc), SQL_SUCCESS); + SQLRETURN rc = SQLDriverConnect( + *dbc, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE); + CHECK_ODBC_OK(rc, *dbc, SQL_HANDLE_DBC); +} diff --git a/odbc/tests/unit/CMakeLists.txt b/odbc/tests/unit/CMakeLists.txt new file mode 100644 index 00000000000..d23e837d2f3 --- /dev/null +++ b/odbc/tests/unit/CMakeLists.txt @@ -0,0 +1,33 @@ +add_ydb_test(NAME odbc-convert_ut GTEST + SOURCES + convert_ut.cpp + LINK_LIBRARIES + yutil + api-protos + ydb-odbc + LABELS + unit +) + +add_ydb_test(NAME odbc-escape_ut GTEST + SOURCES + escape_ut.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/utils/escape.cpp + INCLUDE_DIRS + ${CMAKE_CURRENT_SOURCE_DIR}/../../src + LINK_LIBRARIES + yutil + LABELS + unit +) + +add_ydb_test(NAME odbc-sql_like_ut GTEST + SOURCES + sql_like_ut.cpp + INCLUDE_DIRS + ${CMAKE_CURRENT_SOURCE_DIR}/../../src + LINK_LIBRARIES + yutil + LABELS + unit +) diff --git a/odbc/tests/unit/convert_ut.cpp b/odbc/tests/unit/convert_ut.cpp new file mode 100644 index 00000000000..f4bad34a366 --- /dev/null +++ b/odbc/tests/unit/convert_ut.cpp @@ -0,0 +1,130 @@ +#include +#undef BOOL + +#include + +#include + +#include + +#include + +using namespace NYdb::NOdbc; +using namespace NYdb; + +template +void CheckProto(const T& value, const std::string& expected) { + std::string protoStr; + google::protobuf::TextFormat::PrintToString(value, &protoStr); + ASSERT_EQ(protoStr, expected); +} + +TEST(OdbcConvert, Int64ToYdb) { + SQLBIGINT v = 42; + TBoundParam param{ + 1, // ParamNumber + SQL_PARAM_INPUT, // InputOutputType + SQL_C_SBIGINT, // ValueType + SQL_BIGINT, // ParameterType + 0, 0, // ColumnSize, DecimalDigits + &v, // ParameterValuePtr + sizeof(v), // BufferLength + nullptr // StrLenOrIndPtr + }; + + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: INT64\n }\n}\n"); + CheckProto(value->GetProto(), "int64_value: 42\n"); +} + +TEST(OdbcConvert, Uint64ToYdb) { + SQLUBIGINT v = 123; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_UBIGINT, SQL_BIGINT, 0, 0, &v, sizeof(v), nullptr + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: UINT64\n }\n}\n"); + CheckProto(value->GetProto(), "uint64_value: 123\n"); +} + +TEST(OdbcConvert, DoubleToYdb) { + SQLDOUBLE v = 3.14; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_DOUBLE, SQL_DOUBLE, 0, 0, &v, sizeof(v), nullptr + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: DOUBLE\n }\n}\n"); + CheckProto(value->GetProto(), "double_value: 3.14\n"); +} + +TEST(OdbcConvert, StringToYdbUtf8) { + const char* str = "hello"; + SQLLEN len = 5; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_CHAR, SQL_VARCHAR, 0, 0, (SQLPOINTER)str, len, nullptr + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: UTF8\n }\n}\n"); + CheckProto(value->GetProto(), "text_value: \"hello\"\n"); +} + +TEST(OdbcConvert, StringToYdbBinary) { + const char* str = "bin\x01\x02"; + SQLLEN len = 5; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_BINARY, SQL_BINARY, 0, 0, (SQLPOINTER)str, len, nullptr + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: STRING\n }\n}\n"); + CheckProto(value->GetProto(), "bytes_value: \"bin\\001\\002\"\n"); +} + +TEST(OdbcConvert, Int64NullToYdb) { + SQLBIGINT v = 42; + SQLLEN nullInd = SQL_NULL_DATA; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_SBIGINT, SQL_BIGINT, 0, 0, &v, sizeof(v), &nullInd + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: INT64\n }\n}\n"); + CheckProto(value->GetProto(), "null_flag_value: NULL_VALUE\n"); +} + +TEST(OdbcConvert, StringNullToYdb) { + const char* str = "test"; + SQLLEN nullInd = SQL_NULL_DATA; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_CHAR, SQL_VARCHAR, 0, 0, (SQLPOINTER)str, 4, &nullInd + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: UTF8\n }\n}\n"); + CheckProto(value->GetProto(), "null_flag_value: NULL_VALUE\n"); +} diff --git a/odbc/tests/unit/escape_ut.cpp b/odbc/tests/unit/escape_ut.cpp new file mode 100644 index 00000000000..60b3e582e69 --- /dev/null +++ b/odbc/tests/unit/escape_ut.cpp @@ -0,0 +1,71 @@ +#include "utils/escape.h" + +#include + +using NYdb::NOdbc::RewriteOdbcEscapes; + +TEST(OdbcEscapeRewrite, FnUnwraps) { + EXPECT_EQ(RewriteOdbcEscapes("SELECT {fn ABS(-12)} AS v"), "SELECT ABS(-12) AS v"); +} + +TEST(OdbcEscapeRewrite, FnCaseInsensitive) { + EXPECT_EQ(RewriteOdbcEscapes("{FN LOWER('A')}"), "LOWER('A')"); +} + +TEST(OdbcEscapeRewrite, OjUnwraps) { + EXPECT_EQ(RewriteOdbcEscapes("{oj LEFT OUTER JOIN t ON a=b}"), "LEFT OUTER JOIN t ON a=b"); +} + +TEST(OdbcEscapeRewrite, DateLiteral) { + EXPECT_EQ(RewriteOdbcEscapes("SELECT {d '2024-01-01'}"), "SELECT CAST('2024-01-01' AS Date)"); +} + +TEST(OdbcEscapeRewrite, TimeLiteral) { + EXPECT_EQ(RewriteOdbcEscapes("{t '14:30:00'}"), "CAST('14:30:00' AS Time)"); +} + +TEST(OdbcEscapeRewrite, TimestampLiteralNormalizesSpaceToT) { + EXPECT_EQ( + RewriteOdbcEscapes("SELECT {ts '2024-06-15 14:30:00'} AS v"), + "SELECT CAST('2024-06-15T14:30:00Z' AS Datetime) AS v"); +} + +TEST(OdbcEscapeRewrite, TimestampLiteralKeepsExistingZ) { + EXPECT_EQ( + RewriteOdbcEscapes("SELECT {ts '2024-06-15T14:30:00Z'} AS v"), + "SELECT CAST('2024-06-15T14:30:00Z' AS Datetime) AS v"); +} + +TEST(OdbcEscapeRewrite, Call) { + EXPECT_EQ(RewriteOdbcEscapes("{call sp_demo(1, 2)}"), "CALL sp_demo(1, 2)"); +} + +TEST(OdbcEscapeRewrite, OutputCallBecomesPlainCall) { + EXPECT_EQ(RewriteOdbcEscapes("{?= call sp(1)}"), "CALL sp(1)"); +} + +TEST(OdbcEscapeRewrite, EscapeClause) { + EXPECT_EQ(RewriteOdbcEscapes("LIKE 'a%' {escape '\\'}"), "LIKE 'a%' ESCAPE '\\'"); +} + +TEST(OdbcEscapeRewrite, ConvertOdbcToYqlCast) { + EXPECT_EQ( + RewriteOdbcEscapes("SELECT {fn CONVERT(42, SQL_SMALLINT)} AS v"), + "SELECT CAST(42 AS Int16) AS v"); +} + +TEST(OdbcEscapeRewrite, ConvertNestedInFn) { + EXPECT_EQ(RewriteOdbcEscapes("{fn CONVERT(x, SQL_INTEGER)}"), "CAST(x AS Int32)"); +} + +TEST(OdbcEscapeRewrite, NestedFnEscapes) { + EXPECT_EQ(RewriteOdbcEscapes("{fn {fn ABS(1)}}"), "ABS(1)"); +} + +TEST(OdbcEscapeRewrite, UnknownBraceLeftUnchanged) { + EXPECT_EQ(RewriteOdbcEscapes("{not_a_keyword 1}"), "{not_a_keyword 1}"); +} + +TEST(OdbcEscapeRewrite, EmptyInput) { + EXPECT_EQ(RewriteOdbcEscapes(""), ""); +} diff --git a/odbc/tests/unit/sql_like_ut.cpp b/odbc/tests/unit/sql_like_ut.cpp new file mode 100644 index 00000000000..e0b8d87ee01 --- /dev/null +++ b/odbc/tests/unit/sql_like_ut.cpp @@ -0,0 +1,28 @@ +#include "utils/sql_like.h" + +#include + +using NYdb::NOdbc::SqlLikeMatch; + +TEST(SqlLikeMatch, PercentMatchesSubstring) { + EXPECT_TRUE(SqlLikeMatch("/local/foo_bar", "%foo%")); + EXPECT_TRUE(SqlLikeMatch("/local/pfx_foo_sfx", "%foo%")); + EXPECT_FALSE(SqlLikeMatch("/local/other", "%foo%")); +} + +TEST(SqlLikeMatch, UnderscoreMatchesSingleChar) { + EXPECT_TRUE(SqlLikeMatch("a_c", "a_c")); + EXPECT_TRUE(SqlLikeMatch("abc", "a_c")); + EXPECT_FALSE(SqlLikeMatch("abbc", "a_c")); +} + +TEST(SqlLikeMatch, EmptyPatternMatchesOnlyEmptyText) { + EXPECT_TRUE(SqlLikeMatch("", "")); + EXPECT_FALSE(SqlLikeMatch("anything", "")); +} + +TEST(SqlLikeMatch, PercentAtEnds) { + EXPECT_TRUE(SqlLikeMatch("hello", "%hello%")); + EXPECT_TRUE(SqlLikeMatch("hello", "hel%")); + EXPECT_TRUE(SqlLikeMatch("hello", "%llo")); +} diff --git a/tests/unit/library/operation_id/CMakeLists.txt b/tests/unit/library/operation_id/CMakeLists.txt index 86d3fd5131d..06c568af97c 100644 --- a/tests/unit/library/operation_id/CMakeLists.txt +++ b/tests/unit/library/operation_id/CMakeLists.txt @@ -5,6 +5,7 @@ add_ydb_test(NAME operation_id_ut yutil cpp-testing-unittest_main library-operation_id + lib-operation_id-protos cpp-testing-unittest LABELS unit