all: sync with master, upd chlog

This commit is contained in:
Eugene Burkov
2025-03-11 13:36:04 +03:00
parent 805de59805
commit 474cba52f0
166 changed files with 8809 additions and 10440 deletions

View File

@@ -0,0 +1,24 @@
package aghnet
import "github.com/AdguardTeam/dnsproxy/upstream"
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration
// depending on configuration.
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
if !http3 {
return upstream.DefaultHTTPVersions
}
return []upstream.HTTPVersion{
upstream.HTTPVersion3,
upstream.HTTPVersion2,
upstream.HTTPVersion11,
}
}
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
// This function is useful for filtering out non-upstream lines from upstream
// configs.
func IsCommentOrEmpty(s string) (ok bool) {
return len(s) == 0 || s[0] == '#'
}

View File

@@ -0,0 +1,26 @@
package aghnet_test
import (
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/stretchr/testify/assert"
)
func TestIsCommentOrEmpty(t *testing.T) {
for _, tc := range []struct {
want assert.BoolAssertionFunc
str string
}{{
want: assert.True,
str: "",
}, {
want: assert.True,
str: "# comment",
}, {
want: assert.False,
str: "1.2.3.4",
}} {
tc.want(t, aghnet.IsCommentOrEmpty(tc.str))
}
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
)
@@ -121,26 +120,6 @@ func (p *AddressUpdater) UpdateAddress(
p.OnUpdateAddress(ctx, ip, host, info)
}
// Package dnsforward
// ClientsContainer is a fake [dnsforward.ClientsContainer] implementation for
// tests.
type ClientsContainer struct {
OnUpstreamConfigByID func(
id string,
boot upstream.Resolver,
) (conf *proxy.CustomUpstreamConfig, err error)
}
// UpstreamConfigByID implements the [dnsforward.ClientsContainer] interface
// for *ClientsContainer.
func (c *ClientsContainer) UpstreamConfigByID(
id string,
boot upstream.Resolver,
) (conf *proxy.CustomUpstreamConfig, err error) {
return c.OnUpstreamConfigByID(id, boot)
}
// Package filtering
// Resolver is a fake [filtering.Resolver] implementation for tests.

View File

