Pull request 2155: AG-27492-client-persistent-index

Squashed commit of the following:

commit 1f99640f9f0a24ade7d2325737edf83ad0da3895
Merge: 5a9211e8c 9276afd79
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Feb 27 13:13:03 2024 +0300

    Merge branch 'master' into AG-27492-client-persistent-index

commit 5a9211e8c7832700ae4f58cea25ad38ccba98efa
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Feb 22 19:08:35 2024 +0300

    all: add todo

commit a4fc94904b0b05ed5ca5ba270125a7d7fb1e6817
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Feb 21 14:48:33 2024 +0300

    all: client persistent index
This commit is contained in:
Stanislav Chzhen
2024-02-27 13:48:11 +03:00
parent 9276afd79d
commit 31c3d7d302
12 changed files with 203 additions and 169 deletions

249
internal/client/index.go Normal file
View File

@@ -0,0 +1,249 @@
package client
import (
"fmt"
"net"
"net/netip"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
)
// macKey contains MAC as byte array of 6, 8, or 20 bytes.
type macKey any
// macToKey converts mac into key of type macKey, which is used as the key of
// the [clientIndex.macToUID]. mac must be valid MAC address.
func macToKey(mac net.HardwareAddr) (key macKey) {
switch len(mac) {
case 6:
return [6]byte(mac)
case 8:
return [8]byte(mac)
case 20:
return [20]byte(mac)
default:
panic(fmt.Errorf("invalid mac address %#v", mac))
}
}
// Index stores all information about persistent clients.
type Index struct {
// clientIDToUID maps client ID to UID.
clientIDToUID map[string]UID
// ipToUID maps IP address to UID.
ipToUID map[netip.Addr]UID
// macToUID maps MAC address to UID.
macToUID map[macKey]UID
// uidToClient maps UID to the persistent client.
uidToClient map[UID]*Persistent
// subnetToUID maps subnet to UID.
subnetToUID aghalg.SortedMap[netip.Prefix, UID]
}
// NewIndex initializes the new instance of client index.
func NewIndex() (ci *Index) {
return &Index{
clientIDToUID: map[string]UID{},
ipToUID: map[netip.Addr]UID{},
subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare),
macToUID: map[macKey]UID{},
uidToClient: map[UID]*Persistent{},
}
}
// Add stores information about a persistent client in the index. c must be
// non-nil and contain UID.
func (ci *Index) Add(c *Persistent) {
if (c.UID == UID{}) {
panic("client must contain uid")
}
for _, id := range c.ClientIDs {
ci.clientIDToUID[id] = c.UID
}
for _, ip := range c.IPs {
ci.ipToUID[ip] = c.UID
}
for _, pref := range c.Subnets {
ci.subnetToUID.Set(pref, c.UID)
}
for _, mac := range c.MACs {
k := macToKey(mac)
ci.macToUID[k] = c.UID
}
ci.uidToClient[c.UID] = c
}
// Clashes returns an error if the index contains a different persistent client
// with at least a single identifier contained by c. c must be non-nil.
func (ci *Index) Clashes(c *Persistent) (err error) {
for _, id := range c.ClientIDs {
existing, ok := ci.clientIDToUID[id]
if ok && existing != c.UID {
p := ci.uidToClient[existing]
return fmt.Errorf("another client %q uses the same ID %q", p.Name, id)
}
}
p, ip := ci.clashesIP(c)
if p != nil {
return fmt.Errorf("another client %q uses the same IP %q", p.Name, ip)
}
p, s := ci.clashesSubnet(c)
if p != nil {
return fmt.Errorf("another client %q uses the same subnet %q", p.Name, s)
}
p, mac := ci.clashesMAC(c)
if p != nil {
return fmt.Errorf("another client %q uses the same MAC %q", p.Name, mac)
}
return nil
}
// clashesIP returns a previous client with the same IP address as c. c must be
// non-nil.
func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) {
for _, ip := range c.IPs {
existing, ok := ci.ipToUID[ip]
if ok && existing != c.UID {
return ci.uidToClient[existing], ip
}
}
return nil, netip.Addr{}
}
// clashesSubnet returns a previous client with the same subnet as c. c must be
// non-nil.
func (ci *Index) clashesSubnet(c *Persistent) (p *Persistent, s netip.Prefix) {
for _, s = range c.Subnets {
var existing UID
var ok bool
ci.subnetToUID.Range(func(p netip.Prefix, uid UID) (cont bool) {
if s == p {
existing = uid
ok = true
return false
}
return true
})
if ok && existing != c.UID {
return ci.uidToClient[existing], s
}
}
return nil, netip.Prefix{}
}
// clashesMAC returns a previous client with the same MAC address as c. c must
// be non-nil.
func (ci *Index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr) {
for _, mac = range c.MACs {
k := macToKey(mac)
existing, ok := ci.macToUID[k]
if ok && existing != c.UID {
return ci.uidToClient[existing], mac
}
}
return nil, nil
}
// Find finds persistent client by string representation of the client ID, IP
// address, or MAC.
func (ci *Index) Find(id string) (c *Persistent, ok bool) {
uid, found := ci.clientIDToUID[id]
if found {
return ci.uidToClient[uid], true
}
ip, err := netip.ParseAddr(id)
if err == nil {
// MAC addresses can be successfully parsed as IP addresses.
c, found = ci.findByIP(ip)
if found {
return c, true
}
}
mac, err := net.ParseMAC(id)
if err == nil {
return ci.findByMAC(mac)
}
return nil, false
}
// find finds persistent client by IP address.
func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
uid, found := ci.ipToUID[ip]
if found {
return ci.uidToClient[uid], true
}
ci.subnetToUID.Range(func(pref netip.Prefix, id UID) (cont bool) {
if pref.Contains(ip) {
uid, found = id, true
return false
}
return true
})
if found {
return ci.uidToClient[uid], true
}
return nil, false
}
// find finds persistent client by MAC.
func (ci *Index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
k := macToKey(mac)
uid, found := ci.macToUID[k]
if found {
return ci.uidToClient[uid], true
}
return nil, false
}
// Delete removes information about persistent client from the index. c must be
// non-nil.
func (ci *Index) Delete(c *Persistent) {
for _, id := range c.ClientIDs {
delete(ci.clientIDToUID, id)
}
for _, ip := range c.IPs {
delete(ci.ipToUID, ip)
}
for _, pref := range c.Subnets {
ci.subnetToUID.Del(pref)
}
for _, mac := range c.MACs {
k := macToKey(mac)
delete(ci.macToUID, k)
}
delete(ci.uidToClient, c.UID)
}

