From d41b9a304e0eb6b8794087f7d0cd2d06151201fe Mon Sep 17 00:00:00 2001 From: hulxv Date: Sun, 1 Mar 2026 01:51:56 +0200 Subject: [PATCH 1/7] feat: implement mill-rpc framework --- Cargo.toml | 8 +- README.md | 16 + mill-rpc/Cargo.toml | 54 +++ mill-rpc/README.md | 194 +++++++++++ mill-rpc/examples/calculator_client.rs | 66 ++++ mill-rpc/examples/calculator_server.rs | 62 ++++ mill-rpc/examples/concurrent_clients.rs | 148 ++++++++ mill-rpc/examples/echo_client.rs | 60 ++++ mill-rpc/examples/echo_server.rs | 72 ++++ mill-rpc/examples/kv_client.rs | 107 ++++++ mill-rpc/examples/kv_server.rs | 97 ++++++ mill-rpc/examples/multi_service_client.rs | 102 ++++++ mill-rpc/examples/multi_service_server.rs | 150 +++++++++ mill-rpc/mill-rpc-core/Cargo.toml | 10 + mill-rpc/mill-rpc-core/src/codec.rs | 48 +++ mill-rpc/mill-rpc-core/src/context.rs | 32 ++ mill-rpc/mill-rpc-core/src/error.rs | 111 ++++++ mill-rpc/mill-rpc-core/src/lib.rs | 35 ++ mill-rpc/mill-rpc-core/src/protocol.rs | 391 ++++++++++++++++++++++ mill-rpc/mill-rpc-macros/Cargo.toml | 13 + mill-rpc/mill-rpc-macros/src/lib.rs | 310 +++++++++++++++++ mill-rpc/src/client.rs | 257 ++++++++++++++ mill-rpc/src/lib.rs | 32 ++ mill-rpc/src/prelude.rs | 8 + mill-rpc/src/server.rs | 265 +++++++++++++++ mill-rpc/tests/integration_test.rs | 104 ++++++ 26 files changed, 2751 insertions(+), 1 deletion(-) create mode 100644 mill-rpc/Cargo.toml create mode 100644 mill-rpc/README.md create mode 100644 mill-rpc/examples/calculator_client.rs create mode 100644 mill-rpc/examples/calculator_server.rs create mode 100644 mill-rpc/examples/concurrent_clients.rs create mode 100644 mill-rpc/examples/echo_client.rs create mode 100644 mill-rpc/examples/echo_server.rs create mode 100644 mill-rpc/examples/kv_client.rs create mode 100644 mill-rpc/examples/kv_server.rs create mode 100644 mill-rpc/examples/multi_service_client.rs create mode 100644 mill-rpc/examples/multi_service_server.rs create mode 100644 mill-rpc/mill-rpc-core/Cargo.toml create mode 100644 mill-rpc/mill-rpc-core/src/codec.rs create mode 100644 mill-rpc/mill-rpc-core/src/context.rs create mode 100644 mill-rpc/mill-rpc-core/src/error.rs create mode 100644 mill-rpc/mill-rpc-core/src/lib.rs create mode 100644 mill-rpc/mill-rpc-core/src/protocol.rs create mode 100644 mill-rpc/mill-rpc-macros/Cargo.toml create mode 100644 mill-rpc/mill-rpc-macros/src/lib.rs create mode 100644 mill-rpc/src/client.rs create mode 100644 mill-rpc/src/lib.rs create mode 100644 mill-rpc/src/prelude.rs create mode 100644 mill-rpc/src/server.rs create mode 100644 mill-rpc/tests/integration_test.rs diff --git a/Cargo.toml b/Cargo.toml index fa25026..9087d3d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,11 @@ [workspace] -members = ["mill-io", "mill-net"] +members = [ + "mill-io", + "mill-net", + "mill-rpc", + "mill-rpc/mill-rpc-core", + "mill-rpc/mill-rpc-macros", +] resolver = "2" [workspace.package] diff --git a/README.md b/README.md index f93ebe8..a4403d4 100644 --- a/README.md +++ b/README.md @@ -9,9 +9,18 @@ A lightweight, production-ready event loop library for Rust that provides effici - **Thread pool integration**: Configurable worker threads for handling I/O events - **Compute pool**: Dedicated priority-based thread pool for CPU-intensive tasks - **High-level networking**: High-level server/client components based on mill-io +- **RPC framework**: macro-driven RPC with type-safe clients and servers - **Object pooling**: Reduces allocation overhead for frequent operations - **Clean API**: Simple registration and handler interface +## Crates + +| Crate | Description | +| ------------------------ | -------------------------------------------- | +| [`mill-io`](./mill-io) | Core reactor-based event loop | +| [`mill-net`](./mill-net) | High-level TCP server/client networking | +| [`mill-rpc`](./mill-rpc) | Macro-driven RPC framework (server + client) | + ## Installation For the core event loop only: @@ -28,6 +37,13 @@ For high-level networking (includes mill-io as dependency): mill-net = "2.0.1" ``` +For the RPC framework (includes mill-io and mill-net): + +```toml +[dependencies] +mill-rpc = { path = "mill-rpc" } +``` + ## Architecture For detailed architectural documentation, see [Architecture Guide](./docs/Arch.md). diff --git a/mill-rpc/Cargo.toml b/mill-rpc/Cargo.toml new file mode 100644 index 0000000..abb2b8e --- /dev/null +++ b/mill-rpc/Cargo.toml @@ -0,0 +1,54 @@ +[package] +name = "mill-rpc" +version = "0.1.0" +edition = "2021" +description = "An Axum-inspired RPC framework built on Mill-IO" + +[dependencies] +mill-rpc-core = { path = "mill-rpc-core" } +mill-rpc-macros = { path = "mill-rpc-macros" } +mill-io = { path = "../mill-io" } +mill-net = { path = "../mill-net" } +serde = { version = "1", features = ["derive"] } +bincode = "1" +log = "0.4" +mio = { version = "1", features = ["os-poll", "net"] } + +[dev-dependencies] +env_logger = "0.11" + +[[example]] +name = "calculator_server" +path = "examples/calculator_server.rs" + +[[example]] +name = "calculator_client" +path = "examples/calculator_client.rs" + +[[example]] +name = "echo_server" +path = "examples/echo_server.rs" + +[[example]] +name = "echo_client" +path = "examples/echo_client.rs" + +[[example]] +name = "kv_server" +path = "examples/kv_server.rs" + +[[example]] +name = "kv_client" +path = "examples/kv_client.rs" + +[[example]] +name = "multi_service_server" +path = "examples/multi_service_server.rs" + +[[example]] +name = "multi_service_client" +path = "examples/multi_service_client.rs" + +[[example]] +name = "concurrent_clients" +path = "examples/concurrent_clients.rs" diff --git a/mill-rpc/README.md b/mill-rpc/README.md new file mode 100644 index 0000000..e0db420 --- /dev/null +++ b/mill-rpc/README.md @@ -0,0 +1,194 @@ +# mill-rpc + +An RPC framework built on top of [`mill-io`](../mill-io) and [`mill-net`](../mill-net). Define services as Rust traits, get type-safe clients and servers for free - no async runtime required. + +## Features + +- **Zero async** - Handlers are plain synchronous functions, no `async/await` needed +- **Macro-driven** - `#[mill_rpc::service]` generates server traits, client structs, and dispatch logic from a single trait definition +- **Type-safe** - Compile-time checked request/response types and method signatures +- **Multi-service** - Host multiple services on a single server with automatic routing +- **Pluggable codecs** - Bincode by default, extensible to JSON, MessagePack, CBOR, etc. +- **Binary wire protocol** - Efficient framing with support for one-way calls, ping/pong, and request cancellation + +## Installation + +```toml +[dependencies] +mill-rpc = { path = "../mill-rpc" } +``` + +## Quick Start + +### 1. Define a service + +```rust +use mill_rpc::prelude::*; + +#[mill_rpc::service] +trait Calculator { + fn add(a: i32, b: i32) -> i32; + fn multiply(a: i64, b: i64) -> i64; +} +``` + +This single trait generates: +- `CalculatorServer` - trait you implement on the server +- `CalculatorClient` - struct with typed RPC methods +- `CalculatorDispatcher` - wrapper that implements `ServiceDispatch` +- Per-method request/response types with serde derives +- `calculator_methods` module with method ID constants + +### 2. Implement the server + +```rust +struct MyCalculator; + +impl CalculatorServer for MyCalculator { + fn add(&self, _ctx: &RpcContext, a: i32, b: i32) -> i32 { + a + b + } + + fn multiply(&self, _ctx: &RpcContext, a: i64, b: i64) -> i64 { + a * b + } +} +``` + +### 3. Start the server + +```rust +use mill_io::EventLoop; +use std::sync::Arc; + +fn main() { + let event_loop = Arc::new(EventLoop::new(4, 1024, 100).unwrap()); + + let _server = RpcServer::builder() + .bind("127.0.0.1:9001".parse().unwrap()) + .service(CalculatorDispatcher(MyCalculator)) + .build(&event_loop) + .expect("Failed to start server"); + + event_loop.run().unwrap(); +} +``` + +### 4. Call from a client + +```rust +use mill_io::EventLoop; +use std::sync::Arc; +use std::thread; +use std::time::Duration; + +fn main() { + let event_loop = Arc::new(EventLoop::new(2, 1024, 100).unwrap()); + + // Run event loop in background + let el = event_loop.clone(); + thread::spawn(move || el.run().unwrap()); + thread::sleep(Duration::from_millis(50)); + + let transport = mill_rpc::RpcClient::connect( + "127.0.0.1:9001".parse().unwrap(), + &event_loop, + Codec::bincode(), + ).unwrap(); + + let client = CalculatorClient::new(transport, Codec::bincode(), 0); + + let sum = client.add(10, 25).unwrap(); + println!("10 + 25 = {}", sum); // 35 + + let product = client.multiply(7, 8).unwrap(); + println!("7 * 8 = {}", product); // 56 + + event_loop.stop(); +} +``` + +## Multi-Service Server + +Register multiple services on a single port. Each service gets an auto-assigned service ID. + +```rust +#[mill_rpc::service] +trait MathService { + fn factorial(n: u64) -> u64; +} + +#[mill_rpc::service] +trait StringService { + fn reverse(s: String) -> String; +} + +// Server +let _server = RpcServer::builder() + .bind(addr) + .service(MathServiceDispatcher(MathImpl)) // service_id = 0 + .service(StringServiceDispatcher(StringImpl)) // service_id = 1 + .build(&event_loop)?; + +// Client - both share a single TCP connection +let math = MathServiceClient::new(transport.clone(), Codec::bincode(), 0); +let strings = StringServiceClient::new(transport, Codec::bincode(), 1); + +math.factorial(10)?; // 3628800 +strings.reverse("hello")?; // "olleh" +``` + +## Wire Protocol + +Mill-RPC uses a compact binary frame format: + +```text ++--------+--------+-------+--------+-----------+---------+ +| Magic | Version| Flags | MsgType| PayloadLen| Payload | +| 2B | 1B | 1B | 1B | 4B (LE) | N bytes | ++--------+--------+-------+--------+-----------+---------+ +``` + +Request payloads carry routing info: + +```text ++------------+-----------+-----------+---------+ +| RequestID | ServiceID | MethodID | Args | +| 8B (LE) | 2B (LE) | 2B (LE) | N bytes | ++------------+-----------+-----------+---------+ +``` + +**Message types:** Request, Response, Error, Ping, Pong, Cancel + +**Flags:** Compressed payload, One-way (fire-and-forget) + +## Examples + +Run any example pair (server first, then client): + +```bash +# Terminal 1 +cargo run --example calculator_server + +# Terminal 2 +cargo run --example calculator_client +``` + +you will find all examples [here](./examples/). + +## Error Handling + +Mill-RPC uses structured errors with gRPC-style status codes: + +| Code | Status | Description | +| ---- | ----------------- | --------------------------- | +| 0 | OK | Success | +| 2 | INVALID_ARGUMENT | Bad request parameters | +| 3 | NOT_FOUND | Service or method not found | +| 8 | INTERNAL | Server-side error | +| 9 | UNAVAILABLE | Connection failure | +| 10 | DEADLINE_EXCEEDED | Request timeout | + +## License + +Licensed under the Apache License, Version 2.0. See [LICENSE](../LICENSE) for details. diff --git a/mill-rpc/examples/calculator_client.rs b/mill-rpc/examples/calculator_client.rs new file mode 100644 index 0000000..74fe80b --- /dev/null +++ b/mill-rpc/examples/calculator_client.rs @@ -0,0 +1,66 @@ +//! Calculator RPC client. +//! +//! Run the server first: cargo run --example calculator_server +//! Then run: cargo run --example calculator_client + +use mill_io::EventLoop; +use mill_rpc::prelude::*; +use std::sync::Arc; +use std::thread; +use std::time::Duration; + +#[mill_rpc::service] +trait Calculator { + fn add(a: i32, b: i32) -> i32; + fn subtract(a: i32, b: i32) -> i32; + fn multiply(a: i64, b: i64) -> i64; + fn divide(a: f64, b: f64) -> f64; + fn negate(x: i32) -> i32; +} + +fn main() { + env_logger::init(); + let event_loop = Arc::new(EventLoop::new(2, 1024, 100).unwrap()); + + // Run event loop in background + let el = event_loop.clone(); + let handle = thread::spawn(move || { + el.run().unwrap(); + }); + + // Give event loop a moment to start + thread::sleep(Duration::from_millis(50)); + + let addr = "127.0.0.1:9001".parse().unwrap(); + let transport = mill_rpc::RpcClient::connect(addr, &event_loop, Codec::bincode()) + .expect("Failed to connect to calculator server"); + + let client = CalculatorClient::new(transport, Codec::bincode(), 0); + + println!("Connected to calculator server"); + + // Basic arithmetic + let sum = client.add(10, 25).unwrap(); + println!("10 + 25 = {}", sum); + + let diff = client.subtract(100, 42).unwrap(); + println!("100 - 42 = {}", diff); + + let product = client.multiply(7, 8).unwrap(); + println!("7 * 8 = {}", product); + + let quotient = client.divide(355.0, 113.0).unwrap(); + println!("355 / 113 = {:.6}", quotient); + + let neg = client.negate(42).unwrap(); + println!("negate(42) = {}", neg); + + // Division by zero + let nan = client.divide(1.0, 0.0).unwrap(); + println!("1.0 / 0.0 = {} (NaN: {})", nan, nan.is_nan()); + + println!("\nAll calculations completed!"); + + event_loop.stop(); + let _ = handle.join(); +} diff --git a/mill-rpc/examples/calculator_server.rs b/mill-rpc/examples/calculator_server.rs new file mode 100644 index 0000000..c08973d --- /dev/null +++ b/mill-rpc/examples/calculator_server.rs @@ -0,0 +1,62 @@ +//! Basic calculator RPC server. +//! +//! Run with: cargo run --example calculator_server +//! Then connect with: cargo run --example calculator_client + +use mill_io::EventLoop; +use mill_rpc::prelude::*; +use std::sync::Arc; + +#[mill_rpc::service] +trait Calculator { + fn add(a: i32, b: i32) -> i32; + fn subtract(a: i32, b: i32) -> i32; + fn multiply(a: i64, b: i64) -> i64; + fn divide(a: f64, b: f64) -> f64; + fn negate(x: i32) -> i32; +} + +struct CalculatorImpl; + +impl CalculatorServer for CalculatorImpl { + fn add(&self, _ctx: &RpcContext, a: i32, b: i32) -> i32 { + println!(" add({}, {}) = {}", a, b, a + b); + a + b + } + + fn subtract(&self, _ctx: &RpcContext, a: i32, b: i32) -> i32 { + println!(" subtract({}, {}) = {}", a, b, a - b); + a - b + } + + fn multiply(&self, _ctx: &RpcContext, a: i64, b: i64) -> i64 { + println!(" multiply({}, {}) = {}", a, b, a * b); + a * b + } + + fn divide(&self, _ctx: &RpcContext, a: f64, b: f64) -> f64 { + let result = if b == 0.0 { f64::NAN } else { a / b }; + println!(" divide({}, {}) = {}", a, b, result); + result + } + + fn negate(&self, _ctx: &RpcContext, x: i32) -> i32 { + println!(" negate({}) = {}", x, -x); + -x + } +} + +fn main() { + env_logger::init(); + let event_loop = Arc::new(EventLoop::new(4, 1024, 100).unwrap()); + + let addr = "127.0.0.1:9001".parse().unwrap(); + let _server = RpcServer::builder() + .bind(addr) + .service(CalculatorDispatcher(CalculatorImpl)) + .build(&event_loop) + .expect("Failed to start calculator server"); + + println!("Calculator server listening on {}", addr); + event_loop.run().unwrap(); +} diff --git a/mill-rpc/examples/concurrent_clients.rs b/mill-rpc/examples/concurrent_clients.rs new file mode 100644 index 0000000..ef29126 --- /dev/null +++ b/mill-rpc/examples/concurrent_clients.rs @@ -0,0 +1,148 @@ +//! Concurrent clients stress test - multiple threads sending RPC calls simultaneously. +//! +//! This example starts an embedded server and spawns N client threads, +//! each making M requests. Demonstrates thread-safety and concurrent dispatch. +//! +//! Run with: cargo run --example concurrent_clients + +use mill_io::EventLoop; +use mill_rpc::prelude::*; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use std::thread; +use std::time::{Duration, Instant}; + +#[mill_rpc::service] +trait Counter { + fn increment() -> u64; + fn get() -> u64; + fn add(n: u64) -> u64; +} + +struct AtomicCounter { + value: AtomicU64, +} + +impl AtomicCounter { + fn new() -> Self { + Self { + value: AtomicU64::new(0), + } + } +} + +impl CounterServer for AtomicCounter { + fn increment(&self, _ctx: &RpcContext) -> u64 { + self.value.fetch_add(1, Ordering::SeqCst) + 1 + } + + fn get(&self, _ctx: &RpcContext) -> u64 { + self.value.load(Ordering::SeqCst) + } + + fn add(&self, _ctx: &RpcContext, n: u64) -> u64 { + self.value.fetch_add(n, Ordering::SeqCst) + n + } +} + +fn main() { + env_logger::init(); + + let num_clients = 4; + let requests_per_client = 100; + + // --- Start embedded server --- + let server_el = Arc::new(EventLoop::new(4, 1024, 100).unwrap()); + let addr = "127.0.0.1:9005".parse().unwrap(); + + let _server = RpcServer::builder() + .bind(addr) + .service(CounterDispatcher(AtomicCounter::new())) + .build(&server_el) + .expect("Failed to start server"); + + let sel = server_el.clone(); + let server_thread = thread::spawn(move || { + sel.run().unwrap(); + }); + + thread::sleep(Duration::from_millis(100)); + println!( + "Server started. Spawning {} clients, {} requests each...\n", + num_clients, requests_per_client + ); + + // --- Spawn client threads --- + let start = Instant::now(); + let mut handles = Vec::new(); + + for client_id in 0..num_clients { + let handle = thread::spawn(move || { + let client_el = Arc::new(EventLoop::new(1, 256, 50).unwrap()); + + let cel = client_el.clone(); + let el_thread = thread::spawn(move || { + cel.run().unwrap(); + }); + + thread::sleep(Duration::from_millis(20)); + + let transport = + mill_rpc::RpcClient::connect(addr, &client_el, Codec::bincode()).unwrap(); + let counter = CounterClient::new(transport, Codec::bincode(), 0); + + let mut local_results = Vec::new(); + for _ in 0..requests_per_client { + let val = counter.increment().unwrap(); + local_results.push(val); + } + + println!( + " Client {} finished: first={}, last={}", + client_id, + local_results.first().unwrap(), + local_results.last().unwrap() + ); + + client_el.stop(); + let _ = el_thread.join(); + + local_results + }); + handles.push(handle); + } + + // --- Collect results --- + let mut all_values: Vec = Vec::new(); + for handle in handles { + let values = handle.join().unwrap(); + all_values.extend(values); + } + + let elapsed = start.elapsed(); + + // Every increment should have returned a unique value + all_values.sort(); + all_values.dedup(); + + let total_expected = (num_clients * requests_per_client) as usize; + println!("\n--- Results ---"); + println!("Total requests: {}", total_expected); + println!("Unique values: {}", all_values.len()); + println!("Time elapsed: {:?}", elapsed); + println!( + "Throughput: {:.0} req/s", + total_expected as f64 / elapsed.as_secs_f64() + ); + + assert_eq!( + all_values.len(), + total_expected, + "Expected all increments to produce unique values (no lost updates)" + ); + + println!("\nConcurrency test passed! No lost updates."); + + server_el.stop(); + let _ = server_thread.join(); +} diff --git a/mill-rpc/examples/echo_client.rs b/mill-rpc/examples/echo_client.rs new file mode 100644 index 0000000..f57c69d --- /dev/null +++ b/mill-rpc/examples/echo_client.rs @@ -0,0 +1,60 @@ +//! Echo RPC client. +//! +//! Run the server first: cargo run --example echo_server +//! Then run: cargo run --example echo_client + +use mill_io::EventLoop; +use mill_rpc::prelude::*; +use std::sync::Arc; +use std::thread; +use std::time::Duration; + +#[mill_rpc::service] +trait Echo { + fn echo(message: String) -> String; + fn echo_uppercase(message: String) -> String; + fn echo_repeat(message: String, times: u32) -> String; + fn request_count() -> u64; +} + +fn main() { + env_logger::init(); + let event_loop = Arc::new(EventLoop::new(2, 1024, 100).unwrap()); + + let el = event_loop.clone(); + let handle = thread::spawn(move || { + el.run().unwrap(); + }); + thread::sleep(Duration::from_millis(50)); + + let addr = "127.0.0.1:9002".parse().unwrap(); + let transport = mill_rpc::RpcClient::connect(addr, &event_loop, Codec::bincode()) + .expect("Failed to connect to echo server"); + + let client = EchoClient::new(transport, Codec::bincode(), 0); + + // Basic echo + let reply = client.echo("Hello, Mill-RPC!".into()).unwrap(); + println!("echo: {}", reply); + assert_eq!(reply, "Hello, Mill-RPC!"); + + // Uppercase + let reply = client.echo_uppercase("hello world".into()).unwrap(); + println!("uppercase: {}", reply); + assert_eq!(reply, "HELLO WORLD"); + + // Repeat + let reply = client.echo_repeat("ha".into(), 3).unwrap(); + println!("repeat: {}", reply); + assert_eq!(reply, "hahaha"); + + // Request count + let count = client.request_count().unwrap(); + println!("server handled {} requests", count); + assert_eq!(count, 3); // echo + uppercase + repeat + + println!("\nAll echo tests passed!"); + + event_loop.stop(); + let _ = handle.join(); +} diff --git a/mill-rpc/examples/echo_server.rs b/mill-rpc/examples/echo_server.rs new file mode 100644 index 0000000..d3e41ce --- /dev/null +++ b/mill-rpc/examples/echo_server.rs @@ -0,0 +1,72 @@ +//! Echo RPC server - returns whatever you send. +//! +//! Run with: cargo run --example echo_server +//! Then connect with: cargo run --example echo_client + +use mill_io::EventLoop; +use mill_rpc::prelude::*; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; + +#[mill_rpc::service] +trait Echo { + fn echo(message: String) -> String; + fn echo_uppercase(message: String) -> String; + fn echo_repeat(message: String, times: u32) -> String; + fn request_count() -> u64; +} + +struct EchoImpl { + counter: AtomicU64, +} + +impl EchoImpl { + fn new() -> Self { + Self { + counter: AtomicU64::new(0), + } + } +} + +impl EchoServer for EchoImpl { + fn echo(&self, _ctx: &RpcContext, message: String) -> String { + self.counter.fetch_add(1, Ordering::Relaxed); + println!(" echo: {:?}", message); + message + } + + fn echo_uppercase(&self, _ctx: &RpcContext, message: String) -> String { + self.counter.fetch_add(1, Ordering::Relaxed); + let upper = message.to_uppercase(); + println!(" echo_uppercase: {:?} -> {:?}", message, upper); + upper + } + + fn echo_repeat(&self, _ctx: &RpcContext, message: String, times: u32) -> String { + self.counter.fetch_add(1, Ordering::Relaxed); + let result = message.repeat(times as usize); + println!(" echo_repeat({:?}, {}) -> {:?}", message, times, result); + result + } + + fn request_count(&self, _ctx: &RpcContext) -> u64 { + let count = self.counter.load(Ordering::Relaxed); + println!(" request_count -> {}", count); + count + } +} + +fn main() { + env_logger::init(); + let event_loop = Arc::new(EventLoop::new(4, 1024, 100).unwrap()); + + let addr = "127.0.0.1:9002".parse().unwrap(); + let _server = RpcServer::builder() + .bind(addr) + .service(EchoDispatcher(EchoImpl::new())) + .build(&event_loop) + .expect("Failed to start echo server"); + + println!("Echo server listening on {}", addr); + event_loop.run().unwrap(); +} diff --git a/mill-rpc/examples/kv_client.rs b/mill-rpc/examples/kv_client.rs new file mode 100644 index 0000000..74c1940 --- /dev/null +++ b/mill-rpc/examples/kv_client.rs @@ -0,0 +1,107 @@ +//! Key-value store RPC client. +//! +//! Run the server first: cargo run --example kv_server +//! Then run: cargo run --example kv_client + +use mill_io::EventLoop; +use mill_rpc::prelude::*; +use std::sync::Arc; +use std::thread; +use std::time::Duration; + +#[mill_rpc::service] +trait KeyValue { + fn get(key: String) -> Option; + fn set(key: String, value: String) -> Option; + fn delete(key: String) -> bool; + fn keys() -> Vec; + fn len() -> u64; + fn clear() -> u64; +} + +fn main() { + env_logger::init(); + let event_loop = Arc::new(EventLoop::new(2, 1024, 100).unwrap()); + + let el = event_loop.clone(); + let handle = thread::spawn(move || { + el.run().unwrap(); + }); + thread::sleep(Duration::from_millis(50)); + + let addr = "127.0.0.1:9003".parse().unwrap(); + let transport = mill_rpc::RpcClient::connect(addr, &event_loop, Codec::bincode()) + .expect("Failed to connect to KV server"); + + let kv = KeyValueClient::new(transport, Codec::bincode(), 0); + + println!("=== Key-Value Store Client ===\n"); + + // Initially empty + let len = kv.len().unwrap(); + println!("Initial store size: {}", len); + assert_eq!(len, 0); + + // Set some keys + let old = kv.set("name".into(), "Alice".into()).unwrap(); + println!("SET name=Alice (old: {:?})", old); + assert!(old.is_none()); + + let old = kv.set("city".into(), "Berlin".into()).unwrap(); + println!("SET city=Berlin (old: {:?})", old); + + let old = kv.set("lang".into(), "Rust".into()).unwrap(); + println!("SET lang=Rust (old: {:?})", old); + + // Get values + let val = kv.get("name".into()).unwrap(); + println!("GET name -> {:?}", val); + assert_eq!(val, Some("Alice".to_string())); + + let val = kv.get("missing".into()).unwrap(); + println!("GET missing -> {:?}", val); + assert_eq!(val, None); + + // List keys + let mut keys = kv.keys().unwrap(); + keys.sort(); + println!("KEYS -> {:?}", keys); + assert_eq!(keys.len(), 3); + + // Overwrite + let old = kv.set("name".into(), "Bob".into()).unwrap(); + println!("SET name=Bob (old: {:?})", old); + assert_eq!(old, Some("Alice".to_string())); + + let val = kv.get("name".into()).unwrap(); + println!("GET name -> {:?}", val); + assert_eq!(val, Some("Bob".to_string())); + + // Delete + let existed = kv.delete("city".into()).unwrap(); + println!("DEL city -> existed: {}", existed); + assert!(existed); + + let existed = kv.delete("city".into()).unwrap(); + println!("DEL city -> existed: {}", existed); + assert!(!existed); + + // Final size + let len = kv.len().unwrap(); + println!("Store size: {}", len); + assert_eq!(len, 2); + + // Clear + let removed = kv.clear().unwrap(); + println!("CLEAR -> removed {} entries", removed); + assert_eq!(removed, 2); + + let len = kv.len().unwrap(); + println!("Store size after clear: {}", len); + assert_eq!(len, 0); + + println!("\nAll KV tests passed!"); + + event_loop.stop(); + let _ = handle.join(); +} diff --git a/mill-rpc/examples/kv_server.rs b/mill-rpc/examples/kv_server.rs new file mode 100644 index 0000000..d2d08d4 --- /dev/null +++ b/mill-rpc/examples/kv_server.rs @@ -0,0 +1,97 @@ +//! In-memory key-value store RPC server. +//! +//! Demonstrates shared mutable state (HashMap behind RwLock) accessed +//! concurrently from the thread pool. +//! +//! Run with: cargo run --example kv_server +//! Then connect with: cargo run --example kv_client + +use mill_io::EventLoop; +use mill_rpc::prelude::*; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; + +/// We use Option to represent "key not found" as None. +/// In a real app you'd use Result or a custom enum with #[derive(RpcError)]. +#[mill_rpc::service] +trait KeyValue { + fn get(key: String) -> Option; + fn set(key: String, value: String) -> Option; + fn delete(key: String) -> bool; + fn keys() -> Vec; + fn len() -> u64; + fn clear() -> u64; +} + +struct KvStore { + data: RwLock>, +} + +impl KvStore { + fn new() -> Self { + Self { + data: RwLock::new(HashMap::new()), + } + } +} + +impl KeyValueServer for KvStore { + fn get(&self, _ctx: &RpcContext, key: String) -> Option { + let data = self.data.read().unwrap(); + let result = data.get(&key).cloned(); + println!(" GET {:?} -> {:?}", key, result); + result + } + + fn set(&self, _ctx: &RpcContext, key: String, value: String) -> Option { + let mut data = self.data.write().unwrap(); + let old = data.insert(key.clone(), value.clone()); + println!(" SET {:?} = {:?} (old: {:?})", key, value, old); + old + } + + fn delete(&self, _ctx: &RpcContext, key: String) -> bool { + let mut data = self.data.write().unwrap(); + let existed = data.remove(&key).is_some(); + println!(" DEL {:?} -> existed: {}", key, existed); + existed + } + + fn keys(&self, _ctx: &RpcContext) -> Vec { + let data = self.data.read().unwrap(); + let keys: Vec = data.keys().cloned().collect(); + println!(" KEYS -> {:?}", keys); + keys + } + + fn len(&self, _ctx: &RpcContext) -> u64 { + let data = self.data.read().unwrap(); + let len = data.len() as u64; + println!(" LEN -> {}", len); + len + } + + fn clear(&self, _ctx: &RpcContext) -> u64 { + let mut data = self.data.write().unwrap(); + let count = data.len() as u64; + data.clear(); + println!(" CLEAR -> removed {} entries", count); + count + } +} + +fn main() { + env_logger::init(); + let event_loop = Arc::new(EventLoop::new(4, 1024, 100).unwrap()); + + let addr = "127.0.0.1:9003".parse().unwrap(); + let _server = RpcServer::builder() + .bind(addr) + .service(KeyValueDispatcher(KvStore::new())) + .build(&event_loop) + .expect("Failed to start KV server"); + + println!("Key-Value server listening on {}", addr); + println!("Supports: GET, SET, DEL, KEYS, LEN, CLEAR"); + event_loop.run().unwrap(); +} diff --git a/mill-rpc/examples/multi_service_client.rs b/mill-rpc/examples/multi_service_client.rs new file mode 100644 index 0000000..28dab68 --- /dev/null +++ b/mill-rpc/examples/multi_service_client.rs @@ -0,0 +1,102 @@ +//! Multi-service RPC client - calls two services on one server. +//! +//! Run the server first: cargo run --example multi_service_server +//! Then run: cargo run --example multi_service_client + +use mill_io::EventLoop; +use mill_rpc::prelude::*; +use std::sync::Arc; +use std::thread; +use std::time::Duration; + +// Service definitions must match the server +#[mill_rpc::service] +trait MathService { + fn factorial(n: u64) -> u64; + fn fibonacci(n: u32) -> u64; + fn is_prime(n: u64) -> bool; + fn gcd(a: u64, b: u64) -> u64; +} + +#[mill_rpc::service] +trait StringService { + fn reverse(s: String) -> String; + fn word_count(s: String) -> u32; + fn contains(haystack: String, needle: String) -> bool; + fn trim(s: String) -> String; + fn replace(s: String, from: String, to: String) -> String; +} + +fn main() { + env_logger::init(); + let event_loop = Arc::new(EventLoop::new(2, 1024, 100).unwrap()); + + let el = event_loop.clone(); + let handle = thread::spawn(move || { + el.run().unwrap(); + }); + thread::sleep(Duration::from_millis(50)); + + let addr = "127.0.0.1:9004".parse().unwrap(); + let transport = mill_rpc::RpcClient::connect(addr, &event_loop, Codec::bincode()) + .expect("Failed to connect"); + + // Both clients share the same transport (single TCP connection) + let math = MathServiceClient::new(transport.clone(), Codec::bincode(), 0); + let strings = StringServiceClient::new(transport, Codec::bincode(), 1); + + // ---- Math Service ---- + println!("=== Math Service ===\n"); + + let f = math.factorial(10).unwrap(); + println!("10! = {}", f); + assert_eq!(f, 3628800); + + let fib = math.fibonacci(20).unwrap(); + println!("fib(20) = {}", fib); + assert_eq!(fib, 6765); + + for n in [2, 7, 15, 17, 100] { + let prime = math.is_prime(n).unwrap(); + println!("is_prime({}) = {}", n, prime); + } + + let g = math.gcd(48, 18).unwrap(); + println!("gcd(48, 18) = {}", g); + assert_eq!(g, 6); + + // ---- String Service ---- + println!("\n=== String Service ===\n"); + + let rev = strings.reverse("Hello, World!".into()).unwrap(); + println!("reverse(\"Hello, World!\") = {:?}", rev); + assert_eq!(rev, "!dlroW ,olleH"); + + let wc = strings + .word_count("The quick brown fox jumps".into()) + .unwrap(); + println!("word_count(\"The quick brown fox jumps\") = {}", wc); + assert_eq!(wc, 5); + + let has = strings.contains("rustacean".into(), "rust".into()).unwrap(); + println!("contains(\"rustacean\", \"rust\") = {}", has); + assert!(has); + + let trimmed = strings.trim(" hello ".into()).unwrap(); + println!("trim(\" hello \") = {:?}", trimmed); + assert_eq!(trimmed, "hello"); + + let replaced = strings + .replace("foo bar foo baz".into(), "foo".into(), "qux".into()) + .unwrap(); + println!( + "replace(\"foo bar foo baz\", \"foo\", \"qux\") = {:?}", + replaced + ); + assert_eq!(replaced, "qux bar qux baz"); + + println!("\nAll multi-service tests passed!"); + + event_loop.stop(); + let _ = handle.join(); +} diff --git a/mill-rpc/examples/multi_service_server.rs b/mill-rpc/examples/multi_service_server.rs new file mode 100644 index 0000000..0b4109b --- /dev/null +++ b/mill-rpc/examples/multi_service_server.rs @@ -0,0 +1,150 @@ +//! Multi-service RPC server - hosts multiple services on a single port. +//! +//! Demonstrates service composition: a math service (service_id=0) and +//! a string utility service (service_id=1) on the same server. +//! +//! Run with: cargo run --example multi_service_server +//! Then connect with: cargo run --example multi_service_client + +use mill_io::EventLoop; +use mill_rpc::prelude::*; +use std::sync::Arc; + +// ---------- Service 1: MathService ---------- + +#[mill_rpc::service] +trait MathService { + fn factorial(n: u64) -> u64; + fn fibonacci(n: u32) -> u64; + fn is_prime(n: u64) -> bool; + fn gcd(a: u64, b: u64) -> u64; +} + +struct MathImpl; + +impl MathServiceServer for MathImpl { + fn factorial(&self, _ctx: &RpcContext, n: u64) -> u64 { + let result = (1..=n).product(); + println!(" [math] factorial({}) = {}", n, result); + result + } + + fn fibonacci(&self, _ctx: &RpcContext, n: u32) -> u64 { + let result = match n { + 0 => 0, + 1 => 1, + _ => { + let (mut a, mut b) = (0u64, 1u64); + for _ in 2..=n { + let tmp = a + b; + a = b; + b = tmp; + } + b + } + }; + println!(" [math] fibonacci({}) = {}", n, result); + result + } + + fn is_prime(&self, _ctx: &RpcContext, n: u64) -> bool { + let result = if n < 2 { + false + } else if n < 4 { + true + } else if n % 2 == 0 { + false + } else { + let mut i = 3; + while i * i <= n { + if n % i == 0 { + return false; + } + i += 2; + } + true + }; + println!(" [math] is_prime({}) = {}", n, result); + result + } + + fn gcd(&self, _ctx: &RpcContext, mut a: u64, mut b: u64) -> u64 { + let (orig_a, orig_b) = (a, b); + while b != 0 { + let tmp = b; + b = a % b; + a = tmp; + } + println!(" [math] gcd({}, {}) = {}", orig_a, orig_b, a); + a + } +} + +// ---------- Service 2: StringService ---------- + +#[mill_rpc::service] +trait StringService { + fn reverse(s: String) -> String; + fn word_count(s: String) -> u32; + fn contains(haystack: String, needle: String) -> bool; + fn trim(s: String) -> String; + fn replace(s: String, from: String, to: String) -> String; +} + +struct StringImpl; + +impl StringServiceServer for StringImpl { + fn reverse(&self, _ctx: &RpcContext, s: String) -> String { + let result: String = s.chars().rev().collect(); + println!(" [str] reverse({:?}) = {:?}", s, result); + result + } + + fn word_count(&self, _ctx: &RpcContext, s: String) -> u32 { + let count = s.split_whitespace().count() as u32; + println!(" [str] word_count({:?}) = {}", s, count); + count + } + + fn contains(&self, _ctx: &RpcContext, haystack: String, needle: String) -> bool { + let result = haystack.contains(&needle); + println!( + " [str] contains({:?}, {:?}) = {}", + haystack, needle, result + ); + result + } + + fn trim(&self, _ctx: &RpcContext, s: String) -> String { + let result = s.trim().to_string(); + println!(" [str] trim({:?}) = {:?}", s, result); + result + } + + fn replace(&self, _ctx: &RpcContext, s: String, from: String, to: String) -> String { + let result = s.replace(&from, &to); + println!( + " [str] replace({:?}, {:?}, {:?}) = {:?}", + s, from, to, result + ); + result + } +} + +fn main() { + env_logger::init(); + let event_loop = Arc::new(EventLoop::new(4, 1024, 100).unwrap()); + + let addr = "127.0.0.1:9004".parse().unwrap(); + let _server = RpcServer::builder() + .bind(addr) + .service(MathServiceDispatcher(MathImpl)) // service_id = 0 + .service(StringServiceDispatcher(StringImpl)) // service_id = 1 + .build(&event_loop) + .expect("Failed to start multi-service server"); + + println!("Multi-service server listening on {}", addr); + println!(" Service 0: MathService (factorial, fibonacci, is_prime, gcd)"); + println!(" Service 1: StringService (reverse, word_count, contains, trim, replace)"); + event_loop.run().unwrap(); +} diff --git a/mill-rpc/mill-rpc-core/Cargo.toml b/mill-rpc/mill-rpc-core/Cargo.toml new file mode 100644 index 0000000..c00fc50 --- /dev/null +++ b/mill-rpc/mill-rpc-core/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "mill-rpc-core" +version = "0.1.0" +edition = "2021" +description = "Core types, wire protocol, and codec traits for Mill-RPC" + +[dependencies] +serde = { version = "1", features = ["derive"] } +bincode = "1" +thiserror = "2" diff --git a/mill-rpc/mill-rpc-core/src/codec.rs b/mill-rpc/mill-rpc-core/src/codec.rs new file mode 100644 index 0000000..54d3994 --- /dev/null +++ b/mill-rpc/mill-rpc-core/src/codec.rs @@ -0,0 +1,48 @@ +use crate::error::RpcError; +use serde::{de::DeserializeOwned, Serialize}; + +/// Supported codec types. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CodecType { + Bincode, +} + +/// Codec for serializing/deserializing RPC payloads. +#[derive(Debug, Clone)] +pub struct Codec { + codec_type: CodecType, +} + +impl Codec { + pub fn bincode() -> Self { + Self { + codec_type: CodecType::Bincode, + } + } + + pub fn codec_type(&self) -> CodecType { + self.codec_type + } + + /// Serialize a value to bytes. + pub fn serialize(&self, value: &T) -> Result, RpcError> { + match self.codec_type { + CodecType::Bincode => bincode::serialize(value) + .map_err(|e| RpcError::codec_error(format!("serialize: {}", e))), + } + } + + /// Deserialize bytes to a value. + pub fn deserialize(&self, data: &[u8]) -> Result { + match self.codec_type { + CodecType::Bincode => bincode::deserialize(data) + .map_err(|e| RpcError::codec_error(format!("deserialize: {}", e))), + } + } +} + +impl Default for Codec { + fn default() -> Self { + Self::bincode() + } +} diff --git a/mill-rpc/mill-rpc-core/src/context.rs b/mill-rpc/mill-rpc-core/src/context.rs new file mode 100644 index 0000000..4aca239 --- /dev/null +++ b/mill-rpc/mill-rpc-core/src/context.rs @@ -0,0 +1,32 @@ +use std::net::SocketAddr; + +/// Context available to RPC handlers during request processing. +/// +/// Provides metadata about the current request and the connection it arrived on. +#[derive(Debug, Clone)] +pub struct RpcContext { + /// Unique ID for this request. + pub request_id: u64, + /// Peer address of the client. + pub peer_addr: Option, + /// Service ID being called. + pub service_id: u16, + /// Method ID being called. + pub method_id: u16, +} + +impl RpcContext { + pub fn new(request_id: u64, service_id: u16, method_id: u16) -> Self { + Self { + request_id, + peer_addr: None, + service_id, + method_id, + } + } + + pub fn with_peer_addr(mut self, addr: SocketAddr) -> Self { + self.peer_addr = Some(addr); + self + } +} diff --git a/mill-rpc/mill-rpc-core/src/error.rs b/mill-rpc/mill-rpc-core/src/error.rs new file mode 100644 index 0000000..fb8815a --- /dev/null +++ b/mill-rpc/mill-rpc-core/src/error.rs @@ -0,0 +1,111 @@ +use serde::{Deserialize, Serialize}; +use std::fmt; + +/// RPC status codes (inspired by gRPC). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[repr(u16)] +pub enum RpcStatus { + Ok = 0, + Cancelled = 1, + InvalidArgument = 2, + NotFound = 3, + AlreadyExists = 4, + PermissionDenied = 5, + Unauthenticated = 6, + ResourceExhausted = 7, + Internal = 8, + Unavailable = 9, + DeadlineExceeded = 10, +} + +impl fmt::Display for RpcStatus { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + RpcStatus::Ok => write!(f, "OK"), + RpcStatus::Cancelled => write!(f, "CANCELLED"), + RpcStatus::InvalidArgument => write!(f, "INVALID_ARGUMENT"), + RpcStatus::NotFound => write!(f, "NOT_FOUND"), + RpcStatus::AlreadyExists => write!(f, "ALREADY_EXISTS"), + RpcStatus::PermissionDenied => write!(f, "PERMISSION_DENIED"), + RpcStatus::Unauthenticated => write!(f, "UNAUTHENTICATED"), + RpcStatus::ResourceExhausted => write!(f, "RESOURCE_EXHAUSTED"), + RpcStatus::Internal => write!(f, "INTERNAL"), + RpcStatus::Unavailable => write!(f, "UNAVAILABLE"), + RpcStatus::DeadlineExceeded => write!(f, "DEADLINE_EXCEEDED"), + } + } +} + +/// Structured RPC error with status code and message. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RpcError { + pub status: RpcStatus, + pub message: String, +} + +impl RpcError { + pub fn new(status: RpcStatus, message: impl Into) -> Self { + Self { + status, + message: message.into(), + } + } + + pub fn internal(message: impl Into) -> Self { + Self::new(RpcStatus::Internal, message) + } + + pub fn invalid_argument(message: impl Into) -> Self { + Self::new(RpcStatus::InvalidArgument, message) + } + + pub fn not_found(message: impl Into) -> Self { + Self::new(RpcStatus::NotFound, message) + } + + pub fn method_not_found(method_id: u16) -> Self { + Self::new( + RpcStatus::NotFound, + format!("Method not found: {}", method_id), + ) + } + + pub fn service_not_found(service_id: u16) -> Self { + Self::new( + RpcStatus::NotFound, + format!("Service not found: {}", service_id), + ) + } + + pub fn codec_error(message: impl Into) -> Self { + Self::new(RpcStatus::Internal, message) + } + + pub fn unavailable(message: impl Into) -> Self { + Self::new(RpcStatus::Unavailable, message) + } + + pub fn deadline_exceeded(message: impl Into) -> Self { + Self::new(RpcStatus::DeadlineExceeded, message) + } +} + +impl fmt::Display for RpcError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[{}] {}", self.status, self.message) + } +} + +impl std::error::Error for RpcError {} + +impl From for RpcError { + fn from(err: std::io::Error) -> Self { + RpcError::internal(err.to_string()) + } +} + +impl From for RpcError { + fn from(err: bincode::Error) -> Self { + RpcError::codec_error(format!("bincode: {}", err)) + } +} diff --git a/mill-rpc/mill-rpc-core/src/lib.rs b/mill-rpc/mill-rpc-core/src/lib.rs new file mode 100644 index 0000000..060abc1 --- /dev/null +++ b/mill-rpc/mill-rpc-core/src/lib.rs @@ -0,0 +1,35 @@ +//! Core types for Mill-RPC: wire protocol, codecs, errors, and dispatcher traits. + +pub mod codec; +pub mod context; +pub mod error; +pub mod protocol; + +pub use codec::{Codec, CodecType}; +pub use context::RpcContext; +pub use error::{RpcError, RpcStatus}; +pub use protocol::{Flags, Frame, FrameHeader, MessageType}; + +/// Trait for dispatching RPC calls to handler methods. +/// +/// This is auto-implemented by the `#[mill_rpc::service]` macro for any type +/// that implements the generated `{Service}Server` trait. +pub trait ServiceDispatch: Send + Sync + 'static { + fn dispatch( + &self, + ctx: &RpcContext, + method_id: u16, + args: &[u8], + codec: &Codec, + ) -> Result, RpcError>; +} + +/// Trait for client-side RPC transport. +/// +/// Abstracts the mechanism of sending a request and receiving a response. +/// The main `mill-rpc` crate provides an implementation built on `mill-net`. +pub trait RpcTransport: Send + Sync + 'static { + /// Send a request and wait for a response. + /// Returns the raw response payload bytes. + fn call(&self, service_id: u16, method_id: u16, payload: Vec) -> Result, RpcError>; +} diff --git a/mill-rpc/mill-rpc-core/src/protocol.rs b/mill-rpc/mill-rpc-core/src/protocol.rs new file mode 100644 index 0000000..d40ce1d --- /dev/null +++ b/mill-rpc/mill-rpc-core/src/protocol.rs @@ -0,0 +1,391 @@ +//! Wire protocol: frame format for Mill-RPC. +//! +//! ```text +//! +--------+--------+-------+--------+-----------+---------+ +//! | Magic | Version| Flags | MsgType| PayloadLen| Payload | +//! | 2B | 1B | 1B | 1B | 4B (LE) | N bytes | +//! +--------+--------+-------+--------+-----------+---------+ +//! ``` +//! +//! Request payload: +//! ```text +//! +------------+-----------+-----------+---------+ +//! | RequestID | ServiceID | MethodID | Args | +//! | 8B (LE) | 2B (LE) | 2B (LE) | N bytes | +//! +------------+-----------+-----------+---------+ +//! ``` + +use crate::error::RpcError; +use serde::{Deserialize, Serialize}; + +/// Magic bytes identifying Mill-RPC frames. +pub const MAGIC: [u8; 2] = [0x4D, 0x52]; // "MR" + +/// Current protocol version. +pub const VERSION: u8 = 1; + +/// Header size in bytes (magic:2 + version:1 + flags:1 + msg_type:1 + payload_len:4 = 9). +pub const HEADER_SIZE: usize = 9; + +/// Maximum payload size (16 MB). +pub const MAX_PAYLOAD_SIZE: u32 = 16 * 1024 * 1024; + +/// Message types in the wire protocol. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[repr(u8)] +pub enum MessageType { + Request = 0x01, + Response = 0x02, + Error = 0x03, + Ping = 0x04, + Pong = 0x05, + Cancel = 0x06, +} + +impl MessageType { + pub fn from_u8(v: u8) -> Result { + match v { + 0x01 => Ok(MessageType::Request), + 0x02 => Ok(MessageType::Response), + 0x03 => Ok(MessageType::Error), + 0x04 => Ok(MessageType::Ping), + 0x05 => Ok(MessageType::Pong), + 0x06 => Ok(MessageType::Cancel), + _ => Err(RpcError::invalid_argument(format!( + "Unknown message type: 0x{:02X}", + v + ))), + } + } +} + +/// Bit flags for frame options. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Flags(pub u8); + +impl Flags { + pub const NONE: Flags = Flags(0); + pub const COMPRESSED: Flags = Flags(1 << 0); + pub const ONE_WAY: Flags = Flags(1 << 1); + + pub fn is_one_way(self) -> bool { + self.0 & Self::ONE_WAY.0 != 0 + } + + pub fn is_compressed(self) -> bool { + self.0 & Self::COMPRESSED.0 != 0 + } +} + +/// Parsed frame header. +#[derive(Debug, Clone)] +pub struct FrameHeader { + pub version: u8, + pub flags: Flags, + pub message_type: MessageType, + pub payload_len: u32, +} + +impl FrameHeader { + /// Encode the header into a 9-byte array. + pub fn encode(&self) -> [u8; HEADER_SIZE] { + let mut buf = [0u8; HEADER_SIZE]; + buf[0] = MAGIC[0]; + buf[1] = MAGIC[1]; + buf[2] = self.version; + buf[3] = self.flags.0; + buf[4] = self.message_type as u8; + buf[5..9].copy_from_slice(&self.payload_len.to_le_bytes()); + buf + } + + /// Decode a header from a 9-byte slice. + pub fn decode(buf: &[u8; HEADER_SIZE]) -> Result { + if buf[0] != MAGIC[0] || buf[1] != MAGIC[1] { + return Err(RpcError::invalid_argument(format!( + "Invalid magic: [{:#04X}, {:#04X}]", + buf[0], buf[1] + ))); + } + + let version = buf[2]; + if version != VERSION { + return Err(RpcError::invalid_argument(format!( + "Unsupported version: {}", + version + ))); + } + + let flags = Flags(buf[3]); + let message_type = MessageType::from_u8(buf[4])?; + let payload_len = u32::from_le_bytes([buf[5], buf[6], buf[7], buf[8]]); + + if payload_len > MAX_PAYLOAD_SIZE { + return Err(RpcError::invalid_argument(format!( + "Payload too large: {} > {}", + payload_len, MAX_PAYLOAD_SIZE + ))); + } + + Ok(Self { + version, + flags, + message_type, + payload_len, + }) + } +} + +/// A complete frame (header + payload). +#[derive(Debug, Clone)] +pub struct Frame { + pub header: FrameHeader, + pub payload: Vec, +} + +impl Frame { + /// Create a new request frame. + pub fn request( + request_id: u64, + service_id: u16, + method_id: u16, + args: Vec, + one_way: bool, + ) -> Self { + let mut payload = Vec::with_capacity(12 + args.len()); + payload.extend_from_slice(&request_id.to_le_bytes()); + payload.extend_from_slice(&service_id.to_le_bytes()); + payload.extend_from_slice(&method_id.to_le_bytes()); + payload.extend_from_slice(&args); + + let flags = if one_way { Flags::ONE_WAY } else { Flags::NONE }; + + Frame { + header: FrameHeader { + version: VERSION, + flags, + message_type: MessageType::Request, + payload_len: payload.len() as u32, + }, + payload, + } + } + + /// Create a response frame. + pub fn response(request_id: u64, data: Vec) -> Self { + let mut payload = Vec::with_capacity(8 + data.len()); + payload.extend_from_slice(&request_id.to_le_bytes()); + payload.extend_from_slice(&data); + + Frame { + header: FrameHeader { + version: VERSION, + flags: Flags::NONE, + message_type: MessageType::Response, + payload_len: payload.len() as u32, + }, + payload, + } + } + + /// Create an error frame. + pub fn error(request_id: u64, error_data: Vec) -> Self { + let mut payload = Vec::with_capacity(8 + error_data.len()); + payload.extend_from_slice(&request_id.to_le_bytes()); + payload.extend_from_slice(&error_data); + + Frame { + header: FrameHeader { + version: VERSION, + flags: Flags::NONE, + message_type: MessageType::Error, + payload_len: payload.len() as u32, + }, + payload, + } + } + + /// Create a ping frame. + pub fn ping() -> Self { + Frame { + header: FrameHeader { + version: VERSION, + flags: Flags::NONE, + message_type: MessageType::Ping, + payload_len: 0, + }, + payload: Vec::new(), + } + } + + /// Create a pong frame. + pub fn pong() -> Self { + Frame { + header: FrameHeader { + version: VERSION, + flags: Flags::NONE, + message_type: MessageType::Pong, + payload_len: 0, + }, + payload: Vec::new(), + } + } + + /// Encode the full frame to bytes. + pub fn encode(&self) -> Vec { + let header_bytes = self.header.encode(); + let mut buf = Vec::with_capacity(HEADER_SIZE + self.payload.len()); + buf.extend_from_slice(&header_bytes); + buf.extend_from_slice(&self.payload); + buf + } + + /// Parse request payload fields (request_id, service_id, method_id, args). + pub fn parse_request_payload(&self) -> Result<(u64, u16, u16, &[u8]), RpcError> { + if self.payload.len() < 12 { + return Err(RpcError::invalid_argument("Request payload too short")); + } + let request_id = u64::from_le_bytes(self.payload[0..8].try_into().unwrap()); + let service_id = u16::from_le_bytes(self.payload[8..10].try_into().unwrap()); + let method_id = u16::from_le_bytes(self.payload[10..12].try_into().unwrap()); + let args = &self.payload[12..]; + Ok((request_id, service_id, method_id, args)) + } + + /// Parse response payload fields (request_id, data). + pub fn parse_response_payload(&self) -> Result<(u64, &[u8]), RpcError> { + if self.payload.len() < 8 { + return Err(RpcError::invalid_argument("Response payload too short")); + } + let request_id = u64::from_le_bytes(self.payload[0..8].try_into().unwrap()); + let data = &self.payload[8..]; + Ok((request_id, data)) + } +} + +/// Reads frames from a byte buffer. Returns parsed frames and the number of bytes consumed. +/// +/// This is a streaming parser: it handles partial frames by returning only +/// complete frames and leaving remaining bytes unconsumed. +pub fn parse_frames(buf: &[u8]) -> Result<(Vec, usize), RpcError> { + let mut frames = Vec::new(); + let mut offset = 0; + + while offset + HEADER_SIZE <= buf.len() { + let header_bytes: &[u8; HEADER_SIZE] = buf[offset..offset + HEADER_SIZE] + .try_into() + .map_err(|_| RpcError::internal("Header slice conversion failed"))?; + + let header = FrameHeader::decode(header_bytes)?; + let total_frame_size = HEADER_SIZE + header.payload_len as usize; + + if offset + total_frame_size > buf.len() { + // Incomplete frame, wait for more data. + break; + } + + let payload = buf[offset + HEADER_SIZE..offset + total_frame_size].to_vec(); + frames.push(Frame { header, payload }); + offset += total_frame_size; + } + + Ok((frames, offset)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_header_roundtrip() { + let header = FrameHeader { + version: VERSION, + flags: Flags::NONE, + message_type: MessageType::Request, + payload_len: 42, + }; + let encoded = header.encode(); + let decoded = FrameHeader::decode(&encoded).unwrap(); + assert_eq!(decoded.version, VERSION); + assert_eq!(decoded.flags, Flags::NONE); + assert_eq!(decoded.message_type, MessageType::Request); + assert_eq!(decoded.payload_len, 42); + } + + #[test] + fn test_request_frame_roundtrip() { + let frame = Frame::request(123, 1, 2, vec![10, 20, 30], false); + let bytes = frame.encode(); + let (frames, consumed) = parse_frames(&bytes).unwrap(); + assert_eq!(consumed, bytes.len()); + assert_eq!(frames.len(), 1); + + let (req_id, svc_id, method_id, args) = frames[0].parse_request_payload().unwrap(); + assert_eq!(req_id, 123); + assert_eq!(svc_id, 1); + assert_eq!(method_id, 2); + assert_eq!(args, &[10, 20, 30]); + } + + #[test] + fn test_response_frame_roundtrip() { + let frame = Frame::response(456, vec![1, 2, 3]); + let bytes = frame.encode(); + let (frames, _) = parse_frames(&bytes).unwrap(); + let (req_id, data) = frames[0].parse_response_payload().unwrap(); + assert_eq!(req_id, 456); + assert_eq!(data, &[1, 2, 3]); + } + + #[test] + fn test_multiple_frames() { + let f1 = Frame::request(1, 0, 0, vec![0xAA], false); + let f2 = Frame::response(1, vec![0xBB]); + let mut bytes = f1.encode(); + bytes.extend_from_slice(&f2.encode()); + + let (frames, consumed) = parse_frames(&bytes).unwrap(); + assert_eq!(consumed, bytes.len()); + assert_eq!(frames.len(), 2); + assert_eq!(frames[0].header.message_type, MessageType::Request); + assert_eq!(frames[1].header.message_type, MessageType::Response); + } + + #[test] + fn test_partial_frame() { + let frame = Frame::request(1, 0, 0, vec![0xAA; 100], false); + let bytes = frame.encode(); + // Only give half the bytes + let partial = &bytes[..bytes.len() / 2]; + let (frames, consumed) = parse_frames(partial).unwrap(); + assert_eq!(frames.len(), 0); + assert_eq!(consumed, 0); + } + + #[test] + fn test_ping_pong() { + let ping = Frame::ping(); + let pong = Frame::pong(); + assert_eq!(ping.header.message_type, MessageType::Ping); + assert_eq!(pong.header.message_type, MessageType::Pong); + assert_eq!(ping.payload.len(), 0); + assert_eq!(pong.payload.len(), 0); + } + + #[test] + fn test_invalid_magic() { + let mut buf = [0u8; HEADER_SIZE]; + buf[0] = 0xFF; + buf[1] = 0xFF; + assert!(FrameHeader::decode(&buf).is_err()); + } + + #[test] + fn test_one_way_flag() { + let frame = Frame::request(1, 0, 0, vec![], true); + assert!(frame.header.flags.is_one_way()); + + let frame = Frame::request(1, 0, 0, vec![], false); + assert!(!frame.header.flags.is_one_way()); + } +} diff --git a/mill-rpc/mill-rpc-macros/Cargo.toml b/mill-rpc/mill-rpc-macros/Cargo.toml new file mode 100644 index 0000000..2a05b5d --- /dev/null +++ b/mill-rpc/mill-rpc-macros/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "mill-rpc-macros" +version = "0.1.0" +edition = "2021" +description = "Proc macros for Mill-RPC service definitions" + +[lib] +proc-macro = true + +[dependencies] +syn = { version = "2", features = ["full"] } +quote = "1" +proc-macro2 = "1" diff --git a/mill-rpc/mill-rpc-macros/src/lib.rs b/mill-rpc/mill-rpc-macros/src/lib.rs new file mode 100644 index 0000000..24e139f --- /dev/null +++ b/mill-rpc/mill-rpc-macros/src/lib.rs @@ -0,0 +1,310 @@ +use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use syn::{parse_macro_input, FnArg, Ident, ItemTrait, Pat, ReturnType, TraitItem, Type}; + +/// Attribute macro for defining an RPC service. +/// +/// Applied to a trait, it generates: +/// - `{Name}Server` trait with `&self, ctx: &RpcContext` prepended to each method +/// - `{Name}Client` struct with typed RPC call methods +/// - `{Name}Dispatcher` wrapper that implements `ServiceDispatch` +/// - Per-method request/response types with serde derives +/// +/// # Example +/// +/// ```ignore +/// #[mill_rpc::service] +/// trait Calculator { +/// fn add(a: i32, b: i32) -> i32; +/// fn divide(a: f64, b: f64) -> Result; +/// } +/// ``` +#[proc_macro_attribute] +pub fn service(_attr: TokenStream, item: TokenStream) -> TokenStream { + let input = parse_macro_input!(item as ItemTrait); + match generate_service(input) { + Ok(tokens) => tokens.into(), + Err(err) => err.to_compile_error().into(), + } +} + +struct MethodInfo { + name: Ident, + method_id: u16, + args: Vec<(Ident, Type)>, + return_type: Type, +} + +fn generate_service(input: ItemTrait) -> syn::Result { + let trait_name = &input.ident; + let server_trait_name = format_ident!("{}Server", trait_name); + let client_struct_name = format_ident!("{}Client", trait_name); + let dispatcher_name = format_ident!("{}Dispatcher", trait_name); + let methods_mod_name = format_ident!("{}_methods", to_snake_case(&trait_name.to_string())); + + // Parse methods from the trait + let mut methods = Vec::new(); + for (idx, item) in input.items.iter().enumerate() { + let method = match item { + TraitItem::Fn(m) => m, + _ => continue, + }; + + let name = method.sig.ident.clone(); + + // Extract arguments (skip &self if present, though we don't expect it) + let mut args = Vec::new(); + for arg in &method.sig.inputs { + match arg { + FnArg::Typed(pat_type) => { + let pat = &*pat_type.pat; + let ident = match pat { + Pat::Ident(pi) => pi.ident.clone(), + _ => { + return Err(syn::Error::new_spanned( + pat, + "Expected a simple identifier pattern for argument", + )) + } + }; + let ty = (*pat_type.ty).clone(); + args.push((ident, ty)); + } + FnArg::Receiver(_) => { + return Err(syn::Error::new_spanned( + arg, + "#[mill_rpc::service] trait methods should not have `self` parameter", + )); + } + } + } + + // Extract return type + let return_type = match &method.sig.output { + ReturnType::Default => syn::parse_quote!(()), + ReturnType::Type(_, ty) => (**ty).clone(), + }; + + methods.push(MethodInfo { + name, + method_id: idx as u16, + args, + return_type, + }); + } + + // Generate method ID constants + let method_consts: Vec<_> = methods + .iter() + .map(|m| { + let const_name = format_ident!("{}", m.name.to_string().to_uppercase()); + let id = m.method_id; + quote! { pub const #const_name: u16 = #id; } + }) + .collect(); + + // Generate per-method request/response structs + let mut type_defs = Vec::new(); + for m in &methods { + let req_name = format_ident!("{}Request", to_pascal_case(&m.name.to_string())); + let resp_name = format_ident!("{}Response", to_pascal_case(&m.name.to_string())); + let ret_ty = &m.return_type; + + let field_names: Vec<_> = m.args.iter().map(|(name, _)| name).collect(); + let field_types: Vec<_> = m.args.iter().map(|(_, ty)| ty).collect(); + + let req_struct = if m.args.is_empty() { + quote! { + #[derive(::serde::Serialize, ::serde::Deserialize, Debug)] + pub struct #req_name; + } + } else { + quote! { + #[derive(::serde::Serialize, ::serde::Deserialize, Debug)] + pub struct #req_name { + #( pub #field_names: #field_types, )* + } + } + }; + + type_defs.push(quote! { + #req_struct + + #[derive(::serde::Serialize, ::serde::Deserialize, Debug)] + pub struct #resp_name(pub #ret_ty); + }); + } + + // Generate Server trait + let server_methods: Vec<_> = methods + .iter() + .map(|m| { + let name = &m.name; + let ret_ty = &m.return_type; + let arg_names: Vec<_> = m.args.iter().map(|(n, _)| n).collect(); + let arg_types: Vec<_> = m.args.iter().map(|(_, t)| t).collect(); + + quote! { + fn #name(&self, ctx: &::mill_rpc_core::RpcContext, #( #arg_names: #arg_types ),*) -> #ret_ty; + } + }) + .collect(); + + // Generate dispatcher match arms + let dispatch_arms: Vec<_> = methods + .iter() + .map(|m| { + let name = &m.name; + let const_name = format_ident!("{}", m.name.to_string().to_uppercase()); + let req_name = format_ident!("{}Request", to_pascal_case(&m.name.to_string())); + let resp_name = format_ident!("{}Response", to_pascal_case(&m.name.to_string())); + + let field_names: Vec<_> = m.args.iter().map(|(n, _)| n).collect(); + + let call_args = if m.args.is_empty() { + quote! {} + } else { + let args: Vec<_> = field_names.iter().map(|n| quote! { req.#n }).collect(); + quote! { , #( #args ),* } + }; + + quote! { + #methods_mod_name::#const_name => { + let req: #req_name = codec.deserialize(args)?; + let result = self.0.#name(ctx #call_args); + codec.serialize(&#resp_name(result)) + } + } + }) + .collect(); + + // Generate client methods + let client_methods: Vec<_> = methods + .iter() + .map(|m| { + let name = &m.name; + let ret_ty = &m.return_type; + let const_name = format_ident!("{}", m.name.to_string().to_uppercase()); + let req_name = format_ident!("{}Request", to_pascal_case(&m.name.to_string())); + let resp_name = format_ident!("{}Response", to_pascal_case(&m.name.to_string())); + + let arg_names: Vec<_> = m.args.iter().map(|(n, _)| n).collect(); + let arg_types: Vec<_> = m.args.iter().map(|(_, t)| t).collect(); + + let req_construct = if m.args.is_empty() { + quote! { #req_name } + } else { + quote! { #req_name { #( #arg_names: #arg_names, )* } } + }; + + quote! { + pub fn #name(&self, #( #arg_names: #arg_types ),*) -> Result<#ret_ty, ::mill_rpc_core::RpcError> { + let req = #req_construct; + let payload = self.codec.serialize(&req)?; + let resp_bytes = self.transport.call( + self.service_id, + #methods_mod_name::#const_name, + payload, + )?; + let resp: #resp_name = self.codec.deserialize(&resp_bytes)?; + Ok(resp.0) + } + } + }) + .collect(); + + // Count methods for service registration + let method_count = methods.len() as u16; + let service_name_str = trait_name.to_string(); + + let output = quote! { + /// Method ID constants for this service. + pub mod #methods_mod_name { + #( #method_consts )* + } + + #( #type_defs )* + + /// Server trait - implement this to handle RPC calls. + pub trait #server_trait_name: Send + Sync + 'static { + #( #server_methods )* + } + + /// Service descriptor for registration. + pub struct #trait_name; + + impl #trait_name { + pub const SERVICE_NAME: &'static str = #service_name_str; + pub const METHOD_COUNT: u16 = #method_count; + } + + /// Wrapper that adapts a `{Name}Server` impl into a `ServiceDispatch`. + /// + /// This avoids orphan-rule violations by being a local concrete type. + pub struct #dispatcher_name(pub T); + + impl ::mill_rpc_core::ServiceDispatch for #dispatcher_name { + fn dispatch( + &self, + ctx: &::mill_rpc_core::RpcContext, + method_id: u16, + args: &[u8], + codec: &::mill_rpc_core::Codec, + ) -> Result, ::mill_rpc_core::RpcError> { + match method_id { + #( #dispatch_arms, )* + _ => Err(::mill_rpc_core::RpcError::method_not_found(method_id)), + } + } + } + + /// Generated client for calling this service remotely. + pub struct #client_struct_name { + transport: ::std::sync::Arc, + codec: ::mill_rpc_core::Codec, + service_id: u16, + } + + impl #client_struct_name { + /// Create a new client from a transport, codec, and service ID. + pub fn new( + transport: ::std::sync::Arc, + codec: ::mill_rpc_core::Codec, + service_id: u16, + ) -> Self { + Self { transport, codec, service_id } + } + + #( #client_methods )* + } + }; + + Ok(output) +} + +fn to_snake_case(s: &str) -> String { + let mut result = String::new(); + for (i, ch) in s.chars().enumerate() { + if ch.is_uppercase() { + if i > 0 { + result.push('_'); + } + result.push(ch.to_lowercase().next().unwrap()); + } else { + result.push(ch); + } + } + result +} + +fn to_pascal_case(s: &str) -> String { + s.split('_') + .map(|part| { + let mut chars = part.chars(); + match chars.next() { + None => String::new(), + Some(c) => c.to_uppercase().to_string() + chars.as_str(), + } + }) + .collect() +} diff --git a/mill-rpc/src/client.rs b/mill-rpc/src/client.rs new file mode 100644 index 0000000..f13d43a --- /dev/null +++ b/mill-rpc/src/client.rs @@ -0,0 +1,257 @@ +//! RPC Client built on mill-net's TcpClient. +//! +//! Connects to an RPC server, sends request frames, and waits for responses. + +use crate::{Codec, RpcError, RpcStatus, RpcTransport}; +use mill_io::EventLoop; +use mill_net::tcp::traits::{ConnectionId, NetworkHandler}; +use mill_net::tcp::{ServerContext, TcpClient}; +use mill_rpc_core::protocol::{self, Frame, MessageType, HEADER_SIZE}; +use mio::Token; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Condvar, Mutex}; +use std::time::Duration; + +/// Pending request waiting for a response. +struct PendingRequest { + completed: bool, + result: Option, RpcError>>, +} + +/// Shared state between the client and its network handler. +struct ClientShared { + /// Map of request_id -> pending request. + pending: Mutex>, + /// Condvar to notify waiting callers when a response arrives. + notify: Condvar, + /// Receive buffer for partial frame parsing. + recv_buf: Mutex>, +} + +/// RPC client that connects to a Mill-RPC server. +pub struct RpcClient { + tcp_client: Mutex>, + shared: Arc, + next_request_id: AtomicU64, + codec: Codec, + timeout: Duration, +} + +impl RpcClient { + /// Connect to an RPC server. + pub fn connect( + addr: SocketAddr, + event_loop: &Arc, + codec: Codec, + ) -> Result, RpcError> { + let shared = Arc::new(ClientShared { + pending: Mutex::new(HashMap::new()), + notify: Condvar::new(), + recv_buf: Mutex::new(Vec::new()), + }); + + let handler = RpcClientHandler { + shared: shared.clone(), + }; + + let mut tcp_client = TcpClient::connect(addr, handler) + .map_err(|e| RpcError::unavailable(format!("Connect failed: {}", e)))?; + + tcp_client + .start(event_loop, Token(usize::MAX - 1)) + .map_err(|e| RpcError::unavailable(format!("Client start failed: {}", e)))?; + + Ok(Arc::new(Self { + tcp_client: Mutex::new(tcp_client), + shared, + next_request_id: AtomicU64::new(1), + codec, + timeout: Duration::from_secs(30), + })) + } + + /// Set the default timeout for RPC calls. + pub fn set_timeout(&mut self, timeout: Duration) { + self.timeout = timeout; + } + + /// Send a request and wait for a response (blocking). + fn call_raw( + &self, + service_id: u16, + method_id: u16, + payload: Vec, + ) -> Result, RpcError> { + let request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst); + + // Register the pending request before sending. + { + let mut pending = self.shared.pending.lock().unwrap(); + pending.insert( + request_id, + PendingRequest { + completed: false, + result: None, + }, + ); + } + + // Build and send the request frame. + let frame = Frame::request(request_id, service_id, method_id, payload, false); + { + let client = self.tcp_client.lock().unwrap(); + client + .send(&frame.encode()) + .map_err(|e| RpcError::unavailable(format!("Send failed: {}", e)))?; + } + + // Wait for the response with timeout. + let mut pending = self.shared.pending.lock().unwrap(); + let deadline = std::time::Instant::now() + self.timeout; + + loop { + if let Some(req) = pending.get(&request_id) { + if req.completed { + let req = pending.remove(&request_id).unwrap(); + return req.result.unwrap(); + } + } else { + return Err(RpcError::internal("Pending request disappeared")); + } + + let remaining = deadline.saturating_duration_since(std::time::Instant::now()); + if remaining.is_zero() { + pending.remove(&request_id); + return Err(RpcError::deadline_exceeded(format!( + "Request {} timed out after {:?}", + request_id, self.timeout + ))); + } + + let (guard, timeout_result) = + self.shared.notify.wait_timeout(pending, remaining).unwrap(); + pending = guard; + + if timeout_result.timed_out() { + pending.remove(&request_id); + return Err(RpcError::deadline_exceeded(format!( + "Request {} timed out after {:?}", + request_id, self.timeout + ))); + } + } + } + + /// Send a one-way request (fire-and-forget). + pub fn call_oneway( + &self, + service_id: u16, + method_id: u16, + payload: Vec, + ) -> Result<(), RpcError> { + let request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst); + let frame = Frame::request(request_id, service_id, method_id, payload, true); + + let client = self.tcp_client.lock().unwrap(); + client + .send(&frame.encode()) + .map_err(|e| RpcError::unavailable(format!("Send failed: {}", e)))?; + Ok(()) + } +} + +impl RpcTransport for RpcClient { + fn call(&self, service_id: u16, method_id: u16, payload: Vec) -> Result, RpcError> { + self.call_raw(service_id, method_id, payload) + } +} + +/// Network handler for the RPC client - receives response frames. +struct RpcClientHandler { + shared: Arc, +} + +impl NetworkHandler for RpcClientHandler { + fn on_data( + &self, + _ctx: &ServerContext, + _conn_id: ConnectionId, + data: &[u8], + ) -> mill_net::errors::Result<()> { + let mut recv_buf = self.shared.recv_buf.lock().unwrap(); + recv_buf.extend_from_slice(data); + + let (frames, consumed) = match protocol::parse_frames(&recv_buf) { + Ok(r) => r, + Err(e) => { + log::error!("Client frame parse error: {}", e); + recv_buf.clear(); + return Ok(()); + } + }; + + if consumed > 0 { + recv_buf.drain(..consumed); + } + + drop(recv_buf); + + for frame in frames { + self.handle_frame(frame); + } + + Ok(()) + } +} + +impl RpcClientHandler { + fn handle_frame(&self, frame: Frame) { + match frame.header.message_type { + MessageType::Response => { + let (request_id, data) = match frame.parse_response_payload() { + Ok(r) => r, + Err(e) => { + log::error!("Invalid response payload: {}", e); + return; + } + }; + + let mut pending = self.shared.pending.lock().unwrap(); + if let Some(req) = pending.get_mut(&request_id) { + req.completed = true; + req.result = Some(Ok(data.to_vec())); + } + self.shared.notify.notify_all(); + } + MessageType::Error => { + let (request_id, err_data) = match frame.parse_response_payload() { + Ok(r) => r, + Err(e) => { + log::error!("Invalid error payload: {}", e); + return; + } + }; + + let rpc_err: RpcError = match bincode::deserialize(err_data) { + Ok(e) => e, + Err(_) => RpcError::internal(String::from_utf8_lossy(err_data).to_string()), + }; + + let mut pending = self.shared.pending.lock().unwrap(); + if let Some(req) = pending.get_mut(&request_id) { + req.completed = true; + req.result = Some(Err(rpc_err)); + } + self.shared.notify.notify_all(); + } + MessageType::Pong => { + log::debug!("Received pong from server"); + } + other => { + log::warn!("Unexpected message type from server: {:?}", other); + } + } + } +} diff --git a/mill-rpc/src/lib.rs b/mill-rpc/src/lib.rs new file mode 100644 index 0000000..68c7e44 --- /dev/null +++ b/mill-rpc/src/lib.rs @@ -0,0 +1,32 @@ +//! Mill-RPC: An Axum-inspired RPC framework built on Mill-IO. +//! +//! # Quick Start +//! +//! ```ignore +//! use mill_rpc::prelude::*; +//! +//! #[mill_rpc::service] +//! trait Calculator { +//! fn add(a: i32, b: i32) -> i32; +//! } +//! +//! struct MyCalc; +//! impl CalculatorServer for MyCalc { +//! fn add(&self, _ctx: &RpcContext, a: i32, b: i32) -> i32 { a + b } +//! } +//! ``` + +pub mod client; +pub mod server; + +pub mod prelude; + +// Re-exports +pub use mill_rpc_core::{ + Codec, CodecType, Flags, Frame, FrameHeader, MessageType, RpcContext, RpcError, RpcStatus, + RpcTransport, ServiceDispatch, +}; +pub use mill_rpc_macros::service; + +pub use client::RpcClient; +pub use server::RpcServer; diff --git a/mill-rpc/src/prelude.rs b/mill-rpc/src/prelude.rs new file mode 100644 index 0000000..73a22e9 --- /dev/null +++ b/mill-rpc/src/prelude.rs @@ -0,0 +1,8 @@ +//! Convenient re-exports for Mill-RPC users. + +pub use crate::client::RpcClient; +pub use crate::server::RpcServer; +pub use crate::{Codec, CodecType, RpcContext, RpcError, RpcStatus, RpcTransport, ServiceDispatch}; +pub use mill_rpc_macros::service; + +pub use serde::{Deserialize, Serialize}; diff --git a/mill-rpc/src/server.rs b/mill-rpc/src/server.rs new file mode 100644 index 0000000..4b1f2ee --- /dev/null +++ b/mill-rpc/src/server.rs @@ -0,0 +1,265 @@ +//! RPC Server built on mill-net's TcpServer. +//! +//! The server accepts TCP connections, parses Mill-RPC frames, dispatches +//! requests to registered services, and sends back responses. + +use crate::{Codec, RpcContext, RpcError, ServiceDispatch}; +use mill_io::EventLoop; +use mill_net::errors::{NetworkError, Result}; +use mill_net::tcp::config::TcpServerConfig; +use mill_net::tcp::traits::{ConnectionId, NetworkHandler}; +use mill_net::tcp::{ServerContext, TcpServer}; +use mill_rpc_core::protocol::{self, Frame, FrameHeader, MessageType, HEADER_SIZE}; +use mio::Token; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::{Arc, Mutex, RwLock}; + +/// A registered service with its dispatch implementation. +struct RegisteredService { + dispatcher: Box, +} + +/// RPC Server that hosts one or more services. +pub struct RpcServer { + _tcp_server: Arc>, +} + +impl RpcServer { + /// Create a builder for configuring the RPC server. + pub fn builder() -> RpcServerBuilder { + RpcServerBuilder::new() + } +} + +/// Builder for constructing an RpcServer. +pub struct RpcServerBuilder { + address: Option, + codec: Codec, + max_connections: Option, + services: Vec<(u16, Box)>, + next_service_id: u16, +} + +impl RpcServerBuilder { + fn new() -> Self { + Self { + address: None, + codec: Codec::bincode(), + max_connections: None, + services: Vec::new(), + next_service_id: 0, + } + } + + /// Set the address to bind to. + pub fn bind(mut self, addr: SocketAddr) -> Self { + self.address = Some(addr); + self + } + + /// Set the codec for serialization. + pub fn codec(mut self, codec: Codec) -> Self { + self.codec = codec; + self + } + + /// Set the maximum number of connections. + pub fn max_connections(mut self, max: usize) -> Self { + self.max_connections = Some(max); + self + } + + /// Register a service implementation wrapped in its dispatcher. + /// + /// The service must implement `ServiceDispatch` - typically via the + /// generated `{Name}Dispatcher` wrapper: + /// + /// ```ignore + /// .service(CalculatorDispatcher(MyCalculator)) + /// ``` + pub fn service(mut self, svc: S) -> Self { + let id = self.next_service_id; + self.next_service_id += 1; + self.services.push((id, Box::new(svc))); + self + } + + /// Register a service with an explicit service ID. + pub fn service_with_id(mut self, id: u16, svc: S) -> Self { + self.services.push((id, Box::new(svc))); + self + } + + /// Build and start the RPC server on the given event loop. + pub fn build(self, event_loop: &Arc) -> Result { + let addr = self + .address + .unwrap_or_else(|| "127.0.0.1:9000".parse().unwrap()); + + let mut service_map = HashMap::new(); + for (id, dispatcher) in self.services { + service_map.insert(id, RegisteredService { dispatcher }); + } + + let handler = RpcServerHandler { + services: Arc::new(RwLock::new(service_map)), + codec: self.codec, + conn_buffers: Arc::new(Mutex::new(HashMap::new())), + }; + + let mut config_builder = TcpServerConfig::builder().address(addr); + if let Some(max) = self.max_connections { + config_builder = config_builder.max_connections(max); + } + let config = config_builder.build(); + + let tcp_server = Arc::new(TcpServer::new(config, handler)?); + tcp_server.clone().start(event_loop, Token(0))?; + + log::info!("Mill-RPC server listening on {}", addr); + + Ok(RpcServer { + _tcp_server: tcp_server, + }) + } +} + +/// Internal handler that bridges mill-net's NetworkHandler to RPC dispatch. +struct RpcServerHandler { + services: Arc>>, + codec: Codec, + /// Per-connection receive buffers for handling partial frames. + conn_buffers: Arc>>>, +} + +impl NetworkHandler for RpcServerHandler { + fn on_connect(&self, _ctx: &ServerContext, conn_id: ConnectionId) -> Result<()> { + log::debug!("RPC connection established: {:?}", conn_id); + self.conn_buffers + .lock() + .unwrap() + .insert(conn_id.as_u64(), Vec::new()); + Ok(()) + } + + fn on_data(&self, ctx: &ServerContext, conn_id: ConnectionId, data: &[u8]) -> Result<()> { + // Append incoming data to the connection's buffer. + let mut buffers = self.conn_buffers.lock().unwrap(); + let buf = buffers.entry(conn_id.as_u64()).or_default(); + buf.extend_from_slice(data); + + // Try to parse complete frames. + let (frames, consumed) = match protocol::parse_frames(buf) { + Ok(result) => result, + Err(e) => { + log::error!("Frame parse error from {:?}: {}", conn_id, e); + // Clear buffer on parse error - connection is likely corrupted. + buf.clear(); + return Ok(()); + } + }; + + // Remove consumed bytes from the buffer. + if consumed > 0 { + buf.drain(..consumed); + } + + // Drop the lock before processing frames (handlers may be slow). + drop(buffers); + + // Process each complete frame. + for frame in frames { + self.handle_frame(ctx, conn_id, frame); + } + + Ok(()) + } + + fn on_disconnect(&self, _ctx: &ServerContext, conn_id: ConnectionId) -> Result<()> { + log::debug!("RPC connection closed: {:?}", conn_id); + self.conn_buffers.lock().unwrap().remove(&conn_id.as_u64()); + Ok(()) + } + + fn on_error(&self, _ctx: &ServerContext, conn_id: Option, error: NetworkError) { + log::error!("RPC network error (conn={:?}): {}", conn_id, error); + } +} + +impl RpcServerHandler { + fn handle_frame(&self, ctx: &ServerContext, conn_id: ConnectionId, frame: Frame) { + match frame.header.message_type { + MessageType::Request => self.handle_request(ctx, conn_id, frame), + MessageType::Ping => { + let pong = Frame::pong(); + if let Err(e) = ctx.send_to(conn_id, &pong.encode()) { + log::error!("Failed to send pong to {:?}: {}", conn_id, e); + } + } + MessageType::Cancel => { + log::debug!("Cancel received from {:?} (not yet supported)", conn_id); + } + other => { + log::warn!("Unexpected message type {:?} from {:?}", other, conn_id); + } + } + } + + fn handle_request(&self, ctx: &ServerContext, conn_id: ConnectionId, frame: Frame) { + let (request_id, service_id, method_id, args) = match frame.parse_request_payload() { + Ok(parsed) => parsed, + Err(e) => { + log::error!("Invalid request payload from {:?}: {}", conn_id, e); + return; + } + }; + + let is_one_way = frame.header.flags.is_one_way(); + + let rpc_ctx = RpcContext::new(request_id, service_id, method_id); + + // Look up the service and dispatch. + let result = { + let services = self.services.read().unwrap(); + match services.get(&service_id) { + Some(svc) => svc + .dispatcher + .dispatch(&rpc_ctx, method_id, args, &self.codec), + None => Err(RpcError::service_not_found(service_id)), + } + }; + + // Don't send a response for one-way calls. + if is_one_way { + if let Err(e) = &result { + log::error!("One-way request {}.{} failed: {}", service_id, method_id, e); + } + return; + } + + // Send response or error frame. + let response_frame = match result { + Ok(resp_bytes) => Frame::response(request_id, resp_bytes), + Err(rpc_err) => { + let err_bytes = match self.codec.serialize(&rpc_err) { + Ok(b) => b, + Err(e) => { + log::error!("Failed to serialize error: {}", e); + return; + } + }; + Frame::error(request_id, err_bytes) + } + }; + + if let Err(e) = ctx.send_to(conn_id, &response_frame.encode()) { + log::error!( + "Failed to send response for request {} to {:?}: {}", + request_id, + conn_id, + e + ); + } + } +} diff --git a/mill-rpc/tests/integration_test.rs b/mill-rpc/tests/integration_test.rs new file mode 100644 index 0000000..91876d0 --- /dev/null +++ b/mill-rpc/tests/integration_test.rs @@ -0,0 +1,104 @@ +use mill_io::EventLoop; +use mill_rpc::prelude::*; +use std::sync::Arc; +use std::thread; +use std::time::Duration; + +// ---- Service Definition ---- + +#[mill_rpc::service] +trait Calculator { + fn add(a: i32, b: i32) -> i32; + fn multiply(a: f64, b: f64) -> f64; + fn echo(msg: String) -> String; +} + +// ---- Server Implementation ---- + +struct MyCalculator; + +impl CalculatorServer for MyCalculator { + fn add(&self, _ctx: &RpcContext, a: i32, b: i32) -> i32 { + a + b + } + + fn multiply(&self, _ctx: &RpcContext, a: f64, b: f64) -> f64 { + a * b + } + + fn echo(&self, _ctx: &RpcContext, msg: String) -> String { + format!("echo: {}", msg) + } +} + +#[test] +fn test_rpc_roundtrip() { + let _ = env_logger::builder().is_test(true).try_init(); + + let codec = Codec::bincode(); + let ctx = RpcContext::new(1, 0, 0); + let dispatcher = CalculatorDispatcher(MyCalculator); + + // Test add + let args = codec.serialize(&AddRequest { a: 2, b: 3 }).unwrap(); + let result_bytes = dispatcher + .dispatch(&ctx, calculator_methods::ADD, &args, &codec) + .unwrap(); + let result: AddResponse = codec.deserialize(&result_bytes).unwrap(); + assert_eq!(result.0, 5); + + // Test multiply + let args = codec + .serialize(&MultiplyRequest { a: 3.0, b: 4.0 }) + .unwrap(); + let result_bytes = dispatcher + .dispatch(&ctx, calculator_methods::MULTIPLY, &args, &codec) + .unwrap(); + let result: MultiplyResponse = codec.deserialize(&result_bytes).unwrap(); + assert!((result.0 - 12.0).abs() < f64::EPSILON); + + // Test echo + let args = codec + .serialize(&EchoRequest { + msg: "hello".to_string(), + }) + .unwrap(); + let result_bytes = dispatcher + .dispatch(&ctx, calculator_methods::ECHO, &args, &codec) + .unwrap(); + let result: EchoResponse = codec.deserialize(&result_bytes).unwrap(); + assert_eq!(result.0, "echo: hello"); + + // Test method not found + let result = dispatcher.dispatch(&ctx, 999, &[], &codec); + assert!(result.is_err()); +} + +#[test] +fn test_rpc_server_client_integration() { + let _ = env_logger::builder().is_test(true).try_init(); + + let event_loop = Arc::new(EventLoop::new(4, 1024, 100).unwrap()); + + let server = RpcServer::builder() + .bind("127.0.0.1:0".parse().unwrap()) + .service(CalculatorDispatcher(MyCalculator)) + .build(&event_loop); + + match server { + Ok(_server) => { + let el = event_loop.clone(); + let handle = thread::spawn(move || { + let _ = el.run(); + }); + + thread::sleep(Duration::from_millis(100)); + + event_loop.stop(); + let _ = handle.join(); + } + Err(e) => { + eprintln!("Server build failed (non-fatal in test): {}", e); + } + } +} From e5163db1af5b03aff29dbe024f7afc699dd9b28f Mon Sep 17 00:00:00 2001 From: hulxv Date: Sun, 1 Mar 2026 02:52:27 +0200 Subject: [PATCH 2/7] enhance service macro --- mill-rpc/README.md | 208 ++++----- mill-rpc/examples/calculator_client.rs | 26 +- mill-rpc/examples/calculator_server.rs | 23 +- mill-rpc/examples/concurrent_clients.rs | 107 ++--- mill-rpc/examples/echo_client.rs | 18 +- mill-rpc/examples/echo_server.rs | 20 +- mill-rpc/examples/kv_client.rs | 63 +-- mill-rpc/examples/kv_server.rs | 28 +- mill-rpc/examples/multi_service_client.rs | 67 +-- mill-rpc/examples/multi_service_server.rs | 126 ++--- mill-rpc/mill-rpc-core/Cargo.toml | 1 - mill-rpc/mill-rpc-macros/src/lib.rs | 531 +++++++++++++--------- mill-rpc/src/client.rs | 9 +- mill-rpc/src/lib.rs | 29 +- mill-rpc/src/prelude.rs | 6 +- mill-rpc/src/server.rs | 2 +- mill-rpc/tests/integration_test.rs | 101 ++-- 17 files changed, 652 insertions(+), 713 deletions(-) diff --git a/mill-rpc/README.md b/mill-rpc/README.md index e0db420..662ae16 100644 --- a/mill-rpc/README.md +++ b/mill-rpc/README.md @@ -1,194 +1,138 @@ # mill-rpc -An RPC framework built on top of [`mill-io`](../mill-io) and [`mill-net`](../mill-net). Define services as Rust traits, get type-safe clients and servers for free - no async runtime required. +An RPC framework built on [`mill-io`](../mill-io) and [`mill-net`](../mill-net). Define services declaratively, get type-safe clients and servers — no async runtime required. ## Features -- **Zero async** - Handlers are plain synchronous functions, no `async/await` needed -- **Macro-driven** - `#[mill_rpc::service]` generates server traits, client structs, and dispatch logic from a single trait definition -- **Type-safe** - Compile-time checked request/response types and method signatures -- **Multi-service** - Host multiple services on a single server with automatic routing -- **Pluggable codecs** - Bincode by default, extensible to JSON, MessagePack, CBOR, etc. -- **Binary wire protocol** - Efficient framing with support for one-way calls, ping/pong, and request cancellation - -## Installation - -```toml -[dependencies] -mill-rpc = { path = "../mill-rpc" } -``` +- **Zero async**: Handlers are plain synchronous functions +- **Macro-driven**: `mill_rpc::service!` generates a module with server trait, client struct, and dispatch logic +- **Selective generation**: Use `#[server]`, `#[client]`, or both (default) +- **Multi-service**: Host multiple services on a single port with automatic routing +- **Pluggable codecs**: Bincode by default, extensible +- **Binary wire protocol**: Efficient framing with one-way calls and ping/pong ## Quick Start -### 1. Define a service +### Define a service ```rust -use mill_rpc::prelude::*; - -#[mill_rpc::service] -trait Calculator { - fn add(a: i32, b: i32) -> i32; - fn multiply(a: i64, b: i64) -> i64; +mill_rpc::service! { + service Calculator { + fn add(a: i32, b: i32) -> i32; + fn multiply(a: i64, b: i64) -> i64; + } } ``` -This single trait generates: -- `CalculatorServer` - trait you implement on the server -- `CalculatorClient` - struct with typed RPC methods -- `CalculatorDispatcher` - wrapper that implements `ServiceDispatch` -- Per-method request/response types with serde derives -- `calculator_methods` module with method ID constants +This generates a `calculator` module containing: +- `calculator::Service`: trait to implement on the server +- `calculator::server(impl)`: wraps your impl for registration +- `calculator::Client`: struct with typed RPC methods +- `calculator::methods`: method ID constants -### 2. Implement the server +### Server ```rust -struct MyCalculator; +struct MyCalc; -impl CalculatorServer for MyCalculator { - fn add(&self, _ctx: &RpcContext, a: i32, b: i32) -> i32 { - a + b - } - - fn multiply(&self, _ctx: &RpcContext, a: i64, b: i64) -> i64 { - a * b - } +impl calculator::Service for MyCalc { + fn add(&self, _ctx: &RpcContext, a: i32, b: i32) -> i32 { a + b } + fn multiply(&self, _ctx: &RpcContext, a: i64, b: i64) -> i64 { a * b } } -``` - -### 3. Start the server - -```rust -use mill_io::EventLoop; -use std::sync::Arc; fn main() { let event_loop = Arc::new(EventLoop::new(4, 1024, 100).unwrap()); let _server = RpcServer::builder() .bind("127.0.0.1:9001".parse().unwrap()) - .service(CalculatorDispatcher(MyCalculator)) + .service(calculator::server(MyCalc)) .build(&event_loop) - .expect("Failed to start server"); + .unwrap(); event_loop.run().unwrap(); } ``` -### 4. Call from a client +### Client ```rust -use mill_io::EventLoop; -use std::sync::Arc; -use std::thread; -use std::time::Duration; +let transport = RpcClient::connect(addr, &event_loop, Codec::bincode()).unwrap(); +let client = calculator::Client::new(transport, Codec::bincode(), 0); -fn main() { - let event_loop = Arc::new(EventLoop::new(2, 1024, 100).unwrap()); - - // Run event loop in background - let el = event_loop.clone(); - thread::spawn(move || el.run().unwrap()); - thread::sleep(Duration::from_millis(50)); +let sum = client.add(10, 25).unwrap(); // 35 +let prod = client.multiply(7, 8).unwrap(); // 56 +``` - let transport = mill_rpc::RpcClient::connect( - "127.0.0.1:9001".parse().unwrap(), - &event_loop, - Codec::bincode(), - ).unwrap(); +## Selective Generation - let client = CalculatorClient::new(transport, Codec::bincode(), 0); +Generate only what you need: - let sum = client.add(10, 25).unwrap(); - println!("10 + 25 = {}", sum); // 35 +```rust +// Server crate: no client code generated +mill_rpc::service! { + #[server] + service Calculator { + fn add(a: i32, b: i32) -> i32; + } +} - let product = client.multiply(7, 8).unwrap(); - println!("7 * 8 = {}", product); // 56 +// Client crate: no server code generated +mill_rpc::service! { + #[client] + service Calculator { + fn add(a: i32, b: i32) -> i32; + } +} - event_loop.stop(); +// Both (default): for tests, examples, or single-binary apps +mill_rpc::service! { + service Calculator { + fn add(a: i32, b: i32) -> i32; + } } ``` ## Multi-Service Server -Register multiple services on a single port. Each service gets an auto-assigned service ID. - ```rust -#[mill_rpc::service] -trait MathService { - fn factorial(n: u64) -> u64; +mill_rpc::service! { + #[server] + service MathService { + fn factorial(n: u64) -> u64; + } } -#[mill_rpc::service] -trait StringService { - fn reverse(s: String) -> String; +mill_rpc::service! { + #[server] + service StringService { + fn reverse(s: String) -> String; + } } -// Server let _server = RpcServer::builder() .bind(addr) - .service(MathServiceDispatcher(MathImpl)) // service_id = 0 - .service(StringServiceDispatcher(StringImpl)) // service_id = 1 + .service(math_service::server(MathImpl)) // service_id = 0 + .service(string_service::server(StringImpl)) // service_id = 1 .build(&event_loop)?; -// Client - both share a single TCP connection -let math = MathServiceClient::new(transport.clone(), Codec::bincode(), 0); -let strings = StringServiceClient::new(transport, Codec::bincode(), 1); - -math.factorial(10)?; // 3628800 -strings.reverse("hello")?; // "olleh" -``` - -## Wire Protocol - -Mill-RPC uses a compact binary frame format: - -```text -+--------+--------+-------+--------+-----------+---------+ -| Magic | Version| Flags | MsgType| PayloadLen| Payload | -| 2B | 1B | 1B | 1B | 4B (LE) | N bytes | -+--------+--------+-------+--------+-----------+---------+ -``` - -Request payloads carry routing info: - -```text -+------------+-----------+-----------+---------+ -| RequestID | ServiceID | MethodID | Args | -| 8B (LE) | 2B (LE) | 2B (LE) | N bytes | -+------------+-----------+-----------+---------+ +// Client side: share one connection +let math = math_service::Client::new(transport.clone(), codec, 0); +let strings = string_service::Client::new(transport, codec, 1); ``` -**Message types:** Request, Response, Error, Ping, Pong, Cancel - -**Flags:** Compressed payload, One-way (fire-and-forget) - ## Examples -Run any example pair (server first, then client): - ```bash -# Terminal 1 -cargo run --example calculator_server - -# Terminal 2 -cargo run --example calculator_client +# Terminal 1 # Terminal 2 +cargo run --example calculator_server cargo run --example calculator_client +cargo run --example echo_server cargo run --example echo_client +cargo run --example kv_server cargo run --example kv_client +cargo run --example multi_service_server cargo run --example multi_service_client + +# Self-contained stress test +cargo run --example concurrent_clients ``` -you will find all examples [here](./examples/). - -## Error Handling - -Mill-RPC uses structured errors with gRPC-style status codes: - -| Code | Status | Description | -| ---- | ----------------- | --------------------------- | -| 0 | OK | Success | -| 2 | INVALID_ARGUMENT | Bad request parameters | -| 3 | NOT_FOUND | Service or method not found | -| 8 | INTERNAL | Server-side error | -| 9 | UNAVAILABLE | Connection failure | -| 10 | DEADLINE_EXCEEDED | Request timeout | - ## License Licensed under the Apache License, Version 2.0. See [LICENSE](../LICENSE) for details. diff --git a/mill-rpc/examples/calculator_client.rs b/mill-rpc/examples/calculator_client.rs index 74fe80b..1d02e38 100644 --- a/mill-rpc/examples/calculator_client.rs +++ b/mill-rpc/examples/calculator_client.rs @@ -9,37 +9,36 @@ use std::sync::Arc; use std::thread; use std::time::Duration; -#[mill_rpc::service] -trait Calculator { - fn add(a: i32, b: i32) -> i32; - fn subtract(a: i32, b: i32) -> i32; - fn multiply(a: i64, b: i64) -> i64; - fn divide(a: f64, b: f64) -> f64; - fn negate(x: i32) -> i32; +// Define the service — only generate client side +mill_rpc::service! { + #[client] + service Calculator { + fn add(a: i32, b: i32) -> i32; + fn subtract(a: i32, b: i32) -> i32; + fn multiply(a: i64, b: i64) -> i64; + fn divide(a: f64, b: f64) -> f64; + fn negate(x: i32) -> i32; + } } fn main() { env_logger::init(); let event_loop = Arc::new(EventLoop::new(2, 1024, 100).unwrap()); - // Run event loop in background let el = event_loop.clone(); let handle = thread::spawn(move || { el.run().unwrap(); }); - - // Give event loop a moment to start thread::sleep(Duration::from_millis(50)); let addr = "127.0.0.1:9001".parse().unwrap(); - let transport = mill_rpc::RpcClient::connect(addr, &event_loop, Codec::bincode()) + let transport = RpcClient::connect(addr, &event_loop) .expect("Failed to connect to calculator server"); - let client = CalculatorClient::new(transport, Codec::bincode(), 0); + let client = calculator::Client::new(transport, Codec::bincode(), 0); println!("Connected to calculator server"); - // Basic arithmetic let sum = client.add(10, 25).unwrap(); println!("10 + 25 = {}", sum); @@ -55,7 +54,6 @@ fn main() { let neg = client.negate(42).unwrap(); println!("negate(42) = {}", neg); - // Division by zero let nan = client.divide(1.0, 0.0).unwrap(); println!("1.0 / 0.0 = {} (NaN: {})", nan, nan.is_nan()); diff --git a/mill-rpc/examples/calculator_server.rs b/mill-rpc/examples/calculator_server.rs index c08973d..42f2362 100644 --- a/mill-rpc/examples/calculator_server.rs +++ b/mill-rpc/examples/calculator_server.rs @@ -7,18 +7,21 @@ use mill_io::EventLoop; use mill_rpc::prelude::*; use std::sync::Arc; -#[mill_rpc::service] -trait Calculator { - fn add(a: i32, b: i32) -> i32; - fn subtract(a: i32, b: i32) -> i32; - fn multiply(a: i64, b: i64) -> i64; - fn divide(a: f64, b: f64) -> f64; - fn negate(x: i32) -> i32; +// Define the service — only generate server side +mill_rpc::service! { + #[server] + service Calculator { + fn add(a: i32, b: i32) -> i32; + fn subtract(a: i32, b: i32) -> i32; + fn multiply(a: i64, b: i64) -> i64; + fn divide(a: f64, b: f64) -> f64; + fn negate(x: i32) -> i32; + } } -struct CalculatorImpl; +struct MyCalculator; -impl CalculatorServer for CalculatorImpl { +impl calculator::Service for MyCalculator { fn add(&self, _ctx: &RpcContext, a: i32, b: i32) -> i32 { println!(" add({}, {}) = {}", a, b, a + b); a + b @@ -53,7 +56,7 @@ fn main() { let addr = "127.0.0.1:9001".parse().unwrap(); let _server = RpcServer::builder() .bind(addr) - .service(CalculatorDispatcher(CalculatorImpl)) + .service(calculator::server(MyCalculator)) .build(&event_loop) .expect("Failed to start calculator server"); diff --git a/mill-rpc/examples/concurrent_clients.rs b/mill-rpc/examples/concurrent_clients.rs index ef29126..7574bad 100644 --- a/mill-rpc/examples/concurrent_clients.rs +++ b/mill-rpc/examples/concurrent_clients.rs @@ -1,7 +1,4 @@ -//! Concurrent clients stress test - multiple threads sending RPC calls simultaneously. -//! -//! This example starts an embedded server and spawns N client threads, -//! each making M requests. Demonstrates thread-safety and concurrent dispatch. +//! Concurrent clients stress test. //! //! Run with: cargo run --example concurrent_clients @@ -12,26 +9,19 @@ use std::sync::Arc; use std::thread; use std::time::{Duration, Instant}; -#[mill_rpc::service] -trait Counter { - fn increment() -> u64; - fn get() -> u64; - fn add(n: u64) -> u64; +// Both sides in one binary +mill_rpc::service! { + service Counter { + fn increment() -> u64; + fn get() -> u64; + } } struct AtomicCounter { value: AtomicU64, } -impl AtomicCounter { - fn new() -> Self { - Self { - value: AtomicU64::new(0), - } - } -} - -impl CounterServer for AtomicCounter { +impl counter::Service for AtomicCounter { fn increment(&self, _ctx: &RpcContext) -> u64 { self.value.fetch_add(1, Ordering::SeqCst) + 1 } @@ -39,10 +29,6 @@ impl CounterServer for AtomicCounter { fn get(&self, _ctx: &RpcContext) -> u64 { self.value.load(Ordering::SeqCst) } - - fn add(&self, _ctx: &RpcContext, n: u64) -> u64 { - self.value.fetch_add(n, Ordering::SeqCst) + n - } } fn main() { @@ -51,97 +37,76 @@ fn main() { let num_clients = 4; let requests_per_client = 100; - // --- Start embedded server --- let server_el = Arc::new(EventLoop::new(4, 1024, 100).unwrap()); let addr = "127.0.0.1:9005".parse().unwrap(); let _server = RpcServer::builder() .bind(addr) - .service(CounterDispatcher(AtomicCounter::new())) + .service(counter::server(AtomicCounter { + value: AtomicU64::new(0), + })) .build(&server_el) .expect("Failed to start server"); let sel = server_el.clone(); - let server_thread = thread::spawn(move || { - sel.run().unwrap(); - }); - + let server_thread = thread::spawn(move || { sel.run().unwrap(); }); thread::sleep(Duration::from_millis(100)); + println!( - "Server started. Spawning {} clients, {} requests each...\n", + "Spawning {} clients, {} requests each...\n", num_clients, requests_per_client ); - // --- Spawn client threads --- let start = Instant::now(); let mut handles = Vec::new(); for client_id in 0..num_clients { let handle = thread::spawn(move || { let client_el = Arc::new(EventLoop::new(1, 256, 50).unwrap()); - let cel = client_el.clone(); - let el_thread = thread::spawn(move || { - cel.run().unwrap(); - }); - + let el_thread = thread::spawn(move || { cel.run().unwrap(); }); thread::sleep(Duration::from_millis(20)); - let transport = - mill_rpc::RpcClient::connect(addr, &client_el, Codec::bincode()).unwrap(); - let counter = CounterClient::new(transport, Codec::bincode(), 0); + let transport = RpcClient::connect(addr, &client_el).unwrap(); + let client = counter::Client::new(transport, Codec::bincode(), 0); - let mut local_results = Vec::new(); + let mut results = Vec::new(); for _ in 0..requests_per_client { - let val = counter.increment().unwrap(); - local_results.push(val); + results.push(client.increment().unwrap()); } println!( - " Client {} finished: first={}, last={}", + " Client {} done: first={}, last={}", client_id, - local_results.first().unwrap(), - local_results.last().unwrap() + results.first().unwrap(), + results.last().unwrap() ); client_el.stop(); let _ = el_thread.join(); - - local_results + results }); handles.push(handle); } - // --- Collect results --- - let mut all_values: Vec = Vec::new(); - for handle in handles { - let values = handle.join().unwrap(); - all_values.extend(values); - } + let mut all: Vec = handles + .into_iter() + .flat_map(|h| h.join().unwrap()) + .collect(); let elapsed = start.elapsed(); + all.sort(); + all.dedup(); - // Every increment should have returned a unique value - all_values.sort(); - all_values.dedup(); - - let total_expected = (num_clients * requests_per_client) as usize; + let total = (num_clients * requests_per_client) as usize; println!("\n--- Results ---"); - println!("Total requests: {}", total_expected); - println!("Unique values: {}", all_values.len()); - println!("Time elapsed: {:?}", elapsed); - println!( - "Throughput: {:.0} req/s", - total_expected as f64 / elapsed.as_secs_f64() - ); - - assert_eq!( - all_values.len(), - total_expected, - "Expected all increments to produce unique values (no lost updates)" - ); + println!("Total requests: {}", total); + println!("Unique values: {}", all.len()); + println!("Time: {:?}", elapsed); + println!("Throughput: {:.0} req/s", total as f64 / elapsed.as_secs_f64()); - println!("\nConcurrency test passed! No lost updates."); + assert_eq!(all.len(), total, "No lost updates"); + println!("\nConcurrency test passed!"); server_el.stop(); let _ = server_thread.join(); diff --git a/mill-rpc/examples/echo_client.rs b/mill-rpc/examples/echo_client.rs index f57c69d..e40eabf 100644 --- a/mill-rpc/examples/echo_client.rs +++ b/mill-rpc/examples/echo_client.rs @@ -9,12 +9,14 @@ use std::sync::Arc; use std::thread; use std::time::Duration; -#[mill_rpc::service] -trait Echo { - fn echo(message: String) -> String; - fn echo_uppercase(message: String) -> String; - fn echo_repeat(message: String, times: u32) -> String; - fn request_count() -> u64; +mill_rpc::service! { + #[client] + service Echo { + fn echo(message: String) -> String; + fn echo_uppercase(message: String) -> String; + fn echo_repeat(message: String, times: u32) -> String; + fn request_count() -> u64; + } } fn main() { @@ -28,10 +30,10 @@ fn main() { thread::sleep(Duration::from_millis(50)); let addr = "127.0.0.1:9002".parse().unwrap(); - let transport = mill_rpc::RpcClient::connect(addr, &event_loop, Codec::bincode()) + let transport = RpcClient::connect(addr, &event_loop) .expect("Failed to connect to echo server"); - let client = EchoClient::new(transport, Codec::bincode(), 0); + let client = echo::Client::new(transport, Codec::bincode(), 0); // Basic echo let reply = client.echo("Hello, Mill-RPC!".into()).unwrap(); diff --git a/mill-rpc/examples/echo_server.rs b/mill-rpc/examples/echo_server.rs index d3e41ce..306d7cc 100644 --- a/mill-rpc/examples/echo_server.rs +++ b/mill-rpc/examples/echo_server.rs @@ -1,4 +1,4 @@ -//! Echo RPC server - returns whatever you send. +//! Echo RPC server. //! //! Run with: cargo run --example echo_server //! Then connect with: cargo run --example echo_client @@ -8,12 +8,14 @@ use mill_rpc::prelude::*; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; -#[mill_rpc::service] -trait Echo { - fn echo(message: String) -> String; - fn echo_uppercase(message: String) -> String; - fn echo_repeat(message: String, times: u32) -> String; - fn request_count() -> u64; +mill_rpc::service! { + #[server] + service Echo { + fn echo(message: String) -> String; + fn echo_uppercase(message: String) -> String; + fn echo_repeat(message: String, times: u32) -> String; + fn request_count() -> u64; + } } struct EchoImpl { @@ -28,7 +30,7 @@ impl EchoImpl { } } -impl EchoServer for EchoImpl { +impl echo::Service for EchoImpl { fn echo(&self, _ctx: &RpcContext, message: String) -> String { self.counter.fetch_add(1, Ordering::Relaxed); println!(" echo: {:?}", message); @@ -63,7 +65,7 @@ fn main() { let addr = "127.0.0.1:9002".parse().unwrap(); let _server = RpcServer::builder() .bind(addr) - .service(EchoDispatcher(EchoImpl::new())) + .service(echo::server(EchoImpl::new())) .build(&event_loop) .expect("Failed to start echo server"); diff --git a/mill-rpc/examples/kv_client.rs b/mill-rpc/examples/kv_client.rs index 74c1940..d08d3ba 100644 --- a/mill-rpc/examples/kv_client.rs +++ b/mill-rpc/examples/kv_client.rs @@ -9,14 +9,16 @@ use std::sync::Arc; use std::thread; use std::time::Duration; -#[mill_rpc::service] -trait KeyValue { - fn get(key: String) -> Option; - fn set(key: String, value: String) -> Option; - fn delete(key: String) -> bool; - fn keys() -> Vec; - fn len() -> u64; - fn clear() -> u64; +mill_rpc::service! { + #[client] + service KeyValue { + fn get(key: String) -> Option; + fn set(key: String, value: String) -> Option; + fn delete(key: String) -> bool; + fn keys() -> Vec; + fn len() -> u64; + fn clear() -> u64; + } } fn main() { @@ -30,75 +32,46 @@ fn main() { thread::sleep(Duration::from_millis(50)); let addr = "127.0.0.1:9003".parse().unwrap(); - let transport = mill_rpc::RpcClient::connect(addr, &event_loop, Codec::bincode()) + let transport = RpcClient::connect(addr, &event_loop) .expect("Failed to connect to KV server"); - let kv = KeyValueClient::new(transport, Codec::bincode(), 0); + let kv = key_value::Client::new(transport, Codec::bincode(), 0); println!("=== Key-Value Store Client ===\n"); - // Initially empty let len = kv.len().unwrap(); println!("Initial store size: {}", len); - assert_eq!(len, 0); - // Set some keys - let old = kv.set("name".into(), "Alice".into()).unwrap(); - println!("SET name=Alice (old: {:?})", old); - assert!(old.is_none()); + kv.set("name".into(), "Alice".into()).unwrap(); + println!("SET name=Alice"); - let old = kv.set("city".into(), "Berlin".into()).unwrap(); - println!("SET city=Berlin (old: {:?})", old); + kv.set("city".into(), "Berlin".into()).unwrap(); + println!("SET city=Berlin"); - let old = kv.set("lang".into(), "Rust".into()).unwrap(); - println!("SET lang=Rust (old: {:?})", old); + kv.set("lang".into(), "Rust".into()).unwrap(); + println!("SET lang=Rust"); - // Get values let val = kv.get("name".into()).unwrap(); println!("GET name -> {:?}", val); - assert_eq!(val, Some("Alice".to_string())); let val = kv.get("missing".into()).unwrap(); println!("GET missing -> {:?}", val); - assert_eq!(val, None); - // List keys let mut keys = kv.keys().unwrap(); keys.sort(); println!("KEYS -> {:?}", keys); - assert_eq!(keys.len(), 3); - // Overwrite let old = kv.set("name".into(), "Bob".into()).unwrap(); println!("SET name=Bob (old: {:?})", old); - assert_eq!(old, Some("Alice".to_string())); - - let val = kv.get("name".into()).unwrap(); - println!("GET name -> {:?}", val); - assert_eq!(val, Some("Bob".to_string())); - - // Delete - let existed = kv.delete("city".into()).unwrap(); - println!("DEL city -> existed: {}", existed); - assert!(existed); let existed = kv.delete("city".into()).unwrap(); println!("DEL city -> existed: {}", existed); - assert!(!existed); - // Final size let len = kv.len().unwrap(); println!("Store size: {}", len); - assert_eq!(len, 2); - // Clear let removed = kv.clear().unwrap(); println!("CLEAR -> removed {} entries", removed); - assert_eq!(removed, 2); - - let len = kv.len().unwrap(); - println!("Store size after clear: {}", len); - assert_eq!(len, 0); println!("\nAll KV tests passed!"); diff --git a/mill-rpc/examples/kv_server.rs b/mill-rpc/examples/kv_server.rs index d2d08d4..80a4e21 100644 --- a/mill-rpc/examples/kv_server.rs +++ b/mill-rpc/examples/kv_server.rs @@ -1,8 +1,5 @@ //! In-memory key-value store RPC server. //! -//! Demonstrates shared mutable state (HashMap behind RwLock) accessed -//! concurrently from the thread pool. -//! //! Run with: cargo run --example kv_server //! Then connect with: cargo run --example kv_client @@ -11,16 +8,16 @@ use mill_rpc::prelude::*; use std::collections::HashMap; use std::sync::{Arc, RwLock}; -/// We use Option to represent "key not found" as None. -/// In a real app you'd use Result or a custom enum with #[derive(RpcError)]. -#[mill_rpc::service] -trait KeyValue { - fn get(key: String) -> Option; - fn set(key: String, value: String) -> Option; - fn delete(key: String) -> bool; - fn keys() -> Vec; - fn len() -> u64; - fn clear() -> u64; +mill_rpc::service! { + #[server] + service KeyValue { + fn get(key: String) -> Option; + fn set(key: String, value: String) -> Option; + fn delete(key: String) -> bool; + fn keys() -> Vec; + fn len() -> u64; + fn clear() -> u64; + } } struct KvStore { @@ -35,7 +32,7 @@ impl KvStore { } } -impl KeyValueServer for KvStore { +impl key_value::Service for KvStore { fn get(&self, _ctx: &RpcContext, key: String) -> Option { let data = self.data.read().unwrap(); let result = data.get(&key).cloned(); @@ -87,11 +84,10 @@ fn main() { let addr = "127.0.0.1:9003".parse().unwrap(); let _server = RpcServer::builder() .bind(addr) - .service(KeyValueDispatcher(KvStore::new())) + .service(key_value::server(KvStore::new())) .build(&event_loop) .expect("Failed to start KV server"); println!("Key-Value server listening on {}", addr); - println!("Supports: GET, SET, DEL, KEYS, LEN, CLEAR"); event_loop.run().unwrap(); } diff --git a/mill-rpc/examples/multi_service_client.rs b/mill-rpc/examples/multi_service_client.rs index 28dab68..0d4fa8c 100644 --- a/mill-rpc/examples/multi_service_client.rs +++ b/mill-rpc/examples/multi_service_client.rs @@ -1,4 +1,4 @@ -//! Multi-service RPC client - calls two services on one server. +//! Multi-service RPC client — calls two services on one server. //! //! Run the server first: cargo run --example multi_service_server //! Then run: cargo run --example multi_service_client @@ -9,22 +9,23 @@ use std::sync::Arc; use std::thread; use std::time::Duration; -// Service definitions must match the server -#[mill_rpc::service] -trait MathService { - fn factorial(n: u64) -> u64; - fn fibonacci(n: u32) -> u64; - fn is_prime(n: u64) -> bool; - fn gcd(a: u64, b: u64) -> u64; +mill_rpc::service! { + #[client] + service MathService { + fn factorial(n: u64) -> u64; + fn fibonacci(n: u32) -> u64; + fn is_prime(n: u64) -> bool; + fn gcd(a: u64, b: u64) -> u64; + } } -#[mill_rpc::service] -trait StringService { - fn reverse(s: String) -> String; - fn word_count(s: String) -> u32; - fn contains(haystack: String, needle: String) -> bool; - fn trim(s: String) -> String; - fn replace(s: String, from: String, to: String) -> String; +mill_rpc::service! { + #[client] + service StringService { + fn reverse(s: String) -> String; + fn word_count(s: String) -> u32; + fn contains(haystack: String, needle: String) -> bool; + } } fn main() { @@ -38,62 +39,38 @@ fn main() { thread::sleep(Duration::from_millis(50)); let addr = "127.0.0.1:9004".parse().unwrap(); - let transport = mill_rpc::RpcClient::connect(addr, &event_loop, Codec::bincode()) + let transport = RpcClient::connect(addr, &event_loop) .expect("Failed to connect"); - // Both clients share the same transport (single TCP connection) - let math = MathServiceClient::new(transport.clone(), Codec::bincode(), 0); - let strings = StringServiceClient::new(transport, Codec::bincode(), 1); + let math = math_service::Client::new(transport.clone(), Codec::bincode(), 0); + let strings = string_service::Client::new(transport, Codec::bincode(), 1); - // ---- Math Service ---- println!("=== Math Service ===\n"); let f = math.factorial(10).unwrap(); println!("10! = {}", f); - assert_eq!(f, 3628800); let fib = math.fibonacci(20).unwrap(); println!("fib(20) = {}", fib); - assert_eq!(fib, 6765); - for n in [2, 7, 15, 17, 100] { + for n in [2, 7, 15, 17] { let prime = math.is_prime(n).unwrap(); println!("is_prime({}) = {}", n, prime); } let g = math.gcd(48, 18).unwrap(); println!("gcd(48, 18) = {}", g); - assert_eq!(g, 6); - // ---- String Service ---- println!("\n=== String Service ===\n"); let rev = strings.reverse("Hello, World!".into()).unwrap(); println!("reverse(\"Hello, World!\") = {:?}", rev); - assert_eq!(rev, "!dlroW ,olleH"); - let wc = strings - .word_count("The quick brown fox jumps".into()) - .unwrap(); - println!("word_count(\"The quick brown fox jumps\") = {}", wc); - assert_eq!(wc, 5); + let wc = strings.word_count("The quick brown fox".into()).unwrap(); + println!("word_count = {}", wc); let has = strings.contains("rustacean".into(), "rust".into()).unwrap(); println!("contains(\"rustacean\", \"rust\") = {}", has); - assert!(has); - - let trimmed = strings.trim(" hello ".into()).unwrap(); - println!("trim(\" hello \") = {:?}", trimmed); - assert_eq!(trimmed, "hello"); - - let replaced = strings - .replace("foo bar foo baz".into(), "foo".into(), "qux".into()) - .unwrap(); - println!( - "replace(\"foo bar foo baz\", \"foo\", \"qux\") = {:?}", - replaced - ); - assert_eq!(replaced, "qux bar qux baz"); println!("\nAll multi-service tests passed!"); diff --git a/mill-rpc/examples/multi_service_server.rs b/mill-rpc/examples/multi_service_server.rs index 0b4109b..36110e5 100644 --- a/mill-rpc/examples/multi_service_server.rs +++ b/mill-rpc/examples/multi_service_server.rs @@ -1,7 +1,4 @@ -//! Multi-service RPC server - hosts multiple services on a single port. -//! -//! Demonstrates service composition: a math service (service_id=0) and -//! a string utility service (service_id=1) on the same server. +//! Multi-service RPC server — two services on one port. //! //! Run with: cargo run --example multi_service_server //! Then connect with: cargo run --example multi_service_client @@ -10,27 +7,34 @@ use mill_io::EventLoop; use mill_rpc::prelude::*; use std::sync::Arc; -// ---------- Service 1: MathService ---------- +mill_rpc::service! { + #[server] + service MathService { + fn factorial(n: u64) -> u64; + fn fibonacci(n: u32) -> u64; + fn is_prime(n: u64) -> bool; + fn gcd(a: u64, b: u64) -> u64; + } +} -#[mill_rpc::service] -trait MathService { - fn factorial(n: u64) -> u64; - fn fibonacci(n: u32) -> u64; - fn is_prime(n: u64) -> bool; - fn gcd(a: u64, b: u64) -> u64; +mill_rpc::service! { + #[server] + service StringService { + fn reverse(s: String) -> String; + fn word_count(s: String) -> u32; + fn contains(haystack: String, needle: String) -> bool; + } } struct MathImpl; -impl MathServiceServer for MathImpl { +impl math_service::Service for MathImpl { fn factorial(&self, _ctx: &RpcContext, n: u64) -> u64 { - let result = (1..=n).product(); - println!(" [math] factorial({}) = {}", n, result); - result + (1..=n).product() } fn fibonacci(&self, _ctx: &RpcContext, n: u32) -> u64 { - let result = match n { + match n { 0 => 0, 1 => 1, _ => { @@ -42,92 +46,52 @@ impl MathServiceServer for MathImpl { } b } - }; - println!(" [math] fibonacci({}) = {}", n, result); - result + } } fn is_prime(&self, _ctx: &RpcContext, n: u64) -> bool { - let result = if n < 2 { - false - } else if n < 4 { - true - } else if n % 2 == 0 { - false - } else { - let mut i = 3; - while i * i <= n { - if n % i == 0 { - return false; - } - i += 2; + if n < 2 { + return false; + } + if n < 4 { + return true; + } + if n % 2 == 0 { + return false; + } + let mut i = 3; + while i * i <= n { + if n % i == 0 { + return false; } - true - }; - println!(" [math] is_prime({}) = {}", n, result); - result + i += 2; + } + true } fn gcd(&self, _ctx: &RpcContext, mut a: u64, mut b: u64) -> u64 { - let (orig_a, orig_b) = (a, b); while b != 0 { let tmp = b; b = a % b; a = tmp; } - println!(" [math] gcd({}, {}) = {}", orig_a, orig_b, a); a } } -// ---------- Service 2: StringService ---------- - -#[mill_rpc::service] -trait StringService { - fn reverse(s: String) -> String; - fn word_count(s: String) -> u32; - fn contains(haystack: String, needle: String) -> bool; - fn trim(s: String) -> String; - fn replace(s: String, from: String, to: String) -> String; -} - struct StringImpl; -impl StringServiceServer for StringImpl { +impl string_service::Service for StringImpl { fn reverse(&self, _ctx: &RpcContext, s: String) -> String { - let result: String = s.chars().rev().collect(); - println!(" [str] reverse({:?}) = {:?}", s, result); - result + s.chars().rev().collect() } fn word_count(&self, _ctx: &RpcContext, s: String) -> u32 { - let count = s.split_whitespace().count() as u32; - println!(" [str] word_count({:?}) = {}", s, count); - count + s.split_whitespace().count() as u32 } fn contains(&self, _ctx: &RpcContext, haystack: String, needle: String) -> bool { - let result = haystack.contains(&needle); - println!( - " [str] contains({:?}, {:?}) = {}", - haystack, needle, result - ); - result - } - - fn trim(&self, _ctx: &RpcContext, s: String) -> String { - let result = s.trim().to_string(); - println!(" [str] trim({:?}) = {:?}", s, result); - result - } - - fn replace(&self, _ctx: &RpcContext, s: String, from: String, to: String) -> String { - let result = s.replace(&from, &to); - println!( - " [str] replace({:?}, {:?}, {:?}) = {:?}", - s, from, to, result - ); - result + haystack.contains(&needle) } } @@ -138,13 +102,13 @@ fn main() { let addr = "127.0.0.1:9004".parse().unwrap(); let _server = RpcServer::builder() .bind(addr) - .service(MathServiceDispatcher(MathImpl)) // service_id = 0 - .service(StringServiceDispatcher(StringImpl)) // service_id = 1 + .service(math_service::server(MathImpl)) // service_id = 0 + .service(string_service::server(StringImpl)) // service_id = 1 .build(&event_loop) .expect("Failed to start multi-service server"); println!("Multi-service server listening on {}", addr); - println!(" Service 0: MathService (factorial, fibonacci, is_prime, gcd)"); - println!(" Service 1: StringService (reverse, word_count, contains, trim, replace)"); + println!(" Service 0: MathService"); + println!(" Service 1: StringService"); event_loop.run().unwrap(); } diff --git a/mill-rpc/mill-rpc-core/Cargo.toml b/mill-rpc/mill-rpc-core/Cargo.toml index c00fc50..b6f8ce1 100644 --- a/mill-rpc/mill-rpc-core/Cargo.toml +++ b/mill-rpc/mill-rpc-core/Cargo.toml @@ -7,4 +7,3 @@ description = "Core types, wire protocol, and codec traits for Mill-RPC" [dependencies] serde = { version = "1", features = ["derive"] } bincode = "1" -thiserror = "2" diff --git a/mill-rpc/mill-rpc-macros/src/lib.rs b/mill-rpc/mill-rpc-macros/src/lib.rs index 24e139f..b642f08 100644 --- a/mill-rpc/mill-rpc-macros/src/lib.rs +++ b/mill-rpc/mill-rpc-macros/src/lib.rs @@ -1,287 +1,396 @@ use proc_macro::TokenStream; use quote::{format_ident, quote}; -use syn::{parse_macro_input, FnArg, Ident, ItemTrait, Pat, ReturnType, TraitItem, Type}; +use syn::{ + braced, + parse::{Parse, ParseStream}, + parse_macro_input, FnArg, Ident, Pat, ReturnType, Token, TraitItemFn, Type, +}; -/// Attribute macro for defining an RPC service. +/// Module-level macro for defining an RPC service. /// -/// Applied to a trait, it generates: -/// - `{Name}Server` trait with `&self, ctx: &RpcContext` prepended to each method -/// - `{Name}Client` struct with typed RPC call methods -/// - `{Name}Dispatcher` wrapper that implements `ServiceDispatch` -/// - Per-method request/response types with serde derives +/// Generates a module containing `Service` trait, `Client` struct, +/// `server()` wrapper function, and all request/response types. /// -/// # Example +/// By default, both server and client code are generated. +/// Use `#[server]` or `#[client]` to generate only one side. +/// +/// # Examples /// /// ```ignore -/// #[mill_rpc::service] -/// trait Calculator { -/// fn add(a: i32, b: i32) -> i32; -/// fn divide(a: f64, b: f64) -> Result; +/// // Generate both server and client (default) +/// mill_rpc::service! { +/// service Calculator { +/// fn add(a: i32, b: i32) -> i32; +/// fn divide(a: f64, b: f64) -> f64; +/// } +/// } +/// +/// // Server only +/// mill_rpc::service! { +/// #[server] +/// service Calculator { +/// fn add(a: i32, b: i32) -> i32; +/// } +/// } +/// +/// // Client only (e.g. in a separate client crate) +/// mill_rpc::service! { +/// #[client] +/// service Calculator { +/// fn add(a: i32, b: i32) -> i32; +/// } /// } /// ``` -#[proc_macro_attribute] -pub fn service(_attr: TokenStream, item: TokenStream) -> TokenStream { - let input = parse_macro_input!(item as ItemTrait); - match generate_service(input) { +#[proc_macro] +pub fn service(input: TokenStream) -> TokenStream { + let def = parse_macro_input!(input as ServiceDef); + match generate_service_module(def) { Ok(tokens) => tokens.into(), Err(err) => err.to_compile_error().into(), } } -struct MethodInfo { +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum GenerateMode { + Both, + ServerOnly, + ClientOnly, +} + +struct ServiceDef { + mode: GenerateMode, + name: Ident, + methods: Vec, +} + +struct MethodDef { name: Ident, - method_id: u16, args: Vec<(Ident, Type)>, return_type: Type, } -fn generate_service(input: ItemTrait) -> syn::Result { - let trait_name = &input.ident; - let server_trait_name = format_ident!("{}Server", trait_name); - let client_struct_name = format_ident!("{}Client", trait_name); - let dispatcher_name = format_ident!("{}Dispatcher", trait_name); - let methods_mod_name = format_ident!("{}_methods", to_snake_case(&trait_name.to_string())); - - // Parse methods from the trait - let mut methods = Vec::new(); - for (idx, item) in input.items.iter().enumerate() { - let method = match item { - TraitItem::Fn(m) => m, - _ => continue, - }; - - let name = method.sig.ident.clone(); - - // Extract arguments (skip &self if present, though we don't expect it) - let mut args = Vec::new(); - for arg in &method.sig.inputs { - match arg { - FnArg::Typed(pat_type) => { - let pat = &*pat_type.pat; - let ident = match pat { - Pat::Ident(pi) => pi.ident.clone(), - _ => { - return Err(syn::Error::new_spanned( - pat, - "Expected a simple identifier pattern for argument", - )) - } - }; - let ty = (*pat_type.ty).clone(); - args.push((ident, ty)); - } - FnArg::Receiver(_) => { +impl Parse for ServiceDef { + fn parse(input: ParseStream) -> syn::Result { + // Parse optional #[server] or #[client] + let mode = if input.peek(Token![#]) { + input.parse::()?; + let content; + syn::bracketed!(content in input); + let attr_name: Ident = content.parse()?; + match attr_name.to_string().as_str() { + "server" => GenerateMode::ServerOnly, + "client" => GenerateMode::ClientOnly, + other => { return Err(syn::Error::new_spanned( - arg, - "#[mill_rpc::service] trait methods should not have `self` parameter", - )); + attr_name, + format!( + "Unknown attribute `{}`, expected `server` or `client`", + other + ), + )) } } - } - - // Extract return type - let return_type = match &method.sig.output { - ReturnType::Default => syn::parse_quote!(()), - ReturnType::Type(_, ty) => (**ty).clone(), + } else { + GenerateMode::Both }; - methods.push(MethodInfo { - name, - method_id: idx as u16, - args, - return_type, - }); - } - - // Generate method ID constants - let method_consts: Vec<_> = methods - .iter() - .map(|m| { - let const_name = format_ident!("{}", m.name.to_string().to_uppercase()); - let id = m.method_id; - quote! { pub const #const_name: u16 = #id; } - }) - .collect(); - - // Generate per-method request/response structs - let mut type_defs = Vec::new(); - for m in &methods { - let req_name = format_ident!("{}Request", to_pascal_case(&m.name.to_string())); - let resp_name = format_ident!("{}Response", to_pascal_case(&m.name.to_string())); - let ret_ty = &m.return_type; - - let field_names: Vec<_> = m.args.iter().map(|(name, _)| name).collect(); - let field_types: Vec<_> = m.args.iter().map(|(_, ty)| ty).collect(); + // Parse `service Name { ... }` + let service_kw: Ident = input.parse()?; + if service_kw != "service" { + return Err(syn::Error::new_spanned(service_kw, "Expected `service`")); + } - let req_struct = if m.args.is_empty() { - quote! { - #[derive(::serde::Serialize, ::serde::Deserialize, Debug)] - pub struct #req_name; - } - } else { - quote! { - #[derive(::serde::Serialize, ::serde::Deserialize, Debug)] - pub struct #req_name { - #( pub #field_names: #field_types, )* + let name: Ident = input.parse()?; + + let content; + braced!(content in input); + + let mut methods = Vec::new(); + while !content.is_empty() { + let method: TraitItemFn = content.parse()?; + + let method_name = method.sig.ident.clone(); + + let mut args = Vec::new(); + for arg in &method.sig.inputs { + match arg { + FnArg::Typed(pat_type) => { + let ident = match &*pat_type.pat { + Pat::Ident(pi) => pi.ident.clone(), + other => { + return Err(syn::Error::new_spanned( + other, + "Expected a simple identifier for argument name", + )) + } + }; + args.push((ident, (*pat_type.ty).clone())); + } + FnArg::Receiver(_) => { + return Err(syn::Error::new_spanned( + arg, + "Service methods should not have `self` parameter", + )) + } } } - }; - type_defs.push(quote! { - #req_struct + let return_type = match &method.sig.output { + ReturnType::Default => syn::parse_quote!(()), + ReturnType::Type(_, ty) => (**ty).clone(), + }; + + methods.push(MethodDef { + name: method_name, + args, + return_type, + }); + } - #[derive(::serde::Serialize, ::serde::Deserialize, Debug)] - pub struct #resp_name(pub #ret_ty); - }); + Ok(ServiceDef { + mode, + name, + methods, + }) } +} - // Generate Server trait - let server_methods: Vec<_> = methods - .iter() - .map(|m| { - let name = &m.name; - let ret_ty = &m.return_type; - let arg_names: Vec<_> = m.args.iter().map(|(n, _)| n).collect(); - let arg_types: Vec<_> = m.args.iter().map(|(_, t)| t).collect(); +fn generate_service_module(def: ServiceDef) -> syn::Result { + let mod_name = format_ident!("{}", to_snake_case(&def.name.to_string())); + let service_name_str = def.name.to_string(); + let method_count = def.methods.len() as u16; - quote! { - fn #name(&self, ctx: &::mill_rpc_core::RpcContext, #( #arg_names: #arg_types ),*) -> #ret_ty; - } + let gen_server = def.mode != GenerateMode::ClientOnly; + let gen_client = def.mode != GenerateMode::ServerOnly; + + let method_consts: Vec<_> = def + .methods + .iter() + .enumerate() + .map(|(idx, m)| { + let const_name = format_ident!("{}", m.name.to_string().to_uppercase()); + let id = idx as u16; + quote! { pub const #const_name: u16 = #id; } }) .collect(); - // Generate dispatcher match arms - let dispatch_arms: Vec<_> = methods + // Request / Response types + let type_defs: Vec<_> = def + .methods .iter() .map(|m| { - let name = &m.name; - let const_name = format_ident!("{}", m.name.to_string().to_uppercase()); let req_name = format_ident!("{}Request", to_pascal_case(&m.name.to_string())); let resp_name = format_ident!("{}Response", to_pascal_case(&m.name.to_string())); + let ret_ty = &m.return_type; let field_names: Vec<_> = m.args.iter().map(|(n, _)| n).collect(); + let field_types: Vec<_> = m.args.iter().map(|(_, t)| t).collect(); - let call_args = if m.args.is_empty() { - quote! {} + let req_struct = if m.args.is_empty() { + quote! { + #[derive(::serde::Serialize, ::serde::Deserialize, Debug)] + pub(super) struct #req_name; + } } else { - let args: Vec<_> = field_names.iter().map(|n| quote! { req.#n }).collect(); - quote! { , #( #args ),* } + quote! { + #[derive(::serde::Serialize, ::serde::Deserialize, Debug)] + pub(super) struct #req_name { + #( pub #field_names: #field_types, )* + } + } }; quote! { - #methods_mod_name::#const_name => { - let req: #req_name = codec.deserialize(args)?; - let result = self.0.#name(ctx #call_args); - codec.serialize(&#resp_name(result)) - } + #req_struct + + #[derive(::serde::Serialize, ::serde::Deserialize, Debug)] + pub(super) struct #resp_name(pub #ret_ty); } }) .collect(); - // Generate client methods - let client_methods: Vec<_> = methods - .iter() - .map(|m| { - let name = &m.name; - let ret_ty = &m.return_type; - let const_name = format_ident!("{}", m.name.to_string().to_uppercase()); - let req_name = format_ident!("{}Request", to_pascal_case(&m.name.to_string())); - let resp_name = format_ident!("{}Response", to_pascal_case(&m.name.to_string())); - - let arg_names: Vec<_> = m.args.iter().map(|(n, _)| n).collect(); - let arg_types: Vec<_> = m.args.iter().map(|(_, t)| t).collect(); + let server_trait = if gen_server { + let trait_methods: Vec<_> = def + .methods + .iter() + .map(|m| { + let name = &m.name; + let ret_ty = &m.return_type; + let arg_names: Vec<_> = m.args.iter().map(|(n, _)| n).collect(); + let arg_types: Vec<_> = m.args.iter().map(|(_, t)| t).collect(); + quote! { + fn #name(&self, ctx: &::mill_rpc_core::RpcContext, #( #arg_names: #arg_types ),*) -> #ret_ty; + } + }) + .collect(); + + let dispatch_arms: Vec<_> = def + .methods + .iter() + .enumerate() + .map(|(idx, m)| { + let name = &m.name; + let const_name = format_ident!("{}", m.name.to_string().to_uppercase()); + let req_name = format_ident!("{}Request", to_pascal_case(&m.name.to_string())); + let resp_name = format_ident!("{}Response", to_pascal_case(&m.name.to_string())); + + let call_args = if m.args.is_empty() { + quote! {} + } else { + let field_names: Vec<_> = m.args.iter().map(|(n, _)| n).collect(); + let args: Vec<_> = field_names.iter().map(|n| quote! { req.#n }).collect(); + quote! { , #( #args ),* } + }; + + quote! { + methods::#const_name => { + let req: types::#req_name = codec.deserialize(args)?; + let result = svc.#name(ctx #call_args); + codec.serialize(&types::#resp_name(result)) + } + } + }) + .collect(); - let req_construct = if m.args.is_empty() { - quote! { #req_name } - } else { - quote! { #req_name { #( #arg_names: #arg_names, )* } } - }; + quote! { + /// Server trait — implement this to handle RPC calls for this service. + pub trait Service: Send + Sync + 'static { + #( #trait_methods )* + } - quote! { - pub fn #name(&self, #( #arg_names: #arg_types ),*) -> Result<#ret_ty, ::mill_rpc_core::RpcError> { - let req = #req_construct; - let payload = self.codec.serialize(&req)?; - let resp_bytes = self.transport.call( - self.service_id, - #methods_mod_name::#const_name, - payload, - )?; - let resp: #resp_name = self.codec.deserialize(&resp_bytes)?; - Ok(resp.0) + /// Internal dispatcher that bridges `Service` impl to `ServiceDispatch`. + struct Dispatcher(T); + + impl ::mill_rpc_core::ServiceDispatch for Dispatcher { + fn dispatch( + &self, + ctx: &::mill_rpc_core::RpcContext, + method_id: u16, + args: &[u8], + codec: &::mill_rpc_core::Codec, + ) -> Result, ::mill_rpc_core::RpcError> { + let svc = &self.0; + match method_id { + #( #dispatch_arms, )* + _ => Err(::mill_rpc_core::RpcError::method_not_found(method_id)), + } } } - }) - .collect(); - - // Count methods for service registration - let method_count = methods.len() as u16; - let service_name_str = trait_name.to_string(); - let output = quote! { - /// Method ID constants for this service. - pub mod #methods_mod_name { - #( #method_consts )* + /// Wrap a `Service` implementation for server registration. + /// + /// # Example + /// ```ignore + /// RpcServer::builder() + /// .service(calculator::server(MyCalc)) + /// .build(&event_loop)?; + /// ``` + pub fn server(implementation: T) -> impl ::mill_rpc_core::ServiceDispatch { + Dispatcher(implementation) + } } + } else { + quote! {} + }; - #( #type_defs )* + let client_code = if gen_client { + let client_methods: Vec<_> = def + .methods + .iter() + .map(|m| { + let name = &m.name; + let ret_ty = &m.return_type; + let const_name = format_ident!("{}", m.name.to_string().to_uppercase()); + let req_name = format_ident!("{}Request", to_pascal_case(&m.name.to_string())); + let resp_name = format_ident!("{}Response", to_pascal_case(&m.name.to_string())); + + let arg_names: Vec<_> = m.args.iter().map(|(n, _)| n).collect(); + let arg_types: Vec<_> = m.args.iter().map(|(_, t)| t).collect(); + + let req_construct = if m.args.is_empty() { + quote! { types::#req_name } + } else { + let fields: Vec<_> = arg_names.iter().map(|n| quote! { #n: #n }).collect(); + quote! { types::#req_name { #( #fields, )* } } + }; + + quote! { + pub fn #name(&self, #( #arg_names: #arg_types ),*) -> Result<#ret_ty, ::mill_rpc_core::RpcError> { + let req = #req_construct; + let payload = self.codec.serialize(&req)?; + let resp_bytes = self.transport.call( + self.service_id, + methods::#const_name, + payload, + )?; + let resp: types::#resp_name = self.codec.deserialize(&resp_bytes)?; + Ok(resp.0) + } + } + }) + .collect(); - /// Server trait - implement this to handle RPC calls. - pub trait #server_trait_name: Send + Sync + 'static { - #( #server_methods )* - } + quote! { + /// Generated RPC client for this service. + pub struct Client { + transport: ::std::sync::Arc, + codec: ::mill_rpc_core::Codec, + service_id: u16, + } - /// Service descriptor for registration. - pub struct #trait_name; + impl Client { + /// Create a new client. + /// + /// - `transport`: the RPC transport (typically an `RpcClient`) + /// - `codec`: serialization codec (must match the server) + /// - `service_id`: the ID assigned to this service on the server + /// (matches registration order, starting from 0) + pub fn new( + transport: ::std::sync::Arc, + codec: ::mill_rpc_core::Codec, + service_id: u16, + ) -> Self { + Self { transport, codec, service_id } + } - impl #trait_name { - pub const SERVICE_NAME: &'static str = #service_name_str; - pub const METHOD_COUNT: u16 = #method_count; + #( #client_methods )* + } } + } else { + quote! {} + }; - /// Wrapper that adapts a `{Name}Server` impl into a `ServiceDispatch`. - /// - /// This avoids orphan-rule violations by being a local concrete type. - pub struct #dispatcher_name(pub T); - - impl ::mill_rpc_core::ServiceDispatch for #dispatcher_name { - fn dispatch( - &self, - ctx: &::mill_rpc_core::RpcContext, - method_id: u16, - args: &[u8], - codec: &::mill_rpc_core::Codec, - ) -> Result, ::mill_rpc_core::RpcError> { - match method_id { - #( #dispatch_arms, )* - _ => Err(::mill_rpc_core::RpcError::method_not_found(method_id)), - } + let output = quote! { + pub mod #mod_name { + //! Auto-generated RPC module for the **#service_name_str** service. + + /// Method ID constants. + pub mod methods { + #( #method_consts )* } - } - /// Generated client for calling this service remotely. - pub struct #client_struct_name { - transport: ::std::sync::Arc, - codec: ::mill_rpc_core::Codec, - service_id: u16, - } + /// Service metadata. + pub const SERVICE_NAME: &str = #service_name_str; + pub const METHOD_COUNT: u16 = #method_count; - impl #client_struct_name { - /// Create a new client from a transport, codec, and service ID. - pub fn new( - transport: ::std::sync::Arc, - codec: ::mill_rpc_core::Codec, - service_id: u16, - ) -> Self { - Self { transport, codec, service_id } + /// Internal request/response types (not part of the public API). + mod types { + #( #type_defs )* } - #( #client_methods )* + #server_trait + + #client_code } }; Ok(output) } +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + fn to_snake_case(s: &str) -> String { let mut result = String::new(); for (i, ch) in s.chars().enumerate() { diff --git a/mill-rpc/src/client.rs b/mill-rpc/src/client.rs index f13d43a..2aedf67 100644 --- a/mill-rpc/src/client.rs +++ b/mill-rpc/src/client.rs @@ -2,11 +2,11 @@ //! //! Connects to an RPC server, sends request frames, and waits for responses. -use crate::{Codec, RpcError, RpcStatus, RpcTransport}; +use crate::{RpcError, RpcTransport}; use mill_io::EventLoop; use mill_net::tcp::traits::{ConnectionId, NetworkHandler}; use mill_net::tcp::{ServerContext, TcpClient}; -use mill_rpc_core::protocol::{self, Frame, MessageType, HEADER_SIZE}; +use mill_rpc_core::protocol::{self, Frame, MessageType}; use mio::Token; use std::collections::HashMap; use std::net::SocketAddr; @@ -35,7 +35,6 @@ pub struct RpcClient { tcp_client: Mutex>, shared: Arc, next_request_id: AtomicU64, - codec: Codec, timeout: Duration, } @@ -44,7 +43,6 @@ impl RpcClient { pub fn connect( addr: SocketAddr, event_loop: &Arc, - codec: Codec, ) -> Result, RpcError> { let shared = Arc::new(ClientShared { pending: Mutex::new(HashMap::new()), @@ -67,7 +65,6 @@ impl RpcClient { tcp_client: Mutex::new(tcp_client), shared, next_request_id: AtomicU64::new(1), - codec, timeout: Duration::from_secs(30), })) } @@ -98,7 +95,6 @@ impl RpcClient { ); } - // Build and send the request frame. let frame = Frame::request(request_id, service_id, method_id, payload, false); { let client = self.tcp_client.lock().unwrap(); @@ -107,7 +103,6 @@ impl RpcClient { .map_err(|e| RpcError::unavailable(format!("Send failed: {}", e)))?; } - // Wait for the response with timeout. let mut pending = self.shared.pending.lock().unwrap(); let deadline = std::time::Instant::now() + self.timeout; diff --git a/mill-rpc/src/lib.rs b/mill-rpc/src/lib.rs index 68c7e44..c6692f8 100644 --- a/mill-rpc/src/lib.rs +++ b/mill-rpc/src/lib.rs @@ -3,17 +3,27 @@ //! # Quick Start //! //! ```ignore -//! use mill_rpc::prelude::*; -//! -//! #[mill_rpc::service] -//! trait Calculator { -//! fn add(a: i32, b: i32) -> i32; +//! // Define a service — generates a `calculator` module +//! mill_rpc::service! { +//! service Calculator { +//! fn add(a: i32, b: i32) -> i32; +//! } //! } //! +//! // Server side //! struct MyCalc; -//! impl CalculatorServer for MyCalc { +//! impl calculator::Service for MyCalc { //! fn add(&self, _ctx: &RpcContext, a: i32, b: i32) -> i32 { a + b } //! } +//! +//! // Register +//! RpcServer::builder() +//! .service(calculator::server(MyCalc)) +//! .build(&event_loop)?; +//! +//! // Client side +//! let client = calculator::Client::new(transport, codec, 0); +//! client.add(2, 3)?; //! ``` pub mod client; @@ -21,11 +31,14 @@ pub mod server; pub mod prelude; -// Re-exports +// Re-exports from core pub use mill_rpc_core::{ - Codec, CodecType, Flags, Frame, FrameHeader, MessageType, RpcContext, RpcError, RpcStatus, + Codec, CodecType, Frame, Flags, FrameHeader, MessageType, + RpcContext, RpcError, RpcStatus, RpcTransport, ServiceDispatch, }; + +// Re-export the service! macro pub use mill_rpc_macros::service; pub use client::RpcClient; diff --git a/mill-rpc/src/prelude.rs b/mill-rpc/src/prelude.rs index 73a22e9..52f69af 100644 --- a/mill-rpc/src/prelude.rs +++ b/mill-rpc/src/prelude.rs @@ -1,8 +1,10 @@ //! Convenient re-exports for Mill-RPC users. +pub use crate::{ + Codec, CodecType, RpcContext, RpcError, RpcStatus, + RpcTransport, ServiceDispatch, +}; pub use crate::client::RpcClient; pub use crate::server::RpcServer; -pub use crate::{Codec, CodecType, RpcContext, RpcError, RpcStatus, RpcTransport, ServiceDispatch}; -pub use mill_rpc_macros::service; pub use serde::{Deserialize, Serialize}; diff --git a/mill-rpc/src/server.rs b/mill-rpc/src/server.rs index 4b1f2ee..b2bf5bb 100644 --- a/mill-rpc/src/server.rs +++ b/mill-rpc/src/server.rs @@ -9,7 +9,7 @@ use mill_net::errors::{NetworkError, Result}; use mill_net::tcp::config::TcpServerConfig; use mill_net::tcp::traits::{ConnectionId, NetworkHandler}; use mill_net::tcp::{ServerContext, TcpServer}; -use mill_rpc_core::protocol::{self, Frame, FrameHeader, MessageType, HEADER_SIZE}; +use mill_rpc_core::protocol::{self, Frame, MessageType}; use mio::Token; use std::collections::HashMap; use std::net::SocketAddr; diff --git a/mill-rpc/tests/integration_test.rs b/mill-rpc/tests/integration_test.rs index 91876d0..c41cbac 100644 --- a/mill-rpc/tests/integration_test.rs +++ b/mill-rpc/tests/integration_test.rs @@ -4,20 +4,18 @@ use std::sync::Arc; use std::thread; use std::time::Duration; -// ---- Service Definition ---- - -#[mill_rpc::service] -trait Calculator { - fn add(a: i32, b: i32) -> i32; - fn multiply(a: f64, b: f64) -> f64; - fn echo(msg: String) -> String; +// Generate both server and client for testing +mill_rpc::service! { + service Calculator { + fn add(a: i32, b: i32) -> i32; + fn multiply(a: f64, b: f64) -> f64; + fn echo(msg: String) -> String; + } } -// ---- Server Implementation ---- - struct MyCalculator; -impl CalculatorServer for MyCalculator { +impl calculator::Service for MyCalculator { fn add(&self, _ctx: &RpcContext, a: i32, b: i32) -> i32 { a + b } @@ -32,70 +30,69 @@ impl CalculatorServer for MyCalculator { } #[test] -fn test_rpc_roundtrip() { - let _ = env_logger::builder().is_test(true).try_init(); - +fn test_dispatch_add() { let codec = Codec::bincode(); let ctx = RpcContext::new(1, 0, 0); - let dispatcher = CalculatorDispatcher(MyCalculator); + let dispatcher = calculator::server(MyCalculator); - // Test add - let args = codec.serialize(&AddRequest { a: 2, b: 3 }).unwrap(); - let result_bytes = dispatcher - .dispatch(&ctx, calculator_methods::ADD, &args, &codec) - .unwrap(); - let result: AddResponse = codec.deserialize(&result_bytes).unwrap(); - assert_eq!(result.0, 5); + // bincode serialization of AddRequest { a: 2, b: 3 }: two i32 LE + let mut payload = Vec::new(); + payload.extend_from_slice(&2i32.to_le_bytes()); + payload.extend_from_slice(&3i32.to_le_bytes()); - // Test multiply - let args = codec - .serialize(&MultiplyRequest { a: 3.0, b: 4.0 }) - .unwrap(); let result_bytes = dispatcher - .dispatch(&ctx, calculator_methods::MULTIPLY, &args, &codec) - .unwrap(); - let result: MultiplyResponse = codec.deserialize(&result_bytes).unwrap(); - assert!((result.0 - 12.0).abs() < f64::EPSILON); - - // Test echo - let args = codec - .serialize(&EchoRequest { - msg: "hello".to_string(), - }) + .dispatch(&ctx, calculator::methods::ADD, &payload, &codec) .unwrap(); + let result: i32 = codec.deserialize(&result_bytes).unwrap(); + assert_eq!(result, 5); +} + +#[test] +fn test_dispatch_multiply() { + let codec = Codec::bincode(); + let ctx = RpcContext::new(1, 0, 0); + let dispatcher = calculator::server(MyCalculator); + + // bincode serialization of MultiplyRequest { a: 3.0, b: 4.0 }: two f64 LE + let mut payload = Vec::new(); + payload.extend_from_slice(&3.0f64.to_le_bytes()); + payload.extend_from_slice(&4.0f64.to_le_bytes()); + let result_bytes = dispatcher - .dispatch(&ctx, calculator_methods::ECHO, &args, &codec) + .dispatch(&ctx, calculator::methods::MULTIPLY, &payload, &codec) .unwrap(); - let result: EchoResponse = codec.deserialize(&result_bytes).unwrap(); - assert_eq!(result.0, "echo: hello"); - - // Test method not found - let result = dispatcher.dispatch(&ctx, 999, &[], &codec); - assert!(result.is_err()); + let result: f64 = codec.deserialize(&result_bytes).unwrap(); + assert!((result - 12.0).abs() < f64::EPSILON); } #[test] -fn test_rpc_server_client_integration() { - let _ = env_logger::builder().is_test(true).try_init(); +fn test_dispatch_method_not_found() { + let codec = Codec::bincode(); + let ctx = RpcContext::new(1, 0, 0); + let dispatcher = calculator::server(MyCalculator); - let event_loop = Arc::new(EventLoop::new(4, 1024, 100).unwrap()); + let err = dispatcher.dispatch(&ctx, 999, &[], &codec); + assert!(err.is_err()); +} + +#[test] +fn test_server_builds_and_stops() { + let event_loop = Arc::new(EventLoop::new(2, 1024, 100).unwrap()); let server = RpcServer::builder() .bind("127.0.0.1:0".parse().unwrap()) - .service(CalculatorDispatcher(MyCalculator)) + .service(calculator::server(MyCalculator)) .build(&event_loop); match server { - Ok(_server) => { + Ok(_s) => { let el = event_loop.clone(); - let handle = thread::spawn(move || { + let h = thread::spawn(move || { let _ = el.run(); }); - - thread::sleep(Duration::from_millis(100)); - + thread::sleep(Duration::from_millis(50)); event_loop.stop(); - let _ = handle.join(); + let _ = h.join(); } Err(e) => { eprintln!("Server build failed (non-fatal in test): {}", e); From f1333910167a585a5271594b4b0420bebe13ff74 Mon Sep 17 00:00:00 2001 From: hulxv Date: Sun, 1 Mar 2026 02:53:09 +0200 Subject: [PATCH 3/7] fix: rustfmt --- mill-rpc/examples/calculator_client.rs | 4 ++-- mill-rpc/examples/concurrent_clients.rs | 13 ++++++++++--- mill-rpc/examples/echo_client.rs | 4 ++-- mill-rpc/examples/kv_client.rs | 3 +-- mill-rpc/examples/multi_service_client.rs | 3 +-- mill-rpc/src/client.rs | 5 +---- mill-rpc/src/lib.rs | 3 +-- mill-rpc/src/prelude.rs | 5 +---- 8 files changed, 19 insertions(+), 21 deletions(-) diff --git a/mill-rpc/examples/calculator_client.rs b/mill-rpc/examples/calculator_client.rs index 1d02e38..0f645f3 100644 --- a/mill-rpc/examples/calculator_client.rs +++ b/mill-rpc/examples/calculator_client.rs @@ -32,8 +32,8 @@ fn main() { thread::sleep(Duration::from_millis(50)); let addr = "127.0.0.1:9001".parse().unwrap(); - let transport = RpcClient::connect(addr, &event_loop) - .expect("Failed to connect to calculator server"); + let transport = + RpcClient::connect(addr, &event_loop).expect("Failed to connect to calculator server"); let client = calculator::Client::new(transport, Codec::bincode(), 0); diff --git a/mill-rpc/examples/concurrent_clients.rs b/mill-rpc/examples/concurrent_clients.rs index 7574bad..d386379 100644 --- a/mill-rpc/examples/concurrent_clients.rs +++ b/mill-rpc/examples/concurrent_clients.rs @@ -49,7 +49,9 @@ fn main() { .expect("Failed to start server"); let sel = server_el.clone(); - let server_thread = thread::spawn(move || { sel.run().unwrap(); }); + let server_thread = thread::spawn(move || { + sel.run().unwrap(); + }); thread::sleep(Duration::from_millis(100)); println!( @@ -64,7 +66,9 @@ fn main() { let handle = thread::spawn(move || { let client_el = Arc::new(EventLoop::new(1, 256, 50).unwrap()); let cel = client_el.clone(); - let el_thread = thread::spawn(move || { cel.run().unwrap(); }); + let el_thread = thread::spawn(move || { + cel.run().unwrap(); + }); thread::sleep(Duration::from_millis(20)); let transport = RpcClient::connect(addr, &client_el).unwrap(); @@ -103,7 +107,10 @@ fn main() { println!("Total requests: {}", total); println!("Unique values: {}", all.len()); println!("Time: {:?}", elapsed); - println!("Throughput: {:.0} req/s", total as f64 / elapsed.as_secs_f64()); + println!( + "Throughput: {:.0} req/s", + total as f64 / elapsed.as_secs_f64() + ); assert_eq!(all.len(), total, "No lost updates"); println!("\nConcurrency test passed!"); diff --git a/mill-rpc/examples/echo_client.rs b/mill-rpc/examples/echo_client.rs index e40eabf..9f6d333 100644 --- a/mill-rpc/examples/echo_client.rs +++ b/mill-rpc/examples/echo_client.rs @@ -30,8 +30,8 @@ fn main() { thread::sleep(Duration::from_millis(50)); let addr = "127.0.0.1:9002".parse().unwrap(); - let transport = RpcClient::connect(addr, &event_loop) - .expect("Failed to connect to echo server"); + let transport = + RpcClient::connect(addr, &event_loop).expect("Failed to connect to echo server"); let client = echo::Client::new(transport, Codec::bincode(), 0); diff --git a/mill-rpc/examples/kv_client.rs b/mill-rpc/examples/kv_client.rs index d08d3ba..16e60b0 100644 --- a/mill-rpc/examples/kv_client.rs +++ b/mill-rpc/examples/kv_client.rs @@ -32,8 +32,7 @@ fn main() { thread::sleep(Duration::from_millis(50)); let addr = "127.0.0.1:9003".parse().unwrap(); - let transport = RpcClient::connect(addr, &event_loop) - .expect("Failed to connect to KV server"); + let transport = RpcClient::connect(addr, &event_loop).expect("Failed to connect to KV server"); let kv = key_value::Client::new(transport, Codec::bincode(), 0); diff --git a/mill-rpc/examples/multi_service_client.rs b/mill-rpc/examples/multi_service_client.rs index 0d4fa8c..4055233 100644 --- a/mill-rpc/examples/multi_service_client.rs +++ b/mill-rpc/examples/multi_service_client.rs @@ -39,8 +39,7 @@ fn main() { thread::sleep(Duration::from_millis(50)); let addr = "127.0.0.1:9004".parse().unwrap(); - let transport = RpcClient::connect(addr, &event_loop) - .expect("Failed to connect"); + let transport = RpcClient::connect(addr, &event_loop).expect("Failed to connect"); let math = math_service::Client::new(transport.clone(), Codec::bincode(), 0); let strings = string_service::Client::new(transport, Codec::bincode(), 1); diff --git a/mill-rpc/src/client.rs b/mill-rpc/src/client.rs index 2aedf67..7fd3df3 100644 --- a/mill-rpc/src/client.rs +++ b/mill-rpc/src/client.rs @@ -40,10 +40,7 @@ pub struct RpcClient { impl RpcClient { /// Connect to an RPC server. - pub fn connect( - addr: SocketAddr, - event_loop: &Arc, - ) -> Result, RpcError> { + pub fn connect(addr: SocketAddr, event_loop: &Arc) -> Result, RpcError> { let shared = Arc::new(ClientShared { pending: Mutex::new(HashMap::new()), notify: Condvar::new(), diff --git a/mill-rpc/src/lib.rs b/mill-rpc/src/lib.rs index c6692f8..432d100 100644 --- a/mill-rpc/src/lib.rs +++ b/mill-rpc/src/lib.rs @@ -33,8 +33,7 @@ pub mod prelude; // Re-exports from core pub use mill_rpc_core::{ - Codec, CodecType, Frame, Flags, FrameHeader, MessageType, - RpcContext, RpcError, RpcStatus, + Codec, CodecType, Flags, Frame, FrameHeader, MessageType, RpcContext, RpcError, RpcStatus, RpcTransport, ServiceDispatch, }; diff --git a/mill-rpc/src/prelude.rs b/mill-rpc/src/prelude.rs index 52f69af..2bac0d0 100644 --- a/mill-rpc/src/prelude.rs +++ b/mill-rpc/src/prelude.rs @@ -1,10 +1,7 @@ //! Convenient re-exports for Mill-RPC users. -pub use crate::{ - Codec, CodecType, RpcContext, RpcError, RpcStatus, - RpcTransport, ServiceDispatch, -}; pub use crate::client::RpcClient; pub use crate::server::RpcServer; +pub use crate::{Codec, CodecType, RpcContext, RpcError, RpcStatus, RpcTransport, ServiceDispatch}; pub use serde::{Deserialize, Serialize}; From 9e1074d2b523933dfd9dfdd4ceaf0a901551ff06 Mon Sep 17 00:00:00 2001 From: hulxv Date: Sun, 1 Mar 2026 02:53:56 +0200 Subject: [PATCH 4/7] fix: clippy --- mill-rpc/examples/kv_client.rs | 1 + mill-rpc/examples/kv_server.rs | 5 +++++ mill-rpc/examples/multi_service_server.rs | 4 ++-- mill-rpc/mill-rpc-macros/src/lib.rs | 3 +-- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mill-rpc/examples/kv_client.rs b/mill-rpc/examples/kv_client.rs index 16e60b0..31e5179 100644 --- a/mill-rpc/examples/kv_client.rs +++ b/mill-rpc/examples/kv_client.rs @@ -17,6 +17,7 @@ mill_rpc::service! { fn delete(key: String) -> bool; fn keys() -> Vec; fn len() -> u64; + fn is_empty() -> bool; fn clear() -> u64; } } diff --git a/mill-rpc/examples/kv_server.rs b/mill-rpc/examples/kv_server.rs index 80a4e21..2fec3b4 100644 --- a/mill-rpc/examples/kv_server.rs +++ b/mill-rpc/examples/kv_server.rs @@ -16,6 +16,7 @@ mill_rpc::service! { fn delete(key: String) -> bool; fn keys() -> Vec; fn len() -> u64; + fn is_empty() -> bool; fn clear() -> u64; } } @@ -68,6 +69,10 @@ impl key_value::Service for KvStore { len } + fn is_empty(&self, _ctx: &RpcContext) -> bool { + self.data.read().unwrap().is_empty() + } + fn clear(&self, _ctx: &RpcContext) -> u64 { let mut data = self.data.write().unwrap(); let count = data.len() as u64; diff --git a/mill-rpc/examples/multi_service_server.rs b/mill-rpc/examples/multi_service_server.rs index 36110e5..b27b2a6 100644 --- a/mill-rpc/examples/multi_service_server.rs +++ b/mill-rpc/examples/multi_service_server.rs @@ -56,12 +56,12 @@ impl math_service::Service for MathImpl { if n < 4 { return true; } - if n % 2 == 0 { + if n.is_multiple_of(2) { return false; } let mut i = 3; while i * i <= n { - if n % i == 0 { + if n.is_multiple_of(i) { return false; } i += 2; diff --git a/mill-rpc/mill-rpc-macros/src/lib.rs b/mill-rpc/mill-rpc-macros/src/lib.rs index b642f08..1d5ef7b 100644 --- a/mill-rpc/mill-rpc-macros/src/lib.rs +++ b/mill-rpc/mill-rpc-macros/src/lib.rs @@ -227,8 +227,7 @@ fn generate_service_module(def: ServiceDef) -> syn::Result = def .methods .iter() - .enumerate() - .map(|(idx, m)| { + .map(|m| { let name = &m.name; let const_name = format_ident!("{}", m.name.to_string().to_uppercase()); let req_name = format_ident!("{}Request", to_pascal_case(&m.name.to_string())); From 1f7c77d9831c3272ba163b4e512baeca854e4ca6 Mon Sep 17 00:00:00 2001 From: hulxv Date: Sun, 1 Mar 2026 03:10:33 +0200 Subject: [PATCH 5/7] cargo: use workspace metadata in mill-rpc --- mill-rpc/Cargo.toml | 7 +++++-- mill-rpc/mill-rpc-core/Cargo.toml | 7 +++++-- mill-rpc/mill-rpc-macros/Cargo.toml | 7 +++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/mill-rpc/Cargo.toml b/mill-rpc/Cargo.toml index abb2b8e..c1daa2f 100644 --- a/mill-rpc/Cargo.toml +++ b/mill-rpc/Cargo.toml @@ -1,7 +1,10 @@ [package] name = "mill-rpc" -version = "0.1.0" -edition = "2021" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +rust-version.workspace = true description = "An Axum-inspired RPC framework built on Mill-IO" [dependencies] diff --git a/mill-rpc/mill-rpc-core/Cargo.toml b/mill-rpc/mill-rpc-core/Cargo.toml index b6f8ce1..53ff741 100644 --- a/mill-rpc/mill-rpc-core/Cargo.toml +++ b/mill-rpc/mill-rpc-core/Cargo.toml @@ -1,7 +1,10 @@ [package] name = "mill-rpc-core" -version = "0.1.0" -edition = "2021" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +rust-version.workspace = true description = "Core types, wire protocol, and codec traits for Mill-RPC" [dependencies] diff --git a/mill-rpc/mill-rpc-macros/Cargo.toml b/mill-rpc/mill-rpc-macros/Cargo.toml index 2a05b5d..8cfe928 100644 --- a/mill-rpc/mill-rpc-macros/Cargo.toml +++ b/mill-rpc/mill-rpc-macros/Cargo.toml @@ -1,7 +1,10 @@ [package] name = "mill-rpc-macros" -version = "0.1.0" -edition = "2021" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +rust-version.workspace = true description = "Proc macros for Mill-RPC service definitions" [lib] From 3354381abcbc3e37e80b81c0873257187805cb54 Mon Sep 17 00:00:00 2001 From: hulxv Date: Thu, 12 Mar 2026 22:27:35 +0200 Subject: [PATCH 6/7] bench: coinswap legacy rpc server vs mill-rpc --- mill-rpc/Cargo.toml | 6 + mill-rpc/README.md | 2 +- mill-rpc/benches/rpc_comparison.rs | 535 ++++++++++++++++++++++++++++ mill-rpc/mill-rpc-macros/src/lib.rs | 4 +- 4 files changed, 545 insertions(+), 2 deletions(-) create mode 100644 mill-rpc/benches/rpc_comparison.rs diff --git a/mill-rpc/Cargo.toml b/mill-rpc/Cargo.toml index c1daa2f..fe6ae21 100644 --- a/mill-rpc/Cargo.toml +++ b/mill-rpc/Cargo.toml @@ -19,6 +19,8 @@ mio = { version = "1", features = ["os-poll", "net"] } [dev-dependencies] env_logger = "0.11" +criterion = { version = "0.5", features = ["html_reports"] } +serde_cbor = "0.11" [[example]] name = "calculator_server" @@ -55,3 +57,7 @@ path = "examples/multi_service_client.rs" [[example]] name = "concurrent_clients" path = "examples/concurrent_clients.rs" + +[[bench]] +name = "rpc_comparison" +harness = false diff --git a/mill-rpc/README.md b/mill-rpc/README.md index 662ae16..ed827b6 100644 --- a/mill-rpc/README.md +++ b/mill-rpc/README.md @@ -56,7 +56,7 @@ fn main() { ### Client ```rust -let transport = RpcClient::connect(addr, &event_loop, Codec::bincode()).unwrap(); +let transport = RpcClient::connect(addr, &event_loop).unwrap(); let client = calculator::Client::new(transport, Codec::bincode(), 0); let sum = client.add(10, 25).unwrap(); // 35 diff --git a/mill-rpc/benches/rpc_comparison.rs b/mill-rpc/benches/rpc_comparison.rs new file mode 100644 index 0000000..30f2a3d --- /dev/null +++ b/mill-rpc/benches/rpc_comparison.rs @@ -0,0 +1,535 @@ +//! Benchmark comparison: Legacy hand-rolled RPC vs Mill-RPC. +//! +//! The legacy approach mirrors how coinswap's maker RPC works: +//! - `TcpListener` with `set_nonblocking(true)` +//! - Busy-poll loop with `sleep(HEART_BEAT_INTERVAL)` +//! - Manual `serde_cbor` serialization of request/response enums +//! - Single-threaded, one request at a time +//! +//! Mill-RPC uses: +//! - mill-net's reactor-based TcpServer +//! - Auto-generated dispatch from `mill_rpc::service!` +//! - Thread-pool for concurrent request handling +//! - Binary framing protocol with bincode + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use serde::{Deserialize, Serialize}; +use std::io::{Read, Write}; +use std::net::{SocketAddr, TcpListener, TcpStream}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, RwLock}; +use std::thread; +use std::time::Duration; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +struct UtxoEntry { + txid: String, + vout: u32, + value: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +struct BalanceInfo { + confirmed: u64, + unconfirmed: u64, + total: u64, +} + +/// Simulated state that both servers share. +struct SharedState { + utxos: Vec, + balances: BalanceInfo, + counter: RwLock, +} + +impl SharedState { + fn new() -> Self { + let utxos: Vec = (0..20) + .map(|i| UtxoEntry { + txid: format!( + "abcdef{:04x}abcdef{:04x}abcdef{:04x}abcdef{:04x}", + i, i, i, i + ), + vout: i, + value: (i as u64 + 1) * 100_000, + }) + .collect(); + + Self { + utxos, + balances: BalanceInfo { + confirmed: 5_000_000, + unconfirmed: 200_000, + total: 5_200_000, + }, + counter: RwLock::new(0), + } + } +} + +/// Legacy RPC (mirrors coinswap's approach) +mod legacy { + use super::*; + + #[derive(Debug, Serialize, Deserialize)] + pub enum RpcMsgReq { + Ping, + GetUtxos, + GetBalances, + Increment, + Echo(String), + } + + #[derive(Debug, Serialize, Deserialize)] + pub enum RpcMsgResp { + Pong, + UtxoResp { utxos: Vec }, + BalancesResp(BalanceInfo), + IncrementResp(u64), + EchoResp(String), + ServerError(String), + } + + /// Length-prefixed message: 4-byte LE length + cbor payload (mirrors coinswap's read_message/send_message). + pub fn send_message(stream: &mut TcpStream, msg: &RpcMsgResp) -> std::io::Result<()> { + let data = serde_cbor::to_vec(msg).unwrap(); + let len = (data.len() as u32).to_le_bytes(); + stream.write_all(&len)?; + stream.write_all(&data)?; + stream.flush()?; + Ok(()) + } + + pub fn read_message(stream: &mut TcpStream) -> std::io::Result> { + let mut len_buf = [0u8; 4]; + stream.read_exact(&mut len_buf)?; + let len = u32::from_le_bytes(len_buf) as usize; + let mut buf = vec![0u8; len]; + stream.read_exact(&mut buf)?; + Ok(buf) + } + + pub fn send_request(stream: &mut TcpStream, req: &RpcMsgReq) -> std::io::Result<()> { + let data = serde_cbor::to_vec(req).unwrap(); + let len = (data.len() as u32).to_le_bytes(); + stream.write_all(&len)?; + stream.write_all(&data)?; + stream.flush()?; + Ok(()) + } + + pub fn read_response(stream: &mut TcpStream) -> std::io::Result { + let data = read_message(stream)?; + Ok(serde_cbor::from_slice(&data).unwrap()) + } + + fn handle_request(state: &Arc, socket: &mut TcpStream) -> std::io::Result<()> { + let msg_bytes = read_message(socket)?; + let rpc_request: RpcMsgReq = serde_cbor::from_slice(&msg_bytes).unwrap(); + + let resp = match rpc_request { + RpcMsgReq::Ping => RpcMsgResp::Pong, + RpcMsgReq::GetUtxos => RpcMsgResp::UtxoResp { + utxos: state.utxos.clone(), + }, + RpcMsgReq::GetBalances => RpcMsgResp::BalancesResp(state.balances.clone()), + RpcMsgReq::Increment => { + let mut counter = state.counter.write().unwrap(); + *counter += 1; + RpcMsgResp::IncrementResp(*counter) + } + RpcMsgReq::Echo(msg) => RpcMsgResp::EchoResp(msg), + }; + + send_message(socket, &resp)?; + Ok(()) + } + + /// Start the legacy server and return (addr, shutdown_flag). + pub fn start_server(state: Arc) -> (SocketAddr, Arc) { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + let shutdown = Arc::new(AtomicBool::new(false)); + let shutdown_clone = shutdown.clone(); + + listener.set_nonblocking(true).unwrap(); + + thread::spawn(move || { + while !shutdown_clone.load(Ordering::Relaxed) { + match listener.accept() { + Ok((mut stream, _)) => { + stream + .set_read_timeout(Some(Duration::from_secs(20))) + .unwrap(); + stream + .set_write_timeout(Some(Duration::from_secs(20))) + .unwrap(); + if let Err(e) = handle_request(&state, &mut stream) { + let _ = send_message( + &mut stream, + &RpcMsgResp::ServerError(format!("{e:?}")), + ); + } + } + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {} + Err(_) => {} + } + // Mirrors HEART_BEAT_INTERVAL in coinswap (typically 3 seconds, + // but we use 1ms here so the benchmark isn't dominated by sleep) + thread::sleep(Duration::from_millis(1)); + } + }); + + // Wait for server to be ready + thread::sleep(Duration::from_millis(50)); + (addr, shutdown) + } + + /// Make a single request-response roundtrip (one TCP connection per call, + /// same as coinswap's client pattern). + pub fn call(addr: SocketAddr, req: &RpcMsgReq) -> RpcMsgResp { + let mut stream = TcpStream::connect(addr).unwrap(); + stream + .set_read_timeout(Some(Duration::from_secs(5))) + .unwrap(); + stream + .set_write_timeout(Some(Duration::from_secs(5))) + .unwrap(); + send_request(&mut stream, req).unwrap(); + read_response(&mut stream).unwrap() + } +} + +/// Mill-RPC alternative +mod mill { + use super::*; + use mill_io::EventLoop; + use mill_rpc::prelude::*; + + mill_rpc::service! { + service BenchService { + fn ping() -> (); + fn get_utxos() -> Vec; + fn get_balances() -> BalanceInfo; + fn increment() -> u64; + fn echo(msg: String) -> String; + } + } + + pub struct BenchServiceImpl { + pub state: Arc, + } + + impl bench_service::Service for BenchServiceImpl { + fn ping(&self, _ctx: &RpcContext) {} + + fn get_utxos(&self, _ctx: &RpcContext) -> Vec { + self.state.utxos.clone() + } + + fn get_balances(&self, _ctx: &RpcContext) -> BalanceInfo { + self.state.balances.clone() + } + + fn increment(&self, _ctx: &RpcContext) -> u64 { + let mut counter = self.state.counter.write().unwrap(); + *counter += 1; + *counter + } + + fn echo(&self, _ctx: &RpcContext, msg: String) -> String { + msg + } + } + + /// Make a call using mill-rpc client. + pub fn make_client(addr: SocketAddr, event_loop: &Arc) -> bench_service::Client { + let transport = RpcClient::connect(addr, event_loop).unwrap(); + bench_service::Client::new(transport, Codec::bincode(), 0) + } +} + +fn bench_ping(c: &mut Criterion) { + let mut group = c.benchmark_group("ping_roundtrip"); + group.throughput(Throughput::Elements(1)); + + let state = Arc::new(SharedState::new()); + + // --- Legacy --- + let (legacy_addr, legacy_shutdown) = legacy::start_server(state.clone()); + + group.bench_function("legacy", |b| { + b.iter(|| { + let resp = legacy::call(legacy_addr, &legacy::RpcMsgReq::Ping); + black_box(resp); + }); + }); + + legacy_shutdown.store(true, Ordering::Relaxed); + thread::sleep(Duration::from_millis(50)); + + // --- Mill-RPC --- + let mill_state = Arc::new(SharedState::new()); + let mill_el = Arc::new(mill_io::EventLoop::new(4, 1024, 100).unwrap()); + + let svc = mill::BenchServiceImpl { state: mill_state }; + + let mill_addr: SocketAddr = "127.0.0.1:19876".parse().unwrap(); + let _mill_server = mill_rpc::RpcServer::builder() + .bind(mill_addr) + .service(mill::bench_service::server(svc)) + .build(&mill_el); + + match _mill_server { + Ok(_server) => { + let el = mill_el.clone(); + thread::spawn(move || { + let _ = el.run(); + }); + thread::sleep(Duration::from_millis(100)); + + // Create a persistent client + let client_el = Arc::new(mill_io::EventLoop::new(1, 256, 50).unwrap()); + let cel = client_el.clone(); + thread::spawn(move || { + let _ = cel.run(); + }); + thread::sleep(Duration::from_millis(50)); + + let client = mill::make_client(mill_addr, &client_el); + + group.bench_function("mill_rpc", |b| { + b.iter(|| { + let resp = client.ping(); + black_box(resp).unwrap(); + }); + }); + + client_el.stop(); + mill_el.stop(); + } + Err(e) => { + eprintln!("Mill-RPC server failed to start (skipping): {}", e); + } + } + + group.finish(); +} + +fn bench_get_utxos(c: &mut Criterion) { + let mut group = c.benchmark_group("get_utxos_roundtrip"); + group.throughput(Throughput::Elements(1)); + + let state = Arc::new(SharedState::new()); + + // --- Legacy --- + let (legacy_addr, legacy_shutdown) = legacy::start_server(state.clone()); + + group.bench_function("legacy", |b| { + b.iter(|| { + let resp = legacy::call(legacy_addr, &legacy::RpcMsgReq::GetUtxos); + black_box(resp); + }); + }); + + legacy_shutdown.store(true, Ordering::Relaxed); + thread::sleep(Duration::from_millis(50)); + + // --- Mill-RPC --- + let mill_state = Arc::new(SharedState::new()); + let mill_el = Arc::new(mill_io::EventLoop::new(4, 1024, 100).unwrap()); + + let mill_addr: SocketAddr = "127.0.0.1:19877".parse().unwrap(); + let svc = mill::BenchServiceImpl { state: mill_state }; + let _mill_server = mill_rpc::RpcServer::builder() + .bind(mill_addr) + .service(mill::bench_service::server(svc)) + .build(&mill_el); + + match _mill_server { + Ok(_server) => { + let el = mill_el.clone(); + thread::spawn(move || { + let _ = el.run(); + }); + thread::sleep(Duration::from_millis(100)); + + let client_el = Arc::new(mill_io::EventLoop::new(1, 256, 50).unwrap()); + let cel = client_el.clone(); + thread::spawn(move || { + let _ = cel.run(); + }); + thread::sleep(Duration::from_millis(50)); + + let client = mill::make_client(mill_addr, &client_el); + + group.bench_function("mill_rpc", |b| { + b.iter(|| { + let resp = client.get_utxos(); + black_box(resp).unwrap(); + }); + }); + + client_el.stop(); + mill_el.stop(); + } + Err(e) => { + eprintln!("Mill-RPC server failed to start (skipping): {}", e); + } + } + + group.finish(); +} + +fn bench_echo(c: &mut Criterion) { + let mut group = c.benchmark_group("echo_roundtrip"); + + let state = Arc::new(SharedState::new()); + + for size in [16, 256, 4096] { + let msg = "x".repeat(size); + group.throughput(Throughput::Bytes(size as u64)); + + // --- Legacy --- + let (legacy_addr, legacy_shutdown) = legacy::start_server(state.clone()); + + group.bench_with_input(BenchmarkId::new("legacy", size), &msg, |b, msg| { + b.iter(|| { + let resp = legacy::call(legacy_addr, &legacy::RpcMsgReq::Echo(msg.clone())); + black_box(resp); + }); + }); + + legacy_shutdown.store(true, Ordering::Relaxed); + thread::sleep(Duration::from_millis(50)); + } + + // Mill-RPC echo with different sizes + for size in [16, 256, 4096] { + let msg = "x".repeat(size); + + let mill_state = Arc::new(SharedState::new()); + let mill_el = Arc::new(mill_io::EventLoop::new(4, 1024, 100).unwrap()); + + let port = 19878 + size as u16; + let mill_addr: SocketAddr = format!("127.0.0.1:{}", port).parse().unwrap(); + let svc = mill::BenchServiceImpl { state: mill_state }; + let _mill_server = mill_rpc::RpcServer::builder() + .bind(mill_addr) + .service(mill::bench_service::server(svc)) + .build(&mill_el); + + match _mill_server { + Ok(_server) => { + let el = mill_el.clone(); + thread::spawn(move || { + let _ = el.run(); + }); + thread::sleep(Duration::from_millis(100)); + + let client_el = Arc::new(mill_io::EventLoop::new(1, 256, 50).unwrap()); + let cel = client_el.clone(); + thread::spawn(move || { + let _ = cel.run(); + }); + thread::sleep(Duration::from_millis(50)); + + let client = mill::make_client(mill_addr, &client_el); + + group.bench_with_input(BenchmarkId::new("mill_rpc", size), &msg, |b, msg| { + b.iter(|| { + let resp = client.echo(msg.clone()); + black_box(resp).unwrap(); + }); + }); + + client_el.stop(); + mill_el.stop(); + } + Err(e) => { + eprintln!("Mill-RPC echo server failed (skipping): {}", e); + } + } + } + + group.finish(); +} + +fn bench_sequential_burst(c: &mut Criterion) { + let mut group = c.benchmark_group("sequential_burst"); + let burst_size = 50u64; + group.throughput(Throughput::Elements(burst_size)); + + let state = Arc::new(SharedState::new()); + + // --- Legacy: each call opens a new TCP connection (coinswap pattern) --- + let (legacy_addr, legacy_shutdown) = legacy::start_server(state.clone()); + + group.bench_function("legacy_new_conn_per_call", |b| { + b.iter(|| { + for _ in 0..burst_size { + let resp = legacy::call(legacy_addr, &legacy::RpcMsgReq::Ping); + black_box(resp); + } + }); + }); + + legacy_shutdown.store(true, Ordering::Relaxed); + thread::sleep(Duration::from_millis(50)); + + // --- Mill-RPC: persistent connection, multiplexed --- + let mill_state = Arc::new(SharedState::new()); + let mill_el = Arc::new(mill_io::EventLoop::new(4, 1024, 100).unwrap()); + + let mill_addr: SocketAddr = "127.0.0.1:19890".parse().unwrap(); + let svc = mill::BenchServiceImpl { state: mill_state }; + let _mill_server = mill_rpc::RpcServer::builder() + .bind(mill_addr) + .service(mill::bench_service::server(svc)) + .build(&mill_el); + + match _mill_server { + Ok(_server) => { + let el = mill_el.clone(); + thread::spawn(move || { + let _ = el.run(); + }); + thread::sleep(Duration::from_millis(100)); + + let client_el = Arc::new(mill_io::EventLoop::new(1, 256, 50).unwrap()); + let cel = client_el.clone(); + thread::spawn(move || { + let _ = cel.run(); + }); + thread::sleep(Duration::from_millis(50)); + + let client = mill::make_client(mill_addr, &client_el); + + group.bench_function("mill_rpc_persistent_conn", |b| { + b.iter(|| { + for _ in 0..burst_size { + let resp = client.ping(); + black_box(resp).unwrap(); + } + }); + }); + + client_el.stop(); + mill_el.stop(); + } + Err(e) => { + eprintln!("Mill-RPC burst server failed (skipping): {}", e); + } + } + + group.finish(); +} + +criterion_group!( + benches, + bench_ping, + bench_get_utxos, + bench_echo, + bench_sequential_burst, +); +criterion_main!(benches); diff --git a/mill-rpc/mill-rpc-macros/src/lib.rs b/mill-rpc/mill-rpc-macros/src/lib.rs index 1d5ef7b..9007182 100644 --- a/mill-rpc/mill-rpc-macros/src/lib.rs +++ b/mill-rpc/mill-rpc-macros/src/lib.rs @@ -361,7 +361,8 @@ fn generate_service_module(def: ServiceDef) -> syn::Result syn::Result Date: Thu, 12 Mar 2026 23:06:36 +0200 Subject: [PATCH 7/7] comments --- mill-rpc/examples/multi_service_server.rs | 3 +++ mill-rpc/src/client.rs | 26 +++++++++++++++-------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/mill-rpc/examples/multi_service_server.rs b/mill-rpc/examples/multi_service_server.rs index b27b2a6..8fdff20 100644 --- a/mill-rpc/examples/multi_service_server.rs +++ b/mill-rpc/examples/multi_service_server.rs @@ -56,11 +56,14 @@ impl math_service::Service for MathImpl { if n < 4 { return true; } + + #[allow(clippy::incompatible_msrv)] if n.is_multiple_of(2) { return false; } let mut i = 3; while i * i <= n { + #[allow(clippy::incompatible_msrv)] if n.is_multiple_of(i) { return false; } diff --git a/mill-rpc/src/client.rs b/mill-rpc/src/client.rs index 7fd3df3..6e8fe0e 100644 --- a/mill-rpc/src/client.rs +++ b/mill-rpc/src/client.rs @@ -35,7 +35,7 @@ pub struct RpcClient { tcp_client: Mutex>, shared: Arc, next_request_id: AtomicU64, - timeout: Duration, + timeout: AtomicU64, } impl RpcClient { @@ -62,13 +62,18 @@ impl RpcClient { tcp_client: Mutex::new(tcp_client), shared, next_request_id: AtomicU64::new(1), - timeout: Duration::from_secs(30), + timeout: AtomicU64::new(30 * 1000), })) } /// Set the default timeout for RPC calls. - pub fn set_timeout(&mut self, timeout: Duration) { - self.timeout = timeout; + pub fn set_timeout(&self, timeout: Duration) { + self.timeout + .store(timeout.as_millis() as u64, Ordering::SeqCst); + } + + fn timeout(&self) -> Duration { + Duration::from_millis(self.timeout.load(Ordering::SeqCst)) } /// Send a request and wait for a response (blocking). @@ -93,15 +98,18 @@ impl RpcClient { } let frame = Frame::request(request_id, service_id, method_id, payload, false); - { + let send_result = { let client = self.tcp_client.lock().unwrap(); - client - .send(&frame.encode()) - .map_err(|e| RpcError::unavailable(format!("Send failed: {}", e)))?; + client.send(&frame.encode()) + }; + if let Err(e) = send_result { + let mut pending = self.shared.pending.lock().unwrap(); + pending.remove(&request_id); + return Err(RpcError::unavailable(format!("Send failed: {}", e))); } let mut pending = self.shared.pending.lock().unwrap(); - let deadline = std::time::Instant::now() + self.timeout; + let deadline = std::time::Instant::now() + self.timeout(); loop { if let Some(req) = pending.get(&request_id) {