diff --git a/Cargo.toml b/Cargo.toml index 493717e..10be1f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ panic = 'abort' # Abort on panic bitcoin-payment-instructions = { version = "0.6.0" } lightning = { version = "0.2.0" } lightning-invoice = { version = "0.34.0" } +lightning-macros = "0.2.0" [profile.release] panic = "abort" diff --git a/examples/cli/src/main.rs b/examples/cli/src/main.rs index 138716b..ee08006 100644 --- a/examples/cli/src/main.rs +++ b/examples/cli/src/main.rs @@ -4,6 +4,7 @@ use colored::Colorize; use rustyline::DefaultEditor; use rustyline::error::ReadlineError; +use orange_sdk::bitcoin::hex::DisplayHex; use orange_sdk::bitcoin_payment_instructions::amount::Amount; use orange_sdk::{ ChainSource, Event, ExtraConfig, LoggerType, Mnemonic, PaymentInfo, Seed, SparkWalletConfig, @@ -198,6 +199,21 @@ impl WalletState { fee_msat ); }, + Event::RebalanceFailed { + trigger_payment_id, + trusted_rebalance_payment_id, + amount_msat, + reason, + } => { + println!( + "{} Rebalance failed: {} msat, trigger_payment_id: {}, trusted_rebalance_payment_id: {:?}, reason: {}", + "❌".bright_red(), + amount_msat, + trigger_payment_id, + trusted_rebalance_payment_id.map(|id| id.to_lower_hex_string()), + reason + ); + }, Event::SplicePending { new_funding_txo, .. } => { println!( "{} Splice pending: {}", diff --git a/graduated-rebalancer/Cargo.toml b/graduated-rebalancer/Cargo.toml index ab75ce0..3fc7c07 100644 --- a/graduated-rebalancer/Cargo.toml +++ b/graduated-rebalancer/Cargo.toml @@ -10,4 +10,5 @@ license = "MIT OR Apache-2.0" bitcoin-payment-instructions = { workspace = true } lightning = { workspace = true } lightning-invoice = { workspace = true } -tokio = { version = "1", default-features = false } +lightning-macros = { workspace = true } +tokio = { version = "1", default-features = false, features = ["sync", "macros", "rt", "time"] } diff --git a/graduated-rebalancer/src/lib.rs b/graduated-rebalancer/src/lib.rs index fa65da8..e2878da 100644 --- a/graduated-rebalancer/src/lib.rs +++ b/graduated-rebalancer/src/lib.rs @@ -12,14 +12,122 @@ use bitcoin_payment_instructions::PaymentMethod; use lightning::bitcoin::hashes::Hash; use lightning::bitcoin::hex::DisplayHex; use lightning::bitcoin::OutPoint; +use lightning::impl_writeable_tlv_based; use lightning::util::logger::Logger; use lightning::{log_debug, log_error, log_info}; use lightning_invoice::Bolt11Invoice; + use std::fmt::Debug; use std::future::Future; use std::pin::Pin; use std::sync::Arc; +const REBALANCE_RETRY_WAIT_TIME_SECS: u64 = 60; + +/// Represents the state of an in-progress rebalance from trusted to lightning wallet +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RebalanceState { + /// ID from the rebalance trigger + pub trigger_id: [u8; 32], + /// Expected payment hash for the lightning invoice + pub expected_payment_hash: [u8; 32], + /// Amount being rebalanced in millisatoshis + pub amount_msat: u64, + /// Whether the lightning wallet has received the payment + pub ln_payment_received: bool, + /// Whether the trusted wallet has confirmed sending the payment + pub trusted_payment_sent: bool, + /// Payment ID from the trusted wallet + pub trusted_payment_id: Option<[u8; 32]>, + /// Lightning payment ID (set when received) + pub ln_payment_id: Option<[u8; 32]>, + /// Fee paid by lightning wallet in millisatoshis + pub ln_fee_msat: Option, + /// Fee paid by trusted wallet in millisatoshis + pub trusted_fee_msat: Option, +} + +impl_writeable_tlv_based!(RebalanceState, { + (0, trigger_id, required), + (2, expected_payment_hash, required), + (4, amount_msat, required), + (6, ln_payment_received, required), + (8, trusted_payment_sent, required), + (10, trusted_payment_id, option), + (12, ln_payment_id, option), + (14, ln_fee_msat, option), + (16, trusted_fee_msat, option), +}); + +/// Represents the state of an in-progress on-chain rebalance +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct OnChainRebalanceState { + /// ID from the rebalance trigger (the triggering on-chain txid) + pub trigger_id: [u8; 32], + /// User channel ID assigned by LDK (set after operation completes) + pub user_channel_id: Option, + /// Amount being rebalanced in satoshis + pub amount_sats: u64, + /// Whether this is a splice (true) or channel open (false) + pub is_splice: bool, + /// Whether the channel/splice pending event has been received + pub pending_confirmed: bool, + /// The channel outpoint (set when pending_confirmed is true) + pub channel_outpoint: Option, +} + +impl_writeable_tlv_based!(OnChainRebalanceState, { + (0, trigger_id, required), + (2, user_channel_id, option), + (4, amount_sats, required), + (6, is_splice, required), + (8, pending_confirmed, required), + (10, channel_outpoint, option), +}); + +/// Trait for persisting rebalance state across restarts +pub trait RebalancePersistence: Send + Sync { + /// Insert a new trusted rebalance state + fn insert_trusted_rebalance_state( + &self, state: RebalanceState, + ) -> Pin + Send + '_>>; + + /// Update an existing trusted rebalance state + fn update_trusted_rebalance_state( + &self, state: RebalanceState, + ) -> Pin + Send + '_>> { + self.insert_trusted_rebalance_state(state) + } + + /// Remove a trusted rebalance state + fn remove_trusted_rebalance_state(&self) -> Pin + Send + '_>>; + + /// Gets the current trusted rebalance state + fn get_trusted_rebalance( + &self, + ) -> Pin, ()>> + Send + '_>>; + + /// Insert a new on-chain rebalance state + fn insert_onchain_rebalance_state( + &self, state: OnChainRebalanceState, + ) -> Pin + Send + '_>>; + + /// Update an existing on-chain rebalance state + fn update_onchain_rebalance_state( + &self, state: OnChainRebalanceState, + ) -> Pin + Send + '_>> { + self.insert_onchain_rebalance_state(state) + } + + /// Remove an on-chain rebalance state + fn remove_onchain_rebalance_state(&self) -> Pin + Send + '_>>; + + /// Gets the current on-chain rebalance state + fn get_onchain_rebalance( + &self, + ) -> Pin, ()>> + Send + '_>>; +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] /// Parameters for doing a rebalance pub struct TriggerParams { @@ -60,6 +168,13 @@ impl Default for RebalanceTunables { } } +/// Information about a pending payment found in the trusted wallet +#[derive(Debug, Clone, Copy)] +pub struct PendingPaymentInfo { + /// The payment ID assigned by the trusted wallet + pub payment_id: [u8; 32], +} + /// Trait representing a trusted wallet backend pub trait TrustedWallet: Send + Sync { /// Error type for trusted wallet operations @@ -79,10 +194,13 @@ pub trait TrustedWallet: Send + Sync { &self, method: PaymentMethod, amount: Amount, ) -> Pin> + Send + '_>>; - /// Wait for a payment success notification - fn await_payment_success( + /// Find a pending outbound payment by its payment hash. + /// Returns `Some(PendingPaymentInfo)` if a pending payment with the given hash exists, + /// `None` otherwise. This is used during recovery to determine if a payment was + /// actually initiated before a crash. + fn find_payment_by_hash( &self, payment_hash: [u8; 32], - ) -> Pin> + Send + '_>>; + ) -> Pin, Self::Error>> + Send + '_>>; } /// Trait representing a lightning wallet backend @@ -103,11 +221,6 @@ pub trait LightningWallet: Send + Sync { &self, method: PaymentMethod, amount: Amount, ) -> Pin> + Send + '_>>; - /// Wait for a payment receipt notification - fn await_payment_receipt( - &self, payment_hash: [u8; 32], - ) -> Pin> + Send + '_>>; - /// Check if we already have a channel with the LSP fn has_channel_with_lsp(&self) -> bool; @@ -116,29 +229,17 @@ pub trait LightningWallet: Send + Sync { &self, amt: Amount, ) -> Pin> + Send + '_>>; - /// Wait for a channel pending notification, returns the new channel's outpoint - fn await_channel_pending( - &self, channel_id: u128, - ) -> Pin + Send + '_>>; - /// Splice funds from on-chain to an existing channel with the LSP fn splice_to_lsp_channel( &self, amt: Amount, ) -> Pin> + Send + '_>>; - /// Wait for a splice pending notification, returns the splice outpoint - fn await_splice_pending( - &self, channel_id: u128, - ) -> Pin + Send + '_>>; -} + /// Get the funding outpoint for a channel by user_channel_id, if it exists + fn get_channel_outpoint(&self, user_channel_id: u128) -> Option; -/// Represents a payment from the lightning wallet -#[derive(Debug, Clone)] -pub struct ReceivedLightningPayment { - /// Unique payment ID - pub id: [u8; 32], - /// Fee paid in millisatoshis - pub fee_paid_msat: Option, + /// Find an untracked pending channel/splice with the LSP. + /// Returns None if no matching channel is found. + fn find_pending_lsp_channel(&self) -> Option; } /// Lightning wallet balance information @@ -175,6 +276,17 @@ pub enum RebalancerEvent { /// Total fee paid in millisatoshis fee_msat: u64, }, + /// Rebalance failed + RebalanceFailed { + /// Trigger id given by the rebalance trigger + trigger_id: [u8; 32], + /// Trusted wallet payment ID for the rebalance + trusted_rebalance_payment_id: Option<[u8; 32]>, + /// Amount that was being rebalanced in millisatoshis + amount_msat: u64, + /// Reason for failure + reason: String, + }, /// We have initiated a lightning channel open OnChainRebalanceInitiated { /// Trigger id given by the rebalance trigger @@ -212,40 +324,117 @@ pub struct GraduatedRebalancer< L: LightningWallet, R: RebalanceTrigger, E: EventHandler, + P: RebalancePersistence, O: Logger, > { trusted: Arc, ln_wallet: Arc, trigger: Arc, event_handler: Arc, + persistence: Arc

, logger: Arc, - /// Mutex to ensure thread-safe balance operations. - balance_mutex: tokio::sync::Mutex<()>, + /// In-memory cache of active rebalance (only one trusted rebalance allowed at a time) + active_trusted_rebalance: Arc>>, + + /// In-memory cache of active on-chain rebalance (only one on-chain rebalance allowed at a time) + active_onchain_rebalance: Arc>>, + + /// Handle to cancel a scheduled retry task + scheduled_retry_handle: Arc>>, } -impl GraduatedRebalancer +impl< + T: TrustedWallet, + L: LightningWallet, + R: RebalanceTrigger, + E: EventHandler, + P: RebalancePersistence, + O: Logger, + > Clone for GraduatedRebalancer +{ + fn clone(&self) -> Self { + Self { + trusted: Arc::clone(&self.trusted), + ln_wallet: Arc::clone(&self.ln_wallet), + trigger: Arc::clone(&self.trigger), + event_handler: Arc::clone(&self.event_handler), + persistence: Arc::clone(&self.persistence), + logger: Arc::clone(&self.logger), + active_trusted_rebalance: Arc::clone(&self.active_trusted_rebalance), + active_onchain_rebalance: Arc::clone(&self.active_onchain_rebalance), + scheduled_retry_handle: Arc::clone(&self.scheduled_retry_handle), + } + } +} + +impl GraduatedRebalancer where - T: TrustedWallet, - LN: LightningWallet, - R: RebalanceTrigger, - E: EventHandler, - L: Logger, + T: TrustedWallet + 'static, + LN: LightningWallet + 'static, + R: RebalanceTrigger + 'static, + E: EventHandler + 'static, + P: RebalancePersistence + 'static, + L: Logger + Send + Sync + 'static, { /// Create a new graduated rebalancer pub fn new( - trusted: Arc, ln_wallet: Arc, trigger: Arc, event_handler: Arc, logger: Arc, + trusted: Arc, ln_wallet: Arc, trigger: Arc, event_handler: Arc, + persistence: Arc

