Pull request 2100: v0.107.42-rc

Squashed commit of the following:

commit 284190f748345c7556e60b67f051ec5f6f080948
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Dec 6 19:36:00 2023 +0300

    all: sync with master; upd chlog
This commit is contained in:
Ainar Garipov
2023-12-07 17:23:00 +03:00
parent df91f016f2
commit 25918e56fa
148 changed files with 7204 additions and 1705 deletions

View File

@@ -27,6 +27,19 @@ import (
"golang.org/x/exp/slices"
)
// ClientsContainer provides information about preconfigured DNS clients.
type ClientsContainer interface {
// UpstreamConfigByID returns the custom upstream configuration for the
// client having id, using boot to initialize the one if necessary. It
// returns nil if there is no custom upstream configuration for the client.
// The id is expected to be either a string representation of an IP address
// or the ClientID.
UpstreamConfigByID(
id string,
boot upstream.Resolver,
) (conf *proxy.CustomUpstreamConfig, err error)
}
// Config represents the DNS filtering configuration of AdGuard Home. The zero
// Config is empty and ready for use.
type Config struct {
@@ -35,10 +48,9 @@ type Config struct {
// FilterHandler is an optional additional filtering callback.
FilterHandler func(cliAddr netip.Addr, clientID string, settings *filtering.Settings) `yaml:"-"`
// GetCustomUpstreamByClient is a callback that returns upstreams
// configuration based on the client IP address or ClientID. It returns
// nil if there are no custom upstreams for the client.
GetCustomUpstreamByClient func(id string) (conf *proxy.UpstreamConfig, err error) `yaml:"-"`
// ClientsContainer stores the information about special handling of some
// DNS clients.
ClientsContainer ClientsContainer `yaml:"-"`
// Anti-DNS amplification
@@ -55,7 +67,7 @@ type Config struct {
RatelimitSubnetLenIPv6 int `yaml:"ratelimit_subnet_len_ipv6"`
// RatelimitWhitelist is the list of whitelisted client IP addresses.
RatelimitWhitelist []string `yaml:"ratelimit_whitelist"`
RatelimitWhitelist []netip.Addr `yaml:"ratelimit_whitelist"`
// RefuseAny, if true, refuse ANY requests.
RefuseAny bool `yaml:"refuse_any"`
@@ -277,32 +289,33 @@ type ServerConfig struct {
// UseHTTP3Upstreams defines if HTTP/3 is be allowed for DNS-over-HTTPS
// upstreams.
UseHTTP3Upstreams bool
// ServePlainDNS defines if plain DNS is allowed for incoming requests.
ServePlainDNS bool
}
// createProxyConfig creates and validates configuration for the main proxy.
func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
// newProxyConfig creates and validates configuration for the main proxy.
func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
srvConf := s.conf
conf = proxy.Config{
UDPListenAddr: srvConf.UDPListenAddrs,
TCPListenAddr: srvConf.TCPListenAddrs,
HTTP3: srvConf.ServeHTTP3,
Ratelimit: int(srvConf.Ratelimit),
RatelimitSubnetMaskIPv4: net.CIDRMask(srvConf.RatelimitSubnetLenIPv4, netutil.IPv4BitLen),
RatelimitSubnetMaskIPv6: net.CIDRMask(srvConf.RatelimitSubnetLenIPv6, netutil.IPv6BitLen),
RatelimitWhitelist: srvConf.RatelimitWhitelist,
RefuseAny: srvConf.RefuseAny,
TrustedProxies: srvConf.TrustedProxies,
CacheMinTTL: srvConf.CacheMinTTL,
CacheMaxTTL: srvConf.CacheMaxTTL,
CacheOptimistic: srvConf.CacheOptimistic,
UpstreamConfig: srvConf.UpstreamConfig,
BeforeRequestHandler: s.beforeRequestHandler,
RequestHandler: s.handleDNSRequest,
HTTPSServerName: aghhttp.UserAgent(),
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
MaxGoroutines: int(srvConf.MaxGoroutines),
UseDNS64: srvConf.UseDNS64,
DNS64Prefs: srvConf.DNS64Prefixes,
conf = &proxy.Config{
HTTP3: srvConf.ServeHTTP3,
Ratelimit: int(srvConf.Ratelimit),
RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4,
RatelimitSubnetLenIPv6: srvConf.RatelimitSubnetLenIPv6,
RatelimitWhitelist: srvConf.RatelimitWhitelist,
RefuseAny: srvConf.RefuseAny,
TrustedProxies: srvConf.TrustedProxies,
CacheMinTTL: srvConf.CacheMinTTL,
CacheMaxTTL: srvConf.CacheMaxTTL,
CacheOptimistic: srvConf.CacheOptimistic,
UpstreamConfig: srvConf.UpstreamConfig,
BeforeRequestHandler: s.beforeRequestHandler,
RequestHandler: s.handleDNSRequest,
HTTPSServerName: aghhttp.UserAgent(),
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
MaxGoroutines: int(srvConf.MaxGoroutines),
UseDNS64: srvConf.UseDNS64,
DNS64Prefs: srvConf.DNS64Prefixes,
}
if srvConf.EDNSClientSubnet.UseCustom {
@@ -316,27 +329,25 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
}
setProxyUpstreamMode(
&conf,
conf,
srvConf.AllServers,
srvConf.FastestAddr,
srvConf.FastestTimeout.Duration,
)
for i, s := range srvConf.BogusNXDomain {
var subnet *net.IPNet
subnet, err = netutil.ParseSubnet(s)
if err != nil {
log.Error("subnet at index %d: %s", i, err)
continue
}
conf.BogusNXDomain = append(conf.BogusNXDomain, subnet)
conf.BogusNXDomain, err = parseBogusNXDOMAIN(srvConf.BogusNXDomain)
if err != nil {
return nil, fmt.Errorf("bogus_nxdomain: %w", err)
}
err = s.prepareTLS(&conf)
err = s.prepareTLS(conf)
if err != nil {
return proxy.Config{}, fmt.Errorf("validating tls: %w", err)
return nil, fmt.Errorf("validating tls: %w", err)
}
err = s.preparePlain(conf)
if err != nil {
return nil, fmt.Errorf("validating plain: %w", err)
}
if c := srvConf.DNSCryptConfig; c.Enabled {
@@ -347,12 +358,27 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
}
if conf.UpstreamConfig == nil || len(conf.UpstreamConfig.Upstreams) == 0 {
return proxy.Config{}, errors.Error("no default upstream servers configured")
return nil, errors.Error("no default upstream servers configured")
}
return conf, nil
}
// parseBogusNXDOMAIN parses the bogus NXDOMAIN strings into valid subnets.
func parseBogusNXDOMAIN(confBogusNXDOMAIN []string) (subnets []netip.Prefix, err error) {
for i, s := range confBogusNXDOMAIN {
var subnet netip.Prefix
subnet, err = aghnet.ParseSubnet(s)
if err != nil {
return nil, fmt.Errorf("subnet at index %d: %w", i, err)
}
subnets = append(subnets, subnet)
}
return subnets, nil
}
const defaultBlockedResponseTTL = 3600
// initDefaultSettings initializes default settings if nothing
@@ -423,10 +449,7 @@ func collectListenAddr(
// collectDNSAddrs returns configured set of listening addresses. It also
// returns a set of ports of each unspecified listening address.
func (conf *ServerConfig) collectDNSAddrs() (
addrs map[netip.AddrPort]unit,
unspecPorts map[uint16]unit,
) {
func (conf *ServerConfig) collectDNSAddrs() (addrs mapAddrPortSet, unspecPorts map[uint16]unit) {
// TODO(e.burkov): Perhaps, we shouldn't allocate as much memory, since the
// TCP and UDP listening addresses are currently the same.
addrs = make(map[netip.AddrPort]unit, len(conf.TCPListenAddrs)+len(conf.UDPListenAddrs))
@@ -446,20 +469,64 @@ func (conf *ServerConfig) collectDNSAddrs() (
// defaultPlainDNSPort is the default port for plain DNS.
const defaultPlainDNSPort uint16 = 53
// addrPortMatcher is a function that matches an IP address with port.
type addrPortMatcher func(addr netip.AddrPort) (ok bool)
// addrPortSet is a set of [netip.AddrPort] values.
type addrPortSet interface {
// Has returns true if addrPort is in the set.
Has(addrPort netip.AddrPort) (ok bool)
}
// type check
var _ addrPortSet = emptyAddrPortSet{}
// emptyAddrPortSet is the [addrPortSet] containing no values.
type emptyAddrPortSet struct{}
// Has implements the [addrPortSet] interface for [emptyAddrPortSet].
func (emptyAddrPortSet) Has(_ netip.AddrPort) (ok bool) { return false }
// mapAddrPortSet is the [addrPortSet] containing values of [netip.AddrPort] as
// keys of a map.
type mapAddrPortSet map[netip.AddrPort]unit
// type check
var _ addrPortSet = mapAddrPortSet{}
// Has implements the [addrPortSet] interface for [mapAddrPortSet].
func (m mapAddrPortSet) Has(addrPort netip.AddrPort) (ok bool) {
_, ok = m[addrPort]
return ok
}
// combinedAddrPortSet is the [addrPortSet] defined by some IP addresses along
// with ports, any combination of which is considered being in the set.
type combinedAddrPortSet struct {
// TODO(e.burkov): Use sorted slices in combination with binary search.
ports map[uint16]unit
addrs []netip.Addr
}
// type check
var _ addrPortSet = (*combinedAddrPortSet)(nil)
// Has implements the [addrPortSet] interface for [*combinedAddrPortSet].
func (m *combinedAddrPortSet) Has(addrPort netip.AddrPort) (ok bool) {
_, ok = m.ports[addrPort.Port()]
return ok && slices.Contains(m.addrs, addrPort.Addr())
}
// filterOut filters out all the upstreams that match um. It returns all the
// closing errors joined.
func (m addrPortMatcher) filterOut(upsConf *proxy.UpstreamConfig) (err error) {
func filterOutAddrs(upsConf *proxy.UpstreamConfig, set addrPortSet) (err error) {
var errs []error
delFunc := func(u upstream.Upstream) (ok bool) {
// TODO(e.burkov): We should probably consider the protocol of u to
// only filter out the listening addresses of the same protocol.
addr, parseErr := aghnet.ParseAddrPort(u.Address(), defaultPlainDNSPort)
if parseErr != nil || !m(addr) {
if parseErr != nil || !set.Has(addr) {
// Don't filter out the upstream if it either cannot be parsed, or
// does not match um.
// does not match m.
return false
}
@@ -479,26 +546,20 @@ func (m addrPortMatcher) filterOut(upsConf *proxy.UpstreamConfig) (err error) {
return errors.Join(errs...)
}
// ourAddrsMatcher returns a matcher that matches all the configured listening
// ourAddrsSet returns an addrPortSet that contains all the configured listening
// addresses.
func (conf *ServerConfig) ourAddrsMatcher() (m addrPortMatcher, err error) {
func (conf *ServerConfig) ourAddrsSet() (m addrPortSet, err error) {
addrs, unspecPorts := conf.collectDNSAddrs()
if len(addrs) == 0 {
switch {
case len(addrs) == 0:
log.Debug("dnsforward: no listen addresses")
// Match no addresses.
return func(_ netip.AddrPort) (ok bool) { return false }, nil
}
if len(unspecPorts) == 0 {
return emptyAddrPortSet{}, nil
case len(unspecPorts) == 0:
log.Debug("dnsforward: filtering out addresses %s", addrs)
m = func(a netip.AddrPort) (ok bool) {
_, ok = addrs[a]
return ok
}
} else {
return addrs, nil
default:
var ifaceAddrs []netip.Addr
ifaceAddrs, err = aghnet.CollectAllIfacesAddrs()
if err != nil {
@@ -508,16 +569,11 @@ func (conf *ServerConfig) ourAddrsMatcher() (m addrPortMatcher, err error) {
log.Debug("dnsforward: filtering out addresses %s on ports %d", ifaceAddrs, unspecPorts)
m = func(a netip.AddrPort) (ok bool) {
if _, ok = unspecPorts[a.Port()]; ok {
return slices.Contains(ifaceAddrs, a.Addr())
}
return false
}
return &combinedAddrPortSet{
ports: unspecPorts,
addrs: ifaceAddrs,
}, nil
}
return m, nil
}
// prepareTLS - prepares TLS configuration for the DNS proxy
@@ -574,7 +630,7 @@ func (s *Server) prepareTLS(proxyConfig *proxy.Config) (err error) {
// isWildcard returns true if host is a wildcard hostname.
func isWildcard(host string) (ok bool) {
return len(host) >= 2 && host[0] == '*' && host[1] == '.'
return strings.HasPrefix(host, "*.")
}
// matchesDomainWildcard returns true if host matches the domain wildcard
@@ -614,6 +670,31 @@ func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, er
return &s.conf.cert, nil
}
// preparePlain prepares the plain-DNS configuration for the DNS proxy.
// preparePlain assumes that prepareTLS has already been called.
func (s *Server) preparePlain(proxyConf *proxy.Config) (err error) {
if s.conf.ServePlainDNS {
proxyConf.UDPListenAddr = s.conf.UDPListenAddrs
proxyConf.TCPListenAddr = s.conf.TCPListenAddrs
return nil
}
lenEncrypted := len(proxyConf.DNSCryptTCPListenAddr) +
len(proxyConf.DNSCryptUDPListenAddr) +
len(proxyConf.HTTPSListenAddr) +
len(proxyConf.QUICListenAddr) +
len(proxyConf.TLSListenAddr)
if lenEncrypted == 0 {
// TODO(a.garipov): Support full disabling of all DNS.
return errors.Error("disabling plain dns requires at least one encrypted protocol")
}
log.Info("dnsforward: warning: plain dns is disabled")
return nil
}
// UpdatedProtectionStatus updates protection state, if the protection was
// disabled temporarily. Returns the updated state of protection.
func (s *Server) UpdatedProtectionStatus() (enabled bool, disabledUntil *time.Time) {

View File

@@ -0,0 +1,349 @@
package dnsforward
import (
"fmt"
"strings"
"sync"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)
// upstreamConfigValidator parses the [*proxy.UpstreamConfig] and checks the
// actual DNS availability of each upstream.
type upstreamConfigValidator struct {
// general is the general upstream configuration.
general []*upstreamResult
// fallback is the fallback upstream configuration.
fallback []*upstreamResult
// private is the private upstream configuration.
private []*upstreamResult
}
// upstreamResult is a result of validation of an [upstream.Upstream] within an
// [proxy.UpstreamConfig].
type upstreamResult struct {
// server is the parsed upstream. It is nil when there was an error during
// parsing.
server upstream.Upstream
// err is the error either from parsing or from checking the upstream.
err error
// original is the piece of configuration that have either been turned to an
// upstream or caused an error.
original string
// isSpecific is true if the upstream is domain-specific.
isSpecific bool
}
// compare compares two [upstreamResult]s. It returns 0 if they are equal, -1
// if ur should be sorted before other, and 1 otherwise.
//
// TODO(e.burkov): Perhaps it makes sense to sort the results with errors near
// the end.
func (ur *upstreamResult) compare(other *upstreamResult) (res int) {
return strings.Compare(ur.original, other.original)
}
// newUpstreamConfigValidator parses the upstream configuration and returns a
// validator for it. cv already contains the parsed upstreams along with errors
// related.
func newUpstreamConfigValidator(
general []string,
fallback []string,
private []string,
opts *upstream.Options,
) (cv *upstreamConfigValidator) {
cv = &upstreamConfigValidator{}
for _, line := range general {
cv.general = cv.insertLineResults(cv.general, line, opts)
}
for _, line := range fallback {
cv.fallback = cv.insertLineResults(cv.fallback, line, opts)
}
for _, line := range private {
cv.private = cv.insertLineResults(cv.private, line, opts)
}
return cv
}
// insertLineResults parses line and inserts the result into s. It can insert
// multiple results as well as none.
func (cv *upstreamConfigValidator) insertLineResults(
s []*upstreamResult,
line string,
opts *upstream.Options,
) (result []*upstreamResult) {
upstreams, isSpecific, err := splitUpstreamLine(line)
if err != nil {
return cv.insert(s, &upstreamResult{
err: err,
original: line,
})
}
for _, upstreamAddr := range upstreams {
var res *upstreamResult
if upstreamAddr != "#" {
res = cv.parseUpstream(upstreamAddr, opts)
} else if !isSpecific {
res = &upstreamResult{
err: errNotDomainSpecific,
original: upstreamAddr,
}
} else {
continue
}
res.isSpecific = isSpecific
s = cv.insert(s, res)
}
return s
}
// insert inserts r into slice in a sorted order, except duplicates. slice must
// not be nil.
func (cv *upstreamConfigValidator) insert(
s []*upstreamResult,
r *upstreamResult,
) (result []*upstreamResult) {
i, has := slices.BinarySearchFunc(s, r, (*upstreamResult).compare)
if has {
log.Debug("dnsforward: duplicate configuration %q", r.original)
return s
}
return slices.Insert(s, i, r)
}
// parseUpstream parses addr and returns the result of parsing. It returns nil
// if the specified server points at the default upstream server which is
// validated separately.
func (cv *upstreamConfigValidator) parseUpstream(
addr string,
opts *upstream.Options,
) (r *upstreamResult) {
// Check if the upstream has a valid protocol prefix.
//
// TODO(e.burkov): Validate the domain name.
if proto, _, ok := strings.Cut(addr, "://"); ok {
if !slices.Contains(protocols, proto) {
return &upstreamResult{
err: fmt.Errorf("bad protocol %q", proto),
original: addr,
}
}
}
ups, err := upstream.AddressToUpstream(addr, opts)
return &upstreamResult{
server: ups,
err: err,
original: addr,
}
}
// check tries to exchange with each successfully parsed upstream and enriches
// the results with the healthcheck errors. It should not be called after the
// [upsConfValidator.close] method, since it makes no sense to check the closed
// upstreams.
func (cv *upstreamConfigValidator) check() {
const (
// testTLD is the special-use fully-qualified domain name for testing
// the DNS server reachability.
//
// See https://datatracker.ietf.org/doc/html/rfc6761#section-6.2.
testTLD = "test."
// inAddrARPATLD is the special-use fully-qualified domain name for PTR
// IP address resolution.
//
// See https://datatracker.ietf.org/doc/html/rfc1035#section-3.5.
inAddrARPATLD = "in-addr.arpa."
)
commonChecker := &healthchecker{
hostname: testTLD,
qtype: dns.TypeA,
ansEmpty: true,
}
arpaChecker := &healthchecker{
hostname: inAddrARPATLD,
qtype: dns.TypePTR,
ansEmpty: false,
}
wg := &sync.WaitGroup{}
wg.Add(len(cv.general) + len(cv.fallback) + len(cv.private))
for _, res := range cv.general {
go cv.checkSrv(res, wg, commonChecker)
}
for _, res := range cv.fallback {
go cv.checkSrv(res, wg, commonChecker)
}
for _, res := range cv.private {
go cv.checkSrv(res, wg, arpaChecker)
}
wg.Wait()
}
// checkSrv runs hc on the server from res, if any, and stores any occurred
// error in res. wg is always marked done in the end. It used to be called in
// a separate goroutine.
func (cv *upstreamConfigValidator) checkSrv(
res *upstreamResult,
wg *sync.WaitGroup,
hc *healthchecker,
) {
defer wg.Done()
if res.server == nil {
return
}
res.err = hc.check(res.server)
if res.err != nil && res.isSpecific {
res.err = domainSpecificTestError{Err: res.err}
}
}
// close closes all the upstreams that were successfully parsed. It enriches
// the results with deferred closing errors.
func (cv *upstreamConfigValidator) close() {
for _, slice := range [][]*upstreamResult{cv.general, cv.fallback, cv.private} {
for _, r := range slice {
if r.server != nil {
r.err = errors.WithDeferred(r.err, r.server.Close())
}
}
}
}
// status returns all the data collected during parsing, healthcheck, and
// closing of the upstreams. The returned map is keyed by the original upstream
// configuration piece and contains the corresponding error or "OK" if there was
// no error.
func (cv *upstreamConfigValidator) status() (results map[string]string) {
result := map[string]string{}
for _, res := range cv.general {
resultToStatus("general", res, result)
}
for _, res := range cv.fallback {
resultToStatus("fallback", res, result)
}
for _, res := range cv.private {
resultToStatus("private", res, result)
}
return result
}
// resultToStatus puts "OK" or an error message from res into resMap. section
// is the name of the upstream configuration section, i.e. "general",
// "fallback", or "private", and only used for logging.
//
// TODO(e.burkov): Currently, the HTTP handler expects that all the results are
// put together in a single map, which may lead to collisions, see AG-27539.
// Improve the results compilation.
func resultToStatus(section string, res *upstreamResult, resMap map[string]string) {
val := "OK"
if res.err != nil {
val = res.err.Error()
}
prevVal := resMap[res.original]
switch prevVal {
case "":
resMap[res.original] = val
case val:
log.Debug("dnsforward: duplicating %s config line %q", section, res.original)
default:
log.Debug(
"dnsforward: warning: %s config line %q (%v) had different result %v",
section,
val,
res.original,
prevVal,
)
}
}
// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark
// the tested upstream domain-specific and therefore consider its errors
// non-critical.
//
// TODO(a.garipov): Some common mechanism of distinguishing between errors and
// warnings (non-critical errors) is desired.
type domainSpecificTestError struct {
// Err is the actual error occurred during healthcheck test.
Err error
}
// type check
var _ error = domainSpecificTestError{}
// Error implements the [error] interface for domainSpecificTestError.
func (err domainSpecificTestError) Error() (msg string) {
return fmt.Sprintf("WARNING: %s", err.Err)
}
// type check
var _ errors.Wrapper = domainSpecificTestError{}
// Unwrap implements the [errors.Wrapper] interface for domainSpecificTestError.
func (err domainSpecificTestError) Unwrap() (wrapped error) {
return err.Err
}
// healthchecker checks the upstream's status by exchanging with it.
type healthchecker struct {
// hostname is the name of the host to put into healthcheck DNS request.
hostname string
// qtype is the type of DNS request to use for healthcheck.
qtype uint16
// ansEmpty defines if the answer section within the response is expected to
// be empty.
ansEmpty bool
}
// check exchanges with u and validates the response.
func (h *healthchecker) check(u upstream.Upstream) (err error) {
req := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: h.hostname,
Qtype: h.qtype,
Qclass: dns.ClassINET,
}},
}
reply, err := u.Exchange(req)
if err != nil {
return fmt.Errorf("couldn't communicate with upstream: %w", err)
} else if h.ansEmpty && len(reply.Answer) > 0 {
return errWrongResponse
}
return nil
}

View File

@@ -4,6 +4,8 @@ import (
"context"
"fmt"
"net"
"net/netip"
"strconv"
"time"
"github.com/AdguardTeam/golibs/errors"
@@ -11,10 +13,12 @@ import (
)
// DialContext is an [aghnet.DialContextFunc] that uses s to resolve hostnames.
// addr should be a valid host:port address, where host could be a domain name
// or an IP address.
func (s *Server) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
log.Debug("dnsforward: dialing %q for network %q", addr, network)
host, port, err := net.SplitHostPort(addr)
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
@@ -28,21 +32,24 @@ func (s *Server) DialContext(ctx context.Context, network, addr string) (conn ne
return dialer.DialContext(ctx, network, addr)
}
addrs, err := s.Resolve(host)
port, err := strconv.Atoi(portStr)
if err != nil {
return nil, fmt.Errorf("resolving %q: %w", host, err)
return nil, fmt.Errorf("invalid port %s: %w", portStr, err)
}
log.Debug("dnsforward: resolving %q: %v", host, addrs)
if len(addrs) == 0 {
ips, err := s.Resolve(ctx, network, host)
if err != nil {
return nil, fmt.Errorf("resolving %q: %w", host, err)
} else if len(ips) == 0 {
return nil, fmt.Errorf("no addresses for host %q", host)
}
log.Debug("dnsforward: resolved %q: %v", host, ips)
var dialErrs []error
for _, a := range addrs {
addr = net.JoinHostPort(a.String(), port)
conn, err = dialer.DialContext(ctx, network, addr)
for _, ip := range ips {
addrPort := netip.AddrPortFrom(ip, uint16(port))
conn, err = dialer.DialContext(ctx, network, addrPort.String())
if err != nil {
dialErrs = append(dialErrs, err)

View File

@@ -292,6 +292,7 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}, localUps)
t.Run(tc.name, func(t *testing.T) {

View File

@@ -2,7 +2,9 @@
package dnsforward
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/netip"
@@ -34,6 +36,11 @@ import (
// DefaultTimeout is the default upstream timeout
const DefaultTimeout = 10 * time.Second
// defaultLocalTimeout is the default timeout for resolving addresses from
// locally-served networks. It is assumed that local resolvers should work much
// faster than ordinary upstreams.
const defaultLocalTimeout = 1 * time.Second
// defaultClientIDCacheCount is the default count of items in the LRU ClientID
// cache. The assumption here is that there won't be more than this many
// requests between the BeforeRequestHandler stage and the actual processing.
@@ -108,7 +115,7 @@ type Server struct {
// stats is the statistics collector for client's DNS usage data.
stats stats.Interface
// access drops unallowed clients.
// access drops disallowed clients.
access *accessManager
// localDomainSuffix is the suffix used to detect internal hosts. It
@@ -135,8 +142,21 @@ type Server struct {
// PTR resolving.
sysResolvers SystemResolvers
// recDetector is a cache for recursive requests. It is used to detect
// and prevent recursive requests only for private upstreams.
// etcHosts contains the data from the system's hosts files.
etcHosts upstream.Resolver
// bootstrap is the resolver for upstreams' hostnames.
bootstrap upstream.Resolver
// bootResolvers are the resolvers that should be used for
// bootstrapping along with [etcHosts].
//
// TODO(e.burkov): Use [proxy.UpstreamConfig] when it will implement the
// [upstream.Resolver] interface.
bootResolvers []*upstream.UpstreamResolver
// recDetector is a cache for recursive requests. It is used to detect and
// prevent recursive requests only for private upstreams.
//
// See https://github.com/adguardTeam/adGuardHome/issues/3185#issuecomment-851048135.
recDetector *recursionDetector
@@ -153,8 +173,8 @@ type Server struct {
// during the BeforeRequestHandler stage.
clientIDCache cache.Cache
// DNS proxy instance for internal usage
// We don't Start() it and so no listen port is required.
// internalProxy resolves internal requests from the application itself. It
// isn't started and so no listen ports are required.
internalProxy *proxy.Proxy
// isRunning is true if the DNS server is running.
@@ -185,6 +205,7 @@ type DNSCreateParams struct {
DHCPServer DHCP
PrivateNets netutil.SubnetSet
Anonymizer *aghnet.IPMut
EtcHosts *aghnet.HostsContainer
LocalDomain string
}
@@ -217,8 +238,10 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
if p.Anonymizer == nil {
p.Anonymizer = aghnet.NewIPMut(nil)
}
s = &Server{
dnsFilter: p.DNSFilter,
dhcpServer: p.DHCPServer,
stats: p.Stats,
queryLog: p.QueryLog,
privateNets: p.PrivateNets,
@@ -230,6 +253,12 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
MaxCount: defaultClientIDCacheCount,
}),
anonymizer: p.Anonymizer,
conf: ServerConfig{
ServePlainDNS: true,
},
}
if p.EtcHosts != nil {
s.etcHosts = p.EtcHosts
}
s.sysResolvers, err = sysresolv.NewSystemResolvers(nil, defaultPlainDNSPort)
@@ -237,8 +266,6 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
return nil, fmt.Errorf("initializing system resolvers: %w", err)
}
s.dhcpServer = p.DHCPServer
if runtime.GOARCH == "mips" || runtime.GOARCH == "mipsle" {
// Use plain DNS on MIPS, encryption is too slow
defaultDNS = defaultBootstrap
@@ -274,7 +301,7 @@ func (s *Server) WriteDiskConfig(c *Config) {
sc := s.conf.Config
*c = sc
c.RatelimitWhitelist = stringutil.CloneSlice(sc.RatelimitWhitelist)
c.RatelimitWhitelist = slices.Clone(sc.RatelimitWhitelist)
c.BootstrapDNS = stringutil.CloneSlice(sc.BootstrapDNS)
c.FallbackDNS = stringutil.CloneSlice(sc.FallbackDNS)
c.AllowedClients = stringutil.CloneSlice(sc.AllowedClients)
@@ -305,15 +332,14 @@ func (s *Server) AddrProcConfig() (c *client.DefaultAddrProcConfig) {
}
}
// Resolve - get IP addresses by host name from an upstream server.
// No request/response filtering is performed.
// Query log and Stats are not updated.
// This method may be called before Start().
func (s *Server) Resolve(host string) ([]net.IPAddr, error) {
// Resolve gets IP addresses by host name from an upstream server. No
// request/response filtering is performed. Query log and Stats are not
// updated. This method may be called before [Server.Start].
func (s *Server) Resolve(ctx context.Context, net, host string) (addr []netip.Addr, err error) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()
return s.internalProxy.LookupIPAddr(host)
return s.internalProxy.LookupNetIP(ctx, net, host)
}
const (
@@ -421,7 +447,7 @@ func hostFromPTR(resp *dns.Msg) (host string, ttl time.Duration, err error) {
return "", 0, ErrRDNSNoData
}
// Start starts the DNS server.
// Start starts the DNS server. It must only be called after [Server.Prepare].
func (s *Server) Start() error {
s.serverLock.Lock()
defer s.serverLock.Unlock()
@@ -429,48 +455,42 @@ func (s *Server) Start() error {
return s.startLocked()
}
// startLocked starts the DNS server without locking. For internal use only.
// startLocked starts the DNS server without locking. s.serverLock is expected
// to be locked.
func (s *Server) startLocked() error {
err := s.dnsProxy.Start()
if err == nil {
s.isRunning = true
}
return err
}
// defaultLocalTimeout is the default timeout for resolving addresses from
// locally-served networks. It is assumed that local resolvers should work much
// faster than ordinary upstreams.
const defaultLocalTimeout = 1 * time.Second
// setupLocalResolvers initializes the resolvers for local addresses. For
// internal use only.
func (s *Server) setupLocalResolvers() (err error) {
matcher, err := s.conf.ourAddrsMatcher()
// setupLocalResolvers initializes the resolvers for local addresses. It
// assumes s.serverLock is locked or the Server not running.
func (s *Server) setupLocalResolvers(boot upstream.Resolver) (err error) {
set, err := s.conf.ourAddrsSet()
if err != nil {
// Don't wrap the error because it's informative enough as is.
return err
}
bootstraps := s.conf.BootstrapDNS
resolvers := s.conf.LocalPTRResolvers
filterConfig := false
if len(resolvers) == 0 {
sysResolvers := slices.DeleteFunc(s.sysResolvers.Addrs(), matcher)
confNeedsFiltering := len(resolvers) > 0
if confNeedsFiltering {
resolvers = stringutil.FilterOut(resolvers, IsCommentOrEmpty)
} else {
sysResolvers := slices.DeleteFunc(slices.Clone(s.sysResolvers.Addrs()), set.Has)
resolvers = make([]string, 0, len(sysResolvers))
for _, r := range sysResolvers {
resolvers = append(resolvers, r.String())
}
} else {
resolvers = stringutil.FilterOut(resolvers, IsCommentOrEmpty)
filterConfig = true
}
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", resolvers)
uc, err := s.prepareUpstreamConfig(resolvers, nil, &upstream.Options{
Bootstrap: bootstraps,
Bootstrap: boot,
Timeout: defaultLocalTimeout,
// TODO(e.burkov): Should we verify server's certificates?
PreferIPv6: s.conf.BootstrapPreferIPv6,
@@ -479,8 +499,9 @@ func (s *Server) setupLocalResolvers() (err error) {
return fmt.Errorf("preparing private upstreams: %w", err)
}
if filterConfig {
if err = matcher.filterOut(uc); err != nil {
if confNeedsFiltering {
err = filterOutAddrs(uc, set)
if err != nil {
return fmt.Errorf("filtering private upstreams: %w", err)
}
}
@@ -491,6 +512,7 @@ func (s *Server) setupLocalResolvers() (err error) {
},
}
// TODO(e.burkov): Should we also consider the DNS64 usage?
if s.conf.UsePrivateRDNS &&
// Only set the upstream config if there are any upstreams. It's safe
// to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
@@ -517,31 +539,19 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
s.initDefaultSettings()
err = s.prepareIpsetListSettings()
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return fmt.Errorf("preparing ipset settings: %w", err)
}
err = s.prepareUpstreamSettings()
boot, err := s.prepareInternalDNS()
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
var proxyConfig proxy.Config
proxyConfig, err = s.createProxyConfig()
proxyConfig, err := s.newProxyConfig()
if err != nil {
return fmt.Errorf("preparing proxy: %w", err)
}
s.setupDNS64()
err = s.prepareInternalProxy()
if err != nil {
return fmt.Errorf("preparing internal proxy: %w", err)
}
s.access, err = newAccessCtx(
s.conf.AllowedClients,
s.conf.DisallowedClients,
@@ -554,9 +564,9 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
// Set the proxy here because [setupLocalResolvers] sets its values.
//
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
s.dnsProxy = &proxy.Proxy{Config: proxyConfig}
s.dnsProxy = &proxy.Proxy{Config: *proxyConfig}
err = s.setupLocalResolvers()
err = s.setupLocalResolvers(boot)
if err != nil {
return fmt.Errorf("setting up resolvers: %w", err)
}
@@ -575,6 +585,38 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
return nil
}
// prepareInternalDNS initializes the internal state of s before initializing
// the primary DNS proxy instance. It assumes s.serverLock is locked or the
// Server not running.
func (s *Server) prepareInternalDNS() (boot upstream.Resolver, err error) {
err = s.prepareIpsetListSettings()
if err != nil {
return nil, fmt.Errorf("preparing ipset settings: %w", err)
}
s.bootstrap, s.bootResolvers, err = s.createBootstrap(s.conf.BootstrapDNS, &upstream.Options{
Timeout: DefaultTimeout,
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
})
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
err = s.prepareUpstreamSettings(s.bootstrap)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return s.bootstrap, err
}
err = s.prepareInternalProxy()
if err != nil {
return s.bootstrap, fmt.Errorf("preparing internal proxy: %w", err)
}
return s.bootstrap, nil
}
// setupFallbackDNS initializes the fallback DNS servers.
func (s *Server) setupFallbackDNS() (err error) {
fallbacks := s.conf.FallbackDNS
@@ -598,7 +640,8 @@ func (s *Server) setupFallbackDNS() (err error) {
return nil
}
// setupAddrProc initializes the address processor. For internal use only.
// setupAddrProc initializes the address processor. It assumes s.serverLock is
// locked or the Server not running.
func (s *Server) setupAddrProc() {
// TODO(a.garipov): This is a crutch for tests; remove.
if s.conf.AddrProcConf == nil {
@@ -687,7 +730,8 @@ func (s *Server) Stop() error {
return s.stopLocked()
}
// stopLocked stops the DNS server without locking. For internal use only.
// stopLocked stops the DNS server without locking. s.serverLock is expected to
// be locked.
func (s *Server) stopLocked() (err error) {
// TODO(e.burkov, a.garipov): Return critical errors, not just log them.
// This will require filtering all the non-critical errors in
@@ -700,18 +744,11 @@ func (s *Server) stopLocked() (err error) {
}
}
if upsConf := s.internalProxy.UpstreamConfig; upsConf != nil {
err = upsConf.Close()
if err != nil {
log.Error("dnsforward: closing internal resolvers: %s", err)
}
}
logCloserErr(s.internalProxy.UpstreamConfig, "dnsforward: closing internal resolvers: %s")
logCloserErr(s.localResolvers.UpstreamConfig, "dnsforward: closing local resolvers: %s")
if upsConf := s.localResolvers.UpstreamConfig; upsConf != nil {
err = upsConf.Close()
if err != nil {
log.Error("dnsforward: closing local resolvers: %s", err)
}
for _, b := range s.bootResolvers {
logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address())
}
s.isRunning = false
@@ -719,6 +756,18 @@ func (s *Server) stopLocked() (err error) {
return nil
}
// logCloserErr logs the error returned by c, if any.
func logCloserErr(c io.Closer, format string, args ...any) {
if c == nil {
return
}
err := c.Close()
if err != nil {
log.Error(format, append(args, err)...)
}
}
// IsRunning returns true if the DNS server is running.
func (s *Server) IsRunning() bool {
s.serverLock.RLock()

View File

@@ -54,13 +54,10 @@ const (
testMessagesCount = 10
)
// testClientAddr is the common net.Addr for tests.
// testClientAddrPort is the common net.Addr for tests.
//
// TODO(a.garipov): Use more.
var testClientAddr net.Addr = &net.TCPAddr{
IP: net.IP{1, 2, 3, 4},
Port: 12345,
}
var testClientAddrPort = netip.MustParseAddrPort("1.2.3.4:12345")
func startDeferStop(t *testing.T, s *Server) {
t.Helper()
@@ -182,6 +179,7 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte)
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}, nil)
tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem
@@ -309,6 +307,7 @@ func TestServer(t *testing.T) {
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
startDeferStop(t, s)
@@ -347,6 +346,7 @@ func TestServer_timeout(t *testing.T) {
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}
s, err := NewServer(DNSCreateParams{DNSFilter: createTestDNSFilter(t)})
@@ -381,6 +381,7 @@ func TestServer_Prepare_fallbacks(t *testing.T) {
},
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}
s, err := NewServer(DNSCreateParams{})
@@ -402,6 +403,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
startDeferStop(t, s)
@@ -479,6 +481,7 @@ func TestServerRace(t *testing.T) {
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
},
ConfigModified: func() {},
ServePlainDNS: true,
}
s := createTestServer(t, filterConf, forwardConf, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
@@ -532,6 +535,7 @@ func TestSafeSearch(t *testing.T) {
Enabled: false,
},
},
ServePlainDNS: true,
}
s := createTestServer(t, filterConf, forwardConf, nil)
startDeferStop(t, s)
@@ -594,6 +598,7 @@ func TestInvalidRequest(t *testing.T) {
Enabled: false,
},
},
ServePlainDNS: true,
}, nil)
startDeferStop(t, s)
@@ -622,6 +627,7 @@ func TestBlockedRequest(t *testing.T) {
Enabled: false,
},
},
ServePlainDNS: true,
}
s := createTestServer(t, &filtering.Config{
ProtectionEnabled: true,
@@ -644,45 +650,71 @@ func TestBlockedRequest(t *testing.T) {
}
func TestServerCustomClientUpstream(t *testing.T) {
const defaultCacheSize = 1024 * 1024
var upsCalledCounter uint32
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
CacheSize: defaultCacheSize,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
},
ServePlainDNS: true,
}
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, forwardConf, nil)
s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) {
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
return aghalg.Coalesce(
aghtest.MatchedResponse(req, dns.TypeA, "host", "192.168.0.1"),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil
})
return &proxy.UpstreamConfig{
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
atomic.AddUint32(&upsCalledCounter, 1)
return aghalg.Coalesce(
aghtest.MatchedResponse(req, dns.TypeA, "host", "192.168.0.1"),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil
})
customUpsConf := proxy.NewCustomUpstreamConfig(
&proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
}, nil
},
true,
defaultCacheSize,
forwardConf.EDNSClientSubnet.Enabled,
)
s.conf.ClientsContainer = &aghtest.ClientsContainer{
OnUpstreamConfigByID: func(
_ string,
_ upstream.Resolver,
) (conf *proxy.CustomUpstreamConfig, err error) {
return customUpsConf, nil
},
}
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
// Send test request.
req := createTestMessage("host.")
reply, err := dns.Exchange(req, addr.String())
reply, err := dns.Exchange(req, addr)
require.NoError(t, err)
require.NotEmpty(t, reply.Answer)
require.Len(t, reply.Answer, 1)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
require.NotEmpty(t, reply.Answer)
require.Len(t, reply.Answer, 1)
assert.Equal(t, net.IP{192, 168, 0, 1}, reply.Answer[0].(*dns.A).A)
assert.Equal(t, uint32(1), atomic.LoadUint32(&upsCalledCounter))
_, err = dns.Exchange(req, addr)
require.NoError(t, err)
assert.Equal(t, uint32(1), atomic.LoadUint32(&upsCalledCounter))
}
// testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work.
@@ -708,6 +740,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
Enabled: false,
},
},
ServePlainDNS: true,
}, nil)
testUpstm := &aghtest.Upstream{
CName: testCNAMEs,
@@ -740,6 +773,7 @@ func TestBlockCNAME(t *testing.T) {
Enabled: false,
},
},
ServePlainDNS: true,
}
s := createTestServer(t, &filtering.Config{
ProtectionEnabled: true,
@@ -814,6 +848,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
Enabled: false,
},
},
ServePlainDNS: true,
}
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
@@ -858,6 +893,7 @@ func TestNullBlockedRequest(t *testing.T) {
Enabled: false,
},
},
ServePlainDNS: true,
}
s := createTestServer(t, &filtering.Config{
ProtectionEnabled: true,
@@ -923,6 +959,7 @@ func TestBlockedCustomIP(t *testing.T) {
Enabled: false,
},
},
ServePlainDNS: true,
}
// Invalid BlockingIPv4.
@@ -974,6 +1011,7 @@ func TestBlockedByHosts(t *testing.T) {
Enabled: false,
},
},
ServePlainDNS: true,
}
s := createTestServer(t, &filtering.Config{
@@ -1024,6 +1062,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
Enabled: false,
},
},
ServePlainDNS: true,
}
s := createTestServer(t, filterConf, forwardConf, nil)
startDeferStop(t, s)
@@ -1082,6 +1121,7 @@ func TestRewrite(t *testing.T) {
Enabled: false,
},
},
ServePlainDNS: true,
}))
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {

View File

@@ -40,6 +40,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}, nil)
makeQ := func(qtype rules.RRType) (req *dns.Msg) {

View File

@@ -10,7 +10,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
@@ -28,8 +27,7 @@ func (s *Server) beforeRequestHandler(
return false, fmt.Errorf("getting clientid: %w", err)
}
addrPort := netutil.NetAddrToAddrPort(pctx.Addr)
blocked, _ := s.IsBlockedClient(addrPort.Addr(), clientID)
blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID)
if blocked {
return s.preBlockedResponse(pctx)
}
@@ -60,8 +58,7 @@ func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filter
setts = s.dnsFilter.Settings()
setts.ProtectionEnabled = dctx.protectionEnabled
if s.conf.FilterHandler != nil {
addrPort := netutil.NetAddrToAddrPort(dctx.proxyCtx.Addr)
s.conf.FilterHandler(addrPort.Addr(), dctx.clientID, setts)
s.conf.FilterHandler(dctx.proxyCtx.Addr.Addr(), dctx.clientID, setts)
}
return setts

View File

@@ -35,6 +35,7 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) {
Enabled: false,
},
},
ServePlainDNS: true,
}
filters := []filtering.Filter{{
ID: 0, Data: []byte(rules),
@@ -186,7 +187,7 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) {
dctx := &proxy.DNSContext{
Proto: proxy.ProtoUDP,
Req: tc.req,
Addr: &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 1},
Addr: testClientAddrPort,
}
t.Run(tc.name, func(t *testing.T) {
@@ -325,7 +326,7 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
Proto: proxy.ProtoUDP,
Req: tc.req,
Res: resp,
Addr: &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 1},
Addr: testClientAddrPort,
}
dctx := &dnsContext{

View File

@@ -6,20 +6,15 @@ import (
"io"
"net/http"
"net/netip"
"strings"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/miekg/dns"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)
@@ -45,8 +40,19 @@ type jsonDNSConfig struct {
// ProtectionEnabled defines if protection is enabled.
ProtectionEnabled *bool `json:"protection_enabled"`
// RateLimit is the number of requests per second allowed per client.
RateLimit *uint32 `json:"ratelimit"`
// Ratelimit is the number of requests per second allowed per client.
Ratelimit *uint32 `json:"ratelimit"`
// RatelimitSubnetLenIPv4 is a subnet length for IPv4 addresses used for
// rate limiting requests.
RatelimitSubnetLenIPv4 *int `json:"ratelimit_subnet_len_ipv4"`
// RatelimitSubnetLenIPv6 is a subnet length for IPv6 addresses used for
// rate limiting requests.
RatelimitSubnetLenIPv6 *int `json:"ratelimit_subnet_len_ipv6"`
// RatelimitWhitelist is a list of IP addresses excluded from rate limiting.
RatelimitWhitelist *[]netip.Addr `json:"ratelimit_whitelist"`
// BlockingMode defines the way blocked responses are constructed.
BlockingMode *filtering.BlockingMode `json:"blocking_mode"`
@@ -121,6 +127,9 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
blockingMode, blockingIPv4, blockingIPv6 := s.dnsFilter.BlockingMode()
blockedResponseTTL := s.dnsFilter.BlockedResponseTTL()
ratelimit := s.conf.Ratelimit
ratelimitSubnetLenIPv4 := s.conf.RatelimitSubnetLenIPv4
ratelimitSubnetLenIPv6 := s.conf.RatelimitSubnetLenIPv6
ratelimitWhitelist := append([]netip.Addr{}, s.conf.RatelimitWhitelist...)
customIP := s.conf.EDNSClientSubnet.CustomIP
enableEDNSClientSubnet := s.conf.EDNSClientSubnet.Enabled
@@ -157,7 +166,10 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
BlockingMode: &blockingMode,
BlockingIPv4: blockingIPv4,
BlockingIPv6: blockingIPv6,
RateLimit: &ratelimit,
Ratelimit: &ratelimit,
RatelimitSubnetLenIPv4: &ratelimitSubnetLenIPv4,
RatelimitSubnetLenIPv6: &ratelimitSubnetLenIPv6,
RatelimitWhitelist: &ratelimitWhitelist,
EDNSCSCustomIP: customIP,
EDNSCSEnabled: &enableEDNSClientSubnet,
EDNSCSUseCustom: &useCustom,
@@ -180,13 +192,13 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
// defaultLocalPTRUpstreams returns the list of default local PTR resolvers
// filtered of AdGuard Home's own DNS server addresses. It may appear empty.
func (s *Server) defaultLocalPTRUpstreams() (ups []string, err error) {
matcher, err := s.conf.ourAddrsMatcher()
matcher, err := s.conf.ourAddrsSet()
if err != nil {
// Don't wrap the error because it's informative enough as is.
return nil, err
}
sysResolvers := slices.DeleteFunc(s.sysResolvers.Addrs(), matcher)
sysResolvers := slices.DeleteFunc(s.sysResolvers.Addrs(), matcher.Has)
ups = make([]string, 0, len(sysResolvers))
for _, r := range sysResolvers {
ups = append(ups, r.String())
@@ -201,6 +213,7 @@ func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
aghhttp.WriteJSONResponseOK(w, r, resp)
}
// checkBlockingMode returns an error if blocking mode is invalid.
func (req *jsonDNSConfig) checkBlockingMode() (err error) {
if req.BlockingMode == nil {
return nil
@@ -209,12 +222,21 @@ func (req *jsonDNSConfig) checkBlockingMode() (err error) {
return validateBlockingMode(*req.BlockingMode, req.BlockingIPv4, req.BlockingIPv6)
}
func (req *jsonDNSConfig) checkUpstreamsMode() bool {
valid := []string{"", "fastest_addr", "parallel"}
// checkUpstreamsMode returns an error if the upstream mode is invalid.
func (req *jsonDNSConfig) checkUpstreamsMode() (err error) {
if req.UpstreamMode == nil {
return nil
}
return req.UpstreamMode == nil || stringutil.InSlice(valid, *req.UpstreamMode)
mode := *req.UpstreamMode
if ok := slices.Contains([]string{"", "fastest_addr", "parallel"}, mode); !ok {
return fmt.Errorf("upstream_mode: incorrect value %q", mode)
}
return nil
}
// checkBootstrap returns an error if any bootstrap address is invalid.
func (req *jsonDNSConfig) checkBootstrap() (err error) {
if req.Bootstraps == nil {
return nil
@@ -229,6 +251,7 @@ func (req *jsonDNSConfig) checkBootstrap() (err error) {
}
if _, err = upstream.NewUpstreamResolver(b, nil); err != nil {
// Don't wrap the error because it's informative enough as is.
return err
}
}
@@ -244,67 +267,136 @@ func (req *jsonDNSConfig) checkFallbacks() (err error) {
err = ValidateUpstreams(*req.Fallbacks)
if err != nil {
return fmt.Errorf("validating fallback servers: %w", err)
return fmt.Errorf("fallback servers: %w", err)
}
return nil
}
// validate returns an error if any field of req is invalid.
//
// TODO(s.chzhen): Parse, don't validate.
func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
defer func() { err = errors.Annotate(err, "validating dns config: %w") }()
err = req.validateUpstreamDNSServers(privateNets)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
err = req.checkRatelimitSubnetMaskLen()
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
err = req.checkBlockingMode()
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
err = req.checkUpstreamsMode()
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
err = req.checkCacheTTL()
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
return nil
}
// validateUpstreamDNSServers returns an error if any field of req is invalid.
func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetSet) (err error) {
if req.Upstreams != nil {
err = ValidateUpstreams(*req.Upstreams)
if err != nil {
return fmt.Errorf("validating upstream servers: %w", err)
return fmt.Errorf("upstream servers: %w", err)
}
}
if req.LocalPTRUpstreams != nil {
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets)
if err != nil {
return fmt.Errorf("validating private upstream servers: %w", err)
return fmt.Errorf("private upstream servers: %w", err)
}
}
err = req.checkBootstrap()
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
err = req.checkFallbacks()
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
err = req.checkBlockingMode()
if err != nil {
return err
}
switch {
case !req.checkUpstreamsMode():
return errors.Error("upstream_mode: incorrect value")
case !req.checkCacheTTL():
return errors.Error("cache_ttl_min must be less or equal than cache_ttl_max")
default:
return nil
}
return nil
}
func (req *jsonDNSConfig) checkCacheTTL() bool {
// checkCacheTTL returns an error if the configuration of the cache TTL is
// invalid.
func (req *jsonDNSConfig) checkCacheTTL() (err error) {
if req.CacheMinTTL == nil && req.CacheMaxTTL == nil {
return true
return nil
}
var min, max uint32
var minTTL, maxTTL uint32
if req.CacheMinTTL != nil {
min = *req.CacheMinTTL
minTTL = *req.CacheMinTTL
}
if req.CacheMaxTTL != nil {
max = *req.CacheMaxTTL
maxTTL = *req.CacheMaxTTL
}
return min <= max
if minTTL <= maxTTL {
return nil
}
return errors.Error("cache_ttl_min must be less or equal than cache_ttl_max")
}
// checkRatelimitSubnetMaskLen returns an error if the length of the subnet mask
// for IPv4 or IPv6 addresses is invalid.
func (req *jsonDNSConfig) checkRatelimitSubnetMaskLen() (err error) {
err = checkInclusion(req.RatelimitSubnetLenIPv4, 0, netutil.IPv4BitLen)
if err != nil {
return fmt.Errorf("ratelimit_subnet_len_ipv4 is invalid: %w", err)
}
err = checkInclusion(req.RatelimitSubnetLenIPv6, 0, netutil.IPv6BitLen)
if err != nil {
return fmt.Errorf("ratelimit_subnet_len_ipv6 is invalid: %w", err)
}
return nil
}
// checkInclusion returns an error if a ptr is not nil and points to value,
// that not in the inclusive range between minN and maxN.
func checkInclusion(ptr *int, minN, maxN int) (err error) {
if ptr == nil {
return nil
}
n := *ptr
switch {
case n < minN:
return fmt.Errorf("value %d less than min %d", n, minN)
case n > maxN:
return fmt.Errorf("value %d greater than max %d", n, maxN)
}
return nil
}
// handleSetConfig handles requests to the POST /control/dns_config endpoint.
@@ -401,6 +493,9 @@ func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
setIfNotNil(&s.conf.CacheOptimistic, dc.CacheOptimistic),
setIfNotNil(&s.conf.AddrProcConf.UseRDNS, dc.ResolveClients),
setIfNotNil(&s.conf.UsePrivateRDNS, dc.UsePrivateRDNS),
setIfNotNil(&s.conf.RatelimitSubnetLenIPv4, dc.RatelimitSubnetLenIPv4),
setIfNotNil(&s.conf.RatelimitSubnetLenIPv6, dc.RatelimitSubnetLenIPv6),
setIfNotNil(&s.conf.RatelimitWhitelist, dc.RatelimitWhitelist),
} {
shouldRestart = shouldRestart || hasSet
if shouldRestart {
@@ -408,8 +503,8 @@ func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
}
}
if dc.RateLimit != nil && s.conf.Ratelimit != *dc.RateLimit {
s.conf.Ratelimit = *dc.RateLimit
if dc.Ratelimit != nil && s.conf.Ratelimit != *dc.Ratelimit {
s.conf.Ratelimit = *dc.Ratelimit
shouldRestart = true
}
@@ -424,374 +519,11 @@ type upstreamJSON struct {
PrivateUpstreams []string `json:"private_upstream"`
}
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
// This function is useful for filtering out non-upstream lines from upstream
// configs.
func IsCommentOrEmpty(s string) (ok bool) {
return len(s) == 0 || s[0] == '#'
}
// newUpstreamConfig validates upstreams and returns an appropriate upstream
// configuration or nil if it can't be built.
//
// TODO(e.burkov): Perhaps proxy.ParseUpstreamsConfig should validate upstreams
// slice already so that this function may be considered useless.
func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err error) {
// No need to validate comments and empty lines.
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
if len(upstreams) == 0 {
// Consider this case valid since it means the default server should be
// used.
return nil, nil
// closeBoots closes all the provided bootstrap servers and logs errors if any.
func closeBoots(boots []*upstream.UpstreamResolver) {
for _, c := range boots {
logCloserErr(c, "dnsforward: closing bootstrap %s: %s", c.Address())
}
err = validateUpstreamConfig(upstreams)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
conf, err = proxy.ParseUpstreamsConfig(
upstreams,
&upstream.Options{
Bootstrap: []string{},
Timeout: DefaultTimeout,
},
)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
} else if len(conf.Upstreams) == 0 {
return nil, errors.Error("no default upstreams specified")
}
return conf, nil
}
// validateUpstreamConfig validates each upstream from the upstream
// configuration and returns an error if any upstream is invalid.
//
// TODO(e.burkov): Move into aghnet or even into dnsproxy.
func validateUpstreamConfig(conf []string) (err error) {
for _, u := range conf {
var ups []string
var domains []string
ups, domains, err = separateUpstream(u)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
for _, addr := range ups {
_, err = validateUpstream(addr, domains)
if err != nil {
return fmt.Errorf("validating upstream %q: %w", addr, err)
}
}
}
return nil
}
// ValidateUpstreams validates each upstream and returns an error if any
// upstream is invalid or if there are no default upstreams specified.
//
// TODO(e.burkov): Move into aghnet or even into dnsproxy.
func ValidateUpstreams(upstreams []string) (err error) {
_, err = newUpstreamConfig(upstreams)
return err
}
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
// upstream is invalid or if there are no default upstreams specified. It also
// checks each domain of domain-specific upstreams for being ARPA pointing to
// a locally-served network. privateNets must not be nil.
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
conf, err := newUpstreamConfig(upstreams)
if err != nil {
return fmt.Errorf("creating config: %w", err)
}
if conf == nil {
return nil
}
keys := maps.Keys(conf.DomainReservedUpstreams)
slices.Sort(keys)
var errs []error
for _, domain := range keys {
var subnet netip.Prefix
subnet, err = extractARPASubnet(domain)
if err != nil {
errs = append(errs, err)
continue
}
if !privateNets.Contains(subnet.Addr().AsSlice()) {
errs = append(
errs,
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
)
}
}
return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w")
}
var protocols = []string{
"h3://",
"https://",
"quic://",
"sdns://",
"tcp://",
"tls://",
"udp://",
}
// validateUpstream returns an error if u alongside with domains is not a valid
// upstream configuration. useDefault is true if the upstream is
// domain-specific and is configured to point at the default upstream server
// which is validated separately. The upstream is considered domain-specific
// only if domains is at least not nil.
func validateUpstream(u string, domains []string) (useDefault bool, err error) {
// The special server address '#' means that default server must be used.
if useDefault = u == "#" && domains != nil; useDefault {
return useDefault, nil
}
// Check if the upstream has a valid protocol prefix.
//
// TODO(e.burkov): Validate the domain name.
for _, proto := range protocols {
if strings.HasPrefix(u, proto) {
return false, nil
}
}
if proto, _, ok := strings.Cut(u, "://"); ok {
return false, fmt.Errorf("bad protocol %q", proto)
}
// Check if upstream is either an IP or IP with port.
if _, err = netip.ParseAddr(u); err == nil {
return false, nil
} else if _, err = netip.ParseAddrPort(u); err == nil {
return false, nil
}
return false, err
}
// separateUpstream returns the upstreams and the specified domains. domains
// is nil when the upstream is not domains-specific. Otherwise it may also be
// empty.
func separateUpstream(upstreamStr string) (upstreams, domains []string, err error) {
if !strings.HasPrefix(upstreamStr, "[/") {
return []string{upstreamStr}, nil, nil
}
defer func() { err = errors.Annotate(err, "bad upstream for domain %q: %w", upstreamStr) }()
parts := strings.Split(upstreamStr[2:], "/]")
switch len(parts) {
case 2:
// Go on.
case 1:
return nil, nil, errors.Error("missing separator")
default:
return nil, nil, errors.Error("duplicated separator")
}
for i, host := range strings.Split(parts[0], "/") {
if host == "" {
continue
}
err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*."))
if err != nil {
return nil, nil, fmt.Errorf("domain at index %d: %w", i, err)
}
domains = append(domains, host)
}
return strings.Fields(parts[1]), domains, nil
}
// healthCheckFunc is a signature of function to check if upstream exchanges
// properly.
type healthCheckFunc func(u upstream.Upstream) (err error)
// checkDNSUpstreamExc checks if the DNS upstream exchanges correctly.
func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
// testTLD is the special-use fully-qualified domain name for testing the
// DNS server reachability.
//
// See https://datatracker.ietf.org/doc/html/rfc6761#section-6.2.
const testTLD = "test."
req := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: testTLD,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
var reply *dns.Msg
reply, err = u.Exchange(req)
if err != nil {
return fmt.Errorf("couldn't communicate with upstream: %w", err)
} else if len(reply.Answer) != 0 {
return errors.Error("wrong response")
}
return nil
}
// checkPrivateUpstreamExc checks if the upstream for resolving private
// addresses exchanges correctly.
//
// TODO(e.burkov): Think about testing the ip6.arpa. as well.
func checkPrivateUpstreamExc(u upstream.Upstream) (err error) {
// inAddrArpaTLD is the special-use fully-qualified domain name for PTR IP
// address resolution.
//
// See https://datatracker.ietf.org/doc/html/rfc1035#section-3.5.
const inAddrArpaTLD = "in-addr.arpa."
req := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: inAddrArpaTLD,
Qtype: dns.TypePTR,
Qclass: dns.ClassINET,
}},
}
if _, err = u.Exchange(req); err != nil {
return fmt.Errorf("couldn't communicate with upstream: %w", err)
}
return nil
}
// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark
// the tested upstream domain-specific and therefore consider its errors
// non-critical.
//
// TODO(a.garipov): Some common mechanism of distinguishing between errors and
// warnings (non-critical errors) is desired.
type domainSpecificTestError struct {
error
}
// Error implements the [error] interface for domainSpecificTestError.
func (err domainSpecificTestError) Error() (msg string) {
return fmt.Sprintf("WARNING: %s", err.error)
}
// checkDNS parses line, creates DNS upstreams using opts, and checks if the
// upstreams are exchanging correctly. It saves the result into a sync.Map
// where key is an upstream address and value is "OK", if the upstream
// exchanges correctly, or text of the error. It is intended to be used as a
// goroutine.
//
// TODO(s.chzhen): Separate to a different structure/file.
func (s *Server) checkDNS(
line string,
opts *upstream.Options,
check healthCheckFunc,
wg *sync.WaitGroup,
m *sync.Map,
) {
defer wg.Done()
defer log.OnPanic("dnsforward: checking upstreams")
upstreams, domains, err := separateUpstream(line)
if err != nil {
err = fmt.Errorf("wrong upstream format: %w", err)
m.Store(line, err.Error())
return
}
specific := len(domains) > 0
for _, upstreamAddr := range upstreams {
var useDefault bool
useDefault, err = validateUpstream(upstreamAddr, domains)
if err != nil {
err = fmt.Errorf("wrong upstream format: %w", err)
m.Store(upstreamAddr, err.Error())
continue
}
if useDefault {
continue
}
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
err = s.checkUpstreamAddr(upstreamAddr, specific, opts, check)
if err != nil {
m.Store(upstreamAddr, err.Error())
} else {
m.Store(upstreamAddr, "OK")
}
}
}
// checkUpstreamAddr creates the DNS upstream using opts and information from
// [s.dnsFilter.EtcHosts]. Checks if the DNS upstream exchanges correctly. It
// returns an error if addr is not valid DNS upstream address or the upstream
// is not exchanging correctly.
func (s *Server) checkUpstreamAddr(
addr string,
specific bool,
opts *upstream.Options,
check healthCheckFunc,
) (err error) {
defer func() {
if err != nil && specific {
err = domainSpecificTestError{error: err}
}
}()
opts = &upstream.Options{
Bootstrap: opts.Bootstrap,
Timeout: opts.Timeout,
PreferIPv6: opts.PreferIPv6,
}
// dnsFilter can be nil during application update.
if s.dnsFilter != nil {
recs := s.dnsFilter.EtcHostsRecords(extractUpstreamHost(addr))
for _, rec := range recs {
opts.ServerIPAddrs = append(opts.ServerIPAddrs, rec.Addr.AsSlice())
}
sortNetIPAddrs(opts.ServerIPAddrs, opts.PreferIPv6)
}
u, err := upstream.AddressToUpstream(addr, opts)
if err != nil {
return fmt.Errorf("creating upstream for %q: %w", addr, err)
}
defer func() { err = errors.WithDeferred(err, u.Close()) }()
return check(u)
}
// handleTestUpstreamDNS handles requests to the POST /control/test_upstream_dns
@@ -808,46 +540,27 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty)
req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty)
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
req.BootstrapDNS = stringutil.FilterOut(req.BootstrapDNS, IsCommentOrEmpty)
opts := &upstream.Options{
Bootstrap: req.BootstrapDNS,
Timeout: s.conf.UpstreamTimeout,
PreferIPv6: s.conf.BootstrapPreferIPv6,
}
if len(opts.Bootstrap) == 0 {
opts.Bootstrap = defaultBootstrap
var boots []*upstream.UpstreamResolver
opts.Bootstrap, boots, err = s.createBootstrap(req.BootstrapDNS, opts)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse bootstrap servers: %s", err)
return
}
defer closeBoots(boots)
wg := &sync.WaitGroup{}
m := &sync.Map{}
cv := newUpstreamConfigValidator(req.Upstreams, req.FallbackDNS, req.PrivateUpstreams, opts)
cv.check()
cv.close()
wg.Add(len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams))
for _, ups := range req.Upstreams {
go s.checkDNS(ups, opts, checkDNSUpstreamExc, wg, m)
}
for _, ups := range req.FallbackDNS {
go s.checkDNS(ups, opts, checkDNSUpstreamExc, wg, m)
}
for _, ups := range req.PrivateUpstreams {
go s.checkDNS(ups, opts, checkPrivateUpstreamExc, wg, m)
}
wg.Wait()
result := map[string]string{}
m.Range(func(k, v any) bool {
// TODO(e.burkov): The upstreams used for both common and private
// resolving should be reported separately.
ups := k.(string)
status := v.(string)
result[ups] = status
return true
})
aghhttp.WriteJSONResponseOK(w, r, result)
aghhttp.WriteJSONResponseOK(w, r, cv.status())
}
// handleCacheClear is the handler for the POST /control/cache_clear HTTP API.

