From fe4d98f010024f3e6185d1165f4a3cd7bbc056ce Mon Sep 17 00:00:00 2001 From: Niels Pardon Date: Thu, 23 Oct 2025 20:10:18 +0200 Subject: [PATCH 1/4] fix(core): close AdvancedExtension serde gaps Signed-off-by: Niels Pardon --- .../io/substrait/plan/PlanProtoConverter.java | 4 +- .../io/substrait/plan/ProtoPlanConverter.java | 2 +- .../io/substrait/relation/AbstractDdlRel.java | 2 +- .../substrait/relation/AbstractWriteRel.java | 2 +- .../io/substrait/relation/ExtensionWrite.java | 2 +- .../io/substrait/relation/NamedWrite.java | 2 +- .../substrait/relation/ProtoRelConverter.java | 95 +++++++++++-------- .../substrait/relation/RelProtoConverter.java | 31 ++++-- .../io/substrait/plan/PlanConverterTest.java | 36 +++++++ 9 files changed, 125 insertions(+), 51 deletions(-) diff --git a/core/src/main/java/io/substrait/plan/PlanProtoConverter.java b/core/src/main/java/io/substrait/plan/PlanProtoConverter.java index 6476176c6..ea1c39118 100644 --- a/core/src/main/java/io/substrait/plan/PlanProtoConverter.java +++ b/core/src/main/java/io/substrait/plan/PlanProtoConverter.java @@ -65,7 +65,9 @@ public Plan toProto(io.substrait.plan.Plan plan) { List planRels = new ArrayList<>(); ExtensionCollector functionCollector = new ExtensionCollector(extensionCollection); for (io.substrait.plan.Plan.Root root : plan.getRoots()) { - Rel input = new RelProtoConverter(functionCollector).toProto(root.getInput()); + Rel input = + new RelProtoConverter(functionCollector, extensionProtoConverter) + .toProto(root.getInput()); planRels.add( PlanRel.newBuilder() .setRoot( diff --git a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java index 0de8476db..3a98438ed 100644 --- a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java +++ b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java @@ -63,7 +63,7 @@ public ProtoPlanConverter( /** Override hook for providing custom {@link ProtoRelConverter} implementations */ protected ProtoRelConverter getProtoRelConverter(ExtensionLookup functionLookup) { - return new ProtoRelConverter(functionLookup, this.extensionCollection); + return new ProtoRelConverter(functionLookup, this.extensionCollection, protoExtensionConverter); } public Plan from(io.substrait.proto.Plan plan) { diff --git a/core/src/main/java/io/substrait/relation/AbstractDdlRel.java b/core/src/main/java/io/substrait/relation/AbstractDdlRel.java index 8d7947096..54f2504bc 100644 --- a/core/src/main/java/io/substrait/relation/AbstractDdlRel.java +++ b/core/src/main/java/io/substrait/relation/AbstractDdlRel.java @@ -6,7 +6,7 @@ import io.substrait.type.Type; import java.util.Optional; -public abstract class AbstractDdlRel extends ZeroInputRel { +public abstract class AbstractDdlRel extends ZeroInputRel implements HasExtension { public abstract NamedStruct getTableSchema(); public abstract Expression.StructLiteral getTableDefaults(); diff --git a/core/src/main/java/io/substrait/relation/AbstractWriteRel.java b/core/src/main/java/io/substrait/relation/AbstractWriteRel.java index a43db1cbf..d035e62c8 100644 --- a/core/src/main/java/io/substrait/relation/AbstractWriteRel.java +++ b/core/src/main/java/io/substrait/relation/AbstractWriteRel.java @@ -4,7 +4,7 @@ import io.substrait.type.NamedStruct; import io.substrait.type.Type; -public abstract class AbstractWriteRel extends SingleInputRel { +public abstract class AbstractWriteRel extends SingleInputRel implements HasExtension { public abstract NamedStruct getTableSchema(); diff --git a/core/src/main/java/io/substrait/relation/ExtensionWrite.java b/core/src/main/java/io/substrait/relation/ExtensionWrite.java index db591453b..72e39e84e 100644 --- a/core/src/main/java/io/substrait/relation/ExtensionWrite.java +++ b/core/src/main/java/io/substrait/relation/ExtensionWrite.java @@ -4,7 +4,7 @@ import org.immutables.value.Value; @Value.Immutable -public abstract class ExtensionWrite extends AbstractWriteRel implements HasExtension { +public abstract class ExtensionWrite extends AbstractWriteRel { public abstract Extension.WriteExtensionObject getDetail(); @Override diff --git a/core/src/main/java/io/substrait/relation/NamedWrite.java b/core/src/main/java/io/substrait/relation/NamedWrite.java index e46f9b3cb..bfe087d0f 100644 --- a/core/src/main/java/io/substrait/relation/NamedWrite.java +++ b/core/src/main/java/io/substrait/relation/NamedWrite.java @@ -5,7 +5,7 @@ import org.immutables.value.Value; @Value.Immutable -public abstract class NamedWrite extends AbstractWriteRel implements HasExtension { +public abstract class NamedWrite extends AbstractWriteRel { public abstract List getNames(); @Override diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 85310cb88..7ef943598 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -34,6 +34,7 @@ import io.substrait.proto.SortRel; import io.substrait.proto.UpdateRel; import io.substrait.proto.WriteRel; +import io.substrait.relation.ImmutableExtensionDdl.Builder; import io.substrait.relation.extensions.EmptyDetail; import io.substrait.relation.files.FileFormat; import io.substrait.relation.files.FileOrFiles; @@ -175,8 +176,8 @@ protected Rel newRead(ReadRel rel) { } } - protected Rel newWrite(WriteRel rel) { - WriteRel.WriteTypeCase relType = rel.getWriteTypeCase(); + protected Rel newWrite(final WriteRel rel) { + final WriteRel.WriteTypeCase relType = rel.getWriteTypeCase(); switch (relType) { case NAMED_TABLE: return newNamedWrite(rel); @@ -187,9 +188,9 @@ protected Rel newWrite(WriteRel rel) { } } - protected NamedWrite newNamedWrite(WriteRel rel) { - Rel input = from(rel.getInput()); - ImmutableNamedWrite.Builder builder = + protected NamedWrite newNamedWrite(final WriteRel rel) { + final Rel input = from(rel.getInput()); + final ImmutableNamedWrite.Builder builder = NamedWrite.builder() .input(input) .names(rel.getNamedTable().getNamesList()) @@ -202,14 +203,18 @@ protected NamedWrite newNamedWrite(WriteRel rel) { .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) .hint(optionalHint(rel.getCommon())); + + if (rel.hasAdvancedExtension()) { + builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); + } return builder.build(); } - protected Rel newExtensionWrite(WriteRel rel) { - Rel input = from(rel.getInput()); - Extension.WriteExtensionObject detail = + protected Rel newExtensionWrite(final WriteRel rel) { + final Rel input = from(rel.getInput()); + final Extension.WriteExtensionObject detail = detailFromWriteExtensionObject(rel.getExtensionTable().getDetail()); - ImmutableExtensionWrite.Builder builder = + final ImmutableExtensionWrite.Builder builder = ExtensionWrite.builder() .input(input) .detail(detail) @@ -222,11 +227,15 @@ protected Rel newExtensionWrite(WriteRel rel) { .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) .hint(optionalHint(rel.getCommon())); + + if (rel.hasAdvancedExtension()) { + builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); + } return builder.build(); } - protected Rel newDdl(DdlRel rel) { - DdlRel.WriteTypeCase relType = rel.getWriteTypeCase(); + protected Rel newDdl(final DdlRel rel) { + final DdlRel.WriteTypeCase relType = rel.getWriteTypeCase(); switch (relType) { case NAMED_OBJECT: return newNamedDdl(rel); @@ -238,35 +247,47 @@ protected Rel newDdl(DdlRel rel) { } protected NamedDdl newNamedDdl(DdlRel rel) { - NamedStruct tableSchema = newNamedStruct(rel.getTableSchema()); - return NamedDdl.builder() - .names(rel.getNamedObject().getNamesList()) - .tableSchema(tableSchema) - .tableDefaults(tableDefaults(rel.getTableDefaults(), tableSchema)) - .operation(NamedDdl.DdlOp.fromProto(rel.getOp())) - .object(NamedDdl.DdlObject.fromProto(rel.getObject())) - .viewDefinition(optionalViewDefinition(rel)) - .commonExtension(optionalAdvancedExtension(rel.getCommon())) - .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())) - .build(); + final NamedStruct tableSchema = newNamedStruct(rel.getTableSchema()); + final ImmutableNamedDdl.Builder builder = + NamedDdl.builder() + .names(rel.getNamedObject().getNamesList()) + .tableSchema(tableSchema) + .tableDefaults(tableDefaults(rel.getTableDefaults(), tableSchema)) + .operation(NamedDdl.DdlOp.fromProto(rel.getOp())) + .object(NamedDdl.DdlObject.fromProto(rel.getObject())) + .viewDefinition(optionalViewDefinition(rel)) + .commonExtension(optionalAdvancedExtension(rel.getCommon())) + .remap(optionalRelmap(rel.getCommon())) + .hint(optionalHint(rel.getCommon())); + + if (rel.hasAdvancedExtension()) { + builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); + } + + return builder.build(); } - protected ExtensionDdl newExtensionDdl(DdlRel rel) { - Extension.DdlExtensionObject detail = + protected ExtensionDdl newExtensionDdl(final DdlRel rel) { + final Extension.DdlExtensionObject detail = detailFromDdlExtensionObject(rel.getExtensionObject().getDetail()); - NamedStruct tableSchema = newNamedStruct(rel.getTableSchema()); - return ExtensionDdl.builder() - .detail(detail) - .tableSchema(newNamedStruct(rel.getTableSchema())) - .tableDefaults(tableDefaults(rel.getTableDefaults(), tableSchema)) - .operation(ExtensionDdl.DdlOp.fromProto(rel.getOp())) - .object(ExtensionDdl.DdlObject.fromProto(rel.getObject())) - .viewDefinition(optionalViewDefinition(rel)) - .commonExtension(optionalAdvancedExtension(rel.getCommon())) - .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())) - .build(); + final NamedStruct tableSchema = newNamedStruct(rel.getTableSchema()); + final Builder builder = + ExtensionDdl.builder() + .detail(detail) + .tableSchema(newNamedStruct(rel.getTableSchema())) + .tableDefaults(tableDefaults(rel.getTableDefaults(), tableSchema)) + .operation(ExtensionDdl.DdlOp.fromProto(rel.getOp())) + .object(ExtensionDdl.DdlObject.fromProto(rel.getObject())) + .viewDefinition(optionalViewDefinition(rel)) + .commonExtension(optionalAdvancedExtension(rel.getCommon())) + .remap(optionalRelmap(rel.getCommon())) + .hint(optionalHint(rel.getCommon())); + + if (rel.hasAdvancedExtension()) { + builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); + } + + return builder.build(); } protected Optional optionalViewDefinition(DdlRel rel) { diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index 91a45e798..05858d831 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -202,14 +202,15 @@ private AggregateRel.Grouping toProto(Aggregate.Grouping grouping) { @Override public Rel visit(EmptyScan emptyScan, EmptyVisitationContext context) throws RuntimeException { - return Rel.newBuilder() - .setRead( - ReadRel.newBuilder() - .setCommon(common(emptyScan)) - .setVirtualTable(ReadRel.VirtualTable.newBuilder().build()) - .setBaseSchema(emptyScan.getInitialSchema().toProto(typeProtoConverter)) - .build()) - .build(); + ReadRel.Builder builder = + ReadRel.newBuilder() + .setCommon(common(emptyScan)) + .setVirtualTable(ReadRel.VirtualTable.newBuilder().build()) + .setBaseSchema(emptyScan.getInitialSchema().toProto(typeProtoConverter)); + emptyScan + .getExtension() + .ifPresent(ae -> builder.setAdvancedExtension(extensionProtoConverter.toProto(ae))); + return Rel.newBuilder().setRead(builder.build()).build(); } @Override @@ -435,6 +436,10 @@ public Rel visit(NamedWrite write, EmptyVisitationContext context) throws Runtim .setCreateMode(write.getCreateMode().toProto()) .setOutput(write.getOutputMode().toProto()); + write + .getExtension() + .ifPresent(ae -> builder.setAdvancedExtension(extensionProtoConverter.toProto(ae))); + return Rel.newBuilder().setWrite(builder).build(); } @@ -451,6 +456,10 @@ public Rel visit(ExtensionWrite write, EmptyVisitationContext context) throws Ru .setCreateMode(write.getCreateMode().toProto()) .setOutput(write.getOutputMode().toProto()); + write + .getExtension() + .ifPresent(ae -> builder.setAdvancedExtension(extensionProtoConverter.toProto(ae))); + return Rel.newBuilder().setWrite(builder).build(); } @@ -468,6 +477,9 @@ public Rel visit(NamedDdl ddl, EmptyVisitationContext context) throws RuntimeExc builder.setViewDefinition(toProto(ddl.getViewDefinition().get())); } + ddl.getExtension() + .ifPresent(ae -> builder.setAdvancedExtension(extensionProtoConverter.toProto(ae))); + return Rel.newBuilder().setDdl(builder).build(); } @@ -486,6 +498,9 @@ public Rel visit(ExtensionDdl ddl, EmptyVisitationContext context) throws Runtim builder.setViewDefinition(toProto(ddl.getViewDefinition().get())); } + ddl.getExtension() + .ifPresent(ae -> builder.setAdvancedExtension(extensionProtoConverter.toProto(ae))); + return Rel.newBuilder().setDdl(builder).build(); } diff --git a/core/src/test/java/io/substrait/plan/PlanConverterTest.java b/core/src/test/java/io/substrait/plan/PlanConverterTest.java index 082187448..dd49cf207 100644 --- a/core/src/test/java/io/substrait/plan/PlanConverterTest.java +++ b/core/src/test/java/io/substrait/plan/PlanConverterTest.java @@ -4,6 +4,10 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import io.substrait.extension.AdvancedExtension; +import io.substrait.plan.Plan.Root; +import io.substrait.relation.EmptyScan; +import io.substrait.type.NamedStruct; +import io.substrait.type.TypeCreator; import io.substrait.utils.StringHolder; import io.substrait.utils.StringHolderHandlingExtensionProtoConverter; import io.substrait.utils.StringHolderHandlingProtoExtensionConverter; @@ -153,4 +157,36 @@ void advancedExtensionWithEnhancementAndOptimization() { assertEquals(plan, plan2); } + + @Test + void planIncludingRelationWithAdvancedExtension() { + final StringHolder enhanced = new StringHolder("ENHANCED"); + final StringHolder optimized = new StringHolder("OPTIMIZED"); + + final Plan plan = + Plan.builder() + .addRoots( + Root.builder() + .input( + EmptyScan.builder() + .initialSchema( + NamedStruct.builder().struct(TypeCreator.REQUIRED.struct()).build()) + .extension( + AdvancedExtension.builder() + .enhancement(enhanced) + .addOptimizations(optimized) + .build()) + .build()) + .build()) + .build(); + final PlanProtoConverter toProtoConverter = + new PlanProtoConverter(new StringHolderHandlingExtensionProtoConverter()); + final io.substrait.proto.Plan protoPlan = toProtoConverter.toProto(plan); + + final ProtoPlanConverter fromProtoConverter = + new ProtoPlanConverter(new StringHolderHandlingProtoExtensionConverter()); + final Plan plan2 = fromProtoConverter.from(protoPlan); + + assertEquals(plan, plan2); + } } From ccb6091692fa9f034d494fa7f483a242108dc647 Mon Sep 17 00:00:00 2001 From: Niels Pardon Date: Fri, 24 Oct 2025 11:54:05 +0200 Subject: [PATCH 2/4] fix: add test cases Signed-off-by: Niels Pardon --- .../io/substrait/dsl/SubstraitBuilder.java | 20 + .../substrait/relation/ProtoRelConverter.java | 16 +- ...vancedExtensionRelProtoConversionTest.java | 402 ++++++++++++++++++ 3 files changed, 437 insertions(+), 1 deletion(-) create mode 100644 core/src/test/java/io/substrait/extension/AdvancedExtensionRelProtoConversionTest.java diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index f211d8237..d47d075d6 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -19,6 +19,7 @@ import io.substrait.plan.Plan; import io.substrait.relation.AbstractWriteRel; import io.substrait.relation.Aggregate; +import io.substrait.relation.Aggregate.Measure; import io.substrait.relation.Cross; import io.substrait.relation.EmptyScan; import io.substrait.relation.Expand; @@ -604,6 +605,25 @@ public Aggregate.Measure count(Rel input, int field) { .build()); } + /** + * Returns a {@link Measure} representing the equivalent of a {@code COUNT(*)} SQL aggregation. + * + * @return the {@link Measure} representing {@code COUNT(*)} + */ + public Measure countStar() { + SimpleExtension.AggregateFunctionVariant declaration = + extensions.getAggregateFunction( + SimpleExtension.FunctionAnchor.of( + DefaultExtensionCatalog.FUNCTIONS_AGGREGATE_GENERIC, "count:")); + return measure( + AggregateFunctionInvocation.builder() + .outputType(R.I64) + .declaration(declaration) + .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) + .invocation(Expression.AggregationInvocation.ALL) + .build()); + } + public Aggregate.Measure min(Rel input, int field) { return min(fieldReference(input, field)); } diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 7ef943598..3b76aaf71 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -78,6 +78,18 @@ public ProtoRelConverter( this(lookup, extensions, new ProtoExtensionConverter()); } + /** + * Constructor with custom {@link ExtensionLookup} and {@link ProtoExtensionConverter}. + * + * @param lookup custom {@link ExtensionLookup} to use, must not be null + * @param protoExtensionConverter custom {@link ProtoExtensionConverter} to use, must not be null + */ + public ProtoRelConverter( + @NonNull final ExtensionLookup lookup, + @NonNull final ProtoExtensionConverter protoExtensionConverter) { + this(lookup, DefaultExtensionCatalog.DEFAULT_COLLECTION, protoExtensionConverter); + } + /** * Constructor with custom {@link ExtensionLookup}, {@link ExtensionCollection} and {@link * ProtoExtensionConverter}. @@ -475,9 +487,11 @@ protected NamedScan newNamedScan(ReadRel rel) { } protected ExtensionTable newExtensionTable(ReadRel rel) { + NamedStruct namedStruct = newNamedStruct(rel); Extension.ExtensionTableDetail detail = detailFromExtensionTable(rel.getExtensionTable().getDetail()); - ImmutableExtensionTable.Builder builder = ExtensionTable.from(detail); + ImmutableExtensionTable.Builder builder = + ExtensionTable.from(detail).initialSchema(namedStruct); builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) diff --git a/core/src/test/java/io/substrait/extension/AdvancedExtensionRelProtoConversionTest.java b/core/src/test/java/io/substrait/extension/AdvancedExtensionRelProtoConversionTest.java new file mode 100644 index 000000000..58d5e9728 --- /dev/null +++ b/core/src/test/java/io/substrait/extension/AdvancedExtensionRelProtoConversionTest.java @@ -0,0 +1,402 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.substrait.dsl.SubstraitBuilder; +import io.substrait.expression.Expression.SortDirection; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FieldReference; +import io.substrait.expression.ImmutableExpression.BoolLiteral; +import io.substrait.expression.ImmutableExpression.SortField; +import io.substrait.expression.ImmutableExpression.StructLiteral; +import io.substrait.relation.AbstractDdlRel.DdlObject; +import io.substrait.relation.AbstractDdlRel.DdlOp; +import io.substrait.relation.AbstractUpdate.TransformExpression; +import io.substrait.relation.AbstractWriteRel.CreateMode; +import io.substrait.relation.AbstractWriteRel.OutputMode; +import io.substrait.relation.AbstractWriteRel.WriteOp; +import io.substrait.relation.Aggregate; +import io.substrait.relation.Cross; +import io.substrait.relation.EmptyScan; +import io.substrait.relation.ExtensionDdl; +import io.substrait.relation.ExtensionTable; +import io.substrait.relation.ExtensionWrite; +import io.substrait.relation.Fetch; +import io.substrait.relation.Filter; +import io.substrait.relation.Join; +import io.substrait.relation.Join.JoinType; +import io.substrait.relation.LocalFiles; +import io.substrait.relation.NamedDdl; +import io.substrait.relation.NamedScan; +import io.substrait.relation.NamedUpdate; +import io.substrait.relation.NamedWrite; +import io.substrait.relation.Project; +import io.substrait.relation.ProtoRelConverter; +import io.substrait.relation.Rel; +import io.substrait.relation.RelProtoConverter; +import io.substrait.relation.Set; +import io.substrait.relation.Set.SetOp; +import io.substrait.relation.Sort; +import io.substrait.relation.VirtualTableScan; +import io.substrait.relation.extensions.EmptyDetail; +import io.substrait.type.NamedStruct; +import io.substrait.type.TypeCreator; +import io.substrait.util.EmptyVisitationContext; +import io.substrait.utils.StringHolder; +import io.substrait.utils.StringHolderHandlingExtensionProtoConverter; +import io.substrait.utils.StringHolderHandlingProtoExtensionConverter; +import org.junit.jupiter.api.Test; + +class AdvancedExtensionRelProtoConversionTest { + final SubstraitBuilder builder = new SubstraitBuilder(DefaultExtensionCatalog.DEFAULT_COLLECTION); + + final StringHolder enhanced = new StringHolder("ENHANCED"); + final StringHolder optimized = new StringHolder("OPTIMIZED"); + final AdvancedExtension extension = + AdvancedExtension.builder().enhancement(enhanced).addOptimizations(optimized).build(); + + private void assertRoundTrip(final Rel rel) { + // assert AdvancedExtension serialization to proto + final ExtensionCollector functionCollector = new ExtensionCollector(); + final RelProtoConverter relProtoConverter = + new RelProtoConverter(functionCollector, new StringHolderHandlingExtensionProtoConverter()); + final io.substrait.proto.Rel protoRel = + rel.accept(relProtoConverter, EmptyVisitationContext.INSTANCE); + + // assert AdvancedExtension deserialization from proto + final ProtoRelConverter protoRelConverter = + new ProtoRelConverter(functionCollector, new StringHolderHandlingProtoExtensionConverter()); + final Rel rel2 = protoRelConverter.from(protoRel); + + assertEquals(rel, rel2); + } + + @Test + void testVirtualTableConversionRoundtrip() throws Exception { + final VirtualTableScan rel = + VirtualTableScan.builder() + .extension(extension) + .addRows( + StructLiteral.builder() + .addFields(BoolLiteral.builder().value(true).build()) + .build()) + .initialSchema( + NamedStruct.builder() + .addNames("IS_TRUE") + .struct(TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.BOOLEAN)) + .build()) + .build(); + + assertRoundTrip(rel); + } + + @Test + void testLocalFilesConversionRoundtrip() throws Exception { + final LocalFiles rel = + LocalFiles.builder() + .initialSchema(NamedStruct.builder().struct(TypeCreator.REQUIRED.struct()).build()) + .extension(extension) + .build(); + + assertRoundTrip(rel); + } + + @Test + void testNamedScanConversionRoundtrip() throws Exception { + final NamedScan rel = + NamedScan.builder() + .addNames("CUSTOMER") + .initialSchema(NamedStruct.builder().struct(TypeCreator.REQUIRED.struct()).build()) + .extension(extension) + .build(); + + assertRoundTrip(rel); + } + + @Test + void testExtensionTableConversionRoundtrip() throws Exception { + final ExtensionTable rel = + ExtensionTable.builder() + .detail(new EmptyDetail()) + .initialSchema(NamedStruct.builder().struct(TypeCreator.REQUIRED.struct()).build()) + .extension(extension) + .build(); + + assertRoundTrip(rel); + } + + @Test + void testFilterRelConversionRoundtrip() throws Exception { + final Filter rel = + Filter.builder() + .input( + EmptyScan.builder() + .initialSchema( + NamedStruct.builder().struct(TypeCreator.REQUIRED.struct()).build()) + .build()) + .condition(ExpressionCreator.bool(false, true)) + .extension(extension) + .build(); + + assertRoundTrip(rel); + } + + @Test + void testFetchRelConversionRoundtrip() throws Exception { + final Fetch rel = + Fetch.builder() + .input( + EmptyScan.builder() + .initialSchema( + NamedStruct.builder().struct(TypeCreator.REQUIRED.struct()).build()) + .build()) + .offset(0) + .count(10) + .extension(extension) + .build(); + + assertRoundTrip(rel); + } + + @Test + void testAggregateRelConversionRoundtrip() throws Exception { + final Aggregate rel = + Aggregate.builder() + .input( + EmptyScan.builder() + .initialSchema( + NamedStruct.builder().struct(TypeCreator.REQUIRED.struct()).build()) + .build()) + .addMeasures(builder.countStar()) + .extension(extension) + .build(); + + assertRoundTrip(rel); + } + + @Test + void testSortRelConversionRoundtrip() throws Exception { + final EmptyScan scan = + EmptyScan.builder() + .initialSchema( + NamedStruct.builder() + .addNames("KEY") + .struct(TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.BOOLEAN)) + .build()) + .build(); + final Sort rel = + Sort.builder() + .input(scan) + .addSortFields( + SortField.builder() + .direction(SortDirection.ASC_NULLS_FIRST) + .expr(builder.fieldReference(scan, 0)) + .build()) + .extension(extension) + .build(); + + assertRoundTrip(rel); + } + + @Test + void testJoinRelConversionRoundtrip() throws Exception { + final EmptyScan scan = + EmptyScan.builder() + .initialSchema( + NamedStruct.builder() + .addNames("KEY") + .struct(TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.BOOLEAN)) + .build()) + .build(); + final Join rel = + Join.builder() + .left(scan) + .right(scan) + .joinType(JoinType.INNER) + .condition( + builder.equal(builder.fieldReference(scan, 0), builder.fieldReference(scan, 0))) + .extension(extension) + .build(); + + assertRoundTrip(rel); + } + + @Test + void testProjectRelConversionRoundtrip() throws Exception { + final EmptyScan scan = + EmptyScan.builder() + .initialSchema( + NamedStruct.builder() + .addNames("KEY") + .struct(TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.BOOLEAN)) + .build()) + .build(); + final Project rel = + Project.builder() + .input(scan) + .addExpressions(builder.fieldReference(scan, 0)) + .extension(extension) + .build(); + + assertRoundTrip(rel); + } + + @Test + void testSetRelConversionRoundtrip() throws Exception { + final EmptyScan scan = + EmptyScan.builder() + .initialSchema( + NamedStruct.builder() + .addNames("KEY") + .struct(TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.BOOLEAN)) + .build()) + .build(); + final Set rel = + Set.builder() + .addInputs(scan, scan) + .setOp(SetOp.UNION_DISTINCT) + .extension(extension) + .build(); + + assertRoundTrip(rel); + } + + @Test + void testCrossRelConversionRoundtrip() throws Exception { + final EmptyScan scan = + EmptyScan.builder() + .initialSchema( + NamedStruct.builder() + .addNames("KEY") + .struct(TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.BOOLEAN)) + .build()) + .build(); + final Cross rel = Cross.builder().left(scan).right(scan).extension(extension).build(); + + assertRoundTrip(rel); + } + + @Test + void testNamedObjectWriteConversionRoundtrip() throws Exception { + final NamedStruct schema = + NamedStruct.builder() + .addNames("KEY") + .struct(TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.BOOLEAN)) + .build(); + final EmptyScan scan = EmptyScan.builder().initialSchema(schema).build(); + final NamedWrite rel = + NamedWrite.builder() + .createMode(CreateMode.REPLACE_IF_EXISTS) + .operation(WriteOp.INSERT) + .outputMode(OutputMode.NO_OUTPUT) + .tableSchema(schema) + .input(scan) + .extension(extension) + .build(); + + assertRoundTrip(rel); + } + + @Test + void testExtensionWriteRelConversionRoundtrip() throws Exception { + final NamedStruct schema = + NamedStruct.builder() + .addNames("KEY") + .struct(TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.BOOLEAN)) + .build(); + final EmptyScan scan = EmptyScan.builder().initialSchema(schema).build(); + final ExtensionWrite rel = + ExtensionWrite.builder() + .createMode(CreateMode.REPLACE_IF_EXISTS) + .operation(WriteOp.INSERT) + .outputMode(OutputMode.NO_OUTPUT) + .tableSchema(schema) + .detail(new EmptyDetail()) + .input(scan) + .extension(extension) + .build(); + + assertRoundTrip(rel); + } + + @Test + void testNamedDdlRelConversionRoundtrip() throws Exception { + final NamedStruct schema = + NamedStruct.builder() + .addNames("KEY") + .struct(TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.BOOLEAN)) + .build(); + final EmptyScan scan = EmptyScan.builder().initialSchema(schema).build(); + final NamedDdl rel = + NamedDdl.builder() + .addNames("CUSTOMER") + .operation(DdlOp.CREATE) + .tableSchema(schema) + .viewDefinition(scan) + .tableDefaults( + StructLiteral.builder() + .nullable(false) + .addFields(ExpressionCreator.bool(false, false)) + .build()) + .object(DdlObject.VIEW) + .extension(extension) + .build(); + + assertRoundTrip(rel); + } + + @Test + void testExtensionDdlRelConversionRoundtrip() throws Exception { + final NamedStruct schema = + NamedStruct.builder() + .addNames("KEY") + .struct(TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.BOOLEAN)) + .build(); + final EmptyScan scan = EmptyScan.builder().initialSchema(schema).build(); + final ExtensionDdl rel = + ExtensionDdl.builder() + .detail(new EmptyDetail()) + .operation(DdlOp.CREATE) + .tableSchema(schema) + .viewDefinition(scan) + .tableDefaults( + StructLiteral.builder() + .nullable(false) + .addFields(ExpressionCreator.bool(false, false)) + .build()) + .object(DdlObject.VIEW) + .extension(extension) + .build(); + + assertRoundTrip(rel); + } + + @Test + void testNamedUpdateRelConversionRoundtrip() throws Exception { + final NamedStruct schema = + NamedStruct.builder() + .addNames("KEY") + .struct(TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.BOOLEAN)) + .build(); + final NamedUpdate rel = + NamedUpdate.builder() + .addNames("CUSTOMER") + .tableSchema(schema) + .condition( + builder.equal( + FieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.BOOLEAN) + .build(), + builder.bool(true))) + .addTransformations( + TransformExpression.builder() + .columnTarget(0) + .transformation(ExpressionCreator.bool(false, false)) + .build()) + .extension(extension) + .build(); + + assertRoundTrip(rel); + } +} From 6361a404f3f5fed99288e43b8f4748bb8ab47ddd Mon Sep 17 00:00:00 2001 From: Niels Pardon Date: Mon, 27 Oct 2025 15:57:29 +0100 Subject: [PATCH 3/4] fix: address comments Signed-off-by: Niels Pardon --- .../src/main/java/io/substrait/relation/ProtoRelConverter.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 3b76aaf71..c3f7a4767 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -34,7 +34,6 @@ import io.substrait.proto.SortRel; import io.substrait.proto.UpdateRel; import io.substrait.proto.WriteRel; -import io.substrait.relation.ImmutableExtensionDdl.Builder; import io.substrait.relation.extensions.EmptyDetail; import io.substrait.relation.files.FileFormat; import io.substrait.relation.files.FileOrFiles; @@ -283,7 +282,7 @@ protected ExtensionDdl newExtensionDdl(final DdlRel rel) { final Extension.DdlExtensionObject detail = detailFromDdlExtensionObject(rel.getExtensionObject().getDetail()); final NamedStruct tableSchema = newNamedStruct(rel.getTableSchema()); - final Builder builder = + final ImmutableExtensionDdl.Builder builder = ExtensionDdl.builder() .detail(detail) .tableSchema(newNamedStruct(rel.getTableSchema())) From b6631126a5187fde4fbe0713b9ccaf53b3b1a82e Mon Sep 17 00:00:00 2001 From: Niels Pardon Date: Mon, 27 Oct 2025 16:44:24 +0100 Subject: [PATCH 4/4] fix: more consistent final declaration in changed code Signed-off-by: Niels Pardon --- .../java/io/substrait/dsl/SubstraitBuilder.java | 2 +- .../java/io/substrait/plan/PlanProtoConverter.java | 14 +++++++------- .../java/io/substrait/plan/ProtoPlanConverter.java | 2 +- .../io/substrait/relation/ProtoRelConverter.java | 10 +++++----- .../io/substrait/relation/RelProtoConverter.java | 5 +++-- 5 files changed, 17 insertions(+), 16 deletions(-) diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index d47d075d6..4ac5af818 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -611,7 +611,7 @@ public Aggregate.Measure count(Rel input, int field) { * @return the {@link Measure} representing {@code COUNT(*)} */ public Measure countStar() { - SimpleExtension.AggregateFunctionVariant declaration = + final SimpleExtension.AggregateFunctionVariant declaration = extensions.getAggregateFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_AGGREGATE_GENERIC, "count:")); diff --git a/core/src/main/java/io/substrait/plan/PlanProtoConverter.java b/core/src/main/java/io/substrait/plan/PlanProtoConverter.java index ea1c39118..03b2a4227 100644 --- a/core/src/main/java/io/substrait/plan/PlanProtoConverter.java +++ b/core/src/main/java/io/substrait/plan/PlanProtoConverter.java @@ -61,11 +61,11 @@ public PlanProtoConverter( this.extensionProtoConverter = extensionProtoConverter; } - public Plan toProto(io.substrait.plan.Plan plan) { - List planRels = new ArrayList<>(); - ExtensionCollector functionCollector = new ExtensionCollector(extensionCollection); - for (io.substrait.plan.Plan.Root root : plan.getRoots()) { - Rel input = + public Plan toProto(final io.substrait.plan.Plan plan) { + final List planRels = new ArrayList<>(); + final ExtensionCollector functionCollector = new ExtensionCollector(extensionCollection); + for (final io.substrait.plan.Plan.Root root : plan.getRoots()) { + final Rel input = new RelProtoConverter(functionCollector, extensionProtoConverter) .toProto(root.getInput()); planRels.add( @@ -76,7 +76,7 @@ public Plan toProto(io.substrait.plan.Plan plan) { .addAllNames(root.getNames())) .build()); } - Plan.Builder builder = + final Plan.Builder builder = Plan.newBuilder() .addAllRelations(planRels) .addAllExpectedTypeUrls(plan.getExpectedTypeUrls()); @@ -86,7 +86,7 @@ public Plan toProto(io.substrait.plan.Plan plan) { extensionProtoConverter.toProto(plan.getAdvancedExtension().get())); } - Version.Builder versionBuilder = + final Version.Builder versionBuilder = Version.newBuilder() .setMajorNumber(plan.getVersion().getMajor()) .setMinorNumber(plan.getVersion().getMinor()) diff --git a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java index 3a98438ed..12501c1ac 100644 --- a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java +++ b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java @@ -62,7 +62,7 @@ public ProtoPlanConverter( } /** Override hook for providing custom {@link ProtoRelConverter} implementations */ - protected ProtoRelConverter getProtoRelConverter(ExtensionLookup functionLookup) { + protected ProtoRelConverter getProtoRelConverter(final ExtensionLookup functionLookup) { return new ProtoRelConverter(functionLookup, this.extensionCollection, protoExtensionConverter); } diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index c3f7a4767..ce86f1de2 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -257,7 +257,7 @@ protected Rel newDdl(final DdlRel rel) { } } - protected NamedDdl newNamedDdl(DdlRel rel) { + protected NamedDdl newNamedDdl(final DdlRel rel) { final NamedStruct tableSchema = newNamedStruct(rel.getTableSchema()); final ImmutableNamedDdl.Builder builder = NamedDdl.builder() @@ -485,11 +485,11 @@ protected NamedScan newNamedScan(ReadRel rel) { return builder.build(); } - protected ExtensionTable newExtensionTable(ReadRel rel) { - NamedStruct namedStruct = newNamedStruct(rel); - Extension.ExtensionTableDetail detail = + protected ExtensionTable newExtensionTable(final ReadRel rel) { + final NamedStruct namedStruct = newNamedStruct(rel); + final Extension.ExtensionTableDetail detail = detailFromExtensionTable(rel.getExtensionTable().getDetail()); - ImmutableExtensionTable.Builder builder = + final ImmutableExtensionTable.Builder builder = ExtensionTable.from(detail).initialSchema(namedStruct); builder diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index 05858d831..1b3611388 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -201,8 +201,9 @@ private AggregateRel.Grouping toProto(Aggregate.Grouping grouping) { } @Override - public Rel visit(EmptyScan emptyScan, EmptyVisitationContext context) throws RuntimeException { - ReadRel.Builder builder = + public Rel visit(final EmptyScan emptyScan, EmptyVisitationContext context) + throws RuntimeException { + final ReadRel.Builder builder = ReadRel.newBuilder() .setCommon(common(emptyScan)) .setVirtualTable(ReadRel.VirtualTable.newBuilder().build())