diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 96f8827..b51f068 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -1,4 +1,4 @@ -name: rustfmt +name: cargo on: push: @@ -17,3 +17,5 @@ jobs: run: cargo +stable fmt -- --check - name: Check linting run: cargo +stable clippy -- -Dwarnings + - name: Check unit tests + run: cargo +stable test \ No newline at end of file diff --git a/src/dns_parser/builder.rs b/src/dns_parser/builder.rs index 4f8b5a7..a2b17d8 100644 --- a/src/dns_parser/builder.rs +++ b/src/dns_parser/builder.rs @@ -201,6 +201,21 @@ impl> Builder { builder } + + pub fn add_answers<'a, 'b>( + self, + name: &Name<'_>, + cls: QueryClass, + ttl: u32, + data: impl Iterator> + 'a, + ) -> Builder { + let mut builder = self.move_to::(); + for item in data { + builder.write_rr(name, cls, ttl, &item); + Header::inc_answers(&mut builder.buf).expect("Too many answers"); + } + builder + } } impl> Builder { @@ -219,10 +234,25 @@ impl> Builder { builder } -} -impl Builder { #[allow(dead_code)] + pub fn add_nameservers<'a, 'b>( + self, + name: &Name<'_>, + cls: QueryClass, + ttl: u32, + data: impl Iterator> + 'a, + ) -> Builder { + let mut builder = self.move_to::(); + for item in data { + builder.write_rr(name, cls, ttl, &item); + Header::inc_nameservers(&mut builder.buf).expect("Too many nameservers"); + } + builder + } +} + +impl> Builder { pub fn add_additional( self, name: &Name<'_>, @@ -233,8 +263,23 @@ impl Builder { let mut builder = self.move_to::(); builder.write_rr(name, cls, ttl, data); - Header::inc_nameservers(&mut builder.buf).expect("Too many additional answers"); + Header::inc_additional(&mut builder.buf).expect("Too many additional answers"); + + builder + } + pub fn add_additionals<'a, 'b>( + self, + name: &Name<'_>, + cls: QueryClass, + ttl: u32, + data: impl Iterator> + 'a, + ) -> Builder { + let mut builder = self.move_to::(); + for item in data { + builder.write_rr(name, cls, ttl, &item); + Header::inc_additional(&mut builder.buf).expect("Too many additional answers"); + } builder } } diff --git a/src/dns_parser/mod.rs b/src/dns_parser/mod.rs index a90a1a1..bbb573e 100644 --- a/src/dns_parser/mod.rs +++ b/src/dns_parser/mod.rs @@ -12,4 +12,4 @@ pub use self::header::Header; mod rrdata; pub use self::rrdata::RRData; mod builder; -pub use self::builder::{Answers, Builder}; +pub use self::builder::{Additional, Answers, Builder}; diff --git a/src/fsm.rs b/src/fsm.rs index 5738a69..663ad33 100644 --- a/src/fsm.rs +++ b/src/fsm.rs @@ -18,9 +18,10 @@ use tokio::{net::UdpSocket, sync::mpsc}; use super::{DEFAULT_TTL, MDNS_PORT}; use crate::address_family::AddressFamily; -use crate::services::{ServiceData, Services}; +use crate::services::{ServiceData, Services, ServicesInner}; pub type AnswerBuilder = dns_parser::Builder; +pub type AdditionalBuilder = dns_parser::Builder; const SERVICE_TYPE_ENUMERATION_NAME: Cow<'static, str> = Cow::Borrowed("_services._dns-sd._udp.local"); @@ -105,14 +106,6 @@ impl FSM { return; } - let mut unicast_builder = dns_parser::Builder::new_response(packet.header.id, false, true) - .move_to::(); - let mut multicast_builder = - dns_parser::Builder::new_response(packet.header.id, false, true) - .move_to::(); - unicast_builder.set_max_size(None); - multicast_builder.set_max_size(None); - for question in packet.questions { debug!( "received question: {:?} {}", @@ -120,144 +113,191 @@ impl FSM { ); if question.qclass == QueryClass::IN || question.qclass == QueryClass::Any { + let mut builder = dns_parser::Builder::new_response(packet.header.id, false, true) + .move_to::(); + builder.set_max_size(None); + let builder = self.handle_question(&question, builder); + if builder.is_empty() { + continue; + } + let response = builder.build().unwrap_or_else(|x| x); if question.qu { - unicast_builder = self.handle_question(&question, unicast_builder); + self.outgoing.push_back((response, addr)); } else { - multicast_builder = self.handle_question(&question, multicast_builder); + let addr = SocketAddr::new(AF::MDNS_GROUP.into(), MDNS_PORT); + self.outgoing.push_back((response, addr)); } } } - - if !multicast_builder.is_empty() { - let response = multicast_builder.build().unwrap_or_else(|x| x); - let addr = SocketAddr::new(AF::MDNS_GROUP.into(), MDNS_PORT); - self.outgoing.push_back((response, addr)); - } - - if !unicast_builder.is_empty() { - let response = unicast_builder.build().unwrap_or_else(|x| x); - self.outgoing.push_back((response, addr)); - } } /// - fn handle_service_type_enumeration<'a>( + fn handle_service_type_enumeration( question: &dns_parser::Question<'_>, - services: impl Iterator, + services: &ServicesInner, mut builder: AnswerBuilder, ) -> AnswerBuilder { let service_type_enumeration_name = Name::FromStr(SERVICE_TYPE_ENUMERATION_NAME); if question.qname == service_type_enumeration_name { - for svc in services { - let svc_type = ServiceData { - name: svc.typ.clone(), - typ: service_type_enumeration_name.clone(), - port: svc.port, - txt: vec![], - }; - builder = svc_type.add_ptr_rr(builder, DEFAULT_TTL); + for typ in services.all_types() { + builder = builder.add_answer( + &service_type_enumeration_name, + QueryClass::IN, + DEFAULT_TTL, + &RRData::PTR(typ.clone()), + ); } } builder } + #[allow(clippy::too_many_lines)] fn handle_question( &self, question: &dns_parser::Question<'_>, mut builder: AnswerBuilder, - ) -> AnswerBuilder { + ) -> AdditionalBuilder { let services = self.services.read().unwrap(); let hostname = services.get_hostname(); match question.qtype { - QueryType::A | QueryType::AAAA if question.qname == *hostname => { - builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL); - } + QueryType::A | QueryType::AAAA if question.qname == *hostname => builder + .add_answers(hostname, QueryClass::IN, DEFAULT_TTL, self.ip_rr()) + .move_to(), QueryType::All => { + let mut include_ip_additionals = false; // A / AAAA if question.qname == *hostname { - builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL); + builder = + builder.add_answers(hostname, QueryClass::IN, DEFAULT_TTL, self.ip_rr()); } // PTR - builder = - Self::handle_service_type_enumeration(question, services.into_iter(), builder); + builder = Self::handle_service_type_enumeration(question, &services, builder); for svc in services.find_by_type(&question.qname) { - builder = svc.add_ptr_rr(builder, DEFAULT_TTL); - builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL); - builder = svc.add_txt_rr(builder, DEFAULT_TTL); - builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL); + builder = + builder.add_answer(&svc.typ, QueryClass::IN, DEFAULT_TTL, &svc.ptr_rr()); + include_ip_additionals = true; } // SRV if let Some(svc) = services.find_by_name(&question.qname) { - builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL); - builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL); + builder = builder + .add_answer( + &svc.name, + QueryClass::IN, + DEFAULT_TTL, + &svc.srv_rr(hostname), + ) + .add_answer(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr()); + include_ip_additionals = true; + } + let mut builder = builder.move_to::(); + // PTR (additional) + for svc in services.find_by_type(&question.qname) { + builder = builder + .add_additional( + &svc.name, + QueryClass::IN, + DEFAULT_TTL, + &svc.srv_rr(hostname), + ) + .add_additional(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr()); + include_ip_additionals = true; } + + if include_ip_additionals { + builder = builder.add_additionals( + hostname, + QueryClass::IN, + DEFAULT_TTL, + self.ip_rr(), + ); + } + builder } QueryType::PTR => { - builder = - Self::handle_service_type_enumeration(question, services.into_iter(), builder); + let mut include_ip_additionals = false; + let mut builder = + Self::handle_service_type_enumeration(question, &services, builder); + for svc in services.find_by_type(&question.qname) { + builder = + builder.add_answer(&svc.typ, QueryClass::IN, DEFAULT_TTL, &svc.ptr_rr()); + } + let mut builder = builder.move_to::(); for svc in services.find_by_type(&question.qname) { - builder = svc.add_ptr_rr(builder, DEFAULT_TTL); - builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL); - builder = svc.add_txt_rr(builder, DEFAULT_TTL); - builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL); + builder = builder + .add_additional( + &svc.name, + QueryClass::IN, + DEFAULT_TTL, + &svc.srv_rr(hostname), + ) + .add_additional(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr()); + include_ip_additionals = true; } + if include_ip_additionals { + builder = builder.add_additionals( + hostname, + QueryClass::IN, + DEFAULT_TTL, + self.ip_rr(), + ); + } + builder } QueryType::SRV => { if let Some(svc) = services.find_by_name(&question.qname) { - builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL); - builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL); + builder + .add_answer( + &svc.name, + QueryClass::IN, + DEFAULT_TTL, + &svc.srv_rr(hostname), + ) + .add_additionals(hostname, QueryClass::IN, DEFAULT_TTL, self.ip_rr()) + .move_to() + } else { + builder.move_to() } } QueryType::TXT => { if let Some(svc) = services.find_by_name(&question.qname) { - builder = svc.add_txt_rr(builder, DEFAULT_TTL); + builder + .add_answer(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr()) + .move_to() + } else { + builder.move_to() } } - _ => (), + _ => builder.move_to(), } - - builder } - fn add_ip_rr( - &self, - hostname: &Name<'_>, - mut builder: AnswerBuilder, - ttl: u32, - ) -> AnswerBuilder { + fn ip_rr(&self) -> impl Iterator> + '_ { let interfaces = match get_if_addrs() { Ok(interfaces) => interfaces, Err(err) => { error!("could not get list of interfaces: {err}"); - return builder; + vec![] } }; - - for iface in interfaces { + interfaces.into_iter().filter_map(move |iface| { if iface.is_loopback() { - continue; + return None; } trace!("found interface {iface:?}"); if !self.allowed_ip.is_empty() && !self.allowed_ip.contains(&iface.ip()) { trace!(" -> interface dropped"); - continue; + return None; } match (iface.ip(), AF::DOMAIN) { - (IpAddr::V4(ip), Domain::IPV4) => { - builder = builder.add_answer(hostname, QueryClass::IN, ttl, &RRData::A(ip)); - } - (IpAddr::V6(ip), Domain::IPV6) => { - builder = builder.add_answer(hostname, QueryClass::IN, ttl, &RRData::AAAA(ip)); - } - _ => (), + (IpAddr::V4(ip), Domain::IPV4) => Some(RRData::A(ip)), + (IpAddr::V6(ip), Domain::IPV6) => Some(RRData::AAAA(ip)), + _ => None, } - } - - builder + }) } fn send_unsolicited(&mut self, svc: &ServiceData, ttl: u32, include_ip: bool) { @@ -267,11 +307,17 @@ impl FSM { let services = self.services.read().unwrap(); - builder = svc.add_ptr_rr(builder, ttl); - builder = svc.add_srv_rr(services.get_hostname(), builder, ttl); - builder = svc.add_txt_rr(builder, ttl); + builder = builder.add_answer(&svc.typ, QueryClass::IN, ttl, &svc.ptr_rr()); + builder = builder.add_answer( + &svc.name, + QueryClass::IN, + ttl, + &svc.srv_rr(services.get_hostname()), + ); + builder = builder.add_answer(&svc.name, QueryClass::IN, ttl, &svc.txt_rr()); if include_ip { - builder = self.add_ip_rr(services.get_hostname(), builder, ttl); + builder = + builder.add_answers(services.get_hostname(), QueryClass::IN, ttl, self.ip_rr()); } if !builder.is_empty() { @@ -355,7 +401,7 @@ mod tests { answer_builder = FSM::::handle_service_type_enumeration( &question, - services.read().unwrap().into_iter(), + &services.read().unwrap(), answer_builder, ); diff --git a/src/lib.rs b/src/lib.rs index fa1db0c..ce24607 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -309,7 +309,7 @@ impl Responder { /// /// # use std::io; /// # fn main() -> io::Result<()> { - /// let responder = Responder::new()?; + /// let responder = Responder::new(); /// // bind service /// let _http_svc = responder.register_with_ttl( /// "_http._tcp".into(), diff --git a/src/services.rs b/src/services.rs index 5801f39..f524081 100644 --- a/src/services.rs +++ b/src/services.rs @@ -1,12 +1,10 @@ -use crate::dns_parser::{self, Name, QueryClass, RRData}; +use crate::dns_parser::{Name, RRData}; use multimap::MultiMap; use rand::{rng, Rng}; use std::collections::HashMap; use std::slice; use std::sync::{Arc, RwLock}; -pub type AnswerBuilder = dns_parser::Builder; - /// A collection of registered services is shared between threads. pub type Services = Arc>; @@ -82,6 +80,10 @@ impl ServicesInner { svc } + + pub fn all_types(&self) -> impl Iterator> { + self.by_type.keys() + } } impl<'a> IntoIterator for &'a ServicesInner { @@ -120,35 +122,20 @@ pub struct ServiceData { /// Packet building helpers for `fsm` to respond with `ServiceData` impl ServiceData { - pub fn add_ptr_rr(&self, builder: AnswerBuilder, ttl: u32) -> AnswerBuilder { - builder.add_answer( - &self.typ, - QueryClass::IN, - ttl, - &RRData::PTR(self.name.clone()), - ) + pub fn ptr_rr(&self) -> RRData<'_> { + RRData::PTR(self.name.clone()) } - pub fn add_srv_rr( - &self, - hostname: &Name<'_>, - builder: AnswerBuilder, - ttl: u32, - ) -> AnswerBuilder { - builder.add_answer( - &self.name, - QueryClass::IN, - ttl, - &RRData::SRV { - priority: 0, - weight: 0, - port: self.port, - target: hostname.clone(), - }, - ) + pub fn srv_rr<'a>(&self, hostname: &'a Name<'_>) -> RRData<'a> { + RRData::SRV { + priority: 0, + weight: 0, + port: self.port, + target: hostname.clone(), + } } - pub fn add_txt_rr(&self, builder: AnswerBuilder, ttl: u32) -> AnswerBuilder { - builder.add_answer(&self.name, QueryClass::IN, ttl, &RRData::TXT(&self.txt)) + pub fn txt_rr(&self) -> RRData<'_> { + RRData::TXT(&self.txt) } }