all: resync with master
This commit is contained in:
@@ -2,6 +2,7 @@ package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
)
|
||||
|
||||
@@ -38,6 +40,10 @@ func (EmptyAddrProc) Close() (_ error) { return nil }
|
||||
|
||||
// DefaultAddrProcConfig is the configuration structure for address processors.
|
||||
type DefaultAddrProcConfig struct {
|
||||
// BaseLogger is used to create loggers with custom prefixes for sources of
|
||||
// information about runtime clients. It must not be nil.
|
||||
BaseLogger *slog.Logger
|
||||
|
||||
// DialContext is used to create TCP connections to WHOIS servers.
|
||||
// DialContext must not be nil if [DefaultAddrProcConfig.UseWHOIS] is true.
|
||||
DialContext aghnet.DialContextFunc
|
||||
@@ -147,6 +153,7 @@ func NewDefaultAddrProc(c *DefaultAddrProcConfig) (p *DefaultAddrProc) {
|
||||
|
||||
if c.UseRDNS {
|
||||
p.rdns = rdns.New(&rdns.Config{
|
||||
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "rdns"),
|
||||
Exchanger: c.Exchanger,
|
||||
CacheSize: defaultCacheSize,
|
||||
CacheTTL: defaultIPTTL,
|
||||
@@ -154,7 +161,7 @@ func NewDefaultAddrProc(c *DefaultAddrProcConfig) (p *DefaultAddrProc) {
|
||||
}
|
||||
|
||||
if c.UseWHOIS {
|
||||
p.whois = newWHOIS(c.DialContext)
|
||||
p.whois = newWHOIS(c.BaseLogger.With(slogutil.KeyPrefix, "whois"), c.DialContext)
|
||||
}
|
||||
|
||||
go p.process(c.CatchPanics)
|
||||
@@ -168,7 +175,7 @@ func NewDefaultAddrProc(c *DefaultAddrProcConfig) (p *DefaultAddrProc) {
|
||||
|
||||
// newWHOIS returns a whois.Interface instance using the given function for
|
||||
// dialing.
|
||||
func newWHOIS(dialFunc aghnet.DialContextFunc) (w whois.Interface) {
|
||||
func newWHOIS(logger *slog.Logger, dialFunc aghnet.DialContextFunc) (w whois.Interface) {
|
||||
// TODO(s.chzhen): Consider making configurable.
|
||||
const (
|
||||
// defaultTimeout is the timeout for WHOIS requests.
|
||||
@@ -186,6 +193,7 @@ func newWHOIS(dialFunc aghnet.DialContextFunc) (w whois.Interface) {
|
||||
)
|
||||
|
||||
return whois.New(&whois.Config{
|
||||
Logger: logger,
|
||||
DialContext: dialFunc,
|
||||
ServerAddr: whois.DefaultServer,
|
||||
Port: whois.DefaultPort,
|
||||
@@ -227,9 +235,11 @@ func (p *DefaultAddrProc) process(catchPanics bool) {
|
||||
|
||||
log.Info("clients: processing addresses")
|
||||
|
||||
ctx := context.TODO()
|
||||
|
||||
for ip := range p.clientIPs {
|
||||
host := p.processRDNS(ip)
|
||||
info := p.processWHOIS(ip)
|
||||
host := p.processRDNS(ctx, ip)
|
||||
info := p.processWHOIS(ctx, ip)
|
||||
|
||||
p.addrUpdater.UpdateAddress(ip, host, info)
|
||||
}
|
||||
@@ -239,7 +249,7 @@ func (p *DefaultAddrProc) process(catchPanics bool) {
|
||||
|
||||
// processRDNS resolves the clients' IP addresses using reverse DNS. host is
|
||||
// empty if there were errors or if the information hasn't changed.
|
||||
func (p *DefaultAddrProc) processRDNS(ip netip.Addr) (host string) {
|
||||
func (p *DefaultAddrProc) processRDNS(ctx context.Context, ip netip.Addr) (host string) {
|
||||
start := time.Now()
|
||||
log.Debug("clients: processing %s with rdns", ip)
|
||||
defer func() {
|
||||
@@ -251,7 +261,7 @@ func (p *DefaultAddrProc) processRDNS(ip netip.Addr) (host string) {
|
||||
return
|
||||
}
|
||||
|
||||
host, changed := p.rdns.Process(ip)
|
||||
host, changed := p.rdns.Process(ctx, ip)
|
||||
if !changed {
|
||||
host = ""
|
||||
}
|
||||
@@ -268,7 +278,7 @@ func (p *DefaultAddrProc) shouldResolve(ip netip.Addr) (ok bool) {
|
||||
// processWHOIS looks up the information about clients' IP addresses in the
|
||||
// WHOIS databases. info is nil if there were errors or if the information
|
||||
// hasn't changed.
|
||||
func (p *DefaultAddrProc) processWHOIS(ip netip.Addr) (info *whois.Info) {
|
||||
func (p *DefaultAddrProc) processWHOIS(ctx context.Context, ip netip.Addr) (info *whois.Info) {
|
||||
start := time.Now()
|
||||
log.Debug("clients: processing %s with whois", ip)
|
||||
defer func() {
|
||||
@@ -277,7 +287,7 @@ func (p *DefaultAddrProc) processWHOIS(ip netip.Addr) (info *whois.Info) {
|
||||
|
||||
// TODO(s.chzhen): Move the timeout logic from WHOIS configuration to the
|
||||
// context.
|
||||
info, changed := p.whois.Process(context.Background(), ip)
|
||||
info, changed := p.whois.Process(ctx, ip)
|
||||
if !changed {
|
||||
info = nil
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/testutil/fakenet"
|
||||
@@ -99,6 +100,7 @@ func TestDefaultAddrProc_Process_rDNS(t *testing.T) {
|
||||
updInfoCh := make(chan *whois.Info, 1)
|
||||
|
||||
p := client.NewDefaultAddrProc(&client.DefaultAddrProcConfig{
|
||||
BaseLogger: slogutil.NewDiscardLogger(),
|
||||
DialContext: func(_ context.Context, _, _ string) (conn net.Conn, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
@@ -208,6 +210,7 @@ func TestDefaultAddrProc_Process_WHOIS(t *testing.T) {
|
||||
updInfoCh := make(chan *whois.Info, 1)
|
||||
|
||||
p := client.NewDefaultAddrProc(&client.DefaultAddrProcConfig{
|
||||
BaseLogger: slogutil.NewDiscardLogger(),
|
||||
DialContext: func(_ context.Context, _, _ string) (conn net.Conn, err error) {
|
||||
return whoisConn, nil
|
||||
},
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"encoding"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
)
|
||||
@@ -118,8 +119,9 @@ func (r *Runtime) Info() (cs Source, host string) {
|
||||
return cs, info[0]
|
||||
}
|
||||
|
||||
// SetInfo sets a host as a client information from the cs.
|
||||
func (r *Runtime) SetInfo(cs Source, hosts []string) {
|
||||
// setInfo sets a host as a client information from the cs.
|
||||
func (r *Runtime) setInfo(cs Source, hosts []string) {
|
||||
// TODO(s.chzhen): Use contract where hosts must contain non-empty host.
|
||||
if len(hosts) == 1 && hosts[0] == "" {
|
||||
hosts = []string{}
|
||||
}
|
||||
@@ -136,13 +138,13 @@ func (r *Runtime) SetInfo(cs Source, hosts []string) {
|
||||
}
|
||||
}
|
||||
|
||||
// WHOIS returns a WHOIS client information.
|
||||
// WHOIS returns a copy of WHOIS client information.
|
||||
func (r *Runtime) WHOIS() (info *whois.Info) {
|
||||
return r.whois
|
||||
return r.whois.Clone()
|
||||
}
|
||||
|
||||
// SetWHOIS sets a WHOIS client information. info must be non-nil.
|
||||
func (r *Runtime) SetWHOIS(info *whois.Info) {
|
||||
// setWHOIS sets a WHOIS client information. info must be non-nil.
|
||||
func (r *Runtime) setWHOIS(info *whois.Info) {
|
||||
r.whois = info
|
||||
}
|
||||
|
||||
@@ -175,3 +177,15 @@ func (r *Runtime) isEmpty() (ok bool) {
|
||||
func (r *Runtime) Addr() (ip netip.Addr) {
|
||||
return r.ip
|
||||
}
|
||||
|
||||
// clone returns a deep copy of the runtime client.
|
||||
func (r *Runtime) clone() (c *Runtime) {
|
||||
return &Runtime{
|
||||
ip: r.ip,
|
||||
whois: r.whois.Clone(),
|
||||
arp: slices.Clone(r.arp),
|
||||
rdns: slices.Clone(r.rdns),
|
||||
dhcp: slices.Clone(r.dhcp),
|
||||
hostsFile: slices.Clone(r.hostsFile),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
@@ -136,7 +135,8 @@ type Persistent struct {
|
||||
}
|
||||
|
||||
// validate returns an error if persistent client information contains errors.
|
||||
func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) {
|
||||
// allTags must be sorted.
|
||||
func (c *Persistent) validate(allTags []string) (err error) {
|
||||
switch {
|
||||
case c.Name == "":
|
||||
return errors.Error("empty name")
|
||||
@@ -157,7 +157,8 @@ func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) {
|
||||
}
|
||||
|
||||
for _, t := range c.Tags {
|
||||
if !allTags.Has(t) {
|
||||
_, ok := slices.BinarySearch(allTags, t)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid tag: %q", t)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -125,69 +122,3 @@ func TestPersistent_EqualIDs(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersistent_Validate(t *testing.T) {
|
||||
const (
|
||||
allowedTag = "allowed_tag"
|
||||
notAllowedTag = "not_allowed_tag"
|
||||
)
|
||||
|
||||
allowedTags := container.NewMapSet(allowedTag)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
cli *Persistent
|
||||
wantErrMsg string
|
||||
}{{
|
||||
name: "success",
|
||||
cli: &Persistent{
|
||||
Name: "basic",
|
||||
IPs: []netip.Addr{
|
||||
netip.MustParseAddr("1.2.3.4"),
|
||||
},
|
||||
UID: MustNewUID(),
|
||||
},
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "empty_name",
|
||||
cli: &Persistent{
|
||||
Name: "",
|
||||
},
|
||||
wantErrMsg: "empty name",
|
||||
}, {
|
||||
name: "no_id",
|
||||
cli: &Persistent{
|
||||
Name: "no_id",
|
||||
},
|
||||
wantErrMsg: "id required",
|
||||
}, {
|
||||
name: "no_uid",
|
||||
cli: &Persistent{
|
||||
Name: "no_uid",
|
||||
IPs: []netip.Addr{
|
||||
netip.MustParseAddr("1.2.3.4"),
|
||||
},
|
||||
},
|
||||
wantErrMsg: "uid required",
|
||||
}, {
|
||||
name: "not_allowed_tag",
|
||||
cli: &Persistent{
|
||||
Name: "basic",
|
||||
IPs: []netip.Addr{
|
||||
netip.MustParseAddr("1.2.3.4"),
|
||||
},
|
||||
UID: MustNewUID(),
|
||||
Tags: []string{
|
||||
notAllowedTag,
|
||||
},
|
||||
},
|
||||
wantErrMsg: `invalid tag: "` + notAllowedTag + `"`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.cli.validate(allowedTags)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,39 +2,34 @@ package client
|
||||
|
||||
import "net/netip"
|
||||
|
||||
// RuntimeIndex stores information about runtime clients.
|
||||
type RuntimeIndex struct {
|
||||
// runtimeIndex stores information about runtime clients.
|
||||
type runtimeIndex struct {
|
||||
// index maps IP address to runtime client.
|
||||
index map[netip.Addr]*Runtime
|
||||
}
|
||||
|
||||
// NewRuntimeIndex returns initialized runtime index.
|
||||
func NewRuntimeIndex() (ri *RuntimeIndex) {
|
||||
return &RuntimeIndex{
|
||||
// newRuntimeIndex returns initialized runtime index.
|
||||
func newRuntimeIndex() (ri *runtimeIndex) {
|
||||
return &runtimeIndex{
|
||||
index: map[netip.Addr]*Runtime{},
|
||||
}
|
||||
}
|
||||
|
||||
// Client returns the saved runtime client by ip. If no such client exists,
|
||||
// client returns the saved runtime client by ip. If no such client exists,
|
||||
// returns nil.
|
||||
func (ri *RuntimeIndex) Client(ip netip.Addr) (rc *Runtime) {
|
||||
func (ri *runtimeIndex) client(ip netip.Addr) (rc *Runtime) {
|
||||
return ri.index[ip]
|
||||
}
|
||||
|
||||
// Add saves the runtime client in the index. IP address of a client must be
|
||||
// add saves the runtime client in the index. IP address of a client must be
|
||||
// unique. See [Runtime.Client]. rc must not be nil.
|
||||
func (ri *RuntimeIndex) Add(rc *Runtime) {
|
||||
func (ri *runtimeIndex) add(rc *Runtime) {
|
||||
ip := rc.Addr()
|
||||
ri.index[ip] = rc
|
||||
}
|
||||
|
||||
// Size returns the number of the runtime clients.
|
||||
func (ri *RuntimeIndex) Size() (n int) {
|
||||
return len(ri.index)
|
||||
}
|
||||
|
||||
// Range calls f for each runtime client in an undefined order.
|
||||
func (ri *RuntimeIndex) Range(f func(rc *Runtime) (cont bool)) {
|
||||
// rangeClients calls f for each runtime client in an undefined order.
|
||||
func (ri *runtimeIndex) rangeClients(f func(rc *Runtime) (cont bool)) {
|
||||
for _, rc := range ri.index {
|
||||
if !f(rc) {
|
||||
return
|
||||
@@ -42,17 +37,31 @@ func (ri *RuntimeIndex) Range(f func(rc *Runtime) (cont bool)) {
|
||||
}
|
||||
}
|
||||
|
||||
// Delete removes the runtime client by ip.
|
||||
func (ri *RuntimeIndex) Delete(ip netip.Addr) {
|
||||
delete(ri.index, ip)
|
||||
// setInfo sets the client information from cs for runtime client stored by ip.
|
||||
// If no such client exists, it creates one.
|
||||
func (ri *runtimeIndex) setInfo(ip netip.Addr, cs Source, hosts []string) (rc *Runtime) {
|
||||
rc = ri.index[ip]
|
||||
if rc == nil {
|
||||
rc = NewRuntime(ip)
|
||||
ri.add(rc)
|
||||
}
|
||||
|
||||
rc.setInfo(cs, hosts)
|
||||
|
||||
return rc
|
||||
}
|
||||
|
||||
// DeleteBySource removes all runtime clients that have information only from
|
||||
// the specified source and returns the number of removed clients.
|
||||
func (ri *RuntimeIndex) DeleteBySource(src Source) (n int) {
|
||||
for ip, rc := range ri.index {
|
||||
// clearSource removes information from the specified source from all clients.
|
||||
func (ri *runtimeIndex) clearSource(src Source) {
|
||||
for _, rc := range ri.index {
|
||||
rc.unset(src)
|
||||
}
|
||||
}
|
||||
|
||||
// removeEmpty removes empty runtime clients and returns the number of removed
|
||||
// clients.
|
||||
func (ri *runtimeIndex) removeEmpty() (n int) {
|
||||
for ip, rc := range ri.index {
|
||||
if rc.isEmpty() {
|
||||
delete(ri.index, ip)
|
||||
n++
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
package client_test
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRuntimeIndex(t *testing.T) {
|
||||
const cliSrc = client.SourceARP
|
||||
|
||||
var (
|
||||
ip1 = netip.MustParseAddr("1.1.1.1")
|
||||
ip2 = netip.MustParseAddr("2.2.2.2")
|
||||
ip3 = netip.MustParseAddr("3.3.3.3")
|
||||
)
|
||||
|
||||
ri := client.NewRuntimeIndex()
|
||||
currentSize := 0
|
||||
|
||||
testCases := []struct {
|
||||
ip netip.Addr
|
||||
name string
|
||||
hosts []string
|
||||
src client.Source
|
||||
}{{
|
||||
src: cliSrc,
|
||||
ip: ip1,
|
||||
name: "1",
|
||||
hosts: []string{"host1"},
|
||||
}, {
|
||||
src: cliSrc,
|
||||
ip: ip2,
|
||||
name: "2",
|
||||
hosts: []string{"host2"},
|
||||
}, {
|
||||
src: cliSrc,
|
||||
ip: ip3,
|
||||
name: "3",
|
||||
hosts: []string{"host3"},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
rc := client.NewRuntime(tc.ip)
|
||||
rc.SetInfo(tc.src, tc.hosts)
|
||||
|
||||
ri.Add(rc)
|
||||
currentSize++
|
||||
|
||||
got := ri.Client(tc.ip)
|
||||
assert.Equal(t, rc, got)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("size", func(t *testing.T) {
|
||||
assert.Equal(t, currentSize, ri.Size())
|
||||
})
|
||||
|
||||
t.Run("range", func(t *testing.T) {
|
||||
s := 0
|
||||
|
||||
ri.Range(func(rc *client.Runtime) (cont bool) {
|
||||
s++
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, currentSize, s)
|
||||
})
|
||||
|
||||
t.Run("delete", func(t *testing.T) {
|
||||
ri.Delete(ip1)
|
||||
currentSize--
|
||||
|
||||
assert.Equal(t, currentSize, ri.Size())
|
||||
})
|
||||
|
||||
t.Run("delete_by_src", func(t *testing.T) {
|
||||
assert.Equal(t, currentSize, ri.DeleteBySource(cliSrc))
|
||||
assert.Equal(t, 0, ri.Size())
|
||||
})
|
||||
}
|
||||
@@ -1,29 +1,113 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/hostsfile"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// Config is the client storage configuration structure.
|
||||
//
|
||||
// TODO(s.chzhen): Expand.
|
||||
type Config struct {
|
||||
// AllowedTags is a list of all allowed client tags.
|
||||
AllowedTags []string
|
||||
// allowedTags is the list of available client tags.
|
||||
var allowedTags = []string{
|
||||
"device_audio",
|
||||
"device_camera",
|
||||
"device_gameconsole",
|
||||
"device_laptop",
|
||||
"device_nas", // Network-attached Storage
|
||||
"device_other",
|
||||
"device_pc",
|
||||
"device_phone",
|
||||
"device_printer",
|
||||
"device_securityalarm",
|
||||
"device_tablet",
|
||||
"device_tv",
|
||||
|
||||
"os_android",
|
||||
"os_ios",
|
||||
"os_linux",
|
||||
"os_macos",
|
||||
"os_other",
|
||||
"os_windows",
|
||||
|
||||
"user_admin",
|
||||
"user_child",
|
||||
"user_regular",
|
||||
}
|
||||
|
||||
// DHCP is an interface for accessing DHCP lease data the [Storage] needs.
|
||||
type DHCP interface {
|
||||
// Leases returns all the DHCP leases.
|
||||
Leases() (leases []*dhcpsvc.Lease)
|
||||
|
||||
// HostByIP returns the hostname of the DHCP client with the given IP
|
||||
// address. host will be empty if there is no such client, due to an
|
||||
// assumption that a DHCP client must always have a hostname.
|
||||
HostByIP(ip netip.Addr) (host string)
|
||||
|
||||
// MACByIP returns the MAC address for the given IP address leased. It
|
||||
// returns nil if there is no such client, due to an assumption that a DHCP
|
||||
// client must always have a MAC address.
|
||||
MACByIP(ip netip.Addr) (mac net.HardwareAddr)
|
||||
}
|
||||
|
||||
// EmptyDHCP is the empty [DHCP] implementation that does nothing.
|
||||
type EmptyDHCP struct{}
|
||||
|
||||
// type check
|
||||
var _ DHCP = EmptyDHCP{}
|
||||
|
||||
// Leases implements the [DHCP] interface for emptyDHCP.
|
||||
func (EmptyDHCP) Leases() (leases []*dhcpsvc.Lease) { return nil }
|
||||
|
||||
// HostByIP implements the [DHCP] interface for emptyDHCP.
|
||||
func (EmptyDHCP) HostByIP(_ netip.Addr) (host string) { return "" }
|
||||
|
||||
// MACByIP implements the [DHCP] interface for emptyDHCP.
|
||||
func (EmptyDHCP) MACByIP(_ netip.Addr) (mac net.HardwareAddr) { return nil }
|
||||
|
||||
// HostsContainer is an interface for receiving updates to the system hosts
|
||||
// file.
|
||||
type HostsContainer interface {
|
||||
Upd() (updates <-chan *hostsfile.DefaultStorage)
|
||||
}
|
||||
|
||||
// StorageConfig is the client storage configuration structure.
|
||||
type StorageConfig struct {
|
||||
// DHCP is used to match IPs against MACs of persistent clients and update
|
||||
// [SourceDHCP] runtime client information. It must not be nil.
|
||||
DHCP DHCP
|
||||
|
||||
// EtcHosts is used to update [SourceHostsFile] runtime client information.
|
||||
EtcHosts HostsContainer
|
||||
|
||||
// ARPDB is used to update [SourceARP] runtime client information.
|
||||
ARPDB arpdb.Interface
|
||||
|
||||
// InitialClients is a list of persistent clients parsed from the
|
||||
// configuration file. Each client must not be nil.
|
||||
InitialClients []*Persistent
|
||||
|
||||
// ARPClientsUpdatePeriod defines how often [SourceARP] runtime client
|
||||
// information is updated.
|
||||
ARPClientsUpdatePeriod time.Duration
|
||||
|
||||
// RuntimeSourceDHCP specifies whether to update [SourceDHCP] information
|
||||
// of runtime clients.
|
||||
RuntimeSourceDHCP bool
|
||||
}
|
||||
|
||||
// Storage contains information about persistent and runtime clients.
|
||||
type Storage struct {
|
||||
// allowedTags is a set of all allowed tags.
|
||||
allowedTags *container.MapSet[string]
|
||||
|
||||
// mu protects indexes of persistent and runtime clients.
|
||||
mu *sync.Mutex
|
||||
|
||||
@@ -31,21 +115,250 @@ type Storage struct {
|
||||
index *index
|
||||
|
||||
// runtimeIndex contains information about runtime clients.
|
||||
runtimeIndex *runtimeIndex
|
||||
|
||||
// dhcp is used to update [SourceDHCP] runtime client information.
|
||||
dhcp DHCP
|
||||
|
||||
// etcHosts is used to update [SourceHostsFile] runtime client information.
|
||||
etcHosts HostsContainer
|
||||
|
||||
// arpDB is used to update [SourceARP] runtime client information.
|
||||
arpDB arpdb.Interface
|
||||
|
||||
// done is the shutdown signaling channel.
|
||||
done chan struct{}
|
||||
|
||||
// allowedTags is a sorted list of all allowed tags. It must not be
|
||||
// modified after initialization.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
runtimeIndex *RuntimeIndex
|
||||
// TODO(s.chzhen): Use custom type.
|
||||
allowedTags []string
|
||||
|
||||
// arpClientsUpdatePeriod defines how often [SourceARP] runtime client
|
||||
// information is updated. It must be greater than zero.
|
||||
arpClientsUpdatePeriod time.Duration
|
||||
|
||||
// runtimeSourceDHCP specifies whether to update [SourceDHCP] information
|
||||
// of runtime clients.
|
||||
runtimeSourceDHCP bool
|
||||
}
|
||||
|
||||
// NewStorage returns initialized client storage. conf must not be nil.
|
||||
func NewStorage(conf *Config) (s *Storage) {
|
||||
allowedTags := container.NewMapSet(conf.AllowedTags...)
|
||||
func NewStorage(conf *StorageConfig) (s *Storage, err error) {
|
||||
tags := slices.Clone(allowedTags)
|
||||
slices.Sort(tags)
|
||||
|
||||
return &Storage{
|
||||
allowedTags: allowedTags,
|
||||
mu: &sync.Mutex{},
|
||||
index: newIndex(),
|
||||
runtimeIndex: NewRuntimeIndex(),
|
||||
s = &Storage{
|
||||
allowedTags: tags,
|
||||
mu: &sync.Mutex{},
|
||||
index: newIndex(),
|
||||
runtimeIndex: newRuntimeIndex(),
|
||||
dhcp: conf.DHCP,
|
||||
etcHosts: conf.EtcHosts,
|
||||
arpDB: conf.ARPDB,
|
||||
done: make(chan struct{}),
|
||||
arpClientsUpdatePeriod: conf.ARPClientsUpdatePeriod,
|
||||
runtimeSourceDHCP: conf.RuntimeSourceDHCP,
|
||||
}
|
||||
|
||||
for i, p := range conf.InitialClients {
|
||||
err = s.Add(p)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("adding client %q at index %d: %w", p.Name, i, err)
|
||||
}
|
||||
}
|
||||
|
||||
s.ReloadARP()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Start starts the goroutines for updating the runtime client information.
|
||||
//
|
||||
// TODO(s.chzhen): Pass context.
|
||||
func (s *Storage) Start(_ context.Context) (err error) {
|
||||
go s.periodicARPUpdate()
|
||||
go s.handleHostsUpdates()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown gracefully stops the client storage.
|
||||
//
|
||||
// TODO(s.chzhen): Pass context.
|
||||
func (s *Storage) Shutdown(_ context.Context) (err error) {
|
||||
close(s.done)
|
||||
|
||||
return s.closeUpstreams()
|
||||
}
|
||||
|
||||
// periodicARPUpdate periodically reloads runtime clients from ARP. It is
|
||||
// intended to be used as a goroutine.
|
||||
func (s *Storage) periodicARPUpdate() {
|
||||
defer log.OnPanic("storage")
|
||||
|
||||
t := time.NewTicker(s.arpClientsUpdatePeriod)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
s.ReloadARP()
|
||||
case <-s.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReloadARP reloads runtime clients from ARP, if configured.
|
||||
func (s *Storage) ReloadARP() {
|
||||
if s.arpDB != nil {
|
||||
s.addFromSystemARP()
|
||||
}
|
||||
}
|
||||
|
||||
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
|
||||
// command.
|
||||
func (s *Storage) addFromSystemARP() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if err := s.arpDB.Refresh(); err != nil {
|
||||
s.arpDB = arpdb.Empty{}
|
||||
log.Error("refreshing arp container: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ns := s.arpDB.Neighbors()
|
||||
if len(ns) == 0 {
|
||||
log.Debug("refreshing arp container: the update is empty")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
src := SourceARP
|
||||
s.runtimeIndex.clearSource(src)
|
||||
|
||||
for _, n := range ns {
|
||||
s.runtimeIndex.setInfo(n.IP, src, []string{n.Name})
|
||||
}
|
||||
|
||||
removed := s.runtimeIndex.removeEmpty()
|
||||
|
||||
log.Debug("storage: added %d, removed %d client aliases from arp neighborhood", len(ns), removed)
|
||||
}
|
||||
|
||||
// handleHostsUpdates receives the updates from the hosts container and adds
|
||||
// them to the clients storage. It is intended to be used as a goroutine.
|
||||
func (s *Storage) handleHostsUpdates() {
|
||||
if s.etcHosts == nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer log.OnPanic("storage")
|
||||
|
||||
for {
|
||||
select {
|
||||
case upd, ok := <-s.etcHosts.Upd():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
s.addFromHostsFile(upd)
|
||||
case <-s.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// addFromHostsFile fills the client-hostname pairing index from the system's
|
||||
// hosts files.
|
||||
func (s *Storage) addFromHostsFile(hosts *hostsfile.DefaultStorage) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
src := SourceHostsFile
|
||||
s.runtimeIndex.clearSource(src)
|
||||
|
||||
added := 0
|
||||
hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) {
|
||||
// Only the first name of the first record is considered a canonical
|
||||
// hostname for the IP address.
|
||||
//
|
||||
// TODO(e.burkov): Consider using all the names from all the records.
|
||||
s.runtimeIndex.setInfo(addr, src, []string{names[0]})
|
||||
added++
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
removed := s.runtimeIndex.removeEmpty()
|
||||
log.Debug("storage: added %d, removed %d client aliases from system hosts file", added, removed)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ AddressUpdater = (*Storage)(nil)
|
||||
|
||||
// UpdateAddress implements the [AddressUpdater] interface for *Storage
|
||||
func (s *Storage) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {
|
||||
// Common fast path optimization.
|
||||
if host == "" && info == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if host != "" {
|
||||
s.runtimeIndex.setInfo(ip, SourceRDNS, []string{host})
|
||||
}
|
||||
|
||||
if info != nil {
|
||||
s.setWHOISInfo(ip, info)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateDHCP updates [SourceDHCP] runtime client information.
|
||||
func (s *Storage) UpdateDHCP() {
|
||||
if s.dhcp == nil || !s.runtimeSourceDHCP {
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
src := SourceDHCP
|
||||
s.runtimeIndex.clearSource(src)
|
||||
|
||||
added := 0
|
||||
for _, l := range s.dhcp.Leases() {
|
||||
s.runtimeIndex.setInfo(l.IP, src, []string{l.Hostname})
|
||||
added++
|
||||
}
|
||||
|
||||
removed := s.runtimeIndex.removeEmpty()
|
||||
log.Debug("storage: added %d, removed %d client aliases from dhcp", added, removed)
|
||||
}
|
||||
|
||||
// setWHOISInfo sets the WHOIS information for a runtime client.
|
||||
func (s *Storage) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
|
||||
_, ok := s.index.findByIP(ip)
|
||||
if ok {
|
||||
log.Debug("storage: client for %s is already created, ignore whois info", ip)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
rc := s.runtimeIndex.client(ip)
|
||||
if rc == nil {
|
||||
rc = NewRuntime(ip)
|
||||
s.runtimeIndex.add(rc)
|
||||
}
|
||||
|
||||
rc.setWHOIS(wi)
|
||||
|
||||
log.Debug("storage: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||
}
|
||||
|
||||
// Add stores persistent client information or returns an error.
|
||||
@@ -95,6 +408,9 @@ func (s *Storage) FindByName(name string) (p *Persistent, ok bool) {
|
||||
|
||||
// Find finds persistent client by string representation of the client ID, IP
|
||||
// address, or MAC. And returns its shallow copy.
|
||||
//
|
||||
// TODO(s.chzhen): Accept ClientIDData structure instead, which will contain
|
||||
// the parsed IP address, if any.
|
||||
func (s *Storage) Find(id string) (p *Persistent, ok bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -104,6 +420,16 @@ func (s *Storage) Find(id string) (p *Persistent, ok bool) {
|
||||
return p.ShallowClone(), ok
|
||||
}
|
||||
|
||||
ip, err := netip.ParseAddr(id)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
foundMAC := s.dhcp.MACByIP(ip)
|
||||
if foundMAC != nil {
|
||||
return s.FindByMAC(foundMAC)
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
@@ -131,11 +457,9 @@ func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// FindByMAC finds persistent client by MAC and returns its shallow copy.
|
||||
// FindByMAC finds persistent client by MAC and returns its shallow copy. s.mu
|
||||
// is expected to be locked.
|
||||
func (s *Storage) FindByMAC(mac net.HardwareAddr) (p *Persistent, ok bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
p, ok = s.index.findByMAC(mac)
|
||||
if ok {
|
||||
return p.ShallowClone(), ok
|
||||
@@ -217,8 +541,8 @@ func (s *Storage) Size() (n int) {
|
||||
return s.index.size()
|
||||
}
|
||||
|
||||
// CloseUpstreams closes upstream configurations of persistent clients.
|
||||
func (s *Storage) CloseUpstreams() (err error) {
|
||||
// closeUpstreams closes upstream configurations of persistent clients.
|
||||
func (s *Storage) closeUpstreams() (err error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@@ -227,63 +551,39 @@ func (s *Storage) CloseUpstreams() (err error) {
|
||||
|
||||
// ClientRuntime returns a copy of the saved runtime client by ip. If no such
|
||||
// client exists, returns nil.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
return s.runtimeIndex.Client(ip)
|
||||
}
|
||||
rc = s.runtimeIndex.client(ip)
|
||||
if rc != nil {
|
||||
return rc.clone()
|
||||
}
|
||||
|
||||
// AddRuntime saves the runtime client information in the storage. IP address
|
||||
// of a client must be unique. rc must not be nil.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
func (s *Storage) AddRuntime(rc *Runtime) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if !s.runtimeSourceDHCP {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.runtimeIndex.Add(rc)
|
||||
}
|
||||
host := s.dhcp.HostByIP(ip)
|
||||
if host == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SizeRuntime returns the number of the runtime clients.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
func (s *Storage) SizeRuntime() (n int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
rc = s.runtimeIndex.setInfo(ip, SourceDHCP, []string{host})
|
||||
|
||||
return s.runtimeIndex.Size()
|
||||
return rc.clone()
|
||||
}
|
||||
|
||||
// RangeRuntime calls f for each runtime client in an undefined order.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.runtimeIndex.Range(f)
|
||||
s.runtimeIndex.rangeClients(f)
|
||||
}
|
||||
|
||||
// DeleteRuntime removes the runtime client by ip.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
func (s *Storage) DeleteRuntime(ip netip.Addr) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.runtimeIndex.Delete(ip)
|
||||
}
|
||||
|
||||
// DeleteBySource removes all runtime clients that have information only from
|
||||
// the specified source and returns the number of removed clients.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
func (s *Storage) DeleteBySource(src Source) (n int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
return s.runtimeIndex.DeleteBySource(src)
|
||||
// AllowedTags returns the list of available client tags. tags must not be
|
||||
// modified.
|
||||
func (s *Storage) AllowedTags() (tags []string) {
|
||||
return s.allowedTags
|
||||
}
|
||||
|
||||
@@ -3,28 +3,521 @@ package client_test
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/golibs/hostsfile"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testHostsContainer is a mock implementation of the [client.HostsContainer]
|
||||
// interface.
|
||||
type testHostsContainer struct {
|
||||
onUpd func() (updates <-chan *hostsfile.DefaultStorage)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ client.HostsContainer = (*testHostsContainer)(nil)
|
||||
|
||||
// Upd implements the [client.HostsContainer] interface for *testHostsContainer.
|
||||
func (c *testHostsContainer) Upd() (updates <-chan *hostsfile.DefaultStorage) {
|
||||
return c.onUpd()
|
||||
}
|
||||
|
||||
// Interface stores and refreshes the network neighborhood reported by ARP
|
||||
// (Address Resolution Protocol).
|
||||
type Interface interface {
|
||||
// Refresh updates the stored data. It must be safe for concurrent use.
|
||||
Refresh() (err error)
|
||||
|
||||
// Neighbors returnes the last set of data reported by ARP. Both the method
|
||||
// and it's result must be safe for concurrent use.
|
||||
Neighbors() (ns []arpdb.Neighbor)
|
||||
}
|
||||
|
||||
// testARPDB is a mock implementation of the [arpdb.Interface].
|
||||
type testARPDB struct {
|
||||
onRefresh func() (err error)
|
||||
onNeighbors func() (ns []arpdb.Neighbor)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ arpdb.Interface = (*testARPDB)(nil)
|
||||
|
||||
// Refresh implements the [arpdb.Interface] interface for *testARP.
|
||||
func (c *testARPDB) Refresh() (err error) {
|
||||
return c.onRefresh()
|
||||
}
|
||||
|
||||
// Neighbors implements the [arpdb.Interface] interface for *testARP.
|
||||
func (c *testARPDB) Neighbors() (ns []arpdb.Neighbor) {
|
||||
return c.onNeighbors()
|
||||
}
|
||||
|
||||
// testDHCP is a mock implementation of the [client.DHCP].
|
||||
type testDHCP struct {
|
||||
OnLeases func() (leases []*dhcpsvc.Lease)
|
||||
OnHostBy func(ip netip.Addr) (host string)
|
||||
OnMACBy func(ip netip.Addr) (mac net.HardwareAddr)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ client.DHCP = (*testDHCP)(nil)
|
||||
|
||||
// Lease implements the [client.DHCP] interface for *testDHCP.
|
||||
func (t *testDHCP) Leases() (leases []*dhcpsvc.Lease) { return t.OnLeases() }
|
||||
|
||||
// HostByIP implements the [client.DHCP] interface for *testDHCP.
|
||||
func (t *testDHCP) HostByIP(ip netip.Addr) (host string) { return t.OnHostBy(ip) }
|
||||
|
||||
// MACByIP implements the [client.DHCP] interface for *testDHCP.
|
||||
func (t *testDHCP) MACByIP(ip netip.Addr) (mac net.HardwareAddr) { return t.OnMACBy(ip) }
|
||||
|
||||
// compareRuntimeInfo is a helper function that returns true if the runtime
|
||||
// client has provided info.
|
||||
func compareRuntimeInfo(rc *client.Runtime, src client.Source, host string) (ok bool) {
|
||||
s, h := rc.Info()
|
||||
if s != src {
|
||||
return false
|
||||
} else if h != host {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func TestStorage_Add_hostsfile(t *testing.T) {
|
||||
var (
|
||||
cliIP1 = netip.MustParseAddr("1.1.1.1")
|
||||
cliName1 = "client_one"
|
||||
|
||||
cliIP2 = netip.MustParseAddr("2.2.2.2")
|
||||
cliName2 = "client_two"
|
||||
)
|
||||
|
||||
hostCh := make(chan *hostsfile.DefaultStorage)
|
||||
h := &testHostsContainer{
|
||||
onUpd: func() (updates <-chan *hostsfile.DefaultStorage) { return hostCh },
|
||||
}
|
||||
|
||||
storage, err := client.NewStorage(&client.StorageConfig{
|
||||
DHCP: client.EmptyDHCP{},
|
||||
EtcHosts: h,
|
||||
ARPClientsUpdatePeriod: testTimeout / 10,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = storage.Start(testutil.ContextWithTimeout(t, testTimeout))
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
return storage.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
|
||||
})
|
||||
|
||||
t.Run("add_hosts", func(t *testing.T) {
|
||||
var s *hostsfile.DefaultStorage
|
||||
s, err = hostsfile.NewDefaultStorage()
|
||||
require.NoError(t, err)
|
||||
|
||||
s.Add(&hostsfile.Record{
|
||||
Addr: cliIP1,
|
||||
Names: []string{cliName1},
|
||||
})
|
||||
|
||||
testutil.RequireSend(t, hostCh, s, testTimeout)
|
||||
|
||||
require.Eventually(t, func() (ok bool) {
|
||||
cli1 := storage.ClientRuntime(cliIP1)
|
||||
if cli1 == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
assert.True(t, compareRuntimeInfo(cli1, client.SourceHostsFile, cliName1))
|
||||
|
||||
return true
|
||||
}, testTimeout, testTimeout/10)
|
||||
})
|
||||
|
||||
t.Run("update_hosts", func(t *testing.T) {
|
||||
var s *hostsfile.DefaultStorage
|
||||
s, err = hostsfile.NewDefaultStorage()
|
||||
require.NoError(t, err)
|
||||
|
||||
s.Add(&hostsfile.Record{
|
||||
Addr: cliIP2,
|
||||
Names: []string{cliName2},
|
||||
})
|
||||
|
||||
testutil.RequireSend(t, hostCh, s, testTimeout)
|
||||
|
||||
require.Eventually(t, func() (ok bool) {
|
||||
cli2 := storage.ClientRuntime(cliIP2)
|
||||
if cli2 == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
assert.True(t, compareRuntimeInfo(cli2, client.SourceHostsFile, cliName2))
|
||||
|
||||
cli1 := storage.ClientRuntime(cliIP1)
|
||||
require.Nil(t, cli1)
|
||||
|
||||
return true
|
||||
}, testTimeout, testTimeout/10)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStorage_Add_arp(t *testing.T) {
|
||||
var (
|
||||
mu sync.Mutex
|
||||
neighbors []arpdb.Neighbor
|
||||
|
||||
cliIP1 = netip.MustParseAddr("1.1.1.1")
|
||||
cliName1 = "client_one"
|
||||
|
||||
cliIP2 = netip.MustParseAddr("2.2.2.2")
|
||||
cliName2 = "client_two"
|
||||
)
|
||||
|
||||
a := &testARPDB{
|
||||
onRefresh: func() (err error) { return nil },
|
||||
onNeighbors: func() (ns []arpdb.Neighbor) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
return neighbors
|
||||
},
|
||||
}
|
||||
|
||||
storage, err := client.NewStorage(&client.StorageConfig{
|
||||
DHCP: client.EmptyDHCP{},
|
||||
ARPDB: a,
|
||||
ARPClientsUpdatePeriod: testTimeout / 10,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = storage.Start(testutil.ContextWithTimeout(t, testTimeout))
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
return storage.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
|
||||
})
|
||||
|
||||
t.Run("add_hosts", func(t *testing.T) {
|
||||
func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
neighbors = []arpdb.Neighbor{{
|
||||
Name: cliName1,
|
||||
IP: cliIP1,
|
||||
}}
|
||||
}()
|
||||
|
||||
require.Eventually(t, func() (ok bool) {
|
||||
cli1 := storage.ClientRuntime(cliIP1)
|
||||
if cli1 == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
assert.True(t, compareRuntimeInfo(cli1, client.SourceARP, cliName1))
|
||||
|
||||
return true
|
||||
}, testTimeout, testTimeout/10)
|
||||
})
|
||||
|
||||
t.Run("update_hosts", func(t *testing.T) {
|
||||
func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
neighbors = []arpdb.Neighbor{{
|
||||
Name: cliName2,
|
||||
IP: cliIP2,
|
||||
}}
|
||||
}()
|
||||
|
||||
require.Eventually(t, func() (ok bool) {
|
||||
cli2 := storage.ClientRuntime(cliIP2)
|
||||
if cli2 == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
assert.True(t, compareRuntimeInfo(cli2, client.SourceARP, cliName2))
|
||||
|
||||
cli1 := storage.ClientRuntime(cliIP1)
|
||||
require.Nil(t, cli1)
|
||||
|
||||
return true
|
||||
}, testTimeout, testTimeout/10)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStorage_Add_whois(t *testing.T) {
|
||||
var (
|
||||
cliIP1 = netip.MustParseAddr("1.1.1.1")
|
||||
|
||||
cliIP2 = netip.MustParseAddr("2.2.2.2")
|
||||
cliName2 = "client_two"
|
||||
|
||||
cliIP3 = netip.MustParseAddr("3.3.3.3")
|
||||
cliName3 = "client_three"
|
||||
)
|
||||
|
||||
storage, err := client.NewStorage(&client.StorageConfig{
|
||||
DHCP: client.EmptyDHCP{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
whois := &whois.Info{
|
||||
Country: "AU",
|
||||
Orgname: "Example Org",
|
||||
}
|
||||
|
||||
t.Run("new_client", func(t *testing.T) {
|
||||
storage.UpdateAddress(cliIP1, "", whois)
|
||||
cli1 := storage.ClientRuntime(cliIP1)
|
||||
require.NotNil(t, cli1)
|
||||
|
||||
assert.Equal(t, whois, cli1.WHOIS())
|
||||
})
|
||||
|
||||
t.Run("existing_runtime_client", func(t *testing.T) {
|
||||
storage.UpdateAddress(cliIP2, cliName2, nil)
|
||||
storage.UpdateAddress(cliIP2, "", whois)
|
||||
|
||||
cli2 := storage.ClientRuntime(cliIP2)
|
||||
require.NotNil(t, cli2)
|
||||
|
||||
assert.True(t, compareRuntimeInfo(cli2, client.SourceRDNS, cliName2))
|
||||
|
||||
assert.Equal(t, whois, cli2.WHOIS())
|
||||
})
|
||||
|
||||
t.Run("can't_set_persistent_client", func(t *testing.T) {
|
||||
err = storage.Add(&client.Persistent{
|
||||
Name: cliName3,
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{cliIP3},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
storage.UpdateAddress(cliIP3, "", whois)
|
||||
rc := storage.ClientRuntime(cliIP3)
|
||||
require.Nil(t, rc)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClientsDHCP(t *testing.T) {
|
||||
var (
|
||||
cliIP1 = netip.MustParseAddr("1.1.1.1")
|
||||
cliName1 = "one.dhcp"
|
||||
|
||||
cliIP2 = netip.MustParseAddr("2.2.2.2")
|
||||
cliMAC2 = mustParseMAC("22:22:22:22:22:22")
|
||||
cliName2 = "two.dhcp"
|
||||
|
||||
cliIP3 = netip.MustParseAddr("3.3.3.3")
|
||||
cliMAC3 = mustParseMAC("33:33:33:33:33:33")
|
||||
cliName3 = "three.dhcp"
|
||||
|
||||
prsCliIP = netip.MustParseAddr("4.3.2.1")
|
||||
prsCliMAC = mustParseMAC("AA:AA:AA:AA:AA:AA")
|
||||
prsCliName = "persistent.dhcp"
|
||||
)
|
||||
|
||||
ipToHost := map[netip.Addr]string{
|
||||
cliIP1: cliName1,
|
||||
}
|
||||
ipToMAC := map[netip.Addr]net.HardwareAddr{
|
||||
prsCliIP: prsCliMAC,
|
||||
}
|
||||
|
||||
leases := []*dhcpsvc.Lease{{
|
||||
IP: cliIP2,
|
||||
Hostname: cliName2,
|
||||
HWAddr: cliMAC2,
|
||||
}, {
|
||||
IP: cliIP3,
|
||||
Hostname: cliName3,
|
||||
HWAddr: cliMAC3,
|
||||
}}
|
||||
|
||||
d := &testDHCP{
|
||||
OnLeases: func() (ls []*dhcpsvc.Lease) {
|
||||
return leases
|
||||
},
|
||||
OnHostBy: func(ip netip.Addr) (host string) {
|
||||
return ipToHost[ip]
|
||||
},
|
||||
OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) {
|
||||
return ipToMAC[ip]
|
||||
},
|
||||
}
|
||||
|
||||
storage, err := client.NewStorage(&client.StorageConfig{
|
||||
DHCP: d,
|
||||
RuntimeSourceDHCP: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("find_runtime", func(t *testing.T) {
|
||||
cli1 := storage.ClientRuntime(cliIP1)
|
||||
require.NotNil(t, cli1)
|
||||
|
||||
assert.True(t, compareRuntimeInfo(cli1, client.SourceDHCP, cliName1))
|
||||
})
|
||||
|
||||
t.Run("find_persistent", func(t *testing.T) {
|
||||
err = storage.Add(&client.Persistent{
|
||||
Name: prsCliName,
|
||||
UID: client.MustNewUID(),
|
||||
MACs: []net.HardwareAddr{prsCliMAC},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
prsCli, ok := storage.Find(prsCliIP.String())
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, prsCliName, prsCli.Name)
|
||||
})
|
||||
|
||||
t.Run("leases", func(t *testing.T) {
|
||||
delete(ipToHost, cliIP1)
|
||||
storage.UpdateDHCP()
|
||||
|
||||
cli1 := storage.ClientRuntime(cliIP1)
|
||||
require.Nil(t, cli1)
|
||||
|
||||
for i, l := range leases {
|
||||
cli := storage.ClientRuntime(l.IP)
|
||||
require.NotNil(t, cli)
|
||||
|
||||
src, host := cli.Info()
|
||||
assert.Equal(t, client.SourceDHCP, src)
|
||||
assert.Equal(t, leases[i].Hostname, host)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("range", func(t *testing.T) {
|
||||
s := 0
|
||||
storage.RangeRuntime(func(rc *client.Runtime) (cont bool) {
|
||||
s++
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, len(leases), s)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClientsAddExisting(t *testing.T) {
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
storage, err := client.NewStorage(&client.StorageConfig{
|
||||
DHCP: client.EmptyDHCP{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
|
||||
// Add a client.
|
||||
err = storage.Add(&client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")},
|
||||
Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")},
|
||||
MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now add an auto-client with the same IP.
|
||||
storage.UpdateAddress(ip, "test", nil)
|
||||
rc := storage.ClientRuntime(ip)
|
||||
assert.True(t, compareRuntimeInfo(rc, client.SourceRDNS, "test"))
|
||||
})
|
||||
|
||||
t.Run("complicated", func(t *testing.T) {
|
||||
// TODO(a.garipov): Properly decouple the DHCP server from the client
|
||||
// storage.
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping dhcp test on windows")
|
||||
}
|
||||
|
||||
// First, init a DHCP server with a single static lease.
|
||||
config := &dhcpd.ServerConfig{
|
||||
Enabled: true,
|
||||
DataDir: t.TempDir(),
|
||||
Conf4: dhcpd.V4ServerConf{
|
||||
Enabled: true,
|
||||
GatewayIP: netip.MustParseAddr("1.2.3.1"),
|
||||
SubnetMask: netip.MustParseAddr("255.255.255.0"),
|
||||
RangeStart: netip.MustParseAddr("1.2.3.2"),
|
||||
RangeEnd: netip.MustParseAddr("1.2.3.10"),
|
||||
},
|
||||
}
|
||||
|
||||
dhcpServer, err := dhcpd.Create(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
storage, err := client.NewStorage(&client.StorageConfig{
|
||||
DHCP: dhcpServer,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ip := netip.MustParseAddr("1.2.3.4")
|
||||
|
||||
err = dhcpServer.AddStaticLease(&dhcpsvc.Lease{
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: ip,
|
||||
Hostname: "testhost",
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add a new client with the same IP as for a client with MAC.
|
||||
err = storage.Add(&client.Persistent{
|
||||
Name: "client2",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{ip},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add a new client with the IP from the first client's IP range.
|
||||
err = storage.Add(&client.Persistent{
|
||||
Name: "client3",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// newStorage is a helper function that returns a client storage filled with
|
||||
// persistent clients from the m. It also generates a UID for each client.
|
||||
func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) {
|
||||
tb.Helper()
|
||||
|
||||
s = client.NewStorage(&client.Config{
|
||||
AllowedTags: nil,
|
||||
s, err := client.NewStorage(&client.StorageConfig{
|
||||
DHCP: client.EmptyDHCP{},
|
||||
})
|
||||
require.NoError(tb, err)
|
||||
|
||||
for _, c := range m {
|
||||
c.UID = client.MustNewUID()
|
||||
require.NoError(tb, s.Add(c))
|
||||
}
|
||||
|
||||
require.Equal(tb, len(m), s.Size())
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -43,6 +536,9 @@ func TestStorage_Add(t *testing.T) {
|
||||
const (
|
||||
existingName = "existing_name"
|
||||
existingClientID = "existing_client_id"
|
||||
|
||||
allowedTag = "user_admin"
|
||||
notAllowedTag = "not_allowed_tag"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -59,10 +555,20 @@ func TestStorage_Add(t *testing.T) {
|
||||
UID: existingClientUID,
|
||||
}
|
||||
|
||||
s := client.NewStorage(&client.Config{
|
||||
AllowedTags: nil,
|
||||
})
|
||||
err := s.Add(existingClient)
|
||||
s, err := client.NewStorage(&client.StorageConfig{})
|
||||
require.NoError(t, err)
|
||||
|
||||
tags := s.AllowedTags()
|
||||
require.NotZero(t, len(tags))
|
||||
require.True(t, slices.IsSorted(tags))
|
||||
|
||||
_, ok := slices.BinarySearch(tags, allowedTag)
|
||||
require.True(t, ok)
|
||||
|
||||
_, ok = slices.BinarySearch(tags, notAllowedTag)
|
||||
require.False(t, ok)
|
||||
|
||||
err = s.Add(existingClient)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
@@ -119,6 +625,46 @@ func TestStorage_Add(t *testing.T) {
|
||||
},
|
||||
wantErrMsg: `adding client: another client "existing_name" ` +
|
||||
`uses the same ClientID "existing_client_id"`,
|
||||
}, {
|
||||
name: "not_allowed_tag",
|
||||
cli: &client.Persistent{
|
||||
Name: "not_allowed_tag",
|
||||
Tags: []string{notAllowedTag},
|
||||
IPs: []netip.Addr{netip.MustParseAddr("4.4.4.4")},
|
||||
UID: client.MustNewUID(),
|
||||
},
|
||||
wantErrMsg: `adding client: invalid tag: "not_allowed_tag"`,
|
||||
}, {
|
||||
name: "allowed_tag",
|
||||
cli: &client.Persistent{
|
||||
Name: "allowed_tag",
|
||||
Tags: []string{allowedTag},
|
||||
IPs: []netip.Addr{netip.MustParseAddr("5.5.5.5")},
|
||||
UID: client.MustNewUID(),
|
||||
},
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "",
|
||||
cli: &client.Persistent{
|
||||
Name: "",
|
||||
IPs: []netip.Addr{netip.MustParseAddr("6.6.6.6")},
|
||||
UID: client.MustNewUID(),
|
||||
},
|
||||
wantErrMsg: "adding client: empty name",
|
||||
}, {
|
||||
name: "no_id",
|
||||
cli: &client.Persistent{
|
||||
Name: "no_id",
|
||||
UID: client.MustNewUID(),
|
||||
},
|
||||
wantErrMsg: "adding client: id required",
|
||||
}, {
|
||||
name: "no_uid",
|
||||
cli: &client.Persistent{
|
||||
Name: "no_uid",
|
||||
IPs: []netip.Addr{netip.MustParseAddr("7.7.7.7")},
|
||||
},
|
||||
wantErrMsg: "adding client: uid required",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -141,10 +687,10 @@ func TestStorage_RemoveByName(t *testing.T) {
|
||||
UID: client.MustNewUID(),
|
||||
}
|
||||
|
||||
s := client.NewStorage(&client.Config{
|
||||
AllowedTags: nil,
|
||||
})
|
||||
err := s.Add(existingClient)
|
||||
s, err := client.NewStorage(&client.StorageConfig{})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Add(existingClient)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
@@ -168,9 +714,9 @@ func TestStorage_RemoveByName(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("duplicate_remove", func(t *testing.T) {
|
||||
s = client.NewStorage(&client.Config{
|
||||
AllowedTags: nil,
|
||||
})
|
||||
s, err = client.NewStorage(&client.StorageConfig{})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Add(existingClient)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -341,6 +887,127 @@ func TestStorage_FindLoose(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorage_FindByName(t *testing.T) {
|
||||
const (
|
||||
cliIP1 = "1.1.1.1"
|
||||
cliIP2 = "2.2.2.2"
|
||||
)
|
||||
|
||||
const (
|
||||
clientExistingName = "client_existing"
|
||||
clientAnotherExistingName = "client_another_existing"
|
||||
nonExistingClientName = "client_non_existing"
|
||||
)
|
||||
|
||||
var (
|
||||
clientExisting = &client.Persistent{
|
||||
Name: clientExistingName,
|
||||
IPs: []netip.Addr{netip.MustParseAddr(cliIP1)},
|
||||
}
|
||||
|
||||
clientAnotherExisting = &client.Persistent{
|
||||
Name: clientAnotherExistingName,
|
||||
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
|
||||
}
|
||||
)
|
||||
|
||||
clients := []*client.Persistent{
|
||||
clientExisting,
|
||||
clientAnotherExisting,
|
||||
}
|
||||
s := newStorage(t, clients)
|
||||
|
||||
testCases := []struct {
|
||||
want *client.Persistent
|
||||
name string
|
||||
clientName string
|
||||
}{{
|
||||
name: "existing",
|
||||
clientName: clientExistingName,
|
||||
want: clientExisting,
|
||||
}, {
|
||||
name: "another_existing",
|
||||
clientName: clientAnotherExistingName,
|
||||
want: clientAnotherExisting,
|
||||
}, {
|
||||
name: "non_existing",
|
||||
clientName: nonExistingClientName,
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c, ok := s.FindByName(tc.clientName)
|
||||
if tc.want == nil {
|
||||
assert.False(t, ok)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, tc.want, c)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorage_FindByMAC(t *testing.T) {
|
||||
var (
|
||||
cliMAC = mustParseMAC("11:11:11:11:11:11")
|
||||
cliAnotherMAC = mustParseMAC("22:22:22:22:22:22")
|
||||
nonExistingClientMAC = mustParseMAC("33:33:33:33:33:33")
|
||||
)
|
||||
|
||||
var (
|
||||
clientExisting = &client.Persistent{
|
||||
Name: "client",
|
||||
MACs: []net.HardwareAddr{cliMAC},
|
||||
}
|
||||
|
||||
clientAnotherExisting = &client.Persistent{
|
||||
Name: "another_client",
|
||||
MACs: []net.HardwareAddr{cliAnotherMAC},
|
||||
}
|
||||
)
|
||||
|
||||
clients := []*client.Persistent{
|
||||
clientExisting,
|
||||
clientAnotherExisting,
|
||||
}
|
||||
s := newStorage(t, clients)
|
||||
|
||||
testCases := []struct {
|
||||
want *client.Persistent
|
||||
name string
|
||||
clientMAC net.HardwareAddr
|
||||
}{{
|
||||
name: "existing",
|
||||
clientMAC: cliMAC,
|
||||
want: clientExisting,
|
||||
}, {
|
||||
name: "another_existing",
|
||||
clientMAC: cliAnotherMAC,
|
||||
want: clientAnotherExisting,
|
||||
}, {
|
||||
name: "non_existing",
|
||||
clientMAC: nonExistingClientMAC,
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c, ok := s.FindByMAC(tc.clientMAC)
|
||||
if tc.want == nil {
|
||||
assert.False(t, ok)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, tc.want, c)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorage_Update(t *testing.T) {
|
||||
const (
|
||||
clientName = "client_name"
|
||||
|
||||
Reference in New Issue
Block a user