Skip to content

Commit ed82791

Browse files
alinaliBQrscales
authored andcommitted
Extract SQLNativeSQL implementation
Co-Authored-By: rscales <[email protected]>
1 parent 398fc5a commit ed82791

File tree

3 files changed

+161
-2
lines changed

3 files changed

+161
-2
lines changed

cpp/src/arrow/flight/sql/odbc/odbc_api.cc

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,8 +1113,23 @@ SQLRETURN SQLNativeSql(SQLHDBC conn, SQLWCHAR* in_statement_text,
11131113
<< ", buffer_length: " << buffer_length
11141114
<< ", out_statement_text_length: "
11151115
<< static_cast<const void*>(out_statement_text_length);
1116-
// GH-47723 TODO: Implement SQLNativeSql
1117-
return SQL_INVALID_HANDLE;
1116+
1117+
using ODBC::GetAttributeSQLWCHAR;
1118+
using ODBC::ODBCConnection;
1119+
using ODBC::SqlWcharToString;
1120+
1121+
return ODBCConnection::ExecuteWithDiagnostics(conn, SQL_ERROR, [=]() {
1122+
const bool is_length_in_bytes = false;
1123+
1124+
ODBCConnection* connection = reinterpret_cast<ODBCConnection*>(conn);
1125+
Diagnostics& diagnostics = connection->GetDiagnostics();
1126+
1127+
std::string in_statement_str =
1128+
SqlWcharToString(in_statement_text, in_statement_text_length);
1129+
1130+
return GetAttributeSQLWCHAR(in_statement_str, is_length_in_bytes, out_statement_text,
1131+
buffer_length, out_statement_text_length, diagnostics);
1132+
});
11181133
}
11191134

