all: rewrite package dependency
This commit is contained in:
@@ -68,7 +68,7 @@ func createTestServer(
|
|||||||
ID: 0, Data: []byte(rules),
|
ID: 0, Data: []byte(rules),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
f, err := filtering.New(filterConf, filters)
|
f, err := filtering.New(filterConf, filters, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
f.SetEnabled(true)
|
f.SetEnabled(true)
|
||||||
@@ -761,7 +761,7 @@ func TestBlockedCustomIP(t *testing.T) {
|
|||||||
Data: []byte(rules),
|
Data: []byte(rules),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
f, err := filtering.New(&filtering.Config{}, filters)
|
f, err := filtering.New(&filtering.Config{}, filters, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
s, err := NewServer(DNSCreateParams{
|
s, err := NewServer(DNSCreateParams{
|
||||||
@@ -881,7 +881,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
|||||||
|
|
||||||
func TestRewrite(t *testing.T) {
|
func TestRewrite(t *testing.T) {
|
||||||
c := &filtering.Config{
|
c := &filtering.Config{
|
||||||
Rewrites: []*rewrite.Item{{
|
Rewrites: []*filtering.RewriteItem{{
|
||||||
Domain: "test.com",
|
Domain: "test.com",
|
||||||
Answer: "1.2.3.4",
|
Answer: "1.2.3.4",
|
||||||
}, {
|
}, {
|
||||||
@@ -892,7 +892,11 @@ func TestRewrite(t *testing.T) {
|
|||||||
Answer: "example.org",
|
Answer: "example.org",
|
||||||
}},
|
}},
|
||||||
}
|
}
|
||||||
f, err := filtering.New(c, nil)
|
|
||||||
|
rewriteStorage, err := rewrite.NewDefaultStorage(c.Rewrites)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
f, err := filtering.New(c, nil, rewriteStorage)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
f.SetEnabled(true)
|
f.SetEnabled(true)
|
||||||
@@ -943,6 +947,12 @@ func TestRewrite(t *testing.T) {
|
|||||||
|
|
||||||
assert.Empty(t, reply.Answer)
|
assert.Empty(t, reply.Answer)
|
||||||
|
|
||||||
|
req = createTestMessageWithType("test.com.", dns.TypeTXT)
|
||||||
|
reply, eerr = dns.Exchange(req, addr.String())
|
||||||
|
require.NoError(t, eerr)
|
||||||
|
|
||||||
|
assert.Empty(t, reply.Answer)
|
||||||
|
|
||||||
req = createTestMessageWithType("alias.test.com.", dns.TypeA)
|
req = createTestMessageWithType("alias.test.com.", dns.TypeA)
|
||||||
reply, eerr = dns.Exchange(req, addr.String())
|
reply, eerr = dns.Exchange(req, addr.String())
|
||||||
require.NoError(t, eerr)
|
require.NoError(t, eerr)
|
||||||
@@ -953,6 +963,12 @@ func TestRewrite(t *testing.T) {
|
|||||||
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
|
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
|
||||||
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
|
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
|
||||||
|
|
||||||
|
req = createTestMessageWithType("alias.test.com.", dns.TypeTXT)
|
||||||
|
reply, eerr = dns.Exchange(req, addr.String())
|
||||||
|
require.NoError(t, eerr)
|
||||||
|
|
||||||
|
assert.Empty(t, reply.Answer)
|
||||||
|
|
||||||
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
|
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
|
||||||
reply, eerr = dns.Exchange(req, addr.String())
|
reply, eerr = dns.Exchange(req, addr.String())
|
||||||
require.NoError(t, eerr)
|
require.NoError(t, eerr)
|
||||||
@@ -966,6 +982,12 @@ func TestRewrite(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
|
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
|
||||||
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
|
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
|
||||||
|
|
||||||
|
req = createTestMessageWithType("my.alias.test.com.", dns.TypeTXT)
|
||||||
|
reply, eerr = dns.Exchange(req, addr.String())
|
||||||
|
require.NoError(t, eerr)
|
||||||
|
|
||||||
|
assert.Empty(t, reply.Answer)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, protect := range []bool{true, false} {
|
for _, protect := range []bool{true, false} {
|
||||||
@@ -1010,7 +1032,7 @@ var testDHCP = &dhcpd.MockInterface{
|
|||||||
func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||||
const localDomain = "lan"
|
const localDomain = "lan"
|
||||||
|
|
||||||
flt, err := filtering.New(&filtering.Config{}, nil)
|
flt, err := filtering.New(&filtering.Config{}, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
s, err := NewServer(DNSCreateParams{
|
s, err := NewServer(DNSCreateParams{
|
||||||
@@ -1084,7 +1106,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
|||||||
|
|
||||||
flt, err := filtering.New(&filtering.Config{
|
flt, err := filtering.New(&filtering.Config{
|
||||||
EtcHosts: hc,
|
EtcHosts: hc,
|
||||||
}, nil)
|
}, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
flt.SetEnabled(true)
|
flt.SetEnabled(true)
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
|||||||
ID: 0, Data: []byte(rules),
|
ID: 0, Data: []byte(rules),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
f, err := filtering.New(&filtering.Config{}, filters)
|
f, err := filtering.New(&filtering.Config{}, filters, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
f.SetEnabled(true)
|
f.SetEnabled(true)
|
||||||
|
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ func TestFilters(t *testing.T) {
|
|||||||
HTTPClient: &http.Client{
|
HTTPClient: &http.Client{
|
||||||
Timeout: 5 * time.Second,
|
Timeout: 5 * time.Second,
|
||||||
},
|
},
|
||||||
}, nil)
|
}, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
f := &FilterYAML{
|
f := &FilterYAML{
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import (
|
|||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rewrite"
|
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/cache"
|
"github.com/AdguardTeam/golibs/cache"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
@@ -91,7 +90,7 @@ type Config struct {
|
|||||||
ParentalCacheSize uint `yaml:"parental_cache_size"` // (in bytes)
|
ParentalCacheSize uint `yaml:"parental_cache_size"` // (in bytes)
|
||||||
CacheTime uint `yaml:"cache_time"` // Element's TTL (in minutes)
|
CacheTime uint `yaml:"cache_time"` // Element's TTL (in minutes)
|
||||||
|
|
||||||
Rewrites []*rewrite.Item `yaml:"rewrites"`
|
Rewrites []*RewriteItem `yaml:"rewrites"`
|
||||||
|
|
||||||
// Names of services to block (globally).
|
// Names of services to block (globally).
|
||||||
// Per-client settings can override this configuration.
|
// Per-client settings can override this configuration.
|
||||||
@@ -195,7 +194,7 @@ type DNSFilter struct {
|
|||||||
// TODO(e.burkov): Don't use regexp for such a simple text processing task.
|
// TODO(e.burkov): Don't use regexp for such a simple text processing task.
|
||||||
filterTitleRegexp *regexp.Regexp
|
filterTitleRegexp *regexp.Regexp
|
||||||
|
|
||||||
rewriteStorage *rewrite.DefaultStorage
|
rewriteStorage RewriteStorage
|
||||||
|
|
||||||
hostCheckers []hostChecker
|
hostCheckers []hostChecker
|
||||||
}
|
}
|
||||||
@@ -544,6 +543,10 @@ func (d *DNSFilter) processRewrites(host string, qtype uint16) (res Result) {
|
|||||||
d.confLock.RLock()
|
d.confLock.RLock()
|
||||||
defer d.confLock.RUnlock()
|
defer d.confLock.RUnlock()
|
||||||
|
|
||||||
|
if d.rewriteStorage == nil {
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
dnsr := d.rewriteStorage.MatchRequest(&urlfilter.DNSRequest{
|
dnsr := d.rewriteStorage.MatchRequest(&urlfilter.DNSRequest{
|
||||||
Hostname: host,
|
Hostname: host,
|
||||||
DNSType: qtype,
|
DNSType: qtype,
|
||||||
@@ -893,7 +896,7 @@ func InitModule() {
|
|||||||
|
|
||||||
// New creates properly initialized DNS Filter that is ready to be used. c must
|
// New creates properly initialized DNS Filter that is ready to be used. c must
|
||||||
// be non-nil.
|
// be non-nil.
|
||||||
func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
func New(c *Config, blockFilters []Filter, rewriteStorage RewriteStorage) (d *DNSFilter, err error) {
|
||||||
d = &DNSFilter{
|
d = &DNSFilter{
|
||||||
resolver: net.DefaultResolver,
|
resolver: net.DefaultResolver,
|
||||||
refreshLock: &sync.Mutex{},
|
refreshLock: &sync.Mutex{},
|
||||||
@@ -946,11 +949,7 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
|||||||
|
|
||||||
d.Config = *c
|
d.Config = *c
|
||||||
d.filtersMu = &sync.RWMutex{}
|
d.filtersMu = &sync.RWMutex{}
|
||||||
|
d.rewriteStorage = rewriteStorage
|
||||||
d.rewriteStorage, err = rewrite.NewDefaultStorage(RewritesListID, d.Rewrites)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("rewrites: init: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
bsvcs := []string{}
|
bsvcs := []string{}
|
||||||
for _, s := range d.BlockedServices {
|
for _, s := range d.BlockedServices {
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rewrite"
|
|
||||||
"github.com/AdguardTeam/golibs/cache"
|
"github.com/AdguardTeam/golibs/cache"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
@@ -47,6 +46,7 @@ func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts
|
|||||||
ProtectionEnabled: true,
|
ProtectionEnabled: true,
|
||||||
FilteringEnabled: true,
|
FilteringEnabled: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
if c != nil {
|
if c != nil {
|
||||||
c.SafeBrowsingCacheSize = 10000
|
c.SafeBrowsingCacheSize = 10000
|
||||||
c.ParentalCacheSize = 10000
|
c.ParentalCacheSize = 10000
|
||||||
@@ -59,7 +59,8 @@ func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts
|
|||||||
// It must not be nil.
|
// It must not be nil.
|
||||||
c = &Config{}
|
c = &Config{}
|
||||||
}
|
}
|
||||||
f, err := New(c, filters)
|
|
||||||
|
f, err := New(c, filters, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
purgeCaches(f)
|
purgeCaches(f)
|
||||||
@@ -695,96 +696,6 @@ func TestMatching(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRewrites(t *testing.T) {
|
|
||||||
rewrites := []*rewrite.Item{{
|
|
||||||
Domain: "example.org",
|
|
||||||
Answer: "1.1.1.1",
|
|
||||||
}, {
|
|
||||||
Domain: "example-v6.org",
|
|
||||||
Answer: "1:2:3::4",
|
|
||||||
}, {
|
|
||||||
Domain: "cname.org",
|
|
||||||
Answer: "cname-res.org",
|
|
||||||
}}
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
host string
|
|
||||||
wantReason Reason
|
|
||||||
wantIsFiltered bool
|
|
||||||
qtype uint16
|
|
||||||
}{{
|
|
||||||
name: "not_found_a",
|
|
||||||
host: "not-example.org",
|
|
||||||
wantIsFiltered: false,
|
|
||||||
wantReason: NotFilteredNotFound,
|
|
||||||
qtype: dns.TypeA,
|
|
||||||
}, {
|
|
||||||
name: "not_found_aaaa",
|
|
||||||
host: "not-example.org",
|
|
||||||
wantIsFiltered: false,
|
|
||||||
wantReason: NotFilteredNotFound,
|
|
||||||
qtype: dns.TypeAAAA,
|
|
||||||
}, {
|
|
||||||
name: "not_found_txt",
|
|
||||||
host: "not-example.org",
|
|
||||||
wantIsFiltered: false,
|
|
||||||
wantReason: NotFilteredNotFound,
|
|
||||||
qtype: dns.TypeTXT,
|
|
||||||
}, {
|
|
||||||
name: "found_a",
|
|
||||||
host: "example.org",
|
|
||||||
wantIsFiltered: false,
|
|
||||||
wantReason: Rewritten,
|
|
||||||
qtype: dns.TypeA,
|
|
||||||
}, {
|
|
||||||
name: "found_aaaa",
|
|
||||||
host: "example-v6.org",
|
|
||||||
wantIsFiltered: false,
|
|
||||||
wantReason: Rewritten,
|
|
||||||
qtype: dns.TypeAAAA,
|
|
||||||
}, {
|
|
||||||
name: "found_txt",
|
|
||||||
host: "example.org",
|
|
||||||
wantIsFiltered: false,
|
|
||||||
wantReason: NotFilteredNotFound,
|
|
||||||
qtype: dns.TypeTXT,
|
|
||||||
}, {
|
|
||||||
name: "cname_a",
|
|
||||||
host: "cname.org",
|
|
||||||
wantIsFiltered: false,
|
|
||||||
wantReason: Rewritten,
|
|
||||||
qtype: dns.TypeA,
|
|
||||||
}, {
|
|
||||||
name: "cname_aaaa",
|
|
||||||
host: "cname.org",
|
|
||||||
wantIsFiltered: false,
|
|
||||||
wantReason: Rewritten,
|
|
||||||
qtype: dns.TypeAAAA,
|
|
||||||
}, {
|
|
||||||
name: "cname_txt",
|
|
||||||
host: "cname.org",
|
|
||||||
wantIsFiltered: false,
|
|
||||||
wantReason: NotFilteredNotFound,
|
|
||||||
qtype: dns.TypeTXT,
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
d, setts := newForTest(t, &Config{
|
|
||||||
Rewrites: rewrites,
|
|
||||||
}, nil)
|
|
||||||
t.Cleanup(d.Close)
|
|
||||||
|
|
||||||
res, err := d.CheckHost(tc.host, tc.qtype, setts)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, tc.wantIsFiltered, res.IsFiltered)
|
|
||||||
assert.Equal(t, tc.wantReason, res.Reason)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWhitelist(t *testing.T) {
|
func TestWhitelist(t *testing.T) {
|
||||||
rules := `||host1^
|
rules := `||host1^
|
||||||
||host2^
|
||host2^
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ func TestDNSFilter_handleFilteringSetURL(t *testing.T) {
|
|||||||
},
|
},
|
||||||
ConfigModified: func() { confModifiedCalled = true },
|
ConfigModified: func() { confModifiedCalled = true },
|
||||||
DataDir: filtersDir,
|
DataDir: filtersDir,
|
||||||
}, nil)
|
}, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
|
||||||
|
|||||||
42
internal/filtering/rewrite.go
Normal file
42
internal/filtering/rewrite.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package filtering
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/AdguardTeam/urlfilter"
|
||||||
|
"github.com/AdguardTeam/urlfilter/rules"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RewriteStorage is a storage for rewrite rules.
|
||||||
|
type RewriteStorage interface {
|
||||||
|
// MatchRequest returns matching dnsrewrites for the specified request.
|
||||||
|
MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.DNSRewrite)
|
||||||
|
|
||||||
|
// Add adds item to the storage.
|
||||||
|
Add(item *RewriteItem) (err error)
|
||||||
|
|
||||||
|
// Remove deletes item from the storage.
|
||||||
|
Remove(item *RewriteItem) (err error)
|
||||||
|
|
||||||
|
// List returns all items from the storage.
|
||||||
|
List() (items []*RewriteItem)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RewriteItem is a single DNS rewrite record.
|
||||||
|
type RewriteItem struct {
|
||||||
|
// Domain is the domain pattern for which this rewrite should work.
|
||||||
|
Domain string `yaml:"domain" json:"domain"`
|
||||||
|
|
||||||
|
// Answer is the IP address, canonical name, or one of the special
|
||||||
|
// values: "A" or "AAAA".
|
||||||
|
Answer string `yaml:"answer" json:"answer"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Equal returns true if rw is Equal to other.
|
||||||
|
func (rw *RewriteItem) Equal(other *RewriteItem) (ok bool) {
|
||||||
|
if rw == nil {
|
||||||
|
return other == nil
|
||||||
|
} else if other == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return *rw == *other
|
||||||
|
}
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
package rewrite
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Item is a single DNS rewrite record.
|
|
||||||
type Item struct {
|
|
||||||
// Domain is the domain pattern for which this rewrite should work.
|
|
||||||
Domain string `yaml:"domain" json:"domain"`
|
|
||||||
|
|
||||||
// Answer is the IP address, canonical name, or one of the special
|
|
||||||
// values: "A" or "AAAA".
|
|
||||||
Answer string `yaml:"answer" json:"answer"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// equal returns true if rw is equal to other.
|
|
||||||
func (rw *Item) equal(other *Item) (ok bool) {
|
|
||||||
if rw == nil {
|
|
||||||
return other == nil
|
|
||||||
} else if other == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return *rw == *other
|
|
||||||
}
|
|
||||||
|
|
||||||
// toRule converts rw to a filter rule.
|
|
||||||
func (rw *Item) toRule() (res string) {
|
|
||||||
if rw == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
domain := strings.ToLower(rw.Domain)
|
|
||||||
|
|
||||||
dType, exception := rw.rewriteParams()
|
|
||||||
dTypeKey := dns.TypeToString[dType]
|
|
||||||
if exception {
|
|
||||||
return fmt.Sprintf("@@||%s^$dnstype=%s,dnsrewrite", domain, dTypeKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("|%s^$dnsrewrite=NOERROR;%s;%s", domain, dTypeKey, rw.Answer)
|
|
||||||
}
|
|
||||||
|
|
||||||
// rewriteParams returns dns request type and exception flag for rw.
|
|
||||||
func (rw *Item) rewriteParams() (dType uint16, exception bool) {
|
|
||||||
switch rw.Answer {
|
|
||||||
case "AAAA":
|
|
||||||
return dns.TypeAAAA, true
|
|
||||||
case "A":
|
|
||||||
return dns.TypeA, true
|
|
||||||
default:
|
|
||||||
// Go on.
|
|
||||||
}
|
|
||||||
|
|
||||||
addr, err := netip.ParseAddr(rw.Answer)
|
|
||||||
if err != nil {
|
|
||||||
// TODO(d.kolyshev): Validate rw.Answer as a domain name.
|
|
||||||
return dns.TypeCNAME, false
|
|
||||||
}
|
|
||||||
|
|
||||||
if addr.Is4() {
|
|
||||||
dType = dns.TypeA
|
|
||||||
} else {
|
|
||||||
dType = dns.TypeAAAA
|
|
||||||
}
|
|
||||||
|
|
||||||
return dType, false
|
|
||||||
}
|
|
||||||
@@ -1,124 +0,0 @@
|
|||||||
package rewrite
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestItem_equal(t *testing.T) {
|
|
||||||
const (
|
|
||||||
testDomain = "example.org"
|
|
||||||
testAnswer = "1.1.1.1"
|
|
||||||
)
|
|
||||||
|
|
||||||
testItem := &Item{
|
|
||||||
Domain: testDomain,
|
|
||||||
Answer: testAnswer,
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
left *Item
|
|
||||||
right *Item
|
|
||||||
want bool
|
|
||||||
}{{
|
|
||||||
name: "nil_left",
|
|
||||||
left: nil,
|
|
||||||
right: testItem,
|
|
||||||
want: false,
|
|
||||||
}, {
|
|
||||||
name: "nil_right",
|
|
||||||
left: testItem,
|
|
||||||
right: nil,
|
|
||||||
want: false,
|
|
||||||
}, {
|
|
||||||
name: "nils",
|
|
||||||
left: nil,
|
|
||||||
right: nil,
|
|
||||||
want: true,
|
|
||||||
}, {
|
|
||||||
name: "equal",
|
|
||||||
left: testItem,
|
|
||||||
right: testItem,
|
|
||||||
want: true,
|
|
||||||
}, {
|
|
||||||
name: "distinct",
|
|
||||||
left: testItem,
|
|
||||||
right: &Item{
|
|
||||||
Domain: "other",
|
|
||||||
Answer: "other",
|
|
||||||
},
|
|
||||||
want: false,
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
res := tc.left.equal(tc.right)
|
|
||||||
assert.Equal(t, tc.want, res)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestItem_toRule(t *testing.T) {
|
|
||||||
const testDomain = "example.org"
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
item *Item
|
|
||||||
want string
|
|
||||||
}{{
|
|
||||||
name: "nil",
|
|
||||||
item: nil,
|
|
||||||
want: "",
|
|
||||||
}, {
|
|
||||||
name: "a_rule",
|
|
||||||
item: &Item{
|
|
||||||
Domain: testDomain,
|
|
||||||
Answer: "1.1.1.1",
|
|
||||||
},
|
|
||||||
want: "|example.org^$dnsrewrite=NOERROR;A;1.1.1.1",
|
|
||||||
}, {
|
|
||||||
name: "aaaa_rule",
|
|
||||||
item: &Item{
|
|
||||||
Domain: testDomain,
|
|
||||||
Answer: "1:2:3::4",
|
|
||||||
},
|
|
||||||
want: "|example.org^$dnsrewrite=NOERROR;AAAA;1:2:3::4",
|
|
||||||
}, {
|
|
||||||
name: "cname_rule",
|
|
||||||
item: &Item{
|
|
||||||
Domain: testDomain,
|
|
||||||
Answer: "other.org",
|
|
||||||
},
|
|
||||||
want: "|example.org^$dnsrewrite=NOERROR;CNAME;other.org",
|
|
||||||
}, {
|
|
||||||
name: "wildcard_rule",
|
|
||||||
item: &Item{
|
|
||||||
Domain: "*.example.org",
|
|
||||||
Answer: "other.org",
|
|
||||||
},
|
|
||||||
want: "|*.example.org^$dnsrewrite=NOERROR;CNAME;other.org",
|
|
||||||
}, {
|
|
||||||
name: "aaaa_exception",
|
|
||||||
item: &Item{
|
|
||||||
Domain: testDomain,
|
|
||||||
Answer: "A",
|
|
||||||
},
|
|
||||||
want: "@@||example.org^$dnstype=A,dnsrewrite",
|
|
||||||
}, {
|
|
||||||
name: "aaaa_exception",
|
|
||||||
item: &Item{
|
|
||||||
Domain: testDomain,
|
|
||||||
Answer: "AAAA",
|
|
||||||
},
|
|
||||||
want: "@@||example.org^$dnstype=AAAA,dnsrewrite",
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
res := tc.item.toRule()
|
|
||||||
assert.Equal(t, tc.want, res)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -3,9 +3,11 @@ package rewrite
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/stringutil"
|
"github.com/AdguardTeam/golibs/stringutil"
|
||||||
"github.com/AdguardTeam/urlfilter"
|
"github.com/AdguardTeam/urlfilter"
|
||||||
@@ -15,21 +17,6 @@ import (
|
|||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Storage is a storage for rewrite rules.
|
|
||||||
type Storage interface {
|
|
||||||
// MatchRequest returns matching dnsrewrites for the specified request.
|
|
||||||
MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.DNSRewrite)
|
|
||||||
|
|
||||||
// Add adds item to the storage.
|
|
||||||
Add(item *Item) (err error)
|
|
||||||
|
|
||||||
// Remove deletes item from the storage.
|
|
||||||
Remove(item *Item) (err error)
|
|
||||||
|
|
||||||
// List returns all items from the storage.
|
|
||||||
List() (items []*Item)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultStorage is the default storage for rewrite rules.
|
// DefaultStorage is the default storage for rewrite rules.
|
||||||
type DefaultStorage struct {
|
type DefaultStorage struct {
|
||||||
// mu protects items.
|
// mu protects items.
|
||||||
@@ -42,7 +29,7 @@ type DefaultStorage struct {
|
|||||||
ruleList filterlist.RuleList
|
ruleList filterlist.RuleList
|
||||||
|
|
||||||
// rewrites stores the rewrite entries from configuration.
|
// rewrites stores the rewrite entries from configuration.
|
||||||
rewrites []*Item
|
rewrites []*filtering.RewriteItem
|
||||||
|
|
||||||
// urlFilterID is the synthetic integer identifier for the urlfilter engine.
|
// urlFilterID is the synthetic integer identifier for the urlfilter engine.
|
||||||
//
|
//
|
||||||
@@ -53,10 +40,10 @@ type DefaultStorage struct {
|
|||||||
|
|
||||||
// NewDefaultStorage returns new rewrites storage. listID is used as an
|
// NewDefaultStorage returns new rewrites storage. listID is used as an
|
||||||
// identifier of the underlying rules list. rewrites must not be nil.
|
// identifier of the underlying rules list. rewrites must not be nil.
|
||||||
func NewDefaultStorage(listID int, rewrites []*Item) (s *DefaultStorage, err error) {
|
func NewDefaultStorage(rewrites []*filtering.RewriteItem) (s *DefaultStorage, err error) {
|
||||||
s = &DefaultStorage{
|
s = &DefaultStorage{
|
||||||
mu: &sync.RWMutex{},
|
mu: &sync.RWMutex{},
|
||||||
urlFilterID: listID,
|
urlFilterID: filtering.RewritesListID,
|
||||||
rewrites: rewrites,
|
rewrites: rewrites,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,9 +56,9 @@ func NewDefaultStorage(listID int, rewrites []*Item) (s *DefaultStorage, err err
|
|||||||
}
|
}
|
||||||
|
|
||||||
// type check
|
// type check
|
||||||
var _ Storage = (*DefaultStorage)(nil)
|
var _ filtering.RewriteStorage = (*DefaultStorage)(nil)
|
||||||
|
|
||||||
// MatchRequest implements the [Storage] interface for *DefaultStorage.
|
// MatchRequest implements the [RewriteStorage] interface for *DefaultStorage.
|
||||||
func (s *DefaultStorage) MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.DNSRewrite) {
|
func (s *DefaultStorage) MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.DNSRewrite) {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
@@ -160,8 +147,8 @@ func (s *DefaultStorage) rewriteRulesForReq(dReq *urlfilter.DNSRequest) (rules [
|
|||||||
return res.DNSRewrites()
|
return res.DNSRewrites()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add implements the [Storage] interface for *DefaultStorage.
|
// Add implements the [RewriteStorage] interface for *DefaultStorage.
|
||||||
func (s *DefaultStorage) Add(item *Item) (err error) {
|
func (s *DefaultStorage) Add(item *filtering.RewriteItem) (err error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
@@ -171,16 +158,16 @@ func (s *DefaultStorage) Add(item *Item) (err error) {
|
|||||||
return s.resetRules()
|
return s.resetRules()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove implements the [Storage] interface for *DefaultStorage.
|
// Remove implements the [RewriteStorage] interface for *DefaultStorage.
|
||||||
func (s *DefaultStorage) Remove(item *Item) (err error) {
|
func (s *DefaultStorage) Remove(item *filtering.RewriteItem) (err error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
arr := []*Item{}
|
arr := []*filtering.RewriteItem{}
|
||||||
|
|
||||||
// TODO(d.kolyshev): Use slices.IndexFunc + slices.Delete?
|
// TODO(d.kolyshev): Use slices.IndexFunc + slices.Delete?
|
||||||
for _, ent := range s.rewrites {
|
for _, ent := range s.rewrites {
|
||||||
if ent.equal(item) {
|
if ent.Equal(item) {
|
||||||
log.Debug("rewrite: removed element: %s -> %s", ent.Domain, ent.Answer)
|
log.Debug("rewrite: removed element: %s -> %s", ent.Domain, ent.Answer)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
@@ -193,8 +180,8 @@ func (s *DefaultStorage) Remove(item *Item) (err error) {
|
|||||||
return s.resetRules()
|
return s.resetRules()
|
||||||
}
|
}
|
||||||
|
|
||||||
// List implements the [Storage] interface for *DefaultStorage.
|
// List implements the [RewriteStorage] interface for *DefaultStorage.
|
||||||
func (s *DefaultStorage) List() (items []*Item) {
|
func (s *DefaultStorage) List() (items []*filtering.RewriteItem) {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
@@ -206,7 +193,7 @@ func (s *DefaultStorage) resetRules() (err error) {
|
|||||||
// TODO(a.garipov): Use strings.Builder.
|
// TODO(a.garipov): Use strings.Builder.
|
||||||
var rulesText []string
|
var rulesText []string
|
||||||
for _, rewrite := range s.rewrites {
|
for _, rewrite := range s.rewrites {
|
||||||
rulesText = append(rulesText, rewrite.toRule())
|
rulesText = append(rulesText, toRule(rewrite))
|
||||||
}
|
}
|
||||||
|
|
||||||
strList := &filterlist.StringRuleList{
|
strList := &filterlist.StringRuleList{
|
||||||
@@ -247,3 +234,46 @@ func matchesQType(dnsrr *rules.DNSRewrite, qt uint16) (ok bool) {
|
|||||||
func isWildcard(pat string) (res bool) {
|
func isWildcard(pat string) (res bool) {
|
||||||
return strings.HasPrefix(pat, "|*.")
|
return strings.HasPrefix(pat, "|*.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// toRule converts rw to a filter rule.
|
||||||
|
func toRule(rw *filtering.RewriteItem) (res string) {
|
||||||
|
if rw == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
domain := strings.ToLower(rw.Domain)
|
||||||
|
|
||||||
|
dType, exception := rewriteParams(rw)
|
||||||
|
dTypeKey := dns.TypeToString[dType]
|
||||||
|
if exception {
|
||||||
|
return fmt.Sprintf("@@||%s^$dnstype=%s,dnsrewrite", domain, dTypeKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("|%s^$dnsrewrite=NOERROR;%s;%s", domain, dTypeKey, rw.Answer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RewriteParams returns dns request type and exception flag for rw.
|
||||||
|
func rewriteParams(rw *filtering.RewriteItem) (dType uint16, exception bool) {
|
||||||
|
switch rw.Answer {
|
||||||
|
case "AAAA":
|
||||||
|
return dns.TypeAAAA, true
|
||||||
|
case "A":
|
||||||
|
return dns.TypeA, true
|
||||||
|
default:
|
||||||
|
// Go on.
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, err := netip.ParseAddr(rw.Answer)
|
||||||
|
if err != nil {
|
||||||
|
// TODO(d.kolyshev): Validate rw.Answer as a domain name.
|
||||||
|
return dns.TypeCNAME, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if addr.Is4() {
|
||||||
|
dType = dns.TypeA
|
||||||
|
} else {
|
||||||
|
dType = dns.TypeAAAA
|
||||||
|
}
|
||||||
|
|
||||||
|
return dType, false
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/urlfilter"
|
"github.com/AdguardTeam/urlfilter"
|
||||||
"github.com/AdguardTeam/urlfilter/rules"
|
"github.com/AdguardTeam/urlfilter/rules"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@@ -12,32 +13,32 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestNewDefaultStorage(t *testing.T) {
|
func TestNewDefaultStorage(t *testing.T) {
|
||||||
items := []*Item{{
|
items := []*filtering.RewriteItem{{
|
||||||
Domain: "example.com",
|
Domain: "example.com",
|
||||||
Answer: "answer.com",
|
Answer: "answer.com",
|
||||||
}}
|
}}
|
||||||
|
|
||||||
s, err := NewDefaultStorage(-1, items)
|
s, err := NewDefaultStorage(items)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Len(t, s.List(), 1)
|
require.Len(t, s.List(), 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultStorage_CRUD(t *testing.T) {
|
func TestDefaultStorage_CRUD(t *testing.T) {
|
||||||
var items []*Item
|
var items []*filtering.RewriteItem
|
||||||
|
|
||||||
s, err := NewDefaultStorage(-1, items)
|
s, err := NewDefaultStorage(items)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, s.List(), 0)
|
require.Len(t, s.List(), 0)
|
||||||
|
|
||||||
item := &Item{Domain: "example.com", Answer: "answer.com"}
|
item := &filtering.RewriteItem{Domain: "example.com", Answer: "answer.com"}
|
||||||
|
|
||||||
err = s.Add(item)
|
err = s.Add(item)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
list := s.List()
|
list := s.List()
|
||||||
require.Len(t, list, 1)
|
require.Len(t, list, 1)
|
||||||
require.True(t, item.equal(list[0]))
|
require.True(t, item.Equal(list[0]))
|
||||||
|
|
||||||
err = s.Remove(item)
|
err = s.Remove(item)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -45,7 +46,7 @@ func TestDefaultStorage_CRUD(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultStorage_MatchRequest(t *testing.T) {
|
func TestDefaultStorage_MatchRequest(t *testing.T) {
|
||||||
items := []*Item{{
|
items := []*filtering.RewriteItem{{
|
||||||
// This one and below are about CNAME, A and AAAA.
|
// This one and below are about CNAME, A and AAAA.
|
||||||
Domain: "somecname",
|
Domain: "somecname",
|
||||||
Answer: "somehost.com",
|
Answer: "somehost.com",
|
||||||
@@ -101,7 +102,7 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
|
|||||||
Answer: "sub.issue4016.com",
|
Answer: "sub.issue4016.com",
|
||||||
}}
|
}}
|
||||||
|
|
||||||
s, err := NewDefaultStorage(-1, items)
|
s, err := NewDefaultStorage(items)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
@@ -285,7 +286,7 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
|
|||||||
|
|
||||||
func TestDefaultStorage_MatchRequest_Levels(t *testing.T) {
|
func TestDefaultStorage_MatchRequest_Levels(t *testing.T) {
|
||||||
// Exact host, wildcard L2, wildcard L3.
|
// Exact host, wildcard L2, wildcard L3.
|
||||||
items := []*Item{{
|
items := []*filtering.RewriteItem{{
|
||||||
Domain: "host.com",
|
Domain: "host.com",
|
||||||
Answer: "1.1.1.1",
|
Answer: "1.1.1.1",
|
||||||
}, {
|
}, {
|
||||||
@@ -296,7 +297,7 @@ func TestDefaultStorage_MatchRequest_Levels(t *testing.T) {
|
|||||||
Answer: "3.3.3.3",
|
Answer: "3.3.3.3",
|
||||||
}}
|
}}
|
||||||
|
|
||||||
s, err := NewDefaultStorage(-1, items)
|
s, err := NewDefaultStorage(items)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
@@ -355,7 +356,7 @@ func TestDefaultStorage_MatchRequest_Levels(t *testing.T) {
|
|||||||
|
|
||||||
func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) {
|
func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) {
|
||||||
// Wildcard and exception for a sub-domain.
|
// Wildcard and exception for a sub-domain.
|
||||||
items := []*Item{{
|
items := []*filtering.RewriteItem{{
|
||||||
Domain: "*.host.com",
|
Domain: "*.host.com",
|
||||||
Answer: "2.2.2.2",
|
Answer: "2.2.2.2",
|
||||||
}, {
|
}, {
|
||||||
@@ -366,7 +367,7 @@ func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) {
|
|||||||
Answer: "sub.host.com",
|
Answer: "sub.host.com",
|
||||||
}}
|
}}
|
||||||
|
|
||||||
s, err := NewDefaultStorage(-1, items)
|
s, err := NewDefaultStorage(items)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
@@ -410,7 +411,7 @@ func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) {
|
|||||||
|
|
||||||
func TestDefaultStorage_MatchRequest_CNAMEs(t *testing.T) {
|
func TestDefaultStorage_MatchRequest_CNAMEs(t *testing.T) {
|
||||||
// Two cname rules for one subdomain
|
// Two cname rules for one subdomain
|
||||||
items := []*Item{{
|
items := []*filtering.RewriteItem{{
|
||||||
Domain: "cname.org",
|
Domain: "cname.org",
|
||||||
Answer: "1.1.1.1",
|
Answer: "1.1.1.1",
|
||||||
}, {
|
}, {
|
||||||
@@ -424,7 +425,7 @@ func TestDefaultStorage_MatchRequest_CNAMEs(t *testing.T) {
|
|||||||
Answer: "sub_cname.org",
|
Answer: "sub_cname.org",
|
||||||
}}
|
}}
|
||||||
|
|
||||||
s, err := NewDefaultStorage(-1, items)
|
s, err := NewDefaultStorage(items)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
@@ -478,7 +479,7 @@ func TestDefaultStorage_MatchRequest_CNAMEs(t *testing.T) {
|
|||||||
|
|
||||||
func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) {
|
func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) {
|
||||||
// Exception for AAAA record.
|
// Exception for AAAA record.
|
||||||
items := []*Item{{
|
items := []*filtering.RewriteItem{{
|
||||||
Domain: "host.com",
|
Domain: "host.com",
|
||||||
Answer: "1.2.3.4",
|
Answer: "1.2.3.4",
|
||||||
}, {
|
}, {
|
||||||
@@ -495,7 +496,7 @@ func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) {
|
|||||||
Answer: "A",
|
Answer: "A",
|
||||||
}}
|
}}
|
||||||
|
|
||||||
s, err := NewDefaultStorage(-1, items)
|
s, err := NewDefaultStorage(items)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
@@ -556,3 +557,66 @@ func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestToRule(t *testing.T) {
|
||||||
|
const testDomain = "example.org"
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
item *filtering.RewriteItem
|
||||||
|
want string
|
||||||
|
}{{
|
||||||
|
name: "nil",
|
||||||
|
item: nil,
|
||||||
|
want: "",
|
||||||
|
}, {
|
||||||
|
name: "a_rule",
|
||||||
|
item: &filtering.RewriteItem{
|
||||||
|
Domain: testDomain,
|
||||||
|
Answer: "1.1.1.1",
|
||||||
|
},
|
||||||
|
want: "|example.org^$dnsrewrite=NOERROR;A;1.1.1.1",
|
||||||
|
}, {
|
||||||
|
name: "aaaa_rule",
|
||||||
|
item: &filtering.RewriteItem{
|
||||||
|
Domain: testDomain,
|
||||||
|
Answer: "1:2:3::4",
|
||||||
|
},
|
||||||
|
want: "|example.org^$dnsrewrite=NOERROR;AAAA;1:2:3::4",
|
||||||
|
}, {
|
||||||
|
name: "cname_rule",
|
||||||
|
item: &filtering.RewriteItem{
|
||||||
|
Domain: testDomain,
|
||||||
|
Answer: "other.org",
|
||||||
|
},
|
||||||
|
want: "|example.org^$dnsrewrite=NOERROR;CNAME;other.org",
|
||||||
|
}, {
|
||||||
|
name: "wildcard_rule",
|
||||||
|
item: &filtering.RewriteItem{
|
||||||
|
Domain: "*.example.org",
|
||||||
|
Answer: "other.org",
|
||||||
|
},
|
||||||
|
want: "|*.example.org^$dnsrewrite=NOERROR;CNAME;other.org",
|
||||||
|
}, {
|
||||||
|
name: "aaaa_exception",
|
||||||
|
item: &filtering.RewriteItem{
|
||||||
|
Domain: testDomain,
|
||||||
|
Answer: "A",
|
||||||
|
},
|
||||||
|
want: "@@||example.org^$dnstype=A,dnsrewrite",
|
||||||
|
}, {
|
||||||
|
name: "aaaa_exception",
|
||||||
|
item: &filtering.RewriteItem{
|
||||||
|
Domain: testDomain,
|
||||||
|
Answer: "AAAA",
|
||||||
|
},
|
||||||
|
want: "@@||example.org^$dnstype=AAAA,dnsrewrite",
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
res := toRule(tc.item)
|
||||||
|
assert.Equal(t, tc.want, res)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
61
internal/filtering/rewrite_test.go
Normal file
61
internal/filtering/rewrite_test.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package filtering
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestItem_equal(t *testing.T) {
|
||||||
|
const (
|
||||||
|
testDomain = "example.org"
|
||||||
|
testAnswer = "1.1.1.1"
|
||||||
|
)
|
||||||
|
|
||||||
|
testItem := &RewriteItem{
|
||||||
|
Domain: testDomain,
|
||||||
|
Answer: testAnswer,
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
left *RewriteItem
|
||||||
|
right *RewriteItem
|
||||||
|
want bool
|
||||||
|
}{{
|
||||||
|
name: "nil_left",
|
||||||
|
left: nil,
|
||||||
|
right: testItem,
|
||||||
|
want: false,
|
||||||
|
}, {
|
||||||
|
name: "nil_right",
|
||||||
|
left: testItem,
|
||||||
|
right: nil,
|
||||||
|
want: false,
|
||||||
|
}, {
|
||||||
|
name: "nils",
|
||||||
|
left: nil,
|
||||||
|
right: nil,
|
||||||
|
want: true,
|
||||||
|
}, {
|
||||||
|
name: "equal",
|
||||||
|
left: testItem,
|
||||||
|
right: testItem,
|
||||||
|
want: true,
|
||||||
|
}, {
|
||||||
|
name: "distinct",
|
||||||
|
left: testItem,
|
||||||
|
right: &RewriteItem{
|
||||||
|
Domain: "other",
|
||||||
|
Answer: "other",
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
res := tc.left.Equal(tc.right)
|
||||||
|
assert.Equal(t, tc.want, res)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rewrite"
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -16,7 +15,7 @@ func (d *DNSFilter) handleRewriteList(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// handleRewriteAdd is the handler for the POST /control/rewrite/add HTTP API.
|
// handleRewriteAdd is the handler for the POST /control/rewrite/add HTTP API.
|
||||||
func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
|
func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
|
||||||
rw := &rewrite.Item{}
|
rw := &RewriteItem{}
|
||||||
err := json.NewDecoder(r.Body).Decode(rw)
|
err := json.NewDecoder(r.Body).Decode(rw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||||
@@ -43,7 +42,7 @@ func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
|
|||||||
// handleRewriteDelete is the handler for the POST /control/rewrite/delete HTTP
|
// handleRewriteDelete is the handler for the POST /control/rewrite/delete HTTP
|
||||||
// API.
|
// API.
|
||||||
func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request) {
|
func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request) {
|
||||||
entDel := rewrite.Item{}
|
entDel := RewriteItem{}
|
||||||
err := json.NewDecoder(r.Body).Decode(&entDel)
|
err := json.NewDecoder(r.Body).Decode(&entDel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rewrite"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
@@ -76,30 +77,21 @@ func initDNSServer() (err error) {
|
|||||||
}
|
}
|
||||||
Context.queryLog = querylog.New(conf)
|
Context.queryLog = querylog.New(conf)
|
||||||
|
|
||||||
Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil)
|
rewriteStorage, err := rewrite.NewDefaultStorage(config.DNS.DnsfilterConf.Rewrites)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("rewrites: init: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil, rewriteStorage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Don't wrap the error, since it's informative enough as is.
|
// Don't wrap the error, since it's informative enough as is.
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var privateNets netutil.SubnetSet
|
privateNets, err := initPrivateNets()
|
||||||
switch len(config.DNS.PrivateNets) {
|
if err != nil {
|
||||||
case 0:
|
// Don't wrap the error, since it's informative enough as is.
|
||||||
// Use an optimized locally-served matcher.
|
return err
|
||||||
privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
|
||||||
case 1:
|
|
||||||
privateNets, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("preparing the set of private subnets: %w", err)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
var nets []*net.IPNet
|
|
||||||
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("preparing the set of private subnets: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
privateNets = netutil.SliceSubnetSet(nets)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
p := dnsforward.DNSCreateParams{
|
p := dnsforward.DNSCreateParams{
|
||||||
@@ -146,6 +138,29 @@ func initDNSServer() (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func initPrivateNets() (privateNets netutil.SubnetSet, err error) {
|
||||||
|
switch len(config.DNS.PrivateNets) {
|
||||||
|
case 0:
|
||||||
|
// Use an optimized locally-served matcher.
|
||||||
|
privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||||
|
case 1:
|
||||||
|
privateNets, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("preparing the set of private subnets: %w", err)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
var nets []*net.IPNet
|
||||||
|
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("preparing the set of private subnets: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
privateNets = netutil.SliceSubnetSet(nets)
|
||||||
|
}
|
||||||
|
|
||||||
|
return privateNets, nil
|
||||||
|
}
|
||||||
|
|
||||||
func isRunning() bool {
|
func isRunning() bool {
|
||||||
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
|
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user