diff --git a/sqlx-core/src/any/value.rs b/sqlx-core/src/any/value.rs index bc95ffa977..0d27cdcefe 100644 --- a/sqlx-core/src/any/value.rs +++ b/sqlx-core/src/any/value.rs @@ -176,3 +176,44 @@ impl<'r> ValueRef<'r> for AnyValueRef<'r> { } } } + +macro_rules! impl_try_from_any_value_ref { + ($( + #[cfg(feature = $feature:literal)] + $db_name:ident => $value_ref:ty, + )*) => { + $( + #[cfg(feature = $feature)] + impl<'r> TryFrom> for $value_ref { + type Error = crate::error::Error; + + fn try_from(value: AnyValueRef<'r>) -> Result { + #[allow(unreachable_patterns)] + match value.kind { + AnyValueRefKind::$db_name(value) => Ok(value), + _ => Err(crate::error::Error::Configuration( + format!("Expected {}, got {:?}", stringify!($value_ref), value.type_info()).into(), + )), + } + } + } + )* + }; +} + +impl_try_from_any_value_ref! { + #[cfg(feature = "postgres")] + Postgres => PgValueRef<'r>, + + #[cfg(feature = "mysql")] + MySql => MySqlValueRef<'r>, + + #[cfg(feature = "sqlite")] + Sqlite => SqliteValueRef<'r>, + + #[cfg(feature = "mssql")] + Mssql => MssqlValueRef<'r>, + + #[cfg(feature = "odbc")] + Odbc => OdbcValueRef<'r>, +} diff --git a/tests/any/any.rs b/tests/any/any.rs index ca2f7ffb9a..d827e145bb 100644 --- a/tests/any/any.rs +++ b/tests/any/any.rs @@ -337,3 +337,23 @@ async fn it_fails_to_prepare_invalid_statements() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +#[cfg(feature = "postgres")] +async fn it_converts_any_value_ref_to_specific_postgres_types() -> anyhow::Result<()> { + let mut conn = new::().await?; + let dbms_name = conn.dbms_name().await.unwrap_or_default().to_lowercase(); + + if !dbms_name.contains("postgres") { + return Ok(()); + } + + use sqlx_oldapi::postgres::{types::PgRange, PgValueRef, Postgres}; + use std::ops::Bound; + let row = conn.fetch_one("SELECT int4range(1, 10)").await?; + let pgrange: PgValueRef<'_> = row.try_get_raw(0)?.try_into()?; + let decoded: PgRange = as Decode<'_, Postgres>>::decode(pgrange).unwrap(); + assert_eq!(decoded.start, Bound::Included(1)); + assert_eq!(decoded.end, Bound::Excluded(10)); + Ok(()) +}