View File

@@ -0,0 +1,223 @@
package client
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newIDIndex is a helper function that returns a client index filled with
// persistent clients from the m. It also generates a UID for each client.
func newIDIndex(m []*Persistent) (ci *Index) {
ci = NewIndex()
for _, c := range m {
c.UID = MustNewUID()
ci.Add(c)
}
return ci
}
func TestClientIndex(t *testing.T) {
const (
cliIPNone = "1.2.3.4"
cliIP1 = "1.1.1.1"
cliIP2 = "2.2.2.2"
cliIPv6 = "1:2:3::4"
cliSubnet = "2.2.2.0/24"
cliSubnetIP = "2.2.2.222"
cliID = "client-id"
cliMAC = "11:11:11:11:11:11"
)
clients := []*Persistent{{
Name: "client1",
IPs: []netip.Addr{
netip.MustParseAddr(cliIP1),
netip.MustParseAddr(cliIPv6),
},
}, {
Name: "client2",
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
}, {
Name: "client_with_mac",
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
}, {
Name: "client_with_id",
ClientIDs: []string{cliID},
}}
ci := newIDIndex(clients)
testCases := []struct {
want *Persistent
name string
ids []string
}{{
name: "ipv4_ipv6",
ids: []string{cliIP1, cliIPv6},
want: clients[0],
}, {
name: "ipv4_subnet",
ids: []string{cliIP2, cliSubnetIP},
want: clients[1],
}, {
name: "mac",
ids: []string{cliMAC},
want: clients[2],
}, {
name: "client_id",
ids: []string{cliID},
want: clients[3],
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, id := range tc.ids {
c, ok := ci.Find(id)
require.True(t, ok)
assert.Equal(t, tc.want, c)
}
})
}
t.Run("not_found", func(t *testing.T) {
_, ok := ci.Find(cliIPNone)
assert.False(t, ok)
})
}
func TestClientIndex_Clashes(t *testing.T) {
const (
cliIP1 = "1.1.1.1"
cliSubnet = "2.2.2.0/24"
cliSubnetIP = "2.2.2.222"
cliID = "client-id"
cliMAC = "11:11:11:11:11:11"
)
clients := []*Persistent{{
Name: "client_with_ip",
IPs: []netip.Addr{netip.MustParseAddr(cliIP1)},
}, {
Name: "client_with_subnet",
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
}, {
Name: "client_with_mac",
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
}, {
Name: "client_with_id",
ClientIDs: []string{cliID},
}}
ci := newIDIndex(clients)
testCases := []struct {
client *Persistent
name string
}{{
name: "ipv4",
client: clients[0],
}, {
name: "subnet",
client: clients[1],
}, {
name: "mac",
client: clients[2],
}, {
name: "client_id",
client: clients[3],
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
clone := tc.client.ShallowClone()
clone.UID = MustNewUID()
err := ci.Clashes(clone)
require.Error(t, err)
ci.Delete(tc.client)
err = ci.Clashes(clone)
require.NoError(t, err)
})
}
}
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
// error.
func mustParseMAC(s string) (mac net.HardwareAddr) {
mac, err := net.ParseMAC(s)
if err != nil {
panic(err)
}
return mac
}
func TestMACToKey(t *testing.T) {
testCases := []struct {
want any
name string
in string
}{{
name: "column6",
in: "00:00:5e:00:53:01",
want: [6]byte(mustParseMAC("00:00:5e:00:53:01")),
}, {
name: "column8",
in: "02:00:5e:10:00:00:00:01",
want: [8]byte(mustParseMAC("02:00:5e:10:00:00:00:01")),
}, {
name: "column20",
in: "00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01",
want: [20]byte(mustParseMAC("00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01")),
}, {
name: "hyphen6",
in: "00-00-5e-00-53-01",
want: [6]byte(mustParseMAC("00-00-5e-00-53-01")),
}, {
name: "hyphen8",
in: "02-00-5e-10-00-00-00-01",
want: [8]byte(mustParseMAC("02-00-5e-10-00-00-00-01")),
}, {
name: "hyphen20",
in: "00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01",
want: [20]byte(mustParseMAC("00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01")),
}, {
name: "dot6",
in: "0000.5e00.5301",
want: [6]byte(mustParseMAC("0000.5e00.5301")),
}, {
name: "dot8",
in: "0200.5e10.0000.0001",
want: [8]byte(mustParseMAC("0200.5e10.0000.0001")),
}, {
name: "dot20",
in: "0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001",
want: [20]byte(mustParseMAC("0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001")),
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mac := mustParseMAC(tc.in)
key := macToKey(mac)
assert.Equal(t, tc.want, key)
})
}
assert.Panics(t, func() {
mac := net.HardwareAddr([]byte{1, 2, 3})
_ = macToKey(mac)
})
}

