Skip to content

Commit a940fec

Browse files
committed
Extract SQLNumResultCols implementation
Co-Authored-By: alinalibq <[email protected]>
1 parent 42f27ab commit a940fec

File tree

5 files changed

+107
-2
lines changed

5 files changed

+107
-2
lines changed

cpp/src/arrow/flight/sql/odbc/odbc_api.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -890,8 +890,13 @@ SQLRETURN SQLNumResultCols(SQLHSTMT stmt, SQLSMALLINT* column_count_ptr) {
890890
ARROW_LOG(DEBUG) << "SQLNumResultCols called with stmt: " << stmt
891891
<< ", column_count_ptr: "
892892
<< static_cast<const void*>(column_count_ptr);
893-
// GH-47713 TODO: Implement SQLNumResultCols
894-
return SQL_INVALID_HANDLE;
893+
894+
using ODBC::ODBCStatement;
895+
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
896+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
897+
statement->GetColumnCount(column_count_ptr);
898+
return SQL_SUCCESS;
899+
});
895900
}
896901

897902
SQLRETURN SQLRowCount(SQLHSTMT stmt, SQLLEN* row_count_ptr) {

cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,16 @@ bool ODBCStatement::GetData(SQLSMALLINT record_number, SQLSMALLINT c_type,
735735
data_ptr, buffer_length, indicator_ptr);
736736
}
737737

738+
void ODBCStatement::GetColumnCount(SQLSMALLINT* column_count_ptr) {
739+
if (!column_count_ptr) {
740+
// columnCountPtr is not valid, do nothing as ODBC spec does not mention this as an
741+
// error
742+
return;
743+
}
744+
size_t column_count = ird_->GetRecords().size();
745+
*column_count_ptr = static_cast<SQLSMALLINT>(column_count);
746+
}
747+
738748
void ODBCStatement::ReleaseStatement() {
739749
CloseCursor(true);
740750
connection_.DropStatement(this);

cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ class ODBCStatement : public ODBCHandle<ODBCStatement> {
8080
bool GetData(SQLSMALLINT record_number, SQLSMALLINT c_type, SQLPOINTER data_ptr,
8181
SQLLEN buffer_length, SQLLEN* indicator_ptr);
8282

83+
/**
84+
* @brief Get number of columns from data set
85+
*/
86+
void GetColumnCount(SQLSMALLINT* column_count_ptr);
87+
8388
/**
8489
* @brief Closes the cursor. This does _not_ un-prepare the statement or change
8590
* bindings.

cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ add_arrow_test(flight_sql_odbc_test
3535
odbc_test_suite.cc
3636
odbc_test_suite.h
3737
connection_test.cc
38+
statement_test.cc
3839
# Enable Protobuf cleanup after test execution
3940
# GH-46889: move protobuf_test_util to a more common location
4041
../../../../engine/substrait/protobuf_test_util.cc
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h"
18+
19+
#include "arrow/flight/sql/odbc/odbc_impl/platform.h"
20+
21+
#include <sql.h>
22+
#include <sqltypes.h>
23+
#include <sqlucode.h>
24+
25+
#include <limits>
26+
27+
#include <gmock/gmock.h>
28+
#include <gtest/gtest.h>
29+
30+
namespace arrow::flight::sql::odbc {
31+
32+
template <typename T>
33+
class StatementTest : public T {};
34+
35+
class StatementMockTest : public FlightSQLODBCMockTestBase {};
36+
class StatementRemoteTest : public FlightSQLODBCRemoteTestBase {};
37+
using TestTypes = ::testing::Types<StatementMockTest, StatementRemoteTest>;
38+
TYPED_TEST_SUITE(StatementTest, TestTypes);
39+
40+
TYPED_TEST(StatementTest, SQLNumResultColsReturnsColumnsOnSelect) {
41+
SQLSMALLINT column_count = 0;
42+
SQLSMALLINT expected_value = 3;
43+
SQLWCHAR sql_query[] = L"SELECT 1 AS col1, 'One' AS col2, 3 AS col3";
44+
SQLINTEGER query_length = static_cast<SQLINTEGER>(wcslen(sql_query));
45+
46+
ASSERT_EQ(SQL_SUCCESS, SQLExecDirect(this->stmt, sql_query, query_length));
47+
48+
ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt));
49+
50+
CheckIntColumn(this->stmt, 1, 1);
51+
CheckStringColumnW(this->stmt, 2, L"One");
52+
CheckIntColumn(this->stmt, 3, 3);
53+
54+
ASSERT_EQ(SQL_SUCCESS, SQLNumResultCols(this->stmt, &column_count));
55+
56+
EXPECT_EQ(expected_value, column_count);
57+
}
58+
59+
TYPED_TEST(StatementTest, SQLNumResultColsReturnsSuccessOnNullptr) {
60+
SQLWCHAR sql_query[] = L"SELECT 1 AS col1, 'One' AS col2, 3 AS col3";
61+
SQLINTEGER query_length = static_cast<SQLINTEGER>(wcslen(sql_query));
62+
63+
ASSERT_EQ(SQL_SUCCESS, SQLExecDirect(this->stmt, sql_query, query_length));
64+
65+
ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt));
66+
67+
CheckIntColumn(this->stmt, 1, 1);
68+
CheckStringColumnW(this->stmt, 2, L"One");
69+
CheckIntColumn(this->stmt, 3, 3);
70+
71+
ASSERT_EQ(SQL_SUCCESS, SQLNumResultCols(this->stmt, nullptr));
72+
}
73+
74+
TYPED_TEST(StatementTest, SQLNumResultColsFunctionSequenceErrorOnNoQuery) {
75+
SQLSMALLINT column_count = 0;
76+
SQLSMALLINT expected_value = 0;
77+
78+
ASSERT_EQ(SQL_ERROR, SQLNumResultCols(this->stmt, &column_count));
79+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY010);
80+
81+
EXPECT_EQ(expected_value, column_count);
82+
}
83+
84+
} // namespace arrow::flight::sql::odbc

0 commit comments

Comments
 (0)