diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 60044e13..19ec6f50 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -68,7 +68,7 @@ func createTestServer( ID: 0, Data: []byte(rules), }} - f, err := filtering.New(filterConf, filters) + f, err := filtering.New(filterConf, filters, nil) require.NoError(t, err) f.SetEnabled(true) @@ -761,7 +761,7 @@ func TestBlockedCustomIP(t *testing.T) { Data: []byte(rules), }} - f, err := filtering.New(&filtering.Config{}, filters) + f, err := filtering.New(&filtering.Config{}, filters, nil) require.NoError(t, err) s, err := NewServer(DNSCreateParams{ @@ -881,7 +881,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) { func TestRewrite(t *testing.T) { c := &filtering.Config{ - Rewrites: []*rewrite.Item{{ + Rewrites: []*filtering.RewriteItem{{ Domain: "test.com", Answer: "1.2.3.4", }, { @@ -892,7 +892,11 @@ func TestRewrite(t *testing.T) { Answer: "example.org", }}, } - f, err := filtering.New(c, nil) + + rewriteStorage, err := rewrite.NewDefaultStorage(c.Rewrites) + require.NoError(t, err) + + f, err := filtering.New(c, nil, rewriteStorage) require.NoError(t, err) f.SetEnabled(true) @@ -943,6 +947,12 @@ func TestRewrite(t *testing.T) { assert.Empty(t, reply.Answer) + req = createTestMessageWithType("test.com.", dns.TypeTXT) + reply, eerr = dns.Exchange(req, addr.String()) + require.NoError(t, eerr) + + assert.Empty(t, reply.Answer) + req = createTestMessageWithType("alias.test.com.", dns.TypeA) reply, eerr = dns.Exchange(req, addr.String()) require.NoError(t, eerr) @@ -953,6 +963,12 @@ func TestRewrite(t *testing.T) { assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype) assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A)) + req = createTestMessageWithType("alias.test.com.", dns.TypeTXT) + reply, eerr = dns.Exchange(req, addr.String()) + require.NoError(t, eerr) + + assert.Empty(t, reply.Answer) + req = createTestMessageWithType("my.alias.example.org.", dns.TypeA) reply, eerr = dns.Exchange(req, addr.String()) require.NoError(t, eerr) @@ -966,6 +982,12 @@ func TestRewrite(t *testing.T) { assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target) assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype) + + req = createTestMessageWithType("my.alias.test.com.", dns.TypeTXT) + reply, eerr = dns.Exchange(req, addr.String()) + require.NoError(t, eerr) + + assert.Empty(t, reply.Answer) } for _, protect := range []bool{true, false} { @@ -1010,7 +1032,7 @@ var testDHCP = &dhcpd.MockInterface{ func TestPTRResponseFromDHCPLeases(t *testing.T) { const localDomain = "lan" - flt, err := filtering.New(&filtering.Config{}, nil) + flt, err := filtering.New(&filtering.Config{}, nil, nil) require.NoError(t, err) s, err := NewServer(DNSCreateParams{ @@ -1084,7 +1106,7 @@ func TestPTRResponseFromHosts(t *testing.T) { flt, err := filtering.New(&filtering.Config{ EtcHosts: hc, - }, nil) + }, nil, nil) require.NoError(t, err) flt.SetEnabled(true) diff --git a/internal/dnsforward/filter_test.go b/internal/dnsforward/filter_test.go index 7fa0985a..a5db9879 100644 --- a/internal/dnsforward/filter_test.go +++ b/internal/dnsforward/filter_test.go @@ -35,7 +35,7 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) { ID: 0, Data: []byte(rules), }} - f, err := filtering.New(&filtering.Config{}, filters) + f, err := filtering.New(&filtering.Config{}, filters, nil) require.NoError(t, err) f.SetEnabled(true) diff --git a/internal/filtering/filter_test.go b/internal/filtering/filter_test.go index 53e846fc..087e3c52 100644 --- a/internal/filtering/filter_test.go +++ b/internal/filtering/filter_test.go @@ -68,7 +68,7 @@ func TestFilters(t *testing.T) { HTTPClient: &http.Client{ Timeout: 5 * time.Second, }, - }, nil) + }, nil, nil) require.NoError(t, err) f := &FilterYAML{ diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index bf7795bf..d682e917 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -18,7 +18,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" - "github.com/AdguardTeam/AdGuardHome/internal/filtering/rewrite" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/errors" @@ -91,7 +90,7 @@ type Config struct { ParentalCacheSize uint `yaml:"parental_cache_size"` // (in bytes) CacheTime uint `yaml:"cache_time"` // Element's TTL (in minutes) - Rewrites []*rewrite.Item `yaml:"rewrites"` + Rewrites []*RewriteItem `yaml:"rewrites"` // Names of services to block (globally). // Per-client settings can override this configuration. @@ -195,7 +194,7 @@ type DNSFilter struct { // TODO(e.burkov): Don't use regexp for such a simple text processing task. filterTitleRegexp *regexp.Regexp - rewriteStorage *rewrite.DefaultStorage + rewriteStorage RewriteStorage hostCheckers []hostChecker } @@ -544,6 +543,10 @@ func (d *DNSFilter) processRewrites(host string, qtype uint16) (res Result) { d.confLock.RLock() defer d.confLock.RUnlock() + if d.rewriteStorage == nil { + return res + } + dnsr := d.rewriteStorage.MatchRequest(&urlfilter.DNSRequest{ Hostname: host, DNSType: qtype, @@ -893,7 +896,7 @@ func InitModule() { // New creates properly initialized DNS Filter that is ready to be used. c must // be non-nil. -func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) { +func New(c *Config, blockFilters []Filter, rewriteStorage RewriteStorage) (d *DNSFilter, err error) { d = &DNSFilter{ resolver: net.DefaultResolver, refreshLock: &sync.Mutex{}, @@ -946,11 +949,7 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) { d.Config = *c d.filtersMu = &sync.RWMutex{} - - d.rewriteStorage, err = rewrite.NewDefaultStorage(RewritesListID, d.Rewrites) - if err != nil { - return nil, fmt.Errorf("rewrites: init: %w", err) - } + d.rewriteStorage = rewriteStorage bsvcs := []string{} for _, s := range d.BlockedServices { diff --git a/internal/filtering/filtering_test.go b/internal/filtering/filtering_test.go index dd74ed84..d365150c 100644 --- a/internal/filtering/filtering_test.go +++ b/internal/filtering/filtering_test.go @@ -9,7 +9,6 @@ import ( "testing" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" - "github.com/AdguardTeam/AdGuardHome/internal/filtering/rewrite" "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/testutil" @@ -47,6 +46,7 @@ func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts ProtectionEnabled: true, FilteringEnabled: true, } + if c != nil { c.SafeBrowsingCacheSize = 10000 c.ParentalCacheSize = 10000 @@ -59,7 +59,8 @@ func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts // It must not be nil. c = &Config{} } - f, err := New(c, filters) + + f, err := New(c, filters, nil) require.NoError(t, err) purgeCaches(f) @@ -695,96 +696,6 @@ func TestMatching(t *testing.T) { } } -func TestRewrites(t *testing.T) { - rewrites := []*rewrite.Item{{ - Domain: "example.org", - Answer: "1.1.1.1", - }, { - Domain: "example-v6.org", - Answer: "1:2:3::4", - }, { - Domain: "cname.org", - Answer: "cname-res.org", - }} - - testCases := []struct { - name string - host string - wantReason Reason - wantIsFiltered bool - qtype uint16 - }{{ - name: "not_found_a", - host: "not-example.org", - wantIsFiltered: false, - wantReason: NotFilteredNotFound, - qtype: dns.TypeA, - }, { - name: "not_found_aaaa", - host: "not-example.org", - wantIsFiltered: false, - wantReason: NotFilteredNotFound, - qtype: dns.TypeAAAA, - }, { - name: "not_found_txt", - host: "not-example.org", - wantIsFiltered: false, - wantReason: NotFilteredNotFound, - qtype: dns.TypeTXT, - }, { - name: "found_a", - host: "example.org", - wantIsFiltered: false, - wantReason: Rewritten, - qtype: dns.TypeA, - }, { - name: "found_aaaa", - host: "example-v6.org", - wantIsFiltered: false, - wantReason: Rewritten, - qtype: dns.TypeAAAA, - }, { - name: "found_txt", - host: "example.org", - wantIsFiltered: false, - wantReason: NotFilteredNotFound, - qtype: dns.TypeTXT, - }, { - name: "cname_a", - host: "cname.org", - wantIsFiltered: false, - wantReason: Rewritten, - qtype: dns.TypeA, - }, { - name: "cname_aaaa", - host: "cname.org", - wantIsFiltered: false, - wantReason: Rewritten, - qtype: dns.TypeAAAA, - }, { - name: "cname_txt", - host: "cname.org", - wantIsFiltered: false, - wantReason: NotFilteredNotFound, - qtype: dns.TypeTXT, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - d, setts := newForTest(t, &Config{ - Rewrites: rewrites, - }, nil) - t.Cleanup(d.Close) - - res, err := d.CheckHost(tc.host, tc.qtype, setts) - require.NoError(t, err) - - assert.Equal(t, tc.wantIsFiltered, res.IsFiltered) - assert.Equal(t, tc.wantReason, res.Reason) - }) - } -} - func TestWhitelist(t *testing.T) { rules := `||host1^ ||host2^ diff --git a/internal/filtering/http_test.go b/internal/filtering/http_test.go index df09c3f9..8cc038b1 100644 --- a/internal/filtering/http_test.go +++ b/internal/filtering/http_test.go @@ -105,7 +105,7 @@ func TestDNSFilter_handleFilteringSetURL(t *testing.T) { }, ConfigModified: func() { confModifiedCalled = true }, DataDir: filtersDir, - }, nil) + }, nil, nil) require.NoError(t, err) t.Cleanup(d.Close) diff --git a/internal/filtering/rewrite.go b/internal/filtering/rewrite.go new file mode 100644 index 00000000..e668ce61 --- /dev/null +++ b/internal/filtering/rewrite.go @@ -0,0 +1,42 @@ +package filtering + +import ( + "github.com/AdguardTeam/urlfilter" + "github.com/AdguardTeam/urlfilter/rules" +) + +// RewriteStorage is a storage for rewrite rules. +type RewriteStorage interface { + // MatchRequest returns matching dnsrewrites for the specified request. + MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.DNSRewrite) + + // Add adds item to the storage. + Add(item *RewriteItem) (err error) + + // Remove deletes item from the storage. + Remove(item *RewriteItem) (err error) + + // List returns all items from the storage. + List() (items []*RewriteItem) +} + +// RewriteItem is a single DNS rewrite record. +type RewriteItem struct { + // Domain is the domain pattern for which this rewrite should work. + Domain string `yaml:"domain" json:"domain"` + + // Answer is the IP address, canonical name, or one of the special + // values: "A" or "AAAA". + Answer string `yaml:"answer" json:"answer"` +} + +// Equal returns true if rw is Equal to other. +func (rw *RewriteItem) Equal(other *RewriteItem) (ok bool) { + if rw == nil { + return other == nil + } else if other == nil { + return false + } + + return *rw == *other +} diff --git a/internal/filtering/rewrite/item.go b/internal/filtering/rewrite/item.go deleted file mode 100644 index f5fbe1cf..00000000 --- a/internal/filtering/rewrite/item.go +++ /dev/null @@ -1,73 +0,0 @@ -package rewrite - -import ( - "fmt" - "net/netip" - "strings" - - "github.com/miekg/dns" -) - -// Item is a single DNS rewrite record. -type Item struct { - // Domain is the domain pattern for which this rewrite should work. - Domain string `yaml:"domain" json:"domain"` - - // Answer is the IP address, canonical name, or one of the special - // values: "A" or "AAAA". - Answer string `yaml:"answer" json:"answer"` -} - -// equal returns true if rw is equal to other. -func (rw *Item) equal(other *Item) (ok bool) { - if rw == nil { - return other == nil - } else if other == nil { - return false - } - - return *rw == *other -} - -// toRule converts rw to a filter rule. -func (rw *Item) toRule() (res string) { - if rw == nil { - return "" - } - - domain := strings.ToLower(rw.Domain) - - dType, exception := rw.rewriteParams() - dTypeKey := dns.TypeToString[dType] - if exception { - return fmt.Sprintf("@@||%s^$dnstype=%s,dnsrewrite", domain, dTypeKey) - } - - return fmt.Sprintf("|%s^$dnsrewrite=NOERROR;%s;%s", domain, dTypeKey, rw.Answer) -} - -// rewriteParams returns dns request type and exception flag for rw. -func (rw *Item) rewriteParams() (dType uint16, exception bool) { - switch rw.Answer { - case "AAAA": - return dns.TypeAAAA, true - case "A": - return dns.TypeA, true - default: - // Go on. - } - - addr, err := netip.ParseAddr(rw.Answer) - if err != nil { - // TODO(d.kolyshev): Validate rw.Answer as a domain name. - return dns.TypeCNAME, false - } - - if addr.Is4() { - dType = dns.TypeA - } else { - dType = dns.TypeAAAA - } - - return dType, false -} diff --git a/internal/filtering/rewrite/item_internal_test.go b/internal/filtering/rewrite/item_internal_test.go deleted file mode 100644 index 68d88223..00000000 --- a/internal/filtering/rewrite/item_internal_test.go +++ /dev/null @@ -1,124 +0,0 @@ -package rewrite - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestItem_equal(t *testing.T) { - const ( - testDomain = "example.org" - testAnswer = "1.1.1.1" - ) - - testItem := &Item{ - Domain: testDomain, - Answer: testAnswer, - } - - testCases := []struct { - name string - left *Item - right *Item - want bool - }{{ - name: "nil_left", - left: nil, - right: testItem, - want: false, - }, { - name: "nil_right", - left: testItem, - right: nil, - want: false, - }, { - name: "nils", - left: nil, - right: nil, - want: true, - }, { - name: "equal", - left: testItem, - right: testItem, - want: true, - }, { - name: "distinct", - left: testItem, - right: &Item{ - Domain: "other", - Answer: "other", - }, - want: false, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - res := tc.left.equal(tc.right) - assert.Equal(t, tc.want, res) - }) - } -} - -func TestItem_toRule(t *testing.T) { - const testDomain = "example.org" - - testCases := []struct { - name string - item *Item - want string - }{{ - name: "nil", - item: nil, - want: "", - }, { - name: "a_rule", - item: &Item{ - Domain: testDomain, - Answer: "1.1.1.1", - }, - want: "|example.org^$dnsrewrite=NOERROR;A;1.1.1.1", - }, { - name: "aaaa_rule", - item: &Item{ - Domain: testDomain, - Answer: "1:2:3::4", - }, - want: "|example.org^$dnsrewrite=NOERROR;AAAA;1:2:3::4", - }, { - name: "cname_rule", - item: &Item{ - Domain: testDomain, - Answer: "other.org", - }, - want: "|example.org^$dnsrewrite=NOERROR;CNAME;other.org", - }, { - name: "wildcard_rule", - item: &Item{ - Domain: "*.example.org", - Answer: "other.org", - }, - want: "|*.example.org^$dnsrewrite=NOERROR;CNAME;other.org", - }, { - name: "aaaa_exception", - item: &Item{ - Domain: testDomain, - Answer: "A", - }, - want: "@@||example.org^$dnstype=A,dnsrewrite", - }, { - name: "aaaa_exception", - item: &Item{ - Domain: testDomain, - Answer: "AAAA", - }, - want: "@@||example.org^$dnstype=AAAA,dnsrewrite", - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - res := tc.item.toRule() - assert.Equal(t, tc.want, res) - }) - } -} diff --git a/internal/filtering/rewrite/storage.go b/internal/filtering/rewrite/storage.go index 221592e0..4c455f00 100644 --- a/internal/filtering/rewrite/storage.go +++ b/internal/filtering/rewrite/storage.go @@ -3,9 +3,11 @@ package rewrite import ( "fmt" + "net/netip" "strings" "sync" + "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/urlfilter" @@ -15,21 +17,6 @@ import ( "golang.org/x/exp/slices" ) -// Storage is a storage for rewrite rules. -type Storage interface { - // MatchRequest returns matching dnsrewrites for the specified request. - MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.DNSRewrite) - - // Add adds item to the storage. - Add(item *Item) (err error) - - // Remove deletes item from the storage. - Remove(item *Item) (err error) - - // List returns all items from the storage. - List() (items []*Item) -} - // DefaultStorage is the default storage for rewrite rules. type DefaultStorage struct { // mu protects items. @@ -42,7 +29,7 @@ type DefaultStorage struct { ruleList filterlist.RuleList // rewrites stores the rewrite entries from configuration. - rewrites []*Item + rewrites []*filtering.RewriteItem // urlFilterID is the synthetic integer identifier for the urlfilter engine. // @@ -53,10 +40,10 @@ type DefaultStorage struct { // NewDefaultStorage returns new rewrites storage. listID is used as an // identifier of the underlying rules list. rewrites must not be nil. -func NewDefaultStorage(listID int, rewrites []*Item) (s *DefaultStorage, err error) { +func NewDefaultStorage(rewrites []*filtering.RewriteItem) (s *DefaultStorage, err error) { s = &DefaultStorage{ mu: &sync.RWMutex{}, - urlFilterID: listID, + urlFilterID: filtering.RewritesListID, rewrites: rewrites, } @@ -69,9 +56,9 @@ func NewDefaultStorage(listID int, rewrites []*Item) (s *DefaultStorage, err err } // type check -var _ Storage = (*DefaultStorage)(nil) +var _ filtering.RewriteStorage = (*DefaultStorage)(nil) -// MatchRequest implements the [Storage] interface for *DefaultStorage. +// MatchRequest implements the [RewriteStorage] interface for *DefaultStorage. func (s *DefaultStorage) MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.DNSRewrite) { s.mu.RLock() defer s.mu.RUnlock() @@ -160,8 +147,8 @@ func (s *DefaultStorage) rewriteRulesForReq(dReq *urlfilter.DNSRequest) (rules [ return res.DNSRewrites() } -// Add implements the [Storage] interface for *DefaultStorage. -func (s *DefaultStorage) Add(item *Item) (err error) { +// Add implements the [RewriteStorage] interface for *DefaultStorage. +func (s *DefaultStorage) Add(item *filtering.RewriteItem) (err error) { s.mu.Lock() defer s.mu.Unlock() @@ -171,16 +158,16 @@ func (s *DefaultStorage) Add(item *Item) (err error) { return s.resetRules() } -// Remove implements the [Storage] interface for *DefaultStorage. -func (s *DefaultStorage) Remove(item *Item) (err error) { +// Remove implements the [RewriteStorage] interface for *DefaultStorage. +func (s *DefaultStorage) Remove(item *filtering.RewriteItem) (err error) { s.mu.Lock() defer s.mu.Unlock() - arr := []*Item{} + arr := []*filtering.RewriteItem{} // TODO(d.kolyshev): Use slices.IndexFunc + slices.Delete? for _, ent := range s.rewrites { - if ent.equal(item) { + if ent.Equal(item) { log.Debug("rewrite: removed element: %s -> %s", ent.Domain, ent.Answer) continue @@ -193,8 +180,8 @@ func (s *DefaultStorage) Remove(item *Item) (err error) { return s.resetRules() } -// List implements the [Storage] interface for *DefaultStorage. -func (s *DefaultStorage) List() (items []*Item) { +// List implements the [RewriteStorage] interface for *DefaultStorage. +func (s *DefaultStorage) List() (items []*filtering.RewriteItem) { s.mu.RLock() defer s.mu.RUnlock() @@ -206,7 +193,7 @@ func (s *DefaultStorage) resetRules() (err error) { // TODO(a.garipov): Use strings.Builder. var rulesText []string for _, rewrite := range s.rewrites { - rulesText = append(rulesText, rewrite.toRule()) + rulesText = append(rulesText, toRule(rewrite)) } strList := &filterlist.StringRuleList{ @@ -247,3 +234,46 @@ func matchesQType(dnsrr *rules.DNSRewrite, qt uint16) (ok bool) { func isWildcard(pat string) (res bool) { return strings.HasPrefix(pat, "|*.") } + +// toRule converts rw to a filter rule. +func toRule(rw *filtering.RewriteItem) (res string) { + if rw == nil { + return "" + } + + domain := strings.ToLower(rw.Domain) + + dType, exception := rewriteParams(rw) + dTypeKey := dns.TypeToString[dType] + if exception { + return fmt.Sprintf("@@||%s^$dnstype=%s,dnsrewrite", domain, dTypeKey) + } + + return fmt.Sprintf("|%s^$dnsrewrite=NOERROR;%s;%s", domain, dTypeKey, rw.Answer) +} + +// RewriteParams returns dns request type and exception flag for rw. +func rewriteParams(rw *filtering.RewriteItem) (dType uint16, exception bool) { + switch rw.Answer { + case "AAAA": + return dns.TypeAAAA, true + case "A": + return dns.TypeA, true + default: + // Go on. + } + + addr, err := netip.ParseAddr(rw.Answer) + if err != nil { + // TODO(d.kolyshev): Validate rw.Answer as a domain name. + return dns.TypeCNAME, false + } + + if addr.Is4() { + dType = dns.TypeA + } else { + dType = dns.TypeAAAA + } + + return dType, false +} diff --git a/internal/filtering/rewrite/storage_test.go b/internal/filtering/rewrite/storage_test.go index 1682e91f..4e0ab7bd 100644 --- a/internal/filtering/rewrite/storage_test.go +++ b/internal/filtering/rewrite/storage_test.go @@ -4,6 +4,7 @@ import ( "net" "testing" + "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/urlfilter" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" @@ -12,32 +13,32 @@ import ( ) func TestNewDefaultStorage(t *testing.T) { - items := []*Item{{ + items := []*filtering.RewriteItem{{ Domain: "example.com", Answer: "answer.com", }} - s, err := NewDefaultStorage(-1, items) + s, err := NewDefaultStorage(items) require.NoError(t, err) require.Len(t, s.List(), 1) } func TestDefaultStorage_CRUD(t *testing.T) { - var items []*Item + var items []*filtering.RewriteItem - s, err := NewDefaultStorage(-1, items) + s, err := NewDefaultStorage(items) require.NoError(t, err) require.Len(t, s.List(), 0) - item := &Item{Domain: "example.com", Answer: "answer.com"} + item := &filtering.RewriteItem{Domain: "example.com", Answer: "answer.com"} err = s.Add(item) require.NoError(t, err) list := s.List() require.Len(t, list, 1) - require.True(t, item.equal(list[0])) + require.True(t, item.Equal(list[0])) err = s.Remove(item) require.NoError(t, err) @@ -45,7 +46,7 @@ func TestDefaultStorage_CRUD(t *testing.T) { } func TestDefaultStorage_MatchRequest(t *testing.T) { - items := []*Item{{ + items := []*filtering.RewriteItem{{ // This one and below are about CNAME, A and AAAA. Domain: "somecname", Answer: "somehost.com", @@ -101,7 +102,7 @@ func TestDefaultStorage_MatchRequest(t *testing.T) { Answer: "sub.issue4016.com", }} - s, err := NewDefaultStorage(-1, items) + s, err := NewDefaultStorage(items) require.NoError(t, err) testCases := []struct { @@ -285,7 +286,7 @@ func TestDefaultStorage_MatchRequest(t *testing.T) { func TestDefaultStorage_MatchRequest_Levels(t *testing.T) { // Exact host, wildcard L2, wildcard L3. - items := []*Item{{ + items := []*filtering.RewriteItem{{ Domain: "host.com", Answer: "1.1.1.1", }, { @@ -296,7 +297,7 @@ func TestDefaultStorage_MatchRequest_Levels(t *testing.T) { Answer: "3.3.3.3", }} - s, err := NewDefaultStorage(-1, items) + s, err := NewDefaultStorage(items) require.NoError(t, err) testCases := []struct { @@ -355,7 +356,7 @@ func TestDefaultStorage_MatchRequest_Levels(t *testing.T) { func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) { // Wildcard and exception for a sub-domain. - items := []*Item{{ + items := []*filtering.RewriteItem{{ Domain: "*.host.com", Answer: "2.2.2.2", }, { @@ -366,7 +367,7 @@ func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) { Answer: "sub.host.com", }} - s, err := NewDefaultStorage(-1, items) + s, err := NewDefaultStorage(items) require.NoError(t, err) testCases := []struct { @@ -410,7 +411,7 @@ func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) { func TestDefaultStorage_MatchRequest_CNAMEs(t *testing.T) { // Two cname rules for one subdomain - items := []*Item{{ + items := []*filtering.RewriteItem{{ Domain: "cname.org", Answer: "1.1.1.1", }, { @@ -424,7 +425,7 @@ func TestDefaultStorage_MatchRequest_CNAMEs(t *testing.T) { Answer: "sub_cname.org", }} - s, err := NewDefaultStorage(-1, items) + s, err := NewDefaultStorage(items) require.NoError(t, err) testCases := []struct { @@ -478,7 +479,7 @@ func TestDefaultStorage_MatchRequest_CNAMEs(t *testing.T) { func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) { // Exception for AAAA record. - items := []*Item{{ + items := []*filtering.RewriteItem{{ Domain: "host.com", Answer: "1.2.3.4", }, { @@ -495,7 +496,7 @@ func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) { Answer: "A", }} - s, err := NewDefaultStorage(-1, items) + s, err := NewDefaultStorage(items) require.NoError(t, err) testCases := []struct { @@ -556,3 +557,66 @@ func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) { }) } } + +func TestToRule(t *testing.T) { + const testDomain = "example.org" + + testCases := []struct { + name string + item *filtering.RewriteItem + want string + }{{ + name: "nil", + item: nil, + want: "", + }, { + name: "a_rule", + item: &filtering.RewriteItem{ + Domain: testDomain, + Answer: "1.1.1.1", + }, + want: "|example.org^$dnsrewrite=NOERROR;A;1.1.1.1", + }, { + name: "aaaa_rule", + item: &filtering.RewriteItem{ + Domain: testDomain, + Answer: "1:2:3::4", + }, + want: "|example.org^$dnsrewrite=NOERROR;AAAA;1:2:3::4", + }, { + name: "cname_rule", + item: &filtering.RewriteItem{ + Domain: testDomain, + Answer: "other.org", + }, + want: "|example.org^$dnsrewrite=NOERROR;CNAME;other.org", + }, { + name: "wildcard_rule", + item: &filtering.RewriteItem{ + Domain: "*.example.org", + Answer: "other.org", + }, + want: "|*.example.org^$dnsrewrite=NOERROR;CNAME;other.org", + }, { + name: "aaaa_exception", + item: &filtering.RewriteItem{ + Domain: testDomain, + Answer: "A", + }, + want: "@@||example.org^$dnstype=A,dnsrewrite", + }, { + name: "aaaa_exception", + item: &filtering.RewriteItem{ + Domain: testDomain, + Answer: "AAAA", + }, + want: "@@||example.org^$dnstype=AAAA,dnsrewrite", + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + res := toRule(tc.item) + assert.Equal(t, tc.want, res) + }) + } +} diff --git a/internal/filtering/rewrite_test.go b/internal/filtering/rewrite_test.go new file mode 100644 index 00000000..cf9b98d1 --- /dev/null +++ b/internal/filtering/rewrite_test.go @@ -0,0 +1,61 @@ +package filtering + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestItem_equal(t *testing.T) { + const ( + testDomain = "example.org" + testAnswer = "1.1.1.1" + ) + + testItem := &RewriteItem{ + Domain: testDomain, + Answer: testAnswer, + } + + testCases := []struct { + name string + left *RewriteItem + right *RewriteItem + want bool + }{{ + name: "nil_left", + left: nil, + right: testItem, + want: false, + }, { + name: "nil_right", + left: testItem, + right: nil, + want: false, + }, { + name: "nils", + left: nil, + right: nil, + want: true, + }, { + name: "equal", + left: testItem, + right: testItem, + want: true, + }, { + name: "distinct", + left: testItem, + right: &RewriteItem{ + Domain: "other", + Answer: "other", + }, + want: false, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + res := tc.left.Equal(tc.right) + assert.Equal(t, tc.want, res) + }) + } +} diff --git a/internal/filtering/rewritehttp.go b/internal/filtering/rewritehttp.go index efe6f46a..08b418b4 100644 --- a/internal/filtering/rewritehttp.go +++ b/internal/filtering/rewritehttp.go @@ -5,7 +5,6 @@ import ( "net/http" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" - "github.com/AdguardTeam/AdGuardHome/internal/filtering/rewrite" "github.com/AdguardTeam/golibs/log" ) @@ -16,7 +15,7 @@ func (d *DNSFilter) handleRewriteList(w http.ResponseWriter, r *http.Request) { // handleRewriteAdd is the handler for the POST /control/rewrite/add HTTP API. func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) { - rw := &rewrite.Item{} + rw := &RewriteItem{} err := json.NewDecoder(r.Body).Decode(rw) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err) @@ -43,7 +42,7 @@ func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) { // handleRewriteDelete is the handler for the POST /control/rewrite/delete HTTP // API. func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request) { - entDel := rewrite.Item{} + entDel := RewriteItem{} err := json.NewDecoder(r.Body).Decode(&entDel) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err) diff --git a/internal/home/dns.go b/internal/home/dns.go index 1980b252..9ec5a3f1 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -12,6 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/AdGuardHome/internal/filtering/rewrite" "github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/dnsproxy/proxy" @@ -76,30 +77,21 @@ func initDNSServer() (err error) { } Context.queryLog = querylog.New(conf) - Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil) + rewriteStorage, err := rewrite.NewDefaultStorage(config.DNS.DnsfilterConf.Rewrites) + if err != nil { + return fmt.Errorf("rewrites: init: %w", err) + } + + Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil, rewriteStorage) if err != nil { // Don't wrap the error, since it's informative enough as is. return err } - var privateNets netutil.SubnetSet - switch len(config.DNS.PrivateNets) { - case 0: - // Use an optimized locally-served matcher. - privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed) - case 1: - privateNets, err = netutil.ParseSubnet(config.DNS.PrivateNets[0]) - if err != nil { - return fmt.Errorf("preparing the set of private subnets: %w", err) - } - default: - var nets []*net.IPNet - nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...) - if err != nil { - return fmt.Errorf("preparing the set of private subnets: %w", err) - } - - privateNets = netutil.SliceSubnetSet(nets) + privateNets, err := initPrivateNets() + if err != nil { + // Don't wrap the error, since it's informative enough as is. + return err } p := dnsforward.DNSCreateParams{ @@ -146,6 +138,29 @@ func initDNSServer() (err error) { return nil } +func initPrivateNets() (privateNets netutil.SubnetSet, err error) { + switch len(config.DNS.PrivateNets) { + case 0: + // Use an optimized locally-served matcher. + privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed) + case 1: + privateNets, err = netutil.ParseSubnet(config.DNS.PrivateNets[0]) + if err != nil { + return nil, fmt.Errorf("preparing the set of private subnets: %w", err) + } + default: + var nets []*net.IPNet + nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...) + if err != nil { + return nil, fmt.Errorf("preparing the set of private subnets: %w", err) + } + + privateNets = netutil.SliceSubnetSet(nets) + } + + return privateNets, nil +} + func isRunning() bool { return Context.dnsServer != nil && Context.dnsServer.IsRunning() }