@@ -8,6 +8,8 @@ use crate::networking::{AsyncNetworking, LocalAsyncNetworking};
8
8
use crate :: replicated:: { RepSetup , ReplicatedPlacement } ;
9
9
use crate :: storage:: { AsyncStorage , LocalAsyncStorage } ;
10
10
use futures:: future:: { Map , Shared } ;
11
+ use futures:: stream:: FuturesUnordered ;
12
+ use futures:: StreamExt ;
11
13
use std:: collections:: { HashMap , HashSet } ;
12
14
use std:: convert:: TryFrom ;
13
15
use std:: sync:: { Arc , RwLock } ;
@@ -51,7 +53,7 @@ pub(crate) fn map_receive_error<T>(_: T) -> Error {
51
53
}
52
54
53
55
pub struct AsyncSessionHandle {
54
- pub tasks : Arc < RwLock < Vec < crate :: execution :: AsyncTask > > > ,
56
+ pub tasks : Arc < RwLock < FuturesUnordered < AsyncTask > > > ,
55
57
}
56
58
57
59
impl AsyncSessionHandle {
@@ -63,18 +65,10 @@ impl AsyncSessionHandle {
63
65
64
66
pub async fn join_on_first_error ( self ) -> anyhow:: Result < ( ) > {
65
67
use crate :: error:: Error :: { OperandUnavailable , ResultUnused } ;
66
- // use futures::StreamExt;
67
68
68
69
let mut tasks_guard = self . tasks . write ( ) . unwrap ( ) ;
69
- // TODO (lvorona): should really find a way to use FuturesUnordered here
70
- // let mut tasks = (*tasks_guard)
71
- // .into_iter()
72
- // .collect::<futures::stream::FuturesUnordered<_>>();
73
-
74
- let mut tasks = tasks_guard. iter_mut ( ) ;
75
-
76
- while let Some ( x) = tasks. next ( ) {
77
- let x = x. await ;
70
+ let mut maybe_error = None ;
71
+ while let Some ( x) = tasks_guard. next ( ) . await {
78
72
match x {
79
73
Ok ( Ok ( _) ) => {
80
74
continue ;
@@ -87,26 +81,30 @@ impl AsyncSessionHandle {
87
81
OperandUnavailable => continue ,
88
82
ResultUnused => continue ,
89
83
_ => {
90
- for task in tasks {
91
- task. abort ( ) ;
92
- }
93
- return Err ( anyhow:: Error :: from ( e) ) ;
84
+ maybe_error = Some ( Err ( anyhow:: Error :: from ( e) ) ) ;
85
+ break ;
94
86
}
95
87
}
96
88
}
97
89
Err ( e) => {
98
90
if e. is_cancelled ( ) {
99
91
continue ;
100
92
} else if e. is_panic ( ) {
101
- for task in tasks {
102
- task. abort ( ) ;
103
- }
104
- return Err ( anyhow:: Error :: from ( e) ) ;
93
+ maybe_error = Some ( Err ( anyhow:: Error :: from ( e) ) ) ;
94
+ break ;
105
95
}
106
96
}
107
97
}
108
98
}
109
- Ok ( ( ) )
99
+
100
+ if let Some ( e) = maybe_error {
101
+ for task in tasks_guard. iter_mut ( ) {
102
+ task. abort ( ) ;
103
+ }
104
+ e
105
+ } else {
106
+ Ok ( ( ) )
107
+ }
110
108
}
111
109
}
112
110
@@ -118,7 +116,7 @@ pub struct AsyncSession {
118
116
pub role_assignments : Arc < HashMap < Role , Identity > > ,
119
117
pub networking : AsyncNetworkingImpl ,
120
118
pub storage : AsyncStorageImpl ,
121
- pub tasks : Arc < RwLock < Vec < crate :: execution:: AsyncTask > > > ,
119
+ pub tasks : Arc < RwLock < FuturesUnordered < crate :: execution:: AsyncTask > > > ,
122
120
}
123
121
124
122
impl AsyncSession {
@@ -178,7 +176,7 @@ impl AsyncSession {
178
176
map_send_result ( sender. send ( value) ) ?;
179
177
Ok ( ( ) )
180
178
} ) ;
181
- let mut tasks = self . tasks . write ( ) . unwrap ( ) ;
179
+ let tasks = self . tasks . read ( ) . unwrap ( ) ;
182
180
tasks. push ( task) ;
183
181
184
182
Ok ( receiver)
@@ -216,7 +214,7 @@ impl AsyncSession {
216
214
map_send_result ( sender. send ( result. into ( ) ) ) ?;
217
215
Ok ( ( ) )
218
216
} ) ;
219
- let mut tasks = self . tasks . write ( ) . unwrap ( ) ;
217
+ let tasks = self . tasks . read ( ) . unwrap ( ) ;
220
218
tasks. push ( task) ;
221
219
222
220
Ok ( receiver)
@@ -244,7 +242,7 @@ impl AsyncSession {
244
242
map_send_result ( sender. send ( value) ) ?;
245
243
Ok ( ( ) )
246
244
} ) ;
247
- let mut tasks = self . tasks . write ( ) . unwrap ( ) ;
245
+ let tasks = self . tasks . read ( ) . unwrap ( ) ;
248
246
tasks. push ( task) ;
249
247
250
248
Ok ( receiver)
@@ -279,7 +277,7 @@ impl AsyncSession {
279
277
map_send_result ( sender. send ( result. into ( ) ) ) ?;
280
278
Ok ( ( ) )
281
279
} ) ;
282
- let mut tasks = self . tasks . write ( ) . unwrap ( ) ;
280
+ let tasks = self . tasks . read ( ) . unwrap ( ) ;
283
281
tasks. push ( task) ;
284
282
285
283
Ok ( receiver)
@@ -638,8 +636,11 @@ impl AsyncTestRuntime {
638
636
session_handles. push ( AsyncSessionHandle :: for_session ( & moose_session) )
639
637
}
640
638
641
- for handle in session_handles {
642
- let result = rt. block_on ( handle. join_on_first_error ( ) ) ;
639
+ let mut futures: FuturesUnordered < _ > = session_handles
640
+ . into_iter ( )
641
+ . map ( |h| h. join_on_first_error ( ) )
642
+ . collect ( ) ;
643
+ while let Some ( result) = rt. block_on ( futures. next ( ) ) {
643
644
if let Err ( e) = result {
644
645
return Err ( Error :: TestRuntime ( e. to_string ( ) ) ) ;
645
646
}
0 commit comments