, logger: Arc, ) -> Self { Self { trusted, ln_wallet, trigger, event_handler, + persistence, logger, - balance_mutex: tokio::sync::Mutex::new(()), + active_trusted_rebalance: Arc::new(tokio::sync::Mutex::new(None)), + active_onchain_rebalance: Arc::new(tokio::sync::Mutex::new(None)), + scheduled_retry_handle: Arc::new(tokio::sync::Mutex::new(None)), + } + } + + /// Cancel any scheduled retry task + async fn cancel_scheduled_retry(&self) { + let mut handle = self.scheduled_retry_handle.lock().await; + if let Some(abort_handle) = handle.take() { + log_debug!(self.logger, "Cancelling scheduled retry"); + abort_handle.abort(); } } + /// Schedule a retry of the rebalance after a delay. + /// If a rebalance is triggered before the delay, the retry will be cancelled. + fn schedule_retry(&self) { + let this = self.clone(); + let handle_mutex = Arc::clone(&self.scheduled_retry_handle); + + tokio::spawn(async move { + // Cancel any existing scheduled retry + let mut handle = handle_mutex.lock().await; + if let Some(abort_handle) = handle.take() { + abort_handle.abort(); + } + + // Create the retry task + let retry_this = this.clone(); + let task = tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_secs(REBALANCE_RETRY_WAIT_TIME_SECS)) + .await; + log_info!(retry_this.logger, "Executing scheduled rebalance retry"); + retry_this.do_rebalance_if_needed().await; + }); + + *handle = Some(task.abort_handle()); + + log_info!( + this.logger, + "Scheduled rebalance retry in {REBALANCE_RETRY_WAIT_TIME_SECS} seconds" + ); + }); + } + /// Does any rebalance if it meets the conditions of the tunables pub async fn do_rebalance_if_needed(&self) { self.do_trusted_rebalance_if_needed().await; @@ -268,136 +457,539 @@ where } /// Perform a rebalance from trusted to lightning wallet + /// This method initiates the rebalance and returns immediately. + /// Completion is handled via the observer pattern (on_trusted_payment_sent, on_ln_payment_received). async fn do_trusted_rebalance(&self, params: TriggerParams) { - let transfer_amt = params.amount; - let _lock = self.balance_mutex.lock().await; - log_info!(self.logger, "Initiating rebalance"); + // Check if there's already an active rebalance + let mut rebalance = self.active_trusted_rebalance.lock().await; + if rebalance.is_some() { + log_info!(self.logger, "Skipping rebalance trigger - already have an active rebalance"); + return; + } - if let Ok(inv) = self.ln_wallet.get_bolt11_invoice(Some(transfer_amt)).await { - log_debug!( - self.logger, - "Attempting to pay invoice {inv} to rebalance for {transfer_amt:?}", - ); - let expected_hash = *inv.payment_hash(); - match self.trusted.pay(PaymentMethod::LightningBolt11(inv), transfer_amt).await { - Ok(rebalance_id) => { - log_debug!( - self.logger, - "Rebalance trusted transaction initiated, id {}. Waiting for LN payment.", - rebalance_id.as_hex() - ); + // Cancel any scheduled retry since we're actually doing a rebalance now + self.cancel_scheduled_retry().await; - self.event_handler - .handle_event(RebalancerEvent::RebalanceInitiated { - trigger_id: params.id, - trusted_rebalance_payment_id: rebalance_id, - amount_msat: transfer_amt.milli_sats(), - }) - .await; - - let ln_payment = match self - .ln_wallet - .await_payment_receipt(expected_hash.to_byte_array()) - .await - { - Some(receipt) => receipt, - None => { - log_error!(self.logger, "Failed to receive rebalance payment!"); - return; - }, - }; - - let trusted_payment = match self - .trusted - .await_payment_success(expected_hash.to_byte_array()) - .await - { - Some(success) => success, - None => { - log_error!(self.logger, "Failed to send rebalance payment!"); - return; - }, - }; + log_info!(self.logger, "Initiating rebalance"); - log_info!( - self.logger, - "Rebalance succeeded. Sent trusted tx {} to lightning tx {}", - rebalance_id.as_hex(), - ln_payment.id.as_hex(), - ); + let transfer_amt = params.amount; + let inv = match self.ln_wallet.get_bolt11_invoice(Some(transfer_amt)).await { + Ok(inv) => inv, + Err(_) => return, + }; - self.event_handler - .handle_event(RebalancerEvent::RebalanceSuccessful { - trigger_id: params.id, - trusted_rebalance_payment_id: rebalance_id, - ln_rebalance_payment_id: ln_payment.id, - amount_msat: transfer_amt.milli_sats(), - fee_msat: ln_payment.fee_paid_msat.unwrap_or_default() - + trusted_payment.fee_paid_msat.unwrap_or_default(), - }) - .await; - }, - Err(e) => { - log_info!(self.logger, "Rebalance trusted transaction failed with {e:?}",); - }, - } + log_debug!( + self.logger, + "Attempting to pay invoice {inv} to rebalance for {transfer_amt:?}", + ); + + let expected_hash = inv.payment_hash().to_byte_array(); + + // Create and persist rebalance state + let mut state = RebalanceState { + trigger_id: params.id, + expected_payment_hash: expected_hash, + amount_msat: transfer_amt.milli_sats(), + ln_payment_received: false, + trusted_payment_sent: false, + trusted_payment_id: None, + ln_payment_id: None, + ln_fee_msat: None, + trusted_fee_msat: None, + }; + self.persistence.insert_trusted_rebalance_state(state).await; + + match self.trusted.pay(PaymentMethod::LightningBolt11(inv), transfer_amt).await { + Ok(trusted_payment_id) => { + log_debug!( + self.logger, + "Rebalance trusted transaction initiated, id {}. Will complete via observer callbacks.", + trusted_payment_id.as_hex() + ); + + // persist trusted payment id + state.trusted_payment_id = Some(trusted_payment_id); + self.persistence.update_trusted_rebalance_state(state).await; + + // Set active rebalance + *rebalance = Some(state); + + // Post initiated event + self.event_handler + .handle_event(RebalancerEvent::RebalanceInitiated { + trigger_id: params.id, + trusted_rebalance_payment_id: trusted_payment_id, + amount_msat: transfer_amt.milli_sats(), + }) + .await; + }, + Err(e) => { + log_info!(self.logger, "Rebalance trusted transaction failed with {e:?}",); + + // Clean up persisted state + self.persistence.remove_trusted_rebalance_state().await; + + // Post failure event + self.event_handler + .handle_event(RebalancerEvent::RebalanceFailed { + trigger_id: params.id, + trusted_rebalance_payment_id: None, + amount_msat: transfer_amt.milli_sats(), + reason: format!("Failed to initiate trusted payment: {e:?}"), + }) + .await; + + // Release the lock before scheduling retry + drop(rebalance); + + // Schedule a retry + self.schedule_retry(); + }, } } /// Perform on-chain to lightning rebalance by opening a channel or splicing into an existing one async fn do_onchain_rebalance(&self, params: TriggerParams) { - let _ = self.balance_mutex.lock().await; + let mut onchain_rebalance = self.active_onchain_rebalance.lock().await; - let (channel_outpoint, user_channel_id) = if self.ln_wallet.has_channel_with_lsp() { - log_info!(self.logger, "Splicing into channel with LSP with on-chain funds"); + // Check if there's already an active on-chain rebalance + if onchain_rebalance.is_some() { + log_info!( + self.logger, + "Skipping on-chain rebalance trigger - already have an active on-chain rebalance" + ); + return; + } - let user_chan_id = match self.ln_wallet.splice_to_lsp_channel(params.amount).await { - Ok(chan_id) => chan_id, - Err(e) => { - log_error!(self.logger, "Failed to open channel with LSP: {e:?}"); - return; - }, - }; + // Cancel any scheduled retry since we're actually doing a rebalance now + self.cancel_scheduled_retry().await; - log_info!(self.logger, "Initiated splice opened with LSP"); + let is_splice = self.ln_wallet.has_channel_with_lsp(); - let channel_outpoint = self.ln_wallet.await_splice_pending(user_chan_id).await; + if is_splice { + log_info!(self.logger, "Splicing into channel with LSP with on-chain funds"); + } else { + log_info!(self.logger, "Opening channel with LSP with on-chain funds"); + } - log_info!(self.logger, "Splice initiated at: {channel_outpoint}"); + // Persist state BEFORE initiating the operation to ensure we don't lose track + let mut state = OnChainRebalanceState { + trigger_id: params.id, + user_channel_id: None, + amount_sats: params.amount.sats_rounding_up(), + is_splice, + pending_confirmed: false, + channel_outpoint: None, + }; + self.persistence.insert_onchain_rebalance_state(state).await; + *onchain_rebalance = Some(state); - (channel_outpoint, user_chan_id) + // Now initiate the actual operation + let channel_result = if is_splice { + self.ln_wallet.splice_to_lsp_channel(params.amount).await } else { - log_info!(self.logger, "Opening channel with LSP with on-chain funds"); + self.ln_wallet.open_channel_with_lsp(params.amount).await + }; + + match channel_result { + Ok(user_channel_id) => { + // Update state with the user_channel_id + state.user_channel_id = Some(user_channel_id); + *onchain_rebalance = Some(state); + self.persistence.update_onchain_rebalance_state(state).await; + + log_info!( + self.logger, + "On-chain rebalance initiated for user_channel_id {user_channel_id}. Will complete via observer callback." + ); + }, + Err(e) => { + let op_name = if is_splice { "splice to" } else { "open channel with" }; + log_error!(self.logger, "Failed to {op_name} LSP channel: {e:?}"); + // Clean up the state if we failed + let _ = onchain_rebalance.take(); + self.persistence.remove_onchain_rebalance_state().await; + self.event_handler + .handle_event(RebalancerEvent::RebalanceFailed { + trigger_id: params.id, + trusted_rebalance_payment_id: None, + amount_msat: params.amount.milli_sats(), + reason: format!("Failed to {op_name} LSP channel: {e:?}"), + }) + .await; + + // Release the lock before scheduling retry + drop(onchain_rebalance); + + // Schedule a retry + self.schedule_retry(); + }, + } + } - let user_chan_id = match self.ln_wallet.open_channel_with_lsp(params.amount).await { - Ok(chan_id) => chan_id, - Err(e) => { - log_error!(self.logger, "Failed to open channel with LSP: {e:?}"); - return; - }, - }; + /// Called when the trusted wallet confirms sending a payment + /// This is part of the observer pattern to handle rebalance completion across restarts + pub async fn on_trusted_payment_sent(&self, payment_hash: [u8; 32], fee_msat: Option) { + let mut rebalance = self.active_trusted_rebalance.lock().await; + + if let Some(state) = rebalance.as_mut() { + if state.expected_payment_hash == payment_hash { + state.trusted_payment_sent = true; + state.trusted_fee_msat = fee_msat; + + log_debug!( + self.logger, + "Trusted payment sent for rebalance {}", + payment_hash.as_hex() + ); + + // Check if rebalance is complete, otherwise persist state + if state.ln_payment_received { + self.complete_rebalance(*state, &mut rebalance).await; + } else { + self.persistence.update_trusted_rebalance_state(*state).await; + } + } + } + } - log_info!(self.logger, "Initiated channel opened with LSP"); + /// Called when the lightning wallet receives a payment + /// This is part of the observer pattern to handle rebalance completion across restarts + pub async fn on_ln_payment_received( + &self, payment_hash: [u8; 32], payment_id: [u8; 32], fee_msat: Option, + ) { + let mut rebalance = self.active_trusted_rebalance.lock().await; + + if let Some(state) = rebalance.as_mut() { + if state.expected_payment_hash == payment_hash { + state.ln_payment_received = true; + state.ln_payment_id = Some(payment_id); + state.ln_fee_msat = fee_msat; + + log_debug!( + self.logger, + "Lightning payment received for rebalance {}", + payment_hash.as_hex() + ); + + // Check if rebalance is complete, otherwise persist state + if state.trusted_payment_sent { + self.complete_rebalance(*state, &mut rebalance).await; + } else { + self.persistence.update_trusted_rebalance_state(*state).await; + } + } + } + } - let channel_outpoint = self.ln_wallet.await_channel_pending(user_chan_id).await; + /// Called when a trusted wallet payment fails + /// This is part of the observer pattern to handle rebalance failures + pub async fn on_trusted_payment_failed(&self, payment_hash: [u8; 32], reason: String) { + let mut rebalance = self.active_trusted_rebalance.lock().await; + + if let Some(state) = rebalance.as_ref() { + if state.expected_payment_hash == payment_hash { + log_info!( + self.logger, + "Trusted payment failed for rebalance {}: {}", + payment_hash.as_hex(), + reason + ); + + // Post failure event + self.event_handler + .handle_event(RebalancerEvent::RebalanceFailed { + trigger_id: state.trigger_id, + trusted_rebalance_payment_id: state.trusted_payment_id, + amount_msat: state.amount_msat, + reason, + }) + .await; + + // Clean up + let _ = rebalance.take(); + self.persistence.remove_trusted_rebalance_state().await; + + // Release the lock before scheduling retry + drop(rebalance); + + // Schedule a retry + self.schedule_retry(); + } + } + } - log_info!(self.logger, "Channel open succeeded at: {channel_outpoint}"); + /// Called when a channel or splice becomes pending + /// This is part of the observer pattern to handle on-chain rebalance completion across restarts + pub async fn on_channel_splice_pending( + &self, user_channel_id: u128, channel_outpoint: OutPoint, + ) { + let mut onchain_rebalance = self.active_onchain_rebalance.lock().await; + + // Check if there's an active on-chain rebalance matching this user_channel_id + if let Some(state) = onchain_rebalance.as_mut() { + if state.user_channel_id == Some(user_channel_id) { + state.pending_confirmed = true; + state.channel_outpoint = Some(channel_outpoint); + + log_info!( + self.logger, + "On-chain rebalance pending confirmed for user_channel_id {} at outpoint {}", + user_channel_id, + channel_outpoint + ); + + self.complete_onchain_rebalance(*state, &mut onchain_rebalance).await; + } + } + } - (channel_outpoint, user_chan_id) - }; + /// Complete an on-chain rebalance by posting the event and cleaning up + async fn complete_onchain_rebalance( + &self, state: OnChainRebalanceState, + onchain_rebalance: &mut tokio::sync::MutexGuard<'_, Option>, + ) { + log_info!( + self.logger, + "On-chain rebalance completed for user_channel_id {:?} at {}", + state.user_channel_id, + state.channel_outpoint.expect("channel_outpoint must be set") + ); self.event_handler .handle_event(RebalancerEvent::OnChainRebalanceInitiated { - trigger_id: params.id, - user_channel_id, - channel_outpoint, + trigger_id: state.trigger_id, + user_channel_id: state.user_channel_id.expect("user_channel_id must be set"), + channel_outpoint: state.channel_outpoint.expect("channel_outpoint must be set"), }) .await; + + // Clean up + let _ = onchain_rebalance.take(); + self.persistence.remove_onchain_rebalance_state().await; + } + + /// Complete a rebalance by posting the success event and cleaning up + async fn complete_rebalance( + &self, state: RebalanceState, + rebalance: &mut tokio::sync::MutexGuard<'_, Option>, + ) { + log_info!( + self.logger, + "Rebalance succeeded. Sent trusted tx {} to lightning tx {}", + state.trusted_payment_id.expect("trusted_payment_id must be set").as_hex(), + state.ln_payment_id.expect("ln_payment_id must be set").as_hex(), + ); + + self.event_handler + .handle_event(RebalancerEvent::RebalanceSuccessful { + trigger_id: state.trigger_id, + trusted_rebalance_payment_id: state + .trusted_payment_id + .expect("trusted_payment_id must be set"), + ln_rebalance_payment_id: state.ln_payment_id.expect("ln_payment_id must be set"), + amount_msat: state.amount_msat, + fee_msat: state.ln_fee_msat.unwrap_or_default() + + state.trusted_fee_msat.unwrap_or_default(), + }) + .await; + + // Clean up + let _ = rebalance.take(); + self.persistence.remove_trusted_rebalance_state().await; + } + + /// Recover incomplete rebalances from persistence on startup + /// This should be called during wallet initialization + /// + /// Returns true if you should call [`Self::do_trusted_rebalance_if_needed`] again + /// to retry any cleaned up rebalances. + pub async fn recover_incomplete_trusted_rebalances(&self) -> Result { + let state_opt = self.persistence.get_trusted_rebalance().await?; + + if let Some(mut state) = state_opt { + let mut rebalance = self.active_trusted_rebalance.lock().await; + log_debug!( + self.logger, + "Recovering rebalance {} (ln_received: {}, trusted_sent: {}, trusted_payment_id: {:?})", + state.expected_payment_hash.as_hex(), + state.ln_payment_received, + state.trusted_payment_sent, + state.trusted_payment_id.map(|id| id.to_lower_hex_string()) + ); + + // If trusted_payment_id is None, we crashed after persisting state but before + // persisting the trusted_payment_id. However, the payment might have actually + // been sent, we need to check with the trusted wallet. + if state.trusted_payment_id.is_none() { + log_info!( + self.logger, + "Rebalance {} has no trusted_payment_id, checking if payment was actually sent", + state.expected_payment_hash.as_hex() + ); + + match self.trusted.find_payment_by_hash(state.expected_payment_hash).await { + Ok(Some(pending_info)) => { + // Payment was actually sent, Recover the payment id and continue. + log_info!( + self.logger, + "Found pending payment {} for rebalance {}, recovering", + pending_info.payment_id.as_hex(), + state.expected_payment_hash.as_hex() + ); + state.trusted_payment_id = Some(pending_info.payment_id); + // Update persistence with the recovered payment id. + self.persistence.update_trusted_rebalance_state(state).await; + }, + Ok(None) => { + // The payment was never actually sent. Clean up and retry. + log_info!( + self.logger, + "Rebalance {} was never initiated (no payment found), cleaning up and retrying", + state.expected_payment_hash.as_hex() + ); + self.persistence.remove_trusted_rebalance_state().await; + return Ok(true); + }, + Err(e) => { + log_error!( + self.logger, + "Failed to query trusted wallet ({e:?}) for rebalance {}, assuming payment was not sent", + state.expected_payment_hash.as_hex() + ); + self.persistence.remove_trusted_rebalance_state().await; + return Ok(true); + }, + } + } + + // If both sides already completed, finalize now + if state.ln_payment_received && state.trusted_payment_sent { + log_info!( + self.logger, + "Rebalance {} was already complete, finalizing", + state.expected_payment_hash.as_hex() + ); + + // Temporarily set to complete + *rebalance = Some(state); + self.complete_rebalance(state, &mut rebalance).await; + } else { + // Still waiting for one or both sides + // Note: We only recover the first incomplete rebalance since we only allow one at a time + if rebalance.is_none() { + *rebalance = Some(state); + } else { + debug_assert!( + false, + "Called recover_incomplete_rebalances with multiple times" + ); + } + } + } + + Ok(false) + } + + /// Recover incomplete on-chain rebalances from persistence on startup + /// This should be called during wallet initialization + /// + /// Returns true if you should call [`Self::do_onchain_rebalance_if_needed`] again + /// to retry any cleaned up rebalances. + pub async fn recover_incomplete_onchain_rebalances(&self) -> Result { + let state_opt = self.persistence.get_onchain_rebalance().await?; + + if let Some(mut state) = state_opt { + let mut onchain_rebalance = self.active_onchain_rebalance.lock().await; + log_debug!( + self.logger, + "Recovering on-chain rebalance for user_channel_id {:?} (pending_confirmed: {})", + state.user_channel_id, + state.pending_confirmed + ); + + // Check if the channel is already ready by querying LDK (only if user_channel_id is set) + if let Some(user_channel_id) = state.user_channel_id { + if let Some(outpoint) = self.ln_wallet.get_channel_outpoint(user_channel_id) { + // Channel is already pending/ready, complete now + log_info!( + self.logger, + "On-chain rebalance for user_channel_id {user_channel_id} was already pending, finalizing" + ); + + state.pending_confirmed = true; + state.channel_outpoint = Some(outpoint); + + *onchain_rebalance = Some(state); + self.complete_onchain_rebalance(state, &mut onchain_rebalance).await; + } else { + // Still waiting for channel/splice pending + if onchain_rebalance.is_none() { + *onchain_rebalance = Some(state); + } else { + debug_assert!( + false, + "Called recover_incomplete_onchain_rebalances multiple times" + ); + } + } + } else { + // user_channel_id not set yet, we crashed after persisting state but before + // persisting the user_channel_id. However, the channel/splice might have actually + // been initiated, we need to check with the lightning wallet. + log_info!( + self.logger, + "On-chain rebalance has no user_channel_id, checking if channel was actually opened" + ); + + match self.ln_wallet.find_pending_lsp_channel() { + Some(user_channel_id) => { + log_info!( + self.logger, + "Found pending channel {user_channel_id} for on-chain rebalance, recovering" + ); + state.user_channel_id = Some(user_channel_id); + self.persistence.update_onchain_rebalance_state(state).await; + + // Check if it's already pending/ready + if let Some(outpoint) = self.ln_wallet.get_channel_outpoint(user_channel_id) + { + log_info!( + self.logger, + "Recovered channel {user_channel_id} is already pending, finalizing" + ); + state.pending_confirmed = true; + state.channel_outpoint = Some(outpoint); + *onchain_rebalance = Some(state); + self.complete_onchain_rebalance(state, &mut onchain_rebalance).await; + } else { + // Still waiting for channel pending + *onchain_rebalance = Some(state); + } + }, + None => { + // Pending channel not found, the splice/channel open was never initiated. + // Clean up and retry. + log_info!( + self.logger, + "On-chain rebalance was never initiated (no pending channel found), cleaning up and retrying" + ); + self.persistence.remove_onchain_rebalance_state().await; + return Ok(true); + }, + } + } + } + + Ok(false) } /// Stops the rebalancer, waits for any active rebalances to complete pub async fn stop(&self) { - log_debug!(self.logger, "Waiting for balance mutex..."); - let _ = self.balance_mutex.lock().await; + log_debug!(self.logger, "Waiting for balance mutexes..."); + let _ = tokio::join!( + self.active_trusted_rebalance.lock(), + self.active_onchain_rebalance.lock() + ); } } diff --git a/orange-sdk/Cargo.toml b/orange-sdk/Cargo.toml index 3ad47b2..9ebb291 100644 --- a/orange-sdk/Cargo.toml +++ b/orange-sdk/Cargo.toml @@ -5,7 +5,7 @@ edition = "2024" authors = ["benthecarman "] documentation = "https://docs.rs/orange-sdk/" license = "MIT OR Apache-2.0" -keywords = [ "lightning", "bitcoin", "spark", "cashu" ] +keywords = ["lightning", "bitcoin", "spark", "cashu"] readme = "../README.md" [lib] @@ -24,7 +24,7 @@ _cashu-tests = ["_test-utils", "cdk-ldk-node", "cdk/mint", "cdk-sqlite", "cdk-ax graduated-rebalancer = { path = "../graduated-rebalancer", version = "0.1.0" } ldk-node = { version = "0.7.0" } -lightning-macros = "0.2.0" +lightning-macros = { workspace = true } bitcoin-payment-instructions = { workspace = true, features = ["http"] } chrono = { version = "0.4", default-features = false } rand = { version = "0.8.5", optional = true } diff --git a/orange-sdk/src/event.rs b/orange-sdk/src/event.rs index 87d179d..a288e1e 100644 --- a/orange-sdk/src/event.rs +++ b/orange-sdk/src/event.rs @@ -1,4 +1,5 @@ use crate::logging::Logger; +use crate::rebalancer::RebalanceEventHandlerHolder; use crate::store::{self, PaymentId}; use ldk_node::bitcoin::secp256k1::PublicKey; @@ -131,6 +132,17 @@ pub enum Event { /// The fee paid, in msats, for the rebalance payment. fee_msat: u64, }, + /// A rebalance from our trusted wallet has failed. + RebalanceFailed { + /// The `payment_id` of the transaction that triggered the rebalance. + trigger_payment_id: PaymentId, + /// The `payment_id` of the rebalance payment sent from the trusted wallet. + trusted_rebalance_payment_id: Option<[u8; 32]>, + /// The amount, in msats, of the rebalance payment. + amount_msat: u64, + /// The reason for the failure. + reason: String, + }, /// We have initiated a splice and are waiting for it to confirm. SplicePending { /// The `channel_id` of the channel. @@ -199,6 +211,12 @@ impl_writeable_tlv_based_enum!(Event, (5, user_channel_id, required), (7, new_funding_txo, required), }, + (9, RebalanceFailed) => { + (0, trigger_payment_id, required), + (2, trusted_rebalance_payment_id, option), + (4, amount_msat, required), + (6, reason, required), + }, ); /// A queue for events emitted by the [`Wallet`]. @@ -321,9 +339,8 @@ pub(crate) struct LdkEventHandler { pub(crate) event_queue: Arc, pub(crate) ldk_node: Arc, pub(crate) tx_metadata: store::TxMetadataStore, - pub(crate) payment_receipt_sender: watch::Sender<()>, - pub(crate) channel_pending_sender: watch::Sender, pub(crate) splice_pending_sender: watch::Sender, + pub(crate) rebalance_event_handler: RebalanceEventHandlerHolder, pub(crate) logger: Arc, } @@ -399,7 +416,11 @@ impl LdkEventHandler { { log_error!(self.logger, "Failed to add PaymentReceived event: {e:?}"); } - let _ = self.payment_receipt_sender.send(()); + + // Notify rebalancer if this might be a rebalance payment + self.rebalance_event_handler + .notify_ln_payment_received(payment_hash.0, payment_id.0, lsp_fee_msats) + .await; }, ldk_node::Event::PaymentForwarded { .. } => {}, ldk_node::Event::PaymentClaimable { .. } => { @@ -419,6 +440,11 @@ impl LdkEventHandler { } => { let funding_txo = funding_txo.unwrap(); // safe + // Notify rebalancer if this might be an on-chain rebalance + self.rebalance_event_handler + .notify_channel_splice_pending(user_channel_id.0, funding_txo) + .await; + if let Err(e) = self .event_queue .add_event(Event::ChannelOpened { @@ -432,7 +458,6 @@ impl LdkEventHandler { log_error!(self.logger, "Failed to add ChannelOpened event: {e:?}"); return; } - let _ = self.channel_pending_sender.send(user_channel_id.0); }, ldk_node::Event::ChannelClosed { channel_id, @@ -465,6 +490,12 @@ impl LdkEventHandler { new_funding_txo, } => { log_debug!(self.logger, "Received SplicePending event {event:?}"); + + // Notify rebalancer if this might be an on-chain rebalance + self.rebalance_event_handler + .notify_channel_splice_pending(user_channel_id.0, new_funding_txo) + .await; + let _ = self.splice_pending_sender.send(user_channel_id.0); if let Err(e) = self diff --git a/orange-sdk/src/ffi/orange/mod.rs b/orange-sdk/src/ffi/orange/mod.rs index b348a36..6987988 100644 --- a/orange-sdk/src/ffi/orange/mod.rs +++ b/orange-sdk/src/ffi/orange/mod.rs @@ -274,6 +274,17 @@ pub enum Event { /// The outpoint of the channel's splice funding transaction. new_funding_txo: String, }, + /// A rebalance from our trusted wallet has failed. + RebalanceFailed { + /// The `payment_id` of the transaction that triggered the rebalance. + trigger_payment_id: PaymentId, + /// The `payment_id` of the rebalance payment sent from the trusted wallet. + trusted_rebalance_payment_id: Option>, + /// The amount, in msats, of the rebalance payment. + amount_msat: u64, + /// The reason for the failure. + reason: String, + }, } impl From for Event { @@ -361,6 +372,17 @@ impl From for Event { amount_msat, fee_msat, }, + OrangeEvent::RebalanceFailed { + trigger_payment_id, + trusted_rebalance_payment_id, + amount_msat, + reason, + } => Event::RebalanceFailed { + trigger_payment_id: trigger_payment_id.into(), + trusted_rebalance_payment_id: trusted_rebalance_payment_id.map(|id| id.to_vec()), + amount_msat, + reason, + }, OrangeEvent::SplicePending { channel_id, user_channel_id, diff --git a/orange-sdk/src/lib.rs b/orange-sdk/src/lib.rs index 27f9f74..e927868 100644 --- a/orange-sdk/src/lib.rs +++ b/orange-sdk/src/lib.rs @@ -72,6 +72,7 @@ type Rebalancer = GraduatedRebalancer< LightningWallet, OrangeTrigger, OrangeRebalanceEventHandler, + store::RebalancePersistenceStore, Logger, >; @@ -529,6 +530,13 @@ impl Wallet { let tx_metadata = TxMetadataStore::new(Arc::clone(&store)).await; + // Create the rebalance event handler early so it can be passed to wallet init functions + let rebalance_events = Arc::new(OrangeRebalanceEventHandler::new( + tx_metadata.clone(), + Arc::clone(&event_queue), + Arc::clone(&logger), + )); + let (trusted, ln_wallet) = tokio::join!( async { let trusted: Arc> = match &config.extra_config { @@ -540,6 +548,7 @@ impl Wallet { Arc::clone(&store), Arc::clone(&event_queue), tx_metadata.clone(), + Arc::clone(&rebalance_events), Arc::clone(&logger), Arc::clone(&runtime), ) @@ -553,6 +562,7 @@ impl Wallet { Arc::clone(&store), Arc::clone(&event_queue), tx_metadata.clone(), + Arc::clone(&rebalance_events), Arc::clone(&logger), Arc::clone(&runtime), ) @@ -566,6 +576,7 @@ impl Wallet { &cfg.bitcoind, tx_metadata.clone(), Arc::clone(&event_queue), + Arc::clone(&rebalance_events), Arc::clone(&runtime), ) .await, @@ -581,6 +592,7 @@ impl Wallet { Arc::clone(&store), Arc::clone(&event_queue), tx_metadata.clone(), + Arc::clone(&rebalance_events), Arc::clone(&logger), ) .await?, @@ -602,20 +614,32 @@ impl Wallet { Arc::clone(&logger), )); - let rebalance_events = Arc::new(OrangeRebalanceEventHandler::new( - tx_metadata.clone(), - Arc::clone(&event_queue), - Arc::clone(&logger), - )); - let rebalancer = Arc::new(GraduatedRebalancer::new( wt, Arc::clone(&ln_wallet), trigger, - rebalance_events, + Arc::clone(&rebalance_events), + Arc::new(store::RebalancePersistenceStore::new(Arc::clone(&store))), Arc::clone(&logger), )); + // Set the rebalancer reference in the event handler + rebalance_events.set_rebalancer(Arc::clone(&rebalancer)); + + // Recover incomplete rebalances from previous sessions + // we ignore the return value because we always attempt a rebalance on startup + rebalancer.recover_incomplete_trusted_rebalances().await.map_err(|()| { + log_error!(logger, "Failed to recover incomplete rebalances"); + BuildError::WalletSetupFailed + })?; + + // Recover incomplete on-chain rebalances from previous sessions + // we ignore the return value because we always attempt a rebalance on startup + rebalancer.recover_incomplete_onchain_rebalances().await.map_err(|()| { + log_error!(logger, "Failed to recover incomplete on-chain rebalances"); + BuildError::WalletSetupFailed + })?; + // Spawn a background thread to initiate a rebalance if needed. // We only do this once as we generally rebalance in response to // `Event`s which indicated our balance has changed. diff --git a/orange-sdk/src/lightning_wallet.rs b/orange-sdk/src/lightning_wallet.rs index 4ccdef2..6ea1eae 100644 --- a/orange-sdk/src/lightning_wallet.rs +++ b/orange-sdk/src/lightning_wallet.rs @@ -2,6 +2,7 @@ use crate::bitcoin::OutPoint; use crate::bitcoin::hashes::Hash; use crate::event::{EventQueue, LdkEventHandler}; use crate::logging::Logger; +use crate::rebalancer::RebalanceEventHandlerHolder; use crate::runtime::Runtime; use crate::store::{TxMetadataStore, TxStatus}; use crate::{ChainSource, InitFailure, PaymentType, Seed, WalletConfig, store}; @@ -24,7 +25,7 @@ use ldk_node::payment::{ }; use ldk_node::{DynStore, NodeError, UserChannelId}; -use graduated_rebalancer::{LightningBalance, ReceivedLightningPayment}; +use graduated_rebalancer::LightningBalance; use std::collections::HashMap; use std::fmt::Debug; @@ -43,8 +44,6 @@ pub(crate) struct LightningWalletImpl { pub(crate) ldk_node: Arc, logger: Arc, store: Arc, - payment_receipt_flag: watch::Receiver<()>, - channel_pending_receipt_flag: watch::Receiver, splice_pending_receipt_flag: watch::Receiver, lsp_node_id: PublicKey, lsp_socket_addr: SocketAddress, @@ -57,9 +56,11 @@ pub(crate) struct LightningWallet { const DEFAULT_INVOICE_EXPIRY_SECS: u32 = 86_400; // 24 hours impl LightningWallet { + #[allow(clippy::too_many_arguments)] pub(super) async fn init( runtime: Arc, config: WalletConfig, store: Arc, - event_queue: Arc, tx_metadata: TxMetadataStore, logger: Arc, + event_queue: Arc, tx_metadata: TxMetadataStore, + rebalance_event_handler: RebalanceEventHandlerHolder, logger: Arc, ) -> Result { log_info!(logger, "Creating LDK node..."); let anchor_channels_config = ldk_node::config::AnchorChannelsConfig { @@ -164,24 +165,19 @@ impl LightningWallet { } let ldk_node = Arc::new(builder.build_with_store(Arc::clone(&store))?); - let (payment_receipt_sender, payment_receipt_flag) = watch::channel(()); - let (channel_pending_sender, channel_pending_receipt_flag) = watch::channel(0); let (splice_pending_sender, splice_pending_receipt_flag) = watch::channel(0); let ev_handler = Arc::new(LdkEventHandler { event_queue, ldk_node: Arc::clone(&ldk_node), tx_metadata, - payment_receipt_sender, - channel_pending_sender, splice_pending_sender, + rebalance_event_handler, logger: Arc::clone(&logger), }); let inner = Arc::new(LightningWalletImpl { ldk_node, logger, store, - payment_receipt_flag, - channel_pending_receipt_flag, splice_pending_receipt_flag, lsp_node_id, lsp_socket_addr, @@ -200,18 +196,6 @@ impl LightningWallet { Ok(Self { inner }) } - pub(crate) async fn await_payment_receipt(&self) { - let mut flag = self.inner.payment_receipt_flag.clone(); - flag.mark_unchanged(); - let _ = flag.changed().await; - } - - pub(crate) async fn await_channel_pending(&self, channel_id: u128) { - let mut flag = self.inner.channel_pending_receipt_flag.clone(); - flag.mark_unchanged(); - flag.wait_for(|t| t == &channel_id).await.expect("channel pending not received"); - } - pub(crate) async fn await_splice_pending(&self, channel_id: u128) { let mut flag = self.inner.splice_pending_receipt_flag.clone(); flag.mark_unchanged(); @@ -488,40 +472,6 @@ impl graduated_rebalancer::LightningWallet for LightningWallet { Box::pin(async move { self.pay(&method, amount).await.map(|p| p.0) }) } - fn await_payment_receipt( - &self, payment_hash: [u8; 32], - ) -> Pin> + Send + '_>> { - Box::pin(async move { - let id = PaymentId(payment_hash); - loop { - if let Some(payment) = self.inner.ldk_node.payment(&id) { - let counterparty_skimmed_fee_msat = match payment.kind { - PaymentKind::Bolt11 { hash, .. } => { - debug_assert!(hash.0 == payment_hash, "Payment Hash mismatch"); - None - }, - PaymentKind::Bolt11Jit { hash, counterparty_skimmed_fee_msat, .. } => { - debug_assert!(hash.0 == payment_hash, "Payment Hash mismatch"); - counterparty_skimmed_fee_msat - }, - _ => return None, // Ignore other payment kinds, we only care about the one we just sent. - }; - match payment.status { - PaymentStatus::Succeeded => { - return Some(ReceivedLightningPayment { - id: payment.id.0, - fee_paid_msat: counterparty_skimmed_fee_msat, - }); - }, - PaymentStatus::Pending => {}, - PaymentStatus::Failed => return None, - } - } - self.await_payment_receipt().await; - } - }) - } - fn has_channel_with_lsp(&self) -> bool { let channels = self.inner.ldk_node.list_channels(); channels.iter().any(|c| c.counterparty_node_id == self.inner.lsp_node_id) @@ -536,26 +486,6 @@ impl graduated_rebalancer::LightningWallet for LightningWallet { }) } - fn await_channel_pending( - &self, channel_id: u128, - ) -> Pin + Send + '_>> { - Box::pin(async move { - loop { - let channels = self.inner.ldk_node.list_channels(); - let chan = channels - .into_iter() - .find(|c| c.user_channel_id.0 == channel_id && c.funding_txo.is_some()); - match chan { - Some(c) => return c.funding_txo.expect("channel has no funding txo"), - None => { - self.await_channel_pending(channel_id).await; - // Wait for the next channel pending event - }, - } - } - }) - } - fn splice_to_lsp_channel( &self, amt: Amount, ) -> Pin> + Send + '_>> { @@ -574,29 +504,23 @@ impl graduated_rebalancer::LightningWallet for LightningWallet { Box::pin(async move { self.splice_balance_into_channel(amt).await.map(|c| c.0) }) } - fn await_splice_pending( - &self, channel_id: u128, - ) -> Pin + Send + '_>> { - Box::pin(async move { - // todo since we can't see if we have any active splices, we just await the next splice pending event - // this is kinda race-y hopefully we can fix - self.await_splice_pending(channel_id).await; - loop { - let channels = self.inner.ldk_node.list_channels(); - let chan = channels - .into_iter() - .find(|c| c.user_channel_id.0 == channel_id && c.funding_txo.is_some()); - match chan { - Some(c) => { - return c.funding_txo.expect("channel has no funding txo"); - }, - None => { - self.await_splice_pending(channel_id).await; - // Wait for the next channel pending event - }, - } - } - }) + fn get_channel_outpoint(&self, user_channel_id: u128) -> Option { + self.inner + .ldk_node + .list_channels() + .into_iter() + .find(|c| c.user_channel_id.0 == user_channel_id) + .and_then(|c| c.funding_txo) + } + + fn find_pending_lsp_channel(&self) -> Option { + let channels = self.inner.ldk_node.list_channels(); + let lsp_channels: Vec<_> = + channels.iter().filter(|c| c.counterparty_node_id == self.inner.lsp_node_id).collect(); + + // there should only be one channel with the LSP at a time + debug_assert_eq!(lsp_channels.len(), 1, "More than one channel with LSP found"); + lsp_channels.first().map(|c| c.user_channel_id.0) } } diff --git a/orange-sdk/src/rebalancer.rs b/orange-sdk/src/rebalancer.rs index 335717c..4e7b9d8 100644 --- a/orange-sdk/src/rebalancer.rs +++ b/orange-sdk/src/rebalancer.rs @@ -270,16 +270,72 @@ pub(crate) struct OrangeRebalanceEventHandler { tx_metadata: TxMetadataStore, /// The event handler for processing wallet events. event_queue: Arc, + /// Reference to the rebalancer, set once after initialization. + rebalancer: tokio::sync::OnceCell>, /// Logger for logging events and errors. logger: Arc, } +/// A holder type for the event handler that can be shared across wallets. +pub(crate) type RebalanceEventHandlerHolder = Arc; + impl OrangeRebalanceEventHandler { /// Creates a new `OrangeRebalanceEventHandler` instance. pub(crate) fn new( tx_metadata: TxMetadataStore, event_queue: Arc, logger: Arc, ) -> Self { - Self { tx_metadata, event_queue, logger } + Self { tx_metadata, event_queue, rebalancer: tokio::sync::OnceCell::new(), logger } + } + + /// Sets the rebalancer reference after initialization. Panics if called more than once. + pub(crate) fn set_rebalancer(&self, rebalancer: Arc) { + if self.rebalancer.set(rebalancer).is_err() { + panic!("rebalancer already set"); + } + } + + /// Notify that a trusted wallet payment has been sent. + pub(crate) async fn notify_trusted_payment_sent( + &self, payment_hash: [u8; 32], fee_msat: Option, + ) { + if let Some(rebalancer) = self.rebalancer.get() { + rebalancer.on_trusted_payment_sent(payment_hash, fee_msat).await; + } else { + debug_assert!(false, "Rebalancer not set in OrangeRebalanceEventHandler"); + } + } + + /// Notify that a lightning wallet payment has been received. + pub(crate) async fn notify_ln_payment_received( + &self, payment_hash: [u8; 32], payment_id: [u8; 32], fee_msat: Option, + ) { + if let Some(rebalancer) = self.rebalancer.get() { + rebalancer.on_ln_payment_received(payment_hash, payment_id, fee_msat).await; + } else { + debug_assert!(false, "Rebalancer not set in OrangeRebalanceEventHandler"); + } + } + + /// Notify that a trusted wallet payment has failed. + pub(crate) async fn notify_trusted_payment_failed( + &self, payment_hash: [u8; 32], reason: String, + ) { + if let Some(rebalancer) = self.rebalancer.get() { + rebalancer.on_trusted_payment_failed(payment_hash, reason).await; + } else { + debug_assert!(false, "Rebalancer not set in OrangeRebalanceEventHandler"); + } + } + + /// Notify that a channel or splice has become pending. + pub(crate) async fn notify_channel_splice_pending( + &self, user_channel_id: u128, outpoint: ldk_node::bitcoin::OutPoint, + ) { + if let Some(rebalancer) = self.rebalancer.get() { + rebalancer.on_channel_splice_pending(user_channel_id, outpoint).await; + } else { + debug_assert!(false, "Rebalancer not set in OrangeRebalanceEventHandler"); + } } } @@ -301,6 +357,7 @@ impl graduated_rebalancer::EventHandler for OrangeRebalanceEventHandler { self.tx_metadata .insert(PaymentId::Trusted(trusted_rebalance_payment_id), metadata) .await; + println!("=========== Rebalance Initiated Event ==========="); if let Err(e) = self .event_queue .add_event(Event::RebalanceInitiated { @@ -353,6 +410,33 @@ impl graduated_rebalancer::EventHandler for OrangeRebalanceEventHandler { } }); }, + RebalancerEvent::RebalanceFailed { + trigger_id, + trusted_rebalance_payment_id, + amount_msat, + reason, + } => { + log_info!( + self.logger, + "Rebalance failed for trigger {}: {}", + trigger_id.as_hex(), + reason + ); + + // Post a RebalanceFailed event to the event queue + if let Err(e) = self + .event_queue + .add_event(Event::RebalanceFailed { + trigger_payment_id: PaymentId::Trusted(trigger_id), + trusted_rebalance_payment_id, + amount_msat, + reason, + }) + .await + { + log_error!(self.logger, "Failed to add RebalanceFailed event: {e:?}"); + } + }, RebalancerEvent::OnChainRebalanceInitiated { trigger_id, channel_outpoint, diff --git a/orange-sdk/src/store.rs b/orange-sdk/src/store.rs index f6eb15a..dc48a14 100644 --- a/orange-sdk/src/store.rs +++ b/orange-sdk/src/store.rs @@ -12,6 +12,7 @@ //! shifted to minimize fees and ensure maximal security. use bitcoin_payment_instructions::amount::Amount; +use graduated_rebalancer::RebalancePersistence; use ldk_node::DynStore; use ldk_node::bitcoin::Txid; @@ -26,6 +27,8 @@ use ldk_node::payment::PaymentDetails; use std::collections::HashMap; use std::fmt; +use std::future::Future; +use std::pin::Pin; use std::str::FromStr; use std::sync::{Arc, RwLock, RwLockReadGuard}; use std::time::Duration; @@ -33,6 +36,8 @@ use std::time::Duration; const STORE_PRIMARY_KEY: &str = "orange_sdk"; const STORE_SECONDARY_KEY: &str = "payment_store"; const SPLICE_OUT_SECONDARY_KEY: &str = "splice_out"; +const REBALANCE_STATE_KEY: &str = "rebalance_state"; +const ONCHAIN_REBALANCE_STATE_KEY: &str = "onchain_rebalance_state"; /// The status of a transaction. This is used to track the state of a transaction #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -131,7 +136,7 @@ impl From for StoreTransaction { /// A PaymentId is a unique identifier for a payment. It can be either a Lightning payment or a /// Trusted payment. It is used to track the state of a payment and to provide information about /// the payment to the user. -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +#[derive(Copy, Clone, Hash, PartialEq, Eq)] pub enum PaymentId { /// A self-custodial payment identifier. SelfCustodial([u8; 32]), @@ -139,6 +144,15 @@ pub enum PaymentId { Trusted([u8; 32]), } +impl fmt::Debug for PaymentId { + fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> { + match self { + PaymentId::SelfCustodial(bytes) => write!(fmt, "SelfCustodial({})", bytes.as_hex()), + PaymentId::Trusted(s) => write!(fmt, "Trusted({})", s.as_hex()), + } + } +} + impl fmt::Display for PaymentId { fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> { match self { @@ -427,6 +441,132 @@ impl TxMetadataStore { } } +/// Wrapper for rebalance state persistence +#[derive(Clone)] +pub(crate) struct RebalancePersistenceStore { + store: Arc, +} + +impl RebalancePersistenceStore { + pub fn new(store: Arc) -> Self { + Self { store } + } +} + +impl RebalancePersistence for RebalancePersistenceStore { + fn insert_trusted_rebalance_state( + &self, state: graduated_rebalancer::RebalanceState, + ) -> Pin + Send + '_>> { + Box::pin(async move { + KVStore::write( + self.store.as_ref(), + STORE_PRIMARY_KEY, + "", + REBALANCE_STATE_KEY, + state.encode(), + ) + .await + .expect("We do not allow writes to fail"); + }) + } + + fn remove_trusted_rebalance_state(&self) -> Pin + Send + '_>> { + Box::pin(async move { + KVStore::remove(self.store.as_ref(), STORE_PRIMARY_KEY, "", REBALANCE_STATE_KEY, false) + .await + .expect("We do not allow removes to fail"); + }) + } + + fn get_trusted_rebalance( + &self, + ) -> Pin< + Box< + dyn Future, ()>> + + Send + + '_, + >, + > { + Box::pin(async move { + let res = + KVStore::read(self.store.as_ref(), STORE_PRIMARY_KEY, "", REBALANCE_STATE_KEY) + .await; + + match res { + Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(None), + Err(_) => Err(()), + Ok(data_bytes) => { + let state: graduated_rebalancer::RebalanceState = + Readable::read(&mut &data_bytes[..]).map_err(|_| ())?; + + Ok(Some(state)) + }, + } + }) + } + + fn insert_onchain_rebalance_state( + &self, state: graduated_rebalancer::OnChainRebalanceState, + ) -> Pin + Send + '_>> { + Box::pin(async move { + KVStore::write( + self.store.as_ref(), + STORE_PRIMARY_KEY, + "", + ONCHAIN_REBALANCE_STATE_KEY, + state.encode(), + ) + .await + .expect("We do not allow writes to fail"); + }) + } + + fn remove_onchain_rebalance_state(&self) -> Pin + Send + '_>> { + Box::pin(async move { + KVStore::remove( + self.store.as_ref(), + STORE_PRIMARY_KEY, + "", + ONCHAIN_REBALANCE_STATE_KEY, + false, + ) + .await + .expect("We do not allow removes to fail"); + }) + } + + fn get_onchain_rebalance( + &self, + ) -> Pin< + Box< + dyn Future, ()>> + + Send + + '_, + >, + > { + Box::pin(async move { + let res = KVStore::read( + self.store.as_ref(), + STORE_PRIMARY_KEY, + "", + ONCHAIN_REBALANCE_STATE_KEY, + ) + .await; + + match res { + Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(None), + Err(_) => Err(()), + Ok(data_bytes) => { + let state: graduated_rebalancer::OnChainRebalanceState = + Readable::read(&mut &data_bytes[..]).map_err(|_| ())?; + + Ok(Some(state)) + }, + } + }) + } +} + const REBALANCE_ENABLED_KEY: &str = "rebalance_enabled"; pub(crate) async fn get_rebalance_enabled(store: &DynStore) -> bool { diff --git a/orange-sdk/src/trusted_wallet/cashu/mod.rs b/orange-sdk/src/trusted_wallet/cashu/mod.rs index d001c95..b9773ae 100644 --- a/orange-sdk/src/trusted_wallet/cashu/mod.rs +++ b/orange-sdk/src/trusted_wallet/cashu/mod.rs @@ -1,7 +1,7 @@ //! An implementation of `TrustedWalletInterface` using the Cashu (CDK) SDK. -use crate::bitcoin::hex::DisplayHex; use crate::logging::Logger; +use crate::rebalancer::RebalanceEventHandlerHolder; use crate::runtime::Runtime; use crate::store::{PaymentId, TxMetadataStore, TxStatus}; use crate::trusted_wallet::{Payment, TrustedError, TrustedWalletInterface}; @@ -28,8 +28,6 @@ use cdk::wallet::Wallet; use cdk::wallet::types::{Transaction, TransactionDirection}; use cdk::{Amount as CdkAmount, StreamExt}; -use graduated_rebalancer::ReceivedLightningPayment; - use tokio::sync::{mpsc, watch}; use std::collections::HashMap; @@ -59,8 +57,7 @@ pub struct Cashu { cashu_wallet: Arc, unit: CurrencyUnit, shutdown_sender: watch::Sender<()>, - payment_success_sender: watch::Sender<()>, - payment_success_flag: watch::Receiver<()>, + rebalance_event_handler: RebalanceEventHandlerHolder, logger: Arc, supports_bolt12: bool, mint_quote_sender: mpsc::Sender, @@ -297,7 +294,8 @@ impl TrustedWalletInterface for Cashu { let event_queue = Arc::clone(&self.event_queue); let tx_metadata = self.tx_metadata.clone(); let quote_id = quote.id.clone(); - let payment_success_sender = self.payment_success_sender.clone(); + let rebalance_handler = Arc::clone(&self.rebalance_event_handler); + let unit = self.unit.clone(); self.runtime.spawn_background_task(async move { let mut metadata = HashMap::new(); if let Some(hash) = &payment_hash { @@ -311,15 +309,14 @@ impl TrustedWalletInterface for Cashu { log_info!(logger, "Successfully sent for quote: {quote_id}"); let payment_id = PaymentId::Trusted(payment_id); + let fee_msat = convert_amount(res.fee_paid, &unit) + .unwrap_or(Amount::ZERO) + .milli_sats(); + let is_rebalance = { let map = tx_metadata.read(); map.get(&payment_id).is_some_and(|m| m.ty.is_rebalance()) }; - if is_rebalance { - // make sure we still send payment success - payment_success_sender.send(()).unwrap(); - return; - } let preimage: Option = match &res.preimage { Some(str) => match FromHex::from_hex(str) { @@ -368,6 +365,18 @@ impl TrustedWalletInterface for Cashu { let payment_preimage = preimage.unwrap_or(PaymentPreimage([0u8; 32])); + if is_rebalance { + log_info!( + logger, + "Notifying rebalancer of successful trusted payment: {payment_id:?}" + ); + // Notify the rebalance event handler + rebalance_handler + .notify_trusted_payment_sent(hash.0, Some(fee_msat)) + .await; + return; + } + if tx_metadata .set_preimage(payment_id, payment_preimage.0) .await @@ -379,17 +388,14 @@ impl TrustedWalletInterface for Cashu { ); } - let fee_paid_sat: u64 = res.fee_paid.into(); let _ = event_queue .add_event(Event::PaymentSuccessful { payment_id, payment_hash: hash, payment_preimage, - fee_paid_msat: Some(fee_paid_sat * 1_000), // convert to msats + fee_paid_msat: Some(fee_msat), }) .await; - - payment_success_sender.send(()).unwrap(); }, MeltQuoteState::Failed => { log_error!(logger, "Melt failed for quote: {quote_id}"); @@ -399,7 +405,21 @@ impl TrustedWalletInterface for Cashu { map.get(&payment_id).is_some_and(|m| m.ty.is_rebalance()) }; - if !is_rebalance { + if is_rebalance { + log_info!( + logger, + "Notifying rebalancer of failed trusted payment: {payment_id:?}" + ); + // Notify the rebalance event handler + if let Some(hash) = payment_hash { + rebalance_handler + .notify_trusted_payment_failed( + hash.0, + format!("Cashu melt failed for quote {quote_id}"), + ) + .await; + } + } else { let _ = event_queue .add_event(Event::PaymentFailed { payment_id, @@ -426,7 +446,21 @@ impl TrustedWalletInterface for Cashu { map.get(&payment_id).is_some_and(|m| m.ty.is_rebalance()) }; - if !is_rebalance { + if is_rebalance { + log_info!( + logger, + "Notifying rebalancer of failed trusted payment: {payment_id:?}" + ); + // Notify the rebalance event handler + if let Some(hash) = payment_hash { + rebalance_handler + .notify_trusted_payment_failed( + hash.0, + format!("Cashu melt error for quote {quote_id}: {e}"), + ) + .await; + } + } else { let _ = event_queue .add_event(Event::PaymentFailed { payment_id, @@ -443,39 +477,52 @@ impl TrustedWalletInterface for Cashu { }) } - fn await_payment_success( + fn stop(&self) -> Pin + Send + '_>> { + Box::pin(async move { + log_info!(self.logger, "Stopping Cashu wallet"); + let _ = self.shutdown_sender.send(()); + }) + } + + fn find_payment_by_hash( &self, payment_hash: [u8; 32], - ) -> Pin> + Send + '_>> { + ) -> Pin, TrustedError>> + Send + '_>> { Box::pin(async move { - loop { - let txs = self - .cashu_wallet - .list_transactions(Some(TransactionDirection::Outgoing)) - .await - .ok()?; - - let hex = payment_hash.to_lower_hex_string(); - let tx = txs.iter().find(|tx| { - tx.metadata.get(PAYMENT_HASH_METADATA_KEY).is_some_and(|h| h == &hex) - }); - - if let Some(tx) = tx { - let payment_id = Self::id_to_32_byte_array(tx.quote_id.as_ref().expect("safe")); - return Some(ReceivedLightningPayment { - id: payment_id, - fee_paid_msat: Some(convert_amount(tx.fee, &self.unit).ok()?.milli_sats()), - }); + // Check all active melt quotes (pending outbound payments) + let quotes = self.cashu_wallet.get_active_melt_quotes().await.map_err(|e| { + TrustedError::WalletOperationFailed(format!( + "Failed to get active melt quotes: {e}" + )) + })?; + + for quote in quotes { + if let Ok(invoice) = Bolt11Invoice::from_str("e.request) { + if invoice.payment_hash().to_byte_array() == payment_hash { + let payment_id = Self::id_to_32_byte_array("e.id); + return Ok(Some(payment_id)); + } } + } + + // if not found, check completed transactions + let transactions = self.cashu_wallet.list_transactions(None).await.map_err(|e| { + TrustedError::WalletOperationFailed(format!("Failed to list transactions: {e}")) + })?; - self.await_payment_success().await; + for transaction in transactions { + if transaction.direction == TransactionDirection::Outgoing { + if let Some(quote_id) = &transaction.quote_id { + if let Ok(invoice) = Bolt11Invoice::from_str(quote_id) { + if invoice.payment_hash().to_byte_array() == payment_hash { + let payment_id = Self::id_to_32_byte_array(quote_id); + return Ok(Some(payment_id)); + } + } + } + } } - }) - } - fn stop(&self) -> Pin + Send + '_>> { - Box::pin(async move { - log_info!(self.logger, "Stopping Cashu wallet"); - let _ = self.shutdown_sender.send(()); + Ok(None) }) } } @@ -483,9 +530,11 @@ impl TrustedWalletInterface for Cashu { const PAYMENT_HASH_METADATA_KEY: &str = "payment_hash"; impl Cashu { + #[allow(clippy::too_many_arguments)] pub(crate) async fn init( config: &WalletConfig, cashu_config: CashuConfig, store: Arc, - event_queue: Arc, tx_metadata: TxMetadataStore, logger: Arc, + event_queue: Arc, tx_metadata: TxMetadataStore, + rebalance_event_handler: RebalanceEventHandlerHolder, logger: Arc, runtime: Arc, ) -> Result { match &cashu_config.unit { @@ -541,7 +590,6 @@ impl Cashu { .unwrap_or(false); let (shutdown_sender, mut shutdown_receiver) = watch::channel::<()>(()); - let (payment_success_sender, payment_success_flag) = watch::channel(()); // Create channel for mint quote monitoring with bounded capacity let (mint_quote_sender, mut mint_quote_receiver) = mpsc::channel::(32); @@ -620,8 +668,7 @@ impl Cashu { cashu_wallet, unit: cashu_config.unit, shutdown_sender, - payment_success_sender, - payment_success_flag, + rebalance_event_handler, logger, supports_bolt12, mint_quote_sender, @@ -713,12 +760,6 @@ impl Cashu { } Ok(()) } - - pub(crate) async fn await_payment_success(&self) { - let mut flag = self.payment_success_flag.clone(); - flag.mark_unchanged(); - let _ = flag.changed().await; - } } fn convert_amount(cdk_amount: CdkAmount, unit: &CurrencyUnit) -> Result { diff --git a/orange-sdk/src/trusted_wallet/dummy.rs b/orange-sdk/src/trusted_wallet/dummy.rs index 2c8c241..7fb1775 100644 --- a/orange-sdk/src/trusted_wallet/dummy.rs +++ b/orange-sdk/src/trusted_wallet/dummy.rs @@ -2,6 +2,7 @@ use crate::EventQueue; use crate::bitcoin::hashes::Hash; +use crate::rebalancer::RebalanceEventHandlerHolder; use crate::runtime::Runtime; use crate::store::{PaymentId, TxMetadataStore, TxStatus}; use crate::trusted_wallet::{Payment, TrustedError, TrustedWalletInterface}; @@ -9,11 +10,9 @@ use bitcoin_payment_instructions::PaymentMethod; use bitcoin_payment_instructions::amount::Amount; use corepc_node::client::bitcoin::Network; use corepc_node::{Node as Bitcoind, get_available_port}; -use graduated_rebalancer::ReceivedLightningPayment; -use ldk_node::lightning::ln::channelmanager; use ldk_node::lightning::ln::msgs::SocketAddress; use ldk_node::lightning_invoice::{Bolt11Invoice, Bolt11InvoiceDescription, Description}; -use ldk_node::payment::{PaymentKind, PaymentStatus}; +use ldk_node::payment::{PaymentDirection, PaymentKind}; use ldk_node::{Event, Node}; use rand::RngCore; use std::env::temp_dir; @@ -21,7 +20,7 @@ use std::pin::Pin; use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::Duration; -use tokio::sync::{RwLock, watch}; +use tokio::sync::RwLock; use uuid::Uuid; /// A dummy implementation of `TrustedWalletInterface` for testing purposes. @@ -32,7 +31,6 @@ pub(crate) struct DummyTrustedWallet { current_bal_msats: Arc, payments: Arc>>, ldk_node: Arc, - payment_success_flag: watch::Receiver<()>, } #[derive(Clone)] @@ -50,7 +48,8 @@ impl DummyTrustedWallet { /// Creates a new `DummyTrustedWallet` instance. pub(crate) async fn new( uuid: Uuid, lsp: &Node, bitcoind: &Bitcoind, tx_metadata: TxMetadataStore, - event_queue: Arc, rt: Arc, + event_queue: Arc, rebalance_event_handler: RebalanceEventHandlerHolder, + rt: Arc, ) -> Self { let mut builder = ldk_node::Builder::new(); builder.set_network(Network::Regtest); @@ -81,11 +80,10 @@ impl DummyTrustedWallet { let current_bal_msats = Arc::new(AtomicU64::new(0)); let payments: Arc>> = Arc::new(RwLock::new(vec![])); - let (payment_success_sender, payment_success_flag) = watch::channel(()); - let events_ref = Arc::clone(&ldk_node); let bal = Arc::clone(¤t_bal_msats); let pays = Arc::clone(&payments); + let rebalance_handler = Arc::clone(&rebalance_event_handler); rt.spawn_cancellable_background_task(async move { loop { let event = events_ref.next_event_async().await; @@ -114,6 +112,16 @@ impl DummyTrustedWallet { map.get(&payment_id).is_some_and(|m| m.ty.is_rebalance()) }; + // Notify rebalancer if this is a rebalance payment + if is_rebalance { + println!( + "Notifying rebalancer of successful trusted payment: {payment_id:?}" + ); + rebalance_handler + .notify_trusted_payment_sent(payment_hash.0, fee_paid_msat) + .await; + } + // Send a PaymentSuccessful event if not a rebalance if !is_rebalance { if tx_metadata @@ -133,8 +141,6 @@ impl DummyTrustedWallet { .await .unwrap(); } - - payment_success_sender.send(()).unwrap(); }, Event::PaymentFailed { payment_id, payment_hash, reason } => { // convert id @@ -154,8 +160,20 @@ impl DummyTrustedWallet { map.get(&payment_id).is_some_and(|m| m.ty.is_rebalance()) }; - // Send a PaymentFailed event if not a rebalance - if !is_rebalance { + // Notify rebalancer or send event depending on payment type + if is_rebalance { + println!( + "Notifying rebalancer of failed trusted payment: {payment_id:?}" + ); + if let Some(hash) = payment_hash { + let reason_str = reason + .map(|r| format!("{:?}", r)) + .unwrap_or_else(|| "Unknown".to_string()); + rebalance_handler + .notify_trusted_payment_failed(hash.0, reason_str) + .await; + } + } else { event_queue .add_event(crate::Event::PaymentFailed { payment_id, @@ -254,13 +272,7 @@ impl DummyTrustedWallet { panic!("No usable channels found {channels:?}"); } - DummyTrustedWallet { current_bal_msats, payments, ldk_node, payment_success_flag } - } - - pub(crate) async fn await_payment_success(&self) { - let mut flag = self.payment_success_flag.clone(); - flag.mark_unchanged(); - let _ = flag.changed().await; + DummyTrustedWallet { current_bal_msats, payments, ldk_node } } } @@ -378,43 +390,35 @@ impl TrustedWalletInterface for DummyTrustedWallet { }) } - fn await_payment_success( - &self, payment_hash: [u8; 32], - ) -> Pin> + Send + '_>> { + fn stop(&self) -> Pin + Send + '_>> { Box::pin(async move { - let id = channelmanager::PaymentId(payment_hash); - loop { - if let Some(payment) = self.ldk_node.payment(&id) { - let counterparty_skimmed_fee_msat = match payment.kind { - PaymentKind::Bolt11 { hash, .. } => { - debug_assert!(hash.0 == payment_hash, "Payment Hash mismatch"); - None - }, - PaymentKind::Bolt11Jit { hash, counterparty_skimmed_fee_msat, .. } => { - debug_assert!(hash.0 == payment_hash, "Payment Hash mismatch"); - counterparty_skimmed_fee_msat - }, - _ => return None, /* Ignore other payment kinds, we only care about the one we just sent. */ - }; - match payment.status { - PaymentStatus::Succeeded => { - return Some(ReceivedLightningPayment { - id: payment.id.0, - fee_paid_msat: counterparty_skimmed_fee_msat, - }); - }, - PaymentStatus::Pending => {}, - PaymentStatus::Failed => return None, - } - } - self.await_payment_success().await; - } + let _ = self.ldk_node.stop(); }) } - fn stop(&self) -> Pin + Send + '_>> { + fn find_payment_by_hash( + &self, payment_hash: [u8; 32], + ) -> Pin, TrustedError>> + Send + '_>> { Box::pin(async move { - let _ = self.ldk_node.stop(); + for payment in self.ldk_node.list_payments() { + // Only look at outbound payments + if payment.direction != PaymentDirection::Outbound { + continue; + } + + let hash = match &payment.kind { + PaymentKind::Bolt11 { hash, .. } | PaymentKind::Bolt11Jit { hash, .. } => { + Some(hash.0) + }, + _ => None, + }; + + if hash == Some(payment_hash) { + return Ok(Some(mangle_payment_id(payment.id.0))); + } + } + + Ok(None) }) } } diff --git a/orange-sdk/src/trusted_wallet/mod.rs b/orange-sdk/src/trusted_wallet/mod.rs index 19d9874..b5c4b79 100644 --- a/orange-sdk/src/trusted_wallet/mod.rs +++ b/orange-sdk/src/trusted_wallet/mod.rs @@ -7,7 +7,7 @@ use ldk_node::lightning_invoice::Bolt11Invoice; use bitcoin_payment_instructions::PaymentMethod; use bitcoin_payment_instructions::amount::Amount; -use graduated_rebalancer::ReceivedLightningPayment; +use graduated_rebalancer::PendingPaymentInfo; use std::future::Future; use std::pin::Pin; @@ -44,7 +44,7 @@ pub struct Payment { pub(crate) type DynTrustedWalletInterface = dyn TrustedWalletInterface + Send + Sync; /// Represents a trait for a trusted wallet interface. -pub trait TrustedWalletInterface: Send + Sync + private::Sealed { +pub(crate) trait TrustedWalletInterface: Send + Sync + private::Sealed { /// Returns the current balance of the wallet. fn get_balance( &self, @@ -83,15 +83,18 @@ pub trait TrustedWalletInterface: Send + Sync + private::Sealed { &self, method: PaymentMethod, amount: Amount, ) -> Pin> + Send + '_>>; - /// Waits for a payment with the given payment hash to succeed. - /// Returns the `ReceivedLightningPayment` if successful, or `None` if it fails or times out. - fn await_payment_success( - &self, payment_hash: [u8; 32], - ) -> Pin> + Send + '_>>; - /// Stops the wallet, cleaning up any resources. /// This is typically used to gracefully shut down the wallet. fn stop(&self) -> Pin + Send + '_>>; + + /// Find a pending outbound payment by its payment hash. + /// Returns `Some(payment_id)` if a pending payment with the given hash exists, + /// `None` otherwise. This is used during recovery to determine if a payment was + /// actually initiated before a crash. + #[allow(clippy::type_complexity)] + fn find_payment_by_hash( + &self, payment_hash: [u8; 32], + ) -> Pin, TrustedError>> + Send + '_>>; } pub(crate) struct WalletTrusted(pub(crate) Arc>); @@ -117,10 +120,14 @@ impl graduated_rebalancer::TrustedWallet for Box::pin(async move { self.0.pay(method, amount).await }) } - fn await_payment_success( + fn find_payment_by_hash( &self, payment_hash: [u8; 32], - ) -> Pin> + Send + '_>> { - Box::pin(async move { self.0.await_payment_success(payment_hash).await }) + ) -> Pin, Self::Error>> + Send + '_>> + { + Box::pin(async move { + let payment_id = self.0.find_payment_by_hash(payment_hash).await?; + Ok(payment_id.map(|id| PendingPaymentInfo { payment_id: id })) + }) } } diff --git a/orange-sdk/src/trusted_wallet/spark/mod.rs b/orange-sdk/src/trusted_wallet/spark/mod.rs index ab8d546..ccfd9d1 100644 --- a/orange-sdk/src/trusted_wallet/spark/mod.rs +++ b/orange-sdk/src/trusted_wallet/spark/mod.rs @@ -3,8 +3,11 @@ pub(crate) mod spark_store; use crate::bitcoin::Network; +use crate::bitcoin::hex::DisplayHex; use crate::bitcoin::hex::FromHex; use crate::logging::Logger; +use crate::rebalancer::RebalanceEventHandlerHolder; +use crate::runtime::Runtime; use crate::store::{PaymentId, TxMetadataStore, TxStatus}; use crate::trusted_wallet::{Payment, TrustedError, TrustedWalletInterface}; use crate::{Event, EventQueue, InitFailure, Seed, WalletConfig}; @@ -20,15 +23,12 @@ use bitcoin_payment_instructions::amount::Amount; use breez_sdk_spark::{ BreezSdk, EventListener, GetInfoRequest, ListPaymentsRequest, OptimizationConfig, - PaymentDetails, PaymentStatus, PaymentType, PrepareSendPaymentRequest, ReceivePaymentMethod, + PaymentDetails, PaymentType, PrepareSendPaymentRequest, ReceivePaymentMethod, ReceivePaymentRequest, SdkBuilder, SdkError, SdkEvent, SendPaymentMethod, SendPaymentRequest, }; -use graduated_rebalancer::ReceivedLightningPayment; - use tokio::sync::watch; -use crate::runtime::Runtime; use std::future::Future; use std::pin::Pin; use std::str::FromStr; @@ -87,7 +87,6 @@ impl SparkWalletConfig { pub(crate) struct Spark { spark_wallet: Arc, shutdown_sender: watch::Sender<()>, - payment_success_flag: watch::Receiver<()>, logger: Arc, } @@ -222,48 +221,48 @@ impl TrustedWalletInterface for Spark { }) } - fn await_payment_success( + fn stop(&self) -> Pin + Send + '_>> { + Box::pin(async move { + log_info!(self.logger, "Stopping Spark wallet"); + let _ = self.shutdown_sender.send(()); + }) + } + + fn find_payment_by_hash( &self, payment_hash: [u8; 32], - ) -> Pin> + Send + '_>> { + ) -> Pin, TrustedError>> + Send + '_>> { Box::pin(async move { - loop { - let res = - self.spark_wallet.list_payments(ListPaymentsRequest::default()).await.ok()?; - - let tx = res.payments.into_iter().find(|p| { - if let Some(PaymentDetails::Lightning { payment_hash: ph, .. }) = &p.details { - let hash: Option<[u8; 32]> = FromHex::from_hex(ph).ok(); - hash == Some(payment_hash) - } else { - false - } - })?; + let payment_hash_hex = payment_hash.to_lower_hex_string(); + let resp = self.spark_wallet.list_payments(ListPaymentsRequest::default()).await?; - if tx.status == PaymentStatus::Completed { - return Some(ReceivedLightningPayment { - id: payment_hash, - fee_paid_msat: Some((tx.fees * 1_000) as u64), - }); + for payment in resp.payments { + // Only look at outbound payments + if payment.payment_type != PaymentType::Send { + continue; } - self.await_payment_success().await; + // Check if the payment hash matches + if let Some(PaymentDetails::Lightning { payment_hash: hash, .. }) = &payment.details + { + if hash == &payment_hash_hex { + let id = parse_payment_id(&payment.id)?; + return Ok(Some(id)); + } + } } - }) - } - fn stop(&self) -> Pin + Send + '_>> { - Box::pin(async move { - log_info!(self.logger, "Stopping Spark wallet"); - let _ = self.shutdown_sender.send(()); + Ok(None) }) } } impl Spark { /// Initialize a new Spark wallet instance with the given configuration. + #[allow(clippy::too_many_arguments)] pub(crate) async fn init( config: &WalletConfig, spark_config: SparkWalletConfig, store: Arc, - event_queue: Arc, tx_metadata: TxMetadataStore, logger: Arc, + event_queue: Arc, tx_metadata: TxMetadataStore, + rebalance_event_handler: RebalanceEventHandlerHolder, logger: Arc, runtime: Arc, ) -> Result { let spark_config: breez_sdk_spark::Config = spark_config.to_breez_config(config.network)?; @@ -287,12 +286,11 @@ impl Spark { log_info!(logger, "Started Spark wallet!"); let (shutdown_sender, shutdown_receiver) = watch::channel::<()>(()); - let (payment_success_sender, payment_success_flag) = watch::channel(()); let listener = SparkEventHandler { event_queue: Arc::clone(&event_queue), tx_metadata, - payment_success_sender, + rebalance_event_handler, logger: Arc::clone(&logger), }; @@ -307,20 +305,14 @@ impl Spark { log_info!(logger, "Spark wallet initialized"); - Ok(Spark { spark_wallet, shutdown_sender, payment_success_flag, logger }) - } - - pub(crate) async fn await_payment_success(&self) { - let mut flag = self.payment_success_flag.clone(); - flag.mark_unchanged(); - let _ = flag.changed().await; + Ok(Spark { spark_wallet, shutdown_sender, logger }) } } struct SparkEventHandler { event_queue: Arc, tx_metadata: TxMetadataStore, - payment_success_sender: watch::Sender<()>, + rebalance_event_handler: RebalanceEventHandlerHolder, logger: Arc, } @@ -381,16 +373,6 @@ impl SparkEventHandler { map.get(&payment_id).is_some_and(|m| m.ty.is_rebalance()) }; - if is_rebalance { - log_info!( - self.logger, - "Ignoring successful payment event for rebalance payment: {payment_id:?}" - ); - // make sure we still send payment success - self.payment_success_sender.send(()).unwrap(); - return Ok(()); - } - let preimage = preimage.ok_or_else(|| { TrustedError::Other( "Payment succeeded but preimage is missing".to_string(), @@ -405,6 +387,19 @@ impl SparkEventHandler { TrustedError::Other(format!("Invalid payment_hash hex: {e:?}")) })?; + if is_rebalance { + log_info!( + self.logger, + "Notifying rebalancer of successful trusted payment: {payment_id:?}" + ); + // Notify the rebalance event handler + let fee_msat = Some((payment.fees * 1_000) as u64); + self.rebalance_event_handler + .notify_trusted_payment_sent(payment_hash, fee_msat) + .await; + return Ok(()); + } + if self.tx_metadata.set_preimage(payment_id, preimage).await.is_err() { log_error!( self.logger, @@ -420,8 +415,6 @@ impl SparkEventHandler { fee_paid_msat: Some((payment.fees * 1_000) as u64), // convert to msats }) .await?; - - self.payment_success_sender.send(()).unwrap(); }, _ => { log_debug!(self.logger, "Unsupported payment details for Send: {payment:?}") @@ -476,6 +469,10 @@ impl SparkEventHandler { PaymentType::Send => match payment.details { Some(PaymentDetails::Lightning { payment_hash, .. }) => { let payment_id = PaymentId::Trusted(id); + let payment_hash: [u8; 32] = FromHex::from_hex(&payment_hash).map_err(|e| { + TrustedError::Other(format!("Invalid payment_hash hex: {e:?}")) + })?; + let is_rebalance = { let map = self.tx_metadata.read(); map.get(&payment_id).is_some_and(|m| m.ty.is_rebalance()) @@ -484,15 +481,16 @@ impl SparkEventHandler { if is_rebalance { log_info!( self.logger, - "Ignoring failed payment event for rebalance payment: {payment_id:?}" + "Notifying rebalancer of failed trusted payment: {payment_id:?}" ); + // Notify the rebalance event handler + let reason = format!("Spark payment failed for payment {}", payment.id); + self.rebalance_event_handler + .notify_trusted_payment_failed(payment_hash, reason) + .await; return Ok(()); } - let payment_hash: [u8; 32] = FromHex::from_hex(&payment_hash).map_err(|e| { - TrustedError::Other(format!("Invalid payment_hash hex: {e:?}")) - })?; - self.event_queue .add_event(Event::PaymentFailed { payment_id, diff --git a/orange-sdk/tests/integration_tests.rs b/orange-sdk/tests/integration_tests.rs index e169a39..c9a77a8 100644 --- a/orange-sdk/tests/integration_tests.rs +++ b/orange-sdk/tests/integration_tests.rs @@ -169,6 +169,7 @@ async fn test_pay_from_trusted() { } #[tokio::test(flavor = "multi_thread")] +#[test_log::test] async fn test_sweep_to_ln() { test_utils::run_test(|params| async move { let wallet = Arc::clone(¶ms.wallet); @@ -195,6 +196,7 @@ async fn test_sweep_to_ln() { }) .await; + println!("waiting for payment recv"); let event = wait_next_event(&wallet).await; match event { Event::PaymentReceived { .. } => {}, @@ -216,12 +218,14 @@ async fn test_sweep_to_ln() { .await; // receive to trusted wallet + println!("waiting for payment recv 2"); let event = wait_next_event(&wallet).await; match event { Event::PaymentReceived { .. } => {}, e => panic!("Expected PaymentReceived event, got {e:?}"), } + println!("waiting for rebalance"); let event = wait_next_event(&wallet).await; assert!( matches!(event, Event::RebalanceInitiated { .. }), @@ -235,16 +239,18 @@ async fn test_sweep_to_ln() { .await; // wait for payment received + println!("waiting for channel opened"); let event = wait_next_event(&wallet).await; match event { Event::ChannelOpened { counterparty_node_id, .. } => { assert_eq!(counterparty_node_id, lsp.node_id()); }, - _ => panic!("Expected ChannelOpened event"), + e => panic!("Expected ChannelOpened event, got {e:?}"), } let expect_amt = intermediate_amt.saturating_add(recv_amt); + println!("waiting for payment recv"); let event = wait_next_event(&wallet).await; match event { Event::PaymentReceived { payment_id, amount_msat, lsp_fee_msats, .. } => { @@ -255,6 +261,7 @@ async fn test_sweep_to_ln() { e => panic!("Expected RebalanceSuccessful event, got {e:?}"), } + println!("waiting for rebalance successful"); let event = wait_next_event(&wallet).await; match event { Event::RebalanceSuccessful { amount_msat, fee_msat, .. } => {