Skip to content

Commit fdb3fb3

Browse files
authored
allow extracting unerlying db specific value from AnyValueRef (#44)
Implement TryFrom for AnyValueRef to specific database value references and add a test for Postgres conversion
1 parent 5f57b85 commit fdb3fb3

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

sqlx-core/src/any/value.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,44 @@ impl<'r> ValueRef<'r> for AnyValueRef<'r> {
176176
}
177177
}
178178
}
179+
180+
macro_rules! impl_try_from_any_value_ref {
181+
($(
182+
#[cfg(feature = $feature:literal)]
183+
$db_name:ident => $value_ref:ty,
184+
)*) => {
185+
$(
186+
#[cfg(feature = $feature)]
187+
impl<'r> TryFrom<AnyValueRef<'r>> for $value_ref {
188+
type Error = crate::error::Error;
189+
190+
fn try_from(value: AnyValueRef<'r>) -> Result<Self, Self::Error> {
191+
#[allow(unreachable_patterns)]
192+
match value.kind {
193+
AnyValueRefKind::$db_name(value) => Ok(value),
194+
_ => Err(crate::error::Error::Configuration(
195+
format!("Expected {}, got {:?}", stringify!($value_ref), value.type_info()).into(),
196+
)),
197+
}
198+
}
199+
}
200+
)*
201+
};
202+
}
203+
204+
impl_try_from_any_value_ref! {
205+
#[cfg(feature = "postgres")]
206+
Postgres => PgValueRef<'r>,
207+
208+
#[cfg(feature = "mysql")]
209+
MySql => MySqlValueRef<'r>,
210+
211+
#[cfg(feature = "sqlite")]
212+
Sqlite => SqliteValueRef<'r>,
213+
214+
#[cfg(feature = "mssql")]
215+
Mssql => MssqlValueRef<'r>,
216+
217+
#[cfg(feature = "odbc")]
218+
Odbc => OdbcValueRef<'r>,
219+
}

tests/any/any.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,3 +337,23 @@ async fn it_fails_to_prepare_invalid_statements() -> anyhow::Result<()> {
337337

338338
Ok(())
339339
}
340+
341+
#[sqlx_macros::test]
342+
#[cfg(feature = "postgres")]
343+
async fn it_converts_any_value_ref_to_specific_postgres_types() -> anyhow::Result<()> {
344+
let mut conn = new::<Any>().await?;
345+
let dbms_name = conn.dbms_name().await.unwrap_or_default().to_lowercase();
346+
347+
if !dbms_name.contains("postgres") {
348+
return Ok(());
349+
}
350+
351+
use sqlx_oldapi::postgres::{types::PgRange, PgValueRef, Postgres};
352+
use std::ops::Bound;
353+
let row = conn.fetch_one("SELECT int4range(1, 10)").await?;
354+
let pgrange: PgValueRef<'_> = row.try_get_raw(0)?.try_into()?;
355+
let decoded: PgRange<i32> = <PgRange<i32> as Decode<'_, Postgres>>::decode(pgrange).unwrap();
356+
assert_eq!(decoded.start, Bound::Included(1));
357+
assert_eq!(decoded.end, Bound::Excluded(10));
358+
Ok(())
359+
}

0 commit comments

Comments
 (0)