Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 78 additions & 2 deletions src/wiremock/builder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::wiremock::grpc_server::{GrpcServer, RuleItem};
use tonic::codegen::http::header::IntoHeaderName;
use tonic::codegen::http::{request, HeaderMap, HeaderValue};

pub trait Then {
fn return_status(self, status: tonic::Code) -> Self;
Expand All @@ -7,6 +9,12 @@ pub trait Then {
where
F: Fn() -> T,
T: prost::Message;

fn return_header<K, V>(self, key: K, value: V) -> Self
where
K: IntoHeaderName,
V: TryInto<HeaderValue>,
<V as TryInto<HeaderValue>>::Error: std::fmt::Debug;
}

pub trait Mountable {
Expand All @@ -19,25 +27,41 @@ pub struct MockBuilder {
pub(crate) path: String,
pub(crate) status_code: Option<tonic::Code>,
pub(crate) result: Option<Vec<u8>>,
pub(crate) request_headers: HeaderMap,
pub(crate) response_headers: HeaderMap,
}

#[derive(Clone)]
pub struct WhenBuilder {
path: Option<String>,
headers: HeaderMap,
}
impl WhenBuilder {
pub fn path(&self, p: &str) -> Self {
Self {
path: Some(p.into()),
headers: self.headers.clone(),
}
}

pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
K: IntoHeaderName,
V: TryInto<HeaderValue>,
<V as TryInto<HeaderValue>>::Error: std::fmt::Debug,
{
self.headers.insert(key, value.try_into().unwrap());
self
}

pub fn then(&self) -> ThenBuilder {
self.validate();
ThenBuilder {
path: self.path.clone().unwrap(),
status_code: None,
result: None,
request_headers: self.headers.clone(),
response_headers: HeaderMap::new(),
}
}

Expand All @@ -53,6 +77,8 @@ pub struct ThenBuilder {
pub(crate) path: String,
pub(crate) status_code: Option<tonic::Code>,
pub(crate) result: Option<Vec<u8>>,
pub(crate) request_headers: HeaderMap,
pub(crate) response_headers: HeaderMap,
}

impl MockBuilder {
Expand All @@ -61,11 +87,40 @@ impl MockBuilder {
path: path.into(),
result: None,
status_code: None,
request_headers: HeaderMap::new(),
response_headers: HeaderMap::new(),
}
}

pub fn when() -> WhenBuilder {
WhenBuilder { path: None }
WhenBuilder {
path: None,
headers: HeaderMap::new(),
}
}

pub(crate) fn matches<B: http_body::Body + Send + 'static>(
&self,
req: &request::Request<B>,
) -> bool {
if self.path != req.uri().path() {
return false;
}

for (key, value) in &self.request_headers {
if !req.headers().contains_key(key.as_str()) {
return false;
}
let Some(mock_value) = req.headers().get(key.as_str()) else {
return false;
};

if mock_value != value {
return false;
}
}

true
}
}

Expand Down Expand Up @@ -108,6 +163,16 @@ impl Then for MockBuilder {
..self
}
}

fn return_header<K, V>(mut self, key: K, value: V) -> Self
where
K: IntoHeaderName,
V: TryInto<HeaderValue>,
<V as TryInto<HeaderValue>>::Error: std::fmt::Debug,
{
self.response_headers.insert(key, value.try_into().unwrap());
self
}
}

impl Then for ThenBuilder {
Expand Down Expand Up @@ -135,6 +200,16 @@ impl Then for ThenBuilder {
..self
}
}

fn return_header<K, V>(mut self, key: K, value: V) -> Self
where
K: IntoHeaderName,
V: TryInto<HeaderValue>,
<V as TryInto<HeaderValue>>::Error: std::fmt::Debug,
{
self.response_headers.insert(key, value.try_into().unwrap());
self
}
}

#[allow(clippy::from_over_into)]
Expand All @@ -144,14 +219,15 @@ impl Into<MockBuilder> for ThenBuilder {
path: self.path,
status_code: self.status_code,
result: self.result,
request_headers: self.request_headers,
response_headers: self.response_headers,
}
}
}

