diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..8af59dd --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[env] +RUST_TEST_THREADS = "1" diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ca8335f..f1b240f 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -57,30 +57,23 @@ jobs: toolchain: ${{ matrix.rust }} override: true - - name: check avoid-dev-deps - uses: actions-rs/cargo@v1 - if: matrix.rust == 'nightly' - with: - command: check - args: --all -Z avoid-dev-deps - - name: pg tests uses: actions-rs/cargo@v1 with: command: test - args: --all --features pg,async_std -- --test-threads=1 + args: --all --features pg -- --test-threads=1 - name: sqlite tests uses: actions-rs/cargo@v1 with: command: test - args: --all --features sqlite,async_std + args: --all --features sqlite - name: mysql tests uses: actions-rs/cargo@v1 with: command: test - args: --all --features mysql,async_std -- --test-threads=1 + args: --all --features mysql -- --test-threads=1 check_fmt_and_docs: name: Checking fmt, clippy, and docs diff --git a/Cargo.toml b/Cargo.toml index b452a5d..3a21b0c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,25 +12,29 @@ keywords = ["sessions", "sqlx", "sqlite", "postgres", "mysql"] categories = ["web-programming::http-server", "web-programming", "database"] [package.metadata.docs.rs] -features = ["pg", "sqlite", "mysql", "async_std"] +features = ["pg", "sqlite", "mysql"] [features] -default = ["native-tls"] +default = [] sqlite = ["sqlx/sqlite"] pg = ["sqlx/postgres", "sqlx/json"] native-tls = ["sqlx/runtime-async-std-native-tls"] rustls = ["sqlx/runtime-async-std-rustls"] -async_std = ["async-std"] mysql = ["sqlx/mysql", "sqlx/json"] [dependencies] -async-session = "3.0.0" -sqlx = { version = "0.6.2", features = ["chrono"] } -async-std = { version = "1.12.0", optional = true } +async-session = { git = "https://github.com/http-rs/async-session", branch = "overhaul-session-and-session-store", default-features = false } +sqlx = { version = "0.6.2", features = ["time"] } +log = "0.4.17" +serde_json = "1.0.93" +serde = "1.0.152" +thiserror = "1.0.38" +time = "0.3.18" +base64 = "0.21.0" [dev-dependencies] async-std = { version = "1.12.0", features = ["attributes"] } [dev-dependencies.sqlx] version = "0.6.2" -features = ["chrono", "runtime-async-std-native-tls"] +features = ["runtime-async-std-native-tls"] diff --git a/README.md b/README.md index 11a16c0..c1952b8 100644 --- a/README.md +++ b/README.md @@ -34,18 +34,6 @@ async-sqlx-session = { version = "0.4.0", features = ["pg"] } async-sqlx-session = { version = "0.4.0", features = ["mysql"] } ``` -### Optional `async_std` feature - -To use the `spawn_cleanup_task` function on the async-std runtime, -enable the `async_std` feature, which can be combined with any of the -above datastores. - -```toml -async-sqlx-session = { version = "0.4.0", features = ["pg", "async_std"] } -``` - -## Cargo Features: - ## Safety This crate uses ``#![deny(unsafe_code)]`` to ensure everything is implemented in 100% Safe Rust. @@ -58,7 +46,7 @@ Licensed under either of Apache License, Version
- +p Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in this crate by you, as defined in the Apache-2.0 license, shall diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..ffadb8a --- /dev/null +++ b/src/error.rs @@ -0,0 +1,17 @@ +/// Errors that can arise in the operation of the session stores +/// included in this crate +#[derive(thiserror::Error, Debug)] +#[non_exhaustive] +pub enum Error { + /// an error that comes from sqlx + #[error(transparent)] + SqlxError(#[from] sqlx::Error), + + /// an error that comes from serde_json + #[error(transparent)] + SerdeJson(#[from] serde_json::Error), + + /// an error that comes from base64 + #[error(transparent)] + Base64(#[from] base64::DecodeError), +} diff --git a/src/lib.rs b/src/lib.rs index 3fff7c2..ed73a55 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,3 +54,6 @@ pub use pg::PostgresSessionStore; mod mysql; #[cfg(feature = "mysql")] pub use mysql::MySqlSessionStore; + +mod error; +pub use error::Error; diff --git a/src/mysql.rs b/src/mysql.rs index 0edf82c..5f221b8 100644 --- a/src/mysql.rs +++ b/src/mysql.rs @@ -1,25 +1,24 @@ -use async_session::{async_trait, chrono::Utc, log, serde_json, Result, Session, SessionStore}; +use async_session::{async_trait, Session, SessionStore}; use sqlx::{pool::PoolConnection, Executor, MySql, MySqlPool}; +use time::OffsetDateTime; /// sqlx mysql session store for async-sessions /// /// ```rust -/// use async_sqlx_session::MySqlSessionStore; +/// use async_sqlx_session::{MySqlSessionStore, Error}; /// use async_session::{Session, SessionStore}; /// use std::time::Duration; /// -/// # fn main() -> async_session::Result { async_std::task::block_on(async { +/// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = MySqlSessionStore::new(&std::env::var("MYSQL_TEST_DB_URL").unwrap()).await?; /// store.migrate().await?; /// # store.clear_store().await?; -/// # #[cfg(feature = "async_std")] -/// store.spawn_cleanup_task(Duration::from_secs(60 * 60)); /// /// let mut session = Session::new(); /// session.insert("key", vec![1,2,3]); /// -/// let cookie_value = store.store_session(session).await?.unwrap(); -/// let session = store.load_session(cookie_value).await?.unwrap(); +/// let cookie_value = store.store_session(&mut session).await?.unwrap(); +/// let session = store.load_session(&cookie_value).await?.unwrap(); /// assert_eq!(session.get::>("key").unwrap(), vec![1,2,3]); /// # Ok(()) }) } /// @@ -36,9 +35,8 @@ impl MySqlSessionStore { /// with [`with_table_name`](crate::MySqlSessionStore::with_table_name). /// /// ```rust - /// # use async_sqlx_session::MySqlSessionStore; - /// # use async_session::Result; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # use async_sqlx_session::{MySqlSessionStore, Error}; + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let pool = sqlx::MySqlPool::connect(&std::env::var("MYSQL_TEST_DB_URL").unwrap()).await.unwrap(); /// let store = MySqlSessionStore::from_client(pool) /// .with_table_name("custom_table_name"); @@ -60,9 +58,8 @@ impl MySqlSessionStore { /// [`new_with_table_name`](crate::MySqlSessionStore::new_with_table_name) /// /// ```rust - /// # use async_sqlx_session::MySqlSessionStore; - /// # use async_session::Result; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # use async_sqlx_session::{MySqlSessionStore, Error}; + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = MySqlSessionStore::new(&std::env::var("MYSQL_TEST_DB_URL").unwrap()).await?; /// store.migrate().await; /// # Ok(()) }) } @@ -80,9 +77,8 @@ impl MySqlSessionStore { /// [`new_with_table_name`](crate::MySqlSessionStore::new_with_table_name) /// /// ```rust - /// # use async_sqlx_session::MySqlSessionStore; - /// # use async_session::Result; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # use async_sqlx_session::{MySqlSessionStore, Error}; + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = MySqlSessionStore::new_with_table_name(&std::env::var("MYSQL_TEST_DB_URL").unwrap(), "custom_table_name").await?; /// store.migrate().await; /// # Ok(()) }) } @@ -94,9 +90,8 @@ impl MySqlSessionStore { /// Chainable method to add a custom table name. This will panic /// if the table name is not `[a-zA-Z0-9_-]`. /// ```rust - /// # use async_sqlx_session::MySqlSessionStore; - /// # use async_session::Result; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # use async_sqlx_session::{MySqlSessionStore, Error}; + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = MySqlSessionStore::new(&std::env::var("MYSQL_TEST_DB_URL").unwrap()).await? /// .with_table_name("custom_name"); /// store.migrate().await; @@ -104,9 +99,8 @@ impl MySqlSessionStore { /// ``` /// /// ```should_panic - /// # use async_sqlx_session::MySqlSessionStore; - /// # use async_session::Result; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # use async_sqlx_session::{MySqlSessionStore, Error}; + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = MySqlSessionStore::new(&std::env::var("MYSQL_TEST_DB_URL").unwrap()).await? /// .with_table_name("johnny (); drop users;"); /// # Ok(()) }) } @@ -134,13 +128,13 @@ impl MySqlSessionStore { /// exactly-once modifications to the schema of the session table /// on breaking releases. /// ```rust - /// # use async_sqlx_session::MySqlSessionStore; - /// # use async_session::{Result, SessionStore, Session}; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # use async_sqlx_session::{MySqlSessionStore, Error}; + /// # use async_session::{SessionStore, Session}; + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = MySqlSessionStore::new(&std::env::var("MYSQL_TEST_DB_URL").unwrap()).await?; /// # store.clear_store().await?; /// store.migrate().await?; - /// store.store_session(Session::new()).await?; + /// store.store_session(&mut Session::new()).await?; /// store.migrate().await?; // calling it a second time is safe /// assert_eq!(store.count().await?, 1); /// # Ok(()) }) } @@ -176,56 +170,20 @@ impl MySqlSessionStore { self.client.acquire().await } - /// Spawns an async_std::task that clears out stale (expired) - /// sessions on a periodic basis. Only available with the - /// async_std feature enabled. - /// - /// ```rust,no_run - /// # use async_sqlx_session::MySqlSessionStore; - /// # use async_session::{Result, SessionStore, Session}; - /// # use std::time::Duration; - /// # fn main() -> Result { async_std::task::block_on(async { - /// let store = MySqlSessionStore::new(&std::env::var("MYSQL_TEST_DB_URL").unwrap()).await?; - /// store.migrate().await?; - /// # let join_handle = - /// store.spawn_cleanup_task(Duration::from_secs(1)); - /// let mut session = Session::new(); - /// session.expire_in(Duration::from_secs(0)); - /// store.store_session(session).await?; - /// assert_eq!(store.count().await?, 1); - /// async_std::task::sleep(Duration::from_secs(2)).await; - /// assert_eq!(store.count().await?, 0); - /// # join_handle.cancel().await; - /// # Ok(()) }) } - /// ``` - #[cfg(feature = "async_std")] - pub fn spawn_cleanup_task( - &self, - period: std::time::Duration, - ) -> async_std::task::JoinHandle<()> { - let store = self.clone(); - async_std::task::spawn(async move { - loop { - async_std::task::sleep(period).await; - if let Err(error) = store.cleanup().await { - log::error!("cleanup error: {}", error); - } - } - }) - } - /// Performs a one-time cleanup task that clears out stale /// (expired) sessions. You may want to call this from cron. /// ```rust - /// # use async_sqlx_session::MySqlSessionStore; - /// # use async_session::{chrono::{Utc,Duration}, Result, SessionStore, Session}; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # use async_sqlx_session::{MySqlSessionStore, Error}; + /// # use async_session::{SessionStore, Session}; + /// # use std::time::Duration; + /// # use time::OffsetDateTime; + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = MySqlSessionStore::new(&std::env::var("MYSQL_TEST_DB_URL").unwrap()).await?; /// store.migrate().await?; /// # store.clear_store().await?; /// let mut session = Session::new(); - /// session.set_expiry(Utc::now() - Duration::seconds(5)); - /// store.store_session(session).await?; + /// session.set_expiry(OffsetDateTime::now_utc() - Duration::from_secs(5)); + /// store.store_session(&mut session).await?; /// assert_eq!(store.count().await?, 1); /// store.cleanup().await?; /// assert_eq!(store.count().await?, 0); @@ -234,7 +192,7 @@ impl MySqlSessionStore { pub async fn cleanup(&self) -> sqlx::Result<()> { let mut connection = self.connection().await?; sqlx::query(&self.substitute_table_name("DELETE FROM %%TABLE_NAME%% WHERE expires < ?")) - .bind(Utc::now()) + .bind(OffsetDateTime::now_utc()) .execute(&mut connection) .await?; @@ -245,15 +203,15 @@ impl MySqlSessionStore { /// expired sessions /// /// ```rust - /// # use async_sqlx_session::MySqlSessionStore; - /// # use async_session::{Result, SessionStore, Session}; + /// # use async_sqlx_session::{MySqlSessionStore, Error}; + /// # use async_session::{SessionStore, Session}; /// # use std::time::Duration; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = MySqlSessionStore::new(&std::env::var("MYSQL_TEST_DB_URL").unwrap()).await?; /// store.migrate().await?; /// # store.clear_store().await?; /// assert_eq!(store.count().await?, 0); - /// store.store_session(Session::new()).await?; + /// store.store_session(&mut Session::new()).await?; /// assert_eq!(store.count().await?, 1); /// # Ok(()) }) } /// ``` @@ -270,26 +228,34 @@ impl MySqlSessionStore { #[async_trait] impl SessionStore for MySqlSessionStore { - async fn load_session(&self, cookie_value: String) -> Result> { - let id = Session::id_from_cookie_value(&cookie_value)?; + type Error = crate::Error; + + async fn load_session(&self, cookie_value: &str) -> Result, Self::Error> { + let id = Session::id_from_cookie_value(cookie_value)?; let mut connection = self.connection().await?; - let result: Option<(String,)> = sqlx::query_as(&self.substitute_table_name( - "SELECT session FROM %%TABLE_NAME%% WHERE id = ? AND (expires IS NULL OR expires > ?)", + let result: Option<(String, Option)> = sqlx::query_as(&self.substitute_table_name( + "SELECT session, expires FROM %%TABLE_NAME%% WHERE id = ? AND (expires IS NULL OR expires > ?)", )) .bind(&id) - .bind(Utc::now()) + .bind(OffsetDateTime::now_utc()) .fetch_optional(&mut connection) .await?; - Ok(result - .map(|(session,)| serde_json::from_str(&session)) - .transpose()?) + if let Some((data, expiry)) = result { + Ok(Some(Session::from_parts( + id, + serde_json::from_str(&data)?, + expiry, + ))) + } else { + Ok(None) + } } - async fn store_session(&self, session: Session) -> Result> { + async fn store_session(&self, session: &mut Session) -> Result, Self::Error> { let id = session.id(); - let string = serde_json::to_string(&session)?; + let string = serde_json::to_string(session.data())?; let mut connection = self.connection().await?; sqlx::query(&self.substitute_table_name( @@ -301,27 +267,27 @@ impl SessionStore for MySqlSessionStore { session = VALUES(session) "#, )) - .bind(&id) + .bind(id) .bind(&string) - .bind(&session.expiry()) + .bind(session.expiry()) .execute(&mut connection) .await?; - Ok(session.into_cookie_value()) + Ok(session.take_cookie_value()) } - async fn destroy_session(&self, session: Session) -> Result { + async fn destroy_session(&self, session: &mut Session) -> Result<(), Self::Error> { let id = session.id(); let mut connection = self.connection().await?; sqlx::query(&self.substitute_table_name("DELETE FROM %%TABLE_NAME%% WHERE id = ?")) - .bind(&id) + .bind(id) .execute(&mut connection) .await?; Ok(()) } - async fn clear_store(&self) -> Result { + async fn clear_store(&self) -> Result<(), Self::Error> { let mut connection = self.connection().await?; sqlx::query(&self.substitute_table_name("TRUNCATE %%TABLE_NAME%%")) .execute(&mut connection) @@ -334,8 +300,9 @@ impl SessionStore for MySqlSessionStore { #[cfg(test)] mod tests { use super::*; - use async_session::chrono::DateTime; - use std::time::Duration; + use serde_json::{json, Value}; + use std::{collections::HashMap, time::Duration}; + use time::OffsetDateTime; async fn test_store() -> MySqlSessionStore { let store = MySqlSessionStore::new(&std::env::var("MYSQL_TEST_DB_URL").unwrap()) @@ -353,14 +320,14 @@ mod tests { } #[async_std::test] - async fn creating_a_new_session_with_no_expiry() -> Result { + async fn creating_a_new_session_with_no_expiry() -> Result<(), crate::Error> { let store = test_store().await; let mut session = Session::new(); session.insert("key", "value")?; let cloned = session.clone(); - let cookie_value = store.store_session(session).await?.unwrap(); + let cookie_value = store.store_session(&mut session).await?.unwrap(); - let (id, expires, serialized, count): (String, Option>, String, i64) = + let (id, expires, serialized, count): (String, Option, String, i64) = sqlx::query_as("select id, expires, session, (select count(*) from async_sessions) from async_sessions") .fetch_one(&mut store.connection().await?) .await?; @@ -369,11 +336,13 @@ mod tests { assert_eq!(id, cloned.id()); assert_eq!(expires, None); - let deserialized_session: Session = serde_json::from_str(&serialized)?; - assert_eq!(cloned.id(), deserialized_session.id()); - assert_eq!("value", &deserialized_session.get::("key").unwrap()); + let deserialized_session: HashMap = serde_json::from_str(&serialized)?; + assert_eq!( + &json!("value"), + deserialized_session.get(&String::from("key")).unwrap() + ); - let loaded_session = store.load_session(cookie_value).await?.unwrap(); + let loaded_session = store.load_session(&cookie_value).await?.unwrap(); assert_eq!(cloned.id(), loaded_session.id()); assert_eq!("value", &loaded_session.get::("key").unwrap()); @@ -382,19 +351,19 @@ mod tests { } #[async_std::test] - async fn updating_a_session() -> Result { + async fn updating_a_session() -> Result<(), crate::Error> { let store = test_store().await; let mut session = Session::new(); let original_id = session.id().to_owned(); session.insert("key", "value")?; - let cookie_value = store.store_session(session).await?.unwrap(); + let cookie_value = store.store_session(&mut session).await?.unwrap(); - let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); + let mut session = store.load_session(&cookie_value).await?.unwrap(); session.insert("key", "other value")?; - assert_eq!(None, store.store_session(session).await?); + assert_eq!(None, store.store_session(&mut session).await?); - let session = store.load_session(cookie_value.clone()).await?.unwrap(); + let session = store.load_session(&cookie_value).await?.unwrap(); assert_eq!(session.get::("key").unwrap(), "other value"); let (id, count): (String, i64) = @@ -409,95 +378,103 @@ mod tests { } #[async_std::test] - async fn updating_a_session_extending_expiry() -> Result { + async fn updating_a_session_extending_expiry() -> Result<(), crate::Error> { let store = test_store().await; let mut session = Session::new(); session.expire_in(Duration::from_secs(10)); let original_id = session.id().to_owned(); - let original_expires = session.expiry().unwrap().clone(); - let cookie_value = store.store_session(session).await?.unwrap(); - - let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); - assert_eq!(session.expiry().unwrap(), &original_expires); + let original_expires = *session.expiry().unwrap(); + let cookie_value = store.store_session(&mut session).await?.unwrap(); + + let mut session = store.load_session(&cookie_value).await?.unwrap(); + assert_eq!( + session.expiry().unwrap().unix_timestamp(), + original_expires.unix_timestamp() + ); session.expire_in(Duration::from_secs(20)); - let new_expires = session.expiry().unwrap().clone(); - store.store_session(session).await?; + let new_expires = *session.expiry().unwrap(); + store.store_session(&mut session).await?; - let session = store.load_session(cookie_value.clone()).await?.unwrap(); - assert_eq!(session.expiry().unwrap(), &new_expires); + let session = store.load_session(&cookie_value).await?.unwrap(); + assert_eq!( + session.expiry().unwrap().unix_timestamp(), + new_expires.unix_timestamp() + ); - let (id, expires, count): (String, DateTime, i64) = sqlx::query_as( + let (id, expires, count): (String, Option, i64) = sqlx::query_as( "select id, expires, (select count(*) from async_sessions) from async_sessions", ) .fetch_one(&mut store.connection().await?) .await?; assert_eq!(1, count); - assert_eq!(expires.timestamp_millis(), new_expires.timestamp_millis()); + assert_eq!( + expires.unwrap().unix_timestamp(), + new_expires.unix_timestamp() + ); assert_eq!(original_id, id); Ok(()) } #[async_std::test] - async fn creating_a_new_session_with_expiry() -> Result { + async fn creating_a_new_session_with_expiry() -> Result<(), crate::Error> { let store = test_store().await; let mut session = Session::new(); - session.expire_in(Duration::from_secs(1)); + session.expire_in(std::time::Duration::from_secs(1)); session.insert("key", "value")?; let cloned = session.clone(); - let cookie_value = store.store_session(session).await?.unwrap(); + let cookie_value = store.store_session(&mut session).await?.unwrap(); - let (id, expires, serialized, count): (String, Option>, String, i64) = + let (id, expires, serialized, count): (String, Option, String, i64) = sqlx::query_as("select id, expires, session, (select count(*) from async_sessions) from async_sessions") .fetch_one(&mut store.connection().await?) .await?; assert_eq!(1, count); assert_eq!(id, cloned.id()); - assert!(expires.unwrap() > Utc::now()); + assert!(expires.unwrap() > OffsetDateTime::now_utc()); - let deserialized_session: Session = serde_json::from_str(&serialized)?; - assert_eq!(cloned.id(), deserialized_session.id()); - assert_eq!("value", &deserialized_session.get::("key").unwrap()); + let deserialized_session: HashMap = serde_json::from_str(&serialized)?; + assert_eq!(&json!("value"), deserialized_session.get("key").unwrap()); - let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap(); + let loaded_session = store.load_session(&cookie_value).await?.unwrap(); assert_eq!(cloned.id(), loaded_session.id()); assert_eq!("value", &loaded_session.get::("key").unwrap()); assert!(!loaded_session.is_expired()); - async_std::task::sleep(Duration::from_secs(1)).await; - assert_eq!(None, store.load_session(cookie_value).await?); + async_std::task::sleep(std::time::Duration::from_secs(1)).await; + assert_eq!(None, store.load_session(&cookie_value).await?); Ok(()) } #[async_std::test] - async fn destroying_a_single_session() -> Result { + async fn destroying_a_single_session() -> Result<(), crate::Error> { let store = test_store().await; for _ in 0..3i8 { - store.store_session(Session::new()).await?; + store.store_session(&mut Session::new()).await?; } - let cookie = store.store_session(Session::new()).await?.unwrap(); + let cookie = store.store_session(&mut Session::new()).await?.unwrap(); assert_eq!(4, store.count().await?); - let session = store.load_session(cookie.clone()).await?.unwrap(); - store.destroy_session(session.clone()).await.unwrap(); - assert_eq!(None, store.load_session(cookie).await?); + let mut session = store.load_session(&cookie).await?.unwrap(); + store.destroy_session(&mut session).await.unwrap(); + assert_eq!(None, store.load_session(&cookie).await?); assert_eq!(3, store.count().await?); // // attempting to destroy the session again is not an error - assert!(store.destroy_session(session).await.is_ok()); + assert!(store.destroy_session(&mut session).await.is_ok()); Ok(()) } #[async_std::test] - async fn clearing_the_whole_store() -> Result { + async fn clearing_the_whole_store() -> Result<(), crate::Error> { let store = test_store().await; for _ in 0..3i8 { - store.store_session(Session::new()).await?; + store.store_session(&mut Session::new()).await?; } assert_eq!(3, store.count().await?); diff --git a/src/pg.rs b/src/pg.rs index 6cdf4a6..f0435d9 100644 --- a/src/pg.rs +++ b/src/pg.rs @@ -1,26 +1,24 @@ -use async_session::{async_trait, chrono::Utc, log, serde_json, Result, Session, SessionStore}; +use async_session::{async_trait, Session, SessionStore}; use sqlx::{pool::PoolConnection, Executor, PgPool, Postgres}; +use time::OffsetDateTime; /// sqlx postgres session store for async-sessions /// /// ```rust -/// use async_sqlx_session::PostgresSessionStore; +/// use async_sqlx_session::{PostgresSessionStore, Error}; /// use async_session::{Session, SessionStore}; /// use std::time::Duration; /// -/// # fn main() -> async_session::Result { async_std::task::block_on(async { +/// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = PostgresSessionStore::new(&std::env::var("PG_TEST_DB_URL").unwrap()).await?; /// store.migrate().await?; /// # store.clear_store().await?; -/// # #[cfg(feature = "async_std")] { -/// store.spawn_cleanup_task(Duration::from_secs(60 * 60)); -/// # } /// /// let mut session = Session::new(); /// session.insert("key", vec![1,2,3]); /// -/// let cookie_value = store.store_session(session).await?.unwrap(); -/// let session = store.load_session(cookie_value).await?.unwrap(); +/// let cookie_value = store.store_session(&mut session).await?.unwrap(); +/// let session = store.load_session(&cookie_value).await?.unwrap(); /// assert_eq!(session.get::>("key").unwrap(), vec![1,2,3]); /// # Ok(()) }) } /// @@ -37,9 +35,8 @@ impl PostgresSessionStore { /// with [`with_table_name`](crate::PostgresSessionStore::with_table_name). /// /// ```rust - /// # use async_sqlx_session::PostgresSessionStore; - /// # use async_session::Result; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # use async_sqlx_session::{PostgresSessionStore, Error}; + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let pool = sqlx::PgPool::connect(&std::env::var("PG_TEST_DB_URL").unwrap()).await.unwrap(); /// let store = PostgresSessionStore::from_client(pool) /// .with_table_name("custom_table_name"); @@ -61,9 +58,8 @@ impl PostgresSessionStore { /// [`new_with_table_name`](crate::PostgresSessionStore::new_with_table_name) /// /// ```rust - /// # use async_sqlx_session::PostgresSessionStore; - /// # use async_session::Result; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # use async_sqlx_session::{PostgresSessionStore, Error}; + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = PostgresSessionStore::new(&std::env::var("PG_TEST_DB_URL").unwrap()).await?; /// store.migrate().await; /// # Ok(()) }) } @@ -81,9 +77,8 @@ impl PostgresSessionStore { /// [`new_with_table_name`](crate::PostgresSessionStore::new_with_table_name) /// /// ```rust - /// # use async_sqlx_session::PostgresSessionStore; - /// # use async_session::Result; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # use async_sqlx_session::{PostgresSessionStore, Error}; + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = PostgresSessionStore::new_with_table_name(&std::env::var("PG_TEST_DB_URL").unwrap(), "custom_table_name").await?; /// store.migrate().await; /// # Ok(()) }) } @@ -95,9 +90,8 @@ impl PostgresSessionStore { /// Chainable method to add a custom table name. This will panic /// if the table name is not `[a-zA-Z0-9_-]+`. /// ```rust - /// # use async_sqlx_session::PostgresSessionStore; - /// # use async_session::Result; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # use async_sqlx_session::{PostgresSessionStore, Error}; + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = PostgresSessionStore::new(&std::env::var("PG_TEST_DB_URL").unwrap()).await? /// .with_table_name("custom_name"); /// store.migrate().await; @@ -105,9 +99,8 @@ impl PostgresSessionStore { /// ``` /// /// ```should_panic - /// # use async_sqlx_session::PostgresSessionStore; - /// # use async_session::Result; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # use async_sqlx_session::{PostgresSessionStore, Error}; + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = PostgresSessionStore::new(&std::env::var("PG_TEST_DB_URL").unwrap()).await? /// .with_table_name("johnny (); drop users;"); /// # Ok(()) }) } @@ -135,13 +128,13 @@ impl PostgresSessionStore { /// exactly-once modifications to the schema of the session table /// on breaking releases. /// ```rust - /// # use async_sqlx_session::PostgresSessionStore; - /// # use async_session::{Result, SessionStore, Session}; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # use async_sqlx_session::{PostgresSessionStore, Error}; + /// # use async_session::{SessionStore, Session}; + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = PostgresSessionStore::new(&std::env::var("PG_TEST_DB_URL").unwrap()).await?; /// # store.clear_store().await?; /// store.migrate().await?; - /// store.store_session(Session::new()).await?; + /// store.store_session(&mut Session::new()).await?; /// store.migrate().await?; // calling it a second time is safe /// assert_eq!(store.count().await?, 1); /// # Ok(()) }) } @@ -173,57 +166,19 @@ impl PostgresSessionStore { self.client.acquire().await } - /// Spawns an async_std::task that clears out stale (expired) - /// sessions on a periodic basis. Only available with the - /// async_std feature enabled. - /// - /// ```rust,no_run - /// # use async_sqlx_session::PostgresSessionStore; - /// # use async_session::{Result, SessionStore, Session}; - /// # use std::time::Duration; - /// # fn main() -> Result { async_std::task::block_on(async { - /// let store = PostgresSessionStore::new(&std::env::var("PG_TEST_DB_URL").unwrap()).await?; - /// store.migrate().await?; - /// # let join_handle = - /// store.spawn_cleanup_task(Duration::from_secs(1)); - /// let mut session = Session::new(); - /// session.expire_in(Duration::from_secs(0)); - /// store.store_session(session).await?; - /// assert_eq!(store.count().await?, 1); - /// async_std::task::sleep(Duration::from_secs(2)).await; - /// assert_eq!(store.count().await?, 0); - /// # join_handle.cancel().await; - /// # Ok(()) }) } - /// ``` - #[cfg(feature = "async_std")] - pub fn spawn_cleanup_task( - &self, - period: std::time::Duration, - ) -> async_std::task::JoinHandle<()> { - use async_std::task; - let store = self.clone(); - task::spawn(async move { - loop { - task::sleep(period).await; - if let Err(error) = store.cleanup().await { - log::error!("cleanup error: {}", error); - } - } - }) - } - /// Performs a one-time cleanup task that clears out stale /// (expired) sessions. You may want to call this from cron. /// ```rust - /// # use async_sqlx_session::PostgresSessionStore; - /// # use async_session::{chrono::{Utc,Duration}, Result, SessionStore, Session}; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # use async_sqlx_session::{PostgresSessionStore, Error}; + /// # use async_session::{SessionStore, Session}; + /// # use time::{OffsetDateTime, Duration}; + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = PostgresSessionStore::new(&std::env::var("PG_TEST_DB_URL").unwrap()).await?; /// store.migrate().await?; /// # store.clear_store().await?; /// let mut session = Session::new(); - /// session.set_expiry(Utc::now() - Duration::seconds(5)); - /// store.store_session(session).await?; + /// session.set_expiry(OffsetDateTime::now_utc() - Duration::seconds(5)); + /// store.store_session(&mut session).await?; /// assert_eq!(store.count().await?, 1); /// store.cleanup().await?; /// assert_eq!(store.count().await?, 0); @@ -232,7 +187,7 @@ impl PostgresSessionStore { pub async fn cleanup(&self) -> sqlx::Result<()> { let mut connection = self.connection().await?; sqlx::query(&self.substitute_table_name("DELETE FROM %%TABLE_NAME%% WHERE expires < $1")) - .bind(Utc::now()) + .bind(OffsetDateTime::now_utc()) .execute(&mut connection) .await?; @@ -243,15 +198,15 @@ impl PostgresSessionStore { /// expired sessions /// /// ```rust - /// # use async_sqlx_session::PostgresSessionStore; - /// # use async_session::{Result, SessionStore, Session}; + /// # use async_sqlx_session::{PostgresSessionStore, Error}; + /// # use async_session::{SessionStore, Session}; /// # use std::time::Duration; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = PostgresSessionStore::new(&std::env::var("PG_TEST_DB_URL").unwrap()).await?; /// store.migrate().await?; /// # store.clear_store().await?; /// assert_eq!(store.count().await?, 0); - /// store.store_session(Session::new()).await?; + /// store.store_session(&mut Session::new()).await?; /// assert_eq!(store.count().await?, 1); /// # Ok(()) }) } /// ``` @@ -268,26 +223,33 @@ impl PostgresSessionStore { #[async_trait] impl SessionStore for PostgresSessionStore { - async fn load_session(&self, cookie_value: String) -> Result> { - let id = Session::id_from_cookie_value(&cookie_value)?; + type Error = crate::Error; + + async fn load_session(&self, cookie_value: &str) -> Result, Self::Error> { + let id = Session::id_from_cookie_value(cookie_value)?; let mut connection = self.connection().await?; - let result: Option<(String,)> = sqlx::query_as(&self.substitute_table_name( - "SELECT session FROM %%TABLE_NAME%% WHERE id = $1 AND (expires IS NULL OR expires > $2)" + let result: Option<(String, Option)> = sqlx::query_as(&self.substitute_table_name( + "SELECT session, expires FROM %%TABLE_NAME%% WHERE id = $1 AND (expires IS NULL OR expires > $2)" )) .bind(&id) - .bind(Utc::now()) + .bind(OffsetDateTime::now_utc()) .fetch_optional(&mut connection) .await?; - Ok(result - .map(|(session,)| serde_json::from_str(&session)) - .transpose()?) + match result { + Some((data, expiry)) => Ok(Some(Session::from_parts( + id, + serde_json::from_str(&data)?, + expiry, + ))), + None => Ok(None), + } } - async fn store_session(&self, session: Session) -> Result> { + async fn store_session(&self, session: &mut Session) -> Result, Self::Error> { let id = session.id(); - let string = serde_json::to_string(&session)?; + let string = serde_json::to_string(session.data())?; let mut connection = self.connection().await?; sqlx::query(&self.substitute_table_name( @@ -299,27 +261,27 @@ impl SessionStore for PostgresSessionStore { session = EXCLUDED.session "#, )) - .bind(&id) + .bind(id) .bind(&string) - .bind(&session.expiry()) + .bind(session.expiry()) .execute(&mut connection) .await?; - Ok(session.into_cookie_value()) + Ok(session.take_cookie_value()) } - async fn destroy_session(&self, session: Session) -> Result { + async fn destroy_session(&self, session: &mut Session) -> Result<(), Self::Error> { let id = session.id(); let mut connection = self.connection().await?; sqlx::query(&self.substitute_table_name("DELETE FROM %%TABLE_NAME%% WHERE id = $1")) - .bind(&id) + .bind(id) .execute(&mut connection) .await?; Ok(()) } - async fn clear_store(&self) -> Result { + async fn clear_store(&self) -> Result<(), Self::Error> { let mut connection = self.connection().await?; sqlx::query(&self.substitute_table_name("TRUNCATE %%TABLE_NAME%%")) .execute(&mut connection) @@ -331,8 +293,12 @@ impl SessionStore for PostgresSessionStore { #[cfg(test)] mod tests { + use serde_json::Value; + use std::collections::HashMap; + + use crate::Error; + use super::*; - use async_session::chrono::DateTime; use std::time::Duration; async fn test_store() -> PostgresSessionStore { @@ -351,14 +317,14 @@ mod tests { } #[async_std::test] - async fn creating_a_new_session_with_no_expiry() -> Result { + async fn creating_a_new_session_with_no_expiry() -> Result<(), Error> { let store = test_store().await; let mut session = Session::new(); session.insert("key", "value")?; let cloned = session.clone(); - let cookie_value = store.store_session(session).await?.unwrap(); + let cookie_value = store.store_session(&mut session).await?.unwrap(); - let (id, expires, serialized, count): (String, Option>, String, i64) = + let (id, expires, serialized, count): (String, Option, String, i64) = sqlx::query_as("select id, expires, session, (select count(*) from async_sessions) from async_sessions") .fetch_one(&mut store.connection().await?) .await?; @@ -367,11 +333,13 @@ mod tests { assert_eq!(id, cloned.id()); assert_eq!(expires, None); - let deserialized_session: Session = serde_json::from_str(&serialized)?; - assert_eq!(cloned.id(), deserialized_session.id()); - assert_eq!("value", &deserialized_session.get::("key").unwrap()); + let deserialized_session: HashMap = serde_json::from_str(&serialized)?; + assert_eq!( + "\"value\"", + &deserialized_session.get("key").unwrap().to_string() + ); - let loaded_session = store.load_session(cookie_value).await?.unwrap(); + let loaded_session = store.load_session(&cookie_value).await?.unwrap(); assert_eq!(cloned.id(), loaded_session.id()); assert_eq!("value", &loaded_session.get::("key").unwrap()); @@ -380,19 +348,19 @@ mod tests { } #[async_std::test] - async fn updating_a_session() -> Result { + async fn updating_a_session() -> Result<(), Error> { let store = test_store().await; let mut session = Session::new(); let original_id = session.id().to_owned(); session.insert("key", "value")?; - let cookie_value = store.store_session(session).await?.unwrap(); + let cookie_value = store.store_session(&mut session).await?.unwrap(); - let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); + let mut session = store.load_session(&cookie_value).await?.unwrap(); session.insert("key", "other value")?; - assert_eq!(None, store.store_session(session).await?); + assert_eq!(None, store.store_session(&mut session).await?); - let session = store.load_session(cookie_value.clone()).await?.unwrap(); + let session = store.load_session(&cookie_value).await?.unwrap(); assert_eq!(session.get::("key").unwrap(), "other value"); let (id, count): (String, i64) = @@ -407,95 +375,110 @@ mod tests { } #[async_std::test] - async fn updating_a_session_extending_expiry() -> Result { + async fn updating_a_session_extending_expiry() -> Result<(), Error> { let store = test_store().await; let mut session = Session::new(); session.expire_in(Duration::from_secs(10)); let original_id = session.id().to_owned(); - let original_expires = session.expiry().unwrap().clone(); - let cookie_value = store.store_session(session).await?.unwrap(); + let original_expires = *session.expiry().unwrap(); + + let cookie_value = store.store_session(&mut session).await?.unwrap(); - let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); - assert_eq!(session.expiry().unwrap(), &original_expires); + let mut session = store.load_session(&cookie_value).await?.unwrap(); + assert_eq!( + session.expiry().unwrap().unix_timestamp(), + original_expires.unix_timestamp() + ); session.expire_in(Duration::from_secs(20)); - let new_expires = session.expiry().unwrap().clone(); - store.store_session(session).await?; + let new_expires = *session.expiry().unwrap(); + store.store_session(&mut session).await?; - let session = store.load_session(cookie_value.clone()).await?.unwrap(); - assert_eq!(session.expiry().unwrap(), &new_expires); + let session = store.load_session(&cookie_value).await?.unwrap(); + assert_eq!( + session.expiry().unwrap().unix_timestamp(), + new_expires.unix_timestamp() + ); - let (id, expires, count): (String, DateTime, i64) = sqlx::query_as( + let (id, expires, count): (String, Option, i64) = sqlx::query_as( "select id, expires, (select count(*) from async_sessions) from async_sessions", ) .fetch_one(&mut store.connection().await?) .await?; assert_eq!(1, count); - assert_eq!(expires.timestamp_millis(), new_expires.timestamp_millis()); + assert_eq!( + expires.unwrap().unix_timestamp(), + new_expires.unix_timestamp() + ); assert_eq!(original_id, id); Ok(()) } #[async_std::test] - async fn creating_a_new_session_with_expiry() -> Result { + async fn creating_a_new_session_with_expiry() -> Result<(), Error> { let store = test_store().await; let mut session = Session::new(); session.expire_in(Duration::from_secs(1)); session.insert("key", "value")?; let cloned = session.clone(); - let cookie_value = store.store_session(session).await?.unwrap(); + let cookie_value = store.store_session(&mut session).await?.unwrap(); - let (id, expires, serialized, count): (String, Option>, String, i64) = + let (id, expires, serialized, count): (String, Option, String, i64) = sqlx::query_as("select id, expires, session, (select count(*) from async_sessions) from async_sessions") .fetch_one(&mut store.connection().await?) .await?; assert_eq!(1, count); assert_eq!(id, cloned.id()); - assert!(expires.unwrap() > Utc::now()); - - let deserialized_session: Session = serde_json::from_str(&serialized)?; - assert_eq!(cloned.id(), deserialized_session.id()); - assert_eq!("value", &deserialized_session.get::("key").unwrap()); - - let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap(); + assert!(expires.unwrap() > OffsetDateTime::now_utc()); + + let deserialized_session: HashMap = serde_json::from_str(&serialized)?; + assert_eq!( + "\"value\"", + deserialized_session + .get(&String::from("key")) + .unwrap() + .to_string() + ); + + let loaded_session = store.load_session(&cookie_value).await?.unwrap(); assert_eq!(cloned.id(), loaded_session.id()); assert_eq!("value", &loaded_session.get::("key").unwrap()); assert!(!loaded_session.is_expired()); async_std::task::sleep(Duration::from_secs(1)).await; - assert_eq!(None, store.load_session(cookie_value).await?); + assert_eq!(None, store.load_session(&cookie_value).await?); Ok(()) } #[async_std::test] - async fn destroying_a_single_session() -> Result { + async fn destroying_a_single_session() -> Result<(), Error> { let store = test_store().await; for _ in 0..3i8 { - store.store_session(Session::new()).await?; + store.store_session(&mut Session::new()).await?; } - let cookie = store.store_session(Session::new()).await?.unwrap(); + let cookie = store.store_session(&mut Session::new()).await?.unwrap(); assert_eq!(4, store.count().await?); - let session = store.load_session(cookie.clone()).await?.unwrap(); - store.destroy_session(session.clone()).await.unwrap(); - assert_eq!(None, store.load_session(cookie).await?); + let mut session = store.load_session(&cookie).await?.unwrap(); + store.destroy_session(&mut session).await.unwrap(); + assert_eq!(None, store.load_session(&cookie).await?); assert_eq!(3, store.count().await?); // // attempting to destroy the session again is not an error - assert!(store.destroy_session(session).await.is_ok()); + assert!(store.destroy_session(&mut session).await.is_ok()); Ok(()) } #[async_std::test] - async fn clearing_the_whole_store() -> Result { + async fn clearing_the_whole_store() -> Result<(), Error> { let store = test_store().await; for _ in 0..3i8 { - store.store_session(Session::new()).await?; + store.store_session(&mut Session::new()).await?; } assert_eq!(3, store.count().await?); diff --git a/src/sqlite.rs b/src/sqlite.rs index 77e973f..755c14e 100644 --- a/src/sqlite.rs +++ b/src/sqlite.rs @@ -1,24 +1,24 @@ -use async_session::{async_trait, chrono::Utc, log, serde_json, Result, Session, SessionStore}; +use crate::Error; +use async_session::{async_trait, Session, SessionStore}; use sqlx::{pool::PoolConnection, sqlite::SqlitePool, Sqlite}; +use time::OffsetDateTime; /// sqlx sqlite session store for async-sessions /// /// ```rust -/// use async_sqlx_session::SqliteSessionStore; +/// use async_sqlx_session::{SqliteSessionStore, Error}; /// use async_session::{Session, SessionStore}; /// use std::time::Duration; /// -/// # fn main() -> async_session::Result { async_std::task::block_on(async { +/// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = SqliteSessionStore::new("sqlite::memory:").await?; /// store.migrate().await?; -/// # #[cfg(feature = "async_std")] -/// store.spawn_cleanup_task(Duration::from_secs(60 * 60)); /// /// let mut session = Session::new(); /// session.insert("key", vec![1,2,3]); /// -/// let cookie_value = store.store_session(session).await?.unwrap(); -/// let session = store.load_session(cookie_value).await?.unwrap(); +/// let cookie_value = store.store_session(&mut session).await?.unwrap(); +/// let session = store.load_session(&cookie_value).await?.unwrap(); /// assert_eq!(session.get::>("key").unwrap(), vec![1,2,3]); /// # Ok(()) }) } /// @@ -35,9 +35,8 @@ impl SqliteSessionStore { /// with [`with_table_name`](crate::SqliteSessionStore::with_table_name). /// /// ```rust - /// # use async_sqlx_session::SqliteSessionStore; - /// # use async_session::Result; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # use async_sqlx_session::{SqliteSessionStore, Error}; + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); /// let store = SqliteSessionStore::from_client(pool) /// .with_table_name("custom_table_name"); @@ -62,9 +61,8 @@ impl SqliteSessionStore { /// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name) /// /// ```rust - /// # use async_sqlx_session::SqliteSessionStore; - /// # use async_session::Result; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # use async_sqlx_session::{SqliteSessionStore, Error}; + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = SqliteSessionStore::new("sqlite::memory:").await?; /// store.migrate().await; /// # Ok(()) }) } @@ -81,9 +79,8 @@ impl SqliteSessionStore { /// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name) /// /// ```rust - /// # use async_sqlx_session::SqliteSessionStore; - /// # use async_session::Result; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # use async_sqlx_session::{SqliteSessionStore, Error}; + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = SqliteSessionStore::new_with_table_name("sqlite::memory:", "custom_table_name").await?; /// store.migrate().await; /// # Ok(()) }) } @@ -95,9 +92,9 @@ impl SqliteSessionStore { /// Chainable method to add a custom table name. This will panic /// if the table name is not `[a-zA-Z0-9_-]+`. /// ```rust - /// # use async_sqlx_session::SqliteSessionStore; - /// # use async_session::Result; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # use async_sqlx_session::{SqliteSessionStore, Error}; + + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = SqliteSessionStore::new("sqlite::memory:").await? /// .with_table_name("custom_name"); /// store.migrate().await; @@ -105,9 +102,8 @@ impl SqliteSessionStore { /// ``` /// /// ```should_panic - /// # use async_sqlx_session::SqliteSessionStore; - /// # use async_session::Result; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # use async_sqlx_session::{SqliteSessionStore, Error}; + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = SqliteSessionStore::new("sqlite::memory:").await? /// .with_table_name("johnny (); drop users;"); /// # Ok(()) }) } @@ -135,13 +131,13 @@ impl SqliteSessionStore { /// exactly-once modifications to the schema of the session table /// on breaking releases. /// ```rust - /// # use async_sqlx_session::SqliteSessionStore; - /// # use async_session::{Result, SessionStore, Session}; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # use async_sqlx_session::{SqliteSessionStore, Error}; + /// # use async_session::{SessionStore, Session}; + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = SqliteSessionStore::new("sqlite::memory:").await?; /// assert!(store.count().await.is_err()); /// store.migrate().await?; - /// store.store_session(Session::new()).await?; + /// store.store_session(&mut Session::new()).await?; /// store.migrate().await?; // calling it a second time is safe /// assert_eq!(store.count().await?, 1); /// # Ok(()) }) } @@ -175,55 +171,18 @@ impl SqliteSessionStore { self.client.acquire().await } - /// Spawns an async_std::task that clears out stale (expired) - /// sessions on a periodic basis. Only available with the - /// async_std feature enabled. - /// - /// ```rust,no_run - /// # use async_sqlx_session::SqliteSessionStore; - /// # use async_session::{Result, SessionStore, Session}; - /// # use std::time::Duration; - /// # fn main() -> Result { async_std::task::block_on(async { - /// let store = SqliteSessionStore::new("sqlite::memory:").await?; - /// store.migrate().await?; - /// # let join_handle = - /// store.spawn_cleanup_task(Duration::from_secs(1)); - /// let mut session = Session::new(); - /// session.expire_in(Duration::from_secs(0)); - /// store.store_session(session).await?; - /// assert_eq!(store.count().await?, 1); - /// async_std::task::sleep(Duration::from_secs(2)).await; - /// assert_eq!(store.count().await?, 0); - /// # join_handle.cancel().await; - /// # Ok(()) }) } - /// ``` - #[cfg(feature = "async_std")] - pub fn spawn_cleanup_task( - &self, - period: std::time::Duration, - ) -> async_std::task::JoinHandle<()> { - let store = self.clone(); - async_std::task::spawn(async move { - loop { - async_std::task::sleep(period).await; - if let Err(error) = store.cleanup().await { - log::error!("cleanup error: {}", error); - } - } - }) - } - /// Performs a one-time cleanup task that clears out stale /// (expired) sessions. You may want to call this from cron. /// ```rust - /// # use async_sqlx_session::SqliteSessionStore; - /// # use async_session::{chrono::{Utc,Duration}, Result, SessionStore, Session}; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # use async_sqlx_session::{SqliteSessionStore, Error}; + /// # use async_session::{SessionStore, Session}; + /// # use time::{OffsetDateTime, Duration}; + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = SqliteSessionStore::new("sqlite::memory:").await?; /// store.migrate().await?; /// let mut session = Session::new(); - /// session.set_expiry(Utc::now() - Duration::seconds(5)); - /// store.store_session(session).await?; + /// session.set_expiry(OffsetDateTime::now_utc() - Duration::seconds(5)); + /// store.store_session(&mut session).await?; /// assert_eq!(store.count().await?, 1); /// store.cleanup().await?; /// assert_eq!(store.count().await?, 0); @@ -237,7 +196,7 @@ impl SqliteSessionStore { WHERE expires < ? "#, )) - .bind(Utc::now().timestamp()) + .bind(OffsetDateTime::now_utc().unix_timestamp()) .execute(&mut connection) .await?; @@ -248,14 +207,14 @@ impl SqliteSessionStore { /// expired sessions /// /// ```rust - /// # use async_sqlx_session::SqliteSessionStore; - /// # use async_session::{Result, SessionStore, Session}; + /// # use async_sqlx_session::{SqliteSessionStore, Error}; + /// # use async_session::{SessionStore, Session}; /// # use std::time::Duration; - /// # fn main() -> Result { async_std::task::block_on(async { + /// # fn main() -> Result<(), Error> { async_std::task::block_on(async { /// let store = SqliteSessionStore::new("sqlite::memory:").await?; /// store.migrate().await?; /// assert_eq!(store.count().await?, 0); - /// store.store_session(Session::new()).await?; + /// store.store_session(&mut Session::new()).await?; /// assert_eq!(store.count().await?, 1); /// # Ok(()) }) } /// ``` @@ -272,29 +231,38 @@ impl SqliteSessionStore { #[async_trait] impl SessionStore for SqliteSessionStore { - async fn load_session(&self, cookie_value: String) -> Result> { - let id = Session::id_from_cookie_value(&cookie_value)?; + type Error = Error; + + async fn load_session(&self, cookie_value: &str) -> Result, Self::Error> { + let id = Session::id_from_cookie_value(cookie_value)?; let mut connection = self.connection().await?; - let result: Option<(String,)> = sqlx::query_as(&self.substitute_table_name( - r#" - SELECT session FROM %%TABLE_NAME%% + let result: Option<(String, Option)> = + sqlx::query_as(&self.substitute_table_name( + r#" + SELECT session, expires FROM %%TABLE_NAME%% WHERE id = ? AND (expires IS NULL OR expires > ?) "#, - )) - .bind(&id) - .bind(Utc::now().timestamp()) - .fetch_optional(&mut connection) - .await?; + )) + .bind(&id) + .bind(OffsetDateTime::now_utc().unix_timestamp()) + .fetch_optional(&mut connection) + .await?; - Ok(result - .map(|(session,)| serde_json::from_str(&session)) - .transpose()?) + if let Some((data, expiry)) = result { + Ok(Some(Session::from_parts( + id, + serde_json::from_str(&data)?, + expiry, + ))) + } else { + Ok(None) + } } - async fn store_session(&self, session: Session) -> Result> { + async fn store_session(&self, session: &mut Session) -> Result, Self::Error> { let id = session.id(); - let string = serde_json::to_string(&session)?; + let string = serde_json::to_string(session.data())?; let mut connection = self.connection().await?; sqlx::query(&self.substitute_table_name( @@ -306,16 +274,16 @@ impl SessionStore for SqliteSessionStore { session = excluded.session "#, )) - .bind(&id) + .bind(id) .bind(&string) - .bind(&session.expiry().map(|expiry| expiry.timestamp())) + .bind(session.expiry().map(|expiry| expiry.unix_timestamp())) .execute(&mut connection) .await?; - Ok(session.into_cookie_value()) + Ok(session.take_cookie_value()) } - async fn destroy_session(&self, session: Session) -> Result { + async fn destroy_session(&self, session: &mut Session) -> Result<(), Self::Error> { let id = session.id(); let mut connection = self.connection().await?; sqlx::query(&self.substitute_table_name( @@ -323,14 +291,14 @@ impl SessionStore for SqliteSessionStore { DELETE FROM %%TABLE_NAME%% WHERE id = ? "#, )) - .bind(&id) + .bind(id) .execute(&mut connection) .await?; Ok(()) } - async fn clear_store(&self) -> Result { + async fn clear_store(&self) -> Result<(), Self::Error> { let mut connection = self.connection().await?; sqlx::query(&self.substitute_table_name( r#" @@ -346,8 +314,10 @@ impl SessionStore for SqliteSessionStore { #[cfg(test)] mod tests { + use serde_json::{json, Value}; + use super::*; - use std::time::Duration; + use std::{collections::HashMap, time::Duration}; async fn test_store() -> SqliteSessionStore { let store = SqliteSessionStore::new("sqlite::memory:") @@ -361,12 +331,12 @@ mod tests { } #[async_std::test] - async fn creating_a_new_session_with_no_expiry() -> Result { + async fn creating_a_new_session_with_no_expiry() -> Result<(), Error> { let store = test_store().await; let mut session = Session::new(); session.insert("key", "value")?; let cloned = session.clone(); - let cookie_value = store.store_session(session).await?.unwrap(); + let cookie_value = store.store_session(&mut session).await?.unwrap(); let (id, expires, serialized, count): (String, Option, String, i64) = sqlx::query_as("select id, expires, session, count(*) from async_sessions") @@ -377,11 +347,10 @@ mod tests { assert_eq!(id, cloned.id()); assert_eq!(expires, None); - let deserialized_session: Session = serde_json::from_str(&serialized)?; - assert_eq!(cloned.id(), deserialized_session.id()); - assert_eq!("value", &deserialized_session.get::("key").unwrap()); + let deserialized_session: HashMap = serde_json::from_str(&serialized)?; + assert_eq!(&json!("value"), deserialized_session.get("key").unwrap()); - let loaded_session = store.load_session(cookie_value).await?.unwrap(); + let loaded_session = store.load_session(&cookie_value).await?.unwrap(); assert_eq!(cloned.id(), loaded_session.id()); assert_eq!("value", &loaded_session.get::("key").unwrap()); @@ -390,19 +359,19 @@ mod tests { } #[async_std::test] - async fn updating_a_session() -> Result { + async fn updating_a_session() -> Result<(), Error> { let store = test_store().await; let mut session = Session::new(); let original_id = session.id().to_owned(); session.insert("key", "value")?; - let cookie_value = store.store_session(session).await?.unwrap(); + let cookie_value = store.store_session(&mut session).await?.unwrap(); - let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); + let mut session = store.load_session(&cookie_value).await?.unwrap(); session.insert("key", "other value")?; - assert_eq!(None, store.store_session(session).await?); + assert_eq!(None, store.store_session(&mut session).await?); - let session = store.load_session(cookie_value.clone()).await?.unwrap(); + let session = store.load_session(&cookie_value).await?.unwrap(); assert_eq!(session.get::("key").unwrap(), "other value"); let (id, count): (String, i64) = sqlx::query_as("select id, count(*) from async_sessions") @@ -416,44 +385,50 @@ mod tests { } #[async_std::test] - async fn updating_a_session_extending_expiry() -> Result { + async fn updating_a_session_extending_expiry() -> Result<(), Error> { let store = test_store().await; let mut session = Session::new(); session.expire_in(Duration::from_secs(10)); let original_id = session.id().to_owned(); - let original_expires = session.expiry().unwrap().clone(); - let cookie_value = store.store_session(session).await?.unwrap(); - - let mut session = store.load_session(cookie_value.clone()).await?.unwrap(); - assert_eq!(session.expiry().unwrap(), &original_expires); + let original_expires = *session.expiry().unwrap(); + let cookie_value = store.store_session(&mut session).await?.unwrap(); + + let mut session = store.load_session(&cookie_value).await?.unwrap(); + assert_eq!( + session.expiry().unwrap(), + &original_expires.replace_millisecond(0).unwrap() + ); session.expire_in(Duration::from_secs(20)); - let new_expires = session.expiry().unwrap().clone(); - store.store_session(session).await?; + let new_expires = *session.expiry().unwrap(); + store.store_session(&mut session).await?; - let session = store.load_session(cookie_value.clone()).await?.unwrap(); - assert_eq!(session.expiry().unwrap(), &new_expires); + let session = store.load_session(&cookie_value).await?.unwrap(); + assert_eq!( + session.expiry().unwrap(), + &new_expires.replace_millisecond(0).unwrap() + ); - let (id, expires, count): (String, i64, i64) = + let (id, expires, count): (String, Option, i64) = sqlx::query_as("select id, expires, count(*) from async_sessions") .fetch_one(&mut store.connection().await?) .await?; assert_eq!(1, count); - assert_eq!(expires, new_expires.timestamp()); + assert_eq!(expires.unwrap(), new_expires.unix_timestamp()); assert_eq!(original_id, id); Ok(()) } #[async_std::test] - async fn creating_a_new_session_with_expiry() -> Result { + async fn creating_a_new_session_with_expiry() -> Result<(), Error> { let store = test_store().await; let mut session = Session::new(); session.expire_in(Duration::from_secs(1)); session.insert("key", "value")?; let cloned = session.clone(); - let cookie_value = store.store_session(session).await?.unwrap(); + let cookie_value = store.store_session(&mut session).await?.unwrap(); let (id, expires, serialized, count): (String, Option, String, i64) = sqlx::query_as("select id, expires, session, count(*) from async_sessions") @@ -462,36 +437,35 @@ mod tests { assert_eq!(1, count); assert_eq!(id, cloned.id()); - assert!(expires.unwrap() > Utc::now().timestamp()); + assert!(expires.unwrap() > OffsetDateTime::now_utc().unix_timestamp()); - let deserialized_session: Session = serde_json::from_str(&serialized)?; - assert_eq!(cloned.id(), deserialized_session.id()); - assert_eq!("value", &deserialized_session.get::("key").unwrap()); + let deserialized_session: HashMap = serde_json::from_str(&serialized)?; + assert_eq!(&json!("value"), deserialized_session.get("key").unwrap()); - let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap(); + let loaded_session = store.load_session(&cookie_value).await?.unwrap(); assert_eq!(cloned.id(), loaded_session.id()); assert_eq!("value", &loaded_session.get::("key").unwrap()); assert!(!loaded_session.is_expired()); async_std::task::sleep(Duration::from_secs(1)).await; - assert_eq!(None, store.load_session(cookie_value).await?); + assert_eq!(None, store.load_session(&cookie_value).await?); Ok(()) } #[async_std::test] - async fn destroying_a_single_session() -> Result { + async fn destroying_a_single_session() -> Result<(), Error> { let store = test_store().await; for _ in 0..3i8 { - store.store_session(Session::new()).await?; + store.store_session(&mut Session::new()).await?; } - let cookie = store.store_session(Session::new()).await?.unwrap(); + let cookie = store.store_session(&mut Session::new()).await?.unwrap(); assert_eq!(4, store.count().await?); - let session = store.load_session(cookie.clone()).await?.unwrap(); - store.destroy_session(session.clone()).await.unwrap(); - assert_eq!(None, store.load_session(cookie).await?); + let mut session = store.load_session(&cookie).await?.unwrap(); + store.destroy_session(&mut session).await.unwrap(); + assert_eq!(None, store.load_session(&cookie).await?); assert_eq!(3, store.count().await?); // // attempting to destroy the session again is not an error @@ -500,10 +474,10 @@ mod tests { } #[async_std::test] - async fn clearing_the_whole_store() -> Result { + async fn clearing_the_whole_store() -> Result<(), Error> { let store = test_store().await; for _ in 0..3i8 { - store.store_session(Session::new()).await?; + store.store_session(&mut Session::new()).await?; } assert_eq!(3, store.count().await?);