Skip to content

Decrease Variance of LogNormal to Converge on Expected Value Sooner #301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 49 additions & 27 deletions simln-lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1592,7 +1592,7 @@ async fn track_payment_result(

#[cfg(test)]
mod tests {
use crate::clock::SystemClock;
use crate::clock::{Clock, SimulationClock};
use crate::test_utils::{MockLightningNode, TestNodesResult};
use crate::{
get_payment_delay, test_utils, test_utils::LightningTestNodeBuilder, LightningError,
Expand Down Expand Up @@ -2031,20 +2031,20 @@ mod tests {
let (shutdown_trigger, shutdown_listener) = triggered::trigger();

// Create simulation without a timeout.
let clock = Arc::new(SimulationClock::new(10).unwrap());
let start = clock.now();
let simulation = Simulation::new(
SimulationCfg::new(None, 100, 2.0, None, None),
network.get_client_hashmap(),
TaskTracker::new(),
Arc::new(SystemClock {}),
clock.clone(),
shutdown_trigger,
shutdown_listener,
);

// Run the simulation
let start = std::time::Instant::now();
let _ = simulation.run(&vec![activity_1, activity_2]).await;
let elapsed = start.elapsed();

let elapsed = clock.now().duration_since(start).unwrap();
let expected_payment_list = vec![
network.nodes[1].pubkey,
network.nodes[3].pubkey,
Expand All @@ -2058,13 +2058,14 @@ mod tests {
network.nodes[3].pubkey,
];

// Check that simulation ran 20ish seconds because
// from activity_1 there are 5 payments with a wait_time of 2s -> 10s
// from activity_2 there are 5 payments with a wait_time of 4s -> 20s
// but the wait time is interleave between the payments.
// Check that simulation ran 20ish seconds because:
// - from activity_1 there are 5 payments with a wait_time of 2s -> 10s
// - from activity_2 there are 5 payments with a wait_time of 4s -> 20s
// - but the wait time is interleave between the payments.
// Since we're running with a sped up clock, we allow a little more leeway.
assert!(
elapsed <= Duration::from_secs(21),
"Simulation should have run no more than 21, took {:?}",
elapsed <= Duration::from_secs(30),
"Simulation should have run no more than 30, took {:?}",
elapsed
);

Expand Down Expand Up @@ -2098,55 +2099,76 @@ mod tests {

let (shutdown_trigger, shutdown_listener) = triggered::trigger();

// Create simulation with a defined seed.
// Create simulation with a defined seed, and limit it to running for 45 seconds.
let clock = Arc::new(SimulationClock::new(20).unwrap());
let simulation = Simulation::new(
SimulationCfg::new(Some(25), 100, 2.0, None, Some(42)),
SimulationCfg::new(Some(45), 100, 2.0, None, Some(42)),
network.get_client_hashmap(),
TaskTracker::new(),
Arc::new(SystemClock {}),
clock.clone(),
shutdown_trigger,
shutdown_listener,
);

// Run the simulation
let start = std::time::Instant::now();
let start = clock.now();
let _ = simulation.run(&[]).await;
let elapsed = start.elapsed();
let elapsed = clock.now().duration_since(start).unwrap();

assert!(
elapsed >= Duration::from_secs(25),
"Simulation should have run at least for 25s, took {:?}",
elapsed >= Duration::from_secs(45),
"Simulation should have run at least for 45s, took {:?}",
elapsed
);

// We're running with a sped up clock, so we're not going to hit exactly the same number
// of payments each time. We settle for asserting that our first 20 are deterministic.
// This ordering is set by running the simulation for 25 seconds, and we run for a total
// of 45 seconds so we can reasonably expect that we'll always get at least these 20
// payments.
let expected_payment_list = vec![
pk1, pk2, pk1, pk1, pk1, pk3, pk3, pk3, pk4, pk3, pk2, pk1, pk4,
pk2, pk1, pk1, pk3, pk2, pk4, pk3, pk2, pk2, pk4, pk3, pk2, pk3, pk2, pk3, pk4, pk4,
pk2, pk3, pk1,
];

assert!(
payments_list.lock().unwrap().as_ref() == expected_payment_list,
let actual_payments: Vec<PublicKey> = payments_list
.lock()
.unwrap()
.iter()
.cloned()
.take(20)
.collect();
assert_eq!(
actual_payments,
expected_payment_list,
"The expected order of payments is not correct: {:?} vs {:?}",
payments_list.lock().unwrap(),
expected_payment_list,
);

// remove all the payments made in the previous execution
payments_list.lock().unwrap().clear();

let (shutdown_trigger, shutdown_listener) = triggered::trigger();

// Create the same simulation as before but with different seed.
let simulation2 = Simulation::new(
SimulationCfg::new(Some(25), 100, 2.0, None, Some(500)),
SimulationCfg::new(Some(45), 100, 2.0, None, Some(500)),
network.get_client_hashmap(),
TaskTracker::new(),
Arc::new(SystemClock {}),
clock.clone(),
shutdown_trigger,
shutdown_listener,
);
let _ = simulation2.run(&[]).await;

assert!(
payments_list.lock().unwrap().as_ref() != expected_payment_list,
let actual_payments: Vec<PublicKey> = payments_list
.lock()
.unwrap()
.iter()
.cloned()
.take(20)
.collect();
assert_ne!(
actual_payments, expected_payment_list,
"The expected order of payments shoud be different because a different is used"
);
}
Expand Down
96 changes: 52 additions & 44 deletions simln-lib/src/random_activity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,20 @@ impl RandomPaymentActivity {

Ok(())
}

/// Returns a log normal distribution with our expected payment size as its mean and variance
/// that is scaled by the channel size (larger for larger channels).
fn log_normal(&self, channel_size_msat: f64) -> Result<LogNormal<f64>, PaymentGenerationError> {
let expected_payment_amt_msat = self.expected_payment_amt as f64;
let variance = 1000.0 * channel_size_msat.ln();
let sigma_square =
((variance * variance) / (expected_payment_amt_msat * expected_payment_amt_msat) + 1.0)
.ln();
let sigma = sigma_square.sqrt();
let mu = expected_payment_amt_msat.ln() - sigma_square / 2.0;

LogNormal::new(mu, sigma).map_err(|e| PaymentGenerationError(e.to_string()))
}
}

/// Returns the number of events that the simulation expects the node to process per month based on its capacity, a
Expand Down Expand Up @@ -260,29 +274,17 @@ impl PaymentGenerator for RandomPaymentActivity {
"destination amount required for payment activity generator".to_string(),
))?;

let payment_limit = std::cmp::min(self.source_capacity, destination_capacity) / 2;

let ln_pmt_amt = (self.expected_payment_amt as f64).ln();
let ln_limit = (payment_limit as f64).ln();

let mu = 2.0 * ln_pmt_amt - ln_limit;
let sigma_square = 2.0 * (ln_limit - ln_pmt_amt);

if sigma_square < 0.0 {
return Err(PaymentGenerationError(format!(
"payment amount not possible for limit: {payment_limit}, sigma squared: {sigma_square}"
)));
}

let log_normal = LogNormal::new(mu, sigma_square.sqrt())
.map_err(|e| PaymentGenerationError(e.to_string()))?;
let largest_channel_capacity_msat =
std::cmp::min(self.source_capacity, destination_capacity) / 2;

let mut rng = self
.rng
.0
.lock()
.map_err(|e| PaymentGenerationError(e.to_string()))?;
let payment_amount = log_normal.sample(&mut *rng) as u64;
let payment_amount = self
.log_normal(largest_channel_capacity_msat as f64)?
.sample(&mut *rng) as u64;

Ok(payment_amount)
}
Expand Down Expand Up @@ -456,39 +458,45 @@ mod tests {
}

#[test]
fn test_payment_amount() {
// The special cases for payment_amount are those who may make the internal log normal distribution fail to build, which happens if
// sigma squared is either +-INF or NaN. Given that the constructor of the PaymentActivityGenerator already forces its internal values
// to be greater than zero, the only values that are left are all values of `destination_capacity` smaller or equal to the `source_capacity`
// All of them will yield a sigma squared smaller than 0, which we have a sanity check for.
let expected_payment = get_random_int(1, 100);
let source_capacity = 2 * expected_payment;
let rng = MutRng::new(Some((u64::MAX, None)));
fn test_log_normal_distribution_within_one_std_dev() {
// Tests that samples from the log normal distribution fall within one standard
// deviation of our expected variance. We intentionally use fresh randomness in each
// run of this test because this property should hold for any seed.
let dest_capacity_msat = 200_000_000_000.0;
let pag =
RandomPaymentActivity::new(source_capacity, expected_payment, 1.0, rng).unwrap();
RandomPaymentActivity::new(100_000_000_000, 38_000_000, 1.0, MutRng::new(None))
.unwrap();

// Wrong cases
for i in 0..source_capacity {
assert!(matches!(
pag.payment_amount(Some(i)),
Err(PaymentGenerationError(..))
))
}
let dist = pag.log_normal(dest_capacity_msat).unwrap();
let mut rng = rand::thread_rng();

// All other cases will work. We are not going to exhaustively test for the rest up to u64::MAX, let just pick a bunch
for i in source_capacity + 1..100 * source_capacity {
assert!(pag.payment_amount(Some(i)).is_ok())
let mut samples = Vec::new();
for _ in 0..1000 {
let sample = dist.sample(&mut rng);
samples.push(sample);
}

// We can even try really high numbers to make sure they are not troublesome
for i in u64::MAX - 10000..u64::MAX {
assert!(pag.payment_amount(Some(i)).is_ok())
}
let mean = samples.iter().sum::<f64>() / samples.len() as f64;
let variance =
samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
let std_dev = variance.sqrt();

assert!(matches!(
pag.payment_amount(None),
Err(PaymentGenerationError(..))
));
let lower_bound = mean - std_dev;
let upper_bound = mean + std_dev;

let within_one_std_dev = samples
.iter()
.filter(|&&x| x >= lower_bound && x <= upper_bound)
.count();

// For a normal distribution, approximately 68% of values should be within 1 standard
// deviation. We allow some tolerance in the test so that it doesn't flake.
let percentage = (within_one_std_dev as f64 / samples.len() as f64) * 100.0;
assert!(
(60.0..=75.0).contains(&percentage),
"Expected 60-75% of values within 1 std dev, got {:.1}%",
percentage
);
}
}
}