Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions cpp/src/arrow/flight/sql/odbc/odbc_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -838,8 +838,33 @@ SQLRETURN SQLExtendedFetch(SQLHSTMT stmt, SQLUSMALLINT fetch_orientation,
<< ", row_count_ptr: " << static_cast<const void*>(row_count_ptr)
<< ", row_status_array: "
<< static_cast<const void*>(row_status_array);
// GH-47714 TODO: Implement SQLExtendedFetch
return SQL_INVALID_HANDLE;

using ODBC::ODBCDescriptor;
using ODBC::ODBCStatement;
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
// Only SQL_FETCH_NEXT forward-only fetching orientation is supported,
// meaning the behavior of SQLExtendedFetch is same as SQLFetch.
if (fetch_orientation != SQL_FETCH_NEXT) {
throw DriverException("Optional feature not supported.", "HYC00");
}
// Ignore fetch_offset as it's not applicable to SQL_FETCH_NEXT
ARROW_UNUSED(fetch_offset);

ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);

// The SQL_ROWSET_SIZE statement attribute specifies the number of rows in the
// rowset. Retrieve it from GetRowsetSize.
SQLULEN row_set_size = statement->GetRowsetSize();
ARROW_LOG(DEBUG) << "SQL_ROWSET_SIZE value for SQLExtendedFetch: " << row_set_size;

if (statement->Fetch(static_cast<size_t>(row_set_size), row_count_ptr,
row_status_array)) {
return SQL_SUCCESS;
} else {
// Reached the end of rowset
return SQL_NO_DATA;
}
});
}

