diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index ca6cd4ce3..f211d8237 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -17,6 +17,7 @@ import io.substrait.extension.SimpleExtension; import io.substrait.function.ToTypeString; import io.substrait.plan.Plan; +import io.substrait.relation.AbstractWriteRel; import io.substrait.relation.Aggregate; import io.substrait.relation.Cross; import io.substrait.relation.EmptyScan; @@ -25,6 +26,8 @@ import io.substrait.relation.Filter; import io.substrait.relation.Join; import io.substrait.relation.NamedScan; +import io.substrait.relation.NamedUpdate; +import io.substrait.relation.NamedWrite; import io.substrait.relation.Project; import io.substrait.relation.Rel; import io.substrait.relation.Set; @@ -317,6 +320,92 @@ public EmptyScan emptyScan() { .build(); } + public NamedWrite namedWrite( + Iterable tableName, + Iterable columnNames, + AbstractWriteRel.WriteOp op, + AbstractWriteRel.CreateMode createMode, + AbstractWriteRel.OutputMode outputMode, + Rel input) { + return namedWrite(tableName, columnNames, op, createMode, outputMode, input, Optional.empty()); + } + + public NamedWrite namedWrite( + Iterable tableName, + Iterable columnNames, + AbstractWriteRel.WriteOp op, + AbstractWriteRel.CreateMode createMode, + AbstractWriteRel.OutputMode outputMode, + Rel input, + Rel.Remap remap) { + return namedWrite( + tableName, columnNames, op, createMode, outputMode, input, Optional.of(remap)); + } + + private NamedWrite namedWrite( + Iterable tableName, + Iterable columnNames, + AbstractWriteRel.WriteOp op, + AbstractWriteRel.CreateMode createMode, + AbstractWriteRel.OutputMode outputMode, + Rel input, + Optional remap) { + Type.Struct struct = input.getRecordType(); + NamedStruct namedStruct = NamedStruct.of(columnNames, struct); + return NamedWrite.builder() + .names(tableName) + .tableSchema(namedStruct) + .operation(op) + .createMode(createMode) + .outputMode(outputMode) + .input(input) + .remap(remap) + .build(); + } + + public NamedUpdate namedUpdate( + Iterable tableName, + Iterable columnNames, + List transformations, + Expression condition, + boolean nullable) { + return namedUpdate( + tableName, columnNames, transformations, condition, nullable, Optional.empty()); + } + + public NamedUpdate namedUpdate( + Iterable tableName, + Iterable columnNames, + List transformations, + Expression condition, + boolean nullable, + Rel.Remap remap) { + return namedUpdate( + tableName, columnNames, transformations, condition, nullable, Optional.of(remap)); + } + + private NamedUpdate namedUpdate( + Iterable tableName, + Iterable columnNames, + List transformations, + Expression condition, + boolean nullable, + Optional remap) { + List types = + transformations.stream() + .map(t -> t.getTransformation().getType()) + .collect(Collectors.toList()); + Type.Struct struct = Type.Struct.builder().fields(types).nullable(nullable).build(); + NamedStruct namedStruct = NamedStruct.of(columnNames, struct); + return NamedUpdate.builder() + .names(tableName) + .tableSchema(namedStruct) + .transformations(transformations) + .condition(condition) + .remap(remap) + .build(); + } + public Project project(Function> expressionsFn, Rel input) { return project(expressionsFn, Optional.empty(), input); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java b/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java index 1e8bc3f79..ec05ce1c1 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java @@ -3,6 +3,8 @@ import io.substrait.isthmus.calcite.SubstraitSchema; import io.substrait.isthmus.calcite.SubstraitTable; import io.substrait.relation.NamedScan; +import io.substrait.relation.NamedUpdate; +import io.substrait.relation.NamedWrite; import io.substrait.relation.Rel; import io.substrait.relation.RelCopyOnWriteVisitor; import io.substrait.type.NamedStruct; @@ -72,7 +74,8 @@ private TableGatherer() { } /** - * Gathers all tables defined in {@link NamedScan}s under the given {@link Rel} + * Gathers all tables defined in {@link NamedScan}s and {@link NamedWrite}s under the given + * {@link Rel} * * @param rootRel under which to search for {@link NamedScan}s * @return a map of qualified table names to their associated Substrait schemas @@ -100,5 +103,42 @@ public Optional visit(NamedScan namedScan, EmptyVisitationContext context) return Optional.empty(); } + + @Override + public Optional visit(NamedWrite namedWrite, EmptyVisitationContext context) { + super.visit(namedWrite, context); + List tableName = namedWrite.getNames(); + + if (tableMap.containsKey(tableName)) { + NamedStruct existingSchema = tableMap.get(tableName); + if (!existingSchema.equals(namedWrite.getTableSchema())) { + throw new IllegalArgumentException( + String.format( + "NamedWrite for %s is present multiple times with different schemas", tableName)); + } + } + tableMap.put(tableName, namedWrite.getTableSchema()); + + return Optional.empty(); + } + + @Override + public Optional visit(NamedUpdate namedUpdate, EmptyVisitationContext context) { + super.visit(namedUpdate, context); + List tableName = namedUpdate.getNames(); + + if (tableMap.containsKey(tableName)) { + NamedStruct existingSchema = tableMap.get(tableName); + if (!existingSchema.equals(namedUpdate.getTableSchema())) { + throw new IllegalArgumentException( + String.format( + "NamedUpdate for %s is present multiple times with different schemas", + tableName)); + } + } + tableMap.put(tableName, namedUpdate.getTableSchema()); + + return Optional.empty(); + } } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/SchemaCollectorTest.java b/isthmus/src/test/java/io/substrait/isthmus/SchemaCollectorTest.java index 0524b0674..f094fabfe 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SchemaCollectorTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SchemaCollectorTest.java @@ -5,7 +5,16 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import io.substrait.dsl.SubstraitBuilder; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FieldReference; +import io.substrait.expression.ImmutableFieldReference; +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.relation.AbstractWriteRel; +import io.substrait.relation.NamedUpdate; import io.substrait.relation.Rel; +import io.substrait.type.TypeCreator; +import java.util.Arrays; import java.util.List; import org.apache.calcite.jdbc.CalciteSchema; import org.junit.jupiter.api.Test; @@ -42,17 +51,23 @@ void canCollectTables() { @Test void canCollectTablesInSchemas() { Rel rel = - b.cross( + b.namedWrite( + List.of("schema3", "table4"), + List.of("col1", "col2", "col3", "col4", "col5", "col6"), + AbstractWriteRel.WriteOp.UPDATE, + AbstractWriteRel.CreateMode.REPLACE_IF_EXISTS, + AbstractWriteRel.OutputMode.MODIFIED_RECORDS, b.cross( - b.namedScan( - List.of("schema1", "table1"), - List.of("col1", "col2", "col3"), - List.of(N.I64, N.FP64, N.STRING)), - b.namedScan( - List.of("schema1", "table2"), - List.of("col4", "col5"), - List.of(N.BOOLEAN, N.I32))), - b.namedScan(List.of("schema2", "table3"), List.of("col6"), List.of(N.I64))); + b.cross( + b.namedScan( + List.of("schema1", "table1"), + List.of("col1", "col2", "col3"), + List.of(N.I64, N.FP64, N.STRING)), + b.namedScan( + List.of("schema1", "table2"), + List.of("col4", "col5"), + List.of(N.BOOLEAN, N.I32))), + b.namedScan(List.of("schema2", "table3"), List.of("col6"), List.of(N.I64)))); CalciteSchema calciteSchema = schemaCollector.toSchema(rel); CalciteSchema schema1 = calciteSchema.getSubSchema("schema1", false); @@ -61,6 +76,57 @@ void canCollectTablesInSchemas() { CalciteSchema schema2 = calciteSchema.getSubSchema("schema2", false); hasTable(schema2, "table3", "RecordType(BIGINT col6) NOT NULL"); + + CalciteSchema schema3 = calciteSchema.getSubSchema("schema3", false); + hasTable( + schema3, + "table4", + "RecordType(BIGINT col1, DOUBLE col2, VARCHAR col3, BOOLEAN col4, INTEGER col5, BIGINT col6) NOT NULL"); + } + + private static Expression.ScalarFunctionInvocation fnAdd(int value) { + return DefaultExtensionCatalog.DEFAULT_COLLECTION.scalarFunctions().stream() + .filter(s -> s.name().equalsIgnoreCase("add")) + .findFirst() + .map( + declaration -> + ExpressionCreator.scalarFunction( + declaration, + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.I64) + .build(), + ExpressionCreator.i32(false, value))) + .get(); + } + + @Test + void testUpdate() { + + List transformations = + Arrays.asList( + NamedUpdate.TransformExpression.builder() + .columnTarget(0) + .transformation(fnAdd(1)) + .build()); + Expression condition = ExpressionCreator.bool(false, true); + + Rel rel = + b.namedWrite( + List.of("schema1", "table2"), + List.of("col1"), + AbstractWriteRel.WriteOp.INSERT, + AbstractWriteRel.CreateMode.APPEND_IF_EXISTS, + AbstractWriteRel.OutputMode.NO_OUTPUT, + b.namedUpdate( + List.of("schema1", "table1"), List.of("col1"), transformations, condition, true)); + + CalciteSchema calciteSchema = schemaCollector.toSchema(rel); + CalciteSchema schema1 = calciteSchema.getSubSchema("schema1", false); + hasTable(schema1, "table1", "RecordType(BOOLEAN col1)"); + + hasTable(schema1, "table2", "RecordType(BOOLEAN col1)"); } @Test