all: sync with master; upd chlog

This commit is contained in:
Ainar Garipov
2023-09-07 17:13:48 +03:00
parent 3be7676970
commit 7b93f5d7cf
306 changed files with 19770 additions and 4916 deletions

View File

@@ -1,10 +1,10 @@
package home
import (
"encoding"
"fmt"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
@@ -83,50 +83,6 @@ func (c *Client) setSafeSearch(
return nil
}
// clientSource represents the source from which the information about the
// client has been obtained.
type clientSource uint
// Clients information sources. The order determines the priority.
const (
ClientSourceNone clientSource = iota
ClientSourceWHOIS
ClientSourceARP
ClientSourceRDNS
ClientSourceDHCP
ClientSourceHostsFile
ClientSourcePersistent
)
// type check
var _ fmt.Stringer = clientSource(0)
// String returns a human-readable name of cs.
func (cs clientSource) String() (s string) {
switch cs {
case ClientSourceWHOIS:
return "WHOIS"
case ClientSourceARP:
return "ARP"
case ClientSourceRDNS:
return "rDNS"
case ClientSourceDHCP:
return "DHCP"
case ClientSourceHostsFile:
return "etc/hosts"
default:
return ""
}
}
// type check
var _ encoding.TextMarshaler = clientSource(0)
// MarshalText implements encoding.TextMarshaler for the clientSource.
func (cs clientSource) MarshalText() (text []byte, err error) {
return []byte(cs.String()), nil
}
// RuntimeClient is a client information about which has been obtained using the
// source described in the Source field.
type RuntimeClient struct {
@@ -138,5 +94,5 @@ type RuntimeClient struct {
// Source is the source from which the information about the client has
// been obtained.
Source clientSource
Source client.Source
}

View File

@@ -10,8 +10,8 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
@@ -34,7 +34,7 @@ type DHCP interface {
// HostByIP returns the hostname of the DHCP client with the given IP
// address. The address will be netip.Addr{} if there is no such client,
// due to an assumption that a DHCP client must always have an IP address.
// due to an assumption that a DHCP client must always have a hostname.
HostByIP(ip netip.Addr) (host string)
// MACByIP returns the MAC address for the given IP address leased. It
@@ -55,8 +55,8 @@ type clientsContainer struct {
allTags *stringutil.Set
// dhcpServer is used for looking up clients IP addresses by MAC addresses
dhcpServer dhcpd.Interface
// dhcp is the DHCP service implementation.
dhcp DHCP
// dnsServer is used for checking clients IP status access list status
dnsServer *dnsforward.Server
@@ -65,8 +65,8 @@ type clientsContainer struct {
// hosts database.
etcHosts *aghnet.HostsContainer
// arpdb stores the neighbors retrieved from ARP.
arpdb aghnet.ARPDB
// arpDB stores the neighbors retrieved from ARP.
arpDB arpdb.Interface
// lock protects all fields.
//
@@ -93,9 +93,9 @@ type clientsContainer struct {
// Note: this function must be called only once
func (clients *clientsContainer) Init(
objects []*clientObject,
dhcpServer dhcpd.Interface,
dhcpServer DHCP,
etcHosts *aghnet.HostsContainer,
arpdb aghnet.ARPDB,
arpDB arpdb.Interface,
filteringConf *filtering.Config,
) (err error) {
if clients.list != nil {
@@ -108,9 +108,11 @@ func (clients *clientsContainer) Init(
clients.allTags = stringutil.NewSet(clientTags...)
clients.dhcpServer = dhcpServer
// TODO(e.burkov): Use [dhcpsvc] implementation when it's ready.
clients.dhcp = dhcpServer
clients.etcHosts = etcHosts
clients.arpdb = arpdb
clients.arpDB = arpDB
err = clients.addFromConfig(objects, filteringConf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
@@ -124,11 +126,6 @@ func (clients *clientsContainer) Init(
return nil
}
if clients.dhcpServer != nil {
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
clients.onDHCPLeaseChanged(dhcpd.LeaseChangedAdded)
}
if clients.etcHosts != nil {
go clients.handleHostsUpdates()
}
@@ -164,7 +161,7 @@ func (clients *clientsContainer) Start() {
// reloadARP reloads runtime clients from ARP, if configured.
func (clients *clientsContainer) reloadARP() {
if clients.arpdb != nil {
if clients.arpDB != nil {
clients.addFromSystemARP()
}
}
@@ -290,8 +287,8 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
// above loop can generate different orderings when writing to the config
// file: this produces lots of diffs in config files, so sort objects by
// name before writing.
slices.SortStableFunc(objs, func(a, b *clientObject) (sortsBefore bool) {
return a.Name < b.Name
slices.SortStableFunc(objs, func(a, b *clientObject) (res int) {
return strings.Compare(a.Name, b.Name)
})
return objs
@@ -309,58 +306,28 @@ func (clients *clientsContainer) periodicUpdate() {
}
}
// onDHCPLeaseChanged is a callback for the DHCP server. It updates the list of
// runtime clients using the DHCP server's leases.
//
// TODO(e.burkov): Remove when switched to dhcpsvc.
func (clients *clientsContainer) onDHCPLeaseChanged(flags int) {
if clients.dhcpServer == nil || !config.Clients.Sources.DHCP {
return
}
clients.lock.Lock()
defer clients.lock.Unlock()
clients.rmHostsBySrc(ClientSourceDHCP)
if flags == dhcpd.LeaseChangedRemovedAll {
return
}
leases := clients.dhcpServer.Leases(dhcpd.LeasesAll)
n := 0
for _, l := range leases {
if l.Hostname == "" {
continue
}
ok := clients.addHostLocked(l.IP, l.Hostname, ClientSourceDHCP)
if ok {
n++
}
}
log.Debug("clients: added %d client aliases from dhcp", n)
}
// clientSource checks if client with this IP address already exists and returns
// the source which updated it last. It returns [ClientSourceNone] if the
// the source which updated it last. It returns [client.SourceNone] if the
// client doesn't exist.
func (clients *clientsContainer) clientSource(ip netip.Addr) (src clientSource) {
func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source) {
clients.lock.Lock()
defer clients.lock.Unlock()
_, ok := clients.findLocked(ip.String())
if ok {
return ClientSourcePersistent
return client.SourcePersistent
}
rc, ok := clients.ipToRC[ip]
if ok {
return rc.Source
src = rc.Source
}
return ClientSourceNone
if src < client.SourceDHCP && clients.dhcp.HostByIP(ip) != "" {
src = client.SourceDHCP
}
return src
}
// findMultiple is a wrapper around Find to make it a valid client finder for
@@ -521,17 +488,14 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
}
}
if clients.dhcpServer != nil {
return clients.findDHCP(ip)
}
return nil, false
// TODO(e.burkov): Iterate through clients.list only once.
return clients.findDHCP(ip)
}
// findDHCP searches for a client by its MAC, if the DHCP server is active and
// there is such client. clients.lock is expected to be locked.
func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *Client, ok bool) {
foundMAC := clients.dhcpServer.FindMACbyIP(ip)
foundMAC := clients.dhcp.MACByIP(ip)
if foundMAC == nil {
return nil, false
}
@@ -552,8 +516,9 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *Client, ok bool) {
return nil, false
}
// findRuntimeClient finds a runtime client by their IP.
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *RuntimeClient, ok bool) {
// runtimeClient returns a runtime client from internal index. Note that it
// doesn't include DHCP clients.
func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *RuntimeClient, ok bool) {
if ip == (netip.Addr{}) {
return nil, false
}
@@ -566,6 +531,24 @@ func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *RuntimeCl
return rc, ok
}
// findRuntimeClient finds a runtime client by their IP.
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *RuntimeClient, ok bool) {
if rc, ok = clients.runtimeClient(ip); ok && rc.Source > client.SourceDHCP {
return rc, ok
}
host := clients.dhcp.HostByIP(ip)
if host == "" {
return rc, ok
}
return &RuntimeClient{
Host: host,
Source: client.SourceDHCP,
WHOIS: &whois.Info{},
}, true
}
// check validates the client.
func (clients *clientsContainer) check(c *Client) (err error) {
switch {
@@ -761,7 +744,7 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
// Create a RuntimeClient implicitly so that we don't do this check
// again.
rc = &RuntimeClient{
Source: ClientSourceWHOIS,
Source: client.SourceWHOIS,
}
clients.ipToRC[ip] = rc
@@ -780,7 +763,7 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
func (clients *clientsContainer) addHost(
ip netip.Addr,
host string,
src clientSource,
src client.Source,
) (ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
@@ -803,7 +786,7 @@ func (clients *clientsContainer) UpdateAddress(ip netip.Addr, host string, info
defer clients.lock.Unlock()
if host != "" {
ok := clients.addHostLocked(ip, host, ClientSourceRDNS)
ok := clients.addHostLocked(ip, host, client.SourceRDNS)
if !ok {
log.Debug("clients: host for client %q already set with higher priority source", ip)
}
@@ -819,14 +802,19 @@ func (clients *clientsContainer) UpdateAddress(ip netip.Addr, host string, info
func (clients *clientsContainer) addHostLocked(
ip netip.Addr,
host string,
src clientSource,
src client.Source,
) (ok bool) {
rc, ok := clients.ipToRC[ip]
if !ok {
if src < client.SourceDHCP {
if clients.dhcp.HostByIP(ip) != "" {
return false
}
}
rc = &RuntimeClient{
WHOIS: &whois.Info{},
}
clients.ipToRC[ip] = rc
} else if src < rc.Source {
return false
@@ -841,7 +829,7 @@ func (clients *clientsContainer) addHostLocked(
}
// rmHostsBySrc removes all entries that match the specified source.
func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
func (clients *clientsContainer) rmHostsBySrc(src client.Source) {
n := 0
for ip, rc := range clients.ipToRC {
if rc.Source == src {
@@ -855,15 +843,19 @@ func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
// addFromHostsFile fills the client-hostname pairing index from the system's
// hosts files.
func (clients *clientsContainer) addFromHostsFile(hosts aghnet.HostsRecords) {
func (clients *clientsContainer) addFromHostsFile(hosts aghnet.Hosts) {
clients.lock.Lock()
defer clients.lock.Unlock()
clients.rmHostsBySrc(ClientSourceHostsFile)
clients.rmHostsBySrc(client.SourceHostsFile)
n := 0
for ip, rec := range hosts {
clients.addHostLocked(ip, rec.Canonical, ClientSourceHostsFile)
for addr, rec := range hosts {
// Only the first name of the first record is considered a canonical
// hostname for the IP address.
//
// TODO(e.burkov): Consider using all the names from all the records.
clients.addHostLocked(addr, rec[0].Names[0], client.SourceHostsFile)
n++
}
@@ -873,15 +865,15 @@ func (clients *clientsContainer) addFromHostsFile(hosts aghnet.HostsRecords) {
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
// command.
func (clients *clientsContainer) addFromSystemARP() {
if err := clients.arpdb.Refresh(); err != nil {
if err := clients.arpDB.Refresh(); err != nil {
log.Error("refreshing arp container: %s", err)
clients.arpdb = aghnet.EmptyARPDB{}
clients.arpDB = arpdb.Empty{}
return
}
ns := clients.arpdb.Neighbors()
ns := clients.arpDB.Neighbors()
if len(ns) == 0 {
log.Debug("refreshing arp container: the update is empty")
@@ -891,11 +883,11 @@ func (clients *clientsContainer) addFromSystemARP() {
clients.lock.Lock()
defer clients.lock.Unlock()
clients.rmHostsBySrc(ClientSourceARP)
clients.rmHostsBySrc(client.SourceARP)
added := 0
for _, n := range ns {
if clients.addHostLocked(n.IP, n.Name, ClientSourceARP) {
if clients.addHostLocked(n.IP, n.Name, client.SourceARP) {
added++
}
}
@@ -907,7 +899,9 @@ func (clients *clientsContainer) addFromSystemARP() {
// the persistent clients.
func (clients *clientsContainer) close() (err error) {
persistent := maps.Values(clients.list)
slices.SortFunc(persistent, func(a, b *Client) (less bool) { return a.Name < b.Name })
slices.SortFunc(persistent, func(a, b *Client) (res int) {
return strings.Compare(a.Name, b.Name)
})
var errs []error
@@ -917,9 +911,5 @@ func (clients *clientsContainer) close() (err error) {
}
}
if len(errs) > 0 {
return errors.List("closing client specific upstreams", errs...)
}
return nil
return errors.Join(errs...)
}

View File

@@ -7,22 +7,46 @@ import (
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type testDHCP struct {
OnLeases func() (leases []*dhcpsvc.Lease)
OnHostBy func(ip netip.Addr) (host string)
OnMACBy func(ip netip.Addr) (mac net.HardwareAddr)
}
// Lease implements the [DHCP] interface for testDHCP.
func (t *testDHCP) Leases() (leases []*dhcpsvc.Lease) { return t.OnLeases() }
// HostByIP implements the [DHCP] interface for testDHCP.
func (t *testDHCP) HostByIP(ip netip.Addr) (host string) { return t.OnHostBy(ip) }
// MACByIP implements the [DHCP] interface for testDHCP.
func (t *testDHCP) MACByIP(ip netip.Addr) (mac net.HardwareAddr) { return t.OnMACBy(ip) }
// newClientsContainer is a helper that creates a new clients container for
// tests.
func newClientsContainer(t *testing.T) (c *clientsContainer) {
t.Helper()
c = &clientsContainer{
testing: true,
}
err := c.Init(nil, nil, nil, nil, &filtering.Config{})
require.NoError(t, err)
dhcp := &testDHCP{
OnLeases: func() (leases []*dhcpsvc.Lease) { panic("not implemented") },
OnHostBy: func(ip netip.Addr) (host string) { return "" },
OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) { return nil },
}
require.NoError(t, c.Init(nil, dhcp, nil, nil, &filtering.Config{}))
return c
}
@@ -76,9 +100,9 @@ func TestClients(t *testing.T) {
assert.Equal(t, "client2", c.Name)
assert.Equal(t, clients.clientSource(cliNoneIP), ClientSourceNone)
assert.Equal(t, clients.clientSource(cli1IP), ClientSourcePersistent)
assert.Equal(t, clients.clientSource(cli2IP), ClientSourcePersistent)
assert.Equal(t, clients.clientSource(cliNoneIP), client.SourceNone)
assert.Equal(t, clients.clientSource(cli1IP), client.SourcePersistent)
assert.Equal(t, clients.clientSource(cli2IP), client.SourcePersistent)
})
t.Run("add_fail_name", func(t *testing.T) {
@@ -125,8 +149,8 @@ func TestClients(t *testing.T) {
})
require.NoError(t, err)
assert.Equal(t, clients.clientSource(cliOldIP), ClientSourceNone)
assert.Equal(t, clients.clientSource(cliNewIP), ClientSourcePersistent)
assert.Equal(t, clients.clientSource(cliOldIP), client.SourceNone)
assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent)
prev, ok = clients.list["client1"]
require.True(t, ok)
@@ -158,7 +182,7 @@ func TestClients(t *testing.T) {
ok := clients.Del("client1-renamed")
require.True(t, ok)
assert.Equal(t, clients.clientSource(netip.MustParseAddr("1.1.1.2")), ClientSourceNone)
assert.Equal(t, clients.clientSource(netip.MustParseAddr("1.1.1.2")), client.SourceNone)
})
t.Run("del_fail", func(t *testing.T) {
@@ -168,32 +192,32 @@ func TestClients(t *testing.T) {
t.Run("addhost_success", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
ok := clients.addHost(ip, "host", ClientSourceARP)
ok := clients.addHost(ip, "host", client.SourceARP)
assert.True(t, ok)
ok = clients.addHost(ip, "host2", ClientSourceARP)
ok = clients.addHost(ip, "host2", client.SourceARP)
assert.True(t, ok)
ok = clients.addHost(ip, "host3", ClientSourceHostsFile)
ok = clients.addHost(ip, "host3", client.SourceHostsFile)
assert.True(t, ok)
assert.Equal(t, clients.clientSource(ip), ClientSourceHostsFile)
assert.Equal(t, clients.clientSource(ip), client.SourceHostsFile)
})
t.Run("dhcp_replaces_arp", func(t *testing.T) {
ip := netip.MustParseAddr("1.2.3.4")
ok := clients.addHost(ip, "from_arp", ClientSourceARP)
ok := clients.addHost(ip, "from_arp", client.SourceARP)
assert.True(t, ok)
assert.Equal(t, clients.clientSource(ip), ClientSourceARP)
assert.Equal(t, clients.clientSource(ip), client.SourceARP)
ok = clients.addHost(ip, "from_dhcp", ClientSourceDHCP)
ok = clients.addHost(ip, "from_dhcp", client.SourceDHCP)
assert.True(t, ok)
assert.Equal(t, clients.clientSource(ip), ClientSourceDHCP)
assert.Equal(t, clients.clientSource(ip), client.SourceDHCP)
})
t.Run("addhost_fail", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
ok := clients.addHost(ip, "host1", ClientSourceRDNS)
ok := clients.addHost(ip, "host1", client.SourceRDNS)
assert.False(t, ok)
})
}
@@ -216,7 +240,7 @@ func TestClientsWHOIS(t *testing.T) {
t.Run("existing_auto-client", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
ok := clients.addHost(ip, "host", ClientSourceRDNS)
ok := clients.addHost(ip, "host", client.SourceRDNS)
assert.True(t, ok)
clients.setWHOISInfo(ip, whois)
@@ -259,7 +283,7 @@ func TestClientsAddExisting(t *testing.T) {
assert.True(t, ok)
// Now add an auto-client with the same IP.
ok = clients.addHost(ip, "test", ClientSourceRDNS)
ok = clients.addHost(ip, "test", client.SourceRDNS)
assert.True(t, ok)
})
@@ -288,7 +312,7 @@ func TestClientsAddExisting(t *testing.T) {
dhcpServer, err := dhcpd.Create(config)
require.NoError(t, err)
clients.dhcpServer = dhcpServer
clients.dhcp = dhcpServer
err = dhcpServer.AddStaticLease(&dhcpd.Lease{
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},

View File

@@ -8,6 +8,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
@@ -34,8 +35,12 @@ type clientJSON struct {
WHOIS *whois.Info `json:"whois_info,omitempty"`
SafeSearchConf *filtering.SafeSearchConfig `json:"safe_search"`
// Schedule is blocked services schedule for every day of the week.
Schedule *schedule.Weekly `json:"blocked_services_schedule"`
Name string `json:"name"`
// BlockedServices is the names of blocked services.
BlockedServices []string `json:"blocked_services"`
IDs []string `json:"ids"`
Tags []string `json:"tags"`
@@ -53,12 +58,40 @@ type clientJSON struct {
IgnoreStatistics aghalg.NullBool `json:"ignore_statistics"`
}
// copySettings returns a copy of specific settings from JSON or a previous
// client.
func (j *clientJSON) copySettings(
prev *Client,
) (weekly *schedule.Weekly, ignoreQueryLog, ignoreStatistics bool) {
if j.Schedule != nil {
weekly = j.Schedule.Clone()
} else if prev != nil && prev.BlockedServices != nil {
weekly = prev.BlockedServices.Schedule.Clone()
} else {
weekly = schedule.EmptyWeekly()
}
if j.IgnoreQueryLog != aghalg.NBNull {
ignoreQueryLog = j.IgnoreQueryLog == aghalg.NBTrue
} else if prev != nil {
ignoreQueryLog = prev.IgnoreQueryLog
}
if j.IgnoreStatistics != aghalg.NBNull {
ignoreStatistics = j.IgnoreStatistics == aghalg.NBTrue
} else if prev != nil {
ignoreStatistics = prev.IgnoreStatistics
}
return weekly, ignoreQueryLog, ignoreStatistics
}
type runtimeClientJSON struct {
WHOIS *whois.Info `json:"whois_info"`
IP netip.Addr `json:"ip"`
Name string `json:"name"`
Source clientSource `json:"source"`
IP netip.Addr `json:"ip"`
Name string `json:"name"`
Source client.Source `json:"source"`
}
type clientListJSON struct {
@@ -91,9 +124,20 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
data.RuntimeClients = append(data.RuntimeClients, cj)
}
for _, l := range clients.dhcp.Leases() {
cj := runtimeClientJSON{
Name: l.Hostname,
Source: client.SourceDHCP,
IP: l.IP,
WHOIS: &whois.Info{},
}
data.RuntimeClients = append(data.RuntimeClients, cj)
}
data.Tags = clientTags
_ = aghhttp.WriteJSONResponse(w, r, data)
aghhttp.WriteJSONResponseOK(w, r, data)
}
// jsonToClient converts JSON object to Client object.
@@ -119,9 +163,15 @@ func (clients *clientsContainer) jsonToClient(cj clientJSON, prev *Client) (c *C
}
}
weekly := schedule.EmptyWeekly()
if prev != nil {
weekly = prev.BlockedServices.Schedule.Clone()
weekly, ignoreQueryLog, ignoreStatistics := cj.copySettings(prev)
bs := &filtering.BlockedServices{
Schedule: weekly,
IDs: cj.BlockedServices,
}
err = bs.Validate()
if err != nil {
return nil, fmt.Errorf("validating blocked services: %w", err)
}
c = &Client{
@@ -129,10 +179,7 @@ func (clients *clientsContainer) jsonToClient(cj clientJSON, prev *Client) (c *C
Name: cj.Name,
BlockedServices: &filtering.BlockedServices{
Schedule: weekly,
IDs: cj.BlockedServices,
},
BlockedServices: bs,
IDs: cj.IDs,
Tags: cj.Tags,
@@ -143,18 +190,8 @@ func (clients *clientsContainer) jsonToClient(cj clientJSON, prev *Client) (c *C
ParentalEnabled: cj.ParentalEnabled,
SafeBrowsingEnabled: cj.SafeBrowsingEnabled,
UseOwnBlockedServices: !cj.UseGlobalBlockedServices,
}
if cj.IgnoreQueryLog != aghalg.NBNull {
c.IgnoreQueryLog = cj.IgnoreQueryLog == aghalg.NBTrue
} else if prev != nil {
c.IgnoreQueryLog = prev.IgnoreQueryLog
}
if cj.IgnoreStatistics != aghalg.NBNull {
c.IgnoreStatistics = cj.IgnoreStatistics == aghalg.NBTrue
} else if prev != nil {
c.IgnoreStatistics = prev.IgnoreStatistics
IgnoreQueryLog: ignoreQueryLog,
IgnoreStatistics: ignoreStatistics,
}
if safeSearchConf.Enabled {
@@ -191,6 +228,7 @@ func clientToJSON(c *Client) (cj *clientJSON) {
UseGlobalBlockedServices: !c.UseOwnBlockedServices,
Schedule: c.BlockedServices.Schedule,
BlockedServices: c.BlockedServices.IDs,
Upstreams: c.Upstreams,
@@ -338,7 +376,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
})
}
_ = aghhttp.WriteJSONResponse(w, r, data)
aghhttp.WriteJSONResponseOK(w, r, data)
}
// findRuntime looks up the IP in runtime and temporary storages, like

View File

@@ -10,6 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
"github.com/AdguardTeam/AdGuardHome/internal/confmigrate"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
@@ -21,7 +22,6 @@ import (
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/google/renameio/v2/maybe"
"golang.org/x/exp/slices"
yaml "gopkg.in/yaml.v3"
)
@@ -114,8 +114,6 @@ type configuration struct {
Language string `yaml:"language"`
// Theme is a UI theme for current user.
Theme Theme `yaml:"theme"`
// DebugPProf defines if the profiling HTTP handler will listen on :6060.
DebugPProf bool `yaml:"debug_pprof"`
DNS dnsConfig `yaml:"dns"`
TLS tlsConfigSettings `yaml:"tls"`
@@ -133,7 +131,8 @@ type configuration struct {
WhitelistFilters []filtering.FilterYAML `yaml:"whitelist_filters"`
UserRules []string `yaml:"user_rules"`
DHCP *dhcpd.ServerConfig `yaml:"dhcp"`
DHCP *dhcpd.ServerConfig `yaml:"dhcp"`
Filtering *filtering.Config `yaml:"filtering"`
// Clients contains the YAML representations of the persistent clients.
// This field is only used for reading and writing persistent client data.
@@ -147,7 +146,9 @@ type configuration struct {
sync.RWMutex `yaml:"-"`
SchemaVersion int `yaml:"schema_version"` // keeping last so that users will be less tempted to change it -- used when upgrading between versions
// SchemaVersion is the version of the configuration schema. See
// [confmigrate.LastSchemaVersion].
SchemaVersion uint `yaml:"schema_version"`
}
// httpConfig is a block with HTTP configuration params.
@@ -155,6 +156,9 @@ type configuration struct {
// Field ordering is important, YAML fields better not to be reordered, if it's
// not absolutely necessary.
type httpConfig struct {
// Pprof defines the profiling HTTP handler.
Pprof *httpPprofConfig `yaml:"pprof"`
// Address is the address to serve the web UI on.
Address netip.AddrPort
@@ -163,6 +167,15 @@ type httpConfig struct {
SessionTTL timeutil.Duration `yaml:"session_ttl"`
}
// httpPprofConfig is the block with pprof HTTP configuration.
type httpPprofConfig struct {
// Port for the profiling handler.
Port uint16 `yaml:"port"`
// Enabled defines if the profiling handler is enabled.
Enabled bool `yaml:"enabled"`
}
// dnsConfig is a block with DNS configuration params.
//
// Field ordering is important, YAML fields better not to be reordered, if it's
@@ -175,9 +188,10 @@ type dnsConfig struct {
// in query log and statistics.
AnonymizeClientIP bool `yaml:"anonymize_client_ip"`
dnsforward.FilteringConfig `yaml:",inline"`
DnsfilterConf *filtering.Config `yaml:",inline"`
// Config is the embed configuration with DNS params.
//
// TODO(a.garipov): Remove embed.
dnsforward.Config `yaml:",inline"`
// UpstreamTimeout is the timeout for querying upstream servers.
UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"`
@@ -277,18 +291,19 @@ var config = &configuration{
HTTPConfig: httpConfig{
Address: netip.AddrPortFrom(netip.IPv4Unspecified(), 3000),
SessionTTL: timeutil.Duration{Duration: 30 * timeutil.Day},
Pprof: &httpPprofConfig{
Enabled: false,
Port: 6060,
},
},
DNS: dnsConfig{
BindHosts: []netip.Addr{netip.IPv4Unspecified()},
Port: defaultPortDNS,
FilteringConfig: dnsforward.FilteringConfig{
ProtectionEnabled: true, // whether or not use any of filtering features
BlockingMode: dnsforward.BlockingModeDefault,
BlockedResponseTTL: 10, // in seconds
Ratelimit: 20,
RefuseAny: true,
AllServers: false,
HandleDDR: true,
Config: dnsforward.Config{
Ratelimit: 20,
RefuseAny: true,
AllServers: false,
HandleDDR: true,
FastestTimeout: timeutil.Duration{
Duration: fastip.DefaultPingWaitTimeout,
},
@@ -308,33 +323,6 @@ var config = &configuration{
// was later increased to 300 due to https://github.com/AdguardTeam/AdGuardHome/issues/2257
MaxGoroutines: 300,
},
DnsfilterConf: &filtering.Config{
FilteringEnabled: true,
FiltersUpdateIntervalHours: 24,
ParentalEnabled: false,
SafeBrowsingEnabled: false,
SafeBrowsingCacheSize: 1 * 1024 * 1024,
SafeSearchCacheSize: 1 * 1024 * 1024,
ParentalCacheSize: 1 * 1024 * 1024,
CacheTime: 30,
SafeSearchConf: filtering.SafeSearchConfig{
Enabled: false,
Bing: true,
DuckDuckGo: true,
Google: true,
Pixabay: true,
Yandex: true,
YouTube: true,
},
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
IDs: []string{},
},
},
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
UsePrivateRDNS: true,
},
@@ -371,6 +359,37 @@ var config = &configuration{
URL: "https://adguardteam.github.io/HostlistsRegistry/assets/filter_2.txt",
Name: "AdAway Default Blocklist",
}},
Filtering: &filtering.Config{
ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeDefault,
BlockedResponseTTL: 10, // in seconds
FilteringEnabled: true,
FiltersUpdateIntervalHours: 24,
ParentalEnabled: false,
SafeBrowsingEnabled: false,
SafeBrowsingCacheSize: 1 * 1024 * 1024,
SafeSearchCacheSize: 1 * 1024 * 1024,
ParentalCacheSize: 1 * 1024 * 1024,
CacheTime: 30,
SafeSearchConf: filtering.SafeSearchConfig{
Enabled: false,
Bing: true,
DuckDuckGo: true,
Google: true,
Pixabay: true,
Yandex: true,
YouTube: true,
},
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
IDs: []string{},
},
},
DHCP: &dhcpd.ServerConfig{
LocalDomainName: "lan",
Conf4: dhcpd.V4ServerConf{
@@ -398,7 +417,7 @@ var config = &configuration{
MaxAge: 3,
},
OSConfig: &osConfig{},
SchemaVersion: currentSchemaVersion,
SchemaVersion: confmigrate.LastSchemaVersion,
Theme: ThemeAuto,
}
@@ -414,28 +433,10 @@ func (c *configuration) getConfigFilename() string {
if !filepath.IsAbs(configFile) {
configFile = filepath.Join(Context.workDir, configFile)
}
return configFile
}
// readLogSettings reads logging settings from the config file. We do it in a
// separate method in order to configure logger before the actual configuration
// is parsed and applied.
func readLogSettings() (ls *logSettings) {
conf := &configuration{}
yamlFile, err := readConfigFile()
if err != nil {
return &logSettings{}
}
err = yaml.Unmarshal(yamlFile, conf)
if err != nil {
log.Error("Couldn't get logging settings from the configuration: %s", err)
}
return &conf.Log
}
// validateBindHosts returns error if any of binding hosts from configuration is
// not a valid IP address.
func validateBindHosts(conf *configuration) (err error) {
@@ -452,21 +453,59 @@ func validateBindHosts(conf *configuration) (err error) {
return nil
}
// parseConfig loads configuration from the YAML file
// parseConfig loads configuration from the YAML file, upgrading it if
// necessary.
func parseConfig() (err error) {
var fileData []byte
fileData, err = readConfigFile()
// Do the upgrade if necessary.
config.fileData, err = readConfigFile()
if err != nil {
return err
}
config.fileData = nil
err = yaml.Unmarshal(fileData, &config)
migrator := confmigrate.New(&confmigrate.Config{
WorkingDir: Context.workDir,
})
var upgraded bool
config.fileData, upgraded, err = migrator.Migrate(
config.fileData,
confmigrate.LastSchemaVersion,
)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
} else if upgraded {
err = maybe.WriteFile(config.getConfigFilename(), config.fileData, 0o644)
if err != nil {
return fmt.Errorf("writing new config: %w", err)
}
}
err = yaml.Unmarshal(config.fileData, &config)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
err = validateConfig()
if err != nil {
return err
}
if config.DNS.UpstreamTimeout.Duration == 0 {
config.DNS.UpstreamTimeout = timeutil.Duration{Duration: dnsforward.DefaultTimeout}
}
err = setContextTLSCipherIDs()
if err != nil {
return err
}
return nil
}
// validateConfig returns error if the configuration is invalid.
func validateConfig() (err error) {
err = validateBindHosts(config)
if err != nil {
// Don't wrap the error since it's informative enough as is.
@@ -498,17 +537,8 @@ func parseConfig() (err error) {
return fmt.Errorf("validating udp ports: %w", err)
}
if !filtering.ValidateUpdateIvl(config.DNS.DnsfilterConf.FiltersUpdateIntervalHours) {
config.DNS.DnsfilterConf.FiltersUpdateIntervalHours = 24
}
if config.DNS.UpstreamTimeout.Duration == 0 {
config.DNS.UpstreamTimeout = timeutil.Duration{Duration: dnsforward.DefaultTimeout}
}
err = setContextTLSCipherIDs()
if err != nil {
return err
if !filtering.ValidateUpdateIvl(config.Filtering.FiltersUpdateIntervalHours) {
config.Filtering.FiltersUpdateIntervalHours = 24
}
return nil
@@ -563,7 +593,6 @@ func (c *configuration) write() (err error) {
config.Stats.Interval = timeutil.Duration{Duration: statsConf.Limit}
config.Stats.Enabled = statsConf.Enabled
config.Stats.Ignored = statsConf.Ignored.Values()
slices.Sort(config.Stats.Ignored)
}
if Context.queryLog != nil {
@@ -575,21 +604,20 @@ func (c *configuration) write() (err error) {
config.QueryLog.Interval = timeutil.Duration{Duration: dc.RotationIvl}
config.QueryLog.MemSize = dc.MemSize
config.QueryLog.Ignored = dc.Ignored.Values()
slices.Sort(config.Stats.Ignored)
}
if Context.filters != nil {
Context.filters.WriteDiskConfig(config.DNS.DnsfilterConf)
config.Filters = config.DNS.DnsfilterConf.Filters
config.WhitelistFilters = config.DNS.DnsfilterConf.WhitelistFilters
config.UserRules = config.DNS.DnsfilterConf.UserRules
Context.filters.WriteDiskConfig(config.Filtering)
config.Filters = config.Filtering.Filters
config.WhitelistFilters = config.Filtering.WhitelistFilters
config.UserRules = config.Filtering.UserRules
}
if s := Context.dnsServer; s != nil {
c := dnsforward.FilteringConfig{}
c := dnsforward.Config{}
s.WriteDiskConfig(&c)
dns := &config.DNS
dns.FilteringConfig = c
dns.Config = c
dns.LocalPTRResolvers = s.LocalPTRResolvers()

View File

@@ -127,12 +127,12 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
}
var (
fltConf *dnsforward.FilteringConfig
fltConf *dnsforward.Config
protectionDisabledUntil *time.Time
protectionEnabled bool
)
if Context.dnsServer != nil {
fltConf = &dnsforward.FilteringConfig{}
fltConf = &dnsforward.Config{}
Context.dnsServer.WriteDiskConfig(fltConf)
protectionEnabled, protectionDisabledUntil = Context.dnsServer.UpdatedProtectionStatus()
}
@@ -170,7 +170,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
resp.IsDHCPAvailable = Context.dhcpServer != nil
}
_ = aghhttp.WriteJSONResponse(w, r, resp)
aghhttp.WriteJSONResponseOK(w, r, resp)
}
// ------------------------
@@ -321,9 +321,10 @@ func preInstallHandler(handler http.Handler) http.Handler {
return &preInstallHandlerStruct{handler}
}
// handleHTTPSRedirect redirects the request to HTTPS, if needed. If ok is
// true, the middleware must continue handling the request.
func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
// handleHTTPSRedirect redirects the request to HTTPS, if needed, and adds some
// HTTPS-related headers. If proceed is true, the middleware must continue
// handling the request.
func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (proceed bool) {
web := Context.web
if web.httpsServer.server == nil {
return true
@@ -362,21 +363,17 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
respHdr.Set(httphdr.AltSvc, altSvc)
}
if r.TLS == nil && forceHTTPS {
hostPort := host
if portHTTPS != defaultPortHTTPS {
hostPort = netutil.JoinHostPort(host, portHTTPS)
if forceHTTPS {
if r.TLS == nil {
u := httpsURL(r.URL, host, portHTTPS)
http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect)
return false
}
httpsURL := &url.URL{
Scheme: aghhttp.SchemeHTTPS,
Host: hostPort,
Path: r.URL.Path,
RawQuery: r.URL.RawQuery,
}
http.Redirect(w, r, httpsURL.String(), http.StatusTemporaryRedirect)
return false
// TODO(a.garipov): Consider adding a configurable max-age. Currently,
// the default is 365 days.
respHdr.Set(httphdr.StrictTransportSecurity, aghhttp.HdrValStrictTransportSecurity)
}
// Allow the frontend from the HTTP origin to send requests to the HTTPS
@@ -395,6 +392,22 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
return true
}
// httpsURL returns a copy of u for redirection to the HTTPS version, taking the
// hostname and the HTTPS port into account.
func httpsURL(u *url.URL, host string, portHTTPS int) (redirectURL *url.URL) {
hostPort := host
if portHTTPS != defaultPortHTTPS {
hostPort = netutil.JoinHostPort(host, portHTTPS)
}
return &url.URL{
Scheme: aghhttp.SchemeHTTPS,
Host: hostPort,
Path: u.Path,
RawQuery: u.RawQuery,
}
}
// postInstall lets the handler to run only if firstRun is false. Otherwise, it
// redirects to /install.html. It also enforces HTTPS if it is enabled and
// configured and sets appropriate access control headers.
@@ -408,11 +421,10 @@ func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.Res
return
}
if !handleHTTPSRedirect(w, r) {
return
proceed := handleHTTPSRedirect(w, r)
if proceed {
handler(w, r)
}
handler(w, r)
}
}

View File

@@ -59,7 +59,7 @@ func (web *webAPI) handleInstallGetAddresses(w http.ResponseWriter, r *http.Requ
data.Interfaces[iface.Name] = iface
}
_ = aghhttp.WriteJSONResponse(w, r, data)
aghhttp.WriteJSONResponseOK(w, r, data)
}
type checkConfReqEnt struct {
@@ -190,7 +190,7 @@ func (web *webAPI) handleInstallCheckConfig(w http.ResponseWriter, r *http.Reque
resp.StaticIP = handleStaticIP(req.DNS.IP, req.SetStaticIP)
}
_ = aghhttp.WriteJSONResponse(w, r, resp)
aghhttp.WriteJSONResponseOK(w, r, resp)
}
// handleStaticIP - handles static IP request

View File

@@ -33,7 +33,7 @@ func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) {
resp := &versionResponse{}
if web.conf.disableUpdate {
resp.Disabled = true
_ = aghhttp.WriteJSONResponse(w, r, resp)
aghhttp.WriteJSONResponseOK(w, r, resp)
return
}
@@ -68,7 +68,7 @@ func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) {
return
}
_ = aghhttp.WriteJSONResponse(w, r, resp)
aghhttp.WriteJSONResponseOK(w, r, resp)
}
// requestVersionInfo sets the VersionInfo field of resp if it can reach the

View File

@@ -14,7 +14,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
@@ -60,12 +59,12 @@ func initDNS() (err error) {
ShouldCountClient: Context.clients.shouldCountClient,
}
set, err := aghnet.NewDomainNameSet(config.Stats.Ignored)
engine, err := aghnet.NewIgnoreEngine(config.Stats.Ignored)
if err != nil {
return fmt.Errorf("statistics: ignored list: %w", err)
}
statsConf.Ignored = set
statsConf.Ignored = engine
Context.stats, err = stats.New(statsConf)
if err != nil {
return fmt.Errorf("init stats: %w", err)
@@ -84,18 +83,18 @@ func initDNS() (err error) {
FileEnabled: config.QueryLog.FileEnabled,
}
set, err = aghnet.NewDomainNameSet(config.QueryLog.Ignored)
engine, err = aghnet.NewIgnoreEngine(config.QueryLog.Ignored)
if err != nil {
return fmt.Errorf("querylog: ignored list: %w", err)
}
conf.Ignored = set
conf.Ignored = engine
Context.queryLog, err = querylog.New(conf)
if err != nil {
return fmt.Errorf("init querylog: %w", err)
}
Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil)
Context.filters, err = filtering.New(config.Filtering, nil)
if err != nil {
// Don't wrap the error, since it's informative enough as is.
return err
@@ -123,7 +122,7 @@ func initDNSServer(
filters *filtering.DNSFilter,
sts stats.Interface,
qlog querylog.QueryLog,
dhcpSrv dhcpd.Interface,
dhcpSrv dnsforward.DHCP,
anonymizer *aghnet.IPMut,
httpReg aghhttp.RegisterFunc,
tlsConf *tlsConfigSettings,
@@ -231,13 +230,13 @@ func newServerConfig(
hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()})
newConf = &dnsforward.ServerConfig{
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
TCPListenAddrs: ipsToTCPAddrs(hosts, dnsConf.Port),
FilteringConfig: dnsConf.FilteringConfig,
ConfigModified: onConfigModified,
HTTPRegister: httpReg,
UseDNS64: config.DNS.UseDNS64,
DNS64Prefixes: config.DNS.DNS64Prefixes,
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
TCPListenAddrs: ipsToTCPAddrs(hosts, dnsConf.Port),
Config: dnsConf.Config,
ConfigModified: onConfigModified,
HTTPRegister: httpReg,
UseDNS64: config.DNS.UseDNS64,
DNS64Prefixes: config.DNS.DNS64Prefixes,
}
var initialAddresses []netip.Addr
@@ -378,7 +377,7 @@ func getDNSEncryption() (de dnsEncryption) {
// applyAdditionalFiltering adds additional client information and settings if
// the client has them.
func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering.Settings) {
func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filtering.Settings) {
// pref is a prefix for logging messages around the scope.
const pref = "applying filters"
@@ -386,7 +385,7 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering
log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID)
if clientIP == nil {
if !clientIP.IsValid() {
return
}
@@ -502,14 +501,10 @@ func closeDNSServer() {
if err != nil {
log.Debug("closing stats: %s", err)
}
// TODO(e.burkov): Find out if it's safe.
Context.stats = nil
}
if Context.queryLog != nil {
Context.queryLog.Close()
Context.queryLog = nil
}
log.Debug("all dns modules are closed")

View File

@@ -1,7 +1,7 @@
package home
import (
"net"
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
@@ -10,6 +10,8 @@ import (
"github.com/stretchr/testify/require"
)
var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4})
func TestApplyAdditionalFiltering(t *testing.T) {
var err error
@@ -78,7 +80,7 @@ func TestApplyAdditionalFiltering(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
setts := &filtering.Settings{}
applyAdditionalFiltering(net.IP{1, 2, 3, 4}, tc.id, setts)
applyAdditionalFiltering(testIPv4, tc.id, setts)
tc.FilteringEnabled(t, setts.FilteringEnabled)
tc.SafeSearchEnabled(t, setts.SafeSearchEnabled)
tc.SafeBrowsingEnabled(t, setts.SafeBrowsingEnabled)
@@ -169,7 +171,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
setts := &filtering.Settings{}
applyAdditionalFiltering(net.IP{1, 2, 3, 4}, tc.id, setts)
applyAdditionalFiltering(testIPv4, tc.id, setts)
require.Len(t, setts.ServicesRules, tc.wantLen)
})
}

View File

@@ -22,6 +22,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
@@ -37,12 +38,6 @@ import (
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"golang.org/x/exp/slices"
"gopkg.in/natefinch/lumberjack.v2"
)
const (
// Used in config to indicate that syslog or eventlog (win) should be used for logger output
configSyslog = "syslog"
)
// Global context
@@ -104,8 +99,11 @@ func Main(clientBuildFS fs.FS) {
// package flag.
opts := loadCmdLineOpts()
done := make(chan struct{})
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
for {
sig := <-signals
@@ -117,19 +115,19 @@ func Main(clientBuildFS fs.FS) {
default:
cleanup(context.Background())
cleanupAlways()
os.Exit(0)
close(done)
}
}
}()
if opts.serviceControlAction != "" {
handleServiceControlAction(opts, clientBuildFS, signals)
handleServiceControlAction(opts, clientBuildFS, signals, done)
return
}
// run the protection
run(opts, clientBuildFS)
run(opts, clientBuildFS, done)
}
// setupContext initializes [Context] fields. It also reads and upgrades
@@ -147,14 +145,8 @@ func setupContext(opts options) (err error) {
return nil
}
// Do the upgrade if necessary.
err = upgradeConfig()
err = parseConfig()
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
if err = parseConfig(); err != nil {
log.Error("parsing configuration file: %s", err)
os.Exit(1)
@@ -239,7 +231,6 @@ func setupHostsContainer() (err error) {
}
Context.etcHosts, err = aghnet.NewHostsContainer(
filtering.SysHostsListID,
aghos.RootDirFS(),
hostsWatcher,
aghnet.DefaultHostsPaths()...,
@@ -275,7 +266,7 @@ func setupOpts(opts options) (err error) {
// initContextClients initializes Context clients and related fields.
func initContextClients() (err error) {
err = setupDNSFilteringConf(config.DNS.DnsfilterConf)
err = setupDNSFilteringConf(config.Filtering)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
@@ -296,17 +287,17 @@ func initContextClients() (err error) {
return fmt.Errorf("initing dhcp: %w", err)
}
var arpdb aghnet.ARPDB
var arpDB arpdb.Interface
if config.Clients.Sources.ARP {
arpdb = aghnet.NewARPDB()
arpDB = arpdb.New()
}
err = Context.clients.Init(
config.Clients.Persistent,
Context.dhcpServer,
Context.etcHosts,
arpdb,
config.DNS.DnsfilterConf,
arpDB,
config.Filtering,
)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
@@ -368,6 +359,9 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
pcService = "parental control"
defaultParentalServer = `https://family.adguard-dns.com/dns-query`
pcTXTSuffix = `pc.dns.adguard.com.`
defaultSafeBrowsingBlockHost = "standard-block.dns.adguard.com"
defaultParentalBlockHost = "family-block.dns.adguard.com"
)
conf.EtcHosts = Context.etcHosts
@@ -404,6 +398,10 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
CacheSize: conf.SafeBrowsingCacheSize,
})
if conf.SafeBrowsingBlockHost != "" {
conf.SafeBrowsingBlockHost = defaultSafeBrowsingBlockHost
}
parUps, err := upstream.AddressToUpstream(defaultParentalServer, upsOpts)
if err != nil {
return fmt.Errorf("converting parental server: %w", err)
@@ -417,6 +415,10 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
CacheSize: conf.ParentalCacheSize,
})
if conf.ParentalBlockHost != "" {
conf.ParentalBlockHost = defaultParentalBlockHost
}
conf.SafeSearchConf.CustomResolver = safeSearchResolver{}
conf.SafeSearch, err = safesearch.NewDefault(
conf.SafeSearchConf,
@@ -510,7 +512,7 @@ func fatalOnError(err error) {
}
// run configures and starts AdGuard Home.
func run(opts options, clientBuildFS fs.FS) {
func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
// Configure config filename.
initConfigFilename(opts)
@@ -547,7 +549,7 @@ func run(opts options, clientBuildFS fs.FS) {
fatalOnError(err)
upd := updater.NewUpdater(&updater.Config{
Client: config.DNS.DnsfilterConf.HTTPClient,
Client: config.Filtering.HTTPClient,
Version: version.Version(),
Channel: version.Channel(),
GOARCH: runtime.GOARCH,
@@ -567,9 +569,8 @@ func run(opts options, clientBuildFS fs.FS) {
err = config.write()
fatalOnError(err)
if config.DebugPProf {
// TODO(a.garipov): Make the address configurable.
startPprof("localhost:6060")
if config.HTTPConfig.Pprof.Enabled {
startPprof(config.HTTPConfig.Pprof.Port)
}
}
@@ -616,8 +617,8 @@ func run(opts options, clientBuildFS fs.FS) {
Context.web.start()
// Wait indefinitely for other goroutines to complete their job.
select {}
// Wait for other goroutines to complete their job.
<-done
}
// initUsers initializes context auth module. Clears config users field.
@@ -748,79 +749,6 @@ func initWorkingDir(opts options) (err error) {
return nil
}
// configureLogger configures logger level and output.
func configureLogger(opts options) (err error) {
ls := getLogSettings(opts)
// Configure logger level.
if ls.Verbose {
log.SetLevel(log.DEBUG)
}
// Make sure that we see the microseconds in logs, as networking stuff can
// happen pretty quickly.
log.SetFlags(log.LstdFlags | log.Lmicroseconds)
// Write logs to stdout by default.
if ls.File == "" {
return nil
}
if ls.File == configSyslog {
// Use syslog where it is possible and eventlog on Windows.
err = aghos.ConfigureSyslog(serviceName)
if err != nil {
return fmt.Errorf("cannot initialize syslog: %w", err)
}
return nil
}
logFilePath := ls.File
if !filepath.IsAbs(logFilePath) {
logFilePath = filepath.Join(Context.workDir, logFilePath)
}
log.SetOutput(&lumberjack.Logger{
Filename: logFilePath,
Compress: ls.Compress,
LocalTime: ls.LocalTime,
MaxBackups: ls.MaxBackups,
MaxSize: ls.MaxSize,
MaxAge: ls.MaxAge,
})
return nil
}
// getLogSettings returns a log settings object properly initialized from opts.
func getLogSettings(opts options) (ls *logSettings) {
ls = readLogSettings()
configLogSettings := config.Log
// Command-line arguments can override config settings.
if opts.verbose || configLogSettings.Verbose {
ls.Verbose = true
}
ls.File = stringutil.Coalesce(opts.logFile, configLogSettings.File, ls.File)
// Handle default log settings overrides.
ls.Compress = configLogSettings.Compress
ls.LocalTime = configLogSettings.LocalTime
ls.MaxBackups = configLogSettings.MaxBackups
ls.MaxSize = configLogSettings.MaxSize
ls.MaxAge = configLogSettings.MaxAge
if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
// When running as a Windows service, use eventlog by default if
// nothing else is configured. Otherwise, we'll lose the log output.
ls.File = configSyslog
}
return ls
}
// cleanup stops and resets all the modules.
func cleanup(ctx context.Context) {
log.Info("stopping AdGuard Home")

View File

@@ -7,6 +7,6 @@ import (
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
initCmdLineOpts()
testutil.DiscardLogOutput(m)
}

View File

@@ -58,7 +58,7 @@ type languageJSON struct {
func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
log.Printf("home: language is %s", config.Language)
_ = aghhttp.WriteJSONResponse(w, r, &languageJSON{
aghhttp.WriteJSONResponseOK(w, r, &languageJSON{
Language: config.Language,
})
}

106
internal/home/log.go Normal file
View File

@@ -0,0 +1,106 @@
package home
import (
"fmt"
"path/filepath"
"runtime"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"gopkg.in/natefinch/lumberjack.v2"
"gopkg.in/yaml.v3"
)
// configSyslog is used to indicate that syslog or eventlog (win) should be used
// for logger output.
const configSyslog = "syslog"
// configureLogger configures logger level and output.
func configureLogger(opts options) (err error) {
ls := getLogSettings(opts)
// Configure logger level.
if ls.Verbose {
log.SetLevel(log.DEBUG)
}
// Make sure that we see the microseconds in logs, as networking stuff can
// happen pretty quickly.
log.SetFlags(log.LstdFlags | log.Lmicroseconds)
// Write logs to stdout by default.
if ls.File == "" {
return nil
}
if ls.File == configSyslog {
// Use syslog where it is possible and eventlog on Windows.
err = aghos.ConfigureSyslog(serviceName)
if err != nil {
return fmt.Errorf("cannot initialize syslog: %w", err)
}
return nil
}
logFilePath := ls.File
if !filepath.IsAbs(logFilePath) {
logFilePath = filepath.Join(Context.workDir, logFilePath)
}
log.SetOutput(&lumberjack.Logger{
Filename: logFilePath,
Compress: ls.Compress,
LocalTime: ls.LocalTime,
MaxBackups: ls.MaxBackups,
MaxSize: ls.MaxSize,
MaxAge: ls.MaxAge,
})
return nil
}
// getLogSettings returns a log settings object properly initialized from opts.
func getLogSettings(opts options) (ls *logSettings) {
configLogSettings := config.Log
ls = readLogSettings()
if ls == nil {
// Use default log settings.
ls = &configLogSettings
}
// Command-line arguments can override config settings.
if opts.verbose {
ls.Verbose = true
}
ls.File = stringutil.Coalesce(opts.logFile, ls.File)
if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
// When running as a Windows service, use eventlog by default if
// nothing else is configured. Otherwise, we'll lose the log output.
ls.File = configSyslog
}
return ls
}
// readLogSettings reads logging settings from the config file. We do it in a
// separate method in order to configure logger before the actual configuration
// is parsed and applied.
func readLogSettings() (ls *logSettings) {
conf := &configuration{}
yamlFile, err := readConfigFile()
if err != nil {
return nil
}
err = yaml.Unmarshal(yamlFile, conf)
if err != nil {
log.Error("Couldn't get logging settings from the configuration: %s", err)
}
return &conf.Log
}

View File

@@ -61,7 +61,7 @@ func handleGetProfile(w http.ResponseWriter, r *http.Request) {
}
}()
_ = aghhttp.WriteJSONResponse(w, r, resp)
aghhttp.WriteJSONResponseOK(w, r, resp)
}
// handlePutProfile is the handler for PUT /control/profile/update endpoint.

View File

@@ -34,6 +34,7 @@ const (
type program struct {
clientBuildFS fs.FS
signals chan os.Signal
done chan struct{}
opts options
}
@@ -46,19 +47,19 @@ func (p *program) Start(_ service.Service) (err error) {
args := p.opts
args.runningAsService = true
go run(args, p.clientBuildFS)
go run(args, p.clientBuildFS, p.done)
return nil
}
// Stop implements service.Interface interface for *program.
func (p *program) Stop(_ service.Service) (err error) {
select {
case p.signals <- syscall.SIGINT:
// Go on.
default:
// Stop should not block.
}
log.Info("service: stopping: waiting for cleanup")
aghos.SendShutdownSignal(p.signals)
// Wait for other goroutines to complete their job.
<-p.done
return nil
}
@@ -198,7 +199,12 @@ func restartService() (err error) {
// - run: This is a special command that is not supposed to be used directly
// it is specified when we register a service, and it indicates to the app
// that it is being run as a service/daemon.
func handleServiceControlAction(opts options, clientBuildFS fs.FS, signals chan os.Signal) {
func handleServiceControlAction(
opts options,
clientBuildFS fs.FS,
signals chan os.Signal,
done chan struct{},
) {
// Call chooseSystem explicitly to introduce OpenBSD support for service
// package. It's a noop for other GOOS values.
chooseSystem()
@@ -233,6 +239,7 @@ func handleServiceControlAction(opts options, clientBuildFS fs.FS, signals chan
s, err := service.New(&program{
clientBuildFS: clientBuildFS,
signals: signals,
done: done,
opts: runOpts,
}, svcConfig)
if err != nil {

View File

@@ -770,7 +770,7 @@ func marshalTLS(w http.ResponseWriter, r *http.Request, data tlsConfig) {
data.PrivateKey = ""
}
_ = aghhttp.WriteJSONResponse(w, r, data)
aghhttp.WriteJSONResponseOK(w, r, data)
}
// registerWebHandlers registers HTTP handlers for TLS configuration.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -312,8 +312,10 @@ func (web *webAPI) mustStartHTTP3(address string) {
}
}
// startPprof launches the debug and profiling server on addr.
func startPprof(addr string) {
// startPprof launches the debug and profiling server on the provided port.
func startPprof(port uint16) {
addr := netip.AddrPortFrom(netutil.IPv4Localhost(), port)
runtime.SetBlockProfileRate(1)
runtime.SetMutexProfileFraction(1)
@@ -324,7 +326,7 @@ func startPprof(addr string) {
defer log.OnPanic("pprof server")
log.Info("pprof: listening on %q", addr)
err := http.ListenAndServe(addr, mux)
err := http.ListenAndServe(addr.String(), mux)
if !errors.Is(err, http.ErrServerClosed) {
log.Error("pprof: shutting down: %s", err)
}