Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ pub use futures_tuple_set::FuturesTupleSet;
pub use stream_map::StreamMap;
pub use stream_set::StreamSet;

use std::any::Any;
use std::fmt;
use std::fmt::Formatter;
use std::pin::Pin;
use std::time::Duration;

/// A future failed to complete within the given timeout.
Expand All @@ -35,7 +37,7 @@ impl fmt::Display for Timeout {
}

/// Error of a future pushing
#[derive(PartialEq, Debug)]
#[derive(PartialEq)]
pub enum PushError<T> {
/// The length of the set is equal to the capacity
BeyondCapacity(T),
Expand All @@ -45,4 +47,20 @@ pub enum PushError<T> {
Replaced(T),
}

impl<T> fmt::Debug for PushError<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::BeyondCapacity(_) => f.debug_tuple("BeyondCapacity").finish(),
Self::Replaced(_) => f.debug_tuple("Replaced").finish(),
}
}
}

impl std::error::Error for Timeout {}

#[doc(hidden)]
pub trait AnyStream: futures_util::Stream + Any + Unpin + Send {}

impl<T> AnyStream for T where T: futures_util::Stream + Any + Unpin + Send {}

type BoxStream<T> = Pin<Box<dyn AnyStream<Item = T> + Send>>;
71 changes: 63 additions & 8 deletions src/stream_map.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
use std::any::Any;
use std::mem;
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
use std::time::Duration;

use futures_util::stream::{BoxStream, SelectAll};
use futures_util::stream::SelectAll;
use futures_util::{stream, FutureExt, Stream, StreamExt};

use crate::{Delay, PushError, Timeout};
use crate::{AnyStream, BoxStream, Delay, PushError, Timeout};

/// Represents a map of [`Stream`]s.
///
/// Each stream must finish within the specified time and the map never outgrows its capacity.
pub struct StreamMap<ID, O> {
make_delay: Box<dyn Fn() -> Delay + Send + Sync>,
capacity: usize,
inner: SelectAll<TaggedStream<ID, TimeoutStream<BoxStream<'static, O>>>>,
inner: SelectAll<TaggedStream<ID, TimeoutStream<BoxStream<O>>>>,
empty_waker: Option<Waker>,
full_waker: Option<Waker>,
}
Expand Down Expand Up @@ -42,10 +43,10 @@ where
/// Push a stream into the map.
pub fn try_push<F>(&mut self, id: ID, stream: F) -> Result<(), PushError<BoxStream<O>>>
where
F: Stream<Item = O> + Send + 'static,
F: AnyStream<Item = O>,
{
if self.inner.len() >= self.capacity {
return Err(PushError::BeyondCapacity(stream.boxed()));
return Err(PushError::BeyondCapacity(Box::pin(stream)));
}

if let Some(waker) = self.empty_waker.take() {
Expand All @@ -56,7 +57,7 @@ where
self.inner.push(TaggedStream::new(
id,
TimeoutStream {
inner: stream.boxed(),
inner: Box::pin(stream),
timeout: (self.make_delay)(),
},
));
Expand All @@ -67,10 +68,10 @@ where
}
}

pub fn remove(&mut self, id: ID) -> Option<BoxStream<'static, O>> {
pub fn remove(&mut self, id: ID) -> Option<BoxStream<O>> {
let tagged = self.inner.iter_mut().find(|s| s.key == id)?;

let inner = mem::replace(&mut tagged.inner.inner, stream::pending().boxed());
let inner = mem::replace(&mut tagged.inner.inner, Box::pin(stream::pending()));
tagged.exhausted = true; // Setting this will emit `None` on the next poll and ensure `SelectAll` cleans up the resources.

Some(inner)
Expand Down Expand Up @@ -113,6 +114,37 @@ where
Some((id, None)) => Poll::Ready((id, None)),
}
}

/// Returns an iterator over all streams of type `T` pushed via [`StreamMap::try_push`].
///
/// If downcasting a stream to `T` fails it will be skipped in the iterator.
pub fn iter_of_type<T>(&self) -> impl Iterator<Item = (&ID, &T)>
where
T: 'static,
{
self.inner.iter().filter_map(|a| {
let pin = a.inner.inner.as_ref();
let any = Pin::into_inner(pin) as &(dyn Any + Send);
let inner = any.downcast_ref::<T>()?;
Some((&a.key, inner))
})
}

