Skip to content

Commit 79928ff

Browse files
committed
feat: Finalize our CsvDataSource implementation.
- Introduce an Iterable type on the record batches. - Introduce an IndexedStream to streamline functional iterators that need an index. - Implement CSV parsing end to end. - Introduce an ArrowFieldVector factory >:^| - Introduce more Java-isms >:^| - Add tests for the above.
1 parent 37be0b6 commit 79928ff

File tree

14 files changed

+856
-34
lines changed

14 files changed

+856
-34
lines changed

glint/src/main/java/co/clflushopt/glint/datasource/CsvDataSource.java

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import java.io.File;
44
import java.io.FileNotFoundException;
5+
import java.util.ArrayList;
56
import java.util.List;
67
import java.util.Optional;
78
import java.util.logging.Logger;
@@ -12,6 +13,7 @@
1213

1314
import co.clflushopt.glint.types.ArrowTypes;
1415
import co.clflushopt.glint.types.Field;
16+
import co.clflushopt.glint.types.RecordBatch;
1517
import co.clflushopt.glint.types.Schema;
1618

1719
/**
@@ -23,7 +25,7 @@
2325
* names are inferred from the header, if a header is not present then the field
2426
* names will be `field_0...field_n`.
2527
*/
26-
public class CsvDataSource {
28+
public class CsvDataSource implements DataSource {
2729
private final Schema schema;
2830
private final String filename;
2931
private final Boolean hasHeaders;
@@ -45,14 +47,19 @@ public CsvDataSource(String filename, Optional<Schema> schema, Boolean hasHeader
4547
}
4648
}
4749

48-
private Schema inferSchema() throws FileNotFoundException {
49-
logger.info("Schema inference triggered");
50-
50+
private File open() throws FileNotFoundException {
5151
var file = new File(filename);
5252
if (!file.exists()) {
5353
logger.info("File was not found");
5454
throw new FileNotFoundException("file with name " + filename + " was not found");
5555
}
56+
return file;
57+
}
58+
59+
private Schema inferSchema() throws FileNotFoundException {
60+
logger.info("Schema inference triggered");
61+
62+
var file = open();
5663

5764
var parser = getCsvParser(getCsvDefaultSettings());
5865
parser.beginParsing(file);
@@ -70,11 +77,14 @@ private Schema inferSchema() throws FileNotFoundException {
7077
parser.stopParsing();
7178

7279
if (hasHeaders) {
73-
return new Schema(List.of(headers).stream()
74-
.map(columnName -> new Field(columnName, ArrowTypes.StringType)).toList());
80+
return new Schema(Streams.mapWithIndex(List.of(headers).stream(), (columnName,
81+
columnIndex) -> new Field(columnName, (int) columnIndex, ArrowTypes.StringType))
82+
.toList());
83+
7584
} else {
76-
return new Schema(Streams.mapWithIndex(List.of(headers).stream(), (_field,
77-
index) -> new Field(String.format("field_%d", index), ArrowTypes.StringType))
85+
return new Schema(Streams.mapWithIndex(List.of(headers).stream(),
86+
(_field, index) -> new Field(String.format("field_%d", index), (int) index,
87+
ArrowTypes.StringType))
7888
.toList());
7989
}
8090

@@ -115,4 +125,32 @@ public Logger getLogger() {
115125
return logger;
116126
}
117127

128+
@Override
129+
public Iterable<RecordBatch> scan(List<String> projection) {
130+
try {
131+
var file = this.open();
132+
var schema = this.schema;
133+
var settings = this.getCsvDefaultSettings();
134+
135+
if (!projection.isEmpty()) {
136+
schema = this.schema.select(projection);
137+
settings.selectFields(projection.toArray(new String[0]));
138+
}
139+
settings.setHeaderExtractionEnabled(hasHeaders);
140+
if (!hasHeaders) {
141+
settings.setHeaders(schema.getFields().stream().map(field -> field.name()).toList()
142+
.toArray(new String[0]));
143+
}
144+
145+
var parser = getCsvParser(settings);
146+
parser.beginParsing(file);
147+
var format = parser.getDetectedFormat();
148+
logger.info(String.format("Detected format with delimiter: %s and line separator: %s",
149+
format.getDelimiterString(), format.getLineSeparator()));
150+
151+
return new CsvReaderIterable(schema, parser, this.batchSize);
152+
} catch (Exception e) {
153+
return new ArrayList<>();
154+
}
155+
}
118156
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package co.clflushopt.glint.datasource;
2+
3+
import java.util.Iterator;
4+
5+
import com.univocity.parsers.csv.CsvParser;
6+
7+
import co.clflushopt.glint.types.RecordBatch;
8+
import co.clflushopt.glint.types.Schema;
9+
10+
public class CsvReaderIterable implements Iterable<RecordBatch> {
11+
private final Schema schema;
12+
private final CsvParser parser;
13+
private final int batchSize;
14+
15+
public CsvReaderIterable(Schema schema, CsvParser parser, int batchSize) {
16+
this.schema = schema;
17+
this.parser = parser;
18+
this.batchSize = batchSize;
19+
}
20+
21+
@Override
22+
public Iterator<RecordBatch> iterator() {
23+
return new CsvReaderIterator(schema, parser, batchSize);
24+
}
25+
26+
}
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
package co.clflushopt.glint.datasource;
2+
3+
import java.util.ArrayList;
4+
import java.util.Iterator;
5+
import java.util.NoSuchElementException;
6+
import java.util.logging.Logger;
7+
import java.util.stream.Collectors;
8+
9+
import org.apache.arrow.memory.RootAllocator;
10+
import org.apache.arrow.vector.BigIntVector;
11+
import org.apache.arrow.vector.Float4Vector;
12+
import org.apache.arrow.vector.Float8Vector;
13+
import org.apache.arrow.vector.IntVector;
14+
import org.apache.arrow.vector.SmallIntVector;
15+
import org.apache.arrow.vector.TinyIntVector;
16+
import org.apache.arrow.vector.ValueVector;
17+
import org.apache.arrow.vector.VarCharVector;
18+
import org.apache.arrow.vector.VectorSchemaRoot;
19+
20+
import com.univocity.parsers.common.record.Record;
21+
import com.univocity.parsers.csv.CsvParser;
22+
23+
import co.clflushopt.glint.types.ArrowFieldVector;
24+
import co.clflushopt.glint.types.RecordBatch;
25+
import co.clflushopt.glint.types.Schema;
26+
import co.clflushopt.glint.util.IndexedStream;
27+
28+
public class CsvReaderIterator implements Iterator<RecordBatch> {
29+
private static final Logger logger = Logger.getLogger(CsvDataSource.class.getSimpleName());
30+
31+
private final Schema schema;
32+
private final CsvParser parser;
33+
private final int batchSize;
34+
private RecordBatch next;
35+
private boolean started;
36+
37+
public CsvReaderIterator(Schema schema, CsvParser parser, int batchSize) {
38+
this.schema = schema;
39+
this.parser = parser;
40+
this.batchSize = batchSize;
41+
this.started = false;
42+
}
43+
44+
@Override
45+
public boolean hasNext() {
46+
if (!started) {
47+
started = true;
48+
next = nextBatch();
49+
}
50+
return next != null;
51+
}
52+
53+
@Override
54+
public RecordBatch next() {
55+
if (!started) {
56+
hasNext();
57+
}
58+
59+
RecordBatch out = next;
60+
next = nextBatch();
61+
62+
if (out == null) {
63+
throw new NoSuchElementException(
64+
"Cannot read past the end of " + CsvReaderIterator.class.getSimpleName());
65+
}
66+
67+
return out;
68+
}
69+
70+
private RecordBatch nextBatch() {
71+
ArrayList<Record> rows = new ArrayList<>(batchSize);
72+
73+
Record line;
74+
do {
75+
line = parser.parseNextRecord();
76+
if (line != null) {
77+
rows.add(line);
78+
}
79+
} while (line != null && rows.size() < batchSize);
80+
81+
if (rows.isEmpty()) {
82+
return null;
83+
}
84+
85+
return createBatch(rows);
86+
}
87+
88+
private RecordBatch createBatch(ArrayList<Record> rows) {
89+
VectorSchemaRoot root = VectorSchemaRoot.create(schema.toArrow(),
90+
new RootAllocator(Long.MAX_VALUE));
91+
root.getFieldVectors().forEach(v -> v.setInitialCapacity(rows.size()));
92+
root.allocateNew();
93+
94+
IndexedStream.withIndex(root.getFieldVectors()).forEach(field -> {
95+
ValueVector vector = field.getValue();
96+
if (vector instanceof VarCharVector) {
97+
VarCharVector varCharVector = (VarCharVector) vector;
98+
IndexedStream.withIndex(rows).forEach(row -> {
99+
String valueStr = row.getValue().getValue(vector.getName(), "").trim();
100+
varCharVector.setSafe(row.getIndex(), valueStr.getBytes());
101+
});
102+
} else if (vector instanceof TinyIntVector) {
103+
TinyIntVector tinyIntVector = (TinyIntVector) vector;
104+
IndexedStream.withIndex(rows).forEach(row -> {
105+
String valueStr = row.getValue().getValue(vector.getName(), "").trim();
106+
if (valueStr.isEmpty()) {
107+
tinyIntVector.setNull(row.getIndex());
108+
} else {
109+
tinyIntVector.set(row.getIndex(), Byte.parseByte(valueStr));
110+
}
111+
});
112+
} else if (vector instanceof SmallIntVector) {
113+
SmallIntVector smallIntVector = (SmallIntVector) vector;
114+
IndexedStream.withIndex(rows).forEach(row -> {
115+
String valueStr = row.getValue().getValue(vector.getName(), "").trim();
116+
if (valueStr.isEmpty()) {
117+
smallIntVector.setNull(row.getIndex());
118+
} else {
119+
smallIntVector.set(row.getIndex(), Short.parseShort(valueStr));
120+
}
121+
});
122+
} else if (vector instanceof IntVector) {
123+
IntVector intVector = (IntVector) vector;
124+
IndexedStream.withIndex(rows).forEach(row -> {
125+
String valueStr = row.getValue().getValue(vector.getName(), "").trim();
126+
if (valueStr.isEmpty()) {
127+
intVector.setNull(row.getIndex());
128+
} else {
129+
intVector.set(row.getIndex(), Integer.parseInt(valueStr));
130+
}
131+
});
132+
} else if (vector instanceof BigIntVector) {
133+
BigIntVector bigIntVector = (BigIntVector) vector;
134+
IndexedStream.withIndex(rows).forEach(row -> {
135+
String valueStr = row.getValue().getValue(vector.getName(), "").trim();
136+
if (valueStr.isEmpty()) {
137+
bigIntVector.setNull(row.getIndex());
138+
} else {
139+
bigIntVector.set(row.getIndex(), Long.parseLong(valueStr));
140+
}
141+
});
142+
} else if (vector instanceof Float4Vector) {
143+
Float4Vector float4Vector = (Float4Vector) vector;
144+
IndexedStream.withIndex(rows).forEach(row -> {
145+
String valueStr = row.getValue().getValue(vector.getName(), "").trim();
146+
if (valueStr.isEmpty()) {
147+
float4Vector.setNull(row.getIndex());
148+
} else {
149+
float4Vector.set(row.getIndex(), Float.parseFloat(valueStr));
150+
}
151+
});
152+
} else if (vector instanceof Float8Vector) {
153+
Float8Vector float8Vector = (Float8Vector) vector;
154+
IndexedStream.withIndex(rows).forEach(row -> {
155+
String valueStr = row.getValue().getValue(vector.getName(), "").trim();
156+
if (valueStr.isEmpty()) {
157+
float8Vector.setNull(row.getIndex());
158+
} else {
159+
float8Vector.set(row.getIndex(), Double.parseDouble(valueStr));
160+
}
161+
});
162+
} else {
163+
throw new IllegalStateException(
164+
"No support for reading CSV columns with data type " + vector);
165+
}
166+
vector.setValueCount(rows.size());
167+
});
168+
169+
return new RecordBatch(schema, root.getFieldVectors().stream().map(ArrowFieldVector::new)
170+
.collect(Collectors.toList()));
171+
}
172+
}

glint/src/main/java/co/clflushopt/glint/datasource/DataSource.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,5 @@ public interface DataSource {
3030
* @param projection
3131
* @return `RecordBatch` with only requested projections.
3232
*/
33-
public List<RecordBatch> scan(List<String> projection);
33+
public Iterable<RecordBatch> scan(List<String> projection);
3434
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package co.clflushopt.glint.types;
2+
3+
import org.apache.arrow.vector.BigIntVector;
4+
import org.apache.arrow.vector.BitVector;
5+
import org.apache.arrow.vector.FieldVector;
6+
import org.apache.arrow.vector.Float4Vector;
7+
import org.apache.arrow.vector.Float8Vector;
8+
import org.apache.arrow.vector.IntVector;
9+
import org.apache.arrow.vector.SmallIntVector;
10+
import org.apache.arrow.vector.TinyIntVector;
11+
import org.apache.arrow.vector.VarCharVector;
12+
import org.apache.arrow.vector.types.pojo.ArrowType;
13+
14+
/**
15+
* Wrapper around Arrow FieldVector
16+
*/
17+
public class ArrowFieldVector implements ColumnVector {
18+
private final FieldVector field;
19+
20+
public ArrowFieldVector(FieldVector field) {
21+
this.field = field;
22+
}
23+
24+
@Override
25+
public ArrowType getType() {
26+
if (field instanceof BitVector) {
27+
return ArrowTypes.BooleanType;
28+
} else if (field instanceof TinyIntVector) {
29+
return ArrowTypes.Int8Type;
30+
} else if (field instanceof SmallIntVector) {
31+
return ArrowTypes.Int16Type;
32+
} else if (field instanceof IntVector) {
33+
return ArrowTypes.Int32Type;
34+
} else if (field instanceof BigIntVector) {
35+
return ArrowTypes.Int64Type;
36+
} else if (field instanceof Float4Vector) {
37+
return ArrowTypes.FloatType;
38+
} else if (field instanceof Float8Vector) {
39+
return ArrowTypes.DoubleType;
40+
} else if (field instanceof VarCharVector) {
41+
return ArrowTypes.StringType;
42+
} else {
43+
throw new IllegalStateException("Unsupported field vector type: " + field.getClass());
44+
}
45+
}
46+
47+
@Override
48+
public Object getValue(int i) {
49+
if (field.isNull(i)) {
50+
return null;
51+
}
52+
53+
if (field instanceof BitVector) {
54+
return ((BitVector) field).get(i) == 1;
55+
} else if (field instanceof TinyIntVector) {
56+
return ((TinyIntVector) field).get(i);
57+
} else if (field instanceof SmallIntVector) {
58+
return ((SmallIntVector) field).get(i);
59+
} else if (field instanceof IntVector) {
60+
return ((IntVector) field).get(i);
61+
} else if (field instanceof BigIntVector) {
62+
return ((BigIntVector) field).get(i);
63+
} else if (field instanceof Float4Vector) {
64+
return ((Float4Vector) field).get(i);
65+
} else if (field instanceof Float8Vector) {
66+
return ((Float8Vector) field).get(i);
67+
} else if (field instanceof VarCharVector) {
68+
byte[] bytes = ((VarCharVector) field).get(i);
69+
return bytes == null ? null : new String(bytes);
70+
} else {
71+
throw new IllegalStateException("Unsupported field vector type: " + field.getClass());
72+
}
73+
}
74+
75+
@Override
76+
public int getSize() {
77+
return field.getValueCount();
78+
}
79+
80+
public FieldVector getField() {
81+
return field;
82+
}
83+
}

0 commit comments

Comments
 (0)