View File

@@ -0,0 +1,285 @@
package client
import (
"encoding"
"fmt"
"net"
"net/netip"
"slices"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/google/uuid"
)
// UID is the type for the unique IDs of persistent clients.
type UID uuid.UUID
// NewUID returns a new persistent client UID. Any error returned is an error
// from the cryptographic randomness reader.
func NewUID() (uid UID, err error) {
uuidv7, err := uuid.NewV7()
return UID(uuidv7), err
}
// MustNewUID is a wrapper around [NewUID] that panics if there is an error.
func MustNewUID() (uid UID) {
uid, err := NewUID()
if err != nil {
panic(fmt.Errorf("unexpected uuidv7 error: %w", err))
}
return uid
}
// type check
var _ encoding.TextMarshaler = UID{}
// MarshalText implements the [encoding.TextMarshaler] for UID.
func (uid UID) MarshalText() ([]byte, error) {
return uuid.UUID(uid).MarshalText()
}
// type check
var _ encoding.TextUnmarshaler = (*UID)(nil)
// UnmarshalText implements the [encoding.TextUnmarshaler] interface for UID.
func (uid *UID) UnmarshalText(data []byte) error {
return (*uuid.UUID)(uid).UnmarshalText(data)
}
// Persistent contains information about persistent clients.
type Persistent struct {
// UpstreamConfig is the custom upstream configuration for this client. If
// it's nil, it has not been initialized yet. If it's non-nil and empty,
// there are no valid upstreams. If it's non-nil and non-empty, these
// upstream must be used.
UpstreamConfig *proxy.CustomUpstreamConfig
// TODO(d.kolyshev): Make SafeSearchConf a pointer.
SafeSearchConf filtering.SafeSearchConfig
SafeSearch filtering.SafeSearch
// BlockedServices is the configuration of blocked services of a client.
BlockedServices *filtering.BlockedServices
Name string
Tags []string
Upstreams []string
IPs []netip.Addr
// TODO(s.chzhen): Use netutil.Prefix.
Subnets []netip.Prefix
MACs []net.HardwareAddr
ClientIDs []string
// UID is the unique identifier of the persistent client.
UID UID
UpstreamsCacheSize uint32
UpstreamsCacheEnabled bool
UseOwnSettings bool
FilteringEnabled bool
SafeBrowsingEnabled bool
ParentalEnabled bool
UseOwnBlockedServices bool
IgnoreQueryLog bool
IgnoreStatistics bool
}
// SetTags sets the tags if they are known, otherwise logs an unknown tag.
func (c *Persistent) SetTags(tags []string, known *stringutil.Set) {
for _, t := range tags {
if !known.Has(t) {
log.Info("skipping unknown tag %q", t)
continue
}
c.Tags = append(c.Tags, t)
}
slices.Sort(c.Tags)
}
// SetIDs parses a list of strings into typed fields and returns an error if
// there is one.
func (c *Persistent) SetIDs(ids []string) (err error) {
for _, id := range ids {
err = c.setID(id)
if err != nil {
return err
}
}
slices.SortFunc(c.IPs, netip.Addr.Compare)
// TODO(s.chzhen): Use netip.PrefixCompare in Go 1.23.
slices.SortFunc(c.Subnets, subnetCompare)
slices.SortFunc(c.MACs, slices.Compare[net.HardwareAddr])
slices.Sort(c.ClientIDs)
return nil
}
// subnetCompare is a comparison function for the two subnets. It returns -1 if
// x sorts before y, 1 if x sorts after y, and 0 if their relative sorting
// position is the same.
func subnetCompare(x, y netip.Prefix) (cmp int) {
if x == y {
return 0
}
xAddr, xBits := x.Addr(), x.Bits()
yAddr, yBits := y.Addr(), y.Bits()
if xBits == yBits {
return xAddr.Compare(yAddr)
}
if xBits > yBits {
return -1
} else {
return 1
}
}
// setID parses id into typed field if there is no error.
func (c *Persistent) setID(id string) (err error) {
if id == "" {
return errors.Error("clientid is empty")
}
var ip netip.Addr
if ip, err = netip.ParseAddr(id); err == nil {
c.IPs = append(c.IPs, ip)
return nil
}
var subnet netip.Prefix
if subnet, err = netip.ParsePrefix(id); err == nil {
c.Subnets = append(c.Subnets, subnet)
return nil
}
var mac net.HardwareAddr
if mac, err = net.ParseMAC(id); err == nil {
c.MACs = append(c.MACs, mac)
return nil
}
err = ValidateClientID(id)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
c.ClientIDs = append(c.ClientIDs, strings.ToLower(id))
return nil
}
// ValidateClientID returns an error if id is not a valid ClientID.
//
// TODO(s.chzhen): It's an exact copy of the [dnsforward.ValidateClientID] to
// avoid the import cycle. Remove it.
func ValidateClientID(id string) (err error) {
err = netutil.ValidateHostnameLabel(id)
if err != nil {
// Replace the domain name label wrapper with our own.
return fmt.Errorf("invalid clientid %q: %w", id, errors.Unwrap(err))
}
return nil
}
// IDs returns a list of client IDs containing at least one element.
func (c *Persistent) IDs() (ids []string) {
ids = make([]string, 0, c.IDsLen())
for _, ip := range c.IPs {
ids = append(ids, ip.String())
}
for _, subnet := range c.Subnets {
ids = append(ids, subnet.String())
}
for _, mac := range c.MACs {
ids = append(ids, mac.String())
}
return append(ids, c.ClientIDs...)
}
// IDsLen returns a length of client ids.
func (c *Persistent) IDsLen() (n int) {
return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs)
}
// EqualIDs returns true if the ids of the current and previous clients are the
// same.
func (c *Persistent) EqualIDs(prev *Persistent) (equal bool) {
return slices.Equal(c.IPs, prev.IPs) &&
slices.Equal(c.Subnets, prev.Subnets) &&
slices.EqualFunc(c.MACs, prev.MACs, slices.Equal[net.HardwareAddr]) &&
slices.Equal(c.ClientIDs, prev.ClientIDs)
}
// ShallowClone returns a deep copy of the client, except upstreamConfig,
// safeSearchConf, SafeSearch fields, because it's difficult to copy them.
func (c *Persistent) ShallowClone() (clone *Persistent) {
clone = &Persistent{}
*clone = *c
clone.BlockedServices = c.BlockedServices.Clone()
clone.Tags = slices.Clone(c.Tags)
clone.Upstreams = slices.Clone(c.Upstreams)
clone.IPs = slices.Clone(c.IPs)
clone.Subnets = slices.Clone(c.Subnets)
clone.MACs = slices.Clone(c.MACs)
clone.ClientIDs = slices.Clone(c.ClientIDs)
return clone
}
// CloseUpstreams closes the client-specific upstream config of c if any.
func (c *Persistent) CloseUpstreams() (err error) {
if c.UpstreamConfig != nil {
if err = c.UpstreamConfig.Close(); err != nil {
return fmt.Errorf("closing upstreams of client %q: %w", c.Name, err)
}
}
return nil
}
// SetSafeSearch initializes and sets the safe search filter for this client.
func (c *Persistent) SetSafeSearch(
conf filtering.SafeSearchConfig,
cacheSize uint,
cacheTTL time.Duration,
) (err error) {
ss, err := safesearch.NewDefault(conf, fmt.Sprintf("client %q", c.Name), cacheSize, cacheTTL)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
c.SafeSearch = ss
return nil
}

