diff --git a/openhcl/underhill_core/src/dispatch/mod.rs b/openhcl/underhill_core/src/dispatch/mod.rs index 3b69a3918b..f1b776153c 100644 --- a/openhcl/underhill_core/src/dispatch/mod.rs +++ b/openhcl/underhill_core/src/dispatch/mod.rs @@ -160,7 +160,6 @@ pub(crate) struct LoadedVm { pub host_vmbus_relay: Option, // channels are revoked when dropped, so make sure to keep them alive pub _vmbus_devices: Vec>>, - pub _vmbus_intercept_devices: Vec>, pub _ide_accel_devices: Vec>>, pub network_settings: Option>, pub shutdown_relay: Option<( diff --git a/openhcl/underhill_core/src/worker.rs b/openhcl/underhill_core/src/worker.rs index 059cbb057e..72dc5c4363 100644 --- a/openhcl/underhill_core/src/worker.rs +++ b/openhcl/underhill_core/src/worker.rs @@ -3062,8 +3062,6 @@ async fn new_underhill_vm( ); } - let mut vmbus_intercept_devices = Vec::new(); - let shutdown_relay = if let Some(recv) = intercepted_shutdown_ic { let mut shutdown_guest = ShutdownGuestIc::new(); let recv_host_shutdown = shutdown_guest.get_shutdown_notifier(); @@ -3089,8 +3087,7 @@ async fn new_underhill_vm( .context("shutdown relay dma client")?, shutdown_guest, )?; - vmbus_intercept_devices.push(shutdown_guest.detach(driver_source.simple(), recv)?); - + shutdown_guest.detach(driver_source.simple(), recv)?; Some((recv_host_shutdown, send_guest_shutdown)) } else { None @@ -3218,7 +3215,6 @@ async fn new_underhill_vm( vmbus_server, host_vmbus_relay, _vmbus_devices: vmbus_devices, - _vmbus_intercept_devices: vmbus_intercept_devices, _ide_accel_devices: ide_accel_devices, network_settings, shutdown_relay, diff --git a/vm/devices/vmbus/vmbus_relay/src/lib.rs b/vm/devices/vmbus/vmbus_relay/src/lib.rs index ee68647e51..d28ed3e815 100644 --- a/vm/devices/vmbus/vmbus_relay/src/lib.rs +++ b/vm/devices/vmbus/vmbus_relay/src/lib.rs @@ -647,10 +647,6 @@ impl RelayTask { async fn handle_offer(&mut self, offer: client::OfferInfo) -> Result<()> { let channel_id = offer.offer.channel_id.0; - if self.channels.contains_key(&ChannelId(channel_id)) { - anyhow::bail!("channel {channel_id} already exists"); - } - if let Some(intercept) = self.intercept_channels.get(&offer.offer.instance_id) { self.channels.insert( ChannelId(channel_id), @@ -660,6 +656,10 @@ impl RelayTask { return Ok(()); } + if self.channels.contains_key(&ChannelId(channel_id)) { + anyhow::bail!("channel {channel_id} already exists"); + } + // Used to Recv requests from the server. let (request_send, request_recv) = mesh::channel(); // Used to Send responses from the server diff --git a/vm/devices/vmbus/vmbus_relay_intercept_device/src/lib.rs b/vm/devices/vmbus/vmbus_relay_intercept_device/src/lib.rs index 14599496a0..00bde61a84 100644 --- a/vm/devices/vmbus/vmbus_relay_intercept_device/src/lib.rs +++ b/vm/devices/vmbus/vmbus_relay_intercept_device/src/lib.rs @@ -167,28 +167,36 @@ impl SimpleVmbusClientDeviceWrapper { mut self, driver: impl SpawnDriver, recv_relay: mesh::Receiver, - ) -> Result> { + ) -> Result<()> { + let (send_disconnected, recv_disconnected) = mesh::oneshot(); self.vmbus_listener.insert( &self.spawner, format!("{}", self.instance_id), SimpleVmbusClientDeviceTaskState { offer: None, recv_relay, + send_disconnected: Some(send_disconnected), vtl_pages: None, }, ); - let (driver_send, driver_recv) = mesh::oneshot(); driver .spawn( format!("vmbus_relay_device {}", self.instance_id), async move { self.vmbus_listener.start(); - let _ = driver_recv.await; - self.vmbus_listener.stop().await; + let _ = recv_disconnected.await; + assert!(!self.vmbus_listener.stop().await); + if self.vmbus_listener.state().unwrap().vtl_pages.is_some() { + // The VTL pages were not freed. This can occur if an + // error is hit that drops the vmbus parent tasks. Just + // pend here and let the outer error cause the VM to + // exit. + pending::<()>().await; + } }, ) .detach(); - Ok(driver_send) + Ok(()) } } @@ -215,6 +223,8 @@ struct SimpleVmbusClientDeviceTaskState { offer: Option, #[inspect(skip)] recv_relay: mesh::Receiver, + #[inspect(skip)] + send_disconnected: Option>, #[inspect(hex, with = "|x| x.as_ref().map(|x| inspect::iter_by_index(x.pfns()))")] vtl_pages: Option, } @@ -234,7 +244,13 @@ impl AsyncRun stop: &mut StopTask<'_>, state: &mut SimpleVmbusClientDeviceTaskState, ) -> Result<(), Cancelled> { - stop.until_stopped(self.process_messages(state)).await + stop.until_stopped(self.process_messages(state)).await?; + state + .send_disconnected + .take() + .expect("task should not be restarted") + .send(()); + Ok(()) } } @@ -351,7 +367,7 @@ impl SimpleVmbusClientDeviceTask { }; if state.vtl_pages.is_some() { - if let Err(err) = offer + match offer .request_send .call( ChannelRequest::TeardownGpadl, @@ -359,13 +375,19 @@ impl SimpleVmbusClientDeviceTask { ) .await { - tracing::error!( - error = &err as &dyn std::error::Error, - "failed to teardown gpadl" - ); + Ok(()) => { + state.vtl_pages = None; + } + Err(err) => { + // If the ring buffer pages are still in use by the host, which + // has to be assumed, the memory pages cannot be used again as + // they have been marked as visible to VTL0. + tracing::error!( + error = &err as &dyn std::error::Error, + "Failed to teardown gpadl -- leaking memory." + ); + } } - - state.vtl_pages = None; } } @@ -504,7 +526,7 @@ impl SimpleVmbusClientDeviceTask { /// Responds to the channel being revoked by the host. async fn handle_revoke(&mut self, state: &mut SimpleVmbusClientDeviceTaskState) { - let Some(offer) = state.offer.take() else { + let Some(offer) = state.offer.as_ref() else { return; }; tracing::info!("device revoked"); @@ -513,6 +535,7 @@ impl SimpleVmbusClientDeviceTask { self.device.task_mut().0.close(offer.offer.subchannel_index); } self.cleanup_device_resources(state).await; + drop(state.offer.take()); } fn handle_save(&mut self) -> SavedStateBlob { @@ -545,27 +568,25 @@ impl SimpleVmbusClientDeviceTask { loop { enum Event { Request(InterceptChannelRequest), - Revoke(()), + Revoke, } - let revoke = pin!(async { - if let Some(offer) = &mut state.offer { - (&mut offer.revoke_recv).await.ok(); - } else { - pending().await - } - }); - let Some(r) = ( - (&mut state.recv_relay).map(Event::Request), - futures::stream::once(revoke).map(Event::Revoke), - ) - .merge() - .next() - .await - else { + let r = if let Some(offer) = &mut state.offer { + ( + (&mut state.recv_relay).map(Event::Request), + futures::stream::once(&mut offer.revoke_recv).map(|_| Event::Revoke), + ) + .merge() + .next() + .await + } else { + let mut recv_relay = pin!(&mut state.recv_relay); + recv_relay.next().await.map(Event::Request) + }; + let Some(r) = r else { break; }; match r { - Event::Revoke(()) => { + Event::Revoke => { self.handle_revoke(state).await; } Event::Request(InterceptChannelRequest::Offer(offer)) => {