@@ -36,6 +36,7 @@ use bitcoin::{BlockHash, ScriptBuf, Transaction, Txid};
3636
3737use core:: future:: Future ;
3838use core:: ops:: Deref ;
39+ use core:: pin:: Pin ;
3940use core:: sync:: atomic:: { AtomicBool , Ordering } ;
4041use core:: task;
4142
@@ -414,7 +415,7 @@ where
414415 /// Returns `Err` on persistence failure, in which case the call may be safely retried.
415416 ///
416417 /// [`Event::SpendableOutputs`]: crate::events::Event::SpendableOutputs
417- pub fn track_spendable_outputs (
418+ pub async fn track_spendable_outputs (
418419 & self , output_descriptors : Vec < SpendableOutputDescriptor > , channel_id : Option < ChannelId > ,
419420 exclude_static_outputs : bool , delay_until_height : Option < u32 > ,
420421 ) -> Result < ( ) , ( ) > {
@@ -430,29 +431,34 @@ where
430431 return Ok ( ( ) ) ;
431432 }
432433
433- let mut state_lock = self . sweeper_state . lock ( ) . unwrap ( ) ;
434- for descriptor in relevant_descriptors {
435- let output_info = TrackedSpendableOutput {
436- descriptor,
437- channel_id,
438- status : OutputSpendStatus :: PendingInitialBroadcast {
439- delayed_until_height : delay_until_height,
440- } ,
441- } ;
442-
443- let mut outputs = state_lock. persistent . outputs . iter ( ) ;
444- if outputs. find ( |o| o. descriptor == output_info. descriptor ) . is_some ( ) {
445- continue ;
446- }
434+ let persist_fut;
435+ {
436+ let mut state_lock = self . sweeper_state . lock ( ) . unwrap ( ) ;
437+ for descriptor in relevant_descriptors {
438+ let output_info = TrackedSpendableOutput {
439+ descriptor,
440+ channel_id,
441+ status : OutputSpendStatus :: PendingInitialBroadcast {
442+ delayed_until_height : delay_until_height,
443+ } ,
444+ } ;
447445
448- state_lock. persistent . outputs . push ( output_info) ;
446+ let mut outputs = state_lock. persistent . outputs . iter ( ) ;
447+ if outputs. find ( |o| o. descriptor == output_info. descriptor ) . is_some ( ) {
448+ continue ;
449+ }
450+
451+ state_lock. persistent . outputs . push ( output_info) ;
452+ }
453+ persist_fut = self . persist_state ( & state_lock. persistent ) ;
454+ state_lock. dirty = false ;
449455 }
450- self . persist_state ( & state_lock. persistent ) . map_err ( |e| {
451- log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
452- } ) ?;
453- state_lock. dirty = false ;
454456
455- Ok ( ( ) )
457+ persist_fut. await . map_err ( |e| {
458+ self . sweeper_state . lock ( ) . unwrap ( ) . dirty = true ;
459+
460+ log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
461+ } )
456462 }
457463
458464 /// Returns a list of the currently tracked spendable outputs.
@@ -508,30 +514,42 @@ where
508514 } ;
509515
510516 // See if there is anything to sweep before requesting a change address.
517+ let persist_fut;
518+ let has_respends;
511519 {
512520 let mut sweeper_state = self . sweeper_state . lock ( ) . unwrap ( ) ;
513521
514522 let cur_height = sweeper_state. persistent . best_block . height ;
515- let has_respends =
523+ has_respends =
516524 sweeper_state. persistent . outputs . iter ( ) . any ( |o| filter_fn ( o, cur_height) ) ;
517- if !has_respends {
525+ if !has_respends && sweeper_state . dirty {
518526 // If there is nothing to sweep, we still persist the state if it is dirty.
519- if sweeper_state. dirty {
520- self . persist_state ( & sweeper_state. persistent ) . map_err ( |e| {
521- log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
522- } ) ?;
523- sweeper_state. dirty = false ;
524- }
525-
526- return Ok ( ( ) ) ;
527+ persist_fut = Some ( self . persist_state ( & sweeper_state. persistent ) ) ;
528+ sweeper_state. dirty = false ;
529+ } else {
530+ persist_fut = None ;
527531 }
528532 }
529533
534+ if let Some ( persist_fut) = persist_fut {
535+ persist_fut. await . map_err ( |e| {
536+ self . sweeper_state . lock ( ) . unwrap ( ) . dirty = true ;
537+
538+ log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
539+ } ) ?;
540+ } ;
541+
542+ if !has_respends {
543+ // If there is nothing to sweep, we return early.
544+ return Ok ( ( ) ) ;
545+ }
546+
530547 // Request a new change address outside of the mutex to avoid the mutex crossing await.
531548 let change_destination_script =
532549 self . change_destination_source . get_change_destination_script ( ) . await ?;
533550
534551 // Sweep the outputs.
552+ let persist_fut;
535553 {
536554 let mut sweeper_state = self . sweeper_state . lock ( ) . unwrap ( ) ;
537555
@@ -581,14 +599,17 @@ where
581599 output_info. status . broadcast ( cur_hash, cur_height, spending_tx. clone ( ) ) ;
582600 }
583601
584- self . persist_state ( & sweeper_state. persistent ) . map_err ( |e| {
585- log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
586- } ) ?;
602+ persist_fut = self . persist_state ( & sweeper_state. persistent ) ;
587603 sweeper_state. dirty = false ;
588-
589604 self . broadcaster . broadcast_transactions ( & [ & spending_tx] ) ;
590605 }
591606
607+ persist_fut. await . map_err ( |e| {
608+ self . sweeper_state . lock ( ) . unwrap ( ) . dirty = true ;
609+
610+ log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
611+ } ) ?;
612+
592613 Ok ( ( ) )
593614 }
594615
@@ -614,25 +635,19 @@ where
614635 sweeper_state. dirty = true ;
615636 }
616637
617- fn persist_state ( & self , sweeper_state : & PersistentSweeperState ) -> Result < ( ) , io:: Error > {
618- self . kv_store
619- . write (
620- OUTPUT_SWEEPER_PERSISTENCE_PRIMARY_NAMESPACE ,
621- OUTPUT_SWEEPER_PERSISTENCE_SECONDARY_NAMESPACE ,
622- OUTPUT_SWEEPER_PERSISTENCE_KEY ,
623- & sweeper_state. encode ( ) ,
624- )
625- . map_err ( |e| {
626- log_error ! (
627- self . logger,
628- "Write for key {}/{}/{} failed due to: {}" ,
629- OUTPUT_SWEEPER_PERSISTENCE_PRIMARY_NAMESPACE ,
630- OUTPUT_SWEEPER_PERSISTENCE_SECONDARY_NAMESPACE ,
631- OUTPUT_SWEEPER_PERSISTENCE_KEY ,
632- e
633- ) ;
634- e
635- } )
638+ fn persist_state < ' a > (
639+ & self , sweeper_state : & PersistentSweeperState ,
640+ ) -> Pin < Box < dyn Future < Output = Result < ( ) , io:: Error > > + ' a + Send > > {
641+ let encoded = & sweeper_state. encode ( ) ;
642+
643+ let result = self . kv_store . write (
644+ OUTPUT_SWEEPER_PERSISTENCE_PRIMARY_NAMESPACE ,
645+ OUTPUT_SWEEPER_PERSISTENCE_SECONDARY_NAMESPACE ,
646+ OUTPUT_SWEEPER_PERSISTENCE_KEY ,
647+ encoded,
648+ ) ;
649+
650+ Box :: pin ( async move { result } )
636651 }
637652
638653 fn spend_outputs (
@@ -1005,16 +1020,18 @@ where
10051020 }
10061021
10071022 /// Tells the sweeper to track the given outputs descriptors. Wraps [`OutputSweeper::track_spendable_outputs`].
1008- pub fn track_spendable_outputs (
1023+ pub async fn track_spendable_outputs (
10091024 & self , output_descriptors : Vec < SpendableOutputDescriptor > , channel_id : Option < ChannelId > ,
10101025 exclude_static_outputs : bool , delay_until_height : Option < u32 > ,
10111026 ) -> Result < ( ) , ( ) > {
1012- self . sweeper . track_spendable_outputs (
1013- output_descriptors,
1014- channel_id,
1015- exclude_static_outputs,
1016- delay_until_height,
1017- )
1027+ self . sweeper
1028+ . track_spendable_outputs (
1029+ output_descriptors,
1030+ channel_id,
1031+ exclude_static_outputs,
1032+ delay_until_height,
1033+ )
1034+ . await
10181035 }
10191036
10201037 /// Returns a list of the currently tracked spendable outputs. Wraps [`OutputSweeper::tracked_spendable_outputs`].
0 commit comments