Merge branch 'master' into 3717-fix-qq-blocked
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
)
|
||||
@@ -239,7 +240,7 @@ func initBlockedServices() {
|
||||
for _, s := range serviceRulesArray {
|
||||
netRules := []*rules.NetworkRule{}
|
||||
for _, text := range s.rules {
|
||||
rule, err := rules.NewNetworkRule(text, 0)
|
||||
rule, err := rules.NewNetworkRule(text, BlockedSvcsListID)
|
||||
if err != nil {
|
||||
log.Error("rules.NewNetworkRule: %s rule: %s", err, text)
|
||||
continue
|
||||
@@ -287,7 +288,8 @@ func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Req
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(list)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -296,7 +298,8 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ
|
||||
list := []string{}
|
||||
err := json.NewDecoder(r.Body).Decode(&list)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -15,14 +15,10 @@ type DNSRewriteResult struct {
|
||||
// the server returns.
|
||||
type DNSRewriteResultResponse map[rules.RRType][]rules.RRValue
|
||||
|
||||
// processDNSRewrites processes DNS rewrite rules in dnsr. It returns
|
||||
// an empty result if dnsr is empty. Otherwise, the result will have
|
||||
// either CanonName or DNSRewriteResult set.
|
||||
// processDNSRewrites processes DNS rewrite rules in dnsr. It returns an empty
|
||||
// result if dnsr is empty. Otherwise, the result will have either CanonName or
|
||||
// DNSRewriteResult set. dnsr is expected to be non-empty.
|
||||
func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
|
||||
if len(dnsr) == 0 {
|
||||
return Result{}
|
||||
}
|
||||
|
||||
var rules []*ResultRule
|
||||
dnsrr := &DNSRewriteResult{
|
||||
Response: DNSRewriteResultResponse{},
|
||||
@@ -31,8 +27,7 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
|
||||
for _, nr := range dnsr {
|
||||
dr := nr.DNSRewrite
|
||||
if dr.NewCNAME != "" {
|
||||
// NewCNAME rules have a higher priority than
|
||||
// the other rules.
|
||||
// NewCNAME rules have a higher priority than other rules.
|
||||
rules = []*ResultRule{{
|
||||
FilterListID: int64(nr.GetFilterListID()),
|
||||
Text: nr.RuleText,
|
||||
@@ -54,8 +49,8 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
|
||||
Text: nr.RuleText,
|
||||
})
|
||||
default:
|
||||
// RcodeRefused and other such codes have higher
|
||||
// priority. Return immediately.
|
||||
// RcodeRefused and other such codes have higher priority. Return
|
||||
// immediately.
|
||||
rules = []*ResultRule{{
|
||||
FilterListID: int64(nr.GetFilterListID()),
|
||||
Text: nr.RuleText,
|
||||
|
||||
@@ -49,7 +49,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||
|1.2.3.5.in-addr.arpa^$dnsrewrite=NOERROR;PTR;new-ptr-with-dot.
|
||||
`
|
||||
|
||||
f := newForTest(nil, []Filter{{ID: 0, Data: []byte(text)}})
|
||||
f := newForTest(t, nil, []Filter{{ID: 0, Data: []byte(text)}})
|
||||
setts := &Settings{
|
||||
FilteringEnabled: true,
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ package filtering
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
@@ -24,6 +26,18 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// The IDs of built-in filter lists.
|
||||
//
|
||||
// Keep in sync with client/src/helpers/constants.js.
|
||||
const (
|
||||
CustomListID = -iota
|
||||
SysHostsListID
|
||||
BlockedSvcsListID
|
||||
ParentalListID
|
||||
SafeBrowsingListID
|
||||
SafeSearchListID
|
||||
)
|
||||
|
||||
// ServiceEntry - blocked service array element
|
||||
type ServiceEntry struct {
|
||||
Name string
|
||||
@@ -38,6 +52,7 @@ type Settings struct {
|
||||
|
||||
ServicesRules []ServiceEntry
|
||||
|
||||
ProtectionEnabled bool
|
||||
FilteringEnabled bool
|
||||
SafeSearchEnabled bool
|
||||
SafeBrowsingEnabled bool
|
||||
@@ -65,7 +80,7 @@ type Config struct {
|
||||
ParentalCacheSize uint `yaml:"parental_cache_size"` // (in bytes)
|
||||
CacheTime uint `yaml:"cache_time"` // Element's TTL (in minutes)
|
||||
|
||||
Rewrites []RewriteEntry `yaml:"rewrites"`
|
||||
Rewrites []*LegacyRewrite `yaml:"rewrites"`
|
||||
|
||||
// Names of services to block (globally).
|
||||
// Per-client settings can override this configuration.
|
||||
@@ -73,7 +88,7 @@ type Config struct {
|
||||
|
||||
// EtcHosts is a container of IP-hostname pairs taken from the operating
|
||||
// system configuration files (e.g. /etc/hosts).
|
||||
EtcHosts *aghnet.EtcHostsContainer `yaml:"-"`
|
||||
EtcHosts *aghnet.HostsContainer `yaml:"-"`
|
||||
|
||||
// Called when the configuration is changed by HTTP request
|
||||
ConfigModified func() `yaml:"-"`
|
||||
@@ -124,6 +139,10 @@ type DNSFilter struct {
|
||||
parentalUpstream upstream.Upstream
|
||||
safeBrowsingUpstream upstream.Upstream
|
||||
|
||||
safebrowsingCache cache.Cache
|
||||
parentalCache cache.Cache
|
||||
safeSearchCache cache.Cache
|
||||
|
||||
Config // for direct access by library users, even a = assignment
|
||||
// confLock protects Config.
|
||||
confLock sync.RWMutex
|
||||
@@ -142,9 +161,14 @@ type DNSFilter struct {
|
||||
|
||||
// Filter represents a filter list
|
||||
type Filter struct {
|
||||
ID int64 // auto-assigned when filter is added (see nextFilterID)
|
||||
Data []byte `yaml:"-"` // List of rules divided by '\n'
|
||||
FilePath string `yaml:"-"` // Path to a filtering rules file
|
||||
// FilePath is the path to a filtering rules list file.
|
||||
FilePath string `yaml:"-"`
|
||||
|
||||
// Data is the content of the file.
|
||||
Data []byte `yaml:"-"`
|
||||
|
||||
// ID is automatically assigned when filter is added using nextFilterID.
|
||||
ID int64
|
||||
}
|
||||
|
||||
// Reason holds an enum detailing why it was filtered or not filtered
|
||||
@@ -176,8 +200,8 @@ const (
|
||||
// FilteredBlockedService - the host is blocked by "blocked services" settings
|
||||
FilteredBlockedService
|
||||
|
||||
// Rewritten is returned when there was a rewrite by a legacy DNS
|
||||
// rewrite rule.
|
||||
// Rewritten is returned when there was a rewrite by a legacy DNS rewrite
|
||||
// rule.
|
||||
Rewritten
|
||||
|
||||
// RewrittenAutoHosts is returned when there was a rewrite by autohosts
|
||||
@@ -186,8 +210,8 @@ const (
|
||||
|
||||
// RewrittenRule is returned when a $dnsrewrite filter rule was applied.
|
||||
//
|
||||
// TODO(a.garipov): Remove Rewritten and RewrittenAutoHosts by merging
|
||||
// their functionality into RewrittenRule.
|
||||
// TODO(a.garipov): Remove Rewritten and RewrittenAutoHosts by merging their
|
||||
// functionality into RewrittenRule.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2499.
|
||||
RewrittenRule
|
||||
@@ -221,12 +245,13 @@ func (r Reason) String() string {
|
||||
}
|
||||
|
||||
// In returns true if reasons include r.
|
||||
func (r Reason) In(reasons ...Reason) bool {
|
||||
func (r Reason) In(reasons ...Reason) (ok bool) {
|
||||
for _, reason := range reasons {
|
||||
if r == reason {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -245,7 +270,7 @@ func (d *DNSFilter) GetConfig() (s Settings) {
|
||||
defer d.confLock.RUnlock()
|
||||
|
||||
return Settings{
|
||||
FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) == 1,
|
||||
FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) != 0,
|
||||
SafeSearchEnabled: d.Config.SafeSearchEnabled,
|
||||
SafeBrowsingEnabled: d.Config.SafeBrowsingEnabled,
|
||||
ParentalEnabled: d.Config.ParentalEnabled,
|
||||
@@ -261,8 +286,14 @@ func (d *DNSFilter) WriteDiskConfig(c *Config) {
|
||||
c.Rewrites = cloneRewrites(c.Rewrites)
|
||||
}
|
||||
|
||||
func cloneRewrites(entries []RewriteEntry) (clone []RewriteEntry) {
|
||||
return append([]RewriteEntry(nil), entries...)
|
||||
// cloneRewrites returns a deep copy of entries.
|
||||
func cloneRewrites(entries []*LegacyRewrite) (clone []*LegacyRewrite) {
|
||||
clone = make([]*LegacyRewrite, len(entries))
|
||||
for i, rw := range entries {
|
||||
clone[i] = rw.clone()
|
||||
}
|
||||
|
||||
return clone
|
||||
}
|
||||
|
||||
// SetFilters - set new filters (synchronously or asynchronously)
|
||||
@@ -338,14 +369,6 @@ func (d *DNSFilter) reset() {
|
||||
}
|
||||
}
|
||||
|
||||
type dnsFilterContext struct {
|
||||
safebrowsingCache cache.Cache
|
||||
parentalCache cache.Cache
|
||||
safeSearchCache cache.Cache
|
||||
}
|
||||
|
||||
var gctx dnsFilterContext
|
||||
|
||||
// ResultRule contains information about applied rules.
|
||||
type ResultRule struct {
|
||||
// Text is the text of the rule.
|
||||
@@ -371,24 +394,19 @@ type Result struct {
|
||||
// Reason is the reason for blocking or unblocking the request.
|
||||
Reason Reason `json:",omitempty"`
|
||||
|
||||
// Rules are applied rules. If Rules are not empty, each rule
|
||||
// is not nil.
|
||||
// Rules are applied rules. If Rules are not empty, each rule is not nil.
|
||||
Rules []*ResultRule `json:",omitempty"`
|
||||
|
||||
// ReverseHosts is the reverse lookup rewrite result. It is
|
||||
// empty unless Reason is set to RewrittenAutoHosts.
|
||||
ReverseHosts []string `json:",omitempty"`
|
||||
|
||||
// IPList is the lookup rewrite result. It is empty unless
|
||||
// Reason is set to RewrittenAutoHosts or Rewritten.
|
||||
// IPList is the lookup rewrite result. It is empty unless Reason is set to
|
||||
// Rewritten.
|
||||
IPList []net.IP `json:",omitempty"`
|
||||
|
||||
// CanonName is the CNAME value from the lookup rewrite result.
|
||||
// It is empty unless Reason is set to Rewritten or RewrittenRule.
|
||||
// CanonName is the CNAME value from the lookup rewrite result. It is empty
|
||||
// unless Reason is set to Rewritten or RewrittenRule.
|
||||
CanonName string `json:",omitempty"`
|
||||
|
||||
// ServiceName is the name of the blocked service. It is empty
|
||||
// unless Reason is set to FilteredBlockedService.
|
||||
// ServiceName is the name of the blocked service. It is empty unless
|
||||
// Reason is set to FilteredBlockedService.
|
||||
ServiceName string `json:",omitempty"`
|
||||
|
||||
// DNSRewriteResult is the $dnsrewrite filter rule result.
|
||||
@@ -402,14 +420,8 @@ func (r Reason) Matched() bool {
|
||||
}
|
||||
|
||||
// CheckHostRules tries to match the host against filtering rules only.
|
||||
func (d *DNSFilter) CheckHostRules(host string, qtype uint16, setts *Settings) (Result, error) {
|
||||
if !setts.FilteringEnabled {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
host = strings.ToLower(host)
|
||||
|
||||
return d.matchHost(host, qtype, setts)
|
||||
func (d *DNSFilter) CheckHostRules(host string, rrtype uint16, setts *Settings) (Result, error) {
|
||||
return d.matchHost(strings.ToLower(host), rrtype, setts)
|
||||
}
|
||||
|
||||
// CheckHost tries to match the host against filtering rules, then safebrowsing
|
||||
@@ -422,14 +434,16 @@ func (d *DNSFilter) CheckHost(
|
||||
// Sometimes clients try to resolve ".", which is a request to get root
|
||||
// servers.
|
||||
if host == "" {
|
||||
return Result{Reason: NotFilteredNotFound}, nil
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
host = strings.ToLower(host)
|
||||
|
||||
res = d.processRewrites(host, qtype)
|
||||
if res.Reason == Rewritten {
|
||||
return res, nil
|
||||
if setts.FilteringEnabled {
|
||||
res = d.processRewrites(host, qtype)
|
||||
if res.Reason == Rewritten {
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
for _, hc := range d.hostCheckers {
|
||||
@@ -446,100 +460,123 @@ func (d *DNSFilter) CheckHost(
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
// checkEtcHosts compares the host against our /etc/hosts table. The err is
|
||||
// always nil, it is only there to make this a valid hostChecker function.
|
||||
func (d *DNSFilter) checkEtcHosts(
|
||||
// matchSysHosts tries to match the host against the operating system's hosts
|
||||
// database. err is always nil.
|
||||
func (d *DNSFilter) matchSysHosts(
|
||||
host string,
|
||||
qtype uint16,
|
||||
_ *Settings,
|
||||
setts *Settings,
|
||||
) (res Result, err error) {
|
||||
if d.Config.EtcHosts == nil {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
ips := d.Config.EtcHosts.Process(host, qtype)
|
||||
if ips != nil {
|
||||
res = Result{
|
||||
Reason: RewrittenAutoHosts,
|
||||
IPList: ips,
|
||||
}
|
||||
|
||||
if !setts.FilteringEnabled || d.EtcHosts == nil {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
revHosts := d.Config.EtcHosts.ProcessReverse(host, qtype)
|
||||
if len(revHosts) != 0 {
|
||||
res = Result{
|
||||
Reason: RewrittenAutoHosts,
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Optimize this with a buffer.
|
||||
res.ReverseHosts = make([]string, len(revHosts))
|
||||
for i := range revHosts {
|
||||
res.ReverseHosts[i] = revHosts[i] + "."
|
||||
}
|
||||
|
||||
dnsres, _ := d.EtcHosts.MatchRequest(urlfilter.DNSRequest{
|
||||
Hostname: host,
|
||||
SortedClientTags: setts.ClientTags,
|
||||
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.
|
||||
ClientIP: setts.ClientIP.String(),
|
||||
ClientName: setts.ClientName,
|
||||
DNSType: qtype,
|
||||
})
|
||||
if dnsres == nil {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
return Result{}, nil
|
||||
dnsr := dnsres.DNSRewrites()
|
||||
if len(dnsr) == 0 {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
res = d.processDNSRewrites(dnsr)
|
||||
res.Reason = RewrittenAutoHosts
|
||||
for _, r := range res.Rules {
|
||||
r.Text = stringutil.Coalesce(d.EtcHosts.Translate(r.Text), r.Text)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Process rewrites table
|
||||
// . Find CNAME for a domain name (exact match or by wildcard)
|
||||
// . if found and CNAME equals to domain name - this is an exception; exit
|
||||
// . if found, set domain name to canonical name
|
||||
// . repeat for the new domain name (Note: we return only the last CNAME)
|
||||
// . Find A or AAAA record for a domain name (exact match or by wildcard)
|
||||
// . if found, set IP addresses (IPv4 or IPv6 depending on qtype) in Result.IPList array
|
||||
// processRewrites performs filtering based on the legacy rewrite records.
|
||||
//
|
||||
// Firstly, it finds CNAME rewrites for host. If the CNAME is the same as host,
|
||||
// this query isn't filtered. If it's different, repeat the process for the new
|
||||
// CNAME, breaking loops in the process.
|
||||
//
|
||||
// Secondly, it finds A or AAAA rewrites for host and, if found, sets res.IPList
|
||||
// accordingly. If the found rewrite has a special value of "A" or "AAAA", the
|
||||
// result is an exception.
|
||||
func (d *DNSFilter) processRewrites(host string, qtype uint16) (res Result) {
|
||||
d.confLock.RLock()
|
||||
defer d.confLock.RUnlock()
|
||||
|
||||
rr := findRewrites(d.Rewrites, host, qtype)
|
||||
if len(rr) != 0 {
|
||||
res.Reason = Rewritten
|
||||
rewrites, matched := findRewrites(d.Rewrites, host, qtype)
|
||||
if !matched {
|
||||
return Result{}
|
||||
}
|
||||
|
||||
res.Reason = Rewritten
|
||||
|
||||
cnames := stringutil.NewSet()
|
||||
origHost := host
|
||||
for len(rr) != 0 && rr[0].Type == dns.TypeCNAME {
|
||||
log.Debug("rewrite: CNAME for %s is %s", host, rr[0].Answer)
|
||||
for matched && len(rewrites) > 0 && rewrites[0].Type == dns.TypeCNAME {
|
||||
rw := rewrites[0]
|
||||
rwPat := rw.Domain
|
||||
rwAns := rw.Answer
|
||||
|
||||
if host == rr[0].Answer { // "host == CNAME" is an exception
|
||||
res.Reason = NotFilteredNotFound
|
||||
log.Debug("rewrite: cname for %s is %s", host, rwAns)
|
||||
|
||||
return res
|
||||
if origHost == rwAns || rwPat == rwAns {
|
||||
// Either a request for the hostname itself or a rewrite of
|
||||
// a pattern onto itself, both of which are an exception rules.
|
||||
// Return a not filtered result.
|
||||
return Result{}
|
||||
} else if host == rwAns && isWildcard(rwPat) {
|
||||
// An "*.example.com → sub.example.com" rewrite matching in a loop.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/4016.
|
||||
|
||||
res.CanonName = host
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
host = rr[0].Answer
|
||||
host = rwAns
|
||||
if cnames.Has(host) {
|
||||
log.Info("rewrite: breaking CNAME redirection loop: %s. Question: %s", host, origHost)
|
||||
log.Info("rewrite: cname loop for %q on %q", origHost, host)
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
cnames.Add(host)
|
||||
res.CanonName = rr[0].Answer
|
||||
rr = findRewrites(d.Rewrites, host, qtype)
|
||||
res.CanonName = host
|
||||
rewrites, matched = findRewrites(d.Rewrites, host, qtype)
|
||||
}
|
||||
|
||||
for _, r := range rr {
|
||||
if r.Type == qtype && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
|
||||
if r.IP == nil { // IP exception
|
||||
res.Reason = NotFilteredNotFound
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
res.IPList = append(res.IPList, r.IP)
|
||||
log.Debug("rewrite: A/AAAA for %s is %s", host, r.IP)
|
||||
}
|
||||
}
|
||||
setRewriteResult(&res, host, rewrites, qtype)
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// setRewriteResult sets the Reason or IPList of res if necessary. res must not
|
||||
// be nil.
|
||||
func setRewriteResult(res *Result, host string, rewrites []*LegacyRewrite, qtype uint16) {
|
||||
for _, rw := range rewrites {
|
||||
if rw.Type == qtype && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
|
||||
if rw.IP == nil {
|
||||
// "A"/"AAAA" exception: allow getting from upstream.
|
||||
res.Reason = NotFilteredNotFound
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
res.IPList = append(res.IPList, rw.IP)
|
||||
|
||||
log.Debug("rewrite: a/aaaa for %s is %s", host, rw.IP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// matchBlockedServicesRules checks the host against the blocked services rules
|
||||
// in settings, if any. The err is always nil, it is only there to make this
|
||||
// a valid hostChecker function.
|
||||
@@ -548,6 +585,10 @@ func matchBlockedServicesRules(
|
||||
_ uint16,
|
||||
setts *Settings,
|
||||
) (res Result, err error) {
|
||||
if !setts.ProtectionEnabled {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
svcs := setts.ServicesRules
|
||||
if len(svcs) == 0 {
|
||||
return Result{}, nil
|
||||
@@ -582,80 +623,82 @@ func matchBlockedServicesRules(
|
||||
// Adding rule and matching against the rules
|
||||
//
|
||||
|
||||
// fileExists returns true if file exists.
|
||||
func fileExists(fn string) bool {
|
||||
_, err := os.Stat(fn)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func createFilteringEngine(filters []Filter) (*filterlist.RuleStorage, *urlfilter.DNSEngine, error) {
|
||||
listArray := []filterlist.RuleList{}
|
||||
func newRuleStorage(filters []Filter) (rs *filterlist.RuleStorage, err error) {
|
||||
lists := make([]filterlist.RuleList, 0, len(filters))
|
||||
for _, f := range filters {
|
||||
var list filterlist.RuleList
|
||||
|
||||
if f.ID == 0 {
|
||||
list = &filterlist.StringRuleList{
|
||||
ID: 0,
|
||||
switch id := int(f.ID); {
|
||||
case len(f.Data) != 0:
|
||||
lists = append(lists, &filterlist.StringRuleList{
|
||||
ID: id,
|
||||
RulesText: string(f.Data),
|
||||
IgnoreCosmetic: true,
|
||||
}
|
||||
} else if !fileExists(f.FilePath) {
|
||||
list = &filterlist.StringRuleList{
|
||||
ID: int(f.ID),
|
||||
IgnoreCosmetic: true,
|
||||
}
|
||||
} else if runtime.GOOS == "windows" {
|
||||
// On Windows we don't pass a file to urlfilter because
|
||||
// it's difficult to update this file while it's being
|
||||
// used.
|
||||
data, err := os.ReadFile(f.FilePath)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("reading filter content: %w", err)
|
||||
})
|
||||
case f.FilePath == "":
|
||||
continue
|
||||
case runtime.GOOS == "windows":
|
||||
// On Windows we don't pass a file to urlfilter because it's
|
||||
// difficult to update this file while it's being used.
|
||||
var data []byte
|
||||
data, err = os.ReadFile(f.FilePath)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
continue
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("reading filter content: %w", err)
|
||||
}
|
||||
|
||||
list = &filterlist.StringRuleList{
|
||||
ID: int(f.ID),
|
||||
lists = append(lists, &filterlist.StringRuleList{
|
||||
ID: id,
|
||||
RulesText: string(data),
|
||||
IgnoreCosmetic: true,
|
||||
})
|
||||
default:
|
||||
var list *filterlist.FileRuleList
|
||||
list, err = filterlist.NewFileRuleList(id, f.FilePath, true)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
continue
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("creating file rule list with %q: %w", f.FilePath, err)
|
||||
}
|
||||
} else {
|
||||
var err error
|
||||
list, err = filterlist.NewFileRuleList(int(f.ID), f.FilePath, true)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("filterlist.NewFileRuleList(): %s: %w", f.FilePath, err)
|
||||
}
|
||||
|
||||
lists = append(lists, list)
|
||||
}
|
||||
listArray = append(listArray, list)
|
||||
}
|
||||
|
||||
rulesStorage, err := filterlist.NewRuleStorage(listArray)
|
||||
rs, err = filterlist.NewRuleStorage(lists)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("filterlist.NewRuleStorage(): %w", err)
|
||||
return nil, fmt.Errorf("creating rule storage: %w", err)
|
||||
}
|
||||
filteringEngine := urlfilter.NewDNSEngine(rulesStorage)
|
||||
return rulesStorage, filteringEngine, nil
|
||||
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
// Initialize urlfilter objects.
|
||||
func (d *DNSFilter) initFiltering(allowFilters, blockFilters []Filter) error {
|
||||
rulesStorage, filteringEngine, err := createFilteringEngine(blockFilters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rulesStorageAllow, filteringEngineAllow, err := createFilteringEngine(allowFilters)
|
||||
rulesStorage, err := newRuleStorage(blockFilters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.engineLock.Lock()
|
||||
d.reset()
|
||||
d.rulesStorage = rulesStorage
|
||||
d.filteringEngine = filteringEngine
|
||||
d.rulesStorageAllow = rulesStorageAllow
|
||||
d.filteringEngineAllow = filteringEngineAllow
|
||||
d.engineLock.Unlock()
|
||||
rulesStorageAllow, err := newRuleStorage(allowFilters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Make sure that the OS reclaims memory as soon as possible
|
||||
filteringEngine := urlfilter.NewDNSEngine(rulesStorage)
|
||||
filteringEngineAllow := urlfilter.NewDNSEngine(rulesStorageAllow)
|
||||
|
||||
func() {
|
||||
d.engineLock.Lock()
|
||||
defer d.engineLock.Unlock()
|
||||
|
||||
d.reset()
|
||||
d.rulesStorage = rulesStorage
|
||||
d.filteringEngine = filteringEngine
|
||||
d.rulesStorageAllow = rulesStorageAllow
|
||||
d.filteringEngineAllow = filteringEngineAllow
|
||||
}()
|
||||
|
||||
// Make sure that the OS reclaims memory as soon as possible.
|
||||
debug.FreeOSMemory()
|
||||
log.Debug("initialized filtering engine")
|
||||
|
||||
@@ -677,11 +720,10 @@ func hostRulesToRules(netRules []*rules.HostRule) (res []rules.Rule) {
|
||||
return res
|
||||
}
|
||||
|
||||
// matchHostProcessAllowList processes the allowlist logic of host
|
||||
// matching.
|
||||
// matchHostProcessAllowList processes the allowlist logic of host matching.
|
||||
func (d *DNSFilter) matchHostProcessAllowList(
|
||||
host string,
|
||||
dnsres urlfilter.DNSResult,
|
||||
dnsres *urlfilter.DNSResult,
|
||||
) (res Result, err error) {
|
||||
var matchedRules []rules.Rule
|
||||
if dnsres.NetworkRule != nil {
|
||||
@@ -704,7 +746,7 @@ func (d *DNSFilter) matchHostProcessAllowList(
|
||||
// matchHostProcessDNSResult processes the matched DNS filtering result.
|
||||
func (d *DNSFilter) matchHostProcessDNSResult(
|
||||
qtype uint16,
|
||||
dnsres urlfilter.DNSResult,
|
||||
dnsres *urlfilter.DNSResult,
|
||||
) (res Result) {
|
||||
if dnsres.NetworkRule != nil {
|
||||
reason := FilteredBlockList
|
||||
@@ -734,8 +776,8 @@ func (d *DNSFilter) matchHostProcessDNSResult(
|
||||
}
|
||||
|
||||
if dnsres.HostRulesV4 != nil || dnsres.HostRulesV6 != nil {
|
||||
// Question type doesn't match the host rules. Return the first
|
||||
// matched host rule, but without an IP address.
|
||||
// Question type doesn't match the host rules. Return the first matched
|
||||
// host rule, but without an IP address.
|
||||
var matchedRules []rules.Rule
|
||||
if dnsres.HostRulesV4 != nil {
|
||||
matchedRules = []rules.Rule{dnsres.HostRulesV4[0]}
|
||||
@@ -749,32 +791,34 @@ func (d *DNSFilter) matchHostProcessDNSResult(
|
||||
return Result{}
|
||||
}
|
||||
|
||||
// matchHost is a low-level way to check only if hostname is filtered by rules,
|
||||
// matchHost is a low-level way to check only if host is filtered by rules,
|
||||
// skipping expensive safebrowsing and parental lookups.
|
||||
func (d *DNSFilter) matchHost(
|
||||
host string,
|
||||
qtype uint16,
|
||||
rrtype uint16,
|
||||
setts *Settings,
|
||||
) (res Result, err error) {
|
||||
if !setts.FilteringEnabled {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
d.engineLock.RLock()
|
||||
// Keep in mind that this lock must be held no just when calling Match()
|
||||
// but also while using the rules returned by it.
|
||||
defer d.engineLock.RUnlock()
|
||||
|
||||
ureq := urlfilter.DNSRequest{
|
||||
Hostname: host,
|
||||
SortedClientTags: setts.ClientTags,
|
||||
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.
|
||||
ClientIP: setts.ClientIP.String(),
|
||||
ClientName: setts.ClientName,
|
||||
DNSType: qtype,
|
||||
DNSType: rrtype,
|
||||
}
|
||||
|
||||
if d.filteringEngineAllow != nil {
|
||||
d.engineLock.RLock()
|
||||
// Keep in mind that this lock must be held no just when calling Match() but
|
||||
// also while using the rules returned by it.
|
||||
//
|
||||
// TODO(e.burkov): Inspect if the above is true.
|
||||
defer d.engineLock.RUnlock()
|
||||
|
||||
if setts.ProtectionEnabled && d.filteringEngineAllow != nil {
|
||||
dnsres, ok := d.filteringEngineAllow.MatchRequest(ureq)
|
||||
if ok {
|
||||
return d.matchHostProcessAllowList(host, dnsres)
|
||||
@@ -786,13 +830,12 @@ func (d *DNSFilter) matchHost(
|
||||
}
|
||||
|
||||
dnsres, ok := d.filteringEngine.MatchRequest(ureq)
|
||||
|
||||
// Check DNS rewrites first, because the API there is a bit awkward.
|
||||
if dnsr := dnsres.DNSRewrites(); len(dnsr) > 0 {
|
||||
res = d.processDNSRewrites(dnsr)
|
||||
if res.Reason == RewrittenRule && res.CanonName == host {
|
||||
// A rewrite of a host to itself. Go on and try
|
||||
// matching other things.
|
||||
// A rewrite of a host to itself. Go on and try matching other
|
||||
// things.
|
||||
} else {
|
||||
return res, nil
|
||||
}
|
||||
@@ -800,7 +843,12 @@ func (d *DNSFilter) matchHost(
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
res = d.matchHostProcessDNSResult(qtype, dnsres)
|
||||
if !setts.ProtectionEnabled {
|
||||
// Don't check non-dnsrewrite filtering results.
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
res = d.matchHostProcessDNSResult(rrtype, dnsres)
|
||||
for _, r := range res.Rules {
|
||||
log.Debug(
|
||||
"filtering: found rule %q for host %q, filter list id: %d",
|
||||
@@ -836,40 +884,33 @@ func InitModule() {
|
||||
}
|
||||
|
||||
// New creates properly initialized DNS Filter that is ready to be used.
|
||||
func New(c *Config, blockFilters []Filter) *DNSFilter {
|
||||
var resolver Resolver = net.DefaultResolver
|
||||
func New(c *Config, blockFilters []Filter) (d *DNSFilter) {
|
||||
d = &DNSFilter{
|
||||
resolver: net.DefaultResolver,
|
||||
}
|
||||
if c != nil {
|
||||
cacheConf := cache.Config{
|
||||
|
||||
d.safebrowsingCache = cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
}
|
||||
|
||||
if gctx.safebrowsingCache == nil {
|
||||
cacheConf.MaxSize = c.SafeBrowsingCacheSize
|
||||
gctx.safebrowsingCache = cache.New(cacheConf)
|
||||
}
|
||||
|
||||
if gctx.safeSearchCache == nil {
|
||||
cacheConf.MaxSize = c.SafeSearchCacheSize
|
||||
gctx.safeSearchCache = cache.New(cacheConf)
|
||||
}
|
||||
|
||||
if gctx.parentalCache == nil {
|
||||
cacheConf.MaxSize = c.ParentalCacheSize
|
||||
gctx.parentalCache = cache.New(cacheConf)
|
||||
}
|
||||
MaxSize: c.SafeBrowsingCacheSize,
|
||||
})
|
||||
d.safeSearchCache = cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxSize: c.SafeSearchCacheSize,
|
||||
})
|
||||
d.parentalCache = cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxSize: c.ParentalCacheSize,
|
||||
})
|
||||
|
||||
if c.CustomResolver != nil {
|
||||
resolver = c.CustomResolver
|
||||
d.resolver = c.CustomResolver
|
||||
}
|
||||
}
|
||||
|
||||
d := &DNSFilter{
|
||||
resolver: resolver,
|
||||
}
|
||||
|
||||
d.hostCheckers = []hostChecker{{
|
||||
check: d.checkEtcHosts,
|
||||
name: "etchosts",
|
||||
check: d.matchSysHosts,
|
||||
name: "hosts container",
|
||||
}, {
|
||||
check: d.matchHost,
|
||||
name: "filtering",
|
||||
@@ -890,12 +931,18 @@ func New(c *Config, blockFilters []Filter) *DNSFilter {
|
||||
err := d.initSecurityServices()
|
||||
if err != nil {
|
||||
log.Error("filtering: initialize services: %s", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
d.Config = *c
|
||||
d.prepareRewrites()
|
||||
err = d.prepareRewrites()
|
||||
if err != nil {
|
||||
log.Error("rewrites: preparing: %s", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
bsvcs := []string{}
|
||||
|
||||
@@ -21,15 +21,17 @@ func TestMain(m *testing.M) {
|
||||
aghtest.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
var setts Settings
|
||||
var setts = Settings{
|
||||
ProtectionEnabled: true,
|
||||
}
|
||||
|
||||
// Helpers.
|
||||
|
||||
func purgeCaches() {
|
||||
func purgeCaches(d *DNSFilter) {
|
||||
for _, c := range []cache.Cache{
|
||||
gctx.safebrowsingCache,
|
||||
gctx.parentalCache,
|
||||
gctx.safeSearchCache,
|
||||
d.safebrowsingCache,
|
||||
d.parentalCache,
|
||||
d.safeSearchCache,
|
||||
} {
|
||||
if c != nil {
|
||||
c.Clear()
|
||||
@@ -37,11 +39,11 @@ func purgeCaches() {
|
||||
}
|
||||
}
|
||||
|
||||
func newForTest(c *Config, filters []Filter) *DNSFilter {
|
||||
func newForTest(t testing.TB, c *Config, filters []Filter) *DNSFilter {
|
||||
setts = Settings{
|
||||
FilteringEnabled: true,
|
||||
ProtectionEnabled: true,
|
||||
FilteringEnabled: true,
|
||||
}
|
||||
setts.FilteringEnabled = true
|
||||
if c != nil {
|
||||
c.SafeBrowsingCacheSize = 10000
|
||||
c.ParentalCacheSize = 10000
|
||||
@@ -52,7 +54,8 @@ func newForTest(c *Config, filters []Filter) *DNSFilter {
|
||||
setts.ParentalEnabled = c.ParentalEnabled
|
||||
}
|
||||
d := New(c, filters)
|
||||
purgeCaches()
|
||||
purgeCaches(d)
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
@@ -103,7 +106,7 @@ func TestEtcHostsMatching(t *testing.T) {
|
||||
filters := []Filter{{
|
||||
ID: 0, Data: []byte(text),
|
||||
}}
|
||||
d := newForTest(nil, filters)
|
||||
d := newForTest(t, nil, filters)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
d.checkMatchIP(t, "google.com", addr, dns.TypeA)
|
||||
@@ -168,7 +171,7 @@ func TestSafeBrowsing(t *testing.T) {
|
||||
aghtest.ReplaceLogWriter(t, logOutput)
|
||||
aghtest.ReplaceLogLevel(t, log.DEBUG)
|
||||
|
||||
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
const matching = "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
@@ -191,7 +194,7 @@ func TestSafeBrowsing(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParallelSB(t *testing.T) {
|
||||
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
const matching = "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
@@ -215,7 +218,7 @@ func TestParallelSB(t *testing.T) {
|
||||
// Safe Search.
|
||||
|
||||
func TestSafeSearch(t *testing.T) {
|
||||
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||
d := newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
val, ok := d.SafeSearchDomain("www.google.com")
|
||||
require.True(t, ok)
|
||||
@@ -224,7 +227,9 @@ func TestSafeSearch(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||
d := newForTest(t, &Config{
|
||||
SafeSearchEnabled: true,
|
||||
}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
yandexIP := net.IPv4(213, 180, 193, 56)
|
||||
@@ -247,13 +252,14 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.Equal(t, yandexIP, res.Rules[0].IP)
|
||||
assert.EqualValues(t, SafeSearchListID, res.Rules[0].FilterListID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||
resolver := &aghtest.TestResolver{}
|
||||
d := newForTest(&Config{
|
||||
d := newForTest(t, &Config{
|
||||
SafeSearchEnabled: true,
|
||||
CustomResolver: resolver,
|
||||
}, nil)
|
||||
@@ -280,12 +286,13 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.Equal(t, ip, res.Rules[0].IP)
|
||||
assert.EqualValues(t, SafeSearchListID, res.Rules[0].FilterListID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeSearchCacheYandex(t *testing.T) {
|
||||
d := newForTest(nil, nil)
|
||||
d := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
const domain = "yandex.ru"
|
||||
|
||||
@@ -299,7 +306,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
||||
|
||||
yandexIP := net.IPv4(213, 180, 193, 56)
|
||||
|
||||
d = newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||
d = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
res, err = d.CheckHost(domain, dns.TypeA, &setts)
|
||||
@@ -310,7 +317,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
||||
assert.Equal(t, res.Rules[0].IP, yandexIP)
|
||||
|
||||
// Check cache.
|
||||
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
|
||||
cachedValue, isFound := getCachedResult(d.safeSearchCache, domain)
|
||||
require.True(t, isFound)
|
||||
require.Len(t, cachedValue.Rules, 1)
|
||||
|
||||
@@ -319,7 +326,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
||||
|
||||
func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||
resolver := &aghtest.TestResolver{}
|
||||
d := newForTest(&Config{
|
||||
d := newForTest(t, &Config{
|
||||
CustomResolver: resolver,
|
||||
}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
@@ -332,7 +339,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||
|
||||
require.Empty(t, res.Rules)
|
||||
|
||||
d = newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||
d = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
d.resolver = resolver
|
||||
|
||||
@@ -359,7 +366,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||
assert.True(t, res.Rules[0].IP.Equal(ip))
|
||||
|
||||
// Check cache.
|
||||
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
|
||||
cachedValue, isFound := getCachedResult(d.safeSearchCache, domain)
|
||||
require.True(t, isFound)
|
||||
require.Len(t, cachedValue.Rules, 1)
|
||||
|
||||
@@ -373,7 +380,7 @@ func TestParentalControl(t *testing.T) {
|
||||
aghtest.ReplaceLogWriter(t, logOutput)
|
||||
aghtest.ReplaceLogLevel(t, log.DEBUG)
|
||||
|
||||
d := newForTest(&Config{ParentalEnabled: true}, nil)
|
||||
d := newForTest(t, &Config{ParentalEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
const matching = "pornhub.com"
|
||||
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
|
||||
@@ -677,7 +684,7 @@ func TestMatching(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("%s-%s", tc.name, tc.host), func(t *testing.T) {
|
||||
filters := []Filter{{ID: 0, Data: []byte(tc.rules)}}
|
||||
d := newForTest(nil, filters)
|
||||
d := newForTest(t, nil, filters)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts)
|
||||
@@ -703,7 +710,7 @@ func TestWhitelist(t *testing.T) {
|
||||
whiteFilters := []Filter{{
|
||||
ID: 0, Data: []byte(whiteRules),
|
||||
}}
|
||||
d := newForTest(nil, filters)
|
||||
d := newForTest(t, nil, filters)
|
||||
|
||||
err := d.SetFilters(filters, whiteFilters, false)
|
||||
require.NoError(t, err)
|
||||
@@ -748,7 +755,7 @@ func applyClientSettings(setts *Settings) {
|
||||
}
|
||||
|
||||
func TestClientSettings(t *testing.T) {
|
||||
d := newForTest(
|
||||
d := newForTest(t,
|
||||
&Config{
|
||||
ParentalEnabled: true,
|
||||
SafeBrowsingEnabled: false,
|
||||
@@ -797,7 +804,11 @@ func TestClientSettings(t *testing.T) {
|
||||
|
||||
makeTester := func(tc testCase, before bool) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
r, _ := d.CheckHost(tc.host, dns.TypeA, &setts)
|
||||
t.Helper()
|
||||
|
||||
r, err := d.CheckHost(tc.host, dns.TypeA, &setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
if before {
|
||||
assert.True(t, r.IsFiltered)
|
||||
assert.Equal(t, tc.wantReason, r.Reason)
|
||||
@@ -808,7 +819,7 @@ func TestClientSettings(t *testing.T) {
|
||||
}
|
||||
|
||||
// Check behaviour without any per-client settings, then apply per-client
|
||||
// settings and check behaviour once again.
|
||||
// settings and check behavior once again.
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, makeTester(tc, tc.before))
|
||||
}
|
||||
@@ -823,7 +834,7 @@ func TestClientSettings(t *testing.T) {
|
||||
// Benchmarks.
|
||||
|
||||
func BenchmarkSafeBrowsing(b *testing.B) {
|
||||
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
|
||||
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
blocked := "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
@@ -839,7 +850,7 @@ func BenchmarkSafeBrowsing(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
|
||||
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
blocked := "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
@@ -857,7 +868,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkSafeSearch(b *testing.B) {
|
||||
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||
d := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
for n := 0; n < b.N; n++ {
|
||||
val, ok := d.SafeSearchDomain("www.google.com")
|
||||
@@ -868,7 +879,7 @@ func BenchmarkSafeSearch(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkSafeSearchParallel(b *testing.B) {
|
||||
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||
d := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
|
||||
@@ -4,93 +4,121 @@ package filtering
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// RewriteEntry is a rewrite array element
|
||||
type RewriteEntry struct {
|
||||
// Domain is the domain for which this rewrite should work.
|
||||
// LegacyRewrite is a single legacy DNS rewrite record.
|
||||
//
|
||||
// Instances of *LegacyRewrite must never be nil.
|
||||
type LegacyRewrite struct {
|
||||
// Domain is the domain pattern for which this rewrite should work.
|
||||
Domain string `yaml:"domain"`
|
||||
|
||||
// Answer is the IP address, canonical name, or one of the special
|
||||
// values: "A" or "AAAA".
|
||||
Answer string `yaml:"answer"`
|
||||
|
||||
// IP is the IP address that should be used in the response if Type is
|
||||
// A or AAAA.
|
||||
// dns.TypeA or dns.TypeAAAA.
|
||||
IP net.IP `yaml:"-"`
|
||||
|
||||
// Type is the DNS record type: A, AAAA, or CNAME.
|
||||
Type uint16 `yaml:"-"`
|
||||
}
|
||||
|
||||
// equal returns true if the entry is considered equal to the other.
|
||||
func (e *RewriteEntry) equal(other RewriteEntry) (ok bool) {
|
||||
return e.Domain == other.Domain && e.Answer == other.Answer
|
||||
// clone returns a deep clone of rw.
|
||||
func (rw *LegacyRewrite) clone() (cloneRW *LegacyRewrite) {
|
||||
return &LegacyRewrite{
|
||||
Domain: rw.Domain,
|
||||
Answer: rw.Answer,
|
||||
IP: netutil.CloneIP(rw.IP),
|
||||
Type: rw.Type,
|
||||
}
|
||||
}
|
||||
|
||||
// matchesQType returns true if the entry matched qtype.
|
||||
func (e *RewriteEntry) matchesQType(qtype uint16) (ok bool) {
|
||||
// equal returns true if the rw is equal to the other.
|
||||
func (rw *LegacyRewrite) equal(other *LegacyRewrite) (ok bool) {
|
||||
return rw.Domain == other.Domain && rw.Answer == other.Answer
|
||||
}
|
||||
|
||||
// matchesQType returns true if the entry matches the question type qt.
|
||||
func (rw *LegacyRewrite) matchesQType(qt uint16) (ok bool) {
|
||||
// Add CNAMEs, since they match for all types requests.
|
||||
if e.Type == dns.TypeCNAME {
|
||||
if rw.Type == dns.TypeCNAME {
|
||||
return true
|
||||
}
|
||||
|
||||
// Reject types other than A and AAAA.
|
||||
if qtype != dns.TypeA && qtype != dns.TypeAAAA {
|
||||
if qt != dns.TypeA && qt != dns.TypeAAAA {
|
||||
return false
|
||||
}
|
||||
|
||||
// If the types match or the entry is set to allow only the other type,
|
||||
// include them.
|
||||
return e.Type == qtype || e.IP == nil
|
||||
return rw.Type == qt || rw.IP == nil
|
||||
}
|
||||
|
||||
// normalize makes sure that the a new or decoded entry is normalized with
|
||||
// regards to domain name case, IP length, and so on.
|
||||
func (e *RewriteEntry) normalize() {
|
||||
// TODO(a.garipov): Write a case-agnostic version of strings.HasSuffix
|
||||
// and use it in matchDomainWildcard instead of using strings.ToLower
|
||||
//
|
||||
// If rw is nil, it returns an errors.
|
||||
func (rw *LegacyRewrite) normalize() (err error) {
|
||||
if rw == nil {
|
||||
return errors.Error("nil rewrite entry")
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Write a case-agnostic version of strings.HasSuffix and
|
||||
// use it in matchDomainWildcard instead of using strings.ToLower
|
||||
// everywhere.
|
||||
e.Domain = strings.ToLower(e.Domain)
|
||||
rw.Domain = strings.ToLower(rw.Domain)
|
||||
|
||||
switch e.Answer {
|
||||
switch rw.Answer {
|
||||
case "AAAA":
|
||||
e.IP = nil
|
||||
e.Type = dns.TypeAAAA
|
||||
rw.IP = nil
|
||||
rw.Type = dns.TypeAAAA
|
||||
|
||||
return
|
||||
return nil
|
||||
case "A":
|
||||
e.IP = nil
|
||||
e.Type = dns.TypeA
|
||||
rw.IP = nil
|
||||
rw.Type = dns.TypeA
|
||||
|
||||
return
|
||||
return nil
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
ip := net.ParseIP(e.Answer)
|
||||
ip := net.ParseIP(rw.Answer)
|
||||
if ip == nil {
|
||||
e.Type = dns.TypeCNAME
|
||||
rw.Type = dns.TypeCNAME
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
ip4 := ip.To4()
|
||||
if ip4 != nil {
|
||||
e.IP = ip4
|
||||
e.Type = dns.TypeA
|
||||
rw.IP = ip4
|
||||
rw.Type = dns.TypeA
|
||||
} else {
|
||||
e.IP = ip
|
||||
e.Type = dns.TypeAAAA
|
||||
rw.IP = ip
|
||||
rw.Type = dns.TypeAAAA
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isWildcard(host string) bool {
|
||||
return len(host) > 1 && host[0] == '*' && host[1] == '.'
|
||||
// isWildcard returns true if pat is a wildcard domain pattern.
|
||||
func isWildcard(pat string) bool {
|
||||
return len(pat) > 1 && pat[0] == '*' && pat[1] == '.'
|
||||
}
|
||||
|
||||
// matchDomainWildcard returns true if host matches the wildcard pattern.
|
||||
@@ -106,16 +134,16 @@ func matchDomainWildcard(host, wildcard string) (ok bool) {
|
||||
// wildcard > exact
|
||||
// lower level wildcard > higher level wildcard
|
||||
//
|
||||
type rewritesSorted []RewriteEntry
|
||||
type rewritesSorted []*LegacyRewrite
|
||||
|
||||
// Len implements the sort.Interface interface for legacyRewritesSorted.
|
||||
func (a rewritesSorted) Len() int { return len(a) }
|
||||
func (a rewritesSorted) Len() (l int) { return len(a) }
|
||||
|
||||
// Swap implements the sort.Interface interface for legacyRewritesSorted.
|
||||
func (a rewritesSorted) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
|
||||
// Less implements the sort.Interface interface for legacyRewritesSorted.
|
||||
func (a rewritesSorted) Less(i, j int) bool {
|
||||
func (a rewritesSorted) Less(i, j int) (less bool) {
|
||||
if a[i].Type == dns.TypeCNAME && a[j].Type != dns.TypeCNAME {
|
||||
return true
|
||||
} else if a[i].Type != dns.TypeCNAME && a[j].Type == dns.TypeCNAME {
|
||||
@@ -132,49 +160,62 @@ func (a rewritesSorted) Less(i, j int) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// both are wildcards
|
||||
// Both are wildcards.
|
||||
return len(a[i].Domain) > len(a[j].Domain)
|
||||
}
|
||||
|
||||
func (d *DNSFilter) prepareRewrites() {
|
||||
for i := range d.Rewrites {
|
||||
d.Rewrites[i].normalize()
|
||||
// prepareRewrites normalizes and validates all legacy DNS rewrites.
|
||||
func (d *DNSFilter) prepareRewrites() (err error) {
|
||||
for i, r := range d.Rewrites {
|
||||
err = r.normalize()
|
||||
if err != nil {
|
||||
return fmt.Errorf("at index %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findRewrites returns the list of matched rewrite entries. The priority is:
|
||||
// CNAME, then A and AAAA; exact, then wildcard. If the host is matched
|
||||
// exactly, wildcard entries aren't returned. If the host matched by wildcards,
|
||||
// return the most specific for the question type.
|
||||
func findRewrites(entries []RewriteEntry, host string, qtype uint16) (matched []RewriteEntry) {
|
||||
rr := rewritesSorted{}
|
||||
// findRewrites returns the list of matched rewrite entries. If rewrites are
|
||||
// empty, but matched is true, the domain is found among the rewrite rules but
|
||||
// not for this question type.
|
||||
//
|
||||
// The result priority is: CNAME, then A and AAAA; exact, then wildcard. If the
|
||||
// host is matched exactly, wildcard entries aren't returned. If the host
|
||||
// matched by wildcards, return the most specific for the question type.
|
||||
func findRewrites(
|
||||
entries []*LegacyRewrite,
|
||||
host string,
|
||||
qtype uint16,
|
||||
) (rewrites []*LegacyRewrite, matched bool) {
|
||||
for _, e := range entries {
|
||||
if e.Domain != host && !matchDomainWildcard(host, e.Domain) {
|
||||
continue
|
||||
}
|
||||
|
||||
matched = true
|
||||
if e.matchesQType(qtype) {
|
||||
rr = append(rr, e)
|
||||
rewrites = append(rewrites, e)
|
||||
}
|
||||
}
|
||||
|
||||
if len(rr) == 0 {
|
||||
return nil
|
||||
if len(rewrites) == 0 {
|
||||
return nil, matched
|
||||
}
|
||||
|
||||
sort.Sort(rr)
|
||||
sort.Sort(rewritesSorted(rewrites))
|
||||
|
||||
for i, r := range rr {
|
||||
for i, r := range rewrites {
|
||||
if isWildcard(r.Domain) {
|
||||
// Don't use rr[:0], because we need to return at least
|
||||
// one item here.
|
||||
rr = rr[:max(1, i)]
|
||||
// Don't use rewrites[:0], because we need to return at least one
|
||||
// item here.
|
||||
rewrites = rewrites[:max(1, i)]
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return rr
|
||||
return rewrites, matched
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
@@ -206,29 +247,39 @@ func (d *DNSFilter) handleRewriteList(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(arr)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
|
||||
jsent := rewriteEntryJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&jsent)
|
||||
rwJSON := rewriteEntryJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&rwJSON)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ent := RewriteEntry{
|
||||
Domain: jsent.Domain,
|
||||
Answer: jsent.Answer,
|
||||
rw := &LegacyRewrite{
|
||||
Domain: rwJSON.Domain,
|
||||
Answer: rwJSON.Answer,
|
||||
}
|
||||
ent.normalize()
|
||||
|
||||
err = rw.normalize()
|
||||
if err != nil {
|
||||
// Shouldn't happen currently, since normalize only returns a non-nil
|
||||
// error when a rewrite is nil, but be change-proof.
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "normalizing: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
d.confLock.Lock()
|
||||
d.Config.Rewrites = append(d.Config.Rewrites, ent)
|
||||
d.Config.Rewrites = append(d.Config.Rewrites, rw)
|
||||
d.confLock.Unlock()
|
||||
log.Debug("Rewrites: added element: %s -> %s [%d]",
|
||||
ent.Domain, ent.Answer, len(d.Config.Rewrites))
|
||||
log.Debug("rewrite: added element: %s -> %s [%d]", rw.Domain, rw.Answer, len(d.Config.Rewrites))
|
||||
|
||||
d.Config.ConfigModified()
|
||||
}
|
||||
@@ -237,21 +288,25 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request)
|
||||
jsent := rewriteEntryJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&jsent)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
entDel := RewriteEntry{
|
||||
entDel := &LegacyRewrite{
|
||||
Domain: jsent.Domain,
|
||||
Answer: jsent.Answer,
|
||||
}
|
||||
arr := []RewriteEntry{}
|
||||
arr := []*LegacyRewrite{}
|
||||
|
||||
d.confLock.Lock()
|
||||
for _, ent := range d.Config.Rewrites {
|
||||
if ent.equal(entDel) {
|
||||
log.Debug("Rewrites: removed element: %s -> %s", ent.Domain, ent.Answer)
|
||||
log.Debug("rewrite: removed element: %s -> %s", ent.Domain, ent.Answer)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
arr = append(arr, ent)
|
||||
}
|
||||
d.Config.Rewrites = arr
|
||||
|
||||
@@ -12,10 +12,10 @@ import (
|
||||
// TODO(e.burkov): All the tests in this file may and should me merged together.
|
||||
|
||||
func TestRewrites(t *testing.T) {
|
||||
d := newForTest(nil, nil)
|
||||
d := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
d.Rewrites = []RewriteEntry{{
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
// This one and below are about CNAME, A and AAAA.
|
||||
Domain: "somecname",
|
||||
Answer: "somehost.com",
|
||||
@@ -66,107 +66,132 @@ func TestRewrites(t *testing.T) {
|
||||
}, {
|
||||
Domain: "BIGHOST.COM",
|
||||
Answer: "1.2.3.7",
|
||||
}, {
|
||||
Domain: "*.issue4016.com",
|
||||
Answer: "sub.issue4016.com",
|
||||
}}
|
||||
d.prepareRewrites()
|
||||
|
||||
require.NoError(t, d.prepareRewrites())
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
wantCName string
|
||||
wantVals []net.IP
|
||||
dtyp uint16
|
||||
name string
|
||||
host string
|
||||
wantCName string
|
||||
wantIPs []net.IP
|
||||
wantReason Reason
|
||||
dtyp uint16
|
||||
}{{
|
||||
name: "not_filtered_not_found",
|
||||
host: "hoost.com",
|
||||
wantCName: "",
|
||||
wantVals: nil,
|
||||
dtyp: dns.TypeA,
|
||||
name: "not_filtered_not_found",
|
||||
host: "hoost.com",
|
||||
wantCName: "",
|
||||
wantIPs: nil,
|
||||
wantReason: NotFilteredNotFound,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "rewritten_a",
|
||||
host: "www.host.com",
|
||||
wantCName: "host.com",
|
||||
wantVals: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
|
||||
dtyp: dns.TypeA,
|
||||
name: "rewritten_a",
|
||||
host: "www.host.com",
|
||||
wantCName: "host.com",
|
||||
wantIPs: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "rewritten_aaaa",
|
||||
host: "www.host.com",
|
||||
wantCName: "host.com",
|
||||
wantVals: []net.IP{net.ParseIP("1:2:3::4")},
|
||||
dtyp: dns.TypeAAAA,
|
||||
name: "rewritten_aaaa",
|
||||
host: "www.host.com",
|
||||
wantCName: "host.com",
|
||||
wantIPs: []net.IP{net.ParseIP("1:2:3::4")},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "wildcard_match",
|
||||
host: "abc.host.com",
|
||||
wantCName: "",
|
||||
wantVals: []net.IP{{1, 2, 3, 5}},
|
||||
dtyp: dns.TypeA,
|
||||
name: "wildcard_match",
|
||||
host: "abc.host.com",
|
||||
wantCName: "",
|
||||
wantIPs: []net.IP{{1, 2, 3, 5}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "wildcard_override",
|
||||
host: "a.host.com",
|
||||
wantCName: "",
|
||||
wantVals: []net.IP{{1, 2, 3, 4}},
|
||||
dtyp: dns.TypeA,
|
||||
name: "wildcard_override",
|
||||
host: "a.host.com",
|
||||
wantCName: "",
|
||||
wantIPs: []net.IP{{1, 2, 3, 4}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "wildcard_cname_interaction",
|
||||
host: "www.host2.com",
|
||||
wantCName: "host.com",
|
||||
wantVals: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
|
||||
dtyp: dns.TypeA,
|
||||
name: "wildcard_cname_interaction",
|
||||
host: "www.host2.com",
|
||||
wantCName: "host.com",
|
||||
wantIPs: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "two_cnames",
|
||||
host: "b.host.com",
|
||||
wantCName: "somehost.com",
|
||||
wantVals: []net.IP{{0, 0, 0, 0}},
|
||||
dtyp: dns.TypeA,
|
||||
name: "two_cnames",
|
||||
host: "b.host.com",
|
||||
wantCName: "somehost.com",
|
||||
wantIPs: []net.IP{{0, 0, 0, 0}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "two_cnames_and_wildcard",
|
||||
host: "b.host3.com",
|
||||
wantCName: "x.host.com",
|
||||
wantVals: []net.IP{{1, 2, 3, 5}},
|
||||
dtyp: dns.TypeA,
|
||||
name: "two_cnames_and_wildcard",
|
||||
host: "b.host3.com",
|
||||
wantCName: "x.host.com",
|
||||
wantIPs: []net.IP{{1, 2, 3, 5}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "issue3343",
|
||||
host: "www.hostboth.com",
|
||||
wantCName: "",
|
||||
wantVals: []net.IP{net.ParseIP("1234::5678")},
|
||||
dtyp: dns.TypeAAAA,
|
||||
name: "issue3343",
|
||||
host: "www.hostboth.com",
|
||||
wantCName: "",
|
||||
wantIPs: []net.IP{net.ParseIP("1234::5678")},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "issue3351",
|
||||
host: "bighost.com",
|
||||
wantCName: "",
|
||||
wantVals: []net.IP{{1, 2, 3, 7}},
|
||||
dtyp: dns.TypeA,
|
||||
name: "issue3351",
|
||||
host: "bighost.com",
|
||||
wantCName: "",
|
||||
wantIPs: []net.IP{{1, 2, 3, 7}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "issue4008",
|
||||
host: "somehost.com",
|
||||
wantCName: "",
|
||||
wantIPs: nil,
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeHTTPS,
|
||||
}, {
|
||||
name: "issue4016",
|
||||
host: "www.issue4016.com",
|
||||
wantCName: "sub.issue4016.com",
|
||||
wantIPs: nil,
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "issue4016_self",
|
||||
host: "sub.issue4016.com",
|
||||
wantCName: "",
|
||||
wantIPs: nil,
|
||||
wantReason: NotFilteredNotFound,
|
||||
dtyp: dns.TypeA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
valsNum := len(tc.wantVals)
|
||||
|
||||
r := d.processRewrites(tc.host, tc.dtyp)
|
||||
if valsNum == 0 {
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.Equalf(t, Rewritten, r.Reason, "got %s", r.Reason)
|
||||
require.Equalf(t, tc.wantReason, r.Reason, "got %s", r.Reason)
|
||||
|
||||
if tc.wantCName != "" {
|
||||
assert.Equal(t, tc.wantCName, r.CanonName)
|
||||
}
|
||||
|
||||
require.Len(t, r.IPList, valsNum)
|
||||
for i, ip := range tc.wantVals {
|
||||
assert.Equal(t, ip, r.IPList[i])
|
||||
}
|
||||
assert.Equal(t, tc.wantIPs, r.IPList)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewritesLevels(t *testing.T) {
|
||||
d := newForTest(nil, nil)
|
||||
d := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// Exact host, wildcard L2, wildcard L3.
|
||||
d.Rewrites = []RewriteEntry{{
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
Domain: "host.com",
|
||||
Answer: "1.1.1.1",
|
||||
Type: dns.TypeA,
|
||||
@@ -179,7 +204,8 @@ func TestRewritesLevels(t *testing.T) {
|
||||
Answer: "3.3.3.3",
|
||||
Type: dns.TypeA,
|
||||
}}
|
||||
d.prepareRewrites()
|
||||
|
||||
require.NoError(t, d.prepareRewrites())
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -209,10 +235,10 @@ func TestRewritesLevels(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRewritesExceptionCNAME(t *testing.T) {
|
||||
d := newForTest(nil, nil)
|
||||
d := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// Wildcard and exception for a sub-domain.
|
||||
d.Rewrites = []RewriteEntry{{
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
Domain: "*.host.com",
|
||||
Answer: "2.2.2.2",
|
||||
}, {
|
||||
@@ -222,29 +248,32 @@ func TestRewritesExceptionCNAME(t *testing.T) {
|
||||
Domain: "*.sub.host.com",
|
||||
Answer: "*.sub.host.com",
|
||||
}}
|
||||
d.prepareRewrites()
|
||||
|
||||
require.NoError(t, d.prepareRewrites())
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
want net.IP
|
||||
}{{
|
||||
name: "match_sub-domain",
|
||||
name: "match_subdomain",
|
||||
host: "my.host.com",
|
||||
want: net.IP{2, 2, 2, 2},
|
||||
}, {
|
||||
name: "exception_cname",
|
||||
host: "sub.host.com",
|
||||
want: nil,
|
||||
}, {
|
||||
name: "exception_wildcard",
|
||||
host: "my.sub.host.com",
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := d.processRewrites(tc.host, dns.TypeA)
|
||||
if tc.want == nil {
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason, "got %s", r.Reason)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -257,10 +286,10 @@ func TestRewritesExceptionCNAME(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRewritesExceptionIP(t *testing.T) {
|
||||
d := newForTest(nil, nil)
|
||||
d := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// Exception for AAAA record.
|
||||
d.Rewrites = []RewriteEntry{{
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
Domain: "host.com",
|
||||
Answer: "1.2.3.4",
|
||||
Type: dns.TypeA,
|
||||
@@ -281,7 +310,8 @@ func TestRewritesExceptionIP(t *testing.T) {
|
||||
Answer: "A",
|
||||
Type: dns.TypeA,
|
||||
}}
|
||||
d.prepareRewrites()
|
||||
|
||||
require.NoError(t, d.prepareRewrites())
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
@@ -306,7 +307,7 @@ func (d *DNSFilter) checkSafeBrowsing(
|
||||
_ uint16,
|
||||
setts *Settings,
|
||||
) (res Result, err error) {
|
||||
if !setts.SafeBrowsingEnabled {
|
||||
if !setts.ProtectionEnabled || !setts.SafeBrowsingEnabled {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
@@ -318,7 +319,7 @@ func (d *DNSFilter) checkSafeBrowsing(
|
||||
sctx := &sbCtx{
|
||||
host: host,
|
||||
svc: "SafeBrowsing",
|
||||
cache: gctx.safebrowsingCache,
|
||||
cache: d.safebrowsingCache,
|
||||
cacheTime: d.Config.CacheTime,
|
||||
}
|
||||
|
||||
@@ -326,7 +327,8 @@ func (d *DNSFilter) checkSafeBrowsing(
|
||||
IsFiltered: true,
|
||||
Reason: FilteredSafeBrowsing,
|
||||
Rules: []*ResultRule{{
|
||||
Text: "adguard-malware-shavar",
|
||||
Text: "adguard-malware-shavar",
|
||||
FilterListID: SafeBrowsingListID,
|
||||
}},
|
||||
}
|
||||
|
||||
@@ -339,7 +341,7 @@ func (d *DNSFilter) checkParental(
|
||||
_ uint16,
|
||||
setts *Settings,
|
||||
) (res Result, err error) {
|
||||
if !setts.ParentalEnabled {
|
||||
if !setts.ProtectionEnabled || !setts.ParentalEnabled {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
@@ -351,7 +353,7 @@ func (d *DNSFilter) checkParental(
|
||||
sctx := &sbCtx{
|
||||
host: host,
|
||||
svc: "Parental",
|
||||
cache: gctx.parentalCache,
|
||||
cache: d.parentalCache,
|
||||
cacheTime: d.Config.CacheTime,
|
||||
}
|
||||
|
||||
@@ -359,19 +361,14 @@ func (d *DNSFilter) checkParental(
|
||||
IsFiltered: true,
|
||||
Reason: FilteredParental,
|
||||
Rules: []*ResultRule{{
|
||||
Text: "parental CATEGORY_BLACKLISTED",
|
||||
Text: "parental CATEGORY_BLACKLISTED",
|
||||
FilterListID: ParentalListID,
|
||||
}},
|
||||
}
|
||||
|
||||
return check(sctx, res, d.parentalUpstream)
|
||||
}
|
||||
|
||||
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
log.Info("DNSFilter: %s %s: %s", r.Method, r.URL, text)
|
||||
http.Error(w, text, code)
|
||||
}
|
||||
|
||||
func (d *DNSFilter) handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) {
|
||||
d.Config.SafeBrowsingEnabled = true
|
||||
d.Config.ConfigModified()
|
||||
@@ -390,7 +387,8 @@ func (d *DNSFilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Requ
|
||||
Enabled: d.Config.SafeBrowsingEnabled,
|
||||
})
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -413,8 +411,7 @@ func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request)
|
||||
Enabled: d.Config.ParentalEnabled,
|
||||
})
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
||||
return
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -108,7 +108,7 @@ func TestSafeBrowsingCache(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSBPC_checkErrorUpstream(t *testing.T) {
|
||||
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
ups := &aghtest.TestErrUpstream{}
|
||||
@@ -117,6 +117,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
|
||||
d.SetParentalUpstream(ups)
|
||||
|
||||
setts := &Settings{
|
||||
ProtectionEnabled: true,
|
||||
SafeBrowsingEnabled: true,
|
||||
ParentalEnabled: true,
|
||||
}
|
||||
@@ -129,41 +130,42 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSBPC(t *testing.T) {
|
||||
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
const hostname = "example.org"
|
||||
|
||||
setts := &Settings{
|
||||
ProtectionEnabled: true,
|
||||
SafeBrowsingEnabled: true,
|
||||
ParentalEnabled: true,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
testCache cache.Cache
|
||||
testFunc func(host string, _ uint16, _ *Settings) (res Result, err error)
|
||||
name string
|
||||
block bool
|
||||
testFunc func(host string, _ uint16, _ *Settings) (res Result, err error)
|
||||
testCache cache.Cache
|
||||
}{{
|
||||
testCache: d.safebrowsingCache,
|
||||
testFunc: d.checkSafeBrowsing,
|
||||
name: "sb_no_block",
|
||||
block: false,
|
||||
testFunc: d.checkSafeBrowsing,
|
||||
testCache: gctx.safebrowsingCache,
|
||||
}, {
|
||||
testCache: d.safebrowsingCache,
|
||||
testFunc: d.checkSafeBrowsing,
|
||||
name: "sb_block",
|
||||
block: true,
|
||||
testFunc: d.checkSafeBrowsing,
|
||||
testCache: gctx.safebrowsingCache,
|
||||
}, {
|
||||
testCache: d.parentalCache,
|
||||
testFunc: d.checkParental,
|
||||
name: "pc_no_block",
|
||||
block: false,
|
||||
testFunc: d.checkParental,
|
||||
testCache: gctx.parentalCache,
|
||||
}, {
|
||||
testCache: d.parentalCache,
|
||||
testFunc: d.checkParental,
|
||||
name: "pc_block",
|
||||
block: true,
|
||||
testFunc: d.checkParental,
|
||||
testCache: gctx.parentalCache,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -215,6 +217,6 @@ func TestSBPC(t *testing.T) {
|
||||
assert.Equal(t, 1, ups.RequestsCount())
|
||||
})
|
||||
|
||||
purgeCaches()
|
||||
purgeCaches(d)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
@@ -74,7 +75,7 @@ func (d *DNSFilter) checkSafeSearch(
|
||||
_ uint16,
|
||||
setts *Settings,
|
||||
) (res Result, err error) {
|
||||
if !setts.SafeSearchEnabled {
|
||||
if !setts.ProtectionEnabled || !setts.SafeSearchEnabled {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
@@ -84,7 +85,7 @@ func (d *DNSFilter) checkSafeSearch(
|
||||
}
|
||||
|
||||
// Check cache. Return cached result if it was found
|
||||
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, host)
|
||||
cachedValue, isFound := getCachedResult(d.safeSearchCache, host)
|
||||
if isFound {
|
||||
// atomic.AddUint64(&gctx.stats.Safesearch.CacheHits, 1)
|
||||
log.Tracef("SafeSearch: found in cache: %s", host)
|
||||
@@ -99,12 +100,14 @@ func (d *DNSFilter) checkSafeSearch(
|
||||
res = Result{
|
||||
IsFiltered: true,
|
||||
Reason: FilteredSafeSearch,
|
||||
Rules: []*ResultRule{{}},
|
||||
Rules: []*ResultRule{{
|
||||
FilterListID: SafeSearchListID,
|
||||
}},
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(safeHost); ip != nil {
|
||||
res.Rules[0].IP = ip
|
||||
valLen := d.setCacheResult(gctx.safeSearchCache, host, res)
|
||||
valLen := d.setCacheResult(d.safeSearchCache, host, res)
|
||||
log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, valLen)
|
||||
|
||||
return res, nil
|
||||
@@ -123,7 +126,7 @@ func (d *DNSFilter) checkSafeSearch(
|
||||
|
||||
res.Rules[0].IP = ip
|
||||
|
||||
l := d.setCacheResult(gctx.safeSearchCache, host, res)
|
||||
l := d.setCacheResult(d.safeSearchCache, host, res)
|
||||
log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, l)
|
||||
|
||||
return res, nil
|
||||
@@ -150,8 +153,13 @@ func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Reques
|
||||
Enabled: d.Config.SafeSearchEnabled,
|
||||
})
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
||||
return
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"Unable to write response json: %s",
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user