Skip to content

Commit 451610e

Browse files
author
Lex Vorona
authored
Switching Async session to FuturesUnordered (#1025)
1 parent e93a161 commit 451610e

File tree

2 files changed

+33
-33
lines changed

2 files changed

+33
-33
lines changed

moose/src/execution/asynchronous.rs

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use crate::networking::{AsyncNetworking, LocalAsyncNetworking};
88
use crate::replicated::{RepSetup, ReplicatedPlacement};
99
use crate::storage::{AsyncStorage, LocalAsyncStorage};
1010
use futures::future::{Map, Shared};
11+
use futures::stream::FuturesUnordered;
12+
use futures::StreamExt;
1113
use std::collections::{HashMap, HashSet};
1214
use std::convert::TryFrom;
1315
use std::sync::{Arc, RwLock};
@@ -51,7 +53,7 @@ pub(crate) fn map_receive_error<T>(_: T) -> Error {
5153
}
5254

5355
pub struct AsyncSessionHandle {
54-
pub tasks: Arc<RwLock<Vec<crate::execution::AsyncTask>>>,
56+
pub tasks: Arc<RwLock<FuturesUnordered<AsyncTask>>>,
5557
}
5658

5759
impl AsyncSessionHandle {
@@ -63,18 +65,10 @@ impl AsyncSessionHandle {
6365

6466
pub async fn join_on_first_error(self) -> anyhow::Result<()> {
6567
use crate::error::Error::{OperandUnavailable, ResultUnused};
66-
// use futures::StreamExt;
6768

6869
let mut tasks_guard = self.tasks.write().unwrap();
69-
// TODO (lvorona): should really find a way to use FuturesUnordered here
70-
// let mut tasks = (*tasks_guard)
71-
// .into_iter()
72-
// .collect::<futures::stream::FuturesUnordered<_>>();
73-
74-
let mut tasks = tasks_guard.iter_mut();
75-
76-
while let Some(x) = tasks.next() {
77-
let x = x.await;
70+
let mut maybe_error = None;
71+
while let Some(x) = tasks_guard.next().await {
7872
match x {
7973
Ok(Ok(_)) => {
8074
continue;
@@ -87,26 +81,30 @@ impl AsyncSessionHandle {
8781
OperandUnavailable => continue,
8882
ResultUnused => continue,
8983
_ => {
90-
for task in tasks {
91-
task.abort();
92-
}
93-
return Err(anyhow::Error::from(e));
84+
maybe_error = Some(Err(anyhow::Error::from(e)));
85+
break;
9486
}
9587
}
9688
}
9789
Err(e) => {
9890
if e.is_cancelled() {
9991
continue;
10092
} else if e.is_panic() {
101-
for task in tasks {
102-
task.abort();
103-
}
104-
return Err(anyhow::Error::from(e));
93+
maybe_error = Some(Err(anyhow::Error::from(e)));
94+
break;
10595
}
10696
}
10797
}
10898
}
109-
Ok(())
99+
100+
if let Some(e) = maybe_error {
101+
for task in tasks_guard.iter_mut() {
102+
task.abort();
103+
}
104+
e
105+
} else {
106+
Ok(())
107+
}
110108
}
111109
}
112110

@@ -118,7 +116,7 @@ pub struct AsyncSession {
118116
pub role_assignments: Arc<HashMap<Role, Identity>>,
119117
pub networking: AsyncNetworkingImpl,
120118
pub storage: AsyncStorageImpl,
121-
pub tasks: Arc<RwLock<Vec<crate::execution::AsyncTask>>>,
119+
pub tasks: Arc<RwLock<FuturesUnordered<crate::execution::AsyncTask>>>,
122120
}
123121

124122
impl AsyncSession {
@@ -178,7 +176,7 @@ impl AsyncSession {
178176
map_send_result(sender.send(value))?;
179177
Ok(())
180178
});
181-
let mut tasks = self.tasks.write().unwrap();
179+
let tasks = self.tasks.read().unwrap();
182180
tasks.push(task);
183181

184182
Ok(receiver)
@@ -216,7 +214,7 @@ impl AsyncSession {
216214
map_send_result(sender.send(result.into()))?;
217215
Ok(())
218216
});
219-
let mut tasks = self.tasks.write().unwrap();
217+
let tasks = self.tasks.read().unwrap();
220218
tasks.push(task);
221219

222220
Ok(receiver)
@@ -244,7 +242,7 @@ impl AsyncSession {
244242
map_send_result(sender.send(value))?;
245243
Ok(())
246244
});
247-
let mut tasks = self.tasks.write().unwrap();
245+
let tasks = self.tasks.read().unwrap();
248246
tasks.push(task);
249247

250248
Ok(receiver)
@@ -279,7 +277,7 @@ impl AsyncSession {
279277
map_send_result(sender.send(result.into()))?;
280278
Ok(())
281279
});
282-
let mut tasks = self.tasks.write().unwrap();
280+
let tasks = self.tasks.read().unwrap();
283281
tasks.push(task);
284282

