diff --git a/arrow.go b/arrow.go index 9dc3a690..65cada68 100644 --- a/arrow.go +++ b/arrow.go @@ -301,3 +301,15 @@ func (r *recordReader) Err() error { defer r.mu.Unlock() return r.err } + +// DataChunkFromArrow moves a record batch into an existing DuckDB DataChunk. +// Useful for implementing table functions that read from Arrow sources. +func (a *Arrow) DataChunkFromArrow(rec arrow.RecordBatch, chunk DataChunk) error { + s, ed := arrowmapping.SchemaFromArrow(a.conn.conn, rec.Schema()) + if err := errorDataError(ed); err != nil { + return fmt.Errorf("failed to convert arrow schema to duckdb schema: %w", err) + } + defer arrowmapping.DestroyArrowConvertedSchema(&s) + ed = arrowmapping.DataChunkFromArrow(a.conn.conn, rec, s, chunk.chunk) + return errorDataError(ed) +} diff --git a/arrow_test.go b/arrow_test.go index 6b623bab..87cdeafd 100644 --- a/arrow_test.go +++ b/arrow_test.go @@ -6,6 +6,7 @@ import ( "context" "database/sql" "database/sql/driver" + "fmt" "sync" "testing" @@ -315,3 +316,94 @@ func TestArrowClosedConn(t *testing.T) { }) require.Error(t, err) } + +func TestArrowTableUDF(t *testing.T) { + db := openDbWrapper(t, ``) + defer closeDbWrapper(t, db) + + conn := openConnWrapper(t, db, context.Background()) + defer closeConnWrapper(t, conn) + + c := newConnectorWrapper(t, ``, nil) + defer closeConnectorWrapper(t, c) + + innerConn := openDriverConnWrapper(t, c) + defer closeDriverConnWrapper(t, &innerConn) + + ar, err := NewArrowFromConn(innerConn) + require.NoError(t, err) + + // Create an arrow array of type Float64 buffered in memory + schema := arrow.NewSchema([]arrow.Field{ + {Name: "col0", Type: arrow.PrimitiveTypes.Float64}, + }, nil) + alloc := memory.NewGoAllocator() + builder := array.NewFloat64Builder(alloc) + defer builder.Release() + + // Add values > data chunk size to test multiple chunks + for range 10000 { + builder.Append(float64(0.5)) + } + + arr := builder.NewArray() + rb := array.NewRecordBatch(schema, []arrow.Array{arr}, int64(arr.Len())) + tbl := array.NewTableFromRecords(schema, []arrow.RecordBatch{rb}) + + RegisterTableUDF(conn, "get_arrow", ChunkTableFunction{ + BindArguments: func(named map[string]any, args ...any) (ChunkTableSource, error) { + return &arrowTableUdf{tbl: tbl, ar: ar}, nil + }, + }) + + res, err := db.QueryContext(context.Background(), `SELECT * FROM get_arrow()`) + require.NoError(t, err) + defer closeRowsWrapper(t, res) + + var rowCount int + for res.Next() { + var val float64 + require.NoError(t, res.Scan(&val)) + require.Equal(t, 0.5, val) + rowCount++ + } + require.Equal(t, 10000, rowCount) +} + +// Define a table UDF +type arrowTableUdf struct { + ar *Arrow + tbl arrow.Table + rdr *array.TableReader +} + +func (u *arrowTableUdf) Init() { + u.rdr = array.NewTableReader(u.tbl, int64(GetDataChunkCapacity())) +} + +func (u *arrowTableUdf) ColumnInfos() []ColumnInfo { + t, _ := NewTypeInfo(TYPE_DOUBLE) + return []ColumnInfo{{ + Name: "col0", + T: t, + }} +} + +func (u *arrowTableUdf) Cardinality() *CardinalityInfo { + return &CardinalityInfo{ + Cardinality: uint(u.tbl.NumRows()), + Exact: true, + } +} + +func (u *arrowTableUdf) FillChunk(chunk DataChunk) error { + if u.rdr.Next() { + b := u.rdr.RecordBatch() + defer b.Release() + if err := u.ar.DataChunkFromArrow(b, chunk); err != nil { + return fmt.Errorf("failed to move arrow to data chunk: %w", err) + } + } + + return nil +}