all: sync with master; upd chlog
This commit is contained in:
@@ -15,7 +15,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
@@ -436,102 +435,6 @@ func (s *Server) initDefaultSettings() {
|
||||
}
|
||||
}
|
||||
|
||||
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration
|
||||
// depending on configuration.
|
||||
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
|
||||
if !http3 {
|
||||
return upstream.DefaultHTTPVersions
|
||||
}
|
||||
|
||||
return []upstream.HTTPVersion{
|
||||
upstream.HTTPVersion3,
|
||||
upstream.HTTPVersion2,
|
||||
upstream.HTTPVersion11,
|
||||
}
|
||||
}
|
||||
|
||||
// prepareUpstreamSettings - prepares upstream DNS server settings
|
||||
func (s *Server) prepareUpstreamSettings() error {
|
||||
// We're setting a customized set of RootCAs. The reason is that Go default
|
||||
// mechanism of loading TLS roots does not always work properly on some
|
||||
// routers so we're loading roots manually and pass it here.
|
||||
//
|
||||
// See [aghtls.SystemRootCAs].
|
||||
upstream.RootCAs = s.conf.TLSv12Roots
|
||||
upstream.CipherSuites = s.conf.TLSCiphers
|
||||
|
||||
// Load upstreams either from the file, or from the settings
|
||||
var upstreams []string
|
||||
if s.conf.UpstreamDNSFileName != "" {
|
||||
data, err := os.ReadFile(s.conf.UpstreamDNSFileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading upstream from file: %w", err)
|
||||
}
|
||||
|
||||
upstreams = stringutil.SplitTrimmed(string(data), "\n")
|
||||
|
||||
log.Debug("dns: using %d upstream servers from file %s", len(upstreams), s.conf.UpstreamDNSFileName)
|
||||
} else {
|
||||
upstreams = s.conf.UpstreamDNS
|
||||
}
|
||||
|
||||
httpVersions := UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams)
|
||||
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
|
||||
upstreamConfig, err := proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
&upstream.Options{
|
||||
Bootstrap: s.conf.BootstrapDNS,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
HTTPVersions: httpVersions,
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing upstream config: %w", err)
|
||||
}
|
||||
|
||||
if len(upstreamConfig.Upstreams) == 0 {
|
||||
log.Info("warning: no default upstream servers specified, using %v", defaultDNS)
|
||||
var uc *proxy.UpstreamConfig
|
||||
uc, err = proxy.ParseUpstreamsConfig(
|
||||
defaultDNS,
|
||||
&upstream.Options{
|
||||
Bootstrap: s.conf.BootstrapDNS,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
HTTPVersions: httpVersions,
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing default upstreams: %w", err)
|
||||
}
|
||||
|
||||
upstreamConfig.Upstreams = uc.Upstreams
|
||||
}
|
||||
|
||||
s.conf.UpstreamConfig = upstreamConfig
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setProxyUpstreamMode sets the upstream mode and related settings in conf
|
||||
// based on provided parameters.
|
||||
func setProxyUpstreamMode(
|
||||
conf *proxy.Config,
|
||||
allServers bool,
|
||||
fastestAddr bool,
|
||||
fastestTimeout time.Duration,
|
||||
) {
|
||||
if allServers {
|
||||
conf.UpstreamMode = proxy.UModeParallel
|
||||
} else if fastestAddr {
|
||||
conf.UpstreamMode = proxy.UModeFastestAddr
|
||||
conf.FastestPingTimeout = fastestTimeout
|
||||
} else {
|
||||
conf.UpstreamMode = proxy.UModeLoadBalance
|
||||
}
|
||||
}
|
||||
|
||||
// prepareIpsetListSettings reads and prepares the ipset configuration either
|
||||
// from a file or from the data in the configuration file.
|
||||
func (s *Server) prepareIpsetListSettings() (err error) {
|
||||
@@ -540,6 +443,7 @@ func (s *Server) prepareIpsetListSettings() (err error) {
|
||||
return s.ipset.init(s.conf.IpsetList)
|
||||
}
|
||||
|
||||
// #nosec G304 -- Trust the path explicitly given by the user.
|
||||
data, err := os.ReadFile(fn)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -145,10 +145,13 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error
|
||||
// processRecursion checks the incoming request and halts its handling by
|
||||
// answering NXDOMAIN if s has tried to resolve it recently.
|
||||
func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing recursion")
|
||||
defer log.Debug("dnsforward: finished processing recursion")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
|
||||
if msg := pctx.Req; msg != nil && s.recDetector.check(*msg) {
|
||||
log.Debug("recursion detected resolving %q", msg.Question[0].Name)
|
||||
log.Debug("dnsforward: recursion detected resolving %q", msg.Question[0].Name)
|
||||
pctx.Res = s.genNXDomain(pctx.Req)
|
||||
|
||||
return resultCodeFinish
|
||||
@@ -158,10 +161,13 @@ func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
|
||||
}
|
||||
|
||||
// processInitial terminates the following processing for some requests if
|
||||
// needed and enriches the ctx with some client-specific information.
|
||||
// needed and enriches dctx with some client-specific information.
|
||||
//
|
||||
// TODO(e.burkov): Decompose into less general processors.
|
||||
func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing initial")
|
||||
defer log.Debug("dnsforward: finished processing initial")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
q := pctx.Req.Question[0]
|
||||
qt := q.Qtype
|
||||
@@ -282,6 +288,9 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
//
|
||||
// See https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html.
|
||||
func (s *Server) processDDRQuery(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing ddr")
|
||||
defer log.Debug("dnsforward: finished processing ddr")
|
||||
|
||||
if !s.conf.HandleDDR {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
@@ -375,6 +384,9 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
|
||||
// processDetermineLocal determines if the client's IP address is from locally
|
||||
// served network and saves the result into the context.
|
||||
func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing local detection")
|
||||
defer log.Debug("dnsforward: finished processing local detection")
|
||||
|
||||
rc = resultCodeSuccess
|
||||
|
||||
var ip net.IP
|
||||
@@ -405,6 +417,9 @@ func (s *Server) dhcpHostToIP(host string) (ip netip.Addr, ok bool) {
|
||||
//
|
||||
// TODO(a.garipov): Adapt to AAAA as well.
|
||||
func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing dhcp hosts")
|
||||
defer log.Debug("dnsforward: finished processing dhcp hosts")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
req := pctx.Req
|
||||
q := req.Question[0]
|
||||
@@ -544,6 +559,9 @@ func extractARPASubnet(domain string) (pref netip.Prefix, err error) {
|
||||
// processRestrictLocal responds with NXDOMAIN to PTR requests for IP addresses
|
||||
// in locally served network from external clients.
|
||||
func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing local restriction")
|
||||
defer log.Debug("dnsforward: finished processing local restriction")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
req := pctx.Req
|
||||
q := req.Question[0]
|
||||
@@ -613,6 +631,9 @@ func (s *Server) ipToDHCPHost(ip netip.Addr) (host string, ok bool) {
|
||||
// processDHCPAddrs responds to PTR requests if the target IP is leased by the
|
||||
// DHCP server.
|
||||
func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing dhcp addrs")
|
||||
defer log.Debug("dnsforward: finished processing dhcp addrs")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
if pctx.Res != nil {
|
||||
return resultCodeSuccess
|
||||
@@ -658,6 +679,9 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
|
||||
// processLocalPTR responds to PTR requests if the target IP is detected to be
|
||||
// inside the local network and the query was not answered from DHCP.
|
||||
func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing local ptr")
|
||||
defer log.Debug("dnsforward: finished processing local ptr")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
if pctx.Res != nil {
|
||||
return resultCodeSuccess
|
||||
@@ -692,6 +716,9 @@ func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
|
||||
|
||||
// Apply filtering logic
|
||||
func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing filtering before req")
|
||||
defer log.Debug("dnsforward: finished processing filtering before req")
|
||||
|
||||
if ctx.proxyCtx.Res != nil {
|
||||
// Go on since the response is already set.
|
||||
return resultCodeSuccess
|
||||
@@ -725,6 +752,9 @@ func ipStringFromAddr(addr net.Addr) (ipStr string) {
|
||||
|
||||
// processUpstream passes request to upstream servers and handles the response.
|
||||
func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing upstream")
|
||||
defer log.Debug("dnsforward: finished processing upstream")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
req := pctx.Req
|
||||
q := req.Question[0]
|
||||
@@ -871,6 +901,9 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
|
||||
|
||||
// Apply filtering logic after we have received response from upstream servers
|
||||
func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing filtering after resp")
|
||||
defer log.Debug("dnsforward: finished processing filtering after resp")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
switch res := dctx.result; res.Reason {
|
||||
case filtering.NotFilteredAllowList:
|
||||
|
||||
@@ -48,12 +48,33 @@ var webRegistered bool
|
||||
|
||||
// hostToIPTable is a convenient type alias for tables of host names to an IP
|
||||
// address.
|
||||
//
|
||||
// TODO(e.burkov): Use the [DHCP] interface instead.
|
||||
type hostToIPTable = map[string]netip.Addr
|
||||
|
||||
// ipToHostTable is a convenient type alias for tables of IP addresses to their
|
||||
// host names. For example, for use with PTR queries.
|
||||
//
|
||||
// TODO(e.burkov): Use the [DHCP] interface instead.
|
||||
type ipToHostTable = map[netip.Addr]string
|
||||
|
||||
// DHCP is an interface for accessing DHCP lease data needed in this package.
|
||||
type DHCP interface {
|
||||
// HostByIP returns the hostname of the DHCP client with the given IP
|
||||
// address. The address will be netip.Addr{} if there is no such client,
|
||||
// due to an assumption that a DHCP client must always have an IP address.
|
||||
HostByIP(ip netip.Addr) (host string)
|
||||
|
||||
// IPByHost returns the IP address of the DHCP client with the given
|
||||
// hostname. The hostname will be an empty string if there is no such
|
||||
// client, due to an assumption that a DHCP client must always have a
|
||||
// hostname, either set by the client or assigned automatically.
|
||||
IPByHost(host string) (ip netip.Addr)
|
||||
|
||||
// Enabled returns true if DHCP provides information about clients.
|
||||
Enabled() (ok bool)
|
||||
}
|
||||
|
||||
// Server is the main way to start a DNS server.
|
||||
//
|
||||
// Example:
|
||||
@@ -215,7 +236,7 @@ func (s *Server) Close() {
|
||||
s.dnsProxy = nil
|
||||
|
||||
if err := s.ipset.close(); err != nil {
|
||||
log.Error("closing ipset: %s", err)
|
||||
log.Error("dnsforward: closing ipset: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -443,21 +464,17 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("upstreams to resolve PTR for local addresses: %v", localAddrs)
|
||||
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", localAddrs)
|
||||
|
||||
var upsConfig *proxy.UpstreamConfig
|
||||
upsConfig, err = proxy.ParseUpstreamsConfig(
|
||||
localAddrs,
|
||||
&upstream.Options{
|
||||
Bootstrap: bootstraps,
|
||||
Timeout: defaultLocalTimeout,
|
||||
// TODO(e.burkov): Should we verify server's certificates?
|
||||
upsConfig, err := s.prepareUpstreamConfig(localAddrs, nil, &upstream.Options{
|
||||
Bootstrap: bootstraps,
|
||||
Timeout: defaultLocalTimeout,
|
||||
// TODO(e.burkov): Should we verify server's certificates?
|
||||
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
},
|
||||
)
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing upstreams: %w", err)
|
||||
return fmt.Errorf("parsing private upstreams: %w", err)
|
||||
}
|
||||
|
||||
s.localResolvers = &proxy.Proxy{
|
||||
@@ -489,7 +506,8 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
|
||||
err = s.prepareUpstreamSettings()
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing upstream settings: %w", err)
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
var proxyConfig proxy.Config
|
||||
@@ -656,7 +674,9 @@ func (s *Server) Reconfigure(conf *ServerConfig) error {
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
log.Print("Start reconfiguring the server")
|
||||
log.Info("dnsforward: starting reconfiguring server")
|
||||
defer log.Info("dnsforward: finished reconfiguring server")
|
||||
|
||||
err := s.stopLocked()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not reconfigure the server: %w", err)
|
||||
@@ -708,13 +728,13 @@ func (s *Server) IsBlockedClient(ip netip.Addr, clientID string) (blocked bool,
|
||||
// Allow if at least one of the checks allows in allowlist mode, but block
|
||||
// if at least one of the checks blocks in blocklist mode.
|
||||
if allowlistMode && blockedByIP && blockedByClientID {
|
||||
log.Debug("client %v (id %q) is not in access allowlist", ip, clientID)
|
||||
log.Debug("dnsforward: client %v (id %q) is not in access allowlist", ip, clientID)
|
||||
|
||||
// Return now without substituting the empty rule for the
|
||||
// clientID because the rule can't be empty here.
|
||||
return true, rule
|
||||
} else if !allowlistMode && (blockedByIP || blockedByClientID) {
|
||||
log.Debug("client %v (id %q) is in access blocklist", ip, clientID)
|
||||
log.Debug("dnsforward: client %v (id %q) is in access blocklist", ip, clientID)
|
||||
|
||||
blocked = true
|
||||
}
|
||||
|
||||
@@ -53,14 +53,14 @@ func (s *Server) beforeRequestHandler(
|
||||
// getClientRequestFilteringSettings looks up client filtering settings using
|
||||
// the client's IP address and ID, if any, from dctx.
|
||||
func (s *Server) getClientRequestFilteringSettings(dctx *dnsContext) *filtering.Settings {
|
||||
setts := s.dnsFilter.GetConfig()
|
||||
setts := s.dnsFilter.Settings()
|
||||
setts.ProtectionEnabled = dctx.protectionEnabled
|
||||
if s.conf.FilterHandler != nil {
|
||||
ip, _ := netutil.IPAndPortFromAddr(dctx.proxyCtx.Addr)
|
||||
s.conf.FilterHandler(ip, dctx.clientID, &setts)
|
||||
s.conf.FilterHandler(ip, dctx.clientID, setts)
|
||||
}
|
||||
|
||||
return &setts
|
||||
return setts
|
||||
}
|
||||
|
||||
// filterDNSRequest applies the dnsFilter and sets dctx.proxyCtx.Res if the
|
||||
|
||||
@@ -633,61 +633,70 @@ func (err domainSpecificTestError) Error() (msg string) {
|
||||
return fmt.Sprintf("WARNING: %s", err.error)
|
||||
}
|
||||
|
||||
// checkDNS checks the upstream server defined by upstreamConfigStr using
|
||||
// healthCheck for actually exchange messages. It uses bootstrap to resolve the
|
||||
// upstream's address.
|
||||
func checkDNS(
|
||||
upstreamConfigStr string,
|
||||
bootstrap []string,
|
||||
bootstrapPrefIPv6 bool,
|
||||
timeout time.Duration,
|
||||
healthCheck healthCheckFunc,
|
||||
) (err error) {
|
||||
if IsCommentOrEmpty(upstreamConfigStr) {
|
||||
return nil
|
||||
// parseUpstreamLine parses line and creates the [upstream.Upstream] using opts
|
||||
// and information from [s.dnsFilter.EtcHosts]. It returns an error if the line
|
||||
// is not a valid upstream line, see [upstream.AddressToUpstream]. It's a
|
||||
// caller's responsibility to close u.
|
||||
func (s *Server) parseUpstreamLine(
|
||||
line string,
|
||||
opts *upstream.Options,
|
||||
) (u upstream.Upstream, specific bool, err error) {
|
||||
// Separate upstream from domains list.
|
||||
upstreamAddr, domains, err := separateUpstream(line)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("wrong upstream format: %w", err)
|
||||
}
|
||||
|
||||
// Separate upstream from domains list.
|
||||
upstreamAddr, domains, err := separateUpstream(upstreamConfigStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wrong upstream format: %w", err)
|
||||
}
|
||||
specific = len(domains) > 0
|
||||
|
||||
useDefault, err := validateUpstream(upstreamAddr, domains)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wrong upstream format: %w", err)
|
||||
return nil, specific, fmt.Errorf("wrong upstream format: %w", err)
|
||||
} else if useDefault {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(bootstrap) == 0 {
|
||||
bootstrap = defaultBootstrap
|
||||
return nil, specific, nil
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
|
||||
|
||||
u, err := upstream.AddressToUpstream(upstreamAddr, &upstream.Options{
|
||||
Bootstrap: bootstrap,
|
||||
Timeout: timeout,
|
||||
PreferIPv6: bootstrapPrefIPv6,
|
||||
})
|
||||
opts = &upstream.Options{
|
||||
Bootstrap: opts.Bootstrap,
|
||||
Timeout: opts.Timeout,
|
||||
PreferIPv6: opts.PreferIPv6,
|
||||
}
|
||||
|
||||
if s.dnsFilter != nil && s.dnsFilter.EtcHosts != nil {
|
||||
resolved := s.resolveUpstreamHost(extractUpstreamHost(upstreamAddr))
|
||||
sortNetIPAddrs(resolved, opts.PreferIPv6)
|
||||
opts.ServerIPAddrs = resolved
|
||||
}
|
||||
u, err = upstream.AddressToUpstream(upstreamAddr, opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to choose upstream for %q: %w", upstreamAddr, err)
|
||||
return nil, specific, fmt.Errorf("creating upstream for %q: %w", upstreamAddr, err)
|
||||
}
|
||||
|
||||
return u, specific, nil
|
||||
}
|
||||
|
||||
func (s *Server) checkDNS(line string, opts *upstream.Options, check healthCheckFunc) (err error) {
|
||||
if IsCommentOrEmpty(line) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var u upstream.Upstream
|
||||
var specific bool
|
||||
defer func() {
|
||||
if err != nil && specific {
|
||||
err = domainSpecificTestError{error: err}
|
||||
}
|
||||
}()
|
||||
|
||||
u, specific, err = s.parseUpstreamLine(line, opts)
|
||||
if err != nil || u == nil {
|
||||
return err
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, u.Close()) }()
|
||||
|
||||
if err = healthCheck(u); err != nil {
|
||||
err = fmt.Errorf("upstream %q fails to exchange: %w", upstreamAddr, err)
|
||||
if domains != nil {
|
||||
return domainSpecificTestError{error: err}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: upstream %q is ok", upstreamAddr)
|
||||
|
||||
return nil
|
||||
return check(u)
|
||||
}
|
||||
|
||||
func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -699,47 +708,54 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
result := map[string]string{}
|
||||
bootstraps := req.BootstrapDNS
|
||||
bootstrapPrefIPv6 := s.conf.BootstrapPreferIPv6
|
||||
timeout := s.conf.UpstreamTimeout
|
||||
opts := &upstream.Options{
|
||||
Bootstrap: req.BootstrapDNS,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
}
|
||||
if len(opts.Bootstrap) == 0 {
|
||||
opts.Bootstrap = defaultBootstrap
|
||||
}
|
||||
|
||||
type upsCheckResult = struct {
|
||||
res string
|
||||
err error
|
||||
host string
|
||||
}
|
||||
|
||||
req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty)
|
||||
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
|
||||
|
||||
upsNum := len(req.Upstreams) + len(req.PrivateUpstreams)
|
||||
result := make(map[string]string, upsNum)
|
||||
resCh := make(chan upsCheckResult, upsNum)
|
||||
|
||||
checkUps := func(ups string, healthCheck healthCheckFunc) {
|
||||
res := upsCheckResult{
|
||||
host: ups,
|
||||
}
|
||||
defer func() { resCh <- res }()
|
||||
|
||||
checkErr := checkDNS(ups, bootstraps, bootstrapPrefIPv6, timeout, healthCheck)
|
||||
if checkErr != nil {
|
||||
res.res = checkErr.Error()
|
||||
} else {
|
||||
res.res = "OK"
|
||||
}
|
||||
}
|
||||
|
||||
for _, ups := range req.Upstreams {
|
||||
go checkUps(ups, checkDNSUpstreamExc)
|
||||
go func(ups string) {
|
||||
resCh <- upsCheckResult{
|
||||
host: ups,
|
||||
err: s.checkDNS(ups, opts, checkDNSUpstreamExc),
|
||||
}
|
||||
}(ups)
|
||||
}
|
||||
for _, ups := range req.PrivateUpstreams {
|
||||
go checkUps(ups, checkPrivateUpstreamExc)
|
||||
go func(ups string) {
|
||||
resCh <- upsCheckResult{
|
||||
host: ups,
|
||||
err: s.checkDNS(ups, opts, checkPrivateUpstreamExc),
|
||||
}
|
||||
}(ups)
|
||||
}
|
||||
|
||||
for i := 0; i < upsNum; i++ {
|
||||
pair := <-resCh
|
||||
// TODO(e.burkov): The upstreams used for both common and private
|
||||
// resolving should be reported separately.
|
||||
result[pair.host] = pair.res
|
||||
pair := <-resCh
|
||||
if pair.err != nil {
|
||||
result[pair.host] = pair.err.Error()
|
||||
} else {
|
||||
result[pair.host] = "OK"
|
||||
}
|
||||
}
|
||||
close(resCh)
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, result)
|
||||
}
|
||||
|
||||
@@ -13,10 +13,12 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
@@ -280,6 +282,10 @@ func TestIsCommentOrEmpty(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateUpstreams(t *testing.T) {
|
||||
const sdnsStamp = `sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_J` +
|
||||
`S3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczE` +
|
||||
`uYWRndWFyZC5jb20`
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErr string
|
||||
@@ -300,7 +306,7 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
"[//]tls://1.1.1.1",
|
||||
"[/www.host.com/]#",
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
"[/host/]" + sdnsStamp,
|
||||
},
|
||||
}, {
|
||||
name: "with_default",
|
||||
@@ -310,7 +316,7 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
"[//]tls://1.1.1.1",
|
||||
"[/www.host.com/]#",
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
"[/host/]" + sdnsStamp,
|
||||
"8.8.8.8",
|
||||
},
|
||||
}, {
|
||||
@@ -326,9 +332,10 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
wantErr: `validating upstream "123.3.7m": not an ip:port`,
|
||||
set: []string{"123.3.7m"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": missing separator`,
|
||||
set: []string{"[/host.com]tls://dns.adguard.com"},
|
||||
name: "invalid",
|
||||
wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": ` +
|
||||
`missing separator`,
|
||||
set: []string{"[/host.com]tls://dns.adguard.com"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `validating upstream "[host.ru]#": not an ip:port`,
|
||||
@@ -340,14 +347,14 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
"1.1.1.1",
|
||||
"tls://1.1.1.1",
|
||||
"https://dns.adguard.com/dns-query",
|
||||
"sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
sdnsStamp,
|
||||
"udp://dns.google",
|
||||
"udp://8.8.8.8",
|
||||
"[/host.com/]1.1.1.1",
|
||||
"[//]tls://1.1.1.1",
|
||||
"[/www.host.com/]#",
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
"[/host/]" + sdnsStamp,
|
||||
"[/пример.рф/]8.8.8.8",
|
||||
},
|
||||
}, {
|
||||
@@ -418,27 +425,28 @@ func TestValidateUpstreamsPrivate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func newLocalUpstreamListener(t *testing.T, port int, handler dns.Handler) (real net.Addr) {
|
||||
func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) {
|
||||
t.Helper()
|
||||
|
||||
startCh := make(chan struct{})
|
||||
upsSrv := &dns.Server{
|
||||
Addr: netip.AddrPortFrom(netutil.IPv4Localhost(), uint16(port)).String(),
|
||||
Addr: netip.AddrPortFrom(netutil.IPv4Localhost(), port).String(),
|
||||
Net: "tcp",
|
||||
Handler: handler,
|
||||
NotifyStartedFunc: func() { close(startCh) },
|
||||
}
|
||||
go func() {
|
||||
t := testutil.PanicT{}
|
||||
|
||||
err := upsSrv.ListenAndServe()
|
||||
require.NoError(t, err)
|
||||
require.NoError(testutil.PanicT{}, err)
|
||||
}()
|
||||
|
||||
<-startCh
|
||||
testutil.CleanupAndRequireSuccess(t, upsSrv.Shutdown)
|
||||
|
||||
return upsSrv.Listener.Addr()
|
||||
return testutil.RequireTypeAssert[*net.TCPAddr](t, upsSrv.Listener.Addr()).AddrPort()
|
||||
}
|
||||
|
||||
func TestServer_handleTestUpstreaDNS(t *testing.T) {
|
||||
func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||
goodHandler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
err := w.WriteMsg(new(dns.Msg).SetReply(m))
|
||||
require.NoError(testutil.PanicT{}, err)
|
||||
@@ -457,9 +465,38 @@ func TestServer_handleTestUpstreaDNS(t *testing.T) {
|
||||
Host: newLocalUpstreamListener(t, 0, badHandler).String(),
|
||||
}).String()
|
||||
|
||||
const upsTimeout = 100 * time.Millisecond
|
||||
const (
|
||||
upsTimeout = 100 * time.Millisecond
|
||||
|
||||
srv := createTestServer(t, &filtering.Config{}, ServerConfig{
|
||||
hostsFileName = "hosts"
|
||||
upstreamHost = "custom.localhost"
|
||||
)
|
||||
|
||||
hostsListener := newLocalUpstreamListener(t, 0, goodHandler)
|
||||
hostsUps := (&url.URL{
|
||||
Scheme: "tcp",
|
||||
Host: netutil.JoinHostPort(upstreamHost, int(hostsListener.Port())),
|
||||
}).String()
|
||||
|
||||
hc, err := aghnet.NewHostsContainer(
|
||||
filtering.SysHostsListID,
|
||||
fstest.MapFS{
|
||||
hostsFileName: &fstest.MapFile{
|
||||
Data: []byte(hostsListener.Addr().String() + " " + upstreamHost),
|
||||
},
|
||||
},
|
||||
&aghtest.FSWatcher{
|
||||
OnEvents: func() (e <-chan struct{}) { return nil },
|
||||
OnAdd: func(_ string) (err error) { return nil },
|
||||
OnClose: func() (err error) { return nil },
|
||||
},
|
||||
hostsFileName,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
srv := createTestServer(t, &filtering.Config{
|
||||
EtcHosts: hc,
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
UpstreamTimeout: upsTimeout,
|
||||
@@ -486,8 +523,7 @@ func TestServer_handleTestUpstreaDNS(t *testing.T) {
|
||||
"upstream_dns": []string{badUps},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
badUps: `upstream "` + badUps + `" fails to exchange: ` +
|
||||
`couldn't communicate with upstream: exchanging with ` +
|
||||
badUps: `couldn't communicate with upstream: exchanging with ` +
|
||||
badUps + ` over tcp: dns: id mismatch`,
|
||||
},
|
||||
name: "broken",
|
||||
@@ -497,20 +533,40 @@ func TestServer_handleTestUpstreaDNS(t *testing.T) {
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
goodUps: "OK",
|
||||
badUps: `upstream "` + badUps + `" fails to exchange: ` +
|
||||
`couldn't communicate with upstream: exchanging with ` +
|
||||
badUps: `couldn't communicate with upstream: exchanging with ` +
|
||||
badUps + ` over tcp: dns: id mismatch`,
|
||||
},
|
||||
name: "both",
|
||||
}, {
|
||||
body: map[string]any{
|
||||
"upstream_dns": []string{"[/domain.example/]" + badUps},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
"[/domain.example/]" + badUps: `WARNING: couldn't communicate ` +
|
||||
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
|
||||
`dns: id mismatch`,
|
||||
},
|
||||
name: "domain_specific_error",
|
||||
}, {
|
||||
body: map[string]any{
|
||||
"upstream_dns": []string{hostsUps},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
hostsUps: "OK",
|
||||
},
|
||||
name: "etc_hosts",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
reqBody, err := json.Marshal(tc.body)
|
||||
var reqBody []byte
|
||||
reqBody, err = json.Marshal(tc.body)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r, err := http.NewRequest(http.MethodPost, "", bytes.NewReader(reqBody))
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(reqBody))
|
||||
require.NoError(t, err)
|
||||
|
||||
srv.handleTestUpstreamDNS(w, r)
|
||||
@@ -538,11 +594,15 @@ func TestServer_handleTestUpstreaDNS(t *testing.T) {
|
||||
req := map[string]any{
|
||||
"upstream_dns": []string{sleepyUps},
|
||||
}
|
||||
reqBody, err := json.Marshal(req)
|
||||
|
||||
var reqBody []byte
|
||||
reqBody, err = json.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r, err := http.NewRequest(http.MethodPost, "", bytes.NewReader(reqBody))
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(reqBody))
|
||||
require.NoError(t, err)
|
||||
|
||||
srv.handleTestUpstreamDNS(w, r)
|
||||
|
||||
@@ -110,6 +110,9 @@ func ipsFromAnswer(ans []dns.RR) (ip4s, ip6s []net.IP) {
|
||||
|
||||
// process adds the resolved IP addresses to the domain's ipsets, if any.
|
||||
func (c *ipsetCtx) process(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: ipset: started processing")
|
||||
defer log.Debug("dnsforward: ipset: finished processing")
|
||||
|
||||
if c.skipIpsetProcessing(dctx) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
@@ -125,12 +128,12 @@ func (c *ipsetCtx) process(dctx *dnsContext) (rc resultCode) {
|
||||
n, err := c.ipsetMgr.Add(host, ip4s, ip6s)
|
||||
if err != nil {
|
||||
// Consider ipset errors non-critical to the request.
|
||||
log.Error("ipset: adding host ips: %s", err)
|
||||
log.Error("dnsforward: ipset: adding host ips: %s", err)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
log.Debug("ipset: added %d new ipset entries", n)
|
||||
log.Debug("dnsforward: ipset: added %d new ipset entries", n)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
@@ -57,16 +57,13 @@ func (s *Server) genDNSFilterMessage(
|
||||
return s.genBlockedHost(req, s.conf.SafeBrowsingBlockHost, dctx)
|
||||
case filtering.FilteredParental:
|
||||
return s.genBlockedHost(req, s.conf.ParentalBlockHost, dctx)
|
||||
case filtering.FilteredSafeSearch:
|
||||
// If Safe Search generated the necessary IP addresses, use them.
|
||||
// Otherwise, if there were no errors, there are no addresses for the
|
||||
// requested IP version, so produce a NODATA response.
|
||||
return s.genResponseWithIPs(req, ipsFromRules(res.Rules))
|
||||
default:
|
||||
// If the query was filtered by Safe Search, filtering also must return
|
||||
// the IP addresses that must be used in response. Return them
|
||||
// regardless of the filtering method.
|
||||
ips := ipsFromRules(res.Rules)
|
||||
if res.Reason == filtering.FilteredSafeSearch && len(ips) > 0 {
|
||||
return s.genResponseWithIPs(req, ips)
|
||||
}
|
||||
|
||||
return s.genForBlockingMode(req, ips)
|
||||
return s.genForBlockingMode(req, ipsFromRules(res.Rules))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,60 +17,78 @@ import (
|
||||
|
||||
// Write Stats data and logs
|
||||
func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing querylog and stats")
|
||||
defer log.Debug("dnsforward: finished processing querylog and stats")
|
||||
|
||||
elapsed := time.Since(dctx.startTime)
|
||||
pctx := dctx.proxyCtx
|
||||
|
||||
shouldLog := true
|
||||
msg := pctx.Req
|
||||
q := msg.Question[0]
|
||||
q := pctx.Req.Question[0]
|
||||
host := strings.ToLower(strings.TrimSuffix(q.Name, "."))
|
||||
|
||||
// don't log ANY request if refuseAny is enabled
|
||||
if q.Qtype == dns.TypeANY && s.conf.RefuseAny {
|
||||
shouldLog = false
|
||||
}
|
||||
|
||||
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
|
||||
ip = slices.Clone(ip)
|
||||
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
s.anonymizer.Load()(ip)
|
||||
|
||||
log.Debug("client ip: %s", ip)
|
||||
log.Debug("dnsforward: client ip for stats and querylog: %s", ip)
|
||||
|
||||
ipStr := ip.String()
|
||||
ids := []string{ipStr, dctx.clientID}
|
||||
qt, cl := q.Qtype, q.Qclass
|
||||
|
||||
// Synchronize access to s.queryLog and s.stats so they won't be suddenly
|
||||
// uninitialized while in use. This can happen after proxy server has been
|
||||
// stopped, but its workers haven't yet exited.
|
||||
if shouldLog &&
|
||||
s.queryLog != nil &&
|
||||
// TODO(s.chzhen): Use dnsforward.dnsContext when it will start
|
||||
// containing persistent client.
|
||||
s.queryLog.ShouldLog(host, q.Qtype, q.Qclass, ids) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
if s.shouldLog(host, qt, cl, ids) {
|
||||
s.logQuery(dctx, pctx, elapsed, ip)
|
||||
} else {
|
||||
log.Debug(
|
||||
"dnsforward: request %s %s from %s ignored; not logging",
|
||||
dns.Type(q.Qtype),
|
||||
"dnsforward: request %s %s %q from %s ignored; not adding to querylog",
|
||||
dns.Class(cl),
|
||||
dns.Type(qt),
|
||||
host,
|
||||
ip,
|
||||
)
|
||||
}
|
||||
|
||||
if s.stats != nil &&
|
||||
// TODO(s.chzhen): Use dnsforward.dnsContext when it will start
|
||||
// containing persistent client.
|
||||
s.stats.ShouldCount(host, q.Qtype, q.Qclass, ids) {
|
||||
if s.shouldCountStat(host, qt, cl, ids) {
|
||||
s.updateStats(dctx, elapsed, *dctx.result, ipStr)
|
||||
} else {
|
||||
log.Debug(
|
||||
"dnsforward: request %s %s %q from %s ignored; not counting in stats",
|
||||
dns.Class(cl),
|
||||
dns.Type(qt),
|
||||
host,
|
||||
ip,
|
||||
)
|
||||
}
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// shouldLog returns true if the query with the given data should be logged in
|
||||
// the query log. s.serverLock is expected to be locked.
|
||||
func (s *Server) shouldLog(host string, qt, cl uint16, ids []string) (ok bool) {
|
||||
if qt == dns.TypeANY && s.conf.RefuseAny {
|
||||
return false
|
||||
}
|
||||
|
||||
// TODO(s.chzhen): Use dnsforward.dnsContext when it will start containing
|
||||
// persistent client.
|
||||
return s.queryLog != nil && s.queryLog.ShouldLog(host, qt, cl, ids)
|
||||
}
|
||||
|
||||
// shouldCountStat returns true if the query with the given data should be
|
||||
// counted in the statistics. s.serverLock is expected to be locked.
|
||||
func (s *Server) shouldCountStat(host string, qt, cl uint16, ids []string) (ok bool) {
|
||||
// TODO(s.chzhen): Use dnsforward.dnsContext when it will start containing
|
||||
// persistent client.
|
||||
return s.stats != nil && s.stats.ShouldCount(host, qt, cl, ids)
|
||||
}
|
||||
|
||||
// logQuery pushes the request details into the query log.
|
||||
func (s *Server) logQuery(
|
||||
dctx *dnsContext,
|
||||
@@ -123,7 +141,10 @@ func (s *Server) updateStats(
|
||||
pctx := ctx.proxyCtx
|
||||
e := stats.Entry{}
|
||||
e.Domain = strings.ToLower(pctx.Req.Question[0].Name)
|
||||
e.Domain = e.Domain[:len(e.Domain)-1] // remove last "."
|
||||
if e.Domain != "." {
|
||||
// Remove last ".", but save the domain as is for "." queries.
|
||||
e.Domain = e.Domain[:len(e.Domain)-1]
|
||||
}
|
||||
|
||||
if clientID := ctx.clientID; clientID != "" {
|
||||
e.Client = clientID
|
||||
|
||||
@@ -46,6 +46,10 @@ type testStats struct {
|
||||
|
||||
// Update implements the [stats.Interface] interface for *testStats.
|
||||
func (l *testStats) Update(e stats.Entry) {
|
||||
if e.Domain == "" {
|
||||
return
|
||||
}
|
||||
|
||||
l.lastEntry = e
|
||||
}
|
||||
|
||||
@@ -54,9 +58,12 @@ func (l *testStats) ShouldCount(string, uint16, uint16, []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
|
||||
const domain = "example.com."
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
domain string
|
||||
proto proxy.Proto
|
||||
addr net.Addr
|
||||
clientID string
|
||||
@@ -67,6 +74,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult stats.Result
|
||||
}{{
|
||||
name: "success_udp",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
@@ -77,6 +85,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult: stats.RNotFiltered,
|
||||
}, {
|
||||
name: "success_tls_clientid",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoTLS,
|
||||
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "cli42",
|
||||
@@ -87,6 +96,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult: stats.RNotFiltered,
|
||||
}, {
|
||||
name: "success_tls",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoTLS,
|
||||
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
@@ -97,6 +107,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult: stats.RNotFiltered,
|
||||
}, {
|
||||
name: "success_quic",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoQUIC,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
@@ -107,6 +118,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult: stats.RNotFiltered,
|
||||
}, {
|
||||
name: "success_https",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoHTTPS,
|
||||
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
@@ -117,6 +129,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult: stats.RNotFiltered,
|
||||
}, {
|
||||
name: "success_dnscrypt",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoDNSCrypt,
|
||||
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
@@ -127,6 +140,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult: stats.RNotFiltered,
|
||||
}, {
|
||||
name: "success_udp_filtered",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
@@ -137,6 +151,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult: stats.RFiltered,
|
||||
}, {
|
||||
name: "success_udp_sb",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
@@ -147,6 +162,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult: stats.RSafeBrowsing,
|
||||
}, {
|
||||
name: "success_udp_ss",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
@@ -157,6 +173,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult: stats.RSafeSearch,
|
||||
}, {
|
||||
name: "success_udp_pc",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
@@ -165,6 +182,17 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantCode: resultCodeSuccess,
|
||||
reason: filtering.FilteredParental,
|
||||
wantStatResult: stats.RParental,
|
||||
}, {
|
||||
name: "success_udp_pc_empty_fqdn",
|
||||
domain: ".",
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 5}, Port: 1234},
|
||||
clientID: "",
|
||||
wantLogProto: "",
|
||||
wantStatClient: "1.2.3.5",
|
||||
wantCode: resultCodeSuccess,
|
||||
reason: filtering.FilteredParental,
|
||||
wantStatResult: stats.RParental,
|
||||
}}
|
||||
|
||||
ups, err := upstream.AddressToUpstream("1.1.1.1", nil)
|
||||
@@ -181,7 +209,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
Name: "example.com.",
|
||||
Name: tc.domain,
|
||||
}},
|
||||
}
|
||||
pctx := &proxy.DNSContext{
|
||||
|
||||
311
internal/dnsforward/upstreams.go
Normal file
311
internal/dnsforward/upstreams.go
Normal file
@@ -0,0 +1,311 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// loadUpstreams parses upstream DNS servers from the configured file or from
|
||||
// the configuration itself.
|
||||
func (s *Server) loadUpstreams() (upstreams []string, err error) {
|
||||
if s.conf.UpstreamDNSFileName == "" {
|
||||
return stringutil.FilterOut(s.conf.UpstreamDNS, IsCommentOrEmpty), nil
|
||||
}
|
||||
|
||||
var data []byte
|
||||
data, err = os.ReadFile(s.conf.UpstreamDNSFileName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading upstream from file: %w", err)
|
||||
}
|
||||
|
||||
upstreams = stringutil.SplitTrimmed(string(data), "\n")
|
||||
|
||||
log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), s.conf.UpstreamDNSFileName)
|
||||
|
||||
return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil
|
||||
}
|
||||
|
||||
// prepareUpstreamSettings sets upstream DNS server settings.
|
||||
func (s *Server) prepareUpstreamSettings() (err error) {
|
||||
// We're setting a customized set of RootCAs. The reason is that Go default
|
||||
// mechanism of loading TLS roots does not always work properly on some
|
||||
// routers so we're loading roots manually and pass it here.
|
||||
//
|
||||
// See [aghtls.SystemRootCAs].
|
||||
upstream.RootCAs = s.conf.TLSv12Roots
|
||||
upstream.CipherSuites = s.conf.TLSCiphers
|
||||
|
||||
// Load upstreams either from the file, or from the settings
|
||||
var upstreams []string
|
||||
upstreams, err = s.loadUpstreams()
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading upstreams: %w", err)
|
||||
}
|
||||
|
||||
s.conf.UpstreamConfig, err = s.prepareUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
|
||||
Bootstrap: s.conf.BootstrapDNS,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing upstream config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareUpstreamConfig sets upstream configuration based on upstreams and
|
||||
// configuration of s.
|
||||
func (s *Server) prepareUpstreamConfig(
|
||||
upstreams []string,
|
||||
defaultUpstreams []string,
|
||||
opts *upstream.Options,
|
||||
) (uc *proxy.UpstreamConfig, err error) {
|
||||
uc, err = proxy.ParseUpstreamsConfig(upstreams, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing upstream config: %w", err)
|
||||
}
|
||||
|
||||
if len(uc.Upstreams) == 0 && defaultUpstreams != nil {
|
||||
log.Info("dnsforward: warning: no default upstreams specified, using %v", defaultUpstreams)
|
||||
var defaultUpstreamConfig *proxy.UpstreamConfig
|
||||
defaultUpstreamConfig, err = proxy.ParseUpstreamsConfig(defaultUpstreams, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing default upstreams: %w", err)
|
||||
}
|
||||
|
||||
uc.Upstreams = defaultUpstreamConfig.Upstreams
|
||||
}
|
||||
|
||||
if s.dnsFilter != nil && s.dnsFilter.EtcHosts != nil {
|
||||
err = s.replaceUpstreamsWithHosts(uc, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolving upstreams with hosts: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
// replaceUpstreamsWithHosts replaces unique upstreams with their resolved
|
||||
// versions based on the system hosts file.
|
||||
//
|
||||
// TODO(e.burkov): This should be performed inside dnsproxy, which should
|
||||
// actually consider /etc/hosts. See TODO on [aghnet.HostsContainer].
|
||||
func (s *Server) replaceUpstreamsWithHosts(
|
||||
upsConf *proxy.UpstreamConfig,
|
||||
opts *upstream.Options,
|
||||
) (err error) {
|
||||
resolved := map[string]*upstream.Options{}
|
||||
|
||||
err = s.resolveUpstreamsWithHosts(resolved, upsConf.Upstreams, opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolving upstreams: %w", err)
|
||||
}
|
||||
|
||||
hosts := maps.Keys(upsConf.DomainReservedUpstreams)
|
||||
// TODO(e.burkov): Think of extracting sorted range into an util function.
|
||||
slices.Sort(hosts)
|
||||
for _, host := range hosts {
|
||||
err = s.resolveUpstreamsWithHosts(resolved, upsConf.DomainReservedUpstreams[host], opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolving upstreams reserved for %s: %w", host, err)
|
||||
}
|
||||
}
|
||||
|
||||
hosts = maps.Keys(upsConf.SpecifiedDomainUpstreams)
|
||||
slices.Sort(hosts)
|
||||
for _, host := range hosts {
|
||||
err = s.resolveUpstreamsWithHosts(resolved, upsConf.SpecifiedDomainUpstreams[host], opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolving upstreams specific for %s: %w", host, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveUpstreamsWithHosts resolves the IP addresses of each of the upstreams
|
||||
// and replaces those both in upstreams and resolved. Upstreams that failed to
|
||||
// resolve are placed to resolved as-is. This function only returns error of
|
||||
// upstreams closing.
|
||||
func (s *Server) resolveUpstreamsWithHosts(
|
||||
resolved map[string]*upstream.Options,
|
||||
upstreams []upstream.Upstream,
|
||||
opts *upstream.Options,
|
||||
) (err error) {
|
||||
for i := range upstreams {
|
||||
u := upstreams[i]
|
||||
addr := u.Address()
|
||||
host := extractUpstreamHost(addr)
|
||||
|
||||
withIPs, ok := resolved[host]
|
||||
if !ok {
|
||||
ips := s.resolveUpstreamHost(host)
|
||||
if len(ips) == 0 {
|
||||
resolved[host] = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
sortNetIPAddrs(ips, opts.PreferIPv6)
|
||||
|
||||
withIPs = opts.Clone()
|
||||
withIPs.ServerIPAddrs = ips
|
||||
resolved[host] = withIPs
|
||||
} else if withIPs == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if err = u.Close(); err != nil {
|
||||
return fmt.Errorf("closing upstream %s: %w", addr, err)
|
||||
}
|
||||
|
||||
upstreams[i], err = upstream.AddressToUpstream(addr, withIPs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("replacing upstream %s with resolved %s: %w", addr, host, err)
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: using %s for %s", withIPs.ServerIPAddrs, upstreams[i].Address())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractUpstreamHost returns the hostname of addr without port with an
|
||||
// assumption that any address passed here has already been successfully parsed
|
||||
// by [upstream.AddressToUpstream]. This function eesentially mirrors the logic
|
||||
// of [upstream.AddressToUpstream], see TODO on [replaceUpstreamsWithHosts].
|
||||
func extractUpstreamHost(addr string) (host string) {
|
||||
var err error
|
||||
if strings.Contains(addr, "://") {
|
||||
var u *url.URL
|
||||
u, err = url.Parse(addr)
|
||||
if err != nil {
|
||||
log.Debug("dnsforward: parsing upstream %s: %s", addr, err)
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
return u.Hostname()
|
||||
}
|
||||
|
||||
// Probably, plain UDP upstream defined by address or address:port.
|
||||
host, err = netutil.SplitHost(addr)
|
||||
if err != nil {
|
||||
return addr
|
||||
}
|
||||
|
||||
return host
|
||||
}
|
||||
|
||||
// resolveUpstreamHost returns the version of ups with IP addresses from the
|
||||
// system hosts file placed into its options.
|
||||
func (s *Server) resolveUpstreamHost(host string) (addrs []net.IP) {
|
||||
req := &urlfilter.DNSRequest{
|
||||
Hostname: host,
|
||||
DNSType: dns.TypeA,
|
||||
}
|
||||
aRes, _ := s.dnsFilter.EtcHosts.MatchRequest(req)
|
||||
|
||||
req.DNSType = dns.TypeAAAA
|
||||
aaaaRes, _ := s.dnsFilter.EtcHosts.MatchRequest(req)
|
||||
|
||||
var ips []net.IP
|
||||
for _, rw := range append(aRes.DNSRewrites(), aaaaRes.DNSRewrites()...) {
|
||||
dr := rw.DNSRewrite
|
||||
if dr == nil || dr.Value == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if ip, ok := dr.Value.(net.IP); ok {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
}
|
||||
|
||||
return ips
|
||||
}
|
||||
|
||||
// sortNetIPAddrs sorts addrs in accordance with the protocol preferences.
|
||||
// Invalid addresses are sorted near the end.
|
||||
//
|
||||
// TODO(e.burkov): This function taken from dnsproxy, which also already
|
||||
// contains a few similar functions. Think of moving to golibs.
|
||||
func sortNetIPAddrs(addrs []net.IP, preferIPv6 bool) {
|
||||
l := len(addrs)
|
||||
if l <= 1 {
|
||||
return
|
||||
}
|
||||
|
||||
slices.SortStableFunc(addrs, func(addrA, addrB net.IP) (sortsBefore bool) {
|
||||
switch len(addrA) {
|
||||
case net.IPv4len, net.IPv6len:
|
||||
switch len(addrB) {
|
||||
case net.IPv4len, net.IPv6len:
|
||||
// Go on.
|
||||
default:
|
||||
return true
|
||||
}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
if aIs4, bIs4 := addrA.To4() != nil, addrB.To4() != nil; aIs4 != bIs4 {
|
||||
if aIs4 {
|
||||
return !preferIPv6
|
||||
}
|
||||
|
||||
return preferIPv6
|
||||
}
|
||||
|
||||
return bytes.Compare(addrA, addrB) < 0
|
||||
})
|
||||
}
|
||||
|
||||
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration
|
||||
// depending on configuration.
|
||||
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
|
||||
if !http3 {
|
||||
return upstream.DefaultHTTPVersions
|
||||
}
|
||||
|
||||
return []upstream.HTTPVersion{
|
||||
upstream.HTTPVersion3,
|
||||
upstream.HTTPVersion2,
|
||||
upstream.HTTPVersion11,
|
||||
}
|
||||
}
|
||||
|
||||
// setProxyUpstreamMode sets the upstream mode and related settings in conf
|
||||
// based on provided parameters.
|
||||
func setProxyUpstreamMode(
|
||||
conf *proxy.Config,
|
||||
allServers bool,
|
||||
fastestAddr bool,
|
||||
fastestTimeout time.Duration,
|
||||
) {
|
||||
if allServers {
|
||||
conf.UpstreamMode = proxy.UModeParallel
|
||||
} else if fastestAddr {
|
||||
conf.UpstreamMode = proxy.UModeFastestAddr
|
||||
conf.FastestPingTimeout = fastestTimeout
|
||||
} else {
|
||||
conf.UpstreamMode = proxy.UModeLoadBalance
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user