Skip to content

Commit de1112b

Browse files
committed
feat: Thread-local queue push take 3
This commit attempts to re-introduce the thread-local optimization. It stores the local queues in a multiplex hash map keyed by the thread ID that it started in. It also sets it up so the thread can be woken up by a unique runner ID. cc #64 Signed-off-by: John Nunley <[email protected]>
1 parent 444d0c1 commit de1112b

File tree

1 file changed

+136
-29
lines changed

1 file changed

+136
-29
lines changed

src/lib.rs

Lines changed: 136 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,15 @@
3939
)]
4040
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
4141

42+
use std::collections::HashMap;
4243
use std::fmt;
4344
use std::marker::PhantomData;
4445
use std::panic::{RefUnwindSafe, UnwindSafe};
4546
use std::rc::Rc;
46-
use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
47+
use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering};
4748
use std::sync::{Arc, Mutex, RwLock, TryLockError};
4849
use std::task::{Poll, Waker};
50+
use std::thread::{self, ThreadId};
4951

5052
use async_task::{Builder, Runnable};
5153
use concurrent_queue::ConcurrentQueue;
@@ -347,8 +349,32 @@ impl<'a> Executor<'a> {
347349
fn schedule(&self) -> impl Fn(Runnable) + Send + Sync + 'static {
348350
let state = self.state_as_arc();
349351

350-
// TODO: If possible, push into the current local queue and notify the ticker.
351-
move |runnable| {
352+
move |mut runnable| {
353+
// If possible, push into the current local queue and notify the ticker.
354+
if let Some(local_queue) = state
355+
.local_queues
356+
.read()
357+
.unwrap()
358+
.get(&thread::current().id())
359+
.and_then(|list| list.first())
360+
{
361+
match local_queue.queue.push(runnable) {
362+
Ok(()) => {
363+
if let Some(waker) = state
364+
.sleepers
365+
.lock()
366+
.unwrap()
367+
.notify_runner(local_queue.runner_id)
368+
{
369+
waker.wake();
370+
}
371+
return;
372+
}
373+
374+
Err(r) => runnable = r.into_inner(),
375+
}
376+
}
377+
352378
state.queue.push(runnable).unwrap();
353379
state.notify();
354380
}
@@ -665,7 +691,9 @@ struct State {
665691
queue: ConcurrentQueue<Runnable>,
666692

667693
/// Local queues created by runners.
668-
local_queues: RwLock<Vec<Arc<ConcurrentQueue<Runnable>>>>,
694+
///
695+
/// These are keyed by the thread that the runner originated in.
696+
local_queues: RwLock<HashMap<ThreadId, Vec<Arc<LocalQueue>>>>,
669697

670698
/// Set to `true` when a sleeping ticker is notified or no tickers are sleeping.
671699
notified: AtomicBool,
@@ -682,7 +710,7 @@ impl State {
682710
const fn new() -> State {
683711
State {
684712
queue: ConcurrentQueue::unbounded(),
685-
local_queues: RwLock::new(Vec::new()),
713+
local_queues: RwLock::new(HashMap::new()),
686714
notified: AtomicBool::new(true),
687715
sleepers: Mutex::new(Sleepers {
688716
count: 0,
@@ -756,36 +784,57 @@ struct Sleepers {
756784
/// IDs and wakers of sleeping unnotified tickers.
757785
///
758786
/// A sleeping ticker is notified when its waker is missing from this list.
759-
wakers: Vec<(usize, Waker)>,
787+
wakers: Vec<Sleeper>,
760788

761789
/// Reclaimed IDs.
762790
free_ids: Vec<usize>,
763791
}
764792

793+
/// A single sleeping ticker.
794+
struct Sleeper {
795+
/// ID of the sleeping ticker.
796+
id: usize,
797+
798+
/// Waker associated with this ticker.
799+
waker: Waker,
800+
801+
/// Specific runner ID for targeted wakeups.
802+
runner: Option<usize>,
803+
}
804+
765805
impl Sleepers {
766806
/// Inserts a new sleeping ticker.
767-
fn insert(&mut self, waker: &Waker) -> usize {
807+
fn insert(&mut self, waker: &Waker, runner: Option<usize>) -> usize {
768808
let id = match self.free_ids.pop() {
769809
Some(id) => id,
770810
None => self.count + 1,
771811
};
772812
self.count += 1;
773-
self.wakers.push((id, waker.clone()));
813+
self.wakers.push(Sleeper {
814+
id,
815+
waker: waker.clone(),
816+
runner,
817+
});
774818
id
775819
}
776820

777821
/// Re-inserts a sleeping ticker's waker if it was notified.
778822
///
779823
/// Returns `true` if the ticker was notified.
780-
fn update(&mut self, id: usize, waker: &Waker) -> bool {
824+
fn update(&mut self, id: usize, waker: &Waker, runner: Option<usize>) -> bool {
781825
for item in &mut self.wakers {
782-
if item.0 == id {
783-
item.1.clone_from(waker);
826+
if item.id == id {
827+
debug_assert_eq!(item.runner, runner);
828+
item.waker.clone_from(waker);
784829
return false;
785830
}
786831
}
787832

788-
self.wakers.push((id, waker.clone()));
833+
self.wakers.push(Sleeper {
834+
id,
835+
waker: waker.clone(),
836+
runner,
837+
});
789838
true
790839
}
791840

@@ -797,7 +846,7 @@ impl Sleepers {
797846
self.free_ids.push(id);
798847

799848
for i in (0..self.wakers.len()).rev() {
800-
if self.wakers[i].0 == id {
849+
if self.wakers[i].id == id {
801850
self.wakers.remove(i);
802851
return false;
803852
}
@@ -815,7 +864,20 @@ impl Sleepers {
815864
/// If a ticker was notified already or there are no tickers, `None` will be returned.
816865
fn notify(&mut self) -> Option<Waker> {
817866
if self.wakers.len() == self.count {
818-
self.wakers.pop().map(|item| item.1)
867+
self.wakers.pop().map(|item| item.waker)
868+
} else {
869+
None
870+
}
871+
}
872+
873+
/// Notify a specific waker that was previously sleeping.
874+
fn notify_runner(&mut self, runner: usize) -> Option<Waker> {
875+
if let Some(posn) = self
876+
.wakers
877+
.iter()
878+
.position(|sleeper| sleeper.runner == Some(runner))
879+
{
880+
Some(self.wakers.swap_remove(posn).waker)
819881
} else {
820882
None
821883
}
@@ -834,12 +896,28 @@ struct Ticker<'a> {
834896
/// 2a) Sleeping and unnotified.
835897
/// 2b) Sleeping and notified.
836898
sleeping: usize,
899+
900+
/// Unique runner ID, if this is a runner.
901+
runner: Option<usize>,
837902
}
838903

839904
impl Ticker<'_> {
840905
/// Creates a ticker.
841906
fn new(state: &State) -> Ticker<'_> {
842-
Ticker { state, sleeping: 0 }
907+
Ticker {
908+
state,
909+
sleeping: 0,
910+
runner: None,
911+
}
912+
}
913+
914+
/// Creates a ticker for a runner.
915+
fn for_runner(state: &State, runner: usize) -> Ticker<'_> {
916+
Ticker {
917+
state,
918+
sleeping: 0,
919+
runner: Some(runner),
920+
}
843921
}
844922

845923
/// Moves the ticker into sleeping and unnotified state.
@@ -851,12 +929,12 @@ impl Ticker<'_> {
851929
match self.sleeping {
852930
// Move to sleeping state.
853931
0 => {
854-
self.sleeping = sleepers.insert(waker);
932+
self.sleeping = sleepers.insert(waker, self.runner);
855933
}
856934

857935
// Already sleeping, check if notified.
858936
id => {
859-
if !sleepers.update(id, waker) {
937+
if !sleepers.update(id, waker, self.runner) {
860938
return false;
861939
}
862940
}
@@ -946,8 +1024,11 @@ struct Runner<'a> {
9461024
/// Inner ticker.
9471025
ticker: Ticker<'a>,
9481026

1027+
/// The ID of the thread we originated from.
1028+
origin_id: ThreadId,
1029+
9491030
/// The local queue.
950-
local: Arc<ConcurrentQueue<Runnable>>,
1031+
local: Arc<LocalQueue>,
9511032

9521033
/// Bumped every time a runnable task is found.
9531034
ticks: usize,
@@ -956,16 +1037,26 @@ struct Runner<'a> {
9561037
impl Runner<'_> {
9571038
/// Creates a runner and registers it in the executor state.
9581039
fn new(state: &State) -> Runner<'_> {
1040+
static ID_GENERATOR: AtomicUsize = AtomicUsize::new(0);
1041+
let runner_id = ID_GENERATOR.fetch_add(1, Ordering::SeqCst);
1042+
1043+
let origin_id = thread::current().id();
9591044
let runner = Runner {
9601045
state,
961-
ticker: Ticker::new(state),
962-
local: Arc::new(ConcurrentQueue::bounded(512)),
1046+
ticker: Ticker::for_runner(state, runner_id),
1047+
local: Arc::new(LocalQueue {
1048+
queue: ConcurrentQueue::bounded(512),
1049+
runner_id,
1050+
}),
9631051
ticks: 0,
1052+
origin_id,
9641053
};
9651054
state
9661055
.local_queues
9671056
.write()
9681057
.unwrap()
1058+
.entry(origin_id)
1059+
.or_default()
9691060
.push(runner.local.clone());
9701061
runner
9711062
}
@@ -976,13 +1067,13 @@ impl Runner<'_> {
9761067
.ticker
9771068
.runnable_with(|| {
9781069
// Try the local queue.
979-
if let Ok(r) = self.local.pop() {
1070+
if let Ok(r) = self.local.queue.pop() {
9801071
return Some(r);
9811072
}
9821073

9831074
// Try stealing from the global queue.
9841075
if let Ok(r) = self.state.queue.pop() {
985-
steal(&self.state.queue, &self.local);
1076+
steal(&self.state.queue, &self.local.queue);
9861077
return Some(r);
9871078
}
9881079

@@ -994,7 +1085,8 @@ impl Runner<'_> {
9941085
let start = rng.usize(..n);
9951086
let iter = local_queues
9961087
.iter()
997-
.chain(local_queues.iter())
1088+
.flat_map(|(_, list)| list)
1089+
.chain(local_queues.iter().flat_map(|(_, list)| list))
9981090
.skip(start)
9991091
.take(n);
10001092

@@ -1003,8 +1095,8 @@ impl Runner<'_> {
10031095

10041096
// Try stealing from each local queue in the list.
10051097
for local in iter {
1006-
steal(local, &self.local);
1007-
if let Ok(r) = self.local.pop() {
1098+
steal(&local.queue, &self.local.queue);
1099+
if let Ok(r) = self.local.queue.pop() {
10081100
return Some(r);
10091101
}
10101102
}
@@ -1018,7 +1110,7 @@ impl Runner<'_> {
10181110

10191111
if self.ticks % 64 == 0 {
10201112
// Steal tasks from the global queue to ensure fair task scheduling.
1021-
steal(&self.state.queue, &self.local);
1113+
steal(&self.state.queue, &self.local.queue);
10221114
}
10231115

10241116
runnable
@@ -1032,15 +1124,26 @@ impl Drop for Runner<'_> {
10321124
.local_queues
10331125
.write()
10341126
.unwrap()
1127+
.get_mut(&self.origin_id)
1128+
.unwrap()
10351129
.retain(|local| !Arc::ptr_eq(local, &self.local));
10361130

10371131
// Re-schedule remaining tasks in the local queue.
1038-
while let Ok(r) = self.local.pop() {
1132+
while let Ok(r) = self.local.queue.pop() {
10391133
r.schedule();
10401134
}
10411135
}
10421136
}
10431137

1138+
/// Data associated with a local queue.
1139+
struct LocalQueue {
1140+
/// Concurrent queue of active tasks.
1141+
queue: ConcurrentQueue<Runnable>,
1142+
1143+
/// Unique ID associated with this runner.
1144+
runner_id: usize,
1145+
}
1146+
10441147
/// Steals some items from one queue into another.
10451148
fn steal<T>(src: &ConcurrentQueue<T>, dest: &ConcurrentQueue<T>) {
10461149
// Half of `src`'s length rounded up.
@@ -1104,14 +1207,18 @@ fn debug_state(state: &State, name: &str, f: &mut fmt::Formatter<'_>) -> fmt::Re
11041207
}
11051208

11061209
/// Debug wrapper for the local runners.
1107-
struct LocalRunners<'a>(&'a RwLock<Vec<Arc<ConcurrentQueue<Runnable>>>>);
1210+
struct LocalRunners<'a>(&'a RwLock<HashMap<ThreadId, Vec<Arc<LocalQueue>>>>);
11081211

11091212
impl fmt::Debug for LocalRunners<'_> {
11101213
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
11111214
match self.0.try_read() {
11121215
Ok(lock) => f
11131216
.debug_list()
1114-
.entries(lock.iter().map(|queue| queue.len()))
1217+
.entries(
1218+
lock.iter()
1219+
.flat_map(|(_, list)| list)
1220+
.map(|queue| queue.queue.len()),
1221+
)
11151222
.finish(),
11161223
Err(TryLockError::WouldBlock) => f.write_str("<locked>"),
11171224
Err(TryLockError::Poisoned(_)) => f.write_str("<poisoned>"),

0 commit comments

Comments
 (0)