Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io.substrait.extension.SimpleExtension;
import io.substrait.function.ToTypeString;
import io.substrait.plan.Plan;
import io.substrait.relation.AbstractWriteRel;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Cross;
import io.substrait.relation.EmptyScan;
Expand All @@ -25,6 +26,8 @@
import io.substrait.relation.Filter;
import io.substrait.relation.Join;
import io.substrait.relation.NamedScan;
import io.substrait.relation.NamedUpdate;
import io.substrait.relation.NamedWrite;
import io.substrait.relation.Project;
import io.substrait.relation.Rel;
import io.substrait.relation.Set;
Expand Down Expand Up @@ -317,6 +320,92 @@ public EmptyScan emptyScan() {
.build();
}

public NamedWrite namedWrite(
Iterable<String> tableName,
Iterable<String> columnNames,
AbstractWriteRel.WriteOp op,
AbstractWriteRel.CreateMode createMode,
AbstractWriteRel.OutputMode outputMode,
Rel input) {
return namedWrite(tableName, columnNames, op, createMode, outputMode, input, Optional.empty());
}

public NamedWrite namedWrite(
Iterable<String> tableName,
Iterable<String> columnNames,
AbstractWriteRel.WriteOp op,
AbstractWriteRel.CreateMode createMode,
AbstractWriteRel.OutputMode outputMode,
Rel input,
Rel.Remap remap) {
return namedWrite(
tableName, columnNames, op, createMode, outputMode, input, Optional.of(remap));
}

private NamedWrite namedWrite(
Iterable<String> tableName,
Iterable<String> columnNames,
AbstractWriteRel.WriteOp op,
AbstractWriteRel.CreateMode createMode,
AbstractWriteRel.OutputMode outputMode,
Rel input,
Optional<Rel.Remap> remap) {
Type.Struct struct = input.getRecordType();
NamedStruct namedStruct = NamedStruct.of(columnNames, struct);
return NamedWrite.builder()
.names(tableName)
.tableSchema(namedStruct)
.operation(op)
.createMode(createMode)
.outputMode(outputMode)
.input(input)
.remap(remap)
.build();
}

public NamedUpdate namedUpdate(
Iterable<String> tableName,
Iterable<String> columnNames,
List<NamedUpdate.TransformExpression> transformations,
Expression condition,
boolean nullable) {
return namedUpdate(
tableName, columnNames, transformations, condition, nullable, Optional.empty());
}

public NamedUpdate namedUpdate(
Iterable<String> tableName,
Iterable<String> columnNames,
List<NamedUpdate.TransformExpression> transformations,
Expression condition,
boolean nullable,
Rel.Remap remap) {
return namedUpdate(
tableName, columnNames, transformations, condition, nullable, Optional.of(remap));
}

private NamedUpdate namedUpdate(
Iterable<String> tableName,
Iterable<String> columnNames,
List<NamedUpdate.TransformExpression> transformations,
Expression condition,
boolean nullable,
Optional<Rel.Remap> remap) {
List<Type> types =
transformations.stream()
.map(t -> t.getTransformation().getType())
.collect(Collectors.toList());
Type.Struct struct = Type.Struct.builder().fields(types).nullable(nullable).build();
NamedStruct namedStruct = NamedStruct.of(columnNames, struct);
return NamedUpdate.builder()
.names(tableName)
.tableSchema(namedStruct)
.transformations(transformations)
.condition(condition)
.remap(remap)
.build();
}

public Project project(Function<Rel, Iterable<? extends Expression>> expressionsFn, Rel input) {
return project(expressionsFn, Optional.empty(), input);
}
Expand Down
42 changes: 41 additions & 1 deletion isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import io.substrait.isthmus.calcite.SubstraitSchema;
import io.substrait.isthmus.calcite.SubstraitTable;
import io.substrait.relation.NamedScan;
import io.substrait.relation.NamedUpdate;
import io.substrait.relation.NamedWrite;
import io.substrait.relation.Rel;
import io.substrait.relation.RelCopyOnWriteVisitor;
import io.substrait.type.NamedStruct;
Expand Down Expand Up @@ -72,7 +74,8 @@ private TableGatherer() {
}