impl Mountable for ThenBuilder {
fn mount(self, s: &mut GrpcServer) {
let rb: MockBuilder = self.into();

rb.mount(s);
}
}
15 changes: 12 additions & 3 deletions src/wiremock/grpc_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,14 @@ impl GrpcServer {
{
info!("Request to {}", req.uri().path());

let path = req.uri().path();
let mut inner = self.rules.write().unwrap();

if let Some(item) = inner.iter_mut().find(|x| x.rule.path == path) {
if let Some(item) = inner.iter_mut().find(|x| x.rule.matches(&req)) {
info!("Matched rule {:?}", item);
item.record_request(&req);

let code = item.rule.status_code.unwrap_or(Code::Ok);
let return_headers = item.rule.response_headers.clone();
if let Some(body) = &item.rule.result {
debug!("Returning body ({} bytes)", body.len());
let body = body.clone();
Expand All @@ -191,10 +191,19 @@ impl GrpcServer {

let mut grpc = tonic::server::Grpc::new(codec);
let mut result = grpc.unary(method, req).await;
result.headers_mut().append(

let headers = result.headers_mut();
headers.append(
"grpc-status",
HeaderValue::from_str(format!("{}", code as u32).as_str()).unwrap(),
);

for (name, value) in return_headers {
if let Some(name) = name {
headers.insert(name, value);
}
}

Ok(result)
};
return Box::pin(fut);
Expand Down
105 changes: 104 additions & 1 deletion tests/features_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use hello::{
greeter_client::GreeterClient, HelloReply, HelloRequest, WeatherReply, WeatherRequest,
};
use std::net::TcpStream;
use tonic::{transport::Channel, Code};
use tonic::{transport::Channel, Code, Request};
use wiremock_gen::*;
use wiremock_grpc::*;

Expand Down Expand Up @@ -74,6 +74,37 @@ async fn default() {
);
}

#[tokio::test]
async fn mocked_header_return() {
let (mut server, mut client) = create().await;

// Setup
let mock = server.setup(
MockBuilder::given("/hello.Greeter/SayHello")
.return_header("X-RateLimit-Remaining", "100")
.return_body(|| HelloReply {
message: "Hello to you too!".into(),
}),
);

// Act
let response = client
.say_hello(HelloRequest {
name: "Yo yo".into(),
})
.await
.unwrap();

assert_eq!(
response
.metadata()
.get("X-RateLimit-Remaining")
.expect("header should be set"),
"100"
);
let _ = server.find_one(&mock);
}

#[tokio::test]
async fn handled_when_mock_set_with_different_status_code() {
// client & server
Expand Down Expand Up @@ -158,6 +189,78 @@ async fn multiple_mocks() {
assert_eq!(2, server.find_request_count());
}

#[tokio::test]
async fn header_discriminated_mocks() {
let (mut server, mut client) = create().await;

let mock1_session_id = "mock1";
let mock2_session_id = "mock2";
// setup
let mock1 = server.setup(
MockBuilder::when()
.path("/hello.Greeter/SayHello")
.header("session-id", mock1_session_id)
.then()
.return_body(|| HelloReply {
message: "Hello to you too!".into(),
}),
);

let mock2 = server.setup(
MockBuilder::when()
.path("/hello.Greeter/SayHello")
.header("session-id", mock2_session_id)
.then()
.return_body(|| HelloReply {
message: "Hello to you two!".into(),
}),
);

// Act
let mut request1 = Request::new(HelloRequest {
name: "Mustakim".into(),
});

request1
.metadata_mut()
.insert("session-id", mock1_session_id.parse().unwrap());
let response1 = client.say_hello(request1).await.unwrap();

assert_eq!("Hello to you too!", response1.into_inner().message);

let mut request2 = Request::new(HelloRequest { name: "Zak".into() });
request2
.metadata_mut()
.insert("session-id", mock2_session_id.parse().unwrap());
let response2 = client.say_hello(request2).await.unwrap();
assert_eq!("Hello to you two!", response2.into_inner().message);

// single request
let tracked_response_1 = server.find_one(&mock1);
let tracked_response_2 = server.find_one(&mock2);

assert_eq!(
tracked_response_1
.headers
.get("session-id")
.expect("header set")
.to_str()
.unwrap(),
mock1_session_id
);
assert_eq!(
tracked_response_2
.headers
.get("session-id")
.expect("header set")
.to_str()
.unwrap(),
mock2_session_id
);

assert_eq!(2, server.find_request_count());
}

#[tokio::test]
#[should_panic(expected = "Server terminated with unmatched rules: \n/hello.Greeter/SayHello")]
async fn unmatched_request_panics() {
Expand Down