Pull request: 2639 use testify require vol.4
Merge in DNS/adguard-home from 2639-testify-require-4 to master Closes #2639. Squashed commit of the following: commit 0bb9125f42ab6d2511c1b8e481112aa5edd581d9 Merge: 0e9e9ed12c9992e0Author: Eugene Burkov <e.burkov@adguard.com> Date: Thu Mar 11 15:47:21 2021 +0300 Merge branch 'master' into 2639-testify-require-4 commit 0e9e9ed16ae13ce648b5e1da6ffd123df911c2d7 Author: Eugene Burkov <e.burkov@adguard.com> Date: Wed Mar 10 12:43:15 2021 +0300 home: rm deletion error check commit 6bfbbcd2b7f9197a06856f9e6b959c2e1c4b8353 Merge: c8ebe5418811c881Author: Eugene Burkov <e.burkov@adguard.com> Date: Wed Mar 10 12:30:07 2021 +0300 Merge branch 'master' into 2639-testify-require-4 commit c8ebe54142bba780226f76ddb72e33664ed28f30 Author: Eugene Burkov <e.burkov@adguard.com> Date: Wed Mar 10 12:28:43 2021 +0300 home: imp tests commit f0e1db456f02df5f5f56ca93e7bd40a48475b38c Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Mar 5 14:06:41 2021 +0300 dnsforward: imp tests commit 4528246105ed06471a8778abbe8e5c30fc5483d5 Merge: 54b08d9c90ebc4d8Author: Eugene Burkov <e.burkov@adguard.com> Date: Thu Mar 4 18:17:52 2021 +0300 Merge branch 'master' into 2639-testify-require-4 commit 54b08d9c980b8d69d019a1a1b3931aa048275691 Author: Eugene Burkov <e.burkov@adguard.com> Date: Thu Feb 11 13:17:05 2021 +0300 dnsfilter: imp tests
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -58,7 +60,7 @@ func (d *DNSFilter) checkMatch(t *testing.T, hostname string) {
|
||||
t.Helper()
|
||||
|
||||
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
|
||||
assert.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
|
||||
require.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
|
||||
assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname)
|
||||
}
|
||||
|
||||
@@ -66,20 +68,20 @@ func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16
|
||||
t.Helper()
|
||||
|
||||
res, err := d.CheckHost(hostname, qtype, &setts)
|
||||
assert.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
|
||||
require.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
|
||||
assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname)
|
||||
if assert.NotEmpty(t, res.Rules, "Expected result to have rules") {
|
||||
r := res.Rules[0]
|
||||
assert.NotNilf(t, r.IP, "Expected ip %s to match, actual: %v", ip, r.IP)
|
||||
assert.Equalf(t, ip, r.IP.String(), "Expected ip %s to match, actual: %v", ip, r.IP)
|
||||
}
|
||||
|
||||
require.NotEmpty(t, res.Rules, "Expected result to have rules")
|
||||
r := res.Rules[0]
|
||||
require.NotNilf(t, r.IP, "Expected ip %s to match, actual: %v", ip, r.IP)
|
||||
assert.Equalf(t, ip, r.IP.String(), "Expected ip %s to match, actual: %v", ip, r.IP)
|
||||
}
|
||||
|
||||
func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string) {
|
||||
t.Helper()
|
||||
|
||||
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
|
||||
assert.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
|
||||
require.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
|
||||
assert.Falsef(t, res.IsFiltered, "Expected hostname %s to not match", hostname)
|
||||
}
|
||||
|
||||
@@ -110,40 +112,40 @@ func TestEtcHostsMatching(t *testing.T) {
|
||||
|
||||
// Empty IPv6.
|
||||
res, err := d.CheckHost("block.com", dns.TypeAAAA, &setts)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.True(t, res.IsFiltered)
|
||||
if assert.Len(t, res.Rules, 1) {
|
||||
assert.Equal(t, "0.0.0.0 block.com", res.Rules[0].Text)
|
||||
assert.Empty(t, res.Rules[0].IP)
|
||||
}
|
||||
|
||||
require.Len(t, res.Rules, 1)
|
||||
assert.Equal(t, "0.0.0.0 block.com", res.Rules[0].Text)
|
||||
assert.Empty(t, res.Rules[0].IP)
|
||||
|
||||
// IPv6 match.
|
||||
d.checkMatchIP(t, "ipv6.com", addr6, dns.TypeAAAA)
|
||||
|
||||
// Empty IPv4.
|
||||
res, err = d.CheckHost("ipv6.com", dns.TypeA, &setts)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.True(t, res.IsFiltered)
|
||||
if assert.Len(t, res.Rules, 1) {
|
||||
assert.Equal(t, "::1 ipv6.com", res.Rules[0].Text)
|
||||
assert.Empty(t, res.Rules[0].IP)
|
||||
}
|
||||
|
||||
require.Len(t, res.Rules, 1)
|
||||
assert.Equal(t, "::1 ipv6.com", res.Rules[0].Text)
|
||||
assert.Empty(t, res.Rules[0].IP)
|
||||
|
||||
// Two IPv4, the first one returned.
|
||||
res, err = d.CheckHost("host2", dns.TypeA, &setts)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.True(t, res.IsFiltered)
|
||||
if assert.Len(t, res.Rules, 1) {
|
||||
assert.Equal(t, res.Rules[0].IP, net.IP{0, 0, 0, 1})
|
||||
}
|
||||
|
||||
require.Len(t, res.Rules, 1)
|
||||
assert.Equal(t, res.Rules[0].IP, net.IP{0, 0, 0, 1})
|
||||
|
||||
// One IPv6 address.
|
||||
res, err = d.CheckHost("host2", dns.TypeAAAA, &setts)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.True(t, res.IsFiltered)
|
||||
if assert.Len(t, res.Rules, 1) {
|
||||
assert.Equal(t, res.Rules[0].IP, net.IPv6loopback)
|
||||
}
|
||||
|
||||
require.Len(t, res.Rules, 1)
|
||||
assert.Equal(t, res.Rules[0].IP, net.IPv6loopback)
|
||||
}
|
||||
|
||||
// Safe Browsing.
|
||||
@@ -155,14 +157,14 @@ func TestSafeBrowsing(t *testing.T) {
|
||||
|
||||
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
matching := "wmconvirus.narod.ru"
|
||||
const matching = "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: matching,
|
||||
Block: true,
|
||||
})
|
||||
d.checkMatch(t, matching)
|
||||
|
||||
assert.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching)
|
||||
require.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching)
|
||||
|
||||
d.checkMatch(t, "test."+matching)
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
@@ -178,7 +180,7 @@ func TestSafeBrowsing(t *testing.T) {
|
||||
func TestParallelSB(t *testing.T) {
|
||||
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
matching := "wmconvirus.narod.ru"
|
||||
const matching = "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: matching,
|
||||
Block: true,
|
||||
@@ -203,7 +205,7 @@ func TestSafeSearch(t *testing.T) {
|
||||
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
val, ok := d.SafeSearchDomain("www.google.com")
|
||||
assert.True(t, ok, "Expected safesearch to find result for www.google.com")
|
||||
require.True(t, ok, "Expected safesearch to find result for www.google.com")
|
||||
assert.Equal(t, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
|
||||
}
|
||||
|
||||
@@ -211,6 +213,8 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
yandexIP := net.IPv4(213, 180, 193, 56)
|
||||
|
||||
// Check host for each domain.
|
||||
for _, host := range []string{
|
||||
"yAndeX.ru",
|
||||
@@ -220,22 +224,27 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||
"yandex.kz",
|
||||
"www.yandex.com",
|
||||
} {
|
||||
res, err := d.CheckHost(host, dns.TypeA, &setts)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, res.IsFiltered)
|
||||
if assert.Len(t, res.Rules, 1) {
|
||||
assert.Equal(t, res.Rules[0].IP, net.IPv4(213, 180, 193, 56))
|
||||
}
|
||||
t.Run(strings.ToLower(host), func(t *testing.T) {
|
||||
res, err := d.CheckHost(host, dns.TypeA, &setts)
|
||||
require.Nil(t, err)
|
||||
assert.True(t, res.IsFiltered)
|
||||
|
||||
require.Len(t, res.Rules, 1)
|
||||
assert.Equal(t, yandexIP, res.Rules[0].IP)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||
resolver := &aghtest.TestResolver{}
|
||||
d := newForTest(&Config{
|
||||
SafeSearchEnabled: true,
|
||||
CustomResolver: &aghtest.TestResolver{},
|
||||
CustomResolver: resolver,
|
||||
}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
ip, _ := resolver.HostToIPs("forcesafesearch.google.com")
|
||||
|
||||
// Check host for each domain.
|
||||
for _, host := range []string{
|
||||
"www.google.com",
|
||||
@@ -248,11 +257,10 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||
} {
|
||||
t.Run(host, func(t *testing.T) {
|
||||
res, err := d.CheckHost(host, dns.TypeA, &setts)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.True(t, res.IsFiltered)
|
||||
if assert.Len(t, res.Rules, 1) {
|
||||
assert.NotEqual(t, res.Rules[0].IP.String(), "0.0.0.0")
|
||||
}
|
||||
require.Len(t, res.Rules, 1)
|
||||
assert.Equal(t, ip, res.Rules[0].IP)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -260,31 +268,31 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||
func TestSafeSearchCacheYandex(t *testing.T) {
|
||||
d := newForTest(nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
domain := "yandex.ru"
|
||||
const domain = "yandex.ru"
|
||||
|
||||
// Check host with disabled safesearch.
|
||||
res, err := d.CheckHost(domain, dns.TypeA, &setts)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.False(t, res.IsFiltered)
|
||||
assert.Empty(t, res.Rules)
|
||||
require.Empty(t, res.Rules)
|
||||
|
||||
yandexIP := net.IPv4(213, 180, 193, 56)
|
||||
|
||||
d = newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
res, err = d.CheckHost(domain, dns.TypeA, &setts)
|
||||
assert.Nilf(t, err, "CheckHost for safesearh domain %s failed cause %s", domain, err)
|
||||
require.Nilf(t, err, "CheckHost for safesearh domain %s failed cause %s", domain, err)
|
||||
|
||||
// For yandex we already know valid IP.
|
||||
if assert.Len(t, res.Rules, 1) {
|
||||
assert.Equal(t, res.Rules[0].IP, net.IPv4(213, 180, 193, 56))
|
||||
}
|
||||
require.Len(t, res.Rules, 1)
|
||||
assert.Equal(t, res.Rules[0].IP, yandexIP)
|
||||
|
||||
// Check cache.
|
||||
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
|
||||
assert.True(t, isFound)
|
||||
if assert.Len(t, cachedValue.Rules, 1) {
|
||||
assert.Equal(t, cachedValue.Rules[0].IP, net.IPv4(213, 180, 193, 56))
|
||||
}
|
||||
require.True(t, isFound)
|
||||
require.Len(t, cachedValue.Rules, 1)
|
||||
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
|
||||
}
|
||||
|
||||
func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||
@@ -294,11 +302,11 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||
}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
domain := "www.google.ru"
|
||||
const domain = "www.google.ru"
|
||||
res, err := d.CheckHost(domain, dns.TypeA, &setts)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.False(t, res.IsFiltered)
|
||||
assert.Empty(t, res.Rules)
|
||||
require.Empty(t, res.Rules)
|
||||
|
||||
d = newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
@@ -306,12 +314,10 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||
|
||||
// Lookup for safesearch domain.
|
||||
safeDomain, ok := d.SafeSearchDomain(domain)
|
||||
assert.Truef(t, ok, "Failed to get safesearch domain for %s", domain)
|
||||
require.Truef(t, ok, "Failed to get safesearch domain for %s", domain)
|
||||
|
||||
ips, err := resolver.LookupIP(context.Background(), "ip", safeDomain)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to lookup for %s", safeDomain)
|
||||
}
|
||||
require.Nilf(t, err, "Failed to lookup for %s", safeDomain)
|
||||
|
||||
var ip net.IP
|
||||
for _, foundIP := range ips {
|
||||
@@ -323,17 +329,15 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||
}
|
||||
|
||||
res, err = d.CheckHost(domain, dns.TypeA, &setts)
|
||||
assert.Nil(t, err)
|
||||
if assert.Len(t, res.Rules, 1) {
|
||||
assert.True(t, res.Rules[0].IP.Equal(ip))
|
||||
}
|
||||
require.Nil(t, err)
|
||||
require.Len(t, res.Rules, 1)
|
||||
assert.True(t, res.Rules[0].IP.Equal(ip))
|
||||
|
||||
// Check cache.
|
||||
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
|
||||
assert.True(t, isFound)
|
||||
if assert.Len(t, cachedValue.Rules, 1) {
|
||||
assert.True(t, cachedValue.Rules[0].IP.Equal(ip))
|
||||
}
|
||||
require.True(t, isFound)
|
||||
require.Len(t, cachedValue.Rules, 1)
|
||||
assert.True(t, cachedValue.Rules[0].IP.Equal(ip))
|
||||
}
|
||||
|
||||
// Parental.
|
||||
@@ -345,24 +349,23 @@ func TestParentalControl(t *testing.T) {
|
||||
|
||||
d := newForTest(&Config{ParentalEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
matching := "pornhub.com"
|
||||
const matching = "pornhub.com"
|
||||
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: matching,
|
||||
Block: true,
|
||||
})
|
||||
|
||||
d.checkMatch(t, matching)
|
||||
assert.Contains(t, logOutput.String(), "Parental lookup for "+matching)
|
||||
require.Contains(t, logOutput.String(), "Parental lookup for "+matching)
|
||||
d.checkMatch(t, "www."+matching)
|
||||
d.checkMatchEmpty(t, "www.yandex.ru")
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
d.checkMatchEmpty(t, "api.jquery.com")
|
||||
|
||||
// test cached result
|
||||
// Test cached result.
|
||||
d.parentalServer = "127.0.0.1"
|
||||
d.checkMatch(t, matching)
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
d.parentalServer = defaultParentalServer
|
||||
}
|
||||
|
||||
// Filtering.
|
||||
@@ -651,7 +654,7 @@ func TestMatching(t *testing.T) {
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts)
|
||||
assert.Nilf(t, err, "Error while matching host %s: %s", tc.host, err)
|
||||
require.Nilf(t, err, "Error while matching host %s: %s", tc.host, err)
|
||||
assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered)
|
||||
assert.Equalf(t, tc.wantReason, res.Reason, "Hostname %s has wrong reason (%v must be %v)", tc.host, res.Reason, tc.wantReason)
|
||||
})
|
||||
@@ -674,28 +677,24 @@ func TestWhitelist(t *testing.T) {
|
||||
}}
|
||||
d := newForTest(nil, filters)
|
||||
|
||||
err := d.SetFilters(filters, whiteFilters, false)
|
||||
assert.Nil(t, err)
|
||||
|
||||
require.Nil(t, d.SetFilters(filters, whiteFilters, false))
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
// Matched by white filter.
|
||||
res, err := d.CheckHost("host1", dns.TypeA, &setts)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.False(t, res.IsFiltered)
|
||||
assert.Equal(t, res.Reason, NotFilteredAllowList)
|
||||
if assert.Len(t, res.Rules, 1) {
|
||||
assert.Equal(t, "||host1^", res.Rules[0].Text)
|
||||
}
|
||||
require.Len(t, res.Rules, 1)
|
||||
assert.Equal(t, "||host1^", res.Rules[0].Text)
|
||||
|
||||
// Not matched by white filter, but matched by block filter.
|
||||
res, err = d.CheckHost("host2", dns.TypeA, &setts)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.True(t, res.IsFiltered)
|
||||
assert.Equal(t, res.Reason, FilteredBlockList)
|
||||
if assert.Len(t, res.Rules, 1) {
|
||||
assert.Equal(t, "||host2^", res.Rules[0].Text)
|
||||
}
|
||||
require.Len(t, res.Rules, 1)
|
||||
assert.Equal(t, "||host2^", res.Rules[0].Text)
|
||||
}
|
||||
|
||||
// Client Settings.
|
||||
@@ -797,7 +796,7 @@ func BenchmarkSafeBrowsing(b *testing.B) {
|
||||
})
|
||||
for n := 0; n < b.N; n++ {
|
||||
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
|
||||
assert.Nilf(b, err, "Error while matching host %s: %s", blocked, err)
|
||||
require.Nilf(b, err, "Error while matching host %s: %s", blocked, err)
|
||||
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
|
||||
}
|
||||
}
|
||||
@@ -813,7 +812,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
|
||||
assert.Nilf(b, err, "Error while matching host %s: %s", blocked, err)
|
||||
require.Nilf(b, err, "Error while matching host %s: %s", blocked, err)
|
||||
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
|
||||
}
|
||||
})
|
||||
@@ -824,7 +823,7 @@ func BenchmarkSafeSearch(b *testing.B) {
|
||||
b.Cleanup(d.Close)
|
||||
for n := 0; n < b.N; n++ {
|
||||
val, ok := d.SafeSearchDomain("www.google.com")
|
||||
assert.True(b, ok, "Expected safesearch to find result for www.google.com")
|
||||
require.True(b, ok, "Expected safesearch to find result for www.google.com")
|
||||
assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
|
||||
}
|
||||
}
|
||||
@@ -835,7 +834,7 @@ func BenchmarkSafeSearchParallel(b *testing.B) {
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
val, ok := d.SafeSearchDomain("www.google.com")
|
||||
assert.True(b, ok, "Expected safesearch to find result for www.google.com")
|
||||
require.True(b, ok, "Expected safesearch to find result for www.google.com")
|
||||
assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||
@@ -55,138 +56,89 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||
ipv6p1 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||||
ipv6p2 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
|
||||
|
||||
testCasesA := []struct {
|
||||
name string
|
||||
dtyp uint16
|
||||
rcode int
|
||||
want []interface{}
|
||||
}{{
|
||||
name: "a-record",
|
||||
dtyp: dns.TypeA,
|
||||
rcode: dns.RcodeSuccess,
|
||||
want: []interface{}{ipv4p1},
|
||||
}, {
|
||||
name: "aaaa-record",
|
||||
dtyp: dns.TypeAAAA,
|
||||
rcode: dns.RcodeSuccess,
|
||||
want: []interface{}{ipv6p1},
|
||||
}, {
|
||||
name: "txt-record",
|
||||
dtyp: dns.TypeTXT,
|
||||
rcode: dns.RcodeSuccess,
|
||||
want: []interface{}{"hello-world"},
|
||||
}, {
|
||||
name: "refused",
|
||||
rcode: dns.RcodeRefused,
|
||||
}, {
|
||||
name: "a-records",
|
||||
dtyp: dns.TypeA,
|
||||
rcode: dns.RcodeSuccess,
|
||||
want: []interface{}{ipv4p1, ipv4p2},
|
||||
}, {
|
||||
name: "aaaa-records",
|
||||
dtyp: dns.TypeAAAA,
|
||||
rcode: dns.RcodeSuccess,
|
||||
want: []interface{}{ipv6p1, ipv6p2},
|
||||
}, {
|
||||
name: "disable-one",
|
||||
dtyp: dns.TypeA,
|
||||
rcode: dns.RcodeSuccess,
|
||||
want: []interface{}{ipv4p2},
|
||||
}, {
|
||||
name: "disable-cname",
|
||||
dtyp: dns.TypeA,
|
||||
rcode: dns.RcodeSuccess,
|
||||
want: []interface{}{ipv4p1},
|
||||
}}
|
||||
|
||||
for _, tc := range testCasesA {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
host := path.Base(tc.name)
|
||||
|
||||
res, err := f.CheckHostRules(host, tc.dtyp, setts)
|
||||
require.Nil(t, err)
|
||||
|
||||
dnsrr := res.DNSRewriteResult
|
||||
require.NotNil(t, dnsrr)
|
||||
assert.Equal(t, tc.rcode, dnsrr.RCode)
|
||||
|
||||
if tc.rcode == dns.RcodeRefused {
|
||||
return
|
||||
}
|
||||
|
||||
ipVals := dnsrr.Response[tc.dtyp]
|
||||
require.Len(t, ipVals, len(tc.want))
|
||||
for i, val := range tc.want {
|
||||
require.Equal(t, val, ipVals[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("cname", func(t *testing.T) {
|
||||
dtyp := dns.TypeA
|
||||
host := path.Base(t.Name())
|
||||
|
||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, "new-cname", res.CanonName)
|
||||
})
|
||||
|
||||
t.Run("a-record", func(t *testing.T) {
|
||||
dtyp := dns.TypeA
|
||||
host := path.Base(t.Name())
|
||||
|
||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
||||
assert.Nil(t, err)
|
||||
|
||||
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
|
||||
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
|
||||
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 1) {
|
||||
assert.Equal(t, ipv4p1, ipVals[0])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("aaaa-record", func(t *testing.T) {
|
||||
dtyp := dns.TypeAAAA
|
||||
host := path.Base(t.Name())
|
||||
|
||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
||||
assert.Nil(t, err)
|
||||
|
||||
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
|
||||
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
|
||||
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 1) {
|
||||
assert.Equal(t, ipv6p1, ipVals[0])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("txt-record", func(t *testing.T) {
|
||||
dtyp := dns.TypeTXT
|
||||
host := path.Base(t.Name())
|
||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
||||
assert.Nil(t, err)
|
||||
|
||||
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
|
||||
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
|
||||
if strVals := dnsrr.Response[dtyp]; assert.Len(t, strVals, 1) {
|
||||
assert.Equal(t, "hello-world", strVals[0])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("refused", func(t *testing.T) {
|
||||
host := path.Base(t.Name())
|
||||
res, err := f.CheckHostRules(host, dns.TypeA, setts)
|
||||
assert.Nil(t, err)
|
||||
|
||||
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
|
||||
assert.Equal(t, dns.RcodeRefused, dnsrr.RCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("a-records", func(t *testing.T) {
|
||||
dtyp := dns.TypeA
|
||||
host := path.Base(t.Name())
|
||||
|
||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
||||
assert.Nil(t, err)
|
||||
|
||||
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
|
||||
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
|
||||
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 2) {
|
||||
assert.Equal(t, ipv4p1, ipVals[0])
|
||||
assert.Equal(t, ipv4p2, ipVals[1])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("aaaa-records", func(t *testing.T) {
|
||||
dtyp := dns.TypeAAAA
|
||||
host := path.Base(t.Name())
|
||||
|
||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
||||
assert.Nil(t, err)
|
||||
|
||||
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
|
||||
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
|
||||
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 2) {
|
||||
assert.Equal(t, ipv6p1, ipVals[0])
|
||||
assert.Equal(t, ipv6p2, ipVals[1])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("disable-one", func(t *testing.T) {
|
||||
dtyp := dns.TypeA
|
||||
host := path.Base(t.Name())
|
||||
|
||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
||||
assert.Nil(t, err)
|
||||
|
||||
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
|
||||
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
|
||||
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 1) {
|
||||
assert.Equal(t, ipv4p2, ipVals[0])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("disable-cname", func(t *testing.T) {
|
||||
dtyp := dns.TypeA
|
||||
host := path.Base(t.Name())
|
||||
|
||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
||||
assert.Nil(t, err)
|
||||
assert.Empty(t, res.CanonName)
|
||||
|
||||
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
|
||||
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
|
||||
if ipVals := dnsrr.Response[dtyp]; assert.Len(t, ipVals, 1) {
|
||||
assert.Equal(t, ipv4p1, ipVals[0])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("disable-cname-many", func(t *testing.T) {
|
||||
dtyp := dns.TypeA
|
||||
host := path.Base(t.Name())
|
||||
|
||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, "new-cname-2", res.CanonName)
|
||||
assert.Nil(t, res.DNSRewriteResult)
|
||||
})
|
||||
@@ -196,7 +148,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||
host := path.Base(t.Name())
|
||||
|
||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.Empty(t, res.CanonName)
|
||||
assert.Empty(t, res.Rules)
|
||||
})
|
||||
|
||||
@@ -6,215 +6,297 @@ import (
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TODO(e.burkov): All the tests in this file may and should me merged together.
|
||||
|
||||
func TestRewrites(t *testing.T) {
|
||||
d := newForTest(nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// CNAME, A, AAAA
|
||||
d.Rewrites = []RewriteEntry{
|
||||
{"somecname", "somehost.com", 0, nil},
|
||||
{"somehost.com", "0.0.0.0", 0, nil},
|
||||
|
||||
{"host.com", "1.2.3.4", 0, nil},
|
||||
{"host.com", "1.2.3.5", 0, nil},
|
||||
{"host.com", "1:2:3::4", 0, nil},
|
||||
{"www.host.com", "host.com", 0, nil},
|
||||
}
|
||||
d.Rewrites = []RewriteEntry{{
|
||||
// This one and below are about CNAME, A and AAAA.
|
||||
Domain: "somecname",
|
||||
Answer: "somehost.com",
|
||||
}, {
|
||||
Domain: "somehost.com",
|
||||
Answer: "0.0.0.0",
|
||||
}, {
|
||||
Domain: "host.com",
|
||||
Answer: "1.2.3.4",
|
||||
}, {
|
||||
Domain: "host.com",
|
||||
Answer: "1.2.3.5",
|
||||
}, {
|
||||
Domain: "host.com",
|
||||
Answer: "1:2:3::4",
|
||||
}, {
|
||||
Domain: "www.host.com",
|
||||
Answer: "host.com",
|
||||
}, {
|
||||
// This one is a wildcard.
|
||||
Domain: "*.host.com",
|
||||
Answer: "1.2.3.5",
|
||||
}, {
|
||||
// This one and below are about wildcard overriding.
|
||||
Domain: "a.host.com",
|
||||
Answer: "1.2.3.4",
|
||||
}, {
|
||||
// This one is about CNAME and wildcard interacting.
|
||||
Domain: "*.host2.com",
|
||||
Answer: "host.com",
|
||||
}, {
|
||||
// This one and below are about 2 level CNAME.
|
||||
Domain: "b.host.com",
|
||||
Answer: "somecname",
|
||||
}, {
|
||||
// This one and below are about 2 level CNAME and wildcard.
|
||||
Domain: "b.host3.com",
|
||||
Answer: "a.host3.com",
|
||||
}, {
|
||||
Domain: "a.host3.com",
|
||||
Answer: "x.host.com",
|
||||
}}
|
||||
d.prepareRewrites()
|
||||
r := d.processRewrites("host2.com", dns.TypeA)
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
||||
|
||||
r = d.processRewrites("www.host.com", dns.TypeA)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Equal(t, "host.com", r.CanonName)
|
||||
assert.Len(t, r.IPList, 2)
|
||||
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
|
||||
assert.True(t, r.IPList[1].Equal(net.IP{1, 2, 3, 5}))
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
dtyp uint16
|
||||
wantCName string
|
||||
wantVals []net.IP
|
||||
}{{
|
||||
name: "not_filtered_not_found",
|
||||
host: "hoost.com",
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "rewritten_a",
|
||||
host: "www.host.com",
|
||||
dtyp: dns.TypeA,
|
||||
wantCName: "host.com",
|
||||
wantVals: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
|
||||
}, {
|
||||
name: "rewritten_aaaa",
|
||||
host: "www.host.com",
|
||||
dtyp: dns.TypeAAAA,
|
||||
wantCName: "host.com",
|
||||
wantVals: []net.IP{net.ParseIP("1:2:3::4")},
|
||||
}, {
|
||||
name: "wildcard_match",
|
||||
host: "abc.host.com",
|
||||
dtyp: dns.TypeA,
|
||||
wantVals: []net.IP{{1, 2, 3, 5}},
|
||||
}, {
|
||||
name: "wildcard_override",
|
||||
host: "a.host.com",
|
||||
dtyp: dns.TypeA,
|
||||
wantVals: []net.IP{{1, 2, 3, 4}},
|
||||
}, {
|
||||
name: "wildcard_cname_interaction",
|
||||
host: "www.host2.com",
|
||||
dtyp: dns.TypeA,
|
||||
wantCName: "host.com",
|
||||
wantVals: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
|
||||
}, {
|
||||
name: "two_cnames",
|
||||
host: "b.host.com",
|
||||
dtyp: dns.TypeA,
|
||||
wantCName: "somehost.com",
|
||||
wantVals: []net.IP{{0, 0, 0, 0}},
|
||||
}, {
|
||||
name: "two_cnames_and_wildcard",
|
||||
host: "b.host3.com",
|
||||
dtyp: dns.TypeA,
|
||||
wantCName: "x.host.com",
|
||||
wantVals: []net.IP{{1, 2, 3, 5}},
|
||||
}}
|
||||
|
||||
r = d.processRewrites("www.host.com", dns.TypeAAAA)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Equal(t, "host.com", r.CanonName)
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.True(t, r.IPList[0].Equal(net.ParseIP("1:2:3::4")))
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
valsNum := len(tc.wantVals)
|
||||
|
||||
// wildcard
|
||||
d.Rewrites = []RewriteEntry{
|
||||
{"host.com", "1.2.3.4", 0, nil},
|
||||
{"*.host.com", "1.2.3.5", 0, nil},
|
||||
r := d.processRewrites(tc.host, tc.dtyp)
|
||||
if valsNum == 0 {
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.Equal(t, Rewritten, r.Reason)
|
||||
if tc.wantCName != "" {
|
||||
assert.Equal(t, tc.wantCName, r.CanonName)
|
||||
}
|
||||
|
||||
require.Len(t, r.IPList, valsNum)
|
||||
for i, ip := range tc.wantVals {
|
||||
assert.Equal(t, ip, r.IPList[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
d.prepareRewrites()
|
||||
r = d.processRewrites("host.com", dns.TypeA)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
|
||||
|
||||
r = d.processRewrites("www.host.com", dns.TypeA)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 5}))
|
||||
|
||||
r = d.processRewrites("www.host2.com", dns.TypeA)
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
||||
|
||||
// override a wildcard
|
||||
d.Rewrites = []RewriteEntry{
|
||||
{"a.host.com", "1.2.3.4", 0, nil},
|
||||
{"*.host.com", "1.2.3.5", 0, nil},
|
||||
}
|
||||
d.prepareRewrites()
|
||||
r = d.processRewrites("a.host.com", dns.TypeA)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
|
||||
|
||||
// wildcard + CNAME
|
||||
d.Rewrites = []RewriteEntry{
|
||||
{"host.com", "1.2.3.4", 0, nil},
|
||||
{"*.host.com", "host.com", 0, nil},
|
||||
}
|
||||
d.prepareRewrites()
|
||||
r = d.processRewrites("www.host.com", dns.TypeA)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Equal(t, "host.com", r.CanonName)
|
||||
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
|
||||
|
||||
// 2 CNAMEs
|
||||
d.Rewrites = []RewriteEntry{
|
||||
{"b.host.com", "a.host.com", 0, nil},
|
||||
{"a.host.com", "host.com", 0, nil},
|
||||
{"host.com", "1.2.3.4", 0, nil},
|
||||
}
|
||||
d.prepareRewrites()
|
||||
r = d.processRewrites("b.host.com", dns.TypeA)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Equal(t, "host.com", r.CanonName)
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
|
||||
|
||||
// 2 CNAMEs + wildcard
|
||||
d.Rewrites = []RewriteEntry{
|
||||
{"b.host.com", "a.host.com", 0, nil},
|
||||
{"a.host.com", "x.somehost.com", 0, nil},
|
||||
{"*.somehost.com", "1.2.3.4", 0, nil},
|
||||
}
|
||||
d.prepareRewrites()
|
||||
r = d.processRewrites("b.host.com", dns.TypeA)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Equal(t, "x.somehost.com", r.CanonName)
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
|
||||
}
|
||||
|
||||
func TestRewritesLevels(t *testing.T) {
|
||||
d := newForTest(nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// exact host, wildcard L2, wildcard L3
|
||||
d.Rewrites = []RewriteEntry{
|
||||
{"host.com", "1.1.1.1", 0, nil},
|
||||
{"*.host.com", "2.2.2.2", 0, nil},
|
||||
{"*.sub.host.com", "3.3.3.3", 0, nil},
|
||||
}
|
||||
// Exact host, wildcard L2, wildcard L3.
|
||||
d.Rewrites = []RewriteEntry{{
|
||||
Domain: "host.com",
|
||||
Answer: "1.1.1.1",
|
||||
}, {
|
||||
Domain: "*.host.com",
|
||||
Answer: "2.2.2.2",
|
||||
}, {
|
||||
Domain: "*.sub.host.com",
|
||||
Answer: "3.3.3.3",
|
||||
}}
|
||||
d.prepareRewrites()
|
||||
|
||||
// match exact
|
||||
r := d.processRewrites("host.com", dns.TypeA)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.True(t, net.IP{1, 1, 1, 1}.Equal(r.IPList[0]))
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
want net.IP
|
||||
}{{
|
||||
name: "exact_match",
|
||||
host: "host.com",
|
||||
want: net.IP{1, 1, 1, 1},
|
||||
}, {
|
||||
name: "l2_match",
|
||||
host: "sub.host.com",
|
||||
want: net.IP{2, 2, 2, 2},
|
||||
}, {
|
||||
name: "l3_match",
|
||||
host: "my.sub.host.com",
|
||||
want: net.IP{3, 3, 3, 3},
|
||||
}}
|
||||
|
||||
// match L2
|
||||
r = d.processRewrites("sub.host.com", dns.TypeA)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
|
||||
|
||||
// match L3
|
||||
r = d.processRewrites("my.sub.host.com", dns.TypeA)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.True(t, net.IP{3, 3, 3, 3}.Equal(r.IPList[0]))
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := d.processRewrites(tc.host, dns.TypeA)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
require.Len(t, r.IPList, 1)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewritesExceptionCNAME(t *testing.T) {
|
||||
d := newForTest(nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// wildcard; exception for a sub-domain
|
||||
d.Rewrites = []RewriteEntry{
|
||||
{"*.host.com", "2.2.2.2", 0, nil},
|
||||
{"sub.host.com", "sub.host.com", 0, nil},
|
||||
}
|
||||
// Wildcard and exception for a sub-domain.
|
||||
d.Rewrites = []RewriteEntry{{
|
||||
Domain: "*.host.com",
|
||||
Answer: "2.2.2.2",
|
||||
}, {
|
||||
Domain: "sub.host.com",
|
||||
Answer: "sub.host.com",
|
||||
}, {
|
||||
Domain: "*.sub.host.com",
|
||||
Answer: "*.sub.host.com",
|
||||
}}
|
||||
d.prepareRewrites()
|
||||
|
||||
// match sub-domain
|
||||
r := d.processRewrites("my.host.com", dns.TypeA)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
want net.IP
|
||||
}{{
|
||||
name: "match_sub-domain",
|
||||
host: "my.host.com",
|
||||
want: net.IP{2, 2, 2, 2},
|
||||
}, {
|
||||
name: "exception_cname",
|
||||
host: "sub.host.com",
|
||||
}, {
|
||||
name: "exception_wildcard",
|
||||
host: "my.sub.host.com",
|
||||
}}
|
||||
|
||||
// match sub-domain, but handle exception
|
||||
r = d.processRewrites("sub.host.com", dns.TypeA)
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := d.processRewrites(tc.host, dns.TypeA)
|
||||
if tc.want == nil {
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
||||
|
||||
func TestRewritesExceptionWC(t *testing.T) {
|
||||
d := newForTest(nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// wildcard; exception for a sub-wildcard
|
||||
d.Rewrites = []RewriteEntry{
|
||||
{"*.host.com", "2.2.2.2", 0, nil},
|
||||
{"*.sub.host.com", "*.sub.host.com", 0, nil},
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
require.Len(t, r.IPList, 1)
|
||||
assert.True(t, tc.want.Equal(r.IPList[0]))
|
||||
})
|
||||
}
|
||||
d.prepareRewrites()
|
||||
|
||||
// match sub-domain
|
||||
r := d.processRewrites("my.host.com", dns.TypeA)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
|
||||
|
||||
// match sub-domain, but handle exception
|
||||
r = d.processRewrites("my.sub.host.com", dns.TypeA)
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
||||
}
|
||||
|
||||
func TestRewritesExceptionIP(t *testing.T) {
|
||||
d := newForTest(nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// exception for AAAA record
|
||||
d.Rewrites = []RewriteEntry{
|
||||
{"host.com", "1.2.3.4", 0, nil},
|
||||
{"host.com", "AAAA", 0, nil},
|
||||
{"host2.com", "::1", 0, nil},
|
||||
{"host2.com", "A", 0, nil},
|
||||
{"host3.com", "A", 0, nil},
|
||||
}
|
||||
// Exception for AAAA record.
|
||||
d.Rewrites = []RewriteEntry{{
|
||||
Domain: "host.com",
|
||||
Answer: "1.2.3.4",
|
||||
}, {
|
||||
Domain: "host.com",
|
||||
Answer: "AAAA",
|
||||
}, {
|
||||
Domain: "host2.com",
|
||||
Answer: "::1",
|
||||
}, {
|
||||
Domain: "host2.com",
|
||||
Answer: "A",
|
||||
}, {
|
||||
Domain: "host3.com",
|
||||
Answer: "A",
|
||||
}}
|
||||
d.prepareRewrites()
|
||||
|
||||
// match domain
|
||||
r := d.processRewrites("host.com", dns.TypeA)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.True(t, net.IP{1, 2, 3, 4}.Equal(r.IPList[0]))
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
dtyp uint16
|
||||
want []net.IP
|
||||
}{{
|
||||
name: "match_A",
|
||||
host: "host.com",
|
||||
dtyp: dns.TypeA,
|
||||
want: []net.IP{{1, 2, 3, 4}},
|
||||
}, {
|
||||
name: "exception_AAAA_host.com",
|
||||
host: "host.com",
|
||||
dtyp: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "exception_A_host2.com",
|
||||
host: "host2.com",
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "match_AAAA_host2.com",
|
||||
host: "host2.com",
|
||||
dtyp: dns.TypeAAAA,
|
||||
want: []net.IP{net.ParseIP("::1")},
|
||||
}, {
|
||||
name: "exception_A_host3.com",
|
||||
host: "host3.com",
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "match_AAAA_host3.com",
|
||||
host: "host3.com",
|
||||
dtyp: dns.TypeAAAA,
|
||||
want: []net.IP{},
|
||||
}}
|
||||
|
||||
// match exception
|
||||
r = d.processRewrites("host.com", dns.TypeAAAA)
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name+"_"+tc.host, func(t *testing.T) {
|
||||
r := d.processRewrites(tc.host, tc.dtyp)
|
||||
if tc.want == nil {
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
||||
|
||||
// match exception
|
||||
r = d.processRewrites("host2.com", dns.TypeA)
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
||||
return
|
||||
}
|
||||
|
||||
// match domain
|
||||
r = d.processRewrites("host2.com", dns.TypeAAAA)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.Equal(t, "::1", r.IPList[0].String())
|
||||
|
||||
// match exception
|
||||
r = d.processRewrites("host3.com", dns.TypeA)
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
||||
|
||||
// match domain
|
||||
r = d.processRewrites("host3.com", dns.TypeAAAA)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Empty(t, r.IPList)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
require.Len(t, r.IPList, len(tc.want))
|
||||
for _, ip := range tc.want {
|
||||
assert.True(t, ip.Equal(r.IPList[0]))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSafeBrowsingHash(t *testing.T) {
|
||||
@@ -155,25 +156,25 @@ func TestSBPC(t *testing.T) {
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Prepare the upstream.
|
||||
ups := &aghtest.TestBlockUpstream{
|
||||
Hostname: hostname,
|
||||
Block: tc.block,
|
||||
}
|
||||
d.SetSafeBrowsingUpstream(ups)
|
||||
d.SetParentalUpstream(ups)
|
||||
// Prepare the upstream.
|
||||
ups := &aghtest.TestBlockUpstream{
|
||||
Hostname: hostname,
|
||||
Block: tc.block,
|
||||
}
|
||||
d.SetSafeBrowsingUpstream(ups)
|
||||
d.SetParentalUpstream(ups)
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Firstly, check the request blocking.
|
||||
hits := 0
|
||||
res, err := tc.testFunc(hostname)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
if tc.block {
|
||||
assert.True(t, res.IsFiltered)
|
||||
assert.Len(t, res.Rules, 1)
|
||||
require.Len(t, res.Rules, 1)
|
||||
hits++
|
||||
} else {
|
||||
assert.False(t, res.IsFiltered)
|
||||
require.False(t, res.IsFiltered)
|
||||
}
|
||||
|
||||
// Check the cache state, check the response is now cached.
|
||||
@@ -185,12 +186,12 @@ func TestSBPC(t *testing.T) {
|
||||
|
||||
// Now make the same request to check the cache was used.
|
||||
res, err = tc.testFunc(hostname)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
if tc.block {
|
||||
assert.True(t, res.IsFiltered)
|
||||
assert.Len(t, res.Rules, 1)
|
||||
require.Len(t, res.Rules, 1)
|
||||
} else {
|
||||
assert.False(t, res.IsFiltered)
|
||||
require.False(t, res.IsFiltered)
|
||||
}
|
||||
|
||||
// Check the cache state, it should've been used.
|
||||
@@ -199,8 +200,8 @@ func TestSBPC(t *testing.T) {
|
||||
|
||||
// Check that there were no additional requests.
|
||||
assert.Equal(t, 1, ups.RequestsCount())
|
||||
|
||||
purgeCaches()
|
||||
})
|
||||
|
||||
purgeCaches()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user