@@ -31,6 +31,7 @@ package sqlite3
3131#endif
3232#include <stdlib.h>
3333#include <string.h>
34+ #include <ctype.h>
3435
3536#ifdef __CYGWIN__
3637# include <errno.h>
@@ -79,6 +80,16 @@ _sqlite3_exec(sqlite3* db, const char* pcmd, long long* rowid, long long* change
7980 return rv;
8081}
8182
83+ static const char *
84+ _trim_leading_spaces(const char *str) {
85+ if (str) {
86+ while (isspace(*str)) {
87+ str++;
88+ }
89+ }
90+ return str;
91+ }
92+
8293#ifdef SQLITE_ENABLE_UNLOCK_NOTIFY
8394extern int _sqlite3_step_blocking(sqlite3_stmt *stmt);
8495extern int _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* changes);
@@ -99,7 +110,11 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
99110static int
100111_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
101112{
102- return _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail);
113+ int rv = _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail);
114+ if (pzTail) {
115+ *pzTail = _trim_leading_spaces(*pzTail);
116+ }
117+ return rv;
103118}
104119
105120#else
@@ -122,7 +137,11 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
122137static int
123138_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
124139{
125- return sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
140+ int rv = sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
141+ if (pzTail) {
142+ *pzTail = _trim_leading_spaces(*pzTail);
143+ }
144+ return rv;
126145}
127146#endif
128147
@@ -848,24 +867,32 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err
848867}
849868
850869func (c * SQLiteConn ) exec (ctx context.Context , query string , args []driver.NamedValue ) (driver.Result , error ) {
870+ pquery := C .CString (query )
871+ op := pquery // original pointer
872+ defer C .free (unsafe .Pointer (op ))
873+
874+ var stmtArgs []driver.NamedValue
875+ var tail * C.char
876+ s := new (SQLiteStmt ) // escapes to the heap so reuse it
851877 start := 0
852878 for {
853- s , err := c .prepare (ctx , query )
854- if err != nil {
855- return nil , err
879+ * s = SQLiteStmt {c : c } // reset
880+ rv := C ._sqlite3_prepare_v2_internal (c .db , pquery , C .int (- 1 ), & s .s , & tail )
881+ if rv != C .SQLITE_OK {
882+ return nil , c .lastError ()
856883 }
884+
857885 var res driver.Result
858- if s .(* SQLiteStmt ).s != nil {
859- stmtArgs := make ([]driver.NamedValue , 0 , len (args ))
886+ if s .s != nil {
860887 na := s .NumInput ()
861888 if len (args )- start < na {
862- s .Close ()
889+ s .finalize ()
863890 return nil , fmt .Errorf ("not enough args to execute query: want %d got %d" , na , len (args ))
864891 }
865892 // consume the number of arguments used in the current
866893 // statement and append all named arguments not
867894 // contained therein
868- stmtArgs = append (stmtArgs , args [start :start + na ]... )
895+ stmtArgs = append (stmtArgs [: 0 ] , args [start :start + na ]... )
869896 for i := range args {
870897 if (i < start || i >= na ) && args [i ].Name != "" {
871898 stmtArgs = append (stmtArgs , args [i ])
@@ -874,23 +901,23 @@ func (c *SQLiteConn) exec(ctx context.Context, query string, args []driver.Named
874901 for i := range stmtArgs {
875902 stmtArgs [i ].Ordinal = i + 1
876903 }
877- res , err = s .(* SQLiteStmt ).exec (ctx , stmtArgs )
904+ var err error
905+ res , err = s .exec (ctx , stmtArgs )
878906 if err != nil && err != driver .ErrSkip {
879- s .Close ()
907+ s .finalize ()
880908 return nil , err
881909 }
882910 start += na
883911 }
884- tail := s .(* SQLiteStmt ).t
885- s .Close ()
886- if tail == "" {
912+ s .finalize ()
913+ if tail == nil || * tail == '\000' {
887914 if res == nil {
888915 // https://github.com/mattn/go-sqlite3/issues/963
889916 res = & SQLiteResult {0 , 0 }
890917 }
891918 return res , nil
892919 }
893- query = tail
920+ pquery = tail
894921 }
895922}
896923
@@ -907,22 +934,29 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
907934}
908935
909936func (c * SQLiteConn ) query (ctx context.Context , query string , args []driver.NamedValue ) (driver.Rows , error ) {
937+ pquery := C .CString (query )
938+ op := pquery // original pointer
939+ defer C .free (unsafe .Pointer (op ))
940+
941+ var stmtArgs []driver.NamedValue
942+ var tail * C.char
943+ s := new (SQLiteStmt ) // escapes to the heap so reuse it
910944 start := 0
911945 for {
912- stmtArgs := make ([]driver. NamedValue , 0 , len ( args ))
913- s , err := c . prepare ( ctx , query )
914- if err != nil {
915- return nil , err
946+ * s = SQLiteStmt { c : c , cls : true } // reset
947+ rv := C . _sqlite3_prepare_v2_internal ( c . db , pquery , C . int ( - 1 ), & s . s , & tail )
948+ if rv != C . SQLITE_OK {
949+ return nil , c . lastError ()
916950 }
917- s .( * SQLiteStmt ). cls = true
951+
918952 na := s .NumInput ()
919953 if len (args )- start < na {
920954 return nil , fmt .Errorf ("not enough args to execute query: want %d got %d" , na , len (args )- start )
921955 }
922956 // consume the number of arguments used in the current
923957 // statement and append all named arguments not contained
924958 // therein
925- stmtArgs = append (stmtArgs , args [start :start + na ]... )
959+ stmtArgs = append (stmtArgs [: 0 ] , args [start :start + na ]... )
926960 for i := range args {
927961 if (i < start || i >= na ) && args [i ].Name != "" {
928962 stmtArgs = append (stmtArgs , args [i ])
@@ -931,19 +965,18 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.Name
931965 for i := range stmtArgs {
932966 stmtArgs [i ].Ordinal = i + 1
933967 }
934- rows , err := s .( * SQLiteStmt ). query (ctx , stmtArgs )
968+ rows , err := s .query (ctx , stmtArgs )
935969 if err != nil && err != driver .ErrSkip {
936- s .Close ()
970+ s .finalize ()
937971 return rows , err
938972 }
939973 start += na
940- tail := s .(* SQLiteStmt ).t
941- if tail == "" {
974+ if tail == nil || * tail == '\000' {
942975 return rows , nil
943976 }
944977 rows .Close ()
945- s .Close ()
946- query = tail
978+ s .finalize ()
979+ pquery = tail
947980 }
948981}
949982
@@ -1805,8 +1838,11 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er
18051838 return nil , c .lastError ()
18061839 }
18071840 var t string
1808- if tail != nil && * tail != '\000' {
1809- t = strings .TrimSpace (C .GoString (tail ))
1841+ if tail != nil && * tail != 0 {
1842+ n := int (uintptr (unsafe .Pointer (tail ))) - int (uintptr (unsafe .Pointer (pquery )))
1843+ if 0 <= n && n < len (query ) {
1844+ t = strings .TrimSpace (query [n :])
1845+ }
18101846 }
18111847 ss := & SQLiteStmt {c : c , s : s , t : t }
18121848 runtime .SetFinalizer (ss , (* SQLiteStmt ).Close )
@@ -1899,6 +1935,13 @@ func (s *SQLiteStmt) Close() error {
18991935 return nil
19001936}
19011937
1938+ func (s * SQLiteStmt ) finalize () {
1939+ if s .s != nil {
1940+ C .sqlite3_finalize (s .s )
1941+ s .s = nil
1942+ }
1943+ }
1944+
19021945// NumInput return a number of parameters.
19031946func (s * SQLiteStmt ) NumInput () int {
19041947 return int (C .sqlite3_bind_parameter_count (s .s ))
0 commit comments