diff --git a/iroh/examples/echo.rs b/iroh/examples/echo.rs index 2a1cb7b6d23..40e994f569d 100644 --- a/iroh/examples/echo.rs +++ b/iroh/examples/echo.rs @@ -22,6 +22,15 @@ const ALPN: &[u8] = b"iroh-example/echo/0"; #[tokio::main] async fn main() -> Result<()> { + tracing::subscriber::set_global_default( + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .with_span_events(tracing_subscriber::fmt::format::FmtSpan::NONE) + .compact() + .finish(), + ) + .e()?; + let router = start_accept_side().await?; let node_addr = router.endpoint().node_addr().initialized().await?; @@ -57,6 +66,8 @@ async fn connect_side(addr: NodeAddr) -> Result<()> { // The above call only queues a close message to be sent (see how it's not async!). // We need to actually call this to make sure this message is sent out. + // + // TODO this causes the above close to not get cleaned up gracefully and early exits, check the warn messages in logs... endpoint.close().await; // If we don't call this, but continue using the endpoint, we then the queued // close call will eventually be picked up and sent. diff --git a/iroh/src/endpoint.rs b/iroh/src/endpoint.rs index 76354bcbc3c..4ac2c6566e4 100644 --- a/iroh/src/endpoint.rs +++ b/iroh/src/endpoint.rs @@ -3039,8 +3039,11 @@ mod tests { let (mut send, mut recv) = conn.open_bi().await.e()?; send.write_all(b"Hello, world!").await.e()?; send.finish().e()?; - recv.read_to_end(1_000).await.e()?; + let response = recv.read_to_end(1_000).await.e()?; + assert_eq!(&response, b"Hello, world!"); conn.close(42u32.into(), b"thanks, bye!"); + // TODO this causes a warn that things are not cleaned up gracefully, how can we fail a + // test due to that? client.close().await; let close_err = server_task.await.e()??; @@ -3054,6 +3057,83 @@ mod tests { Ok(()) } + #[tokio::test] + #[traced_test] + async fn connecting_is_fast_from_same_id_consecutively() -> Result { + const ECHOS: usize = 10; + + let server = Endpoint::builder() + .alpns(vec![TEST_ALPN.to_vec()]) + .relay_mode(RelayMode::Disabled) + .bind() + .await?; + let server_addr = server.node_addr().initialized().await?; + let server_task = tokio::spawn(async move { + let mut close_reasons = Vec::new(); + + for _ in 0..ECHOS { + let incoming = server.accept().await.e()?; + let conn = incoming.await.e()?; + let (mut send, mut recv) = conn.accept_bi().await.e()?; + let msg = recv.read_to_end(1000).await.e()?; + send.write_all(&msg).await.e()?; + send.finish().e()?; + let close_reason = conn.closed().await; + close_reasons.push(close_reason); + } + Ok::<_, Error>(close_reasons) + }); + + let mut elapsed_times = Vec::with_capacity(ECHOS); + for i in 0..ECHOS { + let timer = std::time::Instant::now(); + let client_secret_key = SecretKey::from_bytes(&[0u8; 32]); + let client = Endpoint::builder() + .secret_key(client_secret_key) + // NOTE this is not necessary to trigger the failure so I have it commented out for + // now + // .relay_mode(RelayMode::Disabled) + .bind() + .await?; + let conn = client.connect(server_addr.clone(), TEST_ALPN).await?; + let (mut send, mut recv) = conn.open_bi().await.e()?; + let bytes = format!("Hello, world {i}").into_bytes(); + send.write_all(&bytes).await.e()?; + send.finish().e()?; + let response = recv.read_to_end(1_000).await.e()?; + assert_eq!(&response, &bytes); + conn.close(42u32.into(), b"thanks, bye!"); + client.close().await; + let elapsed = timer.elapsed(); + elapsed_times.push(elapsed); + } + + let close_errs = server_task.await.e()??; + assert_eq!(close_errs.len(), ECHOS); + + for (i, err) in close_errs.into_iter().enumerate() { + let ConnectionError::ApplicationClosed(app_close) = err else { + panic!("Unexpected close reason for conn {i}: {err:?}"); + }; + assert_eq!(app_close.error_code, 42u32.into()); + assert_eq!(app_close.reason.as_ref(), b"thanks, bye!" as &[u8]); + } + + elapsed_times.iter().enumerate().for_each(|(i, elapsed)| { + println!("Elapsed time for connection {i}: {elapsed:?}"); + }); + + // If any of the elapsed times are greater than 3x the minimum throw an error + let min_elapsed = elapsed_times.iter().min().unwrap_or(&Duration::ZERO); + for (i, elapsed) in elapsed_times.iter().enumerate() { + if *elapsed > *min_elapsed * 3 { + panic!("Connection {i} took too long compared to baseline ({min_elapsed:?}): {elapsed:?}"); + } + } + + Ok(()) + } + #[cfg(feature = "metrics")] #[tokio::test] #[traced_test]