@@ -3,7 +3,6 @@ package aghtest_test
import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
)
@@ -12,9 +11,6 @@ import (
// type check
var _ filtering.Resolver = (*aghtest.Resolver)(nil)
// type check
var _ dnsforward.ClientsContainer = (*aghtest.ClientsContainer)(nil)
// type check
//
// TODO(s.chzhen): It's here to avoid the import cycle. Remove it.

View File

@@ -9,7 +9,6 @@ import (
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/golibs/errors"
)
// macKey contains MAC as byte array of 6, 8, or 20 bytes.
@@ -35,7 +34,7 @@ type index struct {
// nameToUID maps client name to UID.
nameToUID map[string]UID
// clientIDToUID maps client ID to UID.
// clientIDToUID maps ClientID to UID.
clientIDToUID map[string]UID
// ipToUID maps IP address to UID.
@@ -205,19 +204,19 @@ func (ci *index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr)
return nil, nil
}
// find finds persistent client by string representation of the client ID, IP
// find finds persistent client by string representation of the ClientID, IP
// address, or MAC.
func (ci *index) find(id string) (c *Persistent, ok bool) {
uid, found := ci.clientIDToUID[id]
if found {
return ci.uidToClient[uid], true
c, ok = ci.findByClientID(id)
if ok {
return c, true
}
ip, err := netip.ParseAddr(id)
if err == nil {
// MAC addresses can be successfully parsed as IP addresses.
c, found = ci.findByIP(ip)
if found {
c, ok = ci.findByIP(ip)
if ok {
return c, true
}
}
@@ -230,6 +229,16 @@ func (ci *index) find(id string) (c *Persistent, ok bool) {
return nil, false
}
// findByClientID finds persistent client by ClientID.
func (ci *index) findByClientID(clientID string) (c *Persistent, ok bool) {
uid, ok := ci.clientIDToUID[clientID]
if ok {
return ci.uidToClient[uid], true
}
return nil, false
}
// findByName finds persistent client by name.
func (ci *index) findByName(name string) (c *Persistent, found bool) {
uid, found := ci.nameToUID[name]
@@ -343,18 +352,3 @@ func (ci *index) rangeByName(f func(c *Persistent) (cont bool)) {
}
}
}
// closeUpstreams closes upstream configurations of persistent clients.
func (ci *index) closeUpstreams() (err error) {
var errs []error
ci.rangeByName(func(c *Persistent) (cont bool) {
err = c.CloseUpstreams()
if err != nil {
errs = append(errs, err)
}
return true
})
return errors.Join(errs...)
}

View File

@@ -58,12 +58,6 @@ func (uid *UID) UnmarshalText(data []byte) error {
// Persistent contains information about persistent clients.
type Persistent struct {
// UpstreamConfig is the custom upstream configuration for this client. If
// it's nil, it has not been initialized yet. If it's non-nil and empty,
// there are no valid upstreams. If it's non-nil and non-empty, these
// upstream must be used.
UpstreamConfig *proxy.CustomUpstreamConfig
// SafeSearch handles search engine hosts rewrites.
SafeSearch filtering.SafeSearch
@@ -262,7 +256,7 @@ func ValidateClientID(id string) (err error) {
return nil
}
// IDs returns a list of client IDs containing at least one element.
// IDs returns a list of ClientIDs containing at least one element.
func (c *Persistent) IDs() (ids []string) {
ids = make([]string, 0, c.IDsLen())
@@ -281,7 +275,7 @@ func (c *Persistent) IDs() (ids []string) {
return append(ids, c.ClientIDs...)
}
// IDsLen returns a length of client ids.
// IDsLen returns a length of ClientIDs.
func (c *Persistent) IDsLen() (n int) {
return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs)
}
@@ -312,14 +306,3 @@ func (c *Persistent) ShallowClone() (clone *Persistent) {
return clone
}
// CloseUpstreams closes the client-specific upstream config of c if any.
func (c *Persistent) CloseUpstreams() (err error) {
if c.UpstreamConfig != nil {
if err = c.UpstreamConfig.Close(); err != nil {
return fmt.Errorf("closing upstreams of client %q: %w", c.Name, err)
}
}
return nil
}

View File

@@ -13,9 +13,11 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/timeutil"
)
// allowedTags is the list of available client tags.
@@ -88,6 +90,10 @@ type StorageConfig struct {
// not be nil.
Logger *slog.Logger
// Clock is used by [upstreamManager] to retrieve the current time. It must
// not be nil.
Clock timeutil.Clock
// DHCP is used to match IPs against MACs of persistent clients and update
// [SourceDHCP] runtime client information. It must not be nil.
DHCP DHCP
@@ -126,6 +132,9 @@ type Storage struct {
// runtimeIndex contains information about runtime clients.
runtimeIndex *runtimeIndex
// upstreamManager stores and updates custom client upstream configurations.
upstreamManager *upstreamManager
// dhcp is used to update [SourceDHCP] runtime client information.
dhcp DHCP
@@ -163,6 +172,7 @@ func NewStorage(ctx context.Context, conf *StorageConfig) (s *Storage, err error
mu: &sync.Mutex{},
index: newIndex(),
runtimeIndex: newRuntimeIndex(),
upstreamManager: newUpstreamManager(conf.Logger, conf.Clock),
dhcp: conf.DHCP,
etcHosts: conf.EtcHosts,
arpDB: conf.ARPDB,
@@ -200,7 +210,7 @@ func (s *Storage) Start(ctx context.Context) (err error) {
func (s *Storage) Shutdown(_ context.Context) (err error) {
close(s.done)
return s.closeUpstreams()
return s.upstreamManager.close()
}
// periodicARPUpdate periodically reloads runtime clients from ARP. It is
@@ -416,6 +426,7 @@ func (s *Storage) Add(ctx context.Context, p *Persistent) (err error) {
}
s.index.add(p)
s.upstreamManager.updateCustomUpstreamConfig(p)
s.logger.DebugContext(
ctx,
@@ -441,7 +452,7 @@ func (s *Storage) FindByName(name string) (p *Persistent, ok bool) {
return nil, false
}
// Find finds persistent client by string representation of the client ID, IP
// Find finds persistent client by string representation of the ClientID, IP
// address, or MAC. And returns its shallow copy.
//
// TODO(s.chzhen): Accept ClientIDData structure instead, which will contain
@@ -514,12 +525,13 @@ func (s *Storage) RemoveByName(ctx context.Context, name string) (ok bool) {
return false
}
if err := p.CloseUpstreams(); err != nil {
s.logger.ErrorContext(ctx, "removing client", "name", p.Name, slogutil.KeyError, err)
}
s.index.remove(p)
err := s.upstreamManager.remove(p.UID)
if err != nil {
s.logger.DebugContext(ctx, "closing client upstreams", "name", name, slogutil.KeyError, err)
}
return true
}
@@ -556,6 +568,8 @@ func (s *Storage) Update(ctx context.Context, name string, p *Persistent) (err e
s.index.remove(stored)
s.index.add(p)
s.upstreamManager.updateCustomUpstreamConfig(p)
return nil
}
@@ -576,14 +590,6 @@ func (s *Storage) Size() (n int) {
return s.index.size()
}
// closeUpstreams closes upstream configurations of persistent clients.
func (s *Storage) closeUpstreams() (err error) {
s.mu.Lock()
defer s.mu.Unlock()
return s.index.closeUpstreams()
}
// ClientRuntime returns a copy of the saved runtime client by ip. If no such
// client exists, returns nil.
func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) {
@@ -626,3 +632,42 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
func (s *Storage) AllowedTags() (tags []string) {
return s.allowedTags
}
// CustomUpstreamConfig implements the [dnsforward.ClientsContainer] interface
// for *Storage
func (s *Storage) CustomUpstreamConfig(
id string,
addr netip.Addr,
) (prxConf *proxy.CustomUpstreamConfig) {
s.mu.Lock()
defer s.mu.Unlock()
c, ok := s.index.findByClientID(id)
if !ok {
c, ok = s.index.findByIP(addr)
}
if !ok {
return nil
}
return s.upstreamManager.customUpstreamConfig(c.UID)
}
// UpdateCommonUpstreamConfig implements the [dnsforward.ClientsContainer]
// interface for *Storage
func (s *Storage) UpdateCommonUpstreamConfig(conf *CommonUpstreamConfig) {
s.mu.Lock()
defer s.mu.Unlock()
s.upstreamManager.updateCommonUpstreamConfig(conf)
}
// ClearUpstreamCache implements the [dnsforward.ClientsContainer] interface for
// *Storage
func (s *Storage) ClearUpstreamCache() {
s.mu.Lock()
defer s.mu.Unlock()
s.upstreamManager.clearUpstreamCache()
}

View File

@@ -13,27 +13,34 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/testutil/faketime"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newTestStorage is a helper function that returns initialized storage.
func newTestStorage(tb testing.TB) (s *client.Storage) {
func newTestStorage(tb testing.TB, clock timeutil.Clock) (s *client.Storage) {
tb.Helper()
ctx := testutil.ContextWithTimeout(tb, testTimeout)
s, err := client.NewStorage(ctx, &client.StorageConfig{
Logger: slogutil.NewDiscardLogger(),
Clock: clock,
})
require.NoError(tb, err)
return s
}
// type check
var _ dnsforward.ClientsContainer = (*client.Storage)(nil)
// testHostsContainer is a mock implementation of the [client.HostsContainer]
// interface.
type testHostsContainer struct {
@@ -691,7 +698,7 @@ func TestStorage_Add(t *testing.T) {
}
ctx := testutil.ContextWithTimeout(t, testTimeout)
s := newTestStorage(t)
s := newTestStorage(t, timeutil.SystemClock{})
tags := s.AllowedTags()
require.NotZero(t, len(tags))
require.True(t, slices.IsSorted(tags))
@@ -822,7 +829,7 @@ func TestStorage_RemoveByName(t *testing.T) {
}
ctx := testutil.ContextWithTimeout(t, testTimeout)
s := newTestStorage(t)
s := newTestStorage(t, timeutil.SystemClock{})
err := s.Add(ctx, existingClient)
require.NoError(t, err)
@@ -847,7 +854,7 @@ func TestStorage_RemoveByName(t *testing.T) {
}
t.Run("duplicate_remove", func(t *testing.T) {
s = newTestStorage(t)
s = newTestStorage(t, timeutil.SystemClock{})
err = s.Add(ctx, existingClient)
require.NoError(t, err)
@@ -1278,3 +1285,99 @@ func TestStorage_RangeByName(t *testing.T) {
})
}
}
func TestStorage_CustomUpstreamConfig(t *testing.T) {
const (
existingName = "existing_name"
existingClientID = "existing_client_id"
nonExistingClientID = "non_existing_client_id"
)
var (
existingClientUID = client.MustNewUID()
existingIP = netip.MustParseAddr("192.0.2.1")
nonExistingIP = netip.MustParseAddr("192.0.2.255")
testUpstreamTimeout = time.Second
)
existingClient := &client.Persistent{
Name: existingName,
IPs: []netip.Addr{existingIP},
ClientIDs: []string{existingClientID},
UID: existingClientUID,
Upstreams: []string{"192.0.2.0"},
}
date := time.Now()
clock := &faketime.Clock{
OnNow: func() (now time.Time) {
date = date.Add(time.Second)
return date
},
}
s := newTestStorage(t, clock)
s.UpdateCommonUpstreamConfig(&client.CommonUpstreamConfig{
UpstreamTimeout: testUpstreamTimeout,
})
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return s.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
})
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := s.Add(ctx, existingClient)
require.NoError(t, err)
testCases := []struct {
cliAddr netip.Addr
wantNilConf assert.ValueAssertionFunc
name string
cliID string
}{{
name: "client_id",
cliID: existingClientID,
cliAddr: netip.Addr{},
wantNilConf: assert.NotNil,
}, {
name: "client_addr",
cliID: "",
cliAddr: existingIP,
wantNilConf: assert.NotNil,
}, {
name: "non_existing_client_id",
cliID: nonExistingClientID,
cliAddr: netip.Addr{},
wantNilConf: assert.Nil,
}, {
name: "non_existing_client_addr",
cliID: "",
cliAddr: nonExistingIP,
wantNilConf: assert.Nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
conf := s.CustomUpstreamConfig(tc.cliID, tc.cliAddr)
tc.wantNilConf(t, conf)
})
}
t.Run("update_common_config", func(t *testing.T) {
conf := s.CustomUpstreamConfig(existingClientID, existingIP)
require.NotNil(t, conf)
s.UpdateCommonUpstreamConfig(&client.CommonUpstreamConfig{
UpstreamTimeout: testUpstreamTimeout * 2,
})
updConf := s.CustomUpstreamConfig(existingClientID, existingIP)
require.NotNil(t, updConf)
assert.NotEqual(t, conf, updConf)
})
}

View File

@@ -0,0 +1,224 @@
package client
import (
"fmt"
"log/slog"
"slices"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/timeutil"
)
// CommonUpstreamConfig contains common settings for custom client upstream
// configurations.
type CommonUpstreamConfig struct {
Bootstrap upstream.Resolver
UpstreamTimeout time.Duration
BootstrapPreferIPv6 bool
EDNSClientSubnetEnabled bool
UseHTTP3Upstreams bool
}
// customUpstreamConfig contains custom client upstream configuration and the
// timestamp of the latest configuration update.
type customUpstreamConfig struct {
// proxyConf is the constructed upstream configuration for the [proxy],
// derived from the fields below. It is initialized on demand with
// [newCustomUpstreamConfig].
proxyConf *proxy.CustomUpstreamConfig
// commonConfUpdate is the timestamp of the latest configuration update,
// used to check against [upstreamManager.confUpdate] to determine if the
// configuration is up to date.
commonConfUpdate time.Time
// upstreams is the cached list of custom upstream DNS servers used for the
// configuration of proxyConf.
upstreams []string
// upstreamsCacheSize is the cached value of the cache size of the
// upstreams, used for the configuration of proxyConf.
upstreamsCacheSize uint32
// upstreamsCacheEnabled is the cached value indicating whether the cache of
// the upstreams is enabled for the configuration of proxyConf.
upstreamsCacheEnabled bool
// isChanged indicates whether the proxyConf needs to be updated.
isChanged bool
}
// upstreamManager stores and updates custom client upstream configurations.
type upstreamManager struct {
// logger is used for logging the operation of the upstream manager. It
// must not be nil.
//
// TODO(s.chzhen): Consider using a logger with its own prefix.
logger *slog.Logger
// uidToCustomConf maps persistent client UID to the custom client upstream
// configuration. Stored UIDs must be in sync with the [index.uidToClient].
uidToCustomConf map[UID]*customUpstreamConfig
// commonConf is the common upstream configuration.
commonConf *CommonUpstreamConfig
// clock is used to get the current time. It must not be nil.
clock timeutil.Clock
// confUpdate is the timestamp of the latest common upstream configuration
// update.
confUpdate time.Time
}
// newUpstreamManager returns the new properly initialized upstream manager.
func newUpstreamManager(logger *slog.Logger, clock timeutil.Clock) (m *upstreamManager) {
return &upstreamManager{
logger: logger,
uidToCustomConf: make(map[UID]*customUpstreamConfig),
clock: clock,
}
}
// updateCommonUpstreamConfig updates the common upstream configuration and the
// timestamp of the latest configuration update.
func (m *upstreamManager) updateCommonUpstreamConfig(conf *CommonUpstreamConfig) {
m.commonConf = conf
m.confUpdate = m.clock.Now()
}
// updateCustomUpstreamConfig updates the stored custom client upstream
// configuration associated with the persistent client. It also sets
// [customUpstreamConfig.isChanged] to true so [customUpstreamConfig.proxyConf]
// can be updated later in [upstreamManager.customUpstreamConfig].
func (m *upstreamManager) updateCustomUpstreamConfig(c *Persistent) {
cliConf, ok := m.uidToCustomConf[c.UID]
if !ok {
cliConf = &customUpstreamConfig{
commonConfUpdate: m.confUpdate,
}
m.uidToCustomConf[c.UID] = cliConf
}
// TODO(s.chzhen): Compare before cloning.
cliConf.upstreams = slices.Clone(c.Upstreams)
cliConf.upstreamsCacheSize = c.UpstreamsCacheSize
cliConf.upstreamsCacheEnabled = c.UpstreamsCacheEnabled
cliConf.isChanged = true
}
// customUpstreamConfig returns the custom client upstream configuration.
func (m *upstreamManager) customUpstreamConfig(uid UID) (proxyConf *proxy.CustomUpstreamConfig) {
cliConf, ok := m.uidToCustomConf[uid]
if !ok {
// TODO(s.chzhen): Consider panic.
m.logger.Error("no associated custom client upstream config")
return nil
}
if !m.isConfigChanged(cliConf) {
return cliConf.proxyConf
}
if cliConf.proxyConf != nil {
err := cliConf.proxyConf.Close()
if err != nil {
// TODO(s.chzhen): Pass context.
m.logger.Debug("closing custom upstream config", slogutil.KeyError, err)
}
}
proxyConf = newCustomUpstreamConfig(cliConf, m.commonConf)
cliConf.proxyConf = proxyConf
cliConf.isChanged = false
return proxyConf
}
// isConfigChanged returns true if the update is necessary for the custom client
// upstream configuration.
func (m *upstreamManager) isConfigChanged(cliConf *customUpstreamConfig) (ok bool) {
return !m.confUpdate.Equal(cliConf.commonConfUpdate) || cliConf.isChanged
}
// clearUpstreamCache clears the upstream cache for each stored custom client
// upstream configuration.
func (m *upstreamManager) clearUpstreamCache() {
for _, c := range m.uidToCustomConf {
c.proxyConf.ClearCache()
}
}
// remove deletes the custom client upstream configuration and closes
// [customUpstreamConfig.proxyConf] if necessary.
func (m *upstreamManager) remove(uid UID) (err error) {
cliConf, ok := m.uidToCustomConf[uid]
if !ok {
// TODO(s.chzhen): Consider panic.
return errors.Error("no associated custom client upstream config")
}
delete(m.uidToCustomConf, uid)
if cliConf.proxyConf != nil {
return cliConf.proxyConf.Close()
}
return nil
}
// close shuts down each stored custom client upstream configuration.
func (m *upstreamManager) close() (err error) {
var errs []error
for _, c := range m.uidToCustomConf {
if c.proxyConf == nil {
continue
}
errs = append(errs, c.proxyConf.Close())
}
return errors.Join(errs...)
}
// newCustomUpstreamConfig returns the new properly initialized custom proxy
// upstream configuration for the client.
func newCustomUpstreamConfig(
cliConf *customUpstreamConfig,
conf *CommonUpstreamConfig,
) (proxyConf *proxy.CustomUpstreamConfig) {
upstreams := stringutil.FilterOut(cliConf.upstreams, aghnet.IsCommentOrEmpty)
if len(upstreams) == 0 {
return nil
}
upsConf, err := proxy.ParseUpstreamsConfig(
upstreams,
&upstream.Options{
Bootstrap: conf.Bootstrap,
Timeout: time.Duration(conf.UpstreamTimeout),
HTTPVersions: aghnet.UpstreamHTTPVersions(conf.UseHTTP3Upstreams),
PreferIPv6: conf.BootstrapPreferIPv6,
},
)
if err != nil {
// Should not happen because upstreams are already validated. See
// [Persistent.validate].
panic(fmt.Errorf("creating custom upstream config: %w", err))
}
return proxy.NewCustomUpstreamConfig(
upsConf,
cliConf.upstreamsCacheEnabled,
int(cliConf.upstreamsCacheSize),
conf.EDNSClientSubnetEnabled,
)
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/binary"
"fmt"
"net"
"slices"
"sync/atomic"
"time"
@@ -14,18 +15,48 @@ import (
"golang.org/x/net/ipv6"
)
// raCtx is a context for the Router Advertisement logic.
type raCtx struct {
raAllowSLAAC bool // send RA packets without MO flags
raSLAACOnly bool // send RA packets with MO flags
ipAddr net.IP // source IP address (link-local-unicast)
dnsIPAddr net.IP // IP address for DNS Server option
prefixIPAddr net.IP // IP address for Prefix option
ifaceName string
iface *net.Interface
packetSendPeriod time.Duration // how often RA packets are sent
// raAllowSLAAC is used to determine if the ICMP Router Advertisement
// messages should be sent.
//
// If both raAllowSLAAC and raSLAACOnly are false, the Router Advertisement
// messages aren't sent.
raAllowSLAAC bool
conn *icmp.PacketConn // ICMPv6 socket
stop atomic.Value // stop the packet sending loop
// raSLAACOnly is used to determine if the ICMP Router Advertisement
// messages should set M and O flags, see RFC 4861, section 4.2.
//
// If both raAllowSLAAC and raSLAACOnly are false, the Router Advertisement
// messages aren't sent.
raSLAACOnly bool
// ipAddr is an IP address used within the Source Link-Layer Address option.
// See RFC 4861, section 4.6.1.
ipAddr net.IP
// dnsIPAddr is an IP address used within the DNS Server option.
dnsIPAddr net.IP
// prefixIPAddr is an IP address used within the Prefix Information option.
// See RFC 4861, section 4.6.2.
prefixIPAddr net.IP
// ifaceName is the name of the interface used as a scope of the IP
// addresses.
ifaceName string
// iface is the network interface used to send the ICMPv6 packets.
iface *net.Interface
// packetSendPeriod is the interval between sending the ICMPv6 packets.
packetSendPeriod time.Duration
// conn is the ICMPv6 socket.
conn *icmp.PacketConn
// stop is used to stop the packet sending loop.
stop atomic.Value
}
type icmpv6RA struct {
@@ -38,10 +69,11 @@ type icmpv6RA struct {
mtu uint32
}
// hwAddrToLinkLayerAddr converts a hardware address into a form required by
// RFC4861. That is, a byte slice of length divisible by 8.
// hwAddrToLinkLayerAddr clones the hardware address and returns it as a byte
// slice suitable for the Source Link-Layer Address option in the ICMPv6
// Router Advertisement packet.
//
// See https://tools.ietf.org/html/rfc4861#section-4.6.1.
// TODO(e.burkov): Check if it's safe to use the original slice.
func hwAddrToLinkLayerAddr(hwa net.HardwareAddr) (lla []byte, err error) {
err = netutil.ValidateMAC(hwa)
if err != nil {
@@ -50,19 +82,7 @@ func hwAddrToLinkLayerAddr(hwa net.HardwareAddr) (lla []byte, err error) {
return nil, err
}
if len(hwa) == 6 || len(hwa) == 8 {
lla = make([]byte, 8)
copy(lla, hwa)
return lla, nil
}
// Assume that netutil.ValidateMAC prevents lengths other than 20 by
// now.
lla = make([]byte, 24)
copy(lla, hwa)
return lla, nil
return slices.Clone(hwa), nil
}
// Create an ICMPv6.RouterAdvertisement packet with all necessary options.
@@ -103,15 +123,24 @@ func hwAddrToLinkLayerAddr(hwa net.HardwareAddr) (lla []byte, err error) {
//
// TODO(a.garipov): Replace with an existing implementation from a dependency.
func createICMPv6RAPacket(params icmpv6RA) (data []byte, err error) {
var lla []byte
lla, err = hwAddrToLinkLayerAddr(params.sourceLinkLayerAddress)
lla, err := hwAddrToLinkLayerAddr(params.sourceLinkLayerAddress)
if err != nil {
return nil, fmt.Errorf("converting source link layer address: %w", err)
return nil, fmt.Errorf("converting source link-layer address: %w", err)
}
// Calculate length of the source link-layer address option. As per RFC
// 4861, section 4.6.1, the length should be in units of 8 octets, including
// the type and length fields.
//
// See https://datatracker.ietf.org/doc/html/rfc4861#section-4.6.1.
srcLLAOptLen := len(lla) + 2
// Make sure the value is rounded up to the nearest multiple of 8.
srcLLAOptLenValue := (srcLLAOptLen + 7) / 8
srcLLAPadLen := srcLLAOptLenValue*8 - srcLLAOptLen
// TODO(a.garipov): Don't use a magic constant here. Refactor the code
// and make all constants named instead of all those comments..
data = make([]byte, 82+len(lla))
// and make all constants named instead of all those comments.
data = make([]byte, 80+srcLLAOptLen+srcLLAPadLen)
i := 0
// ICMPv6:
@@ -175,12 +204,11 @@ func createICMPv6RAPacket(params icmpv6RA) (data []byte, err error) {
// Option=Source link-layer address:
data[i] = 1 // Type
data[i+1] = 1 // Length
data[i] = 1 // Type
data[i+1] = byte(srcLLAOptLenValue) // Length
i += 2
copy(data[i:], lla) // Link-Layer Address[8/24]
i += len(lla)
i += len(lla) + srcLLAPadLen
// Option=Recursive DNS Server:

View File

@@ -0,0 +1,63 @@
package dhcpd
import (
"net"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCreateICMPv6RAPacket(t *testing.T) {
raConf := icmpv6RA{
managedAddressConfiguration: false,
otherConfiguration: true,
mtu: 1500,
prefix: net.ParseIP("1234::"),
prefixLen: 64,
recursiveDNSServer: net.ParseIP("fe80::800:27ff:fe00:0"),
sourceLinkLayerAddress: []byte{0x0A, 0x00, 0x27, 0x00, 0x00, 0x00},
}
pkt, err := createICMPv6RAPacket(raConf)
require.NoError(t, err)
icmpPkt := &layers.ICMPv6{}
err = icmpPkt.DecodeFromBytes(pkt, gopacket.NilDecodeFeedback)
require.NoError(t, err)
require.Equal(t, layers.LayerTypeICMPv6RouterAdvertisement, icmpPkt.NextLayerType())
raPkt := &layers.ICMPv6RouterAdvertisement{}
err = raPkt.DecodeFromBytes(icmpPkt.LayerPayload(), gopacket.NilDecodeFeedback)
require.NoError(t, err)
assert.Equal(t, raConf.managedAddressConfiguration, raPkt.ManagedAddressConfig())
assert.Equal(t, raConf.otherConfiguration, raPkt.OtherConfig())
wantOpts := layers.ICMPv6Options{{
Type: layers.ICMPv6OptPrefixInfo,
Data: []uint8{
0x40, 0xC0, 0x00, 0x00, 0x0E, 0x10, 0x00, 0x00,
0x0E, 0x10, 0x00, 0x00, 0x00, 0x00, 0x12, 0x34,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
},
}, {
Type: layers.ICMPv6OptMTU,
Data: []uint8{0x00, 0x00, 0x00, 0x00, 0x05, 0xDC},
}, {
Type: layers.ICMPv6OptSourceAddress,
Data: []uint8{0x0A, 0x00, 0x27, 0x00, 0x00, 0x0},
}, {
// Package layers declares no constant for Recursive DNS Server option.
Type: layers.ICMPv6Opt(25),
Data: []uint8{
0x00, 0x00, 0x00, 0x00, 0x0E, 0x10, 0xFE, 0x80,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00,
0x27, 0xFF, 0xFE, 0x00, 0x00, 0x00,
},
}}
assert.Equal(t, wantOpts, raPkt.Options)
}

View File

@@ -1,38 +0,0 @@
package dhcpd
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCreateICMPv6RAPacket(t *testing.T) {
wantData := []byte{
0x86, 0x00, 0x00, 0x00, 0x40, 0x40, 0x07, 0x08,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x03, 0x04, 0x40, 0xc0, 0x00, 0x00, 0x0e, 0x10,
0x00, 0x00, 0x0e, 0x10, 0x00, 0x00, 0x00, 0x00,
0x12, 0x34, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x05, 0x01, 0x00, 0x00, 0x00, 0x00, 0x05, 0xdc,
0x01, 0x01, 0x0a, 0x00, 0x27, 0x00, 0x00, 0x00,
0x00, 0x00, 0x19, 0x03, 0x00, 0x00, 0x00, 0x00,
0x0e, 0x10, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x08, 0x00, 0x27, 0xff, 0xfe, 0x00,
0x00, 0x00,
}
gotData, err := createICMPv6RAPacket(icmpv6RA{
managedAddressConfiguration: false,
otherConfiguration: true,
mtu: 1500,
prefix: net.ParseIP("1234::"),
prefixLen: 64,
recursiveDNSServer: net.ParseIP("fe80::800:27ff:fe00:0"),
sourceLinkLayerAddress: []byte{0x0a, 0x00, 0x27, 0x00, 0x00, 0x00},
})
assert.NoError(t, err)
assert.Equal(t, wantData, gotData)
}

View File

@@ -239,7 +239,7 @@ func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
err = validateAccessSet(list)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, err.Error())
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}

View File

@@ -15,7 +15,7 @@ import (
var _ proxy.BeforeRequestHandler = (*Server)(nil)
// HandleBefore is the handler that is called before any other processing,
// including logs. It performs access checks and puts the client ID, if there
// including logs. It performs access checks and puts the ClientID, if there
// is one, into the server's cache.
//
// TODO(d.kolyshev): Extract to separate package.

View File

@@ -266,6 +266,7 @@ func TestServer_HandleBefore_udp(t *testing.T) {
UpstreamDNS: []string{localUpsAddr},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
})

View File

@@ -62,7 +62,7 @@ func clientIDFromClientServerName(
return strings.ToLower(clientID), nil
}
// clientIDFromDNSContextHTTPS extracts the client's ID from the path of the
// clientIDFromDNSContextHTTPS extracts the ClientID from the path of the
// client's DNS-over-HTTPS request.
func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err error) {
r := pctx.HTTPRequest

View File

@@ -0,0 +1,46 @@
package dnsforward
import (
"net/netip"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/dnsproxy/proxy"
)
// ClientsContainer provides information about preconfigured DNS clients.
type ClientsContainer interface {
// CustomUpstreamConfig returns the custom client upstream configuration, if
// any. It prioritizes ClientID over client IP address to identify the
// client.
CustomUpstreamConfig(clientID string, cliAddr netip.Addr) (conf *proxy.CustomUpstreamConfig)
// UpdateCommonUpstreamConfig updates the common upstream configuration.
UpdateCommonUpstreamConfig(conf *client.CommonUpstreamConfig)
// ClearUpstreamCache clears the upstream cache for each stored custom
// client upstream configuration.
ClearUpstreamCache()
}
// EmptyClientsContainer is an [ClientsContainer] implementation that does nothing.
type EmptyClientsContainer struct{}
// type check
var _ ClientsContainer = EmptyClientsContainer{}
// CustomUpstreamConfig implements the [ClientsContainer] interface for
// EmptyClientsContainer.
func (EmptyClientsContainer) CustomUpstreamConfig(
clientID string,
cliAddr netip.Addr,
) (conf *proxy.CustomUpstreamConfig) {
return nil
}
// UpdateCommonUpstreamConfig implements the [ClientsContainer] interface for
// EmptyClientsContainer.
func (EmptyClientsContainer) UpdateCommonUpstreamConfig(conf *client.CommonUpstreamConfig) {}
// ClearUpstreamCache implements the [ClientsContainer] interface for
// EmptyClientsContainer.
func (EmptyClientsContainer) ClearUpstreamCache() {}

View File

@@ -29,19 +29,6 @@ import (
"github.com/ameshkov/dnscrypt/v2"
)
// ClientsContainer provides information about preconfigured DNS clients.
type ClientsContainer interface {
// UpstreamConfigByID returns the custom upstream configuration for the
// client having id, using boot to initialize the one if necessary. It
// returns nil if there is no custom upstream configuration for the client.
// The id is expected to be either a string representation of an IP address
// or the ClientID.
UpstreamConfigByID(
id string,
boot upstream.Resolver,
) (conf *proxy.CustomUpstreamConfig, err error)
}
// Config represents the DNS filtering configuration of AdGuard Home. The zero
// Config is empty and ready for use.
type Config struct {
@@ -467,7 +454,7 @@ func (s *Server) prepareIpsetListSettings() (ipsets []string, err error) {
}
ipsets = stringutil.SplitTrimmed(string(data), "\n")
ipsets = slices.DeleteFunc(ipsets, IsCommentOrEmpty)
ipsets = slices.DeleteFunc(ipsets, aghnet.IsCommentOrEmpty)
log.Debug("dns: using %d ipset rules from file %q", len(ipsets), fn)
@@ -478,7 +465,7 @@ func (s *Server) prepareIpsetListSettings() (ipsets []string, err error) {
// the configuration itself.
func (conf *ServerConfig) loadUpstreams() (upstreams []string, err error) {
if conf.UpstreamDNSFileName == "" {
return stringutil.FilterOut(conf.UpstreamDNS, IsCommentOrEmpty), nil
return stringutil.FilterOut(conf.UpstreamDNS, aghnet.IsCommentOrEmpty), nil
}
var data []byte
@@ -491,7 +478,7 @@ func (conf *ServerConfig) loadUpstreams() (upstreams []string, err error) {
log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), conf.UpstreamDNSFileName)
return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil
return stringutil.FilterOut(upstreams, aghnet.IsCommentOrEmpty), nil
}
// collectListenAddr adds addrPort to addrs. It also adds its port to

View File

@@ -299,6 +299,7 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
UpstreamDNS: []string{upsAddr},
},
UsePrivateRDNS: true,
@@ -337,6 +338,7 @@ func TestServer_dns64WithDisabledRDNS(t *testing.T) {
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
UpstreamDNS: []string{upsAddr},
},
UsePrivateRDNS: false,

View File

@@ -540,7 +540,7 @@ func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
uc, err := newUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
Bootstrap: boot,
Timeout: s.conf.UpstreamTimeout,
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
HTTPVersions: aghnet.UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
PreferIPv6: s.conf.BootstrapPreferIPv6,
// Use a customized set of RootCAs, because Go's default mechanism of
// loading TLS roots does not always work properly on some routers so we're
@@ -557,6 +557,13 @@ func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
}
s.conf.UpstreamConfig = uc
s.conf.ClientsContainer.UpdateCommonUpstreamConfig(&client.CommonUpstreamConfig{
Bootstrap: boot,
UpstreamTimeout: s.conf.UpstreamTimeout,
BootstrapPreferIPv6: s.conf.BootstrapPreferIPv6,
EDNSClientSubnetEnabled: s.conf.EDNSClientSubnet.Enabled,
UseHTTP3Upstreams: s.conf.UseHTTP3Upstreams,
})
return nil
}
@@ -630,7 +637,7 @@ func (s *Server) prepareInternalDNS() (err error) {
bootOpts := &upstream.Options{
Timeout: DefaultTimeout,
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
HTTPVersions: aghnet.UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
}
s.bootstrap, s.bootResolvers, err = newBootstrap(s.conf.BootstrapDNS, s.etcHosts, bootOpts)
@@ -661,7 +668,7 @@ func (s *Server) prepareInternalDNS() (err error) {
// setupFallbackDNS initializes the fallback DNS servers.
func (s *Server) setupFallbackDNS() (uc *proxy.UpstreamConfig, err error) {
fallbacks := s.conf.FallbackDNS
fallbacks = stringutil.FilterOut(fallbacks, IsCommentOrEmpty)
fallbacks = stringutil.FilterOut(fallbacks, aghnet.IsCommentOrEmpty)
if len(fallbacks) == 0 {
return nil, nil
}

View File

@@ -23,6 +23,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/hashprefix"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
@@ -61,6 +62,42 @@ const (
// TODO(a.garipov): Use more.
var testClientAddrPort = netip.MustParseAddrPort("1.2.3.4:12345")
// type check
var _ ClientsContainer = (*clientsContainer)(nil)
// clientsContainer is a mock [ClientsContainer] implementation for tests.
type clientsContainer struct {
OnCustomUpstreamConfig func(
clientID string,
cliAddr netip.Addr,
) (conf *proxy.CustomUpstreamConfig)
OnUpdateCommonUpstreamConfig func(conf *client.CommonUpstreamConfig)
OnClearUpstreamCache func()
}
// CustomUpstreamConfig implements the [ClientsContainer] interface for
// *clientsContainer.
func (c *clientsContainer) CustomUpstreamConfig(
clientID string,
cliAddr netip.Addr,
) (conf *proxy.CustomUpstreamConfig) {
return c.OnCustomUpstreamConfig(clientID, cliAddr)
}
// UpdateCommonUpstreamConfig implements the [ClientsContainer] interface for
// *clientsContainer.
func (c *clientsContainer) UpdateCommonUpstreamConfig(conf *client.CommonUpstreamConfig) {
c.OnUpdateCommonUpstreamConfig(conf)
}
// ClearUpstreamCache implements the [ClientsContainer] interface for
// *clientsContainer.
func (c *clientsContainer) ClearUpstreamCache() {
c.OnClearUpstreamCache()
}
func startDeferStop(t *testing.T, s *Server) {
t.Helper()
@@ -168,6 +205,7 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte)
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
})
@@ -297,6 +335,7 @@ func TestServer(t *testing.T) {
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
})
@@ -337,6 +376,7 @@ func TestServer_timeout(t *testing.T) {
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
}
@@ -364,6 +404,7 @@ func TestServer_timeout(t *testing.T) {
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{
Enabled: false,
}
s.conf.Config.ClientsContainer = EmptyClientsContainer{}
err = s.Prepare(&s.conf)
require.NoError(t, err)
@@ -380,6 +421,7 @@ func TestServer_Prepare_fallbacks(t *testing.T) {
},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
}
@@ -405,6 +447,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
})
@@ -536,6 +579,7 @@ func TestSafeSearch(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
}
@@ -629,6 +673,7 @@ func TestInvalidRequest(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
})
@@ -659,6 +704,7 @@ func TestBlockedRequest(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
}
@@ -696,6 +742,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
}
@@ -721,12 +768,12 @@ func TestServerCustomClientUpstream(t *testing.T) {
forwardConf.EDNSClientSubnet.Enabled,
)
s.conf.ClientsContainer = &aghtest.ClientsContainer{
OnUpstreamConfigByID: func(
s.conf.ClientsContainer = &clientsContainer{
OnCustomUpstreamConfig: func(
_ string,
_ upstream.Resolver,
) (conf *proxy.CustomUpstreamConfig, err error) {
return customUpsConf, nil
_ netip.Addr,
) (conf *proxy.CustomUpstreamConfig) {
return customUpsConf
},
}
@@ -774,6 +821,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
})
@@ -808,6 +856,7 @@ func TestBlockCNAME(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
}
@@ -884,6 +933,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
}
@@ -930,6 +980,7 @@ func TestNullBlockedRequest(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
}
@@ -998,6 +1049,7 @@ func TestBlockedCustomIP(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
}
@@ -1051,6 +1103,7 @@ func TestBlockedByHosts(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
}
@@ -1103,6 +1156,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
}
@@ -1164,6 +1218,7 @@ func TestRewrite(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
}))
@@ -1290,6 +1345,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
s.conf.Config.ClientsContainer = EmptyClientsContainer{}
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
err = s.Prepare(&s.conf)
@@ -1375,6 +1431,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
s.conf.Config.ClientsContainer = EmptyClientsContainer{}
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
err = s.Prepare(&s.conf)
@@ -1643,6 +1700,7 @@ func TestServer_Exchange(t *testing.T) {
UpstreamDNS: []string{upsAddr},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
},
LocalPTRResolvers: []string{localUpsAddr},
UsePrivateRDNS: true,
@@ -1665,6 +1723,7 @@ func TestServer_Exchange(t *testing.T) {
UpstreamDNS: []string{upsAddr},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
},
LocalPTRResolvers: []string{},
ServePlainDNS: true,

View File

@@ -40,6 +40,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
})

