all: sync with master

This commit is contained in:
Ainar Garipov
2024-05-15 13:34:12 +03:00
parent 6318fc424b
commit 667263a3a8
82 changed files with 2356 additions and 1817 deletions

View File

@@ -10,29 +10,8 @@ import (
"golang.org/x/exp/constraints"
)
// Coalesce returns the first non-zero value. It is named after function
// COALESCE in SQL. If values or all its elements are empty, it returns a zero
// value.
//
// T is comparable, because Go currently doesn't have a comparableWithZeroValue
// constraint.
//
// TODO(a.garipov): Think of ways to merge with [CoalesceSlice].
func Coalesce[T comparable](values ...T) (res T) {
var zero T
for _, v := range values {
if v != zero {
return v
}
}
return zero
}
// CoalesceSlice returns the first non-zero value. It is named after function
// COALESCE in SQL. If values or all its elements are empty, it returns nil.
//
// TODO(a.garipov): Think of ways to merge with [Coalesce].
func CoalesceSlice[E any, S []E](values ...S) (res S) {
for _, v := range values {
if v != nil {

View File

@@ -33,7 +33,7 @@ func elements(b *aghalg.RingBuffer[int], n uint, reverse bool) (es []int) {
func TestNewRingBuffer(t *testing.T) {
t.Run("success_and_clear", func(t *testing.T) {
b := aghalg.NewRingBuffer[int](5)
for i := 0; i < 10; i++ {
for i := range 10 {
b.Append(i)
}
assert.Equal(t, []int{5, 6, 7, 8, 9}, elements(b, b.Len(), false))
@@ -44,7 +44,7 @@ func TestNewRingBuffer(t *testing.T) {
t.Run("zero", func(t *testing.T) {
b := aghalg.NewRingBuffer[int](0)
for i := 0; i < 10; i++ {
for i := range 10 {
b.Append(i)
bufLen := b.Len()
assert.EqualValues(t, 0, bufLen)
@@ -55,7 +55,7 @@ func TestNewRingBuffer(t *testing.T) {
t.Run("single", func(t *testing.T) {
b := aghalg.NewRingBuffer[int](1)
for i := 0; i < 10; i++ {
for i := range 10 {
b.Append(i)
bufLen := b.Len()
assert.EqualValues(t, 1, bufLen)
@@ -94,7 +94,7 @@ func TestRingBuffer_Range(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for i := 0; i < tc.count; i++ {
for i := range tc.count {
b.Append(i)
}

View File

@@ -11,7 +11,7 @@ func TestNewSortedMap(t *testing.T) {
var m SortedMap[string, int]
letters := []string{}
for i := 0; i < 10; i++ {
for i := range 10 {
r := string('a' + rune(i))
letters = append(letters, r)
}

View File

@@ -97,6 +97,8 @@ func (fw FileWalker) Walk(fsys fs.FS, initial ...string) (ok bool, err error) {
var filename string
defer func() { err = errors.Annotate(err, "checking %q: %w", filename) }()
// TODO(e.burkov): Redo this loop, as it modifies the very same slice it
// iterates over.
for i := 0; i < len(src); i++ {
var patterns []string
var cont bool

View File

@@ -159,21 +159,11 @@ func NotifyReconfigureSignal(c chan<- os.Signal) {
notifyReconfigureSignal(c)
}
// NotifyShutdownSignal notifies c on receiving shutdown signals.
func NotifyShutdownSignal(c chan<- os.Signal) {
notifyShutdownSignal(c)
}
// IsReconfigureSignal returns true if sig is a reconfigure signal.
func IsReconfigureSignal(sig os.Signal) (ok bool) {
return isReconfigureSignal(sig)
}
// IsShutdownSignal returns true if sig is a shutdown signal.
func IsShutdownSignal(sig os.Signal) (ok bool) {
return isShutdownSignal(sig)
}
// SendShutdownSignal sends the shutdown signal to the channel.
func SendShutdownSignal(c chan<- os.Signal) {
sendShutdownSignal(c)

View File

@@ -13,26 +13,10 @@ func notifyReconfigureSignal(c chan<- os.Signal) {
signal.Notify(c, unix.SIGHUP)
}
func notifyShutdownSignal(c chan<- os.Signal) {
signal.Notify(c, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM)
}
func isReconfigureSignal(sig os.Signal) (ok bool) {
return sig == unix.SIGHUP
}
func isShutdownSignal(sig os.Signal) (ok bool) {
switch sig {
case
unix.SIGINT,
unix.SIGQUIT,
unix.SIGTERM:
return true
default:
return false
}
}
func sendShutdownSignal(_ chan<- os.Signal) {
// On Unix we are already notified by the system.
}

View File

@@ -5,7 +5,6 @@ package aghos
import (
"os"
"os/signal"
"syscall"
"golang.org/x/sys/windows"
)
@@ -43,25 +42,10 @@ func notifyReconfigureSignal(c chan<- os.Signal) {
signal.Notify(c, windows.SIGHUP)
}
func notifyShutdownSignal(c chan<- os.Signal) {
// syscall.SIGTERM is processed automatically. See go doc os/signal,
// section Windows.
signal.Notify(c, os.Interrupt)
}
func isReconfigureSignal(sig os.Signal) (ok bool) {
return sig == windows.SIGHUP
}
func isShutdownSignal(sig os.Signal) (ok bool) {
switch sig {
case os.Interrupt, syscall.SIGTERM:
return true
default:
return false
}
}
func sendShutdownSignal(c chan<- os.Signal) {
c <- os.Interrupt
}

View File

@@ -78,7 +78,6 @@ func TestWithDeferredCleanup(t *testing.T) {
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

View File

@@ -91,8 +91,6 @@ func TestDefaultAddrProc_Process_rDNS(t *testing.T) {
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
@@ -186,8 +184,6 @@ func TestDefaultAddrProc_Process_WHOIS(t *testing.T) {
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

View File

@@ -7,6 +7,7 @@ package client
import (
"encoding"
"fmt"
"net/netip"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
)
@@ -56,6 +57,9 @@ func (cs Source) MarshalText() (text []byte, err error) {
// Runtime is a client information from different sources.
type Runtime struct {
// ip is an IP address of a client.
ip netip.Addr
// whois is the filtered WHOIS information of a client.
whois *whois.Info
@@ -80,6 +84,15 @@ type Runtime struct {
hostsFile []string
}
// NewRuntime constructs a new runtime client. ip must be valid IP address.
//
// TODO(s.chzhen): Validate IP address.
func NewRuntime(ip netip.Addr) (r *Runtime) {
return &Runtime{
ip: ip,
}
}
// Info returns a client information from the highest-priority source.
func (r *Runtime) Info() (cs Source, host string) {
info := []string{}
@@ -133,8 +146,8 @@ func (r *Runtime) SetWHOIS(info *whois.Info) {
r.whois = info
}
// Unset clears a cs information.
func (r *Runtime) Unset(cs Source) {
// unset clears a cs information.
func (r *Runtime) unset(cs Source) {
switch cs {
case SourceWHOIS:
r.whois = nil
@@ -149,11 +162,16 @@ func (r *Runtime) Unset(cs Source) {
}
}
// IsEmpty returns true if there is no information from any source.
func (r *Runtime) IsEmpty() (ok bool) {
// isEmpty returns true if there is no information from any source.
func (r *Runtime) isEmpty() (ok bool) {
return r.whois == nil &&
r.arp == nil &&
r.rdns == nil &&
r.dhcp == nil &&
r.hostsFile == nil
}
// Addr returns an IP address of the client.
func (r *Runtime) Addr() (ip netip.Addr) {
return r.ip
}

View File

@@ -4,8 +4,12 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/exp/maps"
)
// macKey contains MAC as byte array of 6, 8, or 20 bytes.
@@ -28,6 +32,9 @@ func macToKey(mac net.HardwareAddr) (key macKey) {
// Index stores all information about persistent clients.
type Index struct {
// nameToUID maps client name to UID.
nameToUID map[string]UID
// clientIDToUID maps client ID to UID.
clientIDToUID map[string]UID
@@ -47,6 +54,7 @@ type Index struct {
// NewIndex initializes the new instance of client index.
func NewIndex() (ci *Index) {
return &Index{
nameToUID: map[string]UID{},
clientIDToUID: map[string]UID{},
ipToUID: map[netip.Addr]UID{},
subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare),
@@ -62,6 +70,8 @@ func (ci *Index) Add(c *Persistent) {
panic("client must contain uid")
}
ci.nameToUID[c.Name] = c.UID
for _, id := range c.ClientIDs {
ci.clientIDToUID[id] = c.UID
}
@@ -82,15 +92,30 @@ func (ci *Index) Add(c *Persistent) {
ci.uidToClient[c.UID] = c
}
// ClashesUID returns existing persistent client with the same UID as c. Note
// that this is only possible when configuration contains duplicate fields.
func (ci *Index) ClashesUID(c *Persistent) (err error) {
p, ok := ci.uidToClient[c.UID]
if ok {
return fmt.Errorf("another client %q uses the same uid", p.Name)
}
return nil
}
// Clashes returns an error if the index contains a different persistent client
// with at least a single identifier contained by c. c must be non-nil.
func (ci *Index) Clashes(c *Persistent) (err error) {
if p := ci.clashesName(c); p != nil {
return fmt.Errorf("another client uses the same name %q", p.Name)
}
for _, id := range c.ClientIDs {
existing, ok := ci.clientIDToUID[id]
if ok && existing != c.UID {
p := ci.uidToClient[existing]
return fmt.Errorf("another client %q uses the same ID %q", p.Name, id)
return fmt.Errorf("another client %q uses the same ClientID %q", p.Name, id)
}
}
@@ -112,6 +137,21 @@ func (ci *Index) Clashes(c *Persistent) (err error) {
return nil
}
// clashesName returns existing persistent client with the same name as c or
// nil. c must be non-nil.
func (ci *Index) clashesName(c *Persistent) (existing *Persistent) {
existing, ok := ci.FindByName(c.Name)
if !ok {
return nil
}
if existing.UID != c.UID {
return existing
}
return nil
}
// clashesIP returns a previous client with the same IP address as c. c must be
// non-nil.
func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) {
@@ -184,21 +224,33 @@ func (ci *Index) Find(id string) (c *Persistent, ok bool) {
mac, err := net.ParseMAC(id)
if err == nil {
return ci.findByMAC(mac)
return ci.FindByMAC(mac)
}
return nil, false
}
// find finds persistent client by IP address.
// FindByName finds persistent client by name.
func (ci *Index) FindByName(name string) (c *Persistent, found bool) {
uid, found := ci.nameToUID[name]
if found {
return ci.uidToClient[uid], true
}
return nil, false
}
// findByIP finds persistent client by IP address.
func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
uid, found := ci.ipToUID[ip]
if found {
return ci.uidToClient[uid], true
}
ipWithoutZone := ip.WithZone("")
ci.subnetToUID.Range(func(pref netip.Prefix, id UID) (cont bool) {
if pref.Contains(ip) {
// Remove zone before checking because prefixes strip zones.
if pref.Contains(ipWithoutZone) {
uid, found = id, true
return false
@@ -214,8 +266,8 @@ func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
return nil, false
}
// find finds persistent client by MAC.
func (ci *Index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
// FindByMAC finds persistent client by MAC.
func (ci *Index) FindByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
k := macToKey(mac)
uid, found := ci.macToUID[k]
if found {
@@ -225,9 +277,31 @@ func (ci *Index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
return nil, false
}
// FindByIPWithoutZone finds a persistent client by IP address without zone. It
// strips the IPv6 zone index from the stored IP addresses before comparing,
// because querylog entries don't have it. See TODO on [querylog.logEntry.IP].
//
// Note that multiple clients can have the same IP address with different zones.
// Therefore, the result of this method is indeterminate.
func (ci *Index) FindByIPWithoutZone(ip netip.Addr) (c *Persistent) {
if (ip == netip.Addr{}) {
return nil
}
for addr, uid := range ci.ipToUID {
if addr.WithZone("") == ip {
return ci.uidToClient[uid]
}
}
return nil
}
// Delete removes information about persistent client from the index. c must be
// non-nil.
func (ci *Index) Delete(c *Persistent) {
delete(ci.nameToUID, c.Name)
for _, id := range c.ClientIDs {
delete(ci.clientIDToUID, id)
}
@@ -247,3 +321,48 @@ func (ci *Index) Delete(c *Persistent) {
delete(ci.uidToClient, c.UID)
}
// Size returns the number of persistent clients.
func (ci *Index) Size() (n int) {
return len(ci.uidToClient)
}
// Range calls f for each persistent client, unless cont is false. The order is
// undefined.
func (ci *Index) Range(f func(c *Persistent) (cont bool)) {
for _, c := range ci.uidToClient {
if !f(c) {
return
}
}
}
// RangeByName is like [Index.Range] but sorts the persistent clients by name
// before iterating ensuring a predictable order.
func (ci *Index) RangeByName(f func(c *Persistent) (cont bool)) {
cs := maps.Values(ci.uidToClient)
slices.SortFunc(cs, func(a, b *Persistent) (n int) {
return strings.Compare(a.Name, b.Name)
})
for _, c := range cs {
if !f(c) {
break
}
}
}
// CloseUpstreams closes upstream configurations of persistent clients.
func (ci *Index) CloseUpstreams() (err error) {
var errs []error
ci.RangeByName(func(c *Persistent) (cont bool) {
err = c.CloseUpstreams()
if err != nil {
errs = append(errs, err)
}
return true
})
return errors.Join(errs...)
}

View File

@@ -22,7 +22,7 @@ func newIDIndex(m []*Persistent) (ci *Index) {
return ci
}
func TestClientIndex(t *testing.T) {
func TestClientIndex_Find(t *testing.T) {
const (
cliIPNone = "1.2.3.4"
cliIP1 = "1.1.1.1"
@@ -35,26 +35,49 @@ func TestClientIndex(t *testing.T) {
cliID = "client-id"
cliMAC = "11:11:11:11:11:11"
linkLocalIP = "fe80::abcd:abcd:abcd:ab%eth0"
linkLocalSubnet = "fe80::/16"
)
clients := []*Persistent{{
Name: "client1",
IPs: []netip.Addr{
netip.MustParseAddr(cliIP1),
netip.MustParseAddr(cliIPv6),
},
}, {
Name: "client2",
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
}, {
Name: "client_with_mac",
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
}, {
Name: "client_with_id",
ClientIDs: []string{cliID},
}}
var (
clientWithBothFams = &Persistent{
Name: "client1",
IPs: []netip.Addr{
netip.MustParseAddr(cliIP1),
netip.MustParseAddr(cliIPv6),
},
}
clientWithSubnet = &Persistent{
Name: "client2",
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
}
clientWithMAC = &Persistent{
Name: "client_with_mac",
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
}
clientWithID = &Persistent{
Name: "client_with_id",
ClientIDs: []string{cliID},
}
clientLinkLocal = &Persistent{
Name: "client_link_local",
Subnets: []netip.Prefix{netip.MustParsePrefix(linkLocalSubnet)},
}
)
clients := []*Persistent{
clientWithBothFams,
clientWithSubnet,
clientWithMAC,
clientWithID,
clientLinkLocal,
}
ci := newIDIndex(clients)
testCases := []struct {
@@ -64,19 +87,23 @@ func TestClientIndex(t *testing.T) {
}{{
name: "ipv4_ipv6",
ids: []string{cliIP1, cliIPv6},
want: clients[0],
want: clientWithBothFams,
}, {
name: "ipv4_subnet",
ids: []string{cliIP2, cliSubnetIP},
want: clients[1],
want: clientWithSubnet,
}, {
name: "mac",
ids: []string{cliMAC},
want: clients[2],
want: clientWithMAC,
}, {
name: "client_id",
ids: []string{cliID},
want: clients[3],
want: clientWithID,
}, {
name: "client_link_local_subnet",
ids: []string{linkLocalIP},
want: clientLinkLocal,
}}
for _, tc := range testCases {
@@ -221,3 +248,103 @@ func TestMACToKey(t *testing.T) {
_ = macToKey(mac)
})
}
func TestIndex_FindByIPWithoutZone(t *testing.T) {
var (
ip = netip.MustParseAddr("fe80::a098:7654:32ef:ff1")
ipWithZone = netip.MustParseAddr("fe80::1ff:fe23:4567:890a%eth2")
)
var (
clientNoZone = &Persistent{
Name: "client",
IPs: []netip.Addr{ip},
}
clientWithZone = &Persistent{
Name: "client_with_zone",
IPs: []netip.Addr{ipWithZone},
}
)
ci := newIDIndex([]*Persistent{
clientNoZone,
clientWithZone,
})
testCases := []struct {
ip netip.Addr
want *Persistent
name string
}{{
name: "without_zone",
ip: ip,
want: clientNoZone,
}, {
name: "with_zone",
ip: ipWithZone,
want: clientWithZone,
}, {
name: "zero_address",
ip: netip.Addr{},
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := ci.FindByIPWithoutZone(tc.ip.WithZone(""))
require.Equal(t, tc.want, c)
})
}
}
func TestClientIndex_RangeByName(t *testing.T) {
sortedClients := []*Persistent{{
Name: "clientA",
ClientIDs: []string{"A"},
}, {
Name: "clientB",
ClientIDs: []string{"B"},
}, {
Name: "clientC",
ClientIDs: []string{"C"},
}, {
Name: "clientD",
ClientIDs: []string{"D"},
}, {
Name: "clientE",
ClientIDs: []string{"E"},
}}
testCases := []struct {
name string
want []*Persistent
}{{
name: "basic",
want: sortedClients,
}, {
name: "nil",
want: nil,
}, {
name: "one_element",
want: sortedClients[:1],
}, {
name: "two_elements",
want: sortedClients[:2],
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ci := newIDIndex(tc.want)
var got []*Persistent
ci.RangeByName(func(c *Persistent) (cont bool) {
got = append(got, c)
return true
})
assert.Equal(t, tc.want, got)
})
}
}

View File

@@ -64,9 +64,7 @@ type Persistent struct {
// upstream must be used.
UpstreamConfig *proxy.CustomUpstreamConfig
// TODO(d.kolyshev): Make SafeSearchConf a pointer.
SafeSearchConf filtering.SafeSearchConfig
SafeSearch filtering.SafeSearch
SafeSearch filtering.SafeSearch
// BlockedServices is the configuration of blocked services of a client.
BlockedServices *filtering.BlockedServices
@@ -95,6 +93,9 @@ type Persistent struct {
UseOwnBlockedServices bool
IgnoreQueryLog bool
IgnoreStatistics bool
// TODO(d.kolyshev): Make SafeSearchConf a pointer.
SafeSearchConf filtering.SafeSearchConfig
}
// SetTags sets the tags if they are known, otherwise logs an unknown tag.

View File

@@ -0,0 +1,63 @@
package client
import "net/netip"
// RuntimeIndex stores information about runtime clients.
type RuntimeIndex struct {
// index maps IP address to runtime client.
index map[netip.Addr]*Runtime
}
// NewRuntimeIndex returns initialized runtime index.
func NewRuntimeIndex() (ri *RuntimeIndex) {
return &RuntimeIndex{
index: map[netip.Addr]*Runtime{},
}
}
// Client returns the saved runtime client by ip. If no such client exists,
// returns nil.
func (ri *RuntimeIndex) Client(ip netip.Addr) (rc *Runtime) {
return ri.index[ip]
}
// Add saves the runtime client in the index. IP address of a client must be
// unique. See [Runtime.Client]. rc must not be nil.
func (ri *RuntimeIndex) Add(rc *Runtime) {
ip := rc.Addr()
ri.index[ip] = rc
}
// Size returns the number of the runtime clients.
func (ri *RuntimeIndex) Size() (n int) {
return len(ri.index)
}
// Range calls f for each runtime client in an undefined order.
func (ri *RuntimeIndex) Range(f func(rc *Runtime) (cont bool)) {
for _, rc := range ri.index {
if !f(rc) {
return
}
}
}
// Delete removes the runtime client by ip.
func (ri *RuntimeIndex) Delete(ip netip.Addr) {
delete(ri.index, ip)
}
// DeleteBySource removes all runtime clients that have information only from
// the specified source and returns the number of removed clients.
func (ri *RuntimeIndex) DeleteBySource(src Source) (n int) {
for ip, rc := range ri.index {
rc.unset(src)
if rc.isEmpty() {
delete(ri.index, ip)
n++
}
}
return n
}

View File

@@ -0,0 +1,85 @@
package client_test
import (
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/stretchr/testify/assert"
)
func TestRuntimeIndex(t *testing.T) {
const cliSrc = client.SourceARP
var (
ip1 = netip.MustParseAddr("1.1.1.1")
ip2 = netip.MustParseAddr("2.2.2.2")
ip3 = netip.MustParseAddr("3.3.3.3")
)
ri := client.NewRuntimeIndex()
currentSize := 0
testCases := []struct {
ip netip.Addr
name string
hosts []string
src client.Source
}{{
src: cliSrc,
ip: ip1,
name: "1",
hosts: []string{"host1"},
}, {
src: cliSrc,
ip: ip2,
name: "2",
hosts: []string{"host2"},
}, {
src: cliSrc,
ip: ip3,
name: "3",
hosts: []string{"host3"},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
rc := client.NewRuntime(tc.ip)
rc.SetInfo(tc.src, tc.hosts)
ri.Add(rc)
currentSize++
got := ri.Client(tc.ip)
assert.Equal(t, rc, got)
})
}
t.Run("size", func(t *testing.T) {
assert.Equal(t, currentSize, ri.Size())
})
t.Run("range", func(t *testing.T) {
s := 0
ri.Range(func(rc *client.Runtime) (cont bool) {
s++
return true
})
assert.Equal(t, currentSize, s)
})
t.Run("delete", func(t *testing.T) {
ri.Delete(ip1)
currentSize--
assert.Equal(t, currentSize, ri.Size())
})
t.Run("delete_by_src", func(t *testing.T) {
assert.Equal(t, currentSize, ri.DeleteBySource(cliSrc))
assert.Equal(t, 0, ri.Size())
})
}

View File

@@ -1,5 +1,7 @@
package configmigrate
import "github.com/AdguardTeam/golibs/errors"
// migrateTo15 performs the following changes:
//
// # BEFORE:
@@ -43,7 +45,7 @@ func migrateTo15(diskConf yobj) (err error) {
}
diskConf["querylog"] = qlog
return coalesceError(
return errors.Join(
moveVal[bool](dns, qlog, "querylog_enabled", "enabled"),
moveVal[bool](dns, qlog, "querylog_file_enabled", "file_enabled"),
moveVal[any](dns, qlog, "querylog_interval", "interval"),

View File

@@ -1,5 +1,7 @@
package configmigrate
import "github.com/AdguardTeam/golibs/errors"
// migrateTo24 performs the following changes:
//
// # BEFORE:
@@ -28,7 +30,7 @@ func migrateTo24(diskConf yobj) (err error) {
diskConf["schema_version"] = 24
logObj := yobj{}
err = coalesceError(
err = errors.Join(
moveVal[string](diskConf, logObj, "log_file", "file"),
moveVal[int](diskConf, logObj, "log_max_backups", "max_backups"),
moveVal[int](diskConf, logObj, "log_max_size", "max_size"),

View File

@@ -1,5 +1,7 @@
package configmigrate
import "github.com/AdguardTeam/golibs/errors"
// migrateTo26 performs the following changes:
//
// # BEFORE:
@@ -78,7 +80,7 @@ func migrateTo26(diskConf yobj) (err error) {
}
filteringObj := yobj{}
err = coalesceError(
err = errors.Join(
moveSameVal[bool](dns, filteringObj, "filtering_enabled"),
moveSameVal[int](dns, filteringObj, "filters_update_interval"),
moveSameVal[bool](dns, filteringObj, "parental_enabled"),

View File

@@ -1,5 +1,7 @@
package configmigrate
import "github.com/AdguardTeam/golibs/errors"
// migrateTo7 performs the following changes:
//
// # BEFORE:
@@ -37,7 +39,7 @@ func migrateTo7(diskConf yobj) (err error) {
}
dhcpv4 := yobj{}
err = coalesceError(
err = errors.Join(
moveSameVal[string](dhcp, dhcpv4, "gateway_ip"),
moveSameVal[string](dhcp, dhcpv4, "subnet_mask"),
moveSameVal[string](dhcp, dhcpv4, "range_start"),

View File

@@ -50,19 +50,3 @@ func moveVal[T any](src, dst yobj, srcKey, dstKey string) (err error) {
func moveSameVal[T any](src, dst yobj, key string) (err error) {
return moveVal[T](src, dst, key, key)
}
// coalesceError returns the first non-nil error. It is named after function
// COALESCE in SQL. If all errors are nil, it returns nil.
//
// TODO(e.burkov): Replace with [errors.Join].
//
// TODO(a.garipov): Think of ways to merge with [aghalg.Coalesce].
func coalesceError(errors ...error) (res error) {
for _, err := range errors {
if err != nil {
return err
}
}
return nil
}

View File

@@ -156,7 +156,10 @@ func (a *accessManager) isBlockedIP(ip netip.Addr) (blocked bool, rule string) {
}
for _, ipnet := range ipnets {
if ipnet.Contains(ip) {
// Remove zone before checking because prefixes stip zones.
//
// TODO(d.kolyshev): Cover with tests.
if ipnet.Contains(ip.WithZone("")) {
return blocked, ipnet.String()
}
}

View File

@@ -0,0 +1,116 @@
package dnsforward
import (
"encoding/binary"
"fmt"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// type check
var _ proxy.BeforeRequestHandler = (*Server)(nil)
// HandleBefore is the handler that is called before any other processing,
// including logs. It performs access checks and puts the client ID, if there
// is one, into the server's cache.
//
// TODO(d.kolyshev): Extract to separate package.
func (s *Server) HandleBefore(
_ *proxy.Proxy,
pctx *proxy.DNSContext,
) (err error) {
clientID, err := s.clientIDFromDNSContext(pctx)
if err != nil {
return &proxy.BeforeRequestError{
Err: fmt.Errorf("getting clientid: %w", err),
Response: s.NewMsgSERVFAIL(pctx.Req),
}
}
blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID)
if blocked {
return s.preBlockedResponse(pctx)
}
if len(pctx.Req.Question) == 1 {
q := pctx.Req.Question[0]
qt := q.Qtype
host := aghnet.NormalizeDomain(q.Name)
if s.access.isBlockedHost(host, qt) {
log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host)
return s.preBlockedResponse(pctx)
}
}
if clientID != "" {
key := [8]byte{}
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
s.clientIDCache.Set(key[:], []byte(clientID))
}
return nil
}
// clientIDFromDNSContext extracts the client's ID from the server name of the
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
// is not one of these, clientID is an empty string and err is nil.
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
proto := pctx.Proto
if proto == proxy.ProtoHTTPS {
clientID, err = clientIDFromDNSContextHTTPS(pctx)
if err != nil {
return "", fmt.Errorf("checking url: %w", err)
} else if clientID != "" {
return clientID, nil
}
// Go on and check the domain name as well.
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
return "", nil
}
hostSrvName := s.conf.ServerName
if hostSrvName == "" {
return "", nil
}
cliSrvName, err := clientServerName(pctx, proto)
if err != nil {
return "", err
}
clientID, err = clientIDFromClientServerName(
hostSrvName,
cliSrvName,
s.conf.StrictSNICheck,
)
if err != nil {
return "", fmt.Errorf("clientid check: %w", err)
}
return clientID, nil
}
// errAccessBlocked is a sentinel error returned when a request is blocked by
// access settings.
var errAccessBlocked errors.Error = "blocked by access settings"
// preBlockedResponse returns a protocol-appropriate response for a request that
// was blocked by access settings.
func (s *Server) preBlockedResponse(pctx *proxy.DNSContext) (err error) {
if pctx.Proto == proxy.ProtoUDP || pctx.Proto == proxy.ProtoDNSCrypt {
// Return nil so that dnsproxy drops the connection and thus
// prevent DNS amplification attacks.
return errAccessBlocked
}
return &proxy.BeforeRequestError{
Err: errAccessBlocked,
Response: s.makeResponseREFUSED(pctx.Req),
}
}

View File

@@ -0,0 +1,299 @@
package dnsforward
import (
"crypto/tls"
"net"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
blockedHost = "blockedhost.org"
testFQDN = "example.org."
dnsClientTimeout = 200 * time.Millisecond
)
func TestServer_HandleBefore_tls(t *testing.T) {
t.Parallel()
const clientID = "client-1"
testCases := []struct {
clientSrvName string
name string
host string
allowedClients []string
disallowedClients []string
blockedHosts []string
wantRCode int
}{{
clientSrvName: tlsServerName,
name: "allow_all",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{},
blockedHosts: []string{},
wantRCode: dns.RcodeSuccess,
}, {
clientSrvName: "%" + "." + tlsServerName,
name: "invalid_client_id",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{},
blockedHosts: []string{},
wantRCode: dns.RcodeServerFailure,
}, {
clientSrvName: clientID + "." + tlsServerName,
name: "allowed_client_allowed",
host: testFQDN,
allowedClients: []string{clientID},
disallowedClients: []string{},
blockedHosts: []string{},
wantRCode: dns.RcodeSuccess,
}, {
clientSrvName: "client-2." + tlsServerName,
name: "allowed_client_rejected",
host: testFQDN,
allowedClients: []string{clientID},
disallowedClients: []string{},
blockedHosts: []string{},
wantRCode: dns.RcodeRefused,
}, {
clientSrvName: tlsServerName,
name: "disallowed_client_allowed",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{clientID},
blockedHosts: []string{},
wantRCode: dns.RcodeSuccess,
}, {
clientSrvName: clientID + "." + tlsServerName,
name: "disallowed_client_rejected",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{clientID},
blockedHosts: []string{},
wantRCode: dns.RcodeRefused,
}, {
clientSrvName: tlsServerName,
name: "blocked_hosts_allowed",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{},
blockedHosts: []string{blockedHost},
wantRCode: dns.RcodeSuccess,
}, {
clientSrvName: tlsServerName,
name: "blocked_hosts_rejected",
host: dns.Fqdn(blockedHost),
allowedClients: []string{},
disallowedClients: []string{},
blockedHosts: []string{blockedHost},
wantRCode: dns.RcodeRefused,
}}
localAns := []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: testFQDN,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 3600,
Rdlength: 4,
},
A: net.IP{1, 2, 3, 4},
}}
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := (&dns.Msg{}).SetReply(req)
resp.Answer = localAns
require.NoError(t, w.WriteMsg(resp))
})
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
s, _ := createTestTLS(t, TLSConfig{
TLSListenAddrs: []*net.TCPAddr{{}},
ServerName: tlsServerName,
})
s.conf.UpstreamDNS = []string{localUpsAddr}
s.conf.AllowedClients = tc.allowedClients
s.conf.DisallowedClients = tc.disallowedClients
s.conf.BlockedHosts = tc.blockedHosts
err := s.Prepare(&s.conf)
require.NoError(t, err)
startDeferStop(t, s)
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: tc.clientSrvName,
}
client := &dns.Client{
Net: "tcp-tls",
TLSConfig: tlsConfig,
Timeout: dnsClientTimeout,
}
req := createTestMessage(tc.host)
addr := s.dnsProxy.Addr(proxy.ProtoTLS).String()
reply, _, err := client.Exchange(req, addr)
require.NoError(t, err)
assert.Equal(t, tc.wantRCode, reply.Rcode)
if tc.wantRCode == dns.RcodeSuccess {
assert.Equal(t, localAns, reply.Answer)
} else {
assert.Empty(t, reply.Answer)
}
})
}
}
func TestServer_HandleBefore_udp(t *testing.T) {
t.Parallel()
const (
clientIPv4 = "127.0.0.1"
clientIPv6 = "::1"
)
clientIPs := []string{clientIPv4, clientIPv6}
testCases := []struct {
name string
host string
allowedClients []string
disallowedClients []string
blockedHosts []string
wantTimeout bool
}{{
name: "allow_all",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{},
blockedHosts: []string{},
wantTimeout: false,
}, {
name: "allowed_client_allowed",
host: testFQDN,
allowedClients: clientIPs,
disallowedClients: []string{},
blockedHosts: []string{},
wantTimeout: false,
}, {
name: "allowed_client_rejected",
host: testFQDN,
allowedClients: []string{"1:2:3::4"},
disallowedClients: []string{},
blockedHosts: []string{},
wantTimeout: true,
}, {
name: "disallowed_client_allowed",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{"1:2:3::4"},
blockedHosts: []string{},
wantTimeout: false,
}, {
name: "disallowed_client_rejected",
host: testFQDN,
allowedClients: []string{},
disallowedClients: clientIPs,
blockedHosts: []string{},
wantTimeout: true,
}, {
name: "blocked_hosts_allowed",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{},
blockedHosts: []string{blockedHost},
wantTimeout: false,
}, {
name: "blocked_hosts_rejected",
host: dns.Fqdn(blockedHost),
allowedClients: []string{},
disallowedClients: []string{},
blockedHosts: []string{blockedHost},
wantTimeout: true,
}}
localAns := []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: testFQDN,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 3600,
Rdlength: 4,
},
A: net.IP{1, 2, 3, 4},
}}
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := (&dns.Msg{}).SetReply(req)
resp.Answer = localAns
require.NoError(t, w.WriteMsg(resp))
})
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
AllowedClients: tc.allowedClients,
DisallowedClients: tc.disallowedClients,
BlockedHosts: tc.blockedHosts,
UpstreamDNS: []string{localUpsAddr},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
})
startDeferStop(t, s)
client := &dns.Client{
Net: "udp",
Timeout: dnsClientTimeout,
}
req := createTestMessage(tc.host)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
reply, _, err := client.Exchange(req, addr)
if tc.wantTimeout {
wantErr := &net.OpError{}
require.ErrorAs(t, err, &wantErr)
assert.True(t, wantErr.Timeout())
assert.Nil(t, reply)
} else {
require.NoError(t, err)
require.NotNil(t, reply)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
assert.Equal(t, localAns, reply.Answer)
}
})
}
}

View File

@@ -110,46 +110,6 @@ type quicConnection interface {
ConnectionState() (cs quic.ConnectionState)
}
// clientIDFromDNSContext extracts the client's ID from the server name of the
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
// is not one of these, clientID is an empty string and err is nil.
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
proto := pctx.Proto
if proto == proxy.ProtoHTTPS {
clientID, err = clientIDFromDNSContextHTTPS(pctx)
if err != nil {
return "", fmt.Errorf("checking url: %w", err)
} else if clientID != "" {
return clientID, nil
}
// Go on and check the domain name as well.
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
return "", nil
}
hostSrvName := s.conf.ServerName
if hostSrvName == "" {
return "", nil
}
cliSrvName, err := clientServerName(pctx, proto)
if err != nil {
return "", err
}
clientID, err = clientIDFromClientServerName(
hostSrvName,
cliSrvName,
s.conf.StrictSNICheck,
)
if err != nil {
return "", fmt.Errorf("clientid check: %w", err)
}
return clientID, nil
}
// clientServerName returns the TLS server name based on the protocol. For
// DNS-over-HTTPS requests, it will return the hostname part of the Host header
// if there is one.

View File

@@ -235,9 +235,18 @@ type DNSCryptConfig struct {
// ServerConfig represents server configuration.
// The zero ServerConfig is empty and ready for use.
type ServerConfig struct {
UDPListenAddrs []*net.UDPAddr // UDP listen address
TCPListenAddrs []*net.TCPAddr // TCP listen address
UpstreamConfig *proxy.UpstreamConfig // Upstream DNS servers config
// UDPListenAddrs is the list of addresses to listen for DNS-over-UDP.
UDPListenAddrs []*net.UDPAddr
// TCPListenAddrs is the list of addresses to listen for DNS-over-TCP.
TCPListenAddrs []*net.TCPAddr
// UpstreamConfig is the general configuration of upstream DNS servers.
UpstreamConfig *proxy.UpstreamConfig
// PrivateRDNSUpstreamConfig is the configuration of upstream DNS servers
// for private reverse DNS.
PrivateRDNSUpstreamConfig *proxy.UpstreamConfig
// AddrProcConf defines the configuration for the client IP processor.
// If nil, [client.EmptyAddrProc] is used.
@@ -306,24 +315,28 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies)
conf = &proxy.Config{
HTTP3: srvConf.ServeHTTP3,
Ratelimit: int(srvConf.Ratelimit),
RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4,
RatelimitSubnetLenIPv6: srvConf.RatelimitSubnetLenIPv6,
RatelimitWhitelist: srvConf.RatelimitWhitelist,
RefuseAny: srvConf.RefuseAny,
TrustedProxies: netutil.SliceSubnetSet(trustedPrefixes),
CacheMinTTL: srvConf.CacheMinTTL,
CacheMaxTTL: srvConf.CacheMaxTTL,
CacheOptimistic: srvConf.CacheOptimistic,
UpstreamConfig: srvConf.UpstreamConfig,
BeforeRequestHandler: s.beforeRequestHandler,
RequestHandler: s.handleDNSRequest,
HTTPSServerName: aghhttp.UserAgent(),
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
MaxGoroutines: srvConf.MaxGoroutines,
UseDNS64: srvConf.UseDNS64,
DNS64Prefs: srvConf.DNS64Prefixes,
HTTP3: srvConf.ServeHTTP3,
Ratelimit: int(srvConf.Ratelimit),
RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4,
RatelimitSubnetLenIPv6: srvConf.RatelimitSubnetLenIPv6,
RatelimitWhitelist: srvConf.RatelimitWhitelist,
RefuseAny: srvConf.RefuseAny,
TrustedProxies: netutil.SliceSubnetSet(trustedPrefixes),
CacheMinTTL: srvConf.CacheMinTTL,
CacheMaxTTL: srvConf.CacheMaxTTL,
CacheOptimistic: srvConf.CacheOptimistic,
UpstreamConfig: srvConf.UpstreamConfig,
PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig,
BeforeRequestHandler: s,
RequestHandler: s.handleDNSRequest,
HTTPSServerName: aghhttp.UserAgent(),
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
MaxGoroutines: srvConf.MaxGoroutines,
UseDNS64: srvConf.UseDNS64,
DNS64Prefs: srvConf.DNS64Prefixes,
UsePrivateRDNS: srvConf.UsePrivateRDNS,
PrivateSubnets: s.privateNets,
MessageConstructor: s,
}
if srvConf.EDNSClientSubnet.UseCustom {
@@ -452,12 +465,33 @@ func (s *Server) prepareIpsetListSettings() (err error) {
}
ipsets := stringutil.SplitTrimmed(string(data), "\n")
ipsets = stringutil.FilterOut(ipsets, IsCommentOrEmpty)
log.Debug("dns: using %d ipset rules from file %q", len(ipsets), fn)
return s.ipset.init(ipsets)
}
// loadUpstreams parses upstream DNS servers from the configured file or from
// the configuration itself.
func (conf *ServerConfig) loadUpstreams() (upstreams []string, err error) {
if conf.UpstreamDNSFileName == "" {
return stringutil.FilterOut(conf.UpstreamDNS, IsCommentOrEmpty), nil
}
var data []byte
data, err = os.ReadFile(conf.UpstreamDNSFileName)
if err != nil {
return nil, fmt.Errorf("reading upstream from file: %w", err)
}
upstreams = stringutil.SplitTrimmed(string(data), "\n")
log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), conf.UpstreamDNSFileName)
return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil
}
// collectListenAddr adds addrPort to addrs. It also adds its port to
// unspecPorts if its address is unspecified.
func collectListenAddr(
@@ -529,8 +563,8 @@ func (m *combinedAddrPortSet) Has(addrPort netip.AddrPort) (ok bool) {
return m.ports.Has(addrPort.Port()) && m.addrs.Has(addrPort.Addr())
}
// filterOut filters out all the upstreams that match um. It returns all the
// closing errors joined.
// filterOutAddrs filters out all the upstreams that match um. It returns all
// the closing errors joined.
func filterOutAddrs(upsConf *proxy.UpstreamConfig, set addrPortSet) (err error) {
var errs []error
delFunc := func(u upstream.Upstream) (ok bool) {

View File

@@ -3,7 +3,6 @@ package dnsforward
import (
"net"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
@@ -11,6 +10,7 @@ import (
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -64,6 +64,8 @@ func newRR(t *testing.T, name string, qtype uint16, ttl uint32, val any) (rr dns
}
func TestServer_HandleDNSRequest_dns64(t *testing.T) {
t.Parallel()
const (
ipv4Domain = "ipv4.only."
ipv6Domain = "ipv6.only."
@@ -252,33 +254,33 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
require.Len(pt, m.Question, 1)
require.Equal(pt, m.Question[0].Name, ptr64Domain)
resp := (&dns.Msg{
Answer: []dns.RR{localRR},
}).SetReply(m)
resp := (&dns.Msg{}).SetReply(m)
resp.Answer = []dns.RR{localRR}
require.NoError(t, w.WriteMsg(resp))
})
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
client := &dns.Client{
Net: "tcp",
Timeout: 1 * time.Second,
Net: string(proxy.ProtoTCP),
Timeout: testTimeout,
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
upsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
q := req.Question[0]
require.Contains(pt, tc.upsAns, q.Qtype)
require.Contains(pt, tc.upsAns, q.Qtype)
answer := tc.upsAns[q.Qtype]
resp := (&dns.Msg{
Answer: answer[sectionAnswer],
Ns: answer[sectionAuthority],
Extra: answer[sectionAdditional],
}).SetReply(req)
resp := (&dns.Msg{}).SetReply(req)
resp.Answer = answer[sectionAnswer]
resp.Ns = answer[sectionAuthority]
resp.Extra = answer[sectionAdditional]
require.NoError(pt, w.WriteMsg(resp))
})
@@ -308,10 +310,54 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
req := (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype)
resp, _, excErr := client.Exchange(req, s.dnsProxy.Addr(proxy.ProtoTCP).String())
resp, _, excErr := client.Exchange(req, s.proxy().Addr(proxy.ProtoTCP).String())
require.NoError(t, excErr)
require.Equal(t, tc.wantAns, resp.Answer)
})
}
}
func TestServer_dns64WithDisabledRDNS(t *testing.T) {
t.Parallel()
// Shouldn't go to upstream at all.
panicHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
panic("not implemented")
})
upsAddr := aghtest.StartLocalhostUpstream(t, panicHdlr).String()
localUpsAddr := aghtest.StartLocalhostUpstream(t, panicHdlr).String()
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
UseDNS64: true,
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
UpstreamDNS: []string{upsAddr},
},
UsePrivateRDNS: false,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
})
startDeferStop(t, s)
mappedIPv6 := net.ParseIP("64:ff9b::102:304")
arpa, err := netutil.IPToReversedAddr(mappedIPv6)
require.NoError(t, err)
req := (&dns.Msg{}).SetQuestion(dns.Fqdn(arpa), dns.TypePTR)
cli := &dns.Client{
Net: string(proxy.ProtoTCP),
Timeout: testTimeout,
}
resp, _, err := cli.Exchange(req, s.proxy().Addr(proxy.ProtoTCP).String())
require.NoError(t, err)
assert.Equal(t, dns.RcodeNameError, resp.Rcode)
}

View File

@@ -2,6 +2,7 @@
package dnsforward
import (
"cmp"
"context"
"fmt"
"io"
@@ -15,7 +16,6 @@ import (
"sync/atomic"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
@@ -135,12 +135,6 @@ type Server struct {
// WHOIS, etc.
addrProc client.AddressProcessor
// localResolvers is a DNS proxy instance used to resolve PTR records for
// addresses considered private as per the [privateNets].
//
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
localResolvers *proxy.Proxy
// sysResolvers used to fetch system resolvers to use by default for private
// PTR resolving.
sysResolvers SystemResolvers
@@ -158,12 +152,6 @@ type Server struct {
// [upstream.Resolver] interface.
bootResolvers []*upstream.UpstreamResolver
// recDetector is a cache for recursive requests. It is used to detect and
// prevent recursive requests only for private upstreams.
//
// See https://github.com/adguardTeam/adGuardHome/issues/3185#issuecomment-851048135.
recDetector *recursionDetector
// dns64Pref is the NAT64 prefix used for DNS64 response mapping. The major
// part of DNS64 happens inside the [proxy] package, but there still are
// some places where response mapping is needed (e.g. DHCP).
@@ -212,14 +200,6 @@ type DNSCreateParams struct {
LocalDomain string
}
const (
// recursionTTL is the time recursive request is cached for.
recursionTTL = 1 * time.Second
// cachedRecurrentReqNum is the maximum number of cached recurrent
// requests.
cachedRecurrentReqNum = 1000
)
// NewServer creates a new instance of the dnsforward.Server
// Note: this function must be called only once
//
@@ -256,7 +236,6 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
// TODO(e.burkov): Use some case-insensitive string comparison.
localDomainSuffix: strings.ToLower(localDomainSuffix),
etcHosts: etcHosts,
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
clientIDCache: cache.New(cache.Config{
EnableLRU: true,
MaxCount: defaultClientIDCacheCount,
@@ -366,6 +345,7 @@ func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err er
s.serverLock.RLock()
defer s.serverLock.RUnlock()
// TODO(e.burkov): Migrate to [netip.Addr] already.
arpa, err := netutil.IPToReversedAddr(ip.AsSlice())
if err != nil {
return "", 0, fmt.Errorf("reversing ip: %w", err)
@@ -386,25 +366,23 @@ func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err er
}
dctx := &proxy.DNSContext{
Proto: "udp",
Req: req,
Proto: proxy.ProtoUDP,
Req: req,
IsPrivateClient: true,
}
var resolver *proxy.Proxy
var errMsg string
if s.privateNets.Contains(ip) {
if !s.conf.UsePrivateRDNS {
return "", 0, nil
}
resolver = s.localResolvers
errMsg = "resolving a private address: %w"
s.recDetector.add(*req)
dctx.RequestedPrivateRDNS = netip.PrefixFrom(ip, ip.BitLen())
} else {
resolver = s.internalProxy
errMsg = "resolving an address: %w"
}
if err = resolver.Resolve(dctx); err != nil {
if err = s.internalProxy.Resolve(dctx); err != nil {
return "", 0, fmt.Errorf(errMsg, err)
}
@@ -473,103 +451,6 @@ func (s *Server) startLocked() error {
return err
}
// prepareLocalResolvers initializes the local upstreams configuration using
// boot as bootstrap. It assumes that s.serverLock is locked or s not running.
func (s *Server) prepareLocalResolvers(
boot upstream.Resolver,
) (uc *proxy.UpstreamConfig, err error) {
set, err := s.conf.ourAddrsSet()
if err != nil {
// Don't wrap the error because it's informative enough as is.
return nil, err
}
resolvers := s.conf.LocalPTRResolvers
confNeedsFiltering := len(resolvers) > 0
if confNeedsFiltering {
resolvers = stringutil.FilterOut(resolvers, IsCommentOrEmpty)
} else {
sysResolvers := slices.DeleteFunc(slices.Clone(s.sysResolvers.Addrs()), set.Has)
resolvers = make([]string, 0, len(sysResolvers))
for _, r := range sysResolvers {
resolvers = append(resolvers, r.String())
}
}
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", resolvers)
uc, err = s.prepareUpstreamConfig(resolvers, nil, &upstream.Options{
Bootstrap: boot,
Timeout: defaultLocalTimeout,
// TODO(e.burkov): Should we verify server's certificates?
PreferIPv6: s.conf.BootstrapPreferIPv6,
})
if err != nil {
return nil, fmt.Errorf("preparing private upstreams: %w", err)
}
if confNeedsFiltering {
err = filterOutAddrs(uc, set)
if err != nil {
return nil, fmt.Errorf("filtering private upstreams: %w", err)
}
}
return uc, nil
}
// LocalResolversError is an error type for errors during local resolvers setup.
// This is only needed to distinguish these errors from errors returned by
// creating the proxy.
type LocalResolversError struct {
Err error
}
// type check
var _ error = (*LocalResolversError)(nil)
// Error implements the error interface for *LocalResolversError.
func (err *LocalResolversError) Error() (s string) {
return fmt.Sprintf("creating local resolvers: %s", err.Err)
}
// type check
var _ errors.Wrapper = (*LocalResolversError)(nil)
// Unwrap implements the [errors.Wrapper] interface for *LocalResolversError.
func (err *LocalResolversError) Unwrap() error {
return err.Err
}
// setupLocalResolvers initializes and sets the resolvers for local addresses.
// It assumes s.serverLock is locked or s not running. It returns the upstream
// configuration used for private PTR resolving, or nil if it's disabled. Note,
// that it's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
func (s *Server) setupLocalResolvers(boot upstream.Resolver) (uc *proxy.UpstreamConfig, err error) {
if !s.conf.UsePrivateRDNS {
// It's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
return nil, nil
}
uc, err = s.prepareLocalResolvers(boot)
if err != nil {
// Don't wrap the error because it's informative enough as is.
return nil, err
}
localResolvers, err := proxy.New(&proxy.Config{
UpstreamConfig: uc,
})
if err != nil {
return nil, &LocalResolversError{Err: err}
}
s.localResolvers = localResolvers
// TODO(e.burkov): Should we also consider the DNS64 usage?
return uc, nil
}
// Prepare initializes parameters of s using data from conf. conf must not be
// nil.
func (s *Server) Prepare(conf *ServerConfig) (err error) {
@@ -586,7 +467,7 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
s.initDefaultSettings()
boot, err := s.prepareInternalDNS()
err = s.prepareInternalDNS()
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
@@ -608,12 +489,6 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
return fmt.Errorf("preparing access: %w", err)
}
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
proxyConfig.PrivateRDNSUpstreamConfig, err = s.setupLocalResolvers(boot)
if err != nil {
return fmt.Errorf("setting up resolvers: %w", err)
}
proxyConfig.Fallbacks, err = s.setupFallbackDNS()
if err != nil {
return fmt.Errorf("setting up fallback dns servers: %w", err)
@@ -626,8 +501,6 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
s.dnsProxy = dnsProxy
s.recDetector.clear()
s.setupAddrProc()
s.registerHandlers()
@@ -635,36 +508,127 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
return nil
}
// prepareInternalDNS initializes the internal state of s before initializing
// the primary DNS proxy instance. It assumes s.serverLock is locked or the
// Server not running.
func (s *Server) prepareInternalDNS() (boot upstream.Resolver, err error) {
err = s.prepareIpsetListSettings()
// prepareUpstreamSettings sets upstream DNS server settings.
func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
// Load upstreams either from the file, or from the settings
var upstreams []string
upstreams, err = s.conf.loadUpstreams()
if err != nil {
return nil, fmt.Errorf("preparing ipset settings: %w", err)
return fmt.Errorf("loading upstreams: %w", err)
}
s.bootstrap, s.bootResolvers, err = s.createBootstrap(s.conf.BootstrapDNS, &upstream.Options{
Timeout: DefaultTimeout,
uc, err := newUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
Bootstrap: boot,
Timeout: s.conf.UpstreamTimeout,
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
PreferIPv6: s.conf.BootstrapPreferIPv6,
// Use a customized set of RootCAs, because Go's default mechanism of
// loading TLS roots does not always work properly on some routers so we're
// loading roots manually and pass it here.
//
// See [aghtls.SystemRootCAs].
//
// TODO(a.garipov): Investigate if that's true.
RootCAs: s.conf.TLSv12Roots,
CipherSuites: s.conf.TLSCiphers,
})
if err != nil {
return fmt.Errorf("preparing upstream config: %w", err)
}
s.conf.UpstreamConfig = uc
return nil
}
// PrivateRDNSError is returned when the private rDNS upstreams are
// invalid but enabled.
//
// TODO(e.burkov): Consider allowing to use incomplete private rDNS upstreams
// configuration in proxy when the private rDNS function is enabled. In theory,
// proxy supports the case when no upstreams provided to resolve the private
// request, since it already supports this for DNS64-prefixed PTR requests.
type PrivateRDNSError struct {
err error
}
// Error implements the [errors.Error] interface.
func (e *PrivateRDNSError) Error() (s string) {
return e.err.Error()
}
func (e *PrivateRDNSError) Unwrap() (err error) {
return e.err
}
// prepareLocalResolvers initializes the private RDNS upstream configuration
// according to the server's settings. It assumes s.serverLock is locked or the
// Server not running.
func (s *Server) prepareLocalResolvers() (uc *proxy.UpstreamConfig, err error) {
if !s.conf.UsePrivateRDNS {
return nil, nil
}
var ownAddrs addrPortSet
ownAddrs, err = s.conf.ourAddrsSet()
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
opts := &upstream.Options{
Bootstrap: s.bootstrap,
Timeout: defaultLocalTimeout,
// TODO(e.burkov): Should we verify server's certificates?
PreferIPv6: s.conf.BootstrapPreferIPv6,
}
addrs := s.conf.LocalPTRResolvers
uc, err = newPrivateConfig(addrs, ownAddrs, s.sysResolvers, s.privateNets, opts)
if err != nil {
return nil, fmt.Errorf("preparing resolvers: %w", err)
}
return uc, nil
}
// prepareInternalDNS initializes the internal state of s before initializing
// the primary DNS proxy instance. It assumes s.serverLock is locked or the
// Server not running.
func (s *Server) prepareInternalDNS() (err error) {
err = s.prepareIpsetListSettings()
if err != nil {
return fmt.Errorf("preparing ipset settings: %w", err)
}
bootOpts := &upstream.Options{
Timeout: DefaultTimeout,
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
}
s.bootstrap, s.bootResolvers, err = newBootstrap(s.conf.BootstrapDNS, s.etcHosts, bootOpts)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
err = s.prepareUpstreamSettings(s.bootstrap)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return s.bootstrap, err
return err
}
s.conf.PrivateRDNSUpstreamConfig, err = s.prepareLocalResolvers()
if err != nil {
return err
}
err = s.prepareInternalProxy()
if err != nil {
return s.bootstrap, fmt.Errorf("preparing internal proxy: %w", err)
return fmt.Errorf("preparing internal proxy: %w", err)
}
return s.bootstrap, nil
return nil
}
// setupFallbackDNS initializes the fallback DNS servers.
@@ -743,10 +707,16 @@ func validateBlockingMode(
func (s *Server) prepareInternalProxy() (err error) {
srvConf := s.conf
conf := &proxy.Config{
CacheEnabled: true,
CacheSizeBytes: 4096,
UpstreamConfig: srvConf.UpstreamConfig,
MaxGoroutines: s.conf.MaxGoroutines,
CacheEnabled: true,
CacheSizeBytes: 4096,
PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig,
UpstreamConfig: srvConf.UpstreamConfig,
MaxGoroutines: srvConf.MaxGoroutines,
UseDNS64: srvConf.UseDNS64,
DNS64Prefs: srvConf.DNS64Prefixes,
UsePrivateRDNS: srvConf.UsePrivateRDNS,
PrivateSubnets: s.privateNets,
MessageConstructor: s,
}
err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration)
@@ -782,11 +752,6 @@ func (s *Server) stopLocked() (err error) {
}
}
logCloserErr(s.internalProxy.UpstreamConfig, "dnsforward: closing internal resolvers: %s")
if s.localResolvers != nil {
logCloserErr(s.localResolvers.UpstreamConfig, "dnsforward: closing local resolvers: %s")
}
for _, b := range s.bootResolvers {
logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address())
}
@@ -908,5 +873,5 @@ func (s *Server) IsBlockedClient(ip netip.Addr, clientID string) (blocked bool,
blocked = true
}
return blocked, aghalg.Coalesce(rule, clientID)
return blocked, cmp.Or(rule, clientID)
}

View File

@@ -1,7 +1,7 @@
package dnsforward
import (
"context"
"cmp"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
@@ -21,7 +21,6 @@ import (
"testing/fstest"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
@@ -190,7 +189,7 @@ func newGoogleUpstream() (u upstream.Upstream) {
return &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "google.upstream.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
return aghalg.Coalesce(
return cmp.Or(
aghtest.MatchedResponse(req, dns.TypeA, googleDomainName, "8.8.8.8"),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil
@@ -253,7 +252,7 @@ func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) {
wg := &sync.WaitGroup{}
for i := 0; i < testMessagesCount; i++ {
for range testMessagesCount {
msg := createGoogleATestMessage()
wg.Add(1)
@@ -276,7 +275,7 @@ func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) {
func sendTestMessages(t *testing.T, conn *dns.Conn) {
t.Helper()
for i := 0; i < testMessagesCount; i++ {
for i := range testMessagesCount {
req := createGoogleATestMessage()
err := conn.WriteMsg(req)
assert.NoErrorf(t, err, "cannot write message #%d: %s", i, err)
@@ -491,19 +490,10 @@ func TestServerRace(t *testing.T) {
}
func TestSafeSearch(t *testing.T) {
resolver := &aghtest.Resolver{
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
ip4, ip6 := aghtest.HostToIPs(host)
return []net.IP{ip4.AsSlice(), ip6.AsSlice()}, nil
},
}
safeSearchConf := filtering.SafeSearchConfig{
Enabled: true,
Google: true,
Yandex: true,
CustomResolver: resolver,
Enabled: true,
Google: true,
Yandex: true,
}
filterConf := &filtering.Config{
@@ -540,7 +530,6 @@ func TestSafeSearch(t *testing.T) {
client := &dns.Client{}
yandexIP := netip.AddrFrom4([4]byte{213, 180, 193, 56})
googleIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
testCases := []struct {
host string
@@ -564,19 +553,19 @@ func TestSafeSearch(t *testing.T) {
wantCNAME: "",
}, {
host: "www.google.com.",
want: googleIP,
want: netip.Addr{},
wantCNAME: "forcesafesearch.google.com.",
}, {
host: "www.google.com.af.",
want: googleIP,
want: netip.Addr{},
wantCNAME: "forcesafesearch.google.com.",
}, {
host: "www.google.be.",
want: googleIP,
want: netip.Addr{},
wantCNAME: "forcesafesearch.google.com.",
}, {
host: "www.google.by.",
want: googleIP,
want: netip.Addr{},
wantCNAME: "forcesafesearch.google.com.",
}}
@@ -593,12 +582,15 @@ func TestSafeSearch(t *testing.T) {
cname := testutil.RequireTypeAssert[*dns.CNAME](t, reply.Answer[0])
assert.Equal(t, tc.wantCNAME, cname.Target)
a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[1])
assert.NotEmpty(t, a.A)
} else {
require.Len(t, reply.Answer, 1)
}
a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[len(reply.Answer)-1])
assert.Equal(t, net.IP(tc.want.AsSlice()), a.A)
a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[0])
assert.Equal(t, net.IP(tc.want.AsSlice()), a.A)
}
})
}
}
@@ -691,7 +683,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
atomic.AddUint32(&upsCalledCounter, 1)
return aghalg.Coalesce(
return cmp.Or(
aghtest.MatchedResponse(req, dns.TypeA, "host", "192.168.0.1"),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil
@@ -1152,7 +1144,7 @@ func TestRewrite(t *testing.T) {
}))
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
return aghalg.Coalesce(
return cmp.Or(
aghtest.MatchedResponse(req, dns.TypeA, "example.org", "4.3.2.1"),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil
@@ -1481,7 +1473,7 @@ func TestServer_Exchange(t *testing.T) {
require.NoError(t, err)
extUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := aghalg.Coalesce(
resp := cmp.Or(
aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, dns.Fqdn(onesHost)),
doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, dns.Fqdn(twosHost))),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
@@ -1495,7 +1487,7 @@ func TestServer_Exchange(t *testing.T) {
require.NoError(t, err)
locUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := aghalg.Coalesce(
resp := cmp.Or(
aghtest.MatchedResponse(req, dns.TypePTR, revLocIPv4, dns.Fqdn(localDomainHost)),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
)

View File

@@ -143,7 +143,7 @@ func (s *Server) filterDNSRewrite(
res *filtering.Result,
pctx *proxy.DNSContext,
) (err error) {
resp := s.makeResponse(req)
resp := s.replyCompressed(req)
dnsrr := res.DNSRewriteResult
if dnsrr == nil {
return errors.Error("no dns rewrite rule content")

View File

@@ -1,57 +1,17 @@
package dnsforward
import (
"encoding/binary"
"fmt"
"net"
"slices"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
)
// beforeRequestHandler is the handler that is called before any other
// processing, including logs. It performs access checks and puts the client
// ID, if there is one, into the server's cache.
func (s *Server) beforeRequestHandler(
_ *proxy.Proxy,
pctx *proxy.DNSContext,
) (reply bool, err error) {
clientID, err := s.clientIDFromDNSContext(pctx)
if err != nil {
return false, fmt.Errorf("getting clientid: %w", err)
}
blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID)
if blocked {
return s.preBlockedResponse(pctx)
}
if len(pctx.Req.Question) == 1 {
q := pctx.Req.Question[0]
qt := q.Qtype
host := aghnet.NormalizeDomain(q.Name)
if s.access.isBlockedHost(host, qt) {
log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host)
return s.preBlockedResponse(pctx)
}
}
if clientID != "" {
key := [8]byte{}
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
s.clientIDCache.Set(key[:], []byte(clientID))
}
return true, nil
}
// clientRequestFilteringSettings looks up client filtering settings using the
// client's IP address and ID, if any, from dctx.
func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) {
@@ -71,6 +31,7 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err
req := pctx.Req
q := req.Question[0]
host := strings.TrimSuffix(q.Name, ".")
resVal, err := s.dnsFilter.CheckHost(host, q.Qtype, dctx.setts)
if err != nil {
return nil, fmt.Errorf("checking host %q: %w", host, err)
@@ -79,22 +40,15 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err
// TODO(a.garipov): Make CheckHost return a pointer.
res = &resVal
switch {
case res.IsFiltered:
log.Debug(
"dnsforward: host %q is filtered, reason: %q; rule: %q",
host,
res.Reason,
res.Rules[0].Text,
)
pctx.Res = s.genDNSFilterMessage(pctx, res)
case res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) &&
res.CanonName != "" &&
len(res.IPList) == 0:
case isRewrittenCNAME(res):
// Resolve the new canonical name, not the original host name. The
// original question is readded in processFilteringAfterResponse.
dctx.origQuestion = q
req.Question[0].Name = dns.Fqdn(res.CanonName)
case res.Reason == filtering.Rewritten:
case res.IsFiltered:
log.Debug("dnsforward: host %q is filtered, reason: %q", host, res.Reason)
pctx.Res = s.genDNSFilterMessage(pctx, res)
case res.Reason.In(filtering.Rewritten, filtering.FilteredSafeSearch):
pctx.Res = s.getCNAMEWithIPs(req, res.IPList, res.CanonName)
case res.Reason.In(filtering.RewrittenRule, filtering.RewrittenAutoHosts):
if err = s.filterDNSRewrite(req, res, pctx); err != nil {
@@ -105,6 +59,17 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err
return res, err
}
// isRewrittenCNAME returns true if the request considered to be rewritten with
// CNAME and has no resolved IPs.
func isRewrittenCNAME(res *filtering.Result) (ok bool) {
return res.Reason.In(
filtering.Rewritten,
filtering.RewrittenRule,
filtering.FilteredSafeSearch) &&
res.CanonName != "" &&
len(res.IPList) == 0
}
// checkHostRules checks the host against filters. It is safe for concurrent
// use.
func (s *Server) checkHostRules(

View File

@@ -1,6 +1,7 @@
package dnsforward
import (
"cmp"
"encoding/json"
"fmt"
"io"
@@ -261,55 +262,17 @@ func (req *jsonDNSConfig) checkUpstreamMode() (err error) {
}
}
// checkBootstrap returns an error if any bootstrap address is invalid.
func (req *jsonDNSConfig) checkBootstrap() (err error) {
if req.Bootstraps == nil {
return nil
}
var b string
defer func() { err = errors.Annotate(err, "checking bootstrap %s: %w", b) }()
for _, b = range *req.Bootstraps {
if b == "" {
return errors.Error("empty")
}
var resolver *upstream.UpstreamResolver
if resolver, err = upstream.NewUpstreamResolver(b, nil); err != nil {
// Don't wrap the error because it's informative enough as is.
return err
}
if err = resolver.Close(); err != nil {
return fmt.Errorf("closing %s: %w", b, err)
}
}
return nil
}
// checkFallbacks returns an error if any fallback address is invalid.
func (req *jsonDNSConfig) checkFallbacks() (err error) {
if req.Fallbacks == nil {
return nil
}
_, err = proxy.ParseUpstreamsConfig(*req.Fallbacks, &upstream.Options{})
if err != nil {
return fmt.Errorf("fallback servers: %w", err)
}
return nil
}
// validate returns an error if any field of req is invalid.
//
// TODO(s.chzhen): Parse, don't validate.
func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
func (req *jsonDNSConfig) validate(
ownAddrs addrPortSet,
sysResolvers SystemResolvers,
privateNets netutil.SubnetSet,
) (err error) {
defer func() { err = errors.Annotate(err, "validating dns config: %w") }()
err = req.validateUpstreamDNSServers(privateNets)
err = req.validateUpstreamDNSServers(ownAddrs, sysResolvers, privateNets)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
@@ -342,20 +305,77 @@ func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
return nil
}
// checkBootstrap returns an error if any bootstrap address is invalid.
func (req *jsonDNSConfig) checkBootstrap() (err error) {
if req.Bootstraps == nil {
return nil
}
var b string
defer func() { err = errors.Annotate(err, "checking bootstrap %s: %w", b) }()
for _, b = range *req.Bootstraps {
if b == "" {
return errors.Error("empty")
}
var resolver *upstream.UpstreamResolver
if resolver, err = upstream.NewUpstreamResolver(b, nil); err != nil {
// Don't wrap the error because it's informative enough as is.
return err
}
if err = resolver.Close(); err != nil {
return fmt.Errorf("closing %s: %w", b, err)
}
}
return nil
}
// checkPrivateRDNS returns an error if the configuration of the private RDNS is
// not valid.
func (req *jsonDNSConfig) checkPrivateRDNS(
ownAddrs addrPortSet,
sysResolvers SystemResolvers,
privateNets netutil.SubnetSet,
) (err error) {
if (req.UsePrivateRDNS == nil || !*req.UsePrivateRDNS) && req.LocalPTRUpstreams == nil {
return nil
}
addrs := cmp.Or(req.LocalPTRUpstreams, &[]string{})
uc, err := newPrivateConfig(*addrs, ownAddrs, sysResolvers, privateNets, &upstream.Options{})
err = errors.WithDeferred(err, uc.Close())
if err != nil {
return fmt.Errorf("private upstream servers: %w", err)
}
return nil
}
// validateUpstreamDNSServers returns an error if any field of req is invalid.
func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetSet) (err error) {
func (req *jsonDNSConfig) validateUpstreamDNSServers(
ownAddrs addrPortSet,
sysResolvers SystemResolvers,
privateNets netutil.SubnetSet,
) (err error) {
var uc *proxy.UpstreamConfig
opts := &upstream.Options{}
if req.Upstreams != nil {
_, err = proxy.ParseUpstreamsConfig(*req.Upstreams, &upstream.Options{})
uc, err = proxy.ParseUpstreamsConfig(*req.Upstreams, opts)
err = errors.WithDeferred(err, uc.Close())
if err != nil {
return fmt.Errorf("upstream servers: %w", err)
}
}
if req.LocalPTRUpstreams != nil {
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets)
if err != nil {
return fmt.Errorf("private upstream servers: %w", err)
}
err = req.checkPrivateRDNS(ownAddrs, sysResolvers, privateNets)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
err = req.checkBootstrap()
@@ -364,10 +384,12 @@ func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetS
return err
}
err = req.checkFallbacks()
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
if req.Fallbacks != nil {
uc, err = proxy.ParseUpstreamsConfig(*req.Fallbacks, opts)
err = errors.WithDeferred(err, uc.Close())
if err != nil {
return fmt.Errorf("fallback servers: %w", err)
}
}
return nil
@@ -436,7 +458,16 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
return
}
err = req.validate(s.privateNets)
// TODO(e.burkov): Consider prebuilding this set on startup.
ourAddrs, err := s.conf.ourAddrsSet()
if err != nil {
// TODO(e.burkov): Put into openapi.
aghhttp.Error(r, w, http.StatusInternalServerError, "getting our addresses: %s", err)
return
}
err = req.validate(ourAddrs, s.sysResolvers, s.privateNets)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@@ -587,7 +618,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
}
var boots []*upstream.UpstreamResolver
opts.Bootstrap, boots, err = s.createBootstrap(req.BootstrapDNS, opts)
opts.Bootstrap, boots, err = newBootstrap(req.BootstrapDNS, s.etcHosts, opts)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse bootstrap servers: %s", err)

View File

@@ -245,9 +245,8 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
wantSet: "",
}, {
name: "local_ptr_upstreams_bad",
wantSet: `validating dns config: ` +
`private upstream servers: checking domain-specific upstreams: ` +
`bad arpa domain name "non.arpa.": not a reversed ip network`,
wantSet: `validating dns config: private upstream servers: ` +
`bad arpa domain name "non.arpa": not a reversed ip network`,
}, {
name: "local_ptr_upstreams_null",
wantSet: "",
@@ -318,58 +317,6 @@ func TestIsCommentOrEmpty(t *testing.T) {
}
}
func TestValidateUpstreamsPrivate(t *testing.T) {
ss := netutil.SubnetSetFunc(netutil.IsLocallyServed)
testCases := []struct {
name string
wantErr string
u string
}{{
name: "success_address",
wantErr: ``,
u: "[/1.0.0.127.in-addr.arpa/]#",
}, {
name: "success_subnet",
wantErr: ``,
u: "[/127.in-addr.arpa/]#",
}, {
name: "not_arpa_subnet",
wantErr: `checking domain-specific upstreams: ` +
`bad arpa domain name "hello.world.": not a reversed ip network`,
u: "[/hello.world/]#",
}, {
name: "non-private_arpa_address",
wantErr: `checking domain-specific upstreams: ` +
`arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network`,
u: "[/1.2.3.4.in-addr.arpa/]#",
}, {
name: "non-private_arpa_subnet",
wantErr: `checking domain-specific upstreams: ` +
`arpa domain "128.in-addr.arpa." should point to a locally-served network`,
u: "[/128.in-addr.arpa/]#",
}, {
name: "several_bad",
wantErr: `checking domain-specific upstreams: ` +
`arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network` + "\n" +
`bad arpa domain name "non.arpa.": not a reversed ip network`,
u: "[/non.arpa/1.2.3.4.in-addr.arpa/127.in-addr.arpa/]#",
}, {
name: "partial_good",
wantErr: "",
u: "[/a.1.2.3.10.in-addr.arpa/a.10.in-addr.arpa/]#",
}}
for _, tc := range testCases {
set := []string{"192.168.0.1", tc.u}
t.Run(tc.name, func(t *testing.T) {
err := ValidateUpstreamsPrivate(set, ss)
testutil.AssertErrorMsg(t, tc.wantErr, err)
})
}
}
func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) {
t.Helper()

View File

@@ -11,17 +11,21 @@ import (
"github.com/miekg/dns"
)
// makeResponse creates a DNS response by req and sets necessary flags. It also
// guarantees that req.Question will be not empty.
func (s *Server) makeResponse(req *dns.Msg) (resp *dns.Msg) {
resp = &dns.Msg{
MsgHdr: dns.MsgHdr{
RecursionAvailable: true,
},
Compress: true,
}
// TODO(e.burkov): Name all the methods by a [proxy.MessageConstructor]
// template. Also extract all the methods to a separate entity.
resp.SetReply(req)
// reply creates a DNS response for req.
func (*Server) reply(req *dns.Msg, code int) (resp *dns.Msg) {
resp = (&dns.Msg{}).SetRcode(req, code)
resp.RecursionAvailable = true
return resp
}
// replyCompressed creates a DNS response for req and sets the compress flag.
func (s *Server) replyCompressed(req *dns.Msg) (resp *dns.Msg) {
resp = s.reply(req, dns.RcodeSuccess)
resp.Compress = true
return resp
}
@@ -48,10 +52,10 @@ func (s *Server) genDNSFilterMessage(
) (resp *dns.Msg) {
req := dctx.Req
qt := req.Question[0].Qtype
if qt != dns.TypeA && qt != dns.TypeAAAA {
if qt != dns.TypeA && qt != dns.TypeAAAA && qt != dns.TypeHTTPS {
m, _, _ := s.dnsFilter.BlockingMode()
if m == filtering.BlockingModeNullIP {
return s.makeResponse(req)
return s.replyCompressed(req)
}
return s.newMsgNODATA(req)
@@ -75,7 +79,7 @@ func (s *Server) genDNSFilterMessage(
// getCNAMEWithIPs generates a filtered response to req for with CNAME record
// and provided ips.
func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) (resp *dns.Msg) {
resp = s.makeResponse(req)
resp = s.replyCompressed(req)
originalName := req.Question[0].Name
@@ -121,13 +125,13 @@ func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.M
case filtering.BlockingModeNullIP:
return s.makeResponseNullIP(req)
case filtering.BlockingModeNXDOMAIN:
return s.genNXDomain(req)
return s.NewMsgNXDOMAIN(req)
case filtering.BlockingModeREFUSED:
return s.makeResponseREFUSED(req)
default:
log.Error("dnsforward: invalid blocking mode %q", mode)
return s.makeResponse(req)
return s.replyCompressed(req)
}
}
@@ -148,25 +152,18 @@ func (s *Server) makeResponseCustomIP(
// genDNSFilterMessage.
log.Error("dnsforward: invalid msg type %s for custom IP blocking mode", dns.Type(qt))
return s.makeResponse(req)
return s.replyCompressed(req)
}
}
func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg {
resp := dns.Msg{}
resp.SetRcode(request, dns.RcodeServerFailure)
resp.RecursionAvailable = true
return &resp
}
func (s *Server) genARecord(request *dns.Msg, ip netip.Addr) *dns.Msg {
resp := s.makeResponse(request)
resp := s.replyCompressed(request)
resp.Answer = append(resp.Answer, s.genAnswerA(request, ip))
return resp
}
func (s *Server) genAAAARecord(request *dns.Msg, ip netip.Addr) *dns.Msg {
resp := s.makeResponse(request)
resp := s.replyCompressed(request)
resp.Answer = append(resp.Answer, s.genAnswerAAAA(request, ip))
return resp
}
@@ -252,7 +249,7 @@ func (s *Server) genResponseWithIPs(req *dns.Msg, ips []netip.Addr) (resp *dns.M
// Go on and return an empty response.
}
resp = s.makeResponse(req)
resp = s.replyCompressed(req)
resp.Answer = ans
return resp
@@ -288,7 +285,7 @@ func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) {
case dns.TypeAAAA:
resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv6Unspecified()})
default:
resp = s.makeResponse(req)
resp = s.replyCompressed(req)
}
return resp
@@ -298,7 +295,7 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
if newAddr == "" {
log.Info("dnsforward: block host is not specified")
return s.genServerFailure(request)
return s.NewMsgSERVFAIL(request)
}
ip, err := netip.ParseAddr(newAddr)
@@ -321,17 +318,17 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
if prx == nil {
log.Debug("dnsforward: %s", srvClosedErr)
return s.genServerFailure(request)
return s.NewMsgSERVFAIL(request)
}
err = prx.Resolve(newContext)
if err != nil {
log.Info("dnsforward: looking up replacement host %q: %s", newAddr, err)
return s.genServerFailure(request)
return s.NewMsgSERVFAIL(request)
}
resp := s.makeResponse(request)
resp := s.replyCompressed(request)
if newContext.Res != nil {
for _, answer := range newContext.Res.Answer {
answer.Header().Name = request.Question[0].Name
@@ -342,48 +339,21 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
return resp
}
// preBlockedResponse returns a protocol-appropriate response for a request that
// was blocked by access settings.
func (s *Server) preBlockedResponse(pctx *proxy.DNSContext) (reply bool, err error) {
if pctx.Proto == proxy.ProtoUDP || pctx.Proto == proxy.ProtoDNSCrypt {
// Return nil so that dnsproxy drops the connection and thus
// prevent DNS amplification attacks.
return false, nil
}
pctx.Res = s.makeResponseREFUSED(pctx.Req)
// Return true so that dnsproxy responds with the REFUSED message.
return true, nil
}
// Create REFUSED DNS response
func (s *Server) makeResponseREFUSED(request *dns.Msg) *dns.Msg {
resp := dns.Msg{}
resp.SetRcode(request, dns.RcodeRefused)
resp.RecursionAvailable = true
return &resp
func (s *Server) makeResponseREFUSED(req *dns.Msg) *dns.Msg {
return s.reply(req, dns.RcodeRefused)
}
// newMsgNODATA returns a properly initialized NODATA response.
//
// See https://www.rfc-editor.org/rfc/rfc2308#section-2.2.
func (s *Server) newMsgNODATA(req *dns.Msg) (resp *dns.Msg) {
resp = (&dns.Msg{}).SetRcode(req, dns.RcodeSuccess)
resp.RecursionAvailable = true
resp = s.reply(req, dns.RcodeSuccess)
resp.Ns = s.genSOA(req)
return resp
}
func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg {
resp := dns.Msg{}
resp.SetRcode(request, dns.RcodeNameError)
resp.RecursionAvailable = true
resp.Ns = s.genSOA(request)
return &resp
}
func (s *Server) genSOA(request *dns.Msg) []dns.RR {
zone := ""
if len(request.Question) > 0 {
@@ -415,5 +385,43 @@ func (s *Server) genSOA(request *dns.Msg) []dns.RR {
if len(zone) > 0 && zone[0] != '.' {
soa.Mbox += zone
}
return []dns.RR{&soa}
}
// type check
var _ proxy.MessageConstructor = (*Server)(nil)
// NewMsgNXDOMAIN implements the [proxy.MessageConstructor] interface for
// *Server.
func (s *Server) NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg) {
resp = s.reply(req, dns.RcodeNameError)
resp.Ns = s.genSOA(req)
return resp
}
// NewMsgSERVFAIL implements the [proxy.MessageConstructor] interface for
// *Server.
func (s *Server) NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg) {
return s.reply(req, dns.RcodeServerFailure)
}
// NewMsgNOTIMPLEMENTED implements the [proxy.MessageConstructor] interface for
// *Server.
func (s *Server) NewMsgNOTIMPLEMENTED(req *dns.Msg) (resp *dns.Msg) {
resp = s.reply(req, dns.RcodeNotImplemented)
// Most of the Internet and especially the inner core has an MTU of at least
// 1500 octets. Maximum DNS/UDP payload size for IPv6 on MTU 1500 ethernet
// is 1452 (1500 minus 40 (IPv6 header size) minus 8 (UDP header size)).
//
// See appendix A of https://datatracker.ietf.org/doc/draft-ietf-dnsop-avoid-fragmentation/17.
const maxUDPPayload = 1452
// NOTIMPLEMENTED without EDNS is treated as 'we don't support EDNS', so
// explicitly set it.
resp.SetEdns0(maxUDPPayload, false)
return resp
}

View File

@@ -1,20 +1,17 @@
package dnsforward
import (
"cmp"
"encoding/binary"
"net"
"net/netip"
"strconv"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/miekg/dns"
)
@@ -34,11 +31,6 @@ type dnsContext struct {
// response is modified by filters.
origResp *dns.Msg
// unreversedReqIP stores an IP address obtained from a PTR request if it
// was parsed successfully and belongs to one of the locally served IP
// ranges.
unreversedReqIP netip.Addr
// err is the error returned from a processing function.
err error
@@ -63,10 +55,6 @@ type dnsContext struct {
// responseAD shows if the response had the AD bit set.
responseAD bool
// isLocalClient shows if client's IP address is from locally served
// network.
isLocalClient bool
// isDHCPHost is true if the request for a local domain name and the DHCP is
// available for this request.
isDHCPHost bool
@@ -109,15 +97,11 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error
// (*proxy.Proxy).handleDNSRequest method performs it before calling the
// appropriate handler.
mods := []modProcessFunc{
s.processRecursion,
s.processInitial,
s.processDDRQuery,
s.processDetermineLocal,
s.processDHCPHosts,
s.processRestrictLocal,
s.processDHCPAddrs,
s.processFilteringBeforeRequest,
s.processLocalPTR,
s.processUpstream,
s.processFilteringAfterResponse,
s.ipset.process,
@@ -145,24 +129,6 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error
return nil
}
// processRecursion checks the incoming request and halts its handling by
// answering NXDOMAIN if s has tried to resolve it recently.
func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing recursion")
defer log.Debug("dnsforward: finished processing recursion")
pctx := dctx.proxyCtx
if msg := pctx.Req; msg != nil && s.recDetector.check(*msg) {
log.Debug("dnsforward: recursion detected resolving %q", msg.Question[0].Name)
pctx.Res = s.genNXDomain(pctx.Req)
return resultCodeFinish
}
return resultCodeSuccess
}
// mozillaFQDN is the domain used to signal the Firefox browser to not use its
// own DoH server.
//
@@ -199,14 +165,14 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
}
if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == mozillaFQDN {
pctx.Res = s.genNXDomain(pctx.Req)
pctx.Res = s.NewMsgNXDOMAIN(pctx.Req)
return resultCodeFinish
}
if q.Name == healthcheckFQDN {
// Generate a NODATA negative response to make nslookup exit with 0.
pctx.Res = s.makeResponse(pctx.Req)
pctx.Res = s.replyCompressed(pctx.Req)
return resultCodeFinish
}
@@ -272,7 +238,7 @@ func (s *Server) processDDRQuery(dctx *dnsContext) (rc resultCode) {
//
// [draft standard]: https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html.
func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
resp = s.makeResponse(req)
resp = s.replyCompressed(req)
if req.Question[0].Qtype != dns.TypeSVCB {
return resp
}
@@ -339,19 +305,6 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
return resp
}
// processDetermineLocal determines if the client's IP address is from locally
// served network and saves the result into the context.
func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing local detection")
defer log.Debug("dnsforward: finished processing local detection")
rc = resultCodeSuccess
dctx.isLocalClient = s.privateNets.Contains(dctx.proxyCtx.Addr.Addr())
return rc
}
// processDHCPHosts respond to A requests if the target hostname is known to
// the server. It responds with a mapped IP address if the DNS64 is enabled and
// the request is for AAAA.
@@ -370,9 +323,9 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess
}
if !dctx.isLocalClient {
if !pctx.IsPrivateClient {
log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, dhcpHost)
pctx.Res = s.genNXDomain(req)
pctx.Res = s.NewMsgNXDOMAIN(req)
// Do not even put into query log.
return resultCodeFinish
@@ -389,7 +342,7 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: dhcp record for %q is %s", dhcpHost, ip)
resp := s.makeResponse(req)
resp := s.replyCompressed(req)
switch q.Qtype {
case dns.TypeA:
a := &dns.A{
@@ -416,141 +369,6 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess
}
// indexFirstV4Label returns the index at which the reversed IPv4 address
// starts, assuming the domain is pre-validated ARPA domain having in-addr and
// arpa labels removed.
func indexFirstV4Label(domain string) (idx int) {
idx = len(domain)
for labelsNum := 0; labelsNum < net.IPv4len && idx > 0; labelsNum++ {
curIdx := strings.LastIndexByte(domain[:idx-1], '.') + 1
_, parseErr := strconv.ParseUint(domain[curIdx:idx-1], 10, 8)
if parseErr != nil {
return idx
}
idx = curIdx
}
return idx
}
// indexFirstV6Label returns the index at which the reversed IPv6 address
// starts, assuming the domain is pre-validated ARPA domain having ip6 and arpa
// labels removed.
func indexFirstV6Label(domain string) (idx int) {
idx = len(domain)
for labelsNum := 0; labelsNum < net.IPv6len*2 && idx > 0; labelsNum++ {
curIdx := idx - len("a.")
if curIdx > 1 && domain[curIdx-1] != '.' {
return idx
}
nibble := domain[curIdx]
if (nibble < '0' || nibble > '9') && (nibble < 'a' || nibble > 'f') {
return idx
}
idx = curIdx
}
return idx
}
// extractARPASubnet tries to convert a reversed ARPA address being a part of
// domain to an IP network. domain must be an FQDN.
//
// TODO(e.burkov): Move to golibs.
func extractARPASubnet(domain string) (pref netip.Prefix, err error) {
err = netutil.ValidateDomainName(strings.TrimSuffix(domain, "."))
if err != nil {
// Don't wrap the error since it's informative enough as is.
return netip.Prefix{}, err
}
const (
v4Suffix = "in-addr.arpa."
v6Suffix = "ip6.arpa."
)
domain = strings.ToLower(domain)
var idx int
switch {
case strings.HasSuffix(domain, v4Suffix):
idx = indexFirstV4Label(domain[:len(domain)-len(v4Suffix)])
case strings.HasSuffix(domain, v6Suffix):
idx = indexFirstV6Label(domain[:len(domain)-len(v6Suffix)])
default:
return netip.Prefix{}, &netutil.AddrError{
Err: netutil.ErrNotAReversedSubnet,
Kind: netutil.AddrKindARPA,
Addr: domain,
}
}
return netutil.PrefixFromReversedAddr(domain[idx:])
}
// processRestrictLocal responds with NXDOMAIN to PTR requests for IP addresses
// in locally served network from external clients.
func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing local restriction")
defer log.Debug("dnsforward: finished processing local restriction")
pctx := dctx.proxyCtx
req := pctx.Req
q := req.Question[0]
if q.Qtype != dns.TypePTR {
// No need for restriction.
return resultCodeSuccess
}
subnet, err := extractARPASubnet(q.Name)
if err != nil {
if errors.Is(err, netutil.ErrNotAReversedSubnet) {
log.Debug("dnsforward: request is not for arpa domain")
return resultCodeSuccess
}
log.Debug("dnsforward: parsing reversed addr: %s", err)
return resultCodeError
}
// Restrict an access to local addresses for external clients. We also
// assume that all the DHCP leases we give are locally served or at least
// shouldn't be accessible externally.
subnetAddr := subnet.Addr()
if !s.privateNets.Contains(subnetAddr) {
return resultCodeSuccess
}
log.Debug("dnsforward: addr %s is from locally served network", subnetAddr)
if !dctx.isLocalClient {
log.Debug("dnsforward: %q requests an internal ip", pctx.Addr)
pctx.Res = s.genNXDomain(req)
// Do not even put into query log.
return resultCodeFinish
}
// Do not perform unreversing ever again.
dctx.unreversedReqIP = subnetAddr
// There is no need to filter request from external addresses since this
// code is only executed when the request is for locally served ARPA
// hostname so disable redundant filters.
dctx.setts.ParentalEnabled = false
dctx.setts.SafeBrowsingEnabled = false
dctx.setts.SafeSearchEnabled = false
dctx.setts.ServicesRules = nil
// Nothing to restrict.
return resultCodeSuccess
}
// processDHCPAddrs responds to PTR requests if the target IP is leased by the
// DHCP server.
func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
@@ -562,23 +380,27 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess
}
ipAddr := dctx.unreversedReqIP
if ipAddr == (netip.Addr{}) {
req := pctx.Req
q := req.Question[0]
pref := pctx.RequestedPrivateRDNS
// TODO(e.burkov): Consider answering authoritatively for SOA and NS
// queries.
if pref == (netip.Prefix{}) || q.Qtype != dns.TypePTR {
return resultCodeSuccess
}
host := s.dhcpServer.HostByIP(ipAddr)
addr := pref.Addr()
host := s.dhcpServer.HostByIP(addr)
if host == "" {
return resultCodeSuccess
}
log.Debug("dnsforward: dhcp client %s is %q", ipAddr, host)
log.Debug("dnsforward: dhcp client %s is %q", addr, host)
req := pctx.Req
resp := s.makeResponse(req)
resp := s.replyCompressed(req)
ptr := &dns.PTR{
Hdr: dns.RR_Header{
Name: req.Question[0].Name,
Name: q.Name,
Rrtype: dns.TypePTR,
// TODO(e.burkov): Use [dhcpsvc.Lease.Expiry]. See
// https://github.com/AdguardTeam/AdGuardHome/issues/3932.
@@ -593,62 +415,20 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess
}
// processLocalPTR responds to PTR requests if the target IP is detected to be
// inside the local network and the query was not answered from DHCP.
func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing local ptr")
defer log.Debug("dnsforward: finished processing local ptr")
pctx := dctx.proxyCtx
if pctx.Res != nil {
return resultCodeSuccess
}
ip := dctx.unreversedReqIP
if ip == (netip.Addr{}) {
return resultCodeSuccess
}
s.serverLock.RLock()
defer s.serverLock.RUnlock()
if s.conf.UsePrivateRDNS {
s.recDetector.add(*pctx.Req)
if err := s.localResolvers.Resolve(pctx); err != nil {
log.Debug("dnsforward: resolving private address: %s", err)
// Generate the server failure if the private upstream configuration
// is empty.
//
// This is a crutch, see TODO at [Server.localResolvers].
if errors.Is(err, upstream.ErrNoUpstreams) {
pctx.Res = s.genServerFailure(pctx.Req)
// Do not even put into query log.
return resultCodeFinish
}
dctx.err = err
return resultCodeError
}
}
if pctx.Res == nil {
pctx.Res = s.genNXDomain(pctx.Req)
// Do not even put into query log.
return resultCodeFinish
}
return resultCodeSuccess
}
// Apply filtering logic
func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing filtering before req")
defer log.Debug("dnsforward: finished processing filtering before req")
if dctx.proxyCtx.RequestedPrivateRDNS != (netip.Prefix{}) {
// There is no need to filter request for locally served ARPA hostname
// so disable redundant filters.
dctx.setts.ParentalEnabled = false
dctx.setts.SafeBrowsingEnabled = false
dctx.setts.SafeSearchEnabled = false
dctx.setts.ServicesRules = nil
}
if dctx.proxyCtx.Res != nil {
// Go on since the response is already set.
return resultCodeSuccess
@@ -695,7 +475,7 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
// local domain name if there is one.
name := req.Question[0].Name
log.Debug("dnsforward: dhcp client hostname %q was not filtered", name[:len(name)-1])
pctx.Res = s.genNXDomain(req)
pctx.Res = s.NewMsgNXDOMAIN(req)
return resultCodeFinish
}
@@ -712,21 +492,7 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
return resultCodeError
}
if err := prx.Resolve(pctx); err != nil {
if errors.Is(err, upstream.ErrNoUpstreams) {
// Do not even put into querylog. Currently this happens either
// when the private resolvers enabled and the request is DNS64 PTR,
// or when the client isn't considered local by prx.
//
// TODO(e.burkov): Make proxy detect local client the same way as
// AGH does.
pctx.Res = s.genNXDomain(req)
return resultCodeFinish
}
dctx.err = err
if dctx.err = prx.Resolve(pctx); dctx.err != nil {
return resultCodeError
}
@@ -810,7 +576,7 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
}
// Use the ClientID first, since it has a higher priority.
id := stringutil.Coalesce(clientID, pctx.Addr.Addr().String())
id := cmp.Or(clientID, pctx.Addr.Addr().String())
upsConf, err := s.conf.ClientsContainer.UpstreamConfigByID(id, s.bootstrap)
if err != nil {
log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err)
@@ -835,7 +601,8 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
return resultCodeSuccess
case
filtering.Rewritten,
filtering.RewrittenRule:
filtering.RewrittenRule,
filtering.FilteredSafeSearch:
if dctx.origQuestion.Name == "" {
// origQuestion is set in case we get only CNAME without IP from
@@ -845,11 +612,10 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
pctx := dctx.proxyCtx
pctx.Req.Question[0], pctx.Res.Question[0] = dctx.origQuestion, dctx.origQuestion
if len(pctx.Res.Answer) > 0 {
rr := s.genAnswerCNAME(pctx.Req, res.CanonName)
answer := append([]dns.RR{rr}, pctx.Res.Answer...)
pctx.Res.Answer = answer
}
rr := s.genAnswerCNAME(pctx.Req, res.CanonName)
answer := append([]dns.RR{rr}, pctx.Res.Answer...)
pctx.Res.Answer = answer
return resultCodeSuccess
default:

View File

@@ -1,14 +1,15 @@
package dnsforward
import (
"cmp"
"net"
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter/rules"
@@ -70,8 +71,6 @@ func TestServer_ProcessInitial(t *testing.T) {
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
@@ -171,8 +170,6 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
@@ -379,44 +376,6 @@ func createTestDNSFilter(t *testing.T) (f *filtering.DNSFilter) {
return f
}
func TestServer_ProcessDetermineLocal(t *testing.T) {
s := &Server{
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
}
testCases := []struct {
want assert.BoolAssertionFunc
name string
cliAddr netip.AddrPort
}{{
want: assert.True,
name: "local",
cliAddr: netip.MustParseAddrPort("192.168.0.1:1"),
}, {
want: assert.False,
name: "external",
cliAddr: netip.MustParseAddrPort("250.249.0.1:1"),
}, {
want: assert.False,
name: "invalid",
cliAddr: netip.AddrPort{},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
proxyCtx := &proxy.DNSContext{
Addr: tc.cliAddr,
}
dctx := &dnsContext{
proxyCtx: proxyCtx,
}
s.processDetermineLocal(dctx)
tc.want(t, dctx.isLocalClient)
})
}
}
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
const (
localDomainSuffix = "lan"
@@ -486,9 +445,9 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{
Req: req,
Req: req,
IsPrivateClient: tc.isLocalCli,
},
isLocalClient: tc.isLocalCli,
}
res := s.processDHCPHosts(dctx)
@@ -621,9 +580,9 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{
Req: req,
Req: req,
IsPrivateClient: true,
},
isLocalClient: true,
}
t.Run(tc.name, func(t *testing.T) {
@@ -658,19 +617,28 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
}
}
func TestServer_ProcessRestrictLocal(t *testing.T) {
// TODO(e.burkov): Rewrite this test to use the whole server instead of just
// testing the [handleDNSRequest] method. See comment on
// "from_external_for_local" test case.
func TestServer_HandleDNSRequest_restrictLocal(t *testing.T) {
intAddr := netip.MustParseAddr("192.168.1.1")
intPTRQuestion, err := netutil.IPToReversedAddr(intAddr.AsSlice())
require.NoError(t, err)
extAddr := netip.MustParseAddr("254.253.252.1")
extPTRQuestion, err := netutil.IPToReversedAddr(extAddr.AsSlice())
require.NoError(t, err)
const (
extPTRQuestion = "251.252.253.254.in-addr.arpa."
extPTRAnswer = "host1.example.net."
intPTRQuestion = "1.1.168.192.in-addr.arpa."
intPTRAnswer = "some.local-client."
extPTRAnswer = "host1.example.net."
intPTRAnswer = "some.local-client."
)
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := aghalg.Coalesce(
resp := cmp.Or(
aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer),
aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
(&dns.Msg{}).SetRcode(req, dns.RcodeNameError),
)
require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
@@ -696,123 +664,165 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
startDeferStop(t, s)
testCases := []struct {
name string
want string
question net.IP
cliAddr netip.AddrPort
wantLen int
name string
question string
wantErr error
wantAns []dns.RR
isPrivate bool
}{{
name: "from_local_to_external",
want: "host1.example.net.",
question: net.IP{254, 253, 252, 251},
cliAddr: netip.MustParseAddrPort("192.168.10.10:1"),
wantLen: 1,
name: "from_local_for_external",
question: extPTRQuestion,
wantErr: nil,
wantAns: []dns.RR{&dns.PTR{
Hdr: dns.RR_Header{
Name: dns.Fqdn(extPTRQuestion),
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 60,
Rdlength: uint16(len(extPTRAnswer) + 1),
},
Ptr: dns.Fqdn(extPTRAnswer),
}},
isPrivate: true,
}, {
name: "from_external_for_local",
want: "",
question: net.IP{192, 168, 1, 1},
cliAddr: netip.MustParseAddrPort("254.253.252.251:1"),
wantLen: 0,
// In theory this case is not reproducible because [proxy.Proxy] should
// respond to such queries with NXDOMAIN before they reach
// [Server.handleDNSRequest].
name: "from_external_for_local",
question: intPTRQuestion,
wantErr: upstream.ErrNoUpstreams,
wantAns: nil,
isPrivate: false,
}, {
name: "from_local_for_local",
want: "some.local-client.",
question: net.IP{192, 168, 1, 1},
cliAddr: netip.MustParseAddrPort("192.168.1.2:1"),
wantLen: 1,
question: intPTRQuestion,
wantErr: nil,
wantAns: []dns.RR{&dns.PTR{
Hdr: dns.RR_Header{
Name: dns.Fqdn(intPTRQuestion),
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 60,
Rdlength: uint16(len(intPTRAnswer) + 1),
},
Ptr: dns.Fqdn(intPTRAnswer),
}},
isPrivate: true,
}, {
name: "from_external_for_external",
want: "host1.example.net.",
question: net.IP{254, 253, 252, 251},
cliAddr: netip.MustParseAddrPort("254.253.252.255:1"),
wantLen: 1,
question: extPTRQuestion,
wantErr: nil,
wantAns: []dns.RR{&dns.PTR{
Hdr: dns.RR_Header{
Name: dns.Fqdn(extPTRQuestion),
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 60,
Rdlength: uint16(len(extPTRAnswer) + 1),
},
Ptr: dns.Fqdn(extPTRAnswer),
}},
isPrivate: false,
}}
for _, tc := range testCases {
reqAddr, err := dns.ReverseAddr(tc.question.String())
require.NoError(t, err)
req := createTestMessageWithType(reqAddr, dns.TypePTR)
pref, extErr := netutil.ExtractReversedAddr(tc.question)
require.NoError(t, extErr)
req := createTestMessageWithType(dns.Fqdn(tc.question), dns.TypePTR)
pctx := &proxy.DNSContext{
Proto: proxy.ProtoTCP,
Req: req,
Addr: tc.cliAddr,
Req: req,
IsPrivateClient: tc.isPrivate,
}
// TODO(e.burkov): Configure the subnet set properly.
if netutil.IsLocallyServed(pref.Addr()) {
pctx.RequestedPrivateRDNS = pref
}
t.Run(tc.name, func(t *testing.T) {
err = s.handleDNSRequest(nil, pctx)
require.NoError(t, err)
require.NotNil(t, pctx.Res)
require.Len(t, pctx.Res.Answer, tc.wantLen)
err = s.handleDNSRequest(s.dnsProxy, pctx)
require.ErrorIs(t, err, tc.wantErr)
if tc.wantLen > 0 {
assert.Equal(t, tc.want, pctx.Res.Answer[0].(*dns.PTR).Ptr)
}
require.NotNil(t, pctx.Res)
assert.Equal(t, tc.wantAns, pctx.Res.Answer)
})
}
}
func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
func TestServer_ProcessUpstream_localPTR(t *testing.T) {
const locDomain = "some.local."
const reqAddr = "1.1.168.192.in-addr.arpa."
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := aghalg.Coalesce(
resp := cmp.Or(
aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
(&dns.Msg{}).SetRcode(req, dns.RcodeNameError),
)
require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
})
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
s := createTestServer(
t,
&filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
},
ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
UsePrivateRDNS: true,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
},
)
var proxyCtx *proxy.DNSContext
var dnsCtx *dnsContext
setup := func(use bool) {
proxyCtx = &proxy.DNSContext{
Addr: testClientAddrPort,
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
newPrxCtx := func() (prxCtx *proxy.DNSContext) {
return &proxy.DNSContext{
Addr: testClientAddrPort,
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
IsPrivateClient: true,
RequestedPrivateRDNS: netip.MustParsePrefix("192.168.1.1/32"),
}
dnsCtx = &dnsContext{
proxyCtx: proxyCtx,
unreversedReqIP: netip.MustParseAddr("192.168.1.1"),
}
s.conf.UsePrivateRDNS = use
}
t.Run("enabled", func(t *testing.T) {
setup(true)
s := createTestServer(
t,
&filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
},
ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
UsePrivateRDNS: true,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
},
)
pctx := newPrxCtx()
rc := s.processLocalPTR(dnsCtx)
rc := s.processUpstream(&dnsContext{proxyCtx: pctx})
require.Equal(t, resultCodeSuccess, rc)
require.NotEmpty(t, proxyCtx.Res.Answer)
require.NotEmpty(t, pctx.Res.Answer)
ptr := testutil.RequireTypeAssert[*dns.PTR](t, pctx.Res.Answer[0])
assert.Equal(t, locDomain, proxyCtx.Res.Answer[0].(*dns.PTR).Ptr)
assert.Equal(t, locDomain, ptr.Ptr)
})
t.Run("disabled", func(t *testing.T) {
setup(false)
s := createTestServer(
t,
&filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
},
ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
UsePrivateRDNS: false,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
},
)
pctx := newPrxCtx()
rc := s.processLocalPTR(dnsCtx)
require.Equal(t, resultCodeFinish, rc)
require.Empty(t, proxyCtx.Res.Answer)
rc := s.processUpstream(&dnsContext{proxyCtx: pctx})
require.Equal(t, resultCodeError, rc)
require.Empty(t, pctx.Res.Answer)
})
}
@@ -830,129 +840,3 @@ func TestIPStringFromAddr(t *testing.T) {
assert.Empty(t, ipStringFromAddr(nil))
})
}
// TODO(e.burkov): Add fuzzing when moving to golibs.
func TestExtractARPASubnet(t *testing.T) {
const (
v4Suf = `in-addr.arpa.`
v4Part = `2.1.` + v4Suf
v4Whole = `4.3.` + v4Part
v6Suf = `ip6.arpa.`
v6Part = `4.3.2.1.0.0.0.0.0.0.0.0.0.0.0.0.` + v6Suf
v6Whole = `f.e.d.c.0.0.0.0.0.0.0.0.0.0.0.0.` + v6Part
)
v4Pref := netip.MustParsePrefix("1.2.3.4/32")
v4PrefPart := netip.MustParsePrefix("1.2.0.0/16")
v6Pref := netip.MustParsePrefix("::1234:0:0:0:cdef/128")
v6PrefPart := netip.MustParsePrefix("0:0:0:1234::/64")
testCases := []struct {
want netip.Prefix
name string
domain string
wantErr string
}{{
want: netip.Prefix{},
name: "not_an_arpa",
domain: "some.domain.name.",
wantErr: `bad arpa domain name "some.domain.name.": ` +
`not a reversed ip network`,
}, {
want: netip.Prefix{},
name: "bad_domain_name",
domain: "abc.123.",
wantErr: `bad domain name "abc.123": ` +
`bad top-level domain name label "123": all octets are numeric`,
}, {
want: v4Pref,
name: "whole_v4",
domain: v4Whole,
wantErr: "",
}, {
want: v4PrefPart,
name: "partial_v4",
domain: v4Part,
wantErr: "",
}, {
want: v4Pref,
name: "whole_v4_within_domain",
domain: "a." + v4Whole,
wantErr: "",
}, {
want: v4Pref,
name: "whole_v4_additional_label",
domain: "5." + v4Whole,
wantErr: "",
}, {
want: v4PrefPart,
name: "partial_v4_within_domain",
domain: "a." + v4Part,
wantErr: "",
}, {
want: v4PrefPart,
name: "overflow_v4",
domain: "256." + v4Part,
wantErr: "",
}, {
want: v4PrefPart,
name: "overflow_v4_within_domain",
domain: "a.256." + v4Part,
wantErr: "",
}, {
want: netip.Prefix{},
name: "empty_v4",
domain: v4Suf,
wantErr: `bad arpa domain name "in-addr.arpa": ` +
`not a reversed ip network`,
}, {
want: netip.Prefix{},
name: "empty_v4_within_domain",
domain: "a." + v4Suf,
wantErr: `bad arpa domain name "in-addr.arpa": ` +
`not a reversed ip network`,
}, {
want: v6Pref,
name: "whole_v6",
domain: v6Whole,
wantErr: "",
}, {
want: v6PrefPart,
name: "partial_v6",
domain: v6Part,
}, {
want: v6Pref,
name: "whole_v6_within_domain",
domain: "g." + v6Whole,
wantErr: "",
}, {
want: v6Pref,
name: "whole_v6_additional_label",
domain: "1." + v6Whole,
wantErr: "",
}, {
want: v6PrefPart,
name: "partial_v6_within_domain",
domain: "label." + v6Part,
wantErr: "",
}, {
want: netip.Prefix{},
name: "empty_v6",
domain: v6Suf,
wantErr: `bad arpa domain name "ip6.arpa": not a reversed ip network`,
}, {
want: netip.Prefix{},
name: "empty_v6_within_domain",
domain: "g." + v6Suf,
wantErr: `bad arpa domain name "ip6.arpa": not a reversed ip network`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
subnet, err := extractARPASubnet(tc.domain)
testutil.AssertErrorMsg(t, tc.wantErr, err)
assert.Equal(t, tc.want, subnet)
})
}
}

View File

@@ -29,7 +29,13 @@ func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: client ip for stats and querylog: %s", ipStr)
ids := []string{ipStr, dctx.clientID}
ids := []string{ipStr}
if dctx.clientID != "" {
// Use the ClientID first because it has a higher priority. Filters
// have the same priority, see applyAdditionalFiltering.
ids = []string{dctx.clientID, ipStr}
}
qt, cl := q.Qtype, q.Qclass
// Synchronize access to s.queryLog and s.stats so they won't be suddenly
@@ -124,7 +130,7 @@ func (s *Server) logQuery(dctx *dnsContext, ip net.IP, processingTime time.Durat
s.queryLog.Add(p)
}
// updatesStats writes the request into statistics.
// updateStats writes the request data into statistics.
func (s *Server) updateStats(dctx *dnsContext, clientIP string, processingTime time.Duration) {
pctx := dctx.proxyCtx

View File

@@ -2,90 +2,77 @@ package dnsforward
import (
"fmt"
"net/netip"
"os"
"slices"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"golang.org/x/exp/maps"
)
// loadUpstreams parses upstream DNS servers from the configured file or from
// the configuration itself.
func (s *Server) loadUpstreams() (upstreams []string, err error) {
if s.conf.UpstreamDNSFileName == "" {
return stringutil.FilterOut(s.conf.UpstreamDNS, IsCommentOrEmpty), nil
// newBootstrap returns a bootstrap resolver based on the configuration of s.
// boots are the upstream resolvers that should be closed after use. r is the
// actual bootstrap resolver, which may include the system hosts.
//
// TODO(e.burkov): This function currently returns a resolver and a slice of
// the upstream resolvers, which are essentially the same. boots are returned
// for being able to close them afterwards, but it introduces an implicit
// contract that r could only be used before that. Anyway, this code should
// improve when the [proxy.UpstreamConfig] will become an [upstream.Resolver]
// and be used here.
func newBootstrap(
addrs []string,
etcHosts upstream.Resolver,
opts *upstream.Options,
) (r upstream.Resolver, boots []*upstream.UpstreamResolver, err error) {
if len(addrs) == 0 {
addrs = defaultBootstrap
}
var data []byte
data, err = os.ReadFile(s.conf.UpstreamDNSFileName)
boots, err = aghnet.ParseBootstraps(addrs, opts)
if err != nil {
return nil, fmt.Errorf("reading upstream from file: %w", err)
// Don't wrap the error, since it's informative enough as is.
return nil, nil, err
}
upstreams = stringutil.SplitTrimmed(string(data), "\n")
var parallel upstream.ParallelResolver
for _, b := range boots {
parallel = append(parallel, upstream.NewCachingResolver(b))
}
log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), s.conf.UpstreamDNSFileName)
if etcHosts != nil {
r = upstream.ConsequentResolver{etcHosts, parallel}
} else {
r = parallel
}
return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil
return r, boots, nil
}
// prepareUpstreamSettings sets upstream DNS server settings.
func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
// Load upstreams either from the file, or from the settings
var upstreams []string
upstreams, err = s.loadUpstreams()
if err != nil {
return fmt.Errorf("loading upstreams: %w", err)
}
s.conf.UpstreamConfig, err = s.prepareUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
Bootstrap: boot,
Timeout: s.conf.UpstreamTimeout,
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
PreferIPv6: s.conf.BootstrapPreferIPv6,
// Use a customized set of RootCAs, because Go's default mechanism of
// loading TLS roots does not always work properly on some routers so we're
// loading roots manually and pass it here.
//
// See [aghtls.SystemRootCAs].
//
// TODO(a.garipov): Investigate if that's true.
RootCAs: s.conf.TLSv12Roots,
CipherSuites: s.conf.TLSCiphers,
})
if err != nil {
return fmt.Errorf("preparing upstream config: %w", err)
}
return nil
}
// prepareUpstreamConfig returns the upstream configuration based on upstreams
// and configuration of s.
func (s *Server) prepareUpstreamConfig(
// newUpstreamConfig returns the upstream configuration based on upstreams. If
// upstreams slice specifies no default upstreams, defaultUpstreams are used to
// create upstreams with no domain specifications. opts are used when creating
// upstream configuration.
func newUpstreamConfig(
upstreams []string,
defaultUpstreams []string,
opts *upstream.Options,
) (uc *proxy.UpstreamConfig, err error) {
uc, err = proxy.ParseUpstreamsConfig(upstreams, opts)
if err != nil {
return nil, fmt.Errorf("parsing upstream config: %w", err)
return uc, fmt.Errorf("parsing upstreams: %w", err)
}
if len(uc.Upstreams) == 0 && defaultUpstreams != nil {
if len(uc.Upstreams) == 0 && len(defaultUpstreams) > 0 {
log.Info("dnsforward: warning: no default upstreams specified, using %v", defaultUpstreams)
var defaultUpstreamConfig *proxy.UpstreamConfig
defaultUpstreamConfig, err = proxy.ParseUpstreamsConfig(defaultUpstreams, opts)
if err != nil {
return nil, fmt.Errorf("parsing default upstreams: %w", err)
return uc, fmt.Errorf("parsing default upstreams: %w", err)
}
uc.Upstreams = defaultUpstreamConfig.Upstreams
@@ -94,6 +81,54 @@ func (s *Server) prepareUpstreamConfig(
return uc, nil
}
// newPrivateConfig creates an upstream configuration for resolving PTR records
// for local addresses. The configuration is built either from the provided
// addresses or from the system resolvers. unwanted filters the resulting
// upstream configuration.
func newPrivateConfig(
addrs []string,
unwanted addrPortSet,
sysResolvers SystemResolvers,
privateNets netutil.SubnetSet,
opts *upstream.Options,
) (uc *proxy.UpstreamConfig, err error) {
confNeedsFiltering := len(addrs) > 0
if confNeedsFiltering {
addrs = stringutil.FilterOut(addrs, IsCommentOrEmpty)
} else {
sysResolvers := slices.DeleteFunc(slices.Clone(sysResolvers.Addrs()), unwanted.Has)
addrs = make([]string, 0, len(sysResolvers))
for _, r := range sysResolvers {
addrs = append(addrs, r.String())
}
}
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", addrs)
uc, err = proxy.ParseUpstreamsConfig(addrs, opts)
if err != nil {
return uc, fmt.Errorf("preparing private upstreams: %w", err)
}
if !confNeedsFiltering {
return uc, nil
}
err = filterOutAddrs(uc, unwanted)
if err != nil {
return uc, fmt.Errorf("filtering private upstreams: %w", err)
}
// Prevalidate the config to catch the exact error before creating proxy.
// See TODO on [PrivateRDNSError].
err = proxy.ValidatePrivateConfig(uc, privateNets)
if err != nil {
return uc, &PrivateRDNSError{err: err}
}
return uc, nil
}
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration
// depending on configuration.
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
@@ -130,85 +165,9 @@ func setProxyUpstreamMode(
return nil
}
// createBootstrap returns a bootstrap resolver based on the configuration of s.
// boots are the upstream resolvers that should be closed after use. r is the
// actual bootstrap resolver, which may include the system hosts.
//
// TODO(e.burkov): This function currently returns a resolver and a slice of
// the upstream resolvers, which are essentially the same. boots are returned
// for being able to close them afterwards, but it introduces an implicit
// contract that r could only be used before that. Anyway, this code should
// improve when the [proxy.UpstreamConfig] will become an [upstream.Resolver]
// and be used here.
func (s *Server) createBootstrap(
addrs []string,
opts *upstream.Options,
) (r upstream.Resolver, boots []*upstream.UpstreamResolver, err error) {
if len(addrs) == 0 {
addrs = defaultBootstrap
}
boots, err = aghnet.ParseBootstraps(addrs, opts)
if err != nil {
// Don't wrap the error, since it's informative enough as is.
return nil, nil, err
}
var parallel upstream.ParallelResolver
for _, b := range boots {
parallel = append(parallel, upstream.NewCachingResolver(b))
}
if s.etcHosts != nil {
r = upstream.ConsequentResolver{s.etcHosts, parallel}
} else {
r = parallel
}
return r, boots, nil
}
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
// This function is useful for filtering out non-upstream lines from upstream
// configs.
func IsCommentOrEmpty(s string) (ok bool) {
return len(s) == 0 || s[0] == '#'
}
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
// upstream is invalid or if there are no default upstreams specified. It also
// checks each domain of domain-specific upstreams for being ARPA pointing to
// a locally-served network. privateNets must not be nil.
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
conf, err := proxy.ParseUpstreamsConfig(upstreams, &upstream.Options{})
if err != nil {
return fmt.Errorf("creating config: %w", err)
}
if conf == nil {
return nil
}
keys := maps.Keys(conf.DomainReservedUpstreams)
slices.Sort(keys)
var errs []error
for _, domain := range keys {
var subnet netip.Prefix
subnet, err = extractARPASubnet(domain)
if err != nil {
errs = append(errs, err)
continue
}
if !privateNets.Contains(subnet.Addr()) {
errs = append(
errs,
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
)
}
}
return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w")
}

View File

@@ -559,6 +559,8 @@ type Result struct {
Reason Reason `json:",omitempty"`
// IsFiltered is true if the request is filtered.
//
// TODO(d.kolyshev): Get rid of this flag.
IsFiltered bool `json:",omitempty"`
}

View File

@@ -200,7 +200,7 @@ func TestParallelSB(t *testing.T) {
t.Cleanup(d.Close)
t.Run("group", func(t *testing.T) {
for i := 0; i < 100; i++ {
for i := range 100 {
t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) {
t.Parallel()
d.checkMatch(t, sbBlocked, setts)
@@ -670,7 +670,7 @@ func BenchmarkSafeBrowsing(b *testing.B) {
}, nil)
b.Cleanup(d.Close)
for n := 0; n < b.N; n++ {
for range b.N {
res, err := d.CheckHost(sbBlocked, dns.TypeA, setts)
require.NoError(b, err)

View File

@@ -63,8 +63,6 @@ func TestIDGenerator_Fix(t *testing.T) {
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
g := newIDGenerator(1)
g.fix(tc.in)

View File

@@ -1,7 +1,6 @@
package rulelist_test
import (
"context"
"net/http"
"testing"
@@ -28,14 +27,12 @@ func TestEngine_Refresh(t *testing.T) {
require.NotNil(t, eng)
testutil.CleanupAndRequireSuccess(t, eng.Close)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
t.Cleanup(cancel)
buf := make([]byte, rulelist.DefaultRuleBufSize)
cli := &http.Client{
Timeout: testTimeout,
}
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := eng.Refresh(ctx, buf, cli, cacheDir, rulelist.DefaultMaxRuleListSize)
require.NoError(t, err)

View File

@@ -1,7 +1,6 @@
package rulelist_test
import (
"context"
"net/http"
"net/url"
"os"
@@ -67,14 +66,12 @@ func TestFilter_Refresh(t *testing.T) {
require.NotNil(t, f)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
t.Cleanup(cancel)
buf := make([]byte, rulelist.DefaultRuleBufSize)
cli := &http.Client{
Timeout: testTimeout,
}
ctx := testutil.ContextWithTimeout(t, testTimeout)
res, err := f.Refresh(ctx, buf, cli, cacheDir, rulelist.DefaultMaxRuleListSize)
require.NoError(t, err)

View File

@@ -132,7 +132,6 @@ func TestParser_Parse(t *testing.T) {
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
@@ -216,7 +215,7 @@ func BenchmarkParser_Parse(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
for range b.N {
resSink, errSink = p.Parse(dst, src, buf)
dst.Reset()
}

View File

@@ -1,7 +1,5 @@
package filtering
import "github.com/miekg/dns"
// SafeSearch interface describes a service for search engines hosts rewrites.
type SafeSearch interface {
// CheckHost checks host with safe search filter. CheckHost must be safe
@@ -16,9 +14,6 @@ type SafeSearch interface {
// SafeSearchConfig is a struct with safe search related settings.
type SafeSearchConfig struct {
// CustomResolver is the resolver used by safe search.
CustomResolver Resolver `yaml:"-" json:"-"`
// Enabled indicates if safe search is enabled entirely.
Enabled bool `yaml:"enabled" json:"enabled"`
@@ -40,13 +35,7 @@ func (d *DNSFilter) checkSafeSearch(
qtype uint16,
setts *Settings,
) (res Result, err error) {
if !setts.ProtectionEnabled ||
!setts.SafeSearchEnabled ||
(qtype != dns.TypeA && qtype != dns.TypeAAAA) {
return Result{}, nil
}
if d.safeSearch == nil {
if d.safeSearch == nil || !setts.ProtectionEnabled || !setts.SafeSearchEnabled {
return Result{}, nil
}

View File

@@ -3,11 +3,9 @@ package safesearch
import (
"bytes"
"context"
"encoding/binary"
"encoding/gob"
"fmt"
"net"
"net/netip"
"strings"
"sync"
@@ -67,7 +65,6 @@ type Default struct {
engine *urlfilter.DNSEngine
cache cache.Cache
resolver filtering.Resolver
logPrefix string
cacheTTL time.Duration
}
@@ -80,11 +77,6 @@ func NewDefault(
cacheSize uint,
cacheTTL time.Duration,
) (ss *Default, err error) {
var resolver filtering.Resolver = net.DefaultResolver
if conf.CustomResolver != nil {
resolver = conf.CustomResolver
}
ss = &Default{
mu: &sync.RWMutex{},
@@ -92,7 +84,6 @@ func NewDefault(
EnableLRU: true,
MaxSize: cacheSize,
}),
resolver: resolver,
// Use %s, because the client safe-search names already contain double
// quotes.
logPrefix: fmt.Sprintf("safesearch %s: ", name),
@@ -170,8 +161,11 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res
ss.log(log.DEBUG, "lookup for %q finished in %s", host, time.Since(start))
}()
if qtype != dns.TypeA && qtype != dns.TypeAAAA {
return filtering.Result{}, fmt.Errorf("unsupported question type %s", dns.Type(qtype))
switch qtype {
case dns.TypeA, dns.TypeAAAA, dns.TypeHTTPS:
// Go on.
default:
return filtering.Result{}, nil
}
// Check cache. Return cached result if it was found
@@ -195,6 +189,9 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res
}
res = *fltRes
// TODO(a.garipov): Consider switch back to resolving CNAME records IPs and
// saving results to cache.
ss.setCacheResult(host, qtype, res)
return res, nil
@@ -223,20 +220,13 @@ func (ss *Default) searchHost(host string, qtype rules.RRType) (res *rules.DNSRe
}
// newResult creates Result object from rewrite rule. qtype must be either
// [dns.TypeA] or [dns.TypeAAAA]. If err is nil, res is never nil, so that the
// empty result is converted into a NODATA response.
//
// TODO(a.garipov): Use the main rewrite result mechanism used in
// [dnsforward.Server.filterDNSRequest]. Now we resolve IPs for CNAME to save
// them in the safe search cache.
// [dns.TypeA] or [dns.TypeAAAA], or [dns.TypeHTTPS]. If err is nil, res is
// never nil, so that the empty result is converted into a NODATA response.
func (ss *Default) newResult(
rewrite *rules.DNSRewrite,
qtype rules.RRType,
) (res *filtering.Result, err error) {
res = &filtering.Result{
Rules: []*filtering.ResultRule{{
FilterListID: rulelist.URLFilterIDSafeSearch,
}},
Reason: filtering.FilteredSafeSearch,
IsFiltered: true,
}
@@ -247,69 +237,19 @@ func (ss *Default) newResult(
return nil, fmt.Errorf("expected ip rewrite value, got %T(%[1]v)", rewrite.Value)
}
res.Rules[0].IP = ip
res.Rules = []*filtering.ResultRule{{
FilterListID: rulelist.URLFilterIDSafeSearch,
IP: ip,
}}
return res, nil
}
host := rewrite.NewCNAME
if host == "" {
return res, nil
}
res.CanonName = host
ss.log(log.DEBUG, "resolving %q", host)
ips, err := ss.resolver.LookupIP(context.Background(), qtypeToProto(qtype), host)
if err != nil {
return nil, fmt.Errorf("resolving cname: %w", err)
}
ss.log(log.DEBUG, "resolved %s", ips)
for _, ip := range ips {
// TODO(a.garipov): Remove this filtering once the resolver we use
// actually learns about network.
addr := fitToProto(ip, qtype)
if addr == (netip.Addr{}) {
continue
}
// TODO(e.burkov): Rules[0]?
res.Rules[0].IP = addr
}
res.CanonName = rewrite.NewCNAME
return res, nil
}
// qtypeToProto returns "ip4" for [dns.TypeA] and "ip6" for [dns.TypeAAAA].
// It panics for other types.
func qtypeToProto(qtype rules.RRType) (proto string) {
switch qtype {
case dns.TypeA:
return "ip4"
case dns.TypeAAAA:
return "ip6"
default:
panic(fmt.Errorf("safesearch: unsupported question type %s", dns.Type(qtype)))
}
}
// fitToProto returns a non-nil IP address if ip is the correct protocol version
// for qtype. qtype is expected to be either [dns.TypeA] or [dns.TypeAAAA].
func fitToProto(ip net.IP, qtype rules.RRType) (res netip.Addr) {
if ip4 := ip.To4(); qtype == dns.TypeA {
if ip4 != nil {
return netip.AddrFrom4([4]byte(ip4))
}
} else if ip = ip.To16(); ip != nil && qtype == dns.TypeAAAA {
return netip.AddrFrom16([16]byte(ip))
}
return netip.Addr{}
}
// setCacheResult stores data in cache for host. qtype is expected to be either
// [dns.TypeA] or [dns.TypeAAAA].
func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering.Result) {

View File

@@ -1,13 +1,10 @@
package safesearch
import (
"context"
"net"
"net/netip"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
@@ -79,47 +76,6 @@ func TestSafeSearchCacheYandex(t *testing.T) {
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
}
func TestSafeSearchCacheGoogle(t *testing.T) {
const domain = "www.google.ru"
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
res, err := ss.CheckHost(domain, testQType)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules)
resolver := &aghtest.Resolver{
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
ip4, ip6 := aghtest.HostToIPs(host)
return []net.IP{ip4.AsSlice(), ip6.AsSlice()}, nil
},
}
ss = newForTest(t, defaultSafeSearchConf)
ss.resolver = resolver
// Lookup for safesearch domain.
rewrite := ss.searchHost(domain, testQType)
wantIP, _ := aghtest.HostToIPs(rewrite.NewCNAME)
res, err = ss.CheckHost(domain, testQType)
require.NoError(t, err)
require.Len(t, res.Rules, 1)
assert.Equal(t, wantIP, res.Rules[0].IP)
// Check cache.
cachedValue, isFound := ss.getCachedResult(domain, testQType)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.Equal(t, wantIP, cachedValue.Rules[0].IP)
}
const googleHost = "www.google.com"
var dnsRewriteSink *rules.DNSRewrite
@@ -127,7 +83,7 @@ var dnsRewriteSink *rules.DNSRewrite
func BenchmarkSafeSearch(b *testing.B) {
ss := newForTest(b, defaultSafeSearchConf)
for n := 0; n < b.N; n++ {
for range b.N {
dnsRewriteSink = ss.searchHost(googleHost, testQType)
}

View File

@@ -7,7 +7,6 @@ import (
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
@@ -31,8 +30,6 @@ const (
// testConf is the default safe search configuration for tests.
var testConf = filtering.SafeSearchConfig{
CustomResolver: nil,
Enabled: true,
Bing: true,
@@ -52,61 +49,60 @@ func TestDefault_CheckHost_yandex(t *testing.T) {
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
require.NoError(t, err)
// Check host for each domain.
for _, host := range []string{
hosts := []string{
"yandex.ru",
"yAndeX.ru",
"YANdex.COM",
"yandex.by",
"yandex.kz",
"www.yandex.com",
} {
var res filtering.Result
res, err = ss.CheckHost(host, testQType)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
require.Len(t, res.Rules, 1)
assert.Equal(t, yandexIP, res.Rules[0].IP)
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
}
}
func TestDefault_CheckHost_yandexAAAA(t *testing.T) {
conf := testConf
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
require.NoError(t, err)
testCases := []struct {
want netip.Addr
name string
qt uint16
}{{
want: yandexIP,
name: "a",
qt: dns.TypeA,
}, {
want: netip.Addr{},
name: "aaaa",
qt: dns.TypeAAAA,
}, {
want: netip.Addr{},
name: "https",
qt: dns.TypeHTTPS,
}}
res, err := ss.CheckHost("www.yandex.ru", dns.TypeAAAA)
require.NoError(t, err)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, host := range hosts {
// Check host for each domain.
var res filtering.Result
res, err = ss.CheckHost(host, tc.qt)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
assert.True(t, res.IsFiltered)
assert.Equal(t, filtering.FilteredSafeSearch, res.Reason)
// TODO(a.garipov): Currently, the safe-search filter returns a single rule
// with a nil IP address. This isn't really necessary and should be changed
// once the TODO in [safesearch.Default.newResult] is resolved.
require.Len(t, res.Rules, 1)
if tc.want == (netip.Addr{}) {
assert.Empty(t, res.Rules)
} else {
require.Len(t, res.Rules, 1)
assert.Empty(t, res.Rules[0].IP)
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
rule := res.Rules[0]
assert.Equal(t, tc.want, rule.IP)
assert.Equal(t, rulelist.URLFilterIDSafeSearch, rule.FilterListID)
}
}
})
}
}
func TestDefault_CheckHost_google(t *testing.T) {
resolver := &aghtest.Resolver{
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
ip4, ip6 := aghtest.HostToIPs(host)
return []net.IP{ip4.AsSlice(), ip6.AsSlice()}, nil
},
}
wantIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
conf := testConf
conf.CustomResolver = resolver
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
ss, err := safesearch.NewDefault(testConf, "", testCacheSize, testCacheTTL)
require.NoError(t, err)
// Check host for each domain.
@@ -125,11 +121,9 @@ func TestDefault_CheckHost_google(t *testing.T) {
require.NoError(t, err)
assert.True(t, res.IsFiltered)
require.Len(t, res.Rules, 1)
assert.Equal(t, wantIP, res.Rules[0].IP)
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
assert.Equal(t, filtering.FilteredSafeSearch, res.Reason)
assert.Equal(t, "forcesafesearch.google.com", res.CanonName)
assert.Empty(t, res.Rules)
})
}
}
@@ -154,17 +148,7 @@ func (r *testResolver) LookupIP(
}
func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
conf := testConf
conf.CustomResolver = &testResolver{
OnLookupIP: func(_ context.Context, network, host string) (ips []net.IP, err error) {
assert.Equal(t, "ip6", network)
assert.Equal(t, "safe.duckduckgo.com", host)
return nil, nil
},
}
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
ss, err := safesearch.NewDefault(testConf, "", testCacheSize, testCacheTTL)
require.NoError(t, err)
// The DuckDuckGo safe-search addresses are resolved through CNAMEs, but
@@ -174,14 +158,9 @@ func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
require.NoError(t, err)
assert.True(t, res.IsFiltered)
// TODO(a.garipov): Currently, the safe-search filter returns a single rule
// with a nil IP address. This isn't really necessary and should be changed
// once the TODO in [safesearch.Default.newResult] is resolved.
require.Len(t, res.Rules, 1)
assert.Empty(t, res.Rules[0].IP)
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
assert.Equal(t, filtering.FilteredSafeSearch, res.Reason)
assert.Equal(t, "safe.duckduckgo.com", res.CanonName)
assert.Empty(t, res.Rules)
}
func TestDefault_Update(t *testing.T) {

View File

@@ -24,7 +24,6 @@ import (
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"golang.org/x/exp/maps"
)
// DHCP is an interface for accessing DHCP lease data the [clientsContainer]
@@ -46,22 +45,20 @@ type DHCP interface {
// clientsContainer is the storage of all runtime and persistent clients.
type clientsContainer struct {
// TODO(a.garipov): Perhaps use a number of separate indices for different
// types (string, netip.Addr, and so on).
list map[string]*client.Persistent // name -> client
// clientIndex stores information about persistent clients.
clientIndex *client.Index
// ipToRC maps IP addresses to runtime client information.
ipToRC map[netip.Addr]*client.Runtime
// runtimeIndex stores information about runtime clients.
runtimeIndex *client.RuntimeIndex
allTags *container.MapSet[string]
// dhcp is the DHCP service implementation.
dhcp DHCP
// dnsServer is used for checking clients IP status access list status
dnsServer *dnsforward.Server
// clientChecker checks if a client is blocked by the current access
// settings.
clientChecker BlockedClientChecker
// etcHosts contains list of rewrite rules taken from the operating system's
// hosts database.
@@ -90,6 +87,12 @@ type clientsContainer struct {
testing bool
}
// BlockedClientChecker checks if a client is blocked by the current access
// settings.
type BlockedClientChecker interface {
IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string)
}
// Init initializes clients container
// dhcpServer: optional
// Note: this function must be called only once
@@ -100,12 +103,12 @@ func (clients *clientsContainer) Init(
arpDB arpdb.Interface,
filteringConf *filtering.Config,
) (err error) {
if clients.list != nil {
log.Fatal("clients.list != nil")
// TODO(s.chzhen): Refactor it.
if clients.clientIndex != nil {
return errors.Error("clients container already initialized")
}
clients.list = map[string]*client.Persistent{}
clients.ipToRC = map[netip.Addr]*client.Runtime{}
clients.runtimeIndex = client.NewRuntimeIndex()
clients.clientIndex = client.NewIndex()
@@ -248,8 +251,6 @@ func (o *clientObject) toPersistent(
}
if o.SafeSearchConf.Enabled {
o.SafeSearchConf.CustomResolver = safeSearchResolver{}
err = cli.SetSafeSearch(
o.SafeSearchConf,
filteringConf.SafeSearchCacheSize,
@@ -285,9 +286,17 @@ func (clients *clientsContainer) addFromConfig(
return fmt.Errorf("clients: init persistent client at index %d: %w", i, err)
}
_, err = clients.add(cli)
// TODO(s.chzhen): Consider moving to the client index constructor.
err = clients.clientIndex.ClashesUID(cli)
if err != nil {
log.Error("clients: adding client at index %d %s: %s", i, cli.Name, err)
return fmt.Errorf("adding client %s at index %d: %w", cli.Name, i, err)
}
err = clients.add(cli)
if err != nil {
// TODO(s.chzhen): Return an error instead of logging if more
// stringent requirements are implemented.
log.Error("clients: adding client %s at index %d: %s", cli.Name, i, err)
}
}
@@ -300,9 +309,9 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
clients.lock.Lock()
defer clients.lock.Unlock()
objs = make([]*clientObject, 0, len(clients.list))
for _, cli := range clients.list {
o := &clientObject{
objs = make([]*clientObject, 0, clients.clientIndex.Size())
clients.clientIndex.Range(func(cli *client.Persistent) (cont bool) {
objs = append(objs, &clientObject{
Name: cli.Name,
BlockedServices: cli.BlockedServices.Clone(),
@@ -323,10 +332,10 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
IgnoreStatistics: cli.IgnoreStatistics,
UpstreamsCacheEnabled: cli.UpstreamsCacheEnabled,
UpstreamsCacheSize: cli.UpstreamsCacheSize,
}
})
objs = append(objs, o)
}
return true
})
// Maps aren't guaranteed to iterate in the same order each time, so the
// above loop can generate different orderings when writing to the config
@@ -363,8 +372,8 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source)
return client.SourcePersistent
}
rc, ok := clients.ipToRC[ip]
if ok {
rc := clients.runtimeIndex.Client(ip)
if rc != nil {
src, _ = rc.Info()
}
@@ -406,23 +415,26 @@ func (clients *clientsContainer) clientOrArtificial(
id string,
) (c *querylog.Client, art bool) {
defer func() {
c.Disallowed, c.DisallowedRule = clients.dnsServer.IsBlockedClient(ip, id)
c.Disallowed, c.DisallowedRule = clients.clientChecker.IsBlockedClient(ip, id)
if c.WHOIS == nil {
c.WHOIS = &whois.Info{}
}
}()
cli, ok := clients.find(id)
if ok {
if !ok {
cli = clients.clientIndex.FindByIPWithoutZone(ip)
}
if cli != nil {
return &querylog.Client{
Name: cli.Name,
IgnoreQueryLog: cli.IgnoreQueryLog,
}, false
}
var rc *client.Runtime
rc, ok = clients.findRuntimeClient(ip)
if ok {
rc := clients.findRuntimeClient(ip)
if rc != nil {
_, host := rc.Info()
return &querylog.Client{
@@ -542,47 +554,38 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent,
return nil, false
}
for _, c = range clients.list {
_, found := slices.BinarySearchFunc(c.MACs, foundMAC, slices.Compare[net.HardwareAddr])
if found {
return c, true
}
}
return nil, false
return clients.clientIndex.FindByMAC(foundMAC)
}
// runtimeClient returns a runtime client from internal index. Note that it
// doesn't include DHCP clients.
func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime, ok bool) {
func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime) {
if ip == (netip.Addr{}) {
return nil, false
return nil
}
clients.lock.Lock()
defer clients.lock.Unlock()
rc, ok = clients.ipToRC[ip]
return rc, ok
return clients.runtimeIndex.Client(ip)
}
// findRuntimeClient finds a runtime client by their IP.
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime, ok bool) {
rc, ok = clients.runtimeClient(ip)
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime) {
rc = clients.runtimeClient(ip)
host := clients.dhcp.HostByIP(ip)
if host != "" {
if !ok {
rc = &client.Runtime{}
if rc == nil {
rc = client.NewRuntime(ip)
}
rc.SetInfo(client.SourceDHCP, []string{host})
return rc, true
return rc
}
return rc, ok
return rc
}
// check validates the client. It also sorts the client tags.
@@ -615,43 +618,32 @@ func (clients *clientsContainer) check(c *client.Persistent) (err error) {
return nil
}
// add adds a new client object. ok is false if such client already exists or
// if an error occurred.
func (clients *clientsContainer) add(c *client.Persistent) (ok bool, err error) {
// add adds a persistent client or returns an error.
func (clients *clientsContainer) add(c *client.Persistent) (err error) {
err = clients.check(c)
if err != nil {
return false, err
// Don't wrap the error since it's informative enough as is.
return err
}
clients.lock.Lock()
defer clients.lock.Unlock()
// check Name index
_, ok = clients.list[c.Name]
if ok {
return false, nil
}
// check ID index
err = clients.clientIndex.Clashes(c)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return false, err
return err
}
clients.addLocked(c)
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs(), len(clients.list))
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs(), clients.clientIndex.Size())
return true, nil
return nil
}
// addLocked c to the indexes. clients.lock is expected to be locked.
func (clients *clientsContainer) addLocked(c *client.Persistent) {
// update Name index
clients.list[c.Name] = c
// update ID index
clients.clientIndex.Add(c)
}
@@ -660,8 +652,7 @@ func (clients *clientsContainer) remove(name string) (ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
var c *client.Persistent
c, ok = clients.list[name]
c, ok := clients.clientIndex.FindByName(name)
if !ok {
return false
}
@@ -678,9 +669,6 @@ func (clients *clientsContainer) removeLocked(c *client.Persistent) {
log.Error("client container: removing client %s: %s", c.Name, err)
}
// Update the name index.
delete(clients.list, c.Name)
// Update the ID index.
clients.clientIndex.Delete(c)
}
@@ -696,22 +684,6 @@ func (clients *clientsContainer) update(prev, c *client.Persistent) (err error)
clients.lock.Lock()
defer clients.lock.Unlock()
// Check the name index.
if prev.Name != c.Name {
_, ok := clients.list[c.Name]
if ok {
return errors.Error("client already exists")
}
}
if c.EqualIDs(prev) {
clients.removeLocked(prev)
clients.addLocked(c)
return nil
}
// Check the ID index.
err = clients.clientIndex.Clashes(c)
if err != nil {
// Don't wrap the error since it's informative enough as is.
@@ -734,12 +706,12 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
return
}
rc, ok := clients.ipToRC[ip]
if !ok {
rc := clients.runtimeIndex.Client(ip)
if rc == nil {
// Create a RuntimeClient implicitly so that we don't do this check
// again.
rc = &client.Runtime{}
clients.ipToRC[ip] = rc
rc = client.NewRuntime(ip)
clients.runtimeIndex.Add(rc)
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
} else {
@@ -798,61 +770,54 @@ func (clients *clientsContainer) addHostLocked(
host string,
src client.Source,
) (ok bool) {
rc, ok := clients.ipToRC[ip]
if !ok {
rc := clients.runtimeIndex.Client(ip)
if rc == nil {
if src < client.SourceDHCP {
if clients.dhcp.HostByIP(ip) != "" {
return false
}
}
rc = &client.Runtime{}
clients.ipToRC[ip] = rc
rc = client.NewRuntime(ip)
clients.runtimeIndex.Add(rc)
}
rc.SetInfo(src, []string{host})
log.Debug("clients: adding client info %s -> %q %q [%d]", ip, src, host, len(clients.ipToRC))
log.Debug(
"clients: adding client info %s -> %q %q [%d]",
ip,
src,
host,
clients.runtimeIndex.Size(),
)
return true
}
// rmHostsBySrc removes all entries that match the specified source.
func (clients *clientsContainer) rmHostsBySrc(src client.Source) {
n := 0
for ip, rc := range clients.ipToRC {
rc.Unset(src)
if rc.IsEmpty() {
delete(clients.ipToRC, ip)
n++
}
}
log.Debug("clients: removed %d client aliases", n)
}
// addFromHostsFile fills the client-hostname pairing index from the system's
// hosts files.
func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorage) {
clients.lock.Lock()
defer clients.lock.Unlock()
clients.rmHostsBySrc(client.SourceHostsFile)
deleted := clients.runtimeIndex.DeleteBySource(client.SourceHostsFile)
log.Debug("clients: removed %d client aliases from system hosts file", deleted)
n := 0
added := 0
hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) {
// Only the first name of the first record is considered a canonical
// hostname for the IP address.
//
// TODO(e.burkov): Consider using all the names from all the records.
if clients.addHostLocked(addr, names[0], client.SourceHostsFile) {
n++
added++
}
return true
})
log.Debug("clients: added %d client aliases from system hosts file", n)
log.Debug("clients: added %d client aliases from system hosts file", added)
}
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
@@ -876,7 +841,8 @@ func (clients *clientsContainer) addFromSystemARP() {
clients.lock.Lock()
defer clients.lock.Unlock()
clients.rmHostsBySrc(client.SourceARP)
deleted := clients.runtimeIndex.DeleteBySource(client.SourceARP)
log.Debug("clients: removed %d client aliases from arp neighborhood", deleted)
added := 0
for _, n := range ns {
@@ -891,18 +857,5 @@ func (clients *clientsContainer) addFromSystemARP() {
// close gracefully closes all the client-specific upstream configurations of
// the persistent clients.
func (clients *clientsContainer) close() (err error) {
persistent := maps.Values(clients.list)
slices.SortFunc(persistent, func(a, b *client.Persistent) (res int) {
return strings.Compare(a.Name, b.Name)
})
var errs []error
for _, cli := range persistent {
if err = cli.CloseUpstreams(); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
return clients.clientIndex.CloseUpstreams()
}

View File

@@ -41,7 +41,7 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) {
}
dhcp := &testDHCP{
OnLeases: func() (leases []*dhcpsvc.Lease) { panic("not implemented") },
OnLeases: func() (leases []*dhcpsvc.Lease) { return nil },
OnHostBy: func(ip netip.Addr) (host string) { return "" },
OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) { return nil },
}
@@ -72,23 +72,19 @@ func TestClients(t *testing.T) {
IPs: []netip.Addr{cli1IP, cliIPv6},
}
ok, err := clients.add(c)
err := clients.add(c)
require.NoError(t, err)
assert.True(t, ok)
c = &client.Persistent{
Name: "client2",
UID: client.MustNewUID(),
IPs: []netip.Addr{cli2IP},
}
ok, err = clients.add(c)
err = clients.add(c)
require.NoError(t, err)
assert.True(t, ok)
c, ok = clients.find(cli1)
c, ok := clients.find(cli1)
require.True(t, ok)
assert.Equal(t, "client1", c.Name)
@@ -111,22 +107,20 @@ func TestClients(t *testing.T) {
})
t.Run("add_fail_name", func(t *testing.T) {
ok, err := clients.add(&client.Persistent{
err := clients.add(&client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")},
})
require.NoError(t, err)
assert.False(t, ok)
require.Error(t, err)
})
t.Run("add_fail_ip", func(t *testing.T) {
ok, err := clients.add(&client.Persistent{
err := clients.add(&client.Persistent{
Name: "client3",
UID: client.MustNewUID(),
})
require.Error(t, err)
assert.False(t, ok)
})
t.Run("update_fail_ip", func(t *testing.T) {
@@ -145,12 +139,13 @@ func TestClients(t *testing.T) {
cliNewIP = netip.MustParseAddr(cliNew)
)
prev, ok := clients.list["client1"]
prev, ok := clients.clientIndex.FindByName("client1")
require.True(t, ok)
require.NotNil(t, prev)
err := clients.update(prev, &client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
UID: prev.UID,
IPs: []netip.Addr{cliNewIP},
})
require.NoError(t, err)
@@ -160,12 +155,13 @@ func TestClients(t *testing.T) {
assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent)
prev, ok = clients.list["client1"]
prev, ok = clients.clientIndex.FindByName("client1")
require.True(t, ok)
require.NotNil(t, prev)
err = clients.update(prev, &client.Persistent{
Name: "client1-renamed",
UID: client.MustNewUID(),
UID: prev.UID,
IPs: []netip.Addr{cliNewIP},
UseOwnSettings: true,
})
@@ -177,7 +173,7 @@ func TestClients(t *testing.T) {
assert.Equal(t, "client1-renamed", c.Name)
assert.True(t, c.UseOwnSettings)
nilCli, ok := clients.list["client1"]
nilCli, ok := clients.clientIndex.FindByName("client1")
require.False(t, ok)
assert.Nil(t, nilCli)
@@ -244,7 +240,7 @@ func TestClientsWHOIS(t *testing.T) {
t.Run("new_client", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.255")
clients.setWHOISInfo(ip, whois)
rc := clients.ipToRC[ip]
rc := clients.runtimeIndex.Client(ip)
require.NotNil(t, rc)
assert.Equal(t, whois, rc.WHOIS())
@@ -256,7 +252,7 @@ func TestClientsWHOIS(t *testing.T) {
assert.True(t, ok)
clients.setWHOISInfo(ip, whois)
rc := clients.ipToRC[ip]
rc := clients.runtimeIndex.Client(ip)
require.NotNil(t, rc)
assert.Equal(t, whois, rc.WHOIS())
@@ -265,16 +261,15 @@ func TestClientsWHOIS(t *testing.T) {
t.Run("can't_set_manually-added", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.2")
ok, err := clients.add(&client.Persistent{
err := clients.add(&client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")},
})
require.NoError(t, err)
assert.True(t, ok)
clients.setWHOISInfo(ip, whois)
rc := clients.ipToRC[ip]
rc := clients.runtimeIndex.Client(ip)
require.Nil(t, rc)
assert.True(t, clients.remove("client1"))
@@ -288,7 +283,7 @@ func TestClientsAddExisting(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
// Add a client.
ok, err := clients.add(&client.Persistent{
err := clients.add(&client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")},
@@ -296,10 +291,9 @@ func TestClientsAddExisting(t *testing.T) {
MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}},
})
require.NoError(t, err)
assert.True(t, ok)
// Now add an auto-client with the same IP.
ok = clients.addHost(ip, "test", client.SourceRDNS)
ok := clients.addHost(ip, "test", client.SourceRDNS)
assert.True(t, ok)
})
@@ -339,22 +333,20 @@ func TestClientsAddExisting(t *testing.T) {
require.NoError(t, err)
// Add a new client with the same IP as for a client with MAC.
ok, err := clients.add(&client.Persistent{
err = clients.add(&client.Persistent{
Name: "client2",
UID: client.MustNewUID(),
IPs: []netip.Addr{ip},
})
require.NoError(t, err)
assert.True(t, ok)
// Add a new client with the IP from the first client's IP range.
ok, err = clients.add(&client.Persistent{
err = clients.add(&client.Persistent{
Name: "client3",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
})
require.NoError(t, err)
assert.True(t, ok)
})
}
@@ -362,7 +354,7 @@ func TestClientsCustomUpstream(t *testing.T) {
clients := newClientsContainer(t)
// Add client with upstreams.
ok, err := clients.add(&client.Persistent{
err := clients.add(&client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")},
@@ -372,7 +364,6 @@ func TestClientsCustomUpstream(t *testing.T) {
},
})
require.NoError(t, err)
assert.True(t, ok)
upsConf, err := clients.UpstreamConfigByID("1.2.3.4", net.DefaultResolver)
assert.Nil(t, upsConf)

View File

@@ -96,22 +96,26 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
clients.lock.Lock()
defer clients.lock.Unlock()
for _, c := range clients.list {
clients.clientIndex.Range(func(c *client.Persistent) (cont bool) {
cj := clientToJSON(c)
data.Clients = append(data.Clients, cj)
}
for ip, rc := range clients.ipToRC {
return true
})
clients.runtimeIndex.Range(func(rc *client.Runtime) (cont bool) {
src, host := rc.Info()
cj := runtimeClientJSON{
WHOIS: whoisOrEmpty(rc),
Name: host,
Source: src,
IP: ip,
IP: rc.Addr(),
}
data.RuntimeClients = append(data.RuntimeClients, cj)
}
return true
})
for _, l := range clients.dhcp.Leases() {
cj := runtimeClientJSON{
@@ -332,20 +336,16 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
return
}
ok, err := clients.add(c)
err = clients.add(c)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
if !ok {
aghhttp.Error(r, w, http.StatusBadRequest, "Client already exists")
return
if !clients.testing {
onConfigModified()
}
onConfigModified()
}
// handleDelClient is the handler for POST /control/clients/delete HTTP API.
@@ -370,7 +370,9 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.
return
}
onConfigModified()
if !clients.testing {
onConfigModified()
}
}
// updateJSON contains the name and data of the updated persistent client.
@@ -404,7 +406,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
clients.lock.Lock()
defer clients.lock.Unlock()
prev, ok = clients.list[dj.Name]
prev, ok = clients.clientIndex.FindByName(dj.Name)
}()
if !ok {
@@ -427,14 +429,16 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
return
}
onConfigModified()
if !clients.testing {
onConfigModified()
}
}
// handleFindClient is the handler for GET /control/clients/find HTTP API.
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
data := []map[string]*clientJSON{}
for i := 0; i < len(q); i++ {
for i := range len(q) {
idStr := q.Get(fmt.Sprintf("ip%d", i))
if idStr == "" {
break
@@ -447,7 +451,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
cj = clients.findRuntime(ip, idStr)
} else {
cj = clientToJSON(c)
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
}
@@ -463,14 +467,14 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
// non-nil.
func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) {
rc, ok := clients.findRuntimeClient(ip)
if !ok {
rc := clients.findRuntimeClient(ip)
if rc == nil {
// It is still possible that the IP used to be in the runtime clients
// list, but then the server was reloaded. So, check the DNS server's
// blocked IP list.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
cj = &clientJSON{
IDs: []string{idStr},
Disallowed: &disallowed,
@@ -488,7 +492,7 @@ func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *c
WHOIS: whoisOrEmpty(rc),
}
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
return cj

View File

@@ -0,0 +1,399 @@
package home
import (
"bytes"
"cmp"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/netip"
"net/url"
"slices"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
testClientIP1 = "1.1.1.1"
testClientIP2 = "2.2.2.2"
)
// testBlockedClientChecker is a mock implementation of the
// [BlockedClientChecker] interface.
type testBlockedClientChecker struct {
onIsBlockedClient func(ip netip.Addr, clientiD string) (blocked bool, rule string)
}
// type check
var _ BlockedClientChecker = (*testBlockedClientChecker)(nil)
// IsBlockedClient implements the [BlockedClientChecker] interface for
// *testBlockedClientChecker.
func (c *testBlockedClientChecker) IsBlockedClient(
ip netip.Addr,
clientID string,
) (blocked bool, rule string) {
return c.onIsBlockedClient(ip, clientID)
}
// newPersistentClient is a helper function that returns a persistent client
// with the specified name and newly generated UID.
func newPersistentClient(name string) (c *client.Persistent) {
return &client.Persistent{
Name: name,
UID: client.MustNewUID(),
BlockedServices: &filtering.BlockedServices{
Schedule: &schedule.Weekly{},
},
}
}
// newPersistentClientWithIDs is a helper function that returns a persistent
// client with the specified name and ids.
func newPersistentClientWithIDs(tb testing.TB, name string, ids []string) (c *client.Persistent) {
tb.Helper()
c = newPersistentClient(name)
err := c.SetIDs(ids)
require.NoError(tb, err)
return c
}
// assertClients is a helper function that compares lists of persistent clients.
func assertClients(tb testing.TB, want, got []*client.Persistent) {
tb.Helper()
require.Len(tb, got, len(want))
sortFunc := func(a, b *client.Persistent) (n int) {
return cmp.Compare(a.Name, b.Name)
}
slices.SortFunc(want, sortFunc)
slices.SortFunc(got, sortFunc)
slices.CompareFunc(want, got, func(a, b *client.Persistent) (n int) {
assert.True(tb, a.EqualIDs(b), "%q doesn't have the same ids as %q", a.Name, b.Name)
return 0
})
}
// assertPersistentClients is a helper function that uses HTTP API to check
// whether want persistent clients are the same as the persistent clients stored
// in the clients container.
func assertPersistentClients(tb testing.TB, clients *clientsContainer, want []*client.Persistent) {
tb.Helper()
rw := httptest.NewRecorder()
clients.handleGetClients(rw, &http.Request{})
body, err := io.ReadAll(rw.Body)
require.NoError(tb, err)
clientList := &clientListJSON{}
err = json.Unmarshal(body, clientList)
require.NoError(tb, err)
var got []*client.Persistent
for _, cj := range clientList.Clients {
var c *client.Persistent
c, err = clients.jsonToClient(*cj, nil)
require.NoError(tb, err)
got = append(got, c)
}
assertClients(tb, want, got)
}
// assertPersistentClientsData is a helper function that checks whether want
// persistent clients are the same as the persistent clients stored in data.
func assertPersistentClientsData(
tb testing.TB,
clients *clientsContainer,
data []map[string]*clientJSON,
want []*client.Persistent,
) {
tb.Helper()
var got []*client.Persistent
for _, cm := range data {
for _, cj := range cm {
var c *client.Persistent
c, err := clients.jsonToClient(*cj, nil)
require.NoError(tb, err)
got = append(got, c)
}
}
assertClients(tb, want, got)
}
func TestClientsContainer_HandleAddClient(t *testing.T) {
clients := newClientsContainer(t)
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
clientEmptyID := newPersistentClient("empty_client_id")
clientEmptyID.ClientIDs = []string{""}
testCases := []struct {
name string
client *client.Persistent
wantCode int
wantClient []*client.Persistent
}{{
name: "add_one",
client: clientOne,
wantCode: http.StatusOK,
wantClient: []*client.Persistent{clientOne},
}, {
name: "add_two",
client: clientTwo,
wantCode: http.StatusOK,
wantClient: []*client.Persistent{clientOne, clientTwo},
}, {
name: "duplicate_client",
client: clientTwo,
wantCode: http.StatusBadRequest,
wantClient: []*client.Persistent{clientOne, clientTwo},
}, {
name: "empty_client_id",
client: clientEmptyID,
wantCode: http.StatusBadRequest,
wantClient: []*client.Persistent{clientOne, clientTwo},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cj := clientToJSON(tc.client)
body, err := json.Marshal(cj)
require.NoError(t, err)
r, err := http.NewRequest(http.MethodPost, "", bytes.NewReader(body))
require.NoError(t, err)
rw := httptest.NewRecorder()
clients.handleAddClient(rw, r)
require.NoError(t, err)
require.Equal(t, tc.wantCode, rw.Code)
assertPersistentClients(t, clients, tc.wantClient)
})
}
}
func TestClientsContainer_HandleDelClient(t *testing.T) {
clients := newClientsContainer(t)
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
err := clients.add(clientOne)
require.NoError(t, err)
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
err = clients.add(clientTwo)
require.NoError(t, err)
assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo})
testCases := []struct {
name string
client *client.Persistent
wantCode int
wantClient []*client.Persistent
}{{
name: "remove_one",
client: clientOne,
wantCode: http.StatusOK,
wantClient: []*client.Persistent{clientTwo},
}, {
name: "duplicate_client",
client: clientOne,
wantCode: http.StatusBadRequest,
wantClient: []*client.Persistent{clientTwo},
}, {
name: "empty_client_name",
client: newPersistentClient(""),
wantCode: http.StatusBadRequest,
wantClient: []*client.Persistent{clientTwo},
}, {
name: "remove_two",
client: clientTwo,
wantCode: http.StatusOK,
wantClient: []*client.Persistent{},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cj := clientToJSON(tc.client)
var body []byte
body, err = json.Marshal(cj)
require.NoError(t, err)
var r *http.Request
r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(body))
require.NoError(t, err)
rw := httptest.NewRecorder()
clients.handleDelClient(rw, r)
require.NoError(t, err)
require.Equal(t, tc.wantCode, rw.Code)
assertPersistentClients(t, clients, tc.wantClient)
})
}
}
func TestClientsContainer_HandleUpdateClient(t *testing.T) {
clients := newClientsContainer(t)
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
err := clients.add(clientOne)
require.NoError(t, err)
assertPersistentClients(t, clients, []*client.Persistent{clientOne})
clientModified := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
clientEmptyID := newPersistentClient("empty_client_id")
clientEmptyID.ClientIDs = []string{""}
testCases := []struct {
name string
clientName string
modified *client.Persistent
wantCode int
wantClient []*client.Persistent
}{{
name: "update_one",
clientName: clientOne.Name,
modified: clientModified,
wantCode: http.StatusOK,
wantClient: []*client.Persistent{clientModified},
}, {
name: "empty_name",
clientName: "",
modified: clientOne,
wantCode: http.StatusBadRequest,
wantClient: []*client.Persistent{clientModified},
}, {
name: "client_not_found",
clientName: "client_not_found",
modified: clientOne,
wantCode: http.StatusBadRequest,
wantClient: []*client.Persistent{clientModified},
}, {
name: "empty_client_id",
clientName: clientModified.Name,
modified: clientEmptyID,
wantCode: http.StatusBadRequest,
wantClient: []*client.Persistent{clientModified},
}, {
name: "no_ids",
clientName: clientModified.Name,
modified: newPersistentClient("no_ids"),
wantCode: http.StatusBadRequest,
wantClient: []*client.Persistent{clientModified},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
uj := updateJSON{
Name: tc.clientName,
Data: *clientToJSON(tc.modified),
}
var body []byte
body, err = json.Marshal(uj)
require.NoError(t, err)
var r *http.Request
r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(body))
require.NoError(t, err)
rw := httptest.NewRecorder()
clients.handleUpdateClient(rw, r)
require.NoError(t, err)
require.Equal(t, tc.wantCode, rw.Code)
assertPersistentClients(t, clients, tc.wantClient)
})
}
}
func TestClientsContainer_HandleFindClient(t *testing.T) {
clients := newClientsContainer(t)
clients.clientChecker = &testBlockedClientChecker{
onIsBlockedClient: func(ip netip.Addr, clientID string) (ok bool, rule string) {
return false, ""
},
}
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
err := clients.add(clientOne)
require.NoError(t, err)
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
err = clients.add(clientTwo)
require.NoError(t, err)
assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo})
testCases := []struct {
name string
query url.Values
wantCode int
wantClient []*client.Persistent
}{{
name: "single",
query: url.Values{
"ip0": []string{testClientIP1},
},
wantCode: http.StatusOK,
wantClient: []*client.Persistent{clientOne},
}, {
name: "multiple",
query: url.Values{
"ip0": []string{testClientIP1},
"ip1": []string{testClientIP2},
},
wantCode: http.StatusOK,
wantClient: []*client.Persistent{clientOne, clientTwo},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var r *http.Request
r, err = http.NewRequest(http.MethodGet, "", nil)
require.NoError(t, err)
r.URL.RawQuery = tc.query.Encode()
rw := httptest.NewRecorder()
clients.handleFindClient(rw, r)
require.NoError(t, err)
require.Equal(t, tc.wantCode, rw.Code)
var body []byte
body, err = io.ReadAll(rw.Body)
require.NoError(t, err)
clientData := []map[string]*clientJSON{}
err = json.Unmarshal(body, &clientData)
require.NoError(t, err)
assertPersistentClientsData(t, clients, clientData, tc.wantClient)
})
}
}

View File

@@ -203,15 +203,24 @@ type dnsConfig struct {
// resolver should be used.
PrivateNets []netutil.Prefix `yaml:"private_networks"`
// UsePrivateRDNS defines if the PTR requests for unknown addresses from
// locally-served networks should be resolved via private PTR resolvers.
// UsePrivateRDNS enables resolving requests containing a private IP address
// using private reverse DNS resolvers. See PrivateRDNSResolvers.
//
// TODO(e.burkov): Rename in YAML.
UsePrivateRDNS bool `yaml:"use_private_ptr_resolvers"`
// LocalPTRResolvers is the slice of addresses to be used as upstreams
// for PTR queries for locally-served networks.
LocalPTRResolvers []string `yaml:"local_ptr_upstreams"`
// PrivateRDNSResolvers is the slice of addresses to be used as upstreams
// for private requests. It's only used for PTR, SOA, and NS queries,
// containing an ARPA subdomain, came from the the client with private
// address. The address considered private according to PrivateNets.
//
// If empty, the OS-provided resolvers are used for private requests.
PrivateRDNSResolvers []string `yaml:"local_ptr_upstreams"`
// UseDNS64 defines if DNS64 should be used for incoming requests.
// UseDNS64 defines if DNS64 should be used for incoming requests. Requests
// of type PTR for addresses within the configured prefixes will be resolved
// via [PrivateRDNSResolvers], so those should be valid and UsePrivateRDNS
// be set to true.
UseDNS64 bool `yaml:"use_dns64"`
// DNS64Prefixes is the list of NAT64 prefixes to be used for DNS64.
@@ -658,7 +667,7 @@ func (c *configuration) write() (err error) {
dns := &config.DNS
dns.Config = c
dns.LocalPTRResolvers = s.LocalPTRResolvers()
dns.PrivateRDNSResolvers = s.LocalPTRResolvers()
addrProcConf := s.AddrProcConfig()
config.Clients.Sources.RDNS = addrProcConf.UseRDNS

View File

@@ -1,7 +1,6 @@
package home
import (
"context"
"fmt"
"net"
"net/netip"
@@ -18,7 +17,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
@@ -150,21 +148,19 @@ func initDNSServer(
return fmt.Errorf("dnsforward.NewServer: %w", err)
}
Context.clients.dnsServer = Context.dnsServer
Context.clients.clientChecker = Context.dnsServer
dnsConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpReg)
if err != nil {
return fmt.Errorf("newServerConfig: %w", err)
}
// Try to prepare the server with disabled private RDNS resolution if it
// failed to prepare as is. See TODO on [ErrBadPrivateRDNSUpstreams].
err = Context.dnsServer.Prepare(dnsConf)
if privRDNSErr := (&dnsforward.PrivateRDNSError{}); errors.As(err, &privRDNSErr) {
log.Info("WARNING: %s; trying to disable private RDNS resolution", err)
// TODO(e.burkov): Recreate the server with private RDNS disabled. This
// should go away once the private RDNS resolution is moved to the proxy.
var locResErr *dnsforward.LocalResolversError
if errors.As(err, &locResErr) && errors.Is(locResErr.Err, upstream.ErrNoUpstreams) {
log.Info("WARNING: no local resolvers configured while private RDNS " +
"resolution enabled, trying to disable")
dnsConf.UsePrivateRDNS = false
err = Context.dnsServer.Prepare(dnsConf)
}
@@ -245,7 +241,7 @@ func newServerConfig(
TLSv12Roots: Context.tlsRoots,
ConfigModified: onConfigModified,
HTTPRegister: httpReg,
LocalPTRResolvers: dnsConf.LocalPTRResolvers,
LocalPTRResolvers: dnsConf.PrivateRDNSResolvers,
UseDNS64: dnsConf.UseDNS64,
DNS64Prefixes: dnsConf.DNS64Prefixes,
UsePrivateRDNS: dnsConf.UsePrivateRDNS,
@@ -531,36 +527,6 @@ func closeDNSServer() {
log.Debug("all dns modules are closed")
}
// safeSearchResolver is a [filtering.Resolver] implementation used for safe
// search.
type safeSearchResolver struct{}
// type check
var _ filtering.Resolver = safeSearchResolver{}
// LookupIP implements [filtering.Resolver] interface for safeSearchResolver.
// It returns the slice of net.Addr with IPv4 and IPv6 instances.
func (r safeSearchResolver) LookupIP(
ctx context.Context,
network string,
host string,
) (ips []net.IP, err error) {
addrs, err := Context.dnsServer.Resolve(ctx, network, host)
if err != nil {
return nil, err
}
if len(addrs) == 0 {
return nil, fmt.Errorf("couldn't lookup host: %s", host)
}
for _, a := range addrs {
ips = append(ips, a.AsSlice())
}
return ips, nil
}
// checkStatsAndQuerylogDirs checks and returns directory paths to store
// statistics and query log.
func checkStatsAndQuerylogDirs(

View File

@@ -439,7 +439,6 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
conf.ParentalBlockHost = host
}
conf.SafeSearchConf.CustomResolver = safeSearchResolver{}
conf.SafeSearch, err = safesearch.NewDefault(
conf.SafeSearchConf,
"default",

View File

@@ -1,13 +1,13 @@
package home
import (
"cmp"
"fmt"
"path/filepath"
"runtime"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"gopkg.in/natefinch/lumberjack.v2"
"gopkg.in/yaml.v3"
)
@@ -76,8 +76,7 @@ func getLogSettings(opts options) (ls *logSettings) {
ls.Verbose = true
}
// TODO(a.garipov): Use cmp.Or in Go 1.22.
ls.File = stringutil.Coalesce(opts.logFile, ls.File)
ls.File = cmp.Or(opts.logFile, ls.File)
if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
// When running as a Windows service, use eventlog by default if

View File

@@ -306,7 +306,7 @@ func handleServiceStatusCommand(s service.Service) {
}
}
// handleServiceStatusCommand handles service "install" command
// handleServiceInstallCommand handles service "install" command.
func handleServiceInstallCommand(s service.Service) {
err := svcAction(s, "install")
if err != nil {
@@ -340,7 +340,7 @@ AdGuard Home is now available at the following addresses:`)
}
}
// handleServiceStatusCommand handles service "uninstall" command
// handleServiceUninstallCommand handles service "uninstall" command.
func handleServiceUninstallCommand(s service.Service) {
if aghos.IsOpenWrt() {
// On OpenWrt it is important to run disable command first
@@ -649,11 +649,6 @@ status() {
// freeBSDScript is the source of the daemon script for FreeBSD. Keep as close
// as possible to the https://github.com/kardianos/service/blob/18c957a3dc1120a2efe77beb401d476bade9e577/service_freebsd.go#L204.
//
// TODO(a.garipov): Don't use .WorkingDirectory here. There are currently no
// guarantees that it will actually be the required directory.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2614.
const freeBSDScript = `#!/bin/sh
# PROVIDE: {{.Name}}
# REQUIRE: networking
@@ -667,7 +662,9 @@ name="{{.Name}}"
pidfile_child="/var/run/${name}.pid"
pidfile="/var/run/${name}_daemon.pid"
command="/usr/sbin/daemon"
command_args="-P ${pidfile} -p ${pidfile_child} -T ${name} -r {{.WorkingDirectory}}/{{.Name}}"
daemon_args="-P ${pidfile} -p ${pidfile_child} -r -t ${name}"
command_args="${daemon_args} {{.Path}}{{range .Arguments}} {{.}}{{end}}"
run_rc_command "$1"
`

View File

@@ -3,6 +3,7 @@
package home
import (
"cmp"
"fmt"
"os"
"os/signal"
@@ -14,7 +15,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/kardianos/service"
)
@@ -76,7 +76,7 @@ func (*openbsdRunComService) Platform() (p string) {
// String implements service.Service interface for *openbsdRunComService.
func (s *openbsdRunComService) String() string {
return stringutil.Coalesce(s.cfg.DisplayName, s.cfg.Name)
return cmp.Or(s.cfg.DisplayName, s.cfg.Name)
}
// getBool returns the value of the given name from kv, assuming the value is a

View File

@@ -147,7 +147,7 @@ func BenchmarkManager_LookupHost(b *testing.B) {
b.Run("long", func(b *testing.B) {
const name = "a.very.long.domain.name.inside.the.domain.example.com"
for i := 0; i < b.N; i++ {
for range b.N {
ipsetPropsSink = m.lookupHost(name)
}
@@ -156,7 +156,7 @@ func BenchmarkManager_LookupHost(b *testing.B) {
b.Run("short", func(b *testing.B) {
const name = "example.net"
for i := 0; i < b.N; i++ {
for range b.N {
ipsetPropsSink = m.lookupHost(name)
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/osutil"
"github.com/google/renameio/v2/maybe"
)
@@ -38,7 +39,7 @@ func (h *signalHandler) handle() {
if aghos.IsReconfigureSignal(sig) {
h.reconfigure()
} else if aghos.IsShutdownSignal(sig) {
} else if osutil.IsShutdownSignal(sig) {
status := h.shutdown()
h.removePID()
@@ -122,7 +123,8 @@ func newSignalHandler(
services: svcs,
}
aghos.NotifyShutdownSignal(h.signal)
notifier := osutil.DefaultSignalNotifier{}
osutil.NotifyShutdownSignal(notifier, h.signal)
aghos.NotifyReconfigureSignal(h.signal)
return h

View File

@@ -1,7 +1,6 @@
package dnssvc_test
import (
"context"
"net/netip"
"testing"
"time"
@@ -94,10 +93,8 @@ func TestService(t *testing.T) {
}},
}
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
cli := &dns.Client{}
ctx := testutil.ContextWithTimeout(t, testTimeout)
var resp *dns.Msg
require.Eventually(t, func() (ok bool) {
@@ -110,10 +107,8 @@ func TestService(t *testing.T) {
assert.NotNil(t, resp)
})
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
err = svc.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
err = svc.Shutdown(ctx)
require.NoError(t, err)
err = upstreamSrv.Shutdown()

View File

@@ -109,12 +109,8 @@ func newTestServer(
err = svc.Start()
require.NoError(t, err)
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
t.Cleanup(cancel)
err = svc.Shutdown(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return svc.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
})
c = svc.Config()

View File

@@ -303,7 +303,7 @@ func BenchmarkAnonymizeIP(b *testing.B) {
b.Run(bc.name, func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
for range b.N {
AnonymizeIP(bc.ip)
}
@@ -313,7 +313,7 @@ func BenchmarkAnonymizeIP(b *testing.B) {
b.Run(bc.name+"_slow", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
for range b.N {
anonymizeIPSlow(bc.ip)
}

View File

@@ -31,6 +31,7 @@ type logEntry struct {
Answer []byte `json:",omitempty"`
OrigAnswer []byte `json:",omitempty"`
// TODO(s.chzhen): Use netip.Addr.
IP net.IP `json:"IP"`
Result filtering.Result

View File

@@ -143,13 +143,13 @@ func TestQueryLogOffsetLimit(t *testing.T) {
secondPageDomain = "second.example.org"
)
// Add entries to the log.
for i := 0; i < entNum; i++ {
for range entNum {
addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
}
// Write them to the first file.
require.NoError(t, l.flushLogBuffer())
// Add more to the in-memory part of log.
for i := 0; i < entNum; i++ {
for range entNum {
addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
}
@@ -215,7 +215,7 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
const entNum = 10
// Add entries to the log.
for i := 0; i < entNum; i++ {
for range entNum {
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
}
// Write them to disk.

View File

@@ -37,7 +37,7 @@ func prepareTestFile(t *testing.T, dir string, linesNum int) (name string) {
var lineIP uint32
lineTime := time.Date(2020, 2, 18, 19, 36, 35, 920973000, time.UTC)
for i := 0; i < linesNum; i++ {
for range linesNum {
lineIP++
lineTime = lineTime.Add(time.Second)

View File

@@ -68,13 +68,13 @@ func TestStats_races(t *testing.T) {
startWG, finWG := &sync.WaitGroup{}, &sync.WaitGroup{}
waitCh := make(chan unit)
for i := 0; i < writersNum; i++ {
for i := range writersNum {
startWG.Add(1)
finWG.Add(1)
go writeFunc(startWG, finWG, waitCh, i)
}
for i := 0; i < readersNum; i++ {
for range readersNum {
startWG.Add(1)
finWG.Add(1)
go readFunc(startWG, finWG, waitCh)
@@ -111,7 +111,7 @@ func TestStatsCtx_FillCollectedStats_daily(t *testing.T) {
dailyData := []*unitDB{}
for i := 0; i < daysCount*24; i++ {
for i := range daysCount * 24 {
n := uint64(i)
nResult := make([]uint64, resultLast)
nResult[RFiltered] = n

View File

@@ -195,7 +195,7 @@ func TestLargeNumbers(t *testing.T) {
for h := 0; h < hoursNum; h++ {
atomic.AddUint32(&curHour, 1)
for i := 0; i < cliNumPerHour; i++ {
for i := range cliNumPerHour {
ip := net.IP{127, 0, byte((i & 0xff00) >> 8), byte(i & 0xff)}
e := &stats.Entry{
Domain: fmt.Sprintf("domain%d.hour%d", i, h),

View File

@@ -525,9 +525,8 @@ func (s *StatsCtx) fillCollectedStatsDaily(
hours := countHours(curHour, days)
units = units[len(units)-hours:]
for i := 0; i < len(units); i++ {
for i, u := range units {
day := i / 24
u := units[i]
data.DNSQueries[day] += u.NTotal
data.BlockedFiltering[day] += u.NResult[RFiltered]

View File

@@ -1,6 +1,6 @@
module github.com/AdguardTeam/AdGuardHome/internal/tools
go 1.22.2
go 1.22.3
require (
github.com/fzipp/gocyclo v0.6.0

View File

@@ -3,6 +3,7 @@ package whois
import (
"bytes"
"cmp"
"context"
"fmt"
"io"
@@ -17,7 +18,6 @@ import (
"github.com/AdguardTeam/golibs/ioutil"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/bluele/gcache"
)
@@ -174,7 +174,7 @@ func whoisParse(data []byte, maxLen int) (info map[string]string) {
val = trimValue(val, maxLen)
case "descr", "netname":
key = "orgname"
val = stringutil.Coalesce(orgname, val)
val = cmp.Or(orgname, val)
orgname = val
case "whois":
key = "whois"
@@ -232,7 +232,7 @@ func (w *Default) queryAll(ctx context.Context, target string) (info map[string]
server := net.JoinHostPort(w.serverAddr, w.portStr)
var data []byte
for i := 0; i < w.maxRedirects; i++ {
for range w.maxRedirects {
data, err = w.query(ctx, target, server)
if err != nil {
// Don't wrap the error since it's informative enough as is.