Pull request: 2639 use testify require vol.4
Merge in DNS/adguard-home from 2639-testify-require-4 to master Closes #2639. Squashed commit of the following: commit 0bb9125f42ab6d2511c1b8e481112aa5edd581d9 Merge: 0e9e9ed12c9992e0Author: Eugene Burkov <e.burkov@adguard.com> Date: Thu Mar 11 15:47:21 2021 +0300 Merge branch 'master' into 2639-testify-require-4 commit 0e9e9ed16ae13ce648b5e1da6ffd123df911c2d7 Author: Eugene Burkov <e.burkov@adguard.com> Date: Wed Mar 10 12:43:15 2021 +0300 home: rm deletion error check commit 6bfbbcd2b7f9197a06856f9e6b959c2e1c4b8353 Merge: c8ebe5418811c881Author: Eugene Burkov <e.burkov@adguard.com> Date: Wed Mar 10 12:30:07 2021 +0300 Merge branch 'master' into 2639-testify-require-4 commit c8ebe54142bba780226f76ddb72e33664ed28f30 Author: Eugene Burkov <e.burkov@adguard.com> Date: Wed Mar 10 12:28:43 2021 +0300 home: imp tests commit f0e1db456f02df5f5f56ca93e7bd40a48475b38c Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Mar 5 14:06:41 2021 +0300 dnsforward: imp tests commit 4528246105ed06471a8778abbe8e5c30fc5483d5 Merge: 54b08d9c90ebc4d8Author: Eugene Burkov <e.burkov@adguard.com> Date: Thu Mar 4 18:17:52 2021 +0300 Merge branch 'master' into 2639-testify-require-4 commit 54b08d9c980b8d69d019a1a1b3931aa048275691 Author: Eugene Burkov <e.burkov@adguard.com> Date: Thu Feb 11 13:17:05 2021 +0300 dnsfilter: imp tests
This commit is contained in:
@@ -5,71 +5,153 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsBlockedIPAllowed(t *testing.T) {
|
||||
a := &accessCtx{}
|
||||
assert.Nil(t, a.Init([]string{"1.1.1.1", "2.2.0.0/16"}, nil, nil))
|
||||
func TestIsBlockedIP(t *testing.T) {
|
||||
const (
|
||||
ip int = iota
|
||||
cidr
|
||||
)
|
||||
|
||||
disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1))
|
||||
assert.False(t, disallowed)
|
||||
assert.Empty(t, disallowedRule)
|
||||
rules := []string{
|
||||
ip: "1.1.1.1",
|
||||
cidr: "2.2.0.0/16",
|
||||
}
|
||||
|
||||
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2))
|
||||
assert.True(t, disallowed)
|
||||
assert.Empty(t, disallowedRule)
|
||||
testCases := []struct {
|
||||
name string
|
||||
allowed bool
|
||||
ip net.IP
|
||||
wantDis bool
|
||||
wantRule string
|
||||
}{{
|
||||
name: "allow_ip",
|
||||
allowed: true,
|
||||
ip: net.IPv4(1, 1, 1, 1),
|
||||
wantDis: false,
|
||||
wantRule: "",
|
||||
}, {
|
||||
name: "disallow_ip",
|
||||
allowed: true,
|
||||
ip: net.IPv4(1, 1, 1, 2),
|
||||
wantDis: true,
|
||||
wantRule: "",
|
||||
}, {
|
||||
name: "allow_cidr",
|
||||
allowed: true,
|
||||
ip: net.IPv4(2, 2, 1, 1),
|
||||
wantDis: false,
|
||||
wantRule: "",
|
||||
}, {
|
||||
name: "disallow_cidr",
|
||||
allowed: true,
|
||||
ip: net.IPv4(2, 3, 1, 1),
|
||||
wantDis: true,
|
||||
wantRule: "",
|
||||
}, {
|
||||
name: "allow_ip",
|
||||
allowed: false,
|
||||
ip: net.IPv4(1, 1, 1, 1),
|
||||
wantDis: true,
|
||||
wantRule: rules[ip],
|
||||
}, {
|
||||
name: "disallow_ip",
|
||||
allowed: false,
|
||||
ip: net.IPv4(1, 1, 1, 2),
|
||||
wantDis: false,
|
||||
wantRule: "",
|
||||
}, {
|
||||
name: "allow_cidr",
|
||||
allowed: false,
|
||||
ip: net.IPv4(2, 2, 1, 1),
|
||||
wantDis: true,
|
||||
wantRule: rules[cidr],
|
||||
}, {
|
||||
name: "disallow_cidr",
|
||||
allowed: false,
|
||||
ip: net.IPv4(2, 3, 1, 1),
|
||||
wantDis: false,
|
||||
wantRule: "",
|
||||
}}
|
||||
|
||||
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1))
|
||||
assert.False(t, disallowed)
|
||||
assert.Empty(t, disallowedRule)
|
||||
for _, tc := range testCases {
|
||||
prefix := "allowed_"
|
||||
if !tc.allowed {
|
||||
prefix = "disallowed_"
|
||||
}
|
||||
|
||||
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1))
|
||||
assert.True(t, disallowed)
|
||||
assert.Empty(t, disallowedRule)
|
||||
t.Run(prefix+tc.name, func(t *testing.T) {
|
||||
aCtx := &accessCtx{}
|
||||
allowedRules := rules
|
||||
var disallowedRules []string
|
||||
|
||||
if !tc.allowed {
|
||||
allowedRules, disallowedRules = disallowedRules, allowedRules
|
||||
}
|
||||
|
||||
require.Nil(t, aCtx.Init(allowedRules, disallowedRules, nil))
|
||||
|
||||
disallowed, rule := aCtx.IsBlockedIP(tc.ip)
|
||||
assert.Equal(t, tc.wantDis, disallowed)
|
||||
assert.Equal(t, tc.wantRule, rule)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBlockedIPDisallowed(t *testing.T) {
|
||||
a := &accessCtx{}
|
||||
assert.Nil(t, a.Init(nil, []string{"1.1.1.1", "2.2.0.0/16"}, nil))
|
||||
|
||||
disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1))
|
||||
assert.True(t, disallowed)
|
||||
assert.Equal(t, "1.1.1.1", disallowedRule)
|
||||
|
||||
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2))
|
||||
assert.False(t, disallowed)
|
||||
assert.Empty(t, disallowedRule)
|
||||
|
||||
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1))
|
||||
assert.True(t, disallowed)
|
||||
assert.Equal(t, "2.2.0.0/16", disallowedRule)
|
||||
|
||||
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1))
|
||||
assert.False(t, disallowed)
|
||||
assert.Empty(t, disallowedRule)
|
||||
}
|
||||
|
||||
func TestIsBlockedIPBlockedDomain(t *testing.T) {
|
||||
a := &accessCtx{}
|
||||
assert.True(t, a.Init(nil, nil, []string{
|
||||
func TestIsBlockedDomain(t *testing.T) {
|
||||
aCtx := &accessCtx{}
|
||||
require.Nil(t, aCtx.Init(nil, nil, []string{
|
||||
"host1",
|
||||
"host2",
|
||||
"*.host.com",
|
||||
"||host3.com^",
|
||||
}) == nil)
|
||||
}))
|
||||
|
||||
// match by "host2.com"
|
||||
assert.True(t, a.IsBlockedDomain("host1"))
|
||||
assert.True(t, a.IsBlockedDomain("host2"))
|
||||
assert.False(t, a.IsBlockedDomain("host3"))
|
||||
testCases := []struct {
|
||||
name string
|
||||
domain string
|
||||
want bool
|
||||
}{{
|
||||
name: "plain_match",
|
||||
domain: "host1",
|
||||
want: true,
|
||||
}, {
|
||||
name: "plain_mismatch",
|
||||
domain: "host2",
|
||||
want: false,
|
||||
}, {
|
||||
name: "wildcard_type-1_match_short",
|
||||
domain: "asdf.host.com",
|
||||
want: true,
|
||||
}, {
|
||||
name: "wildcard_type-1_match_long",
|
||||
domain: "qwer.asdf.host.com",
|
||||
want: true,
|
||||
}, {
|
||||
name: "wildcard_type-1_mismatch_no-lead",
|
||||
domain: "host.com",
|
||||
want: false,
|
||||
}, {
|
||||
name: "wildcard_type-1_mismatch_bad-asterisk",
|
||||
domain: "asdf.zhost.com",
|
||||
want: false,
|
||||
}, {
|
||||
name: "wildcard_type-2_match_simple",
|
||||
domain: "host3.com",
|
||||
want: true,
|
||||
}, {
|
||||
name: "wildcard_type-2_match_complex",
|
||||
domain: "asdf.host3.com",
|
||||
want: true,
|
||||
}, {
|
||||
name: "wildcard_type-2_mismatch",
|
||||
domain: ".host3.com",
|
||||
want: false,
|
||||
}}
|
||||
|
||||
// match by wildcard "*.host.com"
|
||||
assert.False(t, a.IsBlockedDomain("host.com"))
|
||||
assert.True(t, a.IsBlockedDomain("asdf.host.com"))
|
||||
assert.True(t, a.IsBlockedDomain("qwer.asdf.host.com"))
|
||||
assert.False(t, a.IsBlockedDomain("asdf.zhost.com"))
|
||||
|
||||
// match by wildcard "||host3.com^"
|
||||
assert.True(t, a.IsBlockedDomain("host3.com"))
|
||||
assert.True(t, a.IsBlockedDomain("asdf.host3.com"))
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, aCtx.IsBlockedDomain(tc.domain))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -42,14 +43,180 @@ func startDeferStop(t *testing.T, s *Server) {
|
||||
t.Helper()
|
||||
|
||||
err := s.Start()
|
||||
assert.Nilf(t, err, "failed to start server: %s", err)
|
||||
require.Nilf(t, err, "failed to start server: %s", err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := s.Stop()
|
||||
assert.Nilf(t, err, "dns server failed to stop: %s", err)
|
||||
require.Nilf(t, err, "dns server failed to stop: %s", err)
|
||||
})
|
||||
}
|
||||
|
||||
func createTestServer(t *testing.T, filterConf *dnsfilter.Config, forwardConf ServerConfig) *Server {
|
||||
t.Helper()
|
||||
|
||||
rules := `||nxdomain.example.org
|
||||
||null.example.org^
|
||||
127.0.0.1 host.example.org
|
||||
@@||whitelist.example.org^
|
||||
||127.0.0.255`
|
||||
filters := []dnsfilter.Filter{{
|
||||
ID: 0, Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f := dnsfilter.New(filterConf, filters)
|
||||
|
||||
s := NewServer(DNSCreateParams{DNSFilter: f})
|
||||
s.conf = forwardConf
|
||||
require.Nil(t, s.Prepare(nil))
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func createServerTLSConfig(t *testing.T) (*tls.Config, []byte, []byte) {
|
||||
t.Helper()
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.Nilf(t, err, "cannot generate RSA key: %s", err)
|
||||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
require.Nilf(t, err, "failed to generate serial number: %s", err)
|
||||
|
||||
notBefore := time.Now()
|
||||
notAfter := notBefore.Add(5 * 365 * time.Hour * 24)
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"AdGuard Tests"},
|
||||
},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
template.DNSNames = append(template.DNSNames, tlsServerName)
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey)
|
||||
require.Nilf(t, err, "failed to create certificate: %s", err)
|
||||
|
||||
certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
|
||||
|
||||
cert, err := tls.X509KeyPair(certPem, keyPem)
|
||||
require.Nilf(t, err, "failed to create certificate: %s", err)
|
||||
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
ServerName: tlsServerName,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}, certPem, keyPem
|
||||
}
|
||||
|
||||
func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte) {
|
||||
t.Helper()
|
||||
|
||||
var keyPem []byte
|
||||
_, certPem, keyPem = createServerTLSConfig(t)
|
||||
|
||||
s = createTestServer(t, &dnsfilter.Config{}, ServerConfig{
|
||||
UDPListenAddr: &net.UDPAddr{},
|
||||
TCPListenAddr: &net.TCPAddr{},
|
||||
})
|
||||
|
||||
tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem
|
||||
s.conf.TLSConfig = tlsConf
|
||||
|
||||
err := s.Prepare(nil)
|
||||
require.Nilf(t, err, "failed to prepare server: %s", err)
|
||||
|
||||
return s, certPem
|
||||
}
|
||||
|
||||
func createGoogleATestMessage() *dns.Msg {
|
||||
return createTestMessage("google-public-dns-a.google.com.")
|
||||
}
|
||||
|
||||
func createTestMessage(host string) *dns.Msg {
|
||||
return &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: dns.Id(),
|
||||
RecursionDesired: true,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: host,
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func createTestMessageWithType(host string, qtype uint16) *dns.Msg {
|
||||
req := createTestMessage(host)
|
||||
req.Question[0].Qtype = qtype
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func assertGoogleAResponse(t *testing.T, reply *dns.Msg) {
|
||||
assertResponse(t, reply, net.IP{8, 8, 8, 8})
|
||||
}
|
||||
|
||||
func assertResponse(t *testing.T, reply *dns.Msg, ip net.IP) {
|
||||
t.Helper()
|
||||
|
||||
require.Lenf(t, reply.Answer, 1, "dns server returned reply with wrong number of answers - %d", len(reply.Answer))
|
||||
|
||||
a, ok := reply.Answer[0].(*dns.A)
|
||||
require.Truef(t, ok, "dns server returned wrong answer type instead of A: %v", reply.Answer[0])
|
||||
assert.Truef(t, a.A.Equal(ip), "dns server returned wrong answer instead of %s: %s", ip, a.A)
|
||||
}
|
||||
|
||||
// sendTestMessagesAsync sends messages in parallel to check for race issues.
|
||||
//
|
||||
//lint:ignore U1000 it's called from the function which is skipped for now.
|
||||
func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) {
|
||||
t.Helper()
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < testMessagesCount; i++ {
|
||||
msg := createGoogleATestMessage()
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
err := conn.WriteMsg(msg)
|
||||
require.Nilf(t, err, "cannot write message: %s", err)
|
||||
|
||||
res, err := conn.ReadMsg()
|
||||
require.Nilf(t, err, "cannot read response to message: %s", err)
|
||||
|
||||
assertGoogleAResponse(t, res)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func sendTestMessages(t *testing.T, conn *dns.Conn) {
|
||||
t.Helper()
|
||||
|
||||
for i := 0; i < testMessagesCount; i++ {
|
||||
req := createGoogleATestMessage()
|
||||
err := conn.WriteMsg(req)
|
||||
assert.Nilf(t, err, "cannot write message #%d: %s", i, err)
|
||||
|
||||
res, err := conn.ReadMsg()
|
||||
assert.Nilf(t, err, "cannot read response to message #%d: %s", i, err)
|
||||
assertGoogleAResponse(t, res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer(t *testing.T) {
|
||||
s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{
|
||||
UDPListenAddr: &net.UDPAddr{},
|
||||
@@ -81,7 +248,7 @@ func TestServer(t *testing.T) {
|
||||
client := dns.Client{Net: tc.proto}
|
||||
|
||||
reply, _, err := client.Exchange(createGoogleATestMessage(), addr.String())
|
||||
assert.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err)
|
||||
require.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err)
|
||||
|
||||
assertGoogleAResponse(t, reply)
|
||||
})
|
||||
@@ -106,31 +273,12 @@ func TestServerWithProtectionDisabled(t *testing.T) {
|
||||
req := createGoogleATestMessage()
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
client := dns.Client{Net: proxy.ProtoUDP}
|
||||
|
||||
reply, _, err := client.Exchange(req, addr.String())
|
||||
assert.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err)
|
||||
require.Nilf(t, err, "сouldn't talk to server %s: %s", addr, err)
|
||||
assertGoogleAResponse(t, reply)
|
||||
}
|
||||
|
||||
func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte) {
|
||||
t.Helper()
|
||||
|
||||
var keyPem []byte
|
||||
_, certPem, keyPem = createServerTLSConfig(t)
|
||||
|
||||
s = createTestServer(t, &dnsfilter.Config{}, ServerConfig{
|
||||
UDPListenAddr: &net.UDPAddr{},
|
||||
TCPListenAddr: &net.TCPAddr{},
|
||||
})
|
||||
|
||||
tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem
|
||||
s.conf.TLSConfig = tlsConf
|
||||
|
||||
err := s.Prepare(nil)
|
||||
assert.Nilf(t, err, "failed to prepare server: %s", err)
|
||||
|
||||
return s, certPem
|
||||
}
|
||||
|
||||
func TestDoTServer(t *testing.T) {
|
||||
s, certPem := createTestTLS(t, TLSConfig{
|
||||
TLSListenAddr: &net.TCPAddr{},
|
||||
@@ -156,7 +304,7 @@ func TestDoTServer(t *testing.T) {
|
||||
// Create a DNS-over-TLS client connection.
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoTLS)
|
||||
conn, err := dns.DialWithTLS("tcp-tls", addr.String(), tlsConfig)
|
||||
assert.Nilf(t, err, "cannot connect to the proxy: %s", err)
|
||||
require.Nilf(t, err, "cannot connect to the proxy: %s", err)
|
||||
|
||||
sendTestMessages(t, conn)
|
||||
}
|
||||
@@ -178,12 +326,12 @@ func TestDoQServer(t *testing.T) {
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoQUIC)
|
||||
opts := upstream.Options{InsecureSkipVerify: true}
|
||||
u, err := upstream.AddressToUpstream(fmt.Sprintf("%s://%s", proxy.ProtoQUIC, addr), opts)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Send the test message.
|
||||
req := createGoogleATestMessage()
|
||||
res, err := u.Exchange(req)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
|
||||
assertGoogleAResponse(t, res)
|
||||
}
|
||||
@@ -221,7 +369,7 @@ func TestServerRace(t *testing.T) {
|
||||
// Message over UDP.
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
conn, err := dns.Dial(proxy.ProtoUDP, addr.String())
|
||||
assert.Nilf(t, err, "cannot connect to the proxy: %s", err)
|
||||
require.Nilf(t, err, "cannot connect to the proxy: %s", err)
|
||||
|
||||
sendTestMessagesAsync(t, conn)
|
||||
}
|
||||
@@ -282,8 +430,9 @@ func TestSafeSearch(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.host, func(t *testing.T) {
|
||||
req := createTestMessage(tc.host)
|
||||
|
||||
reply, _, err := client.Exchange(req, addr)
|
||||
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
|
||||
require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
|
||||
assertResponse(t, reply, tc.want)
|
||||
})
|
||||
}
|
||||
@@ -330,8 +479,10 @@ func TestBlockedRequest(t *testing.T) {
|
||||
req := createTestMessage("nxdomain.example.org.")
|
||||
|
||||
reply, err := dns.Exchange(req, addr.String())
|
||||
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
|
||||
require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||||
|
||||
require.Len(t, reply.Answer, 1)
|
||||
assert.True(t, reply.Answer[0].(*dns.A).A.IsUnspecified())
|
||||
}
|
||||
|
||||
@@ -364,28 +515,14 @@ func TestServerCustomClientUpstream(t *testing.T) {
|
||||
|
||||
reply, err := dns.Exchange(req, addr.String())
|
||||
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||||
assert.NotEmpty(t, reply.Answer)
|
||||
require.NotEmpty(t, reply.Answer)
|
||||
|
||||
require.Len(t, reply.Answer, 1)
|
||||
assert.Equal(t, net.IP{192, 168, 0, 1}, reply.Answer[0].(*dns.A).A)
|
||||
}
|
||||
|
||||
func (s *Server) startWithUpstream(u upstream.Upstream) error {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
err := s.Prepare(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.dnsProxy.UpstreamConfig = &proxy.UpstreamConfig{
|
||||
Upstreams: []upstream.Upstream{u},
|
||||
}
|
||||
|
||||
return s.dnsProxy.Start()
|
||||
}
|
||||
|
||||
// testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work.
|
||||
var testCNAMEs = map[string]string{
|
||||
"badhost.": "null.example.org.",
|
||||
@@ -409,15 +546,19 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
|
||||
IPv6: nil,
|
||||
}
|
||||
s.conf.ProtectionEnabled = false
|
||||
err := s.startWithUpstream(testUpstm)
|
||||
assert.Nil(t, err)
|
||||
s.dnsProxy.UpstreamConfig = &proxy.UpstreamConfig{
|
||||
Upstreams: []upstream.Upstream{testUpstm},
|
||||
}
|
||||
startDeferStop(t, s)
|
||||
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
|
||||
// 'badhost' has a canonical name 'null.example.org' which is blocked by
|
||||
// filters: but protection is disabled so response is _not_ blocked.
|
||||
// 'badhost' has a canonical name 'null.example.org' which should be
|
||||
// blocked by filters, but protection is disabled so it is not.
|
||||
req := createTestMessage("badhost.")
|
||||
|
||||
reply, err := dns.Exchange(req, addr.String())
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||||
}
|
||||
|
||||
@@ -465,11 +606,15 @@ func TestBlockCNAME(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run("block_cname_"+tc.host, func(t *testing.T) {
|
||||
req := createTestMessage(tc.host)
|
||||
|
||||
reply, err := dns.Exchange(req, addr)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||||
if tc.want {
|
||||
assert.True(t, reply.Answer[0].(*dns.A).A.IsUnspecified())
|
||||
require.Len(t, reply.Answer, 1)
|
||||
a, ok := reply.Answer[0].(*dns.A)
|
||||
require.True(t, ok)
|
||||
assert.True(t, a.A.IsUnspecified())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -513,7 +658,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
|
||||
// However, in our case it should not be blocked as filtering is
|
||||
// disabled on the client level.
|
||||
reply, err := dns.Exchange(&req, addr.String())
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||||
}
|
||||
|
||||
@@ -544,10 +689,10 @@ func TestNullBlockedRequest(t *testing.T) {
|
||||
}
|
||||
|
||||
reply, err := dns.Exchange(&req, addr.String())
|
||||
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
|
||||
assert.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
|
||||
require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
|
||||
require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
|
||||
a, ok := reply.Answer[0].(*dns.A)
|
||||
assert.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
|
||||
require.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
|
||||
assert.Truef(t, a.A.IsUnspecified(), "dns server %s returned wrong answer instead of 0.0.0.0: %v", addr, a.A)
|
||||
}
|
||||
|
||||
@@ -561,7 +706,7 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||
s := NewServer(DNSCreateParams{
|
||||
DNSFilter: dnsfilter.New(&dnsfilter.Config{}, filters),
|
||||
})
|
||||
conf := ServerConfig{
|
||||
conf := &ServerConfig{
|
||||
UDPListenAddr: &net.UDPAddr{},
|
||||
TCPListenAddr: &net.TCPAddr{},
|
||||
FilteringConfig: FilteringConfig{
|
||||
@@ -572,11 +717,11 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||
},
|
||||
}
|
||||
// Invalid BlockingIPv4.
|
||||
assert.NotNil(t, s.Prepare(&conf))
|
||||
assert.NotNil(t, s.Prepare(conf))
|
||||
|
||||
conf.BlockingIPv4 = net.IP{0, 0, 0, 1}
|
||||
conf.BlockingIPv6 = net.ParseIP("::1")
|
||||
assert.Nil(t, s.Prepare(&conf))
|
||||
require.Nil(t, s.Prepare(conf))
|
||||
|
||||
startDeferStop(t, s)
|
||||
|
||||
@@ -584,18 +729,18 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||
|
||||
req := createTestMessageWithType("null.example.org.", dns.TypeA)
|
||||
reply, err := dns.Exchange(req, addr.String())
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, reply.Answer, 1)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, reply.Answer, 1)
|
||||
a, ok := reply.Answer[0].(*dns.A)
|
||||
assert.True(t, ok)
|
||||
require.True(t, ok)
|
||||
assert.True(t, net.IP{0, 0, 0, 1}.Equal(a.A))
|
||||
|
||||
req = createTestMessageWithType("null.example.org.", dns.TypeAAAA)
|
||||
reply, err = dns.Exchange(req, addr.String())
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, reply.Answer, 1)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, reply.Answer, 1)
|
||||
a6, ok := reply.Answer[0].(*dns.AAAA)
|
||||
assert.True(t, ok)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "::1", a6.AAAA.String())
|
||||
}
|
||||
|
||||
@@ -615,11 +760,10 @@ func TestBlockedByHosts(t *testing.T) {
|
||||
req := createTestMessage("host.example.org.")
|
||||
|
||||
reply, err := dns.Exchange(req, addr.String())
|
||||
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
|
||||
assert.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
|
||||
|
||||
require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
|
||||
require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
|
||||
a, ok := reply.Answer[0].(*dns.A)
|
||||
assert.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
|
||||
require.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
|
||||
assert.Equalf(t, net.IP{127, 0, 0, 1}, a.A, "dns server %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A)
|
||||
}
|
||||
|
||||
@@ -630,7 +774,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
Hostname: hostname,
|
||||
Block: true,
|
||||
}
|
||||
ans, _ := (&aghtest.TestResolver{}).HostToIPs(hostname)
|
||||
ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname)
|
||||
|
||||
filterConf := &dnsfilter.Config{
|
||||
SafeBrowsingEnabled: true,
|
||||
@@ -639,7 +783,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
UDPListenAddr: &net.UDPAddr{},
|
||||
TCPListenAddr: &net.TCPAddr{},
|
||||
FilteringConfig: FilteringConfig{
|
||||
SafeBrowsingBlockHost: ans.String(),
|
||||
SafeBrowsingBlockHost: ans4.String(),
|
||||
ProtectionEnabled: true,
|
||||
},
|
||||
}
|
||||
@@ -652,13 +796,12 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
req := createTestMessage(hostname + ".")
|
||||
|
||||
reply, err := dns.Exchange(req, addr.String())
|
||||
assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
|
||||
assert.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
|
||||
require.Nilf(t, err, "couldn't talk to server %s: %s", addr, err)
|
||||
require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
|
||||
|
||||
a, ok := reply.Answer[0].(*dns.A)
|
||||
if assert.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) {
|
||||
assert.Equal(t, ans, a.A, "dns server %s returned wrong answer: %v", addr, a.A)
|
||||
}
|
||||
require.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
|
||||
assert.Equal(t, ans4, a.A, "dns server %s returned wrong answer: %v", addr, a.A)
|
||||
}
|
||||
|
||||
func TestRewrite(t *testing.T) {
|
||||
@@ -680,14 +823,14 @@ func TestRewrite(t *testing.T) {
|
||||
f := dnsfilter.New(c, nil)
|
||||
|
||||
s := NewServer(DNSCreateParams{DNSFilter: f})
|
||||
err := s.Prepare(&ServerConfig{
|
||||
assert.Nil(t, s.Prepare(&ServerConfig{
|
||||
UDPListenAddr: &net.UDPAddr{},
|
||||
TCPListenAddr: &net.TCPAddr{},
|
||||
FilteringConfig: FilteringConfig{
|
||||
ProtectionEnabled: true,
|
||||
UpstreamDNS: []string{"8.8.8.8:53"},
|
||||
},
|
||||
})
|
||||
}))
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.TestUpstream{
|
||||
CName: map[string]string{
|
||||
@@ -698,185 +841,44 @@ func TestRewrite(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.Nil(t, err)
|
||||
startDeferStop(t, s)
|
||||
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
|
||||
req := createTestMessageWithType("test.com.", dns.TypeA)
|
||||
reply, err := dns.Exchange(req, addr.String())
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, reply.Answer, 1)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, reply.Answer, 1)
|
||||
a, ok := reply.Answer[0].(*dns.A)
|
||||
assert.True(t, ok)
|
||||
require.True(t, ok)
|
||||
assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A))
|
||||
|
||||
req = createTestMessageWithType("test.com.", dns.TypeAAAA)
|
||||
reply, err = dns.Exchange(req, addr.String())
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.Empty(t, reply.Answer)
|
||||
|
||||
req = createTestMessageWithType("alias.test.com.", dns.TypeA)
|
||||
reply, err = dns.Exchange(req, addr.String())
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, reply.Answer, 2)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Len(t, reply.Answer, 2)
|
||||
assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target)
|
||||
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
|
||||
|
||||
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
|
||||
reply, err = dns.Exchange(req, addr.String())
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
|
||||
// The original question is restored.
|
||||
require.Len(t, reply.Question, 1)
|
||||
assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name)
|
||||
assert.Len(t, reply.Answer, 2)
|
||||
|
||||
require.Len(t, reply.Answer, 2)
|
||||
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
|
||||
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
|
||||
}
|
||||
|
||||
func createTestServer(t *testing.T, filterConf *dnsfilter.Config, forwardConf ServerConfig) *Server {
|
||||
rules := `||nxdomain.example.org
|
||||
||null.example.org^
|
||||
127.0.0.1 host.example.org
|
||||
@@||whitelist.example.org^
|
||||
||127.0.0.255`
|
||||
filters := []dnsfilter.Filter{{
|
||||
ID: 0, Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f := dnsfilter.New(filterConf, filters)
|
||||
|
||||
s := NewServer(DNSCreateParams{DNSFilter: f})
|
||||
s.conf = forwardConf
|
||||
assert.Nil(t, s.Prepare(nil))
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func createServerTLSConfig(t *testing.T) (*tls.Config, []byte, []byte) {
|
||||
t.Helper()
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.Nilf(t, err, "cannot generate RSA key: %s", err)
|
||||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
assert.Nilf(t, err, "failed to generate serial number: %s", err)
|
||||
|
||||
notBefore := time.Now()
|
||||
notAfter := notBefore.Add(5 * 365 * time.Hour * 24)
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"AdGuard Tests"},
|
||||
},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
template.DNSNames = append(template.DNSNames, tlsServerName)
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey)
|
||||
assert.Nilf(t, err, "failed to create certificate: %s", err)
|
||||
|
||||
certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
|
||||
|
||||
cert, err := tls.X509KeyPair(certPem, keyPem)
|
||||
assert.Nilf(t, err, "failed to create certificate: %s", err)
|
||||
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
ServerName: tlsServerName,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}, certPem, keyPem
|
||||
}
|
||||
|
||||
// sendTestMessagesAsync sends messages in parallel to check for race issues.
|
||||
//lint:ignore U1000 it's called from the function which is skipped for now.
|
||||
func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) {
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < testMessagesCount; i++ {
|
||||
msg := createGoogleATestMessage()
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
err := conn.WriteMsg(msg)
|
||||
assert.Nilf(t, err, "cannot write message: %s", err)
|
||||
|
||||
res, err := conn.ReadMsg()
|
||||
assert.Nilf(t, err, "cannot read response to message: %s", err)
|
||||
|
||||
assertGoogleAResponse(t, res)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func sendTestMessages(t *testing.T, conn *dns.Conn) {
|
||||
t.Helper()
|
||||
|
||||
for i := 0; i < testMessagesCount; i++ {
|
||||
req := createGoogleATestMessage()
|
||||
err := conn.WriteMsg(req)
|
||||
assert.Nilf(t, err, "cannot write message #%d: %s", i, err)
|
||||
|
||||
res, err := conn.ReadMsg()
|
||||
assert.Nilf(t, err, "cannot read response to message #%d: %s", i, err)
|
||||
assertGoogleAResponse(t, res)
|
||||
}
|
||||
}
|
||||
|
||||
func createGoogleATestMessage() *dns.Msg {
|
||||
return createTestMessage("google-public-dns-a.google.com.")
|
||||
}
|
||||
|
||||
func createTestMessage(host string) *dns.Msg {
|
||||
return &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: dns.Id(),
|
||||
RecursionDesired: true,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: host,
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func createTestMessageWithType(host string, qtype uint16) *dns.Msg {
|
||||
req := createTestMessage(host)
|
||||
req.Question[0].Qtype = qtype
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func assertGoogleAResponse(t *testing.T, reply *dns.Msg) {
|
||||
assertResponse(t, reply, net.IP{8, 8, 8, 8})
|
||||
}
|
||||
|
||||
func assertResponse(t *testing.T, reply *dns.Msg, ip net.IP) {
|
||||
t.Helper()
|
||||
|
||||
if !assert.Lenf(t, reply.Answer, 1, "dns server returned reply with wrong number of answers - %d", len(reply.Answer)) {
|
||||
return
|
||||
}
|
||||
|
||||
a, ok := reply.Answer[0].(*dns.A)
|
||||
if assert.Truef(t, ok, "dns server returned wrong answer type instead of A: %v", reply.Answer[0]) {
|
||||
assert.Truef(t, a.A.Equal(ip), "dns server returned wrong answer instead of %s: %s", ip, a.A)
|
||||
}
|
||||
}
|
||||
|
||||
func publicKey(priv interface{}) interface{} {
|
||||
switch k := priv.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
@@ -966,8 +968,8 @@ func TestValidateUpstream(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
defaultUpstream, err := validateUpstream(tc.upstream)
|
||||
assert.Equal(t, tc.valid, err == nil)
|
||||
if err == nil {
|
||||
require.Equal(t, tc.valid, err == nil)
|
||||
if tc.valid {
|
||||
assert.Equal(t, tc.wantDef, defaultUpstream)
|
||||
}
|
||||
})
|
||||
@@ -975,42 +977,73 @@ func TestValidateUpstream(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateUpstreamsSet(t *testing.T) {
|
||||
// Empty upstreams array.
|
||||
var upstreamsSet []string
|
||||
assert.Nil(t, ValidateUpstreams(upstreamsSet), "empty upstreams array should be valid")
|
||||
testCases := []struct {
|
||||
name string
|
||||
msg string
|
||||
set []string
|
||||
wantNil bool
|
||||
}{{
|
||||
name: "empty",
|
||||
msg: "empty upstreams array should be valid",
|
||||
set: nil,
|
||||
wantNil: true,
|
||||
}, {
|
||||
name: "comment",
|
||||
msg: "comments should not be validated",
|
||||
set: []string{"# comment"},
|
||||
wantNil: true,
|
||||
}, {
|
||||
name: "valid_no_default",
|
||||
msg: "there is no default upstream",
|
||||
set: []string{
|
||||
"[/host.com/]1.1.1.1",
|
||||
"[//]tls://1.1.1.1",
|
||||
"[/www.host.com/]#",
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
},
|
||||
wantNil: false,
|
||||
}, {
|
||||
name: "valid_with_default",
|
||||
msg: "upstreams set is valid, but doesn't pass through validation cause: %s",
|
||||
set: []string{
|
||||
"[/host.com/]1.1.1.1",
|
||||
"[//]tls://1.1.1.1",
|
||||
"[/www.host.com/]#",
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
"8.8.8.8",
|
||||
},
|
||||
wantNil: true,
|
||||
}, {
|
||||
name: "invalid",
|
||||
msg: "there is an invalid upstream in set, but it pass through validation",
|
||||
set: []string{"dhcp://fake.dns"},
|
||||
wantNil: false,
|
||||
}}
|
||||
|
||||
// Comment in upstreams array.
|
||||
upstreamsSet = []string{"# comment"}
|
||||
assert.Nil(t, ValidateUpstreams(upstreamsSet), "comments should not be validated")
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := ValidateUpstreams(tc.set)
|
||||
|
||||
// Set of valid upstreams. There is no default upstream specified.
|
||||
upstreamsSet = []string{
|
||||
"[/host.com/]1.1.1.1",
|
||||
"[//]tls://1.1.1.1",
|
||||
"[/www.host.com/]#",
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
assert.Equalf(t, tc.wantNil, err == nil, tc.msg, err)
|
||||
})
|
||||
}
|
||||
assert.NotNil(t, ValidateUpstreams(upstreamsSet), "there is no default upstream")
|
||||
|
||||
// Let's add default upstream.
|
||||
upstreamsSet = append(upstreamsSet, "8.8.8.8")
|
||||
err := ValidateUpstreams(upstreamsSet)
|
||||
assert.Nilf(t, err, "upstreams set is valid, but doesn't pass through validation cause: %s", err)
|
||||
|
||||
// Let's add invalid upstream.
|
||||
upstreamsSet = append(upstreamsSet, "dhcp://fake.dns")
|
||||
assert.NotNil(t, ValidateUpstreams(upstreamsSet), "there is an invalid upstream in set, but it pass through validation")
|
||||
}
|
||||
|
||||
func TestIPStringFromAddr(t *testing.T) {
|
||||
addr := net.UDPAddr{
|
||||
IP: net.ParseIP("1:2:3::4"),
|
||||
Port: 12345,
|
||||
Zone: "eth0",
|
||||
}
|
||||
assert.Equal(t, IPStringFromAddr(&addr), addr.IP.String())
|
||||
assert.Empty(t, IPStringFromAddr(nil))
|
||||
t.Run("not_nil", func(t *testing.T) {
|
||||
addr := net.UDPAddr{
|
||||
IP: net.ParseIP("1:2:3::4"),
|
||||
Port: 12345,
|
||||
Zone: "eth0",
|
||||
}
|
||||
assert.Equal(t, IPStringFromAddr(&addr), addr.IP.String())
|
||||
})
|
||||
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
assert.Empty(t, IPStringFromAddr(nil))
|
||||
})
|
||||
}
|
||||
|
||||
func TestMatchDNSName(t *testing.T) {
|
||||
@@ -1071,38 +1104,33 @@ func (d *testDHCP) Leases(flags int) []dhcpd.Lease {
|
||||
func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {}
|
||||
|
||||
func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
dhcp := &testDHCP{}
|
||||
|
||||
s := NewServer(DNSCreateParams{
|
||||
DNSFilter: dnsfilter.New(&dnsfilter.Config{}, nil),
|
||||
DHCPServer: dhcp,
|
||||
DHCPServer: &testDHCP{},
|
||||
})
|
||||
|
||||
s.conf.UDPListenAddr = &net.UDPAddr{}
|
||||
s.conf.TCPListenAddr = &net.TCPAddr{}
|
||||
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
|
||||
s.conf.FilteringConfig.ProtectionEnabled = true
|
||||
err := s.Prepare(nil)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, s.Prepare(nil))
|
||||
require.Nil(t, s.Start())
|
||||
t.Cleanup(func() {
|
||||
s.Close()
|
||||
})
|
||||
|
||||
assert.Nil(t, s.Start())
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
|
||||
req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR)
|
||||
|
||||
resp, err := dns.Exchange(req, addr.String())
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, resp.Answer, 1)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, resp.Answer, 1)
|
||||
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
|
||||
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
|
||||
|
||||
ptr, ok := resp.Answer[0].(*dns.PTR)
|
||||
if assert.True(t, ok) {
|
||||
assert.Equal(t, "localhost.", ptr.Ptr)
|
||||
}
|
||||
|
||||
s.Close()
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "localhost.", ptr.Ptr)
|
||||
}
|
||||
|
||||
func TestPTRResponseFromHosts(t *testing.T) {
|
||||
@@ -1112,12 +1140,11 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
|
||||
// Prepare test hosts file.
|
||||
hf, err := ioutil.TempFile("", "")
|
||||
if assert.Nil(t, err) {
|
||||
t.Cleanup(func() {
|
||||
assert.Nil(t, hf.Close())
|
||||
assert.Nil(t, os.Remove(hf.Name()))
|
||||
})
|
||||
}
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.Nil(t, hf.Close())
|
||||
assert.Nil(t, os.Remove(hf.Name()))
|
||||
})
|
||||
|
||||
_, _ = hf.WriteString(" 127.0.0.1 host # comment \n")
|
||||
_, _ = hf.WriteString(" ::1 localhost#comment \n")
|
||||
@@ -1131,23 +1158,23 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
s.conf.TCPListenAddr = &net.TCPAddr{}
|
||||
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
|
||||
s.conf.FilteringConfig.ProtectionEnabled = true
|
||||
assert.Nil(t, s.Prepare(nil))
|
||||
require.Nil(t, s.Prepare(nil))
|
||||
|
||||
assert.Nil(t, s.Start())
|
||||
require.Nil(t, s.Start())
|
||||
t.Cleanup(func() {
|
||||
s.Close()
|
||||
})
|
||||
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR)
|
||||
|
||||
resp, err := dns.Exchange(req, addr.String())
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, resp.Answer, 1)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, resp.Answer, 1)
|
||||
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
|
||||
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
|
||||
|
||||
ptr, ok := resp.Answer[0].(*dns.PTR)
|
||||
if assert.True(t, ok) {
|
||||
assert.Equal(t, "host.", ptr.Ptr)
|
||||
}
|
||||
|
||||
s.Close()
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "host.", ptr.Ptr)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
@@ -54,7 +55,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
assert.Nil(t, err)
|
||||
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeNameError, d.Res.Rcode)
|
||||
})
|
||||
|
||||
@@ -64,7 +66,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
assert.Empty(t, d.Res.Answer)
|
||||
})
|
||||
@@ -75,11 +77,11 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
if assert.Len(t, d.Res.Answer, 1) {
|
||||
assert.Equal(t, ip4, d.Res.Answer[0].(*dns.A).A)
|
||||
}
|
||||
|
||||
require.Len(t, d.Res.Answer, 1)
|
||||
assert.Equal(t, ip4, d.Res.Answer[0].(*dns.A).A)
|
||||
})
|
||||
|
||||
t.Run("noerror_aaaa", func(t *testing.T) {
|
||||
@@ -88,11 +90,11 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
if assert.Len(t, d.Res.Answer, 1) {
|
||||
assert.Equal(t, ip6, d.Res.Answer[0].(*dns.AAAA).AAAA)
|
||||
}
|
||||
|
||||
require.Len(t, d.Res.Answer, 1)
|
||||
assert.Equal(t, ip6, d.Res.Answer[0].(*dns.AAAA).AAAA)
|
||||
})
|
||||
|
||||
t.Run("noerror_ptr", func(t *testing.T) {
|
||||
@@ -101,11 +103,11 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
if assert.Len(t, d.Res.Answer, 1) {
|
||||
assert.Equal(t, domain, d.Res.Answer[0].(*dns.PTR).Ptr)
|
||||
}
|
||||
|
||||
require.Len(t, d.Res.Answer, 1)
|
||||
assert.Equal(t, domain, d.Res.Answer[0].(*dns.PTR).Ptr)
|
||||
})
|
||||
|
||||
t.Run("noerror_txt", func(t *testing.T) {
|
||||
@@ -114,11 +116,11 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
if assert.Len(t, d.Res.Answer, 1) {
|
||||
assert.Equal(t, []string{domain}, d.Res.Answer[0].(*dns.TXT).Txt)
|
||||
}
|
||||
|
||||
require.Len(t, d.Res.Answer, 1)
|
||||
assert.Equal(t, []string{domain}, d.Res.Answer[0].(*dns.TXT).Txt)
|
||||
})
|
||||
|
||||
t.Run("noerror_mx", func(t *testing.T) {
|
||||
@@ -127,15 +129,15 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
if assert.Len(t, d.Res.Answer, 1) {
|
||||
ans, ok := d.Res.Answer[0].(*dns.MX)
|
||||
if assert.True(t, ok) {
|
||||
assert.Equal(t, mx.Exchange, ans.Mx)
|
||||
assert.Equal(t, mx.Preference, ans.Preference)
|
||||
}
|
||||
}
|
||||
|
||||
require.Len(t, d.Res.Answer, 1)
|
||||
ans, ok := d.Res.Answer[0].(*dns.MX)
|
||||
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, mx.Exchange, ans.Mx)
|
||||
assert.Equal(t, mx.Preference, ans.Preference)
|
||||
})
|
||||
|
||||
t.Run("noerror_svcb", func(t *testing.T) {
|
||||
@@ -144,17 +146,17 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
if assert.Len(t, d.Res.Answer, 1) {
|
||||
ans, ok := d.Res.Answer[0].(*dns.SVCB)
|
||||
if assert.True(t, ok) {
|
||||
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
|
||||
assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String())
|
||||
assert.Equal(t, svcb.Target, ans.Target)
|
||||
assert.Equal(t, svcb.Priority, ans.Priority)
|
||||
}
|
||||
}
|
||||
|
||||
require.Len(t, d.Res.Answer, 1)
|
||||
ans, ok := d.Res.Answer[0].(*dns.SVCB)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
|
||||
assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String())
|
||||
assert.Equal(t, svcb.Target, ans.Target)
|
||||
assert.Equal(t, svcb.Priority, ans.Priority)
|
||||
})
|
||||
|
||||
t.Run("noerror_https", func(t *testing.T) {
|
||||
@@ -163,16 +165,16 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
if assert.Len(t, d.Res.Answer, 1) {
|
||||
ans, ok := d.Res.Answer[0].(*dns.HTTPS)
|
||||
if assert.True(t, ok) {
|
||||
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
|
||||
assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String())
|
||||
assert.Equal(t, svcb.Target, ans.Target)
|
||||
assert.Equal(t, svcb.Priority, ans.Priority)
|
||||
}
|
||||
}
|
||||
|
||||
require.Len(t, d.Res.Answer, 1)
|
||||
ans, ok := d.Res.Answer[0].(*dns.HTTPS)
|
||||
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
|
||||
assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String())
|
||||
assert.Equal(t, svcb.Target, ans.Target)
|
||||
assert.Equal(t, svcb.Priority, ans.Priority)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
|
||||
@@ -31,9 +32,10 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
|
||||
ConfigModified: func() {},
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf)
|
||||
err := s.Start()
|
||||
assert.Nil(t, err)
|
||||
defer assert.Nil(t, s.Stop())
|
||||
require.Nil(t, s.Start())
|
||||
t.Cleanup(func() {
|
||||
require.Nil(t, s.Stop())
|
||||
})
|
||||
|
||||
defaultConf := s.conf
|
||||
|
||||
@@ -71,13 +73,14 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Cleanup(w.Body.Reset)
|
||||
|
||||
s.conf = tc.conf()
|
||||
s.handleGetConfig(w, nil)
|
||||
assert.Equal(t, tc.want, w.Body.String())
|
||||
|
||||
assert.Equal(t, "application/json", w.Header().Get("Content-Type"))
|
||||
assert.Equal(t, tc.want, w.Body.String())
|
||||
})
|
||||
w.Body.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -191,9 +194,13 @@ func TestDNSForwardHTTTP_handleSetConfig(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
s.conf = defaultConf
|
||||
})
|
||||
|
||||
rBody := ioutil.NopCloser(strings.NewReader(tc.req))
|
||||
r, err := http.NewRequest(http.MethodPost, "http://example.com", rBody)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
|
||||
s.handleSetConfig(w, r)
|
||||
assert.Equal(t, tc.wantSet, w.Body.String())
|
||||
@@ -203,6 +210,5 @@ func TestDNSForwardHTTTP_handleSetConfig(t *testing.T) {
|
||||
assert.Equal(t, tc.wantGet, w.Body.String())
|
||||
w.Body.Reset()
|
||||
})
|
||||
s.conf = defaultConf
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testQueryLog is a simple querylog.QueryLog implementation for tests.
|
||||
@@ -156,7 +157,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
}}
|
||||
|
||||
ups, err := upstream.AddressToUpstream("1.1.1.1", upstream.Options{})
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user