all: sync with master; upd chlog

This commit is contained in:
Ainar Garipov
2023-04-12 14:48:42 +03:00
parent 0dad53b5f7
commit d9c57cdd9a
181 changed files with 6992 additions and 3430 deletions

View File

@@ -8,6 +8,7 @@ import (
"net/http"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
)
@@ -52,7 +53,7 @@ const textPlainDeprMsg = `using this api with the text/plain content-type is dep
// deprecation and removal of a plain-text API if the request is made with the
// "text/plain" content-type.
func WriteTextPlainDeprecated(w http.ResponseWriter, r *http.Request) (isPlainText bool) {
if r.Header.Get(HdrNameContentType) != HdrValTextPlain {
if r.Header.Get(httphdr.ContentType) != HdrValTextPlain {
return false
}
@@ -72,7 +73,7 @@ func WriteJSONResponse(w http.ResponseWriter, r *http.Request, resp any) (err er
// redefine the status code.
func WriteJSONResponseCode(w http.ResponseWriter, r *http.Request, code int, resp any) (err error) {
w.WriteHeader(code)
w.Header().Set(HdrNameContentType, HdrValApplicationJSON)
w.Header().Set(httphdr.ContentType, HdrValApplicationJSON)
err = json.NewEncoder(w).Encode(resp)
if err != nil {
Error(r, w, http.StatusInternalServerError, "encoding resp: %s", err)

View File

@@ -1,22 +1,6 @@
package aghhttp
// HTTP Headers
// HTTP header name constants.
//
// TODO(a.garipov): Remove unused.
const (
HdrNameAcceptEncoding = "Accept-Encoding"
HdrNameAccessControlAllowOrigin = "Access-Control-Allow-Origin"
HdrNameAltSvc = "Alt-Svc"
HdrNameContentEncoding = "Content-Encoding"
HdrNameContentType = "Content-Type"
HdrNameOrigin = "Origin"
HdrNameServer = "Server"
HdrNameTrailer = "Trailer"
HdrNameUserAgent = "User-Agent"
HdrNameVary = "Vary"
)
// HTTP headers
// HTTP header value constants.
const (

View File

@@ -1,47 +1,14 @@
package aghnet
import (
"net"
"strconv"
"fmt"
"net/netip"
"strings"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/stringutil"
)
// The maximum lengths of generated hostnames for different IP versions.
const (
ipv4HostnameMaxLen = len("192-168-100-100")
ipv6HostnameMaxLen = len("ff80-f076-0000-0000-0000-0000-0000-0010")
)
// generateIPv4Hostname generates the hostname by IP address version 4.
func generateIPv4Hostname(ipv4 net.IP) (hostname string) {
hnData := make([]byte, 0, ipv4HostnameMaxLen)
for i, part := range ipv4 {
if i > 0 {
hnData = append(hnData, '-')
}
hnData = strconv.AppendUint(hnData, uint64(part), 10)
}
return string(hnData)
}
// generateIPv6Hostname generates the hostname by IP address version 6.
func generateIPv6Hostname(ipv6 net.IP) (hostname string) {
hnData := make([]byte, 0, ipv6HostnameMaxLen)
for i, partsNum := 0, net.IPv6len/2; i < partsNum; i++ {
if i > 0 {
hnData = append(hnData, '-')
}
for _, val := range ipv6[i*2 : i*2+2] {
if val < 10 {
hnData = append(hnData, '0')
}
hnData = strconv.AppendUint(hnData, uint64(val), 16)
}
}
return string(hnData)
}
// GenerateHostname generates the hostname from ip. In case of using IPv4 the
// result should be like:
//
@@ -52,10 +19,42 @@ func generateIPv6Hostname(ipv6 net.IP) (hostname string) {
// ff80-f076-0000-0000-0000-0000-0000-0010
//
// ip must be either an IPv4 or an IPv6.
func GenerateHostname(ip net.IP) (hostname string) {
if ipv4 := ip.To4(); ipv4 != nil {
return generateIPv4Hostname(ipv4)
func GenerateHostname(ip netip.Addr) (hostname string) {
if !ip.IsValid() {
// TODO(s.chzhen): Get rid of it.
panic("aghnet generate hostname: invalid ip")
}
return generateIPv6Hostname(ip)
ip = ip.Unmap()
hostname = ip.StringExpanded()
if ip.Is4() {
return strings.Replace(hostname, ".", "-", -1)
}
return strings.Replace(hostname, ":", "-", -1)
}
// NewDomainNameSet returns nil and error, if list has duplicate or empty
// domain name. Otherwise returns a set, which contains non-FQDN domain names,
// and nil error.
func NewDomainNameSet(list []string) (set *stringutil.Set, err error) {
set = stringutil.NewSet()
for i, v := range list {
host := strings.ToLower(strings.TrimSuffix(v, "."))
// TODO(a.garipov): Think about ignoring empty (".") names in the
// future.
if host == "" {
return nil, errors.Error("host name is empty")
}
if set.Has(host) {
return nil, fmt.Errorf("duplicate host name %q at index %d", host, i)
}
set.Add(host)
}
return set, nil
}

View File

@@ -1,7 +1,7 @@
package aghnet
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
@@ -12,19 +12,19 @@ func TestGenerateHostName(t *testing.T) {
testCases := []struct {
name string
want string
ip net.IP
ip netip.Addr
}{{
name: "good_ipv4",
want: "127-0-0-1",
ip: net.IP{127, 0, 0, 1},
ip: netip.MustParseAddr("127.0.0.1"),
}, {
name: "good_ipv6",
want: "fe00-0000-0000-0000-0000-0000-0000-0001",
ip: net.ParseIP("fe00::1"),
ip: netip.MustParseAddr("fe00::1"),
}, {
name: "4to6",
want: "1-2-3-4",
ip: net.ParseIP("::ffff:1.2.3.4"),
ip: netip.MustParseAddr("::ffff:1.2.3.4"),
}}
for _, tc := range testCases {
@@ -36,29 +36,6 @@ func TestGenerateHostName(t *testing.T) {
})
t.Run("invalid", func(t *testing.T) {
testCases := []struct {
name string
ip net.IP
}{{
name: "bad_ipv4",
ip: net.IP{127, 0, 0, 1, 0},
}, {
name: "bad_ipv6",
ip: net.IP{
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff,
},
}, {
name: "nil",
ip: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Panics(t, func() { GenerateHostname(tc.ip) })
})
}
assert.Panics(t, func() { GenerateHostname(netip.Addr{}) })
})
}

View File

@@ -56,7 +56,7 @@ func (rm *requestMatcher) MatchRequest(
) (res *urlfilter.DNSResult, ok bool) {
switch req.DNSType {
case dns.TypeA, dns.TypeAAAA, dns.TypePTR:
log.Debug("%s: handling the request for %s", hostsContainerPref, req.Hostname)
log.Debug("%s: handling the request for %s", hostsContainerPrefix, req.Hostname)
default:
return nil, false
}
@@ -90,9 +90,9 @@ func (rm *requestMatcher) resetEng(rulesStrg *filterlist.RuleStorage, tr map[str
rm.translator = tr
}
// hostsContainerPref is a prefix for logging and wrapping errors in
// hostsContainerPrefix is a prefix for logging and wrapping errors in
// HostsContainer's methods.
const hostsContainerPref = "hosts container"
const hostsContainerPrefix = "hosts container"
// HostsContainer stores the relevant hosts database provided by the OS and
// processes both A/AAAA and PTR DNS requests for those.
@@ -115,8 +115,8 @@ type HostsContainer struct {
// fsys is the working file system to read hosts files from.
fsys fs.FS
// w tracks the changes in specified files and directories.
w aghos.FSWatcher
// watcher tracks the changes in specified files and directories.
watcher aghos.FSWatcher
// patterns stores specified paths in the fs.Glob-compatible form.
patterns []string
@@ -160,7 +160,7 @@ func NewHostsContainer(
w aghos.FSWatcher,
paths ...string,
) (hc *HostsContainer, err error) {
defer func() { err = errors.Annotate(err, "%s: %w", hostsContainerPref) }()
defer func() { err = errors.Annotate(err, "%s: %w", hostsContainerPrefix) }()
if len(paths) == 0 {
return nil, ErrNoHostsPaths
@@ -182,11 +182,11 @@ func NewHostsContainer(
done: make(chan struct{}, 1),
updates: make(chan HostsRecords, 1),
fsys: fsys,
w: w,
watcher: w,
patterns: patterns,
}
log.Debug("%s: starting", hostsContainerPref)
log.Debug("%s: starting", hostsContainerPrefix)
// Load initially.
if err = hc.refresh(); err != nil {
@@ -199,7 +199,7 @@ func NewHostsContainer(
return nil, fmt.Errorf("adding path: %w", err)
}
log.Debug("%s: %s is expected to exist but doesn't", hostsContainerPref, p)
log.Debug("%s: %s is expected to exist but doesn't", hostsContainerPrefix, p)
}
}
@@ -208,14 +208,21 @@ func NewHostsContainer(
return hc, nil
}
// Close implements the io.Closer interface for *HostsContainer. Close must
// only be called once. The returned err is always nil.
// Close implements the [io.Closer] interface for *HostsContainer. It closes
// both itself and its [aghos.FSWatcher]. Close must only be called once.
func (hc *HostsContainer) Close() (err error) {
log.Debug("%s: closing", hostsContainerPref)
log.Debug("%s: closing", hostsContainerPrefix)
err = hc.watcher.Close()
if err != nil {
err = fmt.Errorf("closing fs watcher: %w", err)
// Go on and close the container either way.
}
close(hc.done)
return nil
return err
}
// Upd returns the channel into which the updates are sent.
@@ -251,22 +258,22 @@ func pathsToPatterns(fsys fs.FS, paths []string) (patterns []string, err error)
// update channel of HostsContainer when finishes. It's used to be called
// within a separate goroutine.
func (hc *HostsContainer) handleEvents() {
defer log.OnPanic(fmt.Sprintf("%s: handling events", hostsContainerPref))
defer log.OnPanic(fmt.Sprintf("%s: handling events", hostsContainerPrefix))
defer close(hc.updates)
ok, eventsCh := true, hc.w.Events()
ok, eventsCh := true, hc.watcher.Events()
for ok {
select {
case _, ok = <-eventsCh:
if !ok {
log.Debug("%s: watcher closed the events channel", hostsContainerPref)
log.Debug("%s: watcher closed the events channel", hostsContainerPrefix)
continue
}
if err := hc.refresh(); err != nil {
log.Error("%s: %s", hostsContainerPref, err)
log.Error("%s: %s", hostsContainerPrefix, err)
}
case _, ok = <-hc.done:
// Go on.
@@ -345,7 +352,7 @@ func (hp *hostsParser) parseLine(line string) (ip netip.Addr, hosts []string) {
// TODO(e.burkov): Investigate if hosts may contain DNS-SD domains.
err = netutil.ValidateHostname(f)
if err != nil {
log.Error("%s: host %q is invalid, ignoring", hostsContainerPref, f)
log.Error("%s: host %q is invalid, ignoring", hostsContainerPrefix, f)
continue
}
@@ -389,7 +396,7 @@ func (hp *hostsParser) addRules(ip netip.Addr, host, line string) {
rule, rulePtr := hp.writeRules(host, ip)
hp.translations[rule], hp.translations[rulePtr] = line, line
log.Debug("%s: added ip-host pair %q-%q", hostsContainerPref, ip, host)
log.Debug("%s: added ip-host pair %q-%q", hostsContainerPrefix, ip, host)
}
// writeRules writes the actual rule for the qtype and the PTR for the host-ip
@@ -443,7 +450,7 @@ func (hp *hostsParser) writeRules(host string, ip netip.Addr) (rule, rulePtr str
// sendUpd tries to send the parsed data to the ch.
func (hp *hostsParser) sendUpd(ch chan HostsRecords) {
log.Debug("%s: sending upd", hostsContainerPref)
log.Debug("%s: sending upd", hostsContainerPrefix)
upd := hp.table
select {
@@ -451,11 +458,11 @@ func (hp *hostsParser) sendUpd(ch chan HostsRecords) {
// Updates are delivered. Go on.
case <-ch:
ch <- upd
log.Debug("%s: replaced the last update", hostsContainerPref)
log.Debug("%s: replaced the last update", hostsContainerPrefix)
case ch <- upd:
// The previous update was just read and the next one pushed. Go on.
default:
log.Error("%s: the updates channel is broken", hostsContainerPref)
log.Error("%s: the updates channel is broken", hostsContainerPrefix)
}
}
@@ -473,7 +480,7 @@ func (hp *hostsParser) newStrg(id int) (s *filterlist.RuleStorage, err error) {
//
// TODO(e.burkov): Accept a parameter to specify the files to refresh.
func (hc *HostsContainer) refresh() (err error) {
log.Debug("%s: refreshing", hostsContainerPref)
log.Debug("%s: refreshing", hostsContainerPrefix)
hp := hc.newHostsParser()
if _, err = aghos.FileWalker(hp.parseFile).Walk(hc.fsys, hc.patterns...); err != nil {
@@ -482,7 +489,7 @@ func (hc *HostsContainer) refresh() (err error) {
// hc.last is nil on the first refresh, so let that one through.
if hc.last != nil && maps.EqualFunc(hp.table, hc.last, (*HostsRecord).equal) {
log.Debug("%s: no changes detected", hostsContainerPref)
log.Debug("%s: no changes detected", hostsContainerPrefix)
return nil
}

View File

@@ -35,4 +35,4 @@
1.3.5.7 domain4 domain4.alias
7.5.3.1 domain4.alias domain4
::13 domain6 domain6.alias
::31 domain6.alias domain6
::31 domain6.alias domain6

View File

@@ -1 +1 @@
iface sample_name inet static
iface sample_name inet static

View File

@@ -2,4 +2,4 @@
# parent directory. Real interface files usually contain only absolute paths.
source ./testdata/ifaces
source ./testdata/*
source ./testdata/*

View File

@@ -3,4 +3,4 @@ IP address HW type Flags HW address Mask Device
::ffff:ffff 0x1 0x0 ef:cd:ab:ef:cd:ab * br-lan
0.0.0.0 0x0 0x0 00:00:00:00:00:00 * unspec
1.2.3.4.5 0x1 0x2 aa:bb:cc:dd:ee:ff * wan
1.2.3.4 0x1 0x2 12:34:56:78:910 * wan
1.2.3.4 0x1 0x2 12:34:56:78:910 * wan

View File

@@ -1,10 +0,0 @@
//go:build mips || mips64
// This file is an adapted version of github.com/josharian/native.
package aghos
import "encoding/binary"
// NativeEndian is the native endianness of this system.
var NativeEndian = binary.BigEndian

View File

@@ -1,10 +0,0 @@
//go:build amd64 || 386 || arm || arm64 || mipsle || mips64le || ppc64le
// This file is an adapted version of github.com/josharian/native.
package aghos
import "encoding/binary"
// NativeEndian is the native endianness of this system.
var NativeEndian = binary.LittleEndian

View File

@@ -15,18 +15,16 @@ import (
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv4/server4"
"github.com/mdlayher/ethernet"
//lint:ignore SA1019 See the TODO in go.mod.
"github.com/mdlayher/raw"
"github.com/mdlayher/packet"
)
// dhcpUnicastAddr is the combination of MAC and IP addresses for responding to
// the unconfigured host.
type dhcpUnicastAddr struct {
// raw.Addr is embedded here to make *dhcpUcastAddr a net.Addr without
// packet.Addr is embedded here to make *dhcpUcastAddr a net.Addr without
// actually implementing all methods. It also contains the client's
// hardware address.
raw.Addr
packet.Addr
// yiaddr is an IP address just allocated by server for the host.
yiaddr net.IP
@@ -52,7 +50,7 @@ type dhcpConn struct {
// newDHCPConn creates the special connection for DHCP server.
func (s *v4Server) newDHCPConn(iface *net.Interface) (c net.PacketConn, err error) {
var ucast net.PacketConn
if ucast, err = raw.ListenPacket(iface, uint16(ethernet.EtherTypeIPv4), nil); err != nil {
if ucast, err = packet.Listen(iface, packet.Raw, int(ethernet.EtherTypeIPv4), nil); err != nil {
return nil, fmt.Errorf("creating raw udp connection: %w", err)
}

View File

@@ -10,11 +10,9 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/mdlayher/packet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
//lint:ignore SA1019 See the TODO in go.mod.
"github.com/mdlayher/raw"
)
func TestDHCPConn_WriteTo_common(t *testing.T) {
@@ -57,7 +55,7 @@ func TestBuildEtherPkt(t *testing.T) {
srcIP: net.IP{1, 2, 3, 4},
}
peer := &dhcpUnicastAddr{
Addr: raw.Addr{HardwareAddr: net.HardwareAddr{6, 5, 4, 3, 2, 1}},
Addr: packet.Addr{HardwareAddr: net.HardwareAddr{6, 5, 4, 3, 2, 1}},
yiaddr: net.IP{4, 3, 2, 1},
}
payload := (&dhcpv4.DHCPv4{}).ToBytes()
@@ -102,7 +100,7 @@ func TestBuildEtherPkt(t *testing.T) {
t.Run("serializing_error", func(t *testing.T) {
// Create a peer with invalid MAC.
badPeer := &dhcpUnicastAddr{
Addr: raw.Addr{HardwareAddr: net.HardwareAddr{5, 4, 3, 2, 1}},
Addr: packet.Addr{HardwareAddr: net.HardwareAddr{5, 4, 3, 2, 1}},
yiaddr: net.IP{4, 3, 2, 1},
}

View File

@@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"net"
"net/netip"
"os"
"time"
@@ -32,6 +33,8 @@ func normalizeIP(ip net.IP) net.IP {
}
// Load lease table from DB
//
// TODO(s.chzhen): Decrease complexity.
func (s *server) dbLoad() (err error) {
dynLeases := []*Lease{}
staticLeases := []*Lease{}
@@ -57,26 +60,28 @@ func (s *server) dbLoad() (err error) {
for i := range obj {
obj[i].IP = normalizeIP(obj[i].IP)
if !(len(obj[i].IP) == 4 || len(obj[i].IP) == 16) {
ip, ok := netip.AddrFromSlice(obj[i].IP)
if !ok {
log.Info("dhcp: invalid IP: %s", obj[i].IP)
continue
}
lease := Lease{
HWAddr: obj[i].HWAddr,
IP: obj[i].IP,
IP: ip,
Hostname: obj[i].Hostname,
Expiry: time.Unix(obj[i].Expiry, 0),
IsStatic: obj[i].Expiry == leaseExpireStatic,
}
if len(obj[i].IP) == 16 {
if obj[i].Expiry == leaseExpireStatic {
if lease.IsStatic {
v6StaticLeases = append(v6StaticLeases, &lease)
} else {
v6DynLeases = append(v6DynLeases, &lease)
}
} else {
if obj[i].Expiry == leaseExpireStatic {
if lease.IsStatic {
staticLeases = append(staticLeases, &lease)
} else {
dynLeases = append(dynLeases, &lease)
@@ -145,7 +150,7 @@ func (s *server) dbStore() (err error) {
lease := leaseJSON{
HWAddr: l.HWAddr,
IP: l.IP,
IP: l.IP.AsSlice(),
Hostname: l.Hostname,
Expiry: l.Expiry.Unix(),
}
@@ -162,7 +167,7 @@ func (s *server) dbStore() (err error) {
lease := leaseJSON{
HWAddr: l.HWAddr,
IP: l.IP,
IP: l.IP.AsSlice(),
Hostname: l.Hostname,
Expiry: l.Expiry.Unix(),
}

View File

@@ -41,13 +41,19 @@ type Lease struct {
// of 1 means that this is a static lease.
Expiry time.Time `json:"expires"`
Hostname string `json:"hostname"`
HWAddr net.HardwareAddr `json:"mac"`
// Hostname of the client.
Hostname string `json:"hostname"`
// HWAddr is the physical hardware address (MAC address).
HWAddr net.HardwareAddr `json:"mac"`
// IP is the IP address leased to the client.
//
// TODO(a.garipov): Migrate leases.db and use netip.Addr.
IP net.IP `json:"ip"`
// TODO(a.garipov): Migrate leases.db.
IP netip.Addr `json:"ip"`
// IsStatic defines if the lease is static.
IsStatic bool `json:"static"`
}
// Clone returns a deep copy of l.
@@ -60,7 +66,8 @@ func (l *Lease) Clone() (clone *Lease) {
Expiry: l.Expiry,
Hostname: l.Hostname,
HWAddr: slices.Clone(l.HWAddr),
IP: slices.Clone(l.IP),
IP: l.IP,
IsStatic: l.IsStatic,
}
}
@@ -81,17 +88,10 @@ func (l *Lease) IsBlocklisted() (ok bool) {
return true
}
// IsStatic returns true if the lease is static.
//
// TODO(a.garipov): Just make it a boolean field.
func (l *Lease) IsStatic() (ok bool) {
return l != nil && l.Expiry.Unix() == leaseExpireStatic
}
// MarshalJSON implements the json.Marshaler interface for Lease.
func (l Lease) MarshalJSON() ([]byte, error) {
var expiryStr string
if !l.IsStatic() {
if !l.IsStatic {
// The front-end is waiting for RFC 3999 format of the time
// value. It also shouldn't got an Expiry field for static
// leases.

View File

@@ -48,11 +48,11 @@ func TestDB(t *testing.T) {
Expiry: time.Now().Add(time.Hour),
Hostname: "static-1.local",
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.IP{192, 168, 10, 100},
IP: netip.MustParseAddr("192.168.10.100"),
}, {
Hostname: "static-2.local",
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xBB},
IP: net.IP{192, 168, 10, 101},
IP: netip.MustParseAddr("192.168.10.101"),
}}
srv4, ok := s.srv4.(*v4Server)
@@ -80,7 +80,7 @@ func TestDB(t *testing.T) {
assert.Equal(t, leases[1].HWAddr, ll[0].HWAddr)
assert.Equal(t, leases[1].IP, ll[0].IP)
assert.True(t, ll[0].IsStatic())
assert.True(t, ll[0].IsStatic)
assert.Equal(t, leases[0].HWAddr, ll[1].HWAddr)
assert.Equal(t, leases[0].IP, ll[1].IP)
@@ -96,7 +96,7 @@ func TestNormalizeLeases(t *testing.T) {
staticLeases := []*Lease{{
HWAddr: net.HardwareAddr{1, 2, 3, 4},
IP: net.IP{0, 2, 3, 4},
IP: netip.MustParseAddr("0.2.3.4"),
}, {
HWAddr: net.HardwareAddr{2, 2, 3, 4},
}}

View File

@@ -9,6 +9,7 @@ import (
"net/http"
"net/netip"
"os"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
@@ -57,12 +58,77 @@ func v6JSONToServerConf(j *v6ServerConfJSON) V6ServerConf {
// dhcpStatusResponse is the response for /control/dhcp/status endpoint.
type dhcpStatusResponse struct {
IfaceName string `json:"interface_name"`
V4 V4ServerConf `json:"v4"`
V6 V6ServerConf `json:"v6"`
Leases []*Lease `json:"leases"`
StaticLeases []*Lease `json:"static_leases"`
Enabled bool `json:"enabled"`
IfaceName string `json:"interface_name"`
V4 V4ServerConf `json:"v4"`
V6 V6ServerConf `json:"v6"`
Leases []*leaseDynamic `json:"leases"`
StaticLeases []*leaseStatic `json:"static_leases"`
Enabled bool `json:"enabled"`
}
// leaseStatic is the JSON form of static DHCP lease.
type leaseStatic struct {
HWAddr string `json:"mac"`
IP netip.Addr `json:"ip"`
Hostname string `json:"hostname"`
}
// leasesToStatic converts list of leases to their JSON form.
func leasesToStatic(leases []*Lease) (static []*leaseStatic) {
static = make([]*leaseStatic, len(leases))
for i, l := range leases {
static[i] = &leaseStatic{
HWAddr: l.HWAddr.String(),
IP: l.IP,
Hostname: l.Hostname,
}
}
return static
}
// toLease converts leaseStatic to Lease or returns error.
func (l *leaseStatic) toLease() (lease *Lease, err error) {
addr, err := net.ParseMAC(l.HWAddr)
if err != nil {
return nil, fmt.Errorf("couldn't parse MAC address: %w", err)
}
return &Lease{
HWAddr: addr,
IP: l.IP,
Hostname: l.Hostname,
IsStatic: true,
}, nil
}
// leaseDynamic is the JSON form of dynamic DHCP lease.
type leaseDynamic struct {
HWAddr string `json:"mac"`
IP netip.Addr `json:"ip"`
Hostname string `json:"hostname"`
Expiry string `json:"expires"`
}
// leasesToDynamic converts list of leases to their JSON form.
func leasesToDynamic(leases []*Lease) (dynamic []*leaseDynamic) {
dynamic = make([]*leaseDynamic, len(leases))
for i, l := range leases {
dynamic[i] = &leaseDynamic{
HWAddr: l.HWAddr.String(),
IP: l.IP,
Hostname: l.Hostname,
// The front-end is waiting for RFC 3999 format of the time
// value.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2692.
Expiry: l.Expiry.Format(time.RFC3339),
}
}
return dynamic
}
func (s *server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
@@ -76,8 +142,8 @@ func (s *server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
s.srv4.WriteDiskConfig4(&status.V4)
s.srv6.WriteDiskConfig6(&status.V6)
status.Leases = s.Leases(LeasesDynamic)
status.StaticLeases = s.Leases(LeasesStatic)
status.Leases = leasesToDynamic(s.Leases(LeasesDynamic))
status.StaticLeases = leasesToStatic(s.Leases(LeasesStatic))
_ = aghhttp.WriteJSONResponse(w, r, status)
}
@@ -488,7 +554,7 @@ func setOtherDHCPResult(ifaceName string, result *dhcpSearchResult) {
}
func (s *server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
l := &Lease{}
l := &leaseStatic{}
err := json.NewDecoder(r.Body).Decode(l)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
@@ -496,22 +562,29 @@ func (s *server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
return
}
if l.IP == nil {
if !l.IP.IsValid() {
aghhttp.Error(r, w, http.StatusBadRequest, "invalid IP")
return
}
l.IP = l.IP.Unmap()
var srv DHCPServer
if ip4 := l.IP.To4(); ip4 != nil {
l.IP = ip4
if l.IP.Is4() {
srv = s.srv4
} else {
l.IP = l.IP.To16()
srv = s.srv6
}
err = srv.AddStaticLease(l)
lease, err := l.toLease()
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "parsing: %s", err)
return
}
err = srv.AddStaticLease(lease)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@@ -520,7 +593,7 @@ func (s *server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
}
func (s *server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
l := &Lease{}
l := &leaseStatic{}
err := json.NewDecoder(r.Body).Decode(l)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
@@ -528,27 +601,29 @@ func (s *server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ
return
}
if l.IP == nil {
if !l.IP.IsValid() {
aghhttp.Error(r, w, http.StatusBadRequest, "invalid IP")
return
}
ip4 := l.IP.To4()
l.IP = l.IP.Unmap()
if ip4 == nil {
l.IP = l.IP.To16()
var srv DHCPServer
if l.IP.Is4() {
srv = s.srv4
} else {
srv = s.srv6
}
err = s.srv6.RemoveStaticLease(l)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
}
lease, err := l.toLease()
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "parsing: %s", err)
return
}
l.IP = ip4
err = s.srv4.RemoveStaticLease(l)
err = srv.RemoveStaticLease(lease)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)

View File

@@ -0,0 +1,160 @@
//go:build darwin || freebsd || linux || openbsd
package dhcpd
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServer_handleDHCPStatus(t *testing.T) {
const (
staticName = "static-client"
staticMAC = "aa:aa:aa:aa:aa:aa"
)
staticIP := netip.MustParseAddr("192.168.10.10")
staticLease := &leaseStatic{
HWAddr: staticMAC,
IP: staticIP,
Hostname: staticName,
}
s, err := Create(&ServerConfig{
Enabled: true,
Conf4: *defaultV4ServerConf(),
WorkDir: t.TempDir(),
DBFilePath: dbFilename,
ConfigModified: func() {},
})
require.NoError(t, err)
// checkStatus is a helper that asserts the response of
// [*server.handleDHCPStatus].
checkStatus := func(t *testing.T, want *dhcpStatusResponse) {
w := httptest.NewRecorder()
var req *http.Request
req, err = http.NewRequest(http.MethodGet, "", nil)
require.NoError(t, err)
b := &bytes.Buffer{}
err = json.NewEncoder(b).Encode(&want)
require.NoError(t, err)
s.handleDHCPStatus(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.JSONEq(t, b.String(), w.Body.String())
}
// defaultResponse is a helper that returs the response with default
// configuration.
defaultResponse := func() *dhcpStatusResponse {
conf4 := defaultV4ServerConf()
conf4.LeaseDuration = 86400
resp := &dhcpStatusResponse{
V4: *conf4,
V6: V6ServerConf{},
Leases: []*leaseDynamic{},
StaticLeases: []*leaseStatic{},
Enabled: true,
}
return resp
}
ok := t.Run("status", func(t *testing.T) {
resp := defaultResponse()
checkStatus(t, resp)
})
require.True(t, ok)
ok = t.Run("add_static_lease", func(t *testing.T) {
w := httptest.NewRecorder()
b := &bytes.Buffer{}
err = json.NewEncoder(b).Encode(staticLease)
require.NoError(t, err)
var r *http.Request
r, err = http.NewRequest(http.MethodPost, "", b)
require.NoError(t, err)
s.handleDHCPAddStaticLease(w, r)
assert.Equal(t, http.StatusOK, w.Code)
resp := defaultResponse()
resp.StaticLeases = []*leaseStatic{staticLease}
checkStatus(t, resp)
})
require.True(t, ok)
ok = t.Run("add_invalid_lease", func(t *testing.T) {
w := httptest.NewRecorder()
b := &bytes.Buffer{}
err = json.NewEncoder(b).Encode(&leaseStatic{})
require.NoError(t, err)
var r *http.Request
r, err = http.NewRequest(http.MethodPost, "", b)
require.NoError(t, err)
s.handleDHCPAddStaticLease(w, r)
assert.Equal(t, http.StatusBadRequest, w.Code)
})
require.True(t, ok)
ok = t.Run("remove_static_lease", func(t *testing.T) {
w := httptest.NewRecorder()
b := &bytes.Buffer{}
err = json.NewEncoder(b).Encode(staticLease)
require.NoError(t, err)
var r *http.Request
r, err = http.NewRequest(http.MethodPost, "", b)
require.NoError(t, err)
s.handleDHCPRemoveStaticLease(w, r)
assert.Equal(t, http.StatusOK, w.Code)
resp := defaultResponse()
checkStatus(t, resp)
})
require.True(t, ok)
ok = t.Run("set_config", func(t *testing.T) {
w := httptest.NewRecorder()
resp := defaultResponse()
resp.Enabled = false
b := &bytes.Buffer{}
err = json.NewEncoder(b).Encode(&resp)
require.NoError(t, err)
var r *http.Request
r, err = http.NewRequest(http.MethodPost, "", b)
require.NoError(t, err)
s.handleDHCPSetConfig(w, r)
assert.Equal(t, http.StatusOK, w.Code)
checkStatus(t, resp)
})
require.True(t, ok)
}

View File

@@ -16,6 +16,8 @@ import (
//
// TODO(a.garipov): Perhaps create an optimized version with uint32 for IPv4
// ranges? Or use one of uint128 packages?
//
// TODO(e.burkov): Use netip.Addr.
type ipRange struct {
start *big.Int
end *big.Int
@@ -27,8 +29,6 @@ const maxRangeLen = math.MaxUint32
// newIPRange creates a new IP address range. start must be less than end. The
// resulting range must not be greater than maxRangeLen.
//
// TODO(e.burkov): Use netip.Addr.
func newIPRange(start, end net.IP) (r *ipRange, err error) {
defer func() { err = errors.Annotate(err, "invalid ip range: %w") }()

View File

@@ -20,10 +20,8 @@ import (
"github.com/go-ping/ping"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv4/server4"
"github.com/mdlayher/packet"
"golang.org/x/exp/slices"
//lint:ignore SA1019 See the TODO in go.mod.
"github.com/mdlayher/raw"
)
// v4Server is a DHCPv4 server.
@@ -98,7 +96,7 @@ func normalizeHostname(hostname string) (norm string, err error) {
// validHostnameForClient accepts the hostname sent by the client and its IP and
// returns either a normalized version of that hostname, or a new hostname
// generated from the IP address, or an empty string.
func (s *v4Server) validHostnameForClient(cliHostname string, ip net.IP) (hostname string) {
func (s *v4Server) validHostnameForClient(cliHostname string, ip netip.Addr) (hostname string) {
hostname, err := normalizeHostname(cliHostname)
if err != nil {
log.Info("dhcpv4: %s", err)
@@ -130,7 +128,7 @@ func (s *v4Server) ResetLeases(leases []*Lease) (err error) {
s.leases = nil
for _, l := range leases {
if !l.IsStatic() {
if !l.IsStatic {
l.Hostname = s.validHostnameForClient(l.Hostname, l.IP)
}
err = s.addLease(l)
@@ -192,7 +190,7 @@ func (s *v4Server) GetLeases(flags GetLeasesFlags) (leases []*Lease) {
continue
}
if getStatic && l.IsStatic() {
if getStatic && l.IsStatic {
leases = append(leases, l.Clone())
}
}
@@ -211,10 +209,9 @@ func (s *v4Server) FindMACbyIP(ip netip.Addr) (mac net.HardwareAddr) {
return nil
}
netIP := ip.AsSlice()
for _, l := range s.leases {
if l.IP.Equal(netIP) {
if l.Expiry.After(now) || l.IsStatic() {
if l.IP == ip {
if l.IsStatic || l.Expiry.After(now) {
return l.HWAddr
}
}
@@ -247,7 +244,8 @@ func (s *v4Server) rmLeaseByIndex(i int) {
s.leases = append(s.leases[:i], s.leases[i+1:]...)
r := s.conf.ipRange
offset, ok := r.offset(l.IP)
leaseIP := net.IP(l.IP.AsSlice())
offset, ok := r.offset(leaseIP)
if ok {
s.leasedOffsets.set(offset, false)
}
@@ -261,9 +259,9 @@ func (s *v4Server) rmLeaseByIndex(i int) {
// Return error if a static lease is found
func (s *v4Server) rmDynamicLease(lease *Lease) (err error) {
for i, l := range s.leases {
isStatic := l.IsStatic()
isStatic := l.IsStatic
if bytes.Equal(l.HWAddr, lease.HWAddr) || l.IP.Equal(lease.IP) {
if bytes.Equal(l.HWAddr, lease.HWAddr) || l.IP == lease.IP {
if isStatic {
return errors.Error("static lease already exists")
}
@@ -291,13 +289,13 @@ const ErrDupHostname = errors.Error("hostname is not unique")
// addLease adds a dynamic or static lease.
func (s *v4Server) addLease(l *Lease) (err error) {
r := s.conf.ipRange
offset, inOffset := r.offset(l.IP)
leaseIP := net.IP(l.IP.AsSlice())
offset, inOffset := r.offset(leaseIP)
if l.IsStatic() {
if l.IsStatic {
// TODO(a.garipov, d.seregin): Subnet can be nil when dhcp server is
// disabled.
addr := netip.AddrFrom4(*(*[4]byte)(l.IP.To4()))
if sn := s.conf.subnet; !sn.Contains(addr) {
if sn := s.conf.subnet; !sn.Contains(l.IP) {
return fmt.Errorf("subnet %s does not contain the ip %q", sn, l.IP)
}
} else if !inOffset {
@@ -325,7 +323,7 @@ func (s *v4Server) rmLease(lease *Lease) (err error) {
}
for i, l := range s.leases {
if l.IP.Equal(lease.IP) {
if l.IP == lease.IP {
if !bytes.Equal(l.HWAddr, lease.HWAddr) || l.Hostname != lease.Hostname {
return fmt.Errorf("lease for ip %s is different: %+v", lease.IP, l)
}
@@ -352,14 +350,16 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
return ErrUnconfigured
}
ip := l.IP.To4()
if ip == nil {
l.IP = l.IP.Unmap()
if !l.IP.Is4() {
return fmt.Errorf("invalid ip %q, only ipv4 is supported", l.IP)
} else if gwIP := s.conf.GatewayIP; gwIP == netip.AddrFrom4(*(*[4]byte)(ip)) {
} else if gwIP := s.conf.GatewayIP; gwIP == l.IP {
return fmt.Errorf("can't assign the gateway IP %s to the lease", gwIP)
}
l.Expiry = time.Unix(leaseExpireStatic, 0)
l.IsStatic = true
err = netutil.ValidateMAC(l.HWAddr)
if err != nil {
@@ -396,7 +396,7 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
if err != nil {
err = fmt.Errorf(
"removing dynamic leases for %s (%s): %w",
ip,
l.IP,
l.HWAddr,
err,
)
@@ -406,7 +406,7 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
err = s.addLease(l)
if err != nil {
err = fmt.Errorf("adding static lease for %s (%s): %w", ip, l.HWAddr, err)
err = fmt.Errorf("adding static lease for %s (%s): %w", l.IP, l.HWAddr, err)
return
}
@@ -429,7 +429,7 @@ func (s *v4Server) RemoveStaticLease(l *Lease) (err error) {
return ErrUnconfigured
}
if len(l.IP) != 4 {
if !l.IP.Is4() {
return fmt.Errorf("invalid IP")
}
@@ -529,7 +529,7 @@ func (s *v4Server) nextIP() (ip net.IP) {
func (s *v4Server) findExpiredLease() int {
now := time.Now()
for i, lease := range s.leases {
if !lease.IsStatic() && lease.Expiry.Before(now) {
if !lease.IsStatic && lease.Expiry.Before(now) {
return i
}
}
@@ -542,8 +542,8 @@ func (s *v4Server) findExpiredLease() int {
func (s *v4Server) reserveLease(mac net.HardwareAddr) (l *Lease, err error) {
l = &Lease{HWAddr: slices.Clone(mac)}
l.IP = s.nextIP()
if l.IP == nil {
nextIP := s.nextIP()
if nextIP == nil {
i := s.findExpiredLease()
if i < 0 {
return nil, nil
@@ -554,6 +554,13 @@ func (s *v4Server) reserveLease(mac net.HardwareAddr) (l *Lease, err error) {
return s.leases[i], nil
}
netIP, ok := netip.AddrFromSlice(nextIP)
if !ok {
return nil, errors.Error("invalid ip")
}
l.IP = netIP
err = s.addLease(l)
if err != nil {
return nil, err
@@ -603,7 +610,8 @@ func (s *v4Server) allocateLease(mac net.HardwareAddr) (l *Lease, err error) {
return nil, nil
}
if s.addrAvailable(l.IP) {
leaseIP := l.IP.AsSlice()
if s.addrAvailable(leaseIP) {
return l, nil
}
@@ -623,8 +631,9 @@ func (s *v4Server) handleDiscover(req, resp *dhcpv4.DHCPv4) (l *Lease, err error
l = s.findLease(mac)
if l != nil {
reqIP := req.RequestedIPAddress()
if len(reqIP) != 0 && !reqIP.Equal(l.IP) {
log.Debug("dhcpv4: different RequestedIP: %s != %s", reqIP, l.IP)
leaseIP := net.IP(l.IP.AsSlice())
if len(reqIP) != 0 && !reqIP.Equal(leaseIP) {
log.Debug("dhcpv4: different RequestedIP: %s != %s", reqIP, leaseIP)
}
resp.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeOffer))
@@ -674,12 +683,19 @@ func (s *v4Server) checkLease(mac net.HardwareAddr, ip net.IP) (lease *Lease, mi
s.leasesLock.Lock()
defer s.leasesLock.Unlock()
netIP, ok := netip.AddrFromSlice(ip)
if !ok {
log.Info("check lease: invalid IP: %s", ip)
return nil, false
}
for _, l := range s.leases {
if !bytes.Equal(l.HWAddr, mac) {
continue
}
if l.IP.Equal(ip) {
if l.IP == netIP {
return l, false
}
@@ -845,7 +861,7 @@ func (s *v4Server) handleRequest(req, resp *dhcpv4.DHCPv4) (lease *Lease, needsR
s.leasesLock.Lock()
defer s.leasesLock.Unlock()
if lease.IsStatic() {
if lease.IsStatic {
if lease.Hostname != "" {
// TODO(e.burkov): This option is used to update the server's DNS
// mapping. The option should only be answered when it has been
@@ -878,9 +894,16 @@ func (s *v4Server) handleDecline(req, resp *dhcpv4.DHCPv4) (err error) {
reqIP = req.ClientIPAddr
}
netIP, ok := netip.AddrFromSlice(reqIP)
if !ok {
log.Info("dhcpv4: invalid IP: %s", reqIP)
return nil
}
var oldLease *Lease
for _, l := range s.leases {
if bytes.Equal(l.HWAddr, mac) && l.IP.Equal(reqIP) {
if bytes.Equal(l.HWAddr, mac) && l.IP == netIP {
oldLease = l
break
@@ -920,8 +943,7 @@ func (s *v4Server) handleDecline(req, resp *dhcpv4.DHCPv4) (err error) {
log.Info("dhcpv4: changed ip from %s to %s for %s", reqIP, newLease.IP, mac)
resp.YourIPAddr = make([]byte, 4)
copy(resp.YourIPAddr, newLease.IP)
resp.YourIPAddr = net.IP(newLease.IP.AsSlice())
resp.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeAck))
@@ -944,8 +966,15 @@ func (s *v4Server) handleRelease(req, resp *dhcpv4.DHCPv4) (err error) {
s.leasesLock.Lock()
defer s.leasesLock.Unlock()
netIP, ok := netip.AddrFromSlice(reqIP)
if !ok {
log.Info("dhcpv4: invalid IP: %s", reqIP)
return nil
}
for _, l := range s.leases {
if !bytes.Equal(l.HWAddr, mac) || !l.IP.Equal(reqIP) {
if !bytes.Equal(l.HWAddr, mac) || l.IP != netIP {
continue
}
@@ -1018,7 +1047,7 @@ func (s *v4Server) handle(req, resp *dhcpv4.DHCPv4) int {
}
if l != nil {
resp.YourIPAddr = slices.Clone(l.IP)
resp.YourIPAddr = net.IP(l.IP.AsSlice())
}
s.updateOptions(req, resp)
@@ -1136,7 +1165,7 @@ func (s *v4Server) send(peer net.Addr, conn net.PacketConn, req, resp *dhcpv4.DH
// Unicast DHCPOFFER and DHCPACK messages to the client's
// hardware address and yiaddr.
peer = &dhcpUnicastAddr{
Addr: raw.Addr{HardwareAddr: req.ClientHWAddr},
Addr: packet.Addr{HardwareAddr: req.ClientHWAddr},
yiaddr: resp.YourIPAddr,
}
default:

View File

@@ -15,11 +15,9 @@ import (
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/mdlayher/packet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
//lint:ignore SA1019 See the TODO in go.mod.
"github.com/mdlayher/raw"
)
var (
@@ -62,7 +60,7 @@ func TestV4Server_leasing(t *testing.T) {
anotherName = "another-client"
)
staticIP := net.IP{192, 168, 10, 10}
staticIP := netip.MustParseAddr("192.168.10.10")
anotherIP := DefaultRangeStart
staticMAC := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
anotherMAC := net.HardwareAddr{0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB}
@@ -75,6 +73,7 @@ func TestV4Server_leasing(t *testing.T) {
Hostname: staticName,
HWAddr: staticMAC,
IP: staticIP,
IsStatic: true,
})
require.NoError(t, err)
@@ -83,7 +82,8 @@ func TestV4Server_leasing(t *testing.T) {
Expiry: time.Unix(leaseExpireStatic, 0),
Hostname: staticName,
HWAddr: anotherMAC,
IP: anotherIP.AsSlice(),
IP: anotherIP,
IsStatic: true,
})
assert.ErrorIs(t, err, ErrDupHostname)
})
@@ -97,7 +97,8 @@ func TestV4Server_leasing(t *testing.T) {
Expiry: time.Unix(leaseExpireStatic, 0),
Hostname: anotherName,
HWAddr: staticMAC,
IP: anotherIP.AsSlice(),
IP: anotherIP,
IsStatic: true,
})
testutil.AssertErrorMsg(t, wantErrMsg, err)
})
@@ -112,6 +113,7 @@ func TestV4Server_leasing(t *testing.T) {
Hostname: anotherName,
HWAddr: anotherMAC,
IP: staticIP,
IsStatic: true,
})
testutil.AssertErrorMsg(t, wantErrMsg, err)
})
@@ -124,13 +126,14 @@ func TestV4Server_leasing(t *testing.T) {
discoverAnOffer := func(
t *testing.T,
name string,
ip net.IP,
netIP netip.Addr,
mac net.HardwareAddr,
) (resp *dhcpv4.DHCPv4) {
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return s.ResetLeases(s.GetLeases(LeasesStatic))
})
ip := net.IP(netIP.AsSlice())
req, err := dhcpv4.NewDiscovery(
mac,
dhcpv4.WithOption(dhcpv4.OptHostName(name)),
@@ -151,7 +154,7 @@ func TestV4Server_leasing(t *testing.T) {
}
t.Run("same_name", func(t *testing.T) {
resp := discoverAnOffer(t, staticName, anotherIP.AsSlice(), anotherMAC)
resp := discoverAnOffer(t, staticName, anotherIP, anotherMAC)
req, err := dhcpv4.NewRequestFromOffer(resp, dhcpv4.WithOption(
dhcpv4.OptHostName(staticName),
@@ -161,11 +164,15 @@ func TestV4Server_leasing(t *testing.T) {
res := s4.handle(req, resp)
require.Positive(t, res)
assert.Equal(t, aghnet.GenerateHostname(resp.YourIPAddr), resp.HostName())
var netIP netip.Addr
netIP, ok = netip.AddrFromSlice(resp.YourIPAddr)
require.True(t, ok)
assert.Equal(t, aghnet.GenerateHostname(netIP), resp.HostName())
})
t.Run("same_mac", func(t *testing.T) {
resp := discoverAnOffer(t, anotherName, anotherIP.AsSlice(), staticMAC)
resp := discoverAnOffer(t, anotherName, anotherIP, staticMAC)
req, err := dhcpv4.NewRequestFromOffer(resp, dhcpv4.WithOption(
dhcpv4.OptHostName(anotherName),
@@ -179,7 +186,8 @@ func TestV4Server_leasing(t *testing.T) {
require.Len(t, fqdnOptData, 3+len(staticName))
assert.Equal(t, []uint8(staticName), fqdnOptData[3:])
assert.Equal(t, staticIP, resp.YourIPAddr)
ip := net.IP(staticIP.AsSlice())
assert.Equal(t, ip, resp.YourIPAddr)
})
t.Run("same_ip", func(t *testing.T) {
@@ -212,7 +220,7 @@ func TestV4Server_AddRemove_static(t *testing.T) {
lease: &Lease{
Hostname: "success.local",
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.IP{192, 168, 10, 150},
IP: netip.MustParseAddr("192.168.10.150"),
},
name: "success",
wantErrMsg: "",
@@ -220,7 +228,7 @@ func TestV4Server_AddRemove_static(t *testing.T) {
lease: &Lease{
Hostname: "probably-router.local",
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: DefaultGatewayIP.AsSlice(),
IP: DefaultGatewayIP,
},
name: "with_gateway_ip",
wantErrMsg: "dhcpv4: adding static lease: " +
@@ -229,7 +237,7 @@ func TestV4Server_AddRemove_static(t *testing.T) {
lease: &Lease{
Hostname: "ip6.local",
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.ParseIP("ffff::1"),
IP: netip.MustParseAddr("ffff::1"),
},
name: "ipv6",
wantErrMsg: `dhcpv4: adding static lease: ` +
@@ -238,7 +246,7 @@ func TestV4Server_AddRemove_static(t *testing.T) {
lease: &Lease{
Hostname: "bad-mac.local",
HWAddr: net.HardwareAddr{0xAA, 0xAA},
IP: net.IP{192, 168, 10, 150},
IP: netip.MustParseAddr("192.168.10.150"),
},
name: "bad_mac",
wantErrMsg: `dhcpv4: adding static lease: bad mac address "aa:aa": ` +
@@ -247,7 +255,7 @@ func TestV4Server_AddRemove_static(t *testing.T) {
lease: &Lease{
Hostname: "bad-lbl-.local",
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.IP{192, 168, 10, 150},
IP: netip.MustParseAddr("192.168.10.150"),
},
name: "bad_hostname",
wantErrMsg: `dhcpv4: adding static lease: validating hostname: ` +
@@ -289,11 +297,11 @@ func TestV4_AddReplace(t *testing.T) {
dynLeases := []Lease{{
Hostname: "dynamic-1.local",
HWAddr: net.HardwareAddr{0x11, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.IP{192, 168, 10, 150},
IP: netip.MustParseAddr("192.168.10.150"),
}, {
Hostname: "dynamic-2.local",
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.IP{192, 168, 10, 151},
IP: netip.MustParseAddr("192.168.10.151"),
}}
for i := range dynLeases {
@@ -304,11 +312,11 @@ func TestV4_AddReplace(t *testing.T) {
stLeases := []*Lease{{
Hostname: "static-1.local",
HWAddr: net.HardwareAddr{0x33, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.IP{192, 168, 10, 150},
IP: netip.MustParseAddr("192.168.10.150"),
}, {
Hostname: "static-2.local",
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.IP{192, 168, 10, 152},
IP: netip.MustParseAddr("192.168.10.152"),
}}
for _, l := range stLeases {
@@ -320,9 +328,9 @@ func TestV4_AddReplace(t *testing.T) {
require.Len(t, ls, 2)
for i, l := range ls {
assert.True(t, stLeases[i].IP.Equal(l.IP))
assert.Equal(t, stLeases[i].IP, l.IP)
assert.Equal(t, stLeases[i].HWAddr, l.HWAddr)
assert.True(t, l.IsStatic())
assert.True(t, l.IsStatic)
}
}
@@ -513,7 +521,7 @@ func TestV4StaticLease_Get(t *testing.T) {
l := &Lease{
Hostname: "static-1.local",
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.IP{192, 168, 10, 150},
IP: netip.MustParseAddr("192.168.10.150"),
}
err := s.AddStaticLease(l)
require.NoError(t, err)
@@ -539,7 +547,9 @@ func TestV4StaticLease_Get(t *testing.T) {
t.Run("offer", func(t *testing.T) {
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
assert.Equal(t, mac, resp.ClientHWAddr)
assert.True(t, l.IP.Equal(resp.YourIPAddr))
ip := net.IP(l.IP.AsSlice())
assert.True(t, ip.Equal(resp.YourIPAddr))
assert.True(t, resp.Router()[0].Equal(s.conf.GatewayIP.AsSlice()))
assert.True(t, resp.ServerIdentifier().Equal(s.conf.GatewayIP.AsSlice()))
@@ -564,7 +574,9 @@ func TestV4StaticLease_Get(t *testing.T) {
t.Run("ack", func(t *testing.T) {
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
assert.Equal(t, mac, resp.ClientHWAddr)
assert.True(t, l.IP.Equal(resp.YourIPAddr))
ip := net.IP(l.IP.AsSlice())
assert.True(t, ip.Equal(resp.YourIPAddr))
assert.True(t, resp.Router()[0].Equal(s.conf.GatewayIP.AsSlice()))
assert.True(t, resp.ServerIdentifier().Equal(s.conf.GatewayIP.AsSlice()))
@@ -583,7 +595,7 @@ func TestV4StaticLease_Get(t *testing.T) {
ls := s.GetLeases(LeasesStatic)
require.Len(t, ls, 1)
assert.True(t, l.IP.Equal(ls[0].IP))
assert.Equal(t, l.IP, ls[0].IP)
assert.Equal(t, mac, ls[0].HWAddr)
})
}
@@ -681,7 +693,8 @@ func TestV4DynamicLease_Get(t *testing.T) {
ls := s.GetLeases(LeasesDynamic)
require.Len(t, ls, 1)
assert.True(t, net.IP{192, 168, 10, 100}.Equal(ls[0].IP))
ip := netip.MustParseAddr("192.168.10.100")
assert.Equal(t, ip, ls[0].IP)
assert.Equal(t, mac, ls[0].HWAddr)
})
}
@@ -810,7 +823,7 @@ func TestV4Server_Send(t *testing.T) {
req: &dhcpv4.DHCPv4{ClientHWAddr: knownMAC},
resp: &dhcpv4.DHCPv4{YourIPAddr: knownIP},
want: &dhcpUnicastAddr{
Addr: raw.Addr{HardwareAddr: knownMAC},
Addr: packet.Addr{HardwareAddr: knownMAC},
yiaddr: knownIP,
},
}, {
@@ -862,3 +875,144 @@ func TestV4Server_Send(t *testing.T) {
assert.True(t, resp.IsBroadcast())
})
}
func TestV4Server_FindMACbyIP(t *testing.T) {
const (
staticName = "static-client"
anotherName = "another-client"
)
staticIP := netip.MustParseAddr("192.168.10.10")
staticMAC := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
anotherIP := netip.MustParseAddr("192.168.100.100")
anotherMAC := net.HardwareAddr{0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB}
s := &v4Server{
leases: []*Lease{{
Expiry: time.Unix(leaseExpireStatic, 0),
Hostname: staticName,
HWAddr: staticMAC,
IP: staticIP,
IsStatic: true,
}, {
Expiry: time.Unix(10, 0),
Hostname: anotherName,
HWAddr: anotherMAC,
IP: anotherIP,
}},
}
testCases := []struct {
want net.HardwareAddr
ip netip.Addr
name string
}{{
name: "basic",
ip: staticIP,
want: staticMAC,
}, {
name: "not_found",
ip: netip.MustParseAddr("1.2.3.4"),
want: nil,
}, {
name: "expired",
ip: anotherIP,
want: nil,
}, {
name: "v6",
ip: netip.MustParseAddr("ffff::1"),
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mac := s.FindMACbyIP(tc.ip)
require.Equal(t, tc.want, mac)
})
}
}
func TestV4Server_handleDecline(t *testing.T) {
const (
dynamicName = "dynamic-client"
anotherName = "another-client"
)
dynamicIP := netip.MustParseAddr("192.168.10.200")
dynamicMAC := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
s := defaultSrv(t)
s4, ok := s.(*v4Server)
require.True(t, ok)
s4.leases = []*Lease{{
Hostname: dynamicName,
HWAddr: dynamicMAC,
IP: dynamicIP,
}}
req, err := dhcpv4.New(
dhcpv4.WithOption(dhcpv4.OptRequestedIPAddress(net.IP(dynamicIP.AsSlice()))),
)
require.NoError(t, err)
req.ClientIPAddr = net.IP(dynamicIP.AsSlice())
req.ClientHWAddr = dynamicMAC
resp := &dhcpv4.DHCPv4{}
err = s4.handleDecline(req, resp)
require.NoError(t, err)
wantResp := &dhcpv4.DHCPv4{
YourIPAddr: net.IP(s4.conf.RangeStart.AsSlice()),
Options: dhcpv4.OptionsFromList(
dhcpv4.OptMessageType(dhcpv4.MessageTypeAck),
),
}
require.Equal(t, wantResp, resp)
}
func TestV4Server_handleRelease(t *testing.T) {
const (
dynamicName = "dymamic-client"
anotherName = "another-client"
)
dynamicIP := netip.MustParseAddr("192.168.10.200")
dynamicMAC := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
s := defaultSrv(t)
s4, ok := s.(*v4Server)
require.True(t, ok)
s4.leases = []*Lease{{
Hostname: dynamicName,
HWAddr: dynamicMAC,
IP: dynamicIP,
}}
req, err := dhcpv4.New(
dhcpv4.WithOption(dhcpv4.OptRequestedIPAddress(net.IP(dynamicIP.AsSlice()))),
)
require.NoError(t, err)
req.ClientIPAddr = net.IP(dynamicIP.AsSlice())
req.ClientHWAddr = dynamicMAC
resp := &dhcpv4.DHCPv4{}
err = s4.handleRelease(req, resp)
require.NoError(t, err)
wantResp := &dhcpv4.DHCPv4{
Options: dhcpv4.OptionsFromList(
dhcpv4.OptMessageType(dhcpv4.MessageTypeAck),
),
}
require.Equal(t, wantResp, resp)
}

View File

@@ -61,13 +61,13 @@ func ip6InRange(start, ip net.IP) bool {
// ResetLeases resets leases.
func (s *v6Server) ResetLeases(leases []*Lease) (err error) {
defer func() { err = errors.Annotate(err, "dhcpv4: %w") }()
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
s.leases = nil
for _, l := range leases {
ip := net.IP(l.IP.AsSlice())
if l.Expiry.Unix() != leaseExpireStatic &&
!ip6InRange(s.conf.ipStart, l.IP) {
!ip6InRange(s.conf.ipStart, ip) {
log.Debug("dhcpv6: skipping a lease with IP %v: not within current IP range", l.IP)
@@ -119,10 +119,9 @@ func (s *v6Server) FindMACbyIP(ip netip.Addr) (mac net.HardwareAddr) {
return nil
}
netIP := ip.AsSlice()
for _, l := range s.leases {
if l.IP.Equal(netIP) {
if l.Expiry.After(now) || l.IsStatic() {
if l.IP == ip {
if l.IsStatic || l.Expiry.After(now) {
return l.HWAddr
}
}
@@ -133,7 +132,8 @@ func (s *v6Server) FindMACbyIP(ip netip.Addr) (mac net.HardwareAddr) {
// Remove (swap) lease by index
func (s *v6Server) leaseRemoveSwapByIndex(i int) {
s.ipAddrs[s.leases[i].IP[15]] = 0
leaseIP := s.leases[i].IP.As16()
s.ipAddrs[leaseIP[15]] = 0
log.Debug("dhcpv6: removed lease %s", s.leases[i].HWAddr)
n := len(s.leases)
@@ -162,7 +162,7 @@ func (s *v6Server) rmDynamicLease(lease *Lease) (err error) {
l = s.leases[i]
}
if net.IP.Equal(l.IP, lease.IP) {
if l.IP == lease.IP {
if l.Expiry.Unix() == leaseExpireStatic {
return fmt.Errorf("static lease already exists")
}
@@ -178,7 +178,7 @@ func (s *v6Server) rmDynamicLease(lease *Lease) (err error) {
func (s *v6Server) AddStaticLease(l *Lease) (err error) {
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
if len(l.IP) != net.IPv6len {
if !l.IP.Is6() {
return fmt.Errorf("invalid IP")
}
@@ -210,7 +210,7 @@ func (s *v6Server) AddStaticLease(l *Lease) (err error) {
func (s *v6Server) RemoveStaticLease(l *Lease) (err error) {
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
if len(l.IP) != 16 {
if !l.IP.Is6() {
return fmt.Errorf("invalid IP")
}
@@ -234,14 +234,15 @@ func (s *v6Server) RemoveStaticLease(l *Lease) (err error) {
// Add a lease
func (s *v6Server) addLease(l *Lease) {
s.leases = append(s.leases, l)
s.ipAddrs[l.IP[15]] = 1
ip := l.IP.As16()
s.ipAddrs[ip[15]] = 1
log.Debug("dhcpv6: added lease %s <-> %s", l.IP, l.HWAddr)
}
// Remove a lease with the same properties
func (s *v6Server) rmLease(lease *Lease) (err error) {
for i, l := range s.leases {
if net.IP.Equal(l.IP, lease.IP) {
if l.IP == lease.IP {
if !bytes.Equal(l.HWAddr, lease.HWAddr) ||
l.Hostname != lease.Hostname {
return fmt.Errorf("lease not found")
@@ -308,18 +309,27 @@ func (s *v6Server) reserveLease(mac net.HardwareAddr) *Lease {
s.leasesLock.Lock()
defer s.leasesLock.Unlock()
copy(l.IP, s.conf.ipStart)
l.IP = s.findFreeIP()
if l.IP == nil {
ip := s.findFreeIP()
if ip == nil {
i := s.findExpiredLease()
if i < 0 {
return nil
}
copy(s.leases[i].HWAddr, mac)
return s.leases[i]
}
netIP, ok := netip.AddrFromSlice(ip)
if !ok {
return nil
}
l.IP = netIP
s.addLease(&l)
return &l
}
@@ -388,7 +398,8 @@ func (s *v6Server) checkIA(msg *dhcpv6.Message, lease *Lease) error {
return fmt.Errorf("no IANA.Addr option in %s", msg.Type().String())
}
if !oiaAddr.IPv6Addr.Equal(lease.IP) {
leaseIP := net.IP(lease.IP.AsSlice())
if !oiaAddr.IPv6Addr.Equal(leaseIP) {
return fmt.Errorf("invalid IANA.Addr option in %s", msg.Type().String())
}
}
@@ -475,7 +486,7 @@ func (s *v6Server) process(msg *dhcpv6.Message, req, resp dhcpv6.DHCPv6) bool {
copy(oia.IaId[:], []byte(valueIAID))
}
oiaAddr := &dhcpv6.OptIAAddress{
IPv6Addr: lease.IP,
IPv6Addr: net.IP(lease.IP.AsSlice()),
PreferredLifetime: lifetime,
ValidLifetime: lifetime,
}

View File

@@ -4,7 +4,9 @@ package dhcpd
import (
"net"
"net/netip"
"testing"
"time"
"github.com/insomniacslk/dhcp/dhcpv6"
"github.com/insomniacslk/dhcp/iana"
@@ -27,7 +29,7 @@ func TestV6_AddRemove_static(t *testing.T) {
// Add static lease.
l := &Lease{
IP: net.ParseIP("2001::1"),
IP: netip.MustParseAddr("2001::1"),
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
}
err = s.AddStaticLease(l)
@@ -46,7 +48,7 @@ func TestV6_AddRemove_static(t *testing.T) {
// Try to remove non-existent static lease.
err = s.RemoveStaticLease(&Lease{
IP: net.ParseIP("2001::2"),
IP: netip.MustParseAddr("2001::2"),
HWAddr: l.HWAddr,
})
require.Error(t, err)
@@ -71,10 +73,10 @@ func TestV6_AddReplace(t *testing.T) {
// Add dynamic leases.
dynLeases := []*Lease{{
IP: net.ParseIP("2001::1"),
IP: netip.MustParseAddr("2001::1"),
HWAddr: net.HardwareAddr{0x11, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
}, {
IP: net.ParseIP("2001::2"),
IP: netip.MustParseAddr("2001::2"),
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
}}
@@ -83,10 +85,10 @@ func TestV6_AddReplace(t *testing.T) {
}
stLeases := []*Lease{{
IP: net.ParseIP("2001::1"),
IP: netip.MustParseAddr("2001::1"),
HWAddr: net.HardwareAddr{0x33, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
}, {
IP: net.ParseIP("2001::3"),
IP: netip.MustParseAddr("2001::3"),
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
}}
@@ -99,7 +101,7 @@ func TestV6_AddReplace(t *testing.T) {
require.Len(t, ls, 2)
for i, l := range ls {
assert.True(t, stLeases[i].IP.Equal(l.IP))
assert.Equal(t, stLeases[i].IP, l.IP)
assert.Equal(t, stLeases[i].HWAddr, l.HWAddr)
assert.EqualValues(t, leaseExpireStatic, l.Expiry.Unix())
}
@@ -126,7 +128,7 @@ func TestV6GetLease(t *testing.T) {
}
l := &Lease{
IP: net.ParseIP("2001::1"),
IP: netip.MustParseAddr("2001::1"),
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
}
err = s.AddStaticLease(l)
@@ -158,7 +160,8 @@ func TestV6GetLease(t *testing.T) {
oia = resp.Options.OneIANA()
oiaAddr = oia.Options.OneAddress()
assert.Equal(t, l.IP, oiaAddr.IPv6Addr)
ip := net.IP(l.IP.AsSlice())
assert.Equal(t, ip, oiaAddr.IPv6Addr)
assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds())
})
@@ -182,7 +185,8 @@ func TestV6GetLease(t *testing.T) {
oia = resp.Options.OneIANA()
oiaAddr = oia.Options.OneAddress()
assert.Equal(t, l.IP, oiaAddr.IPv6Addr)
ip := net.IP(l.IP.AsSlice())
assert.Equal(t, ip, oiaAddr.IPv6Addr)
assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds())
})
@@ -308,3 +312,74 @@ func TestIP6InRange(t *testing.T) {
})
}
}
func TestV6_FindMACbyIP(t *testing.T) {
const (
staticName = "static-client"
anotherName = "another-client"
)
staticIP := netip.MustParseAddr("2001::1")
staticMAC := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
anotherIP := netip.MustParseAddr("2001::100")
anotherMAC := net.HardwareAddr{0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB}
s := &v6Server{
leases: []*Lease{{
Expiry: time.Unix(leaseExpireStatic, 0),
Hostname: staticName,
HWAddr: staticMAC,
IP: staticIP,
IsStatic: true,
}, {
Expiry: time.Unix(10, 0),
Hostname: anotherName,
HWAddr: anotherMAC,
IP: anotherIP,
}},
}
s.leases = []*Lease{{
Expiry: time.Unix(leaseExpireStatic, 0),
Hostname: staticName,
HWAddr: staticMAC,
IP: staticIP,
IsStatic: true,
}, {
Expiry: time.Unix(10, 0),
Hostname: anotherName,
HWAddr: anotherMAC,
IP: anotherIP,
}}
testCases := []struct {
want net.HardwareAddr
ip netip.Addr
name string
}{{
name: "basic",
ip: staticIP,
want: staticMAC,
}, {
name: "not_found",
ip: netip.MustParseAddr("ffff::1"),
want: nil,
}, {
name: "expired",
ip: anotherIP,
want: nil,
}, {
name: "v4",
ip: netip.MustParseAddr("1.2.3.4"),
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mac := s.FindMACbyIP(tc.ip)
require.Equal(t, tc.want, mac)
})
}
}

View File

@@ -81,6 +81,10 @@ type FilteringConfig struct {
// 0, then default value is used (3600).
BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"`
// ProtectionDisabledUntil is the timestamp until when the protection is
// disabled.
ProtectionDisabledUntil *time.Time `yaml:"protection_disabled_until"`
// ParentalBlockHost is the IP (or domain name) which is used to respond to
// DNS requests blocked by parental control.
ParentalBlockHost string `yaml:"parental_block_host"`
@@ -195,12 +199,16 @@ type FilteringConfig struct {
// IpsetListFileName, if set, points to the file with ipset configuration.
// The format is the same as in [IpsetList].
IpsetListFileName string `yaml:"ipset_file"`
// BootstrapPreferIPv6, if true, instructs the bootstrapper to prefer IPv6
// addresses to IPv4 ones for DoH, DoQ, and DoT.
BootstrapPreferIPv6 bool `yaml:"bootstrap_prefer_ipv6"`
}
// EDNSClientSubnet is the settings list for EDNS Client Subnet.
type EDNSClientSubnet struct {
// CustomIP for EDNS Client Subnet.
CustomIP string `yaml:"custom_ip"`
CustomIP netip.Addr `yaml:"custom_ip"`
// Enabled defines if EDNS Client Subnet is enabled.
Enabled bool `yaml:"enabled"`
@@ -340,15 +348,8 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
}
if srvConf.EDNSClientSubnet.UseCustom {
// TODO(s.chzhen): Add wrapper around netip.Addr.
var ip net.IP
ip, err = netutil.ParseIP(srvConf.EDNSClientSubnet.CustomIP)
if err != nil {
return conf, fmt.Errorf("edns: %w", err)
}
// TODO(s.chzhen): Use netip.Addr instead of net.IP inside dnsproxy.
conf.EDNSAddr = ip
conf.EDNSAddr = net.IP(srvConf.EDNSClientSubnet.CustomIP.AsSlice())
}
if srvConf.CacheSize != 0 {
@@ -377,7 +378,7 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
err = s.prepareTLS(&conf)
if err != nil {
return conf, fmt.Errorf("validating tls: %w", err)
return proxy.Config{}, fmt.Errorf("validating tls: %w", err)
}
if c := srvConf.DNSCryptConfig; c.Enabled {
@@ -388,7 +389,7 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
}
if conf.UpstreamConfig == nil || len(conf.UpstreamConfig.Upstreams) == 0 {
return conf, errors.Error("no default upstream servers configured")
return proxy.Config{}, errors.Error("no default upstream servers configured")
}
return conf, nil
@@ -482,6 +483,7 @@ func (s *Server) prepareUpstreamSettings() error {
Bootstrap: s.conf.BootstrapDNS,
Timeout: s.conf.UpstreamTimeout,
HTTPVersions: httpVersions,
PreferIPv6: s.conf.BootstrapPreferIPv6,
},
)
if err != nil {
@@ -497,6 +499,7 @@ func (s *Server) prepareUpstreamSettings() error {
Bootstrap: s.conf.BootstrapDNS,
Timeout: s.conf.UpstreamTimeout,
HTTPVersions: httpVersions,
PreferIPv6: s.conf.BootstrapPreferIPv6,
},
)
if err != nil {
@@ -584,11 +587,11 @@ func (s *Server) prepareTLS(proxyConfig *proxy.Config) (err error) {
if s.conf.StrictSNICheck {
if len(cert.DNSNames) != 0 {
s.conf.dnsNames = cert.DNSNames
log.Debug("dnsforward: using certificate's SAN as DNS names: %v", cert.DNSNames)
log.Debug("dns: using certificate's SAN as DNS names: %v", cert.DNSNames)
slices.Sort(s.conf.dnsNames)
} else {
s.conf.dnsNames = append(s.conf.dnsNames, cert.Subject.CommonName)
log.Debug("dnsforward: using certificate's CN as DNS name: %s", cert.Subject.CommonName)
log.Debug("dns: using certificate's CN as DNS name: %s", cert.Subject.CommonName)
}
}
@@ -642,3 +645,49 @@ func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, er
}
return &s.conf.cert, nil
}
// UpdatedProtectionStatus updates protection state, if the protection was
// disabled temporarily. Returns the updated state of protection.
func (s *Server) UpdatedProtectionStatus() (enabled bool, disabledUntil *time.Time) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()
disabledUntil = s.conf.ProtectionDisabledUntil
if disabledUntil == nil {
return s.conf.ProtectionEnabled, nil
}
if time.Now().Before(*disabledUntil) {
return false, disabledUntil
}
// Update the values in a separate goroutine, unless an update is already in
// progress. Since this method is called very often, and this update is a
// relatively rare situation, do not lock s.serverLock for writing, as that
// can lead to freezes.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/5661.
if s.protectionUpdateInProgress.CompareAndSwap(false, true) {
go s.enableProtectionAfterPause()
}
return true, nil
}
// enableProtectionAfterPause sets the protection configuration to enabled
// values. It is intended to be used as a goroutine.
func (s *Server) enableProtectionAfterPause() {
defer log.OnPanic("dns: enabling protection after pause")
defer s.protectionUpdateInProgress.Store(false)
defer s.conf.ConfigModified()
s.serverLock.Lock()
defer s.serverLock.Unlock()
s.conf.ProtectionEnabled = true
s.conf.ProtectionDisabledUntil = nil
log.Info("dns: protection is restarted after pause")
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/binary"
"net"
"net/netip"
"strconv"
"strings"
"time"
@@ -21,7 +22,7 @@ import (
// To transfer information between modules
//
// TODO(s.chzhen): Add lowercased, non-FQDN version of the hostname from the
// question of the request.
// question of the request. Add persistent client.
type dnsContext struct {
proxyCtx *proxy.DNSContext
@@ -37,6 +38,8 @@ type dnsContext struct {
// was parsed successfully and belongs to one of the locally served IP
// ranges. It is also filled with unmapped version of the address if it's
// within DNS64 prefixes.
//
// TODO(e.burkov): Use netip.Addr when we switch to netip more fully.
unreversedReqIP net.IP
// err is the error returned from a processing function.
@@ -181,6 +184,21 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
return resultCodeFinish
}
// Handle a reserved domain healthcheck.adguardhome.test.
//
// [Section 6.2 of RFC 6761] states that DNS Registries/Registrars must not
// grant requests to register test names in the normal way to any person or
// entity, making domain names under test. TLD free to use in internal
// purposes.
//
// [Section 6.2 of RFC 6761]: https://www.rfc-editor.org/rfc/rfc6761.html#section-6.2
if q.Name == "healthcheck.adguardhome.test." {
// Generate a NODATA negative response to make nslookup exit with 0.
pctx.Res = s.makeResponse(pctx.Req)
return resultCodeFinish
}
// Get the ClientID, if any, before getting client-specific filtering
// settings.
var key [8]byte
@@ -188,7 +206,7 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
dctx.clientID = string(s.clientIDCache.Get(key[:]))
// Get the client-specific filtering settings.
dctx.protectionEnabled = s.conf.ProtectionEnabled
dctx.protectionEnabled, _ = s.UpdatedProtectionStatus()
dctx.setts = s.getClientRequestFilteringSettings(dctx)
return resultCodeSuccess
@@ -240,17 +258,16 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
lowhost := strings.ToLower(l.Hostname + "." + s.localDomainSuffix)
// Assume that we only process IPv4 now.
//
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
ip, err := netutil.IPToAddr(l.IP, netutil.AddrFamilyIPv4)
if err != nil {
log.Debug("dnsforward: skipping invalid ip %v from dhcp: %s", l.IP, err)
if !l.IP.Is4() {
log.Debug("dnsforward: skipping invalid ip from dhcp: bad ipv4 net.IP %v", l.IP)
continue
}
ipToHost[ip] = lowhost
hostToIP[lowhost] = ip
leaseIP := l.IP
ipToHost[leaseIP] = lowhost
hostToIP[lowhost] = leaseIP
}
s.setTableHostToIP(hostToIP)
@@ -442,6 +459,88 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess
}
// indexFirstV4Label returns the index at which the reversed IPv4 address
// starts, assuming the domain is pre-validated ARPA domain having in-addr and
// arpa labels removed.
func indexFirstV4Label(domain string) (idx int) {
idx = len(domain)
for labelsNum := 0; labelsNum < net.IPv4len && idx > 0; labelsNum++ {
curIdx := strings.LastIndexByte(domain[:idx-1], '.') + 1
_, parseErr := strconv.ParseUint(domain[curIdx:idx-1], 10, 8)
if parseErr != nil {
return idx
}
idx = curIdx
}
return idx
}
// indexFirstV6Label returns the index at which the reversed IPv6 address
// starts, assuming the domain is pre-validated ARPA domain having ip6 and arpa
// labels removed.
func indexFirstV6Label(domain string) (idx int) {
idx = len(domain)
for labelsNum := 0; labelsNum < net.IPv6len*2 && idx > 0; labelsNum++ {
curIdx := idx - len("a.")
if curIdx > 1 && domain[curIdx-1] != '.' {
return idx
}
nibble := domain[curIdx]
if (nibble < '0' || nibble > '9') && (nibble < 'a' || nibble > 'f') {
return idx
}
idx = curIdx
}
return idx
}
// extractARPASubnet tries to convert a reversed ARPA address being a part of
// domain to an IP network. domain must be an FQDN.
//
// TODO(e.burkov): Move to golibs.
func extractARPASubnet(domain string) (pref netip.Prefix, err error) {
err = netutil.ValidateDomainName(strings.TrimSuffix(domain, "."))
if err != nil {
// Don't wrap the error since it's informative enough as is.
return netip.Prefix{}, err
}
const (
v4Suffix = "in-addr.arpa."
v6Suffix = "ip6.arpa."
)
domain = strings.ToLower(domain)
var idx int
switch {
case strings.HasSuffix(domain, v4Suffix):
idx = indexFirstV4Label(domain[:len(domain)-len(v4Suffix)])
case strings.HasSuffix(domain, v6Suffix):
idx = indexFirstV6Label(domain[:len(domain)-len(v6Suffix)])
default:
return netip.Prefix{}, &netutil.AddrError{
Err: netutil.ErrNotAReversedSubnet,
Kind: netutil.AddrKindARPA,
Addr: domain,
}
}
var subnet *net.IPNet
subnet, err = netutil.SubnetFromReversedAddr(domain[idx:])
if err != nil {
// Don't wrap the error since it's informative enough as is.
return netip.Prefix{}, err
}
return netutil.IPNetToPrefixNoMapped(subnet)
}
// processRestrictLocal responds with NXDOMAIN to PTR requests for IP addresses
// in locally served network from external clients.
func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
@@ -453,34 +552,29 @@ func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess
}
ip, err := netutil.IPFromReversedAddr(q.Name)
subnet, err := extractARPASubnet(q.Name)
if err != nil {
log.Debug("dnsforward: parsing reversed addr: %s", err)
if errors.Is(err, netutil.ErrNotAReversedSubnet) {
log.Debug("dnsforward: request is not for arpa domain")
// DNS-Based Service Discovery uses PTR records having not an ARPA
// format of the domain name in question. Those shouldn't be
// invalidated. See http://www.dns-sd.org/ServerStaticSetup.html and
// RFC 2782.
name := strings.TrimSuffix(q.Name, ".")
if err = netutil.ValidateSRVDomainName(name); err != nil {
log.Debug("dnsforward: validating service domain: %s", err)
return resultCodeError
return resultCodeSuccess
}
log.Debug("dnsforward: request is not for arpa domain")
log.Debug("dnsforward: parsing reversed addr: %s", err)
return resultCodeSuccess
return resultCodeError
}
// Restrict an access to local addresses for external clients. We also
// assume that all the DHCP leases we give are locally served or at least
// shouldn't be accessible externally.
if !s.privateNets.Contains(ip) {
subnetAddr := subnet.Addr()
addrData := subnetAddr.AsSlice()
if !s.privateNets.Contains(addrData) {
return resultCodeSuccess
}
log.Debug("dnsforward: addr %s is from locally served network", ip)
log.Debug("dnsforward: addr %s is from locally served network", subnetAddr)
if !dctx.isLocalClient {
log.Debug("dnsforward: %q requests an internal ip", pctx.Addr)
@@ -491,7 +585,7 @@ func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
}
// Do not perform unreversing ever again.
dctx.unreversedReqIP = ip
dctx.unreversedReqIP = addrData
// There is no need to filter request from external addresses since this
// code is only executed when the request is for locally served ARPA

View File

@@ -36,8 +36,6 @@ func (s *Server) setupDNS64() {
// valid IPv4. It panics, if there are no configured DNS64 prefixes, because
// synthesis should not be performed unless DNS64 function enabled.
func (s *Server) mapDNS64(ip netip.Addr) (mapped net.IP) {
// Don't mask the address here since it should have already been masked on
// initialization stage.
pref := s.dns64Pref.Masked().Addr().As16()
ipData := ip.As4()

View File

@@ -605,3 +605,129 @@ func TestIPStringFromAddr(t *testing.T) {
assert.Empty(t, ipStringFromAddr(nil))
})
}
// TODO(e.burkov): Add fuzzing when moving to golibs.
func TestExtractARPASubnet(t *testing.T) {
const (
v4Suf = `in-addr.arpa.`
v4Part = `2.1.` + v4Suf
v4Whole = `4.3.` + v4Part
v6Suf = `ip6.arpa.`
v6Part = `4.3.2.1.0.0.0.0.0.0.0.0.0.0.0.0.` + v6Suf
v6Whole = `f.e.d.c.0.0.0.0.0.0.0.0.0.0.0.0.` + v6Part
)
v4Pref := netip.MustParsePrefix("1.2.3.4/32")
v4PrefPart := netip.MustParsePrefix("1.2.0.0/16")
v6Pref := netip.MustParsePrefix("::1234:0:0:0:cdef/128")
v6PrefPart := netip.MustParsePrefix("0:0:0:1234::/64")
testCases := []struct {
want netip.Prefix
name string
domain string
wantErr string
}{{
want: netip.Prefix{},
name: "not_an_arpa",
domain: "some.domain.name.",
wantErr: `bad arpa domain name "some.domain.name.": ` +
`not a reversed ip network`,
}, {
want: netip.Prefix{},
name: "bad_domain_name",
domain: "abc.123.",
wantErr: `bad domain name "abc.123": ` +
`bad top-level domain name label "123": all octets are numeric`,
}, {
want: v4Pref,
name: "whole_v4",
domain: v4Whole,
wantErr: "",
}, {
want: v4PrefPart,
name: "partial_v4",
domain: v4Part,
wantErr: "",
}, {
want: v4Pref,
name: "whole_v4_within_domain",
domain: "a." + v4Whole,
wantErr: "",
}, {
want: v4Pref,
name: "whole_v4_additional_label",
domain: "5." + v4Whole,
wantErr: "",
}, {
want: v4PrefPart,
name: "partial_v4_within_domain",
domain: "a." + v4Part,
wantErr: "",
}, {
want: v4PrefPart,
name: "overflow_v4",
domain: "256." + v4Part,
wantErr: "",
}, {
want: v4PrefPart,
name: "overflow_v4_within_domain",
domain: "a.256." + v4Part,
wantErr: "",
}, {
want: netip.Prefix{},
name: "empty_v4",
domain: v4Suf,
wantErr: `bad arpa domain name "in-addr.arpa": ` +
`not a reversed ip network`,
}, {
want: netip.Prefix{},
name: "empty_v4_within_domain",
domain: "a." + v4Suf,
wantErr: `bad arpa domain name "in-addr.arpa": ` +
`not a reversed ip network`,
}, {
want: v6Pref,
name: "whole_v6",
domain: v6Whole,
wantErr: "",
}, {
want: v6PrefPart,
name: "partial_v6",
domain: v6Part,
}, {
want: v6Pref,
name: "whole_v6_within_domain",
domain: "g." + v6Whole,
wantErr: "",
}, {
want: v6Pref,
name: "whole_v6_additional_label",
domain: "1." + v6Whole,
wantErr: "",
}, {
want: v6PrefPart,
name: "partial_v6_within_domain",
domain: "label." + v6Part,
wantErr: "",
}, {
want: netip.Prefix{},
name: "empty_v6",
domain: v6Suf,
wantErr: `bad arpa domain name "ip6.arpa": not a reversed ip network`,
}, {
want: netip.Prefix{},
name: "empty_v6_within_domain",
domain: "g." + v6Suf,
wantErr: `bad arpa domain name "ip6.arpa": not a reversed ip network`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
subnet, err := extractARPASubnet(tc.domain)
testutil.AssertErrorMsg(t, tc.wantErr, err)
assert.Equal(t, tc.want, subnet)
})
}
}

View File

@@ -9,6 +9,7 @@ import (
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
@@ -111,6 +112,10 @@ type Server struct {
isRunning bool
// protectionUpdateInProgress is used to make sure that only one goroutine
// updating the protection configuration after a pause is running at a time.
protectionUpdateInProgress atomic.Bool
conf ServerConfig
// serverLock protects Server.
serverLock sync.RWMutex
@@ -447,6 +452,8 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) {
Bootstrap: bootstraps,
Timeout: defaultLocalTimeout,
// TODO(e.burkov): Should we verify server's certificates?
PreferIPv6: s.conf.BootstrapPreferIPv6,
},
)
if err != nil {

View File

@@ -23,6 +23,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/netutil"
@@ -412,7 +413,7 @@ func TestServerRace(t *testing.T) {
filterConf := &filtering.Config{
SafeBrowsingEnabled: true,
SafeBrowsingCacheSize: 1000,
SafeSearchEnabled: true,
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
SafeSearchCacheSize: 1000,
ParentalCacheSize: 1000,
CacheTime: 30,
@@ -440,12 +441,27 @@ func TestServerRace(t *testing.T) {
func TestSafeSearch(t *testing.T) {
resolver := &aghtest.TestResolver{}
safeSearchConf := filtering.SafeSearchConfig{
Enabled: true,
Google: true,
Yandex: true,
CustomResolver: resolver,
}
filterConf := &filtering.Config{
SafeSearchEnabled: true,
SafeSearchConf: safeSearchConf,
SafeSearchCacheSize: 1000,
CacheTime: 30,
CustomResolver: resolver,
}
safeSearch, err := safesearch.NewDefault(
safeSearchConf,
"",
filterConf.SafeSearchCacheSize,
time.Minute*time.Duration(filterConf.CacheTime),
)
require.NoError(t, err)
filterConf.SafeSearch = safeSearch
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
@@ -498,7 +514,8 @@ func TestSafeSearch(t *testing.T) {
t.Run(tc.host, func(t *testing.T) {
req := createTestMessage(tc.host)
reply, _, err := client.Exchange(req, addr)
var reply *dns.Msg
reply, _, err = client.Exchange(req, addr)
require.NoErrorf(t, err, "couldn't talk to server %s: %s", addr, err)
assertResponse(t, reply, tc.want)
})
@@ -1057,7 +1074,7 @@ var testDHCP = &dhcpd.MockInterface{
OnEnabled: func() (ok bool) { return true },
OnLeases: func(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) {
return []*dhcpd.Lease{{
IP: net.IP{192, 168, 12, 34},
IP: netip.MustParseAddr("192.168.12.34"),
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
Hostname: "myhost",
}}

View File

@@ -23,41 +23,101 @@ import (
)
// jsonDNSConfig is the JSON representation of the DNS server configuration.
//
// TODO(s.chzhen): Split it into smaller pieces. Use aghalg.NullBool instead
// of *bool.
type jsonDNSConfig struct {
Upstreams *[]string `json:"upstream_dns"`
UpstreamsFile *string `json:"upstream_dns_file"`
Bootstraps *[]string `json:"bootstrap_dns"`
ProtectionEnabled *bool `json:"protection_enabled"`
RateLimit *uint32 `json:"ratelimit"`
BlockingMode *BlockingMode `json:"blocking_mode"`
EDNSCSEnabled *bool `json:"edns_cs_enabled"`
DNSSECEnabled *bool `json:"dnssec_enabled"`
DisableIPv6 *bool `json:"disable_ipv6"`
UpstreamMode *string `json:"upstream_mode"`
CacheSize *uint32 `json:"cache_size"`
CacheMinTTL *uint32 `json:"cache_ttl_min"`
CacheMaxTTL *uint32 `json:"cache_ttl_max"`
CacheOptimistic *bool `json:"cache_optimistic"`
ResolveClients *bool `json:"resolve_clients"`
UsePrivateRDNS *bool `json:"use_private_ptr_resolvers"`
LocalPTRUpstreams *[]string `json:"local_ptr_upstreams"`
BlockingIPv4 net.IP `json:"blocking_ipv4"`
BlockingIPv6 net.IP `json:"blocking_ipv6"`
// Upstreams is the list of upstream DNS servers.
Upstreams *[]string `json:"upstream_dns"`
// UpstreamsFile is the file containing upstream DNS servers.
UpstreamsFile *string `json:"upstream_dns_file"`
// Bootstraps is the list of DNS servers resolving IP addresses of the
// upstream DoH/DoT resolvers.
Bootstraps *[]string `json:"bootstrap_dns"`
// ProtectionEnabled defines if protection is enabled.
ProtectionEnabled *bool `json:"protection_enabled"`
// RateLimit is the number of requests per second allowed per client.
RateLimit *uint32 `json:"ratelimit"`
// BlockingMode defines the way blocked responses are constructed.
BlockingMode *BlockingMode `json:"blocking_mode"`
// EDNSCSEnabled defines if EDNS Client Subnet is enabled.
EDNSCSEnabled *bool `json:"edns_cs_enabled"`
// EDNSCSUseCustom defines if EDNSCSCustomIP should be used.
EDNSCSUseCustom *bool `json:"edns_cs_use_custom"`
// DNSSECEnabled defines if DNSSEC is enabled.
DNSSECEnabled *bool `json:"dnssec_enabled"`
// DisableIPv6 defines if IPv6 addresses should be dropped.
DisableIPv6 *bool `json:"disable_ipv6"`
// UpstreamMode defines the way DNS requests are constructed.
UpstreamMode *string `json:"upstream_mode"`
// CacheSize in bytes.
CacheSize *uint32 `json:"cache_size"`
// CacheMinTTL is custom minimum TTL for cached DNS responses.
CacheMinTTL *uint32 `json:"cache_ttl_min"`
// CacheMaxTTL is custom maximum TTL for cached DNS responses.
CacheMaxTTL *uint32 `json:"cache_ttl_max"`
// CacheOptimistic defines if expired entries should be served.
CacheOptimistic *bool `json:"cache_optimistic"`
// ResolveClients defines if clients IPs should be resolved into hostnames.
ResolveClients *bool `json:"resolve_clients"`
// UsePrivateRDNS defines if privates DNS resolvers should be used.
UsePrivateRDNS *bool `json:"use_private_ptr_resolvers"`
// LocalPTRUpstreams is the list of local private DNS resolvers.
LocalPTRUpstreams *[]string `json:"local_ptr_upstreams"`
// BlockingIPv4 is custom IPv4 address for blocked A requests.
BlockingIPv4 net.IP `json:"blocking_ipv4"`
// BlockingIPv6 is custom IPv6 address for blocked AAAA requests.
BlockingIPv6 net.IP `json:"blocking_ipv6"`
// DisabledUntil is a timestamp until when the protection is disabled.
DisabledUntil *time.Time `json:"protection_disabled_until"`
// EDNSCSCustomIP is custom IP for EDNS Client Subnet.
EDNSCSCustomIP netip.Addr `json:"edns_cs_custom_ip"`
// DefaultLocalPTRUpstreams is used to pass the addresses from
// systemResolvers to the front-end. It's not a pointer to the slice since
// there is no need to omit it while decoding from JSON.
DefaultLocalPTRUpstreams []string `json:"default_local_ptr_upstreams,omitempty"`
}
func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
protectionEnabled, protectionDisabledUntil := s.UpdatedProtectionStatus()
s.serverLock.RLock()
defer s.serverLock.RUnlock()
upstreams := stringutil.CloneSliceOrEmpty(s.conf.UpstreamDNS)
upstreamFile := s.conf.UpstreamDNSFileName
bootstraps := stringutil.CloneSliceOrEmpty(s.conf.BootstrapDNS)
protectionEnabled := s.conf.ProtectionEnabled
blockingMode := s.conf.BlockingMode
blockingIPv4 := s.conf.BlockingIPv4
blockingIPv6 := s.conf.BlockingIPv6
ratelimit := s.conf.Ratelimit
customIP := s.conf.EDNSClientSubnet.CustomIP
enableEDNSClientSubnet := s.conf.EDNSClientSubnet.Enabled
useCustom := s.conf.EDNSClientSubnet.UseCustom
enableDNSSEC := s.conf.EnableDNSSEC
aaaaDisabled := s.conf.AAAADisabled
cacheSize := s.conf.CacheSize
@@ -67,6 +127,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
resolveClients := s.conf.ResolveClients
usePrivateRDNS := s.conf.UsePrivateRDNS
localPTRUpstreams := stringutil.CloneSliceOrEmpty(s.conf.LocalPTRResolvers)
var upstreamMode string
if s.conf.FastestAddr {
upstreamMode = "fastest_addr"
@@ -74,46 +135,41 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
upstreamMode = "parallel"
}
return &jsonDNSConfig{
Upstreams: &upstreams,
UpstreamsFile: &upstreamFile,
Bootstraps: &bootstraps,
ProtectionEnabled: &protectionEnabled,
BlockingMode: &blockingMode,
BlockingIPv4: blockingIPv4,
BlockingIPv6: blockingIPv6,
RateLimit: &ratelimit,
EDNSCSEnabled: &enableEDNSClientSubnet,
DNSSECEnabled: &enableDNSSEC,
DisableIPv6: &aaaaDisabled,
CacheSize: &cacheSize,
CacheMinTTL: &cacheMinTTL,
CacheMaxTTL: &cacheMaxTTL,
CacheOptimistic: &cacheOptimistic,
UpstreamMode: &upstreamMode,
ResolveClients: &resolveClients,
UsePrivateRDNS: &usePrivateRDNS,
LocalPTRUpstreams: &localPTRUpstreams,
}
}
func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
defLocalPTRUps, err := s.filterOurDNSAddrs(s.sysResolvers.Get())
if err != nil {
log.Debug("getting dns configuration: %s", err)
}
resp := struct {
jsonDNSConfig
// DefautLocalPTRUpstreams is used to pass the addresses from
// systemResolvers to the front-end. It's not a pointer to the slice
// since there is no need to omit it while decoding from JSON.
DefautLocalPTRUpstreams []string `json:"default_local_ptr_upstreams,omitempty"`
}{
jsonDNSConfig: *s.getDNSConfig(),
DefautLocalPTRUpstreams: defLocalPTRUps,
return &jsonDNSConfig{
Upstreams: &upstreams,
UpstreamsFile: &upstreamFile,
Bootstraps: &bootstraps,
ProtectionEnabled: &protectionEnabled,
BlockingMode: &blockingMode,
BlockingIPv4: blockingIPv4,
BlockingIPv6: blockingIPv6,
RateLimit: &ratelimit,
EDNSCSCustomIP: customIP,
EDNSCSEnabled: &enableEDNSClientSubnet,
EDNSCSUseCustom: &useCustom,
DNSSECEnabled: &enableDNSSEC,
DisableIPv6: &aaaaDisabled,
CacheSize: &cacheSize,
CacheMinTTL: &cacheMinTTL,
CacheMaxTTL: &cacheMaxTTL,
CacheOptimistic: &cacheOptimistic,
UpstreamMode: &upstreamMode,
ResolveClients: &resolveClients,
UsePrivateRDNS: &usePrivateRDNS,
LocalPTRUpstreams: &localPTRUpstreams,
DefaultLocalPTRUpstreams: defLocalPTRUps,
DisabledUntil: protectionDisabledUntil,
}
}
// handleGetConfig handles requests to the GET /control/dns_info endpoint.
func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
resp := s.getDNSConfig()
_ = aghhttp.WriteJSONResponse(w, r, resp)
}
@@ -204,6 +260,7 @@ func (req *jsonDNSConfig) checkCacheTTL() bool {
return min <= max
}
// handleSetConfig handles requests to the POST /control/dns_config endpoint.
func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
req := &jsonDNSConfig{}
err := json.NewDecoder(r.Body).Decode(req)
@@ -231,8 +288,8 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
}
}
// setConfigRestartable sets the server parameters. shouldRestart is true if
// the server should be restarted to apply changes.
// setConfig sets the server parameters. shouldRestart is true if the server
// should be restarted to apply changes.
func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
s.serverLock.Lock()
defer s.serverLock.Unlock()
@@ -250,6 +307,10 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
s.conf.FastestAddr = *dc.UpstreamMode == "fastest_addr"
}
if dc.EDNSCSUseCustom != nil && *dc.EDNSCSUseCustom {
s.conf.EDNSClientSubnet.CustomIP = dc.EDNSCSCustomIP
}
setIfNotNil(&s.conf.ProtectionEnabled, dc.ProtectionEnabled)
setIfNotNil(&s.conf.EnableDNSSEC, dc.DNSSECEnabled)
setIfNotNil(&s.conf.AAAADisabled, dc.DisableIPv6)
@@ -281,6 +342,7 @@ func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
setIfNotNil(&s.conf.UpstreamDNSFileName, dc.UpstreamsFile),
setIfNotNil(&s.conf.BootstrapDNS, dc.Bootstraps),
setIfNotNil(&s.conf.EDNSClientSubnet.Enabled, dc.EDNSCSEnabled),
setIfNotNil(&s.conf.EDNSClientSubnet.UseCustom, dc.EDNSCSUseCustom),
setIfNotNil(&s.conf.CacheSize, dc.CacheSize),
setIfNotNil(&s.conf.CacheMinTTL, dc.CacheMinTTL),
setIfNotNil(&s.conf.CacheMaxTTL, dc.CacheMaxTTL),
@@ -388,15 +450,15 @@ func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet)
var errs []error
for _, domain := range keys {
var subnet *net.IPNet
subnet, err = netutil.SubnetFromReversedAddr(domain)
var subnet netip.Prefix
subnet, err = extractARPASubnet(domain)
if err != nil {
errs = append(errs, err)
continue
}
if !privateNets.Contains(subnet.IP) {
if !privateNets.Contains(subnet.Addr().AsSlice()) {
errs = append(
errs,
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
@@ -577,6 +639,7 @@ func (err domainSpecificTestError) Error() (msg string) {
func checkDNS(
upstreamConfigStr string,
bootstrap []string,
bootstrapPrefIPv6 bool,
timeout time.Duration,
healthCheck healthCheckFunc,
) (err error) {
@@ -604,8 +667,9 @@ func checkDNS(
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
u, err := upstream.AddressToUpstream(upstreamAddr, &upstream.Options{
Bootstrap: bootstrap,
Timeout: timeout,
Bootstrap: bootstrap,
Timeout: timeout,
PreferIPv6: bootstrapPrefIPv6,
})
if err != nil {
return fmt.Errorf("failed to choose upstream for %q: %w", upstreamAddr, err)
@@ -637,6 +701,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
result := map[string]string{}
bootstraps := req.BootstrapDNS
bootstrapPrefIPv6 := s.conf.BootstrapPreferIPv6
timeout := s.conf.UpstreamTimeout
type upsCheckResult = struct {
@@ -653,7 +718,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
}
defer func() { resCh <- res }()
checkErr := checkDNS(ups, bootstraps, timeout, healthCheck)
checkErr := checkDNS(ups, bootstraps, bootstrapPrefIPv6, timeout, healthCheck)
if checkErr != nil {
res.res = checkErr.Error()
} else {
@@ -685,6 +750,52 @@ func (s *Server) handleCacheClear(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "OK")
}
// protectionJSON is an object for /control/protection endpoint.
type protectionJSON struct {
Enabled bool `json:"enabled"`
Duration uint `json:"duration"`
}
// handleSetProtection is a handler for the POST /control/protection HTTP API.
func (s *Server) handleSetProtection(w http.ResponseWriter, r *http.Request) {
protectionReq := &protectionJSON{}
err := json.NewDecoder(r.Body).Decode(protectionReq)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "reading req: %s", err)
return
}
var disabledUntil *time.Time
if protectionReq.Duration > 0 {
if protectionReq.Enabled {
aghhttp.Error(
r,
w,
http.StatusBadRequest,
"Setting a duration is only allowed with protection disabling",
)
return
}
calcTime := time.Now().Add(time.Duration(protectionReq.Duration) * time.Millisecond)
disabledUntil = &calcTime
}
func() {
s.serverLock.Lock()
defer s.serverLock.Unlock()
s.conf.ProtectionEnabled = protectionReq.Enabled
s.conf.ProtectionDisabledUntil = disabledUntil
}()
s.conf.ConfigModified()
aghhttp.OK(w)
}
// handleDoH is the DNS-over-HTTPs handler.
//
// Control flow:
@@ -719,6 +830,7 @@ func (s *Server) registerHandlers() {
s.conf.HTTPRegister(http.MethodGet, "/control/dns_info", s.handleGetConfig)
s.conf.HTTPRegister(http.MethodPost, "/control/dns_config", s.handleSetConfig)
s.conf.HTTPRegister(http.MethodPost, "/control/test_upstream_dns", s.handleTestUpstreamDNS)
s.conf.HTTPRegister(http.MethodPost, "/control/protection", s.handleSetProtection)
s.conf.HTTPRegister(http.MethodGet, "/control/access/list", s.handleAccessList)
s.conf.HTTPRegister(http.MethodPost, "/control/access/set", s.handleAccessSet)

View File

@@ -18,6 +18,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
@@ -57,7 +58,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
filterConf := &filtering.Config{
SafeBrowsingEnabled: true,
SafeBrowsingCacheSize: 1000,
SafeSearchEnabled: true,
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
SafeSearchCacheSize: 1000,
ParentalCacheSize: 1000,
CacheTime: 30,
@@ -122,7 +123,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
s.conf = tc.conf()
s.handleGetConfig(w, nil)
cType := w.Header().Get(aghhttp.HdrNameContentType)
cType := w.Header().Get(httphdr.ContentType)
assert.Equal(t, aghhttp.HdrValApplicationJSON, cType)
assert.JSONEq(t, string(caseWant), w.Body.String())
})
@@ -133,7 +134,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
filterConf := &filtering.Config{
SafeBrowsingEnabled: true,
SafeBrowsingCacheSize: 1000,
SafeSearchEnabled: true,
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
SafeSearchCacheSize: 1000,
ParentalCacheSize: 1000,
CacheTime: 30,
@@ -181,6 +182,12 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
}, {
name: "edns_cs_enabled",
wantSet: "",
}, {
name: "edns_cs_use_custom",
wantSet: "",
}, {
name: "edns_cs_use_custom_bad_ip",
wantSet: "decoding request: ParseAddr(\"bad.ip\"): unexpected character (at \"bad.ip\")",
}, {
name: "dnssec_enabled",
wantSet: "",
@@ -212,7 +219,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
}, {
name: "local_ptr_upstreams_bad",
wantSet: `validating private upstream servers: checking domain-specific upstreams: ` +
`bad arpa domain name "non.arpa": not a reversed ip network`,
`bad arpa domain name "non.arpa.": not a reversed ip network`,
}, {
name: "local_ptr_upstreams_null",
wantSet: "",
@@ -222,16 +229,20 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
Req json.RawMessage `json:"req"`
Want json.RawMessage `json:"want"`
}
loadTestData(t, t.Name()+jsonExt, &data)
testData := t.Name() + jsonExt
loadTestData(t, testData, &data)
for _, tc := range testCases {
// NOTE: Do not use require.Contains, because the size of the data
// prevents it from printing a meaningful error message.
caseData, ok := data[tc.name]
require.True(t, ok)
require.Truef(t, ok, "%q does not contain test data for test case %s", testData, tc.name)
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() {
s.conf = defaultConf
s.conf.FilteringConfig.EDNSClientSubnet.Enabled = false
s.conf.FilteringConfig.EDNSClientSubnet = &EDNSClientSubnet{}
})
rBody := io.NopCloser(bytes.NewReader(caseData.Req))
@@ -373,7 +384,7 @@ func TestValidateUpstreamsPrivate(t *testing.T) {
}, {
name: "not_arpa_subnet",
wantErr: `checking domain-specific upstreams: ` +
`bad arpa domain name "hello.world": not a reversed ip network`,
`bad arpa domain name "hello.world.": not a reversed ip network`,
u: "[/hello.world/]#",
}, {
name: "non-private_arpa_address",
@@ -389,8 +400,12 @@ func TestValidateUpstreamsPrivate(t *testing.T) {
name: "several_bad",
wantErr: `checking domain-specific upstreams: 2 errors: ` +
`"arpa domain \"1.2.3.4.in-addr.arpa.\" should point to a locally-served network", ` +
`"bad arpa domain name \"non.arpa\": not a reversed ip network"`,
`"bad arpa domain name \"non.arpa.\": not a reversed ip network"`,
u: "[/non.arpa/1.2.3.4.in-addr.arpa/127.in-addr.arpa/]#",
}, {
name: "partial_good",
wantErr: "",
u: "[/a.1.2.3.10.in-addr.arpa/a.10.in-addr.arpa/]#",
}}
for _, tc := range testCases {

View File

@@ -40,12 +40,17 @@ func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
log.Debug("client ip: %s", ip)
ipStr := ip.String()
ids := []string{ipStr, dctx.clientID}
// Synchronize access to s.queryLog and s.stats so they won't be suddenly
// uninitialized while in use. This can happen after proxy server has been
// stopped, but its workers haven't yet exited.
if shouldLog &&
s.queryLog != nil &&
s.queryLog.ShouldLog(host, q.Qtype, q.Qclass) {
// TODO(s.chzhen): Use dnsforward.dnsContext when it will start
// containing persistent client.
s.queryLog.ShouldLog(host, q.Qtype, q.Qclass, ids) {
s.logQuery(dctx, pctx, elapsed, ip)
} else {
log.Debug(
@@ -56,8 +61,11 @@ func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
)
}
if s.stats != nil && s.stats.ShouldCount(host, q.Qtype, q.Qclass) {
s.updateStats(dctx, elapsed, *dctx.result, ip)
if s.stats != nil &&
// TODO(s.chzhen): Use dnsforward.dnsContext when it will start
// containing persistent client.
s.stats.ShouldCount(host, q.Qtype, q.Qclass, ids) {
s.updateStats(dctx, elapsed, *dctx.result, ipStr)
}
return resultCodeSuccess
@@ -110,7 +118,7 @@ func (s *Server) updateStats(
ctx *dnsContext,
elapsed time.Duration,
res filtering.Result,
clientIP net.IP,
clientIP string,
) {
pctx := ctx.proxyCtx
e := stats.Entry{}
@@ -119,8 +127,8 @@ func (s *Server) updateStats(
if clientID := ctx.clientID; clientID != "" {
e.Client = clientID
} else if clientIP != nil {
e.Client = clientIP.String()
} else {
e.Client = clientIP
}
e.Time = uint32(elapsed / 1000)

View File

@@ -31,7 +31,7 @@ func (l *testQueryLog) Add(p *querylog.AddParams) {
}
// ShouldLog implements the [querylog.QueryLog] interface for *testQueryLog.
func (l *testQueryLog) ShouldLog(string, uint16, uint16) bool {
func (l *testQueryLog) ShouldLog(string, uint16, uint16, []string) bool {
return true
}
@@ -50,7 +50,7 @@ func (l *testStats) Update(e stats.Entry) {
}
// ShouldCount implements the [stats.Interface] interface for *testStats.
func (l *testStats) ShouldCount(string, uint16, uint16) bool {
func (l *testStats) ShouldCount(string, uint16, uint16, []string) bool {
return true
}

View File

@@ -12,6 +12,7 @@
"2620:fe::fe:10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
@@ -26,7 +27,9 @@
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": []
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
},
"fastest_addr": {
"upstream_dns": [
@@ -41,6 +44,7 @@
"2620:fe::fe:10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
@@ -55,7 +59,9 @@
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": []
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
},
"parallel": {
"upstream_dns": [
@@ -70,6 +76,7 @@
"2620:fe::fe:10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
@@ -84,6 +91,8 @@
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": []
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
}

View File

@@ -19,6 +19,7 @@
"2620:fe::fe:10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
@@ -33,7 +34,9 @@
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": []
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
},
"bootstraps": {
@@ -52,6 +55,7 @@
"9.9.9.10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
@@ -66,7 +70,9 @@
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": []
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
},
"blocking_mode_good": {
@@ -86,6 +92,7 @@
"2620:fe::fe:10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "refused",
"blocking_ipv4": "",
@@ -100,7 +107,9 @@
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": []
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
},
"blocking_mode_bad": {
@@ -120,6 +129,7 @@
"2620:fe::fe:10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
@@ -134,7 +144,9 @@
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": []
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
},
"ratelimit": {
@@ -154,6 +166,7 @@
"2620:fe::fe:10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 6,
"blocking_mode": "default",
"blocking_ipv4": "",
@@ -168,7 +181,9 @@
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": []
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
},
"edns_cs_enabled": {
@@ -188,6 +203,7 @@
"2620:fe::fe:10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
@@ -202,7 +218,87 @@
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": []
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
},
"edns_cs_use_custom": {
"req": {
"edns_cs_enabled": true,
"edns_cs_use_custom": true,
"edns_cs_custom_ip": "1.2.3.4"
},
"want": {
"upstream_dns": [
"8.8.8.8:53",
"8.8.4.4:53"
],
"upstream_dns_file": "",
"bootstrap_dns": [
"9.9.9.10",
"149.112.112.10",
"2620:fe::10",
"2620:fe::fe:10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"edns_cs_enabled": true,
"dnssec_enabled": false,
"disable_ipv6": false,
"upstream_mode": "",
"cache_size": 0,
"cache_ttl_min": 0,
"cache_ttl_max": 0,
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": [],
"edns_cs_use_custom": true,
"edns_cs_custom_ip": "1.2.3.4"
}
},
"edns_cs_use_custom_bad_ip": {
"req": {
"edns_cs_enabled": true,
"edns_cs_use_custom": true,
"edns_cs_custom_ip": "bad.ip"
},
"want": {
"upstream_dns": [
"8.8.8.8:53",
"8.8.4.4:53"
],
"upstream_dns_file": "",
"bootstrap_dns": [
"9.9.9.10",
"149.112.112.10",
"2620:fe::10",
"2620:fe::fe:10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
"upstream_mode": "",
"cache_size": 0,
"cache_ttl_min": 0,
"cache_ttl_max": 0,
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
},
"dnssec_enabled": {
@@ -222,6 +318,7 @@
"2620:fe::fe:10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
@@ -236,7 +333,9 @@
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": []
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
},
"cache_size": {
@@ -256,6 +355,7 @@
"2620:fe::fe:10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
@@ -270,7 +370,9 @@
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": []
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
},
"upstream_mode_parallel": {
@@ -290,6 +392,7 @@
"2620:fe::fe:10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
@@ -304,7 +407,9 @@
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": []
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
},
"upstream_mode_fastest_addr": {
@@ -324,6 +429,7 @@
"2620:fe::fe:10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
@@ -338,7 +444,9 @@
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": []
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
},
"upstream_dns_bad": {
@@ -360,6 +468,7 @@
"2620:fe::fe:10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
@@ -374,7 +483,9 @@
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": []
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
},
"bootstraps_bad": {
@@ -396,6 +507,7 @@
"2620:fe::fe:10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
@@ -410,7 +522,9 @@
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": []
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
},
"cache_bad_ttl": {
@@ -431,6 +545,7 @@
"2620:fe::fe:10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
@@ -445,7 +560,9 @@
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": []
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
},
"upstream_mode_bad": {
@@ -465,6 +582,7 @@
"2620:fe::fe:10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
@@ -479,7 +597,9 @@
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": []
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
},
"local_ptr_upstreams_good": {
@@ -501,6 +621,7 @@
"2620:fe::fe:10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
@@ -517,7 +638,9 @@
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": [
"123.123.123.123"
]
],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
},
"local_ptr_upstreams_bad": {
@@ -540,6 +663,7 @@
"2620:fe::fe:10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
@@ -554,7 +678,9 @@
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": []
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
},
"local_ptr_upstreams_null": {
@@ -574,6 +700,7 @@
"2620:fe::fe:10"
],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
@@ -588,7 +715,9 @@
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": []
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
}
}

View File

@@ -176,13 +176,16 @@ func (d *DNSFilter) filterExistsLocked(url string) (ok bool) {
// Add a filter
// Return FALSE if a filter with this URL exists
func (d *DNSFilter) filterAdd(flt FilterYAML) bool {
func (d *DNSFilter) filterAdd(flt FilterYAML) (err error) {
// Defer annotating to unlock sooner.
defer func() { err = errors.Annotate(err, "adding filter: %w") }()
d.filtersMu.Lock()
defer d.filtersMu.Unlock()
// Check for duplicates
// Check for duplicates.
if d.filterExistsLocked(flt.URL) {
return false
return errFilterExists
}
if flt.white {
@@ -190,7 +193,8 @@ func (d *DNSFilter) filterAdd(flt FilterYAML) bool {
} else {
d.Filters = append(d.Filters, flt)
}
return true
return nil
}
// Load filters from the disk
@@ -238,6 +242,7 @@ func updateUniqueFilterID(filters []FilterYAML) {
}
}
// TODO(e.burkov): Improve this inexhaustible source of races.
func assignUniqueFilterID() int64 {
value := nextFilterID
nextFilterID++
@@ -343,29 +348,31 @@ func (d *DNSFilter) refreshFiltersArray(filters *[]FilterYAML, force bool) (int,
}
updateCount := 0
d.filtersMu.Lock()
defer d.filtersMu.Unlock()
for i := range updateFilters {
uf := &updateFilters[i]
updated := updateFlags[i]
d.filtersMu.Lock()
for k := range *filters {
f := &(*filters)[k]
if f.ID != uf.ID || f.URL != uf.URL {
continue
}
f.LastUpdated = uf.LastUpdated
if !updated {
continue
}
log.Info("Updated filter #%d. Rules: %d -> %d",
f.ID, f.RulesCount, uf.RulesCount)
log.Info("Updated filter #%d. Rules: %d -> %d", f.ID, f.RulesCount, uf.RulesCount)
f.Name = uf.Name
f.RulesCount = uf.RulesCount
f.checksum = uf.checksum
updateCount++
}
d.filtersMu.Unlock()
}
return updateCount, updateFilters, updateFlags, false
@@ -421,11 +428,16 @@ func (d *DNSFilter) refreshFiltersIntl(block, allow, force bool) (int, bool) {
if !updated {
continue
}
_ = os.Remove(uf.Path(d.DataDir) + ".old")
p := uf.Path(d.DataDir)
err := os.Remove(p + ".old")
if err != nil {
log.Debug("filtering: removing old filter file %q: %s", p, err)
}
}
}
log.Debug("filtering: update finished")
log.Debug("filtering: update finished: %d lists updated", updNum)
return updNum, false
}
@@ -467,8 +479,8 @@ func scanLinesWithBreak(data []byte, atEOF bool) (advance int, token []byte, err
}
// parseFilter copies filter's content from src to dst and returns the number of
// rules, name, number of bytes written, checksum, and title of the parsed list.
// dst must not be nil.
// rules, number of bytes written, checksum, and title of the parsed list. dst
// must not be nil.
func (d *DNSFilter) parseFilter(
src io.Reader,
dst io.Writer,
@@ -550,14 +562,18 @@ func isHTML(line string) (ok bool) {
return strings.HasPrefix(line, "<html") || strings.HasPrefix(line, "<!doctype")
}
// Perform upgrade on a filter and update LastUpdated value
func (d *DNSFilter) update(filter *FilterYAML) (bool, error) {
b, err := d.updateIntl(filter)
// update refreshes filter's content and a/mtimes of it's file.
func (d *DNSFilter) update(filter *FilterYAML) (b bool, err error) {
b, err = d.updateIntl(filter)
filter.LastUpdated = time.Now()
if !b {
e := os.Chtimes(filter.Path(d.DataDir), filter.LastUpdated, filter.LastUpdated)
if e != nil {
log.Error("os.Chtimes(): %v", e)
chErr := os.Chtimes(
filter.Path(d.DataDir),
filter.LastUpdated,
filter.LastUpdated,
)
if chErr != nil {
log.Error("os.Chtimes(): %v", chErr)
}
}
@@ -591,11 +607,13 @@ func (d *DNSFilter) finalizeUpdate(
return os.Remove(tmpFileName)
}
log.Printf("saving filter %d contents to: %s", flt.ID, flt.Path(d.DataDir))
fltPath := flt.Path(d.DataDir)
log.Printf("saving contents of filter #%d into %s", flt.ID, fltPath)
// Don't use renamio or maybe packages, since those will require loading the
// whole filter content to the memory on Windows.
err = os.Rename(tmpFileName, flt.Path(d.DataDir))
err = os.Rename(tmpFileName, fltPath)
if err != nil {
return errors.WithDeferred(err, os.Remove(tmpFileName))
}
@@ -620,10 +638,14 @@ func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
return false, err
}
defer func() {
err = errors.WithDeferred(err, d.finalizeUpdate(tmpFile, flt, ok, name, rnum, cs))
if ok && err == nil {
finErr := d.finalizeUpdate(tmpFile, flt, ok, name, rnum, cs)
if ok && finErr == nil {
log.Printf("updated filter %d: %d bytes, %d rules", flt.ID, n, rnum)
return
}
err = errors.WithDeferred(err, finErr)
}()
// Change the default 0o600 permission to something more acceptable by end
@@ -634,7 +656,7 @@ func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
return false, fmt.Errorf("changing file mode: %w", err)
}
var rc io.ReadCloser
var r io.Reader
if !filepath.IsAbs(flt.URL) {
var resp *http.Response
resp, err = d.HTTPClient.Get(flt.URL)
@@ -651,16 +673,19 @@ func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
return false, fmt.Errorf("got status code %d, want %d", resp.StatusCode, http.StatusOK)
}
rc = resp.Body
r = resp.Body
} else {
rc, err = os.Open(flt.URL)
var f *os.File
f, err = os.Open(flt.URL)
if err != nil {
return false, fmt.Errorf("open file: %w", err)
}
defer func() { err = errors.WithDeferred(err, rc.Close()) }()
defer func() { err = errors.WithDeferred(err, f.Close()) }()
r = f
}
rnum, n, cs, name, err = d.parseFilter(rc, tmpFile)
rnum, n, cs, name, err = d.parseFilter(r, tmpFile)
return cs != flt.checksum && err == nil, err
}
@@ -705,10 +730,11 @@ func (d *DNSFilter) EnableFilters(async bool) {
}
func (d *DNSFilter) enableFiltersLocked(async bool) {
filters := []Filter{{
filters := make([]Filter, 1, len(d.Filters)+len(d.WhitelistFilters)+1)
filters[0] = Filter{
ID: CustomListID,
Data: []byte(strings.Join(d.UserRules, "\n")),
}}
}
for _, filter := range d.Filters {
if !filter.Enabled {

View File

@@ -63,6 +63,9 @@ type Settings struct {
SafeSearchEnabled bool
SafeBrowsingEnabled bool
ParentalEnabled bool
// ClientSafeSearch is a client configured safe search.
ClientSafeSearch SafeSearch
}
// Resolver is the interface for net.Resolver to simplify testing.
@@ -83,13 +86,16 @@ type Config struct {
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
ParentalEnabled bool `yaml:"parental_enabled"`
SafeSearchEnabled bool `yaml:"safesearch_enabled"`
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
SafeBrowsingCacheSize uint `yaml:"safebrowsing_cache_size"` // (in bytes)
SafeSearchCacheSize uint `yaml:"safesearch_cache_size"` // (in bytes)
ParentalCacheSize uint `yaml:"parental_cache_size"` // (in bytes)
CacheTime uint `yaml:"cache_time"` // Element's TTL (in minutes)
// TODO(a.garipov): Use timeutil.Duration
CacheTime uint `yaml:"cache_time"` // Element's TTL (in minutes)
SafeSearchConf SafeSearchConfig `yaml:"safe_search"`
SafeSearch SafeSearch `yaml:"-"`
Rewrites []*LegacyRewrite `yaml:"rewrites"`
@@ -107,9 +113,6 @@ type Config struct {
// Register an HTTP handler
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
// CustomResolver is the resolver used by DNSFilter.
CustomResolver Resolver `yaml:"-"`
// HTTPClient is the client to use for updating the remote filters.
HTTPClient *http.Client `yaml:"-"`
@@ -172,7 +175,6 @@ type DNSFilter struct {
safebrowsingCache cache.Cache
parentalCache cache.Cache
safeSearchCache cache.Cache
Config // for direct access by library users, even a = assignment
// confLock protects Config.
@@ -182,11 +184,6 @@ type DNSFilter struct {
filtersInitializerChan chan filtersInitializerParams
filtersInitializerLock sync.Mutex
// resolver only looks up the IP address of the host while safe search.
//
// TODO(e.burkov): Use upstream that configured in dnsforward instead.
resolver Resolver
refreshLock *sync.Mutex
// filterTitleRegexp is the regular expression to retrieve a name of a
@@ -195,6 +192,7 @@ type DNSFilter struct {
// TODO(e.burkov): Don't use regexp for such a simple text processing task.
filterTitleRegexp *regexp.Regexp
safeSearch SafeSearch
hostCheckers []hostChecker
}
@@ -298,7 +296,7 @@ func (d *DNSFilter) GetConfig() (s Settings) {
return Settings{
FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) != 0,
SafeSearchEnabled: d.Config.SafeSearchEnabled,
SafeSearchEnabled: d.Config.SafeSearchConf.Enabled,
SafeBrowsingEnabled: d.Config.SafeBrowsingEnabled,
ParentalEnabled: d.Config.ParentalEnabled,
}
@@ -942,7 +940,6 @@ func InitModule() {
// be non-nil.
func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
d = &DNSFilter{
resolver: net.DefaultResolver,
refreshLock: &sync.Mutex{},
filterTitleRegexp: regexp.MustCompile(`^! Title: +(.*)$`),
}
@@ -951,18 +948,12 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
EnableLRU: true,
MaxSize: c.SafeBrowsingCacheSize,
})
d.safeSearchCache = cache.New(cache.Config{
EnableLRU: true,
MaxSize: c.SafeSearchCacheSize,
})
d.parentalCache = cache.New(cache.Config{
EnableLRU: true,
MaxSize: c.ParentalCacheSize,
})
if r := c.CustomResolver; r != nil {
d.resolver = r
}
d.safeSearch = c.SafeSearch
d.hostCheckers = []hostChecker{{
check: d.matchSysHosts,

View File

@@ -2,10 +2,8 @@ package filtering
import (
"bytes"
"context"
"fmt"
"net"
"strings"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
@@ -33,7 +31,6 @@ func purgeCaches(d *DNSFilter) {
for _, c := range []cache.Cache{
d.safebrowsingCache,
d.parentalCache,
d.safeSearchCache,
} {
if c != nil {
c.Clear()
@@ -51,7 +48,7 @@ func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts
c.ParentalCacheSize = 10000
c.SafeSearchCacheSize = 1000
c.CacheTime = 30
setts.SafeSearchEnabled = c.SafeSearchEnabled
setts.SafeSearchEnabled = c.SafeSearchConf.Enabled
setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled
setts.ParentalEnabled = c.ParentalEnabled
} else {
@@ -216,164 +213,6 @@ func TestParallelSB(t *testing.T) {
})
}
// Safe Search.
func TestSafeSearch(t *testing.T) {
d, _ := newForTest(t, &Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close)
val, ok := d.SafeSearchDomain("www.google.com")
require.True(t, ok)
assert.Equal(t, "forcesafesearch.google.com", val)
}
func TestCheckHostSafeSearchYandex(t *testing.T) {
d, setts := newForTest(t, &Config{
SafeSearchEnabled: true,
}, nil)
t.Cleanup(d.Close)
yandexIP := net.IPv4(213, 180, 193, 56)
// Check host for each domain.
for _, host := range []string{
"yAndeX.ru",
"YANdex.COM",
"yandex.ua",
"yandex.by",
"yandex.kz",
"www.yandex.com",
} {
t.Run(strings.ToLower(host), func(t *testing.T) {
res, err := d.CheckHost(host, dns.TypeA, setts)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
require.Len(t, res.Rules, 1)
assert.Equal(t, yandexIP, res.Rules[0].IP)
assert.EqualValues(t, SafeSearchListID, res.Rules[0].FilterListID)
})
}
}
func TestCheckHostSafeSearchGoogle(t *testing.T) {
resolver := &aghtest.TestResolver{}
d, setts := newForTest(t, &Config{
SafeSearchEnabled: true,
CustomResolver: resolver,
}, nil)
t.Cleanup(d.Close)
ip, _ := resolver.HostToIPs("forcesafesearch.google.com")
// Check host for each domain.
for _, host := range []string{
"www.google.com",
"www.google.im",
"www.google.co.in",
"www.google.iq",
"www.google.is",
"www.google.it",
"www.google.je",
} {
t.Run(host, func(t *testing.T) {
res, err := d.CheckHost(host, dns.TypeA, setts)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
require.Len(t, res.Rules, 1)
assert.Equal(t, ip, res.Rules[0].IP)
assert.EqualValues(t, SafeSearchListID, res.Rules[0].FilterListID)
})
}
}
func TestSafeSearchCacheYandex(t *testing.T) {
d, setts := newForTest(t, nil, nil)
t.Cleanup(d.Close)
const domain = "yandex.ru"
// Check host with disabled safesearch.
res, err := d.CheckHost(domain, dns.TypeA, setts)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
require.Empty(t, res.Rules)
yandexIP := net.IPv4(213, 180, 193, 56)
d, setts = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close)
res, err = d.CheckHost(domain, dns.TypeA, setts)
require.NoError(t, err)
// For yandex we already know valid IP.
require.Len(t, res.Rules, 1)
assert.Equal(t, res.Rules[0].IP, yandexIP)
// Check cache.
cachedValue, isFound := getCachedResult(d.safeSearchCache, domain)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
}
func TestSafeSearchCacheGoogle(t *testing.T) {
resolver := &aghtest.TestResolver{}
d, setts := newForTest(t, &Config{
CustomResolver: resolver,
}, nil)
t.Cleanup(d.Close)
const domain = "www.google.ru"
res, err := d.CheckHost(domain, dns.TypeA, setts)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
require.Empty(t, res.Rules)
d, setts = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close)
d.resolver = resolver
// Lookup for safesearch domain.
safeDomain, ok := d.SafeSearchDomain(domain)
require.True(t, ok)
ips, err := resolver.LookupIP(context.Background(), "ip", safeDomain)
require.NoError(t, err)
var ip net.IP
for _, foundIP := range ips {
if foundIP.To4() != nil {
ip = foundIP
break
}
}
res, err = d.CheckHost(domain, dns.TypeA, setts)
require.NoError(t, err)
require.Len(t, res.Rules, 1)
assert.True(t, res.Rules[0].IP.Equal(ip))
// Check cache.
cachedValue, isFound := getCachedResult(d.safeSearchCache, domain)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.True(t, cachedValue.Rules[0].IP.Equal(ip))
}
// Parental.
func TestParentalControl(t *testing.T) {
@@ -854,27 +693,3 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
}
})
}
func BenchmarkSafeSearch(b *testing.B) {
d, _ := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
b.Cleanup(d.Close)
for n := 0; n < b.N; n++ {
val, ok := d.SafeSearchDomain("www.google.com")
require.True(b, ok)
assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
}
}
func BenchmarkSafeSearchParallel(b *testing.B) {
d, _ := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
b.Cleanup(d.Close)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
val, ok := d.SafeSearchDomain("www.google.com")
require.True(b, ok)
assert.Equal(b, "forcesafesearch.google.com", val, "Expected safesearch for google.com to be forcesafesearch.google.com")
}
})
}

View File

@@ -14,26 +14,33 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)
// validateFilterURL validates the filter list URL or file name.
func validateFilterURL(urlStr string) (err error) {
defer func() { err = errors.Annotate(err, "checking filter: %w") }()
if filepath.IsAbs(urlStr) {
_, err = os.Stat(urlStr)
if err != nil {
return fmt.Errorf("checking filter file: %w", err)
// Don't wrap the error since it's informative enough as is.
return err
}
return nil
}
url, err := url.ParseRequestURI(urlStr)
u, err := url.ParseRequestURI(urlStr)
if err != nil {
return fmt.Errorf("checking filter url: %w", err)
}
if s := url.Scheme; s != aghhttp.SchemeHTTP && s != aghhttp.SchemeHTTPS {
return fmt.Errorf("checking filter url: invalid scheme %q", s)
// Don't wrap the error since it's informative enough as is.
return err
} else if s := u.Scheme; s != aghhttp.SchemeHTTP && s != aghhttp.SchemeHTTPS {
return &url.Error{
Op: "Check scheme",
URL: urlStr,
Err: fmt.Errorf("only %v allowed", []string{aghhttp.SchemeHTTP, aghhttp.SchemeHTTPS}),
}
}
return nil
@@ -63,7 +70,8 @@ func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
// Check for duplicates
if d.filterExists(fj.URL) {
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL)
err = errFilterExists
aghhttp.Error(r, w, http.StatusBadRequest, "Filter with URL %q: %s", fj.URL, err)
return
}
@@ -99,7 +107,7 @@ func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
r,
w,
http.StatusBadRequest,
"Filter at the url %s is invalid (maybe it points to blank page?)",
"Filter with URL %q is invalid (maybe it points to blank page?)",
filt.URL,
)
@@ -108,8 +116,9 @@ func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
// URL is assumed valid so append it to filters, update config, write new
// file and reload it to engines.
if !d.filterAdd(filt) {
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL)
err = d.filterAdd(filt)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Filter with URL %q: %s", filt.URL, err)
return
}
@@ -137,31 +146,38 @@ func (d *DNSFilter) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
return
}
d.filtersMu.Lock()
filters := &d.Filters
if req.Whitelist {
filters = &d.WhitelistFilters
}
var deleted FilterYAML
var newFilters []FilterYAML
for _, flt := range *filters {
if flt.URL != req.URL {
newFilters = append(newFilters, flt)
func() {
d.filtersMu.Lock()
defer d.filtersMu.Unlock()
continue
filters := &d.Filters
if req.Whitelist {
filters = &d.WhitelistFilters
}
deleted = flt
path := flt.Path(d.DataDir)
err = os.Rename(path, path+".old")
delIdx := slices.IndexFunc(*filters, func(flt FilterYAML) bool {
return flt.URL == req.URL
})
if delIdx == -1 {
log.Error("deleting filter with url %q: %s", req.URL, errFilterNotExist)
return
}
deleted = (*filters)[delIdx]
p := deleted.Path(d.DataDir)
err = os.Rename(p, p+".old")
if err != nil {
log.Error("deleting filter %q: %s", path, err)
}
}
log.Error("deleting filter %d: renaming file %q: %s", deleted.ID, p, err)
*filters = newFilters
d.filtersMu.Unlock()
return
}
*filters = slices.Delete(*filters, delIdx, delIdx+1)
log.Info("deleted filter %d", deleted.ID)
}()
d.ConfigModified()
d.EnableFilters(true)
@@ -258,10 +274,6 @@ func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
type Req struct {
White bool `json:"whitelist"`
}
type Resp struct {
Updated int `json:"updated"`
}
resp := Resp{}
var err error
req := Req{}
@@ -273,6 +285,9 @@ func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
}
var ok bool
resp := struct {
Updated int `json:"updated"`
}{}
resp.Updated, _, ok = d.tryRefreshFilters(!req.White, req.White, true)
if !ok {
aghhttp.Error(
@@ -461,6 +476,7 @@ func (d *DNSFilter) RegisterFilteringHandlers() {
registerHTTP(http.MethodPost, "/control/safesearch/enable", d.handleSafeSearchEnable)
registerHTTP(http.MethodPost, "/control/safesearch/disable", d.handleSafeSearchDisable)
registerHTTP(http.MethodGet, "/control/safesearch/status", d.handleSafeSearchStatus)
registerHTTP(http.MethodPut, "/control/safesearch/settings", d.handleSafeSearchSettings)
registerHTTP(http.MethodGet, "/control/rewrite/list", d.handleRewriteList)
registerHTTP(http.MethodPost, "/control/rewrite/add", d.handleRewriteAdd)

View File

@@ -1,34 +1,23 @@
package filtering
import (
"bytes"
"context"
"encoding/binary"
"encoding/gob"
"fmt"
"net"
"net/http"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules"
)
import "github.com/miekg/dns"
// SafeSearch interface describes a service for search engines hosts rewrites.
type SafeSearch interface {
// SearchHost returns a replacement address for the search engine host.
SearchHost(host string, qtype uint16) (res *rules.DNSRewrite)
// CheckHost checks host with safe search engine.
// CheckHost checks host with safe search filter. CheckHost must be safe
// for concurrent use. qtype must be either [dns.TypeA] or [dns.TypeAAAA].
CheckHost(host string, qtype uint16) (res Result, err error)
// Update updates the configuration of the safe search filter. Update must
// be safe for concurrent use. An implementation of Update may ignore some
// fields, but it must document which.
Update(conf SafeSearchConfig) (err error)
}
// SafeSearchConfig is a struct with safe search related settings.
type SafeSearchConfig struct {
// CustomResolver is the resolver used by safe search.
CustomResolver Resolver `yaml:"-"`
CustomResolver Resolver `yaml:"-" json:"-"`
// Enabled indicates if safe search is enabled entirely.
Enabled bool `yaml:"enabled" json:"enabled"`
@@ -44,358 +33,27 @@ type SafeSearchConfig struct {
YouTube bool `yaml:"youtube" json:"youtube"`
}
/*
expire byte[4]
res Result
*/
func (d *DNSFilter) setCacheResult(cache cache.Cache, host string, res Result) int {
var buf bytes.Buffer
expire := uint(time.Now().Unix()) + d.Config.CacheTime*60
exp := make([]byte, 4)
binary.BigEndian.PutUint32(exp, uint32(expire))
_, _ = buf.Write(exp)
enc := gob.NewEncoder(&buf)
err := enc.Encode(res)
if err != nil {
log.Error("gob.Encode(): %s", err)
return 0
}
val := buf.Bytes()
_ = cache.Set([]byte(host), val)
return len(val)
}
func getCachedResult(cache cache.Cache, host string) (Result, bool) {
data := cache.Get([]byte(host))
if data == nil {
return Result{}, false
}
exp := int(binary.BigEndian.Uint32(data[:4]))
if exp <= int(time.Now().Unix()) {
cache.Del([]byte(host))
return Result{}, false
}
var buf bytes.Buffer
buf.Write(data[4:])
dec := gob.NewDecoder(&buf)
r := Result{}
err := dec.Decode(&r)
if err != nil {
log.Debug("gob.Decode(): %s", err)
return Result{}, false
}
return r, true
}
// SafeSearchDomain returns replacement address for search engine
func (d *DNSFilter) SafeSearchDomain(host string) (string, bool) {
val, ok := safeSearchDomains[host]
return val, ok
}
// checkSafeSearch checks host with safe search engine. Matches
// [hostChecker.check].
func (d *DNSFilter) checkSafeSearch(
host string,
_ uint16,
qtype uint16,
setts *Settings,
) (res Result, err error) {
if !setts.ProtectionEnabled || !setts.SafeSearchEnabled {
if !setts.ProtectionEnabled ||
!setts.SafeSearchEnabled ||
(qtype != dns.TypeA && qtype != dns.TypeAAAA) {
return Result{}, nil
}
if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("SafeSearch: lookup for %s", host)
}
// Check cache. Return cached result if it was found
cachedValue, isFound := getCachedResult(d.safeSearchCache, host)
if isFound {
// atomic.AddUint64(&gctx.stats.Safesearch.CacheHits, 1)
log.Tracef("SafeSearch: found in cache: %s", host)
return cachedValue, nil
}
safeHost, ok := d.SafeSearchDomain(host)
if !ok {
if d.safeSearch == nil {
return Result{}, nil
}
res = Result{
Rules: []*ResultRule{{
FilterListID: SafeSearchListID,
}},
Reason: FilteredSafeSearch,
IsFiltered: true,
clientSafeSearch := setts.ClientSafeSearch
if clientSafeSearch != nil {
return clientSafeSearch.CheckHost(host, qtype)
}
if ip := net.ParseIP(safeHost); ip != nil {
res.Rules[0].IP = ip
valLen := d.setCacheResult(d.safeSearchCache, host, res)
log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, valLen)
return res, nil
}
ips, err := d.resolver.LookupIP(context.Background(), "ip", safeHost)
if err != nil {
log.Tracef("SafeSearchDomain for %s was found but failed to lookup for %s cause %s", host, safeHost, err)
return Result{}, err
}
for _, ip := range ips {
if ip = ip.To4(); ip == nil {
continue
}
res.Rules[0].IP = ip
l := d.setCacheResult(d.safeSearchCache, host, res)
log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, l)
return res, nil
}
return Result{}, fmt.Errorf("no ipv4 addresses in safe search response for %s", safeHost)
}
func (d *DNSFilter) handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) {
setProtectedBool(&d.confLock, &d.Config.SafeSearchEnabled, true)
d.Config.ConfigModified()
}
func (d *DNSFilter) handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) {
setProtectedBool(&d.confLock, &d.Config.SafeSearchEnabled, false)
d.Config.ConfigModified()
}
func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) {
resp := &struct {
Enabled bool `json:"enabled"`
}{
Enabled: protectedBool(&d.confLock, &d.Config.SafeSearchEnabled),
}
_ = aghhttp.WriteJSONResponse(w, r, resp)
}
var safeSearchDomains = map[string]string{
"yandex.com": "213.180.193.56",
"yandex.ru": "213.180.193.56",
"yandex.ua": "213.180.193.56",
"yandex.by": "213.180.193.56",
"yandex.kz": "213.180.193.56",
"www.yandex.com": "213.180.193.56",
"www.yandex.ru": "213.180.193.56",
"www.yandex.ua": "213.180.193.56",
"www.yandex.by": "213.180.193.56",
"www.yandex.kz": "213.180.193.56",
"www.bing.com": "strict.bing.com",
"duckduckgo.com": "safe.duckduckgo.com",
"www.duckduckgo.com": "safe.duckduckgo.com",
"start.duckduckgo.com": "safe.duckduckgo.com",
"www.google.com": "forcesafesearch.google.com",
"www.google.ad": "forcesafesearch.google.com",
"www.google.ae": "forcesafesearch.google.com",
"www.google.com.af": "forcesafesearch.google.com",
"www.google.com.ag": "forcesafesearch.google.com",
"www.google.com.ai": "forcesafesearch.google.com",
"www.google.al": "forcesafesearch.google.com",
"www.google.am": "forcesafesearch.google.com",
"www.google.co.ao": "forcesafesearch.google.com",
"www.google.com.ar": "forcesafesearch.google.com",
"www.google.as": "forcesafesearch.google.com",
"www.google.at": "forcesafesearch.google.com",
"www.google.com.au": "forcesafesearch.google.com",
"www.google.az": "forcesafesearch.google.com",
"www.google.ba": "forcesafesearch.google.com",
"www.google.com.bd": "forcesafesearch.google.com",
"www.google.be": "forcesafesearch.google.com",
"www.google.bf": "forcesafesearch.google.com",
"www.google.bg": "forcesafesearch.google.com",
"www.google.com.bh": "forcesafesearch.google.com",
"www.google.bi": "forcesafesearch.google.com",
"www.google.bj": "forcesafesearch.google.com",
"www.google.com.bn": "forcesafesearch.google.com",
"www.google.com.bo": "forcesafesearch.google.com",
"www.google.com.br": "forcesafesearch.google.com",
"www.google.bs": "forcesafesearch.google.com",
"www.google.bt": "forcesafesearch.google.com",
"www.google.co.bw": "forcesafesearch.google.com",
"www.google.by": "forcesafesearch.google.com",
"www.google.com.bz": "forcesafesearch.google.com",
"www.google.ca": "forcesafesearch.google.com",
"www.google.cd": "forcesafesearch.google.com",
"www.google.cf": "forcesafesearch.google.com",
"www.google.cg": "forcesafesearch.google.com",
"www.google.ch": "forcesafesearch.google.com",
"www.google.ci": "forcesafesearch.google.com",
"www.google.co.ck": "forcesafesearch.google.com",
"www.google.cl": "forcesafesearch.google.com",
"www.google.cm": "forcesafesearch.google.com",
"www.google.cn": "forcesafesearch.google.com",
"www.google.com.co": "forcesafesearch.google.com",
"www.google.co.cr": "forcesafesearch.google.com",
"www.google.com.cu": "forcesafesearch.google.com",
"www.google.cv": "forcesafesearch.google.com",
"www.google.com.cy": "forcesafesearch.google.com",
"www.google.cz": "forcesafesearch.google.com",
"www.google.de": "forcesafesearch.google.com",
"www.google.dj": "forcesafesearch.google.com",
"www.google.dk": "forcesafesearch.google.com",
"www.google.dm": "forcesafesearch.google.com",
"www.google.com.do": "forcesafesearch.google.com",
"www.google.dz": "forcesafesearch.google.com",
"www.google.com.ec": "forcesafesearch.google.com",
"www.google.ee": "forcesafesearch.google.com",
"www.google.com.eg": "forcesafesearch.google.com",
"www.google.es": "forcesafesearch.google.com",
"www.google.com.et": "forcesafesearch.google.com",
"www.google.fi": "forcesafesearch.google.com",
"www.google.com.fj": "forcesafesearch.google.com",
"www.google.fm": "forcesafesearch.google.com",
"www.google.fr": "forcesafesearch.google.com",
"www.google.ga": "forcesafesearch.google.com",
"www.google.ge": "forcesafesearch.google.com",
"www.google.gg": "forcesafesearch.google.com",
"www.google.com.gh": "forcesafesearch.google.com",
"www.google.com.gi": "forcesafesearch.google.com",
"www.google.gl": "forcesafesearch.google.com",
"www.google.gm": "forcesafesearch.google.com",
"www.google.gp": "forcesafesearch.google.com",
"www.google.gr": "forcesafesearch.google.com",
"www.google.com.gt": "forcesafesearch.google.com",
"www.google.gy": "forcesafesearch.google.com",
"www.google.com.hk": "forcesafesearch.google.com",
"www.google.hn": "forcesafesearch.google.com",
"www.google.hr": "forcesafesearch.google.com",
"www.google.ht": "forcesafesearch.google.com",
"www.google.hu": "forcesafesearch.google.com",
"www.google.co.id": "forcesafesearch.google.com",
"www.google.ie": "forcesafesearch.google.com",
"www.google.co.il": "forcesafesearch.google.com",
"www.google.im": "forcesafesearch.google.com",
"www.google.co.in": "forcesafesearch.google.com",
"www.google.iq": "forcesafesearch.google.com",
"www.google.is": "forcesafesearch.google.com",
"www.google.it": "forcesafesearch.google.com",
"www.google.je": "forcesafesearch.google.com",
"www.google.com.jm": "forcesafesearch.google.com",
"www.google.jo": "forcesafesearch.google.com",
"www.google.co.jp": "forcesafesearch.google.com",
"www.google.co.ke": "forcesafesearch.google.com",
"www.google.com.kh": "forcesafesearch.google.com",
"www.google.ki": "forcesafesearch.google.com",
"www.google.kg": "forcesafesearch.google.com",
"www.google.co.kr": "forcesafesearch.google.com",
"www.google.com.kw": "forcesafesearch.google.com",
"www.google.kz": "forcesafesearch.google.com",
"www.google.la": "forcesafesearch.google.com",
"www.google.com.lb": "forcesafesearch.google.com",
"www.google.li": "forcesafesearch.google.com",
"www.google.lk": "forcesafesearch.google.com",
"www.google.co.ls": "forcesafesearch.google.com",
"www.google.lt": "forcesafesearch.google.com",
"www.google.lu": "forcesafesearch.google.com",
"www.google.lv": "forcesafesearch.google.com",
"www.google.com.ly": "forcesafesearch.google.com",
"www.google.co.ma": "forcesafesearch.google.com",
"www.google.md": "forcesafesearch.google.com",
"www.google.me": "forcesafesearch.google.com",
"www.google.mg": "forcesafesearch.google.com",
"www.google.mk": "forcesafesearch.google.com",
"www.google.ml": "forcesafesearch.google.com",
"www.google.com.mm": "forcesafesearch.google.com",
"www.google.mn": "forcesafesearch.google.com",
"www.google.ms": "forcesafesearch.google.com",
"www.google.com.mt": "forcesafesearch.google.com",
"www.google.mu": "forcesafesearch.google.com",
"www.google.mv": "forcesafesearch.google.com",
"www.google.mw": "forcesafesearch.google.com",
"www.google.com.mx": "forcesafesearch.google.com",
"www.google.com.my": "forcesafesearch.google.com",
"www.google.co.mz": "forcesafesearch.google.com",
"www.google.com.na": "forcesafesearch.google.com",
"www.google.com.nf": "forcesafesearch.google.com",
"www.google.com.ng": "forcesafesearch.google.com",
"www.google.com.ni": "forcesafesearch.google.com",
"www.google.ne": "forcesafesearch.google.com",
"www.google.nl": "forcesafesearch.google.com",
"www.google.no": "forcesafesearch.google.com",
"www.google.com.np": "forcesafesearch.google.com",
"www.google.nr": "forcesafesearch.google.com",
"www.google.nu": "forcesafesearch.google.com",
"www.google.co.nz": "forcesafesearch.google.com",
"www.google.com.om": "forcesafesearch.google.com",
"www.google.com.pa": "forcesafesearch.google.com",
"www.google.com.pe": "forcesafesearch.google.com",
"www.google.com.pg": "forcesafesearch.google.com",
"www.google.com.ph": "forcesafesearch.google.com",
"www.google.com.pk": "forcesafesearch.google.com",
"www.google.pl": "forcesafesearch.google.com",
"www.google.pn": "forcesafesearch.google.com",
"www.google.com.pr": "forcesafesearch.google.com",
"www.google.ps": "forcesafesearch.google.com",
"www.google.pt": "forcesafesearch.google.com",
"www.google.com.py": "forcesafesearch.google.com",
"www.google.com.qa": "forcesafesearch.google.com",
"www.google.ro": "forcesafesearch.google.com",
"www.google.ru": "forcesafesearch.google.com",
"www.google.rw": "forcesafesearch.google.com",
"www.google.com.sa": "forcesafesearch.google.com",
"www.google.com.sb": "forcesafesearch.google.com",
"www.google.sc": "forcesafesearch.google.com",
"www.google.se": "forcesafesearch.google.com",
"www.google.com.sg": "forcesafesearch.google.com",
"www.google.sh": "forcesafesearch.google.com",
"www.google.si": "forcesafesearch.google.com",
"www.google.sk": "forcesafesearch.google.com",
"www.google.com.sl": "forcesafesearch.google.com",
"www.google.sn": "forcesafesearch.google.com",
"www.google.so": "forcesafesearch.google.com",
"www.google.sm": "forcesafesearch.google.com",
"www.google.sr": "forcesafesearch.google.com",
"www.google.st": "forcesafesearch.google.com",
"www.google.com.sv": "forcesafesearch.google.com",
"www.google.td": "forcesafesearch.google.com",
"www.google.tg": "forcesafesearch.google.com",
"www.google.co.th": "forcesafesearch.google.com",
"www.google.com.tj": "forcesafesearch.google.com",
"www.google.tk": "forcesafesearch.google.com",
"www.google.tl": "forcesafesearch.google.com",
"www.google.tm": "forcesafesearch.google.com",
"www.google.tn": "forcesafesearch.google.com",
"www.google.to": "forcesafesearch.google.com",
"www.google.com.tr": "forcesafesearch.google.com",
"www.google.tt": "forcesafesearch.google.com",
"www.google.com.tw": "forcesafesearch.google.com",
"www.google.co.tz": "forcesafesearch.google.com",
"www.google.com.ua": "forcesafesearch.google.com",
"www.google.co.ug": "forcesafesearch.google.com",
"www.google.co.uk": "forcesafesearch.google.com",
"www.google.com.uy": "forcesafesearch.google.com",
"www.google.co.uz": "forcesafesearch.google.com",
"www.google.com.vc": "forcesafesearch.google.com",
"www.google.co.ve": "forcesafesearch.google.com",
"www.google.vg": "forcesafesearch.google.com",
"www.google.co.vi": "forcesafesearch.google.com",
"www.google.com.vn": "forcesafesearch.google.com",
"www.google.vu": "forcesafesearch.google.com",
"www.google.ws": "forcesafesearch.google.com",
"www.google.rs": "forcesafesearch.google.com",
"www.youtube.com": "restrictmoderate.youtube.com",
"m.youtube.com": "restrictmoderate.youtube.com",
"youtubei.googleapis.com": "restrictmoderate.youtube.com",
"youtube.googleapis.com": "restrictmoderate.youtube.com",
"www.youtube-nocookie.com": "restrictmoderate.youtube.com",
"pixabay.com": "safesearch.pixabay.com",
return d.safeSearch.CheckHost(host, qtype)
}

View File

@@ -1 +1 @@
|www.bing.com^$dnsrewrite=NOERROR;CNAME;strict.bing.com
|www.bing.com^$dnsrewrite=NOERROR;CNAME;strict.bing.com

View File

@@ -1,3 +1,3 @@
|duckduckgo.com^$dnsrewrite=NOERROR;CNAME;safe.duckduckgo.com
|start.duckduckgo.com^$dnsrewrite=NOERROR;CNAME;safe.duckduckgo.com
|www.duckduckgo.com^$dnsrewrite=NOERROR;CNAME;safe.duckduckgo.com
|www.duckduckgo.com^$dnsrewrite=NOERROR;CNAME;safe.duckduckgo.com

View File

@@ -188,4 +188,4 @@
|www.google.tt^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com
|www.google.vg^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com
|www.google.vu^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com
|www.google.ws^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com
|www.google.ws^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com

View File

@@ -1 +1 @@
|pixabay.com^$dnsrewrite=NOERROR;CNAME;safesearch.pixabay.com
|pixabay.com^$dnsrewrite=NOERROR;CNAME;safesearch.pixabay.com

View File

@@ -49,4 +49,4 @@
|yandex.ru^$dnsrewrite=NOERROR;A;213.180.193.56
|yandex.tj^$dnsrewrite=NOERROR;A;213.180.193.56
|yandex.tm^$dnsrewrite=NOERROR;A;213.180.193.56
|yandex.uz^$dnsrewrite=NOERROR;A;213.180.193.56
|yandex.uz^$dnsrewrite=NOERROR;A;213.180.193.56

View File

@@ -2,4 +2,4 @@
|m.youtube.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com
|youtubei.googleapis.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com
|youtube.googleapis.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com
|www.youtube-nocookie.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com
|www.youtube-nocookie.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com

View File

@@ -9,6 +9,7 @@ import (
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
@@ -53,44 +54,85 @@ func isServiceProtected(s filtering.SafeSearchConfig, service Service) (ok bool)
}
}
// DefaultSafeSearch is the default safesearch struct.
type DefaultSafeSearch struct {
engine *urlfilter.DNSEngine
safeSearchCache cache.Cache
resolver filtering.Resolver
cacheTime time.Duration
// Default is the default safe search filter that uses filtering rules with the
// dnsrewrite modifier.
type Default struct {
// mu protects engine.
mu *sync.RWMutex
// engine is the filtering engine that contains the DNS rewrite rules.
// engine may be nil, which means that this safe search filter is disabled.
engine *urlfilter.DNSEngine
cache cache.Cache
resolver filtering.Resolver
logPrefix string
cacheTTL time.Duration
}
// NewDefaultSafeSearch returns new safesearch struct. CacheTime is an element
// TTL (in minutes).
func NewDefaultSafeSearch(
// NewDefault returns an initialized default safe search filter. name is used
// for logging.
func NewDefault(
conf filtering.SafeSearchConfig,
name string,
cacheSize uint,
cacheTime time.Duration,
) (ss *DefaultSafeSearch, err error) {
engine, err := newEngine(filtering.SafeSearchListID, conf)
if err != nil {
return nil, err
}
cacheTTL time.Duration,
) (ss *Default, err error) {
var resolver filtering.Resolver = net.DefaultResolver
if conf.CustomResolver != nil {
resolver = conf.CustomResolver
}
return &DefaultSafeSearch{
engine: engine,
safeSearchCache: cache.New(cache.Config{
ss = &Default{
mu: &sync.RWMutex{},
cache: cache.New(cache.Config{
EnableLRU: true,
MaxSize: cacheSize,
}),
cacheTime: cacheTime,
resolver: resolver,
}, nil
resolver: resolver,
// Use %s, because the client safe-search names already contain double
// quotes.
logPrefix: fmt.Sprintf("safesearch %s: ", name),
cacheTTL: cacheTTL,
}
err = ss.resetEngine(filtering.SafeSearchListID, conf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
return ss, nil
}
// newEngine creates new engine for provided safe search configuration.
func newEngine(listID int, conf filtering.SafeSearchConfig) (engine *urlfilter.DNSEngine, err error) {
// log is a helper for logging that includes the name of the safe search
// filter. level must be one of [log.DEBUG], [log.INFO], and [log.ERROR].
func (ss *Default) log(level log.Level, msg string, args ...any) {
switch level {
case log.DEBUG:
log.Debug(ss.logPrefix+msg, args...)
case log.INFO:
log.Info(ss.logPrefix+msg, args...)
case log.ERROR:
log.Error(ss.logPrefix+msg, args...)
default:
panic(fmt.Errorf("safesearch: unsupported logging level %d", level))
}
}
// resetEngine creates new engine for provided safe search configuration and
// sets it in ss.
func (ss *Default) resetEngine(
listID int,
conf filtering.SafeSearchConfig,
) (err error) {
if !conf.Enabled {
ss.log(log.INFO, "disabled")
return nil
}
var sb strings.Builder
for service, serviceRules := range safeSearchRules {
if isServiceProtected(conf, service) {
@@ -106,20 +148,73 @@ func newEngine(listID int, conf filtering.SafeSearchConfig) (engine *urlfilter.D
rs, err := filterlist.NewRuleStorage([]filterlist.RuleList{strList})
if err != nil {
return nil, fmt.Errorf("creating rule storage: %w", err)
return fmt.Errorf("creating rule storage: %w", err)
}
engine = urlfilter.NewDNSEngine(rs)
log.Info("safesearch: filter %d: reset %d rules", listID, engine.RulesCount)
ss.engine = urlfilter.NewDNSEngine(rs)
return engine, nil
ss.log(log.INFO, "reset %d rules", ss.engine.RulesCount)
return nil
}
// type check
var _ filtering.SafeSearch = (*DefaultSafeSearch)(nil)
var _ filtering.SafeSearch = (*Default)(nil)
// CheckHost implements the [filtering.SafeSearch] interface for
// *DefaultSafeSearch.
func (ss *Default) CheckHost(
host string,
qtype rules.RRType,
) (res filtering.Result, err error) {
start := time.Now()
defer func() {
ss.log(log.DEBUG, "lookup for %q finished in %s", host, time.Since(start))
}()
if qtype != dns.TypeA && qtype != dns.TypeAAAA {
return filtering.Result{}, fmt.Errorf("unsupported question type %s", dns.Type(qtype))
}
// Check cache. Return cached result if it was found
cachedValue, isFound := ss.getCachedResult(host, qtype)
if isFound {
ss.log(log.DEBUG, "found in cache: %q", host)
return cachedValue, nil
}
rewrite := ss.searchHost(host, qtype)
if rewrite == nil {
return filtering.Result{}, nil
}
fltRes, err := ss.newResult(rewrite, qtype)
if err != nil {
ss.log(log.DEBUG, "looking up addresses for %q: %s", host, err)
return filtering.Result{}, err
}
if fltRes != nil {
res = *fltRes
ss.setCacheResult(host, qtype, res)
return res, nil
}
return filtering.Result{}, fmt.Errorf("no ipv4 addresses for %q", host)
}
// searchHost looks up DNS rewrites in the internal DNS filtering engine.
func (ss *Default) searchHost(host string, qtype rules.RRType) (res *rules.DNSRewrite) {
ss.mu.RLock()
defer ss.mu.RUnlock()
if ss.engine == nil {
return nil
}
// SearchHost implements the [filtering.SafeSearch] interface for *DefaultSafeSearch.
func (ss *DefaultSafeSearch) SearchHost(host string, qtype uint16) (res *rules.DNSRewrite) {
r, _ := ss.engine.MatchRequest(&urlfilter.DNSRequest{
Hostname: strings.ToLower(host),
DNSType: qtype,
@@ -133,51 +228,11 @@ func (ss *DefaultSafeSearch) SearchHost(host string, qtype uint16) (res *rules.D
return nil
}
// CheckHost implements the [filtering.SafeSearch] interface for
// *DefaultSafeSearch.
func (ss *DefaultSafeSearch) CheckHost(
host string,
qtype uint16,
) (res filtering.Result, err error) {
if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("safesearch: lookup for %s", host)
}
// Check cache. Return cached result if it was found
cachedValue, isFound := ss.getCachedResult(host)
if isFound {
log.Debug("safesearch: found in cache: %s", host)
return cachedValue, nil
}
rewrite := ss.SearchHost(host, qtype)
if rewrite == nil {
return filtering.Result{}, nil
}
dRes, err := ss.newResult(rewrite, qtype)
if err != nil {
log.Debug("safesearch: failed to lookup addresses for %s: %s", host, err)
return filtering.Result{}, err
}
if dRes != nil {
res = *dRes
ss.setCacheResult(host, res)
return res, nil
}
return filtering.Result{}, fmt.Errorf("no ipv4 addresses in safe search response for %s", host)
}
// newResult creates Result object from rewrite rule.
func (ss *DefaultSafeSearch) newResult(
// newResult creates Result object from rewrite rule. qtype must be either
// [dns.TypeA] or [dns.TypeAAAA].
func (ss *Default) newResult(
rewrite *rules.DNSRewrite,
qtype uint16,
qtype rules.RRType,
) (res *filtering.Result, err error) {
res = &filtering.Result{
Rules: []*filtering.ResultRule{{
@@ -187,7 +242,7 @@ func (ss *DefaultSafeSearch) newResult(
IsFiltered: true,
}
if rewrite.RRType == qtype && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
if rewrite.RRType == qtype {
ip, ok := rewrite.Value.(net.IP)
if !ok || ip == nil {
return nil, nil
@@ -198,17 +253,25 @@ func (ss *DefaultSafeSearch) newResult(
return res, nil
}
if rewrite.NewCNAME == "" {
host := rewrite.NewCNAME
if host == "" {
return nil, nil
}
ips, err := ss.resolver.LookupIP(context.Background(), "ip", rewrite.NewCNAME)
ss.log(log.DEBUG, "resolving %q", host)
ips, err := ss.resolver.LookupIP(context.Background(), qtypeToProto(qtype), host)
if err != nil {
return nil, err
}
ss.log(log.DEBUG, "resolved %s", ips)
for _, ip := range ips {
if ip = ip.To4(); ip == nil {
// TODO(a.garipov): Remove this filtering once the resolver we use
// actually learns about network.
ip = fitToProto(ip, qtype)
if ip == nil {
continue
}
@@ -220,38 +283,71 @@ func (ss *DefaultSafeSearch) newResult(
return nil, nil
}
// setCacheResult stores data in cache for host.
func (ss *DefaultSafeSearch) setCacheResult(host string, res filtering.Result) {
expire := uint32(time.Now().Add(ss.cacheTime).Unix())
// qtypeToProto returns "ip4" for [dns.TypeA] and "ip6" for [dns.TypeAAAA].
// It panics for other types.
func qtypeToProto(qtype rules.RRType) (proto string) {
switch qtype {
case dns.TypeA:
return "ip4"
case dns.TypeAAAA:
return "ip6"
default:
panic(fmt.Errorf("safesearch: unsupported question type %s", dns.Type(qtype)))
}
}
// fitToProto returns a non-nil IP address if ip is the correct protocol version
// for qtype. qtype is expected to be either [dns.TypeA] or [dns.TypeAAAA].
func fitToProto(ip net.IP, qtype rules.RRType) (res net.IP) {
ip4 := ip.To4()
if qtype == dns.TypeA {
return ip4
}
if ip4 == nil {
return ip
}
return nil
}
// setCacheResult stores data in cache for host. qtype is expected to be either
// [dns.TypeA] or [dns.TypeAAAA].
func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering.Result) {
expire := uint32(time.Now().Add(ss.cacheTTL).Unix())
exp := make([]byte, 4)
binary.BigEndian.PutUint32(exp, expire)
buf := bytes.NewBuffer(exp)
err := gob.NewEncoder(buf).Encode(res)
if err != nil {
log.Error("safesearch: cache encoding: %s", err)
ss.log(log.ERROR, "cache encoding: %s", err)
return
}
val := buf.Bytes()
_ = ss.safeSearchCache.Set([]byte(host), val)
_ = ss.cache.Set([]byte(dns.Type(qtype).String()+" "+host), val)
log.Debug("safesearch: stored in cache: %s (%d bytes)", host, len(val))
ss.log(log.DEBUG, "stored in cache: %q, %d bytes", host, len(val))
}
// getCachedResult returns stored data from cache for host.
func (ss *DefaultSafeSearch) getCachedResult(host string) (res filtering.Result, ok bool) {
// getCachedResult returns stored data from cache for host. qtype is expected
// to be either [dns.TypeA] or [dns.TypeAAAA].
func (ss *Default) getCachedResult(
host string,
qtype rules.RRType,
) (res filtering.Result, ok bool) {
res = filtering.Result{}
data := ss.safeSearchCache.Get([]byte(host))
data := ss.cache.Get([]byte(dns.Type(qtype).String() + " " + host))
if data == nil {
return res, false
}
exp := binary.BigEndian.Uint32(data[:4])
if exp <= uint32(time.Now().Unix()) {
ss.safeSearchCache.Del([]byte(host))
ss.cache.Del([]byte(host))
return res, false
}
@@ -260,10 +356,27 @@ func (ss *DefaultSafeSearch) getCachedResult(host string) (res filtering.Result,
err := gob.NewDecoder(buf).Decode(&res)
if err != nil {
log.Debug("safesearch: cache decoding: %s", err)
ss.log(log.ERROR, "cache decoding: %s", err)
return filtering.Result{}, false
}
return res, true
}
// Update implements the [filtering.SafeSearch] interface for *Default. Update
// ignores the CustomResolver and Enabled fields.
func (ss *Default) Update(conf filtering.SafeSearchConfig) (err error) {
ss.mu.Lock()
defer ss.mu.Unlock()
err = ss.resetEngine(filtering.SafeSearchListID, conf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
ss.cache.Clear()
return nil
}

View File

@@ -0,0 +1,137 @@
package safesearch
import (
"context"
"net"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TODO(a.garipov): Move as much of this as possible into proper external tests.
const (
// TODO(a.garipov): Add IPv6 tests.
testQType = dns.TypeA
testCacheSize = 5000
testCacheTTL = 30 * time.Minute
)
var defaultSafeSearchConf = filtering.SafeSearchConfig{
Enabled: true,
Bing: true,
DuckDuckGo: true,
Google: true,
Pixabay: true,
Yandex: true,
YouTube: true,
}
var yandexIP = net.IPv4(213, 180, 193, 56)
func newForTest(t testing.TB, ssConf filtering.SafeSearchConfig) (ss *Default) {
ss, err := NewDefault(ssConf, "", testCacheSize, testCacheTTL)
require.NoError(t, err)
return ss
}
func TestSafeSearch(t *testing.T) {
ss := newForTest(t, defaultSafeSearchConf)
val := ss.searchHost("www.google.com", testQType)
assert.Equal(t, &rules.DNSRewrite{NewCNAME: "forcesafesearch.google.com"}, val)
}
func TestSafeSearchCacheYandex(t *testing.T) {
const domain = "yandex.ru"
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
// Check host with disabled safesearch.
res, err := ss.CheckHost(domain, testQType)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules)
ss = newForTest(t, defaultSafeSearchConf)
res, err = ss.CheckHost(domain, testQType)
require.NoError(t, err)
// For yandex we already know valid IP.
require.Len(t, res.Rules, 1)
assert.Equal(t, res.Rules[0].IP, yandexIP)
// Check cache.
cachedValue, isFound := ss.getCachedResult(domain, testQType)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
}
func TestSafeSearchCacheGoogle(t *testing.T) {
const domain = "www.google.ru"
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
res, err := ss.CheckHost(domain, testQType)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules)
resolver := &aghtest.TestResolver{}
ss = newForTest(t, defaultSafeSearchConf)
ss.resolver = resolver
// Lookup for safesearch domain.
rewrite := ss.searchHost(domain, testQType)
ips, err := resolver.LookupIP(context.Background(), "ip", rewrite.NewCNAME)
require.NoError(t, err)
var foundIP net.IP
for _, ip := range ips {
if ip.To4() != nil {
foundIP = ip
break
}
}
res, err = ss.CheckHost(domain, testQType)
require.NoError(t, err)
require.Len(t, res.Rules, 1)
assert.True(t, res.Rules[0].IP.Equal(foundIP))
// Check cache.
cachedValue, isFound := ss.getCachedResult(domain, testQType)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.True(t, cachedValue.Rules[0].IP.Equal(foundIP))
}
const googleHost = "www.google.com"
var dnsRewriteSink *rules.DNSRewrite
func BenchmarkSafeSearch(b *testing.B) {
ss := newForTest(b, defaultSafeSearchConf)
for n := 0; n < b.N; n++ {
dnsRewriteSink = ss.searchHost(googleHost, testQType)
}
assert.Equal(b, "forcesafesearch.google.com", dnsRewriteSink.NewCNAME)
}

View File

@@ -1,26 +1,37 @@
package safesearch
package safesearch_test
import (
"context"
"net"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// Common test constants.
const (
safeSearchCacheSize = 5000
cacheTime = 30 * time.Minute
// TODO(a.garipov): Add IPv6 tests.
testQType = dns.TypeA
testCacheSize = 5000
testCacheTTL = 30 * time.Minute
)
var defaultSafeSearchConf = filtering.SafeSearchConfig{
Enabled: true,
// testConf is the default safe search configuration for tests.
var testConf = filtering.SafeSearchConfig{
CustomResolver: nil,
Enabled: true,
Bing: true,
DuckDuckGo: true,
Google: true,
@@ -29,25 +40,15 @@ var defaultSafeSearchConf = filtering.SafeSearchConfig{
YouTube: true,
}
// yandexIP is the expected IP address of Yandex safe search results. Keep in
// sync with the rules data.
var yandexIP = net.IPv4(213, 180, 193, 56)
func newForTest(t testing.TB, ssConf filtering.SafeSearchConfig) (ss *DefaultSafeSearch) {
ss, err := NewDefaultSafeSearch(ssConf, safeSearchCacheSize, cacheTime)
func TestDefault_CheckHost_yandex(t *testing.T) {
conf := testConf
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
require.NoError(t, err)
return ss
}
func TestSafeSearch(t *testing.T) {
ss := newForTest(t, defaultSafeSearchConf)
val := ss.SearchHost("www.google.com", dns.TypeA)
assert.Equal(t, &rules.DNSRewrite{NewCNAME: "forcesafesearch.google.com"}, val)
}
func TestCheckHostSafeSearchYandex(t *testing.T) {
ss := newForTest(t, defaultSafeSearchConf)
// Check host for each domain.
for _, host := range []string{
"yandex.ru",
@@ -57,7 +58,8 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
"yandex.kz",
"www.yandex.com",
} {
res, err := ss.CheckHost(host, dns.TypeA)
var res filtering.Result
res, err = ss.CheckHost(host, testQType)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
@@ -69,12 +71,14 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
}
}
func TestCheckHostSafeSearchGoogle(t *testing.T) {
func TestDefault_CheckHost_google(t *testing.T) {
resolver := &aghtest.TestResolver{}
ip, _ := resolver.HostToIPs("forcesafesearch.google.com")
ss := newForTest(t, defaultSafeSearchConf)
ss.resolver = resolver
conf := testConf
conf.CustomResolver = resolver
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
require.NoError(t, err)
// Check host for each domain.
for _, host := range []string{
@@ -87,7 +91,8 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
"www.google.je",
} {
t.Run(host, func(t *testing.T) {
res, err := ss.CheckHost(host, dns.TypeA)
var res filtering.Result
res, err = ss.CheckHost(host, testQType)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
@@ -100,103 +105,35 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
}
}
func TestSafeSearchCacheYandex(t *testing.T) {
const domain = "yandex.ru"
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
// Check host with disabled safesearch.
res, err := ss.CheckHost(domain, dns.TypeA)
func TestDefault_Update(t *testing.T) {
conf := testConf
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules)
ss = newForTest(t, defaultSafeSearchConf)
res, err = ss.CheckHost(domain, dns.TypeA)
res, err := ss.CheckHost("www.yandex.com", testQType)
require.NoError(t, err)
// For yandex we already know valid IP.
require.Len(t, res.Rules, 1)
assert.True(t, res.IsFiltered)
assert.Equal(t, res.Rules[0].IP, yandexIP)
// Check cache.
cachedValue, isFound := ss.getCachedResult(domain)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
}
func TestSafeSearchCacheGoogle(t *testing.T) {
const domain = "www.google.ru"
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
res, err := ss.CheckHost(domain, dns.TypeA)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules)
resolver := &aghtest.TestResolver{}
ss = newForTest(t, defaultSafeSearchConf)
ss.resolver = resolver
// Lookup for safesearch domain.
rewrite := ss.SearchHost(domain, dns.TypeA)
ips, err := resolver.LookupIP(context.Background(), "ip", rewrite.NewCNAME)
require.NoError(t, err)
var foundIP net.IP
for _, ip := range ips {
if ip.To4() != nil {
foundIP = ip
break
}
}
res, err = ss.CheckHost(domain, dns.TypeA)
require.NoError(t, err)
require.Len(t, res.Rules, 1)
assert.True(t, res.Rules[0].IP.Equal(foundIP))
// Check cache.
cachedValue, isFound := ss.getCachedResult(domain)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.True(t, cachedValue.Rules[0].IP.Equal(foundIP))
}
const googleHost = "www.google.com"
var dnsRewriteSink *rules.DNSRewrite
func BenchmarkSafeSearch(b *testing.B) {
ss := newForTest(b, defaultSafeSearchConf)
for n := 0; n < b.N; n++ {
dnsRewriteSink = ss.SearchHost(googleHost, dns.TypeA)
}
assert.Equal(b, "forcesafesearch.google.com", dnsRewriteSink.NewCNAME)
}
var dnsRewriteParallelSink *rules.DNSRewrite
func BenchmarkSafeSearch_parallel(b *testing.B) {
ss := newForTest(b, defaultSafeSearchConf)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
dnsRewriteParallelSink = ss.SearchHost(googleHost, dns.TypeA)
}
err = ss.Update(filtering.SafeSearchConfig{
Enabled: true,
Google: false,
})
require.NoError(t, err)
assert.Equal(b, "forcesafesearch.google.com", dnsRewriteParallelSink.NewCNAME)
res, err = ss.CheckHost("www.yandex.com", testQType)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
err = ss.Update(filtering.SafeSearchConfig{
Enabled: false,
Google: true,
})
require.NoError(t, err)
res, err = ss.CheckHost("www.yandex.com", testQType)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
}

View File

@@ -0,0 +1,71 @@
package filtering
import (
"encoding/json"
"net/http"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
)
// handleSafeSearchEnable is the handler for POST /control/safesearch/enable
// HTTP API.
//
// Deprecated: Use handleSafeSearchSettings.
func (d *DNSFilter) handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) {
setProtectedBool(&d.confLock, &d.Config.SafeSearchConf.Enabled, true)
d.Config.ConfigModified()
}
// handleSafeSearchDisable is the handler for POST /control/safesearch/disable
// HTTP API.
//
// Deprecated: Use handleSafeSearchSettings.
func (d *DNSFilter) handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) {
setProtectedBool(&d.confLock, &d.Config.SafeSearchConf.Enabled, false)
d.Config.ConfigModified()
}
// handleSafeSearchStatus is the handler for GET /control/safesearch/status
// HTTP API.
func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) {
var resp SafeSearchConfig
func() {
d.confLock.RLock()
defer d.confLock.RUnlock()
resp = d.Config.SafeSearchConf
}()
_ = aghhttp.WriteJSONResponse(w, r, resp)
}
// handleSafeSearchSettings is the handler for PUT /control/safesearch/settings
// HTTP API.
func (d *DNSFilter) handleSafeSearchSettings(w http.ResponseWriter, r *http.Request) {
req := &SafeSearchConfig{}
err := json.NewDecoder(r.Body).Decode(req)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "reading req: %s", err)
return
}
conf := *req
err = d.safeSearch.Update(conf)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "updating: %s", err)
return
}
func() {
d.confLock.Lock()
defer d.confLock.Unlock()
d.Config.SafeSearchConf = conf
}()
d.Config.ConfigModified()
aghhttp.OK(w)
}

View File

@@ -311,6 +311,14 @@ var blockedServices = []blockedService{{
"||warp.plus^",
"||workers.dev^",
},
}, {
ID: "crunchyroll",
Name: "Crunchyroll",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M 25 3 C 12.85 3 3 12.85 3 25 C 3 40.188 13.387672 44.538609 20.388672 45.974609 C 20.427672 45.982609 20.465953 45.986328 20.501953 45.986328 C 21.006953 45.986328 21.206312 45.25525 20.695312 45.03125 C 13.285312 41.79025 8.0301562 34.327141 9.1601562 25.494141 C 10.256156 16.920141 17.244938 10.069141 25.835938 9.1191406 C 26.564937 9.0381406 27.287 9 28 9 C 35.541 9 42.044422 13.395672 45.107422 19.763672 C 45.206422 19.968672 45.382594 20.058594 45.558594 20.058594 C 45.853594 20.058594 46.144828 19.8075 46.048828 19.4375 C 44.302828 12.7105 39 3 25 3 z M 29 14 C 20.481 14 13.619625 21.101031 14.015625 29.707031 C 14.366625 37.346031 20.653016 43.631422 28.291016 43.982422 C 28.528016 43.994422 28.766 44 29 44 C 37.285 44 44 37.285 44 29 C 44 27.819 43.860563 26.670359 43.601562 25.568359 C 43.542563 25.319359 43.332234 25.183594 43.115234 25.183594 C 42.961234 25.183594 42.806266 25.251484 42.697266 25.396484 C 41.512266 26.976484 39.627 28 37.5 28 C 37.397 28 37.293453 27.997188 37.189453 27.992188 C 34.031453 27.845188 31.348203 25.317875 31.033203 22.171875 C 30.763203 19.477875 32.142297 17.082328 34.279297 15.861328 C 34.656297 15.646328 34.62475 15.100266 34.21875 14.947266 C 32.59375 14.340266 30.838 14 29 14 z M 44.296875 26.595703 L 44.300781 26.595703 L 44.296875 26.595703 z\"/></svg>"),
Rules: []string{
"||crunchyroll.com^",
"||gccrunchyroll.com^",
},
}, {
ID: "dailymotion",
Name: "Dailymotion",
@@ -1182,6 +1190,24 @@ var blockedServices = []blockedService{{
"||gog.com^",
"||gogalaxy.com^",
},
}, {
ID: "hbomax",
Name: "HBO Max",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M0 14v22h5v-9h3v9h5V14H8v8H5v-8H0zm15 0v22h8.4c3.1 0 5.7-2.3 6.2-5.4a11 11 0 1 0 0-11.2c-.5-3-3-5.4-6.2-5.4H15zm5 5h3a2 2 0 1 1 0 4h-3v-4zm19 0a6 6 0 1 1 0 12 6 6 0 0 1 0-12zm0 2a4 4 0 0 0-4 4 4 4 0 0 0 4 4 4 4 0 0 0 4-4 4 4 0 0 0-4-4zm-11 2.8v2.4c-.4-.5-1-1-2-1.2 1-.3 1.7-.8 2-1.3zm-8 4h3a2 2 0 1 1 0 4h-3v-4z\"/></svg>"),
Rules: []string{
"||hbo.com^",
"||hbogo.co.th^",
"||hbogo.com^",
"||hbogo.eu^",
"||hbogoasia.com^",
"||hbogoasia.id^",
"||hbogoasia.ph^",
"||hbomax-images.warnermediacdn.com^",
"||hbomax.com^",
"||hbomaxcdn.com^",
"||hbonow.com^",
"||maxgo.com^",
},
}, {
ID: "hulu",
Name: "Hulu",
@@ -1299,6 +1325,21 @@ var blockedServices = []blockedService{{
"||kakao.com^",
"||kgslb.com^",
},
}, {
ID: "lazada",
Name: "Lazada",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 100 100\"><path d=\"M26 17a2 2 0 0 0-1.1.3L8.5 27.4A3 3 0 0 0 7 30v29.1c0 1 .6 2 1.5 2.6l40 24.8c.2.3.6.4 1 .5h1a3 3 0 0 0 1-.5l40-24.8a3 3 0 0 0 1.5-2.6v-29c0-1.1-.6-2.1-1.5-2.7l-16.4-10a2.1 2.1 0 0 0-2.2 0l-15.4 9.4a14.3 14.3 0 0 1-15 0L27 17.3c-.3-.2-.7-.3-1.1-.3zm0 2 15.4 9.5a16.3 16.3 0 0 0 17.2 0L74 19l16.5 10.1v.1L50 51.4 9.4 29.2h.1L26 19zm48 4c-.4 0-.9 0-1.3.3l-5.5 3.4a.5.5 0 1 0 .6.9l5.4-3.4c.5-.3 1.1-.3 1.6 0l9.4 5.7a.5.5 0 1 0 .5-.8l-9.4-5.8c-.4-.2-.8-.4-1.3-.4zm-8.7 5a.5.5 0 0 0-.3 0l-1.6 1a.5.5 0 0 0 .6 1l1.6-1a.5.5 0 0 0-.3-1zM9 30.1l40.5 22.2v32.6L9.5 60a1 1 0 0 1-.5-.9v-29zm82 0v29c0 .4-.2.7-.5 1l-40 24.7V52.4L91 30.1zM12.5 35a.5.5 0 0 0-.5.5v21.2c0 .8.4 1.6 1.2 2l16 10a.5.5 0 1 0 .6-.8l-16-10c-.5-.2-.8-.7-.8-1.2V35.5a.5.5 0 0 0-.5-.5zm24 37.2a.5.5 0 0 0-.3.9l4 2.5a.5.5 0 1 0 .6-.9l-4-2.4a.5.5 0 0 0-.3-.1zm7 4.3a.5.5 0 0 0-.3 1l1 .6a.5.5 0 1 0 .6-.9l-1-.6a.5.5 0 0 0-.3 0z\"/></svg>"),
Rules: []string{
"||k1-lazadasg-oversea.gslb.ksyuncdn.com^",
"||lazada.co.id^",
"||lazada.co.th^",
"||lazada.com.my^",
"||lazada.com.ph^",
"||lazada.com^",
"||lazada.sg^",
"||lazada.vn^",
"||slatic.net^",
},
}, {
ID: "leagueoflegends",
Name: "League of Legends",
@@ -1383,7 +1424,6 @@ var blockedServices = []blockedService{{
"||mastodon.sdf.org^",
"||mastodon.social^",
"||mastodon.social^",
"||mastodon.top^",
"||mastodon.uno^",
"||mastodon.world^",
"||mastodon.xyz^",
@@ -1400,6 +1440,7 @@ var blockedServices = []blockedService{{
"||mstdn.jp^",
"||mstdn.social^",
"||muenchen.social^",
"||muenster.im^",
"||newsie.social^",
"||noc.social^",
"||norden.social^",
@@ -1428,6 +1469,7 @@ var blockedServices = []blockedService{{
"||techhub.social^",
"||theblower.au^",
"||tkz.one^",
"||todon.eu^",
"||toot.aquilenet.fr^",
"||toot.community^",
"||toot.funami.tech^",
@@ -1438,7 +1480,6 @@ var blockedServices = []blockedService{{
"||union.place^",
"||universeodon.com^",
"||urbanists.social^",
"||vocalodon.net^",
"||wien.rocks^",
"||wxw.moe^",
},
@@ -1568,6 +1609,20 @@ var blockedServices = []blockedService{{
"||pinterest.vn^",
"||pinterestmail.com^",
},
}, {
ID: "playstation",
Name: "PlayStation",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 30 30\"><path d=\"M11.18 3.74v21.12l4.58 1.4V8.58c0-.51 0-.77.26-1.02.12-.26.38-.26.63-.13.64.26 1.02.76 1.02 1.78v7c1.53.76 2.8.76 3.81 0 1.02-.77 1.53-1.9 1.53-3.82 0-2.03-.38-3.3-1.27-4.32-.76-1.02-2.16-1.91-4.2-2.55-2.54-.76-4.7-1.4-6.36-1.78zM9.91 16.97l-5.85 2.04-.89.38c-1.4.63-2.16 1.27-2.16 1.9.12.77.38 1.79 2.29 2.42 1.78.64 3.18.9 6.74-.12v-2.3c-3.44 1.15-3.95 1.02-4.45.77-.51-.25-.51-.5-.39-.64.39-.25 1.78-.76 1.78-.76l2.93-1.02v-2.67zm12.94 1c-.41-.02-.82-.01-1.24.02-1.4 0-2.67.25-4.2.64v2.67l2.8-1.02 1.53-.51s.64-.13 1.02-.25c.63-.13 1.4.12 1.4.12.38 0 .63.13.63.38.13.26-.12.39-.76.64l-1.4.51-5.09 1.9v2.68l2.3-.77 6.35-2.28.77-.39c1.52-.5 2.16-1.14 2.03-1.9 0-.77-.89-1.28-2.42-1.79a14.28 14.28 0 0 0-3.72-.66z\"/></svg>"),
Rules: []string{
"||gaikai.com",
"||playstation-cloud.com",
"||playstation-cloud.net",
"||playstation.com",
"||playstation.net",
"||scea.com",
"||sonyentertainmentnetwork.com",
"||station.sony.com",
},
}, {
ID: "qq",
Name: "QQ",
@@ -1597,6 +1652,18 @@ var blockedServices = []blockedService{{
"||redditmedia.com^",
"||redditstatic.com^",
},
}, {
ID: "riot_games",
Name: "Riot Games",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 64 64\"><path d=\"M31.3 2.2a1 1 0 0 0-.8 1.4l.8 1.8a1 1 0 1 0 1.8-.8l-.8-1.8a1 1 0 0 0-1-.6zm-4.3 2a1 1 0 0 0-.9 1.5l1 1.8a1 1 0 1 0 1.7-.9L28 4.8a1 1 0 0 0-1-.6zm12 1a1 1 0 0 0-.4.1l-34 16.2a1 1 0 0 0-.6 1.1L9.3 47a1 1 0 0 0 1 .8h7a1 1 0 0 0 1-1l-1-12.2 3 12.4a1 1 0 0 0 .9.8h7.3a1 1 0 0 0 1-1l-.2-15.8L32 47a1 1 0 0 0 1 .8h7.6a1 1 0 0 0 1-1l1.3-19.4 1.4 19.5a1 1 0 0 0 1 1h10.2a1 1 0 0 0 1-1L60 11.2a1 1 0 0 0-.8-1l-20-5a1 1 0 0 0-.3 0zm-16.3 1a1 1 0 0 0-.9 1.5l.9 1.8a1 1 0 1 0 1.8-.8l-.9-1.8a1 1 0 0 0-1-.6zm16.4 1L57.9 12l-3.3 33.8h-8.3l-1.9-25a1 1 0 0 0-1.2-.8l-1.2.3a1 1 0 0 0-.7.9l-1.7 24.6h-5.8L30.3 25a1 1 0 0 0-1.3-.8l-1 .4a1 1 0 0 0-.8 1l.3 20.2H22l-4-17a1 1 0 0 0-1.2-.7l-1.1.3a1 1 0 0 0-.7 1l1.2 16.4H11L6.1 23l33-15.7zM18.5 8.4a1 1 0 0 0-1 1.5l.9 1.8a1 1 0 1 0 1.8-.9L19.3 9a1 1 0 0 0-.8-.6zm-4.3 2.1a1 1 0 0 0-.1 0 1 1 0 0 0-.9 1.4l.9 1.8a1 1 0 1 0 1.8-.8L15 11a1 1 0 0 0-.8-.6zm-4.4 2a1 1 0 0 0-.9 1.5l.9 1.8a1 1 0 1 0 1.8-.9l-.9-1.8a1 1 0 0 0-1-.5zm-4.3 2.1a1 1 0 0 0-.9 1.4l.9 1.9a1 1 0 1 0 1.8-1l-.9-1.7a1 1 0 0 0-.9-.6zM30.7 49a1 1 0 0 0-.9 1.4l2.5 6.5a1 1 0 0 0 .7.6l20.7 5.3a1 1 0 0 0 1.2-.9L56 51.4a1 1 0 0 0-1-1L30.9 49a1 1 0 0 0-.1 0zm1.5 2.1L54 52.3l-.9 8.2-19-4.8-1.8-4.6z\"/></svg>"),
Rules: []string{
"||dradis-prod.rdatasrv.net^",
"||pvp.net^",
"||rgpub.io^",
"||riotcdn.com^",
"||riotcdn.net^",
"||riotgames.com^",
},
}, {
ID: "roblox",
Name: "Roblox",
@@ -1610,6 +1677,32 @@ var blockedServices = []blockedService{{
"||robloxcdn.com^",
"||robloxdev.cn^",
},
}, {
ID: "shopee",
Name: "Shopee",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M25 1c-5.3 0-9.4 5-9.8 11H5a2 2 0 0 0-2 2.1l1.7 30.2a5 5 0 0 0 5 4.7h30.4a5 5 0 0 0 5-4.7L47 14a2 2 0 0 0-2-2.1H35C34.3 6 30.2 1 25 1zm0 2c4 0 7.4 3.9 7.8 9H17.2c.4-5.1 3.8-9 7.8-9zM5 14h10.8a1 1 0 0 0 .4 0h17.6a1 1 0 0 0 .4 0h10.7l-1.7 30.2a3 3 0 0 1-3 2.8H9.8a3 3 0 0 1-3-2.8L5 14zm20 4c-4.2 0-7.5 2.7-7.5 6.3 0 4 3.8 5.4 7 6.6 4 1.4 6.5 2.5 6.5 5.7 0 2.4-2.7 4.4-6 4.4-3.8 0-7-2.7-7-2.7l-1.2 1.6c.8.7 4.1 3.1 8.1 3.1 4.5 0 8-2.8 8-6.4 0-4.8-4-6.3-7.7-7.6-3.5-1.3-5.7-2.3-5.7-4.7 0-2.5 2.3-4.3 5.6-4.3a11 11 0 0 1 6 1.9l1-1.7c-.3-.1-3.2-2.2-7-2.2z\"/></svg>"),
Rules: []string{
"||shopee.cl^",
"||shopee.cn^",
"||shopee.co.id^",
"||shopee.co.th^",
"||shopee.com.br^",
"||shopee.com.co^",
"||shopee.com.mx^",
"||shopee.com.my^",
"||shopee.com^",
"||shopee.es^",
"||shopee.fr^",
"||shopee.id^",
"||shopee.in^",
"||shopee.io^",
"||shopee.ph^",
"||shopee.sg^",
"||shopee.tw^",
"||shopee.vn^",
"||shopeemobile.com^",
"||shp.ee^",
},
}, {
ID: "skype",
Name: "Skype",
@@ -1813,6 +1906,15 @@ var blockedServices = []blockedService{{
"||twvid.com^",
"||vine.co^",
},
}, {
ID: "valorant",
Name: "Valorant",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M4 6a1 1 0 0 0-1 1v18a1 1 0 0 0 .2.6l14 17a1 1 0 0 0 .8.4h14a1 1 0 0 0 .8-1.6l-28-35A1 1 0 0 0 4 6zm42 1a1 1 0 0 0-.8.4l-18 22A1 1 0 0 0 28 31h14a1 1 0 0 0 .8-.4l4-5a1 1 0 0 0 .2-.6V8a1 1 0 0 0-1-1zM5 9.9 30 41H18.4L5 24.6V10zm40 .9v13.8L41.5 29H30.1L45 10.8z\"/></svg>"),
Rules: []string{
"||playvalorant.com",
"||valorant.scd.riotcdn.net",
"||valorant.secure.dyn.riotcdn.net",
},
}, {
ID: "viber",
Name: "Viber",
@@ -1869,6 +1971,13 @@ var blockedServices = []blockedService{{
"||vkuservideo.com^",
"||vkuservideo.net^",
},
}, {
ID: "voot",
Name: "Voot",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 512 512\"><path d=\"M96 340c-1 4-4 6-7 6H48c-3 0-6-2-8-6L0 213c-1-2 0-5 2-5l2-1h30c4 0 7 3 8 6l25 87c1 3 2 3 3 0l25-87c1-3 4-6 7-6h31c2 0 4 2 4 4v2L96 340zm46-50v-32c0-29 14-56 63-56s63 27 63 56v32c0 28-14 56-63 56s-63-28-63-56zm85 1v-35c0-13-7-20-22-20s-22 7-22 20v35c0 13 7 20 22 20s22-7 22-20zm54-1v-32c0-29 14-56 63-56s63 27 63 56v32c0 28-14 56-63 56s-63-28-63-56zm85 1v-35c0-13-7-20-22-20s-21 7-21 20v35c0 13 6 20 21 20s22-7 22-20zm144 44-2-17-1-2c-1-3-3-5-6-5h-2l-10 1c-2 1-4 0-6-2l-2-11v-56c0-3 2-6 6-6h17c4 0 6-2 7-5l1-22c0-3-2-5-5-6h-21c-3 0-5-2-5-5v-28c0-3-2-5-5-5h-1l-30 4c-3 1-5 4-5 7v22c0 3-3 5-6 5h-7c-4 0-6 3-6 6v22c0 3 2 5 6 5h7c3 0 6 3 6 6v67c0 26 15 36 42 36 8 0 16-1 23-4h1c2 0 5-3 5-6l-1-1z\"/></svg>"),
Rules: []string{
"||voot.com^",
},
}, {
ID: "wechat",
Name: "WeChat",

View File

@@ -16,6 +16,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/timeutil"
@@ -379,9 +380,9 @@ func (a *Auth) newCookie(req loginJSON, addr string) (c *http.Cookie, err error)
// TODO(a.garipov): Support header Forwarded from RFC 7329.
func realIP(r *http.Request) (ip net.IP, err error) {
proxyHeaders := []string{
"CF-Connecting-IP",
"True-Client-IP",
"X-Real-IP",
httphdr.CFConnectingIP,
httphdr.TrueClientIP,
httphdr.XRealIP,
}
for _, h := range proxyHeaders {
@@ -394,7 +395,7 @@ func realIP(r *http.Request) (ip net.IP, err error) {
// If none of the above yielded any results, get the leftmost IP address
// from the X-Forwarded-For header.
s := r.Header.Get("X-Forwarded-For")
s := r.Header.Get(httphdr.XForwardedFor)
ipStrs := strings.SplitN(s, ", ", 2)
ip = net.ParseIP(ipStrs[0])
if ip != nil {
@@ -411,6 +412,21 @@ func realIP(r *http.Request) (ip net.IP, err error) {
return net.ParseIP(ipStr), nil
}
// writeErrorWithIP is like [aghhttp.Error], but includes the remote IP address
// when it writes to the log.
func writeErrorWithIP(
r *http.Request,
w http.ResponseWriter,
code int,
remoteIP string,
format string,
args ...any,
) {
text := fmt.Sprintf(format, args...)
log.Error("%s %s %s: from ip %s: %s", r.Method, r.Host, r.URL, remoteIP, text)
http.Error(w, text, code)
}
func handleLogin(w http.ResponseWriter, r *http.Request) {
req := loginJSON{}
err := json.NewDecoder(r.Body).Decode(&req)
@@ -420,31 +436,45 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
return
}
var remoteAddr string
var remoteIP string
// realIP cannot be used here without taking TrustedProxies into account due
// to security issues.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2799.
//
// TODO(e.burkov): Use realIP when the issue will be fixed.
if remoteAddr, err = netutil.SplitHost(r.RemoteAddr); err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "auth: getting remote address: %s", err)
if remoteIP, err = netutil.SplitHost(r.RemoteAddr); err != nil {
writeErrorWithIP(
r,
w,
http.StatusBadRequest,
r.RemoteAddr,
"auth: getting remote address: %s",
err,
)
return
}
if rateLimiter := Context.auth.raleLimiter; rateLimiter != nil {
if left := rateLimiter.check(remoteAddr); left > 0 {
w.Header().Set("Retry-After", strconv.Itoa(int(left.Seconds())))
aghhttp.Error(r, w, http.StatusTooManyRequests, "auth: blocked for %s", left)
if left := rateLimiter.check(remoteIP); left > 0 {
w.Header().Set(httphdr.RetryAfter, strconv.Itoa(int(left.Seconds())))
writeErrorWithIP(
r,
w,
http.StatusTooManyRequests,
remoteIP,
"auth: blocked for %s",
left,
)
return
}
}
cookie, err := Context.auth.newCookie(req, remoteAddr)
cookie, err := Context.auth.newCookie(req, remoteIP)
if err != nil {
aghhttp.Error(r, w, http.StatusForbidden, "%s", err)
writeErrorWithIP(r, w, http.StatusForbidden, remoteIP, "%s", err)
return
}
@@ -452,10 +482,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
// Use realIP here, since this IP address is only used for logging.
ip, err := realIP(r)
if err != nil {
log.Error("auth: getting real ip from request: %s", err)
} else if ip == nil {
// Technically shouldn't happen.
log.Error("auth: unknown ip")
log.Error("auth: getting real ip from request with remote ip %s: %s", remoteIP, err)
}
log.Info("auth: user %q successfully logged in from ip %v", req.Name, ip)
@@ -463,9 +490,9 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, cookie)
h := w.Header()
h.Set("Cache-Control", "no-store, no-cache, must-revalidate, proxy-revalidate")
h.Set("Pragma", "no-cache")
h.Set("Expires", "0")
h.Set(httphdr.CacheControl, "no-store, no-cache, must-revalidate, proxy-revalidate")
h.Set(httphdr.Pragma, "no-cache")
h.Set(httphdr.Expires, "0")
aghhttp.OK(w)
}
@@ -476,7 +503,7 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
if err != nil {
// The only error that is returned from r.Cookie is [http.ErrNoCookie].
// The user is already logged out.
respHdr.Set("Location", "/login.html")
respHdr.Set(httphdr.Location, "/login.html")
w.WriteHeader(http.StatusFound)
return
@@ -494,8 +521,8 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
SameSite: http.SameSiteLaxMode,
}
respHdr.Set("Location", "/login.html")
respHdr.Set("Set-Cookie", c.String())
respHdr.Set(httphdr.Location, "/login.html")
respHdr.Set(httphdr.SetCookie, c.String())
w.WriteHeader(http.StatusFound)
}
@@ -543,8 +570,7 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (mustAuth bool) {
log.Debug("auth: redirected to login page by GL-Inet submodule")
} else {
log.Debug("auth: redirected to login page")
w.Header().Set("Location", "/login.html")
w.WriteHeader(http.StatusFound)
http.Redirect(w, r, "login.html", http.StatusFound)
}
} else {
log.Debug("auth: responded with forbidden to %s %s", r.Method, p)
@@ -569,8 +595,7 @@ func optionalAuth(
// Redirect to the dashboard if already authenticated.
res := Context.auth.checkSession(cookie.Value)
if res == checkSessionOK {
w.Header().Set("Location", "/")
w.WriteHeader(http.StatusFound)
http.Redirect(w, r, "", http.StatusFound)
return
}

View File

@@ -12,6 +12,7 @@ import (
"testing"
"time"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -135,11 +136,11 @@ func TestAuthHTTP(t *testing.T) {
handlerCalled = false
handler2(&w, &r)
assert.Equal(t, http.StatusFound, w.statusCode)
assert.NotEmpty(t, w.hdr.Get("Location"))
assert.NotEmpty(t, w.hdr.Get(httphdr.Location))
assert.False(t, handlerCalled)
// go to login page
loginURL := w.hdr.Get("Location")
loginURL := w.hdr.Get(httphdr.Location)
r.URL = &url.URL{Path: loginURL}
handlerCalled = false
handler2(&w, &r)
@@ -153,13 +154,13 @@ func TestAuthHTTP(t *testing.T) {
// get /
handler2 = optionalAuth(handler)
w.hdr = make(http.Header)
r.Header.Set("Cookie", cookie.String())
r.Header.Set(httphdr.Cookie, cookie.String())
r.URL = &url.URL{Path: "/"}
handlerCalled = false
handler2(&w, &r)
assert.True(t, handlerCalled)
r.Header.Del("Cookie")
r.Header.Del(httphdr.Cookie)
// get / with basic auth
handler2 = optionalAuth(handler)
@@ -169,28 +170,28 @@ func TestAuthHTTP(t *testing.T) {
handlerCalled = false
handler2(&w, &r)
assert.True(t, handlerCalled)
r.Header.Del("Authorization")
r.Header.Del(httphdr.Authorization)
// get login page with a valid cookie - we're redirected to /
handler2 = optionalAuth(handler)
w.hdr = make(http.Header)
r.Header.Set("Cookie", cookie.String())
r.Header.Set(httphdr.Cookie, cookie.String())
r.URL = &url.URL{Path: loginURL}
handlerCalled = false
handler2(&w, &r)
assert.NotEmpty(t, w.hdr.Get("Location"))
assert.NotEmpty(t, w.hdr.Get(httphdr.Location))
assert.False(t, handlerCalled)
r.Header.Del("Cookie")
r.Header.Del(httphdr.Cookie)
// get login page with an invalid cookie
handler2 = optionalAuth(handler)
w.hdr = make(http.Header)
r.Header.Set("Cookie", "bad")
r.Header.Set(httphdr.Cookie, "bad")
r.URL = &url.URL{Path: loginURL}
handlerCalled = false
handler2(&w, &r)
assert.True(t, handlerCalled)
r.Header.Del("Cookie")
r.Header.Del(httphdr.Cookie)
Context.auth.Close()
}
@@ -213,7 +214,7 @@ func TestRealIP(t *testing.T) {
}, {
name: "success_proxy",
header: http.Header{
textproto.CanonicalMIMEHeaderKey("X-Real-IP"): []string{"1.2.3.5"},
textproto.CanonicalMIMEHeaderKey(httphdr.XRealIP): []string{"1.2.3.5"},
},
remoteAddr: remoteAddr,
wantErrMsg: "",
@@ -221,7 +222,7 @@ func TestRealIP(t *testing.T) {
}, {
name: "success_proxy_multiple",
header: http.Header{
textproto.CanonicalMIMEHeaderKey("X-Forwarded-For"): []string{
textproto.CanonicalMIMEHeaderKey(httphdr.XForwardedFor): []string{
"1.2.3.6, 1.2.3.5",
},
},

View File

@@ -10,8 +10,8 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/log"
"github.com/josharian/native"
)
// GLMode - enable GL-Inet compatibility mode
@@ -102,7 +102,7 @@ func glGetTokenDate(file string) uint32 {
buf := bytes.NewBuffer(bs)
err = binary.Read(buf, aghos.NativeEndian, &dateToken)
err = binary.Read(buf, native.Endian, &dateToken)
if err != nil {
log.Error("decoding token: %s", err)

View File

@@ -6,7 +6,7 @@ import (
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/josharian/native"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -19,13 +19,13 @@ func TestAuthGL(t *testing.T) {
glFilePrefix = dir + "/gl_token_"
data := make([]byte, 4)
aghos.NativeEndian.PutUint32(data, 1)
native.Endian.PutUint32(data, 1)
require.NoError(t, os.WriteFile(glFilePrefix+"test", data, 0o644))
assert.False(t, glCheckToken("test"))
data = make([]byte, 4)
aghos.NativeEndian.PutUint32(data, uint32(time.Now().UTC().Unix()+60))
native.Endian.PutUint32(data, uint32(time.Now().UTC().Unix()+60))
require.NoError(t, os.WriteFile(glFilePrefix+"test", data, 0o644))
r, _ := http.NewRequest(http.MethodGet, "http://localhost/", nil)

View File

@@ -3,7 +3,10 @@ package home
import (
"encoding"
"fmt"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/dnsproxy/proxy"
)
@@ -15,6 +18,9 @@ type Client struct {
// these upstream must be used.
upstreamConfig *proxy.UpstreamConfig
safeSearchConf filtering.SafeSearchConfig
SafeSearch filtering.SafeSearch
Name string
IDs []string
@@ -24,10 +30,11 @@ type Client struct {
UseOwnSettings bool
FilteringEnabled bool
SafeSearchEnabled bool
SafeBrowsingEnabled bool
ParentalEnabled bool
UseOwnBlockedServices bool
IgnoreQueryLog bool
IgnoreStatistics bool
}
// closeUpstreams closes the client-specific upstream config of c if any.
@@ -42,6 +49,23 @@ func (c *Client) closeUpstreams() (err error) {
return nil
}
// setSafeSearch initializes and sets the safe search filter for this client.
func (c *Client) setSafeSearch(
conf filtering.SafeSearchConfig,
cacheSize uint,
cacheTTL time.Duration,
) (err error) {
ss, err := safesearch.NewDefault(conf, fmt.Sprintf("client %q", c.Name), cacheSize, cacheTTL)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
c.SafeSearch = ss
return nil
}
// clientSource represents the source from which the information about the
// client has been obtained.
type clientSource uint

View File

@@ -18,7 +18,6 @@ import (
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
@@ -52,9 +51,17 @@ type clientsContainer struct {
// lock protects all fields.
//
// TODO(a.garipov): Use a pointer and describe which fields are protected in
// more detail.
// more detail. Use sync.RWMutex.
lock sync.Mutex
// safeSearchCacheSize is the size of the safe search cache to use for
// persistent clients.
safeSearchCacheSize uint
// safeSearchCacheTTL is the TTL of the safe search cache to use for
// persistent clients.
safeSearchCacheTTL time.Duration
// testing is a flag that disables some features for internal tests.
//
// TODO(a.garipov): Awful. Remove.
@@ -69,10 +76,12 @@ func (clients *clientsContainer) Init(
dhcpServer dhcpd.Interface,
etcHosts *aghnet.HostsContainer,
arpdb aghnet.ARPDB,
filteringConf *filtering.Config,
) {
if clients.list != nil {
log.Fatal("clients.list != nil")
}
clients.list = make(map[string]*Client)
clients.idIndex = make(map[string]*Client)
clients.ipToRC = map[netip.Addr]*RuntimeClient{}
@@ -82,7 +91,10 @@ func (clients *clientsContainer) Init(
clients.dhcpServer = dhcpServer
clients.etcHosts = etcHosts
clients.arpdb = arpdb
clients.addFromConfig(objects)
clients.addFromConfig(objects, filteringConf)
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
if clients.testing {
return
@@ -133,6 +145,8 @@ func (clients *clientsContainer) reloadARP() {
// clientObject is the YAML representation of a persistent client.
type clientObject struct {
SafeSearchConf filtering.SafeSearchConfig `yaml:"safe_search"`
Name string `yaml:"name"`
Tags []string `yaml:"tags"`
@@ -143,14 +157,16 @@ type clientObject struct {
UseGlobalSettings bool `yaml:"use_global_settings"`
FilteringEnabled bool `yaml:"filtering_enabled"`
ParentalEnabled bool `yaml:"parental_enabled"`
SafeSearchEnabled bool `yaml:"safesearch_enabled"`
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
UseGlobalBlockedServices bool `yaml:"use_global_blocked_services"`
IgnoreQueryLog bool `yaml:"ignore_querylog"`
IgnoreStatistics bool `yaml:"ignore_statistics"`
}
// addFromConfig initializes the clients container with objects from the
// configuration file.
func (clients *clientsContainer) addFromConfig(objects []*clientObject) {
func (clients *clientsContainer) addFromConfig(objects []*clientObject, filteringConf *filtering.Config) {
for _, o := range objects {
cli := &Client{
Name: o.Name,
@@ -161,9 +177,26 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject) {
UseOwnSettings: !o.UseGlobalSettings,
FilteringEnabled: o.FilteringEnabled,
ParentalEnabled: o.ParentalEnabled,
SafeSearchEnabled: o.SafeSearchEnabled,
safeSearchConf: o.SafeSearchConf,
SafeBrowsingEnabled: o.SafeBrowsingEnabled,
UseOwnBlockedServices: !o.UseGlobalBlockedServices,
IgnoreQueryLog: o.IgnoreQueryLog,
IgnoreStatistics: o.IgnoreStatistics,
}
if o.SafeSearchConf.Enabled {
o.SafeSearchConf.CustomResolver = safeSearchResolver{}
err := cli.setSafeSearch(
o.SafeSearchConf,
filteringConf.SafeSearchCacheSize,
time.Minute*time.Duration(filteringConf.CacheTime),
)
if err != nil {
log.Error("clients: init client safesearch %q: %s", cli.Name, err)
continue
}
}
for _, s := range o.BlockedServices {
@@ -210,9 +243,11 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
UseGlobalSettings: !cli.UseOwnSettings,
FilteringEnabled: cli.FilteringEnabled,
ParentalEnabled: cli.ParentalEnabled,
SafeSearchEnabled: cli.SafeSearchEnabled,
SafeSearchConf: cli.safeSearchConf,
SafeBrowsingEnabled: cli.SafeBrowsingEnabled,
UseGlobalBlockedServices: !cli.UseOwnBlockedServices,
IgnoreQueryLog: cli.IgnoreQueryLog,
IgnoreStatistics: cli.IgnoreStatistics,
}
objs = append(objs, o)
@@ -324,7 +359,8 @@ func (clients *clientsContainer) clientOrArtificial(
client, ok := clients.Find(id)
if ok {
return &querylog.Client{
Name: client.Name,
Name: client.Name,
IgnoreQueryLog: client.IgnoreQueryLog,
}, false
}
@@ -359,6 +395,20 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
return c, true
}
// shouldCountClient is a wrapper around Find to make it a valid client
// information finder for the statistics. If no information about the client
// is found, it returns true.
func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) {
for _, id := range ids {
client, ok := clients.Find(id)
if ok {
return !client.IgnoreStatistics
}
}
return true
}
// findUpstreams returns upstreams configured for the client, identified either
// by its IP address or its ClientID. upsConf is nil if the client isn't found
// or if the client has no custom upstreams.
@@ -389,6 +439,7 @@ func (clients *clientsContainer) findUpstreams(
Bootstrap: config.DNS.BootstrapDNS,
Timeout: config.DNS.UpstreamTimeout.Duration,
HTTPVersions: dnsforward.UpstreamHTTPVersions(config.DNS.UseHTTP3Upstreams),
PreferIPv6: config.DNS.BootstrapPreferIPv6,
},
)
if err != nil {
@@ -839,15 +890,7 @@ func (clients *clientsContainer) updateFromDHCP(add bool) {
continue
}
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
ipAddr, err := netutil.IPToAddrNoMapped(l.IP)
if err != nil {
log.Error("clients: bad client ip %v from dhcp: %s", l.IP, err)
continue
}
ok := clients.addHostLocked(ipAddr, l.Hostname, ClientSourceDHCP)
ok := clients.addHostLocked(l.IP, l.Hostname, ClientSourceDHCP)
if ok {
n++
}

View File

@@ -9,17 +9,27 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestClients(t *testing.T) {
clients := clientsContainer{}
clients.testing = true
// newClientsContainer is a helper that creates a new clients container for
// tests.
func newClientsContainer() (c *clientsContainer) {
c = &clientsContainer{
testing: true,
}
clients.Init(nil, nil, nil, nil)
c.Init(nil, nil, nil, nil, &filtering.Config{})
return c
}
func TestClients(t *testing.T) {
clients := newClientsContainer()
t.Run("add_success", func(t *testing.T) {
var (
@@ -198,10 +208,7 @@ func TestClients(t *testing.T) {
}
func TestClientsWHOIS(t *testing.T) {
clients := clientsContainer{
testing: true,
}
clients.Init(nil, nil, nil, nil)
clients := newClientsContainer()
whois := &RuntimeClientWHOISInfo{
Country: "AU",
Orgname: "Example Org",
@@ -247,10 +254,7 @@ func TestClientsWHOIS(t *testing.T) {
}
func TestClientsAddExisting(t *testing.T) {
clients := clientsContainer{
testing: true,
}
clients.Init(nil, nil, nil, nil)
clients := newClientsContainer()
t.Run("simple", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
@@ -275,7 +279,7 @@ func TestClientsAddExisting(t *testing.T) {
t.Skip("skipping dhcp test on windows")
}
ip := net.IP{1, 2, 3, 4}
ip := netip.MustParseAddr("1.2.3.4")
// First, init a DHCP server with a single static lease.
config := &dhcpd.ServerConfig{
@@ -325,10 +329,7 @@ func TestClientsAddExisting(t *testing.T) {
}
func TestClientsCustomUpstream(t *testing.T) {
clients := clientsContainer{
testing: true,
}
clients.Init(nil, nil, nil, nil)
clients := newClientsContainer()
// Add client with upstreams.
ok, err := clients.Add(&Client{

View File

@@ -7,6 +7,7 @@ import (
"net/netip"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
)
// clientJSON is a common structure used by several handlers to deal with
@@ -26,7 +27,8 @@ type clientJSON struct {
// the allowlist.
DisallowedRule *string `json:"disallowed_rule,omitempty"`
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info,omitempty"`
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info,omitempty"`
SafeSearchConf *filtering.SafeSearchConfig `json:"safe_search"`
Name string `json:"name"`
@@ -35,9 +37,10 @@ type clientJSON struct {
Tags []string `json:"tags"`
Upstreams []string `json:"upstreams"`
FilteringEnabled bool `json:"filtering_enabled"`
ParentalEnabled bool `json:"parental_enabled"`
SafeBrowsingEnabled bool `json:"safebrowsing_enabled"`
FilteringEnabled bool `json:"filtering_enabled"`
ParentalEnabled bool `json:"parental_enabled"`
SafeBrowsingEnabled bool `json:"safebrowsing_enabled"`
// Deprecated: use safeSearchConf.
SafeSearchEnabled bool `json:"safesearch_enabled"`
UseGlobalBlockedServices bool `json:"use_global_blocked_services"`
UseGlobalSettings bool `json:"use_global_settings"`
@@ -46,8 +49,8 @@ type clientJSON struct {
type runtimeClientJSON struct {
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
Name string `json:"name"`
IP netip.Addr `json:"ip"`
Name string `json:"name"`
Source clientSource `json:"source"`
}
@@ -57,7 +60,7 @@ type clientListJSON struct {
Tags []string `json:"supported_tags"`
}
// respond with information about configured clients
// handleGetClients is the handler for GET /control/clients HTTP API.
func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http.Request) {
data := clientListJSON{}
@@ -86,27 +89,67 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
_ = aghhttp.WriteJSONResponse(w, r, data)
}
// Convert JSON object to Client object
func jsonToClient(cj clientJSON) (c *Client) {
return &Client{
Name: cj.Name,
IDs: cj.IDs,
Tags: cj.Tags,
UseOwnSettings: !cj.UseGlobalSettings,
FilteringEnabled: cj.FilteringEnabled,
ParentalEnabled: cj.ParentalEnabled,
SafeSearchEnabled: cj.SafeSearchEnabled,
SafeBrowsingEnabled: cj.SafeBrowsingEnabled,
// jsonToClient converts JSON object to Client object.
func (clients *clientsContainer) jsonToClient(cj clientJSON) (c *Client, err error) {
var safeSearchConf filtering.SafeSearchConfig
if cj.SafeSearchConf != nil {
safeSearchConf = *cj.SafeSearchConf
} else {
// TODO(d.kolyshev): Remove after cleaning the deprecated
// [clientJSON.SafeSearchEnabled] field.
safeSearchConf = filtering.SafeSearchConfig{
Enabled: cj.SafeSearchEnabled,
}
UseOwnBlockedServices: !cj.UseGlobalBlockedServices,
BlockedServices: cj.BlockedServices,
Upstreams: cj.Upstreams,
// Set default service flags for enabled safesearch.
if safeSearchConf.Enabled {
safeSearchConf.Bing = true
safeSearchConf.DuckDuckGo = true
safeSearchConf.Google = true
safeSearchConf.Pixabay = true
safeSearchConf.Yandex = true
safeSearchConf.YouTube = true
}
}
c = &Client{
safeSearchConf: safeSearchConf,
Name: cj.Name,
IDs: cj.IDs,
Tags: cj.Tags,
BlockedServices: cj.BlockedServices,
Upstreams: cj.Upstreams,
UseOwnSettings: !cj.UseGlobalSettings,
FilteringEnabled: cj.FilteringEnabled,
ParentalEnabled: cj.ParentalEnabled,
SafeBrowsingEnabled: cj.SafeBrowsingEnabled,
UseOwnBlockedServices: !cj.UseGlobalBlockedServices,
}
if safeSearchConf.Enabled {
err = c.setSafeSearch(
safeSearchConf,
clients.safeSearchCacheSize,
clients.safeSearchCacheTTL,
)
if err != nil {
return nil, fmt.Errorf("creating safesearch for client %q: %w", c.Name, err)
}
}
return c, nil
}
// Convert Client object to JSON
// clientToJSON converts Client object to JSON.
func clientToJSON(c *Client) (cj *clientJSON) {
// TODO(d.kolyshev): Remove after cleaning the deprecated
// [clientJSON.SafeSearchEnabled] field.
cloneVal := c.safeSearchConf
safeSearchConf := &cloneVal
return &clientJSON{
Name: c.Name,
IDs: c.IDs,
@@ -114,7 +157,8 @@ func clientToJSON(c *Client) (cj *clientJSON) {
UseGlobalSettings: !c.UseOwnSettings,
FilteringEnabled: c.FilteringEnabled,
ParentalEnabled: c.ParentalEnabled,
SafeSearchEnabled: c.SafeSearchEnabled,
SafeSearchEnabled: safeSearchConf.Enabled,
SafeSearchConf: safeSearchConf,
SafeBrowsingEnabled: c.SafeBrowsingEnabled,
UseGlobalBlockedServices: !c.UseOwnBlockedServices,
@@ -124,7 +168,7 @@ func clientToJSON(c *Client) (cj *clientJSON) {
}
}
// Add a new client
// handleAddClient is the handler for POST /control/clients/add HTTP API.
func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.Request) {
cj := clientJSON{}
err := json.NewDecoder(r.Body).Decode(&cj)
@@ -134,7 +178,13 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
return
}
c := jsonToClient(cj)
c, err := clients.jsonToClient(cj)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
ok, err := clients.Add(c)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@@ -151,7 +201,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
onConfigModified()
}
// Remove client
// handleDelClient is the handler for POST /control/clients/delete HTTP API.
func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.Request) {
cj := clientJSON{}
err := json.NewDecoder(r.Body).Decode(&cj)
@@ -181,7 +231,7 @@ type updateJSON struct {
Data clientJSON `json:"data"`
}
// Update client's properties
// handleUpdateClient is the handler for POST /control/clients/update HTTP API.
func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *http.Request) {
dj := updateJSON{}
err := json.NewDecoder(r.Body).Decode(&dj)
@@ -197,7 +247,13 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
return
}
c := jsonToClient(dj.Data)
c, err := clients.jsonToClient(dj.Data)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
err = clients.Update(dj.Name, c)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@@ -208,7 +264,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
onConfigModified()
}
// Get the list of clients by IP address list
// handleFindClient is the handler for GET /control/clients/find HTTP API.
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
data := []map[string]*clientJSON{}

View File

@@ -228,34 +228,32 @@ type tlsConfigSettings struct {
}
type queryLogConfig struct {
// Ignored is the list of host names, which should not be written to log.
Ignored []string `yaml:"ignored"`
// Interval is the interval for query log's files rotation.
Interval timeutil.Duration `yaml:"interval"`
// MemSize is the number of entries kept in memory before they are flushed
// to disk.
MemSize uint32 `yaml:"size_memory"`
// Enabled defines if the query log is enabled.
Enabled bool `yaml:"enabled"`
// FileEnabled defines, if the query log is written to the file.
FileEnabled bool `yaml:"file_enabled"`
// Interval is the interval for query log's files rotation.
Interval timeutil.Duration `yaml:"interval"`
// MemSize is the number of entries kept in memory before they are
// flushed to disk.
MemSize uint32 `yaml:"size_memory"`
// Ignored is the list of host names, which should not be written to
// log.
Ignored []string `yaml:"ignored"`
}
type statsConfig struct {
// Enabled defines if the statistics are enabled.
Enabled bool `yaml:"enabled"`
// Interval is the time interval for flushing statistics to the disk in
// days.
Interval uint32 `yaml:"interval"`
// Ignored is the list of host names, which should not be counted.
Ignored []string `yaml:"ignored"`
// Interval is the retention interval for statistics.
Interval timeutil.Duration `yaml:"interval"`
// Enabled defines if the statistics are enabled.
Enabled bool `yaml:"enabled"`
}
// config is the global configuration structure.
@@ -286,7 +284,7 @@ var config = &configuration{
CacheSize: 4 * 1024 * 1024,
EDNSClientSubnet: &dnsforward.EDNSClientSubnet{
CustomIP: "",
CustomIP: netip.Addr{},
Enabled: false,
UseCustom: false,
},
@@ -322,7 +320,7 @@ var config = &configuration{
},
Stats: statsConfig{
Enabled: true,
Interval: 1,
Interval: timeutil.Duration{Duration: 1 * timeutil.Day},
Ignored: []string{},
},
// NOTE: Keep these parameters in sync with the one put into
@@ -503,7 +501,7 @@ func (c *configuration) write() (err error) {
if Context.stats != nil {
statsConf := stats.Config{}
Context.stats.WriteDiskConfig(&statsConf)
config.Stats.Interval = statsConf.LimitDays
config.Stats.Interval = timeutil.Duration{Duration: statsConf.Limit}
config.Stats.Enabled = statsConf.Enabled
config.Stats.Ignored = statsConf.Ignored.Values()
slices.Sort(config.Stats.Ignored)

View File

@@ -13,7 +13,9 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/mathutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/NYTimes/gziphandler"
)
@@ -97,12 +99,17 @@ func collectDNSAddresses() (addrs []string, err error) {
// statusResponse is a response for /control/status endpoint.
type statusResponse struct {
Version string `json:"version"`
Language string `json:"language"`
DNSAddrs []string `json:"dns_addresses"`
DNSPort int `json:"dns_port"`
HTTPPort int `json:"http_port"`
IsProtectionEnabled bool `json:"protection_enabled"`
Version string `json:"version"`
Language string `json:"language"`
DNSAddrs []string `json:"dns_addresses"`
DNSPort int `json:"dns_port"`
HTTPPort int `json:"http_port"`
// ProtectionDisabledDuration is the duration of the protection pause in
// milliseconds.
ProtectionDisabledDuration int64 `json:"protection_disabled_duration"`
ProtectionEnabled bool `json:"protection_enabled"`
// TODO(e.burkov): Inspect if front-end doesn't requires this field as
// openapi.yaml declares.
IsDHCPAvailable bool `json:"dhcp_available"`
@@ -119,28 +126,45 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
return
}
var (
fltConf *dnsforward.FilteringConfig
protectionDisabledUntil *time.Time
protectionEnabled bool
)
if Context.dnsServer != nil {
fltConf = &dnsforward.FilteringConfig{}
Context.dnsServer.WriteDiskConfig(fltConf)
protectionEnabled, protectionDisabledUntil = Context.dnsServer.UpdatedProtectionStatus()
}
var resp statusResponse
func() {
config.RLock()
defer config.RUnlock()
var protectionDisabledDuration int64
if protectionDisabledUntil != nil {
// Make sure that we don't send negative numbers to the frontend,
// since enough time might have passed to make the difference less
// than zero.
protectionDisabledDuration = mathutil.Max(
0,
time.Until(*protectionDisabledUntil).Milliseconds(),
)
}
resp = statusResponse{
Version: version.Version(),
DNSAddrs: dnsAddrs,
DNSPort: config.DNS.Port,
HTTPPort: config.BindPort,
Language: config.Language,
IsRunning: isRunning(),
Version: version.Version(),
Language: config.Language,
DNSAddrs: dnsAddrs,
DNSPort: config.DNS.Port,
HTTPPort: config.BindPort,
ProtectionDisabledDuration: protectionDisabledDuration,
ProtectionEnabled: protectionEnabled,
IsRunning: isRunning(),
}
}()
var c *dnsforward.FilteringConfig
if Context.dnsServer != nil {
c = &dnsforward.FilteringConfig{}
Context.dnsServer.WriteDiskConfig(c)
resp.IsProtectionEnabled = c.ProtectionEnabled
}
// IsDHCPAvailable field is now false by default for Windows.
if runtime.GOOS != "windows" {
resp.IsDHCPAvailable = Context.dhcpServer != nil
@@ -219,7 +243,7 @@ func modifiesData(m string) (ok bool) {
func ensureContentType(w http.ResponseWriter, r *http.Request) (ok bool) {
const statusUnsup = http.StatusUnsupportedMediaType
cType := r.Header.Get(aghhttp.HdrNameContentType)
cType := r.Header.Get(httphdr.ContentType)
if r.ContentLength == 0 {
if cType == "" {
return true
@@ -308,13 +332,17 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
return false
}
var serveHTTP3 bool
var portHTTPS int
var (
forceHTTPS bool
serveHTTP3 bool
portHTTPS int
)
func() {
config.RLock()
defer config.RUnlock()
serveHTTP3, portHTTPS = config.DNS.ServeHTTP3, config.TLS.PortHTTPS
forceHTTPS = config.TLS.ForceHTTPS && config.TLS.Enabled && config.TLS.PortHTTPS != 0
}()
respHdr := w.Header()
@@ -327,13 +355,13 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
// default is 24 hours.
if serveHTTP3 {
altSvc := fmt.Sprintf(`h3=":%d"`, portHTTPS)
respHdr.Set(aghhttp.HdrNameAltSvc, altSvc)
respHdr.Set(httphdr.AltSvc, altSvc)
}
if r.TLS == nil && web.forceHTTPS {
if r.TLS == nil && forceHTTPS {
hostPort := host
if port := web.conf.PortHTTPS; port != defaultPortHTTPS {
hostPort = netutil.JoinHostPort(host, port)
if portHTTPS != defaultPortHTTPS {
hostPort = netutil.JoinHostPort(host, portHTTPS)
}
httpsURL := &url.URL{
@@ -357,8 +385,8 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
Host: r.Host,
}
respHdr.Set(aghhttp.HdrNameAccessControlAllowOrigin, originURL.String())
respHdr.Set(aghhttp.HdrNameVary, aghhttp.HdrNameOrigin)
respHdr.Set(httphdr.AccessControlAllowOrigin, originURL.String())
respHdr.Set(httphdr.Vary, httphdr.Origin)
return true
}
@@ -371,7 +399,7 @@ func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.Res
path := r.URL.Path
if Context.firstRun && !strings.HasPrefix(path, "/install.") &&
!strings.HasPrefix(path, "/assets/") {
http.Redirect(w, r, "/install.html", http.StatusFound)
http.Redirect(w, r, "install.html", http.StatusFound)
return
}

View File

@@ -39,7 +39,7 @@ type getAddrsResponse struct {
}
// handleInstallGetAddresses is the handler for /install/get_addresses endpoint.
func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
func (web *webAPI) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
data := getAddrsResponse{
Version: version.Version(),
@@ -167,7 +167,7 @@ func (req *checkConfReq) validateDNS(
}
// handleInstallCheckConfig handles the /check_config endpoint.
func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
func (web *webAPI) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
req := &checkConfReq{}
err := json.NewDecoder(r.Body).Decode(req)
@@ -375,7 +375,7 @@ func shutdownSrv3(srv *http3.Server) {
const PasswordMinRunes = 8
// Apply new configuration, start DNS server, restart Web server
func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
req, restartHTTP, err := decodeApplyConfigReq(r.Body)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@@ -503,7 +503,7 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
return req, restartHTTP, err
}
func (web *Web) registerInstallHandlers() {
func (web *webAPI) registerInstallHandlers() {
Context.mux.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(web.handleInstallGetAddresses)))
Context.mux.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig)))
Context.mux.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure)))

View File

@@ -1,13 +1,13 @@
package home
import (
"context"
"fmt"
"net"
"net/netip"
"net/url"
"os"
"path/filepath"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
@@ -21,7 +21,6 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/ameshkov/dnscrypt/v2"
yaml "gopkg.in/yaml.v3"
)
@@ -52,14 +51,15 @@ func initDNS() (err error) {
anonymizer := config.anonymizer()
statsConf := stats.Config{
Filename: filepath.Join(baseDir, "stats.db"),
LimitDays: config.Stats.Interval,
ConfigModified: onConfigModified,
HTTPRegister: httpRegister,
Enabled: config.Stats.Enabled,
Filename: filepath.Join(baseDir, "stats.db"),
Limit: config.Stats.Interval.Duration,
ConfigModified: onConfigModified,
HTTPRegister: httpRegister,
Enabled: config.Stats.Enabled,
ShouldCountClient: Context.clients.shouldCountClient,
}
set, err := nonDupEmptyHostNames(config.Stats.Ignored)
set, err := aghnet.NewDomainNameSet(config.Stats.Ignored)
if err != nil {
return fmt.Errorf("statistics: ignored list: %w", err)
}
@@ -83,13 +83,16 @@ func initDNS() (err error) {
FileEnabled: config.QueryLog.FileEnabled,
}
set, err = nonDupEmptyHostNames(config.QueryLog.Ignored)
set, err = aghnet.NewDomainNameSet(config.QueryLog.Ignored)
if err != nil {
return fmt.Errorf("querylog: ignored list: %w", err)
}
conf.Ignored = set
Context.queryLog = querylog.New(conf)
Context.queryLog, err = querylog.New(conf)
if err != nil {
return fmt.Errorf("init querylog: %w", err)
}
Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil)
if err != nil {
@@ -426,7 +429,8 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering
}
setts.FilteringEnabled = c.FilteringEnabled
setts.SafeSearchEnabled = c.SafeSearchEnabled
setts.SafeSearchEnabled = c.safeSearchConf.Enabled
setts.ClientSafeSearch = c.SafeSearch
setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled
setts.ParentalEnabled = c.ParentalEnabled
}
@@ -533,26 +537,30 @@ func closeDNSServer() {
log.Debug("all dns modules are closed")
}
// nonDupEmptyHostNames returns nil and error, if list has duplicate or empty
// host name. Otherwise returns a set, which contains lowercase host names
// without dot at the end, and nil error.
func nonDupEmptyHostNames(list []string) (set *stringutil.Set, err error) {
set = stringutil.NewSet()
// safeSearchResolver is a [filtering.Resolver] implementation used for safe
// search.
type safeSearchResolver struct{}
for _, v := range list {
host := strings.ToLower(strings.TrimSuffix(v, "."))
// TODO(a.garipov): Think about ignoring empty (".") names in
// the future.
if host == "" {
return nil, errors.Error("host name is empty")
}
// type check
var _ filtering.Resolver = safeSearchResolver{}
if set.Has(host) {
return nil, fmt.Errorf("duplicate host name %q", host)
}
set.Add(host)
// LookupIP implements [filtering.Resolver] interface for safeSearchResolver.
// It returns the slice of net.IP with IPv4 and IPv6 instances.
//
// TODO(a.garipov): Support network.
func (r safeSearchResolver) LookupIP(_ context.Context, _, host string) (ips []net.IP, err error) {
addrs, err := Context.dnsServer.Resolve(host)
if err != nil {
return nil, err
}
return set, nil
if len(addrs) == 0 {
return nil, fmt.Errorf("couldn't lookup host: %s", host)
}
for _, a := range addrs {
ips = append(ips, a.IP)
}
return ips, nil
}

View File

@@ -9,7 +9,6 @@ import (
"io/fs"
"net"
"net/http"
"net/http/pprof"
"net/netip"
"net/url"
"os"
@@ -28,6 +27,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/AdGuardHome/internal/updater"
@@ -58,13 +58,12 @@ type homeContext struct {
dhcpServer dhcpd.Interface // DHCP module
auth *Auth // HTTP authentication module
filters *filtering.DNSFilter // DNS filtering module
web *Web // Web (HTTP, HTTPS) module
web *webAPI // Web (HTTP, HTTPS) module
tls *tlsManager // TLS module
// etcHosts is an IP-hostname pairs set taken from system configuration
// (e.g. /etc/hosts) files.
// etcHosts contains IP-hostname mappings taken from the OS-specific hosts
// configuration files, for example /etc/hosts.
etcHosts *aghnet.HostsContainer
// hostsWatcher is the watcher to detect changes in the hosts files.
hostsWatcher aghos.FSWatcher
updater *updater.Updater
@@ -79,7 +78,6 @@ type homeContext struct {
pidFileName string // PID file name. Empty if no PID file was created.
controlLock sync.Mutex
tlsRoots *x509.CertPool // list of root CAs for TLSv1.2
transport *http.Transport
client *http.Client
appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app
@@ -149,18 +147,17 @@ func setupContext(opts options) {
setupContextFlags(opts)
Context.tlsRoots = aghtls.SystemRootCAs()
Context.transport = &http.Transport{
DialContext: customDialContext,
Proxy: getHTTPProxy,
TLSClientConfig: &tls.Config{
RootCAs: Context.tlsRoots,
CipherSuites: Context.tlsCipherIDs,
MinVersion: tls.VersionTLS12,
},
}
Context.client = &http.Client{
Timeout: time.Minute * 5,
Transport: Context.transport,
Timeout: time.Minute * 5,
Transport: &http.Transport{
DialContext: customDialContext,
Proxy: getHTTPProxy,
TLSClientConfig: &tls.Config{
RootCAs: Context.tlsRoots,
CipherSuites: Context.tlsCipherIDs,
MinVersion: tls.VersionTLS12,
},
},
}
if !Context.firstRun {
@@ -263,7 +260,7 @@ func configureOS(conf *configuration) (err error) {
// setupHostsContainer initializes the structures to keep up-to-date the hosts
// provided by the OS.
func setupHostsContainer() (err error) {
Context.hostsWatcher, err = aghos.NewOSWritesWatcher()
hostsWatcher, err := aghos.NewOSWritesWatcher()
if err != nil {
return fmt.Errorf("initing hosts watcher: %w", err)
}
@@ -271,18 +268,18 @@ func setupHostsContainer() (err error) {
Context.etcHosts, err = aghnet.NewHostsContainer(
filtering.SysHostsListID,
aghos.RootDirFS(),
Context.hostsWatcher,
hostsWatcher,
aghnet.DefaultHostsPaths()...,
)
if err != nil {
cerr := Context.hostsWatcher.Close()
if errors.Is(err, aghnet.ErrNoHostsPaths) && cerr == nil {
closeErr := hostsWatcher.Close()
if errors.Is(err, aghnet.ErrNoHostsPaths) && closeErr == nil {
log.Info("warning: initing hosts container: %s", err)
return nil
}
return errors.WithDeferred(fmt.Errorf("initing hosts container: %w", err), cerr)
return errors.WithDeferred(fmt.Errorf("initing hosts container: %w", err), closeErr)
}
return nil
@@ -298,6 +295,17 @@ func setupConfig(opts options) (err error) {
config.DNS.DnsfilterConf.UserRules = slices.Clone(config.UserRules)
config.DNS.DnsfilterConf.HTTPClient = Context.client
config.DNS.DnsfilterConf.SafeSearchConf.CustomResolver = safeSearchResolver{}
config.DNS.DnsfilterConf.SafeSearch, err = safesearch.NewDefault(
config.DNS.DnsfilterConf.SafeSearchConf,
"default",
config.DNS.DnsfilterConf.SafeSearchCacheSize,
time.Minute*time.Duration(config.DNS.DnsfilterConf.CacheTime),
)
if err != nil {
return fmt.Errorf("initializing safesearch: %w", err)
}
config.DHCP.WorkDir = Context.workDir
config.DHCP.HTTPRegister = httpRegister
config.DHCP.ConfigModified = onConfigModified
@@ -328,33 +336,16 @@ func setupConfig(opts options) (err error) {
arpdb = aghnet.NewARPDB()
}
Context.clients.Init(config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, arpdb)
Context.clients.Init(config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, arpdb, config.DNS.DnsfilterConf)
if opts.bindPort != 0 {
tcpPorts := aghalg.UniqChecker[tcpPort]{}
addPorts(tcpPorts, tcpPort(opts.bindPort))
udpPorts := aghalg.UniqChecker[udpPort]{}
addPorts(udpPorts, udpPort(config.DNS.Port))
if config.TLS.Enabled {
addPorts(
tcpPorts,
tcpPort(config.TLS.PortHTTPS),
tcpPort(config.TLS.PortDNSOverTLS),
tcpPort(config.TLS.PortDNSCrypt),
)
addPorts(udpPorts, udpPort(config.TLS.PortDNSOverQUIC))
}
if err = tcpPorts.Validate(); err != nil {
return fmt.Errorf("validating tcp ports: %w", err)
} else if err = udpPorts.Validate(); err != nil {
return fmt.Errorf("validating udp ports: %w", err)
}
config.BindPort = opts.bindPort
err = checkPorts()
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
}
// override bind host/port from the console
@@ -368,7 +359,35 @@ func setupConfig(opts options) (err error) {
return nil
}
func initWeb(opts options, clientBuildFS fs.FS) (web *Web, err error) {
// checkPorts is a helper for ports validation in config.
func checkPorts() (err error) {
tcpPorts := aghalg.UniqChecker[tcpPort]{}
addPorts(tcpPorts, tcpPort(config.BindPort))
udpPorts := aghalg.UniqChecker[udpPort]{}
addPorts(udpPorts, udpPort(config.DNS.Port))
if config.TLS.Enabled {
addPorts(
tcpPorts,
tcpPort(config.TLS.PortHTTPS),
tcpPort(config.TLS.PortDNSOverTLS),
tcpPort(config.TLS.PortDNSCrypt),
)
addPorts(udpPorts, udpPort(config.TLS.PortDNSOverQUIC))
}
if err = tcpPorts.Validate(); err != nil {
return fmt.Errorf("validating tcp ports: %w", err)
} else if err = udpPorts.Validate(); err != nil {
return fmt.Errorf("validating udp ports: %w", err)
}
return nil
}
func initWeb(opts options, clientBuildFS fs.FS) (web *webAPI, err error) {
var clientFS fs.FS
if opts.localFrontend {
log.Info("warning: using local frontend files")
@@ -395,7 +414,7 @@ func initWeb(opts options, clientBuildFS fs.FS) (web *Web, err error) {
serveHTTP3: config.DNS.ServeHTTP3,
}
web = newWeb(&webConf)
web = newWebAPI(&webConf)
if web == nil {
return nil, fmt.Errorf("initializing web: %w", err)
}
@@ -450,26 +469,8 @@ func run(opts options, clientBuildFS fs.FS) {
fatalOnError(err)
if config.DebugPProf {
mux := http.NewServeMux()
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
// See profileSupportsDelta in src/net/http/pprof/pprof.go.
mux.Handle("/debug/pprof/allocs", pprof.Handler("allocs"))
mux.Handle("/debug/pprof/block", pprof.Handler("block"))
mux.Handle("/debug/pprof/goroutine", pprof.Handler("goroutine"))
mux.Handle("/debug/pprof/heap", pprof.Handler("heap"))
mux.Handle("/debug/pprof/mutex", pprof.Handler("mutex"))
mux.Handle("/debug/pprof/threadcreate", pprof.Handler("threadcreate"))
go func() {
log.Info("pprof: listening on localhost:6060")
lerr := http.ListenAndServe("localhost:6060", mux)
log.Error("Error while running the pprof server: %s", lerr)
}()
// TODO(a.garipov): Make the address configurable.
startPprof("localhost:6060")
}
}
@@ -532,7 +533,7 @@ func run(opts options, clientBuildFS fs.FS) {
}
}
Context.web.Start()
Context.web.start()
// wait indefinitely for other go-routines to complete their job
select {}
@@ -712,7 +713,7 @@ func cleanup(ctx context.Context) {
log.Info("stopping AdGuard Home")
if Context.web != nil {
Context.web.Close(ctx)
Context.web.close(ctx)
Context.web = nil
}
if Context.auth != nil {
@@ -733,13 +734,6 @@ func cleanup(ctx context.Context) {
}
if Context.etcHosts != nil {
// Currently Context.hostsWatcher is only used in Context.etcHosts and
// needs closing only in case of the successful initialization of
// Context.etcHosts.
if err = Context.hostsWatcher.Close(); err != nil {
log.Error("closing hosts watcher: %s", err)
}
if err = Context.etcHosts.Close(); err != nil {
log.Error("closing hosts container: %s", err)
}
@@ -857,8 +851,10 @@ func detectFirstRun() bool {
// Connect to a remote server resolving hostname using our own DNS server.
//
// TODO(e.burkov): This messy logic should be decomposed and clarified.
//
// TODO(a.garipov): Support network.
func customDialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
log.Tracef("network:%v addr:%v", network, addr)
log.Debug("home: customdial: dialing addr %q for network %s", addr, network)
host, port, err := net.SplitHostPort(addr)
if err != nil {

View File

@@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
"github.com/google/uuid"
"howett.net/plist"
@@ -170,7 +171,7 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
return
}
w.Header().Set("Content-Type", "application/xml")
w.Header().Set(httphdr.ContentType, "application/xml")
const (
dohContDisp = `attachment; filename=doh.mobileconfig`
@@ -182,7 +183,7 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
contDisp = dotContDisp
}
w.Header().Set("Content-Disposition", contDisp)
w.Header().Set(httphdr.ContentDisposition, contDisp)
_, _ = w.Write(mobileconfig)
}

39
internal/home/pprof.go Normal file
View File

@@ -0,0 +1,39 @@
package home
import (
"net/http"
"net/http/pprof"
"runtime"
"github.com/AdguardTeam/golibs/log"
)
// startPprof launches the debug and profiling server on addr.
func startPprof(addr string) {
runtime.SetBlockProfileRate(1)
runtime.SetMutexProfileFraction(1)
mux := http.NewServeMux()
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
// See profileSupportsDelta in src/net/http/pprof/pprof.go.
mux.Handle("/debug/pprof/allocs", pprof.Handler("allocs"))
mux.Handle("/debug/pprof/block", pprof.Handler("block"))
mux.Handle("/debug/pprof/goroutine", pprof.Handler("goroutine"))
mux.Handle("/debug/pprof/heap", pprof.Handler("heap"))
mux.Handle("/debug/pprof/mutex", pprof.Handler("mutex"))
mux.Handle("/debug/pprof/threadcreate", pprof.Handler("threadcreate"))
go func() {
defer log.OnPanic("pprof server")
log.Info("pprof: listening on %q", addr)
err := http.ListenAndServe(addr, mux)
log.Info("pprof server errors: %v", err)
}()
}

View File

@@ -108,7 +108,7 @@ func (m *tlsManager) start() {
// The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current
// request.
Context.web.TLSConfigChanged(context.Background(), tlsConf)
Context.web.tlsConfigChanged(context.Background(), tlsConf)
}
// reload updates the configuration and restarts t.
@@ -156,7 +156,7 @@ func (m *tlsManager) reload() {
// The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current
// request.
Context.web.TLSConfigChanged(context.Background(), tlsConf)
Context.web.tlsConfigChanged(context.Background(), tlsConf)
}
// loadTLSConf loads and validates the TLS configuration. The returned error is
@@ -454,7 +454,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
// same reason.
if restartHTTPS {
go func() {
Context.web.TLSConfigChanged(context.Background(), req.tlsConfigSettings)
Context.web.tlsConfigChanged(context.Background(), req.tlsConfigSettings)
}()
}
}

View File

@@ -22,7 +22,7 @@ import (
)
// currentSchemaVersion is the current schema version.
const currentSchemaVersion = 17
const currentSchemaVersion = 20
// These aliases are provided for convenience.
type (
@@ -90,6 +90,9 @@ func upgradeConfigSchema(oldVersion int, diskConf yobj) (err error) {
upgradeSchema14to15,
upgradeSchema15to16,
upgradeSchema16to17,
upgradeSchema17to18,
upgradeSchema18to19,
upgradeSchema19to20,
}
n := 0
@@ -836,9 +839,9 @@ func upgradeSchema14to15(diskConf yobj) (err error) {
}
type temp struct {
val any
from string
to string
val any
}
replaces := []temp{
{from: "querylog_enabled", to: "enabled", val: true},
@@ -873,6 +876,18 @@ func upgradeSchema14to15(diskConf yobj) (err error) {
// 'enabled': true
// 'interval': 1
// 'ignored': []
//
// If statistics were disabled:
//
// # BEFORE:
// 'dns':
// 'statistics_interval': 0
//
// # AFTER:
// 'statistics':
// 'enabled': false
// 'interval': 1
// 'ignored': []
func upgradeSchema15to16(diskConf yobj) (err error) {
log.Printf("Upgrade yaml: 15 to 16")
diskConf["schema_version"] = 16
@@ -894,10 +909,23 @@ func upgradeSchema15to16(diskConf yobj) (err error) {
}
const field = "statistics_interval"
v, has := dns[field]
statsIvlVal, has := dns[field]
if has {
stats["enabled"] = v != 0
stats["interval"] = v
var statsIvl int
statsIvl, ok = statsIvlVal.(int)
if !ok {
return fmt.Errorf("unexpected type of dns.statistics_interval: %T", statsIvlVal)
}
if statsIvl == 0 {
// Set the interval to the default value of one day to make sure
// that it passes the validations.
stats["interval"] = 1
stats["enabled"] = false
} else {
stats["interval"] = statsIvl
stats["enabled"] = true
}
}
delete(dns, field)
@@ -943,6 +971,172 @@ func upgradeSchema16to17(diskConf yobj) (err error) {
return nil
}
// upgradeSchema17to18 performs the following changes:
//
// # BEFORE:
// 'dns':
// 'safesearch_enabled': true
//
// # AFTER:
// 'dns':
// 'safe_search':
// 'enabled': true
// 'bing': true
// 'duckduckgo': true
// 'google': true
// 'pixabay': true
// 'yandex': true
// 'youtube': true
func upgradeSchema17to18(diskConf yobj) (err error) {
log.Printf("Upgrade yaml: 17 to 18")
diskConf["schema_version"] = 18
dnsVal, ok := diskConf["dns"]
if !ok {
return nil
}
dns, ok := dnsVal.(yobj)
if !ok {
return fmt.Errorf("unexpected type of dns: %T", dnsVal)
}
safeSearch := yobj{
"enabled": true,
"bing": true,
"duckduckgo": true,
"google": true,
"pixabay": true,
"yandex": true,
"youtube": true,
}
const safeSearchKey = "safesearch_enabled"
v, has := dns[safeSearchKey]
if has {
safeSearch["enabled"] = v
}
delete(dns, safeSearchKey)
dns["safe_search"] = safeSearch
return nil
}
// upgradeSchema18to19 performs the following changes:
//
// # BEFORE:
// 'clients':
// 'persistent':
// - 'name': 'client-name'
// 'safesearch_enabled': true
//
// # AFTER:
// 'clients':
// 'persistent':
// - 'name': 'client-name'
// 'safe_search':
// 'enabled': true
// 'bing': true
// 'duckduckgo': true
// 'google': true
// 'pixabay': true
// 'yandex': true
// 'youtube': true
func upgradeSchema18to19(diskConf yobj) (err error) {
log.Printf("Upgrade yaml: 18 to 19")
diskConf["schema_version"] = 19
clientsVal, ok := diskConf["clients"]
if !ok {
return nil
}
clients, ok := clientsVal.(yobj)
if !ok {
return fmt.Errorf("unexpected type of clients: %T", clientsVal)
}
persistent, ok := clients["persistent"].([]yobj)
if !ok {
return nil
}
const safeSearchKey = "safesearch_enabled"
for i := range persistent {
c := persistent[i]
safeSearch := yobj{
"enabled": true,
"bing": true,
"duckduckgo": true,
"google": true,
"pixabay": true,
"yandex": true,
"youtube": true,
}
v, has := c[safeSearchKey]
if has {
safeSearch["enabled"] = v
}
delete(c, safeSearchKey)
c["safe_search"] = safeSearch
}
return nil
}
// upgradeSchema19to20 performs the following changes:
//
// # BEFORE:
// 'statistics':
// 'interval': 1
//
// # AFTER:
// 'statistics':
// 'interval': 24h
func upgradeSchema19to20(diskConf yobj) (err error) {
log.Printf("Upgrade yaml: 19 to 20")
diskConf["schema_version"] = 20
statsVal, ok := diskConf["statistics"]
if !ok {
return nil
}
var stats yobj
stats, ok = statsVal.(yobj)
if !ok {
return fmt.Errorf("unexpected type of stats: %T", statsVal)
}
const field = "interval"
// Set the initial value from the global configuration structure.
statsIvl := 1
statsIvlVal, ok := stats[field]
if ok {
statsIvl, ok = statsIvlVal.(int)
if !ok {
return fmt.Errorf("unexpected type of %s: %T", field, statsIvlVal)
}
// The initial version of upgradeSchema16to17 did not set the zero
// interval to a non-zero one. So, reset it now.
if statsIvl == 0 {
statsIvl = 1
}
}
stats[field] = timeutil.Duration{Duration: time.Duration(statsIvl) * timeutil.Day}
return nil
}
// TODO(a.garipov): Replace with log.Output when we port it to our logging
// package.
func funcName() string {

View File

@@ -729,7 +729,7 @@ func TestUpgradeSchema15to16(t *testing.T) {
want: yobj{
"statistics": map[string]any{
"enabled": false,
"interval": 0,
"interval": 1,
"ignored": []any{},
},
"dns": map[string]any{},
@@ -808,3 +808,246 @@ func TestUpgradeSchema16to17(t *testing.T) {
})
}
}
func TestUpgradeSchema17to18(t *testing.T) {
const newSchemaVer = 18
defaultWantObj := yobj{
"dns": yobj{
"safe_search": yobj{
"enabled": true,
"bing": true,
"duckduckgo": true,
"google": true,
"pixabay": true,
"yandex": true,
"youtube": true,
},
},
"schema_version": newSchemaVer,
}
testCases := []struct {
in yobj
want yobj
name string
}{{
in: yobj{"dns": yobj{}},
want: defaultWantObj,
name: "default_values",
}, {
in: yobj{"dns": yobj{"safesearch_enabled": true}},
want: defaultWantObj,
name: "enabled",
}, {
in: yobj{"dns": yobj{"safesearch_enabled": false}},
want: yobj{
"dns": yobj{
"safe_search": map[string]any{
"enabled": false,
"bing": true,
"duckduckgo": true,
"google": true,
"pixabay": true,
"yandex": true,
"youtube": true,
},
},
"schema_version": newSchemaVer,
},
name: "disabled",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := upgradeSchema17to18(tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
})
}
}
func TestUpgradeSchema18to19(t *testing.T) {
const newSchemaVer = 19
defaultWantObj := yobj{
"clients": yobj{
"persistent": []yobj{{
"name": "localhost",
"safe_search": yobj{
"enabled": true,
"bing": true,
"duckduckgo": true,
"google": true,
"pixabay": true,
"yandex": true,
"youtube": true,
},
}},
},
"schema_version": newSchemaVer,
}
testCases := []struct {
in yobj
want yobj
name string
}{{
in: yobj{
"clients": yobj{},
},
want: yobj{
"clients": yobj{},
"schema_version": newSchemaVer,
},
name: "no_clients",
}, {
in: yobj{
"clients": yobj{
"persistent": []yobj{{"name": "localhost"}},
},
},
want: defaultWantObj,
name: "default_values",
}, {
in: yobj{
"clients": yobj{
"persistent": []yobj{{"name": "localhost", "safesearch_enabled": true}},
},
},
want: defaultWantObj,
name: "enabled",
}, {
in: yobj{
"clients": yobj{
"persistent": []yobj{{"name": "localhost", "safesearch_enabled": false}},
},
},
want: yobj{
"clients": yobj{"persistent": []yobj{{
"name": "localhost",
"safe_search": yobj{
"enabled": false,
"bing": true,
"duckduckgo": true,
"google": true,
"pixabay": true,
"yandex": true,
"youtube": true,
},
}}},
"schema_version": newSchemaVer,
},
name: "disabled",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := upgradeSchema18to19(tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
})
}
}
func TestUpgradeSchema19to20(t *testing.T) {
testCases := []struct {
ivl any
want any
wantErr string
name string
}{{
ivl: 1,
want: timeutil.Duration{Duration: timeutil.Day},
wantErr: "",
name: "success",
}, {
ivl: 0,
want: timeutil.Duration{Duration: timeutil.Day},
wantErr: "",
name: "success",
}, {
ivl: 0.25,
want: 0,
wantErr: "unexpected type of interval: float64",
name: "fail",
}}
for _, tc := range testCases {
conf := yobj{
"statistics": yobj{
"interval": tc.ivl,
},
"schema_version": 19,
}
t.Run(tc.name, func(t *testing.T) {
err := upgradeSchema19to20(conf)
if tc.wantErr != "" {
require.Error(t, err)
assert.Equal(t, tc.wantErr, err.Error())
return
}
require.NoError(t, err)
require.Equal(t, conf["schema_version"], 20)
statsVal, ok := conf["statistics"]
require.True(t, ok)
var stats yobj
stats, ok = statsVal.(yobj)
require.True(t, ok)
var newIvl timeutil.Duration
newIvl, ok = stats["interval"].(timeutil.Duration)
require.True(t, ok)
assert.Equal(t, tc.want, newIvl)
})
}
t.Run("no_stats", func(t *testing.T) {
err := upgradeSchema19to20(yobj{})
assert.NoError(t, err)
})
t.Run("bad_stats", func(t *testing.T) {
err := upgradeSchema19to20(yobj{
"statistics": 0,
})
testutil.AssertErrorMsg(t, "unexpected type of stats: int", err)
})
t.Run("no_field", func(t *testing.T) {
conf := yobj{
"statistics": yobj{},
}
err := upgradeSchema19to20(conf)
require.NoError(t, err)
statsVal, ok := conf["statistics"]
require.True(t, ok)
var stats yobj
stats, ok = statsVal.(yobj)
require.True(t, ok)
var ivl any
ivl, ok = stats["interval"]
require.True(t, ok)
var ivlVal timeutil.Duration
ivlVal, ok = ivl.(timeutil.Duration)
require.True(t, ok)
assert.Equal(t, 24*time.Hour, ivlVal.Duration)
})
}

View File

@@ -35,9 +35,8 @@ const (
type webConfig struct {
clientFS fs.FS
BindHost netip.Addr
BindPort int
PortHTTPS int
BindHost netip.Addr
BindPort int
// ReadTimeout is an option to pass to http.Server for setting an
// appropriate field.
@@ -72,8 +71,8 @@ type httpsServer struct {
enabled bool
}
// Web is the web UI and API server.
type Web struct {
// webAPI is the web UI and API server.
type webAPI struct {
conf *webConfig
// TODO(a.garipov): Refactor all these servers.
@@ -82,15 +81,13 @@ type Web struct {
// httpsServer is the server that handles HTTPS traffic. If it is not nil,
// [Web.http3Server] must also not be nil.
httpsServer httpsServer
forceHTTPS bool
}
// newWeb creates a new instance of the web UI and API server.
func newWeb(conf *webConfig) (w *Web) {
// newWebAPI creates a new instance of the web UI and API server.
func newWebAPI(conf *webConfig) (w *webAPI) {
log.Info("web: initializing")
w = &Web{
w = &webAPI{
conf: conf,
}
@@ -125,12 +122,10 @@ func webCheckPortAvailable(port int) (ok bool) {
return aghnet.CheckPort("tcp", netip.AddrPortFrom(config.BindHost, uint16(port))) == nil
}
// TLSConfigChanged updates the TLS configuration and restarts the HTTPS server
// tlsConfigChanged updates the TLS configuration and restarts the HTTPS server
// if necessary.
func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) {
func (web *webAPI) tlsConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) {
log.Debug("web: applying new tls configuration")
web.conf.PortHTTPS = tlsConf.PortHTTPS
web.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0)
enabled := tlsConf.Enabled &&
tlsConf.PortHTTPS != 0 &&
@@ -161,8 +156,8 @@ func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf tlsConfigSettings)
web.httpsServer.cond.L.Unlock()
}
// Start - start serving HTTP requests
func (web *Web) Start() {
// start - start serving HTTP requests
func (web *webAPI) start() {
log.Println("AdGuard Home is available at the following addresses:")
// for https, we have a separate goroutine loop
@@ -203,8 +198,8 @@ func (web *Web) Start() {
}
}
// Close gracefully shuts down the HTTP servers.
func (web *Web) Close(ctx context.Context) {
// close gracefully shuts down the HTTP servers.
func (web *webAPI) close(ctx context.Context) {
log.Info("stopping http server...")
web.httpsServer.cond.L.Lock()
@@ -222,7 +217,7 @@ func (web *Web) Close(ctx context.Context) {
log.Info("stopped http server")
}
func (web *Web) tlsServerLoop() {
func (web *webAPI) tlsServerLoop() {
for {
web.httpsServer.cond.L.Lock()
if web.httpsServer.inShutdown {
@@ -241,7 +236,15 @@ func (web *Web) tlsServerLoop() {
web.httpsServer.cond.L.Unlock()
addr := netutil.JoinHostPort(web.conf.BindHost.String(), web.conf.PortHTTPS)
var portHTTPS int
func() {
config.RLock()
defer config.RUnlock()
portHTTPS = config.TLS.PortHTTPS
}()
addr := netutil.JoinHostPort(web.conf.BindHost.String(), portHTTPS)
web.httpsServer.server = &http.Server{
ErrorLog: log.StdLog("web: https", log.DEBUG),
Addr: addr,
@@ -272,7 +275,7 @@ func (web *Web) tlsServerLoop() {
}
}
func (web *Web) mustStartHTTP3(address string) {
func (web *webAPI) mustStartHTTP3(address string) {
defer log.OnPanic("web: http3")
web.httpsServer.server3 = &http3.Server{

View File

@@ -7,6 +7,7 @@ type Client struct {
Name string `json:"name"`
DisallowedRule string `json:"disallowed_rule"`
Disallowed bool `json:"disallowed"`
IgnoreQueryLog bool `json:"-"`
}
// ClientWHOIS is the filtered WHOIS data for the client.

View File

@@ -166,86 +166,6 @@ var logEntryHandlers = map[string]logEntryHandler{
},
}
var resultHandlers = map[string]logEntryHandler{
"IsFiltered": func(t json.Token, ent *logEntry) error {
v, ok := t.(bool)
if !ok {
return nil
}
ent.Result.IsFiltered = v
return nil
},
"Rule": func(t json.Token, ent *logEntry) error {
s, ok := t.(string)
if !ok {
return nil
}
l := len(ent.Result.Rules)
if l == 0 {
ent.Result.Rules = []*filtering.ResultRule{{}}
l++
}
ent.Result.Rules[l-1].Text = s
return nil
},
"FilterID": func(t json.Token, ent *logEntry) error {
n, ok := t.(json.Number)
if !ok {
return nil
}
i, err := n.Int64()
if err != nil {
return err
}
l := len(ent.Result.Rules)
if l == 0 {
ent.Result.Rules = []*filtering.ResultRule{{}}
l++
}
ent.Result.Rules[l-1].FilterListID = i
return nil
},
"Reason": func(t json.Token, ent *logEntry) error {
v, ok := t.(json.Number)
if !ok {
return nil
}
i, err := v.Int64()
if err != nil {
return err
}
ent.Result.Reason = filtering.Reason(i)
return nil
},
"ServiceName": func(t json.Token, ent *logEntry) error {
s, ok := t.(string)
if !ok {
return nil
}
ent.Result.ServiceName = s
return nil
},
"CanonName": func(t json.Token, ent *logEntry) error {
s, ok := t.(string)
if !ok {
return nil
}
ent.Result.CanonName = s
return nil
},
}
func decodeResultRuleKey(key string, i int, dec *json.Decoder, ent *logEntry) {
var vToken json.Token
switch key {
@@ -582,25 +502,11 @@ func decodeResult(dec *json.Decoder, ent *logEntry) {
return
}
switch key {
case "ReverseHosts":
decodeResultReverseHosts(dec, ent)
decHandler, ok := resultDecHandlers[key]
if ok {
decHandler(dec, ent)
continue
case "IPList":
decodeResultIPList(dec, ent)
continue
case "Rules":
decodeResultRules(dec, ent)
continue
case "DNSRewriteResult":
decodeResultDNSRewriteResult(dec, ent)
continue
default:
// Go on.
}
handler, ok := resultHandlers[key]
@@ -621,6 +527,93 @@ func decodeResult(dec *json.Decoder, ent *logEntry) {
}
}
var resultHandlers = map[string]logEntryHandler{
"IsFiltered": func(t json.Token, ent *logEntry) error {
v, ok := t.(bool)
if !ok {
return nil
}
ent.Result.IsFiltered = v
return nil
},
"Rule": func(t json.Token, ent *logEntry) error {
s, ok := t.(string)
if !ok {
return nil
}
l := len(ent.Result.Rules)
if l == 0 {
ent.Result.Rules = []*filtering.ResultRule{{}}
l++
}
ent.Result.Rules[l-1].Text = s
return nil
},
"FilterID": func(t json.Token, ent *logEntry) error {
n, ok := t.(json.Number)
if !ok {
return nil
}
i, err := n.Int64()
if err != nil {
return err
}
l := len(ent.Result.Rules)
if l == 0 {
ent.Result.Rules = []*filtering.ResultRule{{}}
l++
}
ent.Result.Rules[l-1].FilterListID = i
return nil
},
"Reason": func(t json.Token, ent *logEntry) error {
v, ok := t.(json.Number)
if !ok {
return nil
}
i, err := v.Int64()
if err != nil {
return err
}
ent.Result.Reason = filtering.Reason(i)
return nil
},
"ServiceName": func(t json.Token, ent *logEntry) error {
s, ok := t.(string)
if !ok {
return nil
}
ent.Result.ServiceName = s
return nil
},
"CanonName": func(t json.Token, ent *logEntry) error {
s, ok := t.(string)
if !ok {
return nil
}
ent.Result.CanonName = s
return nil
},
}
var resultDecHandlers = map[string]func(dec *json.Decoder, ent *logEntry){
"ReverseHosts": decodeResultReverseHosts,
"IPList": decodeResultIPList,
"Rules": decodeResultRules,
"DNSRewriteResult": decodeResultDNSRewriteResult,
}
func decodeLogEntry(ent *logEntry, str string) {
dec := json.NewDecoder(strings.NewReader(str))
dec.UseNumber()

View File

@@ -0,0 +1,70 @@
package querylog
import (
"net"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// logEntry represents a single entry in the file.
type logEntry struct {
// client is the found client information, if any.
client *Client
Time time.Time `json:"T"`
QHost string `json:"QH"`
QType string `json:"QT"`
QClass string `json:"QC"`
ReqECS string `json:"ECS,omitempty"`
ClientID string `json:"CID,omitempty"`
ClientProto ClientProto `json:"CP"`
Upstream string `json:",omitempty"`
Answer []byte `json:",omitempty"`
OrigAnswer []byte `json:",omitempty"`
IP net.IP `json:"IP"`
Result filtering.Result
Elapsed time.Duration
Cached bool `json:",omitempty"`
AuthenticatedData bool `json:"AD,omitempty"`
}
// shallowClone returns a shallow clone of e.
func (e *logEntry) shallowClone() (clone *logEntry) {
cloneVal := *e
return &cloneVal
}
// addResponse adds data from resp to e.Answer if resp is not nil. If isOrig is
// true, addResponse sets the e.OrigAnswer field instead of e.Answer. Any
// errors are logged.
func (e *logEntry) addResponse(resp *dns.Msg, isOrig bool) {
if resp == nil {
return
}
var err error
if isOrig {
e.Answer, err = resp.Pack()
err = errors.Annotate(err, "packing answer: %w")
} else {
e.OrigAnswer, err = resp.Pack()
err = errors.Annotate(err, "packing orig answer: %w")
}
if err != nil {
log.Error("querylog: %s", err)
}
}

View File

@@ -13,9 +13,11 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/timeutil"
"golang.org/x/exp/slices"
"golang.org/x/net/idna"
)
@@ -25,8 +27,8 @@ type configJSON struct {
// fractional numbers and not mess the API users by changing the units.
Interval float64 `json:"interval"`
// Enabled shows if the querylog is enabled. It is an [aghalg.NullBool]
// to be able to tell when it's set without using pointers.
// Enabled shows if the querylog is enabled. It is an aghalg.NullBool to
// be able to tell when it's set without using pointers.
Enabled aghalg.NullBool `json:"enabled"`
// AnonymizeClientIP shows if the clients' IP addresses must be anonymized.
@@ -35,44 +37,115 @@ type configJSON struct {
AnonymizeClientIP aghalg.NullBool `json:"anonymize_client_ip"`
}
// getConfigResp is the JSON structure for the querylog configuration.
type getConfigResp struct {
// Ignored is the list of host names, which should not be written to log.
Ignored []string `json:"ignored"`
// Interval is the querylog rotation interval in milliseconds.
Interval float64 `json:"interval"`
// Enabled shows if the querylog is enabled. It is an aghalg.NullBool to
// be able to tell when it's set without using pointers.
Enabled aghalg.NullBool `json:"enabled"`
// AnonymizeClientIP shows if the clients' IP addresses must be anonymized.
// It is an aghalg.NullBool to be able to tell when it's set without using
// pointers.
//
// TODO(a.garipov): Consider using separate setting for statistics.
AnonymizeClientIP aghalg.NullBool `json:"anonymize_client_ip"`
}
// Register web handlers
func (l *queryLog) initWeb() {
l.conf.HTTPRegister(http.MethodGet, "/control/querylog", l.handleQueryLog)
l.conf.HTTPRegister(http.MethodGet, "/control/querylog_info", l.handleQueryLogInfo)
l.conf.HTTPRegister(http.MethodPost, "/control/querylog_clear", l.handleQueryLogClear)
l.conf.HTTPRegister(http.MethodGet, "/control/querylog/config", l.handleGetQueryLogConfig)
l.conf.HTTPRegister(
http.MethodPut,
"/control/querylog/config/update",
l.handlePutQueryLogConfig,
)
// Deprecated handlers.
l.conf.HTTPRegister(http.MethodGet, "/control/querylog_info", l.handleQueryLogInfo)
l.conf.HTTPRegister(http.MethodPost, "/control/querylog_config", l.handleQueryLogConfig)
}
// handleQueryLog is the handler for the GET /control/querylog HTTP API.
func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
l.lock.Lock()
defer l.lock.Unlock()
params, err := l.parseSearchParams(r)
params, err := parseSearchParams(r)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "failed to parse params: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "parsing params: %s", err)
return
}
entries, oldest := l.search(params)
data := l.entriesToJSON(entries, oldest)
var entries []*logEntry
var oldest time.Time
func() {
l.confMu.RLock()
defer l.confMu.RUnlock()
_ = aghhttp.WriteJSONResponse(w, r, data)
entries, oldest = l.search(params)
}()
resp := entriesToJSON(entries, oldest, l.anonymizer.Load())
_ = aghhttp.WriteJSONResponse(w, r, resp)
}
// handleQueryLogClear is the handler for the POST /control/querylog/clear HTTP
// API.
func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, _ *http.Request) {
l.clear()
}
// Get configuration
// handleQueryLogInfo is the handler for the GET /control/querylog_info HTTP
// API.
//
// Deprecated: Remove it when migration to the new API is over.
func (l *queryLog) handleQueryLogInfo(w http.ResponseWriter, r *http.Request) {
l.confMu.RLock()
defer l.confMu.RUnlock()
ivl := l.conf.RotationIvl
if !checkInterval(ivl) {
// NOTE: If interval is custom we set it to 90 days for compatibility
// with old API.
ivl = timeutil.Day * 90
}
_ = aghhttp.WriteJSONResponse(w, r, configJSON{
Enabled: aghalg.BoolToNullBool(l.conf.Enabled),
Interval: l.conf.RotationIvl.Hours() / 24,
Interval: ivl.Hours() / 24,
AnonymizeClientIP: aghalg.BoolToNullBool(l.conf.AnonymizeClientIP),
})
}
// handleGetQueryLogConfig is the handler for the GET /control/querylog/config
// HTTP API.
func (l *queryLog) handleGetQueryLogConfig(w http.ResponseWriter, r *http.Request) {
var resp *getConfigResp
func() {
l.confMu.RLock()
defer l.confMu.RUnlock()
resp = &getConfigResp{
Interval: float64(l.conf.RotationIvl.Milliseconds()),
Enabled: aghalg.BoolToNullBool(l.conf.Enabled),
AnonymizeClientIP: aghalg.BoolToNullBool(l.conf.AnonymizeClientIP),
Ignored: l.conf.Ignored.Values(),
}
}()
slices.Sort(resp.Ignored)
_ = aghhttp.WriteJSONResponse(w, r, resp)
}
// AnonymizeIP masks ip to anonymize the client if the ip is a valid one.
func AnonymizeIP(ip net.IP) {
// zeroes is a slice of zero bytes from which the IP address tail is copied.
@@ -87,7 +160,10 @@ func AnonymizeIP(ip net.IP) {
}
}
// handleQueryLogConfig handles the POST /control/querylog_config queries.
// handleQueryLogConfig is the handler for the POST /control/querylog_config
// HTTP API.
//
// Deprecated: Remove it when migration to the new API is over.
func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request) {
// Set NaN as initial value to be able to know if it changed later by
// comparing it to NaN.
@@ -103,6 +179,7 @@ func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request)
}
ivl := time.Duration(float64(timeutil.Day) * newConf.Interval)
hasIvl := !math.IsNaN(newConf.Interval)
if hasIvl && !checkInterval(ivl) {
aghhttp.Error(r, w, http.StatusBadRequest, "unsupported interval")
@@ -112,11 +189,9 @@ func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request)
defer l.conf.ConfigModified()
l.lock.Lock()
defer l.lock.Unlock()
l.confMu.Lock()
defer l.confMu.Unlock()
// Copy data, modify it, then activate. Other threads (readers) don't need
// to use this lock.
conf := *l.conf
if newConf.Enabled != aghalg.NBNull {
conf.Enabled = newConf.Enabled == aghalg.NBTrue
@@ -138,6 +213,65 @@ func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request)
l.conf = &conf
}
// handlePutQueryLogConfig is the handler for the PUT
// /control/querylog/config/update HTTP API.
func (l *queryLog) handlePutQueryLogConfig(w http.ResponseWriter, r *http.Request) {
newConf := &getConfigResp{}
err := json.NewDecoder(r.Body).Decode(newConf)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
set, err := aghnet.NewDomainNameSet(newConf.Ignored)
if err != nil {
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "ignored: %s", err)
return
}
ivl := time.Duration(newConf.Interval) * time.Millisecond
err = validateIvl(ivl)
if err != nil {
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "unsupported interval: %s", err)
return
}
if newConf.Enabled == aghalg.NBNull {
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "enabled is null")
return
}
if newConf.AnonymizeClientIP == aghalg.NBNull {
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "anonymize_client_ip is null")
return
}
defer l.conf.ConfigModified()
l.confMu.Lock()
defer l.confMu.Unlock()
conf := *l.conf
conf.Ignored = set
conf.RotationIvl = ivl
conf.Enabled = newConf.Enabled == aghalg.NBTrue
conf.AnonymizeClientIP = newConf.AnonymizeClientIP == aghalg.NBTrue
if conf.AnonymizeClientIP {
l.anonymizer.Store(AnonymizeIP)
} else {
l.anonymizer.Store(nil)
}
l.conf = &conf
}
// "value" -> value, return TRUE
func getDoubleQuotesEnclosedValue(s *string) bool {
t := *s
@@ -149,7 +283,7 @@ func getDoubleQuotesEnclosedValue(s *string) bool {
}
// parseSearchCriterion parses a search criterion from the query parameter.
func (l *queryLog) parseSearchCriterion(q url.Values, name string, ct criterionType) (
func parseSearchCriterion(q url.Values, name string, ct criterionType) (
ok bool,
sc searchCriterion,
err error,
@@ -198,8 +332,9 @@ func (l *queryLog) parseSearchCriterion(q url.Values, name string, ct criterionT
return true, sc, nil
}
// parseSearchParams - parses "searchParams" from the HTTP request's query string
func (l *queryLog) parseSearchParams(r *http.Request) (p *searchParams, err error) {
// parseSearchParams parses search parameters from the HTTP request's query
// string.
func parseSearchParams(r *http.Request) (p *searchParams, err error) {
p = newSearchParams()
q := r.URL.Query()
@@ -237,7 +372,7 @@ func (l *queryLog) parseSearchParams(r *http.Request) (p *searchParams, err erro
}} {
var ok bool
var c searchCriterion
ok, c, err = l.parseSearchCriterion(q, v.urlField, v.ct)
ok, c, err = parseSearchCriterion(q, v.urlField, v.ct)
if err != nil {
return nil, err
}

View File

@@ -19,12 +19,16 @@ import (
type jobject = map[string]any
// entriesToJSON converts query log entries to JSON.
func (l *queryLog) entriesToJSON(entries []*logEntry, oldest time.Time) (res jobject) {
func entriesToJSON(
entries []*logEntry,
oldest time.Time,
anonFunc aghnet.IPMutFunc,
) (res jobject) {
data := make([]jobject, 0, len(entries))
// The elements order is already reversed to be from newer to older.
for _, entry := range entries {
jsonEntry := l.entryToJSON(entry, l.anonymizer.Load())
jsonEntry := entryToJSON(entry, anonFunc)
data = append(data, jsonEntry)
}
@@ -40,7 +44,7 @@ func (l *queryLog) entriesToJSON(entries []*logEntry, oldest time.Time) (res job
}
// entryToJSON converts a log entry's data into an entry for the JSON API.
func (l *queryLog) entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject) {
func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject) {
hostname := entry.QHost
question := jobject{
"type": entry.QType,
@@ -92,14 +96,14 @@ func (l *queryLog) entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (json
jsonEntry["service_name"] = entry.Result.ServiceName
}
l.setMsgData(entry, jsonEntry)
l.setOrigAns(entry, jsonEntry)
setMsgData(entry, jsonEntry)
setOrigAns(entry, jsonEntry)
return jsonEntry
}
// setMsgData sets the message data in jsonEntry.
func (l *queryLog) setMsgData(entry *logEntry, jsonEntry jobject) {
func setMsgData(entry *logEntry, jsonEntry jobject) {
if len(entry.Answer) == 0 {
return
}
@@ -122,7 +126,7 @@ func (l *queryLog) setMsgData(entry *logEntry, jsonEntry jobject) {
}
// setOrigAns sets the original answer data in jsonEntry.
func (l *queryLog) setOrigAns(entry *logEntry, jsonEntry jobject) {
func setOrigAns(entry *logEntry, jsonEntry jobject) {
if len(entry.OrigAnswer) == 0 {
return
}

View File

@@ -3,7 +3,6 @@ package querylog
import (
"fmt"
"net"
"os"
"strings"
"sync"
@@ -25,9 +24,12 @@ const (
type queryLog struct {
findClient func(ids []string) (c *Client, err error)
conf *Config
lock sync.Mutex
logFile string // path to the log file
// confMu protects conf.
confMu *sync.RWMutex
conf *Config
// logFile is the path to the log file.
logFile string
// bufferLock protects buffer.
bufferLock sync.RWMutex
@@ -71,52 +73,24 @@ func NewClientProto(s string) (cp ClientProto, err error) {
}
}
// logEntry - represents a single log entry
type logEntry struct {
// client is the found client information, if any.
client *Client
Time time.Time `json:"T"`
QHost string `json:"QH"`
QType string `json:"QT"`
QClass string `json:"QC"`
ReqECS string `json:"ECS,omitempty"`
ClientID string `json:"CID,omitempty"`
ClientProto ClientProto `json:"CP"`
Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net
OrigAnswer []byte `json:",omitempty"`
Result filtering.Result
Upstream string `json:",omitempty"`
IP net.IP `json:"IP"`
Elapsed time.Duration
Cached bool `json:",omitempty"`
AuthenticatedData bool `json:"AD,omitempty"`
}
// shallowClone returns a shallow clone of e.
func (e *logEntry) shallowClone() (clone *logEntry) {
cloneVal := *e
return &cloneVal
}
func (l *queryLog) Start() {
if l.conf.HTTPRegister != nil {
l.initWeb()
}
go l.periodicRotate()
}
func (l *queryLog) Close() {
_ = l.flushLogBuffer(true)
l.confMu.RLock()
defer l.confMu.RUnlock()
if l.conf.FileEnabled {
err := l.flushLogBuffer()
if err != nil {
log.Error("querylog: closing: %s", err)
}
}
}
func checkInterval(ivl time.Duration) (ok bool) {
@@ -132,8 +106,26 @@ func checkInterval(ivl time.Duration) (ok bool) {
return ivl == quarterDay || ivl == day || ivl == week || ivl == month || ivl == threeMonths
}
// validateIvl returns an error if ivl is less than an hour or more than a
// year.
func validateIvl(ivl time.Duration) (err error) {
if ivl < time.Hour {
return errors.Error("less than an hour")
}
if ivl > timeutil.Day*365 {
return errors.Error("more than a year")
}
return nil
}
func (l *queryLog) WriteDiskConfig(c *Config) {
l.confMu.RLock()
defer l.confMu.RUnlock()
*c = *l.conf
c.Ignored = l.conf.Ignored.Clone()
}
// Clear memory buffer and remove log files
@@ -141,10 +133,13 @@ func (l *queryLog) clear() {
l.fileFlushLock.Lock()
defer l.fileFlushLock.Unlock()
l.bufferLock.Lock()
l.buffer = nil
l.flushPending = false
l.bufferLock.Unlock()
func() {
l.bufferLock.Lock()
defer l.bufferLock.Unlock()
l.buffer = nil
l.flushPending = false
}()
oldLogFile := l.logFile + ".1"
err := os.Remove(oldLogFile)
@@ -161,7 +156,17 @@ func (l *queryLog) clear() {
}
func (l *queryLog) Add(params *AddParams) {
if !l.conf.Enabled {
var isEnabled, fileIsEnabled bool
var memSize uint32
func() {
l.confMu.RLock()
defer l.confMu.RUnlock()
isEnabled, fileIsEnabled = l.conf.Enabled, l.conf.FileEnabled
memSize = l.conf.MemSize
}()
if !isEnabled {
return
}
@@ -178,7 +183,7 @@ func (l *queryLog) Add(params *AddParams) {
now := time.Now()
q := params.Question.Question[0]
entry := logEntry{
entry := &logEntry{
Time: now,
QHost: strings.ToLower(q.Name[:len(q.Name)-1]),
@@ -203,65 +208,63 @@ func (l *queryLog) Add(params *AddParams) {
entry.ReqECS = params.ReqECS.String()
}
if params.Answer != nil {
var a []byte
a, err = params.Answer.Pack()
if err != nil {
log.Error("querylog: Answer.Pack(): %s", err)
entry.addResponse(params.Answer, false)
entry.addResponse(params.OrigAnswer, true)
return
}
entry.Answer = a
}
if params.OrigAnswer != nil {
var a []byte
a, err = params.OrigAnswer.Pack()
if err != nil {
log.Error("querylog: OrigAnswer.Pack(): %s", err)
return
}
entry.OrigAnswer = a
}
l.bufferLock.Lock()
l.buffer = append(l.buffer, &entry)
needFlush := false
func() {
l.bufferLock.Lock()
defer l.bufferLock.Unlock()
if !l.conf.FileEnabled {
if len(l.buffer) > int(l.conf.MemSize) {
// writing to file is disabled - just remove the oldest entry from array
//
// TODO(a.garipov): This should be replaced by a proper ring buffer,
// but it's currently difficult to do that.
l.buffer[0] = nil
l.buffer = l.buffer[1:]
}
} else if !l.flushPending {
needFlush = len(l.buffer) >= int(l.conf.MemSize)
if needFlush {
l.flushPending = true
}
}
l.bufferLock.Unlock()
l.buffer = append(l.buffer, entry)
if !fileIsEnabled {
if len(l.buffer) > int(memSize) {
// Writing to file is disabled, so just remove the oldest entry
// from the slices.
//
// TODO(a.garipov): This should be replaced by a proper ring
// buffer, but it's currently difficult to do that.
l.buffer[0] = nil
l.buffer = l.buffer[1:]
}
} else if !l.flushPending {
needFlush = len(l.buffer) >= int(memSize)
if needFlush {
l.flushPending = true
}
}
}()
// if buffer needs to be flushed to disk, do it now
if needFlush {
go func() {
_ = l.flushLogBuffer(false)
flushErr := l.flushLogBuffer()
if flushErr != nil {
log.Error("querylog: flushing after adding: %s", err)
}
}()
}
}
// ShouldLog returns true if request for the host should be logged.
func (l *queryLog) ShouldLog(host string, _, _ uint16) bool {
func (l *queryLog) ShouldLog(host string, _, _ uint16, ids []string) bool {
l.confMu.RLock()
defer l.confMu.RUnlock()
c, err := l.findClient(ids)
if err != nil {
log.Error("querylog: finding client: %s", err)
}
if c != nil && c.IgnoreQueryLog {
return false
}
return !l.isIgnored(host)
}
// isIgnored returns true if the host is in the Ignored list.
// isIgnored returns true if the host is in the ignored domains list. It
// assumes that l.confMu is locked for reading.
func (l *queryLog) isIgnored(host string) bool {
return l.conf.Ignored.Has(host)
}

View File

@@ -22,24 +22,25 @@ func TestMain(m *testing.M) {
// TestQueryLog tests adding and loading (with filtering) entries from disk and
// memory.
func TestQueryLog(t *testing.T) {
l := newQueryLog(Config{
l, err := newQueryLog(Config{
Enabled: true,
FileEnabled: true,
RotationIvl: timeutil.Day,
MemSize: 100,
BaseDir: t.TempDir(),
})
require.NoError(t, err)
// Add disk entries.
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
// Write to disk (first file).
require.NoError(t, l.flushLogBuffer(true))
require.NoError(t, l.flushLogBuffer())
// Start writing to the second file.
require.NoError(t, l.rotate())
// Add disk entries.
addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
// Write to disk.
require.NoError(t, l.flushLogBuffer(true))
require.NoError(t, l.flushLogBuffer())
// Add memory entries.
addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4))
@@ -125,12 +126,13 @@ func TestQueryLog(t *testing.T) {
}
func TestQueryLogOffsetLimit(t *testing.T) {
l := newQueryLog(Config{
l, err := newQueryLog(Config{
Enabled: true,
RotationIvl: timeutil.Day,
MemSize: 100,
BaseDir: t.TempDir(),
})
require.NoError(t, err)
const (
entNum = 10
@@ -142,7 +144,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
}
// Write them to the first file.
require.NoError(t, l.flushLogBuffer(true))
require.NoError(t, l.flushLogBuffer())
// Add more to the in-memory part of log.
for i := 0; i < entNum; i++ {
addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
@@ -199,13 +201,14 @@ func TestQueryLogOffsetLimit(t *testing.T) {
}
func TestQueryLogMaxFileScanEntries(t *testing.T) {
l := newQueryLog(Config{
l, err := newQueryLog(Config{
Enabled: true,
FileEnabled: true,
RotationIvl: timeutil.Day,
MemSize: 100,
BaseDir: t.TempDir(),
})
require.NoError(t, err)
const entNum = 10
// Add entries to the log.
@@ -213,7 +216,7 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
}
// Write them to disk.
require.NoError(t, l.flushLogBuffer(true))
require.NoError(t, l.flushLogBuffer())
params := newSearchParams()
@@ -227,13 +230,14 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
}
func TestQueryLogFileDisabled(t *testing.T) {
l := newQueryLog(Config{
l, err := newQueryLog(Config{
Enabled: true,
FileEnabled: false,
RotationIvl: timeutil.Day,
MemSize: 2,
BaseDir: t.TempDir(),
})
require.NoError(t, err)
addEntry(l, "example1.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
addEntry(l, "example2.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
@@ -254,35 +258,52 @@ func TestQueryLogShouldLog(t *testing.T) {
)
set := stringutil.NewSet(ignored1, ignored2)
l := newQueryLog(Config{
findClient := func(ids []string) (c *Client, err error) {
log := ids[0] == "no_log"
return &Client{IgnoreQueryLog: log}, nil
}
l, err := newQueryLog(Config{
Ignored: set,
Enabled: true,
RotationIvl: timeutil.Day,
MemSize: 100,
BaseDir: t.TempDir(),
Ignored: set,
FindClient: findClient,
})
require.NoError(t, err)
testCases := []struct {
name string
host string
ids []string
wantLog bool
}{{
name: "log",
host: "example.com",
ids: []string{"whatever"},
wantLog: true,
}, {
name: "no_log_ignored_1",
host: ignored1,
ids: []string{"whatever"},
wantLog: false,
}, {
name: "no_log_ignored_2",
host: ignored2,
ids: []string{"whatever"},
wantLog: false,
}, {
name: "no_log_client_ignore",
host: "example.com",
ids: []string{"no_log"},
wantLog: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
res := l.ShouldLog(tc.host, dns.TypeA, dns.ClassINET)
res := l.ShouldLog(tc.host, dns.TypeA, dns.ClassINET, tc.ids)
assert.Equal(t, tc.wantLog, res)
})

View File

@@ -106,6 +106,7 @@ func (r *QLogReader) SeekStart() error {
r.currentFile = len(r.qFiles) - 1
_, err := r.qFiles[r.currentFile].SeekStart()
return err
}

View File

@@ -1,17 +1,17 @@
package querylog
import (
"fmt"
"net"
"path/filepath"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/miekg/dns"
)
@@ -29,16 +29,21 @@ type QueryLog interface {
WriteDiskConfig(c *Config)
// ShouldLog returns true if request for the host should be logged.
ShouldLog(host string, qType, qClass uint16) bool
ShouldLog(host string, qType, qClass uint16, ids []string) bool
}
// Config is the query log configuration structure.
//
// Do not alter any fields of this structure after using it.
type Config struct {
// Ignored is the list of host names, which should not be written to log.
Ignored *stringutil.Set
// Anonymizer processes the IP addresses to anonymize those if needed.
Anonymizer *aghnet.IPMut
// ConfigModified is called when the configuration is changed, for
// example by HTTP requests.
// ConfigModified is called when the configuration is changed, for example
// by HTTP requests.
ConfigModified func()
// HTTPRegister registers an HTTP handler.
@@ -50,20 +55,13 @@ type Config struct {
// BaseDir is the base directory for log files.
BaseDir string
// RotationIvl is the interval for log rotation. After that period, the
// old log file will be renamed, NOT deleted, so the actual log
// retention time is twice the interval. The value must be one of:
//
// 6 * time.Hour
// 1 * timeutil.Day
// 7 * timeutil.Day
// 30 * timeutil.Day
// 90 * timeutil.Day
//
// RotationIvl is the interval for log rotation. After that period, the old
// log file will be renamed, NOT deleted, so the actual log retention time
// is twice the interval.
RotationIvl time.Duration
// MemSize is the number of entries kept in a memory buffer before they
// are flushed to disk.
// MemSize is the number of entries kept in a memory buffer before they are
// flushed to disk.
MemSize uint32
// Enabled tells if the query log is enabled.
@@ -75,10 +73,6 @@ type Config struct {
// AnonymizeClientIP tells if the query log should anonymize clients' IP
// addresses.
AnonymizeClientIP bool
// Ignored is the list of host names, which should not be written to
// log.
Ignored *stringutil.Set
}
// AddParams is the parameters for adding an entry.
@@ -135,12 +129,12 @@ func (p *AddParams) validate() (err error) {
}
// New creates a new instance of the query log.
func New(conf Config) (ql QueryLog) {
func New(conf Config) (ql QueryLog, err error) {
return newQueryLog(conf)
}
// newQueryLog crates a new queryLog.
func newQueryLog(conf Config) (l *queryLog) {
func newQueryLog(conf Config) (l *queryLog, err error) {
findClient := conf.FindClient
if findClient == nil {
findClient = func(_ []string) (_ *Client, _ error) {
@@ -151,20 +145,19 @@ func newQueryLog(conf Config) (l *queryLog) {
l = &queryLog{
findClient: findClient,
logFile: filepath.Join(conf.BaseDir, queryLogFileName),
conf: &Config{},
confMu: &sync.RWMutex{},
logFile: filepath.Join(conf.BaseDir, queryLogFileName),
anonymizer: conf.Anonymizer,
}
l.conf = &Config{}
*l.conf = conf
if !checkInterval(conf.RotationIvl) {
log.Info(
"querylog: warning: unsupported rotation interval %s, setting to 1 day",
conf.RotationIvl,
)
l.conf.RotationIvl = timeutil.Day
err = validateIvl(conf.RotationIvl)
if err != nil {
return nil, fmt.Errorf("unsupported interval: %w", err)
}
return l
return l, nil
}

View File

@@ -11,40 +11,35 @@ import (
"github.com/AdguardTeam/golibs/log"
)
// flushLogBuffer flushes the current buffer to file and resets the current buffer
func (l *queryLog) flushLogBuffer(fullFlush bool) error {
if !l.conf.FileEnabled {
return nil
}
// flushLogBuffer flushes the current buffer to file and resets the current
// buffer.
func (l *queryLog) flushLogBuffer() (err error) {
l.fileFlushLock.Lock()
defer l.fileFlushLock.Unlock()
// flush remainder to file
l.bufferLock.Lock()
needFlush := len(l.buffer) >= int(l.conf.MemSize)
if !needFlush && !fullFlush {
l.bufferLock.Unlock()
return nil
}
flushBuffer := l.buffer
l.buffer = nil
l.flushPending = false
l.bufferLock.Unlock()
err := l.flushToFile(flushBuffer)
if err != nil {
log.Error("Saving querylog to file failed: %s", err)
return err
}
return nil
var flushBuffer []*logEntry
func() {
l.bufferLock.Lock()
defer l.bufferLock.Unlock()
flushBuffer = l.buffer
l.buffer = nil
l.flushPending = false
}()
err = l.flushToFile(flushBuffer)
return errors.Annotate(err, "writing to file: %w")
}
// flushToFile saves the specified log entries to the query log file
func (l *queryLog) flushToFile(buffer []*logEntry) (err error) {
if len(buffer) == 0 {
log.Debug("querylog: there's nothing to write to a file")
log.Debug("querylog: nothing to write to a file")
return nil
}
start := time.Now()
var b bytes.Buffer
@@ -155,8 +150,13 @@ func (l *queryLog) periodicRotate() {
// checkAndRotate rotates log files if those are older than the specified
// rotation interval.
func (l *queryLog) checkAndRotate() {
l.lock.Lock()
defer l.lock.Unlock()
var rotationIvl time.Duration
func() {
l.confMu.RLock()
defer l.confMu.RUnlock()
rotationIvl = l.conf.RotationIvl
}()
oldest, err := l.readFileFirstTimeValue()
if err != nil && !errors.Is(err, os.ErrNotExist) {
@@ -165,11 +165,11 @@ func (l *queryLog) checkAndRotate() {
return
}
if rot, now := oldest.Add(l.conf.RotationIvl), time.Now(); rot.After(now) {
if rotTime, now := oldest.Add(rotationIvl), time.Now(); rotTime.After(now) {
log.Debug(
"querylog: %s <= %s, not rotating",
now.Format(time.RFC3339),
rot.Format(time.RFC3339),
rotTime.Format(time.RFC3339),
)
return

View File

@@ -76,15 +76,20 @@ func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entrie
// search - searches log entries in the query log using specified parameters
// returns the list of entries found + time of the oldest entry
func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest time.Time) {
now := time.Now()
start := time.Now()
if params.limit == 0 {
return []*logEntry{}, time.Time{}
}
cache := clientCache{}
fileEntries, oldest, total := l.searchFiles(params, cache)
memoryEntries, bufLen := l.searchMemory(params, cache)
log.Debug("querylog: got %d entries from memory", len(memoryEntries))
fileEntries, oldest, total := l.searchFiles(params, cache)
log.Debug("querylog: got %d entries from files", len(fileEntries))
total += bufLen
totalLimit := params.offset + params.limit
@@ -123,7 +128,7 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim
len(entries),
total,
params.olderThan,
time.Since(now),
time.Since(start),
)
return entries, oldest
@@ -145,13 +150,14 @@ func (l *queryLog) searchFiles(
r, err := NewQLogReader(files)
if err != nil {
log.Error("querylog: failed to open qlog reader: %s", err)
log.Error("querylog: opening qlog reader: %s", err)
return entries, oldest, 0
}
defer func() {
derr := r.Close()
if derr != nil {
closeErr := r.Close()
if closeErr != nil {
log.Error("querylog: closing file: %s", err)
}
}()
@@ -161,8 +167,8 @@ func (l *queryLog) searchFiles(
} else {
err = r.seekTS(params.olderThan.UnixNano())
if err == nil {
// Read to the next record, because we only need the one
// that goes after it.
// Read to the next record, because we only need the one that goes
// after it.
_, err = r.ReadNext()
}
}
@@ -176,9 +182,9 @@ func (l *queryLog) searchFiles(
totalLimit := params.offset + params.limit
oldestNano := int64(0)
// By default, we do not scan more than maxFileScanEntries at once.
// The idea is to make search calls faster so that the UI could handle
// it and show something quicker. This behavior can be overridden if
// By default, we do not scan more than maxFileScanEntries at once. The
// idea is to make search calls faster so that the UI could handle it and
// show something quicker. This behavior can be overridden if
// maxFileScanEntries is set to 0.
for total < params.maxFileScanEntries || params.maxFileScanEntries <= 0 {
var e *logEntry

View File

@@ -35,7 +35,7 @@ func TestQueryLog_Search_findClient(t *testing.T) {
return nil, nil
}
l := newQueryLog(Config{
l, err := newQueryLog(Config{
FindClient: findClient,
BaseDir: t.TempDir(),
RotationIvl: timeutil.Day,
@@ -44,6 +44,7 @@ func TestQueryLog_Search_findClient(t *testing.T) {
FileEnabled: true,
AnonymizeClientIP: false,
})
require.NoError(t, err)
t.Cleanup(l.Close)
q := &dns.Msg{

View File

@@ -7,8 +7,12 @@ import (
"net/http"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/timeutil"
"golang.org/x/exp/slices"
)
// topAddrs is an alias for the types of the TopFoo fields of statsResponse.
@@ -38,13 +42,21 @@ type StatsResp struct {
AvgProcessingTime float64 `json:"avg_processing_time"`
}
// handleStats handles requests to the GET /control/stats endpoint.
// handleStats is the handler for the GET /control/stats HTTP API.
func (s *StatsCtx) handleStats(w http.ResponseWriter, r *http.Request) {
s.lock.Lock()
defer s.lock.Unlock()
start := time.Now()
resp, ok := s.getData(s.limitHours)
var (
resp StatsResp
ok bool
)
func() {
s.confMu.RLock()
defer s.confMu.RUnlock()
resp, ok = s.getData(uint32(s.limit.Hours()))
}()
log.Debug("stats: prepared data in %v", time.Since(start))
if !ok {
@@ -63,20 +75,73 @@ type configResp struct {
IntervalDays uint32 `json:"interval"`
}
// handleStatsInfo handles requests to the GET /control/stats_info endpoint.
func (s *StatsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) {
s.lock.Lock()
defer s.lock.Unlock()
// getConfigResp is the response to the GET /control/stats_info.
type getConfigResp struct {
// Ignored is the list of host names, which should not be counted.
Ignored []string `json:"ignored"`
resp := configResp{IntervalDays: s.limitHours / 24}
if !s.enabled {
// Interval is the statistics rotation interval in milliseconds.
Interval float64 `json:"interval"`
// Enabled shows if statistics are enabled. It is an aghalg.NullBool to be
// able to tell when it's set without using pointers.
Enabled aghalg.NullBool `json:"enabled"`
}
// handleStatsInfo is the handler for the GET /control/stats_info HTTP API.
//
// Deprecated: Remove it when migration to the new API is over.
func (s *StatsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) {
var (
enabled bool
limit time.Duration
)
func() {
s.confMu.RLock()
defer s.confMu.RUnlock()
enabled, limit = s.enabled, s.limit
}()
days := uint32(limit / timeutil.Day)
ok := checkInterval(days)
if !ok || (enabled && days == 0) {
// NOTE: If interval is custom we set it to 90 days for compatibility
// with old API.
days = 90
}
resp := configResp{IntervalDays: days}
if !enabled {
resp.IntervalDays = 0
}
_ = aghhttp.WriteJSONResponse(w, r, resp)
}
// handleStatsConfig handles requests to the POST /control/stats_config
// endpoint.
// handleGetStatsConfig is the handler for the GET /control/stats/config HTTP
// API.
func (s *StatsCtx) handleGetStatsConfig(w http.ResponseWriter, r *http.Request) {
var resp *getConfigResp
func() {
s.confMu.RLock()
defer s.confMu.RUnlock()
resp = &getConfigResp{
Ignored: s.ignored.Values(),
Interval: float64(s.limit.Milliseconds()),
Enabled: aghalg.BoolToNullBool(s.enabled),
}
}()
slices.Sort(resp.Ignored)
_ = aghhttp.WriteJSONResponse(w, r, resp)
}
// handleStatsConfig is the handler for the POST /control/stats_config HTTP API.
//
// Deprecated: Remove it when migration to the new API is over.
func (s *StatsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) {
reqData := configResp{}
err := json.NewDecoder(r.Body).Decode(&reqData)
@@ -92,11 +157,59 @@ func (s *StatsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) {
return
}
s.setLimit(int(reqData.IntervalDays))
s.configModified()
limit := time.Duration(reqData.IntervalDays) * timeutil.Day
defer s.configModified()
s.confMu.Lock()
defer s.confMu.Unlock()
s.setLimit(limit)
}
// handleStatsReset handles requests to the POST /control/stats_reset endpoint.
// handlePutStatsConfig is the handler for the PUT /control/stats/config/update
// HTTP API.
func (s *StatsCtx) handlePutStatsConfig(w http.ResponseWriter, r *http.Request) {
reqData := getConfigResp{}
err := json.NewDecoder(r.Body).Decode(&reqData)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
return
}
set, err := aghnet.NewDomainNameSet(reqData.Ignored)
if err != nil {
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "ignored: %s", err)
return
}
ivl := time.Duration(reqData.Interval) * time.Millisecond
err = validateIvl(ivl)
if err != nil {
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "unsupported interval: %s", err)
return
}
if reqData.Enabled == aghalg.NBNull {
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "enabled is null")
return
}
defer s.configModified()
s.confMu.Lock()
defer s.confMu.Unlock()
s.ignored = set
s.limit = ivl
s.enabled = reqData.Enabled == aghalg.NBTrue
}
// handleStatsReset is the handler for the POST /control/stats_reset HTTP API.
func (s *StatsCtx) handleStatsReset(w http.ResponseWriter, r *http.Request) {
err := s.clear()
if err != nil {
@@ -112,6 +225,10 @@ func (s *StatsCtx) initWeb() {
s.httpRegister(http.MethodGet, "/control/stats", s.handleStats)
s.httpRegister(http.MethodPost, "/control/stats_reset", s.handleStatsReset)
s.httpRegister(http.MethodPost, "/control/stats_config", s.handleStatsConfig)
s.httpRegister(http.MethodGet, "/control/stats/config", s.handleGetStatsConfig)
s.httpRegister(http.MethodPut, "/control/stats/config/update", s.handlePutStatsConfig)
// Deprecated handlers.
s.httpRegister(http.MethodGet, "/control/stats_info", s.handleStatsInfo)
s.httpRegister(http.MethodPost, "/control/stats_config", s.handleStatsConfig)
}

153
internal/stats/http_test.go Normal file
View File

@@ -0,0 +1,153 @@
package stats
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"path/filepath"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHandleStatsConfig(t *testing.T) {
const (
smallIvl = 1 * time.Minute
minIvl = 1 * time.Hour
maxIvl = 365 * timeutil.Day
)
conf := Config{
UnitID: func() (id uint32) { return 0 },
ConfigModified: func() {},
ShouldCountClient: func([]string) bool { return true },
Filename: filepath.Join(t.TempDir(), "stats.db"),
Limit: time.Hour * 24,
Enabled: true,
}
testCases := []struct {
name string
wantErr string
body getConfigResp
wantCode int
}{{
name: "set_ivl_1_minIvl",
body: getConfigResp{
Enabled: aghalg.NBTrue,
Interval: float64(minIvl.Milliseconds()),
Ignored: []string{},
},
wantCode: http.StatusOK,
wantErr: "",
}, {
name: "small_interval",
body: getConfigResp{
Enabled: aghalg.NBTrue,
Interval: float64(smallIvl.Milliseconds()),
Ignored: []string{},
},
wantCode: http.StatusUnprocessableEntity,
wantErr: "unsupported interval: less than an hour\n",
}, {
name: "big_interval",
body: getConfigResp{
Enabled: aghalg.NBTrue,
Interval: float64(maxIvl.Milliseconds() + minIvl.Milliseconds()),
Ignored: []string{},
},
wantCode: http.StatusUnprocessableEntity,
wantErr: "unsupported interval: more than a year\n",
}, {
name: "set_ignored_ivl_1_maxIvl",
body: getConfigResp{
Enabled: aghalg.NBTrue,
Interval: float64(maxIvl.Milliseconds()),
Ignored: []string{
"ignor.ed",
},
},
wantCode: http.StatusOK,
wantErr: "",
}, {
name: "ignored_duplicate",
body: getConfigResp{
Enabled: aghalg.NBTrue,
Interval: float64(minIvl.Milliseconds()),
Ignored: []string{
"ignor.ed",
"ignor.ed",
},
},
wantCode: http.StatusUnprocessableEntity,
wantErr: "ignored: duplicate host name \"ignor.ed\" at index 1\n",
}, {
name: "ignored_empty",
body: getConfigResp{
Enabled: aghalg.NBTrue,
Interval: float64(minIvl.Milliseconds()),
Ignored: []string{
"",
},
},
wantCode: http.StatusUnprocessableEntity,
wantErr: "ignored: host name is empty\n",
}, {
name: "enabled_is_null",
body: getConfigResp{
Enabled: aghalg.NBNull,
Interval: float64(minIvl.Milliseconds()),
Ignored: []string{},
},
wantCode: http.StatusUnprocessableEntity,
wantErr: "enabled is null\n",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s, err := New(conf)
require.NoError(t, err)
s.Start()
testutil.CleanupAndRequireSuccess(t, s.Close)
buf, err := json.Marshal(tc.body)
require.NoError(t, err)
const (
configGet = "/control/stats/config"
configPut = "/control/stats/config/update"
)
req := httptest.NewRequest(http.MethodPut, configPut, bytes.NewReader(buf))
rw := httptest.NewRecorder()
s.handlePutStatsConfig(rw, req)
require.Equal(t, tc.wantCode, rw.Code)
if tc.wantCode != http.StatusOK {
assert.Equal(t, tc.wantErr, rw.Body.String())
return
}
resp := httptest.NewRequest(http.MethodGet, configGet, nil)
rw = httptest.NewRecorder()
s.handleGetStatsConfig(rw, resp)
require.Equal(t, http.StatusOK, rw.Code)
ans := getConfigResp{}
err = json.Unmarshal(rw.Body.Bytes(), &ans)
require.NoError(t, err)
assert.Equal(t, tc.body, ans)
})
}
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/timeutil"
"go.etcd.io/bbolt"
)
@@ -25,7 +26,23 @@ func checkInterval(days uint32) (ok bool) {
return days == 0 || days == 1 || days == 7 || days == 30 || days == 90
}
// validateIvl returns an error if ivl is less than an hour or more than a
// year.
func validateIvl(ivl time.Duration) (err error) {
if ivl < time.Hour {
return errors.Error("less than an hour")
}
if ivl > timeutil.Day*365 {
return errors.Error("more than a year")
}
return nil
}
// Config is the configuration structure for the statistics collecting.
//
// Do not alter any fields of this structure after using it.
type Config struct {
// UnitID is the function to generate the identifier for current unit. If
// nil, the default function is used, see newUnitID.
@@ -35,22 +52,24 @@ type Config struct {
// interface.
ConfigModified func()
// ShouldCountClient returns client's ignore setting.
ShouldCountClient func([]string) bool
// HTTPRegister is the function that registers handlers for the stats
// endpoints.
HTTPRegister aghhttp.RegisterFunc
// Ignored is the list of host names, which should not be counted.
Ignored *stringutil.Set
// Filename is the name of the database file.
Filename string
// LimitDays is the maximum number of days to collect statistics into the
// current unit.
LimitDays uint32
// Limit is an upper limit for collecting statistics.
Limit time.Duration
// Enabled tells if the statistics are enabled.
Enabled bool
// Ignored is the list of host names, which should not be counted.
Ignored *stringutil.Set
}
// Interface is the statistics interface to be used by other packages.
@@ -71,7 +90,7 @@ type Interface interface {
WriteDiskConfig(dc *Config)
// ShouldCount returns true if request for the host should be counted.
ShouldCount(host string, qType, qClass uint16) bool
ShouldCount(host string, qType, qClass uint16, ids []string) bool
}
// StatsCtx collects the statistics and flushes it to the database. Its default
@@ -96,23 +115,23 @@ type StatsCtx struct {
// interface.
configModified func()
// filename is the name of database file.
filename string
// lock protects all the fields below.
lock sync.Mutex
// enabled tells if the statistics are enabled.
enabled bool
// limitHours is the maximum number of hours to collect statistics into the
// current unit.
//
// TODO(s.chzhen): Rewrite to use time.Duration.
limitHours uint32
// confMu protects ignored, limit, and enabled.
confMu *sync.RWMutex
// ignored is the list of host names, which should not be counted.
ignored *stringutil.Set
// shouldCountClient returns client's ignore setting.
shouldCountClient func([]string) bool
// filename is the name of database file.
filename string
// limit is an upper limit for collecting statistics.
limit time.Duration
// enabled tells if the statistics are enabled.
enabled bool
}
// New creates s from conf and properly initializes it. Don't use s before
@@ -120,17 +139,28 @@ type StatsCtx struct {
func New(conf Config) (s *StatsCtx, err error) {
defer withRecovered(&err)
err = validateIvl(conf.Limit)
if err != nil {
return nil, fmt.Errorf("unsupported interval: %w", err)
}
if conf.ShouldCountClient == nil {
return nil, errors.Error("should count client is unspecified")
}
s = &StatsCtx{
enabled: conf.Enabled,
currMu: &sync.RWMutex{},
filename: conf.Filename,
configModified: conf.ConfigModified,
httpRegister: conf.HTTPRegister,
ignored: conf.Ignored,
}
if s.limitHours = conf.LimitDays * 24; !checkInterval(conf.LimitDays) {
s.limitHours = 24
configModified: conf.ConfigModified,
filename: conf.Filename,
confMu: &sync.RWMutex{},
ignored: conf.Ignored,
shouldCountClient: conf.ShouldCountClient,
limit: conf.Limit,
enabled: conf.Enabled,
}
if s.unitIDGen = newUnitID; conf.UnitID != nil {
s.unitIDGen = conf.UnitID
}
@@ -150,7 +180,7 @@ func New(conf Config) (s *StatsCtx, err error) {
return nil, fmt.Errorf("stats: opening a transaction: %w", err)
}
deleted := deleteOldUnits(tx, id-s.limitHours-1)
deleted := deleteOldUnits(tx, id-uint32(s.limit.Hours())-1)
udb = loadUnitFromDB(tx, id)
err = finishTxn(tx, deleted > 0)
@@ -228,10 +258,10 @@ func (s *StatsCtx) Close() (err error) {
// Update implements the Interface interface for *StatsCtx.
func (s *StatsCtx) Update(e Entry) {
s.lock.Lock()
defer s.lock.Unlock()
s.confMu.Lock()
defer s.confMu.Unlock()
if !s.enabled || s.limitHours == 0 {
if !s.enabled || s.limit == 0 {
return
}
@@ -260,20 +290,20 @@ func (s *StatsCtx) Update(e Entry) {
// WriteDiskConfig implements the Interface interface for *StatsCtx.
func (s *StatsCtx) WriteDiskConfig(dc *Config) {
s.lock.Lock()
defer s.lock.Unlock()
s.confMu.RLock()
defer s.confMu.RUnlock()
dc.LimitDays = s.limitHours / 24
dc.Ignored = s.ignored.Clone()
dc.Limit = s.limit
dc.Enabled = s.enabled
dc.Ignored = s.ignored
}
// TopClientsIP implements the [Interface] interface for *StatsCtx.
func (s *StatsCtx) TopClientsIP(maxCount uint) (ips []netip.Addr) {
s.lock.Lock()
defer s.lock.Unlock()
s.confMu.RLock()
defer s.confMu.RUnlock()
limit := s.limitHours
limit := uint32(s.limit.Hours())
if !s.enabled || limit == 0 {
return nil
}
@@ -366,8 +396,8 @@ func (s *StatsCtx) openDB() (err error) {
func (s *StatsCtx) flush() (cont bool, sleepFor time.Duration) {
id := s.unitIDGen()
s.lock.Lock()
defer s.lock.Unlock()
s.confMu.Lock()
defer s.confMu.Unlock()
s.currMu.Lock()
defer s.currMu.Unlock()
@@ -377,7 +407,7 @@ func (s *StatsCtx) flush() (cont bool, sleepFor time.Duration) {
return false, 0
}
limit := s.limitHours
limit := uint32(s.limit.Hours())
if limit == 0 || ptr.id == id {
return true, time.Second
}
@@ -436,14 +466,14 @@ func (s *StatsCtx) periodicFlush() {
log.Debug("periodic flushing finished")
}
func (s *StatsCtx) setLimit(limitDays int) {
s.lock.Lock()
defer s.lock.Unlock()
if limitDays != 0 {
// setLimit sets the limit. s.lock is expected to be locked.
//
// TODO(s.chzhen): Remove it when migration to the new API is over.
func (s *StatsCtx) setLimit(limit time.Duration) {
if limit != 0 {
s.enabled = true
s.limitHours = uint32(24 * limitDays)
log.Debug("stats: set limit: %d days", limitDays)
s.limit = limit
log.Debug("stats: set limit: %d days", limit/timeutil.Day)
return
}
@@ -558,11 +588,19 @@ func (s *StatsCtx) loadUnits(limit uint32) (units []*unitDB, firstID uint32) {
}
// ShouldCount returns true if request for the host should be counted.
func (s *StatsCtx) ShouldCount(host string, _, _ uint16) bool {
func (s *StatsCtx) ShouldCount(host string, _, _ uint16, ids []string) bool {
s.confMu.RLock()
defer s.confMu.RUnlock()
if !s.shouldCountClient(ids) {
return false
}
return !s.isIgnored(host)
}
// isIgnored returns true if the host is in the Ignored list.
// isIgnored returns true if the host is in the ignored domains list. It
// assumes that s.confMu is locked for reading.
func (s *StatsCtx) isIgnored(host string) bool {
return s.ignored.Has(host)
}

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -35,9 +36,10 @@ func TestStats_races(t *testing.T) {
var r uint32
idGen := func() (id uint32) { return atomic.LoadUint32(&r) }
conf := Config{
UnitID: idGen,
Filename: filepath.Join(t.TempDir(), "./stats.db"),
LimitDays: 1,
ShouldCountClient: func([]string) bool { return true },
UnitID: idGen,
Filename: filepath.Join(t.TempDir(), "./stats.db"),
Limit: timeutil.Day,
}
s, err := New(conf)

View File

@@ -12,7 +12,10 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -51,10 +54,11 @@ func TestStats(t *testing.T) {
handlers := map[string]http.Handler{}
conf := stats.Config{
Filename: filepath.Join(t.TempDir(), "stats.db"),
LimitDays: 1,
Enabled: true,
UnitID: constUnitID,
ShouldCountClient: func([]string) bool { return true },
Filename: filepath.Join(t.TempDir(), "stats.db"),
Limit: timeutil.Day,
Enabled: true,
UnitID: constUnitID,
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
handlers[url] = handler
},
@@ -157,11 +161,12 @@ func TestLargeNumbers(t *testing.T) {
handlers := map[string]http.Handler{}
conf := stats.Config{
Filename: filepath.Join(t.TempDir(), "stats.db"),
LimitDays: 1,
Enabled: true,
UnitID: func() (id uint32) { return atomic.LoadUint32(&curHour) },
HTTPRegister: func(_, url string, handler http.HandlerFunc) { handlers[url] = handler },
ShouldCountClient: func([]string) bool { return true },
Filename: filepath.Join(t.TempDir(), "stats.db"),
Limit: timeutil.Day,
Enabled: true,
UnitID: func() (id uint32) { return atomic.LoadUint32(&curHour) },
HTTPRegister: func(_, url string, handler http.HandlerFunc) { handlers[url] = handler },
}
s, err := stats.New(conf)
@@ -196,3 +201,60 @@ func TestLargeNumbers(t *testing.T) {
assertSuccessAndUnmarshal(t, data, handlers["/control/stats"], req)
assert.Equal(t, hoursNum*cliNumPerHour, int(data.NumDNSQueries))
}
func TestShouldCount(t *testing.T) {
const (
ignored1 = "ignor.ed"
ignored2 = "ignored.to"
)
set := stringutil.NewSet(ignored1, ignored2)
s, err := stats.New(stats.Config{
Enabled: true,
Filename: filepath.Join(t.TempDir(), "stats.db"),
Limit: timeutil.Day,
Ignored: set,
ShouldCountClient: func(ids []string) (a bool) {
return ids[0] != "no_count"
},
})
require.NoError(t, err)
s.Start()
testutil.CleanupAndRequireSuccess(t, s.Close)
testCases := []struct {
wantCount assert.BoolAssertionFunc
name string
host string
ids []string
}{{
name: "count",
host: "example.com",
ids: []string{"whatever"},
wantCount: assert.True,
}, {
name: "no_count_ignored_1",
host: ignored1,
ids: []string{"whatever"},
wantCount: assert.False,
}, {
name: "no_count_ignored_2",
host: ignored2,
ids: []string{"whatever"},
wantCount: assert.False,
}, {
name: "no_count_client_ignore",
host: "example.com",
ids: []string{"no_count"},
wantCount: assert.False,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
res := s.ShouldCount(tc.host, dns.TypeA, dns.ClassINET, tc.ids)
tc.wantCount(t, res)
})
}
}

View File

@@ -72,6 +72,19 @@ type Entry struct {
// unit collects the statistics data for a specific period of time.
type unit struct {
// domains stores the number of requests for each domain.
domains map[string]uint64
// blockedDomains stores the number of requests for each domain that has
// been blocked.
blockedDomains map[string]uint64
// clients stores the number of requests from each client.
clients map[string]uint64
// nResult stores the number of requests grouped by it's result.
nResult []uint64
// id is the unique unit's identifier. It's set to an absolute hour number
// since the beginning of UNIX time by the default ID generating function.
//
@@ -81,29 +94,20 @@ type unit struct {
// nTotal stores the total number of requests.
nTotal uint64
// nResult stores the number of requests grouped by it's result.
nResult []uint64
// timeSum stores the sum of processing time in milliseconds of each request
// written by the unit.
timeSum uint64
// domains stores the number of requests for each domain.
domains map[string]uint64
// blockedDomains stores the number of requests for each domain that has
// been blocked.
blockedDomains map[string]uint64
// clients stores the number of requests from each client.
clients map[string]uint64
}
// newUnit allocates the new *unit.
func newUnit(id uint32) (u *unit) {
return &unit{
id: id,
domains: map[string]uint64{},
blockedDomains: map[string]uint64{},
clients: map[string]uint64{},
nResult: make([]uint64, resultLast),
domains: make(map[string]uint64),
blockedDomains: make(map[string]uint64),
clients: make(map[string]uint64),
id: id,
}
}
@@ -115,19 +119,25 @@ type countPair struct {
}
// unitDB is the structure for serializing statistics data into the database.
//
// NOTE: Do not change the names or types of fields, as this structure is used
// for GOB encoding.
type unitDB struct {
// NTotal is the total number of requests.
NTotal uint64
// NResult is the number of requests by the result's kind.
NResult []uint64
// Domains is the number of requests for each domain name.
Domains []countPair
// BlockedDomains is the number of requests blocked for each domain name.
BlockedDomains []countPair
// Clients is the number of requests from each client.
Clients []countPair
// NTotal is the total number of requests.
NTotal uint64
// TimeAvg is the average of processing times in milliseconds of all the
// requests in the unit.
TimeAvg uint32