all: sync with master; upd chlog
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
@@ -52,7 +53,7 @@ const textPlainDeprMsg = `using this api with the text/plain content-type is dep
|
||||
// deprecation and removal of a plain-text API if the request is made with the
|
||||
// "text/plain" content-type.
|
||||
func WriteTextPlainDeprecated(w http.ResponseWriter, r *http.Request) (isPlainText bool) {
|
||||
if r.Header.Get(HdrNameContentType) != HdrValTextPlain {
|
||||
if r.Header.Get(httphdr.ContentType) != HdrValTextPlain {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -72,7 +73,7 @@ func WriteJSONResponse(w http.ResponseWriter, r *http.Request, resp any) (err er
|
||||
// redefine the status code.
|
||||
func WriteJSONResponseCode(w http.ResponseWriter, r *http.Request, code int, resp any) (err error) {
|
||||
w.WriteHeader(code)
|
||||
w.Header().Set(HdrNameContentType, HdrValApplicationJSON)
|
||||
w.Header().Set(httphdr.ContentType, HdrValApplicationJSON)
|
||||
err = json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
Error(r, w, http.StatusInternalServerError, "encoding resp: %s", err)
|
||||
|
||||
@@ -1,22 +1,6 @@
|
||||
package aghhttp
|
||||
|
||||
// HTTP Headers
|
||||
|
||||
// HTTP header name constants.
|
||||
//
|
||||
// TODO(a.garipov): Remove unused.
|
||||
const (
|
||||
HdrNameAcceptEncoding = "Accept-Encoding"
|
||||
HdrNameAccessControlAllowOrigin = "Access-Control-Allow-Origin"
|
||||
HdrNameAltSvc = "Alt-Svc"
|
||||
HdrNameContentEncoding = "Content-Encoding"
|
||||
HdrNameContentType = "Content-Type"
|
||||
HdrNameOrigin = "Origin"
|
||||
HdrNameServer = "Server"
|
||||
HdrNameTrailer = "Trailer"
|
||||
HdrNameUserAgent = "User-Agent"
|
||||
HdrNameVary = "Vary"
|
||||
)
|
||||
// HTTP headers
|
||||
|
||||
// HTTP header value constants.
|
||||
const (
|
||||
|
||||
@@ -1,47 +1,14 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
)
|
||||
|
||||
// The maximum lengths of generated hostnames for different IP versions.
|
||||
const (
|
||||
ipv4HostnameMaxLen = len("192-168-100-100")
|
||||
ipv6HostnameMaxLen = len("ff80-f076-0000-0000-0000-0000-0000-0010")
|
||||
)
|
||||
|
||||
// generateIPv4Hostname generates the hostname by IP address version 4.
|
||||
func generateIPv4Hostname(ipv4 net.IP) (hostname string) {
|
||||
hnData := make([]byte, 0, ipv4HostnameMaxLen)
|
||||
for i, part := range ipv4 {
|
||||
if i > 0 {
|
||||
hnData = append(hnData, '-')
|
||||
}
|
||||
hnData = strconv.AppendUint(hnData, uint64(part), 10)
|
||||
}
|
||||
|
||||
return string(hnData)
|
||||
}
|
||||
|
||||
// generateIPv6Hostname generates the hostname by IP address version 6.
|
||||
func generateIPv6Hostname(ipv6 net.IP) (hostname string) {
|
||||
hnData := make([]byte, 0, ipv6HostnameMaxLen)
|
||||
for i, partsNum := 0, net.IPv6len/2; i < partsNum; i++ {
|
||||
if i > 0 {
|
||||
hnData = append(hnData, '-')
|
||||
}
|
||||
for _, val := range ipv6[i*2 : i*2+2] {
|
||||
if val < 10 {
|
||||
hnData = append(hnData, '0')
|
||||
}
|
||||
hnData = strconv.AppendUint(hnData, uint64(val), 16)
|
||||
}
|
||||
}
|
||||
|
||||
return string(hnData)
|
||||
}
|
||||
|
||||
// GenerateHostname generates the hostname from ip. In case of using IPv4 the
|
||||
// result should be like:
|
||||
//
|
||||
@@ -52,10 +19,42 @@ func generateIPv6Hostname(ipv6 net.IP) (hostname string) {
|
||||
// ff80-f076-0000-0000-0000-0000-0000-0010
|
||||
//
|
||||
// ip must be either an IPv4 or an IPv6.
|
||||
func GenerateHostname(ip net.IP) (hostname string) {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
return generateIPv4Hostname(ipv4)
|
||||
func GenerateHostname(ip netip.Addr) (hostname string) {
|
||||
if !ip.IsValid() {
|
||||
// TODO(s.chzhen): Get rid of it.
|
||||
panic("aghnet generate hostname: invalid ip")
|
||||
}
|
||||
|
||||
return generateIPv6Hostname(ip)
|
||||
ip = ip.Unmap()
|
||||
hostname = ip.StringExpanded()
|
||||
|
||||
if ip.Is4() {
|
||||
return strings.Replace(hostname, ".", "-", -1)
|
||||
}
|
||||
|
||||
return strings.Replace(hostname, ":", "-", -1)
|
||||
}
|
||||
|
||||
// NewDomainNameSet returns nil and error, if list has duplicate or empty
|
||||
// domain name. Otherwise returns a set, which contains non-FQDN domain names,
|
||||
// and nil error.
|
||||
func NewDomainNameSet(list []string) (set *stringutil.Set, err error) {
|
||||
set = stringutil.NewSet()
|
||||
|
||||
for i, v := range list {
|
||||
host := strings.ToLower(strings.TrimSuffix(v, "."))
|
||||
// TODO(a.garipov): Think about ignoring empty (".") names in the
|
||||
// future.
|
||||
if host == "" {
|
||||
return nil, errors.Error("host name is empty")
|
||||
}
|
||||
|
||||
if set.Has(host) {
|
||||
return nil, fmt.Errorf("duplicate host name %q at index %d", host, i)
|
||||
}
|
||||
|
||||
set.Add(host)
|
||||
}
|
||||
|
||||
return set, nil
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -12,19 +12,19 @@ func TestGenerateHostName(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
want string
|
||||
ip net.IP
|
||||
ip netip.Addr
|
||||
}{{
|
||||
name: "good_ipv4",
|
||||
want: "127-0-0-1",
|
||||
ip: net.IP{127, 0, 0, 1},
|
||||
ip: netip.MustParseAddr("127.0.0.1"),
|
||||
}, {
|
||||
name: "good_ipv6",
|
||||
want: "fe00-0000-0000-0000-0000-0000-0000-0001",
|
||||
ip: net.ParseIP("fe00::1"),
|
||||
ip: netip.MustParseAddr("fe00::1"),
|
||||
}, {
|
||||
name: "4to6",
|
||||
want: "1-2-3-4",
|
||||
ip: net.ParseIP("::ffff:1.2.3.4"),
|
||||
ip: netip.MustParseAddr("::ffff:1.2.3.4"),
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -36,29 +36,6 @@ func TestGenerateHostName(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
ip net.IP
|
||||
}{{
|
||||
name: "bad_ipv4",
|
||||
ip: net.IP{127, 0, 0, 1, 0},
|
||||
}, {
|
||||
name: "bad_ipv6",
|
||||
ip: net.IP{
|
||||
0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff,
|
||||
},
|
||||
}, {
|
||||
name: "nil",
|
||||
ip: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Panics(t, func() { GenerateHostname(tc.ip) })
|
||||
})
|
||||
}
|
||||
assert.Panics(t, func() { GenerateHostname(netip.Addr{}) })
|
||||
})
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ func (rm *requestMatcher) MatchRequest(
|
||||
) (res *urlfilter.DNSResult, ok bool) {
|
||||
switch req.DNSType {
|
||||
case dns.TypeA, dns.TypeAAAA, dns.TypePTR:
|
||||
log.Debug("%s: handling the request for %s", hostsContainerPref, req.Hostname)
|
||||
log.Debug("%s: handling the request for %s", hostsContainerPrefix, req.Hostname)
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
@@ -90,9 +90,9 @@ func (rm *requestMatcher) resetEng(rulesStrg *filterlist.RuleStorage, tr map[str
|
||||
rm.translator = tr
|
||||
}
|
||||
|
||||
// hostsContainerPref is a prefix for logging and wrapping errors in
|
||||
// hostsContainerPrefix is a prefix for logging and wrapping errors in
|
||||
// HostsContainer's methods.
|
||||
const hostsContainerPref = "hosts container"
|
||||
const hostsContainerPrefix = "hosts container"
|
||||
|
||||
// HostsContainer stores the relevant hosts database provided by the OS and
|
||||
// processes both A/AAAA and PTR DNS requests for those.
|
||||
@@ -115,8 +115,8 @@ type HostsContainer struct {
|
||||
// fsys is the working file system to read hosts files from.
|
||||
fsys fs.FS
|
||||
|
||||
// w tracks the changes in specified files and directories.
|
||||
w aghos.FSWatcher
|
||||
// watcher tracks the changes in specified files and directories.
|
||||
watcher aghos.FSWatcher
|
||||
|
||||
// patterns stores specified paths in the fs.Glob-compatible form.
|
||||
patterns []string
|
||||
@@ -160,7 +160,7 @@ func NewHostsContainer(
|
||||
w aghos.FSWatcher,
|
||||
paths ...string,
|
||||
) (hc *HostsContainer, err error) {
|
||||
defer func() { err = errors.Annotate(err, "%s: %w", hostsContainerPref) }()
|
||||
defer func() { err = errors.Annotate(err, "%s: %w", hostsContainerPrefix) }()
|
||||
|
||||
if len(paths) == 0 {
|
||||
return nil, ErrNoHostsPaths
|
||||
@@ -182,11 +182,11 @@ func NewHostsContainer(
|
||||
done: make(chan struct{}, 1),
|
||||
updates: make(chan HostsRecords, 1),
|
||||
fsys: fsys,
|
||||
w: w,
|
||||
watcher: w,
|
||||
patterns: patterns,
|
||||
}
|
||||
|
||||
log.Debug("%s: starting", hostsContainerPref)
|
||||
log.Debug("%s: starting", hostsContainerPrefix)
|
||||
|
||||
// Load initially.
|
||||
if err = hc.refresh(); err != nil {
|
||||
@@ -199,7 +199,7 @@ func NewHostsContainer(
|
||||
return nil, fmt.Errorf("adding path: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("%s: %s is expected to exist but doesn't", hostsContainerPref, p)
|
||||
log.Debug("%s: %s is expected to exist but doesn't", hostsContainerPrefix, p)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -208,14 +208,21 @@ func NewHostsContainer(
|
||||
return hc, nil
|
||||
}
|
||||
|
||||
// Close implements the io.Closer interface for *HostsContainer. Close must
|
||||
// only be called once. The returned err is always nil.
|
||||
// Close implements the [io.Closer] interface for *HostsContainer. It closes
|
||||
// both itself and its [aghos.FSWatcher]. Close must only be called once.
|
||||
func (hc *HostsContainer) Close() (err error) {
|
||||
log.Debug("%s: closing", hostsContainerPref)
|
||||
log.Debug("%s: closing", hostsContainerPrefix)
|
||||
|
||||
err = hc.watcher.Close()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("closing fs watcher: %w", err)
|
||||
|
||||
// Go on and close the container either way.
|
||||
}
|
||||
|
||||
close(hc.done)
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// Upd returns the channel into which the updates are sent.
|
||||
@@ -251,22 +258,22 @@ func pathsToPatterns(fsys fs.FS, paths []string) (patterns []string, err error)
|
||||
// update channel of HostsContainer when finishes. It's used to be called
|
||||
// within a separate goroutine.
|
||||
func (hc *HostsContainer) handleEvents() {
|
||||
defer log.OnPanic(fmt.Sprintf("%s: handling events", hostsContainerPref))
|
||||
defer log.OnPanic(fmt.Sprintf("%s: handling events", hostsContainerPrefix))
|
||||
|
||||
defer close(hc.updates)
|
||||
|
||||
ok, eventsCh := true, hc.w.Events()
|
||||
ok, eventsCh := true, hc.watcher.Events()
|
||||
for ok {
|
||||
select {
|
||||
case _, ok = <-eventsCh:
|
||||
if !ok {
|
||||
log.Debug("%s: watcher closed the events channel", hostsContainerPref)
|
||||
log.Debug("%s: watcher closed the events channel", hostsContainerPrefix)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if err := hc.refresh(); err != nil {
|
||||
log.Error("%s: %s", hostsContainerPref, err)
|
||||
log.Error("%s: %s", hostsContainerPrefix, err)
|
||||
}
|
||||
case _, ok = <-hc.done:
|
||||
// Go on.
|
||||
@@ -345,7 +352,7 @@ func (hp *hostsParser) parseLine(line string) (ip netip.Addr, hosts []string) {
|
||||
// TODO(e.burkov): Investigate if hosts may contain DNS-SD domains.
|
||||
err = netutil.ValidateHostname(f)
|
||||
if err != nil {
|
||||
log.Error("%s: host %q is invalid, ignoring", hostsContainerPref, f)
|
||||
log.Error("%s: host %q is invalid, ignoring", hostsContainerPrefix, f)
|
||||
|
||||
continue
|
||||
}
|
||||
@@ -389,7 +396,7 @@ func (hp *hostsParser) addRules(ip netip.Addr, host, line string) {
|
||||
rule, rulePtr := hp.writeRules(host, ip)
|
||||
hp.translations[rule], hp.translations[rulePtr] = line, line
|
||||
|
||||
log.Debug("%s: added ip-host pair %q-%q", hostsContainerPref, ip, host)
|
||||
log.Debug("%s: added ip-host pair %q-%q", hostsContainerPrefix, ip, host)
|
||||
}
|
||||
|
||||
// writeRules writes the actual rule for the qtype and the PTR for the host-ip
|
||||
@@ -443,7 +450,7 @@ func (hp *hostsParser) writeRules(host string, ip netip.Addr) (rule, rulePtr str
|
||||
|
||||
// sendUpd tries to send the parsed data to the ch.
|
||||
func (hp *hostsParser) sendUpd(ch chan HostsRecords) {
|
||||
log.Debug("%s: sending upd", hostsContainerPref)
|
||||
log.Debug("%s: sending upd", hostsContainerPrefix)
|
||||
|
||||
upd := hp.table
|
||||
select {
|
||||
@@ -451,11 +458,11 @@ func (hp *hostsParser) sendUpd(ch chan HostsRecords) {
|
||||
// Updates are delivered. Go on.
|
||||
case <-ch:
|
||||
ch <- upd
|
||||
log.Debug("%s: replaced the last update", hostsContainerPref)
|
||||
log.Debug("%s: replaced the last update", hostsContainerPrefix)
|
||||
case ch <- upd:
|
||||
// The previous update was just read and the next one pushed. Go on.
|
||||
default:
|
||||
log.Error("%s: the updates channel is broken", hostsContainerPref)
|
||||
log.Error("%s: the updates channel is broken", hostsContainerPrefix)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -473,7 +480,7 @@ func (hp *hostsParser) newStrg(id int) (s *filterlist.RuleStorage, err error) {
|
||||
//
|
||||
// TODO(e.burkov): Accept a parameter to specify the files to refresh.
|
||||
func (hc *HostsContainer) refresh() (err error) {
|
||||
log.Debug("%s: refreshing", hostsContainerPref)
|
||||
log.Debug("%s: refreshing", hostsContainerPrefix)
|
||||
|
||||
hp := hc.newHostsParser()
|
||||
if _, err = aghos.FileWalker(hp.parseFile).Walk(hc.fsys, hc.patterns...); err != nil {
|
||||
@@ -482,7 +489,7 @@ func (hc *HostsContainer) refresh() (err error) {
|
||||
|
||||
// hc.last is nil on the first refresh, so let that one through.
|
||||
if hc.last != nil && maps.EqualFunc(hp.table, hc.last, (*HostsRecord).equal) {
|
||||
log.Debug("%s: no changes detected", hostsContainerPref)
|
||||
log.Debug("%s: no changes detected", hostsContainerPrefix)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
2
internal/aghnet/testdata/etc_hosts
vendored
2
internal/aghnet/testdata/etc_hosts
vendored
@@ -35,4 +35,4 @@
|
||||
1.3.5.7 domain4 domain4.alias
|
||||
7.5.3.1 domain4.alias domain4
|
||||
::13 domain6 domain6.alias
|
||||
::31 domain6.alias domain6
|
||||
::31 domain6.alias domain6
|
||||
|
||||
2
internal/aghnet/testdata/ifaces
vendored
2
internal/aghnet/testdata/ifaces
vendored
@@ -1 +1 @@
|
||||
iface sample_name inet static
|
||||
iface sample_name inet static
|
||||
|
||||
2
internal/aghnet/testdata/include-subsources
vendored
2
internal/aghnet/testdata/include-subsources
vendored
@@ -2,4 +2,4 @@
|
||||
# parent directory. Real interface files usually contain only absolute paths.
|
||||
|
||||
source ./testdata/ifaces
|
||||
source ./testdata/*
|
||||
source ./testdata/*
|
||||
|
||||
2
internal/aghnet/testdata/proc_net_arp
vendored
2
internal/aghnet/testdata/proc_net_arp
vendored
@@ -3,4 +3,4 @@ IP address HW type Flags HW address Mask Device
|
||||
::ffff:ffff 0x1 0x0 ef:cd:ab:ef:cd:ab * br-lan
|
||||
0.0.0.0 0x0 0x0 00:00:00:00:00:00 * unspec
|
||||
1.2.3.4.5 0x1 0x2 aa:bb:cc:dd:ee:ff * wan
|
||||
1.2.3.4 0x1 0x2 12:34:56:78:910 * wan
|
||||
1.2.3.4 0x1 0x2 12:34:56:78:910 * wan
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
//go:build mips || mips64
|
||||
|
||||
// This file is an adapted version of github.com/josharian/native.
|
||||
|
||||
package aghos
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
// NativeEndian is the native endianness of this system.
|
||||
var NativeEndian = binary.BigEndian
|
||||
@@ -1,10 +0,0 @@
|
||||
//go:build amd64 || 386 || arm || arm64 || mipsle || mips64le || ppc64le
|
||||
|
||||
// This file is an adapted version of github.com/josharian/native.
|
||||
|
||||
package aghos
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
// NativeEndian is the native endianness of this system.
|
||||
var NativeEndian = binary.LittleEndian
|
||||
@@ -15,18 +15,16 @@ import (
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/server4"
|
||||
"github.com/mdlayher/ethernet"
|
||||
|
||||
//lint:ignore SA1019 See the TODO in go.mod.
|
||||
"github.com/mdlayher/raw"
|
||||
"github.com/mdlayher/packet"
|
||||
)
|
||||
|
||||
// dhcpUnicastAddr is the combination of MAC and IP addresses for responding to
|
||||
// the unconfigured host.
|
||||
type dhcpUnicastAddr struct {
|
||||
// raw.Addr is embedded here to make *dhcpUcastAddr a net.Addr without
|
||||
// packet.Addr is embedded here to make *dhcpUcastAddr a net.Addr without
|
||||
// actually implementing all methods. It also contains the client's
|
||||
// hardware address.
|
||||
raw.Addr
|
||||
packet.Addr
|
||||
|
||||
// yiaddr is an IP address just allocated by server for the host.
|
||||
yiaddr net.IP
|
||||
@@ -52,7 +50,7 @@ type dhcpConn struct {
|
||||
// newDHCPConn creates the special connection for DHCP server.
|
||||
func (s *v4Server) newDHCPConn(iface *net.Interface) (c net.PacketConn, err error) {
|
||||
var ucast net.PacketConn
|
||||
if ucast, err = raw.ListenPacket(iface, uint16(ethernet.EtherTypeIPv4), nil); err != nil {
|
||||
if ucast, err = packet.Listen(iface, packet.Raw, int(ethernet.EtherTypeIPv4), nil); err != nil {
|
||||
return nil, fmt.Errorf("creating raw udp connection: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -10,11 +10,9 @@ import (
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/mdlayher/packet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
//lint:ignore SA1019 See the TODO in go.mod.
|
||||
"github.com/mdlayher/raw"
|
||||
)
|
||||
|
||||
func TestDHCPConn_WriteTo_common(t *testing.T) {
|
||||
@@ -57,7 +55,7 @@ func TestBuildEtherPkt(t *testing.T) {
|
||||
srcIP: net.IP{1, 2, 3, 4},
|
||||
}
|
||||
peer := &dhcpUnicastAddr{
|
||||
Addr: raw.Addr{HardwareAddr: net.HardwareAddr{6, 5, 4, 3, 2, 1}},
|
||||
Addr: packet.Addr{HardwareAddr: net.HardwareAddr{6, 5, 4, 3, 2, 1}},
|
||||
yiaddr: net.IP{4, 3, 2, 1},
|
||||
}
|
||||
payload := (&dhcpv4.DHCPv4{}).ToBytes()
|
||||
@@ -102,7 +100,7 @@ func TestBuildEtherPkt(t *testing.T) {
|
||||
t.Run("serializing_error", func(t *testing.T) {
|
||||
// Create a peer with invalid MAC.
|
||||
badPeer := &dhcpUnicastAddr{
|
||||
Addr: raw.Addr{HardwareAddr: net.HardwareAddr{5, 4, 3, 2, 1}},
|
||||
Addr: packet.Addr{HardwareAddr: net.HardwareAddr{5, 4, 3, 2, 1}},
|
||||
yiaddr: net.IP{4, 3, 2, 1},
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
@@ -32,6 +33,8 @@ func normalizeIP(ip net.IP) net.IP {
|
||||
}
|
||||
|
||||
// Load lease table from DB
|
||||
//
|
||||
// TODO(s.chzhen): Decrease complexity.
|
||||
func (s *server) dbLoad() (err error) {
|
||||
dynLeases := []*Lease{}
|
||||
staticLeases := []*Lease{}
|
||||
@@ -57,26 +60,28 @@ func (s *server) dbLoad() (err error) {
|
||||
for i := range obj {
|
||||
obj[i].IP = normalizeIP(obj[i].IP)
|
||||
|
||||
if !(len(obj[i].IP) == 4 || len(obj[i].IP) == 16) {
|
||||
ip, ok := netip.AddrFromSlice(obj[i].IP)
|
||||
if !ok {
|
||||
log.Info("dhcp: invalid IP: %s", obj[i].IP)
|
||||
continue
|
||||
}
|
||||
|
||||
lease := Lease{
|
||||
HWAddr: obj[i].HWAddr,
|
||||
IP: obj[i].IP,
|
||||
IP: ip,
|
||||
Hostname: obj[i].Hostname,
|
||||
Expiry: time.Unix(obj[i].Expiry, 0),
|
||||
IsStatic: obj[i].Expiry == leaseExpireStatic,
|
||||
}
|
||||
|
||||
if len(obj[i].IP) == 16 {
|
||||
if obj[i].Expiry == leaseExpireStatic {
|
||||
if lease.IsStatic {
|
||||
v6StaticLeases = append(v6StaticLeases, &lease)
|
||||
} else {
|
||||
v6DynLeases = append(v6DynLeases, &lease)
|
||||
}
|
||||
} else {
|
||||
if obj[i].Expiry == leaseExpireStatic {
|
||||
if lease.IsStatic {
|
||||
staticLeases = append(staticLeases, &lease)
|
||||
} else {
|
||||
dynLeases = append(dynLeases, &lease)
|
||||
@@ -145,7 +150,7 @@ func (s *server) dbStore() (err error) {
|
||||
|
||||
lease := leaseJSON{
|
||||
HWAddr: l.HWAddr,
|
||||
IP: l.IP,
|
||||
IP: l.IP.AsSlice(),
|
||||
Hostname: l.Hostname,
|
||||
Expiry: l.Expiry.Unix(),
|
||||
}
|
||||
@@ -162,7 +167,7 @@ func (s *server) dbStore() (err error) {
|
||||
|
||||
lease := leaseJSON{
|
||||
HWAddr: l.HWAddr,
|
||||
IP: l.IP,
|
||||
IP: l.IP.AsSlice(),
|
||||
Hostname: l.Hostname,
|
||||
Expiry: l.Expiry.Unix(),
|
||||
}
|
||||
|
||||
@@ -41,13 +41,19 @@ type Lease struct {
|
||||
// of 1 means that this is a static lease.
|
||||
Expiry time.Time `json:"expires"`
|
||||
|
||||
Hostname string `json:"hostname"`
|
||||
HWAddr net.HardwareAddr `json:"mac"`
|
||||
// Hostname of the client.
|
||||
Hostname string `json:"hostname"`
|
||||
|
||||
// HWAddr is the physical hardware address (MAC address).
|
||||
HWAddr net.HardwareAddr `json:"mac"`
|
||||
|
||||
// IP is the IP address leased to the client.
|
||||
//
|
||||
// TODO(a.garipov): Migrate leases.db and use netip.Addr.
|
||||
IP net.IP `json:"ip"`
|
||||
// TODO(a.garipov): Migrate leases.db.
|
||||
IP netip.Addr `json:"ip"`
|
||||
|
||||
// IsStatic defines if the lease is static.
|
||||
IsStatic bool `json:"static"`
|
||||
}
|
||||
|
||||
// Clone returns a deep copy of l.
|
||||
@@ -60,7 +66,8 @@ func (l *Lease) Clone() (clone *Lease) {
|
||||
Expiry: l.Expiry,
|
||||
Hostname: l.Hostname,
|
||||
HWAddr: slices.Clone(l.HWAddr),
|
||||
IP: slices.Clone(l.IP),
|
||||
IP: l.IP,
|
||||
IsStatic: l.IsStatic,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,17 +88,10 @@ func (l *Lease) IsBlocklisted() (ok bool) {
|
||||
return true
|
||||
}
|
||||
|
||||
// IsStatic returns true if the lease is static.
|
||||
//
|
||||
// TODO(a.garipov): Just make it a boolean field.
|
||||
func (l *Lease) IsStatic() (ok bool) {
|
||||
return l != nil && l.Expiry.Unix() == leaseExpireStatic
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface for Lease.
|
||||
func (l Lease) MarshalJSON() ([]byte, error) {
|
||||
var expiryStr string
|
||||
if !l.IsStatic() {
|
||||
if !l.IsStatic {
|
||||
// The front-end is waiting for RFC 3999 format of the time
|
||||
// value. It also shouldn't got an Expiry field for static
|
||||
// leases.
|
||||
|
||||
@@ -48,11 +48,11 @@ func TestDB(t *testing.T) {
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
Hostname: "static-1.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 100},
|
||||
IP: netip.MustParseAddr("192.168.10.100"),
|
||||
}, {
|
||||
Hostname: "static-2.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xBB},
|
||||
IP: net.IP{192, 168, 10, 101},
|
||||
IP: netip.MustParseAddr("192.168.10.101"),
|
||||
}}
|
||||
|
||||
srv4, ok := s.srv4.(*v4Server)
|
||||
@@ -80,7 +80,7 @@ func TestDB(t *testing.T) {
|
||||
|
||||
assert.Equal(t, leases[1].HWAddr, ll[0].HWAddr)
|
||||
assert.Equal(t, leases[1].IP, ll[0].IP)
|
||||
assert.True(t, ll[0].IsStatic())
|
||||
assert.True(t, ll[0].IsStatic)
|
||||
|
||||
assert.Equal(t, leases[0].HWAddr, ll[1].HWAddr)
|
||||
assert.Equal(t, leases[0].IP, ll[1].IP)
|
||||
@@ -96,7 +96,7 @@ func TestNormalizeLeases(t *testing.T) {
|
||||
|
||||
staticLeases := []*Lease{{
|
||||
HWAddr: net.HardwareAddr{1, 2, 3, 4},
|
||||
IP: net.IP{0, 2, 3, 4},
|
||||
IP: netip.MustParseAddr("0.2.3.4"),
|
||||
}, {
|
||||
HWAddr: net.HardwareAddr{2, 2, 3, 4},
|
||||
}}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
@@ -57,12 +58,77 @@ func v6JSONToServerConf(j *v6ServerConfJSON) V6ServerConf {
|
||||
|
||||
// dhcpStatusResponse is the response for /control/dhcp/status endpoint.
|
||||
type dhcpStatusResponse struct {
|
||||
IfaceName string `json:"interface_name"`
|
||||
V4 V4ServerConf `json:"v4"`
|
||||
V6 V6ServerConf `json:"v6"`
|
||||
Leases []*Lease `json:"leases"`
|
||||
StaticLeases []*Lease `json:"static_leases"`
|
||||
Enabled bool `json:"enabled"`
|
||||
IfaceName string `json:"interface_name"`
|
||||
V4 V4ServerConf `json:"v4"`
|
||||
V6 V6ServerConf `json:"v6"`
|
||||
Leases []*leaseDynamic `json:"leases"`
|
||||
StaticLeases []*leaseStatic `json:"static_leases"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// leaseStatic is the JSON form of static DHCP lease.
|
||||
type leaseStatic struct {
|
||||
HWAddr string `json:"mac"`
|
||||
IP netip.Addr `json:"ip"`
|
||||
Hostname string `json:"hostname"`
|
||||
}
|
||||
|
||||
// leasesToStatic converts list of leases to their JSON form.
|
||||
func leasesToStatic(leases []*Lease) (static []*leaseStatic) {
|
||||
static = make([]*leaseStatic, len(leases))
|
||||
|
||||
for i, l := range leases {
|
||||
static[i] = &leaseStatic{
|
||||
HWAddr: l.HWAddr.String(),
|
||||
IP: l.IP,
|
||||
Hostname: l.Hostname,
|
||||
}
|
||||
}
|
||||
|
||||
return static
|
||||
}
|
||||
|
||||
// toLease converts leaseStatic to Lease or returns error.
|
||||
func (l *leaseStatic) toLease() (lease *Lease, err error) {
|
||||
addr, err := net.ParseMAC(l.HWAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("couldn't parse MAC address: %w", err)
|
||||
}
|
||||
|
||||
return &Lease{
|
||||
HWAddr: addr,
|
||||
IP: l.IP,
|
||||
Hostname: l.Hostname,
|
||||
IsStatic: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// leaseDynamic is the JSON form of dynamic DHCP lease.
|
||||
type leaseDynamic struct {
|
||||
HWAddr string `json:"mac"`
|
||||
IP netip.Addr `json:"ip"`
|
||||
Hostname string `json:"hostname"`
|
||||
Expiry string `json:"expires"`
|
||||
}
|
||||
|
||||
// leasesToDynamic converts list of leases to their JSON form.
|
||||
func leasesToDynamic(leases []*Lease) (dynamic []*leaseDynamic) {
|
||||
dynamic = make([]*leaseDynamic, len(leases))
|
||||
|
||||
for i, l := range leases {
|
||||
dynamic[i] = &leaseDynamic{
|
||||
HWAddr: l.HWAddr.String(),
|
||||
IP: l.IP,
|
||||
Hostname: l.Hostname,
|
||||
// The front-end is waiting for RFC 3999 format of the time
|
||||
// value.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2692.
|
||||
Expiry: l.Expiry.Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
return dynamic
|
||||
}
|
||||
|
||||
func (s *server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -76,8 +142,8 @@ func (s *server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv4.WriteDiskConfig4(&status.V4)
|
||||
s.srv6.WriteDiskConfig6(&status.V6)
|
||||
|
||||
status.Leases = s.Leases(LeasesDynamic)
|
||||
status.StaticLeases = s.Leases(LeasesStatic)
|
||||
status.Leases = leasesToDynamic(s.Leases(LeasesDynamic))
|
||||
status.StaticLeases = leasesToStatic(s.Leases(LeasesStatic))
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, status)
|
||||
}
|
||||
@@ -488,7 +554,7 @@ func setOtherDHCPResult(ifaceName string, result *dhcpSearchResult) {
|
||||
}
|
||||
|
||||
func (s *server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
|
||||
l := &Lease{}
|
||||
l := &leaseStatic{}
|
||||
err := json.NewDecoder(r.Body).Decode(l)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
@@ -496,22 +562,29 @@ func (s *server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
if l.IP == nil {
|
||||
if !l.IP.IsValid() {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "invalid IP")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
l.IP = l.IP.Unmap()
|
||||
|
||||
var srv DHCPServer
|
||||
if ip4 := l.IP.To4(); ip4 != nil {
|
||||
l.IP = ip4
|
||||
if l.IP.Is4() {
|
||||
srv = s.srv4
|
||||
} else {
|
||||
l.IP = l.IP.To16()
|
||||
srv = s.srv6
|
||||
}
|
||||
|
||||
err = srv.AddStaticLease(l)
|
||||
lease, err := l.toLease()
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "parsing: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = srv.AddStaticLease(lease)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
@@ -520,7 +593,7 @@ func (s *server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
func (s *server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
|
||||
l := &Lease{}
|
||||
l := &leaseStatic{}
|
||||
err := json.NewDecoder(r.Body).Decode(l)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
@@ -528,27 +601,29 @@ func (s *server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
|
||||
if l.IP == nil {
|
||||
if !l.IP.IsValid() {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "invalid IP")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ip4 := l.IP.To4()
|
||||
l.IP = l.IP.Unmap()
|
||||
|
||||
if ip4 == nil {
|
||||
l.IP = l.IP.To16()
|
||||
var srv DHCPServer
|
||||
if l.IP.Is4() {
|
||||
srv = s.srv4
|
||||
} else {
|
||||
srv = s.srv6
|
||||
}
|
||||
|
||||
err = s.srv6.RemoveStaticLease(l)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
}
|
||||
lease, err := l.toLease()
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "parsing: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
l.IP = ip4
|
||||
err = s.srv4.RemoveStaticLease(l)
|
||||
err = srv.RemoveStaticLease(lease)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
|
||||
160
internal/dhcpd/http_unix_test.go
Normal file
160
internal/dhcpd/http_unix_test.go
Normal file
@@ -0,0 +1,160 @@
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServer_handleDHCPStatus(t *testing.T) {
|
||||
const (
|
||||
staticName = "static-client"
|
||||
staticMAC = "aa:aa:aa:aa:aa:aa"
|
||||
)
|
||||
|
||||
staticIP := netip.MustParseAddr("192.168.10.10")
|
||||
|
||||
staticLease := &leaseStatic{
|
||||
HWAddr: staticMAC,
|
||||
IP: staticIP,
|
||||
Hostname: staticName,
|
||||
}
|
||||
|
||||
s, err := Create(&ServerConfig{
|
||||
Enabled: true,
|
||||
Conf4: *defaultV4ServerConf(),
|
||||
WorkDir: t.TempDir(),
|
||||
DBFilePath: dbFilename,
|
||||
ConfigModified: func() {},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// checkStatus is a helper that asserts the response of
|
||||
// [*server.handleDHCPStatus].
|
||||
checkStatus := func(t *testing.T, want *dhcpStatusResponse) {
|
||||
w := httptest.NewRecorder()
|
||||
var req *http.Request
|
||||
req, err = http.NewRequest(http.MethodGet, "", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
err = json.NewEncoder(b).Encode(&want)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.handleDHCPStatus(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
assert.JSONEq(t, b.String(), w.Body.String())
|
||||
}
|
||||
|
||||
// defaultResponse is a helper that returs the response with default
|
||||
// configuration.
|
||||
defaultResponse := func() *dhcpStatusResponse {
|
||||
conf4 := defaultV4ServerConf()
|
||||
conf4.LeaseDuration = 86400
|
||||
|
||||
resp := &dhcpStatusResponse{
|
||||
V4: *conf4,
|
||||
V6: V6ServerConf{},
|
||||
Leases: []*leaseDynamic{},
|
||||
StaticLeases: []*leaseStatic{},
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
ok := t.Run("status", func(t *testing.T) {
|
||||
resp := defaultResponse()
|
||||
|
||||
checkStatus(t, resp)
|
||||
})
|
||||
require.True(t, ok)
|
||||
|
||||
ok = t.Run("add_static_lease", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
err = json.NewEncoder(b).Encode(staticLease)
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", b)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.handleDHCPAddStaticLease(w, r)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
resp := defaultResponse()
|
||||
resp.StaticLeases = []*leaseStatic{staticLease}
|
||||
|
||||
checkStatus(t, resp)
|
||||
})
|
||||
require.True(t, ok)
|
||||
|
||||
ok = t.Run("add_invalid_lease", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
|
||||
err = json.NewEncoder(b).Encode(&leaseStatic{})
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", b)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.handleDHCPAddStaticLease(w, r)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
})
|
||||
require.True(t, ok)
|
||||
|
||||
ok = t.Run("remove_static_lease", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
err = json.NewEncoder(b).Encode(staticLease)
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", b)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.handleDHCPRemoveStaticLease(w, r)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
resp := defaultResponse()
|
||||
|
||||
checkStatus(t, resp)
|
||||
})
|
||||
require.True(t, ok)
|
||||
|
||||
ok = t.Run("set_config", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
resp := defaultResponse()
|
||||
resp.Enabled = false
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
err = json.NewEncoder(b).Encode(&resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", b)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.handleDHCPSetConfig(w, r)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
checkStatus(t, resp)
|
||||
})
|
||||
require.True(t, ok)
|
||||
}
|
||||
@@ -16,6 +16,8 @@ import (
|
||||
//
|
||||
// TODO(a.garipov): Perhaps create an optimized version with uint32 for IPv4
|
||||
// ranges? Or use one of uint128 packages?
|
||||
//
|
||||
// TODO(e.burkov): Use netip.Addr.
|
||||
type ipRange struct {
|
||||
start *big.Int
|
||||
end *big.Int
|
||||
@@ -27,8 +29,6 @@ const maxRangeLen = math.MaxUint32
|
||||
|
||||
// newIPRange creates a new IP address range. start must be less than end. The
|
||||
// resulting range must not be greater than maxRangeLen.
|
||||
//
|
||||
// TODO(e.burkov): Use netip.Addr.
|
||||
func newIPRange(start, end net.IP) (r *ipRange, err error) {
|
||||
defer func() { err = errors.Annotate(err, "invalid ip range: %w") }()
|
||||
|
||||
|
||||
@@ -20,10 +20,8 @@ import (
|
||||
"github.com/go-ping/ping"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/server4"
|
||||
"github.com/mdlayher/packet"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
//lint:ignore SA1019 See the TODO in go.mod.
|
||||
"github.com/mdlayher/raw"
|
||||
)
|
||||
|
||||
// v4Server is a DHCPv4 server.
|
||||
@@ -98,7 +96,7 @@ func normalizeHostname(hostname string) (norm string, err error) {
|
||||
// validHostnameForClient accepts the hostname sent by the client and its IP and
|
||||
// returns either a normalized version of that hostname, or a new hostname
|
||||
// generated from the IP address, or an empty string.
|
||||
func (s *v4Server) validHostnameForClient(cliHostname string, ip net.IP) (hostname string) {
|
||||
func (s *v4Server) validHostnameForClient(cliHostname string, ip netip.Addr) (hostname string) {
|
||||
hostname, err := normalizeHostname(cliHostname)
|
||||
if err != nil {
|
||||
log.Info("dhcpv4: %s", err)
|
||||
@@ -130,7 +128,7 @@ func (s *v4Server) ResetLeases(leases []*Lease) (err error) {
|
||||
s.leases = nil
|
||||
|
||||
for _, l := range leases {
|
||||
if !l.IsStatic() {
|
||||
if !l.IsStatic {
|
||||
l.Hostname = s.validHostnameForClient(l.Hostname, l.IP)
|
||||
}
|
||||
err = s.addLease(l)
|
||||
@@ -192,7 +190,7 @@ func (s *v4Server) GetLeases(flags GetLeasesFlags) (leases []*Lease) {
|
||||
continue
|
||||
}
|
||||
|
||||
if getStatic && l.IsStatic() {
|
||||
if getStatic && l.IsStatic {
|
||||
leases = append(leases, l.Clone())
|
||||
}
|
||||
}
|
||||
@@ -211,10 +209,9 @@ func (s *v4Server) FindMACbyIP(ip netip.Addr) (mac net.HardwareAddr) {
|
||||
return nil
|
||||
}
|
||||
|
||||
netIP := ip.AsSlice()
|
||||
for _, l := range s.leases {
|
||||
if l.IP.Equal(netIP) {
|
||||
if l.Expiry.After(now) || l.IsStatic() {
|
||||
if l.IP == ip {
|
||||
if l.IsStatic || l.Expiry.After(now) {
|
||||
return l.HWAddr
|
||||
}
|
||||
}
|
||||
@@ -247,7 +244,8 @@ func (s *v4Server) rmLeaseByIndex(i int) {
|
||||
s.leases = append(s.leases[:i], s.leases[i+1:]...)
|
||||
|
||||
r := s.conf.ipRange
|
||||
offset, ok := r.offset(l.IP)
|
||||
leaseIP := net.IP(l.IP.AsSlice())
|
||||
offset, ok := r.offset(leaseIP)
|
||||
if ok {
|
||||
s.leasedOffsets.set(offset, false)
|
||||
}
|
||||
@@ -261,9 +259,9 @@ func (s *v4Server) rmLeaseByIndex(i int) {
|
||||
// Return error if a static lease is found
|
||||
func (s *v4Server) rmDynamicLease(lease *Lease) (err error) {
|
||||
for i, l := range s.leases {
|
||||
isStatic := l.IsStatic()
|
||||
isStatic := l.IsStatic
|
||||
|
||||
if bytes.Equal(l.HWAddr, lease.HWAddr) || l.IP.Equal(lease.IP) {
|
||||
if bytes.Equal(l.HWAddr, lease.HWAddr) || l.IP == lease.IP {
|
||||
if isStatic {
|
||||
return errors.Error("static lease already exists")
|
||||
}
|
||||
@@ -291,13 +289,13 @@ const ErrDupHostname = errors.Error("hostname is not unique")
|
||||
// addLease adds a dynamic or static lease.
|
||||
func (s *v4Server) addLease(l *Lease) (err error) {
|
||||
r := s.conf.ipRange
|
||||
offset, inOffset := r.offset(l.IP)
|
||||
leaseIP := net.IP(l.IP.AsSlice())
|
||||
offset, inOffset := r.offset(leaseIP)
|
||||
|
||||
if l.IsStatic() {
|
||||
if l.IsStatic {
|
||||
// TODO(a.garipov, d.seregin): Subnet can be nil when dhcp server is
|
||||
// disabled.
|
||||
addr := netip.AddrFrom4(*(*[4]byte)(l.IP.To4()))
|
||||
if sn := s.conf.subnet; !sn.Contains(addr) {
|
||||
if sn := s.conf.subnet; !sn.Contains(l.IP) {
|
||||
return fmt.Errorf("subnet %s does not contain the ip %q", sn, l.IP)
|
||||
}
|
||||
} else if !inOffset {
|
||||
@@ -325,7 +323,7 @@ func (s *v4Server) rmLease(lease *Lease) (err error) {
|
||||
}
|
||||
|
||||
for i, l := range s.leases {
|
||||
if l.IP.Equal(lease.IP) {
|
||||
if l.IP == lease.IP {
|
||||
if !bytes.Equal(l.HWAddr, lease.HWAddr) || l.Hostname != lease.Hostname {
|
||||
return fmt.Errorf("lease for ip %s is different: %+v", lease.IP, l)
|
||||
}
|
||||
@@ -352,14 +350,16 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
|
||||
return ErrUnconfigured
|
||||
}
|
||||
|
||||
ip := l.IP.To4()
|
||||
if ip == nil {
|
||||
l.IP = l.IP.Unmap()
|
||||
|
||||
if !l.IP.Is4() {
|
||||
return fmt.Errorf("invalid ip %q, only ipv4 is supported", l.IP)
|
||||
} else if gwIP := s.conf.GatewayIP; gwIP == netip.AddrFrom4(*(*[4]byte)(ip)) {
|
||||
} else if gwIP := s.conf.GatewayIP; gwIP == l.IP {
|
||||
return fmt.Errorf("can't assign the gateway IP %s to the lease", gwIP)
|
||||
}
|
||||
|
||||
l.Expiry = time.Unix(leaseExpireStatic, 0)
|
||||
l.IsStatic = true
|
||||
|
||||
err = netutil.ValidateMAC(l.HWAddr)
|
||||
if err != nil {
|
||||
@@ -396,7 +396,7 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
|
||||
if err != nil {
|
||||
err = fmt.Errorf(
|
||||
"removing dynamic leases for %s (%s): %w",
|
||||
ip,
|
||||
l.IP,
|
||||
l.HWAddr,
|
||||
err,
|
||||
)
|
||||
@@ -406,7 +406,7 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
|
||||
|
||||
err = s.addLease(l)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("adding static lease for %s (%s): %w", ip, l.HWAddr, err)
|
||||
err = fmt.Errorf("adding static lease for %s (%s): %w", l.IP, l.HWAddr, err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -429,7 +429,7 @@ func (s *v4Server) RemoveStaticLease(l *Lease) (err error) {
|
||||
return ErrUnconfigured
|
||||
}
|
||||
|
||||
if len(l.IP) != 4 {
|
||||
if !l.IP.Is4() {
|
||||
return fmt.Errorf("invalid IP")
|
||||
}
|
||||
|
||||
@@ -529,7 +529,7 @@ func (s *v4Server) nextIP() (ip net.IP) {
|
||||
func (s *v4Server) findExpiredLease() int {
|
||||
now := time.Now()
|
||||
for i, lease := range s.leases {
|
||||
if !lease.IsStatic() && lease.Expiry.Before(now) {
|
||||
if !lease.IsStatic && lease.Expiry.Before(now) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
@@ -542,8 +542,8 @@ func (s *v4Server) findExpiredLease() int {
|
||||
func (s *v4Server) reserveLease(mac net.HardwareAddr) (l *Lease, err error) {
|
||||
l = &Lease{HWAddr: slices.Clone(mac)}
|
||||
|
||||
l.IP = s.nextIP()
|
||||
if l.IP == nil {
|
||||
nextIP := s.nextIP()
|
||||
if nextIP == nil {
|
||||
i := s.findExpiredLease()
|
||||
if i < 0 {
|
||||
return nil, nil
|
||||
@@ -554,6 +554,13 @@ func (s *v4Server) reserveLease(mac net.HardwareAddr) (l *Lease, err error) {
|
||||
return s.leases[i], nil
|
||||
}
|
||||
|
||||
netIP, ok := netip.AddrFromSlice(nextIP)
|
||||
if !ok {
|
||||
return nil, errors.Error("invalid ip")
|
||||
}
|
||||
|
||||
l.IP = netIP
|
||||
|
||||
err = s.addLease(l)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -603,7 +610,8 @@ func (s *v4Server) allocateLease(mac net.HardwareAddr) (l *Lease, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if s.addrAvailable(l.IP) {
|
||||
leaseIP := l.IP.AsSlice()
|
||||
if s.addrAvailable(leaseIP) {
|
||||
return l, nil
|
||||
}
|
||||
|
||||
@@ -623,8 +631,9 @@ func (s *v4Server) handleDiscover(req, resp *dhcpv4.DHCPv4) (l *Lease, err error
|
||||
l = s.findLease(mac)
|
||||
if l != nil {
|
||||
reqIP := req.RequestedIPAddress()
|
||||
if len(reqIP) != 0 && !reqIP.Equal(l.IP) {
|
||||
log.Debug("dhcpv4: different RequestedIP: %s != %s", reqIP, l.IP)
|
||||
leaseIP := net.IP(l.IP.AsSlice())
|
||||
if len(reqIP) != 0 && !reqIP.Equal(leaseIP) {
|
||||
log.Debug("dhcpv4: different RequestedIP: %s != %s", reqIP, leaseIP)
|
||||
}
|
||||
|
||||
resp.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeOffer))
|
||||
@@ -674,12 +683,19 @@ func (s *v4Server) checkLease(mac net.HardwareAddr, ip net.IP) (lease *Lease, mi
|
||||
s.leasesLock.Lock()
|
||||
defer s.leasesLock.Unlock()
|
||||
|
||||
netIP, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
log.Info("check lease: invalid IP: %s", ip)
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
for _, l := range s.leases {
|
||||
if !bytes.Equal(l.HWAddr, mac) {
|
||||
continue
|
||||
}
|
||||
|
||||
if l.IP.Equal(ip) {
|
||||
if l.IP == netIP {
|
||||
return l, false
|
||||
}
|
||||
|
||||
@@ -845,7 +861,7 @@ func (s *v4Server) handleRequest(req, resp *dhcpv4.DHCPv4) (lease *Lease, needsR
|
||||
s.leasesLock.Lock()
|
||||
defer s.leasesLock.Unlock()
|
||||
|
||||
if lease.IsStatic() {
|
||||
if lease.IsStatic {
|
||||
if lease.Hostname != "" {
|
||||
// TODO(e.burkov): This option is used to update the server's DNS
|
||||
// mapping. The option should only be answered when it has been
|
||||
@@ -878,9 +894,16 @@ func (s *v4Server) handleDecline(req, resp *dhcpv4.DHCPv4) (err error) {
|
||||
reqIP = req.ClientIPAddr
|
||||
}
|
||||
|
||||
netIP, ok := netip.AddrFromSlice(reqIP)
|
||||
if !ok {
|
||||
log.Info("dhcpv4: invalid IP: %s", reqIP)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var oldLease *Lease
|
||||
for _, l := range s.leases {
|
||||
if bytes.Equal(l.HWAddr, mac) && l.IP.Equal(reqIP) {
|
||||
if bytes.Equal(l.HWAddr, mac) && l.IP == netIP {
|
||||
oldLease = l
|
||||
|
||||
break
|
||||
@@ -920,8 +943,7 @@ func (s *v4Server) handleDecline(req, resp *dhcpv4.DHCPv4) (err error) {
|
||||
|
||||
log.Info("dhcpv4: changed ip from %s to %s for %s", reqIP, newLease.IP, mac)
|
||||
|
||||
resp.YourIPAddr = make([]byte, 4)
|
||||
copy(resp.YourIPAddr, newLease.IP)
|
||||
resp.YourIPAddr = net.IP(newLease.IP.AsSlice())
|
||||
|
||||
resp.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeAck))
|
||||
|
||||
@@ -944,8 +966,15 @@ func (s *v4Server) handleRelease(req, resp *dhcpv4.DHCPv4) (err error) {
|
||||
s.leasesLock.Lock()
|
||||
defer s.leasesLock.Unlock()
|
||||
|
||||
netIP, ok := netip.AddrFromSlice(reqIP)
|
||||
if !ok {
|
||||
log.Info("dhcpv4: invalid IP: %s", reqIP)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, l := range s.leases {
|
||||
if !bytes.Equal(l.HWAddr, mac) || !l.IP.Equal(reqIP) {
|
||||
if !bytes.Equal(l.HWAddr, mac) || l.IP != netIP {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1018,7 +1047,7 @@ func (s *v4Server) handle(req, resp *dhcpv4.DHCPv4) int {
|
||||
}
|
||||
|
||||
if l != nil {
|
||||
resp.YourIPAddr = slices.Clone(l.IP)
|
||||
resp.YourIPAddr = net.IP(l.IP.AsSlice())
|
||||
}
|
||||
|
||||
s.updateOptions(req, resp)
|
||||
@@ -1136,7 +1165,7 @@ func (s *v4Server) send(peer net.Addr, conn net.PacketConn, req, resp *dhcpv4.DH
|
||||
// Unicast DHCPOFFER and DHCPACK messages to the client's
|
||||
// hardware address and yiaddr.
|
||||
peer = &dhcpUnicastAddr{
|
||||
Addr: raw.Addr{HardwareAddr: req.ClientHWAddr},
|
||||
Addr: packet.Addr{HardwareAddr: req.ClientHWAddr},
|
||||
yiaddr: resp.YourIPAddr,
|
||||
}
|
||||
default:
|
||||
|
||||
@@ -15,11 +15,9 @@ import (
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/mdlayher/packet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
//lint:ignore SA1019 See the TODO in go.mod.
|
||||
"github.com/mdlayher/raw"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -62,7 +60,7 @@ func TestV4Server_leasing(t *testing.T) {
|
||||
anotherName = "another-client"
|
||||
)
|
||||
|
||||
staticIP := net.IP{192, 168, 10, 10}
|
||||
staticIP := netip.MustParseAddr("192.168.10.10")
|
||||
anotherIP := DefaultRangeStart
|
||||
staticMAC := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
anotherMAC := net.HardwareAddr{0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB}
|
||||
@@ -75,6 +73,7 @@ func TestV4Server_leasing(t *testing.T) {
|
||||
Hostname: staticName,
|
||||
HWAddr: staticMAC,
|
||||
IP: staticIP,
|
||||
IsStatic: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -83,7 +82,8 @@ func TestV4Server_leasing(t *testing.T) {
|
||||
Expiry: time.Unix(leaseExpireStatic, 0),
|
||||
Hostname: staticName,
|
||||
HWAddr: anotherMAC,
|
||||
IP: anotherIP.AsSlice(),
|
||||
IP: anotherIP,
|
||||
IsStatic: true,
|
||||
})
|
||||
assert.ErrorIs(t, err, ErrDupHostname)
|
||||
})
|
||||
@@ -97,7 +97,8 @@ func TestV4Server_leasing(t *testing.T) {
|
||||
Expiry: time.Unix(leaseExpireStatic, 0),
|
||||
Hostname: anotherName,
|
||||
HWAddr: staticMAC,
|
||||
IP: anotherIP.AsSlice(),
|
||||
IP: anotherIP,
|
||||
IsStatic: true,
|
||||
})
|
||||
testutil.AssertErrorMsg(t, wantErrMsg, err)
|
||||
})
|
||||
@@ -112,6 +113,7 @@ func TestV4Server_leasing(t *testing.T) {
|
||||
Hostname: anotherName,
|
||||
HWAddr: anotherMAC,
|
||||
IP: staticIP,
|
||||
IsStatic: true,
|
||||
})
|
||||
testutil.AssertErrorMsg(t, wantErrMsg, err)
|
||||
})
|
||||
@@ -124,13 +126,14 @@ func TestV4Server_leasing(t *testing.T) {
|
||||
discoverAnOffer := func(
|
||||
t *testing.T,
|
||||
name string,
|
||||
ip net.IP,
|
||||
netIP netip.Addr,
|
||||
mac net.HardwareAddr,
|
||||
) (resp *dhcpv4.DHCPv4) {
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
return s.ResetLeases(s.GetLeases(LeasesStatic))
|
||||
})
|
||||
|
||||
ip := net.IP(netIP.AsSlice())
|
||||
req, err := dhcpv4.NewDiscovery(
|
||||
mac,
|
||||
dhcpv4.WithOption(dhcpv4.OptHostName(name)),
|
||||
@@ -151,7 +154,7 @@ func TestV4Server_leasing(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("same_name", func(t *testing.T) {
|
||||
resp := discoverAnOffer(t, staticName, anotherIP.AsSlice(), anotherMAC)
|
||||
resp := discoverAnOffer(t, staticName, anotherIP, anotherMAC)
|
||||
|
||||
req, err := dhcpv4.NewRequestFromOffer(resp, dhcpv4.WithOption(
|
||||
dhcpv4.OptHostName(staticName),
|
||||
@@ -161,11 +164,15 @@ func TestV4Server_leasing(t *testing.T) {
|
||||
res := s4.handle(req, resp)
|
||||
require.Positive(t, res)
|
||||
|
||||
assert.Equal(t, aghnet.GenerateHostname(resp.YourIPAddr), resp.HostName())
|
||||
var netIP netip.Addr
|
||||
netIP, ok = netip.AddrFromSlice(resp.YourIPAddr)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, aghnet.GenerateHostname(netIP), resp.HostName())
|
||||
})
|
||||
|
||||
t.Run("same_mac", func(t *testing.T) {
|
||||
resp := discoverAnOffer(t, anotherName, anotherIP.AsSlice(), staticMAC)
|
||||
resp := discoverAnOffer(t, anotherName, anotherIP, staticMAC)
|
||||
|
||||
req, err := dhcpv4.NewRequestFromOffer(resp, dhcpv4.WithOption(
|
||||
dhcpv4.OptHostName(anotherName),
|
||||
@@ -179,7 +186,8 @@ func TestV4Server_leasing(t *testing.T) {
|
||||
require.Len(t, fqdnOptData, 3+len(staticName))
|
||||
assert.Equal(t, []uint8(staticName), fqdnOptData[3:])
|
||||
|
||||
assert.Equal(t, staticIP, resp.YourIPAddr)
|
||||
ip := net.IP(staticIP.AsSlice())
|
||||
assert.Equal(t, ip, resp.YourIPAddr)
|
||||
})
|
||||
|
||||
t.Run("same_ip", func(t *testing.T) {
|
||||
@@ -212,7 +220,7 @@ func TestV4Server_AddRemove_static(t *testing.T) {
|
||||
lease: &Lease{
|
||||
Hostname: "success.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
IP: netip.MustParseAddr("192.168.10.150"),
|
||||
},
|
||||
name: "success",
|
||||
wantErrMsg: "",
|
||||
@@ -220,7 +228,7 @@ func TestV4Server_AddRemove_static(t *testing.T) {
|
||||
lease: &Lease{
|
||||
Hostname: "probably-router.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: DefaultGatewayIP.AsSlice(),
|
||||
IP: DefaultGatewayIP,
|
||||
},
|
||||
name: "with_gateway_ip",
|
||||
wantErrMsg: "dhcpv4: adding static lease: " +
|
||||
@@ -229,7 +237,7 @@ func TestV4Server_AddRemove_static(t *testing.T) {
|
||||
lease: &Lease{
|
||||
Hostname: "ip6.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.ParseIP("ffff::1"),
|
||||
IP: netip.MustParseAddr("ffff::1"),
|
||||
},
|
||||
name: "ipv6",
|
||||
wantErrMsg: `dhcpv4: adding static lease: ` +
|
||||
@@ -238,7 +246,7 @@ func TestV4Server_AddRemove_static(t *testing.T) {
|
||||
lease: &Lease{
|
||||
Hostname: "bad-mac.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
IP: netip.MustParseAddr("192.168.10.150"),
|
||||
},
|
||||
name: "bad_mac",
|
||||
wantErrMsg: `dhcpv4: adding static lease: bad mac address "aa:aa": ` +
|
||||
@@ -247,7 +255,7 @@ func TestV4Server_AddRemove_static(t *testing.T) {
|
||||
lease: &Lease{
|
||||
Hostname: "bad-lbl-.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
IP: netip.MustParseAddr("192.168.10.150"),
|
||||
},
|
||||
name: "bad_hostname",
|
||||
wantErrMsg: `dhcpv4: adding static lease: validating hostname: ` +
|
||||
@@ -289,11 +297,11 @@ func TestV4_AddReplace(t *testing.T) {
|
||||
dynLeases := []Lease{{
|
||||
Hostname: "dynamic-1.local",
|
||||
HWAddr: net.HardwareAddr{0x11, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
IP: netip.MustParseAddr("192.168.10.150"),
|
||||
}, {
|
||||
Hostname: "dynamic-2.local",
|
||||
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 151},
|
||||
IP: netip.MustParseAddr("192.168.10.151"),
|
||||
}}
|
||||
|
||||
for i := range dynLeases {
|
||||
@@ -304,11 +312,11 @@ func TestV4_AddReplace(t *testing.T) {
|
||||
stLeases := []*Lease{{
|
||||
Hostname: "static-1.local",
|
||||
HWAddr: net.HardwareAddr{0x33, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
IP: netip.MustParseAddr("192.168.10.150"),
|
||||
}, {
|
||||
Hostname: "static-2.local",
|
||||
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 152},
|
||||
IP: netip.MustParseAddr("192.168.10.152"),
|
||||
}}
|
||||
|
||||
for _, l := range stLeases {
|
||||
@@ -320,9 +328,9 @@ func TestV4_AddReplace(t *testing.T) {
|
||||
require.Len(t, ls, 2)
|
||||
|
||||
for i, l := range ls {
|
||||
assert.True(t, stLeases[i].IP.Equal(l.IP))
|
||||
assert.Equal(t, stLeases[i].IP, l.IP)
|
||||
assert.Equal(t, stLeases[i].HWAddr, l.HWAddr)
|
||||
assert.True(t, l.IsStatic())
|
||||
assert.True(t, l.IsStatic)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -513,7 +521,7 @@ func TestV4StaticLease_Get(t *testing.T) {
|
||||
l := &Lease{
|
||||
Hostname: "static-1.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
IP: netip.MustParseAddr("192.168.10.150"),
|
||||
}
|
||||
err := s.AddStaticLease(l)
|
||||
require.NoError(t, err)
|
||||
@@ -539,7 +547,9 @@ func TestV4StaticLease_Get(t *testing.T) {
|
||||
t.Run("offer", func(t *testing.T) {
|
||||
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
|
||||
assert.Equal(t, mac, resp.ClientHWAddr)
|
||||
assert.True(t, l.IP.Equal(resp.YourIPAddr))
|
||||
|
||||
ip := net.IP(l.IP.AsSlice())
|
||||
assert.True(t, ip.Equal(resp.YourIPAddr))
|
||||
|
||||
assert.True(t, resp.Router()[0].Equal(s.conf.GatewayIP.AsSlice()))
|
||||
assert.True(t, resp.ServerIdentifier().Equal(s.conf.GatewayIP.AsSlice()))
|
||||
@@ -564,7 +574,9 @@ func TestV4StaticLease_Get(t *testing.T) {
|
||||
t.Run("ack", func(t *testing.T) {
|
||||
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
|
||||
assert.Equal(t, mac, resp.ClientHWAddr)
|
||||
assert.True(t, l.IP.Equal(resp.YourIPAddr))
|
||||
|
||||
ip := net.IP(l.IP.AsSlice())
|
||||
assert.True(t, ip.Equal(resp.YourIPAddr))
|
||||
|
||||
assert.True(t, resp.Router()[0].Equal(s.conf.GatewayIP.AsSlice()))
|
||||
assert.True(t, resp.ServerIdentifier().Equal(s.conf.GatewayIP.AsSlice()))
|
||||
@@ -583,7 +595,7 @@ func TestV4StaticLease_Get(t *testing.T) {
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
require.Len(t, ls, 1)
|
||||
|
||||
assert.True(t, l.IP.Equal(ls[0].IP))
|
||||
assert.Equal(t, l.IP, ls[0].IP)
|
||||
assert.Equal(t, mac, ls[0].HWAddr)
|
||||
})
|
||||
}
|
||||
@@ -681,7 +693,8 @@ func TestV4DynamicLease_Get(t *testing.T) {
|
||||
ls := s.GetLeases(LeasesDynamic)
|
||||
require.Len(t, ls, 1)
|
||||
|
||||
assert.True(t, net.IP{192, 168, 10, 100}.Equal(ls[0].IP))
|
||||
ip := netip.MustParseAddr("192.168.10.100")
|
||||
assert.Equal(t, ip, ls[0].IP)
|
||||
assert.Equal(t, mac, ls[0].HWAddr)
|
||||
})
|
||||
}
|
||||
@@ -810,7 +823,7 @@ func TestV4Server_Send(t *testing.T) {
|
||||
req: &dhcpv4.DHCPv4{ClientHWAddr: knownMAC},
|
||||
resp: &dhcpv4.DHCPv4{YourIPAddr: knownIP},
|
||||
want: &dhcpUnicastAddr{
|
||||
Addr: raw.Addr{HardwareAddr: knownMAC},
|
||||
Addr: packet.Addr{HardwareAddr: knownMAC},
|
||||
yiaddr: knownIP,
|
||||
},
|
||||
}, {
|
||||
@@ -862,3 +875,144 @@ func TestV4Server_Send(t *testing.T) {
|
||||
assert.True(t, resp.IsBroadcast())
|
||||
})
|
||||
}
|
||||
|
||||
func TestV4Server_FindMACbyIP(t *testing.T) {
|
||||
const (
|
||||
staticName = "static-client"
|
||||
anotherName = "another-client"
|
||||
)
|
||||
|
||||
staticIP := netip.MustParseAddr("192.168.10.10")
|
||||
staticMAC := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
|
||||
anotherIP := netip.MustParseAddr("192.168.100.100")
|
||||
anotherMAC := net.HardwareAddr{0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB}
|
||||
|
||||
s := &v4Server{
|
||||
leases: []*Lease{{
|
||||
Expiry: time.Unix(leaseExpireStatic, 0),
|
||||
Hostname: staticName,
|
||||
HWAddr: staticMAC,
|
||||
IP: staticIP,
|
||||
IsStatic: true,
|
||||
}, {
|
||||
Expiry: time.Unix(10, 0),
|
||||
Hostname: anotherName,
|
||||
HWAddr: anotherMAC,
|
||||
IP: anotherIP,
|
||||
}},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
want net.HardwareAddr
|
||||
ip netip.Addr
|
||||
name string
|
||||
}{{
|
||||
name: "basic",
|
||||
ip: staticIP,
|
||||
want: staticMAC,
|
||||
}, {
|
||||
name: "not_found",
|
||||
ip: netip.MustParseAddr("1.2.3.4"),
|
||||
want: nil,
|
||||
}, {
|
||||
name: "expired",
|
||||
ip: anotherIP,
|
||||
want: nil,
|
||||
}, {
|
||||
name: "v6",
|
||||
ip: netip.MustParseAddr("ffff::1"),
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mac := s.FindMACbyIP(tc.ip)
|
||||
|
||||
require.Equal(t, tc.want, mac)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestV4Server_handleDecline(t *testing.T) {
|
||||
const (
|
||||
dynamicName = "dynamic-client"
|
||||
anotherName = "another-client"
|
||||
)
|
||||
|
||||
dynamicIP := netip.MustParseAddr("192.168.10.200")
|
||||
dynamicMAC := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
|
||||
s := defaultSrv(t)
|
||||
|
||||
s4, ok := s.(*v4Server)
|
||||
require.True(t, ok)
|
||||
|
||||
s4.leases = []*Lease{{
|
||||
Hostname: dynamicName,
|
||||
HWAddr: dynamicMAC,
|
||||
IP: dynamicIP,
|
||||
}}
|
||||
|
||||
req, err := dhcpv4.New(
|
||||
dhcpv4.WithOption(dhcpv4.OptRequestedIPAddress(net.IP(dynamicIP.AsSlice()))),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.ClientIPAddr = net.IP(dynamicIP.AsSlice())
|
||||
req.ClientHWAddr = dynamicMAC
|
||||
|
||||
resp := &dhcpv4.DHCPv4{}
|
||||
err = s4.handleDecline(req, resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
wantResp := &dhcpv4.DHCPv4{
|
||||
YourIPAddr: net.IP(s4.conf.RangeStart.AsSlice()),
|
||||
Options: dhcpv4.OptionsFromList(
|
||||
dhcpv4.OptMessageType(dhcpv4.MessageTypeAck),
|
||||
),
|
||||
}
|
||||
|
||||
require.Equal(t, wantResp, resp)
|
||||
}
|
||||
|
||||
func TestV4Server_handleRelease(t *testing.T) {
|
||||
const (
|
||||
dynamicName = "dymamic-client"
|
||||
anotherName = "another-client"
|
||||
)
|
||||
|
||||
dynamicIP := netip.MustParseAddr("192.168.10.200")
|
||||
dynamicMAC := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
|
||||
s := defaultSrv(t)
|
||||
|
||||
s4, ok := s.(*v4Server)
|
||||
require.True(t, ok)
|
||||
|
||||
s4.leases = []*Lease{{
|
||||
Hostname: dynamicName,
|
||||
HWAddr: dynamicMAC,
|
||||
IP: dynamicIP,
|
||||
}}
|
||||
|
||||
req, err := dhcpv4.New(
|
||||
dhcpv4.WithOption(dhcpv4.OptRequestedIPAddress(net.IP(dynamicIP.AsSlice()))),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.ClientIPAddr = net.IP(dynamicIP.AsSlice())
|
||||
req.ClientHWAddr = dynamicMAC
|
||||
|
||||
resp := &dhcpv4.DHCPv4{}
|
||||
err = s4.handleRelease(req, resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
wantResp := &dhcpv4.DHCPv4{
|
||||
Options: dhcpv4.OptionsFromList(
|
||||
dhcpv4.OptMessageType(dhcpv4.MessageTypeAck),
|
||||
),
|
||||
}
|
||||
|
||||
require.Equal(t, wantResp, resp)
|
||||
}
|
||||
|
||||
@@ -61,13 +61,13 @@ func ip6InRange(start, ip net.IP) bool {
|
||||
|
||||
// ResetLeases resets leases.
|
||||
func (s *v6Server) ResetLeases(leases []*Lease) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "dhcpv4: %w") }()
|
||||
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
|
||||
|
||||
s.leases = nil
|
||||
for _, l := range leases {
|
||||
|
||||
ip := net.IP(l.IP.AsSlice())
|
||||
if l.Expiry.Unix() != leaseExpireStatic &&
|
||||
!ip6InRange(s.conf.ipStart, l.IP) {
|
||||
!ip6InRange(s.conf.ipStart, ip) {
|
||||
|
||||
log.Debug("dhcpv6: skipping a lease with IP %v: not within current IP range", l.IP)
|
||||
|
||||
@@ -119,10 +119,9 @@ func (s *v6Server) FindMACbyIP(ip netip.Addr) (mac net.HardwareAddr) {
|
||||
return nil
|
||||
}
|
||||
|
||||
netIP := ip.AsSlice()
|
||||
for _, l := range s.leases {
|
||||
if l.IP.Equal(netIP) {
|
||||
if l.Expiry.After(now) || l.IsStatic() {
|
||||
if l.IP == ip {
|
||||
if l.IsStatic || l.Expiry.After(now) {
|
||||
return l.HWAddr
|
||||
}
|
||||
}
|
||||
@@ -133,7 +132,8 @@ func (s *v6Server) FindMACbyIP(ip netip.Addr) (mac net.HardwareAddr) {
|
||||
|
||||
// Remove (swap) lease by index
|
||||
func (s *v6Server) leaseRemoveSwapByIndex(i int) {
|
||||
s.ipAddrs[s.leases[i].IP[15]] = 0
|
||||
leaseIP := s.leases[i].IP.As16()
|
||||
s.ipAddrs[leaseIP[15]] = 0
|
||||
log.Debug("dhcpv6: removed lease %s", s.leases[i].HWAddr)
|
||||
|
||||
n := len(s.leases)
|
||||
@@ -162,7 +162,7 @@ func (s *v6Server) rmDynamicLease(lease *Lease) (err error) {
|
||||
l = s.leases[i]
|
||||
}
|
||||
|
||||
if net.IP.Equal(l.IP, lease.IP) {
|
||||
if l.IP == lease.IP {
|
||||
if l.Expiry.Unix() == leaseExpireStatic {
|
||||
return fmt.Errorf("static lease already exists")
|
||||
}
|
||||
@@ -178,7 +178,7 @@ func (s *v6Server) rmDynamicLease(lease *Lease) (err error) {
|
||||
func (s *v6Server) AddStaticLease(l *Lease) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
|
||||
|
||||
if len(l.IP) != net.IPv6len {
|
||||
if !l.IP.Is6() {
|
||||
return fmt.Errorf("invalid IP")
|
||||
}
|
||||
|
||||
@@ -210,7 +210,7 @@ func (s *v6Server) AddStaticLease(l *Lease) (err error) {
|
||||
func (s *v6Server) RemoveStaticLease(l *Lease) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
|
||||
|
||||
if len(l.IP) != 16 {
|
||||
if !l.IP.Is6() {
|
||||
return fmt.Errorf("invalid IP")
|
||||
}
|
||||
|
||||
@@ -234,14 +234,15 @@ func (s *v6Server) RemoveStaticLease(l *Lease) (err error) {
|
||||
// Add a lease
|
||||
func (s *v6Server) addLease(l *Lease) {
|
||||
s.leases = append(s.leases, l)
|
||||
s.ipAddrs[l.IP[15]] = 1
|
||||
ip := l.IP.As16()
|
||||
s.ipAddrs[ip[15]] = 1
|
||||
log.Debug("dhcpv6: added lease %s <-> %s", l.IP, l.HWAddr)
|
||||
}
|
||||
|
||||
// Remove a lease with the same properties
|
||||
func (s *v6Server) rmLease(lease *Lease) (err error) {
|
||||
for i, l := range s.leases {
|
||||
if net.IP.Equal(l.IP, lease.IP) {
|
||||
if l.IP == lease.IP {
|
||||
if !bytes.Equal(l.HWAddr, lease.HWAddr) ||
|
||||
l.Hostname != lease.Hostname {
|
||||
return fmt.Errorf("lease not found")
|
||||
@@ -308,18 +309,27 @@ func (s *v6Server) reserveLease(mac net.HardwareAddr) *Lease {
|
||||
s.leasesLock.Lock()
|
||||
defer s.leasesLock.Unlock()
|
||||
|
||||
copy(l.IP, s.conf.ipStart)
|
||||
l.IP = s.findFreeIP()
|
||||
if l.IP == nil {
|
||||
ip := s.findFreeIP()
|
||||
if ip == nil {
|
||||
i := s.findExpiredLease()
|
||||
if i < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
copy(s.leases[i].HWAddr, mac)
|
||||
|
||||
return s.leases[i]
|
||||
}
|
||||
|
||||
netIP, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
l.IP = netIP
|
||||
|
||||
s.addLease(&l)
|
||||
|
||||
return &l
|
||||
}
|
||||
|
||||
@@ -388,7 +398,8 @@ func (s *v6Server) checkIA(msg *dhcpv6.Message, lease *Lease) error {
|
||||
return fmt.Errorf("no IANA.Addr option in %s", msg.Type().String())
|
||||
}
|
||||
|
||||
if !oiaAddr.IPv6Addr.Equal(lease.IP) {
|
||||
leaseIP := net.IP(lease.IP.AsSlice())
|
||||
if !oiaAddr.IPv6Addr.Equal(leaseIP) {
|
||||
return fmt.Errorf("invalid IANA.Addr option in %s", msg.Type().String())
|
||||
}
|
||||
}
|
||||
@@ -475,7 +486,7 @@ func (s *v6Server) process(msg *dhcpv6.Message, req, resp dhcpv6.DHCPv6) bool {
|
||||
copy(oia.IaId[:], []byte(valueIAID))
|
||||
}
|
||||
oiaAddr := &dhcpv6.OptIAAddress{
|
||||
IPv6Addr: lease.IP,
|
||||
IPv6Addr: net.IP(lease.IP.AsSlice()),
|
||||
PreferredLifetime: lifetime,
|
||||
ValidLifetime: lifetime,
|
||||
}
|
||||
|
||||
@@ -4,7 +4,9 @@ package dhcpd
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/insomniacslk/dhcp/dhcpv6"
|
||||
"github.com/insomniacslk/dhcp/iana"
|
||||
@@ -27,7 +29,7 @@ func TestV6_AddRemove_static(t *testing.T) {
|
||||
|
||||
// Add static lease.
|
||||
l := &Lease{
|
||||
IP: net.ParseIP("2001::1"),
|
||||
IP: netip.MustParseAddr("2001::1"),
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}
|
||||
err = s.AddStaticLease(l)
|
||||
@@ -46,7 +48,7 @@ func TestV6_AddRemove_static(t *testing.T) {
|
||||
|
||||
// Try to remove non-existent static lease.
|
||||
err = s.RemoveStaticLease(&Lease{
|
||||
IP: net.ParseIP("2001::2"),
|
||||
IP: netip.MustParseAddr("2001::2"),
|
||||
HWAddr: l.HWAddr,
|
||||
})
|
||||
require.Error(t, err)
|
||||
@@ -71,10 +73,10 @@ func TestV6_AddReplace(t *testing.T) {
|
||||
|
||||
// Add dynamic leases.
|
||||
dynLeases := []*Lease{{
|
||||
IP: net.ParseIP("2001::1"),
|
||||
IP: netip.MustParseAddr("2001::1"),
|
||||
HWAddr: net.HardwareAddr{0x11, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}, {
|
||||
IP: net.ParseIP("2001::2"),
|
||||
IP: netip.MustParseAddr("2001::2"),
|
||||
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}}
|
||||
|
||||
@@ -83,10 +85,10 @@ func TestV6_AddReplace(t *testing.T) {
|
||||
}
|
||||
|
||||
stLeases := []*Lease{{
|
||||
IP: net.ParseIP("2001::1"),
|
||||
IP: netip.MustParseAddr("2001::1"),
|
||||
HWAddr: net.HardwareAddr{0x33, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}, {
|
||||
IP: net.ParseIP("2001::3"),
|
||||
IP: netip.MustParseAddr("2001::3"),
|
||||
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}}
|
||||
|
||||
@@ -99,7 +101,7 @@ func TestV6_AddReplace(t *testing.T) {
|
||||
require.Len(t, ls, 2)
|
||||
|
||||
for i, l := range ls {
|
||||
assert.True(t, stLeases[i].IP.Equal(l.IP))
|
||||
assert.Equal(t, stLeases[i].IP, l.IP)
|
||||
assert.Equal(t, stLeases[i].HWAddr, l.HWAddr)
|
||||
assert.EqualValues(t, leaseExpireStatic, l.Expiry.Unix())
|
||||
}
|
||||
@@ -126,7 +128,7 @@ func TestV6GetLease(t *testing.T) {
|
||||
}
|
||||
|
||||
l := &Lease{
|
||||
IP: net.ParseIP("2001::1"),
|
||||
IP: netip.MustParseAddr("2001::1"),
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}
|
||||
err = s.AddStaticLease(l)
|
||||
@@ -158,7 +160,8 @@ func TestV6GetLease(t *testing.T) {
|
||||
oia = resp.Options.OneIANA()
|
||||
oiaAddr = oia.Options.OneAddress()
|
||||
|
||||
assert.Equal(t, l.IP, oiaAddr.IPv6Addr)
|
||||
ip := net.IP(l.IP.AsSlice())
|
||||
assert.Equal(t, ip, oiaAddr.IPv6Addr)
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds())
|
||||
})
|
||||
|
||||
@@ -182,7 +185,8 @@ func TestV6GetLease(t *testing.T) {
|
||||
oia = resp.Options.OneIANA()
|
||||
oiaAddr = oia.Options.OneAddress()
|
||||
|
||||
assert.Equal(t, l.IP, oiaAddr.IPv6Addr)
|
||||
ip := net.IP(l.IP.AsSlice())
|
||||
assert.Equal(t, ip, oiaAddr.IPv6Addr)
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds())
|
||||
})
|
||||
|
||||
@@ -308,3 +312,74 @@ func TestIP6InRange(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestV6_FindMACbyIP(t *testing.T) {
|
||||
const (
|
||||
staticName = "static-client"
|
||||
anotherName = "another-client"
|
||||
)
|
||||
|
||||
staticIP := netip.MustParseAddr("2001::1")
|
||||
staticMAC := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
|
||||
anotherIP := netip.MustParseAddr("2001::100")
|
||||
anotherMAC := net.HardwareAddr{0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB}
|
||||
|
||||
s := &v6Server{
|
||||
leases: []*Lease{{
|
||||
Expiry: time.Unix(leaseExpireStatic, 0),
|
||||
Hostname: staticName,
|
||||
HWAddr: staticMAC,
|
||||
IP: staticIP,
|
||||
IsStatic: true,
|
||||
}, {
|
||||
Expiry: time.Unix(10, 0),
|
||||
Hostname: anotherName,
|
||||
HWAddr: anotherMAC,
|
||||
IP: anotherIP,
|
||||
}},
|
||||
}
|
||||
|
||||
s.leases = []*Lease{{
|
||||
Expiry: time.Unix(leaseExpireStatic, 0),
|
||||
Hostname: staticName,
|
||||
HWAddr: staticMAC,
|
||||
IP: staticIP,
|
||||
IsStatic: true,
|
||||
}, {
|
||||
Expiry: time.Unix(10, 0),
|
||||
Hostname: anotherName,
|
||||
HWAddr: anotherMAC,
|
||||
IP: anotherIP,
|
||||
}}
|
||||
|
||||
testCases := []struct {
|
||||
want net.HardwareAddr
|
||||
ip netip.Addr
|
||||
name string
|
||||
}{{
|
||||
name: "basic",
|
||||
ip: staticIP,
|
||||
want: staticMAC,
|
||||
}, {
|
||||
name: "not_found",
|
||||
ip: netip.MustParseAddr("ffff::1"),
|
||||
want: nil,
|
||||
}, {
|
||||
name: "expired",
|
||||
ip: anotherIP,
|
||||
want: nil,
|
||||
}, {
|
||||
name: "v4",
|
||||
ip: netip.MustParseAddr("1.2.3.4"),
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mac := s.FindMACbyIP(tc.ip)
|
||||
|
||||
require.Equal(t, tc.want, mac)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,6 +81,10 @@ type FilteringConfig struct {
|
||||
// 0, then default value is used (3600).
|
||||
BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"`
|
||||
|
||||
// ProtectionDisabledUntil is the timestamp until when the protection is
|
||||
// disabled.
|
||||
ProtectionDisabledUntil *time.Time `yaml:"protection_disabled_until"`
|
||||
|
||||
// ParentalBlockHost is the IP (or domain name) which is used to respond to
|
||||
// DNS requests blocked by parental control.
|
||||
ParentalBlockHost string `yaml:"parental_block_host"`
|
||||
@@ -195,12 +199,16 @@ type FilteringConfig struct {
|
||||
// IpsetListFileName, if set, points to the file with ipset configuration.
|
||||
// The format is the same as in [IpsetList].
|
||||
IpsetListFileName string `yaml:"ipset_file"`
|
||||
|
||||
// BootstrapPreferIPv6, if true, instructs the bootstrapper to prefer IPv6
|
||||
// addresses to IPv4 ones for DoH, DoQ, and DoT.
|
||||
BootstrapPreferIPv6 bool `yaml:"bootstrap_prefer_ipv6"`
|
||||
}
|
||||
|
||||
// EDNSClientSubnet is the settings list for EDNS Client Subnet.
|
||||
type EDNSClientSubnet struct {
|
||||
// CustomIP for EDNS Client Subnet.
|
||||
CustomIP string `yaml:"custom_ip"`
|
||||
CustomIP netip.Addr `yaml:"custom_ip"`
|
||||
|
||||
// Enabled defines if EDNS Client Subnet is enabled.
|
||||
Enabled bool `yaml:"enabled"`
|
||||
@@ -340,15 +348,8 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
|
||||
}
|
||||
|
||||
if srvConf.EDNSClientSubnet.UseCustom {
|
||||
// TODO(s.chzhen): Add wrapper around netip.Addr.
|
||||
var ip net.IP
|
||||
ip, err = netutil.ParseIP(srvConf.EDNSClientSubnet.CustomIP)
|
||||
if err != nil {
|
||||
return conf, fmt.Errorf("edns: %w", err)
|
||||
}
|
||||
|
||||
// TODO(s.chzhen): Use netip.Addr instead of net.IP inside dnsproxy.
|
||||
conf.EDNSAddr = ip
|
||||
conf.EDNSAddr = net.IP(srvConf.EDNSClientSubnet.CustomIP.AsSlice())
|
||||
}
|
||||
|
||||
if srvConf.CacheSize != 0 {
|
||||
@@ -377,7 +378,7 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
|
||||
|
||||
err = s.prepareTLS(&conf)
|
||||
if err != nil {
|
||||
return conf, fmt.Errorf("validating tls: %w", err)
|
||||
return proxy.Config{}, fmt.Errorf("validating tls: %w", err)
|
||||
}
|
||||
|
||||
if c := srvConf.DNSCryptConfig; c.Enabled {
|
||||
@@ -388,7 +389,7 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
|
||||
}
|
||||
|
||||
if conf.UpstreamConfig == nil || len(conf.UpstreamConfig.Upstreams) == 0 {
|
||||
return conf, errors.Error("no default upstream servers configured")
|
||||
return proxy.Config{}, errors.Error("no default upstream servers configured")
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
@@ -482,6 +483,7 @@ func (s *Server) prepareUpstreamSettings() error {
|
||||
Bootstrap: s.conf.BootstrapDNS,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
HTTPVersions: httpVersions,
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
@@ -497,6 +499,7 @@ func (s *Server) prepareUpstreamSettings() error {
|
||||
Bootstrap: s.conf.BootstrapDNS,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
HTTPVersions: httpVersions,
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
@@ -584,11 +587,11 @@ func (s *Server) prepareTLS(proxyConfig *proxy.Config) (err error) {
|
||||
if s.conf.StrictSNICheck {
|
||||
if len(cert.DNSNames) != 0 {
|
||||
s.conf.dnsNames = cert.DNSNames
|
||||
log.Debug("dnsforward: using certificate's SAN as DNS names: %v", cert.DNSNames)
|
||||
log.Debug("dns: using certificate's SAN as DNS names: %v", cert.DNSNames)
|
||||
slices.Sort(s.conf.dnsNames)
|
||||
} else {
|
||||
s.conf.dnsNames = append(s.conf.dnsNames, cert.Subject.CommonName)
|
||||
log.Debug("dnsforward: using certificate's CN as DNS name: %s", cert.Subject.CommonName)
|
||||
log.Debug("dns: using certificate's CN as DNS name: %s", cert.Subject.CommonName)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -642,3 +645,49 @@ func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, er
|
||||
}
|
||||
return &s.conf.cert, nil
|
||||
}
|
||||
|
||||
// UpdatedProtectionStatus updates protection state, if the protection was
|
||||
// disabled temporarily. Returns the updated state of protection.
|
||||
func (s *Server) UpdatedProtectionStatus() (enabled bool, disabledUntil *time.Time) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
disabledUntil = s.conf.ProtectionDisabledUntil
|
||||
if disabledUntil == nil {
|
||||
return s.conf.ProtectionEnabled, nil
|
||||
}
|
||||
|
||||
if time.Now().Before(*disabledUntil) {
|
||||
return false, disabledUntil
|
||||
}
|
||||
|
||||
// Update the values in a separate goroutine, unless an update is already in
|
||||
// progress. Since this method is called very often, and this update is a
|
||||
// relatively rare situation, do not lock s.serverLock for writing, as that
|
||||
// can lead to freezes.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/5661.
|
||||
if s.protectionUpdateInProgress.CompareAndSwap(false, true) {
|
||||
go s.enableProtectionAfterPause()
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// enableProtectionAfterPause sets the protection configuration to enabled
|
||||
// values. It is intended to be used as a goroutine.
|
||||
func (s *Server) enableProtectionAfterPause() {
|
||||
defer log.OnPanic("dns: enabling protection after pause")
|
||||
|
||||
defer s.protectionUpdateInProgress.Store(false)
|
||||
|
||||
defer s.conf.ConfigModified()
|
||||
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
s.conf.ProtectionEnabled = true
|
||||
s.conf.ProtectionDisabledUntil = nil
|
||||
|
||||
log.Info("dns: protection is restarted after pause")
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -21,7 +22,7 @@ import (
|
||||
// To transfer information between modules
|
||||
//
|
||||
// TODO(s.chzhen): Add lowercased, non-FQDN version of the hostname from the
|
||||
// question of the request.
|
||||
// question of the request. Add persistent client.
|
||||
type dnsContext struct {
|
||||
proxyCtx *proxy.DNSContext
|
||||
|
||||
@@ -37,6 +38,8 @@ type dnsContext struct {
|
||||
// was parsed successfully and belongs to one of the locally served IP
|
||||
// ranges. It is also filled with unmapped version of the address if it's
|
||||
// within DNS64 prefixes.
|
||||
//
|
||||
// TODO(e.burkov): Use netip.Addr when we switch to netip more fully.
|
||||
unreversedReqIP net.IP
|
||||
|
||||
// err is the error returned from a processing function.
|
||||
@@ -181,6 +184,21 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
// Handle a reserved domain healthcheck.adguardhome.test.
|
||||
//
|
||||
// [Section 6.2 of RFC 6761] states that DNS Registries/Registrars must not
|
||||
// grant requests to register test names in the normal way to any person or
|
||||
// entity, making domain names under test. TLD free to use in internal
|
||||
// purposes.
|
||||
//
|
||||
// [Section 6.2 of RFC 6761]: https://www.rfc-editor.org/rfc/rfc6761.html#section-6.2
|
||||
if q.Name == "healthcheck.adguardhome.test." {
|
||||
// Generate a NODATA negative response to make nslookup exit with 0.
|
||||
pctx.Res = s.makeResponse(pctx.Req)
|
||||
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
// Get the ClientID, if any, before getting client-specific filtering
|
||||
// settings.
|
||||
var key [8]byte
|
||||
@@ -188,7 +206,7 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
|
||||
dctx.clientID = string(s.clientIDCache.Get(key[:]))
|
||||
|
||||
// Get the client-specific filtering settings.
|
||||
dctx.protectionEnabled = s.conf.ProtectionEnabled
|
||||
dctx.protectionEnabled, _ = s.UpdatedProtectionStatus()
|
||||
dctx.setts = s.getClientRequestFilteringSettings(dctx)
|
||||
|
||||
return resultCodeSuccess
|
||||
@@ -240,17 +258,16 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
lowhost := strings.ToLower(l.Hostname + "." + s.localDomainSuffix)
|
||||
|
||||
// Assume that we only process IPv4 now.
|
||||
//
|
||||
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
|
||||
ip, err := netutil.IPToAddr(l.IP, netutil.AddrFamilyIPv4)
|
||||
if err != nil {
|
||||
log.Debug("dnsforward: skipping invalid ip %v from dhcp: %s", l.IP, err)
|
||||
if !l.IP.Is4() {
|
||||
log.Debug("dnsforward: skipping invalid ip from dhcp: bad ipv4 net.IP %v", l.IP)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
ipToHost[ip] = lowhost
|
||||
hostToIP[lowhost] = ip
|
||||
leaseIP := l.IP
|
||||
|
||||
ipToHost[leaseIP] = lowhost
|
||||
hostToIP[lowhost] = leaseIP
|
||||
}
|
||||
|
||||
s.setTableHostToIP(hostToIP)
|
||||
@@ -442,6 +459,88 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// indexFirstV4Label returns the index at which the reversed IPv4 address
|
||||
// starts, assuming the domain is pre-validated ARPA domain having in-addr and
|
||||
// arpa labels removed.
|
||||
func indexFirstV4Label(domain string) (idx int) {
|
||||
idx = len(domain)
|
||||
for labelsNum := 0; labelsNum < net.IPv4len && idx > 0; labelsNum++ {
|
||||
curIdx := strings.LastIndexByte(domain[:idx-1], '.') + 1
|
||||
_, parseErr := strconv.ParseUint(domain[curIdx:idx-1], 10, 8)
|
||||
if parseErr != nil {
|
||||
return idx
|
||||
}
|
||||
|
||||
idx = curIdx
|
||||
}
|
||||
|
||||
return idx
|
||||
}
|
||||
|
||||
// indexFirstV6Label returns the index at which the reversed IPv6 address
|
||||
// starts, assuming the domain is pre-validated ARPA domain having ip6 and arpa
|
||||
// labels removed.
|
||||
func indexFirstV6Label(domain string) (idx int) {
|
||||
idx = len(domain)
|
||||
for labelsNum := 0; labelsNum < net.IPv6len*2 && idx > 0; labelsNum++ {
|
||||
curIdx := idx - len("a.")
|
||||
if curIdx > 1 && domain[curIdx-1] != '.' {
|
||||
return idx
|
||||
}
|
||||
|
||||
nibble := domain[curIdx]
|
||||
if (nibble < '0' || nibble > '9') && (nibble < 'a' || nibble > 'f') {
|
||||
return idx
|
||||
}
|
||||
|
||||
idx = curIdx
|
||||
}
|
||||
|
||||
return idx
|
||||
}
|
||||
|
||||
// extractARPASubnet tries to convert a reversed ARPA address being a part of
|
||||
// domain to an IP network. domain must be an FQDN.
|
||||
//
|
||||
// TODO(e.burkov): Move to golibs.
|
||||
func extractARPASubnet(domain string) (pref netip.Prefix, err error) {
|
||||
err = netutil.ValidateDomainName(strings.TrimSuffix(domain, "."))
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return netip.Prefix{}, err
|
||||
}
|
||||
|
||||
const (
|
||||
v4Suffix = "in-addr.arpa."
|
||||
v6Suffix = "ip6.arpa."
|
||||
)
|
||||
|
||||
domain = strings.ToLower(domain)
|
||||
|
||||
var idx int
|
||||
switch {
|
||||
case strings.HasSuffix(domain, v4Suffix):
|
||||
idx = indexFirstV4Label(domain[:len(domain)-len(v4Suffix)])
|
||||
case strings.HasSuffix(domain, v6Suffix):
|
||||
idx = indexFirstV6Label(domain[:len(domain)-len(v6Suffix)])
|
||||
default:
|
||||
return netip.Prefix{}, &netutil.AddrError{
|
||||
Err: netutil.ErrNotAReversedSubnet,
|
||||
Kind: netutil.AddrKindARPA,
|
||||
Addr: domain,
|
||||
}
|
||||
}
|
||||
|
||||
var subnet *net.IPNet
|
||||
subnet, err = netutil.SubnetFromReversedAddr(domain[idx:])
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return netip.Prefix{}, err
|
||||
}
|
||||
|
||||
return netutil.IPNetToPrefixNoMapped(subnet)
|
||||
}
|
||||
|
||||
// processRestrictLocal responds with NXDOMAIN to PTR requests for IP addresses
|
||||
// in locally served network from external clients.
|
||||
func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
|
||||
@@ -453,34 +552,29 @@ func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
ip, err := netutil.IPFromReversedAddr(q.Name)
|
||||
subnet, err := extractARPASubnet(q.Name)
|
||||
if err != nil {
|
||||
log.Debug("dnsforward: parsing reversed addr: %s", err)
|
||||
if errors.Is(err, netutil.ErrNotAReversedSubnet) {
|
||||
log.Debug("dnsforward: request is not for arpa domain")
|
||||
|
||||
// DNS-Based Service Discovery uses PTR records having not an ARPA
|
||||
// format of the domain name in question. Those shouldn't be
|
||||
// invalidated. See http://www.dns-sd.org/ServerStaticSetup.html and
|
||||
// RFC 2782.
|
||||
name := strings.TrimSuffix(q.Name, ".")
|
||||
if err = netutil.ValidateSRVDomainName(name); err != nil {
|
||||
log.Debug("dnsforward: validating service domain: %s", err)
|
||||
|
||||
return resultCodeError
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: request is not for arpa domain")
|
||||
log.Debug("dnsforward: parsing reversed addr: %s", err)
|
||||
|
||||
return resultCodeSuccess
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
// Restrict an access to local addresses for external clients. We also
|
||||
// assume that all the DHCP leases we give are locally served or at least
|
||||
// shouldn't be accessible externally.
|
||||
if !s.privateNets.Contains(ip) {
|
||||
subnetAddr := subnet.Addr()
|
||||
addrData := subnetAddr.AsSlice()
|
||||
if !s.privateNets.Contains(addrData) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: addr %s is from locally served network", ip)
|
||||
log.Debug("dnsforward: addr %s is from locally served network", subnetAddr)
|
||||
|
||||
if !dctx.isLocalClient {
|
||||
log.Debug("dnsforward: %q requests an internal ip", pctx.Addr)
|
||||
@@ -491,7 +585,7 @@ func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
|
||||
}
|
||||
|
||||
// Do not perform unreversing ever again.
|
||||
dctx.unreversedReqIP = ip
|
||||
dctx.unreversedReqIP = addrData
|
||||
|
||||
// There is no need to filter request from external addresses since this
|
||||
// code is only executed when the request is for locally served ARPA
|
||||
|
||||
@@ -36,8 +36,6 @@ func (s *Server) setupDNS64() {
|
||||
// valid IPv4. It panics, if there are no configured DNS64 prefixes, because
|
||||
// synthesis should not be performed unless DNS64 function enabled.
|
||||
func (s *Server) mapDNS64(ip netip.Addr) (mapped net.IP) {
|
||||
// Don't mask the address here since it should have already been masked on
|
||||
// initialization stage.
|
||||
pref := s.dns64Pref.Masked().Addr().As16()
|
||||
ipData := ip.As4()
|
||||
|
||||
|
||||
@@ -605,3 +605,129 @@ func TestIPStringFromAddr(t *testing.T) {
|
||||
assert.Empty(t, ipStringFromAddr(nil))
|
||||
})
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Add fuzzing when moving to golibs.
|
||||
func TestExtractARPASubnet(t *testing.T) {
|
||||
const (
|
||||
v4Suf = `in-addr.arpa.`
|
||||
v4Part = `2.1.` + v4Suf
|
||||
v4Whole = `4.3.` + v4Part
|
||||
|
||||
v6Suf = `ip6.arpa.`
|
||||
v6Part = `4.3.2.1.0.0.0.0.0.0.0.0.0.0.0.0.` + v6Suf
|
||||
v6Whole = `f.e.d.c.0.0.0.0.0.0.0.0.0.0.0.0.` + v6Part
|
||||
)
|
||||
|
||||
v4Pref := netip.MustParsePrefix("1.2.3.4/32")
|
||||
v4PrefPart := netip.MustParsePrefix("1.2.0.0/16")
|
||||
v6Pref := netip.MustParsePrefix("::1234:0:0:0:cdef/128")
|
||||
v6PrefPart := netip.MustParsePrefix("0:0:0:1234::/64")
|
||||
|
||||
testCases := []struct {
|
||||
want netip.Prefix
|
||||
name string
|
||||
domain string
|
||||
wantErr string
|
||||
}{{
|
||||
want: netip.Prefix{},
|
||||
name: "not_an_arpa",
|
||||
domain: "some.domain.name.",
|
||||
wantErr: `bad arpa domain name "some.domain.name.": ` +
|
||||
`not a reversed ip network`,
|
||||
}, {
|
||||
want: netip.Prefix{},
|
||||
name: "bad_domain_name",
|
||||
domain: "abc.123.",
|
||||
wantErr: `bad domain name "abc.123": ` +
|
||||
`bad top-level domain name label "123": all octets are numeric`,
|
||||
}, {
|
||||
want: v4Pref,
|
||||
name: "whole_v4",
|
||||
domain: v4Whole,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v4PrefPart,
|
||||
name: "partial_v4",
|
||||
domain: v4Part,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v4Pref,
|
||||
name: "whole_v4_within_domain",
|
||||
domain: "a." + v4Whole,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v4Pref,
|
||||
name: "whole_v4_additional_label",
|
||||
domain: "5." + v4Whole,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v4PrefPart,
|
||||
name: "partial_v4_within_domain",
|
||||
domain: "a." + v4Part,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v4PrefPart,
|
||||
name: "overflow_v4",
|
||||
domain: "256." + v4Part,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v4PrefPart,
|
||||
name: "overflow_v4_within_domain",
|
||||
domain: "a.256." + v4Part,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: netip.Prefix{},
|
||||
name: "empty_v4",
|
||||
domain: v4Suf,
|
||||
wantErr: `bad arpa domain name "in-addr.arpa": ` +
|
||||
`not a reversed ip network`,
|
||||
}, {
|
||||
want: netip.Prefix{},
|
||||
name: "empty_v4_within_domain",
|
||||
domain: "a." + v4Suf,
|
||||
wantErr: `bad arpa domain name "in-addr.arpa": ` +
|
||||
`not a reversed ip network`,
|
||||
}, {
|
||||
want: v6Pref,
|
||||
name: "whole_v6",
|
||||
domain: v6Whole,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v6PrefPart,
|
||||
name: "partial_v6",
|
||||
domain: v6Part,
|
||||
}, {
|
||||
want: v6Pref,
|
||||
name: "whole_v6_within_domain",
|
||||
domain: "g." + v6Whole,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v6Pref,
|
||||
name: "whole_v6_additional_label",
|
||||
domain: "1." + v6Whole,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v6PrefPart,
|
||||
name: "partial_v6_within_domain",
|
||||
domain: "label." + v6Part,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: netip.Prefix{},
|
||||
name: "empty_v6",
|
||||
domain: v6Suf,
|
||||
wantErr: `bad arpa domain name "ip6.arpa": not a reversed ip network`,
|
||||
}, {
|
||||
want: netip.Prefix{},
|
||||
name: "empty_v6_within_domain",
|
||||
domain: "g." + v6Suf,
|
||||
wantErr: `bad arpa domain name "ip6.arpa": not a reversed ip network`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
subnet, err := extractARPASubnet(tc.domain)
|
||||
testutil.AssertErrorMsg(t, tc.wantErr, err)
|
||||
assert.Equal(t, tc.want, subnet)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
@@ -111,6 +112,10 @@ type Server struct {
|
||||
|
||||
isRunning bool
|
||||
|
||||
// protectionUpdateInProgress is used to make sure that only one goroutine
|
||||
// updating the protection configuration after a pause is running at a time.
|
||||
protectionUpdateInProgress atomic.Bool
|
||||
|
||||
conf ServerConfig
|
||||
// serverLock protects Server.
|
||||
serverLock sync.RWMutex
|
||||
@@ -447,6 +452,8 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) {
|
||||
Bootstrap: bootstraps,
|
||||
Timeout: defaultLocalTimeout,
|
||||
// TODO(e.burkov): Should we verify server's certificates?
|
||||
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
@@ -412,7 +413,7 @@ func TestServerRace(t *testing.T) {
|
||||
filterConf := &filtering.Config{
|
||||
SafeBrowsingEnabled: true,
|
||||
SafeBrowsingCacheSize: 1000,
|
||||
SafeSearchEnabled: true,
|
||||
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||
SafeSearchCacheSize: 1000,
|
||||
ParentalCacheSize: 1000,
|
||||
CacheTime: 30,
|
||||
@@ -440,12 +441,27 @@ func TestServerRace(t *testing.T) {
|
||||
|
||||
func TestSafeSearch(t *testing.T) {
|
||||
resolver := &aghtest.TestResolver{}
|
||||
safeSearchConf := filtering.SafeSearchConfig{
|
||||
Enabled: true,
|
||||
Google: true,
|
||||
Yandex: true,
|
||||
CustomResolver: resolver,
|
||||
}
|
||||
|
||||
filterConf := &filtering.Config{
|
||||
SafeSearchEnabled: true,
|
||||
SafeSearchConf: safeSearchConf,
|
||||
SafeSearchCacheSize: 1000,
|
||||
CacheTime: 30,
|
||||
CustomResolver: resolver,
|
||||
}
|
||||
safeSearch, err := safesearch.NewDefault(
|
||||
safeSearchConf,
|
||||
"",
|
||||
filterConf.SafeSearchCacheSize,
|
||||
time.Minute*time.Duration(filterConf.CacheTime),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
filterConf.SafeSearch = safeSearch
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
@@ -498,7 +514,8 @@ func TestSafeSearch(t *testing.T) {
|
||||
t.Run(tc.host, func(t *testing.T) {
|
||||
req := createTestMessage(tc.host)
|
||||
|
||||
reply, _, err := client.Exchange(req, addr)
|
||||
var reply *dns.Msg
|
||||
reply, _, err = client.Exchange(req, addr)
|
||||
require.NoErrorf(t, err, "couldn't talk to server %s: %s", addr, err)
|
||||
assertResponse(t, reply, tc.want)
|
||||
})
|
||||
@@ -1057,7 +1074,7 @@ var testDHCP = &dhcpd.MockInterface{
|
||||
OnEnabled: func() (ok bool) { return true },
|
||||
OnLeases: func(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) {
|
||||
return []*dhcpd.Lease{{
|
||||
IP: net.IP{192, 168, 12, 34},
|
||||
IP: netip.MustParseAddr("192.168.12.34"),
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
Hostname: "myhost",
|
||||
}}
|
||||
|
||||
@@ -23,41 +23,101 @@ import (
|
||||
)
|
||||
|
||||
// jsonDNSConfig is the JSON representation of the DNS server configuration.
|
||||
//
|
||||
// TODO(s.chzhen): Split it into smaller pieces. Use aghalg.NullBool instead
|
||||
// of *bool.
|
||||
type jsonDNSConfig struct {
|
||||
Upstreams *[]string `json:"upstream_dns"`
|
||||
UpstreamsFile *string `json:"upstream_dns_file"`
|
||||
Bootstraps *[]string `json:"bootstrap_dns"`
|
||||
ProtectionEnabled *bool `json:"protection_enabled"`
|
||||
RateLimit *uint32 `json:"ratelimit"`
|
||||
BlockingMode *BlockingMode `json:"blocking_mode"`
|
||||
EDNSCSEnabled *bool `json:"edns_cs_enabled"`
|
||||
DNSSECEnabled *bool `json:"dnssec_enabled"`
|
||||
DisableIPv6 *bool `json:"disable_ipv6"`
|
||||
UpstreamMode *string `json:"upstream_mode"`
|
||||
CacheSize *uint32 `json:"cache_size"`
|
||||
CacheMinTTL *uint32 `json:"cache_ttl_min"`
|
||||
CacheMaxTTL *uint32 `json:"cache_ttl_max"`
|
||||
CacheOptimistic *bool `json:"cache_optimistic"`
|
||||
ResolveClients *bool `json:"resolve_clients"`
|
||||
UsePrivateRDNS *bool `json:"use_private_ptr_resolvers"`
|
||||
LocalPTRUpstreams *[]string `json:"local_ptr_upstreams"`
|
||||
BlockingIPv4 net.IP `json:"blocking_ipv4"`
|
||||
BlockingIPv6 net.IP `json:"blocking_ipv6"`
|
||||
// Upstreams is the list of upstream DNS servers.
|
||||
Upstreams *[]string `json:"upstream_dns"`
|
||||
|
||||
// UpstreamsFile is the file containing upstream DNS servers.
|
||||
UpstreamsFile *string `json:"upstream_dns_file"`
|
||||
|
||||
// Bootstraps is the list of DNS servers resolving IP addresses of the
|
||||
// upstream DoH/DoT resolvers.
|
||||
Bootstraps *[]string `json:"bootstrap_dns"`
|
||||
|
||||
// ProtectionEnabled defines if protection is enabled.
|
||||
ProtectionEnabled *bool `json:"protection_enabled"`
|
||||
|
||||
// RateLimit is the number of requests per second allowed per client.
|
||||
RateLimit *uint32 `json:"ratelimit"`
|
||||
|
||||
// BlockingMode defines the way blocked responses are constructed.
|
||||
BlockingMode *BlockingMode `json:"blocking_mode"`
|
||||
|
||||
// EDNSCSEnabled defines if EDNS Client Subnet is enabled.
|
||||
EDNSCSEnabled *bool `json:"edns_cs_enabled"`
|
||||
|
||||
// EDNSCSUseCustom defines if EDNSCSCustomIP should be used.
|
||||
EDNSCSUseCustom *bool `json:"edns_cs_use_custom"`
|
||||
|
||||
// DNSSECEnabled defines if DNSSEC is enabled.
|
||||
DNSSECEnabled *bool `json:"dnssec_enabled"`
|
||||
|
||||
// DisableIPv6 defines if IPv6 addresses should be dropped.
|
||||
DisableIPv6 *bool `json:"disable_ipv6"`
|
||||
|
||||
// UpstreamMode defines the way DNS requests are constructed.
|
||||
UpstreamMode *string `json:"upstream_mode"`
|
||||
|
||||
// CacheSize in bytes.
|
||||
CacheSize *uint32 `json:"cache_size"`
|
||||
|
||||
// CacheMinTTL is custom minimum TTL for cached DNS responses.
|
||||
CacheMinTTL *uint32 `json:"cache_ttl_min"`
|
||||
|
||||
// CacheMaxTTL is custom maximum TTL for cached DNS responses.
|
||||
CacheMaxTTL *uint32 `json:"cache_ttl_max"`
|
||||
|
||||
// CacheOptimistic defines if expired entries should be served.
|
||||
CacheOptimistic *bool `json:"cache_optimistic"`
|
||||
|
||||
// ResolveClients defines if clients IPs should be resolved into hostnames.
|
||||
ResolveClients *bool `json:"resolve_clients"`
|
||||
|
||||
// UsePrivateRDNS defines if privates DNS resolvers should be used.
|
||||
UsePrivateRDNS *bool `json:"use_private_ptr_resolvers"`
|
||||
|
||||
// LocalPTRUpstreams is the list of local private DNS resolvers.
|
||||
LocalPTRUpstreams *[]string `json:"local_ptr_upstreams"`
|
||||
|
||||
// BlockingIPv4 is custom IPv4 address for blocked A requests.
|
||||
BlockingIPv4 net.IP `json:"blocking_ipv4"`
|
||||
|
||||
// BlockingIPv6 is custom IPv6 address for blocked AAAA requests.
|
||||
BlockingIPv6 net.IP `json:"blocking_ipv6"`
|
||||
|
||||
// DisabledUntil is a timestamp until when the protection is disabled.
|
||||
DisabledUntil *time.Time `json:"protection_disabled_until"`
|
||||
|
||||
// EDNSCSCustomIP is custom IP for EDNS Client Subnet.
|
||||
EDNSCSCustomIP netip.Addr `json:"edns_cs_custom_ip"`
|
||||
|
||||
// DefaultLocalPTRUpstreams is used to pass the addresses from
|
||||
// systemResolvers to the front-end. It's not a pointer to the slice since
|
||||
// there is no need to omit it while decoding from JSON.
|
||||
DefaultLocalPTRUpstreams []string `json:"default_local_ptr_upstreams,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
protectionEnabled, protectionDisabledUntil := s.UpdatedProtectionStatus()
|
||||
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
upstreams := stringutil.CloneSliceOrEmpty(s.conf.UpstreamDNS)
|
||||
upstreamFile := s.conf.UpstreamDNSFileName
|
||||
bootstraps := stringutil.CloneSliceOrEmpty(s.conf.BootstrapDNS)
|
||||
protectionEnabled := s.conf.ProtectionEnabled
|
||||
blockingMode := s.conf.BlockingMode
|
||||
blockingIPv4 := s.conf.BlockingIPv4
|
||||
blockingIPv6 := s.conf.BlockingIPv6
|
||||
ratelimit := s.conf.Ratelimit
|
||||
|
||||
customIP := s.conf.EDNSClientSubnet.CustomIP
|
||||
enableEDNSClientSubnet := s.conf.EDNSClientSubnet.Enabled
|
||||
useCustom := s.conf.EDNSClientSubnet.UseCustom
|
||||
|
||||
enableDNSSEC := s.conf.EnableDNSSEC
|
||||
aaaaDisabled := s.conf.AAAADisabled
|
||||
cacheSize := s.conf.CacheSize
|
||||
@@ -67,6 +127,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
resolveClients := s.conf.ResolveClients
|
||||
usePrivateRDNS := s.conf.UsePrivateRDNS
|
||||
localPTRUpstreams := stringutil.CloneSliceOrEmpty(s.conf.LocalPTRResolvers)
|
||||
|
||||
var upstreamMode string
|
||||
if s.conf.FastestAddr {
|
||||
upstreamMode = "fastest_addr"
|
||||
@@ -74,46 +135,41 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
upstreamMode = "parallel"
|
||||
}
|
||||
|
||||
return &jsonDNSConfig{
|
||||
Upstreams: &upstreams,
|
||||
UpstreamsFile: &upstreamFile,
|
||||
Bootstraps: &bootstraps,
|
||||
ProtectionEnabled: &protectionEnabled,
|
||||
BlockingMode: &blockingMode,
|
||||
BlockingIPv4: blockingIPv4,
|
||||
BlockingIPv6: blockingIPv6,
|
||||
RateLimit: &ratelimit,
|
||||
EDNSCSEnabled: &enableEDNSClientSubnet,
|
||||
DNSSECEnabled: &enableDNSSEC,
|
||||
DisableIPv6: &aaaaDisabled,
|
||||
CacheSize: &cacheSize,
|
||||
CacheMinTTL: &cacheMinTTL,
|
||||
CacheMaxTTL: &cacheMaxTTL,
|
||||
CacheOptimistic: &cacheOptimistic,
|
||||
UpstreamMode: &upstreamMode,
|
||||
ResolveClients: &resolveClients,
|
||||
UsePrivateRDNS: &usePrivateRDNS,
|
||||
LocalPTRUpstreams: &localPTRUpstreams,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
defLocalPTRUps, err := s.filterOurDNSAddrs(s.sysResolvers.Get())
|
||||
if err != nil {
|
||||
log.Debug("getting dns configuration: %s", err)
|
||||
}
|
||||
|
||||
resp := struct {
|
||||
jsonDNSConfig
|
||||
// DefautLocalPTRUpstreams is used to pass the addresses from
|
||||
// systemResolvers to the front-end. It's not a pointer to the slice
|
||||
// since there is no need to omit it while decoding from JSON.
|
||||
DefautLocalPTRUpstreams []string `json:"default_local_ptr_upstreams,omitempty"`
|
||||
}{
|
||||
jsonDNSConfig: *s.getDNSConfig(),
|
||||
DefautLocalPTRUpstreams: defLocalPTRUps,
|
||||
return &jsonDNSConfig{
|
||||
Upstreams: &upstreams,
|
||||
UpstreamsFile: &upstreamFile,
|
||||
Bootstraps: &bootstraps,
|
||||
ProtectionEnabled: &protectionEnabled,
|
||||
BlockingMode: &blockingMode,
|
||||
BlockingIPv4: blockingIPv4,
|
||||
BlockingIPv6: blockingIPv6,
|
||||
RateLimit: &ratelimit,
|
||||
EDNSCSCustomIP: customIP,
|
||||
EDNSCSEnabled: &enableEDNSClientSubnet,
|
||||
EDNSCSUseCustom: &useCustom,
|
||||
DNSSECEnabled: &enableDNSSEC,
|
||||
DisableIPv6: &aaaaDisabled,
|
||||
CacheSize: &cacheSize,
|
||||
CacheMinTTL: &cacheMinTTL,
|
||||
CacheMaxTTL: &cacheMaxTTL,
|
||||
CacheOptimistic: &cacheOptimistic,
|
||||
UpstreamMode: &upstreamMode,
|
||||
ResolveClients: &resolveClients,
|
||||
UsePrivateRDNS: &usePrivateRDNS,
|
||||
LocalPTRUpstreams: &localPTRUpstreams,
|
||||
DefaultLocalPTRUpstreams: defLocalPTRUps,
|
||||
DisabledUntil: protectionDisabledUntil,
|
||||
}
|
||||
}
|
||||
|
||||
// handleGetConfig handles requests to the GET /control/dns_info endpoint.
|
||||
func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
resp := s.getDNSConfig()
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
}
|
||||
|
||||
@@ -204,6 +260,7 @@ func (req *jsonDNSConfig) checkCacheTTL() bool {
|
||||
return min <= max
|
||||
}
|
||||
|
||||
// handleSetConfig handles requests to the POST /control/dns_config endpoint.
|
||||
func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
req := &jsonDNSConfig{}
|
||||
err := json.NewDecoder(r.Body).Decode(req)
|
||||
@@ -231,8 +288,8 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// setConfigRestartable sets the server parameters. shouldRestart is true if
|
||||
// the server should be restarted to apply changes.
|
||||
// setConfig sets the server parameters. shouldRestart is true if the server
|
||||
// should be restarted to apply changes.
|
||||
func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
@@ -250,6 +307,10 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
s.conf.FastestAddr = *dc.UpstreamMode == "fastest_addr"
|
||||
}
|
||||
|
||||
if dc.EDNSCSUseCustom != nil && *dc.EDNSCSUseCustom {
|
||||
s.conf.EDNSClientSubnet.CustomIP = dc.EDNSCSCustomIP
|
||||
}
|
||||
|
||||
setIfNotNil(&s.conf.ProtectionEnabled, dc.ProtectionEnabled)
|
||||
setIfNotNil(&s.conf.EnableDNSSEC, dc.DNSSECEnabled)
|
||||
setIfNotNil(&s.conf.AAAADisabled, dc.DisableIPv6)
|
||||
@@ -281,6 +342,7 @@ func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
setIfNotNil(&s.conf.UpstreamDNSFileName, dc.UpstreamsFile),
|
||||
setIfNotNil(&s.conf.BootstrapDNS, dc.Bootstraps),
|
||||
setIfNotNil(&s.conf.EDNSClientSubnet.Enabled, dc.EDNSCSEnabled),
|
||||
setIfNotNil(&s.conf.EDNSClientSubnet.UseCustom, dc.EDNSCSUseCustom),
|
||||
setIfNotNil(&s.conf.CacheSize, dc.CacheSize),
|
||||
setIfNotNil(&s.conf.CacheMinTTL, dc.CacheMinTTL),
|
||||
setIfNotNil(&s.conf.CacheMaxTTL, dc.CacheMaxTTL),
|
||||
@@ -388,15 +450,15 @@ func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet)
|
||||
|
||||
var errs []error
|
||||
for _, domain := range keys {
|
||||
var subnet *net.IPNet
|
||||
subnet, err = netutil.SubnetFromReversedAddr(domain)
|
||||
var subnet netip.Prefix
|
||||
subnet, err = extractARPASubnet(domain)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !privateNets.Contains(subnet.IP) {
|
||||
if !privateNets.Contains(subnet.Addr().AsSlice()) {
|
||||
errs = append(
|
||||
errs,
|
||||
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
|
||||
@@ -577,6 +639,7 @@ func (err domainSpecificTestError) Error() (msg string) {
|
||||
func checkDNS(
|
||||
upstreamConfigStr string,
|
||||
bootstrap []string,
|
||||
bootstrapPrefIPv6 bool,
|
||||
timeout time.Duration,
|
||||
healthCheck healthCheckFunc,
|
||||
) (err error) {
|
||||
@@ -604,8 +667,9 @@ func checkDNS(
|
||||
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
|
||||
|
||||
u, err := upstream.AddressToUpstream(upstreamAddr, &upstream.Options{
|
||||
Bootstrap: bootstrap,
|
||||
Timeout: timeout,
|
||||
Bootstrap: bootstrap,
|
||||
Timeout: timeout,
|
||||
PreferIPv6: bootstrapPrefIPv6,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to choose upstream for %q: %w", upstreamAddr, err)
|
||||
@@ -637,6 +701,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
result := map[string]string{}
|
||||
bootstraps := req.BootstrapDNS
|
||||
bootstrapPrefIPv6 := s.conf.BootstrapPreferIPv6
|
||||
timeout := s.conf.UpstreamTimeout
|
||||
|
||||
type upsCheckResult = struct {
|
||||
@@ -653,7 +718,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
defer func() { resCh <- res }()
|
||||
|
||||
checkErr := checkDNS(ups, bootstraps, timeout, healthCheck)
|
||||
checkErr := checkDNS(ups, bootstraps, bootstrapPrefIPv6, timeout, healthCheck)
|
||||
if checkErr != nil {
|
||||
res.res = checkErr.Error()
|
||||
} else {
|
||||
@@ -685,6 +750,52 @@ func (s *Server) handleCacheClear(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, "OK")
|
||||
}
|
||||
|
||||
// protectionJSON is an object for /control/protection endpoint.
|
||||
type protectionJSON struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Duration uint `json:"duration"`
|
||||
}
|
||||
|
||||
// handleSetProtection is a handler for the POST /control/protection HTTP API.
|
||||
func (s *Server) handleSetProtection(w http.ResponseWriter, r *http.Request) {
|
||||
protectionReq := &protectionJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(protectionReq)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "reading req: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var disabledUntil *time.Time
|
||||
if protectionReq.Duration > 0 {
|
||||
if protectionReq.Enabled {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
"Setting a duration is only allowed with protection disabling",
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
calcTime := time.Now().Add(time.Duration(protectionReq.Duration) * time.Millisecond)
|
||||
disabledUntil = &calcTime
|
||||
}
|
||||
|
||||
func() {
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
s.conf.ProtectionEnabled = protectionReq.Enabled
|
||||
s.conf.ProtectionDisabledUntil = disabledUntil
|
||||
}()
|
||||
|
||||
s.conf.ConfigModified()
|
||||
|
||||
aghhttp.OK(w)
|
||||
}
|
||||
|
||||
// handleDoH is the DNS-over-HTTPs handler.
|
||||
//
|
||||
// Control flow:
|
||||
@@ -719,6 +830,7 @@ func (s *Server) registerHandlers() {
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/dns_info", s.handleGetConfig)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dns_config", s.handleSetConfig)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/test_upstream_dns", s.handleTestUpstreamDNS)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/protection", s.handleSetProtection)
|
||||
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/access/list", s.handleAccessList)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/access/set", s.handleAccessSet)
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
@@ -57,7 +58,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
|
||||
filterConf := &filtering.Config{
|
||||
SafeBrowsingEnabled: true,
|
||||
SafeBrowsingCacheSize: 1000,
|
||||
SafeSearchEnabled: true,
|
||||
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||
SafeSearchCacheSize: 1000,
|
||||
ParentalCacheSize: 1000,
|
||||
CacheTime: 30,
|
||||
@@ -122,7 +123,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
|
||||
s.conf = tc.conf()
|
||||
s.handleGetConfig(w, nil)
|
||||
|
||||
cType := w.Header().Get(aghhttp.HdrNameContentType)
|
||||
cType := w.Header().Get(httphdr.ContentType)
|
||||
assert.Equal(t, aghhttp.HdrValApplicationJSON, cType)
|
||||
assert.JSONEq(t, string(caseWant), w.Body.String())
|
||||
})
|
||||
@@ -133,7 +134,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
filterConf := &filtering.Config{
|
||||
SafeBrowsingEnabled: true,
|
||||
SafeBrowsingCacheSize: 1000,
|
||||
SafeSearchEnabled: true,
|
||||
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||
SafeSearchCacheSize: 1000,
|
||||
ParentalCacheSize: 1000,
|
||||
CacheTime: 30,
|
||||
@@ -181,6 +182,12 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
}, {
|
||||
name: "edns_cs_enabled",
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "edns_cs_use_custom",
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "edns_cs_use_custom_bad_ip",
|
||||
wantSet: "decoding request: ParseAddr(\"bad.ip\"): unexpected character (at \"bad.ip\")",
|
||||
}, {
|
||||
name: "dnssec_enabled",
|
||||
wantSet: "",
|
||||
@@ -212,7 +219,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
}, {
|
||||
name: "local_ptr_upstreams_bad",
|
||||
wantSet: `validating private upstream servers: checking domain-specific upstreams: ` +
|
||||
`bad arpa domain name "non.arpa": not a reversed ip network`,
|
||||
`bad arpa domain name "non.arpa.": not a reversed ip network`,
|
||||
}, {
|
||||
name: "local_ptr_upstreams_null",
|
||||
wantSet: "",
|
||||
@@ -222,16 +229,20 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
Req json.RawMessage `json:"req"`
|
||||
Want json.RawMessage `json:"want"`
|
||||
}
|
||||
loadTestData(t, t.Name()+jsonExt, &data)
|
||||
|
||||
testData := t.Name() + jsonExt
|
||||
loadTestData(t, testData, &data)
|
||||
|
||||
for _, tc := range testCases {
|
||||
// NOTE: Do not use require.Contains, because the size of the data
|
||||
// prevents it from printing a meaningful error message.
|
||||
caseData, ok := data[tc.name]
|
||||
require.True(t, ok)
|
||||
require.Truef(t, ok, "%q does not contain test data for test case %s", testData, tc.name)
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
s.conf = defaultConf
|
||||
s.conf.FilteringConfig.EDNSClientSubnet.Enabled = false
|
||||
s.conf.FilteringConfig.EDNSClientSubnet = &EDNSClientSubnet{}
|
||||
})
|
||||
|
||||
rBody := io.NopCloser(bytes.NewReader(caseData.Req))
|
||||
@@ -373,7 +384,7 @@ func TestValidateUpstreamsPrivate(t *testing.T) {
|
||||
}, {
|
||||
name: "not_arpa_subnet",
|
||||
wantErr: `checking domain-specific upstreams: ` +
|
||||
`bad arpa domain name "hello.world": not a reversed ip network`,
|
||||
`bad arpa domain name "hello.world.": not a reversed ip network`,
|
||||
u: "[/hello.world/]#",
|
||||
}, {
|
||||
name: "non-private_arpa_address",
|
||||
@@ -389,8 +400,12 @@ func TestValidateUpstreamsPrivate(t *testing.T) {
|
||||
name: "several_bad",
|
||||
wantErr: `checking domain-specific upstreams: 2 errors: ` +
|
||||
`"arpa domain \"1.2.3.4.in-addr.arpa.\" should point to a locally-served network", ` +
|
||||
`"bad arpa domain name \"non.arpa\": not a reversed ip network"`,
|
||||
`"bad arpa domain name \"non.arpa.\": not a reversed ip network"`,
|
||||
u: "[/non.arpa/1.2.3.4.in-addr.arpa/127.in-addr.arpa/]#",
|
||||
}, {
|
||||
name: "partial_good",
|
||||
wantErr: "",
|
||||
u: "[/a.1.2.3.10.in-addr.arpa/a.10.in-addr.arpa/]#",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -40,12 +40,17 @@ func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
|
||||
|
||||
log.Debug("client ip: %s", ip)
|
||||
|
||||
ipStr := ip.String()
|
||||
ids := []string{ipStr, dctx.clientID}
|
||||
|
||||
// Synchronize access to s.queryLog and s.stats so they won't be suddenly
|
||||
// uninitialized while in use. This can happen after proxy server has been
|
||||
// stopped, but its workers haven't yet exited.
|
||||
if shouldLog &&
|
||||
s.queryLog != nil &&
|
||||
s.queryLog.ShouldLog(host, q.Qtype, q.Qclass) {
|
||||
// TODO(s.chzhen): Use dnsforward.dnsContext when it will start
|
||||
// containing persistent client.
|
||||
s.queryLog.ShouldLog(host, q.Qtype, q.Qclass, ids) {
|
||||
s.logQuery(dctx, pctx, elapsed, ip)
|
||||
} else {
|
||||
log.Debug(
|
||||
@@ -56,8 +61,11 @@ func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
|
||||
)
|
||||
}
|
||||
|
||||
if s.stats != nil && s.stats.ShouldCount(host, q.Qtype, q.Qclass) {
|
||||
s.updateStats(dctx, elapsed, *dctx.result, ip)
|
||||
if s.stats != nil &&
|
||||
// TODO(s.chzhen): Use dnsforward.dnsContext when it will start
|
||||
// containing persistent client.
|
||||
s.stats.ShouldCount(host, q.Qtype, q.Qclass, ids) {
|
||||
s.updateStats(dctx, elapsed, *dctx.result, ipStr)
|
||||
}
|
||||
|
||||
return resultCodeSuccess
|
||||
@@ -110,7 +118,7 @@ func (s *Server) updateStats(
|
||||
ctx *dnsContext,
|
||||
elapsed time.Duration,
|
||||
res filtering.Result,
|
||||
clientIP net.IP,
|
||||
clientIP string,
|
||||
) {
|
||||
pctx := ctx.proxyCtx
|
||||
e := stats.Entry{}
|
||||
@@ -119,8 +127,8 @@ func (s *Server) updateStats(
|
||||
|
||||
if clientID := ctx.clientID; clientID != "" {
|
||||
e.Client = clientID
|
||||
} else if clientIP != nil {
|
||||
e.Client = clientIP.String()
|
||||
} else {
|
||||
e.Client = clientIP
|
||||
}
|
||||
|
||||
e.Time = uint32(elapsed / 1000)
|
||||
|
||||
@@ -31,7 +31,7 @@ func (l *testQueryLog) Add(p *querylog.AddParams) {
|
||||
}
|
||||
|
||||
// ShouldLog implements the [querylog.QueryLog] interface for *testQueryLog.
|
||||
func (l *testQueryLog) ShouldLog(string, uint16, uint16) bool {
|
||||
func (l *testQueryLog) ShouldLog(string, uint16, uint16, []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func (l *testStats) Update(e stats.Entry) {
|
||||
}
|
||||
|
||||
// ShouldCount implements the [stats.Interface] interface for *testStats.
|
||||
func (l *testStats) ShouldCount(string, uint16, uint16) bool {
|
||||
func (l *testStats) ShouldCount(string, uint16, uint16, []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
@@ -26,7 +27,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
},
|
||||
"fastest_addr": {
|
||||
"upstream_dns": [
|
||||
@@ -41,6 +44,7 @@
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
@@ -55,7 +59,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
},
|
||||
"parallel": {
|
||||
"upstream_dns": [
|
||||
@@ -70,6 +76,7 @@
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
@@ -84,6 +91,8 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
@@ -33,7 +34,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"bootstraps": {
|
||||
@@ -52,6 +55,7 @@
|
||||
"9.9.9.10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
@@ -66,7 +70,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"blocking_mode_good": {
|
||||
@@ -86,6 +92,7 @@
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "refused",
|
||||
"blocking_ipv4": "",
|
||||
@@ -100,7 +107,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"blocking_mode_bad": {
|
||||
@@ -120,6 +129,7 @@
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
@@ -134,7 +144,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"ratelimit": {
|
||||
@@ -154,6 +166,7 @@
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 6,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
@@ -168,7 +181,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"edns_cs_enabled": {
|
||||
@@ -188,6 +203,7 @@
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
@@ -202,7 +218,87 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"edns_cs_use_custom": {
|
||||
"req": {
|
||||
"edns_cs_enabled": true,
|
||||
"edns_cs_use_custom": true,
|
||||
"edns_cs_custom_ip": "1.2.3.4"
|
||||
},
|
||||
"want": {
|
||||
"upstream_dns": [
|
||||
"8.8.8.8:53",
|
||||
"8.8.4.4:53"
|
||||
],
|
||||
"upstream_dns_file": "",
|
||||
"bootstrap_dns": [
|
||||
"9.9.9.10",
|
||||
"149.112.112.10",
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"edns_cs_enabled": true,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
"upstream_mode": "",
|
||||
"cache_size": 0,
|
||||
"cache_ttl_min": 0,
|
||||
"cache_ttl_max": 0,
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": true,
|
||||
"edns_cs_custom_ip": "1.2.3.4"
|
||||
}
|
||||
},
|
||||
"edns_cs_use_custom_bad_ip": {
|
||||
"req": {
|
||||
"edns_cs_enabled": true,
|
||||
"edns_cs_use_custom": true,
|
||||
"edns_cs_custom_ip": "bad.ip"
|
||||
},
|
||||
"want": {
|
||||
"upstream_dns": [
|
||||
"8.8.8.8:53",
|
||||
"8.8.4.4:53"
|
||||
],
|
||||
"upstream_dns_file": "",
|
||||
"bootstrap_dns": [
|
||||
"9.9.9.10",
|
||||
"149.112.112.10",
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
"upstream_mode": "",
|
||||
"cache_size": 0,
|
||||
"cache_ttl_min": 0,
|
||||
"cache_ttl_max": 0,
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"dnssec_enabled": {
|
||||
@@ -222,6 +318,7 @@
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
@@ -236,7 +333,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"cache_size": {
|
||||
@@ -256,6 +355,7 @@
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
@@ -270,7 +370,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"upstream_mode_parallel": {
|
||||
@@ -290,6 +392,7 @@
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
@@ -304,7 +407,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"upstream_mode_fastest_addr": {
|
||||
@@ -324,6 +429,7 @@
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
@@ -338,7 +444,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"upstream_dns_bad": {
|
||||
@@ -360,6 +468,7 @@
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
@@ -374,7 +483,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"bootstraps_bad": {
|
||||
@@ -396,6 +507,7 @@
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
@@ -410,7 +522,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"cache_bad_ttl": {
|
||||
@@ -431,6 +545,7 @@
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
@@ -445,7 +560,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"upstream_mode_bad": {
|
||||
@@ -465,6 +582,7 @@
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
@@ -479,7 +597,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"local_ptr_upstreams_good": {
|
||||
@@ -501,6 +621,7 @@
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
@@ -517,7 +638,9 @@
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": [
|
||||
"123.123.123.123"
|
||||
]
|
||||
],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"local_ptr_upstreams_bad": {
|
||||
@@ -540,6 +663,7 @@
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
@@ -554,7 +678,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"local_ptr_upstreams_null": {
|
||||
@@ -574,6 +700,7 @@
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
@@ -588,7 +715,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -176,13 +176,16 @@ func (d *DNSFilter) filterExistsLocked(url string) (ok bool) {
|
||||
|
||||
// Add a filter
|
||||
// Return FALSE if a filter with this URL exists
|
||||
func (d *DNSFilter) filterAdd(flt FilterYAML) bool {
|
||||
func (d *DNSFilter) filterAdd(flt FilterYAML) (err error) {
|
||||
// Defer annotating to unlock sooner.
|
||||
defer func() { err = errors.Annotate(err, "adding filter: %w") }()
|
||||
|
||||
d.filtersMu.Lock()
|
||||
defer d.filtersMu.Unlock()
|
||||
|
||||
// Check for duplicates
|
||||
// Check for duplicates.
|
||||
if d.filterExistsLocked(flt.URL) {
|
||||
return false
|
||||
return errFilterExists
|
||||
}
|
||||
|
||||
if flt.white {
|
||||
@@ -190,7 +193,8 @@ func (d *DNSFilter) filterAdd(flt FilterYAML) bool {
|
||||
} else {
|
||||
d.Filters = append(d.Filters, flt)
|
||||
}
|
||||
return true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load filters from the disk
|
||||
@@ -238,6 +242,7 @@ func updateUniqueFilterID(filters []FilterYAML) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Improve this inexhaustible source of races.
|
||||
func assignUniqueFilterID() int64 {
|
||||
value := nextFilterID
|
||||
nextFilterID++
|
||||
@@ -343,29 +348,31 @@ func (d *DNSFilter) refreshFiltersArray(filters *[]FilterYAML, force bool) (int,
|
||||
}
|
||||
|
||||
updateCount := 0
|
||||
|
||||
d.filtersMu.Lock()
|
||||
defer d.filtersMu.Unlock()
|
||||
|
||||
for i := range updateFilters {
|
||||
uf := &updateFilters[i]
|
||||
updated := updateFlags[i]
|
||||
|
||||
d.filtersMu.Lock()
|
||||
for k := range *filters {
|
||||
f := &(*filters)[k]
|
||||
if f.ID != uf.ID || f.URL != uf.URL {
|
||||
continue
|
||||
}
|
||||
|
||||
f.LastUpdated = uf.LastUpdated
|
||||
if !updated {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Info("Updated filter #%d. Rules: %d -> %d",
|
||||
f.ID, f.RulesCount, uf.RulesCount)
|
||||
log.Info("Updated filter #%d. Rules: %d -> %d", f.ID, f.RulesCount, uf.RulesCount)
|
||||
f.Name = uf.Name
|
||||
f.RulesCount = uf.RulesCount
|
||||
f.checksum = uf.checksum
|
||||
updateCount++
|
||||
}
|
||||
d.filtersMu.Unlock()
|
||||
}
|
||||
|
||||
return updateCount, updateFilters, updateFlags, false
|
||||
@@ -421,11 +428,16 @@ func (d *DNSFilter) refreshFiltersIntl(block, allow, force bool) (int, bool) {
|
||||
if !updated {
|
||||
continue
|
||||
}
|
||||
_ = os.Remove(uf.Path(d.DataDir) + ".old")
|
||||
|
||||
p := uf.Path(d.DataDir)
|
||||
err := os.Remove(p + ".old")
|
||||
if err != nil {
|
||||
log.Debug("filtering: removing old filter file %q: %s", p, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("filtering: update finished")
|
||||
log.Debug("filtering: update finished: %d lists updated", updNum)
|
||||
|
||||
return updNum, false
|
||||
}
|
||||
@@ -467,8 +479,8 @@ func scanLinesWithBreak(data []byte, atEOF bool) (advance int, token []byte, err
|
||||
}
|
||||
|
||||
// parseFilter copies filter's content from src to dst and returns the number of
|
||||
// rules, name, number of bytes written, checksum, and title of the parsed list.
|
||||
// dst must not be nil.
|
||||
// rules, number of bytes written, checksum, and title of the parsed list. dst
|
||||
// must not be nil.
|
||||
func (d *DNSFilter) parseFilter(
|
||||
src io.Reader,
|
||||
dst io.Writer,
|
||||
@@ -550,14 +562,18 @@ func isHTML(line string) (ok bool) {
|
||||
return strings.HasPrefix(line, "<html") || strings.HasPrefix(line, "<!doctype")
|
||||
}
|
||||
|
||||
// Perform upgrade on a filter and update LastUpdated value
|
||||
func (d *DNSFilter) update(filter *FilterYAML) (bool, error) {
|
||||
b, err := d.updateIntl(filter)
|
||||
// update refreshes filter's content and a/mtimes of it's file.
|
||||
func (d *DNSFilter) update(filter *FilterYAML) (b bool, err error) {
|
||||
b, err = d.updateIntl(filter)
|
||||
filter.LastUpdated = time.Now()
|
||||
if !b {
|
||||
e := os.Chtimes(filter.Path(d.DataDir), filter.LastUpdated, filter.LastUpdated)
|
||||
if e != nil {
|
||||
log.Error("os.Chtimes(): %v", e)
|
||||
chErr := os.Chtimes(
|
||||
filter.Path(d.DataDir),
|
||||
filter.LastUpdated,
|
||||
filter.LastUpdated,
|
||||
)
|
||||
if chErr != nil {
|
||||
log.Error("os.Chtimes(): %v", chErr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -591,11 +607,13 @@ func (d *DNSFilter) finalizeUpdate(
|
||||
return os.Remove(tmpFileName)
|
||||
}
|
||||
|
||||
log.Printf("saving filter %d contents to: %s", flt.ID, flt.Path(d.DataDir))
|
||||
fltPath := flt.Path(d.DataDir)
|
||||
|
||||
log.Printf("saving contents of filter #%d into %s", flt.ID, fltPath)
|
||||
|
||||
// Don't use renamio or maybe packages, since those will require loading the
|
||||
// whole filter content to the memory on Windows.
|
||||
err = os.Rename(tmpFileName, flt.Path(d.DataDir))
|
||||
err = os.Rename(tmpFileName, fltPath)
|
||||
if err != nil {
|
||||
return errors.WithDeferred(err, os.Remove(tmpFileName))
|
||||
}
|
||||
@@ -620,10 +638,14 @@ func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
|
||||
return false, err
|
||||
}
|
||||
defer func() {
|
||||
err = errors.WithDeferred(err, d.finalizeUpdate(tmpFile, flt, ok, name, rnum, cs))
|
||||
if ok && err == nil {
|
||||
finErr := d.finalizeUpdate(tmpFile, flt, ok, name, rnum, cs)
|
||||
if ok && finErr == nil {
|
||||
log.Printf("updated filter %d: %d bytes, %d rules", flt.ID, n, rnum)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = errors.WithDeferred(err, finErr)
|
||||
}()
|
||||
|
||||
// Change the default 0o600 permission to something more acceptable by end
|
||||
@@ -634,7 +656,7 @@ func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
|
||||
return false, fmt.Errorf("changing file mode: %w", err)
|
||||
}
|
||||
|
||||
var rc io.ReadCloser
|
||||
var r io.Reader
|
||||
if !filepath.IsAbs(flt.URL) {
|
||||
var resp *http.Response
|
||||
resp, err = d.HTTPClient.Get(flt.URL)
|
||||
@@ -651,16 +673,19 @@ func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
|
||||
return false, fmt.Errorf("got status code %d, want %d", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
rc = resp.Body
|
||||
r = resp.Body
|
||||
} else {
|
||||
rc, err = os.Open(flt.URL)
|
||||
var f *os.File
|
||||
f, err = os.Open(flt.URL)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("open file: %w", err)
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, rc.Close()) }()
|
||||
defer func() { err = errors.WithDeferred(err, f.Close()) }()
|
||||
|
||||
r = f
|
||||
}
|
||||
|
||||
rnum, n, cs, name, err = d.parseFilter(rc, tmpFile)
|
||||
rnum, n, cs, name, err = d.parseFilter(r, tmpFile)
|
||||
|
||||
return cs != flt.checksum && err == nil, err
|
||||
}
|
||||
@@ -705,10 +730,11 @@ func (d *DNSFilter) EnableFilters(async bool) {
|
||||
}
|
||||
|
||||
func (d *DNSFilter) enableFiltersLocked(async bool) {
|
||||
filters := []Filter{{
|
||||
filters := make([]Filter, 1, len(d.Filters)+len(d.WhitelistFilters)+1)
|
||||
filters[0] = Filter{
|
||||
ID: CustomListID,
|
||||
Data: []byte(strings.Join(d.UserRules, "\n")),
|
||||
}}
|
||||
}
|
||||
|
||||
for _, filter := range d.Filters {
|
||||
if !filter.Enabled {
|
||||
|
||||
@@ -63,6 +63,9 @@ type Settings struct {
|
||||
SafeSearchEnabled bool
|
||||
SafeBrowsingEnabled bool
|
||||
ParentalEnabled bool
|
||||
|
||||
// ClientSafeSearch is a client configured safe search.
|
||||
ClientSafeSearch SafeSearch
|
||||
}
|
||||
|
||||
// Resolver is the interface for net.Resolver to simplify testing.
|
||||
@@ -83,13 +86,16 @@ type Config struct {
|
||||
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
|
||||
|
||||
ParentalEnabled bool `yaml:"parental_enabled"`
|
||||
SafeSearchEnabled bool `yaml:"safesearch_enabled"`
|
||||
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
|
||||
|
||||
SafeBrowsingCacheSize uint `yaml:"safebrowsing_cache_size"` // (in bytes)
|
||||
SafeSearchCacheSize uint `yaml:"safesearch_cache_size"` // (in bytes)
|
||||
ParentalCacheSize uint `yaml:"parental_cache_size"` // (in bytes)
|
||||
CacheTime uint `yaml:"cache_time"` // Element's TTL (in minutes)
|
||||
// TODO(a.garipov): Use timeutil.Duration
|
||||
CacheTime uint `yaml:"cache_time"` // Element's TTL (in minutes)
|
||||
|
||||
SafeSearchConf SafeSearchConfig `yaml:"safe_search"`
|
||||
SafeSearch SafeSearch `yaml:"-"`
|
||||
|
||||
Rewrites []*LegacyRewrite `yaml:"rewrites"`
|
||||
|
||||
@@ -107,9 +113,6 @@ type Config struct {
|
||||
// Register an HTTP handler
|
||||
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
|
||||
|
||||
// CustomResolver is the resolver used by DNSFilter.
|
||||
CustomResolver Resolver `yaml:"-"`
|
||||
|
||||
// HTTPClient is the client to use for updating the remote filters.
|
||||
HTTPClient *http.Client `yaml:"-"`
|
||||
|
||||
@@ -172,7 +175,6 @@ type DNSFilter struct {
|
||||
|
||||
safebrowsingCache cache.Cache
|
||||
parentalCache cache.Cache
|
||||
safeSearchCache cache.Cache
|
||||
|
||||
Config // for direct access by library users, even a = assignment
|
||||
// confLock protects Config.
|
||||
@@ -182,11 +184,6 @@ type DNSFilter struct {
|
||||
filtersInitializerChan chan filtersInitializerParams
|
||||
filtersInitializerLock sync.Mutex
|
||||
|
||||
// resolver only looks up the IP address of the host while safe search.
|
||||
//
|
||||
// TODO(e.burkov): Use upstream that configured in dnsforward instead.
|
||||
resolver Resolver
|
||||
|
||||
refreshLock *sync.Mutex
|
||||
|
||||
// filterTitleRegexp is the regular expression to retrieve a name of a
|
||||
@@ -195,6 +192,7 @@ type DNSFilter struct {
|
||||
// TODO(e.burkov): Don't use regexp for such a simple text processing task.
|
||||
filterTitleRegexp *regexp.Regexp
|
||||
|
||||
safeSearch SafeSearch
|
||||
hostCheckers []hostChecker
|
||||
}
|
||||
|
||||
@@ -298,7 +296,7 @@ func (d *DNSFilter) GetConfig() (s Settings) {
|
||||
|
||||
return Settings{
|
||||
FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) != 0,
|
||||
SafeSearchEnabled: d.Config.SafeSearchEnabled,
|
||||
SafeSearchEnabled: d.Config.SafeSearchConf.Enabled,
|
||||
SafeBrowsingEnabled: d.Config.SafeBrowsingEnabled,
|
||||
ParentalEnabled: d.Config.ParentalEnabled,
|
||||
}
|
||||
@@ -942,7 +940,6 @@ func InitModule() {
|
||||
// be non-nil.
|
||||
func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
||||
d = &DNSFilter{
|
||||
resolver: net.DefaultResolver,
|
||||
refreshLock: &sync.Mutex{},
|
||||
filterTitleRegexp: regexp.MustCompile(`^! Title: +(.*)$`),
|
||||
}
|
||||
@@ -951,18 +948,12 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
||||
EnableLRU: true,
|
||||
MaxSize: c.SafeBrowsingCacheSize,
|
||||
})
|
||||
d.safeSearchCache = cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxSize: c.SafeSearchCacheSize,
|
||||
})
|
||||
d.parentalCache = cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxSize: c.ParentalCacheSize,
|
||||
})
|
||||
|
||||
if r := c.CustomResolver; r != nil {
|
||||
d.resolver = r
|
||||
}
|
||||
d.safeSearch = c.SafeSearch
|
||||
|
||||
d.hostCheckers = []hostChecker{{
|
||||
check: d.matchSysHosts,
|
||||
|
||||
@@ -2,10 +2,8 @@ package filtering
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
@@ -33,7 +31,6 @@ func purgeCaches(d *DNSFilter) {
|
||||
for _, c := range []cache.Cache{
|
||||
d.safebrowsingCache,
|
||||
d.parentalCache,
|
||||
d.safeSearchCache,
|
||||
} {
|
||||
if c != nil {
|
||||
c.Clear()
|
||||
@@ -51,7 +48,7 @@ func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts
|
||||
c.ParentalCacheSize = 10000
|
||||
c.SafeSearchCacheSize = 1000
|
||||
c.CacheTime = 30
|
||||
setts.SafeSearchEnabled = c.SafeSearchEnabled
|
||||
setts.SafeSearchEnabled = c.SafeSearchConf.Enabled
|
||||
setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled
|
||||
setts.ParentalEnabled = c.ParentalEnabled
|
||||
} else {
|
||||
@@ -216,164 +213,6 @@ func TestParallelSB(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// Safe Search.
|
||||
|
||||
func TestSafeSearch(t *testing.T) {
|
||||
d, _ := newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
val, ok := d.SafeSearchDomain("www.google.com")
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "forcesafesearch.google.com", val)
|
||||
}
|
||||
|
||||
func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||
d, setts := newForTest(t, &Config{
|
||||
SafeSearchEnabled: true,
|
||||
}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
yandexIP := net.IPv4(213, 180, 193, 56)
|
||||
|
||||
// Check host for each domain.
|
||||
for _, host := range []string{
|
||||
"yAndeX.ru",
|
||||
"YANdex.COM",
|
||||
"yandex.ua",
|
||||
"yandex.by",
|
||||
"yandex.kz",
|
||||
"www.yandex.com",
|
||||
} {
|
||||
t.Run(strings.ToLower(host), func(t *testing.T) {
|
||||
res, err := d.CheckHost(host, dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.Equal(t, yandexIP, res.Rules[0].IP)
|
||||
assert.EqualValues(t, SafeSearchListID, res.Rules[0].FilterListID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||
resolver := &aghtest.TestResolver{}
|
||||
d, setts := newForTest(t, &Config{
|
||||
SafeSearchEnabled: true,
|
||||
CustomResolver: resolver,
|
||||
}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
ip, _ := resolver.HostToIPs("forcesafesearch.google.com")
|
||||
|
||||
// Check host for each domain.
|
||||
for _, host := range []string{
|
||||
"www.google.com",
|
||||
"www.google.im",
|
||||
"www.google.co.in",
|
||||
"www.google.iq",
|
||||
"www.google.is",
|
||||
"www.google.it",
|
||||
"www.google.je",
|
||||
} {
|
||||
t.Run(host, func(t *testing.T) {
|
||||
res, err := d.CheckHost(host, dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.Equal(t, ip, res.Rules[0].IP)
|
||||
assert.EqualValues(t, SafeSearchListID, res.Rules[0].FilterListID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeSearchCacheYandex(t *testing.T) {
|
||||
d, setts := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
const domain = "yandex.ru"
|
||||
|
||||
// Check host with disabled safesearch.
|
||||
res, err := d.CheckHost(domain, dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, res.IsFiltered)
|
||||
|
||||
require.Empty(t, res.Rules)
|
||||
|
||||
yandexIP := net.IPv4(213, 180, 193, 56)
|
||||
|
||||
d, setts = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
res, err = d.CheckHost(domain, dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// For yandex we already know valid IP.
|
||||
require.Len(t, res.Rules, 1)
|
||||
assert.Equal(t, res.Rules[0].IP, yandexIP)
|
||||
|
||||
// Check cache.
|
||||
cachedValue, isFound := getCachedResult(d.safeSearchCache, domain)
|
||||
require.True(t, isFound)
|
||||
require.Len(t, cachedValue.Rules, 1)
|
||||
|
||||
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
|
||||
}
|
||||
|
||||
func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||
resolver := &aghtest.TestResolver{}
|
||||
d, setts := newForTest(t, &Config{
|
||||
CustomResolver: resolver,
|
||||
}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
const domain = "www.google.ru"
|
||||
res, err := d.CheckHost(domain, dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, res.IsFiltered)
|
||||
|
||||
require.Empty(t, res.Rules)
|
||||
|
||||
d, setts = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
d.resolver = resolver
|
||||
|
||||
// Lookup for safesearch domain.
|
||||
safeDomain, ok := d.SafeSearchDomain(domain)
|
||||
require.True(t, ok)
|
||||
|
||||
ips, err := resolver.LookupIP(context.Background(), "ip", safeDomain)
|
||||
require.NoError(t, err)
|
||||
|
||||
var ip net.IP
|
||||
for _, foundIP := range ips {
|
||||
if foundIP.To4() != nil {
|
||||
ip = foundIP
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
res, err = d.CheckHost(domain, dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.True(t, res.Rules[0].IP.Equal(ip))
|
||||
|
||||
// Check cache.
|
||||
cachedValue, isFound := getCachedResult(d.safeSearchCache, domain)
|
||||
require.True(t, isFound)
|
||||
require.Len(t, cachedValue.Rules, 1)
|
||||
|
||||
assert.True(t, cachedValue.Rules[0].IP.Equal(ip))
|
||||
}
|
||||
|
||||
// Parental.
|
||||
|
||||
func TestParentalControl(t *testing.T) {
|
||||
@@ -854,27 +693,3 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkSafeSearch(b *testing.B) {
|
||||
d, _ := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
for n := 0; n < b.N; n++ {
|
||||
val, ok := d.SafeSearchDomain("www.google.com")
|
||||
require.True(b, ok)
|
||||
|
||||
assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSafeSearchParallel(b *testing.B) {
|
||||
d, _ := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
val, ok := d.SafeSearchDomain("www.google.com")
|
||||
require.True(b, ok)
|
||||
|
||||
assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -14,26 +14,33 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// validateFilterURL validates the filter list URL or file name.
|
||||
func validateFilterURL(urlStr string) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "checking filter: %w") }()
|
||||
|
||||
if filepath.IsAbs(urlStr) {
|
||||
_, err = os.Stat(urlStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking filter file: %w", err)
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
url, err := url.ParseRequestURI(urlStr)
|
||||
u, err := url.ParseRequestURI(urlStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking filter url: %w", err)
|
||||
}
|
||||
|
||||
if s := url.Scheme; s != aghhttp.SchemeHTTP && s != aghhttp.SchemeHTTPS {
|
||||
return fmt.Errorf("checking filter url: invalid scheme %q", s)
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
} else if s := u.Scheme; s != aghhttp.SchemeHTTP && s != aghhttp.SchemeHTTPS {
|
||||
return &url.Error{
|
||||
Op: "Check scheme",
|
||||
URL: urlStr,
|
||||
Err: fmt.Errorf("only %v allowed", []string{aghhttp.SchemeHTTP, aghhttp.SchemeHTTPS}),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -63,7 +70,8 @@ func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
||||
|
||||
// Check for duplicates
|
||||
if d.filterExists(fj.URL) {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL)
|
||||
err = errFilterExists
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Filter with URL %q: %s", fj.URL, err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -99,7 +107,7 @@ func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
"Filter at the url %s is invalid (maybe it points to blank page?)",
|
||||
"Filter with URL %q is invalid (maybe it points to blank page?)",
|
||||
filt.URL,
|
||||
)
|
||||
|
||||
@@ -108,8 +116,9 @@ func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
||||
|
||||
// URL is assumed valid so append it to filters, update config, write new
|
||||
// file and reload it to engines.
|
||||
if !d.filterAdd(filt) {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL)
|
||||
err = d.filterAdd(filt)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Filter with URL %q: %s", filt.URL, err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -137,31 +146,38 @@ func (d *DNSFilter) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
|
||||
d.filtersMu.Lock()
|
||||
filters := &d.Filters
|
||||
if req.Whitelist {
|
||||
filters = &d.WhitelistFilters
|
||||
}
|
||||
|
||||
var deleted FilterYAML
|
||||
var newFilters []FilterYAML
|
||||
for _, flt := range *filters {
|
||||
if flt.URL != req.URL {
|
||||
newFilters = append(newFilters, flt)
|
||||
func() {
|
||||
d.filtersMu.Lock()
|
||||
defer d.filtersMu.Unlock()
|
||||
|
||||
continue
|
||||
filters := &d.Filters
|
||||
if req.Whitelist {
|
||||
filters = &d.WhitelistFilters
|
||||
}
|
||||
|
||||
deleted = flt
|
||||
path := flt.Path(d.DataDir)
|
||||
err = os.Rename(path, path+".old")
|
||||
delIdx := slices.IndexFunc(*filters, func(flt FilterYAML) bool {
|
||||
return flt.URL == req.URL
|
||||
})
|
||||
if delIdx == -1 {
|
||||
log.Error("deleting filter with url %q: %s", req.URL, errFilterNotExist)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
deleted = (*filters)[delIdx]
|
||||
p := deleted.Path(d.DataDir)
|
||||
err = os.Rename(p, p+".old")
|
||||
if err != nil {
|
||||
log.Error("deleting filter %q: %s", path, err)
|
||||
}
|
||||
}
|
||||
log.Error("deleting filter %d: renaming file %q: %s", deleted.ID, p, err)
|
||||
|
||||
*filters = newFilters
|
||||
d.filtersMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
*filters = slices.Delete(*filters, delIdx, delIdx+1)
|
||||
|
||||
log.Info("deleted filter %d", deleted.ID)
|
||||
}()
|
||||
|
||||
d.ConfigModified()
|
||||
d.EnableFilters(true)
|
||||
@@ -258,10 +274,6 @@ func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
|
||||
type Req struct {
|
||||
White bool `json:"whitelist"`
|
||||
}
|
||||
type Resp struct {
|
||||
Updated int `json:"updated"`
|
||||
}
|
||||
resp := Resp{}
|
||||
var err error
|
||||
|
||||
req := Req{}
|
||||
@@ -273,6 +285,9 @@ func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
|
||||
var ok bool
|
||||
resp := struct {
|
||||
Updated int `json:"updated"`
|
||||
}{}
|
||||
resp.Updated, _, ok = d.tryRefreshFilters(!req.White, req.White, true)
|
||||
if !ok {
|
||||
aghhttp.Error(
|
||||
@@ -461,6 +476,7 @@ func (d *DNSFilter) RegisterFilteringHandlers() {
|
||||
registerHTTP(http.MethodPost, "/control/safesearch/enable", d.handleSafeSearchEnable)
|
||||
registerHTTP(http.MethodPost, "/control/safesearch/disable", d.handleSafeSearchDisable)
|
||||
registerHTTP(http.MethodGet, "/control/safesearch/status", d.handleSafeSearchStatus)
|
||||
registerHTTP(http.MethodPut, "/control/safesearch/settings", d.handleSafeSearchSettings)
|
||||
|
||||
registerHTTP(http.MethodGet, "/control/rewrite/list", d.handleRewriteList)
|
||||
registerHTTP(http.MethodPost, "/control/rewrite/add", d.handleRewriteAdd)
|
||||
|
||||
@@ -1,34 +1,23 @@
|
||||
package filtering
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
)
|
||||
import "github.com/miekg/dns"
|
||||
|
||||
// SafeSearch interface describes a service for search engines hosts rewrites.
|
||||
type SafeSearch interface {
|
||||
// SearchHost returns a replacement address for the search engine host.
|
||||
SearchHost(host string, qtype uint16) (res *rules.DNSRewrite)
|
||||
|
||||
// CheckHost checks host with safe search engine.
|
||||
// CheckHost checks host with safe search filter. CheckHost must be safe
|
||||
// for concurrent use. qtype must be either [dns.TypeA] or [dns.TypeAAAA].
|
||||
CheckHost(host string, qtype uint16) (res Result, err error)
|
||||
|
||||
// Update updates the configuration of the safe search filter. Update must
|
||||
// be safe for concurrent use. An implementation of Update may ignore some
|
||||
// fields, but it must document which.
|
||||
Update(conf SafeSearchConfig) (err error)
|
||||
}
|
||||
|
||||
// SafeSearchConfig is a struct with safe search related settings.
|
||||
type SafeSearchConfig struct {
|
||||
// CustomResolver is the resolver used by safe search.
|
||||
CustomResolver Resolver `yaml:"-"`
|
||||
CustomResolver Resolver `yaml:"-" json:"-"`
|
||||
|
||||
// Enabled indicates if safe search is enabled entirely.
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
@@ -44,358 +33,27 @@ type SafeSearchConfig struct {
|
||||
YouTube bool `yaml:"youtube" json:"youtube"`
|
||||
}
|
||||
|
||||
/*
|
||||
expire byte[4]
|
||||
res Result
|
||||
*/
|
||||
func (d *DNSFilter) setCacheResult(cache cache.Cache, host string, res Result) int {
|
||||
var buf bytes.Buffer
|
||||
|
||||
expire := uint(time.Now().Unix()) + d.Config.CacheTime*60
|
||||
exp := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(exp, uint32(expire))
|
||||
_, _ = buf.Write(exp)
|
||||
|
||||
enc := gob.NewEncoder(&buf)
|
||||
err := enc.Encode(res)
|
||||
if err != nil {
|
||||
log.Error("gob.Encode(): %s", err)
|
||||
return 0
|
||||
}
|
||||
val := buf.Bytes()
|
||||
_ = cache.Set([]byte(host), val)
|
||||
return len(val)
|
||||
}
|
||||
|
||||
func getCachedResult(cache cache.Cache, host string) (Result, bool) {
|
||||
data := cache.Get([]byte(host))
|
||||
if data == nil {
|
||||
return Result{}, false
|
||||
}
|
||||
|
||||
exp := int(binary.BigEndian.Uint32(data[:4]))
|
||||
if exp <= int(time.Now().Unix()) {
|
||||
cache.Del([]byte(host))
|
||||
return Result{}, false
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.Write(data[4:])
|
||||
dec := gob.NewDecoder(&buf)
|
||||
r := Result{}
|
||||
err := dec.Decode(&r)
|
||||
if err != nil {
|
||||
log.Debug("gob.Decode(): %s", err)
|
||||
return Result{}, false
|
||||
}
|
||||
|
||||
return r, true
|
||||
}
|
||||
|
||||
// SafeSearchDomain returns replacement address for search engine
|
||||
func (d *DNSFilter) SafeSearchDomain(host string) (string, bool) {
|
||||
val, ok := safeSearchDomains[host]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// checkSafeSearch checks host with safe search engine. Matches
|
||||
// [hostChecker.check].
|
||||
func (d *DNSFilter) checkSafeSearch(
|
||||
host string,
|
||||
_ uint16,
|
||||
qtype uint16,
|
||||
setts *Settings,
|
||||
) (res Result, err error) {
|
||||
if !setts.ProtectionEnabled || !setts.SafeSearchEnabled {
|
||||
if !setts.ProtectionEnabled ||
|
||||
!setts.SafeSearchEnabled ||
|
||||
(qtype != dns.TypeA && qtype != dns.TypeAAAA) {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
if log.GetLevel() >= log.DEBUG {
|
||||
timer := log.StartTimer()
|
||||
defer timer.LogElapsed("SafeSearch: lookup for %s", host)
|
||||
}
|
||||
|
||||
// Check cache. Return cached result if it was found
|
||||
cachedValue, isFound := getCachedResult(d.safeSearchCache, host)
|
||||
if isFound {
|
||||
// atomic.AddUint64(&gctx.stats.Safesearch.CacheHits, 1)
|
||||
log.Tracef("SafeSearch: found in cache: %s", host)
|
||||
return cachedValue, nil
|
||||
}
|
||||
|
||||
safeHost, ok := d.SafeSearchDomain(host)
|
||||
if !ok {
|
||||
if d.safeSearch == nil {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
res = Result{
|
||||
Rules: []*ResultRule{{
|
||||
FilterListID: SafeSearchListID,
|
||||
}},
|
||||
Reason: FilteredSafeSearch,
|
||||
IsFiltered: true,
|
||||
clientSafeSearch := setts.ClientSafeSearch
|
||||
if clientSafeSearch != nil {
|
||||
return clientSafeSearch.CheckHost(host, qtype)
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(safeHost); ip != nil {
|
||||
res.Rules[0].IP = ip
|
||||
valLen := d.setCacheResult(d.safeSearchCache, host, res)
|
||||
log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, valLen)
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
ips, err := d.resolver.LookupIP(context.Background(), "ip", safeHost)
|
||||
if err != nil {
|
||||
log.Tracef("SafeSearchDomain for %s was found but failed to lookup for %s cause %s", host, safeHost, err)
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if ip = ip.To4(); ip == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
res.Rules[0].IP = ip
|
||||
|
||||
l := d.setCacheResult(d.safeSearchCache, host, res)
|
||||
log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, l)
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
return Result{}, fmt.Errorf("no ipv4 addresses in safe search response for %s", safeHost)
|
||||
}
|
||||
|
||||
func (d *DNSFilter) handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) {
|
||||
setProtectedBool(&d.confLock, &d.Config.SafeSearchEnabled, true)
|
||||
d.Config.ConfigModified()
|
||||
}
|
||||
|
||||
func (d *DNSFilter) handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) {
|
||||
setProtectedBool(&d.confLock, &d.Config.SafeSearchEnabled, false)
|
||||
d.Config.ConfigModified()
|
||||
}
|
||||
|
||||
func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) {
|
||||
resp := &struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}{
|
||||
Enabled: protectedBool(&d.confLock, &d.Config.SafeSearchEnabled),
|
||||
}
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
}
|
||||
|
||||
var safeSearchDomains = map[string]string{
|
||||
"yandex.com": "213.180.193.56",
|
||||
"yandex.ru": "213.180.193.56",
|
||||
"yandex.ua": "213.180.193.56",
|
||||
"yandex.by": "213.180.193.56",
|
||||
"yandex.kz": "213.180.193.56",
|
||||
"www.yandex.com": "213.180.193.56",
|
||||
"www.yandex.ru": "213.180.193.56",
|
||||
"www.yandex.ua": "213.180.193.56",
|
||||
"www.yandex.by": "213.180.193.56",
|
||||
"www.yandex.kz": "213.180.193.56",
|
||||
|
||||
"www.bing.com": "strict.bing.com",
|
||||
|
||||
"duckduckgo.com": "safe.duckduckgo.com",
|
||||
"www.duckduckgo.com": "safe.duckduckgo.com",
|
||||
"start.duckduckgo.com": "safe.duckduckgo.com",
|
||||
|
||||
"www.google.com": "forcesafesearch.google.com",
|
||||
"www.google.ad": "forcesafesearch.google.com",
|
||||
"www.google.ae": "forcesafesearch.google.com",
|
||||
"www.google.com.af": "forcesafesearch.google.com",
|
||||
"www.google.com.ag": "forcesafesearch.google.com",
|
||||
"www.google.com.ai": "forcesafesearch.google.com",
|
||||
"www.google.al": "forcesafesearch.google.com",
|
||||
"www.google.am": "forcesafesearch.google.com",
|
||||
"www.google.co.ao": "forcesafesearch.google.com",
|
||||
"www.google.com.ar": "forcesafesearch.google.com",
|
||||
"www.google.as": "forcesafesearch.google.com",
|
||||
"www.google.at": "forcesafesearch.google.com",
|
||||
"www.google.com.au": "forcesafesearch.google.com",
|
||||
"www.google.az": "forcesafesearch.google.com",
|
||||
"www.google.ba": "forcesafesearch.google.com",
|
||||
"www.google.com.bd": "forcesafesearch.google.com",
|
||||
"www.google.be": "forcesafesearch.google.com",
|
||||
"www.google.bf": "forcesafesearch.google.com",
|
||||
"www.google.bg": "forcesafesearch.google.com",
|
||||
"www.google.com.bh": "forcesafesearch.google.com",
|
||||
"www.google.bi": "forcesafesearch.google.com",
|
||||
"www.google.bj": "forcesafesearch.google.com",
|
||||
"www.google.com.bn": "forcesafesearch.google.com",
|
||||
"www.google.com.bo": "forcesafesearch.google.com",
|
||||
"www.google.com.br": "forcesafesearch.google.com",
|
||||
"www.google.bs": "forcesafesearch.google.com",
|
||||
"www.google.bt": "forcesafesearch.google.com",
|
||||
"www.google.co.bw": "forcesafesearch.google.com",
|
||||
"www.google.by": "forcesafesearch.google.com",
|
||||
"www.google.com.bz": "forcesafesearch.google.com",
|
||||
"www.google.ca": "forcesafesearch.google.com",
|
||||
"www.google.cd": "forcesafesearch.google.com",
|
||||
"www.google.cf": "forcesafesearch.google.com",
|
||||
"www.google.cg": "forcesafesearch.google.com",
|
||||
"www.google.ch": "forcesafesearch.google.com",
|
||||
"www.google.ci": "forcesafesearch.google.com",
|
||||
"www.google.co.ck": "forcesafesearch.google.com",
|
||||
"www.google.cl": "forcesafesearch.google.com",
|
||||
"www.google.cm": "forcesafesearch.google.com",
|
||||
"www.google.cn": "forcesafesearch.google.com",
|
||||
"www.google.com.co": "forcesafesearch.google.com",
|
||||
"www.google.co.cr": "forcesafesearch.google.com",
|
||||
"www.google.com.cu": "forcesafesearch.google.com",
|
||||
"www.google.cv": "forcesafesearch.google.com",
|
||||
"www.google.com.cy": "forcesafesearch.google.com",
|
||||
"www.google.cz": "forcesafesearch.google.com",
|
||||
"www.google.de": "forcesafesearch.google.com",
|
||||
"www.google.dj": "forcesafesearch.google.com",
|
||||
"www.google.dk": "forcesafesearch.google.com",
|
||||
"www.google.dm": "forcesafesearch.google.com",
|
||||
"www.google.com.do": "forcesafesearch.google.com",
|
||||
"www.google.dz": "forcesafesearch.google.com",
|
||||
"www.google.com.ec": "forcesafesearch.google.com",
|
||||
"www.google.ee": "forcesafesearch.google.com",
|
||||
"www.google.com.eg": "forcesafesearch.google.com",
|
||||
"www.google.es": "forcesafesearch.google.com",
|
||||
"www.google.com.et": "forcesafesearch.google.com",
|
||||
"www.google.fi": "forcesafesearch.google.com",
|
||||
"www.google.com.fj": "forcesafesearch.google.com",
|
||||
"www.google.fm": "forcesafesearch.google.com",
|
||||
"www.google.fr": "forcesafesearch.google.com",
|
||||
"www.google.ga": "forcesafesearch.google.com",
|
||||
"www.google.ge": "forcesafesearch.google.com",
|
||||
"www.google.gg": "forcesafesearch.google.com",
|
||||
"www.google.com.gh": "forcesafesearch.google.com",
|
||||
"www.google.com.gi": "forcesafesearch.google.com",
|
||||
"www.google.gl": "forcesafesearch.google.com",
|
||||
"www.google.gm": "forcesafesearch.google.com",
|
||||
"www.google.gp": "forcesafesearch.google.com",
|
||||
"www.google.gr": "forcesafesearch.google.com",
|
||||
"www.google.com.gt": "forcesafesearch.google.com",
|
||||
"www.google.gy": "forcesafesearch.google.com",
|
||||
"www.google.com.hk": "forcesafesearch.google.com",
|
||||
"www.google.hn": "forcesafesearch.google.com",
|
||||
"www.google.hr": "forcesafesearch.google.com",
|
||||
"www.google.ht": "forcesafesearch.google.com",
|
||||
"www.google.hu": "forcesafesearch.google.com",
|
||||
"www.google.co.id": "forcesafesearch.google.com",
|
||||
"www.google.ie": "forcesafesearch.google.com",
|
||||
"www.google.co.il": "forcesafesearch.google.com",
|
||||
"www.google.im": "forcesafesearch.google.com",
|
||||
"www.google.co.in": "forcesafesearch.google.com",
|
||||
"www.google.iq": "forcesafesearch.google.com",
|
||||
"www.google.is": "forcesafesearch.google.com",
|
||||
"www.google.it": "forcesafesearch.google.com",
|
||||
"www.google.je": "forcesafesearch.google.com",
|
||||
"www.google.com.jm": "forcesafesearch.google.com",
|
||||
"www.google.jo": "forcesafesearch.google.com",
|
||||
"www.google.co.jp": "forcesafesearch.google.com",
|
||||
"www.google.co.ke": "forcesafesearch.google.com",
|
||||
"www.google.com.kh": "forcesafesearch.google.com",
|
||||
"www.google.ki": "forcesafesearch.google.com",
|
||||
"www.google.kg": "forcesafesearch.google.com",
|
||||
"www.google.co.kr": "forcesafesearch.google.com",
|
||||
"www.google.com.kw": "forcesafesearch.google.com",
|
||||
"www.google.kz": "forcesafesearch.google.com",
|
||||
"www.google.la": "forcesafesearch.google.com",
|
||||
"www.google.com.lb": "forcesafesearch.google.com",
|
||||
"www.google.li": "forcesafesearch.google.com",
|
||||
"www.google.lk": "forcesafesearch.google.com",
|
||||
"www.google.co.ls": "forcesafesearch.google.com",
|
||||
"www.google.lt": "forcesafesearch.google.com",
|
||||
"www.google.lu": "forcesafesearch.google.com",
|
||||
"www.google.lv": "forcesafesearch.google.com",
|
||||
"www.google.com.ly": "forcesafesearch.google.com",
|
||||
"www.google.co.ma": "forcesafesearch.google.com",
|
||||
"www.google.md": "forcesafesearch.google.com",
|
||||
"www.google.me": "forcesafesearch.google.com",
|
||||
"www.google.mg": "forcesafesearch.google.com",
|
||||
"www.google.mk": "forcesafesearch.google.com",
|
||||
"www.google.ml": "forcesafesearch.google.com",
|
||||
"www.google.com.mm": "forcesafesearch.google.com",
|
||||
"www.google.mn": "forcesafesearch.google.com",
|
||||
"www.google.ms": "forcesafesearch.google.com",
|
||||
"www.google.com.mt": "forcesafesearch.google.com",
|
||||
"www.google.mu": "forcesafesearch.google.com",
|
||||
"www.google.mv": "forcesafesearch.google.com",
|
||||
"www.google.mw": "forcesafesearch.google.com",
|
||||
"www.google.com.mx": "forcesafesearch.google.com",
|
||||
"www.google.com.my": "forcesafesearch.google.com",
|
||||
"www.google.co.mz": "forcesafesearch.google.com",
|
||||
"www.google.com.na": "forcesafesearch.google.com",
|
||||
"www.google.com.nf": "forcesafesearch.google.com",
|
||||
"www.google.com.ng": "forcesafesearch.google.com",
|
||||
"www.google.com.ni": "forcesafesearch.google.com",
|
||||
"www.google.ne": "forcesafesearch.google.com",
|
||||
"www.google.nl": "forcesafesearch.google.com",
|
||||
"www.google.no": "forcesafesearch.google.com",
|
||||
"www.google.com.np": "forcesafesearch.google.com",
|
||||
"www.google.nr": "forcesafesearch.google.com",
|
||||
"www.google.nu": "forcesafesearch.google.com",
|
||||
"www.google.co.nz": "forcesafesearch.google.com",
|
||||
"www.google.com.om": "forcesafesearch.google.com",
|
||||
"www.google.com.pa": "forcesafesearch.google.com",
|
||||
"www.google.com.pe": "forcesafesearch.google.com",
|
||||
"www.google.com.pg": "forcesafesearch.google.com",
|
||||
"www.google.com.ph": "forcesafesearch.google.com",
|
||||
"www.google.com.pk": "forcesafesearch.google.com",
|
||||
"www.google.pl": "forcesafesearch.google.com",
|
||||
"www.google.pn": "forcesafesearch.google.com",
|
||||
"www.google.com.pr": "forcesafesearch.google.com",
|
||||
"www.google.ps": "forcesafesearch.google.com",
|
||||
"www.google.pt": "forcesafesearch.google.com",
|
||||
"www.google.com.py": "forcesafesearch.google.com",
|
||||
"www.google.com.qa": "forcesafesearch.google.com",
|
||||
"www.google.ro": "forcesafesearch.google.com",
|
||||
"www.google.ru": "forcesafesearch.google.com",
|
||||
"www.google.rw": "forcesafesearch.google.com",
|
||||
"www.google.com.sa": "forcesafesearch.google.com",
|
||||
"www.google.com.sb": "forcesafesearch.google.com",
|
||||
"www.google.sc": "forcesafesearch.google.com",
|
||||
"www.google.se": "forcesafesearch.google.com",
|
||||
"www.google.com.sg": "forcesafesearch.google.com",
|
||||
"www.google.sh": "forcesafesearch.google.com",
|
||||
"www.google.si": "forcesafesearch.google.com",
|
||||
"www.google.sk": "forcesafesearch.google.com",
|
||||
"www.google.com.sl": "forcesafesearch.google.com",
|
||||
"www.google.sn": "forcesafesearch.google.com",
|
||||
"www.google.so": "forcesafesearch.google.com",
|
||||
"www.google.sm": "forcesafesearch.google.com",
|
||||
"www.google.sr": "forcesafesearch.google.com",
|
||||
"www.google.st": "forcesafesearch.google.com",
|
||||
"www.google.com.sv": "forcesafesearch.google.com",
|
||||
"www.google.td": "forcesafesearch.google.com",
|
||||
"www.google.tg": "forcesafesearch.google.com",
|
||||
"www.google.co.th": "forcesafesearch.google.com",
|
||||
"www.google.com.tj": "forcesafesearch.google.com",
|
||||
"www.google.tk": "forcesafesearch.google.com",
|
||||
"www.google.tl": "forcesafesearch.google.com",
|
||||
"www.google.tm": "forcesafesearch.google.com",
|
||||
"www.google.tn": "forcesafesearch.google.com",
|
||||
"www.google.to": "forcesafesearch.google.com",
|
||||
"www.google.com.tr": "forcesafesearch.google.com",
|
||||
"www.google.tt": "forcesafesearch.google.com",
|
||||
"www.google.com.tw": "forcesafesearch.google.com",
|
||||
"www.google.co.tz": "forcesafesearch.google.com",
|
||||
"www.google.com.ua": "forcesafesearch.google.com",
|
||||
"www.google.co.ug": "forcesafesearch.google.com",
|
||||
"www.google.co.uk": "forcesafesearch.google.com",
|
||||
"www.google.com.uy": "forcesafesearch.google.com",
|
||||
"www.google.co.uz": "forcesafesearch.google.com",
|
||||
"www.google.com.vc": "forcesafesearch.google.com",
|
||||
"www.google.co.ve": "forcesafesearch.google.com",
|
||||
"www.google.vg": "forcesafesearch.google.com",
|
||||
"www.google.co.vi": "forcesafesearch.google.com",
|
||||
"www.google.com.vn": "forcesafesearch.google.com",
|
||||
"www.google.vu": "forcesafesearch.google.com",
|
||||
"www.google.ws": "forcesafesearch.google.com",
|
||||
"www.google.rs": "forcesafesearch.google.com",
|
||||
|
||||
"www.youtube.com": "restrictmoderate.youtube.com",
|
||||
"m.youtube.com": "restrictmoderate.youtube.com",
|
||||
"youtubei.googleapis.com": "restrictmoderate.youtube.com",
|
||||
"youtube.googleapis.com": "restrictmoderate.youtube.com",
|
||||
"www.youtube-nocookie.com": "restrictmoderate.youtube.com",
|
||||
|
||||
"pixabay.com": "safesearch.pixabay.com",
|
||||
return d.safeSearch.CheckHost(host, qtype)
|
||||
}
|
||||
|
||||
@@ -1 +1 @@
|
||||
|www.bing.com^$dnsrewrite=NOERROR;CNAME;strict.bing.com
|
||||
|www.bing.com^$dnsrewrite=NOERROR;CNAME;strict.bing.com
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
|duckduckgo.com^$dnsrewrite=NOERROR;CNAME;safe.duckduckgo.com
|
||||
|start.duckduckgo.com^$dnsrewrite=NOERROR;CNAME;safe.duckduckgo.com
|
||||
|www.duckduckgo.com^$dnsrewrite=NOERROR;CNAME;safe.duckduckgo.com
|
||||
|www.duckduckgo.com^$dnsrewrite=NOERROR;CNAME;safe.duckduckgo.com
|
||||
|
||||
@@ -188,4 +188,4 @@
|
||||
|www.google.tt^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com
|
||||
|www.google.vg^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com
|
||||
|www.google.vu^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com
|
||||
|www.google.ws^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com
|
||||
|www.google.ws^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com
|
||||
|
||||
@@ -1 +1 @@
|
||||
|pixabay.com^$dnsrewrite=NOERROR;CNAME;safesearch.pixabay.com
|
||||
|pixabay.com^$dnsrewrite=NOERROR;CNAME;safesearch.pixabay.com
|
||||
|
||||
@@ -49,4 +49,4 @@
|
||||
|yandex.ru^$dnsrewrite=NOERROR;A;213.180.193.56
|
||||
|yandex.tj^$dnsrewrite=NOERROR;A;213.180.193.56
|
||||
|yandex.tm^$dnsrewrite=NOERROR;A;213.180.193.56
|
||||
|yandex.uz^$dnsrewrite=NOERROR;A;213.180.193.56
|
||||
|yandex.uz^$dnsrewrite=NOERROR;A;213.180.193.56
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
|m.youtube.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com
|
||||
|youtubei.googleapis.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com
|
||||
|youtube.googleapis.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com
|
||||
|www.youtube-nocookie.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com
|
||||
|www.youtube-nocookie.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
@@ -53,44 +54,85 @@ func isServiceProtected(s filtering.SafeSearchConfig, service Service) (ok bool)
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultSafeSearch is the default safesearch struct.
|
||||
type DefaultSafeSearch struct {
|
||||
engine *urlfilter.DNSEngine
|
||||
safeSearchCache cache.Cache
|
||||
resolver filtering.Resolver
|
||||
cacheTime time.Duration
|
||||
// Default is the default safe search filter that uses filtering rules with the
|
||||
// dnsrewrite modifier.
|
||||
type Default struct {
|
||||
// mu protects engine.
|
||||
mu *sync.RWMutex
|
||||
|
||||
// engine is the filtering engine that contains the DNS rewrite rules.
|
||||
// engine may be nil, which means that this safe search filter is disabled.
|
||||
engine *urlfilter.DNSEngine
|
||||
|
||||
cache cache.Cache
|
||||
resolver filtering.Resolver
|
||||
logPrefix string
|
||||
cacheTTL time.Duration
|
||||
}
|
||||
|
||||
// NewDefaultSafeSearch returns new safesearch struct. CacheTime is an element
|
||||
// TTL (in minutes).
|
||||
func NewDefaultSafeSearch(
|
||||
// NewDefault returns an initialized default safe search filter. name is used
|
||||
// for logging.
|
||||
func NewDefault(
|
||||
conf filtering.SafeSearchConfig,
|
||||
name string,
|
||||
cacheSize uint,
|
||||
cacheTime time.Duration,
|
||||
) (ss *DefaultSafeSearch, err error) {
|
||||
engine, err := newEngine(filtering.SafeSearchListID, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cacheTTL time.Duration,
|
||||
) (ss *Default, err error) {
|
||||
var resolver filtering.Resolver = net.DefaultResolver
|
||||
if conf.CustomResolver != nil {
|
||||
resolver = conf.CustomResolver
|
||||
}
|
||||
|
||||
return &DefaultSafeSearch{
|
||||
engine: engine,
|
||||
safeSearchCache: cache.New(cache.Config{
|
||||
ss = &Default{
|
||||
mu: &sync.RWMutex{},
|
||||
|
||||
cache: cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxSize: cacheSize,
|
||||
}),
|
||||
cacheTime: cacheTime,
|
||||
resolver: resolver,
|
||||
}, nil
|
||||
resolver: resolver,
|
||||
// Use %s, because the client safe-search names already contain double
|
||||
// quotes.
|
||||
logPrefix: fmt.Sprintf("safesearch %s: ", name),
|
||||
cacheTTL: cacheTTL,
|
||||
}
|
||||
|
||||
err = ss.resetEngine(filtering.SafeSearchListID, conf)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
// newEngine creates new engine for provided safe search configuration.
|
||||
func newEngine(listID int, conf filtering.SafeSearchConfig) (engine *urlfilter.DNSEngine, err error) {
|
||||
// log is a helper for logging that includes the name of the safe search
|
||||
// filter. level must be one of [log.DEBUG], [log.INFO], and [log.ERROR].
|
||||
func (ss *Default) log(level log.Level, msg string, args ...any) {
|
||||
switch level {
|
||||
case log.DEBUG:
|
||||
log.Debug(ss.logPrefix+msg, args...)
|
||||
case log.INFO:
|
||||
log.Info(ss.logPrefix+msg, args...)
|
||||
case log.ERROR:
|
||||
log.Error(ss.logPrefix+msg, args...)
|
||||
default:
|
||||
panic(fmt.Errorf("safesearch: unsupported logging level %d", level))
|
||||
}
|
||||
}
|
||||
|
||||
// resetEngine creates new engine for provided safe search configuration and
|
||||
// sets it in ss.
|
||||
func (ss *Default) resetEngine(
|
||||
listID int,
|
||||
conf filtering.SafeSearchConfig,
|
||||
) (err error) {
|
||||
if !conf.Enabled {
|
||||
ss.log(log.INFO, "disabled")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
for service, serviceRules := range safeSearchRules {
|
||||
if isServiceProtected(conf, service) {
|
||||
@@ -106,20 +148,73 @@ func newEngine(listID int, conf filtering.SafeSearchConfig) (engine *urlfilter.D
|
||||
|
||||
rs, err := filterlist.NewRuleStorage([]filterlist.RuleList{strList})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating rule storage: %w", err)
|
||||
return fmt.Errorf("creating rule storage: %w", err)
|
||||
}
|
||||
|
||||
engine = urlfilter.NewDNSEngine(rs)
|
||||
log.Info("safesearch: filter %d: reset %d rules", listID, engine.RulesCount)
|
||||
ss.engine = urlfilter.NewDNSEngine(rs)
|
||||
|
||||
return engine, nil
|
||||
ss.log(log.INFO, "reset %d rules", ss.engine.RulesCount)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ filtering.SafeSearch = (*DefaultSafeSearch)(nil)
|
||||
var _ filtering.SafeSearch = (*Default)(nil)
|
||||
|
||||
// CheckHost implements the [filtering.SafeSearch] interface for
|
||||
// *DefaultSafeSearch.
|
||||
func (ss *Default) CheckHost(
|
||||
host string,
|
||||
qtype rules.RRType,
|
||||
) (res filtering.Result, err error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
ss.log(log.DEBUG, "lookup for %q finished in %s", host, time.Since(start))
|
||||
}()
|
||||
|
||||
if qtype != dns.TypeA && qtype != dns.TypeAAAA {
|
||||
return filtering.Result{}, fmt.Errorf("unsupported question type %s", dns.Type(qtype))
|
||||
}
|
||||
|
||||
// Check cache. Return cached result if it was found
|
||||
cachedValue, isFound := ss.getCachedResult(host, qtype)
|
||||
if isFound {
|
||||
ss.log(log.DEBUG, "found in cache: %q", host)
|
||||
|
||||
return cachedValue, nil
|
||||
}
|
||||
|
||||
rewrite := ss.searchHost(host, qtype)
|
||||
if rewrite == nil {
|
||||
return filtering.Result{}, nil
|
||||
}
|
||||
|
||||
fltRes, err := ss.newResult(rewrite, qtype)
|
||||
if err != nil {
|
||||
ss.log(log.DEBUG, "looking up addresses for %q: %s", host, err)
|
||||
|
||||
return filtering.Result{}, err
|
||||
}
|
||||
|
||||
if fltRes != nil {
|
||||
res = *fltRes
|
||||
ss.setCacheResult(host, qtype, res)
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
return filtering.Result{}, fmt.Errorf("no ipv4 addresses for %q", host)
|
||||
}
|
||||
|
||||
// searchHost looks up DNS rewrites in the internal DNS filtering engine.
|
||||
func (ss *Default) searchHost(host string, qtype rules.RRType) (res *rules.DNSRewrite) {
|
||||
ss.mu.RLock()
|
||||
defer ss.mu.RUnlock()
|
||||
|
||||
if ss.engine == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SearchHost implements the [filtering.SafeSearch] interface for *DefaultSafeSearch.
|
||||
func (ss *DefaultSafeSearch) SearchHost(host string, qtype uint16) (res *rules.DNSRewrite) {
|
||||
r, _ := ss.engine.MatchRequest(&urlfilter.DNSRequest{
|
||||
Hostname: strings.ToLower(host),
|
||||
DNSType: qtype,
|
||||
@@ -133,51 +228,11 @@ func (ss *DefaultSafeSearch) SearchHost(host string, qtype uint16) (res *rules.D
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckHost implements the [filtering.SafeSearch] interface for
|
||||
// *DefaultSafeSearch.
|
||||
func (ss *DefaultSafeSearch) CheckHost(
|
||||
host string,
|
||||
qtype uint16,
|
||||
) (res filtering.Result, err error) {
|
||||
if log.GetLevel() >= log.DEBUG {
|
||||
timer := log.StartTimer()
|
||||
defer timer.LogElapsed("safesearch: lookup for %s", host)
|
||||
}
|
||||
|
||||
// Check cache. Return cached result if it was found
|
||||
cachedValue, isFound := ss.getCachedResult(host)
|
||||
if isFound {
|
||||
log.Debug("safesearch: found in cache: %s", host)
|
||||
|
||||
return cachedValue, nil
|
||||
}
|
||||
|
||||
rewrite := ss.SearchHost(host, qtype)
|
||||
if rewrite == nil {
|
||||
return filtering.Result{}, nil
|
||||
}
|
||||
|
||||
dRes, err := ss.newResult(rewrite, qtype)
|
||||
if err != nil {
|
||||
log.Debug("safesearch: failed to lookup addresses for %s: %s", host, err)
|
||||
|
||||
return filtering.Result{}, err
|
||||
}
|
||||
|
||||
if dRes != nil {
|
||||
res = *dRes
|
||||
ss.setCacheResult(host, res)
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
return filtering.Result{}, fmt.Errorf("no ipv4 addresses in safe search response for %s", host)
|
||||
}
|
||||
|
||||
// newResult creates Result object from rewrite rule.
|
||||
func (ss *DefaultSafeSearch) newResult(
|
||||
// newResult creates Result object from rewrite rule. qtype must be either
|
||||
// [dns.TypeA] or [dns.TypeAAAA].
|
||||
func (ss *Default) newResult(
|
||||
rewrite *rules.DNSRewrite,
|
||||
qtype uint16,
|
||||
qtype rules.RRType,
|
||||
) (res *filtering.Result, err error) {
|
||||
res = &filtering.Result{
|
||||
Rules: []*filtering.ResultRule{{
|
||||
@@ -187,7 +242,7 @@ func (ss *DefaultSafeSearch) newResult(
|
||||
IsFiltered: true,
|
||||
}
|
||||
|
||||
if rewrite.RRType == qtype && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
|
||||
if rewrite.RRType == qtype {
|
||||
ip, ok := rewrite.Value.(net.IP)
|
||||
if !ok || ip == nil {
|
||||
return nil, nil
|
||||
@@ -198,17 +253,25 @@ func (ss *DefaultSafeSearch) newResult(
|
||||
return res, nil
|
||||
}
|
||||
|
||||
if rewrite.NewCNAME == "" {
|
||||
host := rewrite.NewCNAME
|
||||
if host == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ips, err := ss.resolver.LookupIP(context.Background(), "ip", rewrite.NewCNAME)
|
||||
ss.log(log.DEBUG, "resolving %q", host)
|
||||
|
||||
ips, err := ss.resolver.LookupIP(context.Background(), qtypeToProto(qtype), host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ss.log(log.DEBUG, "resolved %s", ips)
|
||||
|
||||
for _, ip := range ips {
|
||||
if ip = ip.To4(); ip == nil {
|
||||
// TODO(a.garipov): Remove this filtering once the resolver we use
|
||||
// actually learns about network.
|
||||
ip = fitToProto(ip, qtype)
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -220,38 +283,71 @@ func (ss *DefaultSafeSearch) newResult(
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// setCacheResult stores data in cache for host.
|
||||
func (ss *DefaultSafeSearch) setCacheResult(host string, res filtering.Result) {
|
||||
expire := uint32(time.Now().Add(ss.cacheTime).Unix())
|
||||
// qtypeToProto returns "ip4" for [dns.TypeA] and "ip6" for [dns.TypeAAAA].
|
||||
// It panics for other types.
|
||||
func qtypeToProto(qtype rules.RRType) (proto string) {
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
return "ip4"
|
||||
case dns.TypeAAAA:
|
||||
return "ip6"
|
||||
default:
|
||||
panic(fmt.Errorf("safesearch: unsupported question type %s", dns.Type(qtype)))
|
||||
}
|
||||
}
|
||||
|
||||
// fitToProto returns a non-nil IP address if ip is the correct protocol version
|
||||
// for qtype. qtype is expected to be either [dns.TypeA] or [dns.TypeAAAA].
|
||||
func fitToProto(ip net.IP, qtype rules.RRType) (res net.IP) {
|
||||
ip4 := ip.To4()
|
||||
if qtype == dns.TypeA {
|
||||
return ip4
|
||||
}
|
||||
|
||||
if ip4 == nil {
|
||||
return ip
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setCacheResult stores data in cache for host. qtype is expected to be either
|
||||
// [dns.TypeA] or [dns.TypeAAAA].
|
||||
func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering.Result) {
|
||||
expire := uint32(time.Now().Add(ss.cacheTTL).Unix())
|
||||
exp := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(exp, expire)
|
||||
buf := bytes.NewBuffer(exp)
|
||||
|
||||
err := gob.NewEncoder(buf).Encode(res)
|
||||
if err != nil {
|
||||
log.Error("safesearch: cache encoding: %s", err)
|
||||
ss.log(log.ERROR, "cache encoding: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
val := buf.Bytes()
|
||||
_ = ss.safeSearchCache.Set([]byte(host), val)
|
||||
_ = ss.cache.Set([]byte(dns.Type(qtype).String()+" "+host), val)
|
||||
|
||||
log.Debug("safesearch: stored in cache: %s (%d bytes)", host, len(val))
|
||||
ss.log(log.DEBUG, "stored in cache: %q, %d bytes", host, len(val))
|
||||
}
|
||||
|
||||
// getCachedResult returns stored data from cache for host.
|
||||
func (ss *DefaultSafeSearch) getCachedResult(host string) (res filtering.Result, ok bool) {
|
||||
// getCachedResult returns stored data from cache for host. qtype is expected
|
||||
// to be either [dns.TypeA] or [dns.TypeAAAA].
|
||||
func (ss *Default) getCachedResult(
|
||||
host string,
|
||||
qtype rules.RRType,
|
||||
) (res filtering.Result, ok bool) {
|
||||
res = filtering.Result{}
|
||||
|
||||
data := ss.safeSearchCache.Get([]byte(host))
|
||||
data := ss.cache.Get([]byte(dns.Type(qtype).String() + " " + host))
|
||||
if data == nil {
|
||||
return res, false
|
||||
}
|
||||
|
||||
exp := binary.BigEndian.Uint32(data[:4])
|
||||
if exp <= uint32(time.Now().Unix()) {
|
||||
ss.safeSearchCache.Del([]byte(host))
|
||||
ss.cache.Del([]byte(host))
|
||||
|
||||
return res, false
|
||||
}
|
||||
@@ -260,10 +356,27 @@ func (ss *DefaultSafeSearch) getCachedResult(host string) (res filtering.Result,
|
||||
|
||||
err := gob.NewDecoder(buf).Decode(&res)
|
||||
if err != nil {
|
||||
log.Debug("safesearch: cache decoding: %s", err)
|
||||
ss.log(log.ERROR, "cache decoding: %s", err)
|
||||
|
||||
return filtering.Result{}, false
|
||||
}
|
||||
|
||||
return res, true
|
||||
}
|
||||
|
||||
// Update implements the [filtering.SafeSearch] interface for *Default. Update
|
||||
// ignores the CustomResolver and Enabled fields.
|
||||
func (ss *Default) Update(conf filtering.SafeSearchConfig) (err error) {
|
||||
ss.mu.Lock()
|
||||
defer ss.mu.Unlock()
|
||||
|
||||
err = ss.resetEngine(filtering.SafeSearchListID, conf)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
ss.cache.Clear()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
137
internal/filtering/safesearch/safesearch_internal_test.go
Normal file
137
internal/filtering/safesearch/safesearch_internal_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package safesearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TODO(a.garipov): Move as much of this as possible into proper external tests.
|
||||
|
||||
const (
|
||||
// TODO(a.garipov): Add IPv6 tests.
|
||||
testQType = dns.TypeA
|
||||
testCacheSize = 5000
|
||||
testCacheTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
var defaultSafeSearchConf = filtering.SafeSearchConfig{
|
||||
Enabled: true,
|
||||
Bing: true,
|
||||
DuckDuckGo: true,
|
||||
Google: true,
|
||||
Pixabay: true,
|
||||
Yandex: true,
|
||||
YouTube: true,
|
||||
}
|
||||
|
||||
var yandexIP = net.IPv4(213, 180, 193, 56)
|
||||
|
||||
func newForTest(t testing.TB, ssConf filtering.SafeSearchConfig) (ss *Default) {
|
||||
ss, err := NewDefault(ssConf, "", testCacheSize, testCacheTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
return ss
|
||||
}
|
||||
|
||||
func TestSafeSearch(t *testing.T) {
|
||||
ss := newForTest(t, defaultSafeSearchConf)
|
||||
val := ss.searchHost("www.google.com", testQType)
|
||||
|
||||
assert.Equal(t, &rules.DNSRewrite{NewCNAME: "forcesafesearch.google.com"}, val)
|
||||
}
|
||||
|
||||
func TestSafeSearchCacheYandex(t *testing.T) {
|
||||
const domain = "yandex.ru"
|
||||
|
||||
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
|
||||
|
||||
// Check host with disabled safesearch.
|
||||
res, err := ss.CheckHost(domain, testQType)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, res.IsFiltered)
|
||||
assert.Empty(t, res.Rules)
|
||||
|
||||
ss = newForTest(t, defaultSafeSearchConf)
|
||||
res, err = ss.CheckHost(domain, testQType)
|
||||
require.NoError(t, err)
|
||||
|
||||
// For yandex we already know valid IP.
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.Equal(t, res.Rules[0].IP, yandexIP)
|
||||
|
||||
// Check cache.
|
||||
cachedValue, isFound := ss.getCachedResult(domain, testQType)
|
||||
require.True(t, isFound)
|
||||
require.Len(t, cachedValue.Rules, 1)
|
||||
|
||||
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
|
||||
}
|
||||
|
||||
func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||
const domain = "www.google.ru"
|
||||
|
||||
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
|
||||
|
||||
res, err := ss.CheckHost(domain, testQType)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, res.IsFiltered)
|
||||
assert.Empty(t, res.Rules)
|
||||
|
||||
resolver := &aghtest.TestResolver{}
|
||||
ss = newForTest(t, defaultSafeSearchConf)
|
||||
ss.resolver = resolver
|
||||
|
||||
// Lookup for safesearch domain.
|
||||
rewrite := ss.searchHost(domain, testQType)
|
||||
|
||||
ips, err := resolver.LookupIP(context.Background(), "ip", rewrite.NewCNAME)
|
||||
require.NoError(t, err)
|
||||
|
||||
var foundIP net.IP
|
||||
for _, ip := range ips {
|
||||
if ip.To4() != nil {
|
||||
foundIP = ip
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
res, err = ss.CheckHost(domain, testQType)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.True(t, res.Rules[0].IP.Equal(foundIP))
|
||||
|
||||
// Check cache.
|
||||
cachedValue, isFound := ss.getCachedResult(domain, testQType)
|
||||
require.True(t, isFound)
|
||||
require.Len(t, cachedValue.Rules, 1)
|
||||
|
||||
assert.True(t, cachedValue.Rules[0].IP.Equal(foundIP))
|
||||
}
|
||||
|
||||
const googleHost = "www.google.com"
|
||||
|
||||
var dnsRewriteSink *rules.DNSRewrite
|
||||
|
||||
func BenchmarkSafeSearch(b *testing.B) {
|
||||
ss := newForTest(b, defaultSafeSearchConf)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
dnsRewriteSink = ss.searchHost(googleHost, testQType)
|
||||
}
|
||||
|
||||
assert.Equal(b, "forcesafesearch.google.com", dnsRewriteSink.NewCNAME)
|
||||
}
|
||||
@@ -1,26 +1,37 @@
|
||||
package safesearch
|
||||
package safesearch_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
// Common test constants.
|
||||
const (
|
||||
safeSearchCacheSize = 5000
|
||||
cacheTime = 30 * time.Minute
|
||||
// TODO(a.garipov): Add IPv6 tests.
|
||||
testQType = dns.TypeA
|
||||
testCacheSize = 5000
|
||||
testCacheTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
var defaultSafeSearchConf = filtering.SafeSearchConfig{
|
||||
Enabled: true,
|
||||
// testConf is the default safe search configuration for tests.
|
||||
var testConf = filtering.SafeSearchConfig{
|
||||
CustomResolver: nil,
|
||||
|
||||
Enabled: true,
|
||||
|
||||
Bing: true,
|
||||
DuckDuckGo: true,
|
||||
Google: true,
|
||||
@@ -29,25 +40,15 @@ var defaultSafeSearchConf = filtering.SafeSearchConfig{
|
||||
YouTube: true,
|
||||
}
|
||||
|
||||
// yandexIP is the expected IP address of Yandex safe search results. Keep in
|
||||
// sync with the rules data.
|
||||
var yandexIP = net.IPv4(213, 180, 193, 56)
|
||||
|
||||
func newForTest(t testing.TB, ssConf filtering.SafeSearchConfig) (ss *DefaultSafeSearch) {
|
||||
ss, err := NewDefaultSafeSearch(ssConf, safeSearchCacheSize, cacheTime)
|
||||
func TestDefault_CheckHost_yandex(t *testing.T) {
|
||||
conf := testConf
|
||||
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
return ss
|
||||
}
|
||||
|
||||
func TestSafeSearch(t *testing.T) {
|
||||
ss := newForTest(t, defaultSafeSearchConf)
|
||||
val := ss.SearchHost("www.google.com", dns.TypeA)
|
||||
|
||||
assert.Equal(t, &rules.DNSRewrite{NewCNAME: "forcesafesearch.google.com"}, val)
|
||||
}
|
||||
|
||||
func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||
ss := newForTest(t, defaultSafeSearchConf)
|
||||
|
||||
// Check host for each domain.
|
||||
for _, host := range []string{
|
||||
"yandex.ru",
|
||||
@@ -57,7 +58,8 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||
"yandex.kz",
|
||||
"www.yandex.com",
|
||||
} {
|
||||
res, err := ss.CheckHost(host, dns.TypeA)
|
||||
var res filtering.Result
|
||||
res, err = ss.CheckHost(host, testQType)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
@@ -69,12 +71,14 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||
func TestDefault_CheckHost_google(t *testing.T) {
|
||||
resolver := &aghtest.TestResolver{}
|
||||
ip, _ := resolver.HostToIPs("forcesafesearch.google.com")
|
||||
|
||||
ss := newForTest(t, defaultSafeSearchConf)
|
||||
ss.resolver = resolver
|
||||
conf := testConf
|
||||
conf.CustomResolver = resolver
|
||||
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check host for each domain.
|
||||
for _, host := range []string{
|
||||
@@ -87,7 +91,8 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||
"www.google.je",
|
||||
} {
|
||||
t.Run(host, func(t *testing.T) {
|
||||
res, err := ss.CheckHost(host, dns.TypeA)
|
||||
var res filtering.Result
|
||||
res, err = ss.CheckHost(host, testQType)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
@@ -100,103 +105,35 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeSearchCacheYandex(t *testing.T) {
|
||||
const domain = "yandex.ru"
|
||||
|
||||
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
|
||||
|
||||
// Check host with disabled safesearch.
|
||||
res, err := ss.CheckHost(domain, dns.TypeA)
|
||||
func TestDefault_Update(t *testing.T) {
|
||||
conf := testConf
|
||||
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, res.IsFiltered)
|
||||
assert.Empty(t, res.Rules)
|
||||
|
||||
ss = newForTest(t, defaultSafeSearchConf)
|
||||
res, err = ss.CheckHost(domain, dns.TypeA)
|
||||
res, err := ss.CheckHost("www.yandex.com", testQType)
|
||||
require.NoError(t, err)
|
||||
|
||||
// For yandex we already know valid IP.
|
||||
require.Len(t, res.Rules, 1)
|
||||
assert.True(t, res.IsFiltered)
|
||||
|
||||
assert.Equal(t, res.Rules[0].IP, yandexIP)
|
||||
|
||||
// Check cache.
|
||||
cachedValue, isFound := ss.getCachedResult(domain)
|
||||
require.True(t, isFound)
|
||||
require.Len(t, cachedValue.Rules, 1)
|
||||
|
||||
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
|
||||
}
|
||||
|
||||
func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||
const domain = "www.google.ru"
|
||||
|
||||
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
|
||||
|
||||
res, err := ss.CheckHost(domain, dns.TypeA)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, res.IsFiltered)
|
||||
assert.Empty(t, res.Rules)
|
||||
|
||||
resolver := &aghtest.TestResolver{}
|
||||
ss = newForTest(t, defaultSafeSearchConf)
|
||||
ss.resolver = resolver
|
||||
|
||||
// Lookup for safesearch domain.
|
||||
rewrite := ss.SearchHost(domain, dns.TypeA)
|
||||
|
||||
ips, err := resolver.LookupIP(context.Background(), "ip", rewrite.NewCNAME)
|
||||
require.NoError(t, err)
|
||||
|
||||
var foundIP net.IP
|
||||
for _, ip := range ips {
|
||||
if ip.To4() != nil {
|
||||
foundIP = ip
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
res, err = ss.CheckHost(domain, dns.TypeA)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.True(t, res.Rules[0].IP.Equal(foundIP))
|
||||
|
||||
// Check cache.
|
||||
cachedValue, isFound := ss.getCachedResult(domain)
|
||||
require.True(t, isFound)
|
||||
require.Len(t, cachedValue.Rules, 1)
|
||||
|
||||
assert.True(t, cachedValue.Rules[0].IP.Equal(foundIP))
|
||||
}
|
||||
|
||||
const googleHost = "www.google.com"
|
||||
|
||||
var dnsRewriteSink *rules.DNSRewrite
|
||||
|
||||
func BenchmarkSafeSearch(b *testing.B) {
|
||||
ss := newForTest(b, defaultSafeSearchConf)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
dnsRewriteSink = ss.SearchHost(googleHost, dns.TypeA)
|
||||
}
|
||||
|
||||
assert.Equal(b, "forcesafesearch.google.com", dnsRewriteSink.NewCNAME)
|
||||
}
|
||||
|
||||
var dnsRewriteParallelSink *rules.DNSRewrite
|
||||
|
||||
func BenchmarkSafeSearch_parallel(b *testing.B) {
|
||||
ss := newForTest(b, defaultSafeSearchConf)
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
dnsRewriteParallelSink = ss.SearchHost(googleHost, dns.TypeA)
|
||||
}
|
||||
err = ss.Update(filtering.SafeSearchConfig{
|
||||
Enabled: true,
|
||||
Google: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(b, "forcesafesearch.google.com", dnsRewriteParallelSink.NewCNAME)
|
||||
res, err = ss.CheckHost("www.yandex.com", testQType)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, res.IsFiltered)
|
||||
|
||||
err = ss.Update(filtering.SafeSearchConfig{
|
||||
Enabled: false,
|
||||
Google: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err = ss.CheckHost("www.yandex.com", testQType)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, res.IsFiltered)
|
||||
}
|
||||
|
||||
71
internal/filtering/safesearchhttp.go
Normal file
71
internal/filtering/safesearchhttp.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package filtering
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
)
|
||||
|
||||
// handleSafeSearchEnable is the handler for POST /control/safesearch/enable
|
||||
// HTTP API.
|
||||
//
|
||||
// Deprecated: Use handleSafeSearchSettings.
|
||||
func (d *DNSFilter) handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) {
|
||||
setProtectedBool(&d.confLock, &d.Config.SafeSearchConf.Enabled, true)
|
||||
d.Config.ConfigModified()
|
||||
}
|
||||
|
||||
// handleSafeSearchDisable is the handler for POST /control/safesearch/disable
|
||||
// HTTP API.
|
||||
//
|
||||
// Deprecated: Use handleSafeSearchSettings.
|
||||
func (d *DNSFilter) handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) {
|
||||
setProtectedBool(&d.confLock, &d.Config.SafeSearchConf.Enabled, false)
|
||||
d.Config.ConfigModified()
|
||||
}
|
||||
|
||||
// handleSafeSearchStatus is the handler for GET /control/safesearch/status
|
||||
// HTTP API.
|
||||
func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) {
|
||||
var resp SafeSearchConfig
|
||||
func() {
|
||||
d.confLock.RLock()
|
||||
defer d.confLock.RUnlock()
|
||||
|
||||
resp = d.Config.SafeSearchConf
|
||||
}()
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
}
|
||||
|
||||
// handleSafeSearchSettings is the handler for PUT /control/safesearch/settings
|
||||
// HTTP API.
|
||||
func (d *DNSFilter) handleSafeSearchSettings(w http.ResponseWriter, r *http.Request) {
|
||||
req := &SafeSearchConfig{}
|
||||
err := json.NewDecoder(r.Body).Decode(req)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "reading req: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
conf := *req
|
||||
err = d.safeSearch.Update(conf)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "updating: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func() {
|
||||
d.confLock.Lock()
|
||||
defer d.confLock.Unlock()
|
||||
|
||||
d.Config.SafeSearchConf = conf
|
||||
}()
|
||||
|
||||
d.Config.ConfigModified()
|
||||
|
||||
aghhttp.OK(w)
|
||||
}
|
||||
@@ -311,6 +311,14 @@ var blockedServices = []blockedService{{
|
||||
"||warp.plus^",
|
||||
"||workers.dev^",
|
||||
},
|
||||
}, {
|
||||
ID: "crunchyroll",
|
||||
Name: "Crunchyroll",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M 25 3 C 12.85 3 3 12.85 3 25 C 3 40.188 13.387672 44.538609 20.388672 45.974609 C 20.427672 45.982609 20.465953 45.986328 20.501953 45.986328 C 21.006953 45.986328 21.206312 45.25525 20.695312 45.03125 C 13.285312 41.79025 8.0301562 34.327141 9.1601562 25.494141 C 10.256156 16.920141 17.244938 10.069141 25.835938 9.1191406 C 26.564937 9.0381406 27.287 9 28 9 C 35.541 9 42.044422 13.395672 45.107422 19.763672 C 45.206422 19.968672 45.382594 20.058594 45.558594 20.058594 C 45.853594 20.058594 46.144828 19.8075 46.048828 19.4375 C 44.302828 12.7105 39 3 25 3 z M 29 14 C 20.481 14 13.619625 21.101031 14.015625 29.707031 C 14.366625 37.346031 20.653016 43.631422 28.291016 43.982422 C 28.528016 43.994422 28.766 44 29 44 C 37.285 44 44 37.285 44 29 C 44 27.819 43.860563 26.670359 43.601562 25.568359 C 43.542563 25.319359 43.332234 25.183594 43.115234 25.183594 C 42.961234 25.183594 42.806266 25.251484 42.697266 25.396484 C 41.512266 26.976484 39.627 28 37.5 28 C 37.397 28 37.293453 27.997188 37.189453 27.992188 C 34.031453 27.845188 31.348203 25.317875 31.033203 22.171875 C 30.763203 19.477875 32.142297 17.082328 34.279297 15.861328 C 34.656297 15.646328 34.62475 15.100266 34.21875 14.947266 C 32.59375 14.340266 30.838 14 29 14 z M 44.296875 26.595703 L 44.300781 26.595703 L 44.296875 26.595703 z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||crunchyroll.com^",
|
||||
"||gccrunchyroll.com^",
|
||||
},
|
||||
}, {
|
||||
ID: "dailymotion",
|
||||
Name: "Dailymotion",
|
||||
@@ -1182,6 +1190,24 @@ var blockedServices = []blockedService{{
|
||||
"||gog.com^",
|
||||
"||gogalaxy.com^",
|
||||
},
|
||||
}, {
|
||||
ID: "hbomax",
|
||||
Name: "HBO Max",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M0 14v22h5v-9h3v9h5V14H8v8H5v-8H0zm15 0v22h8.4c3.1 0 5.7-2.3 6.2-5.4a11 11 0 1 0 0-11.2c-.5-3-3-5.4-6.2-5.4H15zm5 5h3a2 2 0 1 1 0 4h-3v-4zm19 0a6 6 0 1 1 0 12 6 6 0 0 1 0-12zm0 2a4 4 0 0 0-4 4 4 4 0 0 0 4 4 4 4 0 0 0 4-4 4 4 0 0 0-4-4zm-11 2.8v2.4c-.4-.5-1-1-2-1.2 1-.3 1.7-.8 2-1.3zm-8 4h3a2 2 0 1 1 0 4h-3v-4z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||hbo.com^",
|
||||
"||hbogo.co.th^",
|
||||
"||hbogo.com^",
|
||||
"||hbogo.eu^",
|
||||
"||hbogoasia.com^",
|
||||
"||hbogoasia.id^",
|
||||
"||hbogoasia.ph^",
|
||||
"||hbomax-images.warnermediacdn.com^",
|
||||
"||hbomax.com^",
|
||||
"||hbomaxcdn.com^",
|
||||
"||hbonow.com^",
|
||||
"||maxgo.com^",
|
||||
},
|
||||
}, {
|
||||
ID: "hulu",
|
||||
Name: "Hulu",
|
||||
@@ -1299,6 +1325,21 @@ var blockedServices = []blockedService{{
|
||||
"||kakao.com^",
|
||||
"||kgslb.com^",
|
||||
},
|
||||
}, {
|
||||
ID: "lazada",
|
||||
Name: "Lazada",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 100 100\"><path d=\"M26 17a2 2 0 0 0-1.1.3L8.5 27.4A3 3 0 0 0 7 30v29.1c0 1 .6 2 1.5 2.6l40 24.8c.2.3.6.4 1 .5h1a3 3 0 0 0 1-.5l40-24.8a3 3 0 0 0 1.5-2.6v-29c0-1.1-.6-2.1-1.5-2.7l-16.4-10a2.1 2.1 0 0 0-2.2 0l-15.4 9.4a14.3 14.3 0 0 1-15 0L27 17.3c-.3-.2-.7-.3-1.1-.3zm0 2 15.4 9.5a16.3 16.3 0 0 0 17.2 0L74 19l16.5 10.1v.1L50 51.4 9.4 29.2h.1L26 19zm48 4c-.4 0-.9 0-1.3.3l-5.5 3.4a.5.5 0 1 0 .6.9l5.4-3.4c.5-.3 1.1-.3 1.6 0l9.4 5.7a.5.5 0 1 0 .5-.8l-9.4-5.8c-.4-.2-.8-.4-1.3-.4zm-8.7 5a.5.5 0 0 0-.3 0l-1.6 1a.5.5 0 0 0 .6 1l1.6-1a.5.5 0 0 0-.3-1zM9 30.1l40.5 22.2v32.6L9.5 60a1 1 0 0 1-.5-.9v-29zm82 0v29c0 .4-.2.7-.5 1l-40 24.7V52.4L91 30.1zM12.5 35a.5.5 0 0 0-.5.5v21.2c0 .8.4 1.6 1.2 2l16 10a.5.5 0 1 0 .6-.8l-16-10c-.5-.2-.8-.7-.8-1.2V35.5a.5.5 0 0 0-.5-.5zm24 37.2a.5.5 0 0 0-.3.9l4 2.5a.5.5 0 1 0 .6-.9l-4-2.4a.5.5 0 0 0-.3-.1zm7 4.3a.5.5 0 0 0-.3 1l1 .6a.5.5 0 1 0 .6-.9l-1-.6a.5.5 0 0 0-.3 0z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||k1-lazadasg-oversea.gslb.ksyuncdn.com^",
|
||||
"||lazada.co.id^",
|
||||
"||lazada.co.th^",
|
||||
"||lazada.com.my^",
|
||||
"||lazada.com.ph^",
|
||||
"||lazada.com^",
|
||||
"||lazada.sg^",
|
||||
"||lazada.vn^",
|
||||
"||slatic.net^",
|
||||
},
|
||||
}, {
|
||||
ID: "leagueoflegends",
|
||||
Name: "League of Legends",
|
||||
@@ -1383,7 +1424,6 @@ var blockedServices = []blockedService{{
|
||||
"||mastodon.sdf.org^",
|
||||
"||mastodon.social^",
|
||||
"||mastodon.social^",
|
||||
"||mastodon.top^",
|
||||
"||mastodon.uno^",
|
||||
"||mastodon.world^",
|
||||
"||mastodon.xyz^",
|
||||
@@ -1400,6 +1440,7 @@ var blockedServices = []blockedService{{
|
||||
"||mstdn.jp^",
|
||||
"||mstdn.social^",
|
||||
"||muenchen.social^",
|
||||
"||muenster.im^",
|
||||
"||newsie.social^",
|
||||
"||noc.social^",
|
||||
"||norden.social^",
|
||||
@@ -1428,6 +1469,7 @@ var blockedServices = []blockedService{{
|
||||
"||techhub.social^",
|
||||
"||theblower.au^",
|
||||
"||tkz.one^",
|
||||
"||todon.eu^",
|
||||
"||toot.aquilenet.fr^",
|
||||
"||toot.community^",
|
||||
"||toot.funami.tech^",
|
||||
@@ -1438,7 +1480,6 @@ var blockedServices = []blockedService{{
|
||||
"||union.place^",
|
||||
"||universeodon.com^",
|
||||
"||urbanists.social^",
|
||||
"||vocalodon.net^",
|
||||
"||wien.rocks^",
|
||||
"||wxw.moe^",
|
||||
},
|
||||
@@ -1568,6 +1609,20 @@ var blockedServices = []blockedService{{
|
||||
"||pinterest.vn^",
|
||||
"||pinterestmail.com^",
|
||||
},
|
||||
}, {
|
||||
ID: "playstation",
|
||||
Name: "PlayStation",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 30 30\"><path d=\"M11.18 3.74v21.12l4.58 1.4V8.58c0-.51 0-.77.26-1.02.12-.26.38-.26.63-.13.64.26 1.02.76 1.02 1.78v7c1.53.76 2.8.76 3.81 0 1.02-.77 1.53-1.9 1.53-3.82 0-2.03-.38-3.3-1.27-4.32-.76-1.02-2.16-1.91-4.2-2.55-2.54-.76-4.7-1.4-6.36-1.78zM9.91 16.97l-5.85 2.04-.89.38c-1.4.63-2.16 1.27-2.16 1.9.12.77.38 1.79 2.29 2.42 1.78.64 3.18.9 6.74-.12v-2.3c-3.44 1.15-3.95 1.02-4.45.77-.51-.25-.51-.5-.39-.64.39-.25 1.78-.76 1.78-.76l2.93-1.02v-2.67zm12.94 1c-.41-.02-.82-.01-1.24.02-1.4 0-2.67.25-4.2.64v2.67l2.8-1.02 1.53-.51s.64-.13 1.02-.25c.63-.13 1.4.12 1.4.12.38 0 .63.13.63.38.13.26-.12.39-.76.64l-1.4.51-5.09 1.9v2.68l2.3-.77 6.35-2.28.77-.39c1.52-.5 2.16-1.14 2.03-1.9 0-.77-.89-1.28-2.42-1.79a14.28 14.28 0 0 0-3.72-.66z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||gaikai.com",
|
||||
"||playstation-cloud.com",
|
||||
"||playstation-cloud.net",
|
||||
"||playstation.com",
|
||||
"||playstation.net",
|
||||
"||scea.com",
|
||||
"||sonyentertainmentnetwork.com",
|
||||
"||station.sony.com",
|
||||
},
|
||||
}, {
|
||||
ID: "qq",
|
||||
Name: "QQ",
|
||||
@@ -1597,6 +1652,18 @@ var blockedServices = []blockedService{{
|
||||
"||redditmedia.com^",
|
||||
"||redditstatic.com^",
|
||||
},
|
||||
}, {
|
||||
ID: "riot_games",
|
||||
Name: "Riot Games",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 64 64\"><path d=\"M31.3 2.2a1 1 0 0 0-.8 1.4l.8 1.8a1 1 0 1 0 1.8-.8l-.8-1.8a1 1 0 0 0-1-.6zm-4.3 2a1 1 0 0 0-.9 1.5l1 1.8a1 1 0 1 0 1.7-.9L28 4.8a1 1 0 0 0-1-.6zm12 1a1 1 0 0 0-.4.1l-34 16.2a1 1 0 0 0-.6 1.1L9.3 47a1 1 0 0 0 1 .8h7a1 1 0 0 0 1-1l-1-12.2 3 12.4a1 1 0 0 0 .9.8h7.3a1 1 0 0 0 1-1l-.2-15.8L32 47a1 1 0 0 0 1 .8h7.6a1 1 0 0 0 1-1l1.3-19.4 1.4 19.5a1 1 0 0 0 1 1h10.2a1 1 0 0 0 1-1L60 11.2a1 1 0 0 0-.8-1l-20-5a1 1 0 0 0-.3 0zm-16.3 1a1 1 0 0 0-.9 1.5l.9 1.8a1 1 0 1 0 1.8-.8l-.9-1.8a1 1 0 0 0-1-.6zm16.4 1L57.9 12l-3.3 33.8h-8.3l-1.9-25a1 1 0 0 0-1.2-.8l-1.2.3a1 1 0 0 0-.7.9l-1.7 24.6h-5.8L30.3 25a1 1 0 0 0-1.3-.8l-1 .4a1 1 0 0 0-.8 1l.3 20.2H22l-4-17a1 1 0 0 0-1.2-.7l-1.1.3a1 1 0 0 0-.7 1l1.2 16.4H11L6.1 23l33-15.7zM18.5 8.4a1 1 0 0 0-1 1.5l.9 1.8a1 1 0 1 0 1.8-.9L19.3 9a1 1 0 0 0-.8-.6zm-4.3 2.1a1 1 0 0 0-.1 0 1 1 0 0 0-.9 1.4l.9 1.8a1 1 0 1 0 1.8-.8L15 11a1 1 0 0 0-.8-.6zm-4.4 2a1 1 0 0 0-.9 1.5l.9 1.8a1 1 0 1 0 1.8-.9l-.9-1.8a1 1 0 0 0-1-.5zm-4.3 2.1a1 1 0 0 0-.9 1.4l.9 1.9a1 1 0 1 0 1.8-1l-.9-1.7a1 1 0 0 0-.9-.6zM30.7 49a1 1 0 0 0-.9 1.4l2.5 6.5a1 1 0 0 0 .7.6l20.7 5.3a1 1 0 0 0 1.2-.9L56 51.4a1 1 0 0 0-1-1L30.9 49a1 1 0 0 0-.1 0zm1.5 2.1L54 52.3l-.9 8.2-19-4.8-1.8-4.6z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||dradis-prod.rdatasrv.net^",
|
||||
"||pvp.net^",
|
||||
"||rgpub.io^",
|
||||
"||riotcdn.com^",
|
||||
"||riotcdn.net^",
|
||||
"||riotgames.com^",
|
||||
},
|
||||
}, {
|
||||
ID: "roblox",
|
||||
Name: "Roblox",
|
||||
@@ -1610,6 +1677,32 @@ var blockedServices = []blockedService{{
|
||||
"||robloxcdn.com^",
|
||||
"||robloxdev.cn^",
|
||||
},
|
||||
}, {
|
||||
ID: "shopee",
|
||||
Name: "Shopee",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M25 1c-5.3 0-9.4 5-9.8 11H5a2 2 0 0 0-2 2.1l1.7 30.2a5 5 0 0 0 5 4.7h30.4a5 5 0 0 0 5-4.7L47 14a2 2 0 0 0-2-2.1H35C34.3 6 30.2 1 25 1zm0 2c4 0 7.4 3.9 7.8 9H17.2c.4-5.1 3.8-9 7.8-9zM5 14h10.8a1 1 0 0 0 .4 0h17.6a1 1 0 0 0 .4 0h10.7l-1.7 30.2a3 3 0 0 1-3 2.8H9.8a3 3 0 0 1-3-2.8L5 14zm20 4c-4.2 0-7.5 2.7-7.5 6.3 0 4 3.8 5.4 7 6.6 4 1.4 6.5 2.5 6.5 5.7 0 2.4-2.7 4.4-6 4.4-3.8 0-7-2.7-7-2.7l-1.2 1.6c.8.7 4.1 3.1 8.1 3.1 4.5 0 8-2.8 8-6.4 0-4.8-4-6.3-7.7-7.6-3.5-1.3-5.7-2.3-5.7-4.7 0-2.5 2.3-4.3 5.6-4.3a11 11 0 0 1 6 1.9l1-1.7c-.3-.1-3.2-2.2-7-2.2z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||shopee.cl^",
|
||||
"||shopee.cn^",
|
||||
"||shopee.co.id^",
|
||||
"||shopee.co.th^",
|
||||
"||shopee.com.br^",
|
||||
"||shopee.com.co^",
|
||||
"||shopee.com.mx^",
|
||||
"||shopee.com.my^",
|
||||
"||shopee.com^",
|
||||
"||shopee.es^",
|
||||
"||shopee.fr^",
|
||||
"||shopee.id^",
|
||||
"||shopee.in^",
|
||||
"||shopee.io^",
|
||||
"||shopee.ph^",
|
||||
"||shopee.sg^",
|
||||
"||shopee.tw^",
|
||||
"||shopee.vn^",
|
||||
"||shopeemobile.com^",
|
||||
"||shp.ee^",
|
||||
},
|
||||
}, {
|
||||
ID: "skype",
|
||||
Name: "Skype",
|
||||
@@ -1813,6 +1906,15 @@ var blockedServices = []blockedService{{
|
||||
"||twvid.com^",
|
||||
"||vine.co^",
|
||||
},
|
||||
}, {
|
||||
ID: "valorant",
|
||||
Name: "Valorant",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M4 6a1 1 0 0 0-1 1v18a1 1 0 0 0 .2.6l14 17a1 1 0 0 0 .8.4h14a1 1 0 0 0 .8-1.6l-28-35A1 1 0 0 0 4 6zm42 1a1 1 0 0 0-.8.4l-18 22A1 1 0 0 0 28 31h14a1 1 0 0 0 .8-.4l4-5a1 1 0 0 0 .2-.6V8a1 1 0 0 0-1-1zM5 9.9 30 41H18.4L5 24.6V10zm40 .9v13.8L41.5 29H30.1L45 10.8z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||playvalorant.com",
|
||||
"||valorant.scd.riotcdn.net",
|
||||
"||valorant.secure.dyn.riotcdn.net",
|
||||
},
|
||||
}, {
|
||||
ID: "viber",
|
||||
Name: "Viber",
|
||||
@@ -1869,6 +1971,13 @@ var blockedServices = []blockedService{{
|
||||
"||vkuservideo.com^",
|
||||
"||vkuservideo.net^",
|
||||
},
|
||||
}, {
|
||||
ID: "voot",
|
||||
Name: "Voot",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 512 512\"><path d=\"M96 340c-1 4-4 6-7 6H48c-3 0-6-2-8-6L0 213c-1-2 0-5 2-5l2-1h30c4 0 7 3 8 6l25 87c1 3 2 3 3 0l25-87c1-3 4-6 7-6h31c2 0 4 2 4 4v2L96 340zm46-50v-32c0-29 14-56 63-56s63 27 63 56v32c0 28-14 56-63 56s-63-28-63-56zm85 1v-35c0-13-7-20-22-20s-22 7-22 20v35c0 13 7 20 22 20s22-7 22-20zm54-1v-32c0-29 14-56 63-56s63 27 63 56v32c0 28-14 56-63 56s-63-28-63-56zm85 1v-35c0-13-7-20-22-20s-21 7-21 20v35c0 13 6 20 21 20s22-7 22-20zm144 44-2-17-1-2c-1-3-3-5-6-5h-2l-10 1c-2 1-4 0-6-2l-2-11v-56c0-3 2-6 6-6h17c4 0 6-2 7-5l1-22c0-3-2-5-5-6h-21c-3 0-5-2-5-5v-28c0-3-2-5-5-5h-1l-30 4c-3 1-5 4-5 7v22c0 3-3 5-6 5h-7c-4 0-6 3-6 6v22c0 3 2 5 6 5h7c3 0 6 3 6 6v67c0 26 15 36 42 36 8 0 16-1 23-4h1c2 0 5-3 5-6l-1-1z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||voot.com^",
|
||||
},
|
||||
}, {
|
||||
ID: "wechat",
|
||||
Name: "WeChat",
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
@@ -379,9 +380,9 @@ func (a *Auth) newCookie(req loginJSON, addr string) (c *http.Cookie, err error)
|
||||
// TODO(a.garipov): Support header Forwarded from RFC 7329.
|
||||
func realIP(r *http.Request) (ip net.IP, err error) {
|
||||
proxyHeaders := []string{
|
||||
"CF-Connecting-IP",
|
||||
"True-Client-IP",
|
||||
"X-Real-IP",
|
||||
httphdr.CFConnectingIP,
|
||||
httphdr.TrueClientIP,
|
||||
httphdr.XRealIP,
|
||||
}
|
||||
|
||||
for _, h := range proxyHeaders {
|
||||
@@ -394,7 +395,7 @@ func realIP(r *http.Request) (ip net.IP, err error) {
|
||||
|
||||
// If none of the above yielded any results, get the leftmost IP address
|
||||
// from the X-Forwarded-For header.
|
||||
s := r.Header.Get("X-Forwarded-For")
|
||||
s := r.Header.Get(httphdr.XForwardedFor)
|
||||
ipStrs := strings.SplitN(s, ", ", 2)
|
||||
ip = net.ParseIP(ipStrs[0])
|
||||
if ip != nil {
|
||||
@@ -411,6 +412,21 @@ func realIP(r *http.Request) (ip net.IP, err error) {
|
||||
return net.ParseIP(ipStr), nil
|
||||
}
|
||||
|
||||
// writeErrorWithIP is like [aghhttp.Error], but includes the remote IP address
|
||||
// when it writes to the log.
|
||||
func writeErrorWithIP(
|
||||
r *http.Request,
|
||||
w http.ResponseWriter,
|
||||
code int,
|
||||
remoteIP string,
|
||||
format string,
|
||||
args ...any,
|
||||
) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
log.Error("%s %s %s: from ip %s: %s", r.Method, r.Host, r.URL, remoteIP, text)
|
||||
http.Error(w, text, code)
|
||||
}
|
||||
|
||||
func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
req := loginJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
@@ -420,31 +436,45 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
var remoteAddr string
|
||||
var remoteIP string
|
||||
// realIP cannot be used here without taking TrustedProxies into account due
|
||||
// to security issues.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2799.
|
||||
//
|
||||
// TODO(e.burkov): Use realIP when the issue will be fixed.
|
||||
if remoteAddr, err = netutil.SplitHost(r.RemoteAddr); err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "auth: getting remote address: %s", err)
|
||||
if remoteIP, err = netutil.SplitHost(r.RemoteAddr); err != nil {
|
||||
writeErrorWithIP(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
r.RemoteAddr,
|
||||
"auth: getting remote address: %s",
|
||||
err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if rateLimiter := Context.auth.raleLimiter; rateLimiter != nil {
|
||||
if left := rateLimiter.check(remoteAddr); left > 0 {
|
||||
w.Header().Set("Retry-After", strconv.Itoa(int(left.Seconds())))
|
||||
aghhttp.Error(r, w, http.StatusTooManyRequests, "auth: blocked for %s", left)
|
||||
if left := rateLimiter.check(remoteIP); left > 0 {
|
||||
w.Header().Set(httphdr.RetryAfter, strconv.Itoa(int(left.Seconds())))
|
||||
writeErrorWithIP(
|
||||
r,
|
||||
w,
|
||||
http.StatusTooManyRequests,
|
||||
remoteIP,
|
||||
"auth: blocked for %s",
|
||||
left,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
cookie, err := Context.auth.newCookie(req, remoteAddr)
|
||||
cookie, err := Context.auth.newCookie(req, remoteIP)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusForbidden, "%s", err)
|
||||
writeErrorWithIP(r, w, http.StatusForbidden, remoteIP, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -452,10 +482,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
// Use realIP here, since this IP address is only used for logging.
|
||||
ip, err := realIP(r)
|
||||
if err != nil {
|
||||
log.Error("auth: getting real ip from request: %s", err)
|
||||
} else if ip == nil {
|
||||
// Technically shouldn't happen.
|
||||
log.Error("auth: unknown ip")
|
||||
log.Error("auth: getting real ip from request with remote ip %s: %s", remoteIP, err)
|
||||
}
|
||||
|
||||
log.Info("auth: user %q successfully logged in from ip %v", req.Name, ip)
|
||||
@@ -463,9 +490,9 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
http.SetCookie(w, cookie)
|
||||
|
||||
h := w.Header()
|
||||
h.Set("Cache-Control", "no-store, no-cache, must-revalidate, proxy-revalidate")
|
||||
h.Set("Pragma", "no-cache")
|
||||
h.Set("Expires", "0")
|
||||
h.Set(httphdr.CacheControl, "no-store, no-cache, must-revalidate, proxy-revalidate")
|
||||
h.Set(httphdr.Pragma, "no-cache")
|
||||
h.Set(httphdr.Expires, "0")
|
||||
|
||||
aghhttp.OK(w)
|
||||
}
|
||||
@@ -476,7 +503,7 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
// The only error that is returned from r.Cookie is [http.ErrNoCookie].
|
||||
// The user is already logged out.
|
||||
respHdr.Set("Location", "/login.html")
|
||||
respHdr.Set(httphdr.Location, "/login.html")
|
||||
w.WriteHeader(http.StatusFound)
|
||||
|
||||
return
|
||||
@@ -494,8 +521,8 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
|
||||
respHdr.Set("Location", "/login.html")
|
||||
respHdr.Set("Set-Cookie", c.String())
|
||||
respHdr.Set(httphdr.Location, "/login.html")
|
||||
respHdr.Set(httphdr.SetCookie, c.String())
|
||||
w.WriteHeader(http.StatusFound)
|
||||
}
|
||||
|
||||
@@ -543,8 +570,7 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (mustAuth bool) {
|
||||
log.Debug("auth: redirected to login page by GL-Inet submodule")
|
||||
} else {
|
||||
log.Debug("auth: redirected to login page")
|
||||
w.Header().Set("Location", "/login.html")
|
||||
w.WriteHeader(http.StatusFound)
|
||||
http.Redirect(w, r, "login.html", http.StatusFound)
|
||||
}
|
||||
} else {
|
||||
log.Debug("auth: responded with forbidden to %s %s", r.Method, p)
|
||||
@@ -569,8 +595,7 @@ func optionalAuth(
|
||||
// Redirect to the dashboard if already authenticated.
|
||||
res := Context.auth.checkSession(cookie.Value)
|
||||
if res == checkSessionOK {
|
||||
w.Header().Set("Location", "/")
|
||||
w.WriteHeader(http.StatusFound)
|
||||
http.Redirect(w, r, "", http.StatusFound)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -135,11 +136,11 @@ func TestAuthHTTP(t *testing.T) {
|
||||
handlerCalled = false
|
||||
handler2(&w, &r)
|
||||
assert.Equal(t, http.StatusFound, w.statusCode)
|
||||
assert.NotEmpty(t, w.hdr.Get("Location"))
|
||||
assert.NotEmpty(t, w.hdr.Get(httphdr.Location))
|
||||
assert.False(t, handlerCalled)
|
||||
|
||||
// go to login page
|
||||
loginURL := w.hdr.Get("Location")
|
||||
loginURL := w.hdr.Get(httphdr.Location)
|
||||
r.URL = &url.URL{Path: loginURL}
|
||||
handlerCalled = false
|
||||
handler2(&w, &r)
|
||||
@@ -153,13 +154,13 @@ func TestAuthHTTP(t *testing.T) {
|
||||
// get /
|
||||
handler2 = optionalAuth(handler)
|
||||
w.hdr = make(http.Header)
|
||||
r.Header.Set("Cookie", cookie.String())
|
||||
r.Header.Set(httphdr.Cookie, cookie.String())
|
||||
r.URL = &url.URL{Path: "/"}
|
||||
handlerCalled = false
|
||||
handler2(&w, &r)
|
||||
assert.True(t, handlerCalled)
|
||||
|
||||
r.Header.Del("Cookie")
|
||||
r.Header.Del(httphdr.Cookie)
|
||||
|
||||
// get / with basic auth
|
||||
handler2 = optionalAuth(handler)
|
||||
@@ -169,28 +170,28 @@ func TestAuthHTTP(t *testing.T) {
|
||||
handlerCalled = false
|
||||
handler2(&w, &r)
|
||||
assert.True(t, handlerCalled)
|
||||
r.Header.Del("Authorization")
|
||||
r.Header.Del(httphdr.Authorization)
|
||||
|
||||
// get login page with a valid cookie - we're redirected to /
|
||||
handler2 = optionalAuth(handler)
|
||||
w.hdr = make(http.Header)
|
||||
r.Header.Set("Cookie", cookie.String())
|
||||
r.Header.Set(httphdr.Cookie, cookie.String())
|
||||
r.URL = &url.URL{Path: loginURL}
|
||||
handlerCalled = false
|
||||
handler2(&w, &r)
|
||||
assert.NotEmpty(t, w.hdr.Get("Location"))
|
||||
assert.NotEmpty(t, w.hdr.Get(httphdr.Location))
|
||||
assert.False(t, handlerCalled)
|
||||
r.Header.Del("Cookie")
|
||||
r.Header.Del(httphdr.Cookie)
|
||||
|
||||
// get login page with an invalid cookie
|
||||
handler2 = optionalAuth(handler)
|
||||
w.hdr = make(http.Header)
|
||||
r.Header.Set("Cookie", "bad")
|
||||
r.Header.Set(httphdr.Cookie, "bad")
|
||||
r.URL = &url.URL{Path: loginURL}
|
||||
handlerCalled = false
|
||||
handler2(&w, &r)
|
||||
assert.True(t, handlerCalled)
|
||||
r.Header.Del("Cookie")
|
||||
r.Header.Del(httphdr.Cookie)
|
||||
|
||||
Context.auth.Close()
|
||||
}
|
||||
@@ -213,7 +214,7 @@ func TestRealIP(t *testing.T) {
|
||||
}, {
|
||||
name: "success_proxy",
|
||||
header: http.Header{
|
||||
textproto.CanonicalMIMEHeaderKey("X-Real-IP"): []string{"1.2.3.5"},
|
||||
textproto.CanonicalMIMEHeaderKey(httphdr.XRealIP): []string{"1.2.3.5"},
|
||||
},
|
||||
remoteAddr: remoteAddr,
|
||||
wantErrMsg: "",
|
||||
@@ -221,7 +222,7 @@ func TestRealIP(t *testing.T) {
|
||||
}, {
|
||||
name: "success_proxy_multiple",
|
||||
header: http.Header{
|
||||
textproto.CanonicalMIMEHeaderKey("X-Forwarded-For"): []string{
|
||||
textproto.CanonicalMIMEHeaderKey(httphdr.XForwardedFor): []string{
|
||||
"1.2.3.6, 1.2.3.5",
|
||||
},
|
||||
},
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/josharian/native"
|
||||
)
|
||||
|
||||
// GLMode - enable GL-Inet compatibility mode
|
||||
@@ -102,7 +102,7 @@ func glGetTokenDate(file string) uint32 {
|
||||
|
||||
buf := bytes.NewBuffer(bs)
|
||||
|
||||
err = binary.Read(buf, aghos.NativeEndian, &dateToken)
|
||||
err = binary.Read(buf, native.Endian, &dateToken)
|
||||
if err != nil {
|
||||
log.Error("decoding token: %s", err)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/josharian/native"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -19,13 +19,13 @@ func TestAuthGL(t *testing.T) {
|
||||
glFilePrefix = dir + "/gl_token_"
|
||||
|
||||
data := make([]byte, 4)
|
||||
aghos.NativeEndian.PutUint32(data, 1)
|
||||
native.Endian.PutUint32(data, 1)
|
||||
|
||||
require.NoError(t, os.WriteFile(glFilePrefix+"test", data, 0o644))
|
||||
assert.False(t, glCheckToken("test"))
|
||||
|
||||
data = make([]byte, 4)
|
||||
aghos.NativeEndian.PutUint32(data, uint32(time.Now().UTC().Unix()+60))
|
||||
native.Endian.PutUint32(data, uint32(time.Now().UTC().Unix()+60))
|
||||
|
||||
require.NoError(t, os.WriteFile(glFilePrefix+"test", data, 0o644))
|
||||
r, _ := http.NewRequest(http.MethodGet, "http://localhost/", nil)
|
||||
|
||||
@@ -3,7 +3,10 @@ package home
|
||||
import (
|
||||
"encoding"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
)
|
||||
|
||||
@@ -15,6 +18,9 @@ type Client struct {
|
||||
// these upstream must be used.
|
||||
upstreamConfig *proxy.UpstreamConfig
|
||||
|
||||
safeSearchConf filtering.SafeSearchConfig
|
||||
SafeSearch filtering.SafeSearch
|
||||
|
||||
Name string
|
||||
|
||||
IDs []string
|
||||
@@ -24,10 +30,11 @@ type Client struct {
|
||||
|
||||
UseOwnSettings bool
|
||||
FilteringEnabled bool
|
||||
SafeSearchEnabled bool
|
||||
SafeBrowsingEnabled bool
|
||||
ParentalEnabled bool
|
||||
UseOwnBlockedServices bool
|
||||
IgnoreQueryLog bool
|
||||
IgnoreStatistics bool
|
||||
}
|
||||
|
||||
// closeUpstreams closes the client-specific upstream config of c if any.
|
||||
@@ -42,6 +49,23 @@ func (c *Client) closeUpstreams() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setSafeSearch initializes and sets the safe search filter for this client.
|
||||
func (c *Client) setSafeSearch(
|
||||
conf filtering.SafeSearchConfig,
|
||||
cacheSize uint,
|
||||
cacheTTL time.Duration,
|
||||
) (err error) {
|
||||
ss, err := safesearch.NewDefault(conf, fmt.Sprintf("client %q", c.Name), cacheSize, cacheTTL)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
c.SafeSearch = ss
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// clientSource represents the source from which the information about the
|
||||
// client has been obtained.
|
||||
type clientSource uint
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
@@ -52,9 +51,17 @@ type clientsContainer struct {
|
||||
// lock protects all fields.
|
||||
//
|
||||
// TODO(a.garipov): Use a pointer and describe which fields are protected in
|
||||
// more detail.
|
||||
// more detail. Use sync.RWMutex.
|
||||
lock sync.Mutex
|
||||
|
||||
// safeSearchCacheSize is the size of the safe search cache to use for
|
||||
// persistent clients.
|
||||
safeSearchCacheSize uint
|
||||
|
||||
// safeSearchCacheTTL is the TTL of the safe search cache to use for
|
||||
// persistent clients.
|
||||
safeSearchCacheTTL time.Duration
|
||||
|
||||
// testing is a flag that disables some features for internal tests.
|
||||
//
|
||||
// TODO(a.garipov): Awful. Remove.
|
||||
@@ -69,10 +76,12 @@ func (clients *clientsContainer) Init(
|
||||
dhcpServer dhcpd.Interface,
|
||||
etcHosts *aghnet.HostsContainer,
|
||||
arpdb aghnet.ARPDB,
|
||||
filteringConf *filtering.Config,
|
||||
) {
|
||||
if clients.list != nil {
|
||||
log.Fatal("clients.list != nil")
|
||||
}
|
||||
|
||||
clients.list = make(map[string]*Client)
|
||||
clients.idIndex = make(map[string]*Client)
|
||||
clients.ipToRC = map[netip.Addr]*RuntimeClient{}
|
||||
@@ -82,7 +91,10 @@ func (clients *clientsContainer) Init(
|
||||
clients.dhcpServer = dhcpServer
|
||||
clients.etcHosts = etcHosts
|
||||
clients.arpdb = arpdb
|
||||
clients.addFromConfig(objects)
|
||||
clients.addFromConfig(objects, filteringConf)
|
||||
|
||||
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
|
||||
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
|
||||
|
||||
if clients.testing {
|
||||
return
|
||||
@@ -133,6 +145,8 @@ func (clients *clientsContainer) reloadARP() {
|
||||
|
||||
// clientObject is the YAML representation of a persistent client.
|
||||
type clientObject struct {
|
||||
SafeSearchConf filtering.SafeSearchConfig `yaml:"safe_search"`
|
||||
|
||||
Name string `yaml:"name"`
|
||||
|
||||
Tags []string `yaml:"tags"`
|
||||
@@ -143,14 +157,16 @@ type clientObject struct {
|
||||
UseGlobalSettings bool `yaml:"use_global_settings"`
|
||||
FilteringEnabled bool `yaml:"filtering_enabled"`
|
||||
ParentalEnabled bool `yaml:"parental_enabled"`
|
||||
SafeSearchEnabled bool `yaml:"safesearch_enabled"`
|
||||
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
|
||||
UseGlobalBlockedServices bool `yaml:"use_global_blocked_services"`
|
||||
|
||||
IgnoreQueryLog bool `yaml:"ignore_querylog"`
|
||||
IgnoreStatistics bool `yaml:"ignore_statistics"`
|
||||
}
|
||||
|
||||
// addFromConfig initializes the clients container with objects from the
|
||||
// configuration file.
|
||||
func (clients *clientsContainer) addFromConfig(objects []*clientObject) {
|
||||
func (clients *clientsContainer) addFromConfig(objects []*clientObject, filteringConf *filtering.Config) {
|
||||
for _, o := range objects {
|
||||
cli := &Client{
|
||||
Name: o.Name,
|
||||
@@ -161,9 +177,26 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject) {
|
||||
UseOwnSettings: !o.UseGlobalSettings,
|
||||
FilteringEnabled: o.FilteringEnabled,
|
||||
ParentalEnabled: o.ParentalEnabled,
|
||||
SafeSearchEnabled: o.SafeSearchEnabled,
|
||||
safeSearchConf: o.SafeSearchConf,
|
||||
SafeBrowsingEnabled: o.SafeBrowsingEnabled,
|
||||
UseOwnBlockedServices: !o.UseGlobalBlockedServices,
|
||||
IgnoreQueryLog: o.IgnoreQueryLog,
|
||||
IgnoreStatistics: o.IgnoreStatistics,
|
||||
}
|
||||
|
||||
if o.SafeSearchConf.Enabled {
|
||||
o.SafeSearchConf.CustomResolver = safeSearchResolver{}
|
||||
|
||||
err := cli.setSafeSearch(
|
||||
o.SafeSearchConf,
|
||||
filteringConf.SafeSearchCacheSize,
|
||||
time.Minute*time.Duration(filteringConf.CacheTime),
|
||||
)
|
||||
if err != nil {
|
||||
log.Error("clients: init client safesearch %q: %s", cli.Name, err)
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
for _, s := range o.BlockedServices {
|
||||
@@ -210,9 +243,11 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||
UseGlobalSettings: !cli.UseOwnSettings,
|
||||
FilteringEnabled: cli.FilteringEnabled,
|
||||
ParentalEnabled: cli.ParentalEnabled,
|
||||
SafeSearchEnabled: cli.SafeSearchEnabled,
|
||||
SafeSearchConf: cli.safeSearchConf,
|
||||
SafeBrowsingEnabled: cli.SafeBrowsingEnabled,
|
||||
UseGlobalBlockedServices: !cli.UseOwnBlockedServices,
|
||||
IgnoreQueryLog: cli.IgnoreQueryLog,
|
||||
IgnoreStatistics: cli.IgnoreStatistics,
|
||||
}
|
||||
|
||||
objs = append(objs, o)
|
||||
@@ -324,7 +359,8 @@ func (clients *clientsContainer) clientOrArtificial(
|
||||
client, ok := clients.Find(id)
|
||||
if ok {
|
||||
return &querylog.Client{
|
||||
Name: client.Name,
|
||||
Name: client.Name,
|
||||
IgnoreQueryLog: client.IgnoreQueryLog,
|
||||
}, false
|
||||
}
|
||||
|
||||
@@ -359,6 +395,20 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
|
||||
return c, true
|
||||
}
|
||||
|
||||
// shouldCountClient is a wrapper around Find to make it a valid client
|
||||
// information finder for the statistics. If no information about the client
|
||||
// is found, it returns true.
|
||||
func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) {
|
||||
for _, id := range ids {
|
||||
client, ok := clients.Find(id)
|
||||
if ok {
|
||||
return !client.IgnoreStatistics
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// findUpstreams returns upstreams configured for the client, identified either
|
||||
// by its IP address or its ClientID. upsConf is nil if the client isn't found
|
||||
// or if the client has no custom upstreams.
|
||||
@@ -389,6 +439,7 @@ func (clients *clientsContainer) findUpstreams(
|
||||
Bootstrap: config.DNS.BootstrapDNS,
|
||||
Timeout: config.DNS.UpstreamTimeout.Duration,
|
||||
HTTPVersions: dnsforward.UpstreamHTTPVersions(config.DNS.UseHTTP3Upstreams),
|
||||
PreferIPv6: config.DNS.BootstrapPreferIPv6,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
@@ -839,15 +890,7 @@ func (clients *clientsContainer) updateFromDHCP(add bool) {
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
|
||||
ipAddr, err := netutil.IPToAddrNoMapped(l.IP)
|
||||
if err != nil {
|
||||
log.Error("clients: bad client ip %v from dhcp: %s", l.IP, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
ok := clients.addHostLocked(ipAddr, l.Hostname, ClientSourceDHCP)
|
||||
ok := clients.addHostLocked(l.IP, l.Hostname, ClientSourceDHCP)
|
||||
if ok {
|
||||
n++
|
||||
}
|
||||
|
||||
@@ -9,17 +9,27 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClients(t *testing.T) {
|
||||
clients := clientsContainer{}
|
||||
clients.testing = true
|
||||
// newClientsContainer is a helper that creates a new clients container for
|
||||
// tests.
|
||||
func newClientsContainer() (c *clientsContainer) {
|
||||
c = &clientsContainer{
|
||||
testing: true,
|
||||
}
|
||||
|
||||
clients.Init(nil, nil, nil, nil)
|
||||
c.Init(nil, nil, nil, nil, &filtering.Config{})
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func TestClients(t *testing.T) {
|
||||
clients := newClientsContainer()
|
||||
|
||||
t.Run("add_success", func(t *testing.T) {
|
||||
var (
|
||||
@@ -198,10 +208,7 @@ func TestClients(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientsWHOIS(t *testing.T) {
|
||||
clients := clientsContainer{
|
||||
testing: true,
|
||||
}
|
||||
clients.Init(nil, nil, nil, nil)
|
||||
clients := newClientsContainer()
|
||||
whois := &RuntimeClientWHOISInfo{
|
||||
Country: "AU",
|
||||
Orgname: "Example Org",
|
||||
@@ -247,10 +254,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientsAddExisting(t *testing.T) {
|
||||
clients := clientsContainer{
|
||||
testing: true,
|
||||
}
|
||||
clients.Init(nil, nil, nil, nil)
|
||||
clients := newClientsContainer()
|
||||
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
@@ -275,7 +279,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
t.Skip("skipping dhcp test on windows")
|
||||
}
|
||||
|
||||
ip := net.IP{1, 2, 3, 4}
|
||||
ip := netip.MustParseAddr("1.2.3.4")
|
||||
|
||||
// First, init a DHCP server with a single static lease.
|
||||
config := &dhcpd.ServerConfig{
|
||||
@@ -325,10 +329,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientsCustomUpstream(t *testing.T) {
|
||||
clients := clientsContainer{
|
||||
testing: true,
|
||||
}
|
||||
clients.Init(nil, nil, nil, nil)
|
||||
clients := newClientsContainer()
|
||||
|
||||
// Add client with upstreams.
|
||||
ok, err := clients.Add(&Client{
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
)
|
||||
|
||||
// clientJSON is a common structure used by several handlers to deal with
|
||||
@@ -26,7 +27,8 @@ type clientJSON struct {
|
||||
// the allowlist.
|
||||
DisallowedRule *string `json:"disallowed_rule,omitempty"`
|
||||
|
||||
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info,omitempty"`
|
||||
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info,omitempty"`
|
||||
SafeSearchConf *filtering.SafeSearchConfig `json:"safe_search"`
|
||||
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -35,9 +37,10 @@ type clientJSON struct {
|
||||
Tags []string `json:"tags"`
|
||||
Upstreams []string `json:"upstreams"`
|
||||
|
||||
FilteringEnabled bool `json:"filtering_enabled"`
|
||||
ParentalEnabled bool `json:"parental_enabled"`
|
||||
SafeBrowsingEnabled bool `json:"safebrowsing_enabled"`
|
||||
FilteringEnabled bool `json:"filtering_enabled"`
|
||||
ParentalEnabled bool `json:"parental_enabled"`
|
||||
SafeBrowsingEnabled bool `json:"safebrowsing_enabled"`
|
||||
// Deprecated: use safeSearchConf.
|
||||
SafeSearchEnabled bool `json:"safesearch_enabled"`
|
||||
UseGlobalBlockedServices bool `json:"use_global_blocked_services"`
|
||||
UseGlobalSettings bool `json:"use_global_settings"`
|
||||
@@ -46,8 +49,8 @@ type clientJSON struct {
|
||||
type runtimeClientJSON struct {
|
||||
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
|
||||
|
||||
Name string `json:"name"`
|
||||
IP netip.Addr `json:"ip"`
|
||||
Name string `json:"name"`
|
||||
Source clientSource `json:"source"`
|
||||
}
|
||||
|
||||
@@ -57,7 +60,7 @@ type clientListJSON struct {
|
||||
Tags []string `json:"supported_tags"`
|
||||
}
|
||||
|
||||
// respond with information about configured clients
|
||||
// handleGetClients is the handler for GET /control/clients HTTP API.
|
||||
func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http.Request) {
|
||||
data := clientListJSON{}
|
||||
|
||||
@@ -86,27 +89,67 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
_ = aghhttp.WriteJSONResponse(w, r, data)
|
||||
}
|
||||
|
||||
// Convert JSON object to Client object
|
||||
func jsonToClient(cj clientJSON) (c *Client) {
|
||||
return &Client{
|
||||
Name: cj.Name,
|
||||
IDs: cj.IDs,
|
||||
Tags: cj.Tags,
|
||||
UseOwnSettings: !cj.UseGlobalSettings,
|
||||
FilteringEnabled: cj.FilteringEnabled,
|
||||
ParentalEnabled: cj.ParentalEnabled,
|
||||
SafeSearchEnabled: cj.SafeSearchEnabled,
|
||||
SafeBrowsingEnabled: cj.SafeBrowsingEnabled,
|
||||
// jsonToClient converts JSON object to Client object.
|
||||
func (clients *clientsContainer) jsonToClient(cj clientJSON) (c *Client, err error) {
|
||||
var safeSearchConf filtering.SafeSearchConfig
|
||||
if cj.SafeSearchConf != nil {
|
||||
safeSearchConf = *cj.SafeSearchConf
|
||||
} else {
|
||||
// TODO(d.kolyshev): Remove after cleaning the deprecated
|
||||
// [clientJSON.SafeSearchEnabled] field.
|
||||
safeSearchConf = filtering.SafeSearchConfig{
|
||||
Enabled: cj.SafeSearchEnabled,
|
||||
}
|
||||
|
||||
UseOwnBlockedServices: !cj.UseGlobalBlockedServices,
|
||||
BlockedServices: cj.BlockedServices,
|
||||
|
||||
Upstreams: cj.Upstreams,
|
||||
// Set default service flags for enabled safesearch.
|
||||
if safeSearchConf.Enabled {
|
||||
safeSearchConf.Bing = true
|
||||
safeSearchConf.DuckDuckGo = true
|
||||
safeSearchConf.Google = true
|
||||
safeSearchConf.Pixabay = true
|
||||
safeSearchConf.Yandex = true
|
||||
safeSearchConf.YouTube = true
|
||||
}
|
||||
}
|
||||
|
||||
c = &Client{
|
||||
safeSearchConf: safeSearchConf,
|
||||
|
||||
Name: cj.Name,
|
||||
|
||||
IDs: cj.IDs,
|
||||
Tags: cj.Tags,
|
||||
BlockedServices: cj.BlockedServices,
|
||||
Upstreams: cj.Upstreams,
|
||||
|
||||
UseOwnSettings: !cj.UseGlobalSettings,
|
||||
FilteringEnabled: cj.FilteringEnabled,
|
||||
ParentalEnabled: cj.ParentalEnabled,
|
||||
SafeBrowsingEnabled: cj.SafeBrowsingEnabled,
|
||||
UseOwnBlockedServices: !cj.UseGlobalBlockedServices,
|
||||
}
|
||||
|
||||
if safeSearchConf.Enabled {
|
||||
err = c.setSafeSearch(
|
||||
safeSearchConf,
|
||||
clients.safeSearchCacheSize,
|
||||
clients.safeSearchCacheTTL,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating safesearch for client %q: %w", c.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Convert Client object to JSON
|
||||
// clientToJSON converts Client object to JSON.
|
||||
func clientToJSON(c *Client) (cj *clientJSON) {
|
||||
// TODO(d.kolyshev): Remove after cleaning the deprecated
|
||||
// [clientJSON.SafeSearchEnabled] field.
|
||||
cloneVal := c.safeSearchConf
|
||||
safeSearchConf := &cloneVal
|
||||
|
||||
return &clientJSON{
|
||||
Name: c.Name,
|
||||
IDs: c.IDs,
|
||||
@@ -114,7 +157,8 @@ func clientToJSON(c *Client) (cj *clientJSON) {
|
||||
UseGlobalSettings: !c.UseOwnSettings,
|
||||
FilteringEnabled: c.FilteringEnabled,
|
||||
ParentalEnabled: c.ParentalEnabled,
|
||||
SafeSearchEnabled: c.SafeSearchEnabled,
|
||||
SafeSearchEnabled: safeSearchConf.Enabled,
|
||||
SafeSearchConf: safeSearchConf,
|
||||
SafeBrowsingEnabled: c.SafeBrowsingEnabled,
|
||||
|
||||
UseGlobalBlockedServices: !c.UseOwnBlockedServices,
|
||||
@@ -124,7 +168,7 @@ func clientToJSON(c *Client) (cj *clientJSON) {
|
||||
}
|
||||
}
|
||||
|
||||
// Add a new client
|
||||
// handleAddClient is the handler for POST /control/clients/add HTTP API.
|
||||
func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.Request) {
|
||||
cj := clientJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&cj)
|
||||
@@ -134,7 +178,13 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
|
||||
return
|
||||
}
|
||||
|
||||
c := jsonToClient(cj)
|
||||
c, err := clients.jsonToClient(cj)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := clients.Add(c)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
@@ -151,7 +201,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
|
||||
onConfigModified()
|
||||
}
|
||||
|
||||
// Remove client
|
||||
// handleDelClient is the handler for POST /control/clients/delete HTTP API.
|
||||
func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.Request) {
|
||||
cj := clientJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&cj)
|
||||
@@ -181,7 +231,7 @@ type updateJSON struct {
|
||||
Data clientJSON `json:"data"`
|
||||
}
|
||||
|
||||
// Update client's properties
|
||||
// handleUpdateClient is the handler for POST /control/clients/update HTTP API.
|
||||
func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *http.Request) {
|
||||
dj := updateJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&dj)
|
||||
@@ -197,7 +247,13 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||
return
|
||||
}
|
||||
|
||||
c := jsonToClient(dj.Data)
|
||||
c, err := clients.jsonToClient(dj.Data)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = clients.Update(dj.Name, c)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
@@ -208,7 +264,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||
onConfigModified()
|
||||
}
|
||||
|
||||
// Get the list of clients by IP address list
|
||||
// handleFindClient is the handler for GET /control/clients/find HTTP API.
|
||||
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
data := []map[string]*clientJSON{}
|
||||
|
||||
@@ -228,34 +228,32 @@ type tlsConfigSettings struct {
|
||||
}
|
||||
|
||||
type queryLogConfig struct {
|
||||
// Ignored is the list of host names, which should not be written to log.
|
||||
Ignored []string `yaml:"ignored"`
|
||||
|
||||
// Interval is the interval for query log's files rotation.
|
||||
Interval timeutil.Duration `yaml:"interval"`
|
||||
|
||||
// MemSize is the number of entries kept in memory before they are flushed
|
||||
// to disk.
|
||||
MemSize uint32 `yaml:"size_memory"`
|
||||
|
||||
// Enabled defines if the query log is enabled.
|
||||
Enabled bool `yaml:"enabled"`
|
||||
|
||||
// FileEnabled defines, if the query log is written to the file.
|
||||
FileEnabled bool `yaml:"file_enabled"`
|
||||
|
||||
// Interval is the interval for query log's files rotation.
|
||||
Interval timeutil.Duration `yaml:"interval"`
|
||||
|
||||
// MemSize is the number of entries kept in memory before they are
|
||||
// flushed to disk.
|
||||
MemSize uint32 `yaml:"size_memory"`
|
||||
|
||||
// Ignored is the list of host names, which should not be written to
|
||||
// log.
|
||||
Ignored []string `yaml:"ignored"`
|
||||
}
|
||||
|
||||
type statsConfig struct {
|
||||
// Enabled defines if the statistics are enabled.
|
||||
Enabled bool `yaml:"enabled"`
|
||||
|
||||
// Interval is the time interval for flushing statistics to the disk in
|
||||
// days.
|
||||
Interval uint32 `yaml:"interval"`
|
||||
|
||||
// Ignored is the list of host names, which should not be counted.
|
||||
Ignored []string `yaml:"ignored"`
|
||||
|
||||
// Interval is the retention interval for statistics.
|
||||
Interval timeutil.Duration `yaml:"interval"`
|
||||
|
||||
// Enabled defines if the statistics are enabled.
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
||||
// config is the global configuration structure.
|
||||
@@ -286,7 +284,7 @@ var config = &configuration{
|
||||
CacheSize: 4 * 1024 * 1024,
|
||||
|
||||
EDNSClientSubnet: &dnsforward.EDNSClientSubnet{
|
||||
CustomIP: "",
|
||||
CustomIP: netip.Addr{},
|
||||
Enabled: false,
|
||||
UseCustom: false,
|
||||
},
|
||||
@@ -322,7 +320,7 @@ var config = &configuration{
|
||||
},
|
||||
Stats: statsConfig{
|
||||
Enabled: true,
|
||||
Interval: 1,
|
||||
Interval: timeutil.Duration{Duration: 1 * timeutil.Day},
|
||||
Ignored: []string{},
|
||||
},
|
||||
// NOTE: Keep these parameters in sync with the one put into
|
||||
@@ -503,7 +501,7 @@ func (c *configuration) write() (err error) {
|
||||
if Context.stats != nil {
|
||||
statsConf := stats.Config{}
|
||||
Context.stats.WriteDiskConfig(&statsConf)
|
||||
config.Stats.Interval = statsConf.LimitDays
|
||||
config.Stats.Interval = timeutil.Duration{Duration: statsConf.Limit}
|
||||
config.Stats.Enabled = statsConf.Enabled
|
||||
config.Stats.Ignored = statsConf.Ignored.Values()
|
||||
slices.Sort(config.Stats.Ignored)
|
||||
|
||||
@@ -13,7 +13,9 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/mathutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/NYTimes/gziphandler"
|
||||
)
|
||||
@@ -97,12 +99,17 @@ func collectDNSAddresses() (addrs []string, err error) {
|
||||
|
||||
// statusResponse is a response for /control/status endpoint.
|
||||
type statusResponse struct {
|
||||
Version string `json:"version"`
|
||||
Language string `json:"language"`
|
||||
DNSAddrs []string `json:"dns_addresses"`
|
||||
DNSPort int `json:"dns_port"`
|
||||
HTTPPort int `json:"http_port"`
|
||||
IsProtectionEnabled bool `json:"protection_enabled"`
|
||||
Version string `json:"version"`
|
||||
Language string `json:"language"`
|
||||
DNSAddrs []string `json:"dns_addresses"`
|
||||
DNSPort int `json:"dns_port"`
|
||||
HTTPPort int `json:"http_port"`
|
||||
|
||||
// ProtectionDisabledDuration is the duration of the protection pause in
|
||||
// milliseconds.
|
||||
ProtectionDisabledDuration int64 `json:"protection_disabled_duration"`
|
||||
|
||||
ProtectionEnabled bool `json:"protection_enabled"`
|
||||
// TODO(e.burkov): Inspect if front-end doesn't requires this field as
|
||||
// openapi.yaml declares.
|
||||
IsDHCPAvailable bool `json:"dhcp_available"`
|
||||
@@ -119,28 +126,45 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
fltConf *dnsforward.FilteringConfig
|
||||
protectionDisabledUntil *time.Time
|
||||
protectionEnabled bool
|
||||
)
|
||||
if Context.dnsServer != nil {
|
||||
fltConf = &dnsforward.FilteringConfig{}
|
||||
Context.dnsServer.WriteDiskConfig(fltConf)
|
||||
protectionEnabled, protectionDisabledUntil = Context.dnsServer.UpdatedProtectionStatus()
|
||||
}
|
||||
|
||||
var resp statusResponse
|
||||
func() {
|
||||
config.RLock()
|
||||
defer config.RUnlock()
|
||||
|
||||
var protectionDisabledDuration int64
|
||||
if protectionDisabledUntil != nil {
|
||||
// Make sure that we don't send negative numbers to the frontend,
|
||||
// since enough time might have passed to make the difference less
|
||||
// than zero.
|
||||
protectionDisabledDuration = mathutil.Max(
|
||||
0,
|
||||
time.Until(*protectionDisabledUntil).Milliseconds(),
|
||||
)
|
||||
}
|
||||
|
||||
resp = statusResponse{
|
||||
Version: version.Version(),
|
||||
DNSAddrs: dnsAddrs,
|
||||
DNSPort: config.DNS.Port,
|
||||
HTTPPort: config.BindPort,
|
||||
Language: config.Language,
|
||||
IsRunning: isRunning(),
|
||||
Version: version.Version(),
|
||||
Language: config.Language,
|
||||
DNSAddrs: dnsAddrs,
|
||||
DNSPort: config.DNS.Port,
|
||||
HTTPPort: config.BindPort,
|
||||
ProtectionDisabledDuration: protectionDisabledDuration,
|
||||
ProtectionEnabled: protectionEnabled,
|
||||
IsRunning: isRunning(),
|
||||
}
|
||||
}()
|
||||
|
||||
var c *dnsforward.FilteringConfig
|
||||
if Context.dnsServer != nil {
|
||||
c = &dnsforward.FilteringConfig{}
|
||||
Context.dnsServer.WriteDiskConfig(c)
|
||||
resp.IsProtectionEnabled = c.ProtectionEnabled
|
||||
}
|
||||
|
||||
// IsDHCPAvailable field is now false by default for Windows.
|
||||
if runtime.GOOS != "windows" {
|
||||
resp.IsDHCPAvailable = Context.dhcpServer != nil
|
||||
@@ -219,7 +243,7 @@ func modifiesData(m string) (ok bool) {
|
||||
func ensureContentType(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||
const statusUnsup = http.StatusUnsupportedMediaType
|
||||
|
||||
cType := r.Header.Get(aghhttp.HdrNameContentType)
|
||||
cType := r.Header.Get(httphdr.ContentType)
|
||||
if r.ContentLength == 0 {
|
||||
if cType == "" {
|
||||
return true
|
||||
@@ -308,13 +332,17 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||
return false
|
||||
}
|
||||
|
||||
var serveHTTP3 bool
|
||||
var portHTTPS int
|
||||
var (
|
||||
forceHTTPS bool
|
||||
serveHTTP3 bool
|
||||
portHTTPS int
|
||||
)
|
||||
func() {
|
||||
config.RLock()
|
||||
defer config.RUnlock()
|
||||
|
||||
serveHTTP3, portHTTPS = config.DNS.ServeHTTP3, config.TLS.PortHTTPS
|
||||
forceHTTPS = config.TLS.ForceHTTPS && config.TLS.Enabled && config.TLS.PortHTTPS != 0
|
||||
}()
|
||||
|
||||
respHdr := w.Header()
|
||||
@@ -327,13 +355,13 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||
// default is 24 hours.
|
||||
if serveHTTP3 {
|
||||
altSvc := fmt.Sprintf(`h3=":%d"`, portHTTPS)
|
||||
respHdr.Set(aghhttp.HdrNameAltSvc, altSvc)
|
||||
respHdr.Set(httphdr.AltSvc, altSvc)
|
||||
}
|
||||
|
||||
if r.TLS == nil && web.forceHTTPS {
|
||||
if r.TLS == nil && forceHTTPS {
|
||||
hostPort := host
|
||||
if port := web.conf.PortHTTPS; port != defaultPortHTTPS {
|
||||
hostPort = netutil.JoinHostPort(host, port)
|
||||
if portHTTPS != defaultPortHTTPS {
|
||||
hostPort = netutil.JoinHostPort(host, portHTTPS)
|
||||
}
|
||||
|
||||
httpsURL := &url.URL{
|
||||
@@ -357,8 +385,8 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||
Host: r.Host,
|
||||
}
|
||||
|
||||
respHdr.Set(aghhttp.HdrNameAccessControlAllowOrigin, originURL.String())
|
||||
respHdr.Set(aghhttp.HdrNameVary, aghhttp.HdrNameOrigin)
|
||||
respHdr.Set(httphdr.AccessControlAllowOrigin, originURL.String())
|
||||
respHdr.Set(httphdr.Vary, httphdr.Origin)
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -371,7 +399,7 @@ func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.Res
|
||||
path := r.URL.Path
|
||||
if Context.firstRun && !strings.HasPrefix(path, "/install.") &&
|
||||
!strings.HasPrefix(path, "/assets/") {
|
||||
http.Redirect(w, r, "/install.html", http.StatusFound)
|
||||
http.Redirect(w, r, "install.html", http.StatusFound)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ type getAddrsResponse struct {
|
||||
}
|
||||
|
||||
// handleInstallGetAddresses is the handler for /install/get_addresses endpoint.
|
||||
func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
|
||||
func (web *webAPI) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
|
||||
data := getAddrsResponse{
|
||||
Version: version.Version(),
|
||||
|
||||
@@ -167,7 +167,7 @@ func (req *checkConfReq) validateDNS(
|
||||
}
|
||||
|
||||
// handleInstallCheckConfig handles the /check_config endpoint.
|
||||
func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
|
||||
func (web *webAPI) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
|
||||
req := &checkConfReq{}
|
||||
|
||||
err := json.NewDecoder(r.Body).Decode(req)
|
||||
@@ -375,7 +375,7 @@ func shutdownSrv3(srv *http3.Server) {
|
||||
const PasswordMinRunes = 8
|
||||
|
||||
// Apply new configuration, start DNS server, restart Web server
|
||||
func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
req, restartHTTP, err := decodeApplyConfigReq(r.Body)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
@@ -503,7 +503,7 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
|
||||
return req, restartHTTP, err
|
||||
}
|
||||
|
||||
func (web *Web) registerInstallHandlers() {
|
||||
func (web *webAPI) registerInstallHandlers() {
|
||||
Context.mux.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(web.handleInstallGetAddresses)))
|
||||
Context.mux.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig)))
|
||||
Context.mux.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure)))
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/ameshkov/dnscrypt/v2"
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -52,14 +51,15 @@ func initDNS() (err error) {
|
||||
anonymizer := config.anonymizer()
|
||||
|
||||
statsConf := stats.Config{
|
||||
Filename: filepath.Join(baseDir, "stats.db"),
|
||||
LimitDays: config.Stats.Interval,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpRegister,
|
||||
Enabled: config.Stats.Enabled,
|
||||
Filename: filepath.Join(baseDir, "stats.db"),
|
||||
Limit: config.Stats.Interval.Duration,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpRegister,
|
||||
Enabled: config.Stats.Enabled,
|
||||
ShouldCountClient: Context.clients.shouldCountClient,
|
||||
}
|
||||
|
||||
set, err := nonDupEmptyHostNames(config.Stats.Ignored)
|
||||
set, err := aghnet.NewDomainNameSet(config.Stats.Ignored)
|
||||
if err != nil {
|
||||
return fmt.Errorf("statistics: ignored list: %w", err)
|
||||
}
|
||||
@@ -83,13 +83,16 @@ func initDNS() (err error) {
|
||||
FileEnabled: config.QueryLog.FileEnabled,
|
||||
}
|
||||
|
||||
set, err = nonDupEmptyHostNames(config.QueryLog.Ignored)
|
||||
set, err = aghnet.NewDomainNameSet(config.QueryLog.Ignored)
|
||||
if err != nil {
|
||||
return fmt.Errorf("querylog: ignored list: %w", err)
|
||||
}
|
||||
|
||||
conf.Ignored = set
|
||||
Context.queryLog = querylog.New(conf)
|
||||
Context.queryLog, err = querylog.New(conf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init querylog: %w", err)
|
||||
}
|
||||
|
||||
Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil)
|
||||
if err != nil {
|
||||
@@ -426,7 +429,8 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering
|
||||
}
|
||||
|
||||
setts.FilteringEnabled = c.FilteringEnabled
|
||||
setts.SafeSearchEnabled = c.SafeSearchEnabled
|
||||
setts.SafeSearchEnabled = c.safeSearchConf.Enabled
|
||||
setts.ClientSafeSearch = c.SafeSearch
|
||||
setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled
|
||||
setts.ParentalEnabled = c.ParentalEnabled
|
||||
}
|
||||
@@ -533,26 +537,30 @@ func closeDNSServer() {
|
||||
log.Debug("all dns modules are closed")
|
||||
}
|
||||
|
||||
// nonDupEmptyHostNames returns nil and error, if list has duplicate or empty
|
||||
// host name. Otherwise returns a set, which contains lowercase host names
|
||||
// without dot at the end, and nil error.
|
||||
func nonDupEmptyHostNames(list []string) (set *stringutil.Set, err error) {
|
||||
set = stringutil.NewSet()
|
||||
// safeSearchResolver is a [filtering.Resolver] implementation used for safe
|
||||
// search.
|
||||
type safeSearchResolver struct{}
|
||||
|
||||
for _, v := range list {
|
||||
host := strings.ToLower(strings.TrimSuffix(v, "."))
|
||||
// TODO(a.garipov): Think about ignoring empty (".") names in
|
||||
// the future.
|
||||
if host == "" {
|
||||
return nil, errors.Error("host name is empty")
|
||||
}
|
||||
// type check
|
||||
var _ filtering.Resolver = safeSearchResolver{}
|
||||
|
||||
if set.Has(host) {
|
||||
return nil, fmt.Errorf("duplicate host name %q", host)
|
||||
}
|
||||
|
||||
set.Add(host)
|
||||
// LookupIP implements [filtering.Resolver] interface for safeSearchResolver.
|
||||
// It returns the slice of net.IP with IPv4 and IPv6 instances.
|
||||
//
|
||||
// TODO(a.garipov): Support network.
|
||||
func (r safeSearchResolver) LookupIP(_ context.Context, _, host string) (ips []net.IP, err error) {
|
||||
addrs, err := Context.dnsServer.Resolve(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return set, nil
|
||||
if len(addrs) == 0 {
|
||||
return nil, fmt.Errorf("couldn't lookup host: %s", host)
|
||||
}
|
||||
|
||||
for _, a := range addrs {
|
||||
ips = append(ips, a.IP)
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/pprof"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -28,6 +27,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
@@ -58,13 +58,12 @@ type homeContext struct {
|
||||
dhcpServer dhcpd.Interface // DHCP module
|
||||
auth *Auth // HTTP authentication module
|
||||
filters *filtering.DNSFilter // DNS filtering module
|
||||
web *Web // Web (HTTP, HTTPS) module
|
||||
web *webAPI // Web (HTTP, HTTPS) module
|
||||
tls *tlsManager // TLS module
|
||||
// etcHosts is an IP-hostname pairs set taken from system configuration
|
||||
// (e.g. /etc/hosts) files.
|
||||
|
||||
// etcHosts contains IP-hostname mappings taken from the OS-specific hosts
|
||||
// configuration files, for example /etc/hosts.
|
||||
etcHosts *aghnet.HostsContainer
|
||||
// hostsWatcher is the watcher to detect changes in the hosts files.
|
||||
hostsWatcher aghos.FSWatcher
|
||||
|
||||
updater *updater.Updater
|
||||
|
||||
@@ -79,7 +78,6 @@ type homeContext struct {
|
||||
pidFileName string // PID file name. Empty if no PID file was created.
|
||||
controlLock sync.Mutex
|
||||
tlsRoots *x509.CertPool // list of root CAs for TLSv1.2
|
||||
transport *http.Transport
|
||||
client *http.Client
|
||||
appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app
|
||||
|
||||
@@ -149,18 +147,17 @@ func setupContext(opts options) {
|
||||
setupContextFlags(opts)
|
||||
|
||||
Context.tlsRoots = aghtls.SystemRootCAs()
|
||||
Context.transport = &http.Transport{
|
||||
DialContext: customDialContext,
|
||||
Proxy: getHTTPProxy,
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: Context.tlsRoots,
|
||||
CipherSuites: Context.tlsCipherIDs,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
}
|
||||
Context.client = &http.Client{
|
||||
Timeout: time.Minute * 5,
|
||||
Transport: Context.transport,
|
||||
Timeout: time.Minute * 5,
|
||||
Transport: &http.Transport{
|
||||
DialContext: customDialContext,
|
||||
Proxy: getHTTPProxy,
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: Context.tlsRoots,
|
||||
CipherSuites: Context.tlsCipherIDs,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if !Context.firstRun {
|
||||
@@ -263,7 +260,7 @@ func configureOS(conf *configuration) (err error) {
|
||||
// setupHostsContainer initializes the structures to keep up-to-date the hosts
|
||||
// provided by the OS.
|
||||
func setupHostsContainer() (err error) {
|
||||
Context.hostsWatcher, err = aghos.NewOSWritesWatcher()
|
||||
hostsWatcher, err := aghos.NewOSWritesWatcher()
|
||||
if err != nil {
|
||||
return fmt.Errorf("initing hosts watcher: %w", err)
|
||||
}
|
||||
@@ -271,18 +268,18 @@ func setupHostsContainer() (err error) {
|
||||
Context.etcHosts, err = aghnet.NewHostsContainer(
|
||||
filtering.SysHostsListID,
|
||||
aghos.RootDirFS(),
|
||||
Context.hostsWatcher,
|
||||
hostsWatcher,
|
||||
aghnet.DefaultHostsPaths()...,
|
||||
)
|
||||
if err != nil {
|
||||
cerr := Context.hostsWatcher.Close()
|
||||
if errors.Is(err, aghnet.ErrNoHostsPaths) && cerr == nil {
|
||||
closeErr := hostsWatcher.Close()
|
||||
if errors.Is(err, aghnet.ErrNoHostsPaths) && closeErr == nil {
|
||||
log.Info("warning: initing hosts container: %s", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.WithDeferred(fmt.Errorf("initing hosts container: %w", err), cerr)
|
||||
return errors.WithDeferred(fmt.Errorf("initing hosts container: %w", err), closeErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -298,6 +295,17 @@ func setupConfig(opts options) (err error) {
|
||||
config.DNS.DnsfilterConf.UserRules = slices.Clone(config.UserRules)
|
||||
config.DNS.DnsfilterConf.HTTPClient = Context.client
|
||||
|
||||
config.DNS.DnsfilterConf.SafeSearchConf.CustomResolver = safeSearchResolver{}
|
||||
config.DNS.DnsfilterConf.SafeSearch, err = safesearch.NewDefault(
|
||||
config.DNS.DnsfilterConf.SafeSearchConf,
|
||||
"default",
|
||||
config.DNS.DnsfilterConf.SafeSearchCacheSize,
|
||||
time.Minute*time.Duration(config.DNS.DnsfilterConf.CacheTime),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initializing safesearch: %w", err)
|
||||
}
|
||||
|
||||
config.DHCP.WorkDir = Context.workDir
|
||||
config.DHCP.HTTPRegister = httpRegister
|
||||
config.DHCP.ConfigModified = onConfigModified
|
||||
@@ -328,33 +336,16 @@ func setupConfig(opts options) (err error) {
|
||||
arpdb = aghnet.NewARPDB()
|
||||
}
|
||||
|
||||
Context.clients.Init(config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, arpdb)
|
||||
Context.clients.Init(config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, arpdb, config.DNS.DnsfilterConf)
|
||||
|
||||
if opts.bindPort != 0 {
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
addPorts(tcpPorts, tcpPort(opts.bindPort))
|
||||
|
||||
udpPorts := aghalg.UniqChecker[udpPort]{}
|
||||
addPorts(udpPorts, udpPort(config.DNS.Port))
|
||||
|
||||
if config.TLS.Enabled {
|
||||
addPorts(
|
||||
tcpPorts,
|
||||
tcpPort(config.TLS.PortHTTPS),
|
||||
tcpPort(config.TLS.PortDNSOverTLS),
|
||||
tcpPort(config.TLS.PortDNSCrypt),
|
||||
)
|
||||
|
||||
addPorts(udpPorts, udpPort(config.TLS.PortDNSOverQUIC))
|
||||
}
|
||||
|
||||
if err = tcpPorts.Validate(); err != nil {
|
||||
return fmt.Errorf("validating tcp ports: %w", err)
|
||||
} else if err = udpPorts.Validate(); err != nil {
|
||||
return fmt.Errorf("validating udp ports: %w", err)
|
||||
}
|
||||
|
||||
config.BindPort = opts.bindPort
|
||||
|
||||
err = checkPorts()
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// override bind host/port from the console
|
||||
@@ -368,7 +359,35 @@ func setupConfig(opts options) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func initWeb(opts options, clientBuildFS fs.FS) (web *Web, err error) {
|
||||
// checkPorts is a helper for ports validation in config.
|
||||
func checkPorts() (err error) {
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
addPorts(tcpPorts, tcpPort(config.BindPort))
|
||||
|
||||
udpPorts := aghalg.UniqChecker[udpPort]{}
|
||||
addPorts(udpPorts, udpPort(config.DNS.Port))
|
||||
|
||||
if config.TLS.Enabled {
|
||||
addPorts(
|
||||
tcpPorts,
|
||||
tcpPort(config.TLS.PortHTTPS),
|
||||
tcpPort(config.TLS.PortDNSOverTLS),
|
||||
tcpPort(config.TLS.PortDNSCrypt),
|
||||
)
|
||||
|
||||
addPorts(udpPorts, udpPort(config.TLS.PortDNSOverQUIC))
|
||||
}
|
||||
|
||||
if err = tcpPorts.Validate(); err != nil {
|
||||
return fmt.Errorf("validating tcp ports: %w", err)
|
||||
} else if err = udpPorts.Validate(); err != nil {
|
||||
return fmt.Errorf("validating udp ports: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func initWeb(opts options, clientBuildFS fs.FS) (web *webAPI, err error) {
|
||||
var clientFS fs.FS
|
||||
if opts.localFrontend {
|
||||
log.Info("warning: using local frontend files")
|
||||
@@ -395,7 +414,7 @@ func initWeb(opts options, clientBuildFS fs.FS) (web *Web, err error) {
|
||||
serveHTTP3: config.DNS.ServeHTTP3,
|
||||
}
|
||||
|
||||
web = newWeb(&webConf)
|
||||
web = newWebAPI(&webConf)
|
||||
if web == nil {
|
||||
return nil, fmt.Errorf("initializing web: %w", err)
|
||||
}
|
||||
@@ -450,26 +469,8 @@ func run(opts options, clientBuildFS fs.FS) {
|
||||
fatalOnError(err)
|
||||
|
||||
if config.DebugPProf {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/debug/pprof/", pprof.Index)
|
||||
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
|
||||
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
|
||||
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
|
||||
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
|
||||
|
||||
// See profileSupportsDelta in src/net/http/pprof/pprof.go.
|
||||
mux.Handle("/debug/pprof/allocs", pprof.Handler("allocs"))
|
||||
mux.Handle("/debug/pprof/block", pprof.Handler("block"))
|
||||
mux.Handle("/debug/pprof/goroutine", pprof.Handler("goroutine"))
|
||||
mux.Handle("/debug/pprof/heap", pprof.Handler("heap"))
|
||||
mux.Handle("/debug/pprof/mutex", pprof.Handler("mutex"))
|
||||
mux.Handle("/debug/pprof/threadcreate", pprof.Handler("threadcreate"))
|
||||
|
||||
go func() {
|
||||
log.Info("pprof: listening on localhost:6060")
|
||||
lerr := http.ListenAndServe("localhost:6060", mux)
|
||||
log.Error("Error while running the pprof server: %s", lerr)
|
||||
}()
|
||||
// TODO(a.garipov): Make the address configurable.
|
||||
startPprof("localhost:6060")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -532,7 +533,7 @@ func run(opts options, clientBuildFS fs.FS) {
|
||||
}
|
||||
}
|
||||
|
||||
Context.web.Start()
|
||||
Context.web.start()
|
||||
|
||||
// wait indefinitely for other go-routines to complete their job
|
||||
select {}
|
||||
@@ -712,7 +713,7 @@ func cleanup(ctx context.Context) {
|
||||
log.Info("stopping AdGuard Home")
|
||||
|
||||
if Context.web != nil {
|
||||
Context.web.Close(ctx)
|
||||
Context.web.close(ctx)
|
||||
Context.web = nil
|
||||
}
|
||||
if Context.auth != nil {
|
||||
@@ -733,13 +734,6 @@ func cleanup(ctx context.Context) {
|
||||
}
|
||||
|
||||
if Context.etcHosts != nil {
|
||||
// Currently Context.hostsWatcher is only used in Context.etcHosts and
|
||||
// needs closing only in case of the successful initialization of
|
||||
// Context.etcHosts.
|
||||
if err = Context.hostsWatcher.Close(); err != nil {
|
||||
log.Error("closing hosts watcher: %s", err)
|
||||
}
|
||||
|
||||
if err = Context.etcHosts.Close(); err != nil {
|
||||
log.Error("closing hosts container: %s", err)
|
||||
}
|
||||
@@ -857,8 +851,10 @@ func detectFirstRun() bool {
|
||||
// Connect to a remote server resolving hostname using our own DNS server.
|
||||
//
|
||||
// TODO(e.burkov): This messy logic should be decomposed and clarified.
|
||||
//
|
||||
// TODO(a.garipov): Support network.
|
||||
func customDialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
||||
log.Tracef("network:%v addr:%v", network, addr)
|
||||
log.Debug("home: customdial: dialing addr %q for network %s", addr, network)
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/google/uuid"
|
||||
"howett.net/plist"
|
||||
@@ -170,7 +171,7 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/xml")
|
||||
w.Header().Set(httphdr.ContentType, "application/xml")
|
||||
|
||||
const (
|
||||
dohContDisp = `attachment; filename=doh.mobileconfig`
|
||||
@@ -182,7 +183,7 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
|
||||
contDisp = dotContDisp
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Disposition", contDisp)
|
||||
w.Header().Set(httphdr.ContentDisposition, contDisp)
|
||||
|
||||
_, _ = w.Write(mobileconfig)
|
||||
}
|
||||
|
||||
39
internal/home/pprof.go
Normal file
39
internal/home/pprof.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/pprof"
|
||||
"runtime"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// startPprof launches the debug and profiling server on addr.
|
||||
func startPprof(addr string) {
|
||||
runtime.SetBlockProfileRate(1)
|
||||
runtime.SetMutexProfileFraction(1)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("/debug/pprof/", pprof.Index)
|
||||
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
|
||||
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
|
||||
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
|
||||
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
|
||||
|
||||
// See profileSupportsDelta in src/net/http/pprof/pprof.go.
|
||||
mux.Handle("/debug/pprof/allocs", pprof.Handler("allocs"))
|
||||
mux.Handle("/debug/pprof/block", pprof.Handler("block"))
|
||||
mux.Handle("/debug/pprof/goroutine", pprof.Handler("goroutine"))
|
||||
mux.Handle("/debug/pprof/heap", pprof.Handler("heap"))
|
||||
mux.Handle("/debug/pprof/mutex", pprof.Handler("mutex"))
|
||||
mux.Handle("/debug/pprof/threadcreate", pprof.Handler("threadcreate"))
|
||||
|
||||
go func() {
|
||||
defer log.OnPanic("pprof server")
|
||||
|
||||
log.Info("pprof: listening on %q", addr)
|
||||
err := http.ListenAndServe(addr, mux)
|
||||
log.Info("pprof server errors: %v", err)
|
||||
}()
|
||||
}
|
||||
@@ -108,7 +108,7 @@ func (m *tlsManager) start() {
|
||||
// The background context is used because the TLSConfigChanged wraps context
|
||||
// with timeout on its own and shuts down the server, which handles current
|
||||
// request.
|
||||
Context.web.TLSConfigChanged(context.Background(), tlsConf)
|
||||
Context.web.tlsConfigChanged(context.Background(), tlsConf)
|
||||
}
|
||||
|
||||
// reload updates the configuration and restarts t.
|
||||
@@ -156,7 +156,7 @@ func (m *tlsManager) reload() {
|
||||
// The background context is used because the TLSConfigChanged wraps context
|
||||
// with timeout on its own and shuts down the server, which handles current
|
||||
// request.
|
||||
Context.web.TLSConfigChanged(context.Background(), tlsConf)
|
||||
Context.web.tlsConfigChanged(context.Background(), tlsConf)
|
||||
}
|
||||
|
||||
// loadTLSConf loads and validates the TLS configuration. The returned error is
|
||||
@@ -454,7 +454,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
||||
// same reason.
|
||||
if restartHTTPS {
|
||||
go func() {
|
||||
Context.web.TLSConfigChanged(context.Background(), req.tlsConfigSettings)
|
||||
Context.web.tlsConfigChanged(context.Background(), req.tlsConfigSettings)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
)
|
||||
|
||||
// currentSchemaVersion is the current schema version.
|
||||
const currentSchemaVersion = 17
|
||||
const currentSchemaVersion = 20
|
||||
|
||||
// These aliases are provided for convenience.
|
||||
type (
|
||||
@@ -90,6 +90,9 @@ func upgradeConfigSchema(oldVersion int, diskConf yobj) (err error) {
|
||||
upgradeSchema14to15,
|
||||
upgradeSchema15to16,
|
||||
upgradeSchema16to17,
|
||||
upgradeSchema17to18,
|
||||
upgradeSchema18to19,
|
||||
upgradeSchema19to20,
|
||||
}
|
||||
|
||||
n := 0
|
||||
@@ -836,9 +839,9 @@ func upgradeSchema14to15(diskConf yobj) (err error) {
|
||||
}
|
||||
|
||||
type temp struct {
|
||||
val any
|
||||
from string
|
||||
to string
|
||||
val any
|
||||
}
|
||||
replaces := []temp{
|
||||
{from: "querylog_enabled", to: "enabled", val: true},
|
||||
@@ -873,6 +876,18 @@ func upgradeSchema14to15(diskConf yobj) (err error) {
|
||||
// 'enabled': true
|
||||
// 'interval': 1
|
||||
// 'ignored': []
|
||||
//
|
||||
// If statistics were disabled:
|
||||
//
|
||||
// # BEFORE:
|
||||
// 'dns':
|
||||
// 'statistics_interval': 0
|
||||
//
|
||||
// # AFTER:
|
||||
// 'statistics':
|
||||
// 'enabled': false
|
||||
// 'interval': 1
|
||||
// 'ignored': []
|
||||
func upgradeSchema15to16(diskConf yobj) (err error) {
|
||||
log.Printf("Upgrade yaml: 15 to 16")
|
||||
diskConf["schema_version"] = 16
|
||||
@@ -894,10 +909,23 @@ func upgradeSchema15to16(diskConf yobj) (err error) {
|
||||
}
|
||||
|
||||
const field = "statistics_interval"
|
||||
v, has := dns[field]
|
||||
statsIvlVal, has := dns[field]
|
||||
if has {
|
||||
stats["enabled"] = v != 0
|
||||
stats["interval"] = v
|
||||
var statsIvl int
|
||||
statsIvl, ok = statsIvlVal.(int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of dns.statistics_interval: %T", statsIvlVal)
|
||||
}
|
||||
|
||||
if statsIvl == 0 {
|
||||
// Set the interval to the default value of one day to make sure
|
||||
// that it passes the validations.
|
||||
stats["interval"] = 1
|
||||
stats["enabled"] = false
|
||||
} else {
|
||||
stats["interval"] = statsIvl
|
||||
stats["enabled"] = true
|
||||
}
|
||||
}
|
||||
delete(dns, field)
|
||||
|
||||
@@ -943,6 +971,172 @@ func upgradeSchema16to17(diskConf yobj) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// upgradeSchema17to18 performs the following changes:
|
||||
//
|
||||
// # BEFORE:
|
||||
// 'dns':
|
||||
// 'safesearch_enabled': true
|
||||
//
|
||||
// # AFTER:
|
||||
// 'dns':
|
||||
// 'safe_search':
|
||||
// 'enabled': true
|
||||
// 'bing': true
|
||||
// 'duckduckgo': true
|
||||
// 'google': true
|
||||
// 'pixabay': true
|
||||
// 'yandex': true
|
||||
// 'youtube': true
|
||||
func upgradeSchema17to18(diskConf yobj) (err error) {
|
||||
log.Printf("Upgrade yaml: 17 to 18")
|
||||
diskConf["schema_version"] = 18
|
||||
|
||||
dnsVal, ok := diskConf["dns"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
dns, ok := dnsVal.(yobj)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of dns: %T", dnsVal)
|
||||
}
|
||||
|
||||
safeSearch := yobj{
|
||||
"enabled": true,
|
||||
"bing": true,
|
||||
"duckduckgo": true,
|
||||
"google": true,
|
||||
"pixabay": true,
|
||||
"yandex": true,
|
||||
"youtube": true,
|
||||
}
|
||||
|
||||
const safeSearchKey = "safesearch_enabled"
|
||||
|
||||
v, has := dns[safeSearchKey]
|
||||
if has {
|
||||
safeSearch["enabled"] = v
|
||||
}
|
||||
delete(dns, safeSearchKey)
|
||||
|
||||
dns["safe_search"] = safeSearch
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// upgradeSchema18to19 performs the following changes:
|
||||
//
|
||||
// # BEFORE:
|
||||
// 'clients':
|
||||
// 'persistent':
|
||||
// - 'name': 'client-name'
|
||||
// 'safesearch_enabled': true
|
||||
//
|
||||
// # AFTER:
|
||||
// 'clients':
|
||||
// 'persistent':
|
||||
// - 'name': 'client-name'
|
||||
// 'safe_search':
|
||||
// 'enabled': true
|
||||
// 'bing': true
|
||||
// 'duckduckgo': true
|
||||
// 'google': true
|
||||
// 'pixabay': true
|
||||
// 'yandex': true
|
||||
// 'youtube': true
|
||||
func upgradeSchema18to19(diskConf yobj) (err error) {
|
||||
log.Printf("Upgrade yaml: 18 to 19")
|
||||
diskConf["schema_version"] = 19
|
||||
|
||||
clientsVal, ok := diskConf["clients"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
clients, ok := clientsVal.(yobj)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of clients: %T", clientsVal)
|
||||
}
|
||||
|
||||
persistent, ok := clients["persistent"].([]yobj)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
const safeSearchKey = "safesearch_enabled"
|
||||
|
||||
for i := range persistent {
|
||||
c := persistent[i]
|
||||
|
||||
safeSearch := yobj{
|
||||
"enabled": true,
|
||||
"bing": true,
|
||||
"duckduckgo": true,
|
||||
"google": true,
|
||||
"pixabay": true,
|
||||
"yandex": true,
|
||||
"youtube": true,
|
||||
}
|
||||
|
||||
v, has := c[safeSearchKey]
|
||||
if has {
|
||||
safeSearch["enabled"] = v
|
||||
}
|
||||
delete(c, safeSearchKey)
|
||||
|
||||
c["safe_search"] = safeSearch
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// upgradeSchema19to20 performs the following changes:
|
||||
//
|
||||
// # BEFORE:
|
||||
// 'statistics':
|
||||
// 'interval': 1
|
||||
//
|
||||
// # AFTER:
|
||||
// 'statistics':
|
||||
// 'interval': 24h
|
||||
func upgradeSchema19to20(diskConf yobj) (err error) {
|
||||
log.Printf("Upgrade yaml: 19 to 20")
|
||||
diskConf["schema_version"] = 20
|
||||
|
||||
statsVal, ok := diskConf["statistics"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
var stats yobj
|
||||
stats, ok = statsVal.(yobj)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of stats: %T", statsVal)
|
||||
}
|
||||
|
||||
const field = "interval"
|
||||
|
||||
// Set the initial value from the global configuration structure.
|
||||
statsIvl := 1
|
||||
statsIvlVal, ok := stats[field]
|
||||
if ok {
|
||||
statsIvl, ok = statsIvlVal.(int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of %s: %T", field, statsIvlVal)
|
||||
}
|
||||
|
||||
// The initial version of upgradeSchema16to17 did not set the zero
|
||||
// interval to a non-zero one. So, reset it now.
|
||||
if statsIvl == 0 {
|
||||
statsIvl = 1
|
||||
}
|
||||
}
|
||||
|
||||
stats[field] = timeutil.Duration{Duration: time.Duration(statsIvl) * timeutil.Day}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Replace with log.Output when we port it to our logging
|
||||
// package.
|
||||
func funcName() string {
|
||||
|
||||
@@ -729,7 +729,7 @@ func TestUpgradeSchema15to16(t *testing.T) {
|
||||
want: yobj{
|
||||
"statistics": map[string]any{
|
||||
"enabled": false,
|
||||
"interval": 0,
|
||||
"interval": 1,
|
||||
"ignored": []any{},
|
||||
},
|
||||
"dns": map[string]any{},
|
||||
@@ -808,3 +808,246 @@ func TestUpgradeSchema16to17(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpgradeSchema17to18(t *testing.T) {
|
||||
const newSchemaVer = 18
|
||||
|
||||
defaultWantObj := yobj{
|
||||
"dns": yobj{
|
||||
"safe_search": yobj{
|
||||
"enabled": true,
|
||||
"bing": true,
|
||||
"duckduckgo": true,
|
||||
"google": true,
|
||||
"pixabay": true,
|
||||
"yandex": true,
|
||||
"youtube": true,
|
||||
},
|
||||
},
|
||||
"schema_version": newSchemaVer,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
in yobj
|
||||
want yobj
|
||||
name string
|
||||
}{{
|
||||
in: yobj{"dns": yobj{}},
|
||||
want: defaultWantObj,
|
||||
name: "default_values",
|
||||
}, {
|
||||
in: yobj{"dns": yobj{"safesearch_enabled": true}},
|
||||
want: defaultWantObj,
|
||||
name: "enabled",
|
||||
}, {
|
||||
in: yobj{"dns": yobj{"safesearch_enabled": false}},
|
||||
want: yobj{
|
||||
"dns": yobj{
|
||||
"safe_search": map[string]any{
|
||||
"enabled": false,
|
||||
"bing": true,
|
||||
"duckduckgo": true,
|
||||
"google": true,
|
||||
"pixabay": true,
|
||||
"yandex": true,
|
||||
"youtube": true,
|
||||
},
|
||||
},
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
name: "disabled",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := upgradeSchema17to18(tc.in)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.want, tc.in)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpgradeSchema18to19(t *testing.T) {
|
||||
const newSchemaVer = 19
|
||||
|
||||
defaultWantObj := yobj{
|
||||
"clients": yobj{
|
||||
"persistent": []yobj{{
|
||||
"name": "localhost",
|
||||
"safe_search": yobj{
|
||||
"enabled": true,
|
||||
"bing": true,
|
||||
"duckduckgo": true,
|
||||
"google": true,
|
||||
"pixabay": true,
|
||||
"yandex": true,
|
||||
"youtube": true,
|
||||
},
|
||||
}},
|
||||
},
|
||||
"schema_version": newSchemaVer,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
in yobj
|
||||
want yobj
|
||||
name string
|
||||
}{{
|
||||
in: yobj{
|
||||
"clients": yobj{},
|
||||
},
|
||||
want: yobj{
|
||||
"clients": yobj{},
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
name: "no_clients",
|
||||
}, {
|
||||
in: yobj{
|
||||
"clients": yobj{
|
||||
"persistent": []yobj{{"name": "localhost"}},
|
||||
},
|
||||
},
|
||||
want: defaultWantObj,
|
||||
name: "default_values",
|
||||
}, {
|
||||
in: yobj{
|
||||
"clients": yobj{
|
||||
"persistent": []yobj{{"name": "localhost", "safesearch_enabled": true}},
|
||||
},
|
||||
},
|
||||
want: defaultWantObj,
|
||||
name: "enabled",
|
||||
}, {
|
||||
in: yobj{
|
||||
"clients": yobj{
|
||||
"persistent": []yobj{{"name": "localhost", "safesearch_enabled": false}},
|
||||
},
|
||||
},
|
||||
want: yobj{
|
||||
"clients": yobj{"persistent": []yobj{{
|
||||
"name": "localhost",
|
||||
"safe_search": yobj{
|
||||
"enabled": false,
|
||||
"bing": true,
|
||||
"duckduckgo": true,
|
||||
"google": true,
|
||||
"pixabay": true,
|
||||
"yandex": true,
|
||||
"youtube": true,
|
||||
},
|
||||
}}},
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
name: "disabled",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := upgradeSchema18to19(tc.in)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.want, tc.in)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpgradeSchema19to20(t *testing.T) {
|
||||
testCases := []struct {
|
||||
ivl any
|
||||
want any
|
||||
wantErr string
|
||||
name string
|
||||
}{{
|
||||
ivl: 1,
|
||||
want: timeutil.Duration{Duration: timeutil.Day},
|
||||
wantErr: "",
|
||||
name: "success",
|
||||
}, {
|
||||
ivl: 0,
|
||||
want: timeutil.Duration{Duration: timeutil.Day},
|
||||
wantErr: "",
|
||||
name: "success",
|
||||
}, {
|
||||
ivl: 0.25,
|
||||
want: 0,
|
||||
wantErr: "unexpected type of interval: float64",
|
||||
name: "fail",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
conf := yobj{
|
||||
"statistics": yobj{
|
||||
"interval": tc.ivl,
|
||||
},
|
||||
"schema_version": 19,
|
||||
}
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := upgradeSchema19to20(conf)
|
||||
|
||||
if tc.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantErr, err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, conf["schema_version"], 20)
|
||||
|
||||
statsVal, ok := conf["statistics"]
|
||||
require.True(t, ok)
|
||||
|
||||
var stats yobj
|
||||
stats, ok = statsVal.(yobj)
|
||||
require.True(t, ok)
|
||||
|
||||
var newIvl timeutil.Duration
|
||||
newIvl, ok = stats["interval"].(timeutil.Duration)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, tc.want, newIvl)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("no_stats", func(t *testing.T) {
|
||||
err := upgradeSchema19to20(yobj{})
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("bad_stats", func(t *testing.T) {
|
||||
err := upgradeSchema19to20(yobj{
|
||||
"statistics": 0,
|
||||
})
|
||||
|
||||
testutil.AssertErrorMsg(t, "unexpected type of stats: int", err)
|
||||
})
|
||||
|
||||
t.Run("no_field", func(t *testing.T) {
|
||||
conf := yobj{
|
||||
"statistics": yobj{},
|
||||
}
|
||||
|
||||
err := upgradeSchema19to20(conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
statsVal, ok := conf["statistics"]
|
||||
require.True(t, ok)
|
||||
|
||||
var stats yobj
|
||||
stats, ok = statsVal.(yobj)
|
||||
require.True(t, ok)
|
||||
|
||||
var ivl any
|
||||
ivl, ok = stats["interval"]
|
||||
require.True(t, ok)
|
||||
|
||||
var ivlVal timeutil.Duration
|
||||
ivlVal, ok = ivl.(timeutil.Duration)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, 24*time.Hour, ivlVal.Duration)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -35,9 +35,8 @@ const (
|
||||
type webConfig struct {
|
||||
clientFS fs.FS
|
||||
|
||||
BindHost netip.Addr
|
||||
BindPort int
|
||||
PortHTTPS int
|
||||
BindHost netip.Addr
|
||||
BindPort int
|
||||
|
||||
// ReadTimeout is an option to pass to http.Server for setting an
|
||||
// appropriate field.
|
||||
@@ -72,8 +71,8 @@ type httpsServer struct {
|
||||
enabled bool
|
||||
}
|
||||
|
||||
// Web is the web UI and API server.
|
||||
type Web struct {
|
||||
// webAPI is the web UI and API server.
|
||||
type webAPI struct {
|
||||
conf *webConfig
|
||||
|
||||
// TODO(a.garipov): Refactor all these servers.
|
||||
@@ -82,15 +81,13 @@ type Web struct {
|
||||
// httpsServer is the server that handles HTTPS traffic. If it is not nil,
|
||||
// [Web.http3Server] must also not be nil.
|
||||
httpsServer httpsServer
|
||||
|
||||
forceHTTPS bool
|
||||
}
|
||||
|
||||
// newWeb creates a new instance of the web UI and API server.
|
||||
func newWeb(conf *webConfig) (w *Web) {
|
||||
// newWebAPI creates a new instance of the web UI and API server.
|
||||
func newWebAPI(conf *webConfig) (w *webAPI) {
|
||||
log.Info("web: initializing")
|
||||
|
||||
w = &Web{
|
||||
w = &webAPI{
|
||||
conf: conf,
|
||||
}
|
||||
|
||||
@@ -125,12 +122,10 @@ func webCheckPortAvailable(port int) (ok bool) {
|
||||
return aghnet.CheckPort("tcp", netip.AddrPortFrom(config.BindHost, uint16(port))) == nil
|
||||
}
|
||||
|
||||
// TLSConfigChanged updates the TLS configuration and restarts the HTTPS server
|
||||
// tlsConfigChanged updates the TLS configuration and restarts the HTTPS server
|
||||
// if necessary.
|
||||
func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) {
|
||||
func (web *webAPI) tlsConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) {
|
||||
log.Debug("web: applying new tls configuration")
|
||||
web.conf.PortHTTPS = tlsConf.PortHTTPS
|
||||
web.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0)
|
||||
|
||||
enabled := tlsConf.Enabled &&
|
||||
tlsConf.PortHTTPS != 0 &&
|
||||
@@ -161,8 +156,8 @@ func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf tlsConfigSettings)
|
||||
web.httpsServer.cond.L.Unlock()
|
||||
}
|
||||
|
||||
// Start - start serving HTTP requests
|
||||
func (web *Web) Start() {
|
||||
// start - start serving HTTP requests
|
||||
func (web *webAPI) start() {
|
||||
log.Println("AdGuard Home is available at the following addresses:")
|
||||
|
||||
// for https, we have a separate goroutine loop
|
||||
@@ -203,8 +198,8 @@ func (web *Web) Start() {
|
||||
}
|
||||
}
|
||||
|
||||
// Close gracefully shuts down the HTTP servers.
|
||||
func (web *Web) Close(ctx context.Context) {
|
||||
// close gracefully shuts down the HTTP servers.
|
||||
func (web *webAPI) close(ctx context.Context) {
|
||||
log.Info("stopping http server...")
|
||||
|
||||
web.httpsServer.cond.L.Lock()
|
||||
@@ -222,7 +217,7 @@ func (web *Web) Close(ctx context.Context) {
|
||||
log.Info("stopped http server")
|
||||
}
|
||||
|
||||
func (web *Web) tlsServerLoop() {
|
||||
func (web *webAPI) tlsServerLoop() {
|
||||
for {
|
||||
web.httpsServer.cond.L.Lock()
|
||||
if web.httpsServer.inShutdown {
|
||||
@@ -241,7 +236,15 @@ func (web *Web) tlsServerLoop() {
|
||||
|
||||
web.httpsServer.cond.L.Unlock()
|
||||
|
||||
addr := netutil.JoinHostPort(web.conf.BindHost.String(), web.conf.PortHTTPS)
|
||||
var portHTTPS int
|
||||
func() {
|
||||
config.RLock()
|
||||
defer config.RUnlock()
|
||||
|
||||
portHTTPS = config.TLS.PortHTTPS
|
||||
}()
|
||||
|
||||
addr := netutil.JoinHostPort(web.conf.BindHost.String(), portHTTPS)
|
||||
web.httpsServer.server = &http.Server{
|
||||
ErrorLog: log.StdLog("web: https", log.DEBUG),
|
||||
Addr: addr,
|
||||
@@ -272,7 +275,7 @@ func (web *Web) tlsServerLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
func (web *Web) mustStartHTTP3(address string) {
|
||||
func (web *webAPI) mustStartHTTP3(address string) {
|
||||
defer log.OnPanic("web: http3")
|
||||
|
||||
web.httpsServer.server3 = &http3.Server{
|
||||
|
||||
@@ -7,6 +7,7 @@ type Client struct {
|
||||
Name string `json:"name"`
|
||||
DisallowedRule string `json:"disallowed_rule"`
|
||||
Disallowed bool `json:"disallowed"`
|
||||
IgnoreQueryLog bool `json:"-"`
|
||||
}
|
||||
|
||||
// ClientWHOIS is the filtered WHOIS data for the client.
|
||||
|
||||
@@ -166,86 +166,6 @@ var logEntryHandlers = map[string]logEntryHandler{
|
||||
},
|
||||
}
|
||||
|
||||
var resultHandlers = map[string]logEntryHandler{
|
||||
"IsFiltered": func(t json.Token, ent *logEntry) error {
|
||||
v, ok := t.(bool)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
ent.Result.IsFiltered = v
|
||||
return nil
|
||||
},
|
||||
"Rule": func(t json.Token, ent *logEntry) error {
|
||||
s, ok := t.(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
l := len(ent.Result.Rules)
|
||||
if l == 0 {
|
||||
ent.Result.Rules = []*filtering.ResultRule{{}}
|
||||
l++
|
||||
}
|
||||
|
||||
ent.Result.Rules[l-1].Text = s
|
||||
|
||||
return nil
|
||||
},
|
||||
"FilterID": func(t json.Token, ent *logEntry) error {
|
||||
n, ok := t.(json.Number)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
i, err := n.Int64()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l := len(ent.Result.Rules)
|
||||
if l == 0 {
|
||||
ent.Result.Rules = []*filtering.ResultRule{{}}
|
||||
l++
|
||||
}
|
||||
|
||||
ent.Result.Rules[l-1].FilterListID = i
|
||||
|
||||
return nil
|
||||
},
|
||||
"Reason": func(t json.Token, ent *logEntry) error {
|
||||
v, ok := t.(json.Number)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
i, err := v.Int64()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ent.Result.Reason = filtering.Reason(i)
|
||||
return nil
|
||||
},
|
||||
"ServiceName": func(t json.Token, ent *logEntry) error {
|
||||
s, ok := t.(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
ent.Result.ServiceName = s
|
||||
|
||||
return nil
|
||||
},
|
||||
"CanonName": func(t json.Token, ent *logEntry) error {
|
||||
s, ok := t.(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
ent.Result.CanonName = s
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func decodeResultRuleKey(key string, i int, dec *json.Decoder, ent *logEntry) {
|
||||
var vToken json.Token
|
||||
switch key {
|
||||
@@ -582,25 +502,11 @@ func decodeResult(dec *json.Decoder, ent *logEntry) {
|
||||
return
|
||||
}
|
||||
|
||||
switch key {
|
||||
case "ReverseHosts":
|
||||
decodeResultReverseHosts(dec, ent)
|
||||
decHandler, ok := resultDecHandlers[key]
|
||||
if ok {
|
||||
decHandler(dec, ent)
|
||||
|
||||
continue
|
||||
case "IPList":
|
||||
decodeResultIPList(dec, ent)
|
||||
|
||||
continue
|
||||
case "Rules":
|
||||
decodeResultRules(dec, ent)
|
||||
|
||||
continue
|
||||
case "DNSRewriteResult":
|
||||
decodeResultDNSRewriteResult(dec, ent)
|
||||
|
||||
continue
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
handler, ok := resultHandlers[key]
|
||||
@@ -621,6 +527,93 @@ func decodeResult(dec *json.Decoder, ent *logEntry) {
|
||||
}
|
||||
}
|
||||
|
||||
var resultHandlers = map[string]logEntryHandler{
|
||||
"IsFiltered": func(t json.Token, ent *logEntry) error {
|
||||
v, ok := t.(bool)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
ent.Result.IsFiltered = v
|
||||
return nil
|
||||
},
|
||||
"Rule": func(t json.Token, ent *logEntry) error {
|
||||
s, ok := t.(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
l := len(ent.Result.Rules)
|
||||
if l == 0 {
|
||||
ent.Result.Rules = []*filtering.ResultRule{{}}
|
||||
l++
|
||||
}
|
||||
|
||||
ent.Result.Rules[l-1].Text = s
|
||||
|
||||
return nil
|
||||
},
|
||||
"FilterID": func(t json.Token, ent *logEntry) error {
|
||||
n, ok := t.(json.Number)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
i, err := n.Int64()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l := len(ent.Result.Rules)
|
||||
if l == 0 {
|
||||
ent.Result.Rules = []*filtering.ResultRule{{}}
|
||||
l++
|
||||
}
|
||||
|
||||
ent.Result.Rules[l-1].FilterListID = i
|
||||
|
||||
return nil
|
||||
},
|
||||
"Reason": func(t json.Token, ent *logEntry) error {
|
||||
v, ok := t.(json.Number)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
i, err := v.Int64()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ent.Result.Reason = filtering.Reason(i)
|
||||
return nil
|
||||
},
|
||||
"ServiceName": func(t json.Token, ent *logEntry) error {
|
||||
s, ok := t.(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
ent.Result.ServiceName = s
|
||||
|
||||
return nil
|
||||
},
|
||||
"CanonName": func(t json.Token, ent *logEntry) error {
|
||||
s, ok := t.(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
ent.Result.CanonName = s
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var resultDecHandlers = map[string]func(dec *json.Decoder, ent *logEntry){
|
||||
"ReverseHosts": decodeResultReverseHosts,
|
||||
"IPList": decodeResultIPList,
|
||||
"Rules": decodeResultRules,
|
||||
"DNSRewriteResult": decodeResultDNSRewriteResult,
|
||||
}
|
||||
|
||||
func decodeLogEntry(ent *logEntry, str string) {
|
||||
dec := json.NewDecoder(strings.NewReader(str))
|
||||
dec.UseNumber()
|
||||
|
||||
70
internal/querylog/entry.go
Normal file
70
internal/querylog/entry.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package querylog
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// logEntry represents a single entry in the file.
|
||||
type logEntry struct {
|
||||
// client is the found client information, if any.
|
||||
client *Client
|
||||
|
||||
Time time.Time `json:"T"`
|
||||
|
||||
QHost string `json:"QH"`
|
||||
QType string `json:"QT"`
|
||||
QClass string `json:"QC"`
|
||||
|
||||
ReqECS string `json:"ECS,omitempty"`
|
||||
|
||||
ClientID string `json:"CID,omitempty"`
|
||||
ClientProto ClientProto `json:"CP"`
|
||||
|
||||
Upstream string `json:",omitempty"`
|
||||
|
||||
Answer []byte `json:",omitempty"`
|
||||
OrigAnswer []byte `json:",omitempty"`
|
||||
|
||||
IP net.IP `json:"IP"`
|
||||
|
||||
Result filtering.Result
|
||||
|
||||
Elapsed time.Duration
|
||||
|
||||
Cached bool `json:",omitempty"`
|
||||
AuthenticatedData bool `json:"AD,omitempty"`
|
||||
}
|
||||
|
||||
// shallowClone returns a shallow clone of e.
|
||||
func (e *logEntry) shallowClone() (clone *logEntry) {
|
||||
cloneVal := *e
|
||||
|
||||
return &cloneVal
|
||||
}
|
||||
|
||||
// addResponse adds data from resp to e.Answer if resp is not nil. If isOrig is
|
||||
// true, addResponse sets the e.OrigAnswer field instead of e.Answer. Any
|
||||
// errors are logged.
|
||||
func (e *logEntry) addResponse(resp *dns.Msg, isOrig bool) {
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
if isOrig {
|
||||
e.Answer, err = resp.Pack()
|
||||
err = errors.Annotate(err, "packing answer: %w")
|
||||
} else {
|
||||
e.OrigAnswer, err = resp.Pack()
|
||||
err = errors.Annotate(err, "packing orig answer: %w")
|
||||
}
|
||||
if err != nil {
|
||||
log.Error("querylog: %s", err)
|
||||
}
|
||||
}
|
||||
@@ -13,9 +13,11 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
@@ -25,8 +27,8 @@ type configJSON struct {
|
||||
// fractional numbers and not mess the API users by changing the units.
|
||||
Interval float64 `json:"interval"`
|
||||
|
||||
// Enabled shows if the querylog is enabled. It is an [aghalg.NullBool]
|
||||
// to be able to tell when it's set without using pointers.
|
||||
// Enabled shows if the querylog is enabled. It is an aghalg.NullBool to
|
||||
// be able to tell when it's set without using pointers.
|
||||
Enabled aghalg.NullBool `json:"enabled"`
|
||||
|
||||
// AnonymizeClientIP shows if the clients' IP addresses must be anonymized.
|
||||
@@ -35,44 +37,115 @@ type configJSON struct {
|
||||
AnonymizeClientIP aghalg.NullBool `json:"anonymize_client_ip"`
|
||||
}
|
||||
|
||||
// getConfigResp is the JSON structure for the querylog configuration.
|
||||
type getConfigResp struct {
|
||||
// Ignored is the list of host names, which should not be written to log.
|
||||
Ignored []string `json:"ignored"`
|
||||
|
||||
// Interval is the querylog rotation interval in milliseconds.
|
||||
Interval float64 `json:"interval"`
|
||||
|
||||
// Enabled shows if the querylog is enabled. It is an aghalg.NullBool to
|
||||
// be able to tell when it's set without using pointers.
|
||||
Enabled aghalg.NullBool `json:"enabled"`
|
||||
|
||||
// AnonymizeClientIP shows if the clients' IP addresses must be anonymized.
|
||||
// It is an aghalg.NullBool to be able to tell when it's set without using
|
||||
// pointers.
|
||||
//
|
||||
// TODO(a.garipov): Consider using separate setting for statistics.
|
||||
AnonymizeClientIP aghalg.NullBool `json:"anonymize_client_ip"`
|
||||
}
|
||||
|
||||
// Register web handlers
|
||||
func (l *queryLog) initWeb() {
|
||||
l.conf.HTTPRegister(http.MethodGet, "/control/querylog", l.handleQueryLog)
|
||||
l.conf.HTTPRegister(http.MethodGet, "/control/querylog_info", l.handleQueryLogInfo)
|
||||
l.conf.HTTPRegister(http.MethodPost, "/control/querylog_clear", l.handleQueryLogClear)
|
||||
l.conf.HTTPRegister(http.MethodGet, "/control/querylog/config", l.handleGetQueryLogConfig)
|
||||
l.conf.HTTPRegister(
|
||||
http.MethodPut,
|
||||
"/control/querylog/config/update",
|
||||
l.handlePutQueryLogConfig,
|
||||
)
|
||||
|
||||
// Deprecated handlers.
|
||||
l.conf.HTTPRegister(http.MethodGet, "/control/querylog_info", l.handleQueryLogInfo)
|
||||
l.conf.HTTPRegister(http.MethodPost, "/control/querylog_config", l.handleQueryLogConfig)
|
||||
}
|
||||
|
||||
// handleQueryLog is the handler for the GET /control/querylog HTTP API.
|
||||
func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
params, err := l.parseSearchParams(r)
|
||||
params, err := parseSearchParams(r)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "failed to parse params: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "parsing params: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
entries, oldest := l.search(params)
|
||||
data := l.entriesToJSON(entries, oldest)
|
||||
var entries []*logEntry
|
||||
var oldest time.Time
|
||||
func() {
|
||||
l.confMu.RLock()
|
||||
defer l.confMu.RUnlock()
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, data)
|
||||
entries, oldest = l.search(params)
|
||||
}()
|
||||
|
||||
resp := entriesToJSON(entries, oldest, l.anonymizer.Load())
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
}
|
||||
|
||||
// handleQueryLogClear is the handler for the POST /control/querylog/clear HTTP
|
||||
// API.
|
||||
func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, _ *http.Request) {
|
||||
l.clear()
|
||||
}
|
||||
|
||||
// Get configuration
|
||||
// handleQueryLogInfo is the handler for the GET /control/querylog_info HTTP
|
||||
// API.
|
||||
//
|
||||
// Deprecated: Remove it when migration to the new API is over.
|
||||
func (l *queryLog) handleQueryLogInfo(w http.ResponseWriter, r *http.Request) {
|
||||
l.confMu.RLock()
|
||||
defer l.confMu.RUnlock()
|
||||
|
||||
ivl := l.conf.RotationIvl
|
||||
|
||||
if !checkInterval(ivl) {
|
||||
// NOTE: If interval is custom we set it to 90 days for compatibility
|
||||
// with old API.
|
||||
ivl = timeutil.Day * 90
|
||||
}
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, configJSON{
|
||||
Enabled: aghalg.BoolToNullBool(l.conf.Enabled),
|
||||
Interval: l.conf.RotationIvl.Hours() / 24,
|
||||
Interval: ivl.Hours() / 24,
|
||||
AnonymizeClientIP: aghalg.BoolToNullBool(l.conf.AnonymizeClientIP),
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetQueryLogConfig is the handler for the GET /control/querylog/config
|
||||
// HTTP API.
|
||||
func (l *queryLog) handleGetQueryLogConfig(w http.ResponseWriter, r *http.Request) {
|
||||
var resp *getConfigResp
|
||||
func() {
|
||||
l.confMu.RLock()
|
||||
defer l.confMu.RUnlock()
|
||||
|
||||
resp = &getConfigResp{
|
||||
Interval: float64(l.conf.RotationIvl.Milliseconds()),
|
||||
Enabled: aghalg.BoolToNullBool(l.conf.Enabled),
|
||||
AnonymizeClientIP: aghalg.BoolToNullBool(l.conf.AnonymizeClientIP),
|
||||
Ignored: l.conf.Ignored.Values(),
|
||||
}
|
||||
}()
|
||||
|
||||
slices.Sort(resp.Ignored)
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
}
|
||||
|
||||
// AnonymizeIP masks ip to anonymize the client if the ip is a valid one.
|
||||
func AnonymizeIP(ip net.IP) {
|
||||
// zeroes is a slice of zero bytes from which the IP address tail is copied.
|
||||
@@ -87,7 +160,10 @@ func AnonymizeIP(ip net.IP) {
|
||||
}
|
||||
}
|
||||
|
||||
// handleQueryLogConfig handles the POST /control/querylog_config queries.
|
||||
// handleQueryLogConfig is the handler for the POST /control/querylog_config
|
||||
// HTTP API.
|
||||
//
|
||||
// Deprecated: Remove it when migration to the new API is over.
|
||||
func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request) {
|
||||
// Set NaN as initial value to be able to know if it changed later by
|
||||
// comparing it to NaN.
|
||||
@@ -103,6 +179,7 @@ func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
ivl := time.Duration(float64(timeutil.Day) * newConf.Interval)
|
||||
|
||||
hasIvl := !math.IsNaN(newConf.Interval)
|
||||
if hasIvl && !checkInterval(ivl) {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "unsupported interval")
|
||||
@@ -112,11 +189,9 @@ func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
defer l.conf.ConfigModified()
|
||||
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
l.confMu.Lock()
|
||||
defer l.confMu.Unlock()
|
||||
|
||||
// Copy data, modify it, then activate. Other threads (readers) don't need
|
||||
// to use this lock.
|
||||
conf := *l.conf
|
||||
if newConf.Enabled != aghalg.NBNull {
|
||||
conf.Enabled = newConf.Enabled == aghalg.NBTrue
|
||||
@@ -138,6 +213,65 @@ func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request)
|
||||
l.conf = &conf
|
||||
}
|
||||
|
||||
// handlePutQueryLogConfig is the handler for the PUT
|
||||
// /control/querylog/config/update HTTP API.
|
||||
func (l *queryLog) handlePutQueryLogConfig(w http.ResponseWriter, r *http.Request) {
|
||||
newConf := &getConfigResp{}
|
||||
err := json.NewDecoder(r.Body).Decode(newConf)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
set, err := aghnet.NewDomainNameSet(newConf.Ignored)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "ignored: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ivl := time.Duration(newConf.Interval) * time.Millisecond
|
||||
err = validateIvl(ivl)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "unsupported interval: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if newConf.Enabled == aghalg.NBNull {
|
||||
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "enabled is null")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if newConf.AnonymizeClientIP == aghalg.NBNull {
|
||||
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "anonymize_client_ip is null")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer l.conf.ConfigModified()
|
||||
|
||||
l.confMu.Lock()
|
||||
defer l.confMu.Unlock()
|
||||
|
||||
conf := *l.conf
|
||||
|
||||
conf.Ignored = set
|
||||
conf.RotationIvl = ivl
|
||||
conf.Enabled = newConf.Enabled == aghalg.NBTrue
|
||||
|
||||
conf.AnonymizeClientIP = newConf.AnonymizeClientIP == aghalg.NBTrue
|
||||
if conf.AnonymizeClientIP {
|
||||
l.anonymizer.Store(AnonymizeIP)
|
||||
} else {
|
||||
l.anonymizer.Store(nil)
|
||||
}
|
||||
|
||||
l.conf = &conf
|
||||
}
|
||||
|
||||
// "value" -> value, return TRUE
|
||||
func getDoubleQuotesEnclosedValue(s *string) bool {
|
||||
t := *s
|
||||
@@ -149,7 +283,7 @@ func getDoubleQuotesEnclosedValue(s *string) bool {
|
||||
}
|
||||
|
||||
// parseSearchCriterion parses a search criterion from the query parameter.
|
||||
func (l *queryLog) parseSearchCriterion(q url.Values, name string, ct criterionType) (
|
||||
func parseSearchCriterion(q url.Values, name string, ct criterionType) (
|
||||
ok bool,
|
||||
sc searchCriterion,
|
||||
err error,
|
||||
@@ -198,8 +332,9 @@ func (l *queryLog) parseSearchCriterion(q url.Values, name string, ct criterionT
|
||||
return true, sc, nil
|
||||
}
|
||||
|
||||
// parseSearchParams - parses "searchParams" from the HTTP request's query string
|
||||
func (l *queryLog) parseSearchParams(r *http.Request) (p *searchParams, err error) {
|
||||
// parseSearchParams parses search parameters from the HTTP request's query
|
||||
// string.
|
||||
func parseSearchParams(r *http.Request) (p *searchParams, err error) {
|
||||
p = newSearchParams()
|
||||
|
||||
q := r.URL.Query()
|
||||
@@ -237,7 +372,7 @@ func (l *queryLog) parseSearchParams(r *http.Request) (p *searchParams, err erro
|
||||
}} {
|
||||
var ok bool
|
||||
var c searchCriterion
|
||||
ok, c, err = l.parseSearchCriterion(q, v.urlField, v.ct)
|
||||
ok, c, err = parseSearchCriterion(q, v.urlField, v.ct)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -19,12 +19,16 @@ import (
|
||||
type jobject = map[string]any
|
||||
|
||||
// entriesToJSON converts query log entries to JSON.
|
||||
func (l *queryLog) entriesToJSON(entries []*logEntry, oldest time.Time) (res jobject) {
|
||||
func entriesToJSON(
|
||||
entries []*logEntry,
|
||||
oldest time.Time,
|
||||
anonFunc aghnet.IPMutFunc,
|
||||
) (res jobject) {
|
||||
data := make([]jobject, 0, len(entries))
|
||||
|
||||
// The elements order is already reversed to be from newer to older.
|
||||
for _, entry := range entries {
|
||||
jsonEntry := l.entryToJSON(entry, l.anonymizer.Load())
|
||||
jsonEntry := entryToJSON(entry, anonFunc)
|
||||
data = append(data, jsonEntry)
|
||||
}
|
||||
|
||||
@@ -40,7 +44,7 @@ func (l *queryLog) entriesToJSON(entries []*logEntry, oldest time.Time) (res job
|
||||
}
|
||||
|
||||
// entryToJSON converts a log entry's data into an entry for the JSON API.
|
||||
func (l *queryLog) entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject) {
|
||||
func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject) {
|
||||
hostname := entry.QHost
|
||||
question := jobject{
|
||||
"type": entry.QType,
|
||||
@@ -92,14 +96,14 @@ func (l *queryLog) entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (json
|
||||
jsonEntry["service_name"] = entry.Result.ServiceName
|
||||
}
|
||||
|
||||
l.setMsgData(entry, jsonEntry)
|
||||
l.setOrigAns(entry, jsonEntry)
|
||||
setMsgData(entry, jsonEntry)
|
||||
setOrigAns(entry, jsonEntry)
|
||||
|
||||
return jsonEntry
|
||||
}
|
||||
|
||||
// setMsgData sets the message data in jsonEntry.
|
||||
func (l *queryLog) setMsgData(entry *logEntry, jsonEntry jobject) {
|
||||
func setMsgData(entry *logEntry, jsonEntry jobject) {
|
||||
if len(entry.Answer) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -122,7 +126,7 @@ func (l *queryLog) setMsgData(entry *logEntry, jsonEntry jobject) {
|
||||
}
|
||||
|
||||
// setOrigAns sets the original answer data in jsonEntry.
|
||||
func (l *queryLog) setOrigAns(entry *logEntry, jsonEntry jobject) {
|
||||
func setOrigAns(entry *logEntry, jsonEntry jobject) {
|
||||
if len(entry.OrigAnswer) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package querylog
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -25,9 +24,12 @@ const (
|
||||
type queryLog struct {
|
||||
findClient func(ids []string) (c *Client, err error)
|
||||
|
||||
conf *Config
|
||||
lock sync.Mutex
|
||||
logFile string // path to the log file
|
||||
// confMu protects conf.
|
||||
confMu *sync.RWMutex
|
||||
conf *Config
|
||||
|
||||
// logFile is the path to the log file.
|
||||
logFile string
|
||||
|
||||
// bufferLock protects buffer.
|
||||
bufferLock sync.RWMutex
|
||||
@@ -71,52 +73,24 @@ func NewClientProto(s string) (cp ClientProto, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
// logEntry - represents a single log entry
|
||||
type logEntry struct {
|
||||
// client is the found client information, if any.
|
||||
client *Client
|
||||
|
||||
Time time.Time `json:"T"`
|
||||
|
||||
QHost string `json:"QH"`
|
||||
QType string `json:"QT"`
|
||||
QClass string `json:"QC"`
|
||||
|
||||
ReqECS string `json:"ECS,omitempty"`
|
||||
|
||||
ClientID string `json:"CID,omitempty"`
|
||||
ClientProto ClientProto `json:"CP"`
|
||||
|
||||
Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net
|
||||
OrigAnswer []byte `json:",omitempty"`
|
||||
|
||||
Result filtering.Result
|
||||
Upstream string `json:",omitempty"`
|
||||
|
||||
IP net.IP `json:"IP"`
|
||||
|
||||
Elapsed time.Duration
|
||||
|
||||
Cached bool `json:",omitempty"`
|
||||
AuthenticatedData bool `json:"AD,omitempty"`
|
||||
}
|
||||
|
||||
// shallowClone returns a shallow clone of e.
|
||||
func (e *logEntry) shallowClone() (clone *logEntry) {
|
||||
cloneVal := *e
|
||||
|
||||
return &cloneVal
|
||||
}
|
||||
|
||||
func (l *queryLog) Start() {
|
||||
if l.conf.HTTPRegister != nil {
|
||||
l.initWeb()
|
||||
}
|
||||
|
||||
go l.periodicRotate()
|
||||
}
|
||||
|
||||
func (l *queryLog) Close() {
|
||||
_ = l.flushLogBuffer(true)
|
||||
l.confMu.RLock()
|
||||
defer l.confMu.RUnlock()
|
||||
|
||||
if l.conf.FileEnabled {
|
||||
err := l.flushLogBuffer()
|
||||
if err != nil {
|
||||
log.Error("querylog: closing: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkInterval(ivl time.Duration) (ok bool) {
|
||||
@@ -132,8 +106,26 @@ func checkInterval(ivl time.Duration) (ok bool) {
|
||||
return ivl == quarterDay || ivl == day || ivl == week || ivl == month || ivl == threeMonths
|
||||
}
|
||||
|
||||
// validateIvl returns an error if ivl is less than an hour or more than a
|
||||
// year.
|
||||
func validateIvl(ivl time.Duration) (err error) {
|
||||
if ivl < time.Hour {
|
||||
return errors.Error("less than an hour")
|
||||
}
|
||||
|
||||
if ivl > timeutil.Day*365 {
|
||||
return errors.Error("more than a year")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *queryLog) WriteDiskConfig(c *Config) {
|
||||
l.confMu.RLock()
|
||||
defer l.confMu.RUnlock()
|
||||
|
||||
*c = *l.conf
|
||||
c.Ignored = l.conf.Ignored.Clone()
|
||||
}
|
||||
|
||||
// Clear memory buffer and remove log files
|
||||
@@ -141,10 +133,13 @@ func (l *queryLog) clear() {
|
||||
l.fileFlushLock.Lock()
|
||||
defer l.fileFlushLock.Unlock()
|
||||
|
||||
l.bufferLock.Lock()
|
||||
l.buffer = nil
|
||||
l.flushPending = false
|
||||
l.bufferLock.Unlock()
|
||||
func() {
|
||||
l.bufferLock.Lock()
|
||||
defer l.bufferLock.Unlock()
|
||||
|
||||
l.buffer = nil
|
||||
l.flushPending = false
|
||||
}()
|
||||
|
||||
oldLogFile := l.logFile + ".1"
|
||||
err := os.Remove(oldLogFile)
|
||||
@@ -161,7 +156,17 @@ func (l *queryLog) clear() {
|
||||
}
|
||||
|
||||
func (l *queryLog) Add(params *AddParams) {
|
||||
if !l.conf.Enabled {
|
||||
var isEnabled, fileIsEnabled bool
|
||||
var memSize uint32
|
||||
func() {
|
||||
l.confMu.RLock()
|
||||
defer l.confMu.RUnlock()
|
||||
|
||||
isEnabled, fileIsEnabled = l.conf.Enabled, l.conf.FileEnabled
|
||||
memSize = l.conf.MemSize
|
||||
}()
|
||||
|
||||
if !isEnabled {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -178,7 +183,7 @@ func (l *queryLog) Add(params *AddParams) {
|
||||
|
||||
now := time.Now()
|
||||
q := params.Question.Question[0]
|
||||
entry := logEntry{
|
||||
entry := &logEntry{
|
||||
Time: now,
|
||||
|
||||
QHost: strings.ToLower(q.Name[:len(q.Name)-1]),
|
||||
@@ -203,65 +208,63 @@ func (l *queryLog) Add(params *AddParams) {
|
||||
entry.ReqECS = params.ReqECS.String()
|
||||
}
|
||||
|
||||
if params.Answer != nil {
|
||||
var a []byte
|
||||
a, err = params.Answer.Pack()
|
||||
if err != nil {
|
||||
log.Error("querylog: Answer.Pack(): %s", err)
|
||||
entry.addResponse(params.Answer, false)
|
||||
entry.addResponse(params.OrigAnswer, true)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
entry.Answer = a
|
||||
}
|
||||
|
||||
if params.OrigAnswer != nil {
|
||||
var a []byte
|
||||
a, err = params.OrigAnswer.Pack()
|
||||
if err != nil {
|
||||
log.Error("querylog: OrigAnswer.Pack(): %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
entry.OrigAnswer = a
|
||||
}
|
||||
|
||||
l.bufferLock.Lock()
|
||||
l.buffer = append(l.buffer, &entry)
|
||||
needFlush := false
|
||||
func() {
|
||||
l.bufferLock.Lock()
|
||||
defer l.bufferLock.Unlock()
|
||||
|
||||
if !l.conf.FileEnabled {
|
||||
if len(l.buffer) > int(l.conf.MemSize) {
|
||||
// writing to file is disabled - just remove the oldest entry from array
|
||||
//
|
||||
// TODO(a.garipov): This should be replaced by a proper ring buffer,
|
||||
// but it's currently difficult to do that.
|
||||
l.buffer[0] = nil
|
||||
l.buffer = l.buffer[1:]
|
||||
}
|
||||
} else if !l.flushPending {
|
||||
needFlush = len(l.buffer) >= int(l.conf.MemSize)
|
||||
if needFlush {
|
||||
l.flushPending = true
|
||||
}
|
||||
}
|
||||
l.bufferLock.Unlock()
|
||||
l.buffer = append(l.buffer, entry)
|
||||
|
||||
if !fileIsEnabled {
|
||||
if len(l.buffer) > int(memSize) {
|
||||
// Writing to file is disabled, so just remove the oldest entry
|
||||
// from the slices.
|
||||
//
|
||||
// TODO(a.garipov): This should be replaced by a proper ring
|
||||
// buffer, but it's currently difficult to do that.
|
||||
l.buffer[0] = nil
|
||||
l.buffer = l.buffer[1:]
|
||||
}
|
||||
} else if !l.flushPending {
|
||||
needFlush = len(l.buffer) >= int(memSize)
|
||||
if needFlush {
|
||||
l.flushPending = true
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// if buffer needs to be flushed to disk, do it now
|
||||
if needFlush {
|
||||
go func() {
|
||||
_ = l.flushLogBuffer(false)
|
||||
flushErr := l.flushLogBuffer()
|
||||
if flushErr != nil {
|
||||
log.Error("querylog: flushing after adding: %s", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// ShouldLog returns true if request for the host should be logged.
|
||||
func (l *queryLog) ShouldLog(host string, _, _ uint16) bool {
|
||||
func (l *queryLog) ShouldLog(host string, _, _ uint16, ids []string) bool {
|
||||
l.confMu.RLock()
|
||||
defer l.confMu.RUnlock()
|
||||
|
||||
c, err := l.findClient(ids)
|
||||
if err != nil {
|
||||
log.Error("querylog: finding client: %s", err)
|
||||
}
|
||||
|
||||
if c != nil && c.IgnoreQueryLog {
|
||||
return false
|
||||
}
|
||||
|
||||
return !l.isIgnored(host)
|
||||
}
|
||||
|
||||
// isIgnored returns true if the host is in the Ignored list.
|
||||
// isIgnored returns true if the host is in the ignored domains list. It
|
||||
// assumes that l.confMu is locked for reading.
|
||||
func (l *queryLog) isIgnored(host string) bool {
|
||||
return l.conf.Ignored.Has(host)
|
||||
}
|
||||
|
||||
@@ -22,24 +22,25 @@ func TestMain(m *testing.M) {
|
||||
// TestQueryLog tests adding and loading (with filtering) entries from disk and
|
||||
// memory.
|
||||
func TestQueryLog(t *testing.T) {
|
||||
l := newQueryLog(Config{
|
||||
l, err := newQueryLog(Config{
|
||||
Enabled: true,
|
||||
FileEnabled: true,
|
||||
RotationIvl: timeutil.Day,
|
||||
MemSize: 100,
|
||||
BaseDir: t.TempDir(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add disk entries.
|
||||
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
// Write to disk (first file).
|
||||
require.NoError(t, l.flushLogBuffer(true))
|
||||
require.NoError(t, l.flushLogBuffer())
|
||||
// Start writing to the second file.
|
||||
require.NoError(t, l.rotate())
|
||||
// Add disk entries.
|
||||
addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
|
||||
// Write to disk.
|
||||
require.NoError(t, l.flushLogBuffer(true))
|
||||
require.NoError(t, l.flushLogBuffer())
|
||||
// Add memory entries.
|
||||
addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
|
||||
addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4))
|
||||
@@ -125,12 +126,13 @@ func TestQueryLog(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestQueryLogOffsetLimit(t *testing.T) {
|
||||
l := newQueryLog(Config{
|
||||
l, err := newQueryLog(Config{
|
||||
Enabled: true,
|
||||
RotationIvl: timeutil.Day,
|
||||
MemSize: 100,
|
||||
BaseDir: t.TempDir(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
const (
|
||||
entNum = 10
|
||||
@@ -142,7 +144,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
|
||||
addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
}
|
||||
// Write them to the first file.
|
||||
require.NoError(t, l.flushLogBuffer(true))
|
||||
require.NoError(t, l.flushLogBuffer())
|
||||
// Add more to the in-memory part of log.
|
||||
for i := 0; i < entNum; i++ {
|
||||
addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
@@ -199,13 +201,14 @@ func TestQueryLogOffsetLimit(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestQueryLogMaxFileScanEntries(t *testing.T) {
|
||||
l := newQueryLog(Config{
|
||||
l, err := newQueryLog(Config{
|
||||
Enabled: true,
|
||||
FileEnabled: true,
|
||||
RotationIvl: timeutil.Day,
|
||||
MemSize: 100,
|
||||
BaseDir: t.TempDir(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
const entNum = 10
|
||||
// Add entries to the log.
|
||||
@@ -213,7 +216,7 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
|
||||
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
}
|
||||
// Write them to disk.
|
||||
require.NoError(t, l.flushLogBuffer(true))
|
||||
require.NoError(t, l.flushLogBuffer())
|
||||
|
||||
params := newSearchParams()
|
||||
|
||||
@@ -227,13 +230,14 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestQueryLogFileDisabled(t *testing.T) {
|
||||
l := newQueryLog(Config{
|
||||
l, err := newQueryLog(Config{
|
||||
Enabled: true,
|
||||
FileEnabled: false,
|
||||
RotationIvl: timeutil.Day,
|
||||
MemSize: 2,
|
||||
BaseDir: t.TempDir(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
addEntry(l, "example1.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
addEntry(l, "example2.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
@@ -254,35 +258,52 @@ func TestQueryLogShouldLog(t *testing.T) {
|
||||
)
|
||||
set := stringutil.NewSet(ignored1, ignored2)
|
||||
|
||||
l := newQueryLog(Config{
|
||||
findClient := func(ids []string) (c *Client, err error) {
|
||||
log := ids[0] == "no_log"
|
||||
|
||||
return &Client{IgnoreQueryLog: log}, nil
|
||||
}
|
||||
|
||||
l, err := newQueryLog(Config{
|
||||
Ignored: set,
|
||||
Enabled: true,
|
||||
RotationIvl: timeutil.Day,
|
||||
MemSize: 100,
|
||||
BaseDir: t.TempDir(),
|
||||
Ignored: set,
|
||||
FindClient: findClient,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
ids []string
|
||||
wantLog bool
|
||||
}{{
|
||||
name: "log",
|
||||
host: "example.com",
|
||||
ids: []string{"whatever"},
|
||||
wantLog: true,
|
||||
}, {
|
||||
name: "no_log_ignored_1",
|
||||
host: ignored1,
|
||||
ids: []string{"whatever"},
|
||||
wantLog: false,
|
||||
}, {
|
||||
name: "no_log_ignored_2",
|
||||
host: ignored2,
|
||||
ids: []string{"whatever"},
|
||||
wantLog: false,
|
||||
}, {
|
||||
name: "no_log_client_ignore",
|
||||
host: "example.com",
|
||||
ids: []string{"no_log"},
|
||||
wantLog: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
res := l.ShouldLog(tc.host, dns.TypeA, dns.ClassINET)
|
||||
res := l.ShouldLog(tc.host, dns.TypeA, dns.ClassINET, tc.ids)
|
||||
|
||||
assert.Equal(t, tc.wantLog, res)
|
||||
})
|
||||
|
||||
@@ -106,6 +106,7 @@ func (r *QLogReader) SeekStart() error {
|
||||
|
||||
r.currentFile = len(r.qFiles) - 1
|
||||
_, err := r.qFiles[r.currentFile].SeekStart()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
package querylog
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
@@ -29,16 +29,21 @@ type QueryLog interface {
|
||||
WriteDiskConfig(c *Config)
|
||||
|
||||
// ShouldLog returns true if request for the host should be logged.
|
||||
ShouldLog(host string, qType, qClass uint16) bool
|
||||
ShouldLog(host string, qType, qClass uint16, ids []string) bool
|
||||
}
|
||||
|
||||
// Config is the query log configuration structure.
|
||||
//
|
||||
// Do not alter any fields of this structure after using it.
|
||||
type Config struct {
|
||||
// Ignored is the list of host names, which should not be written to log.
|
||||
Ignored *stringutil.Set
|
||||
|
||||
// Anonymizer processes the IP addresses to anonymize those if needed.
|
||||
Anonymizer *aghnet.IPMut
|
||||
|
||||
// ConfigModified is called when the configuration is changed, for
|
||||
// example by HTTP requests.
|
||||
// ConfigModified is called when the configuration is changed, for example
|
||||
// by HTTP requests.
|
||||
ConfigModified func()
|
||||
|
||||
// HTTPRegister registers an HTTP handler.
|
||||
@@ -50,20 +55,13 @@ type Config struct {
|
||||
// BaseDir is the base directory for log files.
|
||||
BaseDir string
|
||||
|
||||
// RotationIvl is the interval for log rotation. After that period, the
|
||||
// old log file will be renamed, NOT deleted, so the actual log
|
||||
// retention time is twice the interval. The value must be one of:
|
||||
//
|
||||
// 6 * time.Hour
|
||||
// 1 * timeutil.Day
|
||||
// 7 * timeutil.Day
|
||||
// 30 * timeutil.Day
|
||||
// 90 * timeutil.Day
|
||||
//
|
||||
// RotationIvl is the interval for log rotation. After that period, the old
|
||||
// log file will be renamed, NOT deleted, so the actual log retention time
|
||||
// is twice the interval.
|
||||
RotationIvl time.Duration
|
||||
|
||||
// MemSize is the number of entries kept in a memory buffer before they
|
||||
// are flushed to disk.
|
||||
// MemSize is the number of entries kept in a memory buffer before they are
|
||||
// flushed to disk.
|
||||
MemSize uint32
|
||||
|
||||
// Enabled tells if the query log is enabled.
|
||||
@@ -75,10 +73,6 @@ type Config struct {
|
||||
// AnonymizeClientIP tells if the query log should anonymize clients' IP
|
||||
// addresses.
|
||||
AnonymizeClientIP bool
|
||||
|
||||
// Ignored is the list of host names, which should not be written to
|
||||
// log.
|
||||
Ignored *stringutil.Set
|
||||
}
|
||||
|
||||
// AddParams is the parameters for adding an entry.
|
||||
@@ -135,12 +129,12 @@ func (p *AddParams) validate() (err error) {
|
||||
}
|
||||
|
||||
// New creates a new instance of the query log.
|
||||
func New(conf Config) (ql QueryLog) {
|
||||
func New(conf Config) (ql QueryLog, err error) {
|
||||
return newQueryLog(conf)
|
||||
}
|
||||
|
||||
// newQueryLog crates a new queryLog.
|
||||
func newQueryLog(conf Config) (l *queryLog) {
|
||||
func newQueryLog(conf Config) (l *queryLog, err error) {
|
||||
findClient := conf.FindClient
|
||||
if findClient == nil {
|
||||
findClient = func(_ []string) (_ *Client, _ error) {
|
||||
@@ -151,20 +145,19 @@ func newQueryLog(conf Config) (l *queryLog) {
|
||||
l = &queryLog{
|
||||
findClient: findClient,
|
||||
|
||||
logFile: filepath.Join(conf.BaseDir, queryLogFileName),
|
||||
conf: &Config{},
|
||||
confMu: &sync.RWMutex{},
|
||||
logFile: filepath.Join(conf.BaseDir, queryLogFileName),
|
||||
|
||||
anonymizer: conf.Anonymizer,
|
||||
}
|
||||
|
||||
l.conf = &Config{}
|
||||
*l.conf = conf
|
||||
|
||||
if !checkInterval(conf.RotationIvl) {
|
||||
log.Info(
|
||||
"querylog: warning: unsupported rotation interval %s, setting to 1 day",
|
||||
conf.RotationIvl,
|
||||
)
|
||||
l.conf.RotationIvl = timeutil.Day
|
||||
err = validateIvl(conf.RotationIvl)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unsupported interval: %w", err)
|
||||
}
|
||||
|
||||
return l
|
||||
return l, nil
|
||||
}
|
||||
|
||||
@@ -11,40 +11,35 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// flushLogBuffer flushes the current buffer to file and resets the current buffer
|
||||
func (l *queryLog) flushLogBuffer(fullFlush bool) error {
|
||||
if !l.conf.FileEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// flushLogBuffer flushes the current buffer to file and resets the current
|
||||
// buffer.
|
||||
func (l *queryLog) flushLogBuffer() (err error) {
|
||||
l.fileFlushLock.Lock()
|
||||
defer l.fileFlushLock.Unlock()
|
||||
|
||||
// flush remainder to file
|
||||
l.bufferLock.Lock()
|
||||
needFlush := len(l.buffer) >= int(l.conf.MemSize)
|
||||
if !needFlush && !fullFlush {
|
||||
l.bufferLock.Unlock()
|
||||
return nil
|
||||
}
|
||||
flushBuffer := l.buffer
|
||||
l.buffer = nil
|
||||
l.flushPending = false
|
||||
l.bufferLock.Unlock()
|
||||
err := l.flushToFile(flushBuffer)
|
||||
if err != nil {
|
||||
log.Error("Saving querylog to file failed: %s", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
var flushBuffer []*logEntry
|
||||
func() {
|
||||
l.bufferLock.Lock()
|
||||
defer l.bufferLock.Unlock()
|
||||
|
||||
flushBuffer = l.buffer
|
||||
l.buffer = nil
|
||||
l.flushPending = false
|
||||
}()
|
||||
|
||||
err = l.flushToFile(flushBuffer)
|
||||
|
||||
return errors.Annotate(err, "writing to file: %w")
|
||||
}
|
||||
|
||||
// flushToFile saves the specified log entries to the query log file
|
||||
func (l *queryLog) flushToFile(buffer []*logEntry) (err error) {
|
||||
if len(buffer) == 0 {
|
||||
log.Debug("querylog: there's nothing to write to a file")
|
||||
log.Debug("querylog: nothing to write to a file")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
var b bytes.Buffer
|
||||
@@ -155,8 +150,13 @@ func (l *queryLog) periodicRotate() {
|
||||
// checkAndRotate rotates log files if those are older than the specified
|
||||
// rotation interval.
|
||||
func (l *queryLog) checkAndRotate() {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
var rotationIvl time.Duration
|
||||
func() {
|
||||
l.confMu.RLock()
|
||||
defer l.confMu.RUnlock()
|
||||
|
||||
rotationIvl = l.conf.RotationIvl
|
||||
}()
|
||||
|
||||
oldest, err := l.readFileFirstTimeValue()
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
@@ -165,11 +165,11 @@ func (l *queryLog) checkAndRotate() {
|
||||
return
|
||||
}
|
||||
|
||||
if rot, now := oldest.Add(l.conf.RotationIvl), time.Now(); rot.After(now) {
|
||||
if rotTime, now := oldest.Add(rotationIvl), time.Now(); rotTime.After(now) {
|
||||
log.Debug(
|
||||
"querylog: %s <= %s, not rotating",
|
||||
now.Format(time.RFC3339),
|
||||
rot.Format(time.RFC3339),
|
||||
rotTime.Format(time.RFC3339),
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
@@ -76,15 +76,20 @@ func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entrie
|
||||
// search - searches log entries in the query log using specified parameters
|
||||
// returns the list of entries found + time of the oldest entry
|
||||
func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest time.Time) {
|
||||
now := time.Now()
|
||||
start := time.Now()
|
||||
|
||||
if params.limit == 0 {
|
||||
return []*logEntry{}, time.Time{}
|
||||
}
|
||||
|
||||
cache := clientCache{}
|
||||
fileEntries, oldest, total := l.searchFiles(params, cache)
|
||||
|
||||
memoryEntries, bufLen := l.searchMemory(params, cache)
|
||||
log.Debug("querylog: got %d entries from memory", len(memoryEntries))
|
||||
|
||||
fileEntries, oldest, total := l.searchFiles(params, cache)
|
||||
log.Debug("querylog: got %d entries from files", len(fileEntries))
|
||||
|
||||
total += bufLen
|
||||
|
||||
totalLimit := params.offset + params.limit
|
||||
@@ -123,7 +128,7 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim
|
||||
len(entries),
|
||||
total,
|
||||
params.olderThan,
|
||||
time.Since(now),
|
||||
time.Since(start),
|
||||
)
|
||||
|
||||
return entries, oldest
|
||||
@@ -145,13 +150,14 @@ func (l *queryLog) searchFiles(
|
||||
|
||||
r, err := NewQLogReader(files)
|
||||
if err != nil {
|
||||
log.Error("querylog: failed to open qlog reader: %s", err)
|
||||
log.Error("querylog: opening qlog reader: %s", err)
|
||||
|
||||
return entries, oldest, 0
|
||||
}
|
||||
|
||||
defer func() {
|
||||
derr := r.Close()
|
||||
if derr != nil {
|
||||
closeErr := r.Close()
|
||||
if closeErr != nil {
|
||||
log.Error("querylog: closing file: %s", err)
|
||||
}
|
||||
}()
|
||||
@@ -161,8 +167,8 @@ func (l *queryLog) searchFiles(
|
||||
} else {
|
||||
err = r.seekTS(params.olderThan.UnixNano())
|
||||
if err == nil {
|
||||
// Read to the next record, because we only need the one
|
||||
// that goes after it.
|
||||
// Read to the next record, because we only need the one that goes
|
||||
// after it.
|
||||
_, err = r.ReadNext()
|
||||
}
|
||||
}
|
||||
@@ -176,9 +182,9 @@ func (l *queryLog) searchFiles(
|
||||
totalLimit := params.offset + params.limit
|
||||
oldestNano := int64(0)
|
||||
|
||||
// By default, we do not scan more than maxFileScanEntries at once.
|
||||
// The idea is to make search calls faster so that the UI could handle
|
||||
// it and show something quicker. This behavior can be overridden if
|
||||
// By default, we do not scan more than maxFileScanEntries at once. The
|
||||
// idea is to make search calls faster so that the UI could handle it and
|
||||
// show something quicker. This behavior can be overridden if
|
||||
// maxFileScanEntries is set to 0.
|
||||
for total < params.maxFileScanEntries || params.maxFileScanEntries <= 0 {
|
||||
var e *logEntry
|
||||
|
||||
@@ -35,7 +35,7 @@ func TestQueryLog_Search_findClient(t *testing.T) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
l := newQueryLog(Config{
|
||||
l, err := newQueryLog(Config{
|
||||
FindClient: findClient,
|
||||
BaseDir: t.TempDir(),
|
||||
RotationIvl: timeutil.Day,
|
||||
@@ -44,6 +44,7 @@ func TestQueryLog_Search_findClient(t *testing.T) {
|
||||
FileEnabled: true,
|
||||
AnonymizeClientIP: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(l.Close)
|
||||
|
||||
q := &dns.Msg{
|
||||
|
||||
@@ -7,8 +7,12 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// topAddrs is an alias for the types of the TopFoo fields of statsResponse.
|
||||
@@ -38,13 +42,21 @@ type StatsResp struct {
|
||||
AvgProcessingTime float64 `json:"avg_processing_time"`
|
||||
}
|
||||
|
||||
// handleStats handles requests to the GET /control/stats endpoint.
|
||||
// handleStats is the handler for the GET /control/stats HTTP API.
|
||||
func (s *StatsCtx) handleStats(w http.ResponseWriter, r *http.Request) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
start := time.Now()
|
||||
resp, ok := s.getData(s.limitHours)
|
||||
|
||||
var (
|
||||
resp StatsResp
|
||||
ok bool
|
||||
)
|
||||
func() {
|
||||
s.confMu.RLock()
|
||||
defer s.confMu.RUnlock()
|
||||
|
||||
resp, ok = s.getData(uint32(s.limit.Hours()))
|
||||
}()
|
||||
|
||||
log.Debug("stats: prepared data in %v", time.Since(start))
|
||||
|
||||
if !ok {
|
||||
@@ -63,20 +75,73 @@ type configResp struct {
|
||||
IntervalDays uint32 `json:"interval"`
|
||||
}
|
||||
|
||||
// handleStatsInfo handles requests to the GET /control/stats_info endpoint.
|
||||
func (s *StatsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
// getConfigResp is the response to the GET /control/stats_info.
|
||||
type getConfigResp struct {
|
||||
// Ignored is the list of host names, which should not be counted.
|
||||
Ignored []string `json:"ignored"`
|
||||
|
||||
resp := configResp{IntervalDays: s.limitHours / 24}
|
||||
if !s.enabled {
|
||||
// Interval is the statistics rotation interval in milliseconds.
|
||||
Interval float64 `json:"interval"`
|
||||
|
||||
// Enabled shows if statistics are enabled. It is an aghalg.NullBool to be
|
||||
// able to tell when it's set without using pointers.
|
||||
Enabled aghalg.NullBool `json:"enabled"`
|
||||
}
|
||||
|
||||
// handleStatsInfo is the handler for the GET /control/stats_info HTTP API.
|
||||
//
|
||||
// Deprecated: Remove it when migration to the new API is over.
|
||||
func (s *StatsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
enabled bool
|
||||
limit time.Duration
|
||||
)
|
||||
func() {
|
||||
s.confMu.RLock()
|
||||
defer s.confMu.RUnlock()
|
||||
|
||||
enabled, limit = s.enabled, s.limit
|
||||
}()
|
||||
|
||||
days := uint32(limit / timeutil.Day)
|
||||
ok := checkInterval(days)
|
||||
if !ok || (enabled && days == 0) {
|
||||
// NOTE: If interval is custom we set it to 90 days for compatibility
|
||||
// with old API.
|
||||
days = 90
|
||||
}
|
||||
|
||||
resp := configResp{IntervalDays: days}
|
||||
if !enabled {
|
||||
resp.IntervalDays = 0
|
||||
}
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
}
|
||||
|
||||
// handleStatsConfig handles requests to the POST /control/stats_config
|
||||
// endpoint.
|
||||
// handleGetStatsConfig is the handler for the GET /control/stats/config HTTP
|
||||
// API.
|
||||
func (s *StatsCtx) handleGetStatsConfig(w http.ResponseWriter, r *http.Request) {
|
||||
var resp *getConfigResp
|
||||
func() {
|
||||
s.confMu.RLock()
|
||||
defer s.confMu.RUnlock()
|
||||
|
||||
resp = &getConfigResp{
|
||||
Ignored: s.ignored.Values(),
|
||||
Interval: float64(s.limit.Milliseconds()),
|
||||
Enabled: aghalg.BoolToNullBool(s.enabled),
|
||||
}
|
||||
}()
|
||||
|
||||
slices.Sort(resp.Ignored)
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
}
|
||||
|
||||
// handleStatsConfig is the handler for the POST /control/stats_config HTTP API.
|
||||
//
|
||||
// Deprecated: Remove it when migration to the new API is over.
|
||||
func (s *StatsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) {
|
||||
reqData := configResp{}
|
||||
err := json.NewDecoder(r.Body).Decode(&reqData)
|
||||
@@ -92,11 +157,59 @@ func (s *StatsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
s.setLimit(int(reqData.IntervalDays))
|
||||
s.configModified()
|
||||
limit := time.Duration(reqData.IntervalDays) * timeutil.Day
|
||||
|
||||
defer s.configModified()
|
||||
|
||||
s.confMu.Lock()
|
||||
defer s.confMu.Unlock()
|
||||
|
||||
s.setLimit(limit)
|
||||
}
|
||||
|
||||
// handleStatsReset handles requests to the POST /control/stats_reset endpoint.
|
||||
// handlePutStatsConfig is the handler for the PUT /control/stats/config/update
|
||||
// HTTP API.
|
||||
func (s *StatsCtx) handlePutStatsConfig(w http.ResponseWriter, r *http.Request) {
|
||||
reqData := getConfigResp{}
|
||||
err := json.NewDecoder(r.Body).Decode(&reqData)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
set, err := aghnet.NewDomainNameSet(reqData.Ignored)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "ignored: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ivl := time.Duration(reqData.Interval) * time.Millisecond
|
||||
err = validateIvl(ivl)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "unsupported interval: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if reqData.Enabled == aghalg.NBNull {
|
||||
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "enabled is null")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer s.configModified()
|
||||
|
||||
s.confMu.Lock()
|
||||
defer s.confMu.Unlock()
|
||||
|
||||
s.ignored = set
|
||||
s.limit = ivl
|
||||
s.enabled = reqData.Enabled == aghalg.NBTrue
|
||||
}
|
||||
|
||||
// handleStatsReset is the handler for the POST /control/stats_reset HTTP API.
|
||||
func (s *StatsCtx) handleStatsReset(w http.ResponseWriter, r *http.Request) {
|
||||
err := s.clear()
|
||||
if err != nil {
|
||||
@@ -112,6 +225,10 @@ func (s *StatsCtx) initWeb() {
|
||||
|
||||
s.httpRegister(http.MethodGet, "/control/stats", s.handleStats)
|
||||
s.httpRegister(http.MethodPost, "/control/stats_reset", s.handleStatsReset)
|
||||
s.httpRegister(http.MethodPost, "/control/stats_config", s.handleStatsConfig)
|
||||
s.httpRegister(http.MethodGet, "/control/stats/config", s.handleGetStatsConfig)
|
||||
s.httpRegister(http.MethodPut, "/control/stats/config/update", s.handlePutStatsConfig)
|
||||
|
||||
// Deprecated handlers.
|
||||
s.httpRegister(http.MethodGet, "/control/stats_info", s.handleStatsInfo)
|
||||
s.httpRegister(http.MethodPost, "/control/stats_config", s.handleStatsConfig)
|
||||
}
|
||||
|
||||
153
internal/stats/http_test.go
Normal file
153
internal/stats/http_test.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHandleStatsConfig(t *testing.T) {
|
||||
const (
|
||||
smallIvl = 1 * time.Minute
|
||||
minIvl = 1 * time.Hour
|
||||
maxIvl = 365 * timeutil.Day
|
||||
)
|
||||
|
||||
conf := Config{
|
||||
UnitID: func() (id uint32) { return 0 },
|
||||
ConfigModified: func() {},
|
||||
ShouldCountClient: func([]string) bool { return true },
|
||||
Filename: filepath.Join(t.TempDir(), "stats.db"),
|
||||
Limit: time.Hour * 24,
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErr string
|
||||
body getConfigResp
|
||||
wantCode int
|
||||
}{{
|
||||
name: "set_ivl_1_minIvl",
|
||||
body: getConfigResp{
|
||||
Enabled: aghalg.NBTrue,
|
||||
Interval: float64(minIvl.Milliseconds()),
|
||||
Ignored: []string{},
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
wantErr: "",
|
||||
}, {
|
||||
name: "small_interval",
|
||||
body: getConfigResp{
|
||||
Enabled: aghalg.NBTrue,
|
||||
Interval: float64(smallIvl.Milliseconds()),
|
||||
Ignored: []string{},
|
||||
},
|
||||
wantCode: http.StatusUnprocessableEntity,
|
||||
wantErr: "unsupported interval: less than an hour\n",
|
||||
}, {
|
||||
name: "big_interval",
|
||||
body: getConfigResp{
|
||||
Enabled: aghalg.NBTrue,
|
||||
Interval: float64(maxIvl.Milliseconds() + minIvl.Milliseconds()),
|
||||
Ignored: []string{},
|
||||
},
|
||||
wantCode: http.StatusUnprocessableEntity,
|
||||
wantErr: "unsupported interval: more than a year\n",
|
||||
}, {
|
||||
name: "set_ignored_ivl_1_maxIvl",
|
||||
body: getConfigResp{
|
||||
Enabled: aghalg.NBTrue,
|
||||
Interval: float64(maxIvl.Milliseconds()),
|
||||
Ignored: []string{
|
||||
"ignor.ed",
|
||||
},
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
wantErr: "",
|
||||
}, {
|
||||
name: "ignored_duplicate",
|
||||
body: getConfigResp{
|
||||
Enabled: aghalg.NBTrue,
|
||||
Interval: float64(minIvl.Milliseconds()),
|
||||
Ignored: []string{
|
||||
"ignor.ed",
|
||||
"ignor.ed",
|
||||
},
|
||||
},
|
||||
wantCode: http.StatusUnprocessableEntity,
|
||||
wantErr: "ignored: duplicate host name \"ignor.ed\" at index 1\n",
|
||||
}, {
|
||||
name: "ignored_empty",
|
||||
body: getConfigResp{
|
||||
Enabled: aghalg.NBTrue,
|
||||
Interval: float64(minIvl.Milliseconds()),
|
||||
Ignored: []string{
|
||||
"",
|
||||
},
|
||||
},
|
||||
wantCode: http.StatusUnprocessableEntity,
|
||||
wantErr: "ignored: host name is empty\n",
|
||||
}, {
|
||||
name: "enabled_is_null",
|
||||
body: getConfigResp{
|
||||
Enabled: aghalg.NBNull,
|
||||
Interval: float64(minIvl.Milliseconds()),
|
||||
Ignored: []string{},
|
||||
},
|
||||
wantCode: http.StatusUnprocessableEntity,
|
||||
wantErr: "enabled is null\n",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s, err := New(conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.Start()
|
||||
testutil.CleanupAndRequireSuccess(t, s.Close)
|
||||
|
||||
buf, err := json.Marshal(tc.body)
|
||||
require.NoError(t, err)
|
||||
|
||||
const (
|
||||
configGet = "/control/stats/config"
|
||||
configPut = "/control/stats/config/update"
|
||||
)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPut, configPut, bytes.NewReader(buf))
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
s.handlePutStatsConfig(rw, req)
|
||||
require.Equal(t, tc.wantCode, rw.Code)
|
||||
|
||||
if tc.wantCode != http.StatusOK {
|
||||
assert.Equal(t, tc.wantErr, rw.Body.String())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resp := httptest.NewRequest(http.MethodGet, configGet, nil)
|
||||
rw = httptest.NewRecorder()
|
||||
|
||||
s.handleGetStatsConfig(rw, resp)
|
||||
require.Equal(t, http.StatusOK, rw.Code)
|
||||
|
||||
ans := getConfigResp{}
|
||||
err = json.Unmarshal(rw.Body.Bytes(), &ans)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.body, ans)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
@@ -25,7 +26,23 @@ func checkInterval(days uint32) (ok bool) {
|
||||
return days == 0 || days == 1 || days == 7 || days == 30 || days == 90
|
||||
}
|
||||
|
||||
// validateIvl returns an error if ivl is less than an hour or more than a
|
||||
// year.
|
||||
func validateIvl(ivl time.Duration) (err error) {
|
||||
if ivl < time.Hour {
|
||||
return errors.Error("less than an hour")
|
||||
}
|
||||
|
||||
if ivl > timeutil.Day*365 {
|
||||
return errors.Error("more than a year")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Config is the configuration structure for the statistics collecting.
|
||||
//
|
||||
// Do not alter any fields of this structure after using it.
|
||||
type Config struct {
|
||||
// UnitID is the function to generate the identifier for current unit. If
|
||||
// nil, the default function is used, see newUnitID.
|
||||
@@ -35,22 +52,24 @@ type Config struct {
|
||||
// interface.
|
||||
ConfigModified func()
|
||||
|
||||
// ShouldCountClient returns client's ignore setting.
|
||||
ShouldCountClient func([]string) bool
|
||||
|
||||
// HTTPRegister is the function that registers handlers for the stats
|
||||
// endpoints.
|
||||
HTTPRegister aghhttp.RegisterFunc
|
||||
|
||||
// Ignored is the list of host names, which should not be counted.
|
||||
Ignored *stringutil.Set
|
||||
|
||||
// Filename is the name of the database file.
|
||||
Filename string
|
||||
|
||||
// LimitDays is the maximum number of days to collect statistics into the
|
||||
// current unit.
|
||||
LimitDays uint32
|
||||
// Limit is an upper limit for collecting statistics.
|
||||
Limit time.Duration
|
||||
|
||||
// Enabled tells if the statistics are enabled.
|
||||
Enabled bool
|
||||
|
||||
// Ignored is the list of host names, which should not be counted.
|
||||
Ignored *stringutil.Set
|
||||
}
|
||||
|
||||
// Interface is the statistics interface to be used by other packages.
|
||||
@@ -71,7 +90,7 @@ type Interface interface {
|
||||
WriteDiskConfig(dc *Config)
|
||||
|
||||
// ShouldCount returns true if request for the host should be counted.
|
||||
ShouldCount(host string, qType, qClass uint16) bool
|
||||
ShouldCount(host string, qType, qClass uint16, ids []string) bool
|
||||
}
|
||||
|
||||
// StatsCtx collects the statistics and flushes it to the database. Its default
|
||||
@@ -96,23 +115,23 @@ type StatsCtx struct {
|
||||
// interface.
|
||||
configModified func()
|
||||
|
||||
// filename is the name of database file.
|
||||
filename string
|
||||
|
||||
// lock protects all the fields below.
|
||||
lock sync.Mutex
|
||||
|
||||
// enabled tells if the statistics are enabled.
|
||||
enabled bool
|
||||
|
||||
// limitHours is the maximum number of hours to collect statistics into the
|
||||
// current unit.
|
||||
//
|
||||
// TODO(s.chzhen): Rewrite to use time.Duration.
|
||||
limitHours uint32
|
||||
// confMu protects ignored, limit, and enabled.
|
||||
confMu *sync.RWMutex
|
||||
|
||||
// ignored is the list of host names, which should not be counted.
|
||||
ignored *stringutil.Set
|
||||
|
||||
// shouldCountClient returns client's ignore setting.
|
||||
shouldCountClient func([]string) bool
|
||||
|
||||
// filename is the name of database file.
|
||||
filename string
|
||||
|
||||
// limit is an upper limit for collecting statistics.
|
||||
limit time.Duration
|
||||
|
||||
// enabled tells if the statistics are enabled.
|
||||
enabled bool
|
||||
}
|
||||
|
||||
// New creates s from conf and properly initializes it. Don't use s before
|
||||
@@ -120,17 +139,28 @@ type StatsCtx struct {
|
||||
func New(conf Config) (s *StatsCtx, err error) {
|
||||
defer withRecovered(&err)
|
||||
|
||||
err = validateIvl(conf.Limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unsupported interval: %w", err)
|
||||
}
|
||||
|
||||
if conf.ShouldCountClient == nil {
|
||||
return nil, errors.Error("should count client is unspecified")
|
||||
}
|
||||
|
||||
s = &StatsCtx{
|
||||
enabled: conf.Enabled,
|
||||
currMu: &sync.RWMutex{},
|
||||
filename: conf.Filename,
|
||||
configModified: conf.ConfigModified,
|
||||
httpRegister: conf.HTTPRegister,
|
||||
ignored: conf.Ignored,
|
||||
}
|
||||
if s.limitHours = conf.LimitDays * 24; !checkInterval(conf.LimitDays) {
|
||||
s.limitHours = 24
|
||||
configModified: conf.ConfigModified,
|
||||
filename: conf.Filename,
|
||||
|
||||
confMu: &sync.RWMutex{},
|
||||
ignored: conf.Ignored,
|
||||
shouldCountClient: conf.ShouldCountClient,
|
||||
limit: conf.Limit,
|
||||
enabled: conf.Enabled,
|
||||
}
|
||||
|
||||
if s.unitIDGen = newUnitID; conf.UnitID != nil {
|
||||
s.unitIDGen = conf.UnitID
|
||||
}
|
||||
@@ -150,7 +180,7 @@ func New(conf Config) (s *StatsCtx, err error) {
|
||||
return nil, fmt.Errorf("stats: opening a transaction: %w", err)
|
||||
}
|
||||
|
||||
deleted := deleteOldUnits(tx, id-s.limitHours-1)
|
||||
deleted := deleteOldUnits(tx, id-uint32(s.limit.Hours())-1)
|
||||
udb = loadUnitFromDB(tx, id)
|
||||
|
||||
err = finishTxn(tx, deleted > 0)
|
||||
@@ -228,10 +258,10 @@ func (s *StatsCtx) Close() (err error) {
|
||||
|
||||
// Update implements the Interface interface for *StatsCtx.
|
||||
func (s *StatsCtx) Update(e Entry) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.confMu.Lock()
|
||||
defer s.confMu.Unlock()
|
||||
|
||||
if !s.enabled || s.limitHours == 0 {
|
||||
if !s.enabled || s.limit == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -260,20 +290,20 @@ func (s *StatsCtx) Update(e Entry) {
|
||||
|
||||
// WriteDiskConfig implements the Interface interface for *StatsCtx.
|
||||
func (s *StatsCtx) WriteDiskConfig(dc *Config) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.confMu.RLock()
|
||||
defer s.confMu.RUnlock()
|
||||
|
||||
dc.LimitDays = s.limitHours / 24
|
||||
dc.Ignored = s.ignored.Clone()
|
||||
dc.Limit = s.limit
|
||||
dc.Enabled = s.enabled
|
||||
dc.Ignored = s.ignored
|
||||
}
|
||||
|
||||
// TopClientsIP implements the [Interface] interface for *StatsCtx.
|
||||
func (s *StatsCtx) TopClientsIP(maxCount uint) (ips []netip.Addr) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.confMu.RLock()
|
||||
defer s.confMu.RUnlock()
|
||||
|
||||
limit := s.limitHours
|
||||
limit := uint32(s.limit.Hours())
|
||||
if !s.enabled || limit == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -366,8 +396,8 @@ func (s *StatsCtx) openDB() (err error) {
|
||||
func (s *StatsCtx) flush() (cont bool, sleepFor time.Duration) {
|
||||
id := s.unitIDGen()
|
||||
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.confMu.Lock()
|
||||
defer s.confMu.Unlock()
|
||||
|
||||
s.currMu.Lock()
|
||||
defer s.currMu.Unlock()
|
||||
@@ -377,7 +407,7 @@ func (s *StatsCtx) flush() (cont bool, sleepFor time.Duration) {
|
||||
return false, 0
|
||||
}
|
||||
|
||||
limit := s.limitHours
|
||||
limit := uint32(s.limit.Hours())
|
||||
if limit == 0 || ptr.id == id {
|
||||
return true, time.Second
|
||||
}
|
||||
@@ -436,14 +466,14 @@ func (s *StatsCtx) periodicFlush() {
|
||||
log.Debug("periodic flushing finished")
|
||||
}
|
||||
|
||||
func (s *StatsCtx) setLimit(limitDays int) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if limitDays != 0 {
|
||||
// setLimit sets the limit. s.lock is expected to be locked.
|
||||
//
|
||||
// TODO(s.chzhen): Remove it when migration to the new API is over.
|
||||
func (s *StatsCtx) setLimit(limit time.Duration) {
|
||||
if limit != 0 {
|
||||
s.enabled = true
|
||||
s.limitHours = uint32(24 * limitDays)
|
||||
log.Debug("stats: set limit: %d days", limitDays)
|
||||
s.limit = limit
|
||||
log.Debug("stats: set limit: %d days", limit/timeutil.Day)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -558,11 +588,19 @@ func (s *StatsCtx) loadUnits(limit uint32) (units []*unitDB, firstID uint32) {
|
||||
}
|
||||
|
||||
// ShouldCount returns true if request for the host should be counted.
|
||||
func (s *StatsCtx) ShouldCount(host string, _, _ uint16) bool {
|
||||
func (s *StatsCtx) ShouldCount(host string, _, _ uint16, ids []string) bool {
|
||||
s.confMu.RLock()
|
||||
defer s.confMu.RUnlock()
|
||||
|
||||
if !s.shouldCountClient(ids) {
|
||||
return false
|
||||
}
|
||||
|
||||
return !s.isIgnored(host)
|
||||
}
|
||||
|
||||
// isIgnored returns true if the host is in the Ignored list.
|
||||
// isIgnored returns true if the host is in the ignored domains list. It
|
||||
// assumes that s.confMu is locked for reading.
|
||||
func (s *StatsCtx) isIgnored(host string) bool {
|
||||
return s.ignored.Has(host)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -35,9 +36,10 @@ func TestStats_races(t *testing.T) {
|
||||
var r uint32
|
||||
idGen := func() (id uint32) { return atomic.LoadUint32(&r) }
|
||||
conf := Config{
|
||||
UnitID: idGen,
|
||||
Filename: filepath.Join(t.TempDir(), "./stats.db"),
|
||||
LimitDays: 1,
|
||||
ShouldCountClient: func([]string) bool { return true },
|
||||
UnitID: idGen,
|
||||
Filename: filepath.Join(t.TempDir(), "./stats.db"),
|
||||
Limit: timeutil.Day,
|
||||
}
|
||||
|
||||
s, err := New(conf)
|
||||
|
||||
@@ -12,7 +12,10 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -51,10 +54,11 @@ func TestStats(t *testing.T) {
|
||||
|
||||
handlers := map[string]http.Handler{}
|
||||
conf := stats.Config{
|
||||
Filename: filepath.Join(t.TempDir(), "stats.db"),
|
||||
LimitDays: 1,
|
||||
Enabled: true,
|
||||
UnitID: constUnitID,
|
||||
ShouldCountClient: func([]string) bool { return true },
|
||||
Filename: filepath.Join(t.TempDir(), "stats.db"),
|
||||
Limit: timeutil.Day,
|
||||
Enabled: true,
|
||||
UnitID: constUnitID,
|
||||
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
|
||||
handlers[url] = handler
|
||||
},
|
||||
@@ -157,11 +161,12 @@ func TestLargeNumbers(t *testing.T) {
|
||||
handlers := map[string]http.Handler{}
|
||||
|
||||
conf := stats.Config{
|
||||
Filename: filepath.Join(t.TempDir(), "stats.db"),
|
||||
LimitDays: 1,
|
||||
Enabled: true,
|
||||
UnitID: func() (id uint32) { return atomic.LoadUint32(&curHour) },
|
||||
HTTPRegister: func(_, url string, handler http.HandlerFunc) { handlers[url] = handler },
|
||||
ShouldCountClient: func([]string) bool { return true },
|
||||
Filename: filepath.Join(t.TempDir(), "stats.db"),
|
||||
Limit: timeutil.Day,
|
||||
Enabled: true,
|
||||
UnitID: func() (id uint32) { return atomic.LoadUint32(&curHour) },
|
||||
HTTPRegister: func(_, url string, handler http.HandlerFunc) { handlers[url] = handler },
|
||||
}
|
||||
|
||||
s, err := stats.New(conf)
|
||||
@@ -196,3 +201,60 @@ func TestLargeNumbers(t *testing.T) {
|
||||
assertSuccessAndUnmarshal(t, data, handlers["/control/stats"], req)
|
||||
assert.Equal(t, hoursNum*cliNumPerHour, int(data.NumDNSQueries))
|
||||
}
|
||||
|
||||
func TestShouldCount(t *testing.T) {
|
||||
const (
|
||||
ignored1 = "ignor.ed"
|
||||
ignored2 = "ignored.to"
|
||||
)
|
||||
set := stringutil.NewSet(ignored1, ignored2)
|
||||
|
||||
s, err := stats.New(stats.Config{
|
||||
Enabled: true,
|
||||
Filename: filepath.Join(t.TempDir(), "stats.db"),
|
||||
Limit: timeutil.Day,
|
||||
Ignored: set,
|
||||
ShouldCountClient: func(ids []string) (a bool) {
|
||||
return ids[0] != "no_count"
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
s.Start()
|
||||
testutil.CleanupAndRequireSuccess(t, s.Close)
|
||||
|
||||
testCases := []struct {
|
||||
wantCount assert.BoolAssertionFunc
|
||||
name string
|
||||
host string
|
||||
ids []string
|
||||
}{{
|
||||
name: "count",
|
||||
host: "example.com",
|
||||
ids: []string{"whatever"},
|
||||
wantCount: assert.True,
|
||||
}, {
|
||||
name: "no_count_ignored_1",
|
||||
host: ignored1,
|
||||
ids: []string{"whatever"},
|
||||
wantCount: assert.False,
|
||||
}, {
|
||||
name: "no_count_ignored_2",
|
||||
host: ignored2,
|
||||
ids: []string{"whatever"},
|
||||
wantCount: assert.False,
|
||||
}, {
|
||||
name: "no_count_client_ignore",
|
||||
host: "example.com",
|
||||
ids: []string{"no_count"},
|
||||
wantCount: assert.False,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
res := s.ShouldCount(tc.host, dns.TypeA, dns.ClassINET, tc.ids)
|
||||
|
||||
tc.wantCount(t, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,6 +72,19 @@ type Entry struct {
|
||||
|
||||
// unit collects the statistics data for a specific period of time.
|
||||
type unit struct {
|
||||
// domains stores the number of requests for each domain.
|
||||
domains map[string]uint64
|
||||
|
||||
// blockedDomains stores the number of requests for each domain that has
|
||||
// been blocked.
|
||||
blockedDomains map[string]uint64
|
||||
|
||||
// clients stores the number of requests from each client.
|
||||
clients map[string]uint64
|
||||
|
||||
// nResult stores the number of requests grouped by it's result.
|
||||
nResult []uint64
|
||||
|
||||
// id is the unique unit's identifier. It's set to an absolute hour number
|
||||
// since the beginning of UNIX time by the default ID generating function.
|
||||
//
|
||||
@@ -81,29 +94,20 @@ type unit struct {
|
||||
|
||||
// nTotal stores the total number of requests.
|
||||
nTotal uint64
|
||||
// nResult stores the number of requests grouped by it's result.
|
||||
nResult []uint64
|
||||
|
||||
// timeSum stores the sum of processing time in milliseconds of each request
|
||||
// written by the unit.
|
||||
timeSum uint64
|
||||
|
||||
// domains stores the number of requests for each domain.
|
||||
domains map[string]uint64
|
||||
// blockedDomains stores the number of requests for each domain that has
|
||||
// been blocked.
|
||||
blockedDomains map[string]uint64
|
||||
// clients stores the number of requests from each client.
|
||||
clients map[string]uint64
|
||||
}
|
||||
|
||||
// newUnit allocates the new *unit.
|
||||
func newUnit(id uint32) (u *unit) {
|
||||
return &unit{
|
||||
id: id,
|
||||
domains: map[string]uint64{},
|
||||
blockedDomains: map[string]uint64{},
|
||||
clients: map[string]uint64{},
|
||||
nResult: make([]uint64, resultLast),
|
||||
domains: make(map[string]uint64),
|
||||
blockedDomains: make(map[string]uint64),
|
||||
clients: make(map[string]uint64),
|
||||
id: id,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,19 +119,25 @@ type countPair struct {
|
||||
}
|
||||
|
||||
// unitDB is the structure for serializing statistics data into the database.
|
||||
//
|
||||
// NOTE: Do not change the names or types of fields, as this structure is used
|
||||
// for GOB encoding.
|
||||
type unitDB struct {
|
||||
// NTotal is the total number of requests.
|
||||
NTotal uint64
|
||||
// NResult is the number of requests by the result's kind.
|
||||
NResult []uint64
|
||||
|
||||
// Domains is the number of requests for each domain name.
|
||||
Domains []countPair
|
||||
|
||||
// BlockedDomains is the number of requests blocked for each domain name.
|
||||
BlockedDomains []countPair
|
||||
|
||||
// Clients is the number of requests from each client.
|
||||
Clients []countPair
|
||||
|
||||
// NTotal is the total number of requests.
|
||||
NTotal uint64
|
||||
|
||||
// TimeAvg is the average of processing times in milliseconds of all the
|
||||
// requests in the unit.
|
||||
TimeAvg uint32
|
||||
|
||||
Reference in New Issue
Block a user