Pull request: 2639 use testify require vol.4

Merge in DNS/adguard-home from 2639-testify-require-4 to master

Closes #2639.

Squashed commit of the following:

commit 0bb9125f42ab6d2511c1b8e481112aa5edd581d9
Merge: 0e9e9ed1 2c9992e0
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Mar 11 15:47:21 2021 +0300

    Merge branch 'master' into 2639-testify-require-4

commit 0e9e9ed16ae13ce648b5e1da6ffd123df911c2d7
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Wed Mar 10 12:43:15 2021 +0300

    home: rm deletion error check

commit 6bfbbcd2b7f9197a06856f9e6b959c2e1c4b8353
Merge: c8ebe541 8811c881
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Wed Mar 10 12:30:07 2021 +0300

    Merge branch 'master' into 2639-testify-require-4

commit c8ebe54142bba780226f76ddb72e33664ed28f30
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Wed Mar 10 12:28:43 2021 +0300

    home: imp tests

commit f0e1db456f02df5f5f56ca93e7bd40a48475b38c
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Mar 5 14:06:41 2021 +0300

    dnsforward: imp tests

commit 4528246105ed06471a8778abbe8e5c30fc5483d5
Merge: 54b08d9c 90ebc4d8
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Mar 4 18:17:52 2021 +0300

    Merge branch 'master' into 2639-testify-require-4

commit 54b08d9c980b8d69d019a1a1b3931aa048275691
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Feb 11 13:17:05 2021 +0300

    dnsfilter: imp tests
This commit is contained in:
Eugene Burkov
2021-03-11 17:32:58 +03:00
parent 2c9992e0cc
commit dfdbfee4fd
19 changed files with 1375 additions and 1267 deletions

View File

