Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.substrait.plan.Plan;
import io.substrait.relation.AbstractWriteRel;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Aggregate.Measure;
import io.substrait.relation.Cross;
import io.substrait.relation.EmptyScan;
import io.substrait.relation.Expand;
Expand Down Expand Up @@ -604,6 +605,25 @@ public Aggregate.Measure count(Rel input, int field) {
.build());
}

/**
* Returns a {@link Measure} representing the equivalent of a {@code COUNT(*)} SQL aggregation.
*
* @return the {@link Measure} representing {@code COUNT(*)}
*/
public Measure countStar() {
final SimpleExtension.AggregateFunctionVariant declaration =
extensions.getAggregateFunction(
SimpleExtension.FunctionAnchor.of(
DefaultExtensionCatalog.FUNCTIONS_AGGREGATE_GENERIC, "count:"));
return measure(
AggregateFunctionInvocation.builder()
.outputType(R.I64)
.declaration(declaration)
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
.invocation(Expression.AggregationInvocation.ALL)
.build());
}

public Aggregate.Measure min(Rel input, int field) {
return min(fieldReference(input, field));
}
Expand Down
16 changes: 9 additions & 7 deletions core/src/main/java/io/substrait/plan/PlanProtoConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ public PlanProtoConverter(
this.extensionProtoConverter = extensionProtoConverter;
}

