diff --git a/src/wiremock/builder.rs b/src/wiremock/builder.rs index 0465e41..85e27d2 100644 --- a/src/wiremock/builder.rs +++ b/src/wiremock/builder.rs @@ -1,4 +1,6 @@ use crate::wiremock::grpc_server::{GrpcServer, RuleItem}; +use tonic::codegen::http::header::IntoHeaderName; +use tonic::codegen::http::{request, HeaderMap, HeaderValue}; pub trait Then { fn return_status(self, status: tonic::Code) -> Self; @@ -7,6 +9,12 @@ pub trait Then { where F: Fn() -> T, T: prost::Message; + + fn return_header(self, key: K, value: V) -> Self + where + K: IntoHeaderName, + V: TryInto, + >::Error: std::fmt::Debug; } pub trait Mountable { @@ -19,25 +27,41 @@ pub struct MockBuilder { pub(crate) path: String, pub(crate) status_code: Option, pub(crate) result: Option>, + pub(crate) request_headers: HeaderMap, + pub(crate) response_headers: HeaderMap, } #[derive(Clone)] pub struct WhenBuilder { path: Option, + headers: HeaderMap, } impl WhenBuilder { pub fn path(&self, p: &str) -> Self { Self { path: Some(p.into()), + headers: self.headers.clone(), } } + pub fn header(mut self, key: K, value: V) -> Self + where + K: IntoHeaderName, + V: TryInto, + >::Error: std::fmt::Debug, + { + self.headers.insert(key, value.try_into().unwrap()); + self + } + pub fn then(&self) -> ThenBuilder { self.validate(); ThenBuilder { path: self.path.clone().unwrap(), status_code: None, result: None, + request_headers: self.headers.clone(), + response_headers: HeaderMap::new(), } } @@ -53,6 +77,8 @@ pub struct ThenBuilder { pub(crate) path: String, pub(crate) status_code: Option, pub(crate) result: Option>, + pub(crate) request_headers: HeaderMap, + pub(crate) response_headers: HeaderMap, } impl MockBuilder { @@ -61,11 +87,40 @@ impl MockBuilder { path: path.into(), result: None, status_code: None, + request_headers: HeaderMap::new(), + response_headers: HeaderMap::new(), } } pub fn when() -> WhenBuilder { - WhenBuilder { path: None } + WhenBuilder { + path: None, + headers: HeaderMap::new(), + } + } + + pub(crate) fn matches( + &self, + req: &request::Request, + ) -> bool { + if self.path != req.uri().path() { + return false; + } + + for (key, value) in &self.request_headers { + if !req.headers().contains_key(key.as_str()) { + return false; + } + let Some(mock_value) = req.headers().get(key.as_str()) else { + return false; + }; + + if mock_value != value { + return false; + } + } + + true } } @@ -108,6 +163,16 @@ impl Then for MockBuilder { ..self } } + + fn return_header(mut self, key: K, value: V) -> Self + where + K: IntoHeaderName, + V: TryInto, + >::Error: std::fmt::Debug, + { + self.response_headers.insert(key, value.try_into().unwrap()); + self + } } impl Then for ThenBuilder { @@ -135,6 +200,16 @@ impl Then for ThenBuilder { ..self } } + + fn return_header(mut self, key: K, value: V) -> Self + where + K: IntoHeaderName, + V: TryInto, + >::Error: std::fmt::Debug, + { + self.response_headers.insert(key, value.try_into().unwrap()); + self + } } #[allow(clippy::from_over_into)] @@ -144,6 +219,8 @@ impl Into for ThenBuilder { path: self.path, status_code: self.status_code, result: self.result, + request_headers: self.request_headers, + response_headers: self.response_headers, } } } @@ -151,7 +228,6 @@ impl Into for ThenBuilder { impl Mountable for ThenBuilder { fn mount(self, s: &mut GrpcServer) { let rb: MockBuilder = self.into(); - rb.mount(s); } } diff --git a/src/wiremock/grpc_server.rs b/src/wiremock/grpc_server.rs index 9cf3f9b..516876b 100644 --- a/src/wiremock/grpc_server.rs +++ b/src/wiremock/grpc_server.rs @@ -173,14 +173,14 @@ impl GrpcServer { { info!("Request to {}", req.uri().path()); - let path = req.uri().path(); let mut inner = self.rules.write().unwrap(); - if let Some(item) = inner.iter_mut().find(|x| x.rule.path == path) { + if let Some(item) = inner.iter_mut().find(|x| x.rule.matches(&req)) { info!("Matched rule {:?}", item); item.record_request(&req); let code = item.rule.status_code.unwrap_or(Code::Ok); + let return_headers = item.rule.response_headers.clone(); if let Some(body) = &item.rule.result { debug!("Returning body ({} bytes)", body.len()); let body = body.clone(); @@ -191,10 +191,19 @@ impl GrpcServer { let mut grpc = tonic::server::Grpc::new(codec); let mut result = grpc.unary(method, req).await; - result.headers_mut().append( + + let headers = result.headers_mut(); + headers.append( "grpc-status", HeaderValue::from_str(format!("{}", code as u32).as_str()).unwrap(), ); + + for (name, value) in return_headers { + if let Some(name) = name { + headers.insert(name, value); + } + } + Ok(result) }; return Box::pin(fut); diff --git a/tests/features_test.rs b/tests/features_test.rs index 7178374..33f9e5a 100644 --- a/tests/features_test.rs +++ b/tests/features_test.rs @@ -10,7 +10,7 @@ use hello::{ greeter_client::GreeterClient, HelloReply, HelloRequest, WeatherReply, WeatherRequest, }; use std::net::TcpStream; -use tonic::{transport::Channel, Code}; +use tonic::{transport::Channel, Code, Request}; use wiremock_gen::*; use wiremock_grpc::*; @@ -74,6 +74,37 @@ async fn default() { ); } +#[tokio::test] +async fn mocked_header_return() { + let (mut server, mut client) = create().await; + + // Setup + let mock = server.setup( + MockBuilder::given("/hello.Greeter/SayHello") + .return_header("X-RateLimit-Remaining", "100") + .return_body(|| HelloReply { + message: "Hello to you too!".into(), + }), + ); + + // Act + let response = client + .say_hello(HelloRequest { + name: "Yo yo".into(), + }) + .await + .unwrap(); + + assert_eq!( + response + .metadata() + .get("X-RateLimit-Remaining") + .expect("header should be set"), + "100" + ); + let _ = server.find_one(&mock); +} + #[tokio::test] async fn handled_when_mock_set_with_different_status_code() { // client & server @@ -158,6 +189,78 @@ async fn multiple_mocks() { assert_eq!(2, server.find_request_count()); } +#[tokio::test] +async fn header_discriminated_mocks() { + let (mut server, mut client) = create().await; + + let mock1_session_id = "mock1"; + let mock2_session_id = "mock2"; + // setup + let mock1 = server.setup( + MockBuilder::when() + .path("/hello.Greeter/SayHello") + .header("session-id", mock1_session_id) + .then() + .return_body(|| HelloReply { + message: "Hello to you too!".into(), + }), + ); + + let mock2 = server.setup( + MockBuilder::when() + .path("/hello.Greeter/SayHello") + .header("session-id", mock2_session_id) + .then() + .return_body(|| HelloReply { + message: "Hello to you two!".into(), + }), + ); + + // Act + let mut request1 = Request::new(HelloRequest { + name: "Mustakim".into(), + }); + + request1 + .metadata_mut() + .insert("session-id", mock1_session_id.parse().unwrap()); + let response1 = client.say_hello(request1).await.unwrap(); + + assert_eq!("Hello to you too!", response1.into_inner().message); + + let mut request2 = Request::new(HelloRequest { name: "Zak".into() }); + request2 + .metadata_mut() + .insert("session-id", mock2_session_id.parse().unwrap()); + let response2 = client.say_hello(request2).await.unwrap(); + assert_eq!("Hello to you two!", response2.into_inner().message); + + // single request + let tracked_response_1 = server.find_one(&mock1); + let tracked_response_2 = server.find_one(&mock2); + + assert_eq!( + tracked_response_1 + .headers + .get("session-id") + .expect("header set") + .to_str() + .unwrap(), + mock1_session_id + ); + assert_eq!( + tracked_response_2 + .headers + .get("session-id") + .expect("header set") + .to_str() + .unwrap(), + mock2_session_id + ); + + assert_eq!(2, server.find_request_count()); +} + #[tokio::test] #[should_panic(expected = "Server terminated with unmatched rules: \n/hello.Greeter/SayHello")] async fn unmatched_request_panics() {