diff --git a/Cargo.toml b/Cargo.toml index a9b4fad..94e38e9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,16 +16,18 @@ edition = "2021" maintenance = { status = "actively-developed" } [dependencies] -axum = "0.6.16" -axum-core = "0.3.4" -axum-sessions = "0.5.0" -base64 = "0.21.0" +axum = "0.7.3" +base64 = "0.21.5" rand = "0.8.5" thiserror = "1.0.40" -tokio = { version = "1.27.0", features = ["macros", "rt", "rt-multi-thread"] } -tower = "0.4.13" -tracing = "0.1.37" +tower-cookies = "0.10.0" +tower-layer = "0.3.2" +tower-service = "0.3.2" +tower-sessions = "0.9.1" +tracing = "0.1.40" [dev-dependencies] +tokio = { version = "1.27.0", features = ["macros", "rt", "rt-multi-thread"] } tokio-test = "0.4.2" -tower-http = { version = "0.4.0", features = ["cors"] } +tower = "0.4.13" +tower-http = { version = "0.5.0", features = ["cors"] } diff --git a/README.md b/README.md index 897d8f4..b998e32 100644 --- a/README.md +++ b/README.md @@ -28,11 +28,9 @@ Consider as well to use the [crate unit tests](https://github.com/LeoniePhiline/ This middleware implements token transfer via [custom request headers](https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#use-of-custom-request-headers). -The middleware requires and is built upon [`axum_sessions`](https://docs.rs/axum-sessions/), which in turn uses [`async_session`](https://docs.rs/async-session/). +The middleware requires and is built upon [`tower_sessions`](https://docs.rs/tower-sessions/). -The current version is built for and works with `axum 0.6.x`, `axum-sessions 0.5.x` and `async_session 3.x`. - -There will be support for `axum 0.7` and later versions. +The current version is built for and works with `axum 0.7.x`, `tower-sessions 0.9.x`. The [Same Origin Policy](https://developer.mozilla.org/en-US/docs/Web/Security/Same-origin_policy) prevents the custom request header to be set by foreign scripts. @@ -67,7 +65,7 @@ See ["Our RNGs"](https://rust-random.github.io/book/guide-rngs.html#cryptographi The security of the underlying session is paramount - the CSRF prevention methods applied can only be as secure as the session carrying the server-side token. -- When creating your [SessionLayer](https://docs.rs/axum-sessions/latest/axum_sessions/struct.SessionLayer.html), make sure to use at least 64 bytes of cryptographically secure randomness. +- When creating your [SessionManagerLayer](https://docs.rs/tower-sessions/latest/tower_sessions/struct.SessionManagerLayer.html) - Do not lower the secure defaults: Keep the session cookie's `secure` flag **on**. - Use the strictest possible same-site policy. @@ -105,16 +103,12 @@ Configure your session and CSRF protection layer in your backend application: ```rust use axum::{ - body::Body, - http::StatusCode, - routing::{get, Router}, + http::{header, StatusCode}, + response::IntoResponse, + routing::get, }; -use axum_csrf_sync_pattern::{CsrfLayer, RegenerateToken}; -use axum_sessions::{async_session::MemoryStore, SessionLayer}; -use rand::RngCore; - -let mut secret = [0; 64]; -rand::thread_rng().try_fill_bytes(&mut secret).unwrap(); +use axum_csrf_sync_pattern::CsrfLayer; +use tower_sessions::{MemoryStore, SessionManagerLayer}; async fn handler() -> StatusCode { StatusCode::OK @@ -136,7 +130,7 @@ let app = Router::new() // Default: "_csrf_token" .session_key("_custom_session_key") ) - .layer(SessionLayer::new(MemoryStore::new(), &secret)); + .layer(SessionManagerLayer::new(MemoryStore::default())); // Use hyper to run `app` as service and expose on a local port or socket. ``` @@ -175,37 +169,33 @@ Configure your CORS layer, session and CSRF protection layer in your backend app ```rust use axum::{ - body::Body, http::{header, Method, StatusCode}, + response::IntoResponse, routing::{get, Router}, }; -use axum_csrf_sync_pattern::{CsrfLayer, RegenerateToken}; -use axum_sessions::{async_session::MemoryStore, SessionLayer}; -use rand::RngCore; +use axum_csrf_sync_pattern::CsrfLayer; use tower_http::cors::{AllowOrigin, CorsLayer}; - -let mut secret = [0; 64]; -rand::thread_rng().try_fill_bytes(&mut secret).unwrap(); +use tower_sessions::{MemoryStore, SessionManagerLayer}; async fn handler() -> StatusCode { StatusCode::OK } let app = Router::new() - .route("/", get(handler).post(handler)) - .layer( - // See example above for custom layer configuration. - CsrfLayer::new() - ) - .layer(SessionLayer::new(MemoryStore::new(), &secret)) - .layer( - CorsLayer::new() - .allow_origin(AllowOrigin::list(["https://www.example.com".parse().unwrap()])) - .allow_methods([Method::GET, Method::POST]) - .allow_headers([header::CONTENT_TYPE, "X-CSRF-TOKEN".parse().unwrap()]) - .allow_credentials(true) - .expose_headers(["X-CSRF-TOKEN".parse().unwrap()]), -); + .route("/", get(handler).post(handler)) + .layer( + // See example above for custom layer configuration. + CsrfLayer::new() + ) + .layer(SessionManagerLayer::new(MemoryStore::default())) + .layer( + CorsLayer::new() + .allow_origin(AllowOrigin::list(["https://www.example.com".parse().rap()])) + .allow_methods([Method::GET, Method::POST]) + .allow_headers([header::CONTENT_TYPE, "X-CSRF-TOKEN".parse().unwrap()]) + .allow_credentials(true) + .expose_headers(["X-CSRF-TOKEN".parse().unwrap()]), + ); // Use hyper to run `app` as service and expose on a local port or socket. ``` diff --git a/examples/cross-site/Cargo.toml b/examples/cross-site/Cargo.toml index 292a783..1265440 100644 --- a/examples/cross-site/Cargo.toml +++ b/examples/cross-site/Cargo.toml @@ -6,12 +6,10 @@ edition = "2021" publish = false [dependencies] -axum = "0.6.16" +axum = "0.7.3" axum-csrf-sync-pattern = { path = "../../" } -axum-sessions = "0.5.0" color-eyre = "0.6.2" -rand = "0.8.5" tokio = { version = "1.27.0", features = ["macros", "rt", "rt-multi-thread"] } -tower = "0.4.13" -tower-http = { version = "0.4.0", features = ["cors"] } -tracing-subscriber = "0.3.16" +tower-http = { version = "0.5.0", features = ["cors"] } +tower-sessions = "0.9.1" +tracing-subscriber = "0.3.18" diff --git a/examples/cross-site/src/main.rs b/examples/cross-site/src/main.rs index 1af3c22..597a872 100644 --- a/examples/cross-site/src/main.rs +++ b/examples/cross-site/src/main.rs @@ -4,13 +4,11 @@ use axum::{ http::{header, Method, StatusCode}, response::IntoResponse, routing::{get, Router}, - Server, }; use axum_csrf_sync_pattern::CsrfLayer; -use axum_sessions::{async_session::MemoryStore, SessionLayer}; use color_eyre::eyre::{self, eyre, WrapErr}; -use rand::RngCore; use tower_http::cors::{AllowOrigin, CorsLayer}; +use tower_sessions::{MemoryStore, SessionManagerLayer}; #[tokio::main] async fn main() -> eyre::Result<()> { @@ -33,15 +31,10 @@ async fn main() -> eyre::Result<()> { }; let backend = async { - let mut secret = [0; 64]; - rand::thread_rng() - .try_fill_bytes(&mut secret) - .wrap_err("Failed to generate session seed.")?; - let app = Router::new() .route("/", get(get_token).post(post_handler)) .layer(CsrfLayer::new()) - .layer(SessionLayer::new(MemoryStore::new(), &secret)) + .layer(SessionManagerLayer::new(MemoryStore::default())) .layer( CorsLayer::new() .allow_origin(AllowOrigin::list([ @@ -81,9 +74,10 @@ async fn main() -> eyre::Result<()> { async fn serve(app: Router, port: u16) -> eyre::Result<()> { let addr = SocketAddr::from(([127, 0, 0, 1], port)); - Server::try_bind(&addr) - .wrap_err("Could not bind to network address.")? - .serve(app.into_make_service()) + let listener = tokio::net::TcpListener::bind(addr) + .await + .wrap_err("Could not bind to network address.")?; + axum::serve(listener, app) .await .wrap_err("Failed to serve the app.")?; diff --git a/examples/same-site/Cargo.toml b/examples/same-site/Cargo.toml index bfd218e..25bb695 100644 --- a/examples/same-site/Cargo.toml +++ b/examples/same-site/Cargo.toml @@ -6,11 +6,9 @@ edition = "2021" publish = false [dependencies] -axum = "0.6.16" +axum = "0.7.3" axum-csrf-sync-pattern = { path = "../../" } -axum-sessions = "0.5.0" color-eyre = "0.6.2" -rand = "0.8.5" tokio = { version = "1.27.0", features = ["macros", "rt", "rt-multi-thread"] } -tower = "0.4.13" -tracing-subscriber = "0.3.16" +tower-sessions = "0.9.1" +tracing-subscriber = "0.3.18" diff --git a/examples/same-site/src/main.rs b/examples/same-site/src/main.rs index c35ae49..7eb62d3 100644 --- a/examples/same-site/src/main.rs +++ b/examples/same-site/src/main.rs @@ -2,12 +2,10 @@ use axum::{ http::{header, StatusCode}, response::IntoResponse, routing::get, - Server, }; use axum_csrf_sync_pattern::CsrfLayer; -use axum_sessions::{async_session::MemoryStore, SessionLayer}; use color_eyre::eyre::{self, eyre, WrapErr}; -use rand::RngCore; +use tower_sessions::{MemoryStore, SessionManagerLayer}; #[tokio::main] async fn main() -> eyre::Result<()> { @@ -20,26 +18,18 @@ async fn main() -> eyre::Result<()> { .map_err(|e| eyre!(e)) .wrap_err("Failed to initialize tracing-subscriber.")?; - let mut secret = [0; 64]; - rand::thread_rng() - .try_fill_bytes(&mut secret) - .wrap_err("Failed to generate session seed.")?; - let app = axum::Router::new() .route("/", get(index).post(handler)) .layer(CsrfLayer::new()) - .layer(SessionLayer::new(MemoryStore::new(), &secret)); + .layer(SessionManagerLayer::new(MemoryStore::default())); // Visit "http://127.0.0.1:3000/" in your browser. - Server::try_bind( - &"0.0.0.0:3000" - .parse() - .wrap_err("Failed to parse socket address.")?, - ) - .wrap_err("Could not bind to network address.")? - .serve(app.into_make_service()) - .await - .wrap_err("Failed to serve the app.")?; + let listener = tokio::net::TcpListener::bind("0.0.0.0:3000") + .await + .wrap_err("Could not bind to network address.")?; + axum::serve(listener, app) + .await + .wrap_err("Failed to serve the app.")?; Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index d351a14..116c9a1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,7 +9,7 @@ //! //! This middleware implements token transfer via [custom request headers](https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#use-of-custom-request-headers). //! -//! The middleware requires and is built upon [`axum_sessions`](https://docs.rs/axum-sessions/), which in turn uses [`async_session`](https://docs.rs/async-session/). +//! The middleware requires and is built upon [`tower_sessions`](https://docs.rs/tower-sessions/). //! //! The [Same Origin Policy](https://developer.mozilla.org/en-US/docs/Web/Security/Same-origin_policy) prevents the custom request header to be set by foreign scripts. //! @@ -43,7 +43,7 @@ //! //! The security of the underlying session is paramount - the CSRF prevention methods applied can only be as secure as the session carrying the server-side token. //! -//! - When creating your [SessionLayer](https://docs.rs/axum-sessions/latest/axum_sessions/struct.SessionLayer.html), make sure to use at least 64 bytes of cryptographically secure randomness. +//! - When creating your [SessionManagerLayer](https://docs.rs/tower-sessions/latest/tower_sessions/struct.SessionManagerLayer.html) //! - Do not lower the secure defaults: Keep the session cookie's `secure` flag **on**. //! - Use the strictest possible same-site policy. //! @@ -79,16 +79,14 @@ //! //! ```rust //! use axum::{ +//! BoxError, //! body::Body, //! http::StatusCode, //! routing::{get, Router}, +//! error_handling::HandleErrorLayer, //! }; //! use axum_csrf_sync_pattern::{CsrfLayer, RegenerateToken}; -//! use axum_sessions::{async_session::MemoryStore, SessionLayer}; -//! use rand::RngCore; -//! -//! let mut secret = [0; 64]; -//! rand::thread_rng().try_fill_bytes(&mut secret).unwrap(); +//! use tower_sessions::{MemoryStore, SessionManagerLayer}; //! //! async fn handler() -> StatusCode { //! StatusCode::OK @@ -110,7 +108,7 @@ //! // Default: "_csrf_token" //! .session_key("_custom_session_key") //! ) -//! .layer(SessionLayer::new(MemoryStore::new(), &secret)); +//! .layer(SessionManagerLayer::new(MemoryStore::default())); //! //! // Use hyper to run `app` as service and expose on a local port or socket. //! @@ -150,18 +148,17 @@ //! //! ```rust //! use axum::{ +//! BoxError, //! body::Body, //! http::{header, Method, StatusCode}, //! routing::{get, Router}, +//! error_handling::HandleErrorLayer, //! }; +//! use tower::ServiceBuilder; //! use axum_csrf_sync_pattern::{CsrfLayer, RegenerateToken}; -//! use axum_sessions::{async_session::MemoryStore, SessionLayer}; -//! use rand::RngCore; +//! use tower_sessions::{MemoryStore, SessionManagerLayer}; //! use tower_http::cors::{AllowOrigin, CorsLayer}; //! -//! let mut secret = [0; 64]; -//! rand::thread_rng().try_fill_bytes(&mut secret).unwrap(); -//! //! async fn handler() -> StatusCode { //! StatusCode::OK //! } @@ -172,7 +169,7 @@ //! // See example above for custom layer configuration. //! CsrfLayer::new() //! ) -//! .layer(SessionLayer::new(MemoryStore::new(), &secret)) +//! .layer(SessionManagerLayer::new(MemoryStore::default())) //! .layer( //! CorsLayer::new() //! .allow_origin(AllowOrigin::list(["https://www.example.com".parse().unwrap()])) @@ -180,7 +177,8 @@ //! .allow_headers([header::CONTENT_TYPE, "X-CSRF-TOKEN".parse().unwrap()]) //! .allow_credentials(true) //! .expose_headers(["X-CSRF-TOKEN".parse().unwrap()]), -//! ); +//! ); +//! //! //! // Use hyper to run `app` as service and expose on a local port or socket. //! @@ -236,13 +234,14 @@ use std::{ task::{Context, Poll}, }; -use axum::http::{self, HeaderValue, Request, StatusCode}; -use axum_core::response::{IntoResponse, Response}; -use axum_sessions::{async_session::Session, SessionHandle}; +use axum::{ + http::{self, HeaderValue, Request, StatusCode}, + response::{IntoResponse, Response}, +}; use base64::prelude::*; use rand::RngCore; -use tokio::sync::RwLockWriteGuard; -use tower::Layer; +use tower_layer::Layer; +use tower_sessions::Session; /// Use `CsrfLayer::new()` to provide the middleware and configuration to axum's service stack. /// @@ -315,14 +314,11 @@ impl CsrfLayer { self } - fn regenerate_token( - &self, - session_write: &mut RwLockWriteGuard, - ) -> Result { + async fn regenerate_token(&self, session: &Session) -> Result { let mut buf = [0; 32]; rand::thread_rng().try_fill_bytes(&mut buf)?; let token = BASE64_STANDARD.encode(buf); - session_write.insert(self.session_key, &token)?; + session.insert(self.session_key, &token).await?; Ok(token) } @@ -382,10 +378,10 @@ enum Error { #[error("Random number generator error")] Rng(#[from] rand::Error), - #[error("Serde JSON error")] - Serde(#[from] axum_sessions::async_session::serde_json::Error), + #[error("Session error")] + Session(#[from] tower_sessions::session::Error), - #[error("Session extension missing. Is `axum_sessions::SessionLayer` installed and layered around the `axum_csrf_sync_pattern::CsrfLayer`?")] + #[error("Session extension missing. Is `tower_sessions::SessionLayer` installed and layered around the `axum_csrf_sync_pattern::CsrfLayer`?")] SessionLayerMissing, #[error("Incoming CSRF token header was not valid ASCII")] @@ -431,8 +427,8 @@ pub struct CsrfMiddleware { } impl CsrfMiddleware { - /// Create a new middleware from an inner [`tower::Service`] (axum-specific bounds, such as `Infallible` errors apply!) and a [`CsrfLayer`]. - /// Commonly, the middleware is created by the [`tower::Layer`] - and never manually. + /// Create a new middleware from an inner [`tower_service::Service`] (axum-specific bounds, such as `Infallible` errors apply!) and a [`CsrfLayer`]. + /// Commonly, the middleware is created by the [`tower_layer::Layer`] - and never manually. pub fn new(inner: S, layer: CsrfLayer) -> Self { CsrfMiddleware { inner, layer } } @@ -444,9 +440,12 @@ impl CsrfMiddleware { } } -impl tower::Service> for CsrfMiddleware +impl tower_service::Service> for CsrfMiddleware where - S: tower::Service, Response = Response, Error = Infallible> + Send + Clone + 'static, + S: tower_service::Service, Response = Response, Error = Infallible> + + Send + + Clone + + 'static, S::Future: Send, { type Response = S::Response; @@ -463,9 +462,9 @@ where let mut inner = std::mem::replace(&mut self.inner, clone); let layer = self.layer; Box::pin(async move { - let session_handle = match req + let session = match req .extensions() - .get::() + .get::() .ok_or(Error::SessionLayerMissing) { Ok(session_handle) => session_handle, @@ -474,10 +473,14 @@ where // Extract the CSRF server side token from the session; create a new one if none has been set yet. // If the regeneration option is set to "per request", then regenerate the token even if present in the session. - let mut session_write = session_handle.write().await; - let mut server_token = match session_write.get::(layer.session_key) { + let mut server_token = match session + .get::(layer.session_key) + .await + .ok() + .flatten() + { Some(token) => token, - None => match layer.regenerate_token(&mut session_write) { + None => match layer.regenerate_token(&session).await { Ok(token) => token, Err(error) => return Ok(error.into_response()), }, @@ -518,7 +521,7 @@ where if layer.regenerate_token == RegenerateToken::PerRequest || (!req.method().is_safe() && layer.regenerate_token == RegenerateToken::PerUse) { - server_token = match layer.regenerate_token(&mut session_write) { + server_token = match layer.regenerate_token(&session).await { Ok(token) => token, Err(error) => { return Ok(layer.response_with_token(error.into_response(), &server_token)) @@ -526,8 +529,6 @@ where }; } - drop(session_write); - let response = inner.call(req).await.into_response(); // Add X-CSRF-TOKEN response header. @@ -540,14 +541,18 @@ where mod tests { use std::convert::Infallible; - use axum::{body::Body, routing::get, Router}; - use axum_core::response::{IntoResponse, Response}; - use axum_sessions::{async_session::MemoryStore, extractors::ReadableSession, SessionLayer}; + use axum::{ + body::Body, + response::{IntoResponse, Response}, + routing::get, + Router, + }; use http::{ header::{COOKIE, SET_COOKIE}, Method, Request, StatusCode, }; use tower::{Service, ServiceExt}; + use tower_sessions::{MemoryStore, SessionManagerLayer}; use super::*; @@ -559,10 +564,8 @@ mod tests { .into_response()) } - fn session_layer() -> SessionLayer { - let mut secret = [0; 64]; - rand::thread_rng().try_fill_bytes(&mut secret).unwrap(); - SessionLayer::new(MemoryStore::new(), &secret) + fn session_layer() -> SessionManagerLayer { + SessionManagerLayer::new(MemoryStore::default()) } fn app(csrf_layer: CsrfLayer) -> Router { @@ -608,6 +611,7 @@ mod tests { // Get CSRF token let response = app + .as_service() .ready() .await .unwrap() @@ -628,6 +632,7 @@ mod tests { // Use CSRF token for POST request let response = app + .as_service() .ready() .await .unwrap() @@ -650,6 +655,7 @@ mod tests { // Attempt token re-use for a second POST request let response = app + .as_service() .ready() .await .unwrap() @@ -677,6 +683,7 @@ mod tests { // Get single-use CSRF token let response = app + .as_service() .ready() .await .unwrap() @@ -697,6 +704,7 @@ mod tests { // Use CSRF token for POST request let response = app + .as_service() .ready() .await .unwrap() @@ -719,6 +727,7 @@ mod tests { // Attempt token re-use for a second POST request let response = app + .as_service() .ready() .await .unwrap() @@ -746,6 +755,7 @@ mod tests { // Get single-use CSRF token let response = app + .as_service() .ready() .await .unwrap() @@ -766,6 +776,7 @@ mod tests { // Perform another GET request let response = app + .as_service() .ready() .await .unwrap() @@ -787,6 +798,7 @@ mod tests { // Attempt using single-request token for POST request let response = app + .as_service() .ready() .await .unwrap() @@ -814,6 +826,7 @@ mod tests { // Get CSRF token let response = app + .as_service() .ready() .await .unwrap() @@ -831,6 +844,7 @@ mod tests { // Use CSRF token for POST request let response = app + .as_service() .ready() .await .unwrap() @@ -869,8 +883,9 @@ mod tests { async fn uses_custom_session_key() { // Custom handler asserting the layer's configured session key is set, // and its value looks like a CSRF token. - async fn extract_session(session: ReadableSession) -> StatusCode { - let session_csrf_token: String = session.get("custom_session_key").unwrap(); + async fn extract_session(session: Session) -> StatusCode { + let session_csrf_token: String = + session.get("custom_session_key").await.unwrap().unwrap(); assert_eq!( BASE64_STANDARD.decode(session_csrf_token).unwrap().len(), @@ -911,7 +926,7 @@ mod tests { let layer = CsrfLayer::new(); let response = Response::builder() .status(StatusCode::OK) - .body(axum::body::boxed(Body::empty())) + .body(Body::empty()) .unwrap(); let response = layer.response_with_token(response, "\n");