SQLRETURN SQLFetchScroll(SQLHSTMT stmt, SQLSMALLINT fetch_orientation,
Expand Down
22 changes: 18 additions & 4 deletions cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,8 @@ void ODBCStatement::ExecuteDirect(const std::string& query) {
is_prepared_ = false;
}

bool ODBCStatement::Fetch(size_t rows) {
bool ODBCStatement::Fetch(size_t rows, SQLULEN* row_count_ptr,
SQLUSMALLINT* row_status_array) {
if (has_reached_end_of_result_) {
ird_->SetRowsProcessed(0);
return false;
Expand Down Expand Up @@ -339,11 +340,24 @@ bool ODBCStatement::Fetch(size_t rows) {
current_ard_->NotifyBindingsHavePropagated();
}

size_t rows_fetched = current_result_->Move(rows, current_ard_->GetBindOffset(),
current_ard_->GetBoundStructOffset(),
ird_->GetArrayStatusPtr());
uint16_t* array_status_ptr;
if (row_status_array) {
// For SQLExtendedFetch only
array_status_ptr = row_status_array;
} else {
array_status_ptr = ird_->GetArrayStatusPtr();
}

size_t rows_fetched =
current_result_->Move(rows, current_ard_->GetBindOffset(),
current_ard_->GetBoundStructOffset(), array_status_ptr);
ird_->SetRowsProcessed(static_cast<SQLULEN>(rows_fetched));

if (row_count_ptr) {
// For SQLExtendedFetch only
*row_count_ptr = rows_fetched;
}

row_number_ += rows_fetched;
has_reached_end_of_result_ = rows_fetched != rows;
return rows_fetched != 0;
Expand Down
9 changes: 5 additions & 4 deletions cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,11 @@ class ODBCStatement : public ODBCHandle<ODBCStatement> {
void ExecutePrepared();
void ExecuteDirect(const std::string& query);

/**
* @brief Returns true if the number of rows fetch was greater than zero.
*/
bool Fetch(size_t rows);
/// \brief Return true if the number of rows fetch was greater than zero.
///
/// row_count_ptr and row_status_array are optional arguments, they are only needed for
/// SQLExtendedFetch
bool Fetch(size_t rows, SQLULEN* row_count_ptr = 0, SQLUSMALLINT* row_status_array = 0);
bool IsPrepared() const;

void GetStmtAttr(SQLINTEGER statement_attribute, SQLPOINTER output,
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ add_arrow_test(flight_sql_odbc_test
odbc_test_suite.cc
odbc_test_suite.h
connection_test.cc
statement_test.cc
# Enable Protobuf cleanup after test execution
# GH-46889: move protobuf_test_util to a more common location
../../../../engine/substrait/protobuf_test_util.cc
Expand Down
117 changes: 117 additions & 0 deletions cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h"

#include "arrow/flight/sql/odbc/odbc_impl/platform.h"

#include <sql.h>
#include <sqltypes.h>
#include <sqlucode.h>

#include <limits>

#include <gmock/gmock.h>
#include <gtest/gtest.h>

namespace arrow::flight::sql::odbc {

template <typename T>
class StatementTest : public T {};

class StatementMockTest : public FlightSQLODBCMockTestBase {};
class StatementRemoteTest : public FlightSQLODBCRemoteTestBase {};
using TestTypes = ::testing::Types<StatementMockTest, StatementRemoteTest>;
TYPED_TEST_SUITE(StatementTest, TestTypes);

TYPED_TEST(StatementTest, TestSQLExtendedFetchRowFetching) {
// Set SQL_ROWSET_SIZE to fetch 3 rows at once

constexpr SQLULEN rows = 3;
SQLINTEGER val[rows];
SQLLEN buf_len = sizeof(val);
SQLLEN ind[rows];

// Same variable will be used for column 1, the value of `val`
// should be updated after every SQLFetch call.
ASSERT_EQ(SQL_SUCCESS, SQLBindCol(this->stmt, 1, SQL_C_LONG, val, buf_len, ind));

ASSERT_EQ(SQL_SUCCESS, SQLSetStmtAttr(this->stmt, SQL_ROWSET_SIZE,
reinterpret_cast<SQLPOINTER>(rows), 0));

std::wstring wsql =
LR"(
SELECT 1 AS small_table
UNION ALL
SELECT 2
UNION ALL
SELECT 3;
)";
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());

ASSERT_EQ(SQL_SUCCESS,
SQLExecDirect(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));

// Fetch row 1-3.
SQLULEN row_count;
SQLUSMALLINT row_status[rows];

ASSERT_EQ(SQL_SUCCESS,
SQLExtendedFetch(this->stmt, SQL_FETCH_NEXT, 0, &row_count, row_status));
EXPECT_EQ(3, row_count);

for (int i = 0; i < rows; i++) {
EXPECT_EQ(SQL_SUCCESS, row_status[i]);
}

// Verify 1 is returned for row 1
EXPECT_EQ(1, val[0]);
// Verify 2 is returned for row 2
EXPECT_EQ(2, val[1]);
// Verify 3 is returned for row 3
EXPECT_EQ(3, val[2]);

// Verify result set has no more data beyond row 3
SQLULEN row_count2;
SQLUSMALLINT row_status2[rows];
EXPECT_EQ(SQL_NO_DATA,
SQLExtendedFetch(this->stmt, SQL_FETCH_NEXT, 0, &row_count2, row_status2));
}

TEST_F(StatementRemoteTest, DISABLED_TestSQLExtendedFetchQueryNullIndicator) {
// GH-47110: SQLExtendedFetch should return SQL_SUCCESS_WITH_INFO for 22002
// Limitation on mock test server prevents null from working properly, so use remote
// server instead. Mock server has type `DENSE_UNION` for null column data.
SQLINTEGER val;

ASSERT_EQ(SQL_SUCCESS, SQLBindCol(this->stmt, 1, SQL_C_LONG, &val, 0, 0));

std::wstring wsql = L"SELECT null as null_col;";
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());

ASSERT_EQ(SQL_SUCCESS,
SQLExecDirect(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));

SQLULEN row_count1;
SQLUSMALLINT row_status1[1];

// SQLExtendedFetch should return SQL_SUCCESS_WITH_INFO for 22002 state
ASSERT_EQ(SQL_SUCCESS_WITH_INFO,
SQLExtendedFetch(this->stmt, SQL_FETCH_NEXT, 0, &row_count1, row_status1));
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState22002);
}

} // namespace arrow::flight::sql::odbc