Skip to content

Commit 71a742c

Browse files
authored
fix(core): close AdvancedExtension serde gaps (#569)
* fix(core): close AdvancedExtension serde gaps Signed-off-by: Niels Pardon <[email protected]> * fix: add test cases Signed-off-by: Niels Pardon <[email protected]> * fix: address comments Signed-off-by: Niels Pardon <[email protected]> * fix: more consistent final declaration in changed code Signed-off-by: Niels Pardon <[email protected]> --------- Signed-off-by: Niels Pardon <[email protected]>
1 parent e50ecd2 commit 71a742c

File tree

11 files changed

+573
-63
lines changed

11 files changed

+573
-63
lines changed

core/src/main/java/io/substrait/dsl/SubstraitBuilder.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import io.substrait.plan.Plan;
2020
import io.substrait.relation.AbstractWriteRel;
2121
import io.substrait.relation.Aggregate;
22+
import io.substrait.relation.Aggregate.Measure;
2223
import io.substrait.relation.Cross;
2324
import io.substrait.relation.EmptyScan;
2425
import io.substrait.relation.Expand;
@@ -604,6 +605,25 @@ public Aggregate.Measure count(Rel input, int field) {
604605
.build());
605606
}
606607

608+
/**
609+
* Returns a {@link Measure} representing the equivalent of a {@code COUNT(*)} SQL aggregation.
610+
*
611+
* @return the {@link Measure} representing {@code COUNT(*)}
612+
*/
613+
public Measure countStar() {
614+
final SimpleExtension.AggregateFunctionVariant declaration =
615+
extensions.getAggregateFunction(
616+
SimpleExtension.FunctionAnchor.of(
617+
DefaultExtensionCatalog.FUNCTIONS_AGGREGATE_GENERIC, "count:"));
618+
return measure(
619+
AggregateFunctionInvocation.builder()
620+
.outputType(R.I64)
621+
.declaration(declaration)
622+
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
623+
.invocation(Expression.AggregationInvocation.ALL)
624+
.build());
625+
}
626+
607627
public Aggregate.Measure min(Rel input, int field) {
608628
return min(fieldReference(input, field));
609629
}

core/src/main/java/io/substrait/plan/PlanProtoConverter.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,13 @@ public PlanProtoConverter(
6161
this.extensionProtoConverter = extensionProtoConverter;
6262
}
6363

