all: resync with master
This commit is contained in:
@@ -97,6 +97,7 @@ func glGetTokenDate(file string) uint32 {
|
||||
|
||||
buf := bytes.NewBuffer(bs)
|
||||
|
||||
// TODO(a.garipov): Get rid of github.com/josharian/native dependency.
|
||||
err = binary.Read(buf, native.Endian, &dateToken)
|
||||
if err != nil {
|
||||
log.Error("decoding token: %s", err)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
@@ -20,50 +19,18 @@ import (
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/hostsfile"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
)
|
||||
|
||||
// DHCP is an interface for accessing DHCP lease data the [clientsContainer]
|
||||
// needs.
|
||||
type DHCP interface {
|
||||
// Leases returns all the DHCP leases.
|
||||
Leases() (leases []*dhcpsvc.Lease)
|
||||
|
||||
// HostByIP returns the hostname of the DHCP client with the given IP
|
||||
// address. The address will be netip.Addr{} if there is no such client,
|
||||
// due to an assumption that a DHCP client must always have a hostname.
|
||||
HostByIP(ip netip.Addr) (host string)
|
||||
|
||||
// MACByIP returns the MAC address for the given IP address leased. It
|
||||
// returns nil if there is no such client, due to an assumption that a DHCP
|
||||
// client must always have a MAC address.
|
||||
MACByIP(ip netip.Addr) (mac net.HardwareAddr)
|
||||
}
|
||||
|
||||
// clientsContainer is the storage of all runtime and persistent clients.
|
||||
type clientsContainer struct {
|
||||
// storage stores information about persistent clients.
|
||||
storage *client.Storage
|
||||
|
||||
// runtimeIndex stores information about runtime clients.
|
||||
runtimeIndex *client.RuntimeIndex
|
||||
|
||||
// dhcp is the DHCP service implementation.
|
||||
dhcp DHCP
|
||||
|
||||
// clientChecker checks if a client is blocked by the current access
|
||||
// settings.
|
||||
clientChecker BlockedClientChecker
|
||||
|
||||
// etcHosts contains list of rewrite rules taken from the operating system's
|
||||
// hosts database.
|
||||
etcHosts *aghnet.HostsContainer
|
||||
|
||||
// arpDB stores the neighbors retrieved from ARP.
|
||||
arpDB arpdb.Interface
|
||||
|
||||
// lock protects all fields.
|
||||
//
|
||||
// TODO(a.garipov): Use a pointer and describe which fields are protected in
|
||||
@@ -95,7 +62,7 @@ type BlockedClientChecker interface {
|
||||
// Note: this function must be called only once
|
||||
func (clients *clientsContainer) Init(
|
||||
objects []*clientObject,
|
||||
dhcpServer DHCP,
|
||||
dhcpServer client.DHCP,
|
||||
etcHosts *aghnet.HostsContainer,
|
||||
arpDB arpdb.Interface,
|
||||
filteringConf *filtering.Config,
|
||||
@@ -105,28 +72,15 @@ func (clients *clientsContainer) Init(
|
||||
return errors.Error("clients container already initialized")
|
||||
}
|
||||
|
||||
clients.runtimeIndex = client.NewRuntimeIndex()
|
||||
confClients := make([]*client.Persistent, 0, len(objects))
|
||||
for i, o := range objects {
|
||||
var p *client.Persistent
|
||||
p, err = o.toPersistent(filteringConf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init persistent client at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
clients.storage = client.NewStorage(&client.Config{
|
||||
AllowedTags: clientTags,
|
||||
})
|
||||
|
||||
// TODO(e.burkov): Use [dhcpsvc] implementation when it's ready.
|
||||
clients.dhcp = dhcpServer
|
||||
|
||||
clients.etcHosts = etcHosts
|
||||
clients.arpDB = arpDB
|
||||
err = clients.addFromConfig(objects, filteringConf)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
|
||||
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
|
||||
|
||||
if clients.testing {
|
||||
return nil
|
||||
confClients = append(confClients, p)
|
||||
}
|
||||
|
||||
// The clients.etcHosts may be nil even if config.Clients.Sources.HostsFile
|
||||
@@ -135,21 +89,26 @@ func (clients *clientsContainer) Init(
|
||||
// TODO(e.burkov): The option should probably be returned, since hosts file
|
||||
// currently used not only for clients' information enrichment, but also in
|
||||
// the filtering module and upstream addresses resolution.
|
||||
if config.Clients.Sources.HostsFile && clients.etcHosts != nil {
|
||||
go clients.handleHostsUpdates()
|
||||
var hosts client.HostsContainer = etcHosts
|
||||
if !config.Clients.Sources.HostsFile {
|
||||
hosts = nil
|
||||
}
|
||||
|
||||
clients.storage, err = client.NewStorage(&client.StorageConfig{
|
||||
InitialClients: confClients,
|
||||
DHCP: dhcpServer,
|
||||
EtcHosts: hosts,
|
||||
ARPDB: arpDB,
|
||||
ARPClientsUpdatePeriod: arpClientsUpdatePeriod,
|
||||
RuntimeSourceDHCP: config.Clients.Sources.DHCP,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("init client storage: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleHostsUpdates receives the updates from the hosts container and adds
|
||||
// them to the clients container. It is intended to be used as a goroutine.
|
||||
func (clients *clientsContainer) handleHostsUpdates() {
|
||||
for upd := range clients.etcHosts.Upd() {
|
||||
clients.addFromHostsFile(upd)
|
||||
}
|
||||
}
|
||||
|
||||
// webHandlersRegistered prevents a [clientsContainer] from registering its web
|
||||
// handlers more than once.
|
||||
//
|
||||
@@ -157,7 +116,7 @@ func (clients *clientsContainer) handleHostsUpdates() {
|
||||
var webHandlersRegistered = false
|
||||
|
||||
// Start starts the clients container.
|
||||
func (clients *clientsContainer) Start() {
|
||||
func (clients *clientsContainer) Start(ctx context.Context) (err error) {
|
||||
if clients.testing {
|
||||
return
|
||||
}
|
||||
@@ -167,14 +126,7 @@ func (clients *clientsContainer) Start() {
|
||||
clients.registerWebHandlers()
|
||||
}
|
||||
|
||||
go clients.periodicUpdate()
|
||||
}
|
||||
|
||||
// reloadARP reloads runtime clients from ARP, if configured.
|
||||
func (clients *clientsContainer) reloadARP() {
|
||||
if clients.arpDB != nil {
|
||||
clients.addFromSystemARP()
|
||||
}
|
||||
return clients.storage.Start(ctx)
|
||||
}
|
||||
|
||||
// clientObject is the YAML representation of a persistent client.
|
||||
@@ -275,28 +227,6 @@ func (o *clientObject) toPersistent(
|
||||
return cli, nil
|
||||
}
|
||||
|
||||
// addFromConfig initializes the clients container with objects from the
|
||||
// configuration file.
|
||||
func (clients *clientsContainer) addFromConfig(
|
||||
objects []*clientObject,
|
||||
filteringConf *filtering.Config,
|
||||
) (err error) {
|
||||
for i, o := range objects {
|
||||
var cli *client.Persistent
|
||||
cli, err = o.toPersistent(filteringConf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("clients: init persistent client at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
err = clients.storage.Add(cli)
|
||||
if err != nil {
|
||||
return fmt.Errorf("adding client %q at index %d: %w", cli.Name, i, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// forConfig returns all currently known persistent clients as objects for the
|
||||
// configuration file.
|
||||
func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||
@@ -337,39 +267,6 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||
// arpClientsUpdatePeriod defines how often ARP clients are updated.
|
||||
const arpClientsUpdatePeriod = 10 * time.Minute
|
||||
|
||||
func (clients *clientsContainer) periodicUpdate() {
|
||||
defer log.OnPanic("clients container")
|
||||
|
||||
for {
|
||||
clients.reloadARP()
|
||||
time.Sleep(arpClientsUpdatePeriod)
|
||||
}
|
||||
}
|
||||
|
||||
// clientSource checks if client with this IP address already exists and returns
|
||||
// the source which updated it last. It returns [client.SourceNone] if the
|
||||
// client doesn't exist. Note that it is only used in tests.
|
||||
func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
_, ok := clients.findLocked(ip.String())
|
||||
if ok {
|
||||
return client.SourcePersistent
|
||||
}
|
||||
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
if rc != nil {
|
||||
src, _ = rc.Info()
|
||||
}
|
||||
|
||||
if src < client.SourceDHCP && clients.dhcp.HostByIP(ip) != "" {
|
||||
src = client.SourceDHCP
|
||||
}
|
||||
|
||||
return src
|
||||
}
|
||||
|
||||
// findMultiple is a wrapper around [clientsContainer.find] to make it a valid
|
||||
// client finder for the query log. c is never nil; if no information about the
|
||||
// client is found, it returns an artificial client record by only setting the
|
||||
@@ -415,7 +312,7 @@ func (clients *clientsContainer) clientOrArtificial(
|
||||
}, false
|
||||
}
|
||||
|
||||
rc := clients.findRuntimeClient(ip)
|
||||
rc := clients.storage.ClientRuntime(ip)
|
||||
if rc != nil {
|
||||
_, host := rc.Info()
|
||||
|
||||
@@ -430,19 +327,6 @@ func (clients *clientsContainer) clientOrArtificial(
|
||||
}, true
|
||||
}
|
||||
|
||||
// find returns a shallow copy of the client if there is one found.
|
||||
func (clients *clientsContainer) find(id string) (c *client.Persistent, ok bool) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
c, ok = clients.findLocked(id)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return c, true
|
||||
}
|
||||
|
||||
// shouldCountClient is a wrapper around [clientsContainer.find] to make it a
|
||||
// valid client information finder for the statistics. If no information about
|
||||
// the client is found, it returns true.
|
||||
@@ -451,7 +335,7 @@ func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) {
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
for _, id := range ids {
|
||||
client, ok := clients.findLocked(id)
|
||||
client, ok := clients.storage.Find(id)
|
||||
if ok {
|
||||
return !client.IgnoreStatistics
|
||||
}
|
||||
@@ -473,7 +357,7 @@ func (clients *clientsContainer) UpstreamConfigByID(
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
c, ok := clients.findLocked(id)
|
||||
c, ok := clients.storage.Find(id)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
} else if c.UpstreamConfig != nil {
|
||||
@@ -511,225 +395,17 @@ func (clients *clientsContainer) UpstreamConfigByID(
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
// findLocked searches for a client by its ID. clients.lock is expected to be
|
||||
// locked.
|
||||
func (clients *clientsContainer) findLocked(id string) (c *client.Persistent, ok bool) {
|
||||
c, ok = clients.storage.Find(id)
|
||||
if ok {
|
||||
return c, true
|
||||
}
|
||||
|
||||
ip, err := netip.ParseAddr(id)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Iterate through clients.list only once.
|
||||
return clients.findDHCP(ip)
|
||||
}
|
||||
|
||||
// findDHCP searches for a client by its MAC, if the DHCP server is active and
|
||||
// there is such client. clients.lock is expected to be locked.
|
||||
func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent, ok bool) {
|
||||
foundMAC := clients.dhcp.MACByIP(ip)
|
||||
if foundMAC == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return clients.storage.FindByMAC(foundMAC)
|
||||
}
|
||||
|
||||
// runtimeClient returns a runtime client from internal index. Note that it
|
||||
// doesn't include DHCP clients.
|
||||
func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime) {
|
||||
if ip == (netip.Addr{}) {
|
||||
return nil
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
return clients.runtimeIndex.Client(ip)
|
||||
}
|
||||
|
||||
// findRuntimeClient finds a runtime client by their IP.
|
||||
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime) {
|
||||
rc = clients.runtimeClient(ip)
|
||||
host := clients.dhcp.HostByIP(ip)
|
||||
|
||||
if host != "" {
|
||||
if rc == nil {
|
||||
rc = client.NewRuntime(ip)
|
||||
}
|
||||
|
||||
rc.SetInfo(client.SourceDHCP, []string{host})
|
||||
|
||||
return rc
|
||||
}
|
||||
|
||||
return rc
|
||||
}
|
||||
|
||||
// setWHOISInfo sets the WHOIS information for a client. clients.lock is
|
||||
// expected to be locked.
|
||||
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
|
||||
_, ok := clients.findLocked(ip.String())
|
||||
if ok {
|
||||
log.Debug("clients: client for %s is already created, ignore whois info", ip)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
if rc == nil {
|
||||
// Create a RuntimeClient implicitly so that we don't do this check
|
||||
// again.
|
||||
rc = client.NewRuntime(ip)
|
||||
clients.runtimeIndex.Add(rc)
|
||||
|
||||
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||
} else {
|
||||
host, _ := rc.Info()
|
||||
log.Debug("clients: set whois info for runtime client %s: %+v", host, wi)
|
||||
}
|
||||
|
||||
rc.SetWHOIS(wi)
|
||||
}
|
||||
|
||||
// addHost adds a new IP-hostname pairing. The priorities of the sources are
|
||||
// taken into account. ok is true if the pairing was added.
|
||||
//
|
||||
// TODO(a.garipov): Only used in internal tests. Consider removing.
|
||||
func (clients *clientsContainer) addHost(
|
||||
ip netip.Addr,
|
||||
host string,
|
||||
src client.Source,
|
||||
) (ok bool) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
return clients.addHostLocked(ip, host, src)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ client.AddressUpdater = (*clientsContainer)(nil)
|
||||
|
||||
// UpdateAddress implements the [client.AddressUpdater] interface for
|
||||
// *clientsContainer
|
||||
func (clients *clientsContainer) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {
|
||||
// Common fast path optimization.
|
||||
if host == "" && info == nil {
|
||||
return
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
if host != "" {
|
||||
ok := clients.addHostLocked(ip, host, client.SourceRDNS)
|
||||
if !ok {
|
||||
log.Debug("clients: host for client %q already set with higher priority source", ip)
|
||||
}
|
||||
}
|
||||
|
||||
if info != nil {
|
||||
clients.setWHOISInfo(ip, info)
|
||||
}
|
||||
}
|
||||
|
||||
// addHostLocked adds a new IP-hostname pairing. clients.lock is expected to be
|
||||
// locked.
|
||||
func (clients *clientsContainer) addHostLocked(
|
||||
ip netip.Addr,
|
||||
host string,
|
||||
src client.Source,
|
||||
) (ok bool) {
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
if rc == nil {
|
||||
if src < client.SourceDHCP {
|
||||
if clients.dhcp.HostByIP(ip) != "" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
rc = client.NewRuntime(ip)
|
||||
clients.runtimeIndex.Add(rc)
|
||||
}
|
||||
|
||||
rc.SetInfo(src, []string{host})
|
||||
|
||||
log.Debug(
|
||||
"clients: adding client info %s -> %q %q [%d]",
|
||||
ip,
|
||||
src,
|
||||
host,
|
||||
clients.runtimeIndex.Size(),
|
||||
)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// addFromHostsFile fills the client-hostname pairing index from the system's
|
||||
// hosts files.
|
||||
func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorage) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
deleted := clients.runtimeIndex.DeleteBySource(client.SourceHostsFile)
|
||||
log.Debug("clients: removed %d client aliases from system hosts file", deleted)
|
||||
|
||||
added := 0
|
||||
hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) {
|
||||
// Only the first name of the first record is considered a canonical
|
||||
// hostname for the IP address.
|
||||
//
|
||||
// TODO(e.burkov): Consider using all the names from all the records.
|
||||
if clients.addHostLocked(addr, names[0], client.SourceHostsFile) {
|
||||
added++
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
log.Debug("clients: added %d client aliases from system hosts file", added)
|
||||
}
|
||||
|
||||
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
|
||||
// command.
|
||||
func (clients *clientsContainer) addFromSystemARP() {
|
||||
if err := clients.arpDB.Refresh(); err != nil {
|
||||
log.Error("refreshing arp container: %s", err)
|
||||
|
||||
clients.arpDB = arpdb.Empty{}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ns := clients.arpDB.Neighbors()
|
||||
if len(ns) == 0 {
|
||||
log.Debug("refreshing arp container: the update is empty")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
deleted := clients.runtimeIndex.DeleteBySource(client.SourceARP)
|
||||
log.Debug("clients: removed %d client aliases from arp neighborhood", deleted)
|
||||
|
||||
added := 0
|
||||
for _, n := range ns {
|
||||
if clients.addHostLocked(n.IP, n.Name, client.SourceARP) {
|
||||
added++
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("clients: added %d client aliases from arp neighborhood", added)
|
||||
clients.storage.UpdateAddress(ip, host, info)
|
||||
}
|
||||
|
||||
// close gracefully closes all the client-specific upstream configurations of
|
||||
// the persistent clients.
|
||||
func (clients *clientsContainer) close() (err error) {
|
||||
return clients.storage.CloseUpstreams()
|
||||
func (clients *clientsContainer) close(ctx context.Context) (err error) {
|
||||
return clients.storage.Shutdown(ctx)
|
||||
}
|
||||
|
||||
@@ -3,34 +3,14 @@ package home
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testDHCP struct {
|
||||
OnLeases func() (leases []*dhcpsvc.Lease)
|
||||
OnHostBy func(ip netip.Addr) (host string)
|
||||
OnMACBy func(ip netip.Addr) (mac net.HardwareAddr)
|
||||
}
|
||||
|
||||
// Lease implements the [DHCP] interface for testDHCP.
|
||||
func (t *testDHCP) Leases() (leases []*dhcpsvc.Lease) { return t.OnLeases() }
|
||||
|
||||
// HostByIP implements the [DHCP] interface for testDHCP.
|
||||
func (t *testDHCP) HostByIP(ip netip.Addr) (host string) { return t.OnHostBy(ip) }
|
||||
|
||||
// MACByIP implements the [DHCP] interface for testDHCP.
|
||||
func (t *testDHCP) MACByIP(ip netip.Addr) (mac net.HardwareAddr) { return t.OnMACBy(ip) }
|
||||
|
||||
// newClientsContainer is a helper that creates a new clients container for
|
||||
// tests.
|
||||
func newClientsContainer(t *testing.T) (c *clientsContainer) {
|
||||
@@ -40,316 +20,11 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) {
|
||||
testing: true,
|
||||
}
|
||||
|
||||
dhcp := &testDHCP{
|
||||
OnLeases: func() (leases []*dhcpsvc.Lease) { return nil },
|
||||
OnHostBy: func(ip netip.Addr) (host string) { return "" },
|
||||
OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) { return nil },
|
||||
}
|
||||
|
||||
require.NoError(t, c.Init(nil, dhcp, nil, nil, &filtering.Config{}))
|
||||
require.NoError(t, c.Init(nil, client.EmptyDHCP{}, nil, nil, &filtering.Config{}))
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func TestClients(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
t.Run("add_success", func(t *testing.T) {
|
||||
var (
|
||||
cliNone = "1.2.3.4"
|
||||
cli1 = "1.1.1.1"
|
||||
cli2 = "2.2.2.2"
|
||||
|
||||
cli1IP = netip.MustParseAddr(cli1)
|
||||
cli2IP = netip.MustParseAddr(cli2)
|
||||
|
||||
cliIPv6 = netip.MustParseAddr("1:2:3::4")
|
||||
)
|
||||
|
||||
c := &client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{cli1IP, cliIPv6},
|
||||
}
|
||||
|
||||
err := clients.storage.Add(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
c = &client.Persistent{
|
||||
Name: "client2",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{cli2IP},
|
||||
}
|
||||
|
||||
err = clients.storage.Add(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
c, ok := clients.find(cli1)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client1", c.Name)
|
||||
|
||||
c, ok = clients.find("1:2:3::4")
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client1", c.Name)
|
||||
|
||||
c, ok = clients.find(cli2)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client2", c.Name)
|
||||
|
||||
_, ok = clients.find(cliNone)
|
||||
assert.False(t, ok)
|
||||
|
||||
assert.Equal(t, clients.clientSource(cli1IP), client.SourcePersistent)
|
||||
assert.Equal(t, clients.clientSource(cli2IP), client.SourcePersistent)
|
||||
})
|
||||
|
||||
t.Run("add_fail_name", func(t *testing.T) {
|
||||
err := clients.storage.Add(&client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")},
|
||||
})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("add_fail_ip", func(t *testing.T) {
|
||||
err := clients.storage.Add(&client.Persistent{
|
||||
Name: "client3",
|
||||
UID: client.MustNewUID(),
|
||||
})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("update_fail_ip", func(t *testing.T) {
|
||||
err := clients.storage.Update("client1", &client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("update_success", func(t *testing.T) {
|
||||
var (
|
||||
cliOld = "1.1.1.1"
|
||||
cliNew = "1.1.1.2"
|
||||
|
||||
cliNewIP = netip.MustParseAddr(cliNew)
|
||||
)
|
||||
|
||||
prev, ok := clients.storage.FindByName("client1")
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, prev)
|
||||
|
||||
err := clients.storage.Update("client1", &client.Persistent{
|
||||
Name: "client1",
|
||||
UID: prev.UID,
|
||||
IPs: []netip.Addr{cliNewIP},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, ok = clients.find(cliOld)
|
||||
assert.False(t, ok)
|
||||
|
||||
assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent)
|
||||
|
||||
prev, ok = clients.storage.FindByName("client1")
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, prev)
|
||||
|
||||
err = clients.storage.Update("client1", &client.Persistent{
|
||||
Name: "client1-renamed",
|
||||
UID: prev.UID,
|
||||
IPs: []netip.Addr{cliNewIP},
|
||||
UseOwnSettings: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
c, ok := clients.find(cliNew)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client1-renamed", c.Name)
|
||||
assert.True(t, c.UseOwnSettings)
|
||||
|
||||
nilCli, ok := clients.storage.FindByName("client1")
|
||||
require.False(t, ok)
|
||||
|
||||
assert.Nil(t, nilCli)
|
||||
|
||||
require.Len(t, c.IDs(), 1)
|
||||
|
||||
assert.Equal(t, cliNewIP, c.IPs[0])
|
||||
})
|
||||
|
||||
t.Run("del_success", func(t *testing.T) {
|
||||
ok := clients.storage.RemoveByName("client1-renamed")
|
||||
require.True(t, ok)
|
||||
|
||||
_, ok = clients.find("1.1.1.2")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("del_fail", func(t *testing.T) {
|
||||
ok := clients.storage.RemoveByName("client3")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("addhost_success", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
ok := clients.addHost(ip, "host", client.SourceARP)
|
||||
assert.True(t, ok)
|
||||
|
||||
ok = clients.addHost(ip, "host2", client.SourceARP)
|
||||
assert.True(t, ok)
|
||||
|
||||
ok = clients.addHost(ip, "host3", client.SourceHostsFile)
|
||||
assert.True(t, ok)
|
||||
|
||||
assert.Equal(t, clients.clientSource(ip), client.SourceHostsFile)
|
||||
})
|
||||
|
||||
t.Run("dhcp_replaces_arp", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.2.3.4")
|
||||
ok := clients.addHost(ip, "from_arp", client.SourceARP)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, clients.clientSource(ip), client.SourceARP)
|
||||
|
||||
ok = clients.addHost(ip, "from_dhcp", client.SourceDHCP)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, clients.clientSource(ip), client.SourceDHCP)
|
||||
})
|
||||
|
||||
t.Run("addhost_priority", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
ok := clients.addHost(ip, "host1", client.SourceRDNS)
|
||||
assert.True(t, ok)
|
||||
|
||||
assert.Equal(t, client.SourceHostsFile, clients.clientSource(ip))
|
||||
})
|
||||
}
|
||||
|
||||
func TestClientsWHOIS(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
whois := &whois.Info{
|
||||
Country: "AU",
|
||||
Orgname: "Example Org",
|
||||
}
|
||||
|
||||
t.Run("new_client", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.255")
|
||||
clients.setWHOISInfo(ip, whois)
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
require.NotNil(t, rc)
|
||||
|
||||
assert.Equal(t, whois, rc.WHOIS())
|
||||
})
|
||||
|
||||
t.Run("existing_auto-client", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
ok := clients.addHost(ip, "host", client.SourceRDNS)
|
||||
assert.True(t, ok)
|
||||
|
||||
clients.setWHOISInfo(ip, whois)
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
require.NotNil(t, rc)
|
||||
|
||||
assert.Equal(t, whois, rc.WHOIS())
|
||||
})
|
||||
|
||||
t.Run("can't_set_manually-added", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.2")
|
||||
|
||||
err := clients.storage.Add(&client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
clients.setWHOISInfo(ip, whois)
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
require.Nil(t, rc)
|
||||
|
||||
assert.True(t, clients.storage.RemoveByName("client1"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestClientsAddExisting(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
|
||||
// Add a client.
|
||||
err := clients.storage.Add(&client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")},
|
||||
Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")},
|
||||
MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now add an auto-client with the same IP.
|
||||
ok := clients.addHost(ip, "test", client.SourceRDNS)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("complicated", func(t *testing.T) {
|
||||
// TODO(a.garipov): Properly decouple the DHCP server from the client
|
||||
// storage.
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping dhcp test on windows")
|
||||
}
|
||||
|
||||
ip := netip.MustParseAddr("1.2.3.4")
|
||||
|
||||
// First, init a DHCP server with a single static lease.
|
||||
config := &dhcpd.ServerConfig{
|
||||
Enabled: true,
|
||||
DataDir: t.TempDir(),
|
||||
Conf4: dhcpd.V4ServerConf{
|
||||
Enabled: true,
|
||||
GatewayIP: netip.MustParseAddr("1.2.3.1"),
|
||||
SubnetMask: netip.MustParseAddr("255.255.255.0"),
|
||||
RangeStart: netip.MustParseAddr("1.2.3.2"),
|
||||
RangeEnd: netip.MustParseAddr("1.2.3.10"),
|
||||
},
|
||||
}
|
||||
|
||||
dhcpServer, err := dhcpd.Create(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
clients.dhcp = dhcpServer
|
||||
|
||||
err = dhcpServer.AddStaticLease(&dhcpsvc.Lease{
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: ip,
|
||||
Hostname: "testhost",
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add a new client with the same IP as for a client with MAC.
|
||||
err = clients.storage.Add(&client.Persistent{
|
||||
Name: "client2",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{ip},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add a new client with the IP from the first client's IP range.
|
||||
err = clients.storage.Add(&client.Persistent{
|
||||
Name: "client3",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClientsCustomUpstream(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
|
||||
@@ -103,7 +103,9 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
return true
|
||||
})
|
||||
|
||||
clients.runtimeIndex.Range(func(rc *client.Runtime) (cont bool) {
|
||||
clients.storage.UpdateDHCP()
|
||||
|
||||
clients.storage.RangeRuntime(func(rc *client.Runtime) (cont bool) {
|
||||
src, host := rc.Info()
|
||||
cj := runtimeClientJSON{
|
||||
WHOIS: whoisOrEmpty(rc),
|
||||
@@ -117,18 +119,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
return true
|
||||
})
|
||||
|
||||
for _, l := range clients.dhcp.Leases() {
|
||||
cj := runtimeClientJSON{
|
||||
Name: l.Hostname,
|
||||
Source: client.SourceDHCP,
|
||||
IP: l.IP,
|
||||
WHOIS: &whois.Info{},
|
||||
}
|
||||
|
||||
data.RuntimeClients = append(data.RuntimeClients, cj)
|
||||
}
|
||||
|
||||
data.Tags = clientTags
|
||||
data.Tags = clients.storage.AllowedTags()
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, data)
|
||||
}
|
||||
@@ -248,6 +239,7 @@ func copySafeSearch(
|
||||
if conf.Enabled {
|
||||
conf.Bing = true
|
||||
conf.DuckDuckGo = true
|
||||
conf.Ecosia = true
|
||||
conf.Google = true
|
||||
conf.Pixabay = true
|
||||
conf.Yandex = true
|
||||
@@ -429,7 +421,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
}
|
||||
|
||||
ip, _ := netip.ParseAddr(idStr)
|
||||
c, ok := clients.find(idStr)
|
||||
c, ok := clients.storage.Find(idStr)
|
||||
var cj *clientJSON
|
||||
if !ok {
|
||||
cj = clients.findRuntime(ip, idStr)
|
||||
@@ -451,7 +443,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
|
||||
// non-nil.
|
||||
func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) {
|
||||
rc := clients.findRuntimeClient(ip)
|
||||
rc := clients.storage.ClientRuntime(ip)
|
||||
if rc == nil {
|
||||
// It is still possible that the IP used to be in the runtime clients
|
||||
// list, but then the server was reloaded. So, check the DNS server's
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
package home
|
||||
|
||||
var clientTags = []string{
|
||||
"device_audio",
|
||||
"device_camera",
|
||||
"device_gameconsole",
|
||||
"device_laptop",
|
||||
"device_nas", // Network-attached Storage
|
||||
"device_other",
|
||||
"device_pc",
|
||||
"device_phone",
|
||||
"device_printer",
|
||||
"device_securityalarm",
|
||||
"device_tablet",
|
||||
"device_tv",
|
||||
|
||||
"os_android",
|
||||
"os_ios",
|
||||
"os_linux",
|
||||
"os_macos",
|
||||
"os_other",
|
||||
"os_windows",
|
||||
|
||||
"user_admin",
|
||||
"user_child",
|
||||
"user_regular",
|
||||
}
|
||||
@@ -423,6 +423,7 @@ var config = &configuration{
|
||||
Enabled: false,
|
||||
Bing: true,
|
||||
DuckDuckGo: true,
|
||||
Ecosia: true,
|
||||
Google: true,
|
||||
Pixabay: true,
|
||||
Yandex: true,
|
||||
|
||||
@@ -433,7 +433,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
// moment we'll allow setting up TLS in the initial configuration or the
|
||||
// configuration itself will use HTTPS protocol, because the underlying
|
||||
// functions potentially restart the HTTPS server.
|
||||
err = startMods()
|
||||
err = startMods(web.logger)
|
||||
if err != nil {
|
||||
Context.firstRun = true
|
||||
copyInstallSettings(config, curConfig)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
@@ -19,6 +21,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/ameshkov/dnscrypt/v2"
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
@@ -43,8 +46,8 @@ func onConfigModified() {
|
||||
|
||||
// initDNS updates all the fields of the [Context] needed to initialize the DNS
|
||||
// server and initializes it at last. It also must not be called unless
|
||||
// [config] and [Context] are initialized.
|
||||
func initDNS() (err error) {
|
||||
// [config] and [Context] are initialized. l must not be nil.
|
||||
func initDNS(l *slog.Logger) (err error) {
|
||||
anonymizer := config.anonymizer()
|
||||
|
||||
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&Context, config)
|
||||
@@ -53,6 +56,7 @@ func initDNS() (err error) {
|
||||
}
|
||||
|
||||
statsConf := stats.Config{
|
||||
Logger: l.With(slogutil.KeyPrefix, "stats"),
|
||||
Filename: filepath.Join(statsDir, "stats.db"),
|
||||
Limit: config.Stats.Interval.Duration,
|
||||
ConfigModified: onConfigModified,
|
||||
@@ -113,13 +117,16 @@ func initDNS() (err error) {
|
||||
anonymizer,
|
||||
httpRegister,
|
||||
tlsConf,
|
||||
l,
|
||||
)
|
||||
}
|
||||
|
||||
// initDNSServer initializes the [context.dnsServer]. To only use the internal
|
||||
// proxy, none of the arguments are required, but tlsConf still must not be nil,
|
||||
// in other cases all the arguments also must not be nil. It also must not be
|
||||
// called unless [config] and [Context] are initialized.
|
||||
// proxy, none of the arguments are required, but tlsConf and l still must not
|
||||
// be nil, in other cases all the arguments also must not be nil. It also must
|
||||
// not be called unless [config] and [Context] are initialized.
|
||||
//
|
||||
// TODO(e.burkov): Use [dnsforward.DNSCreateParams] as a parameter.
|
||||
func initDNSServer(
|
||||
filters *filtering.DNSFilter,
|
||||
sts stats.Interface,
|
||||
@@ -128,8 +135,10 @@ func initDNSServer(
|
||||
anonymizer *aghnet.IPMut,
|
||||
httpReg aghhttp.RegisterFunc,
|
||||
tlsConf *tlsConfigSettings,
|
||||
l *slog.Logger,
|
||||
) (err error) {
|
||||
Context.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{
|
||||
Logger: l,
|
||||
DNSFilter: filters,
|
||||
Stats: sts,
|
||||
QueryLog: qlog,
|
||||
@@ -406,9 +415,9 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
|
||||
|
||||
setts.ClientIP = clientIP
|
||||
|
||||
c, ok := Context.clients.find(clientID)
|
||||
c, ok := Context.clients.storage.Find(clientID)
|
||||
if !ok {
|
||||
c, ok = Context.clients.find(clientIP.String())
|
||||
c, ok = Context.clients.storage.Find(clientIP.String())
|
||||
if !ok {
|
||||
log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID)
|
||||
|
||||
@@ -451,11 +460,15 @@ func startDNSServer() error {
|
||||
|
||||
Context.filters.EnableFilters(false)
|
||||
|
||||
Context.clients.Start()
|
||||
|
||||
err := Context.dnsServer.Start()
|
||||
// TODO(s.chzhen): Pass context.
|
||||
err := Context.clients.Start(context.TODO())
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't start forwarding DNS server: %w", err)
|
||||
return fmt.Errorf("starting clients container: %w", err)
|
||||
}
|
||||
|
||||
err = Context.dnsServer.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting dns server: %w", err)
|
||||
}
|
||||
|
||||
Context.filters.Start()
|
||||
@@ -492,7 +505,7 @@ func stopDNSServer() (err error) {
|
||||
return fmt.Errorf("stopping forwarding dns server: %w", err)
|
||||
}
|
||||
|
||||
err = Context.clients.close()
|
||||
err = Context.clients.close(context.TODO())
|
||||
if err != nil {
|
||||
return fmt.Errorf("closing clients container: %w", err)
|
||||
}
|
||||
|
||||
@@ -18,9 +18,8 @@ var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4})
|
||||
func newStorage(tb testing.TB, clients []*client.Persistent) (s *client.Storage) {
|
||||
tb.Helper()
|
||||
|
||||
s = client.NewStorage(&client.Config{
|
||||
AllowedTags: nil,
|
||||
})
|
||||
s, err := client.NewStorage(&client.StorageConfig{})
|
||||
require.NoError(tb, err)
|
||||
|
||||
for _, p := range clients {
|
||||
p.UID = client.MustNewUID()
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
@@ -38,6 +39,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/hostsfile"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
)
|
||||
@@ -90,6 +92,8 @@ func (c *homeContext) getDataDir() string {
|
||||
}
|
||||
|
||||
// Context - a global context object
|
||||
//
|
||||
// TODO(a.garipov): Refactor.
|
||||
var Context homeContext
|
||||
|
||||
// Main is the entry point
|
||||
@@ -115,7 +119,7 @@ func Main(clientBuildFS fs.FS) {
|
||||
log.Info("Received signal %q", sig)
|
||||
switch sig {
|
||||
case syscall.SIGHUP:
|
||||
Context.clients.reloadARP()
|
||||
Context.clients.storage.ReloadARP()
|
||||
Context.tls.reload()
|
||||
default:
|
||||
cleanup(context.Background())
|
||||
@@ -273,7 +277,7 @@ func setupOpts(opts options) (err error) {
|
||||
}
|
||||
|
||||
// initContextClients initializes Context clients and related fields.
|
||||
func initContextClients() (err error) {
|
||||
func initContextClients(logger *slog.Logger) (err error) {
|
||||
err = setupDNSFilteringConf(config.Filtering)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
@@ -297,7 +301,7 @@ func initContextClients() (err error) {
|
||||
|
||||
var arpDB arpdb.Interface
|
||||
if config.Clients.Sources.ARP {
|
||||
arpDB = arpdb.New()
|
||||
arpDB = arpdb.New(logger.With(slogutil.KeyError, "arpdb"))
|
||||
}
|
||||
|
||||
return Context.clients.Init(
|
||||
@@ -482,7 +486,12 @@ func checkPorts() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func initWeb(opts options, clientBuildFS fs.FS, upd *updater.Updater) (web *webAPI, err error) {
|
||||
func initWeb(
|
||||
opts options,
|
||||
clientBuildFS fs.FS,
|
||||
upd *updater.Updater,
|
||||
l *slog.Logger,
|
||||
) (web *webAPI, err error) {
|
||||
var clientFS fs.FS
|
||||
if opts.localFrontend {
|
||||
log.Info("warning: using local frontend files")
|
||||
@@ -524,7 +533,7 @@ func initWeb(opts options, clientBuildFS fs.FS, upd *updater.Updater) (web *webA
|
||||
serveHTTP3: config.DNS.ServeHTTP3,
|
||||
}
|
||||
|
||||
web = newWebAPI(webConf)
|
||||
web = newWebAPI(webConf, l)
|
||||
if web == nil {
|
||||
return nil, fmt.Errorf("initializing web: %w", err)
|
||||
}
|
||||
@@ -547,10 +556,15 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
// Configure config filename.
|
||||
initConfigFilename(opts)
|
||||
|
||||
ls := getLogSettings(opts)
|
||||
|
||||
// Configure log level and output.
|
||||
err = configureLogger(opts)
|
||||
err = configureLogger(ls)
|
||||
fatalOnError(err)
|
||||
|
||||
// TODO(a.garipov): Use slog everywhere.
|
||||
slogLogger := newSlogLogger(ls)
|
||||
|
||||
// Print the first message after logger is configured.
|
||||
log.Info(version.Full())
|
||||
log.Debug("current working directory is %s", Context.workDir)
|
||||
@@ -569,7 +583,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
// data first, but also to avoid relying on automatic Go init() function.
|
||||
filtering.InitModule()
|
||||
|
||||
err = initContextClients()
|
||||
err = initContextClients(slogLogger)
|
||||
fatalOnError(err)
|
||||
|
||||
err = setupOpts(opts)
|
||||
@@ -604,7 +618,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
|
||||
// TODO(e.burkov): This could be made earlier, probably as the option's
|
||||
// effect.
|
||||
cmdlineUpdate(opts, upd)
|
||||
cmdlineUpdate(opts, upd, slogLogger)
|
||||
|
||||
if !Context.firstRun {
|
||||
// Save the updated config.
|
||||
@@ -632,11 +646,11 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
onConfigModified()
|
||||
}
|
||||
|
||||
Context.web, err = initWeb(opts, clientBuildFS, upd)
|
||||
Context.web, err = initWeb(opts, clientBuildFS, upd, slogLogger)
|
||||
fatalOnError(err)
|
||||
|
||||
if !Context.firstRun {
|
||||
err = initDNS()
|
||||
err = initDNS(slogLogger)
|
||||
fatalOnError(err)
|
||||
|
||||
Context.tls.start()
|
||||
@@ -697,9 +711,10 @@ func (c *configuration) anonymizer() (ipmut *aghnet.IPMut) {
|
||||
return aghnet.NewIPMut(anonFunc)
|
||||
}
|
||||
|
||||
// startMods initializes and starts the DNS server after installation.
|
||||
func startMods() (err error) {
|
||||
err = initDNS()
|
||||
// startMods initializes and starts the DNS server after installation. l must
|
||||
// not be nil.
|
||||
func startMods(l *slog.Logger) (err error) {
|
||||
err = initDNS(l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -959,8 +974,8 @@ type jsonError struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// cmdlineUpdate updates current application and exits.
|
||||
func cmdlineUpdate(opts options, upd *updater.Updater) {
|
||||
// cmdlineUpdate updates current application and exits. l must not be nil.
|
||||
func cmdlineUpdate(opts options, upd *updater.Updater, l *slog.Logger) {
|
||||
if !opts.performUpdate {
|
||||
return
|
||||
}
|
||||
@@ -970,7 +985,7 @@ func cmdlineUpdate(opts options, upd *updater.Updater) {
|
||||
//
|
||||
// TODO(e.burkov): We could probably initialize the internal resolver
|
||||
// separately.
|
||||
err := initDNSServer(nil, nil, nil, nil, nil, nil, &tlsConfigSettings{})
|
||||
err := initDNSServer(nil, nil, nil, nil, nil, nil, &tlsConfigSettings{}, l)
|
||||
fatalOnError(err)
|
||||
|
||||
log.Info("cmdline update: performing update")
|
||||
|
||||
@@ -3,11 +3,13 @@ package home
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -16,10 +18,21 @@ import (
|
||||
// for logger output.
|
||||
const configSyslog = "syslog"
|
||||
|
||||
// configureLogger configures logger level and output.
|
||||
func configureLogger(opts options) (err error) {
|
||||
ls := getLogSettings(opts)
|
||||
// newSlogLogger returns new [*slog.Logger] configured with the given settings.
|
||||
func newSlogLogger(ls *logSettings) (l *slog.Logger) {
|
||||
if !ls.Enabled {
|
||||
return slogutil.NewDiscardLogger()
|
||||
}
|
||||
|
||||
return slogutil.New(&slogutil.Config{
|
||||
Format: slogutil.FormatAdGuardLegacy,
|
||||
AddTimestamp: true,
|
||||
Verbose: ls.Verbose,
|
||||
})
|
||||
}
|
||||
|
||||
// configureLogger configures logger level and output.
|
||||
func configureLogger(ls *logSettings) (err error) {
|
||||
// Configure logger level.
|
||||
if !ls.Enabled {
|
||||
log.SetLevel(log.OFF)
|
||||
@@ -60,7 +73,7 @@ func configureLogger(opts options) (err error) {
|
||||
MaxAge: ls.MaxAge,
|
||||
})
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// getLogSettings returns a log settings object properly initialized from opts.
|
||||
|
||||
@@ -5,12 +5,15 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/golibs/ioutil"
|
||||
"github.com/c2h5oh/datasize"
|
||||
)
|
||||
|
||||
// middlerware is a wrapper function signature.
|
||||
type middleware func(http.Handler) http.Handler
|
||||
|
||||
// withMiddlewares consequently wraps h with all the middlewares.
|
||||
//
|
||||
// TODO(e.burkov): Use [httputil.Wrap].
|
||||
func withMiddlewares(h http.Handler, middlewares ...middleware) (wrapped http.Handler) {
|
||||
wrapped = h
|
||||
|
||||
@@ -23,11 +26,11 @@ func withMiddlewares(h http.Handler, middlewares ...middleware) (wrapped http.Ha
|
||||
|
||||
const (
|
||||
// defaultReqBodySzLim is the default maximum request body size.
|
||||
defaultReqBodySzLim = 64 * 1024
|
||||
defaultReqBodySzLim datasize.ByteSize = 64 * datasize.KB
|
||||
|
||||
// largerReqBodySzLim is the maximum request body size for APIs expecting
|
||||
// larger requests.
|
||||
largerReqBodySzLim = 4 * 1024 * 1024
|
||||
largerReqBodySzLim datasize.ByteSize = 4 * datasize.MB
|
||||
)
|
||||
|
||||
// expectsLargerRequests shows if this request should use a larger body size
|
||||
@@ -38,26 +41,28 @@ const (
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2666 and
|
||||
// https://github.com/AdguardTeam/AdGuardHome/issues/2675.
|
||||
func expectsLargerRequests(r *http.Request) (ok bool) {
|
||||
m := r.Method
|
||||
if m != http.MethodPost {
|
||||
if r.Method != http.MethodPost {
|
||||
return false
|
||||
}
|
||||
|
||||
p := r.URL.Path
|
||||
return p == "/control/access/set" ||
|
||||
p == "/control/filtering/set_rules"
|
||||
switch r.URL.Path {
|
||||
case "/control/access/set", "/control/filtering/set_rules":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// limitRequestBody wraps underlying handler h, making it's request's body Read
|
||||
// method limited.
|
||||
func limitRequestBody(h http.Handler) (limited http.Handler) {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var szLim uint64 = defaultReqBodySzLim
|
||||
szLim := defaultReqBodySzLim
|
||||
if expectsLargerRequests(r) {
|
||||
szLim = largerReqBodySzLim
|
||||
}
|
||||
|
||||
reader := ioutil.LimitReader(r.Body, szLim)
|
||||
reader := ioutil.LimitReader(r.Body, szLim.Bytes())
|
||||
|
||||
// HTTP handlers aren't supposed to call r.Body.Close(), so just
|
||||
// replace the body in a clone.
|
||||
|
||||
@@ -14,29 +14,29 @@ import (
|
||||
|
||||
func TestLimitRequestBody(t *testing.T) {
|
||||
errReqLimitReached := &ioutil.LimitError{
|
||||
Limit: defaultReqBodySzLim,
|
||||
Limit: defaultReqBodySzLim.Bytes(),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
wantErr error
|
||||
name string
|
||||
body string
|
||||
want []byte
|
||||
wantErr error
|
||||
}{{
|
||||
wantErr: nil,
|
||||
name: "not_so_big",
|
||||
body: "somestr",
|
||||
want: []byte("somestr"),
|
||||
wantErr: nil,
|
||||
}, {
|
||||
wantErr: errReqLimitReached,
|
||||
name: "so_big",
|
||||
body: string(make([]byte, defaultReqBodySzLim+1)),
|
||||
want: make([]byte, defaultReqBodySzLim),
|
||||
wantErr: errReqLimitReached,
|
||||
}, {
|
||||
wantErr: nil,
|
||||
name: "empty",
|
||||
body: "",
|
||||
want: []byte(nil),
|
||||
wantErr: nil,
|
||||
}}
|
||||
|
||||
makeHandler := func(t *testing.T, err *error) http.HandlerFunc {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
@@ -16,7 +17,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/pprofutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/httputil"
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
"golang.org/x/net/http2"
|
||||
@@ -90,17 +91,22 @@ type webAPI struct {
|
||||
// TODO(a.garipov): Refactor all these servers.
|
||||
httpServer *http.Server
|
||||
|
||||
// logger is a slog logger used in webAPI. It must not be nil.
|
||||
logger *slog.Logger
|
||||
|
||||
// httpsServer is the server that handles HTTPS traffic. If it is not nil,
|
||||
// [Web.http3Server] must also not be nil.
|
||||
httpsServer httpsServer
|
||||
}
|
||||
|
||||
// newWebAPI creates a new instance of the web UI and API server.
|
||||
func newWebAPI(conf *webConfig) (w *webAPI) {
|
||||
// newWebAPI creates a new instance of the web UI and API server. l must not be
|
||||
// nil.
|
||||
func newWebAPI(conf *webConfig, l *slog.Logger) (w *webAPI) {
|
||||
log.Info("web: initializing")
|
||||
|
||||
w = &webAPI{
|
||||
conf: conf,
|
||||
conf: conf,
|
||||
logger: l,
|
||||
}
|
||||
|
||||
clientFS := http.FileServer(http.FS(conf.clientFS))
|
||||
@@ -327,7 +333,7 @@ func startPprof(port uint16) {
|
||||
runtime.SetMutexProfileFraction(1)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
pprofutil.RoutePprof(mux)
|
||||
httputil.RoutePprof(mux)
|
||||
|
||||
go func() {
|
||||
defer log.OnPanic("pprof server")
|
||||
|
||||
Reference in New Issue
Block a user