/// Returns an iterator with mutable access over all streams of type `T`
/// pushed via [`StreamMap::try_push`].
///
/// If downcasting a stream to `T` fails it will be skipped in the iterator.
pub fn iter_mut_of_type<T>(&mut self) -> impl Iterator<Item = (&mut ID, &mut T)>
where
T: 'static,
{
self.inner.iter_mut().filter_map(|a| {
let pin = a.inner.inner.as_mut();
let any = Pin::into_inner(pin) as &mut (dyn Any + Send);
let inner = any.downcast_mut::<T>()?;
Some((&mut a.key, inner))
})
}
}

struct TimeoutStream<S> {
Expand Down Expand Up @@ -304,6 +336,29 @@ mod tests {
assert!(duration >= DELAY * NUM_STREAMS);
}

#[test]
fn can_iter_named_streams() {
const N: usize = 10;
let mut streams = StreamMap::new(|| Delay::futures_timer(Duration::from_millis(100)), N);
let mut sender = Vec::with_capacity(N);
for i in 0..N {
let (tx, rx) = mpsc::channel::<()>(1);
streams.try_push(format!("ID{i}"), rx).unwrap();
sender.push(tx);
}
assert_eq!(streams.iter_of_type::<mpsc::Receiver<()>>().count(), N);
for (i, (id, _)) in streams.iter_of_type::<mpsc::Receiver<()>>().enumerate() {
let expect_id = format!("ID{}", N - i - 1); // Reverse order.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh? Why are we handing them out in reverse order?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why, but that's how the inner SelectAll::iter returns the the elements. Probably an implementation detail of the SelectAll data structure?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. It is a Vec internally. Should we sort them by ID?

Copy link
Contributor Author

@elenaf9 elenaf9 May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can sort them in the SelectAll, can we? Or where/ how should we sort them?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could sort them as part of the Iterator implementation. itertools has a combinator for that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be better left to the caller: callers who want a sorted iterator can use the combinator, and those who don't care don't have to pay the performance cost of sorting/comparisons (and extra temporary memory).

Documenting that the order is unspecified should be enough to inform people of the need for sorting.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documenting that the order is unspecified should be enough to inform people of the need for sorting.

Makes sense to me! Wanna send a PR?

assert_eq!(id, &expect_id);
}
assert!(!sender.iter().any(|tx| tx.is_closed()));

for (_, rx) in streams.iter_mut_of_type::<mpsc::Receiver<()>>() {
rx.close();
}
assert!(sender.iter().all(|tx| tx.is_closed()));
}

struct Task {
item_delay: Duration,
num_streams: usize,
Expand Down
27 changes: 23 additions & 4 deletions src/stream_set.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use futures_util::stream::BoxStream;
use futures_util::Stream;
use std::task::{ready, Context, Poll};

use crate::{Delay, PushError, StreamMap, Timeout};
use crate::{AnyStream, BoxStream, Delay, PushError, StreamMap, Timeout};

/// Represents a set of [Stream]s.
///
Expand Down Expand Up @@ -32,7 +30,7 @@ where
/// In that case, the stream is not added to the set.
pub fn try_push<F>(&mut self, stream: F) -> Result<(), BoxStream<O>>
where
F: Stream<Item = O> + Send + 'static,
F: AnyStream<Item = O>,
{
self.id = self.id.wrapping_add(1);

Expand Down Expand Up @@ -60,4 +58,25 @@ where

Poll::Ready(res)
}

/// Returns an iterator over all streams of type `T` pushed via [`StreamSet::try_push`].
///
/// If downcasting a stream to `T` fails it will be skipped in the iterator.
pub fn iter_of_type<T>(&self) -> impl Iterator<Item = &T>
where
T: 'static,
{
self.inner.iter_of_type().map(|(_, item)| item)
}

/// Returns an iterator with mutable access over all streams of type `T`
/// pushed via [`StreamSet::try_push`].
///
/// If downcasting a stream to `T` fails it will be skipped in the iterator.
pub fn iter_mut_of_type<T>(&mut self) -> impl Iterator<Item = &mut T>
where
T: 'static,
{
self.inner.iter_mut_of_type().map(|(_, item)| item)
}
}