all: sync with master; upd chlog
This commit is contained in:
@@ -28,6 +28,7 @@ type accessManager struct {
|
||||
allowedClientIDs *stringutil.Set
|
||||
blockedClientIDs *stringutil.Set
|
||||
|
||||
// TODO(s.chzhen): Use [aghnet.IgnoreEngine].
|
||||
blockedHostsEng *urlfilter.DNSEngine
|
||||
|
||||
// TODO(a.garipov): Create a type for a set of IP networks.
|
||||
@@ -131,7 +132,6 @@ func (a *accessManager) isBlockedClientID(id string) (ok bool) {
|
||||
func (a *accessManager) isBlockedHost(host string, qt rules.RRType) (ok bool) {
|
||||
_, ok = a.blockedHostsEng.MatchRequest(&urlfilter.DNSRequest{
|
||||
Hostname: host,
|
||||
ClientIP: "0.0.0.0",
|
||||
DNSType: qt,
|
||||
})
|
||||
|
||||
@@ -183,7 +183,7 @@ func (s *Server) accessListJSON() (j accessListJSON) {
|
||||
}
|
||||
|
||||
func (s *Server) handleAccessList(w http.ResponseWriter, r *http.Request) {
|
||||
_ = aghhttp.WriteJSONResponse(w, r, s.accessListJSON())
|
||||
aghhttp.WriteJSONResponseOK(w, r, s.accessListJSON())
|
||||
}
|
||||
|
||||
// validateAccessSet checks the internal accessListJSON lists. To search for
|
||||
|
||||
@@ -25,74 +25,19 @@ import (
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// BlockingMode is an enum of all allowed blocking modes.
|
||||
type BlockingMode string
|
||||
|
||||
// Allowed blocking modes.
|
||||
const (
|
||||
// BlockingModeCustomIP means respond with a custom IP address.
|
||||
BlockingModeCustomIP BlockingMode = "custom_ip"
|
||||
|
||||
// BlockingModeDefault is the same as BlockingModeNullIP for
|
||||
// Adblock-style rules, but responds with the IP address specified in
|
||||
// the rule when blocked by an `/etc/hosts`-style rule.
|
||||
BlockingModeDefault BlockingMode = "default"
|
||||
|
||||
// BlockingModeNullIP means respond with a zero IP address: "0.0.0.0"
|
||||
// for A requests and "::" for AAAA ones.
|
||||
BlockingModeNullIP BlockingMode = "null_ip"
|
||||
|
||||
// BlockingModeNXDOMAIN means respond with the NXDOMAIN code.
|
||||
BlockingModeNXDOMAIN BlockingMode = "nxdomain"
|
||||
|
||||
// BlockingModeREFUSED means respond with the REFUSED code.
|
||||
BlockingModeREFUSED BlockingMode = "refused"
|
||||
)
|
||||
|
||||
// FilteringConfig represents the DNS filtering configuration of AdGuard Home
|
||||
// The zero FilteringConfig is empty and ready for use.
|
||||
type FilteringConfig struct {
|
||||
// Config represents the DNS filtering configuration of AdGuard Home. The zero
|
||||
// Config is empty and ready for use.
|
||||
type Config struct {
|
||||
// Callbacks for other modules
|
||||
|
||||
// FilterHandler is an optional additional filtering callback.
|
||||
FilterHandler func(clientAddr net.IP, clientID string, settings *filtering.Settings) `yaml:"-"`
|
||||
FilterHandler func(cliAddr netip.Addr, clientID string, settings *filtering.Settings) `yaml:"-"`
|
||||
|
||||
// GetCustomUpstreamByClient is a callback that returns upstreams
|
||||
// configuration based on the client IP address or ClientID. It returns
|
||||
// nil if there are no custom upstreams for the client.
|
||||
GetCustomUpstreamByClient func(id string) (conf *proxy.UpstreamConfig, err error) `yaml:"-"`
|
||||
|
||||
// Protection configuration
|
||||
|
||||
// ProtectionEnabled defines whether or not use any of filtering features.
|
||||
ProtectionEnabled bool `yaml:"protection_enabled"`
|
||||
|
||||
// BlockingMode defines the way how blocked responses are constructed.
|
||||
BlockingMode BlockingMode `yaml:"blocking_mode"`
|
||||
|
||||
// BlockingIPv4 is the IP address to be returned for a blocked A request.
|
||||
BlockingIPv4 net.IP `yaml:"blocking_ipv4"`
|
||||
|
||||
// BlockingIPv6 is the IP address to be returned for a blocked AAAA
|
||||
// request.
|
||||
BlockingIPv6 net.IP `yaml:"blocking_ipv6"`
|
||||
|
||||
// BlockedResponseTTL is the time-to-live value for blocked responses. If
|
||||
// 0, then default value is used (3600).
|
||||
BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"`
|
||||
|
||||
// ProtectionDisabledUntil is the timestamp until when the protection is
|
||||
// disabled.
|
||||
ProtectionDisabledUntil *time.Time `yaml:"protection_disabled_until"`
|
||||
|
||||
// ParentalBlockHost is the IP (or domain name) which is used to respond to
|
||||
// DNS requests blocked by parental control.
|
||||
ParentalBlockHost string `yaml:"parental_block_host"`
|
||||
|
||||
// SafeBrowsingBlockHost is the IP (or domain name) which is used to
|
||||
// respond to DNS requests blocked by safe-browsing.
|
||||
SafeBrowsingBlockHost string `yaml:"safebrowsing_block_host"`
|
||||
|
||||
// Anti-DNS amplification
|
||||
|
||||
// Ratelimit is the maximum number of requests per second from a given IP
|
||||
@@ -118,6 +63,10 @@ type FilteringConfig struct {
|
||||
// resolvers (plain DNS only).
|
||||
BootstrapDNS []string `yaml:"bootstrap_dns"`
|
||||
|
||||
// FallbackDNS is the list of fallback DNS servers used when upstream DNS
|
||||
// servers are not responding.
|
||||
FallbackDNS []string `yaml:"fallback_dns"`
|
||||
|
||||
// AllServers, if true, parallel queries to all configured upstream servers
|
||||
// are enabled.
|
||||
AllServers bool `yaml:"all_servers"`
|
||||
@@ -133,7 +82,7 @@ type FilteringConfig struct {
|
||||
|
||||
// AllowedClients is the slice of IP addresses, CIDR networks, and
|
||||
// ClientIDs of allowed clients. If not empty, only these clients are
|
||||
// allowed, and [FilteringConfig.DisallowedClients] are ignored.
|
||||
// allowed, and [Config.DisallowedClients] are ignored.
|
||||
AllowedClients []string `yaml:"allowed_clients"`
|
||||
|
||||
// DisallowedClients is the slice of IP addresses, CIDR networks, and
|
||||
@@ -279,7 +228,7 @@ type ServerConfig struct {
|
||||
// Remove that.
|
||||
AddrProcConf *client.DefaultAddrProcConfig
|
||||
|
||||
FilteringConfig
|
||||
Config
|
||||
TLSConfig
|
||||
DNSCryptConfig
|
||||
TLSAllowUnencryptedDoH bool
|
||||
@@ -320,13 +269,6 @@ type ServerConfig struct {
|
||||
UseHTTP3Upstreams bool
|
||||
}
|
||||
|
||||
// if any of ServerConfig values are zero, then default values from below are used
|
||||
var defaultValues = ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{Port: 53}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{Port: 53}},
|
||||
FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600},
|
||||
}
|
||||
|
||||
// createProxyConfig creates and validates configuration for the main proxy.
|
||||
func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
|
||||
srvConf := s.conf
|
||||
@@ -399,10 +341,7 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
const (
|
||||
defaultSafeBrowsingBlockHost = "standard-block.dns.adguard.com"
|
||||
defaultParentalBlockHost = "family-block.dns.adguard.com"
|
||||
)
|
||||
const defaultBlockedResponseTTL = 3600
|
||||
|
||||
// initDefaultSettings initializes default settings if nothing
|
||||
// is configured
|
||||
@@ -415,20 +354,12 @@ func (s *Server) initDefaultSettings() {
|
||||
s.conf.BootstrapDNS = defaultBootstrap
|
||||
}
|
||||
|
||||
if s.conf.ParentalBlockHost == "" {
|
||||
s.conf.ParentalBlockHost = defaultParentalBlockHost
|
||||
}
|
||||
|
||||
if s.conf.SafeBrowsingBlockHost == "" {
|
||||
s.conf.SafeBrowsingBlockHost = defaultSafeBrowsingBlockHost
|
||||
}
|
||||
|
||||
if s.conf.UDPListenAddrs == nil {
|
||||
s.conf.UDPListenAddrs = defaultValues.UDPListenAddrs
|
||||
s.conf.UDPListenAddrs = defaultUDPListenAddrs
|
||||
}
|
||||
|
||||
if s.conf.TCPListenAddrs == nil {
|
||||
s.conf.TCPListenAddrs = defaultValues.TCPListenAddrs
|
||||
s.conf.TCPListenAddrs = defaultTCPListenAddrs
|
||||
}
|
||||
|
||||
if len(s.conf.BlockedHosts) == 0 {
|
||||
@@ -561,9 +492,9 @@ func (s *Server) UpdatedProtectionStatus() (enabled bool, disabledUntil *time.Ti
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
disabledUntil = s.conf.ProtectionDisabledUntil
|
||||
enabled, disabledUntil = s.dnsFilter.ProtectionStatus()
|
||||
if disabledUntil == nil {
|
||||
return s.conf.ProtectionEnabled, nil
|
||||
return enabled, nil
|
||||
}
|
||||
|
||||
if time.Now().Before(*disabledUntil) {
|
||||
@@ -595,8 +526,7 @@ func (s *Server) enableProtectionAfterPause() {
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
s.conf.ProtectionEnabled = true
|
||||
s.conf.ProtectionDisabledUntil = nil
|
||||
s.dnsFilter.SetProtectionStatus(true, nil)
|
||||
|
||||
log.Info("dns: protection is restarted after pause")
|
||||
}
|
||||
|
||||
@@ -49,9 +49,8 @@ func (s *Server) DialContext(ctx context.Context, network, addr string) (conn ne
|
||||
continue
|
||||
}
|
||||
|
||||
return conn, err
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Use errors.Join in Go 1.20.
|
||||
return nil, errors.List(fmt.Sprintf("dialing %q", addr), dialErrs...)
|
||||
return nil, errors.Join(dialErrs...)
|
||||
}
|
||||
|
||||
@@ -283,11 +283,13 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
|
||||
// right after stop, due to a data race in [proxy.Proxy.Init] method
|
||||
// when setting an OOB size. As a temporary workaround, recreate the
|
||||
// whole server for each test case.
|
||||
s := createTestServer(t, &filtering.Config{}, ServerConfig{
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
UseDNS64: true,
|
||||
FilteringConfig: FilteringConfig{
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
}, localUps)
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
|
||||
@@ -46,20 +45,16 @@ var defaultBootstrap = []string{"9.9.9.10", "149.112.112.10", "2620:fe::10", "26
|
||||
// Often requested by all kinds of DNS probes
|
||||
var defaultBlockedHosts = []string{"version.bind", "id.server", "hostname.bind"}
|
||||
|
||||
var (
|
||||
// defaultUDPListenAddrs are the default UDP addresses for the server.
|
||||
defaultUDPListenAddrs = []*net.UDPAddr{{Port: 53}}
|
||||
|
||||
// defaultTCPListenAddrs are the default TCP addresses for the server.
|
||||
defaultTCPListenAddrs = []*net.TCPAddr{{Port: 53}}
|
||||
)
|
||||
|
||||
var webRegistered bool
|
||||
|
||||
// hostToIPTable is a convenient type alias for tables of host names to an IP
|
||||
// address.
|
||||
//
|
||||
// TODO(e.burkov): Use the [DHCP] interface instead.
|
||||
type hostToIPTable = map[string]netip.Addr
|
||||
|
||||
// ipToHostTable is a convenient type alias for tables of IP addresses to their
|
||||
// host names. For example, for use with PTR queries.
|
||||
//
|
||||
// TODO(e.burkov): Use the [DHCP] interface instead.
|
||||
type ipToHostTable = map[netip.Addr]string
|
||||
|
||||
// DHCP is an interface for accessing DHCP lease data needed in this package.
|
||||
type DHCP interface {
|
||||
// HostByIP returns the hostname of the DHCP client with the given IP
|
||||
@@ -89,18 +84,34 @@ type DHCP interface {
|
||||
//
|
||||
// The zero Server is empty and ready for use.
|
||||
type Server struct {
|
||||
dnsProxy *proxy.Proxy // DNS proxy instance
|
||||
dnsFilter *filtering.DNSFilter // DNS filter instance
|
||||
dhcpServer dhcpd.Interface // DHCP server instance (optional)
|
||||
queryLog querylog.QueryLog // Query log instance
|
||||
stats stats.Interface
|
||||
access *accessManager
|
||||
// dnsProxy is the DNS proxy for forwarding client's DNS requests.
|
||||
dnsProxy *proxy.Proxy
|
||||
|
||||
// dnsFilter is the DNS filter for filtering client's DNS requests and
|
||||
// responses.
|
||||
dnsFilter *filtering.DNSFilter
|
||||
|
||||
// dhcpServer is the DHCP server for accessing lease data.
|
||||
dhcpServer DHCP
|
||||
|
||||
// queryLog is the query log for client's DNS requests, responses and
|
||||
// filtering results.
|
||||
queryLog querylog.QueryLog
|
||||
|
||||
// stats is the statistics collector for client's DNS usage data.
|
||||
stats stats.Interface
|
||||
|
||||
// access drops unallowed clients.
|
||||
access *accessManager
|
||||
|
||||
// localDomainSuffix is the suffix used to detect internal hosts. It
|
||||
// must be a valid domain name plus dots on each side.
|
||||
localDomainSuffix string
|
||||
|
||||
ipset ipsetCtx
|
||||
// ipset processes DNS requests using ipset data.
|
||||
ipset ipsetCtx
|
||||
|
||||
// privateNets is the configured set of IP networks considered private.
|
||||
privateNets netutil.SubnetSet
|
||||
|
||||
// addrProc, if not nil, is used to process clients' IP addresses with rDNS,
|
||||
@@ -112,7 +123,10 @@ type Server struct {
|
||||
//
|
||||
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
|
||||
localResolvers *proxy.Proxy
|
||||
sysResolvers aghnet.SystemResolvers
|
||||
|
||||
// sysResolvers used to fetch system resolvers to use by default for private
|
||||
// PTR resolving.
|
||||
sysResolvers aghnet.SystemResolvers
|
||||
|
||||
// recDetector is a cache for recursive requests. It is used to detect
|
||||
// and prevent recursive requests only for private upstreams.
|
||||
@@ -128,12 +142,6 @@ type Server struct {
|
||||
// anonymizer masks the client's IP addresses if needed.
|
||||
anonymizer *aghnet.IPMut
|
||||
|
||||
tableHostToIP hostToIPTable
|
||||
tableHostToIPLock sync.Mutex
|
||||
|
||||
tableIPToHost ipToHostTable
|
||||
tableIPToHostLock sync.Mutex
|
||||
|
||||
// clientIDCache is a temporary storage for ClientIDs that were extracted
|
||||
// during the BeforeRequestHandler stage.
|
||||
clientIDCache cache.Cache
|
||||
@@ -142,13 +150,16 @@ type Server struct {
|
||||
// We don't Start() it and so no listen port is required.
|
||||
internalProxy *proxy.Proxy
|
||||
|
||||
// isRunning is true if the DNS server is running.
|
||||
isRunning bool
|
||||
|
||||
// protectionUpdateInProgress is used to make sure that only one goroutine
|
||||
// updating the protection configuration after a pause is running at a time.
|
||||
protectionUpdateInProgress atomic.Bool
|
||||
|
||||
// conf is the current configuration of the server.
|
||||
conf ServerConfig
|
||||
|
||||
// serverLock protects Server.
|
||||
serverLock sync.RWMutex
|
||||
}
|
||||
@@ -164,7 +175,7 @@ type DNSCreateParams struct {
|
||||
DNSFilter *filtering.DNSFilter
|
||||
Stats stats.Interface
|
||||
QueryLog querylog.QueryLog
|
||||
DHCPServer dhcpd.Interface
|
||||
DHCPServer DHCP
|
||||
PrivateNets netutil.SubnetSet
|
||||
Anonymizer *aghnet.IPMut
|
||||
LocalDomain string
|
||||
@@ -200,11 +211,12 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
p.Anonymizer = aghnet.NewIPMut(nil)
|
||||
}
|
||||
s = &Server{
|
||||
dnsFilter: p.DNSFilter,
|
||||
stats: p.Stats,
|
||||
queryLog: p.QueryLog,
|
||||
privateNets: p.PrivateNets,
|
||||
localDomainSuffix: localDomainSuffix,
|
||||
dnsFilter: p.DNSFilter,
|
||||
stats: p.Stats,
|
||||
queryLog: p.QueryLog,
|
||||
privateNets: p.PrivateNets,
|
||||
// TODO(e.burkov): Use some case-insensitive string comparison.
|
||||
localDomainSuffix: strings.ToLower(localDomainSuffix),
|
||||
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
|
||||
clientIDCache: cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
@@ -220,11 +232,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
return nil, fmt.Errorf("initializing system resolvers: %w", err)
|
||||
}
|
||||
|
||||
if p.DHCPServer != nil {
|
||||
s.dhcpServer = p.DHCPServer
|
||||
s.dhcpServer.SetOnLeaseChanged(s.onDHCPLeaseChanged)
|
||||
s.onDHCPLeaseChanged(dhcpd.LeaseChangedAdded)
|
||||
}
|
||||
s.dhcpServer = p.DHCPServer
|
||||
|
||||
if runtime.GOARCH == "mips" || runtime.GOARCH == "mipsle" {
|
||||
// Use plain DNS on MIPS, encryption is too slow
|
||||
@@ -244,7 +252,7 @@ func (s *Server) Close() {
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
s.dnsFilter = nil
|
||||
// TODO(s.chzhen): Remove it.
|
||||
s.stats = nil
|
||||
s.queryLog = nil
|
||||
s.dnsProxy = nil
|
||||
@@ -255,14 +263,15 @@ func (s *Server) Close() {
|
||||
}
|
||||
|
||||
// WriteDiskConfig - write configuration
|
||||
func (s *Server) WriteDiskConfig(c *FilteringConfig) {
|
||||
func (s *Server) WriteDiskConfig(c *Config) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
sc := s.conf.FilteringConfig
|
||||
sc := s.conf.Config
|
||||
*c = sc
|
||||
c.RatelimitWhitelist = stringutil.CloneSlice(sc.RatelimitWhitelist)
|
||||
c.BootstrapDNS = stringutil.CloneSlice(sc.BootstrapDNS)
|
||||
c.FallbackDNS = stringutil.CloneSlice(sc.FallbackDNS)
|
||||
c.AllowedClients = stringutil.CloneSlice(sc.AllowedClients)
|
||||
c.DisallowedClients = stringutil.CloneSlice(sc.DisallowedClients)
|
||||
c.BlockedHosts = stringutil.CloneSlice(sc.BlockedHosts)
|
||||
@@ -533,9 +542,13 @@ func (s *Server) setupLocalResolvers() (err error) {
|
||||
func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
s.conf = *conf
|
||||
|
||||
err = validateBlockingMode(s.conf.BlockingMode, s.conf.BlockingIPv4, s.conf.BlockingIPv6)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking blocking mode: %w", err)
|
||||
// dnsFilter can be nil during application update.
|
||||
if s.dnsFilter != nil {
|
||||
mode, bIPv4, bIPv6 := s.dnsFilter.BlockingMode()
|
||||
err = validateBlockingMode(mode, bIPv4, bIPv6)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking blocking mode: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.initDefaultSettings()
|
||||
@@ -584,6 +597,11 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
return fmt.Errorf("setting up resolvers: %w", err)
|
||||
}
|
||||
|
||||
err = s.setupFallbackDNS()
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting up fallback dns servers: %w", err)
|
||||
}
|
||||
|
||||
s.recDetector.clear()
|
||||
|
||||
s.setupAddrProc()
|
||||
@@ -593,6 +611,28 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupFallbackDNS initializes the fallback DNS servers.
|
||||
func (s *Server) setupFallbackDNS() (err error) {
|
||||
fallbacks := s.conf.FallbackDNS
|
||||
if len(fallbacks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
uc, err := proxy.ParseUpstreamsConfig(fallbacks, &upstream.Options{
|
||||
// TODO(s.chzhen): Investigate if other options are needed.
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
})
|
||||
if err != nil {
|
||||
// Do not wrap the error because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
s.dnsProxy.Fallbacks = uc
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupAddrProc initializes the address processor. For internal use only.
|
||||
func (s *Server) setupAddrProc() {
|
||||
// TODO(a.garipov): This is a crutch for tests; remove.
|
||||
@@ -617,19 +657,22 @@ func (s *Server) setupAddrProc() {
|
||||
}
|
||||
|
||||
// validateBlockingMode returns an error if the blocking mode data aren't valid.
|
||||
func validateBlockingMode(mode BlockingMode, blockingIPv4, blockingIPv6 net.IP) (err error) {
|
||||
func validateBlockingMode(
|
||||
mode filtering.BlockingMode,
|
||||
blockingIPv4, blockingIPv6 netip.Addr,
|
||||
) (err error) {
|
||||
switch mode {
|
||||
case
|
||||
BlockingModeDefault,
|
||||
BlockingModeNXDOMAIN,
|
||||
BlockingModeREFUSED,
|
||||
BlockingModeNullIP:
|
||||
filtering.BlockingModeDefault,
|
||||
filtering.BlockingModeNXDOMAIN,
|
||||
filtering.BlockingModeREFUSED,
|
||||
filtering.BlockingModeNullIP:
|
||||
return nil
|
||||
case BlockingModeCustomIP:
|
||||
if blockingIPv4 == nil {
|
||||
return fmt.Errorf("blocking_ipv4 must be set when blocking_mode is custom_ip")
|
||||
} else if blockingIPv6 == nil {
|
||||
return fmt.Errorf("blocking_ipv6 must be set when blocking_mode is custom_ip")
|
||||
case filtering.BlockingModeCustomIP:
|
||||
if !blockingIPv4.Is4() {
|
||||
return fmt.Errorf("blocking_ipv4 must be valid ipv4 on custom_ip blocking_mode")
|
||||
} else if !blockingIPv6.Is6() {
|
||||
return fmt.Errorf("blocking_ipv6 must be valid ipv6 on custom_ip blocking_mode")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/hashprefix"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||
@@ -94,17 +93,18 @@ func createTestServer(
|
||||
|
||||
f.SetEnabled(true)
|
||||
|
||||
dhcp := &testDHCP{
|
||||
OnEnabled: func() (ok bool) { return false },
|
||||
OnHostByIP: func(ip netip.Addr) (host string) { return "" },
|
||||
OnIPByHost: func(host string) (ip netip.Addr) { panic("not implemented") },
|
||||
}
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: testDHCP,
|
||||
DHCPServer: dhcp,
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
if forwardConf.BlockingMode == "" {
|
||||
forwardConf.BlockingMode = BlockingModeDefault
|
||||
}
|
||||
|
||||
err = s.Prepare(&forwardConf)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -174,10 +174,12 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte)
|
||||
var keyPem []byte
|
||||
_, certPem, keyPem = createServerTLSConfig(t)
|
||||
|
||||
s = createTestServer(t, &filtering.Config{}, ServerConfig{
|
||||
s = createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
FilteringConfig: FilteringConfig{
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
}, nil)
|
||||
@@ -231,18 +233,29 @@ func createTestMessageWithType(host string, qtype uint16) *dns.Msg {
|
||||
return req
|
||||
}
|
||||
|
||||
func assertGoogleAResponse(t *testing.T, reply *dns.Msg) {
|
||||
assertResponse(t, reply, net.IP{8, 8, 8, 8})
|
||||
// newResp returns the new DNS response with response code set to rcode, req
|
||||
// used as request, and rrs added.
|
||||
func newResp(rcode int, req *dns.Msg, ans []dns.RR) (resp *dns.Msg) {
|
||||
resp = (&dns.Msg{}).SetRcode(req, rcode)
|
||||
resp.RecursionAvailable = true
|
||||
resp.Compress = true
|
||||
resp.Answer = ans
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
func assertResponse(t *testing.T, reply *dns.Msg, ip net.IP) {
|
||||
func assertGoogleAResponse(t *testing.T, reply *dns.Msg) {
|
||||
assertResponse(t, reply, netip.AddrFrom4([4]byte{8, 8, 8, 8}))
|
||||
}
|
||||
|
||||
func assertResponse(t *testing.T, reply *dns.Msg, ip netip.Addr) {
|
||||
t.Helper()
|
||||
|
||||
require.Lenf(t, reply.Answer, 1, "dns server returned reply with wrong number of answers - %d", len(reply.Answer))
|
||||
|
||||
a, ok := reply.Answer[0].(*dns.A)
|
||||
require.Truef(t, ok, "dns server returned wrong answer type instead of A: %v", reply.Answer[0])
|
||||
assert.Truef(t, a.A.Equal(ip), "dns server returned wrong answer instead of %s: %s", ip, a.A)
|
||||
assert.Equal(t, net.IP(ip.AsSlice()), a.A)
|
||||
}
|
||||
|
||||
// sendTestMessagesAsync sends messages in parallel to check for race issues.
|
||||
@@ -288,10 +301,12 @@ func sendTestMessages(t *testing.T, conn *dns.Conn) {
|
||||
}
|
||||
|
||||
func TestServer(t *testing.T) {
|
||||
s := createTestServer(t, &filtering.Config{}, ServerConfig{
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
FilteringConfig: FilteringConfig{
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
}, nil)
|
||||
@@ -329,13 +344,12 @@ func TestServer_timeout(t *testing.T) {
|
||||
t.Run("custom", func(t *testing.T) {
|
||||
srvConf := &ServerConfig{
|
||||
UpstreamTimeout: testTimeout,
|
||||
FilteringConfig: FilteringConfig{
|
||||
BlockingMode: BlockingModeDefault,
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
}
|
||||
|
||||
s, err := NewServer(DNSCreateParams{})
|
||||
s, err := NewServer(DNSCreateParams{DNSFilter: createTestDNSFilter(t)})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Prepare(srvConf)
|
||||
@@ -345,11 +359,10 @@ func TestServer_timeout(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("default", func(t *testing.T) {
|
||||
s, err := NewServer(DNSCreateParams{})
|
||||
s, err := NewServer(DNSCreateParams{DNSFilter: createTestDNSFilter(t)})
|
||||
require.NoError(t, err)
|
||||
|
||||
s.conf.FilteringConfig.BlockingMode = BlockingModeDefault
|
||||
s.conf.FilteringConfig.EDNSClientSubnet = &EDNSClientSubnet{
|
||||
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
}
|
||||
err = s.Prepare(&s.conf)
|
||||
@@ -360,10 +373,12 @@ func TestServer_timeout(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServerWithProtectionDisabled(t *testing.T) {
|
||||
s := createTestServer(t, &filtering.Config{}, ServerConfig{
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
FilteringConfig: FilteringConfig{
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
}, nil)
|
||||
@@ -439,9 +454,8 @@ func TestServerRace(t *testing.T) {
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
FilteringConfig: FilteringConfig{
|
||||
ProtectionEnabled: true,
|
||||
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
Config: Config{
|
||||
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
},
|
||||
ConfigModified: func() {},
|
||||
}
|
||||
@@ -462,7 +476,7 @@ func TestSafeSearch(t *testing.T) {
|
||||
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
|
||||
ip4, ip6 := aghtest.HostToIPs(host)
|
||||
|
||||
return []net.IP{ip4, ip6}, nil
|
||||
return []net.IP{ip4.AsSlice(), ip6.AsSlice()}, nil
|
||||
},
|
||||
}
|
||||
|
||||
@@ -474,6 +488,8 @@ func TestSafeSearch(t *testing.T) {
|
||||
}
|
||||
|
||||
filterConf := &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
ProtectionEnabled: true,
|
||||
SafeSearchConf: safeSearchConf,
|
||||
SafeSearchCacheSize: 1000,
|
||||
CacheTime: 30,
|
||||
@@ -490,8 +506,7 @@ func TestSafeSearch(t *testing.T) {
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
FilteringConfig: FilteringConfig{
|
||||
ProtectionEnabled: true,
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
@@ -503,12 +518,12 @@ func TestSafeSearch(t *testing.T) {
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
|
||||
client := &dns.Client{}
|
||||
|
||||
yandexIP := net.IP{213, 180, 193, 56}
|
||||
yandexIP := netip.AddrFrom4([4]byte{213, 180, 193, 56})
|
||||
googleIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
|
||||
|
||||
testCases := []struct {
|
||||
host string
|
||||
want net.IP
|
||||
want netip.Addr
|
||||
}{{
|
||||
host: "yandex.com.",
|
||||
want: yandexIP,
|
||||
@@ -548,10 +563,12 @@ func TestSafeSearch(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInvalidRequest(t *testing.T) {
|
||||
s := createTestServer(t, &filtering.Config{}, ServerConfig{
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
FilteringConfig: FilteringConfig{
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
@@ -579,15 +596,16 @@ func TestBlockedRequest(t *testing.T) {
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
FilteringConfig: FilteringConfig{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: BlockingModeDefault,
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, forwardConf, nil)
|
||||
startDeferStop(t, s)
|
||||
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
@@ -608,14 +626,15 @@ func TestServerCustomClientUpstream(t *testing.T) {
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
FilteringConfig: FilteringConfig{
|
||||
ProtectionEnabled: true,
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, forwardConf, nil)
|
||||
s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) {
|
||||
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return aghalg.Coalesce(
|
||||
@@ -658,10 +677,12 @@ var testIPv4 = map[string][]net.IP{
|
||||
}
|
||||
|
||||
func TestBlockCNAMEProtectionEnabled(t *testing.T) {
|
||||
s := createTestServer(t, &filtering.Config{}, ServerConfig{
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
FilteringConfig: FilteringConfig{
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
@@ -671,7 +692,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
|
||||
CName: testCNAMEs,
|
||||
IPv4: testIPv4,
|
||||
}
|
||||
s.conf.ProtectionEnabled = false
|
||||
|
||||
s.dnsProxy.UpstreamConfig = &proxy.UpstreamConfig{
|
||||
Upstreams: []upstream.Upstream{testUpstm},
|
||||
}
|
||||
@@ -693,15 +714,16 @@ func TestBlockCNAME(t *testing.T) {
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
FilteringConfig: FilteringConfig{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: BlockingModeDefault,
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, forwardConf, nil)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.Upstream{
|
||||
CName: testCNAMEs,
|
||||
@@ -763,9 +785,8 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
FilteringConfig: FilteringConfig{
|
||||
ProtectionEnabled: true,
|
||||
FilterHandler: func(_ net.IP, _ string, settings *filtering.Settings) {
|
||||
Config: Config{
|
||||
FilterHandler: func(_ netip.Addr, _ string, settings *filtering.Settings) {
|
||||
settings.FilteringEnabled = false
|
||||
},
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
@@ -773,7 +794,9 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, forwardConf, nil)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.Upstream{
|
||||
CName: testCNAMEs,
|
||||
@@ -809,15 +832,16 @@ func TestNullBlockedRequest(t *testing.T) {
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
FilteringConfig: FilteringConfig{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: BlockingModeNullIP,
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: filtering.BlockingModeNullIP,
|
||||
}, forwardConf, nil)
|
||||
startDeferStop(t, s)
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
|
||||
@@ -849,11 +873,21 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||
Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f, err := filtering.New(&filtering.Config{}, filters)
|
||||
f, err := filtering.New(&filtering.Config{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: filtering.BlockingModeCustomIP,
|
||||
BlockingIPv4: netip.Addr{},
|
||||
BlockingIPv6: netip.Addr{},
|
||||
}, filters)
|
||||
require.NoError(t, err)
|
||||
|
||||
dhcp := &testDHCP{
|
||||
OnEnabled: func() (ok bool) { return false },
|
||||
OnHostByIP: func(_ netip.Addr) (host string) { panic("not implemented") },
|
||||
OnIPByHost: func(_ string) (ip netip.Addr) { panic("not implemented") },
|
||||
}
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: testDHCP,
|
||||
DHCPServer: dhcp,
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
@@ -862,11 +896,8 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||
conf := &ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
FilteringConfig: FilteringConfig{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: BlockingModeCustomIP,
|
||||
BlockingIPv4: nil,
|
||||
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
Config: Config{
|
||||
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
@@ -877,8 +908,10 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||
err = s.Prepare(conf)
|
||||
assert.Error(t, err)
|
||||
|
||||
conf.BlockingIPv4 = net.IP{0, 0, 0, 1}
|
||||
conf.BlockingIPv6 = net.ParseIP("::1")
|
||||
s.dnsFilter.SetBlockingMode(
|
||||
filtering.BlockingModeCustomIP,
|
||||
netip.AddrFrom4([4]byte{0, 0, 0, 1}),
|
||||
netip.MustParseAddr("::1"))
|
||||
|
||||
err = s.Prepare(conf)
|
||||
require.NoError(t, err)
|
||||
@@ -915,16 +948,17 @@ func TestBlockedByHosts(t *testing.T) {
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
FilteringConfig: FilteringConfig{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: BlockingModeDefault,
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, forwardConf, nil)
|
||||
startDeferStop(t, s)
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
|
||||
@@ -955,15 +989,16 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
ans4, _ := aghtest.HostToIPs(hostname)
|
||||
|
||||
filterConf := &filtering.Config{
|
||||
SafeBrowsingEnabled: true,
|
||||
SafeBrowsingChecker: sbChecker,
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
ProtectionEnabled: true,
|
||||
SafeBrowsingEnabled: true,
|
||||
SafeBrowsingChecker: sbChecker,
|
||||
SafeBrowsingBlockHost: ans4.String(),
|
||||
}
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
FilteringConfig: FilteringConfig{
|
||||
SafeBrowsingBlockHost: ans4.String(),
|
||||
ProtectionEnabled: true,
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
@@ -980,13 +1015,12 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
require.NoErrorf(t, err, "couldn't talk to server %s: %s", addr, err)
|
||||
require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
|
||||
|
||||
a, ok := reply.Answer[0].(*dns.A)
|
||||
require.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
|
||||
assert.Equal(t, ans4, a.A, "dns server %s returned wrong answer: %v", addr, a.A)
|
||||
assertResponse(t, reply, ans4)
|
||||
}
|
||||
|
||||
func TestRewrite(t *testing.T) {
|
||||
c := &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
Rewrites: []*filtering.LegacyRewrite{{
|
||||
Domain: "test.com",
|
||||
Answer: "1.2.3.4",
|
||||
@@ -1006,8 +1040,13 @@ func TestRewrite(t *testing.T) {
|
||||
|
||||
f.SetEnabled(true)
|
||||
|
||||
dhcp := &testDHCP{
|
||||
OnEnabled: func() (ok bool) { return false },
|
||||
OnHostByIP: func(ip netip.Addr) (host string) { panic("not implemented") },
|
||||
OnIPByHost: func(host string) (ip netip.Addr) { panic("not implemented") },
|
||||
}
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: testDHCP,
|
||||
DHCPServer: dhcp,
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
@@ -1016,10 +1055,8 @@ func TestRewrite(t *testing.T) {
|
||||
assert.NoError(t, s.Prepare(&ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
FilteringConfig: FilteringConfig{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: BlockingModeDefault,
|
||||
UpstreamDNS: []string{"8.8.8.8:53"},
|
||||
Config: Config{
|
||||
UpstreamDNS: []string{"8.8.8.8:53"},
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
@@ -1102,31 +1139,42 @@ func publicKey(priv any) any {
|
||||
}
|
||||
}
|
||||
|
||||
var testDHCP = &dhcpd.MockInterface{
|
||||
OnStart: func() (err error) { panic("not implemented") },
|
||||
OnStop: func() (err error) { panic("not implemented") },
|
||||
OnEnabled: func() (ok bool) { return true },
|
||||
OnLeases: func(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) {
|
||||
return []*dhcpd.Lease{{
|
||||
IP: netip.MustParseAddr("192.168.12.34"),
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
Hostname: "myhost",
|
||||
}}
|
||||
},
|
||||
OnSetOnLeaseChanged: func(olct dhcpd.OnLeaseChangedT) {},
|
||||
OnFindMACbyIP: func(ip netip.Addr) (mac net.HardwareAddr) { panic("not implemented") },
|
||||
OnWriteDiskConfig: func(c *dhcpd.ServerConfig) { panic("not implemented") },
|
||||
// testDHCP is a mock implementation of the [DHCP] interface.
|
||||
type testDHCP struct {
|
||||
OnHostByIP func(ip netip.Addr) (host string)
|
||||
OnIPByHost func(host string) (ip netip.Addr)
|
||||
OnEnabled func() (ok bool)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ DHCP = (*testDHCP)(nil)
|
||||
|
||||
// HostByIP implements the [DHCP] interface for *testDHCP.
|
||||
func (d *testDHCP) HostByIP(ip netip.Addr) (host string) { return d.OnHostByIP(ip) }
|
||||
|
||||
// IPByHost implements the [DHCP] interface for *testDHCP.
|
||||
func (d *testDHCP) IPByHost(host string) (ip netip.Addr) { return d.OnIPByHost(host) }
|
||||
|
||||
// IsClientHost implements the [DHCP] interface for *testDHCP.
|
||||
func (d *testDHCP) Enabled() (ok bool) { return d.OnEnabled() }
|
||||
|
||||
func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
const localDomain = "lan"
|
||||
|
||||
flt, err := filtering.New(&filtering.Config{}, nil)
|
||||
flt, err := filtering.New(&filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DNSFilter: flt,
|
||||
DHCPServer: testDHCP,
|
||||
DNSFilter: flt,
|
||||
DHCPServer: &testDHCP{
|
||||
OnEnabled: func() (ok bool) { return true },
|
||||
OnIPByHost: func(host string) (ip netip.Addr) { panic("not implemented") },
|
||||
OnHostByIP: func(ip netip.Addr) (host string) {
|
||||
return "myhost"
|
||||
},
|
||||
},
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
LocalDomain: localDomain,
|
||||
})
|
||||
@@ -1135,9 +1183,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
|
||||
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
|
||||
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
|
||||
s.conf.FilteringConfig.ProtectionEnabled = true
|
||||
s.conf.FilteringConfig.BlockingMode = BlockingModeDefault
|
||||
s.conf.FilteringConfig.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
|
||||
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
|
||||
|
||||
err = s.Prepare(&s.conf)
|
||||
require.NoError(t, err)
|
||||
@@ -1175,8 +1221,14 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
`)},
|
||||
}
|
||||
|
||||
dhcp := &testDHCP{
|
||||
OnEnabled: func() (ok bool) { return false },
|
||||
OnIPByHost: func(host string) (ip netip.Addr) { panic("not implemented") },
|
||||
OnHostByIP: func(ip netip.Addr) (host string) { return "" },
|
||||
}
|
||||
|
||||
var eventsCalledCounter uint32
|
||||
hc, err := aghnet.NewHostsContainer(0, testFS, &aghtest.FSWatcher{
|
||||
hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{
|
||||
OnEvents: func() (e <-chan struct{}) {
|
||||
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
|
||||
|
||||
@@ -1195,7 +1247,8 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
})
|
||||
|
||||
flt, err := filtering.New(&filtering.Config{
|
||||
EtcHosts: hc,
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
EtcHosts: hc,
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1203,7 +1256,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
|
||||
var s *Server
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: testDHCP,
|
||||
DHCPServer: dhcp,
|
||||
DNSFilter: flt,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
@@ -1212,8 +1265,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
|
||||
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
|
||||
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
|
||||
s.conf.FilteringConfig.BlockingMode = BlockingModeDefault
|
||||
s.conf.FilteringConfig.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
|
||||
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
|
||||
|
||||
err = s.Prepare(&s.conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -2,7 +2,7 @@ package dnsforward
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
@@ -44,13 +44,13 @@ func (s *Server) ansFromDNSRewriteIP(
|
||||
rr rules.RRType,
|
||||
req *dns.Msg,
|
||||
) (ans dns.RR, err error) {
|
||||
ip, ok := v.(net.IP)
|
||||
ip, ok := v.(netip.Addr)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value for rr type %s has type %T, not net.IP", dns.Type(rr), v)
|
||||
return nil, fmt.Errorf("value for rr type %s has type %T, not netip.Addr", dns.Type(rr), v)
|
||||
}
|
||||
|
||||
if rr == dns.TypeA {
|
||||
return s.genAnswerA(req, ip.To4()), nil
|
||||
return s.genAnswerA(req, ip), nil
|
||||
}
|
||||
|
||||
return s.genAnswerAAAA(req, ip), nil
|
||||
@@ -160,13 +160,13 @@ func (s *Server) filterDNSRewrite(
|
||||
return errors.Error("no dns rewrite rule responses")
|
||||
}
|
||||
|
||||
rr := req.Question[0].Qtype
|
||||
values := dnsrr.Response[rr]
|
||||
qtype := req.Question[0].Qtype
|
||||
values := dnsrr.Response[qtype]
|
||||
for i, v := range values {
|
||||
var ans dns.RR
|
||||
ans, err = s.filterDNSRewriteResponse(req, rr, v)
|
||||
ans, err = s.filterDNSRewriteResponse(req, qtype, v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dns rewrite response for %d[%d]: %w", rr, i, err)
|
||||
return fmt.Errorf("dns rewrite response for %s[%d]: %w", dns.Type(qtype), i, err)
|
||||
}
|
||||
|
||||
resp.Answer = append(resp.Answer, ans)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -15,8 +16,7 @@ import (
|
||||
func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
// Helper data.
|
||||
const domain = "example.com"
|
||||
ip4 := net.IP{127, 0, 0, 1}
|
||||
ip6 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||||
ip4, ip6 := netutil.IPv4Localhost(), netutil.IPv6Localhost()
|
||||
mxVal := &rules.DNSMX{
|
||||
Exchange: "mail.example.com",
|
||||
Preference: 32,
|
||||
@@ -34,7 +34,14 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
}
|
||||
|
||||
// Helper functions and entities.
|
||||
srv := &Server{}
|
||||
srv := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
makeQ := func(qtype rules.RRType) (req *dns.Msg) {
|
||||
return &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
@@ -89,7 +96,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
|
||||
require.Len(t, d.Res.Answer, 1)
|
||||
assert.Equal(t, ip4, d.Res.Answer[0].(*dns.A).A)
|
||||
assert.Equal(t, net.IP(ip4.AsSlice()), d.Res.Answer[0].(*dns.A).A)
|
||||
})
|
||||
|
||||
t.Run("noerror_aaaa", func(t *testing.T) {
|
||||
@@ -103,7 +110,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
|
||||
require.Len(t, d.Res.Answer, 1)
|
||||
assert.Equal(t, ip6, d.Res.Answer[0].(*dns.AAAA).AAAA)
|
||||
assert.Equal(t, net.IP(ip6.AsSlice()), d.Res.Answer[0].(*dns.AAAA).AAAA)
|
||||
})
|
||||
|
||||
t.Run("noerror_ptr", func(t *testing.T) {
|
||||
|
||||
@@ -3,6 +3,7 @@ package dnsforward
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// beforeRequestHandler is the handler that is called before any other
|
||||
@@ -57,8 +59,8 @@ func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filter
|
||||
setts = s.dnsFilter.Settings()
|
||||
setts.ProtectionEnabled = dctx.protectionEnabled
|
||||
if s.conf.FilterHandler != nil {
|
||||
ip, _ := netutil.IPAndPortFromAddr(dctx.proxyCtx.Addr)
|
||||
s.conf.FilterHandler(ip, dctx.clientID, setts)
|
||||
addrPort := netutil.NetAddrToAddrPort(dctx.proxyCtx.Addr)
|
||||
s.conf.FilterHandler(addrPort.Addr(), dctx.clientID, setts)
|
||||
}
|
||||
|
||||
return setts
|
||||
@@ -123,7 +125,7 @@ func (s *Server) filterRewritten(
|
||||
for _, ip := range res.IPList {
|
||||
switch qt {
|
||||
case dns.TypeA:
|
||||
a := s.genAnswerA(req, ip.To4())
|
||||
a := s.genAnswerA(req, ip)
|
||||
a.Hdr.Name = dns.Fqdn(name)
|
||||
resp.Answer = append(resp.Answer, a)
|
||||
case dns.TypeAAAA:
|
||||
@@ -145,10 +147,6 @@ func (s *Server) checkHostRules(host string, rrtype uint16, setts *filtering.Set
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
if s.dnsFilter == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var res filtering.Result
|
||||
res, err = s.dnsFilter.CheckHostRules(host, rrtype, setts)
|
||||
if err != nil {
|
||||
@@ -176,19 +174,26 @@ func (s *Server) filterDNSResponse(
|
||||
case *dns.CNAME:
|
||||
host = strings.TrimSuffix(a.Target, ".")
|
||||
rrtype = dns.TypeCNAME
|
||||
|
||||
res, err = s.checkHostRules(host, rrtype, setts)
|
||||
case *dns.A:
|
||||
host = a.A.String()
|
||||
rrtype = dns.TypeA
|
||||
|
||||
res, err = s.checkHostRules(host, rrtype, setts)
|
||||
case *dns.AAAA:
|
||||
host = a.AAAA.String()
|
||||
rrtype = dns.TypeAAAA
|
||||
|
||||
res, err = s.checkHostRules(host, rrtype, setts)
|
||||
case *dns.HTTPS:
|
||||
res, err = s.filterHTTPSRecords(a, setts)
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: checking %s %s for %s", dns.Type(rrtype), host, a.Header().Name)
|
||||
log.Debug("dnsforward: checked %s %s for %s", dns.Type(rrtype), host, a.Header().Name)
|
||||
|
||||
res, err = s.checkHostRules(host, rrtype, setts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if res == nil {
|
||||
@@ -203,3 +208,67 @@ func (s *Server) filterDNSResponse(
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// removeIPv6Hints deletes IPv6 hints from RR values.
|
||||
func removeIPv6Hints(rr *dns.HTTPS) {
|
||||
rr.Value = slices.DeleteFunc(rr.Value, func(kv dns.SVCBKeyValue) (del bool) {
|
||||
_, ok := kv.(*dns.SVCBIPv6Hint)
|
||||
|
||||
return ok
|
||||
})
|
||||
}
|
||||
|
||||
// filterHTTPSRecords filters HTTPS answers information through all rule list
|
||||
// filters of the server filters. Removes IPv6 hints if IPv6 resolving is
|
||||
// disabled.
|
||||
func (s *Server) filterHTTPSRecords(rr *dns.HTTPS, setts *filtering.Settings) (r *filtering.Result, err error) {
|
||||
if s.conf.AAAADisabled {
|
||||
removeIPv6Hints(rr)
|
||||
}
|
||||
|
||||
for _, kv := range rr.Value {
|
||||
var ips []net.IP
|
||||
switch hint := kv.(type) {
|
||||
case *dns.SVCBIPv4Hint:
|
||||
ips = hint.Hint
|
||||
case *dns.SVCBIPv6Hint:
|
||||
ips = hint.Hint
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
r, err = s.filterSVCBHint(ips, setts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("filtering svcb hints: %w", err)
|
||||
}
|
||||
|
||||
if r != nil {
|
||||
return r, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// filterSVCBHint filters SVCB hint information.
|
||||
func (s *Server) filterSVCBHint(
|
||||
hint []net.IP,
|
||||
setts *filtering.Settings,
|
||||
) (res *filtering.Result, err error) {
|
||||
for _, h := range hint {
|
||||
res, err = s.checkHostRules(h.String(), dns.TypeHTTPS, setts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("checking rules for %s: %w", h, err)
|
||||
}
|
||||
|
||||
if res != nil && res.IsFiltered {
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
func TestHandleDNSRequest_handleDNSRequest(t *testing.T) {
|
||||
rules := `
|
||||
||blocked.domain^
|
||||
@@||allowed.domain^
|
||||
@@ -23,14 +24,13 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
||::1^$dnstype=~AAAA
|
||||
0.0.0.0 duplicate.domain
|
||||
0.0.0.0 duplicate.domain
|
||||
0.0.0.0 blocked.by.hostrule
|
||||
`
|
||||
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
FilteringConfig: FilteringConfig{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: BlockingModeDefault,
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
@@ -40,12 +40,19 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
ID: 0, Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f, err := filtering.New(&filtering.Config{}, filters)
|
||||
f, err := filtering.New(&filtering.Config{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, filters)
|
||||
require.NoError(t, err)
|
||||
f.SetEnabled(true)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: testDHCP,
|
||||
DHCPServer: &testDHCP{
|
||||
OnEnabled: func() (ok bool) { return false },
|
||||
OnHostByIP: func(ip netip.Addr) (host string) { panic("not implemented") },
|
||||
OnIPByHost: func(host string) (ip netip.Addr) { panic("not implemented") },
|
||||
},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
@@ -73,12 +80,19 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
startDeferStop(t, s)
|
||||
|
||||
testCases := []struct {
|
||||
req *dns.Msg
|
||||
name string
|
||||
wantAns []dns.RR
|
||||
req *dns.Msg
|
||||
name string
|
||||
wantRCode int
|
||||
wantAns []dns.RR
|
||||
}{{
|
||||
req: createTestMessage("cname.exception."),
|
||||
name: "cname_exception",
|
||||
req: createTestMessage(aghtest.ReqFQDN),
|
||||
name: "pass",
|
||||
wantRCode: dns.RcodeNameError,
|
||||
wantAns: nil,
|
||||
}, {
|
||||
req: createTestMessage("cname.exception."),
|
||||
name: "cname_exception",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: []dns.RR{&dns.CNAME{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "cname.exception.",
|
||||
@@ -87,8 +101,9 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
Target: "cname.specific.",
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("should.block."),
|
||||
name: "blocked_by_cname",
|
||||
req: createTestMessage("should.block."),
|
||||
name: "blocked_by_cname",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "should.block.",
|
||||
@@ -98,8 +113,9 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
A: netutil.IPv4Zero(),
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("a.exception."),
|
||||
name: "a_exception",
|
||||
req: createTestMessage("a.exception."),
|
||||
name: "a_exception",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "a.exception.",
|
||||
@@ -108,8 +124,9 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
A: net.IP{0, 0, 0, 1},
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessageWithType("aaaa.exception.", dns.TypeAAAA),
|
||||
name: "aaaa_exception",
|
||||
req: createTestMessageWithType("aaaa.exception.", dns.TypeAAAA),
|
||||
name: "aaaa_exception",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: []dns.RR{&dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "aaaa.exception.",
|
||||
@@ -118,8 +135,9 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
AAAA: net.ParseIP("::1"),
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("allowed.first."),
|
||||
name: "allowed_first",
|
||||
req: createTestMessage("allowed.first."),
|
||||
name: "allowed_first",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "allowed.first.",
|
||||
@@ -129,8 +147,9 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
A: netutil.IPv4Zero(),
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("blocked.first."),
|
||||
name: "blocked_first",
|
||||
req: createTestMessage("blocked.first."),
|
||||
name: "blocked_first",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "blocked.first.",
|
||||
@@ -140,8 +159,9 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
A: netutil.IPv4Zero(),
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("duplicate.domain."),
|
||||
name: "duplicate_domain",
|
||||
req: createTestMessage("duplicate.domain."),
|
||||
name: "duplicate_domain",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "duplicate.domain.",
|
||||
@@ -150,6 +170,16 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
},
|
||||
A: netutil.IPv4Zero(),
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessageWithType("blocked.domain.", dns.TypeHTTPS),
|
||||
name: "blocked_https_req",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: nil,
|
||||
}, {
|
||||
req: createTestMessageWithType("blocked.by.hostrule.", dns.TypeHTTPS),
|
||||
name: "blocked_host_rule_https_req",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -164,7 +194,175 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, dctx.Res)
|
||||
|
||||
assert.Equal(t, tc.wantRCode, dctx.Res.Rcode)
|
||||
assert.Equal(t, tc.wantAns, dctx.Res.Answer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
const (
|
||||
passedIPv4Str = "1.1.1.1"
|
||||
blockedIPv4Str = "1.2.3.4"
|
||||
blockedIPv6Str = "1234::cdef"
|
||||
blockRules = blockedIPv4Str + "\n" + blockedIPv6Str + "\n"
|
||||
)
|
||||
|
||||
var (
|
||||
passedIPv4 net.IP = netip.MustParseAddr(passedIPv4Str).AsSlice()
|
||||
blockedIPv4 net.IP = netip.MustParseAddr(blockedIPv4Str).AsSlice()
|
||||
blockedIPv6 net.IP = netip.MustParseAddr(blockedIPv6Str).AsSlice()
|
||||
)
|
||||
|
||||
filters := []filtering.Filter{{
|
||||
ID: 0, Data: []byte(blockRules),
|
||||
}}
|
||||
|
||||
f, err := filtering.New(&filtering.Config{}, filters)
|
||||
require.NoError(t, err)
|
||||
|
||||
f.SetEnabled(true)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
req *dns.Msg
|
||||
name string
|
||||
wantRule string
|
||||
respAns []dns.RR
|
||||
}{{
|
||||
name: "pass",
|
||||
req: createTestMessageWithType(aghtest.ReqFQDN, dns.TypeA),
|
||||
wantRule: "",
|
||||
respAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: aghtest.ReqFQDN,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
A: passedIPv4,
|
||||
}},
|
||||
}, {
|
||||
name: "ipv4",
|
||||
req: createTestMessageWithType(aghtest.ReqFQDN, dns.TypeA),
|
||||
wantRule: blockedIPv4Str,
|
||||
respAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: aghtest.ReqFQDN,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
A: blockedIPv4,
|
||||
}},
|
||||
}, {
|
||||
name: "ipv6",
|
||||
req: createTestMessageWithType(aghtest.ReqFQDN, dns.TypeAAAA),
|
||||
wantRule: blockedIPv6Str,
|
||||
respAns: []dns.RR{&dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: aghtest.ReqFQDN,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
AAAA: blockedIPv6,
|
||||
}},
|
||||
}, {
|
||||
name: "ipv4hint",
|
||||
req: createTestMessageWithType(aghtest.ReqFQDN, dns.TypeHTTPS),
|
||||
wantRule: blockedIPv4Str,
|
||||
respAns: newSVCBHintsAnswer(
|
||||
aghtest.ReqFQDN,
|
||||
[]dns.SVCBKeyValue{
|
||||
&dns.SVCBIPv4Hint{Hint: []net.IP{blockedIPv4}},
|
||||
&dns.SVCBIPv6Hint{Hint: []net.IP{}},
|
||||
},
|
||||
),
|
||||
}, {
|
||||
name: "ipv6hint",
|
||||
req: createTestMessageWithType(aghtest.ReqFQDN, dns.TypeHTTPS),
|
||||
wantRule: blockedIPv6Str,
|
||||
respAns: newSVCBHintsAnswer(
|
||||
aghtest.ReqFQDN,
|
||||
[]dns.SVCBKeyValue{
|
||||
&dns.SVCBIPv4Hint{Hint: []net.IP{}},
|
||||
&dns.SVCBIPv6Hint{Hint: []net.IP{blockedIPv6}},
|
||||
},
|
||||
),
|
||||
}, {
|
||||
name: "ipv4_ipv6_hints",
|
||||
req: createTestMessageWithType(aghtest.ReqFQDN, dns.TypeHTTPS),
|
||||
wantRule: blockedIPv4Str,
|
||||
respAns: newSVCBHintsAnswer(
|
||||
aghtest.ReqFQDN,
|
||||
[]dns.SVCBKeyValue{
|
||||
&dns.SVCBIPv4Hint{Hint: []net.IP{blockedIPv4}},
|
||||
&dns.SVCBIPv6Hint{Hint: []net.IP{blockedIPv6}},
|
||||
},
|
||||
),
|
||||
}, {
|
||||
name: "pass_hints",
|
||||
req: createTestMessageWithType(aghtest.ReqFQDN, dns.TypeHTTPS),
|
||||
wantRule: "",
|
||||
respAns: newSVCBHintsAnswer(
|
||||
aghtest.ReqFQDN,
|
||||
[]dns.SVCBKeyValue{
|
||||
&dns.SVCBIPv4Hint{Hint: []net.IP{passedIPv4}},
|
||||
&dns.SVCBIPv6Hint{Hint: []net.IP{}},
|
||||
},
|
||||
),
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp := newResp(dns.RcodeSuccess, tc.req, tc.respAns)
|
||||
|
||||
pctx := &proxy.DNSContext{
|
||||
Proto: proxy.ProtoUDP,
|
||||
Req: tc.req,
|
||||
Res: resp,
|
||||
Addr: &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 1},
|
||||
}
|
||||
|
||||
res, rErr := s.filterDNSResponse(pctx, &filtering.Settings{
|
||||
ProtectionEnabled: true,
|
||||
FilteringEnabled: true,
|
||||
})
|
||||
require.NoError(t, rErr)
|
||||
|
||||
if tc.wantRule == "" {
|
||||
assert.Nil(t, res)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
want := &filtering.Result{
|
||||
IsFiltered: true,
|
||||
Reason: filtering.FilteredBlockList,
|
||||
Rules: []*filtering.ResultRule{{
|
||||
Text: tc.wantRule,
|
||||
}},
|
||||
}
|
||||
assert.Equal(t, want, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// newSVCBHintsAnswer returns a test HTTPS answer RRs with SVCB hints.
|
||||
func newSVCBHintsAnswer(target string, hints []dns.SVCBKeyValue) (rrs []dns.RR) {
|
||||
return []dns.RR{&dns.HTTPS{
|
||||
SVCB: dns.SVCB{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: target,
|
||||
Rrtype: dns.TypeHTTPS,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
Target: target,
|
||||
Value: hints,
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
@@ -4,13 +4,13 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -37,6 +37,10 @@ type jsonDNSConfig struct {
|
||||
// upstream DoH/DoT resolvers.
|
||||
Bootstraps *[]string `json:"bootstrap_dns"`
|
||||
|
||||
// Fallbacks is the list of fallback DNS servers used when upstream DNS
|
||||
// servers are not responding.
|
||||
Fallbacks *[]string `json:"fallback_dns"`
|
||||
|
||||
// ProtectionEnabled defines if protection is enabled.
|
||||
ProtectionEnabled *bool `json:"protection_enabled"`
|
||||
|
||||
@@ -44,7 +48,7 @@ type jsonDNSConfig struct {
|
||||
RateLimit *uint32 `json:"ratelimit"`
|
||||
|
||||
// BlockingMode defines the way blocked responses are constructed.
|
||||
BlockingMode *BlockingMode `json:"blocking_mode"`
|
||||
BlockingMode *filtering.BlockingMode `json:"blocking_mode"`
|
||||
|
||||
// EDNSCSEnabled defines if EDNS Client Subnet is enabled.
|
||||
EDNSCSEnabled *bool `json:"edns_cs_enabled"`
|
||||
@@ -83,10 +87,10 @@ type jsonDNSConfig struct {
|
||||
LocalPTRUpstreams *[]string `json:"local_ptr_upstreams"`
|
||||
|
||||
// BlockingIPv4 is custom IPv4 address for blocked A requests.
|
||||
BlockingIPv4 net.IP `json:"blocking_ipv4"`
|
||||
BlockingIPv4 netip.Addr `json:"blocking_ipv4"`
|
||||
|
||||
// BlockingIPv6 is custom IPv6 address for blocked AAAA requests.
|
||||
BlockingIPv6 net.IP `json:"blocking_ipv6"`
|
||||
BlockingIPv6 netip.Addr `json:"blocking_ipv6"`
|
||||
|
||||
// DisabledUntil is a timestamp until when the protection is disabled.
|
||||
DisabledUntil *time.Time `json:"protection_disabled_until"`
|
||||
@@ -109,9 +113,8 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
upstreams := stringutil.CloneSliceOrEmpty(s.conf.UpstreamDNS)
|
||||
upstreamFile := s.conf.UpstreamDNSFileName
|
||||
bootstraps := stringutil.CloneSliceOrEmpty(s.conf.BootstrapDNS)
|
||||
blockingMode := s.conf.BlockingMode
|
||||
blockingIPv4 := s.conf.BlockingIPv4
|
||||
blockingIPv6 := s.conf.BlockingIPv6
|
||||
fallbacks := stringutil.CloneSliceOrEmpty(s.conf.FallbackDNS)
|
||||
blockingMode, blockingIPv4, blockingIPv6 := s.dnsFilter.BlockingMode()
|
||||
ratelimit := s.conf.Ratelimit
|
||||
|
||||
customIP := s.conf.EDNSClientSubnet.CustomIP
|
||||
@@ -144,6 +147,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
Upstreams: &upstreams,
|
||||
UpstreamsFile: &upstreamFile,
|
||||
Bootstraps: &bootstraps,
|
||||
Fallbacks: &fallbacks,
|
||||
ProtectionEnabled: &protectionEnabled,
|
||||
BlockingMode: &blockingMode,
|
||||
BlockingIPv4: blockingIPv4,
|
||||
@@ -170,7 +174,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
// handleGetConfig handles requests to the GET /control/dns_info endpoint.
|
||||
func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
resp := s.getDNSConfig()
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
}
|
||||
|
||||
func (req *jsonDNSConfig) checkBlockingMode() (err error) {
|
||||
@@ -208,6 +212,20 @@ func (req *jsonDNSConfig) checkBootstrap() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkFallbacks returns an error if any fallback address is invalid.
|
||||
func (req *jsonDNSConfig) checkFallbacks() (err error) {
|
||||
if req.Fallbacks == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = ValidateUpstreams(*req.Fallbacks)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating fallback servers: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validate returns an error if any field of req is invalid.
|
||||
func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
|
||||
if req.Upstreams != nil {
|
||||
@@ -229,6 +247,11 @@ func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
err = req.checkFallbacks()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = req.checkBlockingMode()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -295,11 +318,11 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
if dc.BlockingMode != nil {
|
||||
s.conf.BlockingMode = *dc.BlockingMode
|
||||
if *dc.BlockingMode == BlockingModeCustomIP {
|
||||
s.conf.BlockingIPv4 = dc.BlockingIPv4.To4()
|
||||
s.conf.BlockingIPv6 = dc.BlockingIPv6.To16()
|
||||
}
|
||||
s.dnsFilter.SetBlockingMode(*dc.BlockingMode, dc.BlockingIPv4, dc.BlockingIPv6)
|
||||
}
|
||||
|
||||
if dc.ProtectionEnabled != nil {
|
||||
s.dnsFilter.SetProtectionEnabled(*dc.ProtectionEnabled)
|
||||
}
|
||||
|
||||
if dc.UpstreamMode != nil {
|
||||
@@ -311,7 +334,6 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
s.conf.EDNSClientSubnet.CustomIP = dc.EDNSCSCustomIP
|
||||
}
|
||||
|
||||
setIfNotNil(&s.conf.ProtectionEnabled, dc.ProtectionEnabled)
|
||||
setIfNotNil(&s.conf.EnableDNSSEC, dc.DNSSECEnabled)
|
||||
setIfNotNil(&s.conf.AAAADisabled, dc.DisableIPv6)
|
||||
|
||||
@@ -342,6 +364,7 @@ func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
setIfNotNil(&s.conf.LocalPTRResolvers, dc.LocalPTRUpstreams),
|
||||
setIfNotNil(&s.conf.UpstreamDNSFileName, dc.UpstreamsFile),
|
||||
setIfNotNil(&s.conf.BootstrapDNS, dc.Bootstraps),
|
||||
setIfNotNil(&s.conf.FallbackDNS, dc.Fallbacks),
|
||||
setIfNotNil(&s.conf.EDNSClientSubnet.Enabled, dc.EDNSCSEnabled),
|
||||
setIfNotNil(&s.conf.EDNSClientSubnet.UseCustom, dc.EDNSCSUseCustom),
|
||||
setIfNotNil(&s.conf.CacheSize, dc.CacheSize),
|
||||
@@ -369,6 +392,7 @@ func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
type upstreamJSON struct {
|
||||
Upstreams []string `json:"upstream_dns"`
|
||||
BootstrapDNS []string `json:"bootstrap_dns"`
|
||||
FallbackDNS []string `json:"fallback_dns"`
|
||||
PrivateUpstreams []string `json:"private_upstream"`
|
||||
}
|
||||
|
||||
@@ -441,7 +465,7 @@ func ValidateUpstreams(upstreams []string) (err error) {
|
||||
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
|
||||
conf, err := newUpstreamConfig(upstreams)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("creating config: %w", err)
|
||||
}
|
||||
|
||||
if conf == nil {
|
||||
@@ -469,11 +493,7 @@ func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return errors.List("checking domain-specific upstreams", errs...)
|
||||
}
|
||||
|
||||
return nil
|
||||
return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w")
|
||||
}
|
||||
|
||||
var protocols = []string{
|
||||
@@ -667,10 +687,13 @@ func (s *Server) parseUpstreamLine(
|
||||
PreferIPv6: opts.PreferIPv6,
|
||||
}
|
||||
|
||||
if s.dnsFilter != nil && s.dnsFilter.EtcHosts != nil {
|
||||
resolved := s.resolveUpstreamHost(extractUpstreamHost(upstreamAddr))
|
||||
sortNetIPAddrs(resolved, opts.PreferIPv6)
|
||||
opts.ServerIPAddrs = resolved
|
||||
// dnsFilter can be nil during application update.
|
||||
if s.dnsFilter != nil {
|
||||
recs := s.dnsFilter.EtcHostsRecords(extractUpstreamHost(upstreamAddr))
|
||||
for _, rec := range recs {
|
||||
opts.ServerIPAddrs = append(opts.ServerIPAddrs, rec.Addr.AsSlice())
|
||||
}
|
||||
sortNetIPAddrs(opts.ServerIPAddrs, opts.PreferIPv6)
|
||||
}
|
||||
u, err = upstream.AddressToUpstream(upstreamAddr, opts)
|
||||
if err != nil {
|
||||
@@ -728,7 +751,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty)
|
||||
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
|
||||
|
||||
upsNum := len(req.Upstreams) + len(req.PrivateUpstreams)
|
||||
upsNum := len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams)
|
||||
result := make(map[string]string, upsNum)
|
||||
resCh := make(chan upsCheckResult, upsNum)
|
||||
|
||||
@@ -740,6 +763,14 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}(ups)
|
||||
}
|
||||
for _, ups := range req.FallbackDNS {
|
||||
go func(ups string) {
|
||||
resCh <- upsCheckResult{
|
||||
host: ups,
|
||||
err: s.checkDNS(ups, opts, checkDNSUpstreamExc),
|
||||
}
|
||||
}(ups)
|
||||
}
|
||||
for _, ups := range req.PrivateUpstreams {
|
||||
go func(ups string) {
|
||||
resCh <- upsCheckResult{
|
||||
@@ -760,7 +791,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, result)
|
||||
aghhttp.WriteJSONResponseOK(w, r, result)
|
||||
}
|
||||
|
||||
// handleCacheClear is the handler for the POST /control/cache_clear HTTP API.
|
||||
@@ -806,8 +837,7 @@ func (s *Server) handleSetProtection(w http.ResponseWriter, r *http.Request) {
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
s.conf.ProtectionEnabled = protectionReq.Enabled
|
||||
s.conf.ProtectionDisabledUntil = disabledUntil
|
||||
s.dnsFilter.SetProtectionStatus(protectionReq.Enabled, disabledUntil)
|
||||
}()
|
||||
|
||||
s.conf.ConfigModified()
|
||||
|
||||
@@ -58,6 +58,8 @@ const jsonExt = ".json"
|
||||
|
||||
func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
|
||||
filterConf := &filtering.Config{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
SafeBrowsingEnabled: true,
|
||||
SafeBrowsingCacheSize: 1000,
|
||||
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||
@@ -68,11 +70,10 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{},
|
||||
TCPListenAddrs: []*net.TCPAddr{},
|
||||
FilteringConfig: FilteringConfig{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: BlockingModeDefault,
|
||||
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
Config: Config{
|
||||
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
FallbackDNS: []string{"9.9.9.10"},
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ConfigModified: func() {},
|
||||
}
|
||||
@@ -134,6 +135,8 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
|
||||
|
||||
func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
filterConf := &filtering.Config{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
SafeBrowsingEnabled: true,
|
||||
SafeBrowsingCacheSize: 1000,
|
||||
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||
@@ -144,11 +147,9 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{},
|
||||
TCPListenAddrs: []*net.TCPAddr{},
|
||||
FilteringConfig: FilteringConfig{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: BlockingModeDefault,
|
||||
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
Config: Config{
|
||||
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ConfigModified: func() {},
|
||||
}
|
||||
@@ -177,7 +178,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "blocking_mode_bad",
|
||||
wantSet: "blocking_ipv4 must be set when blocking_mode is custom_ip",
|
||||
wantSet: "blocking_ipv4 must be valid ipv4 on custom_ip blocking_mode",
|
||||
}, {
|
||||
name: "ratelimit",
|
||||
wantSet: "",
|
||||
@@ -225,6 +226,9 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
}, {
|
||||
name: "local_ptr_upstreams_null",
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "fallbacks",
|
||||
wantSet: "",
|
||||
}}
|
||||
|
||||
var data map[string]struct {
|
||||
@@ -243,8 +247,9 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
s.dnsFilter.SetBlockingMode(filtering.BlockingModeDefault, netip.Addr{}, netip.Addr{})
|
||||
s.conf = defaultConf
|
||||
s.conf.FilteringConfig.EDNSClientSubnet = &EDNSClientSubnet{}
|
||||
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{}
|
||||
})
|
||||
|
||||
rBody := io.NopCloser(bytes.NewReader(caseData.Req))
|
||||
@@ -405,9 +410,9 @@ func TestValidateUpstreamsPrivate(t *testing.T) {
|
||||
u: "[/128.in-addr.arpa/]#",
|
||||
}, {
|
||||
name: "several_bad",
|
||||
wantErr: `checking domain-specific upstreams: 2 errors: ` +
|
||||
`"arpa domain \"1.2.3.4.in-addr.arpa.\" should point to a locally-served network", ` +
|
||||
`"bad arpa domain name \"non.arpa.\": not a reversed ip network"`,
|
||||
wantErr: `checking domain-specific upstreams: ` +
|
||||
`arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network` + "\n" +
|
||||
`bad arpa domain name "non.arpa.": not a reversed ip network`,
|
||||
u: "[/non.arpa/1.2.3.4.in-addr.arpa/127.in-addr.arpa/]#",
|
||||
}, {
|
||||
name: "partial_good",
|
||||
@@ -479,7 +484,6 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||
}).String()
|
||||
|
||||
hc, err := aghnet.NewHostsContainer(
|
||||
filtering.SysHostsListID,
|
||||
fstest.MapFS{
|
||||
hostsFileName: &fstest.MapFile{
|
||||
Data: []byte(hostsListener.Addr().String() + " " + upstreamHost),
|
||||
@@ -495,12 +499,13 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
srv := createTestServer(t, &filtering.Config{
|
||||
EtcHosts: hc,
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
EtcHosts: hc,
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
UpstreamTimeout: upsTimeout,
|
||||
FilteringConfig: FilteringConfig{
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
}, nil)
|
||||
@@ -555,6 +560,23 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||
hostsUps: "OK",
|
||||
},
|
||||
name: "etc_hosts",
|
||||
}, {
|
||||
body: map[string]any{
|
||||
"fallback_dns": []string{goodUps},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
goodUps: "OK",
|
||||
},
|
||||
name: "fallback_success",
|
||||
}, {
|
||||
body: map[string]any{
|
||||
"fallback_dns": []string{badUps},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
badUps: `couldn't communicate with upstream: exchanging with ` +
|
||||
badUps + ` over tcp: dns: id mismatch`,
|
||||
},
|
||||
name: "fallback_broken",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// makeResponse creates a DNS response by req and sets necessary flags. It also
|
||||
@@ -26,24 +27,13 @@ func (s *Server) makeResponse(req *dns.Msg) (resp *dns.Msg) {
|
||||
return resp
|
||||
}
|
||||
|
||||
// containsIP returns true if the IP is already in the list.
|
||||
func containsIP(ips []net.IP, ip net.IP) bool {
|
||||
for _, a := range ips {
|
||||
if a.Equal(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ipsFromRules extracts unique non-IP addresses from the filtering result
|
||||
// rules.
|
||||
func ipsFromRules(resRules []*filtering.ResultRule) (ips []net.IP) {
|
||||
func ipsFromRules(resRules []*filtering.ResultRule) (ips []netip.Addr) {
|
||||
for _, r := range resRules {
|
||||
// len(resRules) and len(ips) are actually small enough for O(n^2) to do
|
||||
// not raise performance questions.
|
||||
if ip := r.IP; ip != nil && !containsIP(ips, ip) {
|
||||
if ip := r.IP; ip != (netip.Addr{}) && !slices.Contains(ips, ip) {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
}
|
||||
@@ -58,19 +48,21 @@ func (s *Server) genDNSFilterMessage(
|
||||
res *filtering.Result,
|
||||
) (resp *dns.Msg) {
|
||||
req := dctx.Req
|
||||
if qt := req.Question[0].Qtype; qt != dns.TypeA && qt != dns.TypeAAAA {
|
||||
if s.conf.BlockingMode == BlockingModeNullIP {
|
||||
qt := req.Question[0].Qtype
|
||||
if qt != dns.TypeA && qt != dns.TypeAAAA {
|
||||
m, _, _ := s.dnsFilter.BlockingMode()
|
||||
if m == filtering.BlockingModeNullIP {
|
||||
return s.makeResponse(req)
|
||||
}
|
||||
|
||||
return s.genNXDomain(req)
|
||||
return s.newMsgNODATA(req)
|
||||
}
|
||||
|
||||
switch res.Reason {
|
||||
case filtering.FilteredSafeBrowsing:
|
||||
return s.genBlockedHost(req, s.conf.SafeBrowsingBlockHost, dctx)
|
||||
return s.genBlockedHost(req, s.dnsFilter.SafeBrowsingBlockHost(), dctx)
|
||||
case filtering.FilteredParental:
|
||||
return s.genBlockedHost(req, s.conf.ParentalBlockHost, dctx)
|
||||
return s.genBlockedHost(req, s.dnsFilter.ParentalBlockHost(), dctx)
|
||||
case filtering.FilteredSafeSearch:
|
||||
// If Safe Search generated the necessary IP addresses, use them.
|
||||
// Otherwise, if there were no errors, there are no addresses for the
|
||||
@@ -83,36 +75,45 @@ func (s *Server) genDNSFilterMessage(
|
||||
|
||||
// genForBlockingMode generates a filtered response to req based on the server's
|
||||
// blocking mode.
|
||||
func (s *Server) genForBlockingMode(req *dns.Msg, ips []net.IP) (resp *dns.Msg) {
|
||||
qt := req.Question[0].Qtype
|
||||
switch m := s.conf.BlockingMode; m {
|
||||
case BlockingModeCustomIP:
|
||||
switch qt {
|
||||
case dns.TypeA:
|
||||
return s.genARecord(req, s.conf.BlockingIPv4)
|
||||
case dns.TypeAAAA:
|
||||
return s.genAAAARecord(req, s.conf.BlockingIPv6)
|
||||
default:
|
||||
// Generally shouldn't happen, since the types are checked in
|
||||
// genDNSFilterMessage.
|
||||
log.Error("dns: invalid msg type %s for blocking mode %s", dns.Type(qt), m)
|
||||
|
||||
return s.makeResponse(req)
|
||||
}
|
||||
case BlockingModeDefault:
|
||||
func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.Msg) {
|
||||
switch mode, bIPv4, bIPv6 := s.dnsFilter.BlockingMode(); mode {
|
||||
case filtering.BlockingModeCustomIP:
|
||||
return s.makeResponseCustomIP(req, bIPv4, bIPv6)
|
||||
case filtering.BlockingModeDefault:
|
||||
if len(ips) > 0 {
|
||||
return s.genResponseWithIPs(req, ips)
|
||||
}
|
||||
|
||||
return s.makeResponseNullIP(req)
|
||||
case BlockingModeNullIP:
|
||||
case filtering.BlockingModeNullIP:
|
||||
return s.makeResponseNullIP(req)
|
||||
case BlockingModeNXDOMAIN:
|
||||
case filtering.BlockingModeNXDOMAIN:
|
||||
return s.genNXDomain(req)
|
||||
case BlockingModeREFUSED:
|
||||
case filtering.BlockingModeREFUSED:
|
||||
return s.makeResponseREFUSED(req)
|
||||
default:
|
||||
log.Error("dns: invalid blocking mode %q", s.conf.BlockingMode)
|
||||
log.Error("dns: invalid blocking mode %q", mode)
|
||||
|
||||
return s.makeResponse(req)
|
||||
}
|
||||
}
|
||||
|
||||
// makeResponseCustomIP generates a DNS response message for Custom IP blocking
|
||||
// mode with the provided IP addresses and an appropriate resource record type.
|
||||
func (s *Server) makeResponseCustomIP(
|
||||
req *dns.Msg,
|
||||
bIPv4 netip.Addr,
|
||||
bIPv6 netip.Addr,
|
||||
) (resp *dns.Msg) {
|
||||
switch qt := req.Question[0].Qtype; qt {
|
||||
case dns.TypeA:
|
||||
return s.genARecord(req, bIPv4)
|
||||
case dns.TypeAAAA:
|
||||
return s.genAAAARecord(req, bIPv6)
|
||||
default:
|
||||
// Generally shouldn't happen, since the types are checked in
|
||||
// genDNSFilterMessage.
|
||||
log.Error("dns: invalid msg type %s for custom IP blocking mode", dns.Type(qt))
|
||||
|
||||
return s.makeResponse(req)
|
||||
}
|
||||
@@ -125,13 +126,13 @@ func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg {
|
||||
return &resp
|
||||
}
|
||||
|
||||
func (s *Server) genARecord(request *dns.Msg, ip net.IP) *dns.Msg {
|
||||
func (s *Server) genARecord(request *dns.Msg, ip netip.Addr) *dns.Msg {
|
||||
resp := s.makeResponse(request)
|
||||
resp.Answer = append(resp.Answer, s.genAnswerA(request, ip))
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) genAAAARecord(request *dns.Msg, ip net.IP) *dns.Msg {
|
||||
func (s *Server) genAAAARecord(request *dns.Msg, ip netip.Addr) *dns.Msg {
|
||||
resp := s.makeResponse(request)
|
||||
resp.Answer = append(resp.Answer, s.genAnswerAAAA(request, ip))
|
||||
return resp
|
||||
@@ -141,22 +142,22 @@ func (s *Server) hdr(req *dns.Msg, rrType rules.RRType) (h dns.RR_Header) {
|
||||
return dns.RR_Header{
|
||||
Name: req.Question[0].Name,
|
||||
Rrtype: rrType,
|
||||
Ttl: s.conf.BlockedResponseTTL,
|
||||
Ttl: s.dnsFilter.BlockedResponseTTL(),
|
||||
Class: dns.ClassINET,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) genAnswerA(req *dns.Msg, ip net.IP) (ans *dns.A) {
|
||||
func (s *Server) genAnswerA(req *dns.Msg, ip netip.Addr) (ans *dns.A) {
|
||||
return &dns.A{
|
||||
Hdr: s.hdr(req, dns.TypeA),
|
||||
A: ip,
|
||||
A: ip.AsSlice(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) genAnswerAAAA(req *dns.Msg, ip net.IP) (ans *dns.AAAA) {
|
||||
func (s *Server) genAnswerAAAA(req *dns.Msg, ip netip.Addr) (ans *dns.AAAA) {
|
||||
return &dns.AAAA{
|
||||
Hdr: s.hdr(req, dns.TypeAAAA),
|
||||
AAAA: ip,
|
||||
AAAA: ip.AsSlice(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,22 +204,24 @@ func (s *Server) genAnswerTXT(req *dns.Msg, strs []string) (ans *dns.TXT) {
|
||||
// addresses and an appropriate resource record type. If any of the IPs cannot
|
||||
// be converted to the correct protocol, genResponseWithIPs returns an empty
|
||||
// response.
|
||||
func (s *Server) genResponseWithIPs(req *dns.Msg, ips []net.IP) (resp *dns.Msg) {
|
||||
func (s *Server) genResponseWithIPs(req *dns.Msg, ips []netip.Addr) (resp *dns.Msg) {
|
||||
var ans []dns.RR
|
||||
switch req.Question[0].Qtype {
|
||||
case dns.TypeA:
|
||||
for _, ip := range ips {
|
||||
if ip4 := ip.To4(); ip4 == nil {
|
||||
if ip.Is4() {
|
||||
ans = append(ans, s.genAnswerA(req, ip))
|
||||
} else {
|
||||
ans = nil
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
ans = append(ans, s.genAnswerA(req, ip))
|
||||
}
|
||||
case dns.TypeAAAA:
|
||||
for _, ip := range ips {
|
||||
ans = append(ans, s.genAnswerAAAA(req, ip.To16()))
|
||||
if ip.Is6() {
|
||||
ans = append(ans, s.genAnswerAAAA(req, ip))
|
||||
}
|
||||
}
|
||||
default:
|
||||
// Go on and return an empty response.
|
||||
@@ -239,9 +242,9 @@ func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) {
|
||||
// converted into an empty slice instead of the zero IPv4.
|
||||
switch req.Question[0].Qtype {
|
||||
case dns.TypeA:
|
||||
resp = s.genResponseWithIPs(req, []net.IP{{0, 0, 0, 0}})
|
||||
resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv4Unspecified()})
|
||||
case dns.TypeAAAA:
|
||||
resp = s.genResponseWithIPs(req, []net.IP{net.IPv6zero})
|
||||
resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv6Unspecified()})
|
||||
default:
|
||||
resp = s.makeResponse(req)
|
||||
}
|
||||
@@ -250,9 +253,15 @@ func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) {
|
||||
}
|
||||
|
||||
func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSContext) *dns.Msg {
|
||||
ip := net.ParseIP(newAddr)
|
||||
if ip != nil {
|
||||
return s.genResponseWithIPs(request, []net.IP{ip})
|
||||
if newAddr == "" {
|
||||
log.Printf("block host is not specified.")
|
||||
|
||||
return s.genServerFailure(request)
|
||||
}
|
||||
|
||||
ip, err := netip.ParseAddr(newAddr)
|
||||
if err == nil {
|
||||
return s.genResponseWithIPs(request, []netip.Addr{ip})
|
||||
}
|
||||
|
||||
// look up the hostname, TODO: cache
|
||||
@@ -274,7 +283,7 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
|
||||
return s.genServerFailure(request)
|
||||
}
|
||||
|
||||
err := prx.Resolve(newContext)
|
||||
err = prx.Resolve(newContext)
|
||||
if err != nil {
|
||||
log.Printf("couldn't look up replacement host %q: %s", newAddr, err)
|
||||
|
||||
@@ -314,6 +323,17 @@ func (s *Server) makeResponseREFUSED(request *dns.Msg) *dns.Msg {
|
||||
return &resp
|
||||
}
|
||||
|
||||
// newMsgNODATA returns a properly initialized NODATA response.
|
||||
//
|
||||
// See https://www.rfc-editor.org/rfc/rfc2308#section-2.2.
|
||||
func (s *Server) newMsgNODATA(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp = (&dns.Msg{}).SetRcode(req, dns.RcodeSuccess)
|
||||
resp.RecursionAvailable = true
|
||||
resp.Ns = s.genSOA(req)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg {
|
||||
resp := dns.Msg{}
|
||||
resp.SetRcode(request, dns.RcodeNameError)
|
||||
@@ -342,13 +362,13 @@ func (s *Server) genSOA(request *dns.Msg) []dns.RR {
|
||||
Hdr: dns.RR_Header{
|
||||
Name: zone,
|
||||
Rrtype: dns.TypeSOA,
|
||||
Ttl: s.conf.BlockedResponseTTL,
|
||||
Ttl: s.dnsFilter.BlockedResponseTTL(),
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
Mbox: "hostmaster.", // zone will be appended later if it's not empty or "."
|
||||
}
|
||||
if soa.Hdr.Ttl == 0 {
|
||||
soa.Hdr.Ttl = defaultValues.BlockedResponseTTL
|
||||
soa.Hdr.Ttl = defaultBlockedResponseTTL
|
||||
}
|
||||
if len(zone) > 0 && zone[0] != '.' {
|
||||
soa.Mbox += zone
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
@@ -70,20 +69,26 @@ type dnsContext struct {
|
||||
// isLocalClient shows if client's IP address is from locally served
|
||||
// network.
|
||||
isLocalClient bool
|
||||
|
||||
// isDHCPHost is true if the request for a local domain name and the DHCP is
|
||||
// available for this request.
|
||||
isDHCPHost bool
|
||||
}
|
||||
|
||||
// resultCode is the result of a request processing function.
|
||||
type resultCode int
|
||||
|
||||
const (
|
||||
// resultCodeSuccess is returned when a handler performed successfully,
|
||||
// and the next handler must be called.
|
||||
// resultCodeSuccess is returned when a handler performed successfully, and
|
||||
// the next handler must be called.
|
||||
resultCodeSuccess resultCode = iota
|
||||
// resultCodeFinish is returned when a handler performed successfully,
|
||||
// and the processing of the request must be stopped.
|
||||
|
||||
// resultCodeFinish is returned when a handler performed successfully, and
|
||||
// the processing of the request must be stopped.
|
||||
resultCodeFinish
|
||||
// resultCodeError is returned when a handler failed, and the processing
|
||||
// of the request must be stopped.
|
||||
|
||||
// resultCodeError is returned when a handler failed, and the processing of
|
||||
// the request must be stopped.
|
||||
resultCodeError
|
||||
)
|
||||
|
||||
@@ -239,70 +244,6 @@ func (s *Server) processClientIP(addr net.Addr) {
|
||||
s.addrProc.Process(clientIP)
|
||||
}
|
||||
|
||||
func (s *Server) setTableHostToIP(t hostToIPTable) {
|
||||
s.tableHostToIPLock.Lock()
|
||||
defer s.tableHostToIPLock.Unlock()
|
||||
|
||||
s.tableHostToIP = t
|
||||
}
|
||||
|
||||
func (s *Server) setTableIPToHost(t ipToHostTable) {
|
||||
s.tableIPToHostLock.Lock()
|
||||
defer s.tableIPToHostLock.Unlock()
|
||||
|
||||
s.tableIPToHost = t
|
||||
}
|
||||
|
||||
func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
switch flags {
|
||||
case dhcpd.LeaseChangedAdded,
|
||||
dhcpd.LeaseChangedAddedStatic,
|
||||
dhcpd.LeaseChangedRemovedStatic:
|
||||
// Go on.
|
||||
case dhcpd.LeaseChangedRemovedAll:
|
||||
s.setTableHostToIP(nil)
|
||||
s.setTableIPToHost(nil)
|
||||
|
||||
return
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
ll := s.dhcpServer.Leases(dhcpd.LeasesAll)
|
||||
hostToIP := make(hostToIPTable, len(ll))
|
||||
ipToHost := make(ipToHostTable, len(ll))
|
||||
|
||||
for _, l := range ll {
|
||||
// TODO(a.garipov): Remove this after we're finished with the client
|
||||
// hostname validations in the DHCP server code.
|
||||
err := netutil.ValidateHostname(l.Hostname)
|
||||
if err != nil {
|
||||
log.Debug("dnsforward: skipping invalid hostname %q from dhcp: %s", l.Hostname, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
lowhost := strings.ToLower(l.Hostname + "." + s.localDomainSuffix)
|
||||
|
||||
// Assume that we only process IPv4 now.
|
||||
if !l.IP.Is4() {
|
||||
log.Debug("dnsforward: skipping invalid ip from dhcp: bad ipv4 net.IP %v", l.IP)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
leaseIP := l.IP
|
||||
|
||||
ipToHost[leaseIP] = lowhost
|
||||
hostToIP[lowhost] = leaseIP
|
||||
}
|
||||
|
||||
s.setTableHostToIP(hostToIP)
|
||||
s.setTableIPToHost(ipToHost)
|
||||
|
||||
log.Debug("dnsforward: added %d a and ptr entries from dhcp", len(ipToHost))
|
||||
}
|
||||
|
||||
// processDDRQuery responds to Discovery of Designated Resolvers (DDR) SVCB
|
||||
// queries. The response contains different types of encryption supported by
|
||||
// current user configuration.
|
||||
@@ -420,18 +361,6 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
|
||||
return rc
|
||||
}
|
||||
|
||||
// dhcpHostToIP tries to get an IP leased by DHCP and returns the copy of
|
||||
// address since the data inside the internal table may be changed while request
|
||||
// processing. It's safe for concurrent use.
|
||||
func (s *Server) dhcpHostToIP(host string) (ip netip.Addr, ok bool) {
|
||||
s.tableHostToIPLock.Lock()
|
||||
defer s.tableHostToIPLock.Unlock()
|
||||
|
||||
ip, ok = s.tableHostToIP[host]
|
||||
|
||||
return ip, ok
|
||||
}
|
||||
|
||||
// processDHCPHosts respond to A requests if the target hostname is known to
|
||||
// the server. It responds with a mapped IP address if the DNS64 is enabled and
|
||||
// the request is for AAAA.
|
||||
@@ -443,30 +372,31 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
req := pctx.Req
|
||||
q := req.Question[0]
|
||||
reqHost, ok := s.isDHCPClientHostQ(q)
|
||||
if !ok {
|
||||
|
||||
q := &req.Question[0]
|
||||
dhcpHost := s.dhcpHostFromRequest(q)
|
||||
if dctx.isDHCPHost = dhcpHost != ""; !dctx.isDHCPHost {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
if !dctx.isLocalClient {
|
||||
log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, reqHost)
|
||||
log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, dhcpHost)
|
||||
pctx.Res = s.genNXDomain(req)
|
||||
|
||||
// Do not even put into query log.
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
ip, ok := s.dhcpHostToIP(reqHost)
|
||||
if !ok {
|
||||
ip := s.dhcpServer.IPByHost(dhcpHost)
|
||||
if ip == (netip.Addr{}) {
|
||||
// Go on and process them with filters, including dnsrewrite ones, and
|
||||
// possibly route them to a domain-specific upstream.
|
||||
log.Debug("dnsforward: no dhcp record for %q", reqHost)
|
||||
log.Debug("dnsforward: no dhcp record for %q", dhcpHost)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: dhcp record for %q is %s", reqHost, ip)
|
||||
log.Debug("dnsforward: dhcp record for %q is %s", dhcpHost, ip)
|
||||
|
||||
resp := s.makeResponse(req)
|
||||
switch q.Qtype {
|
||||
@@ -638,17 +568,6 @@ func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// ipToDHCPHost tries to get a hostname leased by DHCP. It's safe for
|
||||
// concurrent use.
|
||||
func (s *Server) ipToDHCPHost(ip netip.Addr) (host string, ok bool) {
|
||||
s.tableIPToHostLock.Lock()
|
||||
defer s.tableIPToHostLock.Unlock()
|
||||
|
||||
host, ok = s.tableIPToHost[ip]
|
||||
|
||||
return host, ok
|
||||
}
|
||||
|
||||
// processDHCPAddrs responds to PTR requests if the target IP is leased by the
|
||||
// DHCP server.
|
||||
func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
|
||||
@@ -673,12 +592,12 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
host, ok := s.ipToDHCPHost(ipAddr)
|
||||
if !ok {
|
||||
host := s.dhcpServer.HostByIP(ipAddr)
|
||||
if host == "" {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: dhcp reverse record for %s is %q", ip, host)
|
||||
log.Debug("dnsforward: dhcp client %s is %q", ip, host)
|
||||
|
||||
req := pctx.Req
|
||||
resp := s.makeResponse(req)
|
||||
@@ -686,10 +605,12 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
|
||||
Hdr: dns.RR_Header{
|
||||
Name: req.Question[0].Name,
|
||||
Rrtype: dns.TypePTR,
|
||||
Ttl: s.conf.BlockedResponseTTL,
|
||||
Class: dns.ClassINET,
|
||||
// TODO(e.burkov): Use [dhcpsvc.Lease.Expiry]. See
|
||||
// https://github.com/AdguardTeam/AdGuardHome/issues/3932.
|
||||
Ttl: s.dnsFilter.BlockedResponseTTL(),
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
Ptr: dns.Fqdn(host),
|
||||
Ptr: dns.Fqdn(strings.Join([]string{host, s.localDomainSuffix}, ".")),
|
||||
}
|
||||
resp.Answer = append(resp.Answer, ptr)
|
||||
pctx.Res = resp
|
||||
@@ -762,10 +683,6 @@ func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode)
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
if s.dnsFilter == nil {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
var err error
|
||||
if ctx.result, err = s.filterDNSRequest(ctx); err != nil {
|
||||
ctx.err = err
|
||||
@@ -792,17 +709,18 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
req := pctx.Req
|
||||
q := req.Question[0]
|
||||
|
||||
if pctx.Res != nil {
|
||||
// The response has already been set.
|
||||
return resultCodeSuccess
|
||||
} else if reqHost, ok := s.isDHCPClientHostQ(q); ok {
|
||||
} else if dctx.isDHCPHost {
|
||||
// A DHCP client hostname query that hasn't been handled or filtered.
|
||||
// Respond with an NXDOMAIN.
|
||||
//
|
||||
// TODO(a.garipov): Route such queries to a custom upstream for the
|
||||
// local domain name if there is one.
|
||||
log.Debug("dnsforward: dhcp client hostname %q was not filtered", reqHost)
|
||||
name := req.Question[0].Name
|
||||
log.Debug("dnsforward: dhcp client hostname %q was not filtered", name[:len(name)-1])
|
||||
pctx.Res = s.genNXDomain(req)
|
||||
|
||||
return resultCodeFinish
|
||||
@@ -889,26 +807,26 @@ func (s *Server) setRespAD(pctx *proxy.DNSContext, reqWantsDNSSEC bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// isDHCPClientHostQ returns true if q is from a request for a DHCP client
|
||||
// hostname. If ok is true, reqHost contains the requested hostname.
|
||||
func (s *Server) isDHCPClientHostQ(q dns.Question) (reqHost string, ok bool) {
|
||||
// dhcpHostFromRequest returns a hostname from question, if the request is for a
|
||||
// DHCP client's hostname when DHCP is enabled, and an empty string otherwise.
|
||||
func (s *Server) dhcpHostFromRequest(q *dns.Question) (reqHost string) {
|
||||
if !s.dhcpServer.Enabled() {
|
||||
return "", false
|
||||
return ""
|
||||
}
|
||||
|
||||
// Include AAAA here, because despite the fact that we don't support it yet,
|
||||
// the expected behavior here is to respond with an empty answer and not
|
||||
// NXDOMAIN.
|
||||
if qt := q.Qtype; qt != dns.TypeA && qt != dns.TypeAAAA {
|
||||
return "", false
|
||||
return ""
|
||||
}
|
||||
|
||||
reqHost = strings.ToLower(q.Name[:len(q.Name)-1])
|
||||
if strings.HasSuffix(reqHost, s.localDomainSuffix) {
|
||||
return reqHost, true
|
||||
if !netutil.IsImmediateSubdomain(reqHost, s.localDomainSuffix) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return "", false
|
||||
return reqHost[:len(reqHost)-len(s.localDomainSuffix)-1]
|
||||
}
|
||||
|
||||
// setCustomUpstream sets custom upstream settings in pctx, if necessary.
|
||||
@@ -972,7 +890,7 @@ func (s *Server) filterAfterResponse(dctx *dnsContext, pctx *proxy.DNSContext) (
|
||||
// Check the response only if it's from an upstream. Don't check the
|
||||
// response if the protection is disabled since dnsrewrite rules aren't
|
||||
// applied to it anyway.
|
||||
if !dctx.protectionEnabled || !dctx.responseFromUpstream || s.dnsFilter == nil {
|
||||
if !dctx.protectionEnabled || !dctx.responseFromUpstream {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
|
||||
@@ -77,13 +77,15 @@ func TestServer_ProcessInitial(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := ServerConfig{
|
||||
FilteringConfig: FilteringConfig{
|
||||
Config: Config{
|
||||
AAAADisabled: tc.aaaaDisabled,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
}
|
||||
|
||||
s := createTestServer(t, &filtering.Config{}, c, nil)
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, c, nil)
|
||||
|
||||
var gotAddr netip.Addr
|
||||
s.addrProc = &aghtest.AddressProcessor{
|
||||
@@ -113,6 +115,101 @@ func TestServer_ProcessInitial(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
testIPv4 net.IP = netip.MustParseAddr("1.1.1.1").AsSlice()
|
||||
testIPv6 net.IP = netip.MustParseAddr("1234::cdef").AsSlice()
|
||||
)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
req *dns.Msg
|
||||
aaaaDisabled bool
|
||||
respAns []dns.RR
|
||||
wantRC resultCode
|
||||
wantRespAns []dns.RR
|
||||
}{{
|
||||
name: "pass",
|
||||
req: createTestMessageWithType(aghtest.ReqFQDN, dns.TypeHTTPS),
|
||||
aaaaDisabled: false,
|
||||
respAns: newSVCBHintsAnswer(
|
||||
aghtest.ReqFQDN,
|
||||
[]dns.SVCBKeyValue{
|
||||
&dns.SVCBIPv4Hint{Hint: []net.IP{testIPv4}},
|
||||
&dns.SVCBIPv6Hint{Hint: []net.IP{testIPv6}},
|
||||
},
|
||||
),
|
||||
wantRespAns: newSVCBHintsAnswer(
|
||||
aghtest.ReqFQDN,
|
||||
[]dns.SVCBKeyValue{
|
||||
&dns.SVCBIPv4Hint{Hint: []net.IP{testIPv4}},
|
||||
&dns.SVCBIPv6Hint{Hint: []net.IP{testIPv6}},
|
||||
},
|
||||
),
|
||||
wantRC: resultCodeSuccess,
|
||||
}, {
|
||||
name: "filter",
|
||||
req: createTestMessageWithType(aghtest.ReqFQDN, dns.TypeHTTPS),
|
||||
aaaaDisabled: true,
|
||||
respAns: newSVCBHintsAnswer(
|
||||
aghtest.ReqFQDN,
|
||||
[]dns.SVCBKeyValue{
|
||||
&dns.SVCBIPv4Hint{Hint: []net.IP{testIPv4}},
|
||||
&dns.SVCBIPv6Hint{Hint: []net.IP{testIPv6}},
|
||||
},
|
||||
),
|
||||
wantRespAns: newSVCBHintsAnswer(
|
||||
aghtest.ReqFQDN,
|
||||
[]dns.SVCBKeyValue{
|
||||
&dns.SVCBIPv4Hint{Hint: []net.IP{testIPv4}},
|
||||
},
|
||||
),
|
||||
wantRC: resultCodeSuccess,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := ServerConfig{
|
||||
Config: Config{
|
||||
AAAADisabled: tc.aaaaDisabled,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
}
|
||||
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, c, nil)
|
||||
|
||||
resp := newResp(dns.RcodeSuccess, tc.req, tc.respAns)
|
||||
dctx := &dnsContext{
|
||||
setts: &filtering.Settings{
|
||||
FilteringEnabled: true,
|
||||
ProtectionEnabled: true,
|
||||
},
|
||||
protectionEnabled: true,
|
||||
responseFromUpstream: true,
|
||||
result: &filtering.Result{},
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Proto: proxy.ProtoUDP,
|
||||
Req: tc.req,
|
||||
Res: resp,
|
||||
Addr: testClientAddr,
|
||||
},
|
||||
}
|
||||
|
||||
gotRC := s.processFilteringAfterResponse(dctx)
|
||||
assert.Equal(t, tc.wantRC, gotRC)
|
||||
assert.Equal(t, newResp(dns.RcodeSuccess, tc.req, tc.wantRespAns), dctx.proxyCtx.Res)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
dohSVCB := &dns.SVCB{
|
||||
Priority: 1,
|
||||
@@ -245,15 +342,28 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// createTestDNSFilter returns the minimum valid DNSFilter.
|
||||
func createTestDNSFilter(t *testing.T) (f *filtering.DNSFilter) {
|
||||
t.Helper()
|
||||
|
||||
f, err := filtering.New(&filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, []filtering.Filter{})
|
||||
require.NoError(t, err)
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
func prepareTestServer(t *testing.T, portDoH, portDoT, portDoQ int, ddrEnabled bool) (s *Server) {
|
||||
t.Helper()
|
||||
|
||||
s = &Server{
|
||||
dnsFilter: createTestDNSFilter(t),
|
||||
dnsProxy: &proxy.Proxy{
|
||||
Config: proxy.Config{},
|
||||
},
|
||||
conf: ServerConfig{
|
||||
FilteringConfig: FilteringConfig{
|
||||
Config: Config{
|
||||
HandleDDR: ddrEnabled,
|
||||
},
|
||||
TLSConfig: TLSConfig{
|
||||
@@ -323,47 +433,59 @@ func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||
const (
|
||||
localDomainSuffix = "lan"
|
||||
dhcpClient = "example"
|
||||
|
||||
knownHost = dhcpClient + "." + localDomainSuffix
|
||||
unknownHost = "wronghost." + localDomainSuffix
|
||||
)
|
||||
|
||||
knownIP := netip.MustParseAddr("1.2.3.4")
|
||||
dhcp := &testDHCP{
|
||||
OnEnabled: func() (_ bool) { return true },
|
||||
OnIPByHost: func(host string) (ip netip.Addr) {
|
||||
if host == dhcpClient {
|
||||
ip = knownIP
|
||||
}
|
||||
|
||||
return ip
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
wantIP netip.Addr
|
||||
name string
|
||||
host string
|
||||
wantRes resultCode
|
||||
isLocalCli bool
|
||||
}{{
|
||||
wantIP: knownIP,
|
||||
name: "local_client_success",
|
||||
host: "example.lan",
|
||||
wantRes: resultCodeSuccess,
|
||||
host: knownHost,
|
||||
isLocalCli: true,
|
||||
}, {
|
||||
wantIP: netip.Addr{},
|
||||
name: "local_client_unknown_host",
|
||||
host: "wronghost.lan",
|
||||
wantRes: resultCodeSuccess,
|
||||
host: unknownHost,
|
||||
isLocalCli: true,
|
||||
}, {
|
||||
wantIP: netip.Addr{},
|
||||
name: "external_client_known_host",
|
||||
host: "example.lan",
|
||||
wantRes: resultCodeFinish,
|
||||
host: knownHost,
|
||||
isLocalCli: false,
|
||||
}, {
|
||||
wantIP: netip.Addr{},
|
||||
name: "external_client_unknown_host",
|
||||
host: "wronghost.lan",
|
||||
wantRes: resultCodeFinish,
|
||||
host: unknownHost,
|
||||
isLocalCli: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s := &Server{
|
||||
dhcpServer: testDHCP,
|
||||
localDomainSuffix: defaultLocalDomainSuffix,
|
||||
tableHostToIP: hostToIPTable{
|
||||
"example." + defaultLocalDomainSuffix: knownIP,
|
||||
},
|
||||
dnsFilter: createTestDNSFilter(t),
|
||||
dhcpServer: dhcp,
|
||||
localDomainSuffix: localDomainSuffix,
|
||||
}
|
||||
|
||||
req := &dns.Msg{
|
||||
@@ -385,43 +507,52 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||
}
|
||||
|
||||
res := s.processDHCPHosts(dctx)
|
||||
require.Equal(t, tc.wantRes, res)
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
if tc.wantRes == resultCodeFinish {
|
||||
if !tc.isLocalCli {
|
||||
require.Equal(t, resultCodeFinish, res)
|
||||
require.NotNil(t, pctx.Res)
|
||||
|
||||
assert.Equal(t, dns.RcodeNameError, pctx.Res.Rcode)
|
||||
assert.Len(t, pctx.Res.Answer, 0)
|
||||
assert.Empty(t, pctx.Res.Answer)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.Equal(t, resultCodeSuccess, res)
|
||||
|
||||
if tc.wantIP == (netip.Addr{}) {
|
||||
assert.Nil(t, pctx.Res)
|
||||
} else {
|
||||
require.NotNil(t, pctx.Res)
|
||||
|
||||
ans := pctx.Res.Answer
|
||||
require.Len(t, ans, 1)
|
||||
|
||||
a := testutil.RequireTypeAssert[*dns.A](t, ans[0])
|
||||
|
||||
ip, err := netutil.IPToAddr(a.A, netutil.AddrFamilyIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantIP, ip)
|
||||
return
|
||||
}
|
||||
|
||||
require.NotNil(t, pctx.Res)
|
||||
|
||||
ans := pctx.Res.Answer
|
||||
require.Len(t, ans, 1)
|
||||
|
||||
a := testutil.RequireTypeAssert[*dns.A](t, ans[0])
|
||||
|
||||
ip, err := netutil.IPToAddr(a.A, netutil.AddrFamilyIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantIP, ip)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||
const (
|
||||
examplecom = "example.com"
|
||||
examplelan = "example." + defaultLocalDomainSuffix
|
||||
localTLD = "lan"
|
||||
|
||||
knownClient = "example"
|
||||
externalHost = knownClient + ".com"
|
||||
clientHost = knownClient + "." + localTLD
|
||||
)
|
||||
|
||||
knownIP := netip.MustParseAddr("1.2.3.4")
|
||||
|
||||
testCases := []struct {
|
||||
wantIP netip.Addr
|
||||
name string
|
||||
@@ -431,55 +562,65 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||
qtyp uint16
|
||||
}{{
|
||||
wantIP: netip.Addr{},
|
||||
name: "success_external",
|
||||
host: examplecom,
|
||||
suffix: defaultLocalDomainSuffix,
|
||||
name: "external",
|
||||
host: externalHost,
|
||||
suffix: localTLD,
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeA,
|
||||
}, {
|
||||
wantIP: netip.Addr{},
|
||||
name: "success_external_non_a",
|
||||
host: examplecom,
|
||||
suffix: defaultLocalDomainSuffix,
|
||||
name: "external_non_a",
|
||||
host: externalHost,
|
||||
suffix: localTLD,
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeCNAME,
|
||||
}, {
|
||||
wantIP: knownIP,
|
||||
name: "success_internal",
|
||||
host: examplelan,
|
||||
suffix: defaultLocalDomainSuffix,
|
||||
name: "internal",
|
||||
host: clientHost,
|
||||
suffix: localTLD,
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeA,
|
||||
}, {
|
||||
wantIP: netip.Addr{},
|
||||
name: "success_internal_unknown",
|
||||
name: "internal_unknown",
|
||||
host: "example-new.lan",
|
||||
suffix: defaultLocalDomainSuffix,
|
||||
suffix: localTLD,
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeA,
|
||||
}, {
|
||||
wantIP: netip.Addr{},
|
||||
name: "success_internal_aaaa",
|
||||
host: examplelan,
|
||||
suffix: defaultLocalDomainSuffix,
|
||||
name: "internal_aaaa",
|
||||
host: clientHost,
|
||||
suffix: localTLD,
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeAAAA,
|
||||
}, {
|
||||
wantIP: knownIP,
|
||||
name: "success_custom_suffix",
|
||||
host: "example.custom",
|
||||
name: "custom_suffix",
|
||||
host: knownClient + ".custom",
|
||||
suffix: "custom",
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
testDHCP := &testDHCP{
|
||||
OnEnabled: func() (_ bool) { return true },
|
||||
OnIPByHost: func(host string) (ip netip.Addr) {
|
||||
if host == knownClient {
|
||||
ip = knownIP
|
||||
}
|
||||
|
||||
return ip
|
||||
},
|
||||
OnHostByIP: func(ip netip.Addr) (host string) { panic("not implemented") },
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
dnsFilter: createTestDNSFilter(t),
|
||||
dhcpServer: testDHCP,
|
||||
localDomainSuffix: tc.suffix,
|
||||
tableHostToIP: hostToIPTable{
|
||||
"example." + tc.suffix: knownIP,
|
||||
},
|
||||
}
|
||||
|
||||
req := &dns.Msg{
|
||||
@@ -504,13 +645,6 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||
res := s.processDHCPHosts(dctx)
|
||||
pctx := dctx.proxyCtx
|
||||
assert.Equal(t, tc.wantRes, res)
|
||||
if tc.wantRes == resultCodeFinish {
|
||||
require.NotNil(t, pctx.Res)
|
||||
assert.Equal(t, dns.RcodeNameError, pctx.Res.Rcode)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, dctx.err)
|
||||
|
||||
if tc.qtyp == dns.TypeAAAA {
|
||||
@@ -555,12 +689,14 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
|
||||
), nil
|
||||
})
|
||||
|
||||
s := createTestServer(t, &filtering.Config{}, ServerConfig{
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
// TODO(s.chzhen): Add tests where EDNSClientSubnet.Enabled is true.
|
||||
// Improve FilteringConfig declaration for tests.
|
||||
FilteringConfig: FilteringConfig{
|
||||
// Improve Config declaration for tests.
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
}, ups)
|
||||
@@ -631,11 +767,13 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
|
||||
|
||||
s := createTestServer(
|
||||
t,
|
||||
&filtering.Config{},
|
||||
&filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
},
|
||||
ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
FilteringConfig: FilteringConfig{
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -139,10 +139,14 @@ func (s *Server) updateStats(
|
||||
clientIP string,
|
||||
) {
|
||||
pctx := ctx.proxyCtx
|
||||
e := stats.Entry{
|
||||
e := &stats.Entry{
|
||||
Domain: aghnet.NormalizeDomain(pctx.Req.Question[0].Name),
|
||||
Result: stats.RNotFiltered,
|
||||
Time: uint32(elapsed / 1000),
|
||||
Time: elapsed,
|
||||
}
|
||||
|
||||
if pctx.Upstream != nil {
|
||||
e.Upstream = pctx.Upstream.Address()
|
||||
}
|
||||
|
||||
if clientID := ctx.clientID; clientID != "" {
|
||||
|
||||
@@ -41,11 +41,11 @@ type testStats struct {
|
||||
// without actually implementing all methods.
|
||||
stats.Interface
|
||||
|
||||
lastEntry stats.Entry
|
||||
lastEntry *stats.Entry
|
||||
}
|
||||
|
||||
// Update implements the [stats.Interface] interface for *testStats.
|
||||
func (l *testStats) Update(e stats.Entry) {
|
||||
func (l *testStats) Update(e *stats.Entry) {
|
||||
if e.Domain == "" {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -12,13 +13,13 @@ import (
|
||||
func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
|
||||
// Preconditions.
|
||||
|
||||
s := &Server{
|
||||
conf: ServerConfig{
|
||||
FilteringConfig: FilteringConfig{
|
||||
BlockedResponseTTL: 3600,
|
||||
},
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
Config: Config{
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
}
|
||||
}, nil)
|
||||
|
||||
req := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
|
||||
@@ -11,6 +11,9 @@
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [
|
||||
"9.9.9.10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
@@ -43,6 +46,9 @@
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [
|
||||
"9.9.9.10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
@@ -75,6 +81,9 @@
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [
|
||||
"9.9.9.10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
@@ -54,6 +55,7 @@
|
||||
"bootstrap_dns": [
|
||||
"9.9.9.10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
@@ -91,6 +93,7 @@
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
@@ -128,6 +131,7 @@
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
@@ -165,6 +169,7 @@
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 6,
|
||||
@@ -202,6 +207,7 @@
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
@@ -241,6 +247,7 @@
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
@@ -280,6 +287,7 @@
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
@@ -317,6 +325,7 @@
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
@@ -354,6 +363,7 @@
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
@@ -391,6 +401,7 @@
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
@@ -428,6 +439,7 @@
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
@@ -467,6 +479,7 @@
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
@@ -506,6 +519,7 @@
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
@@ -544,6 +558,7 @@
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
@@ -581,6 +596,7 @@
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
@@ -620,6 +636,7 @@
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
@@ -662,6 +679,7 @@
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
@@ -699,6 +717,49 @@
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
"upstream_mode": "",
|
||||
"cache_size": 0,
|
||||
"cache_ttl_min": 0,
|
||||
"cache_ttl_max": 0,
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"fallbacks": {
|
||||
"req": {
|
||||
"fallback_dns": [
|
||||
"9.9.9.10"
|
||||
]
|
||||
},
|
||||
"want": {
|
||||
"upstream_dns": [
|
||||
"8.8.8.8:53",
|
||||
"8.8.4.4:53"
|
||||
],
|
||||
"upstream_dns_file": "",
|
||||
"bootstrap_dns": [
|
||||
"9.9.9.10",
|
||||
"149.112.112.10",
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [
|
||||
"9.9.9.10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
|
||||
@@ -14,8 +14,6 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
@@ -94,7 +92,8 @@ func (s *Server) prepareUpstreamConfig(
|
||||
uc.Upstreams = defaultUpstreamConfig.Upstreams
|
||||
}
|
||||
|
||||
if s.dnsFilter != nil && s.dnsFilter.EtcHosts != nil {
|
||||
// dnsFilter can be nil during application update.
|
||||
if s.dnsFilter != nil {
|
||||
err = s.replaceUpstreamsWithHosts(uc, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolving upstreams with hosts: %w", err)
|
||||
@@ -158,17 +157,20 @@ func (s *Server) resolveUpstreamsWithHosts(
|
||||
|
||||
withIPs, ok := resolved[host]
|
||||
if !ok {
|
||||
ips := s.resolveUpstreamHost(host)
|
||||
if len(ips) == 0 {
|
||||
recs := s.dnsFilter.EtcHostsRecords(host)
|
||||
if len(recs) == 0 {
|
||||
resolved[host] = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
sortNetIPAddrs(ips, opts.PreferIPv6)
|
||||
|
||||
withIPs = opts.Clone()
|
||||
withIPs.ServerIPAddrs = ips
|
||||
withIPs.ServerIPAddrs = make([]net.IP, 0, len(recs))
|
||||
for _, rec := range recs {
|
||||
withIPs.ServerIPAddrs = append(withIPs.ServerIPAddrs, rec.Addr.AsSlice())
|
||||
}
|
||||
|
||||
sortNetIPAddrs(withIPs.ServerIPAddrs, opts.PreferIPv6)
|
||||
resolved[host] = withIPs
|
||||
} else if withIPs == nil {
|
||||
continue
|
||||
@@ -216,33 +218,6 @@ func extractUpstreamHost(addr string) (host string) {
|
||||
return host
|
||||
}
|
||||
|
||||
// resolveUpstreamHost returns the version of ups with IP addresses from the
|
||||
// system hosts file placed into its options.
|
||||
func (s *Server) resolveUpstreamHost(host string) (addrs []net.IP) {
|
||||
req := &urlfilter.DNSRequest{
|
||||
Hostname: host,
|
||||
DNSType: dns.TypeA,
|
||||
}
|
||||
aRes, _ := s.dnsFilter.EtcHosts.MatchRequest(req)
|
||||
|
||||
req.DNSType = dns.TypeAAAA
|
||||
aaaaRes, _ := s.dnsFilter.EtcHosts.MatchRequest(req)
|
||||
|
||||
var ips []net.IP
|
||||
for _, rw := range append(aRes.DNSRewrites(), aaaaRes.DNSRewrites()...) {
|
||||
dr := rw.DNSRewrite
|
||||
if dr == nil || dr.Value == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if ip, ok := dr.Value.(net.IP); ok {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
}
|
||||
|
||||
return ips
|
||||
}
|
||||
|
||||
// sortNetIPAddrs sorts addrs in accordance with the protocol preferences.
|
||||
// Invalid addresses are sorted near the end.
|
||||
//
|
||||
@@ -254,28 +229,38 @@ func sortNetIPAddrs(addrs []net.IP, preferIPv6 bool) {
|
||||
return
|
||||
}
|
||||
|
||||
slices.SortStableFunc(addrs, func(addrA, addrB net.IP) (sortsBefore bool) {
|
||||
slices.SortStableFunc(addrs, func(addrA, addrB net.IP) (res int) {
|
||||
switch len(addrA) {
|
||||
case net.IPv4len, net.IPv6len:
|
||||
switch len(addrB) {
|
||||
case net.IPv4len, net.IPv6len:
|
||||
// Go on.
|
||||
default:
|
||||
return true
|
||||
return -1
|
||||
}
|
||||
default:
|
||||
return false
|
||||
return 1
|
||||
}
|
||||
|
||||
if aIs4, bIs4 := addrA.To4() != nil, addrB.To4() != nil; aIs4 != bIs4 {
|
||||
if aIs4 {
|
||||
return !preferIPv6
|
||||
// Treat IPv6-mapped IPv4 addresses as IPv6 addresses.
|
||||
aIs4, bIs4 := addrA.To4() != nil, addrB.To4() != nil
|
||||
if aIs4 == bIs4 {
|
||||
return bytes.Compare(addrA, addrB)
|
||||
}
|
||||
|
||||
if aIs4 {
|
||||
if preferIPv6 {
|
||||
return 1
|
||||
}
|
||||
|
||||
return preferIPv6
|
||||
return -1
|
||||
}
|
||||
|
||||
return bytes.Compare(addrA, addrB) < 0
|
||||
if preferIPv6 {
|
||||
return -1
|
||||
}
|
||||
|
||||
return 1
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user