Pull request 2100: v0.107.42-rc
Squashed commit of the following: commit 284190f748345c7556e60b67f051ec5f6f080948 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Wed Dec 6 19:36:00 2023 +0300 all: sync with master; upd chlog
This commit is contained in:
@@ -27,6 +27,19 @@ import (
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// ClientsContainer provides information about preconfigured DNS clients.
|
||||
type ClientsContainer interface {
|
||||
// UpstreamConfigByID returns the custom upstream configuration for the
|
||||
// client having id, using boot to initialize the one if necessary. It
|
||||
// returns nil if there is no custom upstream configuration for the client.
|
||||
// The id is expected to be either a string representation of an IP address
|
||||
// or the ClientID.
|
||||
UpstreamConfigByID(
|
||||
id string,
|
||||
boot upstream.Resolver,
|
||||
) (conf *proxy.CustomUpstreamConfig, err error)
|
||||
}
|
||||
|
||||
// Config represents the DNS filtering configuration of AdGuard Home. The zero
|
||||
// Config is empty and ready for use.
|
||||
type Config struct {
|
||||
@@ -35,10 +48,9 @@ type Config struct {
|
||||
// FilterHandler is an optional additional filtering callback.
|
||||
FilterHandler func(cliAddr netip.Addr, clientID string, settings *filtering.Settings) `yaml:"-"`
|
||||
|
||||
// GetCustomUpstreamByClient is a callback that returns upstreams
|
||||
// configuration based on the client IP address or ClientID. It returns
|
||||
// nil if there are no custom upstreams for the client.
|
||||
GetCustomUpstreamByClient func(id string) (conf *proxy.UpstreamConfig, err error) `yaml:"-"`
|
||||
// ClientsContainer stores the information about special handling of some
|
||||
// DNS clients.
|
||||
ClientsContainer ClientsContainer `yaml:"-"`
|
||||
|
||||
// Anti-DNS amplification
|
||||
|
||||
@@ -55,7 +67,7 @@ type Config struct {
|
||||
RatelimitSubnetLenIPv6 int `yaml:"ratelimit_subnet_len_ipv6"`
|
||||
|
||||
// RatelimitWhitelist is the list of whitelisted client IP addresses.
|
||||
RatelimitWhitelist []string `yaml:"ratelimit_whitelist"`
|
||||
RatelimitWhitelist []netip.Addr `yaml:"ratelimit_whitelist"`
|
||||
|
||||
// RefuseAny, if true, refuse ANY requests.
|
||||
RefuseAny bool `yaml:"refuse_any"`
|
||||
@@ -277,32 +289,33 @@ type ServerConfig struct {
|
||||
// UseHTTP3Upstreams defines if HTTP/3 is be allowed for DNS-over-HTTPS
|
||||
// upstreams.
|
||||
UseHTTP3Upstreams bool
|
||||
|
||||
// ServePlainDNS defines if plain DNS is allowed for incoming requests.
|
||||
ServePlainDNS bool
|
||||
}
|
||||
|
||||
// createProxyConfig creates and validates configuration for the main proxy.
|
||||
func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
|
||||
// newProxyConfig creates and validates configuration for the main proxy.
|
||||
func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
|
||||
srvConf := s.conf
|
||||
conf = proxy.Config{
|
||||
UDPListenAddr: srvConf.UDPListenAddrs,
|
||||
TCPListenAddr: srvConf.TCPListenAddrs,
|
||||
HTTP3: srvConf.ServeHTTP3,
|
||||
Ratelimit: int(srvConf.Ratelimit),
|
||||
RatelimitSubnetMaskIPv4: net.CIDRMask(srvConf.RatelimitSubnetLenIPv4, netutil.IPv4BitLen),
|
||||
RatelimitSubnetMaskIPv6: net.CIDRMask(srvConf.RatelimitSubnetLenIPv6, netutil.IPv6BitLen),
|
||||
RatelimitWhitelist: srvConf.RatelimitWhitelist,
|
||||
RefuseAny: srvConf.RefuseAny,
|
||||
TrustedProxies: srvConf.TrustedProxies,
|
||||
CacheMinTTL: srvConf.CacheMinTTL,
|
||||
CacheMaxTTL: srvConf.CacheMaxTTL,
|
||||
CacheOptimistic: srvConf.CacheOptimistic,
|
||||
UpstreamConfig: srvConf.UpstreamConfig,
|
||||
BeforeRequestHandler: s.beforeRequestHandler,
|
||||
RequestHandler: s.handleDNSRequest,
|
||||
HTTPSServerName: aghhttp.UserAgent(),
|
||||
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
|
||||
MaxGoroutines: int(srvConf.MaxGoroutines),
|
||||
UseDNS64: srvConf.UseDNS64,
|
||||
DNS64Prefs: srvConf.DNS64Prefixes,
|
||||
conf = &proxy.Config{
|
||||
HTTP3: srvConf.ServeHTTP3,
|
||||
Ratelimit: int(srvConf.Ratelimit),
|
||||
RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4,
|
||||
RatelimitSubnetLenIPv6: srvConf.RatelimitSubnetLenIPv6,
|
||||
RatelimitWhitelist: srvConf.RatelimitWhitelist,
|
||||
RefuseAny: srvConf.RefuseAny,
|
||||
TrustedProxies: srvConf.TrustedProxies,
|
||||
CacheMinTTL: srvConf.CacheMinTTL,
|
||||
CacheMaxTTL: srvConf.CacheMaxTTL,
|
||||
CacheOptimistic: srvConf.CacheOptimistic,
|
||||
UpstreamConfig: srvConf.UpstreamConfig,
|
||||
BeforeRequestHandler: s.beforeRequestHandler,
|
||||
RequestHandler: s.handleDNSRequest,
|
||||
HTTPSServerName: aghhttp.UserAgent(),
|
||||
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
|
||||
MaxGoroutines: int(srvConf.MaxGoroutines),
|
||||
UseDNS64: srvConf.UseDNS64,
|
||||
DNS64Prefs: srvConf.DNS64Prefixes,
|
||||
}
|
||||
|
||||
if srvConf.EDNSClientSubnet.UseCustom {
|
||||
@@ -316,27 +329,25 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
|
||||
}
|
||||
|
||||
setProxyUpstreamMode(
|
||||
&conf,
|
||||
conf,
|
||||
srvConf.AllServers,
|
||||
srvConf.FastestAddr,
|
||||
srvConf.FastestTimeout.Duration,
|
||||
)
|
||||
|
||||
for i, s := range srvConf.BogusNXDomain {
|
||||
var subnet *net.IPNet
|
||||
subnet, err = netutil.ParseSubnet(s)
|
||||
if err != nil {
|
||||
log.Error("subnet at index %d: %s", i, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
conf.BogusNXDomain = append(conf.BogusNXDomain, subnet)
|
||||
conf.BogusNXDomain, err = parseBogusNXDOMAIN(srvConf.BogusNXDomain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bogus_nxdomain: %w", err)
|
||||
}
|
||||
|
||||
err = s.prepareTLS(&conf)
|
||||
err = s.prepareTLS(conf)
|
||||
if err != nil {
|
||||
return proxy.Config{}, fmt.Errorf("validating tls: %w", err)
|
||||
return nil, fmt.Errorf("validating tls: %w", err)
|
||||
}
|
||||
|
||||
err = s.preparePlain(conf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validating plain: %w", err)
|
||||
}
|
||||
|
||||
if c := srvConf.DNSCryptConfig; c.Enabled {
|
||||
@@ -347,12 +358,27 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
|
||||
}
|
||||
|
||||
if conf.UpstreamConfig == nil || len(conf.UpstreamConfig.Upstreams) == 0 {
|
||||
return proxy.Config{}, errors.Error("no default upstream servers configured")
|
||||
return nil, errors.Error("no default upstream servers configured")
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
// parseBogusNXDOMAIN parses the bogus NXDOMAIN strings into valid subnets.
|
||||
func parseBogusNXDOMAIN(confBogusNXDOMAIN []string) (subnets []netip.Prefix, err error) {
|
||||
for i, s := range confBogusNXDOMAIN {
|
||||
var subnet netip.Prefix
|
||||
subnet, err = aghnet.ParseSubnet(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("subnet at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
subnets = append(subnets, subnet)
|
||||
}
|
||||
|
||||
return subnets, nil
|
||||
}
|
||||
|
||||
const defaultBlockedResponseTTL = 3600
|
||||
|
||||
// initDefaultSettings initializes default settings if nothing
|
||||
@@ -423,10 +449,7 @@ func collectListenAddr(
|
||||
|
||||
// collectDNSAddrs returns configured set of listening addresses. It also
|
||||
// returns a set of ports of each unspecified listening address.
|
||||
func (conf *ServerConfig) collectDNSAddrs() (
|
||||
addrs map[netip.AddrPort]unit,
|
||||
unspecPorts map[uint16]unit,
|
||||
) {
|
||||
func (conf *ServerConfig) collectDNSAddrs() (addrs mapAddrPortSet, unspecPorts map[uint16]unit) {
|
||||
// TODO(e.burkov): Perhaps, we shouldn't allocate as much memory, since the
|
||||
// TCP and UDP listening addresses are currently the same.
|
||||
addrs = make(map[netip.AddrPort]unit, len(conf.TCPListenAddrs)+len(conf.UDPListenAddrs))
|
||||
@@ -446,20 +469,64 @@ func (conf *ServerConfig) collectDNSAddrs() (
|
||||
// defaultPlainDNSPort is the default port for plain DNS.
|
||||
const defaultPlainDNSPort uint16 = 53
|
||||
|
||||
// addrPortMatcher is a function that matches an IP address with port.
|
||||
type addrPortMatcher func(addr netip.AddrPort) (ok bool)
|
||||
// addrPortSet is a set of [netip.AddrPort] values.
|
||||
type addrPortSet interface {
|
||||
// Has returns true if addrPort is in the set.
|
||||
Has(addrPort netip.AddrPort) (ok bool)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ addrPortSet = emptyAddrPortSet{}
|
||||
|
||||
// emptyAddrPortSet is the [addrPortSet] containing no values.
|
||||
type emptyAddrPortSet struct{}
|
||||
|
||||
// Has implements the [addrPortSet] interface for [emptyAddrPortSet].
|
||||
func (emptyAddrPortSet) Has(_ netip.AddrPort) (ok bool) { return false }
|
||||
|
||||
// mapAddrPortSet is the [addrPortSet] containing values of [netip.AddrPort] as
|
||||
// keys of a map.
|
||||
type mapAddrPortSet map[netip.AddrPort]unit
|
||||
|
||||
// type check
|
||||
var _ addrPortSet = mapAddrPortSet{}
|
||||
|
||||
// Has implements the [addrPortSet] interface for [mapAddrPortSet].
|
||||
func (m mapAddrPortSet) Has(addrPort netip.AddrPort) (ok bool) {
|
||||
_, ok = m[addrPort]
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// combinedAddrPortSet is the [addrPortSet] defined by some IP addresses along
|
||||
// with ports, any combination of which is considered being in the set.
|
||||
type combinedAddrPortSet struct {
|
||||
// TODO(e.burkov): Use sorted slices in combination with binary search.
|
||||
ports map[uint16]unit
|
||||
addrs []netip.Addr
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ addrPortSet = (*combinedAddrPortSet)(nil)
|
||||
|
||||
// Has implements the [addrPortSet] interface for [*combinedAddrPortSet].
|
||||
func (m *combinedAddrPortSet) Has(addrPort netip.AddrPort) (ok bool) {
|
||||
_, ok = m.ports[addrPort.Port()]
|
||||
|
||||
return ok && slices.Contains(m.addrs, addrPort.Addr())
|
||||
}
|
||||
|
||||
// filterOut filters out all the upstreams that match um. It returns all the
|
||||
// closing errors joined.
|
||||
func (m addrPortMatcher) filterOut(upsConf *proxy.UpstreamConfig) (err error) {
|
||||
func filterOutAddrs(upsConf *proxy.UpstreamConfig, set addrPortSet) (err error) {
|
||||
var errs []error
|
||||
delFunc := func(u upstream.Upstream) (ok bool) {
|
||||
// TODO(e.burkov): We should probably consider the protocol of u to
|
||||
// only filter out the listening addresses of the same protocol.
|
||||
addr, parseErr := aghnet.ParseAddrPort(u.Address(), defaultPlainDNSPort)
|
||||
if parseErr != nil || !m(addr) {
|
||||
if parseErr != nil || !set.Has(addr) {
|
||||
// Don't filter out the upstream if it either cannot be parsed, or
|
||||
// does not match um.
|
||||
// does not match m.
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -479,26 +546,20 @@ func (m addrPortMatcher) filterOut(upsConf *proxy.UpstreamConfig) (err error) {
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// ourAddrsMatcher returns a matcher that matches all the configured listening
|
||||
// ourAddrsSet returns an addrPortSet that contains all the configured listening
|
||||
// addresses.
|
||||
func (conf *ServerConfig) ourAddrsMatcher() (m addrPortMatcher, err error) {
|
||||
func (conf *ServerConfig) ourAddrsSet() (m addrPortSet, err error) {
|
||||
addrs, unspecPorts := conf.collectDNSAddrs()
|
||||
if len(addrs) == 0 {
|
||||
switch {
|
||||
case len(addrs) == 0:
|
||||
log.Debug("dnsforward: no listen addresses")
|
||||
|
||||
// Match no addresses.
|
||||
return func(_ netip.AddrPort) (ok bool) { return false }, nil
|
||||
}
|
||||
|
||||
if len(unspecPorts) == 0 {
|
||||
return emptyAddrPortSet{}, nil
|
||||
case len(unspecPorts) == 0:
|
||||
log.Debug("dnsforward: filtering out addresses %s", addrs)
|
||||
|
||||
m = func(a netip.AddrPort) (ok bool) {
|
||||
_, ok = addrs[a]
|
||||
|
||||
return ok
|
||||
}
|
||||
} else {
|
||||
return addrs, nil
|
||||
default:
|
||||
var ifaceAddrs []netip.Addr
|
||||
ifaceAddrs, err = aghnet.CollectAllIfacesAddrs()
|
||||
if err != nil {
|
||||
@@ -508,16 +569,11 @@ func (conf *ServerConfig) ourAddrsMatcher() (m addrPortMatcher, err error) {
|
||||
|
||||
log.Debug("dnsforward: filtering out addresses %s on ports %d", ifaceAddrs, unspecPorts)
|
||||
|
||||
m = func(a netip.AddrPort) (ok bool) {
|
||||
if _, ok = unspecPorts[a.Port()]; ok {
|
||||
return slices.Contains(ifaceAddrs, a.Addr())
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
return &combinedAddrPortSet{
|
||||
ports: unspecPorts,
|
||||
addrs: ifaceAddrs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// prepareTLS - prepares TLS configuration for the DNS proxy
|
||||
@@ -574,7 +630,7 @@ func (s *Server) prepareTLS(proxyConfig *proxy.Config) (err error) {
|
||||
|
||||
// isWildcard returns true if host is a wildcard hostname.
|
||||
func isWildcard(host string) (ok bool) {
|
||||
return len(host) >= 2 && host[0] == '*' && host[1] == '.'
|
||||
return strings.HasPrefix(host, "*.")
|
||||
}
|
||||
|
||||
// matchesDomainWildcard returns true if host matches the domain wildcard
|
||||
@@ -614,6 +670,31 @@ func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, er
|
||||
return &s.conf.cert, nil
|
||||
}
|
||||
|
||||
// preparePlain prepares the plain-DNS configuration for the DNS proxy.
|
||||
// preparePlain assumes that prepareTLS has already been called.
|
||||
func (s *Server) preparePlain(proxyConf *proxy.Config) (err error) {
|
||||
if s.conf.ServePlainDNS {
|
||||
proxyConf.UDPListenAddr = s.conf.UDPListenAddrs
|
||||
proxyConf.TCPListenAddr = s.conf.TCPListenAddrs
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
lenEncrypted := len(proxyConf.DNSCryptTCPListenAddr) +
|
||||
len(proxyConf.DNSCryptUDPListenAddr) +
|
||||
len(proxyConf.HTTPSListenAddr) +
|
||||
len(proxyConf.QUICListenAddr) +
|
||||
len(proxyConf.TLSListenAddr)
|
||||
if lenEncrypted == 0 {
|
||||
// TODO(a.garipov): Support full disabling of all DNS.
|
||||
return errors.Error("disabling plain dns requires at least one encrypted protocol")
|
||||
}
|
||||
|
||||
log.Info("dnsforward: warning: plain dns is disabled")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatedProtectionStatus updates protection state, if the protection was
|
||||
// disabled temporarily. Returns the updated state of protection.
|
||||
func (s *Server) UpdatedProtectionStatus() (enabled bool, disabledUntil *time.Time) {
|
||||
|
||||
349
internal/dnsforward/configvalidator.go
Normal file
349
internal/dnsforward/configvalidator.go
Normal file
@@ -0,0 +1,349 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// upstreamConfigValidator parses the [*proxy.UpstreamConfig] and checks the
|
||||
// actual DNS availability of each upstream.
|
||||
type upstreamConfigValidator struct {
|
||||
// general is the general upstream configuration.
|
||||
general []*upstreamResult
|
||||
|
||||
// fallback is the fallback upstream configuration.
|
||||
fallback []*upstreamResult
|
||||
|
||||
// private is the private upstream configuration.
|
||||
private []*upstreamResult
|
||||
}
|
||||
|
||||
// upstreamResult is a result of validation of an [upstream.Upstream] within an
|
||||
// [proxy.UpstreamConfig].
|
||||
type upstreamResult struct {
|
||||
// server is the parsed upstream. It is nil when there was an error during
|
||||
// parsing.
|
||||
server upstream.Upstream
|
||||
|
||||
// err is the error either from parsing or from checking the upstream.
|
||||
err error
|
||||
|
||||
// original is the piece of configuration that have either been turned to an
|
||||
// upstream or caused an error.
|
||||
original string
|
||||
|
||||
// isSpecific is true if the upstream is domain-specific.
|
||||
isSpecific bool
|
||||
}
|
||||
|
||||
// compare compares two [upstreamResult]s. It returns 0 if they are equal, -1
|
||||
// if ur should be sorted before other, and 1 otherwise.
|
||||
//
|
||||
// TODO(e.burkov): Perhaps it makes sense to sort the results with errors near
|
||||
// the end.
|
||||
func (ur *upstreamResult) compare(other *upstreamResult) (res int) {
|
||||
return strings.Compare(ur.original, other.original)
|
||||
}
|
||||
|
||||
// newUpstreamConfigValidator parses the upstream configuration and returns a
|
||||
// validator for it. cv already contains the parsed upstreams along with errors
|
||||
// related.
|
||||
func newUpstreamConfigValidator(
|
||||
general []string,
|
||||
fallback []string,
|
||||
private []string,
|
||||
opts *upstream.Options,
|
||||
) (cv *upstreamConfigValidator) {
|
||||
cv = &upstreamConfigValidator{}
|
||||
|
||||
for _, line := range general {
|
||||
cv.general = cv.insertLineResults(cv.general, line, opts)
|
||||
}
|
||||
for _, line := range fallback {
|
||||
cv.fallback = cv.insertLineResults(cv.fallback, line, opts)
|
||||
}
|
||||
for _, line := range private {
|
||||
cv.private = cv.insertLineResults(cv.private, line, opts)
|
||||
}
|
||||
|
||||
return cv
|
||||
}
|
||||
|
||||
// insertLineResults parses line and inserts the result into s. It can insert
|
||||
// multiple results as well as none.
|
||||
func (cv *upstreamConfigValidator) insertLineResults(
|
||||
s []*upstreamResult,
|
||||
line string,
|
||||
opts *upstream.Options,
|
||||
) (result []*upstreamResult) {
|
||||
upstreams, isSpecific, err := splitUpstreamLine(line)
|
||||
if err != nil {
|
||||
return cv.insert(s, &upstreamResult{
|
||||
err: err,
|
||||
original: line,
|
||||
})
|
||||
}
|
||||
|
||||
for _, upstreamAddr := range upstreams {
|
||||
var res *upstreamResult
|
||||
if upstreamAddr != "#" {
|
||||
res = cv.parseUpstream(upstreamAddr, opts)
|
||||
} else if !isSpecific {
|
||||
res = &upstreamResult{
|
||||
err: errNotDomainSpecific,
|
||||
original: upstreamAddr,
|
||||
}
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
|
||||
res.isSpecific = isSpecific
|
||||
s = cv.insert(s, res)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// insert inserts r into slice in a sorted order, except duplicates. slice must
|
||||
// not be nil.
|
||||
func (cv *upstreamConfigValidator) insert(
|
||||
s []*upstreamResult,
|
||||
r *upstreamResult,
|
||||
) (result []*upstreamResult) {
|
||||
i, has := slices.BinarySearchFunc(s, r, (*upstreamResult).compare)
|
||||
if has {
|
||||
log.Debug("dnsforward: duplicate configuration %q", r.original)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
return slices.Insert(s, i, r)
|
||||
}
|
||||
|
||||
// parseUpstream parses addr and returns the result of parsing. It returns nil
|
||||
// if the specified server points at the default upstream server which is
|
||||
// validated separately.
|
||||
func (cv *upstreamConfigValidator) parseUpstream(
|
||||
addr string,
|
||||
opts *upstream.Options,
|
||||
) (r *upstreamResult) {
|
||||
// Check if the upstream has a valid protocol prefix.
|
||||
//
|
||||
// TODO(e.burkov): Validate the domain name.
|
||||
if proto, _, ok := strings.Cut(addr, "://"); ok {
|
||||
if !slices.Contains(protocols, proto) {
|
||||
return &upstreamResult{
|
||||
err: fmt.Errorf("bad protocol %q", proto),
|
||||
original: addr,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ups, err := upstream.AddressToUpstream(addr, opts)
|
||||
|
||||
return &upstreamResult{
|
||||
server: ups,
|
||||
err: err,
|
||||
original: addr,
|
||||
}
|
||||
}
|
||||
|
||||
// check tries to exchange with each successfully parsed upstream and enriches
|
||||
// the results with the healthcheck errors. It should not be called after the
|
||||
// [upsConfValidator.close] method, since it makes no sense to check the closed
|
||||
// upstreams.
|
||||
func (cv *upstreamConfigValidator) check() {
|
||||
const (
|
||||
// testTLD is the special-use fully-qualified domain name for testing
|
||||
// the DNS server reachability.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc6761#section-6.2.
|
||||
testTLD = "test."
|
||||
|
||||
// inAddrARPATLD is the special-use fully-qualified domain name for PTR
|
||||
// IP address resolution.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1035#section-3.5.
|
||||
inAddrARPATLD = "in-addr.arpa."
|
||||
)
|
||||
|
||||
commonChecker := &healthchecker{
|
||||
hostname: testTLD,
|
||||
qtype: dns.TypeA,
|
||||
ansEmpty: true,
|
||||
}
|
||||
|
||||
arpaChecker := &healthchecker{
|
||||
hostname: inAddrARPATLD,
|
||||
qtype: dns.TypePTR,
|
||||
ansEmpty: false,
|
||||
}
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(len(cv.general) + len(cv.fallback) + len(cv.private))
|
||||
|
||||
for _, res := range cv.general {
|
||||
go cv.checkSrv(res, wg, commonChecker)
|
||||
}
|
||||
for _, res := range cv.fallback {
|
||||
go cv.checkSrv(res, wg, commonChecker)
|
||||
}
|
||||
for _, res := range cv.private {
|
||||
go cv.checkSrv(res, wg, arpaChecker)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// checkSrv runs hc on the server from res, if any, and stores any occurred
|
||||
// error in res. wg is always marked done in the end. It used to be called in
|
||||
// a separate goroutine.
|
||||
func (cv *upstreamConfigValidator) checkSrv(
|
||||
res *upstreamResult,
|
||||
wg *sync.WaitGroup,
|
||||
hc *healthchecker,
|
||||
) {
|
||||
defer wg.Done()
|
||||
|
||||
if res.server == nil {
|
||||
return
|
||||
}
|
||||
|
||||
res.err = hc.check(res.server)
|
||||
if res.err != nil && res.isSpecific {
|
||||
res.err = domainSpecificTestError{Err: res.err}
|
||||
}
|
||||
}
|
||||
|
||||
// close closes all the upstreams that were successfully parsed. It enriches
|
||||
// the results with deferred closing errors.
|
||||
func (cv *upstreamConfigValidator) close() {
|
||||
for _, slice := range [][]*upstreamResult{cv.general, cv.fallback, cv.private} {
|
||||
for _, r := range slice {
|
||||
if r.server != nil {
|
||||
r.err = errors.WithDeferred(r.err, r.server.Close())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// status returns all the data collected during parsing, healthcheck, and
|
||||
// closing of the upstreams. The returned map is keyed by the original upstream
|
||||
// configuration piece and contains the corresponding error or "OK" if there was
|
||||
// no error.
|
||||
func (cv *upstreamConfigValidator) status() (results map[string]string) {
|
||||
result := map[string]string{}
|
||||
|
||||
for _, res := range cv.general {
|
||||
resultToStatus("general", res, result)
|
||||
}
|
||||
for _, res := range cv.fallback {
|
||||
resultToStatus("fallback", res, result)
|
||||
}
|
||||
for _, res := range cv.private {
|
||||
resultToStatus("private", res, result)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// resultToStatus puts "OK" or an error message from res into resMap. section
|
||||
// is the name of the upstream configuration section, i.e. "general",
|
||||
// "fallback", or "private", and only used for logging.
|
||||
//
|
||||
// TODO(e.burkov): Currently, the HTTP handler expects that all the results are
|
||||
// put together in a single map, which may lead to collisions, see AG-27539.
|
||||
// Improve the results compilation.
|
||||
func resultToStatus(section string, res *upstreamResult, resMap map[string]string) {
|
||||
val := "OK"
|
||||
if res.err != nil {
|
||||
val = res.err.Error()
|
||||
}
|
||||
|
||||
prevVal := resMap[res.original]
|
||||
switch prevVal {
|
||||
case "":
|
||||
resMap[res.original] = val
|
||||
case val:
|
||||
log.Debug("dnsforward: duplicating %s config line %q", section, res.original)
|
||||
default:
|
||||
log.Debug(
|
||||
"dnsforward: warning: %s config line %q (%v) had different result %v",
|
||||
section,
|
||||
val,
|
||||
res.original,
|
||||
prevVal,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark
|
||||
// the tested upstream domain-specific and therefore consider its errors
|
||||
// non-critical.
|
||||
//
|
||||
// TODO(a.garipov): Some common mechanism of distinguishing between errors and
|
||||
// warnings (non-critical errors) is desired.
|
||||
type domainSpecificTestError struct {
|
||||
// Err is the actual error occurred during healthcheck test.
|
||||
Err error
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ error = domainSpecificTestError{}
|
||||
|
||||
// Error implements the [error] interface for domainSpecificTestError.
|
||||
func (err domainSpecificTestError) Error() (msg string) {
|
||||
return fmt.Sprintf("WARNING: %s", err.Err)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ errors.Wrapper = domainSpecificTestError{}
|
||||
|
||||
// Unwrap implements the [errors.Wrapper] interface for domainSpecificTestError.
|
||||
func (err domainSpecificTestError) Unwrap() (wrapped error) {
|
||||
return err.Err
|
||||
}
|
||||
|
||||
// healthchecker checks the upstream's status by exchanging with it.
|
||||
type healthchecker struct {
|
||||
// hostname is the name of the host to put into healthcheck DNS request.
|
||||
hostname string
|
||||
|
||||
// qtype is the type of DNS request to use for healthcheck.
|
||||
qtype uint16
|
||||
|
||||
// ansEmpty defines if the answer section within the response is expected to
|
||||
// be empty.
|
||||
ansEmpty bool
|
||||
}
|
||||
|
||||
// check exchanges with u and validates the response.
|
||||
func (h *healthchecker) check(u upstream.Upstream) (err error) {
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: dns.Id(),
|
||||
RecursionDesired: true,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: h.hostname,
|
||||
Qtype: h.qtype,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
reply, err := u.Exchange(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't communicate with upstream: %w", err)
|
||||
} else if h.ansEmpty && len(reply.Answer) > 0 {
|
||||
return errWrongResponse
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -11,10 +13,12 @@ import (
|
||||
)
|
||||
|
||||
// DialContext is an [aghnet.DialContextFunc] that uses s to resolve hostnames.
|
||||
// addr should be a valid host:port address, where host could be a domain name
|
||||
// or an IP address.
|
||||
func (s *Server) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
||||
log.Debug("dnsforward: dialing %q for network %q", addr, network)
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -28,21 +32,24 @@ func (s *Server) DialContext(ctx context.Context, network, addr string) (conn ne
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
addrs, err := s.Resolve(host)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolving %q: %w", host, err)
|
||||
return nil, fmt.Errorf("invalid port %s: %w", portStr, err)
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: resolving %q: %v", host, addrs)
|
||||
|
||||
if len(addrs) == 0 {
|
||||
ips, err := s.Resolve(ctx, network, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolving %q: %w", host, err)
|
||||
} else if len(ips) == 0 {
|
||||
return nil, fmt.Errorf("no addresses for host %q", host)
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: resolved %q: %v", host, ips)
|
||||
|
||||
var dialErrs []error
|
||||
for _, a := range addrs {
|
||||
addr = net.JoinHostPort(a.String(), port)
|
||||
conn, err = dialer.DialContext(ctx, network, addr)
|
||||
for _, ip := range ips {
|
||||
addrPort := netip.AddrPortFrom(ip, uint16(port))
|
||||
conn, err = dialer.DialContext(ctx, network, addrPort.String())
|
||||
if err != nil {
|
||||
dialErrs = append(dialErrs, err)
|
||||
|
||||
|
||||
@@ -292,6 +292,7 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}, localUps)
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
@@ -34,6 +36,11 @@ import (
|
||||
// DefaultTimeout is the default upstream timeout
|
||||
const DefaultTimeout = 10 * time.Second
|
||||
|
||||
// defaultLocalTimeout is the default timeout for resolving addresses from
|
||||
// locally-served networks. It is assumed that local resolvers should work much
|
||||
// faster than ordinary upstreams.
|
||||
const defaultLocalTimeout = 1 * time.Second
|
||||
|
||||
// defaultClientIDCacheCount is the default count of items in the LRU ClientID
|
||||
// cache. The assumption here is that there won't be more than this many
|
||||
// requests between the BeforeRequestHandler stage and the actual processing.
|
||||
@@ -108,7 +115,7 @@ type Server struct {
|
||||
// stats is the statistics collector for client's DNS usage data.
|
||||
stats stats.Interface
|
||||
|
||||
// access drops unallowed clients.
|
||||
// access drops disallowed clients.
|
||||
access *accessManager
|
||||
|
||||
// localDomainSuffix is the suffix used to detect internal hosts. It
|
||||
@@ -135,8 +142,21 @@ type Server struct {
|
||||
// PTR resolving.
|
||||
sysResolvers SystemResolvers
|
||||
|
||||
// recDetector is a cache for recursive requests. It is used to detect
|
||||
// and prevent recursive requests only for private upstreams.
|
||||
// etcHosts contains the data from the system's hosts files.
|
||||
etcHosts upstream.Resolver
|
||||
|
||||
// bootstrap is the resolver for upstreams' hostnames.
|
||||
bootstrap upstream.Resolver
|
||||
|
||||
// bootResolvers are the resolvers that should be used for
|
||||
// bootstrapping along with [etcHosts].
|
||||
//
|
||||
// TODO(e.burkov): Use [proxy.UpstreamConfig] when it will implement the
|
||||
// [upstream.Resolver] interface.
|
||||
bootResolvers []*upstream.UpstreamResolver
|
||||
|
||||
// recDetector is a cache for recursive requests. It is used to detect and
|
||||
// prevent recursive requests only for private upstreams.
|
||||
//
|
||||
// See https://github.com/adguardTeam/adGuardHome/issues/3185#issuecomment-851048135.
|
||||
recDetector *recursionDetector
|
||||
@@ -153,8 +173,8 @@ type Server struct {
|
||||
// during the BeforeRequestHandler stage.
|
||||
clientIDCache cache.Cache
|
||||
|
||||
// DNS proxy instance for internal usage
|
||||
// We don't Start() it and so no listen port is required.
|
||||
// internalProxy resolves internal requests from the application itself. It
|
||||
// isn't started and so no listen ports are required.
|
||||
internalProxy *proxy.Proxy
|
||||
|
||||
// isRunning is true if the DNS server is running.
|
||||
@@ -185,6 +205,7 @@ type DNSCreateParams struct {
|
||||
DHCPServer DHCP
|
||||
PrivateNets netutil.SubnetSet
|
||||
Anonymizer *aghnet.IPMut
|
||||
EtcHosts *aghnet.HostsContainer
|
||||
LocalDomain string
|
||||
}
|
||||
|
||||
@@ -217,8 +238,10 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
if p.Anonymizer == nil {
|
||||
p.Anonymizer = aghnet.NewIPMut(nil)
|
||||
}
|
||||
|
||||
s = &Server{
|
||||
dnsFilter: p.DNSFilter,
|
||||
dhcpServer: p.DHCPServer,
|
||||
stats: p.Stats,
|
||||
queryLog: p.QueryLog,
|
||||
privateNets: p.PrivateNets,
|
||||
@@ -230,6 +253,12 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
MaxCount: defaultClientIDCacheCount,
|
||||
}),
|
||||
anonymizer: p.Anonymizer,
|
||||
conf: ServerConfig{
|
||||
ServePlainDNS: true,
|
||||
},
|
||||
}
|
||||
if p.EtcHosts != nil {
|
||||
s.etcHosts = p.EtcHosts
|
||||
}
|
||||
|
||||
s.sysResolvers, err = sysresolv.NewSystemResolvers(nil, defaultPlainDNSPort)
|
||||
@@ -237,8 +266,6 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
return nil, fmt.Errorf("initializing system resolvers: %w", err)
|
||||
}
|
||||
|
||||
s.dhcpServer = p.DHCPServer
|
||||
|
||||
if runtime.GOARCH == "mips" || runtime.GOARCH == "mipsle" {
|
||||
// Use plain DNS on MIPS, encryption is too slow
|
||||
defaultDNS = defaultBootstrap
|
||||
@@ -274,7 +301,7 @@ func (s *Server) WriteDiskConfig(c *Config) {
|
||||
|
||||
sc := s.conf.Config
|
||||
*c = sc
|
||||
c.RatelimitWhitelist = stringutil.CloneSlice(sc.RatelimitWhitelist)
|
||||
c.RatelimitWhitelist = slices.Clone(sc.RatelimitWhitelist)
|
||||
c.BootstrapDNS = stringutil.CloneSlice(sc.BootstrapDNS)
|
||||
c.FallbackDNS = stringutil.CloneSlice(sc.FallbackDNS)
|
||||
c.AllowedClients = stringutil.CloneSlice(sc.AllowedClients)
|
||||
@@ -305,15 +332,14 @@ func (s *Server) AddrProcConfig() (c *client.DefaultAddrProcConfig) {
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve - get IP addresses by host name from an upstream server.
|
||||
// No request/response filtering is performed.
|
||||
// Query log and Stats are not updated.
|
||||
// This method may be called before Start().
|
||||
func (s *Server) Resolve(host string) ([]net.IPAddr, error) {
|
||||
// Resolve gets IP addresses by host name from an upstream server. No
|
||||
// request/response filtering is performed. Query log and Stats are not
|
||||
// updated. This method may be called before [Server.Start].
|
||||
func (s *Server) Resolve(ctx context.Context, net, host string) (addr []netip.Addr, err error) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
return s.internalProxy.LookupIPAddr(host)
|
||||
return s.internalProxy.LookupNetIP(ctx, net, host)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -421,7 +447,7 @@ func hostFromPTR(resp *dns.Msg) (host string, ttl time.Duration, err error) {
|
||||
return "", 0, ErrRDNSNoData
|
||||
}
|
||||
|
||||
// Start starts the DNS server.
|
||||
// Start starts the DNS server. It must only be called after [Server.Prepare].
|
||||
func (s *Server) Start() error {
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
@@ -429,48 +455,42 @@ func (s *Server) Start() error {
|
||||
return s.startLocked()
|
||||
}
|
||||
|
||||
// startLocked starts the DNS server without locking. For internal use only.
|
||||
// startLocked starts the DNS server without locking. s.serverLock is expected
|
||||
// to be locked.
|
||||
func (s *Server) startLocked() error {
|
||||
err := s.dnsProxy.Start()
|
||||
if err == nil {
|
||||
s.isRunning = true
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// defaultLocalTimeout is the default timeout for resolving addresses from
|
||||
// locally-served networks. It is assumed that local resolvers should work much
|
||||
// faster than ordinary upstreams.
|
||||
const defaultLocalTimeout = 1 * time.Second
|
||||
|
||||
// setupLocalResolvers initializes the resolvers for local addresses. For
|
||||
// internal use only.
|
||||
func (s *Server) setupLocalResolvers() (err error) {
|
||||
matcher, err := s.conf.ourAddrsMatcher()
|
||||
// setupLocalResolvers initializes the resolvers for local addresses. It
|
||||
// assumes s.serverLock is locked or the Server not running.
|
||||
func (s *Server) setupLocalResolvers(boot upstream.Resolver) (err error) {
|
||||
set, err := s.conf.ourAddrsSet()
|
||||
if err != nil {
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
bootstraps := s.conf.BootstrapDNS
|
||||
resolvers := s.conf.LocalPTRResolvers
|
||||
filterConfig := false
|
||||
|
||||
if len(resolvers) == 0 {
|
||||
sysResolvers := slices.DeleteFunc(s.sysResolvers.Addrs(), matcher)
|
||||
confNeedsFiltering := len(resolvers) > 0
|
||||
if confNeedsFiltering {
|
||||
resolvers = stringutil.FilterOut(resolvers, IsCommentOrEmpty)
|
||||
} else {
|
||||
sysResolvers := slices.DeleteFunc(slices.Clone(s.sysResolvers.Addrs()), set.Has)
|
||||
resolvers = make([]string, 0, len(sysResolvers))
|
||||
for _, r := range sysResolvers {
|
||||
resolvers = append(resolvers, r.String())
|
||||
}
|
||||
} else {
|
||||
resolvers = stringutil.FilterOut(resolvers, IsCommentOrEmpty)
|
||||
filterConfig = true
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", resolvers)
|
||||
|
||||
uc, err := s.prepareUpstreamConfig(resolvers, nil, &upstream.Options{
|
||||
Bootstrap: bootstraps,
|
||||
Bootstrap: boot,
|
||||
Timeout: defaultLocalTimeout,
|
||||
// TODO(e.burkov): Should we verify server's certificates?
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
@@ -479,8 +499,9 @@ func (s *Server) setupLocalResolvers() (err error) {
|
||||
return fmt.Errorf("preparing private upstreams: %w", err)
|
||||
}
|
||||
|
||||
if filterConfig {
|
||||
if err = matcher.filterOut(uc); err != nil {
|
||||
if confNeedsFiltering {
|
||||
err = filterOutAddrs(uc, set)
|
||||
if err != nil {
|
||||
return fmt.Errorf("filtering private upstreams: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -491,6 +512,7 @@ func (s *Server) setupLocalResolvers() (err error) {
|
||||
},
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Should we also consider the DNS64 usage?
|
||||
if s.conf.UsePrivateRDNS &&
|
||||
// Only set the upstream config if there are any upstreams. It's safe
|
||||
// to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
|
||||
@@ -517,31 +539,19 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
|
||||
s.initDefaultSettings()
|
||||
|
||||
err = s.prepareIpsetListSettings()
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return fmt.Errorf("preparing ipset settings: %w", err)
|
||||
}
|
||||
|
||||
err = s.prepareUpstreamSettings()
|
||||
boot, err := s.prepareInternalDNS()
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
var proxyConfig proxy.Config
|
||||
proxyConfig, err = s.createProxyConfig()
|
||||
proxyConfig, err := s.newProxyConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing proxy: %w", err)
|
||||
}
|
||||
|
||||
s.setupDNS64()
|
||||
|
||||
err = s.prepareInternalProxy()
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing internal proxy: %w", err)
|
||||
}
|
||||
|
||||
s.access, err = newAccessCtx(
|
||||
s.conf.AllowedClients,
|
||||
s.conf.DisallowedClients,
|
||||
@@ -554,9 +564,9 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
// Set the proxy here because [setupLocalResolvers] sets its values.
|
||||
//
|
||||
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
|
||||
s.dnsProxy = &proxy.Proxy{Config: proxyConfig}
|
||||
s.dnsProxy = &proxy.Proxy{Config: *proxyConfig}
|
||||
|
||||
err = s.setupLocalResolvers()
|
||||
err = s.setupLocalResolvers(boot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting up resolvers: %w", err)
|
||||
}
|
||||
@@ -575,6 +585,38 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareInternalDNS initializes the internal state of s before initializing
|
||||
// the primary DNS proxy instance. It assumes s.serverLock is locked or the
|
||||
// Server not running.
|
||||
func (s *Server) prepareInternalDNS() (boot upstream.Resolver, err error) {
|
||||
err = s.prepareIpsetListSettings()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("preparing ipset settings: %w", err)
|
||||
}
|
||||
|
||||
s.bootstrap, s.bootResolvers, err = s.createBootstrap(s.conf.BootstrapDNS, &upstream.Options{
|
||||
Timeout: DefaultTimeout,
|
||||
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
|
||||
})
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = s.prepareUpstreamSettings(s.bootstrap)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return s.bootstrap, err
|
||||
}
|
||||
|
||||
err = s.prepareInternalProxy()
|
||||
if err != nil {
|
||||
return s.bootstrap, fmt.Errorf("preparing internal proxy: %w", err)
|
||||
}
|
||||
|
||||
return s.bootstrap, nil
|
||||
}
|
||||
|
||||
// setupFallbackDNS initializes the fallback DNS servers.
|
||||
func (s *Server) setupFallbackDNS() (err error) {
|
||||
fallbacks := s.conf.FallbackDNS
|
||||
@@ -598,7 +640,8 @@ func (s *Server) setupFallbackDNS() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupAddrProc initializes the address processor. For internal use only.
|
||||
// setupAddrProc initializes the address processor. It assumes s.serverLock is
|
||||
// locked or the Server not running.
|
||||
func (s *Server) setupAddrProc() {
|
||||
// TODO(a.garipov): This is a crutch for tests; remove.
|
||||
if s.conf.AddrProcConf == nil {
|
||||
@@ -687,7 +730,8 @@ func (s *Server) Stop() error {
|
||||
return s.stopLocked()
|
||||
}
|
||||
|
||||
// stopLocked stops the DNS server without locking. For internal use only.
|
||||
// stopLocked stops the DNS server without locking. s.serverLock is expected to
|
||||
// be locked.
|
||||
func (s *Server) stopLocked() (err error) {
|
||||
// TODO(e.burkov, a.garipov): Return critical errors, not just log them.
|
||||
// This will require filtering all the non-critical errors in
|
||||
@@ -700,18 +744,11 @@ func (s *Server) stopLocked() (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
if upsConf := s.internalProxy.UpstreamConfig; upsConf != nil {
|
||||
err = upsConf.Close()
|
||||
if err != nil {
|
||||
log.Error("dnsforward: closing internal resolvers: %s", err)
|
||||
}
|
||||
}
|
||||
logCloserErr(s.internalProxy.UpstreamConfig, "dnsforward: closing internal resolvers: %s")
|
||||
logCloserErr(s.localResolvers.UpstreamConfig, "dnsforward: closing local resolvers: %s")
|
||||
|
||||
if upsConf := s.localResolvers.UpstreamConfig; upsConf != nil {
|
||||
err = upsConf.Close()
|
||||
if err != nil {
|
||||
log.Error("dnsforward: closing local resolvers: %s", err)
|
||||
}
|
||||
for _, b := range s.bootResolvers {
|
||||
logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address())
|
||||
}
|
||||
|
||||
s.isRunning = false
|
||||
@@ -719,6 +756,18 @@ func (s *Server) stopLocked() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// logCloserErr logs the error returned by c, if any.
|
||||
func logCloserErr(c io.Closer, format string, args ...any) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
|
||||
err := c.Close()
|
||||
if err != nil {
|
||||
log.Error(format, append(args, err)...)
|
||||
}
|
||||
}
|
||||
|
||||
// IsRunning returns true if the DNS server is running.
|
||||
func (s *Server) IsRunning() bool {
|
||||
s.serverLock.RLock()
|
||||
|
||||
@@ -54,13 +54,10 @@ const (
|
||||
testMessagesCount = 10
|
||||
)
|
||||
|
||||
// testClientAddr is the common net.Addr for tests.
|
||||
// testClientAddrPort is the common net.Addr for tests.
|
||||
//
|
||||
// TODO(a.garipov): Use more.
|
||||
var testClientAddr net.Addr = &net.TCPAddr{
|
||||
IP: net.IP{1, 2, 3, 4},
|
||||
Port: 12345,
|
||||
}
|
||||
var testClientAddrPort = netip.MustParseAddrPort("1.2.3.4:12345")
|
||||
|
||||
func startDeferStop(t *testing.T, s *Server) {
|
||||
t.Helper()
|
||||
@@ -182,6 +179,7 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte)
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}, nil)
|
||||
|
||||
tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem
|
||||
@@ -309,6 +307,7 @@ func TestServer(t *testing.T) {
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}, nil)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
|
||||
startDeferStop(t, s)
|
||||
@@ -347,6 +346,7 @@ func TestServer_timeout(t *testing.T) {
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
|
||||
s, err := NewServer(DNSCreateParams{DNSFilter: createTestDNSFilter(t)})
|
||||
@@ -381,6 +381,7 @@ func TestServer_Prepare_fallbacks(t *testing.T) {
|
||||
},
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
|
||||
s, err := NewServer(DNSCreateParams{})
|
||||
@@ -402,6 +403,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}, nil)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
|
||||
startDeferStop(t, s)
|
||||
@@ -479,6 +481,7 @@ func TestServerRace(t *testing.T) {
|
||||
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
},
|
||||
ConfigModified: func() {},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
|
||||
@@ -532,6 +535,7 @@ func TestSafeSearch(t *testing.T) {
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
||||
startDeferStop(t, s)
|
||||
@@ -594,6 +598,7 @@ func TestInvalidRequest(t *testing.T) {
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}, nil)
|
||||
startDeferStop(t, s)
|
||||
|
||||
@@ -622,6 +627,7 @@ func TestBlockedRequest(t *testing.T) {
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
ProtectionEnabled: true,
|
||||
@@ -644,45 +650,71 @@ func TestBlockedRequest(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServerCustomClientUpstream(t *testing.T) {
|
||||
const defaultCacheSize = 1024 * 1024
|
||||
|
||||
var upsCalledCounter uint32
|
||||
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
Config: Config{
|
||||
CacheSize: defaultCacheSize,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, forwardConf, nil)
|
||||
s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) {
|
||||
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return aghalg.Coalesce(
|
||||
aghtest.MatchedResponse(req, dns.TypeA, "host", "192.168.0.1"),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
), nil
|
||||
})
|
||||
|
||||
return &proxy.UpstreamConfig{
|
||||
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
atomic.AddUint32(&upsCalledCounter, 1)
|
||||
|
||||
return aghalg.Coalesce(
|
||||
aghtest.MatchedResponse(req, dns.TypeA, "host", "192.168.0.1"),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
), nil
|
||||
})
|
||||
|
||||
customUpsConf := proxy.NewCustomUpstreamConfig(
|
||||
&proxy.UpstreamConfig{
|
||||
Upstreams: []upstream.Upstream{ups},
|
||||
}, nil
|
||||
},
|
||||
true,
|
||||
defaultCacheSize,
|
||||
forwardConf.EDNSClientSubnet.Enabled,
|
||||
)
|
||||
|
||||
s.conf.ClientsContainer = &aghtest.ClientsContainer{
|
||||
OnUpstreamConfigByID: func(
|
||||
_ string,
|
||||
_ upstream.Resolver,
|
||||
) (conf *proxy.CustomUpstreamConfig, err error) {
|
||||
return customUpsConf, nil
|
||||
},
|
||||
}
|
||||
|
||||
startDeferStop(t, s)
|
||||
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
|
||||
|
||||
// Send test request.
|
||||
req := createTestMessage("host.")
|
||||
|
||||
reply, err := dns.Exchange(req, addr.String())
|
||||
reply, err := dns.Exchange(req, addr)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, reply.Answer)
|
||||
require.Len(t, reply.Answer, 1)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||||
require.NotEmpty(t, reply.Answer)
|
||||
|
||||
require.Len(t, reply.Answer, 1)
|
||||
assert.Equal(t, net.IP{192, 168, 0, 1}, reply.Answer[0].(*dns.A).A)
|
||||
assert.Equal(t, uint32(1), atomic.LoadUint32(&upsCalledCounter))
|
||||
|
||||
_, err = dns.Exchange(req, addr)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(1), atomic.LoadUint32(&upsCalledCounter))
|
||||
}
|
||||
|
||||
// testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work.
|
||||
@@ -708,6 +740,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}, nil)
|
||||
testUpstm := &aghtest.Upstream{
|
||||
CName: testCNAMEs,
|
||||
@@ -740,6 +773,7 @@ func TestBlockCNAME(t *testing.T) {
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
ProtectionEnabled: true,
|
||||
@@ -814,6 +848,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
@@ -858,6 +893,7 @@ func TestNullBlockedRequest(t *testing.T) {
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
ProtectionEnabled: true,
|
||||
@@ -923,6 +959,7 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
|
||||
// Invalid BlockingIPv4.
|
||||
@@ -974,6 +1011,7 @@ func TestBlockedByHosts(t *testing.T) {
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
@@ -1024,6 +1062,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
||||
startDeferStop(t, s)
|
||||
@@ -1082,6 +1121,7 @@ func TestRewrite(t *testing.T) {
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}))
|
||||
|
||||
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
|
||||
@@ -40,6 +40,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}, nil)
|
||||
|
||||
makeQ := func(qtype rules.RRType) (req *dns.Msg) {
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/slices"
|
||||
@@ -28,8 +27,7 @@ func (s *Server) beforeRequestHandler(
|
||||
return false, fmt.Errorf("getting clientid: %w", err)
|
||||
}
|
||||
|
||||
addrPort := netutil.NetAddrToAddrPort(pctx.Addr)
|
||||
blocked, _ := s.IsBlockedClient(addrPort.Addr(), clientID)
|
||||
blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID)
|
||||
if blocked {
|
||||
return s.preBlockedResponse(pctx)
|
||||
}
|
||||
@@ -60,8 +58,7 @@ func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filter
|
||||
setts = s.dnsFilter.Settings()
|
||||
setts.ProtectionEnabled = dctx.protectionEnabled
|
||||
if s.conf.FilterHandler != nil {
|
||||
addrPort := netutil.NetAddrToAddrPort(dctx.proxyCtx.Addr)
|
||||
s.conf.FilterHandler(addrPort.Addr(), dctx.clientID, setts)
|
||||
s.conf.FilterHandler(dctx.proxyCtx.Addr.Addr(), dctx.clientID, setts)
|
||||
}
|
||||
|
||||
return setts
|
||||
|
||||
@@ -35,6 +35,7 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) {
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
filters := []filtering.Filter{{
|
||||
ID: 0, Data: []byte(rules),
|
||||
@@ -186,7 +187,7 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) {
|
||||
dctx := &proxy.DNSContext{
|
||||
Proto: proxy.ProtoUDP,
|
||||
Req: tc.req,
|
||||
Addr: &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 1},
|
||||
Addr: testClientAddrPort,
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -325,7 +326,7 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
Proto: proxy.ProtoUDP,
|
||||
Req: tc.req,
|
||||
Res: resp,
|
||||
Addr: &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 1},
|
||||
Addr: testClientAddrPort,
|
||||
}
|
||||
|
||||
dctx := &dnsContext{
|
||||
|
||||
@@ -6,20 +6,15 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
@@ -45,8 +40,19 @@ type jsonDNSConfig struct {
|
||||
// ProtectionEnabled defines if protection is enabled.
|
||||
ProtectionEnabled *bool `json:"protection_enabled"`
|
||||
|
||||
// RateLimit is the number of requests per second allowed per client.
|
||||
RateLimit *uint32 `json:"ratelimit"`
|
||||
// Ratelimit is the number of requests per second allowed per client.
|
||||
Ratelimit *uint32 `json:"ratelimit"`
|
||||
|
||||
// RatelimitSubnetLenIPv4 is a subnet length for IPv4 addresses used for
|
||||
// rate limiting requests.
|
||||
RatelimitSubnetLenIPv4 *int `json:"ratelimit_subnet_len_ipv4"`
|
||||
|
||||
// RatelimitSubnetLenIPv6 is a subnet length for IPv6 addresses used for
|
||||
// rate limiting requests.
|
||||
RatelimitSubnetLenIPv6 *int `json:"ratelimit_subnet_len_ipv6"`
|
||||
|
||||
// RatelimitWhitelist is a list of IP addresses excluded from rate limiting.
|
||||
RatelimitWhitelist *[]netip.Addr `json:"ratelimit_whitelist"`
|
||||
|
||||
// BlockingMode defines the way blocked responses are constructed.
|
||||
BlockingMode *filtering.BlockingMode `json:"blocking_mode"`
|
||||
@@ -121,6 +127,9 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
blockingMode, blockingIPv4, blockingIPv6 := s.dnsFilter.BlockingMode()
|
||||
blockedResponseTTL := s.dnsFilter.BlockedResponseTTL()
|
||||
ratelimit := s.conf.Ratelimit
|
||||
ratelimitSubnetLenIPv4 := s.conf.RatelimitSubnetLenIPv4
|
||||
ratelimitSubnetLenIPv6 := s.conf.RatelimitSubnetLenIPv6
|
||||
ratelimitWhitelist := append([]netip.Addr{}, s.conf.RatelimitWhitelist...)
|
||||
|
||||
customIP := s.conf.EDNSClientSubnet.CustomIP
|
||||
enableEDNSClientSubnet := s.conf.EDNSClientSubnet.Enabled
|
||||
@@ -157,7 +166,10 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
BlockingMode: &blockingMode,
|
||||
BlockingIPv4: blockingIPv4,
|
||||
BlockingIPv6: blockingIPv6,
|
||||
RateLimit: &ratelimit,
|
||||
Ratelimit: &ratelimit,
|
||||
RatelimitSubnetLenIPv4: &ratelimitSubnetLenIPv4,
|
||||
RatelimitSubnetLenIPv6: &ratelimitSubnetLenIPv6,
|
||||
RatelimitWhitelist: &ratelimitWhitelist,
|
||||
EDNSCSCustomIP: customIP,
|
||||
EDNSCSEnabled: &enableEDNSClientSubnet,
|
||||
EDNSCSUseCustom: &useCustom,
|
||||
@@ -180,13 +192,13 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
// defaultLocalPTRUpstreams returns the list of default local PTR resolvers
|
||||
// filtered of AdGuard Home's own DNS server addresses. It may appear empty.
|
||||
func (s *Server) defaultLocalPTRUpstreams() (ups []string, err error) {
|
||||
matcher, err := s.conf.ourAddrsMatcher()
|
||||
matcher, err := s.conf.ourAddrsSet()
|
||||
if err != nil {
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sysResolvers := slices.DeleteFunc(s.sysResolvers.Addrs(), matcher)
|
||||
sysResolvers := slices.DeleteFunc(s.sysResolvers.Addrs(), matcher.Has)
|
||||
ups = make([]string, 0, len(sysResolvers))
|
||||
for _, r := range sysResolvers {
|
||||
ups = append(ups, r.String())
|
||||
@@ -201,6 +213,7 @@ func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
}
|
||||
|
||||
// checkBlockingMode returns an error if blocking mode is invalid.
|
||||
func (req *jsonDNSConfig) checkBlockingMode() (err error) {
|
||||
if req.BlockingMode == nil {
|
||||
return nil
|
||||
@@ -209,12 +222,21 @@ func (req *jsonDNSConfig) checkBlockingMode() (err error) {
|
||||
return validateBlockingMode(*req.BlockingMode, req.BlockingIPv4, req.BlockingIPv6)
|
||||
}
|
||||
|
||||
func (req *jsonDNSConfig) checkUpstreamsMode() bool {
|
||||
valid := []string{"", "fastest_addr", "parallel"}
|
||||
// checkUpstreamsMode returns an error if the upstream mode is invalid.
|
||||
func (req *jsonDNSConfig) checkUpstreamsMode() (err error) {
|
||||
if req.UpstreamMode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return req.UpstreamMode == nil || stringutil.InSlice(valid, *req.UpstreamMode)
|
||||
mode := *req.UpstreamMode
|
||||
if ok := slices.Contains([]string{"", "fastest_addr", "parallel"}, mode); !ok {
|
||||
return fmt.Errorf("upstream_mode: incorrect value %q", mode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkBootstrap returns an error if any bootstrap address is invalid.
|
||||
func (req *jsonDNSConfig) checkBootstrap() (err error) {
|
||||
if req.Bootstraps == nil {
|
||||
return nil
|
||||
@@ -229,6 +251,7 @@ func (req *jsonDNSConfig) checkBootstrap() (err error) {
|
||||
}
|
||||
|
||||
if _, err = upstream.NewUpstreamResolver(b, nil); err != nil {
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -244,67 +267,136 @@ func (req *jsonDNSConfig) checkFallbacks() (err error) {
|
||||
|
||||
err = ValidateUpstreams(*req.Fallbacks)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating fallback servers: %w", err)
|
||||
return fmt.Errorf("fallback servers: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validate returns an error if any field of req is invalid.
|
||||
//
|
||||
// TODO(s.chzhen): Parse, don't validate.
|
||||
func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "validating dns config: %w") }()
|
||||
|
||||
err = req.validateUpstreamDNSServers(privateNets)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
err = req.checkRatelimitSubnetMaskLen()
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
err = req.checkBlockingMode()
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
err = req.checkUpstreamsMode()
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
err = req.checkCacheTTL()
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateUpstreamDNSServers returns an error if any field of req is invalid.
|
||||
func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetSet) (err error) {
|
||||
if req.Upstreams != nil {
|
||||
err = ValidateUpstreams(*req.Upstreams)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating upstream servers: %w", err)
|
||||
return fmt.Errorf("upstream servers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if req.LocalPTRUpstreams != nil {
|
||||
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating private upstream servers: %w", err)
|
||||
return fmt.Errorf("private upstream servers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = req.checkBootstrap()
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
err = req.checkFallbacks()
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
err = req.checkBlockingMode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case !req.checkUpstreamsMode():
|
||||
return errors.Error("upstream_mode: incorrect value")
|
||||
case !req.checkCacheTTL():
|
||||
return errors.Error("cache_ttl_min must be less or equal than cache_ttl_max")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (req *jsonDNSConfig) checkCacheTTL() bool {
|
||||
// checkCacheTTL returns an error if the configuration of the cache TTL is
|
||||
// invalid.
|
||||
func (req *jsonDNSConfig) checkCacheTTL() (err error) {
|
||||
if req.CacheMinTTL == nil && req.CacheMaxTTL == nil {
|
||||
return true
|
||||
return nil
|
||||
}
|
||||
|
||||
var min, max uint32
|
||||
var minTTL, maxTTL uint32
|
||||
if req.CacheMinTTL != nil {
|
||||
min = *req.CacheMinTTL
|
||||
minTTL = *req.CacheMinTTL
|
||||
}
|
||||
if req.CacheMaxTTL != nil {
|
||||
max = *req.CacheMaxTTL
|
||||
maxTTL = *req.CacheMaxTTL
|
||||
}
|
||||
|
||||
return min <= max
|
||||
if minTTL <= maxTTL {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.Error("cache_ttl_min must be less or equal than cache_ttl_max")
|
||||
}
|
||||
|
||||
// checkRatelimitSubnetMaskLen returns an error if the length of the subnet mask
|
||||
// for IPv4 or IPv6 addresses is invalid.
|
||||
func (req *jsonDNSConfig) checkRatelimitSubnetMaskLen() (err error) {
|
||||
err = checkInclusion(req.RatelimitSubnetLenIPv4, 0, netutil.IPv4BitLen)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ratelimit_subnet_len_ipv4 is invalid: %w", err)
|
||||
}
|
||||
|
||||
err = checkInclusion(req.RatelimitSubnetLenIPv6, 0, netutil.IPv6BitLen)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ratelimit_subnet_len_ipv6 is invalid: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkInclusion returns an error if a ptr is not nil and points to value,
|
||||
// that not in the inclusive range between minN and maxN.
|
||||
func checkInclusion(ptr *int, minN, maxN int) (err error) {
|
||||
if ptr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
n := *ptr
|
||||
switch {
|
||||
case n < minN:
|
||||
return fmt.Errorf("value %d less than min %d", n, minN)
|
||||
case n > maxN:
|
||||
return fmt.Errorf("value %d greater than max %d", n, maxN)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleSetConfig handles requests to the POST /control/dns_config endpoint.
|
||||
@@ -401,6 +493,9 @@ func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
setIfNotNil(&s.conf.CacheOptimistic, dc.CacheOptimistic),
|
||||
setIfNotNil(&s.conf.AddrProcConf.UseRDNS, dc.ResolveClients),
|
||||
setIfNotNil(&s.conf.UsePrivateRDNS, dc.UsePrivateRDNS),
|
||||
setIfNotNil(&s.conf.RatelimitSubnetLenIPv4, dc.RatelimitSubnetLenIPv4),
|
||||
setIfNotNil(&s.conf.RatelimitSubnetLenIPv6, dc.RatelimitSubnetLenIPv6),
|
||||
setIfNotNil(&s.conf.RatelimitWhitelist, dc.RatelimitWhitelist),
|
||||
} {
|
||||
shouldRestart = shouldRestart || hasSet
|
||||
if shouldRestart {
|
||||
@@ -408,8 +503,8 @@ func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
}
|
||||
}
|
||||
|
||||
if dc.RateLimit != nil && s.conf.Ratelimit != *dc.RateLimit {
|
||||
s.conf.Ratelimit = *dc.RateLimit
|
||||
if dc.Ratelimit != nil && s.conf.Ratelimit != *dc.Ratelimit {
|
||||
s.conf.Ratelimit = *dc.Ratelimit
|
||||
shouldRestart = true
|
||||
}
|
||||
|
||||
@@ -424,374 +519,11 @@ type upstreamJSON struct {
|
||||
PrivateUpstreams []string `json:"private_upstream"`
|
||||
}
|
||||
|
||||
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
|
||||
// This function is useful for filtering out non-upstream lines from upstream
|
||||
// configs.
|
||||
func IsCommentOrEmpty(s string) (ok bool) {
|
||||
return len(s) == 0 || s[0] == '#'
|
||||
}
|
||||
|
||||
// newUpstreamConfig validates upstreams and returns an appropriate upstream
|
||||
// configuration or nil if it can't be built.
|
||||
//
|
||||
// TODO(e.burkov): Perhaps proxy.ParseUpstreamsConfig should validate upstreams
|
||||
// slice already so that this function may be considered useless.
|
||||
func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err error) {
|
||||
// No need to validate comments and empty lines.
|
||||
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
|
||||
if len(upstreams) == 0 {
|
||||
// Consider this case valid since it means the default server should be
|
||||
// used.
|
||||
return nil, nil
|
||||
// closeBoots closes all the provided bootstrap servers and logs errors if any.
|
||||
func closeBoots(boots []*upstream.UpstreamResolver) {
|
||||
for _, c := range boots {
|
||||
logCloserErr(c, "dnsforward: closing bootstrap %s: %s", c.Address())
|
||||
}
|
||||
|
||||
err = validateUpstreamConfig(upstreams)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conf, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
&upstream.Options{
|
||||
Bootstrap: []string{},
|
||||
Timeout: DefaultTimeout,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
} else if len(conf.Upstreams) == 0 {
|
||||
return nil, errors.Error("no default upstreams specified")
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
// validateUpstreamConfig validates each upstream from the upstream
|
||||
// configuration and returns an error if any upstream is invalid.
|
||||
//
|
||||
// TODO(e.burkov): Move into aghnet or even into dnsproxy.
|
||||
func validateUpstreamConfig(conf []string) (err error) {
|
||||
for _, u := range conf {
|
||||
var ups []string
|
||||
var domains []string
|
||||
ups, domains, err = separateUpstream(u)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
for _, addr := range ups {
|
||||
_, err = validateUpstream(addr, domains)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating upstream %q: %w", addr, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateUpstreams validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified.
|
||||
//
|
||||
// TODO(e.burkov): Move into aghnet or even into dnsproxy.
|
||||
func ValidateUpstreams(upstreams []string) (err error) {
|
||||
_, err = newUpstreamConfig(upstreams)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified. It also
|
||||
// checks each domain of domain-specific upstreams for being ARPA pointing to
|
||||
// a locally-served network. privateNets must not be nil.
|
||||
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
|
||||
conf, err := newUpstreamConfig(upstreams)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating config: %w", err)
|
||||
}
|
||||
|
||||
if conf == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
keys := maps.Keys(conf.DomainReservedUpstreams)
|
||||
slices.Sort(keys)
|
||||
|
||||
var errs []error
|
||||
for _, domain := range keys {
|
||||
var subnet netip.Prefix
|
||||
subnet, err = extractARPASubnet(domain)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !privateNets.Contains(subnet.Addr().AsSlice()) {
|
||||
errs = append(
|
||||
errs,
|
||||
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w")
|
||||
}
|
||||
|
||||
var protocols = []string{
|
||||
"h3://",
|
||||
"https://",
|
||||
"quic://",
|
||||
"sdns://",
|
||||
"tcp://",
|
||||
"tls://",
|
||||
"udp://",
|
||||
}
|
||||
|
||||
// validateUpstream returns an error if u alongside with domains is not a valid
|
||||
// upstream configuration. useDefault is true if the upstream is
|
||||
// domain-specific and is configured to point at the default upstream server
|
||||
// which is validated separately. The upstream is considered domain-specific
|
||||
// only if domains is at least not nil.
|
||||
func validateUpstream(u string, domains []string) (useDefault bool, err error) {
|
||||
// The special server address '#' means that default server must be used.
|
||||
if useDefault = u == "#" && domains != nil; useDefault {
|
||||
return useDefault, nil
|
||||
}
|
||||
|
||||
// Check if the upstream has a valid protocol prefix.
|
||||
//
|
||||
// TODO(e.burkov): Validate the domain name.
|
||||
for _, proto := range protocols {
|
||||
if strings.HasPrefix(u, proto) {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
if proto, _, ok := strings.Cut(u, "://"); ok {
|
||||
return false, fmt.Errorf("bad protocol %q", proto)
|
||||
}
|
||||
|
||||
// Check if upstream is either an IP or IP with port.
|
||||
if _, err = netip.ParseAddr(u); err == nil {
|
||||
return false, nil
|
||||
} else if _, err = netip.ParseAddrPort(u); err == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return false, err
|
||||
}
|
||||
|
||||
// separateUpstream returns the upstreams and the specified domains. domains
|
||||
// is nil when the upstream is not domains-specific. Otherwise it may also be
|
||||
// empty.
|
||||
func separateUpstream(upstreamStr string) (upstreams, domains []string, err error) {
|
||||
if !strings.HasPrefix(upstreamStr, "[/") {
|
||||
return []string{upstreamStr}, nil, nil
|
||||
}
|
||||
|
||||
defer func() { err = errors.Annotate(err, "bad upstream for domain %q: %w", upstreamStr) }()
|
||||
|
||||
parts := strings.Split(upstreamStr[2:], "/]")
|
||||
switch len(parts) {
|
||||
case 2:
|
||||
// Go on.
|
||||
case 1:
|
||||
return nil, nil, errors.Error("missing separator")
|
||||
default:
|
||||
return nil, nil, errors.Error("duplicated separator")
|
||||
}
|
||||
|
||||
for i, host := range strings.Split(parts[0], "/") {
|
||||
if host == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*."))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("domain at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
domains = append(domains, host)
|
||||
}
|
||||
|
||||
return strings.Fields(parts[1]), domains, nil
|
||||
}
|
||||
|
||||
// healthCheckFunc is a signature of function to check if upstream exchanges
|
||||
// properly.
|
||||
type healthCheckFunc func(u upstream.Upstream) (err error)
|
||||
|
||||
// checkDNSUpstreamExc checks if the DNS upstream exchanges correctly.
|
||||
func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
|
||||
// testTLD is the special-use fully-qualified domain name for testing the
|
||||
// DNS server reachability.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc6761#section-6.2.
|
||||
const testTLD = "test."
|
||||
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: dns.Id(),
|
||||
RecursionDesired: true,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: testTLD,
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
var reply *dns.Msg
|
||||
reply, err = u.Exchange(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't communicate with upstream: %w", err)
|
||||
} else if len(reply.Answer) != 0 {
|
||||
return errors.Error("wrong response")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkPrivateUpstreamExc checks if the upstream for resolving private
|
||||
// addresses exchanges correctly.
|
||||
//
|
||||
// TODO(e.burkov): Think about testing the ip6.arpa. as well.
|
||||
func checkPrivateUpstreamExc(u upstream.Upstream) (err error) {
|
||||
// inAddrArpaTLD is the special-use fully-qualified domain name for PTR IP
|
||||
// address resolution.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1035#section-3.5.
|
||||
const inAddrArpaTLD = "in-addr.arpa."
|
||||
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: dns.Id(),
|
||||
RecursionDesired: true,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: inAddrArpaTLD,
|
||||
Qtype: dns.TypePTR,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
if _, err = u.Exchange(req); err != nil {
|
||||
return fmt.Errorf("couldn't communicate with upstream: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark
|
||||
// the tested upstream domain-specific and therefore consider its errors
|
||||
// non-critical.
|
||||
//
|
||||
// TODO(a.garipov): Some common mechanism of distinguishing between errors and
|
||||
// warnings (non-critical errors) is desired.
|
||||
type domainSpecificTestError struct {
|
||||
error
|
||||
}
|
||||
|
||||
// Error implements the [error] interface for domainSpecificTestError.
|
||||
func (err domainSpecificTestError) Error() (msg string) {
|
||||
return fmt.Sprintf("WARNING: %s", err.error)
|
||||
}
|
||||
|
||||
// checkDNS parses line, creates DNS upstreams using opts, and checks if the
|
||||
// upstreams are exchanging correctly. It saves the result into a sync.Map
|
||||
// where key is an upstream address and value is "OK", if the upstream
|
||||
// exchanges correctly, or text of the error. It is intended to be used as a
|
||||
// goroutine.
|
||||
//
|
||||
// TODO(s.chzhen): Separate to a different structure/file.
|
||||
func (s *Server) checkDNS(
|
||||
line string,
|
||||
opts *upstream.Options,
|
||||
check healthCheckFunc,
|
||||
wg *sync.WaitGroup,
|
||||
m *sync.Map,
|
||||
) {
|
||||
defer wg.Done()
|
||||
defer log.OnPanic("dnsforward: checking upstreams")
|
||||
|
||||
upstreams, domains, err := separateUpstream(line)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("wrong upstream format: %w", err)
|
||||
m.Store(line, err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
specific := len(domains) > 0
|
||||
|
||||
for _, upstreamAddr := range upstreams {
|
||||
var useDefault bool
|
||||
useDefault, err = validateUpstream(upstreamAddr, domains)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("wrong upstream format: %w", err)
|
||||
m.Store(upstreamAddr, err.Error())
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if useDefault {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
|
||||
|
||||
err = s.checkUpstreamAddr(upstreamAddr, specific, opts, check)
|
||||
if err != nil {
|
||||
m.Store(upstreamAddr, err.Error())
|
||||
} else {
|
||||
m.Store(upstreamAddr, "OK")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkUpstreamAddr creates the DNS upstream using opts and information from
|
||||
// [s.dnsFilter.EtcHosts]. Checks if the DNS upstream exchanges correctly. It
|
||||
// returns an error if addr is not valid DNS upstream address or the upstream
|
||||
// is not exchanging correctly.
|
||||
func (s *Server) checkUpstreamAddr(
|
||||
addr string,
|
||||
specific bool,
|
||||
opts *upstream.Options,
|
||||
check healthCheckFunc,
|
||||
) (err error) {
|
||||
defer func() {
|
||||
if err != nil && specific {
|
||||
err = domainSpecificTestError{error: err}
|
||||
}
|
||||
}()
|
||||
|
||||
opts = &upstream.Options{
|
||||
Bootstrap: opts.Bootstrap,
|
||||
Timeout: opts.Timeout,
|
||||
PreferIPv6: opts.PreferIPv6,
|
||||
}
|
||||
|
||||
// dnsFilter can be nil during application update.
|
||||
if s.dnsFilter != nil {
|
||||
recs := s.dnsFilter.EtcHostsRecords(extractUpstreamHost(addr))
|
||||
for _, rec := range recs {
|
||||
opts.ServerIPAddrs = append(opts.ServerIPAddrs, rec.Addr.AsSlice())
|
||||
}
|
||||
sortNetIPAddrs(opts.ServerIPAddrs, opts.PreferIPv6)
|
||||
}
|
||||
|
||||
u, err := upstream.AddressToUpstream(addr, opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating upstream for %q: %w", addr, err)
|
||||
}
|
||||
|
||||
defer func() { err = errors.WithDeferred(err, u.Close()) }()
|
||||
|
||||
return check(u)
|
||||
}
|
||||
|
||||
// handleTestUpstreamDNS handles requests to the POST /control/test_upstream_dns
|
||||
@@ -808,46 +540,27 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty)
|
||||
req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty)
|
||||
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
|
||||
req.BootstrapDNS = stringutil.FilterOut(req.BootstrapDNS, IsCommentOrEmpty)
|
||||
|
||||
opts := &upstream.Options{
|
||||
Bootstrap: req.BootstrapDNS,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
}
|
||||
if len(opts.Bootstrap) == 0 {
|
||||
opts.Bootstrap = defaultBootstrap
|
||||
|
||||
var boots []*upstream.UpstreamResolver
|
||||
opts.Bootstrap, boots, err = s.createBootstrap(req.BootstrapDNS, opts)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse bootstrap servers: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
defer closeBoots(boots)
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
m := &sync.Map{}
|
||||
cv := newUpstreamConfigValidator(req.Upstreams, req.FallbackDNS, req.PrivateUpstreams, opts)
|
||||
cv.check()
|
||||
cv.close()
|
||||
|
||||
wg.Add(len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams))
|
||||
|
||||
for _, ups := range req.Upstreams {
|
||||
go s.checkDNS(ups, opts, checkDNSUpstreamExc, wg, m)
|
||||
}
|
||||
for _, ups := range req.FallbackDNS {
|
||||
go s.checkDNS(ups, opts, checkDNSUpstreamExc, wg, m)
|
||||
}
|
||||
for _, ups := range req.PrivateUpstreams {
|
||||
go s.checkDNS(ups, opts, checkPrivateUpstreamExc, wg, m)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
result := map[string]string{}
|
||||
m.Range(func(k, v any) bool {
|
||||
// TODO(e.burkov): The upstreams used for both common and private
|
||||
// resolving should be reported separately.
|
||||
ups := k.(string)
|
||||
status := v.(string)
|
||||
|
||||
result[ups] = status
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, result)
|
||||
aghhttp.WriteJSONResponseOK(w, r, cv.status())
|
||||
}
|
||||
|
||||
// handleCacheClear is the handler for the POST /control/cache_clear HTTP API.
|
||||
|
||||
@@ -72,11 +72,14 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
|
||||
UDPListenAddrs: []*net.UDPAddr{},
|
||||
TCPListenAddrs: []*net.TCPAddr{},
|
||||
Config: Config{
|
||||
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
FallbackDNS: []string{"9.9.9.10"},
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
FallbackDNS: []string{"9.9.9.10"},
|
||||
RatelimitSubnetLenIPv4: 24,
|
||||
RatelimitSubnetLenIPv6: 56,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ConfigModified: func() {},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
||||
s.sysResolvers = &emptySysResolvers{}
|
||||
@@ -150,10 +153,13 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
UDPListenAddrs: []*net.UDPAddr{},
|
||||
TCPListenAddrs: []*net.TCPAddr{},
|
||||
Config: Config{
|
||||
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
RatelimitSubnetLenIPv4: 24,
|
||||
RatelimitSubnetLenIPv6: 56,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ConfigModified: func() {},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
||||
s.sysResolvers = &emptySysResolvers{}
|
||||
@@ -179,11 +185,18 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
name: "blocking_mode_good",
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "blocking_mode_bad",
|
||||
wantSet: "blocking_ipv4 must be valid ipv4 on custom_ip blocking_mode",
|
||||
name: "blocking_mode_bad",
|
||||
wantSet: "validating dns config: " +
|
||||
"blocking_ipv4 must be valid ipv4 on custom_ip blocking_mode",
|
||||
}, {
|
||||
name: "ratelimit",
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "ratelimit_subnet_len",
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "ratelimit_whitelist_not_ip",
|
||||
wantSet: `decoding request: ParseAddr("not.ip"): unexpected character (at "not.ip")`,
|
||||
}, {
|
||||
name: "edns_cs_enabled",
|
||||
wantSet: "",
|
||||
@@ -206,24 +219,26 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
name: "upstream_mode_fastest_addr",
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "upstream_dns_bad",
|
||||
wantSet: `validating upstream servers: validating upstream "!!!": not an ip:port`,
|
||||
name: "upstream_dns_bad",
|
||||
wantSet: `validating dns config: ` +
|
||||
`upstream servers: validating upstream "!!!": not an ip:port`,
|
||||
}, {
|
||||
name: "bootstraps_bad",
|
||||
wantSet: `checking bootstrap a: invalid address: bootstrap a:53: ` +
|
||||
wantSet: `validating dns config: checking bootstrap a: invalid address: not a bootstrap: ` +
|
||||
`ParseAddr("a"): unable to parse IP`,
|
||||
}, {
|
||||
name: "cache_bad_ttl",
|
||||
wantSet: `cache_ttl_min must be less or equal than cache_ttl_max`,
|
||||
wantSet: `validating dns config: cache_ttl_min must be less or equal than cache_ttl_max`,
|
||||
}, {
|
||||
name: "upstream_mode_bad",
|
||||
wantSet: `upstream_mode: incorrect value`,
|
||||
wantSet: `validating dns config: upstream_mode: incorrect value "somethingelse"`,
|
||||
}, {
|
||||
name: "local_ptr_upstreams_good",
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "local_ptr_upstreams_bad",
|
||||
wantSet: `validating private upstream servers: checking domain-specific upstreams: ` +
|
||||
wantSet: `validating dns config: ` +
|
||||
`private upstream servers: checking domain-specific upstreams: ` +
|
||||
`bad arpa domain name "non.arpa.": not a reversed ip network`,
|
||||
}, {
|
||||
name: "local_ptr_upstreams_null",
|
||||
@@ -347,7 +362,7 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
set: []string{"123.3.7m"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": ` +
|
||||
wantErr: `splitting upstream line "[/host.com]tls://dns.adguard.com": ` +
|
||||
`missing separator`,
|
||||
set: []string{"[/host.com]tls://dns.adguard.com"},
|
||||
}, {
|
||||
@@ -373,7 +388,7 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
},
|
||||
}, {
|
||||
name: "bad_domain",
|
||||
wantErr: `bad upstream for domain "[/!/]8.8.8.8": domain at index 0: ` +
|
||||
wantErr: `splitting upstream line "[/!/]8.8.8.8": domain at index 0: ` +
|
||||
`bad domain name "!": bad top-level domain name label "!": ` +
|
||||
`bad top-level domain name label rune '!'`,
|
||||
set: []string{"[/!/]8.8.8.8"},
|
||||
@@ -461,25 +476,15 @@ func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (r
|
||||
}
|
||||
|
||||
func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||
goodHandler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
hdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
err := w.WriteMsg(new(dns.Msg).SetReply(m))
|
||||
require.NoError(testutil.PanicT{}, err)
|
||||
})
|
||||
badHandler := dns.HandlerFunc(func(w dns.ResponseWriter, _ *dns.Msg) {
|
||||
err := w.WriteMsg(new(dns.Msg))
|
||||
require.NoError(testutil.PanicT{}, err)
|
||||
})
|
||||
|
||||
goodUps := (&url.URL{
|
||||
ups := (&url.URL{
|
||||
Scheme: "tcp",
|
||||
Host: newLocalUpstreamListener(t, 0, goodHandler).String(),
|
||||
Host: newLocalUpstreamListener(t, 0, hdlr).String(),
|
||||
}).String()
|
||||
badUps := (&url.URL{
|
||||
Scheme: "tcp",
|
||||
Host: newLocalUpstreamListener(t, 0, badHandler).String(),
|
||||
}).String()
|
||||
|
||||
goodAndBadUps := strings.Join([]string{goodUps, badUps}, " ")
|
||||
|
||||
const (
|
||||
upsTimeout = 100 * time.Millisecond
|
||||
@@ -488,7 +493,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||
upstreamHost = "custom.localhost"
|
||||
)
|
||||
|
||||
hostsListener := newLocalUpstreamListener(t, 0, goodHandler)
|
||||
hostsListener := newLocalUpstreamListener(t, 0, hdlr)
|
||||
hostsUps := (&url.URL{
|
||||
Scheme: "tcp",
|
||||
Host: netutil.JoinHostPort(upstreamHost, hostsListener.Port()),
|
||||
@@ -519,7 +524,9 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}, nil)
|
||||
srv.etcHosts = hc
|
||||
startDeferStop(t, srv)
|
||||
|
||||
testCases := []struct {
|
||||
@@ -527,43 +534,6 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||
wantResp map[string]any
|
||||
name string
|
||||
}{{
|
||||
body: map[string]any{
|
||||
"upstream_dns": []string{goodUps},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
goodUps: "OK",
|
||||
},
|
||||
name: "success",
|
||||
}, {
|
||||
body: map[string]any{
|
||||
"upstream_dns": []string{badUps},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
badUps: `couldn't communicate with upstream: exchanging with ` +
|
||||
badUps + ` over tcp: dns: id mismatch`,
|
||||
},
|
||||
name: "broken",
|
||||
}, {
|
||||
body: map[string]any{
|
||||
"upstream_dns": []string{goodUps, badUps},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
goodUps: "OK",
|
||||
badUps: `couldn't communicate with upstream: exchanging with ` +
|
||||
badUps + ` over tcp: dns: id mismatch`,
|
||||
},
|
||||
name: "both",
|
||||
}, {
|
||||
body: map[string]any{
|
||||
"upstream_dns": []string{"[/domain.example/]" + badUps},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
badUps: `WARNING: couldn't communicate ` +
|
||||
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
|
||||
`dns: id mismatch`,
|
||||
},
|
||||
name: "domain_specific_error",
|
||||
}, {
|
||||
body: map[string]any{
|
||||
"upstream_dns": []string{hostsUps},
|
||||
},
|
||||
@@ -573,63 +543,12 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||
name: "etc_hosts",
|
||||
}, {
|
||||
body: map[string]any{
|
||||
"fallback_dns": []string{goodUps},
|
||||
"upstream_dns": []string{ups, "#this.is.comment"},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
goodUps: "OK",
|
||||
ups: "OK",
|
||||
},
|
||||
name: "fallback_success",
|
||||
}, {
|
||||
body: map[string]any{
|
||||
"fallback_dns": []string{badUps},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
badUps: `couldn't communicate with upstream: exchanging with ` +
|
||||
badUps + ` over tcp: dns: id mismatch`,
|
||||
},
|
||||
name: "fallback_broken",
|
||||
}, {
|
||||
body: map[string]any{
|
||||
"fallback_dns": []string{goodUps, "#this.is.comment"},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
goodUps: "OK",
|
||||
},
|
||||
name: "fallback_comment_mix",
|
||||
}, {
|
||||
body: map[string]any{
|
||||
"upstream_dns": []string{"[/domain.example/]" + goodUps + " " + badUps},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
goodUps: "OK",
|
||||
badUps: `WARNING: couldn't communicate ` +
|
||||
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
|
||||
`dns: id mismatch`,
|
||||
},
|
||||
name: "multiple_domain_specific_upstreams",
|
||||
}, {
|
||||
body: map[string]any{
|
||||
"upstream_dns": []string{"[/domain.example/]/]1.2.3.4"},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
"[/domain.example/]/]1.2.3.4": `wrong upstream format: ` +
|
||||
`bad upstream for domain "[/domain.example/]/]1.2.3.4": ` +
|
||||
`duplicated separator`,
|
||||
},
|
||||
name: "bad_specification",
|
||||
}, {
|
||||
body: map[string]any{
|
||||
"upstream_dns": []string{"[/domain.example/]" + goodAndBadUps},
|
||||
"fallback_dns": []string{"[/domain.example/]" + goodAndBadUps},
|
||||
"private_upstream": []string{"[/domain.example/]" + goodAndBadUps},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
goodUps: "OK",
|
||||
badUps: `WARNING: couldn't communicate ` +
|
||||
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
|
||||
`dns: id mismatch`,
|
||||
},
|
||||
name: "all_different",
|
||||
name: "comment_mix",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -91,7 +91,7 @@ func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.M
|
||||
case filtering.BlockingModeREFUSED:
|
||||
return s.makeResponseREFUSED(req)
|
||||
default:
|
||||
log.Error("dns: invalid blocking mode %q", mode)
|
||||
log.Error("dnsforward: invalid blocking mode %q", mode)
|
||||
|
||||
return s.makeResponse(req)
|
||||
}
|
||||
@@ -112,7 +112,7 @@ func (s *Server) makeResponseCustomIP(
|
||||
default:
|
||||
// Generally shouldn't happen, since the types are checked in
|
||||
// genDNSFilterMessage.
|
||||
log.Error("dns: invalid msg type %s for custom IP blocking mode", dns.Type(qt))
|
||||
log.Error("dnsforward: invalid msg type %s for custom IP blocking mode", dns.Type(qt))
|
||||
|
||||
return s.makeResponse(req)
|
||||
}
|
||||
@@ -207,15 +207,7 @@ func (s *Server) genResponseWithIPs(req *dns.Msg, ips []netip.Addr) (resp *dns.M
|
||||
var ans []dns.RR
|
||||
switch req.Question[0].Qtype {
|
||||
case dns.TypeA:
|
||||
for _, ip := range ips {
|
||||
if ip.Is4() {
|
||||
ans = append(ans, s.genAnswerA(req, ip))
|
||||
} else {
|
||||
ans = nil
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
ans = s.genAnswersWithIPv4s(req, ips)
|
||||
case dns.TypeAAAA:
|
||||
for _, ip := range ips {
|
||||
if ip.Is6() {
|
||||
@@ -232,6 +224,23 @@ func (s *Server) genResponseWithIPs(req *dns.Msg, ips []netip.Addr) (resp *dns.M
|
||||
return resp
|
||||
}
|
||||
|
||||
// genAnswersWithIPv4s generates DNS A answers provided IPv4 addresses. If any
|
||||
// of the IPs isn't an IPv4 address, genAnswersWithIPv4s logs a warning and
|
||||
// returns nil,
|
||||
func (s *Server) genAnswersWithIPv4s(req *dns.Msg, ips []netip.Addr) (ans []dns.RR) {
|
||||
for _, ip := range ips {
|
||||
if !ip.Is4() {
|
||||
log.Info("dnsforward: warning: ip %s is not ipv4 address", ip)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
ans = append(ans, s.genAnswerA(req, ip))
|
||||
}
|
||||
|
||||
return ans
|
||||
}
|
||||
|
||||
// makeResponseNullIP creates a response with 0.0.0.0 for A requests, :: for
|
||||
// AAAA requests, and an empty response for other types.
|
||||
func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) {
|
||||
@@ -253,7 +262,7 @@ func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) {
|
||||
|
||||
func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSContext) *dns.Msg {
|
||||
if newAddr == "" {
|
||||
log.Printf("block host is not specified.")
|
||||
log.Info("dnsforward: block host is not specified")
|
||||
|
||||
return s.genServerFailure(request)
|
||||
}
|
||||
@@ -276,14 +285,14 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
|
||||
|
||||
prx := s.proxy()
|
||||
if prx == nil {
|
||||
log.Debug("dns: %s", srvClosedErr)
|
||||
log.Debug("dnsforward: %s", srvClosedErr)
|
||||
|
||||
return s.genServerFailure(request)
|
||||
}
|
||||
|
||||
err = prx.Resolve(newContext)
|
||||
if err != nil {
|
||||
log.Printf("couldn't look up replacement host %q: %s", newAddr, err)
|
||||
log.Info("dnsforward: looking up replacement host %q: %s", newAddr, err)
|
||||
|
||||
return s.genServerFailure(request)
|
||||
}
|
||||
|
||||
@@ -191,7 +191,7 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
|
||||
defer log.Debug("dnsforward: finished processing initial")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
s.processClientIP(pctx.Addr)
|
||||
s.processClientIP(pctx.Addr.Addr())
|
||||
|
||||
q := pctx.Req.Question[0]
|
||||
qt := q.Qtype
|
||||
@@ -228,9 +228,8 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
|
||||
}
|
||||
|
||||
// processClientIP sends the client IP address to s.addrProc, if needed.
|
||||
func (s *Server) processClientIP(addr net.Addr) {
|
||||
clientIP := netutil.NetAddrToAddrPort(addr).Addr()
|
||||
if clientIP == (netip.Addr{}) {
|
||||
func (s *Server) processClientIP(addr netip.Addr) {
|
||||
if !addr.IsValid() {
|
||||
log.Info("dnsforward: warning: bad client addr %q", addr)
|
||||
|
||||
return
|
||||
@@ -241,7 +240,7 @@ func (s *Server) processClientIP(addr net.Addr) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
s.addrProc.Process(clientIP)
|
||||
s.addrProc.Process(addr)
|
||||
}
|
||||
|
||||
// processDDRQuery responds to Discovery of Designated Resolvers (DDR) SVCB
|
||||
@@ -351,12 +350,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
|
||||
|
||||
rc = resultCodeSuccess
|
||||
|
||||
var ip net.IP
|
||||
if ip, _ = netutil.IPAndPortFromAddr(dctx.proxyCtx.Addr); ip == nil {
|
||||
return rc
|
||||
}
|
||||
|
||||
dctx.isLocalClient = s.privateNets.Contains(ip)
|
||||
dctx.isLocalClient = s.privateNets.Contains(dctx.proxyCtx.Addr.Addr().AsSlice())
|
||||
|
||||
return rc
|
||||
}
|
||||
@@ -831,14 +825,13 @@ func (s *Server) dhcpHostFromRequest(q *dns.Question) (reqHost string) {
|
||||
|
||||
// setCustomUpstream sets custom upstream settings in pctx, if necessary.
|
||||
func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
|
||||
customUpsByClient := s.conf.GetCustomUpstreamByClient
|
||||
if pctx.Addr == nil || customUpsByClient == nil {
|
||||
if !pctx.Addr.IsValid() || s.conf.ClientsContainer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Use the ClientID first, since it has a higher priority.
|
||||
id := stringutil.Coalesce(clientID, ipStringFromAddr(pctx.Addr))
|
||||
upsConf, err := customUpsByClient(id)
|
||||
id := stringutil.Coalesce(clientID, pctx.Addr.Addr().String())
|
||||
upsConf, err := s.conf.ClientsContainer.UpstreamConfigByID(id, s.bootstrap)
|
||||
if err != nil {
|
||||
log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err)
|
||||
|
||||
@@ -847,9 +840,9 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
|
||||
|
||||
if upsConf != nil {
|
||||
log.Debug("dnsforward: using custom upstreams for client %s", id)
|
||||
}
|
||||
|
||||
pctx.CustomUpstreamConfig = upsConf
|
||||
pctx.CustomUpstreamConfig = upsConf
|
||||
}
|
||||
}
|
||||
|
||||
// Apply filtering logic after we have received response from upstream servers
|
||||
|
||||
@@ -81,6 +81,7 @@ func TestServer_ProcessInitial(t *testing.T) {
|
||||
AAAADisabled: tc.aaaaDisabled,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
@@ -96,14 +97,14 @@ func TestServer_ProcessInitial(t *testing.T) {
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Req: createTestMessageWithType(tc.target, tc.qType),
|
||||
Addr: testClientAddr,
|
||||
Addr: testClientAddrPort,
|
||||
RequestID: 1234,
|
||||
},
|
||||
}
|
||||
|
||||
gotRC := s.processInitial(dctx)
|
||||
assert.Equal(t, tc.wantRC, gotRC)
|
||||
assert.Equal(t, netutil.NetAddrToAddrPort(testClientAddr).Addr(), gotAddr)
|
||||
assert.Equal(t, testClientAddrPort.Addr(), gotAddr)
|
||||
|
||||
if tc.wantRCode > 0 {
|
||||
gotResp := dctx.proxyCtx.Res
|
||||
@@ -180,6 +181,7 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
|
||||
AAAADisabled: tc.aaaaDisabled,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
@@ -199,7 +201,7 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
|
||||
Proto: proxy.ProtoUDP,
|
||||
Req: tc.req,
|
||||
Res: resp,
|
||||
Addr: testClientAddr,
|
||||
Addr: testClientAddrPort,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -369,6 +371,7 @@ func prepareTestServer(t *testing.T, portDoH, portDoT, portDoQ int, ddrEnabled b
|
||||
TLSConfig: TLSConfig{
|
||||
ServerName: ddrTestDomainName,
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -394,33 +397,27 @@ func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
want assert.BoolAssertionFunc
|
||||
name string
|
||||
cliIP net.IP
|
||||
want assert.BoolAssertionFunc
|
||||
name string
|
||||
cliAddr netip.AddrPort
|
||||
}{{
|
||||
want: assert.True,
|
||||
name: "local",
|
||||
cliIP: net.IP{192, 168, 0, 1},
|
||||
want: assert.True,
|
||||
name: "local",
|
||||
cliAddr: netip.MustParseAddrPort("192.168.0.1:1"),
|
||||
}, {
|
||||
want: assert.False,
|
||||
name: "external",
|
||||
cliIP: net.IP{250, 249, 0, 1},
|
||||
want: assert.False,
|
||||
name: "external",
|
||||
cliAddr: netip.MustParseAddrPort("250.249.0.1:1"),
|
||||
}, {
|
||||
want: assert.False,
|
||||
name: "invalid",
|
||||
cliIP: net.IP{1, 2, 3, 4, 5},
|
||||
}, {
|
||||
want: assert.False,
|
||||
name: "nil",
|
||||
cliIP: nil,
|
||||
want: assert.False,
|
||||
name: "invalid",
|
||||
cliAddr: netip.AddrPort{},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
proxyCtx := &proxy.DNSContext{
|
||||
Addr: &net.TCPAddr{
|
||||
IP: tc.cliIP,
|
||||
},
|
||||
Addr: tc.cliAddr,
|
||||
}
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: proxyCtx,
|
||||
@@ -699,6 +696,7 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}, ups)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ups}
|
||||
startDeferStop(t, s)
|
||||
@@ -707,31 +705,31 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
|
||||
name string
|
||||
want string
|
||||
question net.IP
|
||||
cliIP net.IP
|
||||
cliAddr netip.AddrPort
|
||||
wantLen int
|
||||
}{{
|
||||
name: "from_local_to_external",
|
||||
want: "host1.example.net.",
|
||||
question: net.IP{254, 253, 252, 251},
|
||||
cliIP: net.IP{192, 168, 10, 10},
|
||||
cliAddr: netip.MustParseAddrPort("192.168.10.10:1"),
|
||||
wantLen: 1,
|
||||
}, {
|
||||
name: "from_external_for_local",
|
||||
want: "",
|
||||
question: net.IP{192, 168, 1, 1},
|
||||
cliIP: net.IP{254, 253, 252, 251},
|
||||
cliAddr: netip.MustParseAddrPort("254.253.252.251:1"),
|
||||
wantLen: 0,
|
||||
}, {
|
||||
name: "from_local_for_local",
|
||||
want: "some.local-client.",
|
||||
question: net.IP{192, 168, 1, 1},
|
||||
cliIP: net.IP{192, 168, 1, 2},
|
||||
cliAddr: netip.MustParseAddrPort("192.168.1.2:1"),
|
||||
wantLen: 1,
|
||||
}, {
|
||||
name: "from_external_for_external",
|
||||
want: "host1.example.net.",
|
||||
question: net.IP{254, 253, 252, 251},
|
||||
cliIP: net.IP{254, 253, 252, 255},
|
||||
cliAddr: netip.MustParseAddrPort("254.253.252.255:1"),
|
||||
wantLen: 1,
|
||||
}}
|
||||
|
||||
@@ -743,9 +741,7 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
|
||||
pctx := &proxy.DNSContext{
|
||||
Proto: proxy.ProtoTCP,
|
||||
Req: req,
|
||||
Addr: &net.TCPAddr{
|
||||
IP: tc.cliIP,
|
||||
},
|
||||
Addr: tc.cliAddr,
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -776,6 +772,7 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
},
|
||||
aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return aghalg.Coalesce(
|
||||
@@ -789,7 +786,7 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
|
||||
var dnsCtx *dnsContext
|
||||
setup := func(use bool) {
|
||||
proxyCtx = &proxy.DNSContext{
|
||||
Addr: testClientAddr,
|
||||
Addr: testClientAddrPort,
|
||||
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
|
||||
}
|
||||
dnsCtx = &dnsContext{
|
||||
|
||||
@@ -10,9 +10,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// Write Stats data and logs
|
||||
@@ -25,13 +23,12 @@ func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
|
||||
host := aghnet.NormalizeDomain(q.Name)
|
||||
processingTime := time.Since(dctx.startTime)
|
||||
|
||||
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
|
||||
ip = slices.Clone(ip)
|
||||
ip := pctx.Addr.Addr().AsSlice()
|
||||
s.anonymizer.Load()(ip)
|
||||
|
||||
log.Debug("dnsforward: client ip for stats and querylog: %s", ip)
|
||||
|
||||
ipStr := ip.String()
|
||||
ipStr := pctx.Addr.Addr().String()
|
||||
ids := []string{ipStr, dctx.clientID}
|
||||
qt, cl := q.Qtype, q.Qclass
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -65,7 +65,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
|
||||
name string
|
||||
domain string
|
||||
proto proxy.Proto
|
||||
addr net.Addr
|
||||
addr netip.AddrPort
|
||||
clientID string
|
||||
wantLogProto querylog.ClientProto
|
||||
wantStatClient string
|
||||
@@ -76,7 +76,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
|
||||
name: "success_udp",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
addr: testClientAddrPort,
|
||||
clientID: "",
|
||||
wantLogProto: "",
|
||||
wantStatClient: "1.2.3.4",
|
||||
@@ -87,7 +87,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
|
||||
name: "success_tls_clientid",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoTLS,
|
||||
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
addr: testClientAddrPort,
|
||||
clientID: "cli42",
|
||||
wantLogProto: querylog.ClientProtoDoT,
|
||||
wantStatClient: "cli42",
|
||||
@@ -98,7 +98,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
|
||||
name: "success_tls",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoTLS,
|
||||
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
addr: testClientAddrPort,
|
||||
clientID: "",
|
||||
wantLogProto: querylog.ClientProtoDoT,
|
||||
wantStatClient: "1.2.3.4",
|
||||
@@ -109,7 +109,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
|
||||
name: "success_quic",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoQUIC,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
addr: testClientAddrPort,
|
||||
clientID: "",
|
||||
wantLogProto: querylog.ClientProtoDoQ,
|
||||
wantStatClient: "1.2.3.4",
|
||||
@@ -120,7 +120,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
|
||||
name: "success_https",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoHTTPS,
|
||||
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
addr: testClientAddrPort,
|
||||
clientID: "",
|
||||
wantLogProto: querylog.ClientProtoDoH,
|
||||
wantStatClient: "1.2.3.4",
|
||||
@@ -131,7 +131,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
|
||||
name: "success_dnscrypt",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoDNSCrypt,
|
||||
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
addr: testClientAddrPort,
|
||||
clientID: "",
|
||||
wantLogProto: querylog.ClientProtoDNSCrypt,
|
||||
wantStatClient: "1.2.3.4",
|
||||
@@ -142,7 +142,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
|
||||
name: "success_udp_filtered",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
addr: testClientAddrPort,
|
||||
clientID: "",
|
||||
wantLogProto: "",
|
||||
wantStatClient: "1.2.3.4",
|
||||
@@ -153,7 +153,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
|
||||
name: "success_udp_sb",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
addr: testClientAddrPort,
|
||||
clientID: "",
|
||||
wantLogProto: "",
|
||||
wantStatClient: "1.2.3.4",
|
||||
@@ -164,7 +164,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
|
||||
name: "success_udp_ss",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
addr: testClientAddrPort,
|
||||
clientID: "",
|
||||
wantLogProto: "",
|
||||
wantStatClient: "1.2.3.4",
|
||||
@@ -175,7 +175,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
|
||||
name: "success_udp_pc",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
addr: testClientAddrPort,
|
||||
clientID: "",
|
||||
wantLogProto: "",
|
||||
wantStatClient: "1.2.3.4",
|
||||
@@ -186,10 +186,10 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
|
||||
name: "success_udp_pc_empty_fqdn",
|
||||
domain: ".",
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 5}, Port: 1234},
|
||||
addr: netip.MustParseAddrPort("4.3.2.1:1234"),
|
||||
clientID: "",
|
||||
wantLogProto: "",
|
||||
wantStatClient: "1.2.3.5",
|
||||
wantStatClient: "4.3.2.1",
|
||||
wantCode: resultCodeSuccess,
|
||||
reason: filtering.FilteredParental,
|
||||
wantStatResult: stats.RParental,
|
||||
|
||||
@@ -19,6 +19,7 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}, nil)
|
||||
|
||||
req := &dns.Msg{
|
||||
|
||||
@@ -17,6 +17,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -53,6 +56,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -89,6 +95,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
|
||||
@@ -22,6 +22,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -60,6 +63,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -99,6 +105,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "refused",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -138,6 +147,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -177,6 +189,98 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 6,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
"upstream_mode": "",
|
||||
"cache_size": 0,
|
||||
"cache_ttl_min": 0,
|
||||
"cache_ttl_max": 0,
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"ratelimit_subnet_len": {
|
||||
"req": {
|
||||
"ratelimit": 12,
|
||||
"ratelimit_subnet_len_ipv4": 32,
|
||||
"ratelimit_subnet_len_ipv6": 128
|
||||
},
|
||||
"want": {
|
||||
"upstream_dns": [
|
||||
"8.8.8.8:53",
|
||||
"8.8.4.4:53"
|
||||
],
|
||||
"upstream_dns_file": "",
|
||||
"bootstrap_dns": [
|
||||
"9.9.9.10",
|
||||
"149.112.112.10",
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 12,
|
||||
"ratelimit_subnet_len_ipv4": 32,
|
||||
"ratelimit_subnet_len_ipv6": 128,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
"upstream_mode": "",
|
||||
"cache_size": 0,
|
||||
"cache_ttl_min": 0,
|
||||
"cache_ttl_max": 0,
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"ratelimit_whitelist_not_ip": {
|
||||
"req": {
|
||||
"ratelimit_whitelist": [
|
||||
"1.2.3.4",
|
||||
"not.ip"
|
||||
]
|
||||
},
|
||||
"want": {
|
||||
"upstream_dns": [
|
||||
"8.8.8.8:53",
|
||||
"8.8.4.4:53"
|
||||
],
|
||||
"upstream_dns_file": "",
|
||||
"bootstrap_dns": [
|
||||
"9.9.9.10",
|
||||
"149.112.112.10",
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -216,6 +320,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -257,6 +364,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -298,6 +408,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -337,6 +450,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -376,6 +492,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -415,6 +534,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -454,6 +576,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -495,6 +620,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -536,6 +664,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -576,6 +707,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -615,6 +749,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -656,6 +793,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -700,6 +840,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -739,6 +882,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -782,6 +928,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -821,6 +970,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
@@ -863,6 +1015,9 @@
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"ratelimit_subnet_len_ipv4": 24,
|
||||
"ratelimit_subnet_len_ipv6": 56,
|
||||
"ratelimit_whitelist": [],
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
@@ -18,6 +19,28 @@ import (
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
const (
|
||||
// errNotDomainSpecific is returned when the upstream should be
|
||||
// domain-specific, but isn't.
|
||||
errNotDomainSpecific errors.Error = "not a domain-specific upstream"
|
||||
|
||||
// errMissingSeparator is returned when the domain-specific part of the
|
||||
// upstream configuration line isn't closed.
|
||||
errMissingSeparator errors.Error = "missing separator"
|
||||
|
||||
// errDupSeparator is returned when the domain-specific part of the upstream
|
||||
// configuration line contains more than one ending separator.
|
||||
errDupSeparator errors.Error = "duplicated separator"
|
||||
|
||||
// errNoDefaultUpstreams is returned when there are no default upstreams
|
||||
// specified in the upstream configuration.
|
||||
errNoDefaultUpstreams errors.Error = "no default upstreams specified"
|
||||
|
||||
// errWrongResponse is returned when the checked upstream replies in an
|
||||
// unexpected way.
|
||||
errWrongResponse errors.Error = "wrong response"
|
||||
)
|
||||
|
||||
// loadUpstreams parses upstream DNS servers from the configured file or from
|
||||
// the configuration itself.
|
||||
func (s *Server) loadUpstreams() (upstreams []string, err error) {
|
||||
@@ -39,7 +62,7 @@ func (s *Server) loadUpstreams() (upstreams []string, err error) {
|
||||
}
|
||||
|
||||
// prepareUpstreamSettings sets upstream DNS server settings.
|
||||
func (s *Server) prepareUpstreamSettings() (err error) {
|
||||
func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
|
||||
// Load upstreams either from the file, or from the settings
|
||||
var upstreams []string
|
||||
upstreams, err = s.loadUpstreams()
|
||||
@@ -48,7 +71,7 @@ func (s *Server) prepareUpstreamSettings() (err error) {
|
||||
}
|
||||
|
||||
s.conf.UpstreamConfig, err = s.prepareUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
|
||||
Bootstrap: s.conf.BootstrapDNS,
|
||||
Bootstrap: boot,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
@@ -92,178 +115,9 @@ func (s *Server) prepareUpstreamConfig(
|
||||
uc.Upstreams = defaultUpstreamConfig.Upstreams
|
||||
}
|
||||
|
||||
// dnsFilter can be nil during application update.
|
||||
if s.dnsFilter != nil {
|
||||
err = s.replaceUpstreamsWithHosts(uc, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolving upstreams with hosts: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
// replaceUpstreamsWithHosts replaces unique upstreams with their resolved
|
||||
// versions based on the system hosts file.
|
||||
//
|
||||
// TODO(e.burkov): This should be performed inside dnsproxy, which should
|
||||
// actually consider /etc/hosts. See TODO on [aghnet.HostsContainer].
|
||||
func (s *Server) replaceUpstreamsWithHosts(
|
||||
upsConf *proxy.UpstreamConfig,
|
||||
opts *upstream.Options,
|
||||
) (err error) {
|
||||
resolved := map[string]*upstream.Options{}
|
||||
|
||||
err = s.resolveUpstreamsWithHosts(resolved, upsConf.Upstreams, opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolving upstreams: %w", err)
|
||||
}
|
||||
|
||||
hosts := maps.Keys(upsConf.DomainReservedUpstreams)
|
||||
// TODO(e.burkov): Think of extracting sorted range into an util function.
|
||||
slices.Sort(hosts)
|
||||
for _, host := range hosts {
|
||||
err = s.resolveUpstreamsWithHosts(resolved, upsConf.DomainReservedUpstreams[host], opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolving upstreams reserved for %s: %w", host, err)
|
||||
}
|
||||
}
|
||||
|
||||
hosts = maps.Keys(upsConf.SpecifiedDomainUpstreams)
|
||||
slices.Sort(hosts)
|
||||
for _, host := range hosts {
|
||||
err = s.resolveUpstreamsWithHosts(resolved, upsConf.SpecifiedDomainUpstreams[host], opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolving upstreams specific for %s: %w", host, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveUpstreamsWithHosts resolves the IP addresses of each of the upstreams
|
||||
// and replaces those both in upstreams and resolved. Upstreams that failed to
|
||||
// resolve are placed to resolved as-is. This function only returns error of
|
||||
// upstreams closing.
|
||||
func (s *Server) resolveUpstreamsWithHosts(
|
||||
resolved map[string]*upstream.Options,
|
||||
upstreams []upstream.Upstream,
|
||||
opts *upstream.Options,
|
||||
) (err error) {
|
||||
for i := range upstreams {
|
||||
u := upstreams[i]
|
||||
addr := u.Address()
|
||||
host := extractUpstreamHost(addr)
|
||||
|
||||
withIPs, ok := resolved[host]
|
||||
if !ok {
|
||||
recs := s.dnsFilter.EtcHostsRecords(host)
|
||||
if len(recs) == 0 {
|
||||
resolved[host] = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
withIPs = opts.Clone()
|
||||
withIPs.ServerIPAddrs = make([]net.IP, 0, len(recs))
|
||||
for _, rec := range recs {
|
||||
withIPs.ServerIPAddrs = append(withIPs.ServerIPAddrs, rec.Addr.AsSlice())
|
||||
}
|
||||
|
||||
sortNetIPAddrs(withIPs.ServerIPAddrs, opts.PreferIPv6)
|
||||
resolved[host] = withIPs
|
||||
} else if withIPs == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if err = u.Close(); err != nil {
|
||||
return fmt.Errorf("closing upstream %s: %w", addr, err)
|
||||
}
|
||||
|
||||
upstreams[i], err = upstream.AddressToUpstream(addr, withIPs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("replacing upstream %s with resolved %s: %w", addr, host, err)
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: using %s for %s", withIPs.ServerIPAddrs, upstreams[i].Address())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractUpstreamHost returns the hostname of addr without port with an
|
||||
// assumption that any address passed here has already been successfully parsed
|
||||
// by [upstream.AddressToUpstream]. This function essentially mirrors the logic
|
||||
// of [upstream.AddressToUpstream], see TODO on [replaceUpstreamsWithHosts].
|
||||
func extractUpstreamHost(addr string) (host string) {
|
||||
var err error
|
||||
if strings.Contains(addr, "://") {
|
||||
var u *url.URL
|
||||
u, err = url.Parse(addr)
|
||||
if err != nil {
|
||||
log.Debug("dnsforward: parsing upstream %s: %s", addr, err)
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
return u.Hostname()
|
||||
}
|
||||
|
||||
// Probably, plain UDP upstream defined by address or address:port.
|
||||
host, err = netutil.SplitHost(addr)
|
||||
if err != nil {
|
||||
return addr
|
||||
}
|
||||
|
||||
return host
|
||||
}
|
||||
|
||||
// sortNetIPAddrs sorts addrs in accordance with the protocol preferences.
|
||||
// Invalid addresses are sorted near the end.
|
||||
//
|
||||
// TODO(e.burkov): This function taken from dnsproxy, which also already
|
||||
// contains a few similar functions. Think of moving to golibs.
|
||||
func sortNetIPAddrs(addrs []net.IP, preferIPv6 bool) {
|
||||
l := len(addrs)
|
||||
if l <= 1 {
|
||||
return
|
||||
}
|
||||
|
||||
slices.SortStableFunc(addrs, func(addrA, addrB net.IP) (res int) {
|
||||
switch len(addrA) {
|
||||
case net.IPv4len, net.IPv6len:
|
||||
switch len(addrB) {
|
||||
case net.IPv4len, net.IPv6len:
|
||||
// Go on.
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
|
||||
// Treat IPv6-mapped IPv4 addresses as IPv6 addresses.
|
||||
aIs4, bIs4 := addrA.To4() != nil, addrB.To4() != nil
|
||||
if aIs4 == bIs4 {
|
||||
return bytes.Compare(addrA, addrB)
|
||||
}
|
||||
|
||||
if aIs4 {
|
||||
if preferIPv6 {
|
||||
return 1
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
if preferIPv6 {
|
||||
return -1
|
||||
}
|
||||
|
||||
return 1
|
||||
})
|
||||
}
|
||||
|
||||
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration
|
||||
// depending on configuration.
|
||||
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
|
||||
@@ -295,3 +149,221 @@ func setProxyUpstreamMode(
|
||||
conf.UpstreamMode = proxy.UModeLoadBalance
|
||||
}
|
||||
}
|
||||
|
||||
// createBootstrap returns a bootstrap resolver based on the configuration of s.
|
||||
// boots are the upstream resolvers that should be closed after use. r is the
|
||||
// actual bootstrap resolver, which may include the system hosts.
|
||||
//
|
||||
// TODO(e.burkov): This function currently returns a resolver and a slice of
|
||||
// the upstream resolvers, which are essentially the same. boots are returned
|
||||
// for being able to close them afterwards, but it introduces an implicit
|
||||
// contract that r could only be used before that. Anyway, this code should
|
||||
// improve when the [proxy.UpstreamConfig] will become an [upstream.Resolver]
|
||||
// and be used here.
|
||||
func (s *Server) createBootstrap(
|
||||
addrs []string,
|
||||
opts *upstream.Options,
|
||||
) (r upstream.Resolver, boots []*upstream.UpstreamResolver, err error) {
|
||||
if len(addrs) == 0 {
|
||||
addrs = defaultBootstrap
|
||||
}
|
||||
|
||||
boots, err = aghnet.ParseBootstraps(addrs, opts)
|
||||
if err != nil {
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var parallel upstream.ParallelResolver
|
||||
for _, b := range boots {
|
||||
parallel = append(parallel, b)
|
||||
}
|
||||
|
||||
if s.etcHosts != nil {
|
||||
r = upstream.ConsequentResolver{s.etcHosts, parallel}
|
||||
} else {
|
||||
r = parallel
|
||||
}
|
||||
|
||||
return r, boots, nil
|
||||
}
|
||||
|
||||
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
|
||||
// This function is useful for filtering out non-upstream lines from upstream
|
||||
// configs.
|
||||
func IsCommentOrEmpty(s string) (ok bool) {
|
||||
return len(s) == 0 || s[0] == '#'
|
||||
}
|
||||
|
||||
// newUpstreamConfig validates upstreams and returns an appropriate upstream
|
||||
// configuration or nil if it can't be built.
|
||||
//
|
||||
// TODO(e.burkov): Perhaps proxy.ParseUpstreamsConfig should validate upstreams
|
||||
// slice already so that this function may be considered useless.
|
||||
func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err error) {
|
||||
// No need to validate comments and empty lines.
|
||||
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
|
||||
if len(upstreams) == 0 {
|
||||
// Consider this case valid since it means the default server should be
|
||||
// used.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
err = validateUpstreamConfig(upstreams)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conf, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
&upstream.Options{
|
||||
Bootstrap: net.DefaultResolver,
|
||||
Timeout: DefaultTimeout,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
} else if len(conf.Upstreams) == 0 {
|
||||
return nil, errNoDefaultUpstreams
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
// validateUpstreamConfig validates each upstream from the upstream
|
||||
// configuration and returns an error if any upstream is invalid.
|
||||
//
|
||||
// TODO(e.burkov): Merge with [upstreamConfigValidator] somehow.
|
||||
func validateUpstreamConfig(conf []string) (err error) {
|
||||
for _, u := range conf {
|
||||
var ups []string
|
||||
var isSpecific bool
|
||||
ups, isSpecific, err = splitUpstreamLine(u)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
for _, addr := range ups {
|
||||
_, err = validateUpstream(addr, isSpecific)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating upstream %q: %w", addr, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateUpstreams validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified.
|
||||
//
|
||||
// TODO(e.burkov): Merge with [upstreamConfigValidator] somehow.
|
||||
func ValidateUpstreams(upstreams []string) (err error) {
|
||||
_, err = newUpstreamConfig(upstreams)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified. It also
|
||||
// checks each domain of domain-specific upstreams for being ARPA pointing to
|
||||
// a locally-served network. privateNets must not be nil.
|
||||
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
|
||||
conf, err := newUpstreamConfig(upstreams)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating config: %w", err)
|
||||
}
|
||||
|
||||
if conf == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
keys := maps.Keys(conf.DomainReservedUpstreams)
|
||||
slices.Sort(keys)
|
||||
|
||||
var errs []error
|
||||
for _, domain := range keys {
|
||||
var subnet netip.Prefix
|
||||
subnet, err = extractARPASubnet(domain)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !privateNets.Contains(subnet.Addr().AsSlice()) {
|
||||
errs = append(
|
||||
errs,
|
||||
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w")
|
||||
}
|
||||
|
||||
// protocols are the supported URL schemes for upstreams.
|
||||
var protocols = []string{"h3", "https", "quic", "sdns", "tcp", "tls", "udp"}
|
||||
|
||||
// validateUpstream returns an error if u alongside with domains is not a valid
|
||||
// upstream configuration. useDefault is true if the upstream is
|
||||
// domain-specific and is configured to point at the default upstream server
|
||||
// which is validated separately. The upstream is considered domain-specific
|
||||
// only if domains is at least not nil.
|
||||
func validateUpstream(u string, isSpecific bool) (useDefault bool, err error) {
|
||||
// The special server address '#' means that default server must be used.
|
||||
if useDefault = u == "#" && isSpecific; useDefault {
|
||||
return useDefault, nil
|
||||
}
|
||||
|
||||
// Check if the upstream has a valid protocol prefix.
|
||||
//
|
||||
// TODO(e.burkov): Validate the domain name.
|
||||
if proto, _, ok := strings.Cut(u, "://"); ok {
|
||||
if !slices.Contains(protocols, proto) {
|
||||
return false, fmt.Errorf("bad protocol %q", proto)
|
||||
}
|
||||
} else if _, err = netip.ParseAddr(u); err == nil {
|
||||
return false, nil
|
||||
} else if _, err = netip.ParseAddrPort(u); err == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return false, err
|
||||
}
|
||||
|
||||
// splitUpstreamLine returns the upstreams and the specified domains. domains
|
||||
// is nil when the upstream is not domains-specific. Otherwise it may also be
|
||||
// empty.
|
||||
func splitUpstreamLine(upstreamStr string) (upstreams []string, isSpecific bool, err error) {
|
||||
if !strings.HasPrefix(upstreamStr, "[/") {
|
||||
return []string{upstreamStr}, false, nil
|
||||
}
|
||||
|
||||
defer func() { err = errors.Annotate(err, "splitting upstream line %q: %w", upstreamStr) }()
|
||||
|
||||
doms, ups, found := strings.Cut(upstreamStr[2:], "/]")
|
||||
if !found {
|
||||
return nil, false, errMissingSeparator
|
||||
} else if strings.Contains(ups, "/]") {
|
||||
return nil, false, errDupSeparator
|
||||
}
|
||||
|
||||
for i, host := range strings.Split(doms, "/") {
|
||||
if host == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*."))
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("domain at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
isSpecific = true
|
||||
}
|
||||
|
||||
return strings.Fields(ups), isSpecific, nil
|
||||
}
|
||||
|
||||
223
internal/dnsforward/upstreams_internal_test.go
Normal file
223
internal/dnsforward/upstreams_internal_test.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUpstreamConfigValidator(t *testing.T) {
|
||||
goodHandler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
err := w.WriteMsg(new(dns.Msg).SetReply(m))
|
||||
require.NoError(testutil.PanicT{}, err)
|
||||
})
|
||||
badHandler := dns.HandlerFunc(func(w dns.ResponseWriter, _ *dns.Msg) {
|
||||
err := w.WriteMsg(new(dns.Msg))
|
||||
require.NoError(testutil.PanicT{}, err)
|
||||
})
|
||||
|
||||
goodUps := (&url.URL{
|
||||
Scheme: "tcp",
|
||||
Host: newLocalUpstreamListener(t, 0, goodHandler).String(),
|
||||
}).String()
|
||||
badUps := (&url.URL{
|
||||
Scheme: "tcp",
|
||||
Host: newLocalUpstreamListener(t, 0, badHandler).String(),
|
||||
}).String()
|
||||
|
||||
goodAndBadUps := strings.Join([]string{goodUps, badUps}, " ")
|
||||
|
||||
// upsTimeout restricts the checking process to prevent the test from
|
||||
// hanging.
|
||||
const upsTimeout = 100 * time.Millisecond
|
||||
|
||||
testCases := []struct {
|
||||
want map[string]string
|
||||
name string
|
||||
general []string
|
||||
fallback []string
|
||||
private []string
|
||||
}{{
|
||||
name: "success",
|
||||
general: []string{goodUps},
|
||||
want: map[string]string{
|
||||
goodUps: "OK",
|
||||
},
|
||||
}, {
|
||||
name: "broken",
|
||||
general: []string{badUps},
|
||||
want: map[string]string{
|
||||
badUps: `couldn't communicate with upstream: exchanging with ` +
|
||||
badUps + ` over tcp: dns: id mismatch`,
|
||||
},
|
||||
}, {
|
||||
name: "both",
|
||||
general: []string{goodUps, badUps, goodUps},
|
||||
want: map[string]string{
|
||||
goodUps: "OK",
|
||||
badUps: `couldn't communicate with upstream: exchanging with ` +
|
||||
badUps + ` over tcp: dns: id mismatch`,
|
||||
},
|
||||
}, {
|
||||
name: "domain_specific_error",
|
||||
general: []string{"[/domain.example/]" + badUps},
|
||||
want: map[string]string{
|
||||
badUps: `WARNING: couldn't communicate ` +
|
||||
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
|
||||
`dns: id mismatch`,
|
||||
},
|
||||
}, {
|
||||
name: "fallback_success",
|
||||
fallback: []string{goodUps},
|
||||
want: map[string]string{
|
||||
goodUps: "OK",
|
||||
},
|
||||
}, {
|
||||
name: "fallback_broken",
|
||||
fallback: []string{badUps},
|
||||
want: map[string]string{
|
||||
badUps: `couldn't communicate with upstream: exchanging with ` +
|
||||
badUps + ` over tcp: dns: id mismatch`,
|
||||
},
|
||||
}, {
|
||||
name: "multiple_domain_specific_upstreams",
|
||||
general: []string{"[/domain.example/]" + goodAndBadUps},
|
||||
want: map[string]string{
|
||||
goodUps: "OK",
|
||||
badUps: `WARNING: couldn't communicate ` +
|
||||
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
|
||||
`dns: id mismatch`,
|
||||
},
|
||||
}, {
|
||||
name: "bad_specification",
|
||||
general: []string{"[/domain.example/]/]1.2.3.4"},
|
||||
want: map[string]string{
|
||||
"[/domain.example/]/]1.2.3.4": `splitting upstream line ` +
|
||||
`"[/domain.example/]/]1.2.3.4": duplicated separator`,
|
||||
},
|
||||
}, {
|
||||
name: "all_different",
|
||||
general: []string{"[/domain.example/]" + goodAndBadUps},
|
||||
fallback: []string{"[/domain.example/]" + goodAndBadUps},
|
||||
private: []string{"[/domain.example/]" + goodAndBadUps},
|
||||
want: map[string]string{
|
||||
goodUps: "OK",
|
||||
badUps: `WARNING: couldn't communicate ` +
|
||||
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
|
||||
`dns: id mismatch`,
|
||||
},
|
||||
}, {
|
||||
name: "bad_specific_domains",
|
||||
general: []string{"[/example/]/]" + goodUps},
|
||||
fallback: []string{"[/example/" + goodUps},
|
||||
private: []string{"[/example//bad.123/]" + goodUps},
|
||||
want: map[string]string{
|
||||
`[/example/]/]` + goodUps: `splitting upstream line ` +
|
||||
`"[/example/]/]` + goodUps + `": duplicated separator`,
|
||||
`[/example/` + goodUps: `splitting upstream line ` +
|
||||
`"[/example/` + goodUps + `": missing separator`,
|
||||
`[/example//bad.123/]` + goodUps: `splitting upstream line ` +
|
||||
`"[/example//bad.123/]` + goodUps + `": domain at index 2: ` +
|
||||
`bad domain name "bad.123": ` +
|
||||
`bad top-level domain name label "123": all octets are numeric`,
|
||||
},
|
||||
}, {
|
||||
name: "non-specific_default",
|
||||
general: []string{
|
||||
"#",
|
||||
"[/example/]#",
|
||||
},
|
||||
want: map[string]string{
|
||||
"#": "not a domain-specific upstream",
|
||||
},
|
||||
}, {
|
||||
name: "bad_proto",
|
||||
general: []string{
|
||||
"bad://1.2.3.4",
|
||||
},
|
||||
want: map[string]string{
|
||||
"bad://1.2.3.4": `bad protocol "bad"`,
|
||||
},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cv := newUpstreamConfigValidator(tc.general, tc.fallback, tc.private, &upstream.Options{
|
||||
Timeout: upsTimeout,
|
||||
Bootstrap: net.DefaultResolver,
|
||||
})
|
||||
cv.check()
|
||||
cv.close()
|
||||
|
||||
assert.Equal(t, tc.want, cv.status())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpstreamConfigValidator_Check_once(t *testing.T) {
|
||||
type signal = struct{}
|
||||
|
||||
reqCh := make(chan signal)
|
||||
hdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
pt := testutil.PanicT{}
|
||||
|
||||
err := w.WriteMsg(new(dns.Msg).SetReply(m))
|
||||
require.NoError(pt, err)
|
||||
|
||||
testutil.RequireSend(pt, reqCh, signal{}, testTimeout)
|
||||
})
|
||||
|
||||
addr := (&url.URL{
|
||||
Scheme: "tcp",
|
||||
Host: newLocalUpstreamListener(t, 0, hdlr).String(),
|
||||
}).String()
|
||||
twoAddrs := strings.Join([]string{addr, addr}, " ")
|
||||
|
||||
wantStatus := map[string]string{
|
||||
addr: "OK",
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
ups []string
|
||||
}{{
|
||||
name: "common",
|
||||
ups: []string{addr, addr, addr},
|
||||
}, {
|
||||
name: "domain-specific",
|
||||
ups: []string{"[/one.example/]" + addr, "[/two.example/]" + twoAddrs},
|
||||
}, {
|
||||
name: "both",
|
||||
ups: []string{addr, "[/one.example/]" + addr, addr, "[/two.example/]" + twoAddrs},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cv := newUpstreamConfigValidator(tc.ups, nil, nil, &upstream.Options{
|
||||
Timeout: testTimeout,
|
||||
})
|
||||
|
||||
go func() {
|
||||
cv.check()
|
||||
testutil.RequireSend(testutil.PanicT{}, reqCh, signal{}, testTimeout)
|
||||
}()
|
||||
|
||||
// Wait for the only request to be sent.
|
||||
testutil.RequireReceive(t, reqCh, testTimeout)
|
||||
|
||||
// Wait for the check to finish.
|
||||
testutil.RequireReceive(t, reqCh, testTimeout)
|
||||
|
||||
cv.close()
|
||||
require.Equal(t, wantStatus, cv.status())
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user