all: sync with master; upd chlog
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
@@ -83,50 +83,6 @@ func (c *Client) setSafeSearch(
|
||||
return nil
|
||||
}
|
||||
|
||||
// clientSource represents the source from which the information about the
|
||||
// client has been obtained.
|
||||
type clientSource uint
|
||||
|
||||
// Clients information sources. The order determines the priority.
|
||||
const (
|
||||
ClientSourceNone clientSource = iota
|
||||
ClientSourceWHOIS
|
||||
ClientSourceARP
|
||||
ClientSourceRDNS
|
||||
ClientSourceDHCP
|
||||
ClientSourceHostsFile
|
||||
ClientSourcePersistent
|
||||
)
|
||||
|
||||
// type check
|
||||
var _ fmt.Stringer = clientSource(0)
|
||||
|
||||
// String returns a human-readable name of cs.
|
||||
func (cs clientSource) String() (s string) {
|
||||
switch cs {
|
||||
case ClientSourceWHOIS:
|
||||
return "WHOIS"
|
||||
case ClientSourceARP:
|
||||
return "ARP"
|
||||
case ClientSourceRDNS:
|
||||
return "rDNS"
|
||||
case ClientSourceDHCP:
|
||||
return "DHCP"
|
||||
case ClientSourceHostsFile:
|
||||
return "etc/hosts"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ encoding.TextMarshaler = clientSource(0)
|
||||
|
||||
// MarshalText implements encoding.TextMarshaler for the clientSource.
|
||||
func (cs clientSource) MarshalText() (text []byte, err error) {
|
||||
return []byte(cs.String()), nil
|
||||
}
|
||||
|
||||
// RuntimeClient is a client information about which has been obtained using the
|
||||
// source described in the Source field.
|
||||
type RuntimeClient struct {
|
||||
@@ -138,5 +94,5 @@ type RuntimeClient struct {
|
||||
|
||||
// Source is the source from which the information about the client has
|
||||
// been obtained.
|
||||
Source clientSource
|
||||
Source client.Source
|
||||
}
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
@@ -34,7 +34,7 @@ type DHCP interface {
|
||||
|
||||
// HostByIP returns the hostname of the DHCP client with the given IP
|
||||
// address. The address will be netip.Addr{} if there is no such client,
|
||||
// due to an assumption that a DHCP client must always have an IP address.
|
||||
// due to an assumption that a DHCP client must always have a hostname.
|
||||
HostByIP(ip netip.Addr) (host string)
|
||||
|
||||
// MACByIP returns the MAC address for the given IP address leased. It
|
||||
@@ -55,8 +55,8 @@ type clientsContainer struct {
|
||||
|
||||
allTags *stringutil.Set
|
||||
|
||||
// dhcpServer is used for looking up clients IP addresses by MAC addresses
|
||||
dhcpServer dhcpd.Interface
|
||||
// dhcp is the DHCP service implementation.
|
||||
dhcp DHCP
|
||||
|
||||
// dnsServer is used for checking clients IP status access list status
|
||||
dnsServer *dnsforward.Server
|
||||
@@ -65,8 +65,8 @@ type clientsContainer struct {
|
||||
// hosts database.
|
||||
etcHosts *aghnet.HostsContainer
|
||||
|
||||
// arpdb stores the neighbors retrieved from ARP.
|
||||
arpdb aghnet.ARPDB
|
||||
// arpDB stores the neighbors retrieved from ARP.
|
||||
arpDB arpdb.Interface
|
||||
|
||||
// lock protects all fields.
|
||||
//
|
||||
@@ -93,9 +93,9 @@ type clientsContainer struct {
|
||||
// Note: this function must be called only once
|
||||
func (clients *clientsContainer) Init(
|
||||
objects []*clientObject,
|
||||
dhcpServer dhcpd.Interface,
|
||||
dhcpServer DHCP,
|
||||
etcHosts *aghnet.HostsContainer,
|
||||
arpdb aghnet.ARPDB,
|
||||
arpDB arpdb.Interface,
|
||||
filteringConf *filtering.Config,
|
||||
) (err error) {
|
||||
if clients.list != nil {
|
||||
@@ -108,9 +108,11 @@ func (clients *clientsContainer) Init(
|
||||
|
||||
clients.allTags = stringutil.NewSet(clientTags...)
|
||||
|
||||
clients.dhcpServer = dhcpServer
|
||||
// TODO(e.burkov): Use [dhcpsvc] implementation when it's ready.
|
||||
clients.dhcp = dhcpServer
|
||||
|
||||
clients.etcHosts = etcHosts
|
||||
clients.arpdb = arpdb
|
||||
clients.arpDB = arpDB
|
||||
err = clients.addFromConfig(objects, filteringConf)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
@@ -124,11 +126,6 @@ func (clients *clientsContainer) Init(
|
||||
return nil
|
||||
}
|
||||
|
||||
if clients.dhcpServer != nil {
|
||||
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
|
||||
clients.onDHCPLeaseChanged(dhcpd.LeaseChangedAdded)
|
||||
}
|
||||
|
||||
if clients.etcHosts != nil {
|
||||
go clients.handleHostsUpdates()
|
||||
}
|
||||
@@ -164,7 +161,7 @@ func (clients *clientsContainer) Start() {
|
||||
|
||||
// reloadARP reloads runtime clients from ARP, if configured.
|
||||
func (clients *clientsContainer) reloadARP() {
|
||||
if clients.arpdb != nil {
|
||||
if clients.arpDB != nil {
|
||||
clients.addFromSystemARP()
|
||||
}
|
||||
}
|
||||
@@ -290,8 +287,8 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||
// above loop can generate different orderings when writing to the config
|
||||
// file: this produces lots of diffs in config files, so sort objects by
|
||||
// name before writing.
|
||||
slices.SortStableFunc(objs, func(a, b *clientObject) (sortsBefore bool) {
|
||||
return a.Name < b.Name
|
||||
slices.SortStableFunc(objs, func(a, b *clientObject) (res int) {
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
return objs
|
||||
@@ -309,58 +306,28 @@ func (clients *clientsContainer) periodicUpdate() {
|
||||
}
|
||||
}
|
||||
|
||||
// onDHCPLeaseChanged is a callback for the DHCP server. It updates the list of
|
||||
// runtime clients using the DHCP server's leases.
|
||||
//
|
||||
// TODO(e.burkov): Remove when switched to dhcpsvc.
|
||||
func (clients *clientsContainer) onDHCPLeaseChanged(flags int) {
|
||||
if clients.dhcpServer == nil || !config.Clients.Sources.DHCP {
|
||||
return
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
clients.rmHostsBySrc(ClientSourceDHCP)
|
||||
|
||||
if flags == dhcpd.LeaseChangedRemovedAll {
|
||||
return
|
||||
}
|
||||
|
||||
leases := clients.dhcpServer.Leases(dhcpd.LeasesAll)
|
||||
n := 0
|
||||
for _, l := range leases {
|
||||
if l.Hostname == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
ok := clients.addHostLocked(l.IP, l.Hostname, ClientSourceDHCP)
|
||||
if ok {
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("clients: added %d client aliases from dhcp", n)
|
||||
}
|
||||
|
||||
// clientSource checks if client with this IP address already exists and returns
|
||||
// the source which updated it last. It returns [ClientSourceNone] if the
|
||||
// the source which updated it last. It returns [client.SourceNone] if the
|
||||
// client doesn't exist.
|
||||
func (clients *clientsContainer) clientSource(ip netip.Addr) (src clientSource) {
|
||||
func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
_, ok := clients.findLocked(ip.String())
|
||||
if ok {
|
||||
return ClientSourcePersistent
|
||||
return client.SourcePersistent
|
||||
}
|
||||
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
if ok {
|
||||
return rc.Source
|
||||
src = rc.Source
|
||||
}
|
||||
|
||||
return ClientSourceNone
|
||||
if src < client.SourceDHCP && clients.dhcp.HostByIP(ip) != "" {
|
||||
src = client.SourceDHCP
|
||||
}
|
||||
|
||||
return src
|
||||
}
|
||||
|
||||
// findMultiple is a wrapper around Find to make it a valid client finder for
|
||||
@@ -521,17 +488,14 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
|
||||
}
|
||||
}
|
||||
|
||||
if clients.dhcpServer != nil {
|
||||
return clients.findDHCP(ip)
|
||||
}
|
||||
|
||||
return nil, false
|
||||
// TODO(e.burkov): Iterate through clients.list only once.
|
||||
return clients.findDHCP(ip)
|
||||
}
|
||||
|
||||
// findDHCP searches for a client by its MAC, if the DHCP server is active and
|
||||
// there is such client. clients.lock is expected to be locked.
|
||||
func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *Client, ok bool) {
|
||||
foundMAC := clients.dhcpServer.FindMACbyIP(ip)
|
||||
foundMAC := clients.dhcp.MACByIP(ip)
|
||||
if foundMAC == nil {
|
||||
return nil, false
|
||||
}
|
||||
@@ -552,8 +516,9 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *Client, ok bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findRuntimeClient finds a runtime client by their IP.
|
||||
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *RuntimeClient, ok bool) {
|
||||
// runtimeClient returns a runtime client from internal index. Note that it
|
||||
// doesn't include DHCP clients.
|
||||
func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *RuntimeClient, ok bool) {
|
||||
if ip == (netip.Addr{}) {
|
||||
return nil, false
|
||||
}
|
||||
@@ -566,6 +531,24 @@ func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *RuntimeCl
|
||||
return rc, ok
|
||||
}
|
||||
|
||||
// findRuntimeClient finds a runtime client by their IP.
|
||||
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *RuntimeClient, ok bool) {
|
||||
if rc, ok = clients.runtimeClient(ip); ok && rc.Source > client.SourceDHCP {
|
||||
return rc, ok
|
||||
}
|
||||
|
||||
host := clients.dhcp.HostByIP(ip)
|
||||
if host == "" {
|
||||
return rc, ok
|
||||
}
|
||||
|
||||
return &RuntimeClient{
|
||||
Host: host,
|
||||
Source: client.SourceDHCP,
|
||||
WHOIS: &whois.Info{},
|
||||
}, true
|
||||
}
|
||||
|
||||
// check validates the client.
|
||||
func (clients *clientsContainer) check(c *Client) (err error) {
|
||||
switch {
|
||||
@@ -761,7 +744,7 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
|
||||
// Create a RuntimeClient implicitly so that we don't do this check
|
||||
// again.
|
||||
rc = &RuntimeClient{
|
||||
Source: ClientSourceWHOIS,
|
||||
Source: client.SourceWHOIS,
|
||||
}
|
||||
clients.ipToRC[ip] = rc
|
||||
|
||||
@@ -780,7 +763,7 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
|
||||
func (clients *clientsContainer) addHost(
|
||||
ip netip.Addr,
|
||||
host string,
|
||||
src clientSource,
|
||||
src client.Source,
|
||||
) (ok bool) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
@@ -803,7 +786,7 @@ func (clients *clientsContainer) UpdateAddress(ip netip.Addr, host string, info
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
if host != "" {
|
||||
ok := clients.addHostLocked(ip, host, ClientSourceRDNS)
|
||||
ok := clients.addHostLocked(ip, host, client.SourceRDNS)
|
||||
if !ok {
|
||||
log.Debug("clients: host for client %q already set with higher priority source", ip)
|
||||
}
|
||||
@@ -819,14 +802,19 @@ func (clients *clientsContainer) UpdateAddress(ip netip.Addr, host string, info
|
||||
func (clients *clientsContainer) addHostLocked(
|
||||
ip netip.Addr,
|
||||
host string,
|
||||
src clientSource,
|
||||
src client.Source,
|
||||
) (ok bool) {
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
if !ok {
|
||||
if src < client.SourceDHCP {
|
||||
if clients.dhcp.HostByIP(ip) != "" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
rc = &RuntimeClient{
|
||||
WHOIS: &whois.Info{},
|
||||
}
|
||||
|
||||
clients.ipToRC[ip] = rc
|
||||
} else if src < rc.Source {
|
||||
return false
|
||||
@@ -841,7 +829,7 @@ func (clients *clientsContainer) addHostLocked(
|
||||
}
|
||||
|
||||
// rmHostsBySrc removes all entries that match the specified source.
|
||||
func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
|
||||
func (clients *clientsContainer) rmHostsBySrc(src client.Source) {
|
||||
n := 0
|
||||
for ip, rc := range clients.ipToRC {
|
||||
if rc.Source == src {
|
||||
@@ -855,15 +843,19 @@ func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
|
||||
|
||||
// addFromHostsFile fills the client-hostname pairing index from the system's
|
||||
// hosts files.
|
||||
func (clients *clientsContainer) addFromHostsFile(hosts aghnet.HostsRecords) {
|
||||
func (clients *clientsContainer) addFromHostsFile(hosts aghnet.Hosts) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
clients.rmHostsBySrc(ClientSourceHostsFile)
|
||||
clients.rmHostsBySrc(client.SourceHostsFile)
|
||||
|
||||
n := 0
|
||||
for ip, rec := range hosts {
|
||||
clients.addHostLocked(ip, rec.Canonical, ClientSourceHostsFile)
|
||||
for addr, rec := range hosts {
|
||||
// Only the first name of the first record is considered a canonical
|
||||
// hostname for the IP address.
|
||||
//
|
||||
// TODO(e.burkov): Consider using all the names from all the records.
|
||||
clients.addHostLocked(addr, rec[0].Names[0], client.SourceHostsFile)
|
||||
n++
|
||||
}
|
||||
|
||||
@@ -873,15 +865,15 @@ func (clients *clientsContainer) addFromHostsFile(hosts aghnet.HostsRecords) {
|
||||
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
|
||||
// command.
|
||||
func (clients *clientsContainer) addFromSystemARP() {
|
||||
if err := clients.arpdb.Refresh(); err != nil {
|
||||
if err := clients.arpDB.Refresh(); err != nil {
|
||||
log.Error("refreshing arp container: %s", err)
|
||||
|
||||
clients.arpdb = aghnet.EmptyARPDB{}
|
||||
clients.arpDB = arpdb.Empty{}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ns := clients.arpdb.Neighbors()
|
||||
ns := clients.arpDB.Neighbors()
|
||||
if len(ns) == 0 {
|
||||
log.Debug("refreshing arp container: the update is empty")
|
||||
|
||||
@@ -891,11 +883,11 @@ func (clients *clientsContainer) addFromSystemARP() {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
clients.rmHostsBySrc(ClientSourceARP)
|
||||
clients.rmHostsBySrc(client.SourceARP)
|
||||
|
||||
added := 0
|
||||
for _, n := range ns {
|
||||
if clients.addHostLocked(n.IP, n.Name, ClientSourceARP) {
|
||||
if clients.addHostLocked(n.IP, n.Name, client.SourceARP) {
|
||||
added++
|
||||
}
|
||||
}
|
||||
@@ -907,7 +899,9 @@ func (clients *clientsContainer) addFromSystemARP() {
|
||||
// the persistent clients.
|
||||
func (clients *clientsContainer) close() (err error) {
|
||||
persistent := maps.Values(clients.list)
|
||||
slices.SortFunc(persistent, func(a, b *Client) (less bool) { return a.Name < b.Name })
|
||||
slices.SortFunc(persistent, func(a, b *Client) (res int) {
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
var errs []error
|
||||
|
||||
@@ -917,9 +911,5 @@ func (clients *clientsContainer) close() (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return errors.List("closing client specific upstreams", errs...)
|
||||
}
|
||||
|
||||
return nil
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
@@ -7,22 +7,46 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testDHCP struct {
|
||||
OnLeases func() (leases []*dhcpsvc.Lease)
|
||||
OnHostBy func(ip netip.Addr) (host string)
|
||||
OnMACBy func(ip netip.Addr) (mac net.HardwareAddr)
|
||||
}
|
||||
|
||||
// Lease implements the [DHCP] interface for testDHCP.
|
||||
func (t *testDHCP) Leases() (leases []*dhcpsvc.Lease) { return t.OnLeases() }
|
||||
|
||||
// HostByIP implements the [DHCP] interface for testDHCP.
|
||||
func (t *testDHCP) HostByIP(ip netip.Addr) (host string) { return t.OnHostBy(ip) }
|
||||
|
||||
// MACByIP implements the [DHCP] interface for testDHCP.
|
||||
func (t *testDHCP) MACByIP(ip netip.Addr) (mac net.HardwareAddr) { return t.OnMACBy(ip) }
|
||||
|
||||
// newClientsContainer is a helper that creates a new clients container for
|
||||
// tests.
|
||||
func newClientsContainer(t *testing.T) (c *clientsContainer) {
|
||||
t.Helper()
|
||||
|
||||
c = &clientsContainer{
|
||||
testing: true,
|
||||
}
|
||||
|
||||
err := c.Init(nil, nil, nil, nil, &filtering.Config{})
|
||||
require.NoError(t, err)
|
||||
dhcp := &testDHCP{
|
||||
OnLeases: func() (leases []*dhcpsvc.Lease) { panic("not implemented") },
|
||||
OnHostBy: func(ip netip.Addr) (host string) { return "" },
|
||||
OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) { return nil },
|
||||
}
|
||||
|
||||
require.NoError(t, c.Init(nil, dhcp, nil, nil, &filtering.Config{}))
|
||||
|
||||
return c
|
||||
}
|
||||
@@ -76,9 +100,9 @@ func TestClients(t *testing.T) {
|
||||
|
||||
assert.Equal(t, "client2", c.Name)
|
||||
|
||||
assert.Equal(t, clients.clientSource(cliNoneIP), ClientSourceNone)
|
||||
assert.Equal(t, clients.clientSource(cli1IP), ClientSourcePersistent)
|
||||
assert.Equal(t, clients.clientSource(cli2IP), ClientSourcePersistent)
|
||||
assert.Equal(t, clients.clientSource(cliNoneIP), client.SourceNone)
|
||||
assert.Equal(t, clients.clientSource(cli1IP), client.SourcePersistent)
|
||||
assert.Equal(t, clients.clientSource(cli2IP), client.SourcePersistent)
|
||||
})
|
||||
|
||||
t.Run("add_fail_name", func(t *testing.T) {
|
||||
@@ -125,8 +149,8 @@ func TestClients(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, clients.clientSource(cliOldIP), ClientSourceNone)
|
||||
assert.Equal(t, clients.clientSource(cliNewIP), ClientSourcePersistent)
|
||||
assert.Equal(t, clients.clientSource(cliOldIP), client.SourceNone)
|
||||
assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent)
|
||||
|
||||
prev, ok = clients.list["client1"]
|
||||
require.True(t, ok)
|
||||
@@ -158,7 +182,7 @@ func TestClients(t *testing.T) {
|
||||
ok := clients.Del("client1-renamed")
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, clients.clientSource(netip.MustParseAddr("1.1.1.2")), ClientSourceNone)
|
||||
assert.Equal(t, clients.clientSource(netip.MustParseAddr("1.1.1.2")), client.SourceNone)
|
||||
})
|
||||
|
||||
t.Run("del_fail", func(t *testing.T) {
|
||||
@@ -168,32 +192,32 @@ func TestClients(t *testing.T) {
|
||||
|
||||
t.Run("addhost_success", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
ok := clients.addHost(ip, "host", ClientSourceARP)
|
||||
ok := clients.addHost(ip, "host", client.SourceARP)
|
||||
assert.True(t, ok)
|
||||
|
||||
ok = clients.addHost(ip, "host2", ClientSourceARP)
|
||||
ok = clients.addHost(ip, "host2", client.SourceARP)
|
||||
assert.True(t, ok)
|
||||
|
||||
ok = clients.addHost(ip, "host3", ClientSourceHostsFile)
|
||||
ok = clients.addHost(ip, "host3", client.SourceHostsFile)
|
||||
assert.True(t, ok)
|
||||
|
||||
assert.Equal(t, clients.clientSource(ip), ClientSourceHostsFile)
|
||||
assert.Equal(t, clients.clientSource(ip), client.SourceHostsFile)
|
||||
})
|
||||
|
||||
t.Run("dhcp_replaces_arp", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.2.3.4")
|
||||
ok := clients.addHost(ip, "from_arp", ClientSourceARP)
|
||||
ok := clients.addHost(ip, "from_arp", client.SourceARP)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, clients.clientSource(ip), ClientSourceARP)
|
||||
assert.Equal(t, clients.clientSource(ip), client.SourceARP)
|
||||
|
||||
ok = clients.addHost(ip, "from_dhcp", ClientSourceDHCP)
|
||||
ok = clients.addHost(ip, "from_dhcp", client.SourceDHCP)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, clients.clientSource(ip), ClientSourceDHCP)
|
||||
assert.Equal(t, clients.clientSource(ip), client.SourceDHCP)
|
||||
})
|
||||
|
||||
t.Run("addhost_fail", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
ok := clients.addHost(ip, "host1", ClientSourceRDNS)
|
||||
ok := clients.addHost(ip, "host1", client.SourceRDNS)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
@@ -216,7 +240,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
|
||||
t.Run("existing_auto-client", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
ok := clients.addHost(ip, "host", ClientSourceRDNS)
|
||||
ok := clients.addHost(ip, "host", client.SourceRDNS)
|
||||
assert.True(t, ok)
|
||||
|
||||
clients.setWHOISInfo(ip, whois)
|
||||
@@ -259,7 +283,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
assert.True(t, ok)
|
||||
|
||||
// Now add an auto-client with the same IP.
|
||||
ok = clients.addHost(ip, "test", ClientSourceRDNS)
|
||||
ok = clients.addHost(ip, "test", client.SourceRDNS)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
@@ -288,7 +312,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
dhcpServer, err := dhcpd.Create(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
clients.dhcpServer = dhcpServer
|
||||
clients.dhcp = dhcpServer
|
||||
|
||||
err = dhcpServer.AddStaticLease(&dhcpd.Lease{
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
@@ -34,8 +35,12 @@ type clientJSON struct {
|
||||
WHOIS *whois.Info `json:"whois_info,omitempty"`
|
||||
SafeSearchConf *filtering.SafeSearchConfig `json:"safe_search"`
|
||||
|
||||
// Schedule is blocked services schedule for every day of the week.
|
||||
Schedule *schedule.Weekly `json:"blocked_services_schedule"`
|
||||
|
||||
Name string `json:"name"`
|
||||
|
||||
// BlockedServices is the names of blocked services.
|
||||
BlockedServices []string `json:"blocked_services"`
|
||||
IDs []string `json:"ids"`
|
||||
Tags []string `json:"tags"`
|
||||
@@ -53,12 +58,40 @@ type clientJSON struct {
|
||||
IgnoreStatistics aghalg.NullBool `json:"ignore_statistics"`
|
||||
}
|
||||
|
||||
// copySettings returns a copy of specific settings from JSON or a previous
|
||||
// client.
|
||||
func (j *clientJSON) copySettings(
|
||||
prev *Client,
|
||||
) (weekly *schedule.Weekly, ignoreQueryLog, ignoreStatistics bool) {
|
||||
if j.Schedule != nil {
|
||||
weekly = j.Schedule.Clone()
|
||||
} else if prev != nil && prev.BlockedServices != nil {
|
||||
weekly = prev.BlockedServices.Schedule.Clone()
|
||||
} else {
|
||||
weekly = schedule.EmptyWeekly()
|
||||
}
|
||||
|
||||
if j.IgnoreQueryLog != aghalg.NBNull {
|
||||
ignoreQueryLog = j.IgnoreQueryLog == aghalg.NBTrue
|
||||
} else if prev != nil {
|
||||
ignoreQueryLog = prev.IgnoreQueryLog
|
||||
}
|
||||
|
||||
if j.IgnoreStatistics != aghalg.NBNull {
|
||||
ignoreStatistics = j.IgnoreStatistics == aghalg.NBTrue
|
||||
} else if prev != nil {
|
||||
ignoreStatistics = prev.IgnoreStatistics
|
||||
}
|
||||
|
||||
return weekly, ignoreQueryLog, ignoreStatistics
|
||||
}
|
||||
|
||||
type runtimeClientJSON struct {
|
||||
WHOIS *whois.Info `json:"whois_info"`
|
||||
|
||||
IP netip.Addr `json:"ip"`
|
||||
Name string `json:"name"`
|
||||
Source clientSource `json:"source"`
|
||||
IP netip.Addr `json:"ip"`
|
||||
Name string `json:"name"`
|
||||
Source client.Source `json:"source"`
|
||||
}
|
||||
|
||||
type clientListJSON struct {
|
||||
@@ -91,9 +124,20 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
data.RuntimeClients = append(data.RuntimeClients, cj)
|
||||
}
|
||||
|
||||
for _, l := range clients.dhcp.Leases() {
|
||||
cj := runtimeClientJSON{
|
||||
Name: l.Hostname,
|
||||
Source: client.SourceDHCP,
|
||||
IP: l.IP,
|
||||
WHOIS: &whois.Info{},
|
||||
}
|
||||
|
||||
data.RuntimeClients = append(data.RuntimeClients, cj)
|
||||
}
|
||||
|
||||
data.Tags = clientTags
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, data)
|
||||
aghhttp.WriteJSONResponseOK(w, r, data)
|
||||
}
|
||||
|
||||
// jsonToClient converts JSON object to Client object.
|
||||
@@ -119,9 +163,15 @@ func (clients *clientsContainer) jsonToClient(cj clientJSON, prev *Client) (c *C
|
||||
}
|
||||
}
|
||||
|
||||
weekly := schedule.EmptyWeekly()
|
||||
if prev != nil {
|
||||
weekly = prev.BlockedServices.Schedule.Clone()
|
||||
weekly, ignoreQueryLog, ignoreStatistics := cj.copySettings(prev)
|
||||
|
||||
bs := &filtering.BlockedServices{
|
||||
Schedule: weekly,
|
||||
IDs: cj.BlockedServices,
|
||||
}
|
||||
err = bs.Validate()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validating blocked services: %w", err)
|
||||
}
|
||||
|
||||
c = &Client{
|
||||
@@ -129,10 +179,7 @@ func (clients *clientsContainer) jsonToClient(cj clientJSON, prev *Client) (c *C
|
||||
|
||||
Name: cj.Name,
|
||||
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: weekly,
|
||||
IDs: cj.BlockedServices,
|
||||
},
|
||||
BlockedServices: bs,
|
||||
|
||||
IDs: cj.IDs,
|
||||
Tags: cj.Tags,
|
||||
@@ -143,18 +190,8 @@ func (clients *clientsContainer) jsonToClient(cj clientJSON, prev *Client) (c *C
|
||||
ParentalEnabled: cj.ParentalEnabled,
|
||||
SafeBrowsingEnabled: cj.SafeBrowsingEnabled,
|
||||
UseOwnBlockedServices: !cj.UseGlobalBlockedServices,
|
||||
}
|
||||
|
||||
if cj.IgnoreQueryLog != aghalg.NBNull {
|
||||
c.IgnoreQueryLog = cj.IgnoreQueryLog == aghalg.NBTrue
|
||||
} else if prev != nil {
|
||||
c.IgnoreQueryLog = prev.IgnoreQueryLog
|
||||
}
|
||||
|
||||
if cj.IgnoreStatistics != aghalg.NBNull {
|
||||
c.IgnoreStatistics = cj.IgnoreStatistics == aghalg.NBTrue
|
||||
} else if prev != nil {
|
||||
c.IgnoreStatistics = prev.IgnoreStatistics
|
||||
IgnoreQueryLog: ignoreQueryLog,
|
||||
IgnoreStatistics: ignoreStatistics,
|
||||
}
|
||||
|
||||
if safeSearchConf.Enabled {
|
||||
@@ -191,6 +228,7 @@ func clientToJSON(c *Client) (cj *clientJSON) {
|
||||
|
||||
UseGlobalBlockedServices: !c.UseOwnBlockedServices,
|
||||
|
||||
Schedule: c.BlockedServices.Schedule,
|
||||
BlockedServices: c.BlockedServices.IDs,
|
||||
|
||||
Upstreams: c.Upstreams,
|
||||
@@ -338,7 +376,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
})
|
||||
}
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, data)
|
||||
aghhttp.WriteJSONResponseOK(w, r, data)
|
||||
}
|
||||
|
||||
// findRuntime looks up the IP in runtime and temporary storages, like
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/confmigrate"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
@@ -21,7 +22,6 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/google/renameio/v2/maybe"
|
||||
"golang.org/x/exp/slices"
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -114,8 +114,6 @@ type configuration struct {
|
||||
Language string `yaml:"language"`
|
||||
// Theme is a UI theme for current user.
|
||||
Theme Theme `yaml:"theme"`
|
||||
// DebugPProf defines if the profiling HTTP handler will listen on :6060.
|
||||
DebugPProf bool `yaml:"debug_pprof"`
|
||||
|
||||
DNS dnsConfig `yaml:"dns"`
|
||||
TLS tlsConfigSettings `yaml:"tls"`
|
||||
@@ -133,7 +131,8 @@ type configuration struct {
|
||||
WhitelistFilters []filtering.FilterYAML `yaml:"whitelist_filters"`
|
||||
UserRules []string `yaml:"user_rules"`
|
||||
|
||||
DHCP *dhcpd.ServerConfig `yaml:"dhcp"`
|
||||
DHCP *dhcpd.ServerConfig `yaml:"dhcp"`
|
||||
Filtering *filtering.Config `yaml:"filtering"`
|
||||
|
||||
// Clients contains the YAML representations of the persistent clients.
|
||||
// This field is only used for reading and writing persistent client data.
|
||||
@@ -147,7 +146,9 @@ type configuration struct {
|
||||
|
||||
sync.RWMutex `yaml:"-"`
|
||||
|
||||
SchemaVersion int `yaml:"schema_version"` // keeping last so that users will be less tempted to change it -- used when upgrading between versions
|
||||
// SchemaVersion is the version of the configuration schema. See
|
||||
// [confmigrate.LastSchemaVersion].
|
||||
SchemaVersion uint `yaml:"schema_version"`
|
||||
}
|
||||
|
||||
// httpConfig is a block with HTTP configuration params.
|
||||
@@ -155,6 +156,9 @@ type configuration struct {
|
||||
// Field ordering is important, YAML fields better not to be reordered, if it's
|
||||
// not absolutely necessary.
|
||||
type httpConfig struct {
|
||||
// Pprof defines the profiling HTTP handler.
|
||||
Pprof *httpPprofConfig `yaml:"pprof"`
|
||||
|
||||
// Address is the address to serve the web UI on.
|
||||
Address netip.AddrPort
|
||||
|
||||
@@ -163,6 +167,15 @@ type httpConfig struct {
|
||||
SessionTTL timeutil.Duration `yaml:"session_ttl"`
|
||||
}
|
||||
|
||||
// httpPprofConfig is the block with pprof HTTP configuration.
|
||||
type httpPprofConfig struct {
|
||||
// Port for the profiling handler.
|
||||
Port uint16 `yaml:"port"`
|
||||
|
||||
// Enabled defines if the profiling handler is enabled.
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
||||
// dnsConfig is a block with DNS configuration params.
|
||||
//
|
||||
// Field ordering is important, YAML fields better not to be reordered, if it's
|
||||
@@ -175,9 +188,10 @@ type dnsConfig struct {
|
||||
// in query log and statistics.
|
||||
AnonymizeClientIP bool `yaml:"anonymize_client_ip"`
|
||||
|
||||
dnsforward.FilteringConfig `yaml:",inline"`
|
||||
|
||||
DnsfilterConf *filtering.Config `yaml:",inline"`
|
||||
// Config is the embed configuration with DNS params.
|
||||
//
|
||||
// TODO(a.garipov): Remove embed.
|
||||
dnsforward.Config `yaml:",inline"`
|
||||
|
||||
// UpstreamTimeout is the timeout for querying upstream servers.
|
||||
UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"`
|
||||
@@ -277,18 +291,19 @@ var config = &configuration{
|
||||
HTTPConfig: httpConfig{
|
||||
Address: netip.AddrPortFrom(netip.IPv4Unspecified(), 3000),
|
||||
SessionTTL: timeutil.Duration{Duration: 30 * timeutil.Day},
|
||||
Pprof: &httpPprofConfig{
|
||||
Enabled: false,
|
||||
Port: 6060,
|
||||
},
|
||||
},
|
||||
DNS: dnsConfig{
|
||||
BindHosts: []netip.Addr{netip.IPv4Unspecified()},
|
||||
Port: defaultPortDNS,
|
||||
FilteringConfig: dnsforward.FilteringConfig{
|
||||
ProtectionEnabled: true, // whether or not use any of filtering features
|
||||
BlockingMode: dnsforward.BlockingModeDefault,
|
||||
BlockedResponseTTL: 10, // in seconds
|
||||
Ratelimit: 20,
|
||||
RefuseAny: true,
|
||||
AllServers: false,
|
||||
HandleDDR: true,
|
||||
Config: dnsforward.Config{
|
||||
Ratelimit: 20,
|
||||
RefuseAny: true,
|
||||
AllServers: false,
|
||||
HandleDDR: true,
|
||||
FastestTimeout: timeutil.Duration{
|
||||
Duration: fastip.DefaultPingWaitTimeout,
|
||||
},
|
||||
@@ -308,33 +323,6 @@ var config = &configuration{
|
||||
// was later increased to 300 due to https://github.com/AdguardTeam/AdGuardHome/issues/2257
|
||||
MaxGoroutines: 300,
|
||||
},
|
||||
DnsfilterConf: &filtering.Config{
|
||||
FilteringEnabled: true,
|
||||
FiltersUpdateIntervalHours: 24,
|
||||
|
||||
ParentalEnabled: false,
|
||||
SafeBrowsingEnabled: false,
|
||||
|
||||
SafeBrowsingCacheSize: 1 * 1024 * 1024,
|
||||
SafeSearchCacheSize: 1 * 1024 * 1024,
|
||||
ParentalCacheSize: 1 * 1024 * 1024,
|
||||
CacheTime: 30,
|
||||
|
||||
SafeSearchConf: filtering.SafeSearchConfig{
|
||||
Enabled: false,
|
||||
Bing: true,
|
||||
DuckDuckGo: true,
|
||||
Google: true,
|
||||
Pixabay: true,
|
||||
Yandex: true,
|
||||
YouTube: true,
|
||||
},
|
||||
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
IDs: []string{},
|
||||
},
|
||||
},
|
||||
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
|
||||
UsePrivateRDNS: true,
|
||||
},
|
||||
@@ -371,6 +359,37 @@ var config = &configuration{
|
||||
URL: "https://adguardteam.github.io/HostlistsRegistry/assets/filter_2.txt",
|
||||
Name: "AdAway Default Blocklist",
|
||||
}},
|
||||
Filtering: &filtering.Config{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
BlockedResponseTTL: 10, // in seconds
|
||||
|
||||
FilteringEnabled: true,
|
||||
FiltersUpdateIntervalHours: 24,
|
||||
|
||||
ParentalEnabled: false,
|
||||
SafeBrowsingEnabled: false,
|
||||
|
||||
SafeBrowsingCacheSize: 1 * 1024 * 1024,
|
||||
SafeSearchCacheSize: 1 * 1024 * 1024,
|
||||
ParentalCacheSize: 1 * 1024 * 1024,
|
||||
CacheTime: 30,
|
||||
|
||||
SafeSearchConf: filtering.SafeSearchConfig{
|
||||
Enabled: false,
|
||||
Bing: true,
|
||||
DuckDuckGo: true,
|
||||
Google: true,
|
||||
Pixabay: true,
|
||||
Yandex: true,
|
||||
YouTube: true,
|
||||
},
|
||||
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
IDs: []string{},
|
||||
},
|
||||
},
|
||||
DHCP: &dhcpd.ServerConfig{
|
||||
LocalDomainName: "lan",
|
||||
Conf4: dhcpd.V4ServerConf{
|
||||
@@ -398,7 +417,7 @@ var config = &configuration{
|
||||
MaxAge: 3,
|
||||
},
|
||||
OSConfig: &osConfig{},
|
||||
SchemaVersion: currentSchemaVersion,
|
||||
SchemaVersion: confmigrate.LastSchemaVersion,
|
||||
Theme: ThemeAuto,
|
||||
}
|
||||
|
||||
@@ -414,28 +433,10 @@ func (c *configuration) getConfigFilename() string {
|
||||
if !filepath.IsAbs(configFile) {
|
||||
configFile = filepath.Join(Context.workDir, configFile)
|
||||
}
|
||||
|
||||
return configFile
|
||||
}
|
||||
|
||||
// readLogSettings reads logging settings from the config file. We do it in a
|
||||
// separate method in order to configure logger before the actual configuration
|
||||
// is parsed and applied.
|
||||
func readLogSettings() (ls *logSettings) {
|
||||
conf := &configuration{}
|
||||
|
||||
yamlFile, err := readConfigFile()
|
||||
if err != nil {
|
||||
return &logSettings{}
|
||||
}
|
||||
|
||||
err = yaml.Unmarshal(yamlFile, conf)
|
||||
if err != nil {
|
||||
log.Error("Couldn't get logging settings from the configuration: %s", err)
|
||||
}
|
||||
|
||||
return &conf.Log
|
||||
}
|
||||
|
||||
// validateBindHosts returns error if any of binding hosts from configuration is
|
||||
// not a valid IP address.
|
||||
func validateBindHosts(conf *configuration) (err error) {
|
||||
@@ -452,21 +453,59 @@ func validateBindHosts(conf *configuration) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseConfig loads configuration from the YAML file
|
||||
// parseConfig loads configuration from the YAML file, upgrading it if
|
||||
// necessary.
|
||||
func parseConfig() (err error) {
|
||||
var fileData []byte
|
||||
fileData, err = readConfigFile()
|
||||
// Do the upgrade if necessary.
|
||||
config.fileData, err = readConfigFile()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config.fileData = nil
|
||||
err = yaml.Unmarshal(fileData, &config)
|
||||
migrator := confmigrate.New(&confmigrate.Config{
|
||||
WorkingDir: Context.workDir,
|
||||
})
|
||||
|
||||
var upgraded bool
|
||||
config.fileData, upgraded, err = migrator.Migrate(
|
||||
config.fileData,
|
||||
confmigrate.LastSchemaVersion,
|
||||
)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
} else if upgraded {
|
||||
err = maybe.WriteFile(config.getConfigFilename(), config.fileData, 0o644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing new config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = yaml.Unmarshal(config.fileData, &config)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if config.DNS.UpstreamTimeout.Duration == 0 {
|
||||
config.DNS.UpstreamTimeout = timeutil.Duration{Duration: dnsforward.DefaultTimeout}
|
||||
}
|
||||
|
||||
err = setContextTLSCipherIDs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateConfig returns error if the configuration is invalid.
|
||||
func validateConfig() (err error) {
|
||||
err = validateBindHosts(config)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
@@ -498,17 +537,8 @@ func parseConfig() (err error) {
|
||||
return fmt.Errorf("validating udp ports: %w", err)
|
||||
}
|
||||
|
||||
if !filtering.ValidateUpdateIvl(config.DNS.DnsfilterConf.FiltersUpdateIntervalHours) {
|
||||
config.DNS.DnsfilterConf.FiltersUpdateIntervalHours = 24
|
||||
}
|
||||
|
||||
if config.DNS.UpstreamTimeout.Duration == 0 {
|
||||
config.DNS.UpstreamTimeout = timeutil.Duration{Duration: dnsforward.DefaultTimeout}
|
||||
}
|
||||
|
||||
err = setContextTLSCipherIDs()
|
||||
if err != nil {
|
||||
return err
|
||||
if !filtering.ValidateUpdateIvl(config.Filtering.FiltersUpdateIntervalHours) {
|
||||
config.Filtering.FiltersUpdateIntervalHours = 24
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -563,7 +593,6 @@ func (c *configuration) write() (err error) {
|
||||
config.Stats.Interval = timeutil.Duration{Duration: statsConf.Limit}
|
||||
config.Stats.Enabled = statsConf.Enabled
|
||||
config.Stats.Ignored = statsConf.Ignored.Values()
|
||||
slices.Sort(config.Stats.Ignored)
|
||||
}
|
||||
|
||||
if Context.queryLog != nil {
|
||||
@@ -575,21 +604,20 @@ func (c *configuration) write() (err error) {
|
||||
config.QueryLog.Interval = timeutil.Duration{Duration: dc.RotationIvl}
|
||||
config.QueryLog.MemSize = dc.MemSize
|
||||
config.QueryLog.Ignored = dc.Ignored.Values()
|
||||
slices.Sort(config.Stats.Ignored)
|
||||
}
|
||||
|
||||
if Context.filters != nil {
|
||||
Context.filters.WriteDiskConfig(config.DNS.DnsfilterConf)
|
||||
config.Filters = config.DNS.DnsfilterConf.Filters
|
||||
config.WhitelistFilters = config.DNS.DnsfilterConf.WhitelistFilters
|
||||
config.UserRules = config.DNS.DnsfilterConf.UserRules
|
||||
Context.filters.WriteDiskConfig(config.Filtering)
|
||||
config.Filters = config.Filtering.Filters
|
||||
config.WhitelistFilters = config.Filtering.WhitelistFilters
|
||||
config.UserRules = config.Filtering.UserRules
|
||||
}
|
||||
|
||||
if s := Context.dnsServer; s != nil {
|
||||
c := dnsforward.FilteringConfig{}
|
||||
c := dnsforward.Config{}
|
||||
s.WriteDiskConfig(&c)
|
||||
dns := &config.DNS
|
||||
dns.FilteringConfig = c
|
||||
dns.Config = c
|
||||
|
||||
dns.LocalPTRResolvers = s.LocalPTRResolvers()
|
||||
|
||||
|
||||
@@ -127,12 +127,12 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
var (
|
||||
fltConf *dnsforward.FilteringConfig
|
||||
fltConf *dnsforward.Config
|
||||
protectionDisabledUntil *time.Time
|
||||
protectionEnabled bool
|
||||
)
|
||||
if Context.dnsServer != nil {
|
||||
fltConf = &dnsforward.FilteringConfig{}
|
||||
fltConf = &dnsforward.Config{}
|
||||
Context.dnsServer.WriteDiskConfig(fltConf)
|
||||
protectionEnabled, protectionDisabledUntil = Context.dnsServer.UpdatedProtectionStatus()
|
||||
}
|
||||
@@ -170,7 +170,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
resp.IsDHCPAvailable = Context.dhcpServer != nil
|
||||
}
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
}
|
||||
|
||||
// ------------------------
|
||||
@@ -321,9 +321,10 @@ func preInstallHandler(handler http.Handler) http.Handler {
|
||||
return &preInstallHandlerStruct{handler}
|
||||
}
|
||||
|
||||
// handleHTTPSRedirect redirects the request to HTTPS, if needed. If ok is
|
||||
// true, the middleware must continue handling the request.
|
||||
func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||
// handleHTTPSRedirect redirects the request to HTTPS, if needed, and adds some
|
||||
// HTTPS-related headers. If proceed is true, the middleware must continue
|
||||
// handling the request.
|
||||
func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
web := Context.web
|
||||
if web.httpsServer.server == nil {
|
||||
return true
|
||||
@@ -362,21 +363,17 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||
respHdr.Set(httphdr.AltSvc, altSvc)
|
||||
}
|
||||
|
||||
if r.TLS == nil && forceHTTPS {
|
||||
hostPort := host
|
||||
if portHTTPS != defaultPortHTTPS {
|
||||
hostPort = netutil.JoinHostPort(host, portHTTPS)
|
||||
if forceHTTPS {
|
||||
if r.TLS == nil {
|
||||
u := httpsURL(r.URL, host, portHTTPS)
|
||||
http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
httpsURL := &url.URL{
|
||||
Scheme: aghhttp.SchemeHTTPS,
|
||||
Host: hostPort,
|
||||
Path: r.URL.Path,
|
||||
RawQuery: r.URL.RawQuery,
|
||||
}
|
||||
http.Redirect(w, r, httpsURL.String(), http.StatusTemporaryRedirect)
|
||||
|
||||
return false
|
||||
// TODO(a.garipov): Consider adding a configurable max-age. Currently,
|
||||
// the default is 365 days.
|
||||
respHdr.Set(httphdr.StrictTransportSecurity, aghhttp.HdrValStrictTransportSecurity)
|
||||
}
|
||||
|
||||
// Allow the frontend from the HTTP origin to send requests to the HTTPS
|
||||
@@ -395,6 +392,22 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||
return true
|
||||
}
|
||||
|
||||
// httpsURL returns a copy of u for redirection to the HTTPS version, taking the
|
||||
// hostname and the HTTPS port into account.
|
||||
func httpsURL(u *url.URL, host string, portHTTPS int) (redirectURL *url.URL) {
|
||||
hostPort := host
|
||||
if portHTTPS != defaultPortHTTPS {
|
||||
hostPort = netutil.JoinHostPort(host, portHTTPS)
|
||||
}
|
||||
|
||||
return &url.URL{
|
||||
Scheme: aghhttp.SchemeHTTPS,
|
||||
Host: hostPort,
|
||||
Path: u.Path,
|
||||
RawQuery: u.RawQuery,
|
||||
}
|
||||
}
|
||||
|
||||
// postInstall lets the handler to run only if firstRun is false. Otherwise, it
|
||||
// redirects to /install.html. It also enforces HTTPS if it is enabled and
|
||||
// configured and sets appropriate access control headers.
|
||||
@@ -408,11 +421,10 @@ func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.Res
|
||||
return
|
||||
}
|
||||
|
||||
if !handleHTTPSRedirect(w, r) {
|
||||
return
|
||||
proceed := handleHTTPSRedirect(w, r)
|
||||
if proceed {
|
||||
handler(w, r)
|
||||
}
|
||||
|
||||
handler(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ func (web *webAPI) handleInstallGetAddresses(w http.ResponseWriter, r *http.Requ
|
||||
data.Interfaces[iface.Name] = iface
|
||||
}
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, data)
|
||||
aghhttp.WriteJSONResponseOK(w, r, data)
|
||||
}
|
||||
|
||||
type checkConfReqEnt struct {
|
||||
@@ -190,7 +190,7 @@ func (web *webAPI) handleInstallCheckConfig(w http.ResponseWriter, r *http.Reque
|
||||
resp.StaticIP = handleStaticIP(req.DNS.IP, req.SetStaticIP)
|
||||
}
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
}
|
||||
|
||||
// handleStaticIP - handles static IP request
|
||||
|
||||
@@ -33,7 +33,7 @@ func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
resp := &versionResponse{}
|
||||
if web.conf.disableUpdate {
|
||||
resp.Disabled = true
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -68,7 +68,7 @@ func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
}
|
||||
|
||||
// requestVersionInfo sets the VersionInfo field of resp if it can reach the
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
@@ -60,12 +59,12 @@ func initDNS() (err error) {
|
||||
ShouldCountClient: Context.clients.shouldCountClient,
|
||||
}
|
||||
|
||||
set, err := aghnet.NewDomainNameSet(config.Stats.Ignored)
|
||||
engine, err := aghnet.NewIgnoreEngine(config.Stats.Ignored)
|
||||
if err != nil {
|
||||
return fmt.Errorf("statistics: ignored list: %w", err)
|
||||
}
|
||||
|
||||
statsConf.Ignored = set
|
||||
statsConf.Ignored = engine
|
||||
Context.stats, err = stats.New(statsConf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init stats: %w", err)
|
||||
@@ -84,18 +83,18 @@ func initDNS() (err error) {
|
||||
FileEnabled: config.QueryLog.FileEnabled,
|
||||
}
|
||||
|
||||
set, err = aghnet.NewDomainNameSet(config.QueryLog.Ignored)
|
||||
engine, err = aghnet.NewIgnoreEngine(config.QueryLog.Ignored)
|
||||
if err != nil {
|
||||
return fmt.Errorf("querylog: ignored list: %w", err)
|
||||
}
|
||||
|
||||
conf.Ignored = set
|
||||
conf.Ignored = engine
|
||||
Context.queryLog, err = querylog.New(conf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init querylog: %w", err)
|
||||
}
|
||||
|
||||
Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil)
|
||||
Context.filters, err = filtering.New(config.Filtering, nil)
|
||||
if err != nil {
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return err
|
||||
@@ -123,7 +122,7 @@ func initDNSServer(
|
||||
filters *filtering.DNSFilter,
|
||||
sts stats.Interface,
|
||||
qlog querylog.QueryLog,
|
||||
dhcpSrv dhcpd.Interface,
|
||||
dhcpSrv dnsforward.DHCP,
|
||||
anonymizer *aghnet.IPMut,
|
||||
httpReg aghhttp.RegisterFunc,
|
||||
tlsConf *tlsConfigSettings,
|
||||
@@ -231,13 +230,13 @@ func newServerConfig(
|
||||
hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()})
|
||||
|
||||
newConf = &dnsforward.ServerConfig{
|
||||
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
|
||||
TCPListenAddrs: ipsToTCPAddrs(hosts, dnsConf.Port),
|
||||
FilteringConfig: dnsConf.FilteringConfig,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpReg,
|
||||
UseDNS64: config.DNS.UseDNS64,
|
||||
DNS64Prefixes: config.DNS.DNS64Prefixes,
|
||||
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
|
||||
TCPListenAddrs: ipsToTCPAddrs(hosts, dnsConf.Port),
|
||||
Config: dnsConf.Config,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpReg,
|
||||
UseDNS64: config.DNS.UseDNS64,
|
||||
DNS64Prefixes: config.DNS.DNS64Prefixes,
|
||||
}
|
||||
|
||||
var initialAddresses []netip.Addr
|
||||
@@ -378,7 +377,7 @@ func getDNSEncryption() (de dnsEncryption) {
|
||||
|
||||
// applyAdditionalFiltering adds additional client information and settings if
|
||||
// the client has them.
|
||||
func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering.Settings) {
|
||||
func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filtering.Settings) {
|
||||
// pref is a prefix for logging messages around the scope.
|
||||
const pref = "applying filters"
|
||||
|
||||
@@ -386,7 +385,7 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering
|
||||
|
||||
log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID)
|
||||
|
||||
if clientIP == nil {
|
||||
if !clientIP.IsValid() {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -502,14 +501,10 @@ func closeDNSServer() {
|
||||
if err != nil {
|
||||
log.Debug("closing stats: %s", err)
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Find out if it's safe.
|
||||
Context.stats = nil
|
||||
}
|
||||
|
||||
if Context.queryLog != nil {
|
||||
Context.queryLog.Close()
|
||||
Context.queryLog = nil
|
||||
}
|
||||
|
||||
log.Debug("all dns modules are closed")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4})
|
||||
|
||||
func TestApplyAdditionalFiltering(t *testing.T) {
|
||||
var err error
|
||||
|
||||
@@ -78,7 +80,7 @@ func TestApplyAdditionalFiltering(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
setts := &filtering.Settings{}
|
||||
|
||||
applyAdditionalFiltering(net.IP{1, 2, 3, 4}, tc.id, setts)
|
||||
applyAdditionalFiltering(testIPv4, tc.id, setts)
|
||||
tc.FilteringEnabled(t, setts.FilteringEnabled)
|
||||
tc.SafeSearchEnabled(t, setts.SafeSearchEnabled)
|
||||
tc.SafeBrowsingEnabled(t, setts.SafeBrowsingEnabled)
|
||||
@@ -169,7 +171,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
setts := &filtering.Settings{}
|
||||
|
||||
applyAdditionalFiltering(net.IP{1, 2, 3, 4}, tc.id, setts)
|
||||
applyAdditionalFiltering(testIPv4, tc.id, setts)
|
||||
require.Len(t, setts.ServicesRules, tc.wantLen)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
@@ -37,12 +38,6 @@ import (
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"golang.org/x/exp/slices"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
|
||||
const (
|
||||
// Used in config to indicate that syslog or eventlog (win) should be used for logger output
|
||||
configSyslog = "syslog"
|
||||
)
|
||||
|
||||
// Global context
|
||||
@@ -104,8 +99,11 @@ func Main(clientBuildFS fs.FS) {
|
||||
// package flag.
|
||||
opts := loadCmdLineOpts()
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
signals := make(chan os.Signal, 1)
|
||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
sig := <-signals
|
||||
@@ -117,19 +115,19 @@ func Main(clientBuildFS fs.FS) {
|
||||
default:
|
||||
cleanup(context.Background())
|
||||
cleanupAlways()
|
||||
os.Exit(0)
|
||||
close(done)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if opts.serviceControlAction != "" {
|
||||
handleServiceControlAction(opts, clientBuildFS, signals)
|
||||
handleServiceControlAction(opts, clientBuildFS, signals, done)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// run the protection
|
||||
run(opts, clientBuildFS)
|
||||
run(opts, clientBuildFS, done)
|
||||
}
|
||||
|
||||
// setupContext initializes [Context] fields. It also reads and upgrades
|
||||
@@ -147,14 +145,8 @@ func setupContext(opts options) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Do the upgrade if necessary.
|
||||
err = upgradeConfig()
|
||||
err = parseConfig()
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
if err = parseConfig(); err != nil {
|
||||
log.Error("parsing configuration file: %s", err)
|
||||
|
||||
os.Exit(1)
|
||||
@@ -239,7 +231,6 @@ func setupHostsContainer() (err error) {
|
||||
}
|
||||
|
||||
Context.etcHosts, err = aghnet.NewHostsContainer(
|
||||
filtering.SysHostsListID,
|
||||
aghos.RootDirFS(),
|
||||
hostsWatcher,
|
||||
aghnet.DefaultHostsPaths()...,
|
||||
@@ -275,7 +266,7 @@ func setupOpts(opts options) (err error) {
|
||||
|
||||
// initContextClients initializes Context clients and related fields.
|
||||
func initContextClients() (err error) {
|
||||
err = setupDNSFilteringConf(config.DNS.DnsfilterConf)
|
||||
err = setupDNSFilteringConf(config.Filtering)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
@@ -296,17 +287,17 @@ func initContextClients() (err error) {
|
||||
return fmt.Errorf("initing dhcp: %w", err)
|
||||
}
|
||||
|
||||
var arpdb aghnet.ARPDB
|
||||
var arpDB arpdb.Interface
|
||||
if config.Clients.Sources.ARP {
|
||||
arpdb = aghnet.NewARPDB()
|
||||
arpDB = arpdb.New()
|
||||
}
|
||||
|
||||
err = Context.clients.Init(
|
||||
config.Clients.Persistent,
|
||||
Context.dhcpServer,
|
||||
Context.etcHosts,
|
||||
arpdb,
|
||||
config.DNS.DnsfilterConf,
|
||||
arpDB,
|
||||
config.Filtering,
|
||||
)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
@@ -368,6 +359,9 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
|
||||
pcService = "parental control"
|
||||
defaultParentalServer = `https://family.adguard-dns.com/dns-query`
|
||||
pcTXTSuffix = `pc.dns.adguard.com.`
|
||||
|
||||
defaultSafeBrowsingBlockHost = "standard-block.dns.adguard.com"
|
||||
defaultParentalBlockHost = "family-block.dns.adguard.com"
|
||||
)
|
||||
|
||||
conf.EtcHosts = Context.etcHosts
|
||||
@@ -404,6 +398,10 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
|
||||
CacheSize: conf.SafeBrowsingCacheSize,
|
||||
})
|
||||
|
||||
if conf.SafeBrowsingBlockHost != "" {
|
||||
conf.SafeBrowsingBlockHost = defaultSafeBrowsingBlockHost
|
||||
}
|
||||
|
||||
parUps, err := upstream.AddressToUpstream(defaultParentalServer, upsOpts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("converting parental server: %w", err)
|
||||
@@ -417,6 +415,10 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
|
||||
CacheSize: conf.ParentalCacheSize,
|
||||
})
|
||||
|
||||
if conf.ParentalBlockHost != "" {
|
||||
conf.ParentalBlockHost = defaultParentalBlockHost
|
||||
}
|
||||
|
||||
conf.SafeSearchConf.CustomResolver = safeSearchResolver{}
|
||||
conf.SafeSearch, err = safesearch.NewDefault(
|
||||
conf.SafeSearchConf,
|
||||
@@ -510,7 +512,7 @@ func fatalOnError(err error) {
|
||||
}
|
||||
|
||||
// run configures and starts AdGuard Home.
|
||||
func run(opts options, clientBuildFS fs.FS) {
|
||||
func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
// Configure config filename.
|
||||
initConfigFilename(opts)
|
||||
|
||||
@@ -547,7 +549,7 @@ func run(opts options, clientBuildFS fs.FS) {
|
||||
fatalOnError(err)
|
||||
|
||||
upd := updater.NewUpdater(&updater.Config{
|
||||
Client: config.DNS.DnsfilterConf.HTTPClient,
|
||||
Client: config.Filtering.HTTPClient,
|
||||
Version: version.Version(),
|
||||
Channel: version.Channel(),
|
||||
GOARCH: runtime.GOARCH,
|
||||
@@ -567,9 +569,8 @@ func run(opts options, clientBuildFS fs.FS) {
|
||||
err = config.write()
|
||||
fatalOnError(err)
|
||||
|
||||
if config.DebugPProf {
|
||||
// TODO(a.garipov): Make the address configurable.
|
||||
startPprof("localhost:6060")
|
||||
if config.HTTPConfig.Pprof.Enabled {
|
||||
startPprof(config.HTTPConfig.Pprof.Port)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -616,8 +617,8 @@ func run(opts options, clientBuildFS fs.FS) {
|
||||
|
||||
Context.web.start()
|
||||
|
||||
// Wait indefinitely for other goroutines to complete their job.
|
||||
select {}
|
||||
// Wait for other goroutines to complete their job.
|
||||
<-done
|
||||
}
|
||||
|
||||
// initUsers initializes context auth module. Clears config users field.
|
||||
@@ -748,79 +749,6 @@ func initWorkingDir(opts options) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// configureLogger configures logger level and output.
|
||||
func configureLogger(opts options) (err error) {
|
||||
ls := getLogSettings(opts)
|
||||
|
||||
// Configure logger level.
|
||||
if ls.Verbose {
|
||||
log.SetLevel(log.DEBUG)
|
||||
}
|
||||
|
||||
// Make sure that we see the microseconds in logs, as networking stuff can
|
||||
// happen pretty quickly.
|
||||
log.SetFlags(log.LstdFlags | log.Lmicroseconds)
|
||||
|
||||
// Write logs to stdout by default.
|
||||
if ls.File == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ls.File == configSyslog {
|
||||
// Use syslog where it is possible and eventlog on Windows.
|
||||
err = aghos.ConfigureSyslog(serviceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot initialize syslog: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
logFilePath := ls.File
|
||||
if !filepath.IsAbs(logFilePath) {
|
||||
logFilePath = filepath.Join(Context.workDir, logFilePath)
|
||||
}
|
||||
|
||||
log.SetOutput(&lumberjack.Logger{
|
||||
Filename: logFilePath,
|
||||
Compress: ls.Compress,
|
||||
LocalTime: ls.LocalTime,
|
||||
MaxBackups: ls.MaxBackups,
|
||||
MaxSize: ls.MaxSize,
|
||||
MaxAge: ls.MaxAge,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getLogSettings returns a log settings object properly initialized from opts.
|
||||
func getLogSettings(opts options) (ls *logSettings) {
|
||||
ls = readLogSettings()
|
||||
configLogSettings := config.Log
|
||||
|
||||
// Command-line arguments can override config settings.
|
||||
if opts.verbose || configLogSettings.Verbose {
|
||||
ls.Verbose = true
|
||||
}
|
||||
|
||||
ls.File = stringutil.Coalesce(opts.logFile, configLogSettings.File, ls.File)
|
||||
|
||||
// Handle default log settings overrides.
|
||||
ls.Compress = configLogSettings.Compress
|
||||
ls.LocalTime = configLogSettings.LocalTime
|
||||
ls.MaxBackups = configLogSettings.MaxBackups
|
||||
ls.MaxSize = configLogSettings.MaxSize
|
||||
ls.MaxAge = configLogSettings.MaxAge
|
||||
|
||||
if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
|
||||
// When running as a Windows service, use eventlog by default if
|
||||
// nothing else is configured. Otherwise, we'll lose the log output.
|
||||
ls.File = configSyslog
|
||||
}
|
||||
|
||||
return ls
|
||||
}
|
||||
|
||||
// cleanup stops and resets all the modules.
|
||||
func cleanup(ctx context.Context) {
|
||||
log.Info("stopping AdGuard Home")
|
||||
|
||||
@@ -7,6 +7,6 @@ import (
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
initCmdLineOpts()
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ type languageJSON struct {
|
||||
func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("home: language is %s", config.Language)
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, &languageJSON{
|
||||
aghhttp.WriteJSONResponseOK(w, r, &languageJSON{
|
||||
Language: config.Language,
|
||||
})
|
||||
}
|
||||
|
||||
106
internal/home/log.go
Normal file
106
internal/home/log.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// configSyslog is used to indicate that syslog or eventlog (win) should be used
|
||||
// for logger output.
|
||||
const configSyslog = "syslog"
|
||||
|
||||
// configureLogger configures logger level and output.
|
||||
func configureLogger(opts options) (err error) {
|
||||
ls := getLogSettings(opts)
|
||||
|
||||
// Configure logger level.
|
||||
if ls.Verbose {
|
||||
log.SetLevel(log.DEBUG)
|
||||
}
|
||||
|
||||
// Make sure that we see the microseconds in logs, as networking stuff can
|
||||
// happen pretty quickly.
|
||||
log.SetFlags(log.LstdFlags | log.Lmicroseconds)
|
||||
|
||||
// Write logs to stdout by default.
|
||||
if ls.File == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ls.File == configSyslog {
|
||||
// Use syslog where it is possible and eventlog on Windows.
|
||||
err = aghos.ConfigureSyslog(serviceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot initialize syslog: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
logFilePath := ls.File
|
||||
if !filepath.IsAbs(logFilePath) {
|
||||
logFilePath = filepath.Join(Context.workDir, logFilePath)
|
||||
}
|
||||
|
||||
log.SetOutput(&lumberjack.Logger{
|
||||
Filename: logFilePath,
|
||||
Compress: ls.Compress,
|
||||
LocalTime: ls.LocalTime,
|
||||
MaxBackups: ls.MaxBackups,
|
||||
MaxSize: ls.MaxSize,
|
||||
MaxAge: ls.MaxAge,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getLogSettings returns a log settings object properly initialized from opts.
|
||||
func getLogSettings(opts options) (ls *logSettings) {
|
||||
configLogSettings := config.Log
|
||||
|
||||
ls = readLogSettings()
|
||||
if ls == nil {
|
||||
// Use default log settings.
|
||||
ls = &configLogSettings
|
||||
}
|
||||
|
||||
// Command-line arguments can override config settings.
|
||||
if opts.verbose {
|
||||
ls.Verbose = true
|
||||
}
|
||||
ls.File = stringutil.Coalesce(opts.logFile, ls.File)
|
||||
|
||||
if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
|
||||
// When running as a Windows service, use eventlog by default if
|
||||
// nothing else is configured. Otherwise, we'll lose the log output.
|
||||
ls.File = configSyslog
|
||||
}
|
||||
|
||||
return ls
|
||||
}
|
||||
|
||||
// readLogSettings reads logging settings from the config file. We do it in a
|
||||
// separate method in order to configure logger before the actual configuration
|
||||
// is parsed and applied.
|
||||
func readLogSettings() (ls *logSettings) {
|
||||
conf := &configuration{}
|
||||
|
||||
yamlFile, err := readConfigFile()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = yaml.Unmarshal(yamlFile, conf)
|
||||
if err != nil {
|
||||
log.Error("Couldn't get logging settings from the configuration: %s", err)
|
||||
}
|
||||
|
||||
return &conf.Log
|
||||
}
|
||||
@@ -61,7 +61,7 @@ func handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}()
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
}
|
||||
|
||||
// handlePutProfile is the handler for PUT /control/profile/update endpoint.
|
||||
|
||||
@@ -34,6 +34,7 @@ const (
|
||||
type program struct {
|
||||
clientBuildFS fs.FS
|
||||
signals chan os.Signal
|
||||
done chan struct{}
|
||||
opts options
|
||||
}
|
||||
|
||||
@@ -46,19 +47,19 @@ func (p *program) Start(_ service.Service) (err error) {
|
||||
args := p.opts
|
||||
args.runningAsService = true
|
||||
|
||||
go run(args, p.clientBuildFS)
|
||||
go run(args, p.clientBuildFS, p.done)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop implements service.Interface interface for *program.
|
||||
func (p *program) Stop(_ service.Service) (err error) {
|
||||
select {
|
||||
case p.signals <- syscall.SIGINT:
|
||||
// Go on.
|
||||
default:
|
||||
// Stop should not block.
|
||||
}
|
||||
log.Info("service: stopping: waiting for cleanup")
|
||||
|
||||
aghos.SendShutdownSignal(p.signals)
|
||||
|
||||
// Wait for other goroutines to complete their job.
|
||||
<-p.done
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -198,7 +199,12 @@ func restartService() (err error) {
|
||||
// - run: This is a special command that is not supposed to be used directly
|
||||
// it is specified when we register a service, and it indicates to the app
|
||||
// that it is being run as a service/daemon.
|
||||
func handleServiceControlAction(opts options, clientBuildFS fs.FS, signals chan os.Signal) {
|
||||
func handleServiceControlAction(
|
||||
opts options,
|
||||
clientBuildFS fs.FS,
|
||||
signals chan os.Signal,
|
||||
done chan struct{},
|
||||
) {
|
||||
// Call chooseSystem explicitly to introduce OpenBSD support for service
|
||||
// package. It's a noop for other GOOS values.
|
||||
chooseSystem()
|
||||
@@ -233,6 +239,7 @@ func handleServiceControlAction(opts options, clientBuildFS fs.FS, signals chan
|
||||
s, err := service.New(&program{
|
||||
clientBuildFS: clientBuildFS,
|
||||
signals: signals,
|
||||
done: done,
|
||||
opts: runOpts,
|
||||
}, svcConfig)
|
||||
if err != nil {
|
||||
|
||||
@@ -770,7 +770,7 @@ func marshalTLS(w http.ResponseWriter, r *http.Request, data tlsConfig) {
|
||||
data.PrivateKey = ""
|
||||
}
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, data)
|
||||
aghhttp.WriteJSONResponseOK(w, r, data)
|
||||
}
|
||||
|
||||
// registerWebHandlers registers HTTP handlers for TLS configuration.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -312,8 +312,10 @@ func (web *webAPI) mustStartHTTP3(address string) {
|
||||
}
|
||||
}
|
||||
|
||||
// startPprof launches the debug and profiling server on addr.
|
||||
func startPprof(addr string) {
|
||||
// startPprof launches the debug and profiling server on the provided port.
|
||||
func startPprof(port uint16) {
|
||||
addr := netip.AddrPortFrom(netutil.IPv4Localhost(), port)
|
||||
|
||||
runtime.SetBlockProfileRate(1)
|
||||
runtime.SetMutexProfileFraction(1)
|
||||
|
||||
@@ -324,7 +326,7 @@ func startPprof(addr string) {
|
||||
defer log.OnPanic("pprof server")
|
||||
|
||||
log.Info("pprof: listening on %q", addr)
|
||||
err := http.ListenAndServe(addr, mux)
|
||||
err := http.ListenAndServe(addr.String(), mux)
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Error("pprof: shutting down: %s", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user