Skip to content

Commit 67334a6

Browse files
committed
GH-891: Add ExtensionTypeWriterFactory to TransferPair
1 parent aee8a10 commit 67334a6

File tree

7 files changed

+219
-4
lines changed

7 files changed

+219
-4
lines changed

vector/src/main/codegen/templates/ComplexCopier.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ private static void writeValue(FieldReader reader, FieldWriter writer, Extension
6363
writer.startList();
6464
while (reader.next()) {
6565
FieldReader childReader = reader.reader();
66-
FieldWriter childWriter = getListWriterForReader(childReader, writer);
66+
FieldWriter childWriter = getListWriterForReader(childReader, writer, extensionTypeWriterFactory);
6767
if (childReader.isSet()) {
6868
writeValue(childReader, childWriter, extensionTypeWriterFactory);
6969
} else {
@@ -189,6 +189,10 @@ private static FieldWriter getStructWriterForReader(FieldReader reader, StructWr
189189
}
190190
191191
private static FieldWriter getListWriterForReader(FieldReader reader, ListWriter writer) {
192+
return getListWriterForReader(reader, writer, null);
193+
}
194+
195+
private static FieldWriter getListWriterForReader(FieldReader reader, ListWriter writer, ExtensionTypeWriterFactory extensionTypeWriterFactory) {
192196
switch (reader.getMinorType()) {
193197
<#list vv.types as type><#list type.minor as minor><#assign name = minor.class?cap_first />
194198
<#assign fields = minor.fields!type.fields />
@@ -209,6 +213,9 @@ private static FieldWriter getListWriterForReader(FieldReader reader, ListWriter
209213
return (FieldWriter) writer.listView();
210214
case EXTENSIONTYPE:
211215
ExtensionWriter extensionWriter = writer.extension(reader.getField().getType());
216+
if (extensionTypeWriterFactory != null) {
217+
extensionWriter.addExtensionTypeWriterFactory(extensionTypeWriterFactory);
218+
}
212219
return (FieldWriter) extensionWriter;
213220
default:
214221
throw new UnsupportedOperationException(reader.getMinorType().toString());

vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
import org.apache.arrow.vector.util.OversizedAllocationException;
6464
import org.apache.arrow.vector.util.SchemaChangeRuntimeException;
6565
import org.apache.arrow.vector.util.TransferPair;
66+
import org.apache.arrow.vector.util.TransferPairWithExtendedType;
6667

6768
/**
6869
* A list vector contains lists of a specific type of elements. Its structure contains 3 elements.
@@ -648,7 +649,7 @@ public UnionVector promoteToUnion() {
648649
return vector;
649650
}
650651

651-
private class TransferImpl implements TransferPair {
652+
private class TransferImpl implements TransferPairWithExtendedType {
652653

653654
LargeListVector to;
654655
TransferPair dataTransferPair;
@@ -731,6 +732,12 @@ public ValueVector getTo() {
731732
public void copyValueSafe(int from, int to) {
732733
this.to.copyFrom(from, to, LargeListVector.this);
733734
}
735+
736+
@Override
737+
public void copyValueSafe(
738+
int from, int to, ExtensionTypeWriterFactory extensionTypeWriterFactory) {
739+
this.to.copyFrom(from, to, LargeListVector.this, extensionTypeWriterFactory);
740+
}
734741
}
735742

736743
@Override

vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import org.apache.arrow.vector.util.JsonStringArrayList;
5757
import org.apache.arrow.vector.util.OversizedAllocationException;
5858
import org.apache.arrow.vector.util.TransferPair;
59+
import org.apache.arrow.vector.util.TransferPairWithExtendedType;
5960

6061
/**
6162
* A list vector contains lists of a specific type of elements. Its structure contains 3 elements.
@@ -528,7 +529,7 @@ public <OUT, IN> OUT accept(VectorVisitor<OUT, IN> visitor, IN value) {
528529
return visitor.visit(this, value);
529530
}
530531

531-
private class TransferImpl implements TransferPair {
532+
private class TransferImpl implements TransferPairWithExtendedType {
532533

533534
ListVector to;
534535
TransferPair dataTransferPair;
@@ -612,6 +613,12 @@ public ValueVector getTo() {
612613
public void copyValueSafe(int from, int to) {
613614
this.to.copyFrom(from, to, ListVector.this);
614615
}
616+
617+
@Override
618+
public void copyValueSafe(
619+
int from, int to, ExtensionTypeWriterFactory extensionTypeWriterFactory) {
620+
this.to.copyFrom(from, to, ListVector.this, extensionTypeWriterFactory);
621+
}
615622
}
616623

617624
@Override

vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import org.apache.arrow.vector.util.JsonStringArrayList;
5757
import org.apache.arrow.vector.util.OversizedAllocationException;
5858
import org.apache.arrow.vector.util.TransferPair;
59+
import org.apache.arrow.vector.util.TransferPairWithExtendedType;
5960

6061
/**
6162
* A list view vector contains lists of a specific type of elements. Its structure contains four
@@ -466,7 +467,7 @@ public int hashCode(int index, ArrowBufHasher hasher) {
466467
return hash;
467468
}
468469

469-
private class TransferImpl implements TransferPair {
470+
private class TransferImpl implements TransferPairWithExtendedType {
470471

471472
ListViewVector to;
472473
TransferPair dataTransferPair;
@@ -557,6 +558,11 @@ public ValueVector getTo() {
557558
public void copyValueSafe(int from, int to) {
558559
this.to.copyFrom(from, to, ListViewVector.this);
559560
}
561+
562+
@Override
563+
public void copyValueSafe(int from, int to, ExtensionTypeWriterFactory writerFactory) {
564+
this.to.copyFrom(from, to, ListViewVector.this, writerFactory);
565+
}
560566
}
561567

562568
@Override
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.arrow.vector.util;
18+
19+
import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory;
20+
21+
public interface TransferPairWithExtendedType extends TransferPair {
22+
void copyValueSafe(int from, int to, ExtensionTypeWriterFactory extensionTypeWriterFactory);
23+
}

vector/src/test/java/org/apache/arrow/vector/TestListVector.java

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,6 +1314,84 @@ public void testCopyFromForExtensionType() throws Exception {
13141314
}
13151315
}
13161316

1317+
@Test
1318+
public void testCopyValueSafeForExtensionType() throws Exception {
1319+
try (ListVector inVector = ListVector.empty("input", allocator);
1320+
ListVector outVector = ListVector.empty("output", allocator)) {
1321+
UnionListWriter writer = inVector.getWriter();
1322+
writer.allocate();
1323+
1324+
// Create first list with UUIDs
1325+
writer.setPosition(0);
1326+
UUID u1 = UUID.randomUUID();
1327+
UUID u2 = UUID.randomUUID();
1328+
writer.startList();
1329+
ExtensionWriter extensionWriter = writer.extension(new UuidType());
1330+
extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory());
1331+
extensionWriter.writeExtension(u1);
1332+
extensionWriter.writeExtension(u2);
1333+
writer.endList();
1334+
1335+
// Create second list with UUIDs
1336+
writer.setPosition(1);
1337+
UUID u3 = UUID.randomUUID();
1338+
UUID u4 = UUID.randomUUID();
1339+
writer.startList();
1340+
extensionWriter = writer.extension(new UuidType());
1341+
extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory());
1342+
extensionWriter.writeExtension(u3);
1343+
extensionWriter.writeExtension(u4);
1344+
extensionWriter.writeNull();
1345+
1346+
writer.endList();
1347+
writer.setValueCount(2);
1348+
1349+
// Use copyFromSafe with ExtensionTypeWriterFactory
1350+
// This internally calls TransferImpl.copyValueSafe with ExtensionTypeWriterFactory
1351+
outVector.allocateNew();
1352+
outVector.copyFromSafe(0, 0, inVector, new UuidWriterFactory());
1353+
outVector.copyFromSafe(1, 1, inVector, new UuidWriterFactory());
1354+
outVector.setValueCount(2);
1355+
1356+
// Verify first list
1357+
UnionListReader reader = outVector.getReader();
1358+
reader.setPosition(0);
1359+
assertTrue(reader.isSet(), "first list shouldn't be null");
1360+
reader.next();
1361+
FieldReader uuidReader = reader.reader();
1362+
UuidHolder holder = new UuidHolder();
1363+
uuidReader.read(holder);
1364+
ByteBuffer bb = ByteBuffer.wrap(holder.value);
1365+
UUID actualUuid = new UUID(bb.getLong(), bb.getLong());
1366+
assertEquals(u1, actualUuid);
1367+
reader.next();
1368+
uuidReader = reader.reader();
1369+
uuidReader.read(holder);
1370+
bb = ByteBuffer.wrap(holder.value);
1371+
actualUuid = new UUID(bb.getLong(), bb.getLong());
1372+
assertEquals(u2, actualUuid);
1373+
1374+
// Verify second list
1375+
reader.setPosition(1);
1376+
assertTrue(reader.isSet(), "second list shouldn't be null");
1377+
reader.next();
1378+
uuidReader = reader.reader();
1379+
uuidReader.read(holder);
1380+
bb = ByteBuffer.wrap(holder.value);
1381+
actualUuid = new UUID(bb.getLong(), bb.getLong());
1382+
assertEquals(u3, actualUuid);
1383+
reader.next();
1384+
uuidReader = reader.reader();
1385+
uuidReader.read(holder);
1386+
bb = ByteBuffer.wrap(holder.value);
1387+
actualUuid = new UUID(bb.getLong(), bb.getLong());
1388+
assertEquals(u4, actualUuid);
1389+
reader.next();
1390+
uuidReader = reader.reader();
1391+
assertFalse(uuidReader.isSet(), "third element should be null");
1392+
}
1393+
}
1394+
13171395
private void writeIntValues(UnionListWriter writer, int[] values) {
13181396
writer.startList();
13191397
for (int v : values) {

vector/src/test/java/org/apache/arrow/vector/TestListViewVector.java

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,32 @@
2222
import static org.junit.jupiter.api.Assertions.assertThrows;
2323
import static org.junit.jupiter.api.Assertions.assertTrue;
2424

25+
import java.nio.ByteBuffer;
2526
import java.util.ArrayList;
2627
import java.util.Arrays;
2728
import java.util.Collections;
2829
import java.util.List;
30+
import java.util.UUID;
2931
import org.apache.arrow.memory.ArrowBuf;
3032
import org.apache.arrow.memory.BufferAllocator;
3133
import org.apache.arrow.vector.complex.BaseRepeatedValueVector;
3234
import org.apache.arrow.vector.complex.BaseRepeatedValueViewVector;
3335
import org.apache.arrow.vector.complex.ListVector;
3436
import org.apache.arrow.vector.complex.ListViewVector;
37+
import org.apache.arrow.vector.complex.impl.UnionListViewReader;
3538
import org.apache.arrow.vector.complex.impl.UnionListViewWriter;
39+
import org.apache.arrow.vector.complex.impl.UuidWriterFactory;
40+
import org.apache.arrow.vector.complex.reader.FieldReader;
41+
import org.apache.arrow.vector.complex.writer.BaseWriter.ExtensionWriter;
42+
import org.apache.arrow.vector.holder.UuidHolder;
3643
import org.apache.arrow.vector.holders.DurationHolder;
3744
import org.apache.arrow.vector.holders.TimeStampMilliTZHolder;
3845
import org.apache.arrow.vector.types.TimeUnit;
3946
import org.apache.arrow.vector.types.Types.MinorType;
4047
import org.apache.arrow.vector.types.pojo.ArrowType;
4148
import org.apache.arrow.vector.types.pojo.Field;
4249
import org.apache.arrow.vector.types.pojo.FieldType;
50+
import org.apache.arrow.vector.types.pojo.UuidType;
4351
import org.apache.arrow.vector.util.TransferPair;
4452
import org.junit.jupiter.api.AfterEach;
4553
import org.junit.jupiter.api.BeforeEach;
@@ -2217,6 +2225,85 @@ public void testRangeChildVector2() {
22172225
}
22182226
}
22192227

2228+
@Test
2229+
public void testCopyValueSafeForExtensionType() throws Exception {
2230+
try (ListViewVector inVector = ListViewVector.empty("input", allocator);
2231+
ListViewVector outVector = ListViewVector.empty("output", allocator)) {
2232+
UnionListViewWriter writer = inVector.getWriter();
2233+
writer.allocate();
2234+
2235+
// Create first list with UUIDs
2236+
writer.setPosition(0);
2237+
UUID u1 = UUID.randomUUID();
2238+
UUID u2 = UUID.randomUUID();
2239+
writer.startListView();
2240+
ExtensionWriter extensionWriter = writer.extension(new UuidType());
2241+
extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory());
2242+
extensionWriter.writeExtension(u1);
2243+
extensionWriter.writeExtension(u2);
2244+
writer.endListView();
2245+
2246+
// Create second list with UUIDs
2247+
writer.setPosition(1);
2248+
UUID u3 = UUID.randomUUID();
2249+
UUID u4 = UUID.randomUUID();
2250+
writer.startListView();
2251+
extensionWriter = writer.extension(new UuidType());
2252+
extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory());
2253+
extensionWriter.writeExtension(u3);
2254+
extensionWriter.writeExtension(u4);
2255+
extensionWriter.writeNull();
2256+
2257+
writer.endListView();
2258+
writer.setValueCount(2);
2259+
2260+
// Use copyFromSafe with ExtensionTypeWriterFactory
2261+
// This internally calls TransferImpl.copyValueSafe with ExtensionTypeWriterFactory
2262+
UnionListViewWriter outWriter = outVector.getWriter();
2263+
outWriter.allocate();
2264+
outVector.copyFromSafe(0, 0, inVector, new UuidWriterFactory());
2265+
outVector.copyFromSafe(1, 1, inVector, new UuidWriterFactory());
2266+
outVector.setValueCount(2);
2267+
2268+
// Verify first list
2269+
UnionListViewReader reader = outVector.getReader();
2270+
reader.setPosition(0);
2271+
assertTrue(reader.isSet(), "first list shouldn't be null");
2272+
reader.next();
2273+
FieldReader uuidReader = reader.reader();
2274+
UuidHolder holder = new UuidHolder();
2275+
uuidReader.read(holder);
2276+
ByteBuffer bb = ByteBuffer.wrap(holder.value);
2277+
UUID actualUuid = new UUID(bb.getLong(), bb.getLong());
2278+
assertEquals(u1, actualUuid);
2279+
reader.next();
2280+
uuidReader = reader.reader();
2281+
uuidReader.read(holder);
2282+
bb = ByteBuffer.wrap(holder.value);
2283+
actualUuid = new UUID(bb.getLong(), bb.getLong());
2284+
assertEquals(u2, actualUuid);
2285+
2286+
// Verify second list
2287+
reader.setPosition(1);
2288+
assertTrue(reader.isSet(), "second list shouldn't be null");
2289+
reader.next();
2290+
uuidReader = reader.reader();
2291+
uuidReader.read(holder);
2292+
bb = ByteBuffer.wrap(holder.value);
2293+
actualUuid = new UUID(bb.getLong(), bb.getLong());
2294+
assertEquals(u3, actualUuid);
2295+
reader.next();
2296+
uuidReader = reader.reader();
2297+
uuidReader.read(holder);
2298+
bb = ByteBuffer.wrap(holder.value);
2299+
actualUuid = new UUID(bb.getLong(), bb.getLong());
2300+
assertEquals(u4, actualUuid);
2301+
reader.next();
2302+
uuidReader = reader.reader();
2303+
assertFalse(uuidReader.isSet(), "third element should be null");
2304+
}
2305+
}
2306+
22202307
private void writeIntValues(UnionListViewWriter writer, int[] values) {
22212308
writer.startListView();
22222309
for (int v : values) {

0 commit comments

Comments
 (0)