diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 27588005a..b6931df59 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -25,14 +25,74 @@ jobs: components: rustfmt - run: cargo fmt --all --check + build-protoc-plugin: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macOS-latest, windows-latest] + outputs: + cache-hit: ${{ steps.cache-plugin.outputs.cache-hit }} + steps: + - uses: actions/checkout@v4 + - name: Cache protoc plugin + id: cache-plugin + uses: actions/cache@v4 + with: + path: ${{ runner.temp }}/protoc-plugin + # The key changes only when plugin source files change + key: ${{ runner.os }}-protoc-plugin-${{ hashFiles('protoc-gen-rust-grpc/src/**', 'protoc-gen-rust-grpc/.bazelrc', 'protoc-gen-rust-grpc/MODULE.bazel') }} + - name: Install Bazel + if: steps.cache-plugin.outputs.cache-hit != 'true' + uses: bazel-contrib/setup-bazel@0.15.0 + with: + # Avoid downloading Bazel every time. + bazelisk-cache: true + # Store build cache per workflow. + disk-cache: ${{ github.workflow }} + # Share repository cache between workflows. + repository-cache: true + module-root: ./protoc-gen-rust-grpc + # Building the protoc plugin from scratch takes 6–14 minutes, depending on + # the OS. This delays the execution of workflows that use the plugin in + # build.rs files. We try to avoid rebuilding the plugin if it hasn't + # changed. + - name: Build protoc plugin + if: steps.cache-plugin.outputs.cache-hit != 'true' + working-directory: ./protoc-gen-rust-grpc + shell: bash + run: | + set -e + # On windows, the "//src" gets converted to "/". Disable this path + # conversion. + export MSYS_NO_PATHCONV=1 + export MSYS2_ARG_CONV_EXCL="*" + + bazel build //src:protoc-gen-rust-grpc --enable_platform_specific_config + + # The target path needs to match the cache config. + TARGET_PATH="${{ runner.temp }}/protoc-plugin" + mkdir -p "${TARGET_PATH}" + cp bazel-bin/src/protoc-gen-rust-grpc "${TARGET_PATH}" + clippy: runs-on: ubuntu-latest + needs: build-protoc-plugin steps: - uses: actions/checkout@v4 - uses: hecrj/setup-rust-action@v2 with: components: clippy - uses: taiki-e/install-action@protoc + - name: Restore protoc plugin from cache + id: cache-plugin + uses: actions/cache@v4 + with: + path: ${{ runner.temp }}/protoc-plugin + key: ${{ runner.os }}-protoc-plugin-${{ hashFiles('protoc-gen-rust-grpc/src/**', 'protoc-gen-rust-grpc/.bazelrc', 'protoc-gen-rust-grpc/MODULE.bazel') }} + - name: Add protoc plugin to PATH + shell: bash + run: | + echo "${{ runner.temp }}/protoc-plugin" >> $GITHUB_PATH - uses: Swatinem/rust-cache@v2 - run: cargo clippy --workspace --all-features --all-targets @@ -47,6 +107,7 @@ jobs: udeps: runs-on: ubuntu-latest + needs: build-protoc-plugin steps: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master @@ -55,6 +116,16 @@ jobs: - uses: taiki-e/install-action@cargo-hack - uses: taiki-e/install-action@cargo-udeps - uses: taiki-e/install-action@protoc + - name: Restore protoc plugin from cache + id: cache-plugin + uses: actions/cache@v4 + with: + path: ${{ runner.temp }}/protoc-plugin + key: ${{ runner.os }}-protoc-plugin-${{ hashFiles('protoc-gen-rust-grpc/src/**', 'protoc-gen-rust-grpc/.bazelrc', 'protoc-gen-rust-grpc/MODULE.bazel') }} + - name: Add protoc plugin to PATH + shell: bash + run: | + echo "${{ runner.temp }}/protoc-plugin" >> $GITHUB_PATH - uses: Swatinem/rust-cache@v2 - run: cargo hack udeps --workspace --exclude-features=_tls-any,tls,tls-aws-lc,tls-ring --each-feature - run: cargo udeps --package tonic --features tls-ring,transport @@ -66,6 +137,7 @@ jobs: check: runs-on: ${{ matrix.os }} + needs: build-protoc-plugin strategy: matrix: os: [ubuntu-latest, macOS-latest, windows-latest] @@ -76,6 +148,16 @@ jobs: - uses: hecrj/setup-rust-action@v2 - uses: taiki-e/install-action@cargo-hack - uses: taiki-e/install-action@protoc + - name: Restore protoc plugin from cache + id: cache-plugin + uses: actions/cache@v4 + with: + path: ${{ runner.temp }}/protoc-plugin + key: ${{ runner.os }}-protoc-plugin-${{ hashFiles('protoc-gen-rust-grpc/src/**', 'protoc-gen-rust-grpc/.bazelrc', 'protoc-gen-rust-grpc/MODULE.bazel') }} + - name: Add protoc plugin to PATH + shell: bash + run: | + echo "${{ runner.temp }}/protoc-plugin" >> $GITHUB_PATH - uses: Swatinem/rust-cache@v2 - name: Check features run: cargo hack check --workspace --no-private --each-feature --no-dev-deps @@ -108,6 +190,7 @@ jobs: test: runs-on: ${{ matrix.os }} + needs: build-protoc-plugin strategy: matrix: os: [ubuntu-latest, macOS-latest, windows-latest] @@ -115,6 +198,16 @@ jobs: - uses: actions/checkout@v4 - uses: hecrj/setup-rust-action@v2 - uses: taiki-e/install-action@protoc + - name: Restore protoc plugin from cache + id: cache-plugin + uses: actions/cache@v4 + with: + path: ${{ runner.temp }}/protoc-plugin + key: ${{ runner.os }}-protoc-plugin-${{ hashFiles('protoc-gen-rust-grpc/src/**', 'protoc-gen-rust-grpc/.bazelrc', 'protoc-gen-rust-grpc/MODULE.bazel') }} + - name: Add protoc plugin to PATH + shell: bash + run: | + echo "${{ runner.temp }}/protoc-plugin" >> $GITHUB_PATH - uses: taiki-e/install-action@cargo-hack - uses: taiki-e/install-action@cargo-nextest - uses: Swatinem/rust-cache@v2 @@ -134,6 +227,7 @@ jobs: interop: name: Interop Tests runs-on: ${{ matrix.os }} + needs: build-protoc-plugin strategy: matrix: os: [ubuntu-latest, macOS-latest, windows-latest] @@ -141,6 +235,16 @@ jobs: - uses: actions/checkout@v4 - uses: hecrj/setup-rust-action@v2 - uses: taiki-e/install-action@protoc + - name: Restore protoc plugin from cache + id: cache-plugin + uses: actions/cache@v4 + with: + path: ${{ runner.temp }}/protoc-plugin + key: ${{ runner.os }}-protoc-plugin-${{ hashFiles('protoc-gen-rust-grpc/src/**', 'protoc-gen-rust-grpc/.bazelrc', 'protoc-gen-rust-grpc/MODULE.bazel') }} + - name: Add protoc plugin to PATH + shell: bash + run: | + echo "${{ runner.temp }}/protoc-plugin" >> $GITHUB_PATH - uses: Swatinem/rust-cache@v2 - name: Run interop tests run: ./interop/test.sh diff --git a/Cargo.toml b/Cargo.toml index f72e4746a..3c37bad6f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,8 @@ members = [ "tonic", "tonic-build", "tonic-health", + "tonic-protobuf", + "tonic-protobuf-build", "tonic-types", "tonic-reflection", "tonic-prost", diff --git a/grpc/src/lib.rs b/grpc/src/lib.rs index 45352523b..f56fd2cab 100644 --- a/grpc/src/lib.rs +++ b/grpc/src/lib.rs @@ -34,6 +34,7 @@ pub mod client; pub mod credentials; pub mod inmemory; +mod macros; pub mod rt; pub mod server; pub mod service; diff --git a/grpc/src/macros.rs b/grpc/src/macros.rs new file mode 100644 index 000000000..aaf0bb942 --- /dev/null +++ b/grpc/src/macros.rs @@ -0,0 +1,104 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + * + */ + +/// Includes generated proto message, client, and server code. +/// +/// You must specify the path to the `.proto` file +/// **relative to the proto root directory**, without the `.proto` extension. +/// +/// For example, if your proto directory is `path/to/protos` and it contains the +/// file `helloworld.proto`, you would write: +/// +/// ```rust,ignore +/// mod pb { +/// grpc::include_proto!("path/to/protos", "helloworld"); +/// } +/// ``` +/// +/// # Note +/// **This macro only works if the gRPC build output directory and message path +/// are unmodified.** +/// By default: +/// - The output directory is set to the [`OUT_DIR`] environment variable. +/// - The message path is set to `self`. +/// +/// If your `.proto` files are not in a subdirectory, you can omit the first +/// parameter. +/// +/// ```rust,ignore +/// mod pb { +/// grpc::include_proto!("helloworld"); +/// } +/// ``` +/// +/// If you have modified the output directory or message path, you should +/// include the generated code manually instead of using this macro. +/// +/// The following example assumes the message code is imported using `self`: +/// +/// ```rust,ignore +/// mod protos { +/// // Include message code. +/// include!("relative/protobuf/directory/generated.rs"); +/// +/// // Include service code. +/// include!("relative/protobuf/directory/helloworld_grpc.pb.rs"); +/// } +/// ``` +/// +/// If the message code and service code are in different modules, and the +/// message path specified during code generation is `super::protos`, use: +/// +/// ```rust,ignore +/// mod protos { +/// // Include message code. +/// include!("relative/protobuf/directory/generated.rs"); +/// } +/// +/// mod grpc { +/// // Include service code. +/// include!("relative/protobuf/directory/helloworld_grpc.pb.rs"); +/// } +/// ``` +/// +/// [`OUT_DIR`]: https://doc.rust-lang.org/cargo/reference/environment-variables.html#environment-variables-cargo-sets-for-build-scripts +#[macro_export] +macro_rules! include_proto { + // Assume the generated output dir is OUT_DIR. + ($proto_file:literal) => { + $crate::include_proto!("", $proto_file); + }; + + ($parent_dir:literal, $proto_file:literal) => { + include!(concat!(env!("OUT_DIR"), "/", $parent_dir, "/generated.rs")); + include!(concat!( + env!("OUT_DIR"), + "/", + $parent_dir, + "/", + $proto_file, + "_grpc.pb.rs" + )); + }; +} diff --git a/interop/Cargo.toml b/interop/Cargo.toml index ba080667d..dfc69aa38 100644 --- a/interop/Cargo.toml +++ b/interop/Cargo.toml @@ -26,6 +26,15 @@ tonic = {path = "../tonic", features = ["tls-ring"]} tonic-prost = {path = "../tonic-prost"} tower = "0.5" tracing-subscriber = {version = "0.3"} +grpc = {path = "../grpc"} +# TODO: Remove the direct protobuf dependency after updating to version 4.32, +# which includes https://github.com/protocolbuffers/protobuf/pull/22764. +# We also need the protobuf-codegen crate to support configuring the path +# to the protobuf crate used in the generated message code, instead of +# defaulting to `::protobuf`. +protobuf = { version = "4.31.1-release" } +tonic-protobuf = {path = "../tonic-protobuf"} [build-dependencies] tonic-prost-build = {path = "../tonic-prost-build"} +tonic-protobuf-build = {path = "../tonic-protobuf-build"} diff --git a/interop/build.rs b/interop/build.rs index a73f69f34..ddd903f8d 100644 --- a/interop/build.rs +++ b/interop/build.rs @@ -2,6 +2,11 @@ fn main() { let proto = "proto/grpc/testing/test.proto"; tonic_prost_build::compile_protos(proto).unwrap(); + tonic_protobuf_build::CodeGen::new() + .include("proto/grpc/testing") + .inputs(["test.proto", "empty.proto", "messages.proto"]) + .compile() + .unwrap(); // prevent needing to rebuild if files (or deps) haven't changed println!("cargo:rerun-if-changed={proto}"); diff --git a/interop/src/bin/client.rs b/interop/src/bin/client.rs index 01c279200..8a6b4af2d 100644 --- a/interop/src/bin/client.rs +++ b/interop/src/bin/client.rs @@ -1,4 +1,5 @@ -use interop::client; +use interop::client::{InteropTest, InteropTestUnimplemented}; +use interop::{client_prost, client_protobuf}; use std::{str::FromStr, time::Duration}; use tonic::transport::Endpoint; use tonic::transport::{Certificate, ClientTlsConfig}; @@ -7,6 +8,25 @@ use tonic::transport::{Certificate, ClientTlsConfig}; struct Opts { use_tls: bool, test_case: Vec, + codec: Codec, +} + +#[derive(Debug)] +enum Codec { + Prost, + Protobuf, +} + +impl FromStr for Codec { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "prost" => Ok(Codec::Prost), + "protobuf" => Ok(Codec::Protobuf), + _ => Err(format!("Invalid codec: {}", s)), + } + } } impl Opts { @@ -17,6 +37,7 @@ impl Opts { test_case: pargs.value_from_fn("--test_case", |test_case| { test_case.split(',').map(Testcase::from_str).collect() })?, + codec: pargs.value_from_str("--codec")?, }) } } @@ -48,8 +69,19 @@ async fn main() -> Result<(), Box> { let channel = endpoint.connect().await?; - let mut client = client::TestClient::new(channel.clone()); - let mut unimplemented_client = client::UnimplementedClient::new(channel); + let (mut client, mut unimplemented_client): ( + Box, + Box, + ) = match matches.codec { + Codec::Prost => ( + Box::new(client_prost::TestClient::new(channel.clone())), + Box::new(client_prost::UnimplementedClient::new(channel)), + ), + Codec::Protobuf => ( + Box::new(client_protobuf::TestClient::new(channel.clone())), + Box::new(client_protobuf::UnimplementedClient::new(channel)), + ), + }; let mut failures = Vec::new(); @@ -58,31 +90,25 @@ async fn main() -> Result<(), Box> { let mut test_results = Vec::new(); match test_case { - Testcase::EmptyUnary => client::empty_unary(&mut client, &mut test_results).await, - Testcase::LargeUnary => client::large_unary(&mut client, &mut test_results).await, - Testcase::ClientStreaming => { - client::client_streaming(&mut client, &mut test_results).await - } - Testcase::ServerStreaming => { - client::server_streaming(&mut client, &mut test_results).await - } - Testcase::PingPong => client::ping_pong(&mut client, &mut test_results).await, - Testcase::EmptyStream => client::empty_stream(&mut client, &mut test_results).await, + Testcase::EmptyUnary => client.empty_unary(&mut test_results).await, + Testcase::LargeUnary => client.large_unary(&mut test_results).await, + Testcase::ClientStreaming => client.client_streaming(&mut test_results).await, + Testcase::ServerStreaming => client.server_streaming(&mut test_results).await, + Testcase::PingPong => client.ping_pong(&mut test_results).await, + Testcase::EmptyStream => client.empty_stream(&mut test_results).await, Testcase::StatusCodeAndMessage => { - client::status_code_and_message(&mut client, &mut test_results).await + client.status_code_and_message(&mut test_results).await } Testcase::SpecialStatusMessage => { - client::special_status_message(&mut client, &mut test_results).await - } - Testcase::UnimplementedMethod => { - client::unimplemented_method(&mut client, &mut test_results).await + client.special_status_message(&mut test_results).await } + Testcase::UnimplementedMethod => client.unimplemented_method(&mut test_results).await, Testcase::UnimplementedService => { - client::unimplemented_service(&mut unimplemented_client, &mut test_results).await - } - Testcase::CustomMetadata => { - client::custom_metadata(&mut client, &mut test_results).await + unimplemented_client + .unimplemented_service(&mut test_results) + .await } + Testcase::CustomMetadata => client.custom_metadata(&mut test_results).await, _ => unimplemented!(), } diff --git a/interop/src/client.rs b/interop/src/client.rs index 389264684..1e448d652 100644 --- a/interop/src/client.rs +++ b/interop/src/client.rs @@ -1,410 +1,30 @@ -use crate::{ - pb::test_service_client::*, pb::unimplemented_service_client::*, pb::*, test_assert, - TestAssertion, -}; -use tokio::sync::mpsc; -use tokio_stream::StreamExt; -use tonic::transport::Channel; -use tonic::{metadata::MetadataValue, Code, Request, Response, Status}; +use crate::TestAssertion; +use tonic::async_trait; -pub type TestClient = TestServiceClient; -pub type UnimplementedClient = UnimplementedServiceClient; +#[async_trait] +pub trait InteropTest: Send { + async fn empty_unary(&mut self, assertions: &mut Vec); -const LARGE_REQ_SIZE: usize = 271_828; -const LARGE_RSP_SIZE: i32 = 314_159; -const REQUEST_LENGTHS: &[i32] = &[27182, 8, 1828, 45904]; -const RESPONSE_LENGTHS: &[i32] = &[31415, 9, 2653, 58979]; -const TEST_STATUS_MESSAGE: &str = "test status message"; -const SPECIAL_TEST_STATUS_MESSAGE: &str = - "\t\ntest with whitespace\r\nand Unicode BMP ☺ and non-BMP 😈\t\n"; + async fn large_unary(&mut self, assertions: &mut Vec); -pub async fn empty_unary(client: &mut TestClient, assertions: &mut Vec) { - let result = client.empty_call(Request::new(Empty {})).await; + async fn client_streaming(&mut self, assertions: &mut Vec); - assertions.push(test_assert!( - "call must be successful", - result.is_ok(), - format!("result={:?}", result) - )); + async fn server_streaming(&mut self, assertions: &mut Vec); - if let Ok(response) = result { - let body = response.into_inner(); - assertions.push(test_assert!( - "body must not be null", - body == Empty {}, - format!("body={:?}", body) - )); - } -} - -pub async fn large_unary(client: &mut TestClient, assertions: &mut Vec) { - use std::mem; - let payload = crate::client_payload(LARGE_REQ_SIZE); - let req = SimpleRequest { - response_type: PayloadType::Compressable as i32, - response_size: LARGE_RSP_SIZE, - payload: Some(payload), - ..Default::default() - }; - - let result = client.unary_call(Request::new(req)).await; - - assertions.push(test_assert!( - "call must be successful", - result.is_ok(), - format!("result={:?}", result) - )); - - if let Ok(response) = result { - let body = response.into_inner(); - let payload_len = body.payload.as_ref().map(|p| p.body.len()).unwrap_or(0); - - assertions.push(test_assert!( - "body must be 314159 bytes", - payload_len == LARGE_RSP_SIZE as usize, - format!("mem::size_of_val(&body)={:?}", mem::size_of_val(&body)) - )); - } -} - -// pub async fn cachable_unary(client: &mut Client, assertions: &mut Vec) { -// let payload = Payload { -// r#type: PayloadType::Compressable as i32, -// body: format!("{:?}", std::time::Instant::now()).into_bytes(), -// }; -// let req = SimpleRequest { -// response_type: PayloadType::Compressable as i32, -// payload: Some(payload), -// ..Default::default() -// }; - -// client. -// } - -pub async fn client_streaming(client: &mut TestClient, assertions: &mut Vec) { - let requests = REQUEST_LENGTHS.iter().map(|len| StreamingInputCallRequest { - payload: Some(crate::client_payload(*len as usize)), - ..Default::default() - }); - - let stream = tokio_stream::iter(requests); - - let result = client.streaming_input_call(Request::new(stream)).await; - - assertions.push(test_assert!( - "call must be successful", - result.is_ok(), - format!("result={:?}", result) - )); - - if let Ok(response) = result { - let body = response.into_inner(); - - assertions.push(test_assert!( - "aggregated payload size must be 74922 bytes", - body.aggregated_payload_size == 74922, - format!("aggregated_payload_size={:?}", body.aggregated_payload_size) - )); - } -} - -pub async fn server_streaming(client: &mut TestClient, assertions: &mut Vec) { - let req = StreamingOutputCallRequest { - response_parameters: RESPONSE_LENGTHS - .iter() - .map(|len| ResponseParameters::with_size(*len)) - .collect(), - ..Default::default() - }; - let req = Request::new(req); - - let result = client.streaming_output_call(req).await; - - assertions.push(test_assert!( - "call must be successful", - result.is_ok(), - format!("result={:?}", result) - )); - - if let Ok(response) = result { - let responses = response - .into_inner() - .filter_map(|m| m.ok()) - .collect::>() - .await; - let actual_response_lengths = crate::response_lengths(&responses); - let asserts = vec![ - test_assert!( - "there should be four responses", - responses.len() == 4, - format!("responses.len()={:?}", responses.len()) - ), - test_assert!( - "the response payload sizes should match input", - RESPONSE_LENGTHS == actual_response_lengths.as_slice(), - format!("{:?}={:?}", RESPONSE_LENGTHS, actual_response_lengths) - ), - ]; - - assertions.extend(asserts); - } -} - -pub async fn ping_pong(client: &mut TestClient, assertions: &mut Vec) { - let (tx, rx) = mpsc::unbounded_channel(); - tx.send(make_ping_pong_request(0)).unwrap(); - - let result = client - .full_duplex_call(Request::new( - tokio_stream::wrappers::UnboundedReceiverStream::new(rx), - )) - .await; - - assertions.push(test_assert!( - "call must be successful", - result.is_ok(), - format!("result={:?}", result) - )); - - if let Ok(mut response) = result.map(Response::into_inner) { - let mut responses = Vec::new(); - - loop { - match response.next().await { - Some(result) => { - responses.push(result.unwrap()); - if responses.len() == REQUEST_LENGTHS.len() { - drop(tx); - break; - } else { - tx.send(make_ping_pong_request(responses.len())).unwrap(); - } - } - None => { - assertions.push(TestAssertion::Failed { - description: - "server should keep the stream open until the client closes it", - expression: "Stream terminated unexpectedly early", - why: None, - }); - break; - } - } - } - - let actual_response_lengths = crate::response_lengths(&responses); - assertions.push(test_assert!( - "there should be four responses", - responses.len() == RESPONSE_LENGTHS.len(), - format!("{:?}={:?}", responses.len(), RESPONSE_LENGTHS.len()) - )); - assertions.push(test_assert!( - "the response payload sizes should match input", - RESPONSE_LENGTHS == actual_response_lengths.as_slice(), - format!("{:?}={:?}", RESPONSE_LENGTHS, actual_response_lengths) - )); - } -} - -pub async fn empty_stream(client: &mut TestClient, assertions: &mut Vec) { - let stream = tokio_stream::empty(); - let result = client.full_duplex_call(Request::new(stream)).await; - - assertions.push(test_assert!( - "call must be successful", - result.is_ok(), - format!("result={:?}", result) - )); - - if let Ok(response) = result.map(Response::into_inner) { - let responses = response.collect::>().await; - - assertions.push(test_assert!( - "there should be no responses", - responses.is_empty(), - format!("responses.len()={:?}", responses.len()) - )); - } -} - -pub async fn status_code_and_message(client: &mut TestClient, assertions: &mut Vec) { - fn validate_response(result: Result, assertions: &mut Vec) - where - T: std::fmt::Debug, - { - assertions.push(test_assert!( - "call must fail with unknown status code", - match &result { - Err(status) => status.code() == Code::Unknown, - _ => false, - }, - format!("result={:?}", result) - )); - - assertions.push(test_assert!( - "call must respsond with expected status message", - match &result { - Err(status) => status.message() == TEST_STATUS_MESSAGE, - _ => false, - }, - format!("result={:?}", result) - )); - } - - let simple_req = SimpleRequest { - response_status: Some(EchoStatus { - code: 2, - message: TEST_STATUS_MESSAGE.to_string(), - }), - ..Default::default() - }; - - let duplex_req = StreamingOutputCallRequest { - response_status: Some(EchoStatus { - code: 2, - message: TEST_STATUS_MESSAGE.to_string(), - }), - ..Default::default() - }; - - let result = client.unary_call(Request::new(simple_req)).await; - validate_response(result, assertions); - - let stream = tokio_stream::once(duplex_req); - let result = match client.full_duplex_call(Request::new(stream)).await { - Ok(response) => { - let stream = response.into_inner(); - let responses = stream.collect::>().await; - Ok(responses) - } - Err(e) => Err(e), - }; - - validate_response(result, assertions); -} - -pub async fn special_status_message(client: &mut TestClient, assertions: &mut Vec) { - let req = SimpleRequest { - response_status: Some(EchoStatus { - code: 2, - message: SPECIAL_TEST_STATUS_MESSAGE.to_string(), - }), - ..Default::default() - }; - - let result = client.unary_call(Request::new(req)).await; - - assertions.push(test_assert!( - "call must fail with unknown status code", - match &result { - Err(status) => status.code() == Code::Unknown, - _ => false, - }, - format!("result={:?}", result) - )); - - assertions.push(test_assert!( - "call must respsond with expected status message", - match &result { - Err(status) => status.message() == SPECIAL_TEST_STATUS_MESSAGE, - _ => false, - }, - format!("result={:?}", result) - )); -} - -pub async fn unimplemented_method(client: &mut TestClient, assertions: &mut Vec) { - let result = client.unimplemented_call(Request::new(Empty {})).await; - assertions.push(test_assert!( - "call must fail with unimplemented status code", - match &result { - Err(status) => status.code() == Code::Unimplemented, - _ => false, - }, - format!("result={:?}", result) - )); -} - -pub async fn unimplemented_service( - client: &mut UnimplementedClient, - assertions: &mut Vec, -) { - let result = client.unimplemented_call(Request::new(Empty {})).await; - assertions.push(test_assert!( - "call must fail with unimplemented status code", - match &result { - Err(status) => status.code() == Code::Unimplemented, - _ => false, - }, - format!("result={:?}", result) - )); -} - -pub async fn custom_metadata(client: &mut TestClient, assertions: &mut Vec) { - let key1 = "x-grpc-test-echo-initial"; - let value1: MetadataValue<_> = "test_initial_metadata_value".parse().unwrap(); - let key2 = "x-grpc-test-echo-trailing-bin"; - let value2 = MetadataValue::from_bytes(&[0xab, 0xab, 0xab]); - - let req = SimpleRequest { - response_type: PayloadType::Compressable as i32, - response_size: LARGE_RSP_SIZE, - payload: Some(crate::client_payload(LARGE_REQ_SIZE)), - ..Default::default() - }; - let mut req_unary = Request::new(req); - req_unary.metadata_mut().insert(key1, value1.clone()); - req_unary.metadata_mut().insert_bin(key2, value2.clone()); - - let stream = tokio_stream::once(make_ping_pong_request(0)); - let mut req_stream = Request::new(stream); - req_stream.metadata_mut().insert(key1, value1.clone()); - req_stream.metadata_mut().insert_bin(key2, value2.clone()); - - let response = client - .unary_call(req_unary) - .await - .expect("call should pass."); - - assertions.push(test_assert!( - "metadata string must match in unary", - response.metadata().get(key1) == Some(&value1), - format!("result={:?}", response.metadata().get(key1)) - )); - assertions.push(test_assert!( - "metadata bin must match in unary", - response.metadata().get_bin(key2) == Some(&value2), - format!("result={:?}", response.metadata().get_bin(key1)) - )); + async fn ping_pong(&mut self, assertions: &mut Vec); - let response = client - .full_duplex_call(req_stream) - .await - .expect("call should pass."); + async fn empty_stream(&mut self, assertions: &mut Vec); - assertions.push(test_assert!( - "metadata string must match in unary", - response.metadata().get(key1) == Some(&value1), - format!("result={:?}", response.metadata().get(key1)) - )); + async fn status_code_and_message(&mut self, assertions: &mut Vec); - let mut stream = response.into_inner(); + async fn special_status_message(&mut self, assertions: &mut Vec); - let trailers = stream.trailers().await.unwrap().unwrap(); + async fn unimplemented_method(&mut self, assertions: &mut Vec); - assertions.push(test_assert!( - "metadata bin must match in unary", - trailers.get_bin(key2) == Some(&value2), - format!("result={:?}", trailers.get_bin(key1)) - )); + async fn custom_metadata(&mut self, assertions: &mut Vec); } -fn make_ping_pong_request(idx: usize) -> StreamingOutputCallRequest { - let req_len = REQUEST_LENGTHS[idx]; - let resp_len = RESPONSE_LENGTHS[idx]; - StreamingOutputCallRequest { - response_parameters: vec![ResponseParameters::with_size(resp_len)], - payload: Some(crate::client_payload(req_len as usize)), - ..Default::default() - } +#[async_trait] +pub trait InteropTestUnimplemented: Send { + async fn unimplemented_service(&mut self, assertions: &mut Vec); } diff --git a/interop/src/client_prost.rs b/interop/src/client_prost.rs new file mode 100644 index 000000000..50299b8a2 --- /dev/null +++ b/interop/src/client_prost.rs @@ -0,0 +1,419 @@ +use crate::client::{InteropTest, InteropTestUnimplemented}; +use crate::{ + pb::test_service_client::*, pb::unimplemented_service_client::*, pb::*, test_assert, + TestAssertion, +}; +use tokio::sync::mpsc; +use tokio_stream::StreamExt; +use tonic::async_trait; +use tonic::transport::Channel; +use tonic::{metadata::MetadataValue, Code, Request, Response, Status}; + +pub type TestClient = TestServiceClient; +pub type UnimplementedClient = UnimplementedServiceClient; + +const LARGE_REQ_SIZE: usize = 271_828; +const LARGE_RSP_SIZE: i32 = 314_159; +const REQUEST_LENGTHS: &[i32] = &[27182, 8, 1828, 45904]; +const RESPONSE_LENGTHS: &[i32] = &[31415, 9, 2653, 58979]; +const TEST_STATUS_MESSAGE: &str = "test status message"; +const SPECIAL_TEST_STATUS_MESSAGE: &str = + "\t\ntest with whitespace\r\nand Unicode BMP ☺ and non-BMP 😈\t\n"; + +#[async_trait] +impl InteropTest for TestClient { + async fn empty_unary(&mut self, assertions: &mut Vec) { + let result = self.empty_call(Request::new(Empty {})).await; + + assertions.push(test_assert!( + "call must be successful", + result.is_ok(), + format!("result={:?}", result) + )); + + if let Ok(response) = result { + let body = response.into_inner(); + assertions.push(test_assert!( + "body must not be null", + body == Empty {}, + format!("body={:?}", body) + )); + } + } + + async fn large_unary(&mut self, assertions: &mut Vec) { + use std::mem; + let payload = crate::client_payload(LARGE_REQ_SIZE); + let req = SimpleRequest { + response_type: PayloadType::Compressable as i32, + response_size: LARGE_RSP_SIZE, + payload: Some(payload), + ..Default::default() + }; + + let result = self.unary_call(Request::new(req)).await; + + assertions.push(test_assert!( + "call must be successful", + result.is_ok(), + format!("result={:?}", result) + )); + + if let Ok(response) = result { + let body = response.into_inner(); + let payload_len = body.payload.as_ref().map(|p| p.body.len()).unwrap_or(0); + + assertions.push(test_assert!( + "body must be 314159 bytes", + payload_len == LARGE_RSP_SIZE as usize, + format!("mem::size_of_val(&body)={:?}", mem::size_of_val(&body)) + )); + } + } + + // async fn cachable_unary(client: &mut Client, assertions: &mut Vec) { + // let payload = Payload { + // r#type: PayloadType::Compressable as i32, + // body: format!("{:?}", std::time::Instant::now()).into_bytes(), + // }; + // let req = SimpleRequest { + // response_type: PayloadType::Compressable as i32, + // payload: Some(payload), + // ..Default::default() + // }; + + // self. + // } + + async fn client_streaming(&mut self, assertions: &mut Vec) { + let requests: Vec<_> = REQUEST_LENGTHS + .iter() + .map(make_streaming_input_request) + .collect(); + + let stream = tokio_stream::iter(requests); + + let result = self.streaming_input_call(Request::new(stream)).await; + + assertions.push(test_assert!( + "call must be successful", + result.is_ok(), + format!("result={:?}", result) + )); + + if let Ok(response) = result { + let body = response.into_inner(); + + assertions.push(test_assert!( + "aggregated payload size must be 74922 bytes", + body.aggregated_payload_size == 74922, + format!("aggregated_payload_size={:?}", body.aggregated_payload_size) + )); + } + } + + async fn server_streaming(&mut self, assertions: &mut Vec) { + let req = StreamingOutputCallRequest { + response_parameters: RESPONSE_LENGTHS + .iter() + .map(|len| ResponseParameters::with_size(*len)) + .collect(), + ..Default::default() + }; + let req = Request::new(req); + + let result = self.streaming_output_call(req).await; + + assertions.push(test_assert!( + "call must be successful", + result.is_ok(), + format!("result={:?}", result) + )); + + if let Ok(response) = result { + let responses = response + .into_inner() + .filter_map(|m| m.ok()) + .collect::>() + .await; + let actual_response_lengths = crate::response_lengths(&responses); + let asserts = vec![ + test_assert!( + "there should be four responses", + responses.len() == 4, + format!("responses.len()={:?}", responses.len()) + ), + test_assert!( + "the response payload sizes should match input", + RESPONSE_LENGTHS == actual_response_lengths.as_slice(), + format!("{:?}={:?}", RESPONSE_LENGTHS, actual_response_lengths) + ), + ]; + + assertions.extend(asserts); + } + } + + async fn ping_pong(&mut self, assertions: &mut Vec) { + let (tx, rx) = mpsc::unbounded_channel(); + tx.send(make_ping_pong_request(0)).unwrap(); + + let result = self + .full_duplex_call(Request::new( + tokio_stream::wrappers::UnboundedReceiverStream::new(rx), + )) + .await; + + assertions.push(test_assert!( + "call must be successful", + result.is_ok(), + format!("result={:?}", result) + )); + + if let Ok(mut response) = result.map(Response::into_inner) { + let mut responses = Vec::new(); + + loop { + match response.next().await { + Some(result) => { + responses.push(result.unwrap()); + if responses.len() == REQUEST_LENGTHS.len() { + drop(tx); + break; + } else { + tx.send(make_ping_pong_request(responses.len())).unwrap(); + } + } + None => { + assertions.push(TestAssertion::Failed { + description: + "server should keep the stream open until the client closes it", + expression: "Stream terminated unexpectedly early", + why: None, + }); + break; + } + } + } + + let actual_response_lengths = crate::response_lengths(&responses); + assertions.push(test_assert!( + "there should be four responses", + responses.len() == RESPONSE_LENGTHS.len(), + format!("{:?}={:?}", responses.len(), RESPONSE_LENGTHS.len()) + )); + assertions.push(test_assert!( + "the response payload sizes should match input", + RESPONSE_LENGTHS == actual_response_lengths.as_slice(), + format!("{:?}={:?}", RESPONSE_LENGTHS, actual_response_lengths) + )); + } + } + + async fn empty_stream(&mut self, assertions: &mut Vec) { + let stream = tokio_stream::empty(); + let result = self.full_duplex_call(Request::new(stream)).await; + + assertions.push(test_assert!( + "call must be successful", + result.is_ok(), + format!("result={:?}", result) + )); + + if let Ok(response) = result.map(Response::into_inner) { + let responses = response.collect::>().await; + + assertions.push(test_assert!( + "there should be no responses", + responses.is_empty(), + format!("responses.len()={:?}", responses.len()) + )); + } + } + + async fn status_code_and_message(&mut self, assertions: &mut Vec) { + fn validate_response(result: Result, assertions: &mut Vec) + where + T: std::fmt::Debug, + { + assertions.push(test_assert!( + "call must fail with unknown status code", + match &result { + Err(status) => status.code() == Code::Unknown, + _ => false, + }, + format!("result={:?}", result) + )); + + assertions.push(test_assert!( + "call must respsond with expected status message", + match &result { + Err(status) => status.message() == TEST_STATUS_MESSAGE, + _ => false, + }, + format!("result={:?}", result) + )); + } + + let simple_req = SimpleRequest { + response_status: Some(EchoStatus { + code: 2, + message: TEST_STATUS_MESSAGE.to_string(), + }), + ..Default::default() + }; + + let duplex_req = StreamingOutputCallRequest { + response_status: Some(EchoStatus { + code: 2, + message: TEST_STATUS_MESSAGE.to_string(), + }), + ..Default::default() + }; + + let result = self.unary_call(Request::new(simple_req)).await; + validate_response(result, assertions); + + let stream = tokio_stream::once(duplex_req); + let result = match self.full_duplex_call(Request::new(stream)).await { + Ok(response) => { + let stream = response.into_inner(); + let responses = stream.collect::>().await; + Ok(responses) + } + Err(e) => Err(e), + }; + + validate_response(result, assertions); + } + + async fn special_status_message(&mut self, assertions: &mut Vec) { + let req = SimpleRequest { + response_status: Some(EchoStatus { + code: 2, + message: SPECIAL_TEST_STATUS_MESSAGE.to_string(), + }), + ..Default::default() + }; + + let result = self.unary_call(Request::new(req)).await; + + assertions.push(test_assert!( + "call must fail with unknown status code", + match &result { + Err(status) => status.code() == Code::Unknown, + _ => false, + }, + format!("result={:?}", result) + )); + + assertions.push(test_assert!( + "call must respsond with expected status message", + match &result { + Err(status) => status.message() == SPECIAL_TEST_STATUS_MESSAGE, + _ => false, + }, + format!("result={:?}", result) + )); + } + + async fn unimplemented_method(&mut self, assertions: &mut Vec) { + let result = self.unimplemented_call(Request::new(Empty {})).await; + assertions.push(test_assert!( + "call must fail with unimplemented status code", + match &result { + Err(status) => status.code() == Code::Unimplemented, + _ => false, + }, + format!("result={:?}", result) + )); + } + + async fn custom_metadata(&mut self, assertions: &mut Vec) { + let key1 = "x-grpc-test-echo-initial"; + let value1: MetadataValue<_> = "test_initial_metadata_value".parse().unwrap(); + let key2 = "x-grpc-test-echo-trailing-bin"; + let value2 = MetadataValue::from_bytes(&[0xab, 0xab, 0xab]); + + let req = SimpleRequest { + response_type: PayloadType::Compressable as i32, + response_size: LARGE_RSP_SIZE, + payload: Some(crate::client_payload(LARGE_REQ_SIZE)), + ..Default::default() + }; + let mut req_unary = Request::new(req); + req_unary.metadata_mut().insert(key1, value1.clone()); + req_unary.metadata_mut().insert_bin(key2, value2.clone()); + + let stream = tokio_stream::once(make_ping_pong_request(0)); + let mut req_stream = Request::new(stream); + req_stream.metadata_mut().insert(key1, value1.clone()); + req_stream.metadata_mut().insert_bin(key2, value2.clone()); + + let response = self.unary_call(req_unary).await.expect("call should pass."); + + assertions.push(test_assert!( + "metadata string must match in unary", + response.metadata().get(key1) == Some(&value1), + format!("result={:?}", response.metadata().get(key1)) + )); + assertions.push(test_assert!( + "metadata bin must match in unary", + response.metadata().get_bin(key2) == Some(&value2), + format!("result={:?}", response.metadata().get_bin(key1)) + )); + + let response = self + .full_duplex_call(req_stream) + .await + .expect("call should pass."); + + assertions.push(test_assert!( + "metadata string must match in unary", + response.metadata().get(key1) == Some(&value1), + format!("result={:?}", response.metadata().get(key1)) + )); + + let mut stream = response.into_inner(); + + let trailers = stream.trailers().await.unwrap().unwrap(); + + assertions.push(test_assert!( + "metadata bin must match in unary", + trailers.get_bin(key2) == Some(&value2), + format!("result={:?}", trailers.get_bin(key1)) + )); + } +} + +#[async_trait] +impl InteropTestUnimplemented for UnimplementedClient { + async fn unimplemented_service(&mut self, assertions: &mut Vec) { + let result = self.unimplemented_call(Request::new(Empty {})).await; + assertions.push(test_assert!( + "call must fail with unimplemented status code", + match &result { + Err(status) => status.code() == Code::Unimplemented, + _ => false, + }, + format!("result={:?}", result) + )); + } +} + +fn make_ping_pong_request(idx: usize) -> StreamingOutputCallRequest { + let req_len = REQUEST_LENGTHS[idx]; + let resp_len = RESPONSE_LENGTHS[idx]; + StreamingOutputCallRequest { + response_parameters: vec![ResponseParameters::with_size(resp_len)], + payload: Some(crate::client_payload(req_len as usize)), + ..Default::default() + } +} + +fn make_streaming_input_request(len: &i32) -> StreamingInputCallRequest { + StreamingInputCallRequest { + payload: Some(crate::client_payload(*len as usize)), + ..Default::default() + } +} diff --git a/interop/src/client_protobuf.rs b/interop/src/client_protobuf.rs new file mode 100644 index 000000000..9210ec830 --- /dev/null +++ b/interop/src/client_protobuf.rs @@ -0,0 +1,429 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + * + */ + +use crate::client::{InteropTest, InteropTestUnimplemented}; +use crate::{ + grpc_pb::test_service_client::*, grpc_pb::unimplemented_service_client::*, grpc_pb::*, + test_assert, TestAssertion, +}; +use tokio::sync::mpsc; +use tokio_stream::StreamExt; +use tonic::async_trait; +use tonic::transport::Channel; +use tonic::{metadata::MetadataValue, Code, Request, Response, Status}; +use tonic_protobuf::protobuf::__internal::MatcherEq; +use tonic_protobuf::protobuf::proto; + +pub type TestClient = TestServiceClient; +pub type UnimplementedClient = UnimplementedServiceClient; + +const LARGE_REQ_SIZE: usize = 271_828; +const LARGE_RSP_SIZE: i32 = 314_159; +const REQUEST_LENGTHS: &[i32] = &[27182, 8, 1828, 45904]; +const RESPONSE_LENGTHS: &[i32] = &[31415, 9, 2653, 58979]; +const TEST_STATUS_MESSAGE: &str = "test status message"; +const SPECIAL_TEST_STATUS_MESSAGE: &str = + "\t\ntest with whitespace\r\nand Unicode BMP ☺ and non-BMP 😈\t\n"; + +#[async_trait] +impl InteropTest for TestClient { + async fn empty_unary(&mut self, assertions: &mut Vec) { + let result = self.empty_call(Request::new(Empty::default())).await; + + assertions.push(test_assert!( + "call must be successful", + result.is_ok(), + format!("result={:?}", result) + )); + + if let Ok(response) = result { + let body = response.into_inner(); + assertions.push(test_assert!( + "body must not be null", + body.matches(&Empty::default()), + format!("body={:?}", body) + )); + } + } + + async fn large_unary(&mut self, assertions: &mut Vec) { + use std::mem; + let payload = crate::grpc_utils::client_payload(LARGE_REQ_SIZE); + let req = proto!(SimpleRequest { + response_type: PayloadType::Compressable, + response_size: LARGE_RSP_SIZE, + payload: payload, + }); + + let result = self.unary_call(Request::new(req)).await; + + assertions.push(test_assert!( + "call must be successful", + result.is_ok(), + format!("result={:?}", result) + )); + + if let Ok(response) = result { + let body = response.into_inner(); + let payload_len = body.payload().body().len(); + + assertions.push(test_assert!( + "body must be 314159 bytes", + payload_len == LARGE_RSP_SIZE as usize, + format!("mem::size_of_val(&body)={:?}", mem::size_of_val(&body)) + )); + } + } + + async fn client_streaming(&mut self, assertions: &mut Vec) { + let requests: Vec<_> = REQUEST_LENGTHS + .iter() + .map(make_streaming_input_request) + .collect(); + + let stream = tokio_stream::iter(requests); + + let result = self.streaming_input_call(Request::new(stream)).await; + + assertions.push(test_assert!( + "call must be successful", + result.is_ok(), + format!("result={:?}", result) + )); + + if let Ok(response) = result { + let body = response.into_inner(); + + assertions.push(test_assert!( + "aggregated payload size must be 74922 bytes", + body.aggregated_payload_size() == 74922, + format!( + "aggregated_payload_size={:?}", + body.aggregated_payload_size() + ) + )); + } + } + + async fn server_streaming(&mut self, assertions: &mut Vec) { + let req = proto!(StreamingOutputCallRequest { + response_parameters: RESPONSE_LENGTHS + .iter() + .map(|len| ResponseParameters::with_size(*len)), + }); + let req = Request::new(req); + + let result = self.streaming_output_call(req).await; + + assertions.push(test_assert!( + "call must be successful", + result.is_ok(), + format!("result={:?}", result) + )); + + if let Ok(response) = result { + let responses = response + .into_inner() + .filter_map(|m| m.ok()) + .collect::>() + .await; + let actual_response_lengths = crate::grpc_utils::response_lengths(&responses); + let asserts = vec![ + test_assert!( + "there should be four responses", + responses.len() == 4, + format!("responses.len()={:?}", responses.len()) + ), + test_assert!( + "the response payload sizes should match input", + RESPONSE_LENGTHS == actual_response_lengths.as_slice(), + format!("{:?}={:?}", RESPONSE_LENGTHS, actual_response_lengths) + ), + ]; + + assertions.extend(asserts); + } + } + + async fn ping_pong(&mut self, assertions: &mut Vec) { + let (tx, rx) = mpsc::unbounded_channel(); + tx.send(make_ping_pong_request(0)).unwrap(); + + let result = self + .full_duplex_call(Request::new( + tokio_stream::wrappers::UnboundedReceiverStream::new(rx), + )) + .await; + + assertions.push(test_assert!( + "call must be successful", + result.is_ok(), + format!("result={:?}", result) + )); + + if let Ok(mut response) = result.map(Response::into_inner) { + let mut responses = Vec::new(); + + loop { + match response.next().await { + Some(result) => { + responses.push(result.unwrap()); + if responses.len() == REQUEST_LENGTHS.len() { + drop(tx); + break; + } else { + tx.send(make_ping_pong_request(responses.len())).unwrap(); + } + } + None => { + assertions.push(TestAssertion::Failed { + description: + "server should keep the stream open until the client closes it", + expression: "Stream terminated unexpectedly early", + why: None, + }); + break; + } + } + } + + let actual_response_lengths = crate::grpc_utils::response_lengths(&responses); + assertions.push(test_assert!( + "there should be four responses", + responses.len() == RESPONSE_LENGTHS.len(), + format!("{:?}={:?}", responses.len(), RESPONSE_LENGTHS.len()) + )); + assertions.push(test_assert!( + "the response payload sizes should match input", + RESPONSE_LENGTHS == actual_response_lengths.as_slice(), + format!("{:?}={:?}", RESPONSE_LENGTHS, actual_response_lengths) + )); + } + } + + async fn empty_stream(&mut self, assertions: &mut Vec) { + let stream = tokio_stream::empty(); + let result = self.full_duplex_call(Request::new(stream)).await; + + assertions.push(test_assert!( + "call must be successful", + result.is_ok(), + format!("result={:?}", result) + )); + + if let Ok(response) = result.map(Response::into_inner) { + let responses = response.collect::>().await; + + assertions.push(test_assert!( + "there should be no responses", + responses.is_empty(), + format!("responses.len()={:?}", responses.len()) + )); + } + } + + async fn status_code_and_message(&mut self, assertions: &mut Vec) { + fn validate_response(result: Result, assertions: &mut Vec) + where + T: std::fmt::Debug, + { + assertions.push(test_assert!( + "call must fail with unknown status code", + match &result { + Err(status) => status.code() == Code::Unknown, + _ => false, + }, + format!("result={:?}", result) + )); + + assertions.push(test_assert!( + "call must respsond with expected status message", + match &result { + Err(status) => status.message() == TEST_STATUS_MESSAGE, + _ => false, + }, + format!("result={:?}", result) + )); + } + + let simple_req = proto!(SimpleRequest { + response_status: EchoStatus { + code: 2, + message: TEST_STATUS_MESSAGE.to_string(), + }, + }); + + let duplex_req = proto!(StreamingOutputCallRequest { + response_status: EchoStatus { + code: 2, + message: TEST_STATUS_MESSAGE.to_string(), + }, + }); + + let result = self.unary_call(Request::new(simple_req)).await; + validate_response(result, assertions); + + let stream = tokio_stream::once(duplex_req); + let result = match self.full_duplex_call(Request::new(stream)).await { + Ok(response) => { + let stream = response.into_inner(); + let responses = stream.collect::>().await; + Ok(responses) + } + Err(e) => Err(e), + }; + + validate_response(result, assertions); + } + + async fn special_status_message(&mut self, assertions: &mut Vec) { + let req = proto!(SimpleRequest { + response_status: EchoStatus { + code: 2, + message: SPECIAL_TEST_STATUS_MESSAGE.to_string(), + }, + }); + + let result = self.unary_call(Request::new(req)).await; + + assertions.push(test_assert!( + "call must fail with unknown status code", + match &result { + Err(status) => status.code() == Code::Unknown, + _ => false, + }, + format!("result={:?}", result) + )); + + assertions.push(test_assert!( + "call must respsond with expected status message", + match &result { + Err(status) => status.message() == SPECIAL_TEST_STATUS_MESSAGE, + _ => false, + }, + format!("result={:?}", result) + )); + } + + async fn unimplemented_method(&mut self, assertions: &mut Vec) { + let result = self + .unimplemented_call(Request::new(Empty::default())) + .await; + assertions.push(test_assert!( + "call must fail with unimplemented status code", + match &result { + Err(status) => status.code() == Code::Unimplemented, + _ => false, + }, + format!("result={:?}", result) + )); + } + + async fn custom_metadata(&mut self, assertions: &mut Vec) { + let key1 = "x-grpc-test-echo-initial"; + let value1: MetadataValue<_> = "test_initial_metadata_value".parse().unwrap(); + let key2 = "x-grpc-test-echo-trailing-bin"; + let value2 = MetadataValue::from_bytes(&[0xab, 0xab, 0xab]); + + let req = proto!(SimpleRequest { + response_type: PayloadType::Compressable, + response_size: LARGE_RSP_SIZE, + payload: crate::grpc_utils::client_payload(LARGE_REQ_SIZE), + }); + let mut req_unary = Request::new(req); + req_unary.metadata_mut().insert(key1, value1.clone()); + req_unary.metadata_mut().insert_bin(key2, value2.clone()); + + let stream = tokio_stream::once(make_ping_pong_request(0)); + let mut req_stream = Request::new(stream); + req_stream.metadata_mut().insert(key1, value1.clone()); + req_stream.metadata_mut().insert_bin(key2, value2.clone()); + + let response = self.unary_call(req_unary).await.expect("call should pass."); + + assertions.push(test_assert!( + "metadata string must match in unary", + response.metadata().get(key1) == Some(&value1), + format!("result={:?}", response.metadata().get(key1)) + )); + assertions.push(test_assert!( + "metadata bin must match in unary", + response.metadata().get_bin(key2) == Some(&value2), + format!("result={:?}", response.metadata().get_bin(key1)) + )); + + let response = self + .full_duplex_call(req_stream) + .await + .expect("call should pass."); + + assertions.push(test_assert!( + "metadata string must match in unary", + response.metadata().get(key1) == Some(&value1), + format!("result={:?}", response.metadata().get(key1)) + )); + + let mut stream = response.into_inner(); + + let trailers = stream.trailers().await.unwrap().unwrap(); + + assertions.push(test_assert!( + "metadata bin must match in unary", + trailers.get_bin(key2) == Some(&value2), + format!("result={:?}", trailers.get_bin(key1)) + )); + } +} + +#[async_trait] +impl InteropTestUnimplemented for UnimplementedClient { + async fn unimplemented_service(&mut self, assertions: &mut Vec) { + let result = self + .unimplemented_call(Request::new(Empty::default())) + .await; + assertions.push(test_assert!( + "call must fail with unimplemented status code", + match &result { + Err(status) => status.code() == Code::Unimplemented, + _ => false, + }, + format!("result={:?}", result) + )); + } +} + +fn make_ping_pong_request(idx: usize) -> StreamingOutputCallRequest { + let req_len = REQUEST_LENGTHS[idx]; + let resp_len = RESPONSE_LENGTHS[idx]; + proto!(StreamingOutputCallRequest { + response_parameters: std::iter::once(ResponseParameters::with_size(resp_len)), + payload: crate::grpc_utils::client_payload(req_len as usize), + }) +} + +fn make_streaming_input_request(len: &i32) -> StreamingInputCallRequest { + proto!(StreamingInputCallRequest { + payload: crate::grpc_utils::client_payload(*len as usize), + }) +} diff --git a/interop/src/lib.rs b/interop/src/lib.rs index 961e0fdf7..239512534 100644 --- a/interop/src/lib.rs +++ b/interop/src/lib.rs @@ -1,6 +1,8 @@ #![recursion_limit = "256"] pub mod client; +pub mod client_prost; +pub mod client_protobuf; pub mod server; pub mod pb { @@ -9,6 +11,10 @@ pub mod pb { include!(concat!(env!("OUT_DIR"), "/grpc.testing.rs")); } +pub mod grpc_pb { + grpc::include_proto!("test"); +} + use std::{default, fmt, iter}; pub fn trace_init() { @@ -49,6 +55,32 @@ fn response_lengths(responses: &[pb::StreamingOutputCallResponse]) -> Vec { responses.iter().map(&response_length).collect() } +mod grpc_utils { + use super::grpc_pb; + use protobuf::proto; + use std::iter; + + pub(crate) fn client_payload(size: usize) -> grpc_pb::Payload { + proto!(grpc_pb::Payload { + body: iter::repeat_n(0u8, size).collect::>(), + }) + } + + impl grpc_pb::ResponseParameters { + pub(crate) fn with_size(size: i32) -> Self { + proto!(grpc_pb::ResponseParameters { size: size }) + } + } + + pub(crate) fn response_length(response: &grpc_pb::StreamingOutputCallResponse) -> i32 { + response.payload().body().len() as i32 + } + + pub(crate) fn response_lengths(responses: &[grpc_pb::StreamingOutputCallResponse]) -> Vec { + responses.iter().map(&response_length).collect() + } +} + #[derive(Debug)] pub enum TestAssertion { Passed { diff --git a/interop/test.sh b/interop/test.sh index c4628d164..4e5ff5933 100755 --- a/interop/test.sh +++ b/interop/test.sh @@ -57,7 +57,10 @@ trap 'echo ":; killing test server"; kill ${SERVER_PID};' EXIT sleep 1 -./target/debug/client --test_case="${JOINED_TEST_CASES}" "${ARG}" +./target/debug/client --codec=prost --test_case="${JOINED_TEST_CASES}" "${ARG}" + +# Test a grpc rust client against a Go server. +./target/debug/client --codec=protobuf --test_case="${JOINED_TEST_CASES}" ${ARG} echo ":; killing test server"; kill "${SERVER_PID}"; @@ -72,7 +75,7 @@ trap 'echo ":; killing test server"; kill ${SERVER_PID};' EXIT sleep 1 -./target/debug/client --test_case="${JOINED_TEST_CASES}" "${ARG}" +./target/debug/client --codec=prost --test_case="${JOINED_TEST_CASES}" "${ARG}" # Run client test cases if [ -n "${ARG:-}" ]; then diff --git a/protoc-gen-rust-grpc/.bazelrc b/protoc-gen-rust-grpc/.bazelrc new file mode 100644 index 000000000..441f80f3f --- /dev/null +++ b/protoc-gen-rust-grpc/.bazelrc @@ -0,0 +1,13 @@ +# Define a custom config for common Unix-like flags +build:unix --cxxopt=-std=c++17 +build:unix --host_cxxopt=-std=c++17 + +# Inherit the common 'unix' flags for both macOS and Linux +build:macos --config=unix +build:linux --config=unix + +# Windows flags remain as they are +build:windows --cxxopt=/std:c++17 +build:windows --host_cxxopt=/std:c++17 +build:windows --define=protobuf_allow_msvc=true + diff --git a/protoc-gen-rust-grpc/.gitignore b/protoc-gen-rust-grpc/.gitignore new file mode 100644 index 000000000..6a0537992 --- /dev/null +++ b/protoc-gen-rust-grpc/.gitignore @@ -0,0 +1,7 @@ +# Bazel +bazel-bin +bazel-genfiles +bazel-out +bazel-protoc-gen-rust-grpc +bazel-testlogs +MODULE.bazel.lock diff --git a/protoc-gen-rust-grpc/MODULE.bazel b/protoc-gen-rust-grpc/MODULE.bazel new file mode 100644 index 000000000..a4ed1e2b8 --- /dev/null +++ b/protoc-gen-rust-grpc/MODULE.bazel @@ -0,0 +1,36 @@ +# Copyright 2025 gRPC authors. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. + +bazel_dep(name = "protobuf", repo_name = "com_google_protobuf", version = "31.1") + +# Hedron's Compile Commands Extractor for bazel +# This is used to generate a compile_commands.json file which can be used by +# LSP servers like clangd. +# https://github.com/hedronvision/bazel-compile-commands-extractor +bazel_dep(name = "hedron_compile_commands", dev_dependency = True) +git_override( + module_name = "hedron_compile_commands", + # Using a commit from a fork to workaround failures while using absl. + # TODO: replace with a commit on the official repo once the following PR is + # merged: + # https://github.com/hedronvision/bazel-compile-commands-extractor/pull/219 + remote = "https://github.com/mikael-s-persson/bazel-compile-commands-extractor", + commit = "f5fbd4cee671d8d908f37c83abaf70fba5928fc7" +) diff --git a/protoc-gen-rust-grpc/README.md b/protoc-gen-rust-grpc/README.md new file mode 100644 index 000000000..d7ddf785f --- /dev/null +++ b/protoc-gen-rust-grpc/README.md @@ -0,0 +1,70 @@ +## Build + +To build the Rust gRPC code generator plugin: + +```sh +bazel build //src:protoc-gen-rust-grpc +``` + + +## Usage Example + +**Note:** It's generally recommended to use `tonic_protobuf_build::CodeGen` +and/or `protobuf_codegen::CodeGen` instead of invoking `protoc` directly. Direct +usage of `protoc` and checking in the generated code can lead to stale output if +the `protobuf` dependencies are upgraded later. Using the codegen APIs ensures +consistency with your dependency versions and simplifies regeneration. + +```sh +# Build the plugin +bazel build //src:protoc-gen-rust-grpc + +# Set the plugin path +PLUGIN_PATH="$(pwd)/bazel-bin/src/protoc-gen-rust-grpc" + +# Run protoc with the Rust gRPC plugin +protoc \ + --plugin=protoc-gen-rust-grpc="$PLUGIN_PATH" \ + --rust_opt="experimental-codegen=enabled,kernel=upb" \ + --rust_out=./generated \ + --rust-grpc_out=./generated \ + routeguide.proto + +# Optionally, you can add the plugin to the PATH and omit the --plugin flag. +export PATH="$(pwd)/bazel-bin/src/:$PATH" +``` + +## Available Options + +These options are specific to the Rust gRPC plugin: + +* `message_module_path=PATH` (optional): Specifies the Rust path to the module +where Protobuf messages are defined. Use this when you plan to place the +generated message code in a different module than the service code. + + * Default: `self` + * Example: If your messages are in `crate::pb::routeguide`, use + `message_module_path=crate::pb::routeguide`. +* `crate_mapping=PATH` (optional): Specifies the path to a crate mapping file + generated by Bazel or another build system. You must pass the same mapping + file to `rust_opt` and `rust-grpc_opt`. The file contains: + ``` + \n + \n + \n + ... + \n + ``` + + +## Language Server Support + +To enable IDE features like code navigation and IntelliSense, generate +`compile_commands.json` using [Hedron Compile Commands](https://github.com/hedronvision/bazel-compile-commands-extractor): + +```sh +bazel run @hedron_compile_commands//:refresh_all +``` + +Then configure your C++ language server to use the generated +`compile_commands.json`. diff --git a/protoc-gen-rust-grpc/src/BUILD b/protoc-gen-rust-grpc/src/BUILD new file mode 100644 index 000000000..1a3553829 --- /dev/null +++ b/protoc-gen-rust-grpc/src/BUILD @@ -0,0 +1,32 @@ +# Copyright 2025 gRPC authors. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. + +cc_binary( + name = "protoc-gen-rust-grpc", + srcs = [ + "grpc_rust_plugin.cc", + "grpc_rust_generator.h", + "grpc_rust_generator.cc", + ], + visibility = ["//visibility:public"], + deps = [ + "@com_google_protobuf//:protoc_lib", + ], +) diff --git a/protoc-gen-rust-grpc/src/grpc_rust_generator.cc b/protoc-gen-rust-grpc/src/grpc_rust_generator.cc new file mode 100644 index 000000000..d1c0561ec --- /dev/null +++ b/protoc-gen-rust-grpc/src/grpc_rust_generator.cc @@ -0,0 +1,528 @@ +// Copyright 2025 gRPC authors. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal in the Software without restriction, including without limitation the +// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +// sell copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +// IN THE SOFTWARE. + +#include "src/grpc_rust_generator.h" + +#include +#include + +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "google/protobuf/compiler/rust/naming.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/descriptor.pb.h" + +namespace rust_grpc_generator { +namespace protobuf = google::protobuf; +namespace rust = protobuf::compiler::rust; + +using protobuf::Descriptor; +using protobuf::FileDescriptor; +using protobuf::MethodDescriptor; +using protobuf::ServiceDescriptor; +using protobuf::SourceLocation; +using protobuf::io::Printer; + +namespace { +template +std::string GrpcGetCommentsForDescriptor(const DescriptorType *descriptor) { + SourceLocation location; + if (descriptor->GetSourceLocation(&location)) { + return location.leading_comments.empty() ? location.trailing_comments + : location.leading_comments; + } + return ""; +} + +std::string RustModuleForContainingType(const GrpcOpts &opts, + const Descriptor *containing_type, + const FileDescriptor &file) { + std::vector modules; + // Innermost to outermost order. + const Descriptor *parent = containing_type; + while (parent != nullptr) { + modules.push_back(rust::RsSafeName(rust::CamelToSnakeCase(parent->name()))); + parent = parent->containing_type(); + } + + // Reverse the vector to get submodules in outer-to-inner order). + std::reverse(modules.begin(), modules.end()); + + // If there are any modules at all, push an empty string on the end so that + // we get the trailing :: + if (!modules.empty()) { + modules.push_back(""); + } + + std::string crate_relative = absl::StrJoin(modules, "::"); + + if (opts.IsFileInCurrentCrate(file)) { + return crate_relative; + } + std::string crate_name = + absl::StrCat("::", rust::RsSafeName(opts.GetCrateName(file.name()))); + + return absl::StrCat(crate_name, "::", crate_relative); +} + +std::string RsTypePathWithinMessageModule(const GrpcOpts &opts, + const Descriptor &msg) { + return absl::StrCat( + RustModuleForContainingType(opts, msg.containing_type(), *msg.file()), + rust::RsSafeName(msg.name())); +} + +std::string RsTypePath(const Descriptor &msg, const GrpcOpts &opts, int depth) { + std::string path_within_module = RsTypePathWithinMessageModule(opts, msg); + if (!opts.IsFileInCurrentCrate(*msg.file())) { + return path_within_module; + } + std::string path_to_message_module = opts.GetMessageModulePath() + "::"; + if (path_to_message_module == "self::") { + path_to_message_module = ""; + } + + // If the path to the message module is defined from the crate or global + // root, we don't need to add a prefix of "super::"s. + if (absl::StartsWith(path_to_message_module, "crate::") || + absl::StartsWith(path_to_message_module, "::")) { + depth = 0; + } + std::string prefix = ""; + for (int i = 0; i < depth; ++i) { + prefix += "super::"; + } + return prefix + path_to_message_module + std::string(path_within_module); +} + +absl::Status ReadFileToString(const absl::string_view name, std::string *output, + bool text_mode) { + char buffer[1024]; + FILE *file = fopen(name.data(), text_mode ? "rt" : "rb"); + if (file == nullptr) + return absl::NotFoundError("Could not open file"); + + while (true) { + size_t n = fread(buffer, 1, sizeof(buffer), file); + if (n <= 0) + break; + output->append(buffer, n); + } + + int error = ferror(file); + if (fclose(file) != 0) + return absl::InternalError("Failed to close file"); + if (error != 0) { + return absl::ErrnoToStatus(error, + absl::StrCat("Failed to read the file ", name, + ". Error code: ", error)); + } + return absl::OkStatus(); +} +} // namespace + +absl::StatusOr> +GetImportPathToCrateNameMap(const absl::string_view mapping_file_path) { + absl::flat_hash_map mapping; + std::string mapping_contents; + absl::Status status = + ReadFileToString(mapping_file_path, &mapping_contents, true); + if (!status.ok()) { + return status; + } + + std::vector lines = + absl::StrSplit(mapping_contents, '\n', absl::SkipEmpty()); + size_t len = lines.size(); + + size_t idx = 0; + while (idx < len) { + absl::string_view crate_name = lines[idx++]; + size_t files_cnt; + if (!absl::SimpleAtoi(lines[idx++], &files_cnt)) { + return absl::InvalidArgumentError( + "Couldn't parse number of import paths in mapping file"); + } + for (size_t i = 0; i < files_cnt; ++i) { + mapping.insert({std::string(lines[idx++]), std::string(crate_name)}); + } + } + return mapping; +} + +// Method generation abstraction. +// +// Each service contains a set of generic methods that will be used by codegen +// to generate abstraction implementations for the provided methods. +class Method { +public: + Method() = delete; + + explicit Method(const MethodDescriptor *method) : method_(method) {} + + // The name of the method in Rust style. + std::string Name() const { + return rust::RsSafeName(rust::CamelToSnakeCase(method_->name())); + }; + + // The fully-qualified name of the method, scope delimited by periods. + absl::string_view FullName() const { return method_->full_name(); } + + // The name of the method as it appears in the .proto file. + absl::string_view ProtoFieldName() const { return method_->name(); }; + + // Checks if the method is streamed by the client. + bool IsClientStreaming() const { return method_->client_streaming(); }; + + // Checks if the method is streamed by the server. + bool IsServerStreaming() const { return method_->server_streaming(); }; + + // Get comments about this method. + std::string Comment() const { return GrpcGetCommentsForDescriptor(method_); }; + + // Checks if the method is deprecated. Default is false. + bool IsDeprecated() const { return method_->options().deprecated(); } + + // Returns the Rust type name of request message. + std::string RequestName(const GrpcOpts &opts, int depth) const { + const Descriptor *input = method_->input_type(); + return RsTypePath(*input, opts, depth); + }; + + // Returns the Rust type name of response message. + std::string ResponseName(const GrpcOpts &opts, int depth) const { + const Descriptor *output = method_->output_type(); + return RsTypePath(*output, opts, depth); + }; + +private: + const MethodDescriptor *method_; +}; + +// Service generation abstraction. +// +// This class is an interface that can be implemented and consumed +// by client and server generators to allow any codegen module +// to generate service abstractions. +class Service { +public: + Service() = delete; + + explicit Service(const ServiceDescriptor *service) : service_(service) {} + + // The name of the service, not including its containing scope. + std::string Name() const { + return rust::RsSafeName(rust::SnakeToUpperCamelCase(service_->name())); + }; + + // The fully-qualified name of the service, scope delimited by periods. + absl::string_view FullName() const { return service_->full_name(); }; + + // Returns a list of Methods provided by the service. + std::vector Methods() const { + std::vector ret; + int methods_count = service_->method_count(); + ret.reserve(methods_count); + for (int i = 0; i < methods_count; ++i) { + ret.push_back(Method(service_->method(i))); + } + return ret; + }; + + // Get comments about this service. + virtual std::string Comment() const { + return GrpcGetCommentsForDescriptor(service_); + }; + +private: + const ServiceDescriptor *service_; +}; + +// Formats the full path for a method call. Returns the formatted method path +// (e.g., "/package.MyService/MyMethod") +static std::string FormatMethodPath(const Service &service, + const Method &method) { + return absl::StrFormat("/%s/%s", service.FullName(), method.ProtoFieldName()); +} + +static std::string SanitizeForRustDoc(absl::string_view raw_comment) { + // 1. Escape the escape character itself first. + std::string sanitized = absl::StrReplaceAll(raw_comment, {{"\\", "\\\\"}}); + + // 2. Escape Markdown and Rustdoc special characters. + sanitized = absl::StrReplaceAll(sanitized, { + {"`", "\\`"}, + {"*", "\\*"}, + {"_", "\\_"}, + {"[", "\\["}, + {"]", "\\]"}, + {"#", "\\#"}, + {"<", "\\<"}, + {">", "\\>"}, + }); + + return sanitized; +} + +static std::string ProtoCommentToRustDoc(absl::string_view proto_comment) { + std::string rust_doc; + std::vector lines = absl::StrSplit(proto_comment, '\n'); + for (const absl::string_view &line : lines) { + // Preserve empty lines. + if (line.empty()) { + rust_doc += ("///\n"); + } else { + rust_doc += absl::StrFormat("/// %s\n", SanitizeForRustDoc(line)); + } + } + return rust_doc; +} + +static void GenerateDeprecated(Printer &ctx) { ctx.Emit("#[deprecated]\n"); } + +namespace client { + +static void GenerateMethods(Printer &printer, const Service &service, + const GrpcOpts &opts) { + static const std::string unary_format = R"rs( + pub async fn $ident$( + &mut self, + request: impl tonic::IntoRequest<$request$>, + ) -> std::result::Result, tonic::Status> { + self.inner.ready().await.map_err(|e| { + tonic::Status::unknown(format!("Service was not ready: {}", e.into())) + })?; + let codec = $codec_name$::default(); + let path = http::uri::PathAndQuery::from_static("$path$"); + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new("$service_name$", "$method_name$")); + self.inner.unary(req, path, codec).await + } + )rs"; + + static const std::string server_streaming_format = R"rs( + pub async fn $ident$( + &mut self, + request: impl tonic::IntoRequest<$request$>, + ) -> std::result::Result>, tonic::Status> { + self.inner.ready().await.map_err(|e| { + tonic::Status::unknown(format!("Service was not ready: {}", e.into())) + })?; + let codec = $codec_name$::default(); + let path = http::uri::PathAndQuery::from_static("$path$"); + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new("$service_name$", "$method_name$")); + self.inner.server_streaming(req, path, codec).await + } + )rs"; + + static const std::string client_streaming_format = R"rs( + pub async fn $ident$( + &mut self, + request: impl tonic::IntoStreamingRequest + ) -> std::result::Result, tonic::Status> { + self.inner.ready().await.map_err(|e| { + tonic::Status::unknown(format!("Service was not ready: {}", e.into())) + })?; + let codec = $codec_name$::default(); + let path = http::uri::PathAndQuery::from_static("$path$"); + let mut req = request.into_streaming_request(); + req.extensions_mut().insert(GrpcMethod::new("$service_name$", "$method_name$")); + self.inner.client_streaming(req, path, codec).await + } + )rs"; + + static const std::string streaming_format = R"rs( + pub async fn $ident$( + &mut self, + request: impl tonic::IntoStreamingRequest + ) -> std::result::Result>, tonic::Status> { + self.inner.ready().await.map_err(|e| { + tonic::Status::unknown(format!("Service was not ready: {}", e.into())) + })?; + let codec = $codec_name$::default(); + let path = http::uri::PathAndQuery::from_static("$path$"); + let mut req = request.into_streaming_request(); + req.extensions_mut().insert(GrpcMethod::new("$service_name$", "$method_name$")); + self.inner.streaming(req, path, codec).await + } + )rs"; + + const std::vector methods = service.Methods(); + for (const Method &method : methods) { + printer.Emit(ProtoCommentToRustDoc(method.Comment())); + if (method.IsDeprecated()) { + GenerateDeprecated(printer); + } + const std::string request_type = method.RequestName(opts, 1); + const std::string response_type = method.ResponseName(opts, 1); + { + auto vars = + printer.WithVars({{"codec_name", "tonic_protobuf::ProtoCodec"}, + {"ident", method.Name()}, + {"request", request_type}, + {"response", response_type}, + {"service_name", service.FullName()}, + {"path", FormatMethodPath(service, method)}, + {"method_name", method.ProtoFieldName()}}); + + if (!method.IsClientStreaming() && !method.IsServerStreaming()) { + printer.Emit(unary_format); + } else if (!method.IsClientStreaming() && method.IsServerStreaming()) { + printer.Emit(server_streaming_format); + } else if (method.IsClientStreaming() && !method.IsServerStreaming()) { + printer.Emit(client_streaming_format); + } else { + printer.Emit(streaming_format); + } + if (&method != &methods.back()) { + printer.Emit("\n"); + } + } + } +} + +static void GenerateClient(const Service &service, Printer &printer, + const GrpcOpts &opts) { + std::string service_ident = absl::StrFormat("%sClient", service.Name()); + std::string client_mod = + absl::StrFormat("%s_client", rust::CamelToSnakeCase(service.Name())); + printer.Emit( + { + {"client_mod", client_mod}, + {"service_ident", service_ident}, + {"service_doc", + [&] { printer.Emit(ProtoCommentToRustDoc(service.Comment())); }}, + {"methods", [&] { GenerateMethods(printer, service, opts); }}, + }, + R"rs( + /// Generated client implementations. + pub mod $client_mod$ { + #![allow( + unused_variables, + dead_code, + missing_docs, + clippy::wildcard_imports, + // will trigger if compression is disabled + clippy::let_unit_value, + )] + use tonic::codegen::*; + use tonic::codegen::http::Uri; + + $service_doc$ + #[derive(Debug, Clone)] + pub struct $service_ident$ { + inner: tonic::client::Grpc, + } + + impl $service_ident$ + where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + std::marker::Send + + 'static, ::Error: Into + + std::marker::Send, + { + pub fn new(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } + + pub fn with_interceptor(inner: T, interceptor: F) -> + $service_ident$> where + F: tonic::service::Interceptor, + T::ResponseBody: Default, + T: tonic::codegen::Service< + http::Request, + Response = http::Response<>::ResponseBody> + >, + >>::Error: + Into + std::marker::Send + std::marker::Sync, + { + $service_ident$::new(InterceptedService::new(inner, interceptor)) + } + + /// Compress requests with the given encoding. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) + -> Self { + self.inner = self.inner.send_compressed(encoding); + self + } + + /// Enable decompressing responses. + #[must_use] + pub fn accept_compressed(mut self, encoding: + CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); + self + } + + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> + Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> + Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } + + $methods$ + } + })rs"); +} + +} // namespace client + +void GenerateService(protobuf::io::Printer &printer, + const ServiceDescriptor *service_desc, + const GrpcOpts &opts) { + client::GenerateClient(Service(service_desc), printer, opts); +} + +std::string GetRsGrpcFile(const protobuf::FileDescriptor &file) { + absl::string_view basename = absl::StripSuffix(file.name(), ".proto"); + return absl::StrCat(basename, "_grpc.pb.rs"); +} + +} // namespace rust_grpc_generator diff --git a/protoc-gen-rust-grpc/src/grpc_rust_generator.h b/protoc-gen-rust-grpc/src/grpc_rust_generator.h new file mode 100644 index 000000000..93787f7a8 --- /dev/null +++ b/protoc-gen-rust-grpc/src/grpc_rust_generator.h @@ -0,0 +1,102 @@ +// Copyright 2025 gRPC authors. +// +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal in the Software without restriction, including without limitation the +// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +// sell copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +// IN THE SOFTWARE. + +#ifndef PROTOC_GEN_RUST_GRPC_GRPC_RUST_GENERATOR_H_ +#define PROTOC_GEN_RUST_GRPC_GRPC_RUST_GENERATOR_H_ + +#include "absl/log/absl_log.h" +#include "google/protobuf/descriptor.h" + +namespace rust_grpc_generator { + +class GrpcOpts { +public: + void SetMessageModulePath(const std::string path) { + message_module_path_ = std::move(path); + } + + const std::string &GetMessageModulePath() const { + return message_module_path_; + } + + void SetImportPathToCrateName( + const absl::flat_hash_map mapping) { + import_path_to_crate_name_ = std::move(mapping); + } + + void SetFilesInCurrentCrate( + const std::vector files) { + files_in_current_crate_ = std::move(files); + } + + absl::string_view GetCrateName(absl::string_view import_path) const { + auto it = import_path_to_crate_name_.find(import_path); + if (it == import_path_to_crate_name_.end()) { + ABSL_LOG(ERROR) << "Path " << import_path + << " not found in crate mapping. Crate mapping contains " + << import_path_to_crate_name_.size() << " entries:"; + for (const auto &entry : import_path_to_crate_name_) { + ABSL_LOG(ERROR) << " " << entry.first << " : " << entry.second << "\n"; + } + ABSL_LOG(FATAL) << "Cannot continue with missing crate mapping."; + } + return it->second; + } + + bool IsFileInCurrentCrate(const google::protobuf::FileDescriptor &f) const { + return std::find(files_in_current_crate_.begin(), + files_in_current_crate_.end(), + &f) != files_in_current_crate_.end(); + } + +private: + // Path of the module containing the generated message code. Defaults to + // "self", i.e. the message code and service code are present in the same + // module. + std::string message_module_path_ = "self"; + absl::flat_hash_map import_path_to_crate_name_ = {}; + std::vector + files_in_current_crate_ = {}; +}; + +// Writes the generated service interface into the given ZeroCopyOutputStream +void GenerateService(google::protobuf::io::Printer &printer, + const google::protobuf::ServiceDescriptor *service, + const GrpcOpts &opts); + +std::string GetRsGrpcFile(const google::protobuf::FileDescriptor &file); + +// Returns a map from import path of a .proto file to the name of the crate +// covering that file. +// +// This function parses a .rust_crate_mapping file generated by a build system. +// The file contains: +// +// \n +// \n +// > +GetImportPathToCrateNameMap(const absl::string_view mapping_file_path); +} // namespace rust_grpc_generator + +#endif // PROTOC_GEN_RUST_GRPC_GRPC_RUST_GENERATOR_H_ diff --git a/protoc-gen-rust-grpc/src/grpc_rust_plugin.cc b/protoc-gen-rust-grpc/src/grpc_rust_plugin.cc new file mode 100644 index 000000000..5fb0a6921 --- /dev/null +++ b/protoc-gen-rust-grpc/src/grpc_rust_plugin.cc @@ -0,0 +1,98 @@ +// Copyright 2025 gRPC authors. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal in the Software without restriction, including without limitation the +// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +// sell copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +// IN THE SOFTWARE. + +#include + +#include "google/protobuf/compiler/code_generator.h" +#include "google/protobuf/compiler/plugin.h" +#include "google/protobuf/io/printer.h" + +#include "grpc_rust_generator.h" + +namespace protobuf = google::protobuf; + +class RustGrpcGenerator : public protobuf::compiler::CodeGenerator { +public: + // Protobuf 5.27 released edition 2023. +#if GOOGLE_PROTOBUF_VERSION >= 5027000 + uint64_t GetSupportedFeatures() const override { + return Feature::FEATURE_PROTO3_OPTIONAL | + Feature::FEATURE_SUPPORTS_EDITIONS; + } + protobuf::Edition GetMinimumEdition() const override { + return protobuf::Edition::EDITION_PROTO2; + } + protobuf::Edition GetMaximumEdition() const override { + return protobuf::Edition::EDITION_2023; + } +#else + uint64_t GetSupportedFeatures() const override { + return Feature::FEATURE_PROTO3_OPTIONAL; + } +#endif + + bool Generate(const protobuf::FileDescriptor *file, + const std::string ¶meter, + protobuf::compiler::GeneratorContext *context, + std::string *error) const override { + // Return early to avoid creating an empty output file. + if (file->service_count() == 0) { + return true; + } + std::vector> options; + protobuf::compiler::ParseGeneratorParameter(parameter, &options); + + rust_grpc_generator::GrpcOpts grpc_opts; + for (auto opt : options) { + if (opt.first == "message_module_path") { + grpc_opts.SetMessageModulePath(opt.second); + } else if (opt.first == "crate_mapping") { + absl::StatusOr> + crate_map = + rust_grpc_generator::GetImportPathToCrateNameMap(opt.second); + if (crate_map.ok()) { + grpc_opts.SetImportPathToCrateName(std::move(*crate_map)); + } else { + *error = std::string(crate_map.status().message()); + return false; + } + } + } + + std::vector files; + context->ListParsedFiles(&files); + grpc_opts.SetFilesInCurrentCrate(std::move(files)); + + auto outfile = absl::WrapUnique( + context->Open(rust_grpc_generator::GetRsGrpcFile(*file))); + protobuf::io::Printer printer(outfile.get()); + + for (int i = 0; i < file->service_count(); ++i) { + const protobuf::ServiceDescriptor *service = file->service(i); + rust_grpc_generator::GenerateService(printer, service, grpc_opts); + } + return true; + } +}; + +int main(int argc, char *argv[]) { + RustGrpcGenerator generator; + return protobuf::compiler::PluginMain(argc, argv, &generator); +} diff --git a/tonic-protobuf-build/Cargo.toml b/tonic-protobuf-build/Cargo.toml new file mode 100644 index 000000000..744da8d61 --- /dev/null +++ b/tonic-protobuf-build/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "tonic-protobuf-build" +version = "0.14.0" +edition = "2021" +authors = ["gRPC Authors"] +license = "MIT" +publish = false + +[dependencies] +prettyplease = "0.2.35" +protobuf-codegen = { version = "4.31.1-release" } +syn = "2.0.104" diff --git a/tonic-protobuf-build/README.md b/tonic-protobuf-build/README.md new file mode 100644 index 000000000..60ee13d77 --- /dev/null +++ b/tonic-protobuf-build/README.md @@ -0,0 +1,98 @@ +# tonic-protobuf-build + +Compiles proto files via protobuf rust and generates service stubs and proto +definitions for use with tonic. + +## Features + +Required dependencies + +```toml +[dependencies] +tonic = "" +protobuf = "" +tonic-protobuf = "" + +[build-dependencies] +tonic-protobuf-build = "" +``` + +You must ensure you have the following programs in your PATH: +1. protoc +1. protoc-gen-rust-grpc + +## Getting Started + +`tonic-protobuf-build` works by being included as a [`build.rs` file](https://doc.rust-lang.org/cargo/reference/build-scripts.html) at the root of the binary/library. + +You can rely on the defaults via + +```rust,no_run +fn main() -> Result<(), Box> { + tonic_protobuf_build::CodeGen::new() + .include("proto") + .inputs(["service.proto"]) + .compile()?; + Ok(()) +} +``` + +Or configure the generated code deeper via + +```rust,no_run +fn main() -> Result<(), Box> { + let dependency = tonic_protobuf_build::Dependency::builder() + .crate_name("external_protos".to_string()) + .proto_import_paths(vec![PathBuf::from("external/message.proto")]) + .proto_files(vec!["message.proto".to_string()]) + .build()?; + + tonic_protobuf_build::CodeGen::new() + .generate_message_code(false) + .inputs(["proto/helloworld/helloworld.proto"]) + .include("external") + .message_module_path("super::proto") + .dependencies(vec![dependency]) + //.out_dir("src/generated") // you can change the generated code's location + .compile()?; + Ok(()) +} +``` + +Then you can reference the generated Rust like this this in your code: +```rust,ignore +mod protos { + // Include message code. + include!(concat!(env!("OUT_DIR"), "proto/helloworld/generated.rs")); +} + +mod grpc { + // Include service code. + include!(concat!(env!("OUT_DIR"), "proto/helloworld/helloworld_grpc.pb.rs")); +} +``` + +If you don't modify the `message_module_path`, you can use the `include_proto` +macro to simplify the import code. +```rust,ignore +pub mod grpc_pb { + grpc::include_proto!("proto/helloworld", "helloworld"); +} +``` + +Or if you want to save the generated code in your own code base, +you can uncomment the line `.output_dir(...)` above, and in your lib file +config a mod like this: +```rust,ignore +pub mod generated { + pub mod helloworld { + pub mod proto { + include!("helloworld/generated.rs"); + } + + pub mod grpc { + include!("helloworld/test_grpc.pb.rs"); + } + } +} +``` diff --git a/tonic-protobuf-build/src/lib.rs b/tonic-protobuf-build/src/lib.rs new file mode 100644 index 000000000..b3a54b69f --- /dev/null +++ b/tonic-protobuf-build/src/lib.rs @@ -0,0 +1,316 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + * + */ + +use std::fs::{self, read_to_string}; +use std::io::Write; +use std::path::{Path, PathBuf}; + +use syn::parse_file; + +/// Details about a crate containing proto files with symbols referenced in +/// the file being compiled currently. +#[derive(Debug, Clone)] +pub struct Dependency { + crate_name: String, + proto_import_paths: Vec, + proto_files: Vec, +} + +impl Dependency { + pub fn builder() -> DependencyBuilder { + DependencyBuilder::default() + } +} + +#[derive(Default, Debug)] +pub struct DependencyBuilder { + crate_name: Option, + proto_import_paths: Vec, + proto_files: Vec, +} + +impl DependencyBuilder { + /// Name of the external crate. + pub fn crate_name(mut self, name: impl Into) -> Self { + self.crate_name = Some(name.into()); + self + } + + /// List of paths .proto files whose codegen is present in the crate. This + /// is used to re-run the build command if required. + pub fn proto_import_path(mut self, path: impl Into) -> Self { + self.proto_import_paths.push(path.into()); + self + } + + /// List of .proto file names whose codegen is present in the crate. + pub fn proto_import_paths(mut self, paths: Vec) -> Self { + self.proto_import_paths = paths; + self + } + + pub fn proto_file(mut self, file: impl Into) -> Self { + self.proto_files.push(file.into()); + self + } + + pub fn proto_files(mut self, files: Vec) -> Self { + self.proto_files = files; + self + } + + pub fn build(self) -> Result { + let crate_name = self.crate_name.ok_or("crate_name is required")?; + Ok(Dependency { + crate_name, + proto_import_paths: self.proto_import_paths, + proto_files: self.proto_files, + }) + } +} + +impl From<&Dependency> for protobuf_codegen::Dependency { + fn from(val: &Dependency) -> Self { + protobuf_codegen::Dependency { + crate_name: val.crate_name.clone(), + proto_import_paths: val.proto_import_paths.clone(), + // The following field is not used by protobuf codegen. + c_include_paths: Vec::new(), + proto_files: val.proto_files.clone(), + } + } +} + +/// Service generator builder. +#[derive(Debug, Clone)] +pub struct CodeGen { + inputs: Vec, + output_dir: PathBuf, + includes: Vec, + dependencies: Vec, + message_module_path: Option, + // Whether to generate message code, defaults to true. + generate_message_code: bool, + should_format_code: bool, +} + +impl CodeGen { + pub fn new() -> Self { + Self { + inputs: Vec::new(), + output_dir: PathBuf::from(std::env::var("OUT_DIR").unwrap()), + includes: Vec::new(), + dependencies: Vec::new(), + message_module_path: None, + generate_message_code: true, + should_format_code: true, + } + } + + /// Sets whether to generate the message code. This can be disabled if the + /// message code is being generated independently. + pub fn generate_message_code(&mut self, enable: bool) -> &mut Self { + self.generate_message_code = enable; + self + } + + /// Adds a proto file to compile. + pub fn input(&mut self, input: impl AsRef) -> &mut Self { + self.inputs.push(input.as_ref().to_owned()); + self + } + + /// Adds a proto file to compile. + pub fn inputs(&mut self, inputs: impl IntoIterator>) -> &mut Self { + self.inputs + .extend(inputs.into_iter().map(|input| input.as_ref().to_owned())); + self + } + + /// Enables or disables formatting of generated code. + pub fn should_format_code(&mut self, enable: bool) -> &mut Self { + self.should_format_code = enable; + self + } + + /// Sets the directory for the files generated by protoc. The generated code + /// will be present in a subdirectory corresponding to the path of the + /// proto file withing the included directories. + pub fn output_dir(&mut self, output_dir: impl AsRef) -> &mut Self { + self.output_dir = output_dir.as_ref().to_owned(); + self + } + + /// Add a directory for protoc to scan for .proto files. + pub fn include(&mut self, include: impl AsRef) -> &mut Self { + self.includes.push(include.as_ref().to_owned()); + self + } + + /// Add a directory for protoc to scan for .proto files. + pub fn includes(&mut self, includes: impl Iterator>) -> &mut Self { + self.includes.extend( + includes + .into_iter() + .map(|include| include.as_ref().to_owned()), + ); + self + } + + /// Adds a list of Rust crates along with the proto files whose generated + /// messages they contains. + pub fn dependencies(&mut self, deps: Vec) -> &mut Self { + self.dependencies.extend(deps); + self + } + + /// Sets path of the module containing the generated message code. This is + /// "self" by default, i.e. the service code expects the message structs to + /// be present in the same module. Set this if the message and service + /// codegen needs to live in separate modules. + pub fn message_module_path(&mut self, message_path: &str) -> &mut Self { + self.message_module_path = Some(message_path.to_string()); + self + } + + pub fn compile(&self) -> Result<(), String> { + // Generate the message code. + if self.generate_message_code { + protobuf_codegen::CodeGen::new() + .inputs(self.inputs.clone()) + .output_dir(self.output_dir.clone()) + .includes(self.includes.iter()) + .dependency(self.dependencies.iter().map(|d| d.into()).collect()) + .generate_and_compile() + .unwrap(); + } + let crate_mapping_path = if self.generate_message_code { + self.output_dir.join("crate_mapping.txt") + } else { + self.generate_crate_mapping_file() + }; + + // Generate the service code. + let mut cmd = std::process::Command::new("protoc"); + for input in &self.inputs { + cmd.arg(input); + } + if !self.output_dir.exists() { + // Attempt to make the directory if it doesn't exist + let _ = std::fs::create_dir(&self.output_dir); + } + + if !self.generate_message_code { + for include in &self.includes { + println!("cargo:rerun-if-changed={}", include.display()); + } + for dep in &self.dependencies { + for path in &dep.proto_import_paths { + println!("cargo:rerun-if-changed={}", path.display()); + } + } + } + + cmd.arg(format!("--rust-grpc_out={}", self.output_dir.display())); + cmd.arg(format!( + "--rust-grpc_opt=crate_mapping={}", + crate_mapping_path.display() + )); + if let Some(message_path) = &self.message_module_path { + cmd.arg(format!( + "--rust-grpc_opt=message_module_path={message_path}", + )); + } + + for include in &self.includes { + cmd.arg(format!("--proto_path={}", include.display())); + } + for dep in &self.dependencies { + for path in &dep.proto_import_paths { + cmd.arg(format!("--proto_path={}", path.display())); + } + } + + let output = cmd + .output() + .map_err(|e| format!("failed to run protoc: {e}"))?; + println!("{}", std::str::from_utf8(&output.stdout).unwrap()); + eprintln!("{}", std::str::from_utf8(&output.stderr).unwrap()); + assert!(output.status.success()); + + if self.should_format_code { + self.format_code(); + } + Ok(()) + } + + fn format_code(&self) { + let mut generated_file_paths = Vec::new(); + let output_dir = &self.output_dir; + if self.generate_message_code { + generated_file_paths.push(output_dir.join("generated.rs")); + } + for proto_path in &self.inputs { + let Some(stem) = proto_path.file_stem().and_then(|s| s.to_str()) else { + continue; + }; + generated_file_paths.push(output_dir.join(format!("{}_grpc.pb.rs", stem))); + if self.generate_message_code { + generated_file_paths.push(output_dir.join(format!("{}.u.pb.rs", stem))); + } + } + + for path in &generated_file_paths { + // The path may not exist if there are no services present in the + // proto file. + if path.exists() { + let src = read_to_string(path).expect("Failed to read generated file"); + let syntax = parse_file(&src).unwrap(); + let formatted = prettyplease::unparse(&syntax); + fs::write(path, formatted).unwrap(); + } + } + } + + fn generate_crate_mapping_file(&self) -> PathBuf { + let crate_mapping_path = self.output_dir.join("crate_mapping.txt"); + let mut file = fs::File::create(crate_mapping_path.clone()).unwrap(); + for dep in &self.dependencies { + file.write_all(format!("{}\n", dep.crate_name).as_bytes()) + .unwrap(); + file.write_all(format!("{}\n", dep.proto_files.len()).as_bytes()) + .unwrap(); + for f in &dep.proto_files { + file.write_all(format!("{f}\n").as_bytes()).unwrap(); + } + } + crate_mapping_path + } +} + +impl Default for CodeGen { + fn default() -> Self { + Self::new() + } +} diff --git a/tonic-protobuf/Cargo.toml b/tonic-protobuf/Cargo.toml new file mode 100644 index 000000000..c573f1062 --- /dev/null +++ b/tonic-protobuf/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "tonic-protobuf" +version = "0.14.0" +edition = "2021" +authors = ["gRPC Authors"] +license = "MIT" +publish = false + +[dependencies] +tonic = { version = "0.14.0", path = "../tonic", default-features = false, features = ["codegen"] } +bytes = "1" +protobuf = { version = "4.31.1-release" } + +[package.metadata.cargo_check_external_types] +allowed_external_types = [ + "tonic::*", + "protobuf::codegen_traits::Message", +] diff --git a/tonic-protobuf/src/lib.rs b/tonic-protobuf/src/lib.rs new file mode 100644 index 000000000..f2e434d1b --- /dev/null +++ b/tonic-protobuf/src/lib.rs @@ -0,0 +1,125 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + * + */ + +use bytes::{Buf, BufMut}; +use std::marker::PhantomData; +use tonic::{ + codec::{Codec, DecodeBuf, Decoder, EncodeBuf, Encoder}, + Status, +}; + +pub use protobuf; +use protobuf::Message; + +/// A [`Codec`] that implements `application/grpc+proto` via the protobuf +/// library. +#[derive(Debug, Clone)] +pub struct ProtoCodec { + _pd: PhantomData<(T, U)>, +} + +impl Default for ProtoCodec { + fn default() -> Self { + Self { _pd: PhantomData } + } +} + +impl Codec for ProtoCodec +where + T: Message + Send + 'static, + U: Message + Default + Send + 'static, +{ + type Encode = T; + type Decode = U; + + type Encoder = ProtoEncoder; + type Decoder = ProtoDecoder; + + fn encoder(&mut self) -> Self::Encoder { + ProtoEncoder { _pd: PhantomData } + } + + fn decoder(&mut self) -> Self::Decoder { + ProtoDecoder { _pd: PhantomData } + } +} + +/// A [`Encoder`] that knows how to encode `T`. +#[derive(Debug, Clone, Default)] +pub struct ProtoEncoder { + _pd: PhantomData, +} + +impl ProtoEncoder { + /// Get a new encoder with explicit buffer settings + pub fn new() -> Self { + Self { _pd: PhantomData } + } +} + +impl Encoder for ProtoEncoder { + type Item = T; + type Error = Status; + + fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> { + // The protobuf library doesn't support serializing into a user-provided + // buffer. Instead, it allocates its own buffer, resulting in an extra + // copy and allocation. + // TODO: #2345 - Find a way to avoid this extra copy. + let serialized = item.serialize().map_err(from_decode_error)?; + buf.put_slice(serialized.as_slice()); + Ok(()) + } +} + +/// A [`Decoder`] that knows how to decode `U`. +#[derive(Debug, Clone, Default)] +pub struct ProtoDecoder { + _pd: PhantomData, +} + +impl ProtoDecoder { + /// Get a new decoder. + pub fn new() -> Self { + Self { _pd: PhantomData } + } +} + +impl Decoder for ProtoDecoder { + type Item = U; + type Error = Status; + + fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result, Self::Error> { + let slice = buf.chunk(); + let item = U::parse(slice).map_err(from_decode_error)?; + buf.advance(slice.len()); + Ok(Some(item)) + } +} + +fn from_decode_error(error: impl std::error::Error) -> tonic::Status { + // Map Protobuf parse errors to an INTERNAL status code, as per + // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md + Status::internal(error.to_string()) +}