diff --git a/src/lib.rs b/src/lib.rs index 235cb54..1606d25 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,8 +12,10 @@ pub use futures_tuple_set::FuturesTupleSet; pub use stream_map::StreamMap; pub use stream_set::StreamSet; +use std::any::Any; use std::fmt; use std::fmt::Formatter; +use std::pin::Pin; use std::time::Duration; /// A future failed to complete within the given timeout. @@ -35,7 +37,7 @@ impl fmt::Display for Timeout { } /// Error of a future pushing -#[derive(PartialEq, Debug)] +#[derive(PartialEq)] pub enum PushError { /// The length of the set is equal to the capacity BeyondCapacity(T), @@ -45,4 +47,20 @@ pub enum PushError { Replaced(T), } +impl fmt::Debug for PushError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::BeyondCapacity(_) => f.debug_tuple("BeyondCapacity").finish(), + Self::Replaced(_) => f.debug_tuple("Replaced").finish(), + } + } +} + impl std::error::Error for Timeout {} + +#[doc(hidden)] +pub trait AnyStream: futures_util::Stream + Any + Unpin + Send {} + +impl AnyStream for T where T: futures_util::Stream + Any + Unpin + Send {} + +type BoxStream = Pin + Send>>; diff --git a/src/stream_map.rs b/src/stream_map.rs index 75fc4b9..c0edaf8 100644 --- a/src/stream_map.rs +++ b/src/stream_map.rs @@ -1,12 +1,13 @@ +use std::any::Any; use std::mem; use std::pin::Pin; use std::task::{Context, Poll, Waker}; use std::time::Duration; -use futures_util::stream::{BoxStream, SelectAll}; +use futures_util::stream::SelectAll; use futures_util::{stream, FutureExt, Stream, StreamExt}; -use crate::{Delay, PushError, Timeout}; +use crate::{AnyStream, BoxStream, Delay, PushError, Timeout}; /// Represents a map of [`Stream`]s. /// @@ -14,7 +15,7 @@ use crate::{Delay, PushError, Timeout}; pub struct StreamMap { make_delay: Box Delay + Send + Sync>, capacity: usize, - inner: SelectAll>>>, + inner: SelectAll>>>, empty_waker: Option, full_waker: Option, } @@ -42,10 +43,10 @@ where /// Push a stream into the map. pub fn try_push(&mut self, id: ID, stream: F) -> Result<(), PushError>> where - F: Stream + Send + 'static, + F: AnyStream, { if self.inner.len() >= self.capacity { - return Err(PushError::BeyondCapacity(stream.boxed())); + return Err(PushError::BeyondCapacity(Box::pin(stream))); } if let Some(waker) = self.empty_waker.take() { @@ -56,7 +57,7 @@ where self.inner.push(TaggedStream::new( id, TimeoutStream { - inner: stream.boxed(), + inner: Box::pin(stream), timeout: (self.make_delay)(), }, )); @@ -67,10 +68,10 @@ where } } - pub fn remove(&mut self, id: ID) -> Option> { + pub fn remove(&mut self, id: ID) -> Option> { let tagged = self.inner.iter_mut().find(|s| s.key == id)?; - let inner = mem::replace(&mut tagged.inner.inner, stream::pending().boxed()); + let inner = mem::replace(&mut tagged.inner.inner, Box::pin(stream::pending())); tagged.exhausted = true; // Setting this will emit `None` on the next poll and ensure `SelectAll` cleans up the resources. Some(inner) @@ -113,6 +114,37 @@ where Some((id, None)) => Poll::Ready((id, None)), } } + + /// Returns an iterator over all streams of type `T` pushed via [`StreamMap::try_push`]. + /// + /// If downcasting a stream to `T` fails it will be skipped in the iterator. + pub fn iter_of_type(&self) -> impl Iterator + where + T: 'static, + { + self.inner.iter().filter_map(|a| { + let pin = a.inner.inner.as_ref(); + let any = Pin::into_inner(pin) as &(dyn Any + Send); + let inner = any.downcast_ref::()?; + Some((&a.key, inner)) + }) + } + + /// Returns an iterator with mutable access over all streams of type `T` + /// pushed via [`StreamMap::try_push`]. + /// + /// If downcasting a stream to `T` fails it will be skipped in the iterator. + pub fn iter_mut_of_type(&mut self) -> impl Iterator + where + T: 'static, + { + self.inner.iter_mut().filter_map(|a| { + let pin = a.inner.inner.as_mut(); + let any = Pin::into_inner(pin) as &mut (dyn Any + Send); + let inner = any.downcast_mut::()?; + Some((&mut a.key, inner)) + }) + } } struct TimeoutStream { @@ -304,6 +336,29 @@ mod tests { assert!(duration >= DELAY * NUM_STREAMS); } + #[test] + fn can_iter_named_streams() { + const N: usize = 10; + let mut streams = StreamMap::new(|| Delay::futures_timer(Duration::from_millis(100)), N); + let mut sender = Vec::with_capacity(N); + for i in 0..N { + let (tx, rx) = mpsc::channel::<()>(1); + streams.try_push(format!("ID{i}"), rx).unwrap(); + sender.push(tx); + } + assert_eq!(streams.iter_of_type::>().count(), N); + for (i, (id, _)) in streams.iter_of_type::>().enumerate() { + let expect_id = format!("ID{}", N - i - 1); // Reverse order. + assert_eq!(id, &expect_id); + } + assert!(!sender.iter().any(|tx| tx.is_closed())); + + for (_, rx) in streams.iter_mut_of_type::>() { + rx.close(); + } + assert!(sender.iter().all(|tx| tx.is_closed())); + } + struct Task { item_delay: Duration, num_streams: usize, diff --git a/src/stream_set.rs b/src/stream_set.rs index c4ddd8a..815621e 100644 --- a/src/stream_set.rs +++ b/src/stream_set.rs @@ -1,8 +1,6 @@ -use futures_util::stream::BoxStream; -use futures_util::Stream; use std::task::{ready, Context, Poll}; -use crate::{Delay, PushError, StreamMap, Timeout}; +use crate::{AnyStream, BoxStream, Delay, PushError, StreamMap, Timeout}; /// Represents a set of [Stream]s. /// @@ -32,7 +30,7 @@ where /// In that case, the stream is not added to the set. pub fn try_push(&mut self, stream: F) -> Result<(), BoxStream> where - F: Stream + Send + 'static, + F: AnyStream, { self.id = self.id.wrapping_add(1); @@ -60,4 +58,25 @@ where Poll::Ready(res) } + + /// Returns an iterator over all streams of type `T` pushed via [`StreamSet::try_push`]. + /// + /// If downcasting a stream to `T` fails it will be skipped in the iterator. + pub fn iter_of_type(&self) -> impl Iterator + where + T: 'static, + { + self.inner.iter_of_type().map(|(_, item)| item) + } + + /// Returns an iterator with mutable access over all streams of type `T` + /// pushed via [`StreamSet::try_push`]. + /// + /// If downcasting a stream to `T` fails it will be skipped in the iterator. + pub fn iter_mut_of_type(&mut self) -> impl Iterator + where + T: 'static, + { + self.inner.iter_mut_of_type().map(|(_, item)| item) + } }