diff --git a/sqlite3.go b/sqlite3.go index 3025a500..aca480a5 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -137,6 +137,61 @@ _sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_ } #endif +static int _sqlite3_prepare_v2(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, int *oBytes) { + const char *tail; + int rv = _sqlite3_prepare_v2_internal(db, zSql, nBytes, ppStmt, &tail); + if (rv != SQLITE_OK) { + return rv; + } + if (tail) { + // Set oBytes to the number of bytes consumed instead of using the + // **pzTail out param since that requires storing a Go pointer in + // a C pointer, which is not allowed by CGO and will cause + // runtime.cgoCheckPointer to fail. + *oBytes = tail - zSql; + } else { + // NB: this should not happen, but if it does advance oBytes to the + // end of the string so that we do not loop infinitely. + *oBytes = nBytes; + } + return SQLITE_OK; +} + +// _sqlite3_exec_no_args executes all of the statements in zSql. None of the +// statements are allowed to have positional arguments. +int _sqlite3_exec_no_args(sqlite3 *db, const char *zSql, int nBytes, int64_t *rowid, int64_t *changes) { + while (*zSql && nBytes > 0) { + sqlite3_stmt *stmt; + const char *tail; + int rv = sqlite3_prepare_v2(db, zSql, nBytes, &stmt, &tail); + if (rv != SQLITE_OK) { + return rv; + } + + // Process statement + do { + rv = _sqlite3_step_internal(stmt); + } while (rv == SQLITE_ROW); + + // Only record the number of changes made by the last statement. + *changes = sqlite3_changes64(db); + *rowid = sqlite3_last_insert_rowid(db); + + if (rv != SQLITE_OK && rv != SQLITE_DONE) { + sqlite3_finalize(stmt); + return rv; + } + rv = sqlite3_finalize(stmt); + if (rv != SQLITE_OK) { + return rv; + } + + nBytes -= tail - zSql; + zSql = tail; + } + return SQLITE_OK; +} + void _sqlite3_result_text(sqlite3_context* ctx, const char* s) { sqlite3_result_text(ctx, s, -1, &free); } @@ -858,54 +913,119 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err } func (c *SQLiteConn) exec(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - start := 0 + // Trim the query. This is mostly important for getting rid + // of any trailing space. + query = strings.TrimSpace(query) + if len(args) > 0 { + return c.execArgs(ctx, query, args) + } + return c.execNoArgs(ctx, query) +} + +func (c *SQLiteConn) execArgs(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + var ( + stmtArgs []driver.NamedValue + start int + s SQLiteStmt // escapes to the heap so reuse it + sz C.int // number of query bytes consumed: escapes to the heap + ) for { - s, err := c.prepare(ctx, query) - if err != nil { - return nil, err + s = SQLiteStmt{c: c} // reset + sz = 0 + rv := C._sqlite3_prepare_v2(c.db, (*C.char)(unsafe.Pointer(stringData(query))), + C.int(len(query)), &s.s, &sz) + if rv != C.SQLITE_OK { + return nil, c.lastError() } + query = strings.TrimSpace(query[sz:]) + var res driver.Result - if s.(*SQLiteStmt).s != nil { - stmtArgs := make([]driver.NamedValue, 0, len(args)) + if s.s != nil { na := s.NumInput() if len(args)-start < na { - s.Close() + s.finalize() return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)) } // consume the number of arguments used in the current // statement and append all named arguments not // contained therein - if len(args[start:start+na]) > 0 { - stmtArgs = append(stmtArgs, args[start:start+na]...) - for i := range args { - if (i < start || i >= na) && args[i].Name != "" { - stmtArgs = append(stmtArgs, args[i]) - } - } - for i := range stmtArgs { - stmtArgs[i].Ordinal = i + 1 + stmtArgs = append(stmtArgs[:0], args[start:start+na]...) + for i := range args { + if (i < start || i >= na) && args[i].Name != "" { + stmtArgs = append(stmtArgs, args[i]) } } - res, err = s.(*SQLiteStmt).exec(ctx, stmtArgs) + for i := range stmtArgs { + stmtArgs[i].Ordinal = i + 1 + } + var err error + res, err = s.exec(ctx, stmtArgs) if err != nil && err != driver.ErrSkip { - s.Close() + s.finalize() return nil, err } start += na } - tail := s.(*SQLiteStmt).t - s.Close() - if tail == "" { + s.finalize() + if len(query) == 0 { if res == nil { // https://github.com/mattn/go-sqlite3/issues/963 res = &SQLiteResult{0, 0} } return res, nil } - query = tail } } +// execNoArgsSync processes every SQL statement in query. All processing occurs +// in C code, which reduces the overhead of CGO calls. +func (c *SQLiteConn) execNoArgsSync(query string) (_ driver.Result, err error) { + var rowid, changes C.int64_t + rv := C._sqlite3_exec_no_args(c.db, (*C.char)(unsafe.Pointer(stringData(query))), + C.int(len(query)), &rowid, &changes) + if rv != C.SQLITE_OK { + err = c.lastError() + } + return &SQLiteResult{id: int64(rowid), changes: int64(changes)}, err +} + +func (c *SQLiteConn) execNoArgs(ctx context.Context, query string) (driver.Result, error) { + done := ctx.Done() + if done == nil { + return c.execNoArgsSync(query) + } + + // Fast check if the Context is cancelled + if err := ctx.Err(); err != nil { + return nil, err + } + + ch := make(chan struct{}) + defer close(ch) + go func() { + select { + case <-done: + C.sqlite3_interrupt(c.db) + // Wait until signaled. We need to ensure that this goroutine + // will not call interrupt after this method returns, which is + // why we can't check if only done is closed when waiting below. + <-ch + case <-ch: + } + }() + + res, err := c.execNoArgsSync(query) + + // Stop the goroutine and make sure we're at a point where + // sqlite3_interrupt cannot be called again. + ch <- struct{}{} + + if isInterruptErr(err) { + err = ctx.Err() + } + return res, err +} + // Query implements Queryer. func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) { list := make([]driver.NamedValue, len(args)) @@ -1914,6 +2034,13 @@ func (s *SQLiteStmt) Close() error { return nil } +func (s *SQLiteStmt) finalize() { + if s.s != nil { + C.sqlite3_finalize(s.s) + s.s = nil + } +} + // NumInput return a number of parameters. func (s *SQLiteStmt) NumInput() int { return int(C.sqlite3_bind_parameter_count(s.s)) diff --git a/sqlite3_test.go b/sqlite3_test.go index 94de7386..67aa6ba4 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -10,6 +10,7 @@ package sqlite3 import ( "bytes" + "context" "database/sql" "database/sql/driver" "errors" @@ -1090,6 +1091,67 @@ func TestExecer(t *testing.T) { } } +func TestExecDriverResult(t *testing.T) { + setup := func(t *testing.T) *sql.DB { + db, err := sql.Open("sqlite3", t.TempDir()+"/test.sqlite3") + if err != nil { + t.Fatal("Failed to open database:", err) + } + if _, err := db.Exec(`CREATE TABLE foo (id INTEGER PRIMARY KEY);`); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { db.Close() }) + return db + } + + test := func(t *testing.T, execStmt string, args ...any) { + db := setup(t) + res, err := db.Exec(execStmt, args...) + if err != nil { + t.Fatal(err) + } + rows, err := res.RowsAffected() + if err != nil { + t.Fatal(err) + } + // We only return the changes from the last statement. + if rows != 1 { + t.Errorf("RowsAffected got: %d want: %d", rows, 1) + } + id, err := res.LastInsertId() + if err != nil { + t.Fatal(err) + } + if id != 3 { + t.Errorf("LastInsertId got: %d want: %d", id, 3) + } + var count int64 + err = db.QueryRow(`SELECT COUNT(*) FROM foo WHERE id IN (1, 2, 3);`).Scan(&count) + if err != nil { + t.Fatal(err) + } + if count != 3 { + t.Errorf("Expected count to be %d got: %d", 3, count) + } + } + + t.Run("NoArgs", func(t *testing.T) { + const stmt = ` + INSERT INTO foo(id) VALUES(1); + INSERT INTO foo(id) VALUES(2); + INSERT INTO foo(id) VALUES(3);` + test(t, stmt) + }) + + t.Run("WithArgs", func(t *testing.T) { + const stmt = ` + INSERT INTO foo(id) VALUES(?); + INSERT INTO foo(id) VALUES(?); + INSERT INTO foo(id) VALUES(?);` + test(t, stmt, 1, 2, 3) + }) +} + func TestQueryer(t *testing.T) { tempFilename := TempFilename(t) defer os.Remove(tempFilename) @@ -2106,6 +2168,10 @@ var tests = []testing.InternalTest{ var benchmarks = []testing.InternalBenchmark{ {Name: "BenchmarkExec", F: benchmarkExec}, + {Name: "BenchmarkExecContext", F: benchmarkExecContext}, + {Name: "BenchmarkExecStep", F: benchmarkExecStep}, + {Name: "BenchmarkExecContextStep", F: benchmarkExecContextStep}, + {Name: "BenchmarkExecTx", F: benchmarkExecTx}, {Name: "BenchmarkQuery", F: benchmarkQuery}, {Name: "BenchmarkParams", F: benchmarkParams}, {Name: "BenchmarkStmt", F: benchmarkStmt}, @@ -2459,13 +2525,78 @@ func testExecEmptyQuery(t *testing.T) { // benchmarkExec is benchmark for exec func benchmarkExec(b *testing.B) { + b.Run("Params", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := db.Exec("select ?;", int64(1)); err != nil { + panic(err) + } + } + }) + b.Run("NoParams", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := db.Exec("select 1;"); err != nil { + panic(err) + } + } + }) +} + +func benchmarkExecContext(b *testing.B) { + b.Run("Params", func(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for i := 0; i < b.N; i++ { + if _, err := db.ExecContext(ctx, "select ?;", int64(1)); err != nil { + panic(err) + } + } + }) + b.Run("NoParams", func(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for i := 0; i < b.N; i++ { + if _, err := db.ExecContext(ctx, "select 1;"); err != nil { + panic(err) + } + } + }) +} + +func benchmarkExecTx(b *testing.B) { for i := 0; i < b.N; i++ { - if _, err := db.Exec("select 1"); err != nil { + tx, err := db.Begin() + if err != nil { + panic(err) + } + if _, err := tx.Exec("select 1;"); err != nil { + panic(err) + } + if err := tx.Commit(); err != nil { panic(err) } } } +var largeSelectStmt = strings.Repeat("select 1;\n", 1_000) + +func benchmarkExecStep(b *testing.B) { + for n := 0; n < b.N; n++ { + if _, err := db.Exec(largeSelectStmt); err != nil { + b.Fatal(err) + } + } +} + +func benchmarkExecContextStep(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for n := 0; n < b.N; n++ { + if _, err := db.ExecContext(ctx, largeSelectStmt); err != nil { + b.Fatal(err) + } + } +} + // benchmarkQuery is benchmark for query func benchmarkQuery(b *testing.B) { for i := 0; i < b.N; i++ { diff --git a/unsafe_go120.go b/unsafe_go120.go new file mode 100644 index 00000000..95d673ed --- /dev/null +++ b/unsafe_go120.go @@ -0,0 +1,17 @@ +//go:build !go1.21 +// +build !go1.21 + +package sqlite3 + +import "unsafe" + +// stringData is a safe version of unsafe.StringData that handles empty strings. +func stringData(s string) *byte { + if len(s) != 0 { + b := *(*[]byte)(unsafe.Pointer(&s)) + return &b[0] + } + // The return value of unsafe.StringData + // is unspecified if the string is empty. + return &placeHolder[0] +} diff --git a/unsafe_go121.go b/unsafe_go121.go new file mode 100644 index 00000000..b9c00a12 --- /dev/null +++ b/unsafe_go121.go @@ -0,0 +1,23 @@ +//go:build go1.21 +// +build go1.21 + +// The unsafe.StringData function was made available in Go 1.20 but it +// was not until Go 1.21 that Go was changed to interpret the Go version +// in go.mod (1.19 as of writing this) as the minimum version required +// instead of the exact version. +// +// See: https://github.com/golang/go/issues/59033 + +package sqlite3 + +import "unsafe" + +// stringData is a safe version of unsafe.StringData that handles empty strings. +func stringData(s string) *byte { + if len(s) != 0 { + return unsafe.StringData(s) + } + // The return value of unsafe.StringData + // is unspecified if the string is empty. + return &placeHolder[0] +}