View File

@@ -0,0 +1,124 @@
package client
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPersistentClient_EqualIDs(t *testing.T) {
const (
ip = "0.0.0.0"
ip1 = "1.1.1.1"
ip2 = "2.2.2.2"
cidr = "0.0.0.0/0"
cidr1 = "1.1.1.1/11"
cidr2 = "2.2.2.2/22"
mac = "00-00-00-00-00-00"
mac1 = "11-11-11-11-11-11"
mac2 = "22-22-22-22-22-22"
cli = "client0"
cli1 = "client1"
cli2 = "client2"
)
testCases := []struct {
want assert.BoolAssertionFunc
name string
ids []string
prevIDs []string
}{{
name: "single_ip",
ids: []string{ip1},
prevIDs: []string{ip1},
want: assert.True,
}, {
name: "single_ip_not_equal",
ids: []string{ip1},
prevIDs: []string{ip2},
want: assert.False,
}, {
name: "ips_not_equal",
ids: []string{ip1, ip2},
prevIDs: []string{ip1, ip},
want: assert.False,
}, {
name: "ips_mixed_equal",
ids: []string{ip1, ip2},
prevIDs: []string{ip2, ip1},
want: assert.True,
}, {
name: "single_subnet",
ids: []string{cidr1},
prevIDs: []string{cidr1},
want: assert.True,
}, {
name: "subnets_not_equal",
ids: []string{ip1, ip2, cidr1, cidr2},
prevIDs: []string{ip1, ip2, cidr1, cidr},
want: assert.False,
}, {
name: "subnets_mixed_equal",
ids: []string{ip1, ip2, cidr1, cidr2},
prevIDs: []string{cidr2, cidr1, ip2, ip1},
want: assert.True,
}, {
name: "single_mac",
ids: []string{mac1},
prevIDs: []string{mac1},
want: assert.True,
}, {
name: "single_mac_not_equal",
ids: []string{mac1},
prevIDs: []string{mac2},
want: assert.False,
}, {
name: "macs_not_equal",
ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2},
prevIDs: []string{ip1, ip2, cidr1, cidr2, mac1, mac},
want: assert.False,
}, {
name: "macs_mixed_equal",
ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2},
prevIDs: []string{mac2, mac1, cidr2, cidr1, ip2, ip1},
want: assert.True,
}, {
name: "single_client_id",
ids: []string{cli1},
prevIDs: []string{cli1},
want: assert.True,
}, {
name: "single_client_id_not_equal",
ids: []string{cli1},
prevIDs: []string{cli2},
want: assert.False,
}, {
name: "client_ids_not_equal",
ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2, cli1, cli2},
prevIDs: []string{ip1, ip2, cidr1, cidr2, mac1, mac2, cli1, cli},
want: assert.False,
}, {
name: "client_ids_mixed_equal",
ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2, cli1, cli2},
prevIDs: []string{cli2, cli1, mac2, mac1, cidr2, cidr1, ip2, ip1},
want: assert.True,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := &Persistent{}
err := c.SetIDs(tc.ids)
require.NoError(t, err)
prev := &Persistent{}
err = prev.SetIDs(tc.prevIDs)
require.NoError(t, err)
tc.want(t, c.EqualIDs(prev))
})
}
}