dnsforward: imp code, rm wg
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
@@ -552,11 +553,11 @@ var protocols = []string{
|
||||
}
|
||||
|
||||
// validateUpstream returns an error if u alongside with domains is not a valid
|
||||
// upstream configuration. useDefault is true if the upstream is
|
||||
// upstream configuration. usesDefault is true if the upstream is
|
||||
// domain-specific and is configured to point at the default upstream server
|
||||
// which is validated separately. The upstream is considered domain-specific
|
||||
// only if domains is at least not nil.
|
||||
func validateUpstream(u string, specific bool) (useDefault bool, err error) {
|
||||
// which is validated separately. specific reflects if the upstream is
|
||||
// domain-specific.
|
||||
func validateUpstream(u string, specific bool) (usesDefault bool, err error) {
|
||||
// The special server address '#' means that default server must be used.
|
||||
if u == "#" && specific {
|
||||
return true, nil
|
||||
@@ -625,8 +626,8 @@ func separateUpstream(upstreamStr string) (upstreams, domains []string, err erro
|
||||
// properly.
|
||||
type healthCheckFunc func(u upstream.Upstream) (err error)
|
||||
|
||||
// checkDNSUpstreamExc checks if the DNS upstream exchanges correctly.
|
||||
func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
|
||||
// checkExchange checks if the DNS upstream exchanges correctly.
|
||||
func checkExchange(u upstream.Upstream) (err error) {
|
||||
// testTLD is the special-use fully-qualified domain name for testing the
|
||||
// DNS server reachability.
|
||||
//
|
||||
@@ -656,11 +657,11 @@ func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkPrivateUpstreamExc checks if the upstream for resolving private
|
||||
// addresses exchanges correctly.
|
||||
// checkPrivateExchange checks if the upstream for resolving private addresses
|
||||
// exchanges correctly.
|
||||
//
|
||||
// TODO(e.burkov): Think about testing the ip6.arpa. as well.
|
||||
func checkPrivateUpstreamExc(u upstream.Upstream) (err error) {
|
||||
func checkPrivateExchange(u upstream.Upstream) (err error) {
|
||||
// inAddrArpaTLD is the special-use fully-qualified domain name for PTR IP
|
||||
// address resolution.
|
||||
//
|
||||
@@ -705,16 +706,18 @@ func (err domainSpecificTestError) Error() (msg string) {
|
||||
// system hosts files. Checks if the DNS upstream exchanges correctly. It
|
||||
// returns an error if addr is not valid DNS upstream address or the upstream
|
||||
// is not exchanging correctly.
|
||||
//
|
||||
// TODO(e.burkov): Remove the receiver.
|
||||
func (s *Server) checkUpstreamAddr(
|
||||
addr string,
|
||||
specific bool,
|
||||
basicOpts *upstream.Options,
|
||||
check healthCheckFunc,
|
||||
) (err error) {
|
||||
useDefault, err := validateUpstream(addr, specific)
|
||||
usesDefault, err := validateUpstream(addr, specific)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wrong upstream format: %w", err)
|
||||
} else if useDefault {
|
||||
} else if usesDefault {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -733,6 +736,8 @@ func (s *Server) checkUpstreamAddr(
|
||||
}
|
||||
|
||||
// dnsFilter can be nil during application update.
|
||||
//
|
||||
// TODO(e.burkov): Remove when update dnsproxy.
|
||||
if s.dnsFilter != nil {
|
||||
recs := s.dnsFilter.EtcHostsRecords(extractUpstreamHost(addr))
|
||||
for _, rec := range recs {
|
||||
@@ -756,50 +761,48 @@ type checkResult = struct {
|
||||
// nil for working upstreams.
|
||||
status error
|
||||
|
||||
// ups is the upstream server address as given in the request. It may
|
||||
// appear a domain-specific upstream line if it isn't correct itself.
|
||||
ups string
|
||||
// address is the upstream server address as given in the request. It may
|
||||
// appear to be a whole line if it's incorrect itself.
|
||||
address string
|
||||
}
|
||||
|
||||
// checkDNS parses an upstream configuration line using opts and checks if the
|
||||
// specified upstreams are working using check. addWG is decremented when the
|
||||
// expected number of results is added to resWG, then results are sent to resCh.
|
||||
//
|
||||
// TODO(e.burkov): Remove the receiver.
|
||||
func (s *Server) checkDNS(
|
||||
line string,
|
||||
opts *upstream.Options,
|
||||
check healthCheckFunc,
|
||||
addWG *sync.WaitGroup,
|
||||
resWG *sync.WaitGroup,
|
||||
resCh chan checkResult,
|
||||
countWG *sync.WaitGroup,
|
||||
resNum *atomic.Int32,
|
||||
resCh chan<- checkResult,
|
||||
) {
|
||||
defer log.OnPanic("dnsforward: checking upstreams")
|
||||
|
||||
upstreams, domains, err := separateUpstream(line)
|
||||
addrs, domains, err := separateUpstream(line)
|
||||
if err != nil {
|
||||
resWG.Add(1)
|
||||
addWG.Done()
|
||||
resNum.Add(1)
|
||||
countWG.Done()
|
||||
|
||||
resCh <- checkResult{
|
||||
ups: line,
|
||||
status: fmt.Errorf("wrong upstream format: %w", err),
|
||||
address: line,
|
||||
status: fmt.Errorf("wrong upstream format: %w", err),
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resWG.Add(len(upstreams))
|
||||
addWG.Done()
|
||||
resNum.Add(int32(len(addrs)))
|
||||
countWG.Done()
|
||||
|
||||
specific := len(domains) > 0
|
||||
for _, ups := range upstreams {
|
||||
cr := checkResult{ups: ups}
|
||||
|
||||
checkErr := s.checkUpstreamAddr(ups, specific, opts, check)
|
||||
if checkErr != nil {
|
||||
cr.status = checkErr
|
||||
for _, addr := range addrs {
|
||||
resCh <- checkResult{
|
||||
address: addr,
|
||||
status: s.checkUpstreamAddr(addr, specific, opts, check),
|
||||
}
|
||||
|
||||
resCh <- cr
|
||||
}
|
||||
}
|
||||
|
||||
@@ -809,39 +812,37 @@ func (s *Server) check(req *upstreamJSON, opts *upstream.Options) (result map[st
|
||||
req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty)
|
||||
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
|
||||
|
||||
result = map[string]string{}
|
||||
countWG := &sync.WaitGroup{}
|
||||
countWG.Add(len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams))
|
||||
|
||||
resNum := &atomic.Int32{}
|
||||
resCh := make(chan checkResult)
|
||||
resWG := &sync.WaitGroup{}
|
||||
go func() {
|
||||
for res := range resCh {
|
||||
// TODO(e.burkov): The servers used for at least two of common,
|
||||
// private and fallback resolving should be reported separately.
|
||||
if res.status != nil {
|
||||
result[res.ups] = res.status.Error()
|
||||
} else {
|
||||
result[res.ups] = "OK"
|
||||
}
|
||||
resWG.Done()
|
||||
|
||||
for _, addr := range req.Upstreams {
|
||||
go s.checkDNS(addr, opts, checkExchange, countWG, resNum, resCh)
|
||||
}
|
||||
for _, addr := range req.FallbackDNS {
|
||||
go s.checkDNS(addr, opts, checkExchange, countWG, resNum, resCh)
|
||||
}
|
||||
for _, addr := range req.PrivateUpstreams {
|
||||
go s.checkDNS(addr, opts, checkPrivateExchange, countWG, resNum, resCh)
|
||||
}
|
||||
|
||||
// Wait until all the servers are counted and enqueued.
|
||||
countWG.Wait()
|
||||
n := resNum.Load()
|
||||
|
||||
result = make(map[string]string, n)
|
||||
for i := int32(0); i < n; i++ {
|
||||
// TODO(e.burkov): Upstreams intended for different purposes should
|
||||
// be distinguished in the result, even if specified equally.
|
||||
res := <-resCh
|
||||
if res.status != nil {
|
||||
result[res.address] = res.status.Error()
|
||||
} else {
|
||||
result[res.address] = "OK"
|
||||
}
|
||||
}()
|
||||
|
||||
// addWG is used to wait for all goroutines to count the expected number of
|
||||
// results and to add it to resWG.
|
||||
addWG := &sync.WaitGroup{}
|
||||
addWG.Add(len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams))
|
||||
|
||||
for _, ups := range req.Upstreams {
|
||||
go s.checkDNS(ups, opts, checkDNSUpstreamExc, addWG, resWG, resCh)
|
||||
}
|
||||
for _, ups := range req.FallbackDNS {
|
||||
go s.checkDNS(ups, opts, checkDNSUpstreamExc, addWG, resWG, resCh)
|
||||
}
|
||||
for _, ups := range req.PrivateUpstreams {
|
||||
go s.checkDNS(ups, opts, checkPrivateUpstreamExc, addWG, resWG, resCh)
|
||||
}
|
||||
|
||||
addWG.Wait()
|
||||
resWG.Wait()
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user