all: sync more
This commit is contained in:
@@ -132,8 +132,9 @@ type FilteringConfig struct {
|
||||
|
||||
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS
|
||||
type TLSConfig struct {
|
||||
TLSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
|
||||
QUICListenAddrs []*net.UDPAddr `yaml:"-" json:"-"`
|
||||
TLSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
|
||||
QUICListenAddrs []*net.UDPAddr `yaml:"-" json:"-"`
|
||||
HTTPSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
|
||||
|
||||
// Reject connection if the client uses server name (in SNI) that doesn't match the certificate
|
||||
StrictSNICheck bool `yaml:"strict_sni_check" json:"-"`
|
||||
|
||||
@@ -214,9 +214,8 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
ipToHost = netutil.NewIPMap(len(ll))
|
||||
|
||||
for _, l := range ll {
|
||||
// TODO(a.garipov): Remove this after we're finished
|
||||
// with the client hostname validations in the DHCP
|
||||
// server code.
|
||||
// TODO(a.garipov): Remove this after we're finished with the client
|
||||
// hostname validations in the DHCP server code.
|
||||
err = netutil.ValidateDomainName(l.Hostname)
|
||||
if err != nil {
|
||||
log.Debug(
|
||||
@@ -252,7 +251,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
|
||||
return rc
|
||||
}
|
||||
|
||||
dctx.isLocalClient = s.subnetDetector.IsLocallyServedNetwork(ip)
|
||||
dctx.isLocalClient = s.privateNets.Contains(ip)
|
||||
|
||||
return rc
|
||||
}
|
||||
@@ -300,6 +299,8 @@ func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) {
|
||||
}
|
||||
|
||||
reqHost := strings.ToLower(q.Name)
|
||||
// TODO(a.garipov): Move everything related to DHCP local domain to the DHCP
|
||||
// server.
|
||||
host := strings.TrimSuffix(reqHost, s.localDomainSuffix)
|
||||
if host == reqHost {
|
||||
return resultCodeSuccess
|
||||
@@ -372,7 +373,7 @@ func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) {
|
||||
// Restrict an access to local addresses for external clients. We also
|
||||
// assume that all the DHCP leases we give are locally-served or at least
|
||||
// don't need to be inaccessible externally.
|
||||
if !s.subnetDetector.IsLocallyServedNetwork(ip) {
|
||||
if !s.privateNets.Contains(ip) {
|
||||
log.Debug("dns: addr %s is not from locally-served network", ip)
|
||||
|
||||
return resultCodeSuccess
|
||||
@@ -479,7 +480,7 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
if !s.subnetDetector.IsLocallyServedNetwork(ip) {
|
||||
if !s.privateNets.Contains(ip) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
@@ -611,9 +612,9 @@ func (s *Server) processFilteringAfterResponse(ctx *dnsContext) (rc resultCode)
|
||||
d.Res.Answer = answer
|
||||
}
|
||||
default:
|
||||
// Check the response only if the it's from an upstream. Don't check
|
||||
// the response if the protection is disabled since dnsrewrite rules
|
||||
// aren't applied to it anyway.
|
||||
// Check the response only if it's from an upstream. Don't check the
|
||||
// response if the protection is disabled since dnsrewrite rules aren't
|
||||
// applied to it anyway.
|
||||
if !ctx.protectionEnabled || !ctx.responseFromUpstream || s.dnsFilter == nil {
|
||||
break
|
||||
}
|
||||
|
||||
@@ -4,35 +4,41 @@ import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
s := &Server{
|
||||
subnetDetector: snd,
|
||||
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
want assert.BoolAssertionFunc
|
||||
name string
|
||||
cliIP net.IP
|
||||
want bool
|
||||
}{{
|
||||
want: assert.True,
|
||||
name: "local",
|
||||
cliIP: net.IP{192, 168, 0, 1},
|
||||
want: true,
|
||||
}, {
|
||||
want: assert.False,
|
||||
name: "external",
|
||||
cliIP: net.IP{250, 249, 0, 1},
|
||||
want: false,
|
||||
}, {
|
||||
want: assert.False,
|
||||
name: "invalid",
|
||||
cliIP: net.IP{1, 2, 3, 4, 5},
|
||||
}, {
|
||||
want: assert.False,
|
||||
name: "nil",
|
||||
cliIP: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -47,7 +53,7 @@ func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||
}
|
||||
s.processDetermineLocal(dctx)
|
||||
|
||||
assert.Equal(t, tc.want, dctx.isLocalClient)
|
||||
tc.want(t, dctx.isLocalClient)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -261,7 +267,7 @@ func TestServer_ProcessInternalHosts(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer_ProcessRestrictLocal(t *testing.T) {
|
||||
ups := &aghtest.TestUpstream{
|
||||
ups := &aghtest.Upstream{
|
||||
Reverse: map[string][]string{
|
||||
"251.252.253.254.in-addr.arpa.": {"host1.example.net."},
|
||||
"1.1.168.192.in-addr.arpa.": {"some.local-client."},
|
||||
@@ -339,7 +345,7 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
|
||||
s := createTestServer(t, &filtering.Config{}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
}, &aghtest.TestUpstream{
|
||||
}, &aghtest.Upstream{
|
||||
Reverse: map[string][]string{
|
||||
reqAddr: {locDomain},
|
||||
},
|
||||
|
||||
@@ -74,7 +74,7 @@ type Server struct {
|
||||
localDomainSuffix string
|
||||
|
||||
ipset ipsetCtx
|
||||
subnetDetector *aghnet.SubnetDetector
|
||||
privateNets netutil.SubnetSet
|
||||
localResolvers *proxy.Proxy
|
||||
sysResolvers aghnet.SystemResolvers
|
||||
recDetector *recursionDetector
|
||||
@@ -111,13 +111,13 @@ const defaultLocalDomainSuffix = ".lan."
|
||||
|
||||
// DNSCreateParams are parameters to create a new server.
|
||||
type DNSCreateParams struct {
|
||||
DNSFilter *filtering.DNSFilter
|
||||
Stats stats.Stats
|
||||
QueryLog querylog.QueryLog
|
||||
DHCPServer dhcpd.ServerInterface
|
||||
SubnetDetector *aghnet.SubnetDetector
|
||||
Anonymizer *aghnet.IPMut
|
||||
LocalDomain string
|
||||
DNSFilter *filtering.DNSFilter
|
||||
Stats stats.Stats
|
||||
QueryLog querylog.QueryLog
|
||||
DHCPServer dhcpd.ServerInterface
|
||||
PrivateNets netutil.SubnetSet
|
||||
Anonymizer *aghnet.IPMut
|
||||
LocalDomain string
|
||||
}
|
||||
|
||||
// domainNameToSuffix converts a domain name into a local domain suffix.
|
||||
@@ -161,7 +161,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
dnsFilter: p.DNSFilter,
|
||||
stats: p.Stats,
|
||||
queryLog: p.QueryLog,
|
||||
subnetDetector: p.SubnetDetector,
|
||||
privateNets: p.PrivateNets,
|
||||
localDomainSuffix: localDomainSuffix,
|
||||
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
|
||||
clientIDCache: cache.New(cache.Config{
|
||||
@@ -315,7 +315,7 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) {
|
||||
}
|
||||
|
||||
var resolver *proxy.Proxy
|
||||
if s.subnetDetector.IsLocallyServedNetwork(ip) {
|
||||
if s.privateNets.Contains(ip) {
|
||||
if !s.conf.UsePrivateRDNS {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/miekg/dns"
|
||||
@@ -69,14 +70,11 @@ func createTestServer(
|
||||
f := filtering.New(filterConf, filters)
|
||||
f.SetEnabled(true)
|
||||
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
var err error
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
SubnetDetector: snd,
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -89,7 +87,7 @@ func createTestServer(
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
if localUps != nil {
|
||||
s.localResolvers.Config.UpstreamConfig.Upstreams = []upstream.Upstream{localUps}
|
||||
s.localResolvers.UpstreamConfig.Upstreams = []upstream.Upstream{localUps}
|
||||
s.conf.UsePrivateRDNS = true
|
||||
}
|
||||
|
||||
@@ -247,7 +245,7 @@ func TestServer(t *testing.T) {
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
}, nil)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.TestUpstream{
|
||||
&aghtest.Upstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
||||
},
|
||||
@@ -316,7 +314,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
}, nil)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.TestUpstream{
|
||||
&aghtest.Upstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
||||
},
|
||||
@@ -339,7 +337,7 @@ func TestDoTServer(t *testing.T) {
|
||||
TLSListenAddrs: []*net.TCPAddr{{}},
|
||||
})
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.TestUpstream{
|
||||
&aghtest.Upstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
||||
},
|
||||
@@ -369,7 +367,7 @@ func TestDoQServer(t *testing.T) {
|
||||
QUICListenAddrs: []*net.UDPAddr{{IP: net.IP{127, 0, 0, 1}}},
|
||||
})
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.TestUpstream{
|
||||
&aghtest.Upstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
||||
},
|
||||
@@ -413,7 +411,7 @@ func TestServerRace(t *testing.T) {
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.TestUpstream{
|
||||
&aghtest.Upstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
||||
},
|
||||
@@ -552,7 +550,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||
s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) {
|
||||
ups := &aghtest.TestUpstream{
|
||||
ups := &aghtest.Upstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"host.": {{192, 168, 0, 1}},
|
||||
},
|
||||
@@ -580,9 +578,9 @@ func TestServerCustomClientUpstream(t *testing.T) {
|
||||
}
|
||||
|
||||
// testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work.
|
||||
var testCNAMEs = map[string]string{
|
||||
"badhost.": "NULL.example.org.",
|
||||
"whitelist.example.org.": "NULL.example.org.",
|
||||
var testCNAMEs = map[string][]string{
|
||||
"badhost.": {"NULL.example.org."},
|
||||
"whitelist.example.org.": {"NULL.example.org."},
|
||||
}
|
||||
|
||||
// testIPv4 is a map of names and IPv4s necessary for the TestUpstream work.
|
||||
@@ -596,7 +594,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
}, nil)
|
||||
testUpstm := &aghtest.TestUpstream{
|
||||
testUpstm := &aghtest.Upstream{
|
||||
CName: testCNAMEs,
|
||||
IPv4: testIPv4,
|
||||
IPv6: nil,
|
||||
@@ -630,7 +628,7 @@ func TestBlockCNAME(t *testing.T) {
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.TestUpstream{
|
||||
&aghtest.Upstream{
|
||||
CName: testCNAMEs,
|
||||
IPv4: testIPv4,
|
||||
},
|
||||
@@ -640,14 +638,17 @@ func TestBlockCNAME(t *testing.T) {
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
want bool
|
||||
}{{
|
||||
name: "block_request",
|
||||
host: "badhost.",
|
||||
// 'badhost' has a canonical name 'NULL.example.org' which is
|
||||
// blocked by filters: response is blocked.
|
||||
want: true,
|
||||
}, {
|
||||
name: "allowed",
|
||||
host: "whitelist.example.org.",
|
||||
// 'whitelist.example.org' has a canonical name
|
||||
// 'NULL.example.org' which is blocked by filters
|
||||
@@ -655,6 +656,7 @@ func TestBlockCNAME(t *testing.T) {
|
||||
// response isn't blocked.
|
||||
want: false,
|
||||
}, {
|
||||
name: "block_response",
|
||||
host: "example.org.",
|
||||
// 'example.org' has a canonical name 'cname1' with IP
|
||||
// 127.0.0.255 which is blocked by filters: response is blocked.
|
||||
@@ -662,9 +664,9 @@ func TestBlockCNAME(t *testing.T) {
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run("block_cname_"+tc.host, func(t *testing.T) {
|
||||
req := createTestMessage(tc.host)
|
||||
req := createTestMessage(tc.host)
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
reply, err := dns.Exchange(req, addr)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -674,7 +676,7 @@ func TestBlockCNAME(t *testing.T) {
|
||||
|
||||
ans := reply.Answer[0]
|
||||
a, ok := ans.(*dns.A)
|
||||
require.Truef(t, ok, "got %T", ans)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.True(t, a.A.IsUnspecified())
|
||||
}
|
||||
@@ -695,7 +697,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.TestUpstream{
|
||||
&aghtest.Upstream{
|
||||
CName: testCNAMEs,
|
||||
IPv4: testIPv4,
|
||||
},
|
||||
@@ -766,16 +768,11 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||
Data: []byte(rules),
|
||||
}}
|
||||
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
f := filtering.New(&filtering.Config{}, filters)
|
||||
var s *Server
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
SubnetDetector: snd,
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -909,15 +906,10 @@ func TestRewrite(t *testing.T) {
|
||||
f := filtering.New(c, nil)
|
||||
f.SetEnabled(true)
|
||||
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
var s *Server
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
SubnetDetector: snd,
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -931,9 +923,9 @@ func TestRewrite(t *testing.T) {
|
||||
}))
|
||||
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.TestUpstream{
|
||||
CName: map[string]string{
|
||||
"example.org": "somename",
|
||||
&aghtest.Upstream{
|
||||
CName: map[string][]string{
|
||||
"example.org": {"somename"},
|
||||
},
|
||||
IPv4: map[string][]net.IP{
|
||||
"example.org.": {{4, 3, 2, 1}},
|
||||
@@ -1024,15 +1016,10 @@ func (d *testDHCP) Leases(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) {
|
||||
func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {}
|
||||
|
||||
func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
var s *Server
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DNSFilter: filtering.New(&filtering.Config{}, nil),
|
||||
DHCPServer: &testDHCP{},
|
||||
SubnetDetector: snd,
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DNSFilter: filtering.New(&filtering.Config{}, nil),
|
||||
DHCPServer: &testDHCP{},
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1101,16 +1088,11 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
}, nil)
|
||||
flt.SetEnabled(true)
|
||||
|
||||
var snd *aghnet.SubnetDetector
|
||||
snd, err = aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
var s *Server
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: flt,
|
||||
SubnetDetector: snd,
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: flt,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1193,12 +1175,12 @@ func TestNewServer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer_Exchange(t *testing.T) {
|
||||
extUpstream := &aghtest.TestUpstream{
|
||||
extUpstream := &aghtest.Upstream{
|
||||
Reverse: map[string][]string{
|
||||
"1.1.1.1.in-addr.arpa.": {"one.one.one.one"},
|
||||
},
|
||||
}
|
||||
locUpstream := &aghtest.TestUpstream{
|
||||
locUpstream := &aghtest.Upstream{
|
||||
Reverse: map[string][]string{
|
||||
"1.1.168.192.in-addr.arpa.": {"local.domain"},
|
||||
"2.1.168.192.in-addr.arpa.": {},
|
||||
@@ -1223,9 +1205,7 @@ func TestServer_Exchange(t *testing.T) {
|
||||
srv.conf.ResolveClients = true
|
||||
srv.conf.UsePrivateRDNS = true
|
||||
|
||||
var err error
|
||||
srv.subnetDetector, err = aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
|
||||
localIP := net.IP{192, 168, 1, 1}
|
||||
testCases := []struct {
|
||||
|
||||
@@ -116,7 +116,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
|
||||
|
||||
// checkHostRules checks the host against filters. It is safe for concurrent
|
||||
// use.
|
||||
func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Settings) (
|
||||
func (s *Server) checkHostRules(host string, rrtype uint16, setts *filtering.Settings) (
|
||||
r *filtering.Result,
|
||||
err error,
|
||||
) {
|
||||
@@ -128,7 +128,7 @@ func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Sett
|
||||
}
|
||||
|
||||
var res filtering.Result
|
||||
res, err = s.dnsFilter.CheckHostRules(host, qtype, setts)
|
||||
res, err = s.dnsFilter.CheckHostRules(host, rrtype, setts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -136,33 +136,36 @@ func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Sett
|
||||
return &res, err
|
||||
}
|
||||
|
||||
// If response contains CNAME, A or AAAA records, we apply filtering to each
|
||||
// canonical host name or IP address. If this is a match, we set a new response
|
||||
// in d.Res and return.
|
||||
func (s *Server) filterDNSResponse(ctx *dnsContext) (*filtering.Result, error) {
|
||||
// filterDNSResponse checks each resource record of the response's answer
|
||||
// section from ctx and returns a non-nil res if at least one of canonnical
|
||||
// names or IP addresses in it matches the filtering rules.
|
||||
func (s *Server) filterDNSResponse(ctx *dnsContext) (res *filtering.Result, err error) {
|
||||
d := ctx.proxyCtx
|
||||
setts := ctx.setts
|
||||
if !setts.FilteringEnabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for _, a := range d.Res.Answer {
|
||||
host := ""
|
||||
|
||||
switch v := a.(type) {
|
||||
var rrtype uint16
|
||||
switch a := a.(type) {
|
||||
case *dns.CNAME:
|
||||
log.Debug("DNSFwd: Checking CNAME %s for %s", v.Target, v.Hdr.Name)
|
||||
host = strings.TrimSuffix(v.Target, ".")
|
||||
|
||||
host = strings.TrimSuffix(a.Target, ".")
|
||||
rrtype = dns.TypeCNAME
|
||||
case *dns.A:
|
||||
host = v.A.String()
|
||||
log.Debug("DNSFwd: Checking record A (%s) for %s", host, v.Hdr.Name)
|
||||
|
||||
host = a.A.String()
|
||||
rrtype = dns.TypeA
|
||||
case *dns.AAAA:
|
||||
host = v.AAAA.String()
|
||||
log.Debug("DNSFwd: Checking record AAAA (%s) for %s", host, v.Hdr.Name)
|
||||
|
||||
host = a.AAAA.String()
|
||||
rrtype = dns.TypeAAAA
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
host = strings.TrimSuffix(host, ".")
|
||||
res, err := s.checkHostRules(host, d.Req.Question[0].Qtype, ctx.setts)
|
||||
log.Debug("dnsforward: checking %s %s for %s", dns.Type(rrtype), host, a.Header().Name)
|
||||
|
||||
res, err = s.checkHostRules(host, rrtype, setts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if res == nil {
|
||||
|
||||
154
internal/dnsforward/filter_test.go
Normal file
154
internal/dnsforward/filter_test.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
rules := `
|
||||
||blocked.domain^
|
||||
@@||allowed.domain^
|
||||
||cname.specific^$dnstype=~CNAME
|
||||
||0.0.0.1^$dnstype=~A
|
||||
||::1^$dnstype=~AAAA
|
||||
`
|
||||
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
FilteringConfig: FilteringConfig{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: BlockingModeDefault,
|
||||
},
|
||||
}
|
||||
filters := []filtering.Filter{{
|
||||
ID: 0, Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f := filtering.New(&filtering.Config{}, filters)
|
||||
f.SetEnabled(true)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
s.conf = forwardConf
|
||||
err = s.Prepare(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.Upstream{
|
||||
CName: map[string][]string{
|
||||
"cname.exception.": {"cname.specific."},
|
||||
"should.block.": {"blocked.domain."},
|
||||
"allowed.first.": {"allowed.domain.", "blocked.domain."},
|
||||
"blocked.first.": {"blocked.domain.", "allowed.domain."},
|
||||
},
|
||||
IPv4: map[string][]net.IP{
|
||||
"a.exception.": {{0, 0, 0, 1}},
|
||||
},
|
||||
IPv6: map[string][]net.IP{
|
||||
"aaaa.exception.": {net.ParseIP("::1")},
|
||||
},
|
||||
},
|
||||
}
|
||||
startDeferStop(t, s)
|
||||
|
||||
testCases := []struct {
|
||||
req *dns.Msg
|
||||
name string
|
||||
wantAns []dns.RR
|
||||
}{{
|
||||
req: createTestMessage("cname.exception."),
|
||||
name: "cname_exception",
|
||||
wantAns: []dns.RR{&dns.CNAME{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "cname.exception.",
|
||||
Rrtype: dns.TypeCNAME,
|
||||
},
|
||||
Target: "cname.specific.",
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("should.block."),
|
||||
name: "blocked_by_cname",
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "should.block.",
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
A: netutil.IPv4Zero(),
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("a.exception."),
|
||||
name: "a_exception",
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "a.exception.",
|
||||
Rrtype: dns.TypeA,
|
||||
},
|
||||
A: net.IP{0, 0, 0, 1},
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessageWithType("aaaa.exception.", dns.TypeAAAA),
|
||||
name: "aaaa_exception",
|
||||
wantAns: []dns.RR{&dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "aaaa.exception.",
|
||||
Rrtype: dns.TypeAAAA,
|
||||
},
|
||||
AAAA: net.ParseIP("::1"),
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("allowed.first."),
|
||||
name: "allowed_first",
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "allowed.first.",
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
A: netutil.IPv4Zero(),
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("blocked.first."),
|
||||
name: "blocked_first",
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "blocked.first.",
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
A: netutil.IPv4Zero(),
|
||||
}},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
dctx := &proxy.DNSContext{
|
||||
Proto: proxy.ProtoUDP,
|
||||
Req: tc.req,
|
||||
Addr: &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 1},
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = s.handleDNSRequest(nil, dctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, dctx.Res)
|
||||
|
||||
assert.Equal(t, tc.wantAns, dctx.Res.Answer)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -41,7 +42,7 @@ type dnsConfig struct {
|
||||
LocalPTRUpstreams *[]string `json:"local_ptr_upstreams"`
|
||||
}
|
||||
|
||||
func (s *Server) getDNSConfig() dnsConfig {
|
||||
func (s *Server) getDNSConfig() (c *dnsConfig) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
@@ -70,7 +71,7 @@ func (s *Server) getDNSConfig() dnsConfig {
|
||||
upstreamMode = "parallel"
|
||||
}
|
||||
|
||||
return dnsConfig{
|
||||
return &dnsConfig{
|
||||
Upstreams: &upstreams,
|
||||
UpstreamsFile: &upstreamFile,
|
||||
Bootstraps: &bootstraps,
|
||||
@@ -106,7 +107,7 @@ func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
// since there is no need to omit it while decoding from JSON.
|
||||
DefautLocalPTRUpstreams []string `json:"default_local_ptr_upstreams,omitempty"`
|
||||
}{
|
||||
dnsConfig: s.getDNSConfig(),
|
||||
dnsConfig: *s.getDNSConfig(),
|
||||
DefautLocalPTRUpstreams: defLocalPTRUps,
|
||||
}
|
||||
|
||||
@@ -138,39 +139,63 @@ func (req *dnsConfig) checkBlockingMode() bool {
|
||||
}
|
||||
|
||||
func (req *dnsConfig) checkUpstreamsMode() bool {
|
||||
if req.UpstreamMode == nil {
|
||||
return true
|
||||
}
|
||||
valid := []string{"", "fastest_addr", "parallel"}
|
||||
|
||||
for _, valid := range []string{
|
||||
"",
|
||||
"fastest_addr",
|
||||
"parallel",
|
||||
} {
|
||||
if *req.UpstreamMode == valid {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return req.UpstreamMode == nil || stringutil.InSlice(valid, *req.UpstreamMode)
|
||||
}
|
||||
|
||||
func (req *dnsConfig) checkBootstrap() (string, error) {
|
||||
func (req *dnsConfig) checkBootstrap() (err error) {
|
||||
if req.Bootstraps == nil {
|
||||
return "", nil
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, boot := range *req.Bootstraps {
|
||||
if boot == "" {
|
||||
return boot, fmt.Errorf("invalid bootstrap server address: empty")
|
||||
var b string
|
||||
defer func() { err = errors.Annotate(err, "checking bootstrap %s: invalid address: %w", b) }()
|
||||
|
||||
for _, b = range *req.Bootstraps {
|
||||
if b == "" {
|
||||
return errors.Error("empty")
|
||||
}
|
||||
|
||||
if _, err := upstream.NewResolver(boot, nil); err != nil {
|
||||
return boot, fmt.Errorf("invalid bootstrap server address: %w", err)
|
||||
if _, err = upstream.NewResolver(b, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// validate returns an error if any field of req is invalid.
|
||||
func (req *dnsConfig) validate(privateNets netutil.SubnetSet) (err error) {
|
||||
if req.Upstreams != nil {
|
||||
err = ValidateUpstreams(*req.Upstreams)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating upstream servers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if req.LocalPTRUpstreams != nil {
|
||||
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating private upstream servers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = req.checkBootstrap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case !req.checkBlockingMode():
|
||||
return errors.Error("blocking_mode: incorrect value")
|
||||
case !req.checkUpstreamsMode():
|
||||
return errors.Error("upstream_mode: incorrect value")
|
||||
case !req.checkCacheTTL():
|
||||
return errors.Error("cache_ttl_min must be less or equal than cache_ttl_max")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (req *dnsConfig) checkCacheTTL() bool {
|
||||
@@ -190,69 +215,33 @@ func (req *dnsConfig) checkCacheTTL() bool {
|
||||
}
|
||||
|
||||
func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
req := dnsConfig{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
req := &dnsConfig{}
|
||||
err := json.NewDecoder(r.Body).Decode(req)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json Encode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "decoding request: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if req.Upstreams != nil {
|
||||
if err = ValidateUpstreams(*req.Upstreams); err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "wrong upstreams specification: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var errBoot string
|
||||
if errBoot, err = req.checkBootstrap(); err != nil {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
"%s can not be used as bootstrap dns cause: %s",
|
||||
errBoot,
|
||||
err,
|
||||
)
|
||||
err = req.validate(s.privateNets)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case !req.checkBlockingMode():
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "blocking_mode: incorrect value")
|
||||
|
||||
return
|
||||
case !req.checkUpstreamsMode():
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "upstream_mode: incorrect value")
|
||||
|
||||
return
|
||||
case !req.checkCacheTTL():
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
"cache_ttl_min must be less or equal than cache_ttl_max",
|
||||
)
|
||||
|
||||
return
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
restart := s.setConfig(req)
|
||||
s.conf.ConfigModified()
|
||||
|
||||
if restart {
|
||||
if err = s.Reconfigure(nil); err != nil {
|
||||
err = s.Reconfigure(nil)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) setConfigRestartable(dc dnsConfig) (restart bool) {
|
||||
func (s *Server) setConfigRestartable(dc *dnsConfig) (restart bool) {
|
||||
if dc.Upstreams != nil {
|
||||
s.conf.UpstreamDNS = *dc.Upstreams
|
||||
restart = true
|
||||
@@ -273,9 +262,9 @@ func (s *Server) setConfigRestartable(dc dnsConfig) (restart bool) {
|
||||
restart = true
|
||||
}
|
||||
|
||||
if dc.RateLimit != nil {
|
||||
restart = restart || s.conf.Ratelimit != *dc.RateLimit
|
||||
if dc.RateLimit != nil && s.conf.Ratelimit != *dc.RateLimit {
|
||||
s.conf.Ratelimit = *dc.RateLimit
|
||||
restart = true
|
||||
}
|
||||
|
||||
if dc.EDNSCSEnabled != nil {
|
||||
@@ -306,7 +295,7 @@ func (s *Server) setConfigRestartable(dc dnsConfig) (restart bool) {
|
||||
return restart
|
||||
}
|
||||
|
||||
func (s *Server) setConfig(dc dnsConfig) (restart bool) {
|
||||
func (s *Server) setConfig(dc *dnsConfig) (restart bool) {
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
@@ -353,52 +342,106 @@ type upstreamJSON struct {
|
||||
PrivateUpstreams []string `json:"private_upstream"`
|
||||
}
|
||||
|
||||
// IsCommentOrEmpty returns true of the string starts with a "#" character or is
|
||||
// an empty string. This function is useful for filtering out non-upstream
|
||||
// lines from upstream configs.
|
||||
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
|
||||
// This function is useful for filtering out non-upstream lines from upstream
|
||||
// configs.
|
||||
func IsCommentOrEmpty(s string) (ok bool) {
|
||||
return len(s) == 0 || s[0] == '#'
|
||||
}
|
||||
|
||||
// newUpstreamConfig validates upstreams and returns an appropriate upstream
|
||||
// configuration or nil if it can't be built.
|
||||
//
|
||||
// TODO(e.burkov): Perhaps proxy.ParseUpstreamsConfig should validate upstreams
|
||||
// slice already so that this function may be considered useless.
|
||||
func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err error) {
|
||||
// No need to validate comments and empty lines.
|
||||
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
|
||||
if len(upstreams) == 0 {
|
||||
// Consider this case valid since it means the default server should be
|
||||
// used.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
conf, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
&upstream.Options{Bootstrap: []string{}, Timeout: DefaultTimeout},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(conf.Upstreams) == 0 {
|
||||
return nil, errors.Error("no default upstreams specified")
|
||||
}
|
||||
|
||||
for _, u := range upstreams {
|
||||
_, err = validateUpstream(u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
// ValidateUpstreams validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified.
|
||||
//
|
||||
// TODO(e.burkov): Move into aghnet or even into dnsproxy.
|
||||
// TODO(e.burkov): Move into aghnet or even into dnsproxy.
|
||||
func ValidateUpstreams(upstreams []string) (err error) {
|
||||
// No need to validate comments
|
||||
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
|
||||
_, err = newUpstreamConfig(upstreams)
|
||||
|
||||
// Consider this case valid because defaultDNS will be used
|
||||
if len(upstreams) == 0 {
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// stringKeysSorted returns the sorted slice of string keys of m.
|
||||
//
|
||||
// TODO(e.burkov): Use generics in Go 1.18. Move into golibs.
|
||||
func stringKeysSorted(m map[string][]upstream.Upstream) (sorted []string) {
|
||||
sorted = make([]string, 0, len(m))
|
||||
for s := range m {
|
||||
sorted = append(sorted, s)
|
||||
}
|
||||
|
||||
_, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
&upstream.Options{
|
||||
Bootstrap: []string{},
|
||||
Timeout: DefaultTimeout,
|
||||
},
|
||||
)
|
||||
sort.Strings(sorted)
|
||||
|
||||
return sorted
|
||||
}
|
||||
|
||||
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified. It also
|
||||
// checks each domain of domain-specific upstreams for being ARPA pointing to
|
||||
// a locally-served network. privateNets must not be nil.
|
||||
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
|
||||
conf, err := newUpstreamConfig(upstreams)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var defaultUpstreamFound bool
|
||||
for _, u := range upstreams {
|
||||
var useDefault bool
|
||||
useDefault, err = validateUpstream(u)
|
||||
if conf == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var errs []error
|
||||
|
||||
for _, domain := range stringKeysSorted(conf.DomainReservedUpstreams) {
|
||||
var subnet *net.IPNet
|
||||
subnet, err = netutil.SubnetFromReversedAddr(domain)
|
||||
if err != nil {
|
||||
return err
|
||||
errs = append(errs, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !defaultUpstreamFound {
|
||||
defaultUpstreamFound = useDefault
|
||||
if !privateNets.Contains(subnet.IP) {
|
||||
errs = append(
|
||||
errs,
|
||||
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if !defaultUpstreamFound {
|
||||
return fmt.Errorf("no default upstreams specified")
|
||||
if len(errs) > 0 {
|
||||
return errors.List("checking domain-specific upstreams", errs...)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -184,12 +185,11 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "upstream_dns_bad",
|
||||
wantSet: `wrong upstreams specification: bad ipport address "!!!": address !!!: ` +
|
||||
`missing port in address`,
|
||||
wantSet: `validating upstream servers: bad ipport address "!!!": ` +
|
||||
`address !!!: missing port in address`,
|
||||
}, {
|
||||
name: "bootstraps_bad",
|
||||
wantSet: `a can not be used as bootstrap dns cause: ` +
|
||||
`invalid bootstrap server address: ` +
|
||||
wantSet: `checking bootstrap a: invalid address: ` +
|
||||
`Resolver a is not eligible to be a bootstrap DNS server`,
|
||||
}, {
|
||||
name: "cache_bad_ttl",
|
||||
@@ -200,6 +200,10 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
}, {
|
||||
name: "local_ptr_upstreams_good",
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "local_ptr_upstreams_bad",
|
||||
wantSet: `validating private upstream servers: checking domain-specific upstreams: ` +
|
||||
`bad arpa domain name "non.arpa": not a reversed ip network`,
|
||||
}, {
|
||||
name: "local_ptr_upstreams_null",
|
||||
wantSet: "",
|
||||
@@ -358,7 +362,7 @@ func TestValidateUpstream(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUpstreamsSet(t *testing.T) {
|
||||
func TestValidateUpstreams(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErr string
|
||||
@@ -405,3 +409,51 @@ func TestValidateUpstreamsSet(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUpstreamsPrivate(t *testing.T) {
|
||||
ss := netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErr string
|
||||
u string
|
||||
}{{
|
||||
name: "success_address",
|
||||
wantErr: ``,
|
||||
u: "[/1.0.0.127.in-addr.arpa/]#",
|
||||
}, {
|
||||
name: "success_subnet",
|
||||
wantErr: ``,
|
||||
u: "[/127.in-addr.arpa/]#",
|
||||
}, {
|
||||
name: "not_arpa_subnet",
|
||||
wantErr: `checking domain-specific upstreams: ` +
|
||||
`bad arpa domain name "hello.world": not a reversed ip network`,
|
||||
u: "[/hello.world/]#",
|
||||
}, {
|
||||
name: "non-private_arpa_address",
|
||||
wantErr: `checking domain-specific upstreams: ` +
|
||||
`arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network`,
|
||||
u: "[/1.2.3.4.in-addr.arpa/]#",
|
||||
}, {
|
||||
name: "non-private_arpa_subnet",
|
||||
wantErr: `checking domain-specific upstreams: ` +
|
||||
`arpa domain "128.in-addr.arpa." should point to a locally-served network`,
|
||||
u: "[/128.in-addr.arpa/]#",
|
||||
}, {
|
||||
name: "several_bad",
|
||||
wantErr: `checking domain-specific upstreams: 2 errors: ` +
|
||||
`"arpa domain \"1.2.3.4.in-addr.arpa.\" should point to a locally-served network", ` +
|
||||
`"bad arpa domain name \"non.arpa\": not a reversed ip network"`,
|
||||
u: "[/non.arpa/1.2.3.4.in-addr.arpa/127.in-addr.arpa/]#",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
set := []string{"192.168.0.1", tc.u}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := ValidateUpstreamsPrivate(set, ss)
|
||||
testutil.AssertErrorMsg(t, tc.wantErr, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -520,6 +520,43 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"local_ptr_upstreams_bad": {
|
||||
"req": {
|
||||
"local_ptr_upstreams": [
|
||||
"123.123.123.123",
|
||||
"[/non.arpa/]#"
|
||||
]
|
||||
},
|
||||
"want": {
|
||||
"upstream_dns": [
|
||||
"8.8.8.8:53",
|
||||
"8.8.4.4:53"
|
||||
],
|
||||
"upstream_dns_file": "",
|
||||
"bootstrap_dns": [
|
||||
"9.9.9.10",
|
||||
"149.112.112.10",
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
"upstream_mode": "",
|
||||
"cache_size": 0,
|
||||
"cache_ttl_min": 0,
|
||||
"cache_ttl_max": 0,
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
}
|
||||
},
|
||||
"local_ptr_upstreams_null": {
|
||||
"req": {
|
||||
"local_ptr_upstreams": null
|
||||
|
||||
Reference in New Issue
Block a user