all: sync with master
This commit is contained in:
@@ -24,7 +24,6 @@ import (
|
||||
"github.com/AdguardTeam/golibs/hostsfile"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
// DHCP is an interface for accessing DHCP lease data the [clientsContainer]
|
||||
@@ -46,22 +45,20 @@ type DHCP interface {
|
||||
|
||||
// clientsContainer is the storage of all runtime and persistent clients.
|
||||
type clientsContainer struct {
|
||||
// TODO(a.garipov): Perhaps use a number of separate indices for different
|
||||
// types (string, netip.Addr, and so on).
|
||||
list map[string]*client.Persistent // name -> client
|
||||
|
||||
// clientIndex stores information about persistent clients.
|
||||
clientIndex *client.Index
|
||||
|
||||
// ipToRC maps IP addresses to runtime client information.
|
||||
ipToRC map[netip.Addr]*client.Runtime
|
||||
// runtimeIndex stores information about runtime clients.
|
||||
runtimeIndex *client.RuntimeIndex
|
||||
|
||||
allTags *container.MapSet[string]
|
||||
|
||||
// dhcp is the DHCP service implementation.
|
||||
dhcp DHCP
|
||||
|
||||
// dnsServer is used for checking clients IP status access list status
|
||||
dnsServer *dnsforward.Server
|
||||
// clientChecker checks if a client is blocked by the current access
|
||||
// settings.
|
||||
clientChecker BlockedClientChecker
|
||||
|
||||
// etcHosts contains list of rewrite rules taken from the operating system's
|
||||
// hosts database.
|
||||
@@ -90,6 +87,12 @@ type clientsContainer struct {
|
||||
testing bool
|
||||
}
|
||||
|
||||
// BlockedClientChecker checks if a client is blocked by the current access
|
||||
// settings.
|
||||
type BlockedClientChecker interface {
|
||||
IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string)
|
||||
}
|
||||
|
||||
// Init initializes clients container
|
||||
// dhcpServer: optional
|
||||
// Note: this function must be called only once
|
||||
@@ -100,12 +103,12 @@ func (clients *clientsContainer) Init(
|
||||
arpDB arpdb.Interface,
|
||||
filteringConf *filtering.Config,
|
||||
) (err error) {
|
||||
if clients.list != nil {
|
||||
log.Fatal("clients.list != nil")
|
||||
// TODO(s.chzhen): Refactor it.
|
||||
if clients.clientIndex != nil {
|
||||
return errors.Error("clients container already initialized")
|
||||
}
|
||||
|
||||
clients.list = map[string]*client.Persistent{}
|
||||
clients.ipToRC = map[netip.Addr]*client.Runtime{}
|
||||
clients.runtimeIndex = client.NewRuntimeIndex()
|
||||
|
||||
clients.clientIndex = client.NewIndex()
|
||||
|
||||
@@ -248,8 +251,6 @@ func (o *clientObject) toPersistent(
|
||||
}
|
||||
|
||||
if o.SafeSearchConf.Enabled {
|
||||
o.SafeSearchConf.CustomResolver = safeSearchResolver{}
|
||||
|
||||
err = cli.SetSafeSearch(
|
||||
o.SafeSearchConf,
|
||||
filteringConf.SafeSearchCacheSize,
|
||||
@@ -285,9 +286,17 @@ func (clients *clientsContainer) addFromConfig(
|
||||
return fmt.Errorf("clients: init persistent client at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
_, err = clients.add(cli)
|
||||
// TODO(s.chzhen): Consider moving to the client index constructor.
|
||||
err = clients.clientIndex.ClashesUID(cli)
|
||||
if err != nil {
|
||||
log.Error("clients: adding client at index %d %s: %s", i, cli.Name, err)
|
||||
return fmt.Errorf("adding client %s at index %d: %w", cli.Name, i, err)
|
||||
}
|
||||
|
||||
err = clients.add(cli)
|
||||
if err != nil {
|
||||
// TODO(s.chzhen): Return an error instead of logging if more
|
||||
// stringent requirements are implemented.
|
||||
log.Error("clients: adding client %s at index %d: %s", cli.Name, i, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,9 +309,9 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
objs = make([]*clientObject, 0, len(clients.list))
|
||||
for _, cli := range clients.list {
|
||||
o := &clientObject{
|
||||
objs = make([]*clientObject, 0, clients.clientIndex.Size())
|
||||
clients.clientIndex.Range(func(cli *client.Persistent) (cont bool) {
|
||||
objs = append(objs, &clientObject{
|
||||
Name: cli.Name,
|
||||
|
||||
BlockedServices: cli.BlockedServices.Clone(),
|
||||
@@ -323,10 +332,10 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||
IgnoreStatistics: cli.IgnoreStatistics,
|
||||
UpstreamsCacheEnabled: cli.UpstreamsCacheEnabled,
|
||||
UpstreamsCacheSize: cli.UpstreamsCacheSize,
|
||||
}
|
||||
})
|
||||
|
||||
objs = append(objs, o)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Maps aren't guaranteed to iterate in the same order each time, so the
|
||||
// above loop can generate different orderings when writing to the config
|
||||
@@ -363,8 +372,8 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source)
|
||||
return client.SourcePersistent
|
||||
}
|
||||
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
if ok {
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
if rc != nil {
|
||||
src, _ = rc.Info()
|
||||
}
|
||||
|
||||
@@ -406,23 +415,26 @@ func (clients *clientsContainer) clientOrArtificial(
|
||||
id string,
|
||||
) (c *querylog.Client, art bool) {
|
||||
defer func() {
|
||||
c.Disallowed, c.DisallowedRule = clients.dnsServer.IsBlockedClient(ip, id)
|
||||
c.Disallowed, c.DisallowedRule = clients.clientChecker.IsBlockedClient(ip, id)
|
||||
if c.WHOIS == nil {
|
||||
c.WHOIS = &whois.Info{}
|
||||
}
|
||||
}()
|
||||
|
||||
cli, ok := clients.find(id)
|
||||
if ok {
|
||||
if !ok {
|
||||
cli = clients.clientIndex.FindByIPWithoutZone(ip)
|
||||
}
|
||||
|
||||
if cli != nil {
|
||||
return &querylog.Client{
|
||||
Name: cli.Name,
|
||||
IgnoreQueryLog: cli.IgnoreQueryLog,
|
||||
}, false
|
||||
}
|
||||
|
||||
var rc *client.Runtime
|
||||
rc, ok = clients.findRuntimeClient(ip)
|
||||
if ok {
|
||||
rc := clients.findRuntimeClient(ip)
|
||||
if rc != nil {
|
||||
_, host := rc.Info()
|
||||
|
||||
return &querylog.Client{
|
||||
@@ -542,47 +554,38 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent,
|
||||
return nil, false
|
||||
}
|
||||
|
||||
for _, c = range clients.list {
|
||||
_, found := slices.BinarySearchFunc(c.MACs, foundMAC, slices.Compare[net.HardwareAddr])
|
||||
if found {
|
||||
return c, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
return clients.clientIndex.FindByMAC(foundMAC)
|
||||
}
|
||||
|
||||
// runtimeClient returns a runtime client from internal index. Note that it
|
||||
// doesn't include DHCP clients.
|
||||
func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime, ok bool) {
|
||||
func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime) {
|
||||
if ip == (netip.Addr{}) {
|
||||
return nil, false
|
||||
return nil
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
rc, ok = clients.ipToRC[ip]
|
||||
|
||||
return rc, ok
|
||||
return clients.runtimeIndex.Client(ip)
|
||||
}
|
||||
|
||||
// findRuntimeClient finds a runtime client by their IP.
|
||||
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime, ok bool) {
|
||||
rc, ok = clients.runtimeClient(ip)
|
||||
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime) {
|
||||
rc = clients.runtimeClient(ip)
|
||||
host := clients.dhcp.HostByIP(ip)
|
||||
|
||||
if host != "" {
|
||||
if !ok {
|
||||
rc = &client.Runtime{}
|
||||
if rc == nil {
|
||||
rc = client.NewRuntime(ip)
|
||||
}
|
||||
|
||||
rc.SetInfo(client.SourceDHCP, []string{host})
|
||||
|
||||
return rc, true
|
||||
return rc
|
||||
}
|
||||
|
||||
return rc, ok
|
||||
return rc
|
||||
}
|
||||
|
||||
// check validates the client. It also sorts the client tags.
|
||||
@@ -615,43 +618,32 @@ func (clients *clientsContainer) check(c *client.Persistent) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// add adds a new client object. ok is false if such client already exists or
|
||||
// if an error occurred.
|
||||
func (clients *clientsContainer) add(c *client.Persistent) (ok bool, err error) {
|
||||
// add adds a persistent client or returns an error.
|
||||
func (clients *clientsContainer) add(c *client.Persistent) (err error) {
|
||||
err = clients.check(c)
|
||||
if err != nil {
|
||||
return false, err
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
// check Name index
|
||||
_, ok = clients.list[c.Name]
|
||||
if ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// check ID index
|
||||
err = clients.clientIndex.Clashes(c)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
|
||||
clients.addLocked(c)
|
||||
|
||||
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs(), len(clients.list))
|
||||
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs(), clients.clientIndex.Size())
|
||||
|
||||
return true, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// addLocked c to the indexes. clients.lock is expected to be locked.
|
||||
func (clients *clientsContainer) addLocked(c *client.Persistent) {
|
||||
// update Name index
|
||||
clients.list[c.Name] = c
|
||||
|
||||
// update ID index
|
||||
clients.clientIndex.Add(c)
|
||||
}
|
||||
|
||||
@@ -660,8 +652,7 @@ func (clients *clientsContainer) remove(name string) (ok bool) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
var c *client.Persistent
|
||||
c, ok = clients.list[name]
|
||||
c, ok := clients.clientIndex.FindByName(name)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
@@ -678,9 +669,6 @@ func (clients *clientsContainer) removeLocked(c *client.Persistent) {
|
||||
log.Error("client container: removing client %s: %s", c.Name, err)
|
||||
}
|
||||
|
||||
// Update the name index.
|
||||
delete(clients.list, c.Name)
|
||||
|
||||
// Update the ID index.
|
||||
clients.clientIndex.Delete(c)
|
||||
}
|
||||
@@ -696,22 +684,6 @@ func (clients *clientsContainer) update(prev, c *client.Persistent) (err error)
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
// Check the name index.
|
||||
if prev.Name != c.Name {
|
||||
_, ok := clients.list[c.Name]
|
||||
if ok {
|
||||
return errors.Error("client already exists")
|
||||
}
|
||||
}
|
||||
|
||||
if c.EqualIDs(prev) {
|
||||
clients.removeLocked(prev)
|
||||
clients.addLocked(c)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check the ID index.
|
||||
err = clients.clientIndex.Clashes(c)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
@@ -734,12 +706,12 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
|
||||
return
|
||||
}
|
||||
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
if !ok {
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
if rc == nil {
|
||||
// Create a RuntimeClient implicitly so that we don't do this check
|
||||
// again.
|
||||
rc = &client.Runtime{}
|
||||
clients.ipToRC[ip] = rc
|
||||
rc = client.NewRuntime(ip)
|
||||
clients.runtimeIndex.Add(rc)
|
||||
|
||||
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||
} else {
|
||||
@@ -798,61 +770,54 @@ func (clients *clientsContainer) addHostLocked(
|
||||
host string,
|
||||
src client.Source,
|
||||
) (ok bool) {
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
if !ok {
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
if rc == nil {
|
||||
if src < client.SourceDHCP {
|
||||
if clients.dhcp.HostByIP(ip) != "" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
rc = &client.Runtime{}
|
||||
clients.ipToRC[ip] = rc
|
||||
rc = client.NewRuntime(ip)
|
||||
clients.runtimeIndex.Add(rc)
|
||||
}
|
||||
|
||||
rc.SetInfo(src, []string{host})
|
||||
|
||||
log.Debug("clients: adding client info %s -> %q %q [%d]", ip, src, host, len(clients.ipToRC))
|
||||
log.Debug(
|
||||
"clients: adding client info %s -> %q %q [%d]",
|
||||
ip,
|
||||
src,
|
||||
host,
|
||||
clients.runtimeIndex.Size(),
|
||||
)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// rmHostsBySrc removes all entries that match the specified source.
|
||||
func (clients *clientsContainer) rmHostsBySrc(src client.Source) {
|
||||
n := 0
|
||||
for ip, rc := range clients.ipToRC {
|
||||
rc.Unset(src)
|
||||
if rc.IsEmpty() {
|
||||
delete(clients.ipToRC, ip)
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("clients: removed %d client aliases", n)
|
||||
}
|
||||
|
||||
// addFromHostsFile fills the client-hostname pairing index from the system's
|
||||
// hosts files.
|
||||
func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorage) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
clients.rmHostsBySrc(client.SourceHostsFile)
|
||||
deleted := clients.runtimeIndex.DeleteBySource(client.SourceHostsFile)
|
||||
log.Debug("clients: removed %d client aliases from system hosts file", deleted)
|
||||
|
||||
n := 0
|
||||
added := 0
|
||||
hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) {
|
||||
// Only the first name of the first record is considered a canonical
|
||||
// hostname for the IP address.
|
||||
//
|
||||
// TODO(e.burkov): Consider using all the names from all the records.
|
||||
if clients.addHostLocked(addr, names[0], client.SourceHostsFile) {
|
||||
n++
|
||||
added++
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
log.Debug("clients: added %d client aliases from system hosts file", n)
|
||||
log.Debug("clients: added %d client aliases from system hosts file", added)
|
||||
}
|
||||
|
||||
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
|
||||
@@ -876,7 +841,8 @@ func (clients *clientsContainer) addFromSystemARP() {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
clients.rmHostsBySrc(client.SourceARP)
|
||||
deleted := clients.runtimeIndex.DeleteBySource(client.SourceARP)
|
||||
log.Debug("clients: removed %d client aliases from arp neighborhood", deleted)
|
||||
|
||||
added := 0
|
||||
for _, n := range ns {
|
||||
@@ -891,18 +857,5 @@ func (clients *clientsContainer) addFromSystemARP() {
|
||||
// close gracefully closes all the client-specific upstream configurations of
|
||||
// the persistent clients.
|
||||
func (clients *clientsContainer) close() (err error) {
|
||||
persistent := maps.Values(clients.list)
|
||||
slices.SortFunc(persistent, func(a, b *client.Persistent) (res int) {
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
var errs []error
|
||||
|
||||
for _, cli := range persistent {
|
||||
if err = cli.CloseUpstreams(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Join(errs...)
|
||||
return clients.clientIndex.CloseUpstreams()
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) {
|
||||
}
|
||||
|
||||
dhcp := &testDHCP{
|
||||
OnLeases: func() (leases []*dhcpsvc.Lease) { panic("not implemented") },
|
||||
OnLeases: func() (leases []*dhcpsvc.Lease) { return nil },
|
||||
OnHostBy: func(ip netip.Addr) (host string) { return "" },
|
||||
OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) { return nil },
|
||||
}
|
||||
@@ -72,23 +72,19 @@ func TestClients(t *testing.T) {
|
||||
IPs: []netip.Addr{cli1IP, cliIPv6},
|
||||
}
|
||||
|
||||
ok, err := clients.add(c)
|
||||
err := clients.add(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
c = &client.Persistent{
|
||||
Name: "client2",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{cli2IP},
|
||||
}
|
||||
|
||||
ok, err = clients.add(c)
|
||||
err = clients.add(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
c, ok = clients.find(cli1)
|
||||
c, ok := clients.find(cli1)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client1", c.Name)
|
||||
@@ -111,22 +107,20 @@ func TestClients(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("add_fail_name", func(t *testing.T) {
|
||||
ok, err := clients.add(&client.Persistent{
|
||||
err := clients.add(&client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ok)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("add_fail_ip", func(t *testing.T) {
|
||||
ok, err := clients.add(&client.Persistent{
|
||||
err := clients.add(&client.Persistent{
|
||||
Name: "client3",
|
||||
UID: client.MustNewUID(),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("update_fail_ip", func(t *testing.T) {
|
||||
@@ -145,12 +139,13 @@ func TestClients(t *testing.T) {
|
||||
cliNewIP = netip.MustParseAddr(cliNew)
|
||||
)
|
||||
|
||||
prev, ok := clients.list["client1"]
|
||||
prev, ok := clients.clientIndex.FindByName("client1")
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, prev)
|
||||
|
||||
err := clients.update(prev, &client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
UID: prev.UID,
|
||||
IPs: []netip.Addr{cliNewIP},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -160,12 +155,13 @@ func TestClients(t *testing.T) {
|
||||
|
||||
assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent)
|
||||
|
||||
prev, ok = clients.list["client1"]
|
||||
prev, ok = clients.clientIndex.FindByName("client1")
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, prev)
|
||||
|
||||
err = clients.update(prev, &client.Persistent{
|
||||
Name: "client1-renamed",
|
||||
UID: client.MustNewUID(),
|
||||
UID: prev.UID,
|
||||
IPs: []netip.Addr{cliNewIP},
|
||||
UseOwnSettings: true,
|
||||
})
|
||||
@@ -177,7 +173,7 @@ func TestClients(t *testing.T) {
|
||||
assert.Equal(t, "client1-renamed", c.Name)
|
||||
assert.True(t, c.UseOwnSettings)
|
||||
|
||||
nilCli, ok := clients.list["client1"]
|
||||
nilCli, ok := clients.clientIndex.FindByName("client1")
|
||||
require.False(t, ok)
|
||||
|
||||
assert.Nil(t, nilCli)
|
||||
@@ -244,7 +240,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
t.Run("new_client", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.255")
|
||||
clients.setWHOISInfo(ip, whois)
|
||||
rc := clients.ipToRC[ip]
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
require.NotNil(t, rc)
|
||||
|
||||
assert.Equal(t, whois, rc.WHOIS())
|
||||
@@ -256,7 +252,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
assert.True(t, ok)
|
||||
|
||||
clients.setWHOISInfo(ip, whois)
|
||||
rc := clients.ipToRC[ip]
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
require.NotNil(t, rc)
|
||||
|
||||
assert.Equal(t, whois, rc.WHOIS())
|
||||
@@ -265,16 +261,15 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
t.Run("can't_set_manually-added", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.2")
|
||||
|
||||
ok, err := clients.add(&client.Persistent{
|
||||
err := clients.add(&client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
clients.setWHOISInfo(ip, whois)
|
||||
rc := clients.ipToRC[ip]
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
require.Nil(t, rc)
|
||||
|
||||
assert.True(t, clients.remove("client1"))
|
||||
@@ -288,7 +283,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
|
||||
// Add a client.
|
||||
ok, err := clients.add(&client.Persistent{
|
||||
err := clients.add(&client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")},
|
||||
@@ -296,10 +291,9 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Now add an auto-client with the same IP.
|
||||
ok = clients.addHost(ip, "test", client.SourceRDNS)
|
||||
ok := clients.addHost(ip, "test", client.SourceRDNS)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
@@ -339,22 +333,20 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add a new client with the same IP as for a client with MAC.
|
||||
ok, err := clients.add(&client.Persistent{
|
||||
err = clients.add(&client.Persistent{
|
||||
Name: "client2",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{ip},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Add a new client with the IP from the first client's IP range.
|
||||
ok, err = clients.add(&client.Persistent{
|
||||
err = clients.add(&client.Persistent{
|
||||
Name: "client3",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -362,7 +354,7 @@ func TestClientsCustomUpstream(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
// Add client with upstreams.
|
||||
ok, err := clients.add(&client.Persistent{
|
||||
err := clients.add(&client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")},
|
||||
@@ -372,7 +364,6 @@ func TestClientsCustomUpstream(t *testing.T) {
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
upsConf, err := clients.UpstreamConfigByID("1.2.3.4", net.DefaultResolver)
|
||||
assert.Nil(t, upsConf)
|
||||
|
||||
@@ -96,22 +96,26 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
for _, c := range clients.list {
|
||||
clients.clientIndex.Range(func(c *client.Persistent) (cont bool) {
|
||||
cj := clientToJSON(c)
|
||||
data.Clients = append(data.Clients, cj)
|
||||
}
|
||||
|
||||
for ip, rc := range clients.ipToRC {
|
||||
return true
|
||||
})
|
||||
|
||||
clients.runtimeIndex.Range(func(rc *client.Runtime) (cont bool) {
|
||||
src, host := rc.Info()
|
||||
cj := runtimeClientJSON{
|
||||
WHOIS: whoisOrEmpty(rc),
|
||||
Name: host,
|
||||
Source: src,
|
||||
IP: ip,
|
||||
IP: rc.Addr(),
|
||||
}
|
||||
|
||||
data.RuntimeClients = append(data.RuntimeClients, cj)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
for _, l := range clients.dhcp.Leases() {
|
||||
cj := runtimeClientJSON{
|
||||
@@ -332,20 +336,16 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := clients.add(c)
|
||||
err = clients.add(c)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !ok {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Client already exists")
|
||||
|
||||
return
|
||||
if !clients.testing {
|
||||
onConfigModified()
|
||||
}
|
||||
|
||||
onConfigModified()
|
||||
}
|
||||
|
||||
// handleDelClient is the handler for POST /control/clients/delete HTTP API.
|
||||
@@ -370,7 +370,9 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.
|
||||
return
|
||||
}
|
||||
|
||||
onConfigModified()
|
||||
if !clients.testing {
|
||||
onConfigModified()
|
||||
}
|
||||
}
|
||||
|
||||
// updateJSON contains the name and data of the updated persistent client.
|
||||
@@ -404,7 +406,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
prev, ok = clients.list[dj.Name]
|
||||
prev, ok = clients.clientIndex.FindByName(dj.Name)
|
||||
}()
|
||||
|
||||
if !ok {
|
||||
@@ -427,14 +429,16 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||
return
|
||||
}
|
||||
|
||||
onConfigModified()
|
||||
if !clients.testing {
|
||||
onConfigModified()
|
||||
}
|
||||
}
|
||||
|
||||
// handleFindClient is the handler for GET /control/clients/find HTTP API.
|
||||
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
data := []map[string]*clientJSON{}
|
||||
for i := 0; i < len(q); i++ {
|
||||
for i := range len(q) {
|
||||
idStr := q.Get(fmt.Sprintf("ip%d", i))
|
||||
if idStr == "" {
|
||||
break
|
||||
@@ -447,7 +451,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
cj = clients.findRuntime(ip, idStr)
|
||||
} else {
|
||||
cj = clientToJSON(c)
|
||||
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
|
||||
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
|
||||
}
|
||||
|
||||
@@ -463,14 +467,14 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
|
||||
// non-nil.
|
||||
func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) {
|
||||
rc, ok := clients.findRuntimeClient(ip)
|
||||
if !ok {
|
||||
rc := clients.findRuntimeClient(ip)
|
||||
if rc == nil {
|
||||
// It is still possible that the IP used to be in the runtime clients
|
||||
// list, but then the server was reloaded. So, check the DNS server's
|
||||
// blocked IP list.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
|
||||
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
|
||||
cj = &clientJSON{
|
||||
IDs: []string{idStr},
|
||||
Disallowed: &disallowed,
|
||||
@@ -488,7 +492,7 @@ func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *c
|
||||
WHOIS: whoisOrEmpty(rc),
|
||||
}
|
||||
|
||||
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
|
||||
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
|
||||
|
||||
return cj
|
||||
|
||||
399
internal/home/clientshttp_internal_test.go
Normal file
399
internal/home/clientshttp_internal_test.go
Normal file
@@ -0,0 +1,399 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
testClientIP1 = "1.1.1.1"
|
||||
testClientIP2 = "2.2.2.2"
|
||||
)
|
||||
|
||||
// testBlockedClientChecker is a mock implementation of the
|
||||
// [BlockedClientChecker] interface.
|
||||
type testBlockedClientChecker struct {
|
||||
onIsBlockedClient func(ip netip.Addr, clientiD string) (blocked bool, rule string)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ BlockedClientChecker = (*testBlockedClientChecker)(nil)
|
||||
|
||||
// IsBlockedClient implements the [BlockedClientChecker] interface for
|
||||
// *testBlockedClientChecker.
|
||||
func (c *testBlockedClientChecker) IsBlockedClient(
|
||||
ip netip.Addr,
|
||||
clientID string,
|
||||
) (blocked bool, rule string) {
|
||||
return c.onIsBlockedClient(ip, clientID)
|
||||
}
|
||||
|
||||
// newPersistentClient is a helper function that returns a persistent client
|
||||
// with the specified name and newly generated UID.
|
||||
func newPersistentClient(name string) (c *client.Persistent) {
|
||||
return &client.Persistent{
|
||||
Name: name,
|
||||
UID: client.MustNewUID(),
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: &schedule.Weekly{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// newPersistentClientWithIDs is a helper function that returns a persistent
|
||||
// client with the specified name and ids.
|
||||
func newPersistentClientWithIDs(tb testing.TB, name string, ids []string) (c *client.Persistent) {
|
||||
tb.Helper()
|
||||
|
||||
c = newPersistentClient(name)
|
||||
err := c.SetIDs(ids)
|
||||
require.NoError(tb, err)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// assertClients is a helper function that compares lists of persistent clients.
|
||||
func assertClients(tb testing.TB, want, got []*client.Persistent) {
|
||||
tb.Helper()
|
||||
|
||||
require.Len(tb, got, len(want))
|
||||
|
||||
sortFunc := func(a, b *client.Persistent) (n int) {
|
||||
return cmp.Compare(a.Name, b.Name)
|
||||
}
|
||||
|
||||
slices.SortFunc(want, sortFunc)
|
||||
slices.SortFunc(got, sortFunc)
|
||||
|
||||
slices.CompareFunc(want, got, func(a, b *client.Persistent) (n int) {
|
||||
assert.True(tb, a.EqualIDs(b), "%q doesn't have the same ids as %q", a.Name, b.Name)
|
||||
|
||||
return 0
|
||||
})
|
||||
}
|
||||
|
||||
// assertPersistentClients is a helper function that uses HTTP API to check
|
||||
// whether want persistent clients are the same as the persistent clients stored
|
||||
// in the clients container.
|
||||
func assertPersistentClients(tb testing.TB, clients *clientsContainer, want []*client.Persistent) {
|
||||
tb.Helper()
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
clients.handleGetClients(rw, &http.Request{})
|
||||
|
||||
body, err := io.ReadAll(rw.Body)
|
||||
require.NoError(tb, err)
|
||||
|
||||
clientList := &clientListJSON{}
|
||||
err = json.Unmarshal(body, clientList)
|
||||
require.NoError(tb, err)
|
||||
|
||||
var got []*client.Persistent
|
||||
for _, cj := range clientList.Clients {
|
||||
var c *client.Persistent
|
||||
c, err = clients.jsonToClient(*cj, nil)
|
||||
require.NoError(tb, err)
|
||||
|
||||
got = append(got, c)
|
||||
}
|
||||
|
||||
assertClients(tb, want, got)
|
||||
}
|
||||
|
||||
// assertPersistentClientsData is a helper function that checks whether want
|
||||
// persistent clients are the same as the persistent clients stored in data.
|
||||
func assertPersistentClientsData(
|
||||
tb testing.TB,
|
||||
clients *clientsContainer,
|
||||
data []map[string]*clientJSON,
|
||||
want []*client.Persistent,
|
||||
) {
|
||||
tb.Helper()
|
||||
|
||||
var got []*client.Persistent
|
||||
for _, cm := range data {
|
||||
for _, cj := range cm {
|
||||
var c *client.Persistent
|
||||
c, err := clients.jsonToClient(*cj, nil)
|
||||
require.NoError(tb, err)
|
||||
|
||||
got = append(got, c)
|
||||
}
|
||||
}
|
||||
|
||||
assertClients(tb, want, got)
|
||||
}
|
||||
|
||||
func TestClientsContainer_HandleAddClient(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
|
||||
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
|
||||
|
||||
clientEmptyID := newPersistentClient("empty_client_id")
|
||||
clientEmptyID.ClientIDs = []string{""}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
client *client.Persistent
|
||||
wantCode int
|
||||
wantClient []*client.Persistent
|
||||
}{{
|
||||
name: "add_one",
|
||||
client: clientOne,
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{clientOne},
|
||||
}, {
|
||||
name: "add_two",
|
||||
client: clientTwo,
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{clientOne, clientTwo},
|
||||
}, {
|
||||
name: "duplicate_client",
|
||||
client: clientTwo,
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientOne, clientTwo},
|
||||
}, {
|
||||
name: "empty_client_id",
|
||||
client: clientEmptyID,
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientOne, clientTwo},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cj := clientToJSON(tc.client)
|
||||
|
||||
body, err := json.Marshal(cj)
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := http.NewRequest(http.MethodPost, "", bytes.NewReader(body))
|
||||
require.NoError(t, err)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
clients.handleAddClient(rw, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.wantCode, rw.Code)
|
||||
|
||||
assertPersistentClients(t, clients, tc.wantClient)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientsContainer_HandleDelClient(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
|
||||
err := clients.add(clientOne)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
|
||||
err = clients.add(clientTwo)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
client *client.Persistent
|
||||
wantCode int
|
||||
wantClient []*client.Persistent
|
||||
}{{
|
||||
name: "remove_one",
|
||||
client: clientOne,
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{clientTwo},
|
||||
}, {
|
||||
name: "duplicate_client",
|
||||
client: clientOne,
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientTwo},
|
||||
}, {
|
||||
name: "empty_client_name",
|
||||
client: newPersistentClient(""),
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientTwo},
|
||||
}, {
|
||||
name: "remove_two",
|
||||
client: clientTwo,
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cj := clientToJSON(tc.client)
|
||||
|
||||
var body []byte
|
||||
body, err = json.Marshal(cj)
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(body))
|
||||
require.NoError(t, err)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
clients.handleDelClient(rw, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.wantCode, rw.Code)
|
||||
|
||||
assertPersistentClients(t, clients, tc.wantClient)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientsContainer_HandleUpdateClient(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
|
||||
err := clients.add(clientOne)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertPersistentClients(t, clients, []*client.Persistent{clientOne})
|
||||
|
||||
clientModified := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
|
||||
|
||||
clientEmptyID := newPersistentClient("empty_client_id")
|
||||
clientEmptyID.ClientIDs = []string{""}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
clientName string
|
||||
modified *client.Persistent
|
||||
wantCode int
|
||||
wantClient []*client.Persistent
|
||||
}{{
|
||||
name: "update_one",
|
||||
clientName: clientOne.Name,
|
||||
modified: clientModified,
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{clientModified},
|
||||
}, {
|
||||
name: "empty_name",
|
||||
clientName: "",
|
||||
modified: clientOne,
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientModified},
|
||||
}, {
|
||||
name: "client_not_found",
|
||||
clientName: "client_not_found",
|
||||
modified: clientOne,
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientModified},
|
||||
}, {
|
||||
name: "empty_client_id",
|
||||
clientName: clientModified.Name,
|
||||
modified: clientEmptyID,
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientModified},
|
||||
}, {
|
||||
name: "no_ids",
|
||||
clientName: clientModified.Name,
|
||||
modified: newPersistentClient("no_ids"),
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantClient: []*client.Persistent{clientModified},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
uj := updateJSON{
|
||||
Name: tc.clientName,
|
||||
Data: *clientToJSON(tc.modified),
|
||||
}
|
||||
|
||||
var body []byte
|
||||
body, err = json.Marshal(uj)
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(body))
|
||||
require.NoError(t, err)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
clients.handleUpdateClient(rw, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.wantCode, rw.Code)
|
||||
|
||||
assertPersistentClients(t, clients, tc.wantClient)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientsContainer_HandleFindClient(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
clients.clientChecker = &testBlockedClientChecker{
|
||||
onIsBlockedClient: func(ip netip.Addr, clientID string) (ok bool, rule string) {
|
||||
return false, ""
|
||||
},
|
||||
}
|
||||
|
||||
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
|
||||
err := clients.add(clientOne)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
|
||||
err = clients.add(clientTwo)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
query url.Values
|
||||
wantCode int
|
||||
wantClient []*client.Persistent
|
||||
}{{
|
||||
name: "single",
|
||||
query: url.Values{
|
||||
"ip0": []string{testClientIP1},
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{clientOne},
|
||||
}, {
|
||||
name: "multiple",
|
||||
query: url.Values{
|
||||
"ip0": []string{testClientIP1},
|
||||
"ip1": []string{testClientIP2},
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
wantClient: []*client.Persistent{clientOne, clientTwo},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodGet, "", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
r.URL.RawQuery = tc.query.Encode()
|
||||
rw := httptest.NewRecorder()
|
||||
clients.handleFindClient(rw, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.wantCode, rw.Code)
|
||||
|
||||
var body []byte
|
||||
body, err = io.ReadAll(rw.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientData := []map[string]*clientJSON{}
|
||||
err = json.Unmarshal(body, &clientData)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertPersistentClientsData(t, clients, clientData, tc.wantClient)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -203,15 +203,24 @@ type dnsConfig struct {
|
||||
// resolver should be used.
|
||||
PrivateNets []netutil.Prefix `yaml:"private_networks"`
|
||||
|
||||
// UsePrivateRDNS defines if the PTR requests for unknown addresses from
|
||||
// locally-served networks should be resolved via private PTR resolvers.
|
||||
// UsePrivateRDNS enables resolving requests containing a private IP address
|
||||
// using private reverse DNS resolvers. See PrivateRDNSResolvers.
|
||||
//
|
||||
// TODO(e.burkov): Rename in YAML.
|
||||
UsePrivateRDNS bool `yaml:"use_private_ptr_resolvers"`
|
||||
|
||||
// LocalPTRResolvers is the slice of addresses to be used as upstreams
|
||||
// for PTR queries for locally-served networks.
|
||||
LocalPTRResolvers []string `yaml:"local_ptr_upstreams"`
|
||||
// PrivateRDNSResolvers is the slice of addresses to be used as upstreams
|
||||
// for private requests. It's only used for PTR, SOA, and NS queries,
|
||||
// containing an ARPA subdomain, came from the the client with private
|
||||
// address. The address considered private according to PrivateNets.
|
||||
//
|
||||
// If empty, the OS-provided resolvers are used for private requests.
|
||||
PrivateRDNSResolvers []string `yaml:"local_ptr_upstreams"`
|
||||
|
||||
// UseDNS64 defines if DNS64 should be used for incoming requests.
|
||||
// UseDNS64 defines if DNS64 should be used for incoming requests. Requests
|
||||
// of type PTR for addresses within the configured prefixes will be resolved
|
||||
// via [PrivateRDNSResolvers], so those should be valid and UsePrivateRDNS
|
||||
// be set to true.
|
||||
UseDNS64 bool `yaml:"use_dns64"`
|
||||
|
||||
// DNS64Prefixes is the list of NAT64 prefixes to be used for DNS64.
|
||||
@@ -658,7 +667,7 @@ func (c *configuration) write() (err error) {
|
||||
dns := &config.DNS
|
||||
dns.Config = c
|
||||
|
||||
dns.LocalPTRResolvers = s.LocalPTRResolvers()
|
||||
dns.PrivateRDNSResolvers = s.LocalPTRResolvers()
|
||||
|
||||
addrProcConf := s.AddrProcConfig()
|
||||
config.Clients.Sources.RDNS = addrProcConf.UseRDNS
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -18,7 +17,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
@@ -150,21 +148,19 @@ func initDNSServer(
|
||||
return fmt.Errorf("dnsforward.NewServer: %w", err)
|
||||
}
|
||||
|
||||
Context.clients.dnsServer = Context.dnsServer
|
||||
Context.clients.clientChecker = Context.dnsServer
|
||||
|
||||
dnsConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpReg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("newServerConfig: %w", err)
|
||||
}
|
||||
|
||||
// Try to prepare the server with disabled private RDNS resolution if it
|
||||
// failed to prepare as is. See TODO on [ErrBadPrivateRDNSUpstreams].
|
||||
err = Context.dnsServer.Prepare(dnsConf)
|
||||
if privRDNSErr := (&dnsforward.PrivateRDNSError{}); errors.As(err, &privRDNSErr) {
|
||||
log.Info("WARNING: %s; trying to disable private RDNS resolution", err)
|
||||
|
||||
// TODO(e.burkov): Recreate the server with private RDNS disabled. This
|
||||
// should go away once the private RDNS resolution is moved to the proxy.
|
||||
var locResErr *dnsforward.LocalResolversError
|
||||
if errors.As(err, &locResErr) && errors.Is(locResErr.Err, upstream.ErrNoUpstreams) {
|
||||
log.Info("WARNING: no local resolvers configured while private RDNS " +
|
||||
"resolution enabled, trying to disable")
|
||||
dnsConf.UsePrivateRDNS = false
|
||||
err = Context.dnsServer.Prepare(dnsConf)
|
||||
}
|
||||
@@ -245,7 +241,7 @@ func newServerConfig(
|
||||
TLSv12Roots: Context.tlsRoots,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpReg,
|
||||
LocalPTRResolvers: dnsConf.LocalPTRResolvers,
|
||||
LocalPTRResolvers: dnsConf.PrivateRDNSResolvers,
|
||||
UseDNS64: dnsConf.UseDNS64,
|
||||
DNS64Prefixes: dnsConf.DNS64Prefixes,
|
||||
UsePrivateRDNS: dnsConf.UsePrivateRDNS,
|
||||
@@ -531,36 +527,6 @@ func closeDNSServer() {
|
||||
log.Debug("all dns modules are closed")
|
||||
}
|
||||
|
||||
// safeSearchResolver is a [filtering.Resolver] implementation used for safe
|
||||
// search.
|
||||
type safeSearchResolver struct{}
|
||||
|
||||
// type check
|
||||
var _ filtering.Resolver = safeSearchResolver{}
|
||||
|
||||
// LookupIP implements [filtering.Resolver] interface for safeSearchResolver.
|
||||
// It returns the slice of net.Addr with IPv4 and IPv6 instances.
|
||||
func (r safeSearchResolver) LookupIP(
|
||||
ctx context.Context,
|
||||
network string,
|
||||
host string,
|
||||
) (ips []net.IP, err error) {
|
||||
addrs, err := Context.dnsServer.Resolve(ctx, network, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(addrs) == 0 {
|
||||
return nil, fmt.Errorf("couldn't lookup host: %s", host)
|
||||
}
|
||||
|
||||
for _, a := range addrs {
|
||||
ips = append(ips, a.AsSlice())
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
// checkStatsAndQuerylogDirs checks and returns directory paths to store
|
||||
// statistics and query log.
|
||||
func checkStatsAndQuerylogDirs(
|
||||
|
||||
@@ -439,7 +439,6 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
|
||||
conf.ParentalBlockHost = host
|
||||
}
|
||||
|
||||
conf.SafeSearchConf.CustomResolver = safeSearchResolver{}
|
||||
conf.SafeSearch, err = safesearch.NewDefault(
|
||||
conf.SafeSearchConf,
|
||||
"default",
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -76,8 +76,7 @@ func getLogSettings(opts options) (ls *logSettings) {
|
||||
ls.Verbose = true
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Use cmp.Or in Go 1.22.
|
||||
ls.File = stringutil.Coalesce(opts.logFile, ls.File)
|
||||
ls.File = cmp.Or(opts.logFile, ls.File)
|
||||
|
||||
if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
|
||||
// When running as a Windows service, use eventlog by default if
|
||||
|
||||
@@ -306,7 +306,7 @@ func handleServiceStatusCommand(s service.Service) {
|
||||
}
|
||||
}
|
||||
|
||||
// handleServiceStatusCommand handles service "install" command
|
||||
// handleServiceInstallCommand handles service "install" command.
|
||||
func handleServiceInstallCommand(s service.Service) {
|
||||
err := svcAction(s, "install")
|
||||
if err != nil {
|
||||
@@ -340,7 +340,7 @@ AdGuard Home is now available at the following addresses:`)
|
||||
}
|
||||
}
|
||||
|
||||
// handleServiceStatusCommand handles service "uninstall" command
|
||||
// handleServiceUninstallCommand handles service "uninstall" command.
|
||||
func handleServiceUninstallCommand(s service.Service) {
|
||||
if aghos.IsOpenWrt() {
|
||||
// On OpenWrt it is important to run disable command first
|
||||
@@ -649,11 +649,6 @@ status() {
|
||||
|
||||
// freeBSDScript is the source of the daemon script for FreeBSD. Keep as close
|
||||
// as possible to the https://github.com/kardianos/service/blob/18c957a3dc1120a2efe77beb401d476bade9e577/service_freebsd.go#L204.
|
||||
//
|
||||
// TODO(a.garipov): Don't use .WorkingDirectory here. There are currently no
|
||||
// guarantees that it will actually be the required directory.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2614.
|
||||
const freeBSDScript = `#!/bin/sh
|
||||
# PROVIDE: {{.Name}}
|
||||
# REQUIRE: networking
|
||||
@@ -667,7 +662,9 @@ name="{{.Name}}"
|
||||
pidfile_child="/var/run/${name}.pid"
|
||||
pidfile="/var/run/${name}_daemon.pid"
|
||||
command="/usr/sbin/daemon"
|
||||
command_args="-P ${pidfile} -p ${pidfile_child} -T ${name} -r {{.WorkingDirectory}}/{{.Name}}"
|
||||
daemon_args="-P ${pidfile} -p ${pidfile_child} -r -t ${name}"
|
||||
command_args="${daemon_args} {{.Path}}{{range .Arguments}} {{.}}{{end}}"
|
||||
|
||||
run_rc_command "$1"
|
||||
`
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -14,7 +15,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
@@ -76,7 +76,7 @@ func (*openbsdRunComService) Platform() (p string) {
|
||||
|
||||
// String implements service.Service interface for *openbsdRunComService.
|
||||
func (s *openbsdRunComService) String() string {
|
||||
return stringutil.Coalesce(s.cfg.DisplayName, s.cfg.Name)
|
||||
return cmp.Or(s.cfg.DisplayName, s.cfg.Name)
|
||||
}
|
||||
|
||||
// getBool returns the value of the given name from kv, assuming the value is a
|
||||
|
||||
Reference in New Issue
Block a user