diff --git a/internal/filtering/dnsrewrite.go b/internal/filtering/dnsrewrite.go index 58a8dd7b..ee93104b 100644 --- a/internal/filtering/dnsrewrite.go +++ b/internal/filtering/dnsrewrite.go @@ -3,6 +3,7 @@ package filtering import ( "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" + "golang.org/x/exp/slices" ) // DNSRewriteResult is the result of application of $dnsrewrite rules. @@ -24,7 +25,13 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) { Response: DNSRewriteResultResponse{}, } + slices.SortFunc(dnsr, rewriteSortsBefore) + for _, nr := range dnsr { + if containsWildcard(nr) { + break + } + dr := nr.DNSRewrite if dr.NewCNAME != "" { // NewCNAME rules have a higher priority than other rules. @@ -73,3 +80,19 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) { Reason: RewrittenRule, } } + +func rewriteSortsBefore(a, b *rules.NetworkRule) (sortsBefore bool) { + return len(a.Shortcut) > len(b.Shortcut) +} + +func containsWildcard(r *rules.NetworkRule) (ok bool) { + for _, c := range r.RuleText { + if c == '*' { + return true + } else if c == '^' { + break + } + } + + return false +} diff --git a/internal/filtering/dnsrewrite_test.go b/internal/filtering/dnsrewrite_test.go index c75ea2b9..7d9ae867 100644 --- a/internal/filtering/dnsrewrite_test.go +++ b/internal/filtering/dnsrewrite_test.go @@ -5,6 +5,7 @@ import ( "path" "testing" + "github.com/AdguardTeam/urlfilter" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -202,3 +203,32 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) { assert.Equal(t, "new-ptr-with-dot.", ptr) }) } + +func TestDNSFilter_processDNSRewrites(t *testing.T) { + const text = ` +|www.example.com^$dnsrewrite=127.0.0.1 +|*.example.com^$dnsrewrite=127.0.0.2 +` + + host := "www.example.com" + rrtype := dns.TypeA + + f, _ := newForTest(t, nil, []Filter{{ID: 0, Data: []byte(text)}}) + setts := &Settings{ + FilteringEnabled: true, + } + + ufReq := &urlfilter.DNSRequest{ + Hostname: host, + SortedClientTags: setts.ClientTags, + ClientIP: setts.ClientIP.String(), + ClientName: setts.ClientName, + DNSType: rrtype, + } + + dres, matched := f.filteringEngine.MatchRequest(ufReq) + require.False(t, matched) + + res := f.processDNSResultRewrites(dres, host) + assert.Len(t, res.Rules, 1) +}