Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub use delay::Delay;
pub use futures_map::FuturesMap;
pub use futures_set::FuturesSet;
pub use futures_tuple_set::FuturesTupleSet;
pub use stream_map::StreamMap;
pub use stream_map::{StreamMap, StreamMapIterable};
pub use stream_set::StreamSet;

use std::fmt;
Expand Down
175 changes: 145 additions & 30 deletions src/stream_map.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,77 @@
use std::mem;
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
use std::time::Duration;

use futures_util::stream::{BoxStream, SelectAll};
use futures_util::{stream, FutureExt, Stream, StreamExt};
use futures_util::stream::{select_all, BoxStream, SelectAll};
use futures_util::{FutureExt, Stream, StreamExt};

use crate::{Delay, PushError, Timeout};

/// Represents a map of [`Stream`]s.
///
/// Each stream must finish within the specified time and the map never outgrows its capacity.
pub struct StreamMap<ID, O> {
pub struct StreamMap<ID, O>(StreamMapIterable<ID, BoxStream<'static, O>>);

impl<ID, O> StreamMap<ID, O>
where
ID: Clone + Unpin,
{
pub fn new(make_delay: impl Fn() -> Delay + Send + Sync + 'static, capacity: usize) -> Self {
Self(StreamMapIterable::new(make_delay, capacity))
}
}

impl<ID, O> StreamMap<ID, O>
where
ID: Clone + PartialEq + Send + Unpin + 'static,
O: Send + 'static,
{
/// Push a stream into the map.
pub fn try_push<F>(&mut self, id: ID, stream: F) -> Result<(), PushError<BoxStream<O>>>
where
F: Stream<Item = O> + Send + 'static,
{
self.0.try_push(id, stream.boxed())
}

pub fn remove(&mut self, id: ID) -> Option<BoxStream<O>> {
self.0.remove(id)
}

pub fn len(&self) -> usize {
self.0.len()
}

pub fn is_empty(&self) -> bool {
self.0.is_empty()
}

#[allow(unknown_lints, clippy::needless_pass_by_ref_mut)] // &mut Context is idiomatic.
pub fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<()> {
self.0.poll_ready_unpin(cx)
}

pub fn poll_next_unpin(
&mut self,
cx: &mut Context<'_>,
) -> Poll<(ID, Option<Result<O, Timeout>>)> {
self.0.poll_next_unpin(cx)
}
}

/// Iterable variant of [`StreamMap`] without boxed streams.
pub struct StreamMapIterable<ID, F> {
make_delay: Box<dyn Fn() -> Delay + Send + Sync>,
capacity: usize,
inner: SelectAll<TaggedStream<ID, TimeoutStream<BoxStream<'static, O>>>>,
inner: SelectAll<TaggedStream<ID, TimeoutStream<F>>>,
empty_waker: Option<Waker>,
full_waker: Option<Waker>,
}

impl<ID, O> StreamMap<ID, O>
impl<ID, F> StreamMapIterable<ID, F>
where
ID: Clone + Unpin,
F: Stream + Unpin,
{
pub fn new(make_delay: impl Fn() -> Delay + Send + Sync + 'static, capacity: usize) -> Self {
Self {
Expand All @@ -34,18 +84,18 @@ where
}
}

impl<ID, O> StreamMap<ID, O>
impl<ID, F> StreamMapIterable<ID, F>
where
ID: Clone + PartialEq + Send + Unpin + 'static,
O: Send + 'static,
F: Stream + Unpin,
{
/// Push a stream into the map.
pub fn try_push<F>(&mut self, id: ID, stream: F) -> Result<(), PushError<BoxStream<O>>>
pub fn try_push(&mut self, id: ID, stream: F) -> Result<(), PushError<F>>
where
F: Stream<Item = O> + Send + 'static,
F: Stream + Send + 'static,
{
if self.inner.len() >= self.capacity {
return Err(PushError::BeyondCapacity(stream.boxed()));
return Err(PushError::BeyondCapacity(stream));
}

if let Some(waker) = self.empty_waker.take() {
Expand All @@ -56,7 +106,7 @@ where
self.inner.push(TaggedStream::new(
id,
TimeoutStream {
inner: stream.boxed(),
inner: stream,
timeout: (self.make_delay)(),
},
));
Expand All @@ -67,13 +117,10 @@ where
}
}

pub fn remove(&mut self, id: ID) -> Option<BoxStream<'static, O>> {
pub fn remove(&mut self, id: ID) -> Option<F> {
let tagged = self.inner.iter_mut().find(|s| s.key == id)?;

let inner = mem::replace(&mut tagged.inner.inner, stream::pending().boxed());
tagged.exhausted = true; // Setting this will emit `None` on the next poll and ensure `SelectAll` cleans up the resources.

Some(inner)
let inner = tagged.inner.take()?; // `TaggedStream` will emit `None` on the next poll and ensure `SelectAll` cleans up the resources.
Some(inner.inner)
}

pub fn len(&self) -> usize {
Expand All @@ -98,7 +145,7 @@ where
pub fn poll_next_unpin(
&mut self,
cx: &mut Context<'_>,
) -> Poll<(ID, Option<Result<O, Timeout>>)> {
) -> Poll<(ID, Option<Result<F::Item, Timeout>>)> {
match futures_util::ready!(self.inner.poll_next_unpin(cx)) {
None => {
self.empty_waker = Some(cx.waker().clone());
Expand All @@ -113,6 +160,14 @@ where
Some((id, None)) => Poll::Ready((id, None)),
}
}

pub fn iter(&self) -> Iter<ID, F> {
Iter(self.inner.iter())
}

pub fn iter_mut(&mut self) -> IterMut<ID, F> {
IterMut(self.inner.iter_mut())
}
}

struct TimeoutStream<S> {
Expand All @@ -137,17 +192,14 @@ where

struct TaggedStream<K, S> {
key: K,
inner: S,

exhausted: bool,
inner: Option<S>,
}

impl<K, S> TaggedStream<K, S> {
fn new(key: K, inner: S) -> Self {
Self {
key,
inner,
exhausted: false,
inner: Some(inner),
}
}
}
Expand All @@ -160,21 +212,60 @@ where
type Item = (K, Option<S::Item>);

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.exhausted {
let Some(inner) = self.inner.as_mut() else {
return Poll::Ready(None);
}
};

match futures_util::ready!(self.inner.poll_next_unpin(cx)) {
match futures_util::ready!(inner.poll_next_unpin(cx)) {
Some(item) => Poll::Ready(Some((self.key.clone(), Some(item)))),
None => {
self.exhausted = true;

self.inner.take();
Poll::Ready(Some((self.key.clone(), None)))
}
}
}
}

pub struct Iter<'a, ID: Unpin, F: Unpin>(select_all::Iter<'a, TaggedStream<ID, TimeoutStream<F>>>);

impl<'a, ID, F> Iterator for Iter<'a, ID, F>
where
ID: Clone + Unpin,
F: Unpin + Stream,
{
type Item = (ID, &'a F);

fn next(&mut self) -> Option<Self::Item> {
let next = self.0.next()?;
Some((next.key.clone(), &next.inner.as_ref()?.inner))
}

fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
}

pub struct IterMut<'a, ID: Unpin, F: Unpin>(
select_all::IterMut<'a, TaggedStream<ID, TimeoutStream<F>>>,
);

impl<'a, ID, F> Iterator for IterMut<'a, ID, F>
where
ID: Clone + Unpin,
F: Unpin + Stream,
{
type Item = (ID, &'a mut F);

fn next(&mut self) -> Option<Self::Item> {
let next = self.0.next()?;
Some((next.key.clone(), &mut next.inner.as_mut()?.inner))
}

fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
}

#[cfg(all(test, feature = "futures-timer"))]
mod tests {
use futures::channel::mpsc;
Expand Down Expand Up @@ -237,7 +328,7 @@ mod tests {
fn removing_stream() {
let mut streams = StreamMap::new(|| Delay::futures_timer(Duration::from_millis(100)), 1);

let _ = streams.try_push("ID", stream::once(ready(())));
let _ = streams.try_push("ID", once(ready(())));

{
let cancelled_stream = streams.remove("ID");
Expand All @@ -256,6 +347,30 @@ mod tests {
);
}

#[test]
fn iterating_streams() {
const N: usize = 10;
let mut streams =
StreamMapIterable::new(|| Delay::futures_timer(Duration::from_millis(100)), N);
let mut sender = Vec::with_capacity(N);
for i in 0..N {
let (tx, rx) = mpsc::channel::<()>(1);
let _ = streams.try_push(i, rx);
sender.push(tx);
}
assert_eq!(streams.iter().count(), N);
for (i, (id, _)) in streams.iter().enumerate() {
let expect_id = N - i - 1; // Reverse order.
assert_eq!(id, expect_id);
}
assert!(!sender.iter().any(|tx| tx.is_closed()));

for (_, rx) in streams.iter_mut() {
rx.close();
}
assert!(sender.iter().all(|tx| tx.is_closed()));
}

#[tokio::test]
async fn replaced_stream_is_still_registered() {
let mut streams = StreamMap::new(|| Delay::futures_timer(Duration::from_millis(100)), 3);
Expand Down
Loading