all: sync with master

This commit is contained in:
Ainar Garipov
2024-07-03 15:38:37 +03:00
parent f73717ec08
commit 158d4f0249
352 changed files with 33842 additions and 33276 deletions

View File

@@ -30,8 +30,8 @@ func macToKey(mac net.HardwareAddr) (key macKey) {
}
}
// Index stores all information about persistent clients.
type Index struct {
// index stores all information about persistent clients.
type index struct {
// nameToUID maps client name to UID.
nameToUID map[string]UID
@@ -51,9 +51,9 @@ type Index struct {
subnetToUID aghalg.SortedMap[netip.Prefix, UID]
}
// NewIndex initializes the new instance of client index.
func NewIndex() (ci *Index) {
return &Index{
// newIndex initializes the new instance of client index.
func newIndex() (ci *index) {
return &index{
nameToUID: map[string]UID{},
clientIDToUID: map[string]UID{},
ipToUID: map[netip.Addr]UID{},
@@ -63,9 +63,9 @@ func NewIndex() (ci *Index) {
}
}
// Add stores information about a persistent client in the index. c must be
// non-nil and contain UID.
func (ci *Index) Add(c *Persistent) {
// add stores information about a persistent client in the index. c must be
// non-nil, have a UID, and contain at least one identifier.
func (ci *index) add(c *Persistent) {
if (c.UID == UID{}) {
panic("client must contain uid")
}
@@ -92,9 +92,9 @@ func (ci *Index) Add(c *Persistent) {
ci.uidToClient[c.UID] = c
}
// ClashesUID returns existing persistent client with the same UID as c. Note
// clashesUID returns existing persistent client with the same UID as c. Note
// that this is only possible when configuration contains duplicate fields.
func (ci *Index) ClashesUID(c *Persistent) (err error) {
func (ci *index) clashesUID(c *Persistent) (err error) {
p, ok := ci.uidToClient[c.UID]
if ok {
return fmt.Errorf("another client %q uses the same uid", p.Name)
@@ -103,9 +103,9 @@ func (ci *Index) ClashesUID(c *Persistent) (err error) {
return nil
}
// Clashes returns an error if the index contains a different persistent client
// clashes returns an error if the index contains a different persistent client
// with at least a single identifier contained by c. c must be non-nil.
func (ci *Index) Clashes(c *Persistent) (err error) {
func (ci *index) clashes(c *Persistent) (err error) {
if p := ci.clashesName(c); p != nil {
return fmt.Errorf("another client uses the same name %q", p.Name)
}
@@ -139,8 +139,8 @@ func (ci *Index) Clashes(c *Persistent) (err error) {
// clashesName returns existing persistent client with the same name as c or
// nil. c must be non-nil.
func (ci *Index) clashesName(c *Persistent) (existing *Persistent) {
existing, ok := ci.FindByName(c.Name)
func (ci *index) clashesName(c *Persistent) (existing *Persistent) {
existing, ok := ci.findByName(c.Name)
if !ok {
return nil
}
@@ -154,7 +154,7 @@ func (ci *Index) clashesName(c *Persistent) (existing *Persistent) {
// clashesIP returns a previous client with the same IP address as c. c must be
// non-nil.
func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) {
func (ci *index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) {
for _, ip := range c.IPs {
existing, ok := ci.ipToUID[ip]
if ok && existing != c.UID {
@@ -167,7 +167,7 @@ func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) {
// clashesSubnet returns a previous client with the same subnet as c. c must be
// non-nil.
func (ci *Index) clashesSubnet(c *Persistent) (p *Persistent, s netip.Prefix) {
func (ci *index) clashesSubnet(c *Persistent) (p *Persistent, s netip.Prefix) {
for _, s = range c.Subnets {
var existing UID
var ok bool
@@ -193,7 +193,7 @@ func (ci *Index) clashesSubnet(c *Persistent) (p *Persistent, s netip.Prefix) {
// clashesMAC returns a previous client with the same MAC address as c. c must
// be non-nil.
func (ci *Index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr) {
func (ci *index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr) {
for _, mac = range c.MACs {
k := macToKey(mac)
existing, ok := ci.macToUID[k]
@@ -205,9 +205,9 @@ func (ci *Index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr)
return nil, nil
}
// Find finds persistent client by string representation of the client ID, IP
// find finds persistent client by string representation of the client ID, IP
// address, or MAC.
func (ci *Index) Find(id string) (c *Persistent, ok bool) {
func (ci *index) find(id string) (c *Persistent, ok bool) {
uid, found := ci.clientIDToUID[id]
if found {
return ci.uidToClient[uid], true
@@ -224,14 +224,14 @@ func (ci *Index) Find(id string) (c *Persistent, ok bool) {
mac, err := net.ParseMAC(id)
if err == nil {
return ci.FindByMAC(mac)
return ci.findByMAC(mac)
}
return nil, false
}
// FindByName finds persistent client by name.
func (ci *Index) FindByName(name string) (c *Persistent, found bool) {
// findByName finds persistent client by name.
func (ci *index) findByName(name string) (c *Persistent, found bool) {
uid, found := ci.nameToUID[name]
if found {
return ci.uidToClient[uid], true
@@ -241,7 +241,7 @@ func (ci *Index) FindByName(name string) (c *Persistent, found bool) {
}
// findByIP finds persistent client by IP address.
func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
func (ci *index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
uid, found := ci.ipToUID[ip]
if found {
return ci.uidToClient[uid], true
@@ -266,8 +266,8 @@ func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
return nil, false
}
// FindByMAC finds persistent client by MAC.
func (ci *Index) FindByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
// findByMAC finds persistent client by MAC.
func (ci *index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
k := macToKey(mac)
uid, found := ci.macToUID[k]
if found {
@@ -277,13 +277,13 @@ func (ci *Index) FindByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
return nil, false
}
// FindByIPWithoutZone finds a persistent client by IP address without zone. It
// findByIPWithoutZone finds a persistent client by IP address without zone. It
// strips the IPv6 zone index from the stored IP addresses before comparing,
// because querylog entries don't have it. See TODO on [querylog.logEntry.IP].
//
// Note that multiple clients can have the same IP address with different zones.
// Therefore, the result of this method is indeterminate.
func (ci *Index) FindByIPWithoutZone(ip netip.Addr) (c *Persistent) {
func (ci *index) findByIPWithoutZone(ip netip.Addr) (c *Persistent) {
if (ip == netip.Addr{}) {
return nil
}
@@ -297,9 +297,9 @@ func (ci *Index) FindByIPWithoutZone(ip netip.Addr) (c *Persistent) {
return nil
}
// Delete removes information about persistent client from the index. c must be
// remove removes information about persistent client from the index. c must be
// non-nil.
func (ci *Index) Delete(c *Persistent) {
func (ci *index) remove(c *Persistent) {
delete(ci.nameToUID, c.Name)
for _, id := range c.ClientIDs {
@@ -322,24 +322,14 @@ func (ci *Index) Delete(c *Persistent) {
delete(ci.uidToClient, c.UID)
}
// Size returns the number of persistent clients.
func (ci *Index) Size() (n int) {
// size returns the number of persistent clients.
func (ci *index) size() (n int) {
return len(ci.uidToClient)
}
// Range calls f for each persistent client, unless cont is false. The order is
// undefined.
func (ci *Index) Range(f func(c *Persistent) (cont bool)) {
for _, c := range ci.uidToClient {
if !f(c) {
return
}
}
}
// RangeByName is like [Index.Range] but sorts the persistent clients by name
// rangeByName is like [Index.Range] but sorts the persistent clients by name
// before iterating ensuring a predictable order.
func (ci *Index) RangeByName(f func(c *Persistent) (cont bool)) {
func (ci *index) rangeByName(f func(c *Persistent) (cont bool)) {
cs := maps.Values(ci.uidToClient)
slices.SortFunc(cs, func(a, b *Persistent) (n int) {
return strings.Compare(a.Name, b.Name)
@@ -352,10 +342,10 @@ func (ci *Index) RangeByName(f func(c *Persistent) (cont bool)) {
}
}
// CloseUpstreams closes upstream configurations of persistent clients.
func (ci *Index) CloseUpstreams() (err error) {
// closeUpstreams closes upstream configurations of persistent clients.
func (ci *index) closeUpstreams() (err error) {
var errs []error
ci.RangeByName(func(c *Persistent) (cont bool) {
ci.rangeByName(func(c *Persistent) (cont bool) {
err = c.CloseUpstreams()
if err != nil {
errs = append(errs, err)

View File

@@ -11,17 +11,18 @@ import (
// newIDIndex is a helper function that returns a client index filled with
// persistent clients from the m. It also generates a UID for each client.
func newIDIndex(m []*Persistent) (ci *Index) {
ci = NewIndex()
func newIDIndex(m []*Persistent) (ci *index) {
ci = newIndex()
for _, c := range m {
c.UID = MustNewUID()
ci.Add(c)
ci.add(c)
}
return ci
}
// TODO(s.chzhen): Remove.
func TestClientIndex_Find(t *testing.T) {
const (
cliIPNone = "1.2.3.4"
@@ -109,7 +110,7 @@ func TestClientIndex_Find(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, id := range tc.ids {
c, ok := ci.Find(id)
c, ok := ci.find(id)
require.True(t, ok)
assert.Equal(t, tc.want, c)
@@ -118,7 +119,7 @@ func TestClientIndex_Find(t *testing.T) {
}
t.Run("not_found", func(t *testing.T) {
_, ok := ci.Find(cliIPNone)
_, ok := ci.find(cliIPNone)
assert.False(t, ok)
})
}
@@ -170,11 +171,11 @@ func TestClientIndex_Clashes(t *testing.T) {
clone := tc.client.ShallowClone()
clone.UID = MustNewUID()
err := ci.Clashes(clone)
err := ci.clashes(clone)
require.Error(t, err)
ci.Delete(tc.client)
err = ci.Clashes(clone)
ci.remove(tc.client)
err = ci.clashes(clone)
require.NoError(t, err)
})
}
@@ -292,7 +293,7 @@ func TestIndex_FindByIPWithoutZone(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := ci.FindByIPWithoutZone(tc.ip.WithZone(""))
c := ci.findByIPWithoutZone(tc.ip.WithZone(""))
require.Equal(t, tc.want, c)
})
}
@@ -338,7 +339,7 @@ func TestClientIndex_RangeByName(t *testing.T) {
ci := newIDIndex(tc.want)
var got []*Persistent
ci.RangeByName(func(c *Persistent) (cont bool) {
ci.rangeByName(func(c *Persistent) (cont bool) {
got = append(got, c)
return true

View File

@@ -12,6 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
@@ -64,53 +65,107 @@ type Persistent struct {
// upstream must be used.
UpstreamConfig *proxy.CustomUpstreamConfig
// SafeSearch handles search engine hosts rewrites.
SafeSearch filtering.SafeSearch
// BlockedServices is the configuration of blocked services of a client.
// BlockedServices is the configuration of blocked services of a client. It
// must not be nil after initialization.
BlockedServices *filtering.BlockedServices
// Name of the persistent client. Must not be empty.
Name string
Tags []string
// Tags is a list of client tags that categorize the client.
Tags []string
// Upstreams is a list of custom upstream DNS servers for the client.
Upstreams []string
// IPs is a list of IP addresses that identify the client. The client must
// have at least one ID (IP, subnet, MAC, or ClientID).
IPs []netip.Addr
// Subnets identifying the client. The client must have at least one ID
// (IP, subnet, MAC, or ClientID).
//
// TODO(s.chzhen): Use netutil.Prefix.
Subnets []netip.Prefix
MACs []net.HardwareAddr
Subnets []netip.Prefix
// MACs identifying the client. The client must have at least one ID (IP,
// subnet, MAC, or ClientID).
MACs []net.HardwareAddr
// ClientIDs identifying the client. The client must have at least one ID
// (IP, subnet, MAC, or ClientID).
ClientIDs []string
// UID is the unique identifier of the persistent client.
UID UID
UpstreamsCacheSize uint32
// UpstreamsCacheSize is the cache size for custom upstreams.
UpstreamsCacheSize uint32
// UpstreamsCacheEnabled specifies whether custom upstreams are used.
UpstreamsCacheEnabled bool
UseOwnSettings bool
FilteringEnabled bool
SafeBrowsingEnabled bool
ParentalEnabled bool
UseOwnBlockedServices bool
IgnoreQueryLog bool
IgnoreStatistics bool
// UseOwnSettings specifies whether custom filtering settings are used.
UseOwnSettings bool
// FilteringEnabled specifies whether filtering is enabled.
FilteringEnabled bool
// SafeBrowsingEnabled specifies whether safe browsing is enabled.
SafeBrowsingEnabled bool
// ParentalEnabled specifies whether parental control is enabled.
ParentalEnabled bool
// UseOwnBlockedServices specifies whether custom services are blocked.
UseOwnBlockedServices bool
// IgnoreQueryLog specifies whether the client requests are logged.
IgnoreQueryLog bool
// IgnoreStatistics specifies whether the client requests are counted.
IgnoreStatistics bool
// SafeSearchConf is the safe search filtering configuration.
//
// TODO(d.kolyshev): Make SafeSearchConf a pointer.
SafeSearchConf filtering.SafeSearchConfig
}
// SetTags sets the tags if they are known, otherwise logs an unknown tag.
func (c *Persistent) SetTags(tags []string, known *container.MapSet[string]) {
for _, t := range tags {
if !known.Has(t) {
log.Info("skipping unknown tag %q", t)
continue
}
c.Tags = append(c.Tags, t)
// validate returns an error if persistent client information contains errors.
func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) {
switch {
case c.Name == "":
return errors.Error("empty name")
case c.IDsLen() == 0:
return errors.Error("id required")
case c.UID == UID{}:
return errors.Error("uid required")
}
conf, err := proxy.ParseUpstreamsConfig(c.Upstreams, &upstream.Options{})
if err != nil {
return fmt.Errorf("invalid upstream servers: %w", err)
}
err = conf.Close()
if err != nil {
log.Error("client: closing upstream config: %s", err)
}
for _, t := range c.Tags {
if !allTags.Has(t) {
return fmt.Errorf("invalid tag: %q", t)
}
}
// TODO(s.chzhen): Move to the constructor.
slices.Sort(c.Tags)
return nil
}
// SetIDs parses a list of strings into typed fields and returns an error if

View File

@@ -1,13 +1,16 @@
package client
import (
"net/netip"
"testing"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPersistentClient_EqualIDs(t *testing.T) {
func TestPersistent_EqualIDs(t *testing.T) {
const (
ip = "0.0.0.0"
ip1 = "1.1.1.1"
@@ -122,3 +125,69 @@ func TestPersistentClient_EqualIDs(t *testing.T) {
})
}
}
func TestPersistent_Validate(t *testing.T) {
const (
allowedTag = "allowed_tag"
notAllowedTag = "not_allowed_tag"
)
allowedTags := container.NewMapSet(allowedTag)
testCases := []struct {
name string
cli *Persistent
wantErrMsg string
}{{
name: "success",
cli: &Persistent{
Name: "basic",
IPs: []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
},
UID: MustNewUID(),
},
wantErrMsg: "",
}, {
name: "empty_name",
cli: &Persistent{
Name: "",
},
wantErrMsg: "empty name",
}, {
name: "no_id",
cli: &Persistent{
Name: "no_id",
},
wantErrMsg: "id required",
}, {
name: "no_uid",
cli: &Persistent{
Name: "no_uid",
IPs: []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
},
},
wantErrMsg: "uid required",
}, {
name: "not_allowed_tag",
cli: &Persistent{
Name: "basic",
IPs: []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
},
UID: MustNewUID(),
Tags: []string{
notAllowedTag,
},
},
wantErrMsg: `invalid tag: "` + notAllowedTag + `"`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.cli.validate(allowedTags)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}

289
internal/client/storage.go Normal file
View File

@@ -0,0 +1,289 @@
package client
import (
"fmt"
"net"
"net/netip"
"sync"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
// Config is the client storage configuration structure.
//
// TODO(s.chzhen): Expand.
type Config struct {
// AllowedTags is a list of all allowed client tags.
AllowedTags []string
}
// Storage contains information about persistent and runtime clients.
type Storage struct {
// allowedTags is a set of all allowed tags.
allowedTags *container.MapSet[string]
// mu protects indexes of persistent and runtime clients.
mu *sync.Mutex
// index contains information about persistent clients.
index *index
// runtimeIndex contains information about runtime clients.
//
// TODO(s.chzhen): Use it.
runtimeIndex *RuntimeIndex
}
// NewStorage returns initialized client storage. conf must not be nil.
func NewStorage(conf *Config) (s *Storage) {
allowedTags := container.NewMapSet(conf.AllowedTags...)
return &Storage{
allowedTags: allowedTags,
mu: &sync.Mutex{},
index: newIndex(),
runtimeIndex: NewRuntimeIndex(),
}
}
// Add stores persistent client information or returns an error.
func (s *Storage) Add(p *Persistent) (err error) {
defer func() { err = errors.Annotate(err, "adding client: %w") }()
err = p.validate(s.allowedTags)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}
s.mu.Lock()
defer s.mu.Unlock()
err = s.index.clashesUID(p)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}
err = s.index.clashes(p)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}
s.index.add(p)
log.Debug("client storage: added %q: IDs: %q [%d]", p.Name, p.IDs(), s.index.size())
return nil
}
// FindByName finds persistent client by name. And returns its shallow copy.
func (s *Storage) FindByName(name string) (p *Persistent, ok bool) {
s.mu.Lock()
defer s.mu.Unlock()
p, ok = s.index.findByName(name)
if ok {
return p.ShallowClone(), ok
}
return nil, false
}
// Find finds persistent client by string representation of the client ID, IP
// address, or MAC. And returns its shallow copy.
func (s *Storage) Find(id string) (p *Persistent, ok bool) {
s.mu.Lock()
defer s.mu.Unlock()
p, ok = s.index.find(id)
if ok {
return p.ShallowClone(), ok
}
return nil, false
}
// FindLoose is like [Storage.Find] but it also tries to find a persistent
// client by IP address without zone. It strips the IPv6 zone index from the
// stored IP addresses before comparing, because querylog entries don't have it.
// See TODO on [querylog.logEntry.IP].
//
// Note that multiple clients can have the same IP address with different zones.
// Therefore, the result of this method is indeterminate.
func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) {
s.mu.Lock()
defer s.mu.Unlock()
p, ok = s.index.find(id)
if ok {
return p.ShallowClone(), ok
}
p = s.index.findByIPWithoutZone(ip)
if p != nil {
return p.ShallowClone(), true
}
return nil, false
}
// FindByMAC finds persistent client by MAC and returns its shallow copy.
func (s *Storage) FindByMAC(mac net.HardwareAddr) (p *Persistent, ok bool) {
s.mu.Lock()
defer s.mu.Unlock()
p, ok = s.index.findByMAC(mac)
if ok {
return p.ShallowClone(), ok
}
return nil, false
}
// RemoveByName removes persistent client information. ok is false if no such
// client exists by that name.
func (s *Storage) RemoveByName(name string) (ok bool) {
s.mu.Lock()
defer s.mu.Unlock()
p, ok := s.index.findByName(name)
if !ok {
return false
}
if err := p.CloseUpstreams(); err != nil {
log.Error("client storage: removing client %q: %s", p.Name, err)
}
s.index.remove(p)
return true
}
// Update finds the stored persistent client by its name and updates its
// information from p.
func (s *Storage) Update(name string, p *Persistent) (err error) {
defer func() { err = errors.Annotate(err, "updating client: %w") }()
err = p.validate(s.allowedTags)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}
s.mu.Lock()
defer s.mu.Unlock()
stored, ok := s.index.findByName(name)
if !ok {
return fmt.Errorf("client %q is not found", name)
}
// Client p has a newly generated UID, so replace it with the stored one.
//
// TODO(s.chzhen): Remove when frontend starts handling UIDs.
p.UID = stored.UID
err = s.index.clashes(p)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}
s.index.remove(stored)
s.index.add(p)
return nil
}
// RangeByName calls f for each persistent client sorted by name, unless cont is
// false.
func (s *Storage) RangeByName(f func(c *Persistent) (cont bool)) {
s.mu.Lock()
defer s.mu.Unlock()
s.index.rangeByName(f)
}
// Size returns the number of persistent clients.
func (s *Storage) Size() (n int) {
s.mu.Lock()
defer s.mu.Unlock()
return s.index.size()
}
// CloseUpstreams closes upstream configurations of persistent clients.
func (s *Storage) CloseUpstreams() (err error) {
s.mu.Lock()
defer s.mu.Unlock()
return s.index.closeUpstreams()
}
// ClientRuntime returns a copy of the saved runtime client by ip. If no such
// client exists, returns nil.
//
// TODO(s.chzhen): Use it.
func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) {
s.mu.Lock()
defer s.mu.Unlock()
return s.runtimeIndex.Client(ip)
}
// AddRuntime saves the runtime client information in the storage. IP address
// of a client must be unique. rc must not be nil.
//
// TODO(s.chzhen): Use it.
func (s *Storage) AddRuntime(rc *Runtime) {
s.mu.Lock()
defer s.mu.Unlock()
s.runtimeIndex.Add(rc)
}
// SizeRuntime returns the number of the runtime clients.
//
// TODO(s.chzhen): Use it.
func (s *Storage) SizeRuntime() (n int) {
s.mu.Lock()
defer s.mu.Unlock()
return s.runtimeIndex.Size()
}
// RangeRuntime calls f for each runtime client in an undefined order.
//
// TODO(s.chzhen): Use it.
func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
s.mu.Lock()
defer s.mu.Unlock()
s.runtimeIndex.Range(f)
}
// DeleteRuntime removes the runtime client by ip.
//
// TODO(s.chzhen): Use it.
func (s *Storage) DeleteRuntime(ip netip.Addr) {
s.mu.Lock()
defer s.mu.Unlock()
s.runtimeIndex.Delete(ip)
}
// DeleteBySource removes all runtime clients that have information only from
// the specified source and returns the number of removed clients.
//
// TODO(s.chzhen): Use it.
func (s *Storage) DeleteBySource(src Source) (n int) {
s.mu.Lock()
defer s.mu.Unlock()
return s.runtimeIndex.DeleteBySource(src)
}

View File

@@ -0,0 +1,481 @@
package client_test
import (
"net"
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newStorage is a helper function that returns a client storage filled with
// persistent clients from the m. It also generates a UID for each client.
func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) {
tb.Helper()
s = client.NewStorage(&client.Config{
AllowedTags: nil,
})
for _, c := range m {
c.UID = client.MustNewUID()
require.NoError(tb, s.Add(c))
}
return s
}
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
// error.
func mustParseMAC(s string) (mac net.HardwareAddr) {
mac, err := net.ParseMAC(s)
if err != nil {
panic(err)
}
return mac
}
func TestStorage_Add(t *testing.T) {
const (
existingName = "existing_name"
existingClientID = "existing_client_id"
)
var (
existingClientUID = client.MustNewUID()
existingIP = netip.MustParseAddr("1.2.3.4")
existingSubnet = netip.MustParsePrefix("1.2.3.0/24")
)
existingClient := &client.Persistent{
Name: existingName,
IPs: []netip.Addr{existingIP},
Subnets: []netip.Prefix{existingSubnet},
ClientIDs: []string{existingClientID},
UID: existingClientUID,
}
s := client.NewStorage(&client.Config{
AllowedTags: nil,
})
err := s.Add(existingClient)
require.NoError(t, err)
testCases := []struct {
name string
cli *client.Persistent
wantErrMsg string
}{{
name: "basic",
cli: &client.Persistent{
Name: "basic",
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
UID: client.MustNewUID(),
},
wantErrMsg: "",
}, {
name: "duplicate_uid",
cli: &client.Persistent{
Name: "no_uid",
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
UID: existingClientUID,
},
wantErrMsg: `adding client: another client "existing_name" uses the same uid`,
}, {
name: "duplicate_name",
cli: &client.Persistent{
Name: existingName,
IPs: []netip.Addr{netip.MustParseAddr("3.3.3.3")},
UID: client.MustNewUID(),
},
wantErrMsg: `adding client: another client uses the same name "existing_name"`,
}, {
name: "duplicate_ip",
cli: &client.Persistent{
Name: "duplicate_ip",
IPs: []netip.Addr{existingIP},
UID: client.MustNewUID(),
},
wantErrMsg: `adding client: another client "existing_name" uses the same IP "1.2.3.4"`,
}, {
name: "duplicate_subnet",
cli: &client.Persistent{
Name: "duplicate_subnet",
Subnets: []netip.Prefix{existingSubnet},
UID: client.MustNewUID(),
},
wantErrMsg: `adding client: another client "existing_name" ` +
`uses the same subnet "1.2.3.0/24"`,
}, {
name: "duplicate_client_id",
cli: &client.Persistent{
Name: "duplicate_client_id",
ClientIDs: []string{existingClientID},
UID: client.MustNewUID(),
},
wantErrMsg: `adding client: another client "existing_name" ` +
`uses the same ClientID "existing_client_id"`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err = s.Add(tc.cli)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}
func TestStorage_RemoveByName(t *testing.T) {
const (
existingName = "existing_name"
)
existingClient := &client.Persistent{
Name: existingName,
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
UID: client.MustNewUID(),
}
s := client.NewStorage(&client.Config{
AllowedTags: nil,
})
err := s.Add(existingClient)
require.NoError(t, err)
testCases := []struct {
want assert.BoolAssertionFunc
name string
cliName string
}{{
name: "existing_client",
cliName: existingName,
want: assert.True,
}, {
name: "non_existing_client",
cliName: "non_existing_client",
want: assert.False,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tc.want(t, s.RemoveByName(tc.cliName))
})
}
t.Run("duplicate_remove", func(t *testing.T) {
s = client.NewStorage(&client.Config{
AllowedTags: nil,
})
err = s.Add(existingClient)
require.NoError(t, err)
assert.True(t, s.RemoveByName(existingName))
assert.False(t, s.RemoveByName(existingName))
})
}
func TestStorage_Find(t *testing.T) {
const (
cliIPNone = "1.2.3.4"
cliIP1 = "1.1.1.1"
cliIP2 = "2.2.2.2"
cliIPv6 = "1:2:3::4"
cliSubnet = "2.2.2.0/24"
cliSubnetIP = "2.2.2.222"
cliID = "client-id"
cliMAC = "11:11:11:11:11:11"
linkLocalIP = "fe80::abcd:abcd:abcd:ab%eth0"
linkLocalSubnet = "fe80::/16"
)
var (
clientWithBothFams = &client.Persistent{
Name: "client1",
IPs: []netip.Addr{
netip.MustParseAddr(cliIP1),
netip.MustParseAddr(cliIPv6),
},
}
clientWithSubnet = &client.Persistent{
Name: "client2",
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
}
clientWithMAC = &client.Persistent{
Name: "client_with_mac",
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
}
clientWithID = &client.Persistent{
Name: "client_with_id",
ClientIDs: []string{cliID},
}
clientLinkLocal = &client.Persistent{
Name: "client_link_local",
Subnets: []netip.Prefix{netip.MustParsePrefix(linkLocalSubnet)},
}
)
clients := []*client.Persistent{
clientWithBothFams,
clientWithSubnet,
clientWithMAC,
clientWithID,
clientLinkLocal,
}
s := newStorage(t, clients)
testCases := []struct {
want *client.Persistent
name string
ids []string
}{{
name: "ipv4_ipv6",
ids: []string{cliIP1, cliIPv6},
want: clientWithBothFams,
}, {
name: "ipv4_subnet",
ids: []string{cliIP2, cliSubnetIP},
want: clientWithSubnet,
}, {
name: "mac",
ids: []string{cliMAC},
want: clientWithMAC,
}, {
name: "client_id",
ids: []string{cliID},
want: clientWithID,
}, {
name: "client_link_local_subnet",
ids: []string{linkLocalIP},
want: clientLinkLocal,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, id := range tc.ids {
c, ok := s.Find(id)
require.True(t, ok)
assert.Equal(t, tc.want, c)
}
})
}
t.Run("not_found", func(t *testing.T) {
_, ok := s.Find(cliIPNone)
assert.False(t, ok)
})
}
func TestStorage_FindLoose(t *testing.T) {
const (
nonExistingClientID = "client_id"
)
var (
ip = netip.MustParseAddr("fe80::a098:7654:32ef:ff1")
ipWithZone = netip.MustParseAddr("fe80::1ff:fe23:4567:890a%eth2")
)
var (
clientNoZone = &client.Persistent{
Name: "client",
IPs: []netip.Addr{ip},
}
clientWithZone = &client.Persistent{
Name: "client_with_zone",
IPs: []netip.Addr{ipWithZone},
}
)
s := newStorage(
t,
[]*client.Persistent{
clientNoZone,
clientWithZone,
},
)
testCases := []struct {
ip netip.Addr
want assert.BoolAssertionFunc
wantCli *client.Persistent
name string
}{{
name: "without_zone",
ip: ip,
wantCli: clientNoZone,
want: assert.True,
}, {
name: "with_zone",
ip: ipWithZone,
wantCli: clientWithZone,
want: assert.True,
}, {
name: "zero_address",
ip: netip.Addr{},
wantCli: nil,
want: assert.False,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c, ok := s.FindLoose(tc.ip.WithZone(""), nonExistingClientID)
assert.Equal(t, tc.wantCli, c)
tc.want(t, ok)
})
}
}
func TestStorage_Update(t *testing.T) {
const (
clientName = "client_name"
obstructingName = "obstructing_name"
obstructingClientID = "obstructing_client_id"
)
var (
obstructingIP = netip.MustParseAddr("1.2.3.4")
obstructingSubnet = netip.MustParsePrefix("1.2.3.0/24")
)
obstructingClient := &client.Persistent{
Name: obstructingName,
IPs: []netip.Addr{obstructingIP},
Subnets: []netip.Prefix{obstructingSubnet},
ClientIDs: []string{obstructingClientID},
}
clientToUpdate := &client.Persistent{
Name: clientName,
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
}
testCases := []struct {
name string
cli *client.Persistent
wantErrMsg string
}{{
name: "basic",
cli: &client.Persistent{
Name: "basic",
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
UID: client.MustNewUID(),
},
wantErrMsg: "",
}, {
name: "duplicate_name",
cli: &client.Persistent{
Name: obstructingName,
IPs: []netip.Addr{netip.MustParseAddr("3.3.3.3")},
UID: client.MustNewUID(),
},
wantErrMsg: `updating client: another client uses the same name "obstructing_name"`,
}, {
name: "duplicate_ip",
cli: &client.Persistent{
Name: "duplicate_ip",
IPs: []netip.Addr{obstructingIP},
UID: client.MustNewUID(),
},
wantErrMsg: `updating client: another client "obstructing_name" uses the same IP "1.2.3.4"`,
}, {
name: "duplicate_subnet",
cli: &client.Persistent{
Name: "duplicate_subnet",
Subnets: []netip.Prefix{obstructingSubnet},
UID: client.MustNewUID(),
},
wantErrMsg: `updating client: another client "obstructing_name" ` +
`uses the same subnet "1.2.3.0/24"`,
}, {
name: "duplicate_client_id",
cli: &client.Persistent{
Name: "duplicate_client_id",
ClientIDs: []string{obstructingClientID},
UID: client.MustNewUID(),
},
wantErrMsg: `updating client: another client "obstructing_name" ` +
`uses the same ClientID "obstructing_client_id"`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := newStorage(
t,
[]*client.Persistent{
clientToUpdate,
obstructingClient,
},
)
err := s.Update(clientName, tc.cli)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}
func TestStorage_RangeByName(t *testing.T) {
sortedClients := []*client.Persistent{{
Name: "clientA",
ClientIDs: []string{"A"},
}, {
Name: "clientB",
ClientIDs: []string{"B"},
}, {
Name: "clientC",
ClientIDs: []string{"C"},
}, {
Name: "clientD",
ClientIDs: []string{"D"},
}, {
Name: "clientE",
ClientIDs: []string{"E"},
}}
testCases := []struct {
name string
want []*client.Persistent
}{{
name: "basic",
want: sortedClients,
}, {
name: "nil",
want: nil,
}, {
name: "one_element",
want: sortedClients[:1],
}, {
name: "two_elements",
want: sortedClients[:2],
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := newStorage(t, tc.want)
var got []*client.Persistent
s.RangeByName(func(c *client.Persistent) (cont bool) {
got = append(got, c)
return true
})
assert.Equal(t, tc.want, got)
})
}
}