Pull request 2265: AG-27492-client-runtime-storage
Squashed commit of the following:
commit a164bace2e0333cf95622f34df7b0e79eac69f41
Merge: 6567cd330 184f476bd
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date: Wed Aug 21 16:14:55 2024 +0300
Merge branch 'master' into AG-27492-client-runtime-storage
commit 6567cd330ce76d4744f5eb990c09efdb5481aea9
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date: Tue Aug 20 16:45:43 2024 +0300
all: imp code
commit 243123a404bb5279a27de18391fa58a9d3e6149b
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date: Thu Aug 15 19:15:54 2024 +0300
all: add tests
commit 6489996878ab6bf2ec93c359599ae29e58e938a0
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date: Mon Aug 5 15:12:05 2024 +0300
all: client runtime storage
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"encoding"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
)
|
||||
@@ -120,6 +121,7 @@ func (r *Runtime) Info() (cs Source, host string) {
|
||||
|
||||
// SetInfo sets a host as a client information from the cs.
|
||||
func (r *Runtime) SetInfo(cs Source, hosts []string) {
|
||||
// TODO(s.chzhen): Use contract where hosts must contain non-empty host.
|
||||
if len(hosts) == 1 && hosts[0] == "" {
|
||||
hosts = []string{}
|
||||
}
|
||||
@@ -175,3 +177,15 @@ func (r *Runtime) isEmpty() (ok bool) {
|
||||
func (r *Runtime) Addr() (ip netip.Addr) {
|
||||
return r.ip
|
||||
}
|
||||
|
||||
// Clone returns a deep copy of the runtime client.
|
||||
func (r *Runtime) Clone() (c *Runtime) {
|
||||
return &Runtime{
|
||||
ip: r.ip,
|
||||
whois: r.whois.Clone(),
|
||||
arp: slices.Clone(r.arp),
|
||||
rdns: slices.Clone(r.rdns),
|
||||
dhcp: slices.Clone(r.dhcp),
|
||||
hostsFile: slices.Clone(r.hostsFile),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
@@ -31,8 +32,6 @@ type Storage struct {
|
||||
index *index
|
||||
|
||||
// runtimeIndex contains information about runtime clients.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
runtimeIndex *RuntimeIndex
|
||||
}
|
||||
|
||||
@@ -236,20 +235,75 @@ func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) {
|
||||
return s.runtimeIndex.Client(ip)
|
||||
}
|
||||
|
||||
// AddRuntime saves the runtime client information in the storage. IP address
|
||||
// of a client must be unique. rc must not be nil.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
func (s *Storage) AddRuntime(rc *Runtime) {
|
||||
// UpdateRuntime updates the stored runtime client with information from rc. If
|
||||
// no such client exists, saves the copy of rc in storage. rc must not be nil.
|
||||
func (s *Storage) UpdateRuntime(rc *Runtime) (added bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.runtimeIndex.Add(rc)
|
||||
return s.updateRuntimeLocked(rc)
|
||||
}
|
||||
|
||||
// updateRuntimeLocked updates the stored runtime client with information from
|
||||
// rc. rc must not be nil. Storage.mu is expected to be locked.
|
||||
func (s *Storage) updateRuntimeLocked(rc *Runtime) (added bool) {
|
||||
stored := s.runtimeIndex.Client(rc.ip)
|
||||
if stored == nil {
|
||||
s.runtimeIndex.Add(rc.Clone())
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
if rc.whois != nil {
|
||||
stored.whois = rc.whois.Clone()
|
||||
}
|
||||
|
||||
if rc.arp != nil {
|
||||
stored.arp = slices.Clone(rc.arp)
|
||||
}
|
||||
|
||||
if rc.rdns != nil {
|
||||
stored.rdns = slices.Clone(rc.rdns)
|
||||
}
|
||||
|
||||
if rc.dhcp != nil {
|
||||
stored.dhcp = slices.Clone(rc.dhcp)
|
||||
}
|
||||
|
||||
if rc.hostsFile != nil {
|
||||
stored.hostsFile = slices.Clone(rc.hostsFile)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// BatchUpdateBySource updates the stored runtime clients information from the
|
||||
// specified source and returns the number of added and removed clients.
|
||||
func (s *Storage) BatchUpdateBySource(src Source, rcs []*Runtime) (added, removed int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for _, rc := range s.runtimeIndex.index {
|
||||
rc.unset(src)
|
||||
}
|
||||
|
||||
for _, rc := range rcs {
|
||||
if s.updateRuntimeLocked(rc) {
|
||||
added++
|
||||
}
|
||||
}
|
||||
|
||||
for ip, rc := range s.runtimeIndex.index {
|
||||
if rc.isEmpty() {
|
||||
delete(s.runtimeIndex.index, ip)
|
||||
removed++
|
||||
}
|
||||
}
|
||||
|
||||
return added, removed
|
||||
}
|
||||
|
||||
// SizeRuntime returns the number of the runtime clients.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
func (s *Storage) SizeRuntime() (n int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -258,8 +312,6 @@ func (s *Storage) SizeRuntime() (n int) {
|
||||
}
|
||||
|
||||
// RangeRuntime calls f for each runtime client in an undefined order.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -267,16 +319,6 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
|
||||
s.runtimeIndex.Range(f)
|
||||
}
|
||||
|
||||
// DeleteRuntime removes the runtime client by ip.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
func (s *Storage) DeleteRuntime(ip netip.Addr) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.runtimeIndex.Delete(ip)
|
||||
}
|
||||
|
||||
// DeleteBySource removes all runtime clients that have information only from
|
||||
// the specified source and returns the number of removed clients.
|
||||
//
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -25,9 +26,19 @@ func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) {
|
||||
require.NoError(tb, s.Add(c))
|
||||
}
|
||||
|
||||
require.Equal(tb, len(m), s.Size())
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// newRuntimeClient is a helper function that returns a new runtime client.
|
||||
func newRuntimeClient(ip netip.Addr, source client.Source, host string) (rc *client.Runtime) {
|
||||
rc = client.NewRuntime(ip)
|
||||
rc.SetInfo(source, []string{host})
|
||||
|
||||
return rc
|
||||
}
|
||||
|
||||
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
|
||||
// error.
|
||||
func mustParseMAC(s string) (mac net.HardwareAddr) {
|
||||
@@ -43,6 +54,9 @@ func TestStorage_Add(t *testing.T) {
|
||||
const (
|
||||
existingName = "existing_name"
|
||||
existingClientID = "existing_client_id"
|
||||
|
||||
allowedTag = "tag"
|
||||
notAllowedTag = "not_allowed_tag"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -60,7 +74,7 @@ func TestStorage_Add(t *testing.T) {
|
||||
}
|
||||
|
||||
s := client.NewStorage(&client.Config{
|
||||
AllowedTags: nil,
|
||||
AllowedTags: []string{allowedTag},
|
||||
})
|
||||
err := s.Add(existingClient)
|
||||
require.NoError(t, err)
|
||||
@@ -119,6 +133,15 @@ func TestStorage_Add(t *testing.T) {
|
||||
},
|
||||
wantErrMsg: `adding client: another client "existing_name" ` +
|
||||
`uses the same ClientID "existing_client_id"`,
|
||||
}, {
|
||||
name: "not_allowed_tag",
|
||||
cli: &client.Persistent{
|
||||
Name: "nont_allowed_tag",
|
||||
Tags: []string{notAllowedTag},
|
||||
IPs: []netip.Addr{netip.MustParseAddr("4.4.4.4")},
|
||||
UID: client.MustNewUID(),
|
||||
},
|
||||
wantErrMsg: `adding client: invalid tag: "not_allowed_tag"`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -341,6 +364,127 @@ func TestStorage_FindLoose(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorage_FindByName(t *testing.T) {
|
||||
const (
|
||||
cliIP1 = "1.1.1.1"
|
||||
cliIP2 = "2.2.2.2"
|
||||
)
|
||||
|
||||
const (
|
||||
clientExistingName = "client_existing"
|
||||
clientAnotherExistingName = "client_another_existing"
|
||||
nonExistingClientName = "client_non_existing"
|
||||
)
|
||||
|
||||
var (
|
||||
clientExisting = &client.Persistent{
|
||||
Name: clientExistingName,
|
||||
IPs: []netip.Addr{netip.MustParseAddr(cliIP1)},
|
||||
}
|
||||
|
||||
clientAnotherExisting = &client.Persistent{
|
||||
Name: clientAnotherExistingName,
|
||||
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
|
||||
}
|
||||
)
|
||||
|
||||
clients := []*client.Persistent{
|
||||
clientExisting,
|
||||
clientAnotherExisting,
|
||||
}
|
||||
s := newStorage(t, clients)
|
||||
|
||||
testCases := []struct {
|
||||
want *client.Persistent
|
||||
name string
|
||||
clientName string
|
||||
}{{
|
||||
name: "existing",
|
||||
clientName: clientExistingName,
|
||||
want: clientExisting,
|
||||
}, {
|
||||
name: "another_existing",
|
||||
clientName: clientAnotherExistingName,
|
||||
want: clientAnotherExisting,
|
||||
}, {
|
||||
name: "non_existing",
|
||||
clientName: nonExistingClientName,
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c, ok := s.FindByName(tc.clientName)
|
||||
if tc.want == nil {
|
||||
assert.False(t, ok)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, tc.want, c)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorage_FindByMAC(t *testing.T) {
|
||||
var (
|
||||
cliMAC = mustParseMAC("11:11:11:11:11:11")
|
||||
cliAnotherMAC = mustParseMAC("22:22:22:22:22:22")
|
||||
nonExistingClientMAC = mustParseMAC("33:33:33:33:33:33")
|
||||
)
|
||||
|
||||
var (
|
||||
clientExisting = &client.Persistent{
|
||||
Name: "client",
|
||||
MACs: []net.HardwareAddr{cliMAC},
|
||||
}
|
||||
|
||||
clientAnotherExisting = &client.Persistent{
|
||||
Name: "another_client",
|
||||
MACs: []net.HardwareAddr{cliAnotherMAC},
|
||||
}
|
||||
)
|
||||
|
||||
clients := []*client.Persistent{
|
||||
clientExisting,
|
||||
clientAnotherExisting,
|
||||
}
|
||||
s := newStorage(t, clients)
|
||||
|
||||
testCases := []struct {
|
||||
want *client.Persistent
|
||||
name string
|
||||
clientMAC net.HardwareAddr
|
||||
}{{
|
||||
name: "existing",
|
||||
clientMAC: cliMAC,
|
||||
want: clientExisting,
|
||||
}, {
|
||||
name: "another_existing",
|
||||
clientMAC: cliAnotherMAC,
|
||||
want: clientAnotherExisting,
|
||||
}, {
|
||||
name: "non_existing",
|
||||
clientMAC: nonExistingClientMAC,
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c, ok := s.FindByMAC(tc.clientMAC)
|
||||
if tc.want == nil {
|
||||
assert.False(t, ok)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, tc.want, c)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorage_Update(t *testing.T) {
|
||||
const (
|
||||
clientName = "client_name"
|
||||
@@ -479,3 +623,157 @@ func TestStorage_RangeByName(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorage_UpdateRuntime(t *testing.T) {
|
||||
const (
|
||||
addedARP = "added_arp"
|
||||
addedSecondARP = "added_arp"
|
||||
|
||||
updatedARP = "updated_arp"
|
||||
|
||||
cliCity = "City"
|
||||
cliCountry = "Country"
|
||||
cliOrgname = "Orgname"
|
||||
)
|
||||
|
||||
var (
|
||||
ip = netip.MustParseAddr("1.1.1.1")
|
||||
ip2 = netip.MustParseAddr("2.2.2.2")
|
||||
)
|
||||
|
||||
updated := client.NewRuntime(ip)
|
||||
updated.SetInfo(client.SourceARP, []string{updatedARP})
|
||||
|
||||
info := &whois.Info{
|
||||
City: cliCity,
|
||||
Country: cliCountry,
|
||||
Orgname: cliOrgname,
|
||||
}
|
||||
updated.SetWHOIS(info)
|
||||
|
||||
s := client.NewStorage(&client.Config{
|
||||
AllowedTags: nil,
|
||||
})
|
||||
|
||||
t.Run("add_arp_client", func(t *testing.T) {
|
||||
added := client.NewRuntime(ip)
|
||||
added.SetInfo(client.SourceARP, []string{addedARP})
|
||||
|
||||
require.True(t, s.UpdateRuntime(added))
|
||||
require.Equal(t, 1, s.SizeRuntime())
|
||||
|
||||
got := s.ClientRuntime(ip)
|
||||
source, host := got.Info()
|
||||
assert.Equal(t, client.SourceARP, source)
|
||||
assert.Equal(t, addedARP, host)
|
||||
})
|
||||
|
||||
t.Run("add_second_arp_client", func(t *testing.T) {
|
||||
added := client.NewRuntime(ip2)
|
||||
added.SetInfo(client.SourceARP, []string{addedSecondARP})
|
||||
|
||||
require.True(t, s.UpdateRuntime(added))
|
||||
require.Equal(t, 2, s.SizeRuntime())
|
||||
|
||||
got := s.ClientRuntime(ip2)
|
||||
source, host := got.Info()
|
||||
assert.Equal(t, client.SourceARP, source)
|
||||
assert.Equal(t, addedSecondARP, host)
|
||||
})
|
||||
|
||||
t.Run("update_first_client", func(t *testing.T) {
|
||||
require.False(t, s.UpdateRuntime(updated))
|
||||
got := s.ClientRuntime(ip)
|
||||
require.Equal(t, 2, s.SizeRuntime())
|
||||
|
||||
source, host := got.Info()
|
||||
assert.Equal(t, client.SourceARP, source)
|
||||
assert.Equal(t, updatedARP, host)
|
||||
})
|
||||
|
||||
t.Run("remove_arp_info", func(t *testing.T) {
|
||||
n := s.DeleteBySource(client.SourceARP)
|
||||
require.Equal(t, 1, n)
|
||||
require.Equal(t, 1, s.SizeRuntime())
|
||||
|
||||
got := s.ClientRuntime(ip)
|
||||
source, _ := got.Info()
|
||||
assert.Equal(t, client.SourceWHOIS, source)
|
||||
assert.Equal(t, info, got.WHOIS())
|
||||
})
|
||||
|
||||
t.Run("remove_whois_info", func(t *testing.T) {
|
||||
n := s.DeleteBySource(client.SourceWHOIS)
|
||||
require.Equal(t, 1, n)
|
||||
require.Equal(t, 0, s.SizeRuntime())
|
||||
})
|
||||
}
|
||||
|
||||
func TestStorage_BatchUpdateBySource(t *testing.T) {
|
||||
const (
|
||||
defSrc = client.SourceARP
|
||||
|
||||
cliFirstHost1 = "host1"
|
||||
cliFirstHost2 = "host2"
|
||||
cliUpdatedHost3 = "host3"
|
||||
cliUpdatedHost4 = "host4"
|
||||
cliUpdatedHost5 = "host5"
|
||||
)
|
||||
|
||||
var (
|
||||
cliFirstIP1 = netip.MustParseAddr("1.1.1.1")
|
||||
cliFirstIP2 = netip.MustParseAddr("2.2.2.2")
|
||||
cliUpdatedIP3 = netip.MustParseAddr("3.3.3.3")
|
||||
cliUpdatedIP4 = netip.MustParseAddr("4.4.4.4")
|
||||
cliUpdatedIP5 = netip.MustParseAddr("5.5.5.5")
|
||||
)
|
||||
|
||||
firstClients := []*client.Runtime{
|
||||
newRuntimeClient(cliFirstIP1, defSrc, cliFirstHost1),
|
||||
newRuntimeClient(cliFirstIP2, defSrc, cliFirstHost2),
|
||||
}
|
||||
|
||||
updatedClients := []*client.Runtime{
|
||||
newRuntimeClient(cliUpdatedIP3, defSrc, cliUpdatedHost3),
|
||||
newRuntimeClient(cliUpdatedIP4, defSrc, cliUpdatedHost4),
|
||||
newRuntimeClient(cliUpdatedIP5, defSrc, cliUpdatedHost5),
|
||||
}
|
||||
|
||||
s := client.NewStorage(&client.Config{
|
||||
AllowedTags: nil,
|
||||
})
|
||||
|
||||
t.Run("populate_storage_with_first_clients", func(t *testing.T) {
|
||||
added, removed := s.BatchUpdateBySource(defSrc, firstClients)
|
||||
require.Equal(t, len(firstClients), added)
|
||||
require.Equal(t, 0, removed)
|
||||
require.Equal(t, len(firstClients), s.SizeRuntime())
|
||||
|
||||
rc := s.ClientRuntime(cliFirstIP1)
|
||||
src, host := rc.Info()
|
||||
assert.Equal(t, defSrc, src)
|
||||
assert.Equal(t, cliFirstHost1, host)
|
||||
})
|
||||
|
||||
t.Run("update_storage", func(t *testing.T) {
|
||||
added, removed := s.BatchUpdateBySource(defSrc, updatedClients)
|
||||
require.Equal(t, len(updatedClients), added)
|
||||
require.Equal(t, len(firstClients), removed)
|
||||
require.Equal(t, len(updatedClients), s.SizeRuntime())
|
||||
|
||||
rc := s.ClientRuntime(cliUpdatedIP3)
|
||||
src, host := rc.Info()
|
||||
assert.Equal(t, defSrc, src)
|
||||
assert.Equal(t, cliUpdatedHost3, host)
|
||||
|
||||
rc = s.ClientRuntime(cliFirstIP1)
|
||||
assert.Nil(t, rc)
|
||||
})
|
||||
|
||||
t.Run("remove_all", func(t *testing.T) {
|
||||
added, removed := s.BatchUpdateBySource(defSrc, []*client.Runtime{})
|
||||
require.Equal(t, 0, added)
|
||||
require.Equal(t, len(updatedClients), removed)
|
||||
require.Equal(t, 0, s.SizeRuntime())
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user