diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml index aed25b4748c..196786b617c 100644 --- a/.github/workflows/cpp.yml +++ b/.github/workflows/cpp.yml @@ -310,7 +310,8 @@ jobs: ARROW_DATASET: ON ARROW_FLIGHT: ON ARROW_FLIGHT_SQL: ON - ARROW_FLIGHT_SQL_ODBC: OFF + ARROW_FLIGHT_SQL_ODBC: ON + ARROW_FLIGHT_SQL_ODBC_INSTALLER: ON ARROW_GANDIVA: ON ARROW_GCS: ON ARROW_HDFS: OFF diff --git a/.gitignore b/.gitignore index 8354aa8f816..64c713a74ed 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ dependency-reduced-pom.xml MANIFEST compile_commands.json build.ninja +build*/ # Generated Visual Studio files *.vcxproj @@ -107,3 +108,6 @@ java/.mvn/.develocity/ # rat filtered_rat.txt rat.txt + +# rc +*.rc diff --git a/ci/scripts/cpp_build.sh b/ci/scripts/cpp_build.sh index 2f02f8c1496..57594b124e1 100755 --- a/ci/scripts/cpp_build.sh +++ b/ci/scripts/cpp_build.sh @@ -213,6 +213,7 @@ else -DARROW_FLIGHT=${ARROW_FLIGHT:-OFF} \ -DARROW_FLIGHT_SQL=${ARROW_FLIGHT_SQL:-OFF} \ -DARROW_FLIGHT_SQL_ODBC=${ARROW_FLIGHT_SQL_ODBC:-OFF} \ + -DARROW_FLIGHT_SQL_ODBC_INSTALLER=${ARROW_FLIGHT_SQL_ODBC_INSTALLER:-OFF} \ -DARROW_FUZZING=${ARROW_FUZZING:-OFF} \ -DARROW_GANDIVA_PC_CXX_FLAGS=${ARROW_GANDIVA_PC_CXX_FLAGS:-} \ -DARROW_GANDIVA=${ARROW_GANDIVA:-OFF} \ diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 7286616c4fb..eea0580605a 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -720,9 +720,13 @@ endif() install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/../LICENSE.txt ${CMAKE_CURRENT_SOURCE_DIR}/../NOTICE.txt - ${CMAKE_CURRENT_SOURCE_DIR}/README.md DESTINATION "${ARROW_DOC_DIR}") + ${CMAKE_CURRENT_SOURCE_DIR}/README.md + DESTINATION "${ARROW_DOC_DIR}" + COMPONENT arrow_doc) -install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/gdb_arrow.py DESTINATION "${ARROW_GDB_DIR}") +install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/gdb_arrow.py + DESTINATION "${ARROW_GDB_DIR}" + COMPONENT arrow_gdb) # # Validate and print out Arrow configuration options diff --git a/cpp/CMakePresets.json b/cpp/CMakePresets.json index 8c29c9a2672..ea1ce44e062 100644 --- a/cpp/CMakePresets.json +++ b/cpp/CMakePresets.json @@ -180,6 +180,7 @@ "ARROW_BUILD_EXAMPLES": "ON", "ARROW_BUILD_UTILITIES": "ON", "ARROW_FLIGHT_SQL_ODBC": "ON", + "ARROW_FLIGHT_SQL_ODBC_INSTALLER": "ON", "ARROW_TENSORFLOW": "ON", "PARQUET_BUILD_EXAMPLES": "ON", "PARQUET_BUILD_EXECUTABLES": "ON" diff --git a/cpp/cmake_modules/BuildUtils.cmake b/cpp/cmake_modules/BuildUtils.cmake index db760400f7c..305546572c4 100644 --- a/cpp/cmake_modules/BuildUtils.cmake +++ b/cpp/cmake_modules/BuildUtils.cmake @@ -178,10 +178,12 @@ function(arrow_install_cmake_package PACKAGE_NAME EXPORT_NAME) write_basic_package_version_file("${BUILT_CONFIG_VERSION_CMAKE}" COMPATIBILITY SameMajorVersion) install(FILES "${BUILT_CONFIG_CMAKE}" "${BUILT_CONFIG_VERSION_CMAKE}" - DESTINATION "${ARROW_CMAKE_DIR}/${PACKAGE_NAME}") + DESTINATION "${ARROW_CMAKE_DIR}/${PACKAGE_NAME}" + COMPONENT config_cmake_file) set(TARGETS_CMAKE "${PACKAGE_NAME}Targets.cmake") install(EXPORT ${EXPORT_NAME} DESTINATION "${ARROW_CMAKE_DIR}/${PACKAGE_NAME}" + COMPONENT config_cmake_export NAMESPACE "${PACKAGE_NAME}::" FILE "${TARGETS_CMAKE}") endfunction() @@ -403,8 +405,11 @@ function(ADD_ARROW_LIB LIB_NAME) install(TARGETS ${LIB_NAME}_shared ${INSTALL_IS_OPTIONAL} EXPORT ${LIB_NAME}_targets ARCHIVE DESTINATION ${INSTALL_ARCHIVE_DIR} + COMPONENT ${LIB_NAME}_shared_archive LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} + COMPONENT ${LIB_NAME}_shared_library RUNTIME DESTINATION ${INSTALL_RUNTIME_DIR} + COMPONENT ${LIB_NAME}_shared_runtime INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) endif() @@ -471,8 +476,11 @@ function(ADD_ARROW_LIB LIB_NAME) install(TARGETS ${LIB_NAME}_static ${INSTALL_IS_OPTIONAL} EXPORT ${LIB_NAME}_targets ARCHIVE DESTINATION ${INSTALL_ARCHIVE_DIR} + COMPONENT ${LIB_NAME}_static_library LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} + COMPONENT ${LIB_NAME}_static_library RUNTIME DESTINATION ${INSTALL_RUNTIME_DIR} + COMPONENT ${LIB_NAME}_static_library INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) endif() @@ -934,7 +942,9 @@ function(ARROW_INSTALL_ALL_HEADERS PATH) endif() list(APPEND PUBLIC_HEADERS ${HEADER}) endforeach() - install(FILES ${PUBLIC_HEADERS} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/${PATH}") + install(FILES ${PUBLIC_HEADERS} + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/${PATH}" + COMPONENT ${HEADER}_header) endfunction() function(ARROW_ADD_PKG_CONFIG MODULE) @@ -944,7 +954,8 @@ function(ARROW_ADD_PKG_CONFIG MODULE) OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/$/${MODULE}.pc" INPUT "${CMAKE_CURRENT_BINARY_DIR}/${MODULE}.pc.generate.in") install(FILES "${CMAKE_CURRENT_BINARY_DIR}/$/${MODULE}.pc" - DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig/") + DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig/" + COMPONENT ${MODULE}_pkg_config) endfunction() # Implementations of lisp "car" and "cdr" functions diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 835baec87ba..d1a6632ba45 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -236,7 +236,8 @@ function(provide_cmake_module MODULE_NAME ARROW_CMAKE_PACKAGE_NAME) message(STATUS "Providing CMake module for ${MODULE_NAME} as part of ${ARROW_CMAKE_PACKAGE_NAME} CMake package" ) install(FILES "${module}" - DESTINATION "${ARROW_CMAKE_DIR}/${ARROW_CMAKE_PACKAGE_NAME}") + DESTINATION "${ARROW_CMAKE_DIR}/${ARROW_CMAKE_PACKAGE_NAME}" + COMPONENT ${MODULE_NAME}_module) endif() endfunction() @@ -2397,20 +2398,22 @@ function(build_gtest) endforeach() install(DIRECTORY "${googletest_SOURCE_DIR}/googlemock/include/" "${googletest_SOURCE_DIR}/googletest/include/" - DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}") + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}" + COMPONENT gtest_dir) add_library(arrow::GTest::gtest_headers INTERFACE IMPORTED) target_include_directories(arrow::GTest::gtest_headers INTERFACE "${googletest_SOURCE_DIR}/googlemock/include/" "${googletest_SOURCE_DIR}/googletest/include/") install(TARGETS gmock gmock_main gtest gtest_main EXPORT arrow_testing_targets - RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}" - ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" - LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}") + RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}" COMPONENT gtest_runtime + ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" COMPONENT gtest_archive + LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}" COMPONENT gtest_library) if(MSVC) install(FILES $ $ $ $ DESTINATION "${CMAKE_INSTALL_BINDIR}" + COMPONENT gtest_pdb OPTIONAL) endif() add_library(arrow::GTest::gmock ALIAS gmock) diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index ec8b6c1b32f..f1bf71b284f 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -349,7 +349,8 @@ string(REPLACE "${CMAKE_BINARY_DIR}" "" REDACTED_CXX_FLAGS configure_file("util/config.h.cmake" "util/config.h" ESCAPE_QUOTES) configure_file("util/config_internal.h.cmake" "util/config_internal.h" ESCAPE_QUOTES) install(FILES "${CMAKE_CURRENT_BINARY_DIR}/util/config.h" - DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/arrow/util") + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/arrow/util" + COMPONENT arrow_config) set(ARROW_SRCS builder.cc @@ -1044,7 +1045,8 @@ if(ARROW_BUILD_BUNDLED_DEPENDENCIES) get_target_property(arrow_bundled_dependencies_path arrow_bundled_dependencies IMPORTED_LOCATION) install(FILES ${arrow_bundled_dependencies_path} ${INSTALL_IS_OPTIONAL} - DESTINATION ${CMAKE_INSTALL_LIBDIR}) + DESTINATION ${CMAKE_INSTALL_LIBDIR} + COMPONENT arrow_bundled_dependencies) string(PREPEND ARROW_PC_LIBS_PRIVATE " -larrow_bundled_dependencies") list(INSERT ARROW_STATIC_INSTALL_INTERFACE_LIBS 0 "Arrow::arrow_bundled_dependencies") endif() @@ -1161,6 +1163,7 @@ if(ARROW_BUILD_SHARED AND NOT WIN32) if(ARROW_GDB_AUTO_LOAD_LIBARROW_GDB_INSTALL) install(FILES "${CMAKE_CURRENT_BINARY_DIR}/libarrow_gdb.py" DESTINATION "${ARROW_GDB_AUTO_LOAD_LIBARROW_GDB_DIR}" + COMPONENT arrow_gdb RENAME "$-gdb.py") endif() endif() @@ -1224,11 +1227,13 @@ arrow_install_all_headers("arrow") config_summary_cmake_setters("${CMAKE_CURRENT_BINARY_DIR}/ArrowOptions.cmake") install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ArrowOptions.cmake - DESTINATION "${ARROW_CMAKE_DIR}/Arrow") + DESTINATION "${ARROW_CMAKE_DIR}/Arrow" + COMPONENT arrow_options_cmake) # For backward compatibility for find_package(arrow) install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/arrow-config.cmake - DESTINATION "${ARROW_CMAKE_DIR}/Arrow") + DESTINATION "${ARROW_CMAKE_DIR}/Arrow" + COMPONENT arrow_config_cmake) # # Unit tests diff --git a/cpp/src/arrow/flight/sql/column_metadata.cc b/cpp/src/arrow/flight/sql/column_metadata.cc index 30f557084b2..8d2d2b4ddca 100644 --- a/cpp/src/arrow/flight/sql/column_metadata.cc +++ b/cpp/src/arrow/flight/sql/column_metadata.cc @@ -58,8 +58,15 @@ const char* ColumnMetadata::kIsSearchable = "ARROW:FLIGHT:SQL:IS_SEARCHABLE"; const char* ColumnMetadata::kRemarks = "ARROW:FLIGHT:SQL:REMARKS"; ColumnMetadata::ColumnMetadata( - std::shared_ptr metadata_map) - : metadata_map_(std::move(metadata_map)) {} + std::shared_ptr metadata_map) { + if (metadata_map) { + metadata_map_ = std::move(metadata_map); + } else { + std::shared_ptr empty_metadata_map( + new arrow::KeyValueMetadata); + metadata_map_ = std::move(empty_metadata_map); + } +} arrow::Result ColumnMetadata::GetCatalogName() const { return metadata_map_->Get(kCatalogName); diff --git a/cpp/src/arrow/flight/sql/odbc/CMakeLists.txt b/cpp/src/arrow/flight/sql/odbc/CMakeLists.txt index ac18a9bc7cd..5a43a1bfdb0 100644 --- a/cpp/src/arrow/flight/sql/odbc/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/odbc/CMakeLists.txt @@ -38,10 +38,29 @@ add_subdirectory(tests) arrow_install_all_headers("arrow/flight/sql/odbc") -set(ARROW_FLIGHT_SQL_ODBC_SRCS entry_points.cc odbc_api.cc) +# ODBC Release information +set(ODBC_PACKAGE_VERSION_MAJOR "1") +set(ODBC_PACKAGE_VERSION_MINOR "0") +set(ODBC_PACKAGE_VERSION_PATCH "0") +set(ODBC_PACKAGE_NAME "Apache Arrow Flight SQL ODBC") +set(ODBC_PACKAGE_VENDOR "Apache Arrow") + +# Compile entry_points.cc before odbc_api.cc due to conflict from sql.h and flight/types.h +set(ARROW_FLIGHT_SQL_ODBC_SRCS odbc_api.cc entry_points.cc) if(WIN32) - list(APPEND ARROW_FLIGHT_SQL_ODBC_SRCS odbc.def) + set(VER_FILEVERSION + "${ODBC_PACKAGE_VERSION_MAJOR},${ODBC_PACKAGE_VERSION_MINOR},${ODBC_PACKAGE_VERSION_PATCH},0" + ) + set(VER_FILEVERSION_STR + ${ODBC_PACKAGE_VERSION_MAJOR}.${ODBC_PACKAGE_VERSION_MINOR}.${ODBC_PACKAGE_VERSION_PATCH} + ) + set(VER_COMPANYNAME_STR ${ODBC_PACKAGE_VENDOR}) + set(VER_PRODUCTNAME_STR ${ODBC_PACKAGE_NAME}) + + configure_file("install/versioninfo.rc.in" "install/versioninfo.rc" @ONLY) + + list(APPEND ARROW_FLIGHT_SQL_ODBC_SRCS odbc.def install/versioninfo.rc) endif() add_arrow_lib(arrow_flight_sql_odbc @@ -75,3 +94,76 @@ add_arrow_lib(arrow_flight_sql_odbc foreach(LIB_TARGET ${ARROW_FLIGHT_SQL_ODBC_LIBRARIES}) target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_FLIGHT_SQL_ODBC_EXPORTING) endforeach() + +# Construct ODBC Windows installer. Only Release installer is supported +if(ARROW_FLIGHT_SQL_ODBC_INSTALLER) + + include(InstallRequiredSystemLibraries) + + set(CPACK_RESOURCE_FILE_LICENSE + "${CMAKE_CURRENT_SOURCE_DIR}/../../../../../../LICENSE.txt") + # Tentative version 1.0.0 + set(CPACK_PACKAGE_VERSION_MAJOR ${ODBC_PACKAGE_VERSION_MAJOR}) + set(CPACK_PACKAGE_VERSION_MINOR ${ODBC_PACKAGE_VERSION_MINOR}) + set(CPACK_PACKAGE_VERSION_PATCH ${ODBC_PACKAGE_VERSION_PATCH}) + + set(CPACK_PACKAGE_NAME ${ODBC_PACKAGE_NAME}) + set(CPACK_PACKAGE_VENDOR ${ODBC_PACKAGE_VENDOR}) + set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "Apache Arrow Flight SQL ODBC Driver") + set(CPACK_PACKAGE_CONTACT "#GH-47787 TODO arrow maintainers") + + # GH-47876 TODO: set up `flight_sql_odbc_lib` component for macOS Installer + # GH-47877 TODO: set up `flight_sql_odbc_lib` component for Linux Installer + if(WIN32) + install(DIRECTORY "${BUILD_OUTPUT_ROOT_DIRECTORY}${CMAKE_BUILD_TYPE}/" + DESTINATION bin + COMPONENT flight_sql_odbc_lib + FILES_MATCHING + # Use regex for dll name patterns with versions + PATTERN "abseil_dll.dll" + PATTERN "arrow.dll" + PATTERN "arrow_compute.dll" + PATTERN "arrow_flight.dll" + PATTERN "arrow_flight_sql.dll" + PATTERN "arrow_flight_sql_odbc.dll" + PATTERN "boost_locale*.dll" + PATTERN "cares.dll" + PATTERN "libcrypto*.dll" + PATTERN "libprotobuf.dll" + PATTERN "libssl*.dll" + PATTERN "re2.dll" + PATTERN "utf8proc.dll" + PATTERN "zlib1.dll") + + set(CPACK_WIX_EXTRA_SOURCES + "${CMAKE_CURRENT_SOURCE_DIR}/install/arrow-flight-sql-odbc.wxs") + set(CPACK_WIX_PATCH_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/install/arrow-flight-sql-odbc-patch.xml") + + set(CPACK_WIX_UI_BANNER "${CMAKE_CURRENT_SOURCE_DIR}/install/arrow-wix-banner.bmp") + endif() + + get_cmake_property(CPACK_COMPONENTS_ALL COMPONENTS) + set(CPACK_COMPONENTS_ALL Unspecified) + list(APPEND CPACK_COMPONENTS_ALL "flight_sql_odbc_lib") + + if(WIN32) + # WiX msi installer on Windows + # CPack is compatible with WiX V.5 and V.6 + set(CPACK_GENERATOR "WIX") + set(CPACK_WIX_VERSION 4) + + # Upgrade GUID is required to be unchanged for ODBC installer to upgrade + set(CPACK_WIX_UPGRADE_GUID "DBF27A18-F8BF-423F-9E3A-957414D52C4B") + set(CPACK_WIX_PRODUCT_GUID "279D087B-93B5-4DC3-BA69-BCF485022A26") + endif() + # GH-47876 TODO: create macOS Installer using cpack + # GH-47877 TODO: create Linux Installer using cpack + + # Load CPack after all CPACK* variables are set + include(CPack) + cpack_add_component(flight_sql_odbc_lib + DISPLAY_NAME "ODBC library" + DESCRIPTION "ODBC library bin, required to install" + REQUIRED) +endif() diff --git a/cpp/src/arrow/flight/sql/odbc/README b/cpp/src/arrow/flight/sql/odbc/README new file mode 100644 index 00000000000..7cd53778337 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/README @@ -0,0 +1,62 @@ + + +## Steps to Register the 64-bit Apache Arrow ODBC driver on Windows + +After the build succeeds, the ODBC DLL will be located in +`build\debug\Debug` for a debug build and `build\release\Release` for a release build. + +1. Open Power Shell as administrator. + +2. Register your ODBC DLL: + Need to replace with actual path to repository in the commands. + + i. `cd to repo.` + ii. `cd ` + iii. Run script to register your ODBC DLL as Apache Arrow Flight SQL ODBC Driver + `.\cpp\src\arrow\flight\sql\odbc\install\install_amd64.cmd \cpp\build\< release | debug >\< Release | Debug>\arrow_flight_sql_odbc.dll` + Example command for reference: + `.\cpp\src\arrow\flight\sql\odbc\install\install_amd64.cmd C:\path\to\arrow\cpp\build\release\Release\arrow_flight_sql_odbc.dll` + +If the registration is successful, then Apache Arrow Flight SQL ODBC Driver +should show as an available ODBC driver in the x64 ODBC Driver Manager. + +## Steps to Generate Windows Installer +1. Build with `ARROW_FLIGHT_SQL_ODBC=ON` and `ARROW_FLIGHT_SQL_ODBC_INSTALLER=ON`. +2. `cd` to `build` folder. +3. Run `cpack`. + +If the generation is successful, you will find `Apache Arrow Flight SQL ODBC--win64.msi` generated under the `build` folder. + + +## Steps to Enable Logging +Arrow Flight SQL ODBC driver uses Arrow's internal logging framework. By default, the log messages are printed to the terminal. +1. Set environment variable `ARROW_ODBC_LOG_LEVEL` to any of the following valid values to enable logging. If `ARROW_ODBC_LOG_LEVEL` is set to a non-empty string that does not match any of the following values, `DEBUG` level is used by default. + +The characters are case-insensitive. +- TRACE +- DEBUG +- INFO +- WARNING +- ERROR +- FATAL + +The Windows ODBC driver currently does not support writing log files. `ARROW_USE_GLOG` is required to write log files, and `ARROW_USE_GLOG` is disabled on Windows platform since plasma using `glog` is not fully tested on windows. + +Note: GH-47670 running more than 1 tests with logging enabled is not fully supported. diff --git a/cpp/src/arrow/flight/sql/odbc/install/arrow-flight-sql-odbc-patch.xml b/cpp/src/arrow/flight/sql/odbc/install/arrow-flight-sql-odbc-patch.xml new file mode 100644 index 00000000000..f1a63ce5d3b --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/install/arrow-flight-sql-odbc-patch.xml @@ -0,0 +1,22 @@ + + + + + + + diff --git a/cpp/src/arrow/flight/sql/odbc/install/arrow-flight-sql-odbc.wxs b/cpp/src/arrow/flight/sql/odbc/install/arrow-flight-sql-odbc.wxs new file mode 100644 index 00000000000..bd0216aa766 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/install/arrow-flight-sql-odbc.wxs @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/cpp/src/arrow/flight/sql/odbc/install/arrow-wix-banner.bmp b/cpp/src/arrow/flight/sql/odbc/install/arrow-wix-banner.bmp new file mode 100644 index 00000000000..0c82036f4ec Binary files /dev/null and b/cpp/src/arrow/flight/sql/odbc/install/arrow-wix-banner.bmp differ diff --git a/cpp/src/arrow/flight/sql/odbc/install/versioninfo.rc.in b/cpp/src/arrow/flight/sql/odbc/install/versioninfo.rc.in new file mode 100644 index 00000000000..13024a7a50b --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/install/versioninfo.rc.in @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#define VER_FILEVERSION @VER_FILEVERSION@ +#define VER_FILEVERSION_STR "@VER_FILEVERSION_STR@\0" + +#define VER_PRODUCTVERSION @VER_FILEVERSION@ +#define VER_PRODUCTVERSION_STR "@VER_FILEVERSION_STR@\0" + +#define VER_COMPANYNAME_STR "@VER_COMPANYNAME_STR@\0" +#define VER_PRODUCTNAME_STR "@VER_PRODUCTNAME_STR@\0" + +1 VERSIONINFO +FILEVERSION VER_FILEVERSION +PRODUCTVERSION VER_PRODUCTVERSION +BEGIN + BLOCK "StringFileInfo" + BEGIN + BLOCK "040904E4" + BEGIN + VALUE "CompanyName", VER_COMPANYNAME_STR + VALUE "FileVersion", VER_FILEVERSION_STR + VALUE "ProductName", VER_PRODUCTNAME_STR + VALUE "ProductVersion", VER_PRODUCTVERSION_STR + END + END + + BLOCK "VarFileInfo" + BEGIN + /* The following line should only be modified for localized versions. */ + /* It consists of any number of WORD,WORD pairs, with each pair */ + /* describing a language,codepage combination supported by the file. */ + /* */ + /* For example, a file might have values "0x409,1252" indicating that it */ + /* supports English language (0x409) in the Windows ANSI codepage (1252). */ + + VALUE "Translation", 0x409, 1252 + + END +END diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc index a028e063b34..328d2485237 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc @@ -41,14 +41,6 @@ SQLRETURN SQLAllocHandle(SQLSMALLINT type, SQLHANDLE parent, SQLHANDLE* result) ARROW_LOG(DEBUG) << "SQLAllocHandle called with type: " << type << ", parent: " << parent << ", result: " << static_cast(result); - // GH-47706 TODO: Add tests for SQLAllocStmt, pre-requisite requires - // SQLDriverConnect implementation - - // GH-47707 TODO: Add tests for SQL_HANDLE_DESC implementation for - // descriptor handle, pre-requisite requires SQLAllocStmt - - *result = nullptr; - switch (type) { case SQL_HANDLE_ENV: { using ODBC::ODBCEnvironment; @@ -144,11 +136,6 @@ SQLRETURN SQLAllocHandle(SQLSMALLINT type, SQLHANDLE parent, SQLHANDLE* result) SQLRETURN SQLFreeHandle(SQLSMALLINT type, SQLHANDLE handle) { ARROW_LOG(DEBUG) << "SQLFreeHandle called with type: " << type << ", handle: " << handle; - // GH-47706 TODO: Add tests for SQLFreeStmt, pre-requisite requires - // SQLAllocStmt tests - - // GH-47707 TODO: Add tests for SQL_HANDLE_DESC implementation for - // descriptor handle switch (type) { case SQL_HANDLE_ENV: { using ODBC::ODBCEnvironment; @@ -219,8 +206,44 @@ SQLRETURN SQLFreeHandle(SQLSMALLINT type, SQLHANDLE handle) { SQLRETURN SQLFreeStmt(SQLHSTMT handle, SQLUSMALLINT option) { ARROW_LOG(DEBUG) << "SQLFreeStmt called with handle: " << handle << ", option: " << option; - // GH-47706 TODO: Implement SQLFreeStmt - return SQL_INVALID_HANDLE; + + switch (option) { + case SQL_CLOSE: { + using ODBC::ODBCStatement; + + return ODBCStatement::ExecuteWithDiagnostics(handle, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(handle); + + // Close cursor with suppressErrors set to true + statement->CloseCursor(true); + + return SQL_SUCCESS; + }); + } + + case SQL_DROP: { + return SQLFreeHandle(SQL_HANDLE_STMT, handle); + } + + case SQL_UNBIND: { + using ODBC::ODBCDescriptor; + using ODBC::ODBCStatement; + return ODBCStatement::ExecuteWithDiagnostics(handle, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(handle); + ODBCDescriptor* ard = statement->GetARD(); + // Unbind columns + ard->SetHeaderField(SQL_DESC_COUNT, (void*)0, 0); + return SQL_SUCCESS; + }); + } + + // SQLBindParameter is not supported + case SQL_RESET_PARAMS: { + return SQL_SUCCESS; + } + } + + return SQL_ERROR; } inline bool IsValidStringFieldArgs(SQLPOINTER diag_info_ptr, SQLSMALLINT buffer_length, @@ -248,7 +271,6 @@ SQLRETURN SQLGetDiagField(SQLSMALLINT handle_type, SQLHANDLE handle, << ", diag_info_ptr: " << diag_info_ptr << ", buffer_length: " << buffer_length << ", string_length_ptr: " << static_cast(string_length_ptr); - // GH-46575 TODO: Add tests for SQLGetDiagField using ODBC::GetStringAttribute; using ODBC::ODBCConnection; using ODBC::ODBCDescriptor; @@ -513,8 +535,6 @@ SQLRETURN SQLGetDiagRec(SQLSMALLINT handle_type, SQLHANDLE handle, SQLSMALLINT r << ", message_text: " << static_cast(message_text) << ", buffer_length: " << buffer_length << ", text_length_ptr: " << static_cast(text_length_ptr); - // GH-46575 TODO: Add tests for SQLGetDiagRec - using arrow::flight::sql::odbc::Diagnostics; using ODBC::GetStringAttribute; using ODBC::ODBCConnection; using ODBC::ODBCDescriptor; @@ -714,8 +734,15 @@ SQLRETURN SQLGetConnectAttr(SQLHDBC conn, SQLINTEGER attribute, SQLPOINTER value << ", attribute: " << attribute << ", value_ptr: " << value_ptr << ", buffer_length: " << buffer_length << ", string_length_ptr: " << static_cast(string_length_ptr); - // GH-47708 TODO: Implement SQLGetConnectAttr - return SQL_INVALID_HANDLE; + + using ODBC::ODBCConnection; + + return ODBCConnection::ExecuteWithDiagnostics(conn, SQL_ERROR, [=]() { + const bool is_unicode = true; + ODBCConnection* connection = reinterpret_cast(conn); + return connection->GetConnectAttr(attribute, value_ptr, buffer_length, + string_length_ptr, is_unicode); + }); } SQLRETURN SQLSetConnectAttr(SQLHDBC conn, SQLINTEGER attr, SQLPOINTER value_ptr, @@ -738,7 +765,7 @@ SQLRETURN SQLSetConnectAttr(SQLHDBC conn, SQLINTEGER attr, SQLPOINTER value_ptr, // entries in the properties. void LoadPropertiesFromDSN(const std::string& dsn, Connection::ConnPropertyMap& properties) { - arrow::flight::sql::odbc::config::Configuration config; + config::Configuration config; config.LoadDsn(dsn); Connection::ConnPropertyMap dsn_properties = config.GetProperties(); for (auto& [key, value] : dsn_properties) { @@ -796,7 +823,7 @@ SQLRETURN SQLDriverConnect(SQLHDBC conn, SQLHWND window_handle, // Load the DSN window according to driver_completion if (driver_completion == SQL_DRIVER_PROMPT) { // Load DSN window before first attempt to connect - arrow::flight::sql::odbc::config::Configuration config; + config::Configuration config; if (!DisplayConnectionWindow(window_handle, config, properties)) { return static_cast(SQL_NO_DATA); } @@ -809,7 +836,7 @@ SQLRETURN SQLDriverConnect(SQLHDBC conn, SQLHWND window_handle, // If first connection fails due to missing attributes, load // the DSN window and try to connect again if (!missing_properties.empty()) { - arrow::flight::sql::odbc::config::Configuration config; + config::Configuration config; missing_properties.clear(); if (!DisplayConnectionWindow(window_handle, config, properties)) { @@ -892,7 +919,7 @@ SQLRETURN SQLDisconnect(SQLHDBC conn) { SQLRETURN SQLGetInfo(SQLHDBC conn, SQLUSMALLINT info_type, SQLPOINTER info_value_ptr, SQLSMALLINT buf_len, SQLSMALLINT* string_length_ptr) { - ARROW_LOG(DEBUG) << "SQLGetInfoW called with conn: " << conn + ARROW_LOG(DEBUG) << "SQLGetInfo called with conn: " << conn << ", info_type: " << info_type << ", info_value_ptr: " << info_value_ptr << ", buf_len: " << buf_len << ", string_length_ptr: " @@ -923,8 +950,19 @@ SQLRETURN SQLGetStmtAttr(SQLHSTMT stmt, SQLINTEGER attribute, SQLPOINTER value_p << ", attribute: " << attribute << ", value_ptr: " << value_ptr << ", buffer_length: " << buffer_length << ", string_length_ptr: " << static_cast(string_length_ptr); - // GH-47710 TODO: Implement SQLGetStmtAttr - return SQL_INVALID_HANDLE; + + using ODBC::ODBCStatement; + + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + + bool is_unicode = true; + + statement->GetStmtAttr(attribute, value_ptr, buffer_length, string_length_ptr, + is_unicode); + + return SQL_SUCCESS; + }); } SQLRETURN SQLSetStmtAttr(SQLHSTMT stmt, SQLINTEGER attribute, SQLPOINTER value_ptr, @@ -932,42 +970,95 @@ SQLRETURN SQLSetStmtAttr(SQLHSTMT stmt, SQLINTEGER attribute, SQLPOINTER value_p ARROW_LOG(DEBUG) << "SQLSetStmtAttrW called with stmt: " << stmt << ", attribute: " << attribute << ", value_ptr: " << value_ptr << ", string_length: " << string_length; - // GH-47710 TODO: Implement SQLSetStmtAttr - return SQL_INVALID_HANDLE; + + using ODBC::ODBCStatement; + + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + + bool is_unicode = true; + + statement->SetStmtAttr(attribute, value_ptr, string_length, is_unicode); + + return SQL_SUCCESS; + }); } SQLRETURN SQLExecDirect(SQLHSTMT stmt, SQLWCHAR* query_text, SQLINTEGER text_length) { ARROW_LOG(DEBUG) << "SQLExecDirectW called with stmt: " << stmt << ", query_text: " << static_cast(query_text) << ", text_length: " << text_length; - // GH-47711 TODO: Implement SQLExecDirect - return SQL_INVALID_HANDLE; + + using ODBC::ODBCStatement; + // The driver is built to handle SELECT statements only. + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + std::string query = ODBC::SqlWcharToString(query_text, text_length); + + statement->Prepare(query); + statement->ExecutePrepared(); + + return SQL_SUCCESS; + }); } SQLRETURN SQLPrepare(SQLHSTMT stmt, SQLWCHAR* query_text, SQLINTEGER text_length) { ARROW_LOG(DEBUG) << "SQLPrepareW called with stmt: " << stmt << ", query_text: " << static_cast(query_text) << ", text_length: " << text_length; - // GH-47712 TODO: Implement SQLPrepare - return SQL_INVALID_HANDLE; + + using ODBC::ODBCStatement; + // The driver is built to handle SELECT statements only. + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + std::string query = ODBC::SqlWcharToString(query_text, text_length); + + statement->Prepare(query); + + return SQL_SUCCESS; + }); } SQLRETURN SQLExecute(SQLHSTMT stmt) { ARROW_LOG(DEBUG) << "SQLExecute called with stmt: " << stmt; - // GH-47712 TODO: Implement SQLExecute - return SQL_INVALID_HANDLE; + + using ODBC::ODBCStatement; + // The driver is built to handle SELECT statements only. + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + + statement->ExecutePrepared(); + + return SQL_SUCCESS; + }); } SQLRETURN SQLFetch(SQLHSTMT stmt) { ARROW_LOG(DEBUG) << "SQLFetch called with stmt: " << stmt; - // GH-47713 TODO: Implement SQLFetch - return SQL_INVALID_HANDLE; + + using ODBC::ODBCDescriptor; + using ODBC::ODBCStatement; + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + + // The SQL_ATTR_ROW_ARRAY_SIZE statement attribute specifies the number of rows in the + // rowset. + ODBCDescriptor* ard = statement->GetARD(); + size_t rows = static_cast(ard->GetArraySize()); + + if (statement->Fetch(rows)) { + return SQL_SUCCESS; + } else { + // Reached the end of rowset + return SQL_NO_DATA; + } + }); } SQLRETURN SQLExtendedFetch(SQLHSTMT stmt, SQLUSMALLINT fetch_orientation, SQLLEN fetch_offset, SQLULEN* row_count_ptr, SQLUSMALLINT* row_status_array) { - // GH-47110 TODO: SQLExtendedFetch should return SQL_SUCCESS_WITH_INFO for certain diag + // GH-47110: SQLExtendedFetch should return SQL_SUCCESS_WITH_INFO for certain diag // states ARROW_LOG(DEBUG) << "SQLExtendedFetch called with stmt: " << stmt << ", fetch_orientation: " << fetch_orientation @@ -975,8 +1066,30 @@ SQLRETURN SQLExtendedFetch(SQLHSTMT stmt, SQLUSMALLINT fetch_orientation, << ", row_count_ptr: " << static_cast(row_count_ptr) << ", row_status_array: " << static_cast(row_status_array); - // GH-47714 TODO: Implement SQLExtendedFetch - return SQL_INVALID_HANDLE; + + using ODBC::ODBCDescriptor; + using ODBC::ODBCStatement; + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + if (fetch_orientation != SQL_FETCH_NEXT) { + throw DriverException("Optional feature not supported.", "HYC00"); + } + // fetch_offset is ignored as only SQL_FETCH_NEXT is supported + + ODBCStatement* statement = reinterpret_cast(stmt); + + // The SQL_ROWSET_SIZE statement attribute specifies the number of rows in the + // rowset. + SQLULEN row_set_size = statement->GetRowsetSize(); + ARROW_LOG(DEBUG) << "SQL_ROWSET_SIZE value for SQLExtendedFetch: " << row_set_size; + + if (statement->Fetch(static_cast(row_set_size), row_count_ptr, + row_status_array)) { + return SQL_SUCCESS; + } else { + // Reached the end of rowset + return SQL_NO_DATA; + } + }); } SQLRETURN SQLFetchScroll(SQLHSTMT stmt, SQLSMALLINT fetch_orientation, @@ -984,8 +1097,28 @@ SQLRETURN SQLFetchScroll(SQLHSTMT stmt, SQLSMALLINT fetch_orientation, ARROW_LOG(DEBUG) << "SQLFetchScroll called with stmt: " << stmt << ", fetch_orientation: " << fetch_orientation << ", fetch_offset: " << fetch_offset; - // GH-47715 TODO: Implement SQLFetchScroll - return SQL_INVALID_HANDLE; + + using ODBC::ODBCDescriptor; + using ODBC::ODBCStatement; + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + if (fetch_orientation != SQL_FETCH_NEXT) { + throw DriverException("Optional feature not supported.", "HYC00"); + } + // fetch_offset is ignored as only SQL_FETCH_NEXT is supported + + ODBCStatement* statement = reinterpret_cast(stmt); + + // The SQL_ATTR_ROW_ARRAY_SIZE statement attribute specifies the number of rows in the + // rowset. + ODBCDescriptor* ard = statement->GetARD(); + size_t rows = static_cast(ard->GetArraySize()); + if (statement->Fetch(rows)) { + return SQL_SUCCESS; + } else { + // Reached the end of rowset + return SQL_NO_DATA; + } + }); } SQLRETURN SQLBindCol(SQLHSTMT stmt, SQLUSMALLINT record_number, SQLSMALLINT c_type, @@ -994,48 +1127,85 @@ SQLRETURN SQLBindCol(SQLHSTMT stmt, SQLUSMALLINT record_number, SQLSMALLINT c_ty << ", record_number: " << record_number << ", c_type: " << c_type << ", data_ptr: " << data_ptr << ", buffer_length: " << buffer_length << ", indicator_ptr: " << static_cast(indicator_ptr); - // GH-47716 TODO: Implement SQLBindCol - return SQL_INVALID_HANDLE; + + using ODBC::ODBCDescriptor; + using ODBC::ODBCStatement; + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + // GH-47021: implement driver to return indicator value when data pointer is null + ODBCStatement* statement = reinterpret_cast(stmt); + ODBCDescriptor* ard = statement->GetARD(); + ard->BindCol(record_number, c_type, data_ptr, buffer_length, indicator_ptr); + return SQL_SUCCESS; + }); } SQLRETURN SQLCloseCursor(SQLHSTMT stmt) { ARROW_LOG(DEBUG) << "SQLCloseCursor called with stmt: " << stmt; - // GH-47717 TODO: Implement SQLCloseCursor - return SQL_INVALID_HANDLE; + + using ODBC::ODBCStatement; + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + + // Close cursor with suppressErrors set to false + statement->CloseCursor(false); + + return SQL_SUCCESS; + }); } SQLRETURN SQLGetData(SQLHSTMT stmt, SQLUSMALLINT record_number, SQLSMALLINT c_type, SQLPOINTER data_ptr, SQLLEN buffer_length, SQLLEN* indicator_ptr) { - // GH-46979 TODO: support SQL_C_GUID data type - // GH-46980 TODO: support Interval data types - // GH-46985 TODO: return warning message instead of error on float truncation case + // GH-46979: support SQL_C_GUID data type + // GH-46980: support Interval data types + // GH-46985: return warning message instead of error on float truncation case ARROW_LOG(DEBUG) << "SQLGetData called with stmt: " << stmt << ", record_number: " << record_number << ", c_type: " << c_type << ", data_ptr: " << data_ptr << ", buffer_length: " << buffer_length << ", indicator_ptr: " << static_cast(indicator_ptr); - // GH-47713 TODO: Implement SQLGetData - return SQL_INVALID_HANDLE; + + using ODBC::ODBCStatement; + + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + return statement->GetData(record_number, c_type, data_ptr, buffer_length, + indicator_ptr); + }); } SQLRETURN SQLMoreResults(SQLHSTMT stmt) { ARROW_LOG(DEBUG) << "SQLMoreResults called with stmt: " << stmt; - // GH-47713 TODO: Implement SQLMoreResults - return SQL_INVALID_HANDLE; + + using ODBC::ODBCStatement; + // Multiple result sets not supported. Return SQL_NO_DATA by default. + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + return statement->GetMoreResults(); + }); } SQLRETURN SQLNumResultCols(SQLHSTMT stmt, SQLSMALLINT* column_count_ptr) { ARROW_LOG(DEBUG) << "SQLNumResultCols called with stmt: " << stmt << ", column_count_ptr: " << static_cast(column_count_ptr); - // GH-47713 TODO: Implement SQLNumResultCols - return SQL_INVALID_HANDLE; + + using ODBC::ODBCStatement; + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + statement->GetColumnCount(column_count_ptr); + return SQL_SUCCESS; + }); } SQLRETURN SQLRowCount(SQLHSTMT stmt, SQLLEN* row_count_ptr) { ARROW_LOG(DEBUG) << "SQLRowCount called with stmt: " << stmt << ", column_count_ptr: " << static_cast(row_count_ptr); - // GH-47713 TODO: Implement SQLRowCount - return SQL_INVALID_HANDLE; + + using ODBC::ODBCStatement; + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + statement->GetRowCount(row_count_ptr); + return SQL_SUCCESS; + }); } SQLRETURN SQLTables(SQLHSTMT stmt, SQLWCHAR* catalog_name, @@ -1043,7 +1213,7 @@ SQLRETURN SQLTables(SQLHSTMT stmt, SQLWCHAR* catalog_name, SQLSMALLINT schema_name_length, SQLWCHAR* table_name, SQLSMALLINT table_name_length, SQLWCHAR* table_type, SQLSMALLINT table_type_length) { - ARROW_LOG(DEBUG) << "SQLTablesW called with stmt: " << stmt + ARROW_LOG(DEBUG) << "SQLTables called with stmt: " << stmt << ", catalog_name: " << static_cast(catalog_name) << ", catalog_name_length: " << catalog_name_length << ", schema_name: " << static_cast(schema_name) @@ -1052,8 +1222,24 @@ SQLRETURN SQLTables(SQLHSTMT stmt, SQLWCHAR* catalog_name, << ", table_name_length: " << table_name_length << ", table_type: " << static_cast(table_type) << ", table_type_length: " << table_type_length; - // GH-47719 TODO: Implement SQLTables - return SQL_INVALID_HANDLE; + + using ODBC::ODBCStatement; + using ODBC::SqlWcharToString; + + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + + std::string catalog = SqlWcharToString(catalog_name, catalog_name_length); + std::string schema = SqlWcharToString(schema_name, schema_name_length); + std::string table = SqlWcharToString(table_name, table_name_length); + std::string type = SqlWcharToString(table_type, table_type_length); + + statement->GetTables(catalog_name ? &catalog : nullptr, + schema_name ? &schema : nullptr, table_name ? &table : nullptr, + table_type ? &type : nullptr); + + return SQL_SUCCESS; + }); } SQLRETURN SQLColumns(SQLHSTMT stmt, SQLWCHAR* catalog_name, @@ -1061,7 +1247,7 @@ SQLRETURN SQLColumns(SQLHSTMT stmt, SQLWCHAR* catalog_name, SQLSMALLINT schema_name_length, SQLWCHAR* table_name, SQLSMALLINT table_name_length, SQLWCHAR* column_name, SQLSMALLINT column_name_length) { - // GH-47159 TODO: Return NUM_PREC_RADIX based on whether COLUMN_SIZE contains number of + // GH-47159: Return NUM_PREC_RADIX based on whether COLUMN_SIZE contains number of // digits or bits ARROW_LOG(DEBUG) << "SQLColumnsW called with stmt: " << stmt << ", catalog_name: " << static_cast(catalog_name) @@ -1072,8 +1258,24 @@ SQLRETURN SQLColumns(SQLHSTMT stmt, SQLWCHAR* catalog_name, << ", table_name_length: " << table_name_length << ", column_name: " << static_cast(column_name) << ", column_name_length: " << column_name_length; - // GH-47720 TODO: Implement SQLColumns - return SQL_INVALID_HANDLE; + + using ODBC::ODBCStatement; + using ODBC::SqlWcharToString; + + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + + std::string catalog = SqlWcharToString(catalog_name, catalog_name_length); + std::string schema = SqlWcharToString(schema_name, schema_name_length); + std::string table = SqlWcharToString(table_name, table_name_length); + std::string column = SqlWcharToString(column_name, column_name_length); + + statement->GetColumns(catalog_name ? &catalog : nullptr, + schema_name ? &schema : nullptr, table_name ? &table : nullptr, + column_name ? &column : nullptr); + + return SQL_SUCCESS; + }); } SQLRETURN SQLColAttribute(SQLHSTMT stmt, SQLUSMALLINT record_number, @@ -1088,17 +1290,151 @@ SQLRETURN SQLColAttribute(SQLHSTMT stmt, SQLUSMALLINT record_number, << ", output_length: " << static_cast(output_length) << ", numeric_attribute_ptr: " << static_cast(numeric_attribute_ptr); - // GH-47721 TODO: Implement SQLColAttribute, pre-requisite requires SQLColumns - return SQL_INVALID_HANDLE; + + using ODBC::ODBCDescriptor; + using ODBC::ODBCStatement; + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + ODBCDescriptor* ird = statement->GetIRD(); + SQLINTEGER output_length_int; + switch (field_identifier) { + // Numeric attributes + // internal is SQLLEN, no conversion is needed + case SQL_DESC_DISPLAY_SIZE: + case SQL_DESC_OCTET_LENGTH: { + ird->GetField(record_number, field_identifier, numeric_attribute_ptr, + buffer_length, &output_length_int); + break; + } + // internal is SQLULEN, conversion is needed. + case SQL_COLUMN_LENGTH: // ODBC 2.0 + case SQL_DESC_LENGTH: { + SQLULEN temp; + ird->GetField(record_number, field_identifier, &temp, buffer_length, + &output_length_int); + if (numeric_attribute_ptr) { + *numeric_attribute_ptr = static_cast(temp); + } + break; + } + // internal is SQLINTEGER, conversion is needed. + case SQL_DESC_AUTO_UNIQUE_VALUE: + case SQL_DESC_CASE_SENSITIVE: + case SQL_DESC_NUM_PREC_RADIX: { + SQLINTEGER temp; + ird->GetField(record_number, field_identifier, &temp, buffer_length, + &output_length_int); + if (numeric_attribute_ptr) { + *numeric_attribute_ptr = static_cast(temp); + } + break; + } + // internal is SQLSMALLINT, conversion is needed. + case SQL_DESC_CONCISE_TYPE: + case SQL_DESC_COUNT: + case SQL_DESC_FIXED_PREC_SCALE: + case SQL_DESC_TYPE: + case SQL_DESC_NULLABLE: + case SQL_COLUMN_PRECISION: // ODBC 2.0 + case SQL_DESC_PRECISION: + case SQL_COLUMN_SCALE: // ODBC 2.0 + case SQL_DESC_SCALE: + case SQL_DESC_SEARCHABLE: + case SQL_DESC_UNNAMED: + case SQL_DESC_UNSIGNED: + case SQL_DESC_UPDATABLE: { + SQLSMALLINT temp; + ird->GetField(record_number, field_identifier, &temp, buffer_length, + &output_length_int); + if (numeric_attribute_ptr) { + *numeric_attribute_ptr = static_cast(temp); + } + break; + } + // Character attributes + case SQL_DESC_BASE_COLUMN_NAME: + case SQL_DESC_BASE_TABLE_NAME: + case SQL_DESC_CATALOG_NAME: + case SQL_DESC_LABEL: + case SQL_DESC_LITERAL_PREFIX: + case SQL_DESC_LITERAL_SUFFIX: + case SQL_DESC_LOCAL_TYPE_NAME: + case SQL_DESC_NAME: + case SQL_DESC_SCHEMA_NAME: + case SQL_DESC_TABLE_NAME: + case SQL_DESC_TYPE_NAME: + ird->GetField(record_number, field_identifier, character_attribute_ptr, + buffer_length, &output_length_int); + break; + default: + throw DriverException("Invalid descriptor field", "HY091"); + } + if (output_length) { + *output_length = static_cast(output_length_int); + } + return SQL_SUCCESS; + }); } SQLRETURN SQLGetTypeInfo(SQLHSTMT stmt, SQLSMALLINT data_type) { - // GH-47237 TODO: return SQL_PRED_CHAR and SQL_PRED_BASIC for + // GH-47237 return SQL_PRED_CHAR and SQL_PRED_BASIC for // appropriate data types in `SEARCHABLE` field ARROW_LOG(DEBUG) << "SQLGetTypeInfoW called with stmt: " << stmt << " data_type: " << data_type; - // GH-47722 TODO: Implement SQLGetTypeInfo - return SQL_INVALID_HANDLE; + + using ODBC::ODBCStatement; + return ODBC::ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + + switch (data_type) { + case SQL_ALL_TYPES: + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: + case SQL_BIT: + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: + case SQL_TINYINT: + case SQL_SMALLINT: + case SQL_INTEGER: + case SQL_BIGINT: + case SQL_NUMERIC: + case SQL_DECIMAL: + case SQL_FLOAT: + case SQL_REAL: + case SQL_DOUBLE: + case SQL_GUID: + case SQL_DATE: + case SQL_TYPE_DATE: + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_INTERVAL_DAY: + case SQL_INTERVAL_DAY_TO_HOUR: + case SQL_INTERVAL_DAY_TO_MINUTE: + case SQL_INTERVAL_DAY_TO_SECOND: + case SQL_INTERVAL_HOUR: + case SQL_INTERVAL_HOUR_TO_MINUTE: + case SQL_INTERVAL_HOUR_TO_SECOND: + case SQL_INTERVAL_MINUTE: + case SQL_INTERVAL_MINUTE_TO_SECOND: + case SQL_INTERVAL_SECOND: + case SQL_INTERVAL_YEAR: + case SQL_INTERVAL_YEAR_TO_MONTH: + case SQL_INTERVAL_MONTH: + statement->GetTypeInfo(data_type); + break; + default: + throw DriverException("Invalid SQL data type", "HY004"); + } + + return SQL_SUCCESS; + }); } SQLRETURN SQLNativeSql(SQLHDBC conn, SQLWCHAR* in_statement_text, @@ -1113,8 +1449,23 @@ SQLRETURN SQLNativeSql(SQLHDBC conn, SQLWCHAR* in_statement_text, << ", buffer_length: " << buffer_length << ", out_statement_text_length: " << static_cast(out_statement_text_length); - // GH-47723 TODO: Implement SQLNativeSql - return SQL_INVALID_HANDLE; + + using ODBC::GetAttributeSQLWCHAR; + using ODBC::ODBCConnection; + using ODBC::SqlWcharToString; + + return ODBCConnection::ExecuteWithDiagnostics(conn, SQL_ERROR, [=]() { + const bool is_length_in_bytes = false; + + ODBCConnection* connection = reinterpret_cast(conn); + Diagnostics& diagnostics = connection->GetDiagnostics(); + + std::string in_statement_str = + SqlWcharToString(in_statement_text, in_statement_text_length); + + return GetAttributeSQLWCHAR(in_statement_str, is_length_in_bytes, out_statement_text, + buffer_length, out_statement_text_length, diagnostics); + }); } SQLRETURN SQLDescribeCol(SQLHSTMT stmt, SQLUSMALLINT column_number, SQLWCHAR* column_name, @@ -1131,8 +1482,110 @@ SQLRETURN SQLDescribeCol(SQLHSTMT stmt, SQLUSMALLINT column_number, SQLWCHAR* co << ", decimal_digits_ptr: " << static_cast(decimal_digits_ptr) << ", nullable_ptr: " << static_cast(nullable_ptr); - // GH-47724 TODO: Implement SQLDescribeCol - return SQL_INVALID_HANDLE; + + using ODBC::ODBCDescriptor; + using ODBC::ODBCStatement; + + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + ODBCDescriptor* ird = statement->GetIRD(); + SQLINTEGER output_length_int; + SQLSMALLINT sql_type; + + // Column SQL Type + ird->GetField(column_number, SQL_DESC_CONCISE_TYPE, &sql_type, sizeof(SQLSMALLINT), + nullptr); + if (data_type_ptr) { + *data_type_ptr = sql_type; + } + + // Column Name + if (column_name || name_length_ptr) { + ird->GetField(column_number, SQL_DESC_NAME, column_name, buffer_length, + &output_length_int); + if (name_length_ptr) { + // returned length should be in characters + *name_length_ptr = + static_cast(output_length_int / GetSqlWCharSize()); + } + } + + // Column Size + if (column_size_ptr) { + switch (sql_type) { + // All numeric types + case SQL_DECIMAL: + case SQL_NUMERIC: + case SQL_TINYINT: + case SQL_SMALLINT: + case SQL_INTEGER: + case SQL_BIGINT: + case SQL_REAL: + case SQL_FLOAT: + case SQL_DOUBLE: { + ird->GetField(column_number, SQL_DESC_PRECISION, column_size_ptr, + sizeof(SQLULEN), nullptr); + break; + } + + default: { + ird->GetField(column_number, SQL_DESC_LENGTH, column_size_ptr, sizeof(SQLULEN), + nullptr); + } + } + } + + // Column Decimal Digits + if (decimal_digits_ptr) { + switch (sql_type) { + // All exact numeric types + case SQL_TINYINT: + case SQL_SMALLINT: + case SQL_INTEGER: + case SQL_BIGINT: + case SQL_DECIMAL: + case SQL_NUMERIC: { + ird->GetField(column_number, SQL_DESC_SCALE, decimal_digits_ptr, + sizeof(SQLULEN), nullptr); + break; + } + + // All datetime types (ODBC2) + case SQL_DATE: + case SQL_TIME: + case SQL_TIMESTAMP: + // All datetime types (ODBC3) + case SQL_TYPE_DATE: + case SQL_TYPE_TIME: + case SQL_TYPE_TIMESTAMP: + // All interval types with a seconds component + case SQL_INTERVAL_SECOND: + case SQL_INTERVAL_MINUTE_TO_SECOND: + case SQL_INTERVAL_HOUR_TO_SECOND: + case SQL_INTERVAL_DAY_TO_SECOND: { + ird->GetField(column_number, SQL_DESC_PRECISION, decimal_digits_ptr, + sizeof(SQLULEN), nullptr); + break; + } + + default: { + // All character and binary types + // SQL_BIT + // All approximate numeric types + // All interval types with no seconds component + *decimal_digits_ptr = static_cast(0); + } + } + } + + // Column Nullable + if (nullable_ptr) { + ird->GetField(column_number, SQL_DESC_NULLABLE, nullable_ptr, sizeof(SQLSMALLINT), + nullptr); + } + + return SQL_SUCCESS; + }); } } // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt b/cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt index 8f09fccd71d..ff716660a31 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt @@ -45,10 +45,10 @@ add_library(arrow_odbc_spi_impl config/connection_string_parser.h diagnostics.cc diagnostics.h - error_codes.h encoding.cc encoding.h encoding_utils.h + error_codes.h exceptions.cc exceptions.h flight_sql_auth_method.cc @@ -164,9 +164,11 @@ add_arrow_test(odbc_spi_impl_test accessors/time_array_accessor_test.cc accessors/timestamp_array_accessor_test.cc flight_sql_connection_test.cc + flight_sql_stream_chunk_buffer_test.cc parse_table_types_test.cc json_converter_test.cc record_batch_transformer_test.cc util_test.cc EXTRA_LINK_LIBS - arrow_odbc_spi_impl) + arrow_odbc_spi_impl + arrow_flight_testing_shared) diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/binary_array_accessor.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/binary_array_accessor.cc index e4625ace370..939264bedb7 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/binary_array_accessor.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/binary_array_accessor.cc @@ -19,6 +19,7 @@ #include #include +#include #include "arrow/array.h" namespace arrow::flight::sql::odbc { @@ -39,7 +40,7 @@ inline RowStatus MoveSingleCellToBinaryBuffer(ColumnBinding* binding, BinaryArra auto* byte_buffer = static_cast(binding->buffer) + i * binding->buffer_length; - memcpy(byte_buffer, ((char*)value) + value_offset, value_length); + std::memcpy(byte_buffer, ((char*)value) + value_offset, value_length); if (remaining_length > binding->buffer_length) { result = RowStatus_SUCCESS_WITH_INFO; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/binary_array_accessor_test.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/binary_array_accessor_test.cc index 39d03692da6..423870eb3be 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/binary_array_accessor_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/binary_array_accessor_test.cc @@ -18,11 +18,12 @@ #include "arrow/flight/sql/odbc/odbc_impl/accessors/binary_array_accessor.h" #include "arrow/testing/builder.h" #include "arrow/testing/gtest_util.h" -#include "gtest/gtest.h" + +#include namespace arrow::flight::sql::odbc { -TEST(BinaryArrayAccessor, Test_CDataType_BINARY_Basic) { +TEST(BinaryArrayAccessor, TestCDataTypeBinaryBasic) { std::vector values = {"foo", "barx", "baz123"}; std::shared_ptr array; ArrayFromVector(values, &array); @@ -53,7 +54,7 @@ TEST(BinaryArrayAccessor, Test_CDataType_BINARY_Basic) { } } -TEST(BinaryArrayAccessor, Test_CDataType_BINARY_Truncation) { +TEST(BinaryArrayAccessor, TestCDataTypeBinaryTruncation) { std::vector values = {"ABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEF"}; std::shared_ptr array; ArrayFromVector(values, &array); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/boolean_array_accessor_test.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/boolean_array_accessor_test.cc index 9b17a904598..b3f402dd7c1 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/boolean_array_accessor_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/boolean_array_accessor_test.cc @@ -17,11 +17,12 @@ #include "arrow/flight/sql/odbc/odbc_impl/accessors/boolean_array_accessor.h" #include "arrow/testing/builder.h" -#include "gtest/gtest.h" + +#include namespace arrow::flight::sql::odbc { -TEST(BooleanArrayFlightSqlAccessor, Test_BooleanArray_CDataType_BIT) { +TEST(BooleanArrayFlightSqlAccessor, TestBooleanArrayCDataTypeBit) { const std::vector values = {true, false, true}; std::shared_ptr array; ArrayFromVector(values, &array); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/common.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/common.h index 0a79bc39dfb..45f88b50fb8 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/common.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/common.h @@ -19,6 +19,7 @@ #include #include +#include #include "arrow/array.h" #include "arrow/flight/sql/odbc/odbc_impl/accessors/types.h" #include "arrow/flight/sql/odbc/odbc_impl/diagnostics.h" @@ -42,7 +43,7 @@ inline size_t CopyFromArrayValuesToBinding(ARRAY_TYPE* array, ColumnBinding* bin } } } else { - // Duplicate this loop to avoid null checks within the loop. + // Duplicate above for-loop to exit early when null value is found for (int64_t i = starting_row; i < starting_row + cells; ++i) { if (array->IsNull(i)) { throw NullWithoutIndicatorException(); @@ -54,7 +55,7 @@ inline size_t CopyFromArrayValuesToBinding(ARRAY_TYPE* array, ColumnBinding* bin // Note that the array should already have been sliced down to the same number // of elements in the ODBC data array by the point in which this function is called. const auto* values = array->raw_values(); - memcpy(binding->buffer, &values[starting_row], element_size * cells); + std::memcpy(binding->buffer, &values[starting_row], element_size * cells); return cells; } diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/date_array_accessor_test.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/date_array_accessor_test.cc index 03716e2477a..a482a838101 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/date_array_accessor_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/date_array_accessor_test.cc @@ -20,11 +20,12 @@ #include "arrow/flight/sql/odbc/odbc_impl/accessors/boolean_array_accessor.h" #include "arrow/flight/sql/odbc/odbc_impl/calendar_utils.h" #include "arrow/testing/builder.h" -#include "gtest/gtest.h" + +#include namespace arrow::flight::sql::odbc { -TEST(DateArrayAccessor, Test_Date32Array_CDataType_DATE) { +TEST(DateArrayAccessor, TestDate32ArrayCDataTypeDate) { std::vector values = {7589, 12320, 18980, 19095, -1, 0}; std::vector expected = { {1990, 10, 12}, {2003, 9, 25}, {2021, 12, 19}, @@ -57,7 +58,7 @@ TEST(DateArrayAccessor, Test_Date32Array_CDataType_DATE) { } } -TEST(DateArrayAccessor, Test_Date64Array_CDataType_DATE) { +TEST(DateArrayAccessor, TestDate64ArrayCDataTypeDate) { std::vector values = { 86400000, 172800000, 259200000, 1649793238110, 0, 345600000, 432000000, 518400000, -86400000, -17987443200000, -24268068949000}; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/decimal_array_accessor_test.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/decimal_array_accessor_test.cc index b2eb9450c2f..6664b2d6e60 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/decimal_array_accessor_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/decimal_array_accessor_test.cc @@ -19,7 +19,8 @@ #include "arrow/builder.h" #include "arrow/testing/builder.h" #include "arrow/util/decimal.h" -#include "gtest/gtest.h" + +#include namespace arrow::flight::sql::odbc { namespace { @@ -93,7 +94,7 @@ void AssertNumericOutput(int input_precision, int input_scale, } } -TEST(DecimalArrayFlightSqlAccessor, Test_Decimal128Array_CDataType_NUMERIC_SameScale) { +TEST(DecimalArrayFlightSqlAccessor, TestDecimal128ArrayCDataTypeNumericSameScale) { const std::vector& input_values = {"25.212", "-25.212", "-123456789.123", "123456789.123"}; const std::vector& output_values = diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/primitive_array_accessor_test.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/primitive_array_accessor_test.cc index 2f04b7324a5..a5ce05fb717 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/primitive_array_accessor_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/primitive_array_accessor_test.cc @@ -19,7 +19,8 @@ #include "arrow/flight/sql/odbc/odbc_impl/diagnostics.h" #include "arrow/testing/builder.h" -#include "gtest/gtest.h" + +#include namespace arrow::flight::sql::odbc { @@ -52,43 +53,43 @@ void TestPrimitiveArraySqlAccessor() { } } -TEST(PrimitiveArrayFlightSqlAccessor, Test_Int64Array_CDataType_SBIGINT) { +TEST(PrimitiveArrayFlightSqlAccessor, TestInt64ArrayCDataTypeSbigint) { TestPrimitiveArraySqlAccessor(); } -TEST(PrimitiveArrayFlightSqlAccessor, Test_Int32Array_CDataType_SLONG) { +TEST(PrimitiveArrayFlightSqlAccessor, TestInt32ArrayCDataTypeSlong) { TestPrimitiveArraySqlAccessor(); } -TEST(PrimitiveArrayFlightSqlAccessor, Test_Int16Array_CDataType_SSHORT) { +TEST(PrimitiveArrayFlightSqlAccessor, TestInt16ArrayCDataTypeSshort) { TestPrimitiveArraySqlAccessor(); } -TEST(PrimitiveArrayFlightSqlAccessor, Test_Int8Array_CDataType_STINYINT) { +TEST(PrimitiveArrayFlightSqlAccessor, TestInt8ArrayCDataTypeStinyint) { TestPrimitiveArraySqlAccessor(); } -TEST(PrimitiveArrayFlightSqlAccessor, Test_UInt64Array_CDataType_UBIGINT) { +TEST(PrimitiveArrayFlightSqlAccessor, TestUInt64ArrayCDataTypeUbigint) { TestPrimitiveArraySqlAccessor(); } -TEST(PrimitiveArrayFlightSqlAccessor, Test_UInt32Array_CDataType_ULONG) { +TEST(PrimitiveArrayFlightSqlAccessor, TestUInt32ArrayCDataTypeUlong) { TestPrimitiveArraySqlAccessor(); } -TEST(PrimitiveArrayFlightSqlAccessor, Test_UInt16Array_CDataType_USHORT) { +TEST(PrimitiveArrayFlightSqlAccessor, TestUInt16ArrayCDataTypeUshort) { TestPrimitiveArraySqlAccessor(); } -TEST(PrimitiveArrayFlightSqlAccessor, Test_UInt8Array_CDataType_UTINYINT) { +TEST(PrimitiveArrayFlightSqlAccessor, TestUInt8ArrayCDataTypeUtinyint) { TestPrimitiveArraySqlAccessor(); } -TEST(PrimitiveArrayFlightSqlAccessor, Test_FloatArray_CDataType_FLOAT) { +TEST(PrimitiveArrayFlightSqlAccessor, TestFloatArrayCDataTypeFloat) { TestPrimitiveArraySqlAccessor(); } -TEST(PrimitiveArrayFlightSqlAccessor, Test_DoubleArray_CDataType_DOUBLE) { +TEST(PrimitiveArrayFlightSqlAccessor, TestDoubleArrayCDataTypeDouble) { TestPrimitiveArraySqlAccessor(); } diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/string_array_accessor.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/string_array_accessor.cc index 69b3b304945..441b2a3394e 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/string_array_accessor.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/string_array_accessor.cc @@ -18,6 +18,7 @@ #include "arrow/flight/sql/odbc/odbc_impl/accessors/string_array_accessor.h" #include +#include #include "arrow/array.h" #include "arrow/flight/sql/odbc/odbc_impl/encoding.h" @@ -79,7 +80,7 @@ inline RowStatus MoveSingleCellToCharBuffer( auto* byte_buffer = static_cast(binding->buffer) + i * binding->buffer_length; auto* char_buffer = (CHAR_TYPE*)byte_buffer; - memcpy(char_buffer, ((char*)value) + value_offset, value_length); + std::memcpy(char_buffer, ((char*)value) + value_offset, value_length); // Write a NUL terminator if (binding->buffer_length >= remaining_length + sizeof(CHAR_TYPE)) { diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/string_array_accessor_test.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/string_array_accessor_test.cc index eb7a9c88b3b..4d0e1393407 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/string_array_accessor_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/string_array_accessor_test.cc @@ -19,11 +19,12 @@ #include "arrow/flight/sql/odbc/odbc_impl/encoding.h" #include "arrow/testing/builder.h" -#include "gtest/gtest.h" + +#include namespace arrow::flight::sql::odbc { -TEST(StringArrayAccessor, Test_CDataType_CHAR_Basic) { +TEST(StringArrayAccessor, TestCDataTypeCharBasic) { std::vector values = {"foo", "barx", "baz123"}; std::shared_ptr array; ArrayFromVector(values, &array); @@ -49,7 +50,7 @@ TEST(StringArrayAccessor, Test_CDataType_CHAR_Basic) { } } -TEST(StringArrayAccessor, Test_CDataType_CHAR_Truncation) { +TEST(StringArrayAccessor, TestCDataTypeCharTruncation) { std::vector values = {"ABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEF"}; std::shared_ptr array; ArrayFromVector(values, &array); @@ -82,7 +83,7 @@ TEST(StringArrayAccessor, Test_CDataType_CHAR_Truncation) { ASSERT_EQ(values[0], ss.str()); } -TEST(StringArrayAccessor, Test_CDataType_WCHAR_Basic) { +TEST(StringArrayAccessor, TestCDataTypeWcharBasic) { std::vector values = {"foo", "barx", "baz123"}; std::shared_ptr array; ArrayFromVector(values, &array); @@ -112,7 +113,7 @@ TEST(StringArrayAccessor, Test_CDataType_WCHAR_Basic) { } } -TEST(StringArrayAccessor, Test_CDataType_WCHAR_Truncation) { +TEST(StringArrayAccessor, TestCDataTypeWcharTruncation) { std::vector values = {"ABCDEFA"}; std::shared_ptr array; ArrayFromVector(values, &array); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/time_array_accessor_test.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/time_array_accessor_test.cc index eb49e4078cd..41bd0d73ea7 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/time_array_accessor_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/time_array_accessor_test.cc @@ -20,11 +20,12 @@ #include "arrow/flight/sql/odbc/odbc_impl/calendar_utils.h" #include "arrow/flight/sql/odbc/odbc_impl/util.h" #include "arrow/testing/builder.h" -#include "gtest/gtest.h" + +#include namespace arrow::flight::sql::odbc { -TEST(TEST_TIME32, TIME_WITH_SECONDS) { +TEST(TestTime32, TimeWithSeconds) { auto value_field = field("f0", time32(TimeUnit::SECOND)); std::vector t32_values = {14896, 14897, 14892, 85400, 14893, 14895}; @@ -58,7 +59,7 @@ TEST(TEST_TIME32, TIME_WITH_SECONDS) { } } -TEST(TEST_TIME32, TIME_WITH_MILLI) { +TEST(TestTime32, TimeWithMilli) { auto value_field = field("f0", time32(TimeUnit::MILLI)); std::vector t32_values = {14896000, 14897000, 14892000, 85400000, 14893000, 14895000}; @@ -94,7 +95,7 @@ TEST(TEST_TIME32, TIME_WITH_MILLI) { } } -TEST(TEST_TIME64, TIME_WITH_MICRO) { +TEST(TestTime32, TimeWithMicro) { auto value_field = field("f0", time64(TimeUnit::MICRO)); std::vector t64_values = {14896000, 14897000, 14892000, @@ -131,7 +132,7 @@ TEST(TEST_TIME64, TIME_WITH_MICRO) { } } -TEST(TEST_TIME64, TIME_WITH_NANO) { +TEST(TestTime32, TimeWithNano) { auto value_field = field("f0", time64(TimeUnit::NANO)); std::vector t64_values = {14896000000, 14897000000, 14892000000, 85400000000, 14893000000, 14895000000}; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/timestamp_array_accessor.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/timestamp_array_accessor.cc index 37f14ebd9c5..93af9465576 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/timestamp_array_accessor.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/timestamp_array_accessor.cc @@ -19,12 +19,12 @@ #include "arrow/flight/sql/odbc/odbc_impl/calendar_utils.h" -using arrow::TimeUnit; +#include +#include namespace arrow::flight::sql::odbc { namespace { - -int64_t GetConversionToSecondsDivisor(TimeUnit::type unit) { +inline int64_t GetConversionToSecondsDivisor(TimeUnit::type unit) { int64_t divisor = 1; switch (unit) { case TimeUnit::SECOND: @@ -79,6 +79,10 @@ template RowStatus TimestampArrayFlightSqlAccessor::MoveSingleCellImpl( ColumnBinding* binding, int64_t arrow_row, int64_t cell_counter, int64_t& value_offset, bool update_value_offset, Diagnostics& diagnostics) { + // Times less than the minimum integer number of seconds that can be represented + // for each time unit will not convert correctly. This is mostly interesting for + // nanoseconds as timestamps in other units are outside of the accepted range of + // Gregorian dates. auto* buffer = static_cast(binding->buffer); int64_t value = this->GetArray()->Value(arrow_row); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/timestamp_array_accessor_test.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/timestamp_array_accessor_test.cc index 393dd98501d..dd4917b0e37 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/timestamp_array_accessor_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/timestamp_array_accessor_test.cc @@ -20,11 +20,12 @@ #include "arrow/flight/sql/odbc/odbc_impl/calendar_utils.h" #include "arrow/flight/sql/odbc/odbc_impl/util.h" #include "arrow/testing/builder.h" -#include "gtest/gtest.h" + +#include namespace arrow::flight::sql::odbc { -TEST(TEST_TIMESTAMP, TIMESTAMP_WITH_MILLI) { +TEST(TestTimestamp, TimestampWithMilli) { std::vector values = {86400370, 172800000, 259200000, 1649793238110LL, 345600000, 432000000, 518400000, -86399000, 0, @@ -88,7 +89,7 @@ TEST(TEST_TIMESTAMP, TIMESTAMP_WITH_MILLI) { } } -TEST(TEST_TIMESTAMP, TIMESTAMP_WITH_SECONDS) { +TEST(TestTimestamp, TimestampWithSeconds) { std::vector values = {86400, 172800, 259200, 1649793238, 345600, 432000, 518400}; @@ -130,7 +131,7 @@ TEST(TEST_TIMESTAMP, TIMESTAMP_WITH_SECONDS) { } } -TEST(TEST_TIMESTAMP, TIMESTAMP_WITH_MICRO) { +TEST(TestTimestamp, TimestampWithMicro) { std::vector values = {86400000000, 1649793238000000}; std::shared_ptr timestamp_array; @@ -174,7 +175,7 @@ TEST(TEST_TIMESTAMP, TIMESTAMP_WITH_MICRO) { } } -TEST(TEST_TIMESTAMP, TIMESTAMP_WITH_NANO) { +TEST(TestTimestamp, TimestampWithNano) { std::vector values = {86400000010000, 1649793238000000000}; std::shared_ptr timestamp_array; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/types.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/types.h index ca33d872fa7..c0084a5ab15 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/types.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/accessors/types.h @@ -102,7 +102,7 @@ class FlightSqlAccessor : public Accessor { throw NullWithoutIndicatorException(); } } else { - // TODO: Optimize this by creating different versions of MoveSingleCell + // GH-47849 TODO: Optimize this by creating different versions of MoveSingleCell // depending on if str_len_buffer is null. auto row_status = MoveSingleCell(binding, current_arrow_row, i, value_offset, update_value_offset, diagnostics); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/address_info.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/address_info.cc index 5ee6674c3c2..5dfc85a58f7 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/address_info.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/address_info.cc @@ -16,8 +16,9 @@ // under the License. #include "arrow/flight/sql/odbc/odbc_impl/address_info.h" +#include -namespace driver { +namespace arrow::flight::sql::odbc { bool AddressInfo::GetAddressInfo(const std::string& host, char* host_name_info, int64_t max_host) { @@ -34,7 +35,7 @@ bool AddressInfo::GetAddressInfo(const std::string& host, char* host_name_info, } error = getnameinfo(addrinfo_result_->ai_addr, addrinfo_result_->ai_addrlen, - host_name_info, static_cast(max_host), NULL, 0, 0); + host_name_info, static_cast(max_host), NULL, 0, 0); return error == 0; } @@ -46,4 +47,5 @@ AddressInfo::~AddressInfo() { } AddressInfo::AddressInfo() : addrinfo_result_(nullptr) {} -} // namespace driver + +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/address_info.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/address_info.h index c127c0f4a29..7d538f912e0 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/address_info.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/address_info.h @@ -27,7 +27,7 @@ # include #endif -namespace driver { +namespace arrow::flight::sql::odbc { class AddressInfo { private: @@ -40,4 +40,5 @@ class AddressInfo { bool GetAddressInfo(const std::string& host, char* host_name_info, int64_t max_host); }; -} // namespace driver + +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/attribute_utils.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/attribute_utils.h index 8c5eae59f7e..9f3c6e60555 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/attribute_utils.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/attribute_utils.h @@ -29,9 +29,7 @@ // GH-48083 TODO: replace `namespace ODBC` with `namespace arrow::flight::sql::odbc` namespace ODBC { - using arrow::flight::sql::odbc::Diagnostics; -using arrow::flight::sql::odbc::DriverException; using arrow::flight::sql::odbc::WcsToUtf8; template diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/blocking_queue.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/blocking_queue.h index 5c9e6609d58..e52c305e461 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/blocking_queue.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/blocking_queue.h @@ -18,9 +18,9 @@ #pragma once #include -#include #include #include +#include #include #include @@ -43,7 +43,7 @@ class BlockingQueue { std::atomic closed_{false}; public: - typedef std::function(void)> Supplier; + typedef std::function(void)> Supplier; explicit BlockingQueue(size_t capacity) : capacity_(capacity), buffer_(capacity) {} @@ -58,7 +58,7 @@ class BlockingQueue { // Only one thread at a time be notified and call supplier auto item = supplier(); - if (!item) break; + if (!item.has_value()) break; Push(*item); } diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/calendar_utils.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/calendar_utils.cc index 1dddae2a7c7..b47b33f2d93 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/calendar_utils.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/calendar_utils.cc @@ -40,6 +40,9 @@ int64_t GetTodayTimeFromEpoch() { #endif } +// GH-47631: add support for non-UTC time zone data. +// Read the time zone value from Arrow::Timestamp, and use the time zone value to convert +// seconds_since_epoch instead of converting to UTC time zone by default void GetTimeForSecondsSinceEpoch(const int64_t seconds_since_epoch, std::tm& out_tm) { std::memset(&out_tm, 0, sizeof(std::tm)); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/config/configuration.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/config/configuration.cc index df61f1247c7..75498710d23 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/config/configuration.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/config/configuration.cc @@ -16,6 +16,7 @@ // under the License. #include "arrow/flight/sql/odbc/odbc_impl/config/configuration.h" + #include "arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.h" #include "arrow/flight/sql/odbc/odbc_impl/util.h" #include "arrow/result.h" @@ -186,15 +187,14 @@ const Connection::ConnPropertyMap& Configuration::GetProperties() const { return this->properties_; } -std::vector Configuration::GetCustomKeys() const { +std::vector Configuration::GetCustomKeys() const { Connection::ConnPropertyMap copy_props(properties_); for (auto& key : FlightSqlConnection::ALL_KEYS) { copy_props.erase(std::string(key)); } - std::vector keys; + std::vector keys; boost::copy(copy_props | boost::adaptors::map_keys, std::back_inserter(keys)); return keys; } - } // namespace config } // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/config/configuration.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/config/configuration.h index 0390a57e52f..9a3350a316e 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/config/configuration.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/config/configuration.h @@ -22,8 +22,10 @@ #include "arrow/flight/sql/odbc/odbc_impl/platform.h" #include "arrow/flight/sql/odbc/odbc_impl/spi/connection.h" +#if defined _WIN32 || defined _WIN64 // winuser.h needs to be included after windows.h, which is defined in platform.h -#include +# include +#endif namespace arrow::flight::sql::odbc { namespace config { @@ -60,7 +62,7 @@ class Configuration { */ const Connection::ConnPropertyMap& GetProperties() const; - std::vector GetCustomKeys() const; + std::vector GetCustomKeys() const; private: Connection::ConnPropertyMap properties_; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h index 66e5c3bf0d8..5e3a4ecbdae 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h @@ -31,7 +31,6 @@ #define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING namespace ODBC { - using arrow::flight::sql::odbc::DriverException; using arrow::flight::sql::odbc::GetSqlWCharSize; using arrow::flight::sql::odbc::Utf8ToWcs; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_auth_method.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_auth_method.cc index bdf7f71589c..5da662e7ade 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_auth_method.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_auth_method.cc @@ -24,6 +24,7 @@ #include "arrow/result.h" #include "arrow/status.h" +#include #include namespace arrow::flight::sql::odbc { @@ -36,6 +37,9 @@ class NoOpAuthMethod : public FlightSqlAuthMethod { void Authenticate(FlightSqlConnection& connection, FlightCallOptions& call_options) override { // Do nothing + + // GH-46733 TODO: implement NoOpAuthMethod to validate server address. + // Can use NoOpClientAuthHandler. } }; @@ -44,8 +48,8 @@ class NoOpClientAuthHandler : public ClientAuthHandler { NoOpClientAuthHandler() {} Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override { - // Write a blank string. The server should ignore this and just accept any Handshake - // request. + // The server should ignore this and just accept any Handshake + // request. Some servers do not allow authentication with no handshakes. return outgoing->Write(std::string()); } @@ -63,12 +67,12 @@ class UserPasswordAuthMethod : public FlightSqlAuthMethod { void Authenticate(FlightSqlConnection& connection, FlightCallOptions& call_options) override { FlightCallOptions auth_call_options; - const boost::optional& login_timeout = + const std::optional& login_timeout = connection.GetAttribute(Connection::LOGIN_TIMEOUT); - if (login_timeout && boost::get(*login_timeout) > 0) { + if (login_timeout && std::get(*login_timeout) > 0) { // ODBC's LOGIN_TIMEOUT attribute and FlightCallOptions.timeout use // seconds as time unit. - double timeout_seconds = static_cast(boost::get(*login_timeout)); + double timeout_seconds = static_cast(std::get(*login_timeout)); if (timeout_seconds > 0) { auth_call_options.timeout = TimeoutDuration{timeout_seconds}; } @@ -93,7 +97,9 @@ class UserPasswordAuthMethod : public FlightSqlAuthMethod { throw DriverException(bearer_result.status().message()); } - call_options.headers.push_back(bearer_result.ValueOrDie()); + // call_options may have already been populated with data from the connection string + // or DSN. Ensure auth-generated headers are placed at the front of the header list. + call_options.headers.insert(call_options.headers.begin(), bearer_result.ValueOrDie()); } std::string GetUser() override { return user_; } @@ -115,10 +121,11 @@ class TokenAuthMethod : public FlightSqlAuthMethod { void Authenticate(FlightSqlConnection& connection, FlightCallOptions& call_options) override { - // add the token to the headers + // add the token to the front of the headers. For consistency auth headers should be + // at the front. const std::pair token_header("authorization", "Bearer " + token_); - call_options.headers.push_back(token_header); + call_options.headers.insert(call_options.headers.begin(), token_header); const Status status = client_.Authenticate( call_options, std::unique_ptr(new NoOpClientAuthHandler())); @@ -142,22 +149,22 @@ std::unique_ptr FlightSqlAuthMethod::FromProperties( const std::unique_ptr& client, const Connection::ConnPropertyMap& properties) { // Check if should use user-password authentication - auto it_user = properties.find(FlightSqlConnection::USER); + auto it_user = properties.find(std::string(FlightSqlConnection::USER)); if (it_user == properties.end()) { // The Microsoft OLE DB to ODBC bridge provider (MSDASQL) will write // "User ID" and "Password" properties instead of mapping // to ODBC compliant UID/PWD keys. - it_user = properties.find(FlightSqlConnection::USER_ID); + it_user = properties.find(std::string(FlightSqlConnection::USER_ID)); } - auto it_password = properties.find(FlightSqlConnection::PASSWORD); - auto it_token = properties.find(FlightSqlConnection::TOKEN); + auto it_password = properties.find(std::string(FlightSqlConnection::PASSWORD)); + auto it_token = properties.find(std::string(FlightSqlConnection::TOKEN)); if (it_user == properties.end() || it_password == properties.end()) { // Accept UID/PWD as aliases for User/Password. These are suggested as // standard properties in the documentation for SQLDriverConnect. - it_user = properties.find(FlightSqlConnection::UID); - it_password = properties.find(FlightSqlConnection::PWD); + it_user = properties.find(std::string(FlightSqlConnection::UID)); + it_password = properties.find(std::string(FlightSqlConnection::PWD)); } if (it_user != properties.end() || it_password != properties.end()) { const std::string& user = it_user != properties.end() ? it_user->second : ""; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc index e18a58d069f..91eb07bb277 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc @@ -31,7 +31,6 @@ #include #include #include -#include #include "arrow/flight/sql/odbc/odbc_impl/exceptions.h" #include @@ -107,6 +106,8 @@ struct CaseInsensitiveComparatorStrView { }; const std::set BUILT_IN_PROPERTIES = { + FlightSqlConnection::DRIVER, + FlightSqlConnection::DSN, FlightSqlConnection::HOST, FlightSqlConnection::PORT, FlightSqlConnection::USER, @@ -125,7 +126,7 @@ const std::set BUILT_IN_PROP Connection::ConnPropertyMap::const_iterator TrackMissingRequiredProperty( std::string_view property, const Connection::ConnPropertyMap& properties, std::vector& missing_attr) { - auto prop_iter = properties.find(property); + auto prop_iter = properties.find(std::string(property)); if (properties.end() == prop_iter) { missing_attr.push_back(property); } @@ -144,6 +145,7 @@ std::shared_ptr LoadFlightSslConfigs( AsBool(conn_property_map, FlightSqlConnection::USE_SYSTEM_TRUST_STORE) .value_or(SYSTEM_TRUST_STORE_DEFAULT); + // GH-47630: find co-located TLS certificate if `trusted certs` path is not specified auto trusted_certs_iterator = conn_property_map.find(std::string(FlightSqlConnection::TRUSTED_CERTS)); auto trusted_certs = trusted_certs_iterator != conn_property_map.end() @@ -160,14 +162,15 @@ void FlightSqlConnection::Connect(const ConnPropertyMap& properties, auto flight_ssl_configs = LoadFlightSslConfigs(properties); Location location = BuildLocation(properties, missing_attr, flight_ssl_configs); - FlightClientOptions client_options = + client_options_ = BuildFlightClientOptions(properties, missing_attr, flight_ssl_configs); const std::shared_ptr& cookie_factory = GetCookieFactory(); - client_options.middleware.push_back(cookie_factory); + client_options_.middleware.push_back(cookie_factory); std::unique_ptr flight_client; - ThrowIfNotOK(FlightClient::Connect(location, client_options).Value(&flight_client)); + ThrowIfNotOK(FlightClient::Connect(location, client_options_).Value(&flight_client)); + PopulateMetadataSettings(properties); PopulateCallOptions(properties); @@ -199,7 +202,7 @@ void FlightSqlConnection::PopulateMetadataSettings( metadata_settings_.chunk_buffer_capacity = GetChunkBufferCapacity(conn_property_map); } -boost::optional FlightSqlConnection::GetStringColumnLength( +std::optional FlightSqlConnection::GetStringColumnLength( const Connection::ConnPropertyMap& conn_property_map) { const int32_t min_string_column_length = 1; @@ -214,7 +217,7 @@ boost::optional FlightSqlConnection::GetStringColumnLength( "01000", ODBCErrorCodes_GENERAL_WARNING); } - return boost::none; + return std::nullopt; } bool FlightSqlConnection::GetUseWideChar(const ConnPropertyMap& conn_property_map) { @@ -250,11 +253,11 @@ const FlightCallOptions& FlightSqlConnection::PopulateCallOptions( const ConnPropertyMap& props) { // Set CONNECTION_TIMEOUT attribute or LOGIN_TIMEOUT depending on if this // is the first request. - const boost::optional& connection_timeout = + const std::optional& connection_timeout = closed_ ? GetAttribute(LOGIN_TIMEOUT) : GetAttribute(CONNECTION_TIMEOUT); - if (connection_timeout && boost::get(*connection_timeout) > 0) { + if (connection_timeout && std::get(*connection_timeout) > 0) { call_options_.timeout = - TimeoutDuration{static_cast(boost::get(*connection_timeout))}; + TimeoutDuration{static_cast(std::get(*connection_timeout))}; } for (auto prop : props) { @@ -329,7 +332,7 @@ Location FlightSqlConnection::BuildLocation( Location location; if (ssl_config->UseEncryption()) { - driver::AddressInfo address_info; + AddressInfo address_info; char host_name_info[NI_MAXHOST] = ""; bool operation_result = false; @@ -343,7 +346,7 @@ Location FlightSqlConnection::BuildLocation( ThrowIfNotOK(Location::ForGrpcTls(host_name_info, port).Value(&location)); return location; } - // TODO: We should log that we could not convert an IP to hostname here. + // GH-47852 TODO: We should log that we could not convert an IP to hostname here. } } catch (...) { // This is expected. The Host attribute can be an IP or name, but make_address will @@ -370,7 +373,7 @@ void FlightSqlConnection::Close() { std::shared_ptr FlightSqlConnection::CreateStatement() { return std::shared_ptr(new FlightSqlStatement( - diagnostics_, *sql_client_, call_options_, metadata_settings_)); + diagnostics_, *sql_client_, client_options_, call_options_, metadata_settings_)); } bool FlightSqlConnection::SetAttribute(Connection::AttributeId attribute, @@ -388,17 +391,21 @@ bool FlightSqlConnection::SetAttribute(Connection::AttributeId attribute, } } -boost::optional FlightSqlConnection::GetAttribute( +std::optional FlightSqlConnection::GetAttribute( Connection::AttributeId attribute) { switch (attribute) { case ACCESS_MODE: // FlightSQL does not provide this metadata. - return boost::make_optional(Attribute(static_cast(SQL_MODE_READ_WRITE))); + return std::make_optional(Attribute(static_cast(SQL_MODE_READ_WRITE))); case PACKET_SIZE: - return boost::make_optional(Attribute(static_cast(0))); + return std::make_optional(Attribute(static_cast(0))); default: const auto& it = attribute_.find(attribute); - return boost::make_optional(it != attribute_.end(), it->second); + if (it != attribute_.end()) { + return std::make_optional(it->second); + } else { + return std::nullopt; + } } } @@ -407,7 +414,7 @@ Connection::Info FlightSqlConnection::GetInfo(uint16_t info_type) { if (info_type == SQL_DBMS_NAME || info_type == SQL_SERVER_NAME) { // Update the database component reported in error messages. // We do this lazily for performance reasons. - diagnostics_.SetDataSourceComponent(boost::get(result)); + diagnostics_.SetDataSourceComponent(std::get(result)); } return result; } @@ -416,7 +423,7 @@ FlightSqlConnection::FlightSqlConnection(OdbcVersion odbc_version, const std::string& driver_version) : diagnostics_("Apache Arrow", "Flight SQL", odbc_version), odbc_version_(odbc_version), - info_(call_options_, sql_client_, driver_version), + info_(client_options_, call_options_, sql_client_, driver_version), closed_(true) { attribute_[CONNECTION_DEAD] = static_cast(SQL_TRUE); attribute_[LOGIN_TIMEOUT] = static_cast(0); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.h index 6219bb287e4..625f8d75c9d 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.h @@ -19,6 +19,7 @@ #include "arrow/flight/sql/odbc/odbc_impl/spi/connection.h" +#include #include #include "arrow/flight/api.h" #include "arrow/flight/sql/api.h" @@ -28,12 +29,19 @@ namespace arrow::flight::sql::odbc { +/// \brief Case insensitive comparator that takes string_view +struct CaseInsensitiveComparatorStrView { + bool operator()(const std::string_view& s1, const std::string_view& s2) const { + return boost::lexicographical_compare(s1, s2, boost::is_iless()); + } +}; + class FlightSqlSslConfig; /// \brief Create an instance of the FlightSqlSslConfig class, from the properties passed /// into the map. /// \param conn_property_map the map with the Connection properties. -/// \return An instance of the FlightSqlSslConfig. +/// \return An instance of the FlightSqlSslConfig. std::shared_ptr LoadFlightSslConfigs( const Connection::ConnPropertyMap& conn_property_map); @@ -84,7 +92,7 @@ class FlightSqlConnection : public Connection { bool SetAttribute(AttributeId attribute, const Attribute& value) override; - boost::optional GetAttribute( + std::optional GetAttribute( Connection::AttributeId attribute) override; Info GetInfo(uint16_t info_type) override; @@ -111,8 +119,7 @@ class FlightSqlConnection : public Connection { /// \note Visible for testing void SetClosed(bool is_closed); - boost::optional GetStringColumnLength( - const ConnPropertyMap& conn_property_map); + std::optional GetStringColumnLength(const ConnPropertyMap& conn_property_map); bool GetUseWideChar(const ConnPropertyMap& conn_property_map); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection_test.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection_test.cc index 9c9b0f8f3c1..bc26a113d77 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection_test.cc @@ -19,7 +19,10 @@ #include "arrow/flight/sql/odbc/odbc_impl/platform.h" #include "arrow/flight/types.h" -#include "gtest/gtest.h" + +#include + +#include namespace arrow::flight::sql::odbc { @@ -28,20 +31,19 @@ TEST(AttributeTests, SetAndGetAttribute) { connection.SetClosed(false); connection.SetAttribute(Connection::CONNECTION_TIMEOUT, static_cast(200)); - const boost::optional first_value = + const std::optional first_value = connection.GetAttribute(Connection::CONNECTION_TIMEOUT); - EXPECT_TRUE(first_value); - - EXPECT_EQ(static_cast(200), boost::get(*first_value)); + ASSERT_TRUE(first_value); + ASSERT_EQ(static_cast(200), std::get(*first_value)); connection.SetAttribute(Connection::CONNECTION_TIMEOUT, static_cast(300)); - const boost::optional change_value = + const std::optional change_value = connection.GetAttribute(Connection::CONNECTION_TIMEOUT); - EXPECT_TRUE(change_value); - EXPECT_EQ(static_cast(300), boost::get(*change_value)); + ASSERT_TRUE(change_value); + ASSERT_EQ(static_cast(300), std::get(*change_value)); connection.Close(); } @@ -49,11 +51,11 @@ TEST(AttributeTests, SetAndGetAttribute) { TEST(AttributeTests, GetAttributeWithoutSetting) { FlightSqlConnection connection(OdbcVersion::V_3); - const boost::optional optional = + const std::optional optional = connection.GetAttribute(Connection::CONNECTION_TIMEOUT); connection.SetClosed(false); - EXPECT_EQ(0, boost::get(*optional)); + EXPECT_EQ(0, std::get(*optional)); connection.Close(); } @@ -72,11 +74,11 @@ TEST(MetadataSettingsTest, StringColumnLengthTest) { std::to_string(expected_string_column_length)}, }; - const boost::optional actual_string_column_length = + const std::optional actual_string_column_length = connection.GetStringColumnLength(properties); - EXPECT_TRUE(actual_string_column_length); - EXPECT_EQ(expected_string_column_length, *actual_string_column_length); + ASSERT_TRUE(actual_string_column_length); + ASSERT_EQ(expected_string_column_length, *actual_string_column_length); connection.Close(); } diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.cc index 5fe6069648f..61c0f3dd3f9 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.cc @@ -77,9 +77,9 @@ std::shared_ptr GetTablesReader::GetSchema() { const arrow::Result>& result = arrow::ipc::ReadSchema(&dataset_schema_reader, &in_memo); if (!result.ok()) { - // TODO: Ignoring this error until we fix the problem on Dremio server - // The problem is that complex types columns are being returned without the children - // types. + // GH-46561 TODO: Test and build the driver against a server that returns + // complex types columns with the children + // types and handle the failure properly return nullptr; } diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.cc index 19149b3c48d..56e5bb973f7 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.cc @@ -17,6 +17,8 @@ #include "arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.h" +#include + #include #include "arrow/flight/types.h" #include "arrow/scalar.h" @@ -29,12 +31,12 @@ namespace arrow::flight::sql::odbc { FlightSqlResultSet::FlightSqlResultSet( - FlightSqlClient& flight_sql_client, const FlightCallOptions& call_options, - const std::shared_ptr& flight_info, + FlightSqlClient& flight_sql_client, const FlightClientOptions& client_options, + const FlightCallOptions& call_options, const std::shared_ptr& flight_info, const std::shared_ptr& transformer, Diagnostics& diagnostics, const MetadataSettings& metadata_settings) : metadata_settings_(metadata_settings), - chunk_buffer_(flight_sql_client, call_options, flight_info, + chunk_buffer_(flight_sql_client, client_options, call_options, flight_info, metadata_settings_.chunk_buffer_capacity), transformer_(transformer), metadata_(transformer @@ -212,14 +214,14 @@ void FlightSqlResultSet::Cancel() { current_chunk_.data = nullptr; } -bool FlightSqlResultSet::GetData(int column_n, int16_t target_type, int precision, - int scale, void* buffer, size_t buffer_length, - ssize_t* str_len_buffer) { +SQLRETURN FlightSqlResultSet::GetData(int column_n, int16_t target_type, int precision, + int scale, void* buffer, size_t buffer_length, + ssize_t* str_len_buffer) { reset_get_data_ = true; // Check if the offset is already at the end. int64_t& value_offset = get_data_offsets_[column_n - 1]; if (value_offset == -1) { - return false; + return SQL_NO_DATA; } ColumnBinding binding(util::ConvertCDataTypeFromV2ToV3(target_type), precision, scale, @@ -235,7 +237,11 @@ bool FlightSqlResultSet::GetData(int column_n, int16_t target_type, int precisio diagnostics_, nullptr); // If there was truncation, the converter would have reported it to the diagnostics. - return diagnostics_.HasWarning(); + if (diagnostics_.HasWarning()) { + return SQL_SUCCESS_WITH_INFO; + } else { + return SQL_SUCCESS; + } } std::shared_ptr FlightSqlResultSet::GetMetadata() { return metadata_; } diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.h index 6083b332824..4f19b05bf2c 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.h @@ -51,6 +51,7 @@ class FlightSqlResultSet : public ResultSet { ~FlightSqlResultSet() override; FlightSqlResultSet(FlightSqlClient& flight_sql_client, + const FlightClientOptions& client_options, const FlightCallOptions& call_options, const std::shared_ptr& flight_info, const std::shared_ptr& transformer, @@ -60,8 +61,8 @@ class FlightSqlResultSet : public ResultSet { void Cancel() override; - bool GetData(int column_n, int16_t target_type, int precision, int scale, void* buffer, - size_t buffer_length, ssize_t* str_len_buffer) override; + SQLRETURN GetData(int column_n, int16_t target_type, int precision, int scale, + void* buffer, size_t buffer_length, ssize_t* str_len_buffer) override; size_t Move(size_t rows, size_t bind_offset, size_t bind_type, uint16_t* row_status_array) override; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set_column.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set_column.h index ede53038f1a..d09e550900b 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set_column.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set_column.h @@ -70,4 +70,5 @@ class FlightSqlResultSetColumn { } } }; + } // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set_metadata.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set_metadata.cc index 8ac3c7ed752..52bfcfa4901 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set_metadata.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set_metadata.cc @@ -20,6 +20,7 @@ #include "arrow/flight/sql/column_metadata.h" #include "arrow/flight/sql/odbc/odbc_impl/platform.h" #include "arrow/flight/sql/odbc/odbc_impl/util.h" +#include "arrow/util/key_value_metadata.h" #include #include "arrow/flight/sql/odbc/odbc_impl/exceptions.h" @@ -40,12 +41,8 @@ constexpr int32_t DefaultDecimalPrecision = 38; constexpr int32_t DefaultLengthForVariableLengthColumns = 1024; namespace { -std::shared_ptr empty_metadata_map(new KeyValueMetadata); - inline ColumnMetadata GetMetadata(const std::shared_ptr& field) { - const auto& metadata_map = field->metadata(); - - ColumnMetadata metadata(metadata_map ? metadata_map : empty_metadata_map); + ColumnMetadata metadata(field->metadata()); return metadata; } @@ -160,19 +157,27 @@ size_t FlightSqlResultSetMetadata::GetLength(int column_position) { } std::string FlightSqlResultSetMetadata::GetLiteralPrefix(int column_position) { - // TODO: Flight SQL column metadata does not have this, should we add to the spec? + // GH-47853 TODO: use `ColumnMetadata` to get literal prefix after Flight SQL protocol + // adds support for it + + // Flight SQL column metadata does not have literal prefix, empty string is returned return ""; } std::string FlightSqlResultSetMetadata::GetLiteralSuffix(int column_position) { - // TODO: Flight SQL column metadata does not have this, should we add to the spec? + // GH-47853 TODO: use `ColumnMetadata` to get literal suffix after Flight SQL protocol + // adds support for it + + // Flight SQL column metadata does not have literal suffix, empty string is returned return ""; } std::string FlightSqlResultSetMetadata::GetLocalTypeName(int column_position) { ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1)); - // TODO: Is local type name the same as type name? + // Local type name is for display purpose only. + // Return type name as local type name as Flight SQL protocol doesn't have support for + // local type name. return metadata.GetTypeName().ValueOrElse([] { return ""; }); } @@ -195,7 +200,7 @@ size_t FlightSqlResultSetMetadata::GetOctetLength(int column_position) { // Workaround to get the precision for Decimal and Numeric types, since server doesn't // return it currently. - // TODO: Use the server precision when its fixed. + // GH-47854 TODO: Use the server precision when its fixed. std::shared_ptr arrow_type = field->type(); if (arrow_type->id() == Type::DECIMAL128) { int32_t precision = util::GetDecimalTypePrecision(arrow_type); @@ -207,10 +212,13 @@ size_t FlightSqlResultSetMetadata::GetOctetLength(int column_position) { .value_or(DefaultLengthForVariableLengthColumns); } -std::string FlightSqlResultSetMetadata::GetTypeName(int column_position) { +std::string FlightSqlResultSetMetadata::GetTypeName(int column_position, int data_type) { ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1)); - return metadata.GetTypeName().ValueOrElse([] { return ""; }); + return metadata.GetTypeName().ValueOrElse([data_type] { + // If we get an empty type name, figure out the type name from the data_type. + return util::GetTypeNameFromSqlDataType(data_type); + }); } Updatability FlightSqlResultSetMetadata::GetUpdatable(int column_position) { @@ -220,7 +228,6 @@ Updatability FlightSqlResultSetMetadata::GetUpdatable(int column_position) { bool FlightSqlResultSetMetadata::IsAutoUnique(int column_position) { ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1)); - // TODO: Is AutoUnique equivalent to AutoIncrement? return metadata.GetIsAutoIncrement().ValueOrElse([] { return false; }); } @@ -241,18 +248,29 @@ bool FlightSqlResultSetMetadata::IsUnsigned(int column_position) { const std::shared_ptr& field = schema_->field(column_position - 1); switch (field->type()->id()) { + case Type::INT8: + case Type::INT16: + case Type::INT32: + case Type::INT64: + case Type::DOUBLE: + case Type::FLOAT: + case Type::HALF_FLOAT: + case Type::DECIMAL32: + case Type::DECIMAL64: + case Type::DECIMAL128: + case Type::DECIMAL256: + return false; case Type::UINT8: case Type::UINT16: case Type::UINT32: case Type::UINT64: - return true; default: - return false; + return true; } } bool FlightSqlResultSetMetadata::IsFixedPrecScale(int column_position) { - // TODO: Flight SQL column metadata does not have this, should we add to the spec? + // Precision for Arrow data types are modifiable by the user return false; } diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set_metadata.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set_metadata.h index 0d141a4bb9c..11b1678c24d 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set_metadata.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set_metadata.h @@ -77,7 +77,7 @@ class FlightSqlResultSetMetadata : public ResultSetMetadata { size_t GetOctetLength(int column_position) override; - std::string GetTypeName(int column_position) override; + std::string GetTypeName(int column_position, int data_type) override; Updatability GetUpdatable(int column_position) override; @@ -87,6 +87,7 @@ class FlightSqlResultSetMetadata : public ResultSetMetadata { Searchability IsSearchable(int column_position) override; + /// \brief Returns true if the column is unsigned (not numeric) bool IsUnsigned(int column_position) override; bool IsFixedPrecScale(int column_position) override; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_ssl_config.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_ssl_config.h index 10b12149712..c2d1423e97e 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_ssl_config.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_ssl_config.h @@ -17,9 +17,9 @@ #pragma once -#include -#include #include +#include "arrow/flight/types.h" +#include "arrow/status.h" namespace arrow::flight::sql::odbc { diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc index 30eb1fdf44a..7d13082a219 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc @@ -29,7 +29,7 @@ #include "arrow/flight/sql/odbc/odbc_impl/util.h" #include "arrow/io/memory.h" -#include +#include #include #include "arrow/flight/sql/odbc/odbc_impl/exceptions.h" @@ -41,9 +41,10 @@ using util::ThrowIfNotOK; namespace { -void ClosePreparedStatementIfAny(std::shared_ptr& prepared_statement) { +void ClosePreparedStatementIfAny(std::shared_ptr& prepared_statement, + const FlightCallOptions& options) { if (prepared_statement != nullptr) { - ThrowIfNotOK(prepared_statement->Close()); + ThrowIfNotOK(prepared_statement->Close(options)); prepared_statement.reset(); } } @@ -52,11 +53,13 @@ void ClosePreparedStatementIfAny(std::shared_ptr& prepared_st FlightSqlStatement::FlightSqlStatement(const Diagnostics& diagnostics, FlightSqlClient& sql_client, + FlightClientOptions client_options, FlightCallOptions call_options, const MetadataSettings& metadata_settings) : diagnostics_("Apache Arrow", diagnostics.GetDataSourceComponent(), diagnostics.GetOdbcVersion()), sql_client_(sql_client), + client_options_(std::move(client_options)), call_options_(std::move(call_options)), metadata_settings_(metadata_settings) { attribute_[METADATA_ID] = static_cast(SQL_FALSE); @@ -66,6 +69,10 @@ FlightSqlStatement::FlightSqlStatement(const Diagnostics& diagnostics, call_options_.timeout = TimeoutDuration{-1}; } +FlightSqlStatement::~FlightSqlStatement() { + ClosePreparedStatementIfAny(prepared_statement_, call_options_); +} + bool FlightSqlStatement::SetAttribute(StatementAttributeId attribute, const Attribute& value) { switch (attribute) { @@ -76,9 +83,9 @@ bool FlightSqlStatement::SetAttribute(StatementAttributeId attribute, case MAX_LENGTH: return CheckIfSetToOnlyValidValue(value, static_cast(0)); case QUERY_TIMEOUT: - if (boost::get(value) > 0) { + if (std::get(value) > 0) { call_options_.timeout = - TimeoutDuration{static_cast(boost::get(value))}; + TimeoutDuration{static_cast(std::get(value))}; } else { call_options_.timeout = TimeoutDuration{-1}; // Intentional fall-through. @@ -89,15 +96,19 @@ bool FlightSqlStatement::SetAttribute(StatementAttributeId attribute, } } -boost::optional FlightSqlStatement::GetAttribute( +std::optional FlightSqlStatement::GetAttribute( StatementAttributeId attribute) { const auto& it = attribute_.find(attribute); - return boost::make_optional(it != attribute_.end(), it->second); + if (it != attribute_.end()) { + return std::make_optional(it->second); + } else { + return std::nullopt; + } } -boost::optional> FlightSqlStatement::Prepare( +std::optional> FlightSqlStatement::Prepare( const std::string& query) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.Prepare(call_options_, query); @@ -107,31 +118,33 @@ boost::optional> FlightSqlStatement::Prepare( const auto& result_set_metadata = std::make_shared( prepared_statement_->dataset_schema(), metadata_settings_); - return boost::optional>(result_set_metadata); + return std::optional>(result_set_metadata); } bool FlightSqlStatement::ExecutePrepared() { assert(prepared_statement_.get() != nullptr); - Result> result = prepared_statement_->Execute(); + Result> result = + prepared_statement_->Execute(call_options_); + ThrowIfNotOK(result.status()); current_result_set_ = std::make_shared( - sql_client_, call_options_, result.ValueOrDie(), nullptr, diagnostics_, - metadata_settings_); + sql_client_, client_options_, call_options_, result.ValueOrDie(), nullptr, + diagnostics_, metadata_settings_); return true; } bool FlightSqlStatement::Execute(const std::string& query) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.Execute(call_options_, query); ThrowIfNotOK(result.status()); current_result_set_ = std::make_shared( - sql_client_, call_options_, result.ValueOrDie(), nullptr, diagnostics_, - metadata_settings_); + sql_client_, client_options_, call_options_, result.ValueOrDie(), nullptr, + diagnostics_, metadata_settings_); return true; } @@ -146,33 +159,35 @@ std::shared_ptr FlightSqlStatement::GetTables( const std::string* catalog_name, const std::string* schema_name, const std::string* table_name, const std::string* table_type, const ColumnNames& column_names) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); std::vector table_types; if ((catalog_name && *catalog_name == "%") && (schema_name && schema_name->empty()) && (table_name && table_name->empty())) { - current_result_set_ = GetTablesForSQLAllCatalogs( - column_names, call_options_, sql_client_, diagnostics_, metadata_settings_); + current_result_set_ = + GetTablesForSQLAllCatalogs(column_names, client_options_, call_options_, + sql_client_, diagnostics_, metadata_settings_); } else if ((catalog_name && catalog_name->empty()) && (schema_name && *schema_name == "%") && (table_name && table_name->empty())) { - current_result_set_ = - GetTablesForSQLAllDbSchemas(column_names, call_options_, sql_client_, schema_name, - diagnostics_, metadata_settings_); + current_result_set_ = GetTablesForSQLAllDbSchemas( + column_names, client_options_, call_options_, sql_client_, schema_name, + diagnostics_, metadata_settings_); } else if ((catalog_name && catalog_name->empty()) && (schema_name && schema_name->empty()) && (table_name && table_name->empty()) && (table_type && *table_type == "%")) { - current_result_set_ = GetTablesForSQLAllTableTypes( - column_names, call_options_, sql_client_, diagnostics_, metadata_settings_); + current_result_set_ = + GetTablesForSQLAllTableTypes(column_names, client_options_, call_options_, + sql_client_, diagnostics_, metadata_settings_); } else { if (table_type) { ParseTableTypes(*table_type, table_types); } current_result_set_ = GetTablesForGenericUse( - column_names, call_options_, sql_client_, catalog_name, schema_name, table_name, - table_types, diagnostics_, metadata_settings_); + column_names, client_options_, call_options_, sql_client_, catalog_name, + schema_name, table_name, table_types, diagnostics_, metadata_settings_); } return current_result_set_; @@ -199,7 +214,7 @@ std::shared_ptr FlightSqlStatement::GetTables_V3( std::shared_ptr FlightSqlStatement::GetColumns_V2( const std::string* catalog_name, const std::string* schema_name, const std::string* table_name, const std::string* column_name) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.GetTables( call_options_, catalog_name, schema_name, table_name, true, nullptr); @@ -210,9 +225,9 @@ std::shared_ptr FlightSqlStatement::GetColumns_V2( auto transformer = std::make_shared( metadata_settings_, OdbcVersion::V_2, column_name); - current_result_set_ = - std::make_shared(sql_client_, call_options_, flight_info, - transformer, diagnostics_, metadata_settings_); + current_result_set_ = std::make_shared( + sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_, + metadata_settings_); return current_result_set_; } @@ -220,7 +235,7 @@ std::shared_ptr FlightSqlStatement::GetColumns_V2( std::shared_ptr FlightSqlStatement::GetColumns_V3( const std::string* catalog_name, const std::string* schema_name, const std::string* table_name, const std::string* column_name) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.GetTables( call_options_, catalog_name, schema_name, table_name, true, nullptr); @@ -231,15 +246,15 @@ std::shared_ptr FlightSqlStatement::GetColumns_V3( auto transformer = std::make_shared( metadata_settings_, OdbcVersion::V_3, column_name); - current_result_set_ = - std::make_shared(sql_client_, call_options_, flight_info, - transformer, diagnostics_, metadata_settings_); + current_result_set_ = std::make_shared( + sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_, + metadata_settings_); return current_result_set_; } std::shared_ptr FlightSqlStatement::GetTypeInfo_V2(int16_t data_type) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.GetXdbcTypeInfo(call_options_); ThrowIfNotOK(result.status()); @@ -249,15 +264,15 @@ std::shared_ptr FlightSqlStatement::GetTypeInfo_V2(int16_t data_type) auto transformer = std::make_shared( metadata_settings_, OdbcVersion::V_2, data_type); - current_result_set_ = - std::make_shared(sql_client_, call_options_, flight_info, - transformer, diagnostics_, metadata_settings_); + current_result_set_ = std::make_shared( + sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_, + metadata_settings_); return current_result_set_; } std::shared_ptr FlightSqlStatement::GetTypeInfo_V3(int16_t data_type) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.GetXdbcTypeInfo(call_options_); ThrowIfNotOK(result.status()); @@ -267,9 +282,9 @@ std::shared_ptr FlightSqlStatement::GetTypeInfo_V3(int16_t data_type) auto transformer = std::make_shared( metadata_settings_, OdbcVersion::V_3, data_type); - current_result_set_ = - std::make_shared(sql_client_, call_options_, flight_info, - transformer, diagnostics_, metadata_settings_); + current_result_set_ = std::make_shared( + sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_, + metadata_settings_); return current_result_set_; } diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h index 36dc245c1d7..f9b9cd75611 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h @@ -26,12 +26,15 @@ #include "arrow/flight/sql/api.h" #include "arrow/flight/types.h" +#include + namespace arrow::flight::sql::odbc { class FlightSqlStatement : public Statement { private: Diagnostics diagnostics_; std::map attribute_; + FlightClientOptions client_options_; FlightCallOptions call_options_; FlightSqlClient& sql_client_; std::shared_ptr current_result_set_; @@ -46,14 +49,15 @@ class FlightSqlStatement : public Statement { public: FlightSqlStatement(const Diagnostics& diagnostics, FlightSqlClient& sql_client, - FlightCallOptions call_options, + FlightClientOptions client_options, FlightCallOptions call_options, const MetadataSettings& metadata_settings); + ~FlightSqlStatement(); bool SetAttribute(StatementAttributeId attribute, const Attribute& value) override; - boost::optional GetAttribute(StatementAttributeId attribute) override; + std::optional GetAttribute(StatementAttributeId attribute) override; - boost::optional> Prepare( + std::optional> Prepare( const std::string& query) override; bool ExecutePrepared() override; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_columns.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_columns.cc index 914cd8fa452..a6200c0b1c1 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_columns.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_columns.cc @@ -26,6 +26,8 @@ namespace arrow::flight::sql::odbc { using arrow::Result; +using arrow::Schema; +using arrow::flight::sql::ColumnMetadata; using util::AppendToBuilder; using std::make_optional; @@ -99,10 +101,9 @@ Result> TransformInner( const auto& table_name = reader.GetTableName(); const std::shared_ptr& schema = reader.GetSchema(); if (schema == nullptr) { - // TODO: Remove this if after fixing TODO on GetTablesReader::GetSchema() - // This is because of a problem on Dremio server, where complex types columns - // are being returned without the children types, so we are simply ignoring - // it by now. + // GH-46561 TODO: Test and build the driver against a server that returns + // complex types columns with the children + // types and handle the failure properly. continue; } for (int i = 0; i < schema->num_fields(); ++i) { @@ -126,8 +127,8 @@ Result> TransformInner( ? data_type_v3 : util::ConvertSqlDataTypeFromV3ToV2(data_type_v3); - // TODO: Use `metadata.GetTypeName()` when ARROW-16064 is merged. - const auto& type_name_result = field->metadata()->Get("ARROW:FLIGHT:SQL:TYPE_NAME"); + const auto& type_name_result = metadata.GetTypeName(); + data.type_name = type_name_result.ok() ? type_name_result.ValueOrDie() : util::GetTypeNameFromSqlDataType(data_type_v3); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.cc index 1af2ab42bff..87c7ac0f532 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.cc @@ -66,9 +66,9 @@ void ParseTableTypes(const std::string& table_type, } std::shared_ptr GetTablesForSQLAllCatalogs( - const ColumnNames& names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings) { + const ColumnNames& names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings) { Result> result = sql_client.GetCatalogs(call_options); std::shared_ptr schema; @@ -86,13 +86,15 @@ std::shared_ptr GetTablesForSQLAllCatalogs( .AddFieldOfNulls(names.remarks_column, arrow::utf8()) .Build(); - return std::make_shared( - sql_client, call_options, flight_info, transformer, diagnostics, metadata_settings); + return std::make_shared(sql_client, client_options, call_options, + flight_info, transformer, diagnostics, + metadata_settings); } std::shared_ptr GetTablesForSQLAllDbSchemas( - const ColumnNames& names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, const std::string* schema_name, Diagnostics& diagnostics, + const ColumnNames& names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + const std::string* schema_name, Diagnostics& diagnostics, const MetadataSettings& metadata_settings) { Result> result = sql_client.GetDbSchemas(call_options, nullptr, schema_name); @@ -112,14 +114,15 @@ std::shared_ptr GetTablesForSQLAllDbSchemas( .AddFieldOfNulls(names.remarks_column, arrow::utf8()) .Build(); - return std::make_shared( - sql_client, call_options, flight_info, transformer, diagnostics, metadata_settings); + return std::make_shared(sql_client, client_options, call_options, + flight_info, transformer, diagnostics, + metadata_settings); } std::shared_ptr GetTablesForSQLAllTableTypes( - const ColumnNames& names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings) { + const ColumnNames& names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings) { Result> result = sql_client.GetTableTypes(call_options); std::shared_ptr schema; @@ -137,16 +140,17 @@ std::shared_ptr GetTablesForSQLAllTableTypes( .AddFieldOfNulls(names.remarks_column, arrow::utf8()) .Build(); - return std::make_shared( - sql_client, call_options, flight_info, transformer, diagnostics, metadata_settings); + return std::make_shared(sql_client, client_options, call_options, + flight_info, transformer, diagnostics, + metadata_settings); } std::shared_ptr GetTablesForGenericUse( - const ColumnNames& names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, const std::string* catalog_name, - const std::string* schema_name, const std::string* table_name, - const std::vector& table_types, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings) { + const ColumnNames& names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + const std::string* catalog_name, const std::string* schema_name, + const std::string* table_name, const std::vector& table_types, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings) { Result> result = sql_client.GetTables( call_options, catalog_name, schema_name, table_name, false, &table_types); @@ -165,8 +169,9 @@ std::shared_ptr GetTablesForGenericUse( .AddFieldOfNulls(names.remarks_column, arrow::utf8()) .Build(); - return std::make_shared( - sql_client, call_options, flight_info, transformer, diagnostics, metadata_settings); + return std::make_shared(sql_client, client_options, call_options, + flight_info, transformer, diagnostics, + metadata_settings); } } // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.h index 31abab91cb5..5687134f1eb 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.h @@ -40,25 +40,25 @@ void ParseTableTypes(const std::string& table_type, std::vector& table_types); std::shared_ptr GetTablesForSQLAllCatalogs( - const ColumnNames& column_names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings); + const ColumnNames& column_names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings); std::shared_ptr GetTablesForSQLAllDbSchemas( - const ColumnNames& column_names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, const std::string* schema_name, Diagnostics& diagnostics, + const ColumnNames& column_names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + const std::string* schema_name, Diagnostics& diagnostics, const MetadataSettings& metadata_settings); std::shared_ptr GetTablesForSQLAllTableTypes( - const ColumnNames& column_names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings); + const ColumnNames& column_names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings); std::shared_ptr GetTablesForGenericUse( - const ColumnNames& column_names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, const std::string* catalog_name, - const std::string* schema_name, const std::string* table_name, - const std::vector& table_types, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings); - + const ColumnNames& column_names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + const std::string* catalog_name, const std::string* schema_name, + const std::string* table_name, const std::vector& table_types, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings); } // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_type_info.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_type_info.cc index e94378b7e04..95c0753603b 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_type_info.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_type_info.cc @@ -108,7 +108,7 @@ Result> TransformInner( data.literal_suffix = reader.GetLiteralSuffix(); const auto& create_params = reader.GetCreateParams(); - if (create_params) { + if (create_params && !create_params->empty()) { data.create_params = boost::algorithm::join(*create_params, ","); } else { data.create_params = nullopt; @@ -116,6 +116,8 @@ Result> TransformInner( data.nullable = reader.GetNullable() ? NULLABILITY_NULLABLE : NULLABILITY_NO_NULLS; data.case_sensitive = reader.GetCaseSensitive(); + // GH-47237 return SEARCHABILITY_LIKE_ONLY and SEARCHABILITY_ALL_EXPECT_LIKE for + // appropriate data types data.searchable = reader.GetSearchable() ? SEARCHABILITY_ALL : SEARCHABILITY_NONE; data.unsigned_attribute = reader.GetUnsignedAttribute(); data.fixed_prec_scale = reader.GetFixedPrecScale(); @@ -123,9 +125,9 @@ Result> TransformInner( data.local_type_name = reader.GetLocalTypeName(); data.minimum_scale = reader.GetMinimumScale(); data.maximum_scale = reader.GetMaximumScale(); - data.sql_data_type = + data.sql_data_type = util::GetNonConciseDataType( EnsureRightSqlCharType(static_cast(reader.GetSqlDataType()), - metadata_settings_.use_wide_char); + metadata_settings_.use_wide_char)); data.sql_datetime_sub = util::GetSqlDateTimeSubCode(static_cast(data.data_type)); data.num_prec_radix = reader.GetNumPrecRadix(); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc index 25bf04ea507..ee413d926bf 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc @@ -20,37 +20,71 @@ namespace arrow::flight::sql::odbc { -using arrow::Result; - FlightStreamChunkBuffer::FlightStreamChunkBuffer( - FlightSqlClient& flight_sql_client, const FlightCallOptions& call_options, - const std::shared_ptr& flight_info, size_t queue_capacity) + FlightSqlClient& flight_sql_client, const FlightClientOptions& client_options, + const FlightCallOptions& call_options, const std::shared_ptr& flight_info, + size_t queue_capacity) : queue_(queue_capacity) { - // FIXME: Endpoint iteration should consider endpoints may be at different hosts for (const auto& endpoint : flight_info->endpoints()) { const Ticket& ticket = endpoint.ticket; - auto result = flight_sql_client.DoGet(call_options, ticket); + arrow::Result> result; + std::shared_ptr temp_flight_sql_client; + auto endpoint_locations = endpoint.locations; + if (endpoint_locations.empty()) { + // list of Locations needs to be empty to proceed + result = flight_sql_client.DoGet(call_options, ticket); + } else { + // If it is non-empty, the driver should create a FlightSqlClient to connect to one + // of the specified Locations directly. + + // GH-47117: Currently a new FlightClient will be made for each partition that + // returns a non-empty Location, which is then disposed of. It may be better to + // cache clients because a server may report the same Locations. It would also be + // good to identify when the reported Location is the same as the original + // connection's Location and skip creating a FlightClient in that scenario. + + std::unique_ptr temp_flight_client; + util::ThrowIfNotOK(FlightClient::Connect(endpoint_locations[0], client_options) + .Value(&temp_flight_client)); + temp_flight_sql_client.reset(new FlightSqlClient(std::move(temp_flight_client))); + + result = temp_flight_sql_client->DoGet(call_options, ticket); + } + util::ThrowIfNotOK(result.status()); std::shared_ptr stream_reader_ptr(std::move(result.ValueOrDie())); - BlockingQueue>::Supplier supplier = [=] { + BlockingQueue, + std::shared_ptr>>::Supplier supplier = [=] { auto result = stream_reader_ptr->Next(); bool is_not_ok = !result.ok(); bool is_not_empty = result.ok() && (result.ValueOrDie().data != nullptr); - return boost::make_optional(is_not_ok || is_not_empty, std::move(result)); + // If result is valid, save the temp Flight SQL Client for future stream reader + // call. temp_flight_sql_client is intentionally null if the list of endpoint + // Locations is empty. + // After all data is fetched from reader, the temp client is closed. + if (is_not_ok || is_not_empty) { + return std::make_optional( + std::make_pair(std::move(result), temp_flight_sql_client)); + } else { + return std::optional< + std::pair, std::shared_ptr>>{}; + } }; queue_.AddProducer(std::move(supplier)); } } bool FlightStreamChunkBuffer::GetNext(FlightStreamChunk* chunk) { - Result result; - if (!queue_.Pop(&result)) { + std::pair, std::shared_ptr> + closeable_endpoint_stream_pair; + if (!queue_.Pop(&closeable_endpoint_stream_pair)) { return false; } + Result result = closeable_endpoint_stream_pair.first; if (!result.status().ok()) { Close(); throw DriverException(result.status().message()); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h index f59336c984d..772a854eb59 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h @@ -23,11 +23,15 @@ namespace arrow::flight::sql::odbc { +using arrow::Result; + class FlightStreamChunkBuffer { - BlockingQueue> queue_; + BlockingQueue, std::shared_ptr>> + queue_; public: FlightStreamChunkBuffer(FlightSqlClient& flight_sql_client, + const FlightClientOptions& client_options, const FlightCallOptions& call_options, const std::shared_ptr& flight_info, size_t queue_capacity = 5); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc new file mode 100644 index 00000000000..a3f23ecaaf9 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc @@ -0,0 +1,136 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/array.h" + +#include "arrow/testing/gtest_util.h" + +#include "arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h" +#include "arrow/flight/sql/odbc/odbc_impl/json_converter.h" +#include "arrow/flight/test_flight_server.h" +#include "arrow/flight/test_util.h" + +#include + +namespace arrow::flight::sql::odbc { + +using arrow::Array; +using arrow::flight::FlightCallOptions; +using arrow::flight::FlightClientOptions; +using arrow::flight::FlightDescriptor; +using arrow::flight::FlightEndpoint; +using arrow::flight::Location; +using arrow::flight::Ticket; +using arrow::flight::sql::FlightSqlClient; + +class FlightStreamChunkBufferTest : public ::testing::Test { + // Sets up two mock servers for each test case. + // This is for testing endpoint iteration only. + + protected: + void SetUp() override { + // Set up server 1 + server1 = std::make_shared(); + ASSERT_OK_AND_ASSIGN(auto location1, Location::ForGrpcTcp("0.0.0.0", 0)); + arrow::flight::FlightServerOptions options1(location1); + ASSERT_OK(server1->Init(options1)); + ASSERT_OK_AND_ASSIGN(server_location1, + Location::ForGrpcTcp("localhost", server1->port())); + + // Set up server 2 + server2 = std::make_shared(); + ASSERT_OK_AND_ASSIGN(auto location2, Location::ForGrpcTcp("0.0.0.0", 0)); + arrow::flight::FlightServerOptions options2(location2); + ASSERT_OK(server2->Init(options2)); + ASSERT_OK_AND_ASSIGN(server_location2, + Location::ForGrpcTcp("localhost", server2->port())); + + // Make SQL Client that is connected to server 1 + ASSERT_OK_AND_ASSIGN(auto client, arrow::flight::FlightClient::Connect(location1)); + sql_client.reset(new FlightSqlClient(std::move(client))); + } + + void TearDown() override { + ASSERT_OK(server1->Shutdown()); + ASSERT_OK(server2->Shutdown()); + } + + public: + arrow::flight::Location server_location1; + std::shared_ptr server1; + arrow::flight::Location server_location2; + std::shared_ptr server2; + std::shared_ptr sql_client; +}; + +FlightInfo MultipleEndpointsFlightInfo(Location location1, Location location2) { + // Sever will generate random data for `ticket-ints-1` + FlightEndpoint endpoint1({Ticket{"ticket-ints-1"}, {location1}, std::nullopt, {}}); + FlightEndpoint endpoint2({Ticket{"ticket-ints-1"}, {location2}, std::nullopt, {}}); + + FlightDescriptor descr1{FlightDescriptor::PATH, "", {"examples", "ints"}}; + + auto schema1 = arrow::flight::ExampleIntSchema(); + + return arrow::flight::MakeFlightInfo(*schema1, descr1, {endpoint1, endpoint2}, 1000, + 100000, false, ""); +} + +void VerifyArraysContainIntsOnly(std::shared_ptr intArray) { + for (int64_t i = 0; i < intArray->length(); ++i) { + // null values are accepted + if (!intArray->IsNull(i)) { + auto scalar_data = intArray->GetScalar(i).ValueOrDie(); + std::string scalar_str = ConvertToJson(*scalar_data); + ASSERT_TRUE(std::all_of(scalar_str.begin(), scalar_str.end(), ::isdigit)); + } + } +} + +TEST_F(FlightStreamChunkBufferTest, TestMultipleEndpointsInt) { + FlightClientOptions client_options = FlightClientOptions::Defaults(); + FlightCallOptions options; + FlightInfo info = MultipleEndpointsFlightInfo(server_location1, server_location2); + std::shared_ptr info_ptr = std::make_shared(info); + + FlightStreamChunkBuffer chunk_buffer(*sql_client, client_options, options, info_ptr); + + FlightStreamChunk current_chunk; + + // Server returns 5 batch of results from each endpoints. + // Each batch contains 8 columns + int num_chunks = 0; + while (chunk_buffer.GetNext(¤t_chunk)) { + num_chunks++; + + int num_cols = current_chunk.data->num_columns(); + ASSERT_EQ(8, num_cols); + + for (int i = 0; i < num_cols; i++) { + auto array = current_chunk.data->column(i); + // Each array has random length + ASSERT_GT(array->length(), 0); + + VerifyArraysContainIntsOnly(array); + } + } + + // Verify 5 batches of data is returned by each of the two endpoints. + // In total 10 batches should be returned. + ASSERT_EQ(10, num_chunks); +} +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.cc index bf2f6b6eca2..cd637be165e 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.cc @@ -199,10 +199,14 @@ inline void SetDefaultIfMissing(std::unordered_map& } // namespace -GetInfoCache::GetInfoCache(FlightCallOptions& call_options, +GetInfoCache::GetInfoCache(FlightClientOptions& client_options, + FlightCallOptions& call_options, std::unique_ptr& client, const std::string& driver_version) - : call_options_(call_options), sql_client_(client), has_server_info_(false) { + : client_options_(client_options), + call_options_(call_options), + sql_client_(client), + has_server_info_(false) { info_[SQL_DRIVER_NAME] = "Arrow Flight ODBC Driver"; info_[SQL_DRIVER_VER] = util::ConvertToDBMSVer(driver_version); @@ -283,7 +287,8 @@ bool GetInfoCache::LoadInfoFromServer() { arrow::Result> result = sql_client_->GetSqlInfo(call_options_, {}); util::ThrowIfNotOK(result.status()); - FlightStreamChunkBuffer chunk_iter(*sql_client_, call_options_, result.ValueOrDie()); + FlightStreamChunkBuffer chunk_iter(*sql_client_, client_options_, call_options_, + result.ValueOrDie()); FlightStreamChunk chunk; bool supports_correlation_name = false; @@ -311,8 +316,8 @@ bool GetInfoCache::LoadInfoFromServer() { std::string server_name( reinterpret_cast(scalar->child_value().get())->view()); - // TODO: Consider creating different properties in GetSqlInfo. - // TODO: Investigate if SQL_SERVER_NAME should just be the host + // GH-47855 TODO: Consider creating different properties in GetSqlInfo. + // GH-47856 TODO: Investigate if SQL_SERVER_NAME should just be the host // address as well. In JDBC, FLIGHT_SQL_SERVER_NAME is only used for // the DatabaseProductName. info_[SQL_SERVER_NAME] = server_name; @@ -913,21 +918,21 @@ bool GetInfoCache::LoadInfoFromServer() { break; } case SqlInfoOptions::SQL_SUPPORTED_RESULT_SET_TYPES: - // Ignored. Warpdrive supports forward-only only. + // Ignored. Arrow ODBC supports forward-only only. break; case SqlInfoOptions::SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_UNSPECIFIED: - // Ignored. Warpdrive supports forward-only only. + // Ignored. Arrow ODBC supports forward-only only. break; case SqlInfoOptions::SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_FORWARD_ONLY: - // Ignored. Warpdrive supports forward-only only. + // Ignored. Arrow ODBC supports forward-only only. break; case SqlInfoOptions:: SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_SENSITIVE: - // Ignored. Warpdrive supports forward-only only. + // Ignored. Arrow ODBC supports forward-only only. break; case SqlInfoOptions:: SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_INSENSITIVE: - // Ignored. Warpdrive supports forward-only only. + // Ignored. Arrow ODBC supports forward-only only. break; // List properties @@ -1127,6 +1132,7 @@ void GetInfoCache::LoadDefaultsForMissingEntries() { SetDefaultIfMissing(info_, SQL_CONVERT_DECIMAL, static_cast(0)); SetDefaultIfMissing(info_, SQL_CONVERT_DOUBLE, static_cast(0)); SetDefaultIfMissing(info_, SQL_CONVERT_FLOAT, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_FUNCTIONS, static_cast(0)); SetDefaultIfMissing(info_, SQL_CONVERT_GUID, static_cast(0)); SetDefaultIfMissing(info_, SQL_CONVERT_INTEGER, static_cast(0)); SetDefaultIfMissing(info_, SQL_CONVERT_INTERVAL_YEAR_MONTH, static_cast(0)); @@ -1205,6 +1211,7 @@ void GetInfoCache::LoadDefaultsForMissingEntries() { SetDefaultIfMissing(info_, SQL_MAX_COLUMNS_IN_ORDER_BY, static_cast(0)); SetDefaultIfMissing(info_, SQL_MAX_COLUMNS_IN_SELECT, static_cast(0)); SetDefaultIfMissing(info_, SQL_MAX_COLUMNS_IN_TABLE, static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_CONCURRENT_ACTIVITIES, static_cast(0)); SetDefaultIfMissing(info_, SQL_MAX_CURSOR_NAME_LEN, static_cast(0)); SetDefaultIfMissing(info_, SQL_MAX_DRIVER_CONNECTIONS, static_cast(0)); SetDefaultIfMissing(info_, SQL_MAX_IDENTIFIER_LEN, static_cast(65535)); @@ -1224,6 +1231,7 @@ void GetInfoCache::LoadDefaultsForMissingEntries() { SetDefaultIfMissing(info_, SQL_OJ_CAPABILITIES, static_cast(SQL_OJ_LEFT | SQL_OJ_RIGHT | SQL_OJ_FULL)); SetDefaultIfMissing(info_, SQL_ORDER_BY_COLUMNS_IN_SELECT, "Y"); + SetDefaultIfMissing(info_, SQL_OUTER_JOINS, "N"); SetDefaultIfMissing(info_, SQL_PROCEDURE_TERM, ""); SetDefaultIfMissing(info_, SQL_PROCEDURES, "N"); SetDefaultIfMissing(info_, SQL_QUOTED_IDENTIFIER_CASE, @@ -1232,6 +1240,7 @@ void GetInfoCache::LoadDefaultsForMissingEntries() { SetDefaultIfMissing(info_, SQL_SCHEMA_USAGE, static_cast(SQL_SU_DML_STATEMENTS)); SetDefaultIfMissing(info_, SQL_SEARCH_PATTERN_ESCAPE, "\\"); + SetDefaultIfMissing(info_, SQL_SPECIAL_CHARACTERS, ""); SetDefaultIfMissing( info_, SQL_SERVER_NAME, "Arrow Flight SQL Server"); // This might actually need to be the hostname. @@ -1286,6 +1295,16 @@ void GetInfoCache::LoadDefaultsForMissingEntries() { SQL_FN_TSI_FRAC_SECOND | SQL_FN_TSI_SECOND | SQL_FN_TSI_MINUTE | SQL_FN_TSI_HOUR | SQL_FN_TSI_DAY | SQL_FN_TSI_WEEK | SQL_FN_TSI_MONTH | SQL_FN_TSI_QUARTER | SQL_FN_TSI_YEAR)); + SetDefaultIfMissing( + info_, SQL_TIMEDATE_FUNCTIONS, + static_cast( + SQL_FN_TD_CURRENT_DATE | SQL_FN_TD_CURRENT_TIME | SQL_FN_TD_CURRENT_TIMESTAMP | + SQL_FN_TD_CURDATE | SQL_FN_TD_CURTIME | SQL_FN_TD_DAYNAME | + SQL_FN_TD_DAYOFMONTH | SQL_FN_TD_DAYOFWEEK | SQL_FN_TD_DAYOFYEAR | + SQL_FN_TD_EXTRACT | SQL_FN_TD_HOUR | SQL_FN_TD_MINUTE | SQL_FN_TD_MONTH | + SQL_FN_TD_MONTHNAME | SQL_FN_TD_NOW | SQL_FN_TD_QUARTER | SQL_FN_TD_SECOND | + SQL_FN_TD_TIMESTAMPADD | SQL_FN_TD_TIMESTAMPDIFF | SQL_FN_TD_WEEK | + SQL_FN_TD_YEAR)); SetDefaultIfMissing(info_, SQL_UNION, static_cast(SQL_U_UNION | SQL_U_UNION_ALL)); SetDefaultIfMissing(info_, SQL_XOPEN_CLI_YEAR, "1995"); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.h index d0e0efd159f..693ee000de5 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.h @@ -17,26 +17,27 @@ #pragma once -#include "arrow/flight/sql/client.h" -#include "arrow/flight/sql/odbc/odbc_impl/spi/connection.h" - #include #include #include #include +#include "arrow/flight/sql/client.h" +#include "arrow/flight/sql/odbc/odbc_impl/spi/connection.h" namespace arrow::flight::sql::odbc { class GetInfoCache { private: std::unordered_map info_; + FlightClientOptions& client_options_; FlightCallOptions& call_options_; std::unique_ptr& sql_client_; std::mutex mutex_; std::atomic has_server_info_; public: - GetInfoCache(FlightCallOptions& call_options, std::unique_ptr& client, + GetInfoCache(FlightClientOptions& client_options, FlightCallOptions& call_options, + std::unique_ptr& client, const std::string& driver_version); void SetProperty(uint16_t property, Connection::Info value); Connection::Info GetInfo(uint16_t info_type); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/json_converter.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/json_converter.cc index db6170f3102..acd68f0eef3 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/json_converter.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/json_converter.cc @@ -221,7 +221,7 @@ class ScalarToJson : public ScalarVisitor { } Status Visit(const DurationScalar& scalar) override { - // TODO: Append TimeUnit on conversion + // GH-47857 TODO: Append TimeUnit on conversion return ConvertScalarToStringAndWrite(scalar, writer_); } diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/json_converter.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/json_converter.h index 9c0b42748a8..44321398b6f 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/json_converter.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/json_converter.h @@ -17,8 +17,8 @@ #pragma once -#include #include +#include "arrow/type_fwd.h" namespace arrow::flight::sql::odbc { diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/json_converter_test.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/json_converter_test.cc index a3c3275affd..d6d7a3ed506 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/json_converter_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/json_converter_test.cc @@ -19,7 +19,8 @@ #include "arrow/scalar.h" #include "arrow/testing/builder.h" #include "arrow/type.h" -#include "gtest/gtest.h" + +#include namespace arrow::flight::sql::odbc { @@ -163,7 +164,7 @@ TEST(ConvertToJson, MonthInterval) { } TEST(ConvertToJson, Duration) { - // TODO: Append TimeUnit on conversion + // GH-47857 TODO: Append TimeUnit on conversion ASSERT_EQ("\"123\"", ConvertToJson(DurationScalar(123, TimeUnit::SECOND))); ASSERT_EQ("\"123\"", ConvertToJson(DurationScalar(123, TimeUnit::MILLI))); ASSERT_EQ("\"123\"", ConvertToJson(DurationScalar(123, TimeUnit::MICRO))); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_connection.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_connection.cc index ead2beada4b..755151fcc0b 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_connection.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_connection.cc @@ -21,6 +21,7 @@ #include "arrow/util/utf8.h" #include "arrow/flight/sql/odbc/odbc_impl/attribute_utils.h" +#include "arrow/flight/sql/odbc/odbc_impl/config/configuration.h" #include "arrow/flight/sql/odbc/odbc_impl/exceptions.h" #include "arrow/flight/sql/odbc/odbc_impl/odbc_descriptor.h" #include "arrow/flight/sql/odbc/odbc_impl/odbc_environment.h" @@ -36,6 +37,7 @@ #include #include #include +#include #include using ODBC::ODBCConnection; @@ -87,160 +89,152 @@ void ODBCConnection::Connect(std::string dsn, attribute_tracking_statement_ = std::make_shared(*this, spi_statement); } -void ODBCConnection::GetInfo(SQLUSMALLINT info_type, SQLPOINTER value, - SQLSMALLINT buffer_length, SQLSMALLINT* output_length, - bool is_unicode) { +SQLRETURN ODBCConnection::GetInfo(SQLUSMALLINT info_type, SQLPOINTER value, + SQLSMALLINT buffer_length, SQLSMALLINT* output_length, + bool is_unicode) { switch (info_type) { case SQL_ACTIVE_ENVIRONMENTS: GetAttribute(static_cast(0), value, buffer_length, output_length); - break; + return SQL_SUCCESS; #ifdef SQL_ASYNC_DBC_FUNCTIONS case SQL_ASYNC_DBC_FUNCTIONS: GetAttribute(static_cast(SQL_ASYNC_DBC_NOT_CAPABLE), value, buffer_length, output_length); - break; + return SQL_SUCCESS; #endif case SQL_ASYNC_MODE: GetAttribute(static_cast(SQL_AM_NONE), value, buffer_length, output_length); - break; + return SQL_SUCCESS; #ifdef SQL_ASYNC_NOTIFICATION case SQL_ASYNC_NOTIFICATION: GetAttribute(static_cast(SQL_ASYNC_NOTIFICATION_NOT_CAPABLE), value, buffer_length, output_length); - break; + return SQL_SUCCESS; #endif case SQL_BATCH_ROW_COUNT: GetAttribute(static_cast(0), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_BATCH_SUPPORT: GetAttribute(static_cast(0), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_DATA_SOURCE_NAME: - GetStringAttribute(is_unicode, dsn_, true, value, buffer_length, output_length, - GetDiagnostics()); - break; + return GetStringAttribute(is_unicode, dsn_, true, value, buffer_length, + output_length, GetDiagnostics()); case SQL_DRIVER_ODBC_VER: - GetStringAttribute(is_unicode, "03.80", true, value, buffer_length, output_length, - GetDiagnostics()); - break; + return GetStringAttribute(is_unicode, "03.80", true, value, buffer_length, + output_length, GetDiagnostics()); case SQL_DYNAMIC_CURSOR_ATTRIBUTES1: GetAttribute(static_cast(0), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_DYNAMIC_CURSOR_ATTRIBUTES2: GetAttribute(static_cast(0), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_FORWARD_ONLY_CURSOR_ATTRIBUTES1: GetAttribute(static_cast(SQL_CA1_NEXT), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_FORWARD_ONLY_CURSOR_ATTRIBUTES2: GetAttribute(static_cast(SQL_CA2_READ_ONLY_CONCURRENCY), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_FILE_USAGE: GetAttribute(static_cast(SQL_FILE_NOT_SUPPORTED), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_KEYSET_CURSOR_ATTRIBUTES1: GetAttribute(static_cast(0), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_KEYSET_CURSOR_ATTRIBUTES2: GetAttribute(static_cast(0), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_MAX_ASYNC_CONCURRENT_STATEMENTS: GetAttribute(static_cast(0), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_ODBC_INTERFACE_CONFORMANCE: GetAttribute(static_cast(SQL_OIC_CORE), value, buffer_length, output_length); - break; + return SQL_SUCCESS; // case SQL_ODBC_STANDARD_CLI_CONFORMANCE: - mentioned in SQLGetInfo spec with no // description and there is no constant for this. case SQL_PARAM_ARRAY_ROW_COUNTS: GetAttribute(static_cast(SQL_PARC_NO_BATCH), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_PARAM_ARRAY_SELECTS: GetAttribute(static_cast(SQL_PAS_NO_SELECT), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_ROW_UPDATES: - GetStringAttribute(is_unicode, "N", true, value, buffer_length, output_length, - GetDiagnostics()); - break; + return GetStringAttribute(is_unicode, "N", true, value, buffer_length, + output_length, GetDiagnostics()); case SQL_SCROLL_OPTIONS: GetAttribute(static_cast(SQL_SO_FORWARD_ONLY), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_STATIC_CURSOR_ATTRIBUTES1: GetAttribute(static_cast(0), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_STATIC_CURSOR_ATTRIBUTES2: GetAttribute(static_cast(0), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_BOOKMARK_PERSISTENCE: GetAttribute(static_cast(0), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_DESCRIBE_PARAMETER: - GetStringAttribute(is_unicode, "N", true, value, buffer_length, output_length, - GetDiagnostics()); - break; + return GetStringAttribute(is_unicode, "N", true, value, buffer_length, + output_length, GetDiagnostics()); case SQL_MULT_RESULT_SETS: - GetStringAttribute(is_unicode, "N", true, value, buffer_length, output_length, - GetDiagnostics()); - break; + return GetStringAttribute(is_unicode, "N", true, value, buffer_length, + output_length, GetDiagnostics()); case SQL_MULTIPLE_ACTIVE_TXN: - GetStringAttribute(is_unicode, "N", true, value, buffer_length, output_length, - GetDiagnostics()); - break; + return GetStringAttribute(is_unicode, "N", true, value, buffer_length, + output_length, GetDiagnostics()); case SQL_NEED_LONG_DATA_LEN: - GetStringAttribute(is_unicode, "N", true, value, buffer_length, output_length, - GetDiagnostics()); - break; + return GetStringAttribute(is_unicode, "N", true, value, buffer_length, + output_length, GetDiagnostics()); case SQL_TXN_CAPABLE: GetAttribute(static_cast(SQL_TC_NONE), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_TXN_ISOLATION_OPTION: GetAttribute(static_cast(0), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_TABLE_TERM: - GetStringAttribute(is_unicode, "table", true, value, buffer_length, output_length, - GetDiagnostics()); - break; + return GetStringAttribute(is_unicode, "table", true, value, buffer_length, + output_length, GetDiagnostics()); // Deprecated ODBC 2.x fields required for backwards compatibility. case SQL_ODBC_API_CONFORMANCE: GetAttribute(static_cast(SQL_OAC_LEVEL1), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_FETCH_DIRECTION: GetAttribute(static_cast(SQL_FETCH_NEXT), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_LOCK_TYPES: GetAttribute(static_cast(0), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_POS_OPERATIONS: GetAttribute(static_cast(0), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_POSITIONED_STATEMENTS: GetAttribute(static_cast(0), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_SCROLL_CONCURRENCY: GetAttribute(static_cast(0), value, buffer_length, output_length); - break; + return SQL_SUCCESS; case SQL_STATIC_SENSITIVITY: GetAttribute(static_cast(0), value, buffer_length, output_length); - break; + return SQL_SUCCESS; // Driver-level string properties. case SQL_USER_NAME: case SQL_COLUMN_ALIAS: case SQL_DBMS_NAME: case SQL_DBMS_VER: - case SQL_DRIVER_NAME: // TODO: This should be the driver's filename and shouldn't - // come from the SPI. + case SQL_DRIVER_NAME: // GH-47858 TODO: This should be the driver's filename and + // shouldn't come from the SPI. case SQL_DRIVER_VER: case SQL_SEARCH_PATTERN_ESCAPE: case SQL_SERVER_NAME: @@ -266,10 +260,9 @@ void ODBCConnection::GetInfo(SQLUSMALLINT info_type, SQLPOINTER value, case SQL_SPECIAL_CHARACTERS: case SQL_XOPEN_CLI_YEAR: { const auto& info = spi_connection_->GetInfo(info_type); - const std::string& info_value = boost::get(info); - GetStringAttribute(is_unicode, info_value, true, value, buffer_length, - output_length, GetDiagnostics()); - break; + const std::string& info_value = std::get(info); + return GetStringAttribute(is_unicode, info_value, true, value, buffer_length, + output_length, GetDiagnostics()); } // Driver-level 32-bit integer properties. @@ -357,9 +350,9 @@ void ODBCConnection::GetInfo(SQLUSMALLINT info_type, SQLPOINTER value, case SQL_SQL92_VALUE_EXPRESSIONS: case SQL_STANDARD_CLI_CONFORMANCE: { const auto& info = spi_connection_->GetInfo(info_type); - uint32_t info_value = boost::get(info); + uint32_t info_value = std::get(info); GetAttribute(info_value, value, buffer_length, output_length); - break; + return SQL_SUCCESS; } // Driver-level 16-bit integer properties. @@ -392,9 +385,9 @@ void ODBCConnection::GetInfo(SQLUSMALLINT info_type, SQLPOINTER value, case SQL_ODBC_SQL_CONFORMANCE: case SQL_ODBC_SAG_CLI_CONFORMANCE: { const auto& info = spi_connection_->GetInfo(info_type); - uint16_t info_value = boost::get(info); + uint16_t info_value = std::get(info); GetAttribute(info_value, value, buffer_length, output_length); - break; + return SQL_SUCCESS; } // Special case - SQL_DATABASE_NAME is an alias for SQL_ATTR_CURRENT_CATALOG. @@ -403,14 +396,16 @@ void ODBCConnection::GetInfo(SQLUSMALLINT info_type, SQLPOINTER value, if (!attr) { throw DriverException("Optional feature not supported.", "HYC00"); } - const std::string& info_value = boost::get(*attr); - GetStringAttribute(is_unicode, info_value, true, value, buffer_length, - output_length, GetDiagnostics()); - break; + const std::string& info_value = std::get(*attr); + return GetStringAttribute(is_unicode, info_value, true, value, buffer_length, + output_length, GetDiagnostics()); } default: - throw DriverException("Unknown SQLGetInfo type: " + std::to_string(info_type)); + throw DriverException("Unknown SQLGetInfo type: " + std::to_string(info_type), + "HY096"); } + + return SQL_ERROR; } void ODBCConnection::SetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, @@ -419,7 +414,7 @@ void ODBCConnection::SetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, bool successfully_written = false; switch (attribute) { // Internal connection attributes -#ifdef SQL_ATR_ASYNC_DBC_EVENT +#ifdef SQL_ATTR_ASYNC_DBC_EVENT case SQL_ATTR_ASYNC_DBC_EVENT: throw DriverException("Optional feature not supported.", "HYC00"); #endif @@ -427,7 +422,7 @@ void ODBCConnection::SetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, case SQL_ATTR_ASYNC_DBC_FUNCTIONS_ENABLE: throw DriverException("Optional feature not supported.", "HYC00"); #endif -#ifdef SQL_ATTR_ASYNC_PCALLBACK +#ifdef SQL_ATTR_ASYNC_DBC_PCALLBACK case SQL_ATTR_ASYNC_DBC_PCALLBACK: throw DriverException("Optional feature not supported.", "HYC00"); #endif @@ -455,7 +450,7 @@ void ODBCConnection::SetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, throw DriverException("Cannot set read-only attribute", "HY092"); case SQL_ATTR_TRACE: // DM-only throw DriverException("Cannot set read-only attribute", "HY092"); - case SQL_ATTR_TRACEFILE: + case SQL_ATTR_TRACEFILE: // DM-only throw DriverException("Optional feature not supported.", "HYC00"); case SQL_ATTR_TRANSLATE_LIB: throw DriverException("Optional feature not supported.", "HYC00"); @@ -528,58 +523,58 @@ void ODBCConnection::SetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, } } -void ODBCConnection::GetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, - SQLINTEGER buffer_length, SQLINTEGER* output_length, - bool is_unicode) { - boost::optional spi_attribute; +SQLRETURN ODBCConnection::GetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, + SQLINTEGER buffer_length, + SQLINTEGER* output_length, bool is_unicode) { + std::optional spi_attribute; switch (attribute) { // Internal connection attributes -#ifdef SQL_ATR_ASYNC_DBC_EVENT +#ifdef SQL_ATTR_ASYNC_DBC_EVENT case SQL_ATTR_ASYNC_DBC_EVENT: GetAttribute(static_cast(NULL), value, buffer_length, output_length); - return; + return SQL_SUCCESS; #endif #ifdef SQL_ATTR_ASYNC_DBC_FUNCTIONS_ENABLE case SQL_ATTR_ASYNC_DBC_FUNCTIONS_ENABLE: GetAttribute(static_cast(SQL_ASYNC_DBC_ENABLE_OFF), value, buffer_length, output_length); - return; + return SQL_SUCCESS; #endif -#ifdef SQL_ATTR_ASYNC_PCALLBACK +#ifdef SQL_ATTR_ASYNC_DBC_PCALLBACK case SQL_ATTR_ASYNC_DBC_PCALLBACK: GetAttribute(static_cast(NULL), value, buffer_length, output_length); - return; + return SQL_SUCCESS; #endif #ifdef SQL_ATTR_ASYNC_DBC_PCONTEXT case SQL_ATTR_ASYNC_DBC_PCONTEXT: GetAttribute(static_cast(NULL), value, buffer_length, output_length); - return; + return SQL_SUCCESS; #endif case SQL_ATTR_ASYNC_ENABLE: GetAttribute(static_cast(SQL_ASYNC_ENABLE_OFF), value, buffer_length, output_length); - return; + return SQL_SUCCESS; case SQL_ATTR_AUTO_IPD: GetAttribute(static_cast(SQL_FALSE), value, buffer_length, output_length); - return; + return SQL_SUCCESS; case SQL_ATTR_AUTOCOMMIT: GetAttribute(static_cast(SQL_AUTOCOMMIT_ON), value, buffer_length, output_length); - return; + return SQL_SUCCESS; #ifdef SQL_ATTR_DBC_INFO_TOKEN case SQL_ATTR_DBC_INFO_TOKEN: throw DriverException("Cannot read set-only attribute", "HY092"); #endif case SQL_ATTR_ENLIST_IN_DTC: GetAttribute(static_cast(NULL), value, buffer_length, output_length); - return; + return SQL_SUCCESS; case SQL_ATTR_ODBC_CURSORS: // DM-only. throw DriverException("Invalid attribute", "HY092"); case SQL_ATTR_QUIET_MODE: GetAttribute(static_cast(NULL), value, buffer_length, output_length); - return; + return SQL_SUCCESS; case SQL_ATTR_TRACE: // DM-only throw DriverException("Invalid attribute", "HY092"); case SQL_ATTR_TRACEFILE: @@ -589,17 +584,16 @@ void ODBCConnection::GetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, case SQL_ATTR_TRANSLATE_OPTION: throw DriverException("Optional feature not supported.", "HYC00"); case SQL_ATTR_TXN_ISOLATION: - throw DriverException("Optional feature not supported.", "HCY00"); + throw DriverException("Optional feature not supported.", "HYC00"); case SQL_ATTR_CURRENT_CATALOG: { const auto& catalog = spi_connection_->GetAttribute(Connection::CURRENT_CATALOG); if (!catalog) { throw DriverException("Optional feature not supported.", "HYC00"); } - const std::string& info_value = boost::get(*catalog); - GetStringAttribute(is_unicode, info_value, true, value, buffer_length, - output_length, GetDiagnostics()); - return; + const std::string& info_value = std::get(*catalog); + return GetStringAttribute(is_unicode, info_value, true, value, buffer_length, + output_length, GetDiagnostics()); } // These all are uint32_t attributes. @@ -626,8 +620,9 @@ void ODBCConnection::GetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, throw DriverException("Invalid attribute", "HY092"); } - GetAttribute(static_cast(boost::get(*spi_attribute)), value, + GetAttribute(static_cast(std::get(*spi_attribute)), value, buffer_length, output_length); + return SQL_SUCCESS; } void ODBCConnection::Disconnect() { diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_connection.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_connection.h index 2e5ab57ad49..5b80473c55e 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_connection.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_connection.h @@ -17,8 +17,8 @@ #pragma once -#include #include "arrow/flight/sql/odbc/odbc_impl/odbc_handle.h" +#include "arrow/flight/sql/odbc/odbc_impl/spi/connection.h" #include #include @@ -42,6 +42,9 @@ class ODBCConnection : public ODBCHandle { ODBCConnection(const ODBCConnection&) = delete; ODBCConnection& operator=(const ODBCConnection&) = delete; + /// \brief Constructor for ODBCConnection. + /// \param[in] environment the parent environment. + /// \param[in] spi_connection the underlying spi connection. ODBCConnection(ODBCEnvironment& environment, std::shared_ptr spi_connection); @@ -49,16 +52,22 @@ class ODBCConnection : public ODBCHandle { const std::string& GetDSN() const; bool IsConnected() const; + + /// \brief Connect to Arrow Flight SQL server. + /// \param[in] dsn the dsn name. + /// \param[in] properties the connection property map extracted from connection string. + /// \param[out] missing_properties report the properties that are missing void Connect(std::string dsn, const arrow::flight::sql::odbc::Connection::ConnPropertyMap& properties, std::vector& missing_properties); - void GetInfo(SQLUSMALLINT info_type, SQLPOINTER value, SQLSMALLINT buffer_length, - SQLSMALLINT* output_length, bool is_unicode); + SQLRETURN GetInfo(SQLUSMALLINT info_type, SQLPOINTER value, SQLSMALLINT buffer_length, + SQLSMALLINT* output_length, bool is_unicode); void SetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER string_length, - bool isUnicode); - void GetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER buffer_length, - SQLINTEGER* output_length, bool is_unicode); + bool is_unicode); + SQLRETURN GetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, + SQLINTEGER buffer_length, SQLINTEGER* output_length, + bool is_unicode); ~ODBCConnection() = default; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_descriptor.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_descriptor.cc index d2b7f8865ca..8c856fdbd6b 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_descriptor.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_descriptor.cc @@ -62,7 +62,7 @@ ODBCDescriptor::ODBCDescriptor(Diagnostics& base_diagnostics, ODBCConnection* co parent_statement_(stmt), array_status_ptr_(nullptr), bind_offset_ptr_(nullptr), - rows_processed_ptr_(nullptr), + rows_proccessed_ptr_(nullptr), array_size_(1), bind_type_(SQL_BIND_BY_COLUMN), highest_one_based_bound_record_(0), @@ -109,7 +109,7 @@ void ODBCDescriptor::SetHeaderField(SQLSMALLINT field_identifier, SQLPOINTER val has_bindings_changed_ = true; break; case SQL_DESC_ROWS_PROCESSED_PTR: - SetPointerAttribute(value, rows_processed_ptr_); + SetPointerAttribute(value, rows_proccessed_ptr_); has_bindings_changed_ = true; break; case SQL_DESC_COUNT: { @@ -273,10 +273,12 @@ void ODBCDescriptor::GetHeaderField(SQLSMALLINT field_identifier, SQLPOINTER val GetAttribute(bind_type_, value, buffer_length, output_length); break; case SQL_DESC_ROWS_PROCESSED_PTR: - GetAttribute(rows_processed_ptr_, value, buffer_length, output_length); + GetAttribute(rows_proccessed_ptr_, value, buffer_length, output_length); break; case SQL_DESC_COUNT: { - GetAttribute(highest_one_based_bound_record_, value, buffer_length, output_length); + // highest_one_based_bound_record_ equals number of records + 1 + GetAttribute(static_cast(highest_one_based_bound_record_ - 1), value, + buffer_length, output_length); break; } default: @@ -310,54 +312,55 @@ void ODBCDescriptor::GetField(SQLSMALLINT record_number, SQLSMALLINT field_ident throw DriverException("Invalid descriptor index", "07009"); } - // TODO: Restrict fields based on AppDescriptor IPD, and IRD. + // GH-47867 TODO: Restrict fields based on AppDescriptor IPD, and IRD. + bool length_in_bytes = true; SQLSMALLINT zero_based_record = record_number - 1; const DescriptorRecord& record = records_[zero_based_record]; switch (field_identifier) { case SQL_DESC_BASE_COLUMN_NAME: - GetAttributeUTF8(record.base_column_name, value, buffer_length, output_length, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.base_column_name, length_in_bytes, value, buffer_length, + output_length, GetDiagnostics()); break; case SQL_DESC_BASE_TABLE_NAME: - GetAttributeUTF8(record.base_table_name, value, buffer_length, output_length, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.base_table_name, length_in_bytes, value, buffer_length, + output_length, GetDiagnostics()); break; case SQL_DESC_CATALOG_NAME: - GetAttributeUTF8(record.catalog_name, value, buffer_length, output_length, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.catalog_name, length_in_bytes, value, buffer_length, + output_length, GetDiagnostics()); break; case SQL_DESC_LABEL: - GetAttributeUTF8(record.label, value, buffer_length, output_length, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.label, length_in_bytes, value, buffer_length, + output_length, GetDiagnostics()); break; case SQL_DESC_LITERAL_PREFIX: - GetAttributeUTF8(record.literal_prefix, value, buffer_length, output_length, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.literal_prefix, length_in_bytes, value, buffer_length, + output_length, GetDiagnostics()); break; case SQL_DESC_LITERAL_SUFFIX: - GetAttributeUTF8(record.literal_suffix, value, buffer_length, output_length, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.literal_suffix, length_in_bytes, value, buffer_length, + output_length, GetDiagnostics()); break; case SQL_DESC_LOCAL_TYPE_NAME: - GetAttributeUTF8(record.local_type_name, value, buffer_length, output_length, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.local_type_name, length_in_bytes, value, buffer_length, + output_length, GetDiagnostics()); break; case SQL_DESC_NAME: - GetAttributeUTF8(record.name, value, buffer_length, output_length, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.name, length_in_bytes, value, buffer_length, + output_length, GetDiagnostics()); break; case SQL_DESC_SCHEMA_NAME: - GetAttributeUTF8(record.schema_name, value, buffer_length, output_length, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.schema_name, length_in_bytes, value, buffer_length, + output_length, GetDiagnostics()); break; case SQL_DESC_TABLE_NAME: - GetAttributeUTF8(record.table_name, value, buffer_length, output_length, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.table_name, length_in_bytes, value, buffer_length, + output_length, GetDiagnostics()); break; case SQL_DESC_TYPE_NAME: - GetAttributeUTF8(record.type_name, value, buffer_length, output_length, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.type_name, length_in_bytes, value, buffer_length, + output_length, GetDiagnostics()); break; case SQL_DESC_DATA_PTR: @@ -367,7 +370,7 @@ void ODBCDescriptor::GetField(SQLSMALLINT record_number, SQLSMALLINT field_ident case SQL_DESC_OCTET_LENGTH_PTR: GetAttribute(record.indicator_ptr, value, buffer_length, output_length); break; - + case SQL_COLUMN_LENGTH: // ODBC 2.0 case SQL_DESC_LENGTH: GetAttribute(record.length, value, buffer_length, output_length); break; @@ -407,12 +410,14 @@ void ODBCDescriptor::GetField(SQLSMALLINT record_number, SQLSMALLINT field_ident case SQL_DESC_PARAMETER_TYPE: GetAttribute(record.param_type, value, buffer_length, output_length); break; + case SQL_COLUMN_PRECISION: // ODBC 2.0 case SQL_DESC_PRECISION: GetAttribute(record.precision, value, buffer_length, output_length); break; case SQL_DESC_ROWVER: GetAttribute(record.row_ver, value, buffer_length, output_length); break; + case SQL_COLUMN_SCALE: // ODBC 2.0 case SQL_DESC_SCALE: GetAttribute(record.scale, value, buffer_length, output_length); break; @@ -479,6 +484,8 @@ void ODBCDescriptor::PopulateFromResultSetMetadata(ResultSetMetadata* rsmd) { for (size_t i = 0; i < records_.size(); ++i) { size_t one_based_index = i + 1; + int16_t concise_type = rsmd->GetConciseType(one_based_index); + records_[i].base_column_name = rsmd->GetBaseColumnName(one_based_index); records_[i].base_table_name = rsmd->GetBaseTableName(one_based_index); records_[i].catalog_name = rsmd->GetCatalogName(one_based_index); @@ -489,9 +496,8 @@ void ODBCDescriptor::PopulateFromResultSetMetadata(ResultSetMetadata* rsmd) { records_[i].name = rsmd->GetName(one_based_index); records_[i].schema_name = rsmd->GetSchemaName(one_based_index); records_[i].table_name = rsmd->GetTableName(one_based_index); - records_[i].type_name = rsmd->GetTypeName(one_based_index); - records_[i].concise_type = GetSqlTypeForODBCVersion( - rsmd->GetConciseType(one_based_index), is_2x_connection_); + records_[i].type_name = rsmd->GetTypeName(one_based_index, concise_type); + records_[i].concise_type = GetSqlTypeForODBCVersion(concise_type, is_2x_connection_); records_[i].data_ptr = nullptr; records_[i].indicator_ptr = nullptr; records_[i].display_size = rsmd->GetColumnDisplaySize(one_based_index); @@ -501,10 +507,12 @@ void ODBCDescriptor::PopulateFromResultSetMetadata(ResultSetMetadata* rsmd) { rsmd->IsAutoUnique(one_based_index) ? SQL_TRUE : SQL_FALSE; records_[i].case_sensitive = rsmd->IsCaseSensitive(one_based_index) ? SQL_TRUE : SQL_FALSE; - records_[i].datetime_interval_precision; // TODO - update when rsmd adds this + records_[i].datetime_interval_precision; // GH-47869 TODO implement + // `SQL_DESC_DATETIME_INTERVAL_PRECISION` SQLINTEGER num_prec_radix = rsmd->GetNumPrecRadix(one_based_index); records_[i].num_prec_radix = num_prec_radix > 0 ? num_prec_radix : 0; - records_[i].datetime_interval_code; // TODO + records_[i].datetime_interval_code; // GH-47868 TODO implement + // `SQL_DESC_DATETIME_INTERVAL_CODE` records_[i].fixed_prec_scale = rsmd->IsFixedPrecScale(one_based_index) ? SQL_TRUE : SQL_FALSE; records_[i].nullable = rsmd->IsNullable(one_based_index); @@ -573,5 +581,5 @@ void ODBCDescriptor::SetDataPtrOnRecord(SQLPOINTER data_ptr, SQLSMALLINT record_ } void DescriptorRecord::CheckConsistency() { - // TODO + // GH-47870 TODO implement } diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_descriptor.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_descriptor.h index d28adbc91d2..8a6cab82be0 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_descriptor.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_descriptor.h @@ -126,8 +126,8 @@ class ODBCDescriptor : public ODBCHandle { inline SQLUSMALLINT* GetArrayStatusPtr() { return array_status_ptr_; } inline void SetRowsProcessed(SQLULEN rows) { - if (rows_processed_ptr_) { - *rows_processed_ptr_ = rows; + if (rows_proccessed_ptr_) { + *rows_proccessed_ptr_ = rows; } } @@ -144,7 +144,7 @@ class ODBCDescriptor : public ODBCHandle { ODBCStatement* parent_statement_; SQLUSMALLINT* array_status_ptr_; SQLULEN* bind_offset_ptr_; - SQLULEN* rows_processed_ptr_; + SQLULEN* rows_proccessed_ptr_; SQLULEN array_size_; SQLINTEGER bind_type_; SQLSMALLINT highest_one_based_bound_record_; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_handle.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_handle.h index b3fd6e371a2..4e46c7be4e2 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_handle.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_handle.h @@ -17,13 +17,14 @@ #pragma once -#include -#include +// platform.h includes windows.h, so it needs to be included first +#include "arrow/flight/sql/odbc/odbc_impl/platform.h" #include #include #include #include +#include "arrow/flight/sql/odbc/odbc_impl/diagnostics.h" /** * @brief An abstraction over a generic ODBC handle. diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.cc index d452e77db1d..6ae9cde051f 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.cc @@ -29,8 +29,8 @@ #include #include #include -#include #include +#include #include using ODBC::DescriptorRecord; @@ -129,6 +129,9 @@ SQLSMALLINT getc_typeForSQLType(const DescriptorRecord& record) { case SQL_WLONGVARCHAR: return SQL_C_WCHAR; + case SQL_BIT: + return SQL_C_BIT; + case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: @@ -146,13 +149,20 @@ SQLSMALLINT getc_typeForSQLType(const DescriptorRecord& record) { case SQL_BIGINT: return record.is_unsigned ? SQL_C_UBIGINT : SQL_C_SBIGINT; + case SQL_NUMERIC: + case SQL_DECIMAL: + return SQL_C_NUMERIC; + + case SQL_FLOAT: case SQL_REAL: return SQL_C_FLOAT; - case SQL_FLOAT: case SQL_DOUBLE: return SQL_C_DOUBLE; + case SQL_GUID: + return SQL_C_GUID; + case SQL_DATE: case SQL_TYPE_DATE: return SQL_C_TYPE_DATE; @@ -165,32 +175,32 @@ SQLSMALLINT getc_typeForSQLType(const DescriptorRecord& record) { case SQL_TYPE_TIMESTAMP: return SQL_C_TYPE_TIMESTAMP; - case SQL_C_INTERVAL_DAY: - return SQL_INTERVAL_DAY; - case SQL_C_INTERVAL_DAY_TO_HOUR: - return SQL_INTERVAL_DAY_TO_HOUR; - case SQL_C_INTERVAL_DAY_TO_MINUTE: - return SQL_INTERVAL_DAY_TO_MINUTE; - case SQL_C_INTERVAL_DAY_TO_SECOND: - return SQL_INTERVAL_DAY_TO_SECOND; - case SQL_C_INTERVAL_HOUR: - return SQL_INTERVAL_HOUR; - case SQL_C_INTERVAL_HOUR_TO_MINUTE: - return SQL_INTERVAL_HOUR_TO_MINUTE; - case SQL_C_INTERVAL_HOUR_TO_SECOND: - return SQL_INTERVAL_HOUR_TO_SECOND; - case SQL_C_INTERVAL_MINUTE: - return SQL_INTERVAL_MINUTE; - case SQL_C_INTERVAL_MINUTE_TO_SECOND: - return SQL_INTERVAL_MINUTE_TO_SECOND; - case SQL_C_INTERVAL_SECOND: - return SQL_INTERVAL_SECOND; - case SQL_C_INTERVAL_YEAR: - return SQL_INTERVAL_YEAR; - case SQL_C_INTERVAL_YEAR_TO_MONTH: - return SQL_INTERVAL_YEAR_TO_MONTH; - case SQL_C_INTERVAL_MONTH: - return SQL_INTERVAL_MONTH; + case SQL_INTERVAL_DAY: + return SQL_C_INTERVAL_DAY; + case SQL_INTERVAL_DAY_TO_HOUR: + return SQL_C_INTERVAL_DAY_TO_HOUR; + case SQL_INTERVAL_DAY_TO_MINUTE: + return SQL_C_INTERVAL_DAY_TO_MINUTE; + case SQL_INTERVAL_DAY_TO_SECOND: + return SQL_C_INTERVAL_DAY_TO_SECOND; + case SQL_INTERVAL_HOUR: + return SQL_C_INTERVAL_HOUR; + case SQL_INTERVAL_HOUR_TO_MINUTE: + return SQL_C_INTERVAL_HOUR_TO_MINUTE; + case SQL_INTERVAL_HOUR_TO_SECOND: + return SQL_C_INTERVAL_HOUR_TO_SECOND; + case SQL_INTERVAL_MINUTE: + return SQL_C_INTERVAL_MINUTE; + case SQL_INTERVAL_MINUTE_TO_SECOND: + return SQL_C_INTERVAL_MINUTE_TO_SECOND; + case SQL_INTERVAL_SECOND: + return SQL_C_INTERVAL_SECOND; + case SQL_INTERVAL_YEAR: + return SQL_C_INTERVAL_YEAR; + case SQL_INTERVAL_YEAR_TO_MONTH: + return SQL_C_INTERVAL_YEAR_TO_MONTH; + case SQL_INTERVAL_MONTH: + return SQL_C_INTERVAL_MONTH; default: throw DriverException("Unknown SQL type: " + std::to_string(record.concise_type), @@ -240,7 +250,7 @@ void ODBCStatement::CopyAttributesFromConnection(ODBCConnection& connection) { ODBCStatement& tracking_statement = connection.GetTrackingStatement(); // Get abstraction attributes and copy to this spi_statement_. - // Possible ODBC attributes are below, but many of these are not supported by warpdrive + // Possible ODBC attributes are below, but many of these are not supported by Arrow ODBC // or ODBCAbstaction: // SQL_ATTR_ASYNC_ENABLE: // SQL_ATTR_METADATA_ID: @@ -273,7 +283,7 @@ void ODBCStatement::CopyAttributesFromConnection(ODBCConnection& connection) { bool ODBCStatement::IsPrepared() const { return is_prepared_; } void ODBCStatement::Prepare(const std::string& query) { - boost::optional > metadata = + std::optional > metadata = spi_statement_->Prepare(query); if (metadata) { @@ -306,7 +316,8 @@ void ODBCStatement::ExecuteDirect(const std::string& query) { is_prepared_ = false; } -bool ODBCStatement::Fetch(size_t rows) { +bool ODBCStatement::Fetch(size_t rows, SQLULEN* row_count_ptr, + SQLUSMALLINT* row_status_array) { if (has_reached_end_of_result_) { ird_->SetRowsProcessed(0); return false; @@ -317,7 +328,7 @@ bool ODBCStatement::Fetch(size_t rows) { } if (current_ard_->HaveBindingsChanged()) { - // TODO: Deal handle when offset != buffer_length. + // GH-47871 TODO: handle when offset != buffer_length. // Wipe out all bindings in the ResultSet. // Note that the number of ARD records can both be more or less @@ -339,11 +350,24 @@ bool ODBCStatement::Fetch(size_t rows) { current_ard_->NotifyBindingsHavePropagated(); } - size_t rows_fetched = current_result_->Move(rows, current_ard_->GetBindOffset(), - current_ard_->GetBoundStructOffset(), - ird_->GetArrayStatusPtr()); + uint16_t* array_status_ptr; + if (row_status_array) { + // For SQLExtendedFetch only + array_status_ptr = row_status_array; + } else { + array_status_ptr = ird_->GetArrayStatusPtr(); + } + + size_t rows_fetched = + current_result_->Move(rows, current_ard_->GetBindOffset(), + current_ard_->GetBoundStructOffset(), array_status_ptr); ird_->SetRowsProcessed(static_cast(rows_fetched)); + if (row_count_ptr) { + // For SQLExtendedFetch only + *row_count_ptr = rows_fetched; + } + row_number_ += rows_fetched; has_reached_end_of_result_ = rows_fetched != rows; return rows_fetched != 0; @@ -352,7 +376,7 @@ bool ODBCStatement::Fetch(size_t rows) { void ODBCStatement::GetStmtAttr(SQLINTEGER statement_attribute, SQLPOINTER output, SQLINTEGER buffer_size, SQLINTEGER* str_len_ptr, bool is_unicode) { - boost::optional spi_attribute; + std::optional spi_attribute; switch (statement_attribute) { // Descriptor accessor attributes case SQL_ATTR_APP_PARAM_DESC: @@ -375,6 +399,14 @@ void ODBCStatement::GetStmtAttr(SQLINTEGER statement_attribute, SQLPOINTER outpu return; case SQL_ATTR_PARAM_BIND_TYPE: current_apd_->GetHeaderField(SQL_DESC_BIND_TYPE, output, buffer_size, str_len_ptr); + if (output) { + // Convert SQLINTEGER output to SQLULEN, since SQL_DESC_BIND_TYPE is SQLINTEGER + // and SQL_ATTR_PARAM_BIND_TYPE is SQLULEN + SQLINTEGER* output_int_ptr = reinterpret_cast(output); + SQLINTEGER output_int = *output_int_ptr; + SQLULEN* typed_output = reinterpret_cast(output); + *typed_output = static_cast(output_int); + } return; case SQL_ATTR_PARAM_OPERATION_PTR: current_apd_->GetHeaderField(SQL_DESC_ARRAY_STATUS_PTR, output, buffer_size, @@ -398,6 +430,14 @@ void ODBCStatement::GetStmtAttr(SQLINTEGER statement_attribute, SQLPOINTER outpu return; case SQL_ATTR_ROW_BIND_TYPE: current_ard_->GetHeaderField(SQL_DESC_BIND_TYPE, output, buffer_size, str_len_ptr); + if (output) { + // Convert SQLINTEGER output to SQLULEN, since SQL_DESC_BIND_TYPE is SQLINTEGER + // and SQL_ATTR_ROW_BIND_TYPE is SQLULEN + SQLINTEGER* output_int_ptr = reinterpret_cast(output); + SQLINTEGER output_int = *output_int_ptr; + SQLULEN* typed_output = reinterpret_cast(output); + *typed_output = static_cast(output_int); + } return; case SQL_ATTR_ROW_OPERATION_PTR: current_ard_->GetHeaderField(SQL_DESC_ARRAY_STATUS_PTR, output, buffer_size, @@ -496,7 +536,7 @@ void ODBCStatement::GetStmtAttr(SQLINTEGER statement_attribute, SQLPOINTER outpu } if (spi_attribute) { - GetAttribute(static_cast(boost::get(*spi_attribute)), output, + GetAttribute(static_cast(std::get(*spi_attribute)), output, buffer_size, str_len_ptr); return; } @@ -580,6 +620,7 @@ void ODBCStatement::SetStmtAttr(SQLINTEGER statement_attribute, SQLPOINTER value return; case SQL_ATTR_ASYNC_ENABLE: + throw DriverException("Unsupported attribute", "HYC00"); #ifdef SQL_ATTR_ASYNC_STMT_EVENT case SQL_ATTR_ASYNC_STMT_EVENT: throw DriverException("Unsupported attribute", "HYC00"); @@ -627,7 +668,7 @@ void ODBCStatement::SetStmtAttr(SQLINTEGER statement_attribute, SQLPOINTER value CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_UB_OFF)); return; case SQL_ATTR_RETRIEVE_DATA: - CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_TRUE)); + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_RD_ON)); return; case SQL_ROWSET_SIZE: SetAttribute(value, rowset_size_); @@ -677,7 +718,7 @@ void ODBCStatement::RevertAppDescriptor(bool isApd) { void ODBCStatement::CloseCursor(bool suppress_errors) { if (!suppress_errors && !current_result_) { - throw DriverException("Invalid cursor state", "28000"); + throw DriverException("Invalid cursor state", "24000"); } if (current_result_) { @@ -691,9 +732,9 @@ void ODBCStatement::CloseCursor(bool suppress_errors) { has_reached_end_of_result_ = false; } -bool ODBCStatement::GetData(SQLSMALLINT record_number, SQLSMALLINT c_type, - SQLPOINTER data_ptr, SQLLEN buffer_length, - SQLLEN* indicator_ptr) { +SQLRETURN ODBCStatement::GetData(SQLSMALLINT record_number, SQLSMALLINT c_type, + SQLPOINTER data_ptr, SQLLEN buffer_length, + SQLLEN* indicator_ptr) { if (record_number == 0) { throw DriverException("Bookmarks are not supported", "07009"); } else if (record_number > ird_->GetRecords().size()) { @@ -703,7 +744,7 @@ bool ODBCStatement::GetData(SQLSMALLINT record_number, SQLSMALLINT c_type, SQLSMALLINT evaluated_c_type = c_type; - // TODO: Get proper default precision and scale from abstraction. + // GH-47872 TODO: Get proper default precision and scale from abstraction. int precision = 38; // arrow::Decimal128Type::kMaxPrecision; int scale = 0; @@ -735,6 +776,35 @@ bool ODBCStatement::GetData(SQLSMALLINT record_number, SQLSMALLINT c_type, data_ptr, buffer_length, indicator_ptr); } +SQLRETURN ODBCStatement::GetMoreResults() { + // Multiple result sets are not supported. + if (current_result_) { + return SQL_NO_DATA; + } else { + throw DriverException("Function sequence error", "HY010"); + } +} + +void ODBCStatement::GetColumnCount(SQLSMALLINT* column_count_ptr) { + if (!column_count_ptr) { + // columnCountPtr is not valid, do nothing as ODBC spec does not mention this as an + // error + return; + } + size_t column_count = ird_->GetRecords().size(); + *column_count_ptr = static_cast(column_count); +} + +void ODBCStatement::GetRowCount(SQLLEN* row_count_ptr) { + if (!row_count_ptr) { + // row_count_ptr is not valid, do nothing as ODBC spec does not mention this as an + // error + return; + } + // Will always be -1 (number of rows unknown) if only SELECT is supported + *row_count_ptr = -1; +} + void ODBCStatement::ReleaseStatement() { CloseCursor(true); connection_.DropStatement(this); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.h index 8e128db1bda..d4e37858f23 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.h @@ -17,9 +17,11 @@ #pragma once +// platform.h platform.h includes windows.h so it needs to be included first +#include "arrow/flight/sql/odbc/odbc_impl/platform.h" + #include "arrow/flight/sql/odbc/odbc_impl/odbc_handle.h" -#include #include #include #include @@ -60,8 +62,10 @@ class ODBCStatement : public ODBCHandle { /** * @brief Returns true if the number of rows fetch was greater than zero. + * row_count_ptr and row_status_array are optional arguments, they are only needed for + * SQLExtendedFetch */ - bool Fetch(size_t rows); + bool Fetch(size_t rows, SQLULEN* row_count_ptr = 0, SQLUSMALLINT* row_status_array = 0); bool IsPrepared() const; void GetStmtAttr(SQLINTEGER statement_attribute, SQLPOINTER output, @@ -69,6 +73,11 @@ class ODBCStatement : public ODBCHandle { void SetStmtAttr(SQLINTEGER statement_attribute, SQLPOINTER value, SQLINTEGER buffer_size, bool is_unicode); + /** + * @brief Revert back to implicitly allocated internal descriptors. + * isApd as True indicates APD descritor is to be reverted. + * isApd as False indicates ARD descritor is to be reverted. + */ void RevertAppDescriptor(bool is_apd); inline ODBCDescriptor* GetIRD() { return ird_.get(); } @@ -77,8 +86,20 @@ class ODBCStatement : public ODBCHandle { inline SQLULEN GetRowsetSize() { return rowset_size_; } - bool GetData(SQLSMALLINT record_number, SQLSMALLINT c_type, SQLPOINTER data_ptr, - SQLLEN buffer_length, SQLLEN* indicator_ptr); + SQLRETURN GetData(SQLSMALLINT record_number, SQLSMALLINT c_type, SQLPOINTER data_ptr, + SQLLEN buffer_length, SQLLEN* indicator_ptr); + + SQLRETURN GetMoreResults(); + + /** + * @brief Get number of columns from data set + */ + void GetColumnCount(SQLSMALLINT* column_count_ptr); + + /** + * @brief Get number of rows affected by an UPDATE, INSERT, or DELETE statement + */ + void GetRowCount(SQLLEN* row_count_ptr); /** * @brief Closes the cursor. This does _not_ un-prepare the statement or change diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/parse_table_types_test.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/parse_table_types_test.cc index cf1e5930a82..3749c276d4f 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/parse_table_types_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/parse_table_types_test.cc @@ -18,7 +18,8 @@ #include "arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.h" #include "arrow/flight/sql/odbc/odbc_impl/platform.h" -#include "gtest/gtest.h" + +#include namespace arrow::flight::sql::odbc { diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/record_batch_transformer.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/record_batch_transformer.h index 539583aac29..15548b52c63 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/record_batch_transformer.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/record_batch_transformer.h @@ -17,9 +17,9 @@ #pragma once -#include -#include #include +#include "arrow/flight/client.h" +#include "arrow/type.h" namespace arrow::flight::sql::odbc { diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/record_batch_transformer_test.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/record_batch_transformer_test.cc index 9727167a500..a5e094317ea 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/record_batch_transformer_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/record_batch_transformer_test.cc @@ -20,7 +20,8 @@ #include "arrow/flight/sql/odbc/odbc_impl/platform.h" #include "arrow/record_batch.h" #include "arrow/testing/builder.h" -#include "gtest/gtest.h" + +#include namespace arrow::flight::sql::odbc { namespace { diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/scalar_function_reporter.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/scalar_function_reporter.h index f4855812bf9..e9cf18dc55a 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/scalar_function_reporter.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/scalar_function_reporter.h @@ -17,7 +17,7 @@ #pragma once -#include +#include "arrow/type.h" namespace arrow::flight::sql::odbc { diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/connection.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/connection.h index 7a8243e7859..fdfb2d2ea8b 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/connection.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/connection.h @@ -18,11 +18,12 @@ #pragma once #include -#include #include #include #include +#include #include +#include #include #include "arrow/flight/sql/odbc/odbc_impl/diagnostics.h" @@ -63,8 +64,8 @@ class Connection { PACKET_SIZE, // uint32_t - The Packet Size }; - typedef boost::variant Attribute; - typedef boost::variant Info; + typedef std::variant Attribute; + typedef std::variant Info; typedef PropertyMap ConnPropertyMap; /// \brief Establish the connection. @@ -88,7 +89,7 @@ class Connection { /// \brief Retrieve a connection attribute /// \param attribute [in] Attribute to be retrieved. - virtual boost::optional GetAttribute( + virtual std::optional GetAttribute( Connection::AttributeId attribute) = 0; /// \brief Retrieves info from the database (see ODBC's SQLGetInfo). diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/result_set.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/result_set.h index a273d62f63d..550ca357258 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/result_set.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/result_set.h @@ -24,6 +24,8 @@ #include "arrow/flight/sql/odbc/odbc_impl/types.h" +#include + namespace arrow::flight::sql::odbc { class ResultSetMetadata; @@ -87,10 +89,10 @@ class ResultSet { /// \param buffer Target buffer to be populated. /// \param buffer_length Target buffer length. /// \param strlen_buffer Buffer that holds the length of value being fetched. - /// \returns true if there is more data to fetch from the current cell; - /// false if the whole value was already fetched. - virtual bool GetData(int column, int16_t target_type, int precision, int scale, - void* buffer, size_t buffer_length, ssize_t* strlen_buffer) = 0; + /// \returns SQLRETURN for SQLGetData. + virtual SQLRETURN GetData(int column, int16_t target_type, int precision, int scale, + void* buffer, size_t buffer_length, + ssize_t* strlen_buffer) = 0; }; } // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/result_set_metadata.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/result_set_metadata.h index 38f81fc9c3e..a33784cc79b 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/result_set_metadata.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/result_set_metadata.h @@ -17,9 +17,8 @@ #pragma once -#include "arrow/flight/sql/odbc/odbc_impl/types.h" - #include +#include "arrow/flight/sql/odbc/odbc_impl/types.h" namespace arrow::flight::sql::odbc { @@ -143,8 +142,9 @@ class ResultSetMetadata { /// \brief It returns the data type as a string. /// \param column_position [in] the position of the column, starting from 1. + /// \param data_type [in] the data type of the column. /// \return the data type string. - virtual std::string GetTypeName(int column_position) = 0; + virtual std::string GetTypeName(int column_position, int data_type) = 0; /// \brief It returns a numeric values indicate the updatability of the /// column. diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/statement.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/statement.h index 970e447dfdc..60c68c1138e 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/statement.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/statement.h @@ -17,14 +17,14 @@ #pragma once -#include -#include #include +#include +#include #include namespace arrow::flight::sql::odbc { -using boost::optional; +using std::optional; class ResultSet; @@ -39,7 +39,7 @@ class Statement { virtual ~Statement() = default; /// \brief Statement attributes that can be called at anytime. - ////TODO: Document attributes + /// GH-47850 TODO: Document attributes enum StatementAttributeId { MAX_LENGTH, // size_t - The maximum length when retrieving variable length data. 0 // means no limit. @@ -51,7 +51,7 @@ class Statement { // have no timeout. }; - typedef boost::variant Attribute; + typedef std::variant Attribute; /// \brief Set a statement attribute (may be called at any time) /// @@ -74,9 +74,9 @@ class Statement { /// \brief Prepares the statement. /// Returns ResultSetMetadata if query returns a result set, - /// otherwise it returns `boost::none`. + /// otherwise it returns `std::nullopt`. /// \param query The SQL query to prepare. - virtual boost::optional> Prepare( + virtual std::optional> Prepare( const std::string& query) = 0; /// \brief Execute the prepared statement. diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/types.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/types.h index 7a91221cd44..e9817fa2387 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/types.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/types.h @@ -17,10 +17,10 @@ #pragma once +#include +#include #include "arrow/flight/sql/odbc/odbc_impl/platform.h" -#include - namespace arrow::flight::sql::odbc { /// \brief Supported ODBC versions. @@ -172,7 +172,7 @@ enum RowStatus : uint16_t { }; struct MetadataSettings { - boost::optional string_column_length{boost::none}; + std::optional string_column_length{std::nullopt}; size_t chunk_buffer_capacity; bool use_wide_char; }; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/ui/custom_window.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/ui/custom_window.cc index 179303b68e3..0d24f0cca82 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/ui/custom_window.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/ui/custom_window.cc @@ -25,6 +25,7 @@ #include #include +#include #include #include "arrow/flight/sql/odbc/odbc_impl/exceptions.h" diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/ui/dsn_configuration_window.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/ui/dsn_configuration_window.cc index 0432836a16f..252475c0e44 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/ui/dsn_configuration_window.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/ui/dsn_configuration_window.cc @@ -29,6 +29,7 @@ #include #include #include +#include "arrow/flight/sql/odbc/odbc_impl/util.h" #define COMMON_TAB 0 #define ADVANCED_TAB 1 @@ -44,9 +45,9 @@ std::string TestConnection(const config::Configuration& config) { // This should have been checked before enabling the Test button. assert(missing_properties.empty()); std::string server_name = - boost::get(flight_sql_conn->GetInfo(SQL_SERVER_NAME)); + std::get(flight_sql_conn->GetInfo(SQL_SERVER_NAME)); std::string server_version = - boost::get(flight_sql_conn->GetInfo(SQL_DBMS_VER)); + std::get(flight_sql_conn->GetInfo(SQL_DBMS_VER)); return "Server Name: " + server_name + "\n" + "Server Version: " + server_version; } } // namespace @@ -565,7 +566,7 @@ bool DsnConfigurationWindow::OnMessage(UINT msg, WPARAM wparam, LPARAM lparam) { open_file_name.lpstrFile = file_name; open_file_name.lpstrFile[0] = '\0'; open_file_name.nMaxFile = FILENAME_MAX; - // TODO: What type should this be? + // GH-47851 TODO: Update `lpstrFilter` to correct value open_file_name.lpstrFilter = L"All\0*.*"; open_file_name.nFilterIndex = 1; open_file_name.lpstrFileTitle = NULL; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/ui/window.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/ui/window.cc index f21329977ba..ce10ddd3bf9 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/ui/window.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/ui/window.cc @@ -36,7 +36,6 @@ HINSTANCE GetHInstance() { TCHAR sz_file_name[MAX_PATH]; GetModuleFileName(NULL, sz_file_name, MAX_PATH); - // TODO: This needs to be the module name. HINSTANCE h_instance = GetModuleHandle(sz_file_name); if (h_instance == NULL) { diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.cc index 59ee7dda565..38081d19354 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.cc @@ -56,6 +56,9 @@ SqlDataType GetDefaultSqlCharType(bool use_wide_char) { SqlDataType GetDefaultSqlVarcharType(bool use_wide_char) { return use_wide_char ? SqlDataType_WVARCHAR : SqlDataType_VARCHAR; } +SqlDataType GetDefaultSqlLongVarcharType(bool use_wide_char) { + return use_wide_char ? SqlDataType_WLONGVARCHAR : SqlDataType_LONGVARCHAR; +} CDataType GetDefaultCCharType(bool use_wide_char) { return use_wide_char ? CDataType_WCHAR : CDataType_CHAR; } @@ -113,12 +116,13 @@ SqlDataType GetDataTypeFromArrowFieldV3(const std::shared_ptr& field, case Type::TIME64: return SqlDataType_TYPE_TIME; case Type::INTERVAL_MONTHS: - return SqlDataType_INTERVAL_MONTH; // TODO: maybe - // SqlDataType_INTERVAL_YEAR_TO_MONTH + return SqlDataType_INTERVAL_MONTH; // GH-47873 TODO: check and update to + // SqlDataType_INTERVAL_YEAR_TO_MONTH if it is + // more appropriate case Type::INTERVAL_DAY_TIME: return SqlDataType_INTERVAL_DAY; - // TODO: Handle remaining types. + // GH-47873 TODO: Handle remaining types. case Type::INTERVAL_MONTH_DAY_NANO: case Type::LIST: case Type::STRUCT: @@ -146,6 +150,9 @@ SqlDataType EnsureRightSqlCharType(SqlDataType data_type, bool use_wide_char) { case SqlDataType_VARCHAR: case SqlDataType_WVARCHAR: return GetDefaultSqlVarcharType(use_wide_char); + case SqlDataType_LONGVARCHAR: + case SqlDataType_WLONGVARCHAR: + return GetDefaultSqlLongVarcharType(use_wide_char); default: return data_type; } @@ -663,7 +670,7 @@ optional GetDisplaySize(SqlDataType data_type, case SqlDataType_INTERVAL_HOUR_TO_MINUTE: case SqlDataType_INTERVAL_HOUR_TO_SECOND: case SqlDataType_INTERVAL_MINUTE_TO_SECOND: - return nullopt; // TODO: Implement for INTERVAL types + return nullopt; // GH-47874 TODO: Implement for INTERVAL types case SqlDataType_GUID: return 36; default: @@ -747,10 +754,12 @@ bool NeedArrayConversion(Type::type original_type_id, CDataType data_type) { return data_type != CDataType_BINARY; case Type::DECIMAL128: return data_type != CDataType_NUMERIC; + case Type::DURATION: case Type::LIST: case Type::LARGE_LIST: case Type::FIXED_SIZE_LIST: case Type::MAP: + case Type::STRING_VIEW: case Type::STRUCT: return data_type == CDataType_CHAR || data_type == CDataType_WCHAR; default: @@ -926,9 +935,9 @@ ArrayConvertTask GetConverter(Type::type original_type_id, CDataType target_type auto seconds_from_epoch = GetTodayTimeFromEpoch(); - auto third_converted_array = CheckConversion( - arrow::compute::Add(second_converted_array, - std::make_shared(seconds_from_epoch * 1000))); + auto third_converted_array = CheckConversion(arrow::compute::Add( + second_converted_array, + std::make_shared(seconds_from_epoch * 1000))); arrow::compute::CastOptions cast_options_2; cast_options_2.to_type = arrow::timestamp(TimeUnit::MILLI); @@ -947,7 +956,7 @@ ArrayConvertTask GetConverter(Type::type original_type_id, CDataType target_type auto second_converted_array = CheckConversion(arrow::compute::Add( first_converted_array, - std::make_shared(seconds_from_epoch * 1000000000))); + std::make_shared(seconds_from_epoch * 1000000000))); arrow::compute::CastOptions cast_options_2; cast_options_2.to_type = arrow::timestamp(TimeUnit::NANO); @@ -971,7 +980,7 @@ ArrayConvertTask GetConverter(Type::type original_type_id, CDataType target_type } else if (original_type_id == Type::DECIMAL128 && (target_type == CDataType_CHAR || target_type == CDataType_WCHAR)) { return [=](const std::shared_ptr& original_array) { - StringBuilder builder; + arrow::StringBuilder builder; int64_t length = original_array->length(); ThrowIfNotOK(builder.ReserveData(length)); @@ -1097,30 +1106,30 @@ int32_t GetDecimalTypePrecision(const std::shared_ptr& decimal_type) { return decimal128_type->precision(); } -boost::optional AsBool(const std::string& value) { +std::optional AsBool(const std::string& value) { if (boost::iequals(value, "true") || boost::iequals(value, "1")) { return true; } else if (boost::iequals(value, "false") || boost::iequals(value, "0")) { return false; } else { - return boost::none; + return std::nullopt; } } -boost::optional AsBool(const Connection::ConnPropertyMap& conn_property_map, - std::string_view property_name) { +std::optional AsBool(const Connection::ConnPropertyMap& conn_property_map, + std::string_view property_name) { auto extracted_property = conn_property_map.find(property_name); if (extracted_property != conn_property_map.end()) { return AsBool(extracted_property->second); } - return boost::none; + return std::nullopt; } -boost::optional AsInt32(int32_t min_value, - const Connection::ConnPropertyMap& conn_property_map, - std::string_view property_name) { +std::optional AsInt32(int32_t min_value, + const Connection::ConnPropertyMap& conn_property_map, + std::string_view property_name) { auto extracted_property = conn_property_map.find(property_name); if (extracted_property != conn_property_map.end()) { @@ -1130,7 +1139,7 @@ boost::optional AsInt32(int32_t min_value, return string_column_length; } } - return boost::none; + return std::nullopt; } } // namespace util diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.h index c17e77e7de8..1bca0164166 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.h @@ -17,17 +17,31 @@ #pragma once -#include "arrow/util/utf8.h" +#include +#include +#include +#include +#include #include "arrow/flight/sql/odbc/odbc_impl/exceptions.h" #include "arrow/flight/sql/odbc/odbc_impl/spi/connection.h" #include "arrow/flight/sql/odbc/odbc_impl/types.h" #include "arrow/flight/types.h" +#include "arrow/util/utf8.h" -#include -#include -#include -#include +#define CONVERT_WIDE_STR(wstring_var, utf8_target) \ + wstring_var = [&] { \ + arrow::Result res = arrow::util::UTF8ToWideString(utf8_target); \ + arrow::flight::sql::odbc::util::ThrowIfNotOK(res.status()); \ + return res.ValueOrDie(); \ + }() + +#define CONVERT_UTF8_STR(string_var, wide_str_target) \ + string_var = [&] { \ + arrow::Result res = arrow::util::WideStringToUTF8(wide_str_target); \ + arrow::flight::sql::odbc::util::ThrowIfNotOK(res.status()); \ + return res.ValueOrDie(); \ + }() #define CONVERT_WIDE_STR(wstring_var, utf8_target) \ wstring_var = [&] { \ @@ -59,7 +73,7 @@ inline void ThrowIfNotOK(const Status& status) { template inline bool CheckIfSetToOnlyValidValue(const AttributeTypeT& value, T allowed_value) { - return boost::get(value) == allowed_value; + return std::get(value) == allowed_value; } template @@ -136,15 +150,15 @@ int32_t GetDecimalTypePrecision(const std::shared_ptr& decimal_type); /// Parse a string value to a boolean. /// \param value the value to be parsed. /// \return the parsed valued. -boost::optional AsBool(const std::string& value); +std::optional AsBool(const std::string& value); /// Looks up for a value inside the ConnPropertyMap and then try to parse it. /// In case it does not find or it cannot parse, the default value will be returned. /// \param conn_property_map the map with the connection properties. /// \param property_name the name of the property that will be looked up. /// \return the parsed valued. -boost::optional AsBool(const Connection::ConnPropertyMap& conn_property_map, - std::string_view property_name); +std::optional AsBool(const Connection::ConnPropertyMap& conn_property_map, + std::string_view property_name); /// Looks up for a value inside the ConnPropertyMap and then try to parse it. /// In case it does not find or it cannot parse, the default value will be returned. @@ -154,9 +168,9 @@ boost::optional AsBool(const Connection::ConnPropertyMap& conn_property_ma /// looked up. \return the parsed valued. \exception /// std::invalid_argument exception from std::stoi \exception /// std::out_of_range exception from std::stoi -boost::optional AsInt32(int32_t min_value, - const Connection::ConnPropertyMap& conn_property_map, - std::string_view property_name); +std::optional AsInt32(int32_t min_value, + const Connection::ConnPropertyMap& conn_property_map, + std::string_view property_name); } // namespace util } // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/util_test.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/util_test.cc index bfcec15b4da..4946355ff20 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/util_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/util_test.cc @@ -23,13 +23,21 @@ #include "arrow/testing/builder.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/util.h" -#include "gtest/gtest.h" + +#include namespace arrow::flight::sql::odbc { using util::ConvertSqlPatternToRegexString; using util::ConvertToDBMSVer; +class UtilTestsWithCompute : public ::testing::Test { + public: + // This must be done before using the compute kernels in order to + // register them to the FunctionRegistry. + void SetUp() override { ASSERT_OK(arrow::compute::Initialize()); } +}; + // A global test "environment", to ensure Arrow compute kernel functions are registered class ComputeKernelEnvironment : public ::testing::Environment { @@ -48,7 +56,7 @@ void AssertConvertedArray(const std::shared_ptr& expected_array, ASSERT_EQ(expected_array->ToString(), converted_array->ToString()); } -std::shared_ptr convertArray(const std::shared_ptr& original_array, +std::shared_ptr ConvertArray(const std::shared_ptr& original_array, CDataType c_type) { auto converter = util::GetConverter(original_array->type_id(), c_type); return converter(original_array); @@ -60,7 +68,7 @@ void TestArrayConversion(const std::vector& input, std::shared_ptr original_array; ArrayFromVector(input, &original_array); - auto converted_array = convertArray(original_array, c_type); + auto converted_array = ConvertArray(original_array, c_type); AssertConvertedArray(expected_array, converted_array, input.size(), arrow_type); } @@ -71,7 +79,7 @@ void TestTime32ArrayConversion(const std::vector& input, std::shared_ptr original_array; ArrayFromVector(time32(TimeUnit::MILLI), input, &original_array); - auto converted_array = convertArray(original_array, c_type); + auto converted_array = ConvertArray(original_array, c_type); AssertConvertedArray(expected_array, converted_array, input.size(), arrow_type); } @@ -82,12 +90,12 @@ void TestTime64ArrayConversion(const std::vector& input, std::shared_ptr original_array; ArrayFromVector(time64(TimeUnit::NANO), input, &original_array); - auto converted_array = convertArray(original_array, c_type); + auto converted_array = ConvertArray(original_array, c_type); AssertConvertedArray(expected_array, converted_array, input.size(), arrow_type); } -TEST(Utils, Time32ToTimeStampArray) { +TEST_F(UtilTestsWithCompute, Time32ToTimeStampArray) { std::vector input_data = {14896, 17820}; const auto seconds_from_epoch = GetTodayTimeFromEpoch(); @@ -106,7 +114,7 @@ TEST(Utils, Time32ToTimeStampArray) { TestTime32ArrayConversion(input_data, expected, CDataType_TIMESTAMP, Type::TIMESTAMP); } -TEST(Utils, Time64ToTimeStampArray) { +TEST_F(UtilTestsWithCompute, Time64ToTimeStampArray) { std::vector input_data = {1579489200000, 1646881200000}; const auto seconds_from_epoch = GetTodayTimeFromEpoch(); @@ -125,7 +133,7 @@ TEST(Utils, Time64ToTimeStampArray) { TestTime64ArrayConversion(input_data, expected, CDataType_TIMESTAMP, Type::TIMESTAMP); } -TEST(Utils, StringToDateArray) { +TEST_F(UtilTestsWithCompute, StringToDateArray) { std::shared_ptr expected; ArrayFromVector({1579489200000, 1646881200000}, &expected); @@ -133,7 +141,7 @@ TEST(Utils, StringToDateArray) { Type::DATE64); } -TEST(Utils, StringToTimeArray) { +TEST_F(UtilTestsWithCompute, StringToTimeArray) { std::shared_ptr expected; ArrayFromVector(time64(TimeUnit::MICRO), {36000000000, 43200000000}, &expected); diff --git a/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt b/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt index 4bc240637e7..6a40c220eb8 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt @@ -17,7 +17,6 @@ add_custom_target(tests) -find_package(ODBC REQUIRED) include_directories(${ODBC_INCLUDE_DIRS}) find_package(SQLite3Alt REQUIRED) @@ -32,9 +31,19 @@ set(ARROW_FLIGHT_SQL_MOCK_SERVER_SRCS add_arrow_test(flight_sql_odbc_test SOURCES + columns_test.cc + connection_attr_test.cc + connection_info_test.cc + errors_test.cc + get_functions_test.cc + statement_attr_test.cc + statement_test.cc + tables_test.cc + type_info_test.cc + # Move connection_test to last to prevent segfault errors + connection_test.cc odbc_test_suite.cc odbc_test_suite.h - connection_test.cc # Enable Protobuf cleanup after test execution # GH-46889: move protobuf_test_util to a more common location ../../../../engine/substrait/protobuf_test_util.cc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/README b/cpp/src/arrow/flight/sql/odbc/tests/README new file mode 100644 index 00000000000..fe74c98b72f --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/README @@ -0,0 +1,23 @@ + + +Prior to running the tests, set environment variable `ARROW_FLIGHT_SQL_ODBC_CONN` +to a valid connection string. +A valid connection string looks like: +driver={Apache Arrow Flight SQL ODBC Driver};HOST=localhost;port=32010;pwd=myPassword;uid=myName;useEncryption=false; diff --git a/cpp/src/arrow/flight/sql/odbc/tests/columns_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/columns_test.cc new file mode 100644 index 00000000000..77534d16ac2 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/columns_test.cc @@ -0,0 +1,2716 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#include "arrow/flight/sql/odbc/odbc_impl/platform.h" + +#include +#include +#include + +#include + +namespace arrow::flight::sql::odbc { + +template +class ColumnsTest : public T {}; + +class ColumnsMockTest : public FlightSQLODBCMockTestBase {}; +class ColumnsRemoteTest : public FlightSQLODBCRemoteTestBase {}; +using TestTypes = ::testing::Types; +TYPED_TEST_SUITE(ColumnsTest, TestTypes); + +template +class ColumnsOdbcV2Test : public T {}; + +class ColumnsOdbcV2MockTest : public FlightSQLOdbcV2MockTestBase {}; +class ColumnsOdbcV2RemoteTest : public FlightSQLOdbcV2RemoteTestBase {}; +using TestTypesOdbcV2 = ::testing::Types; +TYPED_TEST_SUITE(ColumnsOdbcV2Test, TestTypesOdbcV2); + +namespace { +// Helper functions +void CheckSQLColumns( + SQLHSTMT stmt, const std::wstring& expected_table, + const std::wstring& expected_column, const SQLINTEGER& expected_data_type, + const std::wstring& expected_type_name, const SQLINTEGER& expected_column_size, + const SQLINTEGER& expected_buffer_length, const SQLSMALLINT& expected_decimal_digits, + const SQLSMALLINT& expected_num_prec_radix, const SQLSMALLINT& expected_nullable, + const SQLSMALLINT& expected_sql_data_type, const SQLSMALLINT& expected_date_time_sub, + const SQLINTEGER& expected_octet_char_length, + const SQLINTEGER& expected_ordinal_position, + const std::wstring& expected_is_nullable) { + CheckStringColumnW(stmt, 3, expected_table); // table name + CheckStringColumnW(stmt, 4, expected_column); // column name + + CheckIntColumn(stmt, 5, expected_data_type); // data type + + CheckStringColumnW(stmt, 6, expected_type_name); // type name + + CheckIntColumn(stmt, 7, expected_column_size); // column size + CheckIntColumn(stmt, 8, expected_buffer_length); // buffer length + + CheckSmallIntColumn(stmt, 9, expected_decimal_digits); // decimal digits + CheckSmallIntColumn(stmt, 10, expected_num_prec_radix); // num prec radix + CheckSmallIntColumn(stmt, 11, + expected_nullable); // nullable + + CheckNullColumnW(stmt, 12); // remarks + CheckNullColumnW(stmt, 13); // column def + + CheckSmallIntColumn(stmt, 14, expected_sql_data_type); // sql data type + CheckSmallIntColumn(stmt, 15, expected_date_time_sub); // sql date type sub + CheckIntColumn(stmt, 16, expected_octet_char_length); // char octet length + CheckIntColumn(stmt, 17, + expected_ordinal_position); // oridinal position + + CheckStringColumnW(stmt, 18, expected_is_nullable); // is nullable +} + +void CheckMockSQLColumns( + SQLHSTMT stmt, const std::wstring& expected_catalog, + const std::wstring& expected_table, const std::wstring& expected_column, + const SQLINTEGER& expected_data_type, const std::wstring& expected_type_name, + const SQLINTEGER& expected_column_size, const SQLINTEGER& expected_buffer_length, + const SQLSMALLINT& expected_decimal_digits, + const SQLSMALLINT& expected_num_prec_radix, const SQLSMALLINT& expected_nullable, + const SQLSMALLINT& expected_sql_data_type, const SQLSMALLINT& expected_date_time_sub, + const SQLINTEGER& expected_octet_char_length, + const SQLINTEGER& expected_ordinal_position, + const std::wstring& expected_is_nullable) { + CheckStringColumnW(stmt, 1, expected_catalog); // catalog + CheckNullColumnW(stmt, 2); // schema + + CheckSQLColumns(stmt, expected_table, expected_column, expected_data_type, + expected_type_name, expected_column_size, expected_buffer_length, + expected_decimal_digits, expected_num_prec_radix, expected_nullable, + expected_sql_data_type, expected_date_time_sub, + expected_octet_char_length, expected_ordinal_position, + expected_is_nullable); +} + +void CheckRemoteSQLColumns( + SQLHSTMT stmt, const std::wstring& expected_schema, + const std::wstring& expected_table, const std::wstring& expected_column, + const SQLINTEGER& expected_data_type, const std::wstring& expected_type_name, + const SQLINTEGER& expected_column_size, const SQLINTEGER& expected_buffer_length, + const SQLSMALLINT& expected_decimal_digits, + const SQLSMALLINT& expected_num_prec_radix, const SQLSMALLINT& expected_nullable, + const SQLSMALLINT& expected_sql_data_type, const SQLSMALLINT& expected_date_time_sub, + const SQLINTEGER& expected_octet_char_length, + const SQLINTEGER& expected_ordinal_position, + const std::wstring& expected_is_nullable) { + CheckNullColumnW(stmt, 1); // catalog + CheckStringColumnW(stmt, 2, expected_schema); // schema + CheckSQLColumns(stmt, expected_table, expected_column, expected_data_type, + expected_type_name, expected_column_size, expected_buffer_length, + expected_decimal_digits, expected_num_prec_radix, expected_nullable, + expected_sql_data_type, expected_date_time_sub, + expected_octet_char_length, expected_ordinal_position, + expected_is_nullable); +} + +void CheckSQLColAttribute(SQLHSTMT stmt, SQLUSMALLINT idx, + const std::wstring& expected_column_name, + SQLLEN expected_data_type, SQLLEN expected_concise_type, + SQLLEN expected_display_size, SQLLEN expected_prec_scale, + SQLLEN expected_length, + const std::wstring& expected_literal_prefix, + const std::wstring& expected_literal_suffix, + SQLLEN expected_column_size, SQLLEN expected_column_scale, + SQLLEN expected_column_nullability, + SQLLEN expected_num_prec_radix, SQLLEN expected_octet_length, + SQLLEN expected_searchable, SQLLEN expected_unsigned_column) { + std::vector name(kOdbcBufferSize); + SQLSMALLINT name_len = 0; + std::vector base_column_name(kOdbcBufferSize); + SQLSMALLINT column_name_len = 0; + std::vector label(kOdbcBufferSize); + SQLSMALLINT label_len = 0; + std::vector prefix(kOdbcBufferSize); + SQLSMALLINT prefix_len = 0; + std::vector suffix(kOdbcBufferSize); + SQLSMALLINT suffix_len = 0; + SQLLEN data_type = 0; + SQLLEN concise_type = 0; + SQLLEN display_size = 0; + SQLLEN prec_scale = 0; + SQLLEN length = 0; + SQLLEN size = 0; + SQLLEN scale = 0; + SQLLEN nullability = 0; + SQLLEN num_prec_radix = 0; + SQLLEN octet_length = 0; + SQLLEN searchable = 0; + SQLLEN unsigned_col = 0; + + EXPECT_EQ(SQL_SUCCESS, SQLColAttribute(stmt, idx, SQL_DESC_NAME, &name[0], + (SQLSMALLINT)name.size(), &name_len, 0)); + + EXPECT_EQ(SQL_SUCCESS, + SQLColAttribute(stmt, idx, SQL_DESC_BASE_COLUMN_NAME, &base_column_name[0], + (SQLSMALLINT)base_column_name.size(), &column_name_len, 0)); + + EXPECT_EQ(SQL_SUCCESS, SQLColAttribute(stmt, idx, SQL_DESC_LABEL, &label[0], + (SQLSMALLINT)label.size(), &label_len, 0)); + + EXPECT_EQ(SQL_SUCCESS, SQLColAttribute(stmt, idx, SQL_DESC_TYPE, 0, 0, 0, &data_type)); + + EXPECT_EQ(SQL_SUCCESS, + SQLColAttribute(stmt, idx, SQL_DESC_CONCISE_TYPE, 0, 0, 0, &concise_type)); + + EXPECT_EQ(SQL_SUCCESS, + SQLColAttribute(stmt, idx, SQL_DESC_DISPLAY_SIZE, 0, 0, 0, &display_size)); + + EXPECT_EQ(SQL_SUCCESS, + SQLColAttribute(stmt, idx, SQL_DESC_FIXED_PREC_SCALE, 0, 0, 0, &prec_scale)); + + EXPECT_EQ(SQL_SUCCESS, SQLColAttribute(stmt, idx, SQL_DESC_LENGTH, 0, 0, 0, &length)); + + EXPECT_EQ(SQL_SUCCESS, SQLColAttribute(stmt, idx, SQL_DESC_LITERAL_PREFIX, &prefix[0], + (SQLSMALLINT)prefix.size(), &prefix_len, 0)); + + EXPECT_EQ(SQL_SUCCESS, SQLColAttribute(stmt, idx, SQL_DESC_LITERAL_SUFFIX, &suffix[0], + (SQLSMALLINT)suffix.size(), &suffix_len, 0)); + + EXPECT_EQ(SQL_SUCCESS, SQLColAttribute(stmt, idx, SQL_DESC_PRECISION, 0, 0, 0, &size)); + + EXPECT_EQ(SQL_SUCCESS, SQLColAttribute(stmt, idx, SQL_DESC_SCALE, 0, 0, 0, &scale)); + + EXPECT_EQ(SQL_SUCCESS, + SQLColAttribute(stmt, idx, SQL_DESC_NULLABLE, 0, 0, 0, &nullability)); + + EXPECT_EQ(SQL_SUCCESS, SQLColAttribute(stmt, idx, SQL_DESC_NUM_PREC_RADIX, 0, 0, 0, + &num_prec_radix)); + + EXPECT_EQ(SQL_SUCCESS, + SQLColAttribute(stmt, idx, SQL_DESC_OCTET_LENGTH, 0, 0, 0, &octet_length)); + + EXPECT_EQ(SQL_SUCCESS, + SQLColAttribute(stmt, idx, SQL_DESC_SEARCHABLE, 0, 0, 0, &searchable)); + + EXPECT_EQ(SQL_SUCCESS, + SQLColAttribute(stmt, idx, SQL_DESC_UNSIGNED, 0, 0, 0, &unsigned_col)); + + std::wstring name_str = ConvertToWString(name, name_len); + std::wstring base_column_name_str = ConvertToWString(base_column_name, column_name_len); + std::wstring label_str = ConvertToWString(label, label_len); + std::wstring prefixStr = ConvertToWString(prefix, prefix_len); + + // Assume column name, base column name, and label are equivalent in the result set + EXPECT_EQ(expected_column_name, name_str); + EXPECT_EQ(expected_column_name, base_column_name_str); + EXPECT_EQ(expected_column_name, label_str); + EXPECT_EQ(expected_data_type, data_type); + EXPECT_EQ(expected_concise_type, concise_type); + EXPECT_EQ(expected_display_size, display_size); + EXPECT_EQ(expected_prec_scale, prec_scale); + EXPECT_EQ(expected_length, length); + EXPECT_EQ(expected_literal_prefix, prefixStr); + EXPECT_EQ(expected_column_size, size); + EXPECT_EQ(expected_column_scale, scale); + EXPECT_EQ(expected_column_nullability, nullability); + EXPECT_EQ(expected_num_prec_radix, num_prec_radix); + EXPECT_EQ(expected_octet_length, octet_length); + EXPECT_EQ(expected_searchable, searchable); + EXPECT_EQ(expected_unsigned_column, unsigned_col); +} + +void CheckSQLColAttributes(SQLHSTMT stmt, SQLUSMALLINT idx, + const std::wstring& expected_column_name, + SQLLEN expected_data_type, SQLLEN expected_display_size, + SQLLEN expected_prec_scale, SQLLEN expected_length, + SQLLEN expected_column_size, SQLLEN expected_column_scale, + SQLLEN expected_column_nullability, SQLLEN expected_searchable, + SQLLEN expected_unsigned_column) { + std::vector name(kOdbcBufferSize); + SQLSMALLINT name_len = 0; + std::vector label(kOdbcBufferSize); + SQLSMALLINT label_len = 0; + SQLLEN data_type = 0; + SQLLEN display_size = 0; + SQLLEN prec_scale = 0; + SQLLEN length = 0; + SQLLEN size = 0; + SQLLEN scale = 0; + SQLLEN nullability = 0; + SQLLEN searchable = 0; + SQLLEN unsigned_col = 0; + + EXPECT_EQ(SQL_SUCCESS, SQLColAttributes(stmt, idx, SQL_COLUMN_NAME, &name[0], + (SQLSMALLINT)name.size(), &name_len, 0)); + + EXPECT_EQ(SQL_SUCCESS, SQLColAttributes(stmt, idx, SQL_COLUMN_LABEL, &label[0], + (SQLSMALLINT)label.size(), &label_len, 0)); + + EXPECT_EQ(SQL_SUCCESS, + SQLColAttributes(stmt, idx, SQL_COLUMN_TYPE, 0, 0, 0, &data_type)); + + EXPECT_EQ(SQL_SUCCESS, + SQLColAttributes(stmt, idx, SQL_COLUMN_DISPLAY_SIZE, 0, 0, 0, &display_size)); + + EXPECT_EQ(SQL_SUCCESS, + SQLColAttribute(stmt, idx, SQL_COLUMN_MONEY, 0, 0, 0, &prec_scale)); + + EXPECT_EQ(SQL_SUCCESS, + SQLColAttributes(stmt, idx, SQL_COLUMN_LENGTH, 0, 0, 0, &length)); + + EXPECT_EQ(SQL_SUCCESS, + SQLColAttributes(stmt, idx, SQL_COLUMN_PRECISION, 0, 0, 0, &size)); + + EXPECT_EQ(SQL_SUCCESS, SQLColAttributes(stmt, idx, SQL_COLUMN_SCALE, 0, 0, 0, &scale)); + + EXPECT_EQ(SQL_SUCCESS, + SQLColAttributes(stmt, idx, SQL_COLUMN_NULLABLE, 0, 0, 0, &nullability)); + + EXPECT_EQ(SQL_SUCCESS, + SQLColAttributes(stmt, idx, SQL_COLUMN_SEARCHABLE, 0, 0, 0, &searchable)); + + EXPECT_EQ(SQL_SUCCESS, + SQLColAttributes(stmt, idx, SQL_COLUMN_UNSIGNED, 0, 0, 0, &unsigned_col)); + + std::wstring name_str = ConvertToWString(name, name_len); + std::wstring label_str = ConvertToWString(label, label_len); + + EXPECT_EQ(expected_column_name, name_str); + EXPECT_EQ(expected_column_name, label_str); + EXPECT_EQ(expected_data_type, data_type); + EXPECT_EQ(expected_display_size, display_size); + EXPECT_EQ(expected_length, length); + EXPECT_EQ(expected_column_size, size); + EXPECT_EQ(expected_column_scale, scale); + EXPECT_EQ(expected_column_nullability, nullability); + EXPECT_EQ(expected_searchable, searchable); + EXPECT_EQ(expected_unsigned_column, unsigned_col); +} + +void GetSQLColAttributeString(SQLHSTMT stmt, const std::wstring& wsql, SQLUSMALLINT idx, + SQLUSMALLINT field_identifier, std::wstring& value) { + if (!wsql.empty()) { + // Execute query + std::vector sql0(wsql.begin(), wsql.end()); + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(stmt)); + } + + // check SQLColAttribute string attribute + std::vector str_val(kOdbcBufferSize); + SQLSMALLINT str_len = 0; + + ASSERT_EQ(SQL_SUCCESS, SQLColAttribute(stmt, idx, field_identifier, &str_val[0], + (SQLSMALLINT)str_val.size(), &str_len, 0)); + + value = ConvertToWString(str_val, str_len); +} + +void GetSQLColAttributesString(SQLHSTMT stmt, const std::wstring& wsql, SQLUSMALLINT idx, + SQLUSMALLINT field_identifier, std::wstring& value) { + if (!wsql.empty()) { + // Execute query + std::vector sql0(wsql.begin(), wsql.end()); + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(stmt)); + } + + // check SQLColAttribute string attribute + std::vector str_val(kOdbcBufferSize); + SQLSMALLINT str_len = 0; + + ASSERT_EQ(SQL_SUCCESS, SQLColAttributes(stmt, idx, field_identifier, &str_val[0], + (SQLSMALLINT)str_val.size(), &str_len, 0)); + + value = ConvertToWString(str_val, str_len); +} + +void GetSQLColAttributeNumeric(SQLHSTMT stmt, const std::wstring& wsql, SQLUSMALLINT idx, + SQLUSMALLINT field_identifier, SQLLEN* value) { + // Execute query and check SQLColAttribute numeric attribute + std::vector sql0(wsql.begin(), wsql.end()); + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(stmt)); + + SQLLEN num_val = 0; + ASSERT_EQ(SQL_SUCCESS, SQLColAttribute(stmt, idx, field_identifier, 0, 0, 0, value)); +} + +void GetSQLColAttributesNumeric(SQLHSTMT stmt, const std::wstring& wsql, SQLUSMALLINT idx, + SQLUSMALLINT field_identifier, SQLLEN* value) { + // Execute query and check SQLColAttribute numeric attribute + std::vector sql0(wsql.begin(), wsql.end()); + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(stmt)); + + SQLLEN num_val = 0; + ASSERT_EQ(SQL_SUCCESS, SQLColAttributes(stmt, idx, field_identifier, 0, 0, 0, value)); +} +} // namespace + +TYPED_TEST(ColumnsTest, SQLColumnsTestInputData) { + SQLWCHAR catalog_name[] = L""; + SQLWCHAR schema_name[] = L""; + SQLWCHAR table_name[] = L""; + SQLWCHAR column_name[] = L""; + + // All values populated + EXPECT_EQ(SQL_SUCCESS, + SQLColumns(this->stmt, catalog_name, sizeof(catalog_name), schema_name, + sizeof(schema_name), table_name, sizeof(table_name), column_name, + sizeof(column_name))); + ValidateFetch(this->stmt, SQL_NO_DATA); + + // Sizes are nulls + EXPECT_EQ(SQL_SUCCESS, SQLColumns(this->stmt, catalog_name, 0, schema_name, 0, + table_name, 0, column_name, 0)); + ValidateFetch(this->stmt, SQL_NO_DATA); + + // Values are nulls + EXPECT_EQ(SQL_SUCCESS, + SQLColumns(this->stmt, 0, sizeof(catalog_name), 0, sizeof(schema_name), 0, + sizeof(table_name), 0, sizeof(column_name))); + ValidateFetch(this->stmt, SQL_SUCCESS); + // Close statement cursor to avoid leaving in an invalid state + SQLFreeStmt(this->stmt, SQL_CLOSE); + + // All values and sizes are nulls + EXPECT_EQ(SQL_SUCCESS, SQLColumns(this->stmt, 0, 0, 0, 0, 0, 0, 0, 0)); + ValidateFetch(this->stmt, SQL_SUCCESS); +} + +TEST_F(ColumnsMockTest, TestSQLColumnsAllColumns) { + // Check table pattern and column pattern returns all columns + + // Attempt to get all columns + SQLWCHAR table_pattern[] = L"%"; + SQLWCHAR column_pattern[] = L"%"; + + ASSERT_EQ(SQL_SUCCESS, SQLColumns(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + table_pattern, SQL_NTS, column_pattern, SQL_NTS)); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // mock limitation: SQLite mock server returns 10 for bigint size when spec indicates + // should be 19 + // DECIMAL_DIGITS should be 0 for bigint type since it is exact + // mock limitation: SQLite mock server returns 10 for bigint decimal digits when spec + // indicates should be 0 + CheckMockSQLColumns(this->stmt, + std::wstring(L"main"), // expected_catalog + std::wstring(L"foreignTable"), // expected_table + std::wstring(L"id"), // expected_column + SQL_BIGINT, // expected_data_type + std::wstring(L"BIGINT"), // expected_type_name + 10, // expected_column_size (mock returns 10 instead of 19) + 8, // expected_buffer_length + 15, // expected_decimal_digits (mock returns 15 instead of 0) + 10, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_BIGINT, // expected_sql_data_type + NULL, // expected_date_time_sub + 8, // expected_octet_char_length + 1, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check 2nd Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckMockSQLColumns(this->stmt, + std::wstring(L"main"), // expected_catalog + std::wstring(L"foreignTable"), // expected_table + std::wstring(L"foreignName"), // expected_column + SQL_WVARCHAR, // expected_data_type + std::wstring(L"WVARCHAR"), // expected_type_name + 0, // expected_column_size (mock server limitation: returns 0 for + // varchar(100), the ODBC spec expects 100) + 0, // expected_buffer_length + 15, // expected_decimal_digits + 0, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_WVARCHAR, // expected_sql_data_type + NULL, // expected_date_time_sub + 0, // expected_octet_char_length + 2, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check 3rd Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckMockSQLColumns(this->stmt, + std::wstring(L"main"), // expected_catalog + std::wstring(L"foreignTable"), // expected_table + std::wstring(L"value"), // expected_column + SQL_BIGINT, // expected_data_type + std::wstring(L"BIGINT"), // expected_type_name + 10, // expected_column_size (mock returns 10 instead of 19) + 8, // expected_buffer_length + 15, // expected_decimal_digits (mock returns 15 instead of 0) + 10, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_BIGINT, // expected_sql_data_type + NULL, // expected_date_time_sub + 8, // expected_octet_char_length + 3, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check 4th Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckMockSQLColumns(this->stmt, + std::wstring(L"main"), // expected_catalog + std::wstring(L"intTable"), // expected_table + std::wstring(L"id"), // expected_column + SQL_BIGINT, // expected_data_type + std::wstring(L"BIGINT"), // expected_type_name + 10, // expected_column_size (mock returns 10 instead of 19) + 8, // expected_buffer_length + 15, // expected_decimal_digits (mock returns 15 instead of 0) + 10, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_BIGINT, // expected_sql_data_type + NULL, // expected_date_time_sub + 8, // expected_octet_char_length + 1, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check 5th Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckMockSQLColumns(this->stmt, + std::wstring(L"main"), // expected_catalog + std::wstring(L"intTable"), // expected_table + std::wstring(L"keyName"), // expected_column + SQL_WVARCHAR, // expected_data_type + std::wstring(L"WVARCHAR"), // expected_type_name + 0, // expected_column_size (mock server limitation: returns 0 for + // varchar(100), the ODBC spec expects 100) + 0, // expected_buffer_length + 15, // expected_decimal_digits + 0, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_WVARCHAR, // expected_sql_data_type + NULL, // expected_date_time_sub + 0, // expected_octet_char_length + 2, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check 6th Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckMockSQLColumns(this->stmt, + std::wstring(L"main"), // expected_catalog + std::wstring(L"intTable"), // expected_table + std::wstring(L"value"), // expected_column + SQL_BIGINT, // expected_data_type + std::wstring(L"BIGINT"), // expected_type_name + 10, // expected_column_size (mock returns 10 instead of 19) + 8, // expected_buffer_length + 15, // expected_decimal_digits (mock returns 15 instead of 0) + 10, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_BIGINT, // expected_sql_data_type + NULL, // expected_date_time_sub + 8, // expected_octet_char_length + 3, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check 7th Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckMockSQLColumns(this->stmt, + std::wstring(L"main"), // expected_catalog + std::wstring(L"intTable"), // expected_table + std::wstring(L"foreignId"), // expected_column + SQL_BIGINT, // expected_data_type + std::wstring(L"BIGINT"), // expected_type_name + 10, // expected_column_size (mock returns 10 instead of 19) + 8, // expected_buffer_length + 15, // expected_decimal_digits (mock returns 15 instead of 0) + 10, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_BIGINT, // expected_sql_data_type + NULL, // expected_date_time_sub + 8, // expected_octet_char_length + 4, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable +} + +TEST_F(ColumnsMockTest, TestSQLColumnsAllTypes) { + // Limitation: Mock server returns incorrect values for column size for some columns. + // For character and binary type columns, the driver calculates buffer length and char + // octet length from column size. + + // Checks filtering table with table name pattern + this->CreateTableAllDataType(); + + // Attempt to get all columns from AllTypesTable + SQLWCHAR table_pattern[] = L"AllTypesTable"; + SQLWCHAR column_pattern[] = L"%"; + + ASSERT_EQ(SQL_SUCCESS, SQLColumns(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + table_pattern, SQL_NTS, column_pattern, SQL_NTS)); + + // Fetch SQLColumn data for 1st column in AllTypesTable + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckMockSQLColumns(this->stmt, + std::wstring(L"main"), // expected_catalog + std::wstring(L"AllTypesTable"), // expected_table + std::wstring(L"bigint_col"), // expected_column + SQL_BIGINT, // expected_data_type + std::wstring(L"BIGINT"), // expected_type_name + 10, // expected_column_size (mock server limitation: returns 10, + // the ODBC spec expects 19) + 8, // expected_buffer_length + 15, // expected_decimal_digits (mock server limitation: returns 15, + // the ODBC spec expects 0) + 10, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_BIGINT, // expected_sql_data_type + NULL, // expected_date_time_sub + 8, // expected_octet_char_length + 1, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check SQLColumn data for 2nd column in AllTypesTable + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckMockSQLColumns(this->stmt, + std::wstring(L"main"), // expected_catalog + std::wstring(L"AllTypesTable"), // expected_table + std::wstring(L"char_col"), // expected_column + SQL_WVARCHAR, // expected_data_type + std::wstring(L"WVARCHAR"), // expected_type_name + 0, // expected_column_size (mock server limitation: returns 0 for + // varchar(100), the ODBC spec expects 100) + 0, // expected_buffer_length + 15, // expected_decimal_digits + 0, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_WVARCHAR, // expected_sql_data_type + NULL, // expected_date_time_sub + 0, // expected_octet_char_length + 2, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check SQLColumn data for 3rd column in AllTypesTable + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckMockSQLColumns(this->stmt, + std::wstring(L"main"), // expected_catalog + std::wstring(L"AllTypesTable"), // expected_table + std::wstring(L"varbinary_col"), // expected_column + SQL_BINARY, // expected_data_type + std::wstring(L"BINARY"), // expected_type_name + 0, // expected_column_size (mock server limitation: returns 0 for + // BLOB column, spec expects binary data limit) + 0, // expected_buffer_length + 15, // expected_decimal_digits + 0, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_BINARY, // expected_sql_data_type + NULL, // expected_date_time_sub + 0, // expected_octet_char_length + 3, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check SQLColumn data for 4th column in AllTypesTable + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckMockSQLColumns(this->stmt, + std::wstring(L"main"), // expected_catalog + std::wstring(L"AllTypesTable"), // expected_table + std::wstring(L"double_col"), // expected_column + SQL_DOUBLE, // expected_data_type + std::wstring(L"DOUBLE"), // expected_type_name + 15, // expected_column_size + 8, // expected_buffer_length + 15, // expected_decimal_digits + 2, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_DOUBLE, // expected_sql_data_type + NULL, // expected_date_time_sub + 8, // expected_octet_char_length + 4, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // There should be no more column data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(ColumnsMockTest, TestSQLColumnsUnicode) { + // Limitation: Mock server returns incorrect values for column size for some columns. + // For character and binary type columns, the driver calculates buffer length and char + // octet length from column size. + this->CreateUnicodeTable(); + + // Attempt to get all columns + SQLWCHAR table_pattern[] = L"数据"; + SQLWCHAR column_pattern[] = L"%"; + + ASSERT_EQ(SQL_SUCCESS, SQLColumns(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + table_pattern, SQL_NTS, column_pattern, SQL_NTS)); + + // Check SQLColumn data for 1st column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckMockSQLColumns(this->stmt, + std::wstring(L"main"), // expected_catalog + std::wstring(L"数据"), // expected_table + std::wstring(L"资料"), // expected_column + SQL_WVARCHAR, // expected_data_type + std::wstring(L"WVARCHAR"), // expected_type_name + 0, // expected_column_size (mock server limitation: returns 0 for + // varchar(100), spec expects 100) + 0, // expected_buffer_length + 15, // expected_decimal_digits + 0, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_WVARCHAR, // expected_sql_data_type + NULL, // expected_date_time_sub + 0, // expected_octet_char_length + 1, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // There should be no more column data + EXPECT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(ColumnsRemoteTest, TestSQLColumnsAllTypes) { + // GH-47159: Return NUM_PREC_RADIX based on whether COLUMN_SIZE contains number of + // digits or bits + + SQLWCHAR table_pattern[] = L"ODBCTest"; + SQLWCHAR column_pattern[] = L"%"; + + ASSERT_EQ(SQL_SUCCESS, SQLColumns(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + table_pattern, SQL_NTS, column_pattern, SQL_NTS)); + + // Check 1st Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckRemoteSQLColumns( + this->stmt, + std::wstring(L"$scratch"), // expected_schema + std::wstring(L"ODBCTest"), // expected_table + std::wstring(L"sinteger_max"), // expected_column + SQL_INTEGER, // expected_data_type + std::wstring(L"INTEGER"), // expected_type_name + 32, // expected_column_size (remote server returns number of bits) + 4, // expected_buffer_length + 0, // expected_decimal_digits + 10, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_INTEGER, // expected_sql_data_type + NULL, // expected_date_time_sub + 4, // expected_octet_char_length + 1, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check 2nd Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckRemoteSQLColumns( + this->stmt, + std::wstring(L"$scratch"), // expected_schema + std::wstring(L"ODBCTest"), // expected_table + std::wstring(L"sbigint_max"), // expected_column + SQL_BIGINT, // expected_data_type + std::wstring(L"BIGINT"), // expected_type_name + 64, // expected_column_size (remote server returns number of bits) + 8, // expected_buffer_length + 0, // expected_decimal_digits + 10, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_BIGINT, // expected_sql_data_type + NULL, // expected_date_time_sub + 8, // expected_octet_char_length + 2, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check 3rd Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckRemoteSQLColumns(this->stmt, + std::wstring(L"$scratch"), // expected_schema + std::wstring(L"ODBCTest"), // expected_table + std::wstring(L"decimal_positive"), // expected_column + SQL_DECIMAL, // expected_data_type + std::wstring(L"DECIMAL"), // expected_type_name + 38, // expected_column_size + 19, // expected_buffer_length + 0, // expected_decimal_digits + 10, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_DECIMAL, // expected_sql_data_type + NULL, // expected_date_time_sub + 2, // expected_octet_char_length + 3, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check 4th Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckRemoteSQLColumns(this->stmt, + std::wstring(L"$scratch"), // expected_schema + std::wstring(L"ODBCTest"), // expected_table + std::wstring(L"float_max"), // expected_column + SQL_FLOAT, // expected_data_type + std::wstring(L"FLOAT"), // expected_type_name + 24, // expected_column_size (precision bits from IEEE 754) + 8, // expected_buffer_length + 0, // expected_decimal_digits + 2, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_FLOAT, // expected_sql_data_type + NULL, // expected_date_time_sub + 8, // expected_octet_char_length + 4, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check 5th Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckRemoteSQLColumns(this->stmt, + std::wstring(L"$scratch"), // expected_schema + std::wstring(L"ODBCTest"), // expected_table + std::wstring(L"double_max"), // expected_column + SQL_DOUBLE, // expected_data_type + std::wstring(L"DOUBLE"), // expected_type_name + 53, // expected_column_size (precision bits from IEEE 754) + 8, // expected_buffer_length + 0, // expected_decimal_digits + 2, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_DOUBLE, // expected_sql_data_type + NULL, // expected_date_time_sub + 8, // expected_octet_char_length + 5, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check 6th Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckRemoteSQLColumns(this->stmt, + std::wstring(L"$scratch"), // expected_schema + std::wstring(L"ODBCTest"), // expected_table + std::wstring(L"bit_true"), // expected_column + SQL_BIT, // expected_data_type + std::wstring(L"BOOLEAN"), // expected_type_name + 0, // expected_column_size (limitation: remote server remote + // server returns 0, should be 1) + 1, // expected_buffer_length + 0, // expected_decimal_digits + 0, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_BIT, // expected_sql_data_type + NULL, // expected_date_time_sub + 1, // expected_octet_char_length + 6, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // ODBC ver 3 returns SQL_TYPE_DATE, SQL_TYPE_TIME, and SQL_TYPE_TIMESTAMP in the + // DATA_TYPE field + + // Check 7th Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckRemoteSQLColumns( + this->stmt, + std::wstring(L"$scratch"), // expected_schema + std::wstring(L"ODBCTest"), // expected_table + std::wstring(L"date_max"), // expected_column + SQL_TYPE_DATE, // expected_data_type + std::wstring(L"DATE"), // expected_type_name + 0, // expected_column_size (limitation: remote server returns 0, should be 10) + 10, // expected_buffer_length + 0, // expected_decimal_digits + 0, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_DATETIME, // expected_sql_data_type + SQL_CODE_DATE, // expected_date_time_sub + 6, // expected_octet_char_length + 7, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check 8th Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckRemoteSQLColumns( + this->stmt, + std::wstring(L"$scratch"), // expected_schema + std::wstring(L"ODBCTest"), // expected_table + std::wstring(L"time_max"), // expected_column + SQL_TYPE_TIME, // expected_data_type + std::wstring(L"TIME"), // expected_type_name + 3, // expected_column_size (limitation: should be 9+fractional digits) + 12, // expected_buffer_length + 0, // expected_decimal_digits + 0, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_DATETIME, // expected_sql_data_type + SQL_CODE_TIME, // expected_date_time_sub + 6, // expected_octet_char_length + 8, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check 9th Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckRemoteSQLColumns( + this->stmt, + std::wstring(L"$scratch"), // expected_schema + std::wstring(L"ODBCTest"), // expected_table + std::wstring(L"timestamp_max"), // expected_column + SQL_TYPE_TIMESTAMP, // expected_data_type + std::wstring(L"TIMESTAMP"), // expected_type_name + 3, // expected_column_size (limitation: should be 20+fractional digits) + 23, // expected_buffer_length + 0, // expected_decimal_digits + 0, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_DATETIME, // expected_sql_data_type + SQL_CODE_TIMESTAMP, // expected_date_time_sub + 16, // expected_octet_char_length + 9, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // There is no more column + EXPECT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(ColumnsOdbcV2RemoteTest, TestSQLColumnsAllTypes) { + // GH-47159: Return NUM_PREC_RADIX based on whether COLUMN_SIZE contains number of + // digits or bits + + SQLWCHAR table_pattern[] = L"ODBCTest"; + SQLWCHAR column_pattern[] = L"%"; + + ASSERT_EQ(SQL_SUCCESS, SQLColumns(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + table_pattern, SQL_NTS, column_pattern, SQL_NTS)); + + // Check 1st Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckRemoteSQLColumns( + this->stmt, + std::wstring(L"$scratch"), // expected_schema + std::wstring(L"ODBCTest"), // expected_table + std::wstring(L"sinteger_max"), // expected_column + SQL_INTEGER, // expected_data_type + std::wstring(L"INTEGER"), // expected_type_name + 32, // expected_column_size (remote server returns number of bits) + 4, // expected_buffer_length + 0, // expected_decimal_digits + 10, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_INTEGER, // expected_sql_data_type + NULL, // expected_date_time_sub + 4, // expected_octet_char_length + 1, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check 2nd Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckRemoteSQLColumns( + this->stmt, + std::wstring(L"$scratch"), // expected_schema + std::wstring(L"ODBCTest"), // expected_table + std::wstring(L"sbigint_max"), // expected_column + SQL_BIGINT, // expected_data_type + std::wstring(L"BIGINT"), // expected_type_name + 64, // expected_column_size (remote server returns number of bits) + 8, // expected_buffer_length + 0, // expected_decimal_digits + 10, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_BIGINT, // expected_sql_data_type + NULL, // expected_date_time_sub + 8, // expected_octet_char_length + 2, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check 3rd Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckRemoteSQLColumns(this->stmt, + std::wstring(L"$scratch"), // expected_schema + std::wstring(L"ODBCTest"), // expected_table + std::wstring(L"decimal_positive"), // expected_column + SQL_DECIMAL, // expected_data_type + std::wstring(L"DECIMAL"), // expected_type_name + 38, // expected_column_size + 19, // expected_buffer_length + 0, // expected_decimal_digits + 10, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_DECIMAL, // expected_sql_data_type + NULL, // expected_date_time_sub + 2, // expected_octet_char_length + 3, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check 4th Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckRemoteSQLColumns(this->stmt, + std::wstring(L"$scratch"), // expected_schema + std::wstring(L"ODBCTest"), // expected_table + std::wstring(L"float_max"), // expected_column + SQL_FLOAT, // expected_data_type + std::wstring(L"FLOAT"), // expected_type_name + 24, // expected_column_size (precision bits from IEEE 754) + 8, // expected_buffer_length + 0, // expected_decimal_digits + 2, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_FLOAT, // expected_sql_data_type + NULL, // expected_date_time_sub + 8, // expected_octet_char_length + 4, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check 5th Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckRemoteSQLColumns(this->stmt, + std::wstring(L"$scratch"), // expected_schema + std::wstring(L"ODBCTest"), // expected_table + std::wstring(L"double_max"), // expected_column + SQL_DOUBLE, // expected_data_type + std::wstring(L"DOUBLE"), // expected_type_name + 53, // expected_column_size (precision bits from IEEE 754) + 8, // expected_buffer_length + 0, // expected_decimal_digits + 2, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_DOUBLE, // expected_sql_data_type + NULL, // expected_date_time_sub + 8, // expected_octet_char_length + 5, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check 6th Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckRemoteSQLColumns(this->stmt, + std::wstring(L"$scratch"), // expected_schema + std::wstring(L"ODBCTest"), // expected_table + std::wstring(L"bit_true"), // expected_column + SQL_BIT, // expected_data_type + std::wstring(L"BOOLEAN"), // expected_type_name + 0, // expected_column_size (limitation: remote server remote + // server returns 0, should be 1) + 1, // expected_buffer_length + 0, // expected_decimal_digits + 0, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_BIT, // expected_sql_data_type + NULL, // expected_date_time_sub + 1, // expected_octet_char_length + 6, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // ODBC ver 2 returns SQL_DATE, SQL_TIME, and SQL_TIMESTAMP in the DATA_TYPE field + + // Check 7th Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckRemoteSQLColumns( + this->stmt, + std::wstring(L"$scratch"), // expected_schema + std::wstring(L"ODBCTest"), // expected_table + std::wstring(L"date_max"), // expected_column + SQL_DATE, // expected_data_type + std::wstring(L"DATE"), // expected_type_name + 0, // expected_column_size (limitation: remote server returns 0, should be 10) + 10, // expected_buffer_length + 0, // expected_decimal_digits + 0, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_DATETIME, // expected_sql_data_type + SQL_CODE_DATE, // expected_date_time_sub + 6, // expected_octet_char_length + 7, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check 8th Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckRemoteSQLColumns( + this->stmt, + std::wstring(L"$scratch"), // expected_schema + std::wstring(L"ODBCTest"), // expected_table + std::wstring(L"time_max"), // expected_column + SQL_TIME, // expected_data_type + std::wstring(L"TIME"), // expected_type_name + 3, // expected_column_size (limitation: should be 9+fractional digits) + 12, // expected_buffer_length + 0, // expected_decimal_digits + 0, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_DATETIME, // expected_sql_data_type + SQL_CODE_TIME, // expected_date_time_sub + 6, // expected_octet_char_length + 8, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check 9th Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckRemoteSQLColumns( + this->stmt, + std::wstring(L"$scratch"), // expected_schema + std::wstring(L"ODBCTest"), // expected_table + std::wstring(L"timestamp_max"), // expected_column + SQL_TIMESTAMP, // expected_data_type + std::wstring(L"TIMESTAMP"), // expected_type_name + 3, // expected_column_size (limitation: should be 20+fractional digits) + 23, // expected_buffer_length + 0, // expected_decimal_digits + 0, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_DATETIME, // expected_sql_data_type + SQL_CODE_TIMESTAMP, // expected_date_time_sub + 16, // expected_octet_char_length + 9, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // There is no more column + EXPECT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(ColumnsMockTest, TestSQLColumnscolumn_pattern) { + // Checks filtering table with column name pattern. + // Only check table and column name + + SQLWCHAR table_pattern[] = L"%"; + SQLWCHAR column_pattern[] = L"id"; + + EXPECT_EQ(SQL_SUCCESS, SQLColumns(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + table_pattern, SQL_NTS, column_pattern, SQL_NTS)); + + // Check 1st Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckMockSQLColumns(this->stmt, + std::wstring(L"main"), // expected_catalog + std::wstring(L"foreignTable"), // expected_table + std::wstring(L"id"), // expected_column + SQL_BIGINT, // expected_data_type + std::wstring(L"BIGINT"), // expected_type_name + 10, // expected_column_size (mock returns 10 instead of 19) + 8, // expected_buffer_length + 15, // expected_decimal_digits (mock returns 15 instead of 0) + 10, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_BIGINT, // expected_sql_data_type + NULL, // expected_date_time_sub + 8, // expected_octet_char_length + 1, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // Check 2nd Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckMockSQLColumns(this->stmt, + std::wstring(L"main"), // expected_catalog + std::wstring(L"intTable"), // expected_table + std::wstring(L"id"), // expected_column + SQL_BIGINT, // expected_data_type + std::wstring(L"BIGINT"), // expected_type_name + 10, // expected_column_size (mock returns 10 instead of 19) + 8, // expected_buffer_length + 15, // expected_decimal_digits (mock returns 15 instead of 0) + 10, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_BIGINT, // expected_sql_data_type + NULL, // expected_date_time_sub + 8, // expected_octet_char_length + 1, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // There is no more column + EXPECT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(ColumnsMockTest, TestSQLColumnsTablecolumn_pattern) { + // Checks filtering table with table and column name pattern. + // Only check table and column name + + SQLWCHAR table_pattern[] = L"foreignTable"; + SQLWCHAR column_pattern[] = L"id"; + + ASSERT_EQ(SQL_SUCCESS, SQLColumns(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + table_pattern, SQL_NTS, column_pattern, SQL_NTS)); + + // Check 1st Column + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckMockSQLColumns(this->stmt, + std::wstring(L"main"), // expected_catalog + std::wstring(L"foreignTable"), // expected_table + std::wstring(L"id"), // expected_column + SQL_BIGINT, // expected_data_type + std::wstring(L"BIGINT"), // expected_type_name + 10, // expected_column_size (mock returns 10 instead of 19) + 8, // expected_buffer_length + 15, // expected_decimal_digits (mock returns 15 instead of 0) + 10, // expected_num_prec_radix + SQL_NULLABLE, // expected_nullable + SQL_BIGINT, // expected_sql_data_type + NULL, // expected_date_time_sub + 8, // expected_octet_char_length + 1, // expected_ordinal_position + std::wstring(L"YES")); // expected_is_nullable + + // There is no more column + EXPECT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(ColumnsMockTest, TestSQLColumnsInvalidtable_pattern) { + SQLWCHAR table_pattern[] = L"non-existent-table"; + SQLWCHAR column_pattern[] = L"%"; + + ASSERT_EQ(SQL_SUCCESS, SQLColumns(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + table_pattern, SQL_NTS, column_pattern, SQL_NTS)); + + // There is no column from filter + EXPECT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TYPED_TEST(ColumnsTest, SQLColAttributeTestInputData) { + std::wstring wsql = L"SELECT 1 as col1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + SQLUSMALLINT idx = 1; + std::vector character_attr(kOdbcBufferSize); + SQLSMALLINT character_attr_len = 0; + SQLLEN numeric_attr = 0; + + // All character values populated + EXPECT_EQ(SQL_SUCCESS, + SQLColAttribute(this->stmt, idx, SQL_DESC_NAME, &character_attr[0], + (SQLSMALLINT)character_attr.size(), &character_attr_len, 0)); + + // All numeric values populated + EXPECT_EQ(SQL_SUCCESS, + SQLColAttribute(this->stmt, idx, SQL_DESC_COUNT, 0, 0, 0, &numeric_attr)); + + // Pass null values, driver should not throw error + EXPECT_EQ(SQL_SUCCESS, + SQLColAttribute(this->stmt, idx, SQL_COLUMN_TABLE_NAME, 0, 0, 0, 0)); + + EXPECT_EQ(SQL_SUCCESS, SQLColAttribute(this->stmt, idx, SQL_DESC_COUNT, 0, 0, 0, 0)); +} + +TYPED_TEST(ColumnsTest, SQLColAttributeGetCharacterLen) { + std::wstring wsql = L"SELECT 1 as col1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + SQLSMALLINT character_attr_len = 0; + + // Check length of character attribute + ASSERT_EQ(SQL_SUCCESS, SQLColAttribute(this->stmt, 1, SQL_DESC_BASE_COLUMN_NAME, 0, 0, + &character_attr_len, 0)); + EXPECT_EQ(4 * ODBC::GetSqlWCharSize(), character_attr_len); +} + +TYPED_TEST(ColumnsTest, SQLColAttributeInvalidFieldId) { + std::wstring wsql = L"SELECT 1 as col1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + SQLUSMALLINT invalid_field_id = -100; + SQLUSMALLINT idx = 1; + std::vector character_attr(kOdbcBufferSize); + SQLSMALLINT character_attr_len = 0; + SQLLEN numeric_attr = 0; + + ASSERT_EQ(SQL_ERROR, + SQLColAttribute(this->stmt, idx, invalid_field_id, &character_attr[0], + (SQLSMALLINT)character_attr.size(), &character_attr_len, 0)); + // Verify invalid descriptor field identifier error state is returned + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY091); +} + +TYPED_TEST(ColumnsTest, SQLColAttributeInvalidColId) { + std::wstring wsql = L"SELECT 1 as col1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + SQLUSMALLINT invalid_col_id = 2; + std::vector character_attr(kOdbcBufferSize); + SQLSMALLINT character_attr_len = 0; + + ASSERT_EQ(SQL_ERROR, + SQLColAttribute(this->stmt, invalid_col_id, SQL_DESC_BASE_COLUMN_NAME, + &character_attr[0], (SQLSMALLINT)character_attr.size(), + &character_attr_len, 0)); + // Verify invalid descriptor index error state is returned + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState07009); +} + +TEST_F(ColumnsMockTest, TestSQLColAttributeAllTypes) { + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLColAttribute(this->stmt, 1, + std::wstring(L"bigint_col"), // expected_column_name + SQL_BIGINT, // expected_data_type + SQL_BIGINT, // expected_concise_type + 20, // expected_display_size + SQL_FALSE, // expected_prec_scale + 8, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 8, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 10, // expected_num_prec_radix + 8, // expected_octet_length + SQL_PRED_NONE, // expected_searchable + SQL_FALSE); // expected_unsigned_column + + CheckSQLColAttribute(this->stmt, 2, + std::wstring(L"char_col"), // expected_column_name + SQL_WVARCHAR, // expected_data_type + SQL_WVARCHAR, // expected_concise_type + 0, // expected_display_size + SQL_FALSE, // expected_prec_scale + 0, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 0, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 0, // expected_num_prec_radix + 0, // expected_octet_length + SQL_PRED_NONE, // expected_searchable + SQL_TRUE); // expected_unsigned_column + + CheckSQLColAttribute(this->stmt, 3, + std::wstring(L"varbinary_col"), // expected_column_name + SQL_BINARY, // expected_data_type + SQL_BINARY, // expected_concise_type + 0, // expected_display_size + SQL_FALSE, // expected_prec_scale + 0, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 0, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 0, // expected_num_prec_radix + 0, // expected_octet_length + SQL_PRED_NONE, // expected_searchable + SQL_TRUE); // expected_unsigned_column + + CheckSQLColAttribute(this->stmt, 4, + std::wstring(L"double_col"), // expected_column_name + SQL_DOUBLE, // expected_data_type + SQL_DOUBLE, // expected_concise_type + 24, // expected_display_size + SQL_FALSE, // expected_prec_scale + 8, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 8, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 2, // expected_num_prec_radix + 8, // expected_octet_length + SQL_PRED_NONE, // expected_searchable + SQL_FALSE); // expected_unsigned_column +} + +TEST_F(ColumnsOdbcV2MockTest, TestSQLColAttributesAllTypes) { + // Tests ODBC 2.0 API SQLColAttributes + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + CheckSQLColAttributes(this->stmt, 1, + std::wstring(L"bigint_col"), // expected_column_name + SQL_BIGINT, // expected_data_type + 20, // expected_display_size + SQL_FALSE, // expected_prec_scale + 8, // expected_length + 8, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + SQL_PRED_NONE, // expected_searchable + SQL_FALSE); // expected_unsigned_column + + CheckSQLColAttributes(this->stmt, 2, + std::wstring(L"char_col"), // expected_column_name + SQL_WVARCHAR, // expected_data_type + 0, // expected_display_size + SQL_FALSE, // expected_prec_scale + 0, // expected_length + 0, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + SQL_PRED_NONE, // expected_searchable + SQL_TRUE); // expected_unsigned_column + + CheckSQLColAttributes(this->stmt, 3, + std::wstring(L"varbinary_col"), // expected_column_name + SQL_BINARY, // expected_data_type + 0, // expected_display_size + SQL_FALSE, // expected_prec_scale + 0, // expected_length + 0, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + SQL_PRED_NONE, // expected_searchable + SQL_TRUE); // expected_unsigned_column + + CheckSQLColAttributes(this->stmt, 4, + std::wstring(L"double_col"), // expected_column_name + SQL_DOUBLE, // expected_data_type + 24, // expected_display_size + SQL_FALSE, // expected_prec_scale + 8, // expected_length + 8, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + SQL_PRED_NONE, // expected_searchable + SQL_FALSE); // expected_unsigned_column +} + +TEST_F(ColumnsRemoteTest, TestSQLColAttributeAllTypes) { + // Test assumes there is a table $scratch.ODBCTest in remote server + + std::wstring wsql = L"SELECT * from $scratch.ODBCTest;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLColAttribute(this->stmt, 1, + std::wstring(L"sinteger_max"), // expected_column_name + SQL_INTEGER, // expected_data_type + SQL_INTEGER, // expected_concise_type + 11, // expected_display_size + SQL_FALSE, // expected_prec_scale + 4, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 4, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 10, // expected_num_prec_radix + 4, // expected_octet_length + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE); // expected_unsigned_column + + CheckSQLColAttribute(this->stmt, 2, + std::wstring(L"sbigint_max"), // expected_column_name + SQL_BIGINT, // expected_data_type + SQL_BIGINT, // expected_concise_type + 20, // expected_display_size + SQL_FALSE, // expected_prec_scale + 8, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 8, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 10, // expected_num_prec_radix + 8, // expected_octet_length + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE); // expected_unsigned_column + + CheckSQLColAttribute(this->stmt, 3, + std::wstring(L"decimal_positive"), // expected_column_name + SQL_DECIMAL, // expected_data_type + SQL_DECIMAL, // expected_concise_type + 40, // expected_display_size + SQL_FALSE, // expected_prec_scale + 19, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 19, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 10, // expected_num_prec_radix + 40, // expected_octet_length + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE); // expected_unsigned_column + + CheckSQLColAttribute(this->stmt, 4, + std::wstring(L"float_max"), // expected_column_name + SQL_FLOAT, // expected_data_type + SQL_FLOAT, // expected_concise_type + 24, // expected_display_size + SQL_FALSE, // expected_prec_scale + 8, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 8, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 2, // expected_num_prec_radix + 8, // expected_octet_length + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE); // expected_unsigned_column + + CheckSQLColAttribute(this->stmt, 5, + std::wstring(L"double_max"), // expected_column_name + SQL_DOUBLE, // expected_data_type + SQL_DOUBLE, // expected_concise_type + 24, // expected_display_size + SQL_FALSE, // expected_prec_scale + 8, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 8, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 2, // expected_num_prec_radix + 8, // expected_octet_length + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE); // expected_unsigned_column + + CheckSQLColAttribute(this->stmt, 6, + std::wstring(L"bit_true"), // expected_column_name + SQL_BIT, // expected_data_type + SQL_BIT, // expected_concise_type + 1, // expected_display_size + SQL_FALSE, // expected_prec_scale + 1, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 1, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 0, // expected_num_prec_radix + 1, // expected_octet_length + SQL_SEARCHABLE, // expected_searchable + SQL_TRUE); // expected_unsigned_column + + CheckSQLColAttribute(this->stmt, 7, + std::wstring(L"date_max"), // expected_column_name + SQL_DATETIME, // expected_data_type + SQL_TYPE_DATE, // expected_concise_type + 10, // expected_display_size + SQL_FALSE, // expected_prec_scale + 10, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 10, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 0, // expected_num_prec_radix + 6, // expected_octet_length + SQL_SEARCHABLE, // expected_searchable + SQL_TRUE); // expected_unsigned_column + + CheckSQLColAttribute(this->stmt, 8, + std::wstring(L"time_max"), // expected_column_name + SQL_DATETIME, // expected_data_type + SQL_TYPE_TIME, // expected_concise_type + 12, // expected_display_size + SQL_FALSE, // expected_prec_scale + 12, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 12, // expected_column_size + 3, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 0, // expected_num_prec_radix + 6, // expected_octet_length + SQL_SEARCHABLE, // expected_searchable + SQL_TRUE); // expected_unsigned_column + + CheckSQLColAttribute(this->stmt, 9, + std::wstring(L"timestamp_max"), // expected_column_name + SQL_DATETIME, // expected_data_type + SQL_TYPE_TIMESTAMP, // expected_concise_type + 23, // expected_display_size + SQL_FALSE, // expected_prec_scale + 23, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 23, // expected_column_size + 3, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 0, // expected_num_prec_radix + 16, // expected_octet_length + SQL_SEARCHABLE, // expected_searchable + SQL_TRUE); // expected_unsigned_column +} + +TEST_F(ColumnsOdbcV2RemoteTest, TestSQLColAttributeAllTypes) { + // Test assumes there is a table $scratch.ODBCTest in remote server + std::wstring wsql = L"SELECT * from $scratch.ODBCTest;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLColAttribute(this->stmt, 1, + std::wstring(L"sinteger_max"), // expected_column_name + SQL_INTEGER, // expected_data_type + SQL_INTEGER, // expected_concise_type + 11, // expected_display_size + SQL_FALSE, // expected_prec_scale + 4, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 4, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 10, // expected_num_prec_radix + 4, // expected_octet_length + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE); // expected_unsigned_column + + CheckSQLColAttribute(this->stmt, 2, + std::wstring(L"sbigint_max"), // expected_column_name + SQL_BIGINT, // expected_data_type + SQL_BIGINT, // expected_concise_type + 20, // expected_display_size + SQL_FALSE, // expected_prec_scale + 8, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 8, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 10, // expected_num_prec_radix + 8, // expected_octet_length + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE); // expected_unsigned_column + + CheckSQLColAttribute(this->stmt, 3, + std::wstring(L"decimal_positive"), // expected_column_name + SQL_DECIMAL, // expected_data_type + SQL_DECIMAL, // expected_concise_type + 40, // expected_display_size + SQL_FALSE, // expected_prec_scale + 19, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 19, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 10, // expected_num_prec_radix + 40, // expected_octet_length + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE); // expected_unsigned_column + + CheckSQLColAttribute(this->stmt, 4, + std::wstring(L"float_max"), // expected_column_name + SQL_FLOAT, // expected_data_type + SQL_FLOAT, // expected_concise_type + 24, // expected_display_size + SQL_FALSE, // expected_prec_scale + 8, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 8, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 2, // expected_num_prec_radix + 8, // expected_octet_length + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE); // expected_unsigned_column + + CheckSQLColAttribute(this->stmt, 5, + std::wstring(L"double_max"), // expected_column_name + SQL_DOUBLE, // expected_data_type + SQL_DOUBLE, // expected_concise_type + 24, // expected_display_size + SQL_FALSE, // expected_prec_scale + 8, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 8, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 2, // expected_num_prec_radix + 8, // expected_octet_length + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE); // expected_unsigned_column + + CheckSQLColAttribute(this->stmt, 6, + std::wstring(L"bit_true"), // expected_column_name + SQL_BIT, // expected_data_type + SQL_BIT, // expected_concise_type + 1, // expected_display_size + SQL_FALSE, // expected_prec_scale + 1, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 1, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 0, // expected_num_prec_radix + 1, // expected_octet_length + SQL_SEARCHABLE, // expected_searchable + SQL_TRUE); // expected_unsigned_column + + CheckSQLColAttribute(this->stmt, 7, + std::wstring(L"date_max"), // expected_column_name + SQL_DATETIME, // expected_data_type + SQL_DATE, // expected_concise_type + 10, // expected_display_size + SQL_FALSE, // expected_prec_scale + 10, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 10, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 0, // expected_num_prec_radix + 6, // expected_octet_length + SQL_SEARCHABLE, // expected_searchable + SQL_TRUE); // expected_unsigned_column + + CheckSQLColAttribute(this->stmt, 8, + std::wstring(L"time_max"), // expected_column_name + SQL_DATETIME, // expected_data_type + SQL_TIME, // expected_concise_type + 12, // expected_display_size + SQL_FALSE, // expected_prec_scale + 12, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 12, // expected_column_size + 3, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 0, // expected_num_prec_radix + 6, // expected_octet_length + SQL_SEARCHABLE, // expected_searchable + SQL_TRUE); // expected_unsigned_column + + CheckSQLColAttribute(this->stmt, 9, + std::wstring(L"timestamp_max"), // expected_column_name + SQL_DATETIME, // expected_data_type + SQL_TIMESTAMP, // expected_concise_type + 23, // expected_display_size + SQL_FALSE, // expected_prec_scale + 23, // expected_length + std::wstring(L""), // expected_literal_prefix + std::wstring(L""), // expected_literal_suffix + 23, // expected_column_size + 3, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + 0, // expected_num_prec_radix + 16, // expected_octet_length + SQL_SEARCHABLE, // expected_searchable + SQL_TRUE); // expected_unsigned_column +} + +TEST_F(ColumnsOdbcV2RemoteTest, TestSQLColAttributesAllTypes) { + // Tests ODBC 2.0 API SQLColAttributes + // Test assumes there is a table $scratch.ODBCTest in remote server + std::wstring wsql = L"SELECT * from $scratch.ODBCTest;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLColAttributes(this->stmt, 1, + std::wstring(L"sinteger_max"), // expected_column_name + SQL_INTEGER, // expected_data_type + 11, // expected_display_size + SQL_FALSE, // expected_prec_scale + 4, // expected_length + 4, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE); // expected_unsigned_column + + CheckSQLColAttributes(this->stmt, 2, + std::wstring(L"sbigint_max"), // expected_column_name + SQL_BIGINT, // expected_data_type + 20, // expected_display_size + SQL_FALSE, // expected_prec_scale + 8, // expected_length + 8, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE); // expected_unsigned_column + + CheckSQLColAttributes(this->stmt, 3, + std::wstring(L"decimal_positive"), // expected_column_name + SQL_DECIMAL, // expected_data_type + 40, // expected_display_size + SQL_FALSE, // expected_prec_scale + 19, // expected_length + 19, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE); // expected_unsigned_column + + CheckSQLColAttributes(this->stmt, 4, + std::wstring(L"float_max"), // expected_column_name + SQL_FLOAT, // expected_data_type + 24, // expected_display_size + SQL_FALSE, // expected_prec_scale + 8, // expected_length + 8, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE); // expected_unsigned_column + + CheckSQLColAttributes(this->stmt, 5, + std::wstring(L"double_max"), // expected_column_name + SQL_DOUBLE, // expected_data_type + 24, // expected_display_size + SQL_FALSE, // expected_prec_scale + 8, // expected_length + 8, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE); // expected_unsigned_column + + CheckSQLColAttributes(this->stmt, 6, + std::wstring(L"bit_true"), // expected_column_name + SQL_BIT, // expected_data_type + 1, // expected_display_size + SQL_FALSE, // expected_prec_scale + 1, // expected_length + 1, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + SQL_SEARCHABLE, // expected_searchable + SQL_TRUE); // expected_unsigned_column + + CheckSQLColAttributes(this->stmt, 7, + std::wstring(L"date_max"), // expected_column_name + SQL_DATE, // expected_data_type + 10, // expected_display_size + SQL_FALSE, // expected_prec_scale + 10, // expected_length + 10, // expected_column_size + 0, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + SQL_SEARCHABLE, // expected_searchable + SQL_TRUE); // expected_unsigned_column + + CheckSQLColAttributes(this->stmt, 8, + std::wstring(L"time_max"), // expected_column_name + SQL_TIME, // expected_data_type + 12, // expected_display_size + SQL_FALSE, // expected_prec_scale + 12, // expected_length + 12, // expected_column_size + 3, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + SQL_SEARCHABLE, // expected_searchable + SQL_TRUE); // expected_unsigned_column + + CheckSQLColAttributes(this->stmt, 9, + std::wstring(L"timestamp_max"), // expected_column_name + SQL_TIMESTAMP, // expected_data_type + 23, // expected_display_size + SQL_FALSE, // expected_prec_scale + 23, // expected_length + 23, // expected_column_size + 3, // expected_column_scale + SQL_NULLABLE, // expected_column_nullability + SQL_SEARCHABLE, // expected_searchable + SQL_TRUE); // expected_unsigned_column +} + +TYPED_TEST(ColumnsTest, TestSQLColAttributeCaseSensitive) { + // Arrow limitation: returns SQL_FALSE for case sensitive column + + std::wstring wsql = this->GetQueryAllDataTypes(); + // Int column + SQLLEN value; + GetSQLColAttributeNumeric(this->stmt, wsql, 1, SQL_DESC_CASE_SENSITIVE, &value); + ASSERT_EQ(SQL_FALSE, value); + SQLFreeStmt(this->stmt, SQL_CLOSE); + // Varchar column + GetSQLColAttributeNumeric(this->stmt, wsql, 28, SQL_DESC_CASE_SENSITIVE, &value); + ASSERT_EQ(SQL_FALSE, value); +} + +TYPED_TEST(ColumnsOdbcV2Test, TestSQLColAttributesCaseSensitive) { + // Arrow limitation: returns SQL_FALSE for case sensitive column + // Tests ODBC 2.0 API SQLColAttributes + + std::wstring wsql = this->GetQueryAllDataTypes(); + // Int column + SQLLEN value; + GetSQLColAttributesNumeric(this->stmt, wsql, 1, SQL_COLUMN_CASE_SENSITIVE, &value); + ASSERT_EQ(SQL_FALSE, value); + SQLFreeStmt(this->stmt, SQL_CLOSE); + // Varchar column + GetSQLColAttributesNumeric(this->stmt, wsql, 28, SQL_COLUMN_CASE_SENSITIVE, &value); + ASSERT_EQ(SQL_FALSE, value); +} + +TEST_F(ColumnsMockTest, TestSQLColAttributeUniqueValue) { + // Mock server limitation: returns false for auto-increment column + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + SQLLEN value; + GetSQLColAttributeNumeric(this->stmt, wsql, 1, SQL_DESC_AUTO_UNIQUE_VALUE, &value); + ASSERT_EQ(SQL_FALSE, value); +} + +TEST_F(ColumnsOdbcV2MockTest, TestSQLColAttributesAutoIncrement) { + // Tests ODBC 2.0 API SQLColAttributes + // Mock server limitation: returns false for auto-increment column + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + SQLLEN value; + GetSQLColAttributeNumeric(this->stmt, wsql, 1, SQL_COLUMN_AUTO_INCREMENT, &value); + ASSERT_EQ(SQL_FALSE, value); +} + +TEST_F(ColumnsMockTest, TestSQLColAttributeBaseTableName) { + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + std::wstring value; + GetSQLColAttributeString(this->stmt, wsql, 1, SQL_DESC_BASE_TABLE_NAME, value); + ASSERT_EQ(std::wstring(L"AllTypesTable"), value); +} + +TEST_F(ColumnsOdbcV2MockTest, TestSQLColAttributesTableName) { + // Tests ODBC 2.0 API SQLColAttributes + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + std::wstring value; + GetSQLColAttributesString(this->stmt, wsql, 1, SQL_COLUMN_TABLE_NAME, value); + ASSERT_EQ(std::wstring(L"AllTypesTable"), value); +} + +TEST_F(ColumnsMockTest, TestSQLColAttributeCatalogName) { + // Mock server limitattion: mock doesn't return catalog for result metadata, + // and the defautl catalog should be 'main' + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + std::wstring value; + GetSQLColAttributeString(this->stmt, wsql, 1, SQL_DESC_CATALOG_NAME, value); + ASSERT_EQ(std::wstring(L""), value); +} + +TEST_F(ColumnsRemoteTest, TestSQLColAttributeCatalogName) { + // Remote server does not have catalogs + + std::wstring wsql = L"SELECT * from $scratch.ODBCTest;"; + std::wstring value; + GetSQLColAttributeString(this->stmt, wsql, 1, SQL_DESC_CATALOG_NAME, value); + ASSERT_EQ(std::wstring(L""), value); +} + +TEST_F(ColumnsOdbcV2MockTest, TestSQLColAttributesQualifierName) { + // Mock server limitattion: mock doesn't return catalog for result metadata, + // and the defautl catalog should be 'main' + // Tests ODBC 2.0 API SQLColAttributes + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + std::wstring value; + GetSQLColAttributeString(this->stmt, wsql, 1, SQL_COLUMN_QUALIFIER_NAME, value); + ASSERT_EQ(std::wstring(L""), value); +} + +TEST_F(ColumnsOdbcV2RemoteTest, TestSQLColAttributesQualifierName) { + // Remote server does not have catalogs + // Tests ODBC 2.0 API SQLColAttributes + std::wstring wsql = L"SELECT * from $scratch.ODBCTest;"; + std::wstring value; + GetSQLColAttributeString(this->stmt, wsql, 1, SQL_COLUMN_QUALIFIER_NAME, value); + ASSERT_EQ(std::wstring(L""), value); +} + +TYPED_TEST(ColumnsTest, TestSQLColAttributeCount) { + std::wstring wsql = this->GetQueryAllDataTypes(); + // Pass 0 as column number, driver should ignore it + SQLLEN value; + GetSQLColAttributeNumeric(this->stmt, wsql, 0, SQL_DESC_COUNT, &value); + ASSERT_EQ(32, value); +} + +TEST_F(ColumnsMockTest, TestSQLColAttributeLocalTypeName) { + std::wstring wsql = this->GetQueryAllDataTypes(); + // Mock server doesn't have local type name + std::wstring value; + GetSQLColAttributeString(this->stmt, wsql, 1, SQL_DESC_LOCAL_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L""), value); +} + +TEST_F(ColumnsRemoteTest, TestSQLColAttributeLocalTypeName) { + std::wstring wsql = this->GetQueryAllDataTypes(); + std::wstring value; + GetSQLColAttributesString(this->stmt, wsql, 1, SQL_DESC_LOCAL_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"INTEGER"), value); +} + +TEST_F(ColumnsMockTest, TestSQLColAttributeSchemaName) { + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + // Mock server doesn't have schemas + std::wstring value; + GetSQLColAttributeString(this->stmt, wsql, 1, SQL_DESC_SCHEMA_NAME, value); + ASSERT_EQ(std::wstring(L""), value); +} + +TEST_F(ColumnsRemoteTest, TestSQLColAttributeSchemaName) { + // Test assumes there is a table $scratch.ODBCTest in remote server + + std::wstring wsql = L"SELECT * from $scratch.ODBCTest;"; + // Remote server limitation: doesn't return schema name, expected schema name is + // $scratch + std::wstring value; + GetSQLColAttributeString(this->stmt, wsql, 1, SQL_DESC_SCHEMA_NAME, value); + ASSERT_EQ(std::wstring(L""), value); +} + +TEST_F(ColumnsOdbcV2MockTest, TestSQLColAttributesOwnerName) { + // Tests ODBC 2.0 API SQLColAttributes + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + // Mock server doesn't have schemas + std::wstring value; + GetSQLColAttributesString(this->stmt, wsql, 1, SQL_COLUMN_OWNER_NAME, value); + ASSERT_EQ(std::wstring(L""), value); +} + +TEST_F(ColumnsOdbcV2RemoteTest, TestSQLColAttributesOwnerName) { + // Test assumes there is a table $scratch.ODBCTest in remote server + // Tests ODBC 2.0 API SQLColAttributes + std::wstring wsql = L"SELECT * from $scratch.ODBCTest;"; + // Remote server limitation: doesn't return schema name, expected schema name is + // $scratch + std::wstring value; + GetSQLColAttributesString(this->stmt, wsql, 1, SQL_COLUMN_OWNER_NAME, value); + ASSERT_EQ(std::wstring(L""), value); +} + +TEST_F(ColumnsMockTest, TestSQLColAttributeTableName) { + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + std::wstring value; + GetSQLColAttributeString(this->stmt, wsql, 1, SQL_DESC_TABLE_NAME, value); + ASSERT_EQ(std::wstring(L"AllTypesTable"), value); +} + +TEST_F(ColumnsMockTest, TestSQLColAttributeTypeName) { + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + std::wstring value; + GetSQLColAttributeString(this->stmt, wsql, 1, SQL_DESC_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"BIGINT"), value); + GetSQLColAttributeString(this->stmt, L"", 2, SQL_DESC_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"WVARCHAR"), value); + GetSQLColAttributeString(this->stmt, L"", 3, SQL_DESC_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"BINARY"), value); + GetSQLColAttributeString(this->stmt, L"", 4, SQL_DESC_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"DOUBLE"), value); +} + +TEST_F(ColumnsRemoteTest, TestSQLColAttributeTypeName) { + std::wstring wsql = L"SELECT * from $scratch.ODBCTest;"; + std::wstring value; + GetSQLColAttributeString(this->stmt, wsql, 1, SQL_DESC_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"INTEGER"), value); + GetSQLColAttributeString(this->stmt, L"", 2, SQL_DESC_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"BIGINT"), value); + GetSQLColAttributeString(this->stmt, L"", 3, SQL_DESC_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"DECIMAL"), value); + GetSQLColAttributeString(this->stmt, L"", 4, SQL_DESC_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"FLOAT"), value); + GetSQLColAttributeString(this->stmt, L"", 5, SQL_DESC_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"DOUBLE"), value); + GetSQLColAttributeString(this->stmt, L"", 6, SQL_DESC_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"BOOLEAN"), value); + GetSQLColAttributeString(this->stmt, L"", 7, SQL_DESC_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"DATE"), value); + GetSQLColAttributeString(this->stmt, L"", 8, SQL_DESC_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"TIME"), value); + GetSQLColAttributeString(this->stmt, L"", 9, SQL_DESC_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"TIMESTAMP"), value); +} + +TEST_F(ColumnsOdbcV2MockTest, TestSQLColAttributesTypeName) { + // Tests ODBC 2.0 API SQLColAttributes + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + // Mock server doesn't return data source-dependent data type name + std::wstring value; + GetSQLColAttributesString(this->stmt, wsql, 1, SQL_COLUMN_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"BIGINT"), value); + GetSQLColAttributesString(this->stmt, L"", 2, SQL_COLUMN_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"WVARCHAR"), value); + GetSQLColAttributesString(this->stmt, L"", 3, SQL_COLUMN_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"BINARY"), value); + GetSQLColAttributesString(this->stmt, L"", 4, SQL_COLUMN_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"DOUBLE"), value); +} + +TEST_F(ColumnsOdbcV2RemoteTest, TestSQLColAttributesTypeName) { + // Tests ODBC 2.0 API SQLColAttributes + std::wstring wsql = L"SELECT * from $scratch.ODBCTest;"; + std::wstring value; + GetSQLColAttributesString(this->stmt, wsql, 1, SQL_COLUMN_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"INTEGER"), value); + GetSQLColAttributesString(this->stmt, L"", 2, SQL_COLUMN_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"BIGINT"), value); + GetSQLColAttributesString(this->stmt, L"", 3, SQL_COLUMN_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"DECIMAL"), value); + GetSQLColAttributesString(this->stmt, L"", 4, SQL_COLUMN_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"FLOAT"), value); + GetSQLColAttributesString(this->stmt, L"", 5, SQL_COLUMN_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"DOUBLE"), value); + GetSQLColAttributesString(this->stmt, L"", 6, SQL_COLUMN_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"BOOLEAN"), value); + GetSQLColAttributesString(this->stmt, L"", 7, SQL_COLUMN_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"DATE"), value); + GetSQLColAttributesString(this->stmt, L"", 8, SQL_COLUMN_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"TIME"), value); + GetSQLColAttributesString(this->stmt, L"", 9, SQL_COLUMN_TYPE_NAME, value); + ASSERT_EQ(std::wstring(L"TIMESTAMP"), value); +} + +TYPED_TEST(ColumnsTest, TestSQLColAttributeUnnamed) { + std::wstring wsql = this->GetQueryAllDataTypes(); + SQLLEN value; + GetSQLColAttributeNumeric(this->stmt, wsql, 1, SQL_DESC_UNNAMED, &value); + ASSERT_EQ(SQL_NAMED, value); +} + +TYPED_TEST(ColumnsTest, TestSQLColAttributeUpdatable) { + std::wstring wsql = this->GetQueryAllDataTypes(); + // Mock server and remote server do not return updatable information + SQLLEN value; + GetSQLColAttributeNumeric(this->stmt, wsql, 1, SQL_DESC_UPDATABLE, &value); + ASSERT_EQ(SQL_ATTR_READWRITE_UNKNOWN, value); +} + +TYPED_TEST(ColumnsOdbcV2Test, TestSQLColAttributesUpdatable) { + // Tests ODBC 2.0 API SQLColAttributes + std::wstring wsql = this->GetQueryAllDataTypes(); + // Mock server and remote server do not return updatable information + SQLLEN value; + GetSQLColAttributesNumeric(this->stmt, wsql, 1, SQL_COLUMN_UPDATABLE, &value); + ASSERT_EQ(SQL_ATTR_READWRITE_UNKNOWN, value); +} + +TEST_F(ColumnsMockTest, SQLDescribeColValidateInput) { + this->CreateTestTables(); + + SQLWCHAR sql_query[] = L"SELECT * FROM TestTable LIMIT 1;"; + SQLINTEGER query_length = static_cast(wcslen(sql_query)); + + SQLUSMALLINT bookmark_column = 0; + SQLUSMALLINT out_of_range_column = 4; + SQLUSMALLINT negative_column = -1; + SQLWCHAR column_name[1024]; + SQLSMALLINT buf_char_len = + static_cast(sizeof(column_name) / ODBC::GetSqlWCharSize()); + SQLSMALLINT name_length = 0; + SQLSMALLINT data_type = 0; + SQLULEN column_size = 0; + SQLSMALLINT decimal_digits = 0; + SQLSMALLINT nullable = 0; + + ASSERT_EQ(SQL_SUCCESS, SQLExecDirect(this->stmt, sql_query, query_length)); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Invalid descriptor index - Bookmarks are not supported + EXPECT_EQ(SQL_ERROR, SQLDescribeCol(this->stmt, bookmark_column, column_name, + buf_char_len, &name_length, &data_type, + &column_size, &decimal_digits, &nullable)); + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState07009); + + // Invalid descriptor index - index out of range + EXPECT_EQ(SQL_ERROR, SQLDescribeCol(this->stmt, out_of_range_column, column_name, + buf_char_len, &name_length, &data_type, + &column_size, &decimal_digits, &nullable)); + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState07009); + + // Invalid descriptor index - index out of range + EXPECT_EQ(SQL_ERROR, SQLDescribeCol(this->stmt, negative_column, column_name, + buf_char_len, &name_length, &data_type, + &column_size, &decimal_digits, &nullable)); + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState07009); +} + +TEST_F(ColumnsMockTest, SQLDescribeColQueryAllDataTypesMetadata) { + // Mock server has a limitation where only SQL_WVARCHAR column type values are returned + // from SELECT AS queries + + SQLWCHAR column_name[1024]; + SQLSMALLINT buf_char_len = + static_cast(sizeof(column_name) / ODBC::GetSqlWCharSize()); + SQLSMALLINT name_length = 0; + SQLSMALLINT column_data_type = 0; + SQLULEN column_size = 0; + SQLSMALLINT decimal_digits = 0; + SQLSMALLINT nullable = 0; + size_t column_index = 0; + + std::wstring wsql = this->GetQueryAllDataTypes(); + std::vector sql0(wsql.begin(), wsql.end()); + + const SQLWCHAR* column_names[] = {static_cast(L"stiny_int_min"), + static_cast(L"stiny_int_max"), + static_cast(L"utiny_int_min"), + static_cast(L"utiny_int_max"), + static_cast(L"ssmall_int_min"), + static_cast(L"ssmall_int_max"), + static_cast(L"usmall_int_min"), + static_cast(L"usmall_int_max"), + static_cast(L"sinteger_min"), + static_cast(L"sinteger_max"), + static_cast(L"uinteger_min"), + static_cast(L"uinteger_max"), + static_cast(L"sbigint_min"), + static_cast(L"sbigint_max"), + static_cast(L"ubigint_min"), + static_cast(L"ubigint_max"), + static_cast(L"decimal_negative"), + static_cast(L"decimal_positive"), + static_cast(L"float_min"), + static_cast(L"float_max"), + static_cast(L"double_min"), + static_cast(L"double_max"), + static_cast(L"bit_false"), + static_cast(L"bit_true"), + static_cast(L"c_char"), + static_cast(L"c_wchar"), + static_cast(L"c_wvarchar"), + static_cast(L"c_varchar"), + static_cast(L"date_min"), + static_cast(L"date_max"), + static_cast(L"timestamp_min"), + static_cast(L"timestamp_max")}; + SQLSMALLINT column_data_types[] = { + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR}; + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + for (size_t i = 0; i < sizeof(column_names) / sizeof(*column_names); ++i) { + column_index = i + 1; + + ASSERT_EQ(SQL_SUCCESS, SQLDescribeCol(this->stmt, column_index, column_name, + buf_char_len, &name_length, &column_data_type, + &column_size, &decimal_digits, &nullable)); + + EXPECT_EQ(wcslen(column_names[i]), name_length); + + std::wstring returned(column_name, column_name + name_length); + EXPECT_EQ(column_names[i], returned); + EXPECT_EQ(column_data_types[i], column_data_type); + EXPECT_EQ(1024, column_size); + EXPECT_EQ(0, decimal_digits); + EXPECT_EQ(SQL_NULLABLE, nullable); + + name_length = 0; + column_data_type = 0; + column_size = 0; + decimal_digits = 0; + nullable = 0; + } +} + +TEST_F(ColumnsRemoteTest, SQLDescribeColQueryAllDataTypesMetadata) { + SQLWCHAR column_name[1024]; + SQLSMALLINT buf_char_len = + static_cast(sizeof(column_name) / ODBC::GetSqlWCharSize()); + SQLSMALLINT name_length = 0; + SQLSMALLINT column_data_type = 0; + SQLULEN column_size = 0; + SQLSMALLINT decimal_digits = 0; + SQLSMALLINT nullable = 0; + size_t column_index = 0; + + std::wstring wsql = this->GetQueryAllDataTypes(); + std::vector sql0(wsql.begin(), wsql.end()); + + const SQLWCHAR* column_names[] = {static_cast(L"stiny_int_min"), + static_cast(L"stiny_int_max"), + static_cast(L"utiny_int_min"), + static_cast(L"utiny_int_max"), + static_cast(L"ssmall_int_min"), + static_cast(L"ssmall_int_max"), + static_cast(L"usmall_int_min"), + static_cast(L"usmall_int_max"), + static_cast(L"sinteger_min"), + static_cast(L"sinteger_max"), + static_cast(L"uinteger_min"), + static_cast(L"uinteger_max"), + static_cast(L"sbigint_min"), + static_cast(L"sbigint_max"), + static_cast(L"ubigint_min"), + static_cast(L"ubigint_max"), + static_cast(L"decimal_negative"), + static_cast(L"decimal_positive"), + static_cast(L"float_min"), + static_cast(L"float_max"), + static_cast(L"double_min"), + static_cast(L"double_max"), + static_cast(L"bit_false"), + static_cast(L"bit_true"), + static_cast(L"c_char"), + static_cast(L"c_wchar"), + static_cast(L"c_wvarchar"), + static_cast(L"c_varchar"), + static_cast(L"date_min"), + static_cast(L"date_max"), + static_cast(L"timestamp_min"), + static_cast(L"timestamp_max")}; + SQLSMALLINT column_data_types[] = { + SQL_INTEGER, SQL_INTEGER, SQL_INTEGER, SQL_INTEGER, SQL_INTEGER, + SQL_INTEGER, SQL_INTEGER, SQL_INTEGER, SQL_INTEGER, SQL_INTEGER, + SQL_BIGINT, SQL_BIGINT, SQL_BIGINT, SQL_BIGINT, SQL_BIGINT, + SQL_WVARCHAR, SQL_DECIMAL, SQL_DECIMAL, SQL_FLOAT, SQL_FLOAT, + SQL_DOUBLE, SQL_DOUBLE, SQL_BIT, SQL_BIT, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_TYPE_DATE, SQL_TYPE_DATE, + SQL_TYPE_TIMESTAMP, SQL_TYPE_TIMESTAMP}; + SQLULEN column_sizes[] = {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, + 8, 8, 8, 8, 65536, 19, 19, 8, 8, 8, 8, + 1, 1, 65536, 65536, 65536, 65536, 10, 10, 23, 23}; + SQLULEN column_decimal_digits[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 10, 23, 23}; + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + for (size_t i = 0; i < sizeof(column_names) / sizeof(*column_names); ++i) { + column_index = i + 1; + + ASSERT_EQ(SQL_SUCCESS, SQLDescribeCol(this->stmt, column_index, column_name, + buf_char_len, &name_length, &column_data_type, + &column_size, &decimal_digits, &nullable)); + + EXPECT_EQ(wcslen(column_names[i]), name_length); + + std::wstring returned(column_name, column_name + name_length); + EXPECT_EQ(column_names[i], returned); + EXPECT_EQ(column_data_types[i], column_data_type); + EXPECT_EQ(column_sizes[i], column_size); + EXPECT_EQ(column_decimal_digits[i], decimal_digits); + EXPECT_EQ(SQL_NULLABLE, nullable); + + name_length = 0; + column_data_type = 0; + column_size = 0; + decimal_digits = 0; + nullable = 0; + } +} + +TEST_F(ColumnsRemoteTest, SQLDescribeColODBCTestTableMetadata) { + // Test assumes there is a table $scratch.ODBCTest in remote server + + SQLWCHAR column_name[1024]; + SQLSMALLINT buf_char_len = + static_cast(sizeof(column_name) / ODBC::GetSqlWCharSize()); + SQLSMALLINT name_length = 0; + SQLSMALLINT column_data_type = 0; + SQLULEN column_size = 0; + SQLSMALLINT decimal_digits = 0; + SQLSMALLINT nullable = 0; + size_t column_index = 0; + + SQLWCHAR sql_query[] = L"SELECT * from $scratch.ODBCTest LIMIT 1;"; + SQLINTEGER query_length = static_cast(wcslen(sql_query)); + + const SQLWCHAR* column_names[] = {static_cast(L"sinteger_max"), + static_cast(L"sbigint_max"), + static_cast(L"decimal_positive"), + static_cast(L"float_max"), + static_cast(L"double_max"), + static_cast(L"bit_true"), + static_cast(L"date_max"), + static_cast(L"time_max"), + static_cast(L"timestamp_max")}; + SQLSMALLINT column_data_types[] = {SQL_INTEGER, SQL_BIGINT, SQL_DECIMAL, + SQL_FLOAT, SQL_DOUBLE, SQL_BIT, + SQL_TYPE_DATE, SQL_TYPE_TIME, SQL_TYPE_TIMESTAMP}; + SQLULEN column_sizes[] = {4, 8, 19, 8, 8, 1, 10, 12, 23}; + SQLULEN columndecimal_digits[] = {0, 0, 0, 0, 0, 0, 10, 12, 23}; + + ASSERT_EQ(SQL_SUCCESS, SQLExecDirect(this->stmt, sql_query, query_length)); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + for (size_t i = 0; i < sizeof(column_names) / sizeof(*column_names); ++i) { + column_index = i + 1; + + ASSERT_EQ(SQL_SUCCESS, SQLDescribeCol(this->stmt, column_index, column_name, + buf_char_len, &name_length, &column_data_type, + &column_size, &decimal_digits, &nullable)); + + EXPECT_EQ(wcslen(column_names[i]), name_length); + + std::wstring returned(column_name, column_name + name_length); + EXPECT_EQ(column_names[i], returned); + EXPECT_EQ(column_data_types[i], column_data_type); + EXPECT_EQ(column_sizes[i], column_size); + EXPECT_EQ(columndecimal_digits[i], decimal_digits); + EXPECT_EQ(SQL_NULLABLE, nullable); + + name_length = 0; + column_data_type = 0; + column_size = 0; + decimal_digits = 0; + nullable = 0; + } +} + +TEST_F(ColumnsOdbcV2RemoteTest, SQLDescribeColODBCTestTableMetadataODBC2) { + // Test assumes there is a table $scratch.ODBCTest in remote server + SQLWCHAR column_name[1024]; + SQLSMALLINT buf_char_len = + static_cast(sizeof(column_name) / ODBC::GetSqlWCharSize()); + SQLSMALLINT name_length = 0; + SQLSMALLINT column_data_type = 0; + SQLULEN column_size = 0; + SQLSMALLINT decimal_digits = 0; + SQLSMALLINT nullable = 0; + size_t column_index = 0; + + SQLWCHAR sql_query[] = L"SELECT * from $scratch.ODBCTest LIMIT 1;"; + SQLINTEGER query_length = static_cast(wcslen(sql_query)); + + const SQLWCHAR* column_names[] = {static_cast(L"sinteger_max"), + static_cast(L"sbigint_max"), + static_cast(L"decimal_positive"), + static_cast(L"float_max"), + static_cast(L"double_max"), + static_cast(L"bit_true"), + static_cast(L"date_max"), + static_cast(L"time_max"), + static_cast(L"timestamp_max")}; + SQLSMALLINT column_data_types[] = {SQL_INTEGER, SQL_BIGINT, SQL_DECIMAL, + SQL_FLOAT, SQL_DOUBLE, SQL_BIT, + SQL_DATE, SQL_TIME, SQL_TIMESTAMP}; + SQLULEN column_sizes[] = {4, 8, 19, 8, 8, 1, 10, 12, 23}; + SQLULEN columndecimal_digits[] = {0, 0, 0, 0, 0, 0, 10, 12, 23}; + + ASSERT_EQ(SQL_SUCCESS, SQLExecDirect(this->stmt, sql_query, query_length)); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + for (size_t i = 0; i < sizeof(column_names) / sizeof(*column_names); ++i) { + column_index = i + 1; + + ASSERT_EQ(SQL_SUCCESS, SQLDescribeCol(this->stmt, column_index, column_name, + buf_char_len, &name_length, &column_data_type, + &column_size, &decimal_digits, &nullable)); + + EXPECT_EQ(wcslen(column_names[i]), name_length); + + std::wstring returned(column_name, column_name + name_length); + EXPECT_EQ(column_names[i], returned); + EXPECT_EQ(column_data_types[i], column_data_type); + EXPECT_EQ(column_sizes[i], column_size); + EXPECT_EQ(columndecimal_digits[i], decimal_digits); + EXPECT_EQ(SQL_NULLABLE, nullable); + + name_length = 0; + column_data_type = 0; + column_size = 0; + decimal_digits = 0; + nullable = 0; + } +} + +TEST_F(ColumnsMockTest, SQLDescribeColAllTypesTableMetadata) { + this->CreateTableAllDataType(); + + SQLWCHAR column_name[1024]; + SQLSMALLINT buf_char_len = + static_cast(sizeof(column_name) / ODBC::GetSqlWCharSize()); + SQLSMALLINT name_length = 0; + SQLSMALLINT column_data_type = 0; + SQLULEN column_size = 0; + SQLSMALLINT decimal_digits = 0; + SQLSMALLINT nullable = 0; + size_t column_index = 0; + + SQLWCHAR sql_query[] = L"SELECT * from AllTypesTable LIMIT 1;"; + SQLINTEGER query_length = static_cast(wcslen(sql_query)); + + const SQLWCHAR* column_names[] = {static_cast(L"bigint_col"), + static_cast(L"char_col"), + static_cast(L"varbinary_col"), + static_cast(L"double_col")}; + SQLSMALLINT column_data_types[] = {SQL_BIGINT, SQL_WVARCHAR, SQL_BINARY, SQL_DOUBLE}; + SQLULEN column_sizes[] = {8, 0, 0, 8}; + + ASSERT_EQ(SQL_SUCCESS, SQLExecDirect(this->stmt, sql_query, query_length)); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + for (size_t i = 0; i < sizeof(column_names) / sizeof(*column_names); ++i) { + column_index = i + 1; + + ASSERT_EQ(SQL_SUCCESS, SQLDescribeCol(this->stmt, column_index, column_name, + buf_char_len, &name_length, &column_data_type, + &column_size, &decimal_digits, &nullable)); + + EXPECT_EQ(wcslen(column_names[i]), name_length); + + std::wstring returned(column_name, column_name + name_length); + EXPECT_EQ(column_names[i], returned); + EXPECT_EQ(column_data_types[i], column_data_type); + EXPECT_EQ(column_sizes[i], column_size); + EXPECT_EQ(0, decimal_digits); + EXPECT_EQ(SQL_NULLABLE, nullable); + + name_length = 0; + column_data_type = 0; + column_size = 0; + decimal_digits = 0; + nullable = 0; + } +} + +TEST_F(ColumnsMockTest, SQLDescribeColUnicodeTableMetadata) { + this->CreateUnicodeTable(); + + SQLWCHAR column_name[1024]; + SQLSMALLINT buf_char_len = + static_cast(sizeof(column_name) / ODBC::GetSqlWCharSize()); + SQLSMALLINT name_length = 0; + SQLSMALLINT column_data_type = 0; + SQLULEN column_size = 0; + SQLSMALLINT decimal_digits = 0; + SQLSMALLINT nullable = 0; + size_t column_index = 1; + + SQLWCHAR sql_query[] = L"SELECT * from 数据 LIMIT 1;"; + SQLINTEGER query_length = static_cast(wcslen(sql_query)); + + SQLWCHAR expected_column_name[] = L"资料"; + SQLSMALLINT expected_column_data_type = SQL_WVARCHAR; + SQLULEN expected_column_size = 0; + + ASSERT_EQ(SQL_SUCCESS, SQLExecDirect(this->stmt, sql_query, query_length)); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + ASSERT_EQ(SQL_SUCCESS, SQLDescribeCol(this->stmt, column_index, column_name, + buf_char_len, &name_length, &column_data_type, + &column_size, &decimal_digits, &nullable)); + + EXPECT_EQ(name_length, wcslen(expected_column_name)); + + std::wstring returned(column_name, column_name + name_length); + EXPECT_EQ(returned, expected_column_name); + EXPECT_EQ(column_data_type, expected_column_data_type); + EXPECT_EQ(column_size, expected_column_size); + EXPECT_EQ(0, decimal_digits); + EXPECT_EQ(SQL_NULLABLE, nullable); +} + +TYPED_TEST(ColumnsTest, SQLColumnsGetMetadataBySQLDescribeCol) { + SQLWCHAR column_name[1024]; + SQLSMALLINT buf_char_len = + static_cast(sizeof(column_name) / ODBC::GetSqlWCharSize()); + SQLSMALLINT name_length = 0; + SQLSMALLINT column_data_type = 0; + SQLULEN column_size = 0; + SQLSMALLINT decimal_digits = 0; + SQLSMALLINT nullable = 0; + size_t column_index = 0; + + const SQLWCHAR* column_names[] = {static_cast(L"TABLE_CAT"), + static_cast(L"TABLE_SCHEM"), + static_cast(L"TABLE_NAME"), + static_cast(L"COLUMN_NAME"), + static_cast(L"DATA_TYPE"), + static_cast(L"TYPE_NAME"), + static_cast(L"COLUMN_SIZE"), + static_cast(L"BUFFER_LENGTH"), + static_cast(L"DECIMAL_DIGITS"), + static_cast(L"NUM_PREC_RADIX"), + static_cast(L"NULLABLE"), + static_cast(L"REMARKS"), + static_cast(L"COLUMN_DEF"), + static_cast(L"SQL_DATA_TYPE"), + static_cast(L"SQL_DATETIME_SUB"), + static_cast(L"CHAR_OCTET_LENGTH"), + static_cast(L"ORDINAL_POSITION"), + static_cast(L"IS_NULLABLE")}; + SQLSMALLINT column_data_types[] = { + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_SMALLINT, SQL_WVARCHAR, + SQL_INTEGER, SQL_INTEGER, SQL_SMALLINT, SQL_SMALLINT, SQL_SMALLINT, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_SMALLINT, SQL_SMALLINT, SQL_INTEGER, SQL_INTEGER, SQL_WVARCHAR}; + SQLULEN column_sizes[] = {1024, 1024, 1024, 1024, 2, 1024, 4, 4, 2, + 2, 2, 1024, 1024, 2, 2, 4, 4, 1024}; + + ASSERT_EQ(SQL_SUCCESS, + SQLColumns(this->stmt, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0)); + + for (size_t i = 0; i < sizeof(column_names) / sizeof(*column_names); ++i) { + column_index = i + 1; + + ASSERT_EQ(SQL_SUCCESS, SQLDescribeCol(this->stmt, column_index, column_name, + buf_char_len, &name_length, &column_data_type, + &column_size, &decimal_digits, &nullable)); + + EXPECT_EQ(wcslen(column_names[i]), name_length); + + std::wstring returned(column_name, column_name + name_length); + EXPECT_EQ(column_names[i], returned); + EXPECT_EQ(column_data_types[i], column_data_type); + EXPECT_EQ(column_sizes[i], column_size); + EXPECT_EQ(0, decimal_digits); + EXPECT_EQ(SQL_NULLABLE, nullable); + + name_length = 0; + column_data_type = 0; + column_size = 0; + decimal_digits = 0; + nullable = 0; + } +} + +TYPED_TEST(ColumnsOdbcV2Test, SQLColumnsGetMetadataBySQLDescribeColODBC2) { + SQLWCHAR column_name[1024]; + SQLSMALLINT buf_char_len = + static_cast(sizeof(column_name) / ODBC::GetSqlWCharSize()); + SQLSMALLINT name_length = 0; + SQLSMALLINT column_data_type = 0; + SQLULEN column_size = 0; + SQLSMALLINT decimal_digits = 0; + SQLSMALLINT nullable = 0; + size_t column_index = 0; + + const SQLWCHAR* column_names[] = {static_cast(L"TABLE_QUALIFIER"), + static_cast(L"TABLE_OWNER"), + static_cast(L"TABLE_NAME"), + static_cast(L"COLUMN_NAME"), + static_cast(L"DATA_TYPE"), + static_cast(L"TYPE_NAME"), + static_cast(L"PRECISION"), + static_cast(L"LENGTH"), + static_cast(L"SCALE"), + static_cast(L"RADIX"), + static_cast(L"NULLABLE"), + static_cast(L"REMARKS"), + static_cast(L"COLUMN_DEF"), + static_cast(L"SQL_DATA_TYPE"), + static_cast(L"SQL_DATETIME_SUB"), + static_cast(L"CHAR_OCTET_LENGTH"), + static_cast(L"ORDINAL_POSITION"), + static_cast(L"IS_NULLABLE")}; + SQLSMALLINT column_data_types[] = { + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_SMALLINT, SQL_WVARCHAR, + SQL_INTEGER, SQL_INTEGER, SQL_SMALLINT, SQL_SMALLINT, SQL_SMALLINT, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_SMALLINT, SQL_SMALLINT, SQL_INTEGER, SQL_INTEGER, SQL_WVARCHAR}; + SQLULEN column_sizes[] = {1024, 1024, 1024, 1024, 2, 1024, 4, 4, 2, + 2, 2, 1024, 1024, 2, 2, 4, 4, 1024}; + + ASSERT_EQ(SQL_SUCCESS, + SQLColumns(this->stmt, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0)); + + for (size_t i = 0; i < sizeof(column_names) / sizeof(*column_names); ++i) { + column_index = i + 1; + + ASSERT_EQ(SQL_SUCCESS, SQLDescribeCol(this->stmt, column_index, column_name, + buf_char_len, &name_length, &column_data_type, + &column_size, &decimal_digits, &nullable)); + + EXPECT_EQ(wcslen(column_names[i]), name_length); + + std::wstring returned(column_name, column_name + name_length); + EXPECT_EQ(column_names[i], returned); + EXPECT_EQ(column_data_types[i], column_data_type); + EXPECT_EQ(column_sizes[i], column_size); + EXPECT_EQ(0, decimal_digits); + EXPECT_EQ(SQL_NULLABLE, nullable); + + name_length = 0; + column_data_type = 0; + column_size = 0; + decimal_digits = 0; + nullable = 0; + } +} +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/connection_attr_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/connection_attr_test.cc new file mode 100644 index 00000000000..b7019cdda42 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/connection_attr_test.cc @@ -0,0 +1,362 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#include "arrow/flight/sql/odbc/odbc_impl/platform.h" + +#include +#include +#include + +#include + +namespace arrow::flight::sql::odbc { + +template +class ConnectionAttributeTest : public T {}; + +using TestTypes = + ::testing::Types; +TYPED_TEST_SUITE(ConnectionAttributeTest, TestTypes); + +#ifdef SQL_ATTR_ASYNC_DBC_EVENT +TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrAsyncDbcEventUnsupported) { + ASSERT_EQ(SQL_ERROR, SQLSetConnectAttr(this->conn, SQL_ATTR_ASYNC_DBC_EVENT, 0, 0)); + // Driver Manager on Windows returns error code HY118 + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHY118); +} +#endif + +#ifdef SQL_ATTR_ASYNC_ENABLE +TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrAyncEnableUnsupported) { + ASSERT_EQ(SQL_ERROR, SQLSetConnectAttr(this->conn, SQL_ATTR_ASYNC_ENABLE, 0, 0)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHYC00); +} +#endif + +#ifdef SQL_ATTR_ASYNC_DBC_PCALLBACK +TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrAyncDbcPcCallbackUnsupported) { + ASSERT_EQ(SQL_ERROR, SQLSetConnectAttr(this->conn, SQL_ATTR_ASYNC_DBC_PCALLBACK, 0, 0)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHYC00); +} +#endif + +#ifdef SQL_ATTR_ASYNC_DBC_PCONTEXT +TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrAyncDbcPcContextUnsupported) { + ASSERT_EQ(SQL_ERROR, SQLSetConnectAttr(this->conn, SQL_ATTR_ASYNC_DBC_PCONTEXT, 0, 0)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHYC00); +} +#endif + +TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrAutoIpdReadOnly) { + // Verify read-only attribute cannot be set + ASSERT_EQ(SQL_ERROR, SQLSetConnectAttr(this->conn, SQL_ATTR_AUTO_IPD, 0, 0)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHY092); +} + +TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrConnectionDeadReadOnly) { + // Verify read-only attribute cannot be set + ASSERT_EQ(SQL_ERROR, SQLSetConnectAttr(this->conn, SQL_ATTR_CONNECTION_DEAD, 0, 0)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHY092); +} + +#ifdef SQL_ATTR_DBC_INFO_TOKEN +TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrDbcInfoTokenUnsupported) { + ASSERT_EQ(SQL_ERROR, SQLSetConnectAttr(this->conn, SQL_ATTR_DBC_INFO_TOKEN, 0, 0)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHYC00); +} +#endif + +TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrEnlistInDtcUnsupported) { + ASSERT_EQ(SQL_ERROR, SQLSetConnectAttr(this->conn, SQL_ATTR_ENLIST_IN_DTC, 0, 0)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHYC00); +} + +TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrOdbcCursorsDMOnly) { + this->AllocEnvConnHandles(); + + // Verify DM-only attribute is settable via Driver Manager + ASSERT_EQ(SQL_SUCCESS, + SQLSetConnectAttr(this->conn, SQL_ATTR_ODBC_CURSORS, + reinterpret_cast(SQL_CUR_USE_DRIVER), 0)); + + std::string connect_str = this->GetConnectionString(); + this->ConnectWithString(connect_str); +} + +TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrQuietModeReadOnly) { + // Verify read-only attribute cannot be set + ASSERT_EQ(SQL_ERROR, SQLSetConnectAttr(this->conn, SQL_ATTR_QUIET_MODE, 0, 0)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHY092); +} + +TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrTraceDMOnly) { + // Verify DM-only attribute is settable via Driver Manager + ASSERT_EQ(SQL_SUCCESS, + SQLSetConnectAttr(this->conn, SQL_ATTR_TRACE, + reinterpret_cast(SQL_OPT_TRACE_OFF), 0)); +} + +TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrTracefileDMOnly) { + // Verify DM-only attribute is handled by Driver Manager + + // Use placeholder value as we want the call to fail, or else + // the driver manager will produce a trace file. + std::wstring trace_file = L"invalid/file/path"; + std::vector trace_file0(trace_file.begin(), trace_file.end()); + ASSERT_EQ(SQL_ERROR, SQLSetConnectAttr(this->conn, SQL_ATTR_TRACEFILE, &trace_file0[0], + static_cast(trace_file0.size()))); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHY000); +} + +TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrTranslateLabDMOnly) { + // Verify DM-only attribute is handled by Driver Manager + ASSERT_EQ(SQL_ERROR, SQLSetConnectAttr(this->conn, SQL_ATTR_TRANSLATE_LIB, 0, 0)); + // Checks for invalid argument return error + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHY024); +} + +TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrTranslateOptionUnsupported) { + ASSERT_EQ(SQL_ERROR, SQLSetConnectAttr(this->conn, SQL_ATTR_TRANSLATE_OPTION, 0, 0)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHYC00); +} + +TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrTxnIsolationUnsupported) { + ASSERT_EQ(SQL_ERROR, + SQLSetConnectAttr(this->conn, SQL_ATTR_TXN_ISOLATION, + reinterpret_cast(SQL_TXN_READ_UNCOMMITTED), 0)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHYC00); +} + +#ifdef SQL_ATTR_DBC_INFO_TOKEN +TYPED_TEST(ConnectionAttributeTest, TestSQLGetConnectAttrDbcInfoTokenSetOnly) { + // Verify that set-only attribute cannot be read + SQLPOINTER ptr = NULL; + ASSERT_EQ(SQL_ERROR, SQLGetConnectAttr(this->conn, SQL_ATTR_DBC_INFO_TOKEN, ptr, 0, 0)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHY092); +} +#endif + +TYPED_TEST(ConnectionAttributeTest, TestSQLGetConnectAttrOdbcCursorsDMOnly) { + // Verify that DM-only attribute is handled by driver manager + SQLULEN cursor_attr; + ASSERT_EQ(SQL_SUCCESS, + SQLGetConnectAttr(this->conn, SQL_ATTR_ODBC_CURSORS, &cursor_attr, 0, 0)); + EXPECT_EQ(SQL_CUR_USE_DRIVER, cursor_attr); +} + +TYPED_TEST(ConnectionAttributeTest, TestSQLGetConnectAttrTraceDMOnly) { + // Verify that DM-only attribute is handled by driver manager + SQLUINTEGER trace; + ASSERT_EQ(SQL_SUCCESS, SQLGetConnectAttr(this->conn, SQL_ATTR_TRACE, &trace, 0, 0)); + EXPECT_EQ(SQL_OPT_TRACE_OFF, trace); +} + +TYPED_TEST(ConnectionAttributeTest, TestSQLGetConnectAttrTraceFileDMOnly) { + // Verify that DM-only attribute is handled by driver manager + SQLWCHAR out_str[kOdbcBufferSize]; + SQLINTEGER out_str_len; + ASSERT_EQ(SQL_SUCCESS, SQLGetConnectAttr(this->conn, SQL_ATTR_TRACEFILE, out_str, + kOdbcBufferSize, &out_str_len)); + // Length is returned in bytes for SQLGetConnectAttr, + // we want the number of characters + out_str_len /= arrow::flight::sql::odbc::GetSqlWCharSize(); + std::string out_connection_string = + ODBC::SqlWcharToString(out_str, static_cast(out_str_len)); + EXPECT_TRUE(!out_connection_string.empty()); +} + +TYPED_TEST(ConnectionAttributeTest, TestSQLGetConnectAttrTranslateLibUnsupported) { + SQLWCHAR out_str[kOdbcBufferSize]; + SQLINTEGER out_str_len; + ASSERT_EQ(SQL_ERROR, SQLGetConnectAttr(this->conn, SQL_ATTR_TRANSLATE_LIB, out_str, + kOdbcBufferSize, &out_str_len)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHYC00); +} + +TYPED_TEST(ConnectionAttributeTest, TestSQLGetConnectAttrTranslateOptionUnsupported) { + SQLINTEGER option; + ASSERT_EQ(SQL_ERROR, + SQLGetConnectAttr(this->conn, SQL_ATTR_TRANSLATE_OPTION, &option, 0, 0)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHYC00); +} + +TYPED_TEST(ConnectionAttributeTest, TestSQLGetConnectAttrTxnIsolationUnsupported) { + SQLINTEGER isolation; + ASSERT_EQ(SQL_ERROR, + SQLGetConnectAttr(this->conn, SQL_ATTR_TXN_ISOLATION, &isolation, 0, 0)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHYC00); +} + +#ifdef SQL_ATTR_ASYNC_DBC_FUNCTIONS_ENABLE +TYPED_TEST(ConnectionAttributeTest, + TestSQLGetConnectAttrAsyncDbcFunctionsEnableUnsupported) { + // Verifies that the Windows driver manager returns HY114 for unsupported functionality + SQLUINTEGER enable; + ASSERT_EQ(SQL_ERROR, SQLGetConnectAttr(this->conn, SQL_ATTR_ASYNC_DBC_FUNCTIONS_ENABLE, + &enable, 0, 0)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHY114); +} +#endif + +// Tests for supported attributes + +#ifdef SQL_ATTR_ASYNC_DBC_EVENT +TYPED_TEST(ConnectionAttributeTest, TestSQLGetConnectAttrAsyncDbcEventDefault) { + SQLPOINTER ptr = NULL; + ASSERT_EQ(SQL_SUCCESS, + SQLGetConnectAttr(this->conn, SQL_ATTR_ASYNC_DBC_EVENT, ptr, 0, 0)); + EXPECT_EQ(reinterpret_cast(NULL), ptr); +} +#endif + +#ifdef SQL_ATTR_ASYNC_DBC_PCALLBACK +TYPED_TEST(ConnectionAttributeTest, TestSQLGetConnectAttrAsyncDbcPcallbackDefault) { + SQLPOINTER ptr = NULL; + ASSERT_EQ(SQL_SUCCESS, + SQLGetConnectAttr(this->conn, SQL_ATTR_ASYNC_DBC_PCALLBACK, ptr, 0, 0)); + EXPECT_EQ(reinterpret_cast(NULL), ptr); +} +#endif + +#ifdef SQL_ATTR_ASYNC_DBC_PCONTEXT +TYPED_TEST(ConnectionAttributeTest, TestSQLGetConnectAttrAsyncDbcPcontextDefault) { + SQLPOINTER ptr = NULL; + ASSERT_EQ(SQL_SUCCESS, + SQLGetConnectAttr(this->conn, SQL_ATTR_ASYNC_DBC_PCONTEXT, ptr, 0, 0)); + EXPECT_EQ(reinterpret_cast(NULL), ptr); +} +#endif + +TYPED_TEST(ConnectionAttributeTest, TestSQLGetConnectAttrAsyncEnableDefault) { + SQLULEN enable; + ASSERT_EQ(SQL_SUCCESS, + SQLGetConnectAttr(this->conn, SQL_ATTR_ASYNC_ENABLE, &enable, 0, 0)); + EXPECT_EQ(SQL_ASYNC_ENABLE_OFF, enable); +} + +TYPED_TEST(ConnectionAttributeTest, TestSQLGetConnectAttrAutoIpdDefault) { + SQLUINTEGER ipd; + ASSERT_EQ(SQL_SUCCESS, SQLGetConnectAttr(this->conn, SQL_ATTR_AUTO_IPD, &ipd, 0, 0)); + EXPECT_EQ(static_cast(SQL_FALSE), ipd); +} + +TYPED_TEST(ConnectionAttributeTest, TestSQLGetConnectAttrAutocommitDefault) { + SQLUINTEGER auto_commit; + ASSERT_EQ(SQL_SUCCESS, + SQLGetConnectAttr(this->conn, SQL_ATTR_AUTOCOMMIT, &auto_commit, 0, 0)); + EXPECT_EQ(SQL_AUTOCOMMIT_ON, auto_commit); +} + +TYPED_TEST(ConnectionAttributeTest, TestSQLGetConnectAttrEnlistInDtcDefault) { + SQLPOINTER ptr = NULL; + ASSERT_EQ(SQL_SUCCESS, + SQLGetConnectAttr(this->conn, SQL_ATTR_ENLIST_IN_DTC, ptr, 0, 0)); + EXPECT_EQ(reinterpret_cast(NULL), ptr); +} + +TYPED_TEST(ConnectionAttributeTest, TestSQLGetConnectAttrQuietModeDefault) { + HWND ptr = NULL; + ASSERT_EQ(SQL_SUCCESS, SQLGetConnectAttr(this->conn, SQL_ATTR_QUIET_MODE, ptr, 0, 0)); + EXPECT_EQ(reinterpret_cast(NULL), ptr); +} + +TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrAccessModeValid) { + // The driver always returns SQL_MODE_READ_WRITE + + // Check default value first + SQLUINTEGER mode = -1; + ASSERT_EQ(SQL_SUCCESS, + SQLGetConnectAttr(this->conn, SQL_ATTR_ACCESS_MODE, &mode, 0, 0)); + EXPECT_EQ(SQL_MODE_READ_WRITE, mode); + + ASSERT_EQ(SQL_SUCCESS, + SQLSetConnectAttr(this->conn, SQL_ATTR_ACCESS_MODE, + reinterpret_cast(SQL_MODE_READ_WRITE), 0)); + + mode = -1; + ASSERT_EQ(SQL_SUCCESS, + SQLGetConnectAttr(this->conn, SQL_ATTR_ACCESS_MODE, &mode, 0, 0)); + EXPECT_EQ(SQL_MODE_READ_WRITE, mode); + + // Attempt to set to SQL_MODE_READ_ONLY, driver should return warning and not error + EXPECT_EQ(SQL_SUCCESS_WITH_INFO, + SQLSetConnectAttr(this->conn, SQL_ATTR_ACCESS_MODE, + reinterpret_cast(SQL_MODE_READ_ONLY), 0)); + + // Verify warning status + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorState01S02); +} + +TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrConnectionTimeoutValid) { + // Check default value first + SQLUINTEGER timeout = -1; + ASSERT_EQ(SQL_SUCCESS, + SQLGetConnectAttr(this->conn, SQL_ATTR_CONNECTION_TIMEOUT, &timeout, 0, 0)); + EXPECT_EQ(0, timeout); + + ASSERT_EQ(SQL_SUCCESS, SQLSetConnectAttr(this->conn, SQL_ATTR_CONNECTION_TIMEOUT, + reinterpret_cast(42), 0)); + + timeout = -1; + ASSERT_EQ(SQL_SUCCESS, + SQLGetConnectAttr(this->conn, SQL_ATTR_CONNECTION_TIMEOUT, &timeout, 0, 0)); + EXPECT_EQ(42, timeout); +} + +TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrLoginTimeoutValid) { + // Check default value first + SQLUINTEGER timeout = -1; + ASSERT_EQ(SQL_SUCCESS, + SQLGetConnectAttr(this->conn, SQL_ATTR_LOGIN_TIMEOUT, &timeout, 0, 0)); + EXPECT_EQ(0, timeout); + + ASSERT_EQ(SQL_SUCCESS, SQLSetConnectAttr(this->conn, SQL_ATTR_LOGIN_TIMEOUT, + reinterpret_cast(42), 0)); + + timeout = -1; + ASSERT_EQ(SQL_SUCCESS, + SQLGetConnectAttr(this->conn, SQL_ATTR_LOGIN_TIMEOUT, &timeout, 0, 0)); + EXPECT_EQ(42, timeout); +} + +TYPED_TEST(ConnectionAttributeTest, TestSQLSetConnectAttrPacketSizeValid) { + // The driver always returns 0. PACKET_SIZE value is unused by the driver. + + // Check default value first + SQLUINTEGER size = -1; + ASSERT_EQ(SQL_SUCCESS, + SQLGetConnectAttr(this->conn, SQL_ATTR_PACKET_SIZE, &size, 0, 0)); + EXPECT_EQ(0, size); + + ASSERT_EQ(SQL_SUCCESS, SQLSetConnectAttr(this->conn, SQL_ATTR_PACKET_SIZE, + reinterpret_cast(0), 0)); + + size = -1; + ASSERT_EQ(SQL_SUCCESS, + SQLGetConnectAttr(this->conn, SQL_ATTR_PACKET_SIZE, &size, 0, 0)); + EXPECT_EQ(0, size); + + // Attempt to set to non-zero value, driver should return warning and not error + EXPECT_EQ(SQL_SUCCESS_WITH_INFO, SQLSetConnectAttr(this->conn, SQL_ATTR_PACKET_SIZE, + reinterpret_cast(2), 0)); + + // Verify warning status + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorState01S02); +} + +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/connection_info_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/connection_info_test.cc new file mode 100644 index 00000000000..878d8d3d508 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/connection_info_test.cc @@ -0,0 +1,1225 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#include "arrow/flight/sql/odbc/odbc_impl/platform.h" + +#include +#include +#include + +#include + +namespace arrow::flight::sql::odbc { + +template +class ConnectionInfoTest : public T {}; + +class ConnectionInfoMockTest : public FlightSQLODBCMockTestBase {}; +using TestTypes = ::testing::Types; +TYPED_TEST_SUITE(ConnectionInfoTest, TestTypes); + +namespace { +// Helper Functions + +// Get SQLUSMALLINT return value +void GetInfo(SQLHDBC connection, SQLUSMALLINT info_type, SQLUSMALLINT* value) { + SQLSMALLINT message_length; + + ASSERT_EQ(SQL_SUCCESS, SQLGetInfo(connection, info_type, value, 0, &message_length)); +} + +// Get SQLUINTEGER return value +void GetInfo(SQLHDBC connection, SQLUSMALLINT info_type, SQLUINTEGER* value) { + SQLSMALLINT message_length; + + ASSERT_EQ(SQL_SUCCESS, SQLGetInfo(connection, info_type, value, 0, &message_length)); +} + +// Get SQLULEN return value +void GetInfo(SQLHDBC connection, SQLUSMALLINT info_type, SQLULEN* value) { + SQLSMALLINT message_length; + + ASSERT_EQ(SQL_SUCCESS, SQLGetInfo(connection, info_type, value, 0, &message_length)); +} + +// Get SQLWCHAR return value +void GetInfo(SQLHDBC connection, SQLUSMALLINT info_type, SQLWCHAR* value, + SQLSMALLINT buf_len = kOdbcBufferSize) { + SQLSMALLINT message_length; + + ASSERT_EQ(SQL_SUCCESS, + SQLGetInfo(connection, info_type, value, buf_len, &message_length)); +} +} // namespace + +// Driver Information + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoActiveEnvironments) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_ACTIVE_ENVIRONMENTS, &value); + + EXPECT_EQ(static_cast(0), value); +} + +#ifdef SQL_ASYNC_DBC_FUNCTIONS +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoAsyncDbcFunctions) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_ASYNC_DBC_FUNCTIONS, &value); + + EXPECT_EQ(static_cast(SQL_ASYNC_DBC_NOT_CAPABLE), value); +} +#endif + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoAsyncMode) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_ASYNC_MODE, &value); + + EXPECT_EQ(static_cast(SQL_AM_NONE), value); +} + +#ifdef SQL_ASYNC_NOTIFICATION +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoAsyncNotification) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_ASYNC_NOTIFICATION, &value); + + EXPECT_EQ(static_cast(SQL_ASYNC_NOTIFICATION_NOT_CAPABLE), value); +} +#endif + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoBatchRowCount) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_BATCH_ROW_COUNT, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoBatchSupport) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_BATCH_SUPPORT, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDataSourceName) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_DATA_SOURCE_NAME, value); + + EXPECT_STREQ(static_cast(L""), value); +} + +#ifdef SQL_DRIVER_AWARE_POOLING_SUPPORTED +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDriverAwarePoolingSupported) { + // According to Microsoft documentation, ODBC driver does not need to implement + // SQL_DRIVER_AWARE_POOLING_SUPPORTED and the Driver Manager will ignore the + // driver's return value for it. + + SQLUINTEGER value; + GetInfo(this->conn, SQL_DRIVER_AWARE_POOLING_SUPPORTED, &value); + + EXPECT_EQ(static_cast(SQL_DRIVER_AWARE_POOLING_NOT_CAPABLE), value); +} +#endif + +// These information types are implemented by the Driver Manager alone. +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDriverHdbc) { + // Value returned from driver manager is the connection address + SQLULEN value; + GetInfo(this->conn, SQL_DRIVER_HDBC, &value); + + EXPECT_GT(value, static_cast(0)); +} + +// These information types are implemented by the Driver Manager alone. +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDriverHdesc) { + SQLHDESC descriptor; + + // Allocate a descriptor using alloc handle + ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_DESC, this->conn, &descriptor)); + + // Value returned from driver manager is the desc address + SQLHDESC local_desc = descriptor; + EXPECT_EQ(SQL_SUCCESS, SQLGetInfo(this->conn, SQL_HANDLE_DESC, &local_desc, 0, 0)); + EXPECT_GT(local_desc, static_cast(0)); + + // Free descriptor handle + ASSERT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_DESC, descriptor)); +} + +// These information types are implemented by the Driver Manager alone. +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDriverHenv) { + // Value returned from driver manager is the env address + SQLULEN value; + GetInfo(this->conn, SQL_DRIVER_HENV, &value); + + EXPECT_GT(value, static_cast(0)); +} + +// These information types are implemented by the Driver Manager alone. +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDriverHlib) { + SQLULEN value; + GetInfo(this->conn, SQL_DRIVER_HLIB, &value); + + EXPECT_GT(value, static_cast(0)); +} + +// These information types are implemented by the Driver Manager alone. +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDriverHstmt) { + // Value returned from driver manager is the stmt address + SQLHSTMT local_stmt = this->stmt; + ASSERT_EQ(SQL_SUCCESS, SQLGetInfo(this->conn, SQL_DRIVER_HSTMT, &local_stmt, 0, 0)); + EXPECT_GT(local_stmt, static_cast(0)); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDriverName) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_DRIVER_NAME, value); + + EXPECT_STREQ(static_cast(L"Arrow Flight ODBC Driver"), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDriverOdbcVer) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_DRIVER_ODBC_VER, value); + + EXPECT_STREQ(static_cast(L"03.80"), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDriverVer) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_DRIVER_VER, value); + + EXPECT_STREQ(static_cast(L"00.09.0000.0"), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDynamicCursorAttributes1) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_DYNAMIC_CURSOR_ATTRIBUTES1, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDynamicCursorAttributes2) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_DYNAMIC_CURSOR_ATTRIBUTES2, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoForwardOnlyCursorAttributes1) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_FORWARD_ONLY_CURSOR_ATTRIBUTES1, &value); + + EXPECT_EQ(static_cast(SQL_CA1_NEXT), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoForwardOnlyCursorAttributes2) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_FORWARD_ONLY_CURSOR_ATTRIBUTES2, &value); + + EXPECT_EQ(static_cast(SQL_CA2_READ_ONLY_CONCURRENCY), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoFileUsage) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_FILE_USAGE, &value); + + EXPECT_EQ(static_cast(SQL_FILE_NOT_SUPPORTED), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoGetDataExtensions) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_GETDATA_EXTENSIONS, &value); + + EXPECT_EQ(static_cast(SQL_GD_ANY_COLUMN | SQL_GD_ANY_ORDER), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoSchemaViews) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_INFO_SCHEMA_VIEWS, &value); + + EXPECT_EQ(static_cast(SQL_ISV_TABLES | SQL_ISV_COLUMNS | SQL_ISV_VIEWS), + value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoKeysetCursorAttributes1) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_KEYSET_CURSOR_ATTRIBUTES1, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoKeysetCursorAttributes2) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_KEYSET_CURSOR_ATTRIBUTES2, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoMaxAsyncConcurrentStatements) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_MAX_ASYNC_CONCURRENT_STATEMENTS, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoMaxConcurrentActivities) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_MAX_CONCURRENT_ACTIVITIES, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoMaxDriverConnections) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_MAX_DRIVER_CONNECTIONS, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoOdbcInterfaceConformance) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_ODBC_INTERFACE_CONFORMANCE, &value); + + EXPECT_EQ(static_cast(SQL_OIC_CORE), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoOdbcVer) { + // This is implemented only in the Driver Manager. + + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_ODBC_VER, value); + + EXPECT_STREQ(static_cast(L"03.80.0000"), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoParamArrayRowCounts) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_PARAM_ARRAY_ROW_COUNTS, &value); + + EXPECT_EQ(static_cast(SQL_PARC_NO_BATCH), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoParamArraySelects) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_PARAM_ARRAY_SELECTS, &value); + + EXPECT_EQ(static_cast(SQL_PAS_NO_SELECT), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoRowUpdates) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_ROW_UPDATES, value); + + EXPECT_STREQ(static_cast(L"N"), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoSearchPatternEscape) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_SEARCH_PATTERN_ESCAPE, value); + + EXPECT_STREQ(static_cast(L"\\"), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoServerName) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_SERVER_NAME, value); + + EXPECT_GT(wcslen(value), 0); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoStaticCursorAttributes1) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_STATIC_CURSOR_ATTRIBUTES1, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoStaticCursorAttributes2) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_STATIC_CURSOR_ATTRIBUTES2, &value); + + EXPECT_EQ(static_cast(0), value); +} + +// DBMS Product Information + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDatabaseName) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_DATABASE_NAME, value); + + EXPECT_STREQ(static_cast(L""), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDbmsName) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_DBMS_NAME, value); + + EXPECT_GT(wcslen(value), 0); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDbmsVer) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_DBMS_VER, value); + + EXPECT_GT(wcslen(value), 0); +} + +// Data Source Information + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoAccessibleProcedures) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_ACCESSIBLE_PROCEDURES, value); + + EXPECT_STREQ(static_cast(L"N"), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoAccessibleTables) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_ACCESSIBLE_TABLES, value); + + EXPECT_STREQ(static_cast(L"Y"), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoBookmarkPersistence) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_BOOKMARK_PERSISTENCE, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoCatalogTerm) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_CATALOG_TERM, value); + + EXPECT_STREQ(static_cast(L""), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoCollationSeq) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_COLLATION_SEQ, value); + + EXPECT_STREQ(static_cast(L""), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoConcatNullBehavior) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_CONCAT_NULL_BEHAVIOR, &value); + + EXPECT_EQ(static_cast(SQL_CB_NULL), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoCursorCommitBehavior) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_CURSOR_COMMIT_BEHAVIOR, &value); + + EXPECT_EQ(static_cast(SQL_CB_CLOSE), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoCursorRollbackBehavior) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_CURSOR_ROLLBACK_BEHAVIOR, &value); + + EXPECT_EQ(static_cast(SQL_CB_CLOSE), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoCursorSensitivity) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CURSOR_SENSITIVITY, &value); + + EXPECT_EQ(static_cast(SQL_UNSPECIFIED), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDataSourceReadOnly) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_DATA_SOURCE_READ_ONLY, value); + + EXPECT_STREQ(static_cast(L"N"), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDefaultTxnIsolation) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_DEFAULT_TXN_ISOLATION, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDescribeParameter) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_DESCRIBE_PARAMETER, value); + + EXPECT_STREQ(static_cast(L"N"), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoMultResultSets) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_MULT_RESULT_SETS, value); + + EXPECT_STREQ(static_cast(L"N"), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoMultipleActiveTxn) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_MULTIPLE_ACTIVE_TXN, value); + + EXPECT_STREQ(static_cast(L"N"), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoNeedLongDataLen) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_NEED_LONG_DATA_LEN, value); + + EXPECT_STREQ(static_cast(L"N"), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoNullCollation) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_NULL_COLLATION, &value); + EXPECT_EQ(static_cast(SQL_NC_START), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoProcedureTerm) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_PROCEDURE_TERM, value); + + EXPECT_STREQ(static_cast(L""), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoSchemaTerm) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_SCHEMA_TERM, value); + + EXPECT_STREQ(static_cast(L"schema"), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoScrollOptions) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_SCROLL_OPTIONS, &value); + + EXPECT_EQ(static_cast(SQL_SO_FORWARD_ONLY), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoTableTerm) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_TABLE_TERM, value); + + EXPECT_STREQ(static_cast(L"table"), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoTxnCapable) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_TXN_CAPABLE, &value); + + EXPECT_EQ(static_cast(SQL_TC_NONE), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoTxnIsolationOption) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_TXN_ISOLATION_OPTION, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoUserName) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_USER_NAME, value); + + EXPECT_STREQ(static_cast(L""), value); +} + +// Supported SQL + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoAggregateFunctions) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_AGGREGATE_FUNCTIONS, &value); + + EXPECT_EQ(value, static_cast(SQL_AF_ALL | SQL_AF_AVG | SQL_AF_COUNT | + SQL_AF_DISTINCT | SQL_AF_MAX | SQL_AF_MIN | + SQL_AF_SUM)); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoAlterDomain) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_ALTER_DOMAIN, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoAlterTable) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_ALTER_TABLE, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoCatalogLocation) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_CATALOG_LOCATION, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoCatalogName) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_CATALOG_NAME, value); + + EXPECT_STREQ(static_cast(L"N"), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoCatalogNameSeparator) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_CATALOG_NAME_SEPARATOR, value); + + EXPECT_STREQ(static_cast(L""), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoCatalogUsage) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CATALOG_USAGE, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoColumnAlias) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_COLUMN_ALIAS, value); + + EXPECT_STREQ(static_cast(L"Y"), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoCorrelationName) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_CORRELATION_NAME, &value); + + EXPECT_EQ(static_cast(SQL_CN_NONE), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoCreateAssertion) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CREATE_ASSERTION, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoCreateCharacterSet) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CREATE_CHARACTER_SET, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoCreateCollation) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CREATE_COLLATION, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoCreateDomain) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CREATE_DOMAIN, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoCreateSchema) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CREATE_SCHEMA, &value); + + EXPECT_EQ(static_cast(1), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoCreateTable) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CREATE_TABLE, &value); + + EXPECT_EQ(static_cast(1), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoCreateTranslation) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CREATE_TRANSLATION, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDdlIndex) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_DDL_INDEX, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDropAssertion) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_DROP_ASSERTION, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDropCharacterSet) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_DROP_CHARACTER_SET, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDropCollation) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_DROP_COLLATION, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDropDomain) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_DROP_DOMAIN, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDropSchema) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_DROP_SCHEMA, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDropTable) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_DROP_TABLE, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDropTranslation) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_DROP_TRANSLATION, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDropView) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_DROP_VIEW, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoExpressionsInOrderby) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_EXPRESSIONS_IN_ORDERBY, value); + + EXPECT_STREQ(static_cast(L"N"), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoGroupBy) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_GROUP_BY, &value); + + EXPECT_EQ(static_cast(SQL_GB_GROUP_BY_CONTAINS_SELECT), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoIdentifierCase) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_IDENTIFIER_CASE, &value); + + EXPECT_EQ(static_cast(SQL_IC_MIXED), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoIdentifierQuoteChar) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_IDENTIFIER_QUOTE_CHAR, value); + + EXPECT_STREQ(static_cast(L"\""), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoIndexKeywords) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_INDEX_KEYWORDS, &value); + + EXPECT_EQ(static_cast(SQL_IK_NONE), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoInsertStatement) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_INSERT_STATEMENT, &value); + + EXPECT_EQ(value, static_cast(SQL_IS_INSERT_LITERALS | + SQL_IS_INSERT_SEARCHED | SQL_IS_SELECT_INTO)); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoIntegrity) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_INTEGRITY, value); + + EXPECT_STREQ(static_cast(L"N"), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoKeywords) { + // Keyword strings can require 5000 buffer length + static constexpr int info_len = kOdbcBufferSize * 5; + SQLWCHAR value[info_len] = L""; + GetInfo(this->conn, SQL_KEYWORDS, value, info_len); + + EXPECT_GT(wcslen(value), 0); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoLikeEscapeClause) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_LIKE_ESCAPE_CLAUSE, value); + + EXPECT_STREQ(static_cast(L"Y"), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoNonNullableColumns) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_NON_NULLABLE_COLUMNS, &value); + + EXPECT_EQ(static_cast(SQL_NNC_NULL), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoOjCapabilities) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_OJ_CAPABILITIES, &value); + + EXPECT_EQ(static_cast(SQL_OJ_LEFT | SQL_OJ_RIGHT | SQL_OJ_FULL), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoOrderByColumnsInSelect) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_ORDER_BY_COLUMNS_IN_SELECT, value); + + EXPECT_STREQ(static_cast(L"Y"), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoOuterJoins) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_OUTER_JOINS, value); + + EXPECT_STREQ(static_cast(L"N"), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoProcedures) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_PROCEDURES, value); + + EXPECT_STREQ(static_cast(L"N"), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoQuotedIdentifierCase) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_QUOTED_IDENTIFIER_CASE, &value); + + EXPECT_EQ(static_cast(SQL_IC_MIXED), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoSchemaUsage) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_SCHEMA_USAGE, &value); + + EXPECT_EQ(static_cast(SQL_SU_DML_STATEMENTS), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoSpecialCharacters) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_SPECIAL_CHARACTERS, value); + + EXPECT_STREQ(static_cast(L""), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoSqlConformance) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_SQL_CONFORMANCE, &value); + + EXPECT_EQ(static_cast(SQL_SC_SQL92_ENTRY), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoSubqueries) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_SUBQUERIES, &value); + + EXPECT_EQ(value, + static_cast(SQL_SQ_CORRELATED_SUBQUERIES | SQL_SQ_COMPARISON | + SQL_SQ_EXISTS | SQL_SQ_IN | SQL_SQ_QUANTIFIED)); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoUnion) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_UNION, &value); + + EXPECT_EQ(static_cast(SQL_U_UNION | SQL_U_UNION_ALL), value); +} + +// SQL Limits + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoMaxBinaryLiteralLen) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_MAX_BINARY_LITERAL_LEN, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoMaxCatalogNameLen) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_MAX_CATALOG_NAME_LEN, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoMaxCharLiteralLen) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_MAX_CHAR_LITERAL_LEN, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoMaxColumnNameLen) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_MAX_COLUMN_NAME_LEN, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoMaxColumnsInGroupBy) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_MAX_COLUMNS_IN_GROUP_BY, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoMaxColumnsInIndex) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_MAX_COLUMNS_IN_INDEX, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoMaxColumnsInOrderBy) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_MAX_COLUMNS_IN_ORDER_BY, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoMaxColumnsInSelect) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_MAX_COLUMNS_IN_SELECT, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoMaxColumnsInTable) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_MAX_COLUMNS_IN_TABLE, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoMaxCursorNameLen) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_MAX_CURSOR_NAME_LEN, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoMaxIdentifierLen) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_MAX_IDENTIFIER_LEN, &value); + + EXPECT_EQ(static_cast(65535), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoMaxIndexSize) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_MAX_INDEX_SIZE, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoMaxProcedureNameLen) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_MAX_PROCEDURE_NAME_LEN, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoMaxRowSize) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_MAX_ROW_SIZE, value); + + EXPECT_STREQ(static_cast(L""), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoMaxRowSizeIncludesLong) { + SQLWCHAR value[kOdbcBufferSize] = L""; + GetInfo(this->conn, SQL_MAX_ROW_SIZE_INCLUDES_LONG, value); + + EXPECT_STREQ(static_cast(L"N"), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoMaxSchemaNameLen) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_MAX_SCHEMA_NAME_LEN, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoMaxStatementLen) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_MAX_STATEMENT_LEN, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoMaxTableNameLen) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_MAX_TABLE_NAME_LEN, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoMaxTablesInSelect) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_MAX_TABLES_IN_SELECT, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoMaxUserNameLen) { + SQLUSMALLINT value; + GetInfo(this->conn, SQL_MAX_USER_NAME_LEN, &value); + + EXPECT_EQ(static_cast(0), value); +} + +// Scalar Function Information + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoConvertFunctions) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_FUNCTIONS, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoNumericFunctions) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_NUMERIC_FUNCTIONS, &value); + + EXPECT_EQ(static_cast(4058942), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoStringFunctions) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_STRING_FUNCTIONS, &value); + + EXPECT_EQ(value, static_cast(SQL_FN_STR_LTRIM | SQL_FN_STR_LENGTH | + SQL_FN_STR_REPLACE | SQL_FN_STR_RTRIM)); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoSystemFunctions) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_SYSTEM_FUNCTIONS, &value); + + EXPECT_EQ(static_cast(SQL_FN_SYS_IFNULL | SQL_FN_SYS_USERNAME), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoTimedateAddIntervals) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_TIMEDATE_ADD_INTERVALS, &value); + + EXPECT_EQ(value, static_cast( + SQL_FN_TSI_FRAC_SECOND | SQL_FN_TSI_SECOND | SQL_FN_TSI_MINUTE | + SQL_FN_TSI_HOUR | SQL_FN_TSI_DAY | SQL_FN_TSI_WEEK | + SQL_FN_TSI_MONTH | SQL_FN_TSI_QUARTER | SQL_FN_TSI_YEAR)); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoTimedateDiffIntervals) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_TIMEDATE_DIFF_INTERVALS, &value); + + EXPECT_EQ(value, static_cast( + SQL_FN_TSI_FRAC_SECOND | SQL_FN_TSI_SECOND | SQL_FN_TSI_MINUTE | + SQL_FN_TSI_HOUR | SQL_FN_TSI_DAY | SQL_FN_TSI_WEEK | + SQL_FN_TSI_MONTH | SQL_FN_TSI_QUARTER | SQL_FN_TSI_YEAR)); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoTimedateFunctions) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_TIMEDATE_FUNCTIONS, &value); + + EXPECT_EQ(value, + static_cast( + SQL_FN_TD_CURRENT_DATE | SQL_FN_TD_CURRENT_TIME | + SQL_FN_TD_CURRENT_TIMESTAMP | SQL_FN_TD_CURDATE | SQL_FN_TD_CURTIME | + SQL_FN_TD_DAYNAME | SQL_FN_TD_DAYOFMONTH | SQL_FN_TD_DAYOFWEEK | + SQL_FN_TD_DAYOFYEAR | SQL_FN_TD_EXTRACT | SQL_FN_TD_HOUR | + SQL_FN_TD_MINUTE | SQL_FN_TD_MONTH | SQL_FN_TD_MONTHNAME | SQL_FN_TD_NOW | + SQL_FN_TD_QUARTER | SQL_FN_TD_SECOND | SQL_FN_TD_TIMESTAMPADD | + SQL_FN_TD_TIMESTAMPDIFF | SQL_FN_TD_WEEK | SQL_FN_TD_YEAR)); +} + +// Conversion Information + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoConvertBigint) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_BIGINT, &value); + + EXPECT_EQ(static_cast(8), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoConvertBinary) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_BINARY, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoConvertBit) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_BIT, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoConvertChar) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_CHAR, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoConvertDate) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_DATE, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoConvertDecimal) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_DECIMAL, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoConvertDouble) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_DOUBLE, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoConvertFloat) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_FLOAT, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoConvertInteger) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_INTEGER, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoConvertIntervalDayTime) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_INTERVAL_DAY_TIME, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoConvertIntervalYearMonth) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_INTERVAL_YEAR_MONTH, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoConvertLongvarbinary) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_LONGVARBINARY, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoConvertLongvarchar) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_LONGVARCHAR, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TEST_F(ConnectionInfoMockTest, TestSQLGetInfoConvertNumeric) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_NUMERIC, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoConvertReal) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_REAL, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoConvertSmallint) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_SMALLINT, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoConvertTime) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_TIME, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoConvertTimestamp) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_TIMESTAMP, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoConvertTinyint) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_TINYINT, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoConvertVarbinary) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_VARBINARY, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoConvertVarchar) { + SQLUINTEGER value; + GetInfo(this->conn, SQL_CONVERT_VARCHAR, &value); + + EXPECT_EQ(static_cast(0), value); +} + +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc index 531250b69b8..91ce8f45a08 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc @@ -468,4 +468,114 @@ TYPED_TEST(ConnectionHandleTest, TestSQLDisconnectWithoutConnection) { TYPED_TEST(ConnectionTest, TestConnect) { // Verifies connect and disconnect works on its own } + +TYPED_TEST(ConnectionTest, TestSQLAllocFreeStmt) { + SQLHSTMT statement; + + // Allocate a statement using alloc statement + ASSERT_EQ(SQL_SUCCESS, SQLAllocStmt(this->conn, &statement)); + + SQLWCHAR sql_buffer[kOdbcBufferSize] = L"SELECT 1"; + ASSERT_EQ(SQL_SUCCESS, SQLExecDirect(statement, sql_buffer, SQL_NTS)); + + // Close statement handle + ASSERT_EQ(SQL_SUCCESS, SQLFreeStmt(statement, SQL_CLOSE)); + + // Free statement handle + ASSERT_EQ(SQL_SUCCESS, SQLFreeStmt(statement, SQL_DROP)); +} + +TYPED_TEST(ConnectionHandleTest, TestCloseConnectionWithOpenStatement) { + SQLHSTMT statement; + + // Connect string + std::string connect_str = this->GetConnectionString(); + ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str, + arrow::util::UTF8ToWideString(connect_str)); + std::vector connect_str0(wconnect_str.begin(), wconnect_str.end()); + + SQLWCHAR out_str[kOdbcBufferSize] = L""; + SQLSMALLINT out_str_len; + + // Connecting to ODBC server. + ASSERT_EQ(SQL_SUCCESS, + SQLDriverConnect(this->conn, NULL, &connect_str0[0], + static_cast(connect_str0.size()), out_str, + kOdbcBufferSize, &out_str_len, SQL_DRIVER_NOPROMPT)) + << GetOdbcErrorMessage(SQL_HANDLE_DBC, this->conn); + + // Allocate a statement using alloc statement + ASSERT_EQ(SQL_SUCCESS, SQLAllocStmt(this->conn, &statement)); + + // Disconnect from ODBC without closing the statement first + ASSERT_EQ(SQL_SUCCESS, SQLDisconnect(this->conn)); +} + +TYPED_TEST(ConnectionTest, TestSQLAllocFreeDesc) { + SQLHDESC descriptor; + + // Allocate a descriptor using alloc handle + ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_DESC, this->conn, &descriptor)); + + // Free descriptor handle + ASSERT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_DESC, descriptor)); +} + +TYPED_TEST(ConnectionTest, TestSQLSetStmtAttrDescriptor) { + SQLHDESC apd_descriptor, ard_descriptor; + + // Allocate an APD descriptor using alloc handle + ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_DESC, this->conn, &apd_descriptor)); + + // Allocate an ARD descriptor using alloc handle + ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_DESC, this->conn, &ard_descriptor)); + + // Save implicitly allocated internal APD and ARD descriptor pointers + SQLPOINTER internal_apd, internal_ard = nullptr; + + EXPECT_EQ(SQL_SUCCESS, SQLGetStmtAttr(this->stmt, SQL_ATTR_APP_PARAM_DESC, + &internal_apd, sizeof(internal_apd), 0)); + + EXPECT_EQ(SQL_SUCCESS, SQLGetStmtAttr(this->stmt, SQL_ATTR_APP_ROW_DESC, &internal_ard, + sizeof(internal_ard), 0)); + + // Set APD descriptor to explicitly allocated handle + EXPECT_EQ(SQL_SUCCESS, SQLSetStmtAttr(this->stmt, SQL_ATTR_APP_PARAM_DESC, + reinterpret_cast(apd_descriptor), 0)); + + // Set ARD descriptor to explicitly allocated handle + EXPECT_EQ(SQL_SUCCESS, SQLSetStmtAttr(this->stmt, SQL_ATTR_APP_ROW_DESC, + reinterpret_cast(ard_descriptor), 0)); + + // Verify APD and ARD descriptors are set to explicitly allocated pointers + SQLPOINTER value = nullptr; + EXPECT_EQ(SQL_SUCCESS, SQLGetStmtAttr(this->stmt, SQL_ATTR_APP_PARAM_DESC, &value, + sizeof(value), 0)); + + EXPECT_EQ(apd_descriptor, value); + + EXPECT_EQ(SQL_SUCCESS, + SQLGetStmtAttr(this->stmt, SQL_ATTR_APP_ROW_DESC, &value, sizeof(value), 0)); + + EXPECT_EQ(ard_descriptor, value); + + // Free explicitly allocated APD and ARD descriptor handles + ASSERT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_DESC, apd_descriptor)); + + ASSERT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_DESC, ard_descriptor)); + + // Verify APD and ARD descriptors has been reverted to implicit descriptors + value = nullptr; + + EXPECT_EQ(SQL_SUCCESS, SQLGetStmtAttr(this->stmt, SQL_ATTR_APP_PARAM_DESC, &value, + sizeof(value), 0)); + + EXPECT_EQ(internal_apd, value); + + EXPECT_EQ(SQL_SUCCESS, + SQLGetStmtAttr(this->stmt, SQL_ATTR_APP_ROW_DESC, &value, sizeof(value), 0)); + + EXPECT_EQ(internal_ard, value); +} + } // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/dremio/docker-compose.yml b/cpp/src/arrow/flight/sql/odbc/tests/dremio/docker-compose.yml new file mode 100644 index 00000000000..eaab4d02b73 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/dremio/docker-compose.yml @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# GH-48068 TODO: run remote ODBC tests on Linux + +services: + dremio: + platform: linux/x86_64 + image: dremio/dremio-oss:latest + ports: + - 9047:9047 # REST API + - 31010:31010 # JDBC/ODBC + - 32010:32010 + container_name: dremio_container + environment: + - DREMIO_JAVA_SERVER_EXTRA_OPTS=-Dsaffron.default.charset=UTF-8 -Dsaffron.default.nationalcharset=UTF-8 -Dsaffron.default.collation.name=UTF-8$$en_US + healthcheck: + test: curl --fail http://localhost:9047 || exit 1 + interval: 10s + timeout: 5s + retries: 30 diff --git a/cpp/src/arrow/flight/sql/odbc/tests/dremio/set_up_dremio_instance.sh b/cpp/src/arrow/flight/sql/odbc/tests/dremio/set_up_dremio_instance.sh new file mode 100644 index 00000000000..8d632bb2c3e --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/dremio/set_up_dremio_instance.sh @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# GH-48068 TODO: run remote ODBC tests on Linux + +#!/bin/bash +set -e + +HOST_URL="http://localhost:9047" +NEW_USER_URL="$HOST_URL/apiv2/bootstrap/firstuser" +LOGIN_URL="$HOST_URL/apiv2/login" +SQL_URL="$HOST_URL/api/v3/sql" + +ADMIN_USER="admin" +ADMIN_PASSWORD="admin2025" + +# Wait for Dremio to be available. +until curl -s "$NEW_USER_URL"; do + echo 'Waiting for Dremio to start...' + sleep 5 +done + +echo "" +echo 'Creating admin user...' + +# Create new admin account. +curl -X PUT "$NEW_USER_URL" \ + -H "Content-Type: application/json" \ + -d "{ \"userName\": \"$ADMIN_USER\", \"password\": \"$ADMIN_PASSWORD\" }" + +echo "" +echo "Created admin user." + +# Use admin account to login and acquire a token. +TOKEN=$(curl -s -X POST "$LOGIN_URL" \ + -H "Content-Type: application/json" \ + -d "{ \"userName\": \"$ADMIN_USER\", \"password\": \"$ADMIN_PASSWORD\" }" \ + | grep -oP '(?<="token":")[^"]+') + +SQL_QUERY="Create Table \$scratch.ODBCTest As SELECT CAST(2147483647 AS INTEGER) AS sinteger_max, CAST(9223372036854775807 AS BIGINT) AS sbigint_max, CAST(999999999 AS DECIMAL(38,0)) AS decimal_positive, CAST(3.40282347E38 AS FLOAT) AS float_max, CAST(1.7976931348623157E308 AS DOUBLE) AS double_max, CAST(true AS BOOLEAN) AS bit_true, CAST(DATE '9999-12-31' AS DATE) AS date_max, CAST(TIME '23:59:59' AS TIME) AS time_max, CAST(TIMESTAMP '9999-12-31 23:59:59' AS TIMESTAMP) AS timestamp_max;" +ESCAPED_QUERY=$(printf '%s' "$SQL_QUERY" | sed 's/"/\\"/g') + +echo "Creating \$scratch.ODBCTest table." + +# Create a new table by sending a SQL query. +curl -i -X POST "$SQL_URL" \ + -H "Authorization: _dremio$TOKEN" \ + -H "Content-Type: application/json" \ + -d "{\"sql\": \"$ESCAPED_QUERY\"}" + +echo "" +echo "Finished setting up dremio docker instance." diff --git a/cpp/src/arrow/flight/sql/odbc/tests/errors_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/errors_test.cc new file mode 100644 index 00000000000..01c861913d5 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/errors_test.cc @@ -0,0 +1,560 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#include "arrow/flight/sql/odbc/odbc_impl/platform.h" + +#include +#include +#include + +#include + +namespace arrow::flight::sql::odbc { + +template +class ErrorsTest : public T {}; + +using TestTypes = + ::testing::Types; +TYPED_TEST_SUITE(ErrorsTest, TestTypes); + +template +class ErrorsOdbcV2Test : public T {}; + +using TestTypesOdbcV2 = + ::testing::Types; +TYPED_TEST_SUITE(ErrorsOdbcV2Test, TestTypesOdbcV2); + +template +class ErrorsHandleTest : public T {}; + +using TestTypesHandle = + ::testing::Types; +TYPED_TEST_SUITE(ErrorsHandleTest, TestTypesHandle); + +TYPED_TEST(ErrorsHandleTest, TestSQLGetDiagFieldWForConnectFailure) { + // Invalid connect string + std::string connect_str = this->GetInvalidConnectionString(); + + ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str, + arrow::util::UTF8ToWideString(connect_str)); + std::vector connect_str0(wconnect_str.begin(), wconnect_str.end()); + + SQLWCHAR out_str[kOdbcBufferSize]; + SQLSMALLINT out_str_len; + + // Connecting to ODBC server. + ASSERT_EQ(SQL_ERROR, + SQLDriverConnect(this->conn, NULL, &connect_str0[0], + static_cast(connect_str0.size()), out_str, + kOdbcBufferSize, &out_str_len, SQL_DRIVER_NOPROMPT)); + + // Retrieve all supported header level and record level data + SQLSMALLINT HEADER_LEVEL = 0; + SQLSMALLINT RECORD_1 = 1; + + // SQL_DIAG_NUMBER + SQLINTEGER diag_number; + SQLSMALLINT diag_number_length; + + EXPECT_EQ(SQL_SUCCESS, + SQLGetDiagField(SQL_HANDLE_DBC, this->conn, HEADER_LEVEL, SQL_DIAG_NUMBER, + &diag_number, sizeof(SQLINTEGER), &diag_number_length)); + + EXPECT_EQ(1, diag_number); + + // SQL_DIAG_SERVER_NAME + SQLWCHAR server_name[kOdbcBufferSize]; + SQLSMALLINT server_name_length; + + EXPECT_EQ(SQL_SUCCESS, + SQLGetDiagField(SQL_HANDLE_DBC, this->conn, RECORD_1, SQL_DIAG_SERVER_NAME, + server_name, kOdbcBufferSize, &server_name_length)); + + // SQL_DIAG_MESSAGE_TEXT + SQLWCHAR message_text[kOdbcBufferSize]; + SQLSMALLINT message_text_length; + + EXPECT_EQ(SQL_SUCCESS, + SQLGetDiagField(SQL_HANDLE_DBC, this->conn, RECORD_1, SQL_DIAG_MESSAGE_TEXT, + message_text, kOdbcBufferSize, &message_text_length)); + + EXPECT_GT(message_text_length, 100); + + // SQL_DIAG_NATIVE + SQLINTEGER diag_native; + SQLSMALLINT diag_native_length; + + EXPECT_EQ(SQL_SUCCESS, + SQLGetDiagField(SQL_HANDLE_DBC, this->conn, RECORD_1, SQL_DIAG_NATIVE, + &diag_native, sizeof(diag_native), &diag_native_length)); + + EXPECT_EQ(200, diag_native); + + // SQL_DIAG_SQLSTATE + const SQLSMALLINT sql_state_size = 6; + SQLWCHAR sql_state[sql_state_size]; + SQLSMALLINT sql_state_length; + + EXPECT_EQ( + SQL_SUCCESS, + SQLGetDiagField(SQL_HANDLE_DBC, this->conn, RECORD_1, SQL_DIAG_SQLSTATE, sql_state, + sql_state_size * arrow::flight::sql::odbc::GetSqlWCharSize(), + &sql_state_length)); + + EXPECT_EQ(std::wstring(L"28000"), std::wstring(sql_state)); +} + +TYPED_TEST(ErrorsHandleTest, DISABLED_TestSQLGetDiagFieldWForConnectFailureNTS) { + // Test is disabled because driver manager on Windows does not pass through SQL_NTS + // This test case can be potentially used on macOS/Linux + + // Invalid connect string + std::string connect_str = this->GetInvalidConnectionString(); + + ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str, + arrow::util::UTF8ToWideString(connect_str)); + std::vector connect_str0(wconnect_str.begin(), wconnect_str.end()); + + SQLWCHAR out_str[kOdbcBufferSize]; + SQLSMALLINT out_str_len; + + // Connecting to ODBC server. + ASSERT_EQ(SQL_ERROR, + SQLDriverConnect(this->conn, NULL, &connect_str0[0], + static_cast(connect_str0.size()), out_str, + kOdbcBufferSize, &out_str_len, SQL_DRIVER_NOPROMPT)); + + // Retrieve all supported header level and record level data + SQLSMALLINT RECORD_1 = 1; + + // SQL_DIAG_MESSAGE_TEXT SQL_NTS + SQLWCHAR message_text[kOdbcBufferSize]; + SQLSMALLINT message_text_length; + + message_text[kOdbcBufferSize - 1] = '\0'; + + ASSERT_EQ(SQL_SUCCESS, + SQLGetDiagField(SQL_HANDLE_DBC, this->conn, RECORD_1, SQL_DIAG_MESSAGE_TEXT, + message_text, SQL_NTS, &message_text_length)); + + EXPECT_GT(message_text_length, 100); +} + +TYPED_TEST(ErrorsTest, TestSQLGetDiagFieldWForDescriptorFailureFromDriverManager) { + SQLHDESC descriptor; + + // Allocate a descriptor using alloc handle + ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_DESC, this->conn, &descriptor)); + + EXPECT_EQ(SQL_ERROR, + SQLGetDescField(descriptor, 1, SQL_DESC_DATETIME_INTERVAL_CODE, 0, 0, 0)); + + // Retrieve all supported header level and record level data + SQLSMALLINT HEADER_LEVEL = 0; + SQLSMALLINT RECORD_1 = 1; + + // SQL_DIAG_NUMBER + SQLINTEGER diag_number; + SQLSMALLINT diag_number_length; + + EXPECT_EQ(SQL_SUCCESS, + SQLGetDiagField(SQL_HANDLE_DESC, descriptor, HEADER_LEVEL, SQL_DIAG_NUMBER, + &diag_number, sizeof(SQLINTEGER), &diag_number_length)); + + EXPECT_EQ(1, diag_number); + + // SQL_DIAG_SERVER_NAME + SQLWCHAR server_name[kOdbcBufferSize]; + SQLSMALLINT server_name_length; + + EXPECT_EQ(SQL_SUCCESS, + SQLGetDiagField(SQL_HANDLE_DESC, descriptor, RECORD_1, SQL_DIAG_SERVER_NAME, + server_name, kOdbcBufferSize, &server_name_length)); + + // SQL_DIAG_MESSAGE_TEXT + SQLWCHAR message_text[kOdbcBufferSize]; + SQLSMALLINT message_text_length; + + EXPECT_EQ(SQL_SUCCESS, + SQLGetDiagField(SQL_HANDLE_DESC, descriptor, RECORD_1, SQL_DIAG_MESSAGE_TEXT, + message_text, kOdbcBufferSize, &message_text_length)); + + EXPECT_GT(message_text_length, 100); + + // SQL_DIAG_NATIVE + SQLINTEGER diag_native; + SQLSMALLINT diag_native_length; + + EXPECT_EQ(SQL_SUCCESS, + SQLGetDiagField(SQL_HANDLE_DESC, descriptor, RECORD_1, SQL_DIAG_NATIVE, + &diag_native, sizeof(diag_native), &diag_native_length)); + + EXPECT_EQ(0, diag_native); + + // SQL_DIAG_SQLSTATE + const SQLSMALLINT sql_state_size = 6; + SQLWCHAR sql_state[sql_state_size]; + SQLSMALLINT sql_state_length; + EXPECT_EQ( + SQL_SUCCESS, + SQLGetDiagField(SQL_HANDLE_DESC, descriptor, RECORD_1, SQL_DIAG_SQLSTATE, sql_state, + sql_state_size * GetSqlWCharSize(), &sql_state_length)); + + EXPECT_EQ(std::wstring(L"IM001"), std::wstring(sql_state)); + + // Free descriptor handle + EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_DESC, descriptor)); +} + +TYPED_TEST(ErrorsTest, TestSQLGetDiagRecForDescriptorFailureFromDriverManager) { + SQLHDESC descriptor; + + // Allocate a descriptor using alloc handle + ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_DESC, this->conn, &descriptor)); + + EXPECT_EQ(SQL_ERROR, + SQLGetDescField(descriptor, 1, SQL_DESC_DATETIME_INTERVAL_CODE, 0, 0, 0)); + + SQLWCHAR sql_state[6]; + SQLINTEGER native_error; + SQLWCHAR message[kOdbcBufferSize]; + SQLSMALLINT message_length; + + ASSERT_EQ(SQL_SUCCESS, + SQLGetDiagRec(SQL_HANDLE_DESC, descriptor, 1, sql_state, &native_error, + message, kOdbcBufferSize, &message_length)); + + EXPECT_GT(message_length, 60); + + EXPECT_EQ(0, native_error); + + // API not implemented error from driver manager + EXPECT_EQ(std::wstring(L"IM001"), std::wstring(sql_state)); + + EXPECT_TRUE(!std::wstring(message).empty()); + + // Free descriptor handle + EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_DESC, descriptor)); +} + +TYPED_TEST(ErrorsHandleTest, TestSQLGetDiagRecForConnectFailure) { + // Invalid connect string + std::string connect_str = this->GetInvalidConnectionString(); + + ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str, + arrow::util::UTF8ToWideString(connect_str)); + std::vector connect_str0(wconnect_str.begin(), wconnect_str.end()); + + SQLWCHAR out_str[kOdbcBufferSize]; + SQLSMALLINT out_str_len; + + // Connecting to ODBC server. + ASSERT_EQ(SQL_ERROR, + SQLDriverConnect(this->conn, NULL, &connect_str0[0], + static_cast(connect_str0.size()), out_str, + kOdbcBufferSize, &out_str_len, SQL_DRIVER_NOPROMPT)); + + SQLWCHAR sql_state[6]; + SQLINTEGER native_error; + SQLWCHAR message[kOdbcBufferSize]; + SQLSMALLINT message_length; + ASSERT_EQ(SQL_SUCCESS, + SQLGetDiagRec(SQL_HANDLE_DBC, this->conn, 1, sql_state, &native_error, + message, kOdbcBufferSize, &message_length)); + + EXPECT_GT(message_length, 120); + + EXPECT_EQ(200, native_error); + + EXPECT_EQ(std::wstring(L"28000"), std::wstring(sql_state)); + + EXPECT_TRUE(!std::wstring(message).empty()); +} + +TYPED_TEST(ErrorsTest, TestSQLGetDiagRecInputData) { + // SQLGetDiagRec does not post diagnostic records for itself. + + SQLWCHAR sql_state[6]; + SQLINTEGER native_error; + SQLWCHAR message[kOdbcBufferSize]; + SQLSMALLINT message_length; + + // Pass invalid record number + EXPECT_EQ(SQL_ERROR, + SQLGetDiagRec(SQL_HANDLE_DBC, this->conn, 0, sql_state, &native_error, + message, kOdbcBufferSize, &message_length)); + + // Pass valid record number with null inputs + EXPECT_EQ(SQL_NO_DATA, SQLGetDiagRec(SQL_HANDLE_DBC, this->conn, 1, 0, 0, 0, 0, 0)); + + // Invalid handle + EXPECT_EQ(SQL_INVALID_HANDLE, SQLGetDiagRec(0, 0, 0, 0, 0, 0, 0, 0)); +} + +TYPED_TEST(ErrorsTest, TestSQLErrorInputData) { + // Test ODBC 2.0 API SQLError. Driver manager maps SQLError to SQLGetDiagRec. + // SQLError does not post diagnostic records for itself. + + // Pass valid handles with null inputs + EXPECT_EQ(SQL_NO_DATA, SQLError(this->env, 0, 0, 0, 0, 0, 0, 0)); + + EXPECT_EQ(SQL_NO_DATA, SQLError(0, this->conn, 0, 0, 0, 0, 0, 0)); + + EXPECT_EQ(SQL_NO_DATA, SQLError(0, 0, this->stmt, 0, 0, 0, 0, 0)); + + // Invalid handle + EXPECT_EQ(SQL_INVALID_HANDLE, SQLError(0, 0, 0, 0, 0, 0, 0, 0)); +} + +TYPED_TEST(ErrorsTest, TestSQLErrorEnvErrorFromDriverManager) { + // Test ODBC 2.0 API SQLError. + // Known Windows Driver Manager (DM) behavior: + // When application passes buffer length greater than SQL_MAX_MESSAGE_LENGTH (512), + // DM passes 512 as buffer length to SQLError. + + // Attempt to set environment attribute after connection handle allocation + ASSERT_EQ(SQL_ERROR, SQLSetEnvAttr(this->env, SQL_ATTR_ODBC_VERSION, + reinterpret_cast(SQL_OV_ODBC2), 0)); + + SQLWCHAR sql_state[6] = {0}; + SQLINTEGER native_error = 0; + SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0}; + SQLSMALLINT message_length = 0; + ASSERT_EQ(SQL_SUCCESS, SQLError(this->env, 0, 0, sql_state, &native_error, message, + SQL_MAX_MESSAGE_LENGTH, &message_length)); + + EXPECT_GT(message_length, 50); + + EXPECT_EQ(0, native_error); + + // Function sequence error state from driver manager + EXPECT_EQ(std::wstring(L"HY010"), std::wstring(sql_state)); + + EXPECT_TRUE(!std::wstring(message).empty()); +} + +TYPED_TEST(ErrorsTest, TestSQLErrorConnError) { + // Test ODBC 2.0 API SQLError. + // Known Windows Driver Manager (DM) behavior: + // When application passes buffer length greater than SQL_MAX_MESSAGE_LENGTH (512), + // DM passes 512 as buffer length to SQLError. + + // Attempt to set unsupported attribute + ASSERT_EQ(SQL_ERROR, SQLGetConnectAttr(this->conn, SQL_ATTR_TXN_ISOLATION, 0, 0, 0)); + + SQLWCHAR sql_state[6] = {0}; + SQLINTEGER native_error = 0; + SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0}; + SQLSMALLINT message_length = 0; + ASSERT_EQ(SQL_SUCCESS, SQLError(0, this->conn, 0, sql_state, &native_error, message, + SQL_MAX_MESSAGE_LENGTH, &message_length)); + + EXPECT_GT(message_length, 60); + + EXPECT_EQ(100, native_error); + + // optional feature not supported error state + EXPECT_EQ(std::wstring(L"HYC00"), std::wstring(sql_state)); + + EXPECT_TRUE(!std::wstring(message).empty()); +} + +TYPED_TEST(ErrorsTest, TestSQLErrorStmtError) { + // Test ODBC 2.0 API SQLError. + // Known Windows Driver Manager (DM) behavior: + // When application passes buffer length greater than SQL_MAX_MESSAGE_LENGTH (512), + // DM passes 512 as buffer length to SQLError. + + std::wstring wsql = L"1"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_ERROR, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + SQLWCHAR sql_state[6] = {0}; + SQLINTEGER native_error = 0; + SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0}; + SQLSMALLINT message_length = 0; + ASSERT_EQ(SQL_SUCCESS, SQLError(0, 0, this->stmt, sql_state, &native_error, message, + SQL_MAX_MESSAGE_LENGTH, &message_length)); + + EXPECT_GT(message_length, 70); + + EXPECT_EQ(100, native_error); + + EXPECT_EQ(std::wstring(L"HY000"), std::wstring(sql_state)); + + EXPECT_TRUE(!std::wstring(message).empty()); +} + +TYPED_TEST(ErrorsTest, TestSQLErrorStmtWarning) { + // Test ODBC 2.0 API SQLError. + + std::wstring wsql = L"SELECT 'VERY LONG STRING here' AS string_col;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + const int len = 17; + SQLCHAR char_val[len]; + SQLLEN buf_len = sizeof(SQLCHAR) * len; + SQLLEN ind; + + EXPECT_EQ(SQL_SUCCESS_WITH_INFO, + SQLGetData(this->stmt, 1, SQL_C_CHAR, &char_val, buf_len, &ind)); + + SQLWCHAR sql_state[6] = {0}; + SQLINTEGER native_error = 0; + SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0}; + SQLSMALLINT message_length = 0; + ASSERT_EQ(SQL_SUCCESS, SQLError(0, 0, this->stmt, sql_state, &native_error, message, + SQL_MAX_MESSAGE_LENGTH, &message_length)); + + EXPECT_GT(message_length, 50); + + EXPECT_EQ(1000100, native_error); + + // Verify string truncation warning is reported + EXPECT_EQ(std::wstring(L"01004"), std::wstring(sql_state)); + + EXPECT_TRUE(!std::wstring(message).empty()); +} + +TYPED_TEST(ErrorsOdbcV2Test, TestSQLErrorEnvErrorFromDriverManager) { + // Test ODBC 2.0 API SQLError with ODBC ver 2. + // Known Windows Driver Manager (DM) behavior: + // When application passes buffer length greater than SQL_MAX_MESSAGE_LENGTH (512), + // DM passes 512 as buffer length to SQLError. + + // Attempt to set environment attribute after connection handle allocation + ASSERT_EQ(SQL_ERROR, SQLSetEnvAttr(this->env, SQL_ATTR_ODBC_VERSION, + reinterpret_cast(SQL_OV_ODBC2), 0)); + + SQLWCHAR sql_state[6] = {0}; + SQLINTEGER native_error = 0; + SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0}; + SQLSMALLINT message_length = 0; + ASSERT_EQ(SQL_SUCCESS, SQLError(this->env, 0, 0, sql_state, &native_error, message, + SQL_MAX_MESSAGE_LENGTH, &message_length)); + + EXPECT_GT(message_length, 50); + + EXPECT_EQ(0, native_error); + + // Function sequence error state from driver manager + EXPECT_EQ(std::wstring(L"S1010"), std::wstring(sql_state)); + + EXPECT_TRUE(!std::wstring(message).empty()); +} + +TYPED_TEST(ErrorsOdbcV2Test, TestSQLErrorConnError) { + // Test ODBC 2.0 API SQLError with ODBC ver 2. + // Known Windows Driver Manager (DM) behavior: + // When application passes buffer length greater than SQL_MAX_MESSAGE_LENGTH (512), + // DM passes 512 as buffer length to SQLError. + + // Attempt to set unsupported attribute + ASSERT_EQ(SQL_ERROR, SQLGetConnectAttr(this->conn, SQL_ATTR_TXN_ISOLATION, 0, 0, 0)); + + SQLWCHAR sql_state[6] = {0}; + SQLINTEGER native_error = 0; + SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0}; + SQLSMALLINT message_length = 0; + ASSERT_EQ(SQL_SUCCESS, SQLError(0, this->conn, 0, sql_state, &native_error, message, + SQL_MAX_MESSAGE_LENGTH, &message_length)); + + EXPECT_GT(message_length, 60); + + EXPECT_EQ(100, native_error); + + // optional feature not supported error state. Driver Manager maps state to S1C00 + EXPECT_EQ(std::wstring(L"S1C00"), std::wstring(sql_state)); + + EXPECT_TRUE(!std::wstring(message).empty()); +} + +TYPED_TEST(ErrorsOdbcV2Test, TestSQLErrorStmtError) { + // Test ODBC 2.0 API SQLError with ODBC ver 2. + // Known Windows Driver Manager (DM) behavior: + // When application passes buffer length greater than SQL_MAX_MESSAGE_LENGTH (512), + // DM passes 512 as buffer length to SQLError. + + std::wstring wsql = L"1"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_ERROR, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + SQLWCHAR sql_state[6] = {0}; + SQLINTEGER native_error = 0; + SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0}; + SQLSMALLINT message_length = 0; + ASSERT_EQ(SQL_SUCCESS, SQLError(0, 0, this->stmt, sql_state, &native_error, message, + SQL_MAX_MESSAGE_LENGTH, &message_length)); + + EXPECT_GT(message_length, 70); + + EXPECT_EQ(100, native_error); + + // Driver Manager maps error state to S1000 + EXPECT_EQ(std::wstring(L"S1000"), std::wstring(sql_state)); + + EXPECT_TRUE(!std::wstring(message).empty()); +} + +TYPED_TEST(ErrorsOdbcV2Test, TestSQLErrorStmtWarning) { + // Test ODBC 2.0 API SQLError. + + std::wstring wsql = L"SELECT 'VERY LONG STRING here' AS string_col;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + const int len = 17; + SQLCHAR char_val[len]; + SQLLEN buf_len = sizeof(SQLCHAR) * len; + SQLLEN ind; + + EXPECT_EQ(SQL_SUCCESS_WITH_INFO, + SQLGetData(this->stmt, 1, SQL_C_CHAR, &char_val, buf_len, &ind)); + + SQLWCHAR sql_state[6] = {0}; + SQLINTEGER native_error = 0; + SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0}; + SQLSMALLINT message_length = 0; + ASSERT_EQ(SQL_SUCCESS, SQLError(0, 0, this->stmt, sql_state, &native_error, message, + SQL_MAX_MESSAGE_LENGTH, &message_length)); + + EXPECT_GT(message_length, 50); + + EXPECT_EQ(1000100, native_error); + + // Verify string truncation warning is reported + EXPECT_EQ(std::wstring(L"01004"), std::wstring(sql_state)); + + EXPECT_TRUE(!std::wstring(message).empty()); +} + +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/get_functions_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/get_functions_test.cc new file mode 100644 index 00000000000..3b47b80cf05 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/get_functions_test.cc @@ -0,0 +1,220 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#include "arrow/flight/sql/odbc/odbc_impl/platform.h" + +#include +#include +#include + +#include + +namespace arrow::flight::sql::odbc { + +template +class GetFunctionsTest : public T {}; + +using TestTypes = + ::testing::Types; +TYPED_TEST_SUITE(GetFunctionsTest, TestTypes); + +template +class GetFunctionsOdbcV2Test : public T {}; + +using TestTypesOdbcV2 = + ::testing::Types; +TYPED_TEST_SUITE(GetFunctionsOdbcV2Test, TestTypesOdbcV2); + +TYPED_TEST(GetFunctionsTest, TestSQLGetFunctionsAllFunctions) { + // Verify driver manager return values for SQLGetFunctions + + SQLUSMALLINT api_exists[SQL_API_ODBC3_ALL_FUNCTIONS_SIZE]; + const std::vector supported_functions = { + SQL_API_SQLALLOCHANDLE, SQL_API_SQLBINDCOL, SQL_API_SQLGETDIAGFIELD, + SQL_API_SQLCANCEL, SQL_API_SQLCLOSECURSOR, SQL_API_SQLGETDIAGREC, + SQL_API_SQLCOLATTRIBUTE, SQL_API_SQLGETENVATTR, SQL_API_SQLCONNECT, + SQL_API_SQLGETINFO, SQL_API_SQLGETSTMTATTR, SQL_API_SQLDESCRIBECOL, + SQL_API_SQLGETTYPEINFO, SQL_API_SQLDISCONNECT, SQL_API_SQLNUMRESULTCOLS, + SQL_API_SQLPREPARE, SQL_API_SQLEXECDIRECT, SQL_API_SQLEXECUTE, SQL_API_SQLROWCOUNT, + SQL_API_SQLFETCH, SQL_API_SQLSETCONNECTATTR, SQL_API_SQLFETCHSCROLL, + SQL_API_SQLFREEHANDLE, SQL_API_SQLFREESTMT, SQL_API_SQLGETCONNECTATTR, + SQL_API_SQLSETENVATTR, SQL_API_SQLSETSTMTATTR, SQL_API_SQLGETDATA, + SQL_API_SQLCOLUMNS, SQL_API_SQLTABLES, SQL_API_SQLNATIVESQL, + SQL_API_SQLDRIVERCONNECT, SQL_API_SQLMORERESULTS, SQL_API_SQLPRIMARYKEYS, + SQL_API_SQLFOREIGNKEYS, + + // ODBC 2.0 APIs + SQL_API_SQLSETSTMTOPTION, SQL_API_SQLGETSTMTOPTION, SQL_API_SQLSETCONNECTOPTION, + SQL_API_SQLGETCONNECTOPTION, SQL_API_SQLALLOCCONNECT, SQL_API_SQLALLOCENV, + SQL_API_SQLALLOCSTMT, SQL_API_SQLFREEENV, SQL_API_SQLFREECONNECT, + + // Driver Manager APIs + SQL_API_SQLGETFUNCTIONS, SQL_API_SQLDRIVERS, SQL_API_SQLDATASOURCES}; + const std::vector unsupported_functions = { + SQL_API_SQLPUTDATA, SQL_API_SQLGETDESCFIELD, SQL_API_SQLGETDESCREC, + SQL_API_SQLCOPYDESC, SQL_API_SQLPARAMDATA, SQL_API_SQLENDTRAN, + SQL_API_SQLSETCURSORNAME, SQL_API_SQLSETDESCFIELD, SQL_API_SQLSETDESCREC, + SQL_API_SQLGETCURSORNAME, SQL_API_SQLSTATISTICS, SQL_API_SQLSPECIALCOLUMNS, + SQL_API_SQLBINDPARAMETER, SQL_API_SQLBROWSECONNECT, SQL_API_SQLNUMPARAMS, + SQL_API_SQLBULKOPERATIONS, SQL_API_SQLCOLUMNPRIVILEGES, SQL_API_SQLPROCEDURECOLUMNS, + SQL_API_SQLDESCRIBEPARAM, SQL_API_SQLPROCEDURES, SQL_API_SQLSETPOS, + SQL_API_SQLTABLEPRIVILEGES}; + + ASSERT_EQ(SQL_SUCCESS, + SQLGetFunctions(this->conn, SQL_API_ODBC3_ALL_FUNCTIONS, api_exists)); + + for (int api : supported_functions) { + EXPECT_EQ(SQL_TRUE, SQL_FUNC_EXISTS(api_exists, api)); + } + + for (int api : unsupported_functions) { + EXPECT_EQ(SQL_FALSE, SQL_FUNC_EXISTS(api_exists, api)); + } +} + +TYPED_TEST(GetFunctionsOdbcV2Test, TestSQLGetFunctionsAllFunctions) { + // Verify driver manager return values for SQLGetFunctions + + // ODBC 2.0 SQLGetFunctions returns 100 elements according to spec + SQLUSMALLINT api_exists[100]; + const std::vector supported_functions = { + SQL_API_SQLCONNECT, SQL_API_SQLGETINFO, SQL_API_SQLDESCRIBECOL, + SQL_API_SQLGETTYPEINFO, SQL_API_SQLDISCONNECT, SQL_API_SQLNUMRESULTCOLS, + SQL_API_SQLPREPARE, SQL_API_SQLEXECDIRECT, SQL_API_SQLEXECUTE, SQL_API_SQLROWCOUNT, + SQL_API_SQLFETCH, SQL_API_SQLFREESTMT, SQL_API_SQLGETDATA, SQL_API_SQLCOLUMNS, + SQL_API_SQLTABLES, SQL_API_SQLNATIVESQL, SQL_API_SQLDRIVERCONNECT, + SQL_API_SQLMORERESULTS, SQL_API_SQLSETSTMTOPTION, SQL_API_SQLGETSTMTOPTION, + SQL_API_SQLSETCONNECTOPTION, SQL_API_SQLGETCONNECTOPTION, SQL_API_SQLALLOCCONNECT, + SQL_API_SQLALLOCENV, SQL_API_SQLALLOCSTMT, SQL_API_SQLFREEENV, + SQL_API_SQLFREECONNECT, SQL_API_SQLPRIMARYKEYS, SQL_API_SQLFOREIGNKEYS, + + // Driver Manager APIs + SQL_API_SQLGETFUNCTIONS, SQL_API_SQLDRIVERS, SQL_API_SQLDATASOURCES}; + const std::vector unsupported_functions = { + SQL_API_SQLPUTDATA, SQL_API_SQLPARAMDATA, SQL_API_SQLSETCURSORNAME, + SQL_API_SQLGETCURSORNAME, SQL_API_SQLSTATISTICS, SQL_API_SQLSPECIALCOLUMNS, + SQL_API_SQLBINDPARAMETER, SQL_API_SQLBROWSECONNECT, SQL_API_SQLNUMPARAMS, + SQL_API_SQLBULKOPERATIONS, SQL_API_SQLCOLUMNPRIVILEGES, SQL_API_SQLPROCEDURECOLUMNS, + SQL_API_SQLDESCRIBEPARAM, SQL_API_SQLPROCEDURES, SQL_API_SQLSETPOS, + SQL_API_SQLTABLEPRIVILEGES}; + ASSERT_EQ(SQL_SUCCESS, SQLGetFunctions(this->conn, SQL_API_ALL_FUNCTIONS, api_exists)); + + for (int api : supported_functions) { + EXPECT_EQ(SQL_TRUE, api_exists[api]); + } + + for (int api : unsupported_functions) { + EXPECT_EQ(SQL_FALSE, api_exists[api]); + } +} + +TYPED_TEST(GetFunctionsTest, TestSQLGetFunctionsSupportedSingleAPI) { + const std::vector supported_functions = { + SQL_API_SQLALLOCHANDLE, SQL_API_SQLBINDCOL, SQL_API_SQLGETDIAGFIELD, + SQL_API_SQLCANCEL, SQL_API_SQLCLOSECURSOR, SQL_API_SQLGETDIAGREC, + SQL_API_SQLCOLATTRIBUTE, SQL_API_SQLGETENVATTR, SQL_API_SQLCONNECT, + SQL_API_SQLGETINFO, SQL_API_SQLGETSTMTATTR, SQL_API_SQLDESCRIBECOL, + SQL_API_SQLGETTYPEINFO, SQL_API_SQLDISCONNECT, SQL_API_SQLNUMRESULTCOLS, + SQL_API_SQLPREPARE, SQL_API_SQLEXECDIRECT, SQL_API_SQLEXECUTE, SQL_API_SQLROWCOUNT, + SQL_API_SQLFETCH, SQL_API_SQLSETCONNECTATTR, SQL_API_SQLFETCHSCROLL, + SQL_API_SQLFREEHANDLE, SQL_API_SQLFREESTMT, SQL_API_SQLGETCONNECTATTR, + SQL_API_SQLSETENVATTR, SQL_API_SQLSETSTMTATTR, SQL_API_SQLGETDATA, + SQL_API_SQLCOLUMNS, SQL_API_SQLTABLES, SQL_API_SQLNATIVESQL, + SQL_API_SQLDRIVERCONNECT, SQL_API_SQLMORERESULTS, SQL_API_SQLPRIMARYKEYS, + SQL_API_SQLFOREIGNKEYS, + + // ODBC 2.0 APIs + SQL_API_SQLSETSTMTOPTION, SQL_API_SQLGETSTMTOPTION, SQL_API_SQLSETCONNECTOPTION, + SQL_API_SQLGETCONNECTOPTION, SQL_API_SQLALLOCCONNECT, SQL_API_SQLALLOCENV, + SQL_API_SQLALLOCSTMT, SQL_API_SQLFREEENV, SQL_API_SQLFREECONNECT, + + // Driver Manager APIs + SQL_API_SQLGETFUNCTIONS, SQL_API_SQLDRIVERS, SQL_API_SQLDATASOURCES}; + SQLUSMALLINT api_exists; + for (SQLUSMALLINT api : supported_functions) { + ASSERT_EQ(SQL_SUCCESS, SQLGetFunctions(this->conn, api, &api_exists)); + + EXPECT_EQ(SQL_TRUE, api_exists); + + api_exists = -1; + } +} + +TYPED_TEST(GetFunctionsTest, TestSQLGetFunctionsUnsupportedSingleAPI) { + const std::vector unsupported_functions = { + SQL_API_SQLPUTDATA, SQL_API_SQLGETDESCFIELD, SQL_API_SQLGETDESCREC, + SQL_API_SQLCOPYDESC, SQL_API_SQLPARAMDATA, SQL_API_SQLENDTRAN, + SQL_API_SQLSETCURSORNAME, SQL_API_SQLSETDESCFIELD, SQL_API_SQLSETDESCREC, + SQL_API_SQLGETCURSORNAME, SQL_API_SQLSTATISTICS, SQL_API_SQLSPECIALCOLUMNS, + SQL_API_SQLBINDPARAMETER, SQL_API_SQLBROWSECONNECT, SQL_API_SQLNUMPARAMS, + SQL_API_SQLBULKOPERATIONS, SQL_API_SQLCOLUMNPRIVILEGES, SQL_API_SQLPROCEDURECOLUMNS, + SQL_API_SQLDESCRIBEPARAM, SQL_API_SQLPROCEDURES, SQL_API_SQLSETPOS, + SQL_API_SQLTABLEPRIVILEGES}; + SQLUSMALLINT api_exists; + for (SQLUSMALLINT api : unsupported_functions) { + ASSERT_EQ(SQL_SUCCESS, SQLGetFunctions(this->conn, api, &api_exists)); + + EXPECT_EQ(SQL_FALSE, api_exists); + + api_exists = -1; + } +} + +TYPED_TEST(GetFunctionsOdbcV2Test, TestSQLGetFunctionsSupportedSingleAPI) { + const std::vector supported_functions = { + SQL_API_SQLCONNECT, SQL_API_SQLGETINFO, SQL_API_SQLDESCRIBECOL, + SQL_API_SQLGETTYPEINFO, SQL_API_SQLDISCONNECT, SQL_API_SQLNUMRESULTCOLS, + SQL_API_SQLPREPARE, SQL_API_SQLEXECDIRECT, SQL_API_SQLEXECUTE, SQL_API_SQLROWCOUNT, + SQL_API_SQLFETCH, SQL_API_SQLFREESTMT, SQL_API_SQLGETDATA, SQL_API_SQLCOLUMNS, + SQL_API_SQLTABLES, SQL_API_SQLNATIVESQL, SQL_API_SQLDRIVERCONNECT, + SQL_API_SQLMORERESULTS, SQL_API_SQLSETSTMTOPTION, SQL_API_SQLGETSTMTOPTION, + SQL_API_SQLSETCONNECTOPTION, SQL_API_SQLGETCONNECTOPTION, SQL_API_SQLALLOCCONNECT, + SQL_API_SQLALLOCENV, SQL_API_SQLALLOCSTMT, SQL_API_SQLFREEENV, + SQL_API_SQLFREECONNECT, SQL_API_SQLPRIMARYKEYS, SQL_API_SQLFOREIGNKEYS, + + // Driver Manager APIs + SQL_API_SQLGETFUNCTIONS, SQL_API_SQLDRIVERS, SQL_API_SQLDATASOURCES}; + SQLUSMALLINT api_exists; + for (SQLUSMALLINT api : supported_functions) { + ASSERT_EQ(SQL_SUCCESS, SQLGetFunctions(this->conn, api, &api_exists)); + + EXPECT_EQ(SQL_TRUE, api_exists); + + api_exists = -1; + } +} + +TYPED_TEST(GetFunctionsOdbcV2Test, TestSQLGetFunctionsUnsupportedSingleAPI) { + const std::vector unsupported_functions = { + SQL_API_SQLPUTDATA, SQL_API_SQLPARAMDATA, SQL_API_SQLSETCURSORNAME, + SQL_API_SQLGETCURSORNAME, SQL_API_SQLSTATISTICS, SQL_API_SQLSPECIALCOLUMNS, + SQL_API_SQLBINDPARAMETER, SQL_API_SQLBROWSECONNECT, SQL_API_SQLNUMPARAMS, + SQL_API_SQLBULKOPERATIONS, SQL_API_SQLCOLUMNPRIVILEGES, SQL_API_SQLPROCEDURECOLUMNS, + SQL_API_SQLDESCRIBEPARAM, SQL_API_SQLPROCEDURES, SQL_API_SQLSETPOS, + SQL_API_SQLTABLEPRIVILEGES}; + SQLUSMALLINT api_exists; + for (SQLUSMALLINT api : unsupported_functions) { + ASSERT_EQ(SQL_SUCCESS, SQLGetFunctions(this->conn, api, &api_exists)); + + EXPECT_EQ(SQL_FALSE, api_exists); + + api_exists = -1; + } +} + +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc index eb6c60b9762..7508632119e 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc +++ b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc @@ -61,13 +61,12 @@ void ODBCRemoteTestBase::ConnectWithString(std::string connect_str) { kOdbcBufferSize, &out_str_len, SQL_DRIVER_NOPROMPT)) << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn); - // GH-47710: TODO Allocate a statement using alloc handle - // ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_STMT, conn, &stmt)); + ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_STMT, conn, &stmt)); } void ODBCRemoteTestBase::Disconnect() { - // GH-47710: TODO Close statement - // EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_STMT, stmt)); + // Close statement + EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_STMT, stmt)); // Disconnect from ODBC EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(conn)) diff --git a/cpp/src/arrow/flight/sql/odbc/tests/statement_attr_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/statement_attr_test.cc new file mode 100644 index 00000000000..680349ac140 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/statement_attr_test.cc @@ -0,0 +1,657 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#include "arrow/flight/sql/odbc/odbc_impl/odbc_statement.h" +#include "arrow/flight/sql/odbc/odbc_impl/spi/statement.h" + +#include "arrow/flight/sql/odbc/odbc_impl/platform.h" + +#include +#include +#include + +#include + +namespace arrow::flight::sql::odbc { + +template +class StatementAttributeTest : public T {}; + +using TestTypes = + ::testing::Types; +TYPED_TEST_SUITE(StatementAttributeTest, TestTypes); + +namespace { +// Helper Functions + +// Get SQLULEN return value +void GetStmtAttr(SQLHSTMT statement, SQLINTEGER attribute, SQLULEN* value) { + SQLINTEGER string_length = 0; + + ASSERT_EQ(SQL_SUCCESS, + SQLGetStmtAttr(statement, attribute, value, sizeof(*value), &string_length)); +} + +// Get SQLLEN return value +void GetStmtAttr(SQLHSTMT statement, SQLINTEGER attribute, SQLLEN* value) { + SQLINTEGER string_length = 0; + + ASSERT_EQ(SQL_SUCCESS, + SQLGetStmtAttr(statement, attribute, value, sizeof(*value), &string_length)); +} + +// Get SQLPOINTER return value +void GetStmtAttr(SQLHSTMT statement, SQLINTEGER attribute, SQLPOINTER* value) { + SQLINTEGER string_length = 0; + + ASSERT_EQ(SQL_SUCCESS, + SQLGetStmtAttr(statement, attribute, value, SQL_IS_POINTER, &string_length)); +} + +// Validate error return value and code +void ValidateGetStmtAttrErrorCode(SQLHSTMT statement, SQLINTEGER attribute, + std::string_view error_code) { + SQLULEN value = 0; + SQLINTEGER string_length_ptr; + + ASSERT_EQ(SQL_ERROR, + SQLGetStmtAttr(statement, attribute, &value, 0, &string_length_ptr)); + + VerifyOdbcErrorState(SQL_HANDLE_STMT, statement, error_code); +} + +// Validate return value for call to SQLSetStmtAttr with SQLULEN +void ValidateSetStmtAttr(SQLHSTMT statement, SQLINTEGER attribute, SQLULEN new_value) { + SQLINTEGER string_length_ptr = sizeof(SQLULEN); + + EXPECT_EQ(SQL_SUCCESS, + SQLSetStmtAttr(statement, attribute, reinterpret_cast(new_value), + string_length_ptr)); +} + +// Validate return value for call to SQLSetStmtAttr with SQLLEN +void ValidateSetStmtAttr(SQLHSTMT statement, SQLINTEGER attribute, SQLLEN new_value) { + SQLINTEGER string_length_ptr = sizeof(SQLLEN); + + EXPECT_EQ(SQL_SUCCESS, + SQLSetStmtAttr(statement, attribute, reinterpret_cast(new_value), + string_length_ptr)); +} + +// Validate return value for call to SQLSetStmtAttr with SQLPOINTER +void ValidateSetStmtAttr(SQLHSTMT statement, SQLINTEGER attribute, SQLPOINTER value) { + EXPECT_EQ(SQL_SUCCESS, SQLSetStmtAttr(statement, attribute, value, 0)); +} + +// Validate error return value and code +void ValidateSetStmtAttrErrorCode(SQLHSTMT statement, SQLINTEGER attribute, + SQLULEN new_value, std::string_view error_code) { + SQLINTEGER string_length_ptr = sizeof(SQLULEN); + + ASSERT_EQ(SQL_ERROR, + SQLSetStmtAttr(statement, attribute, reinterpret_cast(new_value), + string_length_ptr)); + + VerifyOdbcErrorState(SQL_HANDLE_STMT, statement, error_code); +} +} // namespace + +// Test Cases + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrAppParamDesc) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_APP_PARAM_DESC, &value); + + EXPECT_GT(value, static_cast(0)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrAppRowDesc) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_APP_ROW_DESC, &value); + + EXPECT_GT(value, static_cast(0)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrAsyncEnable) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_ASYNC_ENABLE, &value); + + EXPECT_EQ(static_cast(SQL_ASYNC_ENABLE_OFF), value); +} + +#ifdef SQL_ATTR_ASYNC_STMT_EVENT +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrAsyncStmtEventUnsupported) { + // Optional feature not implemented + ValidateGetStmtAttrErrorCode(this->stmt, SQL_ATTR_ASYNC_STMT_EVENT, kErrorStateHYC00); +} +#endif + +#ifdef SQL_ATTR_ASYNC_STMT_PCALLBACK +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrAsyncStmtPCCallbackUnsupported) { + // Optional feature not implemented + ValidateGetStmtAttrErrorCode(this->stmt, SQL_ATTR_ASYNC_STMT_PCALLBACK, + kErrorStateHYC00); +} +#endif + +#ifdef SQL_ATTR_ASYNC_STMT_PCONTEXT +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrAsyncStmtPCContextUnsupported) { + // Optional feature not implemented + ValidateGetStmtAttrErrorCode(this->stmt, SQL_ATTR_ASYNC_STMT_PCONTEXT, + kErrorStateHYC00); +} +#endif + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrConcurrency) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_CONCURRENCY, &value); + + EXPECT_EQ(static_cast(SQL_CONCUR_READ_ONLY), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrCursorScrollable) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_CURSOR_SCROLLABLE, &value); + + EXPECT_EQ(static_cast(SQL_NONSCROLLABLE), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrCursorSensitivity) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_CURSOR_SENSITIVITY, &value); + + EXPECT_EQ(static_cast(SQL_UNSPECIFIED), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrCursorType) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_CURSOR_TYPE, &value); + + EXPECT_EQ(static_cast(SQL_CURSOR_FORWARD_ONLY), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrEnableAutoIPD) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_ENABLE_AUTO_IPD, &value); + + EXPECT_EQ(static_cast(SQL_FALSE), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrFetchBookmarkPointer) { + SQLLEN value; + GetStmtAttr(this->stmt, SQL_ATTR_FETCH_BOOKMARK_PTR, &value); + + EXPECT_EQ(static_cast(NULL), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrIMPParamDesc) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_IMP_PARAM_DESC, &value); + + EXPECT_GT(value, static_cast(0)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrIMPRowDesc) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_IMP_ROW_DESC, &value); + + EXPECT_GT(value, static_cast(0)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrKeysetSize) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_KEYSET_SIZE, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrMaxLength) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_MAX_LENGTH, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrMaxRows) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_MAX_ROWS, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrMetadataID) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_METADATA_ID, &value); + + EXPECT_EQ(static_cast(SQL_FALSE), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrNoscan) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_NOSCAN, &value); + + EXPECT_EQ(static_cast(SQL_NOSCAN_OFF), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrParamBindOffsetPtr) { + SQLPOINTER value = nullptr; + GetStmtAttr(this->stmt, SQL_ATTR_PARAM_BIND_OFFSET_PTR, &value); + + EXPECT_EQ(static_cast(nullptr), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrParamBindType) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_PARAM_BIND_TYPE, &value); + + EXPECT_EQ(static_cast(SQL_PARAM_BIND_BY_COLUMN), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrParamOperationPtr) { + SQLPOINTER value = nullptr; + GetStmtAttr(this->stmt, SQL_ATTR_PARAM_OPERATION_PTR, &value); + + EXPECT_EQ(static_cast(nullptr), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrParamStatusPtr) { + SQLPOINTER value = nullptr; + GetStmtAttr(this->stmt, SQL_ATTR_PARAM_STATUS_PTR, &value); + + EXPECT_EQ(static_cast(nullptr), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrParamsProcessedPtr) { + SQLPOINTER value = nullptr; + GetStmtAttr(this->stmt, SQL_ATTR_PARAMS_PROCESSED_PTR, &value); + + EXPECT_EQ(static_cast(nullptr), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrParamsetSize) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_PARAMSET_SIZE, &value); + + EXPECT_EQ(static_cast(1), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrQueryTimeout) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_QUERY_TIMEOUT, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrRetrieveData) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_RETRIEVE_DATA, &value); + + EXPECT_EQ(static_cast(SQL_RD_ON), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrRowArraySize) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_ROW_ARRAY_SIZE, &value); + + EXPECT_EQ(static_cast(1), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrRowBindOffsetPtr) { + SQLPOINTER value = nullptr; + GetStmtAttr(this->stmt, SQL_ATTR_ROW_BIND_OFFSET_PTR, &value); + + EXPECT_EQ(static_cast(nullptr), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrRowBindType) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_ROW_BIND_TYPE, &value); + + EXPECT_EQ(static_cast(0), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrRowNumber) { + std::wstring wsql = L"SELECT 1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_ROW_NUMBER, &value); + + EXPECT_EQ(static_cast(1), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrRowOperationPtr) { + SQLPOINTER value = nullptr; + GetStmtAttr(this->stmt, SQL_ATTR_ROW_OPERATION_PTR, &value); + + EXPECT_EQ(static_cast(nullptr), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrRowStatusPtr) { + SQLPOINTER value = nullptr; + GetStmtAttr(this->stmt, SQL_ATTR_ROW_STATUS_PTR, &value); + + EXPECT_EQ(static_cast(nullptr), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrRowsFetchedPtr) { + SQLPOINTER value = nullptr; + GetStmtAttr(this->stmt, SQL_ATTR_ROWS_FETCHED_PTR, &value); + + EXPECT_EQ(static_cast(nullptr), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrSimulateCursor) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_SIMULATE_CURSOR, &value); + + EXPECT_EQ(static_cast(SQL_SC_UNIQUE), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrUseBookmarks) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ATTR_USE_BOOKMARKS, &value); + + EXPECT_EQ(static_cast(SQL_UB_OFF), value); +} + +// This is a pre ODBC 3 attribute +TYPED_TEST(StatementAttributeTest, TestSQLGetStmtAttrRowsetSize) { + SQLULEN value; + GetStmtAttr(this->stmt, SQL_ROWSET_SIZE, &value); + + EXPECT_EQ(static_cast(1), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrAppParamDesc) { + SQLULEN app_param_desc = 0; + SQLINTEGER string_length_ptr; + + ASSERT_EQ(SQL_SUCCESS, SQLGetStmtAttr(this->stmt, SQL_ATTR_APP_PARAM_DESC, + &app_param_desc, 0, &string_length_ptr)); + + ValidateSetStmtAttr(this->stmt, SQL_ATTR_APP_PARAM_DESC, static_cast(0)); + + ValidateSetStmtAttr(this->stmt, SQL_ATTR_APP_PARAM_DESC, + static_cast(app_param_desc)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrAppRowDesc) { + SQLULEN app_row_desc = 0; + SQLINTEGER string_length_ptr; + + ASSERT_EQ(SQL_SUCCESS, SQLGetStmtAttr(this->stmt, SQL_ATTR_APP_ROW_DESC, &app_row_desc, + 0, &string_length_ptr)); + + ValidateSetStmtAttr(this->stmt, SQL_ATTR_APP_ROW_DESC, static_cast(0)); + + ValidateSetStmtAttr(this->stmt, SQL_ATTR_APP_ROW_DESC, + static_cast(app_row_desc)); +} + +#ifdef SQL_ATTR_ASYNC_ENABLE +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrAsyncEnableUnsupported) { + // Optional feature not implemented + ValidateSetStmtAttrErrorCode(this->stmt, SQL_ATTR_ASYNC_ENABLE, SQL_ASYNC_ENABLE_OFF, + kErrorStateHYC00); +} +#endif + +#ifdef SQL_ATTR_ASYNC_STMT_EVENT +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrAsyncStmtEventUnsupported) { + // Driver does not support asynchronous notification + ValidateSetStmtAttrErrorCode(this->stmt, SQL_ATTR_ASYNC_STMT_EVENT, 0, + kErrorStateHY118); +} +#endif + +#ifdef SQL_ATTR_ASYNC_STMT_PCALLBACK +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrAsyncStmtPCCallbackUnsupported) { + ValidateSetStmtAttrErrorCode(this->stmt, SQL_ATTR_ASYNC_STMT_PCALLBACK, 0, + kErrorStateHYC00); +} +#endif + +#ifdef SQL_ATTR_ASYNC_STMT_PCONTEXT +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrAsyncStmtPCContextUnsupported) { + // Optional feature not implemented + ValidateSetStmtAttrErrorCode(this->stmt, SQL_ATTR_ASYNC_STMT_PCONTEXT, 0, + kErrorStateHYC00); +} +#endif + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrConcurrency) { + ValidateSetStmtAttr(this->stmt, SQL_ATTR_CONCURRENCY, + static_cast(SQL_CONCUR_READ_ONLY)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrCursorScrollable) { + ValidateSetStmtAttr(this->stmt, SQL_ATTR_CURSOR_SCROLLABLE, + static_cast(SQL_NONSCROLLABLE)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrCursorSensitivity) { + ValidateSetStmtAttr(this->stmt, SQL_ATTR_CURSOR_SENSITIVITY, + static_cast(SQL_UNSPECIFIED)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrCursorType) { + ValidateSetStmtAttr(this->stmt, SQL_ATTR_CURSOR_TYPE, + static_cast(SQL_CURSOR_FORWARD_ONLY)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrEnableAutoIPD) { + ValidateSetStmtAttr(this->stmt, SQL_ATTR_ENABLE_AUTO_IPD, + static_cast(SQL_FALSE)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrFetchBookmarkPointer) { + ValidateSetStmtAttr(this->stmt, SQL_ATTR_FETCH_BOOKMARK_PTR, static_cast(NULL)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrIMPParamDesc) { + // Invalid use of an automatically allocated descriptor handle + ValidateSetStmtAttrErrorCode(this->stmt, SQL_ATTR_IMP_PARAM_DESC, + static_cast(0), kErrorStateHY017); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrIMPRowDesc) { + // Invalid use of an automatically allocated descriptor handle + ValidateSetStmtAttrErrorCode(this->stmt, SQL_ATTR_IMP_ROW_DESC, static_cast(0), + kErrorStateHY017); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrKeysetSizeUnsupported) { + ValidateSetStmtAttr(this->stmt, SQL_ATTR_KEYSET_SIZE, static_cast(0)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrMaxLength) { + ValidateSetStmtAttr(this->stmt, SQL_ATTR_MAX_LENGTH, static_cast(0)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrMaxRows) { + // Cannot set read-only attribute + ValidateSetStmtAttrErrorCode(this->stmt, SQL_ATTR_MAX_ROWS, static_cast(0), + kErrorStateHY092); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrMetadataID) { + ValidateSetStmtAttr(this->stmt, SQL_ATTR_METADATA_ID, static_cast(SQL_FALSE)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrNoscan) { + ValidateSetStmtAttr(this->stmt, SQL_ATTR_NOSCAN, static_cast(SQL_NOSCAN_OFF)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrParamBindOffsetPtr) { + SQLULEN offset = 1000; + + ValidateSetStmtAttr(this->stmt, SQL_ATTR_PARAM_BIND_OFFSET_PTR, + static_cast(&offset)); + + SQLPOINTER value = nullptr; + GetStmtAttr(this->stmt, SQL_ATTR_PARAM_BIND_OFFSET_PTR, &value); + + EXPECT_EQ(static_cast(&offset), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrParamBindType) { + ValidateSetStmtAttr(this->stmt, SQL_ATTR_PARAM_BIND_TYPE, + static_cast(SQL_PARAM_BIND_BY_COLUMN)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrParamOperationPtr) { + constexpr SQLULEN param_set_size = 4; + SQLUSMALLINT param_operations[param_set_size] = {SQL_PARAM_PROCEED, SQL_PARAM_IGNORE, + SQL_PARAM_PROCEED, SQL_PARAM_IGNORE}; + + ValidateSetStmtAttr(this->stmt, SQL_ATTR_PARAM_OPERATION_PTR, + static_cast(param_operations)); + + SQLPOINTER value = nullptr; + GetStmtAttr(this->stmt, SQL_ATTR_PARAM_OPERATION_PTR, &value); + + EXPECT_EQ(static_cast(param_operations), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrParamStatusPtr) { + // Driver does not support parameters, so just check array can be saved/retrieved + constexpr SQLULEN param_status_size = 4; + SQLUSMALLINT param_status[param_status_size] = {SQL_PARAM_PROCEED, SQL_PARAM_IGNORE, + SQL_PARAM_PROCEED, SQL_PARAM_IGNORE}; + + ValidateSetStmtAttr(this->stmt, SQL_ATTR_PARAM_STATUS_PTR, + static_cast(param_status)); + + SQLPOINTER value = nullptr; + GetStmtAttr(this->stmt, SQL_ATTR_PARAM_STATUS_PTR, &value); + + EXPECT_EQ(static_cast(param_status), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrParamsProcessedPtr) { + SQLULEN processed_count = 0; + + ValidateSetStmtAttr(this->stmt, SQL_ATTR_PARAMS_PROCESSED_PTR, + static_cast(&processed_count)); + + SQLPOINTER value = nullptr; + GetStmtAttr(this->stmt, SQL_ATTR_PARAMS_PROCESSED_PTR, &value); + + EXPECT_EQ(static_cast(&processed_count), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrParamsetSize) { + ValidateSetStmtAttr(this->stmt, SQL_ATTR_PARAMSET_SIZE, static_cast(1)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrQueryTimeout) { + ValidateSetStmtAttr(this->stmt, SQL_ATTR_QUERY_TIMEOUT, static_cast(1)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrRetrieveData) { + ValidateSetStmtAttr(this->stmt, SQL_ATTR_RETRIEVE_DATA, + static_cast(SQL_RD_ON)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrRowArraySize) { + ValidateSetStmtAttr(this->stmt, SQL_ATTR_ROW_ARRAY_SIZE, static_cast(1)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrRowBindOffsetPtr) { + SQLULEN offset = 1000; + + ValidateSetStmtAttr(this->stmt, SQL_ATTR_ROW_BIND_OFFSET_PTR, + static_cast(&offset)); + + SQLPOINTER value = nullptr; + GetStmtAttr(this->stmt, SQL_ATTR_ROW_BIND_OFFSET_PTR, &value); + + EXPECT_EQ(static_cast(&offset), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrRowBindType) { + ValidateSetStmtAttr(this->stmt, SQL_ATTR_ROW_BIND_TYPE, static_cast(0)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrRowNumber) { + // Cannot set read-only attribute + ValidateSetStmtAttrErrorCode(this->stmt, SQL_ATTR_ROW_NUMBER, static_cast(0), + kErrorStateHY092); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrRowOperationPtr) { + constexpr SQLULEN param_set_size = 4; + SQLUSMALLINT row_operations[param_set_size] = {SQL_ROW_PROCEED, SQL_ROW_IGNORE, + SQL_ROW_PROCEED, SQL_ROW_IGNORE}; + + ValidateSetStmtAttr(this->stmt, SQL_ATTR_ROW_OPERATION_PTR, + static_cast(row_operations)); + + SQLPOINTER value = nullptr; + GetStmtAttr(this->stmt, SQL_ATTR_ROW_OPERATION_PTR, &value); + + EXPECT_EQ(static_cast(row_operations), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrRowStatusPtr) { + constexpr SQLULEN row_status_size = 4; + SQLUSMALLINT values[4] = {0, 0, 0, 0}; + + ValidateSetStmtAttr(this->stmt, SQL_ATTR_ROW_STATUS_PTR, + static_cast(values)); + + SQLPOINTER value = nullptr; + GetStmtAttr(this->stmt, SQL_ATTR_ROW_STATUS_PTR, &value); + + EXPECT_EQ(static_cast(values), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrRowsFetchedPtr) { + SQLULEN rows_fetched = 1; + + ValidateSetStmtAttr(this->stmt, SQL_ATTR_ROWS_FETCHED_PTR, + static_cast(&rows_fetched)); + + SQLPOINTER value = nullptr; + GetStmtAttr(this->stmt, SQL_ATTR_ROWS_FETCHED_PTR, &value); + + EXPECT_EQ(static_cast(&rows_fetched), value); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrSimulateCursor) { + ValidateSetStmtAttr(this->stmt, SQL_ATTR_SIMULATE_CURSOR, + static_cast(SQL_SC_UNIQUE)); +} + +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrUseBookmarks) { + ValidateSetStmtAttr(this->stmt, SQL_ATTR_USE_BOOKMARKS, + static_cast(SQL_UB_OFF)); +} + +// This is a pre ODBC 3 attribute +TYPED_TEST(StatementAttributeTest, TestSQLSetStmtAttrRowsetSize) { + ValidateSetStmtAttr(this->stmt, SQL_ROWSET_SIZE, static_cast(1)); +} + +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc new file mode 100644 index 00000000000..fe8817df559 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc @@ -0,0 +1,2139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#include "arrow/flight/sql/odbc/odbc_impl/platform.h" + +#include +#include +#include + +#include + +#include +#include + +namespace arrow::flight::sql::odbc { + +template +class StatementTest : public T {}; + +class StatementMockTest : public FlightSQLODBCMockTestBase {}; +class StatementRemoteTest : public FlightSQLODBCRemoteTestBase {}; +using TestTypes = ::testing::Types; +TYPED_TEST_SUITE(StatementTest, TestTypes); + +TYPED_TEST(StatementTest, TestSQLExecDirectSimpleQuery) { + std::wstring wsql = L"SELECT 1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + SQLINTEGER val; + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0)); + // Verify 1 is returned + EXPECT_EQ(1, val); + + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); + + ASSERT_EQ(SQL_ERROR, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0)); + // Invalid cursor state + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState24000); +} + +TYPED_TEST(StatementTest, TestSQLExecDirectInvalidQuery) { + std::wstring wsql = L"SELECT;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_ERROR, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + // ODBC provides generic error code HY000 to all statement errors + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY000); +} + +TYPED_TEST(StatementTest, TestSQLExecuteSimpleQuery) { + std::wstring wsql = L"SELECT 1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLPrepare(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLExecute(this->stmt)); + + // Fetch data + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + SQLINTEGER val; + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0)); + + // Verify 1 is returned + EXPECT_EQ(1, val); + + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); + + ASSERT_EQ(SQL_ERROR, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0)); + // Invalid cursor state + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState24000); +} + +TYPED_TEST(StatementTest, TestSQLPrepareInvalidQuery) { + std::wstring wsql = L"SELECT;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_ERROR, + SQLPrepare(this->stmt, &sql0[0], static_cast(sql0.size()))); + // ODBC provides generic error code HY000 to all statement errors + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY000); + + ASSERT_EQ(SQL_ERROR, SQLExecute(this->stmt)); + // Verify function sequence error state is returned + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY010); +} + +TYPED_TEST(StatementTest, TestSQLExecDirectDataQuery) { + std::wstring wsql = this->GetQueryAllDataTypes(); + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Numeric Types + + // Signed Tiny Int + int8_t stiny_int_val; + SQLLEN buf_len = sizeof(stiny_int_val); + SQLLEN ind; + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 1, SQL_C_STINYINT, &stiny_int_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::min(), stiny_int_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 2, SQL_C_STINYINT, &stiny_int_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), stiny_int_val); + + // Unsigned Tiny Int + uint8_t utiny_int_val; + buf_len = sizeof(utiny_int_val); + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 3, SQL_C_UTINYINT, &utiny_int_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::min(), utiny_int_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 4, SQL_C_UTINYINT, &utiny_int_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), utiny_int_val); + + // Signed Small Int + int16_t ssmall_int_val; + buf_len = sizeof(ssmall_int_val); + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 5, SQL_C_SSHORT, &ssmall_int_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::min(), ssmall_int_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 6, SQL_C_SSHORT, &ssmall_int_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), ssmall_int_val); + + // Unsigned Small Int + uint16_t usmall_int_val; + buf_len = sizeof(usmall_int_val); + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 7, SQL_C_USHORT, &usmall_int_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::min(), usmall_int_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 8, SQL_C_USHORT, &usmall_int_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), usmall_int_val); + + // Signed Integer + SQLINTEGER slong_val; + buf_len = sizeof(slong_val); + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 9, SQL_C_SLONG, &slong_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::min(), slong_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 10, SQL_C_SLONG, &slong_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), slong_val); + + // Unsigned Integer + SQLUINTEGER ulong_val; + buf_len = sizeof(ulong_val); + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 11, SQL_C_ULONG, &ulong_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::min(), ulong_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 12, SQL_C_ULONG, &ulong_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), ulong_val); + + // Signed Big Int + SQLBIGINT sbig_int_val; + buf_len = sizeof(sbig_int_val); + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 13, SQL_C_SBIGINT, &sbig_int_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::min(), sbig_int_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 14, SQL_C_SBIGINT, &sbig_int_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), sbig_int_val); + + // Unsigned Big Int + SQLUBIGINT ubig_int_val; + buf_len = sizeof(ubig_int_val); + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 15, SQL_C_UBIGINT, &ubig_int_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::min(), ubig_int_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 16, SQL_C_UBIGINT, &ubig_int_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), ubig_int_val); + + // Decimal + SQL_NUMERIC_STRUCT decimal_val; + memset(&decimal_val, 0, sizeof(decimal_val)); + buf_len = sizeof(SQL_NUMERIC_STRUCT); + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 17, SQL_C_NUMERIC, &decimal_val, buf_len, &ind)); + // Check for negative decimal_val value + EXPECT_EQ(0, decimal_val.sign); + EXPECT_EQ(0, decimal_val.scale); + EXPECT_EQ(38, decimal_val.precision); + EXPECT_THAT(decimal_val.val, ::testing::ElementsAre(0xFF, 0xC9, 0x9A, 0x3B, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0)); + + memset(&decimal_val, 0, sizeof(decimal_val)); + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 18, SQL_C_NUMERIC, &decimal_val, buf_len, &ind)); + // Check for positive decimal_val value + EXPECT_EQ(1, decimal_val.sign); + EXPECT_EQ(0, decimal_val.scale); + EXPECT_EQ(38, decimal_val.precision); + EXPECT_THAT(decimal_val.val, ::testing::ElementsAre(0xFF, 0xC9, 0x9A, 0x3B, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0)); + + // Float + float float_val; + buf_len = sizeof(float_val); + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 19, SQL_C_FLOAT, &float_val, buf_len, &ind)); + // Get minimum negative float value + EXPECT_EQ(-std::numeric_limits::max(), float_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 20, SQL_C_FLOAT, &float_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), float_val); + + // Double + SQLDOUBLE double_val; + buf_len = sizeof(double_val); + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 21, SQL_C_DOUBLE, &double_val, buf_len, &ind)); + // Get minimum negative double value + EXPECT_EQ(-std::numeric_limits::max(), double_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 22, SQL_C_DOUBLE, &double_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), double_val); + + // Bit + bool bit_val; + buf_len = sizeof(bit_val); + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 23, SQL_C_BIT, &bit_val, buf_len, &ind)); + EXPECT_EQ(false, bit_val); + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 24, SQL_C_BIT, &bit_val, buf_len, &ind)); + EXPECT_EQ(true, bit_val); + + // Characters + + // Char + SQLCHAR char_val[2]; + buf_len = sizeof(SQLCHAR) * 2; + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 25, SQL_C_CHAR, &char_val, buf_len, &ind)); + EXPECT_EQ('Z', char_val[0]); + + // WChar + SQLWCHAR wchar_val[2]; + size_t wchar_size = arrow::flight::sql::odbc::GetSqlWCharSize(); + buf_len = wchar_size * 2; + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 26, SQL_C_WCHAR, &wchar_val, buf_len, &ind)); + EXPECT_EQ(L'你', wchar_val[0]); + + // WVarchar + SQLWCHAR wvarchar_val[3]; + buf_len = wchar_size * 3; + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 27, SQL_C_WCHAR, &wvarchar_val, buf_len, &ind)); + EXPECT_EQ(L'你', wvarchar_val[0]); + EXPECT_EQ(L'好', wvarchar_val[1]); + + // varchar + SQLCHAR varchar_val[4]; + buf_len = sizeof(SQLCHAR) * 4; + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 28, SQL_C_CHAR, &varchar_val, buf_len, &ind)); + EXPECT_EQ('X', varchar_val[0]); + EXPECT_EQ('Y', varchar_val[1]); + EXPECT_EQ('Z', varchar_val[2]); + + // Date and Timestamp + + // Date + SQL_DATE_STRUCT date_var{}; + buf_len = sizeof(date_var); + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 29, SQL_C_TYPE_DATE, &date_var, buf_len, &ind)); + // Check min values for date. Min valid year is 1400. + EXPECT_EQ(1, date_var.day); + EXPECT_EQ(1, date_var.month); + EXPECT_EQ(1400, date_var.year); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 30, SQL_C_TYPE_DATE, &date_var, buf_len, &ind)); + // Check max values for date. Max valid year is 9999. + EXPECT_EQ(31, date_var.day); + EXPECT_EQ(12, date_var.month); + EXPECT_EQ(9999, date_var.year); + + // Timestamp + SQL_TIMESTAMP_STRUCT timestamp_var{}; + buf_len = sizeof(timestamp_var); + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 31, SQL_C_TYPE_TIMESTAMP, ×tamp_var, + buf_len, &ind)); + // Check min values for date. Min valid year is 1400. + EXPECT_EQ(1, timestamp_var.day); + EXPECT_EQ(1, timestamp_var.month); + EXPECT_EQ(1400, timestamp_var.year); + EXPECT_EQ(0, timestamp_var.hour); + EXPECT_EQ(0, timestamp_var.minute); + EXPECT_EQ(0, timestamp_var.second); + EXPECT_EQ(0, timestamp_var.fraction); + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 32, SQL_C_TYPE_TIMESTAMP, ×tamp_var, + buf_len, &ind)); + // Check max values for date. Max valid year is 9999. + EXPECT_EQ(31, timestamp_var.day); + EXPECT_EQ(12, timestamp_var.month); + EXPECT_EQ(9999, timestamp_var.year); + EXPECT_EQ(23, timestamp_var.hour); + EXPECT_EQ(59, timestamp_var.minute); + EXPECT_EQ(59, timestamp_var.second); + EXPECT_EQ(0, timestamp_var.fraction); +} + +TEST_F(StatementRemoteTest, TestSQLExecDirectTimeQuery) { + // Mock server test is skipped due to limitation on the mock server. + // Time type from mock server does not include the fraction + + std::wstring wsql = + LR"( + SELECT CAST(TIME '00:00:00' AS TIME) AS time_min, + CAST(TIME '23:59:59' AS TIME) AS time_max; + )"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + SQL_TIME_STRUCT time_var{}; + SQLLEN buf_len = sizeof(time_var); + SQLLEN ind; + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 1, SQL_C_TYPE_TIME, &time_var, buf_len, &ind)); + // Check min values for time. + EXPECT_EQ(0, time_var.hour); + EXPECT_EQ(0, time_var.minute); + EXPECT_EQ(0, time_var.second); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 2, SQL_C_TYPE_TIME, &time_var, buf_len, &ind)); + // Check max values for time. + EXPECT_EQ(23, time_var.hour); + EXPECT_EQ(59, time_var.minute); + EXPECT_EQ(59, time_var.second); +} + +TEST_F(StatementMockTest, TestSQLExecDirectVarbinaryQuery) { + // Have binary test on mock test base as remote test servers tend to have different + // formats for binary data + + std::wstring wsql = L"SELECT X'ABCDEF' AS c_varbinary;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // varbinary + std::vector varbinary_val(3); + SQLLEN buf_len = varbinary_val.size(); + SQLLEN ind; + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 1, SQL_C_BINARY, &varbinary_val[0], buf_len, &ind)); + EXPECT_EQ('\xAB', varbinary_val[0]); + EXPECT_EQ('\xCD', varbinary_val[1]); + EXPECT_EQ('\xEF', varbinary_val[2]); +} + +// Tests with SQL_C_DEFAULT as the target type + +TEST_F(StatementRemoteTest, TestSQLExecDirectDataQueryDefaultType) { + // Test with default types. Only testing target types supported by server. + + std::wstring wsql = this->GetQueryAllDataTypes(); + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Numeric Types + // Signed Integer + SQLINTEGER slong_val; + SQLLEN buf_len = sizeof(slong_val); + SQLLEN ind; + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 9, SQL_C_DEFAULT, &slong_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::min(), slong_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 10, SQL_C_DEFAULT, &slong_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), slong_val); + + // Signed Big Int + SQLBIGINT sbig_int_val; + buf_len = sizeof(sbig_int_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 13, SQL_C_DEFAULT, &sbig_int_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::min(), sbig_int_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 14, SQL_C_DEFAULT, &sbig_int_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), sbig_int_val); + + // Decimal + SQL_NUMERIC_STRUCT decimal_val; + memset(&decimal_val, 0, sizeof(decimal_val)); + buf_len = sizeof(SQL_NUMERIC_STRUCT); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 17, SQL_C_DEFAULT, &decimal_val, buf_len, &ind)); + // Check for negative decimal_val value + EXPECT_EQ(0, decimal_val.sign); + EXPECT_EQ(0, decimal_val.scale); + EXPECT_EQ(38, decimal_val.precision); + EXPECT_THAT(decimal_val.val, ::testing::ElementsAre(0xFF, 0xC9, 0x9A, 0x3B, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0)); + + memset(&decimal_val, 0, sizeof(decimal_val)); + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 18, SQL_C_DEFAULT, &decimal_val, buf_len, &ind)); + // Check for positive decimal_val value + EXPECT_EQ(1, decimal_val.sign); + EXPECT_EQ(0, decimal_val.scale); + EXPECT_EQ(38, decimal_val.precision); + EXPECT_THAT(decimal_val.val, ::testing::ElementsAre(0xFF, 0xC9, 0x9A, 0x3B, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0)); + + // Float + float float_val; + buf_len = sizeof(float_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 19, SQL_C_DEFAULT, &float_val, buf_len, &ind)); + // Get minimum negative float value + EXPECT_EQ(-std::numeric_limits::max(), float_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 20, SQL_C_DEFAULT, &float_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), float_val); + + // Double + SQLDOUBLE double_val; + buf_len = sizeof(double_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 21, SQL_C_DEFAULT, &double_val, buf_len, &ind)); + // Get minimum negative double value + EXPECT_EQ(-std::numeric_limits::max(), double_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 22, SQL_C_DEFAULT, &double_val, buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), double_val); + + // Bit + bool bit_val; + buf_len = sizeof(bit_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 23, SQL_C_DEFAULT, &bit_val, buf_len, &ind)); + EXPECT_EQ(false, bit_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 24, SQL_C_DEFAULT, &bit_val, buf_len, &ind)); + EXPECT_EQ(true, bit_val); + + // Characters + + // Char will be fetched as wchar by default + SQLWCHAR wchar_val[2]; + size_t wchar_size = arrow::flight::sql::odbc::GetSqlWCharSize(); + buf_len = wchar_size * 2; + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 25, SQL_C_DEFAULT, &wchar_val, buf_len, &ind)); + EXPECT_EQ(L'Z', wchar_val[0]); + + // WChar + SQLWCHAR wchar_val2[2]; + buf_len = wchar_size * 2; + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 26, SQL_C_DEFAULT, &wchar_val2, buf_len, &ind)); + EXPECT_EQ(L'你', wchar_val2[0]); + + // WVarchar + SQLWCHAR wvarchar_val[3]; + buf_len = wchar_size * 3; + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 27, SQL_C_DEFAULT, &wvarchar_val, buf_len, &ind)); + EXPECT_EQ(L'你', wvarchar_val[0]); + EXPECT_EQ(L'好', wvarchar_val[1]); + + // Varchar will be fetched as WVarchar by default + SQLWCHAR wvarchar_val2[4]; + buf_len = wchar_size * 4; + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 28, SQL_C_DEFAULT, &wvarchar_val2, buf_len, &ind)); + EXPECT_EQ(L'X', wvarchar_val2[0]); + EXPECT_EQ(L'Y', wvarchar_val2[1]); + EXPECT_EQ(L'Z', wvarchar_val2[2]); + + // Date and Timestamp + + // Date + SQL_DATE_STRUCT date_var{}; + buf_len = sizeof(date_var); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 29, SQL_C_DEFAULT, &date_var, buf_len, &ind)); + // Check min values for date. Min valid year is 1400. + EXPECT_EQ(1, date_var.day); + EXPECT_EQ(1, date_var.month); + EXPECT_EQ(1400, date_var.year); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 30, SQL_C_DEFAULT, &date_var, buf_len, &ind)); + // Check max values for date. Max valid year is 9999. + EXPECT_EQ(31, date_var.day); + EXPECT_EQ(12, date_var.month); + EXPECT_EQ(9999, date_var.year); + + // Timestamp + SQL_TIMESTAMP_STRUCT timestamp_var{}; + buf_len = sizeof(timestamp_var); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 31, SQL_C_DEFAULT, ×tamp_var, buf_len, &ind)); + // Check min values for date. Min valid year is 1400. + EXPECT_EQ(1, timestamp_var.day); + EXPECT_EQ(1, timestamp_var.month); + EXPECT_EQ(1400, timestamp_var.year); + EXPECT_EQ(0, timestamp_var.hour); + EXPECT_EQ(0, timestamp_var.minute); + EXPECT_EQ(0, timestamp_var.second); + EXPECT_EQ(0, timestamp_var.fraction); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 32, SQL_C_DEFAULT, ×tamp_var, buf_len, &ind)); + // Check max values for date. Max valid year is 9999. + EXPECT_EQ(31, timestamp_var.day); + EXPECT_EQ(12, timestamp_var.month); + EXPECT_EQ(9999, timestamp_var.year); + EXPECT_EQ(23, timestamp_var.hour); + EXPECT_EQ(59, timestamp_var.minute); + EXPECT_EQ(59, timestamp_var.second); + EXPECT_EQ(0, timestamp_var.fraction); +} + +TEST_F(StatementRemoteTest, TestSQLExecDirectTimeQueryDefaultType) { + // Mock server test is skipped due to limitation on the mock server. + // Time type from mock server does not include the fraction + + std::wstring wsql = + LR"( + SELECT CAST(TIME '00:00:00' AS TIME) AS time_min, + CAST(TIME '23:59:59' AS TIME) AS time_max; + )"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + SQL_TIME_STRUCT time_var{}; + SQLLEN buf_len = sizeof(time_var); + SQLLEN ind; + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 1, SQL_C_DEFAULT, &time_var, buf_len, &ind)); + // Check min values for time. + EXPECT_EQ(0, time_var.hour); + EXPECT_EQ(0, time_var.minute); + EXPECT_EQ(0, time_var.second); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 2, SQL_C_DEFAULT, &time_var, buf_len, &ind)); + // Check max values for time. + EXPECT_EQ(23, time_var.hour); + EXPECT_EQ(59, time_var.minute); + EXPECT_EQ(59, time_var.second); +} + +TEST_F(StatementRemoteTest, TestSQLExecDirectVarbinaryQueryDefaultType) { + // Limitation on mock test server prevents SQL_C_DEFAULT from working properly. + // Mock server has type `DENSE_UNION` for varbinary. + // Note that not all remote servers support "from_hex" function + + std::wstring wsql = L"SELECT from_hex('ABCDEF') AS c_varbinary;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // varbinary + std::vector varbinary_val(3); + SQLLEN buf_len = varbinary_val.size(); + SQLLEN ind; + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 1, SQL_C_DEFAULT, &varbinary_val[0], buf_len, &ind)); + EXPECT_EQ('\xAB', varbinary_val[0]); + EXPECT_EQ('\xCD', varbinary_val[1]); + EXPECT_EQ('\xEF', varbinary_val[2]); +} + +TYPED_TEST(StatementTest, TestSQLExecDirectGuidQueryUnsupported) { + // Query GUID as string as SQLite does not support GUID + std::wstring wsql = L"SELECT 'C77313CF-4E08-47CE-B6DF-94DD2FCF3541' AS guid;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + SQLGUID guid_var; + SQLLEN buf_len = sizeof(guid_var); + SQLLEN ind; + ASSERT_EQ(SQL_ERROR, SQLGetData(this->stmt, 1, SQL_C_GUID, &guid_var, buf_len, &ind)); + // GUID is not supported by ODBC + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY000); +} + +TYPED_TEST(StatementTest, TestSQLExecDirectRowFetching) { + std::wstring wsql = + LR"( + SELECT 1 AS small_table + UNION ALL + SELECT 2 + UNION ALL + SELECT 3; + )"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + // Fetch row 1 + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + SQLINTEGER val; + SQLLEN buf_len = sizeof(val); + SQLLEN ind; + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, buf_len, &ind)); + + // Verify 1 is returned + EXPECT_EQ(1, val); + + // Fetch row 2 + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, buf_len, &ind)); + + // Verify 2 is returned + EXPECT_EQ(2, val); + + // Fetch row 3 + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, buf_len, &ind)); + + // Verify 3 is returned + EXPECT_EQ(3, val); + + // Verify result set has no more data beyond row 3 + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); + + ASSERT_EQ(SQL_ERROR, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, &ind)); + + // Invalid cursor state + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState24000); +} + +TYPED_TEST(StatementTest, TestSQLFetchScrollRowFetching) { + SQLLEN rows_fetched; + SQLSetStmtAttr(this->stmt, SQL_ATTR_ROWS_FETCHED_PTR, &rows_fetched, 0); + + std::wstring wsql = + LR"( + SELECT 1 AS small_table + UNION ALL + SELECT 2 + UNION ALL + SELECT 3; + )"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + // Fetch row 1 + ASSERT_EQ(SQL_SUCCESS, SQLFetchScroll(this->stmt, SQL_FETCH_NEXT, 0)); + + SQLINTEGER val; + SQLLEN buf_len = sizeof(val); + SQLLEN ind; + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, buf_len, &ind)); + // Verify 1 is returned + EXPECT_EQ(1, val); + // Verify 1 row is fetched + EXPECT_EQ(1, rows_fetched); + + // Fetch row 2 + ASSERT_EQ(SQL_SUCCESS, SQLFetchScroll(this->stmt, SQL_FETCH_NEXT, 0)); + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, buf_len, &ind)); + + // Verify 2 is returned + EXPECT_EQ(2, val); + // Verify 1 row is fetched in the last SQLFetchScroll call + EXPECT_EQ(1, rows_fetched); + + // Fetch row 3 + ASSERT_EQ(SQL_SUCCESS, SQLFetchScroll(this->stmt, SQL_FETCH_NEXT, 0)); + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, buf_len, &ind)); + + // Verify 3 is returned + EXPECT_EQ(3, val); + // Verify 1 row is fetched in the last SQLFetchScroll call + EXPECT_EQ(1, rows_fetched); + + // Verify result set has no more data beyond row 3 + ASSERT_EQ(SQL_NO_DATA, SQLFetchScroll(this->stmt, SQL_FETCH_NEXT, 0)); + + ASSERT_EQ(SQL_ERROR, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, &ind)); + // Invalid cursor state + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState24000); +} + +TYPED_TEST(StatementTest, TestSQLFetchScrollUnsupportedOrientation) { + // SQL_FETCH_PRIOR is the only supported fetch orientation. + + std::wstring wsql = L"SELECT 1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_ERROR, SQLFetchScroll(this->stmt, SQL_FETCH_PRIOR, 0)); + + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHYC00); + + SQLLEN fetch_offset = 1; + ASSERT_EQ(SQL_ERROR, SQLFetchScroll(this->stmt, SQL_FETCH_RELATIVE, fetch_offset)); + + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHYC00); + + ASSERT_EQ(SQL_ERROR, SQLFetchScroll(this->stmt, SQL_FETCH_ABSOLUTE, fetch_offset)); + + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHYC00); + + ASSERT_EQ(SQL_ERROR, SQLFetchScroll(this->stmt, SQL_FETCH_FIRST, 0)); + + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHYC00); + + ASSERT_EQ(SQL_ERROR, SQLFetchScroll(this->stmt, SQL_FETCH_LAST, 0)); + + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHYC00); + + ASSERT_EQ(SQL_ERROR, SQLFetchScroll(this->stmt, SQL_FETCH_BOOKMARK, fetch_offset)); + + // DM returns state HY106 for SQL_FETCH_BOOKMARK + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY106); +} + +TYPED_TEST(StatementTest, TestSQLExecDirectVarcharTruncation) { + std::wstring wsql = L"SELECT 'VERY LONG STRING here' AS string_col;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + const int len = 17; + SQLCHAR char_val[len]; + SQLLEN buf_len = sizeof(SQLCHAR) * len; + SQLLEN ind; + ASSERT_EQ(SQL_SUCCESS_WITH_INFO, + SQLGetData(this->stmt, 1, SQL_C_CHAR, &char_val, buf_len, &ind)); + // Verify string truncation is reported + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState01004); + + EXPECT_EQ(std::string("VERY LONG STRING"), ODBC::SqlStringToString(char_val)); + EXPECT_EQ(21, ind); + + // Fetch same column 2nd time + const int len2 = 2; + SQLCHAR char_val2[len2]; + buf_len = sizeof(SQLCHAR) * len2; + ASSERT_EQ(SQL_SUCCESS_WITH_INFO, + SQLGetData(this->stmt, 1, SQL_C_CHAR, &char_val2, buf_len, &ind)); + // Verify string truncation is reported + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState01004); + + EXPECT_EQ(std::string(" "), ODBC::SqlStringToString(char_val2)); + EXPECT_EQ(5, ind); + + // Fetch same column 3rd time + const int len3 = 5; + SQLCHAR char_val3[len3]; + buf_len = sizeof(SQLCHAR) * len3; + + // Verify that there is no more truncation reports. The full string has been fetched. + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 1, SQL_C_CHAR, &char_val3, buf_len, &ind)); + + EXPECT_EQ(std::string("here"), ODBC::SqlStringToString(char_val3)); + EXPECT_EQ(4, ind); + + // Attempt to fetch data 4th time + SQLCHAR char_val4[len]; + // Verify SQL_NO_DATA is returned + ASSERT_EQ(SQL_NO_DATA, SQLGetData(this->stmt, 1, SQL_C_CHAR, &char_val4, 0, &ind)); +} + +TYPED_TEST(StatementTest, TestSQLExecDirectWVarcharTruncation) { + std::wstring wsql = L"SELECT 'VERY LONG Unicode STRING 句子 here' AS wstring_col;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + const int len = 28; + SQLWCHAR wchar_val[len]; + size_t wchar_size = arrow::flight::sql::odbc::GetSqlWCharSize(); + SQLLEN buf_len = wchar_size * len; + SQLLEN ind; + ASSERT_EQ(SQL_SUCCESS_WITH_INFO, + SQLGetData(this->stmt, 1, SQL_C_WCHAR, &wchar_val, buf_len, &ind)); + // Verify string truncation is reported + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState01004); + + EXPECT_EQ(std::wstring(L"VERY LONG Unicode STRING 句子"), std::wstring(wchar_val)); + EXPECT_EQ(32 * wchar_size, ind); + + // Fetch same column 2nd time + const int len2 = 2; + SQLWCHAR wchar_val2[len2]; + buf_len = wchar_size * len2; + ASSERT_EQ(SQL_SUCCESS_WITH_INFO, + SQLGetData(this->stmt, 1, SQL_C_WCHAR, &wchar_val2, buf_len, &ind)); + // Verify string truncation is reported + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState01004); + + EXPECT_EQ(std::wstring(L" "), std::wstring(wchar_val2)); + EXPECT_EQ(5 * wchar_size, ind); + + // Fetch same column 3rd time + const int len3 = 5; + SQLWCHAR wchar_val3[len3]; + buf_len = wchar_size * len3; + + // Verify that there is no more truncation reports. The full string has been fetched. + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 1, SQL_C_WCHAR, &wchar_val3, buf_len, &ind)); + + EXPECT_EQ(std::wstring(L"here"), std::wstring(wchar_val3)); + EXPECT_EQ(4 * wchar_size, ind); + + // Attempt to fetch data 4th time + SQLWCHAR wchar_val4[len]; + // Verify SQL_NO_DATA is returned + ASSERT_EQ(SQL_NO_DATA, SQLGetData(this->stmt, 1, SQL_C_WCHAR, &wchar_val4, 0, &ind)); +} + +TEST_F(StatementMockTest, TestSQLExecDirectVarbinaryTruncation) { + // Have binary test on mock test base as remote test servers tend to have different + // formats for binary data + + std::wstring wsql = L"SELECT X'ABCDEFAB' AS c_varbinary;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // varbinary + std::vector varbinary_val(3); + SQLLEN buf_len = varbinary_val.size(); + SQLLEN ind; + ASSERT_EQ(SQL_SUCCESS_WITH_INFO, + SQLGetData(this->stmt, 1, SQL_C_BINARY, &varbinary_val[0], buf_len, &ind)); + // Verify binary truncation is reported + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState01004); + EXPECT_EQ('\xAB', varbinary_val[0]); + EXPECT_EQ('\xCD', varbinary_val[1]); + EXPECT_EQ('\xEF', varbinary_val[2]); + EXPECT_EQ(4, ind); + + // Fetch same column 2nd time + std::vector varbinary_val2(1); + buf_len = varbinary_val2.size(); + + // Verify that there is no more truncation reports. The full binary has been fetched. + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 1, SQL_C_BINARY, &varbinary_val2[0], buf_len, &ind)); + + EXPECT_EQ('\xAB', varbinary_val[0]); + EXPECT_EQ(1, ind); + + // Attempt to fetch data 3rd time + std::vector varbinary_val3(1); + buf_len = varbinary_val3.size(); + // Verify SQL_NO_DATA is returned + ASSERT_EQ(SQL_NO_DATA, + SQLGetData(this->stmt, 1, SQL_C_BINARY, &varbinary_val3[0], buf_len, &ind)); +} + +TYPED_TEST(StatementTest, DISABLED_TestSQLExecDirectFloatTruncation) { + // Test is disabled until float truncation is supported. + // GH-46985: return warning message instead of error on float truncation case + std::wstring wsql; + if constexpr (std::is_same_v) { + wsql = std::wstring(L"SELECT CAST(1.234 AS REAL) AS float_val"); + } else if constexpr (std::is_same_v) { + wsql = std::wstring(L"SELECT CAST(1.234 AS FLOAT) AS float_val"); + } + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + int16_t ssmall_int_val; + + ASSERT_EQ(SQL_SUCCESS_WITH_INFO, + SQLGetData(this->stmt, 1, SQL_C_SSHORT, &ssmall_int_val, 0, 0)); + // Verify float truncation is reported + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState01S07); + + EXPECT_EQ(1, ssmall_int_val); +} + +TEST_F(StatementRemoteTest, TestSQLExecDirectNullQuery) { + // Limitation on mock test server prevents null from working properly, so use remote + // server instead. Mock server has type `DENSE_UNION` for null column data. + + std::wstring wsql = L"SELECT null as null_col;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + SQLINTEGER val; + SQLLEN ind; + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, &ind)); + + // Verify SQL_NULL_DATA is returned for indicator + EXPECT_EQ(SQL_NULL_DATA, ind); +} + +TEST_F(StatementMockTest, TestSQLExecDirectTruncationQueryNullIndicator) { + // Driver should not error out when indicator is null if the cell is non-null + // Have binary test on mock test base as remote test servers tend to have different + // formats for binary data + + std::wstring wsql = + LR"( + SELECT 1, + 'VERY LONG STRING here' AS string_col, + 'VERY LONG Unicode STRING 句子 here' AS wstring_col, + X'ABCDEFAB' AS c_varbinary; + )"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + SQLINTEGER val; + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0)); + // Verify 1 is returned for non-truncation case. + EXPECT_EQ(1, val); + + // Char + const int len = 17; + SQLCHAR char_val[len]; + SQLLEN buf_len = sizeof(SQLCHAR) * len; + ASSERT_EQ(SQL_SUCCESS_WITH_INFO, + SQLGetData(this->stmt, 2, SQL_C_CHAR, &char_val, buf_len, 0)); + // Verify string truncation is reported + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState01004); + + // WChar + const int len2 = 28; + SQLWCHAR wchar_val[len2]; + size_t wchar_size = arrow::flight::sql::odbc::GetSqlWCharSize(); + buf_len = wchar_size * len2; + ASSERT_EQ(SQL_SUCCESS_WITH_INFO, + SQLGetData(this->stmt, 3, SQL_C_WCHAR, &wchar_val, buf_len, 0)); + // Verify string truncation is reported + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState01004); + + // varbinary + std::vector varbinary_val(3); + buf_len = varbinary_val.size(); + ASSERT_EQ(SQL_SUCCESS_WITH_INFO, + SQLGetData(this->stmt, 4, SQL_C_BINARY, &varbinary_val[0], buf_len, 0)); + // Verify binary truncation is reported + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState01004); +} + +TEST_F(StatementRemoteTest, TestSQLExecDirectNullQueryNullIndicator) { + // Limitation on mock test server prevents null from working properly, so use remote + // server instead. Mock server has type `DENSE_UNION` for null column data. + + std::wstring wsql = L"SELECT null as null_col;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + SQLINTEGER val; + + ASSERT_EQ(SQL_ERROR, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0)); + // Verify invalid null indicator is reported, as it is required + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState22002); +} + +TYPED_TEST(StatementTest, TestSQLExecDirectIgnoreInvalidBufLen) { + // Verify the driver ignores invalid buffer length for fixed data types + + std::wstring wsql = this->GetQueryAllDataTypes(); + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Numeric Types + + // Signed Tiny Int + int8_t stiny_int_val; + SQLLEN invalid_buf_len = -1; + SQLLEN ind; + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_STINYINT, &stiny_int_val, + invalid_buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::min(), stiny_int_val); + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 2, SQL_C_STINYINT, &stiny_int_val, + invalid_buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), stiny_int_val); + + // Unsigned Tiny Int + uint8_t utiny_int_val; + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 3, SQL_C_UTINYINT, &utiny_int_val, + invalid_buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::min(), utiny_int_val); + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 4, SQL_C_UTINYINT, &utiny_int_val, + invalid_buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), utiny_int_val); + + // Signed Small Int + int16_t ssmall_int_val; + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 5, SQL_C_SSHORT, &ssmall_int_val, + invalid_buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::min(), ssmall_int_val); + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 6, SQL_C_SSHORT, &ssmall_int_val, + invalid_buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), ssmall_int_val); + + // Unsigned Small Int + uint16_t usmall_int_val; + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 7, SQL_C_USHORT, &usmall_int_val, + invalid_buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::min(), usmall_int_val); + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 8, SQL_C_USHORT, &usmall_int_val, + invalid_buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), usmall_int_val); + + // Signed Integer + SQLINTEGER slong_val; + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 9, SQL_C_SLONG, &slong_val, invalid_buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::min(), slong_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 10, SQL_C_SLONG, &slong_val, invalid_buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), slong_val); + + // Unsigned Integer + SQLUINTEGER ulong_val; + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 11, SQL_C_ULONG, &ulong_val, invalid_buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::min(), ulong_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 12, SQL_C_ULONG, &ulong_val, invalid_buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), ulong_val); + + // Signed Big Int + SQLBIGINT sbig_int_val; + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 13, SQL_C_SBIGINT, &sbig_int_val, + invalid_buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::min(), sbig_int_val); + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 14, SQL_C_SBIGINT, &sbig_int_val, + invalid_buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), sbig_int_val); + + // Unsigned Big Int + SQLUBIGINT ubig_int_val; + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 15, SQL_C_UBIGINT, &ubig_int_val, + invalid_buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::min(), ubig_int_val); + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 16, SQL_C_UBIGINT, &ubig_int_val, + invalid_buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), ubig_int_val); + + // Decimal + SQL_NUMERIC_STRUCT decimal_val; + memset(&decimal_val, 0, sizeof(decimal_val)); + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 17, SQL_C_NUMERIC, &decimal_val, + invalid_buf_len, &ind)); + // Check for negative decimal_val value + EXPECT_EQ(0, decimal_val.sign); + EXPECT_EQ(0, decimal_val.scale); + EXPECT_EQ(38, decimal_val.precision); + EXPECT_THAT(decimal_val.val, ::testing::ElementsAre(0xFF, 0xC9, 0x9A, 0x3B, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0)); + + memset(&decimal_val, 0, sizeof(decimal_val)); + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 18, SQL_C_NUMERIC, &decimal_val, + invalid_buf_len, &ind)); + // Check for positive decimal_val value + EXPECT_EQ(1, decimal_val.sign); + EXPECT_EQ(0, decimal_val.scale); + EXPECT_EQ(38, decimal_val.precision); + EXPECT_THAT(decimal_val.val, ::testing::ElementsAre(0xFF, 0xC9, 0x9A, 0x3B, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0)); + + // Float + float float_val; + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 19, SQL_C_FLOAT, &float_val, invalid_buf_len, &ind)); + // Get minimum negative float value + EXPECT_EQ(-std::numeric_limits::max(), float_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 20, SQL_C_FLOAT, &float_val, invalid_buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), float_val); + + // Double + SQLDOUBLE double_val; + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 21, SQL_C_DOUBLE, &double_val, invalid_buf_len, &ind)); + // Get minimum negative double value + EXPECT_EQ(-std::numeric_limits::max(), double_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 22, SQL_C_DOUBLE, &double_val, invalid_buf_len, &ind)); + EXPECT_EQ(std::numeric_limits::max(), double_val); + + // Bit + bool bit_val; + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 23, SQL_C_BIT, &bit_val, invalid_buf_len, &ind)); + EXPECT_EQ(false, bit_val); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(this->stmt, 24, SQL_C_BIT, &bit_val, invalid_buf_len, &ind)); + EXPECT_EQ(true, bit_val); + + // Date and Timestamp + + // Date + SQL_DATE_STRUCT date_var{}; + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 29, SQL_C_TYPE_DATE, &date_var, + invalid_buf_len, &ind)); + // Check min values for date. Min valid year is 1400. + EXPECT_EQ(1, date_var.day); + EXPECT_EQ(1, date_var.month); + EXPECT_EQ(1400, date_var.year); + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 30, SQL_C_TYPE_DATE, &date_var, + invalid_buf_len, &ind)); + // Check max values for date. Max valid year is 9999. + EXPECT_EQ(31, date_var.day); + EXPECT_EQ(12, date_var.month); + EXPECT_EQ(9999, date_var.year); + + // Timestamp + SQL_TIMESTAMP_STRUCT timestamp_var{}; + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 31, SQL_C_TYPE_TIMESTAMP, ×tamp_var, + invalid_buf_len, &ind)); + // Check min values for date. Min valid year is 1400. + EXPECT_EQ(1, timestamp_var.day); + EXPECT_EQ(1, timestamp_var.month); + EXPECT_EQ(1400, timestamp_var.year); + EXPECT_EQ(0, timestamp_var.hour); + EXPECT_EQ(0, timestamp_var.minute); + EXPECT_EQ(0, timestamp_var.second); + EXPECT_EQ(0, timestamp_var.fraction); + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 32, SQL_C_TYPE_TIMESTAMP, ×tamp_var, + invalid_buf_len, &ind)); + // Check max values for date. Max valid year is 9999. + EXPECT_EQ(31, timestamp_var.day); + EXPECT_EQ(12, timestamp_var.month); + EXPECT_EQ(9999, timestamp_var.year); + EXPECT_EQ(23, timestamp_var.hour); + EXPECT_EQ(59, timestamp_var.minute); + EXPECT_EQ(59, timestamp_var.second); + EXPECT_EQ(0, timestamp_var.fraction); +} + +TYPED_TEST(StatementTest, TestSQLBindColDataQuery) { + // Numeric Types + + // Signed Tiny Int + int8_t stiny_int_val_min; + int8_t stiny_int_val_max; + SQLLEN buf_len = 0; + SQLLEN ind; + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 1, SQL_C_STINYINT, &stiny_int_val_min, buf_len, &ind)); + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 2, SQL_C_STINYINT, &stiny_int_val_max, buf_len, &ind)); + + // Unsigned Tiny Int + uint8_t utiny_int_val_min; + uint8_t utiny_int_val_max; + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 3, SQL_C_UTINYINT, &utiny_int_val_min, buf_len, &ind)); + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 4, SQL_C_UTINYINT, &utiny_int_val_max, buf_len, &ind)); + + // Signed Small Int + int16_t ssmall_int_val_min; + int16_t ssmall_int_val_max; + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 5, SQL_C_SSHORT, &ssmall_int_val_min, buf_len, &ind)); + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 6, SQL_C_SSHORT, &ssmall_int_val_max, buf_len, &ind)); + + // Unsigned Small Int + uint16_t usmall_int_val_min; + uint16_t usmall_int_val_max; + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 7, SQL_C_USHORT, &usmall_int_val_min, buf_len, &ind)); + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 8, SQL_C_USHORT, &usmall_int_val_max, buf_len, &ind)); + + // Signed Integer + SQLINTEGER slong_val_min; + SQLINTEGER slong_val_max; + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 9, SQL_C_SLONG, &slong_val_min, buf_len, &ind)); + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 10, SQL_C_SLONG, &slong_val_max, buf_len, &ind)); + + // Unsigned Integer + SQLUINTEGER ulong_val_min; + SQLUINTEGER ulong_val_max; + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 11, SQL_C_ULONG, &ulong_val_min, buf_len, &ind)); + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 12, SQL_C_ULONG, &ulong_val_max, buf_len, &ind)); + + // Signed Big Int + SQLBIGINT sbig_int_val_min; + SQLBIGINT sbig_int_val_max; + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 13, SQL_C_SBIGINT, &sbig_int_val_min, buf_len, &ind)); + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 14, SQL_C_SBIGINT, &sbig_int_val_max, buf_len, &ind)); + + // Unsigned Big Int + SQLUBIGINT ubig_int_val_min; + SQLUBIGINT ubig_int_val_max; + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 15, SQL_C_UBIGINT, &ubig_int_val_min, buf_len, &ind)); + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 16, SQL_C_UBIGINT, &ubig_int_val_max, buf_len, &ind)); + + // Decimal + SQL_NUMERIC_STRUCT decimal_val_neg; + SQL_NUMERIC_STRUCT decimal_val_pos; + memset(&decimal_val_neg, 0, sizeof(decimal_val_neg)); + memset(&decimal_val_pos, 0, sizeof(decimal_val_pos)); + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 17, SQL_C_NUMERIC, &decimal_val_neg, buf_len, &ind)); + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 18, SQL_C_NUMERIC, &decimal_val_pos, buf_len, &ind)); + + // Float + float float_val_min; + float float_val_max; + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 19, SQL_C_FLOAT, &float_val_min, buf_len, &ind)); + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 20, SQL_C_FLOAT, &float_val_max, buf_len, &ind)); + + // Double + SQLDOUBLE double_val_min; + SQLDOUBLE double_val_max; + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 21, SQL_C_DOUBLE, &double_val_min, buf_len, &ind)); + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 22, SQL_C_DOUBLE, &double_val_max, buf_len, &ind)); + + // Bit + bool bit_val_false; + bool bit_val_true; + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 23, SQL_C_BIT, &bit_val_false, buf_len, &ind)); + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 24, SQL_C_BIT, &bit_val_true, buf_len, &ind)); + + // Characters + SQLCHAR char_val[2]; + buf_len = sizeof(SQLCHAR) * 2; + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 25, SQL_C_CHAR, &char_val, buf_len, &ind)); + + SQLWCHAR wchar_val[2]; + size_t wchar_size = arrow::flight::sql::odbc::GetSqlWCharSize(); + buf_len = wchar_size * 2; + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 26, SQL_C_WCHAR, &wchar_val, buf_len, &ind)); + + SQLWCHAR wvarchar_val[3]; + buf_len = wchar_size * 3; + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 27, SQL_C_WCHAR, &wvarchar_val, buf_len, &ind)); + + SQLCHAR varchar_val[4]; + buf_len = sizeof(SQLCHAR) * 4; + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 28, SQL_C_CHAR, &varchar_val, buf_len, &ind)); + + // Date and Timestamp + SQL_DATE_STRUCT date_val_min{}, date_val_max{}; + buf_len = 0; + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 29, SQL_C_TYPE_DATE, &date_val_min, buf_len, &ind)); + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 30, SQL_C_TYPE_DATE, &date_val_max, buf_len, &ind)); + + SQL_TIMESTAMP_STRUCT timestamp_val_min{}, timestamp_val_max{}; + + EXPECT_EQ(SQL_SUCCESS, SQLBindCol(this->stmt, 31, SQL_C_TYPE_TIMESTAMP, + ×tamp_val_min, buf_len, &ind)); + + EXPECT_EQ(SQL_SUCCESS, SQLBindCol(this->stmt, 32, SQL_C_TYPE_TIMESTAMP, + ×tamp_val_max, buf_len, &ind)); + + // Execute query and fetch data once since there is only 1 row. + std::wstring wsql = this->GetQueryAllDataTypes(); + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Data verification + + // Signed Tiny Int + EXPECT_EQ(std::numeric_limits::min(), stiny_int_val_min); + EXPECT_EQ(std::numeric_limits::max(), stiny_int_val_max); + + // Unsigned Tiny Int + EXPECT_EQ(std::numeric_limits::min(), utiny_int_val_min); + EXPECT_EQ(std::numeric_limits::max(), utiny_int_val_max); + + // Signed Small Int + EXPECT_EQ(std::numeric_limits::min(), ssmall_int_val_min); + EXPECT_EQ(std::numeric_limits::max(), ssmall_int_val_max); + + // Unsigned Small Int + EXPECT_EQ(std::numeric_limits::min(), usmall_int_val_min); + EXPECT_EQ(std::numeric_limits::max(), usmall_int_val_max); + + // Signed Long + EXPECT_EQ(std::numeric_limits::min(), slong_val_min); + EXPECT_EQ(std::numeric_limits::max(), slong_val_max); + + // Unsigned Long + EXPECT_EQ(std::numeric_limits::min(), ulong_val_min); + EXPECT_EQ(std::numeric_limits::max(), ulong_val_max); + + // Signed Big Int + EXPECT_EQ(std::numeric_limits::min(), sbig_int_val_min); + EXPECT_EQ(std::numeric_limits::max(), sbig_int_val_max); + + // Unsigned Big Int + EXPECT_EQ(std::numeric_limits::min(), ubig_int_val_min); + EXPECT_EQ(std::numeric_limits::max(), ubig_int_val_max); + + // Decimal + EXPECT_EQ(0, decimal_val_neg.sign); + EXPECT_EQ(0, decimal_val_neg.scale); + EXPECT_EQ(38, decimal_val_neg.precision); + EXPECT_THAT(decimal_val_neg.val, ::testing::ElementsAre(0xFF, 0xC9, 0x9A, 0x3B, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0)); + + EXPECT_EQ(1, decimal_val_pos.sign); + EXPECT_EQ(0, decimal_val_pos.scale); + EXPECT_EQ(38, decimal_val_pos.precision); + EXPECT_THAT(decimal_val_pos.val, ::testing::ElementsAre(0xFF, 0xC9, 0x9A, 0x3B, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0)); + + // Float + EXPECT_EQ(-std::numeric_limits::max(), float_val_min); + EXPECT_EQ(std::numeric_limits::max(), float_val_max); + + // Double + EXPECT_EQ(-std::numeric_limits::max(), double_val_min); + EXPECT_EQ(std::numeric_limits::max(), double_val_max); + + // Bit + EXPECT_EQ(false, bit_val_false); + EXPECT_EQ(true, bit_val_true); + + // Characters + EXPECT_EQ('Z', char_val[0]); + EXPECT_EQ(L'你', wchar_val[0]); + EXPECT_EQ(L'你', wvarchar_val[0]); + EXPECT_EQ(L'好', wvarchar_val[1]); + + EXPECT_EQ('X', varchar_val[0]); + EXPECT_EQ('Y', varchar_val[1]); + EXPECT_EQ('Z', varchar_val[2]); + + // Date + EXPECT_EQ(1, date_val_min.day); + EXPECT_EQ(1, date_val_min.month); + EXPECT_EQ(1400, date_val_min.year); + + EXPECT_EQ(31, date_val_max.day); + EXPECT_EQ(12, date_val_max.month); + EXPECT_EQ(9999, date_val_max.year); + + // Timestamp + EXPECT_EQ(1, timestamp_val_min.day); + EXPECT_EQ(1, timestamp_val_min.month); + EXPECT_EQ(1400, timestamp_val_min.year); + EXPECT_EQ(0, timestamp_val_min.hour); + EXPECT_EQ(0, timestamp_val_min.minute); + EXPECT_EQ(0, timestamp_val_min.second); + EXPECT_EQ(0, timestamp_val_min.fraction); + + EXPECT_EQ(31, timestamp_val_max.day); + EXPECT_EQ(12, timestamp_val_max.month); + EXPECT_EQ(9999, timestamp_val_max.year); + EXPECT_EQ(23, timestamp_val_max.hour); + EXPECT_EQ(59, timestamp_val_max.minute); + EXPECT_EQ(59, timestamp_val_max.second); + EXPECT_EQ(0, timestamp_val_max.fraction); +} + +TEST_F(StatementRemoteTest, TestSQLBindColTimeQuery) { + // Mock server test is skipped due to limitation on the mock server. + // Time type from mock server does not include the fraction + + SQL_TIME_STRUCT time_var_min{}; + SQL_TIME_STRUCT time_var_max{}; + SQLLEN buf_len = sizeof(time_var_min); + SQLLEN ind; + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 1, SQL_C_TYPE_TIME, &time_var_min, buf_len, &ind)); + + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 2, SQL_C_TYPE_TIME, &time_var_max, buf_len, &ind)); + + std::wstring wsql = + LR"( + SELECT CAST(TIME '00:00:00' AS TIME) AS time_min, + CAST(TIME '23:59:59' AS TIME) AS time_max; + )"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Check min values for time. + EXPECT_EQ(0, time_var_min.hour); + EXPECT_EQ(0, time_var_min.minute); + EXPECT_EQ(0, time_var_min.second); + + // Check max values for time. + EXPECT_EQ(23, time_var_max.hour); + EXPECT_EQ(59, time_var_max.minute); + EXPECT_EQ(59, time_var_max.second); +} + +TEST_F(StatementMockTest, TestSQLBindColVarbinaryQuery) { + // Have binary test on mock test base as remote test servers tend to have different + // formats for binary data + + // varbinary + std::vector varbinary_val(3); + SQLLEN buf_len = varbinary_val.size(); + SQLLEN ind; + ASSERT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 1, SQL_C_BINARY, &varbinary_val[0], buf_len, &ind)); + + std::wstring wsql = L"SELECT X'ABCDEF' AS c_varbinary;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Check varbinary values + EXPECT_EQ('\xAB', varbinary_val[0]); + EXPECT_EQ('\xCD', varbinary_val[1]); + EXPECT_EQ('\xEF', varbinary_val[2]); +} + +TEST_F(StatementRemoteTest, TestSQLBindColNullQuery) { + // Limitation on mock test server prevents null from working properly, so use remote + // server instead. Mock server has type `DENSE_UNION` for null column data. + + SQLINTEGER val; + SQLLEN ind; + + ASSERT_EQ(SQL_SUCCESS, SQLBindCol(this->stmt, 1, SQL_C_LONG, &val, 0, &ind)); + + std::wstring wsql = L"SELECT null as null_col;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Verify SQL_NULL_DATA is returned for indicator + EXPECT_EQ(SQL_NULL_DATA, ind); +} + +TEST_F(StatementRemoteTest, TestSQLBindColNullQueryNullIndicator) { + // Limitation on mock test server prevents null from working properly, so use remote + // server instead. Mock server has type `DENSE_UNION` for null column data. + + SQLINTEGER val; + + ASSERT_EQ(SQL_SUCCESS, SQLBindCol(this->stmt, 1, SQL_C_LONG, &val, 0, 0)); + + std::wstring wsql = L"SELECT null as null_col;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_ERROR, SQLFetch(this->stmt)); + // Verify invalid null indicator is reported, as it is required + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState22002); +} + +TYPED_TEST(StatementTest, TestSQLBindColRowFetching) { + SQLINTEGER val; + SQLLEN buf_len = sizeof(val); + SQLLEN ind; + + // Same variable will be used for column 1, the value of `val` + // should be updated after every SQLFetch call. + ASSERT_EQ(SQL_SUCCESS, SQLBindCol(this->stmt, 1, SQL_C_LONG, &val, buf_len, &ind)); + + std::wstring wsql = + LR"( + SELECT 1 AS small_table + UNION ALL + SELECT 2 + UNION ALL + SELECT 3; + )"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + // Fetch row 1 + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Verify 1 is returned + EXPECT_EQ(1, val); + + // Fetch row 2 + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Verify 2 is returned + EXPECT_EQ(2, val); + + // Fetch row 3 + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Verify 3 is returned + EXPECT_EQ(3, val); + + // Verify result set has no more data beyond row 3 + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TYPED_TEST(StatementTest, TestSQLBindColRowArraySize) { + // Set SQL_ATTR_ROW_ARRAY_SIZE to fetch 3 rows at once + + constexpr SQLULEN rows = 3; + SQLINTEGER val[rows]; + SQLLEN buf_len = sizeof(val); + SQLLEN ind[rows]; + + // Same variable will be used for column 1, the value of `val` + // should be updated after every SQLFetch call. + ASSERT_EQ(SQL_SUCCESS, SQLBindCol(this->stmt, 1, SQL_C_LONG, val, buf_len, ind)); + + SQLLEN rows_fetched; + ASSERT_EQ(SQL_SUCCESS, + SQLSetStmtAttr(this->stmt, SQL_ATTR_ROWS_FETCHED_PTR, &rows_fetched, 0)); + + std::wstring wsql = + LR"( + SELECT 1 AS small_table + UNION ALL + SELECT 2 + UNION ALL + SELECT 3; + )"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLSetStmtAttr(this->stmt, SQL_ATTR_ROW_ARRAY_SIZE, + reinterpret_cast(rows), 0)); + + // Fetch 3 rows at once + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Verify 3 rows are fetched + EXPECT_EQ(3, rows_fetched); + + // Verify 1 is returned + EXPECT_EQ(1, val[0]); + // Verify 2 is returned + EXPECT_EQ(2, val[1]); + // Verify 3 is returned + EXPECT_EQ(3, val[2]); + + // Verify result set has no more data beyond row 3 + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TYPED_TEST(StatementTest, DISABLED_TestSQLBindColIndicatorOnly) { + // GH-47021: implement driver to return indicator value when data pointer is null + + // Verify driver supports null data pointer with valid indicator pointer + + // Numeric Types + + // Signed Tiny Int + SQLLEN stiny_int_ind; + EXPECT_EQ(SQL_SUCCESS, SQLBindCol(this->stmt, 1, SQL_C_STINYINT, 0, 0, &stiny_int_ind)); + + // Characters + SQLLEN buf_len = sizeof(SQLCHAR) * 2; + SQLLEN char_val_ind; + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 25, SQL_C_CHAR, 0, buf_len, &char_val_ind)); + + // Execute query and fetch data once since there is only 1 row. + std::wstring wsql = this->GetQueryAllDataTypes(); + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Verify values for indicator pointer + // Signed Tiny Int + EXPECT_EQ(1, stiny_int_ind); + + // Char array + EXPECT_EQ(1, char_val_ind); +} + +TYPED_TEST(StatementTest, TestSQLBindColIndicatorOnlySQLUnbind) { + // Verify driver supports valid indicator pointer after unbinding all columns + + // Numeric Types + + // Signed Tiny Int + int8_t stiny_int_val; + SQLLEN stiny_int_ind; + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 1, SQL_C_STINYINT, &stiny_int_val, 0, &stiny_int_ind)); + + // Characters + SQLCHAR char_val[2]; + SQLLEN buf_len = sizeof(SQLCHAR) * 2; + SQLLEN char_val_ind; + EXPECT_EQ(SQL_SUCCESS, + SQLBindCol(this->stmt, 25, SQL_C_CHAR, &char_val, buf_len, &char_val_ind)); + + // Driver should still be able to execute queries after unbinding columns + EXPECT_EQ(SQL_SUCCESS, SQLFreeStmt(this->stmt, SQL_UNBIND)); + + // Execute query and fetch data once since there is only 1 row. + std::wstring wsql = this->GetQueryAllDataTypes(); + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // GH-47021: implement driver to return indicator value when data pointer is null and + // uncomment the checks Verify values for indicator pointer Signed Tiny Int + // EXPECT_EQ(1, stiny_int_ind); + + // Char array + // EXPECT_EQ(1, char_val_ind); +} + +TYPED_TEST(StatementTest, TestSQLExtendedFetchRowFetching) { + // Set SQL_ROWSET_SIZE to fetch 3 rows at once + + constexpr SQLULEN rows = 3; + SQLINTEGER val[rows]; + SQLLEN buf_len = sizeof(val); + SQLLEN ind[rows]; + + // Same variable will be used for column 1, the value of `val` + // should be updated after every SQLFetch call. + ASSERT_EQ(SQL_SUCCESS, SQLBindCol(this->stmt, 1, SQL_C_LONG, val, buf_len, ind)); + + ASSERT_EQ(SQL_SUCCESS, SQLSetStmtAttr(this->stmt, SQL_ROWSET_SIZE, + reinterpret_cast(rows), 0)); + + std::wstring wsql = + LR"( + SELECT 1 AS small_table + UNION ALL + SELECT 2 + UNION ALL + SELECT 3; + )"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + // Fetch row 1-3. + SQLULEN row_count; + SQLUSMALLINT row_status[rows]; + + ASSERT_EQ(SQL_SUCCESS, + SQLExtendedFetch(this->stmt, SQL_FETCH_NEXT, 0, &row_count, row_status)); + EXPECT_EQ(3, row_count); + + for (int i = 0; i < rows; i++) { + EXPECT_EQ(SQL_SUCCESS, row_status[i]); + } + + // Verify 1 is returned for row 1 + EXPECT_EQ(1, val[0]); + // Verify 2 is returned for row 2 + EXPECT_EQ(2, val[1]); + // Verify 3 is returned for row 3 + EXPECT_EQ(3, val[2]); + + // Verify result set has no more data beyond row 3 + SQLULEN row_count2; + SQLUSMALLINT row_status2[rows]; + EXPECT_EQ(SQL_NO_DATA, + SQLExtendedFetch(this->stmt, SQL_FETCH_NEXT, 0, &row_count2, row_status2)); +} + +TEST_F(StatementRemoteTest, DISABLED_TestSQLExtendedFetchQueryNullIndicator) { + // GH-47110: SQLExtendedFetch should return SQL_SUCCESS_WITH_INFO for 22002 + // Limitation on mock test server prevents null from working properly, so use remote + // server instead. Mock server has type `DENSE_UNION` for null column data. + SQLINTEGER val; + + ASSERT_EQ(SQL_SUCCESS, SQLBindCol(this->stmt, 1, SQL_C_LONG, &val, 0, 0)); + + std::wstring wsql = L"SELECT null as null_col;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + SQLULEN row_count1; + SQLUSMALLINT row_status1[1]; + + // SQLExtendedFetch should return SQL_SUCCESS_WITH_INFO for 22002 state + ASSERT_EQ(SQL_SUCCESS_WITH_INFO, + SQLExtendedFetch(this->stmt, SQL_FETCH_NEXT, 0, &row_count1, row_status1)); + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState22002); +} + +TYPED_TEST(StatementTest, TestSQLMoreResultsNoData) { + // Verify SQLMoreResults is stubbed to return SQL_NO_DATA + + std::wstring wsql = L"SELECT 1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_NO_DATA, SQLMoreResults(this->stmt)); +} + +TYPED_TEST(StatementTest, TestSQLMoreResultsInvalidFunctionSequence) { + // Verify function sequence error state is reported when SQLMoreResults is called + // without executing any queries + ASSERT_EQ(SQL_ERROR, SQLMoreResults(this->stmt)); + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY010); +} + +TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsInputString) { + SQLWCHAR buf[1024]; + SQLINTEGER buf_char_len = sizeof(buf) / ODBC::GetSqlWCharSize(); + SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1"; + SQLINTEGER input_char_len = static_cast(wcslen(input_str)); + SQLINTEGER output_char_len = 0; + std::wstring expected_string = std::wstring(input_str); + + ASSERT_EQ(SQL_SUCCESS, SQLNativeSql(this->conn, input_str, input_char_len, buf, + buf_char_len, &output_char_len)); + + EXPECT_EQ(input_char_len, output_char_len); + + // returned length is in characters + std::wstring returned_string(buf, buf + output_char_len); + + EXPECT_EQ(expected_string, returned_string); +} + +TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsNTSInputString) { + SQLWCHAR buf[1024]; + SQLINTEGER buf_char_len = sizeof(buf) / ODBC::GetSqlWCharSize(); + SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1"; + SQLINTEGER input_char_len = static_cast(wcslen(input_str)); + SQLINTEGER output_char_len = 0; + std::wstring expected_string = std::wstring(input_str); + + ASSERT_EQ(SQL_SUCCESS, SQLNativeSql(this->conn, input_str, SQL_NTS, buf, buf_char_len, + &output_char_len)); + + EXPECT_EQ(input_char_len, output_char_len); + + // returned length is in characters + std::wstring returned_string(buf, buf + output_char_len); + + EXPECT_EQ(expected_string, returned_string); +} + +TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsInputStringLength) { + SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1"; + SQLINTEGER input_char_len = static_cast(wcslen(input_str)); + SQLINTEGER output_char_len = 0; + std::wstring expected_string = std::wstring(input_str); + + ASSERT_EQ(SQL_SUCCESS, SQLNativeSql(this->conn, input_str, input_char_len, nullptr, 0, + &output_char_len)); + + EXPECT_EQ(input_char_len, output_char_len); + + ASSERT_EQ(SQL_SUCCESS, + SQLNativeSql(this->conn, input_str, SQL_NTS, nullptr, 0, &output_char_len)); + + EXPECT_EQ(input_char_len, output_char_len); +} + +TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsTruncatedString) { + const SQLINTEGER small_buf_size_in_char = 11; + SQLWCHAR small_buf[small_buf_size_in_char]; + SQLINTEGER small_buf_char_len = sizeof(small_buf) / ODBC::GetSqlWCharSize(); + SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1"; + SQLINTEGER input_char_len = static_cast(wcslen(input_str)); + SQLINTEGER output_char_len = 0; + + // Create expected return string based on buf size + SQLWCHAR expected_string_buf[small_buf_size_in_char]; + wcsncpy(expected_string_buf, input_str, 10); + expected_string_buf[10] = L'\0'; + std::wstring expected_string(expected_string_buf, + expected_string_buf + small_buf_size_in_char); + + ASSERT_EQ(SQL_SUCCESS_WITH_INFO, + SQLNativeSql(this->conn, input_str, input_char_len, small_buf, + small_buf_char_len, &output_char_len)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorState01004); + + // Returned text length represents full string char length regardless of truncation + EXPECT_EQ(input_char_len, output_char_len); + + std::wstring returned_string(small_buf, small_buf + small_buf_char_len); + + EXPECT_EQ(expected_string, returned_string); +} + +TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsErrorOnBadInputs) { + SQLWCHAR buf[1024]; + SQLINTEGER buf_char_len = sizeof(buf) / ODBC::GetSqlWCharSize(); + SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1"; + SQLINTEGER input_char_len = static_cast(wcslen(input_str)); + SQLINTEGER output_char_len = 0; + + ASSERT_EQ(SQL_ERROR, SQLNativeSql(this->conn, nullptr, input_char_len, buf, + buf_char_len, &output_char_len)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHY009); + + ASSERT_EQ(SQL_ERROR, SQLNativeSql(this->conn, nullptr, SQL_NTS, buf, buf_char_len, + &output_char_len)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHY009); + + ASSERT_EQ(SQL_ERROR, SQLNativeSql(this->conn, input_str, -100, buf, buf_char_len, + &output_char_len)); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHY090); +} + +TYPED_TEST(StatementTest, SQLNumResultColsReturnsColumnsOnSelect) { + SQLSMALLINT column_count = 0; + SQLSMALLINT expected_value = 3; + SQLWCHAR sql_query[] = L"SELECT 1 AS col1, 'One' AS col2, 3 AS col3"; + SQLINTEGER query_length = static_cast(wcslen(sql_query)); + + ASSERT_EQ(SQL_SUCCESS, SQLExecDirect(this->stmt, sql_query, query_length)); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckIntColumn(this->stmt, 1, 1); + CheckStringColumnW(this->stmt, 2, L"One"); + CheckIntColumn(this->stmt, 3, 3); + + ASSERT_EQ(SQL_SUCCESS, SQLNumResultCols(this->stmt, &column_count)); + + EXPECT_EQ(expected_value, column_count); +} + +TYPED_TEST(StatementTest, SQLNumResultColsReturnsSuccessOnNullptr) { + SQLWCHAR sql_query[] = L"SELECT 1 AS col1, 'One' AS col2, 3 AS col3"; + SQLINTEGER query_length = static_cast(wcslen(sql_query)); + + ASSERT_EQ(SQL_SUCCESS, SQLExecDirect(this->stmt, sql_query, query_length)); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckIntColumn(this->stmt, 1, 1); + CheckStringColumnW(this->stmt, 2, L"One"); + CheckIntColumn(this->stmt, 3, 3); + + ASSERT_EQ(SQL_SUCCESS, SQLNumResultCols(this->stmt, nullptr)); +} + +TYPED_TEST(StatementTest, SQLNumResultColsFunctionSequenceErrorOnNoQuery) { + SQLSMALLINT column_count = 0; + SQLSMALLINT expected_value = 0; + + ASSERT_EQ(SQL_ERROR, SQLNumResultCols(this->stmt, &column_count)); + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY010); + + EXPECT_EQ(expected_value, column_count); +} + +TYPED_TEST(StatementTest, SQLRowCountReturnsNegativeOneOnSelect) { + SQLLEN row_count = 0; + SQLLEN expected_value = -1; + SQLWCHAR sql_query[] = L"SELECT 1 AS col1, 'One' AS col2, 3 AS col3"; + SQLINTEGER query_length = static_cast(wcslen(sql_query)); + + ASSERT_EQ(SQL_SUCCESS, SQLExecDirect(this->stmt, sql_query, query_length)); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckIntColumn(this->stmt, 1, 1); + CheckStringColumnW(this->stmt, 2, L"One"); + CheckIntColumn(this->stmt, 3, 3); + + ASSERT_EQ(SQL_SUCCESS, SQLRowCount(this->stmt, &row_count)); + + EXPECT_EQ(expected_value, row_count); +} + +TYPED_TEST(StatementTest, SQLRowCountReturnsSuccessOnNullptr) { + SQLWCHAR sql_query[] = L"SELECT 1 AS col1, 'One' AS col2, 3 AS col3"; + SQLINTEGER query_length = static_cast(wcslen(sql_query)); + + ASSERT_EQ(SQL_SUCCESS, SQLExecDirect(this->stmt, sql_query, query_length)); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckIntColumn(this->stmt, 1, 1); + CheckStringColumnW(this->stmt, 2, L"One"); + CheckIntColumn(this->stmt, 3, 3); + + ASSERT_EQ(SQL_SUCCESS, SQLRowCount(this->stmt, 0)); +} + +TYPED_TEST(StatementTest, SQLRowCountFunctionSequenceErrorOnNoQuery) { + SQLLEN row_count = 0; + SQLLEN expected_value = 0; + + ASSERT_EQ(SQL_ERROR, SQLRowCount(this->stmt, &row_count)); + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY010); + + EXPECT_EQ(expected_value, row_count); +} + +TYPED_TEST(StatementTest, TestSQLFreeStmtSQLClose) { + std::wstring wsql = L"SELECT 1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLFreeStmt(this->stmt, SQL_CLOSE)); +} + +TYPED_TEST(StatementTest, TestSQLCloseCursor) { + std::wstring wsql = L"SELECT 1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + ASSERT_EQ(SQL_SUCCESS, SQLCloseCursor(this->stmt)); +} + +TYPED_TEST(StatementTest, TestSQLFreeStmtSQLCloseWithoutCursor) { + // SQLFreeStmt(SQL_CLOSE) does not throw error with invalid cursor + + ASSERT_EQ(SQL_SUCCESS, SQLFreeStmt(this->stmt, SQL_CLOSE)); +} + +TYPED_TEST(StatementTest, TestSQLCloseCursorWithoutCursor) { + ASSERT_EQ(SQL_ERROR, SQLCloseCursor(this->stmt)); + + // Verify invalid cursor error state is returned + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState24000); +} + +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/tables_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/tables_test.cc new file mode 100644 index 00000000000..91d079eb844 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/tables_test.cc @@ -0,0 +1,602 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#include "arrow/flight/sql/odbc/odbc_impl/platform.h" + +#include +#include +#include + +#include + +namespace arrow::flight::sql::odbc { + +template +class TablesTest : public T {}; + +class TablesMockTest : public FlightSQLODBCMockTestBase {}; +class TablesRemoteTest : public FlightSQLODBCRemoteTestBase {}; +using TestTypes = ::testing::Types; +TYPED_TEST_SUITE(TablesTest, TestTypes); + +template +class TablesOdbcV2Test : public T {}; + +using TestTypesOdbcV2 = + ::testing::Types; +TYPED_TEST_SUITE(TablesOdbcV2Test, TestTypesOdbcV2); + +namespace { +// Helper Functions + +std::wstring GetStringColumnW(SQLHSTMT stmt, int colId) { + SQLWCHAR buf[1024]; + SQLLEN len_indicator = 0; + + EXPECT_EQ(SQL_SUCCESS, + SQLGetData(stmt, colId, SQL_C_WCHAR, buf, sizeof(buf), &len_indicator)); + + if (len_indicator == SQL_NULL_DATA) { + return L""; + } + + // indicator is in bytes, so convert to character count + size_t char_count = static_cast(len_indicator) / ODBC::GetSqlWCharSize(); + return std::wstring(buf, buf + char_count); +} +} // namespace + +// Test Cases + +TYPED_TEST(TablesTest, SQLTablesTestInputData) { + SQLWCHAR catalog_name[] = L""; + SQLWCHAR schema_name[] = L""; + SQLWCHAR table_name[] = L""; + SQLWCHAR table_type[] = L""; + + // All values populated + EXPECT_EQ(SQL_SUCCESS, SQLTables(this->stmt, catalog_name, sizeof(catalog_name), + schema_name, sizeof(schema_name), table_name, + sizeof(table_name), table_type, sizeof(table_type))); + + ValidateFetch(this->stmt, SQL_NO_DATA); + + // Sizes are nulls + EXPECT_EQ(SQL_SUCCESS, SQLTables(this->stmt, catalog_name, 0, schema_name, 0, + table_name, 0, table_type, 0)); + + ValidateFetch(this->stmt, SQL_NO_DATA); + + // Values are nulls + EXPECT_EQ(SQL_SUCCESS, + SQLTables(this->stmt, 0, sizeof(catalog_name), 0, sizeof(schema_name), 0, + sizeof(table_name), 0, sizeof(table_type))); + + ValidateFetch(this->stmt, SQL_SUCCESS); + // Close statement cursor to avoid leaving in an invalid state + SQLFreeStmt(this->stmt, SQL_CLOSE); + + // All values and sizes are nulls + EXPECT_EQ(SQL_SUCCESS, SQLTables(this->stmt, 0, 0, 0, 0, 0, 0, 0, 0)); + + ValidateFetch(this->stmt, SQL_SUCCESS); +} + +TEST_F(TablesMockTest, SQLTablesTestGetMetadataForAllCatalogs) { + SQLWCHAR empty[] = L""; + SQLWCHAR SQL_ALL_CATALOGS_W[] = L"%"; + std::wstring expected_catalog_name = std::wstring(L"main"); + + // Get Catalog metadata + ASSERT_EQ(SQL_SUCCESS, SQLTables(this->stmt, SQL_ALL_CATALOGS_W, SQL_NTS, empty, + SQL_NTS, empty, SQL_NTS, empty, SQL_NTS)); + + ValidateFetch(this->stmt, SQL_SUCCESS); + + CheckStringColumnW(this->stmt, 1, expected_catalog_name); + CheckNullColumnW(this->stmt, 2); + CheckNullColumnW(this->stmt, 3); + CheckNullColumnW(this->stmt, 4); + CheckNullColumnW(this->stmt, 5); + + ValidateFetch(this->stmt, SQL_NO_DATA); +} + +TEST_F(TablesMockTest, SQLTablesTestGetMetadataForNamedCatalog) { + this->CreateTestTables(); + + SQLWCHAR catalog_name[] = L"main"; + const SQLWCHAR* table_names[] = {static_cast(L"TestTable"), + static_cast(L"foreignTable"), + static_cast(L"intTable"), + static_cast(L"sqlite_sequence")}; + std::wstring expected_catalog_name = std::wstring(catalog_name); + std::wstring expected_table_type = std::wstring(L"table"); + + // Get named Catalog metadata - Mock server returns the system table sqlite_sequence as + // type "table" + ASSERT_EQ(SQL_SUCCESS, SQLTables(this->stmt, catalog_name, SQL_NTS, nullptr, SQL_NTS, + nullptr, SQL_NTS, nullptr, SQL_NTS)); + + for (size_t i = 0; i < sizeof(table_names) / sizeof(*table_names); ++i) { + ValidateFetch(this->stmt, SQL_SUCCESS); + + CheckStringColumnW(this->stmt, 1, expected_catalog_name); + // Mock server does not support table schema + CheckNullColumnW(this->stmt, 2); + CheckStringColumnW(this->stmt, 3, table_names[i]); + CheckStringColumnW(this->stmt, 4, expected_table_type); + CheckNullColumnW(this->stmt, 5); + } + + ValidateFetch(this->stmt, SQL_NO_DATA); +} + +TEST_F(TablesMockTest, SQLTablesTestGetSchemaHasNoData) { + SQLWCHAR SQL_ALL_SCHEMAS_W[] = L"%"; + + // Validate that no schema data is available for Mock server + ASSERT_EQ(SQL_SUCCESS, SQLTables(this->stmt, nullptr, SQL_NTS, SQL_ALL_SCHEMAS_W, + SQL_NTS, nullptr, SQL_NTS, nullptr, SQL_NTS)); + + ValidateFetch(this->stmt, SQL_NO_DATA); +} + +TEST_F(TablesRemoteTest, SQLTablesTestGetMetadataForAllSchemas) { + SQLWCHAR empty[] = L""; + SQLWCHAR SQL_ALL_SCHEMAS_W[] = L"%"; + std::set actual_schemas; + std::set expected_schemas = {L"$scratch", L"INFORMATION_SCHEMA", L"sys", + L"sys.cache"}; + + // Return is unordered and contains user specific schemas, so collect schema names for + // comparison with a known list + ASSERT_EQ(SQL_SUCCESS, SQLTables(this->stmt, empty, SQL_NTS, SQL_ALL_SCHEMAS_W, SQL_NTS, + empty, SQL_NTS, empty, SQL_NTS)); + + while (true) { + SQLRETURN ret = SQLFetch(this->stmt); + if (ret == SQL_NO_DATA) break; + ASSERT_EQ(SQL_SUCCESS, ret); + + CheckNullColumnW(this->stmt, 1); + std::wstring schema = GetStringColumnW(this->stmt, 2); + CheckNullColumnW(this->stmt, 3); + CheckNullColumnW(this->stmt, 4); + CheckNullColumnW(this->stmt, 5); + + // Skip user-specific schemas like "@UserName" + if (!schema.empty() && schema[0] != L'@') { + actual_schemas.insert(schema); + } + } + + EXPECT_EQ(actual_schemas, expected_schemas); +} + +TEST_F(TablesRemoteTest, SQLTablesTestFilterByAllSchema) { + // Requires creation of user table named ODBCTest using schema $scratch in remote server + SQLWCHAR SQL_ALL_SCHEMAS_W[] = L"%"; + const SQLWCHAR* schema_names[] = {static_cast(L"INFORMATION_SCHEMA"), + static_cast(L"INFORMATION_SCHEMA"), + static_cast(L"INFORMATION_SCHEMA"), + static_cast(L"INFORMATION_SCHEMA"), + static_cast(L"INFORMATION_SCHEMA"), + static_cast(L"sys"), + static_cast(L"sys"), + static_cast(L"sys"), + static_cast(L"sys"), + static_cast(L"sys"), + static_cast(L"sys"), + static_cast(L"sys"), + static_cast(L"sys"), + static_cast(L"sys"), + static_cast(L"sys"), + static_cast(L"sys"), + static_cast(L"sys"), + static_cast(L"sys"), + static_cast(L"sys"), + static_cast(L"sys"), + static_cast(L"sys"), + static_cast(L"sys"), + static_cast(L"sys"), + static_cast(L"sys"), + static_cast(L"sys"), + static_cast(L"sys"), + static_cast(L"sys"), + static_cast(L"sys.cache"), + static_cast(L"sys.cache"), + static_cast(L"sys.cache"), + static_cast(L"sys.cache"), + static_cast(L"$scratch")}; + std::wstring expected_system_table_type = std::wstring(L"SYSTEM_TABLE"); + std::wstring expected_user_table_type = std::wstring(L"TABLE"); + + ASSERT_EQ(SQL_SUCCESS, SQLTables(this->stmt, nullptr, SQL_NTS, SQL_ALL_SCHEMAS_W, + SQL_NTS, nullptr, SQL_NTS, nullptr, SQL_NTS)); + + for (size_t i = 0; i < sizeof(schema_names) / sizeof(*schema_names); ++i) { + ValidateFetch(this->stmt, SQL_SUCCESS); + + const std::wstring& expected_table_type = + (std::wstring(schema_names[i]).rfind(L"sys", 0) == 0 || + std::wstring(schema_names[i]) == L"INFORMATION_SCHEMA") + ? expected_system_table_type + : expected_user_table_type; + + CheckNullColumnW(this->stmt, 1); + CheckStringColumnW(this->stmt, 2, schema_names[i]); + // Ignore table name + CheckStringColumnW(this->stmt, 4, expected_table_type); + CheckNullColumnW(this->stmt, 5); + } + + ValidateFetch(this->stmt, SQL_NO_DATA); +} + +TEST_F(TablesRemoteTest, SQLTablesGetMetadataForNamedSchema) { + // Requires creation of user table named ODBCTest using schema $scratch in remote server + SQLWCHAR schema_name[] = L"$scratch"; + std::wstring expected_schema_name = std::wstring(schema_name); + std::wstring expected_table_name = std::wstring(L"ODBCTest"); + std::wstring expected_table_type = std::wstring(L"TABLE"); + + ASSERT_EQ(SQL_SUCCESS, SQLTables(this->stmt, nullptr, SQL_NTS, schema_name, SQL_NTS, + nullptr, SQL_NTS, nullptr, SQL_NTS)); + + ValidateFetch(this->stmt, SQL_SUCCESS); + + CheckNullColumnW(this->stmt, 1); + CheckStringColumnW(this->stmt, 2, expected_schema_name); + // Ignore table name + CheckStringColumnW(this->stmt, 4, expected_table_type); + CheckNullColumnW(this->stmt, 5); + + ValidateFetch(this->stmt, SQL_NO_DATA); +} + +TEST_F(TablesMockTest, SQLTablesTestGetMetadataForAllTables) { + this->CreateTestTables(); + + SQLWCHAR SQL_ALL_TABLES_W[] = L"%"; + const SQLWCHAR* table_names[] = {static_cast(L"TestTable"), + static_cast(L"foreignTable"), + static_cast(L"intTable"), + static_cast(L"sqlite_sequence")}; + std::wstring expected_catalog_name = std::wstring(L"main"); + std::wstring expected_table_type = std::wstring(L"table"); + + // Get all Table metadata - Mock server returns the system table sqlite_sequence as type + // "table" + ASSERT_EQ(SQL_SUCCESS, SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + SQL_ALL_TABLES_W, SQL_NTS, nullptr, SQL_NTS)); + + for (size_t i = 0; i < sizeof(table_names) / sizeof(*table_names); ++i) { + ValidateFetch(this->stmt, SQL_SUCCESS); + + CheckStringColumnW(this->stmt, 1, expected_catalog_name); + // Mock server does not support table schema + CheckNullColumnW(this->stmt, 2); + CheckStringColumnW(this->stmt, 3, table_names[i]); + CheckStringColumnW(this->stmt, 4, expected_table_type); + CheckNullColumnW(this->stmt, 5); + } + + ValidateFetch(this->stmt, SQL_NO_DATA); +} + +TEST_F(TablesMockTest, SQLTablesTestGetMetadataForTableName) { + this->CreateTestTables(); + + // Use mutable arrays to pass SQLWCHAR parameters to SQLTables + SQLWCHAR test_table[] = L"TestTable"; + SQLWCHAR foreign_table[] = L"foreignTable"; + SQLWCHAR int_table[] = L"intTable"; + SQLWCHAR sqlite_sequence[] = L"sqlite_sequence"; + + SQLWCHAR* table_names[] = {test_table, foreign_table, int_table, sqlite_sequence}; + + std::wstring expected_catalog_name = std::wstring(L"main"); + std::wstring expected_table_type = std::wstring(L"table"); + + for (size_t i = 0; i < sizeof(table_names) / sizeof(*table_names); ++i) { + // Get specific Table metadata + ASSERT_EQ(SQL_SUCCESS, SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + table_names[i], SQL_NTS, nullptr, SQL_NTS)); + + ValidateFetch(this->stmt, SQL_SUCCESS); + + CheckStringColumnW(this->stmt, 1, expected_catalog_name); + // Mock server does not support table schema + CheckNullColumnW(this->stmt, 2); + CheckStringColumnW(this->stmt, 3, table_names[i]); + CheckStringColumnW(this->stmt, 4, expected_table_type); + CheckNullColumnW(this->stmt, 5); + + ValidateFetch(this->stmt, SQL_NO_DATA); + } +} + +TEST_F(TablesMockTest, SQLTablesTestGetMetadataForUnicodeTableByTableName) { + this->CreateUnicodeTable(); + + SQLWCHAR unicodetable_name[] = L"数据"; + std::wstring expected_catalog_name = std::wstring(L"main"); + std::wstring expected_table_name = std::wstring(unicodetable_name); + std::wstring expected_table_type = std::wstring(L"table"); + + // Get specific Table metadata + ASSERT_EQ(SQL_SUCCESS, SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + unicodetable_name, SQL_NTS, nullptr, SQL_NTS)); + + ValidateFetch(this->stmt, SQL_SUCCESS); + + CheckStringColumnW(this->stmt, 1, expected_catalog_name); + // Mock server does not support table schema + CheckNullColumnW(this->stmt, 2); + CheckStringColumnW(this->stmt, 3, expected_table_name); + CheckStringColumnW(this->stmt, 4, expected_table_type); + CheckNullColumnW(this->stmt, 5); + + ValidateFetch(this->stmt, SQL_NO_DATA); +} + +TEST_F(TablesMockTest, SQLTablesTestGetMetadataForInvalidTableNameNoData) { + this->CreateTestTables(); + + SQLWCHAR invalid_table_name[] = L"NonExistanttable_name"; + + // Try to get metadata for a non-existant table name + ASSERT_EQ(SQL_SUCCESS, SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + invalid_table_name, SQL_NTS, nullptr, SQL_NTS)); + + ValidateFetch(this->stmt, SQL_NO_DATA); +} + +TEST_F(TablesMockTest, SQLTablesGetMetadataForTableType) { + // Mock server only supports table type "table" in lowercase + this->CreateTestTables(); + + SQLWCHAR table_type_table_lowercase[] = L"table"; + SQLWCHAR table_type_table_uppercase[] = L"TABLE"; + SQLWCHAR table_type_view[] = L"VIEW"; + SQLWCHAR table_type_table_view[] = L"TABLE,VIEW"; + const SQLWCHAR* table_names[] = {static_cast(L"TestTable"), + static_cast(L"foreignTable"), + static_cast(L"intTable"), + static_cast(L"sqlite_sequence")}; + std::wstring expected_catalog_name = std::wstring(L"main"); + std::wstring expected_table_name = std::wstring(L"TestTable"); + std::wstring expected_table_type = std::wstring(table_type_table_lowercase); + + EXPECT_EQ(SQL_SUCCESS, + SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, nullptr, SQL_NTS, + table_type_table_uppercase, SQL_NTS)); + + ValidateFetch(this->stmt, SQL_NO_DATA); + + EXPECT_EQ(SQL_SUCCESS, SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + nullptr, SQL_NTS, table_type_view, SQL_NTS)); + + ValidateFetch(this->stmt, SQL_NO_DATA); + + EXPECT_EQ(SQL_SUCCESS, SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + nullptr, SQL_NTS, table_type_table_view, SQL_NTS)); + + ValidateFetch(this->stmt, SQL_NO_DATA); + + // Returns user table as well as system tables, even though only type table requested + EXPECT_EQ(SQL_SUCCESS, + SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, nullptr, SQL_NTS, + table_type_table_lowercase, SQL_NTS)); + + for (size_t i = 0; i < sizeof(table_names) / sizeof(*table_names); ++i) { + ValidateFetch(this->stmt, SQL_SUCCESS); + + CheckStringColumnW(this->stmt, 1, expected_catalog_name); + // Mock server does not support table schema + CheckNullColumnW(this->stmt, 2); + CheckStringColumnW(this->stmt, 3, table_names[i]); + CheckStringColumnW(this->stmt, 4, expected_table_type); + CheckNullColumnW(this->stmt, 5); + } + + ValidateFetch(this->stmt, SQL_NO_DATA); +} + +TEST_F(TablesRemoteTest, SQLTablesGetMetadataForTableTypeTable) { + // Requires creation of user table named ODBCTest using schema $scratch in remote server + + // Use mutable arrays to pass SQLWCHAR parameters to SQLTables + SQLWCHAR table[] = L"TABLE"; + SQLWCHAR table_view[] = L"TABLE,VIEW"; + + SQLWCHAR* type_list[] = {table, table_view}; + + std::wstring expected_schema_name = std::wstring(L"$scratch"); + std::wstring expected_table_name = std::wstring(L"ODBCTest"); + std::wstring expected_table_type = std::wstring(L"TABLE"); + + for (size_t i = 0; i < sizeof(type_list) / sizeof(*type_list); ++i) { + ASSERT_EQ(SQL_SUCCESS, SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + nullptr, SQL_NTS, type_list[i], SQL_NTS)); + + ValidateFetch(this->stmt, SQL_SUCCESS); + + CheckNullColumnW(this->stmt, 1); + CheckStringColumnW(this->stmt, 2, expected_schema_name); + CheckStringColumnW(this->stmt, 3, expected_table_name); + CheckStringColumnW(this->stmt, 4, expected_table_type); + CheckNullColumnW(this->stmt, 5); + + ValidateFetch(this->stmt, SQL_NO_DATA); + } +} + +TEST_F(TablesRemoteTest, SQLTablesGetMetadataForTableTypeViewHasNoData) { + SQLWCHAR empty[] = L""; + SQLWCHAR type_view[] = L"VIEW"; + + EXPECT_EQ(SQL_SUCCESS, SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, empty, + SQL_NTS, type_view, SQL_NTS)); + + ValidateFetch(this->stmt, SQL_NO_DATA); + + EXPECT_EQ(SQL_SUCCESS, SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + nullptr, SQL_NTS, type_view, SQL_NTS)); + + ValidateFetch(this->stmt, SQL_NO_DATA); +} + +TEST_F(TablesMockTest, SQLTablesGetSupportedTableTypes) { + SQLWCHAR empty[] = L""; + SQLWCHAR SQL_ALL_TABLE_TYPES_W[] = L"%"; + std::wstring expected_table_type = std::wstring(L"table"); + + // Mock server returns lower case for supported type of "table" + ASSERT_EQ(SQL_SUCCESS, SQLTables(this->stmt, empty, SQL_NTS, empty, SQL_NTS, empty, + SQL_NTS, SQL_ALL_TABLE_TYPES_W, SQL_NTS)); + + ValidateFetch(this->stmt, SQL_SUCCESS); + + CheckNullColumnW(this->stmt, 1); + CheckNullColumnW(this->stmt, 2); + CheckNullColumnW(this->stmt, 3); + CheckStringColumnW(this->stmt, 4, expected_table_type); + CheckNullColumnW(this->stmt, 5); + + ValidateFetch(this->stmt, SQL_NO_DATA); +} + +TEST_F(TablesRemoteTest, SQLTablesGetSupportedTableTypes) { + SQLWCHAR empty[] = L""; + SQLWCHAR SQL_ALL_TABLE_TYPES_W[] = L"%"; + const SQLWCHAR* type_lists[] = {static_cast(L"TABLE"), + static_cast(L"SYSTEM_TABLE"), + static_cast(L"VIEW")}; + + ASSERT_EQ(SQL_SUCCESS, SQLTables(this->stmt, empty, SQL_NTS, empty, SQL_NTS, empty, + SQL_NTS, SQL_ALL_TABLE_TYPES_W, SQL_NTS)); + + for (size_t i = 0; i < sizeof(type_lists) / sizeof(*type_lists); ++i) { + ValidateFetch(this->stmt, SQL_SUCCESS); + + CheckNullColumnW(this->stmt, 1); + CheckNullColumnW(this->stmt, 2); + CheckNullColumnW(this->stmt, 3); + CheckStringColumnW(this->stmt, 4, type_lists[i]); + CheckNullColumnW(this->stmt, 5); + } + + ValidateFetch(this->stmt, SQL_NO_DATA); +} + +TYPED_TEST(TablesTest, SQLTablesGetMetadataBySQLDescribeCol) { + SQLWCHAR column_name[1024]; + SQLSMALLINT buf_char_len = + static_cast(sizeof(column_name) / ODBC::GetSqlWCharSize()); + SQLSMALLINT name_length = 0; + SQLSMALLINT column_data_type = 0; + SQLULEN column_size = 0; + SQLSMALLINT decimal_digits = 0; + SQLSMALLINT nullable = 0; + size_t column_index = 0; + + const SQLWCHAR* column_names[] = {static_cast(L"TABLE_CAT"), + static_cast(L"TABLE_SCHEM"), + static_cast(L"TABLE_NAME"), + static_cast(L"TABLE_TYPE"), + static_cast(L"REMARKS")}; + SQLSMALLINT column_data_types[] = {SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR}; + SQLULEN column_sizes[] = {1024, 1024, 1024, 1024, 1024}; + + ASSERT_EQ(SQL_SUCCESS, SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + nullptr, SQL_NTS, nullptr, SQL_NTS)); + + for (size_t i = 0; i < sizeof(column_names) / sizeof(*column_names); ++i) { + column_index = i + 1; + + ASSERT_EQ(SQL_SUCCESS, SQLDescribeCol(this->stmt, column_index, column_name, + buf_char_len, &name_length, &column_data_type, + &column_size, &decimal_digits, &nullable)); + + EXPECT_EQ(wcslen(column_names[i]), name_length); + + std::wstring returned(column_name, column_name + name_length); + EXPECT_EQ(column_names[i], returned); + EXPECT_EQ(column_data_types[i], column_data_type); + EXPECT_EQ(column_sizes[i], column_size); + EXPECT_EQ(0, decimal_digits); + EXPECT_EQ(SQL_NULLABLE, nullable); + + name_length = 0; + column_data_type = 0; + column_size = 0; + decimal_digits = 0; + nullable = 0; + } +} + +TYPED_TEST(TablesOdbcV2Test, SQLTablesGetMetadataBySQLDescribeColODBC2) { + SQLWCHAR column_name[1024]; + SQLSMALLINT buf_char_len = + static_cast(sizeof(column_name) / ODBC::GetSqlWCharSize()); + SQLSMALLINT name_length = 0; + SQLSMALLINT column_data_type = 0; + SQLULEN column_size = 0; + SQLSMALLINT decimal_digits = 0; + SQLSMALLINT nullable = 0; + size_t column_index = 0; + + const SQLWCHAR* column_names[] = {static_cast(L"TABLE_QUALIFIER"), + static_cast(L"TABLE_OWNER"), + static_cast(L"TABLE_NAME"), + static_cast(L"TABLE_TYPE"), + static_cast(L"REMARKS")}; + SQLSMALLINT column_data_types[] = {SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR}; + SQLULEN column_sizes[] = {1024, 1024, 1024, 1024, 1024}; + + ASSERT_EQ(SQL_SUCCESS, SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + nullptr, SQL_NTS, nullptr, SQL_NTS)); + + for (size_t i = 0; i < sizeof(column_names) / sizeof(*column_names); ++i) { + column_index = i + 1; + + ASSERT_EQ(SQL_SUCCESS, SQLDescribeCol(this->stmt, column_index, column_name, + buf_char_len, &name_length, &column_data_type, + &column_size, &decimal_digits, &nullable)); + + EXPECT_EQ(wcslen(column_names[i]), name_length); + + std::wstring returned(column_name, column_name + name_length); + EXPECT_EQ(column_names[i], returned); + EXPECT_EQ(column_data_types[i], column_data_type); + EXPECT_EQ(column_sizes[i], column_size); + EXPECT_EQ(0, decimal_digits); + EXPECT_EQ(SQL_NULLABLE, nullable); + + name_length = 0; + column_data_type = 0; + column_size = 0; + decimal_digits = 0; + nullable = 0; + } +} +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/type_info_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/type_info_test.cc new file mode 100644 index 00000000000..b665477b5c3 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/type_info_test.cc @@ -0,0 +1,1897 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#include "arrow/flight/sql/odbc/odbc_impl/platform.h" + +#include +#include +#include + +#include + +namespace arrow::flight::sql::odbc { + +using std::optional; + +template +class TypeInfoTest : public T {}; + +class TypeInfoMockTest : public FlightSQLODBCMockTestBase {}; +using TestTypes = ::testing::Types; +TYPED_TEST_SUITE(TypeInfoTest, TestTypes); + +class TypeInfoOdbcV2MockTest : public FlightSQLOdbcV2MockTestBase {}; + +namespace { +// Helper Functions + +void CheckSQLDescribeCol(SQLHSTMT stmt, const SQLUSMALLINT column_index, + const std::wstring& expected_name, + const SQLSMALLINT& expected_data_type, + const SQLULEN& expected_column_size, + const SQLSMALLINT& expected_decimal_digits, + const SQLSMALLINT& expected_nullable) { + SQLWCHAR column_name[1024]; + SQLSMALLINT buf_char_len = + static_cast(sizeof(column_name) / ODBC::GetSqlWCharSize()); + SQLSMALLINT name_length = 0; + SQLSMALLINT column_data_type = 0; + SQLULEN column_size = 0; + SQLSMALLINT decimal_digits = 0; + SQLSMALLINT nullable = 0; + + ASSERT_EQ(SQL_SUCCESS, + SQLDescribeCol(stmt, column_index, column_name, buf_char_len, &name_length, + &column_data_type, &column_size, &decimal_digits, &nullable)); + + EXPECT_EQ(name_length, expected_name.size()); + + std::wstring returned(column_name, column_name + name_length); + EXPECT_EQ(expected_name, returned); + EXPECT_EQ(expected_data_type, column_data_type); + EXPECT_EQ(expected_column_size, column_size); + EXPECT_EQ(expected_decimal_digits, decimal_digits); + EXPECT_EQ(expected_nullable, nullable); +} + +void CheckSQLDescribeColODBC2(SQLHSTMT stmt) { + const SQLWCHAR* column_names[] = {static_cast(L"TYPE_NAME"), + static_cast(L"DATA_TYPE"), + static_cast(L"PRECISION"), + static_cast(L"LITERAL_PREFIX"), + static_cast(L"LITERAL_SUFFIX"), + static_cast(L"CREATE_PARAMS"), + static_cast(L"NULLABLE"), + static_cast(L"CASE_SENSITIVE"), + static_cast(L"SEARCHABLE"), + static_cast(L"UNSIGNED_ATTRIBUTE"), + static_cast(L"MONEY"), + static_cast(L"AUTO_INCREMENT"), + static_cast(L"LOCAL_TYPE_NAME"), + static_cast(L"MINIMUM_SCALE"), + static_cast(L"MAXIMUM_SCALE"), + static_cast(L"SQL_DATA_TYPE"), + static_cast(L"SQL_DATETIME_SUB"), + static_cast(L"NUM_PREC_RADIX"), + static_cast(L"INTERVAL_PRECISION")}; + SQLSMALLINT column_data_types[] = { + SQL_WVARCHAR, SQL_SMALLINT, SQL_INTEGER, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_SMALLINT, SQL_SMALLINT, SQL_SMALLINT, SQL_SMALLINT, + SQL_SMALLINT, SQL_SMALLINT, SQL_WVARCHAR, SQL_SMALLINT, SQL_SMALLINT, + SQL_SMALLINT, SQL_SMALLINT, SQL_INTEGER, SQL_SMALLINT}; + SQLULEN column_sizes[] = {1024, 2, 4, 1024, 1024, 1024, 2, 2, 2, 2, + 2, 2, 1024, 2, 2, 2, 2, 4, 2}; + SQLSMALLINT column_decimal_digits[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + SQLSMALLINT column_nullable[] = {SQL_NO_NULLS, SQL_NO_NULLS, SQL_NULLABLE, SQL_NULLABLE, + SQL_NULLABLE, SQL_NULLABLE, SQL_NO_NULLS, SQL_NO_NULLS, + SQL_NO_NULLS, SQL_NULLABLE, SQL_NO_NULLS, SQL_NULLABLE, + SQL_NULLABLE, SQL_NULLABLE, SQL_NULLABLE, SQL_NO_NULLS, + SQL_NULLABLE, SQL_NULLABLE, SQL_NULLABLE}; + + for (size_t i = 0; i < sizeof(column_names) / sizeof(*column_names); ++i) { + SQLUSMALLINT column_index = i + 1; + CheckSQLDescribeCol(stmt, column_index, column_names[i], column_data_types[i], + column_sizes[i], column_decimal_digits[i], column_nullable[i]); + } +} + +void CheckSQLDescribeColODBC3(SQLHSTMT stmt) { + const SQLWCHAR* column_names[] = {static_cast(L"TYPE_NAME"), + static_cast(L"DATA_TYPE"), + static_cast(L"COLUMN_SIZE"), + static_cast(L"LITERAL_PREFIX"), + static_cast(L"LITERAL_SUFFIX"), + static_cast(L"CREATE_PARAMS"), + static_cast(L"NULLABLE"), + static_cast(L"CASE_SENSITIVE"), + static_cast(L"SEARCHABLE"), + static_cast(L"UNSIGNED_ATTRIBUTE"), + static_cast(L"FIXED_PREC_SCALE"), + static_cast(L"AUTO_UNIQUE_VALUE"), + static_cast(L"LOCAL_TYPE_NAME"), + static_cast(L"MINIMUM_SCALE"), + static_cast(L"MAXIMUM_SCALE"), + static_cast(L"SQL_DATA_TYPE"), + static_cast(L"SQL_DATETIME_SUB"), + static_cast(L"NUM_PREC_RADIX"), + static_cast(L"INTERVAL_PRECISION")}; + SQLSMALLINT column_data_types[] = { + SQL_WVARCHAR, SQL_SMALLINT, SQL_INTEGER, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_SMALLINT, SQL_SMALLINT, SQL_SMALLINT, SQL_SMALLINT, + SQL_SMALLINT, SQL_SMALLINT, SQL_WVARCHAR, SQL_SMALLINT, SQL_SMALLINT, + SQL_SMALLINT, SQL_SMALLINT, SQL_INTEGER, SQL_SMALLINT}; + SQLULEN column_sizes[] = {1024, 2, 4, 1024, 1024, 1024, 2, 2, 2, 2, + 2, 2, 1024, 2, 2, 2, 2, 4, 2}; + SQLSMALLINT column_decimal_digits[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + SQLSMALLINT column_nullable[] = {SQL_NO_NULLS, SQL_NO_NULLS, SQL_NULLABLE, SQL_NULLABLE, + SQL_NULLABLE, SQL_NULLABLE, SQL_NO_NULLS, SQL_NO_NULLS, + SQL_NO_NULLS, SQL_NULLABLE, SQL_NO_NULLS, SQL_NULLABLE, + SQL_NULLABLE, SQL_NULLABLE, SQL_NULLABLE, SQL_NO_NULLS, + SQL_NULLABLE, SQL_NULLABLE, SQL_NULLABLE}; + + for (size_t i = 0; i < sizeof(column_names) / sizeof(*column_names); ++i) { + SQLUSMALLINT column_index = i + 1; + CheckSQLDescribeCol(stmt, column_index, column_names[i], column_data_types[i], + column_sizes[i], column_decimal_digits[i], column_nullable[i]); + } +} + +void CheckSQLGetTypeInfo( + SQLHSTMT stmt, const std::wstring& expected_type_name, + const SQLSMALLINT& expected_data_type, const SQLINTEGER& expected_column_size, + const optional& expected_literal_prefix, + const optional& expected_literal_suffix, + const optional& expected_create_params, + const SQLSMALLINT& expected_nullable, const SQLSMALLINT& expected_case_sensitive, + const SQLSMALLINT& expected_searchable, const SQLSMALLINT& expected_unsigned_attr, + const SQLSMALLINT& expected_fixed_prec_scale, + const SQLSMALLINT& expected_auto_unique_value, + const std::wstring& expected_local_type_name, const SQLSMALLINT& expected_min_scale, + const SQLSMALLINT& expected_max_scale, const SQLSMALLINT& expected_sql_data_type, + const SQLSMALLINT& expected_sql_datetime_sub, + const SQLINTEGER& expected_num_prec_radix, const SQLINTEGER& expected_interval_prec) { + CheckStringColumnW(stmt, 1, expected_type_name); // type name + CheckSmallIntColumn(stmt, 2, expected_data_type); // data type + CheckIntColumn(stmt, 3, expected_column_size); // column size + + if (expected_literal_prefix) { // literal prefix + CheckStringColumnW(stmt, 4, *expected_literal_prefix); + } else { + CheckNullColumnW(stmt, 4); + } + + if (expected_literal_suffix) { // literal suffix + CheckStringColumnW(stmt, 5, *expected_literal_suffix); + } else { + CheckNullColumnW(stmt, 5); + } + + if (expected_create_params) { // create params + CheckStringColumnW(stmt, 6, *expected_create_params); + } else { + CheckNullColumnW(stmt, 6); + } + + CheckSmallIntColumn(stmt, 7, expected_nullable); // nullable + CheckSmallIntColumn(stmt, 8, expected_case_sensitive); // case sensitive + CheckSmallIntColumn(stmt, 9, expected_searchable); // searchable + CheckSmallIntColumn(stmt, 10, expected_unsigned_attr); // unsigned attr + CheckSmallIntColumn(stmt, 11, expected_fixed_prec_scale); // fixed prec scale + CheckSmallIntColumn(stmt, 12, expected_auto_unique_value); // auto unique value + CheckStringColumnW(stmt, 13, expected_local_type_name); // local type name + CheckSmallIntColumn(stmt, 14, expected_min_scale); // min scale + CheckSmallIntColumn(stmt, 15, expected_max_scale); // max scale + CheckSmallIntColumn(stmt, 16, expected_sql_data_type); // sql data type + CheckSmallIntColumn(stmt, 17, expected_sql_datetime_sub); // sql datetime sub + CheckIntColumn(stmt, 18, expected_num_prec_radix); // num prec radix + CheckIntColumn(stmt, 19, expected_interval_prec); // interval prec +} +} // namespace + +TEST_F(TypeInfoMockTest, TestSQLGetTypeInfoAllTypes) { + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_ALL_TYPES)); + + // Check bit data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"bit"), // expected_type_name + SQL_BIT, // expected_data_type + 1, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + NULL, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"bit"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_BIT, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // Check tinyint data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"tinyint"), // expected_type_name + SQL_TINYINT, // expected_data_type + 3, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"tinyint"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_TINYINT, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // Check bigint data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"bigint"), // expected_type_name + SQL_BIGINT, // expected_data_type + 19, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"bigint"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_BIGINT, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // Check longvarbinary data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"longvarbinary"), // expected_type_name + SQL_LONGVARBINARY, // expected_data_type + 65536, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + NULL, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"longvarbinary"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_LONGVARBINARY, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // Check varbinary data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"varbinary"), // expected_type_name + SQL_VARBINARY, // expected_data_type + 255, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + NULL, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"varbinary"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_VARBINARY, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // Check text data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Driver returns SQL_WLONGVARCHAR since unicode is enabled + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"text"), // expected_type_name + SQL_WLONGVARCHAR, // expected_data_type + 65536, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::wstring(L"length"), // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + NULL, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"text"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_WLONGVARCHAR, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // Check longvarchar data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"longvarchar"), // expected_type_name + SQL_WLONGVARCHAR, // expected_data_type + 65536, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::wstring(L"length"), // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + NULL, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"longvarchar"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_WLONGVARCHAR, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // Check char data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Driver returns SQL_WCHAR since unicode is enabled + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"char"), // expected_type_name + SQL_WCHAR, // expected_data_type + 255, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::wstring(L"length"), // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + NULL, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"char"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_WCHAR, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // Check integer data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"integer"), // expected_type_name + SQL_INTEGER, // expected_data_type + 9, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"integer"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_INTEGER, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // Check smallint data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"smallint"), // expected_type_name + SQL_SMALLINT, // expected_data_type + 5, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"smallint"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_SMALLINT, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // Check float data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"float"), // expected_type_name + SQL_FLOAT, // expected_data_type + 7, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"float"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_FLOAT, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // Check double data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"double"), // expected_type_name + SQL_DOUBLE, // expected_data_type + 15, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"double"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_DOUBLE, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // Check numeric data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Mock server treats numeric data type as a double type + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"numeric"), // expected_type_name + SQL_DOUBLE, // expected_data_type + 15, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"numeric"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_DOUBLE, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // Check varchar data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Driver returns SQL_WVARCHAR since unicode is enabled + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"varchar"), // expected_type_name + SQL_WVARCHAR, // expected_data_type + 255, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::wstring(L"length"), // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"varchar"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_WVARCHAR, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // Check date data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"date"), // expected_type_name + SQL_TYPE_DATE, // expected_data_type + 10, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"date"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_DATETIME, // expected_sql_data_type + SQL_CODE_DATE, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // Check time data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"time"), // expected_type_name + SQL_TYPE_TIME, // expected_data_type + 8, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"time"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_DATETIME, // expected_sql_data_type + SQL_CODE_TIME, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // Check timestamp data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"timestamp"), // expected_type_name + SQL_TYPE_TIMESTAMP, // expected_data_type + 32, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"timestamp"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_DATETIME, // expected_sql_data_type + SQL_CODE_TIMESTAMP, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); +} + +TEST_F(TypeInfoOdbcV2MockTest, TestSQLGetTypeInfoAllTypes) { + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_ALL_TYPES)); + + // Check bit data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"bit"), // expected_type_name + SQL_BIT, // expected_data_type + 1, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + NULL, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"bit"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_BIT, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC2(this->stmt); + + // Check tinyint data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"tinyint"), // expected_type_name + SQL_TINYINT, // expected_data_type + 3, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"tinyint"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_TINYINT, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC2(this->stmt); + + // Check bigint data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"bigint"), // expected_type_name + SQL_BIGINT, // expected_data_type + 19, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"bigint"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_BIGINT, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC2(this->stmt); + + // Check longvarbinary data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"longvarbinary"), // expected_type_name + SQL_LONGVARBINARY, // expected_data_type + 65536, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + NULL, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"longvarbinary"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_LONGVARBINARY, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC2(this->stmt); + + // Check varbinary data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"varbinary"), // expected_type_name + SQL_VARBINARY, // expected_data_type + 255, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + NULL, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"varbinary"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_VARBINARY, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC2(this->stmt); + + // Check text data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Driver returns SQL_WLONGVARCHAR since unicode is enabled + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"text"), // expected_type_name + SQL_WLONGVARCHAR, // expected_data_type + 65536, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::wstring(L"length"), // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + NULL, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"text"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_WLONGVARCHAR, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC2(this->stmt); + + // Check longvarchar data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"longvarchar"), // expected_type_name + SQL_WLONGVARCHAR, // expected_data_type + 65536, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::wstring(L"length"), // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + NULL, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"longvarchar"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_WLONGVARCHAR, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC2(this->stmt); + + // Check char data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Driver returns SQL_WCHAR since unicode is enabled + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"char"), // expected_type_name + SQL_WCHAR, // expected_data_type + 255, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::wstring(L"length"), // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + NULL, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"char"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_WCHAR, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC2(this->stmt); + + // Check integer data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"integer"), // expected_type_name + SQL_INTEGER, // expected_data_type + 9, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"integer"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_INTEGER, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC2(this->stmt); + + // Check smallint data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"smallint"), // expected_type_name + SQL_SMALLINT, // expected_data_type + 5, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"smallint"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_SMALLINT, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC2(this->stmt); + + // Check float data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"float"), // expected_type_name + SQL_FLOAT, // expected_data_type + 7, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"float"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_FLOAT, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC2(this->stmt); + + // Check double data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"double"), // expected_type_name + SQL_DOUBLE, // expected_data_type + 15, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"double"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_DOUBLE, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC2(this->stmt); + + // Check numeric data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Mock server treats numeric data type as a double type + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"numeric"), // expected_type_name + SQL_DOUBLE, // expected_data_type + 15, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"numeric"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_DOUBLE, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC2(this->stmt); + + // Check varchar data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Driver returns SQL_WVARCHAR since unicode is enabled + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"varchar"), // expected_type_name + SQL_WVARCHAR, // expected_data_type + 255, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::wstring(L"length"), // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"varchar"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_WVARCHAR, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC2(this->stmt); + + // Check date data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"date"), // expected_type_name + SQL_DATE, // expected_data_type + 10, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"date"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_DATETIME, // expected_sql_data_type + NULL, // expected_sql_datetime_sub, driver returns NULL for Ver2 + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC2(this->stmt); + + // Check time data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"time"), // expected_type_name + SQL_TIME, // expected_data_type + 8, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"time"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_DATETIME, // expected_sql_data_type + NULL, // expected_sql_datetime_sub, driver returns NULL for Ver2 + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC2(this->stmt); + + // Check timestamp data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"timestamp"), // expected_type_name + SQL_TIMESTAMP, // expected_data_type + 32, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"timestamp"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_DATETIME, // expected_sql_data_type + NULL, // expected_sql_datetime_sub, driver returns NULL for Ver2 + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC2(this->stmt); +} + +TEST_F(TypeInfoMockTest, TestSQLGetTypeInfoBit) { + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_BIT)); + + // Check bit data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"bit"), // expected_type_name + SQL_BIT, // expected_data_type + 1, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + NULL, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"bit"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_BIT, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // No more data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(TypeInfoMockTest, TestSQLGetTypeInfoTinyInt) { + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_TINYINT)); + + // Check tinyint data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"tinyint"), // expected_type_name + SQL_TINYINT, // expected_data_type + 3, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"tinyint"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_TINYINT, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // No more data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(TypeInfoMockTest, TestSQLGetTypeInfoBigInt) { + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_BIGINT)); + + // Check bigint data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"bigint"), // expected_type_name + SQL_BIGINT, // expected_data_type + 19, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"bigint"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_BIGINT, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // No more data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(TypeInfoMockTest, TestSQLGetTypeInfoLongVarbinary) { + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_LONGVARBINARY)); + + // Check longvarbinary data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"longvarbinary"), // expected_type_name + SQL_LONGVARBINARY, // expected_data_type + 65536, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + NULL, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"longvarbinary"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_LONGVARBINARY, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // No more data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(TypeInfoMockTest, TestSQLGetTypeInfoVarbinary) { + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_VARBINARY)); + + // Check varbinary data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"varbinary"), // expected_type_name + SQL_VARBINARY, // expected_data_type + 255, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + NULL, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"varbinary"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_VARBINARY, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + // No more data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(TypeInfoMockTest, TestSQLGetTypeInfoLongVarchar) { + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_WLONGVARCHAR)); + + // Check text data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Driver returns SQL_WLONGVARCHAR since unicode is enabled + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"text"), // expected_type_name + SQL_WLONGVARCHAR, // expected_data_type + 65536, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::wstring(L"length"), // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + NULL, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"text"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_WLONGVARCHAR, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // Check longvarchar data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"longvarchar"), // expected_type_name + SQL_WLONGVARCHAR, // expected_data_type + 65536, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::wstring(L"length"), // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + NULL, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"longvarchar"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_WLONGVARCHAR, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // No more data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(TypeInfoMockTest, TestSQLGetTypeInfoChar) { + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_WCHAR)); + + // Check char data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Driver returns SQL_WCHAR since unicode is enabled + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"char"), // expected_type_name + SQL_WCHAR, // expected_data_type + 255, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::wstring(L"length"), // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + NULL, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"char"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_WCHAR, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // No more data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(TypeInfoMockTest, TestSQLGetTypeInfoInteger) { + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_INTEGER)); + + // Check integer data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"integer"), // expected_type_name + SQL_INTEGER, // expected_data_type + 9, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"integer"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_INTEGER, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // No more data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(TypeInfoMockTest, TestSQLGetTypeInfoSmallInt) { + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_SMALLINT)); + + // Check smallint data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"smallint"), // expected_type_name + SQL_SMALLINT, // expected_data_type + 5, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"smallint"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_SMALLINT, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // No more data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(TypeInfoMockTest, TestSQLGetTypeInfoFloat) { + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_FLOAT)); + + // Check float data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"float"), // expected_type_name + SQL_FLOAT, // expected_data_type + 7, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"float"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_FLOAT, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // No more data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(TypeInfoMockTest, TestSQLGetTypeInfoDouble) { + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_DOUBLE)); + + // Check double data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"double"), // expected_type_name + SQL_DOUBLE, // expected_data_type + 15, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"double"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_DOUBLE, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // Check numeric data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Mock server treats numeric data type as a double type + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"numeric"), // expected_type_name + SQL_DOUBLE, // expected_data_type + 15, // expected_column_size + std::nullopt, // expected_literal_prefix + std::nullopt, // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"numeric"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_DOUBLE, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // No more data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(TypeInfoMockTest, TestSQLGetTypeInfoVarchar) { + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_WVARCHAR)); + + // Check varchar data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + // Driver returns SQL_WVARCHAR since unicode is enabled + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"varchar"), // expected_type_name + SQL_WVARCHAR, // expected_data_type + 255, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::wstring(L"length"), // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"varchar"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_WVARCHAR, // expected_sql_data_type + NULL, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // No more data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(TypeInfoMockTest, TestSQLGetTypeInfoSQLTypeDate) { + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_TYPE_DATE)); + + // Check date data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"date"), // expected_type_name + SQL_TYPE_DATE, // expected_data_type + 10, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"date"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_DATETIME, // expected_sql_data_type + SQL_CODE_DATE, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // No more data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(TypeInfoMockTest, TestSQLGetTypeInfoSQLDate) { + // Pass ODBC Ver 2 data type + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_DATE)); + + // Check date data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"date"), // expected_type_name + SQL_TYPE_DATE, // expected_data_type + 10, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"date"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_DATETIME, // expected_sql_data_type + SQL_CODE_DATE, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // No more data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(TypeInfoOdbcV2MockTest, TestSQLGetTypeInfoDate) { + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_DATE)); + + // Check date data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"date"), // expected_type_name + SQL_DATE, // expected_data_type + 10, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"date"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_DATETIME, // expected_sql_data_type + NULL, // expected_sql_datetime_sub, driver returns NULL for Ver2 + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC2(this->stmt); + + // No more data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(TypeInfoOdbcV2MockTest, TestSQLGetTypeInfoSQLTypeDate) { + // Pass ODBC Ver 3 data type + ASSERT_EQ(SQL_ERROR, SQLGetTypeInfo(this->stmt, SQL_TYPE_DATE)); + + // Driver manager returns SQL data type out of range error state + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateS1004); +} + +TEST_F(TypeInfoMockTest, TestSQLGetTypeInfoSQLTypeTime) { + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_TYPE_TIME)); + + // Check time data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"time"), // expected_type_name + SQL_TYPE_TIME, // expected_data_type + 8, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"time"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_DATETIME, // expected_sql_data_type + SQL_CODE_TIME, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // No more data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(TypeInfoMockTest, TestSQLGetTypeInfoSQLTime) { + // Pass ODBC Ver 2 data type + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_TIME)); + + // Check time data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"time"), // expected_type_name + SQL_TYPE_TIME, // expected_data_type + 8, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"time"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_DATETIME, // expected_sql_data_type + SQL_CODE_TIME, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // No more data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(TypeInfoOdbcV2MockTest, TestSQLGetTypeInfoTime) { + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_TIME)); + + // Check time data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"time"), // expected_type_name + SQL_TIME, // expected_data_type + 8, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"time"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_DATETIME, // expected_sql_data_type + NULL, // expected_sql_datetime_sub, driver returns NULL for Ver2 + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC2(this->stmt); + + // No more data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(TypeInfoOdbcV2MockTest, TestSQLGetTypeInfoSQLTypeTime) { + // Pass ODBC Ver 3 data type + ASSERT_EQ(SQL_ERROR, SQLGetTypeInfo(this->stmt, SQL_TYPE_TIME)); + + // Driver manager returns SQL data type out of range error state + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateS1004); +} + +TEST_F(TypeInfoMockTest, TestSQLGetTypeInfoSQLTypeTimestamp) { + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_TYPE_TIMESTAMP)); + + // Check timestamp data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"timestamp"), // expected_type_name + SQL_TYPE_TIMESTAMP, // expected_data_type + 32, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"timestamp"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_DATETIME, // expected_sql_data_type + SQL_CODE_TIMESTAMP, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // No more data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(TypeInfoMockTest, TestSQLGetTypeInfoSQLTimestamp) { + // Pass ODBC Ver 2 data type + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_TIMESTAMP)); + + // Check timestamp data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"timestamp"), // expected_type_name + SQL_TYPE_TIMESTAMP, // expected_data_type + 32, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"timestamp"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_DATETIME, // expected_sql_data_type + SQL_CODE_TIMESTAMP, // expected_sql_datetime_sub + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC3(this->stmt); + + // No more data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(TypeInfoOdbcV2MockTest, TestSQLGetTypeInfoSQLTimestamp) { + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_TIMESTAMP)); + + // Check timestamp data type + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckSQLGetTypeInfo(this->stmt, + std::wstring(L"timestamp"), // expected_type_name + SQL_TIMESTAMP, // expected_data_type + 32, // expected_column_size + std::wstring(L"'"), // expected_literal_prefix + std::wstring(L"'"), // expected_literal_suffix + std::nullopt, // expected_create_params + SQL_NULLABLE, // expected_nullable + SQL_FALSE, // expected_case_sensitive + SQL_SEARCHABLE, // expected_searchable + SQL_FALSE, // expected_unsigned_attr + SQL_FALSE, // expected_fixed_prec_scale + NULL, // expected_auto_unique_value + std::wstring(L"timestamp"), // expected_local_type_name + NULL, // expected_min_scale + NULL, // expected_max_scale + SQL_DATETIME, // expected_sql_data_type + NULL, // expected_sql_datetime_sub, driver returns NULL for Ver2 + NULL, // expected_num_prec_radix + NULL); // expected_interval_prec + + CheckSQLDescribeColODBC2(this->stmt); + + // No more data + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +TEST_F(TypeInfoOdbcV2MockTest, TestSQLGetTypeInfoSQLTypeTimestamp) { + // Pass ODBC Ver 3 data type + ASSERT_EQ(SQL_ERROR, SQLGetTypeInfo(this->stmt, SQL_TYPE_TIMESTAMP)); + + // Driver manager returns SQL data type out of range error state + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateS1004); +} + +TEST_F(TypeInfoMockTest, TestSQLGetTypeInfoInvalidDataType) { + SQLSMALLINT invalid_data_type = -114; + ASSERT_EQ(SQL_ERROR, SQLGetTypeInfo(this->stmt, invalid_data_type)); + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY004); +} + +TYPED_TEST(TypeInfoTest, TestSQLGetTypeInfoUnsupportedDataType) { + // Assumes mock and remote server don't support GUID data type + + ASSERT_EQ(SQL_SUCCESS, SQLGetTypeInfo(this->stmt, SQL_GUID)); + + // Result set is empty with valid data type that is unsupported by the server + ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt)); +} + +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/vendored/whereami/whereami.cc b/cpp/src/arrow/vendored/whereami/whereami.cc index 945226193f9..94437361ec0 100644 --- a/cpp/src/arrow/vendored/whereami/whereami.cc +++ b/cpp/src/arrow/vendored/whereami/whereami.cc @@ -159,7 +159,7 @@ WAI_NOINLINE WAI_FUNCSPEC int WAI_PREFIX(getModulePath)(char* out, int capacity, return length; } -#elif defined(__linux__) || defined(__CYGWIN__) || defined(__sun) || \ +#elif defined(__APPLE__) || defined(__linux__) || defined(__CYGWIN__) || defined(__sun) || \ defined(WAI_USE_PROC_SELF_EXE) # include diff --git a/cpp/vcpkg.json b/cpp/vcpkg.json index 41c40fcc85f..7e03a515a8f 100644 --- a/cpp/vcpkg.json +++ b/cpp/vcpkg.json @@ -21,10 +21,8 @@ "boost-filesystem", "boost-locale", "boost-multiprecision", - "boost-optional", "boost-process", "boost-system", - "boost-variant", "boost-xpressive", "brotli", "bzip2",