diff --git a/examples/expiring-tags.rs b/examples/expiring-tags.rs index 9c30d1b7..9abaff28 100644 --- a/examples/expiring-tags.rs +++ b/examples/expiring-tags.rs @@ -73,7 +73,7 @@ async fn delete_expired_tags(blobs: &Store, prefix: &str, bulk: bool) -> anyhow: // find tags to delete one by one and then delete them // // this allows us to print the tags before deleting them - let mut tags = blobs.tags().list().await?; + let mut tags = blobs.tags().list().stream(); let mut to_delete = Vec::new(); while let Some(tag) = tags.next().await { let tag = tag?.name; @@ -102,7 +102,7 @@ async fn delete_expired_tags(blobs: &Store, prefix: &str, bulk: bool) -> anyhow: async fn print_store_info(store: &Store) -> anyhow::Result<()> { let now = chrono::Utc::now(); - let mut tags = store.tags().list().await?; + let mut tags = store.tags().list().stream(); println!( "Current time: {}", now.to_rfc3339_opts(chrono::SecondsFormat::Secs, true) @@ -112,7 +112,7 @@ async fn print_store_info(store: &Store) -> anyhow::Result<()> { let tag = tag?; println!(" {tag:?}"); } - let mut blobs = store.list().stream().await?; + let mut blobs = store.list().stream(); println!("Blobs:"); while let Some(item) = blobs.next().await { println!(" {}", item?); diff --git a/src/api/blobs.rs b/src/api/blobs.rs index d0b94859..ae11d15b 100644 --- a/src/api/blobs.rs +++ b/src/api/blobs.rs @@ -56,10 +56,13 @@ use super::{ ApiClient, RequestResult, Tags, }; use crate::{ - api::proto::{BatchRequest, ImportByteStreamUpdate}, + api::proto::{BatchRequest, ImportByteStreamUpdate, ListBlobsItem}, provider::StreamContext, store::IROH_BLOCK_SIZE, - util::temp_tag::TempTag, + util::{ + irpc::{IrpcReceiverFutExt, IrpcStreamItem}, + temp_tag::TempTag, + }, BlobFormat, Hash, HashAndFormat, }; @@ -835,34 +838,48 @@ impl ImportBaoHandle { /// A progress handle for a blobs list operation. pub struct BlobsListProgress { - inner: future::Boxed>>>, + inner: future::Boxed>>, } impl BlobsListProgress { fn new( - fut: impl Future>>> + Send + 'static, + fut: impl Future>> + Send + 'static, ) -> Self { Self { inner: Box::pin(fut), } } - pub async fn hashes(self) -> RequestResult> { - let mut rx: mpsc::Receiver> = self.inner.await?; - let mut hashes = Vec::new(); - while let Some(item) = rx.recv().await? { - hashes.push(item?); + pub async fn hashes(self) -> super::Result> { + self.inner.try_collect().await + } + + pub fn stream(self) -> impl Stream> { + self.inner.into_stream() + } +} + +impl IrpcStreamItem for ListBlobsItem { + type Error = super::Error; + type Item = Hash; + + fn into_result_opt(self) -> Option> { + match self { + Self::Item(hash) => Some(Ok(hash)), + Self::Error(e) => Some(Err(e)), + Self::Done => None, } - Ok(hashes) } - pub async fn stream(self) -> irpc::Result>> { - let mut rx = self.inner.await?; - Ok(Gen::new(|co| async move { - while let Ok(Some(item)) = rx.recv().await { - co.yield_(item).await; - } - })) + fn from_result(item: std::result::Result) -> Self { + match item { + Ok(hash) => Self::Item(hash), + Err(e) => Self::Error(e), + } + } + + fn done() -> Self { + Self::Done } } diff --git a/src/api/proto.rs b/src/api/proto.rs index 8b3780bd..936e9a7d 100644 --- a/src/api/proto.rs +++ b/src/api/proto.rs @@ -14,7 +14,9 @@ //! The file system store is quite complex and optimized, so to get started take a look at //! the much simpler memory store. use std::{ + collections::HashSet, fmt::{self, Debug}, + future::{Future, IntoFuture}, io, num::NonZeroU64, ops::{Bound, RangeBounds}, @@ -32,13 +34,20 @@ use irpc::{ channel::{mpsc, oneshot}, rpc_requests, }; -use n0_future::Stream; +use n0_future::{future, Stream}; use range_collections::RangeSet2; use serde::{Deserialize, Serialize}; pub(crate) mod bitfield; pub use bitfield::Bitfield; -use crate::{store::util::Tag, util::temp_tag::TempTag, BlobFormat, Hash, HashAndFormat}; +use crate::{ + store::util::Tag, + util::{ + irpc::{IrpcReceiverFutExt, IrpcStreamItem}, + temp_tag::TempTag, + }, + BlobFormat, Hash, HashAndFormat, +}; pub(crate) trait HashSpecific { fn hash(&self) -> Hash; @@ -89,7 +98,7 @@ impl HashSpecific for CreateTagMsg { #[rpc_requests(message = Command, alias = "Msg")] #[derive(Debug, Serialize, Deserialize)] pub enum Request { - #[rpc(tx = mpsc::Sender>)] + #[rpc(tx = mpsc::Sender)] ListBlobs(ListRequest), #[rpc(tx = oneshot::Sender, rx = mpsc::Receiver)] Batch(BatchRequest), @@ -113,7 +122,7 @@ pub enum Request { ImportPath(ImportPathRequest), #[rpc(tx = mpsc::Sender)] ExportPath(ExportPathRequest), - #[rpc(tx = oneshot::Sender>>)] + #[rpc(tx = mpsc::Sender)] ListTags(ListTagsRequest), #[rpc(tx = oneshot::Sender>)] SetTag(SetTagRequest), @@ -123,7 +132,7 @@ pub enum Request { RenameTag(RenameTagRequest), #[rpc(tx = oneshot::Sender>)] CreateTag(CreateTagRequest), - #[rpc(tx = oneshot::Sender>)] + #[rpc(tx = mpsc::Sender)] ListTempTags(ListTempTagsRequest), #[rpc(tx = oneshot::Sender)] CreateTempTag(CreateTempTagRequest), @@ -351,6 +360,109 @@ pub struct TagInfo { pub hash: Hash, } +#[derive(Debug, Serialize, Deserialize)] +pub enum ListBlobsItem { + Item(Hash), + Error(super::Error), + Done, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum ListTagsItem { + Item(TagInfo), + Error(super::Error), + Done, +} + +impl From> for ListTagsItem { + fn from(item: std::result::Result) -> Self { + match item { + Ok(item) => ListTagsItem::Item(item), + Err(err) => ListTagsItem::Error(err), + } + } +} + +impl IrpcStreamItem for ListTagsItem { + type Error = super::Error; + type Item = TagInfo; + + fn into_result_opt(self) -> Option> { + match self { + ListTagsItem::Item(item) => Some(Ok(item)), + ListTagsItem::Done => None, + ListTagsItem::Error(err) => Some(Err(err)), + } + } + + fn from_result(item: std::result::Result) -> Self { + match item { + Ok(i) => Self::Item(i), + Err(e) => Self::Error(e), + } + } + + fn done() -> Self { + Self::Done + } +} + +pub struct ListTempTagsProgress { + inner: future::Boxed>>, +} + +impl IntoFuture for ListTempTagsProgress { + fn into_future(self) -> Self::IntoFuture { + Box::pin(self.inner.try_collect()) + } + + type IntoFuture = future::Boxed; + + type Output = super::Result>; +} + +impl ListTempTagsProgress { + pub(super) fn new( + fut: impl Future>> + Send + 'static, + ) -> Self { + Self { + inner: Box::pin(fut), + } + } + + pub fn stream(self) -> impl Stream> { + self.inner.into_stream() + } +} + +pub struct ListTagsProgress { + inner: future::Boxed>>, +} + +impl IntoFuture for ListTagsProgress { + fn into_future(self) -> Self::IntoFuture { + Box::pin(self.inner.try_collect()) + } + + type IntoFuture = future::Boxed; + + type Output = super::Result>; +} + +impl ListTagsProgress { + pub(super) fn new( + fut: impl Future>> + Send + 'static, + ) -> Self { + Self { + inner: Box::pin(fut), + } + } + + pub fn stream(self) -> impl Stream> { + self.inner.into_stream() + } +} + impl From for HashAndFormat { fn from(tag_info: TagInfo) -> Self { HashAndFormat { @@ -410,6 +522,37 @@ pub struct CreateTempTagRequest { #[derive(Debug, Serialize, Deserialize)] pub struct ListTempTagsRequest; +#[derive(Debug, Serialize, Deserialize)] +pub enum ListTempTagsItem { + Item(HashAndFormat), + Error(super::Error), + Done, +} + +impl IrpcStreamItem for ListTempTagsItem { + type Error = super::Error; + type Item = HashAndFormat; + + fn into_result_opt(self) -> Option> { + match self { + ListTempTagsItem::Item(item) => Some(Ok(item)), + ListTempTagsItem::Done => None, + ListTempTagsItem::Error(err) => Some(Err(err)), + } + } + + fn from_result(item: std::result::Result) -> Self { + match item { + Ok(i) => Self::Item(i), + Err(e) => Self::Error(e), + } + } + + fn done() -> Self { + Self::Done + } +} + /// Rename a tag atomically #[derive(Debug, Serialize, Deserialize)] pub struct RenameTagRequest { diff --git a/src/api/tags.rs b/src/api/tags.rs index b235a8c6..e07da06e 100644 --- a/src/api/tags.rs +++ b/src/api/tags.rs @@ -3,7 +3,7 @@ //! The main entry point is the [`Tags`] struct. use std::ops::RangeBounds; -use n0_future::{Stream, StreamExt}; +use n0_future::StreamExt; use ref_cast::RefCast; use tracing::trace; @@ -16,7 +16,10 @@ use super::{ proto::{CreateTempTagRequest, Scope}, ApiClient, Tag, TempTag, }; -use crate::{api::proto::ListTempTagsRequest, HashAndFormat}; +use crate::{ + api::proto::{ListTagsProgress, ListTempTagsProgress, ListTempTagsRequest}, + HashAndFormat, +}; /// The API for interacting with tags and temp tags. #[derive(Debug, Clone, ref_cast::RefCast)] @@ -30,32 +33,25 @@ impl Tags { Self::ref_cast(sender) } - pub async fn list_temp_tags(&self) -> irpc::Result> { + pub fn list_temp_tags(&self) -> ListTempTagsProgress { let options = ListTempTagsRequest; trace!("{:?}", options); - let res = self.client.rpc(options).await?; - Ok(n0_future::stream::iter(res)) + ListTempTagsProgress::new(self.client.server_streaming(options, 32)) } /// List all tags with options. /// /// This is the most flexible way to list tags. All the other list methods are just convenience /// methods that call this one with the appropriate options. - pub async fn list_with_opts( - &self, - options: ListOptions, - ) -> irpc::Result>> { + pub fn list_with_opts(&self, options: ListOptions) -> ListTagsProgress { trace!("{:?}", options); - let res = self.client.rpc(options).await?; - Ok(n0_future::stream::iter(res)) + ListTagsProgress::new(self.client.server_streaming(options, 32)) } /// Get the value of a single tag pub async fn get(&self, name: impl AsRef<[u8]>) -> super::RequestResult> { - let mut stream = self - .list_with_opts(ListOptions::single(name.as_ref())) - .await?; - Ok(stream.next().await.transpose()?) + let progress = self.list_with_opts(ListOptions::single(name.as_ref())); + Ok(progress.stream().next().await.transpose()?) } pub async fn set_with_opts(&self, options: SetOptions) -> super::RequestResult<()> { @@ -77,34 +73,27 @@ impl Tags { } /// List a range of tags - pub async fn list_range( - &self, - range: R, - ) -> irpc::Result>> + pub fn list_range(&self, range: R) -> ListTagsProgress where R: RangeBounds, E: AsRef<[u8]>, { - self.list_with_opts(ListOptions::range(range)).await + self.list_with_opts(ListOptions::range(range)) } /// Lists all tags with the given prefix. - pub async fn list_prefix( - &self, - prefix: impl AsRef<[u8]>, - ) -> irpc::Result>> { + pub fn list_prefix(&self, prefix: impl AsRef<[u8]>) -> ListTagsProgress { self.list_with_opts(ListOptions::prefix(prefix.as_ref())) - .await } /// Lists all tags. - pub async fn list(&self) -> irpc::Result>> { - self.list_with_opts(ListOptions::all()).await + pub fn list(&self) -> ListTagsProgress { + self.list_with_opts(ListOptions::all()) } /// Lists all tags with a hash_seq format. - pub async fn list_hash_seq(&self) -> irpc::Result>> { - self.list_with_opts(ListOptions::hash_seq()).await + pub fn list_hash_seq(&self) -> ListTagsProgress { + self.list_with_opts(ListOptions::hash_seq()) } /// Deletes a tag. diff --git a/src/store/fs.rs b/src/store/fs.rs index 9e11e098..6ed391ee 100644 --- a/src/store/fs.rs +++ b/src/store/fs.rs @@ -94,7 +94,6 @@ use entry_state::{DataLocation, OutboardLocation}; use gc::run_gc; use import::{ImportEntry, ImportSource}; use irpc::channel::mpsc; -use meta::list_blobs; use n0_future::{future::yield_now, io}; use nested_enum_utils::enum_conversions; use range_collections::range_set::RangeSetRange; @@ -124,6 +123,7 @@ use crate::{ }, util::{ channel::oneshot, + irpc::MpscSenderExt, temp_tag::{TagDrop, TempTag, TempTagScope, TempTags}, ChunkRangesExt, }, @@ -507,9 +507,7 @@ impl Actor { } Command::ListBlobs(cmd) => { trace!("{cmd:?}"); - if let Ok(snapshot) = self.db().snapshot(cmd.span.clone()).await { - self.spawn(list_blobs(snapshot, cmd)); - } + self.db().send(cmd.into()).await.ok(); } Command::Batch(cmd) => { trace!("{cmd:?}"); @@ -523,7 +521,7 @@ impl Actor { Command::ListTempTags(cmd) => { trace!("{cmd:?}"); let tts = self.temp_tags.list(); - cmd.tx.send(tts).await.ok(); + cmd.tx.forward_iter(tts.into_iter().map(Ok)).await.ok(); } Command::ImportBytes(cmd) => { trace!("{cmd:?}"); @@ -1420,13 +1418,13 @@ impl FsStore { #[cfg(test)] pub mod tests { use core::panic; - use std::collections::{HashMap, HashSet}; + use std::{collections::HashMap, future::IntoFuture}; use bao_tree::{ io::{outboard::PreOrderMemOutboard, round_up_to_chunks_groups}, ChunkRanges, }; - use n0_future::{stream, Stream, StreamExt}; + use n0_future::{stream, Stream}; use testresult::TestResult; use walkdir::WalkDir; @@ -1976,23 +1974,13 @@ pub mod tests { let batch = store.blobs().batch().await?; let tt1 = batch.temp_tag(Hash::new("foo")).await?; let tt2 = batch.add_slice("boo").await?; - let tts = store - .tags() - .list_temp_tags() - .await? - .collect::>() - .await; + let tts = store.tags().list_temp_tags().into_future().await?; assert!(tts.contains(tt1.hash_and_format())); assert!(tts.contains(tt2.hash_and_format())); drop(batch); store.sync_db().await?; store.wait_idle().await?; - let tts = store - .tags() - .list_temp_tags() - .await? - .collect::>() - .await; + let tts = store.tags().list_temp_tags().await?; // temp tag went out of scope, so it does not work anymore assert!(!tts.contains(tt1.hash_and_format())); assert!(!tts.contains(tt2.hash_and_format())); diff --git a/src/store/fs/gc.rs b/src/store/fs/gc.rs index da7836e7..e53ea7a3 100644 --- a/src/store/fs/gc.rs +++ b/src/store/fs/gc.rs @@ -48,17 +48,17 @@ pub(super) async fn gc_mark_task( } let mut roots = HashSet::new(); trace!("traversing tags"); - let mut tags = store.tags().list().await?; + let mut tags = store.tags().list().stream(); while let Some(tag) = tags.next().await { let info = tag?; trace!("adding root {:?} {:?}", info.name, info.hash_and_format()); roots.insert(info.hash_and_format()); } trace!("traversing temp roots"); - let mut tts = store.tags().list_temp_tags().await?; + let mut tts = store.tags().list_temp_tags().stream(); while let Some(tt) = tts.next().await { trace!("adding temp root {:?}", tt); - roots.insert(tt); + roots.insert(tt?); } for HashAndFormat { hash, format } in roots { // we need to do this for all formats except raw @@ -85,7 +85,7 @@ async fn gc_sweep_task( live: &HashSet, co: &Co, ) -> crate::api::Result<()> { - let mut blobs = store.blobs().list().stream().await?; + let mut blobs = store.blobs().list().stream(); let mut count = 0; let mut batch = Vec::new(); while let Some(hash) = blobs.next().await { diff --git a/src/store/fs/meta.rs b/src/store/fs/meta.rs index 21fbd9ed..1f60dbfd 100644 --- a/src/store/fs/meta.rs +++ b/src/store/fs/meta.rs @@ -15,7 +15,7 @@ use n0_snafu::SpanTrace; use nested_enum_utils::common_fields; use redb::{Database, DatabaseError, ReadableTable}; use snafu::{Backtrace, ResultExt, Snafu}; -use tokio::pin; +use tokio::{pin, task::JoinSet}; use crate::{ api::{ @@ -23,8 +23,9 @@ use crate::{ blobs::BlobStatus, proto::{ BlobDeleteRequest, BlobStatusMsg, BlobStatusRequest, ClearProtectedMsg, - CreateTagRequest, DeleteBlobsMsg, DeleteTagsRequest, ListBlobsMsg, ListRequest, - ListTagsRequest, RenameTagRequest, SetTagRequest, ShutdownMsg, SyncDbMsg, + CreateTagRequest, DeleteBlobsMsg, DeleteTagsRequest, ListBlobsItem, ListBlobsMsg, + ListRequest, ListTagsItem, ListTagsRequest, RenameTagRequest, SetTagRequest, + ShutdownMsg, SyncDbMsg, }, tags::TagInfo, }, @@ -96,15 +97,6 @@ impl Db { Self { sender } } - pub async fn snapshot(&self, span: tracing::Span) -> io::Result { - let (tx, rx) = tokio::sync::oneshot::channel(); - self.sender - .send(Snapshot { tx, span }.into()) - .await - .map_err(|_| io::Error::other("send snapshot"))?; - rx.await.map_err(|_| io::Error::other("receive snapshot")) - } - pub async fn update_await(&self, hash: Hash, state: EntryState) -> io::Result<()> { let (tx, rx) = oneshot::channel(); self.sender @@ -309,7 +301,6 @@ async fn handle_list_tags(msg: ListTagsMsg, tables: &impl ReadableTables) -> Act } = msg; let from = from.map(Bound::Included).unwrap_or(Bound::Unbounded); let to = to.map(Bound::Excluded).unwrap_or(Bound::Unbounded); - let mut res = Vec::new(); for item in tables.tags().range((from, to)).context(StorageSnafu)? { match item { Ok((k, v)) => { @@ -320,15 +311,20 @@ async fn handle_list_tags(msg: ListTagsMsg, tables: &impl ReadableTables) -> Act hash: v.hash, format: v.format, }; - res.push(crate::api::Result::Ok(info)); + if tx.send(ListTagsItem::Item(info)).await.is_err() { + return Ok(()); + } } } Err(e) => { - res.push(Err(crate::api::Error::other(e))); + tx.send(ListTagsItem::Error(crate::api::Error::other(e))) + .await + .ok(); + return Ok(()); } } } - tx.send(res).await.ok(); + tx.send(ListTagsItem::Done).await.ok(); Ok(()) } @@ -463,6 +459,7 @@ pub struct Actor { ds: DeleteHandle, options: BatchOptions, protected: HashSet, + tasks: JoinSet<()>, } impl Actor { @@ -492,6 +489,7 @@ impl Actor { ds, options, protected: Default::default(), + tasks: JoinSet::new(), }) } @@ -707,6 +705,7 @@ impl Actor { async fn handle_toplevel( db: &mut Database, + tasks: &mut JoinSet<()>, cmd: TopLevelCommand, op: TxnNum, ) -> ActorResult> { @@ -726,11 +725,11 @@ impl Actor { // nothing to do here, since the database will be dropped Some(cmd) } - TopLevelCommand::Snapshot(cmd) => { + TopLevelCommand::ListBlobs(cmd) => { trace!("{cmd:?}"); let txn = db.begin_read().context(TransactionSnafu)?; let snapshot = ReadOnlyTables::new(&txn).context(TableSnafu)?; - cmd.tx.send(snapshot).ok(); + tasks.spawn(list_blobs(snapshot, cmd)); None } }) @@ -741,14 +740,20 @@ impl Actor { let options = &self.options; let mut op = 0u64; let shutdown = loop { + let cmd = tokio::select! { + cmd = self.cmds.recv() => cmd, + _ = self.tasks.join_next(), if !self.tasks.is_empty() => continue, + }; op += 1; - let Some(cmd) = self.cmds.recv().await else { + let Some(cmd) = cmd else { break None; }; match cmd { Command::TopLevel(cmd) => { let op = TxnNum::TopLevel(op); - if let Some(shutdown) = Self::handle_toplevel(&mut db, cmd, op).await? { + if let Some(shutdown) = + Self::handle_toplevel(&mut db, &mut self.tasks, cmd, op).await? + { break Some(shutdown); } } @@ -887,7 +892,7 @@ pub async fn list_blobs(snapshot: ReadOnlyTables, cmd: ListBlobsMsg) { Ok(()) => {} Err(e) => { error!("error listing blobs: {}", e); - tx.send(Err(e)).await.ok(); + tx.send(ListBlobsItem::Error(e)).await.ok(); } } } @@ -895,12 +900,13 @@ pub async fn list_blobs(snapshot: ReadOnlyTables, cmd: ListBlobsMsg) { async fn list_blobs_impl( snapshot: ReadOnlyTables, _cmd: ListRequest, - tx: &mut mpsc::Sender>, + tx: &mut mpsc::Sender, ) -> api::Result<()> { for item in snapshot.blobs.iter().map_err(api::Error::other)? { let (k, _) = item.map_err(api::Error::other)?; let k = k.value(); - tx.send(Ok(k)).await.ok(); + tx.send(ListBlobsItem::Item(k)).await.ok(); } + tx.send(ListBlobsItem::Done).await.ok(); Ok(()) } diff --git a/src/store/fs/meta/proto.rs b/src/store/fs/meta/proto.rs index 6f4aaa6c..24f182c4 100644 --- a/src/store/fs/meta/proto.rs +++ b/src/store/fs/meta/proto.rs @@ -5,11 +5,11 @@ use bytes::Bytes; use nested_enum_utils::enum_conversions; use tracing::Span; -use super::{ActorResult, ReadOnlyTables}; +use super::ActorResult; use crate::{ api::proto::{ - BlobStatusMsg, ClearProtectedMsg, DeleteBlobsMsg, ProcessExitRequest, ShutdownMsg, - SyncDbMsg, + BlobStatusMsg, ClearProtectedMsg, DeleteBlobsMsg, ListBlobsMsg, ProcessExitRequest, + ShutdownMsg, SyncDbMsg, }, store::{fs::entry_state::EntryState, util::DD}, util::channel::oneshot, @@ -49,12 +49,6 @@ pub struct Dump { pub span: Span, } -#[derive(Debug)] -pub struct Snapshot { - pub(crate) tx: tokio::sync::oneshot::Sender, - pub span: Span, -} - pub struct Update { pub hash: Hash, pub state: EntryState, @@ -167,7 +161,7 @@ impl ReadWriteCommand { pub enum TopLevelCommand { SyncDb(SyncDbMsg), Shutdown(ShutdownMsg), - Snapshot(Snapshot), + ListBlobs(ListBlobsMsg), } impl TopLevelCommand { @@ -181,7 +175,7 @@ impl TopLevelCommand { match self { Self::SyncDb(x) => x.parent_span_opt(), Self::Shutdown(x) => x.parent_span_opt(), - Self::Snapshot(x) => Some(&x.span), + Self::ListBlobs(x) => Some(&x.span), } } } diff --git a/src/store/mem.rs b/src/store/mem.rs index 6d022e0f..90a7d990 100644 --- a/src/store/mem.rs +++ b/src/store/mem.rs @@ -61,6 +61,7 @@ use crate::{ HashAndFormat, IROH_BLOCK_SIZE, }, util::{ + irpc::MpscSenderExt, temp_tag::{TagDrop, TempTagScope, TempTags}, ChunkRangesExt, }, @@ -297,7 +298,7 @@ impl Actor { format: value.format, }) .map(Ok); - tx.send(tags.collect()).await.ok(); + tx.forward_iter(tags).await.ok(); } Command::SetTag(SetTagMsg { inner: SetTagRequest { name: tag, value }, @@ -323,17 +324,22 @@ impl Actor { Command::ListTempTags(cmd) => { trace!("{cmd:?}"); let tts = self.temp_tags.list(); - cmd.tx.send(tts).await.ok(); + cmd.tx.forward_iter(tts.into_iter().map(Ok)).await.ok(); } Command::ListBlobs(cmd) => { let ListBlobsMsg { tx, .. } = cmd; let blobs = self.state.data.keys().cloned().collect::>(); self.spawn(async move { for blob in blobs { - if tx.send(Ok(blob)).await.is_err() { + if tx + .send(api::proto::ListBlobsItem::Item(blob)) + .await + .is_err() + { break; } } + tx.send(api::proto::ListBlobsItem::Done).await.ok(); }); } Command::BlobStatus(cmd) => { diff --git a/src/store/readonly_mem.rs b/src/store/readonly_mem.rs index 42274b2e..fa302b50 100644 --- a/src/store/readonly_mem.rs +++ b/src/store/readonly_mem.rs @@ -36,13 +36,13 @@ use crate::{ proto::{ self, BlobStatus, Command, ExportBaoMsg, ExportBaoRequest, ExportPathMsg, ExportPathRequest, ExportRangesItem, ExportRangesMsg, ExportRangesRequest, - ImportBaoMsg, ImportByteStreamMsg, ImportBytesMsg, ImportPathMsg, ObserveMsg, - ObserveRequest, WaitIdleMsg, + ImportBaoMsg, ImportByteStreamMsg, ImportBytesMsg, ImportPathMsg, ListBlobsItem, + ObserveMsg, ObserveRequest, WaitIdleMsg, }, ApiClient, TempTag, }, store::{mem::CompleteStorage, IROH_BLOCK_SIZE}, - util::ChunkRangesExt, + util::{irpc::MpscSenderExt, ChunkRangesExt}, Hash, }; @@ -178,8 +178,9 @@ impl Actor { let hashes: Vec = self.data.keys().cloned().collect(); self.tasks.spawn(async move { for hash in hashes { - cmd.tx.send(Ok(hash)).await.ok(); + cmd.tx.send(ListBlobsItem::Item(hash)).await.ok(); } + cmd.tx.send(ListBlobsItem::Done).await.ok(); }); } Command::BlobStatus(cmd) => { @@ -195,7 +196,7 @@ impl Actor { cmd.tx.send(status).await.ok(); } Command::ListTags(cmd) => { - cmd.tx.send(Vec::new()).await.ok(); + cmd.tx.forward_iter(std::iter::empty()).await.ok(); } Command::SetTag(cmd) => { cmd.tx @@ -204,7 +205,7 @@ impl Actor { .ok(); } Command::ListTempTags(cmd) => { - cmd.tx.send(Vec::new()).await.ok(); + cmd.tx.forward_iter(std::iter::empty()).await.ok(); } Command::SyncDb(cmd) => { cmd.tx.send(Ok(())).await.ok(); diff --git a/src/util.rs b/src/util.rs index 7b9ad4e6..06fd60cd 100644 --- a/src/util.rs +++ b/src/util.rs @@ -4,6 +4,7 @@ use bao_tree::{io::round_up_to_chunks, ChunkNum, ChunkRanges}; use range_collections::{range_set::RangeSetEntry, RangeSet2}; pub mod channel; +pub mod irpc; pub(crate) mod temp_tag; pub mod serde { // Module that handles io::Error serialization/deserialization diff --git a/src/util/irpc.rs b/src/util/irpc.rs new file mode 100644 index 00000000..364c04c2 --- /dev/null +++ b/src/util/irpc.rs @@ -0,0 +1,177 @@ +use std::{future::Future, io}; + +use irpc::{ + channel::{mpsc, RecvError}, + RpcMessage, +}; +use n0_future::{stream, Stream, StreamExt}; + +/// Trait for an enum that has three variants, item, error, and done. +/// +/// This is very common for irpc stream items if you want to provide an explicit +/// end of stream marker to make sure unsuccessful termination is not mistaken +/// for successful end of stream. +pub(crate) trait IrpcStreamItem: RpcMessage { + /// The error case of the item enum. + type Error; + /// The item case of the item enum. + type Item; + /// Converts the stream item into either None for end of stream, or a Result + /// containing the item or an error. Error is assumed as a termination, so + /// if you get error you won't get an additional end of stream marker. + fn into_result_opt(self) -> Option>; + /// Converts a result into the item enum. + fn from_result(item: std::result::Result) -> Self; + /// Produces a done marker for the item enum. + fn done() -> Self; +} + +pub(crate) trait MpscSenderExt: Sized { + /// Forward a stream of items to the sender. + /// + /// This will convert items and errors into the item enum type, and add + /// a done marker if the stream ends without an error. + #[allow(dead_code)] + async fn forward_stream( + self, + stream: impl Stream>, + ) -> std::result::Result<(), irpc::channel::SendError>; + + /// Forward an iterator of items to the sender. + /// + /// This will convert items and errors into the item enum type, and add + /// a done marker if the iterator ends without an error. + async fn forward_iter( + self, + iter: impl Iterator>, + ) -> std::result::Result<(), irpc::channel::SendError>; +} + +impl MpscSenderExt for mpsc::Sender +where + T: IrpcStreamItem, +{ + async fn forward_stream( + self, + stream: impl Stream>, + ) -> std::result::Result<(), irpc::channel::SendError> { + tokio::pin!(stream); + while let Some(item) = stream.next().await { + let done = item.is_err(); + self.send(T::from_result(item)).await?; + if done { + return Ok(()); + }; + } + self.send(T::done()).await + } + + async fn forward_iter( + self, + iter: impl Iterator>, + ) -> std::result::Result<(), irpc::channel::SendError> { + for item in iter { + let done = item.is_err(); + self.send(T::from_result(item)).await?; + if done { + return Ok(()); + }; + } + self.send(T::done()).await + } +} + +pub(crate) trait IrpcReceiverFutExt { + /// Collects the receiver returned by this future into a collection, + /// provided that we get a receiver and draining the receiver does not + /// produce any error items. + /// + /// The collection must implement Default and Extend. + /// Note that using this with a very large stream might use a lot of memory. + async fn try_collect(self) -> std::result::Result + where + C: Default + Extend, + E: From, + E: From, + E: From; + + /// Converts the receiver returned by this future into a stream of items, + /// where each item is either a successful item or an error. + /// + /// There will be at most one error item, which will terminate the stream. + /// If the future returns an error, the stream will yield that error as the + /// first item and then terminate. + fn into_stream(self) -> impl Stream> + where + E: From, + E: From, + E: From; +} + +impl IrpcReceiverFutExt for F +where + T: IrpcStreamItem, + F: Future, irpc::Error>>, +{ + async fn try_collect(self) -> std::result::Result + where + C: Default + Extend, + E: From, + E: From, + E: From, + { + let mut items = C::default(); + let mut stream = self.into_stream::(); + while let Some(item) = stream.next().await { + match item { + Ok(i) => items.extend(Some(i)), + Err(e) => return Err(e), + } + } + Ok(items) + } + + fn into_stream(self) -> impl Stream> + where + E: From, + E: From, + E: From, + { + enum State { + Init(S), + Receiving(mpsc::Receiver), + Done, + } + fn eof() -> RecvError { + io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected end of stream").into() + } + async fn process_recv( + mut rx: mpsc::Receiver, + ) -> Option<(std::result::Result, State)> + where + T: IrpcStreamItem, + E: From, + E: From, + E: From, + { + match rx.recv().await { + Ok(Some(item)) => match item.into_result_opt()? { + Ok(i) => Some((Ok(i), State::Receiving(rx))), + Err(e) => Some((Err(E::from(e)), State::Done)), + }, + Ok(None) => Some((Err(E::from(eof())), State::Done)), + Err(e) => Some((Err(E::from(e)), State::Done)), + } + } + Box::pin(stream::unfold(State::Init(self), |state| async move { + match state { + State::Init(fut) => match fut.await { + Ok(rx) => process_recv(rx).await, + Err(e) => Some((Err(E::from(e)), State::Done)), + }, + State::Receiving(rx) => process_recv(rx).await, + State::Done => None, + } + })) + } +} diff --git a/tests/tags.rs b/tests/tags.rs index 3864bc54..1b3b57e9 100644 --- a/tests/tags.rs +++ b/tests/tags.rs @@ -5,21 +5,14 @@ use std::{ use iroh_blobs::{ api::{ - self, tags::{TagInfo, Tags}, Store, }, store::{fs::FsStore, mem::MemStore}, BlobFormat, Hash, HashAndFormat, }; -use n0_future::{Stream, StreamExt}; use testresult::TestResult; -async fn to_vec(stream: impl Stream>) -> api::Result> { - let res = stream.collect::>().await; - res.into_iter().collect::>>() -} - fn expected(tags: impl IntoIterator) -> Vec { tags.into_iter() .map(|tag| TagInfo::new(tag, Hash::new(tag))) @@ -35,50 +28,40 @@ async fn set(tags: &Tags, names: impl IntoIterator) -> TestResult<( async fn tags_smoke(tags: &Tags) -> TestResult<()> { set(tags, ["a", "b", "c", "d", "e"]).await?; - let stream = tags.list().await?; - let res = to_vec(stream).await?; + let res = tags.list().await?; assert_eq!(res, expected(["a", "b", "c", "d", "e"])); - let stream = tags.list_range("b".."d").await?; - let res = to_vec(stream).await?; + let res = tags.list_range("b".."d").await?; assert_eq!(res, expected(["b", "c"])); - let stream = tags.list_range("b"..).await?; - let res = to_vec(stream).await?; + let res = tags.list_range("b"..).await?; assert_eq!(res, expected(["b", "c", "d", "e"])); - let stream = tags.list_range(.."d").await?; - let res = to_vec(stream).await?; + let res = tags.list_range(.."d").await?; assert_eq!(res, expected(["a", "b", "c"])); - let stream = tags.list_range(..="d").await?; - let res = to_vec(stream).await?; + let res = tags.list_range(..="d").await?; assert_eq!(res, expected(["a", "b", "c", "d"])); tags.delete_range("b"..).await?; - let stream = tags.list().await?; - let res = to_vec(stream).await?; + let res = tags.list().await?; assert_eq!(res, expected(["a"])); tags.delete_range(..="a").await?; - let stream = tags.list().await?; - let res = to_vec(stream).await?; + let res = tags.list().await?; assert_eq!(res, expected([])); set(tags, ["a", "aa", "aaa", "aab", "b"]).await?; - let stream = tags.list_prefix("aa").await?; - let res = to_vec(stream).await?; + let res = tags.list_prefix("aa").await?; assert_eq!(res, expected(["aa", "aaa", "aab"])); tags.delete_prefix("aa").await?; - let stream = tags.list().await?; - let res = to_vec(stream).await?; + let res = tags.list().await?; assert_eq!(res, expected(["a", "b"])); tags.delete_prefix("").await?; - let stream = tags.list().await?; - let res = to_vec(stream).await?; + let res = tags.list().await?; assert_eq!(res, expected([])); set(tags, ["a", "b", "c"]).await?; @@ -89,8 +72,7 @@ async fn tags_smoke(tags: &Tags) -> TestResult<()> { ); tags.delete("b").await?; - let stream = tags.list().await?; - let res = to_vec(stream).await?; + let res = tags.list().await?; assert_eq!(res, expected(["a", "c"])); assert_eq!(tags.get("b").await?, None); @@ -100,8 +82,7 @@ async fn tags_smoke(tags: &Tags) -> TestResult<()> { tags.set("a", HashAndFormat::hash_seq(Hash::new("a"))) .await?; tags.set("b", HashAndFormat::raw(Hash::new("b"))).await?; - let stream = tags.list_hash_seq().await?; - let res = to_vec(stream).await?; + let res = tags.list_hash_seq().await?; assert_eq!( res, vec![TagInfo { @@ -114,8 +95,7 @@ async fn tags_smoke(tags: &Tags) -> TestResult<()> { tags.delete_all().await?; set(tags, ["c"]).await?; tags.rename("c", "f").await?; - let stream = tags.list().await?; - let res = to_vec(stream).await?; + let res = tags.list().await?; assert_eq!( res, vec![TagInfo {