From e851bccba88722757848473027873eb5f7347a69 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 5 Dec 2024 18:56:13 +0200 Subject: [PATCH 1/5] Make blobs more cheaply cloneable by by giving it an Inner --- src/net_protocol.rs | 73 ++++++++++++++++++++++++++------------------- 1 file changed, 42 insertions(+), 31 deletions(-) diff --git a/src/net_protocol.rs b/src/net_protocol.rs index c02a19acc..395d275c4 100644 --- a/src/net_protocol.rs +++ b/src/net_protocol.rs @@ -47,16 +47,21 @@ impl Default for GcState { } } -#[derive(Debug, Clone)] -pub struct Blobs { +#[derive(Debug)] +struct BlobsInner { rt: LocalPoolHandle, pub(crate) store: S, events: EventSender, downloader: Downloader, - #[cfg(feature = "rpc")] - batches: Arc>, endpoint: Endpoint, - gc_state: Arc>, + gc_state: std::sync::Mutex, + #[cfg(feature = "rpc")] + batches: tokio::sync::Mutex, +} + +#[derive(Debug, Clone)] +pub struct Blobs { + inner: Arc>, #[cfg(feature = "rpc")] pub(crate) rpc_handler: Arc>, } @@ -178,40 +183,46 @@ impl Blobs { endpoint: Endpoint, ) -> Self { Self { - rt, - store, - events, - downloader, - endpoint, - #[cfg(feature = "rpc")] - batches: Default::default(), - gc_state: Default::default(), + inner: Arc::new(BlobsInner { + rt, + store, + events, + downloader, + endpoint, + #[cfg(feature = "rpc")] + batches: Default::default(), + gc_state: Default::default(), + }), #[cfg(feature = "rpc")] rpc_handler: Default::default(), } } pub fn store(&self) -> &S { - &self.store + &self.inner.store + } + + pub fn events(&self) -> &EventSender { + &self.inner.events } pub fn rt(&self) -> &LocalPoolHandle { - &self.rt + &self.inner.rt } pub fn downloader(&self) -> &Downloader { - &self.downloader + &self.inner.downloader } pub fn endpoint(&self) -> &Endpoint { - &self.endpoint + &self.inner.endpoint } /// Add a callback that will be called before the garbage collector runs. /// /// This can only be called before the garbage collector has started, otherwise it will return an error. pub fn add_protected(&self, cb: ProtectCb) -> Result<()> { - let mut state = self.gc_state.lock().unwrap(); + let mut state = self.inner.gc_state.lock().unwrap(); match &mut *state { GcState::Initial(cbs) => { cbs.push(cb); @@ -225,7 +236,7 @@ impl Blobs { /// Start garbage collection with the given settings. pub fn start_gc(&self, config: GcConfig) -> Result<()> { - let mut state = self.gc_state.lock().unwrap(); + let mut state = self.inner.gc_state.lock().unwrap(); let protected = match state.deref_mut() { GcState::Initial(items) => std::mem::take(items), GcState::Started(_) => bail!("gc already started"), @@ -241,9 +252,9 @@ impl Blobs { set } }; - let store = self.store.clone(); + let store = self.store().clone(); let run = self - .rt + .rt() .spawn(move || async move { store.gc_run(config, protected_cb).await }); *state = GcState::Started(Some(run)); Ok(()) @@ -251,7 +262,7 @@ impl Blobs { #[cfg(feature = "rpc")] pub(crate) async fn batches(&self) -> tokio::sync::MutexGuard<'_, BlobBatches> { - self.batches.lock().await + self.inner.batches.lock().await } pub(crate) async fn download( @@ -268,7 +279,7 @@ impl Blobs { mode, } = req; let hash_and_format = HashAndFormat { hash, format }; - let temp_tag = self.store.temp_tag(hash_and_format); + let temp_tag = self.store().temp_tag(hash_and_format); let stats = match mode { DownloadMode::Queued => { self.download_queued(endpoint, hash_and_format, nodes, progress.clone()) @@ -283,10 +294,10 @@ impl Blobs { progress.send(DownloadProgress::AllDone(stats)).await.ok(); match tag { SetTagOption::Named(tag) => { - self.store.set_tag(tag, Some(hash_and_format)).await?; + self.store().set_tag(tag, Some(hash_and_format)).await?; } SetTagOption::Auto => { - self.store.create_tag(hash_and_format).await?; + self.store().create_tag(hash_and_format).await?; } } drop(temp_tag); @@ -316,7 +327,7 @@ impl Blobs { let can_download = !node_ids.is_empty() && (any_added || endpoint.discovery().is_some()); anyhow::ensure!(can_download, "no way to reach a node for download"); let req = DownloadRequest::new(hash_and_format, node_ids).progress_sender(progress); - let handle = self.downloader.queue(req).await; + let handle = self.downloader().queue(req).await; let stats = handle.await?; Ok(stats) } @@ -334,7 +345,7 @@ impl Blobs { let mut nodes_iter = nodes.into_iter(); 'outer: loop { match crate::get::db::get_to_db_in_steps( - self.store.clone(), + self.store().clone(), hash_and_format, progress.clone(), ) @@ -393,9 +404,9 @@ impl Blobs { impl ProtocolHandler for Blobs { fn accept(&self, conn: Connecting) -> BoxedFuture> { - let db = self.store.clone(); - let events = self.events.clone(); - let rt = self.rt.clone(); + let db = self.store().clone(); + let events = self.events().clone(); + let rt = self.rt().clone(); Box::pin(async move { crate::provider::handle_connection(conn.await?, db, events, rt).await; @@ -404,7 +415,7 @@ impl ProtocolHandler for Blobs { } fn shutdown(&self) -> BoxedFuture<()> { - let store = self.store.clone(); + let store = self.store().clone(); Box::pin(async move { store.shutdown().await; }) From 09562ce4b7c9b04f082713145ccfae46dee650b5 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Fri, 6 Dec 2024 11:17:05 +0200 Subject: [PATCH 2/5] Remove the lazy part. The lazy handler kept a reference to Blobs alive. This caused both the task and the blobs to never be dropped. To solve this you can just split the inner part in 2 parts, one that has the handle and one that has the logic. But that is not nice. I think it is best for the mem rpc handler to exist completely separately, especially given that rpc is a non-default feature. --- src/net_protocol.rs | 4 ---- src/rpc.rs | 30 ++++++++++++++++++++---------- tests/blobs.rs | 7 +------ tests/gc.rs | 6 +++--- 4 files changed, 24 insertions(+), 23 deletions(-) diff --git a/src/net_protocol.rs b/src/net_protocol.rs index 395d275c4..5dab29d5b 100644 --- a/src/net_protocol.rs +++ b/src/net_protocol.rs @@ -62,8 +62,6 @@ struct BlobsInner { #[derive(Debug, Clone)] pub struct Blobs { inner: Arc>, - #[cfg(feature = "rpc")] - pub(crate) rpc_handler: Arc>, } /// Keeps track of all the currently active batch operations of the blobs api. @@ -193,8 +191,6 @@ impl Blobs { batches: Default::default(), gc_state: Default::default(), }), - #[cfg(feature = "rpc")] - rpc_handler: Default::default(), } } diff --git a/src/rpc.rs b/src/rpc.rs index 6f5ee8ba3..cf46ac499 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -2,12 +2,13 @@ use std::{ io, + ops::Deref, sync::{Arc, Mutex}, }; use anyhow::anyhow; use client::{ - blobs::{self, BlobInfo, BlobStatus, IncompleteBlobInfo, WrapOption}, + blobs::{BlobInfo, BlobStatus, IncompleteBlobInfo, MemClient, WrapOption}, tags::TagInfo, MemConnector, }; @@ -62,13 +63,8 @@ const RPC_BLOB_GET_CHANNEL_CAP: usize = 2; impl Blobs { /// Get a client for the blobs protocol - pub fn client(&self) -> blobs::MemClient { - let client = self - .rpc_handler - .get_or_init(|| RpcHandler::new(self)) - .client - .clone(); - blobs::Client::new(client) + pub fn client(&self) -> RpcHandler { + RpcHandler::new(self) } /// Handle an RPC request @@ -874,20 +870,34 @@ impl Blobs { } } +/// A rpc handler for the blobs rpc protocol +/// +/// This struct contains both a task that handles rpc requests and a client +/// that can be used to send rpc requests. Dropping it will stop the handler task, +/// so you need to put it somewhere where it will be kept alive. #[derive(Debug)] -pub(crate) struct RpcHandler { +pub struct RpcHandler { /// Client to hand out - client: RpcClient, + client: MemClient, /// Handler task _handler: AbortOnDropHandle<()>, } +impl Deref for RpcHandler { + type Target = MemClient; + + fn deref(&self) -> &Self::Target { + &self.client + } +} + impl RpcHandler { fn new(blobs: &Blobs) -> Self { let blobs = blobs.clone(); let (listener, connector) = quic_rpc::transport::flume::channel(1); let listener = RpcServer::new(listener); let client = RpcClient::new(connector); + let client = MemClient::new(client); let _handler = listener .spawn_accept_loop(move |req, chan| blobs.clone().handle_rpc_request(req, chan)); Self { client, _handler } diff --git a/tests/blobs.rs b/tests/blobs.rs index ad1198f92..ab273b118 100644 --- a/tests/blobs.rs +++ b/tests/blobs.rs @@ -32,12 +32,7 @@ async fn blobs_gc_protected() -> TestResult<()> { let pool = LocalPool::default(); let endpoint = Endpoint::builder().bind().await?; let blobs = Blobs::memory().build(pool.handle(), &endpoint); - let client: iroh_blobs::rpc::client::blobs::Client< - quic_rpc::transport::flume::FlumeConnector< - iroh_blobs::rpc::proto::Response, - iroh_blobs::rpc::proto::Request, - >, - > = blobs.clone().client(); + let client = blobs.clone().client(); let h1 = client.add_bytes(b"test".to_vec()).await?; let protected = Arc::new(Mutex::new(Vec::new())); blobs.add_protected(Box::new({ diff --git a/tests/gc.rs b/tests/gc.rs index a703ce5d2..56c79f3a3 100644 --- a/tests/gc.rs +++ b/tests/gc.rs @@ -20,7 +20,7 @@ use iroh::{protocol::Router, Endpoint, NodeAddr, NodeId}; use iroh_blobs::{ hashseq::HashSeq, net_protocol::Blobs, - rpc::client::{blobs, tags}, + rpc::{client::tags, RpcHandler}, store::{ bao_tree, BaoBatchWriter, ConsistencyCheckProgress, EntryStatus, GcConfig, MapEntryMut, MapMut, ReportLevel, Store, @@ -66,8 +66,8 @@ impl Node { } /// Returns an in-memory blobs client - pub fn blobs(&self) -> blobs::MemClient { - self.blobs.clone().client() + pub fn blobs(&self) -> RpcHandler { + self.blobs.client() } /// Returns an in-memory tags client From eeddbaa0b5bd632c60a3e5431bf810f3123259be Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Fri, 6 Dec 2024 11:53:08 +0200 Subject: [PATCH 3/5] spawn_rpc should make it sufficiently clear that this is a thing you need to put away somewhere. Or maybe spawn_client? --- examples/custom-protocol.rs | 2 +- examples/hello-world-fetch.rs | 2 +- examples/hello-world-provide.rs | 2 +- examples/local-swarm-discovery.rs | 2 +- examples/transfer.rs | 2 +- src/rpc.rs | 19 ++++++++++++++----- tests/blobs.rs | 4 ++-- tests/gc.rs | 2 +- 8 files changed, 22 insertions(+), 13 deletions(-) diff --git a/examples/custom-protocol.rs b/examples/custom-protocol.rs index a5127e2dd..41dfc1cc0 100644 --- a/examples/custom-protocol.rs +++ b/examples/custom-protocol.rs @@ -91,7 +91,7 @@ async fn main() -> Result<()> { let local_pool = LocalPool::default(); let blobs = Blobs::memory().build(local_pool.handle(), builder.endpoint()); let builder = builder.accept(iroh_blobs::ALPN, blobs.clone()); - let blobs_client = blobs.client(); + let blobs_client = blobs.spawn_rpc(); // Build our custom protocol handler. The `builder` exposes access to various subsystems in the // iroh node. In our case, we need a blobs client and the endpoint. diff --git a/examples/hello-world-fetch.rs b/examples/hello-world-fetch.rs index 0741a5cd9..7806366e0 100644 --- a/examples/hello-world-fetch.rs +++ b/examples/hello-world-fetch.rs @@ -42,7 +42,7 @@ async fn main() -> Result<()> { let blobs = Blobs::memory().build(local_pool.handle(), builder.endpoint()); let builder = builder.accept(iroh_blobs::ALPN, blobs.clone()); let node = builder.spawn().await?; - let blobs_client = blobs.client(); + let blobs_client = blobs.spawn_rpc(); println!("fetching hash: {}", ticket.hash()); println!("node id: {}", node.endpoint().node_id()); diff --git a/examples/hello-world-provide.rs b/examples/hello-world-provide.rs index 96fa028c2..10d2aac9a 100644 --- a/examples/hello-world-provide.rs +++ b/examples/hello-world-provide.rs @@ -28,7 +28,7 @@ async fn main() -> anyhow::Result<()> { let local_pool = LocalPool::default(); let blobs = Blobs::memory().build(local_pool.handle(), builder.endpoint()); let builder = builder.accept(iroh_blobs::ALPN, blobs.clone()); - let blobs_client = blobs.client(); + let blobs_client = blobs.spawn_rpc(); let node = builder.spawn().await?; // add some data and remember the hash diff --git a/examples/local-swarm-discovery.rs b/examples/local-swarm-discovery.rs index fd84a8f56..9cad60bf3 100644 --- a/examples/local-swarm-discovery.rs +++ b/examples/local-swarm-discovery.rs @@ -78,7 +78,7 @@ async fn main() -> anyhow::Result<()> { let blobs = Blobs::memory().build(local_pool.handle(), builder.endpoint()); let builder = builder.accept(iroh_blobs::ALPN, blobs.clone()); let node = builder.spawn().await?; - let blobs_client = blobs.client(); + let blobs_client = blobs.spawn_rpc(); match &cli.command { Commands::Accept { path } => { diff --git a/examples/transfer.rs b/examples/transfer.rs index 4e73909ea..1729f5c7b 100644 --- a/examples/transfer.rs +++ b/examples/transfer.rs @@ -26,7 +26,7 @@ async fn main() -> Result<()> { .spawn() .await?; - let blobs = blobs.client(); + let blobs = blobs.spawn_rpc(); let args = std::env::args().collect::>(); match &args.iter().map(String::as_str).collect::>()[..] { diff --git a/src/rpc.rs b/src/rpc.rs index cf46ac499..33bb57898 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -62,8 +62,9 @@ const RPC_BLOB_GET_CHUNK_SIZE: usize = 1024 * 64; const RPC_BLOB_GET_CHANNEL_CAP: usize = 2; impl Blobs { - /// Get a client for the blobs protocol - pub fn client(&self) -> RpcHandler { + /// Spawns an in-memory RPC client and server pair. + #[must_use = "Dropping the RpcHandler will stop the client"] + pub fn spawn_rpc(&self) -> RpcHandler { RpcHandler::new(self) } @@ -870,11 +871,14 @@ impl Blobs { } } -/// A rpc handler for the blobs rpc protocol +/// An in memory rpc handler for the blobs rpc protocol /// /// This struct contains both a task that handles rpc requests and a client -/// that can be used to send rpc requests. Dropping it will stop the handler task, -/// so you need to put it somewhere where it will be kept alive. +/// that can be used to send rpc requests. +/// +/// Dropping it will stop the handler task, so you need to put it somewhere +/// where it will be kept alive. This struct will capture a copy of +/// [`crate::net_protocol::Blobs`] and keep it alive. #[derive(Debug)] pub struct RpcHandler { /// Client to hand out @@ -902,4 +906,9 @@ impl RpcHandler { .spawn_accept_loop(move |req, chan| blobs.clone().handle_rpc_request(req, chan)); Self { client, _handler } } + + /// Get a reference to the rpc client api + pub fn client(&self) -> &MemClient { + &self.client + } } diff --git a/tests/blobs.rs b/tests/blobs.rs index ab273b118..8a9f41468 100644 --- a/tests/blobs.rs +++ b/tests/blobs.rs @@ -13,7 +13,7 @@ async fn blobs_gc_smoke() -> TestResult<()> { let pool = LocalPool::default(); let endpoint = Endpoint::builder().bind().await?; let blobs = Blobs::memory().build(pool.handle(), &endpoint); - let client = blobs.clone().client(); + let client = blobs.spawn_rpc(); blobs.start_gc(GcConfig { period: Duration::from_millis(1), done_callback: None, @@ -32,7 +32,7 @@ async fn blobs_gc_protected() -> TestResult<()> { let pool = LocalPool::default(); let endpoint = Endpoint::builder().bind().await?; let blobs = Blobs::memory().build(pool.handle(), &endpoint); - let client = blobs.clone().client(); + let client = blobs.spawn_rpc(); let h1 = client.add_bytes(b"test".to_vec()).await?; let protected = Arc::new(Mutex::new(Vec::new())); blobs.add_protected(Box::new({ diff --git a/tests/gc.rs b/tests/gc.rs index 56c79f3a3..e799febbe 100644 --- a/tests/gc.rs +++ b/tests/gc.rs @@ -67,7 +67,7 @@ impl Node { /// Returns an in-memory blobs client pub fn blobs(&self) -> RpcHandler { - self.blobs.client() + self.blobs.spawn_rpc() } /// Returns an in-memory tags client From e4273dace5ad0a2ac77f0ff060e59187e3c1e3f6 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Fri, 6 Dec 2024 12:33:50 +0200 Subject: [PATCH 4/5] back to the lazy client --- examples/custom-protocol.rs | 4 +- examples/hello-world-fetch.rs | 2 +- examples/hello-world-provide.rs | 2 +- examples/local-swarm-discovery.rs | 2 +- examples/transfer.rs | 2 +- src/net_protocol.rs | 177 +++--------------------- src/rpc.rs | 223 +++++++++++++++++++++++++++--- tests/blobs.rs | 4 +- tests/gc.rs | 6 +- 9 files changed, 237 insertions(+), 185 deletions(-) diff --git a/examples/custom-protocol.rs b/examples/custom-protocol.rs index 41dfc1cc0..a15194792 100644 --- a/examples/custom-protocol.rs +++ b/examples/custom-protocol.rs @@ -91,7 +91,7 @@ async fn main() -> Result<()> { let local_pool = LocalPool::default(); let blobs = Blobs::memory().build(local_pool.handle(), builder.endpoint()); let builder = builder.accept(iroh_blobs::ALPN, blobs.clone()); - let blobs_client = blobs.spawn_rpc(); + let blobs_client = blobs.client(); // Build our custom protocol handler. The `builder` exposes access to various subsystems in the // iroh node. In our case, we need a blobs client and the endpoint. @@ -122,7 +122,7 @@ async fn main() -> Result<()> { // Print out our query results. for hash in hashes { - read_and_print(&blobs_client, hash).await?; + read_and_print(blobs_client, hash).await?; } } } diff --git a/examples/hello-world-fetch.rs b/examples/hello-world-fetch.rs index 7806366e0..0741a5cd9 100644 --- a/examples/hello-world-fetch.rs +++ b/examples/hello-world-fetch.rs @@ -42,7 +42,7 @@ async fn main() -> Result<()> { let blobs = Blobs::memory().build(local_pool.handle(), builder.endpoint()); let builder = builder.accept(iroh_blobs::ALPN, blobs.clone()); let node = builder.spawn().await?; - let blobs_client = blobs.spawn_rpc(); + let blobs_client = blobs.client(); println!("fetching hash: {}", ticket.hash()); println!("node id: {}", node.endpoint().node_id()); diff --git a/examples/hello-world-provide.rs b/examples/hello-world-provide.rs index 10d2aac9a..96fa028c2 100644 --- a/examples/hello-world-provide.rs +++ b/examples/hello-world-provide.rs @@ -28,7 +28,7 @@ async fn main() -> anyhow::Result<()> { let local_pool = LocalPool::default(); let blobs = Blobs::memory().build(local_pool.handle(), builder.endpoint()); let builder = builder.accept(iroh_blobs::ALPN, blobs.clone()); - let blobs_client = blobs.spawn_rpc(); + let blobs_client = blobs.client(); let node = builder.spawn().await?; // add some data and remember the hash diff --git a/examples/local-swarm-discovery.rs b/examples/local-swarm-discovery.rs index 9cad60bf3..fd84a8f56 100644 --- a/examples/local-swarm-discovery.rs +++ b/examples/local-swarm-discovery.rs @@ -78,7 +78,7 @@ async fn main() -> anyhow::Result<()> { let blobs = Blobs::memory().build(local_pool.handle(), builder.endpoint()); let builder = builder.accept(iroh_blobs::ALPN, blobs.clone()); let node = builder.spawn().await?; - let blobs_client = blobs.spawn_rpc(); + let blobs_client = blobs.client(); match &cli.command { Commands::Accept { path } => { diff --git a/examples/transfer.rs b/examples/transfer.rs index 1729f5c7b..4e73909ea 100644 --- a/examples/transfer.rs +++ b/examples/transfer.rs @@ -26,7 +26,7 @@ async fn main() -> Result<()> { .spawn() .await?; - let blobs = blobs.spawn_rpc(); + let blobs = blobs.client(); let args = std::env::args().collect::>(); match &args.iter().map(String::as_str).collect::>()[..] { diff --git a/src/net_protocol.rs b/src/net_protocol.rs index 5dab29d5b..48667aa5c 100644 --- a/src/net_protocol.rs +++ b/src/net_protocol.rs @@ -5,28 +5,22 @@ use std::{collections::BTreeSet, fmt::Debug, ops::DerefMut, sync::Arc}; -use anyhow::{anyhow, bail, Result}; +use anyhow::{bail, Result}; use futures_lite::future::Boxed as BoxedFuture; use futures_util::future::BoxFuture; use iroh::{endpoint::Connecting, protocol::ProtocolHandler, Endpoint, NodeAddr}; use iroh_base::hash::{BlobFormat, Hash}; use serde::{Deserialize, Serialize}; -use tracing::{debug, warn}; +use tracing::debug; use crate::{ - downloader::{DownloadRequest, Downloader}, - get::{ - db::{DownloadProgress, GetState}, - Stats, - }, + downloader::Downloader, provider::EventSender, store::GcConfig, util::{ local_pool::{self, LocalPoolHandle}, - progress::{AsyncChannelProgressSender, ProgressSender}, SetTagOption, }, - HashAndFormat, }; /// A callback that blobs can ask about a set of hashes that should not be garbage collected. @@ -48,20 +42,22 @@ impl Default for GcState { } #[derive(Debug)] -struct BlobsInner { - rt: LocalPoolHandle, +pub(crate) struct BlobsInner { + pub(crate) rt: LocalPoolHandle, pub(crate) store: S, events: EventSender, - downloader: Downloader, - endpoint: Endpoint, + pub(crate) downloader: Downloader, + pub(crate) endpoint: Endpoint, gc_state: std::sync::Mutex, #[cfg(feature = "rpc")] - batches: tokio::sync::Mutex, + pub(crate) batches: tokio::sync::Mutex, } #[derive(Debug, Clone)] pub struct Blobs { - inner: Arc>, + pub(crate) inner: Arc>, + #[cfg(feature = "rpc")] + pub(crate) rpc_handler: Arc>, } /// Keeps track of all the currently active batch operations of the blobs api. @@ -79,7 +75,7 @@ pub(crate) struct BlobBatches { #[derive(Debug, Default)] struct BlobBatch { /// The tags in this batch. - tags: std::collections::BTreeMap>, + tags: std::collections::BTreeMap>, } #[cfg(feature = "rpc")] @@ -98,7 +94,11 @@ impl BlobBatches { } /// Remove a tag from a batch. - pub fn remove_one(&mut self, batch: BatchId, content: &HashAndFormat) -> Result<()> { + pub fn remove_one( + &mut self, + batch: BatchId, + content: &iroh::hash::HashAndFormat, + ) -> Result<()> { if let Some(batch) = self.batches.get_mut(&batch) { if let Some(tags) = batch.tags.get_mut(content) { tags.pop(); @@ -191,6 +191,8 @@ impl Blobs { batches: Default::default(), gc_state: Default::default(), }), + #[cfg(feature = "rpc")] + rpc_handler: Default::default(), } } @@ -255,147 +257,6 @@ impl Blobs { *state = GcState::Started(Some(run)); Ok(()) } - - #[cfg(feature = "rpc")] - pub(crate) async fn batches(&self) -> tokio::sync::MutexGuard<'_, BlobBatches> { - self.inner.batches.lock().await - } - - pub(crate) async fn download( - &self, - endpoint: Endpoint, - req: BlobDownloadRequest, - progress: AsyncChannelProgressSender, - ) -> Result<()> { - let BlobDownloadRequest { - hash, - format, - nodes, - tag, - mode, - } = req; - let hash_and_format = HashAndFormat { hash, format }; - let temp_tag = self.store().temp_tag(hash_and_format); - let stats = match mode { - DownloadMode::Queued => { - self.download_queued(endpoint, hash_and_format, nodes, progress.clone()) - .await? - } - DownloadMode::Direct => { - self.download_direct_from_nodes(endpoint, hash_and_format, nodes, progress.clone()) - .await? - } - }; - - progress.send(DownloadProgress::AllDone(stats)).await.ok(); - match tag { - SetTagOption::Named(tag) => { - self.store().set_tag(tag, Some(hash_and_format)).await?; - } - SetTagOption::Auto => { - self.store().create_tag(hash_and_format).await?; - } - } - drop(temp_tag); - - Ok(()) - } - - async fn download_queued( - &self, - endpoint: Endpoint, - hash_and_format: HashAndFormat, - nodes: Vec, - progress: AsyncChannelProgressSender, - ) -> Result { - /// Name used for logging when new node addresses are added from gossip. - const BLOB_DOWNLOAD_SOURCE_NAME: &str = "blob_download"; - - let mut node_ids = Vec::with_capacity(nodes.len()); - let mut any_added = false; - for node in nodes { - node_ids.push(node.node_id); - if !node.info.is_empty() { - endpoint.add_node_addr_with_source(node, BLOB_DOWNLOAD_SOURCE_NAME)?; - any_added = true; - } - } - let can_download = !node_ids.is_empty() && (any_added || endpoint.discovery().is_some()); - anyhow::ensure!(can_download, "no way to reach a node for download"); - let req = DownloadRequest::new(hash_and_format, node_ids).progress_sender(progress); - let handle = self.downloader().queue(req).await; - let stats = handle.await?; - Ok(stats) - } - - #[tracing::instrument("download_direct", skip_all, fields(hash=%hash_and_format.hash.fmt_short()))] - async fn download_direct_from_nodes( - &self, - endpoint: Endpoint, - hash_and_format: HashAndFormat, - nodes: Vec, - progress: AsyncChannelProgressSender, - ) -> Result { - let mut last_err = None; - let mut remaining_nodes = nodes.len(); - let mut nodes_iter = nodes.into_iter(); - 'outer: loop { - match crate::get::db::get_to_db_in_steps( - self.store().clone(), - hash_and_format, - progress.clone(), - ) - .await? - { - GetState::Complete(stats) => return Ok(stats), - GetState::NeedsConn(needs_conn) => { - let (conn, node_id) = 'inner: loop { - match nodes_iter.next() { - None => break 'outer, - Some(node) => { - remaining_nodes -= 1; - let node_id = node.node_id; - if node_id == endpoint.node_id() { - debug!( - ?remaining_nodes, - "skip node {} (it is the node id of ourselves)", - node_id.fmt_short() - ); - continue 'inner; - } - match endpoint.connect(node, crate::protocol::ALPN).await { - Ok(conn) => break 'inner (conn, node_id), - Err(err) => { - debug!( - ?remaining_nodes, - "failed to connect to {}: {err}", - node_id.fmt_short() - ); - continue 'inner; - } - } - } - } - }; - match needs_conn.proceed(conn).await { - Ok(stats) => return Ok(stats), - Err(err) => { - warn!( - ?remaining_nodes, - "failed to download from {}: {err}", - node_id.fmt_short() - ); - last_err = Some(err); - } - } - } - } - } - match last_err { - Some(err) => Err(err.into()), - None => Err(anyhow!("No nodes to download from provided")), - } - } } impl ProtocolHandler for Blobs { diff --git a/src/rpc.rs b/src/rpc.rs index 33bb57898..4df12e724 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -8,7 +8,7 @@ use std::{ use anyhow::anyhow; use client::{ - blobs::{BlobInfo, BlobStatus, IncompleteBlobInfo, MemClient, WrapOption}, + blobs::{self, BlobInfo, BlobStatus, DownloadMode, IncompleteBlobInfo, MemClient, WrapOption}, tags::TagInfo, MemConnector, }; @@ -16,6 +16,7 @@ use futures_buffered::BufferedStreamExt; use futures_lite::StreamExt; use futures_util::{FutureExt, Stream}; use genawaiter::sync::{Co, Gen}; +use iroh::{Endpoint, NodeAddr}; use iroh_base::hash::{BlobFormat, HashAndFormat}; use iroh_io::AsyncSliceReader; use proto::{ @@ -39,15 +40,21 @@ use quic_rpc::{ RpcClient, RpcServer, }; use tokio_util::task::AbortOnDropHandle; +use tracing::{debug, warn}; use crate::{ + downloader::{DownloadRequest, Downloader}, export::ExportProgress, format::collection::Collection, - get::db::DownloadProgress, - net_protocol::{BlobDownloadRequest, Blobs}, + get::{ + db::{DownloadProgress, GetState}, + Stats, + }, + net_protocol::{BlobDownloadRequest, Blobs, BlobsInner}, provider::{AddProgress, BatchAddPathProgress}, store::{ConsistencyCheckProgress, ImportProgress, MapEntry, ValidateProgress}, util::{ + local_pool::LocalPoolHandle, progress::{AsyncChannelProgressSender, ProgressSender}, SetTagOption, }, @@ -62,10 +69,62 @@ const RPC_BLOB_GET_CHUNK_SIZE: usize = 1024 * 64; const RPC_BLOB_GET_CHANNEL_CAP: usize = 2; impl Blobs { - /// Spawns an in-memory RPC client and server pair. - #[must_use = "Dropping the RpcHandler will stop the client"] - pub fn spawn_rpc(&self) -> RpcHandler { - RpcHandler::new(self) + /// Get a client for the blobs protocol + pub fn client(&self) -> &blobs::MemClient { + &self + .rpc_handler + .get_or_init(|| RpcHandler::new(&self.inner)) + .client + } + + /// Handle an RPC request + pub async fn handle_rpc_request( + self, + msg: Request, + chan: RpcChannel, + ) -> std::result::Result<(), RpcServerError> + where + C: ChannelTypes, + { + Handler(self.inner.clone()) + .handle_rpc_request(msg, chan) + .await + } +} + +#[derive(Clone)] +struct Handler(Arc>); + +impl Deref for Handler { + type Target = BlobsInner; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Handler { + fn store(&self) -> &D { + &self.0.store + } + + fn rt(&self) -> &LocalPoolHandle { + &self.0.rt + } + + fn endpoint(&self) -> &Endpoint { + &self.0.endpoint + } + + fn downloader(&self) -> &Downloader { + &self.0.downloader + } + + #[cfg(feature = "rpc")] + pub(crate) async fn batches( + &self, + ) -> tokio::sync::MutexGuard<'_, crate::net_protocol::BlobBatches> { + self.0.batches.lock().await } /// Handle an RPC request @@ -869,6 +928,142 @@ impl Blobs { Ok(CreateCollectionResponse { hash, tag }) } + + pub(crate) async fn download( + &self, + endpoint: Endpoint, + req: BlobDownloadRequest, + progress: AsyncChannelProgressSender, + ) -> anyhow::Result<()> { + let BlobDownloadRequest { + hash, + format, + nodes, + tag, + mode, + } = req; + let hash_and_format = HashAndFormat { hash, format }; + let temp_tag = self.store().temp_tag(hash_and_format); + let stats = match mode { + DownloadMode::Queued => { + self.download_queued(endpoint, hash_and_format, nodes, progress.clone()) + .await? + } + DownloadMode::Direct => { + self.download_direct_from_nodes(endpoint, hash_and_format, nodes, progress.clone()) + .await? + } + }; + + progress.send(DownloadProgress::AllDone(stats)).await.ok(); + match tag { + SetTagOption::Named(tag) => { + self.store().set_tag(tag, Some(hash_and_format)).await?; + } + SetTagOption::Auto => { + self.store().create_tag(hash_and_format).await?; + } + } + drop(temp_tag); + + Ok(()) + } + + async fn download_queued( + &self, + endpoint: Endpoint, + hash_and_format: HashAndFormat, + nodes: Vec, + progress: AsyncChannelProgressSender, + ) -> anyhow::Result { + /// Name used for logging when new node addresses are added from gossip. + const BLOB_DOWNLOAD_SOURCE_NAME: &str = "blob_download"; + + let mut node_ids = Vec::with_capacity(nodes.len()); + let mut any_added = false; + for node in nodes { + node_ids.push(node.node_id); + if !node.info.is_empty() { + endpoint.add_node_addr_with_source(node, BLOB_DOWNLOAD_SOURCE_NAME)?; + any_added = true; + } + } + let can_download = !node_ids.is_empty() && (any_added || endpoint.discovery().is_some()); + anyhow::ensure!(can_download, "no way to reach a node for download"); + let req = DownloadRequest::new(hash_and_format, node_ids).progress_sender(progress); + let handle = self.downloader().queue(req).await; + let stats = handle.await?; + Ok(stats) + } + + #[tracing::instrument("download_direct", skip_all, fields(hash=%hash_and_format.hash.fmt_short()))] + async fn download_direct_from_nodes( + &self, + endpoint: Endpoint, + hash_and_format: HashAndFormat, + nodes: Vec, + progress: AsyncChannelProgressSender, + ) -> anyhow::Result { + let mut last_err = None; + let mut remaining_nodes = nodes.len(); + let mut nodes_iter = nodes.into_iter(); + 'outer: loop { + match crate::get::db::get_to_db_in_steps( + self.store().clone(), + hash_and_format, + progress.clone(), + ) + .await? + { + GetState::Complete(stats) => return Ok(stats), + GetState::NeedsConn(needs_conn) => { + let (conn, node_id) = 'inner: loop { + match nodes_iter.next() { + None => break 'outer, + Some(node) => { + remaining_nodes -= 1; + let node_id = node.node_id; + if node_id == endpoint.node_id() { + debug!( + ?remaining_nodes, + "skip node {} (it is the node id of ourselves)", + node_id.fmt_short() + ); + continue 'inner; + } + match endpoint.connect(node, crate::protocol::ALPN).await { + Ok(conn) => break 'inner (conn, node_id), + Err(err) => { + debug!( + ?remaining_nodes, + "failed to connect to {}: {err}", + node_id.fmt_short() + ); + continue 'inner; + } + } + } + } + }; + match needs_conn.proceed(conn).await { + Ok(stats) => return Ok(stats), + Err(err) => { + warn!( + ?remaining_nodes, + "failed to download from {}: {err}", + node_id.fmt_short() + ); + last_err = Some(err); + } + } + } + } + } + match last_err { + Some(err) => Err(err.into()), + None => Err(anyhow!("No nodes to download from provided")), + } + } } /// An in memory rpc handler for the blobs rpc protocol @@ -880,7 +1075,7 @@ impl Blobs { /// where it will be kept alive. This struct will capture a copy of /// [`crate::net_protocol::Blobs`] and keep it alive. #[derive(Debug)] -pub struct RpcHandler { +pub(crate) struct RpcHandler { /// Client to hand out client: MemClient, /// Handler task @@ -896,19 +1091,15 @@ impl Deref for RpcHandler { } impl RpcHandler { - fn new(blobs: &Blobs) -> Self { + fn new(blobs: &Arc>) -> Self { let blobs = blobs.clone(); let (listener, connector) = quic_rpc::transport::flume::channel(1); let listener = RpcServer::new(listener); let client = RpcClient::new(connector); let client = MemClient::new(client); - let _handler = listener - .spawn_accept_loop(move |req, chan| blobs.clone().handle_rpc_request(req, chan)); + let _handler = listener.spawn_accept_loop(move |req, chan| { + Handler(blobs.clone()).handle_rpc_request(req, chan) + }); Self { client, _handler } } - - /// Get a reference to the rpc client api - pub fn client(&self) -> &MemClient { - &self.client - } } diff --git a/tests/blobs.rs b/tests/blobs.rs index 8a9f41468..c74484050 100644 --- a/tests/blobs.rs +++ b/tests/blobs.rs @@ -13,7 +13,7 @@ async fn blobs_gc_smoke() -> TestResult<()> { let pool = LocalPool::default(); let endpoint = Endpoint::builder().bind().await?; let blobs = Blobs::memory().build(pool.handle(), &endpoint); - let client = blobs.spawn_rpc(); + let client = blobs.client(); blobs.start_gc(GcConfig { period: Duration::from_millis(1), done_callback: None, @@ -32,7 +32,7 @@ async fn blobs_gc_protected() -> TestResult<()> { let pool = LocalPool::default(); let endpoint = Endpoint::builder().bind().await?; let blobs = Blobs::memory().build(pool.handle(), &endpoint); - let client = blobs.spawn_rpc(); + let client = blobs.client(); let h1 = client.add_bytes(b"test".to_vec()).await?; let protected = Arc::new(Mutex::new(Vec::new())); blobs.add_protected(Box::new({ diff --git a/tests/gc.rs b/tests/gc.rs index e799febbe..56ab4746c 100644 --- a/tests/gc.rs +++ b/tests/gc.rs @@ -20,7 +20,7 @@ use iroh::{protocol::Router, Endpoint, NodeAddr, NodeId}; use iroh_blobs::{ hashseq::HashSeq, net_protocol::Blobs, - rpc::{client::tags, RpcHandler}, + rpc::client::{blobs, tags}, store::{ bao_tree, BaoBatchWriter, ConsistencyCheckProgress, EntryStatus, GcConfig, MapEntryMut, MapMut, ReportLevel, Store, @@ -66,8 +66,8 @@ impl Node { } /// Returns an in-memory blobs client - pub fn blobs(&self) -> RpcHandler { - self.blobs.spawn_rpc() + pub fn blobs(&self) -> &blobs::MemClient { + self.blobs.client() } /// Returns an in-memory tags client From f211439a0179aab6519ca78ce76d4c652bd67aca Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Fri, 6 Dec 2024 12:41:57 +0200 Subject: [PATCH 5/5] add comment about the purpose of the handler --- src/rpc.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/rpc.rs b/src/rpc.rs index 4df12e724..f3b24ab82 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -92,6 +92,8 @@ impl Blobs { } } +/// This is just an internal helper so I don't have to +/// define all the rpc methods on `self: Arc>` #[derive(Clone)] struct Handler(Arc>);