diff --git a/config/config.go b/config/config.go index 7e14a7a8..d327fc32 100644 --- a/config/config.go +++ b/config/config.go @@ -297,6 +297,8 @@ type HTTPProbe struct { Compression string `yaml:"compression,omitempty"` BodySizeLimit units.Base2Bytes `yaml:"body_size_limit,omitempty"` UseHTTP3 bool `yaml:"enable_http3,omitempty"` + DNSServer string `yaml:"dns_server,omitempty"` + DNSTimeout time.Duration `yaml:"dns_timeout,omitempty"` } type GRPCProbe struct { @@ -305,6 +307,8 @@ type GRPCProbe struct { TLSConfig config.TLSConfig `yaml:"tls_config,omitempty"` IPProtocolFallback bool `yaml:"ip_protocol_fallback,omitempty"` PreferredIPProtocol string `yaml:"preferred_ip_protocol,omitempty"` + DNSServer string `yaml:"dns_server,omitempty"` + DNSTimeout time.Duration `yaml:"dns_timeout,omitempty"` } type HeaderMatch struct { @@ -332,6 +336,8 @@ type TCPProbe struct { QueryResponse []QueryResponse `yaml:"query_response,omitempty"` TLS bool `yaml:"tls,omitempty"` TLSConfig config.TLSConfig `yaml:"tls_config,omitempty"` + DNSServer string `yaml:"dns_server,omitempty"` + DNSTimeout time.Duration `yaml:"dns_timeout,omitempty"` } type ICMPProbe struct { @@ -341,6 +347,8 @@ type ICMPProbe struct { PayloadSize int `yaml:"payload_size,omitempty"` DontFragment bool `yaml:"dont_fragment,omitempty"` TTL int `yaml:"ttl,omitempty"` + DNSServer string `yaml:"dns_server,omitempty"` + DNSTimeout time.Duration `yaml:"dns_timeout,omitempty"` } type DNSProbe struct { @@ -358,6 +366,8 @@ type DNSProbe struct { ValidateAnswer DNSRRValidator `yaml:"validate_answer_rrs,omitempty"` ValidateAuthority DNSRRValidator `yaml:"validate_authority_rrs,omitempty"` ValidateAdditional DNSRRValidator `yaml:"validate_additional_rrs,omitempty"` + DNSServer string `yaml:"dns_server,omitempty"` + DNSTimeout time.Duration `yaml:"dns_timeout,omitempty"` } type DNSRRValidator struct { diff --git a/prober/dns.go b/prober/dns.go index 3d301b9b..2d477ce5 100644 --- a/prober/dns.go +++ b/prober/dns.go @@ -196,7 +196,7 @@ func ProbeDNS(ctx context.Context, target string, module config.Module, registry } targetAddr = target } - ip, lookupTime, err := chooseProtocol(ctx, module.DNS.IPProtocol, module.DNS.IPProtocolFallback, targetAddr, registry, logger) + ip, lookupTime, err := chooseProtocol(ctx, module.DNS.IPProtocol, module.DNS.IPProtocolFallback, targetAddr, registry, logger, module.DNS.DNSServer, module.DNS.DNSTimeout) if err != nil { logger.Error("Error resolving address", "err", err) return false diff --git a/prober/grpc.go b/prober/grpc.go index fe8d1eca..fb1a6c11 100644 --- a/prober/grpc.go +++ b/prober/grpc.go @@ -144,7 +144,7 @@ func ProbeGRPC(ctx context.Context, target string, module config.Module, registr return false } - ip, lookupTime, err := chooseProtocol(ctx, module.GRPC.PreferredIPProtocol, module.GRPC.IPProtocolFallback, targetHost, registry, logger) + ip, lookupTime, err := chooseProtocol(ctx, module.GRPC.PreferredIPProtocol, module.GRPC.IPProtocolFallback, targetHost, registry, logger, module.GRPC.DNSServer, module.GRPC.DNSTimeout) if err != nil { logger.Error("Error resolving address", "err", err) return false diff --git a/prober/http.go b/prober/http.go index f5d546f6..05affba2 100644 --- a/prober/http.go +++ b/prober/http.go @@ -404,7 +404,7 @@ func ProbeHTTP(ctx context.Context, target string, module config.Module, registr var ip *net.IPAddr if shouldResolveDNSWithProxy(module.HTTP) { var lookupTime float64 - ip, lookupTime, err = chooseProtocol(ctx, module.HTTP.IPProtocol, module.HTTP.IPProtocolFallback, targetHost, registry, logger) + ip, lookupTime, err = chooseProtocol(ctx, module.HTTP.IPProtocol, module.HTTP.IPProtocolFallback, targetHost, registry, logger, module.HTTP.DNSServer, module.HTTP.DNSTimeout) durationGaugeVec.WithLabelValues("resolve").Add(lookupTime) if err != nil { logger.Error("Error resolving address", "err", err) diff --git a/prober/icmp.go b/prober/icmp.go index 2f07e6b0..739d2d5c 100644 --- a/prober/icmp.go +++ b/prober/icmp.go @@ -87,7 +87,7 @@ func ProbeICMP(ctx context.Context, target string, module config.Module, registr registry.MustRegister(durationGaugeVec) - dstIPAddr, lookupTime, err := chooseProtocol(ctx, module.ICMP.IPProtocol, module.ICMP.IPProtocolFallback, target, registry, logger) + dstIPAddr, lookupTime, err := chooseProtocol(ctx, module.ICMP.IPProtocol, module.ICMP.IPProtocolFallback, target, registry, logger, module.ICMP.DNSServer, module.ICMP.DNSTimeout) if err != nil { logger.Error("Error resolving address", "err", err) diff --git a/prober/tcp.go b/prober/tcp.go index f2f1396d..f45ea48c 100644 --- a/prober/tcp.go +++ b/prober/tcp.go @@ -36,7 +36,7 @@ func dialTCP(ctx context.Context, target string, module config.Module, registry return nil, err } - ip, _, err := chooseProtocol(ctx, module.TCP.IPProtocol, module.TCP.IPProtocolFallback, targetAddress, registry, logger) + ip, _, err := chooseProtocol(ctx, module.TCP.IPProtocol, module.TCP.IPProtocolFallback, targetAddress, registry, logger, module.TCP.DNSServer, module.TCP.DNSTimeout) if err != nil { logger.Error("Error resolving address", "err", err) return nil, err diff --git a/prober/utils.go b/prober/utils.go index 3dc4153c..2f97266a 100644 --- a/prober/utils.go +++ b/prober/utils.go @@ -30,7 +30,7 @@ var protocolToGauge = map[string]float64{ } // Returns the IP for the IPProtocol and lookup time. -func chooseProtocol(ctx context.Context, IPProtocol string, fallbackIPProtocol bool, target string, registry *prometheus.Registry, logger *slog.Logger) (ip *net.IPAddr, lookupTime float64, err error) { +func chooseProtocol(ctx context.Context, IPProtocol string, fallbackIPProtocol bool, target string, registry *prometheus.Registry, logger *slog.Logger, dnsServer string, dnsTimeout time.Duration) (ip *net.IPAddr, lookupTime float64, err error) { var fallbackProtocol string probeDNSLookupTimeSeconds := prometheus.NewGauge(prometheus.GaugeOpts{ Name: "probe_dns_lookup_time_seconds", @@ -67,6 +67,19 @@ func chooseProtocol(ctx context.Context, IPProtocol string, fallbackIPProtocol b }() resolver := &net.Resolver{} + + if dnsServer != "" { + resolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{ + Timeout: dnsTimeout, + } + return d.DialContext(ctx, "udp", dnsServer) + }, + } + } + if !fallbackIPProtocol { ips, err := resolver.LookupIP(ctx, IPProtocol, target) if err == nil {