diff --git a/.clippy.toml b/.clippy.toml index b22159d2ec..70820aab2e 100644 --- a/.clippy.toml +++ b/.clippy.toml @@ -2,3 +2,4 @@ allow-mixed-uninlined-format-args = false disallowed-types = [ { path = "tower::util::BoxCloneService", reason = "Use tower::util::BoxCloneSyncService instead" }, ] +doc-valid-idents = ["WebSocket", "WebSockets", ".."] diff --git a/Cargo.toml b/Cargo.toml index a9eeafcfd3..6482b32038 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,21 @@ uninlined_format_args = "warn" unnested_or_patterns = "warn" unused_self = "warn" verbose_file_reads = "warn" +must_use_candidate = "warn" +use_self = "warn" +redundant_else = "warn" +unnecessary_semicolon = "warn" +manual_let_else = "warn" +option_if_let_else = "warn" +missing_const_for_fn = "warn" +if_not_else = "warn" +map_unwrap_or = "warn" +semicolon_if_nothing_returned = "warn" +manual_assert = "warn" +ignored_unit_patterns = "warn" +wildcard_imports = "warn" +return_self_not_must_use = "warn" +doc_markdown = "warn" # configuration for https://github.com/crate-ci/typos [workspace.metadata.typos.default.extend-identifiers] diff --git a/axum-core/src/body.rs b/axum-core/src/body.rs index 6c49970b96..0ae8729655 100644 --- a/axum-core/src/body.rs +++ b/axum-core/src/body.rs @@ -26,11 +26,10 @@ where K: Send + 'static, { let mut k = Some(k); - if let Some(k) = ::downcast_mut::>(&mut k) { - Ok(k.take().unwrap()) - } else { - Err(k.unwrap()) - } + + ::downcast_mut::>(&mut k) + .and_then(Option::take) + .map_or_else(|| Err(k.unwrap()), Ok) } /// The body type used in axum requests and responses. @@ -48,6 +47,7 @@ impl Body { } /// Create an empty body. + #[must_use] pub fn empty() -> Self { Self::new(http_body_util::Empty::new()) } @@ -72,7 +72,8 @@ impl Body { /// you need a [`Stream`] of all frame types. /// /// [`http_body_util::BodyStream`]: https://docs.rs/http-body-util/latest/http_body_util/struct.BodyStream.html - pub fn into_data_stream(self) -> BodyDataStream { + #[must_use] + pub const fn into_data_stream(self) -> BodyDataStream { BodyDataStream { inner: self } } } @@ -84,7 +85,7 @@ impl Default for Body { } impl From<()> for Body { - fn from(_: ()) -> Self { + fn from((): ()) -> Self { Self::empty() } } diff --git a/axum-core/src/error.rs b/axum-core/src/error.rs index 8c522c72b2..e77340e3cf 100644 --- a/axum-core/src/error.rs +++ b/axum-core/src/error.rs @@ -16,6 +16,7 @@ impl Error { } /// Convert an `Error` back into the underlying boxed trait object. + #[must_use] pub fn into_inner(self) -> BoxError { self.inner } diff --git a/axum-core/src/ext_traits/mod.rs b/axum-core/src/ext_traits/mod.rs index 951a12d70c..b7a21b085e 100644 --- a/axum-core/src/ext_traits/mod.rs +++ b/axum-core/src/ext_traits/mod.rs @@ -1,5 +1,8 @@ -pub(crate) mod request; -pub(crate) mod request_parts; +mod request; +mod request_parts; + +pub use request::RequestExt; +pub use request_parts::RequestPartsExt; #[cfg(test)] mod tests { diff --git a/axum-core/src/ext_traits/request_parts.rs b/axum-core/src/ext_traits/request_parts.rs index 9e1a3d1c16..528f0bd63d 100644 --- a/axum-core/src/ext_traits/request_parts.rs +++ b/axum-core/src/ext_traits/request_parts.rs @@ -147,7 +147,7 @@ mod tests { #[tokio::test] async fn extract_without_state() { - let (mut parts, _) = Request::new(()).into_parts(); + let (mut parts, ()) = Request::new(()).into_parts(); let method: Method = parts.extract().await.unwrap(); @@ -156,7 +156,7 @@ mod tests { #[tokio::test] async fn extract_with_state() { - let (mut parts, _) = Request::new(()).into_parts(); + let (mut parts, ()) = Request::new(()).into_parts(); let state = "state".to_owned(); diff --git a/axum-core/src/extract/option.rs b/axum-core/src/extract/option.rs index c537e72187..ca4be459c8 100644 --- a/axum-core/src/extract/option.rs +++ b/axum-core/src/extract/option.rs @@ -45,7 +45,7 @@ where fn from_request_parts( parts: &mut Parts, state: &S, - ) -> impl Future, Self::Rejection>> { + ) -> impl Future> { T::from_request_parts(parts, state) } } @@ -57,7 +57,7 @@ where { type Rejection = T::Rejection; - async fn from_request(req: Request, state: &S) -> Result, Self::Rejection> { + async fn from_request(req: Request, state: &S) -> Result { T::from_request(req, state).await } } diff --git a/axum-core/src/extract/request_parts.rs b/axum-core/src/extract/request_parts.rs index b2a5c14aa8..c3536526d4 100644 --- a/axum-core/src/extract/request_parts.rs +++ b/axum-core/src/extract/request_parts.rs @@ -1,4 +1,7 @@ -use super::{rejection::*, FromRequest, FromRequestParts, Request}; +use super::{ + rejection::{BytesRejection, FailedToBufferBody, InvalidUtf8, StringRejection}, + FromRequest, FromRequestParts, Request, +}; use crate::{body::Body, RequestExt}; use bytes::{BufMut, Bytes, BytesMut}; use http::{request::Parts, Extensions, HeaderMap, Method, Uri, Version}; @@ -73,7 +76,7 @@ where async fn from_request(req: Request, _: &S) -> Result { let mut body = req.into_limited_body(); - let mut bytes = BytesMut::new(); + let mut bytes = Self::new(); body_to_bytes_mut(&mut body, &mut bytes).await?; Ok(bytes) } @@ -128,7 +131,7 @@ where } })?; - let string = String::from_utf8(bytes.into()).map_err(InvalidUtf8::from_err)?; + let string = Self::from_utf8(bytes.into()).map_err(InvalidUtf8::from_err)?; Ok(string) } diff --git a/axum-core/src/lib.rs b/axum-core/src/lib.rs index 1fd4534710..39b1712467 100644 --- a/axum-core/src/lib.rs +++ b/axum-core/src/lib.rs @@ -30,7 +30,7 @@ pub mod response; /// Alias for a type-erased error type. pub type BoxError = Box; -pub use self::ext_traits::{request::RequestExt, request_parts::RequestPartsExt}; +pub use self::ext_traits::{RequestExt, RequestPartsExt}; #[cfg(test)] use axum_macros::__private_axum_test as test; diff --git a/axum-core/src/macros.rs b/axum-core/src/macros.rs index 8f2762486e..b21b0203e3 100644 --- a/axum-core/src/macros.rs +++ b/axum-core/src/macros.rs @@ -54,7 +54,7 @@ macro_rules! __define_rejection { } /// Get the status code used for this rejection. - pub fn status(&self) -> http::StatusCode { + pub const fn status(&self) -> http::StatusCode { http::StatusCode::$status } } @@ -106,12 +106,14 @@ macro_rules! __define_rejection { } /// Get the response body text used for this rejection. + #[must_use] pub fn body_text(&self) -> String { self.to_string() } /// Get the status code used for this rejection. - pub fn status(&self) -> http::StatusCode { + #[must_use] + pub const fn status(&self) -> http::StatusCode { http::StatusCode::$status } } @@ -179,6 +181,7 @@ macro_rules! __composite_rejection { impl $name { /// Get the response body text used for this rejection. + #[must_use] pub fn body_text(&self) -> String { match self { $( @@ -188,7 +191,8 @@ macro_rules! __composite_rejection { } /// Get the status code used for this rejection. - pub fn status(&self) -> http::StatusCode { + #[must_use] + pub const fn status(&self) -> http::StatusCode { match self { $( Self::$variant(inner) => inner.status(), diff --git a/axum-core/src/response/into_response_parts.rs b/axum-core/src/response/into_response_parts.rs index 955648238d..5fb56acbe1 100644 --- a/axum-core/src/response/into_response_parts.rs +++ b/axum-core/src/response/into_response_parts.rs @@ -165,13 +165,13 @@ pub struct TryIntoHeaderError { } impl TryIntoHeaderError { - pub(super) fn key(err: K) -> Self { + pub(super) const fn key(err: K) -> Self { Self { kind: TryIntoHeaderErrorKind::Key(err), } } - pub(super) fn value(err: V) -> Self { + pub(super) const fn value(err: V) -> Self { Self { kind: TryIntoHeaderErrorKind::Value(err), } diff --git a/axum-extra/src/either.rs b/axum-extra/src/either.rs index 5d0e33eb7c..9f426c951d 100755 --- a/axum-extra/src/either.rs +++ b/axum-extra/src/either.rs @@ -283,8 +283,8 @@ where fn layer(&self, inner: S) -> Self::Service { match self { - Either::E1(layer) => Either::E1(layer.layer(inner)), - Either::E2(layer) => Either::E2(layer.layer(inner)), + Self::E1(layer) => Either::E1(layer.layer(inner)), + Self::E2(layer) => Either::E2(layer.layer(inner)), } } } @@ -300,15 +300,15 @@ where fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { match self { - Either::E1(inner) => inner.poll_ready(cx), - Either::E2(inner) => inner.poll_ready(cx), + Self::E1(inner) => inner.poll_ready(cx), + Self::E2(inner) => inner.poll_ready(cx), } } fn call(&mut self, req: R) -> Self::Future { match self { - Either::E1(inner) => futures_util::future::Either::Left(inner.call(req)), - Either::E2(inner) => futures_util::future::Either::Right(inner.call(req)), + Self::E1(inner) => futures_util::future::Either::Left(inner.call(req)), + Self::E2(inner) => futures_util::future::Either::Right(inner.call(req)), } } } diff --git a/axum-extra/src/extract/cached.rs b/axum-extra/src/extract/cached.rs index 64b4c3056f..d0f44e446f 100644 --- a/axum-extra/src/extract/cached.rs +++ b/axum-extra/src/extract/cached.rs @@ -133,7 +133,7 @@ mod tests { } } - let (mut parts, _) = Request::new(()).into_parts(); + let (mut parts, ()) = Request::new(()).into_parts(); let first = Cached::::from_request_parts(&mut parts, &()) .await diff --git a/axum-extra/src/extract/cookie/mod.rs b/axum-extra/src/extract/cookie/mod.rs index 50fa6031ac..e99fa65eb0 100644 --- a/axum-extra/src/extract/cookie/mod.rs +++ b/axum-extra/src/extract/cookie/mod.rs @@ -118,6 +118,7 @@ impl CookieJar { /// run extractors. Normally you should create `CookieJar`s through [`FromRequestParts`]. /// /// [`FromRequestParts`]: axum::extract::FromRequestParts + #[must_use] pub fn from_headers(headers: &HeaderMap) -> Self { let mut jar = cookie::CookieJar::new(); for cookie in cookies_from_request(headers) { @@ -135,6 +136,7 @@ impl CookieJar { /// CookieJar`. /// /// [`FromRequestParts`]: axum::extract::FromRequestParts + #[must_use] pub fn new() -> Self { Self::default() } @@ -153,6 +155,7 @@ impl CookieJar { /// .map(|cookie| cookie.value().to_owned()); /// } /// ``` + #[must_use] pub fn get(&self, name: &str) -> Option<&Cookie<'static>> { self.jar.get(name) } @@ -321,13 +324,13 @@ mod tests { } impl FromRef for Key { - fn from_ref(state: &AppState) -> Key { + fn from_ref(state: &AppState) -> Self { state.key.clone() } } impl FromRef for CustomKey { - fn from_ref(state: &AppState) -> CustomKey { + fn from_ref(state: &AppState) -> Self { state.custom_key.clone() } } diff --git a/axum-extra/src/extract/cookie/private.rs b/axum-extra/src/extract/cookie/private.rs index f852b8c4ba..09ac5d7ec1 100644 --- a/axum-extra/src/extract/cookie/private.rs +++ b/axum-extra/src/extract/cookie/private.rs @@ -136,7 +136,7 @@ where key, _marker: _, } = PrivateCookieJar::from_headers(&parts.headers, key); - Ok(PrivateCookieJar { + Ok(Self { jar, key, _marker: PhantomData, @@ -153,6 +153,7 @@ impl PrivateCookieJar { /// run extractors. Normally you should create `PrivateCookieJar`s through [`FromRequestParts`]. /// /// [`FromRequestParts`]: axum::extract::FromRequestParts + #[must_use] pub fn from_headers(headers: &HeaderMap, key: Key) -> Self { let mut jar = cookie::CookieJar::new(); let mut private_jar = jar.private_mut(&key); @@ -175,6 +176,7 @@ impl PrivateCookieJar { /// run extractors. Normally you should create `PrivateCookieJar`s through [`FromRequestParts`]. /// /// [`FromRequestParts`]: axum::extract::FromRequestParts + #[must_use] pub fn new(key: Key) -> Self { Self { jar: Default::default(), @@ -201,6 +203,7 @@ impl PrivateCookieJar { /// .map(|cookie| cookie.value().to_owned()); /// } /// ``` + #[must_use] pub fn get(&self, name: &str) -> Option> { self.private_jar().get(name) } @@ -246,6 +249,7 @@ impl PrivateCookieJar { /// Authenticates and decrypts `cookie`, returning the plaintext version if decryption succeeds /// or `None` otherwise. + #[must_use] pub fn decrypt(&self, cookie: Cookie<'static>) -> Option> { self.private_jar().decrypt(cookie) } diff --git a/axum-extra/src/extract/cookie/signed.rs b/axum-extra/src/extract/cookie/signed.rs index 92bf917145..1c2c97f8ea 100644 --- a/axum-extra/src/extract/cookie/signed.rs +++ b/axum-extra/src/extract/cookie/signed.rs @@ -153,7 +153,7 @@ where key, _marker: _, } = SignedCookieJar::from_headers(&parts.headers, key); - Ok(SignedCookieJar { + Ok(Self { jar, key, _marker: PhantomData, @@ -170,6 +170,7 @@ impl SignedCookieJar { /// run extractors. Normally you should create `SignedCookieJar`s through [`FromRequestParts`]. /// /// [`FromRequestParts`]: axum::extract::FromRequestParts + #[must_use] pub fn from_headers(headers: &HeaderMap, key: Key) -> Self { let mut jar = cookie::CookieJar::new(); let mut signed_jar = jar.signed_mut(&key); @@ -192,6 +193,7 @@ impl SignedCookieJar { /// run extractors. Normally you should create `SignedCookieJar`s through [`FromRequestParts`]. /// /// [`FromRequestParts`]: axum::extract::FromRequestParts + #[must_use] pub fn new(key: Key) -> Self { Self { jar: Default::default(), @@ -219,6 +221,7 @@ impl SignedCookieJar { /// .map(|cookie| cookie.value().to_owned()); /// } /// ``` + #[must_use] pub fn get(&self, name: &str) -> Option> { self.signed_jar().get(name) } @@ -264,6 +267,7 @@ impl SignedCookieJar { /// Verifies the authenticity and integrity of `cookie`, returning the plaintext version if /// verification succeeds or `None` otherwise. + #[must_use] pub fn verify(&self, cookie: Cookie<'static>) -> Option> { self.signed_jar().verify(cookie) } diff --git a/axum-extra/src/extract/host.rs b/axum-extra/src/extract/host.rs index e9eb91c5be..0f33a4ab02 100644 --- a/axum-extra/src/extract/host.rs +++ b/axum-extra/src/extract/host.rs @@ -36,7 +36,7 @@ where async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { parts - .extract::>() + .extract::>() .await .ok() .flatten() @@ -55,7 +55,7 @@ where _state: &S, ) -> Result, Self::Rejection> { if let Some(host) = parse_forwarded(&parts.headers) { - return Ok(Some(Host(host.to_owned()))); + return Ok(Some(Self(host.to_owned()))); } if let Some(host) = parts @@ -63,7 +63,7 @@ where .get(X_FORWARDED_HOST_HEADER_KEY) .and_then(|host| host.to_str().ok()) { - return Ok(Some(Host(host.to_owned()))); + return Ok(Some(Self(host.to_owned()))); } if let Some(host) = parts @@ -71,11 +71,11 @@ where .get(http::header::HOST) .and_then(|host| host.to_str().ok()) { - return Ok(Some(Host(host.to_owned()))); + return Ok(Some(Self(host.to_owned()))); } if let Some(authority) = parts.uri.authority() { - return Ok(Some(Host(parse_authority(authority).to_owned()))); + return Ok(Some(Self(parse_authority(authority).to_owned()))); } Ok(None) diff --git a/axum-extra/src/extract/json_deserializer.rs b/axum-extra/src/extract/json_deserializer.rs index 051ab0f1bd..8d53101999 100644 --- a/axum-extra/src/extract/json_deserializer.rs +++ b/axum-extra/src/extract/json_deserializer.rs @@ -183,21 +183,15 @@ composite_rejection! { } fn json_content_type(headers: &HeaderMap) -> bool { - let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) { - content_type - } else { + let Some(content_type) = headers.get(header::CONTENT_TYPE) else { return false; }; - let content_type = if let Ok(content_type) = content_type.to_str() { - content_type - } else { + let Ok(content_type) = content_type.to_str() else { return false; }; - let mime = if let Ok(mime) = content_type.parse::() { - mime - } else { + let Ok(mime) = content_type.parse::() else { return false; }; diff --git a/axum-extra/src/extract/multipart.rs b/axum-extra/src/extract/multipart.rs index e92dc1788e..f2f9892b7a 100644 --- a/axum-extra/src/extract/multipart.rs +++ b/axum-extra/src/extract/multipart.rs @@ -114,11 +114,7 @@ impl Multipart { .await .map_err(MultipartError::from_multer)?; - if let Some(field) = field { - Ok(Some(Field { inner: field })) - } else { - Ok(None) - } + field.map_or_else(|| Ok(None), |field| Ok(Some(Field { inner: field }))) } /// Convert the `Multipart` into a stream of its fields. @@ -150,6 +146,7 @@ impl Field { /// The field name found in the /// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) /// header. + #[must_use] pub fn name(&self) -> Option<&str> { self.inner.name() } @@ -157,16 +154,19 @@ impl Field { /// The file name found in the /// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) /// header. + #[must_use] pub fn file_name(&self) -> Option<&str> { self.inner.file_name() } /// Get the [content type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type) of the field. + #[must_use] pub fn content_type(&self) -> Option<&str> { self.inner.content_type().map(|m| m.as_ref()) } /// Get a map of headers as [`HeaderMap`]. + #[must_use] pub fn headers(&self) -> &HeaderMap { self.inner.headers() } @@ -237,7 +237,7 @@ pub struct MultipartError { } impl MultipartError { - fn from_multer(multer: multer::Error) -> Self { + const fn from_multer(multer: multer::Error) -> Self { Self { source: multer } } @@ -253,6 +253,7 @@ impl MultipartError { } /// Get the status code used for this rejection. + #[must_use] pub fn status(&self) -> http::StatusCode { status_code_from_multer_error(&self.source) } diff --git a/axum-extra/src/extract/query.rs b/axum-extra/src/extract/query.rs index 72bf2f4703..39eab260ca 100644 --- a/axum-extra/src/extract/query.rs +++ b/axum-extra/src/extract/query.rs @@ -91,7 +91,7 @@ where serde_html_form::Deserializer::new(form_urlencoded::parse(query.as_bytes())); let value = serde_path_to_error::deserialize(deserializer) .map_err(FailedToDeserializeQueryString::from_err)?; - Ok(Query(value)) + Ok(Self(value)) } } @@ -170,9 +170,9 @@ where serde_html_form::Deserializer::new(form_urlencoded::parse(query.as_bytes())); let value = serde_path_to_error::deserialize(deserializer) .map_err(FailedToDeserializeQueryString::from_err)?; - Ok(OptionalQuery(Some(value))) + Ok(Self(Some(value))) } else { - Ok(OptionalQuery(None)) + Ok(Self(None)) } } } @@ -269,8 +269,7 @@ mod tests { let app = Router::new().route( "/", post(|OptionalQuery(data): OptionalQuery| async move { - data.map(|Data { values }| values.join(",")) - .unwrap_or("None".to_owned()) + data.map_or("None".to_owned(), |Data { values }| values.join(",")) }), ); diff --git a/axum-extra/src/extract/scheme.rs b/axum-extra/src/extract/scheme.rs index b20e9cf205..3634bb262f 100644 --- a/axum-extra/src/extract/scheme.rs +++ b/axum-extra/src/extract/scheme.rs @@ -38,7 +38,7 @@ where async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { // Within Forwarded header if let Some(scheme) = parse_forwarded(&parts.headers) { - return Ok(Scheme(scheme.to_owned())); + return Ok(Self(scheme.to_owned())); } // X-Forwarded-Proto @@ -47,12 +47,12 @@ where .get(X_FORWARDED_PROTO_HEADER_KEY) .and_then(|scheme| scheme.to_str().ok()) { - return Ok(Scheme(scheme.to_owned())); + return Ok(Self(scheme.to_owned())); } // From parts of an HTTP/2 request if let Some(scheme) = parts.uri.scheme_str() { - return Ok(Scheme(scheme.to_owned())); + return Ok(Self(scheme.to_owned())); } Err(SchemeMissing) diff --git a/axum-extra/src/extract/with_rejection.rs b/axum-extra/src/extract/with_rejection.rs index c093f6fa47..1084801cb3 100644 --- a/axum-extra/src/extract/with_rejection.rs +++ b/axum-extra/src/extract/with_rejection.rs @@ -119,7 +119,7 @@ where async fn from_request(req: Request, state: &S) -> Result { let extractor = E::from_request(req, state).await?; - Ok(WithRejection(extractor, PhantomData)) + Ok(Self(extractor, PhantomData)) } } @@ -133,7 +133,7 @@ where async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let extractor = E::from_request_parts(parts, state).await?; - Ok(WithRejection(extractor, PhantomData)) + Ok(Self(extractor, PhantomData)) } } @@ -187,8 +187,8 @@ mod tests { } impl From<()> for TestRejection { - fn from(_: ()) -> Self { - TestRejection + fn from((): ()) -> Self { + Self } } @@ -196,7 +196,7 @@ mod tests { let result = WithRejection::::from_request(req, &()).await; assert!(matches!(result, Err(TestRejection))); - let (mut parts, _) = Request::new(()).into_parts(); + let (mut parts, ()) = Request::new(()).into_parts(); let result = WithRejection::::from_request_parts(&mut parts, &()) .await; diff --git a/axum-extra/src/json_lines.rs b/axum-extra/src/json_lines.rs index 9aea94dc83..638ce53228 100644 --- a/axum-extra/src/json_lines.rs +++ b/axum-extra/src/json_lines.rs @@ -91,7 +91,7 @@ pub struct AsResponse; impl JsonLines { /// Create a new `JsonLines` from a stream of items. - pub fn new(stream: S) -> Self { + pub const fn new(stream: S) -> Self { Self { inner: Inner::Response { stream }, _marker: PhantomData, diff --git a/axum-extra/src/middleware.rs b/axum-extra/src/middleware.rs index 0303d484fd..de296959c4 100644 --- a/axum-extra/src/middleware.rs +++ b/axum-extra/src/middleware.rs @@ -38,7 +38,5 @@ use tower_layer::Identity; /// [`HandleErrorLayer`]: axum::error_handling::HandleErrorLayer /// [`Infallible`]: std::convert::Infallible pub fn option_layer(layer: Option) -> Either { - layer - .map(Either::E1) - .unwrap_or_else(|| Either::E2(Identity::new())) + layer.map_or_else(|| Either::E2(Identity::new()), Either::E1) } diff --git a/axum-extra/src/protobuf.rs b/axum-extra/src/protobuf.rs index fb63c7a41c..260de04114 100644 --- a/axum-extra/src/protobuf.rs +++ b/axum-extra/src/protobuf.rs @@ -109,7 +109,7 @@ where .aggregate(); match T::decode(&mut buf) { - Ok(value) => Ok(Protobuf(value)), + Ok(value) => Ok(Self(value)), Err(err) => Err(ProtobufDecodeError::from_err(err).into()), } } diff --git a/axum-extra/src/response/attachment.rs b/axum-extra/src/response/attachment.rs index 2063d30f05..64f647d1e2 100644 --- a/axum-extra/src/response/attachment.rs +++ b/axum-extra/src/response/attachment.rs @@ -43,7 +43,7 @@ pub struct Attachment { impl Attachment { /// Creates a new [`Attachment`]. - pub fn new(inner: T) -> Self { + pub const fn new(inner: T) -> Self { Self { inner, filename: None, @@ -55,12 +55,13 @@ impl Attachment { /// /// This updates the `Content-Disposition` header to add a filename. pub fn filename>(mut self, value: H) -> Self { - self.filename = if let Ok(filename) = value.try_into() { - Some(filename) - } else { - error!("Attachment filename contains invalid characters"); - None - }; + self.filename = value.try_into().map_or_else( + |_| { + error!("Attachment filename contains invalid characters"); + None + }, + Some, + ); self } @@ -86,15 +87,17 @@ where headers.append(header::CONTENT_TYPE, content_type); } - let content_disposition = if let Some(filename) = self.filename { - let mut bytes = b"attachment; filename=\"".to_vec(); - bytes.extend_from_slice(filename.as_bytes()); - bytes.push(b'\"'); + let content_disposition = self.filename.map_or_else( + || HeaderValue::from_static("attachment"), + |filename| { + let mut bytes = b"attachment; filename=\"".to_vec(); + bytes.extend_from_slice(filename.as_bytes()); + bytes.push(b'\"'); - HeaderValue::from_bytes(&bytes).expect("This was a HeaderValue so this can not fail") - } else { - HeaderValue::from_static("attachment") - }; + HeaderValue::from_bytes(&bytes) + .expect("This was a HeaderValue so this can not fail") + }, + ); headers.append(header::CONTENT_DISPOSITION, content_disposition); diff --git a/axum-extra/src/response/file_stream.rs b/axum-extra/src/response/file_stream.rs index ed1afdff65..7b0e512442 100644 --- a/axum-extra/src/response/file_stream.rs +++ b/axum-extra/src/response/file_stream.rs @@ -61,7 +61,7 @@ where S::Error: Into, { /// Create a new [`FileStream`] - pub fn new(stream: S) -> Self { + pub const fn new(stream: S) -> Self { Self { stream, file_name: None, @@ -118,20 +118,24 @@ where /// Set the file name of the [`FileStream`]. /// /// This adds the attachment `Content-Disposition` header with the given `file_name`. + #[must_use] pub fn file_name(mut self, file_name: impl Into) -> Self { self.file_name = Some(file_name.into()); self } /// Set the size of the file. - pub fn content_size(mut self, len: u64) -> Self { + #[must_use] + pub const fn content_size(mut self, len: u64) -> Self { self.content_size = Some(len); self } /// Return a range response. /// + /// ```rust /// range: (start, end, total_size) + /// ``` /// /// # Examples /// diff --git a/axum-extra/src/response/multiple.rs b/axum-extra/src/response/multiple.rs index 390ef3e726..aece900909 100644 --- a/axum-extra/src/response/multiple.rs +++ b/axum-extra/src/response/multiple.rs @@ -24,8 +24,9 @@ impl MultipartForm { /// let parts: Vec = vec![Part::text("foo".to_string(), "abc"), Part::text("bar".to_string(), "def")]; /// let form = MultipartForm::with_parts(parts); /// ``` - pub fn with_parts(parts: Vec) -> Self { - MultipartForm { parts } + #[must_use] + pub const fn with_parts(parts: Vec) -> Self { + Self { parts } } } @@ -103,6 +104,7 @@ impl Part { /// let parts: Vec = vec![Part::text("foo".to_string(), "abc")]; /// let form = MultipartForm::from_iter(parts); /// ``` + #[must_use] pub fn text(name: String, contents: &str) -> Self { Self { name, @@ -127,6 +129,7 @@ impl Part { /// let parts: Vec = vec![Part::file("foo", "foo.txt", vec![0x68, 0x68, 0x20, 0x6d, 0x6f, 0x6d])]; /// let form = MultipartForm::from_iter(parts); /// ``` + #[must_use] pub fn file(field_name: &str, file_name: &str, contents: Vec) -> Self { Self { name: field_name.to_owned(), diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs index cf85dc532a..bcac7d3c8a 100644 --- a/axum-extra/src/routing/mod.rs +++ b/axum-extra/src/routing/mod.rs @@ -28,13 +28,13 @@ pub use self::typed::{SecondElementIs, TypedPath}; // Validates a path at compile time, used with the vpath macro. #[rustversion::since(1.80)] #[doc(hidden)] +#[must_use] pub const fn __private_validate_static_path(path: &'static str) -> &'static str { - if path.is_empty() { - panic!("Paths must start with a `/`. Use \"/\" for root routes") - } - if path.as_bytes()[0] != b'/' { - panic!("Paths must start with /"); - } + assert!( + !path.is_empty(), + "Paths must start with a `/`. Use \"/\" for root routes" + ); + assert!(path.as_bytes()[0] == b'/', "Paths must start with /"); path } @@ -84,6 +84,7 @@ pub trait RouterExt: sealed::Sealed { /// /// See [`TypedPath`] for more details and examples. #[cfg(feature = "typed-routing")] + #[must_use] fn typed_get(self, handler: H) -> Self where H: axum::handler::Handler, @@ -97,6 +98,7 @@ pub trait RouterExt: sealed::Sealed { /// /// See [`TypedPath`] for more details and examples. #[cfg(feature = "typed-routing")] + #[must_use] fn typed_delete(self, handler: H) -> Self where H: axum::handler::Handler, @@ -110,6 +112,7 @@ pub trait RouterExt: sealed::Sealed { /// /// See [`TypedPath`] for more details and examples. #[cfg(feature = "typed-routing")] + #[must_use] fn typed_head(self, handler: H) -> Self where H: axum::handler::Handler, @@ -123,6 +126,7 @@ pub trait RouterExt: sealed::Sealed { /// /// See [`TypedPath`] for more details and examples. #[cfg(feature = "typed-routing")] + #[must_use] fn typed_options(self, handler: H) -> Self where H: axum::handler::Handler, @@ -136,6 +140,7 @@ pub trait RouterExt: sealed::Sealed { /// /// See [`TypedPath`] for more details and examples. #[cfg(feature = "typed-routing")] + #[must_use] fn typed_patch(self, handler: H) -> Self where H: axum::handler::Handler, @@ -149,6 +154,7 @@ pub trait RouterExt: sealed::Sealed { /// /// See [`TypedPath`] for more details and examples. #[cfg(feature = "typed-routing")] + #[must_use] fn typed_post(self, handler: H) -> Self where H: axum::handler::Handler, @@ -162,6 +168,7 @@ pub trait RouterExt: sealed::Sealed { /// /// See [`TypedPath`] for more details and examples. #[cfg(feature = "typed-routing")] + #[must_use] fn typed_put(self, handler: H) -> Self where H: axum::handler::Handler, @@ -175,6 +182,7 @@ pub trait RouterExt: sealed::Sealed { /// /// See [`TypedPath`] for more details and examples. #[cfg(feature = "typed-routing")] + #[must_use] fn typed_trace(self, handler: H) -> Self where H: axum::handler::Handler, @@ -188,6 +196,7 @@ pub trait RouterExt: sealed::Sealed { /// /// See [`TypedPath`] for more details and examples. #[cfg(feature = "typed-routing")] + #[must_use] fn typed_connect(self, handler: H) -> Self where H: axum::handler::Handler, @@ -219,6 +228,7 @@ pub trait RouterExt: sealed::Sealed { /// .route_with_tsr("/bar/", get(|| async {})); /// # let _: Router = app; /// ``` + #[must_use] fn route_with_tsr(self, path: &str, method_router: MethodRouter) -> Self where Self: Sized; @@ -226,6 +236,7 @@ pub trait RouterExt: sealed::Sealed { /// Add another route to the router with an additional "trailing slash redirect" route. /// /// This works like [`RouterExt::route_with_tsr`] but accepts any [`Service`]. + #[must_use] fn route_service_with_tsr(self, path: &str, service: T) -> Self where T: Service + Clone + Send + Sync + 'static, @@ -354,9 +365,10 @@ where #[track_caller] fn validate_tsr_path(path: &str) { - if path == "/" { - panic!("Cannot add a trailing slash redirect route for `/`") - } + assert!( + path != "/", + "Cannot add a trailing slash redirect route for `/`" + ); } fn add_tsr_redirect_route(router: Router, path: &str) -> Router @@ -366,15 +378,13 @@ where async fn redirect_handler(OriginalUri(uri): OriginalUri) -> Response { let new_uri = map_path(uri, |path| { path.strip_suffix('/') - .map(Cow::Borrowed) - .unwrap_or_else(|| Cow::Owned(format!("{path}/"))) + .map_or_else(|| Cow::Owned(format!("{path}/")), Cow::Borrowed) }); - if let Some(new_uri) = new_uri { - Redirect::permanent(&new_uri.to_string()).into_response() - } else { - StatusCode::BAD_REQUEST.into_response() - } + new_uri.map_or_else( + || StatusCode::BAD_REQUEST.into_response(), + |new_uri| Redirect::permanent(&new_uri.to_string()).into_response(), + ) } if let Some(path_without_trailing_slash) = path.strip_suffix('/') { diff --git a/axum-extra/src/typed_header.rs b/axum-extra/src/typed_header.rs index 7c08be9e38..0844f4a72a 100644 --- a/axum-extra/src/typed_header.rs +++ b/axum-extra/src/typed_header.rs @@ -137,12 +137,14 @@ pub struct TypedHeaderRejection { impl TypedHeaderRejection { /// Name of the header that caused the rejection - pub fn name(&self) -> &http::header::HeaderName { + #[must_use] + pub const fn name(&self) -> &http::header::HeaderName { self.name } /// Reason why the header extraction has failed - pub fn reason(&self) -> &TypedHeaderRejectionReason { + #[must_use] + pub const fn reason(&self) -> &TypedHeaderRejectionReason { &self.reason } @@ -150,7 +152,7 @@ impl TypedHeaderRejection { /// /// [`Missing`]: TypedHeaderRejectionReason::Missing #[must_use] - pub fn is_missing(&self) -> bool { + pub const fn is_missing(&self) -> bool { self.reason.is_missing() } } @@ -171,7 +173,7 @@ impl TypedHeaderRejectionReason { /// /// [`Missing`]: TypedHeaderRejectionReason::Missing #[must_use] - pub fn is_missing(&self) -> bool { + pub const fn is_missing(&self) -> bool { matches!(self, Self::Missing) } } diff --git a/axum-macros/src/debug_handler.rs b/axum-macros/src/debug_handler.rs index 73dd9f23d9..e2b28fc5ac 100644 --- a/axum-macros/src/debug_handler.rs +++ b/axum-macros/src/debug_handler.rs @@ -52,20 +52,23 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn, kind: FunctionKind) -> TokenS let check_future_send = check_future_send(&item_fn, kind); - if let Some(check_input_order) = check_input_order(&item_fn, kind) { - quote! { - #check_input_order - #check_future_send - } - } else { - let check_inputs_impls_from_request = - check_inputs_impls_from_request(&item_fn, state_ty, kind); - - quote! { - #check_inputs_impls_from_request - #check_future_send - } - } + check_input_order(&item_fn, kind).map_or_else( + || { + let check_inputs_impls_from_request = + check_inputs_impls_from_request(&item_fn, state_ty, kind); + + quote! { + #check_inputs_impls_from_request + #check_future_send + } + }, + |check_input_order| { + quote! { + #check_input_order + #check_future_send + } + }, + ) }) } else { syn::Error::new_spanned( @@ -97,17 +100,17 @@ pub(crate) enum FunctionKind { impl fmt::Display for FunctionKind { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - FunctionKind::Handler => f.write_str("handler"), - FunctionKind::Middleware => f.write_str("middleware"), + Self::Handler => f.write_str("handler"), + Self::Middleware => f.write_str("middleware"), } } } impl FunctionKind { - fn name_uppercase_plural(&self) -> &'static str { + const fn name_uppercase_plural(&self) -> &'static str { match self { - FunctionKind::Handler => "Handlers", - FunctionKind::Middleware => "Middleware", + Self::Handler => "Handlers", + Self::Middleware => "Middleware", } } } @@ -362,8 +365,9 @@ fn check_output_tuples(item_fn: &ItemFn) -> TokenStream { Position::First(ty) => match extract_clean_typename(ty).as_deref() { Some("StatusCode" | "Response") => quote! {}, Some("Parts") => check_is_response_parts(ty, handler_ident, idx), - Some(_) | None => { - if let Some(tn) = well_known_last_response_type(ty) { + Some(_) | None => well_known_last_response_type(ty).map_or_else( + || check_into_response_parts(ty, handler_ident, idx), + |tn| { syn::Error::new_spanned( ty, format!( @@ -372,22 +376,19 @@ fn check_output_tuples(item_fn: &ItemFn) -> TokenStream { ), ) .to_compile_error() - } else { - check_into_response_parts(ty, handler_ident, idx) - } - } + }, + ), }, - Position::Middle(ty) => { - if let Some(tn) = well_known_last_response_type(ty) { + Position::Middle(ty) => well_known_last_response_type(ty).map_or_else( + || check_into_response_parts(ty, handler_ident, idx), + |tn| { syn::Error::new_spanned( ty, format!("`{tn}` must be the last element in a response tuple"), ) .to_compile_error() - } else { - check_into_response_parts(ty, handler_ident, idx) - } - } + }, + ), Position::Last(ty) | Position::Only(ty) => check_into_response(handler_ident, ty), }) .collect::(), @@ -522,7 +523,7 @@ fn check_input_order(item_fn: &ItemFn, kind: FunctionKind) -> Option Option compile_error!(#error); }); - } else { - return None; } + return None; } if types_that_consume_the_request.len() == 2 { @@ -673,39 +673,42 @@ fn check_output_impls_into_response(item_fn: &ItemFn) -> TokenStream { let name = format_ident!("__axum_macros_check_{}_into_response", item_fn.sig.ident); - if let Some(receiver) = self_receiver(item_fn) { - quote_spanned! {span=> - #make + self_receiver(item_fn).map_or_else( + || { + quote_spanned! {span=> + #[allow(warnings)] + #[allow(unreachable_code)] + #[doc(hidden)] + async fn #name() { + #make - #[allow(warnings)] - #[allow(unreachable_code)] - #[doc(hidden)] - async fn #name() { - fn check(_: T) + fn check(_: T) where T: ::axum::response::IntoResponse - {} - let value = #receiver #make_value_name().await; - check(value); - } - } - } else { - quote_spanned! {span=> - #[allow(warnings)] - #[allow(unreachable_code)] - #[doc(hidden)] - async fn #name() { - #make + {} - fn check(_: T) - where T: ::axum::response::IntoResponse - {} + let value = #make_value_name().await; - let value = #make_value_name().await; + check(value); + } + } + }, + |receiver| { + quote_spanned! {span=> + #make - check(value); + #[allow(warnings)] + #[allow(unreachable_code)] + #[doc(hidden)] + async fn #name() { + fn check(_: T) + where T: ::axum::response::IntoResponse + {} + let value = #receiver #make_value_name().await; + check(value); + } } - } - } + }, + ) } fn check_future_send(item_fn: &ItemFn, kind: FunctionKind) -> TokenStream { diff --git a/axum-macros/src/from_ref.rs b/axum-macros/src/from_ref.rs index 9cf317da97..99e543509c 100644 --- a/axum-macros/src/from_ref.rs +++ b/axum-macros/src/from_ref.rs @@ -39,19 +39,22 @@ fn expand_field(state: &Ident, idx: usize, field: &Field) -> TokenStream { let field_ty = &field.ty; let span = field.ty.span(); - let body = if let Some(field_ident) = &field.ident { - if matches!(field_ty, Type::Reference(_)) { - quote_spanned! {span=> state.#field_ident } - } else { - quote_spanned! {span=> state.#field_ident.clone() } - } - } else { - let idx = syn::Index { - index: idx as _, - span: field.span(), - }; - quote_spanned! {span=> state.#idx.clone() } - }; + let body = field.ident.as_ref().map_or_else( + || { + let idx = syn::Index { + index: idx as _, + span: field.span(), + }; + quote_spanned! {span=> state.#idx.clone() } + }, + |field_ident| { + if matches!(field_ty, Type::Reference(_)) { + quote_spanned! {span=> state.#field_ident } + } else { + quote_spanned! {span=> state.#field_ident.clone() } + } + }, + ); quote_spanned! {span=> #[allow(clippy::clone_on_copy, clippy::clone_on_ref_ptr)] diff --git a/axum-macros/src/from_request/mod.rs b/axum-macros/src/from_request/mod.rs index 3838636597..e8b093e4dd 100644 --- a/axum-macros/src/from_request/mod.rs +++ b/axum-macros/src/from_request/mod.rs @@ -21,8 +21,8 @@ pub(crate) enum Trait { impl Trait { fn via_marker_type(&self) -> Option { match self { - Trait::FromRequest => Some(parse_quote!(M)), - Trait::FromRequestParts => None, + Self::FromRequest => Some(parse_quote!(M)), + Self::FromRequestParts => None, } } } @@ -30,8 +30,8 @@ impl Trait { impl fmt::Display for Trait { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Trait::FromRequest => f.write_str("FromRequest"), - Trait::FromRequestParts => f.write_str("FromRequestParts"), + Self::FromRequest => f.write_str("FromRequest"), + Self::FromRequestParts => f.write_str("FromRequestParts"), } } } @@ -50,9 +50,9 @@ impl State { /// ``` fn impl_generics(&self) -> impl Iterator { match self { - State::Default(inner) => Some(inner.clone()), - State::Custom(_) => None, - State::CannotInfer => Some(parse_quote!(S)), + Self::Default(inner) => Some(inner.clone()), + Self::Custom(_) => None, + Self::CannotInfer => Some(parse_quote!(S)), } .into_iter() } @@ -63,18 +63,18 @@ impl State { /// ``` fn trait_generics(&self) -> impl Iterator { match self { - State::Default(inner) | State::Custom(inner) => iter::once(inner.clone()), - State::CannotInfer => iter::once(parse_quote!(S)), + Self::Default(inner) | Self::Custom(inner) => iter::once(inner.clone()), + Self::CannotInfer => iter::once(parse_quote!(S)), } } fn bounds(&self) -> TokenStream { match self { - State::Custom(_) => quote! {}, - State::Default(inner) => quote! { + Self::Custom(_) => quote! {}, + Self::Default(inner) => quote! { #inner: ::std::marker::Send + ::std::marker::Sync, }, - State::CannotInfer => quote! { + Self::CannotInfer => quote! { S: ::std::marker::Send + ::std::marker::Sync, }, } @@ -84,8 +84,8 @@ impl State { impl ToTokens for State { fn to_tokens(&self, tokens: &mut TokenStream) { match self { - State::Custom(inner) | State::Default(inner) => inner.to_tokens(tokens), - State::CannotInfer => quote! { S }.to_tokens(tokens), + Self::Custom(inner) | Self::Default(inner) => inner.to_tokens(tokens), + Self::CannotInfer => quote! { S }.to_tokens(tokens), } } } @@ -316,16 +316,14 @@ fn parse_single_generic_type_on_struct( } fn error_on_generic_ident(generic_ident: Option, tr: Trait) -> syn::Result<()> { - if let Some(generic_ident) = generic_ident { + generic_ident.map_or(Ok(()), |generic_ident| { Err(syn::Error::new_spanned( generic_ident, format_args!( "#[derive({tr})] only supports generics when used with #[from_request(via)]" ), )) - } else { - Ok(()) - } + }) } fn impl_struct_by_extracting_each_field( @@ -335,28 +333,30 @@ fn impl_struct_by_extracting_each_field( state: &State, tr: Trait, ) -> syn::Result { - let trait_fn_body = match state { - State::CannotInfer => quote! { + let trait_fn_body = if let State::CannotInfer = state { + quote! { ::std::unimplemented!() - }, - _ => { - let extract_fields = extract_fields(&fields, &rejection, tr)?; - quote! { - ::std::result::Result::Ok(Self { - #(#extract_fields)* - }) - } } - }; - - let rejection_ident = if let Some(rejection) = rejection { - quote!(#rejection) - } else if has_no_fields(&fields) { - quote!(::std::convert::Infallible) } else { - quote!(::axum::response::Response) + let extract_fields = extract_fields(&fields, &rejection, tr)?; + quote! { + ::std::result::Result::Ok(Self { + #(#extract_fields)* + }) + } }; + let rejection_ident = rejection.map_or_else( + || { + if has_no_fields(&fields) { + quote!(::std::convert::Infallible) + } else { + quote!(::axum::response::Response) + } + }, + |rejection| quote!(#rejection), + ); + let impl_generics = state .impl_generics() .collect::>(); @@ -417,16 +417,16 @@ fn extract_fields( tr: Trait, ) -> syn::Result> { fn member(field: &syn::Field, index: usize) -> TokenStream { - match &field.ident { - Some(ident) => quote! { #ident }, - _ => { + field.ident.as_ref().map_or_else( + || { let member = syn::Member::Unnamed(syn::Index { index: index as u32, span: field.span(), }); quote! { #member } - } - } + }, + |ident| quote! { #ident }, + ) } fn into_inner(via: &Option<(attr::kw::via, syn::Path)>, ty_span: Span) -> TokenStream { @@ -547,11 +547,7 @@ fn extract_fields( Ok(tokens) } else { let field_ty = into_outer(&via,ty_span,&field.ty); - let map_err = if let Some(rejection) = rejection { - quote! { <#rejection as ::std::convert::From<_>>::from } - } else { - quote! { ::axum::response::IntoResponse::into_response } - }; + let map_err = rejection.as_ref().map_or_else(|| quote! { ::axum::response::IntoResponse::into_response }, |rejection| quote! { <#rejection as ::std::convert::From<_>>::from }); let tokens = match tr { Trait::FromRequest => { @@ -619,11 +615,10 @@ fn extract_fields( } } else { let field_ty = into_outer(&via, ty_span, &field.ty); - let map_err = if let Some(rejection) = rejection { - quote! { <#rejection as ::std::convert::From<_>>::from } - } else { - quote! { ::axum::response::IntoResponse::into_response } - }; + let map_err = rejection.as_ref().map_or_else( + || quote! { ::axum::response::IntoResponse::into_response }, + |rejection| quote! { <#rejection as ::std::convert::From<_>>::from }, + ); quote_spanned! {ty_span=> #member: { @@ -642,9 +637,7 @@ fn extract_fields( } fn peel_option(ty: &syn::Type) -> Option<&syn::Type> { - let type_path = if let syn::Type::Path(type_path) = ty { - type_path - } else { + let syn::Type::Path(type_path) = ty else { return None; }; @@ -673,9 +666,7 @@ fn peel_option(ty: &syn::Type) -> Option<&syn::Type> { } fn peel_result_ok(ty: &syn::Type) -> Option<&syn::Type> { - let type_path = if let syn::Type::Path(type_path) = ty { - type_path - } else { + let syn::Type::Path(type_path) = ty else { return None; }; @@ -732,17 +723,20 @@ fn impl_struct_by_extracting_all_at_once( let path_span = via_path.span(); - let (associated_rejection_type, map_err) = if let Some(rejection) = &rejection { - let rejection = quote! { #rejection }; - let map_err = quote! { ::std::convert::From::from }; - (rejection, map_err) - } else { - let rejection = quote! { - ::axum::response::Response - }; - let map_err = quote! { ::axum::response::IntoResponse::into_response }; - (rejection, map_err) - }; + let (associated_rejection_type, map_err) = rejection.as_ref().map_or_else( + || { + let rejection = quote! { + ::axum::response::Response + }; + let map_err = quote! { ::axum::response::IntoResponse::into_response }; + (rejection, map_err) + }, + |rejection| { + let rejection = quote! { #rejection }; + let map_err = quote! { ::std::convert::From::from }; + (rejection, map_err) + }, + ); // for something like // @@ -909,17 +903,20 @@ fn impl_enum_by_extracting_all_at_once( } } - let (associated_rejection_type, map_err) = if let Some(rejection) = &rejection { - let rejection = quote! { #rejection }; - let map_err = quote! { ::std::convert::From::from }; - (rejection, map_err) - } else { - let rejection = quote! { - ::axum::response::Response - }; - let map_err = quote! { ::axum::response::IntoResponse::into_response }; - (rejection, map_err) - }; + let (associated_rejection_type, map_err) = rejection.as_ref().map_or_else( + || { + let rejection = quote! { + ::axum::response::Response + }; + let map_err = quote! { ::axum::response::IntoResponse::into_response }; + (rejection, map_err) + }, + |rejection| { + let rejection = quote! { #rejection }; + let map_err = quote! { ::std::convert::From::from }; + (rejection, map_err) + }, + ); let path_span = path.span(); @@ -1040,11 +1037,9 @@ fn infer_state_type_from_field_attributes(fields: &Fields) -> impl Iterator bool { - if let Some(last_segment) = path.segments.last() { - last_segment.ident == "State" - } else { - false - } + path.segments + .last() + .is_some_and(|last_segment| last_segment.ident == "State") } fn state_from_via(ident: &Ident, via: &Path) -> Option { @@ -1066,4 +1061,4 @@ fn ui() { /// } /// ``` #[allow(dead_code)] -fn test_field_doesnt_impl_from_request() {} +const fn test_field_doesnt_impl_from_request() {} diff --git a/axum-macros/src/typed_path.rs b/axum-macros/src/typed_path.rs index 397cc94c0d..85079693b2 100644 --- a/axum-macros/src/typed_path.rs +++ b/axum-macros/src/typed_path.rs @@ -304,20 +304,22 @@ fn expand_unit_fields( } }; - let rejection_assoc_type = if let Some(rejection) = &rejection { - quote! { #rejection } - } else { - quote! { ::axum::http::StatusCode } - }; - let create_rejection = if let Some(rejection) = &rejection { - quote! { - Err(<#rejection as ::std::default::Default>::default()) - } - } else { - quote! { - Err(::axum::http::StatusCode::NOT_FOUND) - } - }; + let rejection_assoc_type = rejection.as_ref().map_or_else( + || quote! { ::axum::http::StatusCode }, + |rejection| quote! { #rejection }, + ); + let create_rejection = rejection.as_ref().map_or_else( + || { + quote! { + Err(::axum::http::StatusCode::NOT_FOUND) + } + }, + |rejection| { + quote! { + Err(<#rejection as ::std::default::Default>::default()) + } + }, + ); let from_request_impl = quote! { #[automatically_derived] @@ -382,18 +384,17 @@ fn parse_path(path: &LitStr) -> syn::Result> { path.value() .split('/') .map(|segment| { - if let Some(capture) = segment + segment .strip_prefix('{') .and_then(|segment| segment.strip_suffix('}')) .and_then(|segment| { (!segment.starts_with('{') && !segment.ends_with('}')).then_some(segment) }) .map(|capture| capture.strip_prefix('*').unwrap_or(capture)) - { - Ok(Segment::Capture(capture.to_owned(), path.span())) - } else { - Ok(Segment::Static(segment.to_owned())) - } + .map_or_else( + || Ok(Segment::Static(segment.to_owned())), + |capture| Ok(Segment::Capture(capture.to_owned(), path.span())), + ) }) .collect() } @@ -410,10 +411,9 @@ fn path_rejection() -> TokenStream { } fn rejection_assoc_type(rejection: &Option) -> TokenStream { - match rejection { - Some(rejection) => quote! { #rejection }, - None => path_rejection(), - } + rejection + .as_ref() + .map_or_else(path_rejection, |rejection| quote! { #rejection }) } fn map_err_rejection(rejection: &Option) -> TokenStream { diff --git a/axum-macros/src/with_position.rs b/axum-macros/src/with_position.rs index e064a3f01e..770efa5eaa 100644 --- a/axum-macros/src/with_position.rs +++ b/axum-macros/src/with_position.rs @@ -40,8 +40,8 @@ impl WithPosition where I: Iterator, { - pub(crate) fn new(iter: impl IntoIterator) -> WithPosition { - WithPosition { + pub(crate) fn new(iter: impl IntoIterator) -> Self { + Self { handled_first: false, peekable: iter.into_iter().fuse().peekable(), } @@ -72,7 +72,7 @@ pub(crate) enum Position { impl Position { pub(crate) fn into_inner(self) -> T { match self { - Position::First(x) | Position::Middle(x) | Position::Last(x) | Position::Only(x) => x, + Self::First(x) | Self::Middle(x) | Self::Last(x) | Self::Only(x) => x, } } } @@ -83,7 +83,14 @@ impl Iterator for WithPosition { fn next(&mut self) -> Option { match self.peekable.next() { Some(item) => { - if !self.handled_first { + if self.handled_first { + // Have seen the first item, and there's something left. + // Peek to see if this is the last item. + match self.peekable.peek() { + Some(_) => Some(Position::Middle(item)), + None => Some(Position::Last(item)), + } + } else { // Haven't seen the first item yet, and there is one to give. self.handled_first = true; // Peek to see if this is also the last item, @@ -92,13 +99,6 @@ impl Iterator for WithPosition { Some(_) => Some(Position::First(item)), None => Some(Position::Only(item)), } - } else { - // Have seen the first item, and there's something left. - // Peek to see if this is the last item. - match self.peekable.peek() { - Some(_) => Some(Position::Middle(item)), - None => Some(Position::Last(item)), - } } } // Iterator is finished. diff --git a/axum/benches/benches.rs b/axum/benches/benches.rs index c38ef91830..c4376afd43 100644 --- a/axum/benches/benches.rs +++ b/axum/benches/benches.rs @@ -102,7 +102,7 @@ struct Payload { b: bool, } -fn benchmark(name: &'static str) -> BenchmarkBuilder { +const fn benchmark(name: &'static str) -> BenchmarkBuilder { BenchmarkBuilder { name, path: None, @@ -122,7 +122,7 @@ struct BenchmarkBuilder { macro_rules! config_method { ($name:ident, $ty:ty) => { - fn $name(mut self, $name: $ty) -> Self { + const fn $name(mut self, $name: $ty) -> Self { self.$name = Some($name); self } @@ -230,9 +230,7 @@ fn install_rewrk() { let status = cmd .status() .unwrap_or_else(|_| panic!("failed to install rewrk")); - if !status.success() { - panic!("failed to install rewrk"); - } + assert!(status.success(), "failed to install rewrk"); } fn ensure_rewrk_is_installed() { diff --git a/axum/src/extension.rs b/axum/src/extension.rs index da3b7b0ddc..bbdd05d015 100644 --- a/axum/src/extension.rs +++ b/axum/src/extension.rs @@ -1,4 +1,7 @@ -use crate::{extract::rejection::*, response::IntoResponseParts}; +use crate::{ + extract::rejection::{ExtensionRejection, MissingExtension}, + response::IntoResponseParts, +}; use axum_core::extract::OptionalFromRequestParts; use axum_core::{ extract::FromRequestParts, @@ -204,7 +207,7 @@ mod tests { } async fn optional_foo(extension: Option>) -> String { - extension.map(|foo| foo.0 .0).unwrap_or("none".to_owned()) + extension.map_or("none".to_owned(), |foo| foo.0 .0) } async fn requires_bar(Extension(bar): Extension) -> String { @@ -212,7 +215,7 @@ mod tests { } async fn optional_bar(extension: Option>) -> String { - extension.map(|bar| bar.0 .0).unwrap_or("none".to_owned()) + extension.map_or("none".to_owned(), |bar| bar.0 .0) } let app = Router::new() diff --git a/axum/src/extract/connect_info.rs b/axum/src/extract/connect_info.rs index 2c7866f9f6..c0d471d6f9 100644 --- a/axum/src/extract/connect_info.rs +++ b/axum/src/extract/connect_info.rs @@ -105,8 +105,8 @@ const _: () = { } }; -impl Connected for SocketAddr { - fn connect_info(remote_addr: SocketAddr) -> Self { +impl Connected for SocketAddr { + fn connect_info(remote_addr: Self) -> Self { remote_addr } } diff --git a/axum/src/extract/matched_path.rs b/axum/src/extract/matched_path.rs index 0f5efba326..55857996d9 100644 --- a/axum/src/extract/matched_path.rs +++ b/axum/src/extract/matched_path.rs @@ -1,4 +1,7 @@ -use super::{rejection::*, FromRequestParts}; +use super::{ + rejection::{MatchedPathMissing, MatchedPathRejection}, + FromRequestParts, +}; use crate::routing::{RouteId, NEST_TAIL_PARAM_CAPTURE}; use axum_core::extract::OptionalFromRequestParts; use http::request::Parts; @@ -58,6 +61,7 @@ pub struct MatchedPath(pub(crate) Arc); impl MatchedPath { /// Returns a `str` representation of the path. + #[must_use] pub fn as_str(&self) -> &str { &self.0 } @@ -102,9 +106,7 @@ pub(crate) fn set_matched_path_for_request( route_id_to_path: &HashMap>, extensions: &mut http::Extensions, ) { - let matched_path = if let Some(matched_path) = route_id_to_path.get(&id) { - matched_path - } else { + let Some(matched_path) = route_id_to_path.get(&id) else { #[cfg(debug_assertions)] panic!("should always have a matched path for a route id"); #[cfg(not(debug_assertions))] @@ -124,20 +126,21 @@ pub(crate) fn set_matched_path_for_request( // a previous `MatchedPath` might exist if we're inside a nested Router fn append_nested_matched_path(matched_path: &Arc, extensions: &http::Extensions) -> Arc { - if let Some(previous) = extensions + extensions .get::() .map(|matched_path| matched_path.as_str()) .or_else(|| Some(&extensions.get::()?.0)) - { - let previous = previous - .strip_suffix(NEST_TAIL_PARAM_CAPTURE) - .unwrap_or(previous); - - let matched_path = format!("{previous}{matched_path}"); - matched_path.into() - } else { - Arc::clone(matched_path) - } + .map_or_else( + || Arc::clone(matched_path), + |previous| { + let previous = previous + .strip_suffix(NEST_TAIL_PARAM_CAPTURE) + .unwrap_or(previous); + + let matched_path = format!("{previous}{matched_path}"); + matched_path.into() + }, + ) } #[cfg(test)] diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index 8a0af3da49..d2b19b684d 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -81,15 +81,11 @@ pub use self::ws::WebSocketUpgrade; // this is duplicated in `axum-extra/src/extract/form.rs` pub(super) fn has_content_type(headers: &HeaderMap, expected_content_type: &mime::Mime) -> bool { - let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) { - content_type - } else { + let Some(content_type) = headers.get(header::CONTENT_TYPE) else { return false; }; - let content_type = if let Ok(content_type) = content_type.to_str() { - content_type - } else { + let Ok(content_type) = content_type.to_str() else { return false; }; diff --git a/axum/src/extract/multipart.rs b/axum/src/extract/multipart.rs index 22ab7fa47c..2e5db67f5c 100644 --- a/axum/src/extract/multipart.rs +++ b/axum/src/extract/multipart.rs @@ -112,14 +112,15 @@ impl Multipart { .await .map_err(MultipartError::from_multer)?; - if let Some(field) = field { - Ok(Some(Field { - inner: field, - _multipart: self, - })) - } else { - Ok(None) - } + field.map_or_else( + || Ok(None), + |field| { + Ok(Some(Field { + inner: field, + _multipart: self, + })) + }, + ) } } @@ -146,6 +147,7 @@ impl Field<'_> { /// The field name found in the /// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) /// header. + #[must_use] pub fn name(&self) -> Option<&str> { self.inner.name() } @@ -153,16 +155,19 @@ impl Field<'_> { /// The file name found in the /// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) /// header. + #[must_use] pub fn file_name(&self) -> Option<&str> { self.inner.file_name() } /// Get the [content type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type) of the field. + #[must_use] pub fn content_type(&self) -> Option<&str> { self.inner.content_type().map(|m| m.as_ref()) } /// Get a map of headers as [`HeaderMap`]. + #[must_use] pub fn headers(&self) -> &HeaderMap { self.inner.headers() } @@ -233,16 +238,18 @@ pub struct MultipartError { } impl MultipartError { - fn from_multer(multer: multer::Error) -> Self { + const fn from_multer(multer: multer::Error) -> Self { Self { source: multer } } /// Get the response body text used for this rejection. + #[must_use] pub fn body_text(&self) -> String { self.source.to_string() } /// Get the status code used for this rejection. + #[must_use] pub fn status(&self) -> http::StatusCode { status_code_from_multer_error(&self.source) } diff --git a/axum/src/extract/nested_path.rs b/axum/src/extract/nested_path.rs index 2e58d0e826..14c87e6c4f 100644 --- a/axum/src/extract/nested_path.rs +++ b/axum/src/extract/nested_path.rs @@ -41,6 +41,7 @@ pub struct NestedPath(Arc); impl NestedPath { /// Returns a `str` representation of the path. + #[must_use] pub fn as_str(&self) -> &str { &self.0 } @@ -53,10 +54,12 @@ where type Rejection = NestedPathRejection; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { - match parts.extensions.get::() { - Some(nested_path) => Ok(nested_path.clone()), - None => Err(NestedPathRejection), - } + parts + .extensions + .get::() + .map_or(Err(NestedPathRejection), |nested_path| { + Ok(nested_path.clone()) + }) } } @@ -100,7 +103,7 @@ where } else { req.extensions_mut() .insert(NestedPath(Arc::clone(&self.path))); - }; + } self.inner.call(req) } diff --git a/axum/src/extract/original_uri.rs b/axum/src/extract/original_uri.rs index 35364281ba..b5dbb58347 100644 --- a/axum/src/extract/original_uri.rs +++ b/axum/src/extract/original_uri.rs @@ -76,7 +76,7 @@ where async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let uri = Extension::::from_request_parts(parts, state) .await - .unwrap_or_else(|_| Extension(OriginalUri(parts.uri.clone()))) + .unwrap_or_else(|_| Extension(Self(parts.uri.clone()))) .0; Ok(uri) } diff --git a/axum/src/extract/path/de.rs b/axum/src/extract/path/de.rs index ca78bb9e23..f171717b88 100644 --- a/axum/src/extract/path/de.rs +++ b/axum/src/extract/path/de.rs @@ -48,7 +48,7 @@ pub(crate) struct PathDeserializer<'de> { impl<'de> PathDeserializer<'de> { #[inline] - pub(crate) fn new(url_params: &'de [(Arc, PercentDecodedStr)]) -> Self { + pub(crate) const fn new(url_params: &'de [(Arc, PercentDecodedStr)]) -> Self { PathDeserializer { url_params } } } @@ -460,7 +460,7 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> { )); } None => {} - }; + } self.value .take() @@ -637,7 +637,7 @@ enum KeyOrIdx<'de> { } impl<'de> KeyOrIdx<'de> { - fn key(&self) -> &'de str { + const fn key(&self) -> &'de str { match &self { Self::Key(key) => key, Self::Idx { key, .. } => key, diff --git a/axum/src/extract/path/mod.rs b/axum/src/extract/path/mod.rs index f37fff4dfb..a5a40cd3fd 100644 --- a/axum/src/extract/path/mod.rs +++ b/axum/src/extract/path/mod.rs @@ -4,7 +4,10 @@ mod de; use crate::{ - extract::{rejection::*, FromRequestParts}, + extract::{ + rejection::{MissingPathParams, PathRejection, RawPathParamsRejection}, + FromRequestParts, + }, routing::url_params::UrlParams, util::PercentDecodedStr, }; @@ -178,12 +181,12 @@ where } } - fn failed_to_deserialize_path_params(err: PathDeserializationError) -> PathRejection { + const fn failed_to_deserialize_path_params(err: PathDeserializationError) -> PathRejection { PathRejection::FailedToDeserializePathParams(FailedToDeserializePathParams(err)) } match T::deserialize(de::PathDeserializer::new(get_params(parts)?)) { - Ok(val) => Ok(Path(val)), + Ok(val) => Ok(Self(val)), Err(e) => Err(failed_to_deserialize_path_params(e)), } } @@ -220,16 +223,16 @@ pub(crate) struct PathDeserializationError { } impl PathDeserializationError { - pub(super) fn new(kind: ErrorKind) -> Self { + pub(super) const fn new(kind: ErrorKind) -> Self { Self { kind } } - pub(super) fn wrong_number_of_parameters() -> WrongNumberOfParameters<()> { + pub(super) const fn wrong_number_of_parameters() -> WrongNumberOfParameters<()> { WrongNumberOfParameters { got: () } } #[track_caller] - pub(super) fn unsupported_type(name: &'static str) -> Self { + pub(super) const fn unsupported_type(name: &'static str) -> Self { Self::new(ErrorKind::UnsupportedType { name }) } } @@ -246,7 +249,7 @@ impl WrongNumberOfParameters { } impl WrongNumberOfParameters { - pub(super) fn expected(self, expected: usize) -> PathDeserializationError { + pub(super) const fn expected(self, expected: usize) -> PathDeserializationError { PathDeserializationError::new(ErrorKind::WrongNumberOfParameters { got: self.got, expected, @@ -356,9 +359,9 @@ pub enum ErrorKind { impl fmt::Display for ErrorKind { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - ErrorKind::Message(error) => error.fmt(f), - ErrorKind::InvalidUtf8InPathParam { key } => write!(f, "Invalid UTF-8 in `{key}`"), - ErrorKind::WrongNumberOfParameters { got, expected } => { + Self::Message(error) => error.fmt(f), + Self::InvalidUtf8InPathParam { key } => write!(f, "Invalid UTF-8 in `{key}`"), + Self::WrongNumberOfParameters { got, expected } => { write!( f, "Wrong number of path arguments for `Path`. Expected {expected} but got {got}" @@ -370,8 +373,8 @@ impl fmt::Display for ErrorKind { Ok(()) } - ErrorKind::UnsupportedType { name } => write!(f, "Unsupported type `{name}`"), - ErrorKind::ParseErrorAtKey { + Self::UnsupportedType { name } => write!(f, "Unsupported type `{name}`"), + Self::ParseErrorAtKey { key, value, expected_type, @@ -379,11 +382,11 @@ impl fmt::Display for ErrorKind { f, "Cannot parse `{key}` with value `{value}` to a `{expected_type}`" ), - ErrorKind::ParseError { + Self::ParseError { value, expected_type, } => write!(f, "Cannot parse `{value}` to a `{expected_type}`"), - ErrorKind::ParseErrorAtIndex { + Self::ParseErrorAtIndex { index, value, expected_type, @@ -391,7 +394,7 @@ impl fmt::Display for ErrorKind { f, "Cannot parse value at index {index} with value `{value}` to a `{expected_type}`" ), - ErrorKind::DeserializeError { + Self::DeserializeError { key, value, message, @@ -407,16 +410,19 @@ pub struct FailedToDeserializePathParams(PathDeserializationError); impl FailedToDeserializePathParams { /// Get a reference to the underlying error kind. - pub fn kind(&self) -> &ErrorKind { + #[must_use] + pub const fn kind(&self) -> &ErrorKind { &self.0.kind } /// Convert this error into the underlying error kind. + #[must_use] pub fn into_kind(self) -> ErrorKind { self.0.kind } /// Get the response body text used for this rejection. + #[must_use] pub fn body_text(&self) -> String { match self.0.kind { ErrorKind::Message(_) @@ -432,7 +438,8 @@ impl FailedToDeserializePathParams { } /// Get the status code used for this rejection. - pub fn status(&self) -> StatusCode { + #[must_use] + pub const fn status(&self) -> StatusCode { match self.0.kind { ErrorKind::Message(_) | ErrorKind::DeserializeError { .. } @@ -523,6 +530,7 @@ where impl RawPathParams { /// Get an iterator over the path parameters. + #[must_use] pub fn iter(&self) -> RawPathParamsIter<'_> { self.into_iter() } @@ -561,12 +569,14 @@ pub struct InvalidUtf8InPathParam { impl InvalidUtf8InPathParam { /// Get the response body text used for this rejection. + #[must_use] pub fn body_text(&self) -> String { self.to_string() } /// Get the status code used for this rejection. - pub fn status(&self) -> StatusCode { + #[must_use] + pub const fn status(&self) -> StatusCode { StatusCode::BAD_REQUEST } } @@ -766,7 +776,7 @@ mod tests { D: serde::Deserializer<'de>, { let s = <&str as serde::Deserialize>::deserialize(deserializer)?; - Ok(Param(s.to_owned())) + Ok(Self(s.to_owned())) } } diff --git a/axum/src/extract/query.rs b/axum/src/extract/query.rs index 58b7d366e8..ddc9425a46 100644 --- a/axum/src/extract/query.rs +++ b/axum/src/extract/query.rs @@ -1,4 +1,7 @@ -use super::{rejection::*, FromRequestParts}; +use super::{ + rejection::{FailedToDeserializeQueryString, QueryRejection}, + FromRequestParts, +}; use http::{request::Parts, Uri}; use serde::de::DeserializeOwned; @@ -91,7 +94,7 @@ where serde_urlencoded::Deserializer::new(form_urlencoded::parse(query.as_bytes())); let params = serde_path_to_error::deserialize(deserializer) .map_err(FailedToDeserializeQueryString::from_err)?; - Ok(Query(params)) + Ok(Self(params)) } } diff --git a/axum/src/extract/raw_form.rs b/axum/src/extract/raw_form.rs index 29cb4c6dd3..cbb71585d7 100644 --- a/axum/src/extract/raw_form.rs +++ b/axum/src/extract/raw_form.rs @@ -99,6 +99,6 @@ mod tests { assert!(matches!( RawForm::from_request(req, &()).await.unwrap_err(), RawFormRejection::InvalidFormContentType(InvalidFormContentType) - )) + )); } } diff --git a/axum/src/extract/ws.rs b/axum/src/extract/ws.rs index 94a6703903..12fe6fd3b3 100644 --- a/axum/src/extract/ws.rs +++ b/axum/src/extract/ws.rs @@ -90,7 +90,13 @@ //! //! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split -use self::rejection::*; +use self::rejection::{ + ConnectionNotUpgradable, InvalidConnectionHeader, + InvalidUpgradeHeader, InvalidWebSocketVersionHeader, MethodNotConnect, MethodNotGet, + WebSocketKeyHeaderMissing, WebSocketUpgradeRejection, +}; +#[cfg(feature = "http2")] +use self::rejection::InvalidProtocolPseudoheader; use super::FromRequestParts; use crate::{body::Bytes, response::Response, Error}; use axum_core::body::Body; @@ -152,7 +158,8 @@ impl std::fmt::Debug for WebSocketUpgrade { impl WebSocketUpgrade { /// Read buffer capacity. The default value is 128KiB - pub fn read_buffer_size(mut self, size: usize) -> Self { + #[must_use] + pub const fn read_buffer_size(mut self, size: usize) -> Self { self.config.read_buffer_size = size; self } @@ -166,7 +173,8 @@ impl WebSocketUpgrade { /// It is often more optimal to allow them to buffer a little, hence the default value. /// /// Note: [`flush`](SinkExt::flush) will always fully write the buffer regardless. - pub fn write_buffer_size(mut self, size: usize) -> Self { + #[must_use] + pub const fn write_buffer_size(mut self, size: usize) -> Self { self.config.write_buffer_size = size; self } @@ -182,25 +190,29 @@ impl WebSocketUpgrade { /// /// Note: Should always be at least [`write_buffer_size + 1 message`](Self::write_buffer_size) /// and probably a little more depending on error handling strategy. - pub fn max_write_buffer_size(mut self, max: usize) -> Self { + #[must_use] + pub const fn max_write_buffer_size(mut self, max: usize) -> Self { self.config.max_write_buffer_size = max; self } /// Set the maximum message size (defaults to 64 megabytes) - pub fn max_message_size(mut self, max: usize) -> Self { + #[must_use] + pub const fn max_message_size(mut self, max: usize) -> Self { self.config.max_message_size = Some(max); self } /// Set the maximum frame size (defaults to 16 megabytes) - pub fn max_frame_size(mut self, max: usize) -> Self { + #[must_use] + pub const fn max_frame_size(mut self, max: usize) -> Self { self.config.max_frame_size = Some(max); self } /// Allow server to accept unmasked frames (defaults to false) - pub fn accept_unmasked_frames(mut self, accept: bool) -> Self { + #[must_use] + pub const fn accept_unmasked_frames(mut self, accept: bool) -> Self { self.config.accept_unmasked_frames = accept; self } @@ -235,6 +247,7 @@ impl WebSocketUpgrade { /// } /// # let _: Router = app; /// ``` + #[must_use] pub fn protocols(mut self, protocols: I) -> Self where I: IntoIterator, @@ -269,7 +282,7 @@ impl WebSocketUpgrade { /// If [`protocols()`][Self::protocols] has been called and a matching /// protocol has been selected, the return value will be `Some` containing /// said protocol. Otherwise, it will be `None`. - pub fn selected_protocol(&self) -> Option<&HeaderValue> { + pub const fn selected_protocol(&self) -> Option<&HeaderValue> { self.protocol.as_ref() } @@ -346,30 +359,28 @@ impl WebSocketUpgrade { callback(socket).await; }); - let mut response = if let Some(sec_websocket_key) = &self.sec_websocket_key { - // If `sec_websocket_key` was `Some`, we are using HTTP/1.1. - - #[allow(clippy::declare_interior_mutable_const)] - const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade"); - #[allow(clippy::declare_interior_mutable_const)] - const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket"); - - Response::builder() - .status(StatusCode::SWITCHING_PROTOCOLS) - .header(header::CONNECTION, UPGRADE) - .header(header::UPGRADE, WEBSOCKET) - .header( - header::SEC_WEBSOCKET_ACCEPT, - sign(sec_websocket_key.as_bytes()), - ) - .body(Body::empty()) - .unwrap() - } else { - // Otherwise, we are HTTP/2+. As established in RFC 9113 section 8.5, we just respond - // with a 2XX with an empty body: - // . - Response::new(Body::empty()) - }; + let mut response = self.sec_websocket_key.as_ref().map_or_else( + || Response::new(Body::empty()), + |sec_websocket_key| { + // If `sec_websocket_key` was `Some`, we are using HTTP/1.1. + + #[allow(clippy::declare_interior_mutable_const)] + const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade"); + #[allow(clippy::declare_interior_mutable_const)] + const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket"); + + Response::builder() + .status(StatusCode::SWITCHING_PROTOCOLS) + .header(header::CONNECTION, UPGRADE) + .header(header::UPGRADE, WEBSOCKET) + .header( + header::SEC_WEBSOCKET_ACCEPT, + sign(sec_websocket_key.as_bytes()), + ) + .body(Body::empty()) + .unwrap() + }, + ); if let Some(protocol) = self.protocol { response @@ -394,7 +405,7 @@ where F: FnOnce(Error) + Send + 'static, { fn call(self, error: Error) { - self(error) + self(error); } } @@ -479,25 +490,18 @@ where } fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool { - if let Some(header) = headers.get(&key) { - header.as_bytes().eq_ignore_ascii_case(value.as_bytes()) - } else { - false - } + headers + .get(&key) + .is_some_and(|header| header.as_bytes().eq_ignore_ascii_case(value.as_bytes())) } fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool { - let header = if let Some(header) = headers.get(&key) { - header - } else { + let Some(header) = headers.get(&key) else { return false; }; - if let Ok(header) = std::str::from_utf8(header.as_bytes()) { - header.to_ascii_lowercase().contains(value) - } else { - false - } + std::str::from_utf8(header.as_bytes()) + .is_ok_and(|header| header.to_ascii_lowercase().contains(value)) } /// A stream of WebSocket messages. @@ -526,7 +530,7 @@ impl WebSocket { } /// Return the selected WebSocket subprotocol, if one has been chosen. - pub fn protocol(&self) -> Option<&HeaderValue> { + pub const fn protocol(&self) -> Option<&HeaderValue> { self.protocol.as_ref() } } @@ -573,13 +577,14 @@ impl Sink for WebSocket { /// UTF-8 wrapper for [Bytes]. /// -/// An [Utf8Bytes] is always guaranteed to contain valid UTF-8. +/// An [`Utf8Bytes`] is always guaranteed to contain valid UTF-8. #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct Utf8Bytes(ts::Utf8Bytes); impl Utf8Bytes { /// Creates from a static str. #[inline] + #[must_use] pub const fn from_static(str: &'static str) -> Self { Self(ts::Utf8Bytes::from_static(str)) } @@ -806,7 +811,7 @@ impl Message { } } - /// Attempt to consume the WebSocket message and convert it to a Utf8Bytes. + /// Attempt to consume the WebSocket message and convert it to a [`Utf8Bytes`]. pub fn into_text(self) -> Result { match self { Self::Text(string) => Ok(string), @@ -832,49 +837,49 @@ impl Message { } /// Create a new text WebSocket message from a stringable. - pub fn text(string: S) -> Message + pub fn text(string: S) -> Self where S: Into, { - Message::Text(string.into()) + Self::Text(string.into()) } /// Create a new binary WebSocket message by converting to `Bytes`. - pub fn binary(bin: B) -> Message + pub fn binary(bin: B) -> Self where B: Into, { - Message::Binary(bin.into()) + Self::Binary(bin.into()) } } impl From for Message { fn from(string: String) -> Self { - Message::Text(string.into()) + Self::Text(string.into()) } } impl<'s> From<&'s str> for Message { fn from(string: &'s str) -> Self { - Message::Text(string.into()) + Self::Text(string.into()) } } impl<'b> From<&'b [u8]> for Message { fn from(data: &'b [u8]) -> Self { - Message::Binary(Bytes::copy_from_slice(data)) + Self::Binary(Bytes::copy_from_slice(data)) } } impl From for Message { fn from(data: Bytes) -> Self { - Message::Binary(data) + Self::Binary(data) } } impl From> for Message { fn from(data: Vec) -> Self { - Message::Binary(data.into()) + Self::Binary(data.into()) } } diff --git a/axum/src/form.rs b/axum/src/form.rs index dabfb65332..6d4b4f4698 100644 --- a/axum/src/form.rs +++ b/axum/src/form.rs @@ -1,5 +1,10 @@ use crate::extract::Request; -use crate::extract::{rejection::*, FromRequest, RawForm}; +use crate::extract::{ + rejection::{ + FailedToDeserializeForm, FailedToDeserializeFormBody, FormRejection, RawFormRejection, + }, + FromRequest, RawForm, +}; use axum_core::response::{IntoResponse, Response}; use axum_core::RequestExt; use http::header::CONTENT_TYPE; @@ -95,7 +100,7 @@ where } }, )?; - Ok(Form(value)) + Ok(Self(value)) } Err(RawFormRejection::BytesRejection(r)) => Err(FormRejection::BytesRejection(r)), Err(RawFormRejection::InvalidFormContentType(r)) => { @@ -130,6 +135,7 @@ axum_core::__impl_deref!(Form); #[cfg(test)] mod tests { use crate::{ + extract::rejection::InvalidFormContentType, routing::{on, MethodFilter}, test_helpers::TestClient, Router, diff --git a/axum/src/handler/service.rs b/axum/src/handler/service.rs index 2090051978..b026e67f81 100644 --- a/axum/src/handler/service.rs +++ b/axum/src/handler/service.rs @@ -27,7 +27,7 @@ pub struct HandlerService { impl HandlerService { /// Get a reference to the state. - pub fn state(&self) -> &S { + pub const fn state(&self) -> &S { &self.state } @@ -60,7 +60,7 @@ impl HandlerService { /// ``` /// /// [`MakeService`]: tower::make::MakeService - pub fn into_make_service(self) -> IntoMakeService> { + pub const fn into_make_service(self) -> IntoMakeService { IntoMakeService::new(self) } @@ -101,9 +101,7 @@ impl HandlerService { /// [`MakeService`]: tower::make::MakeService /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info #[cfg(feature = "tokio")] - pub fn into_make_service_with_connect_info( - self, - ) -> IntoMakeServiceWithConnectInfo, C> { + pub fn into_make_service_with_connect_info(self) -> IntoMakeServiceWithConnectInfo { IntoMakeServiceWithConnectInfo::new(self) } } diff --git a/axum/src/json.rs b/axum/src/json.rs index 1e662cc6d1..54411f25e0 100644 --- a/axum/src/json.rs +++ b/axum/src/json.rs @@ -1,5 +1,8 @@ use crate::extract::Request; -use crate::extract::{rejection::*, FromRequest}; +use crate::extract::{ + rejection::{JsonDataError, JsonRejection, JsonSyntaxError, MissingJsonContentType}, + FromRequest, +}; use axum_core::extract::OptionalFromRequest; use axum_core::response::{IntoResponse, Response}; use bytes::{BufMut, Bytes, BytesMut}; @@ -184,7 +187,7 @@ where let deserializer = &mut serde_json::Deserializer::from_slice(bytes); match serde_path_to_error::deserialize(deserializer) { - Ok(value) => Ok(Json(value)), + Ok(value) => Ok(Self(value)), Err(err) => Err(make_rejection(err)), } } diff --git a/axum/src/macros.rs b/axum/src/macros.rs index 37b8fc3b26..09370b863e 100644 --- a/axum/src/macros.rs +++ b/axum/src/macros.rs @@ -17,7 +17,7 @@ macro_rules! opaque_future { } impl<$($param),*> $name<$($param),*> { - pub(crate) fn new(future: $actual) -> Self { + pub(crate) const fn new(future: $actual) -> Self { Self { future } } } diff --git a/axum/src/middleware/map_request.rs b/axum/src/middleware/map_request.rs index 56f250bc4d..3a90bc65ac 100644 --- a/axum/src/middleware/map_request.rs +++ b/axum/src/middleware/map_request.rs @@ -380,7 +380,7 @@ where } impl IntoMapRequestResult for Request { - fn into_map_request_result(self) -> Result, Response> { + fn into_map_request_result(self) -> Result { Ok(self) } } diff --git a/axum/src/response/mod.rs b/axum/src/response/mod.rs index 70be745200..45f501a3ac 100644 --- a/axum/src/response/mod.rs +++ b/axum/src/response/mod.rs @@ -252,6 +252,6 @@ mod tests { assert_eq!( super::NoContent.into_response().status(), StatusCode::NO_CONTENT, - ) + ); } } diff --git a/axum/src/response/redirect.rs b/axum/src/response/redirect.rs index 4113c124e0..528639e819 100644 --- a/axum/src/response/redirect.rs +++ b/axum/src/response/redirect.rs @@ -56,11 +56,13 @@ impl Redirect { } /// Returns the HTTP status code of the `Redirect`. - pub fn status_code(&self) -> StatusCode { + #[must_use] + pub const fn status_code(&self) -> StatusCode { self.status_code } /// Returns the `Redirect`'s URI. + #[must_use] pub fn location(&self) -> &str { &self.location } @@ -123,7 +125,7 @@ mod tests { fn correct_location() { assert_eq!(EXAMPLE_URL, Redirect::permanent(EXAMPLE_URL).location()); - assert_eq!("/redirect", Redirect::permanent("/redirect").location()) + assert_eq!("/redirect", Redirect::permanent("/redirect").location()); } #[test] diff --git a/axum/src/response/sse.rs b/axum/src/response/sse.rs index 933f115e6f..50f76ddcd5 100644 --- a/axum/src/response/sse.rs +++ b/axum/src/response/sse.rs @@ -58,12 +58,12 @@ impl Sse { /// [`Event`]s. /// /// See the [module docs](self) for more details. - pub fn new(stream: S) -> Self + pub const fn new(stream: S) -> Self where S: TryStream + Send + 'static, S::Error: Into, { - Sse { stream } + Self { stream } } /// Configure the interval between keep-alive messages. @@ -154,12 +154,12 @@ impl Buffer { /// a new active buffer with the previous contents. fn as_mut(&mut self) -> &mut BytesMut { match self { - Buffer::Active(bytes_mut) => bytes_mut, - Buffer::Finalized(bytes) => { - *self = Buffer::Active(BytesMut::from(mem::take(bytes))); + Self::Active(bytes_mut) => bytes_mut, + Self::Finalized(bytes) => { + *self = Self::Active(BytesMut::from(mem::take(bytes))); match self { - Buffer::Active(bytes_mut) => bytes_mut, - Buffer::Finalized(_) => unreachable!(), + Self::Active(bytes_mut) => bytes_mut, + Self::Finalized(_) => unreachable!(), } } } @@ -199,13 +199,14 @@ impl Event { /// - Panics if `data` or `json_data` have already been called. /// /// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data - pub fn data(mut self, data: T) -> Event + pub fn data(mut self, data: T) -> Self where T: AsRef, { - if self.flags.contains(EventFlags::HAS_DATA) { - panic!("Called `Event::data` multiple times"); - } + assert!( + !self.flags.contains(EventFlags::HAS_DATA), + "Called `Event::data` multiple times" + ); for line in memchr_split(b'\n', data.as_ref().as_bytes()) { self.field("data", line); @@ -226,7 +227,7 @@ impl Event { /// /// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data #[cfg(feature = "json")] - pub fn json_data(mut self, data: T) -> Result + pub fn json_data(mut self, data: T) -> Result where T: serde::Serialize, { @@ -246,9 +247,10 @@ impl Event { self.0.flush() } } - if self.flags.contains(EventFlags::HAS_DATA) { - panic!("Called `Event::json_data` multiple times"); - } + assert!( + !self.flags.contains(EventFlags::HAS_DATA), + "Called `Event::json_data` multiple times" + ); let buffer = self.buffer.as_mut(); buffer.extend_from_slice(b"data: "); @@ -271,7 +273,7 @@ impl Event { /// /// Panics if `comment` contains any newlines or carriage returns, as they are not allowed in /// comments. - pub fn comment(mut self, comment: T) -> Event + pub fn comment(mut self, comment: T) -> Self where T: AsRef, { @@ -293,13 +295,14 @@ impl Event { /// /// - Panics if `event` contains any newlines or carriage returns. /// - Panics if this function has already been called on this event. - pub fn event(mut self, event: T) -> Event + pub fn event(mut self, event: T) -> Self where T: AsRef, { - if self.flags.contains(EventFlags::HAS_EVENT) { - panic!("Called `Event::event` multiple times"); - } + assert!( + !self.flags.contains(EventFlags::HAS_EVENT), + "Called `Event::event` multiple times" + ); self.flags.insert(EventFlags::HAS_EVENT); self.field("event", event.as_ref()); @@ -316,10 +319,11 @@ impl Event { /// # Panics /// /// Panics if this function has already been called on this event. - pub fn retry(mut self, duration: Duration) -> Event { - if self.flags.contains(EventFlags::HAS_RETRY) { - panic!("Called `Event::retry` multiple times"); - } + pub fn retry(mut self, duration: Duration) -> Self { + assert!( + !self.flags.contains(EventFlags::HAS_RETRY), + "Called `Event::retry` multiple times" + ); self.flags.insert(EventFlags::HAS_RETRY); let buffer = self.buffer.as_mut(); @@ -360,13 +364,14 @@ impl Event { /// /// - Panics if `id` contains any newlines, carriage returns or null characters. /// - Panics if this function has already been called on this event. - pub fn id(mut self, id: T) -> Event + pub fn id(mut self, id: T) -> Self where T: AsRef, { - if self.flags.contains(EventFlags::HAS_ID) { - panic!("Called `Event::id` multiple times"); - } + assert!( + !self.flags.contains(EventFlags::HAS_ID), + "Called `Event::id` multiple times" + ); self.flags.insert(EventFlags::HAS_ID); let id = id.as_ref().as_bytes(); @@ -453,7 +458,7 @@ pub struct KeepAlive { impl KeepAlive { /// Create a new `KeepAlive`. - pub fn new() -> Self { + pub const fn new() -> Self { Self { event: Event::DEFAULT_KEEP_ALIVE, max_interval: Duration::from_secs(15), @@ -463,7 +468,7 @@ impl KeepAlive { /// Customize the interval between keep-alive messages. /// /// Default is 15 seconds. - pub fn interval(mut self, time: Duration) -> Self { + pub const fn interval(mut self, time: Duration) -> Self { self.max_interval = time; self } @@ -566,7 +571,7 @@ where } } -fn memchr_split(needle: u8, haystack: &[u8]) -> MemchrSplit<'_> { +const fn memchr_split(needle: u8, haystack: &[u8]) -> MemchrSplit<'_> { MemchrSplit { needle, haystack: Some(haystack), diff --git a/axum/src/routing/into_make_service.rs b/axum/src/routing/into_make_service.rs index 36da73a21e..73d002e7bf 100644 --- a/axum/src/routing/into_make_service.rs +++ b/axum/src/routing/into_make_service.rs @@ -14,7 +14,7 @@ pub struct IntoMakeService { } impl IntoMakeService { - pub(crate) fn new(svc: S) -> Self { + pub(crate) const fn new(svc: S) -> Self { Self { svc } } } diff --git a/axum/src/routing/method_filter.rs b/axum/src/routing/method_filter.rs index 040783ec33..2cb2b7e1a6 100644 --- a/axum/src/routing/method_filter.rs +++ b/axum/src/routing/method_filter.rs @@ -58,6 +58,7 @@ impl MethodFilter { } /// Performs the OR operation between the [`MethodFilter`] in `self` with `other`. + #[must_use] pub const fn or(self, other: Self) -> Self { Self(self.0 | other.0) } @@ -71,7 +72,7 @@ pub struct NoMatchingMethodFilter { impl NoMatchingMethodFilter { /// Get the [`Method`] that couldn't be converted to a [`MethodFilter`]. - pub fn method(&self) -> &Method { + pub const fn method(&self) -> &Method { &self.method } } @@ -89,15 +90,15 @@ impl TryFrom for MethodFilter { fn try_from(m: Method) -> Result { match m { - Method::CONNECT => Ok(MethodFilter::CONNECT), - Method::DELETE => Ok(MethodFilter::DELETE), - Method::GET => Ok(MethodFilter::GET), - Method::HEAD => Ok(MethodFilter::HEAD), - Method::OPTIONS => Ok(MethodFilter::OPTIONS), - Method::PATCH => Ok(MethodFilter::PATCH), - Method::POST => Ok(MethodFilter::POST), - Method::PUT => Ok(MethodFilter::PUT), - Method::TRACE => Ok(MethodFilter::TRACE), + Method::CONNECT => Ok(Self::CONNECT), + Method::DELETE => Ok(Self::DELETE), + Method::GET => Ok(Self::GET), + Method::HEAD => Ok(Self::HEAD), + Method::OPTIONS => Ok(Self::OPTIONS), + Method::PATCH => Ok(Self::PATCH), + Method::POST => Ok(Self::POST), + Method::PUT => Ok(Self::PUT), + Method::TRACE => Ok(Self::TRACE), other => Err(NoMatchingMethodFilter { method: other }), } } diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 42e46612eb..d0718feffa 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -571,14 +571,14 @@ enum AllowHeader { impl AllowHeader { fn merge(self, other: Self) -> Self { match (self, other) { - (AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip, - (AllowHeader::None, AllowHeader::None) => AllowHeader::None, - (AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick), - (AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick), - (AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => { + (Self::Skip, _) | (_, Self::Skip) => Self::Skip, + (Self::None, Self::None) => Self::None, + (Self::None, Self::Bytes(pick)) => Self::Bytes(pick), + (Self::Bytes(pick), Self::None) => Self::Bytes(pick), + (Self::Bytes(mut a), Self::Bytes(b)) => { a.extend_from_slice(b","); a.extend_from_slice(&b); - AllowHeader::Bytes(a) + Self::Bytes(a) } } } @@ -703,6 +703,7 @@ impl MethodRouter<(), Infallible> { /// ``` /// /// [`MakeService`]: tower::make::MakeService + #[must_use] pub fn into_make_service(self) -> IntoMakeService { IntoMakeService::new(self.with_state(())) } @@ -736,6 +737,7 @@ impl MethodRouter<(), Infallible> { /// [`MakeService`]: tower::make::MakeService /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info #[cfg(feature = "tokio")] + #[must_use] pub fn into_make_service_with_connect_info(self) -> IntoMakeServiceWithConnectInfo { IntoMakeServiceWithConnectInfo::new(self.with_state(())) } @@ -834,12 +836,11 @@ where S: Clone, { if endpoint_filter.contains(filter) { - if out.is_some() { - panic!( - "Overlapping method route. Cannot add two method routes that both handle \ + assert!( + !out.is_some(), + "Overlapping method route. Cannot add two method routes that both handle \ `{method_name}`", - ) - } + ); *out = endpoint.clone(); for method in methods { append_allow_header(allow_header, method); @@ -992,7 +993,7 @@ where #[doc = include_str!("../docs/method_routing/route_layer.md")] #[track_caller] - pub fn route_layer(mut self, layer: L) -> MethodRouter + pub fn route_layer(mut self, layer: L) -> Self where L: Layer> + Clone + Send + Sync + 'static, L::Service: Service + Clone + Send + Sync + 'static, @@ -1035,7 +1036,7 @@ where pub(crate) fn merge_for_path( mut self, path: Option<&str>, - other: MethodRouter, + other: Self, ) -> Result> { // written using inner functions to generate less IR fn merge_inner( @@ -1047,20 +1048,21 @@ where match (first, second) { (MethodEndpoint::None, MethodEndpoint::None) => Ok(MethodEndpoint::None), (pick, MethodEndpoint::None) | (MethodEndpoint::None, pick) => Ok(pick), - _ => { - if let Some(path) = path { + _ => path.map_or_else( + || { Err(format!( - "Overlapping method route. Handler for `{name} {path}` already exists" + "Overlapping method route. Cannot merge two method routes that both \ + define `{name}`" ) .into()) - } else { + }, + |path| { Err(format!( - "Overlapping method route. Cannot merge two method routes that both \ - define `{name}`" + "Overlapping method route. Handler for `{name} {path}` already exists" ) .into()) - } - } + }, + ), } } @@ -1086,7 +1088,7 @@ where #[doc = include_str!("../docs/method_routing/merge.md")] #[track_caller] - pub fn merge(self, other: MethodRouter) -> Self { + pub fn merge(self, other: Self) -> Self { match self.merge_for_path(None, other) { Ok(t) => t, // not using unwrap or unwrap_or_else to get a clean panic message + the right location @@ -1230,11 +1232,11 @@ impl MethodEndpoint where S: Clone, { - fn is_some(&self) -> bool { + const fn is_some(&self) -> bool { matches!(self, Self::Route(_) | Self::BoxedHandler(_)) } - fn is_none(&self) -> bool { + const fn is_none(&self) -> bool { matches!(self, Self::None) } @@ -1254,11 +1256,9 @@ where fn with_state(self, state: &S) -> MethodEndpoint { match self { - MethodEndpoint::None => MethodEndpoint::None, - MethodEndpoint::Route(route) => MethodEndpoint::Route(route), - MethodEndpoint::BoxedHandler(handler) => { - MethodEndpoint::Route(handler.into_route(state.clone())) - } + Self::None => MethodEndpoint::None, + Self::Route(route) => MethodEndpoint::Route(route), + Self::BoxedHandler(handler) => MethodEndpoint::Route(handler.into_route(state.clone())), } } } diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index e54a00bb28..2cd8f3860a 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -187,14 +187,11 @@ where T::Response: IntoResponse, T::Future: Send + 'static, { - let service = match try_downcast::, _>(service) { - Ok(_) => { - panic!( - "Invalid route: `Router::route_service` cannot be used with `Router`s. \ - Use `Router::nest` instead" - ); - } - Err(service) => service, + let Err(service) = try_downcast::(service) else { + panic!( + "Invalid route: `Router::route_service` cannot be used with `Router`s. \ + Use `Router::nest` instead" + ); }; tap_inner!(self, mut this => { @@ -205,10 +202,11 @@ where #[doc = include_str!("../docs/routing/nest.md")] #[doc(alias = "scope")] // Some web frameworks like actix-web use this term #[track_caller] - pub fn nest(self, path: &str, router: Router) -> Self { - if path.is_empty() || path == "/" { - panic!("Nesting at the root is no longer supported. Use merge instead."); - } + pub fn nest(self, path: &str, router: Self) -> Self { + assert!( + !(path.is_empty() || path == "/"), + "Nesting at the root is no longer supported. Use merge instead." + ); let RouterInner { path_router, @@ -232,9 +230,10 @@ where T::Response: IntoResponse, T::Future: Send + 'static, { - if path.is_empty() || path == "/" { - panic!("Nesting at the root is no longer supported. Use fallback_service instead."); - } + assert!( + !(path.is_empty() || path == "/"), + "Nesting at the root is no longer supported. Use fallback_service instead." + ); tap_inner!(self, mut this => { panic_on_err!(this.path_router.nest_service(path, service)); @@ -245,9 +244,9 @@ where #[track_caller] pub fn merge(self, other: R) -> Self where - R: Into>, + R: Into, { - let other: Router = other.into(); + let other: Self = other.into(); let RouterInner { path_router, default_fallback, @@ -270,7 +269,7 @@ where (false, false) => { panic!("Cannot merge two `Router`s that both have a fallback") } - }; + } panic_on_err!(this.path_router.merge(path_router)); @@ -284,7 +283,7 @@ where } #[doc = include_str!("../docs/routing/layer.md")] - pub fn layer(self, layer: L) -> Router + pub fn layer(self, layer: L) -> Self where L: Layer + Clone + Send + Sync + 'static, L::Service: Service + Clone + Send + Sync + 'static, @@ -317,6 +316,7 @@ where } /// True if the router currently has at least one route added. + #[must_use] pub fn has_routes(&self) -> bool { self.inner.path_router.has_routes() } @@ -513,7 +513,8 @@ where /// /// This is the same as [`Router::as_service`] instead it returns an owned [`Service`]. See /// that method for more details. - pub fn into_service(self) -> RouterIntoService { + #[must_use] + pub const fn into_service(self) -> RouterIntoService { RouterIntoService { router: self, _marker: PhantomData, @@ -540,6 +541,7 @@ impl Router { /// ``` /// /// [`MakeService`]: tower::make::MakeService + #[must_use] pub fn into_make_service(self) -> IntoMakeService { // call `Router::with_state` such that everything is turned into `Route` eagerly // rather than doing that per request @@ -548,6 +550,7 @@ impl Router { #[doc = include_str!("../docs/routing/into_make_service_with_connect_info.md")] #[cfg(feature = "tokio")] + #[must_use] pub fn into_make_service_with_connect_info(self) -> IntoMakeServiceWithConnectInfo { // call `Router::with_state` such that everything is turned into `Route` eagerly // rather than doing that per request @@ -725,16 +728,16 @@ where fn with_state(self, state: S) -> Fallback { match self { - Fallback::Default(route) => Fallback::Default(route), - Fallback::Service(route) => Fallback::Service(route), - Fallback::BoxedHandler(handler) => Fallback::Service(handler.into_route(state)), + Self::Default(route) => Fallback::Default(route), + Self::Service(route) => Fallback::Service(route), + Self::BoxedHandler(handler) => Fallback::Service(handler.into_route(state)), } } fn call_with_state(self, req: Request, state: S) -> RouteFuture { match self { - Fallback::Default(route) | Fallback::Service(route) => route.oneshot_inner_owned(req), - Fallback::BoxedHandler(handler) => { + Self::Default(route) | Self::Service(route) => route.oneshot_inner_owned(req), + Self::BoxedHandler(handler) => { let route = handler.clone().into_route(state); route.oneshot_inner_owned(req) } @@ -772,7 +775,7 @@ impl Endpoint where S: Clone + Send + Sync + 'static, { - fn layer(self, layer: L) -> Endpoint + fn layer(self, layer: L) -> Self where L: Layer + Clone + Send + Sync + 'static, L::Service: Service + Clone + Send + Sync + 'static, @@ -781,10 +784,8 @@ where >::Future: Send + 'static, { match self { - Endpoint::MethodRouter(method_router) => { - Endpoint::MethodRouter(method_router.layer(layer)) - } - Endpoint::Route(route) => Endpoint::Route(route.layer(layer)), + Self::MethodRouter(method_router) => Self::MethodRouter(method_router.layer(layer)), + Self::Route(route) => Self::Route(route.layer(layer)), } } } diff --git a/axum/src/routing/path_router.rs b/axum/src/routing/path_router.rs index 263cc032f7..2e85974e36 100644 --- a/axum/src/routing/path_router.rs +++ b/axum/src/routing/path_router.rs @@ -143,8 +143,8 @@ where .map_err(|err| format!("Invalid route {path:?}: {err}")) } - pub(super) fn merge(&mut self, other: PathRouter) -> Result<(), Cow<'static, str>> { - let PathRouter { + pub(super) fn merge(&mut self, other: Self) -> Result<(), Cow<'static, str>> { + let Self { routes, node, prev_route_id: _, @@ -172,11 +172,11 @@ where pub(super) fn nest( &mut self, path_to_nest_at: &str, - router: PathRouter, + router: Self, ) -> Result<(), Cow<'static, str>> { let prefix = validate_nest_path(self.v7_checks, path_to_nest_at); - let PathRouter { + let Self { routes, node, prev_route_id: _, @@ -248,7 +248,7 @@ where Ok(()) } - pub(super) fn layer(self, layer: L) -> PathRouter + pub(super) fn layer(self, layer: L) -> Self where L: Layer + Clone + Send + Sync + 'static, L::Service: Service + Clone + Send + Sync + 'static, @@ -265,7 +265,7 @@ where }) .collect(); - PathRouter { + Self { routes, node: self.node, prev_route_id: self.prev_route_id, @@ -282,12 +282,11 @@ where >::Error: Into + 'static, >::Future: Send + 'static, { - if self.routes.is_empty() { - panic!( - "Adding a route_layer before any routes is a no-op. \ + assert!( + !self.routes.is_empty(), + "Adding a route_layer before any routes is a no-op. \ Add the routes you want the layer to apply to first." - ); - } + ); let routes = self .routes @@ -298,7 +297,7 @@ where }) .collect(); - PathRouter { + Self { routes, node: self.node, prev_route_id: self.prev_route_id, diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index 6cdc58a617..adfa989aa4 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -59,7 +59,7 @@ impl Route { pub(crate) fn layer(self, layer: L) -> Route where - L: Layer> + Clone + Send + 'static, + L: Layer + Clone + Send + 'static, L::Service: Service + Clone + Send + Sync + 'static, >::Response: IntoResponse + 'static, >::Error: Into + 'static, @@ -117,7 +117,7 @@ pin_project! { } impl RouteFuture { - fn new( + const fn new( method: Method, inner: Oneshot, Request>, ) -> Self { @@ -134,7 +134,7 @@ impl RouteFuture { self } - pub(crate) fn not_top_level(mut self) -> Self { + pub(crate) const fn not_top_level(mut self) -> Self { self.top_level = false; self } @@ -216,7 +216,7 @@ pin_project! { } impl InfallibleRouteFuture { - pub(crate) fn new(future: RouteFuture) -> Self { + pub(crate) const fn new(future: RouteFuture) -> Self { Self { future } } } diff --git a/axum/src/routing/tests/handle_error.rs b/axum/src/routing/tests/handle_error.rs index a2fd2e6828..530fa84f4e 100644 --- a/axum/src/routing/tests/handle_error.rs +++ b/axum/src/routing/tests/handle_error.rs @@ -4,6 +4,8 @@ use tower::timeout::TimeoutLayer; async fn unit() {} +// Using a semicolon causes an error. +#[allow(clippy::semicolon_if_nothing_returned)] async fn forever() { pending().await } @@ -83,7 +85,7 @@ async fn handler_multiple_methods_last() { async fn handler_service_ext() { let fallible_service = tower::service_fn(|_| async { Err::<(), ()>(()) }); let handle_error_service = - fallible_service.handle_error(|_| async { StatusCode::INTERNAL_SERVER_ERROR }); + fallible_service.handle_error(|()| async { StatusCode::INTERNAL_SERVER_ERROR }); let app = Router::new().route("/", get_service(handle_error_service)); diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 969710627e..00b84a9fd8 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -1090,7 +1090,7 @@ async fn logging_rejections() { level: "TRACE".to_owned(), }, ]) - ) + ); } // https://github.com/tokio-rs/axum/issues/1955 diff --git a/axum/src/routing/url_params.rs b/axum/src/routing/url_params.rs index 1649a9e4cf..467ef0ca0f 100644 --- a/axum/src/routing/url_params.rs +++ b/axum/src/routing/url_params.rs @@ -22,11 +22,8 @@ pub(super) fn insert_url_params(extensions: &mut Extensions, params: Params<'_, .filter(|(key, _)| !key.starts_with(super::NEST_TAIL_PARAM)) .filter(|(key, _)| !key.starts_with(super::FALLBACK_PARAM)) .map(|(k, v)| { - if let Some(decoded) = PercentDecodedStr::new(v) { - Ok((Arc::from(k), decoded)) - } else { - Err(Arc::from(k)) - } + PercentDecodedStr::new(v) + .map_or_else(|| Err(Arc::from(k)), |decoded| Ok((Arc::from(k), decoded))) }) .collect::, _>>(); diff --git a/axum/src/serve/mod.rs b/axum/src/serve/mod.rs index 3a470af2a7..311e5834db 100644 --- a/axum/src/serve/mod.rs +++ b/axum/src/serve/mod.rs @@ -96,7 +96,7 @@ pub use self::listener::{Listener, ListenerExt, TapIo}; /// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info /// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -pub fn serve(listener: L, make_service: M) -> Serve +pub const fn serve(listener: L, make_service: M) -> Serve where L: Listener, M: for<'a> Service, Error = Infallible, Response = S>, @@ -298,7 +298,7 @@ where loop { let (io, remote_addr) = tokio::select! { conn = listener.accept() => conn, - _ = signal_tx.closed() => { + () = signal_tx.closed() => { trace!("signal received, not accepting new connections"); break; } @@ -424,7 +424,7 @@ async fn handle_connection( } break; } - _ = &mut signal_closed => { + () = &mut signal_closed => { trace!("signal received in task, starting graceful shutdown"); conn.as_mut().graceful_shutdown(); } @@ -459,7 +459,7 @@ where } /// Returns the remote address that this stream is bound to. - pub fn remote_addr(&self) -> &L::Addr { + pub const fn remote_addr(&self) -> &L::Addr { &self.remote_addr } } diff --git a/axum/src/test_helpers/counting_cloneable_state.rs b/axum/src/test_helpers/counting_cloneable_state.rs index 762d5ce972..f428e737f6 100644 --- a/axum/src/test_helpers/counting_cloneable_state.rs +++ b/axum/src/test_helpers/counting_cloneable_state.rs @@ -18,7 +18,7 @@ impl CountingCloneableState { setup_done: AtomicBool::new(false), count: AtomicUsize::new(0), }; - CountingCloneableState { + Self { state: Arc::new(inner_state), } } @@ -47,6 +47,6 @@ impl Clone for CountingCloneableState { state.count.fetch_add(1, Ordering::SeqCst); } - CountingCloneableState { state } + Self { state } } } diff --git a/axum/src/test_helpers/mod.rs b/axum/src/test_helpers/mod.rs index 06f9101c3f..3f03caa084 100644 --- a/axum/src/test_helpers/mod.rs +++ b/axum/src/test_helpers/mod.rs @@ -12,9 +12,9 @@ pub(crate) mod tracing_helpers; pub(crate) mod counting_cloneable_state; #[cfg(test)] -pub(crate) fn assert_send() {} +pub(crate) const fn assert_send() {} #[cfg(test)] -pub(crate) fn assert_sync() {} +pub(crate) const fn assert_sync() {} #[allow(dead_code)] pub(crate) struct NotSendSync(*const ()); diff --git a/axum/src/test_helpers/test_client.rs b/axum/src/test_helpers/test_client.rs index e7ba36e95e..d23ae5b2a8 100644 --- a/axum/src/test_helpers/test_client.rs +++ b/axum/src/test_helpers/test_client.rs @@ -23,7 +23,7 @@ where tokio::spawn(async move { serve(listener, Shared::new(svc)) .await - .expect("server error") + .expect("server error"); }); addr @@ -47,21 +47,24 @@ impl TestClient { .build() .unwrap(); - TestClient { client, addr } + Self { client, addr } } + #[must_use] pub fn get(&self, url: &str) -> RequestBuilder { RequestBuilder { builder: self.client.get(format!("http://{}{url}", self.addr)), } } + #[must_use] pub fn head(&self, url: &str) -> RequestBuilder { RequestBuilder { builder: self.client.head(format!("http://{}{url}", self.addr)), } } + #[must_use] pub fn post(&self, url: &str) -> RequestBuilder { RequestBuilder { builder: self.client.post(format!("http://{}{url}", self.addr)), @@ -69,6 +72,7 @@ impl TestClient { } #[allow(dead_code)] + #[must_use] pub fn put(&self, url: &str) -> RequestBuilder { RequestBuilder { builder: self.client.put(format!("http://{}{url}", self.addr)), @@ -76,6 +80,7 @@ impl TestClient { } #[allow(dead_code)] + #[must_use] pub fn patch(&self, url: &str) -> RequestBuilder { RequestBuilder { builder: self.client.patch(format!("http://{}{url}", self.addr)), @@ -83,7 +88,8 @@ impl TestClient { } #[allow(dead_code)] - pub fn server_port(&self) -> u16 { + #[must_use] + pub const fn server_port(&self) -> u16 { self.addr.port() } } @@ -93,11 +99,13 @@ pub struct RequestBuilder { } impl RequestBuilder { + #[must_use] pub fn body(mut self, body: impl Into) -> Self { self.builder = self.builder.body(body); self } + #[must_use] pub fn json(mut self, json: &T) -> Self where T: serde::Serialize, @@ -106,6 +114,7 @@ impl RequestBuilder { self } + #[must_use] pub fn header(mut self, key: K, value: V) -> Self where HeaderName: TryFrom, @@ -117,6 +126,7 @@ impl RequestBuilder { self } + #[must_use] #[allow(dead_code)] pub fn multipart(mut self, form: reqwest::multipart::Form) -> Self { self.builder = self.builder.multipart(form); diff --git a/axum/src/test_helpers/tracing_helpers.rs b/axum/src/test_helpers/tracing_helpers.rs index adf4331fd0..93f7efee26 100644 --- a/axum/src/test_helpers/tracing_helpers.rs +++ b/axum/src/test_helpers/tracing_helpers.rs @@ -114,14 +114,14 @@ struct Writer<'a>(&'a TestMakeWriter); impl io::Write for Writer<'_> { fn write(&mut self, buf: &[u8]) -> io::Result { - match &mut *self.0.write.lock().unwrap() { - Some(vec) => { + (*self.0.write.lock().unwrap()).as_mut().map_or_else( + || Err(io::Error::other("inner writer has been taken")), + |vec| { let len = buf.len(); vec.extend(buf); Ok(len) - } - None => Err(io::Error::other("inner writer has been taken")), - } + }, + ) } fn flush(&mut self) -> io::Result<()> { diff --git a/axum/src/util.rs b/axum/src/util.rs index e4014c5916..8ab3d53230 100644 --- a/axum/src/util.rs +++ b/axum/src/util.rs @@ -51,7 +51,7 @@ pub(crate) struct MapIntoResponse { } impl MapIntoResponse { - pub(crate) fn new(inner: S) -> Self { + pub(crate) const fn new(inner: S) -> Self { Self { inner } } } @@ -102,11 +102,9 @@ where K: Send + 'static, { let mut k = Some(k); - if let Some(k) = ::downcast_mut::>(&mut k) { - Ok(k.take().unwrap()) - } else { - Err(k.unwrap()) - } + ::downcast_mut::>(&mut k) + .and_then(Option::take) + .map_or_else(|| Err(k.unwrap()), Ok) } #[test]