all: sync with master; upd chlog
This commit is contained in:
@@ -1,10 +1,7 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
)
|
||||
|
||||
// NormalizeDomain returns a lowercased version of host without the final dot,
|
||||
@@ -19,25 +16,3 @@ func NormalizeDomain(host string) (norm string) {
|
||||
|
||||
return strings.ToLower(strings.TrimSuffix(host, "."))
|
||||
}
|
||||
|
||||
// NewDomainNameSet returns nil and error, if list has duplicate or empty domain
|
||||
// name. Otherwise returns a set, which contains domain names normalized using
|
||||
// [NormalizeDomain].
|
||||
func NewDomainNameSet(list []string) (set *stringutil.Set, err error) {
|
||||
set = stringutil.NewSet()
|
||||
|
||||
for i, host := range list {
|
||||
if host == "" {
|
||||
return nil, fmt.Errorf("at index %d: hostname is empty", i)
|
||||
}
|
||||
|
||||
host = NormalizeDomain(host)
|
||||
if set.Has(host) {
|
||||
return nil, fmt.Errorf("duplicate hostname %q at index %d", host, i)
|
||||
}
|
||||
|
||||
set.Add(host)
|
||||
}
|
||||
|
||||
return set, nil
|
||||
}
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
package aghnet_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewDomainNameSet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErrMsg string
|
||||
in []string
|
||||
}{{
|
||||
name: "nil",
|
||||
wantErrMsg: "",
|
||||
in: nil,
|
||||
}, {
|
||||
name: "success",
|
||||
wantErrMsg: "",
|
||||
in: []string{
|
||||
"Domain.Example",
|
||||
".",
|
||||
},
|
||||
}, {
|
||||
name: "dups",
|
||||
wantErrMsg: `duplicate hostname "domain.example" at index 1`,
|
||||
in: []string{
|
||||
"Domain.Example",
|
||||
"domain.example",
|
||||
},
|
||||
}, {
|
||||
name: "bad_domain",
|
||||
wantErrMsg: "at index 0: hostname is empty",
|
||||
in: []string{
|
||||
"",
|
||||
},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
set, err := aghnet.NewDomainNameSet(tc.in)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, host := range tc.in {
|
||||
assert.Truef(t, set.Has(aghnet.NormalizeDomain(host)), "%q not matched", host)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"io"
|
||||
"io/fs"
|
||||
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
)
|
||||
|
||||
// FileWalker is the signature of a function called for files in the file tree.
|
||||
@@ -56,7 +56,7 @@ func checkFile(
|
||||
// srcSet. srcSet must be non-nil.
|
||||
func handlePatterns(
|
||||
fsys fs.FS,
|
||||
srcSet *stringutil.Set,
|
||||
srcSet *container.MapSet[string],
|
||||
patterns ...string,
|
||||
) (sub []string, err error) {
|
||||
sub = make([]string, 0, len(patterns))
|
||||
@@ -87,7 +87,7 @@ func handlePatterns(
|
||||
func (fw FileWalker) Walk(fsys fs.FS, initial ...string) (ok bool, err error) {
|
||||
// The slice of sources keeps the order in which the files are walked since
|
||||
// srcSet.Values() returns strings in undefined order.
|
||||
srcSet := stringutil.NewSet()
|
||||
srcSet := container.NewMapSet[string]()
|
||||
var src []string
|
||||
src, err = handlePatterns(fsys, srcSet, initial...)
|
||||
if err != nil {
|
||||
|
||||
@@ -6,10 +6,10 @@ import (
|
||||
"io/fs"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
@@ -46,7 +46,7 @@ type osWatcher struct {
|
||||
events chan event
|
||||
|
||||
// files is the set of tracked files.
|
||||
files *stringutil.Set
|
||||
files *container.MapSet[string]
|
||||
}
|
||||
|
||||
// osWatcherPref is a prefix for logging and wrapping errors in osWathcer's
|
||||
@@ -67,7 +67,7 @@ func NewOSWritesWatcher() (w FSWatcher, err error) {
|
||||
return &osWatcher{
|
||||
watcher: watcher,
|
||||
events: make(chan event, 1),
|
||||
files: stringutil.NewSet(),
|
||||
files: container.NewMapSet[string](),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -12,10 +12,10 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
@@ -98,7 +98,7 @@ type Persistent struct {
|
||||
}
|
||||
|
||||
// SetTags sets the tags if they are known, otherwise logs an unknown tag.
|
||||
func (c *Persistent) SetTags(tags []string, known *stringutil.Set) {
|
||||
func (c *Persistent) SetTags(tags []string, known *container.MapSet[string]) {
|
||||
for _, t := range tags {
|
||||
if !known.Has(t) {
|
||||
log.Info("skipping unknown tag %q", t)
|
||||
|
||||
@@ -5,10 +5,12 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
@@ -16,22 +18,19 @@ import (
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
)
|
||||
|
||||
// unit is a convenient alias for struct{}
|
||||
type unit = struct{}
|
||||
|
||||
// accessManager controls IP and client blocking that takes place before all
|
||||
// other processing. An accessManager is safe for concurrent use.
|
||||
type accessManager struct {
|
||||
allowedIPs map[netip.Addr]unit
|
||||
blockedIPs map[netip.Addr]unit
|
||||
allowedIPs *container.MapSet[netip.Addr]
|
||||
blockedIPs *container.MapSet[netip.Addr]
|
||||
|
||||
allowedClientIDs *stringutil.Set
|
||||
blockedClientIDs *stringutil.Set
|
||||
allowedClientIDs *container.MapSet[string]
|
||||
blockedClientIDs *container.MapSet[string]
|
||||
|
||||
// TODO(s.chzhen): Use [aghnet.IgnoreEngine].
|
||||
blockedHostsEng *urlfilter.DNSEngine
|
||||
|
||||
// TODO(a.garipov): Create a type for a set of IP networks.
|
||||
// TODO(a.garipov): Create a type for an efficient tree set of IP networks.
|
||||
allowedNets []netip.Prefix
|
||||
blockedNets []netip.Prefix
|
||||
}
|
||||
@@ -40,15 +39,15 @@ type accessManager struct {
|
||||
// which may be an IP address, a CIDR, or a ClientID.
|
||||
func processAccessClients(
|
||||
clientStrs []string,
|
||||
ips map[netip.Addr]unit,
|
||||
ips *container.MapSet[netip.Addr],
|
||||
nets *[]netip.Prefix,
|
||||
clientIDs *stringutil.Set,
|
||||
clientIDs *container.MapSet[string],
|
||||
) (err error) {
|
||||
for i, s := range clientStrs {
|
||||
var ip netip.Addr
|
||||
var ipnet netip.Prefix
|
||||
if ip, err = netip.ParseAddr(s); err == nil {
|
||||
ips[ip] = unit{}
|
||||
ips.Add(ip)
|
||||
} else if ipnet, err = netip.ParsePrefix(s); err == nil {
|
||||
*nets = append(*nets, ipnet)
|
||||
} else {
|
||||
@@ -67,11 +66,11 @@ func processAccessClients(
|
||||
// newAccessCtx creates a new accessCtx.
|
||||
func newAccessCtx(allowed, blocked, blockedHosts []string) (a *accessManager, err error) {
|
||||
a = &accessManager{
|
||||
allowedIPs: map[netip.Addr]unit{},
|
||||
blockedIPs: map[netip.Addr]unit{},
|
||||
allowedIPs: container.NewMapSet[netip.Addr](),
|
||||
blockedIPs: container.NewMapSet[netip.Addr](),
|
||||
|
||||
allowedClientIDs: stringutil.NewSet(),
|
||||
blockedClientIDs: stringutil.NewSet(),
|
||||
allowedClientIDs: container.NewMapSet[string](),
|
||||
blockedClientIDs: container.NewMapSet[string](),
|
||||
}
|
||||
|
||||
err = processAccessClients(allowed, a.allowedIPs, &a.allowedNets, a.allowedClientIDs)
|
||||
@@ -109,7 +108,7 @@ func newAccessCtx(allowed, blocked, blockedHosts []string) (a *accessManager, er
|
||||
|
||||
// allowlistMode returns true if this *accessCtx is in the allowlist mode.
|
||||
func (a *accessManager) allowlistMode() (ok bool) {
|
||||
return len(a.allowedIPs) != 0 || a.allowedClientIDs.Len() != 0 || len(a.allowedNets) != 0
|
||||
return a.allowedIPs.Len() != 0 || a.allowedClientIDs.Len() != 0 || len(a.allowedNets) != 0
|
||||
}
|
||||
|
||||
// isBlockedClientID returns true if the ClientID should be blocked.
|
||||
@@ -152,7 +151,7 @@ func (a *accessManager) isBlockedIP(ip netip.Addr) (blocked bool, rule string) {
|
||||
ipnets = a.allowedNets
|
||||
}
|
||||
|
||||
if _, ok := ips[ip]; ok {
|
||||
if ips.Has(ip) {
|
||||
return blocked, ip.String()
|
||||
}
|
||||
|
||||
@@ -176,9 +175,9 @@ func (s *Server) accessListJSON() (j accessListJSON) {
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
return accessListJSON{
|
||||
AllowedClients: stringutil.CloneSlice(s.conf.AllowedClients),
|
||||
DisallowedClients: stringutil.CloneSlice(s.conf.DisallowedClients),
|
||||
BlockedHosts: stringutil.CloneSlice(s.conf.BlockedHosts),
|
||||
AllowedClients: slices.Clone(s.conf.AllowedClients),
|
||||
DisallowedClients: slices.Clone(s.conf.DisallowedClients),
|
||||
BlockedHosts: slices.Clone(s.conf.BlockedHosts),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
@@ -461,26 +462,27 @@ func (s *Server) prepareIpsetListSettings() (err error) {
|
||||
// unspecPorts if its address is unspecified.
|
||||
func collectListenAddr(
|
||||
addrPort netip.AddrPort,
|
||||
addrs map[netip.AddrPort]unit,
|
||||
unspecPorts map[uint16]unit,
|
||||
addrs *container.MapSet[netip.AddrPort],
|
||||
unspecPorts *container.MapSet[uint16],
|
||||
) {
|
||||
if addrPort == (netip.AddrPort{}) {
|
||||
return
|
||||
}
|
||||
|
||||
addrs[addrPort] = unit{}
|
||||
addrs.Add(addrPort)
|
||||
if addrPort.Addr().IsUnspecified() {
|
||||
unspecPorts[addrPort.Port()] = unit{}
|
||||
unspecPorts.Add(addrPort.Port())
|
||||
}
|
||||
}
|
||||
|
||||
// collectDNSAddrs returns configured set of listening addresses. It also
|
||||
// returns a set of ports of each unspecified listening address.
|
||||
func (conf *ServerConfig) collectDNSAddrs() (addrs mapAddrPortSet, unspecPorts map[uint16]unit) {
|
||||
// TODO(e.burkov): Perhaps, we shouldn't allocate as much memory, since the
|
||||
// TCP and UDP listening addresses are currently the same.
|
||||
addrs = make(map[netip.AddrPort]unit, len(conf.TCPListenAddrs)+len(conf.UDPListenAddrs))
|
||||
unspecPorts = map[uint16]unit{}
|
||||
func (conf *ServerConfig) collectDNSAddrs() (
|
||||
addrs *container.MapSet[netip.AddrPort],
|
||||
unspecPorts *container.MapSet[uint16],
|
||||
) {
|
||||
addrs = container.NewMapSet[netip.AddrPort]()
|
||||
unspecPorts = container.NewMapSet[uint16]()
|
||||
|
||||
for _, laddr := range conf.TCPListenAddrs {
|
||||
collectListenAddr(laddr.AddrPort(), addrs, unspecPorts)
|
||||
@@ -511,26 +513,12 @@ type emptyAddrPortSet struct{}
|
||||
// Has implements the [addrPortSet] interface for [emptyAddrPortSet].
|
||||
func (emptyAddrPortSet) Has(_ netip.AddrPort) (ok bool) { return false }
|
||||
|
||||
// mapAddrPortSet is the [addrPortSet] containing values of [netip.AddrPort] as
|
||||
// keys of a map.
|
||||
type mapAddrPortSet map[netip.AddrPort]unit
|
||||
|
||||
// type check
|
||||
var _ addrPortSet = mapAddrPortSet{}
|
||||
|
||||
// Has implements the [addrPortSet] interface for [mapAddrPortSet].
|
||||
func (m mapAddrPortSet) Has(addrPort netip.AddrPort) (ok bool) {
|
||||
_, ok = m[addrPort]
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// combinedAddrPortSet is the [addrPortSet] defined by some IP addresses along
|
||||
// with ports, any combination of which is considered being in the set.
|
||||
type combinedAddrPortSet struct {
|
||||
// TODO(e.burkov): Use sorted slices in combination with binary search.
|
||||
ports map[uint16]unit
|
||||
addrs []netip.Addr
|
||||
// TODO(e.burkov): Use container.SliceSet when available.
|
||||
ports *container.MapSet[uint16]
|
||||
addrs *container.MapSet[netip.Addr]
|
||||
}
|
||||
|
||||
// type check
|
||||
@@ -538,9 +526,7 @@ var _ addrPortSet = (*combinedAddrPortSet)(nil)
|
||||
|
||||
// Has implements the [addrPortSet] interface for [*combinedAddrPortSet].
|
||||
func (m *combinedAddrPortSet) Has(addrPort netip.AddrPort) (ok bool) {
|
||||
_, ok = m.ports[addrPort.Port()]
|
||||
|
||||
return ok && slices.Contains(m.addrs, addrPort.Addr())
|
||||
return m.ports.Has(addrPort.Port()) && m.addrs.Has(addrPort.Addr())
|
||||
}
|
||||
|
||||
// filterOut filters out all the upstreams that match um. It returns all the
|
||||
@@ -578,11 +564,11 @@ func filterOutAddrs(upsConf *proxy.UpstreamConfig, set addrPortSet) (err error)
|
||||
func (conf *ServerConfig) ourAddrsSet() (m addrPortSet, err error) {
|
||||
addrs, unspecPorts := conf.collectDNSAddrs()
|
||||
switch {
|
||||
case len(addrs) == 0:
|
||||
case addrs.Len() == 0:
|
||||
log.Debug("dnsforward: no listen addresses")
|
||||
|
||||
return emptyAddrPortSet{}, nil
|
||||
case len(unspecPorts) == 0:
|
||||
case unspecPorts.Len() == 0:
|
||||
log.Debug("dnsforward: filtering out addresses %s", addrs)
|
||||
|
||||
return addrs, nil
|
||||
@@ -598,7 +584,7 @@ func (conf *ServerConfig) ourAddrsSet() (m addrPortSet, err error) {
|
||||
|
||||
return &combinedAddrPortSet{
|
||||
ports: unspecPorts,
|
||||
addrs: ifaceAddrs,
|
||||
addrs: container.NewMapSet(ifaceAddrs...),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -308,13 +308,13 @@ func (s *Server) WriteDiskConfig(c *Config) {
|
||||
sc := s.conf.Config
|
||||
*c = sc
|
||||
c.RatelimitWhitelist = slices.Clone(sc.RatelimitWhitelist)
|
||||
c.BootstrapDNS = stringutil.CloneSlice(sc.BootstrapDNS)
|
||||
c.FallbackDNS = stringutil.CloneSlice(sc.FallbackDNS)
|
||||
c.AllowedClients = stringutil.CloneSlice(sc.AllowedClients)
|
||||
c.DisallowedClients = stringutil.CloneSlice(sc.DisallowedClients)
|
||||
c.BlockedHosts = stringutil.CloneSlice(sc.BlockedHosts)
|
||||
c.BootstrapDNS = slices.Clone(sc.BootstrapDNS)
|
||||
c.FallbackDNS = slices.Clone(sc.FallbackDNS)
|
||||
c.AllowedClients = slices.Clone(sc.AllowedClients)
|
||||
c.DisallowedClients = slices.Clone(sc.DisallowedClients)
|
||||
c.BlockedHosts = slices.Clone(sc.BlockedHosts)
|
||||
c.TrustedProxies = slices.Clone(sc.TrustedProxies)
|
||||
c.UpstreamDNS = stringutil.CloneSlice(sc.UpstreamDNS)
|
||||
c.UpstreamDNS = slices.Clone(sc.UpstreamDNS)
|
||||
}
|
||||
|
||||
// LocalPTRResolvers returns the current local PTR resolver configuration.
|
||||
@@ -322,7 +322,7 @@ func (s *Server) LocalPTRResolvers() (localPTRResolvers []string) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
return stringutil.CloneSlice(s.conf.LocalPTRResolvers)
|
||||
return slices.Clone(s.conf.LocalPTRResolvers)
|
||||
}
|
||||
|
||||
// AddrProcConfig returns the current address processing configuration. Only
|
||||
|
||||
@@ -474,8 +474,6 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
|
||||
if dc.UpstreamMode != nil {
|
||||
s.conf.UpstreamMode = mustParseUpstreamMode(*dc.UpstreamMode)
|
||||
} else {
|
||||
s.conf.UpstreamMode = UpstreamModeLoadBalance
|
||||
}
|
||||
|
||||
if dc.EDNSCSUseCustom != nil && *dc.EDNSCSUseCustom {
|
||||
|
||||
@@ -29,6 +29,10 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TODO(e.burkov): Use the better approach to testdata with a separate
|
||||
// directory for each test, and a separate file for each subtest. See the
|
||||
// [configmigrate] package.
|
||||
|
||||
// emptySysResolvers is an empty [SystemResolvers] implementation that always
|
||||
// returns nil.
|
||||
type emptySysResolvers struct{}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
@@ -28,7 +29,7 @@ func initBlockedServices() {
|
||||
for i, s := range blockedServices {
|
||||
netRules := make([]*rules.NetworkRule, 0, len(s.Rules))
|
||||
for _, text := range s.Rules {
|
||||
rule, err := rules.NewNetworkRule(text, BlockedSvcsListID)
|
||||
rule, err := rules.NewNetworkRule(text, rulelist.URLFilterIDBlockedService)
|
||||
if err != nil {
|
||||
log.Error("parsing blocked service %q rule %q: %s", s.ID, text, err)
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
|
||||
if dr.NewCNAME != "" {
|
||||
// NewCNAME rules have a higher priority than other rules.
|
||||
rules = []*ResultRule{{
|
||||
FilterListID: int64(nr.GetFilterListID()),
|
||||
FilterListID: nr.GetFilterListID(),
|
||||
Text: nr.RuleText,
|
||||
}}
|
||||
|
||||
@@ -46,14 +46,14 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
|
||||
dnsrr.RCode = dr.RCode
|
||||
dnsrr.Response[dr.RRType] = append(dnsrr.Response[dr.RRType], dr.Value)
|
||||
rules = append(rules, &ResultRule{
|
||||
FilterListID: int64(nr.GetFilterListID()),
|
||||
FilterListID: nr.GetFilterListID(),
|
||||
Text: nr.RuleText,
|
||||
})
|
||||
default:
|
||||
// RcodeRefused and other such codes have higher priority. Return
|
||||
// immediately.
|
||||
rules = []*ResultRule{{
|
||||
FilterListID: int64(nr.GetFilterListID()),
|
||||
FilterListID: nr.GetFilterListID(),
|
||||
Text: nr.RuleText,
|
||||
}}
|
||||
dnsrr = &DNSRewriteResult{
|
||||
|
||||
@@ -13,20 +13,15 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghrenameio"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
)
|
||||
|
||||
// filterDir is the subdirectory of a data directory to store downloaded
|
||||
// filters.
|
||||
const filterDir = "filters"
|
||||
|
||||
// nextFilterID is a way to seed a unique ID generation.
|
||||
//
|
||||
// TODO(e.burkov): Use more deterministic approach.
|
||||
var nextFilterID = time.Now().Unix()
|
||||
|
||||
// FilterYAML represents a filter list in the configuration file.
|
||||
//
|
||||
// TODO(e.burkov): Investigate if the field ordering is important.
|
||||
@@ -50,7 +45,10 @@ func (filter *FilterYAML) unload() {
|
||||
|
||||
// Path to the filter contents
|
||||
func (filter *FilterYAML) Path(dataDir string) string {
|
||||
return filepath.Join(dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
|
||||
return filepath.Join(
|
||||
dataDir,
|
||||
filterDir,
|
||||
strconv.FormatInt(int64(filter.ID), 10)+".txt")
|
||||
}
|
||||
|
||||
// ensureName sets provided title or default name for the filter if it doesn't
|
||||
@@ -217,7 +215,10 @@ func (d *DNSFilter) loadFilters(array []FilterYAML) {
|
||||
for i := range array {
|
||||
filter := &array[i] // otherwise we're operating on a copy
|
||||
if filter.ID == 0 {
|
||||
filter.ID = assignUniqueFilterID()
|
||||
newID := d.idGen.next()
|
||||
log.Info("filtering: warning: filter at index %d has no id; assigning to %d", i, newID)
|
||||
|
||||
filter.ID = newID
|
||||
}
|
||||
|
||||
if !filter.Enabled {
|
||||
@@ -233,7 +234,7 @@ func (d *DNSFilter) loadFilters(array []FilterYAML) {
|
||||
}
|
||||
|
||||
func deduplicateFilters(filters []FilterYAML) (deduplicated []FilterYAML) {
|
||||
urls := stringutil.NewSet()
|
||||
urls := container.NewMapSet[string]()
|
||||
lastIdx := 0
|
||||
|
||||
for _, filter := range filters {
|
||||
@@ -247,22 +248,6 @@ func deduplicateFilters(filters []FilterYAML) (deduplicated []FilterYAML) {
|
||||
return filters[:lastIdx]
|
||||
}
|
||||
|
||||
// Set the next filter ID to max(filter.ID) + 1
|
||||
func updateUniqueFilterID(filters []FilterYAML) {
|
||||
for _, filter := range filters {
|
||||
if nextFilterID < filter.ID {
|
||||
nextFilterID = filter.ID + 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Improve this inexhaustible source of races.
|
||||
func assignUniqueFilterID() int64 {
|
||||
value := nextFilterID
|
||||
nextFilterID++
|
||||
return value
|
||||
}
|
||||
|
||||
// tryRefreshFilters is like [refreshFilters], but backs down if the update is
|
||||
// already going on.
|
||||
//
|
||||
@@ -608,7 +593,7 @@ func (d *DNSFilter) EnableFilters(async bool) {
|
||||
func (d *DNSFilter) enableFiltersLocked(async bool) {
|
||||
filters := make([]Filter, 1, len(d.conf.Filters)+len(d.conf.WhitelistFilters)+1)
|
||||
filters[0] = Filter{
|
||||
ID: CustomListID,
|
||||
ID: rulelist.URLFilterIDCustom,
|
||||
Data: []byte(strings.Join(d.conf.UserRules, "\n")),
|
||||
}
|
||||
|
||||
|
||||
@@ -20,11 +20,11 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/hostsfile"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/mathutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/syncutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||
@@ -32,19 +32,6 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// The IDs of built-in filter lists.
|
||||
//
|
||||
// Keep in sync with client/src/helpers/constants.js.
|
||||
// TODO(d.kolyshev): Add RewritesListID and don't forget to keep in sync.
|
||||
const (
|
||||
CustomListID = -iota
|
||||
SysHostsListID
|
||||
BlockedSvcsListID
|
||||
ParentalListID
|
||||
SafeBrowsingListID
|
||||
SafeSearchListID
|
||||
)
|
||||
|
||||
// ServiceEntry - blocked service array element
|
||||
type ServiceEntry struct {
|
||||
Name string
|
||||
@@ -232,6 +219,9 @@ type Checker interface {
|
||||
|
||||
// DNSFilter matches hostnames and DNS requests against filtering rules.
|
||||
type DNSFilter struct {
|
||||
// idGen is used to generate IDs for package urlfilter.
|
||||
idGen *idGenerator
|
||||
|
||||
// bufPool is a pool of buffers used for filtering-rule list parsing.
|
||||
bufPool *syncutil.Pool[[]byte]
|
||||
|
||||
@@ -278,7 +268,7 @@ type Filter struct {
|
||||
Data []byte `yaml:"-"`
|
||||
|
||||
// ID is automatically assigned when filter is added using nextFilterID.
|
||||
ID int64 `yaml:"id"`
|
||||
ID rulelist.URLFilterID `yaml:"id"`
|
||||
}
|
||||
|
||||
// Reason holds an enum detailing why it was filtered or not filtered
|
||||
@@ -530,11 +520,13 @@ func (d *DNSFilter) ParentalBlockHost() (host string) {
|
||||
type ResultRule struct {
|
||||
// Text is the text of the rule.
|
||||
Text string `json:",omitempty"`
|
||||
|
||||
// IP is the host IP. It is nil unless the rule uses the
|
||||
// /etc/hosts syntax or the reason is FilteredSafeSearch.
|
||||
IP netip.Addr `json:",omitempty"`
|
||||
|
||||
// FilterListID is the ID of the rule's filter list.
|
||||
FilterListID int64 `json:",omitempty"`
|
||||
FilterListID rulelist.URLFilterID `json:",omitempty"`
|
||||
}
|
||||
|
||||
// Result contains the result of a request check.
|
||||
@@ -637,7 +629,7 @@ func (d *DNSFilter) processRewrites(host string, qtype uint16) (res Result) {
|
||||
|
||||
res.Reason = Rewritten
|
||||
|
||||
cnames := stringutil.NewSet()
|
||||
cnames := container.NewMapSet[string]()
|
||||
origHost := host
|
||||
for matched && len(rewrites) > 0 && rewrites[0].Type == dns.TypeCNAME {
|
||||
rw := rewrites[0]
|
||||
@@ -705,7 +697,7 @@ func matchBlockedServicesRules(
|
||||
|
||||
ruleText := rule.Text()
|
||||
res.Rules = []*ResultRule{{
|
||||
FilterListID: int64(rule.GetFilterListID()),
|
||||
FilterListID: rule.GetFilterListID(),
|
||||
Text: ruleText,
|
||||
}}
|
||||
|
||||
@@ -970,7 +962,7 @@ func makeResult(matchedRules []rules.Rule, reason Reason) (res Result) {
|
||||
resRules := make([]*ResultRule, len(matchedRules))
|
||||
for i, mr := range matchedRules {
|
||||
resRules[i] = &ResultRule{
|
||||
FilterListID: int64(mr.GetFilterListID()),
|
||||
FilterListID: mr.GetFilterListID(),
|
||||
Text: mr.Text(),
|
||||
}
|
||||
}
|
||||
@@ -991,6 +983,7 @@ func InitModule() {
|
||||
// be non-nil.
|
||||
func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
||||
d = &DNSFilter{
|
||||
idGen: newIDGenerator(int32(time.Now().Unix())),
|
||||
bufPool: syncutil.NewSlicePool[byte](rulelist.DefaultRuleBufSize),
|
||||
refreshLock: &sync.Mutex{},
|
||||
safeBrowsingChecker: c.SafeBrowsingChecker,
|
||||
@@ -1054,8 +1047,8 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
||||
d.conf.Filters = deduplicateFilters(d.conf.Filters)
|
||||
d.conf.WhitelistFilters = deduplicateFilters(d.conf.WhitelistFilters)
|
||||
|
||||
updateUniqueFilterID(d.conf.Filters)
|
||||
updateUniqueFilterID(d.conf.WhitelistFilters)
|
||||
d.idGen.fix(d.conf.Filters)
|
||||
d.idGen.fix(d.conf.WhitelistFilters)
|
||||
|
||||
return d, nil
|
||||
}
|
||||
@@ -1139,7 +1132,7 @@ func (d *DNSFilter) checkSafeBrowsing(
|
||||
res = Result{
|
||||
Rules: []*ResultRule{{
|
||||
Text: "adguard-malware-shavar",
|
||||
FilterListID: SafeBrowsingListID,
|
||||
FilterListID: rulelist.URLFilterIDSafeBrowsing,
|
||||
}},
|
||||
Reason: FilteredSafeBrowsing,
|
||||
IsFiltered: true,
|
||||
@@ -1171,7 +1164,7 @@ func (d *DNSFilter) checkParental(
|
||||
res = Result{
|
||||
Rules: []*ResultRule{{
|
||||
Text: "parental CATEGORY_BLACKLISTED",
|
||||
FilterListID: ParentalListID,
|
||||
FilterListID: rulelist.URLFilterIDParentalControl,
|
||||
}},
|
||||
Reason: FilteredParental,
|
||||
IsFiltered: true,
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/golibs/hostsfile"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
@@ -66,7 +67,7 @@ func hostsRewrites(
|
||||
vals = append(vals, name)
|
||||
rls = append(rls, &ResultRule{
|
||||
Text: fmt.Sprintf("%s %s", addr, name),
|
||||
FilterListID: SysHostsListID,
|
||||
FilterListID: rulelist.URLFilterIDEtcHosts,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -84,7 +85,7 @@ func hostsRewrites(
|
||||
}
|
||||
rls = append(rls, &ResultRule{
|
||||
Text: fmt.Sprintf("%s %s", addr, host),
|
||||
FilterListID: SysHostsListID,
|
||||
FilterListID: rulelist.URLFilterIDEtcHosts,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
@@ -71,7 +72,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
|
||||
dtyp: dns.TypeA,
|
||||
wantRules: []*ResultRule{{
|
||||
Text: "1.2.3.4 v4.host.example",
|
||||
FilterListID: SysHostsListID,
|
||||
FilterListID: rulelist.URLFilterIDEtcHosts,
|
||||
}},
|
||||
wantResps: []rules.RRValue{addrv4},
|
||||
}, {
|
||||
@@ -80,7 +81,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
|
||||
dtyp: dns.TypeAAAA,
|
||||
wantRules: []*ResultRule{{
|
||||
Text: "::1 v6.host.example",
|
||||
FilterListID: SysHostsListID,
|
||||
FilterListID: rulelist.URLFilterIDEtcHosts,
|
||||
}},
|
||||
wantResps: []rules.RRValue{addrv6},
|
||||
}, {
|
||||
@@ -89,7 +90,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
|
||||
dtyp: dns.TypeAAAA,
|
||||
wantRules: []*ResultRule{{
|
||||
Text: "::ffff:1.2.3.4 mapped.host.example",
|
||||
FilterListID: SysHostsListID,
|
||||
FilterListID: rulelist.URLFilterIDEtcHosts,
|
||||
}},
|
||||
wantResps: []rules.RRValue{addrMapped},
|
||||
}, {
|
||||
@@ -98,7 +99,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
|
||||
dtyp: dns.TypePTR,
|
||||
wantRules: []*ResultRule{{
|
||||
Text: "1.2.3.4 v4.host.example",
|
||||
FilterListID: SysHostsListID,
|
||||
FilterListID: rulelist.URLFilterIDEtcHosts,
|
||||
}},
|
||||
wantResps: []rules.RRValue{"v4.host.example"},
|
||||
}, {
|
||||
@@ -107,7 +108,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
|
||||
dtyp: dns.TypePTR,
|
||||
wantRules: []*ResultRule{{
|
||||
Text: "::ffff:1.2.3.4 mapped.host.example",
|
||||
FilterListID: SysHostsListID,
|
||||
FilterListID: rulelist.URLFilterIDEtcHosts,
|
||||
}},
|
||||
wantResps: []rules.RRValue{"mapped.host.example"},
|
||||
}, {
|
||||
@@ -134,7 +135,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
|
||||
dtyp: dns.TypeAAAA,
|
||||
wantRules: []*ResultRule{{
|
||||
Text: fmt.Sprintf("%s v4.host.example", addrv4),
|
||||
FilterListID: SysHostsListID,
|
||||
FilterListID: rulelist.URLFilterIDEtcHosts,
|
||||
}},
|
||||
wantResps: nil,
|
||||
}, {
|
||||
@@ -143,7 +144,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
|
||||
dtyp: dns.TypeA,
|
||||
wantRules: []*ResultRule{{
|
||||
Text: fmt.Sprintf("%s v6.host.example", addrv6),
|
||||
FilterListID: SysHostsListID,
|
||||
FilterListID: rulelist.URLFilterIDEtcHosts,
|
||||
}},
|
||||
wantResps: nil,
|
||||
}, {
|
||||
@@ -164,7 +165,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
|
||||
dtyp: dns.TypeA,
|
||||
wantRules: []*ResultRule{{
|
||||
Text: "4.3.2.1 v4.host.with-dup",
|
||||
FilterListID: SysHostsListID,
|
||||
FilterListID: rulelist.URLFilterIDEtcHosts,
|
||||
}},
|
||||
wantResps: []rules.RRValue{addrv4Dup},
|
||||
}}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
@@ -86,7 +87,7 @@ func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
||||
Name: fj.Name,
|
||||
white: fj.Whitelist,
|
||||
Filter: Filter{
|
||||
ID: assignUniqueFilterID(),
|
||||
ID: d.idGen.next(),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -307,12 +308,12 @@ func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
|
||||
type filterJSON struct {
|
||||
URL string `json:"url"`
|
||||
Name string `json:"name"`
|
||||
LastUpdated string `json:"last_updated,omitempty"`
|
||||
ID int64 `json:"id"`
|
||||
RulesCount uint32 `json:"rules_count"`
|
||||
Enabled bool `json:"enabled"`
|
||||
URL string `json:"url"`
|
||||
Name string `json:"name"`
|
||||
LastUpdated string `json:"last_updated,omitempty"`
|
||||
ID rulelist.URLFilterID `json:"id"`
|
||||
RulesCount uint32 `json:"rules_count"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
type filteringConfig struct {
|
||||
@@ -388,8 +389,8 @@ func (d *DNSFilter) handleFilteringConfig(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
type checkHostRespRule struct {
|
||||
Text string `json:"text"`
|
||||
FilterListID int64 `json:"filter_list_id"`
|
||||
Text string `json:"text"`
|
||||
FilterListID rulelist.URLFilterID `json:"filter_list_id"`
|
||||
}
|
||||
|
||||
type checkHostResp struct {
|
||||
@@ -412,7 +413,7 @@ type checkHostResp struct {
|
||||
// FilterID is the ID of the rule's filter list.
|
||||
//
|
||||
// Deprecated: Use Rules[*].FilterListID.
|
||||
FilterID int64 `json:"filter_id"`
|
||||
FilterID rulelist.URLFilterID `json:"filter_id"`
|
||||
}
|
||||
|
||||
func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
74
internal/filtering/idgenerator.go
Normal file
74
internal/filtering/idgenerator.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package filtering
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// idGenerator generates filtering-list IDs in a way broadly compatible with the
|
||||
// legacy approach of AdGuard Home.
|
||||
//
|
||||
// TODO(a.garipov): Get rid of this once we switch completely to the new
|
||||
// rule-list architecture.
|
||||
type idGenerator struct {
|
||||
current *atomic.Int32
|
||||
}
|
||||
|
||||
// newIDGenerator returns a new ID generator initialized with the given seed
|
||||
// value.
|
||||
func newIDGenerator(seed int32) (g *idGenerator) {
|
||||
g = &idGenerator{
|
||||
current: &atomic.Int32{},
|
||||
}
|
||||
|
||||
g.current.Store(seed)
|
||||
|
||||
return g
|
||||
}
|
||||
|
||||
// next returns the next ID from the generator. It is safe for concurrent use.
|
||||
func (g *idGenerator) next() (id rulelist.URLFilterID) {
|
||||
id32 := g.current.Add(1)
|
||||
if id32 < 0 {
|
||||
panic(fmt.Errorf("invalid current id value %d", id32))
|
||||
}
|
||||
|
||||
return rulelist.URLFilterID(id32)
|
||||
}
|
||||
|
||||
// fix ensures that flts all have unique IDs.
|
||||
func (g *idGenerator) fix(flts []FilterYAML) {
|
||||
set := container.NewMapSet[rulelist.URLFilterID]()
|
||||
for i, f := range flts {
|
||||
id := f.ID
|
||||
if id == 0 {
|
||||
id = g.next()
|
||||
flts[i].ID = id
|
||||
}
|
||||
|
||||
if !set.Has(id) {
|
||||
set.Add(id)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
newID := g.next()
|
||||
for set.Has(newID) {
|
||||
newID = g.next()
|
||||
}
|
||||
|
||||
log.Info(
|
||||
"filtering: warning: filter at index %d has duplicate id %d; reassigning to %d",
|
||||
i,
|
||||
id,
|
||||
newID,
|
||||
)
|
||||
|
||||
flts[i].ID = newID
|
||||
set.Add(newID)
|
||||
}
|
||||
}
|
||||
88
internal/filtering/idgenerator_internal_test.go
Normal file
88
internal/filtering/idgenerator_internal_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package filtering
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIDGenerator_Fix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
in []FilterYAML
|
||||
}{{
|
||||
name: "nil",
|
||||
in: nil,
|
||||
}, {
|
||||
name: "empty",
|
||||
in: []FilterYAML{},
|
||||
}, {
|
||||
name: "one_zero",
|
||||
in: []FilterYAML{{}},
|
||||
}, {
|
||||
name: "two_zeros",
|
||||
in: []FilterYAML{{}, {}},
|
||||
}, {
|
||||
name: "many_good",
|
||||
in: []FilterYAML{{
|
||||
Filter: Filter{
|
||||
ID: 1,
|
||||
},
|
||||
}, {
|
||||
Filter: Filter{
|
||||
ID: 2,
|
||||
},
|
||||
}, {
|
||||
Filter: Filter{
|
||||
ID: 3,
|
||||
},
|
||||
}},
|
||||
}, {
|
||||
name: "two_dups",
|
||||
in: []FilterYAML{{
|
||||
Filter: Filter{
|
||||
ID: 1,
|
||||
},
|
||||
}, {
|
||||
Filter: Filter{
|
||||
ID: 3,
|
||||
},
|
||||
}, {
|
||||
Filter: Filter{
|
||||
ID: 1,
|
||||
},
|
||||
}, {
|
||||
Filter: Filter{
|
||||
ID: 2,
|
||||
},
|
||||
}},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
g := newIDGenerator(1)
|
||||
g.fix(tc.in)
|
||||
|
||||
assertUniqueIDs(t, tc.in)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// assertUniqueIDs is a test helper that asserts that the IDs of filters are
|
||||
// unique.
|
||||
func assertUniqueIDs(t testing.TB, flts []FilterYAML) {
|
||||
t.Helper()
|
||||
|
||||
uc := aghalg.UniqChecker[rulelist.URLFilterID]{}
|
||||
for _, f := range flts {
|
||||
uc.Add(f.ID)
|
||||
}
|
||||
|
||||
assert.NoError(t, uc.Validate())
|
||||
}
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
@@ -85,7 +85,7 @@ func (s *DefaultStorage) MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Check cnames for cycles on initialization.
|
||||
cnames := stringutil.NewSet()
|
||||
cnames := container.NewMapSet[string]()
|
||||
host := dReq.Hostname
|
||||
for len(rrules) > 0 && rrules[0].DNSRewrite != nil && rrules[0].DNSRewrite.NewCNAME != "" {
|
||||
rule := rrules[0]
|
||||
|
||||
254
internal/filtering/rulelist/engine.go
Normal file
254
internal/filtering/rulelist/engine.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package rulelist
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||
"github.com/c2h5oh/datasize"
|
||||
)
|
||||
|
||||
// Engine is a single DNS filter based on one or more rule lists. This
|
||||
// structure contains the filtering engine combining several rule lists.
|
||||
//
|
||||
// TODO(a.garipov): Merge with [TextEngine] in some way?
|
||||
type Engine struct {
|
||||
// mu protects engine and storage.
|
||||
//
|
||||
// TODO(a.garipov): See if anything else should be protected.
|
||||
mu *sync.RWMutex
|
||||
|
||||
// engine is the filtering engine.
|
||||
engine *urlfilter.DNSEngine
|
||||
|
||||
// storage is the filtering-rule storage. It is saved here to close it.
|
||||
storage *filterlist.RuleStorage
|
||||
|
||||
// name is the human-readable name of the engine, like "allowed", "blocked",
|
||||
// or "custom".
|
||||
name string
|
||||
|
||||
// filters is the data about rule filters in this engine.
|
||||
filters []*Filter
|
||||
}
|
||||
|
||||
// EngineConfig is the configuration for rule-list filtering engines created by
|
||||
// combining refreshable filters.
|
||||
type EngineConfig struct {
|
||||
// Name is the human-readable name of this engine, like "allowed",
|
||||
// "blocked", or "custom".
|
||||
Name string
|
||||
|
||||
// Filters is the data about rule lists in this engine. There must be no
|
||||
// other references to the elements of this slice.
|
||||
Filters []*Filter
|
||||
}
|
||||
|
||||
// NewEngine returns a new rule-list filtering engine. The engine is not
|
||||
// refreshed, so a refresh should be performed before use.
|
||||
func NewEngine(c *EngineConfig) (e *Engine) {
|
||||
return &Engine{
|
||||
mu: &sync.RWMutex{},
|
||||
name: c.Name,
|
||||
filters: c.Filters,
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the underlying rule-list engine as well as the rule lists.
|
||||
func (e *Engine) Close() (err error) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
if e.storage == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = e.storage.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("closing engine %q: %w", e.name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FilterRequest returns the result of filtering req using the DNS filtering
|
||||
// engine.
|
||||
func (e *Engine) FilterRequest(
|
||||
req *urlfilter.DNSRequest,
|
||||
) (res *urlfilter.DNSResult, hasMatched bool) {
|
||||
return e.currentEngine().MatchRequest(req)
|
||||
}
|
||||
|
||||
// currentEngine returns the current filtering engine.
|
||||
func (e *Engine) currentEngine() (enging *urlfilter.DNSEngine) {
|
||||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
|
||||
return e.engine
|
||||
}
|
||||
|
||||
// Refresh updates all rule lists in e. ctx is used for cancellation.
|
||||
// parseBuf, cli, cacheDir, and maxSize are used for updates of rule-list
|
||||
// filters; see [Filter.Refresh].
|
||||
//
|
||||
// TODO(a.garipov): Unexport and test in an internal test or through enigne
|
||||
// tests.
|
||||
func (e *Engine) Refresh(
|
||||
ctx context.Context,
|
||||
parseBuf []byte,
|
||||
cli *http.Client,
|
||||
cacheDir string,
|
||||
maxSize datasize.ByteSize,
|
||||
) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "updating engine %q: %w", e.name) }()
|
||||
|
||||
var filtersToRefresh []*Filter
|
||||
for _, f := range e.filters {
|
||||
if f.enabled {
|
||||
filtersToRefresh = append(filtersToRefresh, f)
|
||||
}
|
||||
}
|
||||
|
||||
if len(filtersToRefresh) == 0 {
|
||||
log.Info("filtering: updating engine %q: no rule-list filters", e.name)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
engRefr := &engineRefresh{
|
||||
httpCli: cli,
|
||||
cacheDir: cacheDir,
|
||||
engineName: e.name,
|
||||
parseBuf: parseBuf,
|
||||
maxSize: maxSize,
|
||||
}
|
||||
|
||||
ruleLists, errs := engRefr.process(ctx, e.filters)
|
||||
if isOneTimeoutError(errs) {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
storage, err := filterlist.NewRuleStorage(ruleLists)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("creating rule storage: %w", err))
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
e.resetStorage(storage)
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// resetStorage sets e.storage and e.engine and closes the previous storage.
|
||||
// Errors from closing the previous storage are logged.
|
||||
func (e *Engine) resetStorage(storage *filterlist.RuleStorage) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
prevStorage := e.storage
|
||||
e.storage, e.engine = storage, urlfilter.NewDNSEngine(storage)
|
||||
|
||||
if prevStorage == nil {
|
||||
return
|
||||
}
|
||||
|
||||
err := prevStorage.Close()
|
||||
if err != nil {
|
||||
log.Error("filtering: engine %q: closing old storage: %s", e.name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// isOneTimeoutError returns true if the sole error in errs is either
|
||||
// [context.Canceled] or [context.DeadlineExceeded].
|
||||
func isOneTimeoutError(errs []error) (ok bool) {
|
||||
if len(errs) != 1 {
|
||||
return false
|
||||
}
|
||||
|
||||
err := errs[0]
|
||||
|
||||
return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
// engineRefresh represents a single ongoing engine refresh.
|
||||
type engineRefresh struct {
|
||||
httpCli *http.Client
|
||||
cacheDir string
|
||||
engineName string
|
||||
parseBuf []byte
|
||||
maxSize datasize.ByteSize
|
||||
}
|
||||
|
||||
// process runs updates of all given rule-list filters. All errors are logged
|
||||
// as they appear, since the update can take a significant amount of time.
|
||||
// errs contains all errors that happened during the update, unless the context
|
||||
// is canceled or its deadline is reached, in which case errs will only contain
|
||||
// a single timeout error.
|
||||
//
|
||||
// TODO(a.garipov): Think of a better way to communicate the timeout condition?
|
||||
func (r *engineRefresh) process(
|
||||
ctx context.Context,
|
||||
filters []*Filter,
|
||||
) (ruleLists []filterlist.RuleList, errs []error) {
|
||||
ruleLists = make([]filterlist.RuleList, 0, len(filters))
|
||||
for i, f := range filters {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, []error{fmt.Errorf("timeout after updating %d filters: %w", i, ctx.Err())}
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
err := r.processFilter(ctx, f)
|
||||
if err == nil {
|
||||
ruleLists = append(ruleLists, f.ruleList)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
errs = append(errs, err)
|
||||
|
||||
// Also log immediately, since the update can take a lot of time.
|
||||
log.Error(
|
||||
"filtering: updating engine %q: rule list %s from url %q: %s\n",
|
||||
r.engineName,
|
||||
f.uid,
|
||||
f.url,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
return ruleLists, errs
|
||||
}
|
||||
|
||||
// processFilter runs an update of a single rule-list filter.
|
||||
func (r *engineRefresh) processFilter(ctx context.Context, f *Filter) (err error) {
|
||||
prevChecksum := f.checksum
|
||||
parseRes, err := f.Refresh(ctx, r.parseBuf, r.httpCli, r.cacheDir, r.maxSize)
|
||||
if err != nil {
|
||||
return fmt.Errorf("updating %s: %w", f.uid, err)
|
||||
}
|
||||
|
||||
if prevChecksum == parseRes.Checksum {
|
||||
log.Info("filtering: engine %q: filter %q: no change", r.engineName, f.uid)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info(
|
||||
"filtering: updated engine %q: filter %q: %d bytes, %d rules",
|
||||
r.engineName,
|
||||
f.uid,
|
||||
parseRes.BytesWritten,
|
||||
parseRes.RulesCount,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
63
internal/filtering/rulelist/engine_test.go
Normal file
63
internal/filtering/rulelist/engine_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package rulelist_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEngine_Refresh(t *testing.T) {
|
||||
cacheDir := t.TempDir()
|
||||
|
||||
fileURL, srvURL := newFilterLocations(t, cacheDir, testRuleTextBlocked, testRuleTextBlocked2)
|
||||
|
||||
fileFlt := newFilter(t, fileURL, "File Filter")
|
||||
httpFlt := newFilter(t, srvURL, "HTTP Filter")
|
||||
|
||||
eng := rulelist.NewEngine(&rulelist.EngineConfig{
|
||||
Name: "Engine",
|
||||
Filters: []*rulelist.Filter{fileFlt, httpFlt},
|
||||
})
|
||||
require.NotNil(t, eng)
|
||||
testutil.CleanupAndRequireSuccess(t, eng.Close)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
buf := make([]byte, rulelist.DefaultRuleBufSize)
|
||||
cli := &http.Client{
|
||||
Timeout: testTimeout,
|
||||
}
|
||||
|
||||
err := eng.Refresh(ctx, buf, cli, cacheDir, rulelist.DefaultMaxRuleListSize)
|
||||
require.NoError(t, err)
|
||||
|
||||
fltReq := &urlfilter.DNSRequest{
|
||||
Hostname: "blocked.example",
|
||||
Answer: false,
|
||||
DNSType: dns.TypeA,
|
||||
}
|
||||
|
||||
fltRes, hasMatched := eng.FilterRequest(fltReq)
|
||||
assert.True(t, hasMatched)
|
||||
|
||||
require.NotNil(t, fltRes)
|
||||
|
||||
fltReq = &urlfilter.DNSRequest{
|
||||
Hostname: "blocked-2.example",
|
||||
Answer: false,
|
||||
DNSType: dns.TypeA,
|
||||
}
|
||||
|
||||
fltRes, hasMatched = eng.FilterRequest(fltReq)
|
||||
assert.True(t, hasMatched)
|
||||
|
||||
require.NotNil(t, fltRes)
|
||||
}
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghrenameio"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/ioutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||
"github.com/c2h5oh/datasize"
|
||||
)
|
||||
@@ -52,8 +51,6 @@ type Filter struct {
|
||||
checksum uint32
|
||||
|
||||
// enabled, if true, means that this rule-list filter is used for filtering.
|
||||
//
|
||||
// TODO(a.garipov): Take into account.
|
||||
enabled bool
|
||||
}
|
||||
|
||||
@@ -106,6 +103,11 @@ func NewFilter(c *FilterConfig) (f *Filter, err error) {
|
||||
// Refresh updates the data in the rule-list filter. parseBuf is the initial
|
||||
// buffer used to parse information from the data. cli and maxSize are only
|
||||
// used when f is a URL-based list.
|
||||
//
|
||||
// TODO(a.garipov): Unexport and test in an internal test or through enigne
|
||||
// tests.
|
||||
//
|
||||
// TODO(a.garipov): Consider not returning parseRes.
|
||||
func (f *Filter) Refresh(
|
||||
ctx context.Context,
|
||||
parseBuf []byte,
|
||||
@@ -300,39 +302,3 @@ func (f *Filter) Close() (err error) {
|
||||
|
||||
return f.ruleList.Close()
|
||||
}
|
||||
|
||||
// filterUpdate represents a single ongoing rule-list filter update.
|
||||
//
|
||||
//lint:ignore U1000 TODO(a.garipov): Use.
|
||||
type filterUpdate struct {
|
||||
httpCli *http.Client
|
||||
cacheDir string
|
||||
name string
|
||||
parseBuf []byte
|
||||
maxSize datasize.ByteSize
|
||||
}
|
||||
|
||||
// process runs an update of a single rule-list.
|
||||
func (u *filterUpdate) process(ctx context.Context, f *Filter) (err error) {
|
||||
prevChecksum := f.checksum
|
||||
parseRes, err := f.Refresh(ctx, u.parseBuf, u.httpCli, u.cacheDir, u.maxSize)
|
||||
if err != nil {
|
||||
return fmt.Errorf("updating %s: %w", f.uid, err)
|
||||
}
|
||||
|
||||
if prevChecksum == parseRes.Checksum {
|
||||
log.Info("filtering: filter %q: filter %q: no change", u.name, f.uid)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info(
|
||||
"filtering: updated filter %q: filter %q: %d bytes, %d rules",
|
||||
u.name,
|
||||
f.uid,
|
||||
parseRes.BytesWritten,
|
||||
parseRes.RulesCount,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,9 +2,7 @@ package rulelist_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -20,23 +18,8 @@ func TestFilter_Refresh(t *testing.T) {
|
||||
cacheDir := t.TempDir()
|
||||
uid := rulelist.MustNewUID()
|
||||
|
||||
initialFile := filepath.Join(cacheDir, "initial.txt")
|
||||
initialData := []byte(
|
||||
testRuleTextTitle +
|
||||
testRuleTextBlocked,
|
||||
)
|
||||
writeErr := os.WriteFile(initialFile, initialData, 0o644)
|
||||
require.NoError(t, writeErr)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
pt := testutil.PanicT{}
|
||||
|
||||
_, err := io.WriteString(w, testRuleTextTitle+testRuleTextBlocked)
|
||||
require.NoError(pt, err)
|
||||
}))
|
||||
|
||||
srvURL, urlErr := url.Parse(srv.URL)
|
||||
require.NoError(t, urlErr)
|
||||
const fltData = testRuleTextTitle + testRuleTextBlocked
|
||||
fileURL, srvURL := newFilterLocations(t, cacheDir, fltData, fltData)
|
||||
|
||||
testCases := []struct {
|
||||
url *url.URL
|
||||
@@ -56,7 +39,7 @@ func TestFilter_Refresh(t *testing.T) {
|
||||
name: "file",
|
||||
url: &url.URL{
|
||||
Scheme: "file",
|
||||
Path: initialFile,
|
||||
Path: fileURL.Path,
|
||||
},
|
||||
wantNewErrMsg: "",
|
||||
}, {
|
||||
|
||||
@@ -25,6 +25,24 @@ const DefaultMaxRuleListSize = 64 * datasize.MB
|
||||
// urlfilter.
|
||||
type URLFilterID = int
|
||||
|
||||
// The IDs of built-in filter lists.
|
||||
//
|
||||
// NOTE: Do not change without the need for it and keep in sync with
|
||||
// client/src/helpers/constants.js.
|
||||
//
|
||||
// TODO(a.garipov): Add type [URLFilterID] once it is used consistently in
|
||||
// package filtering.
|
||||
//
|
||||
// TODO(d.kolyshev): Add URLFilterIDLegacyRewrite here and to the UI.
|
||||
const (
|
||||
URLFilterIDCustom URLFilterID = 0
|
||||
URLFilterIDEtcHosts URLFilterID = -1
|
||||
URLFilterIDBlockedService URLFilterID = -2
|
||||
URLFilterIDParentalControl URLFilterID = -3
|
||||
URLFilterIDSafeBrowsing URLFilterID = -4
|
||||
URLFilterIDSafeSearch URLFilterID = -5
|
||||
)
|
||||
|
||||
// UID is the type for the unique IDs of filtering-rule lists.
|
||||
type UID uuid.UUID
|
||||
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
package rulelist_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -35,3 +43,70 @@ const (
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/6003.
|
||||
testRuleTextCosmetic = "||cosmetic.example## :has-text(/\u200c/i)\n"
|
||||
)
|
||||
|
||||
// urlFilterIDCounter is the atomic integer used to create unique filter IDs.
|
||||
var urlFilterIDCounter = &atomic.Int32{}
|
||||
|
||||
// newURLFilterID returns a new unique URLFilterID.
|
||||
func newURLFilterID() (id rulelist.URLFilterID) {
|
||||
return rulelist.URLFilterID(urlFilterIDCounter.Add(1))
|
||||
}
|
||||
|
||||
// newFilter is a helper for creating new filters in tests. It does not
|
||||
// register the closing of the filter using t.Cleanup; callers must do that
|
||||
// either directly or by using the filter in an engine.
|
||||
func newFilter(t testing.TB, u *url.URL, name string) (f *rulelist.Filter) {
|
||||
t.Helper()
|
||||
|
||||
f, err := rulelist.NewFilter(&rulelist.FilterConfig{
|
||||
URL: u,
|
||||
Name: name,
|
||||
UID: rulelist.MustNewUID(),
|
||||
URLFilterID: newURLFilterID(),
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
// newFilterLocations is a test helper that sets up both the filtering-rule list
|
||||
// file and the HTTP-server. It also registers file removal and server stopping
|
||||
// using t.Cleanup.
|
||||
func newFilterLocations(
|
||||
t testing.TB,
|
||||
cacheDir string,
|
||||
fileData string,
|
||||
httpData string,
|
||||
) (fileURL, srvURL *url.URL) {
|
||||
filePath := filepath.Join(cacheDir, "initial.txt")
|
||||
err := os.WriteFile(filePath, []byte(fileData), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
return os.Remove(filePath)
|
||||
})
|
||||
|
||||
fileURL = &url.URL{
|
||||
Scheme: "file",
|
||||
Path: filePath,
|
||||
}
|
||||
|
||||
srv := newStringHTTPServer(httpData)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
srvURL, err = url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
return fileURL, srvURL
|
||||
}
|
||||
|
||||
// newStringHTTPServer returns a new HTTP server that serves s.
|
||||
func newStringHTTPServer(s string) (srv *httptest.Server) {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
pt := testutil.PanicT{}
|
||||
|
||||
_, err := io.WriteString(w, s)
|
||||
require.NoError(pt, err)
|
||||
}))
|
||||
}
|
||||
|
||||
98
internal/filtering/rulelist/textengine.go
Normal file
98
internal/filtering/rulelist/textengine.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package rulelist
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||
)
|
||||
|
||||
// TextEngine is a single DNS filter based on a list of rules in text form.
|
||||
type TextEngine struct {
|
||||
// mu protects engine and storage.
|
||||
mu *sync.RWMutex
|
||||
|
||||
// engine is the filtering engine.
|
||||
engine *urlfilter.DNSEngine
|
||||
|
||||
// storage is the filtering-rule storage. It is saved here to close it.
|
||||
storage *filterlist.RuleStorage
|
||||
|
||||
// name is the human-readable name of the engine, like "custom".
|
||||
name string
|
||||
}
|
||||
|
||||
// TextEngineConfig is the configuration for a rule-list filtering engine
|
||||
// created from a filtering rule text.
|
||||
type TextEngineConfig struct {
|
||||
// Name is the human-readable name of this engine, like "allowed",
|
||||
// "blocked", or "custom".
|
||||
Name string
|
||||
|
||||
// Rules is the text of the filtering rules for this engine.
|
||||
Rules []string
|
||||
|
||||
// ID is the ID to use inside a URL-filter engine.
|
||||
ID URLFilterID
|
||||
}
|
||||
|
||||
// NewTextEngine returns a new rule-list filtering engine that uses rules
|
||||
// directly. The engine is ready to use and should not be refreshed.
|
||||
func NewTextEngine(c *TextEngineConfig) (e *TextEngine, err error) {
|
||||
text := strings.Join(c.Rules, "\n")
|
||||
storage, err := filterlist.NewRuleStorage([]filterlist.RuleList{
|
||||
&filterlist.StringRuleList{
|
||||
RulesText: text,
|
||||
ID: c.ID,
|
||||
IgnoreCosmetic: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating rule storage: %w", err)
|
||||
}
|
||||
|
||||
engine := urlfilter.NewDNSEngine(storage)
|
||||
|
||||
return &TextEngine{
|
||||
mu: &sync.RWMutex{},
|
||||
engine: engine,
|
||||
storage: storage,
|
||||
name: c.Name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// FilterRequest returns the result of filtering req using the DNS filtering
|
||||
// engine.
|
||||
func (e *TextEngine) FilterRequest(
|
||||
req *urlfilter.DNSRequest,
|
||||
) (res *urlfilter.DNSResult, hasMatched bool) {
|
||||
var engine *urlfilter.DNSEngine
|
||||
|
||||
func() {
|
||||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
|
||||
engine = e.engine
|
||||
}()
|
||||
|
||||
return engine.MatchRequest(req)
|
||||
}
|
||||
|
||||
// Close closes the underlying rule list engine as well as the rule lists.
|
||||
func (e *TextEngine) Close() (err error) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
if e.storage == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = e.storage.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("closing text engine %q: %w", e.name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
40
internal/filtering/rulelist/textengine_test.go
Normal file
40
internal/filtering/rulelist/textengine_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package rulelist_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewTextEngine(t *testing.T) {
|
||||
eng, err := rulelist.NewTextEngine(&rulelist.TextEngineConfig{
|
||||
Name: "RulesEngine",
|
||||
Rules: []string{
|
||||
testRuleTextTitle,
|
||||
testRuleTextBlocked,
|
||||
},
|
||||
ID: testURLFilterID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, eng)
|
||||
testutil.CleanupAndRequireSuccess(t, eng.Close)
|
||||
|
||||
fltReq := &urlfilter.DNSRequest{
|
||||
Hostname: "blocked.example",
|
||||
Answer: false,
|
||||
DNSType: dns.TypeA,
|
||||
}
|
||||
|
||||
fltRes, hasMatched := eng.FilterRequest(fltReq)
|
||||
assert.True(t, hasMatched)
|
||||
|
||||
require.NotNil(t, fltRes)
|
||||
require.NotNil(t, fltRes.NetworkRule)
|
||||
|
||||
assert.Equal(t, fltRes.NetworkRule.FilterListID, testURLFilterID)
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
@@ -98,7 +99,7 @@ func NewDefault(
|
||||
cacheTTL: cacheTTL,
|
||||
}
|
||||
|
||||
err = ss.resetEngine(filtering.SafeSearchListID, conf)
|
||||
err = ss.resetEngine(rulelist.URLFilterIDSafeSearch, conf)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
@@ -234,7 +235,7 @@ func (ss *Default) newResult(
|
||||
) (res *filtering.Result, err error) {
|
||||
res = &filtering.Result{
|
||||
Rules: []*filtering.ResultRule{{
|
||||
FilterListID: filtering.SafeSearchListID,
|
||||
FilterListID: rulelist.URLFilterIDSafeSearch,
|
||||
}},
|
||||
Reason: filtering.FilteredSafeSearch,
|
||||
IsFiltered: true,
|
||||
@@ -368,7 +369,7 @@ func (ss *Default) Update(conf filtering.SafeSearchConfig) (err error) {
|
||||
ss.mu.Lock()
|
||||
defer ss.mu.Unlock()
|
||||
|
||||
err = ss.resetEngine(filtering.SafeSearchListID, conf)
|
||||
err = ss.resetEngine(rulelist.URLFilterIDSafeSearch, conf)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
@@ -69,7 +70,7 @@ func TestDefault_CheckHost_yandex(t *testing.T) {
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.Equal(t, yandexIP, res.Rules[0].IP)
|
||||
assert.EqualValues(t, filtering.SafeSearchListID, res.Rules[0].FilterListID)
|
||||
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,7 +90,7 @@ func TestDefault_CheckHost_yandexAAAA(t *testing.T) {
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.Empty(t, res.Rules[0].IP)
|
||||
assert.EqualValues(t, filtering.SafeSearchListID, res.Rules[0].FilterListID)
|
||||
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
|
||||
}
|
||||
|
||||
func TestDefault_CheckHost_google(t *testing.T) {
|
||||
@@ -128,7 +129,7 @@ func TestDefault_CheckHost_google(t *testing.T) {
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.Equal(t, wantIP, res.Rules[0].IP)
|
||||
assert.EqualValues(t, filtering.SafeSearchListID, res.Rules[0].FilterListID)
|
||||
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -180,7 +181,7 @@ func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.Empty(t, res.Rules[0].IP)
|
||||
assert.EqualValues(t, filtering.SafeSearchListID, res.Rules[0].FilterListID)
|
||||
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
|
||||
}
|
||||
|
||||
func TestDefault_Update(t *testing.T) {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"go.etcd.io/bbolt"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
@@ -51,14 +52,15 @@ func (s *session) deserialize(data []byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Auth - global object
|
||||
// Auth is the global authentication object.
|
||||
type Auth struct {
|
||||
db *bbolt.DB
|
||||
rateLimiter *authRateLimiter
|
||||
sessions map[string]*session
|
||||
users []webUser
|
||||
lock sync.Mutex
|
||||
sessionTTL uint32
|
||||
trustedProxies netutil.SubnetSet
|
||||
db *bbolt.DB
|
||||
rateLimiter *authRateLimiter
|
||||
sessions map[string]*session
|
||||
users []webUser
|
||||
lock sync.Mutex
|
||||
sessionTTL uint32
|
||||
}
|
||||
|
||||
// webUser represents a user of the Web UI.
|
||||
@@ -69,15 +71,22 @@ type webUser struct {
|
||||
PasswordHash string `yaml:"password"`
|
||||
}
|
||||
|
||||
// InitAuth - create a global object
|
||||
func InitAuth(dbFilename string, users []webUser, sessionTTL uint32, rateLimiter *authRateLimiter) *Auth {
|
||||
// InitAuth initializes the global authentication object.
|
||||
func InitAuth(
|
||||
dbFilename string,
|
||||
users []webUser,
|
||||
sessionTTL uint32,
|
||||
rateLimiter *authRateLimiter,
|
||||
trustedProxies netutil.SubnetSet,
|
||||
) (a *Auth) {
|
||||
log.Info("Initializing auth module: %s", dbFilename)
|
||||
|
||||
a := &Auth{
|
||||
sessionTTL: sessionTTL,
|
||||
rateLimiter: rateLimiter,
|
||||
sessions: make(map[string]*session),
|
||||
users: users,
|
||||
a = &Auth{
|
||||
sessionTTL: sessionTTL,
|
||||
rateLimiter: rateLimiter,
|
||||
sessions: make(map[string]*session),
|
||||
users: users,
|
||||
trustedProxies: trustedProxies,
|
||||
}
|
||||
var err error
|
||||
a.db, err = bbolt.Open(dbFilename, 0o644, nil)
|
||||
@@ -95,7 +104,7 @@ func InitAuth(dbFilename string, users []webUser, sessionTTL uint32, rateLimiter
|
||||
return a
|
||||
}
|
||||
|
||||
// Close - close module
|
||||
// Close closes the authentication database.
|
||||
func (a *Auth) Close() {
|
||||
_ = a.db.Close()
|
||||
}
|
||||
@@ -104,7 +113,8 @@ func bucketName() []byte {
|
||||
return []byte("sessions-2")
|
||||
}
|
||||
|
||||
// load sessions from file, remove expired sessions
|
||||
// loadSessions loads sessions from the database file and removes expired
|
||||
// sessions.
|
||||
func (a *Auth) loadSessions() {
|
||||
tx, err := a.db.Begin(true)
|
||||
if err != nil {
|
||||
@@ -156,7 +166,8 @@ func (a *Auth) loadSessions() {
|
||||
log.Debug("auth: loaded %d sessions from DB (removed %d expired)", len(a.sessions), removed)
|
||||
}
|
||||
|
||||
// store session data in file
|
||||
// addSession adds a new session to the list of sessions and saves it in the
|
||||
// database file.
|
||||
func (a *Auth) addSession(data []byte, s *session) {
|
||||
name := hex.EncodeToString(data)
|
||||
a.lock.Lock()
|
||||
@@ -167,7 +178,7 @@ func (a *Auth) addSession(data []byte, s *session) {
|
||||
}
|
||||
}
|
||||
|
||||
// store session data in file
|
||||
// storeSession saves a session in the database file.
|
||||
func (a *Auth) storeSession(data []byte, s *session) bool {
|
||||
tx, err := a.db.Begin(true)
|
||||
if err != nil {
|
||||
|
||||
@@ -37,7 +37,7 @@ func TestAuth(t *testing.T) {
|
||||
Name: "name",
|
||||
PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2",
|
||||
}}
|
||||
a := InitAuth(fn, nil, 60, nil)
|
||||
a := InitAuth(fn, nil, 60, nil, nil)
|
||||
s := session{}
|
||||
|
||||
user := webUser{Name: "name"}
|
||||
@@ -66,7 +66,7 @@ func TestAuth(t *testing.T) {
|
||||
a.Close()
|
||||
|
||||
// load saved session
|
||||
a = InitAuth(fn, users, 60, nil)
|
||||
a = InitAuth(fn, users, 60, nil, nil)
|
||||
|
||||
// the session is still alive
|
||||
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
|
||||
@@ -82,7 +82,7 @@ func TestAuth(t *testing.T) {
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// load and remove expired sessions
|
||||
a = InitAuth(fn, users, 60, nil)
|
||||
a = InitAuth(fn, users, 60, nil, nil)
|
||||
assert.Equal(t, checkSessionNotFound, a.checkSession(sessStr))
|
||||
|
||||
a.Close()
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -78,7 +78,7 @@ func (a *Auth) newCookie(req loginJSON, addr string) (c *http.Cookie, err error)
|
||||
// a well-maintained third-party module.
|
||||
//
|
||||
// TODO(a.garipov): Support header Forwarded from RFC 7329.
|
||||
func realIP(r *http.Request) (ip net.IP, err error) {
|
||||
func realIP(r *http.Request) (ip netip.Addr, err error) {
|
||||
proxyHeaders := []string{
|
||||
httphdr.CFConnectingIP,
|
||||
httphdr.TrueClientIP,
|
||||
@@ -87,8 +87,8 @@ func realIP(r *http.Request) (ip net.IP, err error) {
|
||||
|
||||
for _, h := range proxyHeaders {
|
||||
v := r.Header.Get(h)
|
||||
ip = net.ParseIP(v)
|
||||
if ip != nil {
|
||||
ip, err = netip.ParseAddr(v)
|
||||
if err == nil {
|
||||
return ip, nil
|
||||
}
|
||||
}
|
||||
@@ -96,20 +96,20 @@ func realIP(r *http.Request) (ip net.IP, err error) {
|
||||
// If none of the above yielded any results, get the leftmost IP address
|
||||
// from the X-Forwarded-For header.
|
||||
s := r.Header.Get(httphdr.XForwardedFor)
|
||||
ipStrs := strings.SplitN(s, ", ", 2)
|
||||
ip = net.ParseIP(ipStrs[0])
|
||||
if ip != nil {
|
||||
ipStr, _, _ := strings.Cut(s, ",")
|
||||
ip, err = netip.ParseAddr(ipStr)
|
||||
if err == nil {
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
// When everything else fails, just return the remote address as understood
|
||||
// by the stdlib.
|
||||
ipStr, err := netutil.SplitHost(r.RemoteAddr)
|
||||
ipStr, err = netutil.SplitHost(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting ip from client addr: %w", err)
|
||||
return netip.Addr{}, fmt.Errorf("getting ip from client addr: %w", err)
|
||||
}
|
||||
|
||||
return net.ParseIP(ipStr), nil
|
||||
return netip.ParseAddr(ipStr)
|
||||
}
|
||||
|
||||
// writeErrorWithIP is like [aghhttp.Error], but includes the remote IP address
|
||||
@@ -142,8 +142,6 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
// to security issues.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2799.
|
||||
//
|
||||
// TODO(e.burkov): Use realIP when the issue will be fixed.
|
||||
if remoteIP, err = netutil.SplitHost(r.RemoteAddr); err != nil {
|
||||
writeErrorWithIP(
|
||||
r,
|
||||
@@ -173,20 +171,24 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
cookie, err := Context.auth.newCookie(req, remoteIP)
|
||||
if err != nil {
|
||||
writeErrorWithIP(r, w, http.StatusForbidden, remoteIP, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Use realIP here, since this IP address is only used for logging.
|
||||
ip, err := realIP(r)
|
||||
if err != nil {
|
||||
log.Error("auth: getting real ip from request with remote ip %s: %s", remoteIP, err)
|
||||
}
|
||||
|
||||
log.Info("auth: user %q successfully logged in from ip %v", req.Name, ip)
|
||||
cookie, err := Context.auth.newCookie(req, remoteIP)
|
||||
if err != nil {
|
||||
logIP := remoteIP
|
||||
if Context.auth.trustedProxies.Contains(ip.Unmap()) {
|
||||
logIP = ip.String()
|
||||
}
|
||||
|
||||
writeErrorWithIP(r, w, http.StatusForbidden, logIP, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("auth: user %q successfully logged in from ip %s", req.Name, ip)
|
||||
|
||||
http.SetCookie(w, cookie)
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
@@ -39,7 +39,7 @@ func TestAuthHTTP(t *testing.T) {
|
||||
users := []webUser{
|
||||
{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
|
||||
}
|
||||
Context.auth = InitAuth(fn, users, 60, nil)
|
||||
Context.auth = InitAuth(fn, users, 60, nil, nil)
|
||||
|
||||
handlerCalled := false
|
||||
handler := func(_ http.ResponseWriter, _ *http.Request) {
|
||||
@@ -125,13 +125,13 @@ func TestRealIP(t *testing.T) {
|
||||
header http.Header
|
||||
remoteAddr string
|
||||
wantErrMsg string
|
||||
wantIP net.IP
|
||||
wantIP netip.Addr
|
||||
}{{
|
||||
name: "success_no_proxy",
|
||||
header: nil,
|
||||
remoteAddr: remoteAddr,
|
||||
wantErrMsg: "",
|
||||
wantIP: net.IPv4(1, 2, 3, 4),
|
||||
wantIP: netip.MustParseAddr("1.2.3.4"),
|
||||
}, {
|
||||
name: "success_proxy",
|
||||
header: http.Header{
|
||||
@@ -139,7 +139,7 @@ func TestRealIP(t *testing.T) {
|
||||
},
|
||||
remoteAddr: remoteAddr,
|
||||
wantErrMsg: "",
|
||||
wantIP: net.IPv4(1, 2, 3, 5),
|
||||
wantIP: netip.MustParseAddr("1.2.3.5"),
|
||||
}, {
|
||||
name: "success_proxy_multiple",
|
||||
header: http.Header{
|
||||
@@ -149,14 +149,14 @@ func TestRealIP(t *testing.T) {
|
||||
},
|
||||
remoteAddr: remoteAddr,
|
||||
wantErrMsg: "",
|
||||
wantIP: net.IPv4(1, 2, 3, 6),
|
||||
wantIP: netip.MustParseAddr("1.2.3.6"),
|
||||
}, {
|
||||
name: "error_no_proxy",
|
||||
header: nil,
|
||||
remoteAddr: "1:::2",
|
||||
wantErrMsg: `getting ip from client addr: address 1:::2: ` +
|
||||
`too many colons in address`,
|
||||
wantIP: nil,
|
||||
wantIP: netip.Addr{},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/hostsfile"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
@@ -54,7 +55,7 @@ type clientsContainer struct {
|
||||
// ipToRC maps IP addresses to runtime client information.
|
||||
ipToRC map[netip.Addr]*client.Runtime
|
||||
|
||||
allTags *stringutil.Set
|
||||
allTags *container.MapSet[string]
|
||||
|
||||
// dhcp is the DHCP service implementation.
|
||||
dhcp DHCP
|
||||
@@ -108,7 +109,7 @@ func (clients *clientsContainer) Init(
|
||||
|
||||
clients.clientIndex = client.NewIndex()
|
||||
|
||||
clients.allTags = stringutil.NewSet(clientTags...)
|
||||
clients.allTags = container.NewMapSet(clientTags...)
|
||||
|
||||
// TODO(e.burkov): Use [dhcpsvc] implementation when it's ready.
|
||||
clients.dhcp = dhcpServer
|
||||
@@ -213,7 +214,7 @@ type clientObject struct {
|
||||
// toPersistent returns an initialized persistent client if there are no errors.
|
||||
func (o *clientObject) toPersistent(
|
||||
filteringConf *filtering.Config,
|
||||
allTags *stringutil.Set,
|
||||
allTags *container.MapSet[string],
|
||||
) (cli *client.Persistent, err error) {
|
||||
cli = &client.Persistent{
|
||||
Name: o.Name,
|
||||
@@ -307,8 +308,8 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||
BlockedServices: cli.BlockedServices.Clone(),
|
||||
|
||||
IDs: cli.IDs(),
|
||||
Tags: stringutil.CloneSlice(cli.Tags),
|
||||
Upstreams: stringutil.CloneSlice(cli.Upstreams),
|
||||
Tags: slices.Clone(cli.Tags),
|
||||
Upstreams: slices.Clone(cli.Upstreams),
|
||||
|
||||
UID: cli.UID,
|
||||
|
||||
|
||||
@@ -539,13 +539,13 @@ func fatalOnError(err error) {
|
||||
|
||||
// run configures and starts AdGuard Home.
|
||||
func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
// Configure config filename.
|
||||
initConfigFilename(opts)
|
||||
|
||||
// Configure working dir and config path.
|
||||
// Configure working dir.
|
||||
err := initWorkingDir(opts)
|
||||
fatalOnError(err)
|
||||
|
||||
// Configure config filename.
|
||||
initConfigFilename(opts)
|
||||
|
||||
// Configure log level and output.
|
||||
err = configureLogger(opts)
|
||||
fatalOnError(err)
|
||||
@@ -674,8 +674,10 @@ func initUsers() (auth *Auth, err error) {
|
||||
log.Info("authratelimiter is disabled")
|
||||
}
|
||||
|
||||
trustedProxies := netutil.SliceSubnetSet(netutil.UnembedPrefixes(config.DNS.TrustedProxies))
|
||||
|
||||
sessionTTL := config.HTTPConfig.SessionTTL.Seconds()
|
||||
auth = InitAuth(sessFilename, config.Users, uint32(sessionTTL), rateLimiter)
|
||||
auth = InitAuth(sessFilename, config.Users, uint32(sessionTTL), rateLimiter, trustedProxies)
|
||||
if auth == nil {
|
||||
return nil, errors.Error("initializing auth module failed")
|
||||
}
|
||||
@@ -758,11 +760,12 @@ func writePIDFile(fn string) bool {
|
||||
}
|
||||
|
||||
// initConfigFilename sets up context config file path. This file path can be
|
||||
// overridden by command-line arguments, or is set to default.
|
||||
// overridden by command-line arguments, or is set to default. Must only be
|
||||
// called after initializing the workDir with initWorkingDir.
|
||||
func initConfigFilename(opts options) {
|
||||
confPath := opts.confFilename
|
||||
if confPath == "" {
|
||||
Context.confFilePath = "AdGuardHome.yaml"
|
||||
Context.confFilePath = filepath.Join(Context.workDir, "AdGuardHome.yaml")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -5,12 +5,12 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
)
|
||||
|
||||
// TODO(a.garipov): Get rid of a global or generate from .twosky.json.
|
||||
var allowedLanguages = stringutil.NewSet(
|
||||
var allowedLanguages = container.NewMapSet(
|
||||
"ar",
|
||||
"be",
|
||||
"bg",
|
||||
|
||||
@@ -271,11 +271,12 @@ func handleServiceCommand(s service.Service, action string, opts options) (err e
|
||||
return fmt.Errorf("failed to run service: %w", err)
|
||||
}
|
||||
case "install":
|
||||
initConfigFilename(opts)
|
||||
if err = initWorkingDir(opts); err != nil {
|
||||
return fmt.Errorf("failed to init working dir: %w", err)
|
||||
}
|
||||
|
||||
initConfigFilename(opts)
|
||||
|
||||
handleServiceInstallCommand(s)
|
||||
case "uninstall":
|
||||
handleServiceUninstallCommand(s)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/digineo/go-ipset/v2"
|
||||
@@ -174,18 +175,6 @@ func (p *props) parseAttrData(a netfilter.Attribute) {
|
||||
}
|
||||
}
|
||||
|
||||
// unit is a convenient alias for struct{}.
|
||||
type unit = struct{}
|
||||
|
||||
// ipsInIpset is the type of a set of IP-address-to-ipset mappings.
|
||||
type ipsInIpset map[ipInIpsetEntry]unit
|
||||
|
||||
// ipInIpsetEntry is the type for entries in an ipsInIpset set.
|
||||
type ipInIpsetEntry struct {
|
||||
ipsetName string
|
||||
ipArr [net.IPv6len]byte
|
||||
}
|
||||
|
||||
// manager is the Linux Netfilter ipset manager.
|
||||
type manager struct {
|
||||
nameToIpset map[string]props
|
||||
@@ -196,17 +185,24 @@ type manager struct {
|
||||
// mu protects all properties below.
|
||||
mu *sync.Mutex
|
||||
|
||||
// TODO(a.garipov): Currently, the ipset list is static, and we don't
|
||||
// read the IPs already in sets, so we can assume that all incoming IPs
|
||||
// are either added to all corresponding ipsets or not. When that stops
|
||||
// being the case, for example if we add dynamic reconfiguration of
|
||||
// ipsets, this map will need to become a per-ipset-name one.
|
||||
addedIPs ipsInIpset
|
||||
// TODO(a.garipov): Currently, the ipset list is static, and we don't read
|
||||
// the IPs already in sets, so we can assume that all incoming IPs are
|
||||
// either added to all corresponding ipsets or not. When that stops being
|
||||
// the case, for example if we add dynamic reconfiguration of ipsets, this
|
||||
// map will need to become a per-ipset-name one.
|
||||
addedIPs *container.MapSet[ipInIpsetEntry]
|
||||
|
||||
ipv4Conn ipsetConn
|
||||
ipv6Conn ipsetConn
|
||||
}
|
||||
|
||||
// ipInIpsetEntry is the type for entries in [manager.addIPs].
|
||||
type ipInIpsetEntry struct {
|
||||
ipsetName string
|
||||
// TODO(schzen): Use netip.Addr.
|
||||
ipArr [net.IPv6len]byte
|
||||
}
|
||||
|
||||
// dialNetfilter establishes connections to Linux's netfilter module.
|
||||
func (m *manager) dialNetfilter(conf *netlink.Config) (err error) {
|
||||
// The kernel API does not actually require two sockets but package
|
||||
@@ -372,7 +368,7 @@ func newManagerWithDialer(ipsetConf []string, dial dialer) (mgr Manager, err err
|
||||
|
||||
dial: dial,
|
||||
|
||||
addedIPs: make(ipsInIpset),
|
||||
addedIPs: container.NewMapSet[ipInIpsetEntry](),
|
||||
}
|
||||
|
||||
err = m.dialNetfilter(&netlink.Config{})
|
||||
@@ -438,7 +434,7 @@ func (m *manager) addIPs(host string, set props, ips []net.IP) (n int, err error
|
||||
}
|
||||
copy(e.ipArr[:], ip.To16())
|
||||
|
||||
if _, added := m.addedIPs[e]; added {
|
||||
if m.addedIPs.Has(e) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -471,7 +467,7 @@ func (m *manager) addIPs(host string, set props, ips []net.IP) (n int, err error
|
||||
for _, e := range newAddedEntries {
|
||||
s := m.nameToIpset[e.ipsetName]
|
||||
if s.isPersistent {
|
||||
m.addedIPs[e] = unit{}
|
||||
m.addedIPs.Add(e)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
@@ -179,7 +180,8 @@ func decodeResultRuleKey(key string, i int, dec *json.Decoder, ent *logEntry) {
|
||||
case "FilterListID":
|
||||
ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules)
|
||||
if n, ok := vToken.(json.Number); ok {
|
||||
ent.Result.Rules[i].FilterListID, _ = n.Int64()
|
||||
id, _ := n.Int64()
|
||||
ent.Result.Rules[i].FilterListID = rulelist.URLFilterID(id)
|
||||
}
|
||||
case "IP":
|
||||
ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules)
|
||||
@@ -582,7 +584,7 @@ var resultHandlers = map[string]logEntryHandler{
|
||||
return nil
|
||||
}
|
||||
|
||||
i, err := n.Int64()
|
||||
id, err := n.Int64()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -593,7 +595,7 @@ var resultHandlers = map[string]logEntryHandler{
|
||||
l++
|
||||
}
|
||||
|
||||
ent.Result.Rules[l-1].FilterListID = i
|
||||
ent.Result.Rules[l-1].FilterListID = rulelist.URLFilterID(id)
|
||||
|
||||
return nil
|
||||
},
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -15,7 +16,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
@@ -308,7 +308,7 @@ func parseSearchCriterion(q url.Values, name string, ct criterionType) (
|
||||
asciiVal = ""
|
||||
}
|
||||
case ctFilteringStatus:
|
||||
if !stringutil.InSlice(filteringStatusValues, val) {
|
||||
if !slices.Contains(filteringStatusValues, val) {
|
||||
return false, sc, fmt.Errorf("invalid value %s", val)
|
||||
}
|
||||
default:
|
||||
|
||||
@@ -26,7 +26,7 @@ require (
|
||||
github.com/kyoh86/nolint v0.0.1 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 // indirect
|
||||
golang.org/x/exp/typeparams v0.0.0-20240222234643-814bf88cf225 // indirect
|
||||
golang.org/x/exp/typeparams v0.0.0-20240325151524-a685a6edb6d8 // indirect
|
||||
golang.org/x/mod v0.16.0 // indirect
|
||||
golang.org/x/sync v0.6.0 // indirect
|
||||
golang.org/x/sys v0.18.0 // indirect
|
||||
|
||||
@@ -63,8 +63,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
|
||||
golang.org/x/exp v0.0.0-20230321023759-10a507213a29/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
|
||||
golang.org/x/exp/typeparams v0.0.0-20240222234643-814bf88cf225 h1:BzKNaIRXh1bD+1557OcFIHlpYBiVbK4zEyn8zBHi1SE=
|
||||
golang.org/x/exp/typeparams v0.0.0-20240222234643-814bf88cf225/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
|
||||
golang.org/x/exp/typeparams v0.0.0-20240325151524-a685a6edb6d8 h1:ShhqwXlNzuDeQzaa6htzo1S333ACXZzJZgZLpKAza8E=
|
||||
golang.org/x/exp/typeparams v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
|
||||
Reference in New Issue
Block a user