diff --git a/vm/devices/net/netvsp/src/lib.rs b/vm/devices/net/netvsp/src/lib.rs index a82f849928..ef449c4776 100644 --- a/vm/devices/net/netvsp/src/lib.rs +++ b/vm/devices/net/netvsp/src/lib.rs @@ -147,11 +147,19 @@ const LINK_DELAY_DURATION: Duration = Duration::from_secs(5); #[cfg(test)] const LINK_DELAY_DURATION: Duration = Duration::from_millis(333); -#[derive(PartialEq)] -enum CoordinatorMessage { +#[derive(Default, PartialEq)] +struct CoordinatorMessageUpdateType { /// Update guest VF state based on current availability and the guest VF state tracked by the primary channel. /// This includes adding the guest VF device and switching the data path. - UpdateGuestVfState, + guest_vf_state: bool, + /// Update the receive filter for all channels. + filter_state: bool, +} + +#[derive(PartialEq)] +enum CoordinatorMessage { + /// Update network state. + Update(CoordinatorMessageUpdateType), /// Restart endpoints and resume processing. This will also attempt to set VF and data path state to match current /// expectations. Restart, @@ -382,6 +390,7 @@ struct NetChannel { pending_send_size: usize, restart: Option, can_use_ring_size_opt: bool, + packet_filter: u32, } /// Buffers used during packet processing. @@ -1364,6 +1373,7 @@ impl Nic { pending_send_size: 0, restart: None, can_use_ring_size_opt, + packet_filter: rndisprot::NDIS_PACKET_TYPE_NONE, }, state, coordinator_send: self.coordinator_send.clone().unwrap(), @@ -1453,6 +1463,7 @@ impl Nic { mut control: RestoreControl<'_>, state: saved_state::SavedState, ) -> Result<(), NetRestoreError> { + let mut saved_packet_filter = 0u32; if let Some(state) = state.open { let open = match &state.primary { saved_state::Primary::Version => vec![true], @@ -1537,8 +1548,12 @@ impl Nic { tx_spread_sent, guest_link_down, pending_link_action, + packet_filter, } = ready; + // If saved state does not have a packet filter set, default to directed, multicast, and broadcast. + saved_packet_filter = packet_filter.unwrap_or(rndisprot::NPROTO_PACKET_FILTER); + let version = check_version(version) .ok_or(NetRestoreError::UnsupportedVersion(version))?; @@ -1621,6 +1636,11 @@ impl Nic { self.insert_worker(channel_idx as u16, &request.unwrap(), state, false)?; } } + for worker in self.coordinator.state_mut().unwrap().workers.iter_mut() { + if let Some(worker_state) = worker.state_mut() { + worker_state.channel.packet_filter = saved_packet_filter; + } + } } else { control .restore(&[false]) @@ -1781,6 +1801,11 @@ impl Nic { PrimaryChannelGuestVfState::Restoring(saved_state) => saved_state, }; + let worker_0_packet_filter = coordinator.workers[0] + .state() + .unwrap() + .channel + .packet_filter; saved_state::Primary::Ready(saved_state::ReadyPrimary { version: ready.buffers.version as u32, receive_buffer: ready.buffers.recv_buffer.saved_state(), @@ -1810,6 +1835,7 @@ impl Nic { tx_spread_sent: primary.tx_spread_sent, guest_link_down: !primary.guest_link_up, pending_link_action, + packet_filter: Some(worker_0_packet_filter), }) } }; @@ -2593,7 +2619,12 @@ impl NetChannel { if primary.rndis_state == RndisState::Operational { if self.guest_vf_is_available(Some(vfid), buffers.version, buffers.ndis_config)? { primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised; - return Ok(Some(CoordinatorMessage::UpdateGuestVfState)); + return Ok(Some(CoordinatorMessage::Update( + CoordinatorMessageUpdateType { + guest_vf_state: true, + ..Default::default() + }, + ))); } else if let Some(true) = primary.is_data_path_switched { tracing::error!( "Data path switched, but current guest negotiation does not support VTL0 VF" @@ -2733,10 +2764,7 @@ impl NetChannel { // flag on inband packets and won't send a completion // packet. primary.guest_vf_state = PrimaryChannelGuestVfState::AvailableAdvertised; - // restart will also add the VF based on the guest_vf_state - if self.restart.is_none() { - self.restart = Some(CoordinatorMessage::UpdateGuestVfState); - } + self.send_coordinator_update_vf(); } else if let Some(true) = primary.is_data_path_switched { tracing::error!( "Data path switched, but current guest negotiation does not support VTL0 VF" @@ -2784,12 +2812,18 @@ impl NetChannel { tracing::trace!(?request, "handling control message MESSAGE_TYPE_SET_MSG"); let status = match self.adapter.handle_oid_set(primary, request.oid, reader) { - Ok(restart_endpoint) => { + Ok((restart_endpoint, packet_filter)) => { // Restart the endpoint if the OID changed some critical // endpoint property. if restart_endpoint { self.restart = Some(CoordinatorMessage::Restart); } + if let Some(filter) = packet_filter { + if self.packet_filter != filter { + self.packet_filter = filter; + self.send_coordinator_update_filter(); + } + } rndisprot::STATUS_SUCCESS } Err(err) => { @@ -2973,6 +3007,31 @@ impl NetChannel { } Ok(()) } + + fn send_coordinator_update_message(&mut self, guest_vf: bool, packet_filter: bool) { + if self.restart.is_none() { + self.restart = Some(CoordinatorMessage::Update(CoordinatorMessageUpdateType { + guest_vf_state: guest_vf, + filter_state: packet_filter, + })); + } else if let Some(CoordinatorMessage::Restart) = self.restart { + // If a restart message is pending, do nothing. + // A restart will try to switch the data path based on primary.guest_vf_state. + // A restart will apply packet filter changes. + } else if let Some(CoordinatorMessage::Update(ref mut update)) = self.restart { + // Add the new update to the existing message. + update.guest_vf_state |= guest_vf; + update.filter_state |= packet_filter; + } + } + + fn send_coordinator_update_vf(&mut self) { + self.send_coordinator_update_message(true, false); + } + + fn send_coordinator_update_filter(&mut self) { + self.send_coordinator_update_message(false, true); + } } /// Writes an RNDIS message to `writer`. @@ -3290,13 +3349,14 @@ impl Adapter { primary: &mut PrimaryChannelState, oid: rndisprot::Oid, reader: impl MemoryRead + Clone, - ) -> Result { + ) -> Result<(bool, Option), OidError> { tracing::debug!(?oid, "oid set"); let mut restart_endpoint = false; + let mut packet_filter = None; match oid { rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER => { - // TODO + packet_filter = self.oid_set_packet_filter(reader)?; } rndisprot::Oid::OID_TCP_OFFLOAD_PARAMETERS => { self.oid_set_offload_parameters(reader, primary)?; @@ -3323,7 +3383,7 @@ impl Adapter { return Err(OidError::UnknownOid); } } - Ok(restart_endpoint) + Ok((restart_endpoint, packet_filter)) } fn oid_set_rss_parameters( @@ -3381,6 +3441,15 @@ impl Adapter { Ok(()) } + fn oid_set_packet_filter( + &self, + reader: impl MemoryRead + Clone, + ) -> Result, OidError> { + let filter: rndisprot::RndisPacketFilterOidValue = reader.clone().read_plain()?; + tracing::debug!(filter, "set packet filter"); + Ok(Some(filter)) + } + fn oid_set_offload_parameters( &self, reader: impl MemoryRead + Clone, @@ -3871,8 +3940,26 @@ impl Coordinator { } sleep_duration = None; } - Message::Internal(CoordinatorMessage::UpdateGuestVfState) => { - self.update_guest_vf_state(state).await; + Message::Internal(CoordinatorMessage::Update(update_type)) => { + if update_type.filter_state { + self.stop_workers().await; + let worker_0_packet_filter = + self.workers[0].state().unwrap().channel.packet_filter; + self.workers.iter_mut().skip(1).for_each(|worker| { + if let Some(state) = worker.state_mut() { + state.channel.packet_filter = worker_0_packet_filter; + tracing::debug!( + packet_filter = ?worker_0_packet_filter, + channel_idx = state.channel_idx, + "update packet filter" + ); + } + }); + } + + if update_type.guest_vf_state { + self.update_guest_vf_state(state).await; + } } Message::UpdateFromEndpoint(EndpointAction::RestartRequired) => self.restart = true, Message::UpdateFromEndpoint(EndpointAction::LinkStatusNotify(connect)) => { @@ -4315,6 +4402,7 @@ impl Coordinator { self.num_queues = num_queues; } + let worker_0_packet_filter = self.workers[0].state().unwrap().channel.packet_filter; // Provide the queue and receive buffer ranges for each worker. for ((worker, queue), rx_buffer) in self.workers.iter_mut().zip(queues).zip(rx_buffers) { worker.task_mut().queue_state = Some(QueueState { @@ -4322,6 +4410,10 @@ impl Coordinator { target_vp_set: false, rx_buffer_range: rx_buffer, }); + // Update the receive packet filter for the subchannel worker. + if let Some(worker) = worker.state_mut() { + worker.channel.packet_filter = worker_0_packet_filter; + } } Ok(()) @@ -4929,6 +5021,13 @@ impl NetChannel { data: &mut ProcessingData, epqueue: &mut dyn net_backend::Queue, ) -> Result { + if self.packet_filter == rndisprot::NDIS_PACKET_TYPE_NONE { + tracing::trace!( + packet_filter = self.packet_filter, + "rx packet not processed" + ); + return Ok(false); + } let n = epqueue .rx_poll(&mut data.rx_ready) .map_err(WorkerError::Endpoint)?; @@ -5071,10 +5170,7 @@ impl NetChannel { _ => (), }; if queue_switch_operation { - // A restart will also try to switch the data path based on primary.guest_vf_state. - if self.restart.is_none() { - self.restart = Some(CoordinatorMessage::UpdateGuestVfState) - }; + self.send_coordinator_update_vf(); } else { self.send_completion(transaction_id, &[])?; } diff --git a/vm/devices/net/netvsp/src/rndisprot.rs b/vm/devices/net/netvsp/src/rndisprot.rs index d9c1c3a820..43765d5103 100644 --- a/vm/devices/net/netvsp/src/rndisprot.rs +++ b/vm/devices/net/netvsp/src/rndisprot.rs @@ -751,6 +751,7 @@ open_enum! { DEFAULT = 0x80, RSS_CAPABILITIES = 0x88, RSS_PARAMETERS = 0x89, + OID_REQUEST = 0x96, OFFLOAD = 0xA7, OFFLOAD_ENCAPSULATION = 0xA8, } @@ -1082,3 +1083,14 @@ open_enum! { BINARY = 4, } } + +pub type RndisPacketFilterOidValue = u32; + +// Rndis Packet Filter Flags (OID_GEN_CURRENT_PACKET_FILTER) +pub const NDIS_PACKET_TYPE_NONE: u32 = 0x00000000; +pub const NDIS_PACKET_TYPE_DIRECTED: u32 = 0x00000001; +pub const NDIS_PACKET_TYPE_MULTICAST: u32 = 0x00000002; +pub const NDIS_PACKET_TYPE_ALL_MULTICAST: u32 = 0x00000004; +pub const NDIS_PACKET_TYPE_BROADCAST: u32 = 0x00000008; +pub const NPROTO_PACKET_FILTER: u32 = + NDIS_PACKET_TYPE_DIRECTED | NDIS_PACKET_TYPE_ALL_MULTICAST | NDIS_PACKET_TYPE_BROADCAST; diff --git a/vm/devices/net/netvsp/src/saved_state.rs b/vm/devices/net/netvsp/src/saved_state.rs index 8be332e21d..0dfc7d7a79 100644 --- a/vm/devices/net/netvsp/src/saved_state.rs +++ b/vm/devices/net/netvsp/src/saved_state.rs @@ -120,6 +120,8 @@ pub struct ReadyPrimary { pub guest_link_down: bool, #[mesh(15)] pub pending_link_action: Option, + #[mesh(16)] + pub packet_filter: Option, } #[derive(Debug, Protobuf)] diff --git a/vm/devices/net/netvsp/src/test.rs b/vm/devices/net/netvsp/src/test.rs index 880f315e31..fb582fb543 100644 --- a/vm/devices/net/netvsp/src/test.rs +++ b/vm/devices/net/netvsp/src/test.rs @@ -3861,6 +3861,237 @@ async fn send_rndis_indicate_status_message(driver: DefaultDriver) { .await; } +#[async_test] +async fn send_rndis_set_packet_filter(driver: DefaultDriver) { + const TOTAL_QUEUES: u32 = 4; + let endpoint_state = TestNicEndpointState::new(); + let endpoint = TestNicEndpoint::new(Some(endpoint_state.clone())); + let test_vf = Box::new(TestVirtualFunction::new(123)); + let builder = Nic::builder(); + let nic = builder.virtual_function(test_vf).build( + &VmTaskDriverSource::new(SingleDriverBackend::new(driver.clone())), + Guid::new_random(), + Box::new(endpoint), + [1, 2, 3, 4, 5, 6].into(), + 0, + ); + + let mut nic = TestNicDevice::new_with_nic(&driver, nic).await; + nic.start_vmbus_channel(); + let mut channel = nic.connect_vmbus_channel().await; + channel + .initialize( + TOTAL_QUEUES as usize - 1, + protocol::NdisConfigCapabilities::new().with_sriov(true), + ) + .await; + + let rndis_parser = channel.rndis_message_parser(); + + // Send and verify Initialization + channel + .send_rndis_control_message( + rndisprot::MESSAGE_TYPE_INITIALIZE_MSG, + rndisprot::InitializeRequest { + request_id: 123, + major_version: rndisprot::MAJOR_VERSION, + minor_version: rndisprot::MINOR_VERSION, + max_transfer_size: 0, + }, + &[], + ) + .await; + + let _: rndisprot::InitializeComplete = channel + .read_rndis_control_message(rndisprot::MESSAGE_TYPE_INITIALIZE_CMPLT) + .await + .unwrap(); + + channel + .read_with(|packet| match packet { + IncomingPacket::Data(_) => (), + _ => panic!("Unexpected packet"), + }) + .await + .expect("association packet"); + + // Allocate subchannels + let message = NvspMessage { + header: protocol::MessageHeader { + message_type: protocol::MESSAGE5_TYPE_SUB_CHANNEL, + }, + data: protocol::Message5SubchannelRequest { + operation: protocol::SubchannelOperation::ALLOCATE, + num_sub_channels: TOTAL_QUEUES - 1, + }, + padding: &[], + }; + channel + .write(OutgoingPacket { + transaction_id: 123, + packet_type: OutgoingPacketType::InBandWithCompletion, + payload: &message.payload(), + }) + .await; + channel + .read_with(|packet| match packet { + IncomingPacket::Completion(completion) => { + let mut reader = completion.reader(); + let header: protocol::MessageHeader = reader.read_plain().unwrap(); + assert_eq!(header.message_type, protocol::MESSAGE5_TYPE_SUB_CHANNEL); + let completion_data: protocol::Message5SubchannelComplete = + reader.read_plain().unwrap(); + assert_eq!(completion_data.status, protocol::Status::SUCCESS); + assert_eq!(completion_data.num_sub_channels, TOTAL_QUEUES - 1); + } + _ => panic!("Unexpected packet"), + }) + .await + .expect("completion message"); + + for idx in 1..TOTAL_QUEUES { + channel.connect_subchannel(idx).await; + } + + // Send Indirection Table + let transaction_id = channel + .read_with(|packet| match packet { + IncomingPacket::Data(packet) => { + let mut reader = packet.reader(); + let header: protocol::MessageHeader = reader.read_plain().unwrap(); + assert_eq!( + header.message_type, + protocol::MESSAGE5_TYPE_SEND_INDIRECTION_TABLE + ); + packet.transaction_id() + } + _ => panic!("Unexpected packet"), + }) + .await + .expect("indirection table message after all channels connected"); + if let Some(transaction_id) = transaction_id { + channel + .write(OutgoingPacket { + transaction_id, + packet_type: OutgoingPacketType::Completion, + payload: &NvspMessage { + header: protocol::MessageHeader { + message_type: protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE, + }, + data: protocol::Message1SendRndisPacketComplete { + status: protocol::Status::SUCCESS, + }, + padding: &[], + } + .payload(), + }) + .await; + } + + // Send a packet on every queue. + { + let locked_state = endpoint_state.lock(); + for (idx, queue) in locked_state.queues.iter().enumerate() { + queue.send(vec![idx as u8]); + } + } + + // Expect no packets + for idx in 0..TOTAL_QUEUES { + channel + .read_subchannel_with(idx, |_| panic!("Unexpected packet on subchannel {}", idx)) + .await + .expect_err("Packet should have been filtered"); + } + + // Set packet filter + let request_id = 456; + channel + .send_rndis_control_message( + rndisprot::MESSAGE_TYPE_SET_MSG, + rndisprot::SetRequest { + request_id, + oid: rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER, + information_buffer_length: size_of::() as u32, + information_buffer_offset: size_of::() as u32, + device_vc_handle: 0, + }, + &rndisprot::NPROTO_PACKET_FILTER.to_le_bytes(), + ) + .await; + + let set_complete: rndisprot::SetComplete = channel + .read_rndis_control_message(rndisprot::MESSAGE_TYPE_SET_CMPLT) + .await + .unwrap(); + + assert_eq!(set_complete.request_id, request_id); + assert_eq!(set_complete.status, rndisprot::STATUS_SUCCESS); + + // Send a packet on every queue. + { + let locked_state = endpoint_state.lock(); + for (idx, queue) in locked_state.queues.iter().enumerate() { + queue.send(vec![idx as u8]); + } + } + + // Get the transaction IDs for all of the received packets. + for idx in 0..TOTAL_QUEUES { + channel + .read_subchannel_with(idx, |packet| match packet { + IncomingPacket::Data(packet) => { + let (_, external_ranges) = rndis_parser.parse_data_message(packet); + let data: u8 = rndis_parser.get_data_packet_content(&external_ranges); + assert_eq!(idx, data as u32); + } + _ => panic!("Unexpected packet on subchannel {}", idx), + }) + .await + .expect("Data packet"); + } + + // Set packet filter to None + let request_id = 789; + channel + .send_rndis_control_message( + rndisprot::MESSAGE_TYPE_SET_MSG, + rndisprot::SetRequest { + request_id, + oid: rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER, + information_buffer_length: size_of::() as u32, + information_buffer_offset: size_of::() as u32, + device_vc_handle: 0, + }, + &rndisprot::NDIS_PACKET_TYPE_NONE.to_le_bytes(), + ) + .await; + + let set_complete: rndisprot::SetComplete = channel + .read_rndis_control_message(rndisprot::MESSAGE_TYPE_SET_CMPLT) + .await + .unwrap(); + + assert_eq!(set_complete.request_id, request_id); + assert_eq!(set_complete.status, rndisprot::STATUS_SUCCESS); + + // Test sending packets with the filter set to None. + { + let locked_state = endpoint_state.lock(); + for (idx, queue) in locked_state.queues.iter().enumerate() { + queue.send(vec![idx as u8]); + } + } + + // Expect no packets + for idx in 0..TOTAL_QUEUES { + channel + .read_subchannel_with(idx, |_| panic!("Unexpected packet on subchannel {}", idx)) + .await + .expect_err("Packet should have been filtered"); + } +} + #[async_test] async fn send_rndis_set_ex_message(driver: DefaultDriver) { let endpoint_state = TestNicEndpointState::new(); @@ -4425,6 +4656,29 @@ async fn set_rss_parameter_bufs_not_evenly_divisible(driver: DefaultDriver) { .await; } + // Set packet filter + channel + .send_rndis_control_message( + rndisprot::MESSAGE_TYPE_SET_MSG, + rndisprot::SetRequest { + request_id: 0, + oid: rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER, + information_buffer_length: size_of::() as u32, + information_buffer_offset: size_of::() as u32, + device_vc_handle: 0, + }, + &rndisprot::NPROTO_PACKET_FILTER.to_le_bytes(), + ) + .await; + + let set_complete: rndisprot::SetComplete = channel + .read_rndis_control_message(rndisprot::MESSAGE_TYPE_SET_CMPLT) + .await + .unwrap(); + + assert_eq!(set_complete.request_id, 0); + assert_eq!(set_complete.status, rndisprot::STATUS_SUCCESS); + // Receive a packet on every queue. { let locked_state = endpoint_state.lock();