Pull request 2052: 4977-multiple-domain-specific-upstreams
Updates #4977. Squashed commit of the following: commit da28c1b508b1aa4838d753fbb5fcac64a5fcebb9 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Oct 27 17:24:38 2023 +0300 all: fix typo commit d6bca6b252c9bd264737c93072869499afa24864 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Oct 27 14:44:20 2023 +0300 all: add todo commit 30875515942c58881305aa963220d57d31e0e67d Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Oct 25 20:00:17 2023 +0300 all: imp docs commit 04003c342fcf82aeb671938fb89592fd6baff16d Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Oct 25 16:59:14 2023 +0300 all: multiple domain specific upstreams
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
@@ -444,19 +445,10 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for _, u := range upstreams {
|
||||
var ups string
|
||||
var domains []string
|
||||
ups, domains, err = separateUpstream(u)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = validateUpstream(ups, domains)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validating upstream %q: %w", u, err)
|
||||
}
|
||||
err = validateUpstreamConfig(upstreams)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conf, err = proxy.ParseUpstreamsConfig(
|
||||
@@ -467,6 +459,7 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
} else if len(conf.Upstreams) == 0 {
|
||||
return nil, errors.Error("no default upstreams specified")
|
||||
@@ -475,6 +468,31 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
// validateUpstreamConfig validates each upstream from the upstream
|
||||
// configuration and returns an error if any upstream is invalid.
|
||||
//
|
||||
// TODO(e.burkov): Move into aghnet or even into dnsproxy.
|
||||
func validateUpstreamConfig(conf []string) (err error) {
|
||||
for _, u := range conf {
|
||||
var ups []string
|
||||
var domains []string
|
||||
ups, domains, err = separateUpstream(u)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
for _, addr := range ups {
|
||||
_, err = validateUpstream(addr, domains)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating upstream %q: %w", addr, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateUpstreams validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified.
|
||||
//
|
||||
@@ -567,12 +585,12 @@ func validateUpstream(u string, domains []string) (useDefault bool, err error) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// separateUpstream returns the upstream and the specified domains. domains is
|
||||
// nil when the upstream is not domains-specific. Otherwise it may also be
|
||||
// separateUpstream returns the upstreams and the specified domains. domains
|
||||
// is nil when the upstream is not domains-specific. Otherwise it may also be
|
||||
// empty.
|
||||
func separateUpstream(upstreamStr string) (ups string, domains []string, err error) {
|
||||
func separateUpstream(upstreamStr string) (upstreams, domains []string, err error) {
|
||||
if !strings.HasPrefix(upstreamStr, "[/") {
|
||||
return upstreamStr, nil, nil
|
||||
return []string{upstreamStr}, nil, nil
|
||||
}
|
||||
|
||||
defer func() { err = errors.Annotate(err, "bad upstream for domain %q: %w", upstreamStr) }()
|
||||
@@ -582,9 +600,9 @@ func separateUpstream(upstreamStr string) (ups string, domains []string, err err
|
||||
case 2:
|
||||
// Go on.
|
||||
case 1:
|
||||
return "", nil, errors.Error("missing separator")
|
||||
return nil, nil, errors.Error("missing separator")
|
||||
default:
|
||||
return "", []string{}, errors.Error("duplicated separator")
|
||||
return nil, nil, errors.Error("duplicated separator")
|
||||
}
|
||||
|
||||
for i, host := range strings.Split(parts[0], "/") {
|
||||
@@ -594,13 +612,13 @@ func separateUpstream(upstreamStr string) (ups string, domains []string, err err
|
||||
|
||||
err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*."))
|
||||
if err != nil {
|
||||
return "", domains, fmt.Errorf("domain at index %d: %w", i, err)
|
||||
return nil, nil, fmt.Errorf("domain at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
domains = append(domains, host)
|
||||
}
|
||||
|
||||
return parts[1], domains, nil
|
||||
return strings.Fields(parts[1]), domains, nil
|
||||
}
|
||||
|
||||
// healthCheckFunc is a signature of function to check if upstream exchanges
|
||||
@@ -683,30 +701,65 @@ func (err domainSpecificTestError) Error() (msg string) {
|
||||
return fmt.Sprintf("WARNING: %s", err.error)
|
||||
}
|
||||
|
||||
// parseUpstreamLine parses line and creates the [upstream.Upstream] using opts
|
||||
// and information from [s.dnsFilter.EtcHosts]. It returns an error if the line
|
||||
// is not a valid upstream line, see [upstream.AddressToUpstream]. It's a
|
||||
// caller's responsibility to close u.
|
||||
func (s *Server) parseUpstreamLine(
|
||||
// checkDNS parses line, creates DNS upstreams using opts, and checks if the
|
||||
// upstreams are exchanging correctly. It returns a map where key is an
|
||||
// upstream address and value is "OK", if the upstream exchanges correctly, or
|
||||
// text of the error.
|
||||
func (s *Server) checkDNS(
|
||||
line string,
|
||||
opts *upstream.Options,
|
||||
) (u upstream.Upstream, specific bool, err error) {
|
||||
// Separate upstream from domains list.
|
||||
upstreamAddr, domains, err := separateUpstream(line)
|
||||
check healthCheckFunc,
|
||||
) (result map[string]string) {
|
||||
result = map[string]string{}
|
||||
upstreams, domains, err := separateUpstream(line)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("wrong upstream format: %w", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
specific = len(domains) > 0
|
||||
specific := len(domains) > 0
|
||||
|
||||
useDefault, err := validateUpstream(upstreamAddr, domains)
|
||||
if err != nil {
|
||||
return nil, specific, fmt.Errorf("wrong upstream format: %w", err)
|
||||
} else if useDefault {
|
||||
return nil, specific, nil
|
||||
for _, upstreamAddr := range upstreams {
|
||||
var useDefault bool
|
||||
useDefault, err = validateUpstream(upstreamAddr, domains)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("wrong upstream format: %w", err)
|
||||
result[upstreamAddr] = err.Error()
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if useDefault {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
|
||||
|
||||
err = s.checkUpstreamAddr(upstreamAddr, specific, opts, check)
|
||||
if err != nil {
|
||||
result[upstreamAddr] = err.Error()
|
||||
} else {
|
||||
result[upstreamAddr] = "OK"
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
|
||||
return result
|
||||
}
|
||||
|
||||
// checkUpstreamAddr creates the DNS upstream using opts and information from
|
||||
// [s.dnsFilter.EtcHosts]. Checks if the DNS upstream exchanges correctly. It
|
||||
// returns an error if addr is not valid DNS upstream address or the upstream
|
||||
// is not exchanging correctly.
|
||||
func (s *Server) checkUpstreamAddr(
|
||||
addr string,
|
||||
specific bool,
|
||||
opts *upstream.Options,
|
||||
check healthCheckFunc,
|
||||
) (err error) {
|
||||
defer func() {
|
||||
if err != nil && specific {
|
||||
err = domainSpecificTestError{error: err}
|
||||
}
|
||||
}()
|
||||
|
||||
opts = &upstream.Options{
|
||||
Bootstrap: opts.Bootstrap,
|
||||
@@ -716,42 +769,25 @@ func (s *Server) parseUpstreamLine(
|
||||
|
||||
// dnsFilter can be nil during application update.
|
||||
if s.dnsFilter != nil {
|
||||
recs := s.dnsFilter.EtcHostsRecords(extractUpstreamHost(upstreamAddr))
|
||||
recs := s.dnsFilter.EtcHostsRecords(extractUpstreamHost(addr))
|
||||
for _, rec := range recs {
|
||||
opts.ServerIPAddrs = append(opts.ServerIPAddrs, rec.Addr.AsSlice())
|
||||
}
|
||||
sortNetIPAddrs(opts.ServerIPAddrs, opts.PreferIPv6)
|
||||
}
|
||||
u, err = upstream.AddressToUpstream(upstreamAddr, opts)
|
||||
|
||||
u, err := upstream.AddressToUpstream(addr, opts)
|
||||
if err != nil {
|
||||
return nil, specific, fmt.Errorf("creating upstream for %q: %w", upstreamAddr, err)
|
||||
return fmt.Errorf("creating upstream for %q: %w", addr, err)
|
||||
}
|
||||
|
||||
return u, specific, nil
|
||||
}
|
||||
|
||||
func (s *Server) checkDNS(line string, opts *upstream.Options, check healthCheckFunc) (err error) {
|
||||
if IsCommentOrEmpty(line) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var u upstream.Upstream
|
||||
var specific bool
|
||||
defer func() {
|
||||
if err != nil && specific {
|
||||
err = domainSpecificTestError{error: err}
|
||||
}
|
||||
}()
|
||||
|
||||
u, specific, err = s.parseUpstreamLine(line, opts)
|
||||
if err != nil || u == nil {
|
||||
return err
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, u.Close()) }()
|
||||
|
||||
return check(u)
|
||||
}
|
||||
|
||||
// handleTestUpstreamDNS handles requests to the POST /control/test_upstream_dns
|
||||
// endpoint.
|
||||
func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
req := &upstreamJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(req)
|
||||
@@ -761,6 +797,10 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty)
|
||||
req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty)
|
||||
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
|
||||
|
||||
opts := &upstream.Options{
|
||||
Bootstrap: req.BootstrapDNS,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
@@ -770,54 +810,44 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
opts.Bootstrap = defaultBootstrap
|
||||
}
|
||||
|
||||
type upsCheckResult = struct {
|
||||
err error
|
||||
host string
|
||||
wg := &sync.WaitGroup{}
|
||||
m := &sync.Map{}
|
||||
|
||||
// TODO(s.chzhen): Separate to a different structure/file.
|
||||
worker := func(upstreamLine string, check healthCheckFunc) {
|
||||
defer log.OnPanic("dnsforward: checking upstreams")
|
||||
|
||||
res := s.checkDNS(upstreamLine, opts, check)
|
||||
for ups, status := range res {
|
||||
m.Store(ups, status)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty)
|
||||
req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty)
|
||||
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
|
||||
|
||||
upsNum := len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams)
|
||||
result := make(map[string]string, upsNum)
|
||||
resCh := make(chan upsCheckResult, upsNum)
|
||||
wg.Add(len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams))
|
||||
|
||||
for _, ups := range req.Upstreams {
|
||||
go func(ups string) {
|
||||
resCh <- upsCheckResult{
|
||||
host: ups,
|
||||
err: s.checkDNS(ups, opts, checkDNSUpstreamExc),
|
||||
}
|
||||
}(ups)
|
||||
go worker(ups, checkDNSUpstreamExc)
|
||||
}
|
||||
for _, ups := range req.FallbackDNS {
|
||||
go func(ups string) {
|
||||
resCh <- upsCheckResult{
|
||||
host: ups,
|
||||
err: s.checkDNS(ups, opts, checkDNSUpstreamExc),
|
||||
}
|
||||
}(ups)
|
||||
go worker(ups, checkDNSUpstreamExc)
|
||||
}
|
||||
for _, ups := range req.PrivateUpstreams {
|
||||
go func(ups string) {
|
||||
resCh <- upsCheckResult{
|
||||
host: ups,
|
||||
err: s.checkDNS(ups, opts, checkPrivateUpstreamExc),
|
||||
}
|
||||
}(ups)
|
||||
go worker(ups, checkPrivateUpstreamExc)
|
||||
}
|
||||
|
||||
for i := 0; i < upsNum; i++ {
|
||||
// TODO(e.burkov): The upstreams used for both common and private
|
||||
// resolving should be reported separately.
|
||||
pair := <-resCh
|
||||
if pair.err != nil {
|
||||
result[pair.host] = pair.err.Error()
|
||||
} else {
|
||||
result[pair.host] = "OK"
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
result := map[string]string{}
|
||||
m.Range(func(k, v any) bool {
|
||||
ups := k.(string)
|
||||
status := v.(string)
|
||||
|
||||
result[ups] = status
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, result)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user