Pull request 2378: AGDNS-2750-find-client
Merge in DNS/adguard-home from AGDNS-2750-find-client to master Squashed commit of the following: commit 98f1a8ca4622b6f502a5092273b9724203fe0bd8 Merge: 9270222d84ccc2a213Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Apr 23 17:53:20 2025 +0300 Merge branch 'master' into AGDNS-2750-find-client commit 9270222d8e9e03038e9434b54496cbb6164463cd Merge: 6468ceec8c7c62ad3bAuthor: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Apr 21 19:40:58 2025 +0300 Merge branch 'master' into AGDNS-2750-find-client commit 6468ceec82d30084771a53ff6720a8c11c68bf2f Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Apr 21 19:40:52 2025 +0300 home: imp docs commit 3fd4735a0d6db4fdf2d46f3da9794a687fdcaa8b Merge: 1311a5869a8fdf1c55Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Apr 18 19:43:36 2025 +0300 Merge branch 'master' into AGDNS-2750-find-client commit 1311a58695de00f20c9704378ee6e964a44d1c59 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Apr 18 19:42:41 2025 +0300 home: imp code commit b1f2c4c883c9476c5135140abac31f8ae6609b4f Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Apr 16 16:47:59 2025 +0300 home: imp code commit d0a5abd66587c1ad602c2ccf6c8a45a3dfe39a5c Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Apr 15 14:58:31 2025 +0300 client: imp naming commit 5accdca325551237f003f1c416891b488fe5290b Merge: 6a00232f74d258972dAuthor: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Apr 14 19:40:40 2025 +0300 Merge branch 'master' into AGDNS-2750-find-client commit 6a00232f76a0fe5ce781aa01637b6e04ace7250d Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Apr 14 19:30:32 2025 +0300 home: imp code commit 8633886457c6aab75f5676494b1f49d9811e9ab9 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Apr 11 15:29:25 2025 +0300 all: imp code commit d6f16879e7b054a5ffac59131d2a6eff1da659c0 Merge: 58236fdec6d282ae71Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Apr 10 21:35:23 2025 +0300 Merge branch 'master' into AGDNS-2750-find-client commit 58236fdec5b64e83a44680ff8a89badc18ec81f1 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Apr 10 21:23:01 2025 +0300 all: upd ci commit 3c4d946d7970987677d4ac984394e18987a29f9a Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Apr 10 21:16:03 2025 +0300 all: upd go commit cc1c97734506a9ffbe70fd3c676284e58a21ba46 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Apr 10 20:58:56 2025 +0300 all: imp code commit 8f061c933152481a4c80eef2af575efd4919d82b Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Apr 9 16:49:11 2025 +0300 all: imp docs commit 8d19355f1c519211a56cec3f23d527922d4f2ee0 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Apr 7 21:35:06 2025 +0300 all: imp code commit f1e853f57e5d54d13bedcdab4f8e21e112f3a356 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Apr 2 14:57:40 2025 +0300 all: imp code commit 6a6ac7f899f29ddc90a583c80562233e646ba1d6 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Apr 1 19:51:56 2025 +0300 client: imp tests commit 52040ee7393d0483c682f2f37d7b70f12f9cf621 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Apr 1 19:28:18 2025 +0300 all: imp code commit 1e09208dbd2d35c3f6b2ade169324e23d1a643a5 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Mar 26 15:33:02 2025 +0300 all: imp code ... and 2 more commits
This commit is contained in:
@@ -11,8 +11,34 @@ import (
|
||||
"slices"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
)
|
||||
|
||||
// ClientID is a unique identifier for a persistent client used in
|
||||
// DNS-over-HTTPS, DNS-over-TLS, and DNS-over-QUIC queries.
|
||||
//
|
||||
// TODO(s.chzhen): Use everywhere.
|
||||
type ClientID string
|
||||
|
||||
// ValidateClientID returns an error if id is not a valid ClientID.
|
||||
//
|
||||
// TODO(s.chzhen): Consider implementing [validate.Interface] for ClientID.
|
||||
func ValidateClientID(id string) (err error) {
|
||||
err = netutil.ValidateHostnameLabel(id)
|
||||
if err != nil {
|
||||
// Replace the domain name label wrapper with our own.
|
||||
return fmt.Errorf("invalid clientid %q: %w", id, errors.Unwrap(err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isValidClientID returns false if id is not a valid ClientID.
|
||||
func isValidClientID(id string) (ok bool) {
|
||||
return netutil.IsValidHostnameLabel(id)
|
||||
}
|
||||
|
||||
// Source represents the source from which the information about the client has
|
||||
// been obtained.
|
||||
type Source uint8
|
||||
|
||||
@@ -35,7 +35,7 @@ type index struct {
|
||||
nameToUID map[string]UID
|
||||
|
||||
// clientIDToUID maps ClientID to UID.
|
||||
clientIDToUID map[string]UID
|
||||
clientIDToUID map[ClientID]UID
|
||||
|
||||
// ipToUID maps IP address to UID.
|
||||
ipToUID map[netip.Addr]UID
|
||||
@@ -54,7 +54,7 @@ type index struct {
|
||||
func newIndex() (ci *index) {
|
||||
return &index{
|
||||
nameToUID: map[string]UID{},
|
||||
clientIDToUID: map[string]UID{},
|
||||
clientIDToUID: map[ClientID]UID{},
|
||||
ipToUID: map[netip.Addr]UID{},
|
||||
subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare),
|
||||
macToUID: map[macKey]UID{},
|
||||
@@ -207,7 +207,7 @@ func (ci *index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr)
|
||||
// find finds persistent client by string representation of the ClientID, IP
|
||||
// address, or MAC.
|
||||
func (ci *index) find(id string) (c *Persistent, ok bool) {
|
||||
c, ok = ci.findByClientID(id)
|
||||
c, ok = ci.findByClientID(ClientID(id))
|
||||
if ok {
|
||||
return c, true
|
||||
}
|
||||
@@ -230,7 +230,7 @@ func (ci *index) find(id string) (c *Persistent, ok bool) {
|
||||
}
|
||||
|
||||
// findByClientID finds persistent client by ClientID.
|
||||
func (ci *index) findByClientID(clientID string) (c *Persistent, ok bool) {
|
||||
func (ci *index) findByClientID(clientID ClientID) (c *Persistent, ok bool) {
|
||||
uid, ok := ci.clientIDToUID[clientID]
|
||||
if ok {
|
||||
return ci.uidToClient[uid], true
|
||||
@@ -275,6 +275,26 @@ func (ci *index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findByCIDR searches for a persistent client with the provided subnet as an
|
||||
// identifier. Note that this function looks for an exact match of subnets,
|
||||
// rather than checking if one subnet contains another.
|
||||
func (ci *index) findByCIDR(subnet netip.Prefix) (c *Persistent, ok bool) {
|
||||
var uid UID
|
||||
for pref, id := range ci.subnetToUID.Range {
|
||||
if subnet == pref {
|
||||
uid, ok = id, true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if ok {
|
||||
return ci.uidToClient[uid], true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findByMAC finds persistent client by MAC.
|
||||
func (ci *index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
|
||||
k := macToKey(mac)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -58,12 +59,12 @@ func TestClientIndex_Find(t *testing.T) {
|
||||
|
||||
clientWithMAC = &Persistent{
|
||||
Name: "client_with_mac",
|
||||
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
|
||||
MACs: []net.HardwareAddr{errors.Must(net.ParseMAC(cliMAC))},
|
||||
}
|
||||
|
||||
clientWithID = &Persistent{
|
||||
Name: "client_with_id",
|
||||
ClientIDs: []string{cliID},
|
||||
ClientIDs: []ClientID{cliID},
|
||||
}
|
||||
|
||||
clientLinkLocal = &Persistent{
|
||||
@@ -141,10 +142,10 @@ func TestClientIndex_Clashes(t *testing.T) {
|
||||
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
|
||||
}, {
|
||||
Name: "client_with_mac",
|
||||
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
|
||||
MACs: []net.HardwareAddr{errors.Must(net.ParseMAC(cliMAC))},
|
||||
}, {
|
||||
Name: "client_with_id",
|
||||
ClientIDs: []string{cliID},
|
||||
ClientIDs: []ClientID{cliID},
|
||||
}}
|
||||
|
||||
ci := newIDIndex(clients)
|
||||
@@ -181,17 +182,6 @@ func TestClientIndex_Clashes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
|
||||
// error.
|
||||
func mustParseMAC(s string) (mac net.HardwareAddr) {
|
||||
mac, err := net.ParseMAC(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return mac
|
||||
}
|
||||
|
||||
func TestMACToKey(t *testing.T) {
|
||||
testCases := []struct {
|
||||
want any
|
||||
@@ -200,44 +190,44 @@ func TestMACToKey(t *testing.T) {
|
||||
}{{
|
||||
name: "column6",
|
||||
in: "00:00:5e:00:53:01",
|
||||
want: [6]byte(mustParseMAC("00:00:5e:00:53:01")),
|
||||
want: [6]byte(errors.Must(net.ParseMAC("00:00:5e:00:53:01"))),
|
||||
}, {
|
||||
name: "column8",
|
||||
in: "02:00:5e:10:00:00:00:01",
|
||||
want: [8]byte(mustParseMAC("02:00:5e:10:00:00:00:01")),
|
||||
want: [8]byte(errors.Must(net.ParseMAC("02:00:5e:10:00:00:00:01"))),
|
||||
}, {
|
||||
name: "column20",
|
||||
in: "00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01",
|
||||
want: [20]byte(mustParseMAC("00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01")),
|
||||
want: [20]byte(errors.Must(net.ParseMAC("00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01"))),
|
||||
}, {
|
||||
name: "hyphen6",
|
||||
in: "00-00-5e-00-53-01",
|
||||
want: [6]byte(mustParseMAC("00-00-5e-00-53-01")),
|
||||
want: [6]byte(errors.Must(net.ParseMAC("00-00-5e-00-53-01"))),
|
||||
}, {
|
||||
name: "hyphen8",
|
||||
in: "02-00-5e-10-00-00-00-01",
|
||||
want: [8]byte(mustParseMAC("02-00-5e-10-00-00-00-01")),
|
||||
want: [8]byte(errors.Must(net.ParseMAC("02-00-5e-10-00-00-00-01"))),
|
||||
}, {
|
||||
name: "hyphen20",
|
||||
in: "00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01",
|
||||
want: [20]byte(mustParseMAC("00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01")),
|
||||
want: [20]byte(errors.Must(net.ParseMAC("00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01"))),
|
||||
}, {
|
||||
name: "dot6",
|
||||
in: "0000.5e00.5301",
|
||||
want: [6]byte(mustParseMAC("0000.5e00.5301")),
|
||||
want: [6]byte(errors.Must(net.ParseMAC("0000.5e00.5301"))),
|
||||
}, {
|
||||
name: "dot8",
|
||||
in: "0200.5e10.0000.0001",
|
||||
want: [8]byte(mustParseMAC("0200.5e10.0000.0001")),
|
||||
want: [8]byte(errors.Must(net.ParseMAC("0200.5e10.0000.0001"))),
|
||||
}, {
|
||||
name: "dot20",
|
||||
in: "0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001",
|
||||
want: [20]byte(mustParseMAC("0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001")),
|
||||
want: [20]byte(errors.Must(net.ParseMAC("0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001"))),
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mac := mustParseMAC(tc.in)
|
||||
mac := errors.Must(net.ParseMAC(tc.in))
|
||||
|
||||
key := macToKey(mac)
|
||||
assert.Equal(t, tc.want, key)
|
||||
@@ -302,19 +292,19 @@ func TestIndex_FindByIPWithoutZone(t *testing.T) {
|
||||
func TestClientIndex_RangeByName(t *testing.T) {
|
||||
sortedClients := []*Persistent{{
|
||||
Name: "clientA",
|
||||
ClientIDs: []string{"A"},
|
||||
ClientIDs: []ClientID{"A"},
|
||||
}, {
|
||||
Name: "clientB",
|
||||
ClientIDs: []string{"B"},
|
||||
ClientIDs: []ClientID{"B"},
|
||||
}, {
|
||||
Name: "clientC",
|
||||
ClientIDs: []string{"C"},
|
||||
ClientIDs: []ClientID{"C"},
|
||||
}, {
|
||||
Name: "clientD",
|
||||
ClientIDs: []string{"D"},
|
||||
ClientIDs: []ClientID{"D"},
|
||||
}, {
|
||||
Name: "clientE",
|
||||
ClientIDs: []string{"E"},
|
||||
ClientIDs: []ClientID{"E"},
|
||||
}}
|
||||
|
||||
testCases := []struct {
|
||||
@@ -349,3 +339,115 @@ func TestClientIndex_RangeByName(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndex_FindByName(t *testing.T) {
|
||||
const (
|
||||
clientExistingName = "client_existing"
|
||||
clientAnotherExistingName = "client_another_existing"
|
||||
nonExistingClientName = "client_non_existing"
|
||||
)
|
||||
|
||||
var (
|
||||
clientExisting = &Persistent{
|
||||
Name: clientExistingName,
|
||||
IPs: []netip.Addr{netip.MustParseAddr("192.0.2.1")},
|
||||
}
|
||||
|
||||
clientAnotherExisting = &Persistent{
|
||||
Name: clientAnotherExistingName,
|
||||
IPs: []netip.Addr{netip.MustParseAddr("192.0.2.2")},
|
||||
}
|
||||
)
|
||||
|
||||
clients := []*Persistent{
|
||||
clientExisting,
|
||||
clientAnotherExisting,
|
||||
}
|
||||
ci := newIDIndex(clients)
|
||||
|
||||
testCases := []struct {
|
||||
want *Persistent
|
||||
found assert.BoolAssertionFunc
|
||||
name string
|
||||
clientName string
|
||||
}{{
|
||||
want: clientExisting,
|
||||
found: assert.True,
|
||||
name: "existing",
|
||||
clientName: clientExistingName,
|
||||
}, {
|
||||
want: clientAnotherExisting,
|
||||
found: assert.True,
|
||||
name: "another_existing",
|
||||
clientName: clientAnotherExistingName,
|
||||
}, {
|
||||
want: nil,
|
||||
found: assert.False,
|
||||
name: "non_existing",
|
||||
clientName: nonExistingClientName,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c, ok := ci.findByName(tc.clientName)
|
||||
assert.Equal(t, tc.want, c)
|
||||
tc.found(t, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndex_FindByMAC(t *testing.T) {
|
||||
var (
|
||||
cliMAC = errors.Must(net.ParseMAC("11:11:11:11:11:11"))
|
||||
cliAnotherMAC = errors.Must(net.ParseMAC("22:22:22:22:22:22"))
|
||||
nonExistingClientMAC = errors.Must(net.ParseMAC("33:33:33:33:33:33"))
|
||||
)
|
||||
|
||||
var (
|
||||
clientExisting = &Persistent{
|
||||
Name: "client",
|
||||
MACs: []net.HardwareAddr{cliMAC},
|
||||
}
|
||||
|
||||
clientAnotherExisting = &Persistent{
|
||||
Name: "another_client",
|
||||
MACs: []net.HardwareAddr{cliAnotherMAC},
|
||||
}
|
||||
)
|
||||
|
||||
clients := []*Persistent{
|
||||
clientExisting,
|
||||
clientAnotherExisting,
|
||||
}
|
||||
ci := newIDIndex(clients)
|
||||
|
||||
testCases := []struct {
|
||||
want *Persistent
|
||||
found assert.BoolAssertionFunc
|
||||
name string
|
||||
clientMAC net.HardwareAddr
|
||||
}{{
|
||||
want: clientExisting,
|
||||
found: assert.True,
|
||||
name: "existing",
|
||||
clientMAC: cliMAC,
|
||||
}, {
|
||||
want: clientAnotherExisting,
|
||||
found: assert.True,
|
||||
name: "another_existing",
|
||||
clientMAC: cliAnotherMAC,
|
||||
}, {
|
||||
want: nil,
|
||||
found: assert.False,
|
||||
name: "non_existing",
|
||||
clientMAC: nonExistingClientMAC,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c, ok := ci.findByMAC(tc.clientMAC)
|
||||
assert.Equal(t, tc.want, c)
|
||||
tc.found(t, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
@@ -90,7 +89,7 @@ type Persistent struct {
|
||||
|
||||
// ClientIDs identifying the client. The client must have at least one ID
|
||||
// (IP, subnet, MAC, or ClientID).
|
||||
ClientIDs []string
|
||||
ClientIDs []ClientID
|
||||
|
||||
// UID is the unique identifier of the persistent client.
|
||||
UID UID
|
||||
@@ -134,7 +133,7 @@ func (c *Persistent) validate(ctx context.Context, l *slog.Logger, allTags []str
|
||||
switch {
|
||||
case c.Name == "":
|
||||
return errors.Error("empty name")
|
||||
case c.IDsLen() == 0:
|
||||
case c.idendifiersLen() == 0:
|
||||
return errors.Error("id required")
|
||||
case c.UID == UID{}:
|
||||
return errors.Error("uid required")
|
||||
@@ -237,28 +236,15 @@ func (c *Persistent) setID(id string) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
c.ClientIDs = append(c.ClientIDs, strings.ToLower(id))
|
||||
c.ClientIDs = append(c.ClientIDs, ClientID(strings.ToLower(id)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateClientID returns an error if id is not a valid ClientID.
|
||||
//
|
||||
// TODO(s.chzhen): It's an exact copy of the [dnsforward.ValidateClientID] to
|
||||
// avoid the import cycle. Remove it.
|
||||
func ValidateClientID(id string) (err error) {
|
||||
err = netutil.ValidateHostnameLabel(id)
|
||||
if err != nil {
|
||||
// Replace the domain name label wrapper with our own.
|
||||
return fmt.Errorf("invalid clientid %q: %w", id, errors.Unwrap(err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IDs returns a list of ClientIDs containing at least one element.
|
||||
func (c *Persistent) IDs() (ids []string) {
|
||||
ids = make([]string, 0, c.IDsLen())
|
||||
// Identifiers returns a list of client identifiers containing at least one
|
||||
// element.
|
||||
func (c *Persistent) Identifiers() (ids []string) {
|
||||
ids = make([]string, 0, c.idendifiersLen())
|
||||
|
||||
for _, ip := range c.IPs {
|
||||
ids = append(ids, ip.String())
|
||||
@@ -272,11 +258,15 @@ func (c *Persistent) IDs() (ids []string) {
|
||||
ids = append(ids, mac.String())
|
||||
}
|
||||
|
||||
return append(ids, c.ClientIDs...)
|
||||
for _, cid := range c.ClientIDs {
|
||||
ids = append(ids, string(cid))
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
// IDsLen returns a length of ClientIDs.
|
||||
func (c *Persistent) IDsLen() (n int) {
|
||||
// identifiersLen returns the number of client identifiers.
|
||||
func (c *Persistent) idendifiersLen() (n int) {
|
||||
return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -18,6 +19,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/hostsfile"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
)
|
||||
|
||||
@@ -433,48 +435,186 @@ func (s *Storage) Add(ctx context.Context, p *Persistent) (err error) {
|
||||
ctx,
|
||||
"client added",
|
||||
"name", p.Name,
|
||||
"ids", p.IDs(),
|
||||
"ids", p.Identifiers(),
|
||||
"clients_count", s.index.size(),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindByName finds persistent client by name. And returns its shallow copy.
|
||||
func (s *Storage) FindByName(name string) (p *Persistent, ok bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
// FindParams represents the parameters for searching a client. At least one
|
||||
// field must be non-empty.
|
||||
type FindParams struct {
|
||||
// ClientID is a unique identifier for the client used in DoH, DoT, and DoQ
|
||||
// DNS queries.
|
||||
ClientID ClientID
|
||||
|
||||
p, ok = s.index.findByName(name)
|
||||
if ok {
|
||||
return p.ShallowClone(), ok
|
||||
}
|
||||
// RemoteIP is the IP address used as a client search parameter.
|
||||
RemoteIP netip.Addr
|
||||
|
||||
return nil, false
|
||||
// Subnet is the CIDR used as a client search parameter.
|
||||
Subnet netip.Prefix
|
||||
|
||||
// MAC is the physical hardware address used as a client search parameter.
|
||||
MAC net.HardwareAddr
|
||||
|
||||
// UID is the unique ID of persistent client used as a search parameter.
|
||||
//
|
||||
// TODO(s.chzhen): Use this.
|
||||
UID UID
|
||||
}
|
||||
|
||||
// Find finds persistent client by string representation of the ClientID, IP
|
||||
// address, or MAC. And returns its shallow copy.
|
||||
// ErrBadIdentifier is returned by [FindParams.Set] when it cannot parse the
|
||||
// provided client identifier.
|
||||
const ErrBadIdentifier errors.Error = "bad client identifier"
|
||||
|
||||
// Set clears the stored search parameters and parses the string representation
|
||||
// of the search parameter into typed parameter, storing it. In some cases, it
|
||||
// may result in storing both an IP address and a MAC address because they might
|
||||
// have identical string representations. It returns [ErrBadIdentifier] if id
|
||||
// cannot be parsed.
|
||||
//
|
||||
// TODO(s.chzhen): Accept ClientIDData structure instead, which will contain
|
||||
// the parsed IP address, if any.
|
||||
func (s *Storage) Find(id string) (p *Persistent, ok bool) {
|
||||
// TODO(s.chzhen): Add support for UID.
|
||||
func (p *FindParams) Set(id string) (err error) {
|
||||
*p = FindParams{}
|
||||
|
||||
isClientID := true
|
||||
|
||||
if netutil.IsValidIPString(id) {
|
||||
// It is safe to use [netip.MustParseAddr] because it has already been
|
||||
// validated that id contains the string representation of the IP
|
||||
// address.
|
||||
p.RemoteIP = netip.MustParseAddr(id)
|
||||
|
||||
// Even if id can be parsed as an IP address, it may be a MAC address.
|
||||
// So do not return prematurely, continue parsing.
|
||||
isClientID = false
|
||||
}
|
||||
|
||||
if canBeValidIPPrefixString(id) {
|
||||
p.Subnet, err = netip.ParsePrefix(id)
|
||||
if err == nil {
|
||||
isClientID = false
|
||||
}
|
||||
}
|
||||
|
||||
if canBeMACString(id) {
|
||||
p.MAC, err = net.ParseMAC(id)
|
||||
if err == nil {
|
||||
isClientID = false
|
||||
}
|
||||
}
|
||||
|
||||
if !isClientID {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isValidClientID(id) {
|
||||
return ErrBadIdentifier
|
||||
}
|
||||
|
||||
p.ClientID = ClientID(id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// canBeValidIPPrefixString is a best-effort check to determine if s is a valid
|
||||
// CIDR before using [netip.ParsePrefix], aimed at reducing allocations.
|
||||
//
|
||||
// TODO(s.chzhen): Replace this implementation with the more robust version
|
||||
// from golibs.
|
||||
func canBeValidIPPrefixString(s string) (ok bool) {
|
||||
ipStr, bitStr, ok := strings.Cut(s, "/")
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if bitStr == "" || len(bitStr) > 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
bits := 0
|
||||
for _, c := range bitStr {
|
||||
if c < '0' || c > '9' {
|
||||
return false
|
||||
}
|
||||
|
||||
bits = bits*10 + int(c-'0')
|
||||
}
|
||||
|
||||
if bits > 128 {
|
||||
return false
|
||||
}
|
||||
|
||||
return netutil.IsValidIPString(ipStr)
|
||||
}
|
||||
|
||||
// canBeMACString is a best-effort check to determine if s is a valid MAC
|
||||
// address before using [net.ParseMAC], aimed at reducing allocations.
|
||||
//
|
||||
// TODO(s.chzhen): Replace this implementation with the more robust version
|
||||
// from golibs.
|
||||
func canBeMACString(s string) (ok bool) {
|
||||
switch len(s) {
|
||||
case
|
||||
len("0000.0000.0000"),
|
||||
len("00:00:00:00:00:00"),
|
||||
len("0000.0000.0000.0000"),
|
||||
len("00:00:00:00:00:00:00:00"),
|
||||
len("0000.0000.0000.0000.0000.0000.0000.0000.0000.0000"),
|
||||
len("00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Find represents the parameters for searching a client. params must not be
|
||||
// nil and must have at least one non-empty field.
|
||||
func (s *Storage) Find(params *FindParams) (p *Persistent, ok bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
p, ok = s.index.find(id)
|
||||
isClientID := params.ClientID != ""
|
||||
isRemoteIP := params.RemoteIP != (netip.Addr{})
|
||||
isSubnet := params.Subnet != (netip.Prefix{})
|
||||
isMAC := params.MAC != nil
|
||||
|
||||
for {
|
||||
switch {
|
||||
case isClientID:
|
||||
isClientID = false
|
||||
p, ok = s.index.findByClientID(params.ClientID)
|
||||
case isRemoteIP:
|
||||
isRemoteIP = false
|
||||
p, ok = s.findByIP(params.RemoteIP)
|
||||
case isSubnet:
|
||||
isSubnet = false
|
||||
p, ok = s.index.findByCIDR(params.Subnet)
|
||||
case isMAC:
|
||||
isMAC = false
|
||||
p, ok = s.index.findByMAC(params.MAC)
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if ok {
|
||||
return p.ShallowClone(), true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// findByIP finds persistent client by IP address. s.mu is expected to be
|
||||
// locked.
|
||||
func (s *Storage) findByIP(addr netip.Addr) (p *Persistent, ok bool) {
|
||||
p, ok = s.index.findByIP(addr)
|
||||
if ok {
|
||||
return p.ShallowClone(), ok
|
||||
return p, true
|
||||
}
|
||||
|
||||
ip, err := netip.ParseAddr(id)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
foundMAC := s.dhcp.MACByIP(ip)
|
||||
foundMAC := s.dhcp.MACByIP(addr)
|
||||
if foundMAC != nil {
|
||||
return s.FindByMAC(foundMAC)
|
||||
return s.index.findByMAC(foundMAC)
|
||||
}
|
||||
|
||||
return nil, false
|
||||
@@ -487,6 +627,8 @@ func (s *Storage) Find(id string) (p *Persistent, ok bool) {
|
||||
//
|
||||
// Note that multiple clients can have the same IP address with different zones.
|
||||
// Therefore, the result of this method is indeterminate.
|
||||
//
|
||||
// TODO(s.chzhen): Consider accepting [FindParams].
|
||||
func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -498,7 +640,7 @@ func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) {
|
||||
|
||||
foundMAC := s.dhcp.MACByIP(ip)
|
||||
if foundMAC != nil {
|
||||
return s.FindByMAC(foundMAC)
|
||||
return s.index.findByMAC(foundMAC)
|
||||
}
|
||||
|
||||
p = s.index.findByIPWithoutZone(ip)
|
||||
@@ -509,17 +651,6 @@ func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// FindByMAC finds persistent client by MAC and returns its shallow copy. s.mu
|
||||
// is expected to be locked.
|
||||
func (s *Storage) FindByMAC(mac net.HardwareAddr) (p *Persistent, ok bool) {
|
||||
p, ok = s.index.findByMAC(mac)
|
||||
if ok {
|
||||
return p.ShallowClone(), ok
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// RemoveByName removes persistent client information. ok is false if no such
|
||||
// client exists by that name.
|
||||
func (s *Storage) RemoveByName(ctx context.Context, name string) (ok bool) {
|
||||
@@ -648,7 +779,7 @@ func (s *Storage) CustomUpstreamConfig(
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
c, ok := s.index.findByClientID(id)
|
||||
c, ok := s.index.findByClientID(ClientID(id))
|
||||
if !ok {
|
||||
c, ok = s.index.findByIP(addr)
|
||||
}
|
||||
@@ -682,7 +813,7 @@ func (s *Storage) ClearUpstreamCache() {
|
||||
// ClientID or client IP address, and applies it to the filtering settings.
|
||||
// setts must not be nil.
|
||||
func (s *Storage) ApplyClientFiltering(id string, addr netip.Addr, setts *filtering.Settings) {
|
||||
c, ok := s.index.findByClientID(id)
|
||||
c, ok := s.index.findByClientID(ClientID(id))
|
||||
if !ok {
|
||||
c, ok = s.index.findByIP(addr)
|
||||
}
|
||||
@@ -690,7 +821,7 @@ func (s *Storage) ApplyClientFiltering(id string, addr netip.Addr, setts *filter
|
||||
if !ok {
|
||||
foundMAC := s.dhcp.MACByIP(addr)
|
||||
if foundMAC != nil {
|
||||
c, ok = s.FindByMAC(foundMAC)
|
||||
c, ok = s.index.findByMAC(foundMAC)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/hostsfile"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
@@ -350,15 +351,15 @@ func TestClientsDHCP(t *testing.T) {
|
||||
cliName1 = "one.dhcp"
|
||||
|
||||
cliIP2 = netip.MustParseAddr("2.2.2.2")
|
||||
cliMAC2 = mustParseMAC("22:22:22:22:22:22")
|
||||
cliMAC2 = errors.Must(net.ParseMAC("22:22:22:22:22:22"))
|
||||
cliName2 = "two.dhcp"
|
||||
|
||||
cliIP3 = netip.MustParseAddr("3.3.3.3")
|
||||
cliMAC3 = mustParseMAC("33:33:33:33:33:33")
|
||||
cliMAC3 = errors.Must(net.ParseMAC("33:33:33:33:33:33"))
|
||||
cliName3 = "three.dhcp"
|
||||
|
||||
prsCliIP = netip.MustParseAddr("4.3.2.1")
|
||||
prsCliMAC = mustParseMAC("AA:AA:AA:AA:AA:AA")
|
||||
prsCliMAC = errors.Must(net.ParseMAC("AA:AA:AA:AA:AA:AA"))
|
||||
prsCliName = "persistent.dhcp"
|
||||
|
||||
otherARPCliName = "other.arp"
|
||||
@@ -519,7 +520,11 @@ func TestClientsDHCP(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
prsCli, ok := storage.Find(prsCliIP.String())
|
||||
params := &client.FindParams{}
|
||||
err = params.Set(prsCliIP.String())
|
||||
require.NoError(t, err)
|
||||
|
||||
prsCli, ok := storage.Find(params)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, prsCliName, prsCli.Name)
|
||||
@@ -663,17 +668,6 @@ func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) {
|
||||
return s
|
||||
}
|
||||
|
||||
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
|
||||
// error.
|
||||
func mustParseMAC(s string) (mac net.HardwareAddr) {
|
||||
mac, err := net.ParseMAC(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return mac
|
||||
}
|
||||
|
||||
func TestStorage_Add(t *testing.T) {
|
||||
const (
|
||||
existingName = "existing_name"
|
||||
@@ -693,7 +687,7 @@ func TestStorage_Add(t *testing.T) {
|
||||
Name: existingName,
|
||||
IPs: []netip.Addr{existingIP},
|
||||
Subnets: []netip.Prefix{existingSubnet},
|
||||
ClientIDs: []string{existingClientID},
|
||||
ClientIDs: []client.ClientID{existingClientID},
|
||||
UID: existingClientUID,
|
||||
}
|
||||
|
||||
@@ -761,7 +755,7 @@ func TestStorage_Add(t *testing.T) {
|
||||
name: "duplicate_client_id",
|
||||
cli: &client.Persistent{
|
||||
Name: "duplicate_client_id",
|
||||
ClientIDs: []string{existingClientID},
|
||||
ClientIDs: []client.ClientID{existingClientID},
|
||||
UID: client.MustNewUID(),
|
||||
},
|
||||
wantErrMsg: `adding client: another client "existing_name" ` +
|
||||
@@ -898,12 +892,12 @@ func TestStorage_Find(t *testing.T) {
|
||||
|
||||
clientWithMAC = &client.Persistent{
|
||||
Name: "client_with_mac",
|
||||
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
|
||||
MACs: []net.HardwareAddr{errors.Must(net.ParseMAC(cliMAC))},
|
||||
}
|
||||
|
||||
clientWithID = &client.Persistent{
|
||||
Name: "client_with_id",
|
||||
ClientIDs: []string{cliID},
|
||||
ClientIDs: []client.ClientID{cliID},
|
||||
}
|
||||
|
||||
clientLinkLocal = &client.Persistent{
|
||||
@@ -950,7 +944,11 @@ func TestStorage_Find(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
for _, id := range tc.ids {
|
||||
c, ok := s.Find(id)
|
||||
params := &client.FindParams{}
|
||||
err := params.Set(id)
|
||||
require.NoError(t, err)
|
||||
|
||||
c, ok := s.Find(params)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, tc.want, c)
|
||||
@@ -959,7 +957,11 @@ func TestStorage_Find(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("not_found", func(t *testing.T) {
|
||||
_, ok := s.Find(cliIPNone)
|
||||
params := &client.FindParams{}
|
||||
err := params.Set(cliIPNone)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, ok := s.Find(params)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
@@ -1025,127 +1027,6 @@ func TestStorage_FindLoose(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorage_FindByName(t *testing.T) {
|
||||
const (
|
||||
cliIP1 = "1.1.1.1"
|
||||
cliIP2 = "2.2.2.2"
|
||||
)
|
||||
|
||||
const (
|
||||
clientExistingName = "client_existing"
|
||||
clientAnotherExistingName = "client_another_existing"
|
||||
nonExistingClientName = "client_non_existing"
|
||||
)
|
||||
|
||||
var (
|
||||
clientExisting = &client.Persistent{
|
||||
Name: clientExistingName,
|
||||
IPs: []netip.Addr{netip.MustParseAddr(cliIP1)},
|
||||
}
|
||||
|
||||
clientAnotherExisting = &client.Persistent{
|
||||
Name: clientAnotherExistingName,
|
||||
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
|
||||
}
|
||||
)
|
||||
|
||||
clients := []*client.Persistent{
|
||||
clientExisting,
|
||||
clientAnotherExisting,
|
||||
}
|
||||
s := newStorage(t, clients)
|
||||
|
||||
testCases := []struct {
|
||||
want *client.Persistent
|
||||
name string
|
||||
clientName string
|
||||
}{{
|
||||
name: "existing",
|
||||
clientName: clientExistingName,
|
||||
want: clientExisting,
|
||||
}, {
|
||||
name: "another_existing",
|
||||
clientName: clientAnotherExistingName,
|
||||
want: clientAnotherExisting,
|
||||
}, {
|
||||
name: "non_existing",
|
||||
clientName: nonExistingClientName,
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c, ok := s.FindByName(tc.clientName)
|
||||
if tc.want == nil {
|
||||
assert.False(t, ok)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, tc.want, c)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorage_FindByMAC(t *testing.T) {
|
||||
var (
|
||||
cliMAC = mustParseMAC("11:11:11:11:11:11")
|
||||
cliAnotherMAC = mustParseMAC("22:22:22:22:22:22")
|
||||
nonExistingClientMAC = mustParseMAC("33:33:33:33:33:33")
|
||||
)
|
||||
|
||||
var (
|
||||
clientExisting = &client.Persistent{
|
||||
Name: "client",
|
||||
MACs: []net.HardwareAddr{cliMAC},
|
||||
}
|
||||
|
||||
clientAnotherExisting = &client.Persistent{
|
||||
Name: "another_client",
|
||||
MACs: []net.HardwareAddr{cliAnotherMAC},
|
||||
}
|
||||
)
|
||||
|
||||
clients := []*client.Persistent{
|
||||
clientExisting,
|
||||
clientAnotherExisting,
|
||||
}
|
||||
s := newStorage(t, clients)
|
||||
|
||||
testCases := []struct {
|
||||
want *client.Persistent
|
||||
name string
|
||||
clientMAC net.HardwareAddr
|
||||
}{{
|
||||
name: "existing",
|
||||
clientMAC: cliMAC,
|
||||
want: clientExisting,
|
||||
}, {
|
||||
name: "another_existing",
|
||||
clientMAC: cliAnotherMAC,
|
||||
want: clientAnotherExisting,
|
||||
}, {
|
||||
name: "non_existing",
|
||||
clientMAC: nonExistingClientMAC,
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c, ok := s.FindByMAC(tc.clientMAC)
|
||||
if tc.want == nil {
|
||||
assert.False(t, ok)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, tc.want, c)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorage_Update(t *testing.T) {
|
||||
const (
|
||||
clientName = "client_name"
|
||||
@@ -1162,7 +1043,7 @@ func TestStorage_Update(t *testing.T) {
|
||||
Name: obstructingName,
|
||||
IPs: []netip.Addr{obstructingIP},
|
||||
Subnets: []netip.Prefix{obstructingSubnet},
|
||||
ClientIDs: []string{obstructingClientID},
|
||||
ClientIDs: []client.ClientID{obstructingClientID},
|
||||
}
|
||||
|
||||
clientToUpdate := &client.Persistent{
|
||||
@@ -1211,7 +1092,7 @@ func TestStorage_Update(t *testing.T) {
|
||||
name: "duplicate_client_id",
|
||||
cli: &client.Persistent{
|
||||
Name: "duplicate_client_id",
|
||||
ClientIDs: []string{obstructingClientID},
|
||||
ClientIDs: []client.ClientID{obstructingClientID},
|
||||
UID: client.MustNewUID(),
|
||||
},
|
||||
wantErrMsg: `updating client: another client "obstructing_name" ` +
|
||||
@@ -1238,19 +1119,19 @@ func TestStorage_Update(t *testing.T) {
|
||||
func TestStorage_RangeByName(t *testing.T) {
|
||||
sortedClients := []*client.Persistent{{
|
||||
Name: "clientA",
|
||||
ClientIDs: []string{"A"},
|
||||
ClientIDs: []client.ClientID{"A"},
|
||||
}, {
|
||||
Name: "clientB",
|
||||
ClientIDs: []string{"B"},
|
||||
ClientIDs: []client.ClientID{"B"},
|
||||
}, {
|
||||
Name: "clientC",
|
||||
ClientIDs: []string{"C"},
|
||||
ClientIDs: []client.ClientID{"C"},
|
||||
}, {
|
||||
Name: "clientD",
|
||||
ClientIDs: []string{"D"},
|
||||
ClientIDs: []client.ClientID{"D"},
|
||||
}, {
|
||||
Name: "clientE",
|
||||
ClientIDs: []string{"E"},
|
||||
ClientIDs: []client.ClientID{"E"},
|
||||
}}
|
||||
|
||||
testCases := []struct {
|
||||
@@ -1306,7 +1187,7 @@ func TestStorage_CustomUpstreamConfig(t *testing.T) {
|
||||
existingClient := &client.Persistent{
|
||||
Name: existingName,
|
||||
IPs: []netip.Addr{existingIP},
|
||||
ClientIDs: []string{existingClientID},
|
||||
ClientIDs: []client.ClientID{existingClientID},
|
||||
UID: existingClientUID,
|
||||
Upstreams: []string{"192.0.2.0"},
|
||||
}
|
||||
@@ -1381,3 +1262,182 @@ func TestStorage_CustomUpstreamConfig(t *testing.T) {
|
||||
assert.NotEqual(t, conf, updConf)
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkFindParams_Set(b *testing.B) {
|
||||
const (
|
||||
testIPStr = "192.0.2.1"
|
||||
testCIDRStr = "192.0.2.0/24"
|
||||
testMACStr = "02:00:00:00:00:00"
|
||||
testClientID = "clientid"
|
||||
)
|
||||
|
||||
benchCases := []struct {
|
||||
wantErr error
|
||||
params *client.FindParams
|
||||
name string
|
||||
id string
|
||||
}{{
|
||||
wantErr: nil,
|
||||
params: &client.FindParams{
|
||||
ClientID: testClientID,
|
||||
},
|
||||
name: "client_id",
|
||||
id: testClientID,
|
||||
}, {
|
||||
wantErr: nil,
|
||||
params: &client.FindParams{
|
||||
RemoteIP: netip.MustParseAddr(testIPStr),
|
||||
},
|
||||
name: "ip_address",
|
||||
id: testIPStr,
|
||||
}, {
|
||||
wantErr: nil,
|
||||
params: &client.FindParams{
|
||||
Subnet: netip.MustParsePrefix(testCIDRStr),
|
||||
},
|
||||
name: "subnet",
|
||||
id: testCIDRStr,
|
||||
}, {
|
||||
wantErr: nil,
|
||||
params: &client.FindParams{
|
||||
MAC: errors.Must(net.ParseMAC(testMACStr)),
|
||||
},
|
||||
name: "mac_address",
|
||||
id: testMACStr,
|
||||
}, {
|
||||
wantErr: client.ErrBadIdentifier,
|
||||
params: &client.FindParams{},
|
||||
name: "bad_id",
|
||||
id: "!@#$%^&*()_+",
|
||||
}}
|
||||
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
params := &client.FindParams{}
|
||||
var err error
|
||||
|
||||
b.ReportAllocs()
|
||||
for b.Loop() {
|
||||
err = params.Set(bc.id)
|
||||
}
|
||||
|
||||
assert.ErrorIs(b, err, bc.wantErr)
|
||||
assert.Equal(b, bc.params, params)
|
||||
})
|
||||
}
|
||||
|
||||
// Most recent results:
|
||||
//
|
||||
// goos: linux
|
||||
// goarch: amd64
|
||||
// pkg: github.com/AdguardTeam/AdGuardHome/internal/client
|
||||
// cpu: Intel(R) Core(TM) i7-10510U CPU @ 1.80GHz
|
||||
// BenchmarkFindParams_Set/client_id-8 49463488 24.27 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkFindParams_Set/ip_address-8 18740977 62.22 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkFindParams_Set/subnet-8 10848192 110.0 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkFindParams_Set/mac_address-8 8148494 133.2 ns/op 8 B/op 1 allocs/op
|
||||
// BenchmarkFindParams_Set/bad_id-8 73894278 16.29 ns/op 0 B/op 0 allocs/op
|
||||
}
|
||||
|
||||
func BenchmarkStorage_Find(b *testing.B) {
|
||||
const (
|
||||
cliID = "cid"
|
||||
cliMAC = "02:00:00:00:00:00"
|
||||
)
|
||||
|
||||
const (
|
||||
cliNameWithID = "client_with_id"
|
||||
cliNameWithIP = "client_with_ip"
|
||||
cliNameWithCIDR = "client_with_cidr"
|
||||
cliNameWithMAC = "client_with_mac"
|
||||
)
|
||||
|
||||
var (
|
||||
cliIP = netip.MustParseAddr("192.0.2.1")
|
||||
cliCIDR = netip.MustParsePrefix("192.0.2.0/24")
|
||||
)
|
||||
|
||||
var (
|
||||
clientWithID = &client.Persistent{
|
||||
Name: cliNameWithID,
|
||||
ClientIDs: []client.ClientID{cliID},
|
||||
}
|
||||
clientWithIP = &client.Persistent{
|
||||
Name: cliNameWithIP,
|
||||
IPs: []netip.Addr{cliIP},
|
||||
}
|
||||
clientWithCIDR = &client.Persistent{
|
||||
Name: cliNameWithCIDR,
|
||||
Subnets: []netip.Prefix{cliCIDR},
|
||||
}
|
||||
clientWithMAC = &client.Persistent{
|
||||
Name: cliNameWithMAC,
|
||||
MACs: []net.HardwareAddr{errors.Must(net.ParseMAC(cliMAC))},
|
||||
}
|
||||
)
|
||||
|
||||
clients := []*client.Persistent{
|
||||
clientWithID,
|
||||
clientWithIP,
|
||||
clientWithCIDR,
|
||||
clientWithMAC,
|
||||
}
|
||||
s := newStorage(b, clients)
|
||||
|
||||
benchCases := []struct {
|
||||
params *client.FindParams
|
||||
name string
|
||||
wantName string
|
||||
}{{
|
||||
params: &client.FindParams{
|
||||
ClientID: cliID,
|
||||
},
|
||||
name: "client_id",
|
||||
wantName: cliNameWithID,
|
||||
}, {
|
||||
params: &client.FindParams{
|
||||
RemoteIP: cliIP,
|
||||
},
|
||||
name: "ip_address",
|
||||
wantName: cliNameWithIP,
|
||||
}, {
|
||||
params: &client.FindParams{
|
||||
Subnet: cliCIDR,
|
||||
},
|
||||
name: "subnet",
|
||||
wantName: cliNameWithCIDR,
|
||||
}, {
|
||||
params: &client.FindParams{
|
||||
MAC: errors.Must(net.ParseMAC(cliMAC)),
|
||||
},
|
||||
name: "mac_address",
|
||||
wantName: cliNameWithMAC,
|
||||
}}
|
||||
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
var p *client.Persistent
|
||||
var ok bool
|
||||
|
||||
b.ReportAllocs()
|
||||
for b.Loop() {
|
||||
p, ok = s.Find(bc.params)
|
||||
}
|
||||
|
||||
assert.True(b, ok)
|
||||
assert.NotNil(b, p)
|
||||
assert.Equal(b, bc.wantName, p.Name)
|
||||
})
|
||||
}
|
||||
|
||||
// Most recent results:
|
||||
//
|
||||
// goos: linux
|
||||
// goarch: amd64
|
||||
// pkg: github.com/AdguardTeam/AdGuardHome/internal/client
|
||||
// cpu: Intel(R) Core(TM) i7-10510U CPU @ 1.80GHz
|
||||
// BenchmarkStorage_Find/client_id-8 7070107 154.4 ns/op 240 B/op 2 allocs/op
|
||||
// BenchmarkStorage_Find/ip_address-8 6831823 168.6 ns/op 248 B/op 2 allocs/op
|
||||
// BenchmarkStorage_Find/subnet-8 7209050 167.5 ns/op 256 B/op 2 allocs/op
|
||||
// BenchmarkStorage_Find/mac_address-8 5776131 199.7 ns/op 256 B/op 3 allocs/op
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user