Skip to content

Commit 7813813

Browse files
committed
Move virtual table helpers from builder pkg to literals pkg
1 parent ed27133 commit 7813813

File tree

4 files changed

+501
-434
lines changed

4 files changed

+501
-434
lines changed

literal/go_types.go

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
package literal
2+
3+
import (
4+
"fmt"
5+
6+
substraitgo "github.com/substrait-io/substrait-go/v6"
7+
"github.com/substrait-io/substrait-go/v6/expr"
8+
"github.com/substrait-io/substrait-go/v6/types"
9+
)
10+
11+
// VirtualTableFromGoTypes converts Go values to VirtualTable expressions and infers their types.
12+
// It accepts field names, tuples of polymorphic Go values, and optional nullability configuration.
13+
// Returns the converted expressions, inferred column types, and any error encountered.
14+
// If nullableColumns is nil, all columns default to non-nullable (required).
15+
func VirtualTableFromGoTypes(fieldNames []string, tuples [][]any, nullableColumns []bool) ([]expr.VirtualTableExpressionValue, []types.Type, error) {
16+
// Need at least one tuple to infer column types from Go values.
17+
// Empty virtual tables are valid, but require explicit type specification via VirtualTableFromExpr.
18+
if len(tuples) == 0 {
19+
return nil, nil, fmt.Errorf("%w: must provide at least one tuple for virtual table", substraitgo.ErrInvalidRel)
20+
}
21+
22+
nfields := len(fieldNames)
23+
if nfields == 0 {
24+
return nil, nil, fmt.Errorf("%w: must provide at least one field name", substraitgo.ErrInvalidRel)
25+
}
26+
27+
for i, tuple := range tuples {
28+
if len(tuple) != nfields {
29+
return nil, nil, fmt.Errorf("%w: tuple %d has %d values, expected %d", substraitgo.ErrInvalidRel, i, len(tuple), nfields)
30+
}
31+
}
32+
33+
if nullableColumns == nil {
34+
// default behavior is that none of the columns are nullable
35+
nullableColumns = make([]bool, nfields)
36+
} else if len(nullableColumns) != nfields {
37+
return nil, nil, fmt.Errorf("%w: nullableColumns length (%d) must match fieldNames length (%d) or be nil", substraitgo.ErrInvalidRel, len(nullableColumns), nfields)
38+
}
39+
40+
columnTypes, err := inferColumnTypesFromGoTypes(tuples, fieldNames, nullableColumns)
41+
if err != nil {
42+
return nil, nil, err
43+
}
44+
45+
if err := validateColumnTypesFromGoTypes(tuples, fieldNames, nullableColumns, columnTypes); err != nil {
46+
return nil, nil, err
47+
}
48+
49+
values, err := convertGoTuplesToExpressions(tuples, fieldNames, columnTypes)
50+
if err != nil {
51+
return nil, nil, err
52+
}
53+
54+
return values, columnTypes, nil
55+
}
56+
57+
func GoTypeToSubstraitType(val any, nullable bool) (types.Type, error) {
58+
nullability := types.NullabilityRequired
59+
if nullable {
60+
nullability = types.NullabilityNullable
61+
}
62+
63+
switch val.(type) {
64+
case bool:
65+
return &types.BooleanType{Nullability: nullability}, nil
66+
case int8:
67+
return &types.Int8Type{Nullability: nullability}, nil
68+
case int16:
69+
return &types.Int16Type{Nullability: nullability}, nil
70+
case int32:
71+
return &types.Int32Type{Nullability: nullability}, nil
72+
case int:
73+
return &types.Int64Type{Nullability: nullability}, nil
74+
case int64:
75+
return &types.Int64Type{Nullability: nullability}, nil
76+
case float32:
77+
return &types.Float32Type{Nullability: nullability}, nil
78+
case float64:
79+
return &types.Float64Type{Nullability: nullability}, nil
80+
case string:
81+
return &types.StringType{Nullability: nullability}, nil
82+
default:
83+
return nil, fmt.Errorf("unsupported Go type: %T", val)
84+
}
85+
}
86+
87+
func GoValueToExpression(val any, expectedType types.Type) (expr.Expression, error) {
88+
actualType, err := GoTypeToSubstraitType(val, false)
89+
if err != nil {
90+
return nil, err
91+
}
92+
93+
// Compare base types (ignore nullability for this check)
94+
actualBase := actualType.WithNullability(types.NullabilityRequired)
95+
expectedBase := expectedType.WithNullability(types.NullabilityRequired)
96+
97+
if !actualBase.Equals(expectedBase) {
98+
return nil, fmt.Errorf("type mismatch: got %T, expected type compatible with %s", val, expectedType)
99+
}
100+
101+
switch v := val.(type) {
102+
case bool:
103+
return expr.NewPrimitiveLiteral(v, expectedType.GetNullability() == types.NullabilityNullable), nil
104+
case int8:
105+
return expr.NewPrimitiveLiteral(v, expectedType.GetNullability() == types.NullabilityNullable), nil
106+
case int16:
107+
return expr.NewPrimitiveLiteral(v, expectedType.GetNullability() == types.NullabilityNullable), nil
108+
case int32:
109+
return expr.NewPrimitiveLiteral(v, expectedType.GetNullability() == types.NullabilityNullable), nil
110+
case int:
111+
return expr.NewPrimitiveLiteral(int64(v), expectedType.GetNullability() == types.NullabilityNullable), nil
112+
case int64:
113+
return expr.NewPrimitiveLiteral(v, expectedType.GetNullability() == types.NullabilityNullable), nil
114+
case float32:
115+
return expr.NewPrimitiveLiteral(v, expectedType.GetNullability() == types.NullabilityNullable), nil
116+
case float64:
117+
return expr.NewPrimitiveLiteral(v, expectedType.GetNullability() == types.NullabilityNullable), nil
118+
case string:
119+
return expr.NewPrimitiveLiteral(v, expectedType.GetNullability() == types.NullabilityNullable), nil
120+
default:
121+
return nil, fmt.Errorf("unsupported value type: %T", val)
122+
}
123+
}
124+
125+
// inferColumnTypesFromGoTypes infers Substrait types from the first non-null value in each column
126+
func inferColumnTypesFromGoTypes(tuples [][]any, fieldNames []string, nullableColumns []bool) ([]types.Type, error) {
127+
nfields := len(fieldNames)
128+
columnTypes := make([]types.Type, nfields)
129+
130+
for colIdx := range nfields {
131+
var foundType types.Type
132+
for rowIdx := range len(tuples) {
133+
val := tuples[rowIdx][colIdx]
134+
if val != nil {
135+
var err error
136+
foundType, err = GoTypeToSubstraitType(val, nullableColumns[colIdx])
137+
if err != nil {
138+
return nil, fmt.Errorf("failed to infer type for column %d (%s): %w", colIdx, fieldNames[colIdx], err)
139+
}
140+
break
141+
}
142+
}
143+
144+
if foundType == nil {
145+
return nil, fmt.Errorf("%w: column %d (%s) contains only null values, cannot infer type", substraitgo.ErrInvalidRel, colIdx, fieldNames[colIdx])
146+
}
147+
148+
columnTypes[colIdx] = foundType
149+
}
150+
return columnTypes, nil
151+
}
152+
153+
// validateColumnTypesFromGoTypes validates that the values in each column of every row conform to the type specified in columnTypes
154+
func validateColumnTypesFromGoTypes(tuples [][]any, fieldNames []string, nullableColumns []bool, columnTypes []types.Type) error {
155+
nfields := len(fieldNames)
156+
157+
for colIdx := range nfields {
158+
expectedType := columnTypes[colIdx]
159+
160+
for rowIdx := range len(tuples) {
161+
val := tuples[rowIdx][colIdx]
162+
if val != nil {
163+
currentType, err := GoTypeToSubstraitType(val, nullableColumns[colIdx])
164+
if err != nil {
165+
return fmt.Errorf("invalid type in row %d, col %d (%s): %w", rowIdx, colIdx, fieldNames[colIdx], err)
166+
}
167+
168+
// Compare base types (ignore nullability for this check)
169+
expectedBase := expectedType.WithNullability(types.NullabilityRequired)
170+
currentBase := currentType.WithNullability(types.NullabilityRequired)
171+
172+
if !expectedBase.Equals(currentBase) {
173+
return fmt.Errorf("%w: type mismatch in column %d (%s): found %T in row %d, expected type compatible with %s",
174+
substraitgo.ErrInvalidRel, colIdx, fieldNames[colIdx], val, rowIdx, expectedType)
175+
}
176+
}
177+
}
178+
}
179+
return nil
180+
}
181+
182+
func convertGoTuplesToExpressions(tuples [][]any, fieldNames []string, columnTypes []types.Type) ([]expr.VirtualTableExpressionValue, error) {
183+
nfields := len(fieldNames)
184+
values := make([]expr.VirtualTableExpressionValue, len(tuples))
185+
186+
for rowIdx, tuple := range tuples {
187+
row := make(expr.VirtualTableExpressionValue, nfields)
188+
189+
for colIdx, val := range tuple {
190+
expectedType := columnTypes[colIdx]
191+
192+
if val == nil {
193+
row[colIdx] = expr.NewNullLiteral(expectedType)
194+
} else {
195+
exprVal, err := GoValueToExpression(val, expectedType)
196+
if err != nil {
197+
return nil, fmt.Errorf("failed to convert value at row %d, col %d (%s): %w", rowIdx, colIdx, fieldNames[colIdx], err)
198+
}
199+
row[colIdx] = exprVal
200+
}
201+
}
202+
203+
values[rowIdx] = row
204+
}
205+
206+
return values, nil
207+
}

0 commit comments

Comments
 (0)