Pull request: all: sup many ips in host rules

Closes #1381.

Squashed commit of the following:

commit 44965f74d8becb27173cafc533e56e1cde484b59
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Jun 17 13:54:44 2021 +0300

    dnsforward: imp code, docs

commit 23736cf9b407b668faade19b61739536caafff03
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Jun 16 20:33:59 2021 +0300

    dnsforward: rm todo

commit ff086756160d72f3cdbe662862dd1fb447ac5ff3
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Jun 16 20:32:01 2021 +0300

    all: sup many ips in host rules
This commit is contained in:
Ainar Garipov
2021-06-17 14:09:16 +03:00
parent 84e71e912e
commit 7547d3a422
7 changed files with 273 additions and 158 deletions

View File

@@ -344,13 +344,13 @@ var gctx dnsFilterContext
// ResultRule contains information about applied rules.
type ResultRule struct {
// FilterListID is the ID of the rule's filter list.
FilterListID int64 `json:",omitempty"`
// Text is the text of the rule.
Text string `json:",omitempty"`
// IP is the host IP. It is nil unless the rule uses the
// /etc/hosts syntax or the reason is FilteredSafeSearch.
IP net.IP `json:",omitempty"`
// FilterListID is the ID of the rule's filter list.
FilterListID int64 `json:",omitempty"`
}
// Result contains the result of a request check.
@@ -657,26 +657,43 @@ func (d *DNSFilter) initFiltering(allowFilters, blockFilters []Filter) error {
return nil
}
// matchHostProcessAllowList processes the allowlist logic of host
// matching.
func (d *DNSFilter) matchHostProcessAllowList(host string, dnsres urlfilter.DNSResult) (res Result, err error) {
var rule rules.Rule
if dnsres.NetworkRule != nil {
rule = dnsres.NetworkRule
} else if len(dnsres.HostRulesV4) > 0 {
rule = dnsres.HostRulesV4[0]
} else if len(dnsres.HostRulesV6) > 0 {
rule = dnsres.HostRulesV6[0]
// hostRules is a helper that converts a slice of host rules into a slice of the
// rules.Rule interface values.
func hostRulesToRules(netRules []*rules.HostRule) (res []rules.Rule) {
if netRules == nil {
return nil
}
if rule == nil {
res = make([]rules.Rule, len(netRules))
for i, nr := range netRules {
res[i] = nr
}
return res
}
// matchHostProcessAllowList processes the allowlist logic of host
// matching.
func (d *DNSFilter) matchHostProcessAllowList(
host string,
dnsres urlfilter.DNSResult,
) (res Result, err error) {
var matchedRules []rules.Rule
if dnsres.NetworkRule != nil {
matchedRules = []rules.Rule{dnsres.NetworkRule}
} else if len(dnsres.HostRulesV4) > 0 {
matchedRules = hostRulesToRules(dnsres.HostRulesV4)
} else if len(dnsres.HostRulesV6) > 0 {
matchedRules = hostRulesToRules(dnsres.HostRulesV6)
}
if len(matchedRules) == 0 {
return Result{}, fmt.Errorf("invalid dns result: rules are empty")
}
log.Debug("Filtering: found allowlist rule for host %q: %q list_id: %d",
host, rule.Text(), rule.GetFilterListID())
log.Debug("filtering: allowlist rules for host %q: %+v", host, matchedRules)
return makeResult(rule, NotFilteredAllowList), nil
return makeResult(matchedRules, NotFilteredAllowList), nil
}
// matchHostProcessDNSResult processes the matched DNS filtering result.
@@ -690,21 +707,23 @@ func (d *DNSFilter) matchHostProcessDNSResult(
reason = NotFilteredAllowList
}
return makeResult(dnsres.NetworkRule, reason)
return makeResult([]rules.Rule{dnsres.NetworkRule}, reason)
}
if qtype == dns.TypeA && dnsres.HostRulesV4 != nil {
rule := dnsres.HostRulesV4[0]
res = makeResult(rule, FilteredBlockList)
res.Rules[0].IP = rule.IP.To4()
res = makeResult(hostRulesToRules(dnsres.HostRulesV4), FilteredBlockList)
for i, hr := range dnsres.HostRulesV4 {
res.Rules[i].IP = hr.IP.To4()
}
return res
}
if qtype == dns.TypeAAAA && dnsres.HostRulesV6 != nil {
rule := dnsres.HostRulesV6[0]
res = makeResult(rule, FilteredBlockList)
res.Rules[0].IP = rule.IP.To16()
res = makeResult(hostRulesToRules(dnsres.HostRulesV6), FilteredBlockList)
for i, hr := range dnsres.HostRulesV6 {
res.Rules[i].IP = hr.IP.To16()
}
return res
}
@@ -712,17 +731,14 @@ func (d *DNSFilter) matchHostProcessDNSResult(
if dnsres.HostRulesV4 != nil || dnsres.HostRulesV6 != nil {
// Question type doesn't match the host rules. Return the first
// matched host rule, but without an IP address.
var rule rules.Rule
var matchedRules []rules.Rule
if dnsres.HostRulesV4 != nil {
rule = dnsres.HostRulesV4[0]
matchedRules = []rules.Rule{dnsres.HostRulesV4[0]}
} else if dnsres.HostRulesV6 != nil {
rule = dnsres.HostRulesV6[0]
matchedRules = []rules.Rule{dnsres.HostRulesV6[0]}
}
res = makeResult(rule, FilteredBlockList)
res.Rules[0].IP = net.IP{}
return res
return makeResult(matchedRules, FilteredBlockList)
}
return Result{}
@@ -780,8 +796,7 @@ func (d *DNSFilter) matchHost(
}
res = d.matchHostProcessDNSResult(qtype, dnsres)
if len(res.Rules) > 0 {
r := res.Rules[0]
for _, r := range res.Rules {
log.Debug(
"filtering: found rule %q for host %q, filter list id: %d",
r.Text,
@@ -794,20 +809,20 @@ func (d *DNSFilter) matchHost(
}
// makeResult returns a properly constructed Result.
func makeResult(rule rules.Rule, reason Reason) Result {
res := Result{
Reason: reason,
Rules: []*ResultRule{{
FilterListID: int64(rule.GetFilterListID()),
Text: rule.Text(),
}},
func makeResult(matchedRules []rules.Rule, reason Reason) (res Result) {
resRules := make([]*ResultRule, len(matchedRules))
for i, mr := range matchedRules {
resRules[i] = &ResultRule{
FilterListID: int64(mr.GetFilterListID()),
Text: mr.Text(),
}
}
if reason == FilteredBlockList {
res.IsFiltered = true
return Result{
IsFiltered: reason == FilteredBlockList,
Reason: reason,
Rules: resRules,
}
return res
}
// InitModule manually initializes blocked services map.

View File

@@ -60,29 +60,33 @@ func (d *DNSFilter) checkMatch(t *testing.T, hostname string) {
t.Helper()
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
require.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname)
require.NoErrorf(t, err, "host %q", hostname)
assert.Truef(t, res.IsFiltered, "host %q", hostname)
}
func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16) {
t.Helper()
res, err := d.CheckHost(hostname, qtype, &setts)
require.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
assert.Truef(t, res.IsFiltered, "Expected hostname %s to match", hostname)
require.NoErrorf(t, err, "host %q", hostname, err)
require.NotEmpty(t, res.Rules, "host %q", hostname)
assert.Truef(t, res.IsFiltered, "host %q", hostname)
require.NotEmpty(t, res.Rules, "Expected result to have rules")
r := res.Rules[0]
require.NotNilf(t, r.IP, "Expected ip %s to match, actual: %v", ip, r.IP)
assert.Equalf(t, ip, r.IP.String(), "Expected ip %s to match, actual: %v", ip, r.IP)
assert.Equalf(t, ip, r.IP.String(), "host %q", hostname)
}
func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string) {
t.Helper()
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
require.Nilf(t, err, "Error while matching host %s: %s", hostname, err)
assert.Falsef(t, res.IsFiltered, "Expected hostname %s to not match", hostname)
require.NoErrorf(t, err, "host %q", hostname)
assert.Falsef(t, res.IsFiltered, "host %q", hostname)
}
func TestEtcHostsMatching(t *testing.T) {
@@ -112,10 +116,12 @@ func TestEtcHostsMatching(t *testing.T) {
// Empty IPv6.
res, err := d.CheckHost("block.com", dns.TypeAAAA, &setts)
require.Nil(t, err)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
require.Len(t, res.Rules, 1)
assert.Equal(t, "0.0.0.0 block.com", res.Rules[0].Text)
assert.Empty(t, res.Rules[0].IP)
@@ -124,27 +130,34 @@ func TestEtcHostsMatching(t *testing.T) {
// Empty IPv4.
res, err = d.CheckHost("ipv6.com", dns.TypeA, &setts)
require.Nil(t, err)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
require.Len(t, res.Rules, 1)
assert.Equal(t, "::1 ipv6.com", res.Rules[0].Text)
assert.Empty(t, res.Rules[0].IP)
// Two IPv4, the first one returned.
// Two IPv4, both must be returned.
res, err = d.CheckHost("host2", dns.TypeA, &setts)
require.Nil(t, err)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
require.Len(t, res.Rules, 1)
require.Len(t, res.Rules, 2)
assert.Equal(t, res.Rules[0].IP, net.IP{0, 0, 0, 1})
assert.Equal(t, res.Rules[1].IP, net.IP{0, 0, 0, 2})
// One IPv6 address.
res, err = d.CheckHost("host2", dns.TypeAAAA, &setts)
require.Nil(t, err)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
require.Len(t, res.Rules, 1)
assert.Equal(t, res.Rules[0].IP, net.IPv6loopback)
}
@@ -205,8 +218,9 @@ func TestSafeSearch(t *testing.T) {
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close)
val, ok := d.SafeSearchDomain("www.google.com")
require.True(t, ok, "Expected safesearch to find result for www.google.com")
assert.Equal(t, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
require.True(t, ok)
assert.Equal(t, "forcesafesearch.google.com", val)
}
func TestCheckHostSafeSearchYandex(t *testing.T) {
@@ -226,10 +240,12 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
} {
t.Run(strings.ToLower(host), func(t *testing.T) {
res, err := d.CheckHost(host, dns.TypeA, &setts)
require.Nil(t, err)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
require.Len(t, res.Rules, 1)
assert.Equal(t, yandexIP, res.Rules[0].IP)
})
}
@@ -257,9 +273,12 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
} {
t.Run(host, func(t *testing.T) {
res, err := d.CheckHost(host, dns.TypeA, &setts)
require.Nil(t, err)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
require.Len(t, res.Rules, 1)
assert.Equal(t, ip, res.Rules[0].IP)
})
}
@@ -272,8 +291,10 @@ func TestSafeSearchCacheYandex(t *testing.T) {
// Check host with disabled safesearch.
res, err := d.CheckHost(domain, dns.TypeA, &setts)
require.Nil(t, err)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
require.Empty(t, res.Rules)
yandexIP := net.IPv4(213, 180, 193, 56)
@@ -282,7 +303,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
t.Cleanup(d.Close)
res, err = d.CheckHost(domain, dns.TypeA, &setts)
require.Nilf(t, err, "CheckHost for safesearh domain %s failed cause %s", domain, err)
require.NoError(t, err)
// For yandex we already know valid IP.
require.Len(t, res.Rules, 1)
@@ -292,6 +313,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
}
@@ -304,8 +326,10 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
const domain = "www.google.ru"
res, err := d.CheckHost(domain, dns.TypeA, &setts)
require.Nil(t, err)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
require.Empty(t, res.Rules)
d = newForTest(&Config{SafeSearchEnabled: true}, nil)
@@ -314,10 +338,10 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
// Lookup for safesearch domain.
safeDomain, ok := d.SafeSearchDomain(domain)
require.Truef(t, ok, "Failed to get safesearch domain for %s", domain)
require.True(t, ok)
ips, err := resolver.LookupIP(context.Background(), "ip", safeDomain)
require.Nilf(t, err, "Failed to lookup for %s", safeDomain)
require.NoError(t, err)
var ip net.IP
for _, foundIP := range ips {
@@ -329,14 +353,16 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
}
res, err = d.CheckHost(domain, dns.TypeA, &setts)
require.Nil(t, err)
require.NoError(t, err)
require.Len(t, res.Rules, 1)
assert.True(t, res.Rules[0].IP.Equal(ip))
// Check cache.
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.True(t, cachedValue.Rules[0].IP.Equal(ip))
}
@@ -357,6 +383,7 @@ func TestParentalControl(t *testing.T) {
d.checkMatch(t, matching)
require.Contains(t, logOutput.String(), "Parental lookup for "+matching)
d.checkMatch(t, "www."+matching)
d.checkMatchEmpty(t, "www.yandex.ru")
d.checkMatchEmpty(t, "yandex.ru")
@@ -654,7 +681,8 @@ func TestMatching(t *testing.T) {
t.Cleanup(d.Close)
res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts)
require.Nilf(t, err, "Error while matching host %s: %s", tc.host, err)
require.NoError(t, err)
assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered)
assert.Equalf(t, tc.wantReason, res.Reason, "Hostname %s has wrong reason (%v must be %v)", tc.host, res.Reason, tc.wantReason)
})
@@ -677,23 +705,31 @@ func TestWhitelist(t *testing.T) {
}}
d := newForTest(nil, filters)
require.Nil(t, d.SetFilters(filters, whiteFilters, false))
err := d.SetFilters(filters, whiteFilters, false)
require.NoError(t, err)
t.Cleanup(d.Close)
// Matched by white filter.
res, err := d.CheckHost("host1", dns.TypeA, &setts)
require.Nil(t, err)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
assert.Equal(t, res.Reason, NotFilteredAllowList)
require.Len(t, res.Rules, 1)
assert.Equal(t, "||host1^", res.Rules[0].Text)
// Not matched by white filter, but matched by block filter.
res, err = d.CheckHost("host2", dns.TypeA, &setts)
require.Nil(t, err)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
assert.Equal(t, res.Reason, FilteredBlockList)
require.Len(t, res.Rules, 1)
assert.Equal(t, "||host2^", res.Rules[0].Text)
}
@@ -796,7 +832,8 @@ func BenchmarkSafeBrowsing(b *testing.B) {
})
for n := 0; n < b.N; n++ {
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
require.Nilf(b, err, "Error while matching host %s: %s", blocked, err)
require.NoError(b, err)
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
}
}
@@ -812,7 +849,8 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
require.Nilf(b, err, "Error while matching host %s: %s", blocked, err)
require.NoError(b, err)
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
}
})
@@ -823,7 +861,8 @@ func BenchmarkSafeSearch(b *testing.B) {
b.Cleanup(d.Close)
for n := 0; n < b.N; n++ {
val, ok := d.SafeSearchDomain("www.google.com")
require.True(b, ok, "Expected safesearch to find result for www.google.com")
require.True(b, ok)
assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
}
}
@@ -834,7 +873,8 @@ func BenchmarkSafeSearchParallel(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
val, ok := d.SafeSearchDomain("www.google.com")
require.True(b, ok, "Expected safesearch to find result for www.google.com")
require.True(b, ok)
assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
}
})