Compare commits
25 Commits
ADG-1-util
...
2499-rewri
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c1be2bab4d | ||
|
|
53cd9b7a1a | ||
|
|
18a6066df5 | ||
|
|
18392943fa | ||
|
|
c2abedec70 | ||
|
|
d3bf5fcb05 | ||
|
|
5a794411d9 | ||
|
|
8e058b8042 | ||
|
|
d76834f843 | ||
|
|
5480bed1f7 | ||
|
|
c5fb7e6b0d | ||
|
|
9efc381224 | ||
|
|
299371e0fd | ||
|
|
12f52f07c5 | ||
|
|
990311c9e0 | ||
|
|
526c358697 | ||
|
|
e657899c32 | ||
|
|
fb3602853a | ||
|
|
2cf171f21e | ||
|
|
e56f465ad8 | ||
|
|
a8e80bc583 | ||
|
|
9a186d0a8a | ||
|
|
2d29455d7f | ||
|
|
55a0dec144 | ||
|
|
6b607e982b |
@@ -390,6 +390,7 @@ export const SPECIAL_FILTER_ID = {
|
||||
PARENTAL: -3,
|
||||
SAFE_BROWSING: -4,
|
||||
SAFE_SEARCH: -5,
|
||||
REWRITES: -6,
|
||||
};
|
||||
|
||||
export const BLOCK_ACTIONS = {
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rewrite"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
@@ -67,7 +68,7 @@ func createTestServer(
|
||||
ID: 0, Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f, err := filtering.New(filterConf, filters)
|
||||
f, err := filtering.New(filterConf, filters, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
f.SetEnabled(true)
|
||||
@@ -760,7 +761,7 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||
Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f, err := filtering.New(&filtering.Config{}, filters)
|
||||
f, err := filtering.New(&filtering.Config{}, filters, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
@@ -880,21 +881,22 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
|
||||
func TestRewrite(t *testing.T) {
|
||||
c := &filtering.Config{
|
||||
Rewrites: []*filtering.LegacyRewrite{{
|
||||
Rewrites: []*filtering.RewriteItem{{
|
||||
Domain: "test.com",
|
||||
Answer: "1.2.3.4",
|
||||
Type: dns.TypeA,
|
||||
}, {
|
||||
Domain: "alias.test.com",
|
||||
Answer: "test.com",
|
||||
Type: dns.TypeCNAME,
|
||||
}, {
|
||||
Domain: "my.alias.example.org",
|
||||
Answer: "example.org",
|
||||
Type: dns.TypeCNAME,
|
||||
}},
|
||||
}
|
||||
f, err := filtering.New(c, nil)
|
||||
|
||||
rewriteStorage, err := rewrite.NewDefaultStorage(c.Rewrites)
|
||||
require.NoError(t, err)
|
||||
|
||||
f, err := filtering.New(c, nil, rewriteStorage)
|
||||
require.NoError(t, err)
|
||||
|
||||
f.SetEnabled(true)
|
||||
@@ -945,6 +947,12 @@ func TestRewrite(t *testing.T) {
|
||||
|
||||
assert.Empty(t, reply.Answer)
|
||||
|
||||
req = createTestMessageWithType("test.com.", dns.TypeTXT)
|
||||
reply, eerr = dns.Exchange(req, addr.String())
|
||||
require.NoError(t, eerr)
|
||||
|
||||
assert.Empty(t, reply.Answer)
|
||||
|
||||
req = createTestMessageWithType("alias.test.com.", dns.TypeA)
|
||||
reply, eerr = dns.Exchange(req, addr.String())
|
||||
require.NoError(t, eerr)
|
||||
@@ -952,8 +960,15 @@ func TestRewrite(t *testing.T) {
|
||||
require.Len(t, reply.Answer, 2)
|
||||
|
||||
assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target)
|
||||
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
|
||||
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
|
||||
|
||||
req = createTestMessageWithType("alias.test.com.", dns.TypeTXT)
|
||||
reply, eerr = dns.Exchange(req, addr.String())
|
||||
require.NoError(t, eerr)
|
||||
|
||||
assert.Empty(t, reply.Answer)
|
||||
|
||||
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
|
||||
reply, eerr = dns.Exchange(req, addr.String())
|
||||
require.NoError(t, eerr)
|
||||
@@ -967,6 +982,12 @@ func TestRewrite(t *testing.T) {
|
||||
|
||||
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
|
||||
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
|
||||
|
||||
req = createTestMessageWithType("my.alias.test.com.", dns.TypeTXT)
|
||||
reply, eerr = dns.Exchange(req, addr.String())
|
||||
require.NoError(t, eerr)
|
||||
|
||||
assert.Empty(t, reply.Answer)
|
||||
}
|
||||
|
||||
for _, protect := range []bool{true, false} {
|
||||
@@ -1011,7 +1032,7 @@ var testDHCP = &dhcpd.MockInterface{
|
||||
func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
const localDomain = "lan"
|
||||
|
||||
flt, err := filtering.New(&filtering.Config{}, nil)
|
||||
flt, err := filtering.New(&filtering.Config{}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
@@ -1085,7 +1106,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
|
||||
flt, err := filtering.New(&filtering.Config{
|
||||
EtcHosts: hc,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
flt.SetEnabled(true)
|
||||
|
||||
@@ -35,7 +35,7 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
ID: 0, Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f, err := filtering.New(&filtering.Config{}, filters)
|
||||
f, err := filtering.New(&filtering.Config{}, filters, nil)
|
||||
require.NoError(t, err)
|
||||
f.SetEnabled(true)
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ func TestFilters(t *testing.T) {
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
f := &FilterYAML{
|
||||
|
||||
@@ -33,7 +33,6 @@ import (
|
||||
// The IDs of built-in filter lists.
|
||||
//
|
||||
// Keep in sync with client/src/helpers/constants.js.
|
||||
// TODO(d.kolyshev): Add RewritesListID and don't forget to keep in sync.
|
||||
const (
|
||||
CustomListID = -iota
|
||||
SysHostsListID
|
||||
@@ -41,6 +40,7 @@ const (
|
||||
ParentalListID
|
||||
SafeBrowsingListID
|
||||
SafeSearchListID
|
||||
RewritesListID
|
||||
)
|
||||
|
||||
// ServiceEntry - blocked service array element
|
||||
@@ -90,7 +90,7 @@ type Config struct {
|
||||
ParentalCacheSize uint `yaml:"parental_cache_size"` // (in bytes)
|
||||
CacheTime uint `yaml:"cache_time"` // Element's TTL (in minutes)
|
||||
|
||||
Rewrites []*LegacyRewrite `yaml:"rewrites"`
|
||||
Rewrites []*RewriteItem `yaml:"rewrites"`
|
||||
|
||||
// Names of services to block (globally).
|
||||
// Per-client settings can override this configuration.
|
||||
@@ -194,6 +194,8 @@ type DNSFilter struct {
|
||||
// TODO(e.burkov): Don't use regexp for such a simple text processing task.
|
||||
filterTitleRegexp *regexp.Regexp
|
||||
|
||||
rewriteStorage RewriteStorage
|
||||
|
||||
hostCheckers []hostChecker
|
||||
}
|
||||
|
||||
@@ -315,7 +317,7 @@ func (d *DNSFilter) WriteDiskConfig(c *Config) {
|
||||
defer d.confLock.Unlock()
|
||||
|
||||
*c = d.Config
|
||||
c.Rewrites = cloneRewrites(c.Rewrites)
|
||||
c.Rewrites = slices.Clone(c.Rewrites)
|
||||
}()
|
||||
|
||||
d.filtersMu.RLock()
|
||||
@@ -326,16 +328,6 @@ func (d *DNSFilter) WriteDiskConfig(c *Config) {
|
||||
c.UserRules = slices.Clone(d.UserRules)
|
||||
}
|
||||
|
||||
// cloneRewrites returns a deep copy of entries.
|
||||
func cloneRewrites(entries []*LegacyRewrite) (clone []*LegacyRewrite) {
|
||||
clone = make([]*LegacyRewrite, len(entries))
|
||||
for i, rw := range entries {
|
||||
clone[i] = rw.clone()
|
||||
}
|
||||
|
||||
return clone
|
||||
}
|
||||
|
||||
// SetFilters sets new filters, synchronously or asynchronously. When filters
|
||||
// are set asynchronously, the old filters continue working until the new
|
||||
// filters are ready.
|
||||
@@ -546,75 +538,52 @@ func (d *DNSFilter) matchSysHosts(
|
||||
// CNAME, breaking loops in the process.
|
||||
//
|
||||
// Secondly, it finds A or AAAA rewrites for host and, if found, sets res.IPList
|
||||
// accordingly. If the found rewrite has a special value of "A" or "AAAA", the
|
||||
// result is an exception.
|
||||
// accordingly.
|
||||
func (d *DNSFilter) processRewrites(host string, qtype uint16) (res Result) {
|
||||
d.confLock.RLock()
|
||||
defer d.confLock.RUnlock()
|
||||
|
||||
rewrites, matched := findRewrites(d.Rewrites, host, qtype)
|
||||
if !matched {
|
||||
return Result{}
|
||||
if d.rewriteStorage == nil {
|
||||
return res
|
||||
}
|
||||
|
||||
res.Reason = Rewritten
|
||||
dnsr := d.rewriteStorage.MatchRequest(&urlfilter.DNSRequest{
|
||||
Hostname: host,
|
||||
DNSType: qtype,
|
||||
})
|
||||
|
||||
cnames := stringutil.NewSet()
|
||||
origHost := host
|
||||
for matched && len(rewrites) > 0 && rewrites[0].Type == dns.TypeCNAME {
|
||||
rw := rewrites[0]
|
||||
rwPat := rw.Domain
|
||||
rwAns := rw.Answer
|
||||
|
||||
log.Debug("rewrite: cname for %s is %s", host, rwAns)
|
||||
|
||||
if origHost == rwAns || rwPat == rwAns {
|
||||
// Either a request for the hostname itself or a rewrite of
|
||||
// a pattern onto itself, both of which are an exception rules.
|
||||
// Return a not filtered result.
|
||||
return Result{}
|
||||
} else if host == rwAns && isWildcard(rwPat) {
|
||||
// An "*.example.com → sub.example.com" rewrite matching in a loop.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/4016.
|
||||
|
||||
res.CanonName = host
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
host = rwAns
|
||||
if cnames.Has(host) {
|
||||
log.Info("rewrite: cname loop for %q on %q", origHost, host)
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
cnames.Add(host)
|
||||
res.CanonName = host
|
||||
rewrites, matched = findRewrites(d.Rewrites, host, qtype)
|
||||
}
|
||||
|
||||
setRewriteResult(&res, host, rewrites, qtype)
|
||||
setRewriteResult(&res, host, dnsr, qtype)
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// setRewriteResult sets the Reason or IPList of res if necessary. res must not
|
||||
// be nil.
|
||||
func setRewriteResult(res *Result, host string, rewrites []*LegacyRewrite, qtype uint16) {
|
||||
for _, rw := range rewrites {
|
||||
if rw.Type == qtype && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
|
||||
if rw.IP == nil {
|
||||
// "A"/"AAAA" exception: allow getting from upstream.
|
||||
res.Reason = NotFilteredNotFound
|
||||
func setRewriteResult(res *Result, host string, dnsr []*rules.DNSRewrite, qtype uint16) {
|
||||
if len(dnsr) == 0 {
|
||||
res.Reason = NotFilteredNotFound
|
||||
|
||||
return
|
||||
return
|
||||
}
|
||||
|
||||
res.Reason = Rewritten
|
||||
|
||||
for _, dnsRewrite := range dnsr {
|
||||
if dnsRewrite.RRType == qtype && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
|
||||
ip, ok := dnsRewrite.Value.(net.IP)
|
||||
if !ok || ip == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
res.IPList = append(res.IPList, rw.IP)
|
||||
if qtype == dns.TypeA {
|
||||
ip = ip.To4()
|
||||
}
|
||||
|
||||
log.Debug("rewrite: a/aaaa for %s is %s", host, rw.IP)
|
||||
res.IPList = append(res.IPList, ip)
|
||||
|
||||
log.Debug("rewrite: a/aaaa for %s is %s", host, ip)
|
||||
} else if dnsRewrite.NewCNAME != "" {
|
||||
res.CanonName = dnsRewrite.NewCNAME
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -927,7 +896,7 @@ func InitModule() {
|
||||
|
||||
// New creates properly initialized DNS Filter that is ready to be used. c must
|
||||
// be non-nil.
|
||||
func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
||||
func New(c *Config, blockFilters []Filter, rewriteStorage RewriteStorage) (d *DNSFilter, err error) {
|
||||
d = &DNSFilter{
|
||||
resolver: net.DefaultResolver,
|
||||
refreshLock: &sync.Mutex{},
|
||||
@@ -980,11 +949,7 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
||||
|
||||
d.Config = *c
|
||||
d.filtersMu = &sync.RWMutex{}
|
||||
|
||||
err = d.prepareRewrites()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rewrites: preparing: %s", err)
|
||||
}
|
||||
d.rewriteStorage = rewriteStorage
|
||||
|
||||
bsvcs := []string{}
|
||||
for _, s := range d.BlockedServices {
|
||||
|
||||
@@ -46,6 +46,7 @@ func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts
|
||||
ProtectionEnabled: true,
|
||||
FilteringEnabled: true,
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
c.SafeBrowsingCacheSize = 10000
|
||||
c.ParentalCacheSize = 10000
|
||||
@@ -58,7 +59,8 @@ func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts
|
||||
// It must not be nil.
|
||||
c = &Config{}
|
||||
}
|
||||
f, err := New(c, filters)
|
||||
|
||||
f, err := New(c, filters, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
purgeCaches(f)
|
||||
@@ -417,274 +419,275 @@ func TestMatching(t *testing.T) {
|
||||
host string
|
||||
wantReason Reason
|
||||
wantIsFiltered bool
|
||||
wantDNSType uint16
|
||||
qtype uint16
|
||||
}{{
|
||||
name: "sanity",
|
||||
rules: "||doubleclick.net^",
|
||||
host: "www.doubleclick.net",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "sanity",
|
||||
rules: "||doubleclick.net^",
|
||||
host: "nodoubleclick.net",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "sanity",
|
||||
rules: "||doubleclick.net^",
|
||||
host: "doubleclick.net.ru",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "sanity",
|
||||
rules: "||doubleclick.net^",
|
||||
host: sbBlocked,
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "blocking",
|
||||
rules: blockingRules,
|
||||
host: "example.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "blocking",
|
||||
rules: blockingRules,
|
||||
host: "test.example.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "blocking",
|
||||
rules: blockingRules,
|
||||
host: "test.test.example.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "blocking",
|
||||
rules: blockingRules,
|
||||
host: "testexample.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "blocking",
|
||||
rules: blockingRules,
|
||||
host: "onemoreexample.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "allowlist",
|
||||
rules: allowlistRules,
|
||||
host: "example.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "allowlist",
|
||||
rules: allowlistRules,
|
||||
host: "test.example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredAllowList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "allowlist",
|
||||
rules: allowlistRules,
|
||||
host: "test.test.example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredAllowList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "allowlist",
|
||||
rules: allowlistRules,
|
||||
host: "testexample.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "allowlist",
|
||||
rules: allowlistRules,
|
||||
host: "onemoreexample.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "important",
|
||||
rules: importantRules,
|
||||
host: "example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredAllowList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "important",
|
||||
rules: importantRules,
|
||||
host: "test.example.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "important",
|
||||
rules: importantRules,
|
||||
host: "test.test.example.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "important",
|
||||
rules: importantRules,
|
||||
host: "testexample.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "important",
|
||||
rules: importantRules,
|
||||
host: "onemoreexample.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "regex",
|
||||
rules: regexRules,
|
||||
host: "example.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "regex",
|
||||
rules: regexRules,
|
||||
host: "test.example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredAllowList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "regex",
|
||||
rules: regexRules,
|
||||
host: "test.test.example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredAllowList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "regex",
|
||||
rules: regexRules,
|
||||
host: "testexample.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "regex",
|
||||
rules: regexRules,
|
||||
host: "onemoreexample.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "mask",
|
||||
rules: maskRules,
|
||||
host: "test.example.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "mask",
|
||||
rules: maskRules,
|
||||
host: "test2.example.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "mask",
|
||||
rules: maskRules,
|
||||
host: "example.com",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "mask",
|
||||
rules: maskRules,
|
||||
host: "exampleeee.com",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "mask",
|
||||
rules: maskRules,
|
||||
host: "onemoreexamsite.com",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "mask",
|
||||
rules: maskRules,
|
||||
host: "example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "mask",
|
||||
rules: maskRules,
|
||||
host: "testexample.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "mask",
|
||||
rules: maskRules,
|
||||
host: "example.co.uk",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "dnstype",
|
||||
rules: dnstypeRules,
|
||||
host: "onemoreexample.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "dnstype",
|
||||
rules: dnstypeRules,
|
||||
host: "example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "dnstype",
|
||||
rules: dnstypeRules,
|
||||
host: "example.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeAAAA,
|
||||
qtype: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "dnstype",
|
||||
rules: dnstypeRules,
|
||||
host: "test.example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredAllowList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "dnstype",
|
||||
rules: dnstypeRules,
|
||||
host: "test.example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredAllowList,
|
||||
wantDNSType: dns.TypeAAAA,
|
||||
qtype: dns.TypeAAAA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("%s-%s", tc.name, tc.host), func(t *testing.T) {
|
||||
filters := []Filter{{ID: 0, Data: []byte(tc.rules)}}
|
||||
d, setts := newForTest(t, nil, filters)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
res, err := d.CheckHost(tc.host, tc.wantDNSType, setts)
|
||||
res, err := d.CheckHost(tc.host, tc.qtype, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered)
|
||||
|
||||
@@ -105,7 +105,7 @@ func TestDNSFilter_handleFilteringSetURL(t *testing.T) {
|
||||
},
|
||||
ConfigModified: func() { confModifiedCalled = true },
|
||||
DataDir: filtersDir,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
|
||||
42
internal/filtering/rewrite.go
Normal file
42
internal/filtering/rewrite.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package filtering
|
||||
|
||||
import (
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
)
|
||||
|
||||
// RewriteStorage is a storage for rewrite rules.
|
||||
type RewriteStorage interface {
|
||||
// MatchRequest returns matching dnsrewrites for the specified request.
|
||||
MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.DNSRewrite)
|
||||
|
||||
// Add adds item to the storage.
|
||||
Add(item *RewriteItem) (err error)
|
||||
|
||||
// Remove deletes item from the storage.
|
||||
Remove(item *RewriteItem) (err error)
|
||||
|
||||
// List returns all items from the storage.
|
||||
List() (items []*RewriteItem)
|
||||
}
|
||||
|
||||
// RewriteItem is a single DNS rewrite record.
|
||||
type RewriteItem struct {
|
||||
// Domain is the domain pattern for which this rewrite should work.
|
||||
Domain string `yaml:"domain" json:"domain"`
|
||||
|
||||
// Answer is the IP address, canonical name, or one of the special
|
||||
// values: "A" or "AAAA".
|
||||
Answer string `yaml:"answer" json:"answer"`
|
||||
}
|
||||
|
||||
// Equal returns true if rw is Equal to other.
|
||||
func (rw *RewriteItem) Equal(other *RewriteItem) (ok bool) {
|
||||
if rw == nil {
|
||||
return other == nil
|
||||
} else if other == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return *rw == *other
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
package rewrite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Item is a single DNS rewrite record.
|
||||
type Item struct {
|
||||
// Domain is the domain pattern for which this rewrite should work.
|
||||
Domain string `yaml:"domain"`
|
||||
|
||||
// Answer is the IP address, canonical name, or one of the special
|
||||
// values: "A" or "AAAA".
|
||||
Answer string `yaml:"answer"`
|
||||
}
|
||||
|
||||
// equal returns true if rw is equal to other.
|
||||
func (rw *Item) equal(other *Item) (ok bool) {
|
||||
if rw == nil {
|
||||
return other == nil
|
||||
} else if other == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return *rw == *other
|
||||
}
|
||||
|
||||
// toRule converts rw to a filter rule.
|
||||
func (rw *Item) toRule() (res string) {
|
||||
if rw == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
domain := strings.ToLower(rw.Domain)
|
||||
|
||||
dType, exception := rw.rewriteParams()
|
||||
dTypeKey := dns.TypeToString[dType]
|
||||
if exception {
|
||||
return fmt.Sprintf("@@||%s^$dnstype=%s,dnsrewrite", domain, dTypeKey)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("|%s^$dnsrewrite=NOERROR;%s;%s", domain, dTypeKey, rw.Answer)
|
||||
}
|
||||
|
||||
// rewriteParams returns dns request type and exception flag for rw.
|
||||
func (rw *Item) rewriteParams() (dType uint16, exception bool) {
|
||||
switch rw.Answer {
|
||||
case "AAAA":
|
||||
return dns.TypeAAAA, true
|
||||
case "A":
|
||||
return dns.TypeA, true
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(rw.Answer)
|
||||
if err != nil {
|
||||
// TODO(d.kolyshev): Validate rw.Answer as a domain name.
|
||||
return dns.TypeCNAME, false
|
||||
}
|
||||
|
||||
if addr.Is4() {
|
||||
dType = dns.TypeA
|
||||
} else {
|
||||
dType = dns.TypeAAAA
|
||||
}
|
||||
|
||||
return dType, false
|
||||
}
|
||||
@@ -1,124 +0,0 @@
|
||||
package rewrite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestItem_equal(t *testing.T) {
|
||||
const (
|
||||
testDomain = "example.org"
|
||||
testAnswer = "1.1.1.1"
|
||||
)
|
||||
|
||||
testItem := &Item{
|
||||
Domain: testDomain,
|
||||
Answer: testAnswer,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
left *Item
|
||||
right *Item
|
||||
want bool
|
||||
}{{
|
||||
name: "nil_left",
|
||||
left: nil,
|
||||
right: testItem,
|
||||
want: false,
|
||||
}, {
|
||||
name: "nil_right",
|
||||
left: testItem,
|
||||
right: nil,
|
||||
want: false,
|
||||
}, {
|
||||
name: "nils",
|
||||
left: nil,
|
||||
right: nil,
|
||||
want: true,
|
||||
}, {
|
||||
name: "equal",
|
||||
left: testItem,
|
||||
right: testItem,
|
||||
want: true,
|
||||
}, {
|
||||
name: "distinct",
|
||||
left: testItem,
|
||||
right: &Item{
|
||||
Domain: "other",
|
||||
Answer: "other",
|
||||
},
|
||||
want: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
res := tc.left.equal(tc.right)
|
||||
assert.Equal(t, tc.want, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestItem_toRule(t *testing.T) {
|
||||
const testDomain = "example.org"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
item *Item
|
||||
want string
|
||||
}{{
|
||||
name: "nil",
|
||||
item: nil,
|
||||
want: "",
|
||||
}, {
|
||||
name: "a_rule",
|
||||
item: &Item{
|
||||
Domain: testDomain,
|
||||
Answer: "1.1.1.1",
|
||||
},
|
||||
want: "|example.org^$dnsrewrite=NOERROR;A;1.1.1.1",
|
||||
}, {
|
||||
name: "aaaa_rule",
|
||||
item: &Item{
|
||||
Domain: testDomain,
|
||||
Answer: "1:2:3::4",
|
||||
},
|
||||
want: "|example.org^$dnsrewrite=NOERROR;AAAA;1:2:3::4",
|
||||
}, {
|
||||
name: "cname_rule",
|
||||
item: &Item{
|
||||
Domain: testDomain,
|
||||
Answer: "other.org",
|
||||
},
|
||||
want: "|example.org^$dnsrewrite=NOERROR;CNAME;other.org",
|
||||
}, {
|
||||
name: "wildcard_rule",
|
||||
item: &Item{
|
||||
Domain: "*.example.org",
|
||||
Answer: "other.org",
|
||||
},
|
||||
want: "|*.example.org^$dnsrewrite=NOERROR;CNAME;other.org",
|
||||
}, {
|
||||
name: "aaaa_exception",
|
||||
item: &Item{
|
||||
Domain: testDomain,
|
||||
Answer: "A",
|
||||
},
|
||||
want: "@@||example.org^$dnstype=A,dnsrewrite",
|
||||
}, {
|
||||
name: "aaaa_exception",
|
||||
item: &Item{
|
||||
Domain: testDomain,
|
||||
Answer: "AAAA",
|
||||
},
|
||||
want: "@@||example.org^$dnstype=AAAA,dnsrewrite",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
res := tc.item.toRule()
|
||||
assert.Equal(t, tc.want, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3,9 +3,11 @@ package rewrite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
@@ -15,21 +17,6 @@ import (
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// Storage is a storage for rewrite rules.
|
||||
type Storage interface {
|
||||
// MatchRequest returns matching dnsrewrites for the specified request.
|
||||
MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.DNSRewrite)
|
||||
|
||||
// Add adds item to the storage.
|
||||
Add(item *Item) (err error)
|
||||
|
||||
// Remove deletes item from the storage.
|
||||
Remove(item *Item) (err error)
|
||||
|
||||
// List returns all items from the storage.
|
||||
List() (items []*Item)
|
||||
}
|
||||
|
||||
// DefaultStorage is the default storage for rewrite rules.
|
||||
type DefaultStorage struct {
|
||||
// mu protects items.
|
||||
@@ -42,7 +29,7 @@ type DefaultStorage struct {
|
||||
ruleList filterlist.RuleList
|
||||
|
||||
// rewrites stores the rewrite entries from configuration.
|
||||
rewrites []*Item
|
||||
rewrites []*filtering.RewriteItem
|
||||
|
||||
// urlFilterID is the synthetic integer identifier for the urlfilter engine.
|
||||
//
|
||||
@@ -53,16 +40,13 @@ type DefaultStorage struct {
|
||||
|
||||
// NewDefaultStorage returns new rewrites storage. listID is used as an
|
||||
// identifier of the underlying rules list. rewrites must not be nil.
|
||||
func NewDefaultStorage(listID int, rewrites []*Item) (s *DefaultStorage, err error) {
|
||||
func NewDefaultStorage(rewrites []*filtering.RewriteItem) (s *DefaultStorage, err error) {
|
||||
s = &DefaultStorage{
|
||||
mu: &sync.RWMutex{},
|
||||
urlFilterID: listID,
|
||||
urlFilterID: filtering.RewritesListID,
|
||||
rewrites: rewrites,
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
err = s.resetRules()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -72,9 +56,9 @@ func NewDefaultStorage(listID int, rewrites []*Item) (s *DefaultStorage, err err
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ Storage = (*DefaultStorage)(nil)
|
||||
var _ filtering.RewriteStorage = (*DefaultStorage)(nil)
|
||||
|
||||
// MatchRequest implements the [Storage] interface for *DefaultStorage.
|
||||
// MatchRequest implements the [RewriteStorage] interface for *DefaultStorage.
|
||||
func (s *DefaultStorage) MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.DNSRewrite) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
@@ -84,28 +68,32 @@ func (s *DefaultStorage) MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Check cnames for cycles on initialisation.
|
||||
// TODO(a.garipov): Check cnames for cycles on initialization.
|
||||
cnames := stringutil.NewSet()
|
||||
host := dReq.Hostname
|
||||
var lastCNAMERule *rules.NetworkRule
|
||||
for len(rrules) > 0 && rrules[0].DNSRewrite != nil && rrules[0].DNSRewrite.NewCNAME != "" {
|
||||
rule := rrules[0]
|
||||
rwAns := rule.DNSRewrite.NewCNAME
|
||||
lastCNAMERule = rrules[0]
|
||||
lastDNSRewrite := lastCNAMERule.DNSRewrite
|
||||
rwAns := lastDNSRewrite.NewCNAME
|
||||
|
||||
log.Debug("rewrite: cname for %s is %s", host, rwAns)
|
||||
|
||||
if dReq.Hostname == rwAns {
|
||||
// A request for the hostname itself is an exception rule.
|
||||
// A request for the hostname itself.
|
||||
// TODO(d.kolyshev): Check rewrite of a pattern onto itself.
|
||||
log.Debug("rewrite: request for hostname itself for %q", dReq.Hostname)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if host == rwAns && isWildcard(rule.RuleText) {
|
||||
if host == rwAns && isWildcard(lastCNAMERule.RuleText) {
|
||||
// An "*.example.com → sub.example.com" rewrite matching in a loop.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/4016.
|
||||
log.Debug("rewrite: cname wildcard loop for %q on %q", dReq.Hostname, rwAns)
|
||||
|
||||
return []*rules.DNSRewrite{rule.DNSRewrite}
|
||||
return []*rules.DNSRewrite{lastDNSRewrite}
|
||||
}
|
||||
|
||||
if cnames.Has(rwAns) {
|
||||
@@ -120,21 +108,28 @@ func (s *DefaultStorage) MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.
|
||||
Hostname: rwAns,
|
||||
DNSType: dReq.DNSType,
|
||||
})
|
||||
if drules != nil {
|
||||
rrules = drules
|
||||
|
||||
if drules == nil {
|
||||
break
|
||||
}
|
||||
|
||||
rrules = drules
|
||||
host = rwAns
|
||||
}
|
||||
|
||||
return s.collectDNSRewrites(rrules, dReq.DNSType)
|
||||
return s.collectDNSRewrites(rrules, lastCNAMERule, dReq.DNSType)
|
||||
}
|
||||
|
||||
// collectDNSRewrites filters DNSRewrite by question type.
|
||||
func (s *DefaultStorage) collectDNSRewrites(
|
||||
rewrites []*rules.NetworkRule,
|
||||
cnameRule *rules.NetworkRule,
|
||||
qtyp uint16,
|
||||
) (rws []*rules.DNSRewrite) {
|
||||
if cnameRule != nil {
|
||||
rewrites = append([]*rules.NetworkRule{cnameRule}, rewrites...)
|
||||
}
|
||||
|
||||
for _, rewrite := range rewrites {
|
||||
dnsRewrite := rewrite.DNSRewrite
|
||||
if matchesQType(dnsRewrite, qtyp) {
|
||||
@@ -152,8 +147,8 @@ func (s *DefaultStorage) rewriteRulesForReq(dReq *urlfilter.DNSRequest) (rules [
|
||||
return res.DNSRewrites()
|
||||
}
|
||||
|
||||
// Add implements the [Storage] interface for *DefaultStorage.
|
||||
func (s *DefaultStorage) Add(item *Item) (err error) {
|
||||
// Add implements the [RewriteStorage] interface for *DefaultStorage.
|
||||
func (s *DefaultStorage) Add(item *filtering.RewriteItem) (err error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@@ -163,16 +158,16 @@ func (s *DefaultStorage) Add(item *Item) (err error) {
|
||||
return s.resetRules()
|
||||
}
|
||||
|
||||
// Remove implements the [Storage] interface for *DefaultStorage.
|
||||
func (s *DefaultStorage) Remove(item *Item) (err error) {
|
||||
// Remove implements the [RewriteStorage] interface for *DefaultStorage.
|
||||
func (s *DefaultStorage) Remove(item *filtering.RewriteItem) (err error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
arr := []*Item{}
|
||||
arr := []*filtering.RewriteItem{}
|
||||
|
||||
// TODO(d.kolyshev): Use slices.IndexFunc + slices.Delete?
|
||||
for _, ent := range s.rewrites {
|
||||
if ent.equal(item) {
|
||||
if ent.Equal(item) {
|
||||
log.Debug("rewrite: removed element: %s -> %s", ent.Domain, ent.Answer)
|
||||
|
||||
continue
|
||||
@@ -185,8 +180,8 @@ func (s *DefaultStorage) Remove(item *Item) (err error) {
|
||||
return s.resetRules()
|
||||
}
|
||||
|
||||
// List implements the [Storage] interface for *DefaultStorage.
|
||||
func (s *DefaultStorage) List() (items []*Item) {
|
||||
// List implements the [RewriteStorage] interface for *DefaultStorage.
|
||||
func (s *DefaultStorage) List() (items []*filtering.RewriteItem) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
@@ -198,7 +193,7 @@ func (s *DefaultStorage) resetRules() (err error) {
|
||||
// TODO(a.garipov): Use strings.Builder.
|
||||
var rulesText []string
|
||||
for _, rewrite := range s.rewrites {
|
||||
rulesText = append(rulesText, rewrite.toRule())
|
||||
rulesText = append(rulesText, toRule(rewrite))
|
||||
}
|
||||
|
||||
strList := &filterlist.StringRuleList{
|
||||
@@ -222,20 +217,60 @@ func (s *DefaultStorage) resetRules() (err error) {
|
||||
|
||||
// matchesQType returns true if dnsrewrite matches the question type qt.
|
||||
func matchesQType(dnsrr *rules.DNSRewrite, qt uint16) (ok bool) {
|
||||
// Add CNAMEs, since they match for all types requests.
|
||||
if dnsrr.RRType == dns.TypeCNAME {
|
||||
switch qt {
|
||||
case dns.TypeA:
|
||||
return dnsrr.RRType != dns.TypeAAAA
|
||||
case dns.TypeAAAA:
|
||||
return dnsrr.RRType != dns.TypeA
|
||||
default:
|
||||
return true
|
||||
}
|
||||
|
||||
// Reject types other than A and AAAA.
|
||||
if qt != dns.TypeA && qt != dns.TypeAAAA {
|
||||
return false
|
||||
}
|
||||
|
||||
return dnsrr.RRType == qt
|
||||
}
|
||||
|
||||
// isWildcard returns true if pat is a wildcard domain pattern.
|
||||
func isWildcard(pat string) (res bool) {
|
||||
return strings.HasPrefix(pat, "|*.")
|
||||
}
|
||||
|
||||
// toRule converts rw to a filter rule.
|
||||
func toRule(rw *filtering.RewriteItem) (res string) {
|
||||
if rw == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
domain := strings.ToLower(rw.Domain)
|
||||
|
||||
dType, exception := rewriteParams(rw)
|
||||
dTypeKey := dns.TypeToString[dType]
|
||||
if exception {
|
||||
return fmt.Sprintf("@@||%s^$dnstype=%s,dnsrewrite", domain, dTypeKey)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("|%s^$dnsrewrite=NOERROR;%s;%s", domain, dTypeKey, rw.Answer)
|
||||
}
|
||||
|
||||
// RewriteParams returns dns request type and exception flag for rw.
|
||||
func rewriteParams(rw *filtering.RewriteItem) (dType uint16, exception bool) {
|
||||
switch rw.Answer {
|
||||
case "AAAA":
|
||||
return dns.TypeAAAA, true
|
||||
case "A":
|
||||
return dns.TypeA, true
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(rw.Answer)
|
||||
if err != nil {
|
||||
// TODO(d.kolyshev): Validate rw.Answer as a domain name.
|
||||
return dns.TypeCNAME, false
|
||||
}
|
||||
|
||||
if addr.Is4() {
|
||||
dType = dns.TypeA
|
||||
} else {
|
||||
dType = dns.TypeAAAA
|
||||
}
|
||||
|
||||
return dType, false
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
@@ -12,32 +13,32 @@ import (
|
||||
)
|
||||
|
||||
func TestNewDefaultStorage(t *testing.T) {
|
||||
items := []*Item{{
|
||||
items := []*filtering.RewriteItem{{
|
||||
Domain: "example.com",
|
||||
Answer: "answer.com",
|
||||
}}
|
||||
|
||||
s, err := NewDefaultStorage(-1, items)
|
||||
s, err := NewDefaultStorage(items)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, s.List(), 1)
|
||||
}
|
||||
|
||||
func TestDefaultStorage_CRUD(t *testing.T) {
|
||||
var items []*Item
|
||||
var items []*filtering.RewriteItem
|
||||
|
||||
s, err := NewDefaultStorage(-1, items)
|
||||
s, err := NewDefaultStorage(items)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, s.List(), 0)
|
||||
|
||||
item := &Item{Domain: "example.com", Answer: "answer.com"}
|
||||
item := &filtering.RewriteItem{Domain: "example.com", Answer: "answer.com"}
|
||||
|
||||
err = s.Add(item)
|
||||
require.NoError(t, err)
|
||||
|
||||
list := s.List()
|
||||
require.Len(t, list, 1)
|
||||
require.True(t, item.equal(list[0]))
|
||||
require.True(t, item.Equal(list[0]))
|
||||
|
||||
err = s.Remove(item)
|
||||
require.NoError(t, err)
|
||||
@@ -45,7 +46,7 @@ func TestDefaultStorage_CRUD(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultStorage_MatchRequest(t *testing.T) {
|
||||
items := []*Item{{
|
||||
items := []*filtering.RewriteItem{{
|
||||
// This one and below are about CNAME, A and AAAA.
|
||||
Domain: "somecname",
|
||||
Answer: "somehost.com",
|
||||
@@ -101,7 +102,7 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
|
||||
Answer: "sub.issue4016.com",
|
||||
}}
|
||||
|
||||
s, err := NewDefaultStorage(-1, items)
|
||||
s, err := NewDefaultStorage(items)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
@@ -115,14 +116,39 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
|
||||
wantDNSRewrites: nil,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "not_filtered_qtype",
|
||||
host: "www.host.com",
|
||||
wantDNSRewrites: nil,
|
||||
dtyp: dns.TypeMX,
|
||||
name: "other_qtype",
|
||||
host: "www.host.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: nil,
|
||||
NewCNAME: "host.com",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeNone,
|
||||
}, {
|
||||
Value: net.IP{1, 2, 3, 4}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
}, {
|
||||
Value: net.IP{1, 2, 3, 5}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
}, {
|
||||
Value: net.ParseIP("1:2:3::4"),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeAAAA,
|
||||
}},
|
||||
dtyp: dns.TypeMX,
|
||||
}, {
|
||||
name: "rewritten_a",
|
||||
host: "www.host.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: nil,
|
||||
NewCNAME: "host.com",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeNone,
|
||||
}, {
|
||||
Value: net.IP{1, 2, 3, 4}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
@@ -138,6 +164,11 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
|
||||
name: "rewritten_aaaa",
|
||||
host: "www.host.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: nil,
|
||||
NewCNAME: "host.com",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeNone,
|
||||
}, {
|
||||
Value: net.ParseIP("1:2:3::4"),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
@@ -154,21 +185,30 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
|
||||
RRType: dns.TypeA,
|
||||
}},
|
||||
dtyp: dns.TypeA,
|
||||
//}, {
|
||||
// TODO(d.kolyshev): This is about matching in urlfilter.
|
||||
// name: "wildcard_override",
|
||||
// host: "a.host.com",
|
||||
// wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
// Value: net.IP{1, 2, 3, 4}.To16(),
|
||||
// NewCNAME: "",
|
||||
// RCode: dns.RcodeSuccess,
|
||||
// RRType: dns.TypeA,
|
||||
// }},
|
||||
// dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "wildcard_override",
|
||||
host: "a.host.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: net.IP{1, 2, 3, 4}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
}, {
|
||||
Value: net.IP{1, 2, 3, 5}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
}},
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "wildcard_cname_interaction",
|
||||
host: "www.host2.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: nil,
|
||||
NewCNAME: "host.com",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeNone,
|
||||
}, {
|
||||
Value: net.IP{1, 2, 3, 4}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
@@ -184,6 +224,11 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
|
||||
name: "two_cnames",
|
||||
host: "b.host.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: nil,
|
||||
NewCNAME: "somehost.com",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeNone,
|
||||
}, {
|
||||
Value: net.IP{0, 0, 0, 0}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
@@ -194,6 +239,11 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
|
||||
name: "two_cnames_and_wildcard",
|
||||
host: "b.host3.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: nil,
|
||||
NewCNAME: "x.host.com",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeNone,
|
||||
}, {
|
||||
Value: net.IP{1, 2, 3, 5}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
@@ -221,10 +271,15 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
|
||||
}},
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "issue4008",
|
||||
host: "somehost.com",
|
||||
wantDNSRewrites: nil,
|
||||
dtyp: dns.TypeHTTPS,
|
||||
name: "issue4008",
|
||||
host: "somehost.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: net.IP{0, 0, 0, 0}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
}},
|
||||
dtyp: dns.TypeHTTPS,
|
||||
}, {
|
||||
name: "issue4016",
|
||||
host: "www.issue4016.com",
|
||||
@@ -256,7 +311,7 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
|
||||
|
||||
func TestDefaultStorage_MatchRequest_Levels(t *testing.T) {
|
||||
// Exact host, wildcard L2, wildcard L3.
|
||||
items := []*Item{{
|
||||
items := []*filtering.RewriteItem{{
|
||||
Domain: "host.com",
|
||||
Answer: "1.1.1.1",
|
||||
}, {
|
||||
@@ -267,7 +322,7 @@ func TestDefaultStorage_MatchRequest_Levels(t *testing.T) {
|
||||
Answer: "3.3.3.3",
|
||||
}}
|
||||
|
||||
s, err := NewDefaultStorage(-1, items)
|
||||
s, err := NewDefaultStorage(items)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
@@ -295,17 +350,21 @@ func TestDefaultStorage_MatchRequest_Levels(t *testing.T) {
|
||||
RRType: dns.TypeA,
|
||||
}},
|
||||
dtyp: dns.TypeA,
|
||||
//}, {
|
||||
// TODO(d.kolyshev): This is about matching in urlfilter.
|
||||
// name: "l3_match",
|
||||
// host: "my.sub.host.com",
|
||||
// wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
// Value: net.IP{3, 3, 3, 3}.To16(),
|
||||
// NewCNAME: "",
|
||||
// RCode: dns.RcodeSuccess,
|
||||
// RRType: dns.TypeA,
|
||||
// }},
|
||||
// dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "l3_match",
|
||||
host: "my.sub.host.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: net.IP{3, 3, 3, 3}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
}, {
|
||||
Value: net.IP{2, 2, 2, 2}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
}},
|
||||
dtyp: dns.TypeA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -322,7 +381,7 @@ func TestDefaultStorage_MatchRequest_Levels(t *testing.T) {
|
||||
|
||||
func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) {
|
||||
// Wildcard and exception for a sub-domain.
|
||||
items := []*Item{{
|
||||
items := []*filtering.RewriteItem{{
|
||||
Domain: "*.host.com",
|
||||
Answer: "2.2.2.2",
|
||||
}, {
|
||||
@@ -330,10 +389,10 @@ func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) {
|
||||
Answer: "sub.host.com",
|
||||
}, {
|
||||
Domain: "*.sub.host.com",
|
||||
Answer: "*.sub.host.com",
|
||||
Answer: "sub.host.com",
|
||||
}}
|
||||
|
||||
s, err := NewDefaultStorage(-1, items)
|
||||
s, err := NewDefaultStorage(items)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
@@ -356,12 +415,79 @@ func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) {
|
||||
host: "sub.host.com",
|
||||
wantDNSRewrites: nil,
|
||||
dtyp: dns.TypeA,
|
||||
//}, {
|
||||
// TODO(d.kolyshev): This is about matching in urlfilter.
|
||||
// name: "exception_wildcard",
|
||||
// host: "my.sub.host.com",
|
||||
// wantDNSRewrites: nil,
|
||||
// dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "exception_wildcard",
|
||||
host: "my.sub.host.com",
|
||||
wantDNSRewrites: nil,
|
||||
dtyp: dns.TypeA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
dnsRewrites := s.MatchRequest(&urlfilter.DNSRequest{
|
||||
Hostname: tc.host,
|
||||
DNSType: tc.dtyp,
|
||||
})
|
||||
|
||||
assert.Equal(t, tc.wantDNSRewrites, dnsRewrites)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultStorage_MatchRequest_CNAMEs(t *testing.T) {
|
||||
// Two cname rules for one subdomain
|
||||
items := []*filtering.RewriteItem{{
|
||||
Domain: "cname.org",
|
||||
Answer: "1.1.1.1",
|
||||
}, {
|
||||
Domain: "sub_cname.org",
|
||||
Answer: "2.2.2.2",
|
||||
}, {
|
||||
Domain: "*.host.com",
|
||||
Answer: "cname.org",
|
||||
}, {
|
||||
Domain: "*.sub.host.com",
|
||||
Answer: "sub_cname.org",
|
||||
}}
|
||||
|
||||
s, err := NewDefaultStorage(items)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
wantDNSRewrites []*rules.DNSRewrite
|
||||
dtyp uint16
|
||||
}{{
|
||||
name: "match_my_domain",
|
||||
host: "my.host.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: nil,
|
||||
NewCNAME: "cname.org",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeNone,
|
||||
}, {
|
||||
Value: net.IP{1, 1, 1, 1}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
}},
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "match_sub_my_domain",
|
||||
host: "my.sub.host.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: nil,
|
||||
NewCNAME: "cname.org",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeNone,
|
||||
}, {
|
||||
Value: net.IP{1, 1, 1, 1}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
}},
|
||||
dtyp: dns.TypeA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -378,7 +504,7 @@ func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) {
|
||||
|
||||
func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) {
|
||||
// Exception for AAAA record.
|
||||
items := []*Item{{
|
||||
items := []*filtering.RewriteItem{{
|
||||
Domain: "host.com",
|
||||
Answer: "1.2.3.4",
|
||||
}, {
|
||||
@@ -395,7 +521,7 @@ func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) {
|
||||
Answer: "A",
|
||||
}}
|
||||
|
||||
s, err := NewDefaultStorage(-1, items)
|
||||
s, err := NewDefaultStorage(items)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
@@ -456,3 +582,66 @@ func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToRule(t *testing.T) {
|
||||
const testDomain = "example.org"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
item *filtering.RewriteItem
|
||||
want string
|
||||
}{{
|
||||
name: "nil",
|
||||
item: nil,
|
||||
want: "",
|
||||
}, {
|
||||
name: "a_rule",
|
||||
item: &filtering.RewriteItem{
|
||||
Domain: testDomain,
|
||||
Answer: "1.1.1.1",
|
||||
},
|
||||
want: "|example.org^$dnsrewrite=NOERROR;A;1.1.1.1",
|
||||
}, {
|
||||
name: "aaaa_rule",
|
||||
item: &filtering.RewriteItem{
|
||||
Domain: testDomain,
|
||||
Answer: "1:2:3::4",
|
||||
},
|
||||
want: "|example.org^$dnsrewrite=NOERROR;AAAA;1:2:3::4",
|
||||
}, {
|
||||
name: "cname_rule",
|
||||
item: &filtering.RewriteItem{
|
||||
Domain: testDomain,
|
||||
Answer: "other.org",
|
||||
},
|
||||
want: "|example.org^$dnsrewrite=NOERROR;CNAME;other.org",
|
||||
}, {
|
||||
name: "wildcard_rule",
|
||||
item: &filtering.RewriteItem{
|
||||
Domain: "*.example.org",
|
||||
Answer: "other.org",
|
||||
},
|
||||
want: "|*.example.org^$dnsrewrite=NOERROR;CNAME;other.org",
|
||||
}, {
|
||||
name: "aaaa_exception",
|
||||
item: &filtering.RewriteItem{
|
||||
Domain: testDomain,
|
||||
Answer: "A",
|
||||
},
|
||||
want: "@@||example.org^$dnstype=A,dnsrewrite",
|
||||
}, {
|
||||
name: "aaaa_exception",
|
||||
item: &filtering.RewriteItem{
|
||||
Domain: testDomain,
|
||||
Answer: "AAAA",
|
||||
},
|
||||
want: "@@||example.org^$dnstype=AAAA,dnsrewrite",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
res := toRule(tc.item)
|
||||
assert.Equal(t, tc.want, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
61
internal/filtering/rewrite_test.go
Normal file
61
internal/filtering/rewrite_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package filtering
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestItem_equal(t *testing.T) {
|
||||
const (
|
||||
testDomain = "example.org"
|
||||
testAnswer = "1.1.1.1"
|
||||
)
|
||||
|
||||
testItem := &RewriteItem{
|
||||
Domain: testDomain,
|
||||
Answer: testAnswer,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
left *RewriteItem
|
||||
right *RewriteItem
|
||||
want bool
|
||||
}{{
|
||||
name: "nil_left",
|
||||
left: nil,
|
||||
right: testItem,
|
||||
want: false,
|
||||
}, {
|
||||
name: "nil_right",
|
||||
left: testItem,
|
||||
right: nil,
|
||||
want: false,
|
||||
}, {
|
||||
name: "nils",
|
||||
left: nil,
|
||||
right: nil,
|
||||
want: true,
|
||||
}, {
|
||||
name: "equal",
|
||||
left: testItem,
|
||||
right: testItem,
|
||||
want: true,
|
||||
}, {
|
||||
name: "distinct",
|
||||
left: testItem,
|
||||
right: &RewriteItem{
|
||||
Domain: "other",
|
||||
Answer: "other",
|
||||
},
|
||||
want: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
res := tc.left.Equal(tc.right)
|
||||
assert.Equal(t, tc.want, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -8,85 +8,57 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// TODO(d.kolyshev): Use [rewrite.Item] instead.
|
||||
type rewriteEntryJSON struct {
|
||||
Domain string `json:"domain"`
|
||||
Answer string `json:"answer"`
|
||||
}
|
||||
|
||||
// handleRewriteList is the handler for the GET /control/rewrite/list HTTP API.
|
||||
func (d *DNSFilter) handleRewriteList(w http.ResponseWriter, r *http.Request) {
|
||||
arr := []*rewriteEntryJSON{}
|
||||
|
||||
d.confLock.Lock()
|
||||
for _, ent := range d.Config.Rewrites {
|
||||
jsent := rewriteEntryJSON{
|
||||
Domain: ent.Domain,
|
||||
Answer: ent.Answer,
|
||||
}
|
||||
arr = append(arr, &jsent)
|
||||
}
|
||||
d.confLock.Unlock()
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, arr)
|
||||
_ = aghhttp.WriteJSONResponse(w, r, d.rewriteStorage.List())
|
||||
}
|
||||
|
||||
// handleRewriteAdd is the handler for the POST /control/rewrite/add HTTP API.
|
||||
func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
|
||||
rwJSON := rewriteEntryJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&rwJSON)
|
||||
rw := &RewriteItem{}
|
||||
err := json.NewDecoder(r.Body).Decode(rw)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
rw := &LegacyRewrite{
|
||||
Domain: rwJSON.Domain,
|
||||
Answer: rwJSON.Answer,
|
||||
}
|
||||
|
||||
err = rw.normalize()
|
||||
err = d.rewriteStorage.Add(rw)
|
||||
if err != nil {
|
||||
// Shouldn't happen currently, since normalize only returns a non-nil
|
||||
// error when a rewrite is nil, but be change-proof.
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "normalizing: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "add rewrite: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("rewrite: added element: %s -> %s", rw.Domain, rw.Answer)
|
||||
|
||||
d.confLock.Lock()
|
||||
d.Config.Rewrites = append(d.Config.Rewrites, rw)
|
||||
d.Config.Rewrites = d.rewriteStorage.List()
|
||||
d.confLock.Unlock()
|
||||
log.Debug("rewrite: added element: %s -> %s [%d]", rw.Domain, rw.Answer, len(d.Config.Rewrites))
|
||||
|
||||
d.Config.ConfigModified()
|
||||
}
|
||||
|
||||
// handleRewriteDelete is the handler for the POST /control/rewrite/delete HTTP
|
||||
// API.
|
||||
func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request) {
|
||||
jsent := rewriteEntryJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&jsent)
|
||||
entDel := RewriteItem{}
|
||||
err := json.NewDecoder(r.Body).Decode(&entDel)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
entDel := &LegacyRewrite{
|
||||
Domain: jsent.Domain,
|
||||
Answer: jsent.Answer,
|
||||
err = d.rewriteStorage.Remove(&entDel)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "remove rewrite: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
arr := []*LegacyRewrite{}
|
||||
|
||||
d.confLock.Lock()
|
||||
for _, ent := range d.Config.Rewrites {
|
||||
if ent.equal(entDel) {
|
||||
log.Debug("rewrite: removed element: %s -> %s", ent.Domain, ent.Answer)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
arr = append(arr, ent)
|
||||
}
|
||||
d.Config.Rewrites = arr
|
||||
d.Config.Rewrites = d.rewriteStorage.List()
|
||||
d.confLock.Unlock()
|
||||
|
||||
d.Config.ConfigModified()
|
||||
|
||||
@@ -1,219 +0,0 @@
|
||||
// DNS Rewrites
|
||||
|
||||
package filtering
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// LegacyRewrite is a single legacy DNS rewrite record.
|
||||
//
|
||||
// Instances of *LegacyRewrite must never be nil.
|
||||
type LegacyRewrite struct {
|
||||
// Domain is the domain pattern for which this rewrite should work.
|
||||
Domain string `yaml:"domain"`
|
||||
|
||||
// Answer is the IP address, canonical name, or one of the special
|
||||
// values: "A" or "AAAA".
|
||||
Answer string `yaml:"answer"`
|
||||
|
||||
// IP is the IP address that should be used in the response if Type is
|
||||
// dns.TypeA or dns.TypeAAAA.
|
||||
IP net.IP `yaml:"-"`
|
||||
|
||||
// Type is the DNS record type: A, AAAA, or CNAME.
|
||||
Type uint16 `yaml:"-"`
|
||||
}
|
||||
|
||||
// clone returns a deep clone of rw.
|
||||
func (rw *LegacyRewrite) clone() (cloneRW *LegacyRewrite) {
|
||||
return &LegacyRewrite{
|
||||
Domain: rw.Domain,
|
||||
Answer: rw.Answer,
|
||||
IP: slices.Clone(rw.IP),
|
||||
Type: rw.Type,
|
||||
}
|
||||
}
|
||||
|
||||
// equal returns true if the rw is equal to the other.
|
||||
func (rw *LegacyRewrite) equal(other *LegacyRewrite) (ok bool) {
|
||||
return rw.Domain == other.Domain && rw.Answer == other.Answer
|
||||
}
|
||||
|
||||
// matchesQType returns true if the entry matches the question type qt.
|
||||
func (rw *LegacyRewrite) matchesQType(qt uint16) (ok bool) {
|
||||
// Add CNAMEs, since they match for all types requests.
|
||||
if rw.Type == dns.TypeCNAME {
|
||||
return true
|
||||
}
|
||||
|
||||
// Reject types other than A and AAAA.
|
||||
if qt != dns.TypeA && qt != dns.TypeAAAA {
|
||||
return false
|
||||
}
|
||||
|
||||
// If the types match or the entry is set to allow only the other type,
|
||||
// include them.
|
||||
return rw.Type == qt || rw.IP == nil
|
||||
}
|
||||
|
||||
// normalize makes sure that the a new or decoded entry is normalized with
|
||||
// regards to domain name case, IP length, and so on.
|
||||
//
|
||||
// If rw is nil, it returns an errors.
|
||||
func (rw *LegacyRewrite) normalize() (err error) {
|
||||
if rw == nil {
|
||||
return errors.Error("nil rewrite entry")
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Write a case-agnostic version of strings.HasSuffix and
|
||||
// use it in matchDomainWildcard instead of using strings.ToLower
|
||||
// everywhere.
|
||||
rw.Domain = strings.ToLower(rw.Domain)
|
||||
|
||||
switch rw.Answer {
|
||||
case "AAAA":
|
||||
rw.IP = nil
|
||||
rw.Type = dns.TypeAAAA
|
||||
|
||||
return nil
|
||||
case "A":
|
||||
rw.IP = nil
|
||||
rw.Type = dns.TypeA
|
||||
|
||||
return nil
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
ip := net.ParseIP(rw.Answer)
|
||||
if ip == nil {
|
||||
rw.Type = dns.TypeCNAME
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
ip4 := ip.To4()
|
||||
if ip4 != nil {
|
||||
rw.IP = ip4
|
||||
rw.Type = dns.TypeA
|
||||
} else {
|
||||
rw.IP = ip
|
||||
rw.Type = dns.TypeAAAA
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isWildcard returns true if pat is a wildcard domain pattern.
|
||||
func isWildcard(pat string) bool {
|
||||
return len(pat) > 1 && pat[0] == '*' && pat[1] == '.'
|
||||
}
|
||||
|
||||
// matchDomainWildcard returns true if host matches the wildcard pattern.
|
||||
func matchDomainWildcard(host, wildcard string) (ok bool) {
|
||||
return isWildcard(wildcard) && strings.HasSuffix(host, wildcard[1:])
|
||||
}
|
||||
|
||||
// rewritesSorted is a slice of legacy rewrites for sorting.
|
||||
//
|
||||
// The sorting priority:
|
||||
//
|
||||
// 1. A and AAAA > CNAME
|
||||
// 2. wildcard > exact
|
||||
// 3. lower level wildcard > higher level wildcard
|
||||
//
|
||||
// TODO(a.garipov): Replace with slices.Sort.
|
||||
type rewritesSorted []*LegacyRewrite
|
||||
|
||||
// Len implements the sort.Interface interface for rewritesSorted.
|
||||
func (a rewritesSorted) Len() (l int) { return len(a) }
|
||||
|
||||
// Swap implements the sort.Interface interface for rewritesSorted.
|
||||
func (a rewritesSorted) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
|
||||
// Less implements the sort.Interface interface for rewritesSorted.
|
||||
func (a rewritesSorted) Less(i, j int) (less bool) {
|
||||
ith, jth := a[i], a[j]
|
||||
if ith.Type == dns.TypeCNAME && jth.Type != dns.TypeCNAME {
|
||||
return true
|
||||
} else if ith.Type != dns.TypeCNAME && jth.Type == dns.TypeCNAME {
|
||||
return false
|
||||
}
|
||||
|
||||
if iw, jw := isWildcard(ith.Domain), isWildcard(jth.Domain); iw != jw {
|
||||
return jw
|
||||
}
|
||||
|
||||
// Both are either wildcards or not.
|
||||
return len(ith.Domain) > len(jth.Domain)
|
||||
}
|
||||
|
||||
// prepareRewrites normalizes and validates all legacy DNS rewrites.
|
||||
func (d *DNSFilter) prepareRewrites() (err error) {
|
||||
for i, r := range d.Rewrites {
|
||||
err = r.normalize()
|
||||
if err != nil {
|
||||
return fmt.Errorf("at index %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findRewrites returns the list of matched rewrite entries. If rewrites are
|
||||
// empty, but matched is true, the domain is found among the rewrite rules but
|
||||
// not for this question type.
|
||||
//
|
||||
// The result priority is: CNAME, then A and AAAA; exact, then wildcard. If the
|
||||
// host is matched exactly, wildcard entries aren't returned. If the host
|
||||
// matched by wildcards, return the most specific for the question type.
|
||||
func findRewrites(
|
||||
entries []*LegacyRewrite,
|
||||
host string,
|
||||
qtype uint16,
|
||||
) (rewrites []*LegacyRewrite, matched bool) {
|
||||
for _, e := range entries {
|
||||
if e.Domain != host && !matchDomainWildcard(host, e.Domain) {
|
||||
continue
|
||||
}
|
||||
|
||||
matched = true
|
||||
if e.matchesQType(qtype) {
|
||||
rewrites = append(rewrites, e)
|
||||
}
|
||||
}
|
||||
|
||||
if len(rewrites) == 0 {
|
||||
return nil, matched
|
||||
}
|
||||
|
||||
sort.Sort(rewritesSorted(rewrites))
|
||||
|
||||
for i, r := range rewrites {
|
||||
if isWildcard(r.Domain) {
|
||||
// Don't use rewrites[:0], because we need to return at least one
|
||||
// item here.
|
||||
rewrites = rewrites[:max(1, i)]
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return rewrites, matched
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
@@ -1,371 +0,0 @@
|
||||
package filtering
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TODO(e.burkov): All the tests in this file may and should me merged together.
|
||||
|
||||
func TestRewrites(t *testing.T) {
|
||||
d, _ := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
// This one and below are about CNAME, A and AAAA.
|
||||
Domain: "somecname",
|
||||
Answer: "somehost.com",
|
||||
}, {
|
||||
Domain: "somehost.com",
|
||||
Answer: "0.0.0.0",
|
||||
}, {
|
||||
Domain: "host.com",
|
||||
Answer: "1.2.3.4",
|
||||
}, {
|
||||
Domain: "host.com",
|
||||
Answer: "1.2.3.5",
|
||||
}, {
|
||||
Domain: "host.com",
|
||||
Answer: "1:2:3::4",
|
||||
}, {
|
||||
Domain: "www.host.com",
|
||||
Answer: "host.com",
|
||||
}, {
|
||||
// This one is a wildcard.
|
||||
Domain: "*.host.com",
|
||||
Answer: "1.2.3.5",
|
||||
}, {
|
||||
// This one and below are about wildcard overriding.
|
||||
Domain: "a.host.com",
|
||||
Answer: "1.2.3.4",
|
||||
}, {
|
||||
// This one is about CNAME and wildcard interacting.
|
||||
Domain: "*.host2.com",
|
||||
Answer: "host.com",
|
||||
}, {
|
||||
// This one and below are about 2 level CNAME.
|
||||
Domain: "b.host.com",
|
||||
Answer: "somecname",
|
||||
}, {
|
||||
// This one and below are about 2 level CNAME and wildcard.
|
||||
Domain: "b.host3.com",
|
||||
Answer: "a.host3.com",
|
||||
}, {
|
||||
Domain: "a.host3.com",
|
||||
Answer: "x.host.com",
|
||||
}, {
|
||||
Domain: "*.hostboth.com",
|
||||
Answer: "1.2.3.6",
|
||||
}, {
|
||||
Domain: "*.hostboth.com",
|
||||
Answer: "1234::5678",
|
||||
}, {
|
||||
Domain: "BIGHOST.COM",
|
||||
Answer: "1.2.3.7",
|
||||
}, {
|
||||
Domain: "*.issue4016.com",
|
||||
Answer: "sub.issue4016.com",
|
||||
}}
|
||||
|
||||
require.NoError(t, d.prepareRewrites())
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
wantCName string
|
||||
wantIPs []net.IP
|
||||
wantReason Reason
|
||||
dtyp uint16
|
||||
}{{
|
||||
name: "not_filtered_not_found",
|
||||
host: "hoost.com",
|
||||
wantCName: "",
|
||||
wantIPs: nil,
|
||||
wantReason: NotFilteredNotFound,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "rewritten_a",
|
||||
host: "www.host.com",
|
||||
wantCName: "host.com",
|
||||
wantIPs: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "rewritten_aaaa",
|
||||
host: "www.host.com",
|
||||
wantCName: "host.com",
|
||||
wantIPs: []net.IP{net.ParseIP("1:2:3::4")},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "wildcard_match",
|
||||
host: "abc.host.com",
|
||||
wantCName: "",
|
||||
wantIPs: []net.IP{{1, 2, 3, 5}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "wildcard_override",
|
||||
host: "a.host.com",
|
||||
wantCName: "",
|
||||
wantIPs: []net.IP{{1, 2, 3, 4}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "wildcard_cname_interaction",
|
||||
host: "www.host2.com",
|
||||
wantCName: "host.com",
|
||||
wantIPs: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "two_cnames",
|
||||
host: "b.host.com",
|
||||
wantCName: "somehost.com",
|
||||
wantIPs: []net.IP{{0, 0, 0, 0}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "two_cnames_and_wildcard",
|
||||
host: "b.host3.com",
|
||||
wantCName: "x.host.com",
|
||||
wantIPs: []net.IP{{1, 2, 3, 5}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "issue3343",
|
||||
host: "www.hostboth.com",
|
||||
wantCName: "",
|
||||
wantIPs: []net.IP{net.ParseIP("1234::5678")},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "issue3351",
|
||||
host: "bighost.com",
|
||||
wantCName: "",
|
||||
wantIPs: []net.IP{{1, 2, 3, 7}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "issue4008",
|
||||
host: "somehost.com",
|
||||
wantCName: "",
|
||||
wantIPs: nil,
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeHTTPS,
|
||||
}, {
|
||||
name: "issue4016",
|
||||
host: "www.issue4016.com",
|
||||
wantCName: "sub.issue4016.com",
|
||||
wantIPs: nil,
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "issue4016_self",
|
||||
host: "sub.issue4016.com",
|
||||
wantCName: "",
|
||||
wantIPs: nil,
|
||||
wantReason: NotFilteredNotFound,
|
||||
dtyp: dns.TypeA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := d.processRewrites(tc.host, tc.dtyp)
|
||||
require.Equalf(t, tc.wantReason, r.Reason, "got %s", r.Reason)
|
||||
|
||||
if tc.wantCName != "" {
|
||||
assert.Equal(t, tc.wantCName, r.CanonName)
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.wantIPs, r.IPList)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewritesLevels(t *testing.T) {
|
||||
d, _ := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// Exact host, wildcard L2, wildcard L3.
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
Domain: "host.com",
|
||||
Answer: "1.1.1.1",
|
||||
Type: dns.TypeA,
|
||||
}, {
|
||||
Domain: "*.host.com",
|
||||
Answer: "2.2.2.2",
|
||||
Type: dns.TypeA,
|
||||
}, {
|
||||
Domain: "*.sub.host.com",
|
||||
Answer: "3.3.3.3",
|
||||
Type: dns.TypeA,
|
||||
}}
|
||||
|
||||
require.NoError(t, d.prepareRewrites())
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
want net.IP
|
||||
}{{
|
||||
name: "exact_match",
|
||||
host: "host.com",
|
||||
want: net.IP{1, 1, 1, 1},
|
||||
}, {
|
||||
name: "l2_match",
|
||||
host: "sub.host.com",
|
||||
want: net.IP{2, 2, 2, 2},
|
||||
}, {
|
||||
name: "l3_match",
|
||||
host: "my.sub.host.com",
|
||||
want: net.IP{3, 3, 3, 3},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := d.processRewrites(tc.host, dns.TypeA)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
require.Len(t, r.IPList, 1)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewritesExceptionCNAME(t *testing.T) {
|
||||
d, _ := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// Wildcard and exception for a sub-domain.
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
Domain: "*.host.com",
|
||||
Answer: "2.2.2.2",
|
||||
}, {
|
||||
Domain: "sub.host.com",
|
||||
Answer: "sub.host.com",
|
||||
}, {
|
||||
Domain: "*.sub.host.com",
|
||||
Answer: "*.sub.host.com",
|
||||
}}
|
||||
|
||||
require.NoError(t, d.prepareRewrites())
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
want net.IP
|
||||
}{{
|
||||
name: "match_subdomain",
|
||||
host: "my.host.com",
|
||||
want: net.IP{2, 2, 2, 2},
|
||||
}, {
|
||||
name: "exception_cname",
|
||||
host: "sub.host.com",
|
||||
want: nil,
|
||||
}, {
|
||||
name: "exception_wildcard",
|
||||
host: "my.sub.host.com",
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := d.processRewrites(tc.host, dns.TypeA)
|
||||
if tc.want == nil {
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason, "got %s", r.Reason)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
require.Len(t, r.IPList, 1)
|
||||
assert.True(t, tc.want.Equal(r.IPList[0]))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewritesExceptionIP(t *testing.T) {
|
||||
d, _ := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// Exception for AAAA record.
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
Domain: "host.com",
|
||||
Answer: "1.2.3.4",
|
||||
Type: dns.TypeA,
|
||||
}, {
|
||||
Domain: "host.com",
|
||||
Answer: "AAAA",
|
||||
Type: dns.TypeAAAA,
|
||||
}, {
|
||||
Domain: "host2.com",
|
||||
Answer: "::1",
|
||||
Type: dns.TypeAAAA,
|
||||
}, {
|
||||
Domain: "host2.com",
|
||||
Answer: "A",
|
||||
Type: dns.TypeA,
|
||||
}, {
|
||||
Domain: "host3.com",
|
||||
Answer: "A",
|
||||
Type: dns.TypeA,
|
||||
}}
|
||||
|
||||
require.NoError(t, d.prepareRewrites())
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
want []net.IP
|
||||
dtyp uint16
|
||||
}{{
|
||||
name: "match_A",
|
||||
host: "host.com",
|
||||
want: []net.IP{{1, 2, 3, 4}},
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "exception_AAAA_host.com",
|
||||
host: "host.com",
|
||||
want: nil,
|
||||
dtyp: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "exception_A_host2.com",
|
||||
host: "host2.com",
|
||||
want: nil,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "match_AAAA_host2.com",
|
||||
host: "host2.com",
|
||||
want: []net.IP{net.ParseIP("::1")},
|
||||
dtyp: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "exception_A_host3.com",
|
||||
host: "host3.com",
|
||||
want: nil,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "match_AAAA_host3.com",
|
||||
host: "host3.com",
|
||||
want: []net.IP{},
|
||||
dtyp: dns.TypeAAAA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name+"_"+tc.host, func(t *testing.T) {
|
||||
r := d.processRewrites(tc.host, tc.dtyp)
|
||||
if tc.want == nil {
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equalf(t, Rewritten, r.Reason, "got %s", r.Reason)
|
||||
|
||||
require.Len(t, r.IPList, len(tc.want))
|
||||
|
||||
for _, ip := range tc.want {
|
||||
assert.True(t, ip.Equal(r.IPList[0]))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rewrite"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
@@ -74,7 +75,12 @@ func initDNS() (err error) {
|
||||
}
|
||||
Context.queryLog = querylog.New(conf)
|
||||
|
||||
Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil)
|
||||
rewriteStorage, err := rewrite.NewDefaultStorage(config.DNS.DnsfilterConf.Rewrites)
|
||||
if err != nil {
|
||||
return fmt.Errorf("rewrites: init: %w", err)
|
||||
}
|
||||
|
||||
Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil, rewriteStorage)
|
||||
if err != nil {
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return err
|
||||
|
||||
Reference in New Issue
Block a user