Skip to content

Commit 2ef4807

Browse files
committed
Use the AST to implement aggregate functions.
1 parent 63ebb08 commit 2ef4807

File tree

8 files changed

+152
-112
lines changed

8 files changed

+152
-112
lines changed

aggregate.go

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ package duckql
22

33
import (
44
"reflect"
5-
"strconv"
6-
"strings"
75
)
86

97
type AggregateFunctionColumn struct {
@@ -185,37 +183,3 @@ func sumOfColumn(c *AggregateFunctionColumn, rows ResultRows) ResultRows {
185183
sumRow,
186184
}
187185
}
188-
189-
func ParseAggregateFunction(text string) *AggregateFunctionColumn {
190-
var column AggregateFunctionColumn
191-
var current strings.Builder
192-
193-
for _, r := range text {
194-
switch r {
195-
case '(':
196-
functionName := current.String()
197-
if f, ok := functionMap[functionName]; ok {
198-
column.Function = f
199-
} else {
200-
return nil
201-
}
202-
203-
current.Reset()
204-
case ')':
205-
var err error
206-
column.UnderlyingColumn, err = strconv.Unquote(current.String())
207-
if err != nil {
208-
column.UnderlyingColumn = current.String()
209-
}
210-
current.Reset()
211-
default:
212-
current.WriteRune(r)
213-
}
214-
}
215-
216-
if current.String() == "" {
217-
return &column
218-
}
219-
220-
return nil
221-
}

filter_backing.go

Lines changed: 96 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -259,60 +259,124 @@ func (f *SliceFilter) Rows() ResultRows {
259259
var r ResultRows
260260

261261
source := f.intermediate.Result()
262-
263262
source = source.Filter(f.filter)
264263

264+
if len(source.Rows) == 0 {
265+
return r
266+
}
267+
265268
// Transform our intermediate columns into a lookup table
266269
lookup := make(map[string]int)
267270
for idx, column := range source.Columns {
268271
lookup[column] = idx
269272
}
270273

271-
for _, row := range source.Rows {
272-
var narrow ResultRow
274+
// Find column positions to narrow
275+
var narrowColumns []int
276+
for _, column := range f.resultColumns {
277+
if column.Star.Line > 0 {
278+
for idx, _ := range source.Rows[0] {
279+
narrowColumns = append(narrowColumns, idx)
280+
}
281+
break
282+
}
273283

274-
for _, column := range f.resultColumns {
275-
if column.Star.Line > 0 {
276-
narrow = row
284+
switch t := column.Expr.(type) {
285+
case *sql.Ident:
286+
index, ok := lookup[t.Name]
287+
if !ok {
288+
// FIXME: There should be a better way
289+
parts := strings.Split(source.Columns[0], ".")
290+
if len(parts) > 1 {
291+
index = lookup[parts[0]+"."+t.Name]
292+
}
277293
}
278294

279-
switch t := column.Expr.(type) {
280-
case *sql.Ident:
281-
index, ok := lookup[t.Name]
295+
narrowColumns = append(narrowColumns, index)
296+
case *sql.QualifiedRef:
297+
if t.Star.Line != 0 {
298+
// FIXME: Implement
299+
continue
300+
}
301+
302+
lh := t.Table.Name
303+
rh := t.Column.Name
304+
305+
var index int
306+
var ok bool
307+
if source.Source != nil {
308+
index, ok = lookup[rh]
309+
} else {
310+
index, ok = lookup[lh+"."+rh]
282311
if !ok {
312+
index = lookup[source.Aliases[lh]+"."+rh]
313+
}
314+
}
315+
316+
narrowColumns = append(narrowColumns, index)
317+
default:
318+
narrowColumns = append(narrowColumns, 0)
319+
}
320+
}
321+
322+
// Find aggregations
323+
var aggregations []AggregateFunctionColumn
324+
for idx, column := range f.resultColumns {
325+
switch t := column.Expr.(type) {
326+
case *sql.Call:
327+
var underlying string
328+
if t.Star.Line == 0 {
329+
if len(t.Args) != 1 {
330+
panic("unexpected number of args to function")
331+
}
332+
333+
switch arg := t.Args[0].(type) {
334+
case *sql.Ident:
335+
// Validate?
336+
index, ok := lookup[arg.Name]
337+
if ok {
338+
underlying = arg.Name
339+
break
340+
}
341+
283342
// FIXME: There should be a better way
284343
parts := strings.Split(source.Columns[0], ".")
285344
if len(parts) > 1 {
286-
index = lookup[parts[0]+"."+t.Name]
345+
index, ok = lookup[parts[0]+"."+arg.Name]
346+
if ok {
347+
underlying = parts[0] + "." + arg.Name
348+
}
287349
}
288-
}
289350

290-
narrow = append(narrow, row[index])
291-
case *sql.QualifiedRef:
292-
if t.Star.Line != 0 {
293-
// FIXME: Implement
294-
continue
351+
narrowColumns[idx] = index
295352
}
353+
} else {
354+
narrowColumns[idx] = 0
355+
}
296356

297-
lh := t.Table.Name
298-
rh := t.Column.Name
299-
300-
var index int
301-
var ok bool
302-
if source.Source != nil {
303-
index, ok = lookup[rh]
304-
} else {
305-
index, ok = lookup[lh+"."+rh]
306-
if !ok {
307-
index = lookup[source.Aliases[lh]+"."+rh]
308-
}
309-
}
357+
aggregations = append(aggregations, AggregateFunctionColumn{
358+
UnderlyingColumn: underlying,
359+
ResultPosition: idx,
360+
Function: functionMap[t.Name.Name],
361+
})
362+
}
363+
}
310364

311-
narrow = append(narrow, row[index])
312-
}
365+
for _, row := range source.Rows {
366+
if len(row) == len(narrowColumns) {
367+
r = append(r, row)
368+
continue
369+
}
370+
371+
var newRow ResultRow
372+
for _, column := range narrowColumns {
373+
newRow = append(newRow, row[column])
313374
}
375+
r = append(r, newRow)
376+
}
314377

315-
r = append(r, narrow)
378+
for _, aggregation := range aggregations {
379+
r = aggregation.Call(r)
316380
}
317381

318382
f.resultColumns = []*sql.ResultColumn{}

sqlite_backing.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ type SQLiteBacking struct {
1818

1919
// New creates a new SQLiteBacking with the given SQLite database connection
2020
func NewSQLiteBacking(db *gosql.DB, s *SQLizer) *SQLiteBacking {
21-
s.HandleAggregateFunctions = false
2221
return &SQLiteBacking{
2322
sqlizer: s,
2423
db: db,

sqlizer.go

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,9 @@ const (
2828
)
2929

3030
type SQLizer struct {
31-
Tables map[string]*Table
32-
Permissions uint
33-
Backing BackingStore
34-
AggregateFunctions []*AggregateFunctionColumn
35-
HandleAggregateFunctions bool
31+
Tables map[string]*Table
32+
Permissions uint
33+
Backing BackingStore
3634
}
3735

3836
func (s *SQLizer) SetPermissions(permissions uint) {
@@ -75,22 +73,7 @@ func (s *SQLizer) Execute(statement string) (ResultRows, error) {
7573
return nil, err
7674
}
7775

78-
rows := s.Backing.Rows()
79-
80-
if len(rows) > 0 && len(s.AggregateFunctions) > 0 && s.HandleAggregateFunctions {
81-
var result ResultRows
82-
for idx, aggregate := range s.AggregateFunctions {
83-
r := aggregate.Call(rows)
84-
if idx == 0 {
85-
result = append(result, r[0])
86-
} else {
87-
result[0][aggregate.ResultPosition] = r[0][aggregate.ResultPosition]
88-
}
89-
}
90-
rows = result
91-
}
92-
93-
return rows, nil
76+
return s.Backing.Rows(), nil
9477
}
9578

9679
return nil, nil
@@ -204,8 +187,6 @@ func Initialize(structs ...any) *SQLizer {
204187
sql.addStructTable(s)
205188
}
206189

207-
sql.HandleAggregateFunctions = true
208-
209190
return &sql
210191
}
211192

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
.section = DDL
2+
---
3+
CREATE TABLE users
4+
(
5+
id INTEGER,
6+
name TEXT,
7+
email TEXT
8+
)
9+
10+
---
11+
.section = Result
12+
---
13+
John Doe
14+
Jane Smith
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
.section = DDL
2+
---
3+
CREATE TABLE users
4+
(
5+
id INTEGER,
6+
name TEXT,
7+
email TEXT
8+
)
9+
10+
---
11+
.section = Result
12+
---
13+
John Doe
14+
Jane Smith
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
.section = DDL
2+
---
3+
CREATE TABLE accounts
4+
(
5+
id INTEGER,
6+
username TEXT,
7+
email TEXT,
8+
organization_id INTEGER
9+
)
10+
11+
CREATE TABLE organizations
12+
(
13+
id INTEGER,
14+
name TEXT
15+
)
16+
17+
---
18+
.section = Result
19+
---
20+
user1|Initech

validate.go

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,32 +35,16 @@ func (v *Validator) Visit(n sql.Node) (sql.Visitor, sql.Node, error) {
3535
case *sql.Ident:
3636
v.columns = append(v.columns, e.Name)
3737
case *sql.Call:
38-
var aggregate AggregateFunctionColumn
39-
var star sql.Pos
38+
var underlyingColumn string
4039

4140
// FIXME
4241
if len(e.Args) > 0 {
43-
aggregate.UnderlyingColumn = e.Args[0].(*sql.Ident).Name
42+
underlyingColumn = e.Args[0].(*sql.Ident).Name
4443
} else {
45-
aggregate.UnderlyingColumn = "*"
46-
star.Line = 1
44+
underlyingColumn = "*"
4745
}
4846

49-
aggregate.ResultPosition = len(v.columns)
50-
aggregate.Function = functionMap[e.Name.Name]
51-
52-
v.s.AggregateFunctions = append(v.s.AggregateFunctions, &aggregate)
53-
// FIXME
54-
if v.s.HandleAggregateFunctions {
55-
n = &sql.ResultColumn{
56-
Star: star,
57-
Expr: &sql.Ident{
58-
Name: aggregate.UnderlyingColumn,
59-
Quoted: false,
60-
},
61-
}
62-
}
63-
v.columns = append(v.columns, aggregate.UnderlyingColumn)
47+
v.columns = append(v.columns, underlyingColumn)
6448
}
6549

6650
case *sql.QualifiedTableName:

0 commit comments

Comments
 (0)