From b542d9ebb7883cdd26ad7fc9778191004ab05cc2 Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Fri, 11 Jul 2025 09:14:52 -0700 Subject: [PATCH 1/5] move work in fork to tonic/grpc --- grpc/Cargo.toml | 30 +- grpc/examples/inmemory.rs | 74 +++ grpc/examples/multiaddr.rs | 101 ++++ grpc/src/attributes.rs | 2 +- grpc/src/client/channel.rs | 539 ++++++++++++++++- .../client/load_balancing/child_manager.rs | 127 ++-- grpc/src/client/load_balancing/mod.rs | 347 +++++++++-- grpc/src/client/load_balancing/pick_first.rs | 116 ++++ grpc/src/client/load_balancing/registry.rs | 42 ++ grpc/src/client/mod.rs | 6 +- grpc/src/client/name_resolution/mod.rs | 56 +- grpc/src/client/service.rs | 29 - grpc/src/client/service_config.rs | 26 + grpc/src/client/subchannel.rs | 550 ++++++++++++++++++ grpc/src/client/transport/mod.rs | 16 + grpc/src/client/transport/registry.rs | 62 ++ grpc/src/credentials/mod.rs | 1 + grpc/src/inmemory/mod.rs | 177 ++++++ grpc/src/lib.rs | 8 +- grpc/src/rt/mod.rs | 12 +- grpc/src/rt/tokio/mod.rs | 5 +- grpc/src/server/mod.rs | 41 ++ 22 files changed, 2203 insertions(+), 164 deletions(-) create mode 100644 grpc/examples/inmemory.rs create mode 100644 grpc/examples/multiaddr.rs create mode 100644 grpc/src/client/load_balancing/pick_first.rs create mode 100644 grpc/src/client/load_balancing/registry.rs delete mode 100644 grpc/src/client/service.rs create mode 100644 grpc/src/client/subchannel.rs create mode 100644 grpc/src/client/transport/mod.rs create mode 100644 grpc/src/client/transport/registry.rs create mode 100644 grpc/src/credentials/mod.rs create mode 100644 grpc/src/inmemory/mod.rs create mode 100644 grpc/src/server/mod.rs diff --git a/grpc/Cargo.toml b/grpc/Cargo.toml index 5438e7176..95414a591 100644 --- a/grpc/Cargo.toml +++ b/grpc/Cargo.toml @@ -6,19 +6,33 @@ authors = ["gRPC Authors"] license = "MIT" [dependencies] -url = "2.5.0" -tokio = { version = "1.37.0", features = ["sync", "rt", "net", "time", "macros"] } -tonic = { version = "0.14.0", path = "../tonic", default-features = false, features = ["codegen"] } +bytes = "1.10.1" futures-core = "0.3.31" -serde_json = "1.0.140" -serde = "1.0.219" +futures-util = "0.3.31" hickory-resolver = { version = "0.25.1", optional = true } -rand = "0.9" +http = "1.1.0" +http-body = "1.0.1" +hyper = { version = "1.6.0", features = ["client", "http2"] } +hyper-util = "0.1.14" +once_cell = "1.19.0" parking_lot = "0.12.4" -bytes = "1.10.1" +pin-project-lite = "0.2.16" +rand = "0.9" +serde = { version = "1.0.219", features = ["derive"] } +serde_json = "1.0.140" +socket2 = "0.5.10" +tokio = { version = "1.37.0", features = ["sync", "rt", "net", "time", "macros"] } +tokio-stream = "0.1.17" +tonic = { version = "0.14.0", path = "../tonic", default-features = false, features = ["codegen", "transport"] } +tower = "0.5.2" +tower-service = "0.3.3" +url = "2.5.0" [dev-dependencies] +async-stream = "0.3.6" +tonic = { version = "0.14.0", path = "../tonic", default-features = false, features = ["prost", "server", "router"] } hickory-server = "0.25.2" +prost = "0.13.5" [features] default = ["dns"] @@ -28,4 +42,4 @@ dns = ["dep:hickory-resolver"] allowed_external_types = [ "tonic::*", "futures_core::stream::Stream", -] \ No newline at end of file +] diff --git a/grpc/examples/inmemory.rs b/grpc/examples/inmemory.rs new file mode 100644 index 000000000..919f484b3 --- /dev/null +++ b/grpc/examples/inmemory.rs @@ -0,0 +1,74 @@ +use std::any::Any; + +use futures_util::stream::StreamExt; +use grpc::service::{Message, Request, Response, Service}; +use grpc::{client::ChannelOptions, inmemory}; +use tonic::async_trait; + +struct Handler {} + +#[derive(Debug)] +struct MyReqMessage(String); + +impl Message for MyReqMessage {} + +#[derive(Debug)] +struct MyResMessage(String); +impl Message for MyResMessage {} + +#[async_trait] +impl Service for Handler { + async fn call(&self, method: String, request: Request) -> Response { + let mut stream = request.into_inner(); + let output = async_stream::try_stream! { + while let Some(req) = stream.next().await { + yield Box::new(MyResMessage(format!( + "Server: responding to: {}; msg: {}", + method, (req as Box).downcast_ref::().unwrap().0, + ))) as Box; + } + }; + + Response::new(Box::pin(output)) + } +} + +#[tokio::main] +async fn main() { + inmemory::reg(); + + // Spawn the server. + let lis = inmemory::Listener::new(); + let mut srv = grpc::server::Server::new(); + srv.set_handler(Handler {}); + let lis_clone = lis.clone(); + tokio::task::spawn(async move { + srv.serve(&lis_clone).await; + println!("serve returned for listener 1!"); + }); + + println!("Creating channel for {}", lis.target()); + let chan_opts = ChannelOptions::default(); + let chan = grpc::client::Channel::new(lis.target().as_str(), None, chan_opts); + + let outbound = async_stream::stream! { + yield Box::new(MyReqMessage("My Request 1".to_string())) as Box; + yield Box::new(MyReqMessage("My Request 2".to_string())); + yield Box::new(MyReqMessage("My Request 3".to_string())); + }; + + let req = Request::new(Box::pin(outbound)); + let res = chan.call("/some/method".to_string(), req).await; + let mut res = res.into_inner(); + + while let Some(resp) = res.next().await { + println!( + "CALL RESPONSE: {}", + (resp.unwrap() as Box) + .downcast_ref::() + .unwrap() + .0, + ); + } + lis.close().await; +} diff --git a/grpc/examples/multiaddr.rs b/grpc/examples/multiaddr.rs new file mode 100644 index 000000000..2a48ca3e5 --- /dev/null +++ b/grpc/examples/multiaddr.rs @@ -0,0 +1,101 @@ +use std::any::Any; + +use futures_util::StreamExt; +use grpc::service::{Message, Request, Response, Service}; +use grpc::{client::ChannelOptions, inmemory}; +use tonic::async_trait; + +struct Handler { + id: String, +} + +#[derive(Debug)] +struct MyReqMessage(String); + +impl Message for MyReqMessage {} + +#[derive(Debug)] +struct MyResMessage(String); +impl Message for MyResMessage {} + +#[async_trait] +impl Service for Handler { + async fn call(&self, method: String, request: Request) -> Response { + let id = self.id.clone(); + let mut stream = request.into_inner(); + let output = async_stream::try_stream! { + while let Some(req) = stream.next().await { + yield Box::new(MyResMessage(format!( + "Server {}: responding to: {}; msg: {}", + id, method, (req as Box).downcast_ref::().unwrap().0, + ))) as Box; + } + }; + + Response::new(Box::pin(output)) + } +} + +#[tokio::main] +async fn main() { + inmemory::reg(); + + // Spawn the first server. + let lis1 = inmemory::Listener::new(); + let mut srv = grpc::server::Server::new(); + srv.set_handler(Handler { id: lis1.id() }); + let lis1_clone = lis1.clone(); + tokio::task::spawn(async move { + srv.serve(&lis1_clone).await; + println!("serve returned for listener 1!"); + }); + + // Spawn the second server. + let lis2 = inmemory::Listener::new(); + let mut srv = grpc::server::Server::new(); + srv.set_handler(Handler { id: lis2.id() }); + let lis2_clone = lis2.clone(); + tokio::task::spawn(async move { + srv.serve(&lis2_clone).await; + println!("serve returned for listener 2!"); + }); + + // Spawn the third server. + let lis3 = inmemory::Listener::new(); + let mut srv = grpc::server::Server::new(); + srv.set_handler(Handler { id: lis3.id() }); + let lis3_clone = lis3.clone(); + tokio::task::spawn(async move { + srv.serve(&lis3_clone).await; + println!("serve returned for listener 3!"); + }); + + let target = String::from("inmemory:///dummy"); + println!("Creating channel for {target}"); + let chan_opts = ChannelOptions::default(); + let chan = grpc::client::Channel::new(target.as_str(), None, chan_opts); + + let outbound = async_stream::stream! { + yield Box::new(MyReqMessage("My Request 1".to_string())) as Box; + yield Box::new(MyReqMessage("My Request 2".to_string())); + yield Box::new(MyReqMessage("My Request 3".to_string())); + }; + + let req = Request::new(Box::pin(outbound)); + let res = chan.call("/some/method".to_string(), req).await; + let mut res = res.into_inner(); + + while let Some(resp) = res.next().await { + println!( + "CALL RESPONSE: {}", + (resp.unwrap() as Box) + .downcast_ref::() + .unwrap() + .0, + ); + } + + lis1.close().await; + lis2.close().await; + lis3.close().await; +} diff --git a/grpc/src/attributes.rs b/grpc/src/attributes.rs index 01397f902..3d490266f 100644 --- a/grpc/src/attributes.rs +++ b/grpc/src/attributes.rs @@ -24,5 +24,5 @@ /// A key-value store for arbitrary configuration data between multiple /// pluggable components. -#[derive(Debug, Default, Clone)] +#[derive(Debug, Default, Clone, PartialEq, PartialOrd, Eq, Ord)] pub struct Attributes; diff --git a/grpc/src/client/channel.rs b/grpc/src/client/channel.rs index 4ffde2cf5..6e41c6759 100644 --- a/grpc/src/client/channel.rs +++ b/grpc/src/client/channel.rs @@ -22,19 +22,119 @@ * */ -use std::{any::Any, str::FromStr, sync::Arc}; +use core::panic; +use std::{ + any::Any, + collections::HashMap, + error::Error, + fmt::Display, + mem, + ops::Add, + str::FromStr, + sync::{Arc, Mutex, Weak}, + time::{Duration, Instant}, + vec, +}; -use url::Url; +use tokio::sync::{mpsc, oneshot, watch, Notify}; +use tokio::task::AbortHandle; -use crate::service::{Request, Response}; +use serde_json::json; +use tonic::async_trait; +use url::Url; // NOTE: http::Uri requires non-empty authority portion of URI -use super::ConnectivityState; +use crate::credentials::Credentials; +use crate::rt; +use crate::service::{Request, Response, Service}; +use crate::{attributes::Attributes, rt::tokio::TokioRuntime}; +use crate::{client::ConnectivityState, rt::Runtime}; -#[derive(Default)] -pub struct ChannelOptions {} +use super::service_config::ServiceConfig; +use super::transport::{TransportRegistry, GLOBAL_TRANSPORT_REGISTRY}; +use super::{ + load_balancing::{ + self, pick_first, ExternalSubchannel, LbPolicy, LbPolicyBuilder, LbPolicyOptions, + LbPolicyRegistry, LbState, ParsedJsonLbConfig, PickResult, Picker, Subchannel, + SubchannelState, WorkScheduler, GLOBAL_LB_REGISTRY, + }, + subchannel::{ + InternalSubchannel, InternalSubchannelPool, NopBackoff, SubchannelKey, + SubchannelStateWatcher, + }, +}; +use super::{ + name_resolution::{ + self, global_registry, Address, ResolverBuilder, ResolverOptions, ResolverUpdate, + }, + subchannel, +}; -impl ChannelOptions {} +#[non_exhaustive] +pub struct ChannelOptions { + pub transport_options: Attributes, // ? + pub override_authority: Option, + pub connection_backoff: Option, + pub default_service_config: Option, + pub disable_proxy: bool, + pub disable_service_config_lookup: bool, + pub disable_health_checks: bool, + pub max_retry_memory: u32, // ? + pub idle_timeout: Duration, + // TODO: pub transport_registry: Option, + // TODO: pub name_resolver_registry: Option, + // TODO: pub lb_policy_registry: Option, + // Typically we allow settings at the channel level that impact all RPCs, + // but can also be set per-RPC. E.g.s: + // + // - interceptors + // - user-agent string override + // - max message sizes + // - max retry/hedged attempts + // - disable retry + // + // In gRPC-Go, we can express CallOptions as DialOptions, which is a nice + // pattern: https://pkg.go.dev/google.golang.org/grpc#WithDefaultCallOptions + // + // To do this in rust, all optional behavior for a request would need to be + // expressed through a trait that applies a mutation to a request. We'd + // apply all those mutations before the user's options so the user's options + // would override the defaults, or so the defaults would occur first. + pub default_request_extensions: Vec>, // ?? +} + +impl Default for ChannelOptions { + fn default() -> Self { + Self { + transport_options: Attributes {}, + override_authority: None, + connection_backoff: None, + default_service_config: None, + disable_proxy: false, + disable_service_config_lookup: false, + disable_health_checks: false, + max_retry_memory: 8 * 1024 * 1024, // 8MB -- ??? + idle_timeout: Duration::from_secs(30 * 60), + default_request_extensions: vec![], + } + } +} + +impl ChannelOptions { + pub fn transport_options(self, transport_options: TODO) -> Self { + todo!(); // add to existing options. + } + pub fn override_authority(self, authority: String) -> Self { + Self { + override_authority: Some(authority), + ..self + } + } + // etc +} + +// All of Channel needs to be thread-safe. Arc? Or give out +// Arc from constructor? #[derive(Clone)] pub struct Channel { inner: Arc, @@ -48,31 +148,66 @@ impl Channel { // TODO: should this return a Result instead? pub fn new( target: &str, - credentials: Option>, // TODO: Credentials trait - runtime: Option>, // TODO: Runtime trait + credentials: Option>, options: ChannelOptions, ) -> Self { + pick_first::reg(); Self { inner: Arc::new(PersistentChannel::new( target, credentials, - runtime, + Arc::new(rt::tokio::TokioRuntime {}), options, )), } } + // TODO: enter_idle(&self) and graceful_stop()? + /// Returns the current state of the channel. - // TODO: replace with a watcher that provides state change updates? - pub fn state(&mut self, _connect: bool) -> ConnectivityState { + pub fn state(&mut self, connect: bool) -> ConnectivityState { + let ac = if !connect { + // If !connect and we have no active channel already, return idle. + let ac = self.inner.active_channel.lock().unwrap(); + if ac.is_none() { + return ConnectivityState::Idle; + } + ac.as_ref().unwrap().clone() + } else { + // Otherwise, get or create the active channel. + self.get_or_create_active_channel() + }; + if let Some(s) = ac.connectivity_state.cur() { + return s; + } + ConnectivityState::Idle + } + + /// Waits for the state of the channel to change from source. Times out and + /// returns an error after the deadline. + pub async fn wait_for_state_change( + &self, + source: ConnectivityState, + deadline: Instant, + ) -> Result<(), Box> { todo!() } - /// Performs an RPC on the channel. Response will contain any response - /// messages from the server and/or errors returned by the server or - /// generated locally. - pub async fn call(&self, _request: Request) -> Response { - todo!() // create the active channel if necessary and call it. + fn get_or_create_active_channel(&self) -> Arc { + let mut s = self.inner.active_channel.lock().unwrap(); + if s.is_none() { + *s = Some(ActiveChannel::new( + self.inner.target.clone(), + &self.inner.options, + self.inner.runtime.clone(), + )); + } + s.clone().unwrap() + } + + pub async fn call(&self, method: String, request: Request) -> Response { + let ac = self.get_or_create_active_channel(); + ac.call(method, request).await } } @@ -83,7 +218,8 @@ impl Channel { struct PersistentChannel { target: Url, options: ChannelOptions, - // TODO: active_channel: Mutex>>, + active_channel: Mutex>>, + runtime: Arc, } impl PersistentChannel { @@ -91,13 +227,376 @@ impl PersistentChannel { // are not in ChannelOptions. fn new( target: &str, - _credentials: Option>, - _runtime: Option>, + _credentials: Option>, + runtime: Arc, options: ChannelOptions, ) -> Self { Self { target: Url::from_str(target).unwrap(), // TODO handle err + active_channel: Mutex::default(), options, + runtime, + } + } +} + +struct ActiveChannel { + cur_state: Mutex, + abort_handle: Box, + picker: Arc>>, + connectivity_state: Arc>, + runtime: Arc, +} + +impl ActiveChannel { + fn new(target: Url, options: &ChannelOptions, runtime: Arc) -> Arc { + let (tx, mut rx) = mpsc::unbounded_channel::(); + let transport_registry = GLOBAL_TRANSPORT_REGISTRY.clone(); + + let resolve_now = Arc::new(Notify::new()); + let connectivity_state = Arc::new(Watcher::new()); + let picker = Arc::new(Watcher::new()); + let mut channel_controller = InternalChannelController::new( + transport_registry, + resolve_now.clone(), + tx.clone(), + picker.clone(), + connectivity_state.clone(), + ); + + let resolver_helper = Box::new(tx.clone()); + + // TODO(arjan-bal): Return error here instead of panicking. + let rb = global_registry().get(target.scheme()).unwrap(); + let target = name_resolution::Target::from(target); + let authority = target.authority_host_port(); + let authority = if authority.is_empty() { + rb.default_authority(&target).to_owned() + } else { + authority + }; + let work_scheduler = Arc::new(ResolverWorkScheduler { wqtx: tx }); + let resolver_opts = name_resolution::ResolverOptions { + authority, + work_scheduler, + runtime: Arc::new(TokioRuntime {}), + }; + let resolver = rb.build(&target, resolver_opts); + + let jh = runtime.spawn(Box::pin(async move { + let mut resolver = resolver; + while let Some(w) = rx.recv().await { + match w { + WorkQueueItem::Closure(func) => func(&mut channel_controller), + WorkQueueItem::ScheduleResolver => resolver.work(&mut channel_controller), + } + } + })); + + Arc::new(Self { + cur_state: Mutex::new(ConnectivityState::Connecting), + abort_handle: jh, + picker: picker.clone(), + connectivity_state: connectivity_state.clone(), + runtime, + }) + } + + async fn call(&self, method: String, request: Request) -> Response { + // TODO: pre-pick tasks (e.g. deadlines, interceptors, retry) + let mut i = self.picker.iter(); + loop { + if let Some(p) = i.next().await { + let result = &p.pick(&request); + // TODO: handle picker errors (queue or fail RPC) + match result { + PickResult::Pick(pr) => { + if let Some(sc) = (pr.subchannel.as_ref() as &dyn Any) + .downcast_ref::() + { + return sc.isc.as_ref().unwrap().call(method, request).await; + } else { + panic!("picked subchannel is not an implementation provided by the channel"); + } + } + PickResult::Queue => { + // Continue and retry the RPC with the next picker. + } + PickResult::Fail(status) => { + panic!("failed pick: {}", status); + } + PickResult::Drop(status) => { + panic!("dropped pick: {}", status); + } + } + } + } + } +} + +impl Drop for ActiveChannel { + fn drop(&mut self) { + self.abort_handle.abort(); + } +} + +struct ResolverWorkScheduler { + wqtx: WorkQueueTx, +} + +pub(super) type WorkQueueTx = mpsc::UnboundedSender; + +impl name_resolution::WorkScheduler for ResolverWorkScheduler { + fn schedule_work(&self) { + let _ = self.wqtx.send(WorkQueueItem::ScheduleResolver); + } +} + +pub(crate) struct InternalChannelController { + pub(super) lb: Arc, // called and passes mutable parent to it, so must be Arc. + transport_registry: TransportRegistry, + pub(super) subchannel_pool: Arc, + resolve_now: Arc, + wqtx: WorkQueueTx, + picker: Arc>>, + connectivity_state: Arc>, +} + +impl InternalChannelController { + fn new( + transport_registry: TransportRegistry, + resolve_now: Arc, + wqtx: WorkQueueTx, + picker: Arc>>, + connectivity_state: Arc>, + ) -> Self { + let lb = Arc::new(GracefulSwitchBalancer::new(wqtx.clone())); + + Self { + lb, + transport_registry, + subchannel_pool: Arc::new(InternalSubchannelPool::new()), + resolve_now, + wqtx, + picker, + connectivity_state, + } + } + + fn new_esc_for_isc(&self, isc: Arc) -> Arc { + let sc = Arc::new(ExternalSubchannel::new(isc.clone(), self.wqtx.clone())); + let watcher = Arc::new(SubchannelStateWatcher::new(sc.clone(), self.wqtx.clone())); + sc.set_watcher(watcher.clone()); + isc.register_connectivity_state_watcher(watcher.clone()); + sc + } +} + +impl name_resolution::ChannelController for InternalChannelController { + fn update(&mut self, update: ResolverUpdate) -> Result<(), String> { + let lb = self.lb.clone(); + lb.handle_resolver_update(update, self) + .map_err(|err| err.to_string()) + } + + fn parse_service_config(&self, config: &str) -> Result { + Err("service configs not supported".to_string()) + } +} + +impl load_balancing::ChannelController for InternalChannelController { + fn new_subchannel(&mut self, address: &Address) -> Arc { + let key = SubchannelKey::new(address.clone()); + if let Some(isc) = self.subchannel_pool.lookup_subchannel(&key) { + return self.new_esc_for_isc(isc); + } + + // If we get here, it means one of two things: + // 1. provided key is not found in the map + // 2. provided key points to an unpromotable value, which can occur if + // its internal subchannel has been dropped but hasn't been + // unregistered yet. + + let transport = self + .transport_registry + .get_transport(address.network_type) + .unwrap(); + let scp = self.subchannel_pool.clone(); + let isc = InternalSubchannel::new( + key.clone(), + transport, + Arc::new(NopBackoff {}), + Box::new(move |k: SubchannelKey| { + scp.unregister_subchannel(&k); + }), + ); + let _ = self.subchannel_pool.register_subchannel(&key, isc.clone()); + self.new_esc_for_isc(isc) + } + + fn update_picker(&mut self, update: LbState) { + println!( + "update picker called with state: {:?}", + update.connectivity_state + ); + self.picker.update(update.picker); + self.connectivity_state.update(update.connectivity_state); + } + + fn request_resolution(&mut self) { + self.resolve_now.notify_one(); + } +} + +// A channel that is not idle (connecting, ready, or erroring). +pub(super) struct GracefulSwitchBalancer { + pub(super) policy: Mutex>>, + policy_builder: Mutex>>, + work_scheduler: WorkQueueTx, + pending: Mutex, +} + +impl WorkScheduler for GracefulSwitchBalancer { + fn schedule_work(&self) { + if mem::replace(&mut *self.pending.lock().unwrap(), true) { + // Already had a pending call scheduled. + return; + } + let _ = self.work_scheduler.send(WorkQueueItem::Closure(Box::new( + |c: &mut InternalChannelController| { + *c.lb.pending.lock().unwrap() = false; + c.lb.clone() + .policy + .lock() + .unwrap() + .as_mut() + .unwrap() + .work(c); + }, + ))); + } +} + +impl GracefulSwitchBalancer { + fn new(work_scheduler: WorkQueueTx) -> Self { + Self { + policy_builder: Mutex::default(), + policy: Mutex::default(), // new(None::>), + work_scheduler, + pending: Mutex::default(), + } + } + + fn handle_resolver_update( + self: &Arc, + update: ResolverUpdate, + controller: &mut InternalChannelController, + ) -> Result<(), Box> { + if update.service_config.as_ref().is_ok_and(|sc| sc.is_some()) { + return Err("can't do service configs yet".into()); + } + let policy_name = pick_first::POLICY_NAME; + let mut p = self.policy.lock().unwrap(); + if p.is_none() { + let builder = GLOBAL_LB_REGISTRY.get_policy(policy_name).unwrap(); + let newpol = builder.build(LbPolicyOptions { + work_scheduler: self.clone(), + }); + *self.policy_builder.lock().unwrap() = Some(builder); + *p = Some(newpol); + } + + // TODO: config should come from ServiceConfig. + let builder = self.policy_builder.lock().unwrap(); + let config = match builder + .as_ref() + .unwrap() + .parse_config(&ParsedJsonLbConfig::from_value( + json!({"shuffleAddressList": true, "unknown_field": false}), + )) { + Ok(cfg) => cfg, + Err(e) => { + return Err(e); + } + }; + + p.as_mut() + .unwrap() + .resolver_update(update, config.as_ref(), controller) + + // TODO: close old LB policy gracefully vs. drop? + } + pub(super) fn subchannel_update( + &self, + subchannel: Arc, + state: &SubchannelState, + channel_controller: &mut dyn load_balancing::ChannelController, + ) { + let mut p = self.policy.lock().unwrap(); + + p.as_mut() + .unwrap() + .subchannel_update(subchannel, state, channel_controller); + } +} + +pub(super) enum WorkQueueItem { + // Execute the closure. + Closure(Box), + // Call the resolver to do work. + ScheduleResolver, +} + +pub struct TODO; + +// Enables multiple receivers to view data output from a single producer. +// Producer calls update. Consumers call iter() and call next() until they find +// a good value or encounter None. +pub(crate) struct Watcher { + tx: watch::Sender>, + rx: watch::Receiver>, +} + +impl Watcher { + fn new() -> Self { + let (tx, rx) = watch::channel(None); + Self { tx, rx } + } + + pub(crate) fn iter(&self) -> WatcherIter { + let mut rx = self.rx.clone(); + rx.mark_changed(); + WatcherIter { rx } + } + + pub(crate) fn cur(&self) -> Option { + let mut rx = self.rx.clone(); + rx.mark_changed(); + let c = rx.borrow(); + c.clone() + } + + fn update(&self, item: T) { + self.tx.send(Some(item)).unwrap(); + } +} + +pub(crate) struct WatcherIter { + rx: watch::Receiver>, +} +// TODO: Use an arc_swap::ArcSwap instead that contains T and a channel closed +// when T is updated. Even if the channel needs a lock, the fast path becomes +// lock-free. + +impl WatcherIter { + /// Returns the next unseen value + pub(crate) async fn next(&mut self) -> Option { + loop { + self.rx.changed().await.ok()?; + let x = self.rx.borrow_and_update(); + if x.is_some() { + return x.clone(); + } } } } diff --git a/grpc/src/client/load_balancing/child_manager.rs b/grpc/src/client/load_balancing/child_manager.rs index 0d1c880f6..0d4af6542 100644 --- a/grpc/src/client/load_balancing/child_manager.rs +++ b/grpc/src/client/load_balancing/child_manager.rs @@ -29,10 +29,13 @@ // policy in use. Complete tests must be written before it can be used in // production. Also, support for the work scheduler is missing. +use std::collections::HashSet; +use std::sync::Mutex; use std::{collections::HashMap, error::Error, hash::Hash, mem, sync::Arc}; use crate::client::load_balancing::{ - ChannelController, LbConfig, LbPolicy, LbPolicyBuilder, LbPolicyOptions, LbState, WorkScheduler, + ChannelController, LbConfig, LbPolicy, LbPolicyBuilder, LbPolicyOptions, LbState, + WeakSubchannel, WorkScheduler, }; use crate::client::name_resolution::{Address, ResolverUpdate}; @@ -40,15 +43,17 @@ use super::{Subchannel, SubchannelState}; // An LbPolicy implementation that manages multiple children. pub struct ChildManager { - subchannel_child_map: HashMap, + subchannel_child_map: HashMap, children: Vec>, - shard_update: Box>, + update_sharder: Box>, + pending_work: Arc>>, } struct Child { identifier: T, policy: Box, state: LbState, + work_scheduler: Arc, } /// A collection of data sent to a child of the ChildManager. @@ -57,28 +62,31 @@ pub struct ChildUpdate { pub child_identifier: T, /// The builder the ChildManager should use to create this child if it does /// not exist. - pub child_policy_builder: Box, + pub child_policy_builder: Arc, /// The relevant ResolverUpdate to send to this child. pub child_update: ResolverUpdate, } -// TODO: convert to a trait? -/// Performs the operation of sharding an aggregate ResolverUpdate into one or -/// more ChildUpdates. Called automatically by the ChildManager when its -/// resolver_update method is called. -pub type ResolverUpdateSharder = - fn( - ResolverUpdate, +pub trait ResolverUpdateSharder: Send { + /// Performs the operation of sharding an aggregate ResolverUpdate into one + /// or more ChildUpdates. Called automatically by the ChildManager when its + /// resolver_update method is called. The key in the returned map is the + /// identifier the ChildManager should use for this child. + fn shard_update( + &self, + resolver_update: ResolverUpdate, ) -> Result>>, Box>; +} -impl ChildManager { +impl ChildManager { /// Creates a new ChildManager LB policy. shard_update is called whenever a /// resolver_update operation occurs. - pub fn new(shard_update: Box>) -> Self { + pub fn new(update_sharder: Box>) -> Self { Self { - subchannel_child_map: HashMap::default(), - children: Vec::default(), - shard_update, + update_sharder, + subchannel_child_map: Default::default(), + children: Default::default(), + pending_work: Default::default(), } } @@ -103,7 +111,7 @@ impl ChildManager { ) { // Add all created subchannels into the subchannel_child_map. for csc in channel_controller.created_subchannels { - self.subchannel_child_map.insert(csc, child_idx); + self.subchannel_child_map.insert(csc.into(), child_idx); } // Update the tracked state if the child produced an update. if let Some(state) = channel_controller.picker_update { @@ -112,7 +120,7 @@ impl ChildManager { } } -impl LbPolicy for ChildManager { +impl LbPolicy for ChildManager { fn resolver_update( &mut self, resolver_update: ResolverUpdate, @@ -120,20 +128,25 @@ impl LbPolicy for ChildManager { channel_controller: &mut dyn ChannelController, ) -> Result<(), Box> { // First determine if the incoming update is valid. - let child_updates = (self.shard_update)(resolver_update)?; + let child_updates = self.update_sharder.shard_update(resolver_update)?; + + // Hold the lock to prevent new work requests during this operation and + // rewrite the indices. + let mut pending_work = self.pending_work.lock().unwrap(); + + // Reset pending work; we will re-add any entries it contains with the + // right index later. + let old_pending_work = mem::take(&mut *pending_work); // Replace self.children with an empty vec. - let mut old_children = vec![]; - mem::swap(&mut self.children, &mut old_children); + let old_children = mem::take(&mut self.children); // Replace the subchannel map with an empty map. - let mut old_subchannel_child_map = HashMap::new(); - mem::swap( - &mut self.subchannel_child_map, - &mut old_subchannel_child_map, - ); + let old_subchannel_child_map = mem::take(&mut self.subchannel_child_map); + // Reverse the old subchannel map. - let mut old_child_subchannels_map: HashMap> = HashMap::new(); + let mut old_child_subchannels_map: HashMap> = HashMap::new(); + for (subchannel, child_idx) in old_subchannel_child_map { old_child_subchannels_map .entry(child_idx) @@ -145,7 +158,7 @@ impl LbPolicy for ChildManager { let old_children = old_children .into_iter() .enumerate() - .map(|(old_idx, e)| (e.identifier, (e.policy, e.state, old_idx))); + .map(|(old_idx, e)| (e.identifier, (e.policy, e.state, old_idx, e.work_scheduler))); let mut old_children: HashMap = old_children.collect(); // Split the child updates into the IDs and builders, and the @@ -158,7 +171,8 @@ impl LbPolicy for ChildManager { // update, and create new children. Add entries back into the // subchannel map. for (new_idx, (identifier, builder)) in ids_builders.into_iter().enumerate() { - if let Some((policy, state, old_idx)) = old_children.remove(&identifier) { + if let Some((policy, state, old_idx, work_scheduler)) = old_children.remove(&identifier) + { for subchannel in old_child_subchannels_map .remove(&old_idx) .into_iter() @@ -166,24 +180,43 @@ impl LbPolicy for ChildManager { { self.subchannel_child_map.insert(subchannel, new_idx); } + if old_pending_work.contains(&old_idx) { + pending_work.insert(new_idx); + } + *work_scheduler.idx.lock().unwrap() = Some(new_idx); self.children.push(Child { identifier, state, policy, + work_scheduler, }); } else { + let work_scheduler = Arc::new(ChildWorkScheduler { + pending_work: self.pending_work.clone(), + idx: Mutex::new(Some(new_idx)), + }); let policy = builder.build(LbPolicyOptions { - work_scheduler: Arc::new(UnimplWorkScheduler {}), + work_scheduler: work_scheduler.clone(), }); let state = LbState::initial(); self.children.push(Child { identifier, state, policy, + work_scheduler, }); }; } + // Invalidate all deleted children's work_schedulers. + for (_, (_, _, _, work_scheduler)) in old_children { + *work_scheduler.idx.lock().unwrap() = None; + } + + // Release the pending_work mutex before calling into the children to + // allow their work scheduler calls to unblock. + drop(pending_work); + // Anything left in old_children will just be Dropped and cleaned up. // Call resolver_update on all children. @@ -202,12 +235,15 @@ impl LbPolicy for ChildManager { fn subchannel_update( &mut self, - subchannel: &Subchannel, + subchannel: Arc, state: &SubchannelState, channel_controller: &mut dyn ChannelController, ) { // Determine which child created this subchannel. - let child_idx = *self.subchannel_child_map.get(subchannel).unwrap(); + let child_idx = *self + .subchannel_child_map + .get(&WeakSubchannel::new(&subchannel)) + .unwrap(); let policy = &mut self.children[child_idx].policy; // Wrap the channel_controller to track the child's operations. let mut channel_controller = WrappedController::new(channel_controller); @@ -216,14 +252,21 @@ impl LbPolicy for ChildManager { self.resolve_child_controller(channel_controller, child_idx); } - fn work(&mut self, _channel_controller: &mut dyn ChannelController) { - todo!(); + fn work(&mut self, channel_controller: &mut dyn ChannelController) { + let child_idxes = mem::take(&mut *self.pending_work.lock().unwrap()); + for child_idx in child_idxes { + let mut channel_controller = WrappedController::new(channel_controller); + self.children[child_idx] + .policy + .work(&mut channel_controller); + self.resolve_child_controller(channel_controller, child_idx); + } } } struct WrappedController<'a> { channel_controller: &'a mut dyn ChannelController, - created_subchannels: Vec, + created_subchannels: Vec>, picker_update: Option, } @@ -238,7 +281,7 @@ impl<'a> WrappedController<'a> { } impl ChannelController for WrappedController<'_> { - fn new_subchannel(&mut self, address: &Address) -> Subchannel { + fn new_subchannel(&mut self, address: &Address) -> Arc { let subchannel = self.channel_controller.new_subchannel(address); self.created_subchannels.push(subchannel.clone()); subchannel @@ -253,10 +296,16 @@ impl ChannelController for WrappedController<'_> { } } -pub struct UnimplWorkScheduler; +struct ChildWorkScheduler { + pending_work: Arc>>, // Must be taken first for correctness + idx: Mutex>, // None if the child is deleted. +} -impl WorkScheduler for UnimplWorkScheduler { +impl WorkScheduler for ChildWorkScheduler { fn schedule_work(&self) { - todo!(); + let mut pending_work = self.pending_work.lock().unwrap(); + if let Some(idx) = *self.idx.lock().unwrap() { + pending_work.insert(idx); + } } } diff --git a/grpc/src/client/load_balancing/mod.rs b/grpc/src/client/load_balancing/mod.rs index 0f1808a59..16b4cafbe 100644 --- a/grpc/src/client/load_balancing/mod.rs +++ b/grpc/src/client/load_balancing/mod.rs @@ -22,20 +22,42 @@ * */ -pub mod child_manager; - -use std::{any::Any, error::Error, hash::Hash, sync::Arc}; - +use core::panic; +use serde::de; +use std::{ + any::Any, + collections::HashMap, + error::Error, + fmt::{Debug, Display}, + hash::{Hash, Hasher}, + ops::{Add, Sub}, + sync::{ + atomic::{AtomicI64, Ordering::Relaxed}, + Arc, Mutex, Weak, + }, +}; +use tokio::sync::{mpsc::Sender, Notify}; use tonic::{metadata::MetadataMap, Status}; use crate::{ - client::{ - name_resolution::{Address, ResolverUpdate}, - ConnectivityState, - }, - service::Request, + client::channel::WorkQueueTx, + service::{Request, Response, Service}, +}; + +use crate::client::{ + channel::{InternalChannelController, WorkQueueItem}, + name_resolution::{Address, ResolverUpdate}, + subchannel::InternalSubchannel, + ConnectivityState, }; +pub mod child_manager; +pub mod pick_first; + +pub(crate) mod registry; +use super::{service_config::LbConfig, subchannel::SubchannelStateWatcher}; +pub(crate) use registry::{LbPolicyRegistry, GLOBAL_LB_REGISTRY}; + /// A collection of data configured on the channel that is constructing this /// LbPolicy. pub struct LbPolicyOptions { @@ -71,6 +93,10 @@ impl ParsedJsonLbConfig { } } + pub(crate) fn from_value(value: serde_json::Value) -> Self { + Self { value } + } + /// Converts the JSON configuration into a concrete type that represents the /// configuration of an LB policy. /// @@ -91,7 +117,7 @@ impl ParsedJsonLbConfig { /// An LB policy factory that produces LbPolicy instances used by the channel /// to manage connections and pick connections for RPCs. -pub trait LbPolicyBuilder: Send + Sync { +pub(crate) trait LbPolicyBuilder: Send + Sync { /// Builds and returns a new LB policy instance. /// /// Note that build must not fail. Any optional configuration is delivered @@ -135,7 +161,7 @@ pub trait LbPolicy: Send { /// changes state. fn subchannel_update( &mut self, - subchannel: &Subchannel, + subchannel: Arc, state: &SubchannelState, channel_controller: &mut dyn ChannelController, ); @@ -148,7 +174,7 @@ pub trait LbPolicy: Send { /// Controls channel behaviors. pub trait ChannelController: Send + Sync { /// Creates a new subchannel in IDLE state. - fn new_subchannel(&mut self, address: &Address) -> Subchannel; + fn new_subchannel(&mut self, address: &Address) -> Arc; /// Provides a new snapshot of the LB policy's state to the channel. fn update_picker(&mut self, update: LbState); @@ -160,7 +186,7 @@ pub trait ChannelController: Send + Sync { } /// Represents the current state of a Subchannel. -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct SubchannelState { /// The connectivity state of the subchannel. See SubChannel for a /// description of the various states and their valid transitions. @@ -170,20 +196,22 @@ pub struct SubchannelState { pub last_connection_error: Option>, } -/// A convenience wrapper for an LB policy's configuration object. -pub struct LbConfig { - config: Box, -} - -impl LbConfig { - /// Create a new LbConfig wrapper containing the provided config. - pub fn new(config: Box) -> Self { - LbConfig { config } +impl Default for SubchannelState { + fn default() -> Self { + Self { + connectivity_state: ConnectivityState::Idle, + last_connection_error: None, + } } +} - /// Converts the wrapped configuration into the type used by the LbPolicy. - pub fn into(&self) -> Option<&T> { - self.config.downcast_ref::() +impl Display for SubchannelState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "connectivity_state: {}", self.connectivity_state)?; + if let Some(err) = &self.last_connection_error { + write!(f, ", last_connection_error: {}", err)?; + } + Ok(()) } } @@ -239,6 +267,45 @@ pub enum PickResult { Drop(Status), } +impl PickResult { + pub fn unwrap_pick(self) -> Pick { + let PickResult::Pick(pick) = self else { + panic!("Called `PickResult::unwrap_pick` on a `Queue` or `Err` value"); + }; + pick + } +} + +impl PartialEq for PickResult { + fn eq(&self, other: &Self) -> bool { + match self { + PickResult::Pick(pick) => match other { + PickResult::Pick(other_pick) => pick.subchannel == other_pick.subchannel.clone(), + _ => false, + }, + PickResult::Queue => matches!(other, PickResult::Queue), + PickResult::Fail(status) => { + // TODO: implement me. + false + } + PickResult::Drop(status) => { + // TODO: implement me. + false + } + } + } +} + +impl Display for PickResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Pick(_) => write!(f, "Pick"), + Self::Queue => write!(f, "Queue"), + Self::Fail(st) => write!(f, "Fail({})", st), + Self::Drop(st) => write!(f, "Drop({})", st), + } + } +} /// Data provided by the LB policy. #[derive(Clone)] pub struct LbState { @@ -260,11 +327,42 @@ impl LbState { /// A collection of data used by the channel for routing a request. pub struct Pick { /// The Subchannel for the request. - pub subchannel: Subchannel, + pub subchannel: Arc, // Metadata to be added to existing outgoing metadata. pub metadata: MetadataMap, + // Callback to be invoked once the RPC completes. + pub on_complete: Option>, +} + +pub trait DynHash { + fn dyn_hash(&self, state: &mut Box<&mut dyn Hasher>); +} + +impl DynHash for T { + fn dyn_hash(&self, state: &mut Box<&mut dyn Hasher>) { + self.hash(state); + } } +pub trait DynPartialEq { + fn dyn_eq(&self, other: &&dyn Any) -> bool; +} + +impl DynPartialEq for T { + fn dyn_eq(&self, other: &&dyn Any) -> bool { + let Some(other) = other.downcast_ref::() else { + return false; + }; + self.eq(other) + } +} + +mod private { + pub trait Sealed {} +} + +pub trait SealedSubchannel: private::Sealed {} + /// A Subchannel represents a method of communicating with a server which may be /// connected or disconnected many times across its lifetime. /// @@ -280,24 +378,189 @@ pub struct Pick { /// expired. This timer scales exponentially and is reset when the subchannel /// becomes READY. /// -/// When a Subchannel is dropped, it is disconnected, and no subsequent state -/// updates will be provided for it to the LB policy. -#[derive(Clone, Debug)] -pub struct Subchannel; +/// When a Subchannel is dropped, it is disconnected automatically, and no +/// subsequent state updates will be provided for it to the LB policy. +pub trait Subchannel: SealedSubchannel + DynHash + DynPartialEq + Any + Send + Sync { + /// Returns the address of the Subchannel. + /// TODO: Consider whether this should really be public. + fn address(&self) -> Address; + + /// Notifies the Subchannel to connect. + fn connect(&self); +} + +impl dyn Subchannel { + pub fn downcast_ref(&self) -> Option<&T> + where + T: 'static, + { + (self as &dyn Any).downcast_ref() + } +} + +impl Hash for dyn Subchannel { + fn hash(&self, state: &mut H) { + self.dyn_hash(&mut Box::new(state as &mut dyn Hasher)); + } +} + +impl PartialEq for dyn Subchannel { + fn eq(&self, other: &Self) -> bool { + self.dyn_eq(&Box::new(other as &dyn Any)) + } +} + +impl Eq for dyn Subchannel {} + +impl Debug for dyn Subchannel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Subchannel: {}", self.address()) + } +} + +impl Display for dyn Subchannel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Subchannel: {}", self.address()) + } +} + +struct WeakSubchannel(Weak); + +impl From> for WeakSubchannel { + fn from(subchannel: Arc) -> Self { + WeakSubchannel(Arc::downgrade(&subchannel)) + } +} + +impl WeakSubchannel { + pub fn new(subchannel: &Arc) -> Self { + WeakSubchannel(Arc::downgrade(subchannel)) + } + + pub fn upgrade(&self) -> Option> { + self.0.upgrade() + } +} + +impl Hash for WeakSubchannel { + fn hash(&self, state: &mut H) { + if let Some(strong) = self.upgrade() { + return strong.dyn_hash(&mut Box::new(state as &mut dyn Hasher)); + } + panic!("WeakSubchannel is not valid"); + } +} -impl Hash for Subchannel { - fn hash(&self, _state: &mut H) { - todo!() +impl PartialEq for WeakSubchannel { + fn eq(&self, other: &Self) -> bool { + if let Some(strong) = self.upgrade() { + return strong.dyn_eq(&Box::new(other as &dyn Any)); + } + false } } -impl PartialEq for Subchannel { - fn eq(&self, _other: &Self) -> bool { - todo!() +impl Eq for WeakSubchannel {} + +pub(crate) struct ExternalSubchannel { + pub(crate) isc: Option>, + work_scheduler: WorkQueueTx, + watcher: Mutex>>, +} + +impl ExternalSubchannel { + pub(super) fn new(isc: Arc, work_scheduler: WorkQueueTx) -> Self { + ExternalSubchannel { + isc: Some(isc), + work_scheduler, + watcher: Mutex::default(), + } + } + + pub(super) fn set_watcher(&self, watcher: Arc) { + self.watcher.lock().unwrap().replace(watcher); } } -impl Eq for Subchannel {} +impl Hash for ExternalSubchannel { + fn hash(&self, state: &mut H) { + self.address().hash(state); + } +} + +impl PartialEq for ExternalSubchannel { + fn eq(&self, other: &Self) -> bool { + self.address() == other.address() + } +} + +impl Eq for ExternalSubchannel {} + +impl Subchannel for ExternalSubchannel { + fn address(&self) -> Address { + self.isc.as_ref().unwrap().address() + } + + fn connect(&self) { + println!("connect called for subchannel: {}", self); + self.isc.as_ref().unwrap().connect(false); + } +} + +impl SealedSubchannel for ExternalSubchannel {} +impl private::Sealed for ExternalSubchannel {} + +impl Drop for ExternalSubchannel { + fn drop(&mut self) { + let watcher = self.watcher.lock().unwrap().take(); + let address = self.address().address.clone(); + let isc = self.isc.take(); + let _ = self.work_scheduler.send(WorkQueueItem::Closure(Box::new( + move |c: &mut InternalChannelController| { + println!("unregistering connectivity state watcher for {:?}", address); + isc.as_ref() + .unwrap() + .unregister_connectivity_state_watcher(watcher.unwrap()); + }, + // The internal subchannel is dropped from here (i.e., from inside + // the work serializer), if this is the last reference to it. + ))); + } +} + +impl Debug for ExternalSubchannel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Subchannel {}", self.address()) + } +} + +impl Display for ExternalSubchannel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Subchannel {}", self.address()) + } +} + +pub trait ForwardingSubchannel: DynHash + DynPartialEq + Any + Send + Sync { + fn delegate(&self) -> Arc; + + fn address(&self) -> Address { + self.delegate().address() + } + fn connect(&self) { + self.delegate().connect() + } +} + +impl Subchannel for T { + fn address(&self) -> Address { + self.address() + } + fn connect(&self) { + self.connect() + } +} +impl SealedSubchannel for T {} +impl private::Sealed for T {} /// QueuingPicker always returns Queue. LB policies that are not actively /// Connecting should not use this picker. @@ -308,3 +571,13 @@ impl Picker for QueuingPicker { PickResult::Queue } } + +pub struct Failing { + pub error: String, +} + +impl Picker for Failing { + fn pick(&self, _: &Request) -> PickResult { + PickResult::Fail(Status::unavailable(self.error.clone())) + } +} diff --git a/grpc/src/client/load_balancing/pick_first.rs b/grpc/src/client/load_balancing/pick_first.rs new file mode 100644 index 000000000..ed7ae76f6 --- /dev/null +++ b/grpc/src/client/load_balancing/pick_first.rs @@ -0,0 +1,116 @@ +use std::{ + error::Error, + sync::{Arc, Mutex}, + time::Duration, +}; + +use tokio::time::sleep; +use tonic::metadata::MetadataMap; + +use crate::{ + client::{ + load_balancing::{LbPolicy, LbPolicyBuilder, LbState}, + name_resolution::{Address, ResolverUpdate}, + subchannel, ConnectivityState, + }, + service::Request, +}; + +use super::{ + ChannelController, LbConfig, LbPolicyOptions, Pick, PickResult, Picker, Subchannel, + SubchannelState, WorkScheduler, +}; + +pub static POLICY_NAME: &str = "pick_first"; + +struct Builder {} + +impl LbPolicyBuilder for Builder { + fn build(&self, options: LbPolicyOptions) -> Box { + Box::new(PickFirstPolicy { + work_scheduler: options.work_scheduler, + subchannel: None, + next_addresses: Vec::default(), + }) + } + + fn name(&self) -> &'static str { + POLICY_NAME + } +} + +pub fn reg() { + super::GLOBAL_LB_REGISTRY.add_builder(Builder {}) +} + +struct PickFirstPolicy { + work_scheduler: Arc, + subchannel: Option>, + next_addresses: Vec
, +} + +impl LbPolicy for PickFirstPolicy { + fn resolver_update( + &mut self, + update: ResolverUpdate, + config: Option<&LbConfig>, + channel_controller: &mut dyn ChannelController, + ) -> Result<(), Box> { + let mut addresses = update + .endpoints + .unwrap() + .into_iter() + .next() + .ok_or("no endpoints")? + .addresses; + + let address = addresses.pop().ok_or("no addresses")?; + + let sc = channel_controller.new_subchannel(&address); + sc.connect(); + self.subchannel = Some(sc); + + self.next_addresses = addresses; + let work_scheduler = self.work_scheduler.clone(); + // TODO: Implement Drop that cancels this task. + tokio::task::spawn(async move { + sleep(Duration::from_millis(200)).await; + work_scheduler.schedule_work(); + }); + // TODO: return a picker that queues RPCs. + Ok(()) + } + + fn subchannel_update( + &mut self, + subchannel: Arc, + state: &SubchannelState, + channel_controller: &mut dyn ChannelController, + ) { + // Assume the update is for our subchannel. + if state.connectivity_state == ConnectivityState::Ready { + channel_controller.update_picker(LbState { + connectivity_state: ConnectivityState::Ready, + picker: Arc::new(OneSubchannelPicker { + sc: self.subchannel.as_ref().unwrap().clone(), + }), + }); + } + } + + fn work(&mut self, channel_controller: &mut dyn ChannelController) {} +} + +struct OneSubchannelPicker { + sc: Arc, +} + +impl Picker for OneSubchannelPicker { + fn pick(&self, request: &Request) -> PickResult { + PickResult::Pick(Pick { + subchannel: self.sc.clone(), + on_complete: None, + metadata: MetadataMap::new(), + }) + } +} diff --git a/grpc/src/client/load_balancing/registry.rs b/grpc/src/client/load_balancing/registry.rs new file mode 100644 index 000000000..de7a575d5 --- /dev/null +++ b/grpc/src/client/load_balancing/registry.rs @@ -0,0 +1,42 @@ +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; + +use once_cell::sync::Lazy; + +use super::LbPolicyBuilder; + +/// A registry to store and retrieve LB policies. LB policies are indexed by +/// their names. +pub struct LbPolicyRegistry { + m: Arc>>>, +} + +impl LbPolicyRegistry { + /// Construct an empty LB policy registry. + pub fn new() -> Self { + Self { m: Arc::default() } + } + /// Add a LB policy into the registry. + pub(crate) fn add_builder(&self, builder: impl LbPolicyBuilder + 'static) { + self.m + .lock() + .unwrap() + .insert(builder.name().to_string(), Arc::new(builder)); + } + /// Retrieve a LB policy from the registry, or None if not found. + pub(crate) fn get_policy(&self, name: &str) -> Option> { + self.m.lock().unwrap().get(name).cloned() + } +} + +impl Default for LbPolicyRegistry { + fn default() -> Self { + Self::new() + } +} + +/// The registry used if a local registry is not provided to a channel or if it +/// does not exist in the local registry. +pub static GLOBAL_LB_REGISTRY: Lazy = Lazy::new(LbPolicyRegistry::new); diff --git a/grpc/src/client/mod.rs b/grpc/src/client/mod.rs index a7b722556..66c809e62 100644 --- a/grpc/src/client/mod.rs +++ b/grpc/src/client/mod.rs @@ -27,8 +27,12 @@ use std::fmt::Display; pub mod channel; pub(crate) mod load_balancing; pub(crate) mod name_resolution; -pub mod service; pub mod service_config; +pub mod transport; + +mod subchannel; +pub use channel::Channel; +pub use channel::ChannelOptions; /// A representation of the current state of a gRPC channel, also used for the /// state of subchannels (individual connections within the channel). diff --git a/grpc/src/client/name_resolution/mod.rs b/grpc/src/client/name_resolution/mod.rs index 69fbbd018..d6b4383c2 100644 --- a/grpc/src/client/name_resolution/mod.rs +++ b/grpc/src/client/name_resolution/mod.rs @@ -70,6 +70,22 @@ impl FromStr for Target { } } +impl From for Target { + fn from(url: url::Url) -> Self { + Target { url } + } +} + +/// Target represents a target for gRPC, as specified in: +/// https://github.com/grpc/grpc/blob/master/doc/naming.md. +/// It is parsed from the target string that gets passed during channel creation +/// by the user. gRPC passes it to the resolver and the balancer. +/// +/// If the target follows the naming spec, and the parsed scheme is registered +/// with gRPC, we will parse the target string according to the spec. If the +/// target does not contain a scheme or if the parsed scheme is not registered +/// (i.e. no corresponding resolver available to resolve the endpoint), we will +/// apply the default scheme, and will attempt to reparse it. impl Target { pub fn scheme(&self) -> &str { self.url.scheme() @@ -97,7 +113,7 @@ impl Target { } } - /// Return the path for this target URL, as a percent-encoded ASCII string. + /// Retrieves endpoint from `Url.path()`. pub fn path(&self) -> &str { self.url.path() } @@ -125,7 +141,7 @@ pub trait ResolverBuilder: Send + Sync { fn build(&self, target: &Target, options: ResolverOptions) -> Box; /// Reports the URI scheme handled by this name resolver. - fn scheme(&self) -> &'static str; + fn scheme(&self) -> &str; /// Returns the default authority for a channel using this name resolver /// and target. This refers to the *dataplane authority* — the value used @@ -264,9 +280,15 @@ pub struct Endpoint { pub attributes: Attributes, } +impl Hash for Endpoint { + fn hash(&self, state: &mut H) { + self.addresses.hash(state); + } +} + /// An Address is an identifier that indicates how to connect to a server. #[non_exhaustive] -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, Ord, PartialOrd)] pub struct Address { /// The network type is used to identify what kind of transport to create /// when connecting to this address. Typically TCP_IP_ADDRESS_TYPE. @@ -281,38 +303,24 @@ pub struct Address { pub attributes: Attributes, } -impl Eq for Endpoint {} - -impl PartialEq for Endpoint { - fn eq(&self, _other: &Self) -> bool { - todo!() - } -} - -impl Hash for Endpoint { - fn hash(&self, _state: &mut H) { - todo!() - } -} - impl Eq for Address {} impl PartialEq for Address { - fn eq(&self, _other: &Self) -> bool { - todo!() + fn eq(&self, other: &Self) -> bool { + self.network_type == other.network_type && self.address == other.address } } impl Hash for Address { - fn hash(&self, _state: &mut H) { - todo!() + fn hash(&self, state: &mut H) { + self.network_type.hash(state); + self.address.hash(state); } } impl Display for Address { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let addr: &str = &self.address; - write!(f, "{}:{}", self.network_type, addr) + write!(f, "{}:{}", self.network_type, self.address.to_string()) } } @@ -320,7 +328,7 @@ impl Display for Address { /// via TCP/IP. pub static TCP_IP_NETWORK_TYPE: &str = "tcp"; -// A resolver that returns the same result every time it's work method is called. +// A resolver that returns the same result every time its work method is called. // It can be used to return an error to the channel when a resolver fails to // build. struct NopResolver { diff --git a/grpc/src/client/service.rs b/grpc/src/client/service.rs deleted file mode 100644 index 839702f80..000000000 --- a/grpc/src/client/service.rs +++ /dev/null @@ -1,29 +0,0 @@ -/* - * - * Copyright 2025 gRPC authors. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS - * IN THE SOFTWARE. - * - */ - -/// A gRPC Request. -pub struct Request; - -/// A gRPC Response. -pub struct Response; diff --git a/grpc/src/client/service_config.rs b/grpc/src/client/service_config.rs index 5b306bd0f..da268ca33 100644 --- a/grpc/src/client/service_config.rs +++ b/grpc/src/client/service_config.rs @@ -21,8 +21,34 @@ * IN THE SOFTWARE. * */ +use std::{any::Any, error::Error, sync::Arc}; /// An in-memory representation of a service config, usually provided to gRPC as /// a JSON object. #[derive(Debug, Default, Clone)] pub(crate) struct ServiceConfig; + +/// A convenience wrapper for an LB policy's configuration object. +#[derive(Debug)] +pub(crate) struct LbConfig { + config: Arc, +} + +impl LbConfig { + /// Create a new LbConfig wrapper containing the provided config. + pub fn new(config: T) -> Self { + LbConfig { + config: Arc::new(config), + } + } + + /// Convenience method to extract the LB policy's configuration object. + pub fn convert_to( + &self, + ) -> Result, Box> { + match self.config.clone().downcast::() { + Ok(c) => Ok(c), + Err(e) => Err("failed to downcast to config type".into()), + } + } +} diff --git a/grpc/src/client/subchannel.rs b/grpc/src/client/subchannel.rs new file mode 100644 index 000000000..d325257e1 --- /dev/null +++ b/grpc/src/client/subchannel.rs @@ -0,0 +1,550 @@ +use super::{ + channel::{InternalChannelController, WorkQueueTx}, + load_balancing::{self, ExternalSubchannel, Picker, Subchannel, SubchannelState}, + name_resolution::Address, + transport::{self, ConnectedTransport, Transport, TransportRegistry}, + ConnectivityState, +}; +use crate::{ + client::{channel::WorkQueueItem, subchannel}, + service::{Request, Response, Service}, +}; +use core::panic; +use std::{ + collections::BTreeMap, + error::Error, + fmt::{Debug, Display}, + ops::Sub, + sync::{Arc, Mutex, RwLock, Weak}, +}; +use tokio::{ + sync::{mpsc, watch, Notify}, + task::{AbortHandle, JoinHandle}, + time::{Duration, Instant}, +}; +use tonic::async_trait; + +type SharedService = Arc; + +pub trait Backoff: Send + Sync { + fn backoff_until(&self) -> Instant; + fn reset(&self); + fn min_connect_timeout(&self) -> Duration; +} + +// TODO(easwars): Move this somewhere else, where appropriate. +pub(crate) struct NopBackoff {} +impl Backoff for NopBackoff { + fn backoff_until(&self) -> Instant { + Instant::now() + } + fn reset(&self) {} + fn min_connect_timeout(&self) -> Duration { + Duration::from_secs(20) + } +} + +enum InternalSubchannelState { + Idle, + Connecting(InternalSubchannelConnectingState), + Ready(InternalSubchannelReadyState), + TransientFailure(InternalSubchannelTransientFailureState), +} + +struct InternalSubchannelConnectingState { + abort_handle: Option, +} + +struct InternalSubchannelReadyState { + abort_handle: Option, + svc: SharedService, +} + +struct InternalSubchannelTransientFailureState { + abort_handle: Option, + error: String, +} + +impl InternalSubchannelState { + fn connected_transport(&self) -> Option { + match self { + Self::Ready(st) => Some(st.svc.clone()), + _ => None, + } + } + + fn to_subchannel_state(&self) -> SubchannelState { + match self { + Self::Idle => SubchannelState { + connectivity_state: ConnectivityState::Idle, + last_connection_error: None, + }, + Self::Connecting(_) => SubchannelState { + connectivity_state: ConnectivityState::Connecting, + last_connection_error: None, + }, + Self::Ready(_) => SubchannelState { + connectivity_state: ConnectivityState::Ready, + last_connection_error: None, + }, + Self::TransientFailure(st) => { + let arc_err: Arc = Arc::from(Box::from(st.error.clone())); + SubchannelState { + connectivity_state: ConnectivityState::TransientFailure, + last_connection_error: Some(arc_err), + } + } + } + } +} + +impl Display for InternalSubchannelState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Idle => write!(f, "Idle"), + Self::Connecting(_) => write!(f, "Connecting"), + Self::Ready(_) => write!(f, "Ready"), + Self::TransientFailure(_) => write!(f, "TransientFailure"), + } + } +} + +impl Debug for InternalSubchannelState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Idle => write!(f, "Idle"), + Self::Connecting(_) => write!(f, "Connecting"), + Self::Ready(_) => write!(f, "Ready"), + Self::TransientFailure(_) => write!(f, "TransientFailure"), + } + } +} + +impl PartialEq for InternalSubchannelState { + fn eq(&self, other: &Self) -> bool { + match &self { + Self::Idle => { + if let Self::Idle = other { + return true; + } + } + Self::Connecting(_) => { + if let Self::Connecting(_) = other { + return true; + } + } + Self::Ready(_) => { + if let Self::Ready(_) = other { + return true; + } + } + Self::TransientFailure(_) => { + if let Self::TransientFailure(_) = other { + return true; + } + } + } + false + } +} + +impl Drop for InternalSubchannelState { + fn drop(&mut self) { + match &self { + Self::Idle => {} + Self::Connecting(st) => { + if let Some(ah) = &st.abort_handle { + ah.abort(); + } + } + Self::Ready(st) => { + if let Some(ah) = &st.abort_handle { + ah.abort(); + } + } + Self::TransientFailure(st) => { + if let Some(ah) = &st.abort_handle { + ah.abort(); + } + } + } + } +} + +pub(crate) struct InternalSubchannel { + key: SubchannelKey, + transport: Arc, + backoff: Arc, + unregister_fn: Option>, + state_machine_event_sender: mpsc::UnboundedSender, + inner: Mutex, +} + +struct InnerSubchannel { + state: InternalSubchannelState, + watchers: Vec>, // TODO(easwars): Revisit the choice for this data structure. + backoff_task: Option>, + disconnect_task: Option>, +} + +#[async_trait] +impl Service for InternalSubchannel { + async fn call(&self, method: String, request: Request) -> Response { + let svc = self.inner.lock().unwrap().state.connected_transport(); + if svc.is_none() { + // TODO(easwars): Change the signature of this method to return a + // Result + panic!("todo: handle !ready"); + } + + let svc = svc.unwrap().clone(); + return svc.call(method, request).await; + } +} + +enum SubchannelStateMachineEvent { + ConnectionRequested, + ConnectionSucceeded(SharedService), + ConnectionTimedOut, + ConnectionFailed(String), + ConnectionTerminated, + BackoffExpired, +} +impl Debug for SubchannelStateMachineEvent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::ConnectionRequested => write!(f, "ConnectionRequested"), + Self::ConnectionSucceeded(_) => write!(f, "ConnectionSucceeded"), + Self::ConnectionTimedOut => write!(f, "ConnectionTimedOut"), + Self::ConnectionFailed(_) => write!(f, "ConnectionFailed"), + Self::ConnectionTerminated => write!(f, "ConnectionTerminated"), + Self::BackoffExpired => write!(f, "BackoffExpired"), + } + } +} + +impl InternalSubchannel { + pub(super) fn new( + key: SubchannelKey, + transport: Arc, + backoff: Arc, + unregister_fn: Box, + ) -> Arc { + println!("creating new internal subchannel for: {:?}", &key); + let (tx, mut rx) = mpsc::unbounded_channel::(); + let isc = Arc::new(Self { + key: key.clone(), + transport, + backoff: backoff.clone(), + unregister_fn: Some(unregister_fn), + state_machine_event_sender: tx, + inner: Mutex::new(InnerSubchannel { + state: InternalSubchannelState::Idle, + watchers: Vec::new(), + backoff_task: None, + disconnect_task: None, + }), + }); + + // This long running task implements the subchannel state machine. When + // the subchannel is dropped, the channel from which this task reads is + // closed, and therefore this task exits because rx.recv() returns None + // in that case. + let arc_to_self = Arc::clone(&isc); + tokio::task::spawn(async move { + println!("starting subchannel state machine for: {:?}", &key); + while let Some(m) = rx.recv().await { + println!("subchannel {:?} received event {:?}", &key, &m); + match m { + SubchannelStateMachineEvent::ConnectionRequested => { + arc_to_self.move_to_connecting(); + } + SubchannelStateMachineEvent::ConnectionSucceeded(svc) => { + arc_to_self.move_to_ready(svc); + } + SubchannelStateMachineEvent::ConnectionTimedOut => { + arc_to_self.move_to_transient_failure("connect timeout expired".into()); + } + SubchannelStateMachineEvent::ConnectionFailed(err) => { + arc_to_self.move_to_transient_failure(err); + } + SubchannelStateMachineEvent::ConnectionTerminated => { + arc_to_self.move_to_idle(); + } + SubchannelStateMachineEvent::BackoffExpired => { + arc_to_self.move_to_idle(); + } + } + } + println!("exiting work queue task in subchannel"); + }); + isc + } + + pub(super) fn address(&self) -> Address { + self.key.address.clone() + } + + /// Begins connecting the subchannel asynchronously. If now is set, does + /// not wait for any pending connection backoff to complete. + pub(super) fn connect(&self, now: bool) { + let state = &self.inner.lock().unwrap().state; + if let InternalSubchannelState::Idle = state { + let _ = self + .state_machine_event_sender + .send(SubchannelStateMachineEvent::ConnectionRequested); + } + } + + pub(super) fn register_connectivity_state_watcher(&self, watcher: Arc) { + let mut inner = self.inner.lock().unwrap(); + inner.watchers.push(watcher.clone()); + let state = inner.state.to_subchannel_state().clone(); + watcher.on_state_change(state); + } + + pub(super) fn unregister_connectivity_state_watcher( + &self, + watcher: Arc, + ) { + self.inner + .lock() + .unwrap() + .watchers + .retain(|x| !Arc::ptr_eq(x, &watcher)); + } + + fn notify_watchers(&self, state: SubchannelState) { + let mut inner = self.inner.lock().unwrap(); + inner.state = InternalSubchannelState::Idle; + for w in &inner.watchers { + w.on_state_change(state.clone()); + } + } + + fn move_to_idle(&self) { + self.notify_watchers(SubchannelState { + connectivity_state: ConnectivityState::Idle, + last_connection_error: None, + }); + } + + fn move_to_connecting(&self) { + { + let mut inner = self.inner.lock().unwrap(); + inner.state = InternalSubchannelState::Connecting(InternalSubchannelConnectingState { + abort_handle: None, + }); + } + self.notify_watchers(SubchannelState { + connectivity_state: ConnectivityState::Connecting, + last_connection_error: None, + }); + + let min_connect_timeout = self.backoff.min_connect_timeout(); + let transport = self.transport.clone(); + let address = self.address().address; + let state_machine_tx = self.state_machine_event_sender.clone(); + let connect_task = tokio::task::spawn(async move { + tokio::select! { + _ = tokio::time::sleep(min_connect_timeout) => { + let _ = state_machine_tx.send(SubchannelStateMachineEvent::ConnectionTimedOut); + } + result = transport.connect(address.to_string().clone()) => { + match result { + Ok(s) => { + let _ = state_machine_tx.send(SubchannelStateMachineEvent::ConnectionSucceeded(Arc::from(s))); + } + Err(e) => { + let _ = state_machine_tx.send(SubchannelStateMachineEvent::ConnectionFailed(e)); + } + } + }, + } + }); + let mut inner = self.inner.lock().unwrap(); + inner.state = InternalSubchannelState::Connecting(InternalSubchannelConnectingState { + abort_handle: Some(connect_task.abort_handle()), + }); + } + + fn move_to_ready(&self, svc: SharedService) { + let svc2 = svc.clone(); + { + let mut inner = self.inner.lock().unwrap(); + inner.state = InternalSubchannelState::Ready(InternalSubchannelReadyState { + abort_handle: None, + svc: svc2.clone(), + }); + } + self.notify_watchers(SubchannelState { + connectivity_state: ConnectivityState::Ready, + last_connection_error: None, + }); + + let state_machine_tx = self.state_machine_event_sender.clone(); + let disconnect_task = tokio::task::spawn(async move { + // TODO(easwars): Does it make sense for disconnected() to return an + // error string containing information about why the connection + // terminated? But what can we do with that error other than logging + // it, which the transport can do as well? + svc.disconnected().await; + let _ = state_machine_tx.send(SubchannelStateMachineEvent::ConnectionTerminated); + }); + let mut inner = self.inner.lock().unwrap(); + inner.state = InternalSubchannelState::Ready(InternalSubchannelReadyState { + abort_handle: Some(disconnect_task.abort_handle()), + svc: svc2.clone(), + }); + } + + fn move_to_transient_failure(&self, err: String) { + { + let mut inner = self.inner.lock().unwrap(); + inner.state = InternalSubchannelState::TransientFailure( + InternalSubchannelTransientFailureState { + abort_handle: None, + error: err.clone(), + }, + ); + } + + let arc_err: Arc = Arc::from(Box::from(err.clone())); + self.notify_watchers(SubchannelState { + connectivity_state: ConnectivityState::TransientFailure, + last_connection_error: Some(arc_err.clone()), + }); + + let backoff_interval = self.backoff.backoff_until(); + let state_machine_tx = self.state_machine_event_sender.clone(); + let backoff_task = tokio::task::spawn(async move { + tokio::time::sleep_until(backoff_interval).await; + let _ = state_machine_tx.send(SubchannelStateMachineEvent::BackoffExpired); + }); + let mut inner = self.inner.lock().unwrap(); + inner.state = + InternalSubchannelState::TransientFailure(InternalSubchannelTransientFailureState { + abort_handle: Some(backoff_task.abort_handle()), + error: err.clone(), + }); + } + + /// Wait for any in-flight RPCs to terminate and then close the connection + /// and destroy the Subchannel. + async fn drain(self) {} +} + +impl Drop for InternalSubchannel { + fn drop(&mut self) { + println!("dropping internal subchannel {:?}", self.key); + let unregister_fn = self.unregister_fn.take(); + unregister_fn.unwrap()(self.key.clone()); + } +} + +// SubchannelKey uniiquely identifies a subchannel in the pool. +#[derive(PartialEq, PartialOrd, Eq, Ord, Clone)] + +pub(crate) struct SubchannelKey { + address: Address, +} + +impl SubchannelKey { + pub(crate) fn new(address: Address) -> Self { + Self { address } + } +} + +impl Display for SubchannelKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.address.address.to_string()) + } +} + +impl Debug for SubchannelKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.address) + } +} + +pub(super) struct InternalSubchannelPool { + subchannels: RwLock>>, +} + +impl InternalSubchannelPool { + pub(super) fn new() -> Self { + Self { + subchannels: RwLock::new(BTreeMap::new()), + } + } + + pub(super) fn lookup_subchannel(&self, key: &SubchannelKey) -> Option> { + println!("looking up subchannel for: {:?} in the pool", key); + if let Some(weak_isc) = self.subchannels.read().unwrap().get(key) { + if let Some(isc) = weak_isc.upgrade() { + return Some(isc); + } + } + None + } + + pub(super) fn register_subchannel( + &self, + key: &SubchannelKey, + isc: Arc, + ) -> Arc { + println!("registering subchannel for: {:?} with the pool", key); + self.subchannels + .write() + .unwrap() + .insert(key.clone(), Arc::downgrade(&isc)); + isc + } + + pub(super) fn unregister_subchannel(&self, key: &SubchannelKey) { + let mut subchannels = self.subchannels.write().unwrap(); + if let Some(weak_isc) = subchannels.get(key) { + if let Some(isc) = weak_isc.upgrade() { + return; + } + println!("removing subchannel for: {:?} from the pool", key); + subchannels.remove(key); + return; + } + panic!("attempt to unregister subchannel for unknown key {:?}", key); + } +} + +#[derive(Clone)] +pub(super) struct SubchannelStateWatcher { + subchannel: Weak, + work_scheduler: WorkQueueTx, +} + +impl SubchannelStateWatcher { + pub(super) fn new(sc: Arc, work_scheduler: WorkQueueTx) -> Self { + Self { + subchannel: Arc::downgrade(&sc), + work_scheduler, + } + } + + fn on_state_change(&self, state: SubchannelState) { + // Ignore internal subchannel state changes if the external subchannel + // was dropped but its state watcher is still pending unregistration; + // such updates are inconsequential. + if let Some(sc) = self.subchannel.upgrade() { + let _ = self.work_scheduler.send(WorkQueueItem::Closure(Box::new( + move |c: &mut InternalChannelController| { + c.lb.clone() + .policy + .lock() + .unwrap() + .as_mut() + .unwrap() + .subchannel_update(sc, &state, c); + }, + ))); + } + } +} diff --git a/grpc/src/client/transport/mod.rs b/grpc/src/client/transport/mod.rs new file mode 100644 index 000000000..4c5b021b8 --- /dev/null +++ b/grpc/src/client/transport/mod.rs @@ -0,0 +1,16 @@ +use crate::service::Service; + +mod registry; + +use ::tonic::async_trait; +pub use registry::{TransportRegistry, GLOBAL_TRANSPORT_REGISTRY}; + +#[async_trait] +pub trait Transport: Send + Sync { + async fn connect(&self, address: String) -> Result, String>; +} + +#[async_trait] +pub trait ConnectedTransport: Service { + async fn disconnected(&self); +} diff --git a/grpc/src/client/transport/registry.rs b/grpc/src/client/transport/registry.rs new file mode 100644 index 000000000..3af6d7de5 --- /dev/null +++ b/grpc/src/client/transport/registry.rs @@ -0,0 +1,62 @@ +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; + +use once_cell::sync::Lazy; + +use super::Transport; + +/// A registry to store and retrieve transports. Transports are indexed by +/// the address type they are intended to handle. +#[derive(Clone)] +pub struct TransportRegistry { + m: Arc>>>, +} + +impl std::fmt::Debug for TransportRegistry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let m = self.m.lock().unwrap(); + for key in m.keys() { + write!(f, "k: {:?}", key)? + } + Ok(()) + } +} + +impl TransportRegistry { + /// Construct an empty name resolver registry. + pub fn new() -> Self { + Self { m: Arc::default() } + } + /// Add a name resolver into the registry. + pub fn add_transport(&self, address_type: &str, transport: impl Transport + 'static) { + //let a: Arc = transport; + //let a: Arc> = transport; + self.m + .lock() + .unwrap() + .insert(address_type.to_string(), Arc::new(transport)); + } + /// Retrieve a name resolver from the registry, or None if not found. + pub fn get_transport(&self, address_type: &str) -> Result, String> { + self.m + .lock() + .unwrap() + .get(address_type) + .ok_or(format!( + "no transport found for address type {address_type}" + )) + .cloned() + } +} + +impl Default for TransportRegistry { + fn default() -> Self { + Self::new() + } +} + +/// The registry used if a local registry is not provided to a channel or if it +/// does not exist in the local registry. +pub static GLOBAL_TRANSPORT_REGISTRY: Lazy = Lazy::new(TransportRegistry::new); diff --git a/grpc/src/credentials/mod.rs b/grpc/src/credentials/mod.rs new file mode 100644 index 000000000..8dd788ac0 --- /dev/null +++ b/grpc/src/credentials/mod.rs @@ -0,0 +1 @@ +pub trait Credentials {} diff --git a/grpc/src/inmemory/mod.rs b/grpc/src/inmemory/mod.rs new file mode 100644 index 000000000..6884f4d68 --- /dev/null +++ b/grpc/src/inmemory/mod.rs @@ -0,0 +1,177 @@ +use std::{ + collections::HashMap, + ops::Add, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, +}; + +use crate::{ + client::{ + name_resolution::{ + self, global_registry, Address, ChannelController, Endpoint, Resolver, ResolverBuilder, + ResolverOptions, ResolverUpdate, + }, + transport::{self, ConnectedTransport, GLOBAL_TRANSPORT_REGISTRY}, + }, + server, + service::{Request, Response, Service}, +}; +use once_cell::sync::Lazy; +use tokio::sync::{mpsc, oneshot, Mutex, Notify}; +use tonic::async_trait; + +pub struct Listener { + id: String, + s: Box>>, + r: Arc>>>, + // List of notifiers to call when closed. + closed: Notify, +} + +static ID: AtomicU32 = AtomicU32::new(0); + +impl Listener { + pub fn new() -> Arc { + let (tx, rx) = mpsc::channel(1); + let s = Arc::new(Self { + id: format!("{}", ID.fetch_add(1, Ordering::Relaxed)), + s: Box::new(tx), + r: Arc::new(Mutex::new(rx)), + closed: Notify::new(), + }); + LISTENERS.lock().unwrap().insert(s.id.clone(), s.clone()); + s + } + + pub fn target(&self) -> String { + format!("inmemory:///{}", self.id) + } + + pub fn id(&self) -> String { + self.id.clone() + } + + pub async fn close(&self) { + let _ = self.s.send(None).await; + } +} + +impl Drop for Listener { + fn drop(&mut self) { + self.closed.notify_waiters(); + LISTENERS.lock().unwrap().remove(&self.id); + } +} + +#[async_trait] +impl Service for Arc { + async fn call(&self, method: String, request: Request) -> Response { + // 1. unblock accept, giving it a func back to me + // 2. return what that func had + let (s, r) = oneshot::channel(); + self.s.send(Some((method, request, s))).await.unwrap(); + r.await.unwrap() + } +} + +#[async_trait] +impl ConnectedTransport for Arc { + async fn disconnected(&self) { + self.closed.notified().await; + } +} + +#[async_trait] +impl crate::server::Listener for Arc { + async fn accept(&self) -> Option { + let mut recv = self.r.lock().await; + let r = recv.recv().await; + if r.is_none() { + // Listener was closed. + return None; + } + r.unwrap() + } +} + +static LISTENERS: Lazy>>> = + Lazy::new(std::sync::Mutex::default); + +struct ClientTransport {} + +impl ClientTransport { + fn new() -> Self { + Self {} + } +} + +#[async_trait] +impl transport::Transport for ClientTransport { + async fn connect(&self, address: String) -> Result, String> { + let lis = LISTENERS + .lock() + .unwrap() + .get(&address) + .ok_or(format!("Could not find listener for address {address}"))? + .clone(); + Ok(Box::new(lis)) + } +} + +static INMEMORY_NETWORK_TYPE: &str = "inmemory"; + +pub fn reg() { + GLOBAL_TRANSPORT_REGISTRY.add_transport(INMEMORY_NETWORK_TYPE, ClientTransport::new()); + global_registry().add_builder(Box::new(InMemoryResolverBuilder)); +} + +struct InMemoryResolverBuilder; + +impl ResolverBuilder for InMemoryResolverBuilder { + fn scheme(&self) -> &'static str { + "inmemory" + } + + fn build( + &self, + target: &name_resolution::Target, + options: ResolverOptions, + ) -> Box { + let id = target.path().strip_prefix("/").unwrap().to_string(); + options.work_scheduler.schedule_work(); + Box::new(NopResolver { id }) + } + + fn is_valid_uri(&self, uri: &crate::client::name_resolution::Target) -> bool { + true + } +} + +struct NopResolver { + id: String, +} + +impl Resolver for NopResolver { + fn work(&mut self, channel_controller: &mut dyn ChannelController) { + let mut addresses: Vec
= Vec::new(); + for addr in LISTENERS.lock().unwrap().keys() { + addresses.push(Address { + network_type: INMEMORY_NETWORK_TYPE, + address: addr.clone().into(), + ..Default::default() + }); + } + + let _ = channel_controller.update(ResolverUpdate { + endpoints: Ok(vec![Endpoint { + addresses, + ..Default::default() + }]), + ..Default::default() + }); + } + + fn resolve_now(&mut self) {} +} diff --git a/grpc/src/lib.rs b/grpc/src/lib.rs index 567925131..45352523b 100644 --- a/grpc/src/lib.rs +++ b/grpc/src/lib.rs @@ -29,11 +29,13 @@ //! APIs are unstable. Proceed at your own risk. //! //! [gRPC]: https://grpc.io - -#![allow(dead_code)] +#![allow(dead_code, unused_variables, unused_imports)] pub mod client; -mod rt; +pub mod credentials; +pub mod inmemory; +pub mod rt; +pub mod server; pub mod service; pub(crate) mod attributes; diff --git a/grpc/src/rt/mod.rs b/grpc/src/rt/mod.rs index 0cf4c3361..f550cba90 100644 --- a/grpc/src/rt/mod.rs +++ b/grpc/src/rt/mod.rs @@ -22,7 +22,9 @@ * */ -use std::{future::Future, pin::Pin}; +use ::tokio::io::{AsyncRead, AsyncWrite}; + +use std::{future::Future, net::SocketAddr, pin::Pin, time::Duration}; pub mod tokio; @@ -73,3 +75,11 @@ pub(super) struct ResolverOptions { /// system's default DNS server will be used. pub(super) server_addr: Option, } + +#[derive(Default)] +pub struct TcpOptions { + pub enable_nodelay: bool, + pub keepalive: Option, +} + +pub trait TcpStream: AsyncRead + AsyncWrite + Send + Unpin {} diff --git a/grpc/src/rt/tokio/mod.rs b/grpc/src/rt/tokio/mod.rs index f33886fb8..25d6cdb82 100644 --- a/grpc/src/rt/tokio/mod.rs +++ b/grpc/src/rt/tokio/mod.rs @@ -29,7 +29,10 @@ use std::{ time::Duration, }; -use tokio::task::JoinHandle; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + task::JoinHandle, +}; use super::{DnsResolver, ResolverOptions, Runtime, Sleep, TaskHandle}; diff --git a/grpc/src/server/mod.rs b/grpc/src/server/mod.rs new file mode 100644 index 000000000..18da685ca --- /dev/null +++ b/grpc/src/server/mod.rs @@ -0,0 +1,41 @@ +use std::sync::Arc; + +use tokio::sync::oneshot; +use tonic::async_trait; + +use crate::service::{Request, Response, Service}; + +pub struct Server { + handler: Option>, +} + +pub type Call = (String, Request, oneshot::Sender); + +#[async_trait] +pub trait Listener { + async fn accept(&self) -> Option; +} + +impl Server { + pub fn new() -> Self { + Self { handler: None } + } + + pub fn set_handler(&mut self, f: impl Service + 'static) { + self.handler = Some(Arc::new(f)) + } + + pub async fn serve(&self, l: &impl Listener) { + while let Some((method, req, reply_on)) = l.accept().await { + reply_on + .send(self.handler.as_ref().unwrap().call(method, req).await) + .ok(); // TODO: log error + } + } +} + +impl Default for Server { + fn default() -> Self { + Self::new() + } +} From dac9d0a7e880803103418e5f18a29bad6ab34a92 Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Mon, 14 Jul 2025 15:44:45 -0700 Subject: [PATCH 2/5] delete TcpStream type; add allowed external types --- grpc/Cargo.toml | 2 ++ grpc/src/rt/mod.rs | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/grpc/Cargo.toml b/grpc/Cargo.toml index 95414a591..86f1e8c22 100644 --- a/grpc/Cargo.toml +++ b/grpc/Cargo.toml @@ -42,4 +42,6 @@ dns = ["dep:hickory-resolver"] allowed_external_types = [ "tonic::*", "futures_core::stream::Stream", + "tokio::sync::oneshot::Sender", + "once_cell::sync::Lazy", ] diff --git a/grpc/src/rt/mod.rs b/grpc/src/rt/mod.rs index f550cba90..78accb53f 100644 --- a/grpc/src/rt/mod.rs +++ b/grpc/src/rt/mod.rs @@ -81,5 +81,3 @@ pub struct TcpOptions { pub enable_nodelay: bool, pub keepalive: Option, } - -pub trait TcpStream: AsyncRead + AsyncWrite + Send + Unpin {} From c63ef5f03d2f21f78421b114d2c93e1c4e74b926 Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Mon, 14 Jul 2025 15:56:07 -0700 Subject: [PATCH 3/5] set rust-version in grpc's Cargo.toml --- grpc/Cargo.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/grpc/Cargo.toml b/grpc/Cargo.toml index 86f1e8c22..ae5c224a5 100644 --- a/grpc/Cargo.toml +++ b/grpc/Cargo.toml @@ -45,3 +45,6 @@ allowed_external_types = [ "tokio::sync::oneshot::Sender", "once_cell::sync::Lazy", ] + +[workspace.package] +rust-version = "1.86" From 931d7e977492528460b282b091746127fed5f6df Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Mon, 14 Jul 2025 15:57:01 -0700 Subject: [PATCH 4/5] remove version --- grpc/Cargo.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/grpc/Cargo.toml b/grpc/Cargo.toml index ae5c224a5..86f1e8c22 100644 --- a/grpc/Cargo.toml +++ b/grpc/Cargo.toml @@ -45,6 +45,3 @@ allowed_external_types = [ "tokio::sync::oneshot::Sender", "once_cell::sync::Lazy", ] - -[workspace.package] -rust-version = "1.86" From 64243426ba3f5b72fcb8827b221c5adbf4037f71 Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Thu, 17 Jul 2025 15:29:03 -0700 Subject: [PATCH 5/5] update rust version to 1.86 for trait upcasting feature --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index ce9bc4d43..7b2aaddd8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,7 @@ members = [ resolver = "2" [workspace.package] -rust-version = "1.75" +rust-version = "1.86" [workspace.lints.rust] missing_debug_implementations = "warn"