diff --git a/cpp/src/arrow/flight/sql/odbc/tests/connection_attr_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/connection_attr_test.cc index dbf2fbb74f8..467447bdc9c 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/connection_attr_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/tests/connection_attr_test.cc @@ -87,15 +87,18 @@ TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrEnlistInDtcUnsupported) } TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrOdbcCursorsDMOnly) { - this->AllocEnvConnHandles(); + SQLHENV test_env = SQL_NULL_HENV; + SQLHDBC test_conn = SQL_NULL_HDBC; + this->AllocEnvConnHandles(test_env, test_conn); // Verify DM-only attribute is settable via Driver Manager ASSERT_EQ(SQL_SUCCESS, - SQLSetConnectAttr(conn, SQL_ATTR_ODBC_CURSORS, + SQLSetConnectAttr(test_conn, SQL_ATTR_ODBC_CURSORS, reinterpret_cast(SQL_CUR_USE_DRIVER), 0)); std::string connect_str = this->GetConnectionString(); - this->ConnectWithString(connect_str); + this->ConnectWithString(connect_str, test_conn); + this->Disconnect(test_env, test_conn); } TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrQuietModeReadOnly) { diff --git a/cpp/src/arrow/flight/sql/odbc/tests/connection_info_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/connection_info_test.cc index e39433fa979..ba2f68f0af8 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/connection_info_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/tests/connection_info_test.cc @@ -592,7 +592,7 @@ TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoAlterTable) { TYPED_TEST(ConnectionInfoHandleTest, TestSQLGetInfoCatalogLocation) { // GH-49482 TODO: resolve inconsitent return value for SQL_CATALOG_LOCATION and change // test type to `ConnectionInfoTest` - this->ConnectWithString(this->GetConnectionString()); + this->ConnectWithString(this->GetConnectionString(), conn); SQLUSMALLINT value; GetInfo(conn, SQL_CATALOG_LOCATION, &value); @@ -725,7 +725,7 @@ TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDropDomain) { TYPED_TEST(ConnectionInfoHandleTest, TestSQLGetInfoDropSchema) { // GH-49482 TODO: resolve inconsitent return value for SQL_DROP_SCHEMA and change test // type to `ConnectionInfoTest` - this->ConnectWithString(this->GetConnectionString()); + this->ConnectWithString(this->GetConnectionString(), conn); SQLUINTEGER value; GetInfo(conn, SQL_DROP_SCHEMA, &value); @@ -739,7 +739,7 @@ TYPED_TEST(ConnectionInfoHandleTest, TestSQLGetInfoDropSchema) { TYPED_TEST(ConnectionInfoHandleTest, TestSQLGetInfoDropTable) { // GH-49482 TODO: resolve inconsitent return value for SQL_DROP_TABLE and change test // type to `ConnectionInfoTest` - this->ConnectWithString(this->GetConnectionString()); + this->ConnectWithString(this->GetConnectionString(), conn); SQLUINTEGER value; GetInfo(conn, SQL_DROP_TABLE, &value); diff --git a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc index 3fc48c263ec..b85a26ce537 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc +++ b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc @@ -19,6 +19,7 @@ // with windows.h #include "arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.h" +#include "arrow/compute/api.h" #include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" // For DSN registration @@ -51,28 +52,80 @@ class MockServerEnvironment : public ::testing::Environment { } }; -::testing::Environment* mock_env = +bool RunningRemoteTests() { return !remote_test_connect_str.empty(); } + +class OdbcTestEnvironment : public ::testing::Environment { + public: + void SetUp() override { + remote_test_connect_str = ODBCTestBase::GetConnectionString(); + if (RunningRemoteTests()) { + ODBCTestBase::Connect(remote_test_connect_str, remote_odbcv3_handles.env, + remote_odbcv3_handles.conn, SQL_OV_ODBC3); + ODBCTestBase::Connect(remote_test_connect_str, remote_odbcv2_handles.env, + remote_odbcv2_handles.conn, SQL_OV_ODBC2); + } + + std::string mock_test_connect_str = ODBCMockTestBase::GetConnectionString(); + ODBCMockTestBase::Connect(mock_test_connect_str, mock_odbcv3_handles.env, + mock_odbcv3_handles.conn, SQL_OV_ODBC3); + ODBCMockTestBase::Connect(mock_test_connect_str, mock_odbcv2_handles.env, + mock_odbcv2_handles.conn, SQL_OV_ODBC2); + } + + void TearDown() override { + if (RunningRemoteTests()) { + ODBCTestBase::Disconnect(remote_odbcv3_handles.env, remote_odbcv3_handles.conn); + ODBCTestBase::Disconnect(remote_odbcv2_handles.env, remote_odbcv2_handles.conn); + } + + ODBCTestBase::Disconnect(mock_odbcv3_handles.env, mock_odbcv3_handles.conn); + ODBCTestBase::Disconnect(mock_odbcv2_handles.env, mock_odbcv2_handles.conn); + } +}; + +#ifdef _WIN32 +// A global test "environment", to ensure Arrow compute kernel functions are registered +class ComputeKernelEnvironment : public ::testing::Environment { + public: + void SetUp() override { ASSERT_OK(arrow::compute::Initialize()); } +}; + +::testing::Environment* compute_kernel_env = + ::testing::AddGlobalTestEnvironment(new ComputeKernelEnvironment); +#endif // _WIN32 + +::testing::Environment* mock_server_env = ::testing::AddGlobalTestEnvironment(new MockServerEnvironment); -void ODBCTestBase::AllocEnvConnHandles(SQLINTEGER odbc_ver) { +::testing::Environment* odbc_test_env = + ::testing::AddGlobalTestEnvironment(new OdbcTestEnvironment); + +SQLHENV FlightSQLOdbcEnvConnHandleRemoteTestBase::env_h = SQL_NULL_HENV; +SQLHDBC FlightSQLOdbcEnvConnHandleRemoteTestBase::conn_h = SQL_NULL_HDBC; +SQLHENV FlightSQLOdbcEnvConnHandleMockTestBase::env_h = SQL_NULL_HENV; +SQLHDBC FlightSQLOdbcEnvConnHandleMockTestBase::conn_h = SQL_NULL_HDBC; + +void ODBCTestBase::AllocEnvConnHandles(SQLHENV& env_handle, SQLHDBC& conn_handle, + SQLINTEGER odbc_ver) { // Allocate an environment handle - ASSERT_EQ(SQL_SUCCESS, SQLAllocEnv(&env)); + ASSERT_EQ(SQL_SUCCESS, SQLAllocEnv(&env_handle)); ASSERT_EQ( SQL_SUCCESS, - SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, + SQLSetEnvAttr(env_handle, SQL_ATTR_ODBC_VERSION, reinterpret_cast(static_cast(odbc_ver)), 0)); // Allocate a connection using alloc handle - ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_DBC, env, &conn)); + ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_DBC, env_handle, &conn_handle)); } -void ODBCTestBase::Connect(std::string connect_str, SQLINTEGER odbc_ver) { - ASSERT_NO_FATAL_FAILURE(AllocEnvConnHandles(odbc_ver)); - ASSERT_NO_FATAL_FAILURE(ConnectWithString(connect_str)); +void ODBCTestBase::Connect(std::string connect_str, SQLHENV& env_handle, + SQLHDBC& conn_handle, SQLINTEGER odbc_ver) { + ASSERT_NO_FATAL_FAILURE(AllocEnvConnHandles(env_handle, conn_handle, odbc_ver)); + ASSERT_NO_FATAL_FAILURE(ConnectWithString(connect_str, conn_handle)); } -void ODBCTestBase::ConnectWithString(std::string connect_str) { +void ODBCTestBase::ConnectWithString(std::string connect_str, SQLHDBC& conn_handle) { // Connect string std::vector connect_str0(connect_str.begin(), connect_str.end()); @@ -81,31 +134,39 @@ void ODBCTestBase::ConnectWithString(std::string connect_str) { // Connecting to ODBC server. ASSERT_EQ(SQL_SUCCESS, - SQLDriverConnect(conn, NULL, &connect_str0[0], + SQLDriverConnect(conn_handle, NULL, &connect_str0[0], static_cast(connect_str0.size()), out_str, kOdbcBufferSize, &out_str_len, SQL_DRIVER_NOPROMPT)) - << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn); + << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn_handle); } -void ODBCTestBase::Disconnect() { +void ODBCTestBase::Disconnect(SQLHENV& env_handle, SQLHDBC& conn_handle) { // Disconnect from ODBC - EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(conn)) - << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn); + if (conn_handle != SQL_NULL_HDBC) { + EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(conn_handle)) + << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn_handle); + } - FreeEnvConnHandles(); + FreeEnvConnHandles(env_handle, conn_handle); } -void ODBCTestBase::FreeEnvConnHandles() { +void ODBCTestBase::FreeEnvConnHandles(SQLHENV& env_handle, SQLHDBC& conn_handle) { // Free connection handle - EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_DBC, conn)); + if (conn_handle != SQL_NULL_HDBC) { + EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_DBC, conn_handle)); + conn_handle = SQL_NULL_HDBC; + } // Free environment handle - EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_ENV, env)); + if (env_handle != SQL_NULL_HENV) { + EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_ENV, env_handle)); + env_handle = SQL_NULL_HENV; + } } std::string ODBCTestBase::GetConnectionString() { std::string connect_str = - arrow::internal::GetEnvVar(kTestConnectStr.data()).ValueOrDie(); + arrow::internal::GetEnvVar(kTestConnectStr.data()).ValueOr(""); return connect_str; } @@ -168,68 +229,57 @@ std::wstring ODBCTestBase::GetQueryAllDataTypes() { } void ODBCTestBase::SetUp() { - if (connected) { - ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_STMT, conn, &stmt)); - } + ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_STMT, conn, &stmt)); } void ODBCTestBase::TearDown() { - if (connected) { - ASSERT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_STMT, stmt)); - } -} - -void ODBCTestBase::TearDownTestSuite() { - if (connected) { - Disconnect(); - connected = false; - } -} - -void FlightSQLODBCRemoteTestBase::CheckForRemoteTest() { - if (arrow::internal::GetEnvVar(kTestConnectStr.data()).ValueOr("").empty()) { - skipping_test = true; - GTEST_SKIP() << "Skipping test: kTestConnectStr not set"; - } + ASSERT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_STMT, stmt)); } void FlightSQLODBCRemoteTestBase::SetUpTestSuite() { - CheckForRemoteTest(); - if (skipping_test) { + if (!RunningRemoteTests()) { + GTEST_SKIP() << "Skipping Test Suite: Environment Variable " << kTestConnectStr.data() + << " is not set"; return; } - std::string connect_str = GetConnectionString(); - Connect(connect_str, SQL_OV_ODBC3); - connected = true; + env = remote_odbcv3_handles.env; + conn = remote_odbcv3_handles.conn; + stmt = remote_odbcv3_handles.stmt; } void FlightSQLOdbcV2RemoteTestBase::SetUpTestSuite() { - CheckForRemoteTest(); - if (skipping_test) { + if (!RunningRemoteTests()) { + GTEST_SKIP() << "Skipping Test Suite: Environment Variable " << kTestConnectStr.data() + << " is not set"; return; } - std::string connect_str = GetConnectionString(); - Connect(connect_str, SQL_OV_ODBC2); - connected = true; + env = remote_odbcv2_handles.env; + conn = remote_odbcv2_handles.conn; + stmt = remote_odbcv2_handles.stmt; } void FlightSQLOdbcEnvConnHandleRemoteTestBase::SetUpTestSuite() { - CheckForRemoteTest(); - if (skipping_test) { + if (!RunningRemoteTests()) { + GTEST_SKIP() << "Skipping Test Suite: Environment Variable " << kTestConnectStr.data() + << " is not set"; return; } - AllocEnvConnHandles(); + env_h = SQL_NULL_HENV; + conn_h = SQL_NULL_HDBC; + AllocEnvConnHandles(env_h, conn_h); + env = env_h; + conn = conn_h; } void FlightSQLOdbcEnvConnHandleRemoteTestBase::TearDownTestSuite() { - if (skipping_test) { + if (!RunningRemoteTests()) { return; } - FreeEnvConnHandles(); + FreeEnvConnHandles(env_h, conn_h); } std::string FindTokenInCallHeaders(const CallHeaders& incoming_headers) { @@ -400,20 +450,28 @@ void ODBCMockTestBase::DropUnicodeTable() { } void FlightSQLODBCMockTestBase::SetUpTestSuite() { - std::string connect_str = GetConnectionString(); - Connect(connect_str, SQL_OV_ODBC3); - connected = true; + env = mock_odbcv3_handles.env; + conn = mock_odbcv3_handles.conn; + stmt = mock_odbcv3_handles.stmt; } void FlightSQLOdbcV2MockTestBase::SetUpTestSuite() { - std::string connect_str = GetConnectionString(); - Connect(connect_str, SQL_OV_ODBC2); - connected = true; + env = mock_odbcv2_handles.env; + conn = mock_odbcv2_handles.conn; + stmt = mock_odbcv2_handles.stmt; } -void FlightSQLOdbcEnvConnHandleMockTestBase::SetUpTestSuite() { AllocEnvConnHandles(); } +void FlightSQLOdbcEnvConnHandleMockTestBase::SetUpTestSuite() { + env_h = SQL_NULL_HENV; + conn_h = SQL_NULL_HDBC; + AllocEnvConnHandles(env_h, conn_h); + env = env_h; + conn = conn_h; +} -void FlightSQLOdbcEnvConnHandleMockTestBase::TearDownTestSuite() { FreeEnvConnHandles(); } +void FlightSQLOdbcEnvConnHandleMockTestBase::TearDownTestSuite() { + FreeEnvConnHandles(env_h, conn_h); +} bool CompareConnPropertyMap(Connection::ConnPropertyMap map1, Connection::ConnPropertyMap map2) { diff --git a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h index a4e8665c973..f2a51ce9611 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h +++ b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h @@ -41,12 +41,26 @@ static constexpr std::string_view kTestConnectStr = "ARROW_FLIGHT_SQL_ODBC_CONN"; static constexpr std::string_view kTestDsn = "Apache Arrow Flight SQL Test DSN"; -inline SQLHENV env = 0; -inline SQLHDBC conn = 0; -inline SQLHSTMT stmt = 0; +inline std::string remote_test_connect_str = ""; -inline bool skipping_test = false; -inline bool connected = false; +struct OdbcHandles { + SQLHENV env = SQL_NULL_HENV; + SQLHDBC conn = SQL_NULL_HDBC; + SQLHSTMT stmt = SQL_NULL_HSTMT; +}; + +inline OdbcHandles remote_odbcv3_handles; +inline OdbcHandles remote_odbcv2_handles; +inline OdbcHandles remote_non_connection_handles; +inline OdbcHandles mock_odbcv3_handles; +inline OdbcHandles mock_odbcv2_handles; +inline OdbcHandles mock_non_connection_handles; + +// These handles are meant to point to the relevant handle above +// depending on the test fixture. +inline SQLHENV env = SQL_NULL_HENV; +inline SQLHDBC conn = SQL_NULL_HDBC; +inline SQLHSTMT stmt = SQL_NULL_HSTMT; inline std::shared_ptr mock_server; inline int mock_server_port = 0; @@ -61,17 +75,19 @@ namespace arrow::flight::sql::odbc { class ODBCTestBase : public ::testing::Test { public: /// \brief Allocate environment and connection handles - static void AllocEnvConnHandles(SQLINTEGER odbc_ver = SQL_OV_ODBC3); + static void AllocEnvConnHandles(SQLHENV& env_handle, SQLHDBC& conn_handle, + SQLINTEGER odbc_ver = SQL_OV_ODBC3); /// \brief Free environment and connection handles - static void FreeEnvConnHandles(); + static void FreeEnvConnHandles(SQLHENV& env_handle, SQLHDBC& conn_handle); /// \brief Connect to Arrow Flight SQL server using connection string defined in /// environment variable "ARROW_FLIGHT_SQL_ODBC_CONN", allocate statement handle. /// Connects using ODBC Ver 3 by default - static void Connect(std::string connect_str, SQLINTEGER odbc_ver = SQL_OV_ODBC3); + static void Connect(std::string connect_str, SQLHENV& env_handle, SQLHDBC& conn_handle, + SQLINTEGER odbc_ver = SQL_OV_ODBC3); /// \brief Connect to Arrow Flight SQL server using connection string - static void ConnectWithString(std::string connection_str); + static void ConnectWithString(std::string connect_str, SQLHDBC& conn_handle); /// \brief Disconnect from server - static void Disconnect(); + static void Disconnect(SQLHENV& env_handle, SQLHDBC& conn_handle); /// \brief Get connection string from environment variable "ARROW_FLIGHT_SQL_ODBC_CONN" static std::string GetConnectionString(); /// \brief Get invalid connection string based on connection string defined in @@ -83,7 +99,6 @@ class ODBCTestBase : public ::testing::Test { protected: void SetUp() override; void TearDown() override; - static void TearDownTestSuite(); }; /// \brief Base test fixture for running tests against a remote server. @@ -92,9 +107,6 @@ class ODBCTestBase : public ::testing::Test { /// The connection string for connecting to this server is defined /// in the ARROW_FLIGHT_SQL_ODBC_CONN environment variable. class FlightSQLODBCRemoteTestBase : public ODBCTestBase { - public: - static void CheckForRemoteTest(); - protected: static void SetUpTestSuite(); }; @@ -111,6 +123,12 @@ class FlightSQLOdbcEnvConnHandleRemoteTestBase : public FlightSQLODBCRemoteTestB protected: static void SetUpTestSuite(); static void TearDownTestSuite(); + void SetUp() override {} + void TearDown() override {} + + private: + static SQLHENV env_h; + static SQLHDBC conn_h; }; static constexpr std::string_view kAuthorizationHeader = "authorization"; @@ -200,6 +218,12 @@ class FlightSQLOdbcEnvConnHandleMockTestBase : public FlightSQLODBCMockTestBase protected: static void SetUpTestSuite(); static void TearDownTestSuite(); + void SetUp() override {} + void TearDown() override {} + + private: + static SQLHENV env_h; + static SQLHDBC conn_h; }; /** ODBC read buffer size. */