all: sync with master

This commit is contained in:
Ainar Garipov
2024-07-03 15:38:37 +03:00
parent f73717ec08
commit 158d4f0249
352 changed files with 33842 additions and 33276 deletions

View File

@@ -161,7 +161,8 @@ func (hc *HostsContainer) handleEvents() {
defer close(hc.updates)
ok, eventsCh := true, hc.watcher.Events()
eventsCh := hc.watcher.Events()
ok := eventsCh != nil
for ok {
select {
case _, ok = <-eventsCh:

View File

@@ -160,3 +160,34 @@ func (w *osWatcher) handleErrors() {
log.Error("%s: %s", osWatcherPref, err)
}
}
// EmptyFSWatcher is a no-op implementation of the [FSWatcher] interface. It
// may be used on systems not supporting filesystem events.
type EmptyFSWatcher struct{}
// type check
var _ FSWatcher = EmptyFSWatcher{}
// Start implements the [FSWatcher] interface for EmptyFSWatcher. It always
// returns nil error.
func (EmptyFSWatcher) Start() (err error) {
return nil
}
// Close implements the [FSWatcher] interface for EmptyFSWatcher. It always
// returns nil error.
func (EmptyFSWatcher) Close() (err error) {
return nil
}
// Events implements the [FSWatcher] interface for EmptyFSWatcher. It always
// returns nil channel.
func (EmptyFSWatcher) Events() (e <-chan event) {
return nil
}
// Add implements the [FSWatcher] interface for EmptyFSWatcher. It always
// returns nil error.
func (EmptyFSWatcher) Add(_ string) (err error) {
return nil
}

View File

@@ -19,25 +19,9 @@ import (
"github.com/AdguardTeam/golibs/log"
)
// UnsupportedError is returned by functions and methods when a particular
// operation Op cannot be performed on the current OS.
type UnsupportedError struct {
Op string
OS string
}
// Error implements the error interface for *UnsupportedError.
func (err *UnsupportedError) Error() (msg string) {
return fmt.Sprintf("%s is unsupported on %s", err.Op, err.OS)
}
// Unsupported is a helper that returns an *UnsupportedError with the Op field
// set to op and the OS field set to the current OS.
// Unsupported is a helper that returns a wrapped [errors.ErrUnsupported].
func Unsupported(op string) (err error) {
return &UnsupportedError{
Op: op,
OS: runtime.GOOS,
}
return fmt.Errorf("%s: not supported on %s: %w", op, runtime.GOOS, errors.ErrUnsupported)
}
// SetRlimit sets user-specified limit of how many fd's we can use.

View File

@@ -30,8 +30,8 @@ func macToKey(mac net.HardwareAddr) (key macKey) {
}
}
// Index stores all information about persistent clients.
type Index struct {
// index stores all information about persistent clients.
type index struct {
// nameToUID maps client name to UID.
nameToUID map[string]UID
@@ -51,9 +51,9 @@ type Index struct {
subnetToUID aghalg.SortedMap[netip.Prefix, UID]
}
// NewIndex initializes the new instance of client index.
func NewIndex() (ci *Index) {
return &Index{
// newIndex initializes the new instance of client index.
func newIndex() (ci *index) {
return &index{
nameToUID: map[string]UID{},
clientIDToUID: map[string]UID{},
ipToUID: map[netip.Addr]UID{},
@@ -63,9 +63,9 @@ func NewIndex() (ci *Index) {
}
}
// Add stores information about a persistent client in the index. c must be
// non-nil and contain UID.
func (ci *Index) Add(c *Persistent) {
// add stores information about a persistent client in the index. c must be
// non-nil, have a UID, and contain at least one identifier.
func (ci *index) add(c *Persistent) {
if (c.UID == UID{}) {
panic("client must contain uid")
}
@@ -92,9 +92,9 @@ func (ci *Index) Add(c *Persistent) {
ci.uidToClient[c.UID] = c
}
// ClashesUID returns existing persistent client with the same UID as c. Note
// clashesUID returns existing persistent client with the same UID as c. Note
// that this is only possible when configuration contains duplicate fields.
func (ci *Index) ClashesUID(c *Persistent) (err error) {
func (ci *index) clashesUID(c *Persistent) (err error) {
p, ok := ci.uidToClient[c.UID]
if ok {
return fmt.Errorf("another client %q uses the same uid", p.Name)
@@ -103,9 +103,9 @@ func (ci *Index) ClashesUID(c *Persistent) (err error) {
return nil
}
// Clashes returns an error if the index contains a different persistent client
// clashes returns an error if the index contains a different persistent client
// with at least a single identifier contained by c. c must be non-nil.
func (ci *Index) Clashes(c *Persistent) (err error) {
func (ci *index) clashes(c *Persistent) (err error) {
if p := ci.clashesName(c); p != nil {
return fmt.Errorf("another client uses the same name %q", p.Name)
}
@@ -139,8 +139,8 @@ func (ci *Index) Clashes(c *Persistent) (err error) {
// clashesName returns existing persistent client with the same name as c or
// nil. c must be non-nil.
func (ci *Index) clashesName(c *Persistent) (existing *Persistent) {
existing, ok := ci.FindByName(c.Name)
func (ci *index) clashesName(c *Persistent) (existing *Persistent) {
existing, ok := ci.findByName(c.Name)
if !ok {
return nil
}
@@ -154,7 +154,7 @@ func (ci *Index) clashesName(c *Persistent) (existing *Persistent) {
// clashesIP returns a previous client with the same IP address as c. c must be
// non-nil.
func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) {
func (ci *index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) {
for _, ip := range c.IPs {
existing, ok := ci.ipToUID[ip]
if ok && existing != c.UID {
@@ -167,7 +167,7 @@ func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) {
// clashesSubnet returns a previous client with the same subnet as c. c must be
// non-nil.
func (ci *Index) clashesSubnet(c *Persistent) (p *Persistent, s netip.Prefix) {
func (ci *index) clashesSubnet(c *Persistent) (p *Persistent, s netip.Prefix) {
for _, s = range c.Subnets {
var existing UID
var ok bool
@@ -193,7 +193,7 @@ func (ci *Index) clashesSubnet(c *Persistent) (p *Persistent, s netip.Prefix) {
// clashesMAC returns a previous client with the same MAC address as c. c must
// be non-nil.
func (ci *Index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr) {
func (ci *index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr) {
for _, mac = range c.MACs {
k := macToKey(mac)
existing, ok := ci.macToUID[k]
@@ -205,9 +205,9 @@ func (ci *Index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr)
return nil, nil
}
// Find finds persistent client by string representation of the client ID, IP
// find finds persistent client by string representation of the client ID, IP
// address, or MAC.
func (ci *Index) Find(id string) (c *Persistent, ok bool) {
func (ci *index) find(id string) (c *Persistent, ok bool) {
uid, found := ci.clientIDToUID[id]
if found {
return ci.uidToClient[uid], true
@@ -224,14 +224,14 @@ func (ci *Index) Find(id string) (c *Persistent, ok bool) {
mac, err := net.ParseMAC(id)
if err == nil {
return ci.FindByMAC(mac)
return ci.findByMAC(mac)
}
return nil, false
}
// FindByName finds persistent client by name.
func (ci *Index) FindByName(name string) (c *Persistent, found bool) {
// findByName finds persistent client by name.
func (ci *index) findByName(name string) (c *Persistent, found bool) {
uid, found := ci.nameToUID[name]
if found {
return ci.uidToClient[uid], true
@@ -241,7 +241,7 @@ func (ci *Index) FindByName(name string) (c *Persistent, found bool) {
}
// findByIP finds persistent client by IP address.
func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
func (ci *index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
uid, found := ci.ipToUID[ip]
if found {
return ci.uidToClient[uid], true
@@ -266,8 +266,8 @@ func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
return nil, false
}
// FindByMAC finds persistent client by MAC.
func (ci *Index) FindByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
// findByMAC finds persistent client by MAC.
func (ci *index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
k := macToKey(mac)
uid, found := ci.macToUID[k]
if found {
@@ -277,13 +277,13 @@ func (ci *Index) FindByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
return nil, false
}
// FindByIPWithoutZone finds a persistent client by IP address without zone. It
// findByIPWithoutZone finds a persistent client by IP address without zone. It
// strips the IPv6 zone index from the stored IP addresses before comparing,
// because querylog entries don't have it. See TODO on [querylog.logEntry.IP].
//
// Note that multiple clients can have the same IP address with different zones.
// Therefore, the result of this method is indeterminate.
func (ci *Index) FindByIPWithoutZone(ip netip.Addr) (c *Persistent) {
func (ci *index) findByIPWithoutZone(ip netip.Addr) (c *Persistent) {
if (ip == netip.Addr{}) {
return nil
}
@@ -297,9 +297,9 @@ func (ci *Index) FindByIPWithoutZone(ip netip.Addr) (c *Persistent) {
return nil
}
// Delete removes information about persistent client from the index. c must be
// remove removes information about persistent client from the index. c must be
// non-nil.
func (ci *Index) Delete(c *Persistent) {
func (ci *index) remove(c *Persistent) {
delete(ci.nameToUID, c.Name)
for _, id := range c.ClientIDs {
@@ -322,24 +322,14 @@ func (ci *Index) Delete(c *Persistent) {
delete(ci.uidToClient, c.UID)
}
// Size returns the number of persistent clients.
func (ci *Index) Size() (n int) {
// size returns the number of persistent clients.
func (ci *index) size() (n int) {
return len(ci.uidToClient)
}
// Range calls f for each persistent client, unless cont is false. The order is
// undefined.
func (ci *Index) Range(f func(c *Persistent) (cont bool)) {
for _, c := range ci.uidToClient {
if !f(c) {
return
}
}
}
// RangeByName is like [Index.Range] but sorts the persistent clients by name
// rangeByName is like [Index.Range] but sorts the persistent clients by name
// before iterating ensuring a predictable order.
func (ci *Index) RangeByName(f func(c *Persistent) (cont bool)) {
func (ci *index) rangeByName(f func(c *Persistent) (cont bool)) {
cs := maps.Values(ci.uidToClient)
slices.SortFunc(cs, func(a, b *Persistent) (n int) {
return strings.Compare(a.Name, b.Name)
@@ -352,10 +342,10 @@ func (ci *Index) RangeByName(f func(c *Persistent) (cont bool)) {
}
}
// CloseUpstreams closes upstream configurations of persistent clients.
func (ci *Index) CloseUpstreams() (err error) {
// closeUpstreams closes upstream configurations of persistent clients.
func (ci *index) closeUpstreams() (err error) {
var errs []error
ci.RangeByName(func(c *Persistent) (cont bool) {
ci.rangeByName(func(c *Persistent) (cont bool) {
err = c.CloseUpstreams()
if err != nil {
errs = append(errs, err)

View File

@@ -11,17 +11,18 @@ import (
// newIDIndex is a helper function that returns a client index filled with
// persistent clients from the m. It also generates a UID for each client.
func newIDIndex(m []*Persistent) (ci *Index) {
ci = NewIndex()
func newIDIndex(m []*Persistent) (ci *index) {
ci = newIndex()
for _, c := range m {
c.UID = MustNewUID()
ci.Add(c)
ci.add(c)
}
return ci
}
// TODO(s.chzhen): Remove.
func TestClientIndex_Find(t *testing.T) {
const (
cliIPNone = "1.2.3.4"
@@ -109,7 +110,7 @@ func TestClientIndex_Find(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, id := range tc.ids {
c, ok := ci.Find(id)
c, ok := ci.find(id)
require.True(t, ok)
assert.Equal(t, tc.want, c)
@@ -118,7 +119,7 @@ func TestClientIndex_Find(t *testing.T) {
}
t.Run("not_found", func(t *testing.T) {
_, ok := ci.Find(cliIPNone)
_, ok := ci.find(cliIPNone)
assert.False(t, ok)
})
}
@@ -170,11 +171,11 @@ func TestClientIndex_Clashes(t *testing.T) {
clone := tc.client.ShallowClone()
clone.UID = MustNewUID()
err := ci.Clashes(clone)
err := ci.clashes(clone)
require.Error(t, err)
ci.Delete(tc.client)
err = ci.Clashes(clone)
ci.remove(tc.client)
err = ci.clashes(clone)
require.NoError(t, err)
})
}
@@ -292,7 +293,7 @@ func TestIndex_FindByIPWithoutZone(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := ci.FindByIPWithoutZone(tc.ip.WithZone(""))
c := ci.findByIPWithoutZone(tc.ip.WithZone(""))
require.Equal(t, tc.want, c)
})
}
@@ -338,7 +339,7 @@ func TestClientIndex_RangeByName(t *testing.T) {
ci := newIDIndex(tc.want)
var got []*Persistent
ci.RangeByName(func(c *Persistent) (cont bool) {
ci.rangeByName(func(c *Persistent) (cont bool) {
got = append(got, c)
return true

View File

@@ -12,6 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
@@ -64,53 +65,107 @@ type Persistent struct {
// upstream must be used.
UpstreamConfig *proxy.CustomUpstreamConfig
// SafeSearch handles search engine hosts rewrites.
SafeSearch filtering.SafeSearch
// BlockedServices is the configuration of blocked services of a client.
// BlockedServices is the configuration of blocked services of a client. It
// must not be nil after initialization.
BlockedServices *filtering.BlockedServices
// Name of the persistent client. Must not be empty.
Name string
Tags []string
// Tags is a list of client tags that categorize the client.
Tags []string
// Upstreams is a list of custom upstream DNS servers for the client.
Upstreams []string
// IPs is a list of IP addresses that identify the client. The client must
// have at least one ID (IP, subnet, MAC, or ClientID).
IPs []netip.Addr
// Subnets identifying the client. The client must have at least one ID
// (IP, subnet, MAC, or ClientID).
//
// TODO(s.chzhen): Use netutil.Prefix.
Subnets []netip.Prefix
MACs []net.HardwareAddr
Subnets []netip.Prefix
// MACs identifying the client. The client must have at least one ID (IP,
// subnet, MAC, or ClientID).
MACs []net.HardwareAddr
// ClientIDs identifying the client. The client must have at least one ID
// (IP, subnet, MAC, or ClientID).
ClientIDs []string
// UID is the unique identifier of the persistent client.
UID UID
UpstreamsCacheSize uint32
// UpstreamsCacheSize is the cache size for custom upstreams.
UpstreamsCacheSize uint32
// UpstreamsCacheEnabled specifies whether custom upstreams are used.
UpstreamsCacheEnabled bool
UseOwnSettings bool
FilteringEnabled bool
SafeBrowsingEnabled bool
ParentalEnabled bool
UseOwnBlockedServices bool
IgnoreQueryLog bool
IgnoreStatistics bool
// UseOwnSettings specifies whether custom filtering settings are used.
UseOwnSettings bool
// FilteringEnabled specifies whether filtering is enabled.
FilteringEnabled bool
// SafeBrowsingEnabled specifies whether safe browsing is enabled.
SafeBrowsingEnabled bool
// ParentalEnabled specifies whether parental control is enabled.
ParentalEnabled bool
// UseOwnBlockedServices specifies whether custom services are blocked.
UseOwnBlockedServices bool
// IgnoreQueryLog specifies whether the client requests are logged.
IgnoreQueryLog bool
// IgnoreStatistics specifies whether the client requests are counted.
IgnoreStatistics bool
// SafeSearchConf is the safe search filtering configuration.
//
// TODO(d.kolyshev): Make SafeSearchConf a pointer.
SafeSearchConf filtering.SafeSearchConfig
}
// SetTags sets the tags if they are known, otherwise logs an unknown tag.
func (c *Persistent) SetTags(tags []string, known *container.MapSet[string]) {
for _, t := range tags {
if !known.Has(t) {
log.Info("skipping unknown tag %q", t)
continue
}
c.Tags = append(c.Tags, t)
// validate returns an error if persistent client information contains errors.
func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) {
switch {
case c.Name == "":
return errors.Error("empty name")
case c.IDsLen() == 0:
return errors.Error("id required")
case c.UID == UID{}:
return errors.Error("uid required")
}
conf, err := proxy.ParseUpstreamsConfig(c.Upstreams, &upstream.Options{})
if err != nil {
return fmt.Errorf("invalid upstream servers: %w", err)
}
err = conf.Close()
if err != nil {
log.Error("client: closing upstream config: %s", err)
}
for _, t := range c.Tags {
if !allTags.Has(t) {
return fmt.Errorf("invalid tag: %q", t)
}
}
// TODO(s.chzhen): Move to the constructor.
slices.Sort(c.Tags)
return nil
}
// SetIDs parses a list of strings into typed fields and returns an error if

View File

@@ -1,13 +1,16 @@
package client
import (
"net/netip"
"testing"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPersistentClient_EqualIDs(t *testing.T) {
func TestPersistent_EqualIDs(t *testing.T) {
const (
ip = "0.0.0.0"
ip1 = "1.1.1.1"
@@ -122,3 +125,69 @@ func TestPersistentClient_EqualIDs(t *testing.T) {
})
}
}
func TestPersistent_Validate(t *testing.T) {
const (
allowedTag = "allowed_tag"
notAllowedTag = "not_allowed_tag"
)
allowedTags := container.NewMapSet(allowedTag)
testCases := []struct {
name string
cli *Persistent
wantErrMsg string
}{{
name: "success",
cli: &Persistent{
Name: "basic",
IPs: []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
},
UID: MustNewUID(),
},
wantErrMsg: "",
}, {
name: "empty_name",
cli: &Persistent{
Name: "",
},
wantErrMsg: "empty name",
}, {
name: "no_id",
cli: &Persistent{
Name: "no_id",
},
wantErrMsg: "id required",
}, {
name: "no_uid",
cli: &Persistent{
Name: "no_uid",
IPs: []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
},
},
wantErrMsg: "uid required",
}, {
name: "not_allowed_tag",
cli: &Persistent{
Name: "basic",
IPs: []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
},
UID: MustNewUID(),
Tags: []string{
notAllowedTag,
},
},
wantErrMsg: `invalid tag: "` + notAllowedTag + `"`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.cli.validate(allowedTags)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}

289
internal/client/storage.go Normal file
View File

@@ -0,0 +1,289 @@
package client
import (
"fmt"
"net"
"net/netip"
"sync"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
// Config is the client storage configuration structure.
//
// TODO(s.chzhen): Expand.
type Config struct {
// AllowedTags is a list of all allowed client tags.
AllowedTags []string
}
// Storage contains information about persistent and runtime clients.
type Storage struct {
// allowedTags is a set of all allowed tags.
allowedTags *container.MapSet[string]
// mu protects indexes of persistent and runtime clients.
mu *sync.Mutex
// index contains information about persistent clients.
index *index
// runtimeIndex contains information about runtime clients.
//
// TODO(s.chzhen): Use it.
runtimeIndex *RuntimeIndex
}
// NewStorage returns initialized client storage. conf must not be nil.
func NewStorage(conf *Config) (s *Storage) {
allowedTags := container.NewMapSet(conf.AllowedTags...)
return &Storage{
allowedTags: allowedTags,
mu: &sync.Mutex{},
index: newIndex(),
runtimeIndex: NewRuntimeIndex(),
}
}
// Add stores persistent client information or returns an error.
func (s *Storage) Add(p *Persistent) (err error) {
defer func() { err = errors.Annotate(err, "adding client: %w") }()
err = p.validate(s.allowedTags)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}
s.mu.Lock()
defer s.mu.Unlock()
err = s.index.clashesUID(p)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}
err = s.index.clashes(p)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}
s.index.add(p)
log.Debug("client storage: added %q: IDs: %q [%d]", p.Name, p.IDs(), s.index.size())
return nil
}
// FindByName finds persistent client by name. And returns its shallow copy.
func (s *Storage) FindByName(name string) (p *Persistent, ok bool) {
s.mu.Lock()
defer s.mu.Unlock()
p, ok = s.index.findByName(name)
if ok {
return p.ShallowClone(), ok
}
return nil, false
}
// Find finds persistent client by string representation of the client ID, IP
// address, or MAC. And returns its shallow copy.
func (s *Storage) Find(id string) (p *Persistent, ok bool) {
s.mu.Lock()
defer s.mu.Unlock()
p, ok = s.index.find(id)
if ok {
return p.ShallowClone(), ok
}
return nil, false
}
// FindLoose is like [Storage.Find] but it also tries to find a persistent
// client by IP address without zone. It strips the IPv6 zone index from the
// stored IP addresses before comparing, because querylog entries don't have it.
// See TODO on [querylog.logEntry.IP].
//
// Note that multiple clients can have the same IP address with different zones.
// Therefore, the result of this method is indeterminate.
func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) {
s.mu.Lock()
defer s.mu.Unlock()
p, ok = s.index.find(id)
if ok {
return p.ShallowClone(), ok
}
p = s.index.findByIPWithoutZone(ip)
if p != nil {
return p.ShallowClone(), true
}
return nil, false
}
// FindByMAC finds persistent client by MAC and returns its shallow copy.
func (s *Storage) FindByMAC(mac net.HardwareAddr) (p *Persistent, ok bool) {
s.mu.Lock()
defer s.mu.Unlock()
p, ok = s.index.findByMAC(mac)
if ok {
return p.ShallowClone(), ok
}
return nil, false
}
// RemoveByName removes persistent client information. ok is false if no such
// client exists by that name.
func (s *Storage) RemoveByName(name string) (ok bool) {
s.mu.Lock()
defer s.mu.Unlock()
p, ok := s.index.findByName(name)
if !ok {
return false
}
if err := p.CloseUpstreams(); err != nil {
log.Error("client storage: removing client %q: %s", p.Name, err)
}
s.index.remove(p)
return true
}
// Update finds the stored persistent client by its name and updates its
// information from p.
func (s *Storage) Update(name string, p *Persistent) (err error) {
defer func() { err = errors.Annotate(err, "updating client: %w") }()
err = p.validate(s.allowedTags)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}
s.mu.Lock()
defer s.mu.Unlock()
stored, ok := s.index.findByName(name)
if !ok {
return fmt.Errorf("client %q is not found", name)
}
// Client p has a newly generated UID, so replace it with the stored one.
//
// TODO(s.chzhen): Remove when frontend starts handling UIDs.
p.UID = stored.UID
err = s.index.clashes(p)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}
s.index.remove(stored)
s.index.add(p)
return nil
}
// RangeByName calls f for each persistent client sorted by name, unless cont is
// false.
func (s *Storage) RangeByName(f func(c *Persistent) (cont bool)) {
s.mu.Lock()
defer s.mu.Unlock()
s.index.rangeByName(f)
}
// Size returns the number of persistent clients.
func (s *Storage) Size() (n int) {
s.mu.Lock()
defer s.mu.Unlock()
return s.index.size()
}
// CloseUpstreams closes upstream configurations of persistent clients.
func (s *Storage) CloseUpstreams() (err error) {
s.mu.Lock()
defer s.mu.Unlock()
return s.index.closeUpstreams()
}
// ClientRuntime returns a copy of the saved runtime client by ip. If no such
// client exists, returns nil.
//
// TODO(s.chzhen): Use it.
func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) {
s.mu.Lock()
defer s.mu.Unlock()
return s.runtimeIndex.Client(ip)
}
// AddRuntime saves the runtime client information in the storage. IP address
// of a client must be unique. rc must not be nil.
//
// TODO(s.chzhen): Use it.
func (s *Storage) AddRuntime(rc *Runtime) {
s.mu.Lock()
defer s.mu.Unlock()
s.runtimeIndex.Add(rc)
}
// SizeRuntime returns the number of the runtime clients.
//
// TODO(s.chzhen): Use it.
func (s *Storage) SizeRuntime() (n int) {
s.mu.Lock()
defer s.mu.Unlock()
return s.runtimeIndex.Size()
}
// RangeRuntime calls f for each runtime client in an undefined order.
//
// TODO(s.chzhen): Use it.
func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
s.mu.Lock()
defer s.mu.Unlock()
s.runtimeIndex.Range(f)
}
// DeleteRuntime removes the runtime client by ip.
//
// TODO(s.chzhen): Use it.
func (s *Storage) DeleteRuntime(ip netip.Addr) {
s.mu.Lock()
defer s.mu.Unlock()
s.runtimeIndex.Delete(ip)
}
// DeleteBySource removes all runtime clients that have information only from
// the specified source and returns the number of removed clients.
//
// TODO(s.chzhen): Use it.
func (s *Storage) DeleteBySource(src Source) (n int) {
s.mu.Lock()
defer s.mu.Unlock()
return s.runtimeIndex.DeleteBySource(src)
}

View File

@@ -0,0 +1,481 @@
package client_test
import (
"net"
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newStorage is a helper function that returns a client storage filled with
// persistent clients from the m. It also generates a UID for each client.
func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) {
tb.Helper()
s = client.NewStorage(&client.Config{
AllowedTags: nil,
})
for _, c := range m {
c.UID = client.MustNewUID()
require.NoError(tb, s.Add(c))
}
return s
}
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
// error.
func mustParseMAC(s string) (mac net.HardwareAddr) {
mac, err := net.ParseMAC(s)
if err != nil {
panic(err)
}
return mac
}
func TestStorage_Add(t *testing.T) {
const (
existingName = "existing_name"
existingClientID = "existing_client_id"
)
var (
existingClientUID = client.MustNewUID()
existingIP = netip.MustParseAddr("1.2.3.4")
existingSubnet = netip.MustParsePrefix("1.2.3.0/24")
)
existingClient := &client.Persistent{
Name: existingName,
IPs: []netip.Addr{existingIP},
Subnets: []netip.Prefix{existingSubnet},
ClientIDs: []string{existingClientID},
UID: existingClientUID,
}
s := client.NewStorage(&client.Config{
AllowedTags: nil,
})
err := s.Add(existingClient)
require.NoError(t, err)
testCases := []struct {
name string
cli *client.Persistent
wantErrMsg string
}{{
name: "basic",
cli: &client.Persistent{
Name: "basic",
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
UID: client.MustNewUID(),
},
wantErrMsg: "",
}, {
name: "duplicate_uid",
cli: &client.Persistent{
Name: "no_uid",
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
UID: existingClientUID,
},
wantErrMsg: `adding client: another client "existing_name" uses the same uid`,
}, {
name: "duplicate_name",
cli: &client.Persistent{
Name: existingName,
IPs: []netip.Addr{netip.MustParseAddr("3.3.3.3")},
UID: client.MustNewUID(),
},
wantErrMsg: `adding client: another client uses the same name "existing_name"`,
}, {
name: "duplicate_ip",
cli: &client.Persistent{
Name: "duplicate_ip",
IPs: []netip.Addr{existingIP},
UID: client.MustNewUID(),
},
wantErrMsg: `adding client: another client "existing_name" uses the same IP "1.2.3.4"`,
}, {
name: "duplicate_subnet",
cli: &client.Persistent{
Name: "duplicate_subnet",
Subnets: []netip.Prefix{existingSubnet},
UID: client.MustNewUID(),
},
wantErrMsg: `adding client: another client "existing_name" ` +
`uses the same subnet "1.2.3.0/24"`,
}, {
name: "duplicate_client_id",
cli: &client.Persistent{
Name: "duplicate_client_id",
ClientIDs: []string{existingClientID},
UID: client.MustNewUID(),
},
wantErrMsg: `adding client: another client "existing_name" ` +
`uses the same ClientID "existing_client_id"`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err = s.Add(tc.cli)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}
func TestStorage_RemoveByName(t *testing.T) {
const (
existingName = "existing_name"
)
existingClient := &client.Persistent{
Name: existingName,
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
UID: client.MustNewUID(),
}
s := client.NewStorage(&client.Config{
AllowedTags: nil,
})
err := s.Add(existingClient)
require.NoError(t, err)
testCases := []struct {
want assert.BoolAssertionFunc
name string
cliName string
}{{
name: "existing_client",
cliName: existingName,
want: assert.True,
}, {
name: "non_existing_client",
cliName: "non_existing_client",
want: assert.False,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tc.want(t, s.RemoveByName(tc.cliName))
})
}
t.Run("duplicate_remove", func(t *testing.T) {
s = client.NewStorage(&client.Config{
AllowedTags: nil,
})
err = s.Add(existingClient)
require.NoError(t, err)
assert.True(t, s.RemoveByName(existingName))
assert.False(t, s.RemoveByName(existingName))
})
}
func TestStorage_Find(t *testing.T) {
const (
cliIPNone = "1.2.3.4"
cliIP1 = "1.1.1.1"
cliIP2 = "2.2.2.2"
cliIPv6 = "1:2:3::4"
cliSubnet = "2.2.2.0/24"
cliSubnetIP = "2.2.2.222"
cliID = "client-id"
cliMAC = "11:11:11:11:11:11"
linkLocalIP = "fe80::abcd:abcd:abcd:ab%eth0"
linkLocalSubnet = "fe80::/16"
)
var (
clientWithBothFams = &client.Persistent{
Name: "client1",
IPs: []netip.Addr{
netip.MustParseAddr(cliIP1),
netip.MustParseAddr(cliIPv6),
},
}
clientWithSubnet = &client.Persistent{
Name: "client2",
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
}
clientWithMAC = &client.Persistent{
Name: "client_with_mac",
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
}
clientWithID = &client.Persistent{
Name: "client_with_id",
ClientIDs: []string{cliID},
}
clientLinkLocal = &client.Persistent{
Name: "client_link_local",
Subnets: []netip.Prefix{netip.MustParsePrefix(linkLocalSubnet)},
}
)
clients := []*client.Persistent{
clientWithBothFams,
clientWithSubnet,
clientWithMAC,
clientWithID,
clientLinkLocal,
}
s := newStorage(t, clients)
testCases := []struct {
want *client.Persistent
name string
ids []string
}{{
name: "ipv4_ipv6",
ids: []string{cliIP1, cliIPv6},
want: clientWithBothFams,
}, {
name: "ipv4_subnet",
ids: []string{cliIP2, cliSubnetIP},
want: clientWithSubnet,
}, {
name: "mac",
ids: []string{cliMAC},
want: clientWithMAC,
}, {
name: "client_id",
ids: []string{cliID},
want: clientWithID,
}, {
name: "client_link_local_subnet",
ids: []string{linkLocalIP},
want: clientLinkLocal,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, id := range tc.ids {
c, ok := s.Find(id)
require.True(t, ok)
assert.Equal(t, tc.want, c)
}
})
}
t.Run("not_found", func(t *testing.T) {
_, ok := s.Find(cliIPNone)
assert.False(t, ok)
})
}
func TestStorage_FindLoose(t *testing.T) {
const (
nonExistingClientID = "client_id"
)
var (
ip = netip.MustParseAddr("fe80::a098:7654:32ef:ff1")
ipWithZone = netip.MustParseAddr("fe80::1ff:fe23:4567:890a%eth2")
)
var (
clientNoZone = &client.Persistent{
Name: "client",
IPs: []netip.Addr{ip},
}
clientWithZone = &client.Persistent{
Name: "client_with_zone",
IPs: []netip.Addr{ipWithZone},
}
)
s := newStorage(
t,
[]*client.Persistent{
clientNoZone,
clientWithZone,
},
)
testCases := []struct {
ip netip.Addr
want assert.BoolAssertionFunc
wantCli *client.Persistent
name string
}{{
name: "without_zone",
ip: ip,
wantCli: clientNoZone,
want: assert.True,
}, {
name: "with_zone",
ip: ipWithZone,
wantCli: clientWithZone,
want: assert.True,
}, {
name: "zero_address",
ip: netip.Addr{},
wantCli: nil,
want: assert.False,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c, ok := s.FindLoose(tc.ip.WithZone(""), nonExistingClientID)
assert.Equal(t, tc.wantCli, c)
tc.want(t, ok)
})
}
}
func TestStorage_Update(t *testing.T) {
const (
clientName = "client_name"
obstructingName = "obstructing_name"
obstructingClientID = "obstructing_client_id"
)
var (
obstructingIP = netip.MustParseAddr("1.2.3.4")
obstructingSubnet = netip.MustParsePrefix("1.2.3.0/24")
)
obstructingClient := &client.Persistent{
Name: obstructingName,
IPs: []netip.Addr{obstructingIP},
Subnets: []netip.Prefix{obstructingSubnet},
ClientIDs: []string{obstructingClientID},
}
clientToUpdate := &client.Persistent{
Name: clientName,
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
}
testCases := []struct {
name string
cli *client.Persistent
wantErrMsg string
}{{
name: "basic",
cli: &client.Persistent{
Name: "basic",
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
UID: client.MustNewUID(),
},
wantErrMsg: "",
}, {
name: "duplicate_name",
cli: &client.Persistent{
Name: obstructingName,
IPs: []netip.Addr{netip.MustParseAddr("3.3.3.3")},
UID: client.MustNewUID(),
},
wantErrMsg: `updating client: another client uses the same name "obstructing_name"`,
}, {
name: "duplicate_ip",
cli: &client.Persistent{
Name: "duplicate_ip",
IPs: []netip.Addr{obstructingIP},
UID: client.MustNewUID(),
},
wantErrMsg: `updating client: another client "obstructing_name" uses the same IP "1.2.3.4"`,
}, {
name: "duplicate_subnet",
cli: &client.Persistent{
Name: "duplicate_subnet",
Subnets: []netip.Prefix{obstructingSubnet},
UID: client.MustNewUID(),
},
wantErrMsg: `updating client: another client "obstructing_name" ` +
`uses the same subnet "1.2.3.0/24"`,
}, {
name: "duplicate_client_id",
cli: &client.Persistent{
Name: "duplicate_client_id",
ClientIDs: []string{obstructingClientID},
UID: client.MustNewUID(),
},
wantErrMsg: `updating client: another client "obstructing_name" ` +
`uses the same ClientID "obstructing_client_id"`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := newStorage(
t,
[]*client.Persistent{
clientToUpdate,
obstructingClient,
},
)
err := s.Update(clientName, tc.cli)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}
func TestStorage_RangeByName(t *testing.T) {
sortedClients := []*client.Persistent{{
Name: "clientA",
ClientIDs: []string{"A"},
}, {
Name: "clientB",
ClientIDs: []string{"B"},
}, {
Name: "clientC",
ClientIDs: []string{"C"},
}, {
Name: "clientD",
ClientIDs: []string{"D"},
}, {
Name: "clientE",
ClientIDs: []string{"E"},
}}
testCases := []struct {
name string
want []*client.Persistent
}{{
name: "basic",
want: sortedClients,
}, {
name: "nil",
want: nil,
}, {
name: "one_element",
want: sortedClients[:1],
}, {
name: "two_elements",
want: sortedClients[:2],
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := newStorage(t, tc.want)
var got []*client.Persistent
s.RangeByName(func(c *client.Persistent) (cont bool) {
got = append(got, c)
return true
})
assert.Equal(t, tc.want, got)
})
}
}

View File

@@ -2,11 +2,12 @@ package dhcpsvc
import (
"fmt"
"slices"
"log/slog"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/mapsutil"
"github.com/AdguardTeam/golibs/netutil"
"golang.org/x/exp/maps"
)
// Config is the configuration for the DHCP service.
@@ -15,6 +16,9 @@ type Config struct {
// interface identified by its name.
Interfaces map[string]*InterfaceConfig
// Logger will be used to log the DHCP events.
Logger *slog.Logger
// LocalDomainName is the top-level domain name to use for resolving DHCP
// clients' hostnames.
LocalDomainName string
@@ -38,36 +42,44 @@ type InterfaceConfig struct {
}
// Validate returns an error in conf if any.
//
// TODO(e.burkov): Unexport and rewrite the test.
func (conf *Config) Validate() (err error) {
switch {
case conf == nil:
return errNilConfig
case !conf.Enabled:
return nil
case conf.ICMPTimeout < 0:
return newMustErr("icmp timeout", "be non-negative", conf.ICMPTimeout)
}
var errs []error
if conf.ICMPTimeout < 0 {
err = newMustErr("icmp timeout", "be non-negative", conf.ICMPTimeout)
errs = append(errs, err)
}
err = netutil.ValidateDomainName(conf.LocalDomainName)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
errs = append(errs, err)
}
if len(conf.Interfaces) == 0 {
return errNoInterfaces
errs = append(errs, errNoInterfaces)
return errors.Join(errs...)
}
ifaces := maps.Keys(conf.Interfaces)
slices.Sort(ifaces)
for _, iface := range ifaces {
if err = conf.Interfaces[iface].validate(); err != nil {
return fmt.Errorf("interface %q: %w", iface, err)
mapsutil.SortedRange(conf.Interfaces, func(iface string, ic *InterfaceConfig) (ok bool) {
err = ic.validate()
if err != nil {
errs = append(errs, fmt.Errorf("interface %q: %w", iface, err))
}
}
return nil
return true
})
return errors.Join(errs...)
}
// validate returns an error in ic, if any.

View File

@@ -23,7 +23,8 @@ func TestConfig_Validate(t *testing.T) {
}, {
name: "empty",
conf: &dhcpsvc.Config{
Enabled: true,
Enabled: true,
Interfaces: testInterfaceConf,
},
wantErrMsg: `bad domain name "": domain name is empty`,
}, {

View File

@@ -11,6 +11,14 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
)
const (
// keyInterface is the key for logging the network interface name.
keyInterface = "iface"
// keyFamily is the key for logging the handled address family.
keyFamily = "family"
)
// Interface is a DHCP service.
//
// TODO(e.burkov): Separate HostByIP, MACByIP, IPByHost into a separate
@@ -50,21 +58,21 @@ type Interface interface {
// AddLease adds a new DHCP lease. l must be valid. It returns an error if
// l already exists.
AddLease(l *Lease) (err error)
AddLease(ctx context.Context, l *Lease) (err error)
// UpdateStaticLease replaces an existing static DHCP lease. l must be
// valid. It returns an error if the lease with the given hardware address
// doesn't exist or if other values match another existing lease.
UpdateStaticLease(l *Lease) (err error)
UpdateStaticLease(ctx context.Context, l *Lease) (err error)
// RemoveLease removes an existing DHCP lease. l must be valid. It returns
// an error if there is no lease equal to l.
RemoveLease(l *Lease) (err error)
RemoveLease(ctx context.Context, l *Lease) (err error)
// Reset removes all the DHCP leases.
//
// TODO(e.burkov): If it's really needed?
Reset() (err error)
Reset(ctx context.Context) (err error)
}
// Empty is an [Interface] implementation that does nothing.
@@ -101,13 +109,13 @@ func (Empty) IPByHost(_ string) (ip netip.Addr) { return netip.Addr{} }
func (Empty) Leases() (leases []*Lease) { return nil }
// AddLease implements the [Interface] interface for Empty.
func (Empty) AddLease(_ *Lease) (err error) { return nil }
func (Empty) AddLease(_ context.Context, _ *Lease) (err error) { return nil }
// UpdateStaticLease implements the [Interface] interface for Empty.
func (Empty) UpdateStaticLease(_ *Lease) (err error) { return nil }
func (Empty) UpdateStaticLease(_ context.Context, _ *Lease) (err error) { return nil }
// RemoveLease implements the [Interface] interface for Empty.
func (Empty) RemoveLease(_ *Lease) (err error) { return nil }
func (Empty) RemoveLease(_ context.Context, _ *Lease) (err error) { return nil }
// Reset implements the [Interface] interface for Empty.
func (Empty) Reset() (err error) { return nil }
func (Empty) Reset(_ context.Context) (err error) { return nil }

View File

@@ -2,6 +2,7 @@ package dhcpsvc
import (
"fmt"
"log/slog"
"slices"
"time"
)
@@ -11,6 +12,9 @@ import (
//
// TODO(e.burkov): Add other methods as [DHCPServer] evolves.
type netInterface struct {
// logger logs the events related to the network interface.
logger *slog.Logger
// name is the name of the network interface.
name string

View File

@@ -1,16 +1,17 @@
package dhcpsvc
import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"slices"
"sync"
"sync/atomic"
"time"
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/exp/maps"
"github.com/AdguardTeam/golibs/mapsutil"
)
// DHCPServer is a DHCP server for both IPv4 and IPv6 address families.
@@ -19,6 +20,9 @@ type DHCPServer struct {
// information about its clients.
enabled *atomic.Bool
// logger logs common DHCP events.
logger *slog.Logger
// localTLD is the top-level domain name to use for resolving DHCP clients'
// hostnames.
localTLD string
@@ -43,8 +47,11 @@ type DHCPServer struct {
// error if the given configuration can't be used.
//
// TODO(e.burkov): Use.
func New(conf *Config) (srv *DHCPServer, err error) {
func New(ctx context.Context, conf *Config) (srv *DHCPServer, err error) {
l := conf.Logger
if !conf.Enabled {
l.DebugContext(ctx, "disabled")
// TODO(e.burkov): Perhaps return [Empty]?
return nil, nil
}
@@ -52,27 +59,26 @@ func New(conf *Config) (srv *DHCPServer, err error) {
// TODO(e.burkov): Add validations scoped to the network interfaces set.
ifaces4 := make(netInterfacesV4, 0, len(conf.Interfaces))
ifaces6 := make(netInterfacesV6, 0, len(conf.Interfaces))
var errs []error
ifaceNames := maps.Keys(conf.Interfaces)
slices.Sort(ifaceNames)
var i4 *netInterfaceV4
var i6 *netInterfaceV6
for _, ifaceName := range ifaceNames {
iface := conf.Interfaces[ifaceName]
i4, err = newNetInterfaceV4(ifaceName, iface.IPv4)
mapsutil.SortedRange(conf.Interfaces, func(name string, iface *InterfaceConfig) (cont bool) {
var i4 *netInterfaceV4
i4, err = newNetInterfaceV4(ctx, l, name, iface.IPv4)
if err != nil {
return nil, fmt.Errorf("interface %q: ipv4: %w", ifaceName, err)
errs = append(errs, fmt.Errorf("interface %q: ipv4: %w", name, err))
} else if i4 != nil {
ifaces4 = append(ifaces4, i4)
}
i6 = newNetInterfaceV6(ifaceName, iface.IPv6)
i6 := newNetInterfaceV6(ctx, l, name, iface.IPv6)
if i6 != nil {
ifaces6 = append(ifaces6, i6)
}
return true
})
if err = errors.Join(errs...); err != nil {
return nil, err
}
enabled := &atomic.Bool{}
@@ -80,6 +86,7 @@ func New(conf *Config) (srv *DHCPServer, err error) {
srv = &DHCPServer{
enabled: enabled,
logger: l,
localTLD: conf.LocalDomainName,
leasesMu: &sync.RWMutex{},
leases: newLeaseIndex(),
@@ -159,7 +166,7 @@ func (srv *DHCPServer) IPByHost(host string) (ip netip.Addr) {
}
// Reset implements the [Interface] interface for *DHCPServer.
func (srv *DHCPServer) Reset() (err error) {
func (srv *DHCPServer) Reset(ctx context.Context) (err error) {
srv.leasesMu.Lock()
defer srv.leasesMu.Unlock()
@@ -171,11 +178,13 @@ func (srv *DHCPServer) Reset() (err error) {
}
srv.leases.clear()
srv.logger.DebugContext(ctx, "reset leases")
return nil
}
// AddLease implements the [Interface] interface for *DHCPServer.
func (srv *DHCPServer) AddLease(l *Lease) (err error) {
func (srv *DHCPServer) AddLease(ctx context.Context, l *Lease) (err error) {
defer func() { err = errors.Annotate(err, "adding lease: %w") }()
addr := l.IP
@@ -188,13 +197,27 @@ func (srv *DHCPServer) AddLease(l *Lease) (err error) {
srv.leasesMu.Lock()
defer srv.leasesMu.Unlock()
return srv.leases.add(l, iface)
err = srv.leases.add(l, iface)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}
iface.logger.DebugContext(
ctx, "added lease",
"hostname", l.Hostname,
"ip", l.IP,
"mac", l.HWAddr,
"static", l.IsStatic,
)
return nil
}
// UpdateStaticLease implements the [Interface] interface for *DHCPServer.
//
// TODO(e.burkov): Support moving leases between interfaces.
func (srv *DHCPServer) UpdateStaticLease(l *Lease) (err error) {
func (srv *DHCPServer) UpdateStaticLease(ctx context.Context, l *Lease) (err error) {
defer func() { err = errors.Annotate(err, "updating static lease: %w") }()
addr := l.IP
@@ -207,11 +230,25 @@ func (srv *DHCPServer) UpdateStaticLease(l *Lease) (err error) {
srv.leasesMu.Lock()
defer srv.leasesMu.Unlock()
return srv.leases.update(l, iface)
err = srv.leases.update(l, iface)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}
iface.logger.DebugContext(
ctx, "updated lease",
"hostname", l.Hostname,
"ip", l.IP,
"mac", l.HWAddr,
"static", l.IsStatic,
)
return nil
}
// RemoveLease implements the [Interface] interface for *DHCPServer.
func (srv *DHCPServer) RemoveLease(l *Lease) (err error) {
func (srv *DHCPServer) RemoveLease(ctx context.Context, l *Lease) (err error) {
defer func() { err = errors.Annotate(err, "removing lease: %w") }()
addr := l.IP
@@ -224,7 +261,21 @@ func (srv *DHCPServer) RemoveLease(l *Lease) (err error) {
srv.leasesMu.Lock()
defer srv.leasesMu.Unlock()
return srv.leases.remove(l, iface)
err = srv.leases.remove(l, iface)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}
iface.logger.DebugContext(
ctx, "removed lease",
"hostname", l.Hostname,
"ip", l.IP,
"mac", l.HWAddr,
"static", l.IsStatic,
)
return nil
}
// ifaceForAddr returns the handled network interface for the given IP address,

View File

@@ -8,6 +8,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -16,6 +17,12 @@ import (
// testLocalTLD is a common local TLD for tests.
const testLocalTLD = "local"
// testTimeout is a common timeout for tests and contexts.
const testTimeout time.Duration = 10 * time.Second
// discardLog is a logger to discard test output.
var discardLog = slogutil.NewDiscardLogger()
// testInterfaceConf is a common set of interface configurations for tests.
var testInterfaceConf = map[string]*dhcpsvc.InterfaceConfig{
"eth0": {
@@ -103,6 +110,7 @@ func TestNew(t *testing.T) {
}{{
conf: &dhcpsvc.Config{
Enabled: true,
Logger: discardLog,
LocalDomainName: testLocalTLD,
Interfaces: map[string]*dhcpsvc.InterfaceConfig{
"eth0": {
@@ -116,6 +124,7 @@ func TestNew(t *testing.T) {
}, {
conf: &dhcpsvc.Config{
Enabled: true,
Logger: discardLog,
LocalDomainName: testLocalTLD,
Interfaces: map[string]*dhcpsvc.InterfaceConfig{
"eth0": {
@@ -129,6 +138,7 @@ func TestNew(t *testing.T) {
}, {
conf: &dhcpsvc.Config{
Enabled: true,
Logger: discardLog,
LocalDomainName: testLocalTLD,
Interfaces: map[string]*dhcpsvc.InterfaceConfig{
"eth0": {
@@ -143,6 +153,7 @@ func TestNew(t *testing.T) {
}, {
conf: &dhcpsvc.Config{
Enabled: true,
Logger: discardLog,
LocalDomainName: testLocalTLD,
Interfaces: map[string]*dhcpsvc.InterfaceConfig{
"eth0": {
@@ -156,17 +167,22 @@ func TestNew(t *testing.T) {
`range start 127.0.0.1 is not within 192.168.0.1/24`,
}}
ctx := testutil.ContextWithTimeout(t, testTimeout)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := dhcpsvc.New(tc.conf)
_, err := dhcpsvc.New(ctx, tc.conf)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}
func TestDHCPServer_AddLease(t *testing.T) {
srv, err := dhcpsvc.New(&dhcpsvc.Config{
ctx := testutil.ContextWithTimeout(t, testTimeout)
srv, err := dhcpsvc.New(ctx, &dhcpsvc.Config{
Enabled: true,
Logger: discardLog,
LocalDomainName: testLocalTLD,
Interfaces: testInterfaceConf,
})
@@ -186,7 +202,7 @@ func TestDHCPServer_AddLease(t *testing.T) {
mac2 := mustParseMAC(t, "06:05:04:03:02:01")
mac3 := mustParseMAC(t, "02:03:04:05:06:07")
require.NoError(t, srv.AddLease(&dhcpsvc.Lease{
require.NoError(t, srv.AddLease(ctx, &dhcpsvc.Lease{
Hostname: host1,
IP: ip1,
HWAddr: mac1,
@@ -261,14 +277,17 @@ func TestDHCPServer_AddLease(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
testutil.AssertErrorMsg(t, tc.wantErrMsg, srv.AddLease(tc.lease))
testutil.AssertErrorMsg(t, tc.wantErrMsg, srv.AddLease(ctx, tc.lease))
})
}
}
func TestDHCPServer_index(t *testing.T) {
srv, err := dhcpsvc.New(&dhcpsvc.Config{
ctx := testutil.ContextWithTimeout(t, testTimeout)
srv, err := dhcpsvc.New(ctx, &dhcpsvc.Config{
Enabled: true,
Logger: discardLog,
LocalDomainName: testLocalTLD,
Interfaces: testInterfaceConf,
})
@@ -313,7 +332,7 @@ func TestDHCPServer_index(t *testing.T) {
IsStatic: true,
}}
for _, l := range leases {
require.NoError(t, srv.AddLease(l))
require.NoError(t, srv.AddLease(ctx, l))
}
t.Run("ip_idx", func(t *testing.T) {
@@ -342,8 +361,11 @@ func TestDHCPServer_index(t *testing.T) {
}
func TestDHCPServer_UpdateStaticLease(t *testing.T) {
srv, err := dhcpsvc.New(&dhcpsvc.Config{
ctx := testutil.ContextWithTimeout(t, testTimeout)
srv, err := dhcpsvc.New(ctx, &dhcpsvc.Config{
Enabled: true,
Logger: discardLog,
LocalDomainName: testLocalTLD,
Interfaces: testInterfaceConf,
})
@@ -386,7 +408,7 @@ func TestDHCPServer_UpdateStaticLease(t *testing.T) {
IsStatic: true,
}}
for _, l := range leases {
require.NoError(t, srv.AddLease(l))
require.NoError(t, srv.AddLease(ctx, l))
}
testCases := []struct {
@@ -456,14 +478,17 @@ func TestDHCPServer_UpdateStaticLease(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
testutil.AssertErrorMsg(t, tc.wantErrMsg, srv.UpdateStaticLease(tc.lease))
testutil.AssertErrorMsg(t, tc.wantErrMsg, srv.UpdateStaticLease(ctx, tc.lease))
})
}
}
func TestDHCPServer_RemoveLease(t *testing.T) {
srv, err := dhcpsvc.New(&dhcpsvc.Config{
ctx := testutil.ContextWithTimeout(t, testTimeout)
srv, err := dhcpsvc.New(ctx, &dhcpsvc.Config{
Enabled: true,
Logger: discardLog,
LocalDomainName: testLocalTLD,
Interfaces: testInterfaceConf,
})
@@ -495,7 +520,7 @@ func TestDHCPServer_RemoveLease(t *testing.T) {
IsStatic: true,
}}
for _, l := range leases {
require.NoError(t, srv.AddLease(l))
require.NoError(t, srv.AddLease(ctx, l))
}
testCases := []struct {
@@ -546,7 +571,7 @@ func TestDHCPServer_RemoveLease(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
testutil.AssertErrorMsg(t, tc.wantErrMsg, srv.RemoveLease(tc.lease))
testutil.AssertErrorMsg(t, tc.wantErrMsg, srv.RemoveLease(ctx, tc.lease))
})
}
@@ -554,8 +579,11 @@ func TestDHCPServer_RemoveLease(t *testing.T) {
}
func TestDHCPServer_Reset(t *testing.T) {
srv, err := dhcpsvc.New(&dhcpsvc.Config{
ctx := testutil.ContextWithTimeout(t, testTimeout)
srv, err := dhcpsvc.New(ctx, &dhcpsvc.Config{
Enabled: true,
Logger: discardLog,
LocalDomainName: testLocalTLD,
Interfaces: testInterfaceConf,
})
@@ -584,12 +612,12 @@ func TestDHCPServer_Reset(t *testing.T) {
}}
for _, l := range leases {
require.NoError(t, srv.AddLease(l))
require.NoError(t, srv.AddLease(ctx, l))
}
require.Len(t, srv.Leases(), len(leases))
require.NoError(t, srv.Reset())
require.NoError(t, srv.Reset(ctx))
assert.Empty(t, srv.Leases())
}

View File

@@ -1,13 +1,15 @@
package dhcpsvc
import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"slices"
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/google/gopacket/layers"
)
@@ -43,25 +45,133 @@ type IPv4Config struct {
}
// validate returns an error in conf if any.
func (conf *IPv4Config) validate() (err error) {
switch {
case conf == nil:
func (c *IPv4Config) validate() (err error) {
if c == nil {
return errNilConfig
case !conf.Enabled:
return nil
case !conf.GatewayIP.Is4():
return newMustErr("gateway ip", "be a valid ipv4", conf.GatewayIP)
case !conf.SubnetMask.Is4():
return newMustErr("subnet mask", "be a valid ipv4 cidr mask", conf.SubnetMask)
case !conf.RangeStart.Is4():
return newMustErr("range start", "be a valid ipv4", conf.RangeStart)
case !conf.RangeEnd.Is4():
return newMustErr("range end", "be a valid ipv4", conf.RangeEnd)
case conf.LeaseDuration <= 0:
return newMustErr("lease duration", "be less than %d", conf.LeaseDuration)
default:
} else if !c.Enabled {
return nil
}
var errs []error
if !c.GatewayIP.Is4() {
err = newMustErr("gateway ip", "be a valid ipv4", c.GatewayIP)
errs = append(errs, err)
}
if !c.SubnetMask.Is4() {
err = newMustErr("subnet mask", "be a valid ipv4 cidr mask", c.SubnetMask)
errs = append(errs, err)
}
if !c.RangeStart.Is4() {
err = newMustErr("range start", "be a valid ipv4", c.RangeStart)
errs = append(errs, err)
}
if !c.RangeEnd.Is4() {
err = newMustErr("range end", "be a valid ipv4", c.RangeEnd)
errs = append(errs, err)
}
if c.LeaseDuration <= 0 {
err = newMustErr("icmp timeout", "be positive", c.LeaseDuration)
errs = append(errs, err)
}
return errors.Join(errs...)
}
// netInterfaceV4 is a DHCP interface for IPv4 address family.
type netInterfaceV4 struct {
// gateway is the IP address of the network gateway.
gateway netip.Addr
// subnet is the network subnet.
subnet netip.Prefix
// addrSpace is the IPv4 address space allocated for leasing.
addrSpace ipRange
// implicitOpts are the options listed in Appendix A of RFC 2131 and
// initialized with default values. It must not have intersections with
// explicitOpts.
implicitOpts layers.DHCPOptions
// explicitOpts are the user-configured options. It must not have
// intersections with implicitOpts.
explicitOpts layers.DHCPOptions
// netInterface is embedded here to provide some common network interface
// logic.
netInterface
}
// newNetInterfaceV4 creates a new DHCP interface for IPv4 address family with
// the given configuration. It returns an error if the given configuration
// can't be used.
func newNetInterfaceV4(
ctx context.Context,
l *slog.Logger,
name string,
conf *IPv4Config,
) (i *netInterfaceV4, err error) {
l = l.With(
keyInterface, name,
keyFamily, netutil.AddrFamilyIPv4,
)
if !conf.Enabled {
l.DebugContext(ctx, "disabled")
return nil, nil
}
maskLen, _ := net.IPMask(conf.SubnetMask.AsSlice()).Size()
subnet := netip.PrefixFrom(conf.GatewayIP, maskLen)
switch {
case !subnet.Contains(conf.RangeStart):
return nil, fmt.Errorf("range start %s is not within %s", conf.RangeStart, subnet)
case !subnet.Contains(conf.RangeEnd):
return nil, fmt.Errorf("range end %s is not within %s", conf.RangeEnd, subnet)
}
addrSpace, err := newIPRange(conf.RangeStart, conf.RangeEnd)
if err != nil {
return nil, err
} else if addrSpace.contains(conf.GatewayIP) {
return nil, fmt.Errorf("gateway ip %s in the ip range %s", conf.GatewayIP, addrSpace)
}
i = &netInterfaceV4{
gateway: conf.GatewayIP,
subnet: subnet,
addrSpace: addrSpace,
netInterface: netInterface{
name: name,
leaseTTL: conf.LeaseDuration,
logger: l,
},
}
i.implicitOpts, i.explicitOpts = conf.options(ctx, l)
return i, nil
}
// netInterfacesV4 is a slice of network interfaces of IPv4 address family.
type netInterfacesV4 []*netInterfaceV4
// find returns the first network interface within ifaces containing ip. It
// returns false if there is no such interface.
func (ifaces netInterfacesV4) find(ip netip.Addr) (iface4 *netInterface, ok bool) {
i := slices.IndexFunc(ifaces, func(iface *netInterfaceV4) (contains bool) {
return iface.subnet.Contains(ip)
})
if i < 0 {
return nil, false
}
return &ifaces[i].netInterface, true
}
// options returns the implicit and explicit options for the interface. The two
@@ -69,14 +179,14 @@ func (conf *IPv4Config) validate() (err error) {
// values.
//
// TODO(e.burkov): DRY with the IPv6 version.
func (conf *IPv4Config) options() (implicit, explicit layers.DHCPOptions) {
func (c *IPv4Config) options(ctx context.Context, l *slog.Logger) (imp, exp layers.DHCPOptions) {
// Set default values of host configuration parameters listed in Appendix A
// of RFC-2131.
implicit = layers.DHCPOptions{
imp = layers.DHCPOptions{
// Values From Configuration
layers.NewDHCPOption(layers.DHCPOptSubnetMask, conf.SubnetMask.AsSlice()),
layers.NewDHCPOption(layers.DHCPOptRouter, conf.GatewayIP.AsSlice()),
layers.NewDHCPOption(layers.DHCPOptSubnetMask, c.SubnetMask.AsSlice()),
layers.NewDHCPOption(layers.DHCPOptRouter, c.GatewayIP.AsSlice()),
// IP-Layer Per Host
@@ -228,110 +338,29 @@ func (conf *IPv4Config) options() (implicit, explicit layers.DHCPOptions) {
// See https://datatracker.ietf.org/doc/html/rfc1122#section-4.2.3.6.
layers.NewDHCPOption(layers.DHCPOptTCPKeepAliveGarbage, []byte{0x1}),
}
slices.SortFunc(implicit, compareV4OptionCodes)
slices.SortFunc(imp, compareV4OptionCodes)
// Set values for explicitly configured options.
for _, exp := range conf.Options {
i, found := slices.BinarySearchFunc(implicit, exp, compareV4OptionCodes)
for _, o := range c.Options {
i, found := slices.BinarySearchFunc(imp, o, compareV4OptionCodes)
if found {
implicit = slices.Delete(implicit, i, i+1)
imp = slices.Delete(imp, i, i+1)
}
i, found = slices.BinarySearchFunc(explicit, exp, compareV4OptionCodes)
if exp.Length > 0 {
explicit = slices.Insert(explicit, i, exp)
i, found = slices.BinarySearchFunc(exp, o, compareV4OptionCodes)
if o.Length > 0 {
exp = slices.Insert(exp, i, o)
} else if found {
explicit = slices.Delete(explicit, i, i+1)
exp = slices.Delete(exp, i, i+1)
}
}
log.Debug("dhcpsvc: v4: implicit options: %s", implicit)
log.Debug("dhcpsvc: v4: explicit options: %s", explicit)
l.DebugContext(ctx, "options", "implicit", imp, "explicit", exp)
return implicit, explicit
return imp, exp
}
// compareV4OptionCodes compares option codes of a and b.
func compareV4OptionCodes(a, b layers.DHCPOption) (res int) {
return int(a.Type) - int(b.Type)
}
// netInterfaceV4 is a DHCP interface for IPv4 address family.
type netInterfaceV4 struct {
// gateway is the IP address of the network gateway.
gateway netip.Addr
// subnet is the network subnet.
subnet netip.Prefix
// addrSpace is the IPv4 address space allocated for leasing.
addrSpace ipRange
// implicitOpts are the options listed in Appendix A of RFC 2131 and
// initialized with default values. It must not have intersections with
// explicitOpts.
implicitOpts layers.DHCPOptions
// explicitOpts are the user-configured options. It must not have
// intersections with implicitOpts.
explicitOpts layers.DHCPOptions
// netInterface is embedded here to provide some common network interface
// logic.
netInterface
}
// newNetInterfaceV4 creates a new DHCP interface for IPv4 address family with
// the given configuration. It returns an error if the given configuration
// can't be used.
func newNetInterfaceV4(name string, conf *IPv4Config) (i *netInterfaceV4, err error) {
if !conf.Enabled {
return nil, nil
}
maskLen, _ := net.IPMask(conf.SubnetMask.AsSlice()).Size()
subnet := netip.PrefixFrom(conf.GatewayIP, maskLen)
switch {
case !subnet.Contains(conf.RangeStart):
return nil, fmt.Errorf("range start %s is not within %s", conf.RangeStart, subnet)
case !subnet.Contains(conf.RangeEnd):
return nil, fmt.Errorf("range end %s is not within %s", conf.RangeEnd, subnet)
}
addrSpace, err := newIPRange(conf.RangeStart, conf.RangeEnd)
if err != nil {
return nil, err
} else if addrSpace.contains(conf.GatewayIP) {
return nil, fmt.Errorf("gateway ip %s in the ip range %s", conf.GatewayIP, addrSpace)
}
i = &netInterfaceV4{
gateway: conf.GatewayIP,
subnet: subnet,
addrSpace: addrSpace,
netInterface: netInterface{
name: name,
leaseTTL: conf.LeaseDuration,
},
}
i.implicitOpts, i.explicitOpts = conf.options()
return i, nil
}
// netInterfacesV4 is a slice of network interfaces of IPv4 address family.
type netInterfacesV4 []*netInterfaceV4
// find returns the first network interface within ifaces containing ip. It
// returns false if there is no such interface.
func (ifaces netInterfacesV4) find(ip netip.Addr) (iface4 *netInterface, ok bool) {
i := slices.IndexFunc(ifaces, func(iface *netInterfaceV4) (contains bool) {
return iface.subnet.Contains(ip)
})
if i < 0 {
return nil, false
}
return &ifaces[i].netInterface, true
}

View File

@@ -3,7 +3,10 @@ package dhcpsvc
import (
"net/netip"
"testing"
"time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/assert"
)
@@ -75,9 +78,12 @@ func TestIPv4Config_Options(t *testing.T) {
wantExplicit: layers.DHCPOptions{opt1},
}}
ctx := testutil.ContextWithTimeout(t, time.Second)
l := slogutil.NewDiscardLogger()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
imp, exp := tc.conf.options()
imp, exp := tc.conf.options(ctx, l)
assert.Equal(t, tc.wantExplicit, exp)
for c := range exp {

View File

@@ -1,12 +1,14 @@
package dhcpsvc
import (
"context"
"fmt"
"log/slog"
"net/netip"
"slices"
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/google/gopacket/layers"
)
@@ -38,50 +40,26 @@ type IPv6Config struct {
}
// validate returns an error in conf if any.
func (conf *IPv6Config) validate() (err error) {
switch {
case conf == nil:
func (c *IPv6Config) validate() (err error) {
if c == nil {
return errNilConfig
case !conf.Enabled:
return nil
case !conf.RangeStart.Is6():
return fmt.Errorf("range start %s should be a valid ipv6", conf.RangeStart)
case conf.LeaseDuration <= 0:
return fmt.Errorf("lease duration %s must be positive", conf.LeaseDuration)
default:
} else if !c.Enabled {
return nil
}
}
// options returns the implicit and explicit options for the interface. The two
// lists are disjoint and the implicit options are initialized with default
// values.
//
// TODO(e.burkov): Add implicit options according to RFC.
func (conf *IPv6Config) options() (implicit, explicit layers.DHCPv6Options) {
// Set default values of host configuration parameters listed in RFC 8415.
implicit = layers.DHCPv6Options{}
slices.SortFunc(implicit, compareV6OptionCodes)
var errs []error
// Set values for explicitly configured options.
for _, exp := range conf.Options {
i, found := slices.BinarySearchFunc(implicit, exp, compareV6OptionCodes)
if found {
implicit = slices.Delete(implicit, i, i+1)
}
explicit = append(explicit, exp)
if !c.RangeStart.Is6() {
err = fmt.Errorf("range start %s should be a valid ipv6", c.RangeStart)
errs = append(errs, err)
}
log.Debug("dhcpsvc: v6: implicit options: %s", implicit)
log.Debug("dhcpsvc: v6: explicit options: %s", explicit)
if c.LeaseDuration <= 0 {
err = fmt.Errorf("lease duration %s must be positive", c.LeaseDuration)
errs = append(errs, err)
}
return implicit, explicit
}
// compareV6OptionCodes compares option codes of a and b.
func compareV6OptionCodes(a, b layers.DHCPv6Option) (res int) {
return int(a.Code) - int(b.Code)
return errors.Join(errs...)
}
// netInterfaceV6 is a DHCP interface for IPv6 address family.
@@ -116,8 +94,16 @@ type netInterfaceV6 struct {
// the given configuration.
//
// TODO(e.burkov): Validate properly.
func newNetInterfaceV6(name string, conf *IPv6Config) (i *netInterfaceV6) {
func newNetInterfaceV6(
ctx context.Context,
l *slog.Logger,
name string,
conf *IPv6Config,
) (i *netInterfaceV6) {
l = l.With(keyInterface, name, keyFamily, netutil.AddrFamilyIPv6)
if !conf.Enabled {
l.DebugContext(ctx, "disabled")
return nil
}
@@ -126,11 +112,12 @@ func newNetInterfaceV6(name string, conf *IPv6Config) (i *netInterfaceV6) {
netInterface: netInterface{
name: name,
leaseTTL: conf.LeaseDuration,
logger: l,
},
raSLAACOnly: conf.RASLAACOnly,
raAllowSLAAC: conf.RAAllowSLAAC,
}
i.implicitOpts, i.explicitOpts = conf.options()
i.implicitOpts, i.explicitOpts = conf.options(ctx, l)
return i
}
@@ -159,3 +146,33 @@ func (ifaces netInterfacesV6) find(ip netip.Addr) (iface6 *netInterface, ok bool
return &ifaces[i].netInterface, true
}
// options returns the implicit and explicit options for the interface. The two
// lists are disjoint and the implicit options are initialized with default
// values.
//
// TODO(e.burkov): Add implicit options according to RFC.
func (c *IPv6Config) options(ctx context.Context, l *slog.Logger) (imp, exp layers.DHCPv6Options) {
// Set default values of host configuration parameters listed in RFC 8415.
imp = layers.DHCPv6Options{}
slices.SortFunc(imp, compareV6OptionCodes)
// Set values for explicitly configured options.
for _, e := range c.Options {
i, found := slices.BinarySearchFunc(imp, e, compareV6OptionCodes)
if found {
imp = slices.Delete(imp, i, i+1)
}
exp = append(exp, e)
}
l.DebugContext(ctx, "options", "implicit", imp, "explicit", exp)
return imp, exp
}
// compareV6OptionCodes compares option codes of a and b.
func compareV6OptionCodes(a, b layers.DHCPv6Option) (res int) {
return int(a.Code) - int(b.Code)
}

View File

@@ -6,7 +6,6 @@ import (
"os"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/ipset"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
@@ -35,7 +34,7 @@ func (c *ipsetCtx) init(ipsetConf []string) (err error) {
log.Info("ipset: warning: cannot initialize: %s", err)
return nil
} else if unsupErr := (&aghos.UnsupportedError{}); errors.As(err, &unsupErr) {
} else if errors.Is(err, errors.ErrUnsupported) {
log.Info("ipset: warning: %s", err)
return nil

View File

@@ -28,7 +28,7 @@ type URLFilterID = int
// The IDs of built-in filter lists.
//
// NOTE: Do not change without the need for it and keep in sync with
// client/src/helpers/constants.js.
// client/src/helpers/constants.ts.
//
// TODO(a.garipov): Add type [URLFilterID] once it is used consistently in
// package filtering.

View File

@@ -532,7 +532,7 @@ var blockedServices = []blockedService{{
},
}, {
ID: "cloudflare",
Name: "CloudFlare",
Name: "Cloudflare",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M38 10.813l-.906 3.78-1.907-3.405v1.718c2.899 2.301 4.926 5.79 5.126 9.688.699-.2 1.3-.188 2-.188 1.374 0 2.667.297 3.812.875l-1.031-.593 3.812-.875-3.812-.907L48.5 19h-3.813l2.813-2.688-3.688 1.094 2-3.312-3.312 2 1.094-3.688-2.688 2.781.094-3.874-2 3.28zM27 11c-5 0-9.414 2.992-11.313 7.594-.699-.399-1.687-.688-2.687-.688-3.2 0-5.906 2.606-5.906 5.907v.5c-3.899.3-7.094 3.68-7.094 7.78 0 .802.113 1.52.313 2.22.101.398.5.687 1 .687h47c.398 0 .675-.195.874-.594.5-1.101.813-2.207.813-3.406 0-4.2-3.488-7.594-7.688-7.594-.8 0-1.511.082-2.312.282l4.906 6.625-5.5-4.5L22 29.593l15.094-4.905L28.5 21.5l10.688 1.813v-.125C39.188 16.488 33.699 11 27 11zm19.781 12.656c.434.274.844.586 1.219.938h.5z\" /></svg>"),
Rules: []string{
"||argotunnel.com^",
@@ -709,7 +709,7 @@ var blockedServices = []blockedService{{
},
}, {
ID: "ebay",
Name: "EBay",
Name: "eBay",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M 12.601563 13.671875 L 12.636719 24.058594 C 12.632813 22.457031 12.128906 18.101563 6.464844 18.097656 C 0.210938 18.097656 -0.03125 22.964844 0.00390625 24.230469 C 0.00390625 24.230469 -0.304688 29.917969 6.3125 29.917969 C 11.996094 29.917969 12.277344 26.347656 12.277344 26.347656 L 9.664063 26.355469 C 9.664063 26.355469 9.152344 28.320313 6.320313 28.265625 C 2.683594 28.199219 2.546875 24.675781 2.546875 24.675781 L 12.621094 24.675781 C 12.621094 24.675781 12.628906 24.566406 12.636719 24.425781 L 12.644531 26.960938 C 12.644531 26.960938 12.628906 28.507813 12.535156 29.53125 L 14.984375 29.53125 L 15.089844 28.039063 C 15.089844 28.039063 16.230469 29.917969 19.566406 29.917969 C 22.902344 29.917969 25.535156 27.863281 25.609375 24.050781 C 25.675781 20.242188 22.761719 18.117188 19.617188 18.097656 C 16.472656 18.082031 15.121094 19.960938 15.121094 19.960938 L 15.121094 13.671875 Z M 31.054688 18.046875 C 29.566406 18.097656 26.539063 18.558594 26.132813 21.460938 L 28.796875 21.460938 C 28.796875 21.460938 29 19.6875 31.703125 19.738281 C 34.257813 19.785156 34.722656 21.039063 34.707031 22.578125 C 34.707031 22.578125 32.519531 22.585938 31.785156 22.59375 C 30.46875 22.597656 25.863281 22.742188 25.433594 25.550781 C 24.917969 28.890625 27.898438 29.933594 30.230469 29.917969 C 32.5625 29.90625 33.890625 29.207031 34.878906 27.953125 L 34.984375 29.511719 L 37.300781 29.496094 C 37.300781 29.496094 37.242188 28.628906 37.25 26.90625 C 37.257813 25.1875 37.308594 23.65625 37.25 22.574219 C 37.183594 21.316406 37.304688 18.285156 31.875 18.0625 C 31.875 18.0625 31.550781 18.03125 31.054688 18.046875 Z M 35.871094 18.519531 L 41.675781 29.496094 L 39.4375 33.71875 L 42.265625 33.71875 L 50 18.519531 L 47.359375 18.519531 L 43.074219 27.046875 L 38.796875 18.519531 Z M 6.402344 19.765625 C 9.984375 19.761719 9.984375 22.949219 9.984375 22.949219 L 2.628906 22.949219 C 2.628906 22.949219 2.804688 19.765625 6.402344 19.765625 Z M 19.035156 19.800781 C 23.078125 19.699219 22.949219 24.097656 22.949219 24.097656 C 22.949219 24.097656 23.011719 28.167969 19.042969 28.21875 C 15.070313 28.269531 15.136719 24.011719 15.136719 24.011719 C 15.136719 24.011719 14.992188 19.90625 19.035156 19.800781 Z M 34.734375 24.265625 C 34.734375 24.269531 35.195313 28.371094 30.664063 28.3125 C 30.664063 28.3125 28.136719 28.3125 27.988281 26.296875 C 27.832031 24.140625 31.875 24.269531 31.875 24.269531 Z\" /></svg>"),
Rules: []string{
"|ebay-*.s3-us-west-1.amazonaws.com^",

View File

@@ -5,7 +5,6 @@ import (
"net"
"net/netip"
"slices"
"strings"
"sync"
"time"
@@ -16,10 +15,10 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/log"
@@ -45,14 +44,12 @@ type DHCP interface {
// clientsContainer is the storage of all runtime and persistent clients.
type clientsContainer struct {
// clientIndex stores information about persistent clients.
clientIndex *client.Index
// storage stores information about persistent clients.
storage *client.Storage
// runtimeIndex stores information about runtime clients.
runtimeIndex *client.RuntimeIndex
allTags *container.MapSet[string]
// dhcp is the DHCP service implementation.
dhcp DHCP
@@ -104,15 +101,15 @@ func (clients *clientsContainer) Init(
filteringConf *filtering.Config,
) (err error) {
// TODO(s.chzhen): Refactor it.
if clients.clientIndex != nil {
if clients.storage != nil {
return errors.Error("clients container already initialized")
}
clients.runtimeIndex = client.NewRuntimeIndex()
clients.clientIndex = client.NewIndex()
clients.allTags = container.NewMapSet(clientTags...)
clients.storage = client.NewStorage(&client.Config{
AllowedTags: clientTags,
})
// TODO(e.burkov): Use [dhcpsvc] implementation when it's ready.
clients.dhcp = dhcpServer
@@ -217,7 +214,6 @@ type clientObject struct {
// toPersistent returns an initialized persistent client if there are no errors.
func (o *clientObject) toPersistent(
filteringConf *filtering.Config,
allTags *container.MapSet[string],
) (cli *client.Persistent, err error) {
cli = &client.Persistent{
Name: o.Name,
@@ -261,6 +257,12 @@ func (o *clientObject) toPersistent(
}
}
if o.BlockedServices == nil {
o.BlockedServices = &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
}
}
err = o.BlockedServices.Validate()
if err != nil {
return nil, fmt.Errorf("init blocked services %q: %w", cli.Name, err)
@@ -268,7 +270,7 @@ func (o *clientObject) toPersistent(
cli.BlockedServices = o.BlockedServices.Clone()
cli.SetTags(o.Tags, allTags)
cli.Tags = slices.Clone(o.Tags)
return cli, nil
}
@@ -281,22 +283,14 @@ func (clients *clientsContainer) addFromConfig(
) (err error) {
for i, o := range objects {
var cli *client.Persistent
cli, err = o.toPersistent(filteringConf, clients.allTags)
cli, err = o.toPersistent(filteringConf)
if err != nil {
return fmt.Errorf("clients: init persistent client at index %d: %w", i, err)
}
// TODO(s.chzhen): Consider moving to the client index constructor.
err = clients.clientIndex.ClashesUID(cli)
err = clients.storage.Add(cli)
if err != nil {
return fmt.Errorf("adding client %s at index %d: %w", cli.Name, i, err)
}
err = clients.add(cli)
if err != nil {
// TODO(s.chzhen): Return an error instead of logging if more
// stringent requirements are implemented.
log.Error("clients: adding client %s at index %d: %s", cli.Name, i, err)
return fmt.Errorf("adding client %q at index %d: %w", cli.Name, i, err)
}
}
@@ -309,8 +303,8 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
clients.lock.Lock()
defer clients.lock.Unlock()
objs = make([]*clientObject, 0, clients.clientIndex.Size())
clients.clientIndex.Range(func(cli *client.Persistent) (cont bool) {
objs = make([]*clientObject, 0, clients.storage.Size())
clients.storage.RangeByName(func(cli *client.Persistent) (cont bool) {
objs = append(objs, &clientObject{
Name: cli.Name,
@@ -337,14 +331,6 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
return true
})
// Maps aren't guaranteed to iterate in the same order each time, so the
// above loop can generate different orderings when writing to the config
// file: this produces lots of diffs in config files, so sort objects by
// name before writing.
slices.SortStableFunc(objs, func(a, b *clientObject) (res int) {
return strings.Compare(a.Name, b.Name)
})
return objs
}
@@ -362,7 +348,7 @@ func (clients *clientsContainer) periodicUpdate() {
// clientSource checks if client with this IP address already exists and returns
// the source which updated it last. It returns [client.SourceNone] if the
// client doesn't exist.
// client doesn't exist. Note that it is only used in tests.
func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source) {
clients.lock.Lock()
defer clients.lock.Unlock()
@@ -421,12 +407,8 @@ func (clients *clientsContainer) clientOrArtificial(
}
}()
cli, ok := clients.find(id)
if !ok {
cli = clients.clientIndex.FindByIPWithoutZone(ip)
}
if cli != nil {
cli, ok := clients.storage.FindLoose(ip, id)
if ok {
return &querylog.Client{
Name: cli.Name,
IgnoreQueryLog: cli.IgnoreQueryLog,
@@ -458,7 +440,7 @@ func (clients *clientsContainer) find(id string) (c *client.Persistent, ok bool)
return nil, false
}
return c.ShallowClone(), true
return c, true
}
// shouldCountClient is a wrapper around [clientsContainer.find] to make it a
@@ -532,7 +514,7 @@ func (clients *clientsContainer) UpstreamConfigByID(
// findLocked searches for a client by its ID. clients.lock is expected to be
// locked.
func (clients *clientsContainer) findLocked(id string) (c *client.Persistent, ok bool) {
c, ok = clients.clientIndex.Find(id)
c, ok = clients.storage.Find(id)
if ok {
return c, true
}
@@ -554,7 +536,7 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent,
return nil, false
}
return clients.clientIndex.FindByMAC(foundMAC)
return clients.storage.FindByMAC(foundMAC)
}
// runtimeClient returns a runtime client from internal index. Note that it
@@ -588,114 +570,6 @@ func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Ru
return rc
}
// check validates the client. It also sorts the client tags.
func (clients *clientsContainer) check(c *client.Persistent) (err error) {
switch {
case c == nil:
return errors.Error("client is nil")
case c.Name == "":
return errors.Error("invalid name")
case c.IDsLen() == 0:
return errors.Error("id required")
default:
// Go on.
}
for _, t := range c.Tags {
if !clients.allTags.Has(t) {
return fmt.Errorf("invalid tag: %q", t)
}
}
// TODO(s.chzhen): Move to the constructor.
slices.Sort(c.Tags)
_, err = proxy.ParseUpstreamsConfig(c.Upstreams, &upstream.Options{})
if err != nil {
return fmt.Errorf("invalid upstream servers: %w", err)
}
return nil
}
// add adds a persistent client or returns an error.
func (clients *clientsContainer) add(c *client.Persistent) (err error) {
err = clients.check(c)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
clients.lock.Lock()
defer clients.lock.Unlock()
err = clients.clientIndex.Clashes(c)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
clients.addLocked(c)
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs(), clients.clientIndex.Size())
return nil
}
// addLocked c to the indexes. clients.lock is expected to be locked.
func (clients *clientsContainer) addLocked(c *client.Persistent) {
clients.clientIndex.Add(c)
}
// remove removes a client. ok is false if there is no such client.
func (clients *clientsContainer) remove(name string) (ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
c, ok := clients.clientIndex.FindByName(name)
if !ok {
return false
}
clients.removeLocked(c)
return true
}
// removeLocked removes c from the indexes. clients.lock is expected to be
// locked.
func (clients *clientsContainer) removeLocked(c *client.Persistent) {
if err := c.CloseUpstreams(); err != nil {
log.Error("client container: removing client %s: %s", c.Name, err)
}
// Update the ID index.
clients.clientIndex.Delete(c)
}
// update updates a client by its name.
func (clients *clientsContainer) update(prev, c *client.Persistent) (err error) {
err = clients.check(c)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
clients.lock.Lock()
defer clients.lock.Unlock()
err = clients.clientIndex.Clashes(c)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
clients.removeLocked(prev)
clients.addLocked(c)
return nil
}
// setWHOISInfo sets the WHOIS information for a client. clients.lock is
// expected to be locked.
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
@@ -857,5 +731,5 @@ func (clients *clientsContainer) addFromSystemARP() {
// close gracefully closes all the client-specific upstream configurations of
// the persistent clients.
func (clients *clientsContainer) close() (err error) {
return clients.clientIndex.CloseUpstreams()
return clients.storage.CloseUpstreams()
}

View File

@@ -72,7 +72,7 @@ func TestClients(t *testing.T) {
IPs: []netip.Addr{cli1IP, cliIPv6},
}
err := clients.add(c)
err := clients.storage.Add(c)
require.NoError(t, err)
c = &client.Persistent{
@@ -81,7 +81,7 @@ func TestClients(t *testing.T) {
IPs: []netip.Addr{cli2IP},
}
err = clients.add(c)
err = clients.storage.Add(c)
require.NoError(t, err)
c, ok := clients.find(cli1)
@@ -107,7 +107,7 @@ func TestClients(t *testing.T) {
})
t.Run("add_fail_name", func(t *testing.T) {
err := clients.add(&client.Persistent{
err := clients.storage.Add(&client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")},
@@ -116,7 +116,7 @@ func TestClients(t *testing.T) {
})
t.Run("add_fail_ip", func(t *testing.T) {
err := clients.add(&client.Persistent{
err := clients.storage.Add(&client.Persistent{
Name: "client3",
UID: client.MustNewUID(),
})
@@ -124,7 +124,7 @@ func TestClients(t *testing.T) {
})
t.Run("update_fail_ip", func(t *testing.T) {
err := clients.update(&client.Persistent{Name: "client1"}, &client.Persistent{
err := clients.storage.Update("client1", &client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
})
@@ -139,11 +139,11 @@ func TestClients(t *testing.T) {
cliNewIP = netip.MustParseAddr(cliNew)
)
prev, ok := clients.clientIndex.FindByName("client1")
prev, ok := clients.storage.FindByName("client1")
require.True(t, ok)
require.NotNil(t, prev)
err := clients.update(prev, &client.Persistent{
err := clients.storage.Update("client1", &client.Persistent{
Name: "client1",
UID: prev.UID,
IPs: []netip.Addr{cliNewIP},
@@ -155,11 +155,11 @@ func TestClients(t *testing.T) {
assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent)
prev, ok = clients.clientIndex.FindByName("client1")
prev, ok = clients.storage.FindByName("client1")
require.True(t, ok)
require.NotNil(t, prev)
err = clients.update(prev, &client.Persistent{
err = clients.storage.Update("client1", &client.Persistent{
Name: "client1-renamed",
UID: prev.UID,
IPs: []netip.Addr{cliNewIP},
@@ -173,7 +173,7 @@ func TestClients(t *testing.T) {
assert.Equal(t, "client1-renamed", c.Name)
assert.True(t, c.UseOwnSettings)
nilCli, ok := clients.clientIndex.FindByName("client1")
nilCli, ok := clients.storage.FindByName("client1")
require.False(t, ok)
assert.Nil(t, nilCli)
@@ -184,7 +184,7 @@ func TestClients(t *testing.T) {
})
t.Run("del_success", func(t *testing.T) {
ok := clients.remove("client1-renamed")
ok := clients.storage.RemoveByName("client1-renamed")
require.True(t, ok)
_, ok = clients.find("1.1.1.2")
@@ -192,7 +192,7 @@ func TestClients(t *testing.T) {
})
t.Run("del_fail", func(t *testing.T) {
ok := clients.remove("client3")
ok := clients.storage.RemoveByName("client3")
assert.False(t, ok)
})
@@ -261,7 +261,7 @@ func TestClientsWHOIS(t *testing.T) {
t.Run("can't_set_manually-added", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.2")
err := clients.add(&client.Persistent{
err := clients.storage.Add(&client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")},
@@ -272,7 +272,7 @@ func TestClientsWHOIS(t *testing.T) {
rc := clients.runtimeIndex.Client(ip)
require.Nil(t, rc)
assert.True(t, clients.remove("client1"))
assert.True(t, clients.storage.RemoveByName("client1"))
})
}
@@ -283,7 +283,7 @@ func TestClientsAddExisting(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
// Add a client.
err := clients.add(&client.Persistent{
err := clients.storage.Add(&client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")},
@@ -333,7 +333,7 @@ func TestClientsAddExisting(t *testing.T) {
require.NoError(t, err)
// Add a new client with the same IP as for a client with MAC.
err = clients.add(&client.Persistent{
err = clients.storage.Add(&client.Persistent{
Name: "client2",
UID: client.MustNewUID(),
IPs: []netip.Addr{ip},
@@ -341,7 +341,7 @@ func TestClientsAddExisting(t *testing.T) {
require.NoError(t, err)
// Add a new client with the IP from the first client's IP range.
err = clients.add(&client.Persistent{
err = clients.storage.Add(&client.Persistent{
Name: "client3",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
@@ -354,7 +354,7 @@ func TestClientsCustomUpstream(t *testing.T) {
clients := newClientsContainer(t)
// Add client with upstreams.
err := clients.add(&client.Persistent{
err := clients.storage.Add(&client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")},

View File

@@ -96,7 +96,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
clients.lock.Lock()
defer clients.lock.Unlock()
clients.clientIndex.Range(func(c *client.Persistent) (cont bool) {
clients.storage.RangeByName(func(c *client.Persistent) (cont bool) {
cj := clientToJSON(c)
data.Clients = append(data.Clients, cj)
@@ -267,7 +267,7 @@ func copyBlockedServices(
var weekly *schedule.Weekly
if sch != nil {
weekly = sch.Clone()
} else if prev != nil && prev.BlockedServices != nil {
} else if prev != nil {
weekly = prev.BlockedServices.Schedule.Clone()
} else {
weekly = schedule.EmptyWeekly()
@@ -336,7 +336,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
return
}
err = clients.add(c)
err = clients.storage.Add(c)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@@ -364,7 +364,7 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.
return
}
if !clients.remove(cj.Name) {
if !clients.storage.RemoveByName(cj.Name) {
aghhttp.Error(r, w, http.StatusBadRequest, "Client not found")
return
@@ -399,30 +399,14 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
return
}
var prev *client.Persistent
var ok bool
func() {
clients.lock.Lock()
defer clients.lock.Unlock()
prev, ok = clients.clientIndex.FindByName(dj.Name)
}()
if !ok {
aghhttp.Error(r, w, http.StatusBadRequest, "client not found")
return
}
c, err := clients.jsonToClient(dj.Data, prev)
c, err := clients.jsonToClient(dj.Data, nil)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
err = clients.update(prev, c)
err = clients.storage.Update(dj.Name, c)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)

View File

@@ -49,7 +49,7 @@ func newPersistentClient(name string) (c *client.Persistent) {
Name: name,
UID: client.MustNewUID(),
BlockedServices: &filtering.BlockedServices{
Schedule: &schedule.Weekly{},
Schedule: schedule.EmptyWeekly(),
},
}
}
@@ -198,11 +198,11 @@ func TestClientsContainer_HandleDelClient(t *testing.T) {
clients := newClientsContainer(t)
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
err := clients.add(clientOne)
err := clients.storage.Add(clientOne)
require.NoError(t, err)
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
err = clients.add(clientTwo)
err = clients.storage.Add(clientTwo)
require.NoError(t, err)
assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo})
@@ -260,7 +260,7 @@ func TestClientsContainer_HandleUpdateClient(t *testing.T) {
clients := newClientsContainer(t)
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
err := clients.add(clientOne)
err := clients.storage.Add(clientOne)
require.NoError(t, err)
assertPersistentClients(t, clients, []*client.Persistent{clientOne})
@@ -342,11 +342,11 @@ func TestClientsContainer_HandleFindClient(t *testing.T) {
}
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
err := clients.add(clientOne)
err := clients.storage.Add(clientOne)
require.NoError(t, err)
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
err = clients.add(clientTwo)
err = clients.storage.Add(clientTwo)
require.NoError(t, err)
assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo})

View File

@@ -32,6 +32,9 @@ const dataDir = "data"
// logSettings are the logging settings part of the configuration file.
type logSettings struct {
// Enabled indicates whether logging is enabled.
Enabled bool `yaml:"enabled"`
// File is the path to the log file. If empty, logs are written to stdout.
// If "syslog", logs are written to syslog.
File string `yaml:"file"`
@@ -385,7 +388,7 @@ var config = &configuration{
Ignored: []string{},
},
// NOTE: Keep these parameters in sync with the one put into
// client/src/helpers/filters/filters.js by scripts/vetted-filters.
// client/src/helpers/filters/filters.ts by scripts/vetted-filters.
//
// TODO(a.garipov): Think of a way to make scripts/vetted-filters update
// these as well if necessary.
@@ -454,11 +457,14 @@ var config = &configuration{
},
},
Log: logSettings{
Compress: false,
LocalTime: false,
Enabled: true,
File: "",
MaxBackups: 0,
MaxSize: 100,
MaxAge: 3,
Compress: false,
LocalTime: false,
Verbose: false,
},
OSConfig: &osConfig{},
SchemaVersion: configmigrate.LastSchemaVersion,

View File

@@ -13,17 +13,21 @@ import (
var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4})
// newIDIndex is a helper function that returns a client index filled with
// persistent clients from the m. It also generates a UID for each client.
func newIDIndex(m []*client.Persistent) (ci *client.Index) {
ci = client.NewIndex()
// newStorage is a helper function that returns a client storage filled with
// persistent clients. It also generates a UID for each client.
func newStorage(tb testing.TB, clients []*client.Persistent) (s *client.Storage) {
tb.Helper()
for _, c := range m {
c.UID = client.MustNewUID()
ci.Add(c)
s = client.NewStorage(&client.Config{
AllowedTags: nil,
})
for _, p := range clients {
p.UID = client.MustNewUID()
require.NoError(tb, s.Add(p))
}
return ci
return s
}
func TestApplyAdditionalFiltering(t *testing.T) {
@@ -36,7 +40,8 @@ func TestApplyAdditionalFiltering(t *testing.T) {
}, nil)
require.NoError(t, err)
Context.clients.clientIndex = newIDIndex([]*client.Persistent{{
Context.clients.storage = newStorage(t, []*client.Persistent{{
Name: "default",
ClientIDs: []string{"default"},
UseOwnSettings: false,
SafeSearchConf: filtering.SafeSearchConfig{Enabled: false},
@@ -44,6 +49,7 @@ func TestApplyAdditionalFiltering(t *testing.T) {
SafeBrowsingEnabled: false,
ParentalEnabled: false,
}, {
Name: "custom_filtering",
ClientIDs: []string{"custom_filtering"},
UseOwnSettings: true,
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
@@ -51,6 +57,7 @@ func TestApplyAdditionalFiltering(t *testing.T) {
SafeBrowsingEnabled: true,
ParentalEnabled: true,
}, {
Name: "partial_custom_filtering",
ClientIDs: []string{"partial_custom_filtering"},
UseOwnSettings: true,
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
@@ -121,16 +128,19 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
}, nil)
require.NoError(t, err)
Context.clients.clientIndex = newIDIndex([]*client.Persistent{{
Context.clients.storage = newStorage(t, []*client.Persistent{{
Name: "default",
ClientIDs: []string{"default"},
UseOwnBlockedServices: false,
}, {
Name: "no_services",
ClientIDs: []string{"no_services"},
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
},
UseOwnBlockedServices: true,
}, {
Name: "services",
ClientIDs: []string{"services"},
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
@@ -138,6 +148,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
},
UseOwnBlockedServices: true,
}, {
Name: "invalid_services",
ClientIDs: []string{"invalid_services"},
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
@@ -145,6 +156,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
},
UseOwnBlockedServices: true,
}, {
Name: "allow_all",
ClientIDs: []string{"allow_all"},
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.FullWeekly(),

View File

@@ -178,7 +178,7 @@ func setupContext(opts options) (err error) {
// unsupported errors and returns nil. If err is nil, logIfUnsupported returns
// nil. Otherwise, it returns err.
func logIfUnsupported(msg string, err error) (outErr error) {
if errors.As(err, new(*aghos.UnsupportedError)) {
if errors.Is(err, errors.ErrUnsupported) {
log.Debug(msg, err)
return nil
@@ -232,7 +232,9 @@ func configureOS(conf *configuration) (err error) {
func setupHostsContainer() (err error) {
hostsWatcher, err := aghos.NewOSWritesWatcher()
if err != nil {
return fmt.Errorf("initing hosts watcher: %w", err)
log.Info("WARNING: initializing filesystem watcher: %s; not watching for changes", err)
hostsWatcher = aghos.EmptyFSWatcher{}
}
paths, err := hostsfile.DefaultHostsPaths()

View File

@@ -21,7 +21,9 @@ func configureLogger(opts options) (err error) {
ls := getLogSettings(opts)
// Configure logger level.
if ls.Verbose {
if !ls.Enabled {
log.SetLevel(log.OFF)
} else if ls.Verbose {
log.SetLevel(log.DEBUG)
}
@@ -91,7 +93,14 @@ func getLogSettings(opts options) (ls *logSettings) {
// separate method in order to configure logger before the actual configuration
// is parsed and applied.
func readLogSettings() (ls *logSettings) {
conf := &configuration{}
// TODO(s.chzhen): Add a helper function that returns default parameters
// for this structure and for the global configuration structure [config].
conf := &configuration{
Log: logSettings{
// By default, it is true if the property does not exist.
Enabled: true,
},
}
yamlFile, err := readConfigFile()
if err != nil {

View File

@@ -14,7 +14,7 @@ type Theme string
// Allowed [Theme] values.
//
// Keep in sync with client/src/helpers/constants.js.
// Keep in sync with client/src/helpers/constants.ts.
const (
ThemeAuto Theme = "auto"
ThemeLight Theme = "light"

View File

@@ -460,8 +460,9 @@ var launchdConfig = `<?xml version='1.0' encoding='UTF-8'?>
// 1. The RestartSec setting is set to a lower value of 10 to make sure we
// always restart quickly.
//
// 2. The ExecStartPre setting is added to make sure that the log directory is
// always created to prevent the 209/STDOUT errors.
// 2. The StandardOutput and StandardError settings are set to redirect the
// output to the systemd journal, see
// https://man7.org/linux/man-pages/man5/systemd.exec.5.html#LOGGING_AND_STANDARD_INPUT/OUTPUT.
const systemdScript = `[Unit]
Description={{.Description}}
ConditionFileIsExecutable={{.Path|cmdEscape}}
@@ -471,7 +472,6 @@ ConditionFileIsExecutable={{.Path|cmdEscape}}
[Service]
StartLimitInterval=5
StartLimitBurst=10
ExecStartPre=/bin/mkdir -p /var/log/
ExecStart={{.Path|cmdEscape}}{{range .Arguments}} {{.|cmd}}{{end}}
{{if .ChRoot}}RootDirectory={{.ChRoot|cmd}}{{end}}
{{if .WorkingDirectory}}WorkingDirectory={{.WorkingDirectory|cmdEscape}}{{end}}
@@ -479,8 +479,8 @@ ExecStart={{.Path|cmdEscape}}{{range .Arguments}} {{.|cmd}}{{end}}
{{if .ReloadSignal}}ExecReload=/bin/kill -{{.ReloadSignal}} "$MAINPID"{{end}}
{{if .PIDFile}}PIDFile={{.PIDFile|cmd}}{{end}}
{{if and .LogOutput .HasOutputFileSupport -}}
StandardOutput=file:/var/log/{{.Name}}.out
StandardError=file:/var/log/{{.Name}}.err
StandardOutput=journal
StandardError=journal
{{- end}}
{{if gt .LimitNOFILE -1 }}LimitNOFILE={{.LimitNOFILE}}{{end}}
{{if .Restart}}Restart={{.Restart}}{{end}}

View File

@@ -133,7 +133,14 @@ func webCheckPortAvailable(port uint16) (ok bool) {
addrPort := netip.AddrPortFrom(config.HTTPConfig.Address.Addr(), port)
return aghnet.CheckPort("tcp", addrPort) == nil
err := aghnet.CheckPort("tcp", addrPort)
if err != nil {
log.Info("web: warning: checking https port: %s", err)
return false
}
return true
}
// tlsConfigChanged updates the TLS configuration and restarts the HTTPS server

View File

@@ -22,8 +22,8 @@ type Manager interface {
//
// DOMAIN[,DOMAIN].../IPSET_NAME[,IPSET_NAME]...
//
// If ipsetConf is empty, msg and err are nil. The error is of type
// *aghos.UnsupportedError if the OS is not supported.
// If ipsetConf is empty, msg and err are nil. The error's chain contains
// [errors.ErrUnsupported] if current OS is not supported.
func NewManager(ipsetConf []string) (mgr Manager, err error) {
if len(ipsetConf) == 0 {
return nil, nil

View File

@@ -4,14 +4,14 @@ go 1.22.4
require (
github.com/fzipp/gocyclo v0.6.0
github.com/golangci/misspell v0.5.1
github.com/golangci/misspell v0.6.0
github.com/gordonklaus/ineffassign v0.1.0
github.com/kisielk/errcheck v1.7.0
github.com/kyoh86/looppointer v0.2.1
github.com/securego/gosec/v2 v2.20.0
github.com/uudashr/gocognit v1.1.2
golang.org/x/tools v0.22.0
golang.org/x/vuln v1.1.1
golang.org/x/vuln v1.1.2
honnef.co/go/tools v0.4.7
mvdan.cc/gofumpt v0.6.0
mvdan.cc/unparam v0.0.0-20240528143540-8a5130ca722f
@@ -26,9 +26,10 @@ require (
github.com/kyoh86/nolint v0.0.1 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 // indirect
golang.org/x/exp/typeparams v0.0.0-20240604190554-fc45aab8b7f8 // indirect
golang.org/x/exp/typeparams v0.0.0-20240613232115-7f521ea00fb8 // indirect
golang.org/x/mod v0.18.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/telemetry v0.0.0-20240701175443-4e29c7872ac1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -12,8 +12,8 @@ github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ=
github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
github.com/golangci/misspell v0.5.1 h1:/SjR1clj5uDjNLwYzCahHwIOPmQgoH04AyQIiWGbhCM=
github.com/golangci/misspell v0.5.1/go.mod h1:keMNyY6R9isGaSAu+4Q8NMBwMPkh15Gtc8UCVoDtAWo=
github.com/golangci/misspell v0.6.0 h1:JCle2HUTNWirNlDIAUO44hUsKhOFqGPoC4LZxlaSXDs=
github.com/golangci/misspell v0.6.0/go.mod h1:keMNyY6R9isGaSAu+4Q8NMBwMPkh15Gtc8UCVoDtAWo=
github.com/google/go-cmdtest v0.4.1-0.20220921163831-55ab3332a786 h1:rcv+Ippz6RAtvaGgKxc+8FQIpxHgsF+HBzPyYL2cyVU=
github.com/google/go-cmdtest v0.4.1-0.20220921163831-55ab3332a786/go.mod h1:apVn/GCasLZUVpAJ6oWAuyP7Ne7CEsQbTnc0plM3m+o=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
@@ -63,8 +63,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
golang.org/x/exp v0.0.0-20230321023759-10a507213a29/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/exp/typeparams v0.0.0-20240604190554-fc45aab8b7f8 h1:WKP3FgLqWfVutBnw/dr+LNg4fzjyTQP5o+ELTIyoBrs=
golang.org/x/exp/typeparams v0.0.0-20240604190554-fc45aab8b7f8/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
golang.org/x/exp/typeparams v0.0.0-20240613232115-7f521ea00fb8 h1:+ZJmEdDFzH5H0CnzOrwgbH3elHctfTecW9X0k2tkn5M=
golang.org/x/exp/typeparams v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
@@ -95,6 +95,8 @@ golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220702020025-31831981b65f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/telemetry v0.0.0-20240701175443-4e29c7872ac1 h1:jveUVYFLPlIma1aZBg9rrUN+Dqk4e6QbVSGiZGwA/2Y=
golang.org/x/telemetry v0.0.0-20240701175443-4e29c7872ac1/go.mod h1:n38mvGdgc4dA684EC4NwQwoPKSw4jyKw8/DgZHDA1Dk=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -109,8 +111,8 @@ golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4=
golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA=
golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c=
golang.org/x/vuln v1.1.1 h1:4nYQg4OSr7uYQMtjuuYqLAEVuTjY4k/CPMYqvv5OPcI=
golang.org/x/vuln v1.1.1/go.mod h1:hNgE+SKMSp2wHVUpW0Ow2ejgKpNJePdML+4YjxrVxik=
golang.org/x/vuln v1.1.2 h1:UkLxe+kAMcrNBpGrFbU0Mc5l7cX97P2nhy21wx5+Qbk=
golang.org/x/vuln v1.1.2/go.mod h1:2o3fRKD8Uz9AraAL3lwd/grWBv+t+SeJnPcqBUJrY24=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=