View File

@@ -72,11 +72,14 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
UDPListenAddrs: []*net.UDPAddr{},
TCPListenAddrs: []*net.TCPAddr{},
Config: Config{
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
FallbackDNS: []string{"9.9.9.10"},
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
FallbackDNS: []string{"9.9.9.10"},
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 56,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ConfigModified: func() {},
ServePlainDNS: true,
}
s := createTestServer(t, filterConf, forwardConf, nil)
s.sysResolvers = &emptySysResolvers{}
@@ -150,10 +153,13 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
UDPListenAddrs: []*net.UDPAddr{},
TCPListenAddrs: []*net.TCPAddr{},
Config: Config{
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 56,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ConfigModified: func() {},
ServePlainDNS: true,
}
s := createTestServer(t, filterConf, forwardConf, nil)
s.sysResolvers = &emptySysResolvers{}
@@ -179,11 +185,18 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
name: "blocking_mode_good",
wantSet: "",
}, {
name: "blocking_mode_bad",
wantSet: "blocking_ipv4 must be valid ipv4 on custom_ip blocking_mode",
name: "blocking_mode_bad",
wantSet: "validating dns config: " +
"blocking_ipv4 must be valid ipv4 on custom_ip blocking_mode",
}, {
name: "ratelimit",
wantSet: "",
}, {
name: "ratelimit_subnet_len",
wantSet: "",
}, {
name: "ratelimit_whitelist_not_ip",
wantSet: `decoding request: ParseAddr("not.ip"): unexpected character (at "not.ip")`,
}, {
name: "edns_cs_enabled",
wantSet: "",
@@ -206,24 +219,26 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
name: "upstream_mode_fastest_addr",
wantSet: "",
}, {
name: "upstream_dns_bad",
wantSet: `validating upstream servers: validating upstream "!!!": not an ip:port`,
name: "upstream_dns_bad",
wantSet: `validating dns config: ` +
`upstream servers: validating upstream "!!!": not an ip:port`,
}, {
name: "bootstraps_bad",
wantSet: `checking bootstrap a: invalid address: bootstrap a:53: ` +
wantSet: `validating dns config: checking bootstrap a: invalid address: not a bootstrap: ` +
`ParseAddr("a"): unable to parse IP`,
}, {
name: "cache_bad_ttl",
wantSet: `cache_ttl_min must be less or equal than cache_ttl_max`,
wantSet: `validating dns config: cache_ttl_min must be less or equal than cache_ttl_max`,
}, {
name: "upstream_mode_bad",
wantSet: `upstream_mode: incorrect value`,
wantSet: `validating dns config: upstream_mode: incorrect value "somethingelse"`,
}, {
name: "local_ptr_upstreams_good",
wantSet: "",
}, {
name: "local_ptr_upstreams_bad",
wantSet: `validating private upstream servers: checking domain-specific upstreams: ` +
wantSet: `validating dns config: ` +
`private upstream servers: checking domain-specific upstreams: ` +
`bad arpa domain name "non.arpa.": not a reversed ip network`,
}, {
name: "local_ptr_upstreams_null",
@@ -347,7 +362,7 @@ func TestValidateUpstreams(t *testing.T) {
set: []string{"123.3.7m"},
}, {
name: "invalid",
wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": ` +
wantErr: `splitting upstream line "[/host.com]tls://dns.adguard.com": ` +
`missing separator`,
set: []string{"[/host.com]tls://dns.adguard.com"},
}, {
@@ -373,7 +388,7 @@ func TestValidateUpstreams(t *testing.T) {
},
}, {
name: "bad_domain",
wantErr: `bad upstream for domain "[/!/]8.8.8.8": domain at index 0: ` +
wantErr: `splitting upstream line "[/!/]8.8.8.8": domain at index 0: ` +
`bad domain name "!": bad top-level domain name label "!": ` +
`bad top-level domain name label rune '!'`,
set: []string{"[/!/]8.8.8.8"},
@@ -461,25 +476,15 @@ func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (r
}
func TestServer_HandleTestUpstreamDNS(t *testing.T) {
goodHandler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
hdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
err := w.WriteMsg(new(dns.Msg).SetReply(m))
require.NoError(testutil.PanicT{}, err)
})
badHandler := dns.HandlerFunc(func(w dns.ResponseWriter, _ *dns.Msg) {
err := w.WriteMsg(new(dns.Msg))
require.NoError(testutil.PanicT{}, err)
})
goodUps := (&url.URL{
ups := (&url.URL{
Scheme: "tcp",
Host: newLocalUpstreamListener(t, 0, goodHandler).String(),
Host: newLocalUpstreamListener(t, 0, hdlr).String(),
}).String()
badUps := (&url.URL{
Scheme: "tcp",
Host: newLocalUpstreamListener(t, 0, badHandler).String(),
}).String()
goodAndBadUps := strings.Join([]string{goodUps, badUps}, " ")
const (
upsTimeout = 100 * time.Millisecond
@@ -488,7 +493,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
upstreamHost = "custom.localhost"
)
hostsListener := newLocalUpstreamListener(t, 0, goodHandler)
hostsListener := newLocalUpstreamListener(t, 0, hdlr)
hostsUps := (&url.URL{
Scheme: "tcp",
Host: netutil.JoinHostPort(upstreamHost, hostsListener.Port()),
@@ -519,7 +524,9 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}, nil)
srv.etcHosts = hc
startDeferStop(t, srv)
testCases := []struct {
@@ -527,43 +534,6 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
wantResp map[string]any
name string
}{{
body: map[string]any{
"upstream_dns": []string{goodUps},
},
wantResp: map[string]any{
goodUps: "OK",
},
name: "success",
}, {
body: map[string]any{
"upstream_dns": []string{badUps},
},
wantResp: map[string]any{
badUps: `couldn't communicate with upstream: exchanging with ` +
badUps + ` over tcp: dns: id mismatch`,
},
name: "broken",
}, {
body: map[string]any{
"upstream_dns": []string{goodUps, badUps},
},
wantResp: map[string]any{
goodUps: "OK",
badUps: `couldn't communicate with upstream: exchanging with ` +
badUps + ` over tcp: dns: id mismatch`,
},
name: "both",
}, {
body: map[string]any{
"upstream_dns": []string{"[/domain.example/]" + badUps},
},
wantResp: map[string]any{
badUps: `WARNING: couldn't communicate ` +
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
`dns: id mismatch`,
},
name: "domain_specific_error",
}, {
body: map[string]any{
"upstream_dns": []string{hostsUps},
},
@@ -573,63 +543,12 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
name: "etc_hosts",
}, {
body: map[string]any{
"fallback_dns": []string{goodUps},
"upstream_dns": []string{ups, "#this.is.comment"},
},
wantResp: map[string]any{
goodUps: "OK",
ups: "OK",
},
name: "fallback_success",
}, {
body: map[string]any{
"fallback_dns": []string{badUps},
},
wantResp: map[string]any{
badUps: `couldn't communicate with upstream: exchanging with ` +
badUps + ` over tcp: dns: id mismatch`,
},
name: "fallback_broken",
}, {
body: map[string]any{
"fallback_dns": []string{goodUps, "#this.is.comment"},
},
wantResp: map[string]any{
goodUps: "OK",
},
name: "fallback_comment_mix",
}, {
body: map[string]any{
"upstream_dns": []string{"[/domain.example/]" + goodUps + " " + badUps},
},
wantResp: map[string]any{
goodUps: "OK",
badUps: `WARNING: couldn't communicate ` +
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
`dns: id mismatch`,
},
name: "multiple_domain_specific_upstreams",
}, {
body: map[string]any{
"upstream_dns": []string{"[/domain.example/]/]1.2.3.4"},
},
wantResp: map[string]any{
"[/domain.example/]/]1.2.3.4": `wrong upstream format: ` +
`bad upstream for domain "[/domain.example/]/]1.2.3.4": ` +
`duplicated separator`,
},
name: "bad_specification",
}, {
body: map[string]any{
"upstream_dns": []string{"[/domain.example/]" + goodAndBadUps},
"fallback_dns": []string{"[/domain.example/]" + goodAndBadUps},
"private_upstream": []string{"[/domain.example/]" + goodAndBadUps},
},
wantResp: map[string]any{
goodUps: "OK",
badUps: `WARNING: couldn't communicate ` +
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
`dns: id mismatch`,
},
name: "all_different",
name: "comment_mix",
}}
for _, tc := range testCases {

View File

@@ -91,7 +91,7 @@ func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.M
case filtering.BlockingModeREFUSED:
return s.makeResponseREFUSED(req)
default:
log.Error("dns: invalid blocking mode %q", mode)
log.Error("dnsforward: invalid blocking mode %q", mode)
return s.makeResponse(req)
}
@@ -112,7 +112,7 @@ func (s *Server) makeResponseCustomIP(
default:
// Generally shouldn't happen, since the types are checked in
// genDNSFilterMessage.
log.Error("dns: invalid msg type %s for custom IP blocking mode", dns.Type(qt))
log.Error("dnsforward: invalid msg type %s for custom IP blocking mode", dns.Type(qt))
return s.makeResponse(req)
}
@@ -207,15 +207,7 @@ func (s *Server) genResponseWithIPs(req *dns.Msg, ips []netip.Addr) (resp *dns.M
var ans []dns.RR
switch req.Question[0].Qtype {
case dns.TypeA:
for _, ip := range ips {
if ip.Is4() {
ans = append(ans, s.genAnswerA(req, ip))
} else {
ans = nil
break
}
}
ans = s.genAnswersWithIPv4s(req, ips)
case dns.TypeAAAA:
for _, ip := range ips {
if ip.Is6() {
@@ -232,6 +224,23 @@ func (s *Server) genResponseWithIPs(req *dns.Msg, ips []netip.Addr) (resp *dns.M
return resp
}
// genAnswersWithIPv4s generates DNS A answers provided IPv4 addresses. If any
// of the IPs isn't an IPv4 address, genAnswersWithIPv4s logs a warning and
// returns nil,
func (s *Server) genAnswersWithIPv4s(req *dns.Msg, ips []netip.Addr) (ans []dns.RR) {
for _, ip := range ips {
if !ip.Is4() {
log.Info("dnsforward: warning: ip %s is not ipv4 address", ip)
return nil
}
ans = append(ans, s.genAnswerA(req, ip))
}
return ans
}
// makeResponseNullIP creates a response with 0.0.0.0 for A requests, :: for
// AAAA requests, and an empty response for other types.
func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) {
@@ -253,7 +262,7 @@ func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) {
func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSContext) *dns.Msg {
if newAddr == "" {
log.Printf("block host is not specified.")
log.Info("dnsforward: block host is not specified")
return s.genServerFailure(request)
}
@@ -276,14 +285,14 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
prx := s.proxy()
if prx == nil {
log.Debug("dns: %s", srvClosedErr)
log.Debug("dnsforward: %s", srvClosedErr)
return s.genServerFailure(request)
}
err = prx.Resolve(newContext)
if err != nil {
log.Printf("couldn't look up replacement host %q: %s", newAddr, err)
log.Info("dnsforward: looking up replacement host %q: %s", newAddr, err)
return s.genServerFailure(request)
}

View File

@@ -191,7 +191,7 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
defer log.Debug("dnsforward: finished processing initial")
pctx := dctx.proxyCtx
s.processClientIP(pctx.Addr)
s.processClientIP(pctx.Addr.Addr())
q := pctx.Req.Question[0]
qt := q.Qtype
@@ -228,9 +228,8 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
}
// processClientIP sends the client IP address to s.addrProc, if needed.
func (s *Server) processClientIP(addr net.Addr) {
clientIP := netutil.NetAddrToAddrPort(addr).Addr()
if clientIP == (netip.Addr{}) {
func (s *Server) processClientIP(addr netip.Addr) {
if !addr.IsValid() {
log.Info("dnsforward: warning: bad client addr %q", addr)
return
@@ -241,7 +240,7 @@ func (s *Server) processClientIP(addr net.Addr) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()
s.addrProc.Process(clientIP)
s.addrProc.Process(addr)
}
// processDDRQuery responds to Discovery of Designated Resolvers (DDR) SVCB
@@ -351,12 +350,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
rc = resultCodeSuccess
var ip net.IP
if ip, _ = netutil.IPAndPortFromAddr(dctx.proxyCtx.Addr); ip == nil {
return rc
}
dctx.isLocalClient = s.privateNets.Contains(ip)
dctx.isLocalClient = s.privateNets.Contains(dctx.proxyCtx.Addr.Addr().AsSlice())
return rc
}
@@ -831,14 +825,13 @@ func (s *Server) dhcpHostFromRequest(q *dns.Question) (reqHost string) {
// setCustomUpstream sets custom upstream settings in pctx, if necessary.
func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
customUpsByClient := s.conf.GetCustomUpstreamByClient
if pctx.Addr == nil || customUpsByClient == nil {
if !pctx.Addr.IsValid() || s.conf.ClientsContainer == nil {
return
}
// Use the ClientID first, since it has a higher priority.
id := stringutil.Coalesce(clientID, ipStringFromAddr(pctx.Addr))
upsConf, err := customUpsByClient(id)
id := stringutil.Coalesce(clientID, pctx.Addr.Addr().String())
upsConf, err := s.conf.ClientsContainer.UpstreamConfigByID(id, s.bootstrap)
if err != nil {
log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err)
@@ -847,9 +840,9 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
if upsConf != nil {
log.Debug("dnsforward: using custom upstreams for client %s", id)
}
pctx.CustomUpstreamConfig = upsConf
pctx.CustomUpstreamConfig = upsConf
}
}
// Apply filtering logic after we have received response from upstream servers

View File

@@ -81,6 +81,7 @@ func TestServer_ProcessInitial(t *testing.T) {
AAAADisabled: tc.aaaaDisabled,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}
s := createTestServer(t, &filtering.Config{
@@ -96,14 +97,14 @@ func TestServer_ProcessInitial(t *testing.T) {
dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{
Req: createTestMessageWithType(tc.target, tc.qType),
Addr: testClientAddr,
Addr: testClientAddrPort,
RequestID: 1234,
},
}
gotRC := s.processInitial(dctx)
assert.Equal(t, tc.wantRC, gotRC)
assert.Equal(t, netutil.NetAddrToAddrPort(testClientAddr).Addr(), gotAddr)
assert.Equal(t, testClientAddrPort.Addr(), gotAddr)
if tc.wantRCode > 0 {
gotResp := dctx.proxyCtx.Res
@@ -180,6 +181,7 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
AAAADisabled: tc.aaaaDisabled,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}
s := createTestServer(t, &filtering.Config{
@@ -199,7 +201,7 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
Proto: proxy.ProtoUDP,
Req: tc.req,
Res: resp,
Addr: testClientAddr,
Addr: testClientAddrPort,
},
}
@@ -369,6 +371,7 @@ func prepareTestServer(t *testing.T, portDoH, portDoT, portDoQ int, ddrEnabled b
TLSConfig: TLSConfig{
ServerName: ddrTestDomainName,
},
ServePlainDNS: true,
},
}
@@ -394,33 +397,27 @@ func TestServer_ProcessDetermineLocal(t *testing.T) {
}
testCases := []struct {
want assert.BoolAssertionFunc
name string
cliIP net.IP
want assert.BoolAssertionFunc
name string
cliAddr netip.AddrPort
}{{
want: assert.True,
name: "local",
cliIP: net.IP{192, 168, 0, 1},
want: assert.True,
name: "local",
cliAddr: netip.MustParseAddrPort("192.168.0.1:1"),
}, {
want: assert.False,
name: "external",
cliIP: net.IP{250, 249, 0, 1},
want: assert.False,
name: "external",
cliAddr: netip.MustParseAddrPort("250.249.0.1:1"),
}, {
want: assert.False,
name: "invalid",
cliIP: net.IP{1, 2, 3, 4, 5},
}, {
want: assert.False,
name: "nil",
cliIP: nil,
want: assert.False,
name: "invalid",
cliAddr: netip.AddrPort{},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
proxyCtx := &proxy.DNSContext{
Addr: &net.TCPAddr{
IP: tc.cliIP,
},
Addr: tc.cliAddr,
}
dctx := &dnsContext{
proxyCtx: proxyCtx,
@@ -699,6 +696,7 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}, ups)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ups}
startDeferStop(t, s)
@@ -707,31 +705,31 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
name string
want string
question net.IP
cliIP net.IP
cliAddr netip.AddrPort
wantLen int
}{{
name: "from_local_to_external",
want: "host1.example.net.",
question: net.IP{254, 253, 252, 251},
cliIP: net.IP{192, 168, 10, 10},
cliAddr: netip.MustParseAddrPort("192.168.10.10:1"),
wantLen: 1,
}, {
name: "from_external_for_local",
want: "",
question: net.IP{192, 168, 1, 1},
cliIP: net.IP{254, 253, 252, 251},
cliAddr: netip.MustParseAddrPort("254.253.252.251:1"),
wantLen: 0,
}, {
name: "from_local_for_local",
want: "some.local-client.",
question: net.IP{192, 168, 1, 1},
cliIP: net.IP{192, 168, 1, 2},
cliAddr: netip.MustParseAddrPort("192.168.1.2:1"),
wantLen: 1,
}, {
name: "from_external_for_external",
want: "host1.example.net.",
question: net.IP{254, 253, 252, 251},
cliIP: net.IP{254, 253, 252, 255},
cliAddr: netip.MustParseAddrPort("254.253.252.255:1"),
wantLen: 1,
}}
@@ -743,9 +741,7 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
pctx := &proxy.DNSContext{
Proto: proxy.ProtoTCP,
Req: req,
Addr: &net.TCPAddr{
IP: tc.cliIP,
},
Addr: tc.cliAddr,
}
t.Run(tc.name, func(t *testing.T) {
@@ -776,6 +772,7 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
},
aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
return aghalg.Coalesce(
@@ -789,7 +786,7 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
var dnsCtx *dnsContext
setup := func(use bool) {
proxyCtx = &proxy.DNSContext{
Addr: testClientAddr,
Addr: testClientAddrPort,
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
}
dnsCtx = &dnsContext{

View File

@@ -10,9 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)
// Write Stats data and logs
@@ -25,13 +23,12 @@ func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
host := aghnet.NormalizeDomain(q.Name)
processingTime := time.Since(dctx.startTime)
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
ip = slices.Clone(ip)
ip := pctx.Addr.Addr().AsSlice()
s.anonymizer.Load()(ip)
log.Debug("dnsforward: client ip for stats and querylog: %s", ip)
ipStr := ip.String()
ipStr := pctx.Addr.Addr().String()
ids := []string{ipStr, dctx.clientID}
qt, cl := q.Qtype, q.Qclass

View File

@@ -1,7 +1,7 @@
package dnsforward
import (
"net"
"net/netip"
"testing"
"time"
@@ -65,7 +65,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name string
domain string
proto proxy.Proto
addr net.Addr
addr netip.AddrPort
clientID string
wantLogProto querylog.ClientProto
wantStatClient string
@@ -76,7 +76,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_udp",
domain: domain,
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
addr: testClientAddrPort,
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
@@ -87,7 +87,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_tls_clientid",
domain: domain,
proto: proxy.ProtoTLS,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
addr: testClientAddrPort,
clientID: "cli42",
wantLogProto: querylog.ClientProtoDoT,
wantStatClient: "cli42",
@@ -98,7 +98,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_tls",
domain: domain,
proto: proxy.ProtoTLS,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
addr: testClientAddrPort,
clientID: "",
wantLogProto: querylog.ClientProtoDoT,
wantStatClient: "1.2.3.4",
@@ -109,7 +109,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_quic",
domain: domain,
proto: proxy.ProtoQUIC,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
addr: testClientAddrPort,
clientID: "",
wantLogProto: querylog.ClientProtoDoQ,
wantStatClient: "1.2.3.4",
@@ -120,7 +120,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_https",
domain: domain,
proto: proxy.ProtoHTTPS,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
addr: testClientAddrPort,
clientID: "",
wantLogProto: querylog.ClientProtoDoH,
wantStatClient: "1.2.3.4",
@@ -131,7 +131,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_dnscrypt",
domain: domain,
proto: proxy.ProtoDNSCrypt,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
addr: testClientAddrPort,
clientID: "",
wantLogProto: querylog.ClientProtoDNSCrypt,
wantStatClient: "1.2.3.4",
@@ -142,7 +142,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_udp_filtered",
domain: domain,
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
addr: testClientAddrPort,
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
@@ -153,7 +153,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_udp_sb",
domain: domain,
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
addr: testClientAddrPort,
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
@@ -164,7 +164,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_udp_ss",
domain: domain,
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
addr: testClientAddrPort,
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
@@ -175,7 +175,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_udp_pc",
domain: domain,
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
addr: testClientAddrPort,
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
@@ -186,10 +186,10 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_udp_pc_empty_fqdn",
domain: ".",
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 5}, Port: 1234},
addr: netip.MustParseAddrPort("4.3.2.1:1234"),
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.5",
wantStatClient: "4.3.2.1",
wantCode: resultCodeSuccess,
reason: filtering.FilteredParental,
wantStatResult: stats.RParental,

View File

@@ -19,6 +19,7 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}, nil)
req := &dns.Msg{

View File

@@ -17,6 +17,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -53,6 +56,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -89,6 +95,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",

View File

@@ -22,6 +22,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -60,6 +63,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -99,6 +105,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "refused",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -138,6 +147,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -177,6 +189,98 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 6,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
"upstream_mode": "",
"cache_size": 0,
"cache_ttl_min": 0,
"cache_ttl_max": 0,
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
},
"ratelimit_subnet_len": {
"req": {
"ratelimit": 12,
"ratelimit_subnet_len_ipv4": 32,
"ratelimit_subnet_len_ipv6": 128
},
"want": {
"upstream_dns": [
"8.8.8.8:53",
"8.8.4.4:53"
],
"upstream_dns_file": "",
"bootstrap_dns": [
"9.9.9.10",
"149.112.112.10",
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 12,
"ratelimit_subnet_len_ipv4": 32,
"ratelimit_subnet_len_ipv6": 128,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
"upstream_mode": "",
"cache_size": 0,
"cache_ttl_min": 0,
"cache_ttl_max": 0,
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
},
"ratelimit_whitelist_not_ip": {
"req": {
"ratelimit_whitelist": [
"1.2.3.4",
"not.ip"
]
},
"want": {
"upstream_dns": [
"8.8.8.8:53",
"8.8.4.4:53"
],
"upstream_dns_file": "",
"bootstrap_dns": [
"9.9.9.10",
"149.112.112.10",
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -216,6 +320,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -257,6 +364,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -298,6 +408,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -337,6 +450,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -376,6 +492,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -415,6 +534,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -454,6 +576,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -495,6 +620,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -536,6 +664,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -576,6 +707,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -615,6 +749,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -656,6 +793,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -700,6 +840,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -739,6 +882,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -782,6 +928,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -821,6 +970,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
@@ -863,6 +1015,9 @@
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"ratelimit_subnet_len_ipv4": 24,
"ratelimit_subnet_len_ipv6": 56,
"ratelimit_whitelist": [],
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",

View File

@@ -1,16 +1,17 @@
package dnsforward
import (
"bytes"
"fmt"
"net"
"net/url"
"net/netip"
"os"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
@@ -18,6 +19,28 @@ import (
"golang.org/x/exp/slices"
)
const (
// errNotDomainSpecific is returned when the upstream should be
// domain-specific, but isn't.
errNotDomainSpecific errors.Error = "not a domain-specific upstream"
// errMissingSeparator is returned when the domain-specific part of the
// upstream configuration line isn't closed.
errMissingSeparator errors.Error = "missing separator"
// errDupSeparator is returned when the domain-specific part of the upstream
// configuration line contains more than one ending separator.
errDupSeparator errors.Error = "duplicated separator"
// errNoDefaultUpstreams is returned when there are no default upstreams
// specified in the upstream configuration.
errNoDefaultUpstreams errors.Error = "no default upstreams specified"
// errWrongResponse is returned when the checked upstream replies in an
// unexpected way.
errWrongResponse errors.Error = "wrong response"
)
// loadUpstreams parses upstream DNS servers from the configured file or from
// the configuration itself.
func (s *Server) loadUpstreams() (upstreams []string, err error) {
@@ -39,7 +62,7 @@ func (s *Server) loadUpstreams() (upstreams []string, err error) {
}
// prepareUpstreamSettings sets upstream DNS server settings.
func (s *Server) prepareUpstreamSettings() (err error) {
func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
// Load upstreams either from the file, or from the settings
var upstreams []string
upstreams, err = s.loadUpstreams()
@@ -48,7 +71,7 @@ func (s *Server) prepareUpstreamSettings() (err error) {
}
s.conf.UpstreamConfig, err = s.prepareUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
Bootstrap: s.conf.BootstrapDNS,
Bootstrap: boot,
Timeout: s.conf.UpstreamTimeout,
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
PreferIPv6: s.conf.BootstrapPreferIPv6,
@@ -92,178 +115,9 @@ func (s *Server) prepareUpstreamConfig(
uc.Upstreams = defaultUpstreamConfig.Upstreams
}
// dnsFilter can be nil during application update.
if s.dnsFilter != nil {
err = s.replaceUpstreamsWithHosts(uc, opts)
if err != nil {
return nil, fmt.Errorf("resolving upstreams with hosts: %w", err)
}
}
return uc, nil
}
// replaceUpstreamsWithHosts replaces unique upstreams with their resolved
// versions based on the system hosts file.
//
// TODO(e.burkov): This should be performed inside dnsproxy, which should
// actually consider /etc/hosts. See TODO on [aghnet.HostsContainer].
func (s *Server) replaceUpstreamsWithHosts(
upsConf *proxy.UpstreamConfig,
opts *upstream.Options,
) (err error) {
resolved := map[string]*upstream.Options{}
err = s.resolveUpstreamsWithHosts(resolved, upsConf.Upstreams, opts)
if err != nil {
return fmt.Errorf("resolving upstreams: %w", err)
}
hosts := maps.Keys(upsConf.DomainReservedUpstreams)
// TODO(e.burkov): Think of extracting sorted range into an util function.
slices.Sort(hosts)
for _, host := range hosts {
err = s.resolveUpstreamsWithHosts(resolved, upsConf.DomainReservedUpstreams[host], opts)
if err != nil {
return fmt.Errorf("resolving upstreams reserved for %s: %w", host, err)
}
}
hosts = maps.Keys(upsConf.SpecifiedDomainUpstreams)
slices.Sort(hosts)
for _, host := range hosts {
err = s.resolveUpstreamsWithHosts(resolved, upsConf.SpecifiedDomainUpstreams[host], opts)
if err != nil {
return fmt.Errorf("resolving upstreams specific for %s: %w", host, err)
}
}
return nil
}
// resolveUpstreamsWithHosts resolves the IP addresses of each of the upstreams
// and replaces those both in upstreams and resolved. Upstreams that failed to
// resolve are placed to resolved as-is. This function only returns error of
// upstreams closing.
func (s *Server) resolveUpstreamsWithHosts(
resolved map[string]*upstream.Options,
upstreams []upstream.Upstream,
opts *upstream.Options,
) (err error) {
for i := range upstreams {
u := upstreams[i]
addr := u.Address()
host := extractUpstreamHost(addr)
withIPs, ok := resolved[host]
if !ok {
recs := s.dnsFilter.EtcHostsRecords(host)
if len(recs) == 0 {
resolved[host] = nil
return nil
}
withIPs = opts.Clone()
withIPs.ServerIPAddrs = make([]net.IP, 0, len(recs))
for _, rec := range recs {
withIPs.ServerIPAddrs = append(withIPs.ServerIPAddrs, rec.Addr.AsSlice())
}
sortNetIPAddrs(withIPs.ServerIPAddrs, opts.PreferIPv6)
resolved[host] = withIPs
} else if withIPs == nil {
continue
}
if err = u.Close(); err != nil {
return fmt.Errorf("closing upstream %s: %w", addr, err)
}
upstreams[i], err = upstream.AddressToUpstream(addr, withIPs)
if err != nil {
return fmt.Errorf("replacing upstream %s with resolved %s: %w", addr, host, err)
}
log.Debug("dnsforward: using %s for %s", withIPs.ServerIPAddrs, upstreams[i].Address())
}
return nil
}
// extractUpstreamHost returns the hostname of addr without port with an
// assumption that any address passed here has already been successfully parsed
// by [upstream.AddressToUpstream]. This function essentially mirrors the logic
// of [upstream.AddressToUpstream], see TODO on [replaceUpstreamsWithHosts].
func extractUpstreamHost(addr string) (host string) {
var err error
if strings.Contains(addr, "://") {
var u *url.URL
u, err = url.Parse(addr)
if err != nil {
log.Debug("dnsforward: parsing upstream %s: %s", addr, err)
return addr
}
return u.Hostname()
}
// Probably, plain UDP upstream defined by address or address:port.
host, err = netutil.SplitHost(addr)
if err != nil {
return addr
}
return host
}
// sortNetIPAddrs sorts addrs in accordance with the protocol preferences.
// Invalid addresses are sorted near the end.
//
// TODO(e.burkov): This function taken from dnsproxy, which also already
// contains a few similar functions. Think of moving to golibs.
func sortNetIPAddrs(addrs []net.IP, preferIPv6 bool) {
l := len(addrs)
if l <= 1 {
return
}
slices.SortStableFunc(addrs, func(addrA, addrB net.IP) (res int) {
switch len(addrA) {
case net.IPv4len, net.IPv6len:
switch len(addrB) {
case net.IPv4len, net.IPv6len:
// Go on.
default:
return -1
}
default:
return 1
}
// Treat IPv6-mapped IPv4 addresses as IPv6 addresses.
aIs4, bIs4 := addrA.To4() != nil, addrB.To4() != nil
if aIs4 == bIs4 {
return bytes.Compare(addrA, addrB)
}
if aIs4 {
if preferIPv6 {
return 1
}
return -1
}
if preferIPv6 {
return -1
}
return 1
})
}
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration
// depending on configuration.
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
@@ -295,3 +149,221 @@ func setProxyUpstreamMode(
conf.UpstreamMode = proxy.UModeLoadBalance
}
}
// createBootstrap returns a bootstrap resolver based on the configuration of s.
// boots are the upstream resolvers that should be closed after use. r is the
// actual bootstrap resolver, which may include the system hosts.
//
// TODO(e.burkov): This function currently returns a resolver and a slice of
// the upstream resolvers, which are essentially the same. boots are returned
// for being able to close them afterwards, but it introduces an implicit
// contract that r could only be used before that. Anyway, this code should
// improve when the [proxy.UpstreamConfig] will become an [upstream.Resolver]
// and be used here.
func (s *Server) createBootstrap(
addrs []string,
opts *upstream.Options,
) (r upstream.Resolver, boots []*upstream.UpstreamResolver, err error) {
if len(addrs) == 0 {
addrs = defaultBootstrap
}
boots, err = aghnet.ParseBootstraps(addrs, opts)
if err != nil {
// Don't wrap the error, since it's informative enough as is.
return nil, nil, err
}
var parallel upstream.ParallelResolver
for _, b := range boots {
parallel = append(parallel, b)
}
if s.etcHosts != nil {
r = upstream.ConsequentResolver{s.etcHosts, parallel}
} else {
r = parallel
}
return r, boots, nil
}
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
// This function is useful for filtering out non-upstream lines from upstream
// configs.
func IsCommentOrEmpty(s string) (ok bool) {
return len(s) == 0 || s[0] == '#'
}
// newUpstreamConfig validates upstreams and returns an appropriate upstream
// configuration or nil if it can't be built.
//
// TODO(e.burkov): Perhaps proxy.ParseUpstreamsConfig should validate upstreams
// slice already so that this function may be considered useless.
func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err error) {
// No need to validate comments and empty lines.
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
if len(upstreams) == 0 {
// Consider this case valid since it means the default server should be
// used.
return nil, nil
}
err = validateUpstreamConfig(upstreams)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
conf, err = proxy.ParseUpstreamsConfig(
upstreams,
&upstream.Options{
Bootstrap: net.DefaultResolver,
Timeout: DefaultTimeout,
},
)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
} else if len(conf.Upstreams) == 0 {
return nil, errNoDefaultUpstreams
}
return conf, nil
}
// validateUpstreamConfig validates each upstream from the upstream
// configuration and returns an error if any upstream is invalid.
//
// TODO(e.burkov): Merge with [upstreamConfigValidator] somehow.
func validateUpstreamConfig(conf []string) (err error) {
for _, u := range conf {
var ups []string
var isSpecific bool
ups, isSpecific, err = splitUpstreamLine(u)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
for _, addr := range ups {
_, err = validateUpstream(addr, isSpecific)
if err != nil {
return fmt.Errorf("validating upstream %q: %w", addr, err)
}
}
}
return nil
}
// ValidateUpstreams validates each upstream and returns an error if any
// upstream is invalid or if there are no default upstreams specified.
//
// TODO(e.burkov): Merge with [upstreamConfigValidator] somehow.
func ValidateUpstreams(upstreams []string) (err error) {
_, err = newUpstreamConfig(upstreams)
return err
}
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
// upstream is invalid or if there are no default upstreams specified. It also
// checks each domain of domain-specific upstreams for being ARPA pointing to
// a locally-served network. privateNets must not be nil.
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
conf, err := newUpstreamConfig(upstreams)
if err != nil {
return fmt.Errorf("creating config: %w", err)
}
if conf == nil {
return nil
}
keys := maps.Keys(conf.DomainReservedUpstreams)
slices.Sort(keys)
var errs []error
for _, domain := range keys {
var subnet netip.Prefix
subnet, err = extractARPASubnet(domain)
if err != nil {
errs = append(errs, err)
continue
}
if !privateNets.Contains(subnet.Addr().AsSlice()) {
errs = append(
errs,
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
)
}
}
return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w")
}
// protocols are the supported URL schemes for upstreams.
var protocols = []string{"h3", "https", "quic", "sdns", "tcp", "tls", "udp"}
// validateUpstream returns an error if u alongside with domains is not a valid
// upstream configuration. useDefault is true if the upstream is
// domain-specific and is configured to point at the default upstream server
// which is validated separately. The upstream is considered domain-specific
// only if domains is at least not nil.
func validateUpstream(u string, isSpecific bool) (useDefault bool, err error) {
// The special server address '#' means that default server must be used.
if useDefault = u == "#" && isSpecific; useDefault {
return useDefault, nil
}
// Check if the upstream has a valid protocol prefix.
//
// TODO(e.burkov): Validate the domain name.
if proto, _, ok := strings.Cut(u, "://"); ok {
if !slices.Contains(protocols, proto) {
return false, fmt.Errorf("bad protocol %q", proto)
}
} else if _, err = netip.ParseAddr(u); err == nil {
return false, nil
} else if _, err = netip.ParseAddrPort(u); err == nil {
return false, nil
}
return false, err
}
// splitUpstreamLine returns the upstreams and the specified domains. domains
// is nil when the upstream is not domains-specific. Otherwise it may also be
// empty.
func splitUpstreamLine(upstreamStr string) (upstreams []string, isSpecific bool, err error) {
if !strings.HasPrefix(upstreamStr, "[/") {
return []string{upstreamStr}, false, nil
}
defer func() { err = errors.Annotate(err, "splitting upstream line %q: %w", upstreamStr) }()
doms, ups, found := strings.Cut(upstreamStr[2:], "/]")
if !found {
return nil, false, errMissingSeparator
} else if strings.Contains(ups, "/]") {
return nil, false, errDupSeparator
}
for i, host := range strings.Split(doms, "/") {
if host == "" {
continue
}
err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*."))
if err != nil {
return nil, false, fmt.Errorf("domain at index %d: %w", i, err)
}
isSpecific = true
}
return strings.Fields(ups), isSpecific, nil
}

View File

@@ -0,0 +1,223 @@
package dnsforward
import (
"net"
"net/url"
"strings"
"testing"
"time"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestUpstreamConfigValidator(t *testing.T) {
goodHandler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
err := w.WriteMsg(new(dns.Msg).SetReply(m))
require.NoError(testutil.PanicT{}, err)
})
badHandler := dns.HandlerFunc(func(w dns.ResponseWriter, _ *dns.Msg) {
err := w.WriteMsg(new(dns.Msg))
require.NoError(testutil.PanicT{}, err)
})
goodUps := (&url.URL{
Scheme: "tcp",
Host: newLocalUpstreamListener(t, 0, goodHandler).String(),
}).String()
badUps := (&url.URL{
Scheme: "tcp",
Host: newLocalUpstreamListener(t, 0, badHandler).String(),
}).String()
goodAndBadUps := strings.Join([]string{goodUps, badUps}, " ")
// upsTimeout restricts the checking process to prevent the test from
// hanging.
const upsTimeout = 100 * time.Millisecond
testCases := []struct {
want map[string]string
name string
general []string
fallback []string
private []string
}{{
name: "success",
general: []string{goodUps},
want: map[string]string{
goodUps: "OK",
},
}, {
name: "broken",
general: []string{badUps},
want: map[string]string{
badUps: `couldn't communicate with upstream: exchanging with ` +
badUps + ` over tcp: dns: id mismatch`,
},
}, {
name: "both",
general: []string{goodUps, badUps, goodUps},
want: map[string]string{
goodUps: "OK",
badUps: `couldn't communicate with upstream: exchanging with ` +
badUps + ` over tcp: dns: id mismatch`,
},
}, {
name: "domain_specific_error",
general: []string{"[/domain.example/]" + badUps},
want: map[string]string{
badUps: `WARNING: couldn't communicate ` +
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
`dns: id mismatch`,
},
}, {
name: "fallback_success",
fallback: []string{goodUps},
want: map[string]string{
goodUps: "OK",
},
}, {
name: "fallback_broken",
fallback: []string{badUps},
want: map[string]string{
badUps: `couldn't communicate with upstream: exchanging with ` +
badUps + ` over tcp: dns: id mismatch`,
},
}, {
name: "multiple_domain_specific_upstreams",
general: []string{"[/domain.example/]" + goodAndBadUps},
want: map[string]string{
goodUps: "OK",
badUps: `WARNING: couldn't communicate ` +
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
`dns: id mismatch`,
},
}, {
name: "bad_specification",
general: []string{"[/domain.example/]/]1.2.3.4"},
want: map[string]string{
"[/domain.example/]/]1.2.3.4": `splitting upstream line ` +
`"[/domain.example/]/]1.2.3.4": duplicated separator`,
},
}, {
name: "all_different",
general: []string{"[/domain.example/]" + goodAndBadUps},
fallback: []string{"[/domain.example/]" + goodAndBadUps},
private: []string{"[/domain.example/]" + goodAndBadUps},
want: map[string]string{
goodUps: "OK",
badUps: `WARNING: couldn't communicate ` +
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
`dns: id mismatch`,
},
}, {
name: "bad_specific_domains",
general: []string{"[/example/]/]" + goodUps},
fallback: []string{"[/example/" + goodUps},
private: []string{"[/example//bad.123/]" + goodUps},
want: map[string]string{
`[/example/]/]` + goodUps: `splitting upstream line ` +
`"[/example/]/]` + goodUps + `": duplicated separator`,
`[/example/` + goodUps: `splitting upstream line ` +
`"[/example/` + goodUps + `": missing separator`,
`[/example//bad.123/]` + goodUps: `splitting upstream line ` +
`"[/example//bad.123/]` + goodUps + `": domain at index 2: ` +
`bad domain name "bad.123": ` +
`bad top-level domain name label "123": all octets are numeric`,
},
}, {
name: "non-specific_default",
general: []string{
"#",
"[/example/]#",
},
want: map[string]string{
"#": "not a domain-specific upstream",
},
}, {
name: "bad_proto",
general: []string{
"bad://1.2.3.4",
},
want: map[string]string{
"bad://1.2.3.4": `bad protocol "bad"`,
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cv := newUpstreamConfigValidator(tc.general, tc.fallback, tc.private, &upstream.Options{
Timeout: upsTimeout,
Bootstrap: net.DefaultResolver,
})
cv.check()
cv.close()
assert.Equal(t, tc.want, cv.status())
})
}
}
func TestUpstreamConfigValidator_Check_once(t *testing.T) {
type signal = struct{}
reqCh := make(chan signal)
hdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
pt := testutil.PanicT{}
err := w.WriteMsg(new(dns.Msg).SetReply(m))
require.NoError(pt, err)
testutil.RequireSend(pt, reqCh, signal{}, testTimeout)
})
addr := (&url.URL{
Scheme: "tcp",
Host: newLocalUpstreamListener(t, 0, hdlr).String(),
}).String()
twoAddrs := strings.Join([]string{addr, addr}, " ")
wantStatus := map[string]string{
addr: "OK",
}
testCases := []struct {
name string
ups []string
}{{
name: "common",
ups: []string{addr, addr, addr},
}, {
name: "domain-specific",
ups: []string{"[/one.example/]" + addr, "[/two.example/]" + twoAddrs},
}, {
name: "both",
ups: []string{addr, "[/one.example/]" + addr, addr, "[/two.example/]" + twoAddrs},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cv := newUpstreamConfigValidator(tc.ups, nil, nil, &upstream.Options{
Timeout: testTimeout,
})
go func() {
cv.check()
testutil.RequireSend(testutil.PanicT{}, reqCh, signal{}, testTimeout)
}()
// Wait for the only request to be sent.
testutil.RequireReceive(t, reqCh, testTimeout)
// Wait for the check to finish.
testutil.RequireReceive(t, reqCh, testTimeout)
cv.close()
require.Equal(t, wantStatus, cv.status())
})
}
}