diff --git a/fuzz/src/lsps_message.rs b/fuzz/src/lsps_message.rs index 8371d1c5fc7..7c9f74777a1 100644 --- a/fuzz/src/lsps_message.rs +++ b/fuzz/src/lsps_message.rs @@ -88,6 +88,7 @@ pub fn do_test(data: &[u8]) { Arc::clone(&tx_broadcaster), None, None, + None, ) .unwrap(), ); diff --git a/lightning-background-processor/src/lib.rs b/lightning-background-processor/src/lib.rs index da415c70a32..00f09cb0e7e 100644 --- a/lightning-background-processor/src/lib.rs +++ b/lightning-background-processor/src/lib.rs @@ -2519,6 +2519,7 @@ mod tests { Arc::clone(&tx_broadcaster), None, None, + None, ) .unwrap(), ); diff --git a/lightning-liquidity/src/lsps2/event.rs b/lightning-liquidity/src/lsps2/event.rs index 502429b79ec..9ca20863387 100644 --- a/lightning-liquidity/src/lsps2/event.rs +++ b/lightning-liquidity/src/lsps2/event.rs @@ -49,7 +49,17 @@ pub enum LSPS2ClientEvent { /// When the invoice is paid, the LSP will open a channel with the previously agreed upon /// parameters to you. /// + /// For BOLT11 JIT invoices, `intercept_scid` and `cltv_expiry_delta` can be used in a route + /// hint. + /// + /// For BOLT12 JIT flows, register these parameters for your offer id on an + /// [`LSPS2BOLT12Router`] and then proceed with the regular BOLT12 offer + /// flow. The router will inject the LSPS2-specific blinded payment path when creating the + /// invoice. + /// /// **Note: ** This event will *not* be persisted across restarts. + /// + /// [`LSPS2BOLT12Router`]: crate::lsps2::router::LSPS2BOLT12Router InvoiceParametersReady { /// The identifier of the issued bLIP-52 / LSPS2 `buy` request, as returned by /// [`LSPS2ClientHandler::select_opening_params`]. diff --git a/lightning-liquidity/src/lsps2/mod.rs b/lightning-liquidity/src/lsps2/mod.rs index 1d5fb76d3b4..684ad9b26f7 100644 --- a/lightning-liquidity/src/lsps2/mod.rs +++ b/lightning-liquidity/src/lsps2/mod.rs @@ -13,5 +13,6 @@ pub mod client; pub mod event; pub mod msgs; pub(crate) mod payment_queue; +pub mod router; pub mod service; pub mod utils; diff --git a/lightning-liquidity/src/lsps2/router.rs b/lightning-liquidity/src/lsps2/router.rs new file mode 100644 index 00000000000..2eb456982a1 --- /dev/null +++ b/lightning-liquidity/src/lsps2/router.rs @@ -0,0 +1,450 @@ +// This file is Copyright its original authors, visible in version control +// history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license +// , at your option. +// You may not use this file except in accordance with one or both of these +// licenses. + +//! Router helpers for combining LSPS2 with BOLT12 offer flows. + +use alloc::vec::Vec; + +use crate::prelude::{new_hash_map, HashMap}; +use crate::sync::Mutex; + +use bitcoin::secp256k1::{self, PublicKey, Secp256k1}; + +use lightning::blinded_path::payment::{ + BlindedPaymentPath, Bolt12OfferContext, ForwardTlvs, PaymentConstraints, PaymentContext, + PaymentForwardNode, PaymentRelay, ReceiveTlvs, +}; +use lightning::ln::channel_state::ChannelDetails; +use lightning::ln::channelmanager::{PaymentId, MIN_FINAL_CLTV_EXPIRY_DELTA}; +use lightning::offers::offer::OfferId; +use lightning::routing::router::{InFlightHtlcs, Route, RouteParameters, Router}; +use lightning::sign::{EntropySource, ReceiveAuthKey}; +use lightning::types::features::BlindedHopFeatures; +use lightning::types::payment::PaymentHash; + +/// LSPS2 invoice parameters required to construct BOLT12 blinded payment paths through an LSP. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct LSPS2Bolt12InvoiceParameters { + /// The LSP node id to use as the blinded path introduction node. + pub counterparty_node_id: PublicKey, + /// The LSPS2 intercept short channel id. + pub intercept_scid: u64, + /// The CLTV expiry delta the LSP requires for forwarding over `intercept_scid`. + pub cltv_expiry_delta: u32, +} + +/// A router wrapper that injects LSPS2-specific BOLT12 blinded payment paths for registered +/// offer ids while delegating all other routing behavior to an inner [`Router`]. +pub struct LSPS2BOLT12Router { + inner_router: R, + entropy_source: ES, + offer_to_invoice_params: Mutex>, +} + +impl LSPS2BOLT12Router { + /// Constructs a new wrapper around `inner_router`. + pub fn new(inner_router: R, entropy_source: ES) -> Self { + Self { inner_router, entropy_source, offer_to_invoice_params: Mutex::new(new_hash_map()) } + } + + /// Registers LSPS2 parameters to be used when generating blinded payment paths for `offer_id`. + pub fn register_offer( + &self, offer_id: OfferId, invoice_params: LSPS2Bolt12InvoiceParameters, + ) -> Option { + self.offer_to_invoice_params.lock().unwrap().insert(offer_id.0, invoice_params) + } + + /// Removes any previously registered LSPS2 parameters for `offer_id`. + pub fn unregister_offer(&self, offer_id: &OfferId) -> Option { + self.offer_to_invoice_params.lock().unwrap().remove(&offer_id.0) + } + + /// Clears all LSPS2 parameters previously registered via [`Self::register_offer`]. + pub fn clear_registered_offers(&self) { + self.offer_to_invoice_params.lock().unwrap().clear(); + } + + fn registered_lsps2_params( + &self, payment_context: &PaymentContext, + ) -> Option { + // We intentionally only match `Bolt12Offer` here and not `AsyncBolt12Offer`, as LSPS2 + // JIT channels are not applicable to async (always-online) BOLT12 offer flows. + let Bolt12OfferContext { offer_id, .. } = match payment_context { + PaymentContext::Bolt12Offer(context) => context, + _ => return None, + }; + + self.offer_to_invoice_params.lock().unwrap().get(&offer_id.0).copied() + } +} + +impl Router for LSPS2BOLT12Router { + fn find_route( + &self, payer: &PublicKey, route_params: &RouteParameters, + first_hops: Option<&[&ChannelDetails]>, inflight_htlcs: InFlightHtlcs, + ) -> Result { + self.inner_router.find_route(payer, route_params, first_hops, inflight_htlcs) + } + + fn find_route_with_id( + &self, payer: &PublicKey, route_params: &RouteParameters, + first_hops: Option<&[&ChannelDetails]>, inflight_htlcs: InFlightHtlcs, + payment_hash: PaymentHash, payment_id: PaymentId, + ) -> Result { + self.inner_router.find_route_with_id( + payer, + route_params, + first_hops, + inflight_htlcs, + payment_hash, + payment_id, + ) + } + + fn create_blinded_payment_paths( + &self, recipient: PublicKey, local_node_receive_key: ReceiveAuthKey, + first_hops: Vec, tlvs: ReceiveTlvs, amount_msats: Option, + secp_ctx: &Secp256k1, + ) -> Result, ()> { + let lsps2_invoice_params = match self.registered_lsps2_params(&tlvs.payment_context) { + Some(params) => params, + None => { + return self.inner_router.create_blinded_payment_paths( + recipient, + local_node_receive_key, + first_hops, + tlvs, + amount_msats, + secp_ctx, + ) + }, + }; + + let payment_relay = PaymentRelay { + cltv_expiry_delta: u16::try_from(lsps2_invoice_params.cltv_expiry_delta) + .map_err(|_| ())?, + fee_proportional_millionths: 0, + fee_base_msat: 0, + }; + let payment_constraints = PaymentConstraints { + max_cltv_expiry: tlvs + .payment_constraints + .max_cltv_expiry + .saturating_add(lsps2_invoice_params.cltv_expiry_delta), + htlc_minimum_msat: 0, + }; + + let forward_node = PaymentForwardNode { + tlvs: ForwardTlvs { + short_channel_id: lsps2_invoice_params.intercept_scid, + payment_relay, + payment_constraints, + features: BlindedHopFeatures::empty(), + next_blinding_override: None, + }, + node_id: lsps2_invoice_params.counterparty_node_id, + htlc_maximum_msat: u64::MAX, + }; + + // We deliberately use `BlindedPaymentPath::new` without dummy hops here. Since the LSP + // is the introduction node and already knows the recipient, adding dummy hops would not + // provide meaningful privacy benefits in the LSPS2 JIT channel context. + let path = BlindedPaymentPath::new( + &[forward_node], + recipient, + local_node_receive_key, + tlvs, + u64::MAX, + MIN_FINAL_CLTV_EXPIRY_DELTA, + &self.entropy_source, + secp_ctx, + )?; + + Ok(vec![path]) + } +} + +#[cfg(test)] +mod tests { + use super::{LSPS2BOLT12Router, LSPS2Bolt12InvoiceParameters}; + + use bitcoin::network::Network; + use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey}; + + use lightning::blinded_path::payment::{ + Bolt12OfferContext, Bolt12RefundContext, PaymentConstraints, PaymentContext, ReceiveTlvs, + }; + use lightning::blinded_path::NodeIdLookUp; + use lightning::ln::channel_state::ChannelDetails; + use lightning::ln::channelmanager::MIN_FINAL_CLTV_EXPIRY_DELTA; + use lightning::offers::invoice_request::InvoiceRequestFields; + use lightning::offers::offer::OfferId; + use lightning::routing::router::{InFlightHtlcs, Route, RouteParameters, Router}; + use lightning::sign::{EntropySource, NodeSigner, ReceiveAuthKey, Recipient}; + use lightning::types::payment::PaymentSecret; + use lightning::util::test_utils::TestKeysInterface; + + use crate::sync::Mutex; + + use core::sync::atomic::{AtomicUsize, Ordering}; + + struct RecordingLookup { + next_node_id: PublicKey, + short_channel_id: Mutex>, + } + + impl NodeIdLookUp for RecordingLookup { + fn next_node_id(&self, short_channel_id: u64) -> Option { + *self.short_channel_id.lock().unwrap() = Some(short_channel_id); + Some(self.next_node_id) + } + } + + #[derive(Clone)] + struct TestEntropy; + + impl EntropySource for TestEntropy { + fn get_secure_random_bytes(&self) -> [u8; 32] { + [42; 32] + } + } + + struct MockRouter { + create_blinded_payment_paths_calls: AtomicUsize, + } + + impl MockRouter { + fn new() -> Self { + Self { create_blinded_payment_paths_calls: AtomicUsize::new(0) } + } + + fn create_blinded_payment_paths_calls(&self) -> usize { + self.create_blinded_payment_paths_calls.load(Ordering::Acquire) + } + } + + impl Router for MockRouter { + fn find_route( + &self, _payer: &PublicKey, _route_params: &RouteParameters, + _first_hops: Option<&[&ChannelDetails]>, _inflight_htlcs: InFlightHtlcs, + ) -> Result { + Err("mock router") + } + + fn create_blinded_payment_paths< + T: bitcoin::secp256k1::Signing + bitcoin::secp256k1::Verification, + >( + &self, _recipient: PublicKey, _local_node_receive_key: ReceiveAuthKey, + _first_hops: Vec, _tlvs: ReceiveTlvs, _amount_msats: Option, + _secp_ctx: &Secp256k1, + ) -> Result, ()> { + self.create_blinded_payment_paths_calls.fetch_add(1, Ordering::AcqRel); + Err(()) + } + } + + fn pubkey(byte: u8) -> PublicKey { + let secret_key = SecretKey::from_slice(&[byte; 32]).unwrap(); + PublicKey::from_secret_key(&Secp256k1::new(), &secret_key) + } + + fn bolt12_offer_tlvs(offer_id: OfferId) -> ReceiveTlvs { + ReceiveTlvs { + payment_secret: PaymentSecret([2; 32]), + payment_constraints: PaymentConstraints { max_cltv_expiry: 100, htlc_minimum_msat: 1 }, + payment_context: PaymentContext::Bolt12Offer(Bolt12OfferContext { + offer_id, + invoice_request: InvoiceRequestFields { + payer_signing_pubkey: pubkey(9), + quantity: None, + payer_note_truncated: None, + human_readable_name: None, + }, + }), + } + } + + fn bolt12_refund_tlvs() -> ReceiveTlvs { + ReceiveTlvs { + payment_secret: PaymentSecret([2; 32]), + payment_constraints: PaymentConstraints { max_cltv_expiry: 100, htlc_minimum_msat: 1 }, + payment_context: PaymentContext::Bolt12Refund(Bolt12RefundContext {}), + } + } + + #[test] + fn creates_lsps2_blinded_path_for_registered_offer() { + let inner_router = MockRouter::new(); + let entropy_source = TestEntropy; + let router = LSPS2BOLT12Router::new(inner_router, entropy_source); + + let offer_id = OfferId([8; 32]); + let lsp_keys = TestKeysInterface::new(&[43; 32], Network::Testnet); + let lsp_node_id = lsp_keys.get_node_id(Recipient::Node).unwrap(); + + let expected_scid = 42; + let expected_cltv_delta = 48; + let recipient = pubkey(10); + + router.register_offer( + offer_id, + LSPS2Bolt12InvoiceParameters { + counterparty_node_id: lsp_node_id, + intercept_scid: expected_scid, + cltv_expiry_delta: expected_cltv_delta, + }, + ); + + let secp_ctx = Secp256k1::new(); + let mut paths = router + .create_blinded_payment_paths( + recipient, + ReceiveAuthKey([3; 32]), + Vec::new(), + bolt12_offer_tlvs(offer_id), + Some(5_000), + &secp_ctx, + ) + .unwrap(); + + assert_eq!(paths.len(), 1); + let mut path = paths.pop().unwrap(); + assert_eq!( + path.introduction_node(), + &lightning::blinded_path::IntroductionNode::NodeId(lsp_node_id) + ); + assert_eq!(path.payinfo.fee_base_msat, 0); + assert_eq!(path.payinfo.fee_proportional_millionths, 0); + assert_eq!( + path.payinfo.cltv_expiry_delta, + expected_cltv_delta as u16 + MIN_FINAL_CLTV_EXPIRY_DELTA + ); + + let lookup = + RecordingLookup { next_node_id: recipient, short_channel_id: Mutex::new(None) }; + path.advance_path_by_one(&lsp_keys, &lookup, &secp_ctx).unwrap(); + assert_eq!(*lookup.short_channel_id.lock().unwrap(), Some(expected_scid)); + } + + #[test] + fn delegates_when_offer_is_not_registered() { + let inner_router = MockRouter::new(); + let entropy_source = TestEntropy; + let router = LSPS2BOLT12Router::new(inner_router, entropy_source); + let secp_ctx = Secp256k1::new(); + + let result = router.create_blinded_payment_paths( + pubkey(10), + ReceiveAuthKey([3; 32]), + Vec::new(), + bolt12_refund_tlvs(), + Some(10_000), + &secp_ctx, + ); + + assert!(result.is_err()); + assert_eq!(router.inner_router.create_blinded_payment_paths_calls(), 1); + } + + #[test] + fn delegates_when_offer_id_is_not_registered() { + let inner_router = MockRouter::new(); + let entropy_source = TestEntropy; + let router = LSPS2BOLT12Router::new(inner_router, entropy_source); + let secp_ctx = Secp256k1::new(); + + // Use a Bolt12Offer context with an OfferId that was never registered. + let unregistered_offer_id = OfferId([99; 32]); + let result = router.create_blinded_payment_paths( + pubkey(10), + ReceiveAuthKey([3; 32]), + Vec::new(), + bolt12_offer_tlvs(unregistered_offer_id), + Some(10_000), + &secp_ctx, + ); + + assert!(result.is_err()); + assert_eq!(router.inner_router.create_blinded_payment_paths_calls(), 1); + } + + #[test] + fn rejects_out_of_range_cltv_delta() { + let inner_router = MockRouter::new(); + let entropy_source = TestEntropy; + let router = LSPS2BOLT12Router::new(inner_router, entropy_source); + + let offer_id = OfferId([11; 32]); + router.register_offer( + offer_id, + LSPS2Bolt12InvoiceParameters { + counterparty_node_id: pubkey(12), + intercept_scid: 21, + cltv_expiry_delta: u32::from(u16::MAX) + 1, + }, + ); + + let secp_ctx = Secp256k1::new(); + let result = router.create_blinded_payment_paths( + pubkey(13), + ReceiveAuthKey([3; 32]), + Vec::new(), + bolt12_offer_tlvs(offer_id), + Some(1_000), + &secp_ctx, + ); + + assert!(result.is_err()); + } + + #[test] + fn can_unregister_offer() { + let inner_router = MockRouter::new(); + let entropy_source = TestEntropy; + let router = LSPS2BOLT12Router::new(inner_router, entropy_source); + + let offer_id = OfferId([1; 32]); + let params = LSPS2Bolt12InvoiceParameters { + counterparty_node_id: pubkey(2), + intercept_scid: 7, + cltv_expiry_delta: 40, + }; + assert_eq!(router.register_offer(offer_id, params), None); + assert_eq!(router.unregister_offer(&offer_id), Some(params)); + assert_eq!(router.unregister_offer(&offer_id), None); + } + + #[test] + fn can_clear_registered_offers() { + let inner_router = MockRouter::new(); + let entropy_source = TestEntropy; + let router = LSPS2BOLT12Router::new(inner_router, entropy_source); + + router.register_offer( + OfferId([1; 32]), + LSPS2Bolt12InvoiceParameters { + counterparty_node_id: pubkey(2), + intercept_scid: 7, + cltv_expiry_delta: 40, + }, + ); + router.register_offer( + OfferId([2; 32]), + LSPS2Bolt12InvoiceParameters { + counterparty_node_id: pubkey(3), + intercept_scid: 8, + cltv_expiry_delta: 41, + }, + ); + + router.clear_registered_offers(); + assert_eq!(router.unregister_offer(&OfferId([1; 32])), None); + assert_eq!(router.unregister_offer(&OfferId([2; 32])), None); + } +} diff --git a/lightning-liquidity/src/lsps2/service.rs b/lightning-liquidity/src/lsps2/service.rs index 35942dcd624..4fe80f62c58 100644 --- a/lightning-liquidity/src/lsps2/service.rs +++ b/lightning-liquidity/src/lsps2/service.rs @@ -45,6 +45,7 @@ use lightning::events::HTLCHandlingFailureType; use lightning::ln::channelmanager::{AChannelManager, FailureCode, InterceptId}; use lightning::ln::msgs::{ErrorAction, LightningError}; use lightning::ln::types::ChannelId; +use lightning::onion_message::messenger::OnionMessageInterceptor; use lightning::util::errors::APIError; use lightning::util::logger::Level; use lightning::util::ser::Writeable; @@ -631,17 +632,20 @@ impl PeerState { }); } - fn prune_expired_request_state(&mut self) { + fn prune_expired_request_state(&mut self) -> Vec { + let mut pruned_scids = Vec::new(); self.outbound_channels_by_intercept_scid.retain(|intercept_scid, entry| { if entry.is_prunable() { // We abort the flow, and prune any data kept. self.intercept_scid_by_channel_id.retain(|_, iscid| intercept_scid != iscid); self.intercept_scid_by_user_channel_id.retain(|_, iscid| intercept_scid != iscid); self.needs_persist |= true; + pruned_scids.push(*intercept_scid); return false; } true }); + pruned_scids } fn pending_requests_and_channels(&self) -> usize { @@ -717,6 +721,7 @@ where total_pending_requests: AtomicUsize, config: LSPS2ServiceConfig, persistence_in_flight: AtomicUsize, + onion_message_interceptor: Option>, } impl LSPS2ServiceHandler @@ -728,6 +733,7 @@ where per_peer_state: HashMap>, pending_messages: Arc, pending_events: Arc>, channel_manager: CM, kv_store: K, tx_broadcaster: T, config: LSPS2ServiceConfig, + onion_message_interceptor: Option>, ) -> Result { let mut peer_by_intercept_scid = new_hash_map(); let mut peer_by_channel_id = new_hash_map(); @@ -756,6 +762,14 @@ where } } + // Register all peers with active intercept SCIDs for onion message interception, + // so that messages for offline peers are held rather than dropped. + if let Some(ref interceptor) = onion_message_interceptor { + for node_id in peer_by_intercept_scid.values() { + interceptor.register_peer_for_interception(*node_id); + } + } + Ok(Self { pending_messages, pending_events, @@ -768,6 +782,7 @@ where kv_store, tx_broadcaster, config, + onion_message_interceptor, }) } @@ -776,6 +791,29 @@ where &self.config } + /// Cleans up `peer_by_intercept_scid` entries for the given SCIDs, and deregisters the peer + /// from onion message interception if they have no remaining active intercept SCIDs. + fn cleanup_intercept_scids( + &self, counterparty_node_id: &PublicKey, pruned_scids: &[u64], has_remaining_channels: bool, + ) { + if pruned_scids.is_empty() { + return; + } + + { + let mut peer_by_intercept_scid = self.peer_by_intercept_scid.write().unwrap(); + for scid in pruned_scids { + peer_by_intercept_scid.remove(scid); + } + } + + if !has_remaining_channels { + if let Some(ref interceptor) = self.onion_message_interceptor { + interceptor.deregister_peer_for_interception(counterparty_node_id); + } + } + } + /// Returns whether the peer has any active LSPS2 requests. pub(crate) fn has_active_requests(&self, counterparty_node_id: &PublicKey) -> bool { let outer_state_lock = self.per_peer_state.read().unwrap(); @@ -921,6 +959,10 @@ where peer_by_intercept_scid.insert(intercept_scid, *counterparty_node_id); } + if let Some(ref interceptor) = self.onion_message_interceptor { + interceptor.register_peer_for_interception(*counterparty_node_id); + } + let outbound_jit_channel = OutboundJITChannel::new( buy_request.payment_size_msat, buy_request.opening_fee_params, @@ -1051,7 +1093,15 @@ where peer_state .outbound_channels_by_intercept_scid .remove(&intercept_scid); - // TODO: cleanup peer_by_intercept_scid + let has_remaining = + !peer_state.outbound_channels_by_intercept_scid.is_empty(); + drop(peer_state); + drop(outer_state_lock); + self.cleanup_intercept_scids( + counterparty_node_id, + &[intercept_scid], + has_remaining, + ); return Err(APIError::APIMisuseError { err: e.err }); }, } @@ -1270,7 +1320,7 @@ where pub async fn channel_open_abandoned( &self, counterparty_node_id: &PublicKey, user_channel_id: u128, ) -> Result<(), APIError> { - { + let (intercept_scid, has_remaining) = { let outer_state_lock = self.per_peer_state.read().unwrap(); let inner_state_lock = outer_state_lock.get(counterparty_node_id).ok_or_else(|| { APIError::APIMisuseError { @@ -1317,7 +1367,11 @@ where peer_state.outbound_channels_by_intercept_scid.remove(&intercept_scid); peer_state.intercept_scid_by_channel_id.retain(|_, &mut scid| scid != intercept_scid); peer_state.needs_persist |= true; - } + let has_remaining = !peer_state.outbound_channels_by_intercept_scid.is_empty(); + (intercept_scid, has_remaining) + }; + + self.cleanup_intercept_scids(counterparty_node_id, &[intercept_scid], has_remaining); self.persist_peer_state(*counterparty_node_id).await.map_err(|e| { APIError::APIMisuseError { @@ -1801,10 +1855,16 @@ where { // First build a list of peers to persist and prune with the read lock. This allows // us to avoid the write lock unless we actually need to remove a node. + let mut all_pruned_scids = Vec::new(); let outer_state_lock = self.per_peer_state.read().unwrap(); for (counterparty_node_id, inner_state_lock) in outer_state_lock.iter() { let mut peer_state_lock = inner_state_lock.lock().unwrap(); - peer_state_lock.prune_expired_request_state(); + let pruned_scids = peer_state_lock.prune_expired_request_state(); + if !pruned_scids.is_empty() { + let has_remaining = + !peer_state_lock.outbound_channels_by_intercept_scid.is_empty(); + all_pruned_scids.push((*counterparty_node_id, pruned_scids, has_remaining)); + } let is_prunable = peer_state_lock.is_prunable(); if is_prunable { need_remove.push(*counterparty_node_id); @@ -1812,6 +1872,15 @@ where need_persist.push(*counterparty_node_id); } } + drop(outer_state_lock); + + for (counterparty_node_id, pruned_scids, has_remaining) in all_pruned_scids { + self.cleanup_intercept_scids( + &counterparty_node_id, + &pruned_scids, + has_remaining, + ); + } } for counterparty_node_id in need_persist.into_iter() { @@ -1822,6 +1891,7 @@ where for counterparty_node_id in need_remove { let mut future_opt = None; + let mut was_removed = false; { // We need to take the `per_peer_state` write lock to remove an entry, but also // have to hold it until after the `remove` call returns (but not through @@ -1833,6 +1903,7 @@ where let state = entry.get_mut().get_mut().unwrap(); if state.is_prunable() { entry.remove(); + was_removed = true; let key = counterparty_node_id.to_string(); future_opt = Some(self.kv_store.remove( LIQUIDITY_MANAGER_PERSISTENCE_PRIMARY_NAMESPACE, @@ -1850,6 +1921,20 @@ where debug_assert!(false); } } + if was_removed { + // Clean up handler-level maps for the removed peer. + self.peer_by_intercept_scid + .write() + .unwrap() + .retain(|_, node_id| *node_id != counterparty_node_id); + self.peer_by_channel_id + .write() + .unwrap() + .retain(|_, node_id| *node_id != counterparty_node_id); + if let Some(ref interceptor) = self.onion_message_interceptor { + interceptor.deregister_peer_for_interception(&counterparty_node_id); + } + } if let Some(future) = future_opt { future.await?; did_persist = true; @@ -1877,7 +1962,11 @@ where // We clean up the peer state, but leave removing the peer entry to the prune logic in // `persist` which removes it from the store. peer_state_lock.prune_pending_requests(); - peer_state_lock.prune_expired_request_state(); + let pruned_scids = peer_state_lock.prune_expired_request_state(); + let has_remaining = !peer_state_lock.outbound_channels_by_intercept_scid.is_empty(); + drop(peer_state_lock); + drop(outer_state_lock); + self.cleanup_intercept_scids(&counterparty_node_id, &pruned_scids, has_remaining); } } diff --git a/lightning-liquidity/src/manager.rs b/lightning-liquidity/src/manager.rs index 1f11fc8add7..0d71cb2f6bd 100644 --- a/lightning-liquidity/src/manager.rs +++ b/lightning-liquidity/src/manager.rs @@ -48,6 +48,7 @@ use lightning::ln::channelmanager::{AChannelManager, ChainParameters}; use lightning::ln::msgs::{ErrorAction, LightningError}; use lightning::ln::peer_handler::CustomMessageHandler; use lightning::ln::wire::CustomMessageReader; +use lightning::onion_message::messenger::OnionMessageInterceptor; use lightning::sign::{EntropySource, NodeSigner}; use lightning::util::logger::Level; use lightning::util::persist::{KVStore, KVStoreSync, KVStoreSyncWrapper}; @@ -330,6 +331,7 @@ where chain_params: Option, kv_store: K, transaction_broadcaster: T, service_config: Option, client_config: Option, + onion_message_interceptor: Option>, ) -> Result { Self::new_with_custom_time_provider( entropy_source, @@ -342,6 +344,7 @@ where service_config, client_config, DefaultTimeProvider, + onion_message_interceptor, ) .await } @@ -373,6 +376,7 @@ where chain_source: Option, chain_params: Option, kv_store: K, service_config: Option, client_config: Option, time_provider: TP, + onion_message_interceptor: Option>, ) -> Result { let pending_msgs_or_needs_persist_notifier = Arc::new(Notifier::new()); let pending_messages = @@ -415,6 +419,7 @@ where kv_store.clone(), transaction_broadcaster.clone(), lsps2_service_config.clone(), + onion_message_interceptor.clone(), )?) } else { None @@ -1044,6 +1049,7 @@ where chain_params: Option, kv_store_sync: KS, transaction_broadcaster: T, service_config: Option, client_config: Option, + onion_message_interceptor: Option>, ) -> Result { let kv_store = KVStoreSyncWrapper(kv_store_sync); @@ -1057,6 +1063,7 @@ where transaction_broadcaster, service_config, client_config, + onion_message_interceptor, )); let mut waker = dummy_waker(); @@ -1094,6 +1101,7 @@ where chain_params: Option, kv_store_sync: KS, transaction_broadcaster: T, service_config: Option, client_config: Option, time_provider: TP, + onion_message_interceptor: Option>, ) -> Result { let kv_store = KVStoreSyncWrapper(kv_store_sync); let mut fut = pin!(LiquidityManager::new_with_custom_time_provider( @@ -1107,6 +1115,7 @@ where service_config, client_config, time_provider, + onion_message_interceptor, )); let mut waker = dummy_waker(); diff --git a/lightning-liquidity/tests/common/mod.rs b/lightning-liquidity/tests/common/mod.rs index dea987527ad..7bd3adf7043 100644 --- a/lightning-liquidity/tests/common/mod.rs +++ b/lightning-liquidity/tests/common/mod.rs @@ -47,6 +47,7 @@ fn build_service_and_client_nodes<'a, 'b, 'c>( Some(service_config), None, Arc::clone(&time_provider), + None, ) .unwrap(); @@ -61,6 +62,7 @@ fn build_service_and_client_nodes<'a, 'b, 'c>( None, Some(client_config), time_provider, + None, ) .unwrap(); diff --git a/lightning-liquidity/tests/lsps2_integration_tests.rs b/lightning-liquidity/tests/lsps2_integration_tests.rs index 33a6dd697cf..65df9911dda 100644 --- a/lightning-liquidity/tests/lsps2_integration_tests.rs +++ b/lightning-liquidity/tests/lsps2_integration_tests.rs @@ -14,7 +14,12 @@ use lightning::ln::functional_test_utils::*; use lightning::ln::msgs::BaseMessageHandler; use lightning::ln::msgs::ChannelMessageHandler; use lightning::ln::msgs::MessageSendEvent; +use lightning::ln::msgs::OnionMessageHandler; use lightning::ln::types::ChannelId; +use lightning::offers::invoice_request::InvoiceRequestFields; +use lightning::offers::offer::OfferId; +use lightning::routing::router::{InFlightHtlcs, Route, RouteParameters, Router}; +use lightning::sign::{RandomBytes, ReceiveAuthKey}; use lightning_liquidity::events::LiquidityEvent; use lightning_liquidity::lsps0::ser::LSPSDateTime; @@ -22,11 +27,16 @@ use lightning_liquidity::lsps2::client::LSPS2ClientConfig; use lightning_liquidity::lsps2::event::LSPS2ClientEvent; use lightning_liquidity::lsps2::event::LSPS2ServiceEvent; use lightning_liquidity::lsps2::msgs::LSPS2RawOpeningFeeParams; +use lightning_liquidity::lsps2::router::{LSPS2BOLT12Router, LSPS2Bolt12InvoiceParameters}; use lightning_liquidity::lsps2::service::LSPS2ServiceConfig; use lightning_liquidity::lsps2::utils::is_valid_opening_fee_params; use lightning_liquidity::utils::time::{DefaultTimeProvider, TimeProvider}; use lightning_liquidity::{LiquidityClientConfig, LiquidityManagerSync, LiquidityServiceConfig}; +use lightning::blinded_path::payment::{ + Bolt12OfferContext, PaymentConstraints, PaymentContext, ReceiveTlvs, +}; +use lightning::blinded_path::NodeIdLookUp; use lightning::chain::{BestBlock, Filter}; use lightning::ln::channelmanager::{ChainParameters, InterceptId, MIN_FINAL_CLTV_EXPIRY_DELTA}; use lightning::ln::functional_test_utils::{ @@ -57,6 +67,46 @@ use std::time::Duration; const MAX_PENDING_REQUESTS_PER_PEER: usize = 10; const MAX_TOTAL_PENDING_REQUESTS: usize = 1000; +struct RecordingLookup { + next_node_id: PublicKey, + short_channel_id: std::sync::Mutex>, +} + +impl NodeIdLookUp for RecordingLookup { + fn next_node_id(&self, short_channel_id: u64) -> Option { + *self.short_channel_id.lock().unwrap() = Some(short_channel_id); + Some(self.next_node_id) + } +} + +struct FailingRouter; + +impl FailingRouter { + fn new() -> Self { + Self + } +} + +impl Router for FailingRouter { + fn find_route( + &self, _payer: &PublicKey, _route_params: &RouteParameters, + _first_hops: Option<&[&lightning::ln::channel_state::ChannelDetails]>, + _inflight_htlcs: InFlightHtlcs, + ) -> Result { + Err("failing test router") + } + + fn create_blinded_payment_paths< + T: bitcoin::secp256k1::Signing + bitcoin::secp256k1::Verification, + >( + &self, _recipient: PublicKey, _local_node_receive_key: ReceiveAuthKey, + _first_hops: Vec, _tlvs: ReceiveTlvs, + _amount_msats: Option, _secp_ctx: &Secp256k1, + ) -> Result, ()> { + Err(()) + } +} + fn build_lsps2_configs() -> ([u8; 32], LiquidityServiceConfig, LiquidityClientConfig) { let promise_secret = [42; 32]; let lsps2_service_config = LSPS2ServiceConfig { promise_secret }; @@ -1089,6 +1139,7 @@ fn lsps2_service_handler_persistence_across_restarts() { Some(service_config), None, time_provider, + None, ) .unwrap(); @@ -1486,6 +1537,349 @@ fn execute_lsps2_dance( } } +#[test] +fn bolt12_custom_router_uses_lsps2_intercept_scid() { + let chanmon_cfgs = create_chanmon_cfgs(3); + let node_cfgs = create_node_cfgs(3, &chanmon_cfgs); + let node_chanmgrs = create_node_chanmgrs(3, &node_cfgs, &[None, None, None]); + let nodes = create_network(3, &node_cfgs, &node_chanmgrs); + let (lsps_nodes, promise_secret) = setup_test_lsps2_nodes_with_payer(nodes); + + let service_node_id = lsps_nodes.service_node.inner.node.get_our_node_id(); + let client_node_id = lsps_nodes.client_node.inner.node.get_our_node_id(); + + let intercept_scid = lsps_nodes.service_node.node.get_intercept_scid(); + let cltv_expiry_delta = 72; + + execute_lsps2_dance( + &lsps_nodes, + intercept_scid, + 42, + cltv_expiry_delta, + promise_secret, + Some(250_000), + 1_000, + ); + + let inner_router = FailingRouter::new(); + let router = LSPS2BOLT12Router::new(inner_router, lsps_nodes.client_node.keys_manager); + let offer_id = OfferId([42; 32]); + + router.register_offer( + offer_id, + LSPS2Bolt12InvoiceParameters { + counterparty_node_id: service_node_id, + intercept_scid, + cltv_expiry_delta, + }, + ); + + let tlvs = ReceiveTlvs { + payment_secret: lightning_types::payment::PaymentSecret([7; 32]), + payment_constraints: PaymentConstraints { max_cltv_expiry: 50, htlc_minimum_msat: 1 }, + payment_context: PaymentContext::Bolt12Offer(Bolt12OfferContext { + offer_id, + invoice_request: InvoiceRequestFields { + payer_signing_pubkey: lsps_nodes.payer_node.node.get_our_node_id(), + quantity: None, + payer_note_truncated: None, + human_readable_name: None, + }, + }), + }; + + let secp_ctx = Secp256k1::new(); + let mut paths = router + .create_blinded_payment_paths( + client_node_id, + ReceiveAuthKey([3; 32]), + Vec::new(), + tlvs, + Some(100_000), + &secp_ctx, + ) + .unwrap(); + + assert_eq!(paths.len(), 1); + let mut path = paths.pop().unwrap(); + assert_eq!( + path.introduction_node(), + &lightning::blinded_path::IntroductionNode::NodeId(service_node_id) + ); + assert_eq!(path.payinfo.fee_base_msat, 0); + assert_eq!(path.payinfo.fee_proportional_millionths, 0); + + let lookup = RecordingLookup { + next_node_id: client_node_id, + short_channel_id: std::sync::Mutex::new(None), + }; + path.advance_path_by_one(lsps_nodes.service_node.keys_manager, &lookup, &secp_ctx).unwrap(); + assert_eq!(*lookup.short_channel_id.lock().unwrap(), Some(intercept_scid)); +} + +#[test] +fn bolt12_lsps2_end_to_end_test() { + // End-to-end test of the BOLT12 + LSPS2 JIT channel flow. Three nodes: payer, service, client. + // client_trusts_lsp=true; funding transaction broadcast happens after client claims the HTLC. + let chanmon_cfgs = create_chanmon_cfgs(3); + let node_cfgs = create_node_cfgs(3, &chanmon_cfgs); + + let mut service_node_config = test_default_channel_config(); + service_node_config.htlc_interception_flags = HTLCInterceptionFlags::ToInterceptSCIDs as u8; + + let mut client_node_config = test_default_channel_config(); + client_node_config.accept_inbound_channels = true; + client_node_config.channel_config.accept_underpaying_htlcs = true; + + let node_chanmgrs = create_node_chanmgrs( + 3, + &node_cfgs, + &[Some(service_node_config), Some(client_node_config), None], + ); + let nodes = create_network(3, &node_cfgs, &node_chanmgrs); + let (lsps_nodes, promise_secret) = setup_test_lsps2_nodes_with_payer(nodes); + let LSPSNodesWithPayer { ref service_node, ref client_node, ref payer_node } = lsps_nodes; + + let payer_node_id = payer_node.node.get_our_node_id(); + let service_node_id = service_node.inner.node.get_our_node_id(); + let client_node_id = client_node.inner.node.get_our_node_id(); + + let service_handler = service_node.liquidity_manager.lsps2_service_handler().unwrap(); + + create_chan_between_nodes_with_value(&payer_node, &service_node.inner, 2_000_000, 100_000); + + let intercept_scid = service_node.node.get_intercept_scid(); + let user_channel_id = 42; + let cltv_expiry_delta: u32 = 144; + let payment_size_msat = Some(1_000_000); + let fee_base_msat = 1_000; + + execute_lsps2_dance( + &lsps_nodes, + intercept_scid, + user_channel_id, + cltv_expiry_delta, + promise_secret, + payment_size_msat, + fee_base_msat, + ); + + // Disconnect payer from client to ensure deterministic onion message routing through service. + payer_node.node.peer_disconnected(client_node_id); + client_node.node.peer_disconnected(payer_node_id); + payer_node.onion_messenger.peer_disconnected(client_node_id); + client_node.onion_messenger.peer_disconnected(payer_node_id); + + #[cfg(c_bindings)] + let offer = { + let mut offer_builder = client_node.node.create_offer_builder().unwrap(); + offer_builder.amount_msats(payment_size_msat.unwrap()); + offer_builder.build().unwrap() + }; + #[cfg(not(c_bindings))] + let offer = client_node + .node + .create_offer_builder() + .unwrap() + .amount_msats(payment_size_msat.unwrap()) + .build() + .unwrap(); + + let lsps2_router = Arc::new(LSPS2BOLT12Router::new( + FailingRouter::new(), + Arc::new(RandomBytes::new([43; 32])), + )); + lsps2_router.register_offer( + offer.id(), + LSPS2Bolt12InvoiceParameters { + counterparty_node_id: service_node_id, + intercept_scid, + cltv_expiry_delta, + }, + ); + + let lsps2_router = Arc::clone(&lsps2_router); + *client_node.router.override_create_blinded_payment_paths.lock().unwrap() = + Some(Box::new(move |recipient, local_node_receive_key, first_hops, tlvs, amount_msats| { + let secp_ctx = Secp256k1::new(); + lsps2_router.create_blinded_payment_paths( + recipient, + local_node_receive_key, + first_hops, + tlvs, + amount_msats, + &secp_ctx, + ) + })); + + let payment_id = PaymentId([1; 32]); + payer_node.node.pay_for_offer(&offer, None, payment_id, Default::default()).unwrap(); + + let onion_msg = payer_node + .onion_messenger + .next_onion_message_for_peer(service_node_id) + .expect("Payer should send InvoiceRequest toward service"); + service_node.onion_messenger.handle_onion_message(payer_node_id, &onion_msg); + + let fwd_msg = service_node + .onion_messenger + .next_onion_message_for_peer(client_node_id) + .expect("Service should forward InvoiceRequest to client"); + client_node.onion_messenger.handle_onion_message(service_node_id, &fwd_msg); + + let onion_msg = client_node + .onion_messenger + .next_onion_message_for_peer(service_node_id) + .expect("Client should send Invoice toward service"); + service_node.onion_messenger.handle_onion_message(client_node_id, &onion_msg); + + let fwd_msg = service_node + .onion_messenger + .next_onion_message_for_peer(payer_node_id) + .expect("Service should forward Invoice to payer"); + payer_node.onion_messenger.handle_onion_message(service_node_id, &fwd_msg); + + check_added_monitors(&payer_node, 1); + let events = payer_node.node.get_and_clear_pending_msg_events(); + assert_eq!(events.len(), 1); + let ev = SendEvent::from_event(events[0].clone()); + + service_node.inner.node.handle_update_add_htlc(payer_node_id, &ev.msgs[0]); + do_commitment_signed_dance(&service_node.inner, &payer_node, &ev.commitment_msg, false, true); + service_node.inner.node.process_pending_htlc_forwards(); + + let events = service_node.inner.node.get_and_clear_pending_events(); + assert_eq!(events.len(), 1); + let (payment_hash, expected_outbound_amount_msat) = match &events[0] { + Event::HTLCIntercepted { + intercept_id, + requested_next_hop_scid, + payment_hash, + expected_outbound_amount_msat, + .. + } => { + assert_eq!(*requested_next_hop_scid, intercept_scid); + + service_handler + .htlc_intercepted( + *requested_next_hop_scid, + *intercept_id, + *expected_outbound_amount_msat, + *payment_hash, + ) + .unwrap(); + (*payment_hash, expected_outbound_amount_msat) + }, + other => panic!("Expected HTLCIntercepted event, got: {:?}", other), + }; + + let open_channel_event = service_node.liquidity_manager.next_event().unwrap(); + + match open_channel_event { + LiquidityEvent::LSPS2Service(LSPS2ServiceEvent::OpenChannel { + their_network_key, + amt_to_forward_msat, + opening_fee_msat, + user_channel_id: uc_id, + intercept_scid: iscd, + }) => { + assert_eq!(their_network_key, client_node_id); + assert_eq!(amt_to_forward_msat, payment_size_msat.unwrap() - fee_base_msat); + assert_eq!(opening_fee_msat, fee_base_msat); + assert_eq!(uc_id, user_channel_id); + assert_eq!(iscd, intercept_scid); + }, + other => panic!("Expected OpenChannel event, got: {:?}", other), + }; + + let result = + service_handler.channel_needs_manual_broadcast(user_channel_id, &client_node_id).unwrap(); + assert!(result, "Channel should require manual broadcast"); + + let (channel_id, funding_tx) = create_channel_with_manual_broadcast( + &service_node_id, + &client_node_id, + &service_node, + &client_node, + user_channel_id, + expected_outbound_amount_msat, + true, + ); + + service_handler.channel_ready(user_channel_id, &channel_id, &client_node_id).unwrap(); + + service_node.inner.node.process_pending_htlc_forwards(); + + let pay_event = { + { + let mut added_monitors = + service_node.inner.chain_monitor.added_monitors.lock().unwrap(); + assert_eq!(added_monitors.len(), 1); + added_monitors.clear(); + } + let mut events = service_node.inner.node.get_and_clear_pending_msg_events(); + assert_eq!(events.len(), 1); + SendEvent::from_event(events.remove(0)) + }; + + client_node.inner.node.handle_update_add_htlc(service_node_id, &pay_event.msgs[0]); + do_commitment_signed_dance( + &client_node.inner, + &service_node.inner, + &pay_event.commitment_msg, + false, + true, + ); + client_node.inner.node.process_pending_htlc_forwards(); + + let client_events = client_node.inner.node.get_and_clear_pending_events(); + assert_eq!(client_events.len(), 1); + let preimage = match &client_events[0] { + Event::PaymentClaimable { payment_hash: ph, purpose, .. } => { + assert_eq!(*ph, payment_hash); + purpose.preimage() + }, + other => panic!("Expected PaymentClaimable event on client, got: {:?}", other), + }; + + let broadcasted = service_node.inner.tx_broadcaster.txn_broadcasted.lock().unwrap(); + assert!(broadcasted.is_empty(), "There should be no broadcasted txs yet"); + drop(broadcasted); + + client_node.inner.node.claim_funds(preimage.unwrap()); + + claim_and_assert_forwarded_only( + &payer_node, + &service_node.inner, + &client_node.inner, + preimage.unwrap(), + ); + + let service_events = service_node.node.get_and_clear_pending_events(); + assert_eq!(service_events.len(), 1); + + let total_fee_msat = match service_events[0].clone() { + Event::PaymentForwarded { + prev_node_id, + next_node_id, + skimmed_fee_msat, + total_fee_earned_msat, + .. + } => { + assert_eq!(prev_node_id, Some(payer_node_id)); + assert_eq!(next_node_id, Some(client_node_id)); + service_handler.payment_forwarded(channel_id, skimmed_fee_msat.unwrap_or(0)).unwrap(); + Some(total_fee_earned_msat.unwrap() - skimmed_fee_msat.unwrap()) + }, + _ => panic!("Expected PaymentForwarded event, got: {:?}", service_events[0]), + }; + + let broadcasted = service_node.inner.tx_broadcaster.txn_broadcasted.lock().unwrap(); + assert!(broadcasted.iter().any(|b| b.compute_txid() == funding_tx.compute_txid())); + + expect_payment_sent(&payer_node, preimage.unwrap(), Some(total_fee_msat), true, true); +} + fn create_channel_with_manual_broadcast( service_node_id: &PublicKey, client_node_id: &PublicKey, service_node: &LiquidityNode, client_node: &LiquidityNode, user_channel_id: u128, expected_outbound_amount_msat: &u64, diff --git a/lightning-liquidity/tests/lsps5_integration_tests.rs b/lightning-liquidity/tests/lsps5_integration_tests.rs index 16f20fd095f..623bf42a88f 100644 --- a/lightning-liquidity/tests/lsps5_integration_tests.rs +++ b/lightning-liquidity/tests/lsps5_integration_tests.rs @@ -1618,6 +1618,7 @@ fn lsps5_service_handler_persistence_across_restarts() { Some(service_config), None, Arc::clone(&time_provider), + None, ) .unwrap(); diff --git a/lightning/src/onion_message/messenger.rs b/lightning/src/onion_message/messenger.rs index f94eb7877f5..e81b1993058 100644 --- a/lightning/src/onion_message/messenger.rs +++ b/lightning/src/onion_message/messenger.rs @@ -125,6 +125,54 @@ impl< } } +/// A trait for registering specific peers for onion message interception. +/// +/// When a peer is registered for interception and is currently offline, any onion messages +/// intended to be forwarded to them will generate an [`Event::OnionMessageIntercepted`] instead +/// of being dropped. When a registered peer connects, an [`Event::OnionMessagePeerConnected`] +/// will be generated. +/// +/// [`OnionMessenger`] implements this trait, but it is also useful as a trait object to allow +/// external components (e.g., an LSPS2 service) to register peers for interception without +/// needing to know the concrete [`OnionMessenger`] type. +/// +/// [`Event::OnionMessageIntercepted`]: crate::events::Event::OnionMessageIntercepted +/// [`Event::OnionMessagePeerConnected`]: crate::events::Event::OnionMessagePeerConnected +pub trait OnionMessageInterceptor { + /// Registers a peer for onion message interception. + /// + /// See [`OnionMessenger::register_peer_for_interception`] for more details. + fn register_peer_for_interception(&self, peer_node_id: PublicKey); + + /// Deregisters a peer from onion message interception. + /// + /// See [`OnionMessenger::deregister_peer_for_interception`] for more details. + /// + /// Returns whether the peer was previously registered. + fn deregister_peer_for_interception(&self, peer_node_id: &PublicKey) -> bool; +} + +impl< + ES: EntropySource, + NS: NodeSigner, + L: Logger, + NL: NodeIdLookUp, + MR: MessageRouter, + OMH: OffersMessageHandler, + APH: AsyncPaymentsMessageHandler, + DRH: DNSResolverMessageHandler, + CMH: CustomOnionMessageHandler, + > OnionMessageInterceptor for OnionMessenger +{ + fn register_peer_for_interception(&self, peer_node_id: PublicKey) { + OnionMessenger::register_peer_for_interception(self, peer_node_id) + } + + fn deregister_peer_for_interception(&self, peer_node_id: &PublicKey) -> bool { + OnionMessenger::deregister_peer_for_interception(self, peer_node_id) + } +} + /// A sender, receiver and forwarder of [`OnionMessage`]s. /// /// # Handling Messages @@ -273,6 +321,7 @@ pub struct OnionMessenger< dns_resolver_handler: DRH, custom_handler: CMH, intercept_messages_for_offline_peers: bool, + peers_registered_for_interception: Mutex>, pending_intercepted_msgs_events: Mutex>, pending_peer_connected_events: Mutex>, pending_events_processor: AtomicBool, @@ -1453,6 +1502,7 @@ impl< dns_resolver_handler: dns_resolver, custom_handler, intercept_messages_for_offline_peers, + peers_registered_for_interception: Mutex::new(new_hash_set()), pending_intercepted_msgs_events: Mutex::new(Vec::new()), pending_peer_connected_events: Mutex::new(Vec::new()), pending_events_processor: AtomicBool::new(false), @@ -1470,6 +1520,37 @@ impl< self.async_payments_handler = async_payments_handler; } + /// Registers a peer for onion message interception. + /// + /// When an onion message needs to be forwarded to a registered peer that is currently offline, + /// an [`Event::OnionMessageIntercepted`] will be generated, allowing the message to be stored + /// and forwarded later when the peer reconnects. + /// + /// Similarly, when a registered peer connects, an [`Event::OnionMessagePeerConnected`] will + /// be generated. + /// + /// This is useful for services like LSPS2 that need to intercept onion messages for specific + /// peers (e.g., those with active JIT channel sessions) without enabling blanket interception + /// for all offline peers via [`Self::new_with_offline_peer_interception`]. + /// + /// Use [`Self::deregister_peer_for_interception`] to stop intercepting messages for this peer. + /// + /// [`Event::OnionMessageIntercepted`]: crate::events::Event::OnionMessageIntercepted + /// [`Event::OnionMessagePeerConnected`]: crate::events::Event::OnionMessagePeerConnected + pub fn register_peer_for_interception(&self, peer_node_id: PublicKey) { + self.peers_registered_for_interception.lock().unwrap().insert(peer_node_id); + } + + /// Deregisters a peer from onion message interception. + /// + /// After this call, onion messages for this peer will no longer be intercepted (unless + /// blanket interception is enabled via [`Self::new_with_offline_peer_interception`]). + /// + /// Returns whether the peer was previously registered. + pub fn deregister_peer_for_interception(&self, peer_node_id: &PublicKey) -> bool { + self.peers_registered_for_interception.lock().unwrap().remove(peer_node_id) + } + /// Sends an [`OnionMessage`] based on its [`MessageSendInstructions`]. pub fn send_onion_message( &self, contents: T, instructions: MessageSendInstructions, @@ -1686,6 +1767,9 @@ impl< .entry(next_node_id) .or_insert_with(|| OnionMessageRecipient::ConnectedPeer(VecDeque::new())); + let should_intercept = self.intercept_messages_for_offline_peers + || self.peers_registered_for_interception.lock().unwrap().contains(&next_node_id); + match message_recipients.entry(next_node_id) { hash_map::Entry::Occupied(mut e) if matches!(e.get(), OnionMessageRecipient::ConnectedPeer(..)) => @@ -1699,7 +1783,7 @@ impl< ); Ok(()) }, - _ if self.intercept_messages_for_offline_peers => { + _ if should_intercept => { log_trace!( self.logger, "Generating OnionMessageIntercepted event for peer {} {}", @@ -2142,7 +2226,9 @@ impl< .or_insert_with(|| OnionMessageRecipient::ConnectedPeer(VecDeque::new())) .mark_connected(); } - if self.intercept_messages_for_offline_peers { + let is_registered = + self.peers_registered_for_interception.lock().unwrap().contains(&their_node_id); + if self.intercept_messages_for_offline_peers || is_registered { let mut pending_peer_connected_events = self.pending_peer_connected_events.lock().unwrap(); pending_peer_connected_events diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index 22be4367c7a..8e5f9ad1e42 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -165,6 +165,23 @@ impl chaininterface::FeeEstimator for TestFeeEstimator { } } +/// Override closure type for [`TestRouter::override_create_blinded_payment_paths`]. +/// +/// This closure is called instead of the default [`Router::create_blinded_payment_paths`] +/// implementation when set, receiving the actual [`ReceiveTlvs`] so tests can construct custom +/// blinded payment paths using the same TLVs the caller generated. +pub type BlindedPaymentPathOverrideFn = Box< + dyn Fn( + PublicKey, + ReceiveAuthKey, + Vec, + ReceiveTlvs, + Option, + ) -> Result, ()> + + Send + + Sync, +>; + pub struct TestRouter<'a> { pub router: DefaultRouter< Arc>, @@ -177,6 +194,7 @@ pub struct TestRouter<'a> { pub network_graph: Arc>, pub next_routes: Mutex>)>>, pub next_blinded_payment_paths: Mutex>, + pub override_create_blinded_payment_paths: Mutex>, pub scorer: &'a RwLock, } @@ -188,6 +206,7 @@ impl<'a> TestRouter<'a> { let entropy_source = Arc::new(RandomBytes::new([42; 32])); let next_routes = Mutex::new(VecDeque::new()); let next_blinded_payment_paths = Mutex::new(Vec::new()); + let override_create_blinded_payment_paths = Mutex::new(None); Self { router: DefaultRouter::new( Arc::clone(&network_graph), @@ -199,6 +218,7 @@ impl<'a> TestRouter<'a> { network_graph, next_routes, next_blinded_payment_paths, + override_create_blinded_payment_paths, scorer, } } @@ -321,6 +341,12 @@ impl<'a> Router for TestRouter<'a> { first_hops: Vec, tlvs: ReceiveTlvs, amount_msats: Option, secp_ctx: &Secp256k1, ) -> Result, ()> { + if let Some(override_fn) = + self.override_create_blinded_payment_paths.lock().unwrap().as_ref() + { + return override_fn(recipient, local_node_receive_key, first_hops, tlvs, amount_msats); + } + let mut expected_paths = self.next_blinded_payment_paths.lock().unwrap(); if expected_paths.is_empty() { self.router.create_blinded_payment_paths(