all: sync with master

This commit is contained in:
Ainar Garipov
2024-05-15 13:34:12 +03:00
parent 6318fc424b
commit 667263a3a8
82 changed files with 2356 additions and 1817 deletions

View File

@@ -91,8 +91,6 @@ func TestDefaultAddrProc_Process_rDNS(t *testing.T) {
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
@@ -186,8 +184,6 @@ func TestDefaultAddrProc_Process_WHOIS(t *testing.T) {
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

View File

@@ -7,6 +7,7 @@ package client
import (
"encoding"
"fmt"
"net/netip"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
)
@@ -56,6 +57,9 @@ func (cs Source) MarshalText() (text []byte, err error) {
// Runtime is a client information from different sources.
type Runtime struct {
// ip is an IP address of a client.
ip netip.Addr
// whois is the filtered WHOIS information of a client.
whois *whois.Info
@@ -80,6 +84,15 @@ type Runtime struct {
hostsFile []string
}
// NewRuntime constructs a new runtime client. ip must be valid IP address.
//
// TODO(s.chzhen): Validate IP address.
func NewRuntime(ip netip.Addr) (r *Runtime) {
return &Runtime{
ip: ip,
}
}
// Info returns a client information from the highest-priority source.
func (r *Runtime) Info() (cs Source, host string) {
info := []string{}
@@ -133,8 +146,8 @@ func (r *Runtime) SetWHOIS(info *whois.Info) {
r.whois = info
}
// Unset clears a cs information.
func (r *Runtime) Unset(cs Source) {
// unset clears a cs information.
func (r *Runtime) unset(cs Source) {
switch cs {
case SourceWHOIS:
r.whois = nil
@@ -149,11 +162,16 @@ func (r *Runtime) Unset(cs Source) {
}
}
// IsEmpty returns true if there is no information from any source.
func (r *Runtime) IsEmpty() (ok bool) {
// isEmpty returns true if there is no information from any source.
func (r *Runtime) isEmpty() (ok bool) {
return r.whois == nil &&
r.arp == nil &&
r.rdns == nil &&
r.dhcp == nil &&
r.hostsFile == nil
}
// Addr returns an IP address of the client.
func (r *Runtime) Addr() (ip netip.Addr) {
return r.ip
}

View File

@@ -4,8 +4,12 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/exp/maps"
)
// macKey contains MAC as byte array of 6, 8, or 20 bytes.
@@ -28,6 +32,9 @@ func macToKey(mac net.HardwareAddr) (key macKey) {
// Index stores all information about persistent clients.
type Index struct {
// nameToUID maps client name to UID.
nameToUID map[string]UID
// clientIDToUID maps client ID to UID.
clientIDToUID map[string]UID
@@ -47,6 +54,7 @@ type Index struct {
// NewIndex initializes the new instance of client index.
func NewIndex() (ci *Index) {
return &Index{
nameToUID: map[string]UID{},
clientIDToUID: map[string]UID{},
ipToUID: map[netip.Addr]UID{},
subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare),
@@ -62,6 +70,8 @@ func (ci *Index) Add(c *Persistent) {
panic("client must contain uid")
}
ci.nameToUID[c.Name] = c.UID
for _, id := range c.ClientIDs {
ci.clientIDToUID[id] = c.UID
}
@@ -82,15 +92,30 @@ func (ci *Index) Add(c *Persistent) {
ci.uidToClient[c.UID] = c
}
// ClashesUID returns existing persistent client with the same UID as c. Note
// that this is only possible when configuration contains duplicate fields.
func (ci *Index) ClashesUID(c *Persistent) (err error) {
p, ok := ci.uidToClient[c.UID]
if ok {
return fmt.Errorf("another client %q uses the same uid", p.Name)
}
return nil
}
// Clashes returns an error if the index contains a different persistent client
// with at least a single identifier contained by c. c must be non-nil.
func (ci *Index) Clashes(c *Persistent) (err error) {
if p := ci.clashesName(c); p != nil {
return fmt.Errorf("another client uses the same name %q", p.Name)
}
for _, id := range c.ClientIDs {
existing, ok := ci.clientIDToUID[id]
if ok && existing != c.UID {
p := ci.uidToClient[existing]
return fmt.Errorf("another client %q uses the same ID %q", p.Name, id)
return fmt.Errorf("another client %q uses the same ClientID %q", p.Name, id)
}
}
@@ -112,6 +137,21 @@ func (ci *Index) Clashes(c *Persistent) (err error) {
return nil
}
// clashesName returns existing persistent client with the same name as c or
// nil. c must be non-nil.
func (ci *Index) clashesName(c *Persistent) (existing *Persistent) {
existing, ok := ci.FindByName(c.Name)
if !ok {
return nil
}
if existing.UID != c.UID {
return existing
}
return nil
}
// clashesIP returns a previous client with the same IP address as c. c must be
// non-nil.
func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) {
@@ -184,21 +224,33 @@ func (ci *Index) Find(id string) (c *Persistent, ok bool) {
mac, err := net.ParseMAC(id)
if err == nil {
return ci.findByMAC(mac)
return ci.FindByMAC(mac)
}
return nil, false
}
// find finds persistent client by IP address.
// FindByName finds persistent client by name.
func (ci *Index) FindByName(name string) (c *Persistent, found bool) {
uid, found := ci.nameToUID[name]
if found {
return ci.uidToClient[uid], true
}
return nil, false
}
// findByIP finds persistent client by IP address.
func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
uid, found := ci.ipToUID[ip]
if found {
return ci.uidToClient[uid], true
}
ipWithoutZone := ip.WithZone("")
ci.subnetToUID.Range(func(pref netip.Prefix, id UID) (cont bool) {
if pref.Contains(ip) {
// Remove zone before checking because prefixes strip zones.
if pref.Contains(ipWithoutZone) {
uid, found = id, true
return false
@@ -214,8 +266,8 @@ func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
return nil, false
}
// find finds persistent client by MAC.
func (ci *Index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
// FindByMAC finds persistent client by MAC.
func (ci *Index) FindByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
k := macToKey(mac)
uid, found := ci.macToUID[k]
if found {
@@ -225,9 +277,31 @@ func (ci *Index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
return nil, false
}
// FindByIPWithoutZone finds a persistent client by IP address without zone. It
// strips the IPv6 zone index from the stored IP addresses before comparing,
// because querylog entries don't have it. See TODO on [querylog.logEntry.IP].
//
// Note that multiple clients can have the same IP address with different zones.
// Therefore, the result of this method is indeterminate.
func (ci *Index) FindByIPWithoutZone(ip netip.Addr) (c *Persistent) {
if (ip == netip.Addr{}) {
return nil
}
for addr, uid := range ci.ipToUID {
if addr.WithZone("") == ip {
return ci.uidToClient[uid]
}
}
return nil
}
// Delete removes information about persistent client from the index. c must be
// non-nil.
func (ci *Index) Delete(c *Persistent) {
delete(ci.nameToUID, c.Name)
for _, id := range c.ClientIDs {
delete(ci.clientIDToUID, id)
}
@@ -247,3 +321,48 @@ func (ci *Index) Delete(c *Persistent) {
delete(ci.uidToClient, c.UID)
}
// Size returns the number of persistent clients.
func (ci *Index) Size() (n int) {
return len(ci.uidToClient)
}
// Range calls f for each persistent client, unless cont is false. The order is
// undefined.
func (ci *Index) Range(f func(c *Persistent) (cont bool)) {
for _, c := range ci.uidToClient {
if !f(c) {
return
}
}
}
// RangeByName is like [Index.Range] but sorts the persistent clients by name
// before iterating ensuring a predictable order.
func (ci *Index) RangeByName(f func(c *Persistent) (cont bool)) {
cs := maps.Values(ci.uidToClient)
slices.SortFunc(cs, func(a, b *Persistent) (n int) {
return strings.Compare(a.Name, b.Name)
})
for _, c := range cs {
if !f(c) {
break
}
}
}
// CloseUpstreams closes upstream configurations of persistent clients.
func (ci *Index) CloseUpstreams() (err error) {
var errs []error
ci.RangeByName(func(c *Persistent) (cont bool) {
err = c.CloseUpstreams()
if err != nil {
errs = append(errs, err)
}
return true
})
return errors.Join(errs...)
}

View File

@@ -22,7 +22,7 @@ func newIDIndex(m []*Persistent) (ci *Index) {
return ci
}
func TestClientIndex(t *testing.T) {
func TestClientIndex_Find(t *testing.T) {
const (
cliIPNone = "1.2.3.4"
cliIP1 = "1.1.1.1"
@@ -35,26 +35,49 @@ func TestClientIndex(t *testing.T) {
cliID = "client-id"
cliMAC = "11:11:11:11:11:11"
linkLocalIP = "fe80::abcd:abcd:abcd:ab%eth0"
linkLocalSubnet = "fe80::/16"
)
clients := []*Persistent{{
Name: "client1",
IPs: []netip.Addr{
netip.MustParseAddr(cliIP1),
netip.MustParseAddr(cliIPv6),
},
}, {
Name: "client2",
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
}, {
Name: "client_with_mac",
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
}, {
Name: "client_with_id",
ClientIDs: []string{cliID},
}}
var (
clientWithBothFams = &Persistent{
Name: "client1",
IPs: []netip.Addr{
netip.MustParseAddr(cliIP1),
netip.MustParseAddr(cliIPv6),
},
}
clientWithSubnet = &Persistent{
Name: "client2",
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
}
clientWithMAC = &Persistent{
Name: "client_with_mac",
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
}
clientWithID = &Persistent{
Name: "client_with_id",
ClientIDs: []string{cliID},
}
clientLinkLocal = &Persistent{
Name: "client_link_local",
Subnets: []netip.Prefix{netip.MustParsePrefix(linkLocalSubnet)},
}
)
clients := []*Persistent{
clientWithBothFams,
clientWithSubnet,
clientWithMAC,
clientWithID,
clientLinkLocal,
}
ci := newIDIndex(clients)
testCases := []struct {
@@ -64,19 +87,23 @@ func TestClientIndex(t *testing.T) {
}{{
name: "ipv4_ipv6",
ids: []string{cliIP1, cliIPv6},
want: clients[0],
want: clientWithBothFams,
}, {
name: "ipv4_subnet",
ids: []string{cliIP2, cliSubnetIP},
want: clients[1],
want: clientWithSubnet,
}, {
name: "mac",
ids: []string{cliMAC},
want: clients[2],
want: clientWithMAC,
}, {
name: "client_id",
ids: []string{cliID},
want: clients[3],
want: clientWithID,
}, {
name: "client_link_local_subnet",
ids: []string{linkLocalIP},
want: clientLinkLocal,
}}
for _, tc := range testCases {
@@ -221,3 +248,103 @@ func TestMACToKey(t *testing.T) {
_ = macToKey(mac)
})
}
func TestIndex_FindByIPWithoutZone(t *testing.T) {
var (
ip = netip.MustParseAddr("fe80::a098:7654:32ef:ff1")
ipWithZone = netip.MustParseAddr("fe80::1ff:fe23:4567:890a%eth2")
)
var (
clientNoZone = &Persistent{
Name: "client",
IPs: []netip.Addr{ip},
}
clientWithZone = &Persistent{
Name: "client_with_zone",
IPs: []netip.Addr{ipWithZone},
}
)
ci := newIDIndex([]*Persistent{
clientNoZone,
clientWithZone,
})
testCases := []struct {
ip netip.Addr
want *Persistent
name string
}{{
name: "without_zone",
ip: ip,
want: clientNoZone,
}, {
name: "with_zone",
ip: ipWithZone,
want: clientWithZone,
}, {
name: "zero_address",
ip: netip.Addr{},
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := ci.FindByIPWithoutZone(tc.ip.WithZone(""))
require.Equal(t, tc.want, c)
})
}
}
func TestClientIndex_RangeByName(t *testing.T) {
sortedClients := []*Persistent{{
Name: "clientA",
ClientIDs: []string{"A"},
}, {
Name: "clientB",
ClientIDs: []string{"B"},
}, {
Name: "clientC",
ClientIDs: []string{"C"},
}, {
Name: "clientD",
ClientIDs: []string{"D"},
}, {
Name: "clientE",
ClientIDs: []string{"E"},
}}
testCases := []struct {
name string
want []*Persistent
}{{
name: "basic",
want: sortedClients,
}, {
name: "nil",
want: nil,
}, {
name: "one_element",
want: sortedClients[:1],
}, {
name: "two_elements",
want: sortedClients[:2],
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ci := newIDIndex(tc.want)
var got []*Persistent
ci.RangeByName(func(c *Persistent) (cont bool) {
got = append(got, c)
return true
})
assert.Equal(t, tc.want, got)
})
}
}

View File

@@ -64,9 +64,7 @@ type Persistent struct {
// upstream must be used.
UpstreamConfig *proxy.CustomUpstreamConfig
// TODO(d.kolyshev): Make SafeSearchConf a pointer.
SafeSearchConf filtering.SafeSearchConfig
SafeSearch filtering.SafeSearch
SafeSearch filtering.SafeSearch
// BlockedServices is the configuration of blocked services of a client.
BlockedServices *filtering.BlockedServices
@@ -95,6 +93,9 @@ type Persistent struct {
UseOwnBlockedServices bool
IgnoreQueryLog bool
IgnoreStatistics bool
// TODO(d.kolyshev): Make SafeSearchConf a pointer.
SafeSearchConf filtering.SafeSearchConfig
}
// SetTags sets the tags if they are known, otherwise logs an unknown tag.

View File

@@ -0,0 +1,63 @@
package client
import "net/netip"
// RuntimeIndex stores information about runtime clients.
type RuntimeIndex struct {
// index maps IP address to runtime client.
index map[netip.Addr]*Runtime
}
// NewRuntimeIndex returns initialized runtime index.
func NewRuntimeIndex() (ri *RuntimeIndex) {
return &RuntimeIndex{
index: map[netip.Addr]*Runtime{},
}
}
// Client returns the saved runtime client by ip. If no such client exists,
// returns nil.
func (ri *RuntimeIndex) Client(ip netip.Addr) (rc *Runtime) {
return ri.index[ip]
}
// Add saves the runtime client in the index. IP address of a client must be
// unique. See [Runtime.Client]. rc must not be nil.
func (ri *RuntimeIndex) Add(rc *Runtime) {
ip := rc.Addr()
ri.index[ip] = rc
}
// Size returns the number of the runtime clients.
func (ri *RuntimeIndex) Size() (n int) {
return len(ri.index)
}
// Range calls f for each runtime client in an undefined order.
func (ri *RuntimeIndex) Range(f func(rc *Runtime) (cont bool)) {
for _, rc := range ri.index {
if !f(rc) {
return
}
}
}
// Delete removes the runtime client by ip.
func (ri *RuntimeIndex) Delete(ip netip.Addr) {
delete(ri.index, ip)
}
// DeleteBySource removes all runtime clients that have information only from
// the specified source and returns the number of removed clients.
func (ri *RuntimeIndex) DeleteBySource(src Source) (n int) {
for ip, rc := range ri.index {
rc.unset(src)
if rc.isEmpty() {
delete(ri.index, ip)
n++
}
}
return n
}

View File

@@ -0,0 +1,85 @@
package client_test
import (
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/stretchr/testify/assert"
)
func TestRuntimeIndex(t *testing.T) {
const cliSrc = client.SourceARP
var (
ip1 = netip.MustParseAddr("1.1.1.1")
ip2 = netip.MustParseAddr("2.2.2.2")
ip3 = netip.MustParseAddr("3.3.3.3")
)
ri := client.NewRuntimeIndex()
currentSize := 0
testCases := []struct {
ip netip.Addr
name string
hosts []string
src client.Source
}{{
src: cliSrc,
ip: ip1,
name: "1",
hosts: []string{"host1"},
}, {
src: cliSrc,
ip: ip2,
name: "2",
hosts: []string{"host2"},
}, {
src: cliSrc,
ip: ip3,
name: "3",
hosts: []string{"host3"},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
rc := client.NewRuntime(tc.ip)
rc.SetInfo(tc.src, tc.hosts)
ri.Add(rc)
currentSize++
got := ri.Client(tc.ip)
assert.Equal(t, rc, got)
})
}
t.Run("size", func(t *testing.T) {
assert.Equal(t, currentSize, ri.Size())
})
t.Run("range", func(t *testing.T) {
s := 0
ri.Range(func(rc *client.Runtime) (cont bool) {
s++
return true
})
assert.Equal(t, currentSize, s)
})
t.Run("delete", func(t *testing.T) {
ri.Delete(ip1)
currentSize--
assert.Equal(t, currentSize, ri.Size())
})
t.Run("delete_by_src", func(t *testing.T) {
assert.Equal(t, currentSize, ri.DeleteBySource(cliSrc))
assert.Equal(t, 0, ri.Size())
})
}