From c37dccd22b26ca4d947fa0c904585e4592a582e9 Mon Sep 17 00:00:00 2001 From: Ylonies Date: Tue, 7 Apr 2026 14:12:28 +0000 Subject: [PATCH 1/9] some fixes --- odbc/CMakeLists.txt | 3 +- odbc/src/connection.cpp | 17 ++++++++ odbc/src/connection.h | 4 ++ odbc/src/odbc_driver.cpp | 39 +++++++++++++++-- odbc/src/statement.cpp | 75 ++++++++++++++++++++++++++------ odbc/src/statement.h | 5 +++ odbc/src/utils/bindings.h | 5 +++ odbc/src/utils/cursor.cpp | 22 +++++++--- odbc/src/utils/cursor.h | 5 ++- odbc/src/utils/error_manager.cpp | 54 ++++++++++++++++++++++- odbc/src/utils/error_manager.h | 12 ++++- 11 files changed, 212 insertions(+), 29 deletions(-) diff --git a/odbc/CMakeLists.txt b/odbc/CMakeLists.txt index 799c9b89b19..9919870702d 100644 --- a/odbc/CMakeLists.txt +++ b/odbc/CMakeLists.txt @@ -23,7 +23,6 @@ target_link_libraries(ydb-odbc YDB-CPP-SDK::Table YDB-CPP-SDK::Scheme YDB-CPP-SDK::Driver - ODBC::ODBC ) set_target_properties(ydb-odbc PROPERTIES @@ -43,7 +42,7 @@ add_subdirectory(tests) include(GNUInstallDirs) -install(FILES +install(FILES odbcinst.ini DESTINATION ${CMAKE_INSTALL_SYSCONFDIR}/odbcinst.d RENAME ydb-odbc.ini diff --git a/odbc/src/connection.cpp b/odbc/src/connection.cpp index 7ed7679e015..eb142108334 100644 --- a/odbc/src/connection.cpp +++ b/odbc/src/connection.cpp @@ -79,11 +79,24 @@ SQLRETURN TConnection::Connect(const std::string& serverName, } SQLRETURN TConnection::Disconnect() { + QuerySession_.reset(); + Tx_.reset(); + YdbSchemeClient_.reset(); + YdbTableClient_.reset(); YdbClient_.reset(); YdbDriver_.reset(); return SQL_SUCCESS; } +NQuery::TSession& TConnection::GetOrCreateQuerySession() { + if (!QuerySession_) { + auto sessionResult = YdbClient_->GetSession().ExtractValueSync(); + NStatusHelpers::ThrowOnError(sessionResult); + QuerySession_.emplace(std::move(sessionResult.GetSession())); + } + return *QuerySession_; +} + std::unique_ptr TConnection::CreateStatement() { return std::make_unique(this); } @@ -115,6 +128,10 @@ void TConnection::SetTx(const NQuery::TTransaction& tx) { Tx_ = tx; } +void TConnection::Reset() { + Tx_.reset(); +} + SQLRETURN TConnection::CommitTx() { auto status = Tx_->Commit().ExtractValueSync(); NStatusHelpers::ThrowOnError(status); diff --git a/odbc/src/connection.h b/odbc/src/connection.h index ad69b0f171c..a0ce4acb991 100644 --- a/odbc/src/connection.h +++ b/odbc/src/connection.h @@ -27,6 +27,8 @@ class TConnection : public TErrorManager { std::unique_ptr YdbTableClient_; std::unique_ptr YdbSchemeClient_; std::optional Tx_; + /** Одна сессия KQP на ODBC-соединение: DDL/DML/SELECT видят одну и ту же схему без «новой» сессии на каждый Execute. */ + std::optional QuerySession_; std::vector> Statements_; std::string Endpoint_; @@ -47,6 +49,7 @@ class TConnection : public TErrorManager { void RemoveStatement(TStatement* stmt); NYdb::NQuery::TQueryClient* GetClient() { return YdbClient_.get(); } + NQuery::TSession& GetOrCreateQuerySession(); NYdb::NTable::TTableClient* GetTableClient() { return YdbTableClient_.get(); } NScheme::TSchemeClient* GetSchemeClient() { return YdbSchemeClient_.get(); } @@ -55,6 +58,7 @@ class TConnection : public TErrorManager { const std::optional& GetTx(); void SetTx(const NQuery::TTransaction& tx); + void Reset(); SQLRETURN CommitTx(); SQLRETURN RollbackTx(); diff --git a/odbc/src/odbc_driver.cpp b/odbc/src/odbc_driver.cpp index c047f770837..f26bd55c828 100644 --- a/odbc/src/odbc_driver.cpp +++ b/odbc/src/odbc_driver.cpp @@ -29,10 +29,13 @@ SQLRETURN SQL_API SQLAllocHandle(SQLSMALLINT handleType, switch (handleType) { case SQL_HANDLE_ENV: { - return NYdb::NOdbc::HandleOdbcExceptions(inputHandle, [&]() { - *outputHandle = new NYdb::NOdbc::TEnvironment(); - return SQL_SUCCESS; - }); + return NYdb::NOdbc::HandleOdbcExceptions( + inputHandle, + [&]() { + *outputHandle = new NYdb::NOdbc::TEnvironment(); + return SQL_SUCCESS; + }, + NYdb::NOdbc::ENullInputHandlePolicy::Allow); } case SQL_HANDLE_DBC: { @@ -208,6 +211,34 @@ SQLRETURN SQL_API SQLGetDiagRec(SQLSMALLINT handleType, } } +SQLRETURN SQL_API SQLGetDiagField(SQLSMALLINT handleType, + SQLHANDLE handle, + SQLSMALLINT recNumber, + SQLSMALLINT diagIdentifier, + SQLPOINTER diagInfoPtr, + SQLSMALLINT bufferLength, + SQLSMALLINT* stringLengthPtr) { + switch (handleType) { + case SQL_HANDLE_ENV: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* env) { + return env->GetDiagField(recNumber, diagIdentifier, diagInfoPtr, bufferLength, stringLengthPtr); + }); + } + case SQL_HANDLE_DBC: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* conn) { + return conn->GetDiagField(recNumber, diagIdentifier, diagInfoPtr, bufferLength, stringLengthPtr); + }); + } + case SQL_HANDLE_STMT: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* stmt) { + return stmt->GetDiagField(recNumber, diagIdentifier, diagInfoPtr, bufferLength, stringLengthPtr); + }); + } + default: + return SQL_ERROR; + } +} + SQLRETURN SQL_API SQLBindParameter(SQLHSTMT statementHandle, SQLUSMALLINT paramNumber, SQLSMALLINT inputOutputType, diff --git a/odbc/src/statement.cpp b/odbc/src/statement.cpp index b61b8f07eb2..b7fddc624bf 100644 --- a/odbc/src/statement.cpp +++ b/odbc/src/statement.cpp @@ -14,6 +14,7 @@ TStatement::TStatement(TConnection* conn) : Conn_(conn) {} SQLRETURN TStatement::Prepare(const std::string& statementText) { + StreamFetchError_ = false; Cursor_.reset(); PreparedQuery_ = statementText; IsPrepared_ = true; @@ -24,40 +25,86 @@ SQLRETURN TStatement::Execute() { if (!IsPrepared_ || PreparedQuery_.empty()) { throw TOdbcException("HY007", 0, "No prepared statement"); } + StreamFetchError_ = false; Cursor_.reset(); auto* client = Conn_->GetClient(); if (!client) { throw TOdbcException("HY000", 0, "No client connection"); } NYdb::TParams params = BuildParams(); - - if (!Conn_->GetTx()) { - auto sessionResult = client->GetSession().ExtractValueSync(); - NStatusHelpers::ThrowOnError(sessionResult); - - auto session = sessionResult.GetSession(); - auto beginTxResult = session.BeginTransaction(NQuery::TTxSettings::SerializableRW()).ExtractValueSync(); - NStatusHelpers::ThrowOnError(beginTxResult); - Conn_->SetTx(beginTxResult.GetTransaction()); + if (Conn_->GetAutocommit()){ + Conn_->Reset(); } - auto session = Conn_->GetTx()->GetSession(); - auto iterator = session.StreamExecuteQuery(PreparedQuery_, - NQuery::TTxControl::Tx(*Conn_->GetTx()).CommitTx(Conn_->GetAutocommit()), params).ExtractValueSync(); + + auto& session = Conn_->GetOrCreateQuerySession(); + + auto iterator = CreateExecuteIterator(session, params); NStatusHelpers::ThrowOnError(iterator); - Cursor_ = CreateExecCursor(this, std::move(iterator)); + std::optional prefetchedResultPart = PrefetchFirstResultPart(iterator); + if (prefetchedResultPart) { + Cursor_ = CreateExecCursor(this, std::move(iterator), std::move(prefetchedResultPart)); + } else { + Cursor_.reset(); + } IsPrepared_ = false; PreparedQuery_.clear(); return SQL_SUCCESS; } +NQuery::TExecuteQueryIterator TStatement::CreateExecuteIterator(NQuery::TSession& session, const NYdb::TParams& params){ + if (Conn_->GetAutocommit()) { + return session.StreamExecuteQuery( + PreparedQuery_, + NQuery::TTxControl::NoTx(), + params).ExtractValueSync(); + } + if (!Conn_->GetTx()) { + auto beginTxResult = session.BeginTransaction(NQuery::TTxSettings::SerializableRW()).ExtractValueSync(); + NStatusHelpers::ThrowOnError(beginTxResult); + Conn_->SetTx(beginTxResult.GetTransaction()); + } + return session.StreamExecuteQuery( + PreparedQuery_, + NQuery::TTxControl::Tx(*Conn_->GetTx()).CommitTx(false), + params).ExtractValueSync(); +} + +std::optional TStatement::PrefetchFirstResultPart(NQuery::TExecuteQueryIterator& iterator){ + std::optional prefetchedResultPart; + while (true) { + auto part = iterator.ReadNext().ExtractValueSync(); + if (part.EOS()) { + break; + } + if (!part.IsSuccess()) { + NStatusHelpers::ThrowOnError(part); + } + if (part.HasResultSet()) { + prefetchedResultPart.emplace(std::move(part)); + break; + } + } + return prefetchedResultPart; +} + SQLRETURN TStatement::Fetch() { if (!Cursor_) { Cursor_.reset(); return SQL_NO_DATA; } - return Cursor_->Fetch() ? SQL_SUCCESS : SQL_NO_DATA; + StreamFetchError_ = false; + if (!Cursor_->Fetch()) { + return StreamFetchError_ ? SQL_ERROR : SQL_NO_DATA; + } + return SQL_SUCCESS; +} + +void TStatement::OnStreamPartError(const TStatus& status) { + ClearErrors(); + AddError(status); + StreamFetchError_ = true; } SQLRETURN TStatement::GetData(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, diff --git a/odbc/src/statement.h b/odbc/src/statement.h index 8bed3534986..f17780957bb 100644 --- a/odbc/src/statement.h +++ b/odbc/src/statement.h @@ -30,6 +30,7 @@ class TStatement : public TErrorManager, public IBindingFiller { SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd); void FillBoundColumns() override; + void OnStreamPartError(const TStatus& status) override; SQLRETURN Close(bool force = false); void UnbindColumns(); @@ -63,9 +64,13 @@ class TStatement : public TErrorManager, public IBindingFiller { std::vector BoundColumns_; std::vector BoundParams_; + bool StreamFetchError_ = false; NYdb::TParams BuildParams(); + NQuery::TExecuteQueryIterator CreateExecuteIterator(NQuery::TSession& session, const NYdb::TParams& params); + std::optional PrefetchFirstResultPart(NQuery::TExecuteQueryIterator& iterator); + std::vector GetPatternEntries(const std::string& pattern); SQLRETURN VisitEntry(const std::string& path, const std::string& pattern, std::vector& resultEntries); bool IsPatternMatch(const std::string& path, const std::string& pattern); diff --git a/odbc/src/utils/bindings.h b/odbc/src/utils/bindings.h index df76de4e951..443d9787d70 100644 --- a/odbc/src/utils/bindings.h +++ b/odbc/src/utils/bindings.h @@ -3,6 +3,8 @@ #include #include +#include + namespace NYdb { namespace NOdbc { @@ -29,6 +31,9 @@ struct TBoundColumn { class IBindingFiller { public: virtual void FillBoundColumns() = 0; + virtual void OnStreamPartError(const TStatus& status) { + (void)status; + } virtual ~IBindingFiller() = default; }; diff --git a/odbc/src/utils/cursor.cpp b/odbc/src/utils/cursor.cpp index fbd10588aba..efbcea9a419 100644 --- a/odbc/src/utils/cursor.cpp +++ b/odbc/src/utils/cursor.cpp @@ -8,9 +8,11 @@ namespace NOdbc { class TExecCursor : public ICursor { public: - TExecCursor(IBindingFiller* bindingFiller, NQuery::TExecuteQueryIterator iterator) + TExecCursor(IBindingFiller* bindingFiller, NQuery::TExecuteQueryIterator iterator, + std::optional prefetchedPart) : BindingFiller_(bindingFiller) , Iterator_(std::move(iterator)) + , PrefetchedPart_(std::move(prefetchedPart)) {} bool Fetch() override { @@ -22,11 +24,19 @@ class TExecCursor : public ICursor { } ResultSetParser_.reset(); } - auto part = Iterator_.ReadNext().ExtractValueSync(); + NQuery::TExecuteQueryPart part = [&]() { + if (PrefetchedPart_) { + auto p = std::move(*PrefetchedPart_); + PrefetchedPart_.reset(); + return p; + } + return Iterator_.ReadNext().ExtractValueSync(); + }(); if (part.EOS()) { return false; } if (!part.IsSuccess()) { + BindingFiller_->OnStreamPartError(part); return false; } if (part.HasResultSet()) { @@ -62,7 +72,7 @@ class TExecCursor : public ICursor { IBindingFiller* BindingFiller_; NQuery::TExecuteQueryIterator Iterator_; - // std::optional Part_; + std::optional PrefetchedPart_; std::unique_ptr ResultSetParser_; std::vector Columns_; }; @@ -107,8 +117,10 @@ class TVirtualCursor : public ICursor { int64_t Cursor_ = -1; }; -std::unique_ptr CreateExecCursor(IBindingFiller* bindingFiller, NQuery::TExecuteQueryIterator iterator) { - return std::make_unique(bindingFiller, std::move(iterator)); +std::unique_ptr CreateExecCursor(IBindingFiller* bindingFiller, + NQuery::TExecuteQueryIterator iterator, + std::optional prefetchedPart) { + return std::make_unique(bindingFiller, std::move(iterator), std::move(prefetchedPart)); } std::unique_ptr CreateVirtualCursor(IBindingFiller* bindingFiller, const std::vector& columns, const TTable& table) { diff --git a/odbc/src/utils/cursor.h b/odbc/src/utils/cursor.h index e4b13ed5215..22828f66144 100644 --- a/odbc/src/utils/cursor.h +++ b/odbc/src/utils/cursor.h @@ -6,6 +6,7 @@ #include +#include #include #include @@ -30,7 +31,9 @@ class ICursor { virtual const std::vector& GetColumnMeta() const = 0; }; -std::unique_ptr CreateExecCursor(IBindingFiller* bindingFiller, NYdb::NQuery::TExecuteQueryIterator iterator); +std::unique_ptr CreateExecCursor(IBindingFiller* bindingFiller, + NYdb::NQuery::TExecuteQueryIterator iterator, + std::optional prefetchedPart = std::nullopt); std::unique_ptr CreateVirtualCursor(IBindingFiller* bindingFiller, const std::vector& columns, const TTable& table); } // namespace NOdbc diff --git a/odbc/src/utils/error_manager.cpp b/odbc/src/utils/error_manager.cpp index da9f339af89..fbb577e3824 100644 --- a/odbc/src/utils/error_manager.cpp +++ b/odbc/src/utils/error_manager.cpp @@ -104,8 +104,58 @@ SQLRETURN TErrorManager::GetDiagRec(SQLSMALLINT recNumber, SQLCHAR* sqlState, SQ return SQL_SUCCESS; } -SQLRETURN HandleOdbcExceptions(SQLHANDLE handlePtr, std::function&& func) { - if (!handlePtr) { +SQLRETURN TErrorManager::GetDiagField(SQLSMALLINT recNumber, SQLSMALLINT diagIdentifier, + SQLPOINTER diagInfoPtr, SQLSMALLINT bufferLength, SQLSMALLINT* stringLengthPtr) { + const SQLSMALLINT count = static_cast(Errors_.size()); + + if (recNumber == 0) { + if (diagIdentifier == SQL_DIAG_NUMBER) { + if (!diagInfoPtr) { + return SQL_ERROR; + } + *static_cast(diagInfoPtr) = count; + return SQL_SUCCESS; + } + return SQL_NO_DATA; + } + + if (recNumber < 1 || recNumber > count) { + return SQL_NO_DATA; + } + + const auto& err = Errors_[recNumber - 1]; + switch (diagIdentifier) { + case SQL_DIAG_SQLSTATE: + if (!diagInfoPtr) { + return SQL_ERROR; + } + strncpy((char*)diagInfoPtr, err.SqlState.c_str(), 6); + return SQL_SUCCESS; + case SQL_DIAG_NATIVE: + if (!diagInfoPtr) { + return SQL_ERROR; + } + *static_cast(diagInfoPtr) = err.NativeError; + return SQL_SUCCESS; + case SQL_DIAG_MESSAGE_TEXT: + if (!diagInfoPtr || bufferLength <= 0) { + return SQL_ERROR; + } + strncpy((char*)diagInfoPtr, err.Message.c_str(), bufferLength); + if (stringLengthPtr) { + *stringLengthPtr = static_cast(err.Message.size()); + } + return SQL_SUCCESS; + default: + return SQL_NO_DATA; + } +} + +SQLRETURN HandleOdbcExceptions( + SQLHANDLE handlePtr, + std::function&& func, + ENullInputHandlePolicy nullInputPolicy) { + if (!handlePtr && nullInputPolicy != ENullInputHandlePolicy::Allow) { return SQL_INVALID_HANDLE; } diff --git a/odbc/src/utils/error_manager.h b/odbc/src/utils/error_manager.h index 1e31349964d..5f72a69f563 100644 --- a/odbc/src/utils/error_manager.h +++ b/odbc/src/utils/error_manager.h @@ -66,11 +66,18 @@ class TErrorManager { SQLRETURN GetDiagRec(SQLSMALLINT recNumber, SQLCHAR* sqlState, SQLINTEGER* nativeError, SQLCHAR* messageText, SQLSMALLINT bufferLength, SQLSMALLINT* textLength); + SQLRETURN GetDiagField(SQLSMALLINT recNumber, SQLSMALLINT diagIdentifier, + SQLPOINTER diagInfoPtr, SQLSMALLINT bufferLength, SQLSMALLINT* stringLengthPtr); private: TErrorList Errors_; }; +enum class ENullInputHandlePolicy : unsigned char { + Reject, + Allow, +}; + template SQLRETURN HandleOdbcExceptions(SQLHANDLE handlePtr, std::function&& func) { if (!handlePtr) { @@ -91,7 +98,10 @@ SQLRETURN HandleOdbcExceptions(SQLHANDLE handlePtr, std::function&& func); +SQLRETURN HandleOdbcExceptions( + SQLHANDLE handlePtr, + std::function&& func, + ENullInputHandlePolicy nullInputPolicy = ENullInputHandlePolicy::Reject); } // namespace NOdbc } // namespace NYdb From 062a3a0b73b03db96427156decbf6183b6f11e8f Mon Sep 17 00:00:00 2001 From: Ylonies Date: Tue, 7 Apr 2026 16:18:04 +0000 Subject: [PATCH 2/9] env features EndTran for env + tests --- odbc/src/connection.cpp | 11 +++ odbc/src/connection.h | 5 +- odbc/src/environment.cpp | 51 ++++++++++++++ odbc/src/environment.h | 9 +++ odbc/src/odbc_driver.cpp | 16 +++-- odbc/tests/integration/CMakeLists.txt | 5 ++ odbc/tests/integration/basic_it.cpp | 24 +------ odbc/tests/integration/env_it.cpp | 99 +++++++++++++++++++++++++++ odbc/tests/integration/test_utils.h | 25 +++++++ 9 files changed, 217 insertions(+), 28 deletions(-) create mode 100644 odbc/tests/integration/env_it.cpp create mode 100644 odbc/tests/integration/test_utils.h diff --git a/odbc/src/connection.cpp b/odbc/src/connection.cpp index eb142108334..29b6758b2a8 100644 --- a/odbc/src/connection.cpp +++ b/odbc/src/connection.cpp @@ -146,5 +146,16 @@ SQLRETURN TConnection::RollbackTx() { return SQL_SUCCESS; } +void TConnection::SetEnvironment(TEnvironment* env){ + if (ParentEnv_){ + throw std::logic_error("Connection already bound to environment"); + } + ParentEnv_ = env; +} + +TEnvironment* TConnection::GetEnvironment(){ + return ParentEnv_; +} + } // namespace NOdbc } // namespace NYdb diff --git a/odbc/src/connection.h b/odbc/src/connection.h index a0ce4acb991..f048e1cef4f 100644 --- a/odbc/src/connection.h +++ b/odbc/src/connection.h @@ -27,13 +27,13 @@ class TConnection : public TErrorManager { std::unique_ptr YdbTableClient_; std::unique_ptr YdbSchemeClient_; std::optional Tx_; - /** Одна сессия KQP на ODBC-соединение: DDL/DML/SELECT видят одну и ту же схему без «новой» сессии на каждый Execute. */ std::optional QuerySession_; std::vector> Statements_; std::string Endpoint_; std::string Database_; std::string AuthToken_; + TEnvironment* ParentEnv_; bool Autocommit_ = true; @@ -62,6 +62,9 @@ class TConnection : public TErrorManager { SQLRETURN CommitTx(); SQLRETURN RollbackTx(); + + void SetEnvironment(TEnvironment* env); + TEnvironment* GetEnvironment(); }; } // namespace NOdbc diff --git a/odbc/src/environment.cpp b/odbc/src/environment.cpp index 541ca9e2160..e66af68f11f 100644 --- a/odbc/src/environment.cpp +++ b/odbc/src/environment.cpp @@ -13,5 +13,56 @@ SQLRETURN TEnvironment::SetAttribute(SQLINTEGER attribute, SQLPOINTER value, SQL return SQL_SUCCESS; } +void TEnvironment::RegisterConnection(TConnection* conn){ + if (conn == nullptr){ + throw std::invalid_argument("null connection"); + } + connections_.insert(conn); +} + +void TEnvironment::UnregisterConnection(TConnection* conn){ + if (conn == nullptr){ + throw std::invalid_argument("null connection"); + } + connections_.erase(conn); +} + +std::vector TEnvironment::GetConnectionsSnapshot() const { + return std::vector(connections_.begin(), connections_.end()); +} + +SQLRETURN TEnvironment::EndTran(SQLSMALLINT completionType){ + if (completionType != SQL_COMMIT && completionType != SQL_ROLLBACK){ + return AddError("HY012", 0, "Invalid transaction operation code"); + } + bool hasFailures = false; + int failedCount = 0; + + for (auto* conn : connections_) { + if (!conn || !conn->GetTx()) { + continue; + } + try { + if (completionType == SQL_COMMIT) { + conn->CommitTx(); + } else { + conn->RollbackTx(); + } + } catch (const std::exception& ex) { + hasFailures = true; + ++failedCount; + AddError("HY000", 0, ex.what(), SQL_SUCCESS_WITH_INFO); + } catch (...) { + hasFailures = true; + ++failedCount; + AddError("HY000", 0, "Unknown error during ENV-level transaction completion", SQL_SUCCESS_WITH_INFO); + } + } + if (hasFailures) { + AddError("01000", 0, "SQLEndTran(SQL_HANDLE_ENV): some connections failed", SQL_SUCCESS_WITH_INFO); + return SQL_SUCCESS_WITH_INFO; + } + return SQL_SUCCESS; +} } // namespace NOdbc } // namespace NYdb diff --git a/odbc/src/environment.h b/odbc/src/environment.h index 5258b722492..70a785f45d7 100644 --- a/odbc/src/environment.h +++ b/odbc/src/environment.h @@ -4,6 +4,8 @@ #include #include +#include +#include namespace NYdb { namespace NOdbc { @@ -13,12 +15,19 @@ class TConnection; class TEnvironment : public TErrorManager { private: SQLINTEGER OdbcVersion_; + std::unordered_set connections_; public: TEnvironment(); ~TEnvironment(); SQLRETURN SetAttribute(SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER stringLength); + + void RegisterConnection(TConnection*); + void UnregisterConnection(TConnection*); + std::vector GetConnectionsSnapshot() const; + + SQLRETURN EndTran(SQLSMALLINT completionType); }; } // namespace NOdbc diff --git a/odbc/src/odbc_driver.cpp b/odbc/src/odbc_driver.cpp index f26bd55c828..2993c76fcef 100644 --- a/odbc/src/odbc_driver.cpp +++ b/odbc/src/odbc_driver.cpp @@ -39,8 +39,11 @@ SQLRETURN SQL_API SQLAllocHandle(SQLSMALLINT handleType, } case SQL_HANDLE_DBC: { - return NYdb::NOdbc::HandleOdbcExceptions(inputHandle, [&]() { - *outputHandle = new NYdb::NOdbc::TConnection(); + return NYdb::NOdbc::HandleOdbcExceptions(inputHandle, [&](auto* env) { + auto conn = std::make_unique(); + conn->SetEnvironment(env); + env->RegisterConnection(conn.get()); + *outputHandle = conn.release(); return SQL_SUCCESS; }); } @@ -66,6 +69,10 @@ SQLRETURN SQL_API SQLFreeHandle(SQLSMALLINT handleType, SQLHANDLE handle) { } case SQL_HANDLE_DBC: { return NYdb::NOdbc::HandleOdbcExceptions(handle, [](auto* conn) { + auto* env = conn->GetEnvironment(); + if (env != nullptr){ + env->UnregisterConnection(conn); + } delete conn; return SQL_SUCCESS; }); @@ -281,8 +288,9 @@ SQLRETURN SQL_API SQLEndTran(SQLSMALLINT handleType, SQLHANDLE handle, SQLSMALLI }); } case SQL_HANDLE_ENV: { - // TODO: if's list of connections in ENV, go through them and commit/rollback transactions - return SQL_SUCCESS; + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* env) -> SQLRETURN { + return env->EndTran(completionType); + }); } default: return SQL_ERROR; diff --git a/odbc/tests/integration/CMakeLists.txt b/odbc/tests/integration/CMakeLists.txt index e1aad9d3913..0360679931c 100644 --- a/odbc/tests/integration/CMakeLists.txt +++ b/odbc/tests/integration/CMakeLists.txt @@ -2,3 +2,8 @@ add_odbc_test(NAME odbc-basic_it SOURCES basic_it.cpp ) + +add_odbc_test(NAME odbc-env_it + SOURCES + env_it.cpp +) diff --git a/odbc/tests/integration/basic_it.cpp b/odbc/tests/integration/basic_it.cpp index b4c7078ac4e..37973667147 100644 --- a/odbc/tests/integration/basic_it.cpp +++ b/odbc/tests/integration/basic_it.cpp @@ -1,26 +1,4 @@ -#include - -#include -#include - -#include - - -#define CHECK_ODBC_OK(rc, handle, type) \ - ASSERT_TRUE((rc) == SQL_SUCCESS || (rc) == SQL_SUCCESS_WITH_INFO) << GetOdbcError(handle, type) - -std::string GetOdbcError(SQLHANDLE handle, SQLSMALLINT type) { - SQLCHAR sqlState[6], message[256]; - SQLINTEGER nativeError; - SQLSMALLINT textLength; - SQLRETURN rc = SQLGetDiagRec(type, handle, 1, sqlState, &nativeError, message, sizeof(message), &textLength); - if (rc == SQL_SUCCESS || rc == SQL_SUCCESS_WITH_INFO) { - return std::string((char*)sqlState) + ": " + (char*)message; - } - return "Unknown ODBC error"; -} - -const char* kConnStr = "Driver=" ODBC_DRIVER_PATH ";Endpoint=localhost:2136;Database=/local;"; +#include "test_utils.h" TEST(OdbcBasic, SimpleQuery) { SQLHENV env; diff --git a/odbc/tests/integration/env_it.cpp b/odbc/tests/integration/env_it.cpp new file mode 100644 index 00000000000..fd351d127af --- /dev/null +++ b/odbc/tests/integration/env_it.cpp @@ -0,0 +1,99 @@ +#include "test_utils.h" + +namespace { + +void AllocEnvAndConnect(SQLHENV* env, SQLHDBC* dbc) { + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, env), SQL_SUCCESS); + ASSERT_EQ(SQLSetEnvAttr(*env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, *env, dbc), SQL_SUCCESS); + SQLRETURN rc = SQLDriverConnect( + *dbc, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE); + CHECK_ODBC_OK(rc, *dbc, SQL_HANDLE_DBC); +} + +void StartManualTx(SQLHDBC dbc, SQLHSTMT* stmt) { + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_OFF, 0), dbc, SQL_HANDLE_DBC); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, stmt), SQL_SUCCESS); + CHECK_ODBC_OK(SQLExecDirect(*stmt, (SQLCHAR*)"SELECT 1", SQL_NTS), *stmt, SQL_HANDLE_STMT); +} + +} // namespace + +TEST(OdbcEnv, EndTranCommitOnEnv) { + SQLHENV env; + SQLHDBC dbc1, dbc2; + SQLHSTMT stmt1, stmt2; + + AllocEnvAndConnect(&env, &dbc1); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc2), SQL_SUCCESS); + SQLRETURN rc = SQLDriverConnect( + dbc2, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE); + CHECK_ODBC_OK(rc, dbc2, SQL_HANDLE_DBC); + + StartManualTx(dbc1, &stmt1); + StartManualTx(dbc2, &stmt2); + + CHECK_ODBC_OK(SQLEndTran(SQL_HANDLE_ENV, env, SQL_COMMIT), env, SQL_HANDLE_ENV); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt1); + SQLFreeHandle(SQL_HANDLE_STMT, stmt2); + SQLDisconnect(dbc1); + SQLDisconnect(dbc2); + SQLFreeHandle(SQL_HANDLE_DBC, dbc1); + SQLFreeHandle(SQL_HANDLE_DBC, dbc2); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcEnv, EndTranRollbackOnEnv) { + SQLHENV env; + SQLHDBC dbc1, dbc2; + SQLHSTMT stmt1, stmt2; + + AllocEnvAndConnect(&env, &dbc1); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc2), SQL_SUCCESS); + SQLRETURN rc = SQLDriverConnect( + dbc2, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE); + CHECK_ODBC_OK(rc, dbc2, SQL_HANDLE_DBC); + + StartManualTx(dbc1, &stmt1); + StartManualTx(dbc2, &stmt2); + + CHECK_ODBC_OK(SQLEndTran(SQL_HANDLE_ENV, env, SQL_ROLLBACK), env, SQL_HANDLE_ENV); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt1); + SQLFreeHandle(SQL_HANDLE_STMT, stmt2); + SQLDisconnect(dbc1); + SQLDisconnect(dbc2); + SQLFreeHandle(SQL_HANDLE_DBC, dbc1); + SQLFreeHandle(SQL_HANDLE_DBC, dbc2); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcEnv, EndTranPartialFailureReturnsInfo) { + SQLHENV env; + SQLHDBC dbc1, dbc2; + SQLHSTMT stmt1, stmt2; + + AllocEnvAndConnect(&env, &dbc1); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc2), SQL_SUCCESS); + SQLRETURN rc = SQLDriverConnect( + dbc2, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE); + CHECK_ODBC_OK(rc, dbc2, SQL_HANDLE_DBC); + + StartManualTx(dbc1, &stmt1); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc2, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_OFF, 0), dbc2, SQL_HANDLE_DBC); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc2, &stmt2), SQL_SUCCESS); + (void)SQLExecDirect(stmt2, (SQLCHAR*)"SELECT FROM", SQL_NTS); + + rc = SQLEndTran(SQL_HANDLE_ENV, env, SQL_COMMIT); + ASSERT_TRUE(rc == SQL_SUCCESS || rc == SQL_SUCCESS_WITH_INFO || rc == SQL_ERROR) + << GetOdbcError(env, SQL_HANDLE_ENV); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt1); + SQLFreeHandle(SQL_HANDLE_STMT, stmt2); + SQLDisconnect(dbc1); + SQLDisconnect(dbc2); + SQLFreeHandle(SQL_HANDLE_DBC, dbc1); + SQLFreeHandle(SQL_HANDLE_DBC, dbc2); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} diff --git a/odbc/tests/integration/test_utils.h b/odbc/tests/integration/test_utils.h new file mode 100644 index 00000000000..c43272f0f54 --- /dev/null +++ b/odbc/tests/integration/test_utils.h @@ -0,0 +1,25 @@ +#pragma once + +#include + +#include +#include + +#include + +#define CHECK_ODBC_OK(rc, handle, type) \ + ASSERT_TRUE((rc) == SQL_SUCCESS || (rc) == SQL_SUCCESS_WITH_INFO) << GetOdbcError(handle, type) + +inline std::string GetOdbcError(SQLHANDLE handle, SQLSMALLINT type) { + SQLCHAR sqlState[6] = {0}; + SQLCHAR message[256] = {0}; + SQLINTEGER nativeError = 0; + SQLSMALLINT textLength = 0; + SQLRETURN rc = SQLGetDiagRec(type, handle, 1, sqlState, &nativeError, message, sizeof(message), &textLength); + if (rc == SQL_SUCCESS || rc == SQL_SUCCESS_WITH_INFO) { + return std::string((char*)sqlState) + ": " + (char*)message; + } + return "Unknown ODBC error"; +} + +inline const char* kConnStr = "Driver=" ODBC_DRIVER_PATH ";Endpoint=localhost:2136;Database=/local;"; From fe35daac060d4873f89cd1f5d459fbe6c068943a Mon Sep 17 00:00:00 2001 From: Ylonies Date: Wed, 8 Apr 2026 11:54:28 +0000 Subject: [PATCH 3/9] attributes --- odbc/src/connection.cpp | 24 ++++- odbc/src/connection.h | 8 +- odbc/src/connection_attributes.cpp | 141 +++++++++++++++++++++++++++++ odbc/src/connection_attributes.h | 48 ++++++++++ odbc/src/environment.cpp | 20 +++- odbc/src/odbc_driver.cpp | 20 ++-- odbc/src/statement.cpp | 13 ++- 7 files changed, 250 insertions(+), 24 deletions(-) create mode 100644 odbc/src/connection_attributes.cpp create mode 100644 odbc/src/connection_attributes.h diff --git a/odbc/src/connection.cpp b/odbc/src/connection.cpp index 29b6758b2a8..b1049163a5d 100644 --- a/odbc/src/connection.cpp +++ b/odbc/src/connection.cpp @@ -2,9 +2,8 @@ #include "statement.h" #include "utils/error_manager.h" -#include -#include #include +#include #include #include @@ -107,8 +106,8 @@ void TConnection::RemoveStatement(TStatement* stmt) { } SQLRETURN TConnection::SetAutocommit(bool value) { - Autocommit_ = value; - if (Autocommit_ && Tx_) { + Attributes_.SetAutocommit(value); + if (Attributes_.GetAutocommit() && Tx_) { auto status = Tx_->Commit().ExtractValueSync(); NStatusHelpers::ThrowOnError(status); Tx_.reset(); @@ -117,7 +116,22 @@ SQLRETURN TConnection::SetAutocommit(bool value) { } bool TConnection::GetAutocommit() const { - return Autocommit_; + return Attributes_.GetAutocommit(); +} + +SQLRETURN TConnection::SetConnectAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER stringLength) { + return Attributes_.SetConnectAttr(attr, value, stringLength, [this](bool autocommit) { + return SetAutocommit(autocommit); + }, *this); +} + +SQLRETURN TConnection::GetConnectAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr) { + return Attributes_.GetConnectAttr(attr, value, bufferLength, stringLengthPtr, *this); +} + +NQuery::TTxSettings TConnection::MakeTxSettings() const { + return Attributes_.MakeTxSettings(); } const std::optional& TConnection::GetTx() { diff --git a/odbc/src/connection.h b/odbc/src/connection.h index f048e1cef4f..e1c9028fa84 100644 --- a/odbc/src/connection.h +++ b/odbc/src/connection.h @@ -1,6 +1,7 @@ #pragma once #include "environment.h" +#include "connection_attributes.h" #include "utils/error_manager.h" #include @@ -35,8 +36,7 @@ class TConnection : public TErrorManager { std::string AuthToken_; TEnvironment* ParentEnv_; - bool Autocommit_ = true; - + TConnectionAttributes Attributes_; public: SQLRETURN Connect(const std::string& serverName, const std::string& userName, @@ -56,6 +56,10 @@ class TConnection : public TErrorManager { SQLRETURN SetAutocommit(bool value); bool GetAutocommit() const; + SQLRETURN SetConnectAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER stringLength); + SQLRETURN GetConnectAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER bufferLength, SQLINTEGER* stringLengthPtr); + NQuery::TTxSettings MakeTxSettings() const; + const std::optional& GetTx(); void SetTx(const NQuery::TTransaction& tx); void Reset(); diff --git a/odbc/src/connection_attributes.cpp b/odbc/src/connection_attributes.cpp new file mode 100644 index 00000000000..61bfa7f8caa --- /dev/null +++ b/odbc/src/connection_attributes.cpp @@ -0,0 +1,141 @@ +#include "connection_attributes.h" + +#include + +namespace NYdb { +namespace NOdbc { + +std::optional TConnectionAttributes::ResolveTxMode(SQLUINTEGER accessMode, SQLUINTEGER txnIsolation) { + if (accessMode == SQL_MODE_READ_ONLY) { + switch (txnIsolation) { + case SQL_TXN_READ_UNCOMMITTED: + return NQuery::TTxSettings::TS_STALE_RO; + case SQL_TXN_READ_COMMITTED: + return NQuery::TTxSettings::TS_ONLINE_RO; + case SQL_TXN_REPEATABLE_READ: + case SQL_TXN_SERIALIZABLE: + return NQuery::TTxSettings::TS_SNAPSHOT_RO; + default: + return std::nullopt; + } + } + + switch (txnIsolation) { + case SQL_TXN_REPEATABLE_READ: + case SQL_TXN_SERIALIZABLE: + return NQuery::TTxSettings::TS_SERIALIZABLE_RW; + default: + return std::nullopt; + } +} + +SQLRETURN TConnectionAttributes::SetAutocommit(bool value) { + Autocommit_ = value; + return SQL_SUCCESS; +} + +bool TConnectionAttributes::GetAutocommit() const { + return Autocommit_; +} + +SQLRETURN TConnectionAttributes::SetConnectAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER /*stringLength*/, + const std::function& applyAutocommit, + TErrorManager& errors) { + switch (attr) { + case SQL_ATTR_AUTOCOMMIT: { + const intptr_t val = reinterpret_cast(value); + if (val == static_cast(SQL_AUTOCOMMIT_ON)) { + return applyAutocommit(true); + } + if (val == static_cast(SQL_AUTOCOMMIT_OFF)) { + return applyAutocommit(false); + } + return errors.AddError("HY024", 0, "Invalid SQL_ATTR_AUTOCOMMIT value"); + } + case SQL_ATTR_ACCESS_MODE: { + const intptr_t val = reinterpret_cast(value); + if (val == static_cast(SQL_MODE_READ_WRITE)) { + AccessMode_ = SQL_MODE_READ_WRITE; + auto txMode = ResolveTxMode(AccessMode_, TxnIsolation_); + if (!txMode) { + return errors.AddError("HYC00", 0, "Transaction isolation is not supported for read-write mode"); + } + TxMode_ = *txMode; + return SQL_SUCCESS; + } + if (val == static_cast(SQL_MODE_READ_ONLY)) { + AccessMode_ = SQL_MODE_READ_ONLY; + auto txMode = ResolveTxMode(AccessMode_, TxnIsolation_); + if (!txMode) { + return errors.AddError("HYC00", 0, "Transaction isolation is not supported for read-only mode"); + } + TxMode_ = *txMode; + return SQL_SUCCESS; + } + return errors.AddError("HY024", 0, "Invalid SQL_ATTR_ACCESS_MODE value"); + } + case SQL_ATTR_TXN_ISOLATION: { + const intptr_t val = reinterpret_cast(value); + const SQLUINTEGER isolation = static_cast(val); + auto txMode = ResolveTxMode(AccessMode_, isolation); + if (!txMode) { + return errors.AddError("HYC00", 0, "SQL_ATTR_TXN_ISOLATION value is not supported"); + } + TxnIsolation_ = isolation; + TxMode_ = *txMode; + return SQL_SUCCESS; + } + default: + return errors.AddError("HYC00", 0, "Optional feature not implemented"); + } +} + +SQLRETURN TConnectionAttributes::GetConnectAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER /*bufferLength*/, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) const { + if (!value) { + return errors.AddError("HY009", 0, "Invalid use of null pointer"); + } + if (stringLengthPtr) { + *stringLengthPtr = 0; + } + auto* out = reinterpret_cast(value); + switch (attr) { + case SQL_ATTR_AUTOCOMMIT: + *out = GetAutocommit() ? SQL_AUTOCOMMIT_ON : SQL_AUTOCOMMIT_OFF; + return SQL_SUCCESS; + case SQL_ATTR_ACCESS_MODE: + *out = AccessMode_; + return SQL_SUCCESS; + case SQL_ATTR_TXN_ISOLATION: + *out = TxnIsolation_; + return SQL_SUCCESS; + default: + return errors.AddError("HYC00", 0, "Optional feature not implemented"); + } +} + +NQuery::TTxSettings TConnectionAttributes::MakeTxSettings() const { + switch (TxMode_) { + case NQuery::TTxSettings::TS_ONLINE_RO: + return NQuery::TTxSettings::OnlineRO(); + case NQuery::TTxSettings::TS_STALE_RO: + return NQuery::TTxSettings::StaleRO(); + case NQuery::TTxSettings::TS_SNAPSHOT_RO: + return NQuery::TTxSettings::SnapshotRO(); + case NQuery::TTxSettings::TS_SNAPSHOT_RW: + return NQuery::TTxSettings::SnapshotRW(); + case NQuery::TTxSettings::TS_SERIALIZABLE_RW: + default: + return NQuery::TTxSettings::SerializableRW(); + } +} + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/connection_attributes.h b/odbc/src/connection_attributes.h new file mode 100644 index 00000000000..7b2f0fc7221 --- /dev/null +++ b/odbc/src/connection_attributes.h @@ -0,0 +1,48 @@ +#pragma once + +#include "utils/error_manager.h" + +#include + +#include +#include + +#include +#include + +namespace NYdb { +namespace NOdbc { + +class TConnectionAttributes { +public: + SQLRETURN SetAutocommit(bool value); + bool GetAutocommit() const; + + SQLRETURN SetConnectAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER stringLength, + const std::function& applyAutocommit, + TErrorManager& errors); + + SQLRETURN GetConnectAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) const; + + NQuery::TTxSettings MakeTxSettings() const; + +private: + static std::optional ResolveTxMode(SQLUINTEGER accessMode, SQLUINTEGER txnIsolation); + +private: + bool Autocommit_ = true; + SQLUINTEGER AccessMode_ = SQL_MODE_READ_WRITE; + SQLUINTEGER TxnIsolation_ = SQL_TXN_SERIALIZABLE; + NQuery::TTxSettings::ETransactionMode TxMode_ = NQuery::TTxSettings::TS_SERIALIZABLE_RW; +}; + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/environment.cpp b/odbc/src/environment.cpp index e66af68f11f..44e3473d023 100644 --- a/odbc/src/environment.cpp +++ b/odbc/src/environment.cpp @@ -8,9 +8,23 @@ TEnvironment::TEnvironment() : OdbcVersion_(SQL_OV_ODBC3) {} TEnvironment::~TEnvironment() {} SQLRETURN TEnvironment::SetAttribute(SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER stringLength) { - // TODO: реализовать обработку атрибутов - OdbcVersion_ = attribute == SQL_ATTR_ODBC_VERSION ? reinterpret_cast(value) : 0; - return SQL_SUCCESS; + switch (attribute) { + case SQL_ATTR_ODBC_VERSION: { + if (!value) { + return AddError("HY009", 0, "Invalid use of null pointer"); + } + OdbcVersion_ = static_cast(reinterpret_cast(value)); + return SQL_SUCCESS; + } + case SQL_ATTR_OUTPUT_NTS: { + if (value && static_cast(reinterpret_cast(value)) != SQL_TRUE) { + return AddError("HY024", 0, "SQL_ATTR_OUTPUT_NTS must be SQL_TRUE"); + } + return SQL_SUCCESS; + } + default: + return AddError("HYC00", 0, "Optional feature not implemented"); + } } void TEnvironment::RegisterConnection(TConnection* conn){ diff --git a/odbc/src/odbc_driver.cpp b/odbc/src/odbc_driver.cpp index 2993c76fcef..6b516c63bb8 100644 --- a/odbc/src/odbc_driver.cpp +++ b/odbc/src/odbc_driver.cpp @@ -299,17 +299,14 @@ SQLRETURN SQL_API SQLEndTran(SQLSMALLINT handleType, SQLHANDLE handle, SQLSMALLI SQLRETURN SQL_API SQLSetConnectAttr(SQLHDBC connectionHandle, SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER stringLength) { return NYdb::NOdbc::HandleOdbcExceptions(connectionHandle, [&](auto* conn) { - if (attribute == SQL_ATTR_AUTOCOMMIT) { - if ((intptr_t)value == SQL_AUTOCOMMIT_ON) { - return conn->SetAutocommit(true); - } else if ((intptr_t)value == SQL_AUTOCOMMIT_OFF) { - return conn->SetAutocommit(false); - } else { - throw NYdb::NOdbc::TOdbcException("HY000", 0, "Invalid autocommit value"); - } - } - // TODO: other attributes - throw NYdb::NOdbc::TOdbcException("HYC00", 0, "Optional feature not implemented"); + return conn->SetConnectAttr(attribute, value, stringLength); + }); +} + +SQLRETURN SQL_API SQLGetConnectAttr(SQLHDBC connectionHandle, SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr) { + return NYdb::NOdbc::HandleOdbcExceptions(connectionHandle, [&](auto* conn) { + return conn->GetConnectAttr(attribute, value, bufferLength, stringLengthPtr); }); } @@ -373,6 +370,7 @@ SQLRETURN SQL_API SQLFetchScroll(SQLHSTMT statementHandle, SQLSMALLINT fetchOrie } else { throw NYdb::NOdbc::TOdbcException("HYC00", 0, "Only SQL_FETCH_NEXT is supported"); } + //TODO other fetch-orientation }); } diff --git a/odbc/src/statement.cpp b/odbc/src/statement.cpp index b7fddc624bf..6f714d0b0bf 100644 --- a/odbc/src/statement.cpp +++ b/odbc/src/statement.cpp @@ -36,7 +36,7 @@ SQLRETURN TStatement::Execute() { if (Conn_->GetAutocommit()){ Conn_->Reset(); } - + auto& session = Conn_->GetOrCreateQuerySession(); auto iterator = CreateExecuteIterator(session, params); @@ -55,13 +55,20 @@ SQLRETURN TStatement::Execute() { NQuery::TExecuteQueryIterator TStatement::CreateExecuteIterator(NQuery::TSession& session, const NYdb::TParams& params){ if (Conn_->GetAutocommit()) { + const auto txSettings = Conn_->MakeTxSettings(); + if (txSettings.GetMode() == NQuery::TTxSettings::TS_SERIALIZABLE_RW) { + return session.StreamExecuteQuery( + PreparedQuery_, + NQuery::TTxControl::NoTx(), + params).ExtractValueSync(); + } return session.StreamExecuteQuery( PreparedQuery_, - NQuery::TTxControl::NoTx(), + NQuery::TTxControl::BeginTx(txSettings).CommitTx(), params).ExtractValueSync(); } if (!Conn_->GetTx()) { - auto beginTxResult = session.BeginTransaction(NQuery::TTxSettings::SerializableRW()).ExtractValueSync(); + auto beginTxResult = session.BeginTransaction(Conn_->MakeTxSettings()).ExtractValueSync(); NStatusHelpers::ThrowOnError(beginTxResult); Conn_->SetTx(beginTxResult.GetTransaction()); } From 8a48e2d794bf5bd126347be0af49b79c2453b598 Mon Sep 17 00:00:00 2001 From: Ylonies Date: Thu, 9 Apr 2026 10:30:42 +0000 Subject: [PATCH 4/9] fix --- odbc/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/odbc/CMakeLists.txt b/odbc/CMakeLists.txt index 9919870702d..0390b20698c 100644 --- a/odbc/CMakeLists.txt +++ b/odbc/CMakeLists.txt @@ -5,6 +5,7 @@ add_library(ydb-odbc SHARED src/utils/convert.cpp src/utils/error_manager.cpp src/odbc_driver.cpp + src/connection_attributes.cpp src/connection.cpp src/statement.cpp src/environment.cpp From 97a83f701fec2d068a163bda1b433ea9e0f7933a Mon Sep 17 00:00:00 2001 From: Ylonies Date: Fri, 17 Apr 2026 13:37:51 +0300 Subject: [PATCH 5/9] driver pool --- odbc/src/connection.cpp | 91 +++++++++++++++++++++++++++++++++++++++++ odbc/src/connection.h | 2 +- 2 files changed, 92 insertions(+), 1 deletion(-) diff --git a/odbc/src/connection.cpp b/odbc/src/connection.cpp index b1049163a5d..908edd962f0 100644 --- a/odbc/src/connection.cpp +++ b/odbc/src/connection.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include @@ -13,6 +14,60 @@ namespace NYdb { namespace NOdbc { +namespace { + +struct TDriverKey { + std::string Endpoint; + std::string Database; + + bool operator==(const TDriverKey& other) const noexcept { + return Endpoint == other.Endpoint && Database == other.Database; + } +}; + +struct TDriverKeyHash { + size_t operator()(const TDriverKey& key) const noexcept { + return std::hash{}(key.Endpoint) ^ (std::hash{}(key.Database) << 1U); + } +}; + +struct TDriverPool { + std::unordered_map, TDriverKeyHash> DriversByKey; + size_t InsertionsSinceCleanup = 0; +}; + +void CleanupExpiredDrivers(TDriverPool& pool) { + for (auto mapIt = pool.DriversByKey.begin(); mapIt != pool.DriversByKey.end();) { + if (mapIt->second.expired()) { + mapIt = pool.DriversByKey.erase(mapIt); + } else { + ++mapIt; + } + } +} + +std::shared_ptr AcquireSharedDriver(const std::string& endpoint, const std::string& database) { + static TDriverPool pool; + TDriverKey key{endpoint, database}; + auto it = pool.DriversByKey.find(key); + if (it != pool.DriversByKey.end()) { + if (std::shared_ptr existing = it->second.lock()) { + return existing; + } + } + auto driver = std::make_shared( + NYdb::TDriverConfig().SetEndpoint(endpoint).SetDatabase(database)); + pool.DriversByKey[std::move(key)] = driver; + ++pool.InsertionsSinceCleanup; + if (pool.InsertionsSinceCleanup >= 32) { + CleanupExpiredDrivers(pool); + pool.InsertionsSinceCleanup = 0; + } + return driver; +} + +} // namespace + SQLRETURN TConnection::DriverConnect(const std::string& connectionString) { std::map params; size_t pos = 0; @@ -171,5 +226,41 @@ TEnvironment* TConnection::GetEnvironment(){ return ParentEnv_; } +void TConnection::RecreateYdbClients() { + QuerySession_.reset(); + Tx_.reset(); + YdbSchemeClient_.reset(); + YdbTableClient_.reset(); + YdbClient_.reset(); + YdbDriver_ = AcquireSharedDriver(Endpoint_, Database_); + YdbClient_ = std::make_unique(*YdbDriver_); + YdbSchemeClient_ = std::make_unique(*YdbDriver_); + YdbTableClient_ = std::make_unique(*YdbDriver_); +} + +void TConnection::RebindToDatabase(const std::string& newDatabase) { + std::string db = newDatabase; + TConnectionAttributes::NormalizeCatalogPath(db); + Database_ = std::move(db); + Attributes_.SetCurrentCatalog(Database_); + RecreateYdbClients(); +} + + +std::string TConnection::WrapQueryForCurrentCatalog(const std::string& sql) const { + std::optional rel = Attributes_.ResolveCatalogRoute(Database_).TablePathPrefix; + if (!rel) { + return sql; + } + std::string escapedPrefix; + escapedPrefix.reserve(rel->size() + 8); + for (const char ch : *rel) { + if (ch == '\\' || ch == '"') { + escapedPrefix.push_back('\\'); + } + escapedPrefix.push_back(ch); + } + return "PRAGMA TablePathPrefix = \"" + escapedPrefix + "\";\n" + sql; +} } // namespace NOdbc } // namespace NYdb diff --git a/odbc/src/connection.h b/odbc/src/connection.h index e1c9028fa84..71bb6664454 100644 --- a/odbc/src/connection.h +++ b/odbc/src/connection.h @@ -23,7 +23,7 @@ class TStatement; class TConnection : public TErrorManager { private: - std::unique_ptr YdbDriver_; + std::shared_ptr YdbDriver_; std::unique_ptr YdbClient_; std::unique_ptr YdbTableClient_; std::unique_ptr YdbSchemeClient_; From 42111889ebfc3cbee8b46556dbabd4611b59662c Mon Sep 17 00:00:00 2001 From: Ylonies Date: Fri, 17 Apr 2026 13:55:29 +0300 Subject: [PATCH 6/9] conn attributes Combine connection attribute routing and error-localization updates into one focused commit while keeping driver pool changes separate. Made-with: Cursor --- odbc/CMakeLists.txt | 3 +- odbc/src/connection.cpp | 31 +-- odbc/src/connection.h | 10 +- odbc/src/connection_attr.cpp | 308 ++++++++++++++++++++++++++ odbc/src/connection_attr.h | 86 +++++++ odbc/src/connection_attributes.cpp | 141 ------------ odbc/src/connection_attributes.h | 48 ---- odbc/src/statement.cpp | 7 +- odbc/src/utils/attr.cpp | 51 +++++ odbc/src/utils/attr.h | 46 ++++ odbc/src/utils/diag.h | 33 +++ odbc/tests/integration/CMakeLists.txt | 5 + odbc/tests/integration/attr_it.cpp | 244 ++++++++++++++++++++ 13 files changed, 804 insertions(+), 209 deletions(-) create mode 100644 odbc/src/connection_attr.cpp create mode 100644 odbc/src/connection_attr.h delete mode 100644 odbc/src/connection_attributes.cpp delete mode 100644 odbc/src/connection_attributes.h create mode 100644 odbc/src/utils/attr.cpp create mode 100644 odbc/src/utils/attr.h create mode 100644 odbc/src/utils/diag.h create mode 100644 odbc/tests/integration/attr_it.cpp diff --git a/odbc/CMakeLists.txt b/odbc/CMakeLists.txt index 0390b20698c..5071c42f85d 100644 --- a/odbc/CMakeLists.txt +++ b/odbc/CMakeLists.txt @@ -1,11 +1,12 @@ add_library(ydb-odbc SHARED + src/utils/attr.cpp src/utils/cursor.cpp src/utils/types.cpp src/utils/util.cpp src/utils/convert.cpp src/utils/error_manager.cpp src/odbc_driver.cpp - src/connection_attributes.cpp + src/connection_attr.cpp src/connection.cpp src/statement.cpp src/environment.cpp diff --git a/odbc/src/connection.cpp b/odbc/src/connection.cpp index 908edd962f0..a52d5036f04 100644 --- a/odbc/src/connection.cpp +++ b/odbc/src/connection.cpp @@ -93,13 +93,9 @@ SQLRETURN TConnection::DriverConnect(const std::string& connectionString) { throw TOdbcException("08001", 0, "Missing Endpoint or Database in connection string"); } - YdbDriver_ = std::make_unique(NYdb::TDriverConfig() - .SetEndpoint(Endpoint_) - .SetDatabase(Database_)); - - YdbClient_ = std::make_unique(*YdbDriver_); - YdbSchemeClient_ = std::make_unique(*YdbDriver_); - YdbTableClient_ = std::make_unique(*YdbDriver_); + TConnectionAttributes::NormalizeCatalogPath(Database_); + RecreateYdbClients(); + Attributes_.SetCurrentCatalog(Database_); return SQL_SUCCESS; } @@ -121,13 +117,9 @@ SQLRETURN TConnection::Connect(const std::string& serverName, throw TOdbcException("08001", 0, "Missing Endpoint or Database in DSN"); } - YdbDriver_ = std::make_unique(NYdb::TDriverConfig() - .SetEndpoint(Endpoint_) - .SetDatabase(Database_)); - - YdbClient_ = std::make_unique(*YdbDriver_); - YdbSchemeClient_ = std::make_unique(*YdbDriver_); - YdbTableClient_ = std::make_unique(*YdbDriver_); + TConnectionAttributes::NormalizeCatalogPath(Database_); + RecreateYdbClients(); + Attributes_.SetCurrentCatalog(Database_); return SQL_SUCCESS; } @@ -175,6 +167,17 @@ bool TConnection::GetAutocommit() const { } SQLRETURN TConnection::SetConnectAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER stringLength) { + if (attr == SQL_ATTR_CURRENT_CATALOG) { + std::optional rebindDatabase; + SQLRETURN rc = Attributes_.ApplyCatalogChange(value, stringLength, Database_, rebindDatabase, *this); + if (rc != SQL_SUCCESS) { + return rc; + } + if (rebindDatabase) { + RebindToDatabase(*rebindDatabase); + } + return SQL_SUCCESS; + } return Attributes_.SetConnectAttr(attr, value, stringLength, [this](bool autocommit) { return SetAutocommit(autocommit); }, *this); diff --git a/odbc/src/connection.h b/odbc/src/connection.h index 71bb6664454..284ac36cf65 100644 --- a/odbc/src/connection.h +++ b/odbc/src/connection.h @@ -1,7 +1,7 @@ #pragma once #include "environment.h" -#include "connection_attributes.h" +#include "connection_attr.h" #include "utils/error_manager.h" #include @@ -13,8 +13,9 @@ #include #include -#include +#include #include +#include namespace NYdb { namespace NOdbc { @@ -37,6 +38,9 @@ class TConnection : public TErrorManager { TEnvironment* ParentEnv_; TConnectionAttributes Attributes_; + + void RecreateYdbClients(); + void RebindToDatabase(const std::string& newDatabase); public: SQLRETURN Connect(const std::string& serverName, const std::string& userName, @@ -60,6 +64,8 @@ class TConnection : public TErrorManager { SQLRETURN GetConnectAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER bufferLength, SQLINTEGER* stringLengthPtr); NQuery::TTxSettings MakeTxSettings() const; + std::string WrapQueryForCurrentCatalog(const std::string& sql) const; + const std::optional& GetTx(); void SetTx(const NQuery::TTransaction& tx); void Reset(); diff --git a/odbc/src/connection_attr.cpp b/odbc/src/connection_attr.cpp new file mode 100644 index 00000000000..6197a4ad4a9 --- /dev/null +++ b/odbc/src/connection_attr.cpp @@ -0,0 +1,308 @@ + +#include "connection_attr.h" +#include "utils/attr.h" +#include "utils/diag.h" + +#include + +namespace NYdb { +namespace NOdbc { + +namespace { + +namespace Catalog { + +void NormalizePath(std::string& path) { + if (path.empty() || path == "/") { + return; + } + const size_t trailingSlashStart = path.find_last_not_of('/'); + if (trailingSlashStart == std::string::npos) { + path.assign("/"); + return; + } + path.erase(trailingSlashStart + 1); +} + +TConnectionAttributes::TCatalogBinding BuildBinding(const std::string& currentCatalog, const std::string& database) { + TConnectionAttributes::TCatalogBinding binding; + binding.Catalog = currentCatalog; + binding.Database = database; + NormalizePath(binding.Catalog); + NormalizePath(binding.Database); + if (binding.Catalog == binding.Database) { + return binding; + } + + const std::string databasePrefix = binding.Database + "/"; + if (binding.Catalog.size() <= databasePrefix.size() || + binding.Catalog.compare(0, databasePrefix.size(), databasePrefix) != 0) { + return binding; + } + + std::string relativeCatalog = binding.Catalog.substr(databasePrefix.size()); + if (!relativeCatalog.empty()) { + binding.RelativeCatalog = std::move(relativeCatalog); + } + return binding; +} + +} // namespace Catalog + +namespace Tx { + +bool IsKnownTxnIsolation(SQLUINTEGER txnIsolation) { + switch (txnIsolation) { + case SQL_TXN_READ_UNCOMMITTED: + case SQL_TXN_READ_COMMITTED: + case SQL_TXN_REPEATABLE_READ: + case SQL_TXN_SERIALIZABLE: + return true; + default: + return false; + } +} + +std::optional ResolveTxMode(SQLUINTEGER accessMode, SQLUINTEGER txnIsolation) { + if (accessMode == SQL_MODE_READ_ONLY) { + switch (txnIsolation) { + case SQL_TXN_READ_UNCOMMITTED: + return NQuery::TTxSettings::TS_STALE_RO; + case SQL_TXN_READ_COMMITTED: + return NQuery::TTxSettings::TS_ONLINE_RO; + case SQL_TXN_REPEATABLE_READ: + case SQL_TXN_SERIALIZABLE: + return NQuery::TTxSettings::TS_SNAPSHOT_RO; + default: + return std::nullopt; + } + } + + switch (txnIsolation) { + case SQL_TXN_REPEATABLE_READ: + case SQL_TXN_SERIALIZABLE: + return NQuery::TTxSettings::TS_SERIALIZABLE_RW; + default: + return std::nullopt; + } +} + +} // namespace Tx + +namespace Autocommit { + +SQLRETURN Get(bool autocommitEnabled, SQLPOINTER value) { + auto* out = reinterpret_cast(value); + *out = autocommitEnabled ? SQL_AUTOCOMMIT_ON : SQL_AUTOCOMMIT_OFF; + return SQL_SUCCESS; +} + +} // namespace Autocommit + +} + +void TConnectionAttributes::NormalizeCatalogPath(std::string& path) { + Catalog::NormalizePath(path); +} + +SQLRETURN TConnectionAttributes::SetAutocommit(bool value) { + Autocommit_ = value; + return SQL_SUCCESS; +} + +bool TConnectionAttributes::GetAutocommit() const { + return Autocommit_; +} + +SQLRETURN TConnectionAttributes::SetConnectAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER stringLength, + const std::function& applyAutocommit, + TErrorManager& errors) { + switch (attr) { + case SQL_ATTR_AUTOCOMMIT: + return SetAutocommit(value, applyAutocommit, errors); + case SQL_ATTR_ACCESS_MODE: + return SetAccessMode(value, errors); + case SQL_ATTR_TXN_ISOLATION: + return SetTxnIsolation(value, errors); + case SQL_ATTR_CURRENT_CATALOG: + return SetCurrentCatalog(value, stringLength, errors); + default: + return Diag::AddNotImplemented(errors); + } +} + +SQLRETURN TConnectionAttributes::GetConnectAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) const { + if (!value) { + return Diag::AddNullPointer(errors); + } + if (stringLengthPtr) { + *stringLengthPtr = 0; + } + switch (attr) { + case SQL_ATTR_AUTOCOMMIT: + return GetAutocommit(value); + case SQL_ATTR_ACCESS_MODE: + return GetAccessMode(value); + case SQL_ATTR_TXN_ISOLATION: + return GetTxnIsolation(value); + case SQL_ATTR_CURRENT_CATALOG: + return GetCurrentCatalog(value, bufferLength, stringLengthPtr, errors); + default: + return Diag::AddNotImplemented(errors); + } +} + +SQLRETURN TConnectionAttributes::SetAutocommit( + SQLPOINTER value, + const std::function& applyAutocommit, + TErrorManager& errors) { + const auto token = ReadIntegerAttrIfIn( + value, + {static_cast(SQL_AUTOCOMMIT_ON), static_cast(SQL_AUTOCOMMIT_OFF)}); + if (!token) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_AUTOCOMMIT"); + } + if (*token == static_cast(SQL_AUTOCOMMIT_ON)) { + return applyAutocommit(true); + } + return applyAutocommit(false); +} + +SQLRETURN TConnectionAttributes::SetAccessMode(SQLPOINTER value, TErrorManager& errors) { + const auto mode = ReadIntegerAttrIfIn(value, {SQL_MODE_READ_WRITE, SQL_MODE_READ_ONLY}); + if (!mode) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_ACCESS_MODE"); + } + AccessMode_ = *mode; + auto txMode = Tx::ResolveTxMode(AccessMode_, TxnIsolation_); + if (!txMode) { + return errors.AddError( + "HYC00", + 0, + AccessMode_ == SQL_MODE_READ_WRITE + ? "Transaction isolation is not supported for read-write mode" + : "Transaction isolation is not supported for read-only mode"); + } + TxMode_ = *txMode; + return SQL_SUCCESS; +} + +SQLRETURN TConnectionAttributes::SetTxnIsolation(SQLPOINTER value, TErrorManager& errors) { + const SQLUINTEGER isolation = ReadIntegerAttr(value); + if (!Tx::IsKnownTxnIsolation(isolation)) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_TXN_ISOLATION"); + } + auto txMode = Tx::ResolveTxMode(AccessMode_, isolation); + if (!txMode) { + return errors.AddError("HYC00", 0, "SQL_ATTR_TXN_ISOLATION value is not supported"); + } + TxnIsolation_ = isolation; + TxMode_ = *txMode; + return SQL_SUCCESS; +} + +SQLRETURN TConnectionAttributes::SetCurrentCatalog(SQLPOINTER value, SQLINTEGER stringLength, TErrorManager& errors) { + if (!value) { + return Diag::AddNullPointer(errors); + } + CurrentCatalog_ = ReadAttributeString(value, stringLength); + Catalog::NormalizePath(CurrentCatalog_); + if (CurrentCatalog_.empty()) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_CURRENT_CATALOG"); + } + return SQL_SUCCESS; +} + +SQLRETURN TConnectionAttributes::GetAutocommit(SQLPOINTER value) const { + return Autocommit::Get(Autocommit_, value); +} + +SQLRETURN TConnectionAttributes::GetAccessMode(SQLPOINTER value) const { + auto* out = reinterpret_cast(value); + *out = AccessMode_; + return SQL_SUCCESS; +} + +SQLRETURN TConnectionAttributes::GetTxnIsolation(SQLPOINTER value) const { + auto* out = reinterpret_cast(value); + *out = TxnIsolation_; + return SQL_SUCCESS; +} + +SQLRETURN TConnectionAttributes::GetCurrentCatalog( + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) const { + return WriteAttributeString(CurrentCatalog_, value, bufferLength, stringLengthPtr, errors); +} + +NQuery::TTxSettings TConnectionAttributes::MakeTxSettings() const { + switch (TxMode_) { + case NQuery::TTxSettings::TS_ONLINE_RO: + return NQuery::TTxSettings::OnlineRO(); + case NQuery::TTxSettings::TS_STALE_RO: + return NQuery::TTxSettings::StaleRO(); + case NQuery::TTxSettings::TS_SNAPSHOT_RO: + return NQuery::TTxSettings::SnapshotRO(); + case NQuery::TTxSettings::TS_SNAPSHOT_RW: + return NQuery::TTxSettings::SnapshotRW(); + case NQuery::TTxSettings::TS_SERIALIZABLE_RW: + default: + return NQuery::TTxSettings::SerializableRW(); + } +} + +void TConnectionAttributes::SetCurrentCatalog(const std::string& value) { + CurrentCatalog_ = value; + Catalog::NormalizePath(CurrentCatalog_); +} + +const std::string& TConnectionAttributes::GetCurrentCatalog() const { + return CurrentCatalog_; +} + +TConnectionAttributes::TCatalogBinding TConnectionAttributes::BuildCatalogBinding(const std::string& database) const { + return Catalog::BuildBinding(CurrentCatalog_, database); +} + +TConnectionAttributes::TCatalogRoute TConnectionAttributes::ResolveCatalogRoute(const std::string& currentDatabase) const { + const TCatalogBinding binding = BuildCatalogBinding(currentDatabase); + if (binding.Catalog == binding.Database) { + return {binding.Database, std::nullopt}; + } + if (binding.RelativeCatalog) { + return {binding.Database, binding.Catalog}; + } + return {binding.Catalog, std::nullopt}; +} + +SQLRETURN TConnectionAttributes::ApplyCatalogChange( + SQLPOINTER value, + SQLINTEGER stringLength, + const std::string& currentDatabase, + std::optional& rebindDatabase, + TErrorManager& errors) { + SQLRETURN rc = SetCurrentCatalog(value, stringLength, errors); + if (rc != SQL_SUCCESS) { + return rc; + } + const TCatalogRoute route = ResolveCatalogRoute(currentDatabase); + if (route.EffectiveDatabase != currentDatabase) { + rebindDatabase = route.EffectiveDatabase; + } else { + rebindDatabase.reset(); + } + return SQL_SUCCESS; +} + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/connection_attr.h b/odbc/src/connection_attr.h new file mode 100644 index 00000000000..607f7bf929f --- /dev/null +++ b/odbc/src/connection_attr.h @@ -0,0 +1,86 @@ +#pragma once + +#include "utils/error_manager.h" + +#include + +#include +#include +#include + +#include +#include + +namespace NYdb { +namespace NOdbc { + +class TConnectionAttributes { +public: + struct TCatalogBinding { + std::string Catalog; + std::string Database; + std::optional RelativeCatalog; + }; + + struct TCatalogRoute { + std::string EffectiveDatabase; + std::optional TablePathPrefix; + }; + + SQLRETURN SetAutocommit(bool value); + bool GetAutocommit() const; + + SQLRETURN SetConnectAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER stringLength, + const std::function& applyAutocommit, + TErrorManager& errors); + + SQLRETURN GetConnectAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) const; + + NQuery::TTxSettings MakeTxSettings() const; + void SetCurrentCatalog(const std::string& value); + const std::string& GetCurrentCatalog() const; + TCatalogBinding BuildCatalogBinding(const std::string& database) const; + TCatalogRoute ResolveCatalogRoute(const std::string& currentDatabase) const; + SQLRETURN ApplyCatalogChange( + SQLPOINTER value, + SQLINTEGER stringLength, + const std::string& currentDatabase, + std::optional& rebindDatabase, + TErrorManager& errors); + static void NormalizeCatalogPath(std::string& path); + +private: + SQLRETURN SetAutocommit( + SQLPOINTER value, + const std::function& applyAutocommit, + TErrorManager& errors); + SQLRETURN SetAccessMode(SQLPOINTER value, TErrorManager& errors); + SQLRETURN SetTxnIsolation(SQLPOINTER value, TErrorManager& errors); + SQLRETURN SetCurrentCatalog(SQLPOINTER value, SQLINTEGER stringLength, TErrorManager& errors); + + SQLRETURN GetAutocommit(SQLPOINTER value) const; + SQLRETURN GetAccessMode(SQLPOINTER value) const; + SQLRETURN GetTxnIsolation(SQLPOINTER value) const; + SQLRETURN GetCurrentCatalog( + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) const; + + bool Autocommit_ = true; + std::string CurrentCatalog_; + SQLUINTEGER AccessMode_ = SQL_MODE_READ_WRITE; + SQLUINTEGER TxnIsolation_ = SQL_TXN_SERIALIZABLE; + NQuery::TTxSettings::ETransactionMode TxMode_ = NQuery::TTxSettings::TS_SERIALIZABLE_RW; +}; + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/connection_attributes.cpp b/odbc/src/connection_attributes.cpp deleted file mode 100644 index 61bfa7f8caa..00000000000 --- a/odbc/src/connection_attributes.cpp +++ /dev/null @@ -1,141 +0,0 @@ -#include "connection_attributes.h" - -#include - -namespace NYdb { -namespace NOdbc { - -std::optional TConnectionAttributes::ResolveTxMode(SQLUINTEGER accessMode, SQLUINTEGER txnIsolation) { - if (accessMode == SQL_MODE_READ_ONLY) { - switch (txnIsolation) { - case SQL_TXN_READ_UNCOMMITTED: - return NQuery::TTxSettings::TS_STALE_RO; - case SQL_TXN_READ_COMMITTED: - return NQuery::TTxSettings::TS_ONLINE_RO; - case SQL_TXN_REPEATABLE_READ: - case SQL_TXN_SERIALIZABLE: - return NQuery::TTxSettings::TS_SNAPSHOT_RO; - default: - return std::nullopt; - } - } - - switch (txnIsolation) { - case SQL_TXN_REPEATABLE_READ: - case SQL_TXN_SERIALIZABLE: - return NQuery::TTxSettings::TS_SERIALIZABLE_RW; - default: - return std::nullopt; - } -} - -SQLRETURN TConnectionAttributes::SetAutocommit(bool value) { - Autocommit_ = value; - return SQL_SUCCESS; -} - -bool TConnectionAttributes::GetAutocommit() const { - return Autocommit_; -} - -SQLRETURN TConnectionAttributes::SetConnectAttr( - SQLINTEGER attr, - SQLPOINTER value, - SQLINTEGER /*stringLength*/, - const std::function& applyAutocommit, - TErrorManager& errors) { - switch (attr) { - case SQL_ATTR_AUTOCOMMIT: { - const intptr_t val = reinterpret_cast(value); - if (val == static_cast(SQL_AUTOCOMMIT_ON)) { - return applyAutocommit(true); - } - if (val == static_cast(SQL_AUTOCOMMIT_OFF)) { - return applyAutocommit(false); - } - return errors.AddError("HY024", 0, "Invalid SQL_ATTR_AUTOCOMMIT value"); - } - case SQL_ATTR_ACCESS_MODE: { - const intptr_t val = reinterpret_cast(value); - if (val == static_cast(SQL_MODE_READ_WRITE)) { - AccessMode_ = SQL_MODE_READ_WRITE; - auto txMode = ResolveTxMode(AccessMode_, TxnIsolation_); - if (!txMode) { - return errors.AddError("HYC00", 0, "Transaction isolation is not supported for read-write mode"); - } - TxMode_ = *txMode; - return SQL_SUCCESS; - } - if (val == static_cast(SQL_MODE_READ_ONLY)) { - AccessMode_ = SQL_MODE_READ_ONLY; - auto txMode = ResolveTxMode(AccessMode_, TxnIsolation_); - if (!txMode) { - return errors.AddError("HYC00", 0, "Transaction isolation is not supported for read-only mode"); - } - TxMode_ = *txMode; - return SQL_SUCCESS; - } - return errors.AddError("HY024", 0, "Invalid SQL_ATTR_ACCESS_MODE value"); - } - case SQL_ATTR_TXN_ISOLATION: { - const intptr_t val = reinterpret_cast(value); - const SQLUINTEGER isolation = static_cast(val); - auto txMode = ResolveTxMode(AccessMode_, isolation); - if (!txMode) { - return errors.AddError("HYC00", 0, "SQL_ATTR_TXN_ISOLATION value is not supported"); - } - TxnIsolation_ = isolation; - TxMode_ = *txMode; - return SQL_SUCCESS; - } - default: - return errors.AddError("HYC00", 0, "Optional feature not implemented"); - } -} - -SQLRETURN TConnectionAttributes::GetConnectAttr( - SQLINTEGER attr, - SQLPOINTER value, - SQLINTEGER /*bufferLength*/, - SQLINTEGER* stringLengthPtr, - TErrorManager& errors) const { - if (!value) { - return errors.AddError("HY009", 0, "Invalid use of null pointer"); - } - if (stringLengthPtr) { - *stringLengthPtr = 0; - } - auto* out = reinterpret_cast(value); - switch (attr) { - case SQL_ATTR_AUTOCOMMIT: - *out = GetAutocommit() ? SQL_AUTOCOMMIT_ON : SQL_AUTOCOMMIT_OFF; - return SQL_SUCCESS; - case SQL_ATTR_ACCESS_MODE: - *out = AccessMode_; - return SQL_SUCCESS; - case SQL_ATTR_TXN_ISOLATION: - *out = TxnIsolation_; - return SQL_SUCCESS; - default: - return errors.AddError("HYC00", 0, "Optional feature not implemented"); - } -} - -NQuery::TTxSettings TConnectionAttributes::MakeTxSettings() const { - switch (TxMode_) { - case NQuery::TTxSettings::TS_ONLINE_RO: - return NQuery::TTxSettings::OnlineRO(); - case NQuery::TTxSettings::TS_STALE_RO: - return NQuery::TTxSettings::StaleRO(); - case NQuery::TTxSettings::TS_SNAPSHOT_RO: - return NQuery::TTxSettings::SnapshotRO(); - case NQuery::TTxSettings::TS_SNAPSHOT_RW: - return NQuery::TTxSettings::SnapshotRW(); - case NQuery::TTxSettings::TS_SERIALIZABLE_RW: - default: - return NQuery::TTxSettings::SerializableRW(); - } -} - -} // namespace NOdbc -} // namespace NYdb diff --git a/odbc/src/connection_attributes.h b/odbc/src/connection_attributes.h deleted file mode 100644 index 7b2f0fc7221..00000000000 --- a/odbc/src/connection_attributes.h +++ /dev/null @@ -1,48 +0,0 @@ -#pragma once - -#include "utils/error_manager.h" - -#include - -#include -#include - -#include -#include - -namespace NYdb { -namespace NOdbc { - -class TConnectionAttributes { -public: - SQLRETURN SetAutocommit(bool value); - bool GetAutocommit() const; - - SQLRETURN SetConnectAttr( - SQLINTEGER attr, - SQLPOINTER value, - SQLINTEGER stringLength, - const std::function& applyAutocommit, - TErrorManager& errors); - - SQLRETURN GetConnectAttr( - SQLINTEGER attr, - SQLPOINTER value, - SQLINTEGER bufferLength, - SQLINTEGER* stringLengthPtr, - TErrorManager& errors) const; - - NQuery::TTxSettings MakeTxSettings() const; - -private: - static std::optional ResolveTxMode(SQLUINTEGER accessMode, SQLUINTEGER txnIsolation); - -private: - bool Autocommit_ = true; - SQLUINTEGER AccessMode_ = SQL_MODE_READ_WRITE; - SQLUINTEGER TxnIsolation_ = SQL_TXN_SERIALIZABLE; - NQuery::TTxSettings::ETransactionMode TxMode_ = NQuery::TTxSettings::TS_SERIALIZABLE_RW; -}; - -} // namespace NOdbc -} // namespace NYdb diff --git a/odbc/src/statement.cpp b/odbc/src/statement.cpp index 6f714d0b0bf..dd657f304ba 100644 --- a/odbc/src/statement.cpp +++ b/odbc/src/statement.cpp @@ -54,16 +54,17 @@ SQLRETURN TStatement::Execute() { } NQuery::TExecuteQueryIterator TStatement::CreateExecuteIterator(NQuery::TSession& session, const NYdb::TParams& params){ + const std::string queryText = Conn_->WrapQueryForCurrentCatalog(PreparedQuery_); if (Conn_->GetAutocommit()) { const auto txSettings = Conn_->MakeTxSettings(); if (txSettings.GetMode() == NQuery::TTxSettings::TS_SERIALIZABLE_RW) { return session.StreamExecuteQuery( - PreparedQuery_, + queryText, NQuery::TTxControl::NoTx(), params).ExtractValueSync(); } return session.StreamExecuteQuery( - PreparedQuery_, + queryText, NQuery::TTxControl::BeginTx(txSettings).CommitTx(), params).ExtractValueSync(); } @@ -73,7 +74,7 @@ NQuery::TExecuteQueryIterator TStatement::CreateExecuteIterator(NQuery::TSession Conn_->SetTx(beginTxResult.GetTransaction()); } return session.StreamExecuteQuery( - PreparedQuery_, + queryText, NQuery::TTxControl::Tx(*Conn_->GetTx()).CommitTx(false), params).ExtractValueSync(); } diff --git a/odbc/src/utils/attr.cpp b/odbc/src/utils/attr.cpp new file mode 100644 index 00000000000..1fb2a83324a --- /dev/null +++ b/odbc/src/utils/attr.cpp @@ -0,0 +1,51 @@ +#include "attr.h" +#include "diag.h" + +#include +#include + +namespace NYdb::NOdbc { + +std::string ReadAttributeString(SQLPOINTER value, SQLINTEGER stringLength) { + const char* const str = static_cast(value); + if (stringLength == SQL_NTS) { + return std::string(str); + } + if (stringLength < 0) { + return {}; + } + return std::string(str, static_cast(stringLength)); +} + +SQLRETURN WriteAttributeString( + const std::string& source, + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) { + const SQLINTEGER length = static_cast(source.size()); + if (stringLengthPtr != nullptr) { + *stringLengthPtr = length; + } + if (value == nullptr) { + return SQL_SUCCESS; + } + if (bufferLength <= 0) { + return Diag::AddInvalidBufferLength(errors); + } + + auto* dest = static_cast(value); + const size_t maxData = static_cast(bufferLength - 1); + const size_t nCopy = std::min(source.size(), maxData); + if (nCopy > 0) { + std::memcpy(dest, source.data(), nCopy); + } + dest[nCopy] = 0; + + if (length >= bufferLength) { + return Diag::AddRightTruncated(errors); + } + return SQL_SUCCESS; +} + +} // namespace NYdb::NOdbc diff --git a/odbc/src/utils/attr.h b/odbc/src/utils/attr.h new file mode 100644 index 00000000000..34abeed42da --- /dev/null +++ b/odbc/src/utils/attr.h @@ -0,0 +1,46 @@ +#pragma once + +#include "error_manager.h" + +#include +#include +#include + +#include +#include + +namespace NYdb::NOdbc { + +std::string ReadAttributeString(SQLPOINTER value, SQLINTEGER stringLength); + +SQLRETURN WriteAttributeString( + const std::string& source, + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors); + +template +T ReadIntegerAttr(SQLPOINTER value) noexcept; + +template +std::optional ReadIntegerAttrIfIn(SQLPOINTER value, std::initializer_list allowed) noexcept; + +template +T ReadIntegerAttr(SQLPOINTER value) noexcept { + return static_cast(reinterpret_cast(value)); +} + +template +std::optional ReadIntegerAttrIfIn(SQLPOINTER value, std::initializer_list allowed) noexcept { + const T token = ReadIntegerAttr(value); + for (const T allowedValue : allowed) { + if (token == allowedValue) { + return token; + } + } + return std::nullopt; +} + + +} // namespace NYdb::NOdbc diff --git a/odbc/src/utils/diag.h b/odbc/src/utils/diag.h new file mode 100644 index 00000000000..5e2db740a07 --- /dev/null +++ b/odbc/src/utils/diag.h @@ -0,0 +1,33 @@ +#pragma once + +#include "error_manager.h" + +#include +#include + +namespace NYdb::NOdbc { +namespace Diag { + + inline SQLRETURN AddNullPointer(TErrorManager& errors) { + return errors.AddError("HY009", 0, "Invalid use of null pointer"); + } + + inline SQLRETURN AddNotImplemented(TErrorManager& errors) { + return errors.AddError("HYC00", 0, "Optional feature not implemented"); + } + + inline SQLRETURN AddInvalidAttrValue(TErrorManager& errors, std::string_view attrName) { + return errors.AddError("HY024", 0, "Invalid " + std::string(attrName) + " value"); + } + + inline SQLRETURN AddInvalidBufferLength(TErrorManager& errors) { + return errors.AddError("HY090", 0, "Invalid string or buffer length"); + } + + inline SQLRETURN AddRightTruncated(TErrorManager& errors) { + return errors.AddError("01004", 0, "String data, right truncated", SQL_SUCCESS_WITH_INFO); + } + +} + +} // namespace NYdb::NOdbc::Diag diff --git a/odbc/tests/integration/CMakeLists.txt b/odbc/tests/integration/CMakeLists.txt index 0360679931c..39128437ced 100644 --- a/odbc/tests/integration/CMakeLists.txt +++ b/odbc/tests/integration/CMakeLists.txt @@ -7,3 +7,8 @@ add_odbc_test(NAME odbc-env_it SOURCES env_it.cpp ) + +add_odbc_test(NAME odbc-attr_it + SOURCES + attr_it.cpp +) diff --git a/odbc/tests/integration/attr_it.cpp b/odbc/tests/integration/attr_it.cpp new file mode 100644 index 00000000000..514278f8e67 --- /dev/null +++ b/odbc/tests/integration/attr_it.cpp @@ -0,0 +1,244 @@ +#include "test_utils.h" + +#include +#include + +namespace { + +bool SqlStatePrefix(const std::string& diag, const char* state5) { + return diag.size() >= 5 && std::strncmp(diag.c_str(), state5, 5) == 0; +} + +void AllocEnv(SQLHENV* env) { + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, env), SQL_SUCCESS); + ASSERT_EQ(SQLSetEnvAttr(*env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); +} + +void AllocEnvAndConnect(SQLHENV* env, SQLHDBC* dbc) { + AllocEnv(env); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, *env, dbc), SQL_SUCCESS); + SQLRETURN rc = SQLDriverConnect( + *dbc, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE); + CHECK_ODBC_OK(rc, *dbc, SQL_HANDLE_DBC); +} + +} // namespace + +TEST(OdbcAttrEnv, OdbcVersionAttr) { + SQLHENV env; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env), SQL_SUCCESS); + ASSERT_EQ(SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); + ASSERT_NE(SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, nullptr, 0), SQL_SUCCESS); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcAttrEnv, OutputNtsAttr) { + SQLHENV env; + AllocEnv(&env); + ASSERT_EQ(SQLSetEnvAttr(env, SQL_ATTR_OUTPUT_NTS, (void*)SQL_TRUE, 0), SQL_SUCCESS); + ASSERT_NE(SQLSetEnvAttr(env, SQL_ATTR_OUTPUT_NTS, (void*)SQL_FALSE, 0), SQL_SUCCESS); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcAttrConn, AutocommitAttr) { + SQLHENV env; + SQLHDBC dbc; + AllocEnvAndConnect(&env, &dbc); + SQLHSTMT stmt; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + SQLCHAR dropQuery[] = "DROP TABLE IF EXISTS test_attr_autocommit"; + SQLCHAR createQuery[] = + "CREATE TABLE test_attr_autocommit (id Int32, value Int32, PRIMARY KEY (id))"; + SQLCHAR upsertRollbackQuery[] = "UPSERT INTO test_attr_autocommit (id, value) VALUES (1, 100)"; + SQLCHAR upsertCommitQuery[] = "UPSERT INTO test_attr_autocommit (id, value) VALUES (1, 200)"; + SQLCHAR selectQuery[] = "SELECT value FROM test_attr_autocommit WHERE id = 1"; + + CHECK_ODBC_OK(SQLExecDirect(stmt, dropQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, createQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_OFF, 0), dbc, SQL_HANDLE_DBC); + CHECK_ODBC_OK(SQLExecDirect(stmt, upsertRollbackQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLEndTran(SQL_HANDLE_DBC, dbc, SQL_ROLLBACK), SQL_SUCCESS); + CHECK_ODBC_OK(SQLExecDirect(stmt, selectQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_NO_DATA); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + + CHECK_ODBC_OK(SQLExecDirect(stmt, upsertCommitQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLEndTran(SQL_HANDLE_DBC, dbc, SQL_COMMIT), SQL_SUCCESS); + CHECK_ODBC_OK(SQLExecDirect(stmt, selectQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + SQLINTEGER valueInt = 0; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &valueInt, 0, &valueInd), SQL_SUCCESS); + ASSERT_EQ(valueInt, 200); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_ON, 0), dbc, SQL_HANDLE_DBC); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcAttrConn, AccessModeAttr) { + SQLHENV env; + SQLHDBC dbc; + AllocEnvAndConnect(&env, &dbc); + + constexpr SQLUINTEGER readWriteMode = SQL_MODE_READ_WRITE; + constexpr SQLUINTEGER readOnlyMode = SQL_MODE_READ_ONLY; + SQLUINTEGER currentMode = 0; + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, ¤tMode, sizeof(currentMode), nullptr), SQL_SUCCESS); + ASSERT_EQ(readWriteMode, currentMode); + + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, (SQLPOINTER)readOnlyMode, 0), dbc, SQL_HANDLE_DBC); + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, ¤tMode, sizeof(currentMode), nullptr), SQL_SUCCESS); + ASSERT_EQ(readOnlyMode, currentMode); + + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, (SQLPOINTER)readWriteMode, 0), dbc, SQL_HANDLE_DBC); + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, ¤tMode, sizeof(currentMode), nullptr), SQL_SUCCESS); + ASSERT_EQ(readWriteMode, currentMode); + + ASSERT_EQ(SQLSetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, (SQLPOINTER)9999, 0), SQL_ERROR); + EXPECT_TRUE(SqlStatePrefix(GetOdbcError(dbc, SQL_HANDLE_DBC), "HY024")); + + SQLHSTMT stmt; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + SQLCHAR dropQuery[] = "DROP TABLE IF EXISTS test_attr_read_only"; + SQLCHAR createQuery[] = "CREATE TABLE test_attr_read_only (id Int32, PRIMARY KEY (id))"; + SQLCHAR selectOneQuery[] = "SELECT 1 AS value"; + SQLCHAR upsertQuery[] = "UPSERT INTO test_attr_read_only (id) VALUES (1)"; + CHECK_ODBC_OK(SQLExecDirect(stmt, dropQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, createQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, (SQLPOINTER)readOnlyMode, 0), dbc, SQL_HANDLE_DBC); + CHECK_ODBC_OK(SQLExecDirect(stmt, selectOneQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLExecDirect(stmt, upsertQuery, SQL_NTS), SQL_ERROR); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcAttrConn, TxnIsolationAttr) { + SQLHENV env; + SQLHDBC dbc; + AllocEnvAndConnect(&env, &dbc); + SQLHSTMT stmt; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + SQLCHAR selectOneQuery[] = "SELECT 1 AS value"; + + SQLUINTEGER currentIsolation = 0; + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, ¤tIsolation, sizeof(currentIsolation), nullptr), SQL_SUCCESS); + ASSERT_EQ(static_cast(SQL_TXN_SERIALIZABLE), currentIsolation); + + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)SQL_TXN_REPEATABLE_READ, 0), dbc, SQL_HANDLE_DBC); + CHECK_ODBC_OK(SQLExecDirect(stmt, selectOneQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + + ASSERT_EQ(SQLSetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)SQL_TXN_READ_COMMITTED, 0), SQL_ERROR); + EXPECT_TRUE(SqlStatePrefix(GetOdbcError(dbc, SQL_HANDLE_DBC), "HYC00")); + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, ¤tIsolation, sizeof(currentIsolation), nullptr), SQL_SUCCESS); + ASSERT_EQ(static_cast(SQL_TXN_REPEATABLE_READ), currentIsolation); + + // In read-only mode all four standard levels are accepted and remain executable. + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, (SQLPOINTER)SQL_MODE_READ_ONLY, 0), dbc, SQL_HANDLE_DBC); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)SQL_TXN_READ_UNCOMMITTED, 0), dbc, SQL_HANDLE_DBC); + CHECK_ODBC_OK(SQLExecDirect(stmt, selectOneQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)SQL_TXN_READ_COMMITTED, 0), dbc, SQL_HANDLE_DBC); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)SQL_TXN_REPEATABLE_READ, 0), dbc, SQL_HANDLE_DBC); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)SQL_TXN_SERIALIZABLE, 0), dbc, SQL_HANDLE_DBC); + + ASSERT_EQ(SQLSetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)9999, 0), SQL_ERROR); + EXPECT_TRUE(SqlStatePrefix(GetOdbcError(dbc, SQL_HANDLE_DBC), "HY024")); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcAttrConn, CurrentCatalogAttr) { + SQLHENV env; + SQLHDBC dbc; + AllocEnvAndConnect(&env, &dbc); + SQLHSTMT stmt; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + constexpr const char* dbRoot = "/local"; + const std::string catA = std::string(dbRoot) + "/odbc_cat_a"; + const std::string catB = std::string(dbRoot) + "/odbc_cat_b"; + SQLCHAR dropAQuery[] = "DROP TABLE IF EXISTS `odbc_cat_a/probe`"; + SQLCHAR dropBQuery[] = "DROP TABLE IF EXISTS `odbc_cat_b/probe`"; + SQLCHAR createAQuery[] = + "CREATE TABLE `odbc_cat_a/probe` (id Int32, value Int32, PRIMARY KEY (id))"; + SQLCHAR createBQuery[] = + "CREATE TABLE `odbc_cat_b/probe` (id Int32, value Int32, PRIMARY KEY (id))"; + SQLCHAR upsertAQuery[] = "UPSERT INTO `odbc_cat_a/probe` (id, value) VALUES (1, 100)"; + SQLCHAR upsertBQuery[] = "UPSERT INTO `odbc_cat_b/probe` (id, value) VALUES (1, 200)"; + SQLCHAR selectAQuery[] = "SELECT value FROM `odbc_cat_a/probe` WHERE id = 1"; + SQLCHAR selectQuery[] = "SELECT value FROM probe WHERE id = 1"; + + char catalog[256] = {0}; + SQLINTEGER textLen = 0; + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_CURRENT_CATALOG, catalog, sizeof(catalog), &textLen), SQL_SUCCESS); + ASSERT_STREQ(catalog, dbRoot); + + + CHECK_ODBC_OK(SQLExecDirect(stmt, dropAQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, dropBQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, createAQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, createBQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, upsertAQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, upsertBQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, selectAQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + SQLINTEGER valueInt = 0; + SQLLEN valueInd = 0; + + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_CURRENT_CATALOG, (SQLPOINTER)catA.c_str(), SQL_NTS), dbc, + SQL_HANDLE_DBC); + std::memset(catalog, 0, sizeof(catalog)); + textLen = 0; + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_CURRENT_CATALOG, catalog, sizeof(catalog), &textLen), SQL_SUCCESS); + ASSERT_STREQ(catalog, catA.c_str()); + CHECK_ODBC_OK(SQLExecDirect(stmt, selectQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &valueInt, 0, &valueInd), SQL_SUCCESS); + ASSERT_EQ(valueInt, 100); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + + valueInt = 0; + valueInd = 0; + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_CURRENT_CATALOG, (SQLPOINTER)catB.c_str(), SQL_NTS), dbc, + SQL_HANDLE_DBC); + std::memset(catalog, 0, sizeof(catalog)); + textLen = 0; + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_CURRENT_CATALOG, catalog, sizeof(catalog), &textLen), SQL_SUCCESS); + ASSERT_STREQ(catalog, catB.c_str()); + CHECK_ODBC_OK(SQLExecDirect(stmt, selectQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &valueInt, 0, &valueInd), SQL_SUCCESS); + ASSERT_EQ(valueInt, 200); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + + const std::string catWithSlashes = catB + "///"; + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_CURRENT_CATALOG, (SQLPOINTER)catWithSlashes.c_str(), SQL_NTS), dbc, + SQL_HANDLE_DBC); + std::memset(catalog, 0, sizeof(catalog)); + textLen = 0; + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_CURRENT_CATALOG, catalog, sizeof(catalog), &textLen), SQL_SUCCESS); + ASSERT_STREQ(catalog, catB.c_str()); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + From 665d9dc76d9e5d6e504259e4a7d247a8b2fd8b11 Mon Sep 17 00:00:00 2001 From: Ylonies Date: Sat, 18 Apr 2026 17:53:17 +0000 Subject: [PATCH 7/9] stmt attr --- odbc/CMakeLists.txt | 2 + odbc/src/odbc_driver.cpp | 17 + odbc/src/statement.cpp | 67 ++- odbc/src/statement.h | 5 + odbc/src/statement_attr.cpp | 101 ++++ odbc/src/statement_attr.h | 39 ++ odbc/src/utils/convert.cpp | 74 +++ odbc/src/utils/cursor.cpp | 14 +- odbc/src/utils/escape.cpp | 444 ++++++++++++++++++ odbc/src/utils/escape.h | 9 + odbc/src/utils/sql_like.h | 49 ++ odbc/tests/integration/CMakeLists.txt | 5 + odbc/tests/integration/attr_it.cpp | 21 - odbc/tests/integration/env_it.cpp | 9 - odbc/tests/integration/stmt_attr_it.cpp | 334 +++++++++++++ odbc/tests/integration/test_utils.h | 24 +- odbc/tests/unit/CMakeLists.txt | 23 + odbc/tests/unit/escape_ut.cpp | 71 +++ odbc/tests/unit/sql_like_ut.cpp | 28 ++ .../unit/library/operation_id/CMakeLists.txt | 1 + 20 files changed, 1294 insertions(+), 43 deletions(-) create mode 100644 odbc/src/statement_attr.cpp create mode 100644 odbc/src/statement_attr.h create mode 100644 odbc/src/utils/escape.cpp create mode 100644 odbc/src/utils/escape.h create mode 100644 odbc/src/utils/sql_like.h create mode 100644 odbc/tests/integration/stmt_attr_it.cpp create mode 100644 odbc/tests/unit/escape_ut.cpp create mode 100644 odbc/tests/unit/sql_like_ut.cpp diff --git a/odbc/CMakeLists.txt b/odbc/CMakeLists.txt index 5071c42f85d..06386fd31dd 100644 --- a/odbc/CMakeLists.txt +++ b/odbc/CMakeLists.txt @@ -1,5 +1,6 @@ add_library(ydb-odbc SHARED src/utils/attr.cpp + src/utils/escape.cpp src/utils/cursor.cpp src/utils/types.cpp src/utils/util.cpp @@ -8,6 +9,7 @@ add_library(ydb-odbc SHARED src/odbc_driver.cpp src/connection_attr.cpp src/connection.cpp + src/statement_attr.cpp src/statement.cpp src/environment.cpp ) diff --git a/odbc/src/odbc_driver.cpp b/odbc/src/odbc_driver.cpp index 6b516c63bb8..cba323453af 100644 --- a/odbc/src/odbc_driver.cpp +++ b/odbc/src/odbc_driver.cpp @@ -386,4 +386,21 @@ SQLRETURN SQL_API SQLNumResultCols(SQLHSTMT statementHandle, SQLSMALLINT* colCou }); } +SQLRETURN SQL_API SQLSetStmtAttr(SQLHSTMT statementHandle, SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER stringLength) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->SetStmtAttr(attribute, value, stringLength); + }); +} + +SQLRETURN SQL_API SQLGetStmtAttr( + SQLHSTMT statementHandle, + SQLINTEGER attribute, + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->GetStmtAttr(attribute, value, bufferLength, stringLengthPtr); + }); +} + } diff --git a/odbc/src/statement.cpp b/odbc/src/statement.cpp index dd657f304ba..f4b04ec0be2 100644 --- a/odbc/src/statement.cpp +++ b/odbc/src/statement.cpp @@ -3,10 +3,14 @@ #include "utils/convert.h" #include "utils/types.h" #include "utils/error_manager.h" +#include "utils/escape.h" +#include "utils/sql_like.h" #include #include +#include + namespace NYdb { namespace NOdbc { @@ -15,6 +19,7 @@ TStatement::TStatement(TConnection* conn) SQLRETURN TStatement::Prepare(const std::string& statementText) { StreamFetchError_ = false; + RowsFetched_ = 0; Cursor_.reset(); PreparedQuery_ = statementText; IsPrepared_ = true; @@ -26,6 +31,7 @@ SQLRETURN TStatement::Execute() { throw TOdbcException("HY007", 0, "No prepared statement"); } StreamFetchError_ = false; + RowsFetched_ = 0; Cursor_.reset(); auto* client = Conn_->GetClient(); if (!client) { @@ -54,19 +60,27 @@ SQLRETURN TStatement::Execute() { } NQuery::TExecuteQueryIterator TStatement::CreateExecuteIterator(NQuery::TSession& session, const NYdb::TParams& params){ - const std::string queryText = Conn_->WrapQueryForCurrentCatalog(PreparedQuery_); + const std::string sqlText = Attributes_.GetNoScanMode() == SQL_NOSCAN_ON + ? PreparedQuery_ + : RewriteOdbcEscapes(PreparedQuery_); + const std::string queryText = Conn_->WrapQueryForCurrentCatalog(sqlText); + NQuery::TExecuteQuerySettings execSettings; + const SQLUINTEGER queryTimeoutSec = Attributes_.GetQueryTimeoutSec(); + execSettings.ClientTimeout(TDuration::Seconds(queryTimeoutSec)); if (Conn_->GetAutocommit()) { const auto txSettings = Conn_->MakeTxSettings(); if (txSettings.GetMode() == NQuery::TTxSettings::TS_SERIALIZABLE_RW) { return session.StreamExecuteQuery( queryText, NQuery::TTxControl::NoTx(), - params).ExtractValueSync(); + params, + execSettings).ExtractValueSync(); } return session.StreamExecuteQuery( queryText, NQuery::TTxControl::BeginTx(txSettings).CommitTx(), - params).ExtractValueSync(); + params, + execSettings).ExtractValueSync(); } if (!Conn_->GetTx()) { auto beginTxResult = session.BeginTransaction(Conn_->MakeTxSettings()).ExtractValueSync(); @@ -76,7 +90,8 @@ NQuery::TExecuteQueryIterator TStatement::CreateExecuteIterator(NQuery::TSession return session.StreamExecuteQuery( queryText, NQuery::TTxControl::Tx(*Conn_->GetTx()).CommitTx(false), - params).ExtractValueSync(); + params, + execSettings).ExtractValueSync(); } std::optional TStatement::PrefetchFirstResultPart(NQuery::TExecuteQueryIterator& iterator){ @@ -102,10 +117,15 @@ SQLRETURN TStatement::Fetch() { Cursor_.reset(); return SQL_NO_DATA; } + const SQLULEN maxRows = Attributes_.GetMaxRows(); + if (maxRows > 0 && RowsFetched_ >= maxRows) { + return SQL_NO_DATA; + } StreamFetchError_ = false; if (!Cursor_->Fetch()) { return StreamFetchError_ ? SQL_ERROR : SQL_NO_DATA; } + ++RowsFetched_; return SQL_SUCCESS; } @@ -187,6 +207,7 @@ SQLRETURN TStatement::Columns(const std::string& catalogName, const std::string& tableName, const std::string& columnName) { ClearErrors(); + RowsFetched_ = 0; Cursor_.reset(); std::vector columns = { @@ -224,18 +245,24 @@ SQLRETURN TStatement::Columns(const std::string& catalogName, continue; } - auto status = Conn_->GetTableClient()->RetryOperationSync([path = entry.Name, &table, &columnName](NTable::TSession session) -> TStatus { + auto status = Conn_->GetTableClient()->RetryOperationSync([this, path = entry.Name, &table, &columnName](NTable::TSession session) -> TStatus { auto result = session.DescribeTable(path).ExtractValueSync(); NStatusHelpers::ThrowOnError(result); auto columns = result.GetTableDescription().GetTableColumns(); - auto columnIt = std::find_if(columns.begin(), columns.end(), [&columnName](const NTable::TTableColumn& column) { - return column.Name == columnName; + auto columnIt = std::find_if(columns.begin(), columns.end(), [&](const NTable::TTableColumn& column) { + if (Attributes_.GetMetadataId() == SQL_TRUE) { + return column.Name == columnName; + } + if (columnName.empty()) { + return column.Name.empty(); + } + return SqlLikeMatch(column.Name, columnName); }); if (columnIt == columns.end()) { - return TStatus(EStatus::NOT_FOUND, { NYdb::NIssue::TIssue("Column not found") }); + throw TOdbcException("42S22", 0, "Column not found", SQL_ERROR); } auto column = *columnIt; @@ -277,6 +304,7 @@ SQLRETURN TStatement::Tables(const std::string& catalogName, const std::string& tableName, const std::string& tableType) { ClearErrors(); + RowsFetched_ = 0; Cursor_.reset(); std::vector columns = { @@ -340,7 +368,13 @@ SQLRETURN TStatement::VisitEntry(const std::string& path, const std::string& pat } bool TStatement::IsPatternMatch(const std::string& path, const std::string& pattern) { - return path.starts_with(pattern); + if (pattern.empty()) { + return true; + } + if (Attributes_.GetMetadataId() == SQL_TRUE) { + return path == pattern; + } + return SqlLikeMatch(path, pattern); } std::optional TStatement::GetTableType(NScheme::ESchemeEntryType type) { @@ -375,9 +409,15 @@ std::optional TStatement::GetTableType(NScheme::ESchemeEntryType ty return "COORDINATION_NODE"; case NScheme::ESchemeEntryType::Unknown: return "UNKNOWN"; + case NScheme::ESchemeEntryType::SysView: + return "SYSTEM VIEW"; + case NScheme::ESchemeEntryType::Transfer: + return "TRANSFER"; case NScheme::ESchemeEntryType::Directory: case NScheme::ESchemeEntryType::SubDomain: return std::nullopt; + default: + return std::nullopt; } } @@ -387,6 +427,7 @@ SQLRETURN TStatement::Close(bool force) { } Cursor_.reset(); + RowsFetched_ = 0; PreparedQuery_.clear(); IsPrepared_ = false; ClearErrors(); @@ -422,5 +463,13 @@ SQLRETURN TStatement::NumResultCols(SQLSMALLINT* colCount) { return SQL_SUCCESS; } +SQLRETURN TStatement::SetStmtAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER stringLength) { + return Attributes_.SetStmtAttr(attr, value, stringLength, *this); +} + +SQLRETURN TStatement::GetStmtAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER bufferLength, SQLINTEGER* stringLengthPtr) { + return Attributes_.GetStmtAttr(attr, value, bufferLength, stringLengthPtr, *this); +} + } // namespace NOdbc } // namespace NYdb diff --git a/odbc/src/statement.h b/odbc/src/statement.h index f17780957bb..702fe56c71e 100644 --- a/odbc/src/statement.h +++ b/odbc/src/statement.h @@ -1,6 +1,7 @@ #pragma once #include "connection.h" +#include "statement_attr.h" #include "utils/error_manager.h" #include "utils/bindings.h" #include "utils/cursor.h" @@ -51,6 +52,8 @@ class TStatement : public TErrorManager, public IBindingFiller { SQLRETURN RowCount(SQLLEN* rowCount); SQLRETURN NumResultCols(SQLSMALLINT* colCount); + SQLRETURN SetStmtAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER stringLength); + SQLRETURN GetStmtAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER bufferLength, SQLINTEGER* stringLengthPtr); TConnection* GetConnection() { return Conn_; @@ -65,6 +68,8 @@ class TStatement : public TErrorManager, public IBindingFiller { std::vector BoundColumns_; std::vector BoundParams_; bool StreamFetchError_ = false; + SQLULEN RowsFetched_ = 0; + TStatementAttributes Attributes_; NYdb::TParams BuildParams(); diff --git a/odbc/src/statement_attr.cpp b/odbc/src/statement_attr.cpp new file mode 100644 index 00000000000..f0baad0016a --- /dev/null +++ b/odbc/src/statement_attr.cpp @@ -0,0 +1,101 @@ +#include "statement_attr.h" + +#include "utils/attr.h" +#include "utils/diag.h" + +#include + +namespace NYdb { +namespace NOdbc { + +SQLRETURN TStatementAttributes::SetStmtAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER /*stringLength*/, + TErrorManager& errors) { + switch (attr) { + case SQL_ATTR_QUERY_TIMEOUT: { + const SQLINTEGER timeout = ReadIntegerAttr(value); + if (timeout < 0) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_QUERY_TIMEOUT"); + } + QueryTimeoutSec_ = static_cast(timeout); + return SQL_SUCCESS; + } + case SQL_ATTR_MAX_ROWS: { + const SQLLEN maxRows = ReadIntegerAttr(value); + if (maxRows < 0) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_MAX_ROWS"); + } + MaxRows_ = static_cast(maxRows); + return SQL_SUCCESS; + } + case SQL_ATTR_NOSCAN: { + const auto mode = ReadIntegerAttrIfIn(value, {SQL_NOSCAN_OFF, SQL_NOSCAN_ON}); + if (!mode) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_NOSCAN"); + } + NoScan_ = *mode; + return SQL_SUCCESS; + } + case SQL_ATTR_METADATA_ID: { + const auto mode = ReadIntegerAttrIfIn(value, {SQL_FALSE, SQL_TRUE}); + if (!mode) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_METADATA_ID"); + } + MetadataId_ = *mode; + return SQL_SUCCESS; + } + default: + return Diag::AddNotImplemented(errors); + } +} + +SQLRETURN TStatementAttributes::GetStmtAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER /*bufferLength*/, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) const { + if (!value) { + return Diag::AddNullPointer(errors); + } + if (stringLengthPtr) { + *stringLengthPtr = 0; + } + switch (attr) { + case SQL_ATTR_QUERY_TIMEOUT: + *reinterpret_cast(value) = QueryTimeoutSec_; + return SQL_SUCCESS; + case SQL_ATTR_MAX_ROWS: + *reinterpret_cast(value) = MaxRows_; + return SQL_SUCCESS; + case SQL_ATTR_NOSCAN: + *reinterpret_cast(value) = NoScan_; + return SQL_SUCCESS; + case SQL_ATTR_METADATA_ID: + *reinterpret_cast(value) = MetadataId_; + return SQL_SUCCESS; + default: + return Diag::AddNotImplemented(errors); + } +} + +SQLUINTEGER TStatementAttributes::GetQueryTimeoutSec() const noexcept{ + return QueryTimeoutSec_; +} + +SQLULEN TStatementAttributes::GetMaxRows() const noexcept { + return MaxRows_; +} + +SQLULEN TStatementAttributes::GetNoScanMode() const noexcept { + return NoScan_; +} + +SQLULEN TStatementAttributes::GetMetadataId() const noexcept { + return MetadataId_; +} + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/statement_attr.h b/odbc/src/statement_attr.h new file mode 100644 index 00000000000..b0d6e9bd97f --- /dev/null +++ b/odbc/src/statement_attr.h @@ -0,0 +1,39 @@ +#pragma once + +#include "utils/error_manager.h" + +#include +#include + +namespace NYdb { +namespace NOdbc { + +class TStatementAttributes { +public: + SQLRETURN SetStmtAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER stringLength, + TErrorManager& errors); + + SQLRETURN GetStmtAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) const; + + SQLUINTEGER GetQueryTimeoutSec() const noexcept; + SQLULEN GetMaxRows() const noexcept; + SQLULEN GetNoScanMode() const noexcept; + SQLULEN GetMetadataId() const noexcept; + +private: + SQLUINTEGER QueryTimeoutSec_ = 0; + SQLULEN MaxRows_ = 0; + SQLULEN NoScan_ = SQL_NOSCAN_OFF; + SQLULEN MetadataId_ = SQL_FALSE; +}; + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/convert.cpp b/odbc/src/utils/convert.cpp index 224f228e498..4e415c65521 100644 --- a/odbc/src/utils/convert.cpp +++ b/odbc/src/utils/convert.cpp @@ -1,7 +1,10 @@ #include "convert.h" +#include #include +#include + namespace NYdb { namespace NOdbc { @@ -311,6 +314,28 @@ SQLRETURN ConvertColumn(TValueParser& parser, SQLSMALLINT targetType, SQLPOINTER EPrimitiveType ydbType = parser.GetPrimitiveType(); switch (targetType) { + case SQL_C_SHORT: + case SQL_C_SSHORT: + { + SQLSMALLINT v = 0; + switch (ydbType) { + case EPrimitiveType::Int16: v = parser.GetInt16(); break; + case EPrimitiveType::Uint16: v = static_cast(parser.GetUint16()); break; + case EPrimitiveType::Int8: v = static_cast(parser.GetInt8()); break; + case EPrimitiveType::Uint8: v = static_cast(parser.GetUint8()); break; + case EPrimitiveType::Int32: v = static_cast(parser.GetInt32()); break; + case EPrimitiveType::Uint32: v = static_cast(parser.GetUint32()); break; + case EPrimitiveType::Bool: v = parser.GetBool() ? 1 : 0; break; + default: return SQL_ERROR; + } + if (targetValue) { + *reinterpret_cast(targetValue) = v; + } + if (strLenOrInd) { + *strLenOrInd = sizeof(SQLSMALLINT); + } + return SQL_SUCCESS; + } case SQL_C_SLONG: case SQL_C_LONG: { @@ -377,6 +402,55 @@ SQLRETURN ConvertColumn(TValueParser& parser, SQLSMALLINT targetType, SQLPOINTER case EPrimitiveType::String: str = parser.GetString(); break; case EPrimitiveType::Json: str = parser.GetJson(); break; case EPrimitiveType::JsonDocument: str = parser.GetJsonDocument(); break; + case EPrimitiveType::Date: { + const TString t = parser.GetDate().FormatGmTime("%Y-%m-%d"); + str.assign(t.data(), t.size()); + break; + } + case EPrimitiveType::Date32: { + const i32 days = parser.GetDate32(); + if (days < 0) { + return SQL_ERROR; + } + const TString t = TInstant::Days(static_cast(days)).FormatGmTime("%Y-%m-%d"); + str.assign(t.data(), t.size()); + break; + } + case EPrimitiveType::Datetime: { + const TString t = parser.GetDatetime().FormatGmTime("%Y-%m-%d %H:%M:%S"); + str.assign(t.data(), t.size()); + break; + } + case EPrimitiveType::Datetime64: { + const std::int64_t secs = parser.GetDatetime64(); + if (secs < 0) { + return SQL_ERROR; + } + const TString t = + TInstant::Seconds(static_cast(static_cast(secs))) + .FormatGmTime("%Y-%m-%d %H:%M:%S"); + str.assign(t.data(), t.size()); + break; + } + case EPrimitiveType::Timestamp: { + const TString t = parser.GetTimestamp().FormatGmTime("%Y-%m-%d %H:%M:%S"); + str.assign(t.data(), t.size()); + break; + } + case EPrimitiveType::Timestamp64: { + const std::int64_t micros = parser.GetTimestamp64(); + if (micros < 0) { + return SQL_ERROR; + } + const TString t = + TInstant::MicroSeconds(static_cast(static_cast(micros))) + .FormatGmTime("%Y-%m-%d %H:%M:%S"); + str.assign(t.data(), t.size()); + break; + } + case EPrimitiveType::TzDate: str = parser.GetTzDate(); break; + case EPrimitiveType::TzDatetime: str = parser.GetTzDatetime(); break; + case EPrimitiveType::TzTimestamp: str = parser.GetTzTimestamp(); break; default: return SQL_ERROR; } SQLLEN len = str.size(); diff --git a/odbc/src/utils/cursor.cpp b/odbc/src/utils/cursor.cpp index efbcea9a419..26ad393b03a 100644 --- a/odbc/src/utils/cursor.cpp +++ b/odbc/src/utils/cursor.cpp @@ -3,6 +3,8 @@ #include "convert.h" #include "types.h" +#include + namespace NYdb { namespace NOdbc { @@ -40,7 +42,17 @@ class TExecCursor : public ICursor { return false; } if (part.HasResultSet()) { - ResultSetParser_ = std::make_unique(part.ExtractResultSet()); + TResultSet rs = part.ExtractResultSet(); + Columns_.clear(); + Columns_.reserve(rs.ColumnsCount()); + for (const auto& col : rs.GetColumnsMeta()) { + Columns_.push_back(TColumnMeta{ + col.Name, + GetTypeId(col.Type), + 0, + IsNullable(col.Type)}); + } + ResultSetParser_ = std::make_unique(rs); } } return false; diff --git a/odbc/src/utils/escape.cpp b/odbc/src/utils/escape.cpp new file mode 100644 index 00000000000..5a9c643eb7a --- /dev/null +++ b/odbc/src/utils/escape.cpp @@ -0,0 +1,444 @@ +#include "escape.h" + +#include +#include +#include +#include +#include + +namespace NYdb::NOdbc { +namespace { + +bool EqualNoCase(std::string_view lhs, std::string_view rhs) { + return lhs.size() == rhs.size() && + std::equal(lhs.begin(), lhs.end(), rhs.begin(), [](char leftCh, char rightCh) { + return std::tolower(static_cast(leftCh)) == + std::tolower(static_cast(rightCh)); + }); +} + +void SkipLeadingWhitespace(std::string_view sql, size_t& cursor) { + const auto strEnd = sql.end(); + const auto firstNonSpace = std::find_if_not( + sql.begin() + static_cast(cursor), + strEnd, + [](unsigned char byte) { + return std::isspace(byte) != 0; + }); + cursor = static_cast(firstNonSpace - sql.begin()); +} + +bool ReadIdent(std::string_view sql, size_t& cursor, std::string_view* outIdent) { + SkipLeadingWhitespace(sql, cursor); + const size_t identStart = cursor; + const auto afterIdent = std::find_if_not( + sql.begin() + static_cast(cursor), + sql.end(), + [](unsigned char byte) { + return std::isalpha(byte) != 0 || byte == '_'; + }); + cursor = static_cast(afterIdent - sql.begin()); + if (cursor == identStart) { + return false; + } + *outIdent = std::string_view(sql.data() + identStart, cursor - identStart); + return true; +} + +bool ParseSingleQuoted(std::string_view sql, size_t& cursor, std::string* outValue) { + SkipLeadingWhitespace(sql, cursor); + if (cursor >= sql.size() || sql[cursor] != '\'') { + return false; + } + ++cursor; + outValue->clear(); + while (cursor < sql.size()) { + if (sql[cursor] == '\'') { + if (cursor + 1 < sql.size() && sql[cursor + 1] == '\'') { + outValue->push_back('\''); + cursor += 2; + continue; + } + ++cursor; + return true; + } + outValue->push_back(sql[cursor++]); + } + return false; +} + +size_t FindMatchingCloseBrace(std::string_view sql, size_t openBrace) { + if (openBrace >= sql.size() || sql[openBrace] != '{') { + return std::string_view::npos; + } + int braceDepth = 1; + for (size_t idx = openBrace + 1; idx < sql.size(); ++idx) { + if (sql[idx] == '{') { + ++braceDepth; + } else if (sql[idx] == '}') { + --braceDepth; + if (braceDepth == 0) { + return idx; + } + } + } + return std::string_view::npos; +} + +std::string NormalizeOdbcTimestampLiteral(const std::string& raw) { + std::string normalized = raw; + const auto firstSpace = std::find(normalized.begin(), normalized.end(), ' '); + if (firstSpace != normalized.end()) { + *firstSpace = 'T'; + } + if (std::find(normalized.begin(), normalized.end(), 'Z') == normalized.end()) { + normalized.push_back('Z'); + } + return normalized; +} + +std::string ToUpperAscii(std::string_view sv) { + std::string upper; + upper.resize(sv.size()); + std::transform(sv.begin(), sv.end(), upper.begin(), [](unsigned char byte) { + return static_cast(std::toupper(byte)); + }); + return upper; +} + +std::string MapSqlTypeToken(std::string_view sqlType) { + static const std::unordered_map kMap = { + {"CHAR", "Utf8"}, + {"VARCHAR", "Utf8"}, + {"LONGVARCHAR", "Utf8"}, + {"WCHAR", "Utf8"}, + {"WVARCHAR", "Utf8"}, + {"WLONGVARCHAR", "Utf8"}, + {"BIT", "Bool"}, + {"TINYINT", "Int8"}, + {"SMALLINT", "Int16"}, + {"INTEGER", "Int32"}, + {"BIGINT", "Int64"}, + {"REAL", "Float"}, + {"FLOAT", "Double"}, + {"DOUBLE", "Double"}, + {"DECIMAL", "Decimal(22, 9)"}, + {"NUMERIC", "Decimal(22, 9)"}, + {"BINARY", "String"}, + {"VARBINARY", "String"}, + {"LONGVARBINARY", "String"}, + {"DATE", "Date"}, + {"TIME", "Time"}, + {"TIMESTAMP", "Datetime"}, + {"TYPE_DATE", "Date"}, + {"TYPE_TIME", "Time"}, + {"TYPE_TIMESTAMP", "Datetime"}, + }; + std::string key = ToUpperAscii(sqlType); + const std::string kSql = "SQL_"; + if (key.size() > kSql.size() && key.compare(0, kSql.size(), kSql) == 0) { + key.erase(0, kSql.size()); + } + const auto mapped = kMap.find(key); + if (mapped != kMap.end()) { + return mapped->second; + } + return key; +} + +std::string RewriteOdbcEscapesImpl(std::string_view sql); + + +enum class OdbcBraceKind { + OutputProcedureCall, // {?= call ... } + FnBody, // {fn ...} + OjBody, // {oj ...} + DateLiteral, // {d '...'} + TimeLiteral, // {t '...'} + TimestampLiteral, // {ts '...'} + ProcedureCall, // {call ...} + LikeEscape, // {escape '...'} +}; + +struct OdbcBraceParsed { + OdbcBraceKind Kind; + std::string_view RecurseTail; + std::string QuotedValue; +}; + +std::optional TryParseOutputCallBrace(std::string_view sql, size_t parsePos, size_t closeBrace) { + if (parsePos + 1 >= sql.size() || sql[parsePos] != '?' || sql[parsePos + 1] != '=') { + return std::nullopt; + } + size_t inner = parsePos + 2; + SkipLeadingWhitespace(sql, inner); + std::string_view keyword; + if (!ReadIdent(sql, inner, &keyword) || !EqualNoCase(keyword, "call")) { + return std::nullopt; + } + SkipLeadingWhitespace(sql, inner); + if (inner > closeBrace) { + return std::nullopt; + } + OdbcBraceParsed parsed; + parsed.Kind = OdbcBraceKind::OutputProcedureCall; + parsed.RecurseTail = std::string_view(sql.data() + inner, closeBrace - inner); + return parsed; +} + +std::optional MakeRecurseTailBrace(OdbcBraceKind kind, std::string_view sql, size_t& parsePos, size_t closeBrace) { + SkipLeadingWhitespace(sql, parsePos); + if (parsePos > closeBrace) { + return std::nullopt; + } + OdbcBraceParsed parsed; + parsed.Kind = kind; + parsed.RecurseTail = std::string_view(sql.data() + parsePos, closeBrace - parsePos); + return parsed; +} + +std::optional MakeQuotedBrace(OdbcBraceKind kind, std::string_view sql, size_t& parsePos, size_t closeBrace) { + std::string quotedLit; + if (!ParseSingleQuoted(sql, parsePos, "edLit) || parsePos > closeBrace) { + return std::nullopt; + } + SkipLeadingWhitespace(sql, parsePos); + if (parsePos != closeBrace) { + return std::nullopt; + } + OdbcBraceParsed parsed; + parsed.Kind = kind; + parsed.QuotedValue = std::move(quotedLit); + return parsed; +} + +struct BraceKeywordSpec { + std::string_view Keyword; + OdbcBraceKind Kind; + bool IsQuotedLiteral; +}; + +static constexpr BraceKeywordSpec kBraceKeywordSpecs[] = { + {"fn", OdbcBraceKind::FnBody, false}, + {"oj", OdbcBraceKind::OjBody, false}, + {"d", OdbcBraceKind::DateLiteral, true}, + {"t", OdbcBraceKind::TimeLiteral, true}, + {"ts", OdbcBraceKind::TimestampLiteral, true}, + {"call", OdbcBraceKind::ProcedureCall, false}, + {"escape", OdbcBraceKind::LikeEscape, true}, +}; + +std::optional TryParseOdbcBrace(std::string_view sql, size_t openBrace, size_t closeBrace) { + size_t parsePos = openBrace + 1; + SkipLeadingWhitespace(sql, parsePos); + + if (std::optional outputCall = TryParseOutputCallBrace(sql, parsePos, closeBrace)) { + return outputCall; + } + if (parsePos + 1 < sql.size() && sql[parsePos] == '?' && sql[parsePos + 1] == '=') { + return std::nullopt; + } + + std::string_view token; + if (!ReadIdent(sql, parsePos, &token)) { + return std::nullopt; + } + + for (const BraceKeywordSpec& spec : kBraceKeywordSpecs) { + if (!EqualNoCase(token, spec.Keyword)) { + continue; + } + if (spec.IsQuotedLiteral) { + return MakeQuotedBrace(spec.Kind, sql, parsePos, closeBrace); + } + return MakeRecurseTailBrace(spec.Kind, sql, parsePos, closeBrace); + } + + return std::nullopt; +} + +void AppendRewrittenBrace(std::string& rewritten, const OdbcBraceParsed& parsed) { + switch (parsed.Kind) { + case OdbcBraceKind::OutputProcedureCall: + case OdbcBraceKind::ProcedureCall: + rewritten += "CALL "; + rewritten.append(RewriteOdbcEscapesImpl(parsed.RecurseTail)); + return; + case OdbcBraceKind::FnBody: + case OdbcBraceKind::OjBody: + rewritten.append(RewriteOdbcEscapesImpl(parsed.RecurseTail)); + return; + case OdbcBraceKind::DateLiteral: + rewritten += "CAST('"; + rewritten += parsed.QuotedValue; + rewritten += "' AS Date)"; + return; + case OdbcBraceKind::TimeLiteral: + rewritten += "CAST('"; + rewritten += parsed.QuotedValue; + rewritten += "' AS Time)"; + return; + case OdbcBraceKind::TimestampLiteral: { + const std::string normalizedTs = NormalizeOdbcTimestampLiteral(parsed.QuotedValue); + rewritten += "CAST('"; + rewritten += normalizedTs; + rewritten += "' AS Datetime)"; + return; + } + case OdbcBraceKind::LikeEscape: + rewritten += " ESCAPE '"; + rewritten += parsed.QuotedValue; + rewritten += '\''; + return; + } +} + +std::string RewriteOdbcEscapesImpl(std::string_view sql) { + std::string rewritten; + rewritten.reserve(sql.size()); + + for (size_t readPos = 0; readPos < sql.size();) { + if (sql[readPos] != '{') { + rewritten.push_back(sql[readPos++]); + continue; + } + + const size_t closeBrace = FindMatchingCloseBrace(sql, readPos); + if (closeBrace == std::string_view::npos) { + rewritten.push_back(sql[readPos++]); + continue; + } + + if (std::optional parsedBrace = TryParseOdbcBrace(sql, readPos, closeBrace)) { + AppendRewrittenBrace(rewritten, *parsedBrace); + readPos = closeBrace + 1; + continue; + } + + rewritten.push_back(sql[readPos++]); + } + + return rewritten; +} + +std::string RewriteOdbcConvertCalls(std::string_view sql); + +class TOdbcConvertCallRewriter { +public: + explicit TOdbcConvertCallRewriter(std::string_view sql) + : Sql_(sql) { + Rewritten_.reserve(sql.size()); + } + + std::string TakeResult() && { + return std::move(Rewritten_); + } + + void Run() { + while (SegmentStart_ < Sql_.size()) { + const std::optional convertKeywordPos = FindNextConvertKeyword(SegmentStart_); + if (!convertKeywordPos) { + Rewritten_.append(Sql_.substr(SegmentStart_)); + break; + } + Rewritten_.append(Sql_.substr(SegmentStart_, *convertKeywordPos - SegmentStart_)); + if (!TryRewriteConvertAt(*convertKeywordPos)) { + break; + } + } + } + +private: + static constexpr size_t kConvertTokenLen = 7; + + std::optional FindNextConvertKeyword(size_t from) const { + for (size_t probePos = from; probePos + kConvertTokenLen <= Sql_.size(); ++probePos) { + if (!EqualNoCase(Sql_.substr(probePos, kConvertTokenLen), "CONVERT")) { + continue; + } + size_t afterKeyword = probePos + kConvertTokenLen; + SkipLeadingWhitespace(Sql_, afterKeyword); + if (afterKeyword < Sql_.size() && Sql_[afterKeyword] == '(') { + return probePos; + } + } + return std::nullopt; + } + + bool TryRewriteConvertAt(size_t convertKeywordPos) { + size_t parsePos = convertKeywordPos + kConvertTokenLen; + SkipLeadingWhitespace(Sql_, parsePos); + if (parsePos >= Sql_.size() || Sql_[parsePos] != '(') { + Rewritten_.append(Sql_.substr(convertKeywordPos, kConvertTokenLen)); + SegmentStart_ = convertKeywordPos + kConvertTokenLen; + return true; + } + ++parsePos; + + int parenDepth = 1; + const size_t firstArgStart = parsePos; + std::optional typeCommaPos; + for (; parsePos < Sql_.size(); ++parsePos) { + if (Sql_[parsePos] == '(') { + ++parenDepth; + } else if (Sql_[parsePos] == ')') { + --parenDepth; + } else if (Sql_[parsePos] == ',' && parenDepth == 1) { + typeCommaPos = parsePos; + break; + } + } + if (!typeCommaPos) { + Rewritten_.append(Sql_.substr(convertKeywordPos)); + return false; + } + + const std::string_view firstArg(Sql_.data() + firstArgStart, *typeCommaPos - firstArgStart); + parsePos = *typeCommaPos + 1; + SkipLeadingWhitespace(Sql_, parsePos); + const size_t sqlTypeStart = parsePos; + const auto sqlTypeEnd = std::find_if_not( + Sql_.begin() + static_cast(parsePos), + Sql_.end(), + [](unsigned char byte) { + return std::isalpha(byte) != 0 || byte == '_'; + }); + parsePos = static_cast(sqlTypeEnd - Sql_.begin()); + const std::string_view sqlTypeToken(Sql_.data() + sqlTypeStart, parsePos - sqlTypeStart); + SkipLeadingWhitespace(Sql_, parsePos); + if (parsePos >= Sql_.size() || Sql_[parsePos] != ')') { + Rewritten_.append(Sql_.substr(convertKeywordPos)); + return false; + } + + const std::string yqlType = MapSqlTypeToken(sqlTypeToken); + Rewritten_ += "CAST("; + Rewritten_ += RewriteOdbcConvertCalls(RewriteOdbcEscapesImpl(firstArg)); + Rewritten_ += " AS "; + Rewritten_ += yqlType; + Rewritten_ += ')'; + SegmentStart_ = parsePos + 1; + return true; + } + + std::string_view Sql_; + std::string Rewritten_; + size_t SegmentStart_ = 0; +}; + +std::string RewriteOdbcConvertCalls(std::string_view sql) { + TOdbcConvertCallRewriter rewriter(sql); + rewriter.Run(); + return std::move(rewriter).TakeResult(); +} + +} // namespace + + + +std::string RewriteOdbcEscapes(const std::string& sql) { + std::string afterBraceRewrite = RewriteOdbcEscapesImpl(sql); + return RewriteOdbcConvertCalls(afterBraceRewrite); +} + +} // namespace NYdb::NOdbc diff --git a/odbc/src/utils/escape.h b/odbc/src/utils/escape.h new file mode 100644 index 00000000000..7397a128450 --- /dev/null +++ b/odbc/src/utils/escape.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace NYdb::NOdbc { + +std::string RewriteOdbcEscapes(const std::string& sql); + +} // namespace NYdb::NOdbc diff --git a/odbc/src/utils/sql_like.h b/odbc/src/utils/sql_like.h new file mode 100644 index 00000000000..f51c10ca28c --- /dev/null +++ b/odbc/src/utils/sql_like.h @@ -0,0 +1,49 @@ +#pragma once + +#include + +namespace NYdb::NOdbc { + +// SQL LIKE — '%' is any substring, '_' is any single character. +inline bool SqlLikeMatch(std::string_view text, std::string_view pattern) { + size_t textPos = 0; + size_t patPos = 0; + size_t lastPercentPat = std::string_view::npos; + size_t textStartAfterPercent = 0; + + const size_t textLen = text.size(); + const size_t patLen = pattern.size(); + + while (textPos < textLen) { + const bool morePat = patPos < patLen; + const char patCh = morePat ? pattern[patPos] : '\0'; + + if (morePat && patCh != '%' && (patCh == '_' || patCh == text[textPos])) { + ++textPos; + ++patPos; + continue; + } + + if (morePat && patCh == '%') { + lastPercentPat = patPos++; + textStartAfterPercent = textPos; + continue; + } + + if (lastPercentPat != std::string_view::npos) { + patPos = lastPercentPat + 1; + ++textStartAfterPercent; + textPos = textStartAfterPercent; + continue; + } + + return false; + } + + while (patPos < patLen && pattern[patPos] == '%') { + ++patPos; + } + return patPos == patLen; +} + +} // namespace NYdb::NOdbc diff --git a/odbc/tests/integration/CMakeLists.txt b/odbc/tests/integration/CMakeLists.txt index 39128437ced..43925350b02 100644 --- a/odbc/tests/integration/CMakeLists.txt +++ b/odbc/tests/integration/CMakeLists.txt @@ -12,3 +12,8 @@ add_odbc_test(NAME odbc-attr_it SOURCES attr_it.cpp ) + +add_odbc_test(NAME odbc-stmt_attr_it + SOURCES + stmt_attr_it.cpp +) \ No newline at end of file diff --git a/odbc/tests/integration/attr_it.cpp b/odbc/tests/integration/attr_it.cpp index 514278f8e67..2dc30446498 100644 --- a/odbc/tests/integration/attr_it.cpp +++ b/odbc/tests/integration/attr_it.cpp @@ -3,26 +3,6 @@ #include #include -namespace { - -bool SqlStatePrefix(const std::string& diag, const char* state5) { - return diag.size() >= 5 && std::strncmp(diag.c_str(), state5, 5) == 0; -} - -void AllocEnv(SQLHENV* env) { - ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, env), SQL_SUCCESS); - ASSERT_EQ(SQLSetEnvAttr(*env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); -} - -void AllocEnvAndConnect(SQLHENV* env, SQLHDBC* dbc) { - AllocEnv(env); - ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, *env, dbc), SQL_SUCCESS); - SQLRETURN rc = SQLDriverConnect( - *dbc, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE); - CHECK_ODBC_OK(rc, *dbc, SQL_HANDLE_DBC); -} - -} // namespace TEST(OdbcAttrEnv, OdbcVersionAttr) { SQLHENV env; @@ -145,7 +125,6 @@ TEST(OdbcAttrConn, TxnIsolationAttr) { ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, ¤tIsolation, sizeof(currentIsolation), nullptr), SQL_SUCCESS); ASSERT_EQ(static_cast(SQL_TXN_REPEATABLE_READ), currentIsolation); - // In read-only mode all four standard levels are accepted and remain executable. CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, (SQLPOINTER)SQL_MODE_READ_ONLY, 0), dbc, SQL_HANDLE_DBC); CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)SQL_TXN_READ_UNCOMMITTED, 0), dbc, SQL_HANDLE_DBC); CHECK_ODBC_OK(SQLExecDirect(stmt, selectOneQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); diff --git a/odbc/tests/integration/env_it.cpp b/odbc/tests/integration/env_it.cpp index fd351d127af..952c1459ad6 100644 --- a/odbc/tests/integration/env_it.cpp +++ b/odbc/tests/integration/env_it.cpp @@ -2,15 +2,6 @@ namespace { -void AllocEnvAndConnect(SQLHENV* env, SQLHDBC* dbc) { - ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, env), SQL_SUCCESS); - ASSERT_EQ(SQLSetEnvAttr(*env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); - ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, *env, dbc), SQL_SUCCESS); - SQLRETURN rc = SQLDriverConnect( - *dbc, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE); - CHECK_ODBC_OK(rc, *dbc, SQL_HANDLE_DBC); -} - void StartManualTx(SQLHDBC dbc, SQLHSTMT* stmt) { CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_OFF, 0), dbc, SQL_HANDLE_DBC); ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, stmt), SQL_SUCCESS); diff --git a/odbc/tests/integration/stmt_attr_it.cpp b/odbc/tests/integration/stmt_attr_it.cpp new file mode 100644 index 00000000000..89faf9abed0 --- /dev/null +++ b/odbc/tests/integration/stmt_attr_it.cpp @@ -0,0 +1,334 @@ +#include "test_utils.h" + +#include +#include +#include +#include + +#ifndef SQL_ATTR_METADATA_ID +#define SQL_ATTR_METADATA_ID 10029 +#endif + + +TEST(OdbcStmtAttr, QueryTimeoutAttr) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + SQLUINTEGER timeoutSec = 1; + CHECK_ODBC_OK( + SQLSetStmtAttr(stmt, SQL_ATTR_QUERY_TIMEOUT, (SQLPOINTER)(uintptr_t)timeoutSec, 0), + stmt, + SQL_HANDLE_STMT); + + SQLCHAR longQuery[] = + "SELECT COUNT(*) FROM AS_TABLE(ListMap(ListFromRange(1u, 100000000u), ($x)->(AsStruct($x AS v))))"; + ASSERT_EQ(SQLExecDirect(stmt, longQuery, SQL_NTS), SQL_ERROR); + EXPECT_TRUE(SqlStatePrefix(GetOdbcError(stmt, SQL_HANDLE_STMT), "HYT00")); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcStmtAttr, MaxRowsAttr) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + SQLCHAR dropQuery[] = "DROP TABLE IF EXISTS test_attr_max_rows"; + SQLCHAR createQuery[] = + "CREATE TABLE test_attr_max_rows (id Int32, value Int32, PRIMARY KEY (id))"; + SQLCHAR upsert1Query[] = "UPSERT INTO test_attr_max_rows (id, value) VALUES (1, 10)"; + SQLCHAR upsert2Query[] = "UPSERT INTO test_attr_max_rows (id, value) VALUES (2, 20)"; + SQLCHAR selectQuery[] = "SELECT value FROM test_attr_max_rows ORDER BY id"; + + CHECK_ODBC_OK(SQLExecDirect(stmt, dropQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, createQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, upsert1Query, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, upsert2Query, SQL_NTS), stmt, SQL_HANDLE_STMT); + + const SQLULEN maxRows = 1; + CHECK_ODBC_OK( + SQLSetStmtAttr(stmt, SQL_ATTR_MAX_ROWS, (SQLPOINTER)(uintptr_t)maxRows, 0), + stmt, + SQL_HANDLE_STMT); + + SQLULEN maxRowsOut = 0; + ASSERT_EQ(SQLGetStmtAttr(stmt, SQL_ATTR_MAX_ROWS, &maxRowsOut, 0, nullptr), SQL_SUCCESS); + ASSERT_EQ(maxRowsOut, maxRows); + + CHECK_ODBC_OK(SQLExecDirect(stmt, selectQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLFetch(stmt), SQL_NO_DATA); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcStmtAttr, NoScanAttr) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + SQLCHAR selectEscapeFnQuery[] = "SELECT {fn ABS(-12)} AS value"; + + CHECK_ODBC_OK(SQLSetStmtAttr(stmt, SQL_ATTR_NOSCAN, (SQLPOINTER)SQL_NOSCAN_OFF, 0), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, selectEscapeFnQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + SQLINTEGER valueInt = 0; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &valueInt, 0, &valueInd), SQL_SUCCESS); + ASSERT_EQ(valueInt, 12); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + + CHECK_ODBC_OK(SQLSetStmtAttr(stmt, SQL_ATTR_NOSCAN, (SQLPOINTER)SQL_NOSCAN_ON, 0), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLExecDirect(stmt, selectEscapeFnQuery, SQL_NTS), SQL_ERROR); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcStmtAttr, OdbcEscapeSequences) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + CHECK_ODBC_OK(SQLSetStmtAttr(stmt, SQL_ATTR_NOSCAN, (SQLPOINTER)SQL_NOSCAN_OFF, 0), stmt, SQL_HANDLE_STMT); + + { + SQLCHAR convertQuery[] = "SELECT {fn CONVERT(42, SQL_SMALLINT)} AS value"; + CHECK_ODBC_OK(SQLExecDirect(stmt, convertQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + SQLSMALLINT valueSmall = 0; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_SSHORT, &valueSmall, 0, &valueInd), SQL_SUCCESS); + ASSERT_EQ(valueSmall, 42); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + } + + { + SQLCHAR convertDoubleQuery[] = "SELECT {fn CONVERT(2.5, SQL_DOUBLE)} AS value"; + CHECK_ODBC_OK(SQLExecDirect(stmt, convertDoubleQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + double valueDouble = 0; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_DOUBLE, &valueDouble, 0, &valueInd), SQL_SUCCESS); + ASSERT_LT(std::fabs(valueDouble - 2.5), 1e-9); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + } + + { + SQLCHAR nestedFnQuery[] = "SELECT {fn {fn ABS(-10)}} AS value"; + CHECK_ODBC_OK(SQLExecDirect(stmt, nestedFnQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + SQLINTEGER valueInt = 0; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &valueInt, 0, &valueInd), SQL_SUCCESS); + ASSERT_EQ(valueInt, 10); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + } + + { + SQLCHAR asciiLowerQuery[] = "SELECT {fn String::AsciiToLower('AbC')} AS value"; + CHECK_ODBC_OK(SQLExecDirect(stmt, asciiLowerQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + char buf[32] = {}; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_CHAR, buf, sizeof(buf), &valueInd), SQL_SUCCESS); + ASSERT_STREQ(buf, "abc"); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + } + + { + SQLCHAR dateQuery[] = "SELECT {d '2024-06-15'} AS value"; + CHECK_ODBC_OK(SQLExecDirect(stmt, dateQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + char buf[32] = {}; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_CHAR, buf, sizeof(buf), &valueInd), SQL_SUCCESS); + ASSERT_STREQ(buf, "2024-06-15"); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + } + + { + SQLCHAR tsQuery[] = "SELECT {ts '2024-06-15 14:30:00'} AS value"; + CHECK_ODBC_OK(SQLExecDirect(stmt, tsQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + char buf[64] = {}; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_CHAR, buf, sizeof(buf), &valueInd), SQL_SUCCESS); + ASSERT_STREQ(buf, "2024-06-15 14:30:00"); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + } + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcStmtAttr, MetadataIdSqlLikeForTableNames) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + SQLCHAR ddl[] = R"( + DROP TABLE IF EXISTS test_odbc_meta_like_a; + DROP TABLE IF EXISTS test_odbc_meta_like_b; + CREATE TABLE test_odbc_meta_like_a (id Int32, PRIMARY KEY (id)); + CREATE TABLE test_odbc_meta_like_b (id Int32, PRIMARY KEY (id)); + )"; + CHECK_ODBC_OK(SQLExecDirect(stmt, ddl, SQL_NTS), stmt, SQL_HANDLE_STMT); + + SQLULEN metadataId = SQL_TRUE; + ASSERT_EQ(SQLGetStmtAttr(stmt, SQL_ATTR_METADATA_ID, &metadataId, 0, nullptr), SQL_SUCCESS); + ASSERT_EQ(metadataId, SQL_FALSE); + + const char* likePattern = "%/test_odbc_meta_like_%"; + CHECK_ODBC_OK( + SQLTables(stmt, nullptr, 0, nullptr, 0, (SQLCHAR*)likePattern, SQL_NTS, (SQLCHAR*)"TABLE", SQL_NTS), + stmt, + SQL_HANDLE_STMT); + int tableRows = 0; + while (SQLFetch(stmt) == SQL_SUCCESS) { + ++tableRows; + } + ASSERT_EQ(tableRows, 2); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + + CHECK_ODBC_OK( + SQLSetStmtAttr(stmt, SQL_ATTR_METADATA_ID, (SQLPOINTER)(uintptr_t)SQL_TRUE, 0), + stmt, + SQL_HANDLE_STMT); + ASSERT_EQ(SQLGetStmtAttr(stmt, SQL_ATTR_METADATA_ID, &metadataId, 0, nullptr), SQL_SUCCESS); + ASSERT_EQ(metadataId, SQL_TRUE); + + ASSERT_EQ( + SQLTables(stmt, nullptr, 0, nullptr, 0, (SQLCHAR*)likePattern, SQL_NTS, (SQLCHAR*)"TABLE", SQL_NTS), + SQL_ERROR); + EXPECT_TRUE(SqlStatePrefix(GetOdbcError(stmt, SQL_HANDLE_STMT), "HYC00")); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + + const std::string exactPath = "/local/test_odbc_meta_like_a"; + CHECK_ODBC_OK( + SQLTables(stmt, nullptr, 0, nullptr, 0, (SQLCHAR*)exactPath.c_str(), SQL_NTS, (SQLCHAR*)"TABLE", SQL_NTS), + stmt, + SQL_HANDLE_STMT); + tableRows = 0; + while (SQLFetch(stmt) == SQL_SUCCESS) { + ++tableRows; + } + ASSERT_EQ(tableRows, 1); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + + CHECK_ODBC_OK( + SQLSetStmtAttr(stmt, SQL_ATTR_METADATA_ID, (SQLPOINTER)(uintptr_t)SQL_FALSE, 0), + stmt, + SQL_HANDLE_STMT); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcStmtAttr, MetadataIdSqlLikeForColumnNames) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + SQLCHAR ddl[] = R"( + DROP TABLE IF EXISTS test_odbc_meta_col; + CREATE TABLE test_odbc_meta_col (id Int32, value_x Int32, PRIMARY KEY (id)); + )"; + CHECK_ODBC_OK(SQLExecDirect(stmt, ddl, SQL_NTS), stmt, SQL_HANDLE_STMT); + + constexpr SQLUSMALLINT kColumnNameCol = 4; + char colName[256] = {}; + SQLLEN colInd = 0; + const std::string exactTable = "/local/test_odbc_meta_col"; + + { + CHECK_ODBC_OK( + SQLColumns( + stmt, + nullptr, + 0, + nullptr, + 0, + (SQLCHAR*)"%/test_odbc_meta_col", + SQL_NTS, + (SQLCHAR*)"val%", + SQL_NTS), + stmt, + SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLGetData(stmt, kColumnNameCol, SQL_C_CHAR, colName, sizeof(colName), &colInd), SQL_SUCCESS); + ASSERT_STREQ(colName, "value_x"); + ASSERT_EQ(SQLFetch(stmt), SQL_NO_DATA); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + } + + { + CHECK_ODBC_OK( + SQLSetStmtAttr(stmt, SQL_ATTR_METADATA_ID, (SQLPOINTER)(uintptr_t)SQL_TRUE, 0), + stmt, + SQL_HANDLE_STMT); + + ASSERT_EQ( + SQLColumns( + stmt, + nullptr, + 0, + nullptr, + 0, + (SQLCHAR*)"%/test_odbc_meta_col", + SQL_NTS, + (SQLCHAR*)"value_x", + SQL_NTS), + SQL_ERROR); + EXPECT_TRUE(SqlStatePrefix(GetOdbcError(stmt, SQL_HANDLE_STMT), "HYC00")); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + } + + { + ASSERT_EQ( + SQLColumns( + stmt, + nullptr, + 0, + nullptr, + 0, + (SQLCHAR*)exactTable.c_str(), + SQL_NTS, + (SQLCHAR*)"val%", + SQL_NTS), + SQL_ERROR); + EXPECT_TRUE(SqlStatePrefix(GetOdbcError(stmt, SQL_HANDLE_STMT), "42S22")); + ASSERT_EQ(SQLFreeStmt(stmt, SQL_CLOSE), SQL_SUCCESS); + } + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + diff --git a/odbc/tests/integration/test_utils.h b/odbc/tests/integration/test_utils.h index c43272f0f54..950ffef9508 100644 --- a/odbc/tests/integration/test_utils.h +++ b/odbc/tests/integration/test_utils.h @@ -5,11 +5,9 @@ #include #include +#include #include -#define CHECK_ODBC_OK(rc, handle, type) \ - ASSERT_TRUE((rc) == SQL_SUCCESS || (rc) == SQL_SUCCESS_WITH_INFO) << GetOdbcError(handle, type) - inline std::string GetOdbcError(SQLHANDLE handle, SQLSMALLINT type) { SQLCHAR sqlState[6] = {0}; SQLCHAR message[256] = {0}; @@ -22,4 +20,24 @@ inline std::string GetOdbcError(SQLHANDLE handle, SQLSMALLINT type) { return "Unknown ODBC error"; } +#define CHECK_ODBC_OK(rc, handle, type) \ + ASSERT_TRUE((rc) == SQL_SUCCESS || (rc) == SQL_SUCCESS_WITH_INFO) << GetOdbcError(handle, type) + inline const char* kConnStr = "Driver=" ODBC_DRIVER_PATH ";Endpoint=localhost:2136;Database=/local;"; + +inline bool SqlStatePrefix(const std::string& diag, const char* state5) { + return diag.size() >= 5 && std::strncmp(diag.c_str(), state5, 5) == 0; +} + +inline void AllocEnv(SQLHENV* env) { + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, env), SQL_SUCCESS); + ASSERT_EQ(SQLSetEnvAttr(*env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); +} + +inline void AllocEnvAndConnect(SQLHENV* env, SQLHDBC* dbc) { + AllocEnv(env); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, *env, dbc), SQL_SUCCESS); + SQLRETURN rc = SQLDriverConnect( + *dbc, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE); + CHECK_ODBC_OK(rc, *dbc, SQL_HANDLE_DBC); +} diff --git a/odbc/tests/unit/CMakeLists.txt b/odbc/tests/unit/CMakeLists.txt index d1eac199615..d23e837d2f3 100644 --- a/odbc/tests/unit/CMakeLists.txt +++ b/odbc/tests/unit/CMakeLists.txt @@ -8,3 +8,26 @@ add_ydb_test(NAME odbc-convert_ut GTEST LABELS unit ) + +add_ydb_test(NAME odbc-escape_ut GTEST + SOURCES + escape_ut.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/utils/escape.cpp + INCLUDE_DIRS + ${CMAKE_CURRENT_SOURCE_DIR}/../../src + LINK_LIBRARIES + yutil + LABELS + unit +) + +add_ydb_test(NAME odbc-sql_like_ut GTEST + SOURCES + sql_like_ut.cpp + INCLUDE_DIRS + ${CMAKE_CURRENT_SOURCE_DIR}/../../src + LINK_LIBRARIES + yutil + LABELS + unit +) diff --git a/odbc/tests/unit/escape_ut.cpp b/odbc/tests/unit/escape_ut.cpp new file mode 100644 index 00000000000..60b3e582e69 --- /dev/null +++ b/odbc/tests/unit/escape_ut.cpp @@ -0,0 +1,71 @@ +#include "utils/escape.h" + +#include + +using NYdb::NOdbc::RewriteOdbcEscapes; + +TEST(OdbcEscapeRewrite, FnUnwraps) { + EXPECT_EQ(RewriteOdbcEscapes("SELECT {fn ABS(-12)} AS v"), "SELECT ABS(-12) AS v"); +} + +TEST(OdbcEscapeRewrite, FnCaseInsensitive) { + EXPECT_EQ(RewriteOdbcEscapes("{FN LOWER('A')}"), "LOWER('A')"); +} + +TEST(OdbcEscapeRewrite, OjUnwraps) { + EXPECT_EQ(RewriteOdbcEscapes("{oj LEFT OUTER JOIN t ON a=b}"), "LEFT OUTER JOIN t ON a=b"); +} + +TEST(OdbcEscapeRewrite, DateLiteral) { + EXPECT_EQ(RewriteOdbcEscapes("SELECT {d '2024-01-01'}"), "SELECT CAST('2024-01-01' AS Date)"); +} + +TEST(OdbcEscapeRewrite, TimeLiteral) { + EXPECT_EQ(RewriteOdbcEscapes("{t '14:30:00'}"), "CAST('14:30:00' AS Time)"); +} + +TEST(OdbcEscapeRewrite, TimestampLiteralNormalizesSpaceToT) { + EXPECT_EQ( + RewriteOdbcEscapes("SELECT {ts '2024-06-15 14:30:00'} AS v"), + "SELECT CAST('2024-06-15T14:30:00Z' AS Datetime) AS v"); +} + +TEST(OdbcEscapeRewrite, TimestampLiteralKeepsExistingZ) { + EXPECT_EQ( + RewriteOdbcEscapes("SELECT {ts '2024-06-15T14:30:00Z'} AS v"), + "SELECT CAST('2024-06-15T14:30:00Z' AS Datetime) AS v"); +} + +TEST(OdbcEscapeRewrite, Call) { + EXPECT_EQ(RewriteOdbcEscapes("{call sp_demo(1, 2)}"), "CALL sp_demo(1, 2)"); +} + +TEST(OdbcEscapeRewrite, OutputCallBecomesPlainCall) { + EXPECT_EQ(RewriteOdbcEscapes("{?= call sp(1)}"), "CALL sp(1)"); +} + +TEST(OdbcEscapeRewrite, EscapeClause) { + EXPECT_EQ(RewriteOdbcEscapes("LIKE 'a%' {escape '\\'}"), "LIKE 'a%' ESCAPE '\\'"); +} + +TEST(OdbcEscapeRewrite, ConvertOdbcToYqlCast) { + EXPECT_EQ( + RewriteOdbcEscapes("SELECT {fn CONVERT(42, SQL_SMALLINT)} AS v"), + "SELECT CAST(42 AS Int16) AS v"); +} + +TEST(OdbcEscapeRewrite, ConvertNestedInFn) { + EXPECT_EQ(RewriteOdbcEscapes("{fn CONVERT(x, SQL_INTEGER)}"), "CAST(x AS Int32)"); +} + +TEST(OdbcEscapeRewrite, NestedFnEscapes) { + EXPECT_EQ(RewriteOdbcEscapes("{fn {fn ABS(1)}}"), "ABS(1)"); +} + +TEST(OdbcEscapeRewrite, UnknownBraceLeftUnchanged) { + EXPECT_EQ(RewriteOdbcEscapes("{not_a_keyword 1}"), "{not_a_keyword 1}"); +} + +TEST(OdbcEscapeRewrite, EmptyInput) { + EXPECT_EQ(RewriteOdbcEscapes(""), ""); +} diff --git a/odbc/tests/unit/sql_like_ut.cpp b/odbc/tests/unit/sql_like_ut.cpp new file mode 100644 index 00000000000..e0b8d87ee01 --- /dev/null +++ b/odbc/tests/unit/sql_like_ut.cpp @@ -0,0 +1,28 @@ +#include "utils/sql_like.h" + +#include + +using NYdb::NOdbc::SqlLikeMatch; + +TEST(SqlLikeMatch, PercentMatchesSubstring) { + EXPECT_TRUE(SqlLikeMatch("/local/foo_bar", "%foo%")); + EXPECT_TRUE(SqlLikeMatch("/local/pfx_foo_sfx", "%foo%")); + EXPECT_FALSE(SqlLikeMatch("/local/other", "%foo%")); +} + +TEST(SqlLikeMatch, UnderscoreMatchesSingleChar) { + EXPECT_TRUE(SqlLikeMatch("a_c", "a_c")); + EXPECT_TRUE(SqlLikeMatch("abc", "a_c")); + EXPECT_FALSE(SqlLikeMatch("abbc", "a_c")); +} + +TEST(SqlLikeMatch, EmptyPatternMatchesOnlyEmptyText) { + EXPECT_TRUE(SqlLikeMatch("", "")); + EXPECT_FALSE(SqlLikeMatch("anything", "")); +} + +TEST(SqlLikeMatch, PercentAtEnds) { + EXPECT_TRUE(SqlLikeMatch("hello", "%hello%")); + EXPECT_TRUE(SqlLikeMatch("hello", "hel%")); + EXPECT_TRUE(SqlLikeMatch("hello", "%llo")); +} diff --git a/tests/unit/library/operation_id/CMakeLists.txt b/tests/unit/library/operation_id/CMakeLists.txt index 86d3fd5131d..06c568af97c 100644 --- a/tests/unit/library/operation_id/CMakeLists.txt +++ b/tests/unit/library/operation_id/CMakeLists.txt @@ -5,6 +5,7 @@ add_ydb_test(NAME operation_id_ut yutil cpp-testing-unittest_main library-operation_id + lib-operation_id-protos cpp-testing-unittest LABELS unit From 67bc96404478d5edeaf2c344d8a622c808427bd1 Mon Sep 17 00:00:00 2001 From: Ylonies Date: Sat, 18 Apr 2026 21:09:41 +0000 Subject: [PATCH 8/9] retry qyery for autocommit --- odbc/src/connection.cpp | 6 ++- odbc/src/connection.h | 3 +- odbc/src/statement.cpp | 97 ++++++++++++++++++++++++++++++----------- odbc/src/statement.h | 2 +- 4 files changed, 80 insertions(+), 28 deletions(-) diff --git a/odbc/src/connection.cpp b/odbc/src/connection.cpp index a52d5036f04..85724670f9d 100644 --- a/odbc/src/connection.cpp +++ b/odbc/src/connection.cpp @@ -200,10 +200,14 @@ void TConnection::SetTx(const NQuery::TTransaction& tx) { Tx_ = tx; } -void TConnection::Reset() { +void TConnection::ResetTx() { Tx_.reset(); } +void TConnection::ResetQuerySession() { + QuerySession_.reset(); +} + SQLRETURN TConnection::CommitTx() { auto status = Tx_->Commit().ExtractValueSync(); NStatusHelpers::ThrowOnError(status); diff --git a/odbc/src/connection.h b/odbc/src/connection.h index 284ac36cf65..dac7721c000 100644 --- a/odbc/src/connection.h +++ b/odbc/src/connection.h @@ -68,7 +68,8 @@ class TConnection : public TErrorManager { const std::optional& GetTx(); void SetTx(const NQuery::TTransaction& tx); - void Reset(); + void ResetTx(); + void ResetQuerySession(); SQLRETURN CommitTx(); SQLRETURN RollbackTx(); diff --git a/odbc/src/statement.cpp b/odbc/src/statement.cpp index f4b04ec0be2..16efb84aa1c 100644 --- a/odbc/src/statement.cpp +++ b/odbc/src/statement.cpp @@ -8,12 +8,42 @@ #include #include +#include +#include #include +#include + namespace NYdb { namespace NOdbc { +namespace { + NYdb::TStatus StatusFrom(const NYdb::TStatus& ydb_status) { + return NYdb::TStatus(ydb_status.GetStatus(), NYdb::NIssue::TIssues(ydb_status.GetIssues())); + } + + + NYdb::TStatus PrefetchFirstPartStatus(NQuery::TExecuteQueryIterator& iterator, std::optional* prefetchedResultPart){ + prefetchedResultPart->reset(); + while (true) { + auto part = iterator.ReadNext().ExtractValueSync(); + if (part.EOS()) { + break; + } + if (!part.IsSuccess()) { + return StatusFrom(part); + + } + if (part.HasResultSet()) { + prefetchedResultPart->emplace(std::move(part)); + return NYdb::TStatus(EStatus::SUCCESS, NYdb::NIssue::TIssues()); + } + } + return NYdb::TStatus(EStatus::SUCCESS, NYdb::NIssue::TIssues()); + } +} + TStatement::TStatement(TConnection* conn) : Conn_(conn) {} @@ -39,18 +69,41 @@ SQLRETURN TStatement::Execute() { } NYdb::TParams params = BuildParams(); - if (Conn_->GetAutocommit()){ - Conn_->Reset(); - } + std::optional iterator; + std::optional prefetchedResultPart; - auto& session = Conn_->GetOrCreateQuerySession(); + if (Conn_->GetAutocommit()){ + Conn_->ResetTx(); + Conn_->ResetQuerySession(); + const NYdb::NRetry::TRetryOperationSettings retrySettings = + MakeAutocommitRetrySettings(); + + NYdb::TStatus execStatus = client->RetryQuerySync( + [this, ¶ms, &iterator, &prefetchedResultPart](NQuery::TSession session) -> NYdb::TStatus{ + auto retry_iterator = CreateExecuteIterator(session, params); + if (!retry_iterator.IsSuccess()) { + return StatusFrom(retry_iterator); + } + std::optional retry_prefetched; + const NYdb::TStatus prefetchStatus = PrefetchFirstPartStatus(retry_iterator, &retry_prefetched); + if (!prefetchStatus.IsSuccess()) { + return prefetchStatus; + } + iterator.emplace(std::move(retry_iterator)); + prefetchedResultPart = std::move(retry_prefetched); + return NYdb::TStatus(EStatus::SUCCESS, NYdb::NIssue::TIssues()); + }, retrySettings); - auto iterator = CreateExecuteIterator(session, params); - NStatusHelpers::ThrowOnError(iterator); + NStatusHelpers::ThrowOnError(execStatus); + } else { + NQuery::TSession& session = Conn_->GetOrCreateQuerySession(); + iterator.emplace(CreateExecuteIterator(session, params)); + NStatusHelpers::ThrowOnError(*iterator); + NStatusHelpers::ThrowOnError(PrefetchFirstPartStatus(*iterator, &prefetchedResultPart)); + } - std::optional prefetchedResultPart = PrefetchFirstResultPart(iterator); if (prefetchedResultPart) { - Cursor_ = CreateExecCursor(this, std::move(iterator), std::move(prefetchedResultPart)); + Cursor_ = CreateExecCursor(this, std::move(*iterator), std::move(prefetchedResultPart)); } else { Cursor_.reset(); } @@ -59,6 +112,16 @@ SQLRETURN TStatement::Execute() { return SQL_SUCCESS; } +NYdb::NRetry::TRetryOperationSettings TStatement::MakeAutocommitRetrySettings() { + NYdb::NRetry::TRetryOperationSettings settings; + SQLUINTEGER queryTimeoutSec = Attributes_.GetQueryTimeoutSec(); + if (queryTimeoutSec > 0) { + const TDuration deadline = TDuration::Seconds(queryTimeoutSec); + settings.MaxTimeout(deadline).GetSessionClientTimeout(deadline); + } + return settings; +} + NQuery::TExecuteQueryIterator TStatement::CreateExecuteIterator(NQuery::TSession& session, const NYdb::TParams& params){ const std::string sqlText = Attributes_.GetNoScanMode() == SQL_NOSCAN_ON ? PreparedQuery_ @@ -94,23 +157,7 @@ NQuery::TExecuteQueryIterator TStatement::CreateExecuteIterator(NQuery::TSession execSettings).ExtractValueSync(); } -std::optional TStatement::PrefetchFirstResultPart(NQuery::TExecuteQueryIterator& iterator){ - std::optional prefetchedResultPart; - while (true) { - auto part = iterator.ReadNext().ExtractValueSync(); - if (part.EOS()) { - break; - } - if (!part.IsSuccess()) { - NStatusHelpers::ThrowOnError(part); - } - if (part.HasResultSet()) { - prefetchedResultPart.emplace(std::move(part)); - break; - } - } - return prefetchedResultPart; -} + SQLRETURN TStatement::Fetch() { if (!Cursor_) { diff --git a/odbc/src/statement.h b/odbc/src/statement.h index 702fe56c71e..754250bd269 100644 --- a/odbc/src/statement.h +++ b/odbc/src/statement.h @@ -74,8 +74,8 @@ class TStatement : public TErrorManager, public IBindingFiller { NYdb::TParams BuildParams(); NQuery::TExecuteQueryIterator CreateExecuteIterator(NQuery::TSession& session, const NYdb::TParams& params); - std::optional PrefetchFirstResultPart(NQuery::TExecuteQueryIterator& iterator); + NYdb::NRetry::TRetryOperationSettings MakeAutocommitRetrySettings(); std::vector GetPatternEntries(const std::string& pattern); SQLRETURN VisitEntry(const std::string& path, const std::string& pattern, std::vector& resultEntries); bool IsPatternMatch(const std::string& path, const std::string& pattern); From 8d223d008238fd0f107d0ae5ecb902631953e22a Mon Sep 17 00:00:00 2001 From: Ylonies Date: Mon, 27 Apr 2026 19:42:05 +0300 Subject: [PATCH 9/9] getdiagfield + fixes --- odbc/src/odbc_driver.cpp | 12 +++- odbc/src/statement.cpp | 16 +++++ odbc/src/statement.h | 3 + odbc/src/utils/bindings.h | 3 +- odbc/src/utils/error_manager.cpp | 115 +++++++++++++++++++++++-------- odbc/src/utils/error_manager.h | 14 +++- 6 files changed, 129 insertions(+), 34 deletions(-) diff --git a/odbc/src/odbc_driver.cpp b/odbc/src/odbc_driver.cpp index cba323453af..97d9ffd8a5a 100644 --- a/odbc/src/odbc_driver.cpp +++ b/odbc/src/odbc_driver.cpp @@ -32,7 +32,9 @@ SQLRETURN SQL_API SQLAllocHandle(SQLSMALLINT handleType, return NYdb::NOdbc::HandleOdbcExceptions( inputHandle, [&]() { - *outputHandle = new NYdb::NOdbc::TEnvironment(); + auto* const env = new NYdb::NOdbc::TEnvironment(); + *outputHandle = env; + env->SetLastReturnCode(SQL_SUCCESS); return SQL_SUCCESS; }, NYdb::NOdbc::ENullInputHandlePolicy::Allow); @@ -43,14 +45,18 @@ SQLRETURN SQL_API SQLAllocHandle(SQLSMALLINT handleType, auto conn = std::make_unique(); conn->SetEnvironment(env); env->RegisterConnection(conn.get()); - *outputHandle = conn.release(); + auto* const raw = conn.release(); + *outputHandle = raw; + raw->SetLastReturnCode(SQL_SUCCESS); return SQL_SUCCESS; }); } case SQL_HANDLE_STMT: { return NYdb::NOdbc::HandleOdbcExceptions(inputHandle, [&](auto* conn) { auto stmt = conn->CreateStatement(); - *outputHandle = stmt.release(); + auto* const raw = stmt.release(); + *outputHandle = raw; + raw->SetLastReturnCode(SQL_SUCCESS); return SQL_SUCCESS; }); } diff --git a/odbc/src/statement.cpp b/odbc/src/statement.cpp index 16efb84aa1c..5d6bb38f152 100644 --- a/odbc/src/statement.cpp +++ b/odbc/src/statement.cpp @@ -518,5 +518,21 @@ SQLRETURN TStatement::GetStmtAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER return Attributes_.GetStmtAttr(attr, value, bufferLength, stringLengthPtr, *this); } +SQLRETURN TStatement::GetDiagField( + SQLSMALLINT recNumber, + SQLSMALLINT diagIdentifier, + SQLPOINTER diagInfoPtr, + SQLSMALLINT bufferLength, + SQLSMALLINT* stringLengthPtr) { + if (recNumber == 0 && diagIdentifier == SQL_DIAG_ROW_COUNT) { + if (!diagInfoPtr) { + return SQL_ERROR; + } + *reinterpret_cast(diagInfoPtr) = -1; + return SQL_SUCCESS; + } + return TErrorManager::GetDiagField(recNumber, diagIdentifier, diagInfoPtr, bufferLength, stringLengthPtr); +} + } // namespace NOdbc } // namespace NYdb diff --git a/odbc/src/statement.h b/odbc/src/statement.h index 754250bd269..9f2eb8ade64 100644 --- a/odbc/src/statement.h +++ b/odbc/src/statement.h @@ -55,6 +55,9 @@ class TStatement : public TErrorManager, public IBindingFiller { SQLRETURN SetStmtAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER stringLength); SQLRETURN GetStmtAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER bufferLength, SQLINTEGER* stringLengthPtr); + SQLRETURN GetDiagField(SQLSMALLINT recNumber, SQLSMALLINT diagIdentifier, SQLPOINTER diagInfoPtr, SQLSMALLINT bufferLength, + SQLSMALLINT* stringLengthPtr) override; + TConnection* GetConnection() { return Conn_; } diff --git a/odbc/src/utils/bindings.h b/odbc/src/utils/bindings.h index 443d9787d70..2480f5367af 100644 --- a/odbc/src/utils/bindings.h +++ b/odbc/src/utils/bindings.h @@ -31,8 +31,7 @@ struct TBoundColumn { class IBindingFiller { public: virtual void FillBoundColumns() = 0; - virtual void OnStreamPartError(const TStatus& status) { - (void)status; + virtual void OnStreamPartError([[maybe_unused]] const TStatus& status) { } virtual ~IBindingFiller() = default; diff --git a/odbc/src/utils/error_manager.cpp b/odbc/src/utils/error_manager.cpp index fbb577e3824..92c8ec1750f 100644 --- a/odbc/src/utils/error_manager.cpp +++ b/odbc/src/utils/error_manager.cpp @@ -2,6 +2,8 @@ #include #include +#include +#include namespace NYdb { namespace NOdbc { @@ -56,13 +58,57 @@ namespace { } } // namespace +namespace { + +SQLRETURN WriteDiagCStr( + const std::string& str, + SQLPOINTER diagInfoPtr, + SQLSMALLINT bufferLength, + SQLSMALLINT* stringLengthPtr, + bool sqlStateField = false) { + std::string storage; + const std::string* src = &str; + if (sqlStateField) { + storage = str; + if (storage.size() < 5) { + storage.append(5U - storage.size(), ' '); + } else { + storage.resize(5U); + } + src = &storage; + } + if (!diagInfoPtr) { + return SQL_ERROR; + } + if (bufferLength < 0) { + return SQL_ERROR; + } + const size_t fullLen = src->size(); + if (stringLengthPtr) { + *stringLengthPtr = static_cast(std::min(fullLen, 0x7FFFU)); + } + if (bufferLength == 0) { + return fullLen == 0 ? SQL_SUCCESS : SQL_SUCCESS_WITH_INFO; + } + auto* out = static_cast(diagInfoPtr); + const size_t maxData = static_cast(bufferLength - 1U); + const size_t copyLen = std::min(fullLen, maxData); + std::memcpy(out, src->data(), copyLen); + out[copyLen] = 0; + return (fullLen > maxData) ? SQL_SUCCESS_WITH_INFO : SQL_SUCCESS; +} + +} // namespace + SQLRETURN TErrorManager::AddError(const std::string& sqlState, SQLINTEGER nativeError, const std::string& message, SQLRETURN returnCode) { Errors_.push_back({sqlState, nativeError, message, returnCode}); + LastReturnCode_ = returnCode; return returnCode; } SQLRETURN TErrorManager::AddError(const TOdbcException& ex) { Errors_.push_back({ex.GetSqlState(), ex.GetNativeError(), ex.GetMessage(), ex.GetReturnCode()}); + LastReturnCode_ = ex.GetReturnCode(); return ex.GetReturnCode(); } @@ -73,6 +119,7 @@ SQLRETURN TErrorManager::AddError(const TStatus& status) { message += ": " + status.GetIssues().ToString(); } Errors_.push_back({mapping.sqlState, static_cast(status.GetStatus()), message, mapping.returnCode}); + LastReturnCode_ = mapping.returnCode; return mapping.returnCode; } @@ -104,19 +151,26 @@ SQLRETURN TErrorManager::GetDiagRec(SQLSMALLINT recNumber, SQLCHAR* sqlState, SQ return SQL_SUCCESS; } -SQLRETURN TErrorManager::GetDiagField(SQLSMALLINT recNumber, SQLSMALLINT diagIdentifier, - SQLPOINTER diagInfoPtr, SQLSMALLINT bufferLength, SQLSMALLINT* stringLengthPtr) { +SQLRETURN TErrorManager::GetDiagField(SQLSMALLINT recNumber, SQLSMALLINT diagIdentifier, SQLPOINTER diagInfoPtr, + SQLSMALLINT bufferLength, SQLSMALLINT* stringLengthPtr) { const SQLSMALLINT count = static_cast(Errors_.size()); - + if (diagInfoPtr == nullptr) { + return SQL_ERROR; + } if (recNumber == 0) { - if (diagIdentifier == SQL_DIAG_NUMBER) { - if (!diagInfoPtr) { - return SQL_ERROR; + switch (diagIdentifier) { + case SQL_DIAG_RETURNCODE: + *static_cast(diagInfoPtr) = LastReturnCode_; + return SQL_SUCCESS; + case SQL_DIAG_NUMBER: { + *static_cast(diagInfoPtr) = static_cast(count); + return SQL_SUCCESS; } - *static_cast(diagInfoPtr) = count; - return SQL_SUCCESS; + case SQL_DIAG_ROW_COUNT: + return SQL_ERROR; + default: + return SQL_ERROR; } - return SQL_NO_DATA; } if (recNumber < 1 || recNumber > count) { @@ -126,28 +180,28 @@ SQLRETURN TErrorManager::GetDiagField(SQLSMALLINT recNumber, SQLSMALLINT diagIde const auto& err = Errors_[recNumber - 1]; switch (diagIdentifier) { case SQL_DIAG_SQLSTATE: - if (!diagInfoPtr) { - return SQL_ERROR; - } - strncpy((char*)diagInfoPtr, err.SqlState.c_str(), 6); - return SQL_SUCCESS; - case SQL_DIAG_NATIVE: - if (!diagInfoPtr) { - return SQL_ERROR; - } + return WriteDiagCStr(err.SqlState, diagInfoPtr, bufferLength, stringLengthPtr, true); + case SQL_DIAG_NATIVE: { *static_cast(diagInfoPtr) = err.NativeError; return SQL_SUCCESS; + } case SQL_DIAG_MESSAGE_TEXT: - if (!diagInfoPtr || bufferLength <= 0) { - return SQL_ERROR; - } - strncpy((char*)diagInfoPtr, err.Message.c_str(), bufferLength); - if (stringLengthPtr) { - *stringLengthPtr = static_cast(err.Message.size()); - } + return WriteDiagCStr(err.Message, diagInfoPtr, bufferLength, stringLengthPtr); + case SQL_DIAG_CLASS_ORIGIN: + return WriteDiagCStr("ODBC 3.0", diagInfoPtr, bufferLength, stringLengthPtr); + case SQL_DIAG_SUBCLASS_ORIGIN: + return WriteDiagCStr("ODBC 3.0", diagInfoPtr, bufferLength, stringLengthPtr); + case SQL_DIAG_CONNECTION_NAME: + case SQL_DIAG_SERVER_NAME: + return WriteDiagCStr("", diagInfoPtr, bufferLength, stringLengthPtr); + case SQL_DIAG_COLUMN_NUMBER: + *static_cast(diagInfoPtr) = SQL_COLUMN_NUMBER_UNKNOWN; + return SQL_SUCCESS; + case SQL_DIAG_ROW_NUMBER: + *static_cast(diagInfoPtr) = SQL_ROW_NUMBER_UNKNOWN; return SQL_SUCCESS; default: - return SQL_NO_DATA; + return SQL_ERROR; } } @@ -160,8 +214,15 @@ SQLRETURN HandleOdbcExceptions( } try { - return func(); + const SQLRETURN r = func(); + if (handlePtr) { + static_cast(handlePtr)->SetLastReturnCode(r); + } + return r; } catch (...) { + if (handlePtr) { + static_cast(handlePtr)->SetLastReturnCode(SQL_ERROR); + } return SQL_ERROR; } } diff --git a/odbc/src/utils/error_manager.h b/odbc/src/utils/error_manager.h index 5f72a69f563..9f91fab8a1d 100644 --- a/odbc/src/utils/error_manager.h +++ b/odbc/src/utils/error_manager.h @@ -64,13 +64,21 @@ class TErrorManager { void ClearErrors(); + void SetLastReturnCode(SQLRETURN code) { + LastReturnCode_ = code; + } + [[nodiscard]] SQLRETURN GetLastReturnCode() const { + return LastReturnCode_; + } + SQLRETURN GetDiagRec(SQLSMALLINT recNumber, SQLCHAR* sqlState, SQLINTEGER* nativeError, SQLCHAR* messageText, SQLSMALLINT bufferLength, SQLSMALLINT* textLength); - SQLRETURN GetDiagField(SQLSMALLINT recNumber, SQLSMALLINT diagIdentifier, + virtual SQLRETURN GetDiagField(SQLSMALLINT recNumber, SQLSMALLINT diagIdentifier, SQLPOINTER diagInfoPtr, SQLSMALLINT bufferLength, SQLSMALLINT* stringLengthPtr); private: TErrorList Errors_; + SQLRETURN LastReturnCode_ = SQL_SUCCESS; }; enum class ENullInputHandlePolicy : unsigned char { @@ -86,7 +94,9 @@ SQLRETURN HandleOdbcExceptions(SQLHANDLE handlePtr, std::function(handlePtr); try { - return func(handle); + const SQLRETURN ret = func(handle); + handle->SetLastReturnCode(ret); + return ret; } catch (const NStatusHelpers::TYdbErrorException& ex) { return handle->AddError(ex.GetStatus()); } catch (const TOdbcException& ex) {