all: resync with master

This commit is contained in:
Eugene Burkov
2025-03-17 20:56:05 +03:00
parent 2fc1e258ed
commit a829adad10
69 changed files with 1126 additions and 434 deletions

View File

@@ -49,6 +49,9 @@ func initBlockedServices() {
}
// BlockedServices is the configuration of blocked services.
//
// TODO(s.chzhen): Move to a higher-level package to allow importing the client
// package into the filtering package.
type BlockedServices struct {
// Schedule is blocked services schedule for every day of the week.
Schedule *schedule.Weekly `json:"schedule" yaml:"schedule"`

View File

@@ -1,10 +1,11 @@
package filtering
package filtering_test
import (
"net/netip"
"path"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@@ -50,8 +51,17 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|1.2.3.5.in-addr.arpa^$dnsrewrite=NOERROR;PTR;new-ptr-with-dot.
`
f, _ := newForTest(t, nil, []Filter{{ID: 0, Data: []byte(text)}})
setts := &Settings{
conf := &filtering.Config{
SafeBrowsingCacheSize: 10000,
ParentalCacheSize: 10000,
SafeSearchCacheSize: 1000,
CacheTime: 30,
}
f, err := filtering.New(conf, []filtering.Filter{{ID: 0, Data: []byte(text)}})
require.NoError(t, err)
setts := &filtering.Settings{
FilteringEnabled: true,
}
@@ -117,7 +127,8 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
host := path.Base(tc.name)
res, err := f.CheckHostRules(host, tc.dtyp, setts)
var res filtering.Result
res, err = f.CheckHostRules(host, tc.dtyp, setts)
require.NoError(t, err)
dnsrr := res.DNSRewriteResult
@@ -141,7 +152,8 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
dtyp := dns.TypeA
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
var res filtering.Result
res, err = f.CheckHostRules(host, dtyp, setts)
require.NoError(t, err)
assert.Equal(t, "new-cname", res.CanonName)
@@ -151,7 +163,8 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
dtyp := dns.TypeA
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
var res filtering.Result
res, err = f.CheckHostRules(host, dtyp, setts)
require.NoError(t, err)
assert.Equal(t, "new-cname-2", res.CanonName)
@@ -162,7 +175,8 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
dtyp := dns.TypeA
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
var res filtering.Result
res, err = f.CheckHostRules(host, dtyp, setts)
require.NoError(t, err)
assert.Empty(t, res.CanonName)
@@ -173,7 +187,8 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
dtyp := dns.TypePTR
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
var res filtering.Result
res, err = f.CheckHostRules(host, dtyp, setts)
require.NoError(t, err)
require.NotNil(t, res.DNSRewriteResult)
@@ -193,7 +208,8 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
dtyp := dns.TypePTR
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
var res filtering.Result
res, err = f.CheckHostRules(host, dtyp, setts)
require.NoError(t, err)
require.NotNil(t, res.DNSRewriteResult)

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"io"
"net/http"
"net/netip"
"os"
"path/filepath"
"slices"
@@ -629,3 +630,19 @@ func (d *DNSFilter) enableFiltersLocked(async bool) {
d.SetEnabled(d.conf.FilteringEnabled)
}
// ApplyAdditionalFiltering enhances the provided filtering settings with
// blocked services and client-specific configurations.
func (d *DNSFilter) ApplyAdditionalFiltering(cliAddr netip.Addr, clientID string, setts *Settings) {
d.ApplyBlockedServices(setts)
d.applyClientFiltering(clientID, cliAddr, setts)
if setts.BlockedServices != nil {
// TODO(e.burkov): Get rid of this crutch.
setts.ServicesRules = nil
svcs := setts.BlockedServices.IDs
if !setts.BlockedServices.Schedule.Contains(time.Now()) {
d.ApplyBlockedServicesList(setts, svcs)
}
}
}

View File

@@ -40,6 +40,8 @@ type ServiceEntry struct {
}
// Settings are custom filtering settings for a client.
//
// TODO(s.chzhen): Move to the client package.
type Settings struct {
ClientName string
ClientIP netip.Addr
@@ -47,6 +49,10 @@ type Settings struct {
ServicesRules []ServiceEntry
// BlockedServices is the configuration of blocked services of a client. It
// is nil if the client does not have any blocked services.
BlockedServices *BlockedServices
ProtectionEnabled bool
FilteringEnabled bool
SafeSearchEnabled bool
@@ -78,6 +84,11 @@ type Config struct {
SafeSearch SafeSearch `yaml:"-"`
// ApplyClientFiltering retrieves persistent client information using the
// ClientID or client IP address, and applies it to the filtering settings.
// It must not be nil.
ApplyClientFiltering func(clientID string, cliAddr netip.Addr, setts *Settings) `yaml:"-"`
// BlockedServices is the configuration of blocked services.
// Per-client settings can override this configuration.
BlockedServices *BlockedServices `yaml:"blocked_services"`
@@ -244,6 +255,13 @@ type DNSFilter struct {
// parentalControl is the parental control hash-prefix checker.
parentalControlChecker Checker
// applyClientFiltering retrieves persistent client information using the
// ClientID or client IP address, and applies it to the filtering settings.
//
// TODO(s.chzhen): Consider finding a better approach while taking an
// import cycle into account.
applyClientFiltering func(clientID string, cliAddr netip.Addr, setts *Settings)
engineLock sync.RWMutex
// confMu protects conf.
@@ -998,6 +1016,7 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
refreshLock: &sync.Mutex{},
safeBrowsingChecker: c.SafeBrowsingChecker,
parentalControlChecker: c.ParentalControlChecker,
applyClientFiltering: c.ApplyClientFiltering,
confMu: &sync.RWMutex{},
}

View File

@@ -1,4 +1,4 @@
package filtering
package filtering_test
import (
"fmt"
@@ -8,6 +8,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter/rules"
@@ -50,27 +51,27 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, hc.Close)
conf := &Config{
conf := &filtering.Config{
EtcHosts: hc,
}
f, err := New(conf, nil)
f, err := filtering.New(conf, nil)
require.NoError(t, err)
setts := &Settings{
setts := &filtering.Settings{
FilteringEnabled: true,
}
testCases := []struct {
name string
host string
wantRules []*ResultRule
wantRules []*filtering.ResultRule
wantResps []rules.RRValue
dtyp uint16
}{{
name: "v4",
host: "v4.host.example",
dtyp: dns.TypeA,
wantRules: []*ResultRule{{
wantRules: []*filtering.ResultRule{{
Text: "1.2.3.4 v4.host.example",
FilterListID: rulelist.URLFilterIDEtcHosts,
}},
@@ -79,7 +80,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
name: "v6",
host: "v6.host.example",
dtyp: dns.TypeAAAA,
wantRules: []*ResultRule{{
wantRules: []*filtering.ResultRule{{
Text: "::1 v6.host.example",
FilterListID: rulelist.URLFilterIDEtcHosts,
}},
@@ -88,7 +89,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
name: "mapped",
host: "mapped.host.example",
dtyp: dns.TypeAAAA,
wantRules: []*ResultRule{{
wantRules: []*filtering.ResultRule{{
Text: "::ffff:1.2.3.4 mapped.host.example",
FilterListID: rulelist.URLFilterIDEtcHosts,
}},
@@ -97,7 +98,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
name: "ptr",
host: "4.3.2.1.in-addr.arpa",
dtyp: dns.TypePTR,
wantRules: []*ResultRule{{
wantRules: []*filtering.ResultRule{{
Text: "1.2.3.4 v4.host.example",
FilterListID: rulelist.URLFilterIDEtcHosts,
}},
@@ -106,7 +107,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
name: "ptr-mapped",
host: "4.0.3.0.2.0.1.0.f.f.f.f.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa",
dtyp: dns.TypePTR,
wantRules: []*ResultRule{{
wantRules: []*filtering.ResultRule{{
Text: "::ffff:1.2.3.4 mapped.host.example",
FilterListID: rulelist.URLFilterIDEtcHosts,
}},
@@ -133,7 +134,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
name: "v4_mismatch",
host: "v4.host.example",
dtyp: dns.TypeAAAA,
wantRules: []*ResultRule{{
wantRules: []*filtering.ResultRule{{
Text: fmt.Sprintf("%s v4.host.example", addrv4),
FilterListID: rulelist.URLFilterIDEtcHosts,
}},
@@ -142,7 +143,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
name: "v6_mismatch",
host: "v6.host.example",
dtyp: dns.TypeA,
wantRules: []*ResultRule{{
wantRules: []*filtering.ResultRule{{
Text: fmt.Sprintf("%s v6.host.example", addrv6),
FilterListID: rulelist.URLFilterIDEtcHosts,
}},
@@ -163,7 +164,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
name: "v4_dup",
host: "v4.host.with-dup",
dtyp: dns.TypeA,
wantRules: []*ResultRule{{
wantRules: []*filtering.ResultRule{{
Text: "4.3.2.1 v4.host.with-dup",
FilterListID: rulelist.URLFilterIDEtcHosts,
}},
@@ -172,7 +173,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var res Result
var res filtering.Result
res, err = f.CheckHost(tc.host, tc.dtyp, setts)
require.NoError(t, err)

View File

@@ -9,6 +9,8 @@ import (
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"sync"
"time"
@@ -420,15 +422,53 @@ type checkHostResp struct {
FilterID rulelist.URLFilterID `json:"filter_id"`
}
// handleCheckHost is the handler for the GET /control/filtering/check_host HTTP
// API.
func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
host := r.URL.Query().Get("name")
query := r.URL.Query()
host := query.Get("name")
if host == "" {
aghhttp.Error(
r,
w,
http.StatusBadRequest,
`query parameter "name" is required`,
)
return
}
cli := query.Get("client")
qTypeStr := query.Get("qtype")
qType, err := stringToDNSType(qTypeStr)
if err != nil {
aghhttp.Error(
r,
w,
http.StatusUnprocessableEntity,
"bad qtype query parameter: %q",
qTypeStr,
)
return
}
setts := d.Settings()
setts.FilteringEnabled = true
setts.ProtectionEnabled = true
d.ApplyBlockedServices(setts)
result, err := d.CheckHost(host, dns.TypeA, setts)
addr, err := netip.ParseAddr(cli)
if err == nil {
setts.ClientIP = addr
d.ApplyAdditionalFiltering(addr, "", setts)
} else if cli != "" {
// TODO(s.chzhen): Set [Settings.ClientName] once urlfilter supports
// multiple client names. This will handle the case when a rule exists
// but the persistent client does not.
d.ApplyAdditionalFiltering(netip.Addr{}, cli, setts)
}
result, err := d.CheckHost(host, qType, setts)
if err != nil {
aghhttp.Error(
r,
@@ -466,6 +506,33 @@ func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
aghhttp.WriteJSONResponseOK(w, r, resp)
}
// stringToDNSType is a helper function that converts a string to DNS type. If
// the string is empty, it returns the default value [dns.TypeA].
func stringToDNSType(str string) (qtype uint16, err error) {
if str == "" {
return dns.TypeA, nil
}
qtype, ok := dns.StringToType[str]
if ok {
return qtype, nil
}
// typePref is a prefix for DNS types from experimental RFCs.
const typePref = "TYPE"
if !strings.HasPrefix(str, typePref) {
return 0, errors.ErrBadEnumValue
}
val, err := strconv.ParseUint(str[len(typePref):], 10, 16)
if err != nil {
return 0, errors.ErrBadEnumValue
}
return uint16(val), nil
}
// setProtectedBool sets the value of a boolean pointer under a lock. l must
// protect the value under ptr.
//

View File

@@ -3,11 +3,15 @@ package filtering
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/netip"
"strings"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -305,3 +309,168 @@ func TestDNSFilter_handleParentalStatus(t *testing.T) {
})
}
}
func TestDNSFilter_HandleCheckHost(t *testing.T) {
const (
cliName = "client_name"
cliID = "client_id"
notFilteredHost = "not.filterd.example"
allowedHost = "allowed.example"
blockedHost = "blocked.example"
cliHost = "client.example"
qTypeHost = "qtype.example"
cliQTypeHost = "cli.qtype.example"
target = "/control/check_host"
hostFmt = target + "?name=%s"
hostCliFmt = hostFmt + "&client=%s"
hostQTypeFmt = hostFmt + "&qtype=%s"
hostCliQTypeFmt = hostCliFmt + "&qtype=%s"
allowedRuleFmt = "@@||%s^"
blockedRuleFmt = "||%s^"
blockedRuleCliFmt = blockedRuleFmt + "$client=%s"
blockedRuleQTypeFmt = blockedRuleFmt + "$dnstype=%s"
blockedRuleCliQTypeFmt = blockedRuleCliFmt + ",dnstype=%s"
)
var (
allowedRule = fmt.Sprintf(allowedRuleFmt, allowedHost)
blockedRule = fmt.Sprintf(blockedRuleFmt, blockedHost)
blockedClientRule = fmt.Sprintf(blockedRuleCliFmt, cliHost, cliName)
blockedQTypeRule = fmt.Sprintf(blockedRuleQTypeFmt, qTypeHost, "CNAME")
blockedClientQTypeRule = fmt.Sprintf(blockedRuleCliQTypeFmt, cliQTypeHost, cliName, "CNAME")
notFilteredURL = fmt.Sprintf(hostFmt, notFilteredHost)
allowedURL = fmt.Sprintf(hostFmt, allowedHost)
blockedURL = fmt.Sprintf(hostFmt, blockedHost)
blockedClientURL = fmt.Sprintf(hostCliFmt, cliHost, cliID)
allowedQTypeURL = fmt.Sprintf(hostQTypeFmt, qTypeHost, "AAAA")
blockedQTypeURL = fmt.Sprintf(hostQTypeFmt, qTypeHost, "CNAME")
allowedClientQTypeURL = fmt.Sprintf(hostCliQTypeFmt, cliQTypeHost, cliID, "AAAA")
blockedClientQTypeURL = fmt.Sprintf(hostCliQTypeFmt, cliQTypeHost, cliID, "CNAME")
)
rules := []string{
allowedRule,
blockedRule,
blockedClientRule,
blockedQTypeRule,
blockedClientQTypeRule,
}
rulesData := strings.Join(rules, "\n")
filters := []Filter{{
ID: 0, Data: []byte(rulesData),
}}
clientNames := map[string]string{
cliID: cliName,
}
dnsFilter, err := New(&Config{
BlockedServices: &BlockedServices{
Schedule: schedule.EmptyWeekly(),
},
ApplyClientFiltering: func(clientID string, cliAddr netip.Addr, setts *Settings) {
setts.ClientName = clientNames[clientID]
},
}, filters)
require.NoError(t, err)
testCases := []struct {
name string
url string
want *checkHostResp
}{{
name: "not_filtered",
url: notFilteredURL,
want: &checkHostResp{
Reason: reasonNames[NotFilteredNotFound],
Rule: "",
Rules: []*checkHostRespRule{},
},
}, {
name: "allowed",
url: allowedURL,
want: &checkHostResp{
Reason: reasonNames[NotFilteredAllowList],
Rule: allowedRule,
Rules: []*checkHostRespRule{{
Text: allowedRule,
}},
},
}, {
name: "blocked",
url: blockedURL,
want: &checkHostResp{
Reason: reasonNames[FilteredBlockList],
Rule: blockedRule,
Rules: []*checkHostRespRule{{
Text: blockedRule,
}},
},
}, {
name: "blocked_client",
url: blockedClientURL,
want: &checkHostResp{
Reason: reasonNames[FilteredBlockList],
Rule: blockedClientRule,
Rules: []*checkHostRespRule{{
Text: blockedClientRule,
}},
},
}, {
name: "allowed_qtype",
url: allowedQTypeURL,
want: &checkHostResp{
Reason: reasonNames[NotFilteredNotFound],
Rule: "",
Rules: []*checkHostRespRule{},
},
}, {
name: "blocked_qtype",
url: blockedQTypeURL,
want: &checkHostResp{
Reason: reasonNames[FilteredBlockList],
Rule: blockedQTypeRule,
Rules: []*checkHostRespRule{{
Text: blockedQTypeRule,
}},
},
}, {
name: "blocked_client_qtype",
url: blockedClientQTypeURL,
want: &checkHostResp{
Reason: reasonNames[FilteredBlockList],
Rule: blockedClientQTypeRule,
Rules: []*checkHostRespRule{{
Text: blockedClientQTypeRule,
}},
},
}, {
name: "allowed_client_qtype",
url: allowedClientQTypeURL,
want: &checkHostResp{
Reason: reasonNames[NotFilteredNotFound],
Rule: "",
Rules: []*checkHostRespRule{},
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, tc.url, nil)
w := httptest.NewRecorder()
dnsFilter.handleCheckHost(w, r)
res := &checkHostResp{}
err = json.NewDecoder(w.Body).Decode(res)
require.NoError(t, err)
assert.Equal(t, tc.want, res)
})
}
}

View File

@@ -3,6 +3,9 @@ package filtering
import "context"
// SafeSearch interface describes a service for search engines hosts rewrites.
//
// TODO(s.chzhen): Move to a higher-level package to allow importing the client
// package into the filtering package.
type SafeSearch interface {
// CheckHost checks host with safe search filter. CheckHost must be safe
// for concurrent use. qtype must be either [dns.TypeA] or [dns.TypeAAAA].