39
39
) ]
40
40
#![ cfg_attr( docsrs, feature( doc_auto_cfg) ) ]
41
41
42
+ use std:: collections:: HashMap ;
42
43
use std:: fmt;
43
44
use std:: marker:: PhantomData ;
44
45
use std:: panic:: { RefUnwindSafe , UnwindSafe } ;
45
46
use std:: rc:: Rc ;
46
- use std:: sync:: atomic:: { AtomicBool , AtomicPtr , Ordering } ;
47
+ use std:: sync:: atomic:: { AtomicBool , AtomicPtr , AtomicUsize , Ordering } ;
47
48
use std:: sync:: { Arc , Mutex , RwLock , TryLockError } ;
48
49
use std:: task:: { Poll , Waker } ;
50
+ use std:: thread:: { self , ThreadId } ;
49
51
50
52
use async_task:: { Builder , Runnable } ;
51
53
use concurrent_queue:: ConcurrentQueue ;
@@ -347,8 +349,32 @@ impl<'a> Executor<'a> {
347
349
fn schedule ( & self ) -> impl Fn ( Runnable ) + Send + Sync + ' static {
348
350
let state = self . state_as_arc ( ) ;
349
351
350
- // TODO: If possible, push into the current local queue and notify the ticker.
351
- move |runnable| {
352
+ move |mut runnable| {
353
+ // If possible, push into the current local queue and notify the ticker.
354
+ if let Some ( local_queue) = state
355
+ . local_queues
356
+ . read ( )
357
+ . unwrap ( )
358
+ . get ( & thread:: current ( ) . id ( ) )
359
+ . and_then ( |list| list. first ( ) )
360
+ {
361
+ match local_queue. queue . push ( runnable) {
362
+ Ok ( ( ) ) => {
363
+ if let Some ( waker) = state
364
+ . sleepers
365
+ . lock ( )
366
+ . unwrap ( )
367
+ . notify_runner ( local_queue. runner_id )
368
+ {
369
+ waker. wake ( ) ;
370
+ }
371
+ return ;
372
+ }
373
+
374
+ Err ( r) => runnable = r. into_inner ( ) ,
375
+ }
376
+ }
377
+
352
378
state. queue . push ( runnable) . unwrap ( ) ;
353
379
state. notify ( ) ;
354
380
}
@@ -665,7 +691,9 @@ struct State {
665
691
queue : ConcurrentQueue < Runnable > ,
666
692
667
693
/// Local queues created by runners.
668
- local_queues : RwLock < Vec < Arc < ConcurrentQueue < Runnable > > > > ,
694
+ ///
695
+ /// These are keyed by the thread that the runner originated in.
696
+ local_queues : RwLock < HashMap < ThreadId , Vec < Arc < LocalQueue > > > > ,
669
697
670
698
/// Set to `true` when a sleeping ticker is notified or no tickers are sleeping.
671
699
notified : AtomicBool ,
@@ -682,7 +710,7 @@ impl State {
682
710
const fn new ( ) -> State {
683
711
State {
684
712
queue : ConcurrentQueue :: unbounded ( ) ,
685
- local_queues : RwLock :: new ( Vec :: new ( ) ) ,
713
+ local_queues : RwLock :: new ( HashMap :: new ( ) ) ,
686
714
notified : AtomicBool :: new ( true ) ,
687
715
sleepers : Mutex :: new ( Sleepers {
688
716
count : 0 ,
@@ -756,36 +784,57 @@ struct Sleepers {
756
784
/// IDs and wakers of sleeping unnotified tickers.
757
785
///
758
786
/// A sleeping ticker is notified when its waker is missing from this list.
759
- wakers : Vec < ( usize , Waker ) > ,
787
+ wakers : Vec < Sleeper > ,
760
788
761
789
/// Reclaimed IDs.
762
790
free_ids : Vec < usize > ,
763
791
}
764
792
793
+ /// A single sleeping ticker.
794
+ struct Sleeper {
795
+ /// ID of the sleeping ticker.
796
+ id : usize ,
797
+
798
+ /// Waker associated with this ticker.
799
+ waker : Waker ,
800
+
801
+ /// Specific runner ID for targeted wakeups.
802
+ runner : Option < usize > ,
803
+ }
804
+
765
805
impl Sleepers {
766
806
/// Inserts a new sleeping ticker.
767
- fn insert ( & mut self , waker : & Waker ) -> usize {
807
+ fn insert ( & mut self , waker : & Waker , runner : Option < usize > ) -> usize {
768
808
let id = match self . free_ids . pop ( ) {
769
809
Some ( id) => id,
770
810
None => self . count + 1 ,
771
811
} ;
772
812
self . count += 1 ;
773
- self . wakers . push ( ( id, waker. clone ( ) ) ) ;
813
+ self . wakers . push ( Sleeper {
814
+ id,
815
+ waker : waker. clone ( ) ,
816
+ runner,
817
+ } ) ;
774
818
id
775
819
}
776
820
777
821
/// Re-inserts a sleeping ticker's waker if it was notified.
778
822
///
779
823
/// Returns `true` if the ticker was notified.
780
- fn update ( & mut self , id : usize , waker : & Waker ) -> bool {
824
+ fn update ( & mut self , id : usize , waker : & Waker , runner : Option < usize > ) -> bool {
781
825
for item in & mut self . wakers {
782
- if item. 0 == id {
783
- item. 1 . clone_from ( waker) ;
826
+ if item. id == id {
827
+ debug_assert_eq ! ( item. runner, runner) ;
828
+ item. waker . clone_from ( waker) ;
784
829
return false ;
785
830
}
786
831
}
787
832
788
- self . wakers . push ( ( id, waker. clone ( ) ) ) ;
833
+ self . wakers . push ( Sleeper {
834
+ id,
835
+ waker : waker. clone ( ) ,
836
+ runner,
837
+ } ) ;
789
838
true
790
839
}
791
840
@@ -797,7 +846,7 @@ impl Sleepers {
797
846
self . free_ids . push ( id) ;
798
847
799
848
for i in ( 0 ..self . wakers . len ( ) ) . rev ( ) {
800
- if self . wakers [ i] . 0 == id {
849
+ if self . wakers [ i] . id == id {
801
850
self . wakers . remove ( i) ;
802
851
return false ;
803
852
}
@@ -815,7 +864,20 @@ impl Sleepers {
815
864
/// If a ticker was notified already or there are no tickers, `None` will be returned.
816
865
fn notify ( & mut self ) -> Option < Waker > {
817
866
if self . wakers . len ( ) == self . count {
818
- self . wakers . pop ( ) . map ( |item| item. 1 )
867
+ self . wakers . pop ( ) . map ( |item| item. waker )
868
+ } else {
869
+ None
870
+ }
871
+ }
872
+
873
+ /// Notify a specific waker that was previously sleeping.
874
+ fn notify_runner ( & mut self , runner : usize ) -> Option < Waker > {
875
+ if let Some ( posn) = self
876
+ . wakers
877
+ . iter ( )
878
+ . position ( |sleeper| sleeper. runner == Some ( runner) )
879
+ {
880
+ Some ( self . wakers . swap_remove ( posn) . waker )
819
881
} else {
820
882
None
821
883
}
@@ -834,12 +896,28 @@ struct Ticker<'a> {
834
896
/// 2a) Sleeping and unnotified.
835
897
/// 2b) Sleeping and notified.
836
898
sleeping : usize ,
899
+
900
+ /// Unique runner ID, if this is a runner.
901
+ runner : Option < usize > ,
837
902
}
838
903
839
904
impl Ticker < ' _ > {
840
905
/// Creates a ticker.
841
906
fn new ( state : & State ) -> Ticker < ' _ > {
842
- Ticker { state, sleeping : 0 }
907
+ Ticker {
908
+ state,
909
+ sleeping : 0 ,
910
+ runner : None ,
911
+ }
912
+ }
913
+
914
+ /// Creates a ticker for a runner.
915
+ fn for_runner ( state : & State , runner : usize ) -> Ticker < ' _ > {
916
+ Ticker {
917
+ state,
918
+ sleeping : 0 ,
919
+ runner : Some ( runner) ,
920
+ }
843
921
}
844
922
845
923
/// Moves the ticker into sleeping and unnotified state.
@@ -851,12 +929,12 @@ impl Ticker<'_> {
851
929
match self . sleeping {
852
930
// Move to sleeping state.
853
931
0 => {
854
- self . sleeping = sleepers. insert ( waker) ;
932
+ self . sleeping = sleepers. insert ( waker, self . runner ) ;
855
933
}
856
934
857
935
// Already sleeping, check if notified.
858
936
id => {
859
- if !sleepers. update ( id, waker) {
937
+ if !sleepers. update ( id, waker, self . runner ) {
860
938
return false ;
861
939
}
862
940
}
@@ -946,8 +1024,11 @@ struct Runner<'a> {
946
1024
/// Inner ticker.
947
1025
ticker : Ticker < ' a > ,
948
1026
1027
+ /// The ID of the thread we originated from.
1028
+ origin_id : ThreadId ,
1029
+
949
1030
/// The local queue.
950
- local : Arc < ConcurrentQueue < Runnable > > ,
1031
+ local : Arc < LocalQueue > ,
951
1032
952
1033
/// Bumped every time a runnable task is found.
953
1034
ticks : usize ,
@@ -956,16 +1037,26 @@ struct Runner<'a> {
956
1037
impl Runner < ' _ > {
957
1038
/// Creates a runner and registers it in the executor state.
958
1039
fn new ( state : & State ) -> Runner < ' _ > {
1040
+ static ID_GENERATOR : AtomicUsize = AtomicUsize :: new ( 0 ) ;
1041
+ let runner_id = ID_GENERATOR . fetch_add ( 1 , Ordering :: SeqCst ) ;
1042
+
1043
+ let origin_id = thread:: current ( ) . id ( ) ;
959
1044
let runner = Runner {
960
1045
state,
961
- ticker : Ticker :: new ( state) ,
962
- local : Arc :: new ( ConcurrentQueue :: bounded ( 512 ) ) ,
1046
+ ticker : Ticker :: for_runner ( state, runner_id) ,
1047
+ local : Arc :: new ( LocalQueue {
1048
+ queue : ConcurrentQueue :: bounded ( 512 ) ,
1049
+ runner_id,
1050
+ } ) ,
963
1051
ticks : 0 ,
1052
+ origin_id,
964
1053
} ;
965
1054
state
966
1055
. local_queues
967
1056
. write ( )
968
1057
. unwrap ( )
1058
+ . entry ( origin_id)
1059
+ . or_default ( )
969
1060
. push ( runner. local . clone ( ) ) ;
970
1061
runner
971
1062
}
@@ -976,13 +1067,13 @@ impl Runner<'_> {
976
1067
. ticker
977
1068
. runnable_with ( || {
978
1069
// Try the local queue.
979
- if let Ok ( r) = self . local . pop ( ) {
1070
+ if let Ok ( r) = self . local . queue . pop ( ) {
980
1071
return Some ( r) ;
981
1072
}
982
1073
983
1074
// Try stealing from the global queue.
984
1075
if let Ok ( r) = self . state . queue . pop ( ) {
985
- steal ( & self . state . queue , & self . local ) ;
1076
+ steal ( & self . state . queue , & self . local . queue ) ;
986
1077
return Some ( r) ;
987
1078
}
988
1079
@@ -994,7 +1085,8 @@ impl Runner<'_> {
994
1085
let start = rng. usize ( ..n) ;
995
1086
let iter = local_queues
996
1087
. iter ( )
997
- . chain ( local_queues. iter ( ) )
1088
+ . flat_map ( |( _, list) | list)
1089
+ . chain ( local_queues. iter ( ) . flat_map ( |( _, list) | list) )
998
1090
. skip ( start)
999
1091
. take ( n) ;
1000
1092
@@ -1003,8 +1095,8 @@ impl Runner<'_> {
1003
1095
1004
1096
// Try stealing from each local queue in the list.
1005
1097
for local in iter {
1006
- steal ( local, & self . local ) ;
1007
- if let Ok ( r) = self . local . pop ( ) {
1098
+ steal ( & local. queue , & self . local . queue ) ;
1099
+ if let Ok ( r) = self . local . queue . pop ( ) {
1008
1100
return Some ( r) ;
1009
1101
}
1010
1102
}
@@ -1018,7 +1110,7 @@ impl Runner<'_> {
1018
1110
1019
1111
if self . ticks % 64 == 0 {
1020
1112
// Steal tasks from the global queue to ensure fair task scheduling.
1021
- steal ( & self . state . queue , & self . local ) ;
1113
+ steal ( & self . state . queue , & self . local . queue ) ;
1022
1114
}
1023
1115
1024
1116
runnable
@@ -1032,15 +1124,26 @@ impl Drop for Runner<'_> {
1032
1124
. local_queues
1033
1125
. write ( )
1034
1126
. unwrap ( )
1127
+ . get_mut ( & self . origin_id )
1128
+ . unwrap ( )
1035
1129
. retain ( |local| !Arc :: ptr_eq ( local, & self . local ) ) ;
1036
1130
1037
1131
// Re-schedule remaining tasks in the local queue.
1038
- while let Ok ( r) = self . local . pop ( ) {
1132
+ while let Ok ( r) = self . local . queue . pop ( ) {
1039
1133
r. schedule ( ) ;
1040
1134
}
1041
1135
}
1042
1136
}
1043
1137
1138
+ /// Data associated with a local queue.
1139
+ struct LocalQueue {
1140
+ /// Concurrent queue of active tasks.
1141
+ queue : ConcurrentQueue < Runnable > ,
1142
+
1143
+ /// Unique ID associated with this runner.
1144
+ runner_id : usize ,
1145
+ }
1146
+
1044
1147
/// Steals some items from one queue into another.
1045
1148
fn steal < T > ( src : & ConcurrentQueue < T > , dest : & ConcurrentQueue < T > ) {
1046
1149
// Half of `src`'s length rounded up.
@@ -1104,14 +1207,18 @@ fn debug_state(state: &State, name: &str, f: &mut fmt::Formatter<'_>) -> fmt::Re
1104
1207
}
1105
1208
1106
1209
/// Debug wrapper for the local runners.
1107
- struct LocalRunners < ' a > ( & ' a RwLock < Vec < Arc < ConcurrentQueue < Runnable > > > > ) ;
1210
+ struct LocalRunners < ' a > ( & ' a RwLock < HashMap < ThreadId , Vec < Arc < LocalQueue > > > > ) ;
1108
1211
1109
1212
impl fmt:: Debug for LocalRunners < ' _ > {
1110
1213
fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
1111
1214
match self . 0 . try_read ( ) {
1112
1215
Ok ( lock) => f
1113
1216
. debug_list ( )
1114
- . entries ( lock. iter ( ) . map ( |queue| queue. len ( ) ) )
1217
+ . entries (
1218
+ lock. iter ( )
1219
+ . flat_map ( |( _, list) | list)
1220
+ . map ( |queue| queue. queue . len ( ) ) ,
1221
+ )
1115
1222
. finish ( ) ,
1116
1223
Err ( TryLockError :: WouldBlock ) => f. write_str ( "<locked>" ) ,
1117
1224
Err ( TryLockError :: Poisoned ( _) ) => f. write_str ( "<poisoned>" ) ,
0 commit comments