@@ -30,6 +30,7 @@ package sqlite3
3030#endif
3131#include <stdlib.h>
3232#include <string.h>
33+ #include <ctype.h>
3334
3435#ifdef __CYGWIN__
3536# include <errno.h>
@@ -90,6 +91,16 @@ _sqlite3_exec(sqlite3* db, const char* pcmd, long long* rowid, long long* change
9091 return rv;
9192}
9293
94+ static const char *
95+ _trim_leading_spaces(const char *str) {
96+ if (str) {
97+ while (isspace(*str)) {
98+ str++;
99+ }
100+ }
101+ return str;
102+ }
103+
93104#ifdef SQLITE_ENABLE_UNLOCK_NOTIFY
94105extern int _sqlite3_step_blocking(sqlite3_stmt *stmt);
95106extern int _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* changes);
@@ -110,7 +121,11 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
110121static int
111122_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
112123{
113- return _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail);
124+ int rv = _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail);
125+ if (pzTail) {
126+ *pzTail = _trim_leading_spaces(*pzTail);
127+ }
128+ return rv;
114129}
115130
116131#else
@@ -133,7 +148,11 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
133148static int
134149_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
135150{
136- return sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
151+ int rv = sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
152+ if (pzTail) {
153+ *pzTail = _trim_leading_spaces(*pzTail);
154+ }
155+ return rv;
137156}
138157#endif
139158
@@ -858,25 +877,33 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err
858877}
859878
860879func (c * SQLiteConn ) exec (ctx context.Context , query string , args []driver.NamedValue ) (driver.Result , error ) {
880+ pquery := C .CString (query )
881+ op := pquery // original pointer
882+ defer C .free (unsafe .Pointer (op ))
883+
884+ var stmtArgs []driver.NamedValue
885+ var tail * C.char
886+ s := new (SQLiteStmt ) // escapes to the heap so reuse it
861887 start := 0
862888 for {
863- s , err := c .prepare (ctx , query )
864- if err != nil {
865- return nil , err
889+ * s = SQLiteStmt {c : c } // reset
890+ rv := C ._sqlite3_prepare_v2_internal (c .db , pquery , C .int (- 1 ), & s .s , & tail )
891+ if rv != C .SQLITE_OK {
892+ return nil , c .lastError ()
866893 }
894+
867895 var res driver.Result
868- if s .(* SQLiteStmt ).s != nil {
869- stmtArgs := make ([]driver.NamedValue , 0 , len (args ))
896+ if s .s != nil {
870897 na := s .NumInput ()
871898 if len (args )- start < na {
872- s .Close ()
899+ s .finalize ()
873900 return nil , fmt .Errorf ("not enough args to execute query: want %d got %d" , na , len (args ))
874901 }
875902 // consume the number of arguments used in the current
876903 // statement and append all named arguments not
877904 // contained therein
878905 if len (args [start :start + na ]) > 0 {
879- stmtArgs = append (stmtArgs , args [start :start + na ]... )
906+ stmtArgs = append (stmtArgs [: 0 ] , args [start :start + na ]... )
880907 for i := range args {
881908 if (i < start || i >= na ) && args [i ].Name != "" {
882909 stmtArgs = append (stmtArgs , args [i ])
@@ -886,23 +913,23 @@ func (c *SQLiteConn) exec(ctx context.Context, query string, args []driver.Named
886913 stmtArgs [i ].Ordinal = i + 1
887914 }
888915 }
889- res , err = s .(* SQLiteStmt ).exec (ctx , stmtArgs )
916+ var err error
917+ res , err = s .exec (ctx , stmtArgs )
890918 if err != nil && err != driver .ErrSkip {
891- s .Close ()
919+ s .finalize ()
892920 return nil , err
893921 }
894922 start += na
895923 }
896- tail := s .(* SQLiteStmt ).t
897- s .Close ()
898- if tail == "" {
924+ s .finalize ()
925+ if tail == nil || * tail == '\000' {
899926 if res == nil {
900927 // https://github.com/mattn/go-sqlite3/issues/963
901928 res = & SQLiteResult {0 , 0 }
902929 }
903930 return res , nil
904931 }
905- query = tail
932+ pquery = tail
906933 }
907934}
908935
@@ -919,22 +946,29 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
919946}
920947
921948func (c * SQLiteConn ) query (ctx context.Context , query string , args []driver.NamedValue ) (driver.Rows , error ) {
949+ pquery := C .CString (query )
950+ op := pquery // original pointer
951+ defer C .free (unsafe .Pointer (op ))
952+
953+ var stmtArgs []driver.NamedValue
954+ var tail * C.char
955+ s := new (SQLiteStmt ) // escapes to the heap so reuse it
922956 start := 0
923957 for {
924- stmtArgs := make ([]driver. NamedValue , 0 , len ( args ))
925- s , err := c . prepare ( ctx , query )
926- if err != nil {
927- return nil , err
958+ * s = SQLiteStmt { c : c , cls : true } // reset
959+ rv := C . _sqlite3_prepare_v2_internal ( c . db , pquery , C . int ( - 1 ), & s . s , & tail )
960+ if rv != C . SQLITE_OK {
961+ return nil , c . lastError ()
928962 }
929- s .( * SQLiteStmt ). cls = true
963+
930964 na := s .NumInput ()
931965 if len (args )- start < na {
932966 return nil , fmt .Errorf ("not enough args to execute query: want %d got %d" , na , len (args )- start )
933967 }
934968 // consume the number of arguments used in the current
935969 // statement and append all named arguments not contained
936970 // therein
937- stmtArgs = append (stmtArgs , args [start :start + na ]... )
971+ stmtArgs = append (stmtArgs [: 0 ] , args [start :start + na ]... )
938972 for i := range args {
939973 if (i < start || i >= na ) && args [i ].Name != "" {
940974 stmtArgs = append (stmtArgs , args [i ])
@@ -943,19 +977,18 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.Name
943977 for i := range stmtArgs {
944978 stmtArgs [i ].Ordinal = i + 1
945979 }
946- rows , err := s .( * SQLiteStmt ). query (ctx , stmtArgs )
980+ rows , err := s .query (ctx , stmtArgs )
947981 if err != nil && err != driver .ErrSkip {
948- s .Close ()
982+ s .finalize ()
949983 return rows , err
950984 }
951985 start += na
952- tail := s .(* SQLiteStmt ).t
953- if tail == "" {
986+ if tail == nil || * tail == '\000' {
954987 return rows , nil
955988 }
956989 rows .Close ()
957- s .Close ()
958- query = tail
990+ s .finalize ()
991+ pquery = tail
959992 }
960993}
961994
@@ -1818,8 +1851,11 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er
18181851 return nil , c .lastError ()
18191852 }
18201853 var t string
1821- if tail != nil && * tail != '\000' {
1822- t = strings .TrimSpace (C .GoString (tail ))
1854+ if tail != nil && * tail != 0 {
1855+ n := int (uintptr (unsafe .Pointer (tail ))) - int (uintptr (unsafe .Pointer (pquery )))
1856+ if 0 <= n && n < len (query ) {
1857+ t = strings .TrimSpace (query [n :])
1858+ }
18231859 }
18241860 ss := & SQLiteStmt {c : c , s : s , t : t }
18251861 runtime .SetFinalizer (ss , (* SQLiteStmt ).Close )
@@ -1913,6 +1949,13 @@ func (s *SQLiteStmt) Close() error {
19131949 return nil
19141950}
19151951
1952+ func (s * SQLiteStmt ) finalize () {
1953+ if s .s != nil {
1954+ C .sqlite3_finalize (s .s )
1955+ s .s = nil
1956+ }
1957+ }
1958+
19161959// NumInput return a number of parameters.
19171960func (s * SQLiteStmt ) NumInput () int {
19181961 return int (C .sqlite3_bind_parameter_count (s .s ))
0 commit comments