11201135
SQLRETURN SQLDescribeCol(SQLHSTMT stmt, SQLUSMALLINT column_number, SQLWCHAR* column_name,

cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ add_arrow_test(flight_sql_odbc_test
3535
odbc_test_suite.cc
3636
odbc_test_suite.h
3737
connection_test.cc
38+
statement_test.cc
3839
# Enable Protobuf cleanup after test execution
3940
# GH-46889: move protobuf_test_util to a more common location
4041
../../../../engine/substrait/protobuf_test_util.cc
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h"
18+
19+
#include "arrow/flight/sql/odbc/odbc_impl/platform.h"
20+
21+
#include <sql.h>
22+
#include <sqltypes.h>
23+
#include <sqlucode.h>
24+
25+
#include <limits>
26+
27+
#include <gmock/gmock.h>
28+
#include <gtest/gtest.h>
29+
30+
namespace arrow::flight::sql::odbc {
31+
32+
template <typename T>
33+
class StatementTest : public T {};
34+
35+
class StatementMockTest : public FlightSQLODBCMockTestBase {};
36+
class StatementRemoteTest : public FlightSQLODBCRemoteTestBase {};
37+
using TestTypes = ::testing::Types<StatementMockTest, StatementRemoteTest>;
38+
TYPED_TEST_SUITE(StatementTest, TestTypes);
39+
40+
TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsInputString) {
41+
SQLWCHAR buf[1024];
42+
SQLINTEGER buf_char_len = sizeof(buf) / ODBC::GetSqlWCharSize();
43+
SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1";
44+
SQLINTEGER input_char_len = static_cast<SQLINTEGER>(wcslen(input_str));
45+
SQLINTEGER output_char_len = 0;
46+
std::wstring expected_string = std::wstring(input_str);
47+
48+
ASSERT_EQ(SQL_SUCCESS, SQLNativeSql(this->conn, input_str, input_char_len, buf,
49+
buf_char_len, &output_char_len));
50+
51+
EXPECT_EQ(input_char_len, output_char_len);
52+
53+
// returned length is in characters
54+
std::wstring returned_string(buf, buf + output_char_len);
55+
56+
EXPECT_EQ(expected_string, returned_string);
57+
}
58+
59+
TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsNTSInputString) {
60+
SQLWCHAR buf[1024];
61+
SQLINTEGER buf_char_len = sizeof(buf) / ODBC::GetSqlWCharSize();
62+
SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1";
63+
SQLINTEGER input_char_len = static_cast<SQLINTEGER>(wcslen(input_str));
64+
SQLINTEGER output_char_len = 0;
65+
std::wstring expected_string = std::wstring(input_str);
66+
67+
ASSERT_EQ(SQL_SUCCESS, SQLNativeSql(this->conn, input_str, SQL_NTS, buf, buf_char_len,
68+
&output_char_len));
69+
70+
EXPECT_EQ(input_char_len, output_char_len);
71+
72+
// returned length is in characters
73+
std::wstring returned_string(buf, buf + output_char_len);
74+
75+
EXPECT_EQ(expected_string, returned_string);
76+
}
77+
78+
TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsInputStringLength) {
79+
SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1";
80+
SQLINTEGER input_char_len = static_cast<SQLINTEGER>(wcslen(input_str));
81+
SQLINTEGER output_char_len = 0;
82+
std::wstring expected_string = std::wstring(input_str);
83+
84+
ASSERT_EQ(SQL_SUCCESS, SQLNativeSql(this->conn, input_str, input_char_len, nullptr, 0,
85+
&output_char_len));
86+
87+
EXPECT_EQ(input_char_len, output_char_len);
88+
89+
ASSERT_EQ(SQL_SUCCESS,
90+
SQLNativeSql(this->conn, input_str, SQL_NTS, nullptr, 0, &output_char_len));
91+
92+
EXPECT_EQ(input_char_len, output_char_len);
93+
}
94+
95+
TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsTruncatedString) {
96+
const SQLINTEGER small_buf_size_in_char = 11;
97+
SQLWCHAR small_buf[small_buf_size_in_char];
98+
SQLINTEGER small_buf_char_len = sizeof(small_buf) / ODBC::GetSqlWCharSize();
99+
SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1";
100+
SQLINTEGER input_char_len = static_cast<SQLINTEGER>(wcslen(input_str));
101+
SQLINTEGER output_char_len = 0;
102+
103+
// Create expected return string based on buf size
104+
SQLWCHAR expected_string_buf[small_buf_size_in_char];
105+
wcsncpy(expected_string_buf, input_str, 10);
106+
expected_string_buf[10] = L'\0';
107+
std::wstring expected_string(expected_string_buf,
108+
expected_string_buf + small_buf_size_in_char);
109+
110+
ASSERT_EQ(SQL_SUCCESS_WITH_INFO,
111+
SQLNativeSql(this->conn, input_str, input_char_len, small_buf,
112+
small_buf_char_len, &output_char_len));
113+
VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorState01004);
114+
115+
// Returned text length represents full string char length regardless of truncation
116+
EXPECT_EQ(input_char_len, output_char_len);
117+
118+
std::wstring returned_string(small_buf, small_buf + small_buf_char_len);
119+
120+
EXPECT_EQ(expected_string, returned_string);
121+
}
122+
123+
TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsErrorOnBadInputs) {
124+
SQLWCHAR buf[1024];
125+
SQLINTEGER buf_char_len = sizeof(buf) / ODBC::GetSqlWCharSize();
126+
SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1";
127+
SQLINTEGER input_char_len = static_cast<SQLINTEGER>(wcslen(input_str));
128+
SQLINTEGER output_char_len = 0;
129+
130+
ASSERT_EQ(SQL_ERROR, SQLNativeSql(this->conn, nullptr, input_char_len, buf,
131+
buf_char_len, &output_char_len));
132+
VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHY009);
133+
134+
ASSERT_EQ(SQL_ERROR, SQLNativeSql(this->conn, nullptr, SQL_NTS, buf, buf_char_len,
135+
&output_char_len));
136+
VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHY009);
137+
138+
ASSERT_EQ(SQL_ERROR, SQLNativeSql(this->conn, input_str, -100, buf, buf_char_len,
139+
&output_char_len));
140+
VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHY090);
141+
}
142+
143+
} // namespace arrow::flight::sql::odbc

0 commit comments

Comments
 (0)