285283
Ok(receiver)
@@ -638,8 +636,11 @@ impl AsyncTestRuntime {
638636
session_handles.push(AsyncSessionHandle::for_session(&moose_session))
639637
}
640638

641-
for handle in session_handles {
642-
let result = rt.block_on(handle.join_on_first_error());
639+
let mut futures: FuturesUnordered<_> = session_handles
640+
.into_iter()
641+
.map(|h| h.join_on_first_error())
642+
.collect();
643+
while let Some(result) = rt.block_on(futures.next()) {
643644
if let Err(e) = result {
644645
return Err(Error::TestRuntime(e.to_string()));
645646
}

moose/src/lib.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,6 @@ macro_rules! concrete_dispatch_kernel {
238238

239239
Ok(Box::new(move |sess, operands: Operands<crate::computation::Value>| {
240240
assert_eq!(operands.len(), 0);
241-
242241
let y: $u = k(sess, &plc)?;
243242
if y.placement()? == plc.clone().into() {
244243
Ok(y.into())
@@ -294,7 +293,7 @@ macro_rules! concrete_dispatch_kernel {
294293
Err(crate::error::Error::KernelError(format!("Placement mismatch after running {:?}. Expected {:?} got {:?}", op, plc, y.placement())))
295294
}
296295
});
297-
let mut tasks = tasks.write().unwrap();
296+
let tasks = tasks.read().unwrap();
298297
tasks.push(task);
299298

300299
Ok(result)
@@ -406,7 +405,7 @@ macro_rules! concrete_dispatch_kernel {
406405
Err(crate::error::Error::KernelError(format!("Placement mismatch after running {:?}. Expected {:?} got {:?}", op, plc, y.placement())))
407406
}
408407
});
409-
let mut tasks = tasks.write().unwrap();
408+
let tasks = tasks.read().unwrap();
410409
tasks.push(task);
411410

412411
Ok(result)
@@ -541,7 +540,7 @@ macro_rules! concrete_dispatch_kernel {
541540
Err(crate::error::Error::KernelError(format!("Placement mismatch after running {:?}. Expected {:?} got {:?}", op, plc, y.placement())))
542541
}
543542
});
544-
let mut tasks = tasks.write().unwrap();
543+
let tasks = tasks.read().unwrap();
545544
tasks.push(task);
546545

547546
Ok(result)
@@ -671,7 +670,7 @@ macro_rules! concrete_dispatch_kernel {
671670
Err(crate::error::Error::KernelError(format!("Placement mismatch after running {:?}. Expected {:?} got {:?}", op, plc, y.placement())))
672671
}
673672
});
674-
let mut tasks = tasks.write().unwrap();
673+
let tasks = tasks.read().unwrap();
675674
tasks.push(task);
676675

677676
Ok(result)
@@ -779,7 +778,7 @@ macro_rules! concrete_dispatch_kernel {
779778
Err(crate::error::Error::KernelError(format!("Placement mismatch after running {:?}. Expected {:?} got {:?}", op, plc, y.placement())))
780779
}
781780
});
782-
let mut tasks = tasks.write().unwrap();
781+
let tasks = tasks.read().unwrap();
783782
tasks.push(task);
784783

785784
Ok(result)

0 commit comments

Comments
 (0)