Pull request 2378: AGDNS-2750-find-client

Merge in DNS/adguard-home from AGDNS-2750-find-client to master

Squashed commit of the following:

commit 98f1a8ca4622b6f502a5092273b9724203fe0bd8
Merge: 9270222d8 4ccc2a213
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Apr 23 17:53:20 2025 +0300

    Merge branch 'master' into AGDNS-2750-find-client

commit 9270222d8e9e03038e9434b54496cbb6164463cd
Merge: 6468ceec8 c7c62ad3b
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Apr 21 19:40:58 2025 +0300

    Merge branch 'master' into AGDNS-2750-find-client

commit 6468ceec82d30084771a53ff6720a8c11c68bf2f
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Apr 21 19:40:52 2025 +0300

    home: imp docs

commit 3fd4735a0d6db4fdf2d46f3da9794a687fdcaa8b
Merge: 1311a5869 a8fdf1c55
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Fri Apr 18 19:43:36 2025 +0300

    Merge branch 'master' into AGDNS-2750-find-client

commit 1311a58695de00f20c9704378ee6e964a44d1c59
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Fri Apr 18 19:42:41 2025 +0300

    home: imp code

commit b1f2c4c883c9476c5135140abac31f8ae6609b4f
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Apr 16 16:47:59 2025 +0300

    home: imp code

commit d0a5abd66587c1ad602c2ccf6c8a45a3dfe39a5c
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Apr 15 14:58:31 2025 +0300

    client: imp naming

commit 5accdca325551237f003f1c416891b488fe5290b
Merge: 6a00232f7 4d258972d
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Apr 14 19:40:40 2025 +0300

    Merge branch 'master' into AGDNS-2750-find-client

commit 6a00232f76a0fe5ce781aa01637b6e04ace7250d
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Apr 14 19:30:32 2025 +0300

    home: imp code

commit 8633886457c6aab75f5676494b1f49d9811e9ab9
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Fri Apr 11 15:29:25 2025 +0300

    all: imp code

commit d6f16879e7b054a5ffac59131d2a6eff1da659c0
Merge: 58236fdec 6d282ae71
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Apr 10 21:35:23 2025 +0300

    Merge branch 'master' into AGDNS-2750-find-client

commit 58236fdec5b64e83a44680ff8a89badc18ec81f1
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Apr 10 21:23:01 2025 +0300

    all: upd ci

commit 3c4d946d7970987677d4ac984394e18987a29f9a
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Apr 10 21:16:03 2025 +0300

    all: upd go

commit cc1c97734506a9ffbe70fd3c676284e58a21ba46
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Apr 10 20:58:56 2025 +0300

    all: imp code

commit 8f061c933152481a4c80eef2af575efd4919d82b
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Apr 9 16:49:11 2025 +0300

    all: imp docs

commit 8d19355f1c519211a56cec3f23d527922d4f2ee0
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Apr 7 21:35:06 2025 +0300

    all: imp code

commit f1e853f57e5d54d13bedcdab4f8e21e112f3a356
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Apr 2 14:57:40 2025 +0300

    all: imp code

commit 6a6ac7f899f29ddc90a583c80562233e646ba1d6
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Apr 1 19:51:56 2025 +0300

    client: imp tests

commit 52040ee7393d0483c682f2f37d7b70f12f9cf621
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Apr 1 19:28:18 2025 +0300

    all: imp code

commit 1e09208dbd2d35c3f6b2ade169324e23d1a643a5
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Mar 26 15:33:02 2025 +0300

    all: imp code

... and 2 more commits
This commit is contained in:
Stanislav Chzhen
2025-04-23 18:10:52 +03:00
parent 4ccc2a2138
commit 61a1403e4e
16 changed files with 705 additions and 338 deletions

View File

@@ -11,8 +11,34 @@ import (
"slices"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
)
// ClientID is a unique identifier for a persistent client used in
// DNS-over-HTTPS, DNS-over-TLS, and DNS-over-QUIC queries.
//
// TODO(s.chzhen): Use everywhere.
type ClientID string
// ValidateClientID returns an error if id is not a valid ClientID.
//
// TODO(s.chzhen): Consider implementing [validate.Interface] for ClientID.
func ValidateClientID(id string) (err error) {
err = netutil.ValidateHostnameLabel(id)
if err != nil {
// Replace the domain name label wrapper with our own.
return fmt.Errorf("invalid clientid %q: %w", id, errors.Unwrap(err))
}
return nil
}
// isValidClientID returns false if id is not a valid ClientID.
func isValidClientID(id string) (ok bool) {
return netutil.IsValidHostnameLabel(id)
}
// Source represents the source from which the information about the client has
// been obtained.
type Source uint8

View File

