Merge branch 'master' into 3717-fix-qq-blocked

This commit is contained in:
Ainar Garipov
2022-03-07 19:11:03 +03:00
232 changed files with 12754 additions and 8110 deletions

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"net/http"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules"
)
@@ -239,7 +240,7 @@ func initBlockedServices() {
for _, s := range serviceRulesArray {
netRules := []*rules.NetworkRule{}
for _, text := range s.rules {
rule, err := rules.NewNetworkRule(text, 0)
rule, err := rules.NewNetworkRule(text, BlockedSvcsListID)
if err != nil {
log.Error("rules.NewNetworkRule: %s rule: %s", err, text)
continue
@@ -287,7 +288,8 @@ func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Req
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(list)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
return
}
}
@@ -296,7 +298,8 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ
list := []string{}
err := json.NewDecoder(r.Body).Decode(&list)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}

View File

@@ -15,14 +15,10 @@ type DNSRewriteResult struct {
// the server returns.
type DNSRewriteResultResponse map[rules.RRType][]rules.RRValue
// processDNSRewrites processes DNS rewrite rules in dnsr. It returns
// an empty result if dnsr is empty. Otherwise, the result will have
// either CanonName or DNSRewriteResult set.
// processDNSRewrites processes DNS rewrite rules in dnsr. It returns an empty
// result if dnsr is empty. Otherwise, the result will have either CanonName or
// DNSRewriteResult set. dnsr is expected to be non-empty.
func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
if len(dnsr) == 0 {
return Result{}
}
var rules []*ResultRule
dnsrr := &DNSRewriteResult{
Response: DNSRewriteResultResponse{},
@@ -31,8 +27,7 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
for _, nr := range dnsr {
dr := nr.DNSRewrite
if dr.NewCNAME != "" {
// NewCNAME rules have a higher priority than
// the other rules.
// NewCNAME rules have a higher priority than other rules.
rules = []*ResultRule{{
FilterListID: int64(nr.GetFilterListID()),
Text: nr.RuleText,
@@ -54,8 +49,8 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
Text: nr.RuleText,
})
default:
// RcodeRefused and other such codes have higher
// priority. Return immediately.
// RcodeRefused and other such codes have higher priority. Return
// immediately.
rules = []*ResultRule{{
FilterListID: int64(nr.GetFilterListID()),
Text: nr.RuleText,

View File

@@ -49,7 +49,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|1.2.3.5.in-addr.arpa^$dnsrewrite=NOERROR;PTR;new-ptr-with-dot.
`
f := newForTest(nil, []Filter{{ID: 0, Data: []byte(text)}})
f := newForTest(t, nil, []Filter{{ID: 0, Data: []byte(text)}})
setts := &Settings{
FilteringEnabled: true,
}

View File

@@ -4,6 +4,7 @@ package filtering
import (
"context"
"fmt"
"io/fs"
"net"
"net/http"
"os"
@@ -16,6 +17,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/urlfilter"
@@ -24,6 +26,18 @@ import (
"github.com/miekg/dns"
)
// The IDs of built-in filter lists.
//
// Keep in sync with client/src/helpers/constants.js.
const (
CustomListID = -iota
SysHostsListID
BlockedSvcsListID
ParentalListID
SafeBrowsingListID
SafeSearchListID
)
// ServiceEntry - blocked service array element
type ServiceEntry struct {
Name string
@@ -38,6 +52,7 @@ type Settings struct {
ServicesRules []ServiceEntry
ProtectionEnabled bool
FilteringEnabled bool
SafeSearchEnabled bool
SafeBrowsingEnabled bool
@@ -65,7 +80,7 @@ type Config struct {
ParentalCacheSize uint `yaml:"parental_cache_size"` // (in bytes)
CacheTime uint `yaml:"cache_time"` // Element's TTL (in minutes)
Rewrites []RewriteEntry `yaml:"rewrites"`
Rewrites []*LegacyRewrite `yaml:"rewrites"`
// Names of services to block (globally).
// Per-client settings can override this configuration.
@@ -73,7 +88,7 @@ type Config struct {
// EtcHosts is a container of IP-hostname pairs taken from the operating
// system configuration files (e.g. /etc/hosts).
EtcHosts *aghnet.EtcHostsContainer `yaml:"-"`
EtcHosts *aghnet.HostsContainer `yaml:"-"`
// Called when the configuration is changed by HTTP request
ConfigModified func() `yaml:"-"`
@@ -124,6 +139,10 @@ type DNSFilter struct {
parentalUpstream upstream.Upstream
safeBrowsingUpstream upstream.Upstream
safebrowsingCache cache.Cache
parentalCache cache.Cache
safeSearchCache cache.Cache
Config // for direct access by library users, even a = assignment
// confLock protects Config.
confLock sync.RWMutex
@@ -142,9 +161,14 @@ type DNSFilter struct {
// Filter represents a filter list
type Filter struct {
ID int64 // auto-assigned when filter is added (see nextFilterID)
Data []byte `yaml:"-"` // List of rules divided by '\n'
FilePath string `yaml:"-"` // Path to a filtering rules file
// FilePath is the path to a filtering rules list file.
FilePath string `yaml:"-"`
// Data is the content of the file.
Data []byte `yaml:"-"`
// ID is automatically assigned when filter is added using nextFilterID.
ID int64
}
// Reason holds an enum detailing why it was filtered or not filtered
@@ -176,8 +200,8 @@ const (
// FilteredBlockedService - the host is blocked by "blocked services" settings
FilteredBlockedService
// Rewritten is returned when there was a rewrite by a legacy DNS
// rewrite rule.
// Rewritten is returned when there was a rewrite by a legacy DNS rewrite
// rule.
Rewritten
// RewrittenAutoHosts is returned when there was a rewrite by autohosts
@@ -186,8 +210,8 @@ const (
// RewrittenRule is returned when a $dnsrewrite filter rule was applied.
//
// TODO(a.garipov): Remove Rewritten and RewrittenAutoHosts by merging
// their functionality into RewrittenRule.
// TODO(a.garipov): Remove Rewritten and RewrittenAutoHosts by merging their
// functionality into RewrittenRule.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2499.
RewrittenRule
@@ -221,12 +245,13 @@ func (r Reason) String() string {
}
// In returns true if reasons include r.
func (r Reason) In(reasons ...Reason) bool {
func (r Reason) In(reasons ...Reason) (ok bool) {
for _, reason := range reasons {
if r == reason {
return true
}
}
return false
}
@@ -245,7 +270,7 @@ func (d *DNSFilter) GetConfig() (s Settings) {
defer d.confLock.RUnlock()
return Settings{
FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) == 1,
FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) != 0,
SafeSearchEnabled: d.Config.SafeSearchEnabled,
SafeBrowsingEnabled: d.Config.SafeBrowsingEnabled,
ParentalEnabled: d.Config.ParentalEnabled,
@@ -261,8 +286,14 @@ func (d *DNSFilter) WriteDiskConfig(c *Config) {
c.Rewrites = cloneRewrites(c.Rewrites)
}
func cloneRewrites(entries []RewriteEntry) (clone []RewriteEntry) {
return append([]RewriteEntry(nil), entries...)
// cloneRewrites returns a deep copy of entries.
func cloneRewrites(entries []*LegacyRewrite) (clone []*LegacyRewrite) {
clone = make([]*LegacyRewrite, len(entries))
for i, rw := range entries {
clone[i] = rw.clone()
}
return clone
}
// SetFilters - set new filters (synchronously or asynchronously)
@@ -338,14 +369,6 @@ func (d *DNSFilter) reset() {
}
}
type dnsFilterContext struct {
safebrowsingCache cache.Cache
parentalCache cache.Cache
safeSearchCache cache.Cache
}
var gctx dnsFilterContext
// ResultRule contains information about applied rules.
type ResultRule struct {
// Text is the text of the rule.
@@ -371,24 +394,19 @@ type Result struct {
// Reason is the reason for blocking or unblocking the request.
Reason Reason `json:",omitempty"`
// Rules are applied rules. If Rules are not empty, each rule
// is not nil.
// Rules are applied rules. If Rules are not empty, each rule is not nil.
Rules []*ResultRule `json:",omitempty"`
// ReverseHosts is the reverse lookup rewrite result. It is
// empty unless Reason is set to RewrittenAutoHosts.
ReverseHosts []string `json:",omitempty"`
// IPList is the lookup rewrite result. It is empty unless
// Reason is set to RewrittenAutoHosts or Rewritten.
// IPList is the lookup rewrite result. It is empty unless Reason is set to
// Rewritten.
IPList []net.IP `json:",omitempty"`
// CanonName is the CNAME value from the lookup rewrite result.
// It is empty unless Reason is set to Rewritten or RewrittenRule.
// CanonName is the CNAME value from the lookup rewrite result. It is empty
// unless Reason is set to Rewritten or RewrittenRule.
CanonName string `json:",omitempty"`
// ServiceName is the name of the blocked service. It is empty
// unless Reason is set to FilteredBlockedService.
// ServiceName is the name of the blocked service. It is empty unless
// Reason is set to FilteredBlockedService.
ServiceName string `json:",omitempty"`
// DNSRewriteResult is the $dnsrewrite filter rule result.
@@ -402,14 +420,8 @@ func (r Reason) Matched() bool {
}
// CheckHostRules tries to match the host against filtering rules only.
func (d *DNSFilter) CheckHostRules(host string, qtype uint16, setts *Settings) (Result, error) {
if !setts.FilteringEnabled {
return Result{}, nil
}
host = strings.ToLower(host)
return d.matchHost(host, qtype, setts)
func (d *DNSFilter) CheckHostRules(host string, rrtype uint16, setts *Settings) (Result, error) {
return d.matchHost(strings.ToLower(host), rrtype, setts)
}
// CheckHost tries to match the host against filtering rules, then safebrowsing
@@ -422,14 +434,16 @@ func (d *DNSFilter) CheckHost(
// Sometimes clients try to resolve ".", which is a request to get root
// servers.
if host == "" {
return Result{Reason: NotFilteredNotFound}, nil
return Result{}, nil
}
host = strings.ToLower(host)
res = d.processRewrites(host, qtype)
if res.Reason == Rewritten {
return res, nil
if setts.FilteringEnabled {
res = d.processRewrites(host, qtype)
if res.Reason == Rewritten {
return res, nil
}
}
for _, hc := range d.hostCheckers {
@@ -446,100 +460,123 @@ func (d *DNSFilter) CheckHost(
return Result{}, nil
}
// checkEtcHosts compares the host against our /etc/hosts table. The err is
// always nil, it is only there to make this a valid hostChecker function.
func (d *DNSFilter) checkEtcHosts(
// matchSysHosts tries to match the host against the operating system's hosts
// database. err is always nil.
func (d *DNSFilter) matchSysHosts(
host string,
qtype uint16,
_ *Settings,
setts *Settings,
) (res Result, err error) {
if d.Config.EtcHosts == nil {
return Result{}, nil
}
ips := d.Config.EtcHosts.Process(host, qtype)
if ips != nil {
res = Result{
Reason: RewrittenAutoHosts,
IPList: ips,
}
if !setts.FilteringEnabled || d.EtcHosts == nil {
return res, nil
}
revHosts := d.Config.EtcHosts.ProcessReverse(host, qtype)
if len(revHosts) != 0 {
res = Result{
Reason: RewrittenAutoHosts,
}
// TODO(a.garipov): Optimize this with a buffer.
res.ReverseHosts = make([]string, len(revHosts))
for i := range revHosts {
res.ReverseHosts[i] = revHosts[i] + "."
}
dnsres, _ := d.EtcHosts.MatchRequest(urlfilter.DNSRequest{
Hostname: host,
SortedClientTags: setts.ClientTags,
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.
ClientIP: setts.ClientIP.String(),
ClientName: setts.ClientName,
DNSType: qtype,
})
if dnsres == nil {
return res, nil
}
return Result{}, nil
dnsr := dnsres.DNSRewrites()
if len(dnsr) == 0 {
return res, nil
}
res = d.processDNSRewrites(dnsr)
res.Reason = RewrittenAutoHosts
for _, r := range res.Rules {
r.Text = stringutil.Coalesce(d.EtcHosts.Translate(r.Text), r.Text)
}
return res, nil
}
// Process rewrites table
// . Find CNAME for a domain name (exact match or by wildcard)
// . if found and CNAME equals to domain name - this is an exception; exit
// . if found, set domain name to canonical name
// . repeat for the new domain name (Note: we return only the last CNAME)
// . Find A or AAAA record for a domain name (exact match or by wildcard)
// . if found, set IP addresses (IPv4 or IPv6 depending on qtype) in Result.IPList array
// processRewrites performs filtering based on the legacy rewrite records.
//
// Firstly, it finds CNAME rewrites for host. If the CNAME is the same as host,
// this query isn't filtered. If it's different, repeat the process for the new
// CNAME, breaking loops in the process.
//
// Secondly, it finds A or AAAA rewrites for host and, if found, sets res.IPList
// accordingly. If the found rewrite has a special value of "A" or "AAAA", the
// result is an exception.
func (d *DNSFilter) processRewrites(host string, qtype uint16) (res Result) {
d.confLock.RLock()
defer d.confLock.RUnlock()
rr := findRewrites(d.Rewrites, host, qtype)
if len(rr) != 0 {
res.Reason = Rewritten
rewrites, matched := findRewrites(d.Rewrites, host, qtype)
if !matched {
return Result{}
}
res.Reason = Rewritten
cnames := stringutil.NewSet()
origHost := host
for len(rr) != 0 && rr[0].Type == dns.TypeCNAME {
log.Debug("rewrite: CNAME for %s is %s", host, rr[0].Answer)
for matched && len(rewrites) > 0 && rewrites[0].Type == dns.TypeCNAME {
rw := rewrites[0]
rwPat := rw.Domain
rwAns := rw.Answer
if host == rr[0].Answer { // "host == CNAME" is an exception
res.Reason = NotFilteredNotFound
log.Debug("rewrite: cname for %s is %s", host, rwAns)
return res
if origHost == rwAns || rwPat == rwAns {
// Either a request for the hostname itself or a rewrite of
// a pattern onto itself, both of which are an exception rules.
// Return a not filtered result.
return Result{}
} else if host == rwAns && isWildcard(rwPat) {
// An "*.example.com → sub.example.com" rewrite matching in a loop.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/4016.
res.CanonName = host
break
}
host = rr[0].Answer
host = rwAns
if cnames.Has(host) {
log.Info("rewrite: breaking CNAME redirection loop: %s. Question: %s", host, origHost)
log.Info("rewrite: cname loop for %q on %q", origHost, host)
return res
}
cnames.Add(host)
res.CanonName = rr[0].Answer
rr = findRewrites(d.Rewrites, host, qtype)
res.CanonName = host
rewrites, matched = findRewrites(d.Rewrites, host, qtype)
}
for _, r := range rr {
if r.Type == qtype && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
if r.IP == nil { // IP exception
res.Reason = NotFilteredNotFound
return res
}
res.IPList = append(res.IPList, r.IP)
log.Debug("rewrite: A/AAAA for %s is %s", host, r.IP)
}
}
setRewriteResult(&res, host, rewrites, qtype)
return res
}
// setRewriteResult sets the Reason or IPList of res if necessary. res must not
// be nil.
func setRewriteResult(res *Result, host string, rewrites []*LegacyRewrite, qtype uint16) {
for _, rw := range rewrites {
if rw.Type == qtype && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
if rw.IP == nil {
// "A"/"AAAA" exception: allow getting from upstream.
res.Reason = NotFilteredNotFound
return
}
res.IPList = append(res.IPList, rw.IP)
log.Debug("rewrite: a/aaaa for %s is %s", host, rw.IP)
}
}
}
// matchBlockedServicesRules checks the host against the blocked services rules
// in settings, if any. The err is always nil, it is only there to make this
// a valid hostChecker function.
@@ -548,6 +585,10 @@ func matchBlockedServicesRules(
_ uint16,
setts *Settings,
) (res Result, err error) {
if !setts.ProtectionEnabled {
return Result{}, nil
}
svcs := setts.ServicesRules
if len(svcs) == 0 {
return Result{}, nil
@@ -582,80 +623,82 @@ func matchBlockedServicesRules(
// Adding rule and matching against the rules
//
// fileExists returns true if file exists.
func fileExists(fn string) bool {
_, err := os.Stat(fn)
return err == nil
}
func createFilteringEngine(filters []Filter) (*filterlist.RuleStorage, *urlfilter.DNSEngine, error) {
listArray := []filterlist.RuleList{}
func newRuleStorage(filters []Filter) (rs *filterlist.RuleStorage, err error) {
lists := make([]filterlist.RuleList, 0, len(filters))
for _, f := range filters {
var list filterlist.RuleList
if f.ID == 0 {
list = &filterlist.StringRuleList{
ID: 0,
switch id := int(f.ID); {
case len(f.Data) != 0:
lists = append(lists, &filterlist.StringRuleList{
ID: id,
RulesText: string(f.Data),
IgnoreCosmetic: true,
}
} else if !fileExists(f.FilePath) {
list = &filterlist.StringRuleList{
ID: int(f.ID),
IgnoreCosmetic: true,
}
} else if runtime.GOOS == "windows" {
// On Windows we don't pass a file to urlfilter because
// it's difficult to update this file while it's being
// used.
data, err := os.ReadFile(f.FilePath)
if err != nil {
return nil, nil, fmt.Errorf("reading filter content: %w", err)
})
case f.FilePath == "":
continue
case runtime.GOOS == "windows":
// On Windows we don't pass a file to urlfilter because it's
// difficult to update this file while it's being used.
var data []byte
data, err = os.ReadFile(f.FilePath)
if errors.Is(err, fs.ErrNotExist) {
continue
} else if err != nil {
return nil, fmt.Errorf("reading filter content: %w", err)
}
list = &filterlist.StringRuleList{
ID: int(f.ID),
lists = append(lists, &filterlist.StringRuleList{
ID: id,
RulesText: string(data),
IgnoreCosmetic: true,
})
default:
var list *filterlist.FileRuleList
list, err = filterlist.NewFileRuleList(id, f.FilePath, true)
if errors.Is(err, fs.ErrNotExist) {
continue
} else if err != nil {
return nil, fmt.Errorf("creating file rule list with %q: %w", f.FilePath, err)
}
} else {
var err error
list, err = filterlist.NewFileRuleList(int(f.ID), f.FilePath, true)
if err != nil {
return nil, nil, fmt.Errorf("filterlist.NewFileRuleList(): %s: %w", f.FilePath, err)
}
lists = append(lists, list)
}
listArray = append(listArray, list)
}
rulesStorage, err := filterlist.NewRuleStorage(listArray)
rs, err = filterlist.NewRuleStorage(lists)
if err != nil {
return nil, nil, fmt.Errorf("filterlist.NewRuleStorage(): %w", err)
return nil, fmt.Errorf("creating rule storage: %w", err)
}
filteringEngine := urlfilter.NewDNSEngine(rulesStorage)
return rulesStorage, filteringEngine, nil
return rs, nil
}
// Initialize urlfilter objects.
func (d *DNSFilter) initFiltering(allowFilters, blockFilters []Filter) error {
rulesStorage, filteringEngine, err := createFilteringEngine(blockFilters)
if err != nil {
return err
}
rulesStorageAllow, filteringEngineAllow, err := createFilteringEngine(allowFilters)
rulesStorage, err := newRuleStorage(blockFilters)
if err != nil {
return err
}
d.engineLock.Lock()
d.reset()
d.rulesStorage = rulesStorage
d.filteringEngine = filteringEngine
d.rulesStorageAllow = rulesStorageAllow
d.filteringEngineAllow = filteringEngineAllow
d.engineLock.Unlock()
rulesStorageAllow, err := newRuleStorage(allowFilters)
if err != nil {
return err
}
// Make sure that the OS reclaims memory as soon as possible
filteringEngine := urlfilter.NewDNSEngine(rulesStorage)
filteringEngineAllow := urlfilter.NewDNSEngine(rulesStorageAllow)
func() {
d.engineLock.Lock()
defer d.engineLock.Unlock()
d.reset()
d.rulesStorage = rulesStorage
d.filteringEngine = filteringEngine
d.rulesStorageAllow = rulesStorageAllow
d.filteringEngineAllow = filteringEngineAllow
}()
// Make sure that the OS reclaims memory as soon as possible.
debug.FreeOSMemory()
log.Debug("initialized filtering engine")
@@ -677,11 +720,10 @@ func hostRulesToRules(netRules []*rules.HostRule) (res []rules.Rule) {
return res
}
// matchHostProcessAllowList processes the allowlist logic of host
// matching.
// matchHostProcessAllowList processes the allowlist logic of host matching.
func (d *DNSFilter) matchHostProcessAllowList(
host string,
dnsres urlfilter.DNSResult,
dnsres *urlfilter.DNSResult,
) (res Result, err error) {
var matchedRules []rules.Rule
if dnsres.NetworkRule != nil {
@@ -704,7 +746,7 @@ func (d *DNSFilter) matchHostProcessAllowList(
// matchHostProcessDNSResult processes the matched DNS filtering result.
func (d *DNSFilter) matchHostProcessDNSResult(
qtype uint16,
dnsres urlfilter.DNSResult,
dnsres *urlfilter.DNSResult,
) (res Result) {
if dnsres.NetworkRule != nil {
reason := FilteredBlockList
@@ -734,8 +776,8 @@ func (d *DNSFilter) matchHostProcessDNSResult(
}
if dnsres.HostRulesV4 != nil || dnsres.HostRulesV6 != nil {
// Question type doesn't match the host rules. Return the first
// matched host rule, but without an IP address.
// Question type doesn't match the host rules. Return the first matched
// host rule, but without an IP address.
var matchedRules []rules.Rule
if dnsres.HostRulesV4 != nil {
matchedRules = []rules.Rule{dnsres.HostRulesV4[0]}
@@ -749,32 +791,34 @@ func (d *DNSFilter) matchHostProcessDNSResult(
return Result{}
}
// matchHost is a low-level way to check only if hostname is filtered by rules,
// matchHost is a low-level way to check only if host is filtered by rules,
// skipping expensive safebrowsing and parental lookups.
func (d *DNSFilter) matchHost(
host string,
qtype uint16,
rrtype uint16,
setts *Settings,
) (res Result, err error) {
if !setts.FilteringEnabled {
return Result{}, nil
}
d.engineLock.RLock()
// Keep in mind that this lock must be held no just when calling Match()
// but also while using the rules returned by it.
defer d.engineLock.RUnlock()
ureq := urlfilter.DNSRequest{
Hostname: host,
SortedClientTags: setts.ClientTags,
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.
ClientIP: setts.ClientIP.String(),
ClientName: setts.ClientName,
DNSType: qtype,
DNSType: rrtype,
}
if d.filteringEngineAllow != nil {
d.engineLock.RLock()
// Keep in mind that this lock must be held no just when calling Match() but
// also while using the rules returned by it.
//
// TODO(e.burkov): Inspect if the above is true.
defer d.engineLock.RUnlock()
if setts.ProtectionEnabled && d.filteringEngineAllow != nil {
dnsres, ok := d.filteringEngineAllow.MatchRequest(ureq)
if ok {
return d.matchHostProcessAllowList(host, dnsres)
@@ -786,13 +830,12 @@ func (d *DNSFilter) matchHost(
}
dnsres, ok := d.filteringEngine.MatchRequest(ureq)
// Check DNS rewrites first, because the API there is a bit awkward.
if dnsr := dnsres.DNSRewrites(); len(dnsr) > 0 {
res = d.processDNSRewrites(dnsr)
if res.Reason == RewrittenRule && res.CanonName == host {
// A rewrite of a host to itself. Go on and try
// matching other things.
// A rewrite of a host to itself. Go on and try matching other
// things.
} else {
return res, nil
}
@@ -800,7 +843,12 @@ func (d *DNSFilter) matchHost(
return Result{}, nil
}
res = d.matchHostProcessDNSResult(qtype, dnsres)
if !setts.ProtectionEnabled {
// Don't check non-dnsrewrite filtering results.
return Result{}, nil
}
res = d.matchHostProcessDNSResult(rrtype, dnsres)
for _, r := range res.Rules {
log.Debug(
"filtering: found rule %q for host %q, filter list id: %d",
@@ -836,40 +884,33 @@ func InitModule() {
}
// New creates properly initialized DNS Filter that is ready to be used.
func New(c *Config, blockFilters []Filter) *DNSFilter {
var resolver Resolver = net.DefaultResolver
func New(c *Config, blockFilters []Filter) (d *DNSFilter) {
d = &DNSFilter{
resolver: net.DefaultResolver,
}
if c != nil {
cacheConf := cache.Config{
d.safebrowsingCache = cache.New(cache.Config{
EnableLRU: true,
}
if gctx.safebrowsingCache == nil {
cacheConf.MaxSize = c.SafeBrowsingCacheSize
gctx.safebrowsingCache = cache.New(cacheConf)
}
if gctx.safeSearchCache == nil {
cacheConf.MaxSize = c.SafeSearchCacheSize
gctx.safeSearchCache = cache.New(cacheConf)
}
if gctx.parentalCache == nil {
cacheConf.MaxSize = c.ParentalCacheSize
gctx.parentalCache = cache.New(cacheConf)
}
MaxSize: c.SafeBrowsingCacheSize,
})
d.safeSearchCache = cache.New(cache.Config{
EnableLRU: true,
MaxSize: c.SafeSearchCacheSize,
})
d.parentalCache = cache.New(cache.Config{
EnableLRU: true,
MaxSize: c.ParentalCacheSize,
})
if c.CustomResolver != nil {
resolver = c.CustomResolver
d.resolver = c.CustomResolver
}
}
d := &DNSFilter{
resolver: resolver,
}
d.hostCheckers = []hostChecker{{
check: d.checkEtcHosts,
name: "etchosts",
check: d.matchSysHosts,
name: "hosts container",
}, {
check: d.matchHost,
name: "filtering",
@@ -890,12 +931,18 @@ func New(c *Config, blockFilters []Filter) *DNSFilter {
err := d.initSecurityServices()
if err != nil {
log.Error("filtering: initialize services: %s", err)
return nil
}
if c != nil {
d.Config = *c
d.prepareRewrites()
err = d.prepareRewrites()
if err != nil {
log.Error("rewrites: preparing: %s", err)
return nil
}
}
bsvcs := []string{}

View File

@@ -21,15 +21,17 @@ func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
}
var setts Settings
var setts = Settings{
ProtectionEnabled: true,
}
// Helpers.
func purgeCaches() {
func purgeCaches(d *DNSFilter) {
for _, c := range []cache.Cache{
gctx.safebrowsingCache,
gctx.parentalCache,
gctx.safeSearchCache,
d.safebrowsingCache,
d.parentalCache,
d.safeSearchCache,
} {
if c != nil {
c.Clear()
@@ -37,11 +39,11 @@ func purgeCaches() {
}
}
func newForTest(c *Config, filters []Filter) *DNSFilter {
func newForTest(t testing.TB, c *Config, filters []Filter) *DNSFilter {
setts = Settings{
FilteringEnabled: true,
ProtectionEnabled: true,
FilteringEnabled: true,
}
setts.FilteringEnabled = true
if c != nil {
c.SafeBrowsingCacheSize = 10000
c.ParentalCacheSize = 10000
@@ -52,7 +54,8 @@ func newForTest(c *Config, filters []Filter) *DNSFilter {
setts.ParentalEnabled = c.ParentalEnabled
}
d := New(c, filters)
purgeCaches()
purgeCaches(d)
return d
}
@@ -103,7 +106,7 @@ func TestEtcHostsMatching(t *testing.T) {
filters := []Filter{{
ID: 0, Data: []byte(text),
}}
d := newForTest(nil, filters)
d := newForTest(t, nil, filters)
t.Cleanup(d.Close)
d.checkMatchIP(t, "google.com", addr, dns.TypeA)
@@ -168,7 +171,7 @@ func TestSafeBrowsing(t *testing.T) {
aghtest.ReplaceLogWriter(t, logOutput)
aghtest.ReplaceLogLevel(t, log.DEBUG)
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
const matching = "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
@@ -191,7 +194,7 @@ func TestSafeBrowsing(t *testing.T) {
}
func TestParallelSB(t *testing.T) {
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
const matching = "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
@@ -215,7 +218,7 @@ func TestParallelSB(t *testing.T) {
// Safe Search.
func TestSafeSearch(t *testing.T) {
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
d := newForTest(t, &Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close)
val, ok := d.SafeSearchDomain("www.google.com")
require.True(t, ok)
@@ -224,7 +227,9 @@ func TestSafeSearch(t *testing.T) {
}
func TestCheckHostSafeSearchYandex(t *testing.T) {
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
d := newForTest(t, &Config{
SafeSearchEnabled: true,
}, nil)
t.Cleanup(d.Close)
yandexIP := net.IPv4(213, 180, 193, 56)
@@ -247,13 +252,14 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
require.Len(t, res.Rules, 1)
assert.Equal(t, yandexIP, res.Rules[0].IP)
assert.EqualValues(t, SafeSearchListID, res.Rules[0].FilterListID)
})
}
}
func TestCheckHostSafeSearchGoogle(t *testing.T) {
resolver := &aghtest.TestResolver{}
d := newForTest(&Config{
d := newForTest(t, &Config{
SafeSearchEnabled: true,
CustomResolver: resolver,
}, nil)
@@ -280,12 +286,13 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
require.Len(t, res.Rules, 1)
assert.Equal(t, ip, res.Rules[0].IP)
assert.EqualValues(t, SafeSearchListID, res.Rules[0].FilterListID)
})
}
}
func TestSafeSearchCacheYandex(t *testing.T) {
d := newForTest(nil, nil)
d := newForTest(t, nil, nil)
t.Cleanup(d.Close)
const domain = "yandex.ru"
@@ -299,7 +306,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
yandexIP := net.IPv4(213, 180, 193, 56)
d = newForTest(&Config{SafeSearchEnabled: true}, nil)
d = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close)
res, err = d.CheckHost(domain, dns.TypeA, &setts)
@@ -310,7 +317,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
assert.Equal(t, res.Rules[0].IP, yandexIP)
// Check cache.
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
cachedValue, isFound := getCachedResult(d.safeSearchCache, domain)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
@@ -319,7 +326,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
func TestSafeSearchCacheGoogle(t *testing.T) {
resolver := &aghtest.TestResolver{}
d := newForTest(&Config{
d := newForTest(t, &Config{
CustomResolver: resolver,
}, nil)
t.Cleanup(d.Close)
@@ -332,7 +339,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
require.Empty(t, res.Rules)
d = newForTest(&Config{SafeSearchEnabled: true}, nil)
d = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close)
d.resolver = resolver
@@ -359,7 +366,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
assert.True(t, res.Rules[0].IP.Equal(ip))
// Check cache.
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
cachedValue, isFound := getCachedResult(d.safeSearchCache, domain)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
@@ -373,7 +380,7 @@ func TestParentalControl(t *testing.T) {
aghtest.ReplaceLogWriter(t, logOutput)
aghtest.ReplaceLogLevel(t, log.DEBUG)
d := newForTest(&Config{ParentalEnabled: true}, nil)
d := newForTest(t, &Config{ParentalEnabled: true}, nil)
t.Cleanup(d.Close)
const matching = "pornhub.com"
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
@@ -677,7 +684,7 @@ func TestMatching(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("%s-%s", tc.name, tc.host), func(t *testing.T) {
filters := []Filter{{ID: 0, Data: []byte(tc.rules)}}
d := newForTest(nil, filters)
d := newForTest(t, nil, filters)
t.Cleanup(d.Close)
res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts)
@@ -703,7 +710,7 @@ func TestWhitelist(t *testing.T) {
whiteFilters := []Filter{{
ID: 0, Data: []byte(whiteRules),
}}
d := newForTest(nil, filters)
d := newForTest(t, nil, filters)
err := d.SetFilters(filters, whiteFilters, false)
require.NoError(t, err)
@@ -748,7 +755,7 @@ func applyClientSettings(setts *Settings) {
}
func TestClientSettings(t *testing.T) {
d := newForTest(
d := newForTest(t,
&Config{
ParentalEnabled: true,
SafeBrowsingEnabled: false,
@@ -797,7 +804,11 @@ func TestClientSettings(t *testing.T) {
makeTester := func(tc testCase, before bool) func(t *testing.T) {
return func(t *testing.T) {
r, _ := d.CheckHost(tc.host, dns.TypeA, &setts)
t.Helper()
r, err := d.CheckHost(tc.host, dns.TypeA, &setts)
require.NoError(t, err)
if before {
assert.True(t, r.IsFiltered)
assert.Equal(t, tc.wantReason, r.Reason)
@@ -808,7 +819,7 @@ func TestClientSettings(t *testing.T) {
}
// Check behaviour without any per-client settings, then apply per-client
// settings and check behaviour once again.
// settings and check behavior once again.
for _, tc := range testCases {
t.Run(tc.name, makeTester(tc, tc.before))
}
@@ -823,7 +834,7 @@ func TestClientSettings(t *testing.T) {
// Benchmarks.
func BenchmarkSafeBrowsing(b *testing.B) {
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
b.Cleanup(d.Close)
blocked := "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
@@ -839,7 +850,7 @@ func BenchmarkSafeBrowsing(b *testing.B) {
}
func BenchmarkSafeBrowsingParallel(b *testing.B) {
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
b.Cleanup(d.Close)
blocked := "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
@@ -857,7 +868,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
}
func BenchmarkSafeSearch(b *testing.B) {
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
d := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
b.Cleanup(d.Close)
for n := 0; n < b.N; n++ {
val, ok := d.SafeSearchDomain("www.google.com")
@@ -868,7 +879,7 @@ func BenchmarkSafeSearch(b *testing.B) {
}
func BenchmarkSafeSearchParallel(b *testing.B) {
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
d := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
b.Cleanup(d.Close)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {

View File

@@ -4,93 +4,121 @@ package filtering
import (
"encoding/json"
"fmt"
"net"
"net/http"
"sort"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
// RewriteEntry is a rewrite array element
type RewriteEntry struct {
// Domain is the domain for which this rewrite should work.
// LegacyRewrite is a single legacy DNS rewrite record.
//
// Instances of *LegacyRewrite must never be nil.
type LegacyRewrite struct {
// Domain is the domain pattern for which this rewrite should work.
Domain string `yaml:"domain"`
// Answer is the IP address, canonical name, or one of the special
// values: "A" or "AAAA".
Answer string `yaml:"answer"`
// IP is the IP address that should be used in the response if Type is
// A or AAAA.
// dns.TypeA or dns.TypeAAAA.
IP net.IP `yaml:"-"`
// Type is the DNS record type: A, AAAA, or CNAME.
Type uint16 `yaml:"-"`
}
// equal returns true if the entry is considered equal to the other.
func (e *RewriteEntry) equal(other RewriteEntry) (ok bool) {
return e.Domain == other.Domain && e.Answer == other.Answer
// clone returns a deep clone of rw.
func (rw *LegacyRewrite) clone() (cloneRW *LegacyRewrite) {
return &LegacyRewrite{
Domain: rw.Domain,
Answer: rw.Answer,
IP: netutil.CloneIP(rw.IP),
Type: rw.Type,
}
}
// matchesQType returns true if the entry matched qtype.
func (e *RewriteEntry) matchesQType(qtype uint16) (ok bool) {
// equal returns true if the rw is equal to the other.
func (rw *LegacyRewrite) equal(other *LegacyRewrite) (ok bool) {
return rw.Domain == other.Domain && rw.Answer == other.Answer
}
// matchesQType returns true if the entry matches the question type qt.
func (rw *LegacyRewrite) matchesQType(qt uint16) (ok bool) {
// Add CNAMEs, since they match for all types requests.
if e.Type == dns.TypeCNAME {
if rw.Type == dns.TypeCNAME {
return true
}
// Reject types other than A and AAAA.
if qtype != dns.TypeA && qtype != dns.TypeAAAA {
if qt != dns.TypeA && qt != dns.TypeAAAA {
return false
}
// If the types match or the entry is set to allow only the other type,
// include them.
return e.Type == qtype || e.IP == nil
return rw.Type == qt || rw.IP == nil
}
// normalize makes sure that the a new or decoded entry is normalized with
// regards to domain name case, IP length, and so on.
func (e *RewriteEntry) normalize() {
// TODO(a.garipov): Write a case-agnostic version of strings.HasSuffix
// and use it in matchDomainWildcard instead of using strings.ToLower
//
// If rw is nil, it returns an errors.
func (rw *LegacyRewrite) normalize() (err error) {
if rw == nil {
return errors.Error("nil rewrite entry")
}
// TODO(a.garipov): Write a case-agnostic version of strings.HasSuffix and
// use it in matchDomainWildcard instead of using strings.ToLower
// everywhere.
e.Domain = strings.ToLower(e.Domain)
rw.Domain = strings.ToLower(rw.Domain)
switch e.Answer {
switch rw.Answer {
case "AAAA":
e.IP = nil
e.Type = dns.TypeAAAA
rw.IP = nil
rw.Type = dns.TypeAAAA
return
return nil
case "A":
e.IP = nil
e.Type = dns.TypeA
rw.IP = nil
rw.Type = dns.TypeA
return
return nil
default:
// Go on.
}
ip := net.ParseIP(e.Answer)
ip := net.ParseIP(rw.Answer)
if ip == nil {
e.Type = dns.TypeCNAME
rw.Type = dns.TypeCNAME
return
return nil
}
ip4 := ip.To4()
if ip4 != nil {
e.IP = ip4
e.Type = dns.TypeA
rw.IP = ip4
rw.Type = dns.TypeA
} else {
e.IP = ip
e.Type = dns.TypeAAAA
rw.IP = ip
rw.Type = dns.TypeAAAA
}
return nil
}
func isWildcard(host string) bool {
return len(host) > 1 && host[0] == '*' && host[1] == '.'
// isWildcard returns true if pat is a wildcard domain pattern.
func isWildcard(pat string) bool {
return len(pat) > 1 && pat[0] == '*' && pat[1] == '.'
}
// matchDomainWildcard returns true if host matches the wildcard pattern.
@@ -106,16 +134,16 @@ func matchDomainWildcard(host, wildcard string) (ok bool) {
// wildcard > exact
// lower level wildcard > higher level wildcard
//
type rewritesSorted []RewriteEntry
type rewritesSorted []*LegacyRewrite
// Len implements the sort.Interface interface for legacyRewritesSorted.
func (a rewritesSorted) Len() int { return len(a) }
func (a rewritesSorted) Len() (l int) { return len(a) }
// Swap implements the sort.Interface interface for legacyRewritesSorted.
func (a rewritesSorted) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
// Less implements the sort.Interface interface for legacyRewritesSorted.
func (a rewritesSorted) Less(i, j int) bool {
func (a rewritesSorted) Less(i, j int) (less bool) {
if a[i].Type == dns.TypeCNAME && a[j].Type != dns.TypeCNAME {
return true
} else if a[i].Type != dns.TypeCNAME && a[j].Type == dns.TypeCNAME {
@@ -132,49 +160,62 @@ func (a rewritesSorted) Less(i, j int) bool {
}
}
// both are wildcards
// Both are wildcards.
return len(a[i].Domain) > len(a[j].Domain)
}
func (d *DNSFilter) prepareRewrites() {
for i := range d.Rewrites {
d.Rewrites[i].normalize()
// prepareRewrites normalizes and validates all legacy DNS rewrites.
func (d *DNSFilter) prepareRewrites() (err error) {
for i, r := range d.Rewrites {
err = r.normalize()
if err != nil {
return fmt.Errorf("at index %d: %w", i, err)
}
}
return nil
}
// findRewrites returns the list of matched rewrite entries. The priority is:
// CNAME, then A and AAAA; exact, then wildcard. If the host is matched
// exactly, wildcard entries aren't returned. If the host matched by wildcards,
// return the most specific for the question type.
func findRewrites(entries []RewriteEntry, host string, qtype uint16) (matched []RewriteEntry) {
rr := rewritesSorted{}
// findRewrites returns the list of matched rewrite entries. If rewrites are
// empty, but matched is true, the domain is found among the rewrite rules but
// not for this question type.
//
// The result priority is: CNAME, then A and AAAA; exact, then wildcard. If the
// host is matched exactly, wildcard entries aren't returned. If the host
// matched by wildcards, return the most specific for the question type.
func findRewrites(
entries []*LegacyRewrite,
host string,
qtype uint16,
) (rewrites []*LegacyRewrite, matched bool) {
for _, e := range entries {
if e.Domain != host && !matchDomainWildcard(host, e.Domain) {
continue
}
matched = true
if e.matchesQType(qtype) {
rr = append(rr, e)
rewrites = append(rewrites, e)
}
}
if len(rr) == 0 {
return nil
if len(rewrites) == 0 {
return nil, matched
}
sort.Sort(rr)
sort.Sort(rewritesSorted(rewrites))
for i, r := range rr {
for i, r := range rewrites {
if isWildcard(r.Domain) {
// Don't use rr[:0], because we need to return at least
// one item here.
rr = rr[:max(1, i)]
// Don't use rewrites[:0], because we need to return at least one
// item here.
rewrites = rewrites[:max(1, i)]
break
}
}
return rr
return rewrites, matched
}
func max(a, b int) int {
@@ -206,29 +247,39 @@ func (d *DNSFilter) handleRewriteList(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(arr)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
return
}
}
func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
jsent := rewriteEntryJSON{}
err := json.NewDecoder(r.Body).Decode(&jsent)
rwJSON := rewriteEntryJSON{}
err := json.NewDecoder(r.Body).Decode(&rwJSON)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
ent := RewriteEntry{
Domain: jsent.Domain,
Answer: jsent.Answer,
rw := &LegacyRewrite{
Domain: rwJSON.Domain,
Answer: rwJSON.Answer,
}
ent.normalize()
err = rw.normalize()
if err != nil {
// Shouldn't happen currently, since normalize only returns a non-nil
// error when a rewrite is nil, but be change-proof.
aghhttp.Error(r, w, http.StatusBadRequest, "normalizing: %s", err)
return
}
d.confLock.Lock()
d.Config.Rewrites = append(d.Config.Rewrites, ent)
d.Config.Rewrites = append(d.Config.Rewrites, rw)
d.confLock.Unlock()
log.Debug("Rewrites: added element: %s -> %s [%d]",
ent.Domain, ent.Answer, len(d.Config.Rewrites))
log.Debug("rewrite: added element: %s -> %s [%d]", rw.Domain, rw.Answer, len(d.Config.Rewrites))
d.Config.ConfigModified()
}
@@ -237,21 +288,25 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request)
jsent := rewriteEntryJSON{}
err := json.NewDecoder(r.Body).Decode(&jsent)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
entDel := RewriteEntry{
entDel := &LegacyRewrite{
Domain: jsent.Domain,
Answer: jsent.Answer,
}
arr := []RewriteEntry{}
arr := []*LegacyRewrite{}
d.confLock.Lock()
for _, ent := range d.Config.Rewrites {
if ent.equal(entDel) {
log.Debug("Rewrites: removed element: %s -> %s", ent.Domain, ent.Answer)
log.Debug("rewrite: removed element: %s -> %s", ent.Domain, ent.Answer)
continue
}
arr = append(arr, ent)
}
d.Config.Rewrites = arr

View File

@@ -12,10 +12,10 @@ import (
// TODO(e.burkov): All the tests in this file may and should me merged together.
func TestRewrites(t *testing.T) {
d := newForTest(nil, nil)
d := newForTest(t, nil, nil)
t.Cleanup(d.Close)
d.Rewrites = []RewriteEntry{{
d.Rewrites = []*LegacyRewrite{{
// This one and below are about CNAME, A and AAAA.
Domain: "somecname",
Answer: "somehost.com",
@@ -66,107 +66,132 @@ func TestRewrites(t *testing.T) {
}, {
Domain: "BIGHOST.COM",
Answer: "1.2.3.7",
}, {
Domain: "*.issue4016.com",
Answer: "sub.issue4016.com",
}}
d.prepareRewrites()
require.NoError(t, d.prepareRewrites())
testCases := []struct {
name string
host string
wantCName string
wantVals []net.IP
dtyp uint16
name string
host string
wantCName string
wantIPs []net.IP
wantReason Reason
dtyp uint16
}{{
name: "not_filtered_not_found",
host: "hoost.com",
wantCName: "",
wantVals: nil,
dtyp: dns.TypeA,
name: "not_filtered_not_found",
host: "hoost.com",
wantCName: "",
wantIPs: nil,
wantReason: NotFilteredNotFound,
dtyp: dns.TypeA,
}, {
name: "rewritten_a",
host: "www.host.com",
wantCName: "host.com",
wantVals: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
dtyp: dns.TypeA,
name: "rewritten_a",
host: "www.host.com",
wantCName: "host.com",
wantIPs: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "rewritten_aaaa",
host: "www.host.com",
wantCName: "host.com",
wantVals: []net.IP{net.ParseIP("1:2:3::4")},
dtyp: dns.TypeAAAA,
name: "rewritten_aaaa",
host: "www.host.com",
wantCName: "host.com",
wantIPs: []net.IP{net.ParseIP("1:2:3::4")},
wantReason: Rewritten,
dtyp: dns.TypeAAAA,
}, {
name: "wildcard_match",
host: "abc.host.com",
wantCName: "",
wantVals: []net.IP{{1, 2, 3, 5}},
dtyp: dns.TypeA,
name: "wildcard_match",
host: "abc.host.com",
wantCName: "",
wantIPs: []net.IP{{1, 2, 3, 5}},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "wildcard_override",
host: "a.host.com",
wantCName: "",
wantVals: []net.IP{{1, 2, 3, 4}},
dtyp: dns.TypeA,
name: "wildcard_override",
host: "a.host.com",
wantCName: "",
wantIPs: []net.IP{{1, 2, 3, 4}},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "wildcard_cname_interaction",
host: "www.host2.com",
wantCName: "host.com",
wantVals: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
dtyp: dns.TypeA,
name: "wildcard_cname_interaction",
host: "www.host2.com",
wantCName: "host.com",
wantIPs: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "two_cnames",
host: "b.host.com",
wantCName: "somehost.com",
wantVals: []net.IP{{0, 0, 0, 0}},
dtyp: dns.TypeA,
name: "two_cnames",
host: "b.host.com",
wantCName: "somehost.com",
wantIPs: []net.IP{{0, 0, 0, 0}},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "two_cnames_and_wildcard",
host: "b.host3.com",
wantCName: "x.host.com",
wantVals: []net.IP{{1, 2, 3, 5}},
dtyp: dns.TypeA,
name: "two_cnames_and_wildcard",
host: "b.host3.com",
wantCName: "x.host.com",
wantIPs: []net.IP{{1, 2, 3, 5}},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "issue3343",
host: "www.hostboth.com",
wantCName: "",
wantVals: []net.IP{net.ParseIP("1234::5678")},
dtyp: dns.TypeAAAA,
name: "issue3343",
host: "www.hostboth.com",
wantCName: "",
wantIPs: []net.IP{net.ParseIP("1234::5678")},
wantReason: Rewritten,
dtyp: dns.TypeAAAA,
}, {
name: "issue3351",
host: "bighost.com",
wantCName: "",
wantVals: []net.IP{{1, 2, 3, 7}},
dtyp: dns.TypeA,
name: "issue3351",
host: "bighost.com",
wantCName: "",
wantIPs: []net.IP{{1, 2, 3, 7}},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "issue4008",
host: "somehost.com",
wantCName: "",
wantIPs: nil,
wantReason: Rewritten,
dtyp: dns.TypeHTTPS,
}, {
name: "issue4016",
host: "www.issue4016.com",
wantCName: "sub.issue4016.com",
wantIPs: nil,
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "issue4016_self",
host: "sub.issue4016.com",
wantCName: "",
wantIPs: nil,
wantReason: NotFilteredNotFound,
dtyp: dns.TypeA,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
valsNum := len(tc.wantVals)
r := d.processRewrites(tc.host, tc.dtyp)
if valsNum == 0 {
assert.Equal(t, NotFilteredNotFound, r.Reason)
return
}
require.Equalf(t, Rewritten, r.Reason, "got %s", r.Reason)
require.Equalf(t, tc.wantReason, r.Reason, "got %s", r.Reason)
if tc.wantCName != "" {
assert.Equal(t, tc.wantCName, r.CanonName)
}
require.Len(t, r.IPList, valsNum)
for i, ip := range tc.wantVals {
assert.Equal(t, ip, r.IPList[i])
}
assert.Equal(t, tc.wantIPs, r.IPList)
})
}
}
func TestRewritesLevels(t *testing.T) {
d := newForTest(nil, nil)
d := newForTest(t, nil, nil)
t.Cleanup(d.Close)
// Exact host, wildcard L2, wildcard L3.
d.Rewrites = []RewriteEntry{{
d.Rewrites = []*LegacyRewrite{{
Domain: "host.com",
Answer: "1.1.1.1",
Type: dns.TypeA,
@@ -179,7 +204,8 @@ func TestRewritesLevels(t *testing.T) {
Answer: "3.3.3.3",
Type: dns.TypeA,
}}
d.prepareRewrites()
require.NoError(t, d.prepareRewrites())
testCases := []struct {
name string
@@ -209,10 +235,10 @@ func TestRewritesLevels(t *testing.T) {
}
func TestRewritesExceptionCNAME(t *testing.T) {
d := newForTest(nil, nil)
d := newForTest(t, nil, nil)
t.Cleanup(d.Close)
// Wildcard and exception for a sub-domain.
d.Rewrites = []RewriteEntry{{
d.Rewrites = []*LegacyRewrite{{
Domain: "*.host.com",
Answer: "2.2.2.2",
}, {
@@ -222,29 +248,32 @@ func TestRewritesExceptionCNAME(t *testing.T) {
Domain: "*.sub.host.com",
Answer: "*.sub.host.com",
}}
d.prepareRewrites()
require.NoError(t, d.prepareRewrites())
testCases := []struct {
name string
host string
want net.IP
}{{
name: "match_sub-domain",
name: "match_subdomain",
host: "my.host.com",
want: net.IP{2, 2, 2, 2},
}, {
name: "exception_cname",
host: "sub.host.com",
want: nil,
}, {
name: "exception_wildcard",
host: "my.sub.host.com",
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := d.processRewrites(tc.host, dns.TypeA)
if tc.want == nil {
assert.Equal(t, NotFilteredNotFound, r.Reason)
assert.Equal(t, NotFilteredNotFound, r.Reason, "got %s", r.Reason)
return
}
@@ -257,10 +286,10 @@ func TestRewritesExceptionCNAME(t *testing.T) {
}
func TestRewritesExceptionIP(t *testing.T) {
d := newForTest(nil, nil)
d := newForTest(t, nil, nil)
t.Cleanup(d.Close)
// Exception for AAAA record.
d.Rewrites = []RewriteEntry{{
d.Rewrites = []*LegacyRewrite{{
Domain: "host.com",
Answer: "1.2.3.4",
Type: dns.TypeA,
@@ -281,7 +310,8 @@ func TestRewritesExceptionIP(t *testing.T) {
Answer: "A",
Type: dns.TypeA,
}}
d.prepareRewrites()
require.NoError(t, d.prepareRewrites())
testCases := []struct {
name string

View File

@@ -13,6 +13,7 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
@@ -306,7 +307,7 @@ func (d *DNSFilter) checkSafeBrowsing(
_ uint16,
setts *Settings,
) (res Result, err error) {
if !setts.SafeBrowsingEnabled {
if !setts.ProtectionEnabled || !setts.SafeBrowsingEnabled {
return Result{}, nil
}
@@ -318,7 +319,7 @@ func (d *DNSFilter) checkSafeBrowsing(
sctx := &sbCtx{
host: host,
svc: "SafeBrowsing",
cache: gctx.safebrowsingCache,
cache: d.safebrowsingCache,
cacheTime: d.Config.CacheTime,
}
@@ -326,7 +327,8 @@ func (d *DNSFilter) checkSafeBrowsing(
IsFiltered: true,
Reason: FilteredSafeBrowsing,
Rules: []*ResultRule{{
Text: "adguard-malware-shavar",
Text: "adguard-malware-shavar",
FilterListID: SafeBrowsingListID,
}},
}
@@ -339,7 +341,7 @@ func (d *DNSFilter) checkParental(
_ uint16,
setts *Settings,
) (res Result, err error) {
if !setts.ParentalEnabled {
if !setts.ProtectionEnabled || !setts.ParentalEnabled {
return Result{}, nil
}
@@ -351,7 +353,7 @@ func (d *DNSFilter) checkParental(
sctx := &sbCtx{
host: host,
svc: "Parental",
cache: gctx.parentalCache,
cache: d.parentalCache,
cacheTime: d.Config.CacheTime,
}
@@ -359,19 +361,14 @@ func (d *DNSFilter) checkParental(
IsFiltered: true,
Reason: FilteredParental,
Rules: []*ResultRule{{
Text: "parental CATEGORY_BLACKLISTED",
Text: "parental CATEGORY_BLACKLISTED",
FilterListID: ParentalListID,
}},
}
return check(sctx, res, d.parentalUpstream)
}
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Info("DNSFilter: %s %s: %s", r.Method, r.URL, text)
http.Error(w, text, code)
}
func (d *DNSFilter) handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) {
d.Config.SafeBrowsingEnabled = true
d.Config.ConfigModified()
@@ -390,7 +387,8 @@ func (d *DNSFilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Requ
Enabled: d.Config.SafeBrowsingEnabled,
})
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
}
}
@@ -413,8 +411,7 @@ func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request)
Enabled: d.Config.ParentalEnabled,
})
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
}
}

View File

@@ -108,7 +108,7 @@ func TestSafeBrowsingCache(t *testing.T) {
}
func TestSBPC_checkErrorUpstream(t *testing.T) {
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
ups := &aghtest.TestErrUpstream{}
@@ -117,6 +117,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
d.SetParentalUpstream(ups)
setts := &Settings{
ProtectionEnabled: true,
SafeBrowsingEnabled: true,
ParentalEnabled: true,
}
@@ -129,41 +130,42 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
}
func TestSBPC(t *testing.T) {
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
const hostname = "example.org"
setts := &Settings{
ProtectionEnabled: true,
SafeBrowsingEnabled: true,
ParentalEnabled: true,
}
testCases := []struct {
testCache cache.Cache
testFunc func(host string, _ uint16, _ *Settings) (res Result, err error)
name string
block bool
testFunc func(host string, _ uint16, _ *Settings) (res Result, err error)
testCache cache.Cache
}{{
testCache: d.safebrowsingCache,
testFunc: d.checkSafeBrowsing,
name: "sb_no_block",
block: false,
testFunc: d.checkSafeBrowsing,
testCache: gctx.safebrowsingCache,
}, {
testCache: d.safebrowsingCache,
testFunc: d.checkSafeBrowsing,
name: "sb_block",
block: true,
testFunc: d.checkSafeBrowsing,
testCache: gctx.safebrowsingCache,
}, {
testCache: d.parentalCache,
testFunc: d.checkParental,
name: "pc_no_block",
block: false,
testFunc: d.checkParental,
testCache: gctx.parentalCache,
}, {
testCache: d.parentalCache,
testFunc: d.checkParental,
name: "pc_block",
block: true,
testFunc: d.checkParental,
testCache: gctx.parentalCache,
}}
for _, tc := range testCases {
@@ -215,6 +217,6 @@ func TestSBPC(t *testing.T) {
assert.Equal(t, 1, ups.RequestsCount())
})
purgeCaches()
purgeCaches(d)
}
}

View File

@@ -11,6 +11,7 @@ import (
"net/http"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
)
@@ -74,7 +75,7 @@ func (d *DNSFilter) checkSafeSearch(
_ uint16,
setts *Settings,
) (res Result, err error) {
if !setts.SafeSearchEnabled {
if !setts.ProtectionEnabled || !setts.SafeSearchEnabled {
return Result{}, nil
}
@@ -84,7 +85,7 @@ func (d *DNSFilter) checkSafeSearch(
}
// Check cache. Return cached result if it was found
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, host)
cachedValue, isFound := getCachedResult(d.safeSearchCache, host)
if isFound {
// atomic.AddUint64(&gctx.stats.Safesearch.CacheHits, 1)
log.Tracef("SafeSearch: found in cache: %s", host)
@@ -99,12 +100,14 @@ func (d *DNSFilter) checkSafeSearch(
res = Result{
IsFiltered: true,
Reason: FilteredSafeSearch,
Rules: []*ResultRule{{}},
Rules: []*ResultRule{{
FilterListID: SafeSearchListID,
}},
}
if ip := net.ParseIP(safeHost); ip != nil {
res.Rules[0].IP = ip
valLen := d.setCacheResult(gctx.safeSearchCache, host, res)
valLen := d.setCacheResult(d.safeSearchCache, host, res)
log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, valLen)
return res, nil
@@ -123,7 +126,7 @@ func (d *DNSFilter) checkSafeSearch(
res.Rules[0].IP = ip
l := d.setCacheResult(gctx.safeSearchCache, host, res)
l := d.setCacheResult(d.safeSearchCache, host, res)
log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, l)
return res, nil
@@ -150,8 +153,13 @@ func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Reques
Enabled: d.Config.SafeSearchEnabled,
})
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Unable to write response json: %s",
err,
)
}
}