Pull request 2378: AGDNS-2750-find-client
Merge in DNS/adguard-home from AGDNS-2750-find-client to master Squashed commit of the following: commit 98f1a8ca4622b6f502a5092273b9724203fe0bd8 Merge: 9270222d84ccc2a213Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Apr 23 17:53:20 2025 +0300 Merge branch 'master' into AGDNS-2750-find-client commit 9270222d8e9e03038e9434b54496cbb6164463cd Merge: 6468ceec8c7c62ad3bAuthor: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Apr 21 19:40:58 2025 +0300 Merge branch 'master' into AGDNS-2750-find-client commit 6468ceec82d30084771a53ff6720a8c11c68bf2f Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Apr 21 19:40:52 2025 +0300 home: imp docs commit 3fd4735a0d6db4fdf2d46f3da9794a687fdcaa8b Merge: 1311a5869a8fdf1c55Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Apr 18 19:43:36 2025 +0300 Merge branch 'master' into AGDNS-2750-find-client commit 1311a58695de00f20c9704378ee6e964a44d1c59 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Apr 18 19:42:41 2025 +0300 home: imp code commit b1f2c4c883c9476c5135140abac31f8ae6609b4f Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Apr 16 16:47:59 2025 +0300 home: imp code commit d0a5abd66587c1ad602c2ccf6c8a45a3dfe39a5c Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Apr 15 14:58:31 2025 +0300 client: imp naming commit 5accdca325551237f003f1c416891b488fe5290b Merge: 6a00232f74d258972dAuthor: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Apr 14 19:40:40 2025 +0300 Merge branch 'master' into AGDNS-2750-find-client commit 6a00232f76a0fe5ce781aa01637b6e04ace7250d Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Apr 14 19:30:32 2025 +0300 home: imp code commit 8633886457c6aab75f5676494b1f49d9811e9ab9 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Apr 11 15:29:25 2025 +0300 all: imp code commit d6f16879e7b054a5ffac59131d2a6eff1da659c0 Merge: 58236fdec6d282ae71Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Apr 10 21:35:23 2025 +0300 Merge branch 'master' into AGDNS-2750-find-client commit 58236fdec5b64e83a44680ff8a89badc18ec81f1 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Apr 10 21:23:01 2025 +0300 all: upd ci commit 3c4d946d7970987677d4ac984394e18987a29f9a Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Apr 10 21:16:03 2025 +0300 all: upd go commit cc1c97734506a9ffbe70fd3c676284e58a21ba46 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Apr 10 20:58:56 2025 +0300 all: imp code commit 8f061c933152481a4c80eef2af575efd4919d82b Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Apr 9 16:49:11 2025 +0300 all: imp docs commit 8d19355f1c519211a56cec3f23d527922d4f2ee0 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Apr 7 21:35:06 2025 +0300 all: imp code commit f1e853f57e5d54d13bedcdab4f8e21e112f3a356 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Apr 2 14:57:40 2025 +0300 all: imp code commit 6a6ac7f899f29ddc90a583c80562233e646ba1d6 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Apr 1 19:51:56 2025 +0300 client: imp tests commit 52040ee7393d0483c682f2f37d7b70f12f9cf621 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Apr 1 19:28:18 2025 +0300 all: imp code commit 1e09208dbd2d35c3f6b2ade169324e23d1a643a5 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Mar 26 15:33:02 2025 +0300 all: imp code ... and 2 more commits
This commit is contained in:
@@ -36,6 +36,10 @@ See also the [v0.107.61 GitHub milestone][ms-v0.107.61].
|
|||||||
|
|
||||||
**NOTE:** We thank [Xiang Li][mr-xiang-li] for reporting this security issue. It's strongly recommended to leave it enabled, otherwise AdGuard Home will be vulnerable to untrusted clients.
|
**NOTE:** We thank [Xiang Li][mr-xiang-li] for reporting this security issue. It's strongly recommended to leave it enabled, otherwise AdGuard Home will be vulnerable to untrusted clients.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Searching for persistent clients using an exact match for CIDR in the `POST /clients/search HTTP API`.
|
||||||
|
|
||||||
[mr-xiang-li]: https://lixiang521.com/
|
[mr-xiang-li]: https://lixiang521.com/
|
||||||
[ms-v0.107.61]: https://github.com/AdguardTeam/AdGuardHome/milestone/96?closed=1
|
[ms-v0.107.61]: https://github.com/AdguardTeam/AdGuardHome/milestone/96?closed=1
|
||||||
|
|
||||||
|
|||||||
@@ -11,8 +11,34 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ClientID is a unique identifier for a persistent client used in
|
||||||
|
// DNS-over-HTTPS, DNS-over-TLS, and DNS-over-QUIC queries.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Use everywhere.
|
||||||
|
type ClientID string
|
||||||
|
|
||||||
|
// ValidateClientID returns an error if id is not a valid ClientID.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Consider implementing [validate.Interface] for ClientID.
|
||||||
|
func ValidateClientID(id string) (err error) {
|
||||||
|
err = netutil.ValidateHostnameLabel(id)
|
||||||
|
if err != nil {
|
||||||
|
// Replace the domain name label wrapper with our own.
|
||||||
|
return fmt.Errorf("invalid clientid %q: %w", id, errors.Unwrap(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isValidClientID returns false if id is not a valid ClientID.
|
||||||
|
func isValidClientID(id string) (ok bool) {
|
||||||
|
return netutil.IsValidHostnameLabel(id)
|
||||||
|
}
|
||||||
|
|
||||||
// Source represents the source from which the information about the client has
|
// Source represents the source from which the information about the client has
|
||||||
// been obtained.
|
// been obtained.
|
||||||
type Source uint8
|
type Source uint8
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ type index struct {
|
|||||||
nameToUID map[string]UID
|
nameToUID map[string]UID
|
||||||
|
|
||||||
// clientIDToUID maps ClientID to UID.
|
// clientIDToUID maps ClientID to UID.
|
||||||
clientIDToUID map[string]UID
|
clientIDToUID map[ClientID]UID
|
||||||
|
|
||||||
// ipToUID maps IP address to UID.
|
// ipToUID maps IP address to UID.
|
||||||
ipToUID map[netip.Addr]UID
|
ipToUID map[netip.Addr]UID
|
||||||
@@ -54,7 +54,7 @@ type index struct {
|
|||||||
func newIndex() (ci *index) {
|
func newIndex() (ci *index) {
|
||||||
return &index{
|
return &index{
|
||||||
nameToUID: map[string]UID{},
|
nameToUID: map[string]UID{},
|
||||||
clientIDToUID: map[string]UID{},
|
clientIDToUID: map[ClientID]UID{},
|
||||||
ipToUID: map[netip.Addr]UID{},
|
ipToUID: map[netip.Addr]UID{},
|
||||||
subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare),
|
subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare),
|
||||||
macToUID: map[macKey]UID{},
|
macToUID: map[macKey]UID{},
|
||||||
@@ -207,7 +207,7 @@ func (ci *index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr)
|
|||||||
// find finds persistent client by string representation of the ClientID, IP
|
// find finds persistent client by string representation of the ClientID, IP
|
||||||
// address, or MAC.
|
// address, or MAC.
|
||||||
func (ci *index) find(id string) (c *Persistent, ok bool) {
|
func (ci *index) find(id string) (c *Persistent, ok bool) {
|
||||||
c, ok = ci.findByClientID(id)
|
c, ok = ci.findByClientID(ClientID(id))
|
||||||
if ok {
|
if ok {
|
||||||
return c, true
|
return c, true
|
||||||
}
|
}
|
||||||
@@ -230,7 +230,7 @@ func (ci *index) find(id string) (c *Persistent, ok bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// findByClientID finds persistent client by ClientID.
|
// findByClientID finds persistent client by ClientID.
|
||||||
func (ci *index) findByClientID(clientID string) (c *Persistent, ok bool) {
|
func (ci *index) findByClientID(clientID ClientID) (c *Persistent, ok bool) {
|
||||||
uid, ok := ci.clientIDToUID[clientID]
|
uid, ok := ci.clientIDToUID[clientID]
|
||||||
if ok {
|
if ok {
|
||||||
return ci.uidToClient[uid], true
|
return ci.uidToClient[uid], true
|
||||||
@@ -275,6 +275,26 @@ func (ci *index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// findByCIDR searches for a persistent client with the provided subnet as an
|
||||||
|
// identifier. Note that this function looks for an exact match of subnets,
|
||||||
|
// rather than checking if one subnet contains another.
|
||||||
|
func (ci *index) findByCIDR(subnet netip.Prefix) (c *Persistent, ok bool) {
|
||||||
|
var uid UID
|
||||||
|
for pref, id := range ci.subnetToUID.Range {
|
||||||
|
if subnet == pref {
|
||||||
|
uid, ok = id, true
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
return ci.uidToClient[uid], true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
// findByMAC finds persistent client by MAC.
|
// findByMAC finds persistent client by MAC.
|
||||||
func (ci *index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
|
func (ci *index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
|
||||||
k := macToKey(mac)
|
k := macToKey(mac)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -58,12 +59,12 @@ func TestClientIndex_Find(t *testing.T) {
|
|||||||
|
|
||||||
clientWithMAC = &Persistent{
|
clientWithMAC = &Persistent{
|
||||||
Name: "client_with_mac",
|
Name: "client_with_mac",
|
||||||
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
|
MACs: []net.HardwareAddr{errors.Must(net.ParseMAC(cliMAC))},
|
||||||
}
|
}
|
||||||
|
|
||||||
clientWithID = &Persistent{
|
clientWithID = &Persistent{
|
||||||
Name: "client_with_id",
|
Name: "client_with_id",
|
||||||
ClientIDs: []string{cliID},
|
ClientIDs: []ClientID{cliID},
|
||||||
}
|
}
|
||||||
|
|
||||||
clientLinkLocal = &Persistent{
|
clientLinkLocal = &Persistent{
|
||||||
@@ -141,10 +142,10 @@ func TestClientIndex_Clashes(t *testing.T) {
|
|||||||
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
|
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
|
||||||
}, {
|
}, {
|
||||||
Name: "client_with_mac",
|
Name: "client_with_mac",
|
||||||
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
|
MACs: []net.HardwareAddr{errors.Must(net.ParseMAC(cliMAC))},
|
||||||
}, {
|
}, {
|
||||||
Name: "client_with_id",
|
Name: "client_with_id",
|
||||||
ClientIDs: []string{cliID},
|
ClientIDs: []ClientID{cliID},
|
||||||
}}
|
}}
|
||||||
|
|
||||||
ci := newIDIndex(clients)
|
ci := newIDIndex(clients)
|
||||||
@@ -181,17 +182,6 @@ func TestClientIndex_Clashes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
|
|
||||||
// error.
|
|
||||||
func mustParseMAC(s string) (mac net.HardwareAddr) {
|
|
||||||
mac, err := net.ParseMAC(s)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return mac
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMACToKey(t *testing.T) {
|
func TestMACToKey(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
want any
|
want any
|
||||||
@@ -200,44 +190,44 @@ func TestMACToKey(t *testing.T) {
|
|||||||
}{{
|
}{{
|
||||||
name: "column6",
|
name: "column6",
|
||||||
in: "00:00:5e:00:53:01",
|
in: "00:00:5e:00:53:01",
|
||||||
want: [6]byte(mustParseMAC("00:00:5e:00:53:01")),
|
want: [6]byte(errors.Must(net.ParseMAC("00:00:5e:00:53:01"))),
|
||||||
}, {
|
}, {
|
||||||
name: "column8",
|
name: "column8",
|
||||||
in: "02:00:5e:10:00:00:00:01",
|
in: "02:00:5e:10:00:00:00:01",
|
||||||
want: [8]byte(mustParseMAC("02:00:5e:10:00:00:00:01")),
|
want: [8]byte(errors.Must(net.ParseMAC("02:00:5e:10:00:00:00:01"))),
|
||||||
}, {
|
}, {
|
||||||
name: "column20",
|
name: "column20",
|
||||||
in: "00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01",
|
in: "00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01",
|
||||||
want: [20]byte(mustParseMAC("00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01")),
|
want: [20]byte(errors.Must(net.ParseMAC("00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01"))),
|
||||||
}, {
|
}, {
|
||||||
name: "hyphen6",
|
name: "hyphen6",
|
||||||
in: "00-00-5e-00-53-01",
|
in: "00-00-5e-00-53-01",
|
||||||
want: [6]byte(mustParseMAC("00-00-5e-00-53-01")),
|
want: [6]byte(errors.Must(net.ParseMAC("00-00-5e-00-53-01"))),
|
||||||
}, {
|
}, {
|
||||||
name: "hyphen8",
|
name: "hyphen8",
|
||||||
in: "02-00-5e-10-00-00-00-01",
|
in: "02-00-5e-10-00-00-00-01",
|
||||||
want: [8]byte(mustParseMAC("02-00-5e-10-00-00-00-01")),
|
want: [8]byte(errors.Must(net.ParseMAC("02-00-5e-10-00-00-00-01"))),
|
||||||
}, {
|
}, {
|
||||||
name: "hyphen20",
|
name: "hyphen20",
|
||||||
in: "00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01",
|
in: "00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01",
|
||||||
want: [20]byte(mustParseMAC("00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01")),
|
want: [20]byte(errors.Must(net.ParseMAC("00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01"))),
|
||||||
}, {
|
}, {
|
||||||
name: "dot6",
|
name: "dot6",
|
||||||
in: "0000.5e00.5301",
|
in: "0000.5e00.5301",
|
||||||
want: [6]byte(mustParseMAC("0000.5e00.5301")),
|
want: [6]byte(errors.Must(net.ParseMAC("0000.5e00.5301"))),
|
||||||
}, {
|
}, {
|
||||||
name: "dot8",
|
name: "dot8",
|
||||||
in: "0200.5e10.0000.0001",
|
in: "0200.5e10.0000.0001",
|
||||||
want: [8]byte(mustParseMAC("0200.5e10.0000.0001")),
|
want: [8]byte(errors.Must(net.ParseMAC("0200.5e10.0000.0001"))),
|
||||||
}, {
|
}, {
|
||||||
name: "dot20",
|
name: "dot20",
|
||||||
in: "0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001",
|
in: "0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001",
|
||||||
want: [20]byte(mustParseMAC("0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001")),
|
want: [20]byte(errors.Must(net.ParseMAC("0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001"))),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
mac := mustParseMAC(tc.in)
|
mac := errors.Must(net.ParseMAC(tc.in))
|
||||||
|
|
||||||
key := macToKey(mac)
|
key := macToKey(mac)
|
||||||
assert.Equal(t, tc.want, key)
|
assert.Equal(t, tc.want, key)
|
||||||
@@ -302,19 +292,19 @@ func TestIndex_FindByIPWithoutZone(t *testing.T) {
|
|||||||
func TestClientIndex_RangeByName(t *testing.T) {
|
func TestClientIndex_RangeByName(t *testing.T) {
|
||||||
sortedClients := []*Persistent{{
|
sortedClients := []*Persistent{{
|
||||||
Name: "clientA",
|
Name: "clientA",
|
||||||
ClientIDs: []string{"A"},
|
ClientIDs: []ClientID{"A"},
|
||||||
}, {
|
}, {
|
||||||
Name: "clientB",
|
Name: "clientB",
|
||||||
ClientIDs: []string{"B"},
|
ClientIDs: []ClientID{"B"},
|
||||||
}, {
|
}, {
|
||||||
Name: "clientC",
|
Name: "clientC",
|
||||||
ClientIDs: []string{"C"},
|
ClientIDs: []ClientID{"C"},
|
||||||
}, {
|
}, {
|
||||||
Name: "clientD",
|
Name: "clientD",
|
||||||
ClientIDs: []string{"D"},
|
ClientIDs: []ClientID{"D"},
|
||||||
}, {
|
}, {
|
||||||
Name: "clientE",
|
Name: "clientE",
|
||||||
ClientIDs: []string{"E"},
|
ClientIDs: []ClientID{"E"},
|
||||||
}}
|
}}
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
@@ -349,3 +339,115 @@ func TestClientIndex_RangeByName(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIndex_FindByName(t *testing.T) {
|
||||||
|
const (
|
||||||
|
clientExistingName = "client_existing"
|
||||||
|
clientAnotherExistingName = "client_another_existing"
|
||||||
|
nonExistingClientName = "client_non_existing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
clientExisting = &Persistent{
|
||||||
|
Name: clientExistingName,
|
||||||
|
IPs: []netip.Addr{netip.MustParseAddr("192.0.2.1")},
|
||||||
|
}
|
||||||
|
|
||||||
|
clientAnotherExisting = &Persistent{
|
||||||
|
Name: clientAnotherExistingName,
|
||||||
|
IPs: []netip.Addr{netip.MustParseAddr("192.0.2.2")},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
clients := []*Persistent{
|
||||||
|
clientExisting,
|
||||||
|
clientAnotherExisting,
|
||||||
|
}
|
||||||
|
ci := newIDIndex(clients)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
want *Persistent
|
||||||
|
found assert.BoolAssertionFunc
|
||||||
|
name string
|
||||||
|
clientName string
|
||||||
|
}{{
|
||||||
|
want: clientExisting,
|
||||||
|
found: assert.True,
|
||||||
|
name: "existing",
|
||||||
|
clientName: clientExistingName,
|
||||||
|
}, {
|
||||||
|
want: clientAnotherExisting,
|
||||||
|
found: assert.True,
|
||||||
|
name: "another_existing",
|
||||||
|
clientName: clientAnotherExistingName,
|
||||||
|
}, {
|
||||||
|
want: nil,
|
||||||
|
found: assert.False,
|
||||||
|
name: "non_existing",
|
||||||
|
clientName: nonExistingClientName,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
c, ok := ci.findByName(tc.clientName)
|
||||||
|
assert.Equal(t, tc.want, c)
|
||||||
|
tc.found(t, ok)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIndex_FindByMAC(t *testing.T) {
|
||||||
|
var (
|
||||||
|
cliMAC = errors.Must(net.ParseMAC("11:11:11:11:11:11"))
|
||||||
|
cliAnotherMAC = errors.Must(net.ParseMAC("22:22:22:22:22:22"))
|
||||||
|
nonExistingClientMAC = errors.Must(net.ParseMAC("33:33:33:33:33:33"))
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
clientExisting = &Persistent{
|
||||||
|
Name: "client",
|
||||||
|
MACs: []net.HardwareAddr{cliMAC},
|
||||||
|
}
|
||||||
|
|
||||||
|
clientAnotherExisting = &Persistent{
|
||||||
|
Name: "another_client",
|
||||||
|
MACs: []net.HardwareAddr{cliAnotherMAC},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
clients := []*Persistent{
|
||||||
|
clientExisting,
|
||||||
|
clientAnotherExisting,
|
||||||
|
}
|
||||||
|
ci := newIDIndex(clients)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
want *Persistent
|
||||||
|
found assert.BoolAssertionFunc
|
||||||
|
name string
|
||||||
|
clientMAC net.HardwareAddr
|
||||||
|
}{{
|
||||||
|
want: clientExisting,
|
||||||
|
found: assert.True,
|
||||||
|
name: "existing",
|
||||||
|
clientMAC: cliMAC,
|
||||||
|
}, {
|
||||||
|
want: clientAnotherExisting,
|
||||||
|
found: assert.True,
|
||||||
|
name: "another_existing",
|
||||||
|
clientMAC: cliAnotherMAC,
|
||||||
|
}, {
|
||||||
|
want: nil,
|
||||||
|
found: assert.False,
|
||||||
|
name: "non_existing",
|
||||||
|
clientMAC: nonExistingClientMAC,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
c, ok := ci.findByMAC(tc.clientMAC)
|
||||||
|
assert.Equal(t, tc.want, c)
|
||||||
|
tc.found(t, ok)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -90,7 +89,7 @@ type Persistent struct {
|
|||||||
|
|
||||||
// ClientIDs identifying the client. The client must have at least one ID
|
// ClientIDs identifying the client. The client must have at least one ID
|
||||||
// (IP, subnet, MAC, or ClientID).
|
// (IP, subnet, MAC, or ClientID).
|
||||||
ClientIDs []string
|
ClientIDs []ClientID
|
||||||
|
|
||||||
// UID is the unique identifier of the persistent client.
|
// UID is the unique identifier of the persistent client.
|
||||||
UID UID
|
UID UID
|
||||||
@@ -134,7 +133,7 @@ func (c *Persistent) validate(ctx context.Context, l *slog.Logger, allTags []str
|
|||||||
switch {
|
switch {
|
||||||
case c.Name == "":
|
case c.Name == "":
|
||||||
return errors.Error("empty name")
|
return errors.Error("empty name")
|
||||||
case c.IDsLen() == 0:
|
case c.idendifiersLen() == 0:
|
||||||
return errors.Error("id required")
|
return errors.Error("id required")
|
||||||
case c.UID == UID{}:
|
case c.UID == UID{}:
|
||||||
return errors.Error("uid required")
|
return errors.Error("uid required")
|
||||||
@@ -237,28 +236,15 @@ func (c *Persistent) setID(id string) (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.ClientIDs = append(c.ClientIDs, strings.ToLower(id))
|
c.ClientIDs = append(c.ClientIDs, ClientID(strings.ToLower(id)))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateClientID returns an error if id is not a valid ClientID.
|
// Identifiers returns a list of client identifiers containing at least one
|
||||||
//
|
// element.
|
||||||
// TODO(s.chzhen): It's an exact copy of the [dnsforward.ValidateClientID] to
|
func (c *Persistent) Identifiers() (ids []string) {
|
||||||
// avoid the import cycle. Remove it.
|
ids = make([]string, 0, c.idendifiersLen())
|
||||||
func ValidateClientID(id string) (err error) {
|
|
||||||
err = netutil.ValidateHostnameLabel(id)
|
|
||||||
if err != nil {
|
|
||||||
// Replace the domain name label wrapper with our own.
|
|
||||||
return fmt.Errorf("invalid clientid %q: %w", id, errors.Unwrap(err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IDs returns a list of ClientIDs containing at least one element.
|
|
||||||
func (c *Persistent) IDs() (ids []string) {
|
|
||||||
ids = make([]string, 0, c.IDsLen())
|
|
||||||
|
|
||||||
for _, ip := range c.IPs {
|
for _, ip := range c.IPs {
|
||||||
ids = append(ids, ip.String())
|
ids = append(ids, ip.String())
|
||||||
@@ -272,11 +258,15 @@ func (c *Persistent) IDs() (ids []string) {
|
|||||||
ids = append(ids, mac.String())
|
ids = append(ids, mac.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return append(ids, c.ClientIDs...)
|
for _, cid := range c.ClientIDs {
|
||||||
|
ids = append(ids, string(cid))
|
||||||
|
}
|
||||||
|
|
||||||
|
return ids
|
||||||
}
|
}
|
||||||
|
|
||||||
// IDsLen returns a length of ClientIDs.
|
// identifiersLen returns the number of client identifiers.
|
||||||
func (c *Persistent) IDsLen() (n int) {
|
func (c *Persistent) idendifiersLen() (n int) {
|
||||||
return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs)
|
return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -18,6 +19,7 @@ import (
|
|||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/hostsfile"
|
"github.com/AdguardTeam/golibs/hostsfile"
|
||||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/AdguardTeam/golibs/timeutil"
|
"github.com/AdguardTeam/golibs/timeutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -433,48 +435,186 @@ func (s *Storage) Add(ctx context.Context, p *Persistent) (err error) {
|
|||||||
ctx,
|
ctx,
|
||||||
"client added",
|
"client added",
|
||||||
"name", p.Name,
|
"name", p.Name,
|
||||||
"ids", p.IDs(),
|
"ids", p.Identifiers(),
|
||||||
"clients_count", s.index.size(),
|
"clients_count", s.index.size(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindByName finds persistent client by name. And returns its shallow copy.
|
// FindParams represents the parameters for searching a client. At least one
|
||||||
func (s *Storage) FindByName(name string) (p *Persistent, ok bool) {
|
// field must be non-empty.
|
||||||
s.mu.Lock()
|
type FindParams struct {
|
||||||
defer s.mu.Unlock()
|
// ClientID is a unique identifier for the client used in DoH, DoT, and DoQ
|
||||||
|
// DNS queries.
|
||||||
|
ClientID ClientID
|
||||||
|
|
||||||
p, ok = s.index.findByName(name)
|
// RemoteIP is the IP address used as a client search parameter.
|
||||||
if ok {
|
RemoteIP netip.Addr
|
||||||
return p.ShallowClone(), ok
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, false
|
// Subnet is the CIDR used as a client search parameter.
|
||||||
|
Subnet netip.Prefix
|
||||||
|
|
||||||
|
// MAC is the physical hardware address used as a client search parameter.
|
||||||
|
MAC net.HardwareAddr
|
||||||
|
|
||||||
|
// UID is the unique ID of persistent client used as a search parameter.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Use this.
|
||||||
|
UID UID
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find finds persistent client by string representation of the ClientID, IP
|
// ErrBadIdentifier is returned by [FindParams.Set] when it cannot parse the
|
||||||
// address, or MAC. And returns its shallow copy.
|
// provided client identifier.
|
||||||
|
const ErrBadIdentifier errors.Error = "bad client identifier"
|
||||||
|
|
||||||
|
// Set clears the stored search parameters and parses the string representation
|
||||||
|
// of the search parameter into typed parameter, storing it. In some cases, it
|
||||||
|
// may result in storing both an IP address and a MAC address because they might
|
||||||
|
// have identical string representations. It returns [ErrBadIdentifier] if id
|
||||||
|
// cannot be parsed.
|
||||||
//
|
//
|
||||||
// TODO(s.chzhen): Accept ClientIDData structure instead, which will contain
|
// TODO(s.chzhen): Add support for UID.
|
||||||
// the parsed IP address, if any.
|
func (p *FindParams) Set(id string) (err error) {
|
||||||
func (s *Storage) Find(id string) (p *Persistent, ok bool) {
|
*p = FindParams{}
|
||||||
|
|
||||||
|
isClientID := true
|
||||||
|
|
||||||
|
if netutil.IsValidIPString(id) {
|
||||||
|
// It is safe to use [netip.MustParseAddr] because it has already been
|
||||||
|
// validated that id contains the string representation of the IP
|
||||||
|
// address.
|
||||||
|
p.RemoteIP = netip.MustParseAddr(id)
|
||||||
|
|
||||||
|
// Even if id can be parsed as an IP address, it may be a MAC address.
|
||||||
|
// So do not return prematurely, continue parsing.
|
||||||
|
isClientID = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if canBeValidIPPrefixString(id) {
|
||||||
|
p.Subnet, err = netip.ParsePrefix(id)
|
||||||
|
if err == nil {
|
||||||
|
isClientID = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if canBeMACString(id) {
|
||||||
|
p.MAC, err = net.ParseMAC(id)
|
||||||
|
if err == nil {
|
||||||
|
isClientID = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isClientID {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValidClientID(id) {
|
||||||
|
return ErrBadIdentifier
|
||||||
|
}
|
||||||
|
|
||||||
|
p.ClientID = ClientID(id)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// canBeValidIPPrefixString is a best-effort check to determine if s is a valid
|
||||||
|
// CIDR before using [netip.ParsePrefix], aimed at reducing allocations.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Replace this implementation with the more robust version
|
||||||
|
// from golibs.
|
||||||
|
func canBeValidIPPrefixString(s string) (ok bool) {
|
||||||
|
ipStr, bitStr, ok := strings.Cut(s, "/")
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if bitStr == "" || len(bitStr) > 3 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
bits := 0
|
||||||
|
for _, c := range bitStr {
|
||||||
|
if c < '0' || c > '9' {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
bits = bits*10 + int(c-'0')
|
||||||
|
}
|
||||||
|
|
||||||
|
if bits > 128 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return netutil.IsValidIPString(ipStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// canBeMACString is a best-effort check to determine if s is a valid MAC
|
||||||
|
// address before using [net.ParseMAC], aimed at reducing allocations.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Replace this implementation with the more robust version
|
||||||
|
// from golibs.
|
||||||
|
func canBeMACString(s string) (ok bool) {
|
||||||
|
switch len(s) {
|
||||||
|
case
|
||||||
|
len("0000.0000.0000"),
|
||||||
|
len("00:00:00:00:00:00"),
|
||||||
|
len("0000.0000.0000.0000"),
|
||||||
|
len("00:00:00:00:00:00:00:00"),
|
||||||
|
len("0000.0000.0000.0000.0000.0000.0000.0000.0000.0000"),
|
||||||
|
len("00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"):
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find represents the parameters for searching a client. params must not be
|
||||||
|
// nil and must have at least one non-empty field.
|
||||||
|
func (s *Storage) Find(params *FindParams) (p *Persistent, ok bool) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
p, ok = s.index.find(id)
|
isClientID := params.ClientID != ""
|
||||||
|
isRemoteIP := params.RemoteIP != (netip.Addr{})
|
||||||
|
isSubnet := params.Subnet != (netip.Prefix{})
|
||||||
|
isMAC := params.MAC != nil
|
||||||
|
|
||||||
|
for {
|
||||||
|
switch {
|
||||||
|
case isClientID:
|
||||||
|
isClientID = false
|
||||||
|
p, ok = s.index.findByClientID(params.ClientID)
|
||||||
|
case isRemoteIP:
|
||||||
|
isRemoteIP = false
|
||||||
|
p, ok = s.findByIP(params.RemoteIP)
|
||||||
|
case isSubnet:
|
||||||
|
isSubnet = false
|
||||||
|
p, ok = s.index.findByCIDR(params.Subnet)
|
||||||
|
case isMAC:
|
||||||
|
isMAC = false
|
||||||
|
p, ok = s.index.findByMAC(params.MAC)
|
||||||
|
default:
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
return p.ShallowClone(), true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// findByIP finds persistent client by IP address. s.mu is expected to be
|
||||||
|
// locked.
|
||||||
|
func (s *Storage) findByIP(addr netip.Addr) (p *Persistent, ok bool) {
|
||||||
|
p, ok = s.index.findByIP(addr)
|
||||||
if ok {
|
if ok {
|
||||||
return p.ShallowClone(), ok
|
return p, true
|
||||||
}
|
}
|
||||||
|
|
||||||
ip, err := netip.ParseAddr(id)
|
foundMAC := s.dhcp.MACByIP(addr)
|
||||||
if err != nil {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
foundMAC := s.dhcp.MACByIP(ip)
|
|
||||||
if foundMAC != nil {
|
if foundMAC != nil {
|
||||||
return s.FindByMAC(foundMAC)
|
return s.index.findByMAC(foundMAC)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, false
|
return nil, false
|
||||||
@@ -487,6 +627,8 @@ func (s *Storage) Find(id string) (p *Persistent, ok bool) {
|
|||||||
//
|
//
|
||||||
// Note that multiple clients can have the same IP address with different zones.
|
// Note that multiple clients can have the same IP address with different zones.
|
||||||
// Therefore, the result of this method is indeterminate.
|
// Therefore, the result of this method is indeterminate.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Consider accepting [FindParams].
|
||||||
func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) {
|
func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
@@ -498,7 +640,7 @@ func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) {
|
|||||||
|
|
||||||
foundMAC := s.dhcp.MACByIP(ip)
|
foundMAC := s.dhcp.MACByIP(ip)
|
||||||
if foundMAC != nil {
|
if foundMAC != nil {
|
||||||
return s.FindByMAC(foundMAC)
|
return s.index.findByMAC(foundMAC)
|
||||||
}
|
}
|
||||||
|
|
||||||
p = s.index.findByIPWithoutZone(ip)
|
p = s.index.findByIPWithoutZone(ip)
|
||||||
@@ -509,17 +651,6 @@ func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) {
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindByMAC finds persistent client by MAC and returns its shallow copy. s.mu
|
|
||||||
// is expected to be locked.
|
|
||||||
func (s *Storage) FindByMAC(mac net.HardwareAddr) (p *Persistent, ok bool) {
|
|
||||||
p, ok = s.index.findByMAC(mac)
|
|
||||||
if ok {
|
|
||||||
return p.ShallowClone(), ok
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveByName removes persistent client information. ok is false if no such
|
// RemoveByName removes persistent client information. ok is false if no such
|
||||||
// client exists by that name.
|
// client exists by that name.
|
||||||
func (s *Storage) RemoveByName(ctx context.Context, name string) (ok bool) {
|
func (s *Storage) RemoveByName(ctx context.Context, name string) (ok bool) {
|
||||||
@@ -648,7 +779,7 @@ func (s *Storage) CustomUpstreamConfig(
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
c, ok := s.index.findByClientID(id)
|
c, ok := s.index.findByClientID(ClientID(id))
|
||||||
if !ok {
|
if !ok {
|
||||||
c, ok = s.index.findByIP(addr)
|
c, ok = s.index.findByIP(addr)
|
||||||
}
|
}
|
||||||
@@ -682,7 +813,7 @@ func (s *Storage) ClearUpstreamCache() {
|
|||||||
// ClientID or client IP address, and applies it to the filtering settings.
|
// ClientID or client IP address, and applies it to the filtering settings.
|
||||||
// setts must not be nil.
|
// setts must not be nil.
|
||||||
func (s *Storage) ApplyClientFiltering(id string, addr netip.Addr, setts *filtering.Settings) {
|
func (s *Storage) ApplyClientFiltering(id string, addr netip.Addr, setts *filtering.Settings) {
|
||||||
c, ok := s.index.findByClientID(id)
|
c, ok := s.index.findByClientID(ClientID(id))
|
||||||
if !ok {
|
if !ok {
|
||||||
c, ok = s.index.findByIP(addr)
|
c, ok = s.index.findByIP(addr)
|
||||||
}
|
}
|
||||||
@@ -690,7 +821,7 @@ func (s *Storage) ApplyClientFiltering(id string, addr netip.Addr, setts *filter
|
|||||||
if !ok {
|
if !ok {
|
||||||
foundMAC := s.dhcp.MACByIP(addr)
|
foundMAC := s.dhcp.MACByIP(addr)
|
||||||
if foundMAC != nil {
|
if foundMAC != nil {
|
||||||
c, ok = s.FindByMAC(foundMAC)
|
c, ok = s.index.findByMAC(foundMAC)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/hostsfile"
|
"github.com/AdguardTeam/golibs/hostsfile"
|
||||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
@@ -350,15 +351,15 @@ func TestClientsDHCP(t *testing.T) {
|
|||||||
cliName1 = "one.dhcp"
|
cliName1 = "one.dhcp"
|
||||||
|
|
||||||
cliIP2 = netip.MustParseAddr("2.2.2.2")
|
cliIP2 = netip.MustParseAddr("2.2.2.2")
|
||||||
cliMAC2 = mustParseMAC("22:22:22:22:22:22")
|
cliMAC2 = errors.Must(net.ParseMAC("22:22:22:22:22:22"))
|
||||||
cliName2 = "two.dhcp"
|
cliName2 = "two.dhcp"
|
||||||
|
|
||||||
cliIP3 = netip.MustParseAddr("3.3.3.3")
|
cliIP3 = netip.MustParseAddr("3.3.3.3")
|
||||||
cliMAC3 = mustParseMAC("33:33:33:33:33:33")
|
cliMAC3 = errors.Must(net.ParseMAC("33:33:33:33:33:33"))
|
||||||
cliName3 = "three.dhcp"
|
cliName3 = "three.dhcp"
|
||||||
|
|
||||||
prsCliIP = netip.MustParseAddr("4.3.2.1")
|
prsCliIP = netip.MustParseAddr("4.3.2.1")
|
||||||
prsCliMAC = mustParseMAC("AA:AA:AA:AA:AA:AA")
|
prsCliMAC = errors.Must(net.ParseMAC("AA:AA:AA:AA:AA:AA"))
|
||||||
prsCliName = "persistent.dhcp"
|
prsCliName = "persistent.dhcp"
|
||||||
|
|
||||||
otherARPCliName = "other.arp"
|
otherARPCliName = "other.arp"
|
||||||
@@ -519,7 +520,11 @@ func TestClientsDHCP(t *testing.T) {
|
|||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
prsCli, ok := storage.Find(prsCliIP.String())
|
params := &client.FindParams{}
|
||||||
|
err = params.Set(prsCliIP.String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
prsCli, ok := storage.Find(params)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
|
|
||||||
assert.Equal(t, prsCliName, prsCli.Name)
|
assert.Equal(t, prsCliName, prsCli.Name)
|
||||||
@@ -663,17 +668,6 @@ func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
|
|
||||||
// error.
|
|
||||||
func mustParseMAC(s string) (mac net.HardwareAddr) {
|
|
||||||
mac, err := net.ParseMAC(s)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return mac
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStorage_Add(t *testing.T) {
|
func TestStorage_Add(t *testing.T) {
|
||||||
const (
|
const (
|
||||||
existingName = "existing_name"
|
existingName = "existing_name"
|
||||||
@@ -693,7 +687,7 @@ func TestStorage_Add(t *testing.T) {
|
|||||||
Name: existingName,
|
Name: existingName,
|
||||||
IPs: []netip.Addr{existingIP},
|
IPs: []netip.Addr{existingIP},
|
||||||
Subnets: []netip.Prefix{existingSubnet},
|
Subnets: []netip.Prefix{existingSubnet},
|
||||||
ClientIDs: []string{existingClientID},
|
ClientIDs: []client.ClientID{existingClientID},
|
||||||
UID: existingClientUID,
|
UID: existingClientUID,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -761,7 +755,7 @@ func TestStorage_Add(t *testing.T) {
|
|||||||
name: "duplicate_client_id",
|
name: "duplicate_client_id",
|
||||||
cli: &client.Persistent{
|
cli: &client.Persistent{
|
||||||
Name: "duplicate_client_id",
|
Name: "duplicate_client_id",
|
||||||
ClientIDs: []string{existingClientID},
|
ClientIDs: []client.ClientID{existingClientID},
|
||||||
UID: client.MustNewUID(),
|
UID: client.MustNewUID(),
|
||||||
},
|
},
|
||||||
wantErrMsg: `adding client: another client "existing_name" ` +
|
wantErrMsg: `adding client: another client "existing_name" ` +
|
||||||
@@ -898,12 +892,12 @@ func TestStorage_Find(t *testing.T) {
|
|||||||
|
|
||||||
clientWithMAC = &client.Persistent{
|
clientWithMAC = &client.Persistent{
|
||||||
Name: "client_with_mac",
|
Name: "client_with_mac",
|
||||||
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
|
MACs: []net.HardwareAddr{errors.Must(net.ParseMAC(cliMAC))},
|
||||||
}
|
}
|
||||||
|
|
||||||
clientWithID = &client.Persistent{
|
clientWithID = &client.Persistent{
|
||||||
Name: "client_with_id",
|
Name: "client_with_id",
|
||||||
ClientIDs: []string{cliID},
|
ClientIDs: []client.ClientID{cliID},
|
||||||
}
|
}
|
||||||
|
|
||||||
clientLinkLocal = &client.Persistent{
|
clientLinkLocal = &client.Persistent{
|
||||||
@@ -950,7 +944,11 @@ func TestStorage_Find(t *testing.T) {
|
|||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
for _, id := range tc.ids {
|
for _, id := range tc.ids {
|
||||||
c, ok := s.Find(id)
|
params := &client.FindParams{}
|
||||||
|
err := params.Set(id)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
c, ok := s.Find(params)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
|
|
||||||
assert.Equal(t, tc.want, c)
|
assert.Equal(t, tc.want, c)
|
||||||
@@ -959,7 +957,11 @@ func TestStorage_Find(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.Run("not_found", func(t *testing.T) {
|
t.Run("not_found", func(t *testing.T) {
|
||||||
_, ok := s.Find(cliIPNone)
|
params := &client.FindParams{}
|
||||||
|
err := params.Set(cliIPNone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, ok := s.Find(params)
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -1025,127 +1027,6 @@ func TestStorage_FindLoose(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStorage_FindByName(t *testing.T) {
|
|
||||||
const (
|
|
||||||
cliIP1 = "1.1.1.1"
|
|
||||||
cliIP2 = "2.2.2.2"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
clientExistingName = "client_existing"
|
|
||||||
clientAnotherExistingName = "client_another_existing"
|
|
||||||
nonExistingClientName = "client_non_existing"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
clientExisting = &client.Persistent{
|
|
||||||
Name: clientExistingName,
|
|
||||||
IPs: []netip.Addr{netip.MustParseAddr(cliIP1)},
|
|
||||||
}
|
|
||||||
|
|
||||||
clientAnotherExisting = &client.Persistent{
|
|
||||||
Name: clientAnotherExistingName,
|
|
||||||
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
clients := []*client.Persistent{
|
|
||||||
clientExisting,
|
|
||||||
clientAnotherExisting,
|
|
||||||
}
|
|
||||||
s := newStorage(t, clients)
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
want *client.Persistent
|
|
||||||
name string
|
|
||||||
clientName string
|
|
||||||
}{{
|
|
||||||
name: "existing",
|
|
||||||
clientName: clientExistingName,
|
|
||||||
want: clientExisting,
|
|
||||||
}, {
|
|
||||||
name: "another_existing",
|
|
||||||
clientName: clientAnotherExistingName,
|
|
||||||
want: clientAnotherExisting,
|
|
||||||
}, {
|
|
||||||
name: "non_existing",
|
|
||||||
clientName: nonExistingClientName,
|
|
||||||
want: nil,
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
c, ok := s.FindByName(tc.clientName)
|
|
||||||
if tc.want == nil {
|
|
||||||
assert.False(t, ok)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equal(t, tc.want, c)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStorage_FindByMAC(t *testing.T) {
|
|
||||||
var (
|
|
||||||
cliMAC = mustParseMAC("11:11:11:11:11:11")
|
|
||||||
cliAnotherMAC = mustParseMAC("22:22:22:22:22:22")
|
|
||||||
nonExistingClientMAC = mustParseMAC("33:33:33:33:33:33")
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
clientExisting = &client.Persistent{
|
|
||||||
Name: "client",
|
|
||||||
MACs: []net.HardwareAddr{cliMAC},
|
|
||||||
}
|
|
||||||
|
|
||||||
clientAnotherExisting = &client.Persistent{
|
|
||||||
Name: "another_client",
|
|
||||||
MACs: []net.HardwareAddr{cliAnotherMAC},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
clients := []*client.Persistent{
|
|
||||||
clientExisting,
|
|
||||||
clientAnotherExisting,
|
|
||||||
}
|
|
||||||
s := newStorage(t, clients)
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
want *client.Persistent
|
|
||||||
name string
|
|
||||||
clientMAC net.HardwareAddr
|
|
||||||
}{{
|
|
||||||
name: "existing",
|
|
||||||
clientMAC: cliMAC,
|
|
||||||
want: clientExisting,
|
|
||||||
}, {
|
|
||||||
name: "another_existing",
|
|
||||||
clientMAC: cliAnotherMAC,
|
|
||||||
want: clientAnotherExisting,
|
|
||||||
}, {
|
|
||||||
name: "non_existing",
|
|
||||||
clientMAC: nonExistingClientMAC,
|
|
||||||
want: nil,
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
c, ok := s.FindByMAC(tc.clientMAC)
|
|
||||||
if tc.want == nil {
|
|
||||||
assert.False(t, ok)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equal(t, tc.want, c)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStorage_Update(t *testing.T) {
|
func TestStorage_Update(t *testing.T) {
|
||||||
const (
|
const (
|
||||||
clientName = "client_name"
|
clientName = "client_name"
|
||||||
@@ -1162,7 +1043,7 @@ func TestStorage_Update(t *testing.T) {
|
|||||||
Name: obstructingName,
|
Name: obstructingName,
|
||||||
IPs: []netip.Addr{obstructingIP},
|
IPs: []netip.Addr{obstructingIP},
|
||||||
Subnets: []netip.Prefix{obstructingSubnet},
|
Subnets: []netip.Prefix{obstructingSubnet},
|
||||||
ClientIDs: []string{obstructingClientID},
|
ClientIDs: []client.ClientID{obstructingClientID},
|
||||||
}
|
}
|
||||||
|
|
||||||
clientToUpdate := &client.Persistent{
|
clientToUpdate := &client.Persistent{
|
||||||
@@ -1211,7 +1092,7 @@ func TestStorage_Update(t *testing.T) {
|
|||||||
name: "duplicate_client_id",
|
name: "duplicate_client_id",
|
||||||
cli: &client.Persistent{
|
cli: &client.Persistent{
|
||||||
Name: "duplicate_client_id",
|
Name: "duplicate_client_id",
|
||||||
ClientIDs: []string{obstructingClientID},
|
ClientIDs: []client.ClientID{obstructingClientID},
|
||||||
UID: client.MustNewUID(),
|
UID: client.MustNewUID(),
|
||||||
},
|
},
|
||||||
wantErrMsg: `updating client: another client "obstructing_name" ` +
|
wantErrMsg: `updating client: another client "obstructing_name" ` +
|
||||||
@@ -1238,19 +1119,19 @@ func TestStorage_Update(t *testing.T) {
|
|||||||
func TestStorage_RangeByName(t *testing.T) {
|
func TestStorage_RangeByName(t *testing.T) {
|
||||||
sortedClients := []*client.Persistent{{
|
sortedClients := []*client.Persistent{{
|
||||||
Name: "clientA",
|
Name: "clientA",
|
||||||
ClientIDs: []string{"A"},
|
ClientIDs: []client.ClientID{"A"},
|
||||||
}, {
|
}, {
|
||||||
Name: "clientB",
|
Name: "clientB",
|
||||||
ClientIDs: []string{"B"},
|
ClientIDs: []client.ClientID{"B"},
|
||||||
}, {
|
}, {
|
||||||
Name: "clientC",
|
Name: "clientC",
|
||||||
ClientIDs: []string{"C"},
|
ClientIDs: []client.ClientID{"C"},
|
||||||
}, {
|
}, {
|
||||||
Name: "clientD",
|
Name: "clientD",
|
||||||
ClientIDs: []string{"D"},
|
ClientIDs: []client.ClientID{"D"},
|
||||||
}, {
|
}, {
|
||||||
Name: "clientE",
|
Name: "clientE",
|
||||||
ClientIDs: []string{"E"},
|
ClientIDs: []client.ClientID{"E"},
|
||||||
}}
|
}}
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
@@ -1306,7 +1187,7 @@ func TestStorage_CustomUpstreamConfig(t *testing.T) {
|
|||||||
existingClient := &client.Persistent{
|
existingClient := &client.Persistent{
|
||||||
Name: existingName,
|
Name: existingName,
|
||||||
IPs: []netip.Addr{existingIP},
|
IPs: []netip.Addr{existingIP},
|
||||||
ClientIDs: []string{existingClientID},
|
ClientIDs: []client.ClientID{existingClientID},
|
||||||
UID: existingClientUID,
|
UID: existingClientUID,
|
||||||
Upstreams: []string{"192.0.2.0"},
|
Upstreams: []string{"192.0.2.0"},
|
||||||
}
|
}
|
||||||
@@ -1381,3 +1262,182 @@ func TestStorage_CustomUpstreamConfig(t *testing.T) {
|
|||||||
assert.NotEqual(t, conf, updConf)
|
assert.NotEqual(t, conf, updConf)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkFindParams_Set(b *testing.B) {
|
||||||
|
const (
|
||||||
|
testIPStr = "192.0.2.1"
|
||||||
|
testCIDRStr = "192.0.2.0/24"
|
||||||
|
testMACStr = "02:00:00:00:00:00"
|
||||||
|
testClientID = "clientid"
|
||||||
|
)
|
||||||
|
|
||||||
|
benchCases := []struct {
|
||||||
|
wantErr error
|
||||||
|
params *client.FindParams
|
||||||
|
name string
|
||||||
|
id string
|
||||||
|
}{{
|
||||||
|
wantErr: nil,
|
||||||
|
params: &client.FindParams{
|
||||||
|
ClientID: testClientID,
|
||||||
|
},
|
||||||
|
name: "client_id",
|
||||||
|
id: testClientID,
|
||||||
|
}, {
|
||||||
|
wantErr: nil,
|
||||||
|
params: &client.FindParams{
|
||||||
|
RemoteIP: netip.MustParseAddr(testIPStr),
|
||||||
|
},
|
||||||
|
name: "ip_address",
|
||||||
|
id: testIPStr,
|
||||||
|
}, {
|
||||||
|
wantErr: nil,
|
||||||
|
params: &client.FindParams{
|
||||||
|
Subnet: netip.MustParsePrefix(testCIDRStr),
|
||||||
|
},
|
||||||
|
name: "subnet",
|
||||||
|
id: testCIDRStr,
|
||||||
|
}, {
|
||||||
|
wantErr: nil,
|
||||||
|
params: &client.FindParams{
|
||||||
|
MAC: errors.Must(net.ParseMAC(testMACStr)),
|
||||||
|
},
|
||||||
|
name: "mac_address",
|
||||||
|
id: testMACStr,
|
||||||
|
}, {
|
||||||
|
wantErr: client.ErrBadIdentifier,
|
||||||
|
params: &client.FindParams{},
|
||||||
|
name: "bad_id",
|
||||||
|
id: "!@#$%^&*()_+",
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, bc := range benchCases {
|
||||||
|
b.Run(bc.name, func(b *testing.B) {
|
||||||
|
params := &client.FindParams{}
|
||||||
|
var err error
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
for b.Loop() {
|
||||||
|
err = params.Set(bc.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.ErrorIs(b, err, bc.wantErr)
|
||||||
|
assert.Equal(b, bc.params, params)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Most recent results:
|
||||||
|
//
|
||||||
|
// goos: linux
|
||||||
|
// goarch: amd64
|
||||||
|
// pkg: github.com/AdguardTeam/AdGuardHome/internal/client
|
||||||
|
// cpu: Intel(R) Core(TM) i7-10510U CPU @ 1.80GHz
|
||||||
|
// BenchmarkFindParams_Set/client_id-8 49463488 24.27 ns/op 0 B/op 0 allocs/op
|
||||||
|
// BenchmarkFindParams_Set/ip_address-8 18740977 62.22 ns/op 0 B/op 0 allocs/op
|
||||||
|
// BenchmarkFindParams_Set/subnet-8 10848192 110.0 ns/op 0 B/op 0 allocs/op
|
||||||
|
// BenchmarkFindParams_Set/mac_address-8 8148494 133.2 ns/op 8 B/op 1 allocs/op
|
||||||
|
// BenchmarkFindParams_Set/bad_id-8 73894278 16.29 ns/op 0 B/op 0 allocs/op
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkStorage_Find(b *testing.B) {
|
||||||
|
const (
|
||||||
|
cliID = "cid"
|
||||||
|
cliMAC = "02:00:00:00:00:00"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
cliNameWithID = "client_with_id"
|
||||||
|
cliNameWithIP = "client_with_ip"
|
||||||
|
cliNameWithCIDR = "client_with_cidr"
|
||||||
|
cliNameWithMAC = "client_with_mac"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
cliIP = netip.MustParseAddr("192.0.2.1")
|
||||||
|
cliCIDR = netip.MustParsePrefix("192.0.2.0/24")
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
clientWithID = &client.Persistent{
|
||||||
|
Name: cliNameWithID,
|
||||||
|
ClientIDs: []client.ClientID{cliID},
|
||||||
|
}
|
||||||
|
clientWithIP = &client.Persistent{
|
||||||
|
Name: cliNameWithIP,
|
||||||
|
IPs: []netip.Addr{cliIP},
|
||||||
|
}
|
||||||
|
clientWithCIDR = &client.Persistent{
|
||||||
|
Name: cliNameWithCIDR,
|
||||||
|
Subnets: []netip.Prefix{cliCIDR},
|
||||||
|
}
|
||||||
|
clientWithMAC = &client.Persistent{
|
||||||
|
Name: cliNameWithMAC,
|
||||||
|
MACs: []net.HardwareAddr{errors.Must(net.ParseMAC(cliMAC))},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
clients := []*client.Persistent{
|
||||||
|
clientWithID,
|
||||||
|
clientWithIP,
|
||||||
|
clientWithCIDR,
|
||||||
|
clientWithMAC,
|
||||||
|
}
|
||||||
|
s := newStorage(b, clients)
|
||||||
|
|
||||||
|
benchCases := []struct {
|
||||||
|
params *client.FindParams
|
||||||
|
name string
|
||||||
|
wantName string
|
||||||
|
}{{
|
||||||
|
params: &client.FindParams{
|
||||||
|
ClientID: cliID,
|
||||||
|
},
|
||||||
|
name: "client_id",
|
||||||
|
wantName: cliNameWithID,
|
||||||
|
}, {
|
||||||
|
params: &client.FindParams{
|
||||||
|
RemoteIP: cliIP,
|
||||||
|
},
|
||||||
|
name: "ip_address",
|
||||||
|
wantName: cliNameWithIP,
|
||||||
|
}, {
|
||||||
|
params: &client.FindParams{
|
||||||
|
Subnet: cliCIDR,
|
||||||
|
},
|
||||||
|
name: "subnet",
|
||||||
|
wantName: cliNameWithCIDR,
|
||||||
|
}, {
|
||||||
|
params: &client.FindParams{
|
||||||
|
MAC: errors.Must(net.ParseMAC(cliMAC)),
|
||||||
|
},
|
||||||
|
name: "mac_address",
|
||||||
|
wantName: cliNameWithMAC,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, bc := range benchCases {
|
||||||
|
b.Run(bc.name, func(b *testing.B) {
|
||||||
|
var p *client.Persistent
|
||||||
|
var ok bool
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
for b.Loop() {
|
||||||
|
p, ok = s.Find(bc.params)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(b, ok)
|
||||||
|
assert.NotNil(b, p)
|
||||||
|
assert.Equal(b, bc.wantName, p.Name)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Most recent results:
|
||||||
|
//
|
||||||
|
// goos: linux
|
||||||
|
// goarch: amd64
|
||||||
|
// pkg: github.com/AdguardTeam/AdGuardHome/internal/client
|
||||||
|
// cpu: Intel(R) Core(TM) i7-10510U CPU @ 1.80GHz
|
||||||
|
// BenchmarkStorage_Find/client_id-8 7070107 154.4 ns/op 240 B/op 2 allocs/op
|
||||||
|
// BenchmarkStorage_Find/ip_address-8 6831823 168.6 ns/op 248 B/op 2 allocs/op
|
||||||
|
// BenchmarkStorage_Find/subnet-8 7209050 167.5 ns/op 256 B/op 2 allocs/op
|
||||||
|
// BenchmarkStorage_Find/mac_address-8 5776131 199.7 ns/op 256 B/op 3 allocs/op
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,13 +1,11 @@
|
|||||||
package dhcpsvc_test
|
package dhcpsvc_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// testLocalTLD is a common local TLD for tests.
|
// testLocalTLD is a common local TLD for tests.
|
||||||
@@ -56,11 +54,3 @@ var testInterfaceConf = map[string]*dhcpsvc.InterfaceConfig{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// mustParseMAC parses a hardware address from s and requires no errors.
|
|
||||||
func mustParseMAC(t require.TestingT, s string) (mac net.HardwareAddr) {
|
|
||||||
mac, err := net.ParseMAC(s)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
return mac
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package dhcpsvc_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
@@ -11,6 +12,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -176,9 +178,9 @@ func TestDHCPServer_AddLease(t *testing.T) {
|
|||||||
newIP = netip.MustParseAddr("192.168.0.3")
|
newIP = netip.MustParseAddr("192.168.0.3")
|
||||||
newIPv6 = netip.MustParseAddr("2001:db8::2")
|
newIPv6 = netip.MustParseAddr("2001:db8::2")
|
||||||
|
|
||||||
existMAC = mustParseMAC(t, "01:02:03:04:05:06")
|
existMAC = errors.Must(net.ParseMAC("01:02:03:04:05:06"))
|
||||||
newMAC = mustParseMAC(t, "06:05:04:03:02:01")
|
newMAC = errors.Must(net.ParseMAC("06:05:04:03:02:01"))
|
||||||
ipv6MAC = mustParseMAC(t, "02:03:04:05:06:07")
|
ipv6MAC = errors.Must(net.ParseMAC("02:03:04:05:06:07"))
|
||||||
)
|
)
|
||||||
|
|
||||||
require.NoError(t, srv.AddLease(ctx, &dhcpsvc.Lease{
|
require.NoError(t, srv.AddLease(ctx, &dhcpsvc.Lease{
|
||||||
@@ -291,9 +293,9 @@ func TestDHCPServer_index(t *testing.T) {
|
|||||||
ip3 = netip.MustParseAddr("172.16.0.3")
|
ip3 = netip.MustParseAddr("172.16.0.3")
|
||||||
ip4 = netip.MustParseAddr("172.16.0.4")
|
ip4 = netip.MustParseAddr("172.16.0.4")
|
||||||
|
|
||||||
mac1 = mustParseMAC(t, "01:02:03:04:05:06")
|
mac1 = errors.Must(net.ParseMAC("01:02:03:04:05:06"))
|
||||||
mac2 = mustParseMAC(t, "06:05:04:03:02:01")
|
mac2 = errors.Must(net.ParseMAC("06:05:04:03:02:01"))
|
||||||
mac3 = mustParseMAC(t, "02:03:04:05:06:07")
|
mac3 = errors.Must(net.ParseMAC("02:03:04:05:06:07"))
|
||||||
)
|
)
|
||||||
|
|
||||||
t.Run("ip_idx", func(t *testing.T) {
|
t.Run("ip_idx", func(t *testing.T) {
|
||||||
@@ -349,9 +351,9 @@ func TestDHCPServer_UpdateStaticLease(t *testing.T) {
|
|||||||
ip3 = netip.MustParseAddr("192.168.0.4")
|
ip3 = netip.MustParseAddr("192.168.0.4")
|
||||||
ip4 = netip.MustParseAddr("2001:db8::3")
|
ip4 = netip.MustParseAddr("2001:db8::3")
|
||||||
|
|
||||||
mac1 = mustParseMAC(t, "01:02:03:04:05:06")
|
mac1 = errors.Must(net.ParseMAC("01:02:03:04:05:06"))
|
||||||
mac2 = mustParseMAC(t, "06:05:04:03:02:01")
|
mac2 = errors.Must(net.ParseMAC("06:05:04:03:02:01"))
|
||||||
mac3 = mustParseMAC(t, "06:05:04:03:02:02")
|
mac3 = errors.Must(net.ParseMAC("06:05:04:03:02:02"))
|
||||||
)
|
)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
@@ -452,9 +454,9 @@ func TestDHCPServer_RemoveLease(t *testing.T) {
|
|||||||
newIP = netip.MustParseAddr("192.168.0.3")
|
newIP = netip.MustParseAddr("192.168.0.3")
|
||||||
newIPv6 = netip.MustParseAddr("2001:db8::2")
|
newIPv6 = netip.MustParseAddr("2001:db8::2")
|
||||||
|
|
||||||
existMAC = mustParseMAC(t, "01:02:03:04:05:06")
|
existMAC = errors.Must(net.ParseMAC("01:02:03:04:05:06"))
|
||||||
newMAC = mustParseMAC(t, "02:03:04:05:06:07")
|
newMAC = errors.Must(net.ParseMAC("02:03:04:05:06:07"))
|
||||||
ipv6MAC = mustParseMAC(t, "06:05:04:03:02:01")
|
ipv6MAC = errors.Must(net.ParseMAC("06:05:04:03:02:01"))
|
||||||
)
|
)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
@@ -559,13 +561,13 @@ func TestServer_Leases(t *testing.T) {
|
|||||||
Expiry: expiry,
|
Expiry: expiry,
|
||||||
IP: netip.MustParseAddr("192.168.0.3"),
|
IP: netip.MustParseAddr("192.168.0.3"),
|
||||||
Hostname: "example.host",
|
Hostname: "example.host",
|
||||||
HWAddr: mustParseMAC(t, "AA:AA:AA:AA:AA:AA"),
|
HWAddr: errors.Must(net.ParseMAC("AA:AA:AA:AA:AA:AA")),
|
||||||
IsStatic: false,
|
IsStatic: false,
|
||||||
}, {
|
}, {
|
||||||
Expiry: time.Time{},
|
Expiry: time.Time{},
|
||||||
IP: netip.MustParseAddr("192.168.0.4"),
|
IP: netip.MustParseAddr("192.168.0.4"),
|
||||||
Hostname: "example.static.host",
|
Hostname: "example.static.host",
|
||||||
HWAddr: mustParseMAC(t, "BB:BB:BB:BB:BB:BB"),
|
HWAddr: errors.Must(net.ParseMAC("BB:BB:BB:BB:BB:BB")),
|
||||||
IsStatic: true,
|
IsStatic: true,
|
||||||
}}
|
}}
|
||||||
assert.ElementsMatch(t, wantLeases, srv.Leases())
|
assert.ElementsMatch(t, wantLeases, srv.Leases())
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
"github.com/AdguardTeam/golibs/container"
|
"github.com/AdguardTeam/golibs/container"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/stringutil"
|
"github.com/AdguardTeam/golibs/stringutil"
|
||||||
@@ -51,7 +52,7 @@ func processAccessClients(
|
|||||||
} else if ipnet, err = netip.ParsePrefix(s); err == nil {
|
} else if ipnet, err = netip.ParsePrefix(s); err == nil {
|
||||||
*nets = append(*nets, ipnet)
|
*nets = append(*nets, ipnet)
|
||||||
} else {
|
} else {
|
||||||
err = ValidateClientID(s)
|
err = client.ValidateClientID(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("value %q at index %d: bad ip, cidr, or clientid", s, i)
|
return fmt.Errorf("value %q at index %d: bad ip, cidr, or clientid", s, i)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,26 +7,13 @@ import (
|
|||||||
"path"
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/quic-go/quic-go"
|
"github.com/quic-go/quic-go"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ValidateClientID returns an error if id is not a valid ClientID.
|
|
||||||
//
|
|
||||||
// Keep in sync with [client.ValidateClientID].
|
|
||||||
func ValidateClientID(id string) (err error) {
|
|
||||||
err = netutil.ValidateHostnameLabel(id)
|
|
||||||
if err != nil {
|
|
||||||
// Replace the domain name label wrapper with our own.
|
|
||||||
return fmt.Errorf("invalid clientid %q: %w", id, errors.Unwrap(err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// clientIDFromClientServerName extracts and validates a ClientID. hostSrvName
|
// clientIDFromClientServerName extracts and validates a ClientID. hostSrvName
|
||||||
// is the server name of the host. cliSrvName is the server name as sent by the
|
// is the server name of the host. cliSrvName is the server name as sent by the
|
||||||
// client. When strict is true, and client and host server name don't match,
|
// client. When strict is true, and client and host server name don't match,
|
||||||
@@ -53,7 +40,7 @@ func clientIDFromClientServerName(
|
|||||||
}
|
}
|
||||||
|
|
||||||
clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1]
|
clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1]
|
||||||
err = ValidateClientID(clientID)
|
err = client.ValidateClientID(clientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Don't wrap the error, because it's informative enough as is.
|
// Don't wrap the error, because it's informative enough as is.
|
||||||
return "", err
|
return "", err
|
||||||
@@ -93,7 +80,7 @@ func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err e
|
|||||||
return "", fmt.Errorf("clientid check: invalid path %q: extra parts", origPath)
|
return "", fmt.Errorf("clientid check: invalid path %q: extra parts", origPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ValidateClientID(clientID)
|
err = client.ValidateClientID(clientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("clientid check: %w", err)
|
return "", fmt.Errorf("clientid check: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ type clientsContainer struct {
|
|||||||
// filter. It must not be nil.
|
// filter. It must not be nil.
|
||||||
baseLogger *slog.Logger
|
baseLogger *slog.Logger
|
||||||
|
|
||||||
|
// logger is used for logging the operation of the client container. It
|
||||||
|
// must not be nil.
|
||||||
|
logger *slog.Logger
|
||||||
|
|
||||||
// storage stores information about persistent clients.
|
// storage stores information about persistent clients.
|
||||||
storage *client.Storage
|
storage *client.Storage
|
||||||
|
|
||||||
@@ -58,6 +62,7 @@ type clientsContainer struct {
|
|||||||
// BlockedClientChecker checks if a client is blocked by the current access
|
// BlockedClientChecker checks if a client is blocked by the current access
|
||||||
// settings.
|
// settings.
|
||||||
type BlockedClientChecker interface {
|
type BlockedClientChecker interface {
|
||||||
|
// TODO(s.chzhen): Accept [client.FindParams].
|
||||||
IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string)
|
IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,6 +85,7 @@ func (clients *clientsContainer) Init(
|
|||||||
}
|
}
|
||||||
|
|
||||||
clients.baseLogger = baseLogger
|
clients.baseLogger = baseLogger
|
||||||
|
clients.logger = baseLogger.With(slogutil.KeyPrefix, "client_container")
|
||||||
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
|
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
|
||||||
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
|
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
|
||||||
|
|
||||||
@@ -269,7 +275,7 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
|||||||
|
|
||||||
BlockedServices: cli.BlockedServices.Clone(),
|
BlockedServices: cli.BlockedServices.Clone(),
|
||||||
|
|
||||||
IDs: cli.IDs(),
|
IDs: cli.Identifiers(),
|
||||||
Tags: slices.Clone(cli.Tags),
|
Tags: slices.Clone(cli.Tags),
|
||||||
Upstreams: slices.Clone(cli.Upstreams),
|
Upstreams: slices.Clone(cli.Upstreams),
|
||||||
|
|
||||||
@@ -356,15 +362,27 @@ func (clients *clientsContainer) clientOrArtificial(
|
|||||||
}, true
|
}, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// shouldCountClient is a wrapper around [clientsContainer.find] to make it a
|
// shouldCountClient is a wrapper around [client.Storage.Find] to make it a
|
||||||
// valid client information finder for the statistics. If no information about
|
// valid client information finder for the statistics. If no information about
|
||||||
// the client is found, it returns true.
|
// the client is found, it returns true. Values of ids must be either a valid
|
||||||
|
// ClientID or a valid IP address.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Accept [client.FindParams].
|
||||||
func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) {
|
func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) {
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
|
params := &client.FindParams{}
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
client, ok := clients.storage.Find(id)
|
err := params.Set(id)
|
||||||
|
if err != nil {
|
||||||
|
// Should not happen.
|
||||||
|
clients.logger.Warn("parsing find params", slogutil.KeyError, err)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
client, ok := clients.storage.Find(params)
|
||||||
if ok {
|
if ok {
|
||||||
return !client.IgnoreStatistics
|
return !client.IgnoreStatistics
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -300,7 +300,7 @@ func clientToJSON(c *client.Persistent) (cj *clientJSON) {
|
|||||||
|
|
||||||
return &clientJSON{
|
return &clientJSON{
|
||||||
Name: c.Name,
|
Name: c.Name,
|
||||||
IDs: c.IDs(),
|
IDs: c.Identifiers(),
|
||||||
Tags: c.Tags,
|
Tags: c.Tags,
|
||||||
UseGlobalSettings: !c.UseOwnSettings,
|
UseGlobalSettings: !c.UseOwnSettings,
|
||||||
FilteringEnabled: c.FilteringEnabled,
|
FilteringEnabled: c.FilteringEnabled,
|
||||||
@@ -428,32 +428,53 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
|||||||
// Deprecated: Remove it when migration to the new API is over.
|
// Deprecated: Remove it when migration to the new API is over.
|
||||||
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
|
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
|
||||||
q := r.URL.Query()
|
q := r.URL.Query()
|
||||||
data := []map[string]*clientJSON{}
|
data := make([]map[string]*clientJSON, 0, len(q))
|
||||||
|
params := &client.FindParams{}
|
||||||
|
var err error
|
||||||
|
|
||||||
for i := range len(q) {
|
for i := range len(q) {
|
||||||
idStr := q.Get(fmt.Sprintf("ip%d", i))
|
idStr := q.Get(fmt.Sprintf("ip%d", i))
|
||||||
if idStr == "" {
|
if idStr == "" {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = params.Set(idStr)
|
||||||
|
if err != nil {
|
||||||
|
clients.logger.DebugContext(
|
||||||
|
r.Context(),
|
||||||
|
"finding client",
|
||||||
|
"id", idStr,
|
||||||
|
slogutil.KeyError, err,
|
||||||
|
)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
data = append(data, map[string]*clientJSON{
|
data = append(data, map[string]*clientJSON{
|
||||||
idStr: clients.findClient(idStr),
|
idStr: clients.findClient(idStr, params),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
aghhttp.WriteJSONResponseOK(w, r, data)
|
aghhttp.WriteJSONResponseOK(w, r, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// findClient returns available information about a client by idStr from the
|
// findClient returns available information about a client by params from the
|
||||||
// client's storage or access settings. cj is guaranteed to be non-nil.
|
// client's storage or access settings. idStr is the string representation of
|
||||||
func (clients *clientsContainer) findClient(idStr string) (cj *clientJSON) {
|
// typed params. params must not be nil. cj is guaranteed to be non-nil.
|
||||||
ip, _ := netip.ParseAddr(idStr)
|
func (clients *clientsContainer) findClient(
|
||||||
c, ok := clients.storage.Find(idStr)
|
idStr string,
|
||||||
|
params *client.FindParams,
|
||||||
|
) (cj *clientJSON) {
|
||||||
|
c, ok := clients.storage.Find(params)
|
||||||
if !ok {
|
if !ok {
|
||||||
return clients.findRuntime(ip, idStr)
|
return clients.findRuntime(idStr, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
cj = clientToJSON(c)
|
cj = clientToJSON(c)
|
||||||
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
|
disallowed, rule := clients.clientChecker.IsBlockedClient(
|
||||||
|
params.RemoteIP,
|
||||||
|
string(params.ClientID),
|
||||||
|
)
|
||||||
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
|
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
|
||||||
|
|
||||||
return cj
|
return cj
|
||||||
@@ -472,7 +493,8 @@ type searchClientJSON struct {
|
|||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleSearchClient is the handler for the POST /control/clients/search HTTP API.
|
// handleSearchClient is the handler for the POST /control/clients/search HTTP
|
||||||
|
// API.
|
||||||
func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *http.Request) {
|
func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *http.Request) {
|
||||||
q := searchQueryJSON{}
|
q := searchQueryJSON{}
|
||||||
err := json.NewDecoder(r.Body).Decode(&q)
|
err := json.NewDecoder(r.Body).Decode(&q)
|
||||||
@@ -482,11 +504,25 @@ func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *ht
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
data := []map[string]*clientJSON{}
|
data := make([]map[string]*clientJSON, 0, len(q.Clients))
|
||||||
|
params := &client.FindParams{}
|
||||||
|
|
||||||
for _, c := range q.Clients {
|
for _, c := range q.Clients {
|
||||||
idStr := c.ID
|
idStr := c.ID
|
||||||
|
err = params.Set(idStr)
|
||||||
|
if err != nil {
|
||||||
|
clients.logger.DebugContext(
|
||||||
|
r.Context(),
|
||||||
|
"searching client",
|
||||||
|
"id", idStr,
|
||||||
|
slogutil.KeyError, err,
|
||||||
|
)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
data = append(data, map[string]*clientJSON{
|
data = append(data, map[string]*clientJSON{
|
||||||
idStr: clients.findClient(idStr),
|
idStr: clients.findClient(idStr, params),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -494,38 +530,37 @@ func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *ht
|
|||||||
}
|
}
|
||||||
|
|
||||||
// findRuntime looks up the IP in runtime and temporary storages, like
|
// findRuntime looks up the IP in runtime and temporary storages, like
|
||||||
// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
|
// /etc/hosts tables, DHCP leases, or blocklists. params must not be nil. cj
|
||||||
// non-nil.
|
// is guaranteed to be non-nil.
|
||||||
func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) {
|
func (clients *clientsContainer) findRuntime(
|
||||||
|
idStr string,
|
||||||
|
params *client.FindParams,
|
||||||
|
) (cj *clientJSON) {
|
||||||
|
var host string
|
||||||
|
whois := &whois.Info{}
|
||||||
|
|
||||||
|
ip := params.RemoteIP
|
||||||
rc := clients.storage.ClientRuntime(ip)
|
rc := clients.storage.ClientRuntime(ip)
|
||||||
if rc == nil {
|
if rc != nil {
|
||||||
// It is still possible that the IP used to be in the runtime clients
|
_, host = rc.Info()
|
||||||
// list, but then the server was reloaded. So, check the DNS server's
|
whois = whoisOrEmpty(rc)
|
||||||
// blocked IP list.
|
|
||||||
//
|
|
||||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
|
|
||||||
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
|
|
||||||
cj = &clientJSON{
|
|
||||||
IDs: []string{idStr},
|
|
||||||
Disallowed: &disallowed,
|
|
||||||
DisallowedRule: &rule,
|
|
||||||
WHOIS: &whois.Info{},
|
|
||||||
}
|
|
||||||
|
|
||||||
return cj
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_, host := rc.Info()
|
// Check the DNS server's blocked IP list regardless of whether a runtime
|
||||||
cj = &clientJSON{
|
// client was found or not. This is because it's still possible that the
|
||||||
Name: host,
|
// runtime client associated with the IP address was stored previously, but
|
||||||
IDs: []string{idStr},
|
// then the server was reloaded.
|
||||||
WHOIS: whoisOrEmpty(rc),
|
//
|
||||||
|
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
|
||||||
|
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, string(params.ClientID))
|
||||||
|
|
||||||
|
return &clientJSON{
|
||||||
|
Name: host,
|
||||||
|
IDs: []string{idStr},
|
||||||
|
WHOIS: whois,
|
||||||
|
Disallowed: &disallowed,
|
||||||
|
DisallowedRule: &rule,
|
||||||
}
|
}
|
||||||
|
|
||||||
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
|
|
||||||
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
|
|
||||||
|
|
||||||
return cj
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterClientsHandlers registers HTTP handlers
|
// RegisterClientsHandlers registers HTTP handlers
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ func TestClientsContainer_HandleAddClient(t *testing.T) {
|
|||||||
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
|
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
|
||||||
|
|
||||||
clientEmptyID := newPersistentClient("empty_client_id")
|
clientEmptyID := newPersistentClient("empty_client_id")
|
||||||
clientEmptyID.ClientIDs = []string{""}
|
clientEmptyID.ClientIDs = []client.ClientID{""}
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -278,7 +278,7 @@ func TestClientsContainer_HandleUpdateClient(t *testing.T) {
|
|||||||
clientModified := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
|
clientModified := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
|
||||||
|
|
||||||
clientEmptyID := newPersistentClient("empty_client_id")
|
clientEmptyID := newPersistentClient("empty_client_id")
|
||||||
clientEmptyID.ClientIDs = []string{""}
|
clientEmptyID.ClientIDs = []client.ClientID{""}
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"path"
|
"path"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/httphdr"
|
"github.com/AdguardTeam/golibs/httphdr"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
@@ -151,7 +151,7 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
|
|||||||
|
|
||||||
clientID := q.Get("client_id")
|
clientID := q.Get("client_id")
|
||||||
if clientID != "" {
|
if clientID != "" {
|
||||||
err = dnsforward.ValidateClientID(clientID)
|
err = client.ValidateClientID(clientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respondJSONError(w, http.StatusBadRequest, err.Error())
|
respondJSONError(w, http.StatusBadRequest, err.Error())
|
||||||
|
|
||||||
|
|||||||
@@ -980,7 +980,8 @@
|
|||||||
- 'clients'
|
- 'clients'
|
||||||
'operationId': 'clientsSearch'
|
'operationId': 'clientsSearch'
|
||||||
'summary': >
|
'summary': >
|
||||||
Get information about clients by their IP addresses, CIDRs, MAC addresses, or ClientIDs.
|
Retrieve information about clients by performing an exact match search
|
||||||
|
using IP addresses, CIDRs, MAC addresses, or ClientIDs.
|
||||||
'requestBody':
|
'requestBody':
|
||||||
'content':
|
'content':
|
||||||
'application/json':
|
'application/json':
|
||||||
|
|||||||
Reference in New Issue
Block a user