1+ use std:: any:: Any ;
12use std:: mem;
23use std:: pin:: Pin ;
34use std:: task:: { Context , Poll , Waker } ;
45use std:: time:: Duration ;
56
6- use futures_util:: stream:: { BoxStream , SelectAll } ;
7+ use futures_util:: stream:: SelectAll ;
78use futures_util:: { stream, FutureExt , Stream , StreamExt } ;
89
9- use crate :: { Delay , PushError , Timeout } ;
10+ use crate :: { AnyStream , BoxStream , Delay , PushError , Timeout } ;
1011
1112/// Represents a map of [`Stream`]s.
1213///
1314/// Each stream must finish within the specified time and the map never outgrows its capacity.
1415pub struct StreamMap < ID , O > {
1516 make_delay : Box < dyn Fn ( ) -> Delay + Send + Sync > ,
1617 capacity : usize ,
17- inner : SelectAll < TaggedStream < ID , TimeoutStream < BoxStream < ' static , O > > > > ,
18+ inner : SelectAll < TaggedStream < ID , TimeoutStream < BoxStream < O > > > > ,
1819 empty_waker : Option < Waker > ,
1920 full_waker : Option < Waker > ,
2021}
@@ -42,10 +43,10 @@ where
4243 /// Push a stream into the map.
4344 pub fn try_push < F > ( & mut self , id : ID , stream : F ) -> Result < ( ) , PushError < BoxStream < O > > >
4445 where
45- F : Stream < Item = O > + Send + ' static ,
46+ F : AnyStream < Item = O > ,
4647 {
4748 if self . inner . len ( ) >= self . capacity {
48- return Err ( PushError :: BeyondCapacity ( stream . boxed ( ) ) ) ;
49+ return Err ( PushError :: BeyondCapacity ( Box :: pin ( stream ) ) ) ;
4950 }
5051
5152 if let Some ( waker) = self . empty_waker . take ( ) {
5657 self . inner . push ( TaggedStream :: new (
5758 id,
5859 TimeoutStream {
59- inner : stream . boxed ( ) ,
60+ inner : Box :: pin ( stream ) ,
6061 timeout : ( self . make_delay ) ( ) ,
6162 } ,
6263 ) ) ;
@@ -67,10 +68,10 @@ where
6768 }
6869 }
6970
70- pub fn remove ( & mut self , id : ID ) -> Option < BoxStream < ' static , O > > {
71+ pub fn remove ( & mut self , id : ID ) -> Option < BoxStream < O > > {
7172 let tagged = self . inner . iter_mut ( ) . find ( |s| s. key == id) ?;
7273
73- let inner = mem:: replace ( & mut tagged. inner . inner , stream:: pending ( ) . boxed ( ) ) ;
74+ let inner = mem:: replace ( & mut tagged. inner . inner , Box :: pin ( stream:: pending ( ) ) ) ;
7475 tagged. exhausted = true ; // Setting this will emit `None` on the next poll and ensure `SelectAll` cleans up the resources.
7576
7677 Some ( inner)
@@ -113,6 +114,37 @@ where
113114 Some ( ( id, None ) ) => Poll :: Ready ( ( id, None ) ) ,
114115 }
115116 }
117+
118+ /// Returns an iterator over all streams of type `T` pushed via [`StreamMap::try_push`].
119+ ///
120+ /// If downcasting a stream to `T` fails it will be skipped in the iterator.
121+ pub fn iter_of_type < T > ( & self ) -> impl Iterator < Item = ( & ID , & T ) >
122+ where
123+ T : ' static ,
124+ {
125+ self . inner . iter ( ) . filter_map ( |a| {
126+ let pin = a. inner . inner . as_ref ( ) ;
127+ let any = Pin :: into_inner ( pin) as & ( dyn Any + Send ) ;
128+ let inner = any. downcast_ref :: < T > ( ) ?;
129+ Some ( ( & a. key , inner) )
130+ } )
131+ }
132+
133+ /// Returns an iterator with mutable access over all streams of type `T`
134+ /// pushed via [`StreamMap::try_push`].
135+ ///
136+ /// If downcasting a stream to `T` fails it will be skipped in the iterator.
137+ pub fn iter_mut_of_type < T > ( & mut self ) -> impl Iterator < Item = ( & mut ID , & mut T ) >
138+ where
139+ T : ' static ,
140+ {
141+ self . inner . iter_mut ( ) . filter_map ( |a| {
142+ let pin = a. inner . inner . as_mut ( ) ;
143+ let any = Pin :: into_inner ( pin) as & mut ( dyn Any + Send ) ;
144+ let inner = any. downcast_mut :: < T > ( ) ?;
145+ Some ( ( & mut a. key , inner) )
146+ } )
147+ }
116148}
117149
118150struct TimeoutStream < S > {
@@ -304,6 +336,29 @@ mod tests {
304336 assert ! ( duration >= DELAY * NUM_STREAMS ) ;
305337 }
306338
339+ #[ test]
340+ fn can_iter_named_streams ( ) {
341+ const N : usize = 10 ;
342+ let mut streams = StreamMap :: new ( || Delay :: futures_timer ( Duration :: from_millis ( 100 ) ) , N ) ;
343+ let mut sender = Vec :: with_capacity ( N ) ;
344+ for i in 0 ..N {
345+ let ( tx, rx) = mpsc:: channel :: < ( ) > ( 1 ) ;
346+ streams. try_push ( format ! ( "ID{i}" ) , rx) . unwrap ( ) ;
347+ sender. push ( tx) ;
348+ }
349+ assert_eq ! ( streams. iter_of_type:: <mpsc:: Receiver <( ) >>( ) . count( ) , N ) ;
350+ for ( i, ( id, _) ) in streams. iter_of_type :: < mpsc:: Receiver < ( ) > > ( ) . enumerate ( ) {
351+ let expect_id = format ! ( "ID{}" , N - i - 1 ) ; // Reverse order.
352+ assert_eq ! ( id, & expect_id) ;
353+ }
354+ assert ! ( !sender. iter( ) . any( |tx| tx. is_closed( ) ) ) ;
355+
356+ for ( _, rx) in streams. iter_mut_of_type :: < mpsc:: Receiver < ( ) > > ( ) {
357+ rx. close ( ) ;
358+ }
359+ assert ! ( sender. iter( ) . all( |tx| tx. is_closed( ) ) ) ;
360+ }
361+
307362 struct Task {
308363 item_delay : Duration ,
309364 num_streams : usize ,
0 commit comments