Skip to content

Commit 9f56a5c

Browse files
authored
Improve ClientIoTask start logic (#1731)
In the Crucible upstairs, we stop the client IO task for various reasons, which fall into two broad categories: - Something has gone wrong, and we need to recover with a fresh connection. In this case, it's wise to sleep for a little while; we don't want to spam connections to an offline Downstairs as fast as possible. This is controlled by a `delay: bool`, flag which adds a 10-second delay before reconnecting. - We are deliberately stopping the upstairs. In this case, we don't want to try connecting at all until the upstairs goes active again. This is controlled by an `Option<oneshot::Receiver<()>>`, which is awaited before connecting if present. In the current codebase, when the upstairs is disabled, we restart the Client IO task with **both** of these measures in place: the Client IO task waits for a message on a oneshot, _then_ sleeps for 10 seconds before connecting. This is awkward, because it means that a `GoActive` request is blocked for 10 seconds. We could just fix this issue, but I realized that the flags are mutually exclusive: we should only ever start an IO task in one or the other mode. As such, we can use a new `enum ClientConnectDelay` to represent the choice. This ends up being a nice simplification!
1 parent 9327f6c commit 9f56a5c

File tree

1 file changed

+80
-129
lines changed

1 file changed

+80
-129
lines changed

upstairs/src/client.rs

Lines changed: 80 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,10 @@ impl DownstairsClient {
188188
tls_context: Option<Arc<crucible_common::x509::TLSContext>>,
189189
) -> Self {
190190
let client_delay_us = Arc::new(AtomicU64::new(0));
191-
let (client_task, client_connect_tx) = Self::new_io_task(
191+
let (client_connect_tx, client_connect_rx) = oneshot::channel();
192+
let client_task = Self::new_io_task(
192193
target_addr,
193-
false, // do not delay in starting the task
194-
false, // do not start the task until GoActive
194+
ClientConnectDelay::Wait(client_connect_rx),
195195
client_id,
196196
tls_context.clone(),
197197
client_delay_us.clone(),
@@ -209,10 +209,7 @@ impl DownstairsClient {
209209
repair_addr: None,
210210
state: DsStateData::Connecting {
211211
mode: ConnectionMode::New,
212-
state: match client_connect_tx {
213-
Some(t) => NegotiationStateData::WaitConnect(t),
214-
None => NegotiationStateData::Start,
215-
},
212+
state: NegotiationStateData::WaitConnect(client_connect_tx),
216213
},
217214
last_flush: None,
218215
stats: DownstairsStats::default(),
@@ -572,11 +569,26 @@ impl DownstairsClient {
572569

573570
self.connection_id.update();
574571

575-
// Restart with a short delay, connecting if we're not disabled
576-
let state = match self.start_task(true, auto_connect) {
577-
Some(t) => NegotiationStateData::WaitConnect(t),
578-
None => NegotiationStateData::Start,
572+
let (client_connect, state) = if auto_connect {
573+
(
574+
ClientConnectDelay::Delay(CLIENT_RECONNECT_DELAY),
575+
NegotiationStateData::Start,
576+
)
577+
} else {
578+
let (client_connect_tx, client_connect_rx) = oneshot::channel();
579+
(
580+
ClientConnectDelay::Wait(client_connect_rx),
581+
NegotiationStateData::WaitConnect(client_connect_tx),
582+
)
579583
};
584+
self.client_task = Self::new_io_task(
585+
self.target_addr,
586+
client_connect,
587+
self.client_id,
588+
self.tls_context.clone(),
589+
self.client_delay_us.clone(),
590+
&self.log,
591+
);
580592

581593
let new_state = DsStateData::Connecting {
582594
mode: new_mode,
@@ -591,65 +603,32 @@ impl DownstairsClient {
591603
self.last_flush
592604
}
593605

594-
/// Starts a client IO task, saving the handle in `self.client_task`
595-
///
596-
/// If we are running unit tests and `self.target_addr` is not populated, we
597-
/// start a dummy task instead.
598-
///
599-
/// Returns the connection oneshot, or `None` if we're connecting
600-
/// automatically.
601-
///
602-
/// # Panics
603-
/// If `self.client_task` is not `None`, or `self.target_addr` is `None` and
604-
/// this isn't running in test mode
605-
#[must_use]
606-
fn start_task(
607-
&mut self,
608-
delay: bool,
609-
connect: bool,
610-
) -> Option<oneshot::Sender<()>> {
611-
let (client_task, client_connect_tx) = Self::new_io_task(
612-
self.target_addr,
613-
delay,
614-
connect,
615-
self.client_id,
616-
self.tls_context.clone(),
617-
self.client_delay_us.clone(),
618-
&self.log,
619-
);
620-
self.client_task = client_task;
621-
client_connect_tx
622-
}
623-
624606
fn new_io_task(
625607
target: Option<SocketAddr>,
626-
delay: bool,
627-
connect: bool,
608+
start: ClientConnectDelay,
628609
client_id: ClientId,
629610
tls_context: Option<Arc<TLSContext>>,
630611
client_delay_us: Arc<AtomicU64>,
631612
log: &Logger,
632-
) -> (ClientTaskHandle, Option<oneshot::Sender<()>>) {
613+
) -> ClientTaskHandle {
633614
#[cfg(test)]
634615
if let Some(target) = target {
635616
Self::new_network_task(
636617
target,
637-
delay,
638-
connect,
618+
start,
639619
client_id,
640620
tls_context,
641621
client_delay_us,
642622
log,
643623
)
644624
} else {
645-
Self::new_dummy_task(connect)
625+
Self::new_dummy_task(start)
646626
}
647627

648628
#[cfg(not(test))]
649629
Self::new_network_task(
650630
target.expect("must provide socketaddr"),
651-
delay,
652-
connect,
631+
start,
653632
client_id,
654633
tls_context,
655634
client_delay_us,
@@ -659,27 +638,18 @@ impl DownstairsClient {
659638

660639
fn new_network_task(
661640
target: SocketAddr,
662-
delay: bool,
663-
connect: bool,
641+
start: ClientConnectDelay,
664642
client_id: ClientId,
665643
tls_context: Option<Arc<TLSContext>>,
666644
client_delay_us: Arc<AtomicU64>,
667645
log: &Logger,
668-
) -> (ClientTaskHandle, Option<oneshot::Sender<()>>) {
646+
) -> ClientTaskHandle {
669647
// Messages in flight are limited by backpressure, so we can use
670648
// unbounded channels here without fear of runaway.
671649
let (client_request_tx, client_request_rx) = mpsc::unbounded_channel();
672650
let (client_response_tx, client_response_rx) =
673651
mpsc::unbounded_channel();
674652
let (client_stop_tx, client_stop_rx) = oneshot::channel();
675-
let (client_connect_tx, client_connect_rx) = oneshot::channel();
676-
677-
let client_connect_tx = if connect {
678-
client_connect_tx.send(()).unwrap();
679-
None
680-
} else {
681-
Some(client_connect_tx)
682-
};
683653

684654
let log = log.new(o!("" => "io task"));
685655
tokio::spawn(async move {
@@ -689,57 +659,41 @@ impl DownstairsClient {
689659
target,
690660
request_rx: client_request_rx,
691661
response_tx: client_response_tx,
692-
start: client_connect_rx,
693662
stop: client_stop_rx,
694663
recv_task: ClientRxTask {
695664
handle: None,
696665
log: log.clone(),
697666
},
698-
delay,
699667
client_delay_us,
700668
log,
701669
};
702-
c.run().await
670+
c.run(start).await
703671
});
704-
(
705-
ClientTaskHandle {
706-
client_request_tx,
707-
client_stop_tx: Some(client_stop_tx),
708-
client_response_rx,
709-
},
710-
client_connect_tx,
711-
)
672+
ClientTaskHandle {
673+
client_request_tx,
674+
client_stop_tx: Some(client_stop_tx),
675+
client_response_rx,
676+
}
712677
}
713678

714679
/// Starts a dummy IO task, returning its IO handle
715680
#[cfg(test)]
716-
fn new_dummy_task(
717-
connect: bool,
718-
) -> (ClientTaskHandle, Option<oneshot::Sender<()>>) {
681+
fn new_dummy_task(_start: ClientConnectDelay) -> ClientTaskHandle {
719682
let (client_request_tx, client_request_rx) = mpsc::unbounded_channel();
720683
let (_client_response_tx, client_response_rx) =
721684
mpsc::unbounded_channel();
722685
let (client_stop_tx, client_stop_rx) = oneshot::channel();
723-
let (client_connect_tx, client_connect_rx) = oneshot::channel();
724686

725687
// Forget these without dropping them, so that we can send values into
726688
// the void!
727689
std::mem::forget(client_request_rx);
728690
std::mem::forget(client_stop_rx);
729-
std::mem::forget(client_connect_rx);
730691

731-
(
732-
ClientTaskHandle {
733-
client_request_tx,
734-
client_stop_tx: Some(client_stop_tx),
735-
client_response_rx,
736-
},
737-
if connect {
738-
None
739-
} else {
740-
Some(client_connect_tx)
741-
},
742-
)
692+
ClientTaskHandle {
693+
client_request_tx,
694+
client_stop_tx: Some(client_stop_tx),
695+
client_response_rx,
696+
}
743697
}
744698

745699
/// Indicate that the upstairs has requested that we go active
@@ -2225,15 +2179,9 @@ struct ClientIoTask {
22252179
/// Reply channel to the main task
22262180
response_tx: mpsc::UnboundedSender<ClientResponse>,
22272181

2228-
/// Oneshot used to start the task
2229-
start: oneshot::Receiver<()>,
2230-
22312182
/// Oneshot used to stop the task
22322183
stop: oneshot::Receiver<ClientStopReason>,
22332184

2234-
/// Delay on startup, to avoid a busy-loop if connections always fail
2235-
delay: bool,
2236-
22372185
/// Handle for the rx task
22382186
recv_task: ClientRxTask,
22392187

@@ -2245,6 +2193,32 @@ struct ClientIoTask {
22452193
log: Logger,
22462194
}
22472195

2196+
enum ClientConnectDelay {
2197+
/// Connect after a fixed delay
2198+
Delay(std::time::Duration),
2199+
/// Wait for a oneshot to fire before connecting
2200+
Wait(oneshot::Receiver<()>),
2201+
}
2202+
2203+
impl ClientConnectDelay {
2204+
async fn wait(self, log: &Logger) -> Result<(), ClientRunResult> {
2205+
match self {
2206+
ClientConnectDelay::Delay(dur) => {
2207+
info!(log, "sleeping for {dur:?} before connecting");
2208+
tokio::time::sleep(dur).await;
2209+
Ok(())
2210+
}
2211+
ClientConnectDelay::Wait(rx) => {
2212+
info!(log, "client is waiting for oneshot");
2213+
rx.await.map_err(|e| {
2214+
warn!(log, "failed to await start oneshot: {e}");
2215+
ClientRunResult::QueueClosed
2216+
})
2217+
}
2218+
}
2219+
}
2220+
}
2221+
22482222
/// Handle for the rx side of client IO
22492223
///
22502224
/// This is a convenient wrapper so that we can join the task exactly once,
@@ -2293,8 +2267,8 @@ impl Drop for ClientRxTask {
22932267
}
22942268

22952269
impl ClientIoTask {
2296-
async fn run(&mut self) {
2297-
let r = self.run_inner().await;
2270+
async fn run(&mut self, start: ClientConnectDelay) {
2271+
let r = self.run_inner(start).await;
22982272

22992273
warn!(self.log, "client task is sending Done({r:?})");
23002274
if self.response_tx.send(ClientResponse::Done(r)).is_err() {
@@ -2309,41 +2283,18 @@ impl ClientIoTask {
23092283
info!(self.log, "client task is exiting");
23102284
}
23112285

2312-
async fn run_inner(&mut self) -> ClientRunResult {
2313-
// If we're reconnecting, then add a short delay to avoid constantly
2314-
// spinning (e.g. if something is fundamentally wrong with the
2315-
// Downstairs)
2316-
//
2317-
// The upstairs can still stop us here, e.g. if we need to transition
2318-
// from Offline -> Faulted because we hit a job limit, that bounces the
2319-
// IO task (whether it *should* is debatable).
2320-
if self.delay {
2321-
tokio::select! {
2322-
s = &mut self.stop => {
2323-
warn!(
2324-
self.log,
2325-
"client IO task stopped during sleep: {s:?}"
2326-
);
2327-
return s.into();
2328-
}
2329-
_ = tokio::time::sleep(CLIENT_RECONNECT_DELAY) => {
2330-
// this is fine
2331-
},
2332-
}
2333-
}
2334-
2335-
// Wait for the start oneshot to fire. This may happen immediately, but
2336-
// not necessarily (for example, if the client was deactivated). We
2337-
// also wait for the stop oneshot here, in case someone decides to stop
2338-
// the IO task before it tries to connect.
2286+
async fn run_inner(
2287+
&mut self,
2288+
start: ClientConnectDelay,
2289+
) -> ClientRunResult {
2290+
// Wait for either the connection delay to expire (either time-based or
2291+
// a oneshot), or for the stop oneshot to receive a message.
23392292
tokio::select! {
2340-
s = &mut self.start => {
2341-
if let Err(e) = s {
2342-
warn!(self.log, "failed to await start oneshot: {e}");
2343-
return ClientRunResult::QueueClosed;
2293+
r = start.wait(&self.log) => {
2294+
if let Err(e) = r {
2295+
return e;
23442296
}
2345-
// Otherwise, continue as usual
2346-
}
2297+
},
23472298
s = &mut self.stop => {
23482299
warn!(
23492300
self.log,

0 commit comments

Comments
 (0)