@@ -35,7 +35,7 @@ type index struct {
nameToUID map[string]UID
// clientIDToUID maps ClientID to UID.
clientIDToUID map[string]UID
clientIDToUID map[ClientID]UID
// ipToUID maps IP address to UID.
ipToUID map[netip.Addr]UID
@@ -54,7 +54,7 @@ type index struct {
func newIndex() (ci *index) {
return &index{
nameToUID: map[string]UID{},
clientIDToUID: map[string]UID{},
clientIDToUID: map[ClientID]UID{},
ipToUID: map[netip.Addr]UID{},
subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare),
macToUID: map[macKey]UID{},
@@ -207,7 +207,7 @@ func (ci *index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr)
// find finds persistent client by string representation of the ClientID, IP
// address, or MAC.
func (ci *index) find(id string) (c *Persistent, ok bool) {
c, ok = ci.findByClientID(id)
c, ok = ci.findByClientID(ClientID(id))
if ok {
return c, true
}
@@ -230,7 +230,7 @@ func (ci *index) find(id string) (c *Persistent, ok bool) {
}
// findByClientID finds persistent client by ClientID.
func (ci *index) findByClientID(clientID string) (c *Persistent, ok bool) {
func (ci *index) findByClientID(clientID ClientID) (c *Persistent, ok bool) {
uid, ok := ci.clientIDToUID[clientID]
if ok {
return ci.uidToClient[uid], true
@@ -275,6 +275,26 @@ func (ci *index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
return nil, false
}
// findByCIDR searches for a persistent client with the provided subnet as an
// identifier. Note that this function looks for an exact match of subnets,
// rather than checking if one subnet contains another.
func (ci *index) findByCIDR(subnet netip.Prefix) (c *Persistent, ok bool) {
var uid UID
for pref, id := range ci.subnetToUID.Range {
if subnet == pref {
uid, ok = id, true
break
}
}
if ok {
return ci.uidToClient[uid], true
}
return nil, false
}
// findByMAC finds persistent client by MAC.
func (ci *index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
k := macToKey(mac)

View File

@@ -5,6 +5,7 @@ import (
"net/netip"
"testing"
"github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -58,12 +59,12 @@ func TestClientIndex_Find(t *testing.T) {
clientWithMAC = &Persistent{
Name: "client_with_mac",
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
MACs: []net.HardwareAddr{errors.Must(net.ParseMAC(cliMAC))},
}
clientWithID = &Persistent{
Name: "client_with_id",
ClientIDs: []string{cliID},
ClientIDs: []ClientID{cliID},
}
clientLinkLocal = &Persistent{
@@ -141,10 +142,10 @@ func TestClientIndex_Clashes(t *testing.T) {
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
}, {
Name: "client_with_mac",
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
MACs: []net.HardwareAddr{errors.Must(net.ParseMAC(cliMAC))},
}, {
Name: "client_with_id",
ClientIDs: []string{cliID},
ClientIDs: []ClientID{cliID},
}}
ci := newIDIndex(clients)
@@ -181,17 +182,6 @@ func TestClientIndex_Clashes(t *testing.T) {
}
}
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
// error.
func mustParseMAC(s string) (mac net.HardwareAddr) {
mac, err := net.ParseMAC(s)
if err != nil {
panic(err)
}
return mac
}
func TestMACToKey(t *testing.T) {
testCases := []struct {
want any
@@ -200,44 +190,44 @@ func TestMACToKey(t *testing.T) {
}{{
name: "column6",
in: "00:00:5e:00:53:01",
want: [6]byte(mustParseMAC("00:00:5e:00:53:01")),
want: [6]byte(errors.Must(net.ParseMAC("00:00:5e:00:53:01"))),
}, {
name: "column8",
in: "02:00:5e:10:00:00:00:01",
want: [8]byte(mustParseMAC("02:00:5e:10:00:00:00:01")),
want: [8]byte(errors.Must(net.ParseMAC("02:00:5e:10:00:00:00:01"))),
}, {
name: "column20",
in: "00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01",
want: [20]byte(mustParseMAC("00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01")),
want: [20]byte(errors.Must(net.ParseMAC("00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01"))),
}, {
name: "hyphen6",
in: "00-00-5e-00-53-01",
want: [6]byte(mustParseMAC("00-00-5e-00-53-01")),
want: [6]byte(errors.Must(net.ParseMAC("00-00-5e-00-53-01"))),
}, {
name: "hyphen8",
in: "02-00-5e-10-00-00-00-01",
want: [8]byte(mustParseMAC("02-00-5e-10-00-00-00-01")),
want: [8]byte(errors.Must(net.ParseMAC("02-00-5e-10-00-00-00-01"))),
}, {
name: "hyphen20",
in: "00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01",
want: [20]byte(mustParseMAC("00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01")),
want: [20]byte(errors.Must(net.ParseMAC("00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01"))),
}, {
name: "dot6",
in: "0000.5e00.5301",
want: [6]byte(mustParseMAC("0000.5e00.5301")),
want: [6]byte(errors.Must(net.ParseMAC("0000.5e00.5301"))),
}, {
name: "dot8",
in: "0200.5e10.0000.0001",
want: [8]byte(mustParseMAC("0200.5e10.0000.0001")),
want: [8]byte(errors.Must(net.ParseMAC("0200.5e10.0000.0001"))),
}, {
name: "dot20",
in: "0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001",
want: [20]byte(mustParseMAC("0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001")),
want: [20]byte(errors.Must(net.ParseMAC("0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001"))),
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mac := mustParseMAC(tc.in)
mac := errors.Must(net.ParseMAC(tc.in))
key := macToKey(mac)
assert.Equal(t, tc.want, key)
@@ -302,19 +292,19 @@ func TestIndex_FindByIPWithoutZone(t *testing.T) {
func TestClientIndex_RangeByName(t *testing.T) {
sortedClients := []*Persistent{{
Name: "clientA",
ClientIDs: []string{"A"},
ClientIDs: []ClientID{"A"},
}, {
Name: "clientB",
ClientIDs: []string{"B"},
ClientIDs: []ClientID{"B"},
}, {
Name: "clientC",
ClientIDs: []string{"C"},
ClientIDs: []ClientID{"C"},
}, {
Name: "clientD",
ClientIDs: []string{"D"},
ClientIDs: []ClientID{"D"},
}, {
Name: "clientE",
ClientIDs: []string{"E"},
ClientIDs: []ClientID{"E"},
}}
testCases := []struct {
@@ -349,3 +339,115 @@ func TestClientIndex_RangeByName(t *testing.T) {
})
}
}
func TestIndex_FindByName(t *testing.T) {
const (
clientExistingName = "client_existing"
clientAnotherExistingName = "client_another_existing"
nonExistingClientName = "client_non_existing"
)
var (
clientExisting = &Persistent{
Name: clientExistingName,
IPs: []netip.Addr{netip.MustParseAddr("192.0.2.1")},
}
clientAnotherExisting = &Persistent{
Name: clientAnotherExistingName,
IPs: []netip.Addr{netip.MustParseAddr("192.0.2.2")},
}
)
clients := []*Persistent{
clientExisting,
clientAnotherExisting,
}
ci := newIDIndex(clients)
testCases := []struct {
want *Persistent
found assert.BoolAssertionFunc
name string
clientName string
}{{
want: clientExisting,
found: assert.True,
name: "existing",
clientName: clientExistingName,
}, {
want: clientAnotherExisting,
found: assert.True,
name: "another_existing",
clientName: clientAnotherExistingName,
}, {
want: nil,
found: assert.False,
name: "non_existing",
clientName: nonExistingClientName,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c, ok := ci.findByName(tc.clientName)
assert.Equal(t, tc.want, c)
tc.found(t, ok)
})
}
}
func TestIndex_FindByMAC(t *testing.T) {
var (
cliMAC = errors.Must(net.ParseMAC("11:11:11:11:11:11"))
cliAnotherMAC = errors.Must(net.ParseMAC("22:22:22:22:22:22"))
nonExistingClientMAC = errors.Must(net.ParseMAC("33:33:33:33:33:33"))
)
var (
clientExisting = &Persistent{
Name: "client",
MACs: []net.HardwareAddr{cliMAC},
}
clientAnotherExisting = &Persistent{
Name: "another_client",
MACs: []net.HardwareAddr{cliAnotherMAC},
}
)
clients := []*Persistent{
clientExisting,
clientAnotherExisting,
}
ci := newIDIndex(clients)
testCases := []struct {
want *Persistent
found assert.BoolAssertionFunc
name string
clientMAC net.HardwareAddr
}{{
want: clientExisting,
found: assert.True,
name: "existing",
clientMAC: cliMAC,
}, {
want: clientAnotherExisting,
found: assert.True,
name: "another_existing",
clientMAC: cliAnotherMAC,
}, {
want: nil,
found: assert.False,
name: "non_existing",
clientMAC: nonExistingClientMAC,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c, ok := ci.findByMAC(tc.clientMAC)
assert.Equal(t, tc.want, c)
tc.found(t, ok)
})
}
}

View File

@@ -15,7 +15,6 @@ import (
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/google/uuid"
)
@@ -90,7 +89,7 @@ type Persistent struct {
// ClientIDs identifying the client. The client must have at least one ID
// (IP, subnet, MAC, or ClientID).
ClientIDs []string
ClientIDs []ClientID
// UID is the unique identifier of the persistent client.
UID UID
@@ -134,7 +133,7 @@ func (c *Persistent) validate(ctx context.Context, l *slog.Logger, allTags []str
switch {
case c.Name == "":
return errors.Error("empty name")
case c.IDsLen() == 0:
case c.idendifiersLen() == 0:
return errors.Error("id required")
case c.UID == UID{}:
return errors.Error("uid required")
@@ -237,28 +236,15 @@ func (c *Persistent) setID(id string) (err error) {
return err
}
c.ClientIDs = append(c.ClientIDs, strings.ToLower(id))
c.ClientIDs = append(c.ClientIDs, ClientID(strings.ToLower(id)))
return nil
}
// ValidateClientID returns an error if id is not a valid ClientID.
//
// TODO(s.chzhen): It's an exact copy of the [dnsforward.ValidateClientID] to
// avoid the import cycle. Remove it.
func ValidateClientID(id string) (err error) {
err = netutil.ValidateHostnameLabel(id)
if err != nil {
// Replace the domain name label wrapper with our own.
return fmt.Errorf("invalid clientid %q: %w", id, errors.Unwrap(err))
}
return nil
}
// IDs returns a list of ClientIDs containing at least one element.
func (c *Persistent) IDs() (ids []string) {
ids = make([]string, 0, c.IDsLen())
// Identifiers returns a list of client identifiers containing at least one
// element.
func (c *Persistent) Identifiers() (ids []string) {
ids = make([]string, 0, c.idendifiersLen())
for _, ip := range c.IPs {
ids = append(ids, ip.String())
@@ -272,11 +258,15 @@ func (c *Persistent) IDs() (ids []string) {
ids = append(ids, mac.String())
}
return append(ids, c.ClientIDs...)
for _, cid := range c.ClientIDs {
ids = append(ids, string(cid))
}
return ids
}
// IDsLen returns a length of ClientIDs.
func (c *Persistent) IDsLen() (n int) {
// identifiersLen returns the number of client identifiers.
func (c *Persistent) idendifiersLen() (n int) {
return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs)
}

View File

@@ -7,6 +7,7 @@ import (
"net"
"net/netip"
"slices"
"strings"
"sync"
"time"
@@ -18,6 +19,7 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/timeutil"
)
@@ -433,48 +435,186 @@ func (s *Storage) Add(ctx context.Context, p *Persistent) (err error) {
ctx,
"client added",
"name", p.Name,
"ids", p.IDs(),
"ids", p.Identifiers(),
"clients_count", s.index.size(),
)
return nil
}
// FindByName finds persistent client by name. And returns its shallow copy.
func (s *Storage) FindByName(name string) (p *Persistent, ok bool) {
s.mu.Lock()
defer s.mu.Unlock()
// FindParams represents the parameters for searching a client. At least one
// field must be non-empty.
type FindParams struct {
// ClientID is a unique identifier for the client used in DoH, DoT, and DoQ
// DNS queries.
ClientID ClientID
p, ok = s.index.findByName(name)
if ok {
return p.ShallowClone(), ok
}
// RemoteIP is the IP address used as a client search parameter.
RemoteIP netip.Addr
return nil, false
// Subnet is the CIDR used as a client search parameter.
Subnet netip.Prefix
// MAC is the physical hardware address used as a client search parameter.
MAC net.HardwareAddr
// UID is the unique ID of persistent client used as a search parameter.
//
// TODO(s.chzhen): Use this.
UID UID
}
// Find finds persistent client by string representation of the ClientID, IP
// address, or MAC. And returns its shallow copy.
// ErrBadIdentifier is returned by [FindParams.Set] when it cannot parse the
// provided client identifier.
const ErrBadIdentifier errors.Error = "bad client identifier"
// Set clears the stored search parameters and parses the string representation
// of the search parameter into typed parameter, storing it. In some cases, it
// may result in storing both an IP address and a MAC address because they might
// have identical string representations. It returns [ErrBadIdentifier] if id
// cannot be parsed.
//
// TODO(s.chzhen): Accept ClientIDData structure instead, which will contain
// the parsed IP address, if any.
func (s *Storage) Find(id string) (p *Persistent, ok bool) {
// TODO(s.chzhen): Add support for UID.
func (p *FindParams) Set(id string) (err error) {
*p = FindParams{}
isClientID := true
if netutil.IsValidIPString(id) {
// It is safe to use [netip.MustParseAddr] because it has already been
// validated that id contains the string representation of the IP
// address.
p.RemoteIP = netip.MustParseAddr(id)
// Even if id can be parsed as an IP address, it may be a MAC address.
// So do not return prematurely, continue parsing.
isClientID = false
}
if canBeValidIPPrefixString(id) {
p.Subnet, err = netip.ParsePrefix(id)
if err == nil {
isClientID = false
}
}
if canBeMACString(id) {
p.MAC, err = net.ParseMAC(id)
if err == nil {
isClientID = false
}
}
if !isClientID {
return nil
}
if !isValidClientID(id) {
return ErrBadIdentifier
}
p.ClientID = ClientID(id)
return nil
}
// canBeValidIPPrefixString is a best-effort check to determine if s is a valid
// CIDR before using [netip.ParsePrefix], aimed at reducing allocations.
//
// TODO(s.chzhen): Replace this implementation with the more robust version
// from golibs.
func canBeValidIPPrefixString(s string) (ok bool) {
ipStr, bitStr, ok := strings.Cut(s, "/")
if !ok {
return false
}
if bitStr == "" || len(bitStr) > 3 {
return false
}
bits := 0
for _, c := range bitStr {
if c < '0' || c > '9' {
return false
}
bits = bits*10 + int(c-'0')
}
if bits > 128 {
return false
}
return netutil.IsValidIPString(ipStr)
}
// canBeMACString is a best-effort check to determine if s is a valid MAC
// address before using [net.ParseMAC], aimed at reducing allocations.
//
// TODO(s.chzhen): Replace this implementation with the more robust version
// from golibs.
func canBeMACString(s string) (ok bool) {
switch len(s) {
case
len("0000.0000.0000"),
len("00:00:00:00:00:00"),
len("0000.0000.0000.0000"),
len("00:00:00:00:00:00:00:00"),
len("0000.0000.0000.0000.0000.0000.0000.0000.0000.0000"),
len("00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"):
return true
default:
return false
}
}
// Find represents the parameters for searching a client. params must not be
// nil and must have at least one non-empty field.
func (s *Storage) Find(params *FindParams) (p *Persistent, ok bool) {
s.mu.Lock()
defer s.mu.Unlock()
p, ok = s.index.find(id)
isClientID := params.ClientID != ""
isRemoteIP := params.RemoteIP != (netip.Addr{})
isSubnet := params.Subnet != (netip.Prefix{})
isMAC := params.MAC != nil
for {
switch {
case isClientID:
isClientID = false
p, ok = s.index.findByClientID(params.ClientID)
case isRemoteIP:
isRemoteIP = false
p, ok = s.findByIP(params.RemoteIP)
case isSubnet:
isSubnet = false
p, ok = s.index.findByCIDR(params.Subnet)
case isMAC:
isMAC = false
p, ok = s.index.findByMAC(params.MAC)
default:
return nil, false
}
if ok {
return p.ShallowClone(), true
}
}
}
// findByIP finds persistent client by IP address. s.mu is expected to be
// locked.
func (s *Storage) findByIP(addr netip.Addr) (p *Persistent, ok bool) {
p, ok = s.index.findByIP(addr)
if ok {
return p.ShallowClone(), ok
return p, true
}
ip, err := netip.ParseAddr(id)
if err != nil {
return nil, false
}
foundMAC := s.dhcp.MACByIP(ip)
foundMAC := s.dhcp.MACByIP(addr)
if foundMAC != nil {
return s.FindByMAC(foundMAC)
return s.index.findByMAC(foundMAC)
}
return nil, false
@@ -487,6 +627,8 @@ func (s *Storage) Find(id string) (p *Persistent, ok bool) {
//
// Note that multiple clients can have the same IP address with different zones.
// Therefore, the result of this method is indeterminate.
//
// TODO(s.chzhen): Consider accepting [FindParams].
func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -498,7 +640,7 @@ func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) {
foundMAC := s.dhcp.MACByIP(ip)
if foundMAC != nil {
return s.FindByMAC(foundMAC)
return s.index.findByMAC(foundMAC)
}
p = s.index.findByIPWithoutZone(ip)
@@ -509,17 +651,6 @@ func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) {
return nil, false
}
// FindByMAC finds persistent client by MAC and returns its shallow copy. s.mu
// is expected to be locked.
func (s *Storage) FindByMAC(mac net.HardwareAddr) (p *Persistent, ok bool) {
p, ok = s.index.findByMAC(mac)
if ok {
return p.ShallowClone(), ok
}
return nil, false
}
// RemoveByName removes persistent client information. ok is false if no such
// client exists by that name.
func (s *Storage) RemoveByName(ctx context.Context, name string) (ok bool) {
@@ -648,7 +779,7 @@ func (s *Storage) CustomUpstreamConfig(
s.mu.Lock()
defer s.mu.Unlock()
c, ok := s.index.findByClientID(id)
c, ok := s.index.findByClientID(ClientID(id))
if !ok {
c, ok = s.index.findByIP(addr)
}
@@ -682,7 +813,7 @@ func (s *Storage) ClearUpstreamCache() {
// ClientID or client IP address, and applies it to the filtering settings.
// setts must not be nil.
func (s *Storage) ApplyClientFiltering(id string, addr netip.Addr, setts *filtering.Settings) {
c, ok := s.index.findByClientID(id)
c, ok := s.index.findByClientID(ClientID(id))
if !ok {
c, ok = s.index.findByIP(addr)
}
@@ -690,7 +821,7 @@ func (s *Storage) ApplyClientFiltering(id string, addr netip.Addr, setts *filter
if !ok {
foundMAC := s.dhcp.MACByIP(addr)
if foundMAC != nil {
c, ok = s.FindByMAC(foundMAC)
c, ok = s.index.findByMAC(foundMAC)
}
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
@@ -350,15 +351,15 @@ func TestClientsDHCP(t *testing.T) {
cliName1 = "one.dhcp"
cliIP2 = netip.MustParseAddr("2.2.2.2")
cliMAC2 = mustParseMAC("22:22:22:22:22:22")
cliMAC2 = errors.Must(net.ParseMAC("22:22:22:22:22:22"))
cliName2 = "two.dhcp"
cliIP3 = netip.MustParseAddr("3.3.3.3")
cliMAC3 = mustParseMAC("33:33:33:33:33:33")
cliMAC3 = errors.Must(net.ParseMAC("33:33:33:33:33:33"))
cliName3 = "three.dhcp"
prsCliIP = netip.MustParseAddr("4.3.2.1")
prsCliMAC = mustParseMAC("AA:AA:AA:AA:AA:AA")
prsCliMAC = errors.Must(net.ParseMAC("AA:AA:AA:AA:AA:AA"))
prsCliName = "persistent.dhcp"
otherARPCliName = "other.arp"
@@ -519,7 +520,11 @@ func TestClientsDHCP(t *testing.T) {
})
require.NoError(t, err)
prsCli, ok := storage.Find(prsCliIP.String())
params := &client.FindParams{}
err = params.Set(prsCliIP.String())
require.NoError(t, err)
prsCli, ok := storage.Find(params)
require.True(t, ok)
assert.Equal(t, prsCliName, prsCli.Name)
@@ -663,17 +668,6 @@ func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) {
return s
}
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
// error.
func mustParseMAC(s string) (mac net.HardwareAddr) {
mac, err := net.ParseMAC(s)
if err != nil {
panic(err)
}
return mac
}
func TestStorage_Add(t *testing.T) {
const (
existingName = "existing_name"
@@ -693,7 +687,7 @@ func TestStorage_Add(t *testing.T) {
Name: existingName,
IPs: []netip.Addr{existingIP},
Subnets: []netip.Prefix{existingSubnet},
ClientIDs: []string{existingClientID},
ClientIDs: []client.ClientID{existingClientID},
UID: existingClientUID,
}
@@ -761,7 +755,7 @@ func TestStorage_Add(t *testing.T) {
name: "duplicate_client_id",
cli: &client.Persistent{
Name: "duplicate_client_id",
ClientIDs: []string{existingClientID},
ClientIDs: []client.ClientID{existingClientID},
UID: client.MustNewUID(),
},
wantErrMsg: `adding client: another client "existing_name" ` +
@@ -898,12 +892,12 @@ func TestStorage_Find(t *testing.T) {
clientWithMAC = &client.Persistent{
Name: "client_with_mac",
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
MACs: []net.HardwareAddr{errors.Must(net.ParseMAC(cliMAC))},
}
clientWithID = &client.Persistent{
Name: "client_with_id",
ClientIDs: []string{cliID},
ClientIDs: []client.ClientID{cliID},
}
clientLinkLocal = &client.Persistent{
@@ -950,7 +944,11 @@ func TestStorage_Find(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, id := range tc.ids {
c, ok := s.Find(id)
params := &client.FindParams{}
err := params.Set(id)
require.NoError(t, err)
c, ok := s.Find(params)
require.True(t, ok)
assert.Equal(t, tc.want, c)
@@ -959,7 +957,11 @@ func TestStorage_Find(t *testing.T) {
}
t.Run("not_found", func(t *testing.T) {
_, ok := s.Find(cliIPNone)
params := &client.FindParams{}
err := params.Set(cliIPNone)
require.NoError(t, err)
_, ok := s.Find(params)
assert.False(t, ok)
})
}
@@ -1025,127 +1027,6 @@ func TestStorage_FindLoose(t *testing.T) {
}
}
func TestStorage_FindByName(t *testing.T) {
const (
cliIP1 = "1.1.1.1"
cliIP2 = "2.2.2.2"
)
const (
clientExistingName = "client_existing"
clientAnotherExistingName = "client_another_existing"
nonExistingClientName = "client_non_existing"
)
var (
clientExisting = &client.Persistent{
Name: clientExistingName,
IPs: []netip.Addr{netip.MustParseAddr(cliIP1)},
}
clientAnotherExisting = &client.Persistent{
Name: clientAnotherExistingName,
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
}
)
clients := []*client.Persistent{
clientExisting,
clientAnotherExisting,
}
s := newStorage(t, clients)
testCases := []struct {
want *client.Persistent
name string
clientName string
}{{
name: "existing",
clientName: clientExistingName,
want: clientExisting,
}, {
name: "another_existing",
clientName: clientAnotherExistingName,
want: clientAnotherExisting,
}, {
name: "non_existing",
clientName: nonExistingClientName,
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c, ok := s.FindByName(tc.clientName)
if tc.want == nil {
assert.False(t, ok)
return
}
assert.True(t, ok)
assert.Equal(t, tc.want, c)
})
}
}
func TestStorage_FindByMAC(t *testing.T) {
var (
cliMAC = mustParseMAC("11:11:11:11:11:11")
cliAnotherMAC = mustParseMAC("22:22:22:22:22:22")
nonExistingClientMAC = mustParseMAC("33:33:33:33:33:33")
)
var (
clientExisting = &client.Persistent{
Name: "client",
MACs: []net.HardwareAddr{cliMAC},
}
clientAnotherExisting = &client.Persistent{
Name: "another_client",
MACs: []net.HardwareAddr{cliAnotherMAC},
}
)
clients := []*client.Persistent{
clientExisting,
clientAnotherExisting,
}
s := newStorage(t, clients)
testCases := []struct {
want *client.Persistent
name string
clientMAC net.HardwareAddr
}{{
name: "existing",
clientMAC: cliMAC,
want: clientExisting,
}, {
name: "another_existing",
clientMAC: cliAnotherMAC,
want: clientAnotherExisting,
}, {
name: "non_existing",
clientMAC: nonExistingClientMAC,
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c, ok := s.FindByMAC(tc.clientMAC)
if tc.want == nil {
assert.False(t, ok)
return
}
assert.True(t, ok)
assert.Equal(t, tc.want, c)
})
}
}
func TestStorage_Update(t *testing.T) {
const (
clientName = "client_name"
@@ -1162,7 +1043,7 @@ func TestStorage_Update(t *testing.T) {
Name: obstructingName,
IPs: []netip.Addr{obstructingIP},
Subnets: []netip.Prefix{obstructingSubnet},
ClientIDs: []string{obstructingClientID},
ClientIDs: []client.ClientID{obstructingClientID},
}
clientToUpdate := &client.Persistent{
@@ -1211,7 +1092,7 @@ func TestStorage_Update(t *testing.T) {
name: "duplicate_client_id",
cli: &client.Persistent{
Name: "duplicate_client_id",
ClientIDs: []string{obstructingClientID},
ClientIDs: []client.ClientID{obstructingClientID},
UID: client.MustNewUID(),
},
wantErrMsg: `updating client: another client "obstructing_name" ` +
@@ -1238,19 +1119,19 @@ func TestStorage_Update(t *testing.T) {
func TestStorage_RangeByName(t *testing.T) {
sortedClients := []*client.Persistent{{
Name: "clientA",
ClientIDs: []string{"A"},
ClientIDs: []client.ClientID{"A"},
}, {
Name: "clientB",
ClientIDs: []string{"B"},
ClientIDs: []client.ClientID{"B"},
}, {
Name: "clientC",
ClientIDs: []string{"C"},
ClientIDs: []client.ClientID{"C"},
}, {
Name: "clientD",
ClientIDs: []string{"D"},
ClientIDs: []client.ClientID{"D"},
}, {
Name: "clientE",
ClientIDs: []string{"E"},
ClientIDs: []client.ClientID{"E"},
}}
testCases := []struct {
@@ -1306,7 +1187,7 @@ func TestStorage_CustomUpstreamConfig(t *testing.T) {
existingClient := &client.Persistent{
Name: existingName,
IPs: []netip.Addr{existingIP},
ClientIDs: []string{existingClientID},
ClientIDs: []client.ClientID{existingClientID},
UID: existingClientUID,
Upstreams: []string{"192.0.2.0"},
}
@@ -1381,3 +1262,182 @@ func TestStorage_CustomUpstreamConfig(t *testing.T) {
assert.NotEqual(t, conf, updConf)
})
}
func BenchmarkFindParams_Set(b *testing.B) {
const (
testIPStr = "192.0.2.1"
testCIDRStr = "192.0.2.0/24"
testMACStr = "02:00:00:00:00:00"
testClientID = "clientid"
)
benchCases := []struct {
wantErr error
params *client.FindParams
name string
id string
}{{
wantErr: nil,
params: &client.FindParams{
ClientID: testClientID,
},
name: "client_id",
id: testClientID,
}, {
wantErr: nil,
params: &client.FindParams{
RemoteIP: netip.MustParseAddr(testIPStr),
},
name: "ip_address",
id: testIPStr,
}, {
wantErr: nil,
params: &client.FindParams{
Subnet: netip.MustParsePrefix(testCIDRStr),
},
name: "subnet",
id: testCIDRStr,
}, {
wantErr: nil,
params: &client.FindParams{
MAC: errors.Must(net.ParseMAC(testMACStr)),
},
name: "mac_address",
id: testMACStr,
}, {
wantErr: client.ErrBadIdentifier,
params: &client.FindParams{},
name: "bad_id",
id: "!@#$%^&*()_+",
}}
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
params := &client.FindParams{}
var err error
b.ReportAllocs()
for b.Loop() {
err = params.Set(bc.id)
}
assert.ErrorIs(b, err, bc.wantErr)
assert.Equal(b, bc.params, params)
})
}
// Most recent results:
//
// goos: linux
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardHome/internal/client
// cpu: Intel(R) Core(TM) i7-10510U CPU @ 1.80GHz
// BenchmarkFindParams_Set/client_id-8 49463488 24.27 ns/op 0 B/op 0 allocs/op
// BenchmarkFindParams_Set/ip_address-8 18740977 62.22 ns/op 0 B/op 0 allocs/op
// BenchmarkFindParams_Set/subnet-8 10848192 110.0 ns/op 0 B/op 0 allocs/op
// BenchmarkFindParams_Set/mac_address-8 8148494 133.2 ns/op 8 B/op 1 allocs/op
// BenchmarkFindParams_Set/bad_id-8 73894278 16.29 ns/op 0 B/op 0 allocs/op
}
func BenchmarkStorage_Find(b *testing.B) {
const (
cliID = "cid"
cliMAC = "02:00:00:00:00:00"
)
const (
cliNameWithID = "client_with_id"
cliNameWithIP = "client_with_ip"
cliNameWithCIDR = "client_with_cidr"
cliNameWithMAC = "client_with_mac"
)
var (
cliIP = netip.MustParseAddr("192.0.2.1")
cliCIDR = netip.MustParsePrefix("192.0.2.0/24")
)
var (
clientWithID = &client.Persistent{
Name: cliNameWithID,
ClientIDs: []client.ClientID{cliID},
}
clientWithIP = &client.Persistent{
Name: cliNameWithIP,
IPs: []netip.Addr{cliIP},
}
clientWithCIDR = &client.Persistent{
Name: cliNameWithCIDR,
Subnets: []netip.Prefix{cliCIDR},
}
clientWithMAC = &client.Persistent{
Name: cliNameWithMAC,
MACs: []net.HardwareAddr{errors.Must(net.ParseMAC(cliMAC))},
}
)
clients := []*client.Persistent{
clientWithID,
clientWithIP,
clientWithCIDR,
clientWithMAC,
}
s := newStorage(b, clients)
benchCases := []struct {
params *client.FindParams
name string
wantName string
}{{
params: &client.FindParams{
ClientID: cliID,
},
name: "client_id",
wantName: cliNameWithID,
}, {
params: &client.FindParams{
RemoteIP: cliIP,
},
name: "ip_address",
wantName: cliNameWithIP,
}, {
params: &client.FindParams{
Subnet: cliCIDR,
},
name: "subnet",
wantName: cliNameWithCIDR,
}, {
params: &client.FindParams{
MAC: errors.Must(net.ParseMAC(cliMAC)),
},
name: "mac_address",
wantName: cliNameWithMAC,
}}
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
var p *client.Persistent
var ok bool
b.ReportAllocs()
for b.Loop() {
p, ok = s.Find(bc.params)
}
assert.True(b, ok)
assert.NotNil(b, p)
assert.Equal(b, bc.wantName, p.Name)
})
}
// Most recent results:
//
// goos: linux
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardHome/internal/client
// cpu: Intel(R) Core(TM) i7-10510U CPU @ 1.80GHz
// BenchmarkStorage_Find/client_id-8 7070107 154.4 ns/op 240 B/op 2 allocs/op
// BenchmarkStorage_Find/ip_address-8 6831823 168.6 ns/op 248 B/op 2 allocs/op
// BenchmarkStorage_Find/subnet-8 7209050 167.5 ns/op 256 B/op 2 allocs/op
// BenchmarkStorage_Find/mac_address-8 5776131 199.7 ns/op 256 B/op 3 allocs/op
}

View File

@@ -1,13 +1,11 @@
package dhcpsvc_test
import (
"net"
"net/netip"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/stretchr/testify/require"
)
// testLocalTLD is a common local TLD for tests.
@@ -56,11 +54,3 @@ var testInterfaceConf = map[string]*dhcpsvc.InterfaceConfig{
},
},
}
// mustParseMAC parses a hardware address from s and requires no errors.
func mustParseMAC(t require.TestingT, s string) (mac net.HardwareAddr) {
mac, err := net.ParseMAC(s)
require.NoError(t, err)
return mac
}

View File

@@ -2,6 +2,7 @@ package dhcpsvc_test
import (
"io/fs"
"net"
"net/netip"
"os"
"path"
@@ -11,6 +12,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -176,9 +178,9 @@ func TestDHCPServer_AddLease(t *testing.T) {
newIP = netip.MustParseAddr("192.168.0.3")
newIPv6 = netip.MustParseAddr("2001:db8::2")
existMAC = mustParseMAC(t, "01:02:03:04:05:06")
newMAC = mustParseMAC(t, "06:05:04:03:02:01")
ipv6MAC = mustParseMAC(t, "02:03:04:05:06:07")
existMAC = errors.Must(net.ParseMAC("01:02:03:04:05:06"))
newMAC = errors.Must(net.ParseMAC("06:05:04:03:02:01"))
ipv6MAC = errors.Must(net.ParseMAC("02:03:04:05:06:07"))
)
require.NoError(t, srv.AddLease(ctx, &dhcpsvc.Lease{
@@ -291,9 +293,9 @@ func TestDHCPServer_index(t *testing.T) {
ip3 = netip.MustParseAddr("172.16.0.3")
ip4 = netip.MustParseAddr("172.16.0.4")
mac1 = mustParseMAC(t, "01:02:03:04:05:06")
mac2 = mustParseMAC(t, "06:05:04:03:02:01")
mac3 = mustParseMAC(t, "02:03:04:05:06:07")
mac1 = errors.Must(net.ParseMAC("01:02:03:04:05:06"))
mac2 = errors.Must(net.ParseMAC("06:05:04:03:02:01"))
mac3 = errors.Must(net.ParseMAC("02:03:04:05:06:07"))
)
t.Run("ip_idx", func(t *testing.T) {
@@ -349,9 +351,9 @@ func TestDHCPServer_UpdateStaticLease(t *testing.T) {
ip3 = netip.MustParseAddr("192.168.0.4")
ip4 = netip.MustParseAddr("2001:db8::3")
mac1 = mustParseMAC(t, "01:02:03:04:05:06")
mac2 = mustParseMAC(t, "06:05:04:03:02:01")
mac3 = mustParseMAC(t, "06:05:04:03:02:02")
mac1 = errors.Must(net.ParseMAC("01:02:03:04:05:06"))
mac2 = errors.Must(net.ParseMAC("06:05:04:03:02:01"))
mac3 = errors.Must(net.ParseMAC("06:05:04:03:02:02"))
)
testCases := []struct {
@@ -452,9 +454,9 @@ func TestDHCPServer_RemoveLease(t *testing.T) {
newIP = netip.MustParseAddr("192.168.0.3")
newIPv6 = netip.MustParseAddr("2001:db8::2")
existMAC = mustParseMAC(t, "01:02:03:04:05:06")
newMAC = mustParseMAC(t, "02:03:04:05:06:07")
ipv6MAC = mustParseMAC(t, "06:05:04:03:02:01")
existMAC = errors.Must(net.ParseMAC("01:02:03:04:05:06"))
newMAC = errors.Must(net.ParseMAC("02:03:04:05:06:07"))
ipv6MAC = errors.Must(net.ParseMAC("06:05:04:03:02:01"))
)
testCases := []struct {
@@ -559,13 +561,13 @@ func TestServer_Leases(t *testing.T) {
Expiry: expiry,
IP: netip.MustParseAddr("192.168.0.3"),
Hostname: "example.host",
HWAddr: mustParseMAC(t, "AA:AA:AA:AA:AA:AA"),
HWAddr: errors.Must(net.ParseMAC("AA:AA:AA:AA:AA:AA")),
IsStatic: false,
}, {
Expiry: time.Time{},
IP: netip.MustParseAddr("192.168.0.4"),
Hostname: "example.static.host",
HWAddr: mustParseMAC(t, "BB:BB:BB:BB:BB:BB"),
HWAddr: errors.Must(net.ParseMAC("BB:BB:BB:BB:BB:BB")),
IsStatic: true,
}}
assert.ElementsMatch(t, wantLeases, srv.Leases())

View File

@@ -10,6 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
@@ -51,7 +52,7 @@ func processAccessClients(
} else if ipnet, err = netip.ParsePrefix(s); err == nil {
*nets = append(*nets, ipnet)
} else {
err = ValidateClientID(s)
err = client.ValidateClientID(s)
if err != nil {
return fmt.Errorf("value %q at index %d: bad ip, cidr, or clientid", s, i)
}

View File

@@ -7,26 +7,13 @@ import (
"path"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/quic-go/quic-go"
)
// ValidateClientID returns an error if id is not a valid ClientID.
//
// Keep in sync with [client.ValidateClientID].
func ValidateClientID(id string) (err error) {
err = netutil.ValidateHostnameLabel(id)
if err != nil {
// Replace the domain name label wrapper with our own.
return fmt.Errorf("invalid clientid %q: %w", id, errors.Unwrap(err))
}
return nil
}
// clientIDFromClientServerName extracts and validates a ClientID. hostSrvName
// is the server name of the host. cliSrvName is the server name as sent by the
// client. When strict is true, and client and host server name don't match,
@@ -53,7 +40,7 @@ func clientIDFromClientServerName(
}
clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1]
err = ValidateClientID(clientID)
err = client.ValidateClientID(clientID)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return "", err
@@ -93,7 +80,7 @@ func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err e
return "", fmt.Errorf("clientid check: invalid path %q: extra parts", origPath)
}
err = ValidateClientID(clientID)
err = client.ValidateClientID(clientID)
if err != nil {
return "", fmt.Errorf("clientid check: %w", err)
}

View File

@@ -28,6 +28,10 @@ type clientsContainer struct {
// filter. It must not be nil.
baseLogger *slog.Logger
// logger is used for logging the operation of the client container. It
// must not be nil.
logger *slog.Logger
// storage stores information about persistent clients.
storage *client.Storage
@@ -58,6 +62,7 @@ type clientsContainer struct {
// BlockedClientChecker checks if a client is blocked by the current access
// settings.
type BlockedClientChecker interface {
// TODO(s.chzhen): Accept [client.FindParams].
IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string)
}
@@ -80,6 +85,7 @@ func (clients *clientsContainer) Init(
}
clients.baseLogger = baseLogger
clients.logger = baseLogger.With(slogutil.KeyPrefix, "client_container")
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
@@ -269,7 +275,7 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
BlockedServices: cli.BlockedServices.Clone(),
IDs: cli.IDs(),
IDs: cli.Identifiers(),
Tags: slices.Clone(cli.Tags),
Upstreams: slices.Clone(cli.Upstreams),
@@ -356,15 +362,27 @@ func (clients *clientsContainer) clientOrArtificial(
}, true
}
// shouldCountClient is a wrapper around [clientsContainer.find] to make it a
// shouldCountClient is a wrapper around [client.Storage.Find] to make it a
// valid client information finder for the statistics. If no information about
// the client is found, it returns true.
// the client is found, it returns true. Values of ids must be either a valid
// ClientID or a valid IP address.
//
// TODO(s.chzhen): Accept [client.FindParams].
func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
params := &client.FindParams{}
for _, id := range ids {
client, ok := clients.storage.Find(id)
err := params.Set(id)
if err != nil {
// Should not happen.
clients.logger.Warn("parsing find params", slogutil.KeyError, err)
continue
}
client, ok := clients.storage.Find(params)
if ok {
return !client.IgnoreStatistics
}

View File

@@ -300,7 +300,7 @@ func clientToJSON(c *client.Persistent) (cj *clientJSON) {
return &clientJSON{
Name: c.Name,
IDs: c.IDs(),
IDs: c.Identifiers(),
Tags: c.Tags,
UseGlobalSettings: !c.UseOwnSettings,
FilteringEnabled: c.FilteringEnabled,
@@ -428,32 +428,53 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
// Deprecated: Remove it when migration to the new API is over.
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
data := []map[string]*clientJSON{}
data := make([]map[string]*clientJSON, 0, len(q))
params := &client.FindParams{}
var err error
for i := range len(q) {
idStr := q.Get(fmt.Sprintf("ip%d", i))
if idStr == "" {
break
}
err = params.Set(idStr)
if err != nil {
clients.logger.DebugContext(
r.Context(),
"finding client",
"id", idStr,
slogutil.KeyError, err,
)
continue
}
data = append(data, map[string]*clientJSON{
idStr: clients.findClient(idStr),
idStr: clients.findClient(idStr, params),
})
}
aghhttp.WriteJSONResponseOK(w, r, data)
}
// findClient returns available information about a client by idStr from the
// client's storage or access settings. cj is guaranteed to be non-nil.
func (clients *clientsContainer) findClient(idStr string) (cj *clientJSON) {
ip, _ := netip.ParseAddr(idStr)
c, ok := clients.storage.Find(idStr)
// findClient returns available information about a client by params from the
// client's storage or access settings. idStr is the string representation of
// typed params. params must not be nil. cj is guaranteed to be non-nil.
func (clients *clientsContainer) findClient(
idStr string,
params *client.FindParams,
) (cj *clientJSON) {
c, ok := clients.storage.Find(params)
if !ok {
return clients.findRuntime(ip, idStr)
return clients.findRuntime(idStr, params)
}
cj = clientToJSON(c)
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
disallowed, rule := clients.clientChecker.IsBlockedClient(
params.RemoteIP,
string(params.ClientID),
)
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
return cj
@@ -472,7 +493,8 @@ type searchClientJSON struct {
ID string `json:"id"`
}
// handleSearchClient is the handler for the POST /control/clients/search HTTP API.
// handleSearchClient is the handler for the POST /control/clients/search HTTP
// API.
func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *http.Request) {
q := searchQueryJSON{}
err := json.NewDecoder(r.Body).Decode(&q)
@@ -482,11 +504,25 @@ func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *ht
return
}
data := []map[string]*clientJSON{}
data := make([]map[string]*clientJSON, 0, len(q.Clients))
params := &client.FindParams{}
for _, c := range q.Clients {
idStr := c.ID
err = params.Set(idStr)
if err != nil {
clients.logger.DebugContext(
r.Context(),
"searching client",
"id", idStr,
slogutil.KeyError, err,
)
continue
}
data = append(data, map[string]*clientJSON{
idStr: clients.findClient(idStr),
idStr: clients.findClient(idStr, params),
})
}
@@ -494,38 +530,37 @@ func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *ht
}
// findRuntime looks up the IP in runtime and temporary storages, like
// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
// non-nil.
func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) {
// /etc/hosts tables, DHCP leases, or blocklists. params must not be nil. cj
// is guaranteed to be non-nil.
func (clients *clientsContainer) findRuntime(
idStr string,
params *client.FindParams,
) (cj *clientJSON) {
var host string
whois := &whois.Info{}
ip := params.RemoteIP
rc := clients.storage.ClientRuntime(ip)
if rc == nil {
// It is still possible that the IP used to be in the runtime clients
// list, but then the server was reloaded. So, check the DNS server's
// blocked IP list.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
cj = &clientJSON{
IDs: []string{idStr},
Disallowed: &disallowed,
DisallowedRule: &rule,
WHOIS: &whois.Info{},
}
return cj
if rc != nil {
_, host = rc.Info()
whois = whoisOrEmpty(rc)
}
_, host := rc.Info()
cj = &clientJSON{
Name: host,
IDs: []string{idStr},
WHOIS: whoisOrEmpty(rc),
// Check the DNS server's blocked IP list regardless of whether a runtime
// client was found or not. This is because it's still possible that the
// runtime client associated with the IP address was stored previously, but
// then the server was reloaded.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, string(params.ClientID))
return &clientJSON{
Name: host,
IDs: []string{idStr},
WHOIS: whois,
Disallowed: &disallowed,
DisallowedRule: &rule,
}
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
return cj
}
// RegisterClientsHandlers registers HTTP handlers

View File

@@ -153,7 +153,7 @@ func TestClientsContainer_HandleAddClient(t *testing.T) {
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
clientEmptyID := newPersistentClient("empty_client_id")
clientEmptyID.ClientIDs = []string{""}
clientEmptyID.ClientIDs = []client.ClientID{""}
testCases := []struct {
name string
@@ -278,7 +278,7 @@ func TestClientsContainer_HandleUpdateClient(t *testing.T) {
clientModified := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
clientEmptyID := newPersistentClient("empty_client_id")
clientEmptyID.ClientIDs = []string{""}
clientEmptyID.ClientIDs = []client.ClientID{""}
testCases := []struct {
name string

View File

@@ -8,7 +8,7 @@ import (
"net/url"
"path"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
@@ -151,7 +151,7 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
clientID := q.Get("client_id")
if clientID != "" {
err = dnsforward.ValidateClientID(clientID)
err = client.ValidateClientID(clientID)
if err != nil {
respondJSONError(w, http.StatusBadRequest, err.Error())