64-
public Plan toProto(io.substrait.plan.Plan plan) {
65-
List<PlanRel> planRels = new ArrayList<>();
66-
ExtensionCollector functionCollector = new ExtensionCollector(extensionCollection);
67-
for (io.substrait.plan.Plan.Root root : plan.getRoots()) {
68-
Rel input = new RelProtoConverter(functionCollector).toProto(root.getInput());
64+
public Plan toProto(final io.substrait.plan.Plan plan) {
65+
final List<PlanRel> planRels = new ArrayList<>();
66+
final ExtensionCollector functionCollector = new ExtensionCollector(extensionCollection);
67+
for (final io.substrait.plan.Plan.Root root : plan.getRoots()) {
68+
final Rel input =
69+
new RelProtoConverter(functionCollector, extensionProtoConverter)
70+
.toProto(root.getInput());
6971
planRels.add(
7072
PlanRel.newBuilder()
7173
.setRoot(
@@ -74,7 +76,7 @@ public Plan toProto(io.substrait.plan.Plan plan) {
7476
.addAllNames(root.getNames()))
7577
.build());
7678
}
77-
Plan.Builder builder =
79+
final Plan.Builder builder =
7880
Plan.newBuilder()
7981
.addAllRelations(planRels)
8082
.addAllExpectedTypeUrls(plan.getExpectedTypeUrls());
@@ -84,7 +86,7 @@ public Plan toProto(io.substrait.plan.Plan plan) {
8486
extensionProtoConverter.toProto(plan.getAdvancedExtension().get()));
8587
}
8688

87-
Version.Builder versionBuilder =
89+
final Version.Builder versionBuilder =
8890
Version.newBuilder()
8991
.setMajorNumber(plan.getVersion().getMajor())
9092
.setMinorNumber(plan.getVersion().getMinor())

core/src/main/java/io/substrait/plan/ProtoPlanConverter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ public ProtoPlanConverter(
6262
}
6363

6464
/** Override hook for providing custom {@link ProtoRelConverter} implementations */
65-
protected ProtoRelConverter getProtoRelConverter(ExtensionLookup functionLookup) {
66-
return new ProtoRelConverter(functionLookup, this.extensionCollection);
65+
protected ProtoRelConverter getProtoRelConverter(final ExtensionLookup functionLookup) {
66+
return new ProtoRelConverter(functionLookup, this.extensionCollection, protoExtensionConverter);
6767
}
6868

6969
public Plan from(io.substrait.proto.Plan plan) {

core/src/main/java/io/substrait/relation/AbstractDdlRel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import io.substrait.type.Type;
77
import java.util.Optional;
88

9-
public abstract class AbstractDdlRel extends ZeroInputRel {
9+
public abstract class AbstractDdlRel extends ZeroInputRel implements HasExtension {
1010
public abstract NamedStruct getTableSchema();
1111

1212
public abstract Expression.StructLiteral getTableDefaults();

core/src/main/java/io/substrait/relation/AbstractWriteRel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import io.substrait.type.NamedStruct;
55
import io.substrait.type.Type;
66

7-
public abstract class AbstractWriteRel extends SingleInputRel {
7+
public abstract class AbstractWriteRel extends SingleInputRel implements HasExtension {
88

99
public abstract NamedStruct getTableSchema();
1010

core/src/main/java/io/substrait/relation/ExtensionWrite.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import org.immutables.value.Value;
55

66
@Value.Immutable
7-
public abstract class ExtensionWrite extends AbstractWriteRel implements HasExtension {
7+
public abstract class ExtensionWrite extends AbstractWriteRel {
88
public abstract Extension.WriteExtensionObject getDetail();
99

1010
@Override

core/src/main/java/io/substrait/relation/NamedWrite.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import org.immutables.value.Value;
66

77
@Value.Immutable
8-
public abstract class NamedWrite extends AbstractWriteRel implements HasExtension {
8+
public abstract class NamedWrite extends AbstractWriteRel {
99
public abstract List<String> getNames();
1010

1111
@Override

core/src/main/java/io/substrait/relation/ProtoRelConverter.java

Lines changed: 75 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,18 @@ public ProtoRelConverter(
7777
this(lookup, extensions, new ProtoExtensionConverter());
7878
}
7979

80+
/**
81+
* Constructor with custom {@link ExtensionLookup} and {@link ProtoExtensionConverter}.
82+
*
83+
* @param lookup custom {@link ExtensionLookup} to use, must not be null
84+
* @param protoExtensionConverter custom {@link ProtoExtensionConverter} to use, must not be null
85+
*/
86+
public ProtoRelConverter(
87+
@NonNull final ExtensionLookup lookup,
88+
@NonNull final ProtoExtensionConverter protoExtensionConverter) {
89+
this(lookup, DefaultExtensionCatalog.DEFAULT_COLLECTION, protoExtensionConverter);
90+
}
91+
8092
/**
8193
* Constructor with custom {@link ExtensionLookup}, {@link ExtensionCollection} and {@link
8294
* ProtoExtensionConverter}.
@@ -175,8 +187,8 @@ protected Rel newRead(ReadRel rel) {
175187
}
176188
}
177189

178-
protected Rel newWrite(WriteRel rel) {
179-
WriteRel.WriteTypeCase relType = rel.getWriteTypeCase();
190+
protected Rel newWrite(final WriteRel rel) {
191+
final WriteRel.WriteTypeCase relType = rel.getWriteTypeCase();
180192
switch (relType) {
181193
case NAMED_TABLE:
182194
return newNamedWrite(rel);
@@ -187,9 +199,9 @@ protected Rel newWrite(WriteRel rel) {
187199
}
188200
}
189201

190-
protected NamedWrite newNamedWrite(WriteRel rel) {
191-
Rel input = from(rel.getInput());
192-
ImmutableNamedWrite.Builder builder =
202+
protected NamedWrite newNamedWrite(final WriteRel rel) {
203+
final Rel input = from(rel.getInput());
204+
final ImmutableNamedWrite.Builder builder =
193205
NamedWrite.builder()
194206
.input(input)
195207
.names(rel.getNamedTable().getNamesList())
@@ -202,14 +214,18 @@ protected NamedWrite newNamedWrite(WriteRel rel) {
202214
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
203215
.remap(optionalRelmap(rel.getCommon()))
204216
.hint(optionalHint(rel.getCommon()));
217+
218+
if (rel.hasAdvancedExtension()) {
219+
builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension()));
220+
}
205221
return builder.build();
206222
}
207223

208-
protected Rel newExtensionWrite(WriteRel rel) {
209-
Rel input = from(rel.getInput());
210-
Extension.WriteExtensionObject detail =
224+
protected Rel newExtensionWrite(final WriteRel rel) {
225+
final Rel input = from(rel.getInput());
226+
final Extension.WriteExtensionObject detail =
211227
detailFromWriteExtensionObject(rel.getExtensionTable().getDetail());
212-
ImmutableExtensionWrite.Builder builder =
228+
final ImmutableExtensionWrite.Builder builder =
213229
ExtensionWrite.builder()
214230
.input(input)
215231
.detail(detail)
@@ -222,11 +238,15 @@ protected Rel newExtensionWrite(WriteRel rel) {
222238
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
223239
.remap(optionalRelmap(rel.getCommon()))
224240
.hint(optionalHint(rel.getCommon()));
241+
242+
if (rel.hasAdvancedExtension()) {
243+
builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension()));
244+
}
225245
return builder.build();
226246
}
227247

228-
protected Rel newDdl(DdlRel rel) {
229-
DdlRel.WriteTypeCase relType = rel.getWriteTypeCase();
248+
protected Rel newDdl(final DdlRel rel) {
249+
final DdlRel.WriteTypeCase relType = rel.getWriteTypeCase();
230250
switch (relType) {
231251
case NAMED_OBJECT:
232252
return newNamedDdl(rel);
@@ -237,36 +257,48 @@ protected Rel newDdl(DdlRel rel) {
237257
}
238258
}
239259

240-
protected NamedDdl newNamedDdl(DdlRel rel) {
241-
NamedStruct tableSchema = newNamedStruct(rel.getTableSchema());
242-
return NamedDdl.builder()
243-
.names(rel.getNamedObject().getNamesList())
244-
.tableSchema(tableSchema)
245-
.tableDefaults(tableDefaults(rel.getTableDefaults(), tableSchema))
246-
.operation(NamedDdl.DdlOp.fromProto(rel.getOp()))
247-
.object(NamedDdl.DdlObject.fromProto(rel.getObject()))
248-
.viewDefinition(optionalViewDefinition(rel))
249-
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
250-
.remap(optionalRelmap(rel.getCommon()))
251-
.hint(optionalHint(rel.getCommon()))
252-
.build();
260+
protected NamedDdl newNamedDdl(final DdlRel rel) {
261+
final NamedStruct tableSchema = newNamedStruct(rel.getTableSchema());
262+
final ImmutableNamedDdl.Builder builder =
263+
NamedDdl.builder()
264+
.names(rel.getNamedObject().getNamesList())
265+
.tableSchema(tableSchema)
266+
.tableDefaults(tableDefaults(rel.getTableDefaults(), tableSchema))
267+
.operation(NamedDdl.DdlOp.fromProto(rel.getOp()))
268+
.object(NamedDdl.DdlObject.fromProto(rel.getObject()))
269+
.viewDefinition(optionalViewDefinition(rel))
270+
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
271+
.remap(optionalRelmap(rel.getCommon()))
272+
.hint(optionalHint(rel.getCommon()));
273+
274+
if (rel.hasAdvancedExtension()) {
275+
builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension()));
276+
}
277+
278+
return builder.build();
253279
}
254280

255-
protected ExtensionDdl newExtensionDdl(DdlRel rel) {
256-
Extension.DdlExtensionObject detail =
281+
protected ExtensionDdl newExtensionDdl(final DdlRel rel) {
282+
final Extension.DdlExtensionObject detail =
257283
detailFromDdlExtensionObject(rel.getExtensionObject().getDetail());
258-
NamedStruct tableSchema = newNamedStruct(rel.getTableSchema());
259-
return ExtensionDdl.builder()
260-
.detail(detail)
261-
.tableSchema(newNamedStruct(rel.getTableSchema()))
262-
.tableDefaults(tableDefaults(rel.getTableDefaults(), tableSchema))
263-
.operation(ExtensionDdl.DdlOp.fromProto(rel.getOp()))
264-
.object(ExtensionDdl.DdlObject.fromProto(rel.getObject()))
265-
.viewDefinition(optionalViewDefinition(rel))
266-
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
267-
.remap(optionalRelmap(rel.getCommon()))
268-
.hint(optionalHint(rel.getCommon()))
269-
.build();
284+
final NamedStruct tableSchema = newNamedStruct(rel.getTableSchema());
285+
final ImmutableExtensionDdl.Builder builder =
286+
ExtensionDdl.builder()
287+
.detail(detail)
288+
.tableSchema(newNamedStruct(rel.getTableSchema()))
289+
.tableDefaults(tableDefaults(rel.getTableDefaults(), tableSchema))
290+
.operation(ExtensionDdl.DdlOp.fromProto(rel.getOp()))
291+
.object(ExtensionDdl.DdlObject.fromProto(rel.getObject()))
292+
.viewDefinition(optionalViewDefinition(rel))
293+
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
294+
.remap(optionalRelmap(rel.getCommon()))
295+
.hint(optionalHint(rel.getCommon()));
296+
297+
if (rel.hasAdvancedExtension()) {
298+
builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension()));
299+
}
300+
301+
return builder.build();
270302
}
271303

272304
protected Optional<Rel> optionalViewDefinition(DdlRel rel) {
@@ -453,10 +485,12 @@ protected NamedScan newNamedScan(ReadRel rel) {
453485
return builder.build();
454486
}
455487

456-
protected ExtensionTable newExtensionTable(ReadRel rel) {
457-
Extension.ExtensionTableDetail detail =
488+
protected ExtensionTable newExtensionTable(final ReadRel rel) {
489+
final NamedStruct namedStruct = newNamedStruct(rel);
490+
final Extension.ExtensionTableDetail detail =
458491
detailFromExtensionTable(rel.getExtensionTable().getDetail());
459-
ImmutableExtensionTable.Builder builder = ExtensionTable.from(detail);
492+
final ImmutableExtensionTable.Builder builder =
493+
ExtensionTable.from(detail).initialSchema(namedStruct);
460494

461495
builder
462496
.commonExtension(optionalAdvancedExtension(rel.getCommon()))

core/src/main/java/io/substrait/relation/RelProtoConverter.java

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -201,15 +201,17 @@ private AggregateRel.Grouping toProto(Aggregate.Grouping grouping) {
201201
}
202202

203203
@Override
204-
public Rel visit(EmptyScan emptyScan, EmptyVisitationContext context) throws RuntimeException {
205-
return Rel.newBuilder()
206-
.setRead(
207-
ReadRel.newBuilder()
208-
.setCommon(common(emptyScan))
209-
.setVirtualTable(ReadRel.VirtualTable.newBuilder().build())
210-
.setBaseSchema(emptyScan.getInitialSchema().toProto(typeProtoConverter))
211-
.build())
212-
.build();
204+
public Rel visit(final EmptyScan emptyScan, EmptyVisitationContext context)
205+
throws RuntimeException {
206+
final ReadRel.Builder builder =
207+
ReadRel.newBuilder()
208+
.setCommon(common(emptyScan))
209+
.setVirtualTable(ReadRel.VirtualTable.newBuilder().build())
210+
.setBaseSchema(emptyScan.getInitialSchema().toProto(typeProtoConverter));
211+
emptyScan
212+
.getExtension()
213+
.ifPresent(ae -> builder.setAdvancedExtension(extensionProtoConverter.toProto(ae)));
214+
return Rel.newBuilder().setRead(builder.build()).build();
213215
}
214216

215217
@Override
@@ -435,6 +437,10 @@ public Rel visit(NamedWrite write, EmptyVisitationContext context) throws Runtim
435437
.setCreateMode(write.getCreateMode().toProto())
436438
.setOutput(write.getOutputMode().toProto());
437439

440+
write
441+
.getExtension()
442+
.ifPresent(ae -> builder.setAdvancedExtension(extensionProtoConverter.toProto(ae)));
443+
438444
return Rel.newBuilder().setWrite(builder).build();
439445
}
440446

@@ -451,6 +457,10 @@ public Rel visit(ExtensionWrite write, EmptyVisitationContext context) throws Ru
451457
.setCreateMode(write.getCreateMode().toProto())
452458
.setOutput(write.getOutputMode().toProto());
453459

460+
write
461+
.getExtension()
462+
.ifPresent(ae -> builder.setAdvancedExtension(extensionProtoConverter.toProto(ae)));
463+
454464
return Rel.newBuilder().setWrite(builder).build();
455465
}
456466

@@ -468,6 +478,9 @@ public Rel visit(NamedDdl ddl, EmptyVisitationContext context) throws RuntimeExc
468478
builder.setViewDefinition(toProto(ddl.getViewDefinition().get()));
469479
}
470480

481+
ddl.getExtension()
482+
.ifPresent(ae -> builder.setAdvancedExtension(extensionProtoConverter.toProto(ae)));
483+
471484
return Rel.newBuilder().setDdl(builder).build();
472485
}
473486

@@ -486,6 +499,9 @@ public Rel visit(ExtensionDdl ddl, EmptyVisitationContext context) throws Runtim
486499
builder.setViewDefinition(toProto(ddl.getViewDefinition().get()));
487500
}
488501

502+
ddl.getExtension()
503+
.ifPresent(ae -> builder.setAdvancedExtension(extensionProtoConverter.toProto(ae)));
504+
489505
return Rel.newBuilder().setDdl(builder).build();
490506
}
491507

0 commit comments

Comments
 (0)