View File

@@ -36,6 +36,7 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
}

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
@@ -647,7 +648,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
return
}
req.BootstrapDNS = stringutil.FilterOut(req.BootstrapDNS, IsCommentOrEmpty)
req.BootstrapDNS = stringutil.FilterOut(req.BootstrapDNS, aghnet.IsCommentOrEmpty)
opts := &upstream.Options{
Timeout: s.conf.UpstreamTimeout,
@@ -673,6 +674,8 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
// handleCacheClear is the handler for the POST /control/cache_clear HTTP API.
func (s *Server) handleCacheClear(w http.ResponseWriter, _ *http.Request) {
s.dnsProxy.ClearCache()
s.conf.ClientsContainer.ClearUpstreamCache()
_, _ = io.WriteString(w, "OK")
}

View File

@@ -83,6 +83,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
RatelimitSubnetLenIPv6: 56,
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
},
ConfigModified: func() {},
ServePlainDNS: true,
@@ -164,6 +165,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
RatelimitSubnetLenIPv6: 56,
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
},
ConfigModified: func() {},
ServePlainDNS: true,
@@ -299,24 +301,6 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
}
}
func TestIsCommentOrEmpty(t *testing.T) {
for _, tc := range []struct {
want assert.BoolAssertionFunc
str string
}{{
want: assert.True,
str: "",
}, {
want: assert.True,
str: "# comment",
}, {
want: assert.False,
str: "1.2.3.4",
}} {
tc.want(t, IsCommentOrEmpty(tc.str))
}
}
func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) {
t.Helper()
@@ -388,6 +372,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
})

View File

@@ -1,7 +1,6 @@
package dnsforward
import (
"cmp"
"context"
"encoding/binary"
"net"
@@ -577,17 +576,14 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
return
}
// Use the ClientID first, since it has a higher priority.
id := cmp.Or(clientID, pctx.Addr.Addr().String())
upsConf, err := s.conf.ClientsContainer.UpstreamConfigByID(id, s.bootstrap)
if err != nil {
log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err)
return
}
cliAddr := pctx.Addr.Addr()
upsConf := s.conf.ClientsContainer.CustomUpstreamConfig(clientID, cliAddr)
if upsConf != nil {
log.Debug("dnsforward: using custom upstreams for client %s", id)
log.Debug(
"dnsforward: using custom upstreams for client with ip %s and clientid %q",
cliAddr,
clientID,
)
pctx.CustomUpstreamConfig = upsConf
}

View File

@@ -81,6 +81,7 @@ func TestServer_ProcessInitial(t *testing.T) {
AAAADisabled: tc.aaaaDisabled,
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
}
@@ -180,6 +181,7 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
AAAADisabled: tc.aaaaDisabled,
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
}
@@ -324,6 +326,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
HandleDDR: tc.ddrEnabled,
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
},
TLSConfig: TLSConfig{
ServerName: ddrTestDomainName,
@@ -660,6 +663,7 @@ func TestServer_HandleDNSRequest_restrictLocal(t *testing.T) {
UpstreamDNS: []string{localUpsAddr},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
},
UsePrivateRDNS: true,
LocalPTRResolvers: []string{localUpsAddr},
@@ -788,6 +792,7 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) {
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
},
UsePrivateRDNS: true,
LocalPTRResolvers: []string{localUpsAddr},
@@ -816,6 +821,7 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) {
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
},
UsePrivateRDNS: false,
LocalPTRResolvers: []string{localUpsAddr},

View File

@@ -19,6 +19,7 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
ClientsContainer: EmptyClientsContainer{},
},
ServePlainDNS: true,
})

View File