@@ -5,6 +5,7 @@ import (
"context"
"fmt"
"net"
"strings"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
@@ -13,6 +14,7 @@ import (
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
@@ -58,7 +60,7 @@ func (d *DNSFilter) checkMatch(t *testing.T, hostname string) {
t.Helper()
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
assert.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
require.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname)
}
@@ -66,20 +68,20 @@ func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16
t.Helper()
res, err := d.CheckHost(hostname, qtype, &setts)
assert.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
require.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname)
if assert.NotEmpty(t, res.Rules, "Expected result to have rules") {
r := res.Rules[0]
assert.NotNilf(t, r.IP, "Expected ip %s to match, actual: %v", ip, r.IP)
assert.Equalf(t, ip, r.IP.String(), "Expected ip %s to match, actual: %v", ip, r.IP)
}
require.NotEmpty(t, res.Rules, "Expected result to have rules")
r := res.Rules[0]
require.NotNilf(t, r.IP, "Expected ip %s to match, actual: %v", ip, r.IP)
assert.Equalf(t, ip, r.IP.String(), "Expected ip %s to match, actual: %v", ip, r.IP)
}
func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string) {
t.Helper()
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
assert.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
require.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
assert.Falsef(t, res.IsFiltered, "Expected hostname %s to not match", hostname)
}
@@ -110,40 +112,40 @@ func TestEtcHostsMatching(t *testing.T) {
// Empty IPv6.
res, err := d.CheckHost("block.com", dns.TypeAAAA, &setts)
assert.Nil(t, err)
require.Nil(t, err)
assert.True(t, res.IsFiltered)
if assert.Len(t, res.Rules, 1) {
assert.Equal(t, "0.0.0.0 block.com", res.Rules[0].Text)
assert.Empty(t, res.Rules[0].IP)
}
require.Len(t, res.Rules, 1)
assert.Equal(t, "0.0.0.0 block.com", res.Rules[0].Text)
assert.Empty(t, res.Rules[0].IP)
// IPv6 match.
d.checkMatchIP(t, "ipv6.com", addr6, dns.TypeAAAA)
// Empty IPv4.
res, err = d.CheckHost("ipv6.com", dns.TypeA, &setts)
assert.Nil(t, err)
require.Nil(t, err)
assert.True(t, res.IsFiltered)
if assert.Len(t, res.Rules, 1) {
assert.Equal(t, "::1 ipv6.com", res.Rules[0].Text)
assert.Empty(t, res.Rules[0].IP)
}
require.Len(t, res.Rules, 1)
assert.Equal(t, "::1 ipv6.com", res.Rules[0].Text)
assert.Empty(t, res.Rules[0].IP)
// Two IPv4, the first one returned.
res, err = d.CheckHost("host2", dns.TypeA, &setts)
assert.Nil(t, err)
require.Nil(t, err)
assert.True(t, res.IsFiltered)
if assert.Len(t, res.Rules, 1) {
assert.Equal(t, res.Rules[0].IP, net.IP{0, 0, 0, 1})
}
require.Len(t, res.Rules, 1)
assert.Equal(t, res.Rules[0].IP, net.IP{0, 0, 0, 1})
// One IPv6 address.
res, err = d.CheckHost("host2", dns.TypeAAAA, &setts)
assert.Nil(t, err)
require.Nil(t, err)
assert.True(t, res.IsFiltered)
if assert.Len(t, res.Rules, 1) {
assert.Equal(t, res.Rules[0].IP, net.IPv6loopback)
}
require.Len(t, res.Rules, 1)
assert.Equal(t, res.Rules[0].IP, net.IPv6loopback)
}
// Safe Browsing.
@@ -155,14 +157,14 @@ func TestSafeBrowsing(t *testing.T) {
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
matching := "wmconvirus.narod.ru"
const matching = "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
Hostname: matching,
Block: true,
})
d.checkMatch(t, matching)
assert.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching)
require.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching)
d.checkMatch(t, "test."+matching)
d.checkMatchEmpty(t, "yandex.ru")
@@ -178,7 +180,7 @@ func TestSafeBrowsing(t *testing.T) {
func TestParallelSB(t *testing.T) {
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
matching := "wmconvirus.narod.ru"
const matching = "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
Hostname: matching,
Block: true,
@@ -203,7 +205,7 @@ func TestSafeSearch(t *testing.T) {
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close)
val, ok := d.SafeSearchDomain("www.google.com")
assert.True(t, ok, "Expected safesearch to find result for www.google.com")
require.True(t, ok, "Expected safesearch to find result for www.google.com")
assert.Equal(t, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
}
@@ -211,6 +213,8 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close)
yandexIP := net.IPv4(213, 180, 193, 56)
// Check host for each domain.
for _, host := range []string{
"yAndeX.ru",
@@ -220,22 +224,27 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
"yandex.kz",
"www.yandex.com",
} {
res, err := d.CheckHost(host, dns.TypeA, &setts)
assert.Nil(t, err)
assert.True(t, res.IsFiltered)
if assert.Len(t, res.Rules, 1) {
assert.Equal(t, res.Rules[0].IP, net.IPv4(213, 180, 193, 56))
}
t.Run(strings.ToLower(host), func(t *testing.T) {
res, err := d.CheckHost(host, dns.TypeA, &setts)
require.Nil(t, err)
assert.True(t, res.IsFiltered)
require.Len(t, res.Rules, 1)
assert.Equal(t, yandexIP, res.Rules[0].IP)
})
}
}
func TestCheckHostSafeSearchGoogle(t *testing.T) {
resolver := &aghtest.TestResolver{}
d := newForTest(&Config{
SafeSearchEnabled: true,
CustomResolver: &aghtest.TestResolver{},
CustomResolver: resolver,
}, nil)
t.Cleanup(d.Close)
ip, _ := resolver.HostToIPs("forcesafesearch.google.com")
// Check host for each domain.
for _, host := range []string{
"www.google.com",
@@ -248,11 +257,10 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
} {
t.Run(host, func(t *testing.T) {
res, err := d.CheckHost(host, dns.TypeA, &setts)
assert.Nil(t, err)
require.Nil(t, err)
assert.True(t, res.IsFiltered)
if assert.Len(t, res.Rules, 1) {
assert.NotEqual(t, res.Rules[0].IP.String(), "0.0.0.0")
}
require.Len(t, res.Rules, 1)
assert.Equal(t, ip, res.Rules[0].IP)
})
}
}
@@ -260,31 +268,31 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
func TestSafeSearchCacheYandex(t *testing.T) {
d := newForTest(nil, nil)
t.Cleanup(d.Close)
domain := "yandex.ru"
const domain = "yandex.ru"
// Check host with disabled safesearch.
res, err := d.CheckHost(domain, dns.TypeA, &setts)
assert.Nil(t, err)
require.Nil(t, err)
assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules)
require.Empty(t, res.Rules)
yandexIP := net.IPv4(213, 180, 193, 56)
d = newForTest(&Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close)
res, err = d.CheckHost(domain, dns.TypeA, &setts)
assert.Nilf(t, err, "CheckHost for safesearh domain %s failed cause %s", domain, err)
require.Nilf(t, err, "CheckHost for safesearh domain %s failed cause %s", domain, err)
// For yandex we already know valid IP.
if assert.Len(t, res.Rules, 1) {
assert.Equal(t, res.Rules[0].IP, net.IPv4(213, 180, 193, 56))
}
require.Len(t, res.Rules, 1)
assert.Equal(t, res.Rules[0].IP, yandexIP)
// Check cache.
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
assert.True(t, isFound)
if assert.Len(t, cachedValue.Rules, 1) {
assert.Equal(t, cachedValue.Rules[0].IP, net.IPv4(213, 180, 193, 56))
}
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
}
func TestSafeSearchCacheGoogle(t *testing.T) {
@@ -294,11 +302,11 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
}, nil)
t.Cleanup(d.Close)
domain := "www.google.ru"
const domain = "www.google.ru"
res, err := d.CheckHost(domain, dns.TypeA, &setts)
assert.Nil(t, err)
require.Nil(t, err)
assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules)
require.Empty(t, res.Rules)
d = newForTest(&Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close)
@@ -306,12 +314,10 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
// Lookup for safesearch domain.
safeDomain, ok := d.SafeSearchDomain(domain)
assert.Truef(t, ok, "Failed to get safesearch domain for %s", domain)
require.Truef(t, ok, "Failed to get safesearch domain for %s", domain)
ips, err := resolver.LookupIP(context.Background(), "ip", safeDomain)
if err != nil {
t.Fatalf("Failed to lookup for %s", safeDomain)
}
require.Nilf(t, err, "Failed to lookup for %s", safeDomain)
var ip net.IP
for _, foundIP := range ips {
@@ -323,17 +329,15 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
}
res, err = d.CheckHost(domain, dns.TypeA, &setts)
assert.Nil(t, err)
if assert.Len(t, res.Rules, 1) {
assert.True(t, res.Rules[0].IP.Equal(ip))
}
require.Nil(t, err)
require.Len(t, res.Rules, 1)
assert.True(t, res.Rules[0].IP.Equal(ip))
// Check cache.
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
assert.True(t, isFound)
if assert.Len(t, cachedValue.Rules, 1) {
assert.True(t, cachedValue.Rules[0].IP.Equal(ip))
}
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.True(t, cachedValue.Rules[0].IP.Equal(ip))
}
// Parental.
@@ -345,24 +349,23 @@ func TestParentalControl(t *testing.T) {
d := newForTest(&Config{ParentalEnabled: true}, nil)
t.Cleanup(d.Close)
matching := "pornhub.com"
const matching = "pornhub.com"
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
Hostname: matching,
Block: true,
})
d.checkMatch(t, matching)
assert.Contains(t, logOutput.String(), "Parental lookup for "+matching)
require.Contains(t, logOutput.String(), "Parental lookup for "+matching)
d.checkMatch(t, "www."+matching)
d.checkMatchEmpty(t, "www.yandex.ru")
d.checkMatchEmpty(t, "yandex.ru")
d.checkMatchEmpty(t, "api.jquery.com")
// test cached result
// Test cached result.
d.parentalServer = "127.0.0.1"
d.checkMatch(t, matching)
d.checkMatchEmpty(t, "yandex.ru")
d.parentalServer = defaultParentalServer
}
// Filtering.
@@ -651,7 +654,7 @@ func TestMatching(t *testing.T) {
t.Cleanup(d.Close)
res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts)
assert.Nilf(t, err, "Error while matching host %s: %s", tc.host, err)
require.Nilf(t, err, "Error while matching host %s: %s", tc.host, err)
assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered)
assert.Equalf(t, tc.wantReason, res.Reason, "Hostname %s has wrong reason (%v must be %v)", tc.host, res.Reason, tc.wantReason)
})
@@ -674,28 +677,24 @@ func TestWhitelist(t *testing.T) {
}}
d := newForTest(nil, filters)
err := d.SetFilters(filters, whiteFilters, false)
assert.Nil(t, err)
require.Nil(t, d.SetFilters(filters, whiteFilters, false))
t.Cleanup(d.Close)
// Matched by white filter.
res, err := d.CheckHost("host1", dns.TypeA, &setts)
assert.Nil(t, err)
require.Nil(t, err)
assert.False(t, res.IsFiltered)
assert.Equal(t, res.Reason, NotFilteredAllowList)
if assert.Len(t, res.Rules, 1) {
assert.Equal(t, "||host1^", res.Rules[0].Text)
}
require.Len(t, res.Rules, 1)
assert.Equal(t, "||host1^", res.Rules[0].Text)
// Not matched by white filter, but matched by block filter.
res, err = d.CheckHost("host2", dns.TypeA, &setts)
assert.Nil(t, err)
require.Nil(t, err)
assert.True(t, res.IsFiltered)
assert.Equal(t, res.Reason, FilteredBlockList)
if assert.Len(t, res.Rules, 1) {
assert.Equal(t, "||host2^", res.Rules[0].Text)
}
require.Len(t, res.Rules, 1)
assert.Equal(t, "||host2^", res.Rules[0].Text)
}
// Client Settings.
@@ -797,7 +796,7 @@ func BenchmarkSafeBrowsing(b *testing.B) {
})
for n := 0; n < b.N; n++ {
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
assert.Nilf(b, err, "Error while matching host %s: %s", blocked, err)
require.Nilf(b, err, "Error while matching host %s: %s", blocked, err)
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
}
}
@@ -813,7 +812,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
assert.Nilf(b, err, "Error while matching host %s: %s", blocked, err)
require.Nilf(b, err, "Error while matching host %s: %s", blocked, err)
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
}
})
@@ -824,7 +823,7 @@ func BenchmarkSafeSearch(b *testing.B) {
b.Cleanup(d.Close)
for n := 0; n < b.N; n++ {
val, ok := d.SafeSearchDomain("www.google.com")
assert.True(b, ok, "Expected safesearch to find result for www.google.com")
require.True(b, ok, "Expected safesearch to find result for www.google.com")
assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
}
}
@@ -835,7 +834,7 @@ func BenchmarkSafeSearchParallel(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
val, ok := d.SafeSearchDomain("www.google.com")
assert.True(b, ok, "Expected safesearch to find result for www.google.com")
require.True(b, ok, "Expected safesearch to find result for www.google.com")
assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
}
})

