diff --git a/src/socket/icmp.rs b/src/socket/icmp.rs index dee3416a3..1f87801b4 100644 --- a/src/socket/icmp.rs +++ b/src/socket/icmp.rs @@ -14,6 +14,7 @@ use crate::wire::{Icmpv4Packet, Icmpv4Repr, Ipv4Repr}; #[cfg(feature = "proto-ipv6")] use crate::wire::{Icmpv6Packet, Icmpv6Repr, Ipv6Repr}; use crate::wire::{IpAddress, IpListenEndpoint, IpProtocol, IpRepr}; +use crate::wire::{TcpPacket, TcpRepr}; use crate::wire::{UdpPacket, UdpRepr}; /// Error returned by [`Socket::bind`] @@ -86,15 +87,17 @@ pub enum Endpoint { #[default] Unspecified, Ident(u16), + Tcp(IpListenEndpoint), Udp(IpListenEndpoint), } impl Endpoint { pub fn is_specified(&self) -> bool { match *self { + Endpoint::Unspecified => false, Endpoint::Ident(_) => true, + Endpoint::Tcp(endpoint) => endpoint.port != 0, Endpoint::Udp(endpoint) => endpoint.port != 0, - Endpoint::Unspecified => false, } } } @@ -453,6 +456,26 @@ impl<'a> Socket<'a> { Err(_) => false, } } + // If we are bound to ICMP errors associated to a TCP port, only + // accept Destination Unreachable or Time Exceeded messages with + // the data containing a UDP packet send from the local port we + // are bound to. + ( + &Endpoint::Tcp(endpoint), + &Icmpv4Repr::DstUnreachable { data, header, .. } + | &Icmpv4Repr::TimeExceeded { data, header, .. }, + ) if endpoint.addr.is_none() || endpoint.addr == Some(ip_repr.dst_addr.into()) => { + let packet = TcpPacket::new_unchecked(data); + match TcpRepr::parse( + &packet, + &header.src_addr.into(), + &header.dst_addr.into(), + &cx.checksum_caps(), + ) { + Ok(repr) => endpoint.port == repr.src_port, + Err(_) => false, + } + } // If we are bound to a specific ICMP identifier value, only accept an // Echo Request/Reply with the identifier field matching the endpoint // port. @@ -495,6 +518,26 @@ impl<'a> Socket<'a> { Err(_) => false, } } + // If we are bound to ICMP errors associated to a TCP port, only + // accept Destination Unreachable or Time Exceeded messages with + // the data containing a UDP packet send from the local port we + // are bound to. + ( + &Endpoint::Tcp(endpoint), + &Icmpv6Repr::DstUnreachable { data, header, .. } + | &Icmpv6Repr::TimeExceeded { data, header, .. }, + ) if endpoint.addr.is_none() || endpoint.addr == Some(ip_repr.dst_addr.into()) => { + let packet = TcpPacket::new_unchecked(data); + match TcpRepr::parse( + &packet, + &header.src_addr.into(), + &header.dst_addr.into(), + &cx.checksum_caps(), + ) { + Ok(repr) => endpoint.port == repr.src_port, + Err(_) => false, + } + } // If we are bound to a specific ICMP identifier value, only accept an // Echo Request/Reply with the identifier field matching the endpoint // port.