public Plan toProto(io.substrait.plan.Plan plan) {
List<PlanRel> planRels = new ArrayList<>();
ExtensionCollector functionCollector = new ExtensionCollector(extensionCollection);
for (io.substrait.plan.Plan.Root root : plan.getRoots()) {
Rel input = new RelProtoConverter(functionCollector).toProto(root.getInput());
public Plan toProto(final io.substrait.plan.Plan plan) {
final List<PlanRel> planRels = new ArrayList<>();
final ExtensionCollector functionCollector = new ExtensionCollector(extensionCollection);
for (final io.substrait.plan.Plan.Root root : plan.getRoots()) {
final Rel input =
new RelProtoConverter(functionCollector, extensionProtoConverter)
.toProto(root.getInput());
planRels.add(
PlanRel.newBuilder()
.setRoot(
Expand All @@ -74,7 +76,7 @@ public Plan toProto(io.substrait.plan.Plan plan) {
.addAllNames(root.getNames()))
.build());
}
Plan.Builder builder =
final Plan.Builder builder =
Plan.newBuilder()
.addAllRelations(planRels)
.addAllExpectedTypeUrls(plan.getExpectedTypeUrls());
Expand All @@ -84,7 +86,7 @@ public Plan toProto(io.substrait.plan.Plan plan) {
extensionProtoConverter.toProto(plan.getAdvancedExtension().get()));
}

Version.Builder versionBuilder =
final Version.Builder versionBuilder =
Version.newBuilder()
.setMajorNumber(plan.getVersion().getMajor())
.setMinorNumber(plan.getVersion().getMinor())
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/java/io/substrait/plan/ProtoPlanConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ public ProtoPlanConverter(
}

/** Override hook for providing custom {@link ProtoRelConverter} implementations */
protected ProtoRelConverter getProtoRelConverter(ExtensionLookup functionLookup) {
return new ProtoRelConverter(functionLookup, this.extensionCollection);
protected ProtoRelConverter getProtoRelConverter(final ExtensionLookup functionLookup) {
return new ProtoRelConverter(functionLookup, this.extensionCollection, protoExtensionConverter);
}

public Plan from(io.substrait.proto.Plan plan) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import io.substrait.type.Type;
import java.util.Optional;

public abstract class AbstractDdlRel extends ZeroInputRel {
public abstract class AbstractDdlRel extends ZeroInputRel implements HasExtension {
public abstract NamedStruct getTableSchema();

public abstract Expression.StructLiteral getTableDefaults();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;

public abstract class AbstractWriteRel extends SingleInputRel {
public abstract class AbstractWriteRel extends SingleInputRel implements HasExtension {

public abstract NamedStruct getTableSchema();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import org.immutables.value.Value;

@Value.Immutable
public abstract class ExtensionWrite extends AbstractWriteRel implements HasExtension {
public abstract class ExtensionWrite extends AbstractWriteRel {
public abstract Extension.WriteExtensionObject getDetail();

@Override
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/io/substrait/relation/NamedWrite.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import org.immutables.value.Value;

@Value.Immutable
public abstract class NamedWrite extends AbstractWriteRel implements HasExtension {
public abstract class NamedWrite extends AbstractWriteRel {
public abstract List<String> getNames();

@Override
Expand Down
116 changes: 75 additions & 41 deletions core/src/main/java/io/substrait/relation/ProtoRelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ public ProtoRelConverter(
this(lookup, extensions, new ProtoExtensionConverter());
}

/**
* Constructor with custom {@link ExtensionLookup} and {@link ProtoExtensionConverter}.
*
* @param lookup custom {@link ExtensionLookup} to use, must not be null
* @param protoExtensionConverter custom {@link ProtoExtensionConverter} to use, must not be null
*/
public ProtoRelConverter(
@NonNull final ExtensionLookup lookup,
@NonNull final ProtoExtensionConverter protoExtensionConverter) {
this(lookup, DefaultExtensionCatalog.DEFAULT_COLLECTION, protoExtensionConverter);
}

/**
* Constructor with custom {@link ExtensionLookup}, {@link ExtensionCollection} and {@link
* ProtoExtensionConverter}.
Expand Down Expand Up @@ -175,8 +187,8 @@ protected Rel newRead(ReadRel rel) {
}
}

protected Rel newWrite(WriteRel rel) {
WriteRel.WriteTypeCase relType = rel.getWriteTypeCase();
protected Rel newWrite(final WriteRel rel) {
final WriteRel.WriteTypeCase relType = rel.getWriteTypeCase();
Comment on lines +190 to +191
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, why did you add the final modifier to the Write/Update/DDL relations, but not all the other relations that use this same code pattern?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mainly did it for the code I thought I touched to not introduce too many unrelated changes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean, let me update the PR to be a little more consistent

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be more consistent now

switch (relType) {
case NAMED_TABLE:
return newNamedWrite(rel);
Expand All @@ -187,9 +199,9 @@ protected Rel newWrite(WriteRel rel) {
}
}

protected NamedWrite newNamedWrite(WriteRel rel) {
Rel input = from(rel.getInput());
ImmutableNamedWrite.Builder builder =
protected NamedWrite newNamedWrite(final WriteRel rel) {
final Rel input = from(rel.getInput());
final ImmutableNamedWrite.Builder builder =
NamedWrite.builder()
.input(input)
.names(rel.getNamedTable().getNamesList())
Expand All @@ -202,14 +214,18 @@ protected NamedWrite newNamedWrite(WriteRel rel) {
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
.remap(optionalRelmap(rel.getCommon()))
.hint(optionalHint(rel.getCommon()));

if (rel.hasAdvancedExtension()) {
builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension()));
}
return builder.build();
}

protected Rel newExtensionWrite(WriteRel rel) {
Rel input = from(rel.getInput());
Extension.WriteExtensionObject detail =
protected Rel newExtensionWrite(final WriteRel rel) {
final Rel input = from(rel.getInput());
final Extension.WriteExtensionObject detail =
detailFromWriteExtensionObject(rel.getExtensionTable().getDetail());
ImmutableExtensionWrite.Builder builder =
final ImmutableExtensionWrite.Builder builder =
ExtensionWrite.builder()
.input(input)
.detail(detail)
Expand All @@ -222,11 +238,15 @@ protected Rel newExtensionWrite(WriteRel rel) {
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
.remap(optionalRelmap(rel.getCommon()))
.hint(optionalHint(rel.getCommon()));

if (rel.hasAdvancedExtension()) {
builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension()));
}
return builder.build();
}

protected Rel newDdl(DdlRel rel) {
DdlRel.WriteTypeCase relType = rel.getWriteTypeCase();
protected Rel newDdl(final DdlRel rel) {
final DdlRel.WriteTypeCase relType = rel.getWriteTypeCase();
switch (relType) {
case NAMED_OBJECT:
return newNamedDdl(rel);
Expand All @@ -237,36 +257,48 @@ protected Rel newDdl(DdlRel rel) {
}
}

protected NamedDdl newNamedDdl(DdlRel rel) {
NamedStruct tableSchema = newNamedStruct(rel.getTableSchema());
return NamedDdl.builder()
.names(rel.getNamedObject().getNamesList())
.tableSchema(tableSchema)
.tableDefaults(tableDefaults(rel.getTableDefaults(), tableSchema))
.operation(NamedDdl.DdlOp.fromProto(rel.getOp()))
.object(NamedDdl.DdlObject.fromProto(rel.getObject()))
.viewDefinition(optionalViewDefinition(rel))
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
.remap(optionalRelmap(rel.getCommon()))
.hint(optionalHint(rel.getCommon()))
.build();
protected NamedDdl newNamedDdl(final DdlRel rel) {
final NamedStruct tableSchema = newNamedStruct(rel.getTableSchema());
final ImmutableNamedDdl.Builder builder =
NamedDdl.builder()
.names(rel.getNamedObject().getNamesList())
.tableSchema(tableSchema)
.tableDefaults(tableDefaults(rel.getTableDefaults(), tableSchema))
.operation(NamedDdl.DdlOp.fromProto(rel.getOp()))
.object(NamedDdl.DdlObject.fromProto(rel.getObject()))
.viewDefinition(optionalViewDefinition(rel))
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
.remap(optionalRelmap(rel.getCommon()))
.hint(optionalHint(rel.getCommon()));

if (rel.hasAdvancedExtension()) {
builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension()));
}

return builder.build();
}

protected ExtensionDdl newExtensionDdl(DdlRel rel) {
Extension.DdlExtensionObject detail =
protected ExtensionDdl newExtensionDdl(final DdlRel rel) {
final Extension.DdlExtensionObject detail =
detailFromDdlExtensionObject(rel.getExtensionObject().getDetail());
NamedStruct tableSchema = newNamedStruct(rel.getTableSchema());
return ExtensionDdl.builder()
.detail(detail)
.tableSchema(newNamedStruct(rel.getTableSchema()))
.tableDefaults(tableDefaults(rel.getTableDefaults(), tableSchema))
.operation(ExtensionDdl.DdlOp.fromProto(rel.getOp()))
.object(ExtensionDdl.DdlObject.fromProto(rel.getObject()))
.viewDefinition(optionalViewDefinition(rel))
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
.remap(optionalRelmap(rel.getCommon()))
.hint(optionalHint(rel.getCommon()))
.build();
final NamedStruct tableSchema = newNamedStruct(rel.getTableSchema());
final ImmutableExtensionDdl.Builder builder =
ExtensionDdl.builder()
.detail(detail)
.tableSchema(newNamedStruct(rel.getTableSchema()))
.tableDefaults(tableDefaults(rel.getTableDefaults(), tableSchema))
.operation(ExtensionDdl.DdlOp.fromProto(rel.getOp()))
.object(ExtensionDdl.DdlObject.fromProto(rel.getObject()))
.viewDefinition(optionalViewDefinition(rel))
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
.remap(optionalRelmap(rel.getCommon()))
.hint(optionalHint(rel.getCommon()));

if (rel.hasAdvancedExtension()) {
builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension()));
}

return builder.build();
}

protected Optional<Rel> optionalViewDefinition(DdlRel rel) {
Expand Down Expand Up @@ -453,10 +485,12 @@ protected NamedScan newNamedScan(ReadRel rel) {
return builder.build();
}

protected ExtensionTable newExtensionTable(ReadRel rel) {
Extension.ExtensionTableDetail detail =
protected ExtensionTable newExtensionTable(final ReadRel rel) {
final NamedStruct namedStruct = newNamedStruct(rel);
final Extension.ExtensionTableDetail detail =
detailFromExtensionTable(rel.getExtensionTable().getDetail());
ImmutableExtensionTable.Builder builder = ExtensionTable.from(detail);
final ImmutableExtensionTable.Builder builder =
ExtensionTable.from(detail).initialSchema(namedStruct);

builder
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
Expand Down
34 changes: 25 additions & 9 deletions core/src/main/java/io/substrait/relation/RelProtoConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -201,15 +201,17 @@ private AggregateRel.Grouping toProto(Aggregate.Grouping grouping) {
}

@Override
public Rel visit(EmptyScan emptyScan, EmptyVisitationContext context) throws RuntimeException {
return Rel.newBuilder()
.setRead(
ReadRel.newBuilder()
.setCommon(common(emptyScan))
.setVirtualTable(ReadRel.VirtualTable.newBuilder().build())
.setBaseSchema(emptyScan.getInitialSchema().toProto(typeProtoConverter))
.build())
.build();
public Rel visit(final EmptyScan emptyScan, EmptyVisitationContext context)
throws RuntimeException {
final ReadRel.Builder builder =
ReadRel.newBuilder()
.setCommon(common(emptyScan))
.setVirtualTable(ReadRel.VirtualTable.newBuilder().build())
.setBaseSchema(emptyScan.getInitialSchema().toProto(typeProtoConverter));
emptyScan
.getExtension()
.ifPresent(ae -> builder.setAdvancedExtension(extensionProtoConverter.toProto(ae)));
return Rel.newBuilder().setRead(builder.build()).build();
}

@Override
Expand Down Expand Up @@ -435,6 +437,10 @@ public Rel visit(NamedWrite write, EmptyVisitationContext context) throws Runtim
.setCreateMode(write.getCreateMode().toProto())
.setOutput(write.getOutputMode().toProto());

write
.getExtension()
.ifPresent(ae -> builder.setAdvancedExtension(extensionProtoConverter.toProto(ae)));

return Rel.newBuilder().setWrite(builder).build();
}

Expand All @@ -451,6 +457,10 @@ public Rel visit(ExtensionWrite write, EmptyVisitationContext context) throws Ru
.setCreateMode(write.getCreateMode().toProto())
.setOutput(write.getOutputMode().toProto());

write
.getExtension()
.ifPresent(ae -> builder.setAdvancedExtension(extensionProtoConverter.toProto(ae)));

return Rel.newBuilder().setWrite(builder).build();
}

Expand All @@ -468,6 +478,9 @@ public Rel visit(NamedDdl ddl, EmptyVisitationContext context) throws RuntimeExc
builder.setViewDefinition(toProto(ddl.getViewDefinition().get()));
}

ddl.getExtension()
.ifPresent(ae -> builder.setAdvancedExtension(extensionProtoConverter.toProto(ae)));

return Rel.newBuilder().setDdl(builder).build();
}

Expand All @@ -486,6 +499,9 @@ public Rel visit(ExtensionDdl ddl, EmptyVisitationContext context) throws Runtim
builder.setViewDefinition(toProto(ddl.getViewDefinition().get()));
}

ddl.getExtension()
.ifPresent(ae -> builder.setAdvancedExtension(extensionProtoConverter.toProto(ae)));

return Rel.newBuilder().setDdl(builder).build();
}

Expand Down
Loading