/**
* Gathers all tables defined in {@link NamedScan}s under the given {@link Rel}
* Gathers all tables defined in {@link NamedScan}s and {@link NamedWrite}s under the given
* {@link Rel}
*
* @param rootRel under which to search for {@link NamedScan}s
* @return a map of qualified table names to their associated Substrait schemas
Expand Down Expand Up @@ -100,5 +103,42 @@ public Optional<Rel> visit(NamedScan namedScan, EmptyVisitationContext context)

return Optional.empty();
}

@Override
public Optional<Rel> visit(NamedWrite namedWrite, EmptyVisitationContext context) {
super.visit(namedWrite, context);
List<String> tableName = namedWrite.getNames();

if (tableMap.containsKey(tableName)) {
NamedStruct existingSchema = tableMap.get(tableName);
if (!existingSchema.equals(namedWrite.getTableSchema())) {
throw new IllegalArgumentException(
String.format(
"NamedWrite for %s is present multiple times with different schemas", tableName));
}
}
tableMap.put(tableName, namedWrite.getTableSchema());

return Optional.empty();
}

@Override
public Optional<Rel> visit(NamedUpdate namedUpdate, EmptyVisitationContext context) {
super.visit(namedUpdate, context);
List<String> tableName = namedUpdate.getNames();

if (tableMap.containsKey(tableName)) {
NamedStruct existingSchema = tableMap.get(tableName);
if (!existingSchema.equals(namedUpdate.getTableSchema())) {
throw new IllegalArgumentException(
String.format(
"NamedUpdate for %s is present multiple times with different schemas",
tableName));
}
}
tableMap.put(tableName, namedUpdate.getTableSchema());

return Optional.empty();
}
}
}
86 changes: 76 additions & 10 deletions isthmus/src/test/java/io/substrait/isthmus/SchemaCollectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@
import static org.junit.jupiter.api.Assertions.assertThrows;

import io.substrait.dsl.SubstraitBuilder;
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FieldReference;
import io.substrait.expression.ImmutableFieldReference;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.relation.AbstractWriteRel;
import io.substrait.relation.NamedUpdate;
import io.substrait.relation.Rel;
import io.substrait.type.TypeCreator;
import java.util.Arrays;
import java.util.List;
import org.apache.calcite.jdbc.CalciteSchema;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -42,17 +51,23 @@ void canCollectTables() {
@Test
void canCollectTablesInSchemas() {
Rel rel =
b.cross(
b.namedWrite(
List.of("schema3", "table4"),
List.of("col1", "col2", "col3", "col4", "col5", "col6"),
AbstractWriteRel.WriteOp.UPDATE,
AbstractWriteRel.CreateMode.REPLACE_IF_EXISTS,
AbstractWriteRel.OutputMode.MODIFIED_RECORDS,
b.cross(
b.namedScan(
List.of("schema1", "table1"),
List.of("col1", "col2", "col3"),
List.of(N.I64, N.FP64, N.STRING)),
b.namedScan(
List.of("schema1", "table2"),
List.of("col4", "col5"),
List.of(N.BOOLEAN, N.I32))),
b.namedScan(List.of("schema2", "table3"), List.of("col6"), List.of(N.I64)));
b.cross(
b.namedScan(
List.of("schema1", "table1"),
List.of("col1", "col2", "col3"),
List.of(N.I64, N.FP64, N.STRING)),
b.namedScan(
List.of("schema1", "table2"),
List.of("col4", "col5"),
List.of(N.BOOLEAN, N.I32))),
b.namedScan(List.of("schema2", "table3"), List.of("col6"), List.of(N.I64))));
CalciteSchema calciteSchema = schemaCollector.toSchema(rel);

CalciteSchema schema1 = calciteSchema.getSubSchema("schema1", false);
Expand All @@ -61,6 +76,57 @@ void canCollectTablesInSchemas() {

CalciteSchema schema2 = calciteSchema.getSubSchema("schema2", false);
hasTable(schema2, "table3", "RecordType(BIGINT col6) NOT NULL");

CalciteSchema schema3 = calciteSchema.getSubSchema("schema3", false);
hasTable(
schema3,
"table4",
"RecordType(BIGINT col1, DOUBLE col2, VARCHAR col3, BOOLEAN col4, INTEGER col5, BIGINT col6) NOT NULL");
}

private static Expression.ScalarFunctionInvocation fnAdd(int value) {
return DefaultExtensionCatalog.DEFAULT_COLLECTION.scalarFunctions().stream()
.filter(s -> s.name().equalsIgnoreCase("add"))
.findFirst()
.map(
declaration ->
ExpressionCreator.scalarFunction(
declaration,
TypeCreator.REQUIRED.BOOLEAN,
ImmutableFieldReference.builder()
.addSegments(FieldReference.StructField.of(0))
.type(TypeCreator.REQUIRED.I64)
.build(),
ExpressionCreator.i32(false, value)))
.get();
}

@Test
void testUpdate() {

List<NamedUpdate.TransformExpression> transformations =
Arrays.asList(
NamedUpdate.TransformExpression.builder()
.columnTarget(0)
.transformation(fnAdd(1))
.build());
Expression condition = ExpressionCreator.bool(false, true);

Rel rel =
b.namedWrite(
List.of("schema1", "table2"),
List.of("col1"),
AbstractWriteRel.WriteOp.INSERT,
AbstractWriteRel.CreateMode.APPEND_IF_EXISTS,
AbstractWriteRel.OutputMode.NO_OUTPUT,
b.namedUpdate(
List.of("schema1", "table1"), List.of("col1"), transformations, condition, true));

CalciteSchema calciteSchema = schemaCollector.toSchema(rel);
CalciteSchema schema1 = calciteSchema.getSubSchema("schema1", false);
hasTable(schema1, "table1", "RecordType(BOOLEAN col1)");

hasTable(schema1, "table2", "RecordType(BOOLEAN col1)");
}

@Test
Expand Down