all: sync with master
This commit is contained in:
@@ -10,29 +10,8 @@ import (
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
// Coalesce returns the first non-zero value. It is named after function
|
||||
// COALESCE in SQL. If values or all its elements are empty, it returns a zero
|
||||
// value.
|
||||
//
|
||||
// T is comparable, because Go currently doesn't have a comparableWithZeroValue
|
||||
// constraint.
|
||||
//
|
||||
// TODO(a.garipov): Think of ways to merge with [CoalesceSlice].
|
||||
func Coalesce[T comparable](values ...T) (res T) {
|
||||
var zero T
|
||||
for _, v := range values {
|
||||
if v != zero {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
return zero
|
||||
}
|
||||
|
||||
// CoalesceSlice returns the first non-zero value. It is named after function
|
||||
// COALESCE in SQL. If values or all its elements are empty, it returns nil.
|
||||
//
|
||||
// TODO(a.garipov): Think of ways to merge with [Coalesce].
|
||||
func CoalesceSlice[E any, S []E](values ...S) (res S) {
|
||||
for _, v := range values {
|
||||
if v != nil {
|
||||
|
||||
@@ -33,7 +33,7 @@ func elements(b *aghalg.RingBuffer[int], n uint, reverse bool) (es []int) {
|
||||
func TestNewRingBuffer(t *testing.T) {
|
||||
t.Run("success_and_clear", func(t *testing.T) {
|
||||
b := aghalg.NewRingBuffer[int](5)
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := range 10 {
|
||||
b.Append(i)
|
||||
}
|
||||
assert.Equal(t, []int{5, 6, 7, 8, 9}, elements(b, b.Len(), false))
|
||||
@@ -44,7 +44,7 @@ func TestNewRingBuffer(t *testing.T) {
|
||||
|
||||
t.Run("zero", func(t *testing.T) {
|
||||
b := aghalg.NewRingBuffer[int](0)
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := range 10 {
|
||||
b.Append(i)
|
||||
bufLen := b.Len()
|
||||
assert.EqualValues(t, 0, bufLen)
|
||||
@@ -55,7 +55,7 @@ func TestNewRingBuffer(t *testing.T) {
|
||||
|
||||
t.Run("single", func(t *testing.T) {
|
||||
b := aghalg.NewRingBuffer[int](1)
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := range 10 {
|
||||
b.Append(i)
|
||||
bufLen := b.Len()
|
||||
assert.EqualValues(t, 1, bufLen)
|
||||
@@ -94,7 +94,7 @@ func TestRingBuffer_Range(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
for i := 0; i < tc.count; i++ {
|
||||
for i := range tc.count {
|
||||
b.Append(i)
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ func TestNewSortedMap(t *testing.T) {
|
||||
var m SortedMap[string, int]
|
||||
|
||||
letters := []string{}
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := range 10 {
|
||||
r := string('a' + rune(i))
|
||||
letters = append(letters, r)
|
||||
}
|
||||
|
||||
@@ -97,6 +97,8 @@ func (fw FileWalker) Walk(fsys fs.FS, initial ...string) (ok bool, err error) {
|
||||
var filename string
|
||||
defer func() { err = errors.Annotate(err, "checking %q: %w", filename) }()
|
||||
|
||||
// TODO(e.burkov): Redo this loop, as it modifies the very same slice it
|
||||
// iterates over.
|
||||
for i := 0; i < len(src); i++ {
|
||||
var patterns []string
|
||||
var cont bool
|
||||
|
||||
@@ -159,21 +159,11 @@ func NotifyReconfigureSignal(c chan<- os.Signal) {
|
||||
notifyReconfigureSignal(c)
|
||||
}
|
||||
|
||||
// NotifyShutdownSignal notifies c on receiving shutdown signals.
|
||||
func NotifyShutdownSignal(c chan<- os.Signal) {
|
||||
notifyShutdownSignal(c)
|
||||
}
|
||||
|
||||
// IsReconfigureSignal returns true if sig is a reconfigure signal.
|
||||
func IsReconfigureSignal(sig os.Signal) (ok bool) {
|
||||
return isReconfigureSignal(sig)
|
||||
}
|
||||
|
||||
// IsShutdownSignal returns true if sig is a shutdown signal.
|
||||
func IsShutdownSignal(sig os.Signal) (ok bool) {
|
||||
return isShutdownSignal(sig)
|
||||
}
|
||||
|
||||
// SendShutdownSignal sends the shutdown signal to the channel.
|
||||
func SendShutdownSignal(c chan<- os.Signal) {
|
||||
sendShutdownSignal(c)
|
||||
|
||||
@@ -13,26 +13,10 @@ func notifyReconfigureSignal(c chan<- os.Signal) {
|
||||
signal.Notify(c, unix.SIGHUP)
|
||||
}
|
||||
|
||||
func notifyShutdownSignal(c chan<- os.Signal) {
|
||||
signal.Notify(c, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM)
|
||||
}
|
||||
|
||||
func isReconfigureSignal(sig os.Signal) (ok bool) {
|
||||
return sig == unix.SIGHUP
|
||||
}
|
||||
|
||||
func isShutdownSignal(sig os.Signal) (ok bool) {
|
||||
switch sig {
|
||||
case
|
||||
unix.SIGINT,
|
||||
unix.SIGQUIT,
|
||||
unix.SIGTERM:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func sendShutdownSignal(_ chan<- os.Signal) {
|
||||
// On Unix we are already notified by the system.
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ package aghos
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
@@ -43,25 +42,10 @@ func notifyReconfigureSignal(c chan<- os.Signal) {
|
||||
signal.Notify(c, windows.SIGHUP)
|
||||
}
|
||||
|
||||
func notifyShutdownSignal(c chan<- os.Signal) {
|
||||
// syscall.SIGTERM is processed automatically. See go doc os/signal,
|
||||
// section Windows.
|
||||
signal.Notify(c, os.Interrupt)
|
||||
}
|
||||
|
||||
func isReconfigureSignal(sig os.Signal) (ok bool) {
|
||||
return sig == windows.SIGHUP
|
||||
}
|
||||
|
||||
func isShutdownSignal(sig os.Signal) (ok bool) {
|
||||
switch sig {
|
||||
case os.Interrupt, syscall.SIGTERM:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func sendShutdownSignal(c chan<- os.Signal) {
|
||||
c <- os.Interrupt
|
||||
}
|
||||
|
||||
@@ -78,7 +78,6 @@ func TestWithDeferredCleanup(t *testing.T) {
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -91,8 +91,6 @@ func TestDefaultAddrProc_Process_rDNS(t *testing.T) {
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -186,8 +184,6 @@ func TestDefaultAddrProc_Process_WHOIS(t *testing.T) {
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ package client
|
||||
import (
|
||||
"encoding"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
)
|
||||
@@ -56,6 +57,9 @@ func (cs Source) MarshalText() (text []byte, err error) {
|
||||
|
||||
// Runtime is a client information from different sources.
|
||||
type Runtime struct {
|
||||
// ip is an IP address of a client.
|
||||
ip netip.Addr
|
||||
|
||||
// whois is the filtered WHOIS information of a client.
|
||||
whois *whois.Info
|
||||
|
||||
@@ -80,6 +84,15 @@ type Runtime struct {
|
||||
hostsFile []string
|
||||
}
|
||||
|
||||
// NewRuntime constructs a new runtime client. ip must be valid IP address.
|
||||
//
|
||||
// TODO(s.chzhen): Validate IP address.
|
||||
func NewRuntime(ip netip.Addr) (r *Runtime) {
|
||||
return &Runtime{
|
||||
ip: ip,
|
||||
}
|
||||
}
|
||||
|
||||
// Info returns a client information from the highest-priority source.
|
||||
func (r *Runtime) Info() (cs Source, host string) {
|
||||
info := []string{}
|
||||
@@ -133,8 +146,8 @@ func (r *Runtime) SetWHOIS(info *whois.Info) {
|
||||
r.whois = info
|
||||
}
|
||||
|
||||
// Unset clears a cs information.
|
||||
func (r *Runtime) Unset(cs Source) {
|
||||
// unset clears a cs information.
|
||||
func (r *Runtime) unset(cs Source) {
|
||||
switch cs {
|
||||
case SourceWHOIS:
|
||||
r.whois = nil
|
||||
@@ -149,11 +162,16 @@ func (r *Runtime) Unset(cs Source) {
|
||||
}
|
||||
}
|
||||
|
||||
// IsEmpty returns true if there is no information from any source.
|
||||
func (r *Runtime) IsEmpty() (ok bool) {
|
||||
// isEmpty returns true if there is no information from any source.
|
||||
func (r *Runtime) isEmpty() (ok bool) {
|
||||
return r.whois == nil &&
|
||||
r.arp == nil &&
|
||||
r.rdns == nil &&
|
||||
r.dhcp == nil &&
|
||||
r.hostsFile == nil
|
||||
}
|
||||
|
||||
// Addr returns an IP address of the client.
|
||||
func (r *Runtime) Addr() (ip netip.Addr) {
|
||||
return r.ip
|
||||
}
|
||||
|
||||
@@ -4,8 +4,12 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
// macKey contains MAC as byte array of 6, 8, or 20 bytes.
|
||||
@@ -28,6 +32,9 @@ func macToKey(mac net.HardwareAddr) (key macKey) {
|
||||
|
||||
// Index stores all information about persistent clients.
|
||||
type Index struct {
|
||||
// nameToUID maps client name to UID.
|
||||
nameToUID map[string]UID
|
||||
|
||||
// clientIDToUID maps client ID to UID.
|
||||
clientIDToUID map[string]UID
|
||||
|
||||
@@ -47,6 +54,7 @@ type Index struct {
|
||||
// NewIndex initializes the new instance of client index.
|
||||
func NewIndex() (ci *Index) {
|
||||
return &Index{
|
||||
nameToUID: map[string]UID{},
|
||||
clientIDToUID: map[string]UID{},
|
||||
ipToUID: map[netip.Addr]UID{},
|
||||
subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare),
|
||||
@@ -62,6 +70,8 @@ func (ci *Index) Add(c *Persistent) {
|
||||
panic("client must contain uid")
|
||||
}
|
||||
|
||||
ci.nameToUID[c.Name] = c.UID
|
||||
|
||||
for _, id := range c.ClientIDs {
|
||||
ci.clientIDToUID[id] = c.UID
|
||||
}
|
||||
@@ -82,15 +92,30 @@ func (ci *Index) Add(c *Persistent) {
|
||||
ci.uidToClient[c.UID] = c
|
||||
}
|
||||
|
||||
// ClashesUID returns existing persistent client with the same UID as c. Note
|
||||
// that this is only possible when configuration contains duplicate fields.
|
||||
func (ci *Index) ClashesUID(c *Persistent) (err error) {
|
||||
p, ok := ci.uidToClient[c.UID]
|
||||
if ok {
|
||||
return fmt.Errorf("another client %q uses the same uid", p.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clashes returns an error if the index contains a different persistent client
|
||||
// with at least a single identifier contained by c. c must be non-nil.
|
||||
func (ci *Index) Clashes(c *Persistent) (err error) {
|
||||
if p := ci.clashesName(c); p != nil {
|
||||
return fmt.Errorf("another client uses the same name %q", p.Name)
|
||||
}
|
||||
|
||||
for _, id := range c.ClientIDs {
|
||||
existing, ok := ci.clientIDToUID[id]
|
||||
if ok && existing != c.UID {
|
||||
p := ci.uidToClient[existing]
|
||||
|
||||
return fmt.Errorf("another client %q uses the same ID %q", p.Name, id)
|
||||
return fmt.Errorf("another client %q uses the same ClientID %q", p.Name, id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,6 +137,21 @@ func (ci *Index) Clashes(c *Persistent) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// clashesName returns existing persistent client with the same name as c or
|
||||
// nil. c must be non-nil.
|
||||
func (ci *Index) clashesName(c *Persistent) (existing *Persistent) {
|
||||
existing, ok := ci.FindByName(c.Name)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if existing.UID != c.UID {
|
||||
return existing
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// clashesIP returns a previous client with the same IP address as c. c must be
|
||||
// non-nil.
|
||||
func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) {
|
||||
@@ -184,21 +224,33 @@ func (ci *Index) Find(id string) (c *Persistent, ok bool) {
|
||||
|
||||
mac, err := net.ParseMAC(id)
|
||||
if err == nil {
|
||||
return ci.findByMAC(mac)
|
||||
return ci.FindByMAC(mac)
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// find finds persistent client by IP address.
|
||||
// FindByName finds persistent client by name.
|
||||
func (ci *Index) FindByName(name string) (c *Persistent, found bool) {
|
||||
uid, found := ci.nameToUID[name]
|
||||
if found {
|
||||
return ci.uidToClient[uid], true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findByIP finds persistent client by IP address.
|
||||
func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
|
||||
uid, found := ci.ipToUID[ip]
|
||||
if found {
|
||||
return ci.uidToClient[uid], true
|
||||
}
|
||||
|
||||
ipWithoutZone := ip.WithZone("")
|
||||
ci.subnetToUID.Range(func(pref netip.Prefix, id UID) (cont bool) {
|
||||
if pref.Contains(ip) {
|
||||
// Remove zone before checking because prefixes strip zones.
|
||||
if pref.Contains(ipWithoutZone) {
|
||||
uid, found = id, true
|
||||
|
||||
return false
|
||||
@@ -214,8 +266,8 @@ func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// find finds persistent client by MAC.
|
||||
func (ci *Index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
|
||||
// FindByMAC finds persistent client by MAC.
|
||||
func (ci *Index) FindByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
|
||||
k := macToKey(mac)
|
||||
uid, found := ci.macToUID[k]
|
||||
if found {
|
||||
@@ -225,9 +277,31 @@ func (ci *Index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// FindByIPWithoutZone finds a persistent client by IP address without zone. It
|
||||
// strips the IPv6 zone index from the stored IP addresses before comparing,
|
||||
// because querylog entries don't have it. See TODO on [querylog.logEntry.IP].
|
||||
//
|
||||
// Note that multiple clients can have the same IP address with different zones.
|
||||
// Therefore, the result of this method is indeterminate.
|
||||
func (ci *Index) FindByIPWithoutZone(ip netip.Addr) (c *Persistent) {
|
||||
if (ip == netip.Addr{}) {
|
||||
return nil
|
||||
}
|
||||
|
||||
for addr, uid := range ci.ipToUID {
|
||||
if addr.WithZone("") == ip {
|
||||
return ci.uidToClient[uid]
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes information about persistent client from the index. c must be
|
||||
// non-nil.
|
||||
func (ci *Index) Delete(c *Persistent) {
|
||||
delete(ci.nameToUID, c.Name)
|
||||
|
||||
for _, id := range c.ClientIDs {
|
||||
delete(ci.clientIDToUID, id)
|
||||
}
|
||||
@@ -247,3 +321,48 @@ func (ci *Index) Delete(c *Persistent) {
|
||||
|
||||
delete(ci.uidToClient, c.UID)
|
||||
}
|
||||
|
||||
// Size returns the number of persistent clients.
|
||||
func (ci *Index) Size() (n int) {
|
||||
return len(ci.uidToClient)
|
||||
}
|
||||
|
||||
// Range calls f for each persistent client, unless cont is false. The order is
|
||||
// undefined.
|
||||
func (ci *Index) Range(f func(c *Persistent) (cont bool)) {
|
||||
for _, c := range ci.uidToClient {
|
||||
if !f(c) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RangeByName is like [Index.Range] but sorts the persistent clients by name
|
||||
// before iterating ensuring a predictable order.
|
||||
func (ci *Index) RangeByName(f func(c *Persistent) (cont bool)) {
|
||||
cs := maps.Values(ci.uidToClient)
|
||||
slices.SortFunc(cs, func(a, b *Persistent) (n int) {
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
for _, c := range cs {
|
||||
if !f(c) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CloseUpstreams closes upstream configurations of persistent clients.
|
||||
func (ci *Index) CloseUpstreams() (err error) {
|
||||
var errs []error
|
||||
ci.RangeByName(func(c *Persistent) (cont bool) {
|
||||
err = c.CloseUpstreams()
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ func newIDIndex(m []*Persistent) (ci *Index) {
|
||||
return ci
|
||||
}
|
||||
|
||||
func TestClientIndex(t *testing.T) {
|
||||
func TestClientIndex_Find(t *testing.T) {
|
||||
const (
|
||||
cliIPNone = "1.2.3.4"
|
||||
cliIP1 = "1.1.1.1"
|
||||
@@ -35,26 +35,49 @@ func TestClientIndex(t *testing.T) {
|
||||
|
||||
cliID = "client-id"
|
||||
cliMAC = "11:11:11:11:11:11"
|
||||
|
||||
linkLocalIP = "fe80::abcd:abcd:abcd:ab%eth0"
|
||||
linkLocalSubnet = "fe80::/16"
|
||||
)
|
||||
|
||||
clients := []*Persistent{{
|
||||
Name: "client1",
|
||||
IPs: []netip.Addr{
|
||||
netip.MustParseAddr(cliIP1),
|
||||
netip.MustParseAddr(cliIPv6),
|
||||
},
|
||||
}, {
|
||||
Name: "client2",
|
||||
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
|
||||
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
|
||||
}, {
|
||||
Name: "client_with_mac",
|
||||
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
|
||||
}, {
|
||||
Name: "client_with_id",
|
||||
ClientIDs: []string{cliID},
|
||||
}}
|
||||
var (
|
||||
clientWithBothFams = &Persistent{
|
||||
Name: "client1",
|
||||
IPs: []netip.Addr{
|
||||
netip.MustParseAddr(cliIP1),
|
||||
netip.MustParseAddr(cliIPv6),
|
||||
},
|
||||
}
|
||||
|
||||
clientWithSubnet = &Persistent{
|
||||
Name: "client2",
|
||||
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
|
||||
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
|
||||
}
|
||||
|
||||
clientWithMAC = &Persistent{
|
||||
Name: "client_with_mac",
|
||||
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
|
||||
}
|
||||
|
||||
clientWithID = &Persistent{
|
||||
Name: "client_with_id",
|
||||
ClientIDs: []string{cliID},
|
||||
}
|
||||
|
||||
clientLinkLocal = &Persistent{
|
||||
Name: "client_link_local",
|
||||
Subnets: []netip.Prefix{netip.MustParsePrefix(linkLocalSubnet)},
|
||||
}
|
||||
)
|
||||
|
||||
clients := []*Persistent{
|
||||
clientWithBothFams,
|
||||
clientWithSubnet,
|
||||
clientWithMAC,
|
||||
clientWithID,
|
||||
clientLinkLocal,
|
||||
}
|
||||
ci := newIDIndex(clients)
|
||||
|
||||
testCases := []struct {
|
||||
@@ -64,19 +87,23 @@ func TestClientIndex(t *testing.T) {
|
||||
}{{
|
||||
name: "ipv4_ipv6",
|
||||
ids: []string{cliIP1, cliIPv6},
|
||||
want: clients[0],
|
||||
want: clientWithBothFams,
|
||||
}, {
|
||||
name: "ipv4_subnet",
|
||||
ids: []string{cliIP2, cliSubnetIP},
|
||||
want: clients[1],
|
||||
want: clientWithSubnet,
|
||||
}, {
|
||||
name: "mac",
|
||||
ids: []string{cliMAC},
|
||||
want: clients[2],
|
||||
want: clientWithMAC,
|
||||
}, {
|
||||
name: "client_id",
|
||||
ids: []string{cliID},
|
||||
want: clients[3],
|
||||
want: clientWithID,
|
||||
}, {
|
||||
name: "client_link_local_subnet",
|
||||
ids: []string{linkLocalIP},
|
||||
want: clientLinkLocal,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -221,3 +248,103 @@ func TestMACToKey(t *testing.T) {
|
||||
_ = macToKey(mac)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIndex_FindByIPWithoutZone(t *testing.T) {
|
||||
var (
|
||||
ip = netip.MustParseAddr("fe80::a098:7654:32ef:ff1")
|
||||
ipWithZone = netip.MustParseAddr("fe80::1ff:fe23:4567:890a%eth2")
|
||||
)
|
||||
|
||||
var (
|
||||
clientNoZone = &Persistent{
|
||||
Name: "client",
|
||||
IPs: []netip.Addr{ip},
|
||||
}
|
||||
|
||||
clientWithZone = &Persistent{
|
||||
Name: "client_with_zone",
|
||||
IPs: []netip.Addr{ipWithZone},
|
||||
}
|
||||
)
|
||||
|
||||
ci := newIDIndex([]*Persistent{
|
||||
clientNoZone,
|
||||
clientWithZone,
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
ip netip.Addr
|
||||
want *Persistent
|
||||
name string
|
||||
}{{
|
||||
name: "without_zone",
|
||||
ip: ip,
|
||||
want: clientNoZone,
|
||||
}, {
|
||||
name: "with_zone",
|
||||
ip: ipWithZone,
|
||||
want: clientWithZone,
|
||||
}, {
|
||||
name: "zero_address",
|
||||
ip: netip.Addr{},
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c := ci.FindByIPWithoutZone(tc.ip.WithZone(""))
|
||||
require.Equal(t, tc.want, c)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIndex_RangeByName(t *testing.T) {
|
||||
sortedClients := []*Persistent{{
|
||||
Name: "clientA",
|
||||
ClientIDs: []string{"A"},
|
||||
}, {
|
||||
Name: "clientB",
|
||||
ClientIDs: []string{"B"},
|
||||
}, {
|
||||
Name: "clientC",
|
||||
ClientIDs: []string{"C"},
|
||||
}, {
|
||||
Name: "clientD",
|
||||
ClientIDs: []string{"D"},
|
||||
}, {
|
||||
Name: "clientE",
|
||||
ClientIDs: []string{"E"},
|
||||
}}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
want []*Persistent
|
||||
}{{
|
||||
name: "basic",
|
||||
want: sortedClients,
|
||||
}, {
|
||||
name: "nil",
|
||||
want: nil,
|
||||
}, {
|
||||
name: "one_element",
|
||||
want: sortedClients[:1],
|
||||
}, {
|
||||
name: "two_elements",
|
||||
want: sortedClients[:2],
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ci := newIDIndex(tc.want)
|
||||
|
||||
var got []*Persistent
|
||||
ci.RangeByName(func(c *Persistent) (cont bool) {
|
||||
got = append(got, c)
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,9 +64,7 @@ type Persistent struct {
|
||||
// upstream must be used.
|
||||
UpstreamConfig *proxy.CustomUpstreamConfig
|
||||
|
||||
// TODO(d.kolyshev): Make SafeSearchConf a pointer.
|
||||
SafeSearchConf filtering.SafeSearchConfig
|
||||
SafeSearch filtering.SafeSearch
|
||||
SafeSearch filtering.SafeSearch
|
||||
|
||||
// BlockedServices is the configuration of blocked services of a client.
|
||||
BlockedServices *filtering.BlockedServices
|
||||
@@ -95,6 +93,9 @@ type Persistent struct {
|
||||
UseOwnBlockedServices bool
|
||||
IgnoreQueryLog bool
|
||||
IgnoreStatistics bool
|
||||
|
||||
// TODO(d.kolyshev): Make SafeSearchConf a pointer.
|
||||
SafeSearchConf filtering.SafeSearchConfig
|
||||
}
|
||||
|
||||
// SetTags sets the tags if they are known, otherwise logs an unknown tag.
|
||||
|
||||
63
internal/client/runtimeindex.go
Normal file
63
internal/client/runtimeindex.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package client
|
||||
|
||||
import "net/netip"
|
||||
|
||||
// RuntimeIndex stores information about runtime clients.
|
||||
type RuntimeIndex struct {
|
||||
// index maps IP address to runtime client.
|
||||
index map[netip.Addr]*Runtime
|
||||
}
|
||||
|
||||
// NewRuntimeIndex returns initialized runtime index.
|
||||
func NewRuntimeIndex() (ri *RuntimeIndex) {
|
||||
return &RuntimeIndex{
|
||||
index: map[netip.Addr]*Runtime{},
|
||||
}
|
||||
}
|
||||
|
||||
// Client returns the saved runtime client by ip. If no such client exists,
|
||||
// returns nil.
|
||||
func (ri *RuntimeIndex) Client(ip netip.Addr) (rc *Runtime) {
|
||||
return ri.index[ip]
|
||||
}
|
||||
|
||||
// Add saves the runtime client in the index. IP address of a client must be
|
||||
// unique. See [Runtime.Client]. rc must not be nil.
|
||||
func (ri *RuntimeIndex) Add(rc *Runtime) {
|
||||
ip := rc.Addr()
|
||||
ri.index[ip] = rc
|
||||
}
|
||||
|
||||
// Size returns the number of the runtime clients.
|
||||
func (ri *RuntimeIndex) Size() (n int) {
|
||||
return len(ri.index)
|
||||
}
|
||||
|
||||
// Range calls f for each runtime client in an undefined order.
|
||||
func (ri *RuntimeIndex) Range(f func(rc *Runtime) (cont bool)) {
|
||||
for _, rc := range ri.index {
|
||||
if !f(rc) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Delete removes the runtime client by ip.
|
||||
func (ri *RuntimeIndex) Delete(ip netip.Addr) {
|
||||
delete(ri.index, ip)
|
||||
}
|
||||
|
||||
// DeleteBySource removes all runtime clients that have information only from
|
||||
// the specified source and returns the number of removed clients.
|
||||
func (ri *RuntimeIndex) DeleteBySource(src Source) (n int) {
|
||||
for ip, rc := range ri.index {
|
||||
rc.unset(src)
|
||||
|
||||
if rc.isEmpty() {
|
||||
delete(ri.index, ip)
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
85
internal/client/runtimeindex_test.go
Normal file
85
internal/client/runtimeindex_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package client_test
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRuntimeIndex(t *testing.T) {
|
||||
const cliSrc = client.SourceARP
|
||||
|
||||
var (
|
||||
ip1 = netip.MustParseAddr("1.1.1.1")
|
||||
ip2 = netip.MustParseAddr("2.2.2.2")
|
||||
ip3 = netip.MustParseAddr("3.3.3.3")
|
||||
)
|
||||
|
||||
ri := client.NewRuntimeIndex()
|
||||
currentSize := 0
|
||||
|
||||
testCases := []struct {
|
||||
ip netip.Addr
|
||||
name string
|
||||
hosts []string
|
||||
src client.Source
|
||||
}{{
|
||||
src: cliSrc,
|
||||
ip: ip1,
|
||||
name: "1",
|
||||
hosts: []string{"host1"},
|
||||
}, {
|
||||
src: cliSrc,
|
||||
ip: ip2,
|
||||
name: "2",
|
||||
hosts: []string{"host2"},
|
||||
}, {
|
||||
src: cliSrc,
|
||||
ip: ip3,
|
||||
name: "3",
|
||||
hosts: []string{"host3"},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
rc := client.NewRuntime(tc.ip)
|
||||
rc.SetInfo(tc.src, tc.hosts)
|
||||
|
||||
ri.Add(rc)
|
||||
currentSize++
|
||||
|
||||
got := ri.Client(tc.ip)
|
||||
assert.Equal(t, rc, got)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("size", func(t *testing.T) {
|
||||
assert.Equal(t, currentSize, ri.Size())
|
||||
})
|
||||
|
||||
t.Run("range", func(t *testing.T) {
|
||||
s := 0
|
||||
|
||||
ri.Range(func(rc *client.Runtime) (cont bool) {
|
||||
s++
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, currentSize, s)
|
||||
})
|
||||
|
||||
t.Run("delete", func(t *testing.T) {
|
||||
ri.Delete(ip1)
|
||||
currentSize--
|
||||
|
||||
assert.Equal(t, currentSize, ri.Size())
|
||||
})
|
||||
|
||||
t.Run("delete_by_src", func(t *testing.T) {
|
||||
assert.Equal(t, currentSize, ri.DeleteBySource(cliSrc))
|
||||
assert.Equal(t, 0, ri.Size())
|
||||
})
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
package configmigrate
|
||||
|
||||
import "github.com/AdguardTeam/golibs/errors"
|
||||
|
||||
// migrateTo15 performs the following changes:
|
||||
//
|
||||
// # BEFORE:
|
||||
@@ -43,7 +45,7 @@ func migrateTo15(diskConf yobj) (err error) {
|
||||
}
|
||||
diskConf["querylog"] = qlog
|
||||
|
||||
return coalesceError(
|
||||
return errors.Join(
|
||||
moveVal[bool](dns, qlog, "querylog_enabled", "enabled"),
|
||||
moveVal[bool](dns, qlog, "querylog_file_enabled", "file_enabled"),
|
||||
moveVal[any](dns, qlog, "querylog_interval", "interval"),
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package configmigrate
|
||||
|
||||
import "github.com/AdguardTeam/golibs/errors"
|
||||
|
||||
// migrateTo24 performs the following changes:
|
||||
//
|
||||
// # BEFORE:
|
||||
@@ -28,7 +30,7 @@ func migrateTo24(diskConf yobj) (err error) {
|
||||
diskConf["schema_version"] = 24
|
||||
|
||||
logObj := yobj{}
|
||||
err = coalesceError(
|
||||
err = errors.Join(
|
||||
moveVal[string](diskConf, logObj, "log_file", "file"),
|
||||
moveVal[int](diskConf, logObj, "log_max_backups", "max_backups"),
|
||||
moveVal[int](diskConf, logObj, "log_max_size", "max_size"),
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package configmigrate
|
||||
|
||||
import "github.com/AdguardTeam/golibs/errors"
|
||||
|
||||
// migrateTo26 performs the following changes:
|
||||
//
|
||||
// # BEFORE:
|
||||
@@ -78,7 +80,7 @@ func migrateTo26(diskConf yobj) (err error) {
|
||||
}
|
||||
|
||||
filteringObj := yobj{}
|
||||
err = coalesceError(
|
||||
err = errors.Join(
|
||||
moveSameVal[bool](dns, filteringObj, "filtering_enabled"),
|
||||
moveSameVal[int](dns, filteringObj, "filters_update_interval"),
|
||||
moveSameVal[bool](dns, filteringObj, "parental_enabled"),
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package configmigrate
|
||||
|
||||
import "github.com/AdguardTeam/golibs/errors"
|
||||
|
||||
// migrateTo7 performs the following changes:
|
||||
//
|
||||
// # BEFORE:
|
||||
@@ -37,7 +39,7 @@ func migrateTo7(diskConf yobj) (err error) {
|
||||
}
|
||||
|
||||
dhcpv4 := yobj{}
|
||||
err = coalesceError(
|
||||
err = errors.Join(
|
||||
moveSameVal[string](dhcp, dhcpv4, "gateway_ip"),
|
||||
moveSameVal[string](dhcp, dhcpv4, "subnet_mask"),
|
||||
moveSameVal[string](dhcp, dhcpv4, "range_start"),
|
||||
|
||||
@@ -50,19 +50,3 @@ func moveVal[T any](src, dst yobj, srcKey, dstKey string) (err error) {
|
||||
func moveSameVal[T any](src, dst yobj, key string) (err error) {
|
||||
return moveVal[T](src, dst, key, key)
|
||||
}
|
||||
|
||||
// coalesceError returns the first non-nil error. It is named after function
|
||||
// COALESCE in SQL. If all errors are nil, it returns nil.
|
||||
//
|
||||
// TODO(e.burkov): Replace with [errors.Join].
|
||||
//
|
||||
// TODO(a.garipov): Think of ways to merge with [aghalg.Coalesce].
|
||||
func coalesceError(errors ...error) (res error) {
|
||||
for _, err := range errors {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -156,7 +156,10 @@ func (a *accessManager) isBlockedIP(ip netip.Addr) (blocked bool, rule string) {
|
||||
}
|
||||
|
||||
for _, ipnet := range ipnets {
|
||||
if ipnet.Contains(ip) {
|
||||
// Remove zone before checking because prefixes stip zones.
|
||||
//
|
||||
// TODO(d.kolyshev): Cover with tests.
|
||||
if ipnet.Contains(ip.WithZone("")) {
|
||||
return blocked, ipnet.String()
|
||||
}
|
||||
}
|
||||
|
||||
116
internal/dnsforward/beforerequest.go
Normal file
116
internal/dnsforward/beforerequest.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// type check
|
||||
var _ proxy.BeforeRequestHandler = (*Server)(nil)
|
||||
|
||||
// HandleBefore is the handler that is called before any other processing,
|
||||
// including logs. It performs access checks and puts the client ID, if there
|
||||
// is one, into the server's cache.
|
||||
//
|
||||
// TODO(d.kolyshev): Extract to separate package.
|
||||
func (s *Server) HandleBefore(
|
||||
_ *proxy.Proxy,
|
||||
pctx *proxy.DNSContext,
|
||||
) (err error) {
|
||||
clientID, err := s.clientIDFromDNSContext(pctx)
|
||||
if err != nil {
|
||||
return &proxy.BeforeRequestError{
|
||||
Err: fmt.Errorf("getting clientid: %w", err),
|
||||
Response: s.NewMsgSERVFAIL(pctx.Req),
|
||||
}
|
||||
}
|
||||
|
||||
blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID)
|
||||
if blocked {
|
||||
return s.preBlockedResponse(pctx)
|
||||
}
|
||||
|
||||
if len(pctx.Req.Question) == 1 {
|
||||
q := pctx.Req.Question[0]
|
||||
qt := q.Qtype
|
||||
host := aghnet.NormalizeDomain(q.Name)
|
||||
if s.access.isBlockedHost(host, qt) {
|
||||
log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host)
|
||||
|
||||
return s.preBlockedResponse(pctx)
|
||||
}
|
||||
}
|
||||
|
||||
if clientID != "" {
|
||||
key := [8]byte{}
|
||||
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
|
||||
s.clientIDCache.Set(key[:], []byte(clientID))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// clientIDFromDNSContext extracts the client's ID from the server name of the
|
||||
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
|
||||
// is not one of these, clientID is an empty string and err is nil.
|
||||
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
|
||||
proto := pctx.Proto
|
||||
if proto == proxy.ProtoHTTPS {
|
||||
clientID, err = clientIDFromDNSContextHTTPS(pctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("checking url: %w", err)
|
||||
} else if clientID != "" {
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// Go on and check the domain name as well.
|
||||
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
hostSrvName := s.conf.ServerName
|
||||
if hostSrvName == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
cliSrvName, err := clientServerName(pctx, proto)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
clientID, err = clientIDFromClientServerName(
|
||||
hostSrvName,
|
||||
cliSrvName,
|
||||
s.conf.StrictSNICheck,
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("clientid check: %w", err)
|
||||
}
|
||||
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// errAccessBlocked is a sentinel error returned when a request is blocked by
|
||||
// access settings.
|
||||
var errAccessBlocked errors.Error = "blocked by access settings"
|
||||
|
||||
// preBlockedResponse returns a protocol-appropriate response for a request that
|
||||
// was blocked by access settings.
|
||||
func (s *Server) preBlockedResponse(pctx *proxy.DNSContext) (err error) {
|
||||
if pctx.Proto == proxy.ProtoUDP || pctx.Proto == proxy.ProtoDNSCrypt {
|
||||
// Return nil so that dnsproxy drops the connection and thus
|
||||
// prevent DNS amplification attacks.
|
||||
return errAccessBlocked
|
||||
}
|
||||
|
||||
return &proxy.BeforeRequestError{
|
||||
Err: errAccessBlocked,
|
||||
Response: s.makeResponseREFUSED(pctx.Req),
|
||||
}
|
||||
}
|
||||
299
internal/dnsforward/beforerequest_internal_test.go
Normal file
299
internal/dnsforward/beforerequest_internal_test.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
blockedHost = "blockedhost.org"
|
||||
testFQDN = "example.org."
|
||||
dnsClientTimeout = 200 * time.Millisecond
|
||||
)
|
||||
|
||||
func TestServer_HandleBefore_tls(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const clientID = "client-1"
|
||||
|
||||
testCases := []struct {
|
||||
clientSrvName string
|
||||
name string
|
||||
host string
|
||||
allowedClients []string
|
||||
disallowedClients []string
|
||||
blockedHosts []string
|
||||
wantRCode int
|
||||
}{{
|
||||
clientSrvName: tlsServerName,
|
||||
name: "allow_all",
|
||||
host: testFQDN,
|
||||
allowedClients: []string{},
|
||||
disallowedClients: []string{},
|
||||
blockedHosts: []string{},
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
}, {
|
||||
clientSrvName: "%" + "." + tlsServerName,
|
||||
name: "invalid_client_id",
|
||||
host: testFQDN,
|
||||
allowedClients: []string{},
|
||||
disallowedClients: []string{},
|
||||
blockedHosts: []string{},
|
||||
wantRCode: dns.RcodeServerFailure,
|
||||
}, {
|
||||
clientSrvName: clientID + "." + tlsServerName,
|
||||
name: "allowed_client_allowed",
|
||||
host: testFQDN,
|
||||
allowedClients: []string{clientID},
|
||||
disallowedClients: []string{},
|
||||
blockedHosts: []string{},
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
}, {
|
||||
clientSrvName: "client-2." + tlsServerName,
|
||||
name: "allowed_client_rejected",
|
||||
host: testFQDN,
|
||||
allowedClients: []string{clientID},
|
||||
disallowedClients: []string{},
|
||||
blockedHosts: []string{},
|
||||
wantRCode: dns.RcodeRefused,
|
||||
}, {
|
||||
clientSrvName: tlsServerName,
|
||||
name: "disallowed_client_allowed",
|
||||
host: testFQDN,
|
||||
allowedClients: []string{},
|
||||
disallowedClients: []string{clientID},
|
||||
blockedHosts: []string{},
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
}, {
|
||||
clientSrvName: clientID + "." + tlsServerName,
|
||||
name: "disallowed_client_rejected",
|
||||
host: testFQDN,
|
||||
allowedClients: []string{},
|
||||
disallowedClients: []string{clientID},
|
||||
blockedHosts: []string{},
|
||||
wantRCode: dns.RcodeRefused,
|
||||
}, {
|
||||
clientSrvName: tlsServerName,
|
||||
name: "blocked_hosts_allowed",
|
||||
host: testFQDN,
|
||||
allowedClients: []string{},
|
||||
disallowedClients: []string{},
|
||||
blockedHosts: []string{blockedHost},
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
}, {
|
||||
clientSrvName: tlsServerName,
|
||||
name: "blocked_hosts_rejected",
|
||||
host: dns.Fqdn(blockedHost),
|
||||
allowedClients: []string{},
|
||||
disallowedClients: []string{},
|
||||
blockedHosts: []string{blockedHost},
|
||||
wantRCode: dns.RcodeRefused,
|
||||
}}
|
||||
|
||||
localAns := []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: testFQDN,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 3600,
|
||||
Rdlength: 4,
|
||||
},
|
||||
A: net.IP{1, 2, 3, 4},
|
||||
}}
|
||||
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
resp := (&dns.Msg{}).SetReply(req)
|
||||
resp.Answer = localAns
|
||||
|
||||
require.NoError(t, w.WriteMsg(resp))
|
||||
})
|
||||
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s, _ := createTestTLS(t, TLSConfig{
|
||||
TLSListenAddrs: []*net.TCPAddr{{}},
|
||||
ServerName: tlsServerName,
|
||||
})
|
||||
|
||||
s.conf.UpstreamDNS = []string{localUpsAddr}
|
||||
|
||||
s.conf.AllowedClients = tc.allowedClients
|
||||
s.conf.DisallowedClients = tc.disallowedClients
|
||||
s.conf.BlockedHosts = tc.blockedHosts
|
||||
|
||||
err := s.Prepare(&s.conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
startDeferStop(t, s)
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: tc.clientSrvName,
|
||||
}
|
||||
|
||||
client := &dns.Client{
|
||||
Net: "tcp-tls",
|
||||
TLSConfig: tlsConfig,
|
||||
Timeout: dnsClientTimeout,
|
||||
}
|
||||
|
||||
req := createTestMessage(tc.host)
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoTLS).String()
|
||||
|
||||
reply, _, err := client.Exchange(req, addr)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantRCode, reply.Rcode)
|
||||
if tc.wantRCode == dns.RcodeSuccess {
|
||||
assert.Equal(t, localAns, reply.Answer)
|
||||
} else {
|
||||
assert.Empty(t, reply.Answer)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_HandleBefore_udp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const (
|
||||
clientIPv4 = "127.0.0.1"
|
||||
clientIPv6 = "::1"
|
||||
)
|
||||
|
||||
clientIPs := []string{clientIPv4, clientIPv6}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
allowedClients []string
|
||||
disallowedClients []string
|
||||
blockedHosts []string
|
||||
wantTimeout bool
|
||||
}{{
|
||||
name: "allow_all",
|
||||
host: testFQDN,
|
||||
allowedClients: []string{},
|
||||
disallowedClients: []string{},
|
||||
blockedHosts: []string{},
|
||||
wantTimeout: false,
|
||||
}, {
|
||||
name: "allowed_client_allowed",
|
||||
host: testFQDN,
|
||||
allowedClients: clientIPs,
|
||||
disallowedClients: []string{},
|
||||
blockedHosts: []string{},
|
||||
wantTimeout: false,
|
||||
}, {
|
||||
name: "allowed_client_rejected",
|
||||
host: testFQDN,
|
||||
allowedClients: []string{"1:2:3::4"},
|
||||
disallowedClients: []string{},
|
||||
blockedHosts: []string{},
|
||||
wantTimeout: true,
|
||||
}, {
|
||||
name: "disallowed_client_allowed",
|
||||
host: testFQDN,
|
||||
allowedClients: []string{},
|
||||
disallowedClients: []string{"1:2:3::4"},
|
||||
blockedHosts: []string{},
|
||||
wantTimeout: false,
|
||||
}, {
|
||||
name: "disallowed_client_rejected",
|
||||
host: testFQDN,
|
||||
allowedClients: []string{},
|
||||
disallowedClients: clientIPs,
|
||||
blockedHosts: []string{},
|
||||
wantTimeout: true,
|
||||
}, {
|
||||
name: "blocked_hosts_allowed",
|
||||
host: testFQDN,
|
||||
allowedClients: []string{},
|
||||
disallowedClients: []string{},
|
||||
blockedHosts: []string{blockedHost},
|
||||
wantTimeout: false,
|
||||
}, {
|
||||
name: "blocked_hosts_rejected",
|
||||
host: dns.Fqdn(blockedHost),
|
||||
allowedClients: []string{},
|
||||
disallowedClients: []string{},
|
||||
blockedHosts: []string{blockedHost},
|
||||
wantTimeout: true,
|
||||
}}
|
||||
|
||||
localAns := []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: testFQDN,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 3600,
|
||||
Rdlength: 4,
|
||||
},
|
||||
A: net.IP{1, 2, 3, 4},
|
||||
}}
|
||||
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
resp := (&dns.Msg{}).SetReply(req)
|
||||
resp.Answer = localAns
|
||||
|
||||
require.NoError(t, w.WriteMsg(resp))
|
||||
})
|
||||
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
Config: Config{
|
||||
AllowedClients: tc.allowedClients,
|
||||
DisallowedClients: tc.disallowedClients,
|
||||
BlockedHosts: tc.blockedHosts,
|
||||
UpstreamDNS: []string{localUpsAddr},
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
|
||||
startDeferStop(t, s)
|
||||
|
||||
client := &dns.Client{
|
||||
Net: "udp",
|
||||
Timeout: dnsClientTimeout,
|
||||
}
|
||||
|
||||
req := createTestMessage(tc.host)
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
|
||||
|
||||
reply, _, err := client.Exchange(req, addr)
|
||||
if tc.wantTimeout {
|
||||
wantErr := &net.OpError{}
|
||||
require.ErrorAs(t, err, &wantErr)
|
||||
assert.True(t, wantErr.Timeout())
|
||||
|
||||
assert.Nil(t, reply)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, reply)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||||
assert.Equal(t, localAns, reply.Answer)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -110,46 +110,6 @@ type quicConnection interface {
|
||||
ConnectionState() (cs quic.ConnectionState)
|
||||
}
|
||||
|
||||
// clientIDFromDNSContext extracts the client's ID from the server name of the
|
||||
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
|
||||
// is not one of these, clientID is an empty string and err is nil.
|
||||
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
|
||||
proto := pctx.Proto
|
||||
if proto == proxy.ProtoHTTPS {
|
||||
clientID, err = clientIDFromDNSContextHTTPS(pctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("checking url: %w", err)
|
||||
} else if clientID != "" {
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// Go on and check the domain name as well.
|
||||
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
hostSrvName := s.conf.ServerName
|
||||
if hostSrvName == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
cliSrvName, err := clientServerName(pctx, proto)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
clientID, err = clientIDFromClientServerName(
|
||||
hostSrvName,
|
||||
cliSrvName,
|
||||
s.conf.StrictSNICheck,
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("clientid check: %w", err)
|
||||
}
|
||||
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// clientServerName returns the TLS server name based on the protocol. For
|
||||
// DNS-over-HTTPS requests, it will return the hostname part of the Host header
|
||||
// if there is one.
|
||||
|
||||
@@ -235,9 +235,18 @@ type DNSCryptConfig struct {
|
||||
// ServerConfig represents server configuration.
|
||||
// The zero ServerConfig is empty and ready for use.
|
||||
type ServerConfig struct {
|
||||
UDPListenAddrs []*net.UDPAddr // UDP listen address
|
||||
TCPListenAddrs []*net.TCPAddr // TCP listen address
|
||||
UpstreamConfig *proxy.UpstreamConfig // Upstream DNS servers config
|
||||
// UDPListenAddrs is the list of addresses to listen for DNS-over-UDP.
|
||||
UDPListenAddrs []*net.UDPAddr
|
||||
|
||||
// TCPListenAddrs is the list of addresses to listen for DNS-over-TCP.
|
||||
TCPListenAddrs []*net.TCPAddr
|
||||
|
||||
// UpstreamConfig is the general configuration of upstream DNS servers.
|
||||
UpstreamConfig *proxy.UpstreamConfig
|
||||
|
||||
// PrivateRDNSUpstreamConfig is the configuration of upstream DNS servers
|
||||
// for private reverse DNS.
|
||||
PrivateRDNSUpstreamConfig *proxy.UpstreamConfig
|
||||
|
||||
// AddrProcConf defines the configuration for the client IP processor.
|
||||
// If nil, [client.EmptyAddrProc] is used.
|
||||
@@ -306,24 +315,28 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
|
||||
trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies)
|
||||
|
||||
conf = &proxy.Config{
|
||||
HTTP3: srvConf.ServeHTTP3,
|
||||
Ratelimit: int(srvConf.Ratelimit),
|
||||
RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4,
|
||||
RatelimitSubnetLenIPv6: srvConf.RatelimitSubnetLenIPv6,
|
||||
RatelimitWhitelist: srvConf.RatelimitWhitelist,
|
||||
RefuseAny: srvConf.RefuseAny,
|
||||
TrustedProxies: netutil.SliceSubnetSet(trustedPrefixes),
|
||||
CacheMinTTL: srvConf.CacheMinTTL,
|
||||
CacheMaxTTL: srvConf.CacheMaxTTL,
|
||||
CacheOptimistic: srvConf.CacheOptimistic,
|
||||
UpstreamConfig: srvConf.UpstreamConfig,
|
||||
BeforeRequestHandler: s.beforeRequestHandler,
|
||||
RequestHandler: s.handleDNSRequest,
|
||||
HTTPSServerName: aghhttp.UserAgent(),
|
||||
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
|
||||
MaxGoroutines: srvConf.MaxGoroutines,
|
||||
UseDNS64: srvConf.UseDNS64,
|
||||
DNS64Prefs: srvConf.DNS64Prefixes,
|
||||
HTTP3: srvConf.ServeHTTP3,
|
||||
Ratelimit: int(srvConf.Ratelimit),
|
||||
RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4,
|
||||
RatelimitSubnetLenIPv6: srvConf.RatelimitSubnetLenIPv6,
|
||||
RatelimitWhitelist: srvConf.RatelimitWhitelist,
|
||||
RefuseAny: srvConf.RefuseAny,
|
||||
TrustedProxies: netutil.SliceSubnetSet(trustedPrefixes),
|
||||
CacheMinTTL: srvConf.CacheMinTTL,
|
||||
CacheMaxTTL: srvConf.CacheMaxTTL,
|
||||
CacheOptimistic: srvConf.CacheOptimistic,
|
||||
UpstreamConfig: srvConf.UpstreamConfig,
|
||||
PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig,
|
||||
BeforeRequestHandler: s,
|
||||
RequestHandler: s.handleDNSRequest,
|
||||
HTTPSServerName: aghhttp.UserAgent(),
|
||||
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
|
||||
MaxGoroutines: srvConf.MaxGoroutines,
|
||||
UseDNS64: srvConf.UseDNS64,
|
||||
DNS64Prefs: srvConf.DNS64Prefixes,
|
||||
UsePrivateRDNS: srvConf.UsePrivateRDNS,
|
||||
PrivateSubnets: s.privateNets,
|
||||
MessageConstructor: s,
|
||||
}
|
||||
|
||||
if srvConf.EDNSClientSubnet.UseCustom {
|
||||
@@ -452,12 +465,33 @@ func (s *Server) prepareIpsetListSettings() (err error) {
|
||||
}
|
||||
|
||||
ipsets := stringutil.SplitTrimmed(string(data), "\n")
|
||||
ipsets = stringutil.FilterOut(ipsets, IsCommentOrEmpty)
|
||||
|
||||
log.Debug("dns: using %d ipset rules from file %q", len(ipsets), fn)
|
||||
|
||||
return s.ipset.init(ipsets)
|
||||
}
|
||||
|
||||
// loadUpstreams parses upstream DNS servers from the configured file or from
|
||||
// the configuration itself.
|
||||
func (conf *ServerConfig) loadUpstreams() (upstreams []string, err error) {
|
||||
if conf.UpstreamDNSFileName == "" {
|
||||
return stringutil.FilterOut(conf.UpstreamDNS, IsCommentOrEmpty), nil
|
||||
}
|
||||
|
||||
var data []byte
|
||||
data, err = os.ReadFile(conf.UpstreamDNSFileName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading upstream from file: %w", err)
|
||||
}
|
||||
|
||||
upstreams = stringutil.SplitTrimmed(string(data), "\n")
|
||||
|
||||
log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), conf.UpstreamDNSFileName)
|
||||
|
||||
return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil
|
||||
}
|
||||
|
||||
// collectListenAddr adds addrPort to addrs. It also adds its port to
|
||||
// unspecPorts if its address is unspecified.
|
||||
func collectListenAddr(
|
||||
@@ -529,8 +563,8 @@ func (m *combinedAddrPortSet) Has(addrPort netip.AddrPort) (ok bool) {
|
||||
return m.ports.Has(addrPort.Port()) && m.addrs.Has(addrPort.Addr())
|
||||
}
|
||||
|
||||
// filterOut filters out all the upstreams that match um. It returns all the
|
||||
// closing errors joined.
|
||||
// filterOutAddrs filters out all the upstreams that match um. It returns all
|
||||
// the closing errors joined.
|
||||
func filterOutAddrs(upsConf *proxy.UpstreamConfig, set addrPortSet) (err error) {
|
||||
var errs []error
|
||||
delFunc := func(u upstream.Upstream) (ok bool) {
|
||||
|
||||
@@ -3,7 +3,6 @@ package dnsforward
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
@@ -11,6 +10,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -64,6 +64,8 @@ func newRR(t *testing.T, name string, qtype uint16, ttl uint32, val any) (rr dns
|
||||
}
|
||||
|
||||
func TestServer_HandleDNSRequest_dns64(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const (
|
||||
ipv4Domain = "ipv4.only."
|
||||
ipv6Domain = "ipv6.only."
|
||||
@@ -252,33 +254,33 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
|
||||
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
require.Len(pt, m.Question, 1)
|
||||
require.Equal(pt, m.Question[0].Name, ptr64Domain)
|
||||
resp := (&dns.Msg{
|
||||
Answer: []dns.RR{localRR},
|
||||
}).SetReply(m)
|
||||
|
||||
resp := (&dns.Msg{}).SetReply(m)
|
||||
resp.Answer = []dns.RR{localRR}
|
||||
|
||||
require.NoError(t, w.WriteMsg(resp))
|
||||
})
|
||||
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
|
||||
|
||||
client := &dns.Client{
|
||||
Net: "tcp",
|
||||
Timeout: 1 * time.Second,
|
||||
Net: string(proxy.ProtoTCP),
|
||||
Timeout: testTimeout,
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
upsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
q := req.Question[0]
|
||||
require.Contains(pt, tc.upsAns, q.Qtype)
|
||||
|
||||
require.Contains(pt, tc.upsAns, q.Qtype)
|
||||
answer := tc.upsAns[q.Qtype]
|
||||
|
||||
resp := (&dns.Msg{
|
||||
Answer: answer[sectionAnswer],
|
||||
Ns: answer[sectionAuthority],
|
||||
Extra: answer[sectionAdditional],
|
||||
}).SetReply(req)
|
||||
resp := (&dns.Msg{}).SetReply(req)
|
||||
resp.Answer = answer[sectionAnswer]
|
||||
resp.Ns = answer[sectionAuthority]
|
||||
resp.Extra = answer[sectionAdditional]
|
||||
|
||||
require.NoError(pt, w.WriteMsg(resp))
|
||||
})
|
||||
@@ -308,10 +310,54 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
|
||||
|
||||
req := (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype)
|
||||
|
||||
resp, _, excErr := client.Exchange(req, s.dnsProxy.Addr(proxy.ProtoTCP).String())
|
||||
resp, _, excErr := client.Exchange(req, s.proxy().Addr(proxy.ProtoTCP).String())
|
||||
require.NoError(t, excErr)
|
||||
|
||||
require.Equal(t, tc.wantAns, resp.Answer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_dns64WithDisabledRDNS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Shouldn't go to upstream at all.
|
||||
panicHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
panic("not implemented")
|
||||
})
|
||||
upsAddr := aghtest.StartLocalhostUpstream(t, panicHdlr).String()
|
||||
localUpsAddr := aghtest.StartLocalhostUpstream(t, panicHdlr).String()
|
||||
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
UseDNS64: true,
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
UpstreamDNS: []string{upsAddr},
|
||||
},
|
||||
UsePrivateRDNS: false,
|
||||
LocalPTRResolvers: []string{localUpsAddr},
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
startDeferStop(t, s)
|
||||
|
||||
mappedIPv6 := net.ParseIP("64:ff9b::102:304")
|
||||
arpa, err := netutil.IPToReversedAddr(mappedIPv6)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := (&dns.Msg{}).SetQuestion(dns.Fqdn(arpa), dns.TypePTR)
|
||||
|
||||
cli := &dns.Client{
|
||||
Net: string(proxy.ProtoTCP),
|
||||
Timeout: testTimeout,
|
||||
}
|
||||
|
||||
resp, _, err := cli.Exchange(req, s.proxy().Addr(proxy.ProtoTCP).String())
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dns.RcodeNameError, resp.Rcode)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -15,7 +16,6 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
@@ -135,12 +135,6 @@ type Server struct {
|
||||
// WHOIS, etc.
|
||||
addrProc client.AddressProcessor
|
||||
|
||||
// localResolvers is a DNS proxy instance used to resolve PTR records for
|
||||
// addresses considered private as per the [privateNets].
|
||||
//
|
||||
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
|
||||
localResolvers *proxy.Proxy
|
||||
|
||||
// sysResolvers used to fetch system resolvers to use by default for private
|
||||
// PTR resolving.
|
||||
sysResolvers SystemResolvers
|
||||
@@ -158,12 +152,6 @@ type Server struct {
|
||||
// [upstream.Resolver] interface.
|
||||
bootResolvers []*upstream.UpstreamResolver
|
||||
|
||||
// recDetector is a cache for recursive requests. It is used to detect and
|
||||
// prevent recursive requests only for private upstreams.
|
||||
//
|
||||
// See https://github.com/adguardTeam/adGuardHome/issues/3185#issuecomment-851048135.
|
||||
recDetector *recursionDetector
|
||||
|
||||
// dns64Pref is the NAT64 prefix used for DNS64 response mapping. The major
|
||||
// part of DNS64 happens inside the [proxy] package, but there still are
|
||||
// some places where response mapping is needed (e.g. DHCP).
|
||||
@@ -212,14 +200,6 @@ type DNSCreateParams struct {
|
||||
LocalDomain string
|
||||
}
|
||||
|
||||
const (
|
||||
// recursionTTL is the time recursive request is cached for.
|
||||
recursionTTL = 1 * time.Second
|
||||
// cachedRecurrentReqNum is the maximum number of cached recurrent
|
||||
// requests.
|
||||
cachedRecurrentReqNum = 1000
|
||||
)
|
||||
|
||||
// NewServer creates a new instance of the dnsforward.Server
|
||||
// Note: this function must be called only once
|
||||
//
|
||||
@@ -256,7 +236,6 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
// TODO(e.burkov): Use some case-insensitive string comparison.
|
||||
localDomainSuffix: strings.ToLower(localDomainSuffix),
|
||||
etcHosts: etcHosts,
|
||||
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
|
||||
clientIDCache: cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxCount: defaultClientIDCacheCount,
|
||||
@@ -366,6 +345,7 @@ func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err er
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
// TODO(e.burkov): Migrate to [netip.Addr] already.
|
||||
arpa, err := netutil.IPToReversedAddr(ip.AsSlice())
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("reversing ip: %w", err)
|
||||
@@ -386,25 +366,23 @@ func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err er
|
||||
}
|
||||
|
||||
dctx := &proxy.DNSContext{
|
||||
Proto: "udp",
|
||||
Req: req,
|
||||
Proto: proxy.ProtoUDP,
|
||||
Req: req,
|
||||
IsPrivateClient: true,
|
||||
}
|
||||
|
||||
var resolver *proxy.Proxy
|
||||
var errMsg string
|
||||
if s.privateNets.Contains(ip) {
|
||||
if !s.conf.UsePrivateRDNS {
|
||||
return "", 0, nil
|
||||
}
|
||||
|
||||
resolver = s.localResolvers
|
||||
errMsg = "resolving a private address: %w"
|
||||
s.recDetector.add(*req)
|
||||
dctx.RequestedPrivateRDNS = netip.PrefixFrom(ip, ip.BitLen())
|
||||
} else {
|
||||
resolver = s.internalProxy
|
||||
errMsg = "resolving an address: %w"
|
||||
}
|
||||
if err = resolver.Resolve(dctx); err != nil {
|
||||
if err = s.internalProxy.Resolve(dctx); err != nil {
|
||||
return "", 0, fmt.Errorf(errMsg, err)
|
||||
}
|
||||
|
||||
@@ -473,103 +451,6 @@ func (s *Server) startLocked() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// prepareLocalResolvers initializes the local upstreams configuration using
|
||||
// boot as bootstrap. It assumes that s.serverLock is locked or s not running.
|
||||
func (s *Server) prepareLocalResolvers(
|
||||
boot upstream.Resolver,
|
||||
) (uc *proxy.UpstreamConfig, err error) {
|
||||
set, err := s.conf.ourAddrsSet()
|
||||
if err != nil {
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resolvers := s.conf.LocalPTRResolvers
|
||||
confNeedsFiltering := len(resolvers) > 0
|
||||
if confNeedsFiltering {
|
||||
resolvers = stringutil.FilterOut(resolvers, IsCommentOrEmpty)
|
||||
} else {
|
||||
sysResolvers := slices.DeleteFunc(slices.Clone(s.sysResolvers.Addrs()), set.Has)
|
||||
resolvers = make([]string, 0, len(sysResolvers))
|
||||
for _, r := range sysResolvers {
|
||||
resolvers = append(resolvers, r.String())
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", resolvers)
|
||||
|
||||
uc, err = s.prepareUpstreamConfig(resolvers, nil, &upstream.Options{
|
||||
Bootstrap: boot,
|
||||
Timeout: defaultLocalTimeout,
|
||||
// TODO(e.burkov): Should we verify server's certificates?
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("preparing private upstreams: %w", err)
|
||||
}
|
||||
|
||||
if confNeedsFiltering {
|
||||
err = filterOutAddrs(uc, set)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("filtering private upstreams: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
// LocalResolversError is an error type for errors during local resolvers setup.
|
||||
// This is only needed to distinguish these errors from errors returned by
|
||||
// creating the proxy.
|
||||
type LocalResolversError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ error = (*LocalResolversError)(nil)
|
||||
|
||||
// Error implements the error interface for *LocalResolversError.
|
||||
func (err *LocalResolversError) Error() (s string) {
|
||||
return fmt.Sprintf("creating local resolvers: %s", err.Err)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ errors.Wrapper = (*LocalResolversError)(nil)
|
||||
|
||||
// Unwrap implements the [errors.Wrapper] interface for *LocalResolversError.
|
||||
func (err *LocalResolversError) Unwrap() error {
|
||||
return err.Err
|
||||
}
|
||||
|
||||
// setupLocalResolvers initializes and sets the resolvers for local addresses.
|
||||
// It assumes s.serverLock is locked or s not running. It returns the upstream
|
||||
// configuration used for private PTR resolving, or nil if it's disabled. Note,
|
||||
// that it's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
|
||||
func (s *Server) setupLocalResolvers(boot upstream.Resolver) (uc *proxy.UpstreamConfig, err error) {
|
||||
if !s.conf.UsePrivateRDNS {
|
||||
// It's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
uc, err = s.prepareLocalResolvers(boot)
|
||||
if err != nil {
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
localResolvers, err := proxy.New(&proxy.Config{
|
||||
UpstreamConfig: uc,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, &LocalResolversError{Err: err}
|
||||
}
|
||||
|
||||
s.localResolvers = localResolvers
|
||||
|
||||
// TODO(e.burkov): Should we also consider the DNS64 usage?
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
// Prepare initializes parameters of s using data from conf. conf must not be
|
||||
// nil.
|
||||
func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
@@ -586,7 +467,7 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
|
||||
s.initDefaultSettings()
|
||||
|
||||
boot, err := s.prepareInternalDNS()
|
||||
err = s.prepareInternalDNS()
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
@@ -608,12 +489,6 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
return fmt.Errorf("preparing access: %w", err)
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
|
||||
proxyConfig.PrivateRDNSUpstreamConfig, err = s.setupLocalResolvers(boot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting up resolvers: %w", err)
|
||||
}
|
||||
|
||||
proxyConfig.Fallbacks, err = s.setupFallbackDNS()
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting up fallback dns servers: %w", err)
|
||||
@@ -626,8 +501,6 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
|
||||
s.dnsProxy = dnsProxy
|
||||
|
||||
s.recDetector.clear()
|
||||
|
||||
s.setupAddrProc()
|
||||
|
||||
s.registerHandlers()
|
||||
@@ -635,36 +508,127 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareInternalDNS initializes the internal state of s before initializing
|
||||
// the primary DNS proxy instance. It assumes s.serverLock is locked or the
|
||||
// Server not running.
|
||||
func (s *Server) prepareInternalDNS() (boot upstream.Resolver, err error) {
|
||||
err = s.prepareIpsetListSettings()
|
||||
// prepareUpstreamSettings sets upstream DNS server settings.
|
||||
func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
|
||||
// Load upstreams either from the file, or from the settings
|
||||
var upstreams []string
|
||||
upstreams, err = s.conf.loadUpstreams()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("preparing ipset settings: %w", err)
|
||||
return fmt.Errorf("loading upstreams: %w", err)
|
||||
}
|
||||
|
||||
s.bootstrap, s.bootResolvers, err = s.createBootstrap(s.conf.BootstrapDNS, &upstream.Options{
|
||||
Timeout: DefaultTimeout,
|
||||
uc, err := newUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
|
||||
Bootstrap: boot,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
// Use a customized set of RootCAs, because Go's default mechanism of
|
||||
// loading TLS roots does not always work properly on some routers so we're
|
||||
// loading roots manually and pass it here.
|
||||
//
|
||||
// See [aghtls.SystemRootCAs].
|
||||
//
|
||||
// TODO(a.garipov): Investigate if that's true.
|
||||
RootCAs: s.conf.TLSv12Roots,
|
||||
CipherSuites: s.conf.TLSCiphers,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing upstream config: %w", err)
|
||||
}
|
||||
|
||||
s.conf.UpstreamConfig = uc
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PrivateRDNSError is returned when the private rDNS upstreams are
|
||||
// invalid but enabled.
|
||||
//
|
||||
// TODO(e.burkov): Consider allowing to use incomplete private rDNS upstreams
|
||||
// configuration in proxy when the private rDNS function is enabled. In theory,
|
||||
// proxy supports the case when no upstreams provided to resolve the private
|
||||
// request, since it already supports this for DNS64-prefixed PTR requests.
|
||||
type PrivateRDNSError struct {
|
||||
err error
|
||||
}
|
||||
|
||||
// Error implements the [errors.Error] interface.
|
||||
func (e *PrivateRDNSError) Error() (s string) {
|
||||
return e.err.Error()
|
||||
}
|
||||
|
||||
func (e *PrivateRDNSError) Unwrap() (err error) {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// prepareLocalResolvers initializes the private RDNS upstream configuration
|
||||
// according to the server's settings. It assumes s.serverLock is locked or the
|
||||
// Server not running.
|
||||
func (s *Server) prepareLocalResolvers() (uc *proxy.UpstreamConfig, err error) {
|
||||
if !s.conf.UsePrivateRDNS {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var ownAddrs addrPortSet
|
||||
ownAddrs, err = s.conf.ourAddrsSet()
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts := &upstream.Options{
|
||||
Bootstrap: s.bootstrap,
|
||||
Timeout: defaultLocalTimeout,
|
||||
// TODO(e.burkov): Should we verify server's certificates?
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
}
|
||||
|
||||
addrs := s.conf.LocalPTRResolvers
|
||||
uc, err = newPrivateConfig(addrs, ownAddrs, s.sysResolvers, s.privateNets, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("preparing resolvers: %w", err)
|
||||
}
|
||||
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
// prepareInternalDNS initializes the internal state of s before initializing
|
||||
// the primary DNS proxy instance. It assumes s.serverLock is locked or the
|
||||
// Server not running.
|
||||
func (s *Server) prepareInternalDNS() (err error) {
|
||||
err = s.prepareIpsetListSettings()
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing ipset settings: %w", err)
|
||||
}
|
||||
|
||||
bootOpts := &upstream.Options{
|
||||
Timeout: DefaultTimeout,
|
||||
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
|
||||
}
|
||||
|
||||
s.bootstrap, s.bootResolvers, err = newBootstrap(s.conf.BootstrapDNS, s.etcHosts, bootOpts)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.prepareUpstreamSettings(s.bootstrap)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return s.bootstrap, err
|
||||
return err
|
||||
}
|
||||
|
||||
s.conf.PrivateRDNSUpstreamConfig, err = s.prepareLocalResolvers()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.prepareInternalProxy()
|
||||
if err != nil {
|
||||
return s.bootstrap, fmt.Errorf("preparing internal proxy: %w", err)
|
||||
return fmt.Errorf("preparing internal proxy: %w", err)
|
||||
}
|
||||
|
||||
return s.bootstrap, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupFallbackDNS initializes the fallback DNS servers.
|
||||
@@ -743,10 +707,16 @@ func validateBlockingMode(
|
||||
func (s *Server) prepareInternalProxy() (err error) {
|
||||
srvConf := s.conf
|
||||
conf := &proxy.Config{
|
||||
CacheEnabled: true,
|
||||
CacheSizeBytes: 4096,
|
||||
UpstreamConfig: srvConf.UpstreamConfig,
|
||||
MaxGoroutines: s.conf.MaxGoroutines,
|
||||
CacheEnabled: true,
|
||||
CacheSizeBytes: 4096,
|
||||
PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig,
|
||||
UpstreamConfig: srvConf.UpstreamConfig,
|
||||
MaxGoroutines: srvConf.MaxGoroutines,
|
||||
UseDNS64: srvConf.UseDNS64,
|
||||
DNS64Prefs: srvConf.DNS64Prefixes,
|
||||
UsePrivateRDNS: srvConf.UsePrivateRDNS,
|
||||
PrivateSubnets: s.privateNets,
|
||||
MessageConstructor: s,
|
||||
}
|
||||
|
||||
err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration)
|
||||
@@ -782,11 +752,6 @@ func (s *Server) stopLocked() (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
logCloserErr(s.internalProxy.UpstreamConfig, "dnsforward: closing internal resolvers: %s")
|
||||
if s.localResolvers != nil {
|
||||
logCloserErr(s.localResolvers.UpstreamConfig, "dnsforward: closing local resolvers: %s")
|
||||
}
|
||||
|
||||
for _, b := range s.bootResolvers {
|
||||
logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address())
|
||||
}
|
||||
@@ -908,5 +873,5 @@ func (s *Server) IsBlockedClient(ip netip.Addr, clientID string) (blocked bool,
|
||||
blocked = true
|
||||
}
|
||||
|
||||
return blocked, aghalg.Coalesce(rule, clientID)
|
||||
return blocked, cmp.Or(rule, clientID)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"cmp"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"testing/fstest"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
@@ -190,7 +189,7 @@ func newGoogleUpstream() (u upstream.Upstream) {
|
||||
return &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "google.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return aghalg.Coalesce(
|
||||
return cmp.Or(
|
||||
aghtest.MatchedResponse(req, dns.TypeA, googleDomainName, "8.8.8.8"),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
), nil
|
||||
@@ -253,7 +252,7 @@ func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) {
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < testMessagesCount; i++ {
|
||||
for range testMessagesCount {
|
||||
msg := createGoogleATestMessage()
|
||||
wg.Add(1)
|
||||
|
||||
@@ -276,7 +275,7 @@ func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) {
|
||||
func sendTestMessages(t *testing.T, conn *dns.Conn) {
|
||||
t.Helper()
|
||||
|
||||
for i := 0; i < testMessagesCount; i++ {
|
||||
for i := range testMessagesCount {
|
||||
req := createGoogleATestMessage()
|
||||
err := conn.WriteMsg(req)
|
||||
assert.NoErrorf(t, err, "cannot write message #%d: %s", i, err)
|
||||
@@ -491,19 +490,10 @@ func TestServerRace(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSafeSearch(t *testing.T) {
|
||||
resolver := &aghtest.Resolver{
|
||||
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
|
||||
ip4, ip6 := aghtest.HostToIPs(host)
|
||||
|
||||
return []net.IP{ip4.AsSlice(), ip6.AsSlice()}, nil
|
||||
},
|
||||
}
|
||||
|
||||
safeSearchConf := filtering.SafeSearchConfig{
|
||||
Enabled: true,
|
||||
Google: true,
|
||||
Yandex: true,
|
||||
CustomResolver: resolver,
|
||||
Enabled: true,
|
||||
Google: true,
|
||||
Yandex: true,
|
||||
}
|
||||
|
||||
filterConf := &filtering.Config{
|
||||
@@ -540,7 +530,6 @@ func TestSafeSearch(t *testing.T) {
|
||||
client := &dns.Client{}
|
||||
|
||||
yandexIP := netip.AddrFrom4([4]byte{213, 180, 193, 56})
|
||||
googleIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
|
||||
|
||||
testCases := []struct {
|
||||
host string
|
||||
@@ -564,19 +553,19 @@ func TestSafeSearch(t *testing.T) {
|
||||
wantCNAME: "",
|
||||
}, {
|
||||
host: "www.google.com.",
|
||||
want: googleIP,
|
||||
want: netip.Addr{},
|
||||
wantCNAME: "forcesafesearch.google.com.",
|
||||
}, {
|
||||
host: "www.google.com.af.",
|
||||
want: googleIP,
|
||||
want: netip.Addr{},
|
||||
wantCNAME: "forcesafesearch.google.com.",
|
||||
}, {
|
||||
host: "www.google.be.",
|
||||
want: googleIP,
|
||||
want: netip.Addr{},
|
||||
wantCNAME: "forcesafesearch.google.com.",
|
||||
}, {
|
||||
host: "www.google.by.",
|
||||
want: googleIP,
|
||||
want: netip.Addr{},
|
||||
wantCNAME: "forcesafesearch.google.com.",
|
||||
}}
|
||||
|
||||
@@ -593,12 +582,15 @@ func TestSafeSearch(t *testing.T) {
|
||||
|
||||
cname := testutil.RequireTypeAssert[*dns.CNAME](t, reply.Answer[0])
|
||||
assert.Equal(t, tc.wantCNAME, cname.Target)
|
||||
|
||||
a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[1])
|
||||
assert.NotEmpty(t, a.A)
|
||||
} else {
|
||||
require.Len(t, reply.Answer, 1)
|
||||
}
|
||||
|
||||
a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[len(reply.Answer)-1])
|
||||
assert.Equal(t, net.IP(tc.want.AsSlice()), a.A)
|
||||
a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[0])
|
||||
assert.Equal(t, net.IP(tc.want.AsSlice()), a.A)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -691,7 +683,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
|
||||
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
atomic.AddUint32(&upsCalledCounter, 1)
|
||||
|
||||
return aghalg.Coalesce(
|
||||
return cmp.Or(
|
||||
aghtest.MatchedResponse(req, dns.TypeA, "host", "192.168.0.1"),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
), nil
|
||||
@@ -1152,7 +1144,7 @@ func TestRewrite(t *testing.T) {
|
||||
}))
|
||||
|
||||
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return aghalg.Coalesce(
|
||||
return cmp.Or(
|
||||
aghtest.MatchedResponse(req, dns.TypeA, "example.org", "4.3.2.1"),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
), nil
|
||||
@@ -1481,7 +1473,7 @@ func TestServer_Exchange(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
extUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
resp := aghalg.Coalesce(
|
||||
resp := cmp.Or(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, dns.Fqdn(onesHost)),
|
||||
doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, dns.Fqdn(twosHost))),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
@@ -1495,7 +1487,7 @@ func TestServer_Exchange(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
locUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
resp := aghalg.Coalesce(
|
||||
resp := cmp.Or(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, revLocIPv4, dns.Fqdn(localDomainHost)),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
@@ -143,7 +143,7 @@ func (s *Server) filterDNSRewrite(
|
||||
res *filtering.Result,
|
||||
pctx *proxy.DNSContext,
|
||||
) (err error) {
|
||||
resp := s.makeResponse(req)
|
||||
resp := s.replyCompressed(req)
|
||||
dnsrr := res.DNSRewriteResult
|
||||
if dnsrr == nil {
|
||||
return errors.Error("no dns rewrite rule content")
|
||||
|
||||
@@ -1,57 +1,17 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// beforeRequestHandler is the handler that is called before any other
|
||||
// processing, including logs. It performs access checks and puts the client
|
||||
// ID, if there is one, into the server's cache.
|
||||
func (s *Server) beforeRequestHandler(
|
||||
_ *proxy.Proxy,
|
||||
pctx *proxy.DNSContext,
|
||||
) (reply bool, err error) {
|
||||
clientID, err := s.clientIDFromDNSContext(pctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("getting clientid: %w", err)
|
||||
}
|
||||
|
||||
blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID)
|
||||
if blocked {
|
||||
return s.preBlockedResponse(pctx)
|
||||
}
|
||||
|
||||
if len(pctx.Req.Question) == 1 {
|
||||
q := pctx.Req.Question[0]
|
||||
qt := q.Qtype
|
||||
host := aghnet.NormalizeDomain(q.Name)
|
||||
if s.access.isBlockedHost(host, qt) {
|
||||
log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host)
|
||||
|
||||
return s.preBlockedResponse(pctx)
|
||||
}
|
||||
}
|
||||
|
||||
if clientID != "" {
|
||||
key := [8]byte{}
|
||||
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
|
||||
s.clientIDCache.Set(key[:], []byte(clientID))
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// clientRequestFilteringSettings looks up client filtering settings using the
|
||||
// client's IP address and ID, if any, from dctx.
|
||||
func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) {
|
||||
@@ -71,6 +31,7 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err
|
||||
req := pctx.Req
|
||||
q := req.Question[0]
|
||||
host := strings.TrimSuffix(q.Name, ".")
|
||||
|
||||
resVal, err := s.dnsFilter.CheckHost(host, q.Qtype, dctx.setts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("checking host %q: %w", host, err)
|
||||
@@ -79,22 +40,15 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err
|
||||
// TODO(a.garipov): Make CheckHost return a pointer.
|
||||
res = &resVal
|
||||
switch {
|
||||
case res.IsFiltered:
|
||||
log.Debug(
|
||||
"dnsforward: host %q is filtered, reason: %q; rule: %q",
|
||||
host,
|
||||
res.Reason,
|
||||
res.Rules[0].Text,
|
||||
)
|
||||
pctx.Res = s.genDNSFilterMessage(pctx, res)
|
||||
case res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) &&
|
||||
res.CanonName != "" &&
|
||||
len(res.IPList) == 0:
|
||||
case isRewrittenCNAME(res):
|
||||
// Resolve the new canonical name, not the original host name. The
|
||||
// original question is readded in processFilteringAfterResponse.
|
||||
dctx.origQuestion = q
|
||||
req.Question[0].Name = dns.Fqdn(res.CanonName)
|
||||
case res.Reason == filtering.Rewritten:
|
||||
case res.IsFiltered:
|
||||
log.Debug("dnsforward: host %q is filtered, reason: %q", host, res.Reason)
|
||||
pctx.Res = s.genDNSFilterMessage(pctx, res)
|
||||
case res.Reason.In(filtering.Rewritten, filtering.FilteredSafeSearch):
|
||||
pctx.Res = s.getCNAMEWithIPs(req, res.IPList, res.CanonName)
|
||||
case res.Reason.In(filtering.RewrittenRule, filtering.RewrittenAutoHosts):
|
||||
if err = s.filterDNSRewrite(req, res, pctx); err != nil {
|
||||
@@ -105,6 +59,17 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err
|
||||
return res, err
|
||||
}
|
||||
|
||||
// isRewrittenCNAME returns true if the request considered to be rewritten with
|
||||
// CNAME and has no resolved IPs.
|
||||
func isRewrittenCNAME(res *filtering.Result) (ok bool) {
|
||||
return res.Reason.In(
|
||||
filtering.Rewritten,
|
||||
filtering.RewrittenRule,
|
||||
filtering.FilteredSafeSearch) &&
|
||||
res.CanonName != "" &&
|
||||
len(res.IPList) == 0
|
||||
}
|
||||
|
||||
// checkHostRules checks the host against filters. It is safe for concurrent
|
||||
// use.
|
||||
func (s *Server) checkHostRules(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -261,55 +262,17 @@ func (req *jsonDNSConfig) checkUpstreamMode() (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
// checkBootstrap returns an error if any bootstrap address is invalid.
|
||||
func (req *jsonDNSConfig) checkBootstrap() (err error) {
|
||||
if req.Bootstraps == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var b string
|
||||
defer func() { err = errors.Annotate(err, "checking bootstrap %s: %w", b) }()
|
||||
|
||||
for _, b = range *req.Bootstraps {
|
||||
if b == "" {
|
||||
return errors.Error("empty")
|
||||
}
|
||||
|
||||
var resolver *upstream.UpstreamResolver
|
||||
if resolver, err = upstream.NewUpstreamResolver(b, nil); err != nil {
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
if err = resolver.Close(); err != nil {
|
||||
return fmt.Errorf("closing %s: %w", b, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkFallbacks returns an error if any fallback address is invalid.
|
||||
func (req *jsonDNSConfig) checkFallbacks() (err error) {
|
||||
if req.Fallbacks == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = proxy.ParseUpstreamsConfig(*req.Fallbacks, &upstream.Options{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("fallback servers: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validate returns an error if any field of req is invalid.
|
||||
//
|
||||
// TODO(s.chzhen): Parse, don't validate.
|
||||
func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
|
||||
func (req *jsonDNSConfig) validate(
|
||||
ownAddrs addrPortSet,
|
||||
sysResolvers SystemResolvers,
|
||||
privateNets netutil.SubnetSet,
|
||||
) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "validating dns config: %w") }()
|
||||
|
||||
err = req.validateUpstreamDNSServers(privateNets)
|
||||
err = req.validateUpstreamDNSServers(ownAddrs, sysResolvers, privateNets)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
@@ -342,20 +305,77 @@ func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkBootstrap returns an error if any bootstrap address is invalid.
|
||||
func (req *jsonDNSConfig) checkBootstrap() (err error) {
|
||||
if req.Bootstraps == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var b string
|
||||
defer func() { err = errors.Annotate(err, "checking bootstrap %s: %w", b) }()
|
||||
|
||||
for _, b = range *req.Bootstraps {
|
||||
if b == "" {
|
||||
return errors.Error("empty")
|
||||
}
|
||||
|
||||
var resolver *upstream.UpstreamResolver
|
||||
if resolver, err = upstream.NewUpstreamResolver(b, nil); err != nil {
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
if err = resolver.Close(); err != nil {
|
||||
return fmt.Errorf("closing %s: %w", b, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkPrivateRDNS returns an error if the configuration of the private RDNS is
|
||||
// not valid.
|
||||
func (req *jsonDNSConfig) checkPrivateRDNS(
|
||||
ownAddrs addrPortSet,
|
||||
sysResolvers SystemResolvers,
|
||||
privateNets netutil.SubnetSet,
|
||||
) (err error) {
|
||||
if (req.UsePrivateRDNS == nil || !*req.UsePrivateRDNS) && req.LocalPTRUpstreams == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
addrs := cmp.Or(req.LocalPTRUpstreams, &[]string{})
|
||||
|
||||
uc, err := newPrivateConfig(*addrs, ownAddrs, sysResolvers, privateNets, &upstream.Options{})
|
||||
err = errors.WithDeferred(err, uc.Close())
|
||||
if err != nil {
|
||||
return fmt.Errorf("private upstream servers: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateUpstreamDNSServers returns an error if any field of req is invalid.
|
||||
func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetSet) (err error) {
|
||||
func (req *jsonDNSConfig) validateUpstreamDNSServers(
|
||||
ownAddrs addrPortSet,
|
||||
sysResolvers SystemResolvers,
|
||||
privateNets netutil.SubnetSet,
|
||||
) (err error) {
|
||||
var uc *proxy.UpstreamConfig
|
||||
opts := &upstream.Options{}
|
||||
|
||||
if req.Upstreams != nil {
|
||||
_, err = proxy.ParseUpstreamsConfig(*req.Upstreams, &upstream.Options{})
|
||||
uc, err = proxy.ParseUpstreamsConfig(*req.Upstreams, opts)
|
||||
err = errors.WithDeferred(err, uc.Close())
|
||||
if err != nil {
|
||||
return fmt.Errorf("upstream servers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if req.LocalPTRUpstreams != nil {
|
||||
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets)
|
||||
if err != nil {
|
||||
return fmt.Errorf("private upstream servers: %w", err)
|
||||
}
|
||||
err = req.checkPrivateRDNS(ownAddrs, sysResolvers, privateNets)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
err = req.checkBootstrap()
|
||||
@@ -364,10 +384,12 @@ func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetS
|
||||
return err
|
||||
}
|
||||
|
||||
err = req.checkFallbacks()
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
if req.Fallbacks != nil {
|
||||
uc, err = proxy.ParseUpstreamsConfig(*req.Fallbacks, opts)
|
||||
err = errors.WithDeferred(err, uc.Close())
|
||||
if err != nil {
|
||||
return fmt.Errorf("fallback servers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -436,7 +458,16 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = req.validate(s.privateNets)
|
||||
// TODO(e.burkov): Consider prebuilding this set on startup.
|
||||
ourAddrs, err := s.conf.ourAddrsSet()
|
||||
if err != nil {
|
||||
// TODO(e.burkov): Put into openapi.
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "getting our addresses: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = req.validate(ourAddrs, s.sysResolvers, s.privateNets)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
@@ -587,7 +618,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
var boots []*upstream.UpstreamResolver
|
||||
opts.Bootstrap, boots, err = s.createBootstrap(req.BootstrapDNS, opts)
|
||||
opts.Bootstrap, boots, err = newBootstrap(req.BootstrapDNS, s.etcHosts, opts)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse bootstrap servers: %s", err)
|
||||
|
||||
|
||||
@@ -245,9 +245,8 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "local_ptr_upstreams_bad",
|
||||
wantSet: `validating dns config: ` +
|
||||
`private upstream servers: checking domain-specific upstreams: ` +
|
||||
`bad arpa domain name "non.arpa.": not a reversed ip network`,
|
||||
wantSet: `validating dns config: private upstream servers: ` +
|
||||
`bad arpa domain name "non.arpa": not a reversed ip network`,
|
||||
}, {
|
||||
name: "local_ptr_upstreams_null",
|
||||
wantSet: "",
|
||||
@@ -318,58 +317,6 @@ func TestIsCommentOrEmpty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUpstreamsPrivate(t *testing.T) {
|
||||
ss := netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErr string
|
||||
u string
|
||||
}{{
|
||||
name: "success_address",
|
||||
wantErr: ``,
|
||||
u: "[/1.0.0.127.in-addr.arpa/]#",
|
||||
}, {
|
||||
name: "success_subnet",
|
||||
wantErr: ``,
|
||||
u: "[/127.in-addr.arpa/]#",
|
||||
}, {
|
||||
name: "not_arpa_subnet",
|
||||
wantErr: `checking domain-specific upstreams: ` +
|
||||
`bad arpa domain name "hello.world.": not a reversed ip network`,
|
||||
u: "[/hello.world/]#",
|
||||
}, {
|
||||
name: "non-private_arpa_address",
|
||||
wantErr: `checking domain-specific upstreams: ` +
|
||||
`arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network`,
|
||||
u: "[/1.2.3.4.in-addr.arpa/]#",
|
||||
}, {
|
||||
name: "non-private_arpa_subnet",
|
||||
wantErr: `checking domain-specific upstreams: ` +
|
||||
`arpa domain "128.in-addr.arpa." should point to a locally-served network`,
|
||||
u: "[/128.in-addr.arpa/]#",
|
||||
}, {
|
||||
name: "several_bad",
|
||||
wantErr: `checking domain-specific upstreams: ` +
|
||||
`arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network` + "\n" +
|
||||
`bad arpa domain name "non.arpa.": not a reversed ip network`,
|
||||
u: "[/non.arpa/1.2.3.4.in-addr.arpa/127.in-addr.arpa/]#",
|
||||
}, {
|
||||
name: "partial_good",
|
||||
wantErr: "",
|
||||
u: "[/a.1.2.3.10.in-addr.arpa/a.10.in-addr.arpa/]#",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
set := []string{"192.168.0.1", tc.u}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := ValidateUpstreamsPrivate(set, ss)
|
||||
testutil.AssertErrorMsg(t, tc.wantErr, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -11,17 +11,21 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// makeResponse creates a DNS response by req and sets necessary flags. It also
|
||||
// guarantees that req.Question will be not empty.
|
||||
func (s *Server) makeResponse(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp = &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
RecursionAvailable: true,
|
||||
},
|
||||
Compress: true,
|
||||
}
|
||||
// TODO(e.burkov): Name all the methods by a [proxy.MessageConstructor]
|
||||
// template. Also extract all the methods to a separate entity.
|
||||
|
||||
resp.SetReply(req)
|
||||
// reply creates a DNS response for req.
|
||||
func (*Server) reply(req *dns.Msg, code int) (resp *dns.Msg) {
|
||||
resp = (&dns.Msg{}).SetRcode(req, code)
|
||||
resp.RecursionAvailable = true
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// replyCompressed creates a DNS response for req and sets the compress flag.
|
||||
func (s *Server) replyCompressed(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp = s.reply(req, dns.RcodeSuccess)
|
||||
resp.Compress = true
|
||||
|
||||
return resp
|
||||
}
|
||||
@@ -48,10 +52,10 @@ func (s *Server) genDNSFilterMessage(
|
||||
) (resp *dns.Msg) {
|
||||
req := dctx.Req
|
||||
qt := req.Question[0].Qtype
|
||||
if qt != dns.TypeA && qt != dns.TypeAAAA {
|
||||
if qt != dns.TypeA && qt != dns.TypeAAAA && qt != dns.TypeHTTPS {
|
||||
m, _, _ := s.dnsFilter.BlockingMode()
|
||||
if m == filtering.BlockingModeNullIP {
|
||||
return s.makeResponse(req)
|
||||
return s.replyCompressed(req)
|
||||
}
|
||||
|
||||
return s.newMsgNODATA(req)
|
||||
@@ -75,7 +79,7 @@ func (s *Server) genDNSFilterMessage(
|
||||
// getCNAMEWithIPs generates a filtered response to req for with CNAME record
|
||||
// and provided ips.
|
||||
func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) (resp *dns.Msg) {
|
||||
resp = s.makeResponse(req)
|
||||
resp = s.replyCompressed(req)
|
||||
|
||||
originalName := req.Question[0].Name
|
||||
|
||||
@@ -121,13 +125,13 @@ func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.M
|
||||
case filtering.BlockingModeNullIP:
|
||||
return s.makeResponseNullIP(req)
|
||||
case filtering.BlockingModeNXDOMAIN:
|
||||
return s.genNXDomain(req)
|
||||
return s.NewMsgNXDOMAIN(req)
|
||||
case filtering.BlockingModeREFUSED:
|
||||
return s.makeResponseREFUSED(req)
|
||||
default:
|
||||
log.Error("dnsforward: invalid blocking mode %q", mode)
|
||||
|
||||
return s.makeResponse(req)
|
||||
return s.replyCompressed(req)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -148,25 +152,18 @@ func (s *Server) makeResponseCustomIP(
|
||||
// genDNSFilterMessage.
|
||||
log.Error("dnsforward: invalid msg type %s for custom IP blocking mode", dns.Type(qt))
|
||||
|
||||
return s.makeResponse(req)
|
||||
return s.replyCompressed(req)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg {
|
||||
resp := dns.Msg{}
|
||||
resp.SetRcode(request, dns.RcodeServerFailure)
|
||||
resp.RecursionAvailable = true
|
||||
return &resp
|
||||
}
|
||||
|
||||
func (s *Server) genARecord(request *dns.Msg, ip netip.Addr) *dns.Msg {
|
||||
resp := s.makeResponse(request)
|
||||
resp := s.replyCompressed(request)
|
||||
resp.Answer = append(resp.Answer, s.genAnswerA(request, ip))
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) genAAAARecord(request *dns.Msg, ip netip.Addr) *dns.Msg {
|
||||
resp := s.makeResponse(request)
|
||||
resp := s.replyCompressed(request)
|
||||
resp.Answer = append(resp.Answer, s.genAnswerAAAA(request, ip))
|
||||
return resp
|
||||
}
|
||||
@@ -252,7 +249,7 @@ func (s *Server) genResponseWithIPs(req *dns.Msg, ips []netip.Addr) (resp *dns.M
|
||||
// Go on and return an empty response.
|
||||
}
|
||||
|
||||
resp = s.makeResponse(req)
|
||||
resp = s.replyCompressed(req)
|
||||
resp.Answer = ans
|
||||
|
||||
return resp
|
||||
@@ -288,7 +285,7 @@ func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) {
|
||||
case dns.TypeAAAA:
|
||||
resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv6Unspecified()})
|
||||
default:
|
||||
resp = s.makeResponse(req)
|
||||
resp = s.replyCompressed(req)
|
||||
}
|
||||
|
||||
return resp
|
||||
@@ -298,7 +295,7 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
|
||||
if newAddr == "" {
|
||||
log.Info("dnsforward: block host is not specified")
|
||||
|
||||
return s.genServerFailure(request)
|
||||
return s.NewMsgSERVFAIL(request)
|
||||
}
|
||||
|
||||
ip, err := netip.ParseAddr(newAddr)
|
||||
@@ -321,17 +318,17 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
|
||||
if prx == nil {
|
||||
log.Debug("dnsforward: %s", srvClosedErr)
|
||||
|
||||
return s.genServerFailure(request)
|
||||
return s.NewMsgSERVFAIL(request)
|
||||
}
|
||||
|
||||
err = prx.Resolve(newContext)
|
||||
if err != nil {
|
||||
log.Info("dnsforward: looking up replacement host %q: %s", newAddr, err)
|
||||
|
||||
return s.genServerFailure(request)
|
||||
return s.NewMsgSERVFAIL(request)
|
||||
}
|
||||
|
||||
resp := s.makeResponse(request)
|
||||
resp := s.replyCompressed(request)
|
||||
if newContext.Res != nil {
|
||||
for _, answer := range newContext.Res.Answer {
|
||||
answer.Header().Name = request.Question[0].Name
|
||||
@@ -342,48 +339,21 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
|
||||
return resp
|
||||
}
|
||||
|
||||
// preBlockedResponse returns a protocol-appropriate response for a request that
|
||||
// was blocked by access settings.
|
||||
func (s *Server) preBlockedResponse(pctx *proxy.DNSContext) (reply bool, err error) {
|
||||
if pctx.Proto == proxy.ProtoUDP || pctx.Proto == proxy.ProtoDNSCrypt {
|
||||
// Return nil so that dnsproxy drops the connection and thus
|
||||
// prevent DNS amplification attacks.
|
||||
return false, nil
|
||||
}
|
||||
|
||||
pctx.Res = s.makeResponseREFUSED(pctx.Req)
|
||||
|
||||
// Return true so that dnsproxy responds with the REFUSED message.
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Create REFUSED DNS response
|
||||
func (s *Server) makeResponseREFUSED(request *dns.Msg) *dns.Msg {
|
||||
resp := dns.Msg{}
|
||||
resp.SetRcode(request, dns.RcodeRefused)
|
||||
resp.RecursionAvailable = true
|
||||
return &resp
|
||||
func (s *Server) makeResponseREFUSED(req *dns.Msg) *dns.Msg {
|
||||
return s.reply(req, dns.RcodeRefused)
|
||||
}
|
||||
|
||||
// newMsgNODATA returns a properly initialized NODATA response.
|
||||
//
|
||||
// See https://www.rfc-editor.org/rfc/rfc2308#section-2.2.
|
||||
func (s *Server) newMsgNODATA(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp = (&dns.Msg{}).SetRcode(req, dns.RcodeSuccess)
|
||||
resp.RecursionAvailable = true
|
||||
resp = s.reply(req, dns.RcodeSuccess)
|
||||
resp.Ns = s.genSOA(req)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg {
|
||||
resp := dns.Msg{}
|
||||
resp.SetRcode(request, dns.RcodeNameError)
|
||||
resp.RecursionAvailable = true
|
||||
resp.Ns = s.genSOA(request)
|
||||
return &resp
|
||||
}
|
||||
|
||||
func (s *Server) genSOA(request *dns.Msg) []dns.RR {
|
||||
zone := ""
|
||||
if len(request.Question) > 0 {
|
||||
@@ -415,5 +385,43 @@ func (s *Server) genSOA(request *dns.Msg) []dns.RR {
|
||||
if len(zone) > 0 && zone[0] != '.' {
|
||||
soa.Mbox += zone
|
||||
}
|
||||
|
||||
return []dns.RR{&soa}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ proxy.MessageConstructor = (*Server)(nil)
|
||||
|
||||
// NewMsgNXDOMAIN implements the [proxy.MessageConstructor] interface for
|
||||
// *Server.
|
||||
func (s *Server) NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp = s.reply(req, dns.RcodeNameError)
|
||||
resp.Ns = s.genSOA(req)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// NewMsgSERVFAIL implements the [proxy.MessageConstructor] interface for
|
||||
// *Server.
|
||||
func (s *Server) NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg) {
|
||||
return s.reply(req, dns.RcodeServerFailure)
|
||||
}
|
||||
|
||||
// NewMsgNOTIMPLEMENTED implements the [proxy.MessageConstructor] interface for
|
||||
// *Server.
|
||||
func (s *Server) NewMsgNOTIMPLEMENTED(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp = s.reply(req, dns.RcodeNotImplemented)
|
||||
|
||||
// Most of the Internet and especially the inner core has an MTU of at least
|
||||
// 1500 octets. Maximum DNS/UDP payload size for IPv6 on MTU 1500 ethernet
|
||||
// is 1452 (1500 minus 40 (IPv6 header size) minus 8 (UDP header size)).
|
||||
//
|
||||
// See appendix A of https://datatracker.ietf.org/doc/draft-ietf-dnsop-avoid-fragmentation/17.
|
||||
const maxUDPPayload = 1452
|
||||
|
||||
// NOTIMPLEMENTED without EDNS is treated as 'we don't support EDNS', so
|
||||
// explicitly set it.
|
||||
resp.SetEdns0(maxUDPPayload, false)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
@@ -1,20 +1,17 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
@@ -34,11 +31,6 @@ type dnsContext struct {
|
||||
// response is modified by filters.
|
||||
origResp *dns.Msg
|
||||
|
||||
// unreversedReqIP stores an IP address obtained from a PTR request if it
|
||||
// was parsed successfully and belongs to one of the locally served IP
|
||||
// ranges.
|
||||
unreversedReqIP netip.Addr
|
||||
|
||||
// err is the error returned from a processing function.
|
||||
err error
|
||||
|
||||
@@ -63,10 +55,6 @@ type dnsContext struct {
|
||||
// responseAD shows if the response had the AD bit set.
|
||||
responseAD bool
|
||||
|
||||
// isLocalClient shows if client's IP address is from locally served
|
||||
// network.
|
||||
isLocalClient bool
|
||||
|
||||
// isDHCPHost is true if the request for a local domain name and the DHCP is
|
||||
// available for this request.
|
||||
isDHCPHost bool
|
||||
@@ -109,15 +97,11 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error
|
||||
// (*proxy.Proxy).handleDNSRequest method performs it before calling the
|
||||
// appropriate handler.
|
||||
mods := []modProcessFunc{
|
||||
s.processRecursion,
|
||||
s.processInitial,
|
||||
s.processDDRQuery,
|
||||
s.processDetermineLocal,
|
||||
s.processDHCPHosts,
|
||||
s.processRestrictLocal,
|
||||
s.processDHCPAddrs,
|
||||
s.processFilteringBeforeRequest,
|
||||
s.processLocalPTR,
|
||||
s.processUpstream,
|
||||
s.processFilteringAfterResponse,
|
||||
s.ipset.process,
|
||||
@@ -145,24 +129,6 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// processRecursion checks the incoming request and halts its handling by
|
||||
// answering NXDOMAIN if s has tried to resolve it recently.
|
||||
func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing recursion")
|
||||
defer log.Debug("dnsforward: finished processing recursion")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
|
||||
if msg := pctx.Req; msg != nil && s.recDetector.check(*msg) {
|
||||
log.Debug("dnsforward: recursion detected resolving %q", msg.Question[0].Name)
|
||||
pctx.Res = s.genNXDomain(pctx.Req)
|
||||
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// mozillaFQDN is the domain used to signal the Firefox browser to not use its
|
||||
// own DoH server.
|
||||
//
|
||||
@@ -199,14 +165,14 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
|
||||
}
|
||||
|
||||
if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == mozillaFQDN {
|
||||
pctx.Res = s.genNXDomain(pctx.Req)
|
||||
pctx.Res = s.NewMsgNXDOMAIN(pctx.Req)
|
||||
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
if q.Name == healthcheckFQDN {
|
||||
// Generate a NODATA negative response to make nslookup exit with 0.
|
||||
pctx.Res = s.makeResponse(pctx.Req)
|
||||
pctx.Res = s.replyCompressed(pctx.Req)
|
||||
|
||||
return resultCodeFinish
|
||||
}
|
||||
@@ -272,7 +238,7 @@ func (s *Server) processDDRQuery(dctx *dnsContext) (rc resultCode) {
|
||||
//
|
||||
// [draft standard]: https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html.
|
||||
func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp = s.makeResponse(req)
|
||||
resp = s.replyCompressed(req)
|
||||
if req.Question[0].Qtype != dns.TypeSVCB {
|
||||
return resp
|
||||
}
|
||||
@@ -339,19 +305,6 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
|
||||
return resp
|
||||
}
|
||||
|
||||
// processDetermineLocal determines if the client's IP address is from locally
|
||||
// served network and saves the result into the context.
|
||||
func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing local detection")
|
||||
defer log.Debug("dnsforward: finished processing local detection")
|
||||
|
||||
rc = resultCodeSuccess
|
||||
|
||||
dctx.isLocalClient = s.privateNets.Contains(dctx.proxyCtx.Addr.Addr())
|
||||
|
||||
return rc
|
||||
}
|
||||
|
||||
// processDHCPHosts respond to A requests if the target hostname is known to
|
||||
// the server. It responds with a mapped IP address if the DNS64 is enabled and
|
||||
// the request is for AAAA.
|
||||
@@ -370,9 +323,9 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
if !dctx.isLocalClient {
|
||||
if !pctx.IsPrivateClient {
|
||||
log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, dhcpHost)
|
||||
pctx.Res = s.genNXDomain(req)
|
||||
pctx.Res = s.NewMsgNXDOMAIN(req)
|
||||
|
||||
// Do not even put into query log.
|
||||
return resultCodeFinish
|
||||
@@ -389,7 +342,7 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
|
||||
|
||||
log.Debug("dnsforward: dhcp record for %q is %s", dhcpHost, ip)
|
||||
|
||||
resp := s.makeResponse(req)
|
||||
resp := s.replyCompressed(req)
|
||||
switch q.Qtype {
|
||||
case dns.TypeA:
|
||||
a := &dns.A{
|
||||
@@ -416,141 +369,6 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// indexFirstV4Label returns the index at which the reversed IPv4 address
|
||||
// starts, assuming the domain is pre-validated ARPA domain having in-addr and
|
||||
// arpa labels removed.
|
||||
func indexFirstV4Label(domain string) (idx int) {
|
||||
idx = len(domain)
|
||||
for labelsNum := 0; labelsNum < net.IPv4len && idx > 0; labelsNum++ {
|
||||
curIdx := strings.LastIndexByte(domain[:idx-1], '.') + 1
|
||||
_, parseErr := strconv.ParseUint(domain[curIdx:idx-1], 10, 8)
|
||||
if parseErr != nil {
|
||||
return idx
|
||||
}
|
||||
|
||||
idx = curIdx
|
||||
}
|
||||
|
||||
return idx
|
||||
}
|
||||
|
||||
// indexFirstV6Label returns the index at which the reversed IPv6 address
|
||||
// starts, assuming the domain is pre-validated ARPA domain having ip6 and arpa
|
||||
// labels removed.
|
||||
func indexFirstV6Label(domain string) (idx int) {
|
||||
idx = len(domain)
|
||||
for labelsNum := 0; labelsNum < net.IPv6len*2 && idx > 0; labelsNum++ {
|
||||
curIdx := idx - len("a.")
|
||||
if curIdx > 1 && domain[curIdx-1] != '.' {
|
||||
return idx
|
||||
}
|
||||
|
||||
nibble := domain[curIdx]
|
||||
if (nibble < '0' || nibble > '9') && (nibble < 'a' || nibble > 'f') {
|
||||
return idx
|
||||
}
|
||||
|
||||
idx = curIdx
|
||||
}
|
||||
|
||||
return idx
|
||||
}
|
||||
|
||||
// extractARPASubnet tries to convert a reversed ARPA address being a part of
|
||||
// domain to an IP network. domain must be an FQDN.
|
||||
//
|
||||
// TODO(e.burkov): Move to golibs.
|
||||
func extractARPASubnet(domain string) (pref netip.Prefix, err error) {
|
||||
err = netutil.ValidateDomainName(strings.TrimSuffix(domain, "."))
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return netip.Prefix{}, err
|
||||
}
|
||||
|
||||
const (
|
||||
v4Suffix = "in-addr.arpa."
|
||||
v6Suffix = "ip6.arpa."
|
||||
)
|
||||
|
||||
domain = strings.ToLower(domain)
|
||||
|
||||
var idx int
|
||||
switch {
|
||||
case strings.HasSuffix(domain, v4Suffix):
|
||||
idx = indexFirstV4Label(domain[:len(domain)-len(v4Suffix)])
|
||||
case strings.HasSuffix(domain, v6Suffix):
|
||||
idx = indexFirstV6Label(domain[:len(domain)-len(v6Suffix)])
|
||||
default:
|
||||
return netip.Prefix{}, &netutil.AddrError{
|
||||
Err: netutil.ErrNotAReversedSubnet,
|
||||
Kind: netutil.AddrKindARPA,
|
||||
Addr: domain,
|
||||
}
|
||||
}
|
||||
|
||||
return netutil.PrefixFromReversedAddr(domain[idx:])
|
||||
}
|
||||
|
||||
// processRestrictLocal responds with NXDOMAIN to PTR requests for IP addresses
|
||||
// in locally served network from external clients.
|
||||
func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing local restriction")
|
||||
defer log.Debug("dnsforward: finished processing local restriction")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
req := pctx.Req
|
||||
q := req.Question[0]
|
||||
if q.Qtype != dns.TypePTR {
|
||||
// No need for restriction.
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
subnet, err := extractARPASubnet(q.Name)
|
||||
if err != nil {
|
||||
if errors.Is(err, netutil.ErrNotAReversedSubnet) {
|
||||
log.Debug("dnsforward: request is not for arpa domain")
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: parsing reversed addr: %s", err)
|
||||
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
// Restrict an access to local addresses for external clients. We also
|
||||
// assume that all the DHCP leases we give are locally served or at least
|
||||
// shouldn't be accessible externally.
|
||||
subnetAddr := subnet.Addr()
|
||||
if !s.privateNets.Contains(subnetAddr) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: addr %s is from locally served network", subnetAddr)
|
||||
|
||||
if !dctx.isLocalClient {
|
||||
log.Debug("dnsforward: %q requests an internal ip", pctx.Addr)
|
||||
pctx.Res = s.genNXDomain(req)
|
||||
|
||||
// Do not even put into query log.
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
// Do not perform unreversing ever again.
|
||||
dctx.unreversedReqIP = subnetAddr
|
||||
|
||||
// There is no need to filter request from external addresses since this
|
||||
// code is only executed when the request is for locally served ARPA
|
||||
// hostname so disable redundant filters.
|
||||
dctx.setts.ParentalEnabled = false
|
||||
dctx.setts.SafeBrowsingEnabled = false
|
||||
dctx.setts.SafeSearchEnabled = false
|
||||
dctx.setts.ServicesRules = nil
|
||||
|
||||
// Nothing to restrict.
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// processDHCPAddrs responds to PTR requests if the target IP is leased by the
|
||||
// DHCP server.
|
||||
func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
|
||||
@@ -562,23 +380,27 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
ipAddr := dctx.unreversedReqIP
|
||||
if ipAddr == (netip.Addr{}) {
|
||||
req := pctx.Req
|
||||
q := req.Question[0]
|
||||
pref := pctx.RequestedPrivateRDNS
|
||||
// TODO(e.burkov): Consider answering authoritatively for SOA and NS
|
||||
// queries.
|
||||
if pref == (netip.Prefix{}) || q.Qtype != dns.TypePTR {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
host := s.dhcpServer.HostByIP(ipAddr)
|
||||
addr := pref.Addr()
|
||||
host := s.dhcpServer.HostByIP(addr)
|
||||
if host == "" {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: dhcp client %s is %q", ipAddr, host)
|
||||
log.Debug("dnsforward: dhcp client %s is %q", addr, host)
|
||||
|
||||
req := pctx.Req
|
||||
resp := s.makeResponse(req)
|
||||
resp := s.replyCompressed(req)
|
||||
ptr := &dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: req.Question[0].Name,
|
||||
Name: q.Name,
|
||||
Rrtype: dns.TypePTR,
|
||||
// TODO(e.burkov): Use [dhcpsvc.Lease.Expiry]. See
|
||||
// https://github.com/AdguardTeam/AdGuardHome/issues/3932.
|
||||
@@ -593,62 +415,20 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// processLocalPTR responds to PTR requests if the target IP is detected to be
|
||||
// inside the local network and the query was not answered from DHCP.
|
||||
func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing local ptr")
|
||||
defer log.Debug("dnsforward: finished processing local ptr")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
if pctx.Res != nil {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
ip := dctx.unreversedReqIP
|
||||
if ip == (netip.Addr{}) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
if s.conf.UsePrivateRDNS {
|
||||
s.recDetector.add(*pctx.Req)
|
||||
if err := s.localResolvers.Resolve(pctx); err != nil {
|
||||
log.Debug("dnsforward: resolving private address: %s", err)
|
||||
|
||||
// Generate the server failure if the private upstream configuration
|
||||
// is empty.
|
||||
//
|
||||
// This is a crutch, see TODO at [Server.localResolvers].
|
||||
if errors.Is(err, upstream.ErrNoUpstreams) {
|
||||
pctx.Res = s.genServerFailure(pctx.Req)
|
||||
|
||||
// Do not even put into query log.
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
dctx.err = err
|
||||
|
||||
return resultCodeError
|
||||
}
|
||||
}
|
||||
|
||||
if pctx.Res == nil {
|
||||
pctx.Res = s.genNXDomain(pctx.Req)
|
||||
|
||||
// Do not even put into query log.
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// Apply filtering logic
|
||||
func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing filtering before req")
|
||||
defer log.Debug("dnsforward: finished processing filtering before req")
|
||||
|
||||
if dctx.proxyCtx.RequestedPrivateRDNS != (netip.Prefix{}) {
|
||||
// There is no need to filter request for locally served ARPA hostname
|
||||
// so disable redundant filters.
|
||||
dctx.setts.ParentalEnabled = false
|
||||
dctx.setts.SafeBrowsingEnabled = false
|
||||
dctx.setts.SafeSearchEnabled = false
|
||||
dctx.setts.ServicesRules = nil
|
||||
}
|
||||
|
||||
if dctx.proxyCtx.Res != nil {
|
||||
// Go on since the response is already set.
|
||||
return resultCodeSuccess
|
||||
@@ -695,7 +475,7 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
|
||||
// local domain name if there is one.
|
||||
name := req.Question[0].Name
|
||||
log.Debug("dnsforward: dhcp client hostname %q was not filtered", name[:len(name)-1])
|
||||
pctx.Res = s.genNXDomain(req)
|
||||
pctx.Res = s.NewMsgNXDOMAIN(req)
|
||||
|
||||
return resultCodeFinish
|
||||
}
|
||||
@@ -712,21 +492,7 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
if err := prx.Resolve(pctx); err != nil {
|
||||
if errors.Is(err, upstream.ErrNoUpstreams) {
|
||||
// Do not even put into querylog. Currently this happens either
|
||||
// when the private resolvers enabled and the request is DNS64 PTR,
|
||||
// or when the client isn't considered local by prx.
|
||||
//
|
||||
// TODO(e.burkov): Make proxy detect local client the same way as
|
||||
// AGH does.
|
||||
pctx.Res = s.genNXDomain(req)
|
||||
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
dctx.err = err
|
||||
|
||||
if dctx.err = prx.Resolve(pctx); dctx.err != nil {
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
@@ -810,7 +576,7 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
|
||||
}
|
||||
|
||||
// Use the ClientID first, since it has a higher priority.
|
||||
id := stringutil.Coalesce(clientID, pctx.Addr.Addr().String())
|
||||
id := cmp.Or(clientID, pctx.Addr.Addr().String())
|
||||
upsConf, err := s.conf.ClientsContainer.UpstreamConfigByID(id, s.bootstrap)
|
||||
if err != nil {
|
||||
log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err)
|
||||
@@ -835,7 +601,8 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
|
||||
return resultCodeSuccess
|
||||
case
|
||||
filtering.Rewritten,
|
||||
filtering.RewrittenRule:
|
||||
filtering.RewrittenRule,
|
||||
filtering.FilteredSafeSearch:
|
||||
|
||||
if dctx.origQuestion.Name == "" {
|
||||
// origQuestion is set in case we get only CNAME without IP from
|
||||
@@ -845,11 +612,10 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
pctx.Req.Question[0], pctx.Res.Question[0] = dctx.origQuestion, dctx.origQuestion
|
||||
if len(pctx.Res.Answer) > 0 {
|
||||
rr := s.genAnswerCNAME(pctx.Req, res.CanonName)
|
||||
answer := append([]dns.RR{rr}, pctx.Res.Answer...)
|
||||
pctx.Res.Answer = answer
|
||||
}
|
||||
|
||||
rr := s.genAnswerCNAME(pctx.Req, res.CanonName)
|
||||
answer := append([]dns.RR{rr}, pctx.Res.Answer...)
|
||||
pctx.Res.Answer = answer
|
||||
|
||||
return resultCodeSuccess
|
||||
default:
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
@@ -70,8 +71,6 @@ func TestServer_ProcessInitial(t *testing.T) {
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -171,8 +170,6 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -379,44 +376,6 @@ func createTestDNSFilter(t *testing.T) (f *filtering.DNSFilter) {
|
||||
return f
|
||||
}
|
||||
|
||||
func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||
s := &Server{
|
||||
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
want assert.BoolAssertionFunc
|
||||
name string
|
||||
cliAddr netip.AddrPort
|
||||
}{{
|
||||
want: assert.True,
|
||||
name: "local",
|
||||
cliAddr: netip.MustParseAddrPort("192.168.0.1:1"),
|
||||
}, {
|
||||
want: assert.False,
|
||||
name: "external",
|
||||
cliAddr: netip.MustParseAddrPort("250.249.0.1:1"),
|
||||
}, {
|
||||
want: assert.False,
|
||||
name: "invalid",
|
||||
cliAddr: netip.AddrPort{},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
proxyCtx := &proxy.DNSContext{
|
||||
Addr: tc.cliAddr,
|
||||
}
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: proxyCtx,
|
||||
}
|
||||
s.processDetermineLocal(dctx)
|
||||
|
||||
tc.want(t, dctx.isLocalClient)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||
const (
|
||||
localDomainSuffix = "lan"
|
||||
@@ -486,9 +445,9 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Req: req,
|
||||
Req: req,
|
||||
IsPrivateClient: tc.isLocalCli,
|
||||
},
|
||||
isLocalClient: tc.isLocalCli,
|
||||
}
|
||||
|
||||
res := s.processDHCPHosts(dctx)
|
||||
@@ -621,9 +580,9 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Req: req,
|
||||
Req: req,
|
||||
IsPrivateClient: true,
|
||||
},
|
||||
isLocalClient: true,
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -658,19 +617,28 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessRestrictLocal(t *testing.T) {
|
||||
// TODO(e.burkov): Rewrite this test to use the whole server instead of just
|
||||
// testing the [handleDNSRequest] method. See comment on
|
||||
// "from_external_for_local" test case.
|
||||
func TestServer_HandleDNSRequest_restrictLocal(t *testing.T) {
|
||||
intAddr := netip.MustParseAddr("192.168.1.1")
|
||||
intPTRQuestion, err := netutil.IPToReversedAddr(intAddr.AsSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
extAddr := netip.MustParseAddr("254.253.252.1")
|
||||
extPTRQuestion, err := netutil.IPToReversedAddr(extAddr.AsSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
const (
|
||||
extPTRQuestion = "251.252.253.254.in-addr.arpa."
|
||||
extPTRAnswer = "host1.example.net."
|
||||
intPTRQuestion = "1.1.168.192.in-addr.arpa."
|
||||
intPTRAnswer = "some.local-client."
|
||||
extPTRAnswer = "host1.example.net."
|
||||
intPTRAnswer = "some.local-client."
|
||||
)
|
||||
|
||||
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
resp := aghalg.Coalesce(
|
||||
resp := cmp.Or(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer),
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
(&dns.Msg{}).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
|
||||
@@ -696,123 +664,165 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
|
||||
startDeferStop(t, s)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
want string
|
||||
question net.IP
|
||||
cliAddr netip.AddrPort
|
||||
wantLen int
|
||||
name string
|
||||
question string
|
||||
wantErr error
|
||||
wantAns []dns.RR
|
||||
isPrivate bool
|
||||
}{{
|
||||
name: "from_local_to_external",
|
||||
want: "host1.example.net.",
|
||||
question: net.IP{254, 253, 252, 251},
|
||||
cliAddr: netip.MustParseAddrPort("192.168.10.10:1"),
|
||||
wantLen: 1,
|
||||
name: "from_local_for_external",
|
||||
question: extPTRQuestion,
|
||||
wantErr: nil,
|
||||
wantAns: []dns.RR{&dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn(extPTRQuestion),
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
Rdlength: uint16(len(extPTRAnswer) + 1),
|
||||
},
|
||||
Ptr: dns.Fqdn(extPTRAnswer),
|
||||
}},
|
||||
isPrivate: true,
|
||||
}, {
|
||||
name: "from_external_for_local",
|
||||
want: "",
|
||||
question: net.IP{192, 168, 1, 1},
|
||||
cliAddr: netip.MustParseAddrPort("254.253.252.251:1"),
|
||||
wantLen: 0,
|
||||
// In theory this case is not reproducible because [proxy.Proxy] should
|
||||
// respond to such queries with NXDOMAIN before they reach
|
||||
// [Server.handleDNSRequest].
|
||||
name: "from_external_for_local",
|
||||
question: intPTRQuestion,
|
||||
wantErr: upstream.ErrNoUpstreams,
|
||||
wantAns: nil,
|
||||
isPrivate: false,
|
||||
}, {
|
||||
name: "from_local_for_local",
|
||||
want: "some.local-client.",
|
||||
question: net.IP{192, 168, 1, 1},
|
||||
cliAddr: netip.MustParseAddrPort("192.168.1.2:1"),
|
||||
wantLen: 1,
|
||||
question: intPTRQuestion,
|
||||
wantErr: nil,
|
||||
wantAns: []dns.RR{&dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn(intPTRQuestion),
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
Rdlength: uint16(len(intPTRAnswer) + 1),
|
||||
},
|
||||
Ptr: dns.Fqdn(intPTRAnswer),
|
||||
}},
|
||||
isPrivate: true,
|
||||
}, {
|
||||
name: "from_external_for_external",
|
||||
want: "host1.example.net.",
|
||||
question: net.IP{254, 253, 252, 251},
|
||||
cliAddr: netip.MustParseAddrPort("254.253.252.255:1"),
|
||||
wantLen: 1,
|
||||
question: extPTRQuestion,
|
||||
wantErr: nil,
|
||||
wantAns: []dns.RR{&dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn(extPTRQuestion),
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
Rdlength: uint16(len(extPTRAnswer) + 1),
|
||||
},
|
||||
Ptr: dns.Fqdn(extPTRAnswer),
|
||||
}},
|
||||
isPrivate: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
reqAddr, err := dns.ReverseAddr(tc.question.String())
|
||||
require.NoError(t, err)
|
||||
req := createTestMessageWithType(reqAddr, dns.TypePTR)
|
||||
pref, extErr := netutil.ExtractReversedAddr(tc.question)
|
||||
require.NoError(t, extErr)
|
||||
|
||||
req := createTestMessageWithType(dns.Fqdn(tc.question), dns.TypePTR)
|
||||
pctx := &proxy.DNSContext{
|
||||
Proto: proxy.ProtoTCP,
|
||||
Req: req,
|
||||
Addr: tc.cliAddr,
|
||||
Req: req,
|
||||
IsPrivateClient: tc.isPrivate,
|
||||
}
|
||||
// TODO(e.burkov): Configure the subnet set properly.
|
||||
if netutil.IsLocallyServed(pref.Addr()) {
|
||||
pctx.RequestedPrivateRDNS = pref
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = s.handleDNSRequest(nil, pctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pctx.Res)
|
||||
require.Len(t, pctx.Res.Answer, tc.wantLen)
|
||||
err = s.handleDNSRequest(s.dnsProxy, pctx)
|
||||
require.ErrorIs(t, err, tc.wantErr)
|
||||
|
||||
if tc.wantLen > 0 {
|
||||
assert.Equal(t, tc.want, pctx.Res.Answer[0].(*dns.PTR).Ptr)
|
||||
}
|
||||
require.NotNil(t, pctx.Res)
|
||||
assert.Equal(t, tc.wantAns, pctx.Res.Answer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
|
||||
func TestServer_ProcessUpstream_localPTR(t *testing.T) {
|
||||
const locDomain = "some.local."
|
||||
const reqAddr = "1.1.168.192.in-addr.arpa."
|
||||
|
||||
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
resp := aghalg.Coalesce(
|
||||
resp := cmp.Or(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
(&dns.Msg{}).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
|
||||
})
|
||||
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
|
||||
|
||||
s := createTestServer(
|
||||
t,
|
||||
&filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
},
|
||||
ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
UsePrivateRDNS: true,
|
||||
LocalPTRResolvers: []string{localUpsAddr},
|
||||
ServePlainDNS: true,
|
||||
},
|
||||
)
|
||||
|
||||
var proxyCtx *proxy.DNSContext
|
||||
var dnsCtx *dnsContext
|
||||
setup := func(use bool) {
|
||||
proxyCtx = &proxy.DNSContext{
|
||||
Addr: testClientAddrPort,
|
||||
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
|
||||
newPrxCtx := func() (prxCtx *proxy.DNSContext) {
|
||||
return &proxy.DNSContext{
|
||||
Addr: testClientAddrPort,
|
||||
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
|
||||
IsPrivateClient: true,
|
||||
RequestedPrivateRDNS: netip.MustParsePrefix("192.168.1.1/32"),
|
||||
}
|
||||
dnsCtx = &dnsContext{
|
||||
proxyCtx: proxyCtx,
|
||||
unreversedReqIP: netip.MustParseAddr("192.168.1.1"),
|
||||
}
|
||||
s.conf.UsePrivateRDNS = use
|
||||
}
|
||||
|
||||
t.Run("enabled", func(t *testing.T) {
|
||||
setup(true)
|
||||
s := createTestServer(
|
||||
t,
|
||||
&filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
},
|
||||
ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
UsePrivateRDNS: true,
|
||||
LocalPTRResolvers: []string{localUpsAddr},
|
||||
ServePlainDNS: true,
|
||||
},
|
||||
)
|
||||
pctx := newPrxCtx()
|
||||
|
||||
rc := s.processLocalPTR(dnsCtx)
|
||||
rc := s.processUpstream(&dnsContext{proxyCtx: pctx})
|
||||
require.Equal(t, resultCodeSuccess, rc)
|
||||
require.NotEmpty(t, proxyCtx.Res.Answer)
|
||||
require.NotEmpty(t, pctx.Res.Answer)
|
||||
ptr := testutil.RequireTypeAssert[*dns.PTR](t, pctx.Res.Answer[0])
|
||||
|
||||
assert.Equal(t, locDomain, proxyCtx.Res.Answer[0].(*dns.PTR).Ptr)
|
||||
assert.Equal(t, locDomain, ptr.Ptr)
|
||||
})
|
||||
|
||||
t.Run("disabled", func(t *testing.T) {
|
||||
setup(false)
|
||||
s := createTestServer(
|
||||
t,
|
||||
&filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
},
|
||||
ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
UsePrivateRDNS: false,
|
||||
LocalPTRResolvers: []string{localUpsAddr},
|
||||
ServePlainDNS: true,
|
||||
},
|
||||
)
|
||||
pctx := newPrxCtx()
|
||||
|
||||
rc := s.processLocalPTR(dnsCtx)
|
||||
require.Equal(t, resultCodeFinish, rc)
|
||||
require.Empty(t, proxyCtx.Res.Answer)
|
||||
rc := s.processUpstream(&dnsContext{proxyCtx: pctx})
|
||||
require.Equal(t, resultCodeError, rc)
|
||||
require.Empty(t, pctx.Res.Answer)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -830,129 +840,3 @@ func TestIPStringFromAddr(t *testing.T) {
|
||||
assert.Empty(t, ipStringFromAddr(nil))
|
||||
})
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Add fuzzing when moving to golibs.
|
||||
func TestExtractARPASubnet(t *testing.T) {
|
||||
const (
|
||||
v4Suf = `in-addr.arpa.`
|
||||
v4Part = `2.1.` + v4Suf
|
||||
v4Whole = `4.3.` + v4Part
|
||||
|
||||
v6Suf = `ip6.arpa.`
|
||||
v6Part = `4.3.2.1.0.0.0.0.0.0.0.0.0.0.0.0.` + v6Suf
|
||||
v6Whole = `f.e.d.c.0.0.0.0.0.0.0.0.0.0.0.0.` + v6Part
|
||||
)
|
||||
|
||||
v4Pref := netip.MustParsePrefix("1.2.3.4/32")
|
||||
v4PrefPart := netip.MustParsePrefix("1.2.0.0/16")
|
||||
v6Pref := netip.MustParsePrefix("::1234:0:0:0:cdef/128")
|
||||
v6PrefPart := netip.MustParsePrefix("0:0:0:1234::/64")
|
||||
|
||||
testCases := []struct {
|
||||
want netip.Prefix
|
||||
name string
|
||||
domain string
|
||||
wantErr string
|
||||
}{{
|
||||
want: netip.Prefix{},
|
||||
name: "not_an_arpa",
|
||||
domain: "some.domain.name.",
|
||||
wantErr: `bad arpa domain name "some.domain.name.": ` +
|
||||
`not a reversed ip network`,
|
||||
}, {
|
||||
want: netip.Prefix{},
|
||||
name: "bad_domain_name",
|
||||
domain: "abc.123.",
|
||||
wantErr: `bad domain name "abc.123": ` +
|
||||
`bad top-level domain name label "123": all octets are numeric`,
|
||||
}, {
|
||||
want: v4Pref,
|
||||
name: "whole_v4",
|
||||
domain: v4Whole,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v4PrefPart,
|
||||
name: "partial_v4",
|
||||
domain: v4Part,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v4Pref,
|
||||
name: "whole_v4_within_domain",
|
||||
domain: "a." + v4Whole,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v4Pref,
|
||||
name: "whole_v4_additional_label",
|
||||
domain: "5." + v4Whole,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v4PrefPart,
|
||||
name: "partial_v4_within_domain",
|
||||
domain: "a." + v4Part,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v4PrefPart,
|
||||
name: "overflow_v4",
|
||||
domain: "256." + v4Part,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v4PrefPart,
|
||||
name: "overflow_v4_within_domain",
|
||||
domain: "a.256." + v4Part,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: netip.Prefix{},
|
||||
name: "empty_v4",
|
||||
domain: v4Suf,
|
||||
wantErr: `bad arpa domain name "in-addr.arpa": ` +
|
||||
`not a reversed ip network`,
|
||||
}, {
|
||||
want: netip.Prefix{},
|
||||
name: "empty_v4_within_domain",
|
||||
domain: "a." + v4Suf,
|
||||
wantErr: `bad arpa domain name "in-addr.arpa": ` +
|
||||
`not a reversed ip network`,
|
||||
}, {
|
||||
want: v6Pref,
|
||||
name: "whole_v6",
|
||||
domain: v6Whole,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v6PrefPart,
|
||||
name: "partial_v6",
|
||||
domain: v6Part,
|
||||
}, {
|
||||
want: v6Pref,
|
||||
name: "whole_v6_within_domain",
|
||||
domain: "g." + v6Whole,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v6Pref,
|
||||
name: "whole_v6_additional_label",
|
||||
domain: "1." + v6Whole,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: v6PrefPart,
|
||||
name: "partial_v6_within_domain",
|
||||
domain: "label." + v6Part,
|
||||
wantErr: "",
|
||||
}, {
|
||||
want: netip.Prefix{},
|
||||
name: "empty_v6",
|
||||
domain: v6Suf,
|
||||
wantErr: `bad arpa domain name "ip6.arpa": not a reversed ip network`,
|
||||
}, {
|
||||
want: netip.Prefix{},
|
||||
name: "empty_v6_within_domain",
|
||||
domain: "g." + v6Suf,
|
||||
wantErr: `bad arpa domain name "ip6.arpa": not a reversed ip network`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
subnet, err := extractARPASubnet(tc.domain)
|
||||
testutil.AssertErrorMsg(t, tc.wantErr, err)
|
||||
assert.Equal(t, tc.want, subnet)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,7 +29,13 @@ func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
|
||||
|
||||
log.Debug("dnsforward: client ip for stats and querylog: %s", ipStr)
|
||||
|
||||
ids := []string{ipStr, dctx.clientID}
|
||||
ids := []string{ipStr}
|
||||
if dctx.clientID != "" {
|
||||
// Use the ClientID first because it has a higher priority. Filters
|
||||
// have the same priority, see applyAdditionalFiltering.
|
||||
ids = []string{dctx.clientID, ipStr}
|
||||
}
|
||||
|
||||
qt, cl := q.Qtype, q.Qclass
|
||||
|
||||
// Synchronize access to s.queryLog and s.stats so they won't be suddenly
|
||||
@@ -124,7 +130,7 @@ func (s *Server) logQuery(dctx *dnsContext, ip net.IP, processingTime time.Durat
|
||||
s.queryLog.Add(p)
|
||||
}
|
||||
|
||||
// updatesStats writes the request into statistics.
|
||||
// updateStats writes the request data into statistics.
|
||||
func (s *Server) updateStats(dctx *dnsContext, clientIP string, processingTime time.Duration) {
|
||||
pctx := dctx.proxyCtx
|
||||
|
||||
|
||||
@@ -2,90 +2,77 @@ package dnsforward
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
// loadUpstreams parses upstream DNS servers from the configured file or from
|
||||
// the configuration itself.
|
||||
func (s *Server) loadUpstreams() (upstreams []string, err error) {
|
||||
if s.conf.UpstreamDNSFileName == "" {
|
||||
return stringutil.FilterOut(s.conf.UpstreamDNS, IsCommentOrEmpty), nil
|
||||
// newBootstrap returns a bootstrap resolver based on the configuration of s.
|
||||
// boots are the upstream resolvers that should be closed after use. r is the
|
||||
// actual bootstrap resolver, which may include the system hosts.
|
||||
//
|
||||
// TODO(e.burkov): This function currently returns a resolver and a slice of
|
||||
// the upstream resolvers, which are essentially the same. boots are returned
|
||||
// for being able to close them afterwards, but it introduces an implicit
|
||||
// contract that r could only be used before that. Anyway, this code should
|
||||
// improve when the [proxy.UpstreamConfig] will become an [upstream.Resolver]
|
||||
// and be used here.
|
||||
func newBootstrap(
|
||||
addrs []string,
|
||||
etcHosts upstream.Resolver,
|
||||
opts *upstream.Options,
|
||||
) (r upstream.Resolver, boots []*upstream.UpstreamResolver, err error) {
|
||||
if len(addrs) == 0 {
|
||||
addrs = defaultBootstrap
|
||||
}
|
||||
|
||||
var data []byte
|
||||
data, err = os.ReadFile(s.conf.UpstreamDNSFileName)
|
||||
boots, err = aghnet.ParseBootstraps(addrs, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading upstream from file: %w", err)
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
upstreams = stringutil.SplitTrimmed(string(data), "\n")
|
||||
var parallel upstream.ParallelResolver
|
||||
for _, b := range boots {
|
||||
parallel = append(parallel, upstream.NewCachingResolver(b))
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), s.conf.UpstreamDNSFileName)
|
||||
if etcHosts != nil {
|
||||
r = upstream.ConsequentResolver{etcHosts, parallel}
|
||||
} else {
|
||||
r = parallel
|
||||
}
|
||||
|
||||
return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil
|
||||
return r, boots, nil
|
||||
}
|
||||
|
||||
// prepareUpstreamSettings sets upstream DNS server settings.
|
||||
func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
|
||||
// Load upstreams either from the file, or from the settings
|
||||
var upstreams []string
|
||||
upstreams, err = s.loadUpstreams()
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading upstreams: %w", err)
|
||||
}
|
||||
|
||||
s.conf.UpstreamConfig, err = s.prepareUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
|
||||
Bootstrap: boot,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
// Use a customized set of RootCAs, because Go's default mechanism of
|
||||
// loading TLS roots does not always work properly on some routers so we're
|
||||
// loading roots manually and pass it here.
|
||||
//
|
||||
// See [aghtls.SystemRootCAs].
|
||||
//
|
||||
// TODO(a.garipov): Investigate if that's true.
|
||||
RootCAs: s.conf.TLSv12Roots,
|
||||
CipherSuites: s.conf.TLSCiphers,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing upstream config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareUpstreamConfig returns the upstream configuration based on upstreams
|
||||
// and configuration of s.
|
||||
func (s *Server) prepareUpstreamConfig(
|
||||
// newUpstreamConfig returns the upstream configuration based on upstreams. If
|
||||
// upstreams slice specifies no default upstreams, defaultUpstreams are used to
|
||||
// create upstreams with no domain specifications. opts are used when creating
|
||||
// upstream configuration.
|
||||
func newUpstreamConfig(
|
||||
upstreams []string,
|
||||
defaultUpstreams []string,
|
||||
opts *upstream.Options,
|
||||
) (uc *proxy.UpstreamConfig, err error) {
|
||||
uc, err = proxy.ParseUpstreamsConfig(upstreams, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing upstream config: %w", err)
|
||||
return uc, fmt.Errorf("parsing upstreams: %w", err)
|
||||
}
|
||||
|
||||
if len(uc.Upstreams) == 0 && defaultUpstreams != nil {
|
||||
if len(uc.Upstreams) == 0 && len(defaultUpstreams) > 0 {
|
||||
log.Info("dnsforward: warning: no default upstreams specified, using %v", defaultUpstreams)
|
||||
|
||||
var defaultUpstreamConfig *proxy.UpstreamConfig
|
||||
defaultUpstreamConfig, err = proxy.ParseUpstreamsConfig(defaultUpstreams, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing default upstreams: %w", err)
|
||||
return uc, fmt.Errorf("parsing default upstreams: %w", err)
|
||||
}
|
||||
|
||||
uc.Upstreams = defaultUpstreamConfig.Upstreams
|
||||
@@ -94,6 +81,54 @@ func (s *Server) prepareUpstreamConfig(
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
// newPrivateConfig creates an upstream configuration for resolving PTR records
|
||||
// for local addresses. The configuration is built either from the provided
|
||||
// addresses or from the system resolvers. unwanted filters the resulting
|
||||
// upstream configuration.
|
||||
func newPrivateConfig(
|
||||
addrs []string,
|
||||
unwanted addrPortSet,
|
||||
sysResolvers SystemResolvers,
|
||||
privateNets netutil.SubnetSet,
|
||||
opts *upstream.Options,
|
||||
) (uc *proxy.UpstreamConfig, err error) {
|
||||
confNeedsFiltering := len(addrs) > 0
|
||||
if confNeedsFiltering {
|
||||
addrs = stringutil.FilterOut(addrs, IsCommentOrEmpty)
|
||||
} else {
|
||||
sysResolvers := slices.DeleteFunc(slices.Clone(sysResolvers.Addrs()), unwanted.Has)
|
||||
addrs = make([]string, 0, len(sysResolvers))
|
||||
for _, r := range sysResolvers {
|
||||
addrs = append(addrs, r.String())
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", addrs)
|
||||
|
||||
uc, err = proxy.ParseUpstreamsConfig(addrs, opts)
|
||||
if err != nil {
|
||||
return uc, fmt.Errorf("preparing private upstreams: %w", err)
|
||||
}
|
||||
|
||||
if !confNeedsFiltering {
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
err = filterOutAddrs(uc, unwanted)
|
||||
if err != nil {
|
||||
return uc, fmt.Errorf("filtering private upstreams: %w", err)
|
||||
}
|
||||
|
||||
// Prevalidate the config to catch the exact error before creating proxy.
|
||||
// See TODO on [PrivateRDNSError].
|
||||
err = proxy.ValidatePrivateConfig(uc, privateNets)
|
||||
if err != nil {
|
||||
return uc, &PrivateRDNSError{err: err}
|
||||
}
|
||||
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration
|
||||
// depending on configuration.
|
||||
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
|
||||
@@ -130,85 +165,9 @@ func setProxyUpstreamMode(
|
||||
return nil
|
||||
}
|
||||
|
||||
// createBootstrap returns a bootstrap resolver based on the configuration of s.
|
||||
// boots are the upstream resolvers that should be closed after use. r is the
|
||||
// actual bootstrap resolver, which may include the system hosts.
|
||||
//
|
||||
// TODO(e.burkov): This function currently returns a resolver and a slice of
|
||||
// the upstream resolvers, which are essentially the same. boots are returned
|
||||
// for being able to close them afterwards, but it introduces an implicit
|
||||
// contract that r could only be used before that. Anyway, this code should
|
||||
// improve when the [proxy.UpstreamConfig] will become an [upstream.Resolver]
|
||||
// and be used here.
|
||||
func (s *Server) createBootstrap(
|
||||
addrs []string,
|
||||
opts *upstream.Options,
|
||||
) (r upstream.Resolver, boots []*upstream.UpstreamResolver, err error) {
|
||||
if len(addrs) == 0 {
|
||||
addrs = defaultBootstrap
|
||||
}
|
||||
|
||||
boots, err = aghnet.ParseBootstraps(addrs, opts)
|
||||
if err != nil {
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var parallel upstream.ParallelResolver
|
||||
for _, b := range boots {
|
||||
parallel = append(parallel, upstream.NewCachingResolver(b))
|
||||
}
|
||||
|
||||
if s.etcHosts != nil {
|
||||
r = upstream.ConsequentResolver{s.etcHosts, parallel}
|
||||
} else {
|
||||
r = parallel
|
||||
}
|
||||
|
||||
return r, boots, nil
|
||||
}
|
||||
|
||||
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
|
||||
// This function is useful for filtering out non-upstream lines from upstream
|
||||
// configs.
|
||||
func IsCommentOrEmpty(s string) (ok bool) {
|
||||
return len(s) == 0 || s[0] == '#'
|
||||
}
|
||||
|
||||
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified. It also
|
||||
// checks each domain of domain-specific upstreams for being ARPA pointing to
|
||||
// a locally-served network. privateNets must not be nil.
|
||||
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
|
||||
conf, err := proxy.ParseUpstreamsConfig(upstreams, &upstream.Options{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating config: %w", err)
|
||||
}
|
||||
|
||||
if conf == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
keys := maps.Keys(conf.DomainReservedUpstreams)
|
||||
slices.Sort(keys)
|
||||
|
||||
var errs []error
|
||||
for _, domain := range keys {
|
||||
var subnet netip.Prefix
|
||||
subnet, err = extractARPASubnet(domain)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !privateNets.Contains(subnet.Addr()) {
|
||||
errs = append(
|
||||
errs,
|
||||
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w")
|
||||
}
|
||||
|
||||
@@ -559,6 +559,8 @@ type Result struct {
|
||||
Reason Reason `json:",omitempty"`
|
||||
|
||||
// IsFiltered is true if the request is filtered.
|
||||
//
|
||||
// TODO(d.kolyshev): Get rid of this flag.
|
||||
IsFiltered bool `json:",omitempty"`
|
||||
}
|
||||
|
||||
|
||||
@@ -200,7 +200,7 @@ func TestParallelSB(t *testing.T) {
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
t.Run("group", func(t *testing.T) {
|
||||
for i := 0; i < 100; i++ {
|
||||
for i := range 100 {
|
||||
t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
d.checkMatch(t, sbBlocked, setts)
|
||||
@@ -670,7 +670,7 @@ func BenchmarkSafeBrowsing(b *testing.B) {
|
||||
}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
for range b.N {
|
||||
res, err := d.CheckHost(sbBlocked, dns.TypeA, setts)
|
||||
require.NoError(b, err)
|
||||
|
||||
|
||||
@@ -63,8 +63,6 @@ func TestIDGenerator_Fix(t *testing.T) {
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
g := newIDGenerator(1)
|
||||
g.fix(tc.in)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package rulelist_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
@@ -28,14 +27,12 @@ func TestEngine_Refresh(t *testing.T) {
|
||||
require.NotNil(t, eng)
|
||||
testutil.CleanupAndRequireSuccess(t, eng.Close)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
buf := make([]byte, rulelist.DefaultRuleBufSize)
|
||||
cli := &http.Client{
|
||||
Timeout: testTimeout,
|
||||
}
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
err := eng.Refresh(ctx, buf, cli, cacheDir, rulelist.DefaultMaxRuleListSize)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package rulelist_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -67,14 +66,12 @@ func TestFilter_Refresh(t *testing.T) {
|
||||
|
||||
require.NotNil(t, f)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
buf := make([]byte, rulelist.DefaultRuleBufSize)
|
||||
cli := &http.Client{
|
||||
Timeout: testTimeout,
|
||||
}
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
res, err := f.Refresh(ctx, buf, cli, cacheDir, rulelist.DefaultMaxRuleListSize)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -132,7 +132,6 @@ func TestParser_Parse(t *testing.T) {
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -216,7 +215,7 @@ func BenchmarkParser_Parse(b *testing.B) {
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for range b.N {
|
||||
resSink, errSink = p.Parse(dst, src, buf)
|
||||
dst.Reset()
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
package filtering
|
||||
|
||||
import "github.com/miekg/dns"
|
||||
|
||||
// SafeSearch interface describes a service for search engines hosts rewrites.
|
||||
type SafeSearch interface {
|
||||
// CheckHost checks host with safe search filter. CheckHost must be safe
|
||||
@@ -16,9 +14,6 @@ type SafeSearch interface {
|
||||
|
||||
// SafeSearchConfig is a struct with safe search related settings.
|
||||
type SafeSearchConfig struct {
|
||||
// CustomResolver is the resolver used by safe search.
|
||||
CustomResolver Resolver `yaml:"-" json:"-"`
|
||||
|
||||
// Enabled indicates if safe search is enabled entirely.
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
|
||||
@@ -40,13 +35,7 @@ func (d *DNSFilter) checkSafeSearch(
|
||||
qtype uint16,
|
||||
setts *Settings,
|
||||
) (res Result, err error) {
|
||||
if !setts.ProtectionEnabled ||
|
||||
!setts.SafeSearchEnabled ||
|
||||
(qtype != dns.TypeA && qtype != dns.TypeAAAA) {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
if d.safeSearch == nil {
|
||||
if d.safeSearch == nil || !setts.ProtectionEnabled || !setts.SafeSearchEnabled {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -3,11 +3,9 @@ package safesearch
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -67,7 +65,6 @@ type Default struct {
|
||||
engine *urlfilter.DNSEngine
|
||||
|
||||
cache cache.Cache
|
||||
resolver filtering.Resolver
|
||||
logPrefix string
|
||||
cacheTTL time.Duration
|
||||
}
|
||||
@@ -80,11 +77,6 @@ func NewDefault(
|
||||
cacheSize uint,
|
||||
cacheTTL time.Duration,
|
||||
) (ss *Default, err error) {
|
||||
var resolver filtering.Resolver = net.DefaultResolver
|
||||
if conf.CustomResolver != nil {
|
||||
resolver = conf.CustomResolver
|
||||
}
|
||||
|
||||
ss = &Default{
|
||||
mu: &sync.RWMutex{},
|
||||
|
||||
@@ -92,7 +84,6 @@ func NewDefault(
|
||||
EnableLRU: true,
|
||||
MaxSize: cacheSize,
|
||||
}),
|
||||
resolver: resolver,
|
||||
// Use %s, because the client safe-search names already contain double
|
||||
// quotes.
|
||||
logPrefix: fmt.Sprintf("safesearch %s: ", name),
|
||||
@@ -170,8 +161,11 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res
|
||||
ss.log(log.DEBUG, "lookup for %q finished in %s", host, time.Since(start))
|
||||
}()
|
||||
|
||||
if qtype != dns.TypeA && qtype != dns.TypeAAAA {
|
||||
return filtering.Result{}, fmt.Errorf("unsupported question type %s", dns.Type(qtype))
|
||||
switch qtype {
|
||||
case dns.TypeA, dns.TypeAAAA, dns.TypeHTTPS:
|
||||
// Go on.
|
||||
default:
|
||||
return filtering.Result{}, nil
|
||||
}
|
||||
|
||||
// Check cache. Return cached result if it was found
|
||||
@@ -195,6 +189,9 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res
|
||||
}
|
||||
|
||||
res = *fltRes
|
||||
|
||||
// TODO(a.garipov): Consider switch back to resolving CNAME records IPs and
|
||||
// saving results to cache.
|
||||
ss.setCacheResult(host, qtype, res)
|
||||
|
||||
return res, nil
|
||||
@@ -223,20 +220,13 @@ func (ss *Default) searchHost(host string, qtype rules.RRType) (res *rules.DNSRe
|
||||
}
|
||||
|
||||
// newResult creates Result object from rewrite rule. qtype must be either
|
||||
// [dns.TypeA] or [dns.TypeAAAA]. If err is nil, res is never nil, so that the
|
||||
// empty result is converted into a NODATA response.
|
||||
//
|
||||
// TODO(a.garipov): Use the main rewrite result mechanism used in
|
||||
// [dnsforward.Server.filterDNSRequest]. Now we resolve IPs for CNAME to save
|
||||
// them in the safe search cache.
|
||||
// [dns.TypeA] or [dns.TypeAAAA], or [dns.TypeHTTPS]. If err is nil, res is
|
||||
// never nil, so that the empty result is converted into a NODATA response.
|
||||
func (ss *Default) newResult(
|
||||
rewrite *rules.DNSRewrite,
|
||||
qtype rules.RRType,
|
||||
) (res *filtering.Result, err error) {
|
||||
res = &filtering.Result{
|
||||
Rules: []*filtering.ResultRule{{
|
||||
FilterListID: rulelist.URLFilterIDSafeSearch,
|
||||
}},
|
||||
Reason: filtering.FilteredSafeSearch,
|
||||
IsFiltered: true,
|
||||
}
|
||||
@@ -247,69 +237,19 @@ func (ss *Default) newResult(
|
||||
return nil, fmt.Errorf("expected ip rewrite value, got %T(%[1]v)", rewrite.Value)
|
||||
}
|
||||
|
||||
res.Rules[0].IP = ip
|
||||
res.Rules = []*filtering.ResultRule{{
|
||||
FilterListID: rulelist.URLFilterIDSafeSearch,
|
||||
IP: ip,
|
||||
}}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
host := rewrite.NewCNAME
|
||||
if host == "" {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
res.CanonName = host
|
||||
|
||||
ss.log(log.DEBUG, "resolving %q", host)
|
||||
|
||||
ips, err := ss.resolver.LookupIP(context.Background(), qtypeToProto(qtype), host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolving cname: %w", err)
|
||||
}
|
||||
|
||||
ss.log(log.DEBUG, "resolved %s", ips)
|
||||
|
||||
for _, ip := range ips {
|
||||
// TODO(a.garipov): Remove this filtering once the resolver we use
|
||||
// actually learns about network.
|
||||
addr := fitToProto(ip, qtype)
|
||||
if addr == (netip.Addr{}) {
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Rules[0]?
|
||||
res.Rules[0].IP = addr
|
||||
}
|
||||
res.CanonName = rewrite.NewCNAME
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// qtypeToProto returns "ip4" for [dns.TypeA] and "ip6" for [dns.TypeAAAA].
|
||||
// It panics for other types.
|
||||
func qtypeToProto(qtype rules.RRType) (proto string) {
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
return "ip4"
|
||||
case dns.TypeAAAA:
|
||||
return "ip6"
|
||||
default:
|
||||
panic(fmt.Errorf("safesearch: unsupported question type %s", dns.Type(qtype)))
|
||||
}
|
||||
}
|
||||
|
||||
// fitToProto returns a non-nil IP address if ip is the correct protocol version
|
||||
// for qtype. qtype is expected to be either [dns.TypeA] or [dns.TypeAAAA].
|
||||
func fitToProto(ip net.IP, qtype rules.RRType) (res netip.Addr) {
|
||||
if ip4 := ip.To4(); qtype == dns.TypeA {
|
||||
if ip4 != nil {
|
||||
return netip.AddrFrom4([4]byte(ip4))
|
||||
}
|
||||
} else if ip = ip.To16(); ip != nil && qtype == dns.TypeAAAA {
|
||||
return netip.AddrFrom16([16]byte(ip))
|
||||
}
|
||||
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
// setCacheResult stores data in cache for host. qtype is expected to be either
|
||||
// [dns.TypeA] or [dns.TypeAAAA].
|
||||
func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering.Result) {
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
package safesearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
@@ -79,47 +76,6 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
||||
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
|
||||
}
|
||||
|
||||
func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||
const domain = "www.google.ru"
|
||||
|
||||
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
|
||||
|
||||
res, err := ss.CheckHost(domain, testQType)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, res.IsFiltered)
|
||||
assert.Empty(t, res.Rules)
|
||||
|
||||
resolver := &aghtest.Resolver{
|
||||
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
|
||||
ip4, ip6 := aghtest.HostToIPs(host)
|
||||
|
||||
return []net.IP{ip4.AsSlice(), ip6.AsSlice()}, nil
|
||||
},
|
||||
}
|
||||
|
||||
ss = newForTest(t, defaultSafeSearchConf)
|
||||
ss.resolver = resolver
|
||||
|
||||
// Lookup for safesearch domain.
|
||||
rewrite := ss.searchHost(domain, testQType)
|
||||
|
||||
wantIP, _ := aghtest.HostToIPs(rewrite.NewCNAME)
|
||||
|
||||
res, err = ss.CheckHost(domain, testQType)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.Equal(t, wantIP, res.Rules[0].IP)
|
||||
|
||||
// Check cache.
|
||||
cachedValue, isFound := ss.getCachedResult(domain, testQType)
|
||||
require.True(t, isFound)
|
||||
require.Len(t, cachedValue.Rules, 1)
|
||||
|
||||
assert.Equal(t, wantIP, cachedValue.Rules[0].IP)
|
||||
}
|
||||
|
||||
const googleHost = "www.google.com"
|
||||
|
||||
var dnsRewriteSink *rules.DNSRewrite
|
||||
@@ -127,7 +83,7 @@ var dnsRewriteSink *rules.DNSRewrite
|
||||
func BenchmarkSafeSearch(b *testing.B) {
|
||||
ss := newForTest(b, defaultSafeSearchConf)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
for range b.N {
|
||||
dnsRewriteSink = ss.searchHost(googleHost, testQType)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||
@@ -31,8 +30,6 @@ const (
|
||||
|
||||
// testConf is the default safe search configuration for tests.
|
||||
var testConf = filtering.SafeSearchConfig{
|
||||
CustomResolver: nil,
|
||||
|
||||
Enabled: true,
|
||||
|
||||
Bing: true,
|
||||
@@ -52,61 +49,60 @@ func TestDefault_CheckHost_yandex(t *testing.T) {
|
||||
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check host for each domain.
|
||||
for _, host := range []string{
|
||||
hosts := []string{
|
||||
"yandex.ru",
|
||||
"yAndeX.ru",
|
||||
"YANdex.COM",
|
||||
"yandex.by",
|
||||
"yandex.kz",
|
||||
"www.yandex.com",
|
||||
} {
|
||||
var res filtering.Result
|
||||
res, err = ss.CheckHost(host, testQType)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.Equal(t, yandexIP, res.Rules[0].IP)
|
||||
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefault_CheckHost_yandexAAAA(t *testing.T) {
|
||||
conf := testConf
|
||||
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
|
||||
require.NoError(t, err)
|
||||
testCases := []struct {
|
||||
want netip.Addr
|
||||
name string
|
||||
qt uint16
|
||||
}{{
|
||||
want: yandexIP,
|
||||
name: "a",
|
||||
qt: dns.TypeA,
|
||||
}, {
|
||||
want: netip.Addr{},
|
||||
name: "aaaa",
|
||||
qt: dns.TypeAAAA,
|
||||
}, {
|
||||
want: netip.Addr{},
|
||||
name: "https",
|
||||
qt: dns.TypeHTTPS,
|
||||
}}
|
||||
|
||||
res, err := ss.CheckHost("www.yandex.ru", dns.TypeAAAA)
|
||||
require.NoError(t, err)
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
for _, host := range hosts {
|
||||
// Check host for each domain.
|
||||
var res filtering.Result
|
||||
res, err = ss.CheckHost(host, tc.qt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
assert.True(t, res.IsFiltered)
|
||||
assert.Equal(t, filtering.FilteredSafeSearch, res.Reason)
|
||||
|
||||
// TODO(a.garipov): Currently, the safe-search filter returns a single rule
|
||||
// with a nil IP address. This isn't really necessary and should be changed
|
||||
// once the TODO in [safesearch.Default.newResult] is resolved.
|
||||
require.Len(t, res.Rules, 1)
|
||||
if tc.want == (netip.Addr{}) {
|
||||
assert.Empty(t, res.Rules)
|
||||
} else {
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.Empty(t, res.Rules[0].IP)
|
||||
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
|
||||
rule := res.Rules[0]
|
||||
assert.Equal(t, tc.want, rule.IP)
|
||||
assert.Equal(t, rulelist.URLFilterIDSafeSearch, rule.FilterListID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefault_CheckHost_google(t *testing.T) {
|
||||
resolver := &aghtest.Resolver{
|
||||
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
|
||||
ip4, ip6 := aghtest.HostToIPs(host)
|
||||
|
||||
return []net.IP{ip4.AsSlice(), ip6.AsSlice()}, nil
|
||||
},
|
||||
}
|
||||
|
||||
wantIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
|
||||
|
||||
conf := testConf
|
||||
conf.CustomResolver = resolver
|
||||
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
|
||||
ss, err := safesearch.NewDefault(testConf, "", testCacheSize, testCacheTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check host for each domain.
|
||||
@@ -125,11 +121,9 @@ func TestDefault_CheckHost_google(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.Equal(t, wantIP, res.Rules[0].IP)
|
||||
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
|
||||
assert.Equal(t, filtering.FilteredSafeSearch, res.Reason)
|
||||
assert.Equal(t, "forcesafesearch.google.com", res.CanonName)
|
||||
assert.Empty(t, res.Rules)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -154,17 +148,7 @@ func (r *testResolver) LookupIP(
|
||||
}
|
||||
|
||||
func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
|
||||
conf := testConf
|
||||
conf.CustomResolver = &testResolver{
|
||||
OnLookupIP: func(_ context.Context, network, host string) (ips []net.IP, err error) {
|
||||
assert.Equal(t, "ip6", network)
|
||||
assert.Equal(t, "safe.duckduckgo.com", host)
|
||||
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
|
||||
ss, err := safesearch.NewDefault(testConf, "", testCacheSize, testCacheTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The DuckDuckGo safe-search addresses are resolved through CNAMEs, but
|
||||
@@ -174,14 +158,9 @@ func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
|
||||
// TODO(a.garipov): Currently, the safe-search filter returns a single rule
|
||||
// with a nil IP address. This isn't really necessary and should be changed
|
||||
// once the TODO in [safesearch.Default.newResult] is resolved.
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.Empty(t, res.Rules[0].IP)
|
||||
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
|
||||
assert.Equal(t, filtering.FilteredSafeSearch, res.Reason)
|
||||
assert.Equal(t, "safe.duckduckgo.com", res.CanonName)
|
||||
assert.Empty(t, res.Rules)
|
||||
}
|
||||
|
||||
func TestDefault_Update(t *testing.T) {
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
"github.com/AdguardTeam/golibs/hostsfile"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
// DHCP is an interface for accessing DHCP lease data the [clientsContainer]
|
||||
@@ -46,22 +45,20 @@ type DHCP interface {
|
||||
|
||||
// clientsContainer is the storage of all runtime and persistent clients.
|
||||
type clientsContainer struct {
|
||||
// TODO(a.garipov): Perhaps use a number of separate indices for different
|
||||
// types (string, netip.Addr, and so on).
|
||||
list map[string]*client.Persistent // name -> client
|
||||
|
||||
// clientIndex stores information about persistent clients.
|
||||
clientIndex *client.Index
|
||||
|
||||
// ipToRC maps IP addresses to runtime client information.
|
||||
ipToRC map[netip.Addr]*client.Runtime
|
||||
// runtimeIndex stores information about runtime clients.
|
||||
runtimeIndex *client.RuntimeIndex
|
||||
|
||||
allTags *container.MapSet[string]
|
||||
|
||||
// dhcp is the DHCP service implementation.
|
||||
dhcp DHCP
|
||||
|
||||
// dnsServer is used for checking clients IP status access list status
|
||||
dnsServer *dnsforward.Server
|
||||
// clientChecker checks if a client is blocked by the current access
|
||||
// settings.
|
||||
clientChecker BlockedClientChecker
|
||||
|
||||
// etcHosts contains list of rewrite rules taken from the operating system's
|
||||
// hosts database.
|
||||
@@ -90,6 +87,12 @@ type clientsContainer struct {
|
||||
testing bool
|
||||
}
|
||||
|
||||
// BlockedClientChecker checks if a client is blocked by the current access
|
||||
// settings.
|
||||
type BlockedClientChecker interface {
|
||||
IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string)
|
||||
}
|
||||
|
||||
// Init initializes clients container
|
||||
// dhcpServer: optional
|
||||
// Note: this function must be called only once
|
||||
@@ -100,12 +103,12 @@ func (clients *clientsContainer) Init(
|
||||
arpDB arpdb.Interface,
|
||||
filteringConf *filtering.Config,
|
||||
) (err error) {
|
||||
if clients.list != nil {
|
||||
log.Fatal("clients.list != nil")
|
||||
// TODO(s.chzhen): Refactor it.
|
||||
if clients.clientIndex != nil {
|
||||
return errors.Error("clients container already initialized")
|
||||
}
|
||||
|
||||
clients.list = map[string]*client.Persistent{}
|
||||
clients.ipToRC = map[netip.Addr]*client.Runtime{}
|
||||
clients.runtimeIndex = client.NewRuntimeIndex()
|
||||
|
||||
clients.clientIndex = client.NewIndex()
|
||||
|
||||
@@ -248,8 +251,6 @@ func (o *clientObject) toPersistent(
|
||||
}
|
||||
|
||||
if o.SafeSearchConf.Enabled {
|
||||
o.SafeSearchConf.CustomResolver = safeSearchResolver{}
|
||||
|
||||
err = cli.SetSafeSearch(
|
||||
o.SafeSearchConf,
|
||||
filteringConf.SafeSearchCacheSize,
|
||||
@@ -285,9 +286,17 @@ func (clients *clientsContainer) addFromConfig(
|
||||
return fmt.Errorf("clients: init persistent client at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
_, err = clients.add(cli)
|
||||
// TODO(s.chzhen): Consider moving to the client index constructor.
|
||||
err = clients.clientIndex.ClashesUID(cli)
|
||||
if err != nil {
|
||||
log.Error("clients: adding client at index %d %s: %s", i, cli.Name, err)
|
||||
return fmt.Errorf("adding client %s at index %d: %w", cli.Name, i, err)
|
||||
}
|
||||
|
||||
err = clients.add(cli)
|
||||
if err != nil {
|
||||
// TODO(s.chzhen): Return an error instead of logging if more
|
||||
// stringent requirements are implemented.
|
||||
log.Error("clients: adding client %s at index %d: %s", cli.Name, i, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,9 +309,9 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
objs = make([]*clientObject, 0, len(clients.list))
|
||||
for _, cli := range clients.list {
|
||||
o := &clientObject{
|
||||
objs = make([]*clientObject, 0, clients.clientIndex.Size())
|
||||
clients.clientIndex.Range(func(cli *client.Persistent) (cont bool) {
|
||||
objs = append(objs, &clientObject{
|
||||
Name: cli.Name,
|
||||
|
||||
BlockedServices: cli.BlockedServices.Clone(),
|
||||
@@ -323,10 +332,10 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||
IgnoreStatistics: cli.IgnoreStatistics,
|
||||
UpstreamsCacheEnabled: cli.UpstreamsCacheEnabled,
|
||||
UpstreamsCacheSize: cli.UpstreamsCacheSize,
|
||||
}
|
||||
})
|
||||
|
||||
objs = append(objs, o)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Maps aren't guaranteed to iterate in the same order each time, so the
|
||||
// above loop can generate different orderings when writing to the config
|
||||
@@ -363,8 +372,8 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source)
|
||||
return client.SourcePersistent
|
||||
}
|
||||
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
if ok {
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
if rc != nil {
|
||||
src, _ = rc.Info()
|
||||
}
|
||||
|
||||
@@ -406,23 +415,26 @@ func (clients *clientsContainer) clientOrArtificial(
|
||||
id string,
|
||||
) (c *querylog.Client, art bool) {
|
||||
defer func() {
|
||||
c.Disallowed, c.DisallowedRule = clients.dnsServer.IsBlockedClient(ip, id)
|
||||
c.Disallowed, c.DisallowedRule = clients.clientChecker.IsBlockedClient(ip, id)
|
||||
if c.WHOIS == nil {
|
||||
c.WHOIS = &whois.Info{}
|
||||
}
|
||||
}()
|
||||
|
||||
cli, ok := clients.find(id)
|
||||
if ok {
|
||||
if !ok {
|
||||
cli = clients.clientIndex.FindByIPWithoutZone(ip)
|
||||
}
|
||||
|
||||
if cli != nil {
|
||||
return &querylog.Client{
|
||||
Name: cli.Name,
|
||||
IgnoreQueryLog: cli.IgnoreQueryLog,
|
||||
}, false
|
||||
}
|
||||
|
||||
var rc *client.Runtime
|
||||
rc, ok = clients.findRuntimeClient(ip)
|
||||
if ok {
|
||||
rc := clients.findRuntimeClient(ip)
|
||||
if rc != nil {
|
||||
_, host := rc.Info()
|
||||
|
||||
return &querylog.Client{
|
||||
@@ -542,47 +554,38 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent,
|
||||
return nil, false
|
||||
}
|
||||
|
||||
for _, c = range clients.list {
|
||||
_, found := slices.BinarySearchFunc(c.MACs, foundMAC, slices.Compare[net.HardwareAddr])
|
||||
if found {
|
||||
return c, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
return clients.clientIndex.FindByMAC(foundMAC)
|
||||
}
|
||||
|
||||
// runtimeClient returns a runtime client from internal index. Note that it
|
||||
// doesn't include DHCP clients.
|
||||
func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime, ok bool) {
|
||||
func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime) {
|
||||
if ip == (netip.Addr{}) {
|
||||
return nil, false
|
||||
return nil
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
rc, ok = clients.ipToRC[ip]
|
||||
|
||||
return rc, ok
|
||||
return clients.runtimeIndex.Client(ip)
|
||||
}
|
||||
|
||||
// findRuntimeClient finds a runtime client by their IP.
|
||||
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime, ok bool) {
|
||||
rc, ok = clients.runtimeClient(ip)
|
||||
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime) {
|
||||
rc = clients.runtimeClient(ip)
|
||||
host := clients.dhcp.HostByIP(ip)
|
||||
|
||||
if host != "" {
|
||||
if !ok {
|
||||
rc = &client.Runtime{}
|
||||
if rc == nil {
|
||||
rc = client.NewRuntime(ip)
|
||||
}
|
||||
|
||||
rc.SetInfo(client.SourceDHCP, []string{host})
|
||||
|
||||
return rc, true
|
||||
return rc
|
||||
}
|
||||
|
||||
return rc, ok
|
||||
return rc
|
||||
}
|
||||
|
||||
// check validates the client. It also sorts the client tags.
|
||||
@@ -615,43 +618,32 @@ func (clients *clientsContainer) check(c *client.Persistent) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// add adds a new client object. ok is false if such client already exists or
|
||||
// if an error occurred.
|
||||
func (clients *clientsContainer) add(c *client.Persistent) (ok bool, err error) {
|
||||
// add adds a persistent client or returns an error.
|
||||
func (clients *clientsContainer) add(c *client.Persistent) (err error) {
|
||||
err = clients.check(c)
|
||||
if err != nil {
|
||||
return false, err
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
// check Name index
|
||||
_, ok = clients.list[c.Name]
|
||||
if ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// check ID index
|
||||
err = clients.clientIndex.Clashes(c)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
|
||||
clients.addLocked(c)
|
||||
|
||||
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs(), len(clients.list))
|
||||
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs(), clients.clientIndex.Size())
|
||||
|
||||
return true, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// addLocked c to the indexes. clients.lock is expected to be locked.
|
||||
func (clients *clientsContainer) addLocked(c *client.Persistent) {
|
||||
// update Name index
|
||||
clients.list[c.Name] = c
|
||||
|
||||
// update ID index
|
||||
clients.clientIndex.Add(c)
|
||||
}
|
||||
|
||||
@@ -660,8 +652,7 @@ func (clients *clientsContainer) remove(name string) (ok bool) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
var c *client.Persistent
|
||||
c, ok = clients.list[name]
|
||||
c, ok := clients.clientIndex.FindByName(name)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
@@ -678,9 +669,6 @@ func (clients *clientsContainer) removeLocked(c *client.Persistent) {
|
||||
log.Error("client container: removing client %s: %s", c.Name, err)
|
||||
}
|
||||
|
||||
// Update the name index.
|
||||
delete(clients.list, c.Name)
|
||||
|
||||
// Update the ID index.
|
||||
clients.clientIndex.Delete(c)
|
||||
}
|
||||
@@ -696,22 +684,6 @@ func (clients *clientsContainer) update(prev, c *client.Persistent) (err error)
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
// Check the name index.
|
||||
if prev.Name != c.Name {
|
||||
_, ok := clients.list[c.Name]
|
||||
if ok {
|
||||
return errors.Error("client already exists")
|
||||
}
|
||||
}
|
||||
|
||||
if c.EqualIDs(prev) {
|
||||
clients.removeLocked(prev)
|
||||
clients.addLocked(c)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check the ID index.
|
||||
err = clients.clientIndex.Clashes(c)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
@@ -734,12 +706,12 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
|
||||
return
|
||||
}
|
||||
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
if !ok {
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
if rc == nil {
|
||||
// Create a RuntimeClient implicitly so that we don't do this check
|
||||
// again.
|
||||
rc = &client.Runtime{}
|
||||
clients.ipToRC[ip] = rc
|
||||
rc = client.NewRuntime(ip)
|
||||
clients.runtimeIndex.Add(rc)
|
||||
|
||||
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||
} else {
|
||||
@@ -798,61 +770,54 @@ func (clients *clientsContainer) addHostLocked(
|
||||
host string,
|
||||
src client.Source,
|
||||
) (ok bool) {
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
if !ok {
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
if rc == nil {
|
||||
if src < client.SourceDHCP {
|
||||
if clients.dhcp.HostByIP(ip) != "" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
rc = &client.Runtime{}
|
||||
clients.ipToRC[ip] = rc
|
||||
rc = client.NewRuntime(ip)
|
||||
clients.runtimeIndex.Add(rc)
|
||||
}
|
||||
|
||||
rc.SetInfo(src, []string{host})
|
||||
|
||||
log.Debug("clients: adding client info %s -> %q %q [%d]", ip, src, host, len(clients.ipToRC))
|
||||
log.Debug(
|
||||
"clients: adding client info %s -> %q %q [%d]",
|
||||
ip,
|
||||
src,
|
||||
host,
|
||||
clients.runtimeIndex.Size(),
|
||||
)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// rmHostsBySrc removes all entries that match the specified source.
|
||||
func (clients *clientsContainer) rmHostsBySrc(src client.Source) {
|
||||
n := 0
|
||||
for ip, rc := range clients.ipToRC {
|
||||
rc.Unset(src)
|
||||
if rc.IsEmpty() {
|
||||
delete(clients.ipToRC, ip)
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("clients: removed %d client aliases", n)
|
||||
}
|
||||
|
||||
// addFromHostsFile fills the client-hostname pairing index from the system's
|
||||
// hosts files.
|
||||
func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorage) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
clients.rmHostsBySrc(client.SourceHostsFile)
|
||||
deleted := clients.runtimeIndex.DeleteBySource(client.SourceHostsFile)
|
||||
log.Debug("clients: removed %d client aliases from system hosts file", deleted)
|
||||
|
||||
n := 0
|
||||
added := 0
|
||||
hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) {
|
||||
// Only the first name of the first record is considered a canonical
|
||||
// hostname for the IP address.
|
||||
//
|
||||
// TODO(e.burkov): Consider using all the names from all the records.
|
||||
if clients.addHostLocked(addr, names[0], client.SourceHostsFile) {
|
||||
n++
|
||||
added++
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
log.Debug("clients: added %d client aliases from system hosts file", n)
|
||||
log.Debug("clients: added %d client aliases from system hosts file", added)
|
||||
}
|
||||
|
||||
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
|
||||
@@ -876,7 +841,8 @@ func (clients *clientsContainer) addFromSystemARP() {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
clients.rmHostsBySrc(client.SourceARP)
|
||||
deleted := clients.runtimeIndex.DeleteBySource(client.SourceARP)
|
||||
log.Debug("clients: removed %d client aliases from arp neighborhood", deleted)
|
||||
|
||||
added := 0
|
||||
for _, n := range ns {
|
||||
@@ -891,18 +857,5 @@ func (clients *clientsContainer) addFromSystemARP() {
|
||||
// close gracefully closes all the client-specific upstream configurations of
|
||||
// the persistent clients.
|
||||
func (clients *clientsContainer) close() (err error) {
|
||||
persistent := maps.Values(clients.list)
|
||||
slices.SortFunc(persistent, func(a, b *client.Persistent) (res int) {
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
var errs []error
|
||||
|
||||
for _, cli := range persistent {
|
||||
if err = cli.CloseUpstreams(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Join(errs...)
|
||||
return clients.clientIndex.CloseUpstreams()
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) {
|
||||
}
|
||||
|
||||
dhcp := &testDHCP{
|
||||
OnLeases: func() (leases []*dhcpsvc.Lease) { panic("not implemented") },
|
||||
OnLeases: func() (leases []*dhcpsvc.Lease) { return nil },
|
||||
OnHostBy: func(ip netip.Addr) (host string) { return "" },
|
||||
OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) { return nil },
|
||||
}
|
||||
@@ -72,23 +72,19 @@ func TestClients(t *testing.T) {
|
||||
IPs: []netip.Addr{cli1IP, cliIPv6},
|
||||
}
|
||||
|
||||
ok, err := clients.add(c)
|
||||
err := clients.add(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
c = &client.Persistent{
|
||||
Name: "client2",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{cli2IP},
|
||||
}
|
||||
|
||||
ok, err = clients.add(c)
|
||||
err = clients.add(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
c, ok = clients.find(cli1)
|
||||
c, ok := clients.find(cli1)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client1", c.Name)
|
||||
@@ -111,22 +107,20 @@ func TestClients(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("add_fail_name", func(t *testing.T) {
|
||||
ok, err := clients.add(&client.Persistent{
|
||||
err := clients.add(&client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ok)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("add_fail_ip", func(t *testing.T) {
|
||||
ok, err := clients.add(&client.Persistent{
|
||||
err := clients.add(&client.Persistent{
|
||||
Name: "client3",
|
||||
UID: client.MustNewUID(),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("update_fail_ip", func(t *testing.T) {
|
||||
@@ -145,12 +139,13 @@ func TestClients(t *testing.T) {
|
||||
cliNewIP = netip.MustParseAddr(cliNew)
|
||||
)
|
||||
|
||||
prev, ok := clients.list["client1"]
|
||||
prev, ok := clients.clientIndex.FindByName("client1")
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, prev)
|
||||
|
||||
err := clients.update(prev, &client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
UID: prev.UID,
|
||||
IPs: []netip.Addr{cliNewIP},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -160,12 +155,13 @@ func TestClients(t *testing.T) {
|
||||
|
||||
assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent)
|
||||
|
||||
prev, ok = clients.list["client1"]
|
||||
prev, ok = clients.clientIndex.FindByName("client1")
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, prev)
|
||||
|
||||
err = clients.update(prev, &client.Persistent{
|
||||
Name: "client1-renamed",
|
||||
UID: client.MustNewUID(),
|
||||
UID: prev.UID,
|
||||
IPs: []netip.Addr{cliNewIP},
|
||||
UseOwnSettings: true,
|
||||
})
|
||||
@@ -177,7 +173,7 @@ func TestClients(t *testing.T) {
|
||||
assert.Equal(t, "client1-renamed", c.Name)
|
||||
assert.True(t, c.UseOwnSettings)
|
||||
|
||||
nilCli, ok := clients.list["client1"]
|
||||
nilCli, ok := clients.clientIndex.FindByName("client1")
|
||||
require.False(t, ok)
|
||||
|
||||
assert.Nil(t, nilCli)
|
||||
@@ -244,7 +240,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
t.Run("new_client", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.255")
|
||||
clients.setWHOISInfo(ip, whois)
|
||||
rc := clients.ipToRC[ip]
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
require.NotNil(t, rc)
|
||||
|
||||
assert.Equal(t, whois, rc.WHOIS())
|
||||
@@ -256,7 +252,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
assert.True(t, ok)
|
||||
|
||||
clients.setWHOISInfo(ip, whois)
|
||||
rc := clients.ipToRC[ip]
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
require.NotNil(t, rc)
|
||||
|
||||
assert.Equal(t, whois, rc.WHOIS())
|
||||
@@ -265,16 +261,15 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
t.Run("can't_set_manually-added", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.2")
|
||||
|
||||
ok, err := clients.add(&client.Persistent{
|
||||
err := clients.add(&client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
clients.setWHOISInfo(ip, whois)
|
||||
rc := clients.ipToRC[ip]
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
require.Nil(t, rc)
|
||||
|
||||
assert.True(t, clients.remove("client1"))
|
||||
@@ -288,7 +283,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
|
||||
// Add a client.
|
||||
ok, err := clients.add(&client.Persistent{
|
||||
err := clients.add(&client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")},
|
||||
@@ -296,10 +291,9 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Now add an auto-client with the same IP.
|
||||
ok = clients.addHost(ip, "test", client.SourceRDNS)
|
||||
ok := clients.addHost(ip, "test", client.SourceRDNS)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
@@ -339,22 +333,20 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add a new client with the same IP as for a client with MAC.
|
||||
ok, err := clients.add(&client.Persistent{
|
||||
err = clients.add(&client.Persistent{
|
||||
Name: "client2",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{ip},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Add a new client with the IP from the first client's IP range.
|
||||
ok, err = clients.add(&client.Persistent{
|
||||
err = clients.add(&client.Persistent{
|
||||
Name: "client3",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -362,7 +354,7 @@ func TestClientsCustomUpstream(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
// Add client with upstreams.
|
||||
ok, err := clients.add(&client.Persistent{
|
||||
err := clients.add(&client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")},
|
||||
@@ -372,7 +364,6 @@ func TestClientsCustomUpstream(t *testing.T) {
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
upsConf, err := clients.UpstreamConfigByID("1.2.3.4", net.DefaultResolver)
|
||||
assert.Nil(t, upsConf)
|
||||
|
||||
@@ -96,22 +96,26 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
for _, c := range clients.list {
|
||||
clients.clientIndex.Range(func(c *client.Persistent) (cont bool) {
|
||||
cj := clientToJSON(c)
|
||||
data.Clients = append(data.Clients, cj)
|
||||
}
|
||||
|
||||
for ip, rc := range clients.ipToRC {
|
||||
return true
|
||||
})
|
||||
|
||||
clients.runtimeIndex.Range(func(rc *client.Runtime) (cont bool) {
|
||||
src, host := rc.Info()
|
||||
cj := runtimeClientJSON{
|
||||
WHOIS: whoisOrEmpty(rc),
|
||||
Name: host,
|
||||
Source: src,
|
||||
IP: ip,
|
||||
IP: rc.Addr(),
|
||||
}
|
||||
|
||||
data.RuntimeClients = append(data.RuntimeClients, cj)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
for _, l := range clients.dhcp.Leases() {
|
||||
cj := runtimeClientJSON{
|
||||
@@ -332,20 +336,16 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := clients.add(c)
|
||||
err = clients.add(c)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !ok {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Client already exists")
|
||||
|
||||
return
|
||||
if !clients.testing {
|
||||
onConfigModified()
|
||||
}
|
||||
|
||||
onConfigModified()
|
||||
}
|
||||
|
||||
// handleDelClient is the handler for POST /control/clients/delete HTTP API.
|
||||
@@ -370,7 +370,9 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.
|
||||
return
|
||||
}
|
||||
|
||||
onConfigModified()
|
||||
if !clients.testing {
|
||||
onConfigModified()
|
||||
}
|
||||
}
|
||||
|
||||
// updateJSON contains the name and data of the updated persistent client.
|
||||
@@ -404,7 +406,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
prev, ok = clients.list[dj.Name]
|
||||
prev, ok = clients.clientIndex.FindByName(dj.Name)
|
||||
}()
|
||||
|
||||
if !ok {
|
||||
@@ -427,14 +429,16 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||
return
|
||||
}
|
||||
|
||||
onConfigModified()
|
||||
if !clients.testing {
|
||||
onConfigModified()
|
||||
}
|
||||
}
|
||||
|
||||
// handleFindClient is the handler for GET /control/clients/find HTTP API.
|
||||
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
data := []map[string]*clientJSON{}
|
||||
for i := 0; i < len(q); i++ {
|
||||
for i := range len(q) {
|
||||
idStr := q.Get(fmt.Sprintf("ip%d", i))
|
||||
if idStr == "" {
|
||||
break
|
||||
@@ -447,7 +451,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
cj = clients.findRuntime(ip, idStr)
|
||||
} else {
|
||||
cj = clientToJSON(c)
|
||||
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
|
||||
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
|
||||
}
|
||||
|
||||
@@ -463,14 +467,14 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
|
||||
// non-nil.
|
||||
func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) {
|
||||
rc, ok := clients.findRuntimeClient(ip)
|
||||
if !ok {
|
||||
rc := clients.findRuntimeClient(ip)
|
||||
if rc == nil {
|
||||
// It is still possible that the IP used to be in the runtime clients
|
||||
// list, but then the server was reloaded. So, check the DNS server's
|
||||
// blocked IP list.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
|
||||
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
|
||||
cj = &clientJSON{
|
||||
IDs: []string{idStr},
|
||||
Disallowed: &disallowed,
|
||||
@@ -488,7 +492,7 @@ func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *c
|
||||
WHOIS: whoisOrEmpty(rc),
|
||||
}
|
||||
|
||||
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
|
||||
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
|
||||
|
||||
return cj
|
||||
|
||||
399
internal/home/clientshttp_internal_test.go
Normal file
399
internal/home/clientshttp_internal_test.go
Normal file
@@ -0,0 +1,399 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
testClientIP1 = "1.1.1.1"
|
||||
testClientIP2 = "2.2.2.2"
|
||||
)
|
||||
|
||||
// testBlockedClientChecker is a mock implementation of the
|
||||
// [BlockedClientChecker] interface.
|
||||
type testBlockedClientChecker struct {
|
||||
onIsBlockedClient func(ip netip.Addr, clientiD string) (blocked bool, rule string)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ BlockedClientChecker = (*testBlockedClientChecker)(nil)
|
||||
|
||||
// IsBlockedClient implements the [BlockedClientChecker] interface for
|
||||
// *testBlockedClientChecker.
|
||||
func (c *testBlockedClientChecker) IsBlockedClient(
|
||||
ip netip.Addr,
|
||||
clientID string,
|
||||
) (blocked bool, rule string) {
|
||||
return c.onIsBlockedClient(ip, clientID)
|
||||
}
|
||||
|
||||
// newPersistentClient is a helper function that returns a persistent client
|
||||
// with the specified name and newly generated UID.
|
||||
func newPersistentClient(name string) (c *client.Persistent) {
|
||||
return &client.Persistent{
|
||||
Name: name,
|
||||
UID: client.MustNewUID(),
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: &schedule.Weekly{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// newPersistentClientWithIDs is a helper function that returns a persistent
|
||||
// client with the specified name and ids.
|
||||
func newPersistentClientWithIDs(tb testing.TB, name string, ids []string) (c *client.Persistent) {
|
||||
tb.Helper()
|
||||
|
||||
c = newPersistentClient(name)
|
||||
err := c.SetIDs(ids)
|
||||
require.NoError(tb, err)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// assertClients is a helper function that compares lists of persistent clients.
|
||||
func assertClients(tb testing.TB, want, got []*client.Persistent) {
|
||||
tb.Helper()
|
||||
|
||||
require.Len(tb, got, len(want))
|
||||
|
||||
sortFunc := func(a, b *client.Persistent) (n int) {
|
||||
return cmp.Compare(a.Name, b.Name)
|
||||
}
|
||||
|
||||
slices.SortFunc(want, sortFunc)
|
||||
slices.SortFunc(got, sortFunc)
|
||||
|
||||
slices.CompareFunc(want, got, func(a, b *client.Persistent) (n int) {
|
||||
assert.True(tb, a.EqualIDs(b), "%q doesn't have the same ids as %q", a.Name, b.Name)
|
||||
|
||||
return 0
|
||||
})
|
||||
}
|
||||
|
||||
// assertPersistentClients is a helper function that uses HTTP API to check
|
||||
// whether want persistent clients are the same as the persistent clients stored
|
||||
// in the clients container.
|
||||
func assertPersistentClients(tb testing.TB, clients *clientsContainer, want []*client.Persistent) {
|
||||
tb.Helper()
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
clients.handleGetClients(rw, &http.Request{})
|
||||
|
||||
body, err := io.ReadAll(rw.Body)
|
||||
require.NoError(tb, err)
|
||||
|
||||
clientList := &clientListJSON{}
|
||||
err = json.Unmarshal(body, clientList)
|
||||
require.NoError(tb, err)
|
||||
|
||||
var got []*client.Persistent
|
||||
for _, cj := range clientList.Clients {
|
||||
var c *client.Persistent
|
||||
c, err = clients.jsonToClient(*cj, nil)
|
||||
require.NoError(tb, err)
|
||||
|
||||
got = append(got, c)
|
||||
}
|
||||
|
||||
assertClients(tb, want, got)
|
||||
}
|
||||
|
||||
// assertPersistentClientsData is a helper function that checks whether want
|
||||
// persistent clients are the same as the persistent clients stored in data.
|
||||
func assertPersistentClientsData(
|
||||
tb testing.TB,
|
||||
clients *clientsContainer,
|
||||
data []map[string]*clientJSON,
|
||||
want []*client.Persistent,
|
||||
) {
|
||||
tb.Helper()
|
||||
|
||||
var got []*client.Persistent
|
||||
for _, cm := range data {
|
||||
for _, cj := range cm {
|
||||
var c *client.Persistent
|
||||
c, err := clients.jsonToClient(*cj, nil)
|
||||
require.NoError(tb, err)
|
||||
|
||||
got = append(got, c)
|
||||
}
|
||||
}
|
||||
|
||||
assertClients(tb, want, got)
|
||||
}
|
||||
|
||||
func TestClientsContainer_HandleAddClient(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
|
||||
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
|
||||
|
||||
clientEmptyID := newPersistentClient("empty_client_id")
|
||||
clientEmptyID.ClientIDs = []string{""}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
client *client.Persistent
|
||||
wantCode int
|
||||
wantClient []*client.Persistent
|
||||
}{{
|
||||
name: "add_one",
|
||||
client: clientOne,
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{clientOne},
|
||||
}, {
|
||||
name: "add_two",
|
||||
client: clientTwo,
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{clientOne, clientTwo},
|
||||
}, {
|
||||
name: "duplicate_client",
|
||||
client: clientTwo,
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientOne, clientTwo},
|
||||
}, {
|
||||
name: "empty_client_id",
|
||||
client: clientEmptyID,
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientOne, clientTwo},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cj := clientToJSON(tc.client)
|
||||
|
||||
body, err := json.Marshal(cj)
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := http.NewRequest(http.MethodPost, "", bytes.NewReader(body))
|
||||
require.NoError(t, err)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
clients.handleAddClient(rw, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.wantCode, rw.Code)
|
||||
|
||||
assertPersistentClients(t, clients, tc.wantClient)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientsContainer_HandleDelClient(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
|
||||
err := clients.add(clientOne)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
|
||||
err = clients.add(clientTwo)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
client *client.Persistent
|
||||
wantCode int
|
||||
wantClient []*client.Persistent
|
||||
}{{
|
||||
name: "remove_one",
|
||||
client: clientOne,
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{clientTwo},
|
||||
}, {
|
||||
name: "duplicate_client",
|
||||
client: clientOne,
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientTwo},
|
||||
}, {
|
||||
name: "empty_client_name",
|
||||
client: newPersistentClient(""),
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientTwo},
|
||||
}, {
|
||||
name: "remove_two",
|
||||
client: clientTwo,
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cj := clientToJSON(tc.client)
|
||||
|
||||
var body []byte
|
||||
body, err = json.Marshal(cj)
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(body))
|
||||
require.NoError(t, err)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
clients.handleDelClient(rw, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.wantCode, rw.Code)
|
||||
|
||||
assertPersistentClients(t, clients, tc.wantClient)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientsContainer_HandleUpdateClient(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
|
||||
err := clients.add(clientOne)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertPersistentClients(t, clients, []*client.Persistent{clientOne})
|
||||
|
||||
clientModified := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
|
||||
|
||||
clientEmptyID := newPersistentClient("empty_client_id")
|
||||
clientEmptyID.ClientIDs = []string{""}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
clientName string
|
||||
modified *client.Persistent
|
||||
wantCode int
|
||||
wantClient []*client.Persistent
|
||||
}{{
|
||||
name: "update_one",
|
||||
clientName: clientOne.Name,
|
||||
modified: clientModified,
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{clientModified},
|
||||
}, {
|
||||
name: "empty_name",
|
||||
clientName: "",
|
||||
modified: clientOne,
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientModified},
|
||||
}, {
|
||||
name: "client_not_found",
|
||||
clientName: "client_not_found",
|
||||
modified: clientOne,
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientModified},
|
||||
}, {
|
||||
name: "empty_client_id",
|
||||
clientName: clientModified.Name,
|
||||
modified: clientEmptyID,
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientModified},
|
||||
}, {
|
||||
name: "no_ids",
|
||||
clientName: clientModified.Name,
|
||||
modified: newPersistentClient("no_ids"),
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientModified},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
uj := updateJSON{
|
||||
Name: tc.clientName,
|
||||
Data: *clientToJSON(tc.modified),
|
||||
}
|
||||
|
||||
var body []byte
|
||||
body, err = json.Marshal(uj)
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(body))
|
||||
require.NoError(t, err)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
clients.handleUpdateClient(rw, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.wantCode, rw.Code)
|
||||
|
||||
assertPersistentClients(t, clients, tc.wantClient)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientsContainer_HandleFindClient(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
clients.clientChecker = &testBlockedClientChecker{
|
||||
onIsBlockedClient: func(ip netip.Addr, clientID string) (ok bool, rule string) {
|
||||
return false, ""
|
||||
},
|
||||
}
|
||||
|
||||
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
|
||||
err := clients.add(clientOne)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
|
||||
err = clients.add(clientTwo)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
query url.Values
|
||||
wantCode int
|
||||
wantClient []*client.Persistent
|
||||
}{{
|
||||
name: "single",
|
||||
query: url.Values{
|
||||
"ip0": []string{testClientIP1},
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{clientOne},
|
||||
}, {
|
||||
name: "multiple",
|
||||
query: url.Values{
|
||||
"ip0": []string{testClientIP1},
|
||||
"ip1": []string{testClientIP2},
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{clientOne, clientTwo},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodGet, "", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
r.URL.RawQuery = tc.query.Encode()
|
||||
rw := httptest.NewRecorder()
|
||||
clients.handleFindClient(rw, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.wantCode, rw.Code)
|
||||
|
||||
var body []byte
|
||||
body, err = io.ReadAll(rw.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientData := []map[string]*clientJSON{}
|
||||
err = json.Unmarshal(body, &clientData)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertPersistentClientsData(t, clients, clientData, tc.wantClient)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -203,15 +203,24 @@ type dnsConfig struct {
|
||||
// resolver should be used.
|
||||
PrivateNets []netutil.Prefix `yaml:"private_networks"`
|
||||
|
||||
// UsePrivateRDNS defines if the PTR requests for unknown addresses from
|
||||
// locally-served networks should be resolved via private PTR resolvers.
|
||||
// UsePrivateRDNS enables resolving requests containing a private IP address
|
||||
// using private reverse DNS resolvers. See PrivateRDNSResolvers.
|
||||
//
|
||||
// TODO(e.burkov): Rename in YAML.
|
||||
UsePrivateRDNS bool `yaml:"use_private_ptr_resolvers"`
|
||||
|
||||
// LocalPTRResolvers is the slice of addresses to be used as upstreams
|
||||
// for PTR queries for locally-served networks.
|
||||
LocalPTRResolvers []string `yaml:"local_ptr_upstreams"`
|
||||
// PrivateRDNSResolvers is the slice of addresses to be used as upstreams
|
||||
// for private requests. It's only used for PTR, SOA, and NS queries,
|
||||
// containing an ARPA subdomain, came from the the client with private
|
||||
// address. The address considered private according to PrivateNets.
|
||||
//
|
||||
// If empty, the OS-provided resolvers are used for private requests.
|
||||
PrivateRDNSResolvers []string `yaml:"local_ptr_upstreams"`
|
||||
|
||||
// UseDNS64 defines if DNS64 should be used for incoming requests.
|
||||
// UseDNS64 defines if DNS64 should be used for incoming requests. Requests
|
||||
// of type PTR for addresses within the configured prefixes will be resolved
|
||||
// via [PrivateRDNSResolvers], so those should be valid and UsePrivateRDNS
|
||||
// be set to true.
|
||||
UseDNS64 bool `yaml:"use_dns64"`
|
||||
|
||||
// DNS64Prefixes is the list of NAT64 prefixes to be used for DNS64.
|
||||
@@ -658,7 +667,7 @@ func (c *configuration) write() (err error) {
|
||||
dns := &config.DNS
|
||||
dns.Config = c
|
||||
|
||||
dns.LocalPTRResolvers = s.LocalPTRResolvers()
|
||||
dns.PrivateRDNSResolvers = s.LocalPTRResolvers()
|
||||
|
||||
addrProcConf := s.AddrProcConfig()
|
||||
config.Clients.Sources.RDNS = addrProcConf.UseRDNS
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -18,7 +17,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
@@ -150,21 +148,19 @@ func initDNSServer(
|
||||
return fmt.Errorf("dnsforward.NewServer: %w", err)
|
||||
}
|
||||
|
||||
Context.clients.dnsServer = Context.dnsServer
|
||||
Context.clients.clientChecker = Context.dnsServer
|
||||
|
||||
dnsConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpReg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("newServerConfig: %w", err)
|
||||
}
|
||||
|
||||
// Try to prepare the server with disabled private RDNS resolution if it
|
||||
// failed to prepare as is. See TODO on [ErrBadPrivateRDNSUpstreams].
|
||||
err = Context.dnsServer.Prepare(dnsConf)
|
||||
if privRDNSErr := (&dnsforward.PrivateRDNSError{}); errors.As(err, &privRDNSErr) {
|
||||
log.Info("WARNING: %s; trying to disable private RDNS resolution", err)
|
||||
|
||||
// TODO(e.burkov): Recreate the server with private RDNS disabled. This
|
||||
// should go away once the private RDNS resolution is moved to the proxy.
|
||||
var locResErr *dnsforward.LocalResolversError
|
||||
if errors.As(err, &locResErr) && errors.Is(locResErr.Err, upstream.ErrNoUpstreams) {
|
||||
log.Info("WARNING: no local resolvers configured while private RDNS " +
|
||||
"resolution enabled, trying to disable")
|
||||
dnsConf.UsePrivateRDNS = false
|
||||
err = Context.dnsServer.Prepare(dnsConf)
|
||||
}
|
||||
@@ -245,7 +241,7 @@ func newServerConfig(
|
||||
TLSv12Roots: Context.tlsRoots,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpReg,
|
||||
LocalPTRResolvers: dnsConf.LocalPTRResolvers,
|
||||
LocalPTRResolvers: dnsConf.PrivateRDNSResolvers,
|
||||
UseDNS64: dnsConf.UseDNS64,
|
||||
DNS64Prefixes: dnsConf.DNS64Prefixes,
|
||||
UsePrivateRDNS: dnsConf.UsePrivateRDNS,
|
||||
@@ -531,36 +527,6 @@ func closeDNSServer() {
|
||||
log.Debug("all dns modules are closed")
|
||||
}
|
||||
|
||||
// safeSearchResolver is a [filtering.Resolver] implementation used for safe
|
||||
// search.
|
||||
type safeSearchResolver struct{}
|
||||
|
||||
// type check
|
||||
var _ filtering.Resolver = safeSearchResolver{}
|
||||
|
||||
// LookupIP implements [filtering.Resolver] interface for safeSearchResolver.
|
||||
// It returns the slice of net.Addr with IPv4 and IPv6 instances.
|
||||
func (r safeSearchResolver) LookupIP(
|
||||
ctx context.Context,
|
||||
network string,
|
||||
host string,
|
||||
) (ips []net.IP, err error) {
|
||||
addrs, err := Context.dnsServer.Resolve(ctx, network, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(addrs) == 0 {
|
||||
return nil, fmt.Errorf("couldn't lookup host: %s", host)
|
||||
}
|
||||
|
||||
for _, a := range addrs {
|
||||
ips = append(ips, a.AsSlice())
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
// checkStatsAndQuerylogDirs checks and returns directory paths to store
|
||||
// statistics and query log.
|
||||
func checkStatsAndQuerylogDirs(
|
||||
|
||||
@@ -439,7 +439,6 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
|
||||
conf.ParentalBlockHost = host
|
||||
}
|
||||
|
||||
conf.SafeSearchConf.CustomResolver = safeSearchResolver{}
|
||||
conf.SafeSearch, err = safesearch.NewDefault(
|
||||
conf.SafeSearchConf,
|
||||
"default",
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -76,8 +76,7 @@ func getLogSettings(opts options) (ls *logSettings) {
|
||||
ls.Verbose = true
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Use cmp.Or in Go 1.22.
|
||||
ls.File = stringutil.Coalesce(opts.logFile, ls.File)
|
||||
ls.File = cmp.Or(opts.logFile, ls.File)
|
||||
|
||||
if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
|
||||
// When running as a Windows service, use eventlog by default if
|
||||
|
||||
@@ -306,7 +306,7 @@ func handleServiceStatusCommand(s service.Service) {
|
||||
}
|
||||
}
|
||||
|
||||
// handleServiceStatusCommand handles service "install" command
|
||||
// handleServiceInstallCommand handles service "install" command.
|
||||
func handleServiceInstallCommand(s service.Service) {
|
||||
err := svcAction(s, "install")
|
||||
if err != nil {
|
||||
@@ -340,7 +340,7 @@ AdGuard Home is now available at the following addresses:`)
|
||||
}
|
||||
}
|
||||
|
||||
// handleServiceStatusCommand handles service "uninstall" command
|
||||
// handleServiceUninstallCommand handles service "uninstall" command.
|
||||
func handleServiceUninstallCommand(s service.Service) {
|
||||
if aghos.IsOpenWrt() {
|
||||
// On OpenWrt it is important to run disable command first
|
||||
@@ -649,11 +649,6 @@ status() {
|
||||
|
||||
// freeBSDScript is the source of the daemon script for FreeBSD. Keep as close
|
||||
// as possible to the https://github.com/kardianos/service/blob/18c957a3dc1120a2efe77beb401d476bade9e577/service_freebsd.go#L204.
|
||||
//
|
||||
// TODO(a.garipov): Don't use .WorkingDirectory here. There are currently no
|
||||
// guarantees that it will actually be the required directory.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2614.
|
||||
const freeBSDScript = `#!/bin/sh
|
||||
# PROVIDE: {{.Name}}
|
||||
# REQUIRE: networking
|
||||
@@ -667,7 +662,9 @@ name="{{.Name}}"
|
||||
pidfile_child="/var/run/${name}.pid"
|
||||
pidfile="/var/run/${name}_daemon.pid"
|
||||
command="/usr/sbin/daemon"
|
||||
command_args="-P ${pidfile} -p ${pidfile_child} -T ${name} -r {{.WorkingDirectory}}/{{.Name}}"
|
||||
daemon_args="-P ${pidfile} -p ${pidfile_child} -r -t ${name}"
|
||||
command_args="${daemon_args} {{.Path}}{{range .Arguments}} {{.}}{{end}}"
|
||||
|
||||
run_rc_command "$1"
|
||||
`
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -14,7 +15,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
@@ -76,7 +76,7 @@ func (*openbsdRunComService) Platform() (p string) {
|
||||
|
||||
// String implements service.Service interface for *openbsdRunComService.
|
||||
func (s *openbsdRunComService) String() string {
|
||||
return stringutil.Coalesce(s.cfg.DisplayName, s.cfg.Name)
|
||||
return cmp.Or(s.cfg.DisplayName, s.cfg.Name)
|
||||
}
|
||||
|
||||
// getBool returns the value of the given name from kv, assuming the value is a
|
||||
|
||||
@@ -147,7 +147,7 @@ func BenchmarkManager_LookupHost(b *testing.B) {
|
||||
|
||||
b.Run("long", func(b *testing.B) {
|
||||
const name = "a.very.long.domain.name.inside.the.domain.example.com"
|
||||
for i := 0; i < b.N; i++ {
|
||||
for range b.N {
|
||||
ipsetPropsSink = m.lookupHost(name)
|
||||
}
|
||||
|
||||
@@ -156,7 +156,7 @@ func BenchmarkManager_LookupHost(b *testing.B) {
|
||||
|
||||
b.Run("short", func(b *testing.B) {
|
||||
const name = "example.net"
|
||||
for i := 0; i < b.N; i++ {
|
||||
for range b.N {
|
||||
ipsetPropsSink = m.lookupHost(name)
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
"github.com/google/renameio/v2/maybe"
|
||||
)
|
||||
|
||||
@@ -38,7 +39,7 @@ func (h *signalHandler) handle() {
|
||||
|
||||
if aghos.IsReconfigureSignal(sig) {
|
||||
h.reconfigure()
|
||||
} else if aghos.IsShutdownSignal(sig) {
|
||||
} else if osutil.IsShutdownSignal(sig) {
|
||||
status := h.shutdown()
|
||||
h.removePID()
|
||||
|
||||
@@ -122,7 +123,8 @@ func newSignalHandler(
|
||||
services: svcs,
|
||||
}
|
||||
|
||||
aghos.NotifyShutdownSignal(h.signal)
|
||||
notifier := osutil.DefaultSignalNotifier{}
|
||||
osutil.NotifyShutdownSignal(notifier, h.signal)
|
||||
aghos.NotifyReconfigureSignal(h.signal)
|
||||
|
||||
return h
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package dnssvc_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -94,10 +93,8 @@ func TestService(t *testing.T) {
|
||||
}},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
cli := &dns.Client{}
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
var resp *dns.Msg
|
||||
require.Eventually(t, func() (ok bool) {
|
||||
@@ -110,10 +107,8 @@ func TestService(t *testing.T) {
|
||||
assert.NotNil(t, resp)
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
err = svc.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
|
||||
|
||||
err = svc.Shutdown(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = upstreamSrv.Shutdown()
|
||||
|
||||
@@ -109,12 +109,8 @@ func newTestServer(
|
||||
|
||||
err = svc.Start()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
err = svc.Shutdown(ctx)
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
return svc.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
|
||||
})
|
||||
|
||||
c = svc.Config()
|
||||
|
||||
@@ -303,7 +303,7 @@ func BenchmarkAnonymizeIP(b *testing.B) {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
for range b.N {
|
||||
AnonymizeIP(bc.ip)
|
||||
}
|
||||
|
||||
@@ -313,7 +313,7 @@ func BenchmarkAnonymizeIP(b *testing.B) {
|
||||
b.Run(bc.name+"_slow", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
for range b.N {
|
||||
anonymizeIPSlow(bc.ip)
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ type logEntry struct {
|
||||
Answer []byte `json:",omitempty"`
|
||||
OrigAnswer []byte `json:",omitempty"`
|
||||
|
||||
// TODO(s.chzhen): Use netip.Addr.
|
||||
IP net.IP `json:"IP"`
|
||||
|
||||
Result filtering.Result
|
||||
|
||||
@@ -143,13 +143,13 @@ func TestQueryLogOffsetLimit(t *testing.T) {
|
||||
secondPageDomain = "second.example.org"
|
||||
)
|
||||
// Add entries to the log.
|
||||
for i := 0; i < entNum; i++ {
|
||||
for range entNum {
|
||||
addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
}
|
||||
// Write them to the first file.
|
||||
require.NoError(t, l.flushLogBuffer())
|
||||
// Add more to the in-memory part of log.
|
||||
for i := 0; i < entNum; i++ {
|
||||
for range entNum {
|
||||
addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
}
|
||||
|
||||
@@ -215,7 +215,7 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
|
||||
|
||||
const entNum = 10
|
||||
// Add entries to the log.
|
||||
for i := 0; i < entNum; i++ {
|
||||
for range entNum {
|
||||
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
}
|
||||
// Write them to disk.
|
||||
|
||||
@@ -37,7 +37,7 @@ func prepareTestFile(t *testing.T, dir string, linesNum int) (name string) {
|
||||
|
||||
var lineIP uint32
|
||||
lineTime := time.Date(2020, 2, 18, 19, 36, 35, 920973000, time.UTC)
|
||||
for i := 0; i < linesNum; i++ {
|
||||
for range linesNum {
|
||||
lineIP++
|
||||
lineTime = lineTime.Add(time.Second)
|
||||
|
||||
|
||||
@@ -68,13 +68,13 @@ func TestStats_races(t *testing.T) {
|
||||
startWG, finWG := &sync.WaitGroup{}, &sync.WaitGroup{}
|
||||
waitCh := make(chan unit)
|
||||
|
||||
for i := 0; i < writersNum; i++ {
|
||||
for i := range writersNum {
|
||||
startWG.Add(1)
|
||||
finWG.Add(1)
|
||||
go writeFunc(startWG, finWG, waitCh, i)
|
||||
}
|
||||
|
||||
for i := 0; i < readersNum; i++ {
|
||||
for range readersNum {
|
||||
startWG.Add(1)
|
||||
finWG.Add(1)
|
||||
go readFunc(startWG, finWG, waitCh)
|
||||
@@ -111,7 +111,7 @@ func TestStatsCtx_FillCollectedStats_daily(t *testing.T) {
|
||||
|
||||
dailyData := []*unitDB{}
|
||||
|
||||
for i := 0; i < daysCount*24; i++ {
|
||||
for i := range daysCount * 24 {
|
||||
n := uint64(i)
|
||||
nResult := make([]uint64, resultLast)
|
||||
nResult[RFiltered] = n
|
||||
|
||||
@@ -195,7 +195,7 @@ func TestLargeNumbers(t *testing.T) {
|
||||
for h := 0; h < hoursNum; h++ {
|
||||
atomic.AddUint32(&curHour, 1)
|
||||
|
||||
for i := 0; i < cliNumPerHour; i++ {
|
||||
for i := range cliNumPerHour {
|
||||
ip := net.IP{127, 0, byte((i & 0xff00) >> 8), byte(i & 0xff)}
|
||||
e := &stats.Entry{
|
||||
Domain: fmt.Sprintf("domain%d.hour%d", i, h),
|
||||
|
||||
@@ -525,9 +525,8 @@ func (s *StatsCtx) fillCollectedStatsDaily(
|
||||
hours := countHours(curHour, days)
|
||||
units = units[len(units)-hours:]
|
||||
|
||||
for i := 0; i < len(units); i++ {
|
||||
for i, u := range units {
|
||||
day := i / 24
|
||||
u := units[i]
|
||||
|
||||
data.DNSQueries[day] += u.NTotal
|
||||
data.BlockedFiltering[day] += u.NResult[RFiltered]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
module github.com/AdguardTeam/AdGuardHome/internal/tools
|
||||
|
||||
go 1.22.2
|
||||
go 1.22.3
|
||||
|
||||
require (
|
||||
github.com/fzipp/gocyclo v0.6.0
|
||||
|
||||
@@ -3,6 +3,7 @@ package whois
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -17,7 +18,6 @@ import (
|
||||
"github.com/AdguardTeam/golibs/ioutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/bluele/gcache"
|
||||
)
|
||||
|
||||
@@ -174,7 +174,7 @@ func whoisParse(data []byte, maxLen int) (info map[string]string) {
|
||||
val = trimValue(val, maxLen)
|
||||
case "descr", "netname":
|
||||
key = "orgname"
|
||||
val = stringutil.Coalesce(orgname, val)
|
||||
val = cmp.Or(orgname, val)
|
||||
orgname = val
|
||||
case "whois":
|
||||
key = "whois"
|
||||
@@ -232,7 +232,7 @@ func (w *Default) queryAll(ctx context.Context, target string) (info map[string]
|
||||
server := net.JoinHostPort(w.serverAddr, w.portStr)
|
||||
var data []byte
|
||||
|
||||
for i := 0; i < w.maxRedirects; i++ {
|
||||
for range w.maxRedirects {
|
||||
data, err = w.query(ctx, target, server)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
|
||||
Reference in New Issue
Block a user