Skip to content

Commit ed27133

Browse files
committed
Add helper fn to construct VirtualTable from go primitives
Closes #162
1 parent 1803339 commit ed27133

File tree

2 files changed

+436
-0
lines changed

2 files changed

+436
-0
lines changed

plan/builders.go

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,12 @@ type Builder interface {
124124
// Deprecated: Use VirtualTableFromExpr(...).Remap() instead.
125125
VirtualTableFromExprRemap(fieldNames []string, remap []int32, values ...expr.VirtualTableExpressionValue) (*VirtualTableReadRel, error)
126126
VirtualTableFromExpr(fieldNames []string, values ...expr.VirtualTableExpressionValue) (*VirtualTableReadRel, error)
127+
// VirtualTableFromGoTypes constructs a VirtualTableReadRel from native Go types.
128+
// It accepts field names, tuples of polymorphic Go values, and optional nullability
129+
// configuration. The function automatically maps Go types to appropriate Substrait types
130+
// and handles nil values by converting them to typed null literals.
131+
// If nullableColumns is nil, all columns default to non-nullable (required).
132+
VirtualTableFromGoTypes(fieldNames []string, tuples [][]any, nullableColumns []bool) (*VirtualTableReadRel, error)
127133
IcebergTableFromMetadataFile(metadataURI string, snapshot IcebergSnapshot, schema types.NamedStruct) (*IcebergTableReadRel, error)
128134
// Deprecated: Use Sort(...).Remap() instead.
129135
SortRemap(input Rel, remap []int32, sorts ...expr.SortField) (*SortRel, error)
@@ -619,6 +625,214 @@ func (b *builder) VirtualTable(fields []string, values ...expr.StructLiteralValu
619625
return b.VirtualTableRemap(fields, nil, values...)
620626
}
621627

628+
func (b *builder) VirtualTableFromGoTypes(fieldNames []string, tuples [][]any, nullableColumns []bool) (*VirtualTableReadRel, error) {
629+
// Need at least one tuple to infer column types from Go values.
630+
// Empty virtual tables are valid, but require explicit type specification via VirtualTableFromExpr.
631+
if len(tuples) == 0 {
632+
return nil, fmt.Errorf("%w: must provide at least one tuple for virtual table", substraitgo.ErrInvalidRel)
633+
}
634+
635+
nfields := len(fieldNames)
636+
if nfields == 0 {
637+
return nil, fmt.Errorf("%w: must provide at least one field name", substraitgo.ErrInvalidRel)
638+
}
639+
640+
for i, tuple := range tuples {
641+
if len(tuple) != nfields {
642+
return nil, fmt.Errorf("%w: tuple %d has %d values, expected %d", substraitgo.ErrInvalidRel, i, len(tuple), nfields)
643+
}
644+
}
645+
646+
if nullableColumns == nil {
647+
// default behavior is that none of the columns are nullable
648+
nullableColumns = make([]bool, nfields)
649+
} else if len(nullableColumns) != nfields {
650+
return nil, fmt.Errorf("%w: nullableColumns length (%d) must match fieldNames length (%d) or be nil", substraitgo.ErrInvalidRel, len(nullableColumns), nfields)
651+
}
652+
653+
columnTypes, err := inferColumnTypesFromGoTypes(b, tuples, fieldNames, nullableColumns)
654+
if err != nil {
655+
return nil, err
656+
}
657+
658+
if err := validateColumnTypesFromGoTypes(b, tuples, fieldNames, nullableColumns, columnTypes); err != nil {
659+
return nil, err
660+
}
661+
662+
values, err := convertGoTuplesToExpressions(b, tuples, fieldNames, columnTypes)
663+
if err != nil {
664+
return nil, err
665+
}
666+
667+
baseSchema := types.NamedStruct{
668+
Names: fieldNames,
669+
Struct: types.StructType{
670+
Nullability: types.NullabilityRequired,
671+
Types: columnTypes,
672+
},
673+
}
674+
675+
return &VirtualTableReadRel{
676+
baseReadRel: baseReadRel{
677+
RelCommon: RelCommon{},
678+
baseSchema: baseSchema,
679+
},
680+
values: values,
681+
}, nil
682+
}
683+
684+
func (b *builder) goTypeToSubstraitType(val any, nullable bool) (types.Type, error) {
685+
nullability := types.NullabilityRequired
686+
if nullable {
687+
nullability = types.NullabilityNullable
688+
}
689+
690+
switch val.(type) {
691+
case bool:
692+
return &types.BooleanType{Nullability: nullability}, nil
693+
case int8:
694+
return &types.Int8Type{Nullability: nullability}, nil
695+
case int16:
696+
return &types.Int16Type{Nullability: nullability}, nil
697+
case int32:
698+
return &types.Int32Type{Nullability: nullability}, nil
699+
case int:
700+
return &types.Int64Type{Nullability: nullability}, nil
701+
case int64:
702+
return &types.Int64Type{Nullability: nullability}, nil
703+
case float32:
704+
return &types.Float32Type{Nullability: nullability}, nil
705+
case float64:
706+
return &types.Float64Type{Nullability: nullability}, nil
707+
case string:
708+
return &types.StringType{Nullability: nullability}, nil
709+
default:
710+
return nil, fmt.Errorf("unsupported Go type: %T", val)
711+
}
712+
}
713+
714+
func (b *builder) goValueToExpression(val any, expectedType types.Type) (expr.Expression, error) {
715+
actualType, err := b.goTypeToSubstraitType(val, false)
716+
if err != nil {
717+
return nil, err
718+
}
719+
720+
// Compare base types (ignore nullability for this check)
721+
actualBase := actualType.WithNullability(types.NullabilityRequired)
722+
expectedBase := expectedType.WithNullability(types.NullabilityRequired)
723+
724+
if !actualBase.Equals(expectedBase) {
725+
return nil, fmt.Errorf("type mismatch: got %T, expected type compatible with %s", val, expectedType)
726+
}
727+
728+
switch v := val.(type) {
729+
case bool:
730+
return expr.NewPrimitiveLiteral(v, expectedType.GetNullability() == types.NullabilityNullable), nil
731+
case int8:
732+
return expr.NewPrimitiveLiteral(v, expectedType.GetNullability() == types.NullabilityNullable), nil
733+
case int16:
734+
return expr.NewPrimitiveLiteral(v, expectedType.GetNullability() == types.NullabilityNullable), nil
735+
case int32:
736+
return expr.NewPrimitiveLiteral(v, expectedType.GetNullability() == types.NullabilityNullable), nil
737+
case int:
738+
return expr.NewPrimitiveLiteral(int64(v), expectedType.GetNullability() == types.NullabilityNullable), nil
739+
case int64:
740+
return expr.NewPrimitiveLiteral(v, expectedType.GetNullability() == types.NullabilityNullable), nil
741+
case float32:
742+
return expr.NewPrimitiveLiteral(v, expectedType.GetNullability() == types.NullabilityNullable), nil
743+
case float64:
744+
return expr.NewPrimitiveLiteral(v, expectedType.GetNullability() == types.NullabilityNullable), nil
745+
case string:
746+
return expr.NewPrimitiveLiteral(v, expectedType.GetNullability() == types.NullabilityNullable), nil
747+
default:
748+
return nil, fmt.Errorf("unsupported value type: %T", val)
749+
}
750+
}
751+
752+
// inferColumnTypesFromGoTypes infers Substrait types from the first non-null value in each column
753+
func inferColumnTypesFromGoTypes(b *builder, tuples [][]any, fieldNames []string, nullableColumns []bool) ([]types.Type, error) {
754+
nfields := len(fieldNames)
755+
columnTypes := make([]types.Type, nfields)
756+
757+
for colIdx := range nfields {
758+
var foundType types.Type
759+
for rowIdx := range len(tuples) {
760+
val := tuples[rowIdx][colIdx]
761+
if val != nil {
762+
var err error
763+
foundType, err = b.goTypeToSubstraitType(val, nullableColumns[colIdx])
764+
if err != nil {
765+
return nil, fmt.Errorf("failed to infer type for column %d (%s): %w", colIdx, fieldNames[colIdx], err)
766+
}
767+
break
768+
}
769+
}
770+
771+
if foundType == nil {
772+
return nil, fmt.Errorf("%w: column %d (%s) contains only null values, cannot infer type", substraitgo.ErrInvalidRel, colIdx, fieldNames[colIdx])
773+
}
774+
775+
columnTypes[colIdx] = foundType
776+
}
777+
return columnTypes, nil
778+
}
779+
780+
// Validate that the values in each column of every row conform to the type specified in columnTypes
781+
func validateColumnTypesFromGoTypes(b *builder, tuples [][]any, fieldNames []string, nullableColumns []bool, columnTypes []types.Type) error {
782+
nfields := len(fieldNames)
783+
784+
for colIdx := range nfields {
785+
expectedType := columnTypes[colIdx]
786+
787+
for rowIdx := range len(tuples) {
788+
val := tuples[rowIdx][colIdx]
789+
if val != nil {
790+
currentType, err := b.goTypeToSubstraitType(val, nullableColumns[colIdx])
791+
if err != nil {
792+
return fmt.Errorf("invalid type in row %d, col %d (%s): %w", rowIdx, colIdx, fieldNames[colIdx], err)
793+
}
794+
795+
// Compare base types (ignore nullability for this check)
796+
expectedBase := expectedType.WithNullability(types.NullabilityRequired)
797+
currentBase := currentType.WithNullability(types.NullabilityRequired)
798+
799+
if !expectedBase.Equals(currentBase) {
800+
return fmt.Errorf("%w: type mismatch in column %d (%s): found %T in row %d, expected type compatible with %s",
801+
substraitgo.ErrInvalidRel, colIdx, fieldNames[colIdx], val, rowIdx, expectedType)
802+
}
803+
}
804+
}
805+
}
806+
return nil
807+
}
808+
809+
func convertGoTuplesToExpressions(b *builder, tuples [][]any, fieldNames []string, columnTypes []types.Type) ([]expr.VirtualTableExpressionValue, error) {
810+
nfields := len(fieldNames)
811+
values := make([]expr.VirtualTableExpressionValue, len(tuples))
812+
813+
for rowIdx, tuple := range tuples {
814+
row := make(expr.VirtualTableExpressionValue, nfields)
815+
816+
for colIdx, val := range tuple {
817+
expectedType := columnTypes[colIdx]
818+
819+
if val == nil {
820+
row[colIdx] = expr.NewNullLiteral(expectedType)
821+
} else {
822+
exprVal, err := b.goValueToExpression(val, expectedType)
823+
if err != nil {
824+
return nil, fmt.Errorf("failed to convert value at row %d, col %d (%s): %w", rowIdx, colIdx, fieldNames[colIdx], err)
825+
}
826+
row[colIdx] = exprVal
827+
}
828+
}
829+
830+
values[rowIdx] = row
831+
}
832+
833+
return values, nil
834+
}
835+
622836
func (b *builder) IcebergTableFromMetadataFile(metadataURI string, snapshot IcebergSnapshot, schema types.NamedStruct) (*IcebergTableReadRel, error) {
623837
tableType := &Direct{}
624838
tableType.MetadataUri = metadataURI

0 commit comments

Comments
 (0)