@@ -94,7 +94,7 @@ func newPrivateConfig(
) (uc *proxy.UpstreamConfig, err error) {
confNeedsFiltering := len(addrs) > 0
if confNeedsFiltering {
addrs = stringutil.FilterOut(addrs, IsCommentOrEmpty)
addrs = stringutil.FilterOut(addrs, aghnet.IsCommentOrEmpty)
} else {
sysResolvers := slices.DeleteFunc(slices.Clone(sysResolvers.Addrs()), unwanted.Has)
addrs = make([]string, 0, len(sysResolvers))
@@ -127,20 +127,6 @@ func newPrivateConfig(
return uc, nil
}
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration
// depending on configuration.
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
if !http3 {
return upstream.DefaultHTTPVersions
}
return []upstream.HTTPVersion{
upstream.HTTPVersion3,
upstream.HTTPVersion2,
upstream.HTTPVersion11,
}
}
// setProxyUpstreamMode sets the upstream mode and related settings in conf
// based on provided parameters.
func setProxyUpstreamMode(
@@ -162,10 +148,3 @@ func setProxyUpstreamMode(
return nil
}
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
// This function is useful for filtering out non-upstream lines from upstream
// configs.
func IsCommentOrEmpty(s string) (ok bool) {
return len(s) == 0 || s[0] == '#'
}

View File

@@ -661,8 +661,6 @@ func TestClientSettings(t *testing.T) {
}
}
// Benchmarks.
func BenchmarkSafeBrowsing(b *testing.B) {
d, setts := newForTest(b, &Config{
SafeBrowsingEnabled: true,
@@ -670,15 +668,26 @@ func BenchmarkSafeBrowsing(b *testing.B) {
}, nil)
b.Cleanup(d.Close)
for range b.N {
res, err := d.CheckHost(sbBlocked, dns.TypeA, setts)
require.NoError(b, err)
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
var res Result
var err error
b.ReportAllocs()
for b.Loop() {
res, err = d.CheckHost(sbBlocked, dns.TypeA, setts)
}
require.NoError(b, err)
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
// Most recent results:
//
// goos: darwin
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardHome/internal/filtering
// cpu: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
// BenchmarkSafeBrowsing-12 358934 2994 ns/op 1304 B/op 40 allocs/op
}
func BenchmarkSafeBrowsingParallel(b *testing.B) {
func BenchmarkSafeBrowsing_parallel(b *testing.B) {
d, setts := newForTest(b, &Config{
SafeBrowsingEnabled: true,
SafeBrowsingChecker: newChecker(sbBlocked),
@@ -693,4 +702,12 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
}
})
// Most recent results:
//
// goos: darwin
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardHome/internal/filtering
// cpu: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
// BenchmarkSafeBrowsing_parallel-12 507327 2382 ns/op 1352 B/op 42 allocs/op
}

View File

@@ -244,7 +244,7 @@ func (d *DNSFilter) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
restart, err := d.filterSetProperties(fj.URL, filt, fj.Whitelist)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, err.Error())
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}

View File

@@ -202,34 +202,30 @@ func TestParser_Parse_checksums(t *testing.T) {
assert.Equal(t, gotWithoutComments, gotWithComments)
}
var (
resSink *rulelist.ParseResult
errSink error
)
func BenchmarkParser_Parse(b *testing.B) {
dst := &bytes.Buffer{}
src := strings.NewReader(strings.Repeat(testRuleTextBlocked, 1000))
buf := make([]byte, rulelist.DefaultRuleBufSize)
p := rulelist.NewParser()
var res *rulelist.ParseResult
var err error
b.ReportAllocs()
b.ResetTimer()
for range b.N {
resSink, errSink = p.Parse(dst, src, buf)
for b.Loop() {
res, err = p.Parse(dst, src, buf)
dst.Reset()
}
require.NoError(b, errSink)
require.NotNil(b, resSink)
require.NoError(b, err)
require.NotNil(b, res)
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
// Most recent results:
//
// goos: linux
// goos: darwin
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
// BenchmarkParser_Parse-16 100000000 128.0 ns/op 48 B/op 1 allocs/op
// cpu: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
// BenchmarkParser_Parse-12 19635926 53.70 ns/op 48 B/op 1 allocs/op
}
func FuzzParser_Parse(f *testing.F) {

View File

@@ -88,16 +88,25 @@ func TestSafeSearchCacheYandex(t *testing.T) {
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
}
const googleHost = "www.google.com"
func BenchmarkDefault_SearchHost(b *testing.B) {
const googleHost = "www.google.com"
var dnsRewriteSink *rules.DNSRewrite
func BenchmarkSafeSearch(b *testing.B) {
ss := newForTest(b, defaultSafeSearchConf)
for range b.N {
dnsRewriteSink = ss.searchHost(googleHost, testQType)
var rewrite *rules.DNSRewrite
b.ReportAllocs()
for b.Loop() {
rewrite = ss.searchHost(googleHost, testQType)
}
assert.Equal(b, "forcesafesearch.google.com", dnsRewriteSink.NewCNAME)
require.NotNil(b, rewrite)
assert.Equal(b, "forcesafesearch.google.com", rewrite.NewCNAME)
// Most recent results:
//
// goos: darwin
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch
// cpu: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
// BenchmarkDefault_SearchHost-12 751882 1604 ns/op 129 B/op 5 allocs/op
}

View File

@@ -356,7 +356,7 @@ func (a *Auth) getCurrentUser(r *http.Request) (u webUser) {
// There's no Cookie, check Basic authentication.
user, pass, ok := r.BasicAuth()
if ok {
u, _ = Context.auth.findUser(user, pass)
u, _ = globalContext.auth.findUser(user, pass)
return u
}
@@ -408,13 +408,12 @@ func (a *Auth) authRequired() bool {
// bytes of sessionTokenSize length.
//
// TODO(e.burkov): Think about using byte array instead of byte slice.
func newSessionToken() (data []byte, err error) {
func newSessionToken() (data []byte) {
randData := make([]byte, sessionTokenSize)
_, err = rand.Read(randData)
if err != nil {
return nil, err
}
// Since Go 1.24, crypto/rand.Read doesn't return an error and crashes
// unrecoverably instead.
_, _ = rand.Read(randData)
return randData, nil
return randData
}

View File

@@ -1,8 +1,6 @@
package home
import (
"bytes"
"crypto/rand"
"encoding/hex"
"path/filepath"
"testing"
@@ -12,23 +10,6 @@ import (
"github.com/stretchr/testify/require"
)
func TestNewSessionToken(t *testing.T) {
// Successful case.
token, err := newSessionToken()
require.NoError(t, err)
assert.Len(t, token, sessionTokenSize)
// Break the rand.Reader.
prevReader := rand.Reader
t.Cleanup(func() { rand.Reader = prevReader })
rand.Reader = &bytes.Buffer{}
// Unsuccessful case.
token, err = newSessionToken()
require.Error(t, err)
assert.Empty(t, token)
}
func TestAuth(t *testing.T) {
dir := t.TempDir()
fn := filepath.Join(dir, "sessions.db")
@@ -47,8 +28,7 @@ func TestAuth(t *testing.T) {
assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
a.removeSession("notfound")
sess, err := newSessionToken()
require.NoError(t, err)
sess := newSessionToken()
sessStr := hex.EncodeToString(sess)
now := time.Now().UTC().Unix()

View File

@@ -47,11 +47,7 @@ func (a *Auth) newCookie(req loginJSON, addr string) (c *http.Cookie, err error)
rateLimiter.remove(addr)
}
sess, err := newSessionToken()
if err != nil {
return nil, fmt.Errorf("generating token: %w", err)
}
sess := newSessionToken()
now := time.Now().UTC()
a.addSession(sess, &session{
@@ -155,7 +151,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
return
}
if rateLimiter := Context.auth.rateLimiter; rateLimiter != nil {
if rateLimiter := globalContext.auth.rateLimiter; rateLimiter != nil {
if left := rateLimiter.check(remoteIP); left > 0 {
w.Header().Set(httphdr.RetryAfter, strconv.Itoa(int(left.Seconds())))
writeErrorWithIP(
@@ -176,10 +172,10 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
log.Error("auth: getting real ip from request with remote ip %s: %s", remoteIP, err)
}
cookie, err := Context.auth.newCookie(req, remoteIP)
cookie, err := globalContext.auth.newCookie(req, remoteIP)
if err != nil {
logIP := remoteIP
if Context.auth.trustedProxies.Contains(ip.Unmap()) {
if globalContext.auth.trustedProxies.Contains(ip.Unmap()) {
logIP = ip.String()
}
@@ -213,7 +209,7 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
return
}
Context.auth.removeSession(c.Value)
globalContext.auth.removeSession(c.Value)
c = &http.Cookie{
Name: sessionCookieName,
@@ -232,7 +228,7 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
// RegisterAuthHandlers - register handlers
func RegisterAuthHandlers() {
Context.mux.Handle("/control/login", postInstallHandler(ensureHandler(http.MethodPost, handleLogin)))
globalContext.mux.Handle("/control/login", postInstallHandler(ensureHandler(http.MethodPost, handleLogin)))
httpRegister(http.MethodGet, "/control/logout", handleLogout)
}
@@ -254,13 +250,13 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (mustAuth bool) {
// Check Basic authentication.
user, pass, hasBasic := r.BasicAuth()
if hasBasic {
_, isAuthenticated = Context.auth.findUser(user, pass)
_, isAuthenticated = globalContext.auth.findUser(user, pass)
if !isAuthenticated {
log.Info("%s: invalid basic authorization value", pref)
}
}
} else {
res := Context.auth.checkSession(cookie.Value)
res := globalContext.auth.checkSession(cookie.Value)
isAuthenticated = res == checkSessionOK
if !isAuthenticated {
log.Debug("%s: invalid cookie value: %q", pref, cookie)
@@ -294,12 +290,12 @@ func optionalAuth(
) (wrapped func(http.ResponseWriter, *http.Request)) {
return func(w http.ResponseWriter, r *http.Request) {
p := r.URL.Path
authRequired := Context.auth != nil && Context.auth.authRequired()
authRequired := globalContext.auth != nil && globalContext.auth.authRequired()
if p == "/login.html" {
cookie, err := r.Cookie(sessionCookieName)
if authRequired && err == nil {
// Redirect to the dashboard if already authenticated.
res := Context.auth.checkSession(cookie.Value)
res := globalContext.auth.checkSession(cookie.Value)
if res == checkSessionOK {
http.Redirect(w, r, "", http.StatusFound)

View File

@@ -39,7 +39,7 @@ func TestAuthHTTP(t *testing.T) {
users := []webUser{
{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
}
Context.auth = InitAuth(fn, users, 60, nil, nil)
globalContext.auth = InitAuth(fn, users, 60, nil, nil)
handlerCalled := false
handler := func(_ http.ResponseWriter, _ *http.Request) {
@@ -68,7 +68,7 @@ func TestAuthHTTP(t *testing.T) {
assert.True(t, handlerCalled)
// perform login
cookie, err := Context.auth.newCookie(loginJSON{Name: "name", Password: "password"}, "")
cookie, err := globalContext.auth.newCookie(loginJSON{Name: "name", Password: "password"}, "")
require.NoError(t, err)
require.NotNil(t, cookie)
@@ -114,7 +114,7 @@ func TestAuthHTTP(t *testing.T) {
assert.True(t, handlerCalled)
r.Header.Del(httphdr.Cookie)
Context.auth.Close()
globalContext.auth.Close()
}
func TestRealIP(t *testing.T) {

View File

@@ -12,17 +12,14 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/timeutil"
)
// clientsContainer is the storage of all runtime and persistent clients.
@@ -75,6 +72,7 @@ func (clients *clientsContainer) Init(
etcHosts *aghnet.HostsContainer,
arpDB arpdb.Interface,
filteringConf *filtering.Config,
sigHdlr *signalHandler,
) (err error) {
// TODO(s.chzhen): Refactor it.
if clients.storage != nil {
@@ -109,6 +107,7 @@ func (clients *clientsContainer) Init(
clients.storage, err = client.NewStorage(ctx, &client.StorageConfig{
Logger: baseLogger.With(slogutil.KeyPrefix, "client_storage"),
Clock: timeutil.SystemClock{},
InitialClients: confClients,
DHCP: dhcpServer,
EtcHosts: hosts,
@@ -120,6 +119,8 @@ func (clients *clientsContainer) Init(
return fmt.Errorf("init client storage: %w", err)
}
sigHdlr.addClientStorage(clients.storage)
return nil
}
@@ -370,63 +371,6 @@ func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) {
return true
}
// type check
var _ dnsforward.ClientsContainer = (*clientsContainer)(nil)
// UpstreamConfigByID implements the [dnsforward.ClientsContainer] interface for
// *clientsContainer. upsConf is nil if the client isn't found or if the client
// has no custom upstreams.
func (clients *clientsContainer) UpstreamConfigByID(
id string,
bootstrap upstream.Resolver,
) (conf *proxy.CustomUpstreamConfig, err error) {
clients.lock.Lock()
defer clients.lock.Unlock()
c, ok := clients.storage.Find(id)
if !ok {
return nil, nil
} else if c.UpstreamConfig != nil {
return c.UpstreamConfig, nil
}
upstreams := stringutil.FilterOut(c.Upstreams, dnsforward.IsCommentOrEmpty)
if len(upstreams) == 0 {
return nil, nil
}
var upsConf *proxy.UpstreamConfig
upsConf, err = proxy.ParseUpstreamsConfig(
upstreams,
&upstream.Options{
Bootstrap: bootstrap,
Timeout: time.Duration(config.DNS.UpstreamTimeout),
HTTPVersions: dnsforward.UpstreamHTTPVersions(config.DNS.UseHTTP3Upstreams),
PreferIPv6: config.DNS.BootstrapPreferIPv6,
},
)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
conf = proxy.NewCustomUpstreamConfig(
upsConf,
c.UpstreamsCacheEnabled,
int(c.UpstreamsCacheSize),
config.DNS.EDNSClientSubnet.Enabled,
)
c.UpstreamConfig = conf
// TODO(s.chzhen): Pass context.
err = clients.storage.Update(context.TODO(), c.Name, c)
if err != nil {
return nil, fmt.Errorf("setting upstream config: %w", err)
}
return conf, nil
}
// type check
var _ client.AddressUpdater = (*clientsContainer)(nil)

View File

@@ -1,15 +1,12 @@
package home
import (
"net"
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -31,34 +28,10 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) {
nil,
nil,
&filtering.Config{},
newSignalHandler(nil, nil),
)
require.NoError(t, err)
return c
}
func TestClientsCustomUpstream(t *testing.T) {
clients := newClientsContainer(t)
ctx := testutil.ContextWithTimeout(t, testTimeout)
// Add client with upstreams.
err := clients.storage.Add(ctx, &client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")},
Upstreams: []string{
"1.1.1.1",
"[/example.org/]8.8.8.8",
},
})
require.NoError(t, err)
upsConf, err := clients.UpstreamConfigByID("1.2.3.4", net.DefaultResolver)
assert.Nil(t, upsConf)
assert.NoError(t, err)
upsConf, err = clients.UpstreamConfigByID("1.1.1.1", net.DefaultResolver)
require.NotNil(t, upsConf)
assert.NoError(t, err)
}

View File

@@ -486,9 +486,9 @@ var config = &configuration{
// configFilePath returns the absolute path to the symlink-evaluated path to the
// current config file.
func configFilePath() (confPath string) {
confPath, err := filepath.EvalSymlinks(Context.confFilePath)
confPath, err := filepath.EvalSymlinks(globalContext.confFilePath)
if err != nil {
confPath = Context.confFilePath
confPath = globalContext.confFilePath
logFunc := log.Error
if errors.Is(err, os.ErrNotExist) {
logFunc = log.Debug
@@ -498,7 +498,7 @@ func configFilePath() (confPath string) {
}
if !filepath.IsAbs(confPath) {
confPath = filepath.Join(Context.workDir, confPath)
confPath = filepath.Join(globalContext.workDir, confPath)
}
return confPath
@@ -530,8 +530,8 @@ func parseConfig() (err error) {
}
migrator := configmigrate.New(&configmigrate.Config{
WorkingDir: Context.workDir,
DataDir: Context.getDataDir(),
WorkingDir: globalContext.workDir,
DataDir: globalContext.getDataDir(),
})
var upgraded bool
@@ -640,31 +640,31 @@ func readConfigFile() (fileData []byte, err error) {
}
// Saves configuration to the YAML file and also saves the user filter contents to a file
func (c *configuration) write() (err error) {
func (c *configuration) write(tlsMgr *tlsManager) (err error) {
c.Lock()
defer c.Unlock()
if Context.auth != nil {
config.Users = Context.auth.usersList()
if globalContext.auth != nil {
config.Users = globalContext.auth.usersList()
}
if Context.tls != nil {
if tlsMgr != nil {
tlsConf := tlsConfigSettings{}
Context.tls.WriteDiskConfig(&tlsConf)
tlsMgr.WriteDiskConfig(&tlsConf)
config.TLS = tlsConf
}
if Context.stats != nil {
if globalContext.stats != nil {
statsConf := stats.Config{}
Context.stats.WriteDiskConfig(&statsConf)
globalContext.stats.WriteDiskConfig(&statsConf)
config.Stats.Interval = timeutil.Duration(statsConf.Limit)
config.Stats.Enabled = statsConf.Enabled
config.Stats.Ignored = statsConf.Ignored.Values()
}
if Context.queryLog != nil {
if globalContext.queryLog != nil {
dc := querylog.Config{}
Context.queryLog.WriteDiskConfig(&dc)
globalContext.queryLog.WriteDiskConfig(&dc)
config.DNS.AnonymizeClientIP = dc.AnonymizeClientIP
config.QueryLog.Enabled = dc.Enabled
config.QueryLog.FileEnabled = dc.FileEnabled
@@ -673,14 +673,14 @@ func (c *configuration) write() (err error) {
config.QueryLog.Ignored = dc.Ignored.Values()
}
if Context.filters != nil {
Context.filters.WriteDiskConfig(config.Filtering)
if globalContext.filters != nil {
globalContext.filters.WriteDiskConfig(config.Filtering)
config.Filters = config.Filtering.Filters
config.WhitelistFilters = config.Filtering.WhitelistFilters
config.UserRules = config.Filtering.UserRules
}
if s := Context.dnsServer; s != nil {
if s := globalContext.dnsServer; s != nil {
c := dnsforward.Config{}
s.WriteDiskConfig(&c)
dns := &config.DNS
@@ -695,11 +695,11 @@ func (c *configuration) write() (err error) {
dns.UpstreamTimeout = timeutil.Duration(s.UpstreamTimeout())
}
if Context.dhcpServer != nil {
Context.dhcpServer.WriteDiskConfig(config.DHCP)
if globalContext.dhcpServer != nil {
globalContext.dhcpServer.WriteDiskConfig(config.DHCP)
}
config.Clients.Persistent = Context.clients.forConfig()
config.Clients.Persistent = globalContext.clients.forConfig()
confPath := configFilePath()
log.Debug("writing config file %q", confPath)
@@ -726,14 +726,14 @@ func setContextTLSCipherIDs() (err error) {
if len(config.TLS.OverrideTLSCiphers) == 0 {
log.Info("tls: using default ciphers")
Context.tlsCipherIDs = aghtls.SaferCipherSuites()
globalContext.tlsCipherIDs = aghtls.SaferCipherSuites()
return nil
}
log.Info("tls: overriding ciphers: %s", config.TLS.OverrideTLSCiphers)
Context.tlsCipherIDs, err = aghtls.ParseCiphers(config.TLS.OverrideTLSCiphers)
globalContext.tlsCipherIDs, err = aghtls.ParseCiphers(config.TLS.OverrideTLSCiphers)
if err != nil {
return fmt.Errorf("parsing override ciphers: %w", err)
}

View File

@@ -14,7 +14,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/NYTimes/gziphandler"
@@ -69,7 +68,8 @@ func appendDNSAddrsWithIfaces(dst []string, src []netip.Addr) (res []string, err
// collectDNSAddresses returns the list of DNS addresses the server is listening
// on, including the addresses on all interfaces in cases of unspecified IPs.
func collectDNSAddresses() (addrs []string, err error) {
// tlsMgr must not be nil.
func collectDNSAddresses(tlsMgr *tlsManager) (addrs []string, err error) {
if hosts := config.DNS.BindHosts; len(hosts) == 0 {
addrs = appendDNSAddrs(addrs, netutil.IPv4Localhost())
} else {
@@ -79,7 +79,7 @@ func collectDNSAddresses() (addrs []string, err error) {
}
}
de := getDNSEncryption()
de := getDNSEncryption(tlsMgr)
if de.https != "" {
addrs = append(addrs, de.https)
}
@@ -114,8 +114,8 @@ type statusResponse struct {
IsRunning bool `json:"running"`
}
func handleStatus(w http.ResponseWriter, r *http.Request) {
dnsAddrs, err := collectDNSAddresses()
func (web *webAPI) handleStatus(w http.ResponseWriter, r *http.Request) {
dnsAddrs, err := collectDNSAddresses(web.tlsManager)
if err != nil {
// Don't add a lot of formatting, since the error is already
// wrapped by collectDNSAddresses.
@@ -129,10 +129,10 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
protectionDisabledUntil *time.Time
protectionEnabled bool
)
if Context.dnsServer != nil {
if globalContext.dnsServer != nil {
fltConf = &dnsforward.Config{}
Context.dnsServer.WriteDiskConfig(fltConf)
protectionEnabled, protectionDisabledUntil = Context.dnsServer.UpdatedProtectionStatus()
globalContext.dnsServer.WriteDiskConfig(fltConf)
protectionEnabled, protectionDisabledUntil = globalContext.dnsServer.UpdatedProtectionStatus()
}
var resp statusResponse
@@ -162,42 +162,42 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
// IsDHCPAvailable field is now false by default for Windows.
if runtime.GOOS != "windows" {
resp.IsDHCPAvailable = Context.dhcpServer != nil
resp.IsDHCPAvailable = globalContext.dhcpServer != nil
}
aghhttp.WriteJSONResponseOK(w, r, resp)
}
// ------------------------
// registration of handlers
// ------------------------
// registerControlHandlers sets up HTTP handlers for various control endpoints.
// web must not be nil.
func registerControlHandlers(web *webAPI) {
Context.mux.HandleFunc(
globalContext.mux.HandleFunc(
"/control/version.json",
postInstall(optionalAuth(web.handleVersionJSON)),
)
httpRegister(http.MethodPost, "/control/update", web.handleUpdate)
httpRegister(http.MethodGet, "/control/status", handleStatus)
httpRegister(http.MethodGet, "/control/status", web.handleStatus)
httpRegister(http.MethodPost, "/control/i18n/change_language", handleI18nChangeLanguage)
httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage)
httpRegister(http.MethodGet, "/control/profile", handleGetProfile)
httpRegister(http.MethodPut, "/control/profile/update", handlePutProfile)
// No auth is necessary for DoH/DoT configurations
Context.mux.HandleFunc("/apple/doh.mobileconfig", postInstall(handleMobileConfigDoH))
Context.mux.HandleFunc("/apple/dot.mobileconfig", postInstall(handleMobileConfigDoT))
globalContext.mux.HandleFunc("/apple/doh.mobileconfig", postInstall(handleMobileConfigDoH))
globalContext.mux.HandleFunc("/apple/dot.mobileconfig", postInstall(handleMobileConfigDoT))
RegisterAuthHandlers()
}
// httpRegister registers an HTTP handler.
func httpRegister(method, url string, handler http.HandlerFunc) {
if method == "" {
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
Context.mux.HandleFunc(url, postInstall(handler))
globalContext.mux.HandleFunc(url, postInstall(handler))
return
}
Context.mux.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
globalContext.mux.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
}
// ensure returns a wrapped handler that makes sure that the request has the
@@ -207,11 +207,7 @@ func ensure(
handler func(http.ResponseWriter, *http.Request),
) (wrapped func(http.ResponseWriter, *http.Request)) {
return func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
m, u := r.Method, r.URL
log.Debug("started %s %s %s", m, r.Host, u)
defer func() { log.Debug("finished %s %s %s in %s", m, r.Host, u, time.Since(start)) }()
m := r.Method
if m != method {
aghhttp.Error(r, w, http.StatusMethodNotAllowed, "only method %s is allowed", method)
@@ -223,8 +219,8 @@ func ensure(
return
}
Context.controlLock.Lock()
defer Context.controlLock.Unlock()
globalContext.controlLock.Lock()
defer globalContext.controlLock.Unlock()
}
handler(w, r)
@@ -293,7 +289,7 @@ func ensureHandler(method string, handler func(http.ResponseWriter, *http.Reques
// preInstall lets the handler run only if firstRun is true, no redirects
func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if !Context.firstRun {
if !globalContext.firstRun {
// if it's not first run, don't let users access it (for example /install.html when configuration is done)
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
@@ -320,7 +316,7 @@ func preInstallHandler(handler http.Handler) http.Handler {
// HTTPS-related headers. If proceed is true, the middleware must continue
// handling the request.
func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (proceed bool) {
web := Context.web
web := globalContext.web
if web.httpsServer.server == nil {
return true
}
@@ -409,7 +405,7 @@ func httpsURL(u *url.URL, host string, portHTTPS uint16) (redirectURL *url.URL)
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
if Context.firstRun && !strings.HasPrefix(path, "/install.") &&
if globalContext.firstRun && !strings.HasPrefix(path, "/install.") &&
!strings.HasPrefix(path, "/assets/") {
http.Redirect(w, r, "install.html", http.StatusFound)

View File

@@ -428,20 +428,20 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
curConfig := &configuration{}
copyInstallSettings(curConfig, config)
Context.firstRun = false
globalContext.firstRun = false
config.DNS.BindHosts = []netip.Addr{req.DNS.IP}
config.DNS.Port = req.DNS.Port
config.Filtering.SafeFSPatterns = []string{
filepath.Join(Context.workDir, userFilterDataDir, "*"),
filepath.Join(globalContext.workDir, userFilterDataDir, "*"),
}
config.HTTPConfig.Address = netip.AddrPortFrom(req.Web.IP, req.Web.Port)
u := &webUser{
Name: req.Username,
}
err = Context.auth.addUser(u, req.Password)
err = globalContext.auth.addUser(u, req.Password)
if err != nil {
Context.firstRun = true
globalContext.firstRun = true
copyInstallSettings(config, curConfig)
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "%s", err)
@@ -452,18 +452,18 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
// moment we'll allow setting up TLS in the initial configuration or the
// configuration itself will use HTTPS protocol, because the underlying
// functions potentially restart the HTTPS server.
err = startMods(web.baseLogger)
err = startMods(r.Context(), web.baseLogger, web.tlsManager)
if err != nil {
Context.firstRun = true
globalContext.firstRun = true
copyInstallSettings(config, curConfig)
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
return
}
err = config.write()
err = config.write(web.tlsManager)
if err != nil {
Context.firstRun = true
globalContext.firstRun = true
copyInstallSettings(config, curConfig)
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write config: %s", err)
@@ -527,8 +527,33 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
return req, restartHTTP, err
}
func (web *webAPI) registerInstallHandlers() {
Context.mux.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(web.handleInstallGetAddresses)))
Context.mux.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig)))
Context.mux.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure)))
// startMods initializes and starts the DNS server after installation.
// baseLogger and tlsMgr must not be nil.
func startMods(ctx context.Context, baseLogger *slog.Logger, tlsMgr *tlsManager) (err error) {
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&globalContext, config)
if err != nil {
return err
}
err = initDNS(baseLogger, tlsMgr, statsDir, querylogDir)
if err != nil {
return err
}
tlsMgr.start(ctx)
err = startDNSServer()
if err != nil {
closeDNSServer()
return err
}
return nil
}
func (web *webAPI) registerInstallHandlers() {
globalContext.mux.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(web.handleInstallGetAddresses)))
globalContext.mux.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig)))
globalContext.mux.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure)))
}

View File

@@ -62,7 +62,7 @@ func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) {
return
}
err = resp.setAllowedToAutoUpdate()
err = resp.setAllowedToAutoUpdate(web.tlsManager)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
@@ -158,14 +158,14 @@ type versionResponse struct {
}
// setAllowedToAutoUpdate sets CanAutoUpdate to true if AdGuard Home is actually
// allowed to perform an automatic update by the OS.
func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
// allowed to perform an automatic update by the OS. tlsMgr must not be nil.
func (vr *versionResponse) setAllowedToAutoUpdate(tlsMgr *tlsManager) (err error) {
if vr.CanAutoUpdate != aghalg.NBTrue {
return nil
}
tlsConf := &tlsConfigSettings{}
Context.tls.WriteDiskConfig(tlsConf)
tlsMgr.WriteDiskConfig(tlsConf)
canUpdate := true
if tlsConfUsesPrivilegedPorts(tlsConf) ||

View File

@@ -39,16 +39,22 @@ const (
// Called by other modules when configuration is changed
func onConfigModified() {
err := config.write()
err := config.write(globalContext.tls)
if err != nil {
log.Error("writing config: %s", err)
}
}
// initDNS updates all the fields of the [Context] needed to initialize the DNS
// server and initializes it at last. It also must not be called unless
// [config] and [Context] are initialized. baseLogger must not be nil.
func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error) {
// initDNS updates all the fields of the [globalContext] needed to initialize
// the DNS server and initializes it at last. It also must not be called unless
// [config] and [globalContext] are initialized. baseLogger and tlsMgr must not
// be nil.
func initDNS(
baseLogger *slog.Logger,
tlsMgr *tlsManager,
statsDir string,
querylogDir string,
) (err error) {
anonymizer := config.anonymizer()
statsConf := stats.Config{
@@ -58,7 +64,7 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
ConfigModified: onConfigModified,
HTTPRegister: httpRegister,
Enabled: config.Stats.Enabled,
ShouldCountClient: Context.clients.shouldCountClient,
ShouldCountClient: globalContext.clients.shouldCountClient,
}
engine, err := aghnet.NewIgnoreEngine(config.Stats.Ignored)
@@ -67,7 +73,7 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
}
statsConf.Ignored = engine
Context.stats, err = stats.New(statsConf)
globalContext.stats, err = stats.New(statsConf)
if err != nil {
return fmt.Errorf("init stats: %w", err)
}
@@ -77,7 +83,7 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
Anonymizer: anonymizer,
ConfigModified: onConfigModified,
HTTPRegister: httpRegister,
FindClient: Context.clients.findMultiple,
FindClient: globalContext.clients.findMultiple,
BaseDir: querylogDir,
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
RotationIvl: time.Duration(config.QueryLog.Interval),
@@ -92,25 +98,25 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
}
conf.Ignored = engine
Context.queryLog, err = querylog.New(conf)
globalContext.queryLog, err = querylog.New(conf)
if err != nil {
return fmt.Errorf("init querylog: %w", err)
}
Context.filters, err = filtering.New(config.Filtering, nil)
globalContext.filters, err = filtering.New(config.Filtering, nil)
if err != nil {
// Don't wrap the error, since it's informative enough as is.
return err
}
tlsConf := &tlsConfigSettings{}
Context.tls.WriteDiskConfig(tlsConf)
tlsMgr.WriteDiskConfig(tlsConf)
return initDNSServer(
Context.filters,
Context.stats,
Context.queryLog,
Context.dhcpServer,
globalContext.filters,
globalContext.stats,
globalContext.queryLog,
globalContext.dhcpServer,
anonymizer,
httpRegister,
tlsConf,
@@ -121,7 +127,7 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
// initDNSServer initializes the [context.dnsServer]. To only use the internal
// proxy, none of the arguments are required, but tlsConf and l still must not
// be nil, in other cases all the arguments also must not be nil. It also must
// not be called unless [config] and [Context] are initialized.
// not be called unless [config] and [globalContext] are initialized.
//
// TODO(e.burkov): Use [dnsforward.DNSCreateParams] as a parameter.
func initDNSServer(
@@ -134,7 +140,7 @@ func initDNSServer(
tlsConf *tlsConfigSettings,
l *slog.Logger,
) (err error) {
Context.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{
globalContext.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{
Logger: l,
DNSFilter: filters,
Stats: sts,
@@ -142,7 +148,7 @@ func initDNSServer(
PrivateNets: parseSubnetSet(config.DNS.PrivateNets),
Anonymizer: anonymizer,
DHCPServer: dhcpSrv,
EtcHosts: Context.etcHosts,
EtcHosts: globalContext.etcHosts,
LocalDomain: config.DHCP.LocalDomainName,
})
defer func() {
@@ -154,21 +160,27 @@ func initDNSServer(
return fmt.Errorf("dnsforward.NewServer: %w", err)
}
Context.clients.clientChecker = Context.dnsServer
globalContext.clients.clientChecker = globalContext.dnsServer
dnsConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpReg)
dnsConf, err := newServerConfig(
&config.DNS,
config.Clients.Sources,
tlsConf,
httpReg,
globalContext.clients.storage,
)
if err != nil {
return fmt.Errorf("newServerConfig: %w", err)
}
// Try to prepare the server with disabled private RDNS resolution if it
// failed to prepare as is. See TODO on [dnsforward.PrivateRDNSError].
err = Context.dnsServer.Prepare(dnsConf)
err = globalContext.dnsServer.Prepare(dnsConf)
if privRDNSErr := (&dnsforward.PrivateRDNSError{}); errors.As(err, &privRDNSErr) {
log.Info("WARNING: %s; trying to disable private RDNS resolution", err)
dnsConf.UsePrivateRDNS = false
err = Context.dnsServer.Prepare(dnsConf)
err = globalContext.dnsServer.Prepare(dnsConf)
}
if err != nil {
@@ -194,7 +206,7 @@ func parseSubnetSet(nets []netutil.Prefix) (s netutil.SubnetSet) {
}
func isRunning() bool {
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
return globalContext.dnsServer != nil && globalContext.dnsServer.IsRunning()
}
func ipsToTCPAddrs(ips []netip.Addr, port uint16) (tcpAddrs []*net.TCPAddr) {
@@ -230,12 +242,13 @@ func newServerConfig(
clientSrcConf *clientSourcesConfig,
tlsConf *tlsConfigSettings,
httpReg aghhttp.RegisterFunc,
clientsContainer dnsforward.ClientsContainer,
) (newConf *dnsforward.ServerConfig, err error) {
hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()})
fwdConf := dnsConf.Config
fwdConf.FilterHandler = applyAdditionalFiltering
fwdConf.ClientsContainer = &Context.clients
fwdConf.ClientsContainer = clientsContainer
newConf = &dnsforward.ServerConfig{
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
@@ -244,7 +257,7 @@ func newServerConfig(
TLSConfig: newDNSTLSConfig(tlsConf, hosts),
TLSAllowUnencryptedDoH: tlsConf.AllowUnencryptedDoH,
UpstreamTimeout: time.Duration(dnsConf.UpstreamTimeout),
TLSv12Roots: Context.tlsRoots,
TLSv12Roots: globalContext.tlsRoots,
ConfigModified: onConfigModified,
HTTPRegister: httpReg,
LocalPTRResolvers: dnsConf.PrivateRDNSResolvers,
@@ -259,16 +272,16 @@ func newServerConfig(
var initialAddresses []netip.Addr
// Context.stats may be nil here if initDNSServer is called from
// [cmdlineUpdate].
if sts := Context.stats; sts != nil {
if sts := globalContext.stats; sts != nil {
const initialClientsNum = 100
initialAddresses = Context.stats.TopClientsIP(initialClientsNum)
initialAddresses = globalContext.stats.TopClientsIP(initialClientsNum)
}
// Do not set DialContext, PrivateSubnets, and UsePrivateRDNS, because they
// are set by [dnsforward.Server.Prepare].
newConf.AddrProcConf = &client.DefaultAddrProcConfig{
Exchanger: Context.dnsServer,
AddressUpdater: &Context.clients,
Exchanger: globalContext.dnsServer,
AddressUpdater: &globalContext.clients,
InitialAddresses: initialAddresses,
CatchPanics: true,
UseRDNS: clientSrcConf.RDNS,
@@ -350,16 +363,18 @@ func newDNSCryptConfig(
}, nil
}
// dnsEncryption contains different types of TLS encryption addresses.
type dnsEncryption struct {
https string
tls string
quic string
}
func getDNSEncryption() (de dnsEncryption) {
// getDNSEncryption returns the TLS encryption addresses that AdGuard Home
// listens on. tlsMgr must not be nil.
func getDNSEncryption(tlsMgr *tlsManager) (de dnsEncryption) {
tlsConf := tlsConfigSettings{}
Context.tls.WriteDiskConfig(&tlsConf)
tlsMgr.WriteDiskConfig(&tlsConf)
if !tlsConf.Enabled || len(tlsConf.ServerName) == 0 {
return dnsEncryption{}
@@ -402,7 +417,7 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
// pref is a prefix for logging messages around the scope.
const pref = "applying filters"
Context.filters.ApplyBlockedServices(setts)
globalContext.filters.ApplyBlockedServices(setts)
log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID)
@@ -412,9 +427,9 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
setts.ClientIP = clientIP
c, ok := Context.clients.storage.Find(clientID)
c, ok := globalContext.clients.storage.Find(clientID)
if !ok {
c, ok = Context.clients.storage.Find(clientIP.String())
c, ok = globalContext.clients.storage.Find(clientIP.String())
if !ok {
log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID)
@@ -429,7 +444,7 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
setts.ServicesRules = nil
svcs := c.BlockedServices.IDs
if !c.BlockedServices.Schedule.Contains(time.Now()) {
Context.filters.ApplyBlockedServicesList(setts, svcs)
globalContext.filters.ApplyBlockedServicesList(setts, svcs)
log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs)
}
}
@@ -455,24 +470,24 @@ func startDNSServer() error {
return fmt.Errorf("unable to start forwarding DNS server: Already running")
}
Context.filters.EnableFilters(false)
globalContext.filters.EnableFilters(false)
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
err := Context.clients.Start(ctx)
err := globalContext.clients.Start(ctx)
if err != nil {
return fmt.Errorf("starting clients container: %w", err)
}
err = Context.dnsServer.Start()
err = globalContext.dnsServer.Start()
if err != nil {
return fmt.Errorf("starting dns server: %w", err)
}
Context.filters.Start()
Context.stats.Start()
globalContext.filters.Start()
globalContext.stats.Start()
err = Context.queryLog.Start(ctx)
err = globalContext.queryLog.Start(ctx)
if err != nil {
return fmt.Errorf("starting query log: %w", err)
}
@@ -480,16 +495,24 @@ func startDNSServer() error {
return nil
}
func reconfigureDNSServer() (err error) {
// reconfigureDNSServer updates the DNS server configuration using the provided
// TLS settings. tlsMgr must not be nil.
func reconfigureDNSServer(tlsMgr *tlsManager) (err error) {
tlsConf := &tlsConfigSettings{}
Context.tls.WriteDiskConfig(tlsConf)
tlsMgr.WriteDiskConfig(tlsConf)
newConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpRegister)
newConf, err := newServerConfig(
&config.DNS,
config.Clients.Sources,
tlsConf,
httpRegister,
globalContext.clients.storage,
)
if err != nil {
return fmt.Errorf("generating forwarding dns server config: %w", err)
}
err = Context.dnsServer.Reconfigure(newConf)
err = globalContext.dnsServer.Reconfigure(newConf)
if err != nil {
return fmt.Errorf("starting forwarding dns server: %w", err)
}
@@ -502,12 +525,12 @@ func stopDNSServer() (err error) {
return nil
}
err = Context.dnsServer.Stop()
err = globalContext.dnsServer.Stop()
if err != nil {
return fmt.Errorf("stopping forwarding dns server: %w", err)
}
err = Context.clients.close(context.TODO())
err = globalContext.clients.close(context.TODO())
if err != nil {
return fmt.Errorf("closing clients container: %w", err)
}
@@ -519,25 +542,25 @@ func stopDNSServer() (err error) {
func closeDNSServer() {
// DNS forward module must be closed BEFORE stats or queryLog because it depends on them
if Context.dnsServer != nil {
Context.dnsServer.Close()
Context.dnsServer = nil
if globalContext.dnsServer != nil {
globalContext.dnsServer.Close()
globalContext.dnsServer = nil
}
if Context.filters != nil {
Context.filters.Close()
if globalContext.filters != nil {
globalContext.filters.Close()
}
if Context.stats != nil {
err := Context.stats.Close()
if globalContext.stats != nil {
err := globalContext.stats.Close()
if err != nil {
log.Error("closing stats: %s", err)
}
}
if Context.queryLog != nil {
if globalContext.queryLog != nil {
// TODO(s.chzhen): Pass context.
err := Context.queryLog.Shutdown(context.TODO())
err := globalContext.queryLog.Shutdown(context.TODO())
if err != nil {
log.Error("closing query log: %s", err)
}

View File

@@ -37,14 +37,14 @@ func newStorage(tb testing.TB, clients []*client.Persistent) (s *client.Storage)
func TestApplyAdditionalFiltering(t *testing.T) {
var err error
Context.filters, err = filtering.New(&filtering.Config{
globalContext.filters, err = filtering.New(&filtering.Config{
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
},
}, nil)
require.NoError(t, err)
Context.clients.storage = newStorage(t, []*client.Persistent{{
globalContext.clients.storage = newStorage(t, []*client.Persistent{{
Name: "default",
ClientIDs: []string{"default"},
UseOwnSettings: false,
@@ -124,7 +124,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
err error
)
Context.filters, err = filtering.New(&filtering.Config{
globalContext.filters, err = filtering.New(&filtering.Config{
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
IDs: globalBlockedServices,
@@ -132,7 +132,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
}, nil)
require.NoError(t, err)
Context.clients.storage = newStorage(t, []*client.Persistent{{
globalContext.clients.storage = newStorage(t, []*client.Persistent{{
Name: "default",
ClientIDs: []string{"default"},
UseOwnBlockedServices: false,

View File

@@ -57,7 +57,12 @@ type homeContext struct {
auth *Auth // HTTP authentication module
filters *filtering.DNSFilter // DNS filtering module
web *webAPI // Web (HTTP, HTTPS) module
tls *tlsManager // TLS module
// tls contains the current configuration and state of TLS encryption.
//
// TODO(s.chzhen): Remove once it is no longer called from different
// modules. See [onConfigModified].
tls *tlsManager
// etcHosts contains IP-hostname mappings taken from the OS-specific hosts
// configuration files, for example /etc/hosts.
@@ -91,10 +96,10 @@ func (c *homeContext) getDataDir() string {
return filepath.Join(c.workDir, dataDir)
}
// Context - a global context object
// globalContext is a global context object.
//
// TODO(a.garipov): Refactor.
var Context homeContext
var globalContext homeContext
// Main is the entry point
func Main(clientBuildFS fs.FS) {
@@ -113,40 +118,32 @@ func Main(clientBuildFS fs.FS) {
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
ctx := context.Background()
for {
sig := <-signals
log.Info("Received signal %q", sig)
switch sig {
case syscall.SIGHUP:
Context.clients.storage.ReloadARP(ctx)
Context.tls.reload()
default:
cleanup(ctx)
cleanupAlways()
close(done)
}
}
}()
ctx := context.Background()
sigHdlr := newSignalHandler(signals, func(ctx context.Context) {
cleanup(ctx)
cleanupAlways()
close(done)
})
go sigHdlr.handle(ctx)
if opts.serviceControlAction != "" {
handleServiceControlAction(opts, clientBuildFS, signals, done)
handleServiceControlAction(opts, clientBuildFS, signals, done, sigHdlr)
return
}
// run the protection
run(opts, clientBuildFS, done)
run(opts, clientBuildFS, done, sigHdlr)
}
// setupContext initializes [Context] fields. It also reads and upgrades
// setupContext initializes [globalContext] fields. It also reads and upgrades
// config file if necessary.
func setupContext(opts options) (err error) {
Context.firstRun = detectFirstRun()
globalContext.firstRun = detectFirstRun()
Context.tlsRoots = aghtls.SystemRootCAs()
Context.mux = http.NewServeMux()
globalContext.tlsRoots = aghtls.SystemRootCAs()
globalContext.mux = http.NewServeMux()
if !opts.noEtcHosts {
err = setupHostsContainer()
@@ -156,7 +153,7 @@ func setupContext(opts options) (err error) {
}
}
if Context.firstRun {
if globalContext.firstRun {
log.Info("This is the first time AdGuard Home is launched")
checkNetworkPermissions()
@@ -247,7 +244,7 @@ func setupHostsContainer() (err error) {
return fmt.Errorf("getting default system hosts paths: %w", err)
}
Context.etcHosts, err = aghnet.NewHostsContainer(osutil.RootDirFS(), hostsWatcher, paths...)
globalContext.etcHosts, err = aghnet.NewHostsContainer(osutil.RootDirFS(), hostsWatcher, paths...)
if err != nil {
closeErr := hostsWatcher.Close()
if errors.Is(err, aghnet.ErrNoHostsPaths) {
@@ -271,14 +268,18 @@ func setupOpts(opts options) (err error) {
}
if len(opts.pidFile) != 0 && writePIDFile(opts.pidFile) {
Context.pidFileName = opts.pidFile
globalContext.pidFileName = opts.pidFile
}
return nil
}
// initContextClients initializes Context clients and related fields.
func initContextClients(ctx context.Context, logger *slog.Logger) (err error) {
func initContextClients(
ctx context.Context,
logger *slog.Logger,
sigHdlr *signalHandler,
) (err error) {
err = setupDNSFilteringConf(ctx, logger, config.Filtering)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
@@ -286,13 +287,13 @@ func initContextClients(ctx context.Context, logger *slog.Logger) (err error) {
}
//lint:ignore SA1019 Migration is not over.
config.DHCP.WorkDir = Context.workDir
config.DHCP.DataDir = Context.getDataDir()
config.DHCP.WorkDir = globalContext.workDir
config.DHCP.DataDir = globalContext.getDataDir()
config.DHCP.HTTPRegister = httpRegister
config.DHCP.ConfigModified = onConfigModified
Context.dhcpServer, err = dhcpd.Create(config.DHCP)
if Context.dhcpServer == nil || err != nil {
globalContext.dhcpServer, err = dhcpd.Create(config.DHCP)
if globalContext.dhcpServer == nil || err != nil {
// TODO(a.garipov): There are a lot of places in the code right
// now which assume that the DHCP server can be nil despite this
// condition. Inspect them and perhaps rewrite them to use
@@ -305,14 +306,15 @@ func initContextClients(ctx context.Context, logger *slog.Logger) (err error) {
arpDB = arpdb.New(logger.With(slogutil.KeyError, "arpdb"))
}
return Context.clients.Init(
return globalContext.clients.Init(
ctx,
logger,
config.Clients.Persistent,
Context.dhcpServer,
Context.etcHosts,
globalContext.dhcpServer,
globalContext.etcHosts,
arpDB,
config.Filtering,
sigHdlr,
)
}
@@ -374,15 +376,15 @@ func setupDNSFilteringConf(
pcTXTSuffix = `pc.dns.adguard.com.`
)
conf.EtcHosts = Context.etcHosts
conf.EtcHosts = globalContext.etcHosts
// TODO(s.chzhen): Use empty interface.
if Context.etcHosts == nil || !config.DNS.HostsFileEnabled {
if globalContext.etcHosts == nil || !config.DNS.HostsFileEnabled {
conf.EtcHosts = nil
}
conf.ConfigModified = onConfigModified
conf.HTTPRegister = httpRegister
conf.DataDir = Context.getDataDir()
conf.DataDir = globalContext.getDataDir()
conf.Filters = slices.Clone(config.Filters)
conf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
conf.UserRules = slices.Clone(config.UserRules)
@@ -522,13 +524,15 @@ func isUpdateEnabled(ctx context.Context, l *slog.Logger, opts *options, customU
}
}
// initWeb initializes the web module. upd and baseLogger must not be nil.
// initWeb initializes the web module. upd, baseLogger, and tlsMgr must not be
// nil.
func initWeb(
ctx context.Context,
opts options,
clientBuildFS fs.FS,
upd *updater.Updater,
baseLogger *slog.Logger,
tlsMgr *tlsManager,
customURL bool,
) (web *webAPI, err error) {
logger := baseLogger.With(slogutil.KeyPrefix, "webapi")
@@ -551,6 +555,7 @@ func initWeb(
updater: upd,
logger: logger,
baseLogger: baseLogger,
tlsManager: tlsMgr,
clientFS: clientFS,
@@ -560,7 +565,7 @@ func initWeb(
ReadHeaderTimeout: readHdrTimeout,
WriteTimeout: writeTimeout,
firstRun: Context.firstRun,
firstRun: globalContext.firstRun,
disableUpdate: disableUpdate,
runningAsService: opts.runningAsService,
serveHTTP3: config.DNS.ServeHTTP3,
@@ -583,7 +588,7 @@ func fatalOnError(err error) {
// run configures and starts AdGuard Home.
//
// TODO(e.burkov): Make opts a pointer.
func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalHandler) {
// Configure working dir.
err := initWorkingDir(opts)
fatalOnError(err)
@@ -599,10 +604,11 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
// TODO(a.garipov): Use slog everywhere.
slogLogger := newSlogLogger(ls)
sigHdlr.swapLogger(slogLogger)
// Print the first message after logger is configured.
log.Info(version.Full())
log.Debug("current working directory is %s", Context.workDir)
log.Info("%s", version.Full())
log.Debug("current working directory is %s", globalContext.workDir)
if opts.runningAsService {
log.Info("AdGuard Home is running as a service")
}
@@ -621,7 +627,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
// TODO(s.chzhen): Use it for the entire initialization process.
ctx := context.Background()
err = initContextClients(ctx, slogLogger)
err = initContextClients(ctx, slogLogger, sigHdlr)
fatalOnError(err)
err = setupOpts(opts)
@@ -632,15 +638,15 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
confPath := configFilePath()
upd, customURL := newUpdater(ctx, slogLogger, Context.workDir, confPath, execPath, config)
upd, customURL := newUpdater(ctx, slogLogger, globalContext.workDir, confPath, execPath, config)
// TODO(e.burkov): This could be made earlier, probably as the option's
// effect.
cmdlineUpdate(ctx, slogLogger, opts, upd)
if !Context.firstRun {
if !globalContext.firstRun {
// Save the updated config.
err = config.write()
err = config.write(nil)
fatalOnError(err)
if config.HTTPConfig.Pprof.Enabled {
@@ -648,33 +654,36 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
}
}
dataDir := Context.getDataDir()
dataDir := globalContext.getDataDir()
err = os.MkdirAll(dataDir, aghos.DefaultPermDir)
fatalOnError(errors.Annotate(err, "creating DNS data dir at %s: %w", dataDir))
GLMode = opts.glinetMode
// Init auth module.
Context.auth, err = initUsers()
globalContext.auth, err = initUsers()
fatalOnError(err)
Context.tls, err = newTLSManager(config.TLS, config.DNS.ServePlainDNS)
tlsMgr, err := newTLSManager(config.TLS, config.DNS.ServePlainDNS)
if err != nil {
log.Error("initializing tls: %s", err)
onConfigModified()
}
Context.web, err = initWeb(ctx, opts, clientBuildFS, upd, slogLogger, customURL)
globalContext.tls = tlsMgr
sigHdlr.addTLSManager(tlsMgr)
globalContext.web, err = initWeb(ctx, opts, clientBuildFS, upd, slogLogger, tlsMgr, customURL)
fatalOnError(err)
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&Context, config)
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&globalContext, config)
fatalOnError(err)
if !Context.firstRun {
err = initDNS(slogLogger, statsDir, querylogDir)
if !globalContext.firstRun {
err = initDNS(slogLogger, tlsMgr, statsDir, querylogDir)
fatalOnError(err)
Context.tls.start()
tlsMgr.start(ctx)
go func() {
startErr := startDNSServer()
@@ -684,8 +693,8 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
}
}()
if Context.dhcpServer != nil {
err = Context.dhcpServer.Start()
if globalContext.dhcpServer != nil {
err = globalContext.dhcpServer.Start()
if err != nil {
log.Error("starting dhcp server: %s", err)
}
@@ -693,10 +702,10 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
}
if !opts.noPermCheck {
checkPermissions(ctx, slogLogger, Context.workDir, confPath, dataDir, statsDir, querylogDir)
checkPermissions(ctx, slogLogger, globalContext.workDir, confPath, dataDir, statsDir, querylogDir)
}
Context.web.start(ctx)
globalContext.web.start(ctx)
// Wait for other goroutines to complete their job.
<-done
@@ -775,7 +784,7 @@ func checkPermissions(
// initUsers initializes context auth module. Clears config users field.
func initUsers() (auth *Auth, err error) {
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
sessFilename := filepath.Join(globalContext.getDataDir(), "sessions.db")
var rateLimiter *authRateLimiter
if config.AuthAttempts > 0 && config.AuthBlockMin > 0 {
@@ -807,31 +816,6 @@ func (c *configuration) anonymizer() (ipmut *aghnet.IPMut) {
return aghnet.NewIPMut(anonFunc)
}
// startMods initializes and starts the DNS server after installation.
// baseLogger must not be nil.
func startMods(baseLogger *slog.Logger) (err error) {
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&Context, config)
if err != nil {
return err
}
err = initDNS(baseLogger, statsDir, querylogDir)
if err != nil {
return err
}
Context.tls.start()
err = startDNSServer()
if err != nil {
closeDNSServer()
return err
}
return nil
}
// checkNetworkPermissions checks if the current user permissions are enough to
// use the required networking functionality.
func checkNetworkPermissions() {
@@ -883,14 +867,14 @@ func writePIDFile(fn string) bool {
func initConfigFilename(opts options) {
confPath := opts.confFilename
if confPath == "" {
Context.confFilePath = filepath.Join(Context.workDir, "AdGuardHome.yaml")
globalContext.confFilePath = filepath.Join(globalContext.workDir, "AdGuardHome.yaml")
return
}
log.Debug("config path overridden to %q from cmdline", confPath)
Context.confFilePath = confPath
globalContext.confFilePath = confPath
}
// initWorkingDir initializes the workDir. If no command-line arguments are
@@ -904,18 +888,18 @@ func initWorkingDir(opts options) (err error) {
if opts.workDir != "" {
// If there is a custom config file, use it's directory as our working dir
Context.workDir = opts.workDir
globalContext.workDir = opts.workDir
} else {
Context.workDir = filepath.Dir(execPath)
globalContext.workDir = filepath.Dir(execPath)
}
workDir, err := filepath.EvalSymlinks(Context.workDir)
workDir, err := filepath.EvalSymlinks(globalContext.workDir)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
Context.workDir = workDir
globalContext.workDir = workDir
return nil
}
@@ -924,13 +908,13 @@ func initWorkingDir(opts options) (err error) {
func cleanup(ctx context.Context) {
log.Info("stopping AdGuard Home")
if Context.web != nil {
Context.web.close(ctx)
Context.web = nil
if globalContext.web != nil {
globalContext.web.close(ctx)
globalContext.web = nil
}
if Context.auth != nil {
Context.auth.Close()
Context.auth = nil
if globalContext.auth != nil {
globalContext.auth.Close()
globalContext.auth = nil
}
err := stopDNSServer()
@@ -938,28 +922,24 @@ func cleanup(ctx context.Context) {
log.Error("stopping dns server: %s", err)
}
if Context.dhcpServer != nil {
err = Context.dhcpServer.Stop()
if globalContext.dhcpServer != nil {
err = globalContext.dhcpServer.Stop()
if err != nil {
log.Error("stopping dhcp server: %s", err)
}
}
if Context.etcHosts != nil {
if err = Context.etcHosts.Close(); err != nil {
if globalContext.etcHosts != nil {
if err = globalContext.etcHosts.Close(); err != nil {
log.Error("closing hosts container: %s", err)
}
}
if Context.tls != nil {
Context.tls = nil
}
}
// This function is called before application exits
func cleanupAlways() {
if len(Context.pidFileName) != 0 {
_ = os.Remove(Context.pidFileName)
if len(globalContext.pidFileName) != 0 {
_ = os.Remove(globalContext.pidFileName)
}
log.Info("stopped")
@@ -975,7 +955,7 @@ func exitWithError() {
func loadCmdLineOpts() (opts options) {
opts, eff, err := parseCmdOpts(os.Args[0], os.Args[1:])
if err != nil {
log.Error(err.Error())
log.Error("%s", err)
printHelp(os.Args[0])
exitWithError()
@@ -984,7 +964,7 @@ func loadCmdLineOpts() (opts options) {
if eff != nil {
err = eff()
if err != nil {
log.Error(err.Error())
log.Error("%s", err)
exitWithError()
}
@@ -1005,10 +985,12 @@ func printWebAddrs(proto, addr string, port uint16) {
// printHTTPAddresses prints the IP addresses which user can use to access the
// admin interface. proto is either schemeHTTP or schemeHTTPS.
func printHTTPAddresses(proto string) {
//
// TODO(s.chzhen): Implement separate functions for HTTP and HTTPS.
func printHTTPAddresses(proto string, tlsMgr *tlsManager) {
tlsConf := tlsConfigSettings{}
if Context.tls != nil {
Context.tls.WriteDiskConfig(&tlsConf)
if tlsMgr != nil {
tlsMgr.WriteDiskConfig(&tlsConf)
}
port := config.HTTPConfig.Address.Port()
@@ -1016,7 +998,6 @@ func printHTTPAddresses(proto string) {
port = tlsConf.PortHTTPS
}
// TODO(e.burkov): Inspect and perhaps merge with the previous condition.
if proto == urlutil.SchemeHTTPS && tlsConf.ServerName != "" {
printWebAddrs(proto, tlsConf.ServerName, tlsConf.PortHTTPS)
@@ -1050,9 +1031,9 @@ func printHTTPAddresses(proto string) {
// detectFirstRun returns true if this is the first run of AdGuard Home.
func detectFirstRun() (ok bool) {
confPath := Context.confFilePath
confPath := globalContext.confFilePath
if !filepath.IsAbs(confPath) {
confPath = filepath.Join(Context.workDir, Context.confFilePath)
confPath = filepath.Join(globalContext.workDir, globalContext.confFilePath)
}
_, err := os.Stat(confPath)
@@ -1105,7 +1086,7 @@ func cmdlineUpdate(ctx context.Context, l *slog.Logger, opts options, upd *updat
os.Exit(osutil.ExitCodeSuccess)
}
err = upd.Update(Context.firstRun)
err = upd.Update(globalContext.firstRun)
fatalOnError(err)
err = restartService()

View File

@@ -17,7 +17,7 @@ func httpClient() (c *http.Client) {
// Do not use Context.dnsServer.DialContext directly in the struct literal
// below, since Context.dnsServer may be nil when this function is called.
dialContext := func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
return Context.dnsServer.DialContext(ctx, network, addr)
return globalContext.dnsServer.DialContext(ctx, network, addr)
}
return &http.Client{
@@ -27,8 +27,8 @@ func httpClient() (c *http.Client) {
DialContext: dialContext,
Proxy: httpProxy,
TLSClientConfig: &tls.Config{
RootCAs: Context.tlsRoots,
CipherSuites: Context.tlsCipherIDs,
RootCAs: globalContext.tlsRoots,
CipherSuites: globalContext.tlsCipherIDs,
MinVersion: tls.VersionTLS12,
},
},

View File

@@ -66,7 +66,7 @@ func configureLogger(ls *logSettings) (err error) {
logFilePath := ls.File
if !filepath.IsAbs(logFilePath) {
logFilePath = filepath.Join(Context.workDir, logFilePath)
logFilePath = filepath.Join(globalContext.workDir, logFilePath)
}
log.SetOutput(&lumberjack.Logger{

View File

@@ -19,10 +19,8 @@ func setupDNSIPs(t testing.TB) {
t.Helper()
prevConfig := config
prevTLS := Context.tls
t.Cleanup(func() {
config = prevConfig
Context.tls = prevTLS
})
config = &configuration{
@@ -31,8 +29,6 @@ func setupDNSIPs(t testing.TB) {
Port: defaultPortDNS,
},
}
Context.tls = &tlsManager{}
}
func TestHandleMobileConfigDoH(t *testing.T) {
@@ -62,11 +58,6 @@ func TestHandleMobileConfigDoH(t *testing.T) {
})
t.Run("error_no_host", func(t *testing.T) {
oldTLSConf := Context.tls
t.Cleanup(func() { Context.tls = oldTLSConf })
Context.tls = &tlsManager{conf: tlsConfigSettings{}}
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil)
require.NoError(t, err)
@@ -134,11 +125,6 @@ func TestHandleMobileConfigDoT(t *testing.T) {
})
t.Run("error_no_host", func(t *testing.T) {
oldTLSConf := Context.tls
t.Cleanup(func() { Context.tls = oldTLSConf })
Context.tls = &tlsManager{conf: tlsConfigSettings{}}
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
require.NoError(t, err)

View File

@@ -47,7 +47,7 @@ type profileJSON struct {
// handleGetProfile is the handler for GET /control/profile endpoint.
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
u := Context.auth.getCurrentUser(r)
u := globalContext.auth.getCurrentUser(r)
var resp profileJSON
func() {

View File

@@ -36,6 +36,7 @@ type program struct {
signals chan os.Signal
done chan struct{}
opts options
sigHdlr *signalHandler
}
// type check
@@ -47,7 +48,7 @@ func (p *program) Start(_ service.Service) (err error) {
args := p.opts
args.runningAsService = true
go run(args, p.clientBuildFS, p.done)
go run(args, p.clientBuildFS, p.done, p.sigHdlr)
return nil
}
@@ -204,13 +205,14 @@ func handleServiceControlAction(
clientBuildFS fs.FS,
signals chan os.Signal,
done chan struct{},
sigHdlr *signalHandler,
) {
// Call chooseSystem explicitly to introduce OpenBSD support for service
// package. It's a noop for other GOOS values.
chooseSystem()
action := opts.serviceControlAction
log.Info(version.Full())
log.Info("%s", version.Full())
log.Info("service: control action: %s", action)
if action == "reload" {
@@ -244,6 +246,7 @@ func handleServiceControlAction(
signals: signals,
done: done,
opts: runOpts,
sigHdlr: sigHdlr,
}, svcConfig)
if err != nil {
log.Fatalf("service: initializing service: %s", err)
@@ -336,7 +339,7 @@ AdGuard Home is successfully installed and will automatically start on boot.
There are a few more things that must be configured before you can use it.
Click on the link below and follow the Installation Wizard steps to finish setup.
AdGuard Home is now available at the following addresses:`)
printHTTPAddresses(urlutil.SchemeHTTP)
printHTTPAddresses(urlutil.SchemeHTTP, nil)
}
}

View File

@@ -392,7 +392,7 @@ type sysLogger struct{}
// Error implements service.Logger interface for sysLogger.
func (sysLogger) Error(v ...any) error {
log.Error(fmt.Sprint(v...))
log.Error("%s", fmt.Sprint(v...))
return nil
}
@@ -406,7 +406,7 @@ func (sysLogger) Warning(v ...any) error {
// Info implements service.Logger interface for sysLogger.
func (sysLogger) Info(v ...any) error {
log.Info(fmt.Sprint(v...))
log.Info("%s", fmt.Sprint(v...))
return nil
}

121
internal/home/signal.go Normal file
View File

@@ -0,0 +1,121 @@
package home
import (
"context"
"log/slog"
"os"
"sync"
"sync/atomic"
"syscall"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil"
)
// signalHandler processes incoming signals. It reloads configurations of
// stored entities on SIGHUP and performs cleanup on all other signals.
type signalHandler struct {
// logger is used to log the operation of the signal handler. Initially,
// [slog.Default] is used, but it should be swapped later using
// [signalHandler.swapLogger].
logger *atomic.Pointer[slog.Logger]
// mu protects clientStorage and tlsManager.
mu *sync.Mutex
// clientStorage is used to reload information about runtime clients with an
// ARP source.
clientStorage *client.Storage
// tlsManager is used to reload the TLS configuration.
tlsManager *tlsManager
// signals receives incoming signals.
signals <-chan os.Signal
// cleanup is called to perform cleanup on all incoming signals, except
// SIGHUP.
cleanup func(ctx context.Context)
}
// newSignalHandler returns a new properly initialized *signalHandler.
func newSignalHandler(
signals <-chan os.Signal,
cleanup func(ctx context.Context),
) (h *signalHandler) {
h = &signalHandler{
logger: &atomic.Pointer[slog.Logger]{},
mu: &sync.Mutex{},
signals: signals,
cleanup: cleanup,
}
h.logger.Store(slog.Default())
return h
}
// swapLogger replaces the stored logger with the given logger.
func (h *signalHandler) swapLogger(logger *slog.Logger) {
h.logger.Swap(logger)
}
// addClientStorage stores the client storage.
func (h *signalHandler) addClientStorage(s *client.Storage) {
h.mu.Lock()
defer h.mu.Unlock()
h.clientStorage = s
}
// addTLSManager stores the TLS manager.
func (h *signalHandler) addTLSManager(m *tlsManager) {
h.mu.Lock()
defer h.mu.Unlock()
h.tlsManager = m
}
// handle processes incoming signals. It blocks until a signal is received. It
// reloads configurations of stored entities on SIGHUP, or performs cleanup on
// all other signals. It is intended to be used as a goroutine.
func (h *signalHandler) handle(ctx context.Context) {
// NOTE: Avoid using [slogutil.RecoverAndExit] to prevent immediate
// evaluation of the logger.
defer func() {
v := recover()
if v == nil {
return
}
slogutil.PrintRecovered(ctx, h.logger.Load(), v)
os.Exit(osutil.ExitCodeFailure)
}()
for {
sig := <-h.signals
h.logger.Load().InfoContext(ctx, "received signal", "signal", sig)
switch sig {
case syscall.SIGHUP:
h.reloadConfig(ctx)
default:
h.cleanup(ctx)
}
}
}
// reloadConfig refreshes configurations of stored entities.
func (h *signalHandler) reloadConfig(ctx context.Context) {
h.mu.Lock()
defer h.mu.Unlock()
if h.clientStorage != nil {
h.clientStorage.ReloadARP(ctx)
}
if h.tlsManager != nil {
h.tlsManager.reload()
}
}

View File

@@ -102,7 +102,9 @@ func (m *tlsManager) setCertFileTime() {
}
// start updates the configuration of t and starts it.
func (m *tlsManager) start() {
//
// TODO(s.chzhen): Use context.
func (m *tlsManager) start(_ context.Context) {
m.registerWebHandlers()
m.confLock.Lock()
@@ -112,7 +114,7 @@ func (m *tlsManager) start() {
// The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current
// request.
Context.web.tlsConfigChanged(context.Background(), tlsConf)
globalContext.web.tlsConfigChanged(context.Background(), tlsConf)
}
// reload updates the configuration and restarts t.
@@ -151,7 +153,7 @@ func (m *tlsManager) reload() {
m.certLastMod = fi.ModTime().UTC()
_ = reconfigureDNSServer()
_ = reconfigureDNSServer(m)
m.confLock.Lock()
tlsConf = m.conf
@@ -160,7 +162,7 @@ func (m *tlsManager) reload() {
// The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current
// request.
Context.web.tlsConfigChanged(context.Background(), tlsConf)
globalContext.web.tlsConfigChanged(context.Background(), tlsConf)
}
// loadTLSConf loads and validates the TLS configuration. The returned error is
@@ -440,7 +442,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
onConfigModified()
err = reconfigureDNSServer()
err = reconfigureDNSServer(m)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
@@ -463,7 +465,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
// same reason.
if restartHTTPS {
go func() {
Context.web.tlsConfigChanged(context.Background(), req.tlsConfigSettings)
globalContext.web.tlsConfigChanged(context.Background(), req.tlsConfigSettings)
}()
}
}
@@ -539,7 +541,7 @@ func validateCertChain(certs []*x509.Certificate, srvName string) (err error) {
opts := x509.VerifyOptions{
DNSName: srvName,
Roots: Context.tlsRoots,
Roots: globalContext.tlsRoots,
Intermediates: pool,
}
_, err = main.Verify(opts)

View File

@@ -49,6 +49,10 @@ type webConfig struct {
// nil.
baseLogger *slog.Logger
// tlsManager contains the current configuration and state of TLS
// encryption. It must not be nil.
tlsManager *tlsManager
clientFS fs.FS
// BindAddr is the binding address with port for plain HTTP web interface.
@@ -108,6 +112,10 @@ type webAPI struct {
// nil.
baseLogger *slog.Logger
// tlsManager contains the current configuration and state of TLS
// encryption.
tlsManager *tlsManager
// httpsServer is the server that handles HTTPS traffic. If it is not nil,
// [Web.http3Server] must also not be nil.
httpsServer httpsServer
@@ -124,12 +132,13 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
conf: conf,
logger: conf.logger,
baseLogger: conf.baseLogger,
tlsManager: conf.tlsManager,
}
clientFS := http.FileServer(http.FS(conf.clientFS))
// if not configured, redirect / to /install.html, otherwise redirect /install.html to /
Context.mux.Handle("/", withMiddlewares(clientFS, gziphandler.GzipHandler, optionalAuthHandler, postInstallHandler))
globalContext.mux.Handle("/", withMiddlewares(clientFS, gziphandler.GzipHandler, optionalAuthHandler, postInstallHandler))
// add handlers for /install paths, we only need them when we're not configured yet
if conf.firstRun {
@@ -138,7 +147,7 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
"This is the first launch of AdGuard Home, redirecting everything to /install.html",
)
Context.mux.Handle("/install.html", preInstallHandler(clientFS))
globalContext.mux.Handle("/install.html", preInstallHandler(clientFS))
w.registerInstallHandlers()
} else {
registerControlHandlers(w)
@@ -154,7 +163,7 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
//
// TODO(a.garipov): Adapt for HTTP/3.
func webCheckPortAvailable(port uint16) (ok bool) {
if Context.web.httpsServer.server != nil {
if globalContext.web.httpsServer.server != nil {
return true
}
@@ -220,14 +229,18 @@ func (web *webAPI) start(ctx context.Context) {
// this loop is used as an ability to change listening host and/or port
for !web.httpsServer.inShutdown {
printHTTPAddresses(urlutil.SchemeHTTP)
printHTTPAddresses(urlutil.SchemeHTTP, web.tlsManager)
errs := make(chan error, 2)
// Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies.
hdlr := h2c.NewHandler(withMiddlewares(Context.mux, limitRequestBody), &http2.Server{})
hdlr := h2c.NewHandler(withMiddlewares(globalContext.mux, limitRequestBody), &http2.Server{})
logger := web.baseLogger.With(loggerKeyServer, "plain")
// TODO(a.garipov): Remove other logs like this in other code.
logMw := httputil.NewLogMiddleware(logger, slog.LevelDebug)
hdlr = logMw.Wrap(hdlr)
// Create a new instance, because the Web is not usable after Shutdown.
web.httpServer = &http.Server{
Addr: web.conf.BindAddr.String(),
@@ -238,7 +251,9 @@ func (web *webAPI) start(ctx context.Context) {
ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError),
}
go func() {
defer slogutil.RecoverAndLog(ctx, web.logger)
defer slogutil.RecoverAndLog(ctx, logger)
logger.InfoContext(ctx, "starting plain server", "addr", web.httpServer.Addr)
errs <- web.httpServer.ListenAndServe()
}()
@@ -305,13 +320,17 @@ func (web *webAPI) tlsServerLoop(ctx context.Context) {
addr := netip.AddrPortFrom(web.conf.BindAddr.Addr(), portHTTPS).String()
logger := web.baseLogger.With(loggerKeyServer, "https")
// TODO(a.garipov): Remove other logs like this in other code.
logMw := httputil.NewLogMiddleware(logger, slog.LevelDebug)
hdlr := logMw.Wrap(withMiddlewares(globalContext.mux, limitRequestBody))
web.httpsServer.server = &http.Server{
Addr: addr,
Handler: withMiddlewares(Context.mux, limitRequestBody),
Handler: hdlr,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{web.httpsServer.cert},
RootCAs: Context.tlsRoots,
CipherSuites: Context.tlsCipherIDs,
RootCAs: globalContext.tlsRoots,
CipherSuites: globalContext.tlsCipherIDs,
MinVersion: tls.VersionTLS12,
},
ReadTimeout: web.conf.ReadTimeout,
@@ -320,13 +339,13 @@ func (web *webAPI) tlsServerLoop(ctx context.Context) {
ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError),
}
printHTTPAddresses(urlutil.SchemeHTTPS)
printHTTPAddresses(urlutil.SchemeHTTPS, web.tlsManager)
if web.conf.serveHTTP3 {
go web.mustStartHTTP3(ctx, addr)
}
web.logger.DebugContext(ctx, "starting https server")
logger.InfoContext(ctx, "starting https server")
err := web.httpsServer.server.ListenAndServeTLS("", "")
if !errors.Is(err, http.ErrServerClosed) {
cleanupAlways()
@@ -344,11 +363,11 @@ func (web *webAPI) mustStartHTTP3(ctx context.Context, address string) {
Addr: address,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{web.httpsServer.cert},
RootCAs: Context.tlsRoots,
CipherSuites: Context.tlsCipherIDs,
RootCAs: globalContext.tlsRoots,
CipherSuites: globalContext.tlsCipherIDs,
MinVersion: tls.VersionTLS12,
},
Handler: withMiddlewares(Context.mux, limitRequestBody),
Handler: withMiddlewares(globalContext.mux, limitRequestBody),
}
web.logger.DebugContext(ctx, "starting http/3 server")

View File

@@ -134,9 +134,6 @@ func TestManager_Add(t *testing.T) {
assert.NoError(t, err)
}
// ipsetPropsSink is the typed sink for benchmark results.
var ipsetPropsSink []props
func BenchmarkManager_LookupHost(b *testing.B) {
propsLong := []props{{
name: "example.com",
@@ -155,9 +152,13 @@ func BenchmarkManager_LookupHost(b *testing.B) {
},
}
var ipsetPropsSink []props
b.Run("long", func(b *testing.B) {
const name = "a.very.long.domain.name.inside.the.domain.example.com"
for range b.N {
b.ReportAllocs()
for b.Loop() {
ipsetPropsSink = m.lookupHost(name)
}
@@ -166,10 +167,21 @@ func BenchmarkManager_LookupHost(b *testing.B) {
b.Run("short", func(b *testing.B) {
const name = "example.net"
for range b.N {
b.ReportAllocs()
for b.Loop() {
ipsetPropsSink = m.lookupHost(name)
}
require.Equal(b, propsShort, ipsetPropsSink)
})
// Most recent results:
//
// goos: linux
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardHome/internal/ipset
// cpu: Intel(R) Core(TM) i7-10510U CPU @ 1.80GHz
// BenchmarkManager_LookupHost/long-8 6562424 174.8 ns/op 0 B/op 0 allocs/op
// BenchmarkManager_LookupHost/short-8 100000000 10.72 ns/op 0 B/op 0 allocs/op
}

View File

@@ -283,6 +283,8 @@ func anonymizeIPSlow(ip net.IP) {
}
}
// TODO(e.burkov): Investigate the results, it seems that the slow version
// isn't that slow.
func BenchmarkAnonymizeIP(b *testing.B) {
benchCases := []struct {
name string
@@ -320,7 +322,7 @@ func BenchmarkAnonymizeIP(b *testing.B) {
b.Run(bc.name, func(b *testing.B) {
b.ReportAllocs()
for range b.N {
for b.Loop() {
AnonymizeIP(bc.ip)
}
@@ -330,11 +332,26 @@ func BenchmarkAnonymizeIP(b *testing.B) {
b.Run(bc.name+"_slow", func(b *testing.B) {
b.ReportAllocs()
for range b.N {
for b.Loop() {
anonymizeIPSlow(bc.ip)
}
assert.Equal(b, bc.want, bc.ip)
})
}
// Most recent results:
//
// goos: darwin
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardHome/internal/querylog
// cpu: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
// BenchmarkAnonymizeIP/v4-12 426499675 2.687 ns/op 0 B/op 0 allocs/op
// BenchmarkAnonymizeIP/v4_slow-12 510082938 2.412 ns/op 0 B/op 0 allocs/op
// BenchmarkAnonymizeIP/v4_mapped-12 149121745 7.992 ns/op 0 B/op 0 allocs/op
// BenchmarkAnonymizeIP/v4_mapped_slow-12 178441804 6.698 ns/op 0 B/op 0 allocs/op
// BenchmarkAnonymizeIP/v6-12 346746447 3.436 ns/op 0 B/op 0 allocs/op
// BenchmarkAnonymizeIP/v6_slow-12 419062732 2.966 ns/op 0 B/op 0 allocs/op
// BenchmarkAnonymizeIP/invalid-12 316385232 3.941 ns/op 0 B/op 0 allocs/op
// BenchmarkAnonymizeIP/invalid_slow-12 456531592 2.760 ns/op 0 B/op 0 allocs/op
}