Pull request 2100: v0.107.42-rc
Squashed commit of the following: commit 284190f748345c7556e60b67f051ec5f6f080948 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Wed Dec 6 19:36:00 2023 +0300 all: sync with master; upd chlog
This commit is contained in:
@@ -6,20 +6,15 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
@@ -45,8 +40,19 @@ type jsonDNSConfig struct {
|
||||
// ProtectionEnabled defines if protection is enabled.
|
||||
ProtectionEnabled *bool `json:"protection_enabled"`
|
||||
|
||||
// RateLimit is the number of requests per second allowed per client.
|
||||
RateLimit *uint32 `json:"ratelimit"`
|
||||
// Ratelimit is the number of requests per second allowed per client.
|
||||
Ratelimit *uint32 `json:"ratelimit"`
|
||||
|
||||
// RatelimitSubnetLenIPv4 is a subnet length for IPv4 addresses used for
|
||||
// rate limiting requests.
|
||||
RatelimitSubnetLenIPv4 *int `json:"ratelimit_subnet_len_ipv4"`
|
||||
|
||||
// RatelimitSubnetLenIPv6 is a subnet length for IPv6 addresses used for
|
||||
// rate limiting requests.
|
||||
RatelimitSubnetLenIPv6 *int `json:"ratelimit_subnet_len_ipv6"`
|
||||
|
||||
// RatelimitWhitelist is a list of IP addresses excluded from rate limiting.
|
||||
RatelimitWhitelist *[]netip.Addr `json:"ratelimit_whitelist"`
|
||||
|
||||
// BlockingMode defines the way blocked responses are constructed.
|
||||
BlockingMode *filtering.BlockingMode `json:"blocking_mode"`
|
||||
@@ -121,6 +127,9 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
blockingMode, blockingIPv4, blockingIPv6 := s.dnsFilter.BlockingMode()
|
||||
blockedResponseTTL := s.dnsFilter.BlockedResponseTTL()
|
||||
ratelimit := s.conf.Ratelimit
|
||||
ratelimitSubnetLenIPv4 := s.conf.RatelimitSubnetLenIPv4
|
||||
ratelimitSubnetLenIPv6 := s.conf.RatelimitSubnetLenIPv6
|
||||
ratelimitWhitelist := append([]netip.Addr{}, s.conf.RatelimitWhitelist...)
|
||||
|
||||
customIP := s.conf.EDNSClientSubnet.CustomIP
|
||||
enableEDNSClientSubnet := s.conf.EDNSClientSubnet.Enabled
|
||||
@@ -157,7 +166,10 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
BlockingMode: &blockingMode,
|
||||
BlockingIPv4: blockingIPv4,
|
||||
BlockingIPv6: blockingIPv6,
|
||||
RateLimit: &ratelimit,
|
||||
Ratelimit: &ratelimit,
|
||||
RatelimitSubnetLenIPv4: &ratelimitSubnetLenIPv4,
|
||||
RatelimitSubnetLenIPv6: &ratelimitSubnetLenIPv6,
|
||||
RatelimitWhitelist: &ratelimitWhitelist,
|
||||
EDNSCSCustomIP: customIP,
|
||||
EDNSCSEnabled: &enableEDNSClientSubnet,
|
||||
EDNSCSUseCustom: &useCustom,
|
||||
@@ -180,13 +192,13 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
// defaultLocalPTRUpstreams returns the list of default local PTR resolvers
|
||||
// filtered of AdGuard Home's own DNS server addresses. It may appear empty.
|
||||
func (s *Server) defaultLocalPTRUpstreams() (ups []string, err error) {
|
||||
matcher, err := s.conf.ourAddrsMatcher()
|
||||
matcher, err := s.conf.ourAddrsSet()
|
||||
if err != nil {
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sysResolvers := slices.DeleteFunc(s.sysResolvers.Addrs(), matcher)
|
||||
sysResolvers := slices.DeleteFunc(s.sysResolvers.Addrs(), matcher.Has)
|
||||
ups = make([]string, 0, len(sysResolvers))
|
||||
for _, r := range sysResolvers {
|
||||
ups = append(ups, r.String())
|
||||
@@ -201,6 +213,7 @@ func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
}
|
||||
|
||||
// checkBlockingMode returns an error if blocking mode is invalid.
|
||||
func (req *jsonDNSConfig) checkBlockingMode() (err error) {
|
||||
if req.BlockingMode == nil {
|
||||
return nil
|
||||
@@ -209,12 +222,21 @@ func (req *jsonDNSConfig) checkBlockingMode() (err error) {
|
||||
return validateBlockingMode(*req.BlockingMode, req.BlockingIPv4, req.BlockingIPv6)
|
||||
}
|
||||
|
||||
func (req *jsonDNSConfig) checkUpstreamsMode() bool {
|
||||
valid := []string{"", "fastest_addr", "parallel"}
|
||||
// checkUpstreamsMode returns an error if the upstream mode is invalid.
|
||||
func (req *jsonDNSConfig) checkUpstreamsMode() (err error) {
|
||||
if req.UpstreamMode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return req.UpstreamMode == nil || stringutil.InSlice(valid, *req.UpstreamMode)
|
||||
mode := *req.UpstreamMode
|
||||
if ok := slices.Contains([]string{"", "fastest_addr", "parallel"}, mode); !ok {
|
||||
return fmt.Errorf("upstream_mode: incorrect value %q", mode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkBootstrap returns an error if any bootstrap address is invalid.
|
||||
func (req *jsonDNSConfig) checkBootstrap() (err error) {
|
||||
if req.Bootstraps == nil {
|
||||
return nil
|
||||
@@ -229,6 +251,7 @@ func (req *jsonDNSConfig) checkBootstrap() (err error) {
|
||||
}
|
||||
|
||||
if _, err = upstream.NewUpstreamResolver(b, nil); err != nil {
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -244,67 +267,136 @@ func (req *jsonDNSConfig) checkFallbacks() (err error) {
|
||||
|
||||
err = ValidateUpstreams(*req.Fallbacks)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating fallback servers: %w", err)
|
||||
return fmt.Errorf("fallback servers: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validate returns an error if any field of req is invalid.
|
||||
//
|
||||
// TODO(s.chzhen): Parse, don't validate.
|
||||
func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "validating dns config: %w") }()
|
||||
|
||||
err = req.validateUpstreamDNSServers(privateNets)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
err = req.checkRatelimitSubnetMaskLen()
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
err = req.checkBlockingMode()
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
err = req.checkUpstreamsMode()
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
err = req.checkCacheTTL()
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateUpstreamDNSServers returns an error if any field of req is invalid.
|
||||
func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetSet) (err error) {
|
||||
if req.Upstreams != nil {
|
||||
err = ValidateUpstreams(*req.Upstreams)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating upstream servers: %w", err)
|
||||
return fmt.Errorf("upstream servers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if req.LocalPTRUpstreams != nil {
|
||||
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating private upstream servers: %w", err)
|
||||
return fmt.Errorf("private upstream servers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = req.checkBootstrap()
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
err = req.checkFallbacks()
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
err = req.checkBlockingMode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case !req.checkUpstreamsMode():
|
||||
return errors.Error("upstream_mode: incorrect value")
|
||||
case !req.checkCacheTTL():
|
||||
return errors.Error("cache_ttl_min must be less or equal than cache_ttl_max")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (req *jsonDNSConfig) checkCacheTTL() bool {
|
||||
// checkCacheTTL returns an error if the configuration of the cache TTL is
|
||||
// invalid.
|
||||
func (req *jsonDNSConfig) checkCacheTTL() (err error) {
|
||||
if req.CacheMinTTL == nil && req.CacheMaxTTL == nil {
|
||||
return true
|
||||
return nil
|
||||
}
|
||||
|
||||
var min, max uint32
|
||||
var minTTL, maxTTL uint32
|
||||
if req.CacheMinTTL != nil {
|
||||
min = *req.CacheMinTTL
|
||||
minTTL = *req.CacheMinTTL
|
||||
}
|
||||
if req.CacheMaxTTL != nil {
|
||||
max = *req.CacheMaxTTL
|
||||
maxTTL = *req.CacheMaxTTL
|
||||
}
|
||||
|
||||
return min <= max
|
||||
if minTTL <= maxTTL {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.Error("cache_ttl_min must be less or equal than cache_ttl_max")
|
||||
}
|
||||
|
||||
// checkRatelimitSubnetMaskLen returns an error if the length of the subnet mask
|
||||
// for IPv4 or IPv6 addresses is invalid.
|
||||
func (req *jsonDNSConfig) checkRatelimitSubnetMaskLen() (err error) {
|
||||
err = checkInclusion(req.RatelimitSubnetLenIPv4, 0, netutil.IPv4BitLen)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ratelimit_subnet_len_ipv4 is invalid: %w", err)
|
||||
}
|
||||
|
||||
err = checkInclusion(req.RatelimitSubnetLenIPv6, 0, netutil.IPv6BitLen)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ratelimit_subnet_len_ipv6 is invalid: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkInclusion returns an error if a ptr is not nil and points to value,
|
||||
// that not in the inclusive range between minN and maxN.
|
||||
func checkInclusion(ptr *int, minN, maxN int) (err error) {
|
||||
if ptr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
n := *ptr
|
||||
switch {
|
||||
case n < minN:
|
||||
return fmt.Errorf("value %d less than min %d", n, minN)
|
||||
case n > maxN:
|
||||
return fmt.Errorf("value %d greater than max %d", n, maxN)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleSetConfig handles requests to the POST /control/dns_config endpoint.
|
||||
@@ -401,6 +493,9 @@ func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
setIfNotNil(&s.conf.CacheOptimistic, dc.CacheOptimistic),
|
||||
setIfNotNil(&s.conf.AddrProcConf.UseRDNS, dc.ResolveClients),
|
||||
setIfNotNil(&s.conf.UsePrivateRDNS, dc.UsePrivateRDNS),
|
||||
setIfNotNil(&s.conf.RatelimitSubnetLenIPv4, dc.RatelimitSubnetLenIPv4),
|
||||
setIfNotNil(&s.conf.RatelimitSubnetLenIPv6, dc.RatelimitSubnetLenIPv6),
|
||||
setIfNotNil(&s.conf.RatelimitWhitelist, dc.RatelimitWhitelist),
|
||||
} {
|
||||
shouldRestart = shouldRestart || hasSet
|
||||
if shouldRestart {
|
||||
@@ -408,8 +503,8 @@ func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
}
|
||||
}
|
||||
|
||||
if dc.RateLimit != nil && s.conf.Ratelimit != *dc.RateLimit {
|
||||
s.conf.Ratelimit = *dc.RateLimit
|
||||
if dc.Ratelimit != nil && s.conf.Ratelimit != *dc.Ratelimit {
|
||||
s.conf.Ratelimit = *dc.Ratelimit
|
||||
shouldRestart = true
|
||||
}
|
||||
|
||||
@@ -424,374 +519,11 @@ type upstreamJSON struct {
|
||||
PrivateUpstreams []string `json:"private_upstream"`
|
||||
}
|
||||
|
||||
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
|
||||
// This function is useful for filtering out non-upstream lines from upstream
|
||||
// configs.
|
||||
func IsCommentOrEmpty(s string) (ok bool) {
|
||||
return len(s) == 0 || s[0] == '#'
|
||||
}
|
||||
|
||||
// newUpstreamConfig validates upstreams and returns an appropriate upstream
|
||||
// configuration or nil if it can't be built.
|
||||
//
|
||||
// TODO(e.burkov): Perhaps proxy.ParseUpstreamsConfig should validate upstreams
|
||||
// slice already so that this function may be considered useless.
|
||||
func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err error) {
|
||||
// No need to validate comments and empty lines.
|
||||
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
|
||||
if len(upstreams) == 0 {
|
||||
// Consider this case valid since it means the default server should be
|
||||
// used.
|
||||
return nil, nil
|
||||
// closeBoots closes all the provided bootstrap servers and logs errors if any.
|
||||
func closeBoots(boots []*upstream.UpstreamResolver) {
|
||||
for _, c := range boots {
|
||||
logCloserErr(c, "dnsforward: closing bootstrap %s: %s", c.Address())
|
||||
}
|
||||
|
||||
err = validateUpstreamConfig(upstreams)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conf, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
&upstream.Options{
|
||||
Bootstrap: []string{},
|
||||
Timeout: DefaultTimeout,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
} else if len(conf.Upstreams) == 0 {
|
||||
return nil, errors.Error("no default upstreams specified")
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
// validateUpstreamConfig validates each upstream from the upstream
|
||||
// configuration and returns an error if any upstream is invalid.
|
||||
//
|
||||
// TODO(e.burkov): Move into aghnet or even into dnsproxy.
|
||||
func validateUpstreamConfig(conf []string) (err error) {
|
||||
for _, u := range conf {
|
||||
var ups []string
|
||||
var domains []string
|
||||
ups, domains, err = separateUpstream(u)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
for _, addr := range ups {
|
||||
_, err = validateUpstream(addr, domains)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating upstream %q: %w", addr, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateUpstreams validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified.
|
||||
//
|
||||
// TODO(e.burkov): Move into aghnet or even into dnsproxy.
|
||||
func ValidateUpstreams(upstreams []string) (err error) {
|
||||
_, err = newUpstreamConfig(upstreams)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified. It also
|
||||
// checks each domain of domain-specific upstreams for being ARPA pointing to
|
||||
// a locally-served network. privateNets must not be nil.
|
||||
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
|
||||
conf, err := newUpstreamConfig(upstreams)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating config: %w", err)
|
||||
}
|
||||
|
||||
if conf == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
keys := maps.Keys(conf.DomainReservedUpstreams)
|
||||
slices.Sort(keys)
|
||||
|
||||
var errs []error
|
||||
for _, domain := range keys {
|
||||
var subnet netip.Prefix
|
||||
subnet, err = extractARPASubnet(domain)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !privateNets.Contains(subnet.Addr().AsSlice()) {
|
||||
errs = append(
|
||||
errs,
|
||||
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w")
|
||||
}
|
||||
|
||||
var protocols = []string{
|
||||
"h3://",
|
||||
"https://",
|
||||
"quic://",
|
||||
"sdns://",
|
||||
"tcp://",
|
||||
"tls://",
|
||||
"udp://",
|
||||
}
|
||||
|
||||
// validateUpstream returns an error if u alongside with domains is not a valid
|
||||
// upstream configuration. useDefault is true if the upstream is
|
||||
// domain-specific and is configured to point at the default upstream server
|
||||
// which is validated separately. The upstream is considered domain-specific
|
||||
// only if domains is at least not nil.
|
||||
func validateUpstream(u string, domains []string) (useDefault bool, err error) {
|
||||
// The special server address '#' means that default server must be used.
|
||||
if useDefault = u == "#" && domains != nil; useDefault {
|
||||
return useDefault, nil
|
||||
}
|
||||
|
||||
// Check if the upstream has a valid protocol prefix.
|
||||
//
|
||||
// TODO(e.burkov): Validate the domain name.
|
||||
for _, proto := range protocols {
|
||||
if strings.HasPrefix(u, proto) {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
if proto, _, ok := strings.Cut(u, "://"); ok {
|
||||
return false, fmt.Errorf("bad protocol %q", proto)
|
||||
}
|
||||
|
||||
// Check if upstream is either an IP or IP with port.
|
||||
if _, err = netip.ParseAddr(u); err == nil {
|
||||
return false, nil
|
||||
} else if _, err = netip.ParseAddrPort(u); err == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return false, err
|
||||
}
|
||||
|
||||
// separateUpstream returns the upstreams and the specified domains. domains
|
||||
// is nil when the upstream is not domains-specific. Otherwise it may also be
|
||||
// empty.
|
||||
func separateUpstream(upstreamStr string) (upstreams, domains []string, err error) {
|
||||
if !strings.HasPrefix(upstreamStr, "[/") {
|
||||
return []string{upstreamStr}, nil, nil
|
||||
}
|
||||
|
||||
defer func() { err = errors.Annotate(err, "bad upstream for domain %q: %w", upstreamStr) }()
|
||||
|
||||
parts := strings.Split(upstreamStr[2:], "/]")
|
||||
switch len(parts) {
|
||||
case 2:
|
||||
// Go on.
|
||||
case 1:
|
||||
return nil, nil, errors.Error("missing separator")
|
||||
default:
|
||||
return nil, nil, errors.Error("duplicated separator")
|
||||
}
|
||||
|
||||
for i, host := range strings.Split(parts[0], "/") {
|
||||
if host == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*."))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("domain at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
domains = append(domains, host)
|
||||
}
|
||||
|
||||
return strings.Fields(parts[1]), domains, nil
|
||||
}
|
||||
|
||||
// healthCheckFunc is a signature of function to check if upstream exchanges
|
||||
// properly.
|
||||
type healthCheckFunc func(u upstream.Upstream) (err error)
|
||||
|
||||
// checkDNSUpstreamExc checks if the DNS upstream exchanges correctly.
|
||||
func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
|
||||
// testTLD is the special-use fully-qualified domain name for testing the
|
||||
// DNS server reachability.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc6761#section-6.2.
|
||||
const testTLD = "test."
|
||||
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: dns.Id(),
|
||||
RecursionDesired: true,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: testTLD,
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
var reply *dns.Msg
|
||||
reply, err = u.Exchange(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't communicate with upstream: %w", err)
|
||||
} else if len(reply.Answer) != 0 {
|
||||
return errors.Error("wrong response")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkPrivateUpstreamExc checks if the upstream for resolving private
|
||||
// addresses exchanges correctly.
|
||||
//
|
||||
// TODO(e.burkov): Think about testing the ip6.arpa. as well.
|
||||
func checkPrivateUpstreamExc(u upstream.Upstream) (err error) {
|
||||
// inAddrArpaTLD is the special-use fully-qualified domain name for PTR IP
|
||||
// address resolution.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1035#section-3.5.
|
||||
const inAddrArpaTLD = "in-addr.arpa."
|
||||
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: dns.Id(),
|
||||
RecursionDesired: true,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: inAddrArpaTLD,
|
||||
Qtype: dns.TypePTR,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
if _, err = u.Exchange(req); err != nil {
|
||||
return fmt.Errorf("couldn't communicate with upstream: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark
|
||||
// the tested upstream domain-specific and therefore consider its errors
|
||||
// non-critical.
|
||||
//
|
||||
// TODO(a.garipov): Some common mechanism of distinguishing between errors and
|
||||
// warnings (non-critical errors) is desired.
|
||||
type domainSpecificTestError struct {
|
||||
error
|
||||
}
|
||||
|
||||
// Error implements the [error] interface for domainSpecificTestError.
|
||||
func (err domainSpecificTestError) Error() (msg string) {
|
||||
return fmt.Sprintf("WARNING: %s", err.error)
|
||||
}
|
||||
|
||||
// checkDNS parses line, creates DNS upstreams using opts, and checks if the
|
||||
// upstreams are exchanging correctly. It saves the result into a sync.Map
|
||||
// where key is an upstream address and value is "OK", if the upstream
|
||||
// exchanges correctly, or text of the error. It is intended to be used as a
|
||||
// goroutine.
|
||||
//
|
||||
// TODO(s.chzhen): Separate to a different structure/file.
|
||||
func (s *Server) checkDNS(
|
||||
line string,
|
||||
opts *upstream.Options,
|
||||
check healthCheckFunc,
|
||||
wg *sync.WaitGroup,
|
||||
m *sync.Map,
|
||||
) {
|
||||
defer wg.Done()
|
||||
defer log.OnPanic("dnsforward: checking upstreams")
|
||||
|
||||
upstreams, domains, err := separateUpstream(line)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("wrong upstream format: %w", err)
|
||||
m.Store(line, err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
specific := len(domains) > 0
|
||||
|
||||
for _, upstreamAddr := range upstreams {
|
||||
var useDefault bool
|
||||
useDefault, err = validateUpstream(upstreamAddr, domains)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("wrong upstream format: %w", err)
|
||||
m.Store(upstreamAddr, err.Error())
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if useDefault {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
|
||||
|
||||
err = s.checkUpstreamAddr(upstreamAddr, specific, opts, check)
|
||||
if err != nil {
|
||||
m.Store(upstreamAddr, err.Error())
|
||||
} else {
|
||||
m.Store(upstreamAddr, "OK")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkUpstreamAddr creates the DNS upstream using opts and information from
|
||||
// [s.dnsFilter.EtcHosts]. Checks if the DNS upstream exchanges correctly. It
|
||||
// returns an error if addr is not valid DNS upstream address or the upstream
|
||||
// is not exchanging correctly.
|
||||
func (s *Server) checkUpstreamAddr(
|
||||
addr string,
|
||||
specific bool,
|
||||
opts *upstream.Options,
|
||||
check healthCheckFunc,
|
||||
) (err error) {
|
||||
defer func() {
|
||||
if err != nil && specific {
|
||||
err = domainSpecificTestError{error: err}
|
||||
}
|
||||
}()
|
||||
|
||||
opts = &upstream.Options{
|
||||
Bootstrap: opts.Bootstrap,
|
||||
Timeout: opts.Timeout,
|
||||
PreferIPv6: opts.PreferIPv6,
|
||||
}
|
||||
|
||||
// dnsFilter can be nil during application update.
|
||||
if s.dnsFilter != nil {
|
||||
recs := s.dnsFilter.EtcHostsRecords(extractUpstreamHost(addr))
|
||||
for _, rec := range recs {
|
||||
opts.ServerIPAddrs = append(opts.ServerIPAddrs, rec.Addr.AsSlice())
|
||||
}
|
||||
sortNetIPAddrs(opts.ServerIPAddrs, opts.PreferIPv6)
|
||||
}
|
||||
|
||||
u, err := upstream.AddressToUpstream(addr, opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating upstream for %q: %w", addr, err)
|
||||
}
|
||||
|
||||
defer func() { err = errors.WithDeferred(err, u.Close()) }()
|
||||
|
||||
return check(u)
|
||||
}
|
||||
|
||||
// handleTestUpstreamDNS handles requests to the POST /control/test_upstream_dns
|
||||
@@ -808,46 +540,27 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty)
|
||||
req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty)
|
||||
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
|
||||
req.BootstrapDNS = stringutil.FilterOut(req.BootstrapDNS, IsCommentOrEmpty)
|
||||
|
||||
opts := &upstream.Options{
|
||||
Bootstrap: req.BootstrapDNS,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
}
|
||||
if len(opts.Bootstrap) == 0 {
|
||||
opts.Bootstrap = defaultBootstrap
|
||||
|
||||
var boots []*upstream.UpstreamResolver
|
||||
opts.Bootstrap, boots, err = s.createBootstrap(req.BootstrapDNS, opts)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse bootstrap servers: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
defer closeBoots(boots)
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
m := &sync.Map{}
|
||||
cv := newUpstreamConfigValidator(req.Upstreams, req.FallbackDNS, req.PrivateUpstreams, opts)
|
||||
cv.check()
|
||||
cv.close()
|
||||
|
||||
wg.Add(len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams))
|
||||
|
||||
for _, ups := range req.Upstreams {
|
||||
go s.checkDNS(ups, opts, checkDNSUpstreamExc, wg, m)
|
||||
}
|
||||
for _, ups := range req.FallbackDNS {
|
||||
go s.checkDNS(ups, opts, checkDNSUpstreamExc, wg, m)
|
||||
}
|
||||
for _, ups := range req.PrivateUpstreams {
|
||||
go s.checkDNS(ups, opts, checkPrivateUpstreamExc, wg, m)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
result := map[string]string{}
|
||||
m.Range(func(k, v any) bool {
|
||||
// TODO(e.burkov): The upstreams used for both common and private
|
||||
// resolving should be reported separately.
|
||||
ups := k.(string)
|
||||
status := v.(string)
|
||||
|
||||
result[ups] = status
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, result)
|
||||
aghhttp.WriteJSONResponseOK(w, r, cv.status())
|
||||
}
|
||||
|
||||
// handleCacheClear is the handler for the POST /control/cache_clear HTTP API.
|
||||
|
||||
Reference in New Issue
Block a user