Skip to content

Commit 034024a

Browse files
committed
feat(isthmus): extend Schema collector for dml
1 parent 0ad0570 commit 034024a

File tree

3 files changed

+206
-11
lines changed

3 files changed

+206
-11
lines changed

core/src/main/java/io/substrait/dsl/SubstraitBuilder.java

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import io.substrait.extension.SimpleExtension;
1818
import io.substrait.function.ToTypeString;
1919
import io.substrait.plan.Plan;
20+
import io.substrait.relation.AbstractWriteRel;
2021
import io.substrait.relation.Aggregate;
2122
import io.substrait.relation.Cross;
2223
import io.substrait.relation.EmptyScan;
@@ -25,6 +26,8 @@
2526
import io.substrait.relation.Filter;
2627
import io.substrait.relation.Join;
2728
import io.substrait.relation.NamedScan;
29+
import io.substrait.relation.NamedUpdate;
30+
import io.substrait.relation.NamedWrite;
2831
import io.substrait.relation.Project;
2932
import io.substrait.relation.Rel;
3033
import io.substrait.relation.Set;
@@ -317,6 +320,92 @@ public EmptyScan emptyScan() {
317320
.build();
318321
}
319322

323+
public NamedWrite namedWrite(
324+
Iterable<String> tableName,
325+
Iterable<String> columnNames,
326+
AbstractWriteRel.WriteOp op,
327+
AbstractWriteRel.CreateMode createMode,
328+
AbstractWriteRel.OutputMode outputMode,
329+
Rel input) {
330+
return namedWrite(tableName, columnNames, op, createMode, outputMode, input, Optional.empty());
331+
}
332+
333+
public NamedWrite namedWrite(
334+
Iterable<String> tableName,
335+
Iterable<String> columnNames,
336+
AbstractWriteRel.WriteOp op,
337+
AbstractWriteRel.CreateMode createMode,
338+
AbstractWriteRel.OutputMode outputMode,
339+
Rel input,
340+
Rel.Remap remap) {
341+
return namedWrite(
342+
tableName, columnNames, op, createMode, outputMode, input, Optional.of(remap));
343+
}
344+
345+
private NamedWrite namedWrite(
346+
Iterable<String> tableName,
347+
Iterable<String> columnNames,
348+
AbstractWriteRel.WriteOp op,
349+
AbstractWriteRel.CreateMode createMode,
350+
AbstractWriteRel.OutputMode outputMode,
351+
Rel input,
352+
Optional<Rel.Remap> remap) {
353+
Type.Struct struct = input.getRecordType();
354+
NamedStruct namedStruct = NamedStruct.of(columnNames, struct);
355+
return NamedWrite.builder()
356+
.names(tableName)
357+
.tableSchema(namedStruct)
358+
.operation(op)
359+
.createMode(createMode)
360+
.outputMode(outputMode)
361+
.input(input)
362+
.remap(remap)
363+
.build();
364+
}
365+
366+
public NamedUpdate namedUpdate(
367+
Iterable<String> tableName,
368+
Iterable<String> columnNames,
369+
List<NamedUpdate.TransformExpression> transformations,
370+
Expression condition,
371+
boolean nullable) {
372+
return namedUpdate(
373+
tableName, columnNames, transformations, condition, nullable, Optional.empty());
374+
}
375+
376+
public NamedUpdate namedUpdate(
377+
Iterable<String> tableName,
378+
Iterable<String> columnNames,
379+
List<NamedUpdate.TransformExpression> transformations,
380+
Expression condition,
381+
boolean nullable,
382+
Rel.Remap remap) {
383+
return namedUpdate(
384+
tableName, columnNames, transformations, condition, nullable, Optional.of(remap));
385+
}
386+
387+
private NamedUpdate namedUpdate(
388+
Iterable<String> tableName,
389+
Iterable<String> columnNames,
390+
List<NamedUpdate.TransformExpression> transformations,
391+
Expression condition,
392+
boolean nullable,
393+
Optional<Rel.Remap> remap) {
394+
List<Type> types =
395+
transformations.stream()
396+
.map(t -> t.getTransformation().getType())
397+
.collect(Collectors.toList());
398+
Type.Struct struct = Type.Struct.builder().fields(types).nullable(nullable).build();
399+
NamedStruct namedStruct = NamedStruct.of(columnNames, struct);
400+
return NamedUpdate.builder()
401+
.names(tableName)
402+
.tableSchema(namedStruct)
403+
.transformations(transformations)
404+
.condition(condition)
405+
.remap(remap)
406+
.build();
407+
}
408+
320409
public Project project(Function<Rel, Iterable<? extends Expression>> expressionsFn, Rel input) {
321410
return project(expressionsFn, Optional.empty(), input);
322411
}

isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import io.substrait.isthmus.calcite.SubstraitSchema;
44
import io.substrait.isthmus.calcite.SubstraitTable;
55
import io.substrait.relation.NamedScan;
6+
import io.substrait.relation.NamedUpdate;
7+
import io.substrait.relation.NamedWrite;
68
import io.substrait.relation.Rel;
79
import io.substrait.relation.RelCopyOnWriteVisitor;
810
import io.substrait.type.NamedStruct;
@@ -72,7 +74,8 @@ private TableGatherer() {
7274
}
7375

7476
/**
75-
* Gathers all tables defined in {@link NamedScan}s under the given {@link Rel}
77+
* Gathers all tables defined in {@link NamedScan}s and {@link NamedWrite}s under the given
78+
* {@link Rel}
7679
*
7780
* @param rootRel under which to search for {@link NamedScan}s
7881
* @return a map of qualified table names to their associated Substrait schemas
@@ -100,5 +103,42 @@ public Optional<Rel> visit(NamedScan namedScan, EmptyVisitationContext context)
100103

101104
return Optional.empty();
102105
}
106+
107+
@Override
108+
public Optional<Rel> visit(NamedWrite namedWrite, EmptyVisitationContext context) {
109+
super.visit(namedWrite, context);
110+
List<String> tableName = namedWrite.getNames();
111+
112+
if (tableMap.containsKey(tableName)) {
113+
NamedStruct existingSchema = tableMap.get(tableName);
114+
if (!existingSchema.equals(namedWrite.getTableSchema())) {
115+
throw new IllegalArgumentException(
116+
String.format(
117+
"NamedWrite for %s is present multiple times with different schemas", tableName));
118+
}
119+
}
120+
tableMap.put(tableName, namedWrite.getTableSchema());
121+
122+
return Optional.empty();
123+
}
124+
125+
@Override
126+
public Optional<Rel> visit(NamedUpdate namedUpdate, EmptyVisitationContext context) {
127+
super.visit(namedUpdate, context);
128+
List<String> tableName = namedUpdate.getNames();
129+
130+
if (tableMap.containsKey(tableName)) {
131+
NamedStruct existingSchema = tableMap.get(tableName);
132+
if (!existingSchema.equals(namedUpdate.getTableSchema())) {
133+
throw new IllegalArgumentException(
134+
String.format(
135+
"NamedUpdate for %s is present multiple times with different schemas",
136+
tableName));
137+
}
138+
}
139+
tableMap.put(tableName, namedUpdate.getTableSchema());
140+
141+
return Optional.empty();
142+
}
103143
}
104144
}

isthmus/src/test/java/io/substrait/isthmus/SchemaCollectorTest.java

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,16 @@
55
import static org.junit.jupiter.api.Assertions.assertThrows;
66

77
import io.substrait.dsl.SubstraitBuilder;
8+
import io.substrait.expression.Expression;
9+
import io.substrait.expression.ExpressionCreator;
10+
import io.substrait.expression.FieldReference;
11+
import io.substrait.expression.ImmutableFieldReference;
12+
import io.substrait.extension.DefaultExtensionCatalog;
13+
import io.substrait.relation.AbstractWriteRel;
14+
import io.substrait.relation.NamedUpdate;
815
import io.substrait.relation.Rel;
16+
import io.substrait.type.TypeCreator;
17+
import java.util.Arrays;
918
import java.util.List;
1019
import org.apache.calcite.jdbc.CalciteSchema;
1120
import org.junit.jupiter.api.Test;
@@ -42,17 +51,23 @@ void canCollectTables() {
4251
@Test
4352
void canCollectTablesInSchemas() {
4453
Rel rel =
45-
b.cross(
54+
b.namedWrite(
55+
List.of("schema3", "table4"),
56+
List.of("col1", "col2", "col3", "col4", "col5", "col6"),
57+
AbstractWriteRel.WriteOp.UPDATE,
58+
AbstractWriteRel.CreateMode.REPLACE_IF_EXISTS,
59+
AbstractWriteRel.OutputMode.MODIFIED_RECORDS,
4660
b.cross(
47-
b.namedScan(
48-
List.of("schema1", "table1"),
49-
List.of("col1", "col2", "col3"),
50-
List.of(N.I64, N.FP64, N.STRING)),
51-
b.namedScan(
52-
List.of("schema1", "table2"),
53-
List.of("col4", "col5"),
54-
List.of(N.BOOLEAN, N.I32))),
55-
b.namedScan(List.of("schema2", "table3"), List.of("col6"), List.of(N.I64)));
61+
b.cross(
62+
b.namedScan(
63+
List.of("schema1", "table1"),
64+
List.of("col1", "col2", "col3"),
65+
List.of(N.I64, N.FP64, N.STRING)),
66+
b.namedScan(
67+
List.of("schema1", "table2"),
68+
List.of("col4", "col5"),
69+
List.of(N.BOOLEAN, N.I32))),
70+
b.namedScan(List.of("schema2", "table3"), List.of("col6"), List.of(N.I64))));
5671
CalciteSchema calciteSchema = schemaCollector.toSchema(rel);
5772

5873
CalciteSchema schema1 = calciteSchema.getSubSchema("schema1", false);
@@ -61,6 +76,57 @@ void canCollectTablesInSchemas() {
6176

6277
CalciteSchema schema2 = calciteSchema.getSubSchema("schema2", false);
6378
hasTable(schema2, "table3", "RecordType(BIGINT col6) NOT NULL");
79+
80+
CalciteSchema schema3 = calciteSchema.getSubSchema("schema3", false);
81+
hasTable(
82+
schema3,
83+
"table4",
84+
"RecordType(BIGINT col1, DOUBLE col2, VARCHAR col3, BOOLEAN col4, INTEGER col5, BIGINT col6) NOT NULL");
85+
}
86+
87+
private static Expression.ScalarFunctionInvocation fnAdd(int value) {
88+
return DefaultExtensionCatalog.DEFAULT_COLLECTION.scalarFunctions().stream()
89+
.filter(s -> s.name().equalsIgnoreCase("add"))
90+
.findFirst()
91+
.map(
92+
declaration ->
93+
ExpressionCreator.scalarFunction(
94+
declaration,
95+
TypeCreator.REQUIRED.BOOLEAN,
96+
ImmutableFieldReference.builder()
97+
.addSegments(FieldReference.StructField.of(0))
98+
.type(TypeCreator.REQUIRED.I64)
99+
.build(),
100+
ExpressionCreator.i32(false, value)))
101+
.get();
102+
}
103+
104+
@Test
105+
void testUpdate() {
106+
107+
List<NamedUpdate.TransformExpression> transformations =
108+
Arrays.asList(
109+
NamedUpdate.TransformExpression.builder()
110+
.columnTarget(0)
111+
.transformation(fnAdd(1))
112+
.build());
113+
Expression condition = ExpressionCreator.bool(false, true);
114+
115+
Rel rel =
116+
b.namedWrite(
117+
List.of("schema1", "table2"),
118+
List.of("col1"),
119+
AbstractWriteRel.WriteOp.INSERT,
120+
AbstractWriteRel.CreateMode.APPEND_IF_EXISTS,
121+
AbstractWriteRel.OutputMode.NO_OUTPUT,
122+
b.namedUpdate(
123+
List.of("schema1", "table1"), List.of("col1"), transformations, condition, true));
124+
125+
CalciteSchema calciteSchema = schemaCollector.toSchema(rel);
126+
CalciteSchema schema1 = calciteSchema.getSubSchema("schema1", false);
127+
hasTable(schema1, "table1", "RecordType(BOOLEAN col1)");
128+
129+
hasTable(schema1, "table2", "RecordType(BOOLEAN col1)");
64130
}
65131

66132
@Test

0 commit comments

Comments
 (0)