@@ -188,10 +188,10 @@ impl DownstairsClient {
188
188
tls_context : Option < Arc < crucible_common:: x509:: TLSContext > > ,
189
189
) -> Self {
190
190
let client_delay_us = Arc :: new ( AtomicU64 :: new ( 0 ) ) ;
191
- let ( client_task, client_connect_tx) = Self :: new_io_task (
191
+ let ( client_connect_tx, client_connect_rx) = oneshot:: channel ( ) ;
192
+ let client_task = Self :: new_io_task (
192
193
target_addr,
193
- false , // do not delay in starting the task
194
- false , // do not start the task until GoActive
194
+ ClientConnectDelay :: Wait ( client_connect_rx) ,
195
195
client_id,
196
196
tls_context. clone ( ) ,
197
197
client_delay_us. clone ( ) ,
@@ -209,10 +209,7 @@ impl DownstairsClient {
209
209
repair_addr : None ,
210
210
state : DsStateData :: Connecting {
211
211
mode : ConnectionMode :: New ,
212
- state : match client_connect_tx {
213
- Some ( t) => NegotiationStateData :: WaitConnect ( t) ,
214
- None => NegotiationStateData :: Start ,
215
- } ,
212
+ state : NegotiationStateData :: WaitConnect ( client_connect_tx) ,
216
213
} ,
217
214
last_flush : None ,
218
215
stats : DownstairsStats :: default ( ) ,
@@ -572,11 +569,26 @@ impl DownstairsClient {
572
569
573
570
self . connection_id . update ( ) ;
574
571
575
- // Restart with a short delay, connecting if we're not disabled
576
- let state = match self . start_task ( true , auto_connect) {
577
- Some ( t) => NegotiationStateData :: WaitConnect ( t) ,
578
- None => NegotiationStateData :: Start ,
572
+ let ( client_connect, state) = if auto_connect {
573
+ (
574
+ ClientConnectDelay :: Delay ( CLIENT_RECONNECT_DELAY ) ,
575
+ NegotiationStateData :: Start ,
576
+ )
577
+ } else {
578
+ let ( client_connect_tx, client_connect_rx) = oneshot:: channel ( ) ;
579
+ (
580
+ ClientConnectDelay :: Wait ( client_connect_rx) ,
581
+ NegotiationStateData :: WaitConnect ( client_connect_tx) ,
582
+ )
579
583
} ;
584
+ self . client_task = Self :: new_io_task (
585
+ self . target_addr ,
586
+ client_connect,
587
+ self . client_id ,
588
+ self . tls_context . clone ( ) ,
589
+ self . client_delay_us . clone ( ) ,
590
+ & self . log ,
591
+ ) ;
580
592
581
593
let new_state = DsStateData :: Connecting {
582
594
mode : new_mode,
@@ -591,65 +603,32 @@ impl DownstairsClient {
591
603
self . last_flush
592
604
}
593
605
594
- /// Starts a client IO task, saving the handle in `self.client_task`
595
- ///
596
- /// If we are running unit tests and `self.target_addr` is not populated, we
597
- /// start a dummy task instead.
598
- ///
599
- /// Returns the connection oneshot, or `None` if we're connecting
600
- /// automatically.
601
- ///
602
- /// # Panics
603
- /// If `self.client_task` is not `None`, or `self.target_addr` is `None` and
604
- /// this isn't running in test mode
605
- #[ must_use]
606
- fn start_task (
607
- & mut self ,
608
- delay : bool ,
609
- connect : bool ,
610
- ) -> Option < oneshot:: Sender < ( ) > > {
611
- let ( client_task, client_connect_tx) = Self :: new_io_task (
612
- self . target_addr ,
613
- delay,
614
- connect,
615
- self . client_id ,
616
- self . tls_context . clone ( ) ,
617
- self . client_delay_us . clone ( ) ,
618
- & self . log ,
619
- ) ;
620
- self . client_task = client_task;
621
- client_connect_tx
622
- }
623
-
624
606
fn new_io_task (
625
607
target : Option < SocketAddr > ,
626
- delay : bool ,
627
- connect : bool ,
608
+ start : ClientConnectDelay ,
628
609
client_id : ClientId ,
629
610
tls_context : Option < Arc < TLSContext > > ,
630
611
client_delay_us : Arc < AtomicU64 > ,
631
612
log : & Logger ,
632
- ) -> ( ClientTaskHandle , Option < oneshot :: Sender < ( ) > > ) {
613
+ ) -> ClientTaskHandle {
633
614
#[ cfg( test) ]
634
615
if let Some ( target) = target {
635
616
Self :: new_network_task (
636
617
target,
637
- delay,
638
- connect,
618
+ start,
639
619
client_id,
640
620
tls_context,
641
621
client_delay_us,
642
622
log,
643
623
)
644
624
} else {
645
- Self :: new_dummy_task ( connect )
625
+ Self :: new_dummy_task ( start )
646
626
}
647
627
648
628
#[ cfg( not( test) ) ]
649
629
Self :: new_network_task (
650
630
target. expect ( "must provide socketaddr" ) ,
651
- delay,
652
- connect,
631
+ start,
653
632
client_id,
654
633
tls_context,
655
634
client_delay_us,
@@ -659,27 +638,18 @@ impl DownstairsClient {
659
638
660
639
fn new_network_task (
661
640
target : SocketAddr ,
662
- delay : bool ,
663
- connect : bool ,
641
+ start : ClientConnectDelay ,
664
642
client_id : ClientId ,
665
643
tls_context : Option < Arc < TLSContext > > ,
666
644
client_delay_us : Arc < AtomicU64 > ,
667
645
log : & Logger ,
668
- ) -> ( ClientTaskHandle , Option < oneshot :: Sender < ( ) > > ) {
646
+ ) -> ClientTaskHandle {
669
647
// Messages in flight are limited by backpressure, so we can use
670
648
// unbounded channels here without fear of runaway.
671
649
let ( client_request_tx, client_request_rx) = mpsc:: unbounded_channel ( ) ;
672
650
let ( client_response_tx, client_response_rx) =
673
651
mpsc:: unbounded_channel ( ) ;
674
652
let ( client_stop_tx, client_stop_rx) = oneshot:: channel ( ) ;
675
- let ( client_connect_tx, client_connect_rx) = oneshot:: channel ( ) ;
676
-
677
- let client_connect_tx = if connect {
678
- client_connect_tx. send ( ( ) ) . unwrap ( ) ;
679
- None
680
- } else {
681
- Some ( client_connect_tx)
682
- } ;
683
653
684
654
let log = log. new ( o ! ( "" => "io task" ) ) ;
685
655
tokio:: spawn ( async move {
@@ -689,57 +659,41 @@ impl DownstairsClient {
689
659
target,
690
660
request_rx : client_request_rx,
691
661
response_tx : client_response_tx,
692
- start : client_connect_rx,
693
662
stop : client_stop_rx,
694
663
recv_task : ClientRxTask {
695
664
handle : None ,
696
665
log : log. clone ( ) ,
697
666
} ,
698
- delay,
699
667
client_delay_us,
700
668
log,
701
669
} ;
702
- c. run ( ) . await
670
+ c. run ( start ) . await
703
671
} ) ;
704
- (
705
- ClientTaskHandle {
706
- client_request_tx,
707
- client_stop_tx : Some ( client_stop_tx) ,
708
- client_response_rx,
709
- } ,
710
- client_connect_tx,
711
- )
672
+ ClientTaskHandle {
673
+ client_request_tx,
674
+ client_stop_tx : Some ( client_stop_tx) ,
675
+ client_response_rx,
676
+ }
712
677
}
713
678
714
679
/// Starts a dummy IO task, returning its IO handle
715
680
#[ cfg( test) ]
716
- fn new_dummy_task (
717
- connect : bool ,
718
- ) -> ( ClientTaskHandle , Option < oneshot:: Sender < ( ) > > ) {
681
+ fn new_dummy_task ( _start : ClientConnectDelay ) -> ClientTaskHandle {
719
682
let ( client_request_tx, client_request_rx) = mpsc:: unbounded_channel ( ) ;
720
683
let ( _client_response_tx, client_response_rx) =
721
684
mpsc:: unbounded_channel ( ) ;
722
685
let ( client_stop_tx, client_stop_rx) = oneshot:: channel ( ) ;
723
- let ( client_connect_tx, client_connect_rx) = oneshot:: channel ( ) ;
724
686
725
687
// Forget these without dropping them, so that we can send values into
726
688
// the void!
727
689
std:: mem:: forget ( client_request_rx) ;
728
690
std:: mem:: forget ( client_stop_rx) ;
729
- std:: mem:: forget ( client_connect_rx) ;
730
691
731
- (
732
- ClientTaskHandle {
733
- client_request_tx,
734
- client_stop_tx : Some ( client_stop_tx) ,
735
- client_response_rx,
736
- } ,
737
- if connect {
738
- None
739
- } else {
740
- Some ( client_connect_tx)
741
- } ,
742
- )
692
+ ClientTaskHandle {
693
+ client_request_tx,
694
+ client_stop_tx : Some ( client_stop_tx) ,
695
+ client_response_rx,
696
+ }
743
697
}
744
698
745
699
/// Indicate that the upstairs has requested that we go active
@@ -2225,15 +2179,9 @@ struct ClientIoTask {
2225
2179
/// Reply channel to the main task
2226
2180
response_tx : mpsc:: UnboundedSender < ClientResponse > ,
2227
2181
2228
- /// Oneshot used to start the task
2229
- start : oneshot:: Receiver < ( ) > ,
2230
-
2231
2182
/// Oneshot used to stop the task
2232
2183
stop : oneshot:: Receiver < ClientStopReason > ,
2233
2184
2234
- /// Delay on startup, to avoid a busy-loop if connections always fail
2235
- delay : bool ,
2236
-
2237
2185
/// Handle for the rx task
2238
2186
recv_task : ClientRxTask ,
2239
2187
@@ -2245,6 +2193,32 @@ struct ClientIoTask {
2245
2193
log : Logger ,
2246
2194
}
2247
2195
2196
+ enum ClientConnectDelay {
2197
+ /// Connect after a fixed delay
2198
+ Delay ( std:: time:: Duration ) ,
2199
+ /// Wait for a oneshot to fire before connecting
2200
+ Wait ( oneshot:: Receiver < ( ) > ) ,
2201
+ }
2202
+
2203
+ impl ClientConnectDelay {
2204
+ async fn wait ( self , log : & Logger ) -> Result < ( ) , ClientRunResult > {
2205
+ match self {
2206
+ ClientConnectDelay :: Delay ( dur) => {
2207
+ info ! ( log, "sleeping for {dur:?} before connecting" ) ;
2208
+ tokio:: time:: sleep ( dur) . await ;
2209
+ Ok ( ( ) )
2210
+ }
2211
+ ClientConnectDelay :: Wait ( rx) => {
2212
+ info ! ( log, "client is waiting for oneshot" ) ;
2213
+ rx. await . map_err ( |e| {
2214
+ warn ! ( log, "failed to await start oneshot: {e}" ) ;
2215
+ ClientRunResult :: QueueClosed
2216
+ } )
2217
+ }
2218
+ }
2219
+ }
2220
+ }
2221
+
2248
2222
/// Handle for the rx side of client IO
2249
2223
///
2250
2224
/// This is a convenient wrapper so that we can join the task exactly once,
@@ -2293,8 +2267,8 @@ impl Drop for ClientRxTask {
2293
2267
}
2294
2268
2295
2269
impl ClientIoTask {
2296
- async fn run ( & mut self ) {
2297
- let r = self . run_inner ( ) . await ;
2270
+ async fn run ( & mut self , start : ClientConnectDelay ) {
2271
+ let r = self . run_inner ( start ) . await ;
2298
2272
2299
2273
warn ! ( self . log, "client task is sending Done({r:?})" ) ;
2300
2274
if self . response_tx . send ( ClientResponse :: Done ( r) ) . is_err ( ) {
@@ -2309,41 +2283,18 @@ impl ClientIoTask {
2309
2283
info ! ( self . log, "client task is exiting" ) ;
2310
2284
}
2311
2285
2312
- async fn run_inner ( & mut self ) -> ClientRunResult {
2313
- // If we're reconnecting, then add a short delay to avoid constantly
2314
- // spinning (e.g. if something is fundamentally wrong with the
2315
- // Downstairs)
2316
- //
2317
- // The upstairs can still stop us here, e.g. if we need to transition
2318
- // from Offline -> Faulted because we hit a job limit, that bounces the
2319
- // IO task (whether it *should* is debatable).
2320
- if self . delay {
2321
- tokio:: select! {
2322
- s = & mut self . stop => {
2323
- warn!(
2324
- self . log,
2325
- "client IO task stopped during sleep: {s:?}"
2326
- ) ;
2327
- return s. into( ) ;
2328
- }
2329
- _ = tokio:: time:: sleep( CLIENT_RECONNECT_DELAY ) => {
2330
- // this is fine
2331
- } ,
2332
- }
2333
- }
2334
-
2335
- // Wait for the start oneshot to fire. This may happen immediately, but
2336
- // not necessarily (for example, if the client was deactivated). We
2337
- // also wait for the stop oneshot here, in case someone decides to stop
2338
- // the IO task before it tries to connect.
2286
+ async fn run_inner (
2287
+ & mut self ,
2288
+ start : ClientConnectDelay ,
2289
+ ) -> ClientRunResult {
2290
+ // Wait for either the connection delay to expire (either time-based or
2291
+ // a oneshot), or for the stop oneshot to receive a message.
2339
2292
tokio:: select! {
2340
- s = & mut self . start => {
2341
- if let Err ( e) = s {
2342
- warn!( self . log, "failed to await start oneshot: {e}" ) ;
2343
- return ClientRunResult :: QueueClosed ;
2293
+ r = start. wait( & self . log) => {
2294
+ if let Err ( e) = r {
2295
+ return e;
2344
2296
}
2345
- // Otherwise, continue as usual
2346
- }
2297
+ } ,
2347
2298
s = & mut self . stop => {
2348
2299
warn!(
2349
2300
self . log,
0 commit comments