all: sync with master

This commit is contained in:
Ainar Garipov
2024-05-15 13:34:12 +03:00
parent 6318fc424b
commit 667263a3a8
82 changed files with 2356 additions and 1817 deletions

View File

@@ -3,11 +3,9 @@ package safesearch
import (
"bytes"
"context"
"encoding/binary"
"encoding/gob"
"fmt"
"net"
"net/netip"
"strings"
"sync"
@@ -67,7 +65,6 @@ type Default struct {
engine *urlfilter.DNSEngine
cache cache.Cache
resolver filtering.Resolver
logPrefix string
cacheTTL time.Duration
}
@@ -80,11 +77,6 @@ func NewDefault(
cacheSize uint,
cacheTTL time.Duration,
) (ss *Default, err error) {
var resolver filtering.Resolver = net.DefaultResolver
if conf.CustomResolver != nil {
resolver = conf.CustomResolver
}
ss = &Default{
mu: &sync.RWMutex{},
@@ -92,7 +84,6 @@ func NewDefault(
EnableLRU: true,
MaxSize: cacheSize,
}),
resolver: resolver,
// Use %s, because the client safe-search names already contain double
// quotes.
logPrefix: fmt.Sprintf("safesearch %s: ", name),
@@ -170,8 +161,11 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res
ss.log(log.DEBUG, "lookup for %q finished in %s", host, time.Since(start))
}()
if qtype != dns.TypeA && qtype != dns.TypeAAAA {
return filtering.Result{}, fmt.Errorf("unsupported question type %s", dns.Type(qtype))
switch qtype {
case dns.TypeA, dns.TypeAAAA, dns.TypeHTTPS:
// Go on.
default:
return filtering.Result{}, nil
}
// Check cache. Return cached result if it was found
@@ -195,6 +189,9 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res
}
res = *fltRes
// TODO(a.garipov): Consider switch back to resolving CNAME records IPs and
// saving results to cache.
ss.setCacheResult(host, qtype, res)
return res, nil
@@ -223,20 +220,13 @@ func (ss *Default) searchHost(host string, qtype rules.RRType) (res *rules.DNSRe
}
// newResult creates Result object from rewrite rule. qtype must be either
// [dns.TypeA] or [dns.TypeAAAA]. If err is nil, res is never nil, so that the
// empty result is converted into a NODATA response.
//
// TODO(a.garipov): Use the main rewrite result mechanism used in
// [dnsforward.Server.filterDNSRequest]. Now we resolve IPs for CNAME to save
// them in the safe search cache.
// [dns.TypeA] or [dns.TypeAAAA], or [dns.TypeHTTPS]. If err is nil, res is
// never nil, so that the empty result is converted into a NODATA response.
func (ss *Default) newResult(
rewrite *rules.DNSRewrite,
qtype rules.RRType,
) (res *filtering.Result, err error) {
res = &filtering.Result{
Rules: []*filtering.ResultRule{{
FilterListID: rulelist.URLFilterIDSafeSearch,
}},
Reason: filtering.FilteredSafeSearch,
IsFiltered: true,
}
@@ -247,69 +237,19 @@ func (ss *Default) newResult(
return nil, fmt.Errorf("expected ip rewrite value, got %T(%[1]v)", rewrite.Value)
}
res.Rules[0].IP = ip
res.Rules = []*filtering.ResultRule{{
FilterListID: rulelist.URLFilterIDSafeSearch,
IP: ip,
}}
return res, nil
}
host := rewrite.NewCNAME
if host == "" {
return res, nil
}
res.CanonName = host
ss.log(log.DEBUG, "resolving %q", host)
ips, err := ss.resolver.LookupIP(context.Background(), qtypeToProto(qtype), host)
if err != nil {
return nil, fmt.Errorf("resolving cname: %w", err)
}
ss.log(log.DEBUG, "resolved %s", ips)
for _, ip := range ips {
// TODO(a.garipov): Remove this filtering once the resolver we use
// actually learns about network.
addr := fitToProto(ip, qtype)
if addr == (netip.Addr{}) {
continue
}
// TODO(e.burkov): Rules[0]?
res.Rules[0].IP = addr
}
res.CanonName = rewrite.NewCNAME
return res, nil
}
// qtypeToProto returns "ip4" for [dns.TypeA] and "ip6" for [dns.TypeAAAA].
// It panics for other types.
func qtypeToProto(qtype rules.RRType) (proto string) {
switch qtype {
case dns.TypeA:
return "ip4"
case dns.TypeAAAA:
return "ip6"
default:
panic(fmt.Errorf("safesearch: unsupported question type %s", dns.Type(qtype)))
}
}
// fitToProto returns a non-nil IP address if ip is the correct protocol version
// for qtype. qtype is expected to be either [dns.TypeA] or [dns.TypeAAAA].
func fitToProto(ip net.IP, qtype rules.RRType) (res netip.Addr) {
if ip4 := ip.To4(); qtype == dns.TypeA {
if ip4 != nil {
return netip.AddrFrom4([4]byte(ip4))
}
} else if ip = ip.To16(); ip != nil && qtype == dns.TypeAAAA {
return netip.AddrFrom16([16]byte(ip))
}
return netip.Addr{}
}
// setCacheResult stores data in cache for host. qtype is expected to be either
// [dns.TypeA] or [dns.TypeAAAA].
func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering.Result) {

View File

@@ -1,13 +1,10 @@
package safesearch
import (
"context"
"net"
"net/netip"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
@@ -79,47 +76,6 @@ func TestSafeSearchCacheYandex(t *testing.T) {
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
}
func TestSafeSearchCacheGoogle(t *testing.T) {
const domain = "www.google.ru"
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
res, err := ss.CheckHost(domain, testQType)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules)
resolver := &aghtest.Resolver{
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
ip4, ip6 := aghtest.HostToIPs(host)
return []net.IP{ip4.AsSlice(), ip6.AsSlice()}, nil
},
}
ss = newForTest(t, defaultSafeSearchConf)
ss.resolver = resolver
// Lookup for safesearch domain.
rewrite := ss.searchHost(domain, testQType)
wantIP, _ := aghtest.HostToIPs(rewrite.NewCNAME)
res, err = ss.CheckHost(domain, testQType)
require.NoError(t, err)
require.Len(t, res.Rules, 1)
assert.Equal(t, wantIP, res.Rules[0].IP)
// Check cache.
cachedValue, isFound := ss.getCachedResult(domain, testQType)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.Equal(t, wantIP, cachedValue.Rules[0].IP)
}
const googleHost = "www.google.com"
var dnsRewriteSink *rules.DNSRewrite
@@ -127,7 +83,7 @@ var dnsRewriteSink *rules.DNSRewrite
func BenchmarkSafeSearch(b *testing.B) {
ss := newForTest(b, defaultSafeSearchConf)
for n := 0; n < b.N; n++ {
for range b.N {
dnsRewriteSink = ss.searchHost(googleHost, testQType)
}

View File

@@ -7,7 +7,6 @@ import (
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
@@ -31,8 +30,6 @@ const (
// testConf is the default safe search configuration for tests.
var testConf = filtering.SafeSearchConfig{
CustomResolver: nil,
Enabled: true,
Bing: true,
@@ -52,61 +49,60 @@ func TestDefault_CheckHost_yandex(t *testing.T) {
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
require.NoError(t, err)
// Check host for each domain.
for _, host := range []string{
hosts := []string{
"yandex.ru",
"yAndeX.ru",
"YANdex.COM",
"yandex.by",
"yandex.kz",
"www.yandex.com",
} {
var res filtering.Result
res, err = ss.CheckHost(host, testQType)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
require.Len(t, res.Rules, 1)
assert.Equal(t, yandexIP, res.Rules[0].IP)
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
}
}
func TestDefault_CheckHost_yandexAAAA(t *testing.T) {
conf := testConf
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
require.NoError(t, err)
testCases := []struct {
want netip.Addr
name string
qt uint16
}{{
want: yandexIP,
name: "a",
qt: dns.TypeA,
}, {
want: netip.Addr{},
name: "aaaa",
qt: dns.TypeAAAA,
}, {
want: netip.Addr{},
name: "https",
qt: dns.TypeHTTPS,
}}
res, err := ss.CheckHost("www.yandex.ru", dns.TypeAAAA)
require.NoError(t, err)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, host := range hosts {
// Check host for each domain.
var res filtering.Result
res, err = ss.CheckHost(host, tc.qt)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
assert.True(t, res.IsFiltered)
assert.Equal(t, filtering.FilteredSafeSearch, res.Reason)
// TODO(a.garipov): Currently, the safe-search filter returns a single rule
// with a nil IP address. This isn't really necessary and should be changed
// once the TODO in [safesearch.Default.newResult] is resolved.
require.Len(t, res.Rules, 1)
if tc.want == (netip.Addr{}) {
assert.Empty(t, res.Rules)
} else {
require.Len(t, res.Rules, 1)
assert.Empty(t, res.Rules[0].IP)
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
rule := res.Rules[0]
assert.Equal(t, tc.want, rule.IP)
assert.Equal(t, rulelist.URLFilterIDSafeSearch, rule.FilterListID)
}
}
})
}
}
func TestDefault_CheckHost_google(t *testing.T) {
resolver := &aghtest.Resolver{
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
ip4, ip6 := aghtest.HostToIPs(host)
return []net.IP{ip4.AsSlice(), ip6.AsSlice()}, nil
},
}
wantIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
conf := testConf
conf.CustomResolver = resolver
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
ss, err := safesearch.NewDefault(testConf, "", testCacheSize, testCacheTTL)
require.NoError(t, err)
// Check host for each domain.
@@ -125,11 +121,9 @@ func TestDefault_CheckHost_google(t *testing.T) {
require.NoError(t, err)
assert.True(t, res.IsFiltered)
require.Len(t, res.Rules, 1)
assert.Equal(t, wantIP, res.Rules[0].IP)
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
assert.Equal(t, filtering.FilteredSafeSearch, res.Reason)
assert.Equal(t, "forcesafesearch.google.com", res.CanonName)
assert.Empty(t, res.Rules)
})
}
}
@@ -154,17 +148,7 @@ func (r *testResolver) LookupIP(
}
func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
conf := testConf
conf.CustomResolver = &testResolver{
OnLookupIP: func(_ context.Context, network, host string) (ips []net.IP, err error) {
assert.Equal(t, "ip6", network)
assert.Equal(t, "safe.duckduckgo.com", host)
return nil, nil
},
}
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
ss, err := safesearch.NewDefault(testConf, "", testCacheSize, testCacheTTL)
require.NoError(t, err)
// The DuckDuckGo safe-search addresses are resolved through CNAMEs, but
@@ -174,14 +158,9 @@ func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
require.NoError(t, err)
assert.True(t, res.IsFiltered)
// TODO(a.garipov): Currently, the safe-search filter returns a single rule
// with a nil IP address. This isn't really necessary and should be changed
// once the TODO in [safesearch.Default.newResult] is resolved.
require.Len(t, res.Rules, 1)
assert.Empty(t, res.Rules[0].IP)
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
assert.Equal(t, filtering.FilteredSafeSearch, res.Reason)
assert.Equal(t, "safe.duckduckgo.com", res.CanonName)
assert.Empty(t, res.Rules)
}
func TestDefault_Update(t *testing.T) {