@@ -2,12 +2,14 @@ package upstream
22
33import (
44 "context"
5+ "fmt"
56 "net/netip"
67 "time"
78
89 "github.com/AdguardTeam/golibs/errors"
910 "github.com/AdguardTeam/golibs/log"
1011 "github.com/miekg/dns"
12+ "golang.org/x/exp/slices"
1113)
1214
1315const (
@@ -21,37 +23,40 @@ const (
2123
2224// ExchangeParallel returns the dirst successful response from one of u. It
2325// returns an error if all upstreams failed to exchange the request.
24- func ExchangeParallel (u []Upstream , req * dns.Msg ) (reply * dns.Msg , resolved Upstream , err error ) {
25- upsNum := len (u )
26+ func ExchangeParallel (ups []Upstream , req * dns.Msg ) (reply * dns.Msg , resolved Upstream , err error ) {
27+ upsNum := len (ups )
2628 switch upsNum {
2729 case 0 :
2830 return nil , nil , ErrNoUpstreams
2931 case 1 :
30- reply , err = exchangeAndLog (u [0 ], req )
32+ reply , err = exchangeAndLog (ups [0 ], req )
3133
32- return reply , u [0 ], err
34+ return reply , ups [0 ], err
3335 default :
3436 // Go on.
3537 }
3638
37- resCh := make (chan * ExchangeAllResult )
38- errCh := make (chan error )
39- for _ , f := range u {
40- go exchangeAsync (f , req , resCh , errCh )
39+ resCh := make (chan any , upsNum )
40+ for _ , f := range ups {
41+ go exchangeAsync (f , req , resCh )
4142 }
4243
4344 errs := []error {}
44- for range u {
45- select {
46- case excErr := <- errCh :
47- errs = append (errs , excErr )
48- case rep := <- resCh :
49- if rep .Resp != nil {
50- return rep .Resp , rep .Upstream , nil
45+ for range ups {
46+ var r * ExchangeAllResult
47+ r , err = receiveAsyncResult (resCh )
48+ if err != nil {
49+ if ! errors .Is (err , ErrNoReply ) {
50+ errs = append (errs , err )
5151 }
52+ } else {
53+ return r .Resp , r .Upstream , nil
5254 }
5355 }
5456
57+ // TODO(e.burkov): Probably it's better to return the joined error from
58+ // each upstream that returned no response, and get rid of multiple
59+ // [errors.Is] calls. This will change the behavior though.
5560 if len (errs ) == 0 {
5661 return nil , nil , errors .Error ("none of upstream servers responded" )
5762 }
@@ -72,8 +77,8 @@ type ExchangeAllResult struct {
7277// ExchangeAll returns the responses from all of u. It returns an error only if
7378// all upstreams failed to exchange the request.
7479func ExchangeAll (ups []Upstream , req * dns.Msg ) (res []ExchangeAllResult , err error ) {
75- upsl := len (ups )
76- switch upsl {
80+ upsNum := len (ups )
81+ switch upsNum {
7782 case 0 :
7883 return nil , ErrNoUpstreams
7984 case 1 :
@@ -90,62 +95,60 @@ func ExchangeAll(ups []Upstream, req *dns.Msg) (res []ExchangeAllResult, err err
9095 // Go on.
9196 }
9297
93- res = make ([]ExchangeAllResult , 0 , upsl )
98+ res = make ([]ExchangeAllResult , 0 , upsNum )
9499 var errs []error
95100
96- resCh := make (chan * ExchangeAllResult )
97- errCh := make (chan error )
101+ resCh := make (chan any , upsNum )
98102
99103 // Start exchanging concurrently.
100104 for _ , u := range ups {
101- go exchangeAsync (u , req , resCh , errCh )
105+ go exchangeAsync (u , req , resCh )
102106 }
103107
104108 // Wait for all exchanges to finish.
105109 for range ups {
106110 var r * ExchangeAllResult
107- r , err = receiveAsyncResult (resCh , errCh )
111+ r , err = receiveAsyncResult (resCh )
108112 if err != nil {
109113 errs = append (errs , err )
110114 } else {
111115 res = append (res , * r )
112116 }
113117 }
114118
115- if len (errs ) == upsl {
119+ if len (errs ) == upsNum {
116120 // TODO(e.burkov): Use [errors.Join] in Go 1.20.
117121 return res , errors .List ("all upstreams failed to exchange" , errs ... )
118122 }
119123
120- return res , nil
124+ return slices . Clip ( res ) , nil
121125}
122126
123127// receiveAsyncResult receives a single result from resCh or an error from
124128// errCh. It returns either a non-nil result or an error.
125- func receiveAsyncResult (
126- resCh chan * ExchangeAllResult ,
127- errCh chan error ,
128- ) (res * ExchangeAllResult , err error ) {
129- select {
130- case err = <- errCh :
131- return nil , err
132- case rep := <- resCh :
133- if rep .Resp == nil {
129+ func receiveAsyncResult (resCh chan any ) (res * ExchangeAllResult , err error ) {
130+ switch res := (<- resCh ).(type ) {
131+ case error :
132+ return nil , res
133+ case * ExchangeAllResult :
134+ if res .Resp == nil {
134135 return nil , ErrNoReply
135136 }
136137
137- return rep , nil
138+ return res , nil
139+ default :
140+ return nil , fmt .Errorf ("unexpected type %T of result" , res )
138141 }
139142}
140143
141144// exchangeAsync tries to resolve DNS request with one upstream and sends the
142145// result to respCh.
143- func exchangeAsync (u Upstream , req * dns.Msg , respCh chan * ExchangeAllResult , errCh chan error ) {
146+ func exchangeAsync (u Upstream , req * dns.Msg , resCh chan any ) {
144147 reply , err := exchangeAndLog (u , req )
145148 if err != nil {
146- errCh <- err
149+ resCh <- err
147150 } else {
148- respCh <- & ExchangeAllResult {Resp : reply , Upstream : u }
151+ resCh <- & ExchangeAllResult {Resp : reply , Upstream : u }
149152 }
150153}
151154
@@ -156,12 +159,14 @@ func exchangeAndLog(u Upstream, req *dns.Msg) (resp *dns.Msg, err error) {
156159
157160 start := time .Now ()
158161 reply , err := u .Exchange (req )
159- elapsed := time .Since (start )
162+ dur := time .Since (start )
160163
161- if q := & req .Question [0 ]; err == nil {
162- log .Debug ("dnsproxy: upstream %s exchanged %s successfully in %s" , addr , q , elapsed )
163- } else {
164- log .Debug ("dnsproxy: upstream %s failed to exchange %s in %s: %s" , addr , q , elapsed , err )
164+ if len (req .Question ) > 0 {
165+ if q := & req .Question [0 ]; err == nil {
166+ log .Debug ("dnsproxy: upstream %s exchanged %s successfully in %s" , addr , q , dur )
167+ } else {
168+ log .Debug ("dnsproxy: upstream %s failed to exchange %s in %s: %s" , addr , q , dur , err )
169+ }
165170 }
166171
167172 return reply , err
0 commit comments