all: sync with master; upd chlog

This commit is contained in:
Ainar Garipov
2023-09-07 17:13:48 +03:00
parent 3be7676970
commit 7b93f5d7cf
306 changed files with 19770 additions and 4916 deletions

View File

@@ -28,6 +28,7 @@ type accessManager struct {
allowedClientIDs *stringutil.Set
blockedClientIDs *stringutil.Set
// TODO(s.chzhen): Use [aghnet.IgnoreEngine].
blockedHostsEng *urlfilter.DNSEngine
// TODO(a.garipov): Create a type for a set of IP networks.
@@ -131,7 +132,6 @@ func (a *accessManager) isBlockedClientID(id string) (ok bool) {
func (a *accessManager) isBlockedHost(host string, qt rules.RRType) (ok bool) {
_, ok = a.blockedHostsEng.MatchRequest(&urlfilter.DNSRequest{
Hostname: host,
ClientIP: "0.0.0.0",
DNSType: qt,
})
@@ -183,7 +183,7 @@ func (s *Server) accessListJSON() (j accessListJSON) {
}
func (s *Server) handleAccessList(w http.ResponseWriter, r *http.Request) {
_ = aghhttp.WriteJSONResponse(w, r, s.accessListJSON())
aghhttp.WriteJSONResponseOK(w, r, s.accessListJSON())
}
// validateAccessSet checks the internal accessListJSON lists. To search for

View File

@@ -25,74 +25,19 @@ import (
"golang.org/x/exp/slices"
)
// BlockingMode is an enum of all allowed blocking modes.
type BlockingMode string
// Allowed blocking modes.
const (
// BlockingModeCustomIP means respond with a custom IP address.
BlockingModeCustomIP BlockingMode = "custom_ip"
// BlockingModeDefault is the same as BlockingModeNullIP for
// Adblock-style rules, but responds with the IP address specified in
// the rule when blocked by an `/etc/hosts`-style rule.
BlockingModeDefault BlockingMode = "default"
// BlockingModeNullIP means respond with a zero IP address: "0.0.0.0"
// for A requests and "::" for AAAA ones.
BlockingModeNullIP BlockingMode = "null_ip"
// BlockingModeNXDOMAIN means respond with the NXDOMAIN code.
BlockingModeNXDOMAIN BlockingMode = "nxdomain"
// BlockingModeREFUSED means respond with the REFUSED code.
BlockingModeREFUSED BlockingMode = "refused"
)
// FilteringConfig represents the DNS filtering configuration of AdGuard Home
// The zero FilteringConfig is empty and ready for use.
type FilteringConfig struct {
// Config represents the DNS filtering configuration of AdGuard Home. The zero
// Config is empty and ready for use.
type Config struct {
// Callbacks for other modules
// FilterHandler is an optional additional filtering callback.
FilterHandler func(clientAddr net.IP, clientID string, settings *filtering.Settings) `yaml:"-"`
FilterHandler func(cliAddr netip.Addr, clientID string, settings *filtering.Settings) `yaml:"-"`
// GetCustomUpstreamByClient is a callback that returns upstreams
// configuration based on the client IP address or ClientID. It returns
// nil if there are no custom upstreams for the client.
GetCustomUpstreamByClient func(id string) (conf *proxy.UpstreamConfig, err error) `yaml:"-"`
// Protection configuration
// ProtectionEnabled defines whether or not use any of filtering features.
ProtectionEnabled bool `yaml:"protection_enabled"`
// BlockingMode defines the way how blocked responses are constructed.
BlockingMode BlockingMode `yaml:"blocking_mode"`
// BlockingIPv4 is the IP address to be returned for a blocked A request.
BlockingIPv4 net.IP `yaml:"blocking_ipv4"`
// BlockingIPv6 is the IP address to be returned for a blocked AAAA
// request.
BlockingIPv6 net.IP `yaml:"blocking_ipv6"`
// BlockedResponseTTL is the time-to-live value for blocked responses. If
// 0, then default value is used (3600).
BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"`
// ProtectionDisabledUntil is the timestamp until when the protection is
// disabled.
ProtectionDisabledUntil *time.Time `yaml:"protection_disabled_until"`
// ParentalBlockHost is the IP (or domain name) which is used to respond to
// DNS requests blocked by parental control.
ParentalBlockHost string `yaml:"parental_block_host"`
// SafeBrowsingBlockHost is the IP (or domain name) which is used to
// respond to DNS requests blocked by safe-browsing.
SafeBrowsingBlockHost string `yaml:"safebrowsing_block_host"`
// Anti-DNS amplification
// Ratelimit is the maximum number of requests per second from a given IP
@@ -118,6 +63,10 @@ type FilteringConfig struct {
// resolvers (plain DNS only).
BootstrapDNS []string `yaml:"bootstrap_dns"`
// FallbackDNS is the list of fallback DNS servers used when upstream DNS
// servers are not responding.
FallbackDNS []string `yaml:"fallback_dns"`
// AllServers, if true, parallel queries to all configured upstream servers
// are enabled.
AllServers bool `yaml:"all_servers"`
@@ -133,7 +82,7 @@ type FilteringConfig struct {
// AllowedClients is the slice of IP addresses, CIDR networks, and
// ClientIDs of allowed clients. If not empty, only these clients are
// allowed, and [FilteringConfig.DisallowedClients] are ignored.
// allowed, and [Config.DisallowedClients] are ignored.
AllowedClients []string `yaml:"allowed_clients"`
// DisallowedClients is the slice of IP addresses, CIDR networks, and
@@ -279,7 +228,7 @@ type ServerConfig struct {
// Remove that.
AddrProcConf *client.DefaultAddrProcConfig
FilteringConfig
Config
TLSConfig
DNSCryptConfig
TLSAllowUnencryptedDoH bool
@@ -320,13 +269,6 @@ type ServerConfig struct {
UseHTTP3Upstreams bool
}
// if any of ServerConfig values are zero, then default values from below are used
var defaultValues = ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{Port: 53}},
TCPListenAddrs: []*net.TCPAddr{{Port: 53}},
FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600},
}
// createProxyConfig creates and validates configuration for the main proxy.
func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
srvConf := s.conf
@@ -399,10 +341,7 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
return conf, nil
}
const (
defaultSafeBrowsingBlockHost = "standard-block.dns.adguard.com"
defaultParentalBlockHost = "family-block.dns.adguard.com"
)
const defaultBlockedResponseTTL = 3600
// initDefaultSettings initializes default settings if nothing
// is configured
@@ -415,20 +354,12 @@ func (s *Server) initDefaultSettings() {
s.conf.BootstrapDNS = defaultBootstrap
}
if s.conf.ParentalBlockHost == "" {
s.conf.ParentalBlockHost = defaultParentalBlockHost
}
if s.conf.SafeBrowsingBlockHost == "" {
s.conf.SafeBrowsingBlockHost = defaultSafeBrowsingBlockHost
}
if s.conf.UDPListenAddrs == nil {
s.conf.UDPListenAddrs = defaultValues.UDPListenAddrs
s.conf.UDPListenAddrs = defaultUDPListenAddrs
}
if s.conf.TCPListenAddrs == nil {
s.conf.TCPListenAddrs = defaultValues.TCPListenAddrs
s.conf.TCPListenAddrs = defaultTCPListenAddrs
}
if len(s.conf.BlockedHosts) == 0 {
@@ -561,9 +492,9 @@ func (s *Server) UpdatedProtectionStatus() (enabled bool, disabledUntil *time.Ti
s.serverLock.RLock()
defer s.serverLock.RUnlock()
disabledUntil = s.conf.ProtectionDisabledUntil
enabled, disabledUntil = s.dnsFilter.ProtectionStatus()
if disabledUntil == nil {
return s.conf.ProtectionEnabled, nil
return enabled, nil
}
if time.Now().Before(*disabledUntil) {
@@ -595,8 +526,7 @@ func (s *Server) enableProtectionAfterPause() {
s.serverLock.Lock()
defer s.serverLock.Unlock()
s.conf.ProtectionEnabled = true
s.conf.ProtectionDisabledUntil = nil
s.dnsFilter.SetProtectionStatus(true, nil)
log.Info("dns: protection is restarted after pause")
}

View File

@@ -49,9 +49,8 @@ func (s *Server) DialContext(ctx context.Context, network, addr string) (conn ne
continue
}
return conn, err
return conn, nil
}
// TODO(a.garipov): Use errors.Join in Go 1.20.
return nil, errors.List(fmt.Sprintf("dialing %q", addr), dialErrs...)
return nil, errors.Join(dialErrs...)
}

View File

@@ -283,11 +283,13 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
// right after stop, due to a data race in [proxy.Proxy.Init] method
// when setting an OOB size. As a temporary workaround, recreate the
// whole server for each test case.
s := createTestServer(t, &filtering.Config{}, ServerConfig{
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
UseDNS64: true,
FilteringConfig: FilteringConfig{
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
}, localUps)

View File

@@ -15,7 +15,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
@@ -46,20 +45,16 @@ var defaultBootstrap = []string{"9.9.9.10", "149.112.112.10", "2620:fe::10", "26
// Often requested by all kinds of DNS probes
var defaultBlockedHosts = []string{"version.bind", "id.server", "hostname.bind"}
var (
// defaultUDPListenAddrs are the default UDP addresses for the server.
defaultUDPListenAddrs = []*net.UDPAddr{{Port: 53}}
// defaultTCPListenAddrs are the default TCP addresses for the server.
defaultTCPListenAddrs = []*net.TCPAddr{{Port: 53}}
)
var webRegistered bool
// hostToIPTable is a convenient type alias for tables of host names to an IP
// address.
//
// TODO(e.burkov): Use the [DHCP] interface instead.
type hostToIPTable = map[string]netip.Addr
// ipToHostTable is a convenient type alias for tables of IP addresses to their
// host names. For example, for use with PTR queries.
//
// TODO(e.burkov): Use the [DHCP] interface instead.
type ipToHostTable = map[netip.Addr]string
// DHCP is an interface for accessing DHCP lease data needed in this package.
type DHCP interface {
// HostByIP returns the hostname of the DHCP client with the given IP
@@ -89,18 +84,34 @@ type DHCP interface {
//
// The zero Server is empty and ready for use.
type Server struct {
dnsProxy *proxy.Proxy // DNS proxy instance
dnsFilter *filtering.DNSFilter // DNS filter instance
dhcpServer dhcpd.Interface // DHCP server instance (optional)
queryLog querylog.QueryLog // Query log instance
stats stats.Interface
access *accessManager
// dnsProxy is the DNS proxy for forwarding client's DNS requests.
dnsProxy *proxy.Proxy
// dnsFilter is the DNS filter for filtering client's DNS requests and
// responses.
dnsFilter *filtering.DNSFilter
// dhcpServer is the DHCP server for accessing lease data.
dhcpServer DHCP
// queryLog is the query log for client's DNS requests, responses and
// filtering results.
queryLog querylog.QueryLog
// stats is the statistics collector for client's DNS usage data.
stats stats.Interface
// access drops unallowed clients.
access *accessManager
// localDomainSuffix is the suffix used to detect internal hosts. It
// must be a valid domain name plus dots on each side.
localDomainSuffix string
ipset ipsetCtx
// ipset processes DNS requests using ipset data.
ipset ipsetCtx
// privateNets is the configured set of IP networks considered private.
privateNets netutil.SubnetSet
// addrProc, if not nil, is used to process clients' IP addresses with rDNS,
@@ -112,7 +123,10 @@ type Server struct {
//
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
localResolvers *proxy.Proxy
sysResolvers aghnet.SystemResolvers
// sysResolvers used to fetch system resolvers to use by default for private
// PTR resolving.
sysResolvers aghnet.SystemResolvers
// recDetector is a cache for recursive requests. It is used to detect
// and prevent recursive requests only for private upstreams.
@@ -128,12 +142,6 @@ type Server struct {
// anonymizer masks the client's IP addresses if needed.
anonymizer *aghnet.IPMut
tableHostToIP hostToIPTable
tableHostToIPLock sync.Mutex
tableIPToHost ipToHostTable
tableIPToHostLock sync.Mutex
// clientIDCache is a temporary storage for ClientIDs that were extracted
// during the BeforeRequestHandler stage.
clientIDCache cache.Cache
@@ -142,13 +150,16 @@ type Server struct {
// We don't Start() it and so no listen port is required.
internalProxy *proxy.Proxy
// isRunning is true if the DNS server is running.
isRunning bool
// protectionUpdateInProgress is used to make sure that only one goroutine
// updating the protection configuration after a pause is running at a time.
protectionUpdateInProgress atomic.Bool
// conf is the current configuration of the server.
conf ServerConfig
// serverLock protects Server.
serverLock sync.RWMutex
}
@@ -164,7 +175,7 @@ type DNSCreateParams struct {
DNSFilter *filtering.DNSFilter
Stats stats.Interface
QueryLog querylog.QueryLog
DHCPServer dhcpd.Interface
DHCPServer DHCP
PrivateNets netutil.SubnetSet
Anonymizer *aghnet.IPMut
LocalDomain string
@@ -200,11 +211,12 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
p.Anonymizer = aghnet.NewIPMut(nil)
}
s = &Server{
dnsFilter: p.DNSFilter,
stats: p.Stats,
queryLog: p.QueryLog,
privateNets: p.PrivateNets,
localDomainSuffix: localDomainSuffix,
dnsFilter: p.DNSFilter,
stats: p.Stats,
queryLog: p.QueryLog,
privateNets: p.PrivateNets,
// TODO(e.burkov): Use some case-insensitive string comparison.
localDomainSuffix: strings.ToLower(localDomainSuffix),
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
clientIDCache: cache.New(cache.Config{
EnableLRU: true,
@@ -220,11 +232,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
return nil, fmt.Errorf("initializing system resolvers: %w", err)
}
if p.DHCPServer != nil {
s.dhcpServer = p.DHCPServer
s.dhcpServer.SetOnLeaseChanged(s.onDHCPLeaseChanged)
s.onDHCPLeaseChanged(dhcpd.LeaseChangedAdded)
}
s.dhcpServer = p.DHCPServer
if runtime.GOARCH == "mips" || runtime.GOARCH == "mipsle" {
// Use plain DNS on MIPS, encryption is too slow
@@ -244,7 +252,7 @@ func (s *Server) Close() {
s.serverLock.Lock()
defer s.serverLock.Unlock()
s.dnsFilter = nil
// TODO(s.chzhen): Remove it.
s.stats = nil
s.queryLog = nil
s.dnsProxy = nil
@@ -255,14 +263,15 @@ func (s *Server) Close() {
}
// WriteDiskConfig - write configuration
func (s *Server) WriteDiskConfig(c *FilteringConfig) {
func (s *Server) WriteDiskConfig(c *Config) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()
sc := s.conf.FilteringConfig
sc := s.conf.Config
*c = sc
c.RatelimitWhitelist = stringutil.CloneSlice(sc.RatelimitWhitelist)
c.BootstrapDNS = stringutil.CloneSlice(sc.BootstrapDNS)
c.FallbackDNS = stringutil.CloneSlice(sc.FallbackDNS)
c.AllowedClients = stringutil.CloneSlice(sc.AllowedClients)
c.DisallowedClients = stringutil.CloneSlice(sc.DisallowedClients)
c.BlockedHosts = stringutil.CloneSlice(sc.BlockedHosts)
@@ -533,9 +542,13 @@ func (s *Server) setupLocalResolvers() (err error) {
func (s *Server) Prepare(conf *ServerConfig) (err error) {
s.conf = *conf
err = validateBlockingMode(s.conf.BlockingMode, s.conf.BlockingIPv4, s.conf.BlockingIPv6)
if err != nil {
return fmt.Errorf("checking blocking mode: %w", err)
// dnsFilter can be nil during application update.
if s.dnsFilter != nil {
mode, bIPv4, bIPv6 := s.dnsFilter.BlockingMode()
err = validateBlockingMode(mode, bIPv4, bIPv6)
if err != nil {
return fmt.Errorf("checking blocking mode: %w", err)
}
}
s.initDefaultSettings()
@@ -584,6 +597,11 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
return fmt.Errorf("setting up resolvers: %w", err)
}
err = s.setupFallbackDNS()
if err != nil {
return fmt.Errorf("setting up fallback dns servers: %w", err)
}
s.recDetector.clear()
s.setupAddrProc()
@@ -593,6 +611,28 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
return nil
}
// setupFallbackDNS initializes the fallback DNS servers.
func (s *Server) setupFallbackDNS() (err error) {
fallbacks := s.conf.FallbackDNS
if len(fallbacks) == 0 {
return nil
}
uc, err := proxy.ParseUpstreamsConfig(fallbacks, &upstream.Options{
// TODO(s.chzhen): Investigate if other options are needed.
Timeout: s.conf.UpstreamTimeout,
PreferIPv6: s.conf.BootstrapPreferIPv6,
})
if err != nil {
// Do not wrap the error because it's informative enough as is.
return err
}
s.dnsProxy.Fallbacks = uc
return nil
}
// setupAddrProc initializes the address processor. For internal use only.
func (s *Server) setupAddrProc() {
// TODO(a.garipov): This is a crutch for tests; remove.
@@ -617,19 +657,22 @@ func (s *Server) setupAddrProc() {
}
// validateBlockingMode returns an error if the blocking mode data aren't valid.
func validateBlockingMode(mode BlockingMode, blockingIPv4, blockingIPv6 net.IP) (err error) {
func validateBlockingMode(
mode filtering.BlockingMode,
blockingIPv4, blockingIPv6 netip.Addr,
) (err error) {
switch mode {
case
BlockingModeDefault,
BlockingModeNXDOMAIN,
BlockingModeREFUSED,
BlockingModeNullIP:
filtering.BlockingModeDefault,
filtering.BlockingModeNXDOMAIN,
filtering.BlockingModeREFUSED,
filtering.BlockingModeNullIP:
return nil
case BlockingModeCustomIP:
if blockingIPv4 == nil {
return fmt.Errorf("blocking_ipv4 must be set when blocking_mode is custom_ip")
} else if blockingIPv6 == nil {
return fmt.Errorf("blocking_ipv6 must be set when blocking_mode is custom_ip")
case filtering.BlockingModeCustomIP:
if !blockingIPv4.Is4() {
return fmt.Errorf("blocking_ipv4 must be valid ipv4 on custom_ip blocking_mode")
} else if !blockingIPv6.Is6() {
return fmt.Errorf("blocking_ipv6 must be valid ipv6 on custom_ip blocking_mode")
}
return nil

View File

@@ -22,7 +22,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/hashprefix"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
@@ -94,17 +93,18 @@ func createTestServer(
f.SetEnabled(true)
dhcp := &testDHCP{
OnEnabled: func() (ok bool) { return false },
OnHostByIP: func(ip netip.Addr) (host string) { return "" },
OnIPByHost: func(host string) (ip netip.Addr) { panic("not implemented") },
}
s, err = NewServer(DNSCreateParams{
DHCPServer: testDHCP,
DHCPServer: dhcp,
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
require.NoError(t, err)
if forwardConf.BlockingMode == "" {
forwardConf.BlockingMode = BlockingModeDefault
}
err = s.Prepare(&forwardConf)
require.NoError(t, err)
@@ -174,10 +174,12 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte)
var keyPem []byte
_, certPem, keyPem = createServerTLSConfig(t)
s = createTestServer(t, &filtering.Config{}, ServerConfig{
s = createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
FilteringConfig: FilteringConfig{
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
}, nil)
@@ -231,18 +233,29 @@ func createTestMessageWithType(host string, qtype uint16) *dns.Msg {
return req
}
func assertGoogleAResponse(t *testing.T, reply *dns.Msg) {
assertResponse(t, reply, net.IP{8, 8, 8, 8})
// newResp returns the new DNS response with response code set to rcode, req
// used as request, and rrs added.
func newResp(rcode int, req *dns.Msg, ans []dns.RR) (resp *dns.Msg) {
resp = (&dns.Msg{}).SetRcode(req, rcode)
resp.RecursionAvailable = true
resp.Compress = true
resp.Answer = ans
return resp
}
func assertResponse(t *testing.T, reply *dns.Msg, ip net.IP) {
func assertGoogleAResponse(t *testing.T, reply *dns.Msg) {
assertResponse(t, reply, netip.AddrFrom4([4]byte{8, 8, 8, 8}))
}
func assertResponse(t *testing.T, reply *dns.Msg, ip netip.Addr) {
t.Helper()
require.Lenf(t, reply.Answer, 1, "dns server returned reply with wrong number of answers - %d", len(reply.Answer))
a, ok := reply.Answer[0].(*dns.A)
require.Truef(t, ok, "dns server returned wrong answer type instead of A: %v", reply.Answer[0])
assert.Truef(t, a.A.Equal(ip), "dns server returned wrong answer instead of %s: %s", ip, a.A)
assert.Equal(t, net.IP(ip.AsSlice()), a.A)
}
// sendTestMessagesAsync sends messages in parallel to check for race issues.
@@ -288,10 +301,12 @@ func sendTestMessages(t *testing.T, conn *dns.Conn) {
}
func TestServer(t *testing.T) {
s := createTestServer(t, &filtering.Config{}, ServerConfig{
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
FilteringConfig: FilteringConfig{
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
}, nil)
@@ -329,13 +344,12 @@ func TestServer_timeout(t *testing.T) {
t.Run("custom", func(t *testing.T) {
srvConf := &ServerConfig{
UpstreamTimeout: testTimeout,
FilteringConfig: FilteringConfig{
BlockingMode: BlockingModeDefault,
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
}
s, err := NewServer(DNSCreateParams{})
s, err := NewServer(DNSCreateParams{DNSFilter: createTestDNSFilter(t)})
require.NoError(t, err)
err = s.Prepare(srvConf)
@@ -345,11 +359,10 @@ func TestServer_timeout(t *testing.T) {
})
t.Run("default", func(t *testing.T) {
s, err := NewServer(DNSCreateParams{})
s, err := NewServer(DNSCreateParams{DNSFilter: createTestDNSFilter(t)})
require.NoError(t, err)
s.conf.FilteringConfig.BlockingMode = BlockingModeDefault
s.conf.FilteringConfig.EDNSClientSubnet = &EDNSClientSubnet{
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{
Enabled: false,
}
err = s.Prepare(&s.conf)
@@ -360,10 +373,12 @@ func TestServer_timeout(t *testing.T) {
}
func TestServerWithProtectionDisabled(t *testing.T) {
s := createTestServer(t, &filtering.Config{}, ServerConfig{
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
FilteringConfig: FilteringConfig{
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
}, nil)
@@ -439,9 +454,8 @@ func TestServerRace(t *testing.T) {
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
Config: Config{
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
},
ConfigModified: func() {},
}
@@ -462,7 +476,7 @@ func TestSafeSearch(t *testing.T) {
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
ip4, ip6 := aghtest.HostToIPs(host)
return []net.IP{ip4, ip6}, nil
return []net.IP{ip4.AsSlice(), ip6.AsSlice()}, nil
},
}
@@ -474,6 +488,8 @@ func TestSafeSearch(t *testing.T) {
}
filterConf := &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
ProtectionEnabled: true,
SafeSearchConf: safeSearchConf,
SafeSearchCacheSize: 1000,
CacheTime: 30,
@@ -490,8 +506,7 @@ func TestSafeSearch(t *testing.T) {
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
@@ -503,12 +518,12 @@ func TestSafeSearch(t *testing.T) {
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
client := &dns.Client{}
yandexIP := net.IP{213, 180, 193, 56}
yandexIP := netip.AddrFrom4([4]byte{213, 180, 193, 56})
googleIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
testCases := []struct {
host string
want net.IP
want netip.Addr
}{{
host: "yandex.com.",
want: yandexIP,
@@ -548,10 +563,12 @@ func TestSafeSearch(t *testing.T) {
}
func TestInvalidRequest(t *testing.T) {
s := createTestServer(t, &filtering.Config{}, ServerConfig{
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
FilteringConfig: FilteringConfig{
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
@@ -579,15 +596,16 @@ func TestBlockedRequest(t *testing.T) {
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
BlockingMode: BlockingModeDefault,
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
},
}
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
s := createTestServer(t, &filtering.Config{
ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeDefault,
}, forwardConf, nil)
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@@ -608,14 +626,15 @@ func TestServerCustomClientUpstream(t *testing.T) {
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
},
}
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, forwardConf, nil)
s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) {
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
return aghalg.Coalesce(
@@ -658,10 +677,12 @@ var testIPv4 = map[string][]net.IP{
}
func TestBlockCNAMEProtectionEnabled(t *testing.T) {
s := createTestServer(t, &filtering.Config{}, ServerConfig{
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
FilteringConfig: FilteringConfig{
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
@@ -671,7 +692,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
CName: testCNAMEs,
IPv4: testIPv4,
}
s.conf.ProtectionEnabled = false
s.dnsProxy.UpstreamConfig = &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{testUpstm},
}
@@ -693,15 +714,16 @@ func TestBlockCNAME(t *testing.T) {
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
BlockingMode: BlockingModeDefault,
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
},
}
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
s := createTestServer(t, &filtering.Config{
ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeDefault,
}, forwardConf, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.Upstream{
CName: testCNAMEs,
@@ -763,9 +785,8 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
FilterHandler: func(_ net.IP, _ string, settings *filtering.Settings) {
Config: Config{
FilterHandler: func(_ netip.Addr, _ string, settings *filtering.Settings) {
settings.FilteringEnabled = false
},
EDNSClientSubnet: &EDNSClientSubnet{
@@ -773,7 +794,9 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
},
},
}
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, forwardConf, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.Upstream{
CName: testCNAMEs,
@@ -809,15 +832,16 @@ func TestNullBlockedRequest(t *testing.T) {
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
BlockingMode: BlockingModeNullIP,
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
},
}
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
s := createTestServer(t, &filtering.Config{
ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeNullIP,
}, forwardConf, nil)
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@@ -849,11 +873,21 @@ func TestBlockedCustomIP(t *testing.T) {
Data: []byte(rules),
}}
f, err := filtering.New(&filtering.Config{}, filters)
f, err := filtering.New(&filtering.Config{
ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeCustomIP,
BlockingIPv4: netip.Addr{},
BlockingIPv6: netip.Addr{},
}, filters)
require.NoError(t, err)
dhcp := &testDHCP{
OnEnabled: func() (ok bool) { return false },
OnHostByIP: func(_ netip.Addr) (host string) { panic("not implemented") },
OnIPByHost: func(_ string) (ip netip.Addr) { panic("not implemented") },
}
s, err := NewServer(DNSCreateParams{
DHCPServer: testDHCP,
DHCPServer: dhcp,
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
@@ -862,11 +896,8 @@ func TestBlockedCustomIP(t *testing.T) {
conf := &ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
BlockingMode: BlockingModeCustomIP,
BlockingIPv4: nil,
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
Config: Config{
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
@@ -877,8 +908,10 @@ func TestBlockedCustomIP(t *testing.T) {
err = s.Prepare(conf)
assert.Error(t, err)
conf.BlockingIPv4 = net.IP{0, 0, 0, 1}
conf.BlockingIPv6 = net.ParseIP("::1")
s.dnsFilter.SetBlockingMode(
filtering.BlockingModeCustomIP,
netip.AddrFrom4([4]byte{0, 0, 0, 1}),
netip.MustParseAddr("::1"))
err = s.Prepare(conf)
require.NoError(t, err)
@@ -915,16 +948,17 @@ func TestBlockedByHosts(t *testing.T) {
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
BlockingMode: BlockingModeDefault,
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
},
}
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
s := createTestServer(t, &filtering.Config{
ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeDefault,
}, forwardConf, nil)
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@@ -955,15 +989,16 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
ans4, _ := aghtest.HostToIPs(hostname)
filterConf := &filtering.Config{
SafeBrowsingEnabled: true,
SafeBrowsingChecker: sbChecker,
BlockingMode: filtering.BlockingModeDefault,
ProtectionEnabled: true,
SafeBrowsingEnabled: true,
SafeBrowsingChecker: sbChecker,
SafeBrowsingBlockHost: ans4.String(),
}
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
FilteringConfig: FilteringConfig{
SafeBrowsingBlockHost: ans4.String(),
ProtectionEnabled: true,
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
@@ -980,13 +1015,12 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
require.NoErrorf(t, err, "couldn't talk to server %s: %s", addr, err)
require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
a, ok := reply.Answer[0].(*dns.A)
require.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
assert.Equal(t, ans4, a.A, "dns server %s returned wrong answer: %v", addr, a.A)
assertResponse(t, reply, ans4)
}
func TestRewrite(t *testing.T) {
c := &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
Rewrites: []*filtering.LegacyRewrite{{
Domain: "test.com",
Answer: "1.2.3.4",
@@ -1006,8 +1040,13 @@ func TestRewrite(t *testing.T) {
f.SetEnabled(true)
dhcp := &testDHCP{
OnEnabled: func() (ok bool) { return false },
OnHostByIP: func(ip netip.Addr) (host string) { panic("not implemented") },
OnIPByHost: func(host string) (ip netip.Addr) { panic("not implemented") },
}
s, err := NewServer(DNSCreateParams{
DHCPServer: testDHCP,
DHCPServer: dhcp,
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
@@ -1016,10 +1055,8 @@ func TestRewrite(t *testing.T) {
assert.NoError(t, s.Prepare(&ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
BlockingMode: BlockingModeDefault,
UpstreamDNS: []string{"8.8.8.8:53"},
Config: Config{
UpstreamDNS: []string{"8.8.8.8:53"},
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
@@ -1102,31 +1139,42 @@ func publicKey(priv any) any {
}
}
var testDHCP = &dhcpd.MockInterface{
OnStart: func() (err error) { panic("not implemented") },
OnStop: func() (err error) { panic("not implemented") },
OnEnabled: func() (ok bool) { return true },
OnLeases: func(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) {
return []*dhcpd.Lease{{
IP: netip.MustParseAddr("192.168.12.34"),
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
Hostname: "myhost",
}}
},
OnSetOnLeaseChanged: func(olct dhcpd.OnLeaseChangedT) {},
OnFindMACbyIP: func(ip netip.Addr) (mac net.HardwareAddr) { panic("not implemented") },
OnWriteDiskConfig: func(c *dhcpd.ServerConfig) { panic("not implemented") },
// testDHCP is a mock implementation of the [DHCP] interface.
type testDHCP struct {
OnHostByIP func(ip netip.Addr) (host string)
OnIPByHost func(host string) (ip netip.Addr)
OnEnabled func() (ok bool)
}
// type check
var _ DHCP = (*testDHCP)(nil)
// HostByIP implements the [DHCP] interface for *testDHCP.
func (d *testDHCP) HostByIP(ip netip.Addr) (host string) { return d.OnHostByIP(ip) }
// IPByHost implements the [DHCP] interface for *testDHCP.
func (d *testDHCP) IPByHost(host string) (ip netip.Addr) { return d.OnIPByHost(host) }
// IsClientHost implements the [DHCP] interface for *testDHCP.
func (d *testDHCP) Enabled() (ok bool) { return d.OnEnabled() }
func TestPTRResponseFromDHCPLeases(t *testing.T) {
const localDomain = "lan"
flt, err := filtering.New(&filtering.Config{}, nil)
flt, err := filtering.New(&filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, nil)
require.NoError(t, err)
s, err := NewServer(DNSCreateParams{
DNSFilter: flt,
DHCPServer: testDHCP,
DNSFilter: flt,
DHCPServer: &testDHCP{
OnEnabled: func() (ok bool) { return true },
OnIPByHost: func(host string) (ip netip.Addr) { panic("not implemented") },
OnHostByIP: func(ip netip.Addr) (host string) {
return "myhost"
},
},
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
LocalDomain: localDomain,
})
@@ -1135,9 +1183,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true
s.conf.FilteringConfig.BlockingMode = BlockingModeDefault
s.conf.FilteringConfig.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
err = s.Prepare(&s.conf)
require.NoError(t, err)
@@ -1175,8 +1221,14 @@ func TestPTRResponseFromHosts(t *testing.T) {
`)},
}
dhcp := &testDHCP{
OnEnabled: func() (ok bool) { return false },
OnIPByHost: func(host string) (ip netip.Addr) { panic("not implemented") },
OnHostByIP: func(ip netip.Addr) (host string) { return "" },
}
var eventsCalledCounter uint32
hc, err := aghnet.NewHostsContainer(0, testFS, &aghtest.FSWatcher{
hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{
OnEvents: func() (e <-chan struct{}) {
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
@@ -1195,7 +1247,8 @@ func TestPTRResponseFromHosts(t *testing.T) {
})
flt, err := filtering.New(&filtering.Config{
EtcHosts: hc,
BlockingMode: filtering.BlockingModeDefault,
EtcHosts: hc,
}, nil)
require.NoError(t, err)
@@ -1203,7 +1256,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
var s *Server
s, err = NewServer(DNSCreateParams{
DHCPServer: testDHCP,
DHCPServer: dhcp,
DNSFilter: flt,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
@@ -1212,8 +1265,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.BlockingMode = BlockingModeDefault
s.conf.FilteringConfig.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
err = s.Prepare(&s.conf)
require.NoError(t, err)

View File

@@ -2,7 +2,7 @@ package dnsforward
import (
"fmt"
"net"
"net/netip"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
@@ -44,13 +44,13 @@ func (s *Server) ansFromDNSRewriteIP(
rr rules.RRType,
req *dns.Msg,
) (ans dns.RR, err error) {
ip, ok := v.(net.IP)
ip, ok := v.(netip.Addr)
if !ok {
return nil, fmt.Errorf("value for rr type %s has type %T, not net.IP", dns.Type(rr), v)
return nil, fmt.Errorf("value for rr type %s has type %T, not netip.Addr", dns.Type(rr), v)
}
if rr == dns.TypeA {
return s.genAnswerA(req, ip.To4()), nil
return s.genAnswerA(req, ip), nil
}
return s.genAnswerAAAA(req, ip), nil
@@ -160,13 +160,13 @@ func (s *Server) filterDNSRewrite(
return errors.Error("no dns rewrite rule responses")
}
rr := req.Question[0].Qtype
values := dnsrr.Response[rr]
qtype := req.Question[0].Qtype
values := dnsrr.Response[qtype]
for i, v := range values {
var ans dns.RR
ans, err = s.filterDNSRewriteResponse(req, rr, v)
ans, err = s.filterDNSRewriteResponse(req, qtype, v)
if err != nil {
return fmt.Errorf("dns rewrite response for %d[%d]: %w", rr, i, err)
return fmt.Errorf("dns rewrite response for %s[%d]: %w", dns.Type(qtype), i, err)
}
resp.Answer = append(resp.Answer, ans)

View File

@@ -6,6 +6,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@@ -15,8 +16,7 @@ import (
func TestServer_FilterDNSRewrite(t *testing.T) {
// Helper data.
const domain = "example.com"
ip4 := net.IP{127, 0, 0, 1}
ip6 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
ip4, ip6 := netutil.IPv4Localhost(), netutil.IPv6Localhost()
mxVal := &rules.DNSMX{
Exchange: "mail.example.com",
Preference: 32,
@@ -34,7 +34,14 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
}
// Helper functions and entities.
srv := &Server{}
srv := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
}, nil)
makeQ := func(qtype rules.RRType) (req *dns.Msg) {
return &dns.Msg{
Question: []dns.Question{{
@@ -89,7 +96,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
require.Len(t, d.Res.Answer, 1)
assert.Equal(t, ip4, d.Res.Answer[0].(*dns.A).A)
assert.Equal(t, net.IP(ip4.AsSlice()), d.Res.Answer[0].(*dns.A).A)
})
t.Run("noerror_aaaa", func(t *testing.T) {
@@ -103,7 +110,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
require.Len(t, d.Res.Answer, 1)
assert.Equal(t, ip6, d.Res.Answer[0].(*dns.AAAA).AAAA)
assert.Equal(t, net.IP(ip6.AsSlice()), d.Res.Answer[0].(*dns.AAAA).AAAA)
})
t.Run("noerror_ptr", func(t *testing.T) {

View File

@@ -3,6 +3,7 @@ package dnsforward
import (
"encoding/binary"
"fmt"
"net"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
@@ -11,6 +12,7 @@ import (
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)
// beforeRequestHandler is the handler that is called before any other
@@ -57,8 +59,8 @@ func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filter
setts = s.dnsFilter.Settings()
setts.ProtectionEnabled = dctx.protectionEnabled
if s.conf.FilterHandler != nil {
ip, _ := netutil.IPAndPortFromAddr(dctx.proxyCtx.Addr)
s.conf.FilterHandler(ip, dctx.clientID, setts)
addrPort := netutil.NetAddrToAddrPort(dctx.proxyCtx.Addr)
s.conf.FilterHandler(addrPort.Addr(), dctx.clientID, setts)
}
return setts
@@ -123,7 +125,7 @@ func (s *Server) filterRewritten(
for _, ip := range res.IPList {
switch qt {
case dns.TypeA:
a := s.genAnswerA(req, ip.To4())
a := s.genAnswerA(req, ip)
a.Hdr.Name = dns.Fqdn(name)
resp.Answer = append(resp.Answer, a)
case dns.TypeAAAA:
@@ -145,10 +147,6 @@ func (s *Server) checkHostRules(host string, rrtype uint16, setts *filtering.Set
s.serverLock.RLock()
defer s.serverLock.RUnlock()
if s.dnsFilter == nil {
return nil, nil
}
var res filtering.Result
res, err = s.dnsFilter.CheckHostRules(host, rrtype, setts)
if err != nil {
@@ -176,19 +174,26 @@ func (s *Server) filterDNSResponse(
case *dns.CNAME:
host = strings.TrimSuffix(a.Target, ".")
rrtype = dns.TypeCNAME
res, err = s.checkHostRules(host, rrtype, setts)
case *dns.A:
host = a.A.String()
rrtype = dns.TypeA
res, err = s.checkHostRules(host, rrtype, setts)
case *dns.AAAA:
host = a.AAAA.String()
rrtype = dns.TypeAAAA
res, err = s.checkHostRules(host, rrtype, setts)
case *dns.HTTPS:
res, err = s.filterHTTPSRecords(a, setts)
default:
continue
}
log.Debug("dnsforward: checking %s %s for %s", dns.Type(rrtype), host, a.Header().Name)
log.Debug("dnsforward: checked %s %s for %s", dns.Type(rrtype), host, a.Header().Name)
res, err = s.checkHostRules(host, rrtype, setts)
if err != nil {
return nil, err
} else if res == nil {
@@ -203,3 +208,67 @@ func (s *Server) filterDNSResponse(
return nil, nil
}
// removeIPv6Hints deletes IPv6 hints from RR values.
func removeIPv6Hints(rr *dns.HTTPS) {
rr.Value = slices.DeleteFunc(rr.Value, func(kv dns.SVCBKeyValue) (del bool) {
_, ok := kv.(*dns.SVCBIPv6Hint)
return ok
})
}
// filterHTTPSRecords filters HTTPS answers information through all rule list
// filters of the server filters. Removes IPv6 hints if IPv6 resolving is
// disabled.
func (s *Server) filterHTTPSRecords(rr *dns.HTTPS, setts *filtering.Settings) (r *filtering.Result, err error) {
if s.conf.AAAADisabled {
removeIPv6Hints(rr)
}
for _, kv := range rr.Value {
var ips []net.IP
switch hint := kv.(type) {
case *dns.SVCBIPv4Hint:
ips = hint.Hint
case *dns.SVCBIPv6Hint:
ips = hint.Hint
default:
// Go on.
}
if len(ips) == 0 {
continue
}
r, err = s.filterSVCBHint(ips, setts)
if err != nil {
return nil, fmt.Errorf("filtering svcb hints: %w", err)
}
if r != nil {
return r, nil
}
}
return nil, nil
}
// filterSVCBHint filters SVCB hint information.
func (s *Server) filterSVCBHint(
hint []net.IP,
setts *filtering.Settings,
) (res *filtering.Result, err error) {
for _, h := range hint {
res, err = s.checkHostRules(h.String(), dns.TypeHTTPS, setts)
if err != nil {
return nil, fmt.Errorf("checking rules for %s: %w", h, err)
}
if res != nil && res.IsFiltered {
return res, nil
}
}
return nil, nil
}

View File

@@ -2,6 +2,7 @@ package dnsforward
import (
"net"
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
@@ -14,7 +15,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
func TestHandleDNSRequest_handleDNSRequest(t *testing.T) {
rules := `
||blocked.domain^
@@||allowed.domain^
@@ -23,14 +24,13 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
||::1^$dnstype=~AAAA
0.0.0.0 duplicate.domain
0.0.0.0 duplicate.domain
0.0.0.0 blocked.by.hostrule
`
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
BlockingMode: BlockingModeDefault,
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
@@ -40,12 +40,19 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
ID: 0, Data: []byte(rules),
}}
f, err := filtering.New(&filtering.Config{}, filters)
f, err := filtering.New(&filtering.Config{
ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeDefault,
}, filters)
require.NoError(t, err)
f.SetEnabled(true)
s, err := NewServer(DNSCreateParams{
DHCPServer: testDHCP,
DHCPServer: &testDHCP{
OnEnabled: func() (ok bool) { return false },
OnHostByIP: func(ip netip.Addr) (host string) { panic("not implemented") },
OnIPByHost: func(host string) (ip netip.Addr) { panic("not implemented") },
},
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
@@ -73,12 +80,19 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
startDeferStop(t, s)
testCases := []struct {
req *dns.Msg
name string
wantAns []dns.RR
req *dns.Msg
name string
wantRCode int
wantAns []dns.RR
}{{
req: createTestMessage("cname.exception."),
name: "cname_exception",
req: createTestMessage(aghtest.ReqFQDN),
name: "pass",
wantRCode: dns.RcodeNameError,
wantAns: nil,
}, {
req: createTestMessage("cname.exception."),
name: "cname_exception",
wantRCode: dns.RcodeSuccess,
wantAns: []dns.RR{&dns.CNAME{
Hdr: dns.RR_Header{
Name: "cname.exception.",
@@ -87,8 +101,9 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
Target: "cname.specific.",
}},
}, {
req: createTestMessage("should.block."),
name: "blocked_by_cname",
req: createTestMessage("should.block."),
name: "blocked_by_cname",
wantRCode: dns.RcodeSuccess,
wantAns: []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: "should.block.",
@@ -98,8 +113,9 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
A: netutil.IPv4Zero(),
}},
}, {
req: createTestMessage("a.exception."),
name: "a_exception",
req: createTestMessage("a.exception."),
name: "a_exception",
wantRCode: dns.RcodeSuccess,
wantAns: []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: "a.exception.",
@@ -108,8 +124,9 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
A: net.IP{0, 0, 0, 1},
}},
}, {
req: createTestMessageWithType("aaaa.exception.", dns.TypeAAAA),
name: "aaaa_exception",
req: createTestMessageWithType("aaaa.exception.", dns.TypeAAAA),
name: "aaaa_exception",
wantRCode: dns.RcodeSuccess,
wantAns: []dns.RR{&dns.AAAA{
Hdr: dns.RR_Header{
Name: "aaaa.exception.",
@@ -118,8 +135,9 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
AAAA: net.ParseIP("::1"),
}},
}, {
req: createTestMessage("allowed.first."),
name: "allowed_first",
req: createTestMessage("allowed.first."),
name: "allowed_first",
wantRCode: dns.RcodeSuccess,
wantAns: []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: "allowed.first.",
@@ -129,8 +147,9 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
A: netutil.IPv4Zero(),
}},
}, {
req: createTestMessage("blocked.first."),
name: "blocked_first",
req: createTestMessage("blocked.first."),
name: "blocked_first",
wantRCode: dns.RcodeSuccess,
wantAns: []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: "blocked.first.",
@@ -140,8 +159,9 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
A: netutil.IPv4Zero(),
}},
}, {
req: createTestMessage("duplicate.domain."),
name: "duplicate_domain",
req: createTestMessage("duplicate.domain."),
name: "duplicate_domain",
wantRCode: dns.RcodeSuccess,
wantAns: []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: "duplicate.domain.",
@@ -150,6 +170,16 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
},
A: netutil.IPv4Zero(),
}},
}, {
req: createTestMessageWithType("blocked.domain.", dns.TypeHTTPS),
name: "blocked_https_req",
wantRCode: dns.RcodeSuccess,
wantAns: nil,
}, {
req: createTestMessageWithType("blocked.by.hostrule.", dns.TypeHTTPS),
name: "blocked_host_rule_https_req",
wantRCode: dns.RcodeSuccess,
wantAns: nil,
}}
for _, tc := range testCases {
@@ -164,7 +194,175 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, dctx.Res)
assert.Equal(t, tc.wantRCode, dctx.Res.Rcode)
assert.Equal(t, tc.wantAns, dctx.Res.Answer)
})
}
}
func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
const (
passedIPv4Str = "1.1.1.1"
blockedIPv4Str = "1.2.3.4"
blockedIPv6Str = "1234::cdef"
blockRules = blockedIPv4Str + "\n" + blockedIPv6Str + "\n"
)
var (
passedIPv4 net.IP = netip.MustParseAddr(passedIPv4Str).AsSlice()
blockedIPv4 net.IP = netip.MustParseAddr(blockedIPv4Str).AsSlice()
blockedIPv6 net.IP = netip.MustParseAddr(blockedIPv6Str).AsSlice()
)
filters := []filtering.Filter{{
ID: 0, Data: []byte(blockRules),
}}
f, err := filtering.New(&filtering.Config{}, filters)
require.NoError(t, err)
f.SetEnabled(true)
s, err := NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
require.NoError(t, err)
testCases := []struct {
req *dns.Msg
name string
wantRule string
respAns []dns.RR
}{{
name: "pass",
req: createTestMessageWithType(aghtest.ReqFQDN, dns.TypeA),
wantRule: "",
respAns: []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: aghtest.ReqFQDN,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: passedIPv4,
}},
}, {
name: "ipv4",
req: createTestMessageWithType(aghtest.ReqFQDN, dns.TypeA),
wantRule: blockedIPv4Str,
respAns: []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: aghtest.ReqFQDN,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: blockedIPv4,
}},
}, {
name: "ipv6",
req: createTestMessageWithType(aghtest.ReqFQDN, dns.TypeAAAA),
wantRule: blockedIPv6Str,
respAns: []dns.RR{&dns.AAAA{
Hdr: dns.RR_Header{
Name: aghtest.ReqFQDN,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
},
AAAA: blockedIPv6,
}},
}, {
name: "ipv4hint",
req: createTestMessageWithType(aghtest.ReqFQDN, dns.TypeHTTPS),
wantRule: blockedIPv4Str,
respAns: newSVCBHintsAnswer(
aghtest.ReqFQDN,
[]dns.SVCBKeyValue{
&dns.SVCBIPv4Hint{Hint: []net.IP{blockedIPv4}},
&dns.SVCBIPv6Hint{Hint: []net.IP{}},
},
),
}, {
name: "ipv6hint",
req: createTestMessageWithType(aghtest.ReqFQDN, dns.TypeHTTPS),
wantRule: blockedIPv6Str,
respAns: newSVCBHintsAnswer(
aghtest.ReqFQDN,
[]dns.SVCBKeyValue{
&dns.SVCBIPv4Hint{Hint: []net.IP{}},
&dns.SVCBIPv6Hint{Hint: []net.IP{blockedIPv6}},
},
),
}, {
name: "ipv4_ipv6_hints",
req: createTestMessageWithType(aghtest.ReqFQDN, dns.TypeHTTPS),
wantRule: blockedIPv4Str,
respAns: newSVCBHintsAnswer(
aghtest.ReqFQDN,
[]dns.SVCBKeyValue{
&dns.SVCBIPv4Hint{Hint: []net.IP{blockedIPv4}},
&dns.SVCBIPv6Hint{Hint: []net.IP{blockedIPv6}},
},
),
}, {
name: "pass_hints",
req: createTestMessageWithType(aghtest.ReqFQDN, dns.TypeHTTPS),
wantRule: "",
respAns: newSVCBHintsAnswer(
aghtest.ReqFQDN,
[]dns.SVCBKeyValue{
&dns.SVCBIPv4Hint{Hint: []net.IP{passedIPv4}},
&dns.SVCBIPv6Hint{Hint: []net.IP{}},
},
),
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
resp := newResp(dns.RcodeSuccess, tc.req, tc.respAns)
pctx := &proxy.DNSContext{
Proto: proxy.ProtoUDP,
Req: tc.req,
Res: resp,
Addr: &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 1},
}
res, rErr := s.filterDNSResponse(pctx, &filtering.Settings{
ProtectionEnabled: true,
FilteringEnabled: true,
})
require.NoError(t, rErr)
if tc.wantRule == "" {
assert.Nil(t, res)
return
}
want := &filtering.Result{
IsFiltered: true,
Reason: filtering.FilteredBlockList,
Rules: []*filtering.ResultRule{{
Text: tc.wantRule,
}},
}
assert.Equal(t, want, res)
})
}
}
// newSVCBHintsAnswer returns a test HTTPS answer RRs with SVCB hints.
func newSVCBHintsAnswer(target string, hints []dns.SVCBKeyValue) (rrs []dns.RR) {
return []dns.RR{&dns.HTTPS{
SVCB: dns.SVCB{
Hdr: dns.RR_Header{
Name: target,
Rrtype: dns.TypeHTTPS,
Class: dns.ClassINET,
},
Target: target,
Value: hints,
},
}}
}

View File

@@ -4,13 +4,13 @@ import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
@@ -37,6 +37,10 @@ type jsonDNSConfig struct {
// upstream DoH/DoT resolvers.
Bootstraps *[]string `json:"bootstrap_dns"`
// Fallbacks is the list of fallback DNS servers used when upstream DNS
// servers are not responding.
Fallbacks *[]string `json:"fallback_dns"`
// ProtectionEnabled defines if protection is enabled.
ProtectionEnabled *bool `json:"protection_enabled"`
@@ -44,7 +48,7 @@ type jsonDNSConfig struct {
RateLimit *uint32 `json:"ratelimit"`
// BlockingMode defines the way blocked responses are constructed.
BlockingMode *BlockingMode `json:"blocking_mode"`
BlockingMode *filtering.BlockingMode `json:"blocking_mode"`
// EDNSCSEnabled defines if EDNS Client Subnet is enabled.
EDNSCSEnabled *bool `json:"edns_cs_enabled"`
@@ -83,10 +87,10 @@ type jsonDNSConfig struct {
LocalPTRUpstreams *[]string `json:"local_ptr_upstreams"`
// BlockingIPv4 is custom IPv4 address for blocked A requests.
BlockingIPv4 net.IP `json:"blocking_ipv4"`
BlockingIPv4 netip.Addr `json:"blocking_ipv4"`
// BlockingIPv6 is custom IPv6 address for blocked AAAA requests.
BlockingIPv6 net.IP `json:"blocking_ipv6"`
BlockingIPv6 netip.Addr `json:"blocking_ipv6"`
// DisabledUntil is a timestamp until when the protection is disabled.
DisabledUntil *time.Time `json:"protection_disabled_until"`
@@ -109,9 +113,8 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
upstreams := stringutil.CloneSliceOrEmpty(s.conf.UpstreamDNS)
upstreamFile := s.conf.UpstreamDNSFileName
bootstraps := stringutil.CloneSliceOrEmpty(s.conf.BootstrapDNS)
blockingMode := s.conf.BlockingMode
blockingIPv4 := s.conf.BlockingIPv4
blockingIPv6 := s.conf.BlockingIPv6
fallbacks := stringutil.CloneSliceOrEmpty(s.conf.FallbackDNS)
blockingMode, blockingIPv4, blockingIPv6 := s.dnsFilter.BlockingMode()
ratelimit := s.conf.Ratelimit
customIP := s.conf.EDNSClientSubnet.CustomIP
@@ -144,6 +147,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
Upstreams: &upstreams,
UpstreamsFile: &upstreamFile,
Bootstraps: &bootstraps,
Fallbacks: &fallbacks,
ProtectionEnabled: &protectionEnabled,
BlockingMode: &blockingMode,
BlockingIPv4: blockingIPv4,
@@ -170,7 +174,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
// handleGetConfig handles requests to the GET /control/dns_info endpoint.
func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
resp := s.getDNSConfig()
_ = aghhttp.WriteJSONResponse(w, r, resp)
aghhttp.WriteJSONResponseOK(w, r, resp)
}
func (req *jsonDNSConfig) checkBlockingMode() (err error) {
@@ -208,6 +212,20 @@ func (req *jsonDNSConfig) checkBootstrap() (err error) {
return nil
}
// checkFallbacks returns an error if any fallback address is invalid.
func (req *jsonDNSConfig) checkFallbacks() (err error) {
if req.Fallbacks == nil {
return nil
}
err = ValidateUpstreams(*req.Fallbacks)
if err != nil {
return fmt.Errorf("validating fallback servers: %w", err)
}
return nil
}
// validate returns an error if any field of req is invalid.
func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
if req.Upstreams != nil {
@@ -229,6 +247,11 @@ func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
return err
}
err = req.checkFallbacks()
if err != nil {
return err
}
err = req.checkBlockingMode()
if err != nil {
return err
@@ -295,11 +318,11 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
defer s.serverLock.Unlock()
if dc.BlockingMode != nil {
s.conf.BlockingMode = *dc.BlockingMode
if *dc.BlockingMode == BlockingModeCustomIP {
s.conf.BlockingIPv4 = dc.BlockingIPv4.To4()
s.conf.BlockingIPv6 = dc.BlockingIPv6.To16()
}
s.dnsFilter.SetBlockingMode(*dc.BlockingMode, dc.BlockingIPv4, dc.BlockingIPv6)
}
if dc.ProtectionEnabled != nil {
s.dnsFilter.SetProtectionEnabled(*dc.ProtectionEnabled)
}
if dc.UpstreamMode != nil {
@@ -311,7 +334,6 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
s.conf.EDNSClientSubnet.CustomIP = dc.EDNSCSCustomIP
}
setIfNotNil(&s.conf.ProtectionEnabled, dc.ProtectionEnabled)
setIfNotNil(&s.conf.EnableDNSSEC, dc.DNSSECEnabled)
setIfNotNil(&s.conf.AAAADisabled, dc.DisableIPv6)
@@ -342,6 +364,7 @@ func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
setIfNotNil(&s.conf.LocalPTRResolvers, dc.LocalPTRUpstreams),
setIfNotNil(&s.conf.UpstreamDNSFileName, dc.UpstreamsFile),
setIfNotNil(&s.conf.BootstrapDNS, dc.Bootstraps),
setIfNotNil(&s.conf.FallbackDNS, dc.Fallbacks),
setIfNotNil(&s.conf.EDNSClientSubnet.Enabled, dc.EDNSCSEnabled),
setIfNotNil(&s.conf.EDNSClientSubnet.UseCustom, dc.EDNSCSUseCustom),
setIfNotNil(&s.conf.CacheSize, dc.CacheSize),
@@ -369,6 +392,7 @@ func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
type upstreamJSON struct {
Upstreams []string `json:"upstream_dns"`
BootstrapDNS []string `json:"bootstrap_dns"`
FallbackDNS []string `json:"fallback_dns"`
PrivateUpstreams []string `json:"private_upstream"`
}
@@ -441,7 +465,7 @@ func ValidateUpstreams(upstreams []string) (err error) {
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
conf, err := newUpstreamConfig(upstreams)
if err != nil {
return err
return fmt.Errorf("creating config: %w", err)
}
if conf == nil {
@@ -469,11 +493,7 @@ func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet)
}
}
if len(errs) > 0 {
return errors.List("checking domain-specific upstreams", errs...)
}
return nil
return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w")
}
var protocols = []string{
@@ -667,10 +687,13 @@ func (s *Server) parseUpstreamLine(
PreferIPv6: opts.PreferIPv6,
}
if s.dnsFilter != nil && s.dnsFilter.EtcHosts != nil {
resolved := s.resolveUpstreamHost(extractUpstreamHost(upstreamAddr))
sortNetIPAddrs(resolved, opts.PreferIPv6)
opts.ServerIPAddrs = resolved
// dnsFilter can be nil during application update.
if s.dnsFilter != nil {
recs := s.dnsFilter.EtcHostsRecords(extractUpstreamHost(upstreamAddr))
for _, rec := range recs {
opts.ServerIPAddrs = append(opts.ServerIPAddrs, rec.Addr.AsSlice())
}
sortNetIPAddrs(opts.ServerIPAddrs, opts.PreferIPv6)
}
u, err = upstream.AddressToUpstream(upstreamAddr, opts)
if err != nil {
@@ -728,7 +751,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty)
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
upsNum := len(req.Upstreams) + len(req.PrivateUpstreams)
upsNum := len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams)
result := make(map[string]string, upsNum)
resCh := make(chan upsCheckResult, upsNum)
@@ -740,6 +763,14 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
}
}(ups)
}
for _, ups := range req.FallbackDNS {
go func(ups string) {
resCh <- upsCheckResult{
host: ups,
err: s.checkDNS(ups, opts, checkDNSUpstreamExc),
}
}(ups)
}
for _, ups := range req.PrivateUpstreams {
go func(ups string) {
resCh <- upsCheckResult{
@@ -760,7 +791,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
}
}
_ = aghhttp.WriteJSONResponse(w, r, result)
aghhttp.WriteJSONResponseOK(w, r, result)
}
// handleCacheClear is the handler for the POST /control/cache_clear HTTP API.
@@ -806,8 +837,7 @@ func (s *Server) handleSetProtection(w http.ResponseWriter, r *http.Request) {
s.serverLock.Lock()
defer s.serverLock.Unlock()
s.conf.ProtectionEnabled = protectionReq.Enabled
s.conf.ProtectionDisabledUntil = disabledUntil
s.dnsFilter.SetProtectionStatus(protectionReq.Enabled, disabledUntil)
}()
s.conf.ConfigModified()

View File

@@ -58,6 +58,8 @@ const jsonExt = ".json"
func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
filterConf := &filtering.Config{
ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeDefault,
SafeBrowsingEnabled: true,
SafeBrowsingCacheSize: 1000,
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
@@ -68,11 +70,10 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{},
TCPListenAddrs: []*net.TCPAddr{},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
BlockingMode: BlockingModeDefault,
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
Config: Config{
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
FallbackDNS: []string{"9.9.9.10"},
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ConfigModified: func() {},
}
@@ -134,6 +135,8 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
filterConf := &filtering.Config{
ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeDefault,
SafeBrowsingEnabled: true,
SafeBrowsingCacheSize: 1000,
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
@@ -144,11 +147,9 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{},
TCPListenAddrs: []*net.TCPAddr{},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
BlockingMode: BlockingModeDefault,
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
Config: Config{
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ConfigModified: func() {},
}
@@ -177,7 +178,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
wantSet: "",
}, {
name: "blocking_mode_bad",
wantSet: "blocking_ipv4 must be set when blocking_mode is custom_ip",
wantSet: "blocking_ipv4 must be valid ipv4 on custom_ip blocking_mode",
}, {
name: "ratelimit",
wantSet: "",
@@ -225,6 +226,9 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
}, {
name: "local_ptr_upstreams_null",
wantSet: "",
}, {
name: "fallbacks",
wantSet: "",
}}
var data map[string]struct {
@@ -243,8 +247,9 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() {
s.dnsFilter.SetBlockingMode(filtering.BlockingModeDefault, netip.Addr{}, netip.Addr{})
s.conf = defaultConf
s.conf.FilteringConfig.EDNSClientSubnet = &EDNSClientSubnet{}
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{}
})
rBody := io.NopCloser(bytes.NewReader(caseData.Req))
@@ -405,9 +410,9 @@ func TestValidateUpstreamsPrivate(t *testing.T) {
u: "[/128.in-addr.arpa/]#",
}, {
name: "several_bad",
wantErr: `checking domain-specific upstreams: 2 errors: ` +
`"arpa domain \"1.2.3.4.in-addr.arpa.\" should point to a locally-served network", ` +
`"bad arpa domain name \"non.arpa.\": not a reversed ip network"`,
wantErr: `checking domain-specific upstreams: ` +
`arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network` + "\n" +
`bad arpa domain name "non.arpa.": not a reversed ip network`,
u: "[/non.arpa/1.2.3.4.in-addr.arpa/127.in-addr.arpa/]#",
}, {
name: "partial_good",
@@ -479,7 +484,6 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
}).String()
hc, err := aghnet.NewHostsContainer(
filtering.SysHostsListID,
fstest.MapFS{
hostsFileName: &fstest.MapFile{
Data: []byte(hostsListener.Addr().String() + " " + upstreamHost),
@@ -495,12 +499,13 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
require.NoError(t, err)
srv := createTestServer(t, &filtering.Config{
EtcHosts: hc,
BlockingMode: filtering.BlockingModeDefault,
EtcHosts: hc,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
UpstreamTimeout: upsTimeout,
FilteringConfig: FilteringConfig{
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
}, nil)
@@ -555,6 +560,23 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
hostsUps: "OK",
},
name: "etc_hosts",
}, {
body: map[string]any{
"fallback_dns": []string{goodUps},
},
wantResp: map[string]any{
goodUps: "OK",
},
name: "fallback_success",
}, {
body: map[string]any{
"fallback_dns": []string{badUps},
},
wantResp: map[string]any{
badUps: `couldn't communicate with upstream: exchanging with ` +
badUps + ` over tcp: dns: id mismatch`,
},
name: "fallback_broken",
}}
for _, tc := range testCases {

View File

@@ -1,7 +1,7 @@
package dnsforward
import (
"net"
"net/netip"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
@@ -9,6 +9,7 @@ import (
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)
// makeResponse creates a DNS response by req and sets necessary flags. It also
@@ -26,24 +27,13 @@ func (s *Server) makeResponse(req *dns.Msg) (resp *dns.Msg) {
return resp
}
// containsIP returns true if the IP is already in the list.
func containsIP(ips []net.IP, ip net.IP) bool {
for _, a := range ips {
if a.Equal(ip) {
return true
}
}
return false
}
// ipsFromRules extracts unique non-IP addresses from the filtering result
// rules.
func ipsFromRules(resRules []*filtering.ResultRule) (ips []net.IP) {
func ipsFromRules(resRules []*filtering.ResultRule) (ips []netip.Addr) {
for _, r := range resRules {
// len(resRules) and len(ips) are actually small enough for O(n^2) to do
// not raise performance questions.
if ip := r.IP; ip != nil && !containsIP(ips, ip) {
if ip := r.IP; ip != (netip.Addr{}) && !slices.Contains(ips, ip) {
ips = append(ips, ip)
}
}
@@ -58,19 +48,21 @@ func (s *Server) genDNSFilterMessage(
res *filtering.Result,
) (resp *dns.Msg) {
req := dctx.Req
if qt := req.Question[0].Qtype; qt != dns.TypeA && qt != dns.TypeAAAA {
if s.conf.BlockingMode == BlockingModeNullIP {
qt := req.Question[0].Qtype
if qt != dns.TypeA && qt != dns.TypeAAAA {
m, _, _ := s.dnsFilter.BlockingMode()
if m == filtering.BlockingModeNullIP {
return s.makeResponse(req)
}
return s.genNXDomain(req)
return s.newMsgNODATA(req)
}
switch res.Reason {
case filtering.FilteredSafeBrowsing:
return s.genBlockedHost(req, s.conf.SafeBrowsingBlockHost, dctx)
return s.genBlockedHost(req, s.dnsFilter.SafeBrowsingBlockHost(), dctx)
case filtering.FilteredParental:
return s.genBlockedHost(req, s.conf.ParentalBlockHost, dctx)
return s.genBlockedHost(req, s.dnsFilter.ParentalBlockHost(), dctx)
case filtering.FilteredSafeSearch:
// If Safe Search generated the necessary IP addresses, use them.
// Otherwise, if there were no errors, there are no addresses for the
@@ -83,36 +75,45 @@ func (s *Server) genDNSFilterMessage(
// genForBlockingMode generates a filtered response to req based on the server's
// blocking mode.
func (s *Server) genForBlockingMode(req *dns.Msg, ips []net.IP) (resp *dns.Msg) {
qt := req.Question[0].Qtype
switch m := s.conf.BlockingMode; m {
case BlockingModeCustomIP:
switch qt {
case dns.TypeA:
return s.genARecord(req, s.conf.BlockingIPv4)
case dns.TypeAAAA:
return s.genAAAARecord(req, s.conf.BlockingIPv6)
default:
// Generally shouldn't happen, since the types are checked in
// genDNSFilterMessage.
log.Error("dns: invalid msg type %s for blocking mode %s", dns.Type(qt), m)
return s.makeResponse(req)
}
case BlockingModeDefault:
func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.Msg) {
switch mode, bIPv4, bIPv6 := s.dnsFilter.BlockingMode(); mode {
case filtering.BlockingModeCustomIP:
return s.makeResponseCustomIP(req, bIPv4, bIPv6)
case filtering.BlockingModeDefault:
if len(ips) > 0 {
return s.genResponseWithIPs(req, ips)
}
return s.makeResponseNullIP(req)
case BlockingModeNullIP:
case filtering.BlockingModeNullIP:
return s.makeResponseNullIP(req)
case BlockingModeNXDOMAIN:
case filtering.BlockingModeNXDOMAIN:
return s.genNXDomain(req)
case BlockingModeREFUSED:
case filtering.BlockingModeREFUSED:
return s.makeResponseREFUSED(req)
default:
log.Error("dns: invalid blocking mode %q", s.conf.BlockingMode)
log.Error("dns: invalid blocking mode %q", mode)
return s.makeResponse(req)
}
}
// makeResponseCustomIP generates a DNS response message for Custom IP blocking
// mode with the provided IP addresses and an appropriate resource record type.
func (s *Server) makeResponseCustomIP(
req *dns.Msg,
bIPv4 netip.Addr,
bIPv6 netip.Addr,
) (resp *dns.Msg) {
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA:
return s.genARecord(req, bIPv4)
case dns.TypeAAAA:
return s.genAAAARecord(req, bIPv6)
default:
// Generally shouldn't happen, since the types are checked in
// genDNSFilterMessage.
log.Error("dns: invalid msg type %s for custom IP blocking mode", dns.Type(qt))
return s.makeResponse(req)
}
@@ -125,13 +126,13 @@ func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg {
return &resp
}
func (s *Server) genARecord(request *dns.Msg, ip net.IP) *dns.Msg {
func (s *Server) genARecord(request *dns.Msg, ip netip.Addr) *dns.Msg {
resp := s.makeResponse(request)
resp.Answer = append(resp.Answer, s.genAnswerA(request, ip))
return resp
}
func (s *Server) genAAAARecord(request *dns.Msg, ip net.IP) *dns.Msg {
func (s *Server) genAAAARecord(request *dns.Msg, ip netip.Addr) *dns.Msg {
resp := s.makeResponse(request)
resp.Answer = append(resp.Answer, s.genAnswerAAAA(request, ip))
return resp
@@ -141,22 +142,22 @@ func (s *Server) hdr(req *dns.Msg, rrType rules.RRType) (h dns.RR_Header) {
return dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: rrType,
Ttl: s.conf.BlockedResponseTTL,
Ttl: s.dnsFilter.BlockedResponseTTL(),
Class: dns.ClassINET,
}
}
func (s *Server) genAnswerA(req *dns.Msg, ip net.IP) (ans *dns.A) {
func (s *Server) genAnswerA(req *dns.Msg, ip netip.Addr) (ans *dns.A) {
return &dns.A{
Hdr: s.hdr(req, dns.TypeA),
A: ip,
A: ip.AsSlice(),
}
}
func (s *Server) genAnswerAAAA(req *dns.Msg, ip net.IP) (ans *dns.AAAA) {
func (s *Server) genAnswerAAAA(req *dns.Msg, ip netip.Addr) (ans *dns.AAAA) {
return &dns.AAAA{
Hdr: s.hdr(req, dns.TypeAAAA),
AAAA: ip,
AAAA: ip.AsSlice(),
}
}
@@ -203,22 +204,24 @@ func (s *Server) genAnswerTXT(req *dns.Msg, strs []string) (ans *dns.TXT) {
// addresses and an appropriate resource record type. If any of the IPs cannot
// be converted to the correct protocol, genResponseWithIPs returns an empty
// response.
func (s *Server) genResponseWithIPs(req *dns.Msg, ips []net.IP) (resp *dns.Msg) {
func (s *Server) genResponseWithIPs(req *dns.Msg, ips []netip.Addr) (resp *dns.Msg) {
var ans []dns.RR
switch req.Question[0].Qtype {
case dns.TypeA:
for _, ip := range ips {
if ip4 := ip.To4(); ip4 == nil {
if ip.Is4() {
ans = append(ans, s.genAnswerA(req, ip))
} else {
ans = nil
break
}
ans = append(ans, s.genAnswerA(req, ip))
}
case dns.TypeAAAA:
for _, ip := range ips {
ans = append(ans, s.genAnswerAAAA(req, ip.To16()))
if ip.Is6() {
ans = append(ans, s.genAnswerAAAA(req, ip))
}
}
default:
// Go on and return an empty response.
@@ -239,9 +242,9 @@ func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) {
// converted into an empty slice instead of the zero IPv4.
switch req.Question[0].Qtype {
case dns.TypeA:
resp = s.genResponseWithIPs(req, []net.IP{{0, 0, 0, 0}})
resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv4Unspecified()})
case dns.TypeAAAA:
resp = s.genResponseWithIPs(req, []net.IP{net.IPv6zero})
resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv6Unspecified()})
default:
resp = s.makeResponse(req)
}
@@ -250,9 +253,15 @@ func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) {
}
func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSContext) *dns.Msg {
ip := net.ParseIP(newAddr)
if ip != nil {
return s.genResponseWithIPs(request, []net.IP{ip})
if newAddr == "" {
log.Printf("block host is not specified.")
return s.genServerFailure(request)
}
ip, err := netip.ParseAddr(newAddr)
if err == nil {
return s.genResponseWithIPs(request, []netip.Addr{ip})
}
// look up the hostname, TODO: cache
@@ -274,7 +283,7 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
return s.genServerFailure(request)
}
err := prx.Resolve(newContext)
err = prx.Resolve(newContext)
if err != nil {
log.Printf("couldn't look up replacement host %q: %s", newAddr, err)
@@ -314,6 +323,17 @@ func (s *Server) makeResponseREFUSED(request *dns.Msg) *dns.Msg {
return &resp
}
// newMsgNODATA returns a properly initialized NODATA response.
//
// See https://www.rfc-editor.org/rfc/rfc2308#section-2.2.
func (s *Server) newMsgNODATA(req *dns.Msg) (resp *dns.Msg) {
resp = (&dns.Msg{}).SetRcode(req, dns.RcodeSuccess)
resp.RecursionAvailable = true
resp.Ns = s.genSOA(req)
return resp
}
func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg {
resp := dns.Msg{}
resp.SetRcode(request, dns.RcodeNameError)
@@ -342,13 +362,13 @@ func (s *Server) genSOA(request *dns.Msg) []dns.RR {
Hdr: dns.RR_Header{
Name: zone,
Rrtype: dns.TypeSOA,
Ttl: s.conf.BlockedResponseTTL,
Ttl: s.dnsFilter.BlockedResponseTTL(),
Class: dns.ClassINET,
},
Mbox: "hostmaster.", // zone will be appended later if it's not empty or "."
}
if soa.Hdr.Ttl == 0 {
soa.Hdr.Ttl = defaultValues.BlockedResponseTTL
soa.Hdr.Ttl = defaultBlockedResponseTTL
}
if len(zone) > 0 && zone[0] != '.' {
soa.Mbox += zone

View File

@@ -8,7 +8,6 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
@@ -70,20 +69,26 @@ type dnsContext struct {
// isLocalClient shows if client's IP address is from locally served
// network.
isLocalClient bool
// isDHCPHost is true if the request for a local domain name and the DHCP is
// available for this request.
isDHCPHost bool
}
// resultCode is the result of a request processing function.
type resultCode int
const (
// resultCodeSuccess is returned when a handler performed successfully,
// and the next handler must be called.
// resultCodeSuccess is returned when a handler performed successfully, and
// the next handler must be called.
resultCodeSuccess resultCode = iota
// resultCodeFinish is returned when a handler performed successfully,
// and the processing of the request must be stopped.
// resultCodeFinish is returned when a handler performed successfully, and
// the processing of the request must be stopped.
resultCodeFinish
// resultCodeError is returned when a handler failed, and the processing
// of the request must be stopped.
// resultCodeError is returned when a handler failed, and the processing of
// the request must be stopped.
resultCodeError
)
@@ -239,70 +244,6 @@ func (s *Server) processClientIP(addr net.Addr) {
s.addrProc.Process(clientIP)
}
func (s *Server) setTableHostToIP(t hostToIPTable) {
s.tableHostToIPLock.Lock()
defer s.tableHostToIPLock.Unlock()
s.tableHostToIP = t
}
func (s *Server) setTableIPToHost(t ipToHostTable) {
s.tableIPToHostLock.Lock()
defer s.tableIPToHostLock.Unlock()
s.tableIPToHost = t
}
func (s *Server) onDHCPLeaseChanged(flags int) {
switch flags {
case dhcpd.LeaseChangedAdded,
dhcpd.LeaseChangedAddedStatic,
dhcpd.LeaseChangedRemovedStatic:
// Go on.
case dhcpd.LeaseChangedRemovedAll:
s.setTableHostToIP(nil)
s.setTableIPToHost(nil)
return
default:
return
}
ll := s.dhcpServer.Leases(dhcpd.LeasesAll)
hostToIP := make(hostToIPTable, len(ll))
ipToHost := make(ipToHostTable, len(ll))
for _, l := range ll {
// TODO(a.garipov): Remove this after we're finished with the client
// hostname validations in the DHCP server code.
err := netutil.ValidateHostname(l.Hostname)
if err != nil {
log.Debug("dnsforward: skipping invalid hostname %q from dhcp: %s", l.Hostname, err)
continue
}
lowhost := strings.ToLower(l.Hostname + "." + s.localDomainSuffix)
// Assume that we only process IPv4 now.
if !l.IP.Is4() {
log.Debug("dnsforward: skipping invalid ip from dhcp: bad ipv4 net.IP %v", l.IP)
continue
}
leaseIP := l.IP
ipToHost[leaseIP] = lowhost
hostToIP[lowhost] = leaseIP
}
s.setTableHostToIP(hostToIP)
s.setTableIPToHost(ipToHost)
log.Debug("dnsforward: added %d a and ptr entries from dhcp", len(ipToHost))
}
// processDDRQuery responds to Discovery of Designated Resolvers (DDR) SVCB
// queries. The response contains different types of encryption supported by
// current user configuration.
@@ -420,18 +361,6 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
return rc
}
// dhcpHostToIP tries to get an IP leased by DHCP and returns the copy of
// address since the data inside the internal table may be changed while request
// processing. It's safe for concurrent use.
func (s *Server) dhcpHostToIP(host string) (ip netip.Addr, ok bool) {
s.tableHostToIPLock.Lock()
defer s.tableHostToIPLock.Unlock()
ip, ok = s.tableHostToIP[host]
return ip, ok
}
// processDHCPHosts respond to A requests if the target hostname is known to
// the server. It responds with a mapped IP address if the DNS64 is enabled and
// the request is for AAAA.
@@ -443,30 +372,31 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
pctx := dctx.proxyCtx
req := pctx.Req
q := req.Question[0]
reqHost, ok := s.isDHCPClientHostQ(q)
if !ok {
q := &req.Question[0]
dhcpHost := s.dhcpHostFromRequest(q)
if dctx.isDHCPHost = dhcpHost != ""; !dctx.isDHCPHost {
return resultCodeSuccess
}
if !dctx.isLocalClient {
log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, reqHost)
log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, dhcpHost)
pctx.Res = s.genNXDomain(req)
// Do not even put into query log.
return resultCodeFinish
}
ip, ok := s.dhcpHostToIP(reqHost)
if !ok {
ip := s.dhcpServer.IPByHost(dhcpHost)
if ip == (netip.Addr{}) {
// Go on and process them with filters, including dnsrewrite ones, and
// possibly route them to a domain-specific upstream.
log.Debug("dnsforward: no dhcp record for %q", reqHost)
log.Debug("dnsforward: no dhcp record for %q", dhcpHost)
return resultCodeSuccess
}
log.Debug("dnsforward: dhcp record for %q is %s", reqHost, ip)
log.Debug("dnsforward: dhcp record for %q is %s", dhcpHost, ip)
resp := s.makeResponse(req)
switch q.Qtype {
@@ -638,17 +568,6 @@ func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess
}
// ipToDHCPHost tries to get a hostname leased by DHCP. It's safe for
// concurrent use.
func (s *Server) ipToDHCPHost(ip netip.Addr) (host string, ok bool) {
s.tableIPToHostLock.Lock()
defer s.tableIPToHostLock.Unlock()
host, ok = s.tableIPToHost[ip]
return host, ok
}
// processDHCPAddrs responds to PTR requests if the target IP is leased by the
// DHCP server.
func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
@@ -673,12 +592,12 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess
}
host, ok := s.ipToDHCPHost(ipAddr)
if !ok {
host := s.dhcpServer.HostByIP(ipAddr)
if host == "" {
return resultCodeSuccess
}
log.Debug("dnsforward: dhcp reverse record for %s is %q", ip, host)
log.Debug("dnsforward: dhcp client %s is %q", ip, host)
req := pctx.Req
resp := s.makeResponse(req)
@@ -686,10 +605,12 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
Hdr: dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypePTR,
Ttl: s.conf.BlockedResponseTTL,
Class: dns.ClassINET,
// TODO(e.burkov): Use [dhcpsvc.Lease.Expiry]. See
// https://github.com/AdguardTeam/AdGuardHome/issues/3932.
Ttl: s.dnsFilter.BlockedResponseTTL(),
Class: dns.ClassINET,
},
Ptr: dns.Fqdn(host),
Ptr: dns.Fqdn(strings.Join([]string{host, s.localDomainSuffix}, ".")),
}
resp.Answer = append(resp.Answer, ptr)
pctx.Res = resp
@@ -762,10 +683,6 @@ func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode)
s.serverLock.RLock()
defer s.serverLock.RUnlock()
if s.dnsFilter == nil {
return resultCodeSuccess
}
var err error
if ctx.result, err = s.filterDNSRequest(ctx); err != nil {
ctx.err = err
@@ -792,17 +709,18 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
pctx := dctx.proxyCtx
req := pctx.Req
q := req.Question[0]
if pctx.Res != nil {
// The response has already been set.
return resultCodeSuccess
} else if reqHost, ok := s.isDHCPClientHostQ(q); ok {
} else if dctx.isDHCPHost {
// A DHCP client hostname query that hasn't been handled or filtered.
// Respond with an NXDOMAIN.
//
// TODO(a.garipov): Route such queries to a custom upstream for the
// local domain name if there is one.
log.Debug("dnsforward: dhcp client hostname %q was not filtered", reqHost)
name := req.Question[0].Name
log.Debug("dnsforward: dhcp client hostname %q was not filtered", name[:len(name)-1])
pctx.Res = s.genNXDomain(req)
return resultCodeFinish
@@ -889,26 +807,26 @@ func (s *Server) setRespAD(pctx *proxy.DNSContext, reqWantsDNSSEC bool) {
}
}
// isDHCPClientHostQ returns true if q is from a request for a DHCP client
// hostname. If ok is true, reqHost contains the requested hostname.
func (s *Server) isDHCPClientHostQ(q dns.Question) (reqHost string, ok bool) {
// dhcpHostFromRequest returns a hostname from question, if the request is for a
// DHCP client's hostname when DHCP is enabled, and an empty string otherwise.
func (s *Server) dhcpHostFromRequest(q *dns.Question) (reqHost string) {
if !s.dhcpServer.Enabled() {
return "", false
return ""
}
// Include AAAA here, because despite the fact that we don't support it yet,
// the expected behavior here is to respond with an empty answer and not
// NXDOMAIN.
if qt := q.Qtype; qt != dns.TypeA && qt != dns.TypeAAAA {
return "", false
return ""
}
reqHost = strings.ToLower(q.Name[:len(q.Name)-1])
if strings.HasSuffix(reqHost, s.localDomainSuffix) {
return reqHost, true
if !netutil.IsImmediateSubdomain(reqHost, s.localDomainSuffix) {
return ""
}
return "", false
return reqHost[:len(reqHost)-len(s.localDomainSuffix)-1]
}
// setCustomUpstream sets custom upstream settings in pctx, if necessary.
@@ -972,7 +890,7 @@ func (s *Server) filterAfterResponse(dctx *dnsContext, pctx *proxy.DNSContext) (
// Check the response only if it's from an upstream. Don't check the
// response if the protection is disabled since dnsrewrite rules aren't
// applied to it anyway.
if !dctx.protectionEnabled || !dctx.responseFromUpstream || s.dnsFilter == nil {
if !dctx.protectionEnabled || !dctx.responseFromUpstream {
return resultCodeSuccess
}

View File

@@ -77,13 +77,15 @@ func TestServer_ProcessInitial(t *testing.T) {
t.Parallel()
c := ServerConfig{
FilteringConfig: FilteringConfig{
Config: Config{
AAAADisabled: tc.aaaaDisabled,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
}
s := createTestServer(t, &filtering.Config{}, c, nil)
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, c, nil)
var gotAddr netip.Addr
s.addrProc = &aghtest.AddressProcessor{
@@ -113,6 +115,101 @@ func TestServer_ProcessInitial(t *testing.T) {
}
}
func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
t.Parallel()
var (
testIPv4 net.IP = netip.MustParseAddr("1.1.1.1").AsSlice()
testIPv6 net.IP = netip.MustParseAddr("1234::cdef").AsSlice()
)
testCases := []struct {
name string
req *dns.Msg
aaaaDisabled bool
respAns []dns.RR
wantRC resultCode
wantRespAns []dns.RR
}{{
name: "pass",
req: createTestMessageWithType(aghtest.ReqFQDN, dns.TypeHTTPS),
aaaaDisabled: false,
respAns: newSVCBHintsAnswer(
aghtest.ReqFQDN,
[]dns.SVCBKeyValue{
&dns.SVCBIPv4Hint{Hint: []net.IP{testIPv4}},
&dns.SVCBIPv6Hint{Hint: []net.IP{testIPv6}},
},
),
wantRespAns: newSVCBHintsAnswer(
aghtest.ReqFQDN,
[]dns.SVCBKeyValue{
&dns.SVCBIPv4Hint{Hint: []net.IP{testIPv4}},
&dns.SVCBIPv6Hint{Hint: []net.IP{testIPv6}},
},
),
wantRC: resultCodeSuccess,
}, {
name: "filter",
req: createTestMessageWithType(aghtest.ReqFQDN, dns.TypeHTTPS),
aaaaDisabled: true,
respAns: newSVCBHintsAnswer(
aghtest.ReqFQDN,
[]dns.SVCBKeyValue{
&dns.SVCBIPv4Hint{Hint: []net.IP{testIPv4}},
&dns.SVCBIPv6Hint{Hint: []net.IP{testIPv6}},
},
),
wantRespAns: newSVCBHintsAnswer(
aghtest.ReqFQDN,
[]dns.SVCBKeyValue{
&dns.SVCBIPv4Hint{Hint: []net.IP{testIPv4}},
},
),
wantRC: resultCodeSuccess,
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
c := ServerConfig{
Config: Config{
AAAADisabled: tc.aaaaDisabled,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
}
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, c, nil)
resp := newResp(dns.RcodeSuccess, tc.req, tc.respAns)
dctx := &dnsContext{
setts: &filtering.Settings{
FilteringEnabled: true,
ProtectionEnabled: true,
},
protectionEnabled: true,
responseFromUpstream: true,
result: &filtering.Result{},
proxyCtx: &proxy.DNSContext{
Proto: proxy.ProtoUDP,
Req: tc.req,
Res: resp,
Addr: testClientAddr,
},
}
gotRC := s.processFilteringAfterResponse(dctx)
assert.Equal(t, tc.wantRC, gotRC)
assert.Equal(t, newResp(dns.RcodeSuccess, tc.req, tc.wantRespAns), dctx.proxyCtx.Res)
})
}
}
func TestServer_ProcessDDRQuery(t *testing.T) {
dohSVCB := &dns.SVCB{
Priority: 1,
@@ -245,15 +342,28 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
}
}
// createTestDNSFilter returns the minimum valid DNSFilter.
func createTestDNSFilter(t *testing.T) (f *filtering.DNSFilter) {
t.Helper()
f, err := filtering.New(&filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, []filtering.Filter{})
require.NoError(t, err)
return f
}
func prepareTestServer(t *testing.T, portDoH, portDoT, portDoQ int, ddrEnabled bool) (s *Server) {
t.Helper()
s = &Server{
dnsFilter: createTestDNSFilter(t),
dnsProxy: &proxy.Proxy{
Config: proxy.Config{},
},
conf: ServerConfig{
FilteringConfig: FilteringConfig{
Config: Config{
HandleDDR: ddrEnabled,
},
TLSConfig: TLSConfig{
@@ -323,47 +433,59 @@ func TestServer_ProcessDetermineLocal(t *testing.T) {
}
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
const (
localDomainSuffix = "lan"
dhcpClient = "example"
knownHost = dhcpClient + "." + localDomainSuffix
unknownHost = "wronghost." + localDomainSuffix
)
knownIP := netip.MustParseAddr("1.2.3.4")
dhcp := &testDHCP{
OnEnabled: func() (_ bool) { return true },
OnIPByHost: func(host string) (ip netip.Addr) {
if host == dhcpClient {
ip = knownIP
}
return ip
},
}
testCases := []struct {
wantIP netip.Addr
name string
host string
wantRes resultCode
isLocalCli bool
}{{
wantIP: knownIP,
name: "local_client_success",
host: "example.lan",
wantRes: resultCodeSuccess,
host: knownHost,
isLocalCli: true,
}, {
wantIP: netip.Addr{},
name: "local_client_unknown_host",
host: "wronghost.lan",
wantRes: resultCodeSuccess,
host: unknownHost,
isLocalCli: true,
}, {
wantIP: netip.Addr{},
name: "external_client_known_host",
host: "example.lan",
wantRes: resultCodeFinish,
host: knownHost,
isLocalCli: false,
}, {
wantIP: netip.Addr{},
name: "external_client_unknown_host",
host: "wronghost.lan",
wantRes: resultCodeFinish,
host: unknownHost,
isLocalCli: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := &Server{
dhcpServer: testDHCP,
localDomainSuffix: defaultLocalDomainSuffix,
tableHostToIP: hostToIPTable{
"example." + defaultLocalDomainSuffix: knownIP,
},
dnsFilter: createTestDNSFilter(t),
dhcpServer: dhcp,
localDomainSuffix: localDomainSuffix,
}
req := &dns.Msg{
@@ -385,43 +507,52 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
}
res := s.processDHCPHosts(dctx)
require.Equal(t, tc.wantRes, res)
pctx := dctx.proxyCtx
if tc.wantRes == resultCodeFinish {
if !tc.isLocalCli {
require.Equal(t, resultCodeFinish, res)
require.NotNil(t, pctx.Res)
assert.Equal(t, dns.RcodeNameError, pctx.Res.Rcode)
assert.Len(t, pctx.Res.Answer, 0)
assert.Empty(t, pctx.Res.Answer)
return
}
require.Equal(t, resultCodeSuccess, res)
if tc.wantIP == (netip.Addr{}) {
assert.Nil(t, pctx.Res)
} else {
require.NotNil(t, pctx.Res)
ans := pctx.Res.Answer
require.Len(t, ans, 1)
a := testutil.RequireTypeAssert[*dns.A](t, ans[0])
ip, err := netutil.IPToAddr(a.A, netutil.AddrFamilyIPv4)
require.NoError(t, err)
assert.Equal(t, tc.wantIP, ip)
return
}
require.NotNil(t, pctx.Res)
ans := pctx.Res.Answer
require.Len(t, ans, 1)
a := testutil.RequireTypeAssert[*dns.A](t, ans[0])
ip, err := netutil.IPToAddr(a.A, netutil.AddrFamilyIPv4)
require.NoError(t, err)
assert.Equal(t, tc.wantIP, ip)
})
}
}
func TestServer_ProcessDHCPHosts(t *testing.T) {
const (
examplecom = "example.com"
examplelan = "example." + defaultLocalDomainSuffix
localTLD = "lan"
knownClient = "example"
externalHost = knownClient + ".com"
clientHost = knownClient + "." + localTLD
)
knownIP := netip.MustParseAddr("1.2.3.4")
testCases := []struct {
wantIP netip.Addr
name string
@@ -431,55 +562,65 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
qtyp uint16
}{{
wantIP: netip.Addr{},
name: "success_external",
host: examplecom,
suffix: defaultLocalDomainSuffix,
name: "external",
host: externalHost,
suffix: localTLD,
wantRes: resultCodeSuccess,
qtyp: dns.TypeA,
}, {
wantIP: netip.Addr{},
name: "success_external_non_a",
host: examplecom,
suffix: defaultLocalDomainSuffix,
name: "external_non_a",
host: externalHost,
suffix: localTLD,
wantRes: resultCodeSuccess,
qtyp: dns.TypeCNAME,
}, {
wantIP: knownIP,
name: "success_internal",
host: examplelan,
suffix: defaultLocalDomainSuffix,
name: "internal",
host: clientHost,
suffix: localTLD,
wantRes: resultCodeSuccess,
qtyp: dns.TypeA,
}, {
wantIP: netip.Addr{},
name: "success_internal_unknown",
name: "internal_unknown",
host: "example-new.lan",
suffix: defaultLocalDomainSuffix,
suffix: localTLD,
wantRes: resultCodeSuccess,
qtyp: dns.TypeA,
}, {
wantIP: netip.Addr{},
name: "success_internal_aaaa",
host: examplelan,
suffix: defaultLocalDomainSuffix,
name: "internal_aaaa",
host: clientHost,
suffix: localTLD,
wantRes: resultCodeSuccess,
qtyp: dns.TypeAAAA,
}, {
wantIP: knownIP,
name: "success_custom_suffix",
host: "example.custom",
name: "custom_suffix",
host: knownClient + ".custom",
suffix: "custom",
wantRes: resultCodeSuccess,
qtyp: dns.TypeA,
}}
for _, tc := range testCases {
testDHCP := &testDHCP{
OnEnabled: func() (_ bool) { return true },
OnIPByHost: func(host string) (ip netip.Addr) {
if host == knownClient {
ip = knownIP
}
return ip
},
OnHostByIP: func(ip netip.Addr) (host string) { panic("not implemented") },
}
s := &Server{
dnsFilter: createTestDNSFilter(t),
dhcpServer: testDHCP,
localDomainSuffix: tc.suffix,
tableHostToIP: hostToIPTable{
"example." + tc.suffix: knownIP,
},
}
req := &dns.Msg{
@@ -504,13 +645,6 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
res := s.processDHCPHosts(dctx)
pctx := dctx.proxyCtx
assert.Equal(t, tc.wantRes, res)
if tc.wantRes == resultCodeFinish {
require.NotNil(t, pctx.Res)
assert.Equal(t, dns.RcodeNameError, pctx.Res.Rcode)
return
}
require.NoError(t, dctx.err)
if tc.qtyp == dns.TypeAAAA {
@@ -555,12 +689,14 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
), nil
})
s := createTestServer(t, &filtering.Config{}, ServerConfig{
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
// TODO(s.chzhen): Add tests where EDNSClientSubnet.Enabled is true.
// Improve FilteringConfig declaration for tests.
FilteringConfig: FilteringConfig{
// Improve Config declaration for tests.
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
}, ups)
@@ -631,11 +767,13 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
s := createTestServer(
t,
&filtering.Config{},
&filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
},
ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
FilteringConfig: FilteringConfig{
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
},

View File

@@ -139,10 +139,14 @@ func (s *Server) updateStats(
clientIP string,
) {
pctx := ctx.proxyCtx
e := stats.Entry{
e := &stats.Entry{
Domain: aghnet.NormalizeDomain(pctx.Req.Question[0].Name),
Result: stats.RNotFiltered,
Time: uint32(elapsed / 1000),
Time: elapsed,
}
if pctx.Upstream != nil {
e.Upstream = pctx.Upstream.Address()
}
if clientID := ctx.clientID; clientID != "" {

View File

@@ -41,11 +41,11 @@ type testStats struct {
// without actually implementing all methods.
stats.Interface
lastEntry stats.Entry
lastEntry *stats.Entry
}
// Update implements the [stats.Interface] interface for *testStats.
func (l *testStats) Update(e stats.Entry) {
func (l *testStats) Update(e *stats.Entry) {
if e.Domain == "" {
return
}

View File

@@ -4,6 +4,7 @@ import (
"net"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@@ -12,13 +13,13 @@ import (
func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
// Preconditions.
s := &Server{
conf: ServerConfig{
FilteringConfig: FilteringConfig{
BlockedResponseTTL: 3600,
},
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
}
}, nil)
req := &dns.Msg{
Question: []dns.Question{{

View File

@@ -11,6 +11,9 @@
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [
"9.9.9.10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
@@ -43,6 +46,9 @@
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [
"9.9.9.10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
@@ -75,6 +81,9 @@
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [
"9.9.9.10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,

View File

@@ -18,6 +18,7 @@
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
@@ -54,6 +55,7 @@
"bootstrap_dns": [
"9.9.9.10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
@@ -91,6 +93,7 @@
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
@@ -128,6 +131,7 @@
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
@@ -165,6 +169,7 @@
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 6,
@@ -202,6 +207,7 @@
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
@@ -241,6 +247,7 @@
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
@@ -280,6 +287,7 @@
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
@@ -317,6 +325,7 @@
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
@@ -354,6 +363,7 @@
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
@@ -391,6 +401,7 @@
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
@@ -428,6 +439,7 @@
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
@@ -467,6 +479,7 @@
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
@@ -506,6 +519,7 @@
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
@@ -544,6 +558,7 @@
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
@@ -581,6 +596,7 @@
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
@@ -620,6 +636,7 @@
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
@@ -662,6 +679,7 @@
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
@@ -699,6 +717,49 @@
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
"upstream_mode": "",
"cache_size": 0,
"cache_ttl_min": 0,
"cache_ttl_max": 0,
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
},
"fallbacks": {
"req": {
"fallback_dns": [
"9.9.9.10"
]
},
"want": {
"upstream_dns": [
"8.8.8.8:53",
"8.8.4.4:53"
],
"upstream_dns_file": "",
"bootstrap_dns": [
"9.9.9.10",
"149.112.112.10",
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [
"9.9.9.10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,

View File

@@ -14,8 +14,6 @@ import (
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/urlfilter"
"github.com/miekg/dns"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)
@@ -94,7 +92,8 @@ func (s *Server) prepareUpstreamConfig(
uc.Upstreams = defaultUpstreamConfig.Upstreams
}
if s.dnsFilter != nil && s.dnsFilter.EtcHosts != nil {
// dnsFilter can be nil during application update.
if s.dnsFilter != nil {
err = s.replaceUpstreamsWithHosts(uc, opts)
if err != nil {
return nil, fmt.Errorf("resolving upstreams with hosts: %w", err)
@@ -158,17 +157,20 @@ func (s *Server) resolveUpstreamsWithHosts(
withIPs, ok := resolved[host]
if !ok {
ips := s.resolveUpstreamHost(host)
if len(ips) == 0 {
recs := s.dnsFilter.EtcHostsRecords(host)
if len(recs) == 0 {
resolved[host] = nil
return nil
}
sortNetIPAddrs(ips, opts.PreferIPv6)
withIPs = opts.Clone()
withIPs.ServerIPAddrs = ips
withIPs.ServerIPAddrs = make([]net.IP, 0, len(recs))
for _, rec := range recs {
withIPs.ServerIPAddrs = append(withIPs.ServerIPAddrs, rec.Addr.AsSlice())
}
sortNetIPAddrs(withIPs.ServerIPAddrs, opts.PreferIPv6)
resolved[host] = withIPs
} else if withIPs == nil {
continue
@@ -216,33 +218,6 @@ func extractUpstreamHost(addr string) (host string) {
return host
}
// resolveUpstreamHost returns the version of ups with IP addresses from the
// system hosts file placed into its options.
func (s *Server) resolveUpstreamHost(host string) (addrs []net.IP) {
req := &urlfilter.DNSRequest{
Hostname: host,
DNSType: dns.TypeA,
}
aRes, _ := s.dnsFilter.EtcHosts.MatchRequest(req)
req.DNSType = dns.TypeAAAA
aaaaRes, _ := s.dnsFilter.EtcHosts.MatchRequest(req)
var ips []net.IP
for _, rw := range append(aRes.DNSRewrites(), aaaaRes.DNSRewrites()...) {
dr := rw.DNSRewrite
if dr == nil || dr.Value == nil {
continue
}
if ip, ok := dr.Value.(net.IP); ok {
ips = append(ips, ip)
}
}
return ips
}
// sortNetIPAddrs sorts addrs in accordance with the protocol preferences.
// Invalid addresses are sorted near the end.
//
@@ -254,28 +229,38 @@ func sortNetIPAddrs(addrs []net.IP, preferIPv6 bool) {
return
}
slices.SortStableFunc(addrs, func(addrA, addrB net.IP) (sortsBefore bool) {
slices.SortStableFunc(addrs, func(addrA, addrB net.IP) (res int) {
switch len(addrA) {
case net.IPv4len, net.IPv6len:
switch len(addrB) {
case net.IPv4len, net.IPv6len:
// Go on.
default:
return true
return -1
}
default:
return false
return 1
}
if aIs4, bIs4 := addrA.To4() != nil, addrB.To4() != nil; aIs4 != bIs4 {
if aIs4 {
return !preferIPv6
// Treat IPv6-mapped IPv4 addresses as IPv6 addresses.
aIs4, bIs4 := addrA.To4() != nil, addrB.To4() != nil
if aIs4 == bIs4 {
return bytes.Compare(addrA, addrB)
}
if aIs4 {
if preferIPv6 {
return 1
}
return preferIPv6
return -1
}
return bytes.Compare(addrA, addrB) < 0
if preferIPv6 {
return -1
}
return 1
})
}