diff --git a/Cargo.toml b/Cargo.toml index fa25026..9087d3d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,11 @@ [workspace] -members = ["mill-io", "mill-net"] +members = [ + "mill-io", + "mill-net", + "mill-rpc", + "mill-rpc/mill-rpc-core", + "mill-rpc/mill-rpc-macros", +] resolver = "2" [workspace.package] diff --git a/README.md b/README.md index f93ebe8..a4403d4 100644 --- a/README.md +++ b/README.md @@ -9,9 +9,18 @@ A lightweight, production-ready event loop library for Rust that provides effici - **Thread pool integration**: Configurable worker threads for handling I/O events - **Compute pool**: Dedicated priority-based thread pool for CPU-intensive tasks - **High-level networking**: High-level server/client components based on mill-io +- **RPC framework**: macro-driven RPC with type-safe clients and servers - **Object pooling**: Reduces allocation overhead for frequent operations - **Clean API**: Simple registration and handler interface +## Crates + +| Crate | Description | +| ------------------------ | -------------------------------------------- | +| [`mill-io`](./mill-io) | Core reactor-based event loop | +| [`mill-net`](./mill-net) | High-level TCP server/client networking | +| [`mill-rpc`](./mill-rpc) | Macro-driven RPC framework (server + client) | + ## Installation For the core event loop only: @@ -28,6 +37,13 @@ For high-level networking (includes mill-io as dependency): mill-net = "2.0.1" ``` +For the RPC framework (includes mill-io and mill-net): + +```toml +[dependencies] +mill-rpc = { path = "mill-rpc" } +``` + ## Architecture For detailed architectural documentation, see [Architecture Guide](./docs/Arch.md). diff --git a/mill-rpc/Cargo.toml b/mill-rpc/Cargo.toml new file mode 100644 index 0000000..fe6ae21 --- /dev/null +++ b/mill-rpc/Cargo.toml @@ -0,0 +1,63 @@ +[package] +name = "mill-rpc" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +rust-version.workspace = true +description = "An Axum-inspired RPC framework built on Mill-IO" + +[dependencies] +mill-rpc-core = { path = "mill-rpc-core" } +mill-rpc-macros = { path = "mill-rpc-macros" } +mill-io = { path = "../mill-io" } +mill-net = { path = "../mill-net" } +serde = { version = "1", features = ["derive"] } +bincode = "1" +log = "0.4" +mio = { version = "1", features = ["os-poll", "net"] } + +[dev-dependencies] +env_logger = "0.11" +criterion = { version = "0.5", features = ["html_reports"] } +serde_cbor = "0.11" + +[[example]] +name = "calculator_server" +path = "examples/calculator_server.rs" + +[[example]] +name = "calculator_client" +path = "examples/calculator_client.rs" + +[[example]] +name = "echo_server" +path = "examples/echo_server.rs" + +[[example]] +name = "echo_client" +path = "examples/echo_client.rs" + +[[example]] +name = "kv_server" +path = "examples/kv_server.rs" + +[[example]] +name = "kv_client" +path = "examples/kv_client.rs" + +[[example]] +name = "multi_service_server" +path = "examples/multi_service_server.rs" + +[[example]] +name = "multi_service_client" +path = "examples/multi_service_client.rs" + +[[example]] +name = "concurrent_clients" +path = "examples/concurrent_clients.rs" + +[[bench]] +name = "rpc_comparison" +harness = false diff --git a/mill-rpc/README.md b/mill-rpc/README.md new file mode 100644 index 0000000..ed827b6 --- /dev/null +++ b/mill-rpc/README.md @@ -0,0 +1,138 @@ +# mill-rpc + +An RPC framework built on [`mill-io`](../mill-io) and [`mill-net`](../mill-net). Define services declaratively, get type-safe clients and servers — no async runtime required. + +## Features + +- **Zero async**: Handlers are plain synchronous functions +- **Macro-driven**: `mill_rpc::service!` generates a module with server trait, client struct, and dispatch logic +- **Selective generation**: Use `#[server]`, `#[client]`, or both (default) +- **Multi-service**: Host multiple services on a single port with automatic routing +- **Pluggable codecs**: Bincode by default, extensible +- **Binary wire protocol**: Efficient framing with one-way calls and ping/pong + +## Quick Start + +### Define a service + +```rust +mill_rpc::service! { + service Calculator { + fn add(a: i32, b: i32) -> i32; + fn multiply(a: i64, b: i64) -> i64; + } +} +``` + +This generates a `calculator` module containing: +- `calculator::Service`: trait to implement on the server +- `calculator::server(impl)`: wraps your impl for registration +- `calculator::Client`: struct with typed RPC methods +- `calculator::methods`: method ID constants + +### Server + +```rust +struct MyCalc; + +impl calculator::Service for MyCalc { + fn add(&self, _ctx: &RpcContext, a: i32, b: i32) -> i32 { a + b } + fn multiply(&self, _ctx: &RpcContext, a: i64, b: i64) -> i64 { a * b } +} + +fn main() { + let event_loop = Arc::new(EventLoop::new(4, 1024, 100).unwrap()); + + let _server = RpcServer::builder() + .bind("127.0.0.1:9001".parse().unwrap()) + .service(calculator::server(MyCalc)) + .build(&event_loop) + .unwrap(); + + event_loop.run().unwrap(); +} +``` + +### Client + +```rust +let transport = RpcClient::connect(addr, &event_loop).unwrap(); +let client = calculator::Client::new(transport, Codec::bincode(), 0); + +let sum = client.add(10, 25).unwrap(); // 35 +let prod = client.multiply(7, 8).unwrap(); // 56 +``` + +## Selective Generation + +Generate only what you need: + +```rust +// Server crate: no client code generated +mill_rpc::service! { + #[server] + service Calculator { + fn add(a: i32, b: i32) -> i32; + } +} + +// Client crate: no server code generated +mill_rpc::service! { + #[client] + service Calculator { + fn add(a: i32, b: i32) -> i32; + } +} + +// Both (default): for tests, examples, or single-binary apps +mill_rpc::service! { + service Calculator { + fn add(a: i32, b: i32) -> i32; + } +} +``` + +## Multi-Service Server + +```rust +mill_rpc::service! { + #[server] + service MathService { + fn factorial(n: u64) -> u64; + } +} + +mill_rpc::service! { + #[server] + service StringService { + fn reverse(s: String) -> String; + } +} + +let _server = RpcServer::builder() + .bind(addr) + .service(math_service::server(MathImpl)) // service_id = 0 + .service(string_service::server(StringImpl)) // service_id = 1 + .build(&event_loop)?; + +// Client side: share one connection +let math = math_service::Client::new(transport.clone(), codec, 0); +let strings = string_service::Client::new(transport, codec, 1); +``` + +## Examples + +```bash +# Terminal 1 # Terminal 2 +cargo run --example calculator_server cargo run --example calculator_client +cargo run --example echo_server cargo run --example echo_client +cargo run --example kv_server cargo run --example kv_client +cargo run --example multi_service_server cargo run --example multi_service_client + +# Self-contained stress test +cargo run --example concurrent_clients +``` + +## License + +Licensed under the Apache License, Version 2.0. See [LICENSE](../LICENSE) for details. diff --git a/mill-rpc/benches/rpc_comparison.rs b/mill-rpc/benches/rpc_comparison.rs new file mode 100644 index 0000000..30f2a3d --- /dev/null +++ b/mill-rpc/benches/rpc_comparison.rs @@ -0,0 +1,535 @@ +//! Benchmark comparison: Legacy hand-rolled RPC vs Mill-RPC. +//! +//! The legacy approach mirrors how coinswap's maker RPC works: +//! - `TcpListener` with `set_nonblocking(true)` +//! - Busy-poll loop with `sleep(HEART_BEAT_INTERVAL)` +//! - Manual `serde_cbor` serialization of request/response enums +//! - Single-threaded, one request at a time +//! +//! Mill-RPC uses: +//! - mill-net's reactor-based TcpServer +//! - Auto-generated dispatch from `mill_rpc::service!` +//! - Thread-pool for concurrent request handling +//! - Binary framing protocol with bincode + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use serde::{Deserialize, Serialize}; +use std::io::{Read, Write}; +use std::net::{SocketAddr, TcpListener, TcpStream}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, RwLock}; +use std::thread; +use std::time::Duration; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +struct UtxoEntry { + txid: String, + vout: u32, + value: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +struct BalanceInfo { + confirmed: u64, + unconfirmed: u64, + total: u64, +} + +/// Simulated state that both servers share. +struct SharedState { + utxos: Vec, + balances: BalanceInfo, + counter: RwLock, +} + +impl SharedState { + fn new() -> Self { + let utxos: Vec = (0..20) + .map(|i| UtxoEntry { + txid: format!( + "abcdef{:04x}abcdef{:04x}abcdef{:04x}abcdef{:04x}", + i, i, i, i + ), + vout: i, + value: (i as u64 + 1) * 100_000, + }) + .collect(); + + Self { + utxos, + balances: BalanceInfo { + confirmed: 5_000_000, + unconfirmed: 200_000, + total: 5_200_000, + }, + counter: RwLock::new(0), + } + } +} + +/// Legacy RPC (mirrors coinswap's approach) +mod legacy { + use super::*; + + #[derive(Debug, Serialize, Deserialize)] + pub enum RpcMsgReq { + Ping, + GetUtxos, + GetBalances, + Increment, + Echo(String), + } + + #[derive(Debug, Serialize, Deserialize)] + pub enum RpcMsgResp { + Pong, + UtxoResp { utxos: Vec }, + BalancesResp(BalanceInfo), + IncrementResp(u64), + EchoResp(String), + ServerError(String), + } + + /// Length-prefixed message: 4-byte LE length + cbor payload (mirrors coinswap's read_message/send_message). + pub fn send_message(stream: &mut TcpStream, msg: &RpcMsgResp) -> std::io::Result<()> { + let data = serde_cbor::to_vec(msg).unwrap(); + let len = (data.len() as u32).to_le_bytes(); + stream.write_all(&len)?; + stream.write_all(&data)?; + stream.flush()?; + Ok(()) + } + + pub fn read_message(stream: &mut TcpStream) -> std::io::Result> { + let mut len_buf = [0u8; 4]; + stream.read_exact(&mut len_buf)?; + let len = u32::from_le_bytes(len_buf) as usize; + let mut buf = vec![0u8; len]; + stream.read_exact(&mut buf)?; + Ok(buf) + } + + pub fn send_request(stream: &mut TcpStream, req: &RpcMsgReq) -> std::io::Result<()> { + let data = serde_cbor::to_vec(req).unwrap(); + let len = (data.len() as u32).to_le_bytes(); + stream.write_all(&len)?; + stream.write_all(&data)?; + stream.flush()?; + Ok(()) + } + + pub fn read_response(stream: &mut TcpStream) -> std::io::Result { + let data = read_message(stream)?; + Ok(serde_cbor::from_slice(&data).unwrap()) + } + + fn handle_request(state: &Arc, socket: &mut TcpStream) -> std::io::Result<()> { + let msg_bytes = read_message(socket)?; + let rpc_request: RpcMsgReq = serde_cbor::from_slice(&msg_bytes).unwrap(); + + let resp = match rpc_request { + RpcMsgReq::Ping => RpcMsgResp::Pong, + RpcMsgReq::GetUtxos => RpcMsgResp::UtxoResp { + utxos: state.utxos.clone(), + }, + RpcMsgReq::GetBalances => RpcMsgResp::BalancesResp(state.balances.clone()), + RpcMsgReq::Increment => { + let mut counter = state.counter.write().unwrap(); + *counter += 1; + RpcMsgResp::IncrementResp(*counter) + } + RpcMsgReq::Echo(msg) => RpcMsgResp::EchoResp(msg), + }; + + send_message(socket, &resp)?; + Ok(()) + } + + /// Start the legacy server and return (addr, shutdown_flag). + pub fn start_server(state: Arc) -> (SocketAddr, Arc) { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + let shutdown = Arc::new(AtomicBool::new(false)); + let shutdown_clone = shutdown.clone(); + + listener.set_nonblocking(true).unwrap(); + + thread::spawn(move || { + while !shutdown_clone.load(Ordering::Relaxed) { + match listener.accept() { + Ok((mut stream, _)) => { + stream + .set_read_timeout(Some(Duration::from_secs(20))) + .unwrap(); + stream + .set_write_timeout(Some(Duration::from_secs(20))) + .unwrap(); + if let Err(e) = handle_request(&state, &mut stream) { + let _ = send_message( + &mut stream, + &RpcMsgResp::ServerError(format!("{e:?}")), + ); + } + } + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {} + Err(_) => {} + } + // Mirrors HEART_BEAT_INTERVAL in coinswap (typically 3 seconds, + // but we use 1ms here so the benchmark isn't dominated by sleep) + thread::sleep(Duration::from_millis(1)); + } + }); + + // Wait for server to be ready + thread::sleep(Duration::from_millis(50)); + (addr, shutdown) + } + + /// Make a single request-response roundtrip (one TCP connection per call, + /// same as coinswap's client pattern). + pub fn call(addr: SocketAddr, req: &RpcMsgReq) -> RpcMsgResp { + let mut stream = TcpStream::connect(addr).unwrap(); + stream + .set_read_timeout(Some(Duration::from_secs(5))) + .unwrap(); + stream + .set_write_timeout(Some(Duration::from_secs(5))) + .unwrap(); + send_request(&mut stream, req).unwrap(); + read_response(&mut stream).unwrap() + } +} + +/// Mill-RPC alternative +mod mill { + use super::*; + use mill_io::EventLoop; + use mill_rpc::prelude::*; + + mill_rpc::service! { + service BenchService { + fn ping() -> (); + fn get_utxos() -> Vec; + fn get_balances() -> BalanceInfo; + fn increment() -> u64; + fn echo(msg: String) -> String; + } + } + + pub struct BenchServiceImpl { + pub state: Arc, + } + + impl bench_service::Service for BenchServiceImpl { + fn ping(&self, _ctx: &RpcContext) {} + + fn get_utxos(&self, _ctx: &RpcContext) -> Vec { + self.state.utxos.clone() + } + + fn get_balances(&self, _ctx: &RpcContext) -> BalanceInfo { + self.state.balances.clone() + } + + fn increment(&self, _ctx: &RpcContext) -> u64 { + let mut counter = self.state.counter.write().unwrap(); + *counter += 1; + *counter + } + + fn echo(&self, _ctx: &RpcContext, msg: String) -> String { + msg + } + } + + /// Make a call using mill-rpc client. + pub fn make_client(addr: SocketAddr, event_loop: &Arc) -> bench_service::Client { + let transport = RpcClient::connect(addr, event_loop).unwrap(); + bench_service::Client::new(transport, Codec::bincode(), 0) + } +} + +fn bench_ping(c: &mut Criterion) { + let mut group = c.benchmark_group("ping_roundtrip"); + group.throughput(Throughput::Elements(1)); + + let state = Arc::new(SharedState::new()); + + // --- Legacy --- + let (legacy_addr, legacy_shutdown) = legacy::start_server(state.clone()); + + group.bench_function("legacy", |b| { + b.iter(|| { + let resp = legacy::call(legacy_addr, &legacy::RpcMsgReq::Ping); + black_box(resp); + }); + }); + + legacy_shutdown.store(true, Ordering::Relaxed); + thread::sleep(Duration::from_millis(50)); + + // --- Mill-RPC --- + let mill_state = Arc::new(SharedState::new()); + let mill_el = Arc::new(mill_io::EventLoop::new(4, 1024, 100).unwrap()); + + let svc = mill::BenchServiceImpl { state: mill_state }; + + let mill_addr: SocketAddr = "127.0.0.1:19876".parse().unwrap(); + let _mill_server = mill_rpc::RpcServer::builder() + .bind(mill_addr) + .service(mill::bench_service::server(svc)) + .build(&mill_el); + + match _mill_server { + Ok(_server) => { + let el = mill_el.clone(); + thread::spawn(move || { + let _ = el.run(); + }); + thread::sleep(Duration::from_millis(100)); + + // Create a persistent client + let client_el = Arc::new(mill_io::EventLoop::new(1, 256, 50).unwrap()); + let cel = client_el.clone(); + thread::spawn(move || { + let _ = cel.run(); + }); + thread::sleep(Duration::from_millis(50)); + + let client = mill::make_client(mill_addr, &client_el); + + group.bench_function("mill_rpc", |b| { + b.iter(|| { + let resp = client.ping(); + black_box(resp).unwrap(); + }); + }); + + client_el.stop(); + mill_el.stop(); + } + Err(e) => { + eprintln!("Mill-RPC server failed to start (skipping): {}", e); + } + } + + group.finish(); +} + +fn bench_get_utxos(c: &mut Criterion) { + let mut group = c.benchmark_group("get_utxos_roundtrip"); + group.throughput(Throughput::Elements(1)); + + let state = Arc::new(SharedState::new()); + + // --- Legacy --- + let (legacy_addr, legacy_shutdown) = legacy::start_server(state.clone()); + + group.bench_function("legacy", |b| { + b.iter(|| { + let resp = legacy::call(legacy_addr, &legacy::RpcMsgReq::GetUtxos); + black_box(resp); + }); + }); + + legacy_shutdown.store(true, Ordering::Relaxed); + thread::sleep(Duration::from_millis(50)); + + // --- Mill-RPC --- + let mill_state = Arc::new(SharedState::new()); + let mill_el = Arc::new(mill_io::EventLoop::new(4, 1024, 100).unwrap()); + + let mill_addr: SocketAddr = "127.0.0.1:19877".parse().unwrap(); + let svc = mill::BenchServiceImpl { state: mill_state }; + let _mill_server = mill_rpc::RpcServer::builder() + .bind(mill_addr) + .service(mill::bench_service::server(svc)) + .build(&mill_el); + + match _mill_server { + Ok(_server) => { + let el = mill_el.clone(); + thread::spawn(move || { + let _ = el.run(); + }); + thread::sleep(Duration::from_millis(100)); + + let client_el = Arc::new(mill_io::EventLoop::new(1, 256, 50).unwrap()); + let cel = client_el.clone(); + thread::spawn(move || { + let _ = cel.run(); + }); + thread::sleep(Duration::from_millis(50)); + + let client = mill::make_client(mill_addr, &client_el); + + group.bench_function("mill_rpc", |b| { + b.iter(|| { + let resp = client.get_utxos(); + black_box(resp).unwrap(); + }); + }); + + client_el.stop(); + mill_el.stop(); + } + Err(e) => { + eprintln!("Mill-RPC server failed to start (skipping): {}", e); + } + } + + group.finish(); +} + +fn bench_echo(c: &mut Criterion) { + let mut group = c.benchmark_group("echo_roundtrip"); + + let state = Arc::new(SharedState::new()); + + for size in [16, 256, 4096] { + let msg = "x".repeat(size); + group.throughput(Throughput::Bytes(size as u64)); + + // --- Legacy --- + let (legacy_addr, legacy_shutdown) = legacy::start_server(state.clone()); + + group.bench_with_input(BenchmarkId::new("legacy", size), &msg, |b, msg| { + b.iter(|| { + let resp = legacy::call(legacy_addr, &legacy::RpcMsgReq::Echo(msg.clone())); + black_box(resp); + }); + }); + + legacy_shutdown.store(true, Ordering::Relaxed); + thread::sleep(Duration::from_millis(50)); + } + + // Mill-RPC echo with different sizes + for size in [16, 256, 4096] { + let msg = "x".repeat(size); + + let mill_state = Arc::new(SharedState::new()); + let mill_el = Arc::new(mill_io::EventLoop::new(4, 1024, 100).unwrap()); + + let port = 19878 + size as u16; + let mill_addr: SocketAddr = format!("127.0.0.1:{}", port).parse().unwrap(); + let svc = mill::BenchServiceImpl { state: mill_state }; + let _mill_server = mill_rpc::RpcServer::builder() + .bind(mill_addr) + .service(mill::bench_service::server(svc)) + .build(&mill_el); + + match _mill_server { + Ok(_server) => { + let el = mill_el.clone(); + thread::spawn(move || { + let _ = el.run(); + }); + thread::sleep(Duration::from_millis(100)); + + let client_el = Arc::new(mill_io::EventLoop::new(1, 256, 50).unwrap()); + let cel = client_el.clone(); + thread::spawn(move || { + let _ = cel.run(); + }); + thread::sleep(Duration::from_millis(50)); + + let client = mill::make_client(mill_addr, &client_el); + + group.bench_with_input(BenchmarkId::new("mill_rpc", size), &msg, |b, msg| { + b.iter(|| { + let resp = client.echo(msg.clone()); + black_box(resp).unwrap(); + }); + }); + + client_el.stop(); + mill_el.stop(); + } + Err(e) => { + eprintln!("Mill-RPC echo server failed (skipping): {}", e); + } + } + } + + group.finish(); +} + +fn bench_sequential_burst(c: &mut Criterion) { + let mut group = c.benchmark_group("sequential_burst"); + let burst_size = 50u64; + group.throughput(Throughput::Elements(burst_size)); + + let state = Arc::new(SharedState::new()); + + // --- Legacy: each call opens a new TCP connection (coinswap pattern) --- + let (legacy_addr, legacy_shutdown) = legacy::start_server(state.clone()); + + group.bench_function("legacy_new_conn_per_call", |b| { + b.iter(|| { + for _ in 0..burst_size { + let resp = legacy::call(legacy_addr, &legacy::RpcMsgReq::Ping); + black_box(resp); + } + }); + }); + + legacy_shutdown.store(true, Ordering::Relaxed); + thread::sleep(Duration::from_millis(50)); + + // --- Mill-RPC: persistent connection, multiplexed --- + let mill_state = Arc::new(SharedState::new()); + let mill_el = Arc::new(mill_io::EventLoop::new(4, 1024, 100).unwrap()); + + let mill_addr: SocketAddr = "127.0.0.1:19890".parse().unwrap(); + let svc = mill::BenchServiceImpl { state: mill_state }; + let _mill_server = mill_rpc::RpcServer::builder() + .bind(mill_addr) + .service(mill::bench_service::server(svc)) + .build(&mill_el); + + match _mill_server { + Ok(_server) => { + let el = mill_el.clone(); + thread::spawn(move || { + let _ = el.run(); + }); + thread::sleep(Duration::from_millis(100)); + + let client_el = Arc::new(mill_io::EventLoop::new(1, 256, 50).unwrap()); + let cel = client_el.clone(); + thread::spawn(move || { + let _ = cel.run(); + }); + thread::sleep(Duration::from_millis(50)); + + let client = mill::make_client(mill_addr, &client_el); + + group.bench_function("mill_rpc_persistent_conn", |b| { + b.iter(|| { + for _ in 0..burst_size { + let resp = client.ping(); + black_box(resp).unwrap(); + } + }); + }); + + client_el.stop(); + mill_el.stop(); + } + Err(e) => { + eprintln!("Mill-RPC burst server failed (skipping): {}", e); + } + } + + group.finish(); +} + +criterion_group!( + benches, + bench_ping, + bench_get_utxos, + bench_echo, + bench_sequential_burst, +); +criterion_main!(benches); diff --git a/mill-rpc/examples/calculator_client.rs b/mill-rpc/examples/calculator_client.rs new file mode 100644 index 0000000..0f645f3 --- /dev/null +++ b/mill-rpc/examples/calculator_client.rs @@ -0,0 +1,64 @@ +//! Calculator RPC client. +//! +//! Run the server first: cargo run --example calculator_server +//! Then run: cargo run --example calculator_client + +use mill_io::EventLoop; +use mill_rpc::prelude::*; +use std::sync::Arc; +use std::thread; +use std::time::Duration; + +// Define the service — only generate client side +mill_rpc::service! { + #[client] + service Calculator { + fn add(a: i32, b: i32) -> i32; + fn subtract(a: i32, b: i32) -> i32; + fn multiply(a: i64, b: i64) -> i64; + fn divide(a: f64, b: f64) -> f64; + fn negate(x: i32) -> i32; + } +} + +fn main() { + env_logger::init(); + let event_loop = Arc::new(EventLoop::new(2, 1024, 100).unwrap()); + + let el = event_loop.clone(); + let handle = thread::spawn(move || { + el.run().unwrap(); + }); + thread::sleep(Duration::from_millis(50)); + + let addr = "127.0.0.1:9001".parse().unwrap(); + let transport = + RpcClient::connect(addr, &event_loop).expect("Failed to connect to calculator server"); + + let client = calculator::Client::new(transport, Codec::bincode(), 0); + + println!("Connected to calculator server"); + + let sum = client.add(10, 25).unwrap(); + println!("10 + 25 = {}", sum); + + let diff = client.subtract(100, 42).unwrap(); + println!("100 - 42 = {}", diff); + + let product = client.multiply(7, 8).unwrap(); + println!("7 * 8 = {}", product); + + let quotient = client.divide(355.0, 113.0).unwrap(); + println!("355 / 113 = {:.6}", quotient); + + let neg = client.negate(42).unwrap(); + println!("negate(42) = {}", neg); + + let nan = client.divide(1.0, 0.0).unwrap(); + println!("1.0 / 0.0 = {} (NaN: {})", nan, nan.is_nan()); + + println!("\nAll calculations completed!"); + + event_loop.stop(); + let _ = handle.join(); +} diff --git a/mill-rpc/examples/calculator_server.rs b/mill-rpc/examples/calculator_server.rs new file mode 100644 index 0000000..42f2362 --- /dev/null +++ b/mill-rpc/examples/calculator_server.rs @@ -0,0 +1,65 @@ +//! Basic calculator RPC server. +//! +//! Run with: cargo run --example calculator_server +//! Then connect with: cargo run --example calculator_client + +use mill_io::EventLoop; +use mill_rpc::prelude::*; +use std::sync::Arc; + +// Define the service — only generate server side +mill_rpc::service! { + #[server] + service Calculator { + fn add(a: i32, b: i32) -> i32; + fn subtract(a: i32, b: i32) -> i32; + fn multiply(a: i64, b: i64) -> i64; + fn divide(a: f64, b: f64) -> f64; + fn negate(x: i32) -> i32; + } +} + +struct MyCalculator; + +impl calculator::Service for MyCalculator { + fn add(&self, _ctx: &RpcContext, a: i32, b: i32) -> i32 { + println!(" add({}, {}) = {}", a, b, a + b); + a + b + } + + fn subtract(&self, _ctx: &RpcContext, a: i32, b: i32) -> i32 { + println!(" subtract({}, {}) = {}", a, b, a - b); + a - b + } + + fn multiply(&self, _ctx: &RpcContext, a: i64, b: i64) -> i64 { + println!(" multiply({}, {}) = {}", a, b, a * b); + a * b + } + + fn divide(&self, _ctx: &RpcContext, a: f64, b: f64) -> f64 { + let result = if b == 0.0 { f64::NAN } else { a / b }; + println!(" divide({}, {}) = {}", a, b, result); + result + } + + fn negate(&self, _ctx: &RpcContext, x: i32) -> i32 { + println!(" negate({}) = {}", x, -x); + -x + } +} + +fn main() { + env_logger::init(); + let event_loop = Arc::new(EventLoop::new(4, 1024, 100).unwrap()); + + let addr = "127.0.0.1:9001".parse().unwrap(); + let _server = RpcServer::builder() + .bind(addr) + .service(calculator::server(MyCalculator)) + .build(&event_loop) + .expect("Failed to start calculator server"); + + println!("Calculator server listening on {}", addr); + event_loop.run().unwrap(); +} diff --git a/mill-rpc/examples/concurrent_clients.rs b/mill-rpc/examples/concurrent_clients.rs new file mode 100644 index 0000000..d386379 --- /dev/null +++ b/mill-rpc/examples/concurrent_clients.rs @@ -0,0 +1,120 @@ +//! Concurrent clients stress test. +//! +//! Run with: cargo run --example concurrent_clients + +use mill_io::EventLoop; +use mill_rpc::prelude::*; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use std::thread; +use std::time::{Duration, Instant}; + +// Both sides in one binary +mill_rpc::service! { + service Counter { + fn increment() -> u64; + fn get() -> u64; + } +} + +struct AtomicCounter { + value: AtomicU64, +} + +impl counter::Service for AtomicCounter { + fn increment(&self, _ctx: &RpcContext) -> u64 { + self.value.fetch_add(1, Ordering::SeqCst) + 1 + } + + fn get(&self, _ctx: &RpcContext) -> u64 { + self.value.load(Ordering::SeqCst) + } +} + +fn main() { + env_logger::init(); + + let num_clients = 4; + let requests_per_client = 100; + + let server_el = Arc::new(EventLoop::new(4, 1024, 100).unwrap()); + let addr = "127.0.0.1:9005".parse().unwrap(); + + let _server = RpcServer::builder() + .bind(addr) + .service(counter::server(AtomicCounter { + value: AtomicU64::new(0), + })) + .build(&server_el) + .expect("Failed to start server"); + + let sel = server_el.clone(); + let server_thread = thread::spawn(move || { + sel.run().unwrap(); + }); + thread::sleep(Duration::from_millis(100)); + + println!( + "Spawning {} clients, {} requests each...\n", + num_clients, requests_per_client + ); + + let start = Instant::now(); + let mut handles = Vec::new(); + + for client_id in 0..num_clients { + let handle = thread::spawn(move || { + let client_el = Arc::new(EventLoop::new(1, 256, 50).unwrap()); + let cel = client_el.clone(); + let el_thread = thread::spawn(move || { + cel.run().unwrap(); + }); + thread::sleep(Duration::from_millis(20)); + + let transport = RpcClient::connect(addr, &client_el).unwrap(); + let client = counter::Client::new(transport, Codec::bincode(), 0); + + let mut results = Vec::new(); + for _ in 0..requests_per_client { + results.push(client.increment().unwrap()); + } + + println!( + " Client {} done: first={}, last={}", + client_id, + results.first().unwrap(), + results.last().unwrap() + ); + + client_el.stop(); + let _ = el_thread.join(); + results + }); + handles.push(handle); + } + + let mut all: Vec = handles + .into_iter() + .flat_map(|h| h.join().unwrap()) + .collect(); + + let elapsed = start.elapsed(); + all.sort(); + all.dedup(); + + let total = (num_clients * requests_per_client) as usize; + println!("\n--- Results ---"); + println!("Total requests: {}", total); + println!("Unique values: {}", all.len()); + println!("Time: {:?}", elapsed); + println!( + "Throughput: {:.0} req/s", + total as f64 / elapsed.as_secs_f64() + ); + + assert_eq!(all.len(), total, "No lost updates"); + println!("\nConcurrency test passed!"); + + server_el.stop(); + let _ = server_thread.join(); +} diff --git a/mill-rpc/examples/echo_client.rs b/mill-rpc/examples/echo_client.rs new file mode 100644 index 0000000..9f6d333 --- /dev/null +++ b/mill-rpc/examples/echo_client.rs @@ -0,0 +1,62 @@ +//! Echo RPC client. +//! +//! Run the server first: cargo run --example echo_server +//! Then run: cargo run --example echo_client + +use mill_io::EventLoop; +use mill_rpc::prelude::*; +use std::sync::Arc; +use std::thread; +use std::time::Duration; + +mill_rpc::service! { + #[client] + service Echo { + fn echo(message: String) -> String; + fn echo_uppercase(message: String) -> String; + fn echo_repeat(message: String, times: u32) -> String; + fn request_count() -> u64; + } +} + +fn main() { + env_logger::init(); + let event_loop = Arc::new(EventLoop::new(2, 1024, 100).unwrap()); + + let el = event_loop.clone(); + let handle = thread::spawn(move || { + el.run().unwrap(); + }); + thread::sleep(Duration::from_millis(50)); + + let addr = "127.0.0.1:9002".parse().unwrap(); + let transport = + RpcClient::connect(addr, &event_loop).expect("Failed to connect to echo server"); + + let client = echo::Client::new(transport, Codec::bincode(), 0); + + // Basic echo + let reply = client.echo("Hello, Mill-RPC!".into()).unwrap(); + println!("echo: {}", reply); + assert_eq!(reply, "Hello, Mill-RPC!"); + + // Uppercase + let reply = client.echo_uppercase("hello world".into()).unwrap(); + println!("uppercase: {}", reply); + assert_eq!(reply, "HELLO WORLD"); + + // Repeat + let reply = client.echo_repeat("ha".into(), 3).unwrap(); + println!("repeat: {}", reply); + assert_eq!(reply, "hahaha"); + + // Request count + let count = client.request_count().unwrap(); + println!("server handled {} requests", count); + assert_eq!(count, 3); // echo + uppercase + repeat + + println!("\nAll echo tests passed!"); + + event_loop.stop(); + let _ = handle.join(); +} diff --git a/mill-rpc/examples/echo_server.rs b/mill-rpc/examples/echo_server.rs new file mode 100644 index 0000000..306d7cc --- /dev/null +++ b/mill-rpc/examples/echo_server.rs @@ -0,0 +1,74 @@ +//! Echo RPC server. +//! +//! Run with: cargo run --example echo_server +//! Then connect with: cargo run --example echo_client + +use mill_io::EventLoop; +use mill_rpc::prelude::*; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; + +mill_rpc::service! { + #[server] + service Echo { + fn echo(message: String) -> String; + fn echo_uppercase(message: String) -> String; + fn echo_repeat(message: String, times: u32) -> String; + fn request_count() -> u64; + } +} + +struct EchoImpl { + counter: AtomicU64, +} + +impl EchoImpl { + fn new() -> Self { + Self { + counter: AtomicU64::new(0), + } + } +} + +impl echo::Service for EchoImpl { + fn echo(&self, _ctx: &RpcContext, message: String) -> String { + self.counter.fetch_add(1, Ordering::Relaxed); + println!(" echo: {:?}", message); + message + } + + fn echo_uppercase(&self, _ctx: &RpcContext, message: String) -> String { + self.counter.fetch_add(1, Ordering::Relaxed); + let upper = message.to_uppercase(); + println!(" echo_uppercase: {:?} -> {:?}", message, upper); + upper + } + + fn echo_repeat(&self, _ctx: &RpcContext, message: String, times: u32) -> String { + self.counter.fetch_add(1, Ordering::Relaxed); + let result = message.repeat(times as usize); + println!(" echo_repeat({:?}, {}) -> {:?}", message, times, result); + result + } + + fn request_count(&self, _ctx: &RpcContext) -> u64 { + let count = self.counter.load(Ordering::Relaxed); + println!(" request_count -> {}", count); + count + } +} + +fn main() { + env_logger::init(); + let event_loop = Arc::new(EventLoop::new(4, 1024, 100).unwrap()); + + let addr = "127.0.0.1:9002".parse().unwrap(); + let _server = RpcServer::builder() + .bind(addr) + .service(echo::server(EchoImpl::new())) + .build(&event_loop) + .expect("Failed to start echo server"); + + println!("Echo server listening on {}", addr); + event_loop.run().unwrap(); +} diff --git a/mill-rpc/examples/kv_client.rs b/mill-rpc/examples/kv_client.rs new file mode 100644 index 0000000..31e5179 --- /dev/null +++ b/mill-rpc/examples/kv_client.rs @@ -0,0 +1,80 @@ +//! Key-value store RPC client. +//! +//! Run the server first: cargo run --example kv_server +//! Then run: cargo run --example kv_client + +use mill_io::EventLoop; +use mill_rpc::prelude::*; +use std::sync::Arc; +use std::thread; +use std::time::Duration; + +mill_rpc::service! { + #[client] + service KeyValue { + fn get(key: String) -> Option; + fn set(key: String, value: String) -> Option; + fn delete(key: String) -> bool; + fn keys() -> Vec; + fn len() -> u64; + fn is_empty() -> bool; + fn clear() -> u64; + } +} + +fn main() { + env_logger::init(); + let event_loop = Arc::new(EventLoop::new(2, 1024, 100).unwrap()); + + let el = event_loop.clone(); + let handle = thread::spawn(move || { + el.run().unwrap(); + }); + thread::sleep(Duration::from_millis(50)); + + let addr = "127.0.0.1:9003".parse().unwrap(); + let transport = RpcClient::connect(addr, &event_loop).expect("Failed to connect to KV server"); + + let kv = key_value::Client::new(transport, Codec::bincode(), 0); + + println!("=== Key-Value Store Client ===\n"); + + let len = kv.len().unwrap(); + println!("Initial store size: {}", len); + + kv.set("name".into(), "Alice".into()).unwrap(); + println!("SET name=Alice"); + + kv.set("city".into(), "Berlin".into()).unwrap(); + println!("SET city=Berlin"); + + kv.set("lang".into(), "Rust".into()).unwrap(); + println!("SET lang=Rust"); + + let val = kv.get("name".into()).unwrap(); + println!("GET name -> {:?}", val); + + let val = kv.get("missing".into()).unwrap(); + println!("GET missing -> {:?}", val); + + let mut keys = kv.keys().unwrap(); + keys.sort(); + println!("KEYS -> {:?}", keys); + + let old = kv.set("name".into(), "Bob".into()).unwrap(); + println!("SET name=Bob (old: {:?})", old); + + let existed = kv.delete("city".into()).unwrap(); + println!("DEL city -> existed: {}", existed); + + let len = kv.len().unwrap(); + println!("Store size: {}", len); + + let removed = kv.clear().unwrap(); + println!("CLEAR -> removed {} entries", removed); + + println!("\nAll KV tests passed!"); + + event_loop.stop(); + let _ = handle.join(); +} diff --git a/mill-rpc/examples/kv_server.rs b/mill-rpc/examples/kv_server.rs new file mode 100644 index 0000000..2fec3b4 --- /dev/null +++ b/mill-rpc/examples/kv_server.rs @@ -0,0 +1,98 @@ +//! In-memory key-value store RPC server. +//! +//! Run with: cargo run --example kv_server +//! Then connect with: cargo run --example kv_client + +use mill_io::EventLoop; +use mill_rpc::prelude::*; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; + +mill_rpc::service! { + #[server] + service KeyValue { + fn get(key: String) -> Option; + fn set(key: String, value: String) -> Option; + fn delete(key: String) -> bool; + fn keys() -> Vec; + fn len() -> u64; + fn is_empty() -> bool; + fn clear() -> u64; + } +} + +struct KvStore { + data: RwLock>, +} + +impl KvStore { + fn new() -> Self { + Self { + data: RwLock::new(HashMap::new()), + } + } +} + +impl key_value::Service for KvStore { + fn get(&self, _ctx: &RpcContext, key: String) -> Option { + let data = self.data.read().unwrap(); + let result = data.get(&key).cloned(); + println!(" GET {:?} -> {:?}", key, result); + result + } + + fn set(&self, _ctx: &RpcContext, key: String, value: String) -> Option { + let mut data = self.data.write().unwrap(); + let old = data.insert(key.clone(), value.clone()); + println!(" SET {:?} = {:?} (old: {:?})", key, value, old); + old + } + + fn delete(&self, _ctx: &RpcContext, key: String) -> bool { + let mut data = self.data.write().unwrap(); + let existed = data.remove(&key).is_some(); + println!(" DEL {:?} -> existed: {}", key, existed); + existed + } + + fn keys(&self, _ctx: &RpcContext) -> Vec { + let data = self.data.read().unwrap(); + let keys: Vec = data.keys().cloned().collect(); + println!(" KEYS -> {:?}", keys); + keys + } + + fn len(&self, _ctx: &RpcContext) -> u64 { + let data = self.data.read().unwrap(); + let len = data.len() as u64; + println!(" LEN -> {}", len); + len + } + + fn is_empty(&self, _ctx: &RpcContext) -> bool { + self.data.read().unwrap().is_empty() + } + + fn clear(&self, _ctx: &RpcContext) -> u64 { + let mut data = self.data.write().unwrap(); + let count = data.len() as u64; + data.clear(); + println!(" CLEAR -> removed {} entries", count); + count + } +} + +fn main() { + env_logger::init(); + let event_loop = Arc::new(EventLoop::new(4, 1024, 100).unwrap()); + + let addr = "127.0.0.1:9003".parse().unwrap(); + let _server = RpcServer::builder() + .bind(addr) + .service(key_value::server(KvStore::new())) + .build(&event_loop) + .expect("Failed to start KV server"); + + println!("Key-Value server listening on {}", addr); + event_loop.run().unwrap(); +} diff --git a/mill-rpc/examples/multi_service_client.rs b/mill-rpc/examples/multi_service_client.rs new file mode 100644 index 0000000..4055233 --- /dev/null +++ b/mill-rpc/examples/multi_service_client.rs @@ -0,0 +1,78 @@ +//! Multi-service RPC client — calls two services on one server. +//! +//! Run the server first: cargo run --example multi_service_server +//! Then run: cargo run --example multi_service_client + +use mill_io::EventLoop; +use mill_rpc::prelude::*; +use std::sync::Arc; +use std::thread; +use std::time::Duration; + +mill_rpc::service! { + #[client] + service MathService { + fn factorial(n: u64) -> u64; + fn fibonacci(n: u32) -> u64; + fn is_prime(n: u64) -> bool; + fn gcd(a: u64, b: u64) -> u64; + } +} + +mill_rpc::service! { + #[client] + service StringService { + fn reverse(s: String) -> String; + fn word_count(s: String) -> u32; + fn contains(haystack: String, needle: String) -> bool; + } +} + +fn main() { + env_logger::init(); + let event_loop = Arc::new(EventLoop::new(2, 1024, 100).unwrap()); + + let el = event_loop.clone(); + let handle = thread::spawn(move || { + el.run().unwrap(); + }); + thread::sleep(Duration::from_millis(50)); + + let addr = "127.0.0.1:9004".parse().unwrap(); + let transport = RpcClient::connect(addr, &event_loop).expect("Failed to connect"); + + let math = math_service::Client::new(transport.clone(), Codec::bincode(), 0); + let strings = string_service::Client::new(transport, Codec::bincode(), 1); + + println!("=== Math Service ===\n"); + + let f = math.factorial(10).unwrap(); + println!("10! = {}", f); + + let fib = math.fibonacci(20).unwrap(); + println!("fib(20) = {}", fib); + + for n in [2, 7, 15, 17] { + let prime = math.is_prime(n).unwrap(); + println!("is_prime({}) = {}", n, prime); + } + + let g = math.gcd(48, 18).unwrap(); + println!("gcd(48, 18) = {}", g); + + println!("\n=== String Service ===\n"); + + let rev = strings.reverse("Hello, World!".into()).unwrap(); + println!("reverse(\"Hello, World!\") = {:?}", rev); + + let wc = strings.word_count("The quick brown fox".into()).unwrap(); + println!("word_count = {}", wc); + + let has = strings.contains("rustacean".into(), "rust".into()).unwrap(); + println!("contains(\"rustacean\", \"rust\") = {}", has); + + println!("\nAll multi-service tests passed!"); + + event_loop.stop(); + let _ = handle.join(); +} diff --git a/mill-rpc/examples/multi_service_server.rs b/mill-rpc/examples/multi_service_server.rs new file mode 100644 index 0000000..8fdff20 --- /dev/null +++ b/mill-rpc/examples/multi_service_server.rs @@ -0,0 +1,117 @@ +//! Multi-service RPC server — two services on one port. +//! +//! Run with: cargo run --example multi_service_server +//! Then connect with: cargo run --example multi_service_client + +use mill_io::EventLoop; +use mill_rpc::prelude::*; +use std::sync::Arc; + +mill_rpc::service! { + #[server] + service MathService { + fn factorial(n: u64) -> u64; + fn fibonacci(n: u32) -> u64; + fn is_prime(n: u64) -> bool; + fn gcd(a: u64, b: u64) -> u64; + } +} + +mill_rpc::service! { + #[server] + service StringService { + fn reverse(s: String) -> String; + fn word_count(s: String) -> u32; + fn contains(haystack: String, needle: String) -> bool; + } +} + +struct MathImpl; + +impl math_service::Service for MathImpl { + fn factorial(&self, _ctx: &RpcContext, n: u64) -> u64 { + (1..=n).product() + } + + fn fibonacci(&self, _ctx: &RpcContext, n: u32) -> u64 { + match n { + 0 => 0, + 1 => 1, + _ => { + let (mut a, mut b) = (0u64, 1u64); + for _ in 2..=n { + let tmp = a + b; + a = b; + b = tmp; + } + b + } + } + } + + fn is_prime(&self, _ctx: &RpcContext, n: u64) -> bool { + if n < 2 { + return false; + } + if n < 4 { + return true; + } + + #[allow(clippy::incompatible_msrv)] + if n.is_multiple_of(2) { + return false; + } + let mut i = 3; + while i * i <= n { + #[allow(clippy::incompatible_msrv)] + if n.is_multiple_of(i) { + return false; + } + i += 2; + } + true + } + + fn gcd(&self, _ctx: &RpcContext, mut a: u64, mut b: u64) -> u64 { + while b != 0 { + let tmp = b; + b = a % b; + a = tmp; + } + a + } +} + +struct StringImpl; + +impl string_service::Service for StringImpl { + fn reverse(&self, _ctx: &RpcContext, s: String) -> String { + s.chars().rev().collect() + } + + fn word_count(&self, _ctx: &RpcContext, s: String) -> u32 { + s.split_whitespace().count() as u32 + } + + fn contains(&self, _ctx: &RpcContext, haystack: String, needle: String) -> bool { + haystack.contains(&needle) + } +} + +fn main() { + env_logger::init(); + let event_loop = Arc::new(EventLoop::new(4, 1024, 100).unwrap()); + + let addr = "127.0.0.1:9004".parse().unwrap(); + let _server = RpcServer::builder() + .bind(addr) + .service(math_service::server(MathImpl)) // service_id = 0 + .service(string_service::server(StringImpl)) // service_id = 1 + .build(&event_loop) + .expect("Failed to start multi-service server"); + + println!("Multi-service server listening on {}", addr); + println!(" Service 0: MathService"); + println!(" Service 1: StringService"); + event_loop.run().unwrap(); +} diff --git a/mill-rpc/mill-rpc-core/Cargo.toml b/mill-rpc/mill-rpc-core/Cargo.toml new file mode 100644 index 0000000..53ff741 --- /dev/null +++ b/mill-rpc/mill-rpc-core/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "mill-rpc-core" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +rust-version.workspace = true +description = "Core types, wire protocol, and codec traits for Mill-RPC" + +[dependencies] +serde = { version = "1", features = ["derive"] } +bincode = "1" diff --git a/mill-rpc/mill-rpc-core/src/codec.rs b/mill-rpc/mill-rpc-core/src/codec.rs new file mode 100644 index 0000000..54d3994 --- /dev/null +++ b/mill-rpc/mill-rpc-core/src/codec.rs @@ -0,0 +1,48 @@ +use crate::error::RpcError; +use serde::{de::DeserializeOwned, Serialize}; + +/// Supported codec types. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CodecType { + Bincode, +} + +/// Codec for serializing/deserializing RPC payloads. +#[derive(Debug, Clone)] +pub struct Codec { + codec_type: CodecType, +} + +impl Codec { + pub fn bincode() -> Self { + Self { + codec_type: CodecType::Bincode, + } + } + + pub fn codec_type(&self) -> CodecType { + self.codec_type + } + + /// Serialize a value to bytes. + pub fn serialize(&self, value: &T) -> Result, RpcError> { + match self.codec_type { + CodecType::Bincode => bincode::serialize(value) + .map_err(|e| RpcError::codec_error(format!("serialize: {}", e))), + } + } + + /// Deserialize bytes to a value. + pub fn deserialize(&self, data: &[u8]) -> Result { + match self.codec_type { + CodecType::Bincode => bincode::deserialize(data) + .map_err(|e| RpcError::codec_error(format!("deserialize: {}", e))), + } + } +} + +impl Default for Codec { + fn default() -> Self { + Self::bincode() + } +} diff --git a/mill-rpc/mill-rpc-core/src/context.rs b/mill-rpc/mill-rpc-core/src/context.rs new file mode 100644 index 0000000..4aca239 --- /dev/null +++ b/mill-rpc/mill-rpc-core/src/context.rs @@ -0,0 +1,32 @@ +use std::net::SocketAddr; + +/// Context available to RPC handlers during request processing. +/// +/// Provides metadata about the current request and the connection it arrived on. +#[derive(Debug, Clone)] +pub struct RpcContext { + /// Unique ID for this request. + pub request_id: u64, + /// Peer address of the client. + pub peer_addr: Option, + /// Service ID being called. + pub service_id: u16, + /// Method ID being called. + pub method_id: u16, +} + +impl RpcContext { + pub fn new(request_id: u64, service_id: u16, method_id: u16) -> Self { + Self { + request_id, + peer_addr: None, + service_id, + method_id, + } + } + + pub fn with_peer_addr(mut self, addr: SocketAddr) -> Self { + self.peer_addr = Some(addr); + self + } +} diff --git a/mill-rpc/mill-rpc-core/src/error.rs b/mill-rpc/mill-rpc-core/src/error.rs new file mode 100644 index 0000000..fb8815a --- /dev/null +++ b/mill-rpc/mill-rpc-core/src/error.rs @@ -0,0 +1,111 @@ +use serde::{Deserialize, Serialize}; +use std::fmt; + +/// RPC status codes (inspired by gRPC). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[repr(u16)] +pub enum RpcStatus { + Ok = 0, + Cancelled = 1, + InvalidArgument = 2, + NotFound = 3, + AlreadyExists = 4, + PermissionDenied = 5, + Unauthenticated = 6, + ResourceExhausted = 7, + Internal = 8, + Unavailable = 9, + DeadlineExceeded = 10, +} + +impl fmt::Display for RpcStatus { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + RpcStatus::Ok => write!(f, "OK"), + RpcStatus::Cancelled => write!(f, "CANCELLED"), + RpcStatus::InvalidArgument => write!(f, "INVALID_ARGUMENT"), + RpcStatus::NotFound => write!(f, "NOT_FOUND"), + RpcStatus::AlreadyExists => write!(f, "ALREADY_EXISTS"), + RpcStatus::PermissionDenied => write!(f, "PERMISSION_DENIED"), + RpcStatus::Unauthenticated => write!(f, "UNAUTHENTICATED"), + RpcStatus::ResourceExhausted => write!(f, "RESOURCE_EXHAUSTED"), + RpcStatus::Internal => write!(f, "INTERNAL"), + RpcStatus::Unavailable => write!(f, "UNAVAILABLE"), + RpcStatus::DeadlineExceeded => write!(f, "DEADLINE_EXCEEDED"), + } + } +} + +/// Structured RPC error with status code and message. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RpcError { + pub status: RpcStatus, + pub message: String, +} + +impl RpcError { + pub fn new(status: RpcStatus, message: impl Into) -> Self { + Self { + status, + message: message.into(), + } + } + + pub fn internal(message: impl Into) -> Self { + Self::new(RpcStatus::Internal, message) + } + + pub fn invalid_argument(message: impl Into) -> Self { + Self::new(RpcStatus::InvalidArgument, message) + } + + pub fn not_found(message: impl Into) -> Self { + Self::new(RpcStatus::NotFound, message) + } + + pub fn method_not_found(method_id: u16) -> Self { + Self::new( + RpcStatus::NotFound, + format!("Method not found: {}", method_id), + ) + } + + pub fn service_not_found(service_id: u16) -> Self { + Self::new( + RpcStatus::NotFound, + format!("Service not found: {}", service_id), + ) + } + + pub fn codec_error(message: impl Into) -> Self { + Self::new(RpcStatus::Internal, message) + } + + pub fn unavailable(message: impl Into) -> Self { + Self::new(RpcStatus::Unavailable, message) + } + + pub fn deadline_exceeded(message: impl Into) -> Self { + Self::new(RpcStatus::DeadlineExceeded, message) + } +} + +impl fmt::Display for RpcError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[{}] {}", self.status, self.message) + } +} + +impl std::error::Error for RpcError {} + +impl From for RpcError { + fn from(err: std::io::Error) -> Self { + RpcError::internal(err.to_string()) + } +} + +impl From for RpcError { + fn from(err: bincode::Error) -> Self { + RpcError::codec_error(format!("bincode: {}", err)) + } +} diff --git a/mill-rpc/mill-rpc-core/src/lib.rs b/mill-rpc/mill-rpc-core/src/lib.rs new file mode 100644 index 0000000..060abc1 --- /dev/null +++ b/mill-rpc/mill-rpc-core/src/lib.rs @@ -0,0 +1,35 @@ +//! Core types for Mill-RPC: wire protocol, codecs, errors, and dispatcher traits. + +pub mod codec; +pub mod context; +pub mod error; +pub mod protocol; + +pub use codec::{Codec, CodecType}; +pub use context::RpcContext; +pub use error::{RpcError, RpcStatus}; +pub use protocol::{Flags, Frame, FrameHeader, MessageType}; + +/// Trait for dispatching RPC calls to handler methods. +/// +/// This is auto-implemented by the `#[mill_rpc::service]` macro for any type +/// that implements the generated `{Service}Server` trait. +pub trait ServiceDispatch: Send + Sync + 'static { + fn dispatch( + &self, + ctx: &RpcContext, + method_id: u16, + args: &[u8], + codec: &Codec, + ) -> Result, RpcError>; +} + +/// Trait for client-side RPC transport. +/// +/// Abstracts the mechanism of sending a request and receiving a response. +/// The main `mill-rpc` crate provides an implementation built on `mill-net`. +pub trait RpcTransport: Send + Sync + 'static { + /// Send a request and wait for a response. + /// Returns the raw response payload bytes. + fn call(&self, service_id: u16, method_id: u16, payload: Vec) -> Result, RpcError>; +} diff --git a/mill-rpc/mill-rpc-core/src/protocol.rs b/mill-rpc/mill-rpc-core/src/protocol.rs new file mode 100644 index 0000000..d40ce1d --- /dev/null +++ b/mill-rpc/mill-rpc-core/src/protocol.rs @@ -0,0 +1,391 @@ +//! Wire protocol: frame format for Mill-RPC. +//! +//! ```text +//! +--------+--------+-------+--------+-----------+---------+ +//! | Magic | Version| Flags | MsgType| PayloadLen| Payload | +//! | 2B | 1B | 1B | 1B | 4B (LE) | N bytes | +//! +--------+--------+-------+--------+-----------+---------+ +//! ``` +//! +//! Request payload: +//! ```text +//! +------------+-----------+-----------+---------+ +//! | RequestID | ServiceID | MethodID | Args | +//! | 8B (LE) | 2B (LE) | 2B (LE) | N bytes | +//! +------------+-----------+-----------+---------+ +//! ``` + +use crate::error::RpcError; +use serde::{Deserialize, Serialize}; + +/// Magic bytes identifying Mill-RPC frames. +pub const MAGIC: [u8; 2] = [0x4D, 0x52]; // "MR" + +/// Current protocol version. +pub const VERSION: u8 = 1; + +/// Header size in bytes (magic:2 + version:1 + flags:1 + msg_type:1 + payload_len:4 = 9). +pub const HEADER_SIZE: usize = 9; + +/// Maximum payload size (16 MB). +pub const MAX_PAYLOAD_SIZE: u32 = 16 * 1024 * 1024; + +/// Message types in the wire protocol. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[repr(u8)] +pub enum MessageType { + Request = 0x01, + Response = 0x02, + Error = 0x03, + Ping = 0x04, + Pong = 0x05, + Cancel = 0x06, +} + +impl MessageType { + pub fn from_u8(v: u8) -> Result { + match v { + 0x01 => Ok(MessageType::Request), + 0x02 => Ok(MessageType::Response), + 0x03 => Ok(MessageType::Error), + 0x04 => Ok(MessageType::Ping), + 0x05 => Ok(MessageType::Pong), + 0x06 => Ok(MessageType::Cancel), + _ => Err(RpcError::invalid_argument(format!( + "Unknown message type: 0x{:02X}", + v + ))), + } + } +} + +/// Bit flags for frame options. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Flags(pub u8); + +impl Flags { + pub const NONE: Flags = Flags(0); + pub const COMPRESSED: Flags = Flags(1 << 0); + pub const ONE_WAY: Flags = Flags(1 << 1); + + pub fn is_one_way(self) -> bool { + self.0 & Self::ONE_WAY.0 != 0 + } + + pub fn is_compressed(self) -> bool { + self.0 & Self::COMPRESSED.0 != 0 + } +} + +/// Parsed frame header. +#[derive(Debug, Clone)] +pub struct FrameHeader { + pub version: u8, + pub flags: Flags, + pub message_type: MessageType, + pub payload_len: u32, +} + +impl FrameHeader { + /// Encode the header into a 9-byte array. + pub fn encode(&self) -> [u8; HEADER_SIZE] { + let mut buf = [0u8; HEADER_SIZE]; + buf[0] = MAGIC[0]; + buf[1] = MAGIC[1]; + buf[2] = self.version; + buf[3] = self.flags.0; + buf[4] = self.message_type as u8; + buf[5..9].copy_from_slice(&self.payload_len.to_le_bytes()); + buf + } + + /// Decode a header from a 9-byte slice. + pub fn decode(buf: &[u8; HEADER_SIZE]) -> Result { + if buf[0] != MAGIC[0] || buf[1] != MAGIC[1] { + return Err(RpcError::invalid_argument(format!( + "Invalid magic: [{:#04X}, {:#04X}]", + buf[0], buf[1] + ))); + } + + let version = buf[2]; + if version != VERSION { + return Err(RpcError::invalid_argument(format!( + "Unsupported version: {}", + version + ))); + } + + let flags = Flags(buf[3]); + let message_type = MessageType::from_u8(buf[4])?; + let payload_len = u32::from_le_bytes([buf[5], buf[6], buf[7], buf[8]]); + + if payload_len > MAX_PAYLOAD_SIZE { + return Err(RpcError::invalid_argument(format!( + "Payload too large: {} > {}", + payload_len, MAX_PAYLOAD_SIZE + ))); + } + + Ok(Self { + version, + flags, + message_type, + payload_len, + }) + } +} + +/// A complete frame (header + payload). +#[derive(Debug, Clone)] +pub struct Frame { + pub header: FrameHeader, + pub payload: Vec, +} + +impl Frame { + /// Create a new request frame. + pub fn request( + request_id: u64, + service_id: u16, + method_id: u16, + args: Vec, + one_way: bool, + ) -> Self { + let mut payload = Vec::with_capacity(12 + args.len()); + payload.extend_from_slice(&request_id.to_le_bytes()); + payload.extend_from_slice(&service_id.to_le_bytes()); + payload.extend_from_slice(&method_id.to_le_bytes()); + payload.extend_from_slice(&args); + + let flags = if one_way { Flags::ONE_WAY } else { Flags::NONE }; + + Frame { + header: FrameHeader { + version: VERSION, + flags, + message_type: MessageType::Request, + payload_len: payload.len() as u32, + }, + payload, + } + } + + /// Create a response frame. + pub fn response(request_id: u64, data: Vec) -> Self { + let mut payload = Vec::with_capacity(8 + data.len()); + payload.extend_from_slice(&request_id.to_le_bytes()); + payload.extend_from_slice(&data); + + Frame { + header: FrameHeader { + version: VERSION, + flags: Flags::NONE, + message_type: MessageType::Response, + payload_len: payload.len() as u32, + }, + payload, + } + } + + /// Create an error frame. + pub fn error(request_id: u64, error_data: Vec) -> Self { + let mut payload = Vec::with_capacity(8 + error_data.len()); + payload.extend_from_slice(&request_id.to_le_bytes()); + payload.extend_from_slice(&error_data); + + Frame { + header: FrameHeader { + version: VERSION, + flags: Flags::NONE, + message_type: MessageType::Error, + payload_len: payload.len() as u32, + }, + payload, + } + } + + /// Create a ping frame. + pub fn ping() -> Self { + Frame { + header: FrameHeader { + version: VERSION, + flags: Flags::NONE, + message_type: MessageType::Ping, + payload_len: 0, + }, + payload: Vec::new(), + } + } + + /// Create a pong frame. + pub fn pong() -> Self { + Frame { + header: FrameHeader { + version: VERSION, + flags: Flags::NONE, + message_type: MessageType::Pong, + payload_len: 0, + }, + payload: Vec::new(), + } + } + + /// Encode the full frame to bytes. + pub fn encode(&self) -> Vec { + let header_bytes = self.header.encode(); + let mut buf = Vec::with_capacity(HEADER_SIZE + self.payload.len()); + buf.extend_from_slice(&header_bytes); + buf.extend_from_slice(&self.payload); + buf + } + + /// Parse request payload fields (request_id, service_id, method_id, args). + pub fn parse_request_payload(&self) -> Result<(u64, u16, u16, &[u8]), RpcError> { + if self.payload.len() < 12 { + return Err(RpcError::invalid_argument("Request payload too short")); + } + let request_id = u64::from_le_bytes(self.payload[0..8].try_into().unwrap()); + let service_id = u16::from_le_bytes(self.payload[8..10].try_into().unwrap()); + let method_id = u16::from_le_bytes(self.payload[10..12].try_into().unwrap()); + let args = &self.payload[12..]; + Ok((request_id, service_id, method_id, args)) + } + + /// Parse response payload fields (request_id, data). + pub fn parse_response_payload(&self) -> Result<(u64, &[u8]), RpcError> { + if self.payload.len() < 8 { + return Err(RpcError::invalid_argument("Response payload too short")); + } + let request_id = u64::from_le_bytes(self.payload[0..8].try_into().unwrap()); + let data = &self.payload[8..]; + Ok((request_id, data)) + } +} + +/// Reads frames from a byte buffer. Returns parsed frames and the number of bytes consumed. +/// +/// This is a streaming parser: it handles partial frames by returning only +/// complete frames and leaving remaining bytes unconsumed. +pub fn parse_frames(buf: &[u8]) -> Result<(Vec, usize), RpcError> { + let mut frames = Vec::new(); + let mut offset = 0; + + while offset + HEADER_SIZE <= buf.len() { + let header_bytes: &[u8; HEADER_SIZE] = buf[offset..offset + HEADER_SIZE] + .try_into() + .map_err(|_| RpcError::internal("Header slice conversion failed"))?; + + let header = FrameHeader::decode(header_bytes)?; + let total_frame_size = HEADER_SIZE + header.payload_len as usize; + + if offset + total_frame_size > buf.len() { + // Incomplete frame, wait for more data. + break; + } + + let payload = buf[offset + HEADER_SIZE..offset + total_frame_size].to_vec(); + frames.push(Frame { header, payload }); + offset += total_frame_size; + } + + Ok((frames, offset)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_header_roundtrip() { + let header = FrameHeader { + version: VERSION, + flags: Flags::NONE, + message_type: MessageType::Request, + payload_len: 42, + }; + let encoded = header.encode(); + let decoded = FrameHeader::decode(&encoded).unwrap(); + assert_eq!(decoded.version, VERSION); + assert_eq!(decoded.flags, Flags::NONE); + assert_eq!(decoded.message_type, MessageType::Request); + assert_eq!(decoded.payload_len, 42); + } + + #[test] + fn test_request_frame_roundtrip() { + let frame = Frame::request(123, 1, 2, vec![10, 20, 30], false); + let bytes = frame.encode(); + let (frames, consumed) = parse_frames(&bytes).unwrap(); + assert_eq!(consumed, bytes.len()); + assert_eq!(frames.len(), 1); + + let (req_id, svc_id, method_id, args) = frames[0].parse_request_payload().unwrap(); + assert_eq!(req_id, 123); + assert_eq!(svc_id, 1); + assert_eq!(method_id, 2); + assert_eq!(args, &[10, 20, 30]); + } + + #[test] + fn test_response_frame_roundtrip() { + let frame = Frame::response(456, vec![1, 2, 3]); + let bytes = frame.encode(); + let (frames, _) = parse_frames(&bytes).unwrap(); + let (req_id, data) = frames[0].parse_response_payload().unwrap(); + assert_eq!(req_id, 456); + assert_eq!(data, &[1, 2, 3]); + } + + #[test] + fn test_multiple_frames() { + let f1 = Frame::request(1, 0, 0, vec![0xAA], false); + let f2 = Frame::response(1, vec![0xBB]); + let mut bytes = f1.encode(); + bytes.extend_from_slice(&f2.encode()); + + let (frames, consumed) = parse_frames(&bytes).unwrap(); + assert_eq!(consumed, bytes.len()); + assert_eq!(frames.len(), 2); + assert_eq!(frames[0].header.message_type, MessageType::Request); + assert_eq!(frames[1].header.message_type, MessageType::Response); + } + + #[test] + fn test_partial_frame() { + let frame = Frame::request(1, 0, 0, vec![0xAA; 100], false); + let bytes = frame.encode(); + // Only give half the bytes + let partial = &bytes[..bytes.len() / 2]; + let (frames, consumed) = parse_frames(partial).unwrap(); + assert_eq!(frames.len(), 0); + assert_eq!(consumed, 0); + } + + #[test] + fn test_ping_pong() { + let ping = Frame::ping(); + let pong = Frame::pong(); + assert_eq!(ping.header.message_type, MessageType::Ping); + assert_eq!(pong.header.message_type, MessageType::Pong); + assert_eq!(ping.payload.len(), 0); + assert_eq!(pong.payload.len(), 0); + } + + #[test] + fn test_invalid_magic() { + let mut buf = [0u8; HEADER_SIZE]; + buf[0] = 0xFF; + buf[1] = 0xFF; + assert!(FrameHeader::decode(&buf).is_err()); + } + + #[test] + fn test_one_way_flag() { + let frame = Frame::request(1, 0, 0, vec![], true); + assert!(frame.header.flags.is_one_way()); + + let frame = Frame::request(1, 0, 0, vec![], false); + assert!(!frame.header.flags.is_one_way()); + } +} diff --git a/mill-rpc/mill-rpc-macros/Cargo.toml b/mill-rpc/mill-rpc-macros/Cargo.toml new file mode 100644 index 0000000..8cfe928 --- /dev/null +++ b/mill-rpc/mill-rpc-macros/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "mill-rpc-macros" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +rust-version.workspace = true +description = "Proc macros for Mill-RPC service definitions" + +[lib] +proc-macro = true + +[dependencies] +syn = { version = "2", features = ["full"] } +quote = "1" +proc-macro2 = "1" diff --git a/mill-rpc/mill-rpc-macros/src/lib.rs b/mill-rpc/mill-rpc-macros/src/lib.rs new file mode 100644 index 0000000..9007182 --- /dev/null +++ b/mill-rpc/mill-rpc-macros/src/lib.rs @@ -0,0 +1,420 @@ +use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use syn::{ + braced, + parse::{Parse, ParseStream}, + parse_macro_input, FnArg, Ident, Pat, ReturnType, Token, TraitItemFn, Type, +}; + +/// Module-level macro for defining an RPC service. +/// +/// Generates a module containing `Service` trait, `Client` struct, +/// `server()` wrapper function, and all request/response types. +/// +/// By default, both server and client code are generated. +/// Use `#[server]` or `#[client]` to generate only one side. +/// +/// # Examples +/// +/// ```ignore +/// // Generate both server and client (default) +/// mill_rpc::service! { +/// service Calculator { +/// fn add(a: i32, b: i32) -> i32; +/// fn divide(a: f64, b: f64) -> f64; +/// } +/// } +/// +/// // Server only +/// mill_rpc::service! { +/// #[server] +/// service Calculator { +/// fn add(a: i32, b: i32) -> i32; +/// } +/// } +/// +/// // Client only (e.g. in a separate client crate) +/// mill_rpc::service! { +/// #[client] +/// service Calculator { +/// fn add(a: i32, b: i32) -> i32; +/// } +/// } +/// ``` +#[proc_macro] +pub fn service(input: TokenStream) -> TokenStream { + let def = parse_macro_input!(input as ServiceDef); + match generate_service_module(def) { + Ok(tokens) => tokens.into(), + Err(err) => err.to_compile_error().into(), + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum GenerateMode { + Both, + ServerOnly, + ClientOnly, +} + +struct ServiceDef { + mode: GenerateMode, + name: Ident, + methods: Vec, +} + +struct MethodDef { + name: Ident, + args: Vec<(Ident, Type)>, + return_type: Type, +} + +impl Parse for ServiceDef { + fn parse(input: ParseStream) -> syn::Result { + // Parse optional #[server] or #[client] + let mode = if input.peek(Token![#]) { + input.parse::()?; + let content; + syn::bracketed!(content in input); + let attr_name: Ident = content.parse()?; + match attr_name.to_string().as_str() { + "server" => GenerateMode::ServerOnly, + "client" => GenerateMode::ClientOnly, + other => { + return Err(syn::Error::new_spanned( + attr_name, + format!( + "Unknown attribute `{}`, expected `server` or `client`", + other + ), + )) + } + } + } else { + GenerateMode::Both + }; + + // Parse `service Name { ... }` + let service_kw: Ident = input.parse()?; + if service_kw != "service" { + return Err(syn::Error::new_spanned(service_kw, "Expected `service`")); + } + + let name: Ident = input.parse()?; + + let content; + braced!(content in input); + + let mut methods = Vec::new(); + while !content.is_empty() { + let method: TraitItemFn = content.parse()?; + + let method_name = method.sig.ident.clone(); + + let mut args = Vec::new(); + for arg in &method.sig.inputs { + match arg { + FnArg::Typed(pat_type) => { + let ident = match &*pat_type.pat { + Pat::Ident(pi) => pi.ident.clone(), + other => { + return Err(syn::Error::new_spanned( + other, + "Expected a simple identifier for argument name", + )) + } + }; + args.push((ident, (*pat_type.ty).clone())); + } + FnArg::Receiver(_) => { + return Err(syn::Error::new_spanned( + arg, + "Service methods should not have `self` parameter", + )) + } + } + } + + let return_type = match &method.sig.output { + ReturnType::Default => syn::parse_quote!(()), + ReturnType::Type(_, ty) => (**ty).clone(), + }; + + methods.push(MethodDef { + name: method_name, + args, + return_type, + }); + } + + Ok(ServiceDef { + mode, + name, + methods, + }) + } +} + +fn generate_service_module(def: ServiceDef) -> syn::Result { + let mod_name = format_ident!("{}", to_snake_case(&def.name.to_string())); + let service_name_str = def.name.to_string(); + let method_count = def.methods.len() as u16; + + let gen_server = def.mode != GenerateMode::ClientOnly; + let gen_client = def.mode != GenerateMode::ServerOnly; + + let method_consts: Vec<_> = def + .methods + .iter() + .enumerate() + .map(|(idx, m)| { + let const_name = format_ident!("{}", m.name.to_string().to_uppercase()); + let id = idx as u16; + quote! { pub const #const_name: u16 = #id; } + }) + .collect(); + + // Request / Response types + let type_defs: Vec<_> = def + .methods + .iter() + .map(|m| { + let req_name = format_ident!("{}Request", to_pascal_case(&m.name.to_string())); + let resp_name = format_ident!("{}Response", to_pascal_case(&m.name.to_string())); + let ret_ty = &m.return_type; + + let field_names: Vec<_> = m.args.iter().map(|(n, _)| n).collect(); + let field_types: Vec<_> = m.args.iter().map(|(_, t)| t).collect(); + + let req_struct = if m.args.is_empty() { + quote! { + #[derive(::serde::Serialize, ::serde::Deserialize, Debug)] + pub(super) struct #req_name; + } + } else { + quote! { + #[derive(::serde::Serialize, ::serde::Deserialize, Debug)] + pub(super) struct #req_name { + #( pub #field_names: #field_types, )* + } + } + }; + + quote! { + #req_struct + + #[derive(::serde::Serialize, ::serde::Deserialize, Debug)] + pub(super) struct #resp_name(pub #ret_ty); + } + }) + .collect(); + + let server_trait = if gen_server { + let trait_methods: Vec<_> = def + .methods + .iter() + .map(|m| { + let name = &m.name; + let ret_ty = &m.return_type; + let arg_names: Vec<_> = m.args.iter().map(|(n, _)| n).collect(); + let arg_types: Vec<_> = m.args.iter().map(|(_, t)| t).collect(); + quote! { + fn #name(&self, ctx: &::mill_rpc_core::RpcContext, #( #arg_names: #arg_types ),*) -> #ret_ty; + } + }) + .collect(); + + let dispatch_arms: Vec<_> = def + .methods + .iter() + .map(|m| { + let name = &m.name; + let const_name = format_ident!("{}", m.name.to_string().to_uppercase()); + let req_name = format_ident!("{}Request", to_pascal_case(&m.name.to_string())); + let resp_name = format_ident!("{}Response", to_pascal_case(&m.name.to_string())); + + let call_args = if m.args.is_empty() { + quote! {} + } else { + let field_names: Vec<_> = m.args.iter().map(|(n, _)| n).collect(); + let args: Vec<_> = field_names.iter().map(|n| quote! { req.#n }).collect(); + quote! { , #( #args ),* } + }; + + quote! { + methods::#const_name => { + let req: types::#req_name = codec.deserialize(args)?; + let result = svc.#name(ctx #call_args); + codec.serialize(&types::#resp_name(result)) + } + } + }) + .collect(); + + quote! { + /// Server trait — implement this to handle RPC calls for this service. + pub trait Service: Send + Sync + 'static { + #( #trait_methods )* + } + + /// Internal dispatcher that bridges `Service` impl to `ServiceDispatch`. + struct Dispatcher(T); + + impl ::mill_rpc_core::ServiceDispatch for Dispatcher { + fn dispatch( + &self, + ctx: &::mill_rpc_core::RpcContext, + method_id: u16, + args: &[u8], + codec: &::mill_rpc_core::Codec, + ) -> Result, ::mill_rpc_core::RpcError> { + let svc = &self.0; + match method_id { + #( #dispatch_arms, )* + _ => Err(::mill_rpc_core::RpcError::method_not_found(method_id)), + } + } + } + + /// Wrap a `Service` implementation for server registration. + /// + /// # Example + /// ```ignore + /// RpcServer::builder() + /// .service(calculator::server(MyCalc)) + /// .build(&event_loop)?; + /// ``` + pub fn server(implementation: T) -> impl ::mill_rpc_core::ServiceDispatch { + Dispatcher(implementation) + } + } + } else { + quote! {} + }; + + let client_code = if gen_client { + let client_methods: Vec<_> = def + .methods + .iter() + .map(|m| { + let name = &m.name; + let ret_ty = &m.return_type; + let const_name = format_ident!("{}", m.name.to_string().to_uppercase()); + let req_name = format_ident!("{}Request", to_pascal_case(&m.name.to_string())); + let resp_name = format_ident!("{}Response", to_pascal_case(&m.name.to_string())); + + let arg_names: Vec<_> = m.args.iter().map(|(n, _)| n).collect(); + let arg_types: Vec<_> = m.args.iter().map(|(_, t)| t).collect(); + + let req_construct = if m.args.is_empty() { + quote! { types::#req_name } + } else { + let fields: Vec<_> = arg_names.iter().map(|n| quote! { #n: #n }).collect(); + quote! { types::#req_name { #( #fields, )* } } + }; + + quote! { + pub fn #name(&self, #( #arg_names: #arg_types ),*) -> Result<#ret_ty, ::mill_rpc_core::RpcError> { + let req = #req_construct; + let payload = self.codec.serialize(&req)?; + let resp_bytes = self.transport.call( + self.service_id, + methods::#const_name, + payload, + )?; + let resp: types::#resp_name = self.codec.deserialize(&resp_bytes)?; + Ok(resp.0) + } + } + }) + .collect(); + + quote! { + /// Generated RPC client for this service. + pub struct Client { + transport: ::std::sync::Arc, + codec: ::mill_rpc_core::Codec, + service_id: u16, + } + + impl Client { + /// Create a new client. + /// + /// - `transport`: the RPC transport (typically an `RpcClient`) + /// - `codec`: serialization codec (must match the server) + /// - `service_id`: the ID assigned to this service on the server + /// (matches registration order, starting from 0) + pub fn new( + transport: ::std::sync::Arc, + codec: ::mill_rpc_core::Codec, + service_id: u16, + ) -> Self { + Self { transport, codec, service_id } + } + + #( #client_methods )* + } + } + } else { + quote! {} + }; + + let output = quote! { + pub mod #mod_name { + #![allow(unused_imports)] + use super::*; + + /// Method ID constants. + pub mod methods { + #( #method_consts )* + } + + /// Service metadata. + pub const SERVICE_NAME: &str = #service_name_str; + pub const METHOD_COUNT: u16 = #method_count; + + /// Internal request/response types (not part of the public API). + mod types { + use super::super::*; + #( #type_defs )* + } + + #server_trait + + #client_code + } + }; + + Ok(output) +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn to_snake_case(s: &str) -> String { + let mut result = String::new(); + for (i, ch) in s.chars().enumerate() { + if ch.is_uppercase() { + if i > 0 { + result.push('_'); + } + result.push(ch.to_lowercase().next().unwrap()); + } else { + result.push(ch); + } + } + result +} + +fn to_pascal_case(s: &str) -> String { + s.split('_') + .map(|part| { + let mut chars = part.chars(); + match chars.next() { + None => String::new(), + Some(c) => c.to_uppercase().to_string() + chars.as_str(), + } + }) + .collect() +} diff --git a/mill-rpc/src/client.rs b/mill-rpc/src/client.rs new file mode 100644 index 0000000..6e8fe0e --- /dev/null +++ b/mill-rpc/src/client.rs @@ -0,0 +1,257 @@ +//! RPC Client built on mill-net's TcpClient. +//! +//! Connects to an RPC server, sends request frames, and waits for responses. + +use crate::{RpcError, RpcTransport}; +use mill_io::EventLoop; +use mill_net::tcp::traits::{ConnectionId, NetworkHandler}; +use mill_net::tcp::{ServerContext, TcpClient}; +use mill_rpc_core::protocol::{self, Frame, MessageType}; +use mio::Token; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Condvar, Mutex}; +use std::time::Duration; + +/// Pending request waiting for a response. +struct PendingRequest { + completed: bool, + result: Option, RpcError>>, +} + +/// Shared state between the client and its network handler. +struct ClientShared { + /// Map of request_id -> pending request. + pending: Mutex>, + /// Condvar to notify waiting callers when a response arrives. + notify: Condvar, + /// Receive buffer for partial frame parsing. + recv_buf: Mutex>, +} + +/// RPC client that connects to a Mill-RPC server. +pub struct RpcClient { + tcp_client: Mutex>, + shared: Arc, + next_request_id: AtomicU64, + timeout: AtomicU64, +} + +impl RpcClient { + /// Connect to an RPC server. + pub fn connect(addr: SocketAddr, event_loop: &Arc) -> Result, RpcError> { + let shared = Arc::new(ClientShared { + pending: Mutex::new(HashMap::new()), + notify: Condvar::new(), + recv_buf: Mutex::new(Vec::new()), + }); + + let handler = RpcClientHandler { + shared: shared.clone(), + }; + + let mut tcp_client = TcpClient::connect(addr, handler) + .map_err(|e| RpcError::unavailable(format!("Connect failed: {}", e)))?; + + tcp_client + .start(event_loop, Token(usize::MAX - 1)) + .map_err(|e| RpcError::unavailable(format!("Client start failed: {}", e)))?; + + Ok(Arc::new(Self { + tcp_client: Mutex::new(tcp_client), + shared, + next_request_id: AtomicU64::new(1), + timeout: AtomicU64::new(30 * 1000), + })) + } + + /// Set the default timeout for RPC calls. + pub fn set_timeout(&self, timeout: Duration) { + self.timeout + .store(timeout.as_millis() as u64, Ordering::SeqCst); + } + + fn timeout(&self) -> Duration { + Duration::from_millis(self.timeout.load(Ordering::SeqCst)) + } + + /// Send a request and wait for a response (blocking). + fn call_raw( + &self, + service_id: u16, + method_id: u16, + payload: Vec, + ) -> Result, RpcError> { + let request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst); + + // Register the pending request before sending. + { + let mut pending = self.shared.pending.lock().unwrap(); + pending.insert( + request_id, + PendingRequest { + completed: false, + result: None, + }, + ); + } + + let frame = Frame::request(request_id, service_id, method_id, payload, false); + let send_result = { + let client = self.tcp_client.lock().unwrap(); + client.send(&frame.encode()) + }; + if let Err(e) = send_result { + let mut pending = self.shared.pending.lock().unwrap(); + pending.remove(&request_id); + return Err(RpcError::unavailable(format!("Send failed: {}", e))); + } + + let mut pending = self.shared.pending.lock().unwrap(); + let deadline = std::time::Instant::now() + self.timeout(); + + loop { + if let Some(req) = pending.get(&request_id) { + if req.completed { + let req = pending.remove(&request_id).unwrap(); + return req.result.unwrap(); + } + } else { + return Err(RpcError::internal("Pending request disappeared")); + } + + let remaining = deadline.saturating_duration_since(std::time::Instant::now()); + if remaining.is_zero() { + pending.remove(&request_id); + return Err(RpcError::deadline_exceeded(format!( + "Request {} timed out after {:?}", + request_id, self.timeout + ))); + } + + let (guard, timeout_result) = + self.shared.notify.wait_timeout(pending, remaining).unwrap(); + pending = guard; + + if timeout_result.timed_out() { + pending.remove(&request_id); + return Err(RpcError::deadline_exceeded(format!( + "Request {} timed out after {:?}", + request_id, self.timeout + ))); + } + } + } + + /// Send a one-way request (fire-and-forget). + pub fn call_oneway( + &self, + service_id: u16, + method_id: u16, + payload: Vec, + ) -> Result<(), RpcError> { + let request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst); + let frame = Frame::request(request_id, service_id, method_id, payload, true); + + let client = self.tcp_client.lock().unwrap(); + client + .send(&frame.encode()) + .map_err(|e| RpcError::unavailable(format!("Send failed: {}", e)))?; + Ok(()) + } +} + +impl RpcTransport for RpcClient { + fn call(&self, service_id: u16, method_id: u16, payload: Vec) -> Result, RpcError> { + self.call_raw(service_id, method_id, payload) + } +} + +/// Network handler for the RPC client - receives response frames. +struct RpcClientHandler { + shared: Arc, +} + +impl NetworkHandler for RpcClientHandler { + fn on_data( + &self, + _ctx: &ServerContext, + _conn_id: ConnectionId, + data: &[u8], + ) -> mill_net::errors::Result<()> { + let mut recv_buf = self.shared.recv_buf.lock().unwrap(); + recv_buf.extend_from_slice(data); + + let (frames, consumed) = match protocol::parse_frames(&recv_buf) { + Ok(r) => r, + Err(e) => { + log::error!("Client frame parse error: {}", e); + recv_buf.clear(); + return Ok(()); + } + }; + + if consumed > 0 { + recv_buf.drain(..consumed); + } + + drop(recv_buf); + + for frame in frames { + self.handle_frame(frame); + } + + Ok(()) + } +} + +impl RpcClientHandler { + fn handle_frame(&self, frame: Frame) { + match frame.header.message_type { + MessageType::Response => { + let (request_id, data) = match frame.parse_response_payload() { + Ok(r) => r, + Err(e) => { + log::error!("Invalid response payload: {}", e); + return; + } + }; + + let mut pending = self.shared.pending.lock().unwrap(); + if let Some(req) = pending.get_mut(&request_id) { + req.completed = true; + req.result = Some(Ok(data.to_vec())); + } + self.shared.notify.notify_all(); + } + MessageType::Error => { + let (request_id, err_data) = match frame.parse_response_payload() { + Ok(r) => r, + Err(e) => { + log::error!("Invalid error payload: {}", e); + return; + } + }; + + let rpc_err: RpcError = match bincode::deserialize(err_data) { + Ok(e) => e, + Err(_) => RpcError::internal(String::from_utf8_lossy(err_data).to_string()), + }; + + let mut pending = self.shared.pending.lock().unwrap(); + if let Some(req) = pending.get_mut(&request_id) { + req.completed = true; + req.result = Some(Err(rpc_err)); + } + self.shared.notify.notify_all(); + } + MessageType::Pong => { + log::debug!("Received pong from server"); + } + other => { + log::warn!("Unexpected message type from server: {:?}", other); + } + } + } +} diff --git a/mill-rpc/src/lib.rs b/mill-rpc/src/lib.rs new file mode 100644 index 0000000..432d100 --- /dev/null +++ b/mill-rpc/src/lib.rs @@ -0,0 +1,44 @@ +//! Mill-RPC: An Axum-inspired RPC framework built on Mill-IO. +//! +//! # Quick Start +//! +//! ```ignore +//! // Define a service — generates a `calculator` module +//! mill_rpc::service! { +//! service Calculator { +//! fn add(a: i32, b: i32) -> i32; +//! } +//! } +//! +//! // Server side +//! struct MyCalc; +//! impl calculator::Service for MyCalc { +//! fn add(&self, _ctx: &RpcContext, a: i32, b: i32) -> i32 { a + b } +//! } +//! +//! // Register +//! RpcServer::builder() +//! .service(calculator::server(MyCalc)) +//! .build(&event_loop)?; +//! +//! // Client side +//! let client = calculator::Client::new(transport, codec, 0); +//! client.add(2, 3)?; +//! ``` + +pub mod client; +pub mod server; + +pub mod prelude; + +// Re-exports from core +pub use mill_rpc_core::{ + Codec, CodecType, Flags, Frame, FrameHeader, MessageType, RpcContext, RpcError, RpcStatus, + RpcTransport, ServiceDispatch, +}; + +// Re-export the service! macro +pub use mill_rpc_macros::service; + +pub use client::RpcClient; +pub use server::RpcServer; diff --git a/mill-rpc/src/prelude.rs b/mill-rpc/src/prelude.rs new file mode 100644 index 0000000..2bac0d0 --- /dev/null +++ b/mill-rpc/src/prelude.rs @@ -0,0 +1,7 @@ +//! Convenient re-exports for Mill-RPC users. + +pub use crate::client::RpcClient; +pub use crate::server::RpcServer; +pub use crate::{Codec, CodecType, RpcContext, RpcError, RpcStatus, RpcTransport, ServiceDispatch}; + +pub use serde::{Deserialize, Serialize}; diff --git a/mill-rpc/src/server.rs b/mill-rpc/src/server.rs new file mode 100644 index 0000000..b2bf5bb --- /dev/null +++ b/mill-rpc/src/server.rs @@ -0,0 +1,265 @@ +//! RPC Server built on mill-net's TcpServer. +//! +//! The server accepts TCP connections, parses Mill-RPC frames, dispatches +//! requests to registered services, and sends back responses. + +use crate::{Codec, RpcContext, RpcError, ServiceDispatch}; +use mill_io::EventLoop; +use mill_net::errors::{NetworkError, Result}; +use mill_net::tcp::config::TcpServerConfig; +use mill_net::tcp::traits::{ConnectionId, NetworkHandler}; +use mill_net::tcp::{ServerContext, TcpServer}; +use mill_rpc_core::protocol::{self, Frame, MessageType}; +use mio::Token; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::{Arc, Mutex, RwLock}; + +/// A registered service with its dispatch implementation. +struct RegisteredService { + dispatcher: Box, +} + +/// RPC Server that hosts one or more services. +pub struct RpcServer { + _tcp_server: Arc>, +} + +impl RpcServer { + /// Create a builder for configuring the RPC server. + pub fn builder() -> RpcServerBuilder { + RpcServerBuilder::new() + } +} + +/// Builder for constructing an RpcServer. +pub struct RpcServerBuilder { + address: Option, + codec: Codec, + max_connections: Option, + services: Vec<(u16, Box)>, + next_service_id: u16, +} + +impl RpcServerBuilder { + fn new() -> Self { + Self { + address: None, + codec: Codec::bincode(), + max_connections: None, + services: Vec::new(), + next_service_id: 0, + } + } + + /// Set the address to bind to. + pub fn bind(mut self, addr: SocketAddr) -> Self { + self.address = Some(addr); + self + } + + /// Set the codec for serialization. + pub fn codec(mut self, codec: Codec) -> Self { + self.codec = codec; + self + } + + /// Set the maximum number of connections. + pub fn max_connections(mut self, max: usize) -> Self { + self.max_connections = Some(max); + self + } + + /// Register a service implementation wrapped in its dispatcher. + /// + /// The service must implement `ServiceDispatch` - typically via the + /// generated `{Name}Dispatcher` wrapper: + /// + /// ```ignore + /// .service(CalculatorDispatcher(MyCalculator)) + /// ``` + pub fn service(mut self, svc: S) -> Self { + let id = self.next_service_id; + self.next_service_id += 1; + self.services.push((id, Box::new(svc))); + self + } + + /// Register a service with an explicit service ID. + pub fn service_with_id(mut self, id: u16, svc: S) -> Self { + self.services.push((id, Box::new(svc))); + self + } + + /// Build and start the RPC server on the given event loop. + pub fn build(self, event_loop: &Arc) -> Result { + let addr = self + .address + .unwrap_or_else(|| "127.0.0.1:9000".parse().unwrap()); + + let mut service_map = HashMap::new(); + for (id, dispatcher) in self.services { + service_map.insert(id, RegisteredService { dispatcher }); + } + + let handler = RpcServerHandler { + services: Arc::new(RwLock::new(service_map)), + codec: self.codec, + conn_buffers: Arc::new(Mutex::new(HashMap::new())), + }; + + let mut config_builder = TcpServerConfig::builder().address(addr); + if let Some(max) = self.max_connections { + config_builder = config_builder.max_connections(max); + } + let config = config_builder.build(); + + let tcp_server = Arc::new(TcpServer::new(config, handler)?); + tcp_server.clone().start(event_loop, Token(0))?; + + log::info!("Mill-RPC server listening on {}", addr); + + Ok(RpcServer { + _tcp_server: tcp_server, + }) + } +} + +/// Internal handler that bridges mill-net's NetworkHandler to RPC dispatch. +struct RpcServerHandler { + services: Arc>>, + codec: Codec, + /// Per-connection receive buffers for handling partial frames. + conn_buffers: Arc>>>, +} + +impl NetworkHandler for RpcServerHandler { + fn on_connect(&self, _ctx: &ServerContext, conn_id: ConnectionId) -> Result<()> { + log::debug!("RPC connection established: {:?}", conn_id); + self.conn_buffers + .lock() + .unwrap() + .insert(conn_id.as_u64(), Vec::new()); + Ok(()) + } + + fn on_data(&self, ctx: &ServerContext, conn_id: ConnectionId, data: &[u8]) -> Result<()> { + // Append incoming data to the connection's buffer. + let mut buffers = self.conn_buffers.lock().unwrap(); + let buf = buffers.entry(conn_id.as_u64()).or_default(); + buf.extend_from_slice(data); + + // Try to parse complete frames. + let (frames, consumed) = match protocol::parse_frames(buf) { + Ok(result) => result, + Err(e) => { + log::error!("Frame parse error from {:?}: {}", conn_id, e); + // Clear buffer on parse error - connection is likely corrupted. + buf.clear(); + return Ok(()); + } + }; + + // Remove consumed bytes from the buffer. + if consumed > 0 { + buf.drain(..consumed); + } + + // Drop the lock before processing frames (handlers may be slow). + drop(buffers); + + // Process each complete frame. + for frame in frames { + self.handle_frame(ctx, conn_id, frame); + } + + Ok(()) + } + + fn on_disconnect(&self, _ctx: &ServerContext, conn_id: ConnectionId) -> Result<()> { + log::debug!("RPC connection closed: {:?}", conn_id); + self.conn_buffers.lock().unwrap().remove(&conn_id.as_u64()); + Ok(()) + } + + fn on_error(&self, _ctx: &ServerContext, conn_id: Option, error: NetworkError) { + log::error!("RPC network error (conn={:?}): {}", conn_id, error); + } +} + +impl RpcServerHandler { + fn handle_frame(&self, ctx: &ServerContext, conn_id: ConnectionId, frame: Frame) { + match frame.header.message_type { + MessageType::Request => self.handle_request(ctx, conn_id, frame), + MessageType::Ping => { + let pong = Frame::pong(); + if let Err(e) = ctx.send_to(conn_id, &pong.encode()) { + log::error!("Failed to send pong to {:?}: {}", conn_id, e); + } + } + MessageType::Cancel => { + log::debug!("Cancel received from {:?} (not yet supported)", conn_id); + } + other => { + log::warn!("Unexpected message type {:?} from {:?}", other, conn_id); + } + } + } + + fn handle_request(&self, ctx: &ServerContext, conn_id: ConnectionId, frame: Frame) { + let (request_id, service_id, method_id, args) = match frame.parse_request_payload() { + Ok(parsed) => parsed, + Err(e) => { + log::error!("Invalid request payload from {:?}: {}", conn_id, e); + return; + } + }; + + let is_one_way = frame.header.flags.is_one_way(); + + let rpc_ctx = RpcContext::new(request_id, service_id, method_id); + + // Look up the service and dispatch. + let result = { + let services = self.services.read().unwrap(); + match services.get(&service_id) { + Some(svc) => svc + .dispatcher + .dispatch(&rpc_ctx, method_id, args, &self.codec), + None => Err(RpcError::service_not_found(service_id)), + } + }; + + // Don't send a response for one-way calls. + if is_one_way { + if let Err(e) = &result { + log::error!("One-way request {}.{} failed: {}", service_id, method_id, e); + } + return; + } + + // Send response or error frame. + let response_frame = match result { + Ok(resp_bytes) => Frame::response(request_id, resp_bytes), + Err(rpc_err) => { + let err_bytes = match self.codec.serialize(&rpc_err) { + Ok(b) => b, + Err(e) => { + log::error!("Failed to serialize error: {}", e); + return; + } + }; + Frame::error(request_id, err_bytes) + } + }; + + if let Err(e) = ctx.send_to(conn_id, &response_frame.encode()) { + log::error!( + "Failed to send response for request {} to {:?}: {}", + request_id, + conn_id, + e + ); + } + } +} diff --git a/mill-rpc/tests/integration_test.rs b/mill-rpc/tests/integration_test.rs new file mode 100644 index 0000000..c41cbac --- /dev/null +++ b/mill-rpc/tests/integration_test.rs @@ -0,0 +1,101 @@ +use mill_io::EventLoop; +use mill_rpc::prelude::*; +use std::sync::Arc; +use std::thread; +use std::time::Duration; + +// Generate both server and client for testing +mill_rpc::service! { + service Calculator { + fn add(a: i32, b: i32) -> i32; + fn multiply(a: f64, b: f64) -> f64; + fn echo(msg: String) -> String; + } +} + +struct MyCalculator; + +impl calculator::Service for MyCalculator { + fn add(&self, _ctx: &RpcContext, a: i32, b: i32) -> i32 { + a + b + } + + fn multiply(&self, _ctx: &RpcContext, a: f64, b: f64) -> f64 { + a * b + } + + fn echo(&self, _ctx: &RpcContext, msg: String) -> String { + format!("echo: {}", msg) + } +} + +#[test] +fn test_dispatch_add() { + let codec = Codec::bincode(); + let ctx = RpcContext::new(1, 0, 0); + let dispatcher = calculator::server(MyCalculator); + + // bincode serialization of AddRequest { a: 2, b: 3 }: two i32 LE + let mut payload = Vec::new(); + payload.extend_from_slice(&2i32.to_le_bytes()); + payload.extend_from_slice(&3i32.to_le_bytes()); + + let result_bytes = dispatcher + .dispatch(&ctx, calculator::methods::ADD, &payload, &codec) + .unwrap(); + let result: i32 = codec.deserialize(&result_bytes).unwrap(); + assert_eq!(result, 5); +} + +#[test] +fn test_dispatch_multiply() { + let codec = Codec::bincode(); + let ctx = RpcContext::new(1, 0, 0); + let dispatcher = calculator::server(MyCalculator); + + // bincode serialization of MultiplyRequest { a: 3.0, b: 4.0 }: two f64 LE + let mut payload = Vec::new(); + payload.extend_from_slice(&3.0f64.to_le_bytes()); + payload.extend_from_slice(&4.0f64.to_le_bytes()); + + let result_bytes = dispatcher + .dispatch(&ctx, calculator::methods::MULTIPLY, &payload, &codec) + .unwrap(); + let result: f64 = codec.deserialize(&result_bytes).unwrap(); + assert!((result - 12.0).abs() < f64::EPSILON); +} + +#[test] +fn test_dispatch_method_not_found() { + let codec = Codec::bincode(); + let ctx = RpcContext::new(1, 0, 0); + let dispatcher = calculator::server(MyCalculator); + + let err = dispatcher.dispatch(&ctx, 999, &[], &codec); + assert!(err.is_err()); +} + +#[test] +fn test_server_builds_and_stops() { + let event_loop = Arc::new(EventLoop::new(2, 1024, 100).unwrap()); + + let server = RpcServer::builder() + .bind("127.0.0.1:0".parse().unwrap()) + .service(calculator::server(MyCalculator)) + .build(&event_loop); + + match server { + Ok(_s) => { + let el = event_loop.clone(); + let h = thread::spawn(move || { + let _ = el.run(); + }); + thread::sleep(Duration::from_millis(50)); + event_loop.stop(); + let _ = h.join(); + } + Err(e) => { + eprintln!("Server build failed (non-fatal in test): {}", e); + } + } +}