1414
1515#include < gtest/gtest.h>
1616#include " presto_cpp/main/thrift/ProtocolToThrift.h"
17+ #include " presto_cpp/main/thrift/ThriftIO.h"
1718#include " presto_cpp/main/common/tests/test_json.h"
1819#include " presto_cpp/main/connectors/PrestoToVeloxConnector.h"
1920
@@ -100,7 +101,7 @@ TEST_F(TaskUpdateRequestTest, mapOutputBuffers) {
100101 ASSERT_EQ (outputBuffers.buffers [" 2" ], 20 );
101102}
102103
103- TEST_F (TaskUpdateRequestTest, binarySplitFromThrift ) {
104+ TEST_F (TaskUpdateRequestTest, binaryHiveSplitFromThrift ) {
104105 thrift::Split thriftSplit;
105106 thriftSplit.connectorId ()->catalogName_ref () = " hive" ;
106107 thriftSplit.transactionHandle ()->jsonValue_ref () = R"( {
@@ -127,14 +128,89 @@ TEST_F(TaskUpdateRequestTest, binarySplitFromThrift) {
127128 protocol::NodeSelectionStrategy::NO_PREFERENCE);
128129}
129130
130- TEST_F (TaskUpdateRequestTest, binaryTableWriteInfo) {
131- std::string str = slurp (getDataPath (BASE_DATA_PATH, " TableWriteInfo.json" ));
132- protocol::TableWriteInfo tableWriteInfo;
131+ TEST_F (TaskUpdateRequestTest, binaryRemoteSplitFromThrift) {
132+ thrift::Split thriftSplit;
133+ thrift::RemoteTransactionHandle thriftTransactionHandle;
134+ thrift::RemoteSplit thriftRemoteSplit;
135+
136+ thriftSplit.connectorId ()->catalogName_ref () = " $remote" ;
137+ thriftSplit.transactionHandle ()->customSerializedValue_ref () =
138+ thriftWrite (thriftTransactionHandle);
139+
140+ thriftRemoteSplit.location ()->location_ref () = " /test_location" ;
141+ thriftRemoteSplit.remoteSourceTaskId ()->id_ref () = 100 ;
142+ thriftRemoteSplit.remoteSourceTaskId ()->attemptNumber_ref () = 200 ;
143+ thriftRemoteSplit.remoteSourceTaskId ()->stageExecutionId ()->id_ref () = 300 ;
144+ thriftRemoteSplit.remoteSourceTaskId ()->stageExecutionId ()->stageId ()->id_ref () = 400 ;
145+ thriftRemoteSplit.remoteSourceTaskId ()->stageExecutionId ()->stageId ()->queryId_ref () = " test_query_id" ;
146+
147+ thriftSplit.connectorSplit ()->connectorId_ref () = " $remote" ;
148+ thriftSplit.connectorSplit ()->customSerializedValue_ref () =
149+ thriftWrite (thriftRemoteSplit);
150+
151+ protocol::Split split;
152+ thrift::fromThrift (thriftSplit, split);
153+
154+ // Verify that connector specific fields are set correctly with thrift codec
155+ auto remoteSplit = std::dynamic_pointer_cast<protocol::RemoteSplit>(
156+ split.connectorSplit );
157+ ASSERT_EQ ((remoteSplit->location ).location , " /test_location" );
158+ ASSERT_EQ (remoteSplit->remoteSourceTaskId , " test_query_id.400.300.100.200" );
159+ }
160+
161+ TEST_F (TaskUpdateRequestTest, unionExecutionWriterTargetFromThrift) {
162+ // Construct ExecutionWriterTarget with CreateHandle
163+ thrift::CreateHandle thriftCreateHandle;
164+ thrift::ExecutionWriterTargetUnion thriftWriterTarget;
165+ thriftCreateHandle.schemaTableName ()->schema_ref () = " test_schema" ;
166+ thriftCreateHandle.schemaTableName ()->table_ref () = " test_table" ;
167+ thriftCreateHandle.handle ()->connectorId ()->catalogName_ref () = " hive" ;
168+ thriftCreateHandle.handle ()->transactionHandle ()->jsonValue_ref () = R"( {
169+ "@type": "hive",
170+ "uuid": "8a4d6c83-60ee-46de-9715-bc91755619fa"
171+ })" ;
172+ thriftCreateHandle.handle ()->connectorHandle ()->jsonValue_ref () = slurp (getDataPath (BASE_DATA_PATH, " HiveOutputTableHandle.json" ));;
173+ thriftWriterTarget.set_createHandle (std::move (thriftCreateHandle));
174+
175+ // Convert from thrift to protocol and verify fields
176+ auto writerTarget = std::make_shared<protocol::ExecutionWriterTarget>();
177+ thrift::fromThrift (thriftWriterTarget, writerTarget);
178+
179+ ASSERT_EQ (writerTarget->_type , " CreateHandle" );
180+ auto createHandle = std::dynamic_pointer_cast<protocol::CreateHandle>(writerTarget);
181+ ASSERT_NE (createHandle, nullptr );
182+ ASSERT_EQ (createHandle->schemaTableName .schema , " test_schema" );
183+ ASSERT_EQ (createHandle->schemaTableName .table , " test_table" );
184+
185+ auto * hiveTxnHandle = dynamic_cast <protocol::hive::HiveTransactionHandle*>(createHandle->handle .transactionHandle .get ());
186+ ASSERT_NE (hiveTxnHandle, nullptr );
187+ ASSERT_EQ (hiveTxnHandle->uuid , " 8a4d6c83-60ee-46de-9715-bc91755619fa" );
188+
189+ auto * hiveOutputTableHandle = dynamic_cast <protocol::hive::HiveOutputTableHandle*>(createHandle->handle .connectorHandle .get ());
190+ ASSERT_NE (hiveOutputTableHandle, nullptr );
191+ ASSERT_EQ (hiveOutputTableHandle->schemaName , " test_schema" );
192+ ASSERT_EQ (hiveOutputTableHandle->tableName , " test_table" );
193+ ASSERT_EQ (hiveOutputTableHandle->tableStorageFormat , protocol::hive::HiveStorageFormat::ORC);
194+ ASSERT_EQ (hiveOutputTableHandle->locationHandle .targetPath , " /path/to/target" );
195+ }
196+
197+ TEST_F (TaskUpdateRequestTest, unionExecutionWriterTargetToThrift) {
198+ // Construct thrift ExecutionWriterTarget with CreateHandle
199+ auto createHandle = std::make_shared<protocol::CreateHandle>();
200+ createHandle->schemaTableName .schema = " test_schema" ;
201+ createHandle->schemaTableName .table = " test_table" ;
202+
203+ auto writerTarget = std::make_shared<protocol::ExecutionWriterTarget>();
204+ writerTarget->_type = " CreateHandle" ;
205+ writerTarget = createHandle;
133206
134- thrift::fromThrift (str, tableWriteInfo);
135- auto hiveTableHandle = std::dynamic_pointer_cast<protocol::hive::HiveTableHandle>((*tableWriteInfo.analyzeTableHandle ).connectorHandle );
136- ASSERT_EQ (hiveTableHandle->tableName , " test_table" );
137- ASSERT_EQ (hiveTableHandle->analyzePartitionValues ->size (), 2 );
207+ // Convert to thrift and verify fields. Note that toThrift functions for connector fields are not implemented.
208+ thrift::ExecutionWriterTargetUnion thriftWriterTarget;
209+ thrift::toThrift (writerTarget, thriftWriterTarget);
210+ ASSERT_TRUE (thriftWriterTarget.createHandle_ref ().has_value ());
211+ const auto & thriftCreateHandle = thriftWriterTarget.createHandle_ref ().value ();
212+ ASSERT_EQ (thriftCreateHandle.schemaTableName ()->schema_ref ().value (), " test_schema" );
213+ ASSERT_EQ (thriftCreateHandle.schemaTableName ()->table_ref ().value (), " test_table" );
138214}
139215
140216TEST_F (TaskUpdateRequestTest, fragment) {
0 commit comments