Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion crates/twirp-build/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,3 @@ prost-build = "0.13"
prettyplease = { version = "0.2" }
quote = "1.0"
syn = "2.0"
proc-macro2 = "1.0"
14 changes: 12 additions & 2 deletions crates/twirp-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,19 +166,29 @@ impl prost_build::ServiceGenerator for ServiceGenerator {
let mut client_methods = Vec::with_capacity(service.methods.len());
for m in &service.methods {
let name = &m.name;
let build_name = format_ident!("build_{}", name);
let input_type = &m.input_type;
let output_type = &m.output_type;
let request_path = format!("{}/{}", service.fqn, m.proto_name);

client_trait_methods.push(quote! {
async fn #name(&self, req: #input_type) -> Result<#output_type, twirp::ClientError>;
});
client_trait_methods.push(quote! {
fn #build_name(&self, req: #input_type) -> Result<twirp::RequestBuilder<#input_type, #output_type>, twirp::ClientError>;
});

client_methods.push(quote! {
fn #build_name(&self, req: #input_type) -> Result<twirp::RequestBuilder<#input_type, #output_type>, twirp::ClientError> {
self.build_request(#request_path, req)
}
});
client_methods.push(quote! {
async fn #name(&self, req: #input_type) -> Result<#output_type, twirp::ClientError> {
self.request(#request_path, req).await
let builder = self.#build_name(req)?;
self.request(builder).await
}
})
});
}
let client_trait = quote! {
#[twirp::async_trait::async_trait]
Expand Down
53 changes: 48 additions & 5 deletions crates/twirp/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::sync::Arc;
use std::vec;

use async_trait::async_trait;
use http::{HeaderName, HeaderValue};
use reqwest::header::{InvalidHeaderValue, CONTENT_TYPE};
use reqwest::StatusCode;
use thiserror::Error;
Expand Down Expand Up @@ -145,6 +146,8 @@ impl Client {
&self.inner.base_url
}

// TODO: Move this to the `ClientBuilder`
//
/// Creates a new `twirp::Client` with the same configuration as the current
/// one, but with a different host in the base URL.
pub fn with_host(&self, host: &str) -> Self {
Expand All @@ -155,8 +158,7 @@ impl Client {
}
}

/// Make an HTTP twirp request.
pub async fn request<I, O>(&self, path: &str, body: I) -> Result<O>
pub fn build_request<I, O>(&self, path: &str, body: I) -> Result<RequestBuilder<I, O>>
where
I: prost::Message,
O: prost::Message + Default,
Expand All @@ -165,13 +167,23 @@ impl Client {
if let Some(host) = &self.host {
url.set_host(Some(host))?
};
let path = url.path().to_string();

let req = self
.http_client
.post(url)
.header(CONTENT_TYPE, CONTENT_TYPE_PROTOBUF)
.body(serialize_proto_message(body))
.build()?;
.body(serialize_proto_message(body));
Ok(RequestBuilder::new(req))
}

/// Make an HTTP twirp request.
pub async fn request<I, O>(&self, builder: RequestBuilder<I, O>) -> Result<O>
where
I: prost::Message,
O: prost::Message + Default,
{
let req = builder.build()?;
let path = req.url().path().to_string();

// Create and execute the middleware handlers
let next = Next::new(&self.http_client, &self.inner.middlewares);
Expand Down Expand Up @@ -206,6 +218,37 @@ impl Client {
}
}

pub struct RequestBuilder<I, O> {
inner: reqwest::RequestBuilder,
_input: std::marker::PhantomData<I>,
_output: std::marker::PhantomData<O>,
}

impl<I, O> RequestBuilder<I, O> {
pub fn new(inner: reqwest::RequestBuilder) -> Self {
Self {
inner,
_input: std::marker::PhantomData,
_output: std::marker::PhantomData,
}
}

pub fn header<K, V>(self, key: K, value: V) -> RequestBuilder<I, O>
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
RequestBuilder::new(self.inner.header(key, value))
}

/// Builds the request.
pub fn build(self) -> Result<reqwest::Request, reqwest::Error> {
self.inner.build()
}
}

// This concept of reqwest middleware is taken pretty much directly from:
// https://github.com/TrueLayer/reqwest-middleware, but simplified for the
// specific needs of this twirp client.
Expand Down
2 changes: 1 addition & 1 deletion crates/twirp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub mod test;
#[doc(hidden)]
pub mod details;

pub use client::{Client, ClientBuilder, ClientError, Middleware, Next, Result};
pub use client::{Client, ClientBuilder, ClientError, Middleware, Next, RequestBuilder, Result};
pub use context::Context;
pub use error::*; // many constructors like `invalid_argument()`
pub use http::Extensions;
Expand Down
3 changes: 2 additions & 1 deletion crates/twirp/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ pub trait TestApiClient {
#[async_trait]
impl TestApiClient for Client {
async fn ping(&self, req: PingRequest) -> Result<PingResponse> {
self.request("test.TestAPI/Ping", req).await
let req = self.build_request("test.TestAPI/Ping", req)?;
self.request(req).await
}

async fn boom(&self, _req: PingRequest) -> Result<PingResponse> {
Expand Down
29 changes: 26 additions & 3 deletions example/src/bin/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ pub async fn main() -> Result<(), GenericError> {
.await;
eprintln!("{:?}", resp);

// TODO: Figure out where `with_host` goes in all this...
let req = client
.with_host("localhost")
.build_make_hat(MakeHatRequest { inches: 1 })?
.header("x-custom-header", "a");
// Make a request with context
let resp = client.request(req).await?;
eprintln!("{:?}", resp);

Ok(())
}

Expand Down Expand Up @@ -69,23 +78,37 @@ impl Middleware for PrintResponseHeaders {
}
}

// NOTE: This is just to demonstrate manually implementing the client trait. You don't need to do this as this code will
// be generated for you by twirp-build.
//
// This is here so that we can visualize changes to the generated client code
#[allow(dead_code)]
#[derive(Debug)]
struct MockHaberdasherApiClient;

#[async_trait]
impl HaberdasherApiClient for MockHaberdasherApiClient {
async fn make_hat(
fn build_make_hat(
&self,
_req: MakeHatRequest,
) -> Result<MakeHatResponse, twirp::client::ClientError> {
) -> Result<twirp::RequestBuilder<MakeHatRequest, MakeHatResponse>, twirp::ClientError> {
todo!()
}
async fn make_hat(&self, _req: MakeHatRequest) -> Result<MakeHatResponse, twirp::ClientError> {
todo!()
}

fn build_get_status(
&self,
_req: GetStatusRequest,
) -> Result<twirp::RequestBuilder<GetStatusRequest, GetStatusResponse>, twirp::ClientError>
{
todo!()
}
async fn get_status(
&self,
_req: GetStatusRequest,
) -> Result<GetStatusResponse, twirp::client::ClientError> {
) -> Result<GetStatusResponse, twirp::ClientError> {
todo!()
}
}