Skip to content

fix(sqlite): query macro argument lifetime use inconsistent with other db platforms #3957

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
12 changes: 6 additions & 6 deletions sqlx-core/src/any/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl Default for AnyArguments<'_> {

impl<'q> AnyArguments<'q> {
#[doc(hidden)]
pub fn convert_to<'a, A: Arguments<'a>>(&'a self) -> Result<A, BoxDynError>
pub fn convert_into<'a, A: Arguments<'a>>(self) -> Result<A, BoxDynError>
where
'q: 'a,
Option<i32>: Type<A::Database> + Encode<'a, A::Database>,
Expand All @@ -60,12 +60,12 @@ impl<'q> AnyArguments<'q> {
i64: Type<A::Database> + Encode<'a, A::Database>,
f32: Type<A::Database> + Encode<'a, A::Database>,
f64: Type<A::Database> + Encode<'a, A::Database>,
&'a str: Type<A::Database> + Encode<'a, A::Database>,
&'a [u8]: Type<A::Database> + Encode<'a, A::Database>,
String: Type<A::Database> + Encode<'a, A::Database>,
Vec<u8>: Type<A::Database> + Encode<'a, A::Database>,
{
let mut out = A::default();

for arg in &self.values.0 {
for arg in self.values.0 {
match arg {
AnyValueKind::Null(AnyTypeInfoKind::Null) => out.add(Option::<i32>::None),
AnyValueKind::Null(AnyTypeInfoKind::Bool) => out.add(Option::<bool>::None),
Expand All @@ -82,8 +82,8 @@ impl<'q> AnyArguments<'q> {
AnyValueKind::BigInt(i) => out.add(i),
AnyValueKind::Real(r) => out.add(r),
AnyValueKind::Double(d) => out.add(d),
AnyValueKind::Text(t) => out.add(&**t),
AnyValueKind::Blob(b) => out.add(&**b),
AnyValueKind::Text(t) => out.add(String::from(t)),
AnyValueKind::Blob(b) => out.add(Vec::from(b)),
}?
}
Ok(out)
Expand Down
5 changes: 2 additions & 3 deletions sqlx-mysql/src/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl AnyConnectionBackend for MySqlConnection {
arguments: Option<AnyArguments<'q>>,
) -> BoxStream<'q, sqlx_core::Result<Either<AnyQueryResult, AnyRow>>> {
let persistent = persistent && arguments.is_some();
let arguments = match arguments.as_ref().map(AnyArguments::convert_to).transpose() {
let arguments = match arguments.map(AnyArguments::convert_into).transpose() {
Ok(arguments) => arguments,
Err(error) => {
return stream::once(future::ready(Err(sqlx_core::Error::Encode(error)))).boxed()
Expand All @@ -111,8 +111,7 @@ impl AnyConnectionBackend for MySqlConnection {
) -> BoxFuture<'q, sqlx_core::Result<Option<AnyRow>>> {
let persistent = persistent && arguments.is_some();
let arguments = arguments
.as_ref()
.map(AnyArguments::convert_to)
.map(AnyArguments::convert_into)
.transpose()
.map_err(sqlx_core::Error::Encode);

Expand Down
5 changes: 2 additions & 3 deletions sqlx-postgres/src/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ impl AnyConnectionBackend for PgConnection {
arguments: Option<AnyArguments<'q>>,
) -> BoxStream<'q, sqlx_core::Result<Either<AnyQueryResult, AnyRow>>> {
let persistent = persistent && arguments.is_some();
let arguments = match arguments.as_ref().map(AnyArguments::convert_to).transpose() {
let arguments = match arguments.map(AnyArguments::convert_into).transpose() {
Ok(arguments) => arguments,
Err(error) => {
return stream::once(future::ready(Err(sqlx_core::Error::Encode(error)))).boxed()
Expand All @@ -113,8 +113,7 @@ impl AnyConnectionBackend for PgConnection {
) -> BoxFuture<'q, sqlx_core::Result<Option<AnyRow>>> {
let persistent = persistent && arguments.is_some();
let arguments = arguments
.as_ref()
.map(AnyArguments::convert_to)
.map(AnyArguments::convert_into)
.transpose()
.map_err(sqlx_core::Error::Encode);

Expand Down
44 changes: 24 additions & 20 deletions sqlx-sqlite/src/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ use sqlx_core::any::{
};
use sqlx_core::sql_str::SqlStr;

use crate::arguments::SqliteArgumentsBuffer;
use crate::type_info::DataType;
use sqlx_core::connection::{ConnectOptions, Connection};
use sqlx_core::database::Database;
use sqlx_core::describe::Describe;
use sqlx_core::executor::Executor;
use sqlx_core::transaction::TransactionManager;
use std::pin::pin;
use std::sync::Arc;

sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Sqlite);

Expand Down Expand Up @@ -203,27 +205,29 @@ impl<'a> TryFrom<&'a AnyConnectOptions> for SqliteConnectOptions {
}
}

/// Instead of `AnyArguments::convert_into()`, we can do a direct mapping and preserve the lifetime.
fn map_arguments(args: AnyArguments<'_>) -> SqliteArguments<'_> {
// Infallible alternative to AnyArguments::convert_into()
fn map_arguments(args: AnyArguments<'_>) -> SqliteArguments {
let values = args
.values
.0
.into_iter()
.map(|val| match val {
AnyValueKind::Null(_) => SqliteArgumentValue::Null,
AnyValueKind::Bool(b) => SqliteArgumentValue::Int(b as i32),
AnyValueKind::SmallInt(i) => SqliteArgumentValue::Int(i as i32),
AnyValueKind::Integer(i) => SqliteArgumentValue::Int(i),
AnyValueKind::BigInt(i) => SqliteArgumentValue::Int64(i),
AnyValueKind::Real(r) => SqliteArgumentValue::Double(r as f64),
AnyValueKind::Double(d) => SqliteArgumentValue::Double(d),
AnyValueKind::Text(t) => SqliteArgumentValue::Text(Arc::new(t.to_string())),
AnyValueKind::Blob(b) => SqliteArgumentValue::Blob(Arc::new(b.to_vec())),
// AnyValueKind is `#[non_exhaustive]` but we should have covered everything
_ => unreachable!("BUG: missing mapping for {val:?}"),
})
.collect();

SqliteArguments {
values: args
.values
.0
.into_iter()
.map(|val| match val {
AnyValueKind::Null(_) => SqliteArgumentValue::Null,
AnyValueKind::Bool(b) => SqliteArgumentValue::Int(b as i32),
AnyValueKind::SmallInt(i) => SqliteArgumentValue::Int(i as i32),
AnyValueKind::Integer(i) => SqliteArgumentValue::Int(i),
AnyValueKind::BigInt(i) => SqliteArgumentValue::Int64(i),
AnyValueKind::Real(r) => SqliteArgumentValue::Double(r as f64),
AnyValueKind::Double(d) => SqliteArgumentValue::Double(d),
AnyValueKind::Text(t) => SqliteArgumentValue::Text(t),
AnyValueKind::Blob(b) => SqliteArgumentValue::Blob(b),
// AnyValueKind is `#[non_exhaustive]` but we should have covered everything
_ => unreachable!("BUG: missing mapping for {val:?}"),
})
.collect(),
values: SqliteArgumentsBuffer::new(values),
}
}

Expand Down
67 changes: 30 additions & 37 deletions sqlx-sqlite/src/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,62 +4,56 @@ use crate::statement::StatementHandle;
use crate::Sqlite;
use atoi::atoi;
use libsqlite3_sys::SQLITE_OK;
use std::borrow::Cow;
use std::sync::Arc;

pub(crate) use sqlx_core::arguments::*;
use sqlx_core::error::BoxDynError;

#[derive(Debug, Clone)]
pub enum SqliteArgumentValue<'q> {
pub enum SqliteArgumentValue {
Null,
Text(Cow<'q, str>),
Blob(Cow<'q, [u8]>),
Text(Arc<String>),
TextSlice(Arc<str>),
Blob(Arc<Vec<u8>>),
Double(f64),
Int(i32),
Int64(i64),
}

#[derive(Default, Debug, Clone)]
pub struct SqliteArguments<'q> {
pub(crate) values: Vec<SqliteArgumentValue<'q>>,
pub struct SqliteArguments {
pub(crate) values: SqliteArgumentsBuffer,
}

impl<'q> SqliteArguments<'q> {
#[derive(Default, Debug, Clone)]
pub struct SqliteArgumentsBuffer(Vec<SqliteArgumentValue>);

impl<'q> SqliteArguments {
pub(crate) fn add<T>(&mut self, value: T) -> Result<(), BoxDynError>
where
T: Encode<'q, Sqlite>,
{
let value_length_before_encoding = self.values.len();
let value_length_before_encoding = self.values.0.len();

match value.encode(&mut self.values) {
Ok(IsNull::Yes) => self.values.push(SqliteArgumentValue::Null),
Ok(IsNull::Yes) => self.values.0.push(SqliteArgumentValue::Null),
Ok(IsNull::No) => {}
Err(error) => {
// reset the value buffer to its previous value if encoding failed so we don't leave a half-encoded value behind
self.values.truncate(value_length_before_encoding);
self.values.0.truncate(value_length_before_encoding);
return Err(error);
}
};

Ok(())
}

pub(crate) fn into_static(self) -> SqliteArguments<'static> {
SqliteArguments {
values: self
.values
.into_iter()
.map(SqliteArgumentValue::into_static)
.collect(),
}
}
}

impl<'q> Arguments<'q> for SqliteArguments<'q> {
impl<'q> Arguments<'q> for SqliteArguments {
type Database = Sqlite;

fn reserve(&mut self, len: usize, _size_hint: usize) {
self.values.reserve(len);
self.values.0.reserve(len);
}

fn add<T>(&mut self, value: T) -> Result<(), BoxDynError>
Expand All @@ -70,11 +64,11 @@ impl<'q> Arguments<'q> for SqliteArguments<'q> {
}

fn len(&self) -> usize {
self.values.len()
self.values.0.len()
}
}

impl SqliteArguments<'_> {
impl SqliteArguments {
pub(super) fn bind(&self, handle: &mut StatementHandle, offset: usize) -> Result<usize, Error> {
let mut arg_i = offset;
// for handle in &statement.handles {
Expand Down Expand Up @@ -103,7 +97,7 @@ impl SqliteArguments<'_> {
arg_i
};

if n > self.values.len() {
if n > self.values.0.len() {
// SQLite treats unbound variables as NULL
// we reproduce this here
// If you are reading this and think this should be an error, open an issue and we can
Expand All @@ -113,32 +107,31 @@ impl SqliteArguments<'_> {
break;
}

self.values[n - 1].bind(handle, param_i)?;
self.values.0[n - 1].bind(handle, param_i)?;
}

Ok(arg_i - offset)
}
}

impl SqliteArgumentValue<'_> {
fn into_static(self) -> SqliteArgumentValue<'static> {
use SqliteArgumentValue::*;
impl SqliteArgumentsBuffer {
#[allow(dead_code)] // clippy incorrectly reports this as unused
pub(crate) fn new(values: Vec<SqliteArgumentValue>) -> SqliteArgumentsBuffer {
Self(values)
}

match self {
Null => Null,
Text(text) => Text(text.into_owned().into()),
Blob(blob) => Blob(blob.into_owned().into()),
Int(v) => Int(v),
Int64(v) => Int64(v),
Double(v) => Double(v),
}
pub(crate) fn push(&mut self, value: SqliteArgumentValue) {
self.0.push(value);
}
}

impl SqliteArgumentValue {
fn bind(&self, handle: &mut StatementHandle, i: usize) -> Result<(), Error> {
use SqliteArgumentValue::*;

let status = match self {
Text(v) => handle.bind_text(i, v),
TextSlice(v) => handle.bind_text(i, v),
Blob(v) => handle.bind_blob(i, v),
Int(v) => handle.bind_int(i, *v),
Int64(v) => handle.bind_int64(i, *v),
Expand Down
14 changes: 7 additions & 7 deletions sqlx-sqlite/src/connection/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub struct ExecuteIter<'a> {
handle: &'a mut ConnectionHandle,
statement: &'a mut VirtualStatement,
logger: QueryLogger,
args: Option<SqliteArguments<'a>>,
args: Option<SqliteArguments>,

/// since a `VirtualStatement` can encompass multiple actual statements,
/// this keeps track of the number of arguments so far
Expand All @@ -19,12 +19,12 @@ pub struct ExecuteIter<'a> {
goto_next: bool,
}

pub(crate) fn iter<'a>(
conn: &'a mut ConnectionState,
pub(crate) fn iter(
conn: &mut ConnectionState,
query: impl SqlSafeStr,
args: Option<SqliteArguments<'a>>,
args: Option<SqliteArguments>,
persistent: bool,
) -> Result<ExecuteIter<'a>, Error> {
) -> Result<ExecuteIter<'_>, Error> {
let query = query.into_sql_str();
// fetch the cached statement or allocate a new one
let statement = conn.statements.get(query.as_str(), persistent)?;
Expand All @@ -43,7 +43,7 @@ pub(crate) fn iter<'a>(

fn bind(
statement: &mut StatementHandle,
arguments: &Option<SqliteArguments<'_>>,
arguments: &Option<SqliteArguments>,
offset: usize,
) -> Result<usize, Error> {
let mut n = 0;
Expand All @@ -56,7 +56,7 @@ fn bind(
}

impl ExecuteIter<'_> {
pub fn finish(&mut self) -> Result<(), Error> {
pub fn finish(self) -> Result<(), Error> {
for res in self {
let _ = res?;
}
Expand Down
6 changes: 3 additions & 3 deletions sqlx-sqlite/src/connection/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ enum Command {
},
Execute {
query: SqlStr,
arguments: Option<SqliteArguments<'static>>,
arguments: Option<SqliteArguments>,
persistent: bool,
tx: flume::Sender<Result<Either<SqliteQueryResult, SqliteRow>, Error>>,
limit: Option<usize>,
Expand Down Expand Up @@ -353,7 +353,7 @@ impl ConnectionWorker {
pub(crate) async fn execute(
&mut self,
query: SqlStr,
args: Option<SqliteArguments<'_>>,
args: Option<SqliteArguments>,
chan_size: usize,
persistent: bool,
limit: Option<usize>,
Expand All @@ -364,7 +364,7 @@ impl ConnectionWorker {
.send_async((
Command::Execute {
query,
arguments: args.map(SqliteArguments::into_static),
arguments: args,
persistent,
tx,
limit,
Expand Down
10 changes: 5 additions & 5 deletions sqlx-sqlite/src/database.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pub(crate) use sqlx_core::database::{Database, HasStatementCache};

use crate::arguments::SqliteArgumentsBuffer;
use crate::{
SqliteArgumentValue, SqliteArguments, SqliteColumn, SqliteConnection, SqliteQueryResult,
SqliteRow, SqliteStatement, SqliteTransactionManager, SqliteTypeInfo, SqliteValue,
SqliteValueRef,
SqliteArguments, SqliteColumn, SqliteConnection, SqliteQueryResult, SqliteRow, SqliteStatement,
SqliteTransactionManager, SqliteTypeInfo, SqliteValue, SqliteValueRef,
};

/// Sqlite database driver.
Expand All @@ -26,8 +26,8 @@ impl Database for Sqlite {
type Value = SqliteValue;
type ValueRef<'r> = SqliteValueRef<'r>;

type Arguments<'q> = SqliteArguments<'q>;
type ArgumentBuffer<'q> = Vec<SqliteArgumentValue<'q>>;
type Arguments<'q> = SqliteArguments;
type ArgumentBuffer<'q> = SqliteArgumentsBuffer;

type Statement = SqliteStatement;

Expand Down
4 changes: 2 additions & 2 deletions sqlx-sqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ extern crate sqlx_core;

use std::sync::atomic::AtomicBool;

pub use arguments::{SqliteArgumentValue, SqliteArguments};
pub use arguments::{SqliteArgumentValue, SqliteArguments, SqliteArgumentsBuffer};
pub use column::SqliteColumn;
#[cfg(feature = "deserialize")]
#[cfg_attr(docsrs, doc(cfg(feature = "deserialize")))]
Expand Down Expand Up @@ -147,7 +147,7 @@ impl<'c, T: Executor<'c, Database = Sqlite>> SqliteExecutor<'c> for T {}
pub type SqliteTransaction<'c> = sqlx_core::transaction::Transaction<'c, Sqlite>;

// NOTE: required due to the lack of lazy normalization
impl_into_arguments_for_arguments!(SqliteArguments<'q>);
impl_into_arguments_for_arguments!(SqliteArguments);
impl_column_index_for_row!(SqliteRow);
impl_column_index_for_statement!(SqliteStatement);
impl_acquire!(Sqlite, SqliteConnection);
Expand Down
Loading
Loading