View File

@@ -7,6 +7,7 @@ import (
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
@@ -55,138 +56,89 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
ipv6p1 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
ipv6p2 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
testCasesA := []struct {
name string
dtyp uint16
rcode int
want []interface{}
}{{
name: "a-record",
dtyp: dns.TypeA,
rcode: dns.RcodeSuccess,
want: []interface{}{ipv4p1},
}, {
name: "aaaa-record",
dtyp: dns.TypeAAAA,
rcode: dns.RcodeSuccess,
want: []interface{}{ipv6p1},
}, {
name: "txt-record",
dtyp: dns.TypeTXT,
rcode: dns.RcodeSuccess,
want: []interface{}{"hello-world"},
}, {
name: "refused",
rcode: dns.RcodeRefused,
}, {
name: "a-records",
dtyp: dns.TypeA,
rcode: dns.RcodeSuccess,
want: []interface{}{ipv4p1, ipv4p2},
}, {
name: "aaaa-records",
dtyp: dns.TypeAAAA,
rcode: dns.RcodeSuccess,
want: []interface{}{ipv6p1, ipv6p2},
}, {
name: "disable-one",
dtyp: dns.TypeA,
rcode: dns.RcodeSuccess,
want: []interface{}{ipv4p2},
}, {
name: "disable-cname",
dtyp: dns.TypeA,
rcode: dns.RcodeSuccess,
want: []interface{}{ipv4p1},
}}
for _, tc := range testCasesA {
t.Run(tc.name, func(t *testing.T) {
host := path.Base(tc.name)
res, err := f.CheckHostRules(host, tc.dtyp, setts)
require.Nil(t, err)
dnsrr := res.DNSRewriteResult
require.NotNil(t, dnsrr)
assert.Equal(t, tc.rcode, dnsrr.RCode)
if tc.rcode == dns.RcodeRefused {
return
}
ipVals := dnsrr.Response[tc.dtyp]
require.Len(t, ipVals, len(tc.want))
for i, val := range tc.want {
require.Equal(t, val, ipVals[i])
}
})
}
t.Run("cname", func(t *testing.T) {
dtyp := dns.TypeA
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
require.Nil(t, err)
assert.Equal(t, "new-cname", res.CanonName)
})
t.Run("a-record", func(t *testing.T) {
dtyp := dns.TypeA
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 1) {
assert.Equal(t, ipv4p1, ipVals[0])
}
}
})
t.Run("aaaa-record", func(t *testing.T) {
dtyp := dns.TypeAAAA
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 1) {
assert.Equal(t, ipv6p1, ipVals[0])
}
}
})
t.Run("txt-record", func(t *testing.T) {
dtyp := dns.TypeTXT
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
if strVals := dnsrr.Response[dtyp]; assert.Len(t, strVals, 1) {
assert.Equal(t, "hello-world", strVals[0])
}
}
})
t.Run("refused", func(t *testing.T) {
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dns.TypeA, setts)
assert.Nil(t, err)
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
assert.Equal(t, dns.RcodeRefused, dnsrr.RCode)
}
})
t.Run("a-records", func(t *testing.T) {
dtyp := dns.TypeA
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 2) {
assert.Equal(t, ipv4p1, ipVals[0])
assert.Equal(t, ipv4p2, ipVals[1])
}
}
})
t.Run("aaaa-records", func(t *testing.T) {
dtyp := dns.TypeAAAA
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 2) {
assert.Equal(t, ipv6p1, ipVals[0])
assert.Equal(t, ipv6p2, ipVals[1])
}
}
})
t.Run("disable-one", func(t *testing.T) {
dtyp := dns.TypeA
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 1) {
assert.Equal(t, ipv4p2, ipVals[0])
}
}
})
t.Run("disable-cname", func(t *testing.T) {
dtyp := dns.TypeA
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
assert.Empty(t, res.CanonName)
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 1) {
assert.Equal(t, ipv4p1, ipVals[0])
}
}
})
t.Run("disable-cname-many", func(t *testing.T) {
dtyp := dns.TypeA
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
require.Nil(t, err)
assert.Equal(t, "new-cname-2", res.CanonName)
assert.Nil(t, res.DNSRewriteResult)
})
@@ -196,7 +148,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
require.Nil(t, err)
assert.Empty(t, res.CanonName)
assert.Empty(t, res.Rules)
})

View File

@@ -6,215 +6,297 @@ import (
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TODO(e.burkov): All the tests in this file may and should me merged together.
func TestRewrites(t *testing.T) {
d := newForTest(nil, nil)
t.Cleanup(d.Close)
// CNAME, A, AAAA
d.Rewrites = []RewriteEntry{
{"somecname", "somehost.com", 0, nil},
{"somehost.com", "0.0.0.0", 0, nil},
{"host.com", "1.2.3.4", 0, nil},
{"host.com", "1.2.3.5", 0, nil},
{"host.com", "1:2:3::4", 0, nil},
{"www.host.com", "host.com", 0, nil},
}
d.Rewrites = []RewriteEntry{{
// This one and below are about CNAME, A and AAAA.
Domain: "somecname",
Answer: "somehost.com",
}, {
Domain: "somehost.com",
Answer: "0.0.0.0",
}, {
Domain: "host.com",
Answer: "1.2.3.4",
}, {
Domain: "host.com",
Answer: "1.2.3.5",
}, {
Domain: "host.com",
Answer: "1:2:3::4",
}, {
Domain: "www.host.com",
Answer: "host.com",
}, {
// This one is a wildcard.
Domain: "*.host.com",
Answer: "1.2.3.5",
}, {
// This one and below are about wildcard overriding.
Domain: "a.host.com",
Answer: "1.2.3.4",
}, {
// This one is about CNAME and wildcard interacting.
Domain: "*.host2.com",
Answer: "host.com",
}, {
// This one and below are about 2 level CNAME.
Domain: "b.host.com",
Answer: "somecname",
}, {
// This one and below are about 2 level CNAME and wildcard.
Domain: "b.host3.com",
Answer: "a.host3.com",
}, {
Domain: "a.host3.com",
Answer: "x.host.com",
}}
d.prepareRewrites()
r := d.processRewrites("host2.com", dns.TypeA)
assert.Equal(t, NotFilteredNotFound, r.Reason)
r = d.processRewrites("www.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "host.com", r.CanonName)
assert.Len(t, r.IPList, 2)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
assert.True(t, r.IPList[1].Equal(net.IP{1, 2, 3, 5}))
testCases := []struct {
name string
host string
dtyp uint16
wantCName string
wantVals []net.IP
}{{
name: "not_filtered_not_found",
host: "hoost.com",
dtyp: dns.TypeA,
}, {
name: "rewritten_a",
host: "www.host.com",
dtyp: dns.TypeA,
wantCName: "host.com",
wantVals: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
}, {
name: "rewritten_aaaa",
host: "www.host.com",
dtyp: dns.TypeAAAA,
wantCName: "host.com",
wantVals: []net.IP{net.ParseIP("1:2:3::4")},
}, {
name: "wildcard_match",
host: "abc.host.com",
dtyp: dns.TypeA,
wantVals: []net.IP{{1, 2, 3, 5}},
}, {
name: "wildcard_override",
host: "a.host.com",
dtyp: dns.TypeA,
wantVals: []net.IP{{1, 2, 3, 4}},
}, {
name: "wildcard_cname_interaction",
host: "www.host2.com",
dtyp: dns.TypeA,
wantCName: "host.com",
wantVals: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
}, {
name: "two_cnames",
host: "b.host.com",
dtyp: dns.TypeA,
wantCName: "somehost.com",
wantVals: []net.IP{{0, 0, 0, 0}},
}, {
name: "two_cnames_and_wildcard",
host: "b.host3.com",
dtyp: dns.TypeA,
wantCName: "x.host.com",
wantVals: []net.IP{{1, 2, 3, 5}},
}}
r = d.processRewrites("www.host.com", dns.TypeAAAA)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "host.com", r.CanonName)
assert.Len(t, r.IPList, 1)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1:2:3::4")))
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
valsNum := len(tc.wantVals)
// wildcard
d.Rewrites = []RewriteEntry{
{"host.com", "1.2.3.4", 0, nil},
{"*.host.com", "1.2.3.5", 0, nil},
r := d.processRewrites(tc.host, tc.dtyp)
if valsNum == 0 {
assert.Equal(t, NotFilteredNotFound, r.Reason)
return
}
require.Equal(t, Rewritten, r.Reason)
if tc.wantCName != "" {
assert.Equal(t, tc.wantCName, r.CanonName)
}
require.Len(t, r.IPList, valsNum)
for i, ip := range tc.wantVals {
assert.Equal(t, ip, r.IPList[i])
}
})
}
d.prepareRewrites()
r = d.processRewrites("host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
r = d.processRewrites("www.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 5}))
r = d.processRewrites("www.host2.com", dns.TypeA)
assert.Equal(t, NotFilteredNotFound, r.Reason)
// override a wildcard
d.Rewrites = []RewriteEntry{
{"a.host.com", "1.2.3.4", 0, nil},
{"*.host.com", "1.2.3.5", 0, nil},
}
d.prepareRewrites()
r = d.processRewrites("a.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
// wildcard + CNAME
d.Rewrites = []RewriteEntry{
{"host.com", "1.2.3.4", 0, nil},
{"*.host.com", "host.com", 0, nil},
}
d.prepareRewrites()
r = d.processRewrites("www.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "host.com", r.CanonName)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
// 2 CNAMEs
d.Rewrites = []RewriteEntry{
{"b.host.com", "a.host.com", 0, nil},
{"a.host.com", "host.com", 0, nil},
{"host.com", "1.2.3.4", 0, nil},
}
d.prepareRewrites()
r = d.processRewrites("b.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "host.com", r.CanonName)
assert.Len(t, r.IPList, 1)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
// 2 CNAMEs + wildcard
d.Rewrites = []RewriteEntry{
{"b.host.com", "a.host.com", 0, nil},
{"a.host.com", "x.somehost.com", 0, nil},
{"*.somehost.com", "1.2.3.4", 0, nil},
}
d.prepareRewrites()
r = d.processRewrites("b.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "x.somehost.com", r.CanonName)
assert.Len(t, r.IPList, 1)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
}
func TestRewritesLevels(t *testing.T) {
d := newForTest(nil, nil)
t.Cleanup(d.Close)
// exact host, wildcard L2, wildcard L3
d.Rewrites = []RewriteEntry{
{"host.com", "1.1.1.1", 0, nil},
{"*.host.com", "2.2.2.2", 0, nil},
{"*.sub.host.com", "3.3.3.3", 0, nil},
}
// Exact host, wildcard L2, wildcard L3.
d.Rewrites = []RewriteEntry{{
Domain: "host.com",
Answer: "1.1.1.1",
}, {
Domain: "*.host.com",
Answer: "2.2.2.2",
}, {
Domain: "*.sub.host.com",
Answer: "3.3.3.3",
}}
d.prepareRewrites()
// match exact
r := d.processRewrites("host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.True(t, net.IP{1, 1, 1, 1}.Equal(r.IPList[0]))
testCases := []struct {
name string
host string
want net.IP
}{{
name: "exact_match",
host: "host.com",
want: net.IP{1, 1, 1, 1},
}, {
name: "l2_match",
host: "sub.host.com",
want: net.IP{2, 2, 2, 2},
}, {
name: "l3_match",
host: "my.sub.host.com",
want: net.IP{3, 3, 3, 3},
}}
// match L2
r = d.processRewrites("sub.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
// match L3
r = d.processRewrites("my.sub.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.True(t, net.IP{3, 3, 3, 3}.Equal(r.IPList[0]))
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := d.processRewrites(tc.host, dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
require.Len(t, r.IPList, 1)
})
}
}
func TestRewritesExceptionCNAME(t *testing.T) {
d := newForTest(nil, nil)
t.Cleanup(d.Close)
// wildcard; exception for a sub-domain
d.Rewrites = []RewriteEntry{
{"*.host.com", "2.2.2.2", 0, nil},
{"sub.host.com", "sub.host.com", 0, nil},
}
// Wildcard and exception for a sub-domain.
d.Rewrites = []RewriteEntry{{
Domain: "*.host.com",
Answer: "2.2.2.2",
}, {
Domain: "sub.host.com",
Answer: "sub.host.com",
}, {
Domain: "*.sub.host.com",
Answer: "*.sub.host.com",
}}
d.prepareRewrites()
// match sub-domain
r := d.processRewrites("my.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
testCases := []struct {
name string
host string
want net.IP
}{{
name: "match_sub-domain",
host: "my.host.com",
want: net.IP{2, 2, 2, 2},
}, {
name: "exception_cname",
host: "sub.host.com",
}, {
name: "exception_wildcard",
host: "my.sub.host.com",
}}
// match sub-domain, but handle exception
r = d.processRewrites("sub.host.com", dns.TypeA)
assert.Equal(t, NotFilteredNotFound, r.Reason)
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := d.processRewrites(tc.host, dns.TypeA)
if tc.want == nil {
assert.Equal(t, NotFilteredNotFound, r.Reason)
func TestRewritesExceptionWC(t *testing.T) {
d := newForTest(nil, nil)
t.Cleanup(d.Close)
// wildcard; exception for a sub-wildcard
d.Rewrites = []RewriteEntry{
{"*.host.com", "2.2.2.2", 0, nil},
{"*.sub.host.com", "*.sub.host.com", 0, nil},
return
}
assert.Equal(t, Rewritten, r.Reason)
require.Len(t, r.IPList, 1)
assert.True(t, tc.want.Equal(r.IPList[0]))
})
}
d.prepareRewrites()
// match sub-domain
r := d.processRewrites("my.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
// match sub-domain, but handle exception
r = d.processRewrites("my.sub.host.com", dns.TypeA)
assert.Equal(t, NotFilteredNotFound, r.Reason)
}
func TestRewritesExceptionIP(t *testing.T) {
d := newForTest(nil, nil)
t.Cleanup(d.Close)
// exception for AAAA record
d.Rewrites = []RewriteEntry{
{"host.com", "1.2.3.4", 0, nil},
{"host.com", "AAAA", 0, nil},
{"host2.com", "::1", 0, nil},
{"host2.com", "A", 0, nil},
{"host3.com", "A", 0, nil},
}
// Exception for AAAA record.
d.Rewrites = []RewriteEntry{{
Domain: "host.com",
Answer: "1.2.3.4",
}, {
Domain: "host.com",
Answer: "AAAA",
}, {
Domain: "host2.com",
Answer: "::1",
}, {
Domain: "host2.com",
Answer: "A",
}, {
Domain: "host3.com",
Answer: "A",
}}
d.prepareRewrites()
// match domain
r := d.processRewrites("host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.True(t, net.IP{1, 2, 3, 4}.Equal(r.IPList[0]))
testCases := []struct {
name string
host string
dtyp uint16
want []net.IP
}{{
name: "match_A",
host: "host.com",
dtyp: dns.TypeA,
want: []net.IP{{1, 2, 3, 4}},
}, {
name: "exception_AAAA_host.com",
host: "host.com",
dtyp: dns.TypeAAAA,
}, {
name: "exception_A_host2.com",
host: "host2.com",
dtyp: dns.TypeA,
}, {
name: "match_AAAA_host2.com",
host: "host2.com",
dtyp: dns.TypeAAAA,
want: []net.IP{net.ParseIP("::1")},
}, {
name: "exception_A_host3.com",
host: "host3.com",
dtyp: dns.TypeA,
}, {
name: "match_AAAA_host3.com",
host: "host3.com",
dtyp: dns.TypeAAAA,
want: []net.IP{},
}}
// match exception
r = d.processRewrites("host.com", dns.TypeAAAA)
assert.Equal(t, NotFilteredNotFound, r.Reason)
for _, tc := range testCases {
t.Run(tc.name+"_"+tc.host, func(t *testing.T) {
r := d.processRewrites(tc.host, tc.dtyp)
if tc.want == nil {
assert.Equal(t, NotFilteredNotFound, r.Reason)
// match exception
r = d.processRewrites("host2.com", dns.TypeA)
assert.Equal(t, NotFilteredNotFound, r.Reason)
return
}
// match domain
r = d.processRewrites("host2.com", dns.TypeAAAA)
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.Equal(t, "::1", r.IPList[0].String())
// match exception
r = d.processRewrites("host3.com", dns.TypeA)
assert.Equal(t, NotFilteredNotFound, r.Reason)
// match domain
r = d.processRewrites("host3.com", dns.TypeAAAA)
assert.Equal(t, Rewritten, r.Reason)
assert.Empty(t, r.IPList)
assert.Equal(t, Rewritten, r.Reason)
require.Len(t, r.IPList, len(tc.want))
for _, ip := range tc.want {
assert.True(t, ip.Equal(r.IPList[0]))
}
})
}
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/cache"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSafeBrowsingHash(t *testing.T) {
@@ -155,25 +156,25 @@ func TestSBPC(t *testing.T) {
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Prepare the upstream.
ups := &aghtest.TestBlockUpstream{
Hostname: hostname,
Block: tc.block,
}
d.SetSafeBrowsingUpstream(ups)
d.SetParentalUpstream(ups)
// Prepare the upstream.
ups := &aghtest.TestBlockUpstream{
Hostname: hostname,
Block: tc.block,
}
d.SetSafeBrowsingUpstream(ups)
d.SetParentalUpstream(ups)
t.Run(tc.name, func(t *testing.T) {
// Firstly, check the request blocking.
hits := 0
res, err := tc.testFunc(hostname)
assert.Nil(t, err)
require.Nil(t, err)
if tc.block {
assert.True(t, res.IsFiltered)
assert.Len(t, res.Rules, 1)
require.Len(t, res.Rules, 1)
hits++
} else {
assert.False(t, res.IsFiltered)
require.False(t, res.IsFiltered)
}
// Check the cache state, check the response is now cached.
@@ -185,12 +186,12 @@ func TestSBPC(t *testing.T) {
// Now make the same request to check the cache was used.
res, err = tc.testFunc(hostname)
assert.Nil(t, err)
require.Nil(t, err)
if tc.block {
assert.True(t, res.IsFiltered)
assert.Len(t, res.Rules, 1)
require.Len(t, res.Rules, 1)
} else {
assert.False(t, res.IsFiltered)
require.False(t, res.IsFiltered)
}
// Check the cache state, it should've been used.
@@ -199,8 +200,8 @@ func TestSBPC(t *testing.T) {
// Check that there were no additional requests.
assert.Equal(t, 1, ups.RequestsCount())
purgeCaches()
})
purgeCaches()
}
}