Skip to content

Commit b524675

Browse files
feat: add typed iterators for StreamMap and StreamSet (#10)
Co-authored-by: Thomas Eizinger <[email protected]>
1 parent d0fc40e commit b524675

File tree

3 files changed

+105
-13
lines changed

3 files changed

+105
-13
lines changed

src/lib.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ pub use futures_tuple_set::FuturesTupleSet;
1212
pub use stream_map::StreamMap;
1313
pub use stream_set::StreamSet;
1414

15+
use std::any::Any;
1516
use std::fmt;
1617
use std::fmt::Formatter;
18+
use std::pin::Pin;
1719
use std::time::Duration;
1820

1921
/// A future failed to complete within the given timeout.
@@ -35,7 +37,7 @@ impl fmt::Display for Timeout {
3537
}
3638

3739
/// Error of a future pushing
38-
#[derive(PartialEq, Debug)]
40+
#[derive(PartialEq)]
3941
pub enum PushError<T> {
4042
/// The length of the set is equal to the capacity
4143
BeyondCapacity(T),
@@ -45,4 +47,20 @@ pub enum PushError<T> {
4547
Replaced(T),
4648
}
4749

50+
impl<T> fmt::Debug for PushError<T> {
51+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
52+
match self {
53+
Self::BeyondCapacity(_) => f.debug_tuple("BeyondCapacity").finish(),
54+
Self::Replaced(_) => f.debug_tuple("Replaced").finish(),
55+
}
56+
}
57+
}
58+
4859
impl std::error::Error for Timeout {}
60+
61+
#[doc(hidden)]
62+
pub trait AnyStream: futures_util::Stream + Any + Unpin + Send {}
63+
64+
impl<T> AnyStream for T where T: futures_util::Stream + Any + Unpin + Send {}
65+
66+
type BoxStream<T> = Pin<Box<dyn AnyStream<Item = T> + Send>>;

src/stream_map.rs

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
1+
use std::any::Any;
12
use std::mem;
23
use std::pin::Pin;
34
use std::task::{Context, Poll, Waker};
45
use std::time::Duration;
56

6-
use futures_util::stream::{BoxStream, SelectAll};
7+
use futures_util::stream::SelectAll;
78
use futures_util::{stream, FutureExt, Stream, StreamExt};
89

9-
use crate::{Delay, PushError, Timeout};
10+
use crate::{AnyStream, BoxStream, Delay, PushError, Timeout};
1011

1112
/// Represents a map of [`Stream`]s.
1213
///
1314
/// Each stream must finish within the specified time and the map never outgrows its capacity.
1415
pub struct StreamMap<ID, O> {
1516
make_delay: Box<dyn Fn() -> Delay + Send + Sync>,
1617
capacity: usize,
17-
inner: SelectAll<TaggedStream<ID, TimeoutStream<BoxStream<'static, O>>>>,
18+
inner: SelectAll<TaggedStream<ID, TimeoutStream<BoxStream<O>>>>,
1819
empty_waker: Option<Waker>,
1920
full_waker: Option<Waker>,
2021
}
@@ -42,10 +43,10 @@ where
4243
/// Push a stream into the map.
4344
pub fn try_push<F>(&mut self, id: ID, stream: F) -> Result<(), PushError<BoxStream<O>>>
4445
where
45-
F: Stream<Item = O> + Send + 'static,
46+
F: AnyStream<Item = O>,
4647
{
4748
if self.inner.len() >= self.capacity {
48-
return Err(PushError::BeyondCapacity(stream.boxed()));
49+
return Err(PushError::BeyondCapacity(Box::pin(stream)));
4950
}
5051

5152
if let Some(waker) = self.empty_waker.take() {
@@ -56,7 +57,7 @@ where
5657
self.inner.push(TaggedStream::new(
5758
id,
5859
TimeoutStream {
59-
inner: stream.boxed(),
60+
inner: Box::pin(stream),
6061
timeout: (self.make_delay)(),
6162
},
6263
));
@@ -67,10 +68,10 @@ where
6768
}
6869
}
6970

70-
pub fn remove(&mut self, id: ID) -> Option<BoxStream<'static, O>> {
71+
pub fn remove(&mut self, id: ID) -> Option<BoxStream<O>> {
7172
let tagged = self.inner.iter_mut().find(|s| s.key == id)?;
7273

73-
let inner = mem::replace(&mut tagged.inner.inner, stream::pending().boxed());
74+
let inner = mem::replace(&mut tagged.inner.inner, Box::pin(stream::pending()));
7475
tagged.exhausted = true; // Setting this will emit `None` on the next poll and ensure `SelectAll` cleans up the resources.
7576

7677
Some(inner)
@@ -113,6 +114,37 @@ where
113114
Some((id, None)) => Poll::Ready((id, None)),
114115
}
115116
}
117+
118+
/// Returns an iterator over all streams of type `T` pushed via [`StreamMap::try_push`].
119+
///
120+
/// If downcasting a stream to `T` fails it will be skipped in the iterator.
121+
pub fn iter_of_type<T>(&self) -> impl Iterator<Item = (&ID, &T)>
122+
where
123+
T: 'static,
124+
{
125+
self.inner.iter().filter_map(|a| {
126+
let pin = a.inner.inner.as_ref();
127+
let any = Pin::into_inner(pin) as &(dyn Any + Send);
128+
let inner = any.downcast_ref::<T>()?;
129+
Some((&a.key, inner))
130+
})
131+
}
132+
133+
/// Returns an iterator with mutable access over all streams of type `T`
134+
/// pushed via [`StreamMap::try_push`].
135+
///
136+
/// If downcasting a stream to `T` fails it will be skipped in the iterator.
137+
pub fn iter_mut_of_type<T>(&mut self) -> impl Iterator<Item = (&mut ID, &mut T)>
138+
where
139+
T: 'static,
140+
{
141+
self.inner.iter_mut().filter_map(|a| {
142+
let pin = a.inner.inner.as_mut();
143+
let any = Pin::into_inner(pin) as &mut (dyn Any + Send);
144+
let inner = any.downcast_mut::<T>()?;
145+
Some((&mut a.key, inner))
146+
})
147+
}
116148
}
117149

118150
struct TimeoutStream<S> {
@@ -304,6 +336,29 @@ mod tests {
304336
assert!(duration >= DELAY * NUM_STREAMS);
305337
}
306338

339+
#[test]
340+
fn can_iter_named_streams() {
341+
const N: usize = 10;
342+
let mut streams = StreamMap::new(|| Delay::futures_timer(Duration::from_millis(100)), N);
343+
let mut sender = Vec::with_capacity(N);
344+
for i in 0..N {
345+
let (tx, rx) = mpsc::channel::<()>(1);
346+
streams.try_push(format!("ID{i}"), rx).unwrap();
347+
sender.push(tx);
348+
}
349+
assert_eq!(streams.iter_of_type::<mpsc::Receiver<()>>().count(), N);
350+
for (i, (id, _)) in streams.iter_of_type::<mpsc::Receiver<()>>().enumerate() {
351+
let expect_id = format!("ID{}", N - i - 1); // Reverse order.
352+
assert_eq!(id, &expect_id);
353+
}
354+
assert!(!sender.iter().any(|tx| tx.is_closed()));
355+
356+
for (_, rx) in streams.iter_mut_of_type::<mpsc::Receiver<()>>() {
357+
rx.close();
358+
}
359+
assert!(sender.iter().all(|tx| tx.is_closed()));
360+
}
361+
307362
struct Task {
308363
item_delay: Duration,
309364
num_streams: usize,

src/stream_set.rs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
use futures_util::stream::BoxStream;
2-
use futures_util::Stream;
31
use std::task::{ready, Context, Poll};
42

5-
use crate::{Delay, PushError, StreamMap, Timeout};
3+
use crate::{AnyStream, BoxStream, Delay, PushError, StreamMap, Timeout};
64

75
/// Represents a set of [Stream]s.
86
///
@@ -32,7 +30,7 @@ where
3230
/// In that case, the stream is not added to the set.
3331
pub fn try_push<F>(&mut self, stream: F) -> Result<(), BoxStream<O>>
3432
where
35-
F: Stream<Item = O> + Send + 'static,
33+
F: AnyStream<Item = O>,
3634
{
3735
self.id = self.id.wrapping_add(1);
3836

@@ -60,4 +58,25 @@ where
6058

6159
Poll::Ready(res)
6260
}
61+
62+
/// Returns an iterator over all streams of type `T` pushed via [`StreamSet::try_push`].
63+
///
64+
/// If downcasting a stream to `T` fails it will be skipped in the iterator.
65+
pub fn iter_of_type<T>(&self) -> impl Iterator<Item = &T>
66+
where
67+
T: 'static,
68+
{
69+
self.inner.iter_of_type().map(|(_, item)| item)
70+
}
71+
72+
/// Returns an iterator with mutable access over all streams of type `T`
73+
/// pushed via [`StreamSet::try_push`].
74+
///
75+
/// If downcasting a stream to `T` fails it will be skipped in the iterator.
76+
pub fn iter_mut_of_type<T>(&mut self) -> impl Iterator<Item = &mut T>
77+
where
78+
T: 'static,
79+
{
80+
self.inner.iter_mut_of_type().map(|(_, item)| item)
81+
}
6382
}

0 commit comments

Comments
 (0)