Skip to content

Commit 8d17484

Browse files
committed
Scan improvements.
1 parent 3283ec0 commit 8d17484

File tree

6 files changed

+120
-80
lines changed

6 files changed

+120
-80
lines changed

conn.go

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -444,20 +444,27 @@ func (c *Conn) Status(op DBStatus, reset bool) (current, highwater int, err erro
444444
// https://sqlite.org/c3ref/table_column_metadata.html
445445
func (c *Conn) TableColumnMetadata(schema, table, column string) (declType, collSeq string, notNull, primaryKey, autoInc bool, err error) {
446446
defer c.arena.mark()()
447-
448-
var schemaPtr, columnPtr ptr_t
449-
declTypePtr := c.arena.new(ptrlen)
450-
collSeqPtr := c.arena.new(ptrlen)
451-
notNullPtr := c.arena.new(ptrlen)
452-
autoIncPtr := c.arena.new(ptrlen)
453-
primaryKeyPtr := c.arena.new(ptrlen)
447+
var (
448+
declTypePtr ptr_t
449+
collSeqPtr ptr_t
450+
notNullPtr ptr_t
451+
primaryKeyPtr ptr_t
452+
autoIncPtr ptr_t
453+
columnPtr ptr_t
454+
schemaPtr ptr_t
455+
)
456+
if column != "" {
457+
declTypePtr = c.arena.new(ptrlen)
458+
collSeqPtr = c.arena.new(ptrlen)
459+
notNullPtr = c.arena.new(ptrlen)
460+
primaryKeyPtr = c.arena.new(ptrlen)
461+
autoIncPtr = c.arena.new(ptrlen)
462+
columnPtr = c.arena.string(column)
463+
}
454464
if schema != "" {
455465
schemaPtr = c.arena.string(schema)
456466
}
457467
tablePtr := c.arena.string(table)
458-
if column != "" {
459-
columnPtr = c.arena.string(column)
460-
}
461468

462469
rc := res_t(c.call("sqlite3_table_column_metadata", stk_t(c.handle),
463470
stk_t(schemaPtr), stk_t(tablePtr), stk_t(columnPtr),

driver/driver.go

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -607,14 +607,24 @@ func (r resultRowsAffected) RowsAffected() (int64, error) {
607607
type scantype byte
608608

609609
const (
610-
_ANY scantype = iota
611-
_INT scantype = scantype(sqlite3.INTEGER)
612-
_REAL scantype = scantype(sqlite3.FLOAT)
613-
_TEXT scantype = scantype(sqlite3.TEXT)
614-
_BLOB scantype = scantype(sqlite3.BLOB)
615-
_NULL scantype = scantype(sqlite3.NULL)
616-
_BOOL scantype = iota
610+
_ANY scantype = iota
611+
_INT
612+
_REAL
613+
_TEXT
614+
_BLOB
615+
_NULL
616+
_BOOL
617617
_TIME
618+
_NOT_NULL
619+
)
620+
621+
var (
622+
_ [0]struct{} = [scantype(sqlite3.INTEGER) - _INT]struct{}{}
623+
_ [0]struct{} = [scantype(sqlite3.FLOAT) - _REAL]struct{}{}
624+
_ [0]struct{} = [scantype(sqlite3.TEXT) - _TEXT]struct{}{}
625+
_ [0]struct{} = [scantype(sqlite3.BLOB) - _BLOB]struct{}{}
626+
_ [0]struct{} = [scantype(sqlite3.NULL) - _NULL]struct{}{}
627+
_ [0]struct{} = [_NOT_NULL & (_NOT_NULL - 1)]struct{}{}
618628
)
619629

620630
func scanFromDecl(decl string) scantype {
@@ -644,8 +654,8 @@ type rows struct {
644654
*stmt
645655
names []string
646656
types []string
647-
nulls []bool
648657
scans []scantype
658+
dest []driver.Value
649659
}
650660

651661
var (
@@ -675,34 +685,36 @@ func (r *rows) Columns() []string {
675685

676686
func (r *rows) scanType(index int) scantype {
677687
if r.scans == nil {
678-
count := r.Stmt.ColumnCount()
688+
count := len(r.names)
679689
scans := make([]scantype, count)
680690
for i := range scans {
681691
scans[i] = scanFromDecl(strings.ToUpper(r.Stmt.ColumnDeclType(i)))
682692
}
683693
r.scans = scans
684694
}
685-
return r.scans[index]
695+
return r.scans[index] &^ _NOT_NULL
686696
}
687697

688698
func (r *rows) loadColumnMetadata() {
689-
if r.nulls == nil {
699+
if r.types == nil {
690700
c := r.Stmt.Conn()
691-
count := r.Stmt.ColumnCount()
692-
nulls := make([]bool, count)
701+
count := len(r.names)
693702
types := make([]string, count)
694703
scans := make([]scantype, count)
695-
for i := range nulls {
704+
for i := range types {
705+
var notnull bool
696706
if col := r.Stmt.ColumnOriginName(i); col != "" {
697-
types[i], _, nulls[i], _, _, _ = c.TableColumnMetadata(
707+
types[i], _, notnull, _, _, _ = c.TableColumnMetadata(
698708
r.Stmt.ColumnDatabaseName(i),
699709
r.Stmt.ColumnTableName(i),
700710
col)
701711
types[i] = strings.ToUpper(types[i])
702712
scans[i] = scanFromDecl(types[i])
713+
if notnull {
714+
scans[i] |= _NOT_NULL
715+
}
703716
}
704717
}
705-
r.nulls = nulls
706718
r.types = types
707719
r.scans = scans
708720
}
@@ -721,15 +733,13 @@ func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
721733

722734
func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
723735
r.loadColumnMetadata()
724-
if r.nulls[index] {
725-
return false, true
726-
}
727-
return true, false
736+
nullable = r.scans[index]&^_NOT_NULL == 0
737+
return nullable, !nullable
728738
}
729739

730740
func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) {
731741
r.loadColumnMetadata()
732-
scan := r.scans[index]
742+
scan := r.scans[index] &^ _NOT_NULL
733743

734744
if r.Stmt.Busy() {
735745
// SQLite is dynamically typed and we now have a row.
@@ -772,6 +782,7 @@ func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) {
772782
}
773783

774784
func (r *rows) Next(dest []driver.Value) error {
785+
r.dest = nil
775786
c := r.Stmt.Conn()
776787
if old := c.SetInterrupt(r.ctx); old != r.ctx {
777788
defer c.SetInterrupt(old)
@@ -829,10 +840,11 @@ func (r *rows) Next(dest []driver.Value) error {
829840
}
830841
}
831842
}
843+
r.dest = dest
832844
return nil
833845
}
834846

835-
func (r *rows) ScanColumn(dest any, index int) error {
847+
func (r *rows) ScanColumn(dest any, index int) (err error) {
836848
// notest // Go 1.26
837849
var tm *time.Time
838850
var ok *bool
@@ -848,10 +860,13 @@ func (r *rows) ScanColumn(dest any, index int) error {
848860
default:
849861
return driver.ErrSkip
850862
}
851-
*tm = r.Stmt.ColumnTime(index, r.tmRead)
852-
err := r.Stmt.Err()
853-
if ok != nil && err == nil {
854-
*ok = r.stmt.ColumnType(index) != sqlite3.NULL
863+
value := r.dest[index]
864+
*tm, err = r.tmRead.Decode(value)
865+
if ok != nil {
866+
*ok = err == nil
867+
if value == nil {
868+
return nil
869+
}
855870
}
856871
return err
857872
}

driver/driver_test.go

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"math"
99
"net/url"
1010
"reflect"
11+
"strings"
1112
"testing"
1213
"time"
1314

@@ -33,7 +34,7 @@ func Test_Open_error(t *testing.T) {
3334
func Test_Open_dir(t *testing.T) {
3435
t.Parallel()
3536

36-
db, err := sql.Open("sqlite3", ".")
37+
db, err := Open(".")
3738
if err != nil {
3839
t.Fatal(err)
3940
}
@@ -54,7 +55,7 @@ func Test_Open_pragma(t *testing.T) {
5455
"_pragma": {"busy_timeout(1000)"},
5556
})
5657

57-
db, err := sql.Open("sqlite3", tmp)
58+
db, err := Open(tmp)
5859
if err != nil {
5960
t.Fatal(err)
6061
}
@@ -76,7 +77,7 @@ func Test_Open_pragma_invalid(t *testing.T) {
7677
"_pragma": {"busy_timeout 1000"},
7778
})
7879

79-
db, err := sql.Open("sqlite3", tmp)
80+
db, err := Open(tmp)
8081
if err != nil {
8182
t.Fatal(err)
8283
}
@@ -105,7 +106,7 @@ func Test_Open_txLock(t *testing.T) {
105106
"_pragma": {"busy_timeout(1000)"},
106107
})
107108

108-
db, err := sql.Open("sqlite3", tmp)
109+
db, err := Open(tmp)
109110
if err != nil {
110111
t.Fatal(err)
111112
}
@@ -140,7 +141,7 @@ func Test_Open_txLock_invalid(t *testing.T) {
140141
"_txlock": {"xclusive"},
141142
})
142143

143-
_, err := sql.Open("sqlite3", tmp+"_txlock=xclusive")
144+
_, err := Open(tmp)
144145
if err == nil {
145146
t.Fatal("want error")
146147
}
@@ -156,7 +157,7 @@ func Test_BeginTx(t *testing.T) {
156157
"_pragma": {"busy_timeout(0)"},
157158
})
158159

159-
db, err := sql.Open("sqlite3", tmp)
160+
db, err := Open(tmp)
160161
if err != nil {
161162
t.Fatal(err)
162163
}
@@ -200,7 +201,7 @@ func Test_nested_context(t *testing.T) {
200201
t.Parallel()
201202
tmp := memdb.TestDB(t)
202203

203-
db, err := sql.Open("sqlite3", tmp)
204+
db, err := Open(tmp)
204205
if err != nil {
205206
t.Fatal(err)
206207
}
@@ -258,7 +259,7 @@ func Test_Prepare(t *testing.T) {
258259
t.Parallel()
259260
tmp := memdb.TestDB(t)
260261

261-
db, err := sql.Open("sqlite3", tmp)
262+
db, err := Open(tmp)
262263
if err != nil {
263264
t.Fatal(err)
264265
}
@@ -299,7 +300,7 @@ func Test_QueryRow_named(t *testing.T) {
299300
t.Parallel()
300301
tmp := memdb.TestDB(t)
301302

302-
db, err := sql.Open("sqlite3", tmp)
303+
db, err := Open(tmp)
303304
if err != nil {
304305
t.Fatal(err)
305306
}
@@ -349,7 +350,7 @@ func Test_QueryRow_blob_null(t *testing.T) {
349350
t.Parallel()
350351
tmp := memdb.TestDB(t)
351352

352-
db, err := sql.Open("sqlite3", tmp)
353+
db, err := Open(tmp)
353354
if err != nil {
354355
t.Fatal(err)
355356
}
@@ -388,7 +389,7 @@ func Test_time(t *testing.T) {
388389
"_timefmt": {fmt},
389390
})
390391

391-
db, err := sql.Open("sqlite3", tmp)
392+
db, err := Open(tmp)
392393
if err != nil {
393394
t.Fatal(err)
394395
}
@@ -433,7 +434,7 @@ func Test_ColumnType_ScanType(t *testing.T) {
433434
t.Parallel()
434435
tmp := memdb.TestDB(t)
435436

436-
db, err := sql.Open("sqlite3", tmp)
437+
db, err := Open(tmp)
437438
if err != nil {
438439
t.Fatal(err)
439440
}
@@ -520,6 +521,39 @@ func Test_ColumnType_ScanType(t *testing.T) {
520521
}
521522
}
522523

524+
func Test_rows_ScanColumn(t *testing.T) {
525+
t.Parallel()
526+
tmp := memdb.TestDB(t)
527+
528+
db, err := Open(tmp)
529+
if err != nil {
530+
t.Fatal(err)
531+
}
532+
defer db.Close()
533+
534+
var tm time.Time
535+
err = db.QueryRow(`SELECT NULL`).Scan(&tm)
536+
if err == nil {
537+
t.Error("want error")
538+
}
539+
// Go 1.26
540+
err = db.QueryRow(`SELECT datetime()`).Scan(&tm)
541+
if err != nil && !strings.HasPrefix(err.Error(), "sql: Scan error") {
542+
t.Error(err)
543+
}
544+
545+
var nt sql.NullTime
546+
err = db.QueryRow(`SELECT NULL`).Scan(&nt)
547+
if err != nil {
548+
t.Error(err)
549+
}
550+
// Go 1.26
551+
err = db.QueryRow(`SELECT datetime()`).Scan(&nt)
552+
if err != nil && !strings.HasPrefix(err.Error(), "sql: Scan error") {
553+
t.Error(err)
554+
}
555+
}
556+
523557
func Benchmark_loop(b *testing.B) {
524558
db, err := Open(":memory:")
525559
if err != nil {
@@ -533,8 +567,7 @@ func Benchmark_loop(b *testing.B) {
533567
b.Fatal(err)
534568
}
535569

536-
b.ResetTimer()
537-
for range b.N {
570+
for b.Loop() {
538571
_, err := db.ExecContext(b.Context(),
539572
`WITH RECURSIVE c(x) AS (VALUES(1) UNION ALL SELECT x+1 FROM c WHERE x < 1000000) SELECT x FROM c;`)
540573
if err != nil {

0 commit comments

Comments
 (0)