From ca898fe74e9c447f57c21bd99f292bc0fb6ade15 Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Thu, 9 Nov 2023 12:26:44 +0300 Subject: [PATCH] dnsforward: imp code, rm wg --- internal/dnsforward/http.go | 125 ++++++++++++++++++------------------ 1 file changed, 63 insertions(+), 62 deletions(-) diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index 8445e56a..9b7eaea2 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -8,6 +8,7 @@ import ( "net/netip" "strings" "sync" + "sync/atomic" "time" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" @@ -552,11 +553,11 @@ var protocols = []string{ } // validateUpstream returns an error if u alongside with domains is not a valid -// upstream configuration. useDefault is true if the upstream is +// upstream configuration. usesDefault is true if the upstream is // domain-specific and is configured to point at the default upstream server -// which is validated separately. The upstream is considered domain-specific -// only if domains is at least not nil. -func validateUpstream(u string, specific bool) (useDefault bool, err error) { +// which is validated separately. specific reflects if the upstream is +// domain-specific. +func validateUpstream(u string, specific bool) (usesDefault bool, err error) { // The special server address '#' means that default server must be used. if u == "#" && specific { return true, nil @@ -625,8 +626,8 @@ func separateUpstream(upstreamStr string) (upstreams, domains []string, err erro // properly. type healthCheckFunc func(u upstream.Upstream) (err error) -// checkDNSUpstreamExc checks if the DNS upstream exchanges correctly. -func checkDNSUpstreamExc(u upstream.Upstream) (err error) { +// checkExchange checks if the DNS upstream exchanges correctly. +func checkExchange(u upstream.Upstream) (err error) { // testTLD is the special-use fully-qualified domain name for testing the // DNS server reachability. // @@ -656,11 +657,11 @@ func checkDNSUpstreamExc(u upstream.Upstream) (err error) { return nil } -// checkPrivateUpstreamExc checks if the upstream for resolving private -// addresses exchanges correctly. +// checkPrivateExchange checks if the upstream for resolving private addresses +// exchanges correctly. // // TODO(e.burkov): Think about testing the ip6.arpa. as well. -func checkPrivateUpstreamExc(u upstream.Upstream) (err error) { +func checkPrivateExchange(u upstream.Upstream) (err error) { // inAddrArpaTLD is the special-use fully-qualified domain name for PTR IP // address resolution. // @@ -705,16 +706,18 @@ func (err domainSpecificTestError) Error() (msg string) { // system hosts files. Checks if the DNS upstream exchanges correctly. It // returns an error if addr is not valid DNS upstream address or the upstream // is not exchanging correctly. +// +// TODO(e.burkov): Remove the receiver. func (s *Server) checkUpstreamAddr( addr string, specific bool, basicOpts *upstream.Options, check healthCheckFunc, ) (err error) { - useDefault, err := validateUpstream(addr, specific) + usesDefault, err := validateUpstream(addr, specific) if err != nil { return fmt.Errorf("wrong upstream format: %w", err) - } else if useDefault { + } else if usesDefault { return nil } @@ -733,6 +736,8 @@ func (s *Server) checkUpstreamAddr( } // dnsFilter can be nil during application update. + // + // TODO(e.burkov): Remove when update dnsproxy. if s.dnsFilter != nil { recs := s.dnsFilter.EtcHostsRecords(extractUpstreamHost(addr)) for _, rec := range recs { @@ -756,50 +761,48 @@ type checkResult = struct { // nil for working upstreams. status error - // ups is the upstream server address as given in the request. It may - // appear a domain-specific upstream line if it isn't correct itself. - ups string + // address is the upstream server address as given in the request. It may + // appear to be a whole line if it's incorrect itself. + address string } // checkDNS parses an upstream configuration line using opts and checks if the // specified upstreams are working using check. addWG is decremented when the // expected number of results is added to resWG, then results are sent to resCh. +// +// TODO(e.burkov): Remove the receiver. func (s *Server) checkDNS( line string, opts *upstream.Options, check healthCheckFunc, - addWG *sync.WaitGroup, - resWG *sync.WaitGroup, - resCh chan checkResult, + countWG *sync.WaitGroup, + resNum *atomic.Int32, + resCh chan<- checkResult, ) { defer log.OnPanic("dnsforward: checking upstreams") - upstreams, domains, err := separateUpstream(line) + addrs, domains, err := separateUpstream(line) if err != nil { - resWG.Add(1) - addWG.Done() + resNum.Add(1) + countWG.Done() resCh <- checkResult{ - ups: line, - status: fmt.Errorf("wrong upstream format: %w", err), + address: line, + status: fmt.Errorf("wrong upstream format: %w", err), } return } - resWG.Add(len(upstreams)) - addWG.Done() + resNum.Add(int32(len(addrs))) + countWG.Done() specific := len(domains) > 0 - for _, ups := range upstreams { - cr := checkResult{ups: ups} - - checkErr := s.checkUpstreamAddr(ups, specific, opts, check) - if checkErr != nil { - cr.status = checkErr + for _, addr := range addrs { + resCh <- checkResult{ + address: addr, + status: s.checkUpstreamAddr(addr, specific, opts, check), } - - resCh <- cr } } @@ -809,39 +812,37 @@ func (s *Server) check(req *upstreamJSON, opts *upstream.Options) (result map[st req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty) req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty) - result = map[string]string{} + countWG := &sync.WaitGroup{} + countWG.Add(len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams)) + + resNum := &atomic.Int32{} resCh := make(chan checkResult) - resWG := &sync.WaitGroup{} - go func() { - for res := range resCh { - // TODO(e.burkov): The servers used for at least two of common, - // private and fallback resolving should be reported separately. - if res.status != nil { - result[res.ups] = res.status.Error() - } else { - result[res.ups] = "OK" - } - resWG.Done() + + for _, addr := range req.Upstreams { + go s.checkDNS(addr, opts, checkExchange, countWG, resNum, resCh) + } + for _, addr := range req.FallbackDNS { + go s.checkDNS(addr, opts, checkExchange, countWG, resNum, resCh) + } + for _, addr := range req.PrivateUpstreams { + go s.checkDNS(addr, opts, checkPrivateExchange, countWG, resNum, resCh) + } + + // Wait until all the servers are counted and enqueued. + countWG.Wait() + n := resNum.Load() + + result = make(map[string]string, n) + for i := int32(0); i < n; i++ { + // TODO(e.burkov): Upstreams intended for different purposes should + // be distinguished in the result, even if specified equally. + res := <-resCh + if res.status != nil { + result[res.address] = res.status.Error() + } else { + result[res.address] = "OK" } - }() - - // addWG is used to wait for all goroutines to count the expected number of - // results and to add it to resWG. - addWG := &sync.WaitGroup{} - addWG.Add(len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams)) - - for _, ups := range req.Upstreams { - go s.checkDNS(ups, opts, checkDNSUpstreamExc, addWG, resWG, resCh) } - for _, ups := range req.FallbackDNS { - go s.checkDNS(ups, opts, checkDNSUpstreamExc, addWG, resWG, resCh) - } - for _, ups := range req.PrivateUpstreams { - go s.checkDNS(ups, opts, checkPrivateUpstreamExc, addWG, resWG, resCh) - } - - addWG.Wait() - resWG.Wait() return result }