diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc index 01780f0efe2..27379943f97 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc @@ -838,8 +838,33 @@ SQLRETURN SQLExtendedFetch(SQLHSTMT stmt, SQLUSMALLINT fetch_orientation, << ", row_count_ptr: " << static_cast(row_count_ptr) << ", row_status_array: " << static_cast(row_status_array); - // GH-47714 TODO: Implement SQLExtendedFetch - return SQL_INVALID_HANDLE; + + using ODBC::ODBCDescriptor; + using ODBC::ODBCStatement; + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + // Only SQL_FETCH_NEXT forward-only fetching orientation is supported, + // meaning the behavior of SQLExtendedFetch is same as SQLFetch. + if (fetch_orientation != SQL_FETCH_NEXT) { + throw DriverException("Optional feature not supported.", "HYC00"); + } + // Ignore fetch_offset as it's not applicable to SQL_FETCH_NEXT + ARROW_UNUSED(fetch_offset); + + ODBCStatement* statement = reinterpret_cast(stmt); + + // The SQL_ROWSET_SIZE statement attribute specifies the number of rows in the + // rowset. Retrieve it from GetRowsetSize. + SQLULEN row_set_size = statement->GetRowsetSize(); + ARROW_LOG(DEBUG) << "SQL_ROWSET_SIZE value for SQLExtendedFetch: " << row_set_size; + + if (statement->Fetch(static_cast(row_set_size), row_count_ptr, + row_status_array)) { + return SQL_SUCCESS; + } else { + // Reached the end of rowset + return SQL_NO_DATA; + } + }); } SQLRETURN SQLFetchScroll(SQLHSTMT stmt, SQLSMALLINT fetch_orientation, diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.cc index d452e77db1d..b380180ce58 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.cc @@ -306,7 +306,8 @@ void ODBCStatement::ExecuteDirect(const std::string& query) { is_prepared_ = false; } -bool ODBCStatement::Fetch(size_t rows) { +bool ODBCStatement::Fetch(size_t rows, SQLULEN* row_count_ptr, + SQLUSMALLINT* row_status_array) { if (has_reached_end_of_result_) { ird_->SetRowsProcessed(0); return false; @@ -339,11 +340,24 @@ bool ODBCStatement::Fetch(size_t rows) { current_ard_->NotifyBindingsHavePropagated(); } - size_t rows_fetched = current_result_->Move(rows, current_ard_->GetBindOffset(), - current_ard_->GetBoundStructOffset(), - ird_->GetArrayStatusPtr()); + uint16_t* array_status_ptr; + if (row_status_array) { + // For SQLExtendedFetch only + array_status_ptr = row_status_array; + } else { + array_status_ptr = ird_->GetArrayStatusPtr(); + } + + size_t rows_fetched = + current_result_->Move(rows, current_ard_->GetBindOffset(), + current_ard_->GetBoundStructOffset(), array_status_ptr); ird_->SetRowsProcessed(static_cast(rows_fetched)); + if (row_count_ptr) { + // For SQLExtendedFetch only + *row_count_ptr = rows_fetched; + } + row_number_ += rows_fetched; has_reached_end_of_result_ = rows_fetched != rows; return rows_fetched != 0; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.h index 8e128db1bda..f8349d83ba3 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.h @@ -58,10 +58,11 @@ class ODBCStatement : public ODBCHandle { void ExecutePrepared(); void ExecuteDirect(const std::string& query); - /** - * @brief Returns true if the number of rows fetch was greater than zero. - */ - bool Fetch(size_t rows); + /// \brief Return true if the number of rows fetch was greater than zero. + /// + /// row_count_ptr and row_status_array are optional arguments, they are only needed for + /// SQLExtendedFetch + bool Fetch(size_t rows, SQLULEN* row_count_ptr = 0, SQLUSMALLINT* row_status_array = 0); bool IsPrepared() const; void GetStmtAttr(SQLINTEGER statement_attribute, SQLPOINTER output, diff --git a/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt b/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt index 4bc240637e7..cf3e15451d9 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt @@ -35,6 +35,7 @@ add_arrow_test(flight_sql_odbc_test odbc_test_suite.cc odbc_test_suite.h connection_test.cc + statement_test.cc # Enable Protobuf cleanup after test execution # GH-46889: move protobuf_test_util to a more common location ../../../../engine/substrait/protobuf_test_util.cc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc new file mode 100644 index 00000000000..4c99475fcc4 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#include "arrow/flight/sql/odbc/odbc_impl/platform.h" + +#include +#include +#include + +#include + +#include +#include + +namespace arrow::flight::sql::odbc { + +template +class StatementTest : public T {}; + +class StatementMockTest : public FlightSQLODBCMockTestBase {}; +class StatementRemoteTest : public FlightSQLODBCRemoteTestBase {}; +using TestTypes = ::testing::Types; +TYPED_TEST_SUITE(StatementTest, TestTypes); + +TYPED_TEST(StatementTest, TestSQLExtendedFetchRowFetching) { + // Set SQL_ROWSET_SIZE to fetch 3 rows at once + + constexpr SQLULEN rows = 3; + SQLINTEGER val[rows]; + SQLLEN buf_len = sizeof(val); + SQLLEN ind[rows]; + + // Same variable will be used for column 1, the value of `val` + // should be updated after every SQLFetch call. + ASSERT_EQ(SQL_SUCCESS, SQLBindCol(this->stmt, 1, SQL_C_LONG, val, buf_len, ind)); + + ASSERT_EQ(SQL_SUCCESS, SQLSetStmtAttr(this->stmt, SQL_ROWSET_SIZE, + reinterpret_cast(rows), 0)); + + std::wstring wsql = + LR"( + SELECT 1 AS small_table + UNION ALL + SELECT 2 + UNION ALL + SELECT 3; + )"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + // Fetch row 1-3. + SQLULEN row_count; + SQLUSMALLINT row_status[rows]; + + ASSERT_EQ(SQL_SUCCESS, + SQLExtendedFetch(this->stmt, SQL_FETCH_NEXT, 0, &row_count, row_status)); + EXPECT_EQ(3, row_count); + + for (int i = 0; i < rows; i++) { + EXPECT_EQ(SQL_SUCCESS, row_status[i]); + } + + // Verify 1 is returned for row 1 + EXPECT_EQ(1, val[0]); + // Verify 2 is returned for row 2 + EXPECT_EQ(2, val[1]); + // Verify 3 is returned for row 3 + EXPECT_EQ(3, val[2]); + + // Verify result set has no more data beyond row 3 + SQLULEN row_count2; + SQLUSMALLINT row_status2[rows]; + EXPECT_EQ(SQL_NO_DATA, + SQLExtendedFetch(this->stmt, SQL_FETCH_NEXT, 0, &row_count2, row_status2)); +} + +TEST_F(StatementRemoteTest, DISABLED_TestSQLExtendedFetchQueryNullIndicator) { + // GH-47110: SQLExtendedFetch should return SQL_SUCCESS_WITH_INFO for 22002 + // Limitation on mock test server prevents null from working properly, so use remote + // server instead. Mock server has type `DENSE_UNION` for null column data. + SQLINTEGER val; + + ASSERT_EQ(SQL_SUCCESS, SQLBindCol(this->stmt, 1, SQL_C_LONG, &val, 0, 0)); + + std::wstring wsql = L"SELECT null as null_col;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ASSERT_EQ(SQL_SUCCESS, + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size()))); + + SQLULEN row_count1; + SQLUSMALLINT row_status1[1]; + + // SQLExtendedFetch should return SQL_SUCCESS_WITH_INFO for 22002 state + ASSERT_EQ(SQL_SUCCESS_WITH_INFO, + SQLExtendedFetch(this->stmt, SQL_FETCH_NEXT, 0, &row_count1, row_status1)); + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState22002); +} + +} // namespace arrow::flight::sql::odbc