all: sync with master, upd chlog
This commit is contained in:
24
internal/aghnet/upstream.go
Normal file
24
internal/aghnet/upstream.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package aghnet
|
||||
|
||||
import "github.com/AdguardTeam/dnsproxy/upstream"
|
||||
|
||||
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration
|
||||
// depending on configuration.
|
||||
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
|
||||
if !http3 {
|
||||
return upstream.DefaultHTTPVersions
|
||||
}
|
||||
|
||||
return []upstream.HTTPVersion{
|
||||
upstream.HTTPVersion3,
|
||||
upstream.HTTPVersion2,
|
||||
upstream.HTTPVersion11,
|
||||
}
|
||||
}
|
||||
|
||||
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
|
||||
// This function is useful for filtering out non-upstream lines from upstream
|
||||
// configs.
|
||||
func IsCommentOrEmpty(s string) (ok bool) {
|
||||
return len(s) == 0 || s[0] == '#'
|
||||
}
|
||||
26
internal/aghnet/upstream_test.go
Normal file
26
internal/aghnet/upstream_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package aghnet_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsCommentOrEmpty(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
want assert.BoolAssertionFunc
|
||||
str string
|
||||
}{{
|
||||
want: assert.True,
|
||||
str: "",
|
||||
}, {
|
||||
want: assert.True,
|
||||
str: "# comment",
|
||||
}, {
|
||||
want: assert.False,
|
||||
str: "1.2.3.4",
|
||||
}} {
|
||||
tc.want(t, aghnet.IsCommentOrEmpty(tc.str))
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@@ -121,26 +120,6 @@ func (p *AddressUpdater) UpdateAddress(
|
||||
p.OnUpdateAddress(ctx, ip, host, info)
|
||||
}
|
||||
|
||||
// Package dnsforward
|
||||
|
||||
// ClientsContainer is a fake [dnsforward.ClientsContainer] implementation for
|
||||
// tests.
|
||||
type ClientsContainer struct {
|
||||
OnUpstreamConfigByID func(
|
||||
id string,
|
||||
boot upstream.Resolver,
|
||||
) (conf *proxy.CustomUpstreamConfig, err error)
|
||||
}
|
||||
|
||||
// UpstreamConfigByID implements the [dnsforward.ClientsContainer] interface
|
||||
// for *ClientsContainer.
|
||||
func (c *ClientsContainer) UpstreamConfigByID(
|
||||
id string,
|
||||
boot upstream.Resolver,
|
||||
) (conf *proxy.CustomUpstreamConfig, err error) {
|
||||
return c.OnUpstreamConfigByID(id, boot)
|
||||
}
|
||||
|
||||
// Package filtering
|
||||
|
||||
// Resolver is a fake [filtering.Resolver] implementation for tests.
|
||||
|
||||
@@ -3,7 +3,6 @@ package aghtest_test
|
||||
import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
)
|
||||
|
||||
@@ -12,9 +11,6 @@ import (
|
||||
// type check
|
||||
var _ filtering.Resolver = (*aghtest.Resolver)(nil)
|
||||
|
||||
// type check
|
||||
var _ dnsforward.ClientsContainer = (*aghtest.ClientsContainer)(nil)
|
||||
|
||||
// type check
|
||||
//
|
||||
// TODO(s.chzhen): It's here to avoid the import cycle. Remove it.
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
)
|
||||
|
||||
// macKey contains MAC as byte array of 6, 8, or 20 bytes.
|
||||
@@ -35,7 +34,7 @@ type index struct {
|
||||
// nameToUID maps client name to UID.
|
||||
nameToUID map[string]UID
|
||||
|
||||
// clientIDToUID maps client ID to UID.
|
||||
// clientIDToUID maps ClientID to UID.
|
||||
clientIDToUID map[string]UID
|
||||
|
||||
// ipToUID maps IP address to UID.
|
||||
@@ -205,19 +204,19 @@ func (ci *index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// find finds persistent client by string representation of the client ID, IP
|
||||
// find finds persistent client by string representation of the ClientID, IP
|
||||
// address, or MAC.
|
||||
func (ci *index) find(id string) (c *Persistent, ok bool) {
|
||||
uid, found := ci.clientIDToUID[id]
|
||||
if found {
|
||||
return ci.uidToClient[uid], true
|
||||
c, ok = ci.findByClientID(id)
|
||||
if ok {
|
||||
return c, true
|
||||
}
|
||||
|
||||
ip, err := netip.ParseAddr(id)
|
||||
if err == nil {
|
||||
// MAC addresses can be successfully parsed as IP addresses.
|
||||
c, found = ci.findByIP(ip)
|
||||
if found {
|
||||
c, ok = ci.findByIP(ip)
|
||||
if ok {
|
||||
return c, true
|
||||
}
|
||||
}
|
||||
@@ -230,6 +229,16 @@ func (ci *index) find(id string) (c *Persistent, ok bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findByClientID finds persistent client by ClientID.
|
||||
func (ci *index) findByClientID(clientID string) (c *Persistent, ok bool) {
|
||||
uid, ok := ci.clientIDToUID[clientID]
|
||||
if ok {
|
||||
return ci.uidToClient[uid], true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findByName finds persistent client by name.
|
||||
func (ci *index) findByName(name string) (c *Persistent, found bool) {
|
||||
uid, found := ci.nameToUID[name]
|
||||
@@ -343,18 +352,3 @@ func (ci *index) rangeByName(f func(c *Persistent) (cont bool)) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// closeUpstreams closes upstream configurations of persistent clients.
|
||||
func (ci *index) closeUpstreams() (err error) {
|
||||
var errs []error
|
||||
ci.rangeByName(func(c *Persistent) (cont bool) {
|
||||
err = c.CloseUpstreams()
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
@@ -58,12 +58,6 @@ func (uid *UID) UnmarshalText(data []byte) error {
|
||||
|
||||
// Persistent contains information about persistent clients.
|
||||
type Persistent struct {
|
||||
// UpstreamConfig is the custom upstream configuration for this client. If
|
||||
// it's nil, it has not been initialized yet. If it's non-nil and empty,
|
||||
// there are no valid upstreams. If it's non-nil and non-empty, these
|
||||
// upstream must be used.
|
||||
UpstreamConfig *proxy.CustomUpstreamConfig
|
||||
|
||||
// SafeSearch handles search engine hosts rewrites.
|
||||
SafeSearch filtering.SafeSearch
|
||||
|
||||
@@ -262,7 +256,7 @@ func ValidateClientID(id string) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// IDs returns a list of client IDs containing at least one element.
|
||||
// IDs returns a list of ClientIDs containing at least one element.
|
||||
func (c *Persistent) IDs() (ids []string) {
|
||||
ids = make([]string, 0, c.IDsLen())
|
||||
|
||||
@@ -281,7 +275,7 @@ func (c *Persistent) IDs() (ids []string) {
|
||||
return append(ids, c.ClientIDs...)
|
||||
}
|
||||
|
||||
// IDsLen returns a length of client ids.
|
||||
// IDsLen returns a length of ClientIDs.
|
||||
func (c *Persistent) IDsLen() (n int) {
|
||||
return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs)
|
||||
}
|
||||
@@ -312,14 +306,3 @@ func (c *Persistent) ShallowClone() (clone *Persistent) {
|
||||
|
||||
return clone
|
||||
}
|
||||
|
||||
// CloseUpstreams closes the client-specific upstream config of c if any.
|
||||
func (c *Persistent) CloseUpstreams() (err error) {
|
||||
if c.UpstreamConfig != nil {
|
||||
if err = c.UpstreamConfig.Close(); err != nil {
|
||||
return fmt.Errorf("closing upstreams of client %q: %w", c.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -13,9 +13,11 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/hostsfile"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
)
|
||||
|
||||
// allowedTags is the list of available client tags.
|
||||
@@ -88,6 +90,10 @@ type StorageConfig struct {
|
||||
// not be nil.
|
||||
Logger *slog.Logger
|
||||
|
||||
// Clock is used by [upstreamManager] to retrieve the current time. It must
|
||||
// not be nil.
|
||||
Clock timeutil.Clock
|
||||
|
||||
// DHCP is used to match IPs against MACs of persistent clients and update
|
||||
// [SourceDHCP] runtime client information. It must not be nil.
|
||||
DHCP DHCP
|
||||
@@ -126,6 +132,9 @@ type Storage struct {
|
||||
// runtimeIndex contains information about runtime clients.
|
||||
runtimeIndex *runtimeIndex
|
||||
|
||||
// upstreamManager stores and updates custom client upstream configurations.
|
||||
upstreamManager *upstreamManager
|
||||
|
||||
// dhcp is used to update [SourceDHCP] runtime client information.
|
||||
dhcp DHCP
|
||||
|
||||
@@ -163,6 +172,7 @@ func NewStorage(ctx context.Context, conf *StorageConfig) (s *Storage, err error
|
||||
mu: &sync.Mutex{},
|
||||
index: newIndex(),
|
||||
runtimeIndex: newRuntimeIndex(),
|
||||
upstreamManager: newUpstreamManager(conf.Logger, conf.Clock),
|
||||
dhcp: conf.DHCP,
|
||||
etcHosts: conf.EtcHosts,
|
||||
arpDB: conf.ARPDB,
|
||||
@@ -200,7 +210,7 @@ func (s *Storage) Start(ctx context.Context) (err error) {
|
||||
func (s *Storage) Shutdown(_ context.Context) (err error) {
|
||||
close(s.done)
|
||||
|
||||
return s.closeUpstreams()
|
||||
return s.upstreamManager.close()
|
||||
}
|
||||
|
||||
// periodicARPUpdate periodically reloads runtime clients from ARP. It is
|
||||
@@ -416,6 +426,7 @@ func (s *Storage) Add(ctx context.Context, p *Persistent) (err error) {
|
||||
}
|
||||
|
||||
s.index.add(p)
|
||||
s.upstreamManager.updateCustomUpstreamConfig(p)
|
||||
|
||||
s.logger.DebugContext(
|
||||
ctx,
|
||||
@@ -441,7 +452,7 @@ func (s *Storage) FindByName(name string) (p *Persistent, ok bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Find finds persistent client by string representation of the client ID, IP
|
||||
// Find finds persistent client by string representation of the ClientID, IP
|
||||
// address, or MAC. And returns its shallow copy.
|
||||
//
|
||||
// TODO(s.chzhen): Accept ClientIDData structure instead, which will contain
|
||||
@@ -514,12 +525,13 @@ func (s *Storage) RemoveByName(ctx context.Context, name string) (ok bool) {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := p.CloseUpstreams(); err != nil {
|
||||
s.logger.ErrorContext(ctx, "removing client", "name", p.Name, slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
s.index.remove(p)
|
||||
|
||||
err := s.upstreamManager.remove(p.UID)
|
||||
if err != nil {
|
||||
s.logger.DebugContext(ctx, "closing client upstreams", "name", name, slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -556,6 +568,8 @@ func (s *Storage) Update(ctx context.Context, name string, p *Persistent) (err e
|
||||
s.index.remove(stored)
|
||||
s.index.add(p)
|
||||
|
||||
s.upstreamManager.updateCustomUpstreamConfig(p)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -576,14 +590,6 @@ func (s *Storage) Size() (n int) {
|
||||
return s.index.size()
|
||||
}
|
||||
|
||||
// closeUpstreams closes upstream configurations of persistent clients.
|
||||
func (s *Storage) closeUpstreams() (err error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
return s.index.closeUpstreams()
|
||||
}
|
||||
|
||||
// ClientRuntime returns a copy of the saved runtime client by ip. If no such
|
||||
// client exists, returns nil.
|
||||
func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) {
|
||||
@@ -626,3 +632,42 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
|
||||
func (s *Storage) AllowedTags() (tags []string) {
|
||||
return s.allowedTags
|
||||
}
|
||||
|
||||
// CustomUpstreamConfig implements the [dnsforward.ClientsContainer] interface
|
||||
// for *Storage
|
||||
func (s *Storage) CustomUpstreamConfig(
|
||||
id string,
|
||||
addr netip.Addr,
|
||||
) (prxConf *proxy.CustomUpstreamConfig) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
c, ok := s.index.findByClientID(id)
|
||||
if !ok {
|
||||
c, ok = s.index.findByIP(addr)
|
||||
}
|
||||
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.upstreamManager.customUpstreamConfig(c.UID)
|
||||
}
|
||||
|
||||
// UpdateCommonUpstreamConfig implements the [dnsforward.ClientsContainer]
|
||||
// interface for *Storage
|
||||
func (s *Storage) UpdateCommonUpstreamConfig(conf *CommonUpstreamConfig) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.upstreamManager.updateCommonUpstreamConfig(conf)
|
||||
}
|
||||
|
||||
// ClearUpstreamCache implements the [dnsforward.ClientsContainer] interface for
|
||||
// *Storage
|
||||
func (s *Storage) ClearUpstreamCache() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.upstreamManager.clearUpstreamCache()
|
||||
}
|
||||
|
||||
@@ -13,27 +13,34 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/golibs/hostsfile"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/testutil/faketime"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// newTestStorage is a helper function that returns initialized storage.
|
||||
func newTestStorage(tb testing.TB) (s *client.Storage) {
|
||||
func newTestStorage(tb testing.TB, clock timeutil.Clock) (s *client.Storage) {
|
||||
tb.Helper()
|
||||
|
||||
ctx := testutil.ContextWithTimeout(tb, testTimeout)
|
||||
s, err := client.NewStorage(ctx, &client.StorageConfig{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
Clock: clock,
|
||||
})
|
||||
require.NoError(tb, err)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ dnsforward.ClientsContainer = (*client.Storage)(nil)
|
||||
|
||||
// testHostsContainer is a mock implementation of the [client.HostsContainer]
|
||||
// interface.
|
||||
type testHostsContainer struct {
|
||||
@@ -691,7 +698,7 @@ func TestStorage_Add(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
s := newTestStorage(t)
|
||||
s := newTestStorage(t, timeutil.SystemClock{})
|
||||
tags := s.AllowedTags()
|
||||
require.NotZero(t, len(tags))
|
||||
require.True(t, slices.IsSorted(tags))
|
||||
@@ -822,7 +829,7 @@ func TestStorage_RemoveByName(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
s := newTestStorage(t)
|
||||
s := newTestStorage(t, timeutil.SystemClock{})
|
||||
err := s.Add(ctx, existingClient)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -847,7 +854,7 @@ func TestStorage_RemoveByName(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("duplicate_remove", func(t *testing.T) {
|
||||
s = newTestStorage(t)
|
||||
s = newTestStorage(t, timeutil.SystemClock{})
|
||||
err = s.Add(ctx, existingClient)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1278,3 +1285,99 @@ func TestStorage_RangeByName(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorage_CustomUpstreamConfig(t *testing.T) {
|
||||
const (
|
||||
existingName = "existing_name"
|
||||
existingClientID = "existing_client_id"
|
||||
|
||||
nonExistingClientID = "non_existing_client_id"
|
||||
)
|
||||
|
||||
var (
|
||||
existingClientUID = client.MustNewUID()
|
||||
existingIP = netip.MustParseAddr("192.0.2.1")
|
||||
|
||||
nonExistingIP = netip.MustParseAddr("192.0.2.255")
|
||||
|
||||
testUpstreamTimeout = time.Second
|
||||
)
|
||||
|
||||
existingClient := &client.Persistent{
|
||||
Name: existingName,
|
||||
IPs: []netip.Addr{existingIP},
|
||||
ClientIDs: []string{existingClientID},
|
||||
UID: existingClientUID,
|
||||
Upstreams: []string{"192.0.2.0"},
|
||||
}
|
||||
|
||||
date := time.Now()
|
||||
clock := &faketime.Clock{
|
||||
OnNow: func() (now time.Time) {
|
||||
date = date.Add(time.Second)
|
||||
|
||||
return date
|
||||
},
|
||||
}
|
||||
|
||||
s := newTestStorage(t, clock)
|
||||
s.UpdateCommonUpstreamConfig(&client.CommonUpstreamConfig{
|
||||
UpstreamTimeout: testUpstreamTimeout,
|
||||
})
|
||||
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
return s.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
|
||||
})
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
err := s.Add(ctx, existingClient)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
cliAddr netip.Addr
|
||||
wantNilConf assert.ValueAssertionFunc
|
||||
name string
|
||||
cliID string
|
||||
}{{
|
||||
name: "client_id",
|
||||
cliID: existingClientID,
|
||||
cliAddr: netip.Addr{},
|
||||
wantNilConf: assert.NotNil,
|
||||
}, {
|
||||
name: "client_addr",
|
||||
cliID: "",
|
||||
cliAddr: existingIP,
|
||||
wantNilConf: assert.NotNil,
|
||||
}, {
|
||||
name: "non_existing_client_id",
|
||||
cliID: nonExistingClientID,
|
||||
cliAddr: netip.Addr{},
|
||||
wantNilConf: assert.Nil,
|
||||
}, {
|
||||
name: "non_existing_client_addr",
|
||||
cliID: "",
|
||||
cliAddr: nonExistingIP,
|
||||
wantNilConf: assert.Nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
conf := s.CustomUpstreamConfig(tc.cliID, tc.cliAddr)
|
||||
tc.wantNilConf(t, conf)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("update_common_config", func(t *testing.T) {
|
||||
conf := s.CustomUpstreamConfig(existingClientID, existingIP)
|
||||
require.NotNil(t, conf)
|
||||
|
||||
s.UpdateCommonUpstreamConfig(&client.CommonUpstreamConfig{
|
||||
UpstreamTimeout: testUpstreamTimeout * 2,
|
||||
})
|
||||
|
||||
updConf := s.CustomUpstreamConfig(existingClientID, existingIP)
|
||||
require.NotNil(t, updConf)
|
||||
|
||||
assert.NotEqual(t, conf, updConf)
|
||||
})
|
||||
}
|
||||
|
||||
224
internal/client/upstreammanager.go
Normal file
224
internal/client/upstreammanager.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
)
|
||||
|
||||
// CommonUpstreamConfig contains common settings for custom client upstream
|
||||
// configurations.
|
||||
type CommonUpstreamConfig struct {
|
||||
Bootstrap upstream.Resolver
|
||||
UpstreamTimeout time.Duration
|
||||
BootstrapPreferIPv6 bool
|
||||
EDNSClientSubnetEnabled bool
|
||||
UseHTTP3Upstreams bool
|
||||
}
|
||||
|
||||
// customUpstreamConfig contains custom client upstream configuration and the
|
||||
// timestamp of the latest configuration update.
|
||||
type customUpstreamConfig struct {
|
||||
// proxyConf is the constructed upstream configuration for the [proxy],
|
||||
// derived from the fields below. It is initialized on demand with
|
||||
// [newCustomUpstreamConfig].
|
||||
proxyConf *proxy.CustomUpstreamConfig
|
||||
|
||||
// commonConfUpdate is the timestamp of the latest configuration update,
|
||||
// used to check against [upstreamManager.confUpdate] to determine if the
|
||||
// configuration is up to date.
|
||||
commonConfUpdate time.Time
|
||||
|
||||
// upstreams is the cached list of custom upstream DNS servers used for the
|
||||
// configuration of proxyConf.
|
||||
upstreams []string
|
||||
|
||||
// upstreamsCacheSize is the cached value of the cache size of the
|
||||
// upstreams, used for the configuration of proxyConf.
|
||||
upstreamsCacheSize uint32
|
||||
|
||||
// upstreamsCacheEnabled is the cached value indicating whether the cache of
|
||||
// the upstreams is enabled for the configuration of proxyConf.
|
||||
upstreamsCacheEnabled bool
|
||||
|
||||
// isChanged indicates whether the proxyConf needs to be updated.
|
||||
isChanged bool
|
||||
}
|
||||
|
||||
// upstreamManager stores and updates custom client upstream configurations.
|
||||
type upstreamManager struct {
|
||||
// logger is used for logging the operation of the upstream manager. It
|
||||
// must not be nil.
|
||||
//
|
||||
// TODO(s.chzhen): Consider using a logger with its own prefix.
|
||||
logger *slog.Logger
|
||||
|
||||
// uidToCustomConf maps persistent client UID to the custom client upstream
|
||||
// configuration. Stored UIDs must be in sync with the [index.uidToClient].
|
||||
uidToCustomConf map[UID]*customUpstreamConfig
|
||||
|
||||
// commonConf is the common upstream configuration.
|
||||
commonConf *CommonUpstreamConfig
|
||||
|
||||
// clock is used to get the current time. It must not be nil.
|
||||
clock timeutil.Clock
|
||||
|
||||
// confUpdate is the timestamp of the latest common upstream configuration
|
||||
// update.
|
||||
confUpdate time.Time
|
||||
}
|
||||
|
||||
// newUpstreamManager returns the new properly initialized upstream manager.
|
||||
func newUpstreamManager(logger *slog.Logger, clock timeutil.Clock) (m *upstreamManager) {
|
||||
return &upstreamManager{
|
||||
logger: logger,
|
||||
uidToCustomConf: make(map[UID]*customUpstreamConfig),
|
||||
clock: clock,
|
||||
}
|
||||
}
|
||||
|
||||
// updateCommonUpstreamConfig updates the common upstream configuration and the
|
||||
// timestamp of the latest configuration update.
|
||||
func (m *upstreamManager) updateCommonUpstreamConfig(conf *CommonUpstreamConfig) {
|
||||
m.commonConf = conf
|
||||
m.confUpdate = m.clock.Now()
|
||||
}
|
||||
|
||||
// updateCustomUpstreamConfig updates the stored custom client upstream
|
||||
// configuration associated with the persistent client. It also sets
|
||||
// [customUpstreamConfig.isChanged] to true so [customUpstreamConfig.proxyConf]
|
||||
// can be updated later in [upstreamManager.customUpstreamConfig].
|
||||
func (m *upstreamManager) updateCustomUpstreamConfig(c *Persistent) {
|
||||
cliConf, ok := m.uidToCustomConf[c.UID]
|
||||
if !ok {
|
||||
cliConf = &customUpstreamConfig{
|
||||
commonConfUpdate: m.confUpdate,
|
||||
}
|
||||
|
||||
m.uidToCustomConf[c.UID] = cliConf
|
||||
}
|
||||
|
||||
// TODO(s.chzhen): Compare before cloning.
|
||||
cliConf.upstreams = slices.Clone(c.Upstreams)
|
||||
cliConf.upstreamsCacheSize = c.UpstreamsCacheSize
|
||||
cliConf.upstreamsCacheEnabled = c.UpstreamsCacheEnabled
|
||||
cliConf.isChanged = true
|
||||
}
|
||||
|
||||
// customUpstreamConfig returns the custom client upstream configuration.
|
||||
func (m *upstreamManager) customUpstreamConfig(uid UID) (proxyConf *proxy.CustomUpstreamConfig) {
|
||||
cliConf, ok := m.uidToCustomConf[uid]
|
||||
if !ok {
|
||||
// TODO(s.chzhen): Consider panic.
|
||||
m.logger.Error("no associated custom client upstream config")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if !m.isConfigChanged(cliConf) {
|
||||
return cliConf.proxyConf
|
||||
}
|
||||
|
||||
if cliConf.proxyConf != nil {
|
||||
err := cliConf.proxyConf.Close()
|
||||
if err != nil {
|
||||
// TODO(s.chzhen): Pass context.
|
||||
m.logger.Debug("closing custom upstream config", slogutil.KeyError, err)
|
||||
}
|
||||
}
|
||||
|
||||
proxyConf = newCustomUpstreamConfig(cliConf, m.commonConf)
|
||||
cliConf.proxyConf = proxyConf
|
||||
cliConf.isChanged = false
|
||||
|
||||
return proxyConf
|
||||
}
|
||||
|
||||
// isConfigChanged returns true if the update is necessary for the custom client
|
||||
// upstream configuration.
|
||||
func (m *upstreamManager) isConfigChanged(cliConf *customUpstreamConfig) (ok bool) {
|
||||
return !m.confUpdate.Equal(cliConf.commonConfUpdate) || cliConf.isChanged
|
||||
}
|
||||
|
||||
// clearUpstreamCache clears the upstream cache for each stored custom client
|
||||
// upstream configuration.
|
||||
func (m *upstreamManager) clearUpstreamCache() {
|
||||
for _, c := range m.uidToCustomConf {
|
||||
c.proxyConf.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
// remove deletes the custom client upstream configuration and closes
|
||||
// [customUpstreamConfig.proxyConf] if necessary.
|
||||
func (m *upstreamManager) remove(uid UID) (err error) {
|
||||
cliConf, ok := m.uidToCustomConf[uid]
|
||||
if !ok {
|
||||
// TODO(s.chzhen): Consider panic.
|
||||
return errors.Error("no associated custom client upstream config")
|
||||
}
|
||||
|
||||
delete(m.uidToCustomConf, uid)
|
||||
|
||||
if cliConf.proxyConf != nil {
|
||||
return cliConf.proxyConf.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// close shuts down each stored custom client upstream configuration.
|
||||
func (m *upstreamManager) close() (err error) {
|
||||
var errs []error
|
||||
for _, c := range m.uidToCustomConf {
|
||||
if c.proxyConf == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
errs = append(errs, c.proxyConf.Close())
|
||||
}
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// newCustomUpstreamConfig returns the new properly initialized custom proxy
|
||||
// upstream configuration for the client.
|
||||
func newCustomUpstreamConfig(
|
||||
cliConf *customUpstreamConfig,
|
||||
conf *CommonUpstreamConfig,
|
||||
) (proxyConf *proxy.CustomUpstreamConfig) {
|
||||
upstreams := stringutil.FilterOut(cliConf.upstreams, aghnet.IsCommentOrEmpty)
|
||||
if len(upstreams) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
upsConf, err := proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
&upstream.Options{
|
||||
Bootstrap: conf.Bootstrap,
|
||||
Timeout: time.Duration(conf.UpstreamTimeout),
|
||||
HTTPVersions: aghnet.UpstreamHTTPVersions(conf.UseHTTP3Upstreams),
|
||||
PreferIPv6: conf.BootstrapPreferIPv6,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
// Should not happen because upstreams are already validated. See
|
||||
// [Persistent.validate].
|
||||
panic(fmt.Errorf("creating custom upstream config: %w", err))
|
||||
}
|
||||
|
||||
return proxy.NewCustomUpstreamConfig(
|
||||
upsConf,
|
||||
cliConf.upstreamsCacheEnabled,
|
||||
int(cliConf.upstreamsCacheSize),
|
||||
conf.EDNSClientSubnetEnabled,
|
||||
)
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -14,18 +15,48 @@ import (
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
// raCtx is a context for the Router Advertisement logic.
|
||||
type raCtx struct {
|
||||
raAllowSLAAC bool // send RA packets without MO flags
|
||||
raSLAACOnly bool // send RA packets with MO flags
|
||||
ipAddr net.IP // source IP address (link-local-unicast)
|
||||
dnsIPAddr net.IP // IP address for DNS Server option
|
||||
prefixIPAddr net.IP // IP address for Prefix option
|
||||
ifaceName string
|
||||
iface *net.Interface
|
||||
packetSendPeriod time.Duration // how often RA packets are sent
|
||||
// raAllowSLAAC is used to determine if the ICMP Router Advertisement
|
||||
// messages should be sent.
|
||||
//
|
||||
// If both raAllowSLAAC and raSLAACOnly are false, the Router Advertisement
|
||||
// messages aren't sent.
|
||||
raAllowSLAAC bool
|
||||
|
||||
conn *icmp.PacketConn // ICMPv6 socket
|
||||
stop atomic.Value // stop the packet sending loop
|
||||
// raSLAACOnly is used to determine if the ICMP Router Advertisement
|
||||
// messages should set M and O flags, see RFC 4861, section 4.2.
|
||||
//
|
||||
// If both raAllowSLAAC and raSLAACOnly are false, the Router Advertisement
|
||||
// messages aren't sent.
|
||||
raSLAACOnly bool
|
||||
|
||||
// ipAddr is an IP address used within the Source Link-Layer Address option.
|
||||
// See RFC 4861, section 4.6.1.
|
||||
ipAddr net.IP
|
||||
|
||||
// dnsIPAddr is an IP address used within the DNS Server option.
|
||||
dnsIPAddr net.IP
|
||||
|
||||
// prefixIPAddr is an IP address used within the Prefix Information option.
|
||||
// See RFC 4861, section 4.6.2.
|
||||
prefixIPAddr net.IP
|
||||
|
||||
// ifaceName is the name of the interface used as a scope of the IP
|
||||
// addresses.
|
||||
ifaceName string
|
||||
|
||||
// iface is the network interface used to send the ICMPv6 packets.
|
||||
iface *net.Interface
|
||||
|
||||
// packetSendPeriod is the interval between sending the ICMPv6 packets.
|
||||
packetSendPeriod time.Duration
|
||||
|
||||
// conn is the ICMPv6 socket.
|
||||
conn *icmp.PacketConn
|
||||
|
||||
// stop is used to stop the packet sending loop.
|
||||
stop atomic.Value
|
||||
}
|
||||
|
||||
type icmpv6RA struct {
|
||||
@@ -38,10 +69,11 @@ type icmpv6RA struct {
|
||||
mtu uint32
|
||||
}
|
||||
|
||||
// hwAddrToLinkLayerAddr converts a hardware address into a form required by
|
||||
// RFC4861. That is, a byte slice of length divisible by 8.
|
||||
// hwAddrToLinkLayerAddr clones the hardware address and returns it as a byte
|
||||
// slice suitable for the Source Link-Layer Address option in the ICMPv6
|
||||
// Router Advertisement packet.
|
||||
//
|
||||
// See https://tools.ietf.org/html/rfc4861#section-4.6.1.
|
||||
// TODO(e.burkov): Check if it's safe to use the original slice.
|
||||
func hwAddrToLinkLayerAddr(hwa net.HardwareAddr) (lla []byte, err error) {
|
||||
err = netutil.ValidateMAC(hwa)
|
||||
if err != nil {
|
||||
@@ -50,19 +82,7 @@ func hwAddrToLinkLayerAddr(hwa net.HardwareAddr) (lla []byte, err error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(hwa) == 6 || len(hwa) == 8 {
|
||||
lla = make([]byte, 8)
|
||||
copy(lla, hwa)
|
||||
|
||||
return lla, nil
|
||||
}
|
||||
|
||||
// Assume that netutil.ValidateMAC prevents lengths other than 20 by
|
||||
// now.
|
||||
lla = make([]byte, 24)
|
||||
copy(lla, hwa)
|
||||
|
||||
return lla, nil
|
||||
return slices.Clone(hwa), nil
|
||||
}
|
||||
|
||||
// Create an ICMPv6.RouterAdvertisement packet with all necessary options.
|
||||
@@ -103,15 +123,24 @@ func hwAddrToLinkLayerAddr(hwa net.HardwareAddr) (lla []byte, err error) {
|
||||
//
|
||||
// TODO(a.garipov): Replace with an existing implementation from a dependency.
|
||||
func createICMPv6RAPacket(params icmpv6RA) (data []byte, err error) {
|
||||
var lla []byte
|
||||
lla, err = hwAddrToLinkLayerAddr(params.sourceLinkLayerAddress)
|
||||
lla, err := hwAddrToLinkLayerAddr(params.sourceLinkLayerAddress)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("converting source link layer address: %w", err)
|
||||
return nil, fmt.Errorf("converting source link-layer address: %w", err)
|
||||
}
|
||||
|
||||
// Calculate length of the source link-layer address option. As per RFC
|
||||
// 4861, section 4.6.1, the length should be in units of 8 octets, including
|
||||
// the type and length fields.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc4861#section-4.6.1.
|
||||
srcLLAOptLen := len(lla) + 2
|
||||
// Make sure the value is rounded up to the nearest multiple of 8.
|
||||
srcLLAOptLenValue := (srcLLAOptLen + 7) / 8
|
||||
srcLLAPadLen := srcLLAOptLenValue*8 - srcLLAOptLen
|
||||
|
||||
// TODO(a.garipov): Don't use a magic constant here. Refactor the code
|
||||
// and make all constants named instead of all those comments..
|
||||
data = make([]byte, 82+len(lla))
|
||||
// and make all constants named instead of all those comments.
|
||||
data = make([]byte, 80+srcLLAOptLen+srcLLAPadLen)
|
||||
i := 0
|
||||
|
||||
// ICMPv6:
|
||||
@@ -175,12 +204,11 @@ func createICMPv6RAPacket(params icmpv6RA) (data []byte, err error) {
|
||||
|
||||
// Option=Source link-layer address:
|
||||
|
||||
data[i] = 1 // Type
|
||||
data[i+1] = 1 // Length
|
||||
data[i] = 1 // Type
|
||||
data[i+1] = byte(srcLLAOptLenValue) // Length
|
||||
i += 2
|
||||
|
||||
copy(data[i:], lla) // Link-Layer Address[8/24]
|
||||
i += len(lla)
|
||||
i += len(lla) + srcLLAPadLen
|
||||
|
||||
// Option=Recursive DNS Server:
|
||||
|
||||
|
||||
63
internal/dhcpd/routeradv_internal_test.go
Normal file
63
internal/dhcpd/routeradv_internal_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateICMPv6RAPacket(t *testing.T) {
|
||||
raConf := icmpv6RA{
|
||||
managedAddressConfiguration: false,
|
||||
otherConfiguration: true,
|
||||
mtu: 1500,
|
||||
prefix: net.ParseIP("1234::"),
|
||||
prefixLen: 64,
|
||||
recursiveDNSServer: net.ParseIP("fe80::800:27ff:fe00:0"),
|
||||
sourceLinkLayerAddress: []byte{0x0A, 0x00, 0x27, 0x00, 0x00, 0x00},
|
||||
}
|
||||
|
||||
pkt, err := createICMPv6RAPacket(raConf)
|
||||
require.NoError(t, err)
|
||||
|
||||
icmpPkt := &layers.ICMPv6{}
|
||||
err = icmpPkt.DecodeFromBytes(pkt, gopacket.NilDecodeFeedback)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, layers.LayerTypeICMPv6RouterAdvertisement, icmpPkt.NextLayerType())
|
||||
raPkt := &layers.ICMPv6RouterAdvertisement{}
|
||||
err = raPkt.DecodeFromBytes(icmpPkt.LayerPayload(), gopacket.NilDecodeFeedback)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, raConf.managedAddressConfiguration, raPkt.ManagedAddressConfig())
|
||||
assert.Equal(t, raConf.otherConfiguration, raPkt.OtherConfig())
|
||||
|
||||
wantOpts := layers.ICMPv6Options{{
|
||||
Type: layers.ICMPv6OptPrefixInfo,
|
||||
Data: []uint8{
|
||||
0x40, 0xC0, 0x00, 0x00, 0x0E, 0x10, 0x00, 0x00,
|
||||
0x0E, 0x10, 0x00, 0x00, 0x00, 0x00, 0x12, 0x34,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
},
|
||||
}, {
|
||||
Type: layers.ICMPv6OptMTU,
|
||||
Data: []uint8{0x00, 0x00, 0x00, 0x00, 0x05, 0xDC},
|
||||
}, {
|
||||
Type: layers.ICMPv6OptSourceAddress,
|
||||
Data: []uint8{0x0A, 0x00, 0x27, 0x00, 0x00, 0x0},
|
||||
}, {
|
||||
// Package layers declares no constant for Recursive DNS Server option.
|
||||
Type: layers.ICMPv6Opt(25),
|
||||
Data: []uint8{
|
||||
0x00, 0x00, 0x00, 0x00, 0x0E, 0x10, 0xFE, 0x80,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00,
|
||||
0x27, 0xFF, 0xFE, 0x00, 0x00, 0x00,
|
||||
},
|
||||
}}
|
||||
assert.Equal(t, wantOpts, raPkt.Options)
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCreateICMPv6RAPacket(t *testing.T) {
|
||||
wantData := []byte{
|
||||
0x86, 0x00, 0x00, 0x00, 0x40, 0x40, 0x07, 0x08,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x03, 0x04, 0x40, 0xc0, 0x00, 0x00, 0x0e, 0x10,
|
||||
0x00, 0x00, 0x0e, 0x10, 0x00, 0x00, 0x00, 0x00,
|
||||
0x12, 0x34, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x05, 0x01, 0x00, 0x00, 0x00, 0x00, 0x05, 0xdc,
|
||||
0x01, 0x01, 0x0a, 0x00, 0x27, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x19, 0x03, 0x00, 0x00, 0x00, 0x00,
|
||||
0x0e, 0x10, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x08, 0x00, 0x27, 0xff, 0xfe, 0x00,
|
||||
0x00, 0x00,
|
||||
}
|
||||
|
||||
gotData, err := createICMPv6RAPacket(icmpv6RA{
|
||||
managedAddressConfiguration: false,
|
||||
otherConfiguration: true,
|
||||
mtu: 1500,
|
||||
prefix: net.ParseIP("1234::"),
|
||||
prefixLen: 64,
|
||||
recursiveDNSServer: net.ParseIP("fe80::800:27ff:fe00:0"),
|
||||
sourceLinkLayerAddress: []byte{0x0a, 0x00, 0x27, 0x00, 0x00, 0x00},
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, wantData, gotData)
|
||||
}
|
||||
@@ -239,7 +239,7 @@ func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = validateAccessSet(list)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, err.Error())
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
var _ proxy.BeforeRequestHandler = (*Server)(nil)
|
||||
|
||||
// HandleBefore is the handler that is called before any other processing,
|
||||
// including logs. It performs access checks and puts the client ID, if there
|
||||
// including logs. It performs access checks and puts the ClientID, if there
|
||||
// is one, into the server's cache.
|
||||
//
|
||||
// TODO(d.kolyshev): Extract to separate package.
|
||||
|
||||
@@ -266,6 +266,7 @@ func TestServer_HandleBefore_udp(t *testing.T) {
|
||||
UpstreamDNS: []string{localUpsAddr},
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
|
||||
@@ -62,7 +62,7 @@ func clientIDFromClientServerName(
|
||||
return strings.ToLower(clientID), nil
|
||||
}
|
||||
|
||||
// clientIDFromDNSContextHTTPS extracts the client's ID from the path of the
|
||||
// clientIDFromDNSContextHTTPS extracts the ClientID from the path of the
|
||||
// client's DNS-over-HTTPS request.
|
||||
func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err error) {
|
||||
r := pctx.HTTPRequest
|
||||
|
||||
46
internal/dnsforward/clientscontainer.go
Normal file
46
internal/dnsforward/clientscontainer.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
)
|
||||
|
||||
// ClientsContainer provides information about preconfigured DNS clients.
|
||||
type ClientsContainer interface {
|
||||
// CustomUpstreamConfig returns the custom client upstream configuration, if
|
||||
// any. It prioritizes ClientID over client IP address to identify the
|
||||
// client.
|
||||
CustomUpstreamConfig(clientID string, cliAddr netip.Addr) (conf *proxy.CustomUpstreamConfig)
|
||||
|
||||
// UpdateCommonUpstreamConfig updates the common upstream configuration.
|
||||
UpdateCommonUpstreamConfig(conf *client.CommonUpstreamConfig)
|
||||
|
||||
// ClearUpstreamCache clears the upstream cache for each stored custom
|
||||
// client upstream configuration.
|
||||
ClearUpstreamCache()
|
||||
}
|
||||
|
||||
// EmptyClientsContainer is an [ClientsContainer] implementation that does nothing.
|
||||
type EmptyClientsContainer struct{}
|
||||
|
||||
// type check
|
||||
var _ ClientsContainer = EmptyClientsContainer{}
|
||||
|
||||
// CustomUpstreamConfig implements the [ClientsContainer] interface for
|
||||
// EmptyClientsContainer.
|
||||
func (EmptyClientsContainer) CustomUpstreamConfig(
|
||||
clientID string,
|
||||
cliAddr netip.Addr,
|
||||
) (conf *proxy.CustomUpstreamConfig) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateCommonUpstreamConfig implements the [ClientsContainer] interface for
|
||||
// EmptyClientsContainer.
|
||||
func (EmptyClientsContainer) UpdateCommonUpstreamConfig(conf *client.CommonUpstreamConfig) {}
|
||||
|
||||
// ClearUpstreamCache implements the [ClientsContainer] interface for
|
||||
// EmptyClientsContainer.
|
||||
func (EmptyClientsContainer) ClearUpstreamCache() {}
|
||||
@@ -29,19 +29,6 @@ import (
|
||||
"github.com/ameshkov/dnscrypt/v2"
|
||||
)
|
||||
|
||||
// ClientsContainer provides information about preconfigured DNS clients.
|
||||
type ClientsContainer interface {
|
||||
// UpstreamConfigByID returns the custom upstream configuration for the
|
||||
// client having id, using boot to initialize the one if necessary. It
|
||||
// returns nil if there is no custom upstream configuration for the client.
|
||||
// The id is expected to be either a string representation of an IP address
|
||||
// or the ClientID.
|
||||
UpstreamConfigByID(
|
||||
id string,
|
||||
boot upstream.Resolver,
|
||||
) (conf *proxy.CustomUpstreamConfig, err error)
|
||||
}
|
||||
|
||||
// Config represents the DNS filtering configuration of AdGuard Home. The zero
|
||||
// Config is empty and ready for use.
|
||||
type Config struct {
|
||||
@@ -467,7 +454,7 @@ func (s *Server) prepareIpsetListSettings() (ipsets []string, err error) {
|
||||
}
|
||||
|
||||
ipsets = stringutil.SplitTrimmed(string(data), "\n")
|
||||
ipsets = slices.DeleteFunc(ipsets, IsCommentOrEmpty)
|
||||
ipsets = slices.DeleteFunc(ipsets, aghnet.IsCommentOrEmpty)
|
||||
|
||||
log.Debug("dns: using %d ipset rules from file %q", len(ipsets), fn)
|
||||
|
||||
@@ -478,7 +465,7 @@ func (s *Server) prepareIpsetListSettings() (ipsets []string, err error) {
|
||||
// the configuration itself.
|
||||
func (conf *ServerConfig) loadUpstreams() (upstreams []string, err error) {
|
||||
if conf.UpstreamDNSFileName == "" {
|
||||
return stringutil.FilterOut(conf.UpstreamDNS, IsCommentOrEmpty), nil
|
||||
return stringutil.FilterOut(conf.UpstreamDNS, aghnet.IsCommentOrEmpty), nil
|
||||
}
|
||||
|
||||
var data []byte
|
||||
@@ -491,7 +478,7 @@ func (conf *ServerConfig) loadUpstreams() (upstreams []string, err error) {
|
||||
|
||||
log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), conf.UpstreamDNSFileName)
|
||||
|
||||
return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil
|
||||
return stringutil.FilterOut(upstreams, aghnet.IsCommentOrEmpty), nil
|
||||
}
|
||||
|
||||
// collectListenAddr adds addrPort to addrs. It also adds its port to
|
||||
|
||||
@@ -299,6 +299,7 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
UpstreamDNS: []string{upsAddr},
|
||||
},
|
||||
UsePrivateRDNS: true,
|
||||
@@ -337,6 +338,7 @@ func TestServer_dns64WithDisabledRDNS(t *testing.T) {
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
UpstreamDNS: []string{upsAddr},
|
||||
},
|
||||
UsePrivateRDNS: false,
|
||||
|
||||
@@ -540,7 +540,7 @@ func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
|
||||
uc, err := newUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
|
||||
Bootstrap: boot,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
|
||||
HTTPVersions: aghnet.UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
// Use a customized set of RootCAs, because Go's default mechanism of
|
||||
// loading TLS roots does not always work properly on some routers so we're
|
||||
@@ -557,6 +557,13 @@ func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
|
||||
}
|
||||
|
||||
s.conf.UpstreamConfig = uc
|
||||
s.conf.ClientsContainer.UpdateCommonUpstreamConfig(&client.CommonUpstreamConfig{
|
||||
Bootstrap: boot,
|
||||
UpstreamTimeout: s.conf.UpstreamTimeout,
|
||||
BootstrapPreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
EDNSClientSubnetEnabled: s.conf.EDNSClientSubnet.Enabled,
|
||||
UseHTTP3Upstreams: s.conf.UseHTTP3Upstreams,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -630,7 +637,7 @@ func (s *Server) prepareInternalDNS() (err error) {
|
||||
|
||||
bootOpts := &upstream.Options{
|
||||
Timeout: DefaultTimeout,
|
||||
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
|
||||
HTTPVersions: aghnet.UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
|
||||
}
|
||||
|
||||
s.bootstrap, s.bootResolvers, err = newBootstrap(s.conf.BootstrapDNS, s.etcHosts, bootOpts)
|
||||
@@ -661,7 +668,7 @@ func (s *Server) prepareInternalDNS() (err error) {
|
||||
// setupFallbackDNS initializes the fallback DNS servers.
|
||||
func (s *Server) setupFallbackDNS() (uc *proxy.UpstreamConfig, err error) {
|
||||
fallbacks := s.conf.FallbackDNS
|
||||
fallbacks = stringutil.FilterOut(fallbacks, IsCommentOrEmpty)
|
||||
fallbacks = stringutil.FilterOut(fallbacks, aghnet.IsCommentOrEmpty)
|
||||
if len(fallbacks) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/hashprefix"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||
@@ -61,6 +62,42 @@ const (
|
||||
// TODO(a.garipov): Use more.
|
||||
var testClientAddrPort = netip.MustParseAddrPort("1.2.3.4:12345")
|
||||
|
||||
// type check
|
||||
var _ ClientsContainer = (*clientsContainer)(nil)
|
||||
|
||||
// clientsContainer is a mock [ClientsContainer] implementation for tests.
|
||||
type clientsContainer struct {
|
||||
OnCustomUpstreamConfig func(
|
||||
clientID string,
|
||||
cliAddr netip.Addr,
|
||||
) (conf *proxy.CustomUpstreamConfig)
|
||||
|
||||
OnUpdateCommonUpstreamConfig func(conf *client.CommonUpstreamConfig)
|
||||
|
||||
OnClearUpstreamCache func()
|
||||
}
|
||||
|
||||
// CustomUpstreamConfig implements the [ClientsContainer] interface for
|
||||
// *clientsContainer.
|
||||
func (c *clientsContainer) CustomUpstreamConfig(
|
||||
clientID string,
|
||||
cliAddr netip.Addr,
|
||||
) (conf *proxy.CustomUpstreamConfig) {
|
||||
return c.OnCustomUpstreamConfig(clientID, cliAddr)
|
||||
}
|
||||
|
||||
// UpdateCommonUpstreamConfig implements the [ClientsContainer] interface for
|
||||
// *clientsContainer.
|
||||
func (c *clientsContainer) UpdateCommonUpstreamConfig(conf *client.CommonUpstreamConfig) {
|
||||
c.OnUpdateCommonUpstreamConfig(conf)
|
||||
}
|
||||
|
||||
// ClearUpstreamCache implements the [ClientsContainer] interface for
|
||||
// *clientsContainer.
|
||||
func (c *clientsContainer) ClearUpstreamCache() {
|
||||
c.OnClearUpstreamCache()
|
||||
}
|
||||
|
||||
func startDeferStop(t *testing.T, s *Server) {
|
||||
t.Helper()
|
||||
|
||||
@@ -168,6 +205,7 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte)
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
@@ -297,6 +335,7 @@ func TestServer(t *testing.T) {
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
@@ -337,6 +376,7 @@ func TestServer_timeout(t *testing.T) {
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
@@ -364,6 +404,7 @@ func TestServer_timeout(t *testing.T) {
|
||||
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
}
|
||||
s.conf.Config.ClientsContainer = EmptyClientsContainer{}
|
||||
err = s.Prepare(&s.conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -380,6 +421,7 @@ func TestServer_Prepare_fallbacks(t *testing.T) {
|
||||
},
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
@@ -405,6 +447,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
@@ -536,6 +579,7 @@ func TestSafeSearch(t *testing.T) {
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
@@ -629,6 +673,7 @@ func TestInvalidRequest(t *testing.T) {
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
@@ -659,6 +704,7 @@ func TestBlockedRequest(t *testing.T) {
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
@@ -696,6 +742,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
@@ -721,12 +768,12 @@ func TestServerCustomClientUpstream(t *testing.T) {
|
||||
forwardConf.EDNSClientSubnet.Enabled,
|
||||
)
|
||||
|
||||
s.conf.ClientsContainer = &aghtest.ClientsContainer{
|
||||
OnUpstreamConfigByID: func(
|
||||
s.conf.ClientsContainer = &clientsContainer{
|
||||
OnCustomUpstreamConfig: func(
|
||||
_ string,
|
||||
_ upstream.Resolver,
|
||||
) (conf *proxy.CustomUpstreamConfig, err error) {
|
||||
return customUpsConf, nil
|
||||
_ netip.Addr,
|
||||
) (conf *proxy.CustomUpstreamConfig) {
|
||||
return customUpsConf
|
||||
},
|
||||
}
|
||||
|
||||
@@ -774,6 +821,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
@@ -808,6 +856,7 @@ func TestBlockCNAME(t *testing.T) {
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
@@ -884,6 +933,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
@@ -930,6 +980,7 @@ func TestNullBlockedRequest(t *testing.T) {
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
@@ -998,6 +1049,7 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
@@ -1051,6 +1103,7 @@ func TestBlockedByHosts(t *testing.T) {
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
@@ -1103,6 +1156,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
@@ -1164,6 +1218,7 @@ func TestRewrite(t *testing.T) {
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}))
|
||||
@@ -1290,6 +1345,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
|
||||
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
|
||||
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
|
||||
s.conf.Config.ClientsContainer = EmptyClientsContainer{}
|
||||
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
|
||||
|
||||
err = s.Prepare(&s.conf)
|
||||
@@ -1375,6 +1431,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
|
||||
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
|
||||
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
|
||||
s.conf.Config.ClientsContainer = EmptyClientsContainer{}
|
||||
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
|
||||
|
||||
err = s.Prepare(&s.conf)
|
||||
@@ -1643,6 +1700,7 @@ func TestServer_Exchange(t *testing.T) {
|
||||
UpstreamDNS: []string{upsAddr},
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
LocalPTRResolvers: []string{localUpsAddr},
|
||||
UsePrivateRDNS: true,
|
||||
@@ -1665,6 +1723,7 @@ func TestServer_Exchange(t *testing.T) {
|
||||
UpstreamDNS: []string{upsAddr},
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
LocalPTRResolvers: []string{},
|
||||
ServePlainDNS: true,
|
||||
|
||||
@@ -40,6 +40,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
|
||||
@@ -36,6 +36,7 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) {
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
@@ -647,7 +648,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
req.BootstrapDNS = stringutil.FilterOut(req.BootstrapDNS, IsCommentOrEmpty)
|
||||
req.BootstrapDNS = stringutil.FilterOut(req.BootstrapDNS, aghnet.IsCommentOrEmpty)
|
||||
|
||||
opts := &upstream.Options{
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
@@ -673,6 +674,8 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
// handleCacheClear is the handler for the POST /control/cache_clear HTTP API.
|
||||
func (s *Server) handleCacheClear(w http.ResponseWriter, _ *http.Request) {
|
||||
s.dnsProxy.ClearCache()
|
||||
s.conf.ClientsContainer.ClearUpstreamCache()
|
||||
|
||||
_, _ = io.WriteString(w, "OK")
|
||||
}
|
||||
|
||||
|
||||
@@ -83,6 +83,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
|
||||
RatelimitSubnetLenIPv6: 56,
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ConfigModified: func() {},
|
||||
ServePlainDNS: true,
|
||||
@@ -164,6 +165,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
RatelimitSubnetLenIPv6: 56,
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ConfigModified: func() {},
|
||||
ServePlainDNS: true,
|
||||
@@ -299,24 +301,6 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCommentOrEmpty(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
want assert.BoolAssertionFunc
|
||||
str string
|
||||
}{{
|
||||
want: assert.True,
|
||||
str: "",
|
||||
}, {
|
||||
want: assert.True,
|
||||
str: "# comment",
|
||||
}, {
|
||||
want: assert.False,
|
||||
str: "1.2.3.4",
|
||||
}} {
|
||||
tc.want(t, IsCommentOrEmpty(tc.str))
|
||||
}
|
||||
}
|
||||
|
||||
func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) {
|
||||
t.Helper()
|
||||
|
||||
@@ -388,6 +372,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
@@ -577,17 +576,14 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
|
||||
return
|
||||
}
|
||||
|
||||
// Use the ClientID first, since it has a higher priority.
|
||||
id := cmp.Or(clientID, pctx.Addr.Addr().String())
|
||||
upsConf, err := s.conf.ClientsContainer.UpstreamConfigByID(id, s.bootstrap)
|
||||
if err != nil {
|
||||
log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
cliAddr := pctx.Addr.Addr()
|
||||
upsConf := s.conf.ClientsContainer.CustomUpstreamConfig(clientID, cliAddr)
|
||||
if upsConf != nil {
|
||||
log.Debug("dnsforward: using custom upstreams for client %s", id)
|
||||
log.Debug(
|
||||
"dnsforward: using custom upstreams for client with ip %s and clientid %q",
|
||||
cliAddr,
|
||||
clientID,
|
||||
)
|
||||
|
||||
pctx.CustomUpstreamConfig = upsConf
|
||||
}
|
||||
|
||||
@@ -81,6 +81,7 @@ func TestServer_ProcessInitial(t *testing.T) {
|
||||
AAAADisabled: tc.aaaaDisabled,
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
@@ -180,6 +181,7 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
|
||||
AAAADisabled: tc.aaaaDisabled,
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
@@ -324,6 +326,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
HandleDDR: tc.ddrEnabled,
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
TLSConfig: TLSConfig{
|
||||
ServerName: ddrTestDomainName,
|
||||
@@ -660,6 +663,7 @@ func TestServer_HandleDNSRequest_restrictLocal(t *testing.T) {
|
||||
UpstreamDNS: []string{localUpsAddr},
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
UsePrivateRDNS: true,
|
||||
LocalPTRResolvers: []string{localUpsAddr},
|
||||
@@ -788,6 +792,7 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) {
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
UsePrivateRDNS: true,
|
||||
LocalPTRResolvers: []string{localUpsAddr},
|
||||
@@ -816,6 +821,7 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) {
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
UsePrivateRDNS: false,
|
||||
LocalPTRResolvers: []string{localUpsAddr},
|
||||
|
||||
@@ -19,6 +19,7 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
|
||||
@@ -94,7 +94,7 @@ func newPrivateConfig(
|
||||
) (uc *proxy.UpstreamConfig, err error) {
|
||||
confNeedsFiltering := len(addrs) > 0
|
||||
if confNeedsFiltering {
|
||||
addrs = stringutil.FilterOut(addrs, IsCommentOrEmpty)
|
||||
addrs = stringutil.FilterOut(addrs, aghnet.IsCommentOrEmpty)
|
||||
} else {
|
||||
sysResolvers := slices.DeleteFunc(slices.Clone(sysResolvers.Addrs()), unwanted.Has)
|
||||
addrs = make([]string, 0, len(sysResolvers))
|
||||
@@ -127,20 +127,6 @@ func newPrivateConfig(
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration
|
||||
// depending on configuration.
|
||||
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
|
||||
if !http3 {
|
||||
return upstream.DefaultHTTPVersions
|
||||
}
|
||||
|
||||
return []upstream.HTTPVersion{
|
||||
upstream.HTTPVersion3,
|
||||
upstream.HTTPVersion2,
|
||||
upstream.HTTPVersion11,
|
||||
}
|
||||
}
|
||||
|
||||
// setProxyUpstreamMode sets the upstream mode and related settings in conf
|
||||
// based on provided parameters.
|
||||
func setProxyUpstreamMode(
|
||||
@@ -162,10 +148,3 @@ func setProxyUpstreamMode(
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
|
||||
// This function is useful for filtering out non-upstream lines from upstream
|
||||
// configs.
|
||||
func IsCommentOrEmpty(s string) (ok bool) {
|
||||
return len(s) == 0 || s[0] == '#'
|
||||
}
|
||||
|
||||
@@ -661,8 +661,6 @@ func TestClientSettings(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmarks.
|
||||
|
||||
func BenchmarkSafeBrowsing(b *testing.B) {
|
||||
d, setts := newForTest(b, &Config{
|
||||
SafeBrowsingEnabled: true,
|
||||
@@ -670,15 +668,26 @@ func BenchmarkSafeBrowsing(b *testing.B) {
|
||||
}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
|
||||
for range b.N {
|
||||
res, err := d.CheckHost(sbBlocked, dns.TypeA, setts)
|
||||
require.NoError(b, err)
|
||||
|
||||
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
|
||||
var res Result
|
||||
var err error
|
||||
b.ReportAllocs()
|
||||
for b.Loop() {
|
||||
res, err = d.CheckHost(sbBlocked, dns.TypeA, setts)
|
||||
}
|
||||
|
||||
require.NoError(b, err)
|
||||
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
|
||||
|
||||
// Most recent results:
|
||||
//
|
||||
// goos: darwin
|
||||
// goarch: amd64
|
||||
// pkg: github.com/AdguardTeam/AdGuardHome/internal/filtering
|
||||
// cpu: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
|
||||
// BenchmarkSafeBrowsing-12 358934 2994 ns/op 1304 B/op 40 allocs/op
|
||||
}
|
||||
|
||||
func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||
func BenchmarkSafeBrowsing_parallel(b *testing.B) {
|
||||
d, setts := newForTest(b, &Config{
|
||||
SafeBrowsingEnabled: true,
|
||||
SafeBrowsingChecker: newChecker(sbBlocked),
|
||||
@@ -693,4 +702,12 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
|
||||
}
|
||||
})
|
||||
|
||||
// Most recent results:
|
||||
//
|
||||
// goos: darwin
|
||||
// goarch: amd64
|
||||
// pkg: github.com/AdguardTeam/AdGuardHome/internal/filtering
|
||||
// cpu: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
|
||||
// BenchmarkSafeBrowsing_parallel-12 507327 2382 ns/op 1352 B/op 42 allocs/op
|
||||
}
|
||||
|
||||
@@ -244,7 +244,7 @@ func (d *DNSFilter) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
|
||||
|
||||
restart, err := d.filterSetProperties(fj.URL, filt, fj.Whitelist)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, err.Error())
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -202,34 +202,30 @@ func TestParser_Parse_checksums(t *testing.T) {
|
||||
assert.Equal(t, gotWithoutComments, gotWithComments)
|
||||
}
|
||||
|
||||
var (
|
||||
resSink *rulelist.ParseResult
|
||||
errSink error
|
||||
)
|
||||
|
||||
func BenchmarkParser_Parse(b *testing.B) {
|
||||
dst := &bytes.Buffer{}
|
||||
src := strings.NewReader(strings.Repeat(testRuleTextBlocked, 1000))
|
||||
buf := make([]byte, rulelist.DefaultRuleBufSize)
|
||||
p := rulelist.NewParser()
|
||||
|
||||
var res *rulelist.ParseResult
|
||||
var err error
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
resSink, errSink = p.Parse(dst, src, buf)
|
||||
for b.Loop() {
|
||||
res, err = p.Parse(dst, src, buf)
|
||||
dst.Reset()
|
||||
}
|
||||
|
||||
require.NoError(b, errSink)
|
||||
require.NotNil(b, resSink)
|
||||
require.NoError(b, err)
|
||||
require.NotNil(b, res)
|
||||
|
||||
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
|
||||
// Most recent results:
|
||||
//
|
||||
// goos: linux
|
||||
// goos: darwin
|
||||
// goarch: amd64
|
||||
// pkg: github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist
|
||||
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
|
||||
// BenchmarkParser_Parse-16 100000000 128.0 ns/op 48 B/op 1 allocs/op
|
||||
// cpu: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
|
||||
// BenchmarkParser_Parse-12 19635926 53.70 ns/op 48 B/op 1 allocs/op
|
||||
}
|
||||
|
||||
func FuzzParser_Parse(f *testing.F) {
|
||||
|
||||
@@ -88,16 +88,25 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
||||
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
|
||||
}
|
||||
|
||||
const googleHost = "www.google.com"
|
||||
func BenchmarkDefault_SearchHost(b *testing.B) {
|
||||
const googleHost = "www.google.com"
|
||||
|
||||
var dnsRewriteSink *rules.DNSRewrite
|
||||
|
||||
func BenchmarkSafeSearch(b *testing.B) {
|
||||
ss := newForTest(b, defaultSafeSearchConf)
|
||||
|
||||
for range b.N {
|
||||
dnsRewriteSink = ss.searchHost(googleHost, testQType)
|
||||
var rewrite *rules.DNSRewrite
|
||||
b.ReportAllocs()
|
||||
for b.Loop() {
|
||||
rewrite = ss.searchHost(googleHost, testQType)
|
||||
}
|
||||
|
||||
assert.Equal(b, "forcesafesearch.google.com", dnsRewriteSink.NewCNAME)
|
||||
require.NotNil(b, rewrite)
|
||||
assert.Equal(b, "forcesafesearch.google.com", rewrite.NewCNAME)
|
||||
|
||||
// Most recent results:
|
||||
//
|
||||
// goos: darwin
|
||||
// goarch: amd64
|
||||
// pkg: github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch
|
||||
// cpu: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
|
||||
// BenchmarkDefault_SearchHost-12 751882 1604 ns/op 129 B/op 5 allocs/op
|
||||
}
|
||||
|
||||
@@ -356,7 +356,7 @@ func (a *Auth) getCurrentUser(r *http.Request) (u webUser) {
|
||||
// There's no Cookie, check Basic authentication.
|
||||
user, pass, ok := r.BasicAuth()
|
||||
if ok {
|
||||
u, _ = Context.auth.findUser(user, pass)
|
||||
u, _ = globalContext.auth.findUser(user, pass)
|
||||
|
||||
return u
|
||||
}
|
||||
@@ -408,13 +408,12 @@ func (a *Auth) authRequired() bool {
|
||||
// bytes of sessionTokenSize length.
|
||||
//
|
||||
// TODO(e.burkov): Think about using byte array instead of byte slice.
|
||||
func newSessionToken() (data []byte, err error) {
|
||||
func newSessionToken() (data []byte) {
|
||||
randData := make([]byte, sessionTokenSize)
|
||||
|
||||
_, err = rand.Read(randData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Since Go 1.24, crypto/rand.Read doesn't return an error and crashes
|
||||
// unrecoverably instead.
|
||||
_, _ = rand.Read(randData)
|
||||
|
||||
return randData, nil
|
||||
return randData
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -12,23 +10,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewSessionToken(t *testing.T) {
|
||||
// Successful case.
|
||||
token, err := newSessionToken()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, token, sessionTokenSize)
|
||||
|
||||
// Break the rand.Reader.
|
||||
prevReader := rand.Reader
|
||||
t.Cleanup(func() { rand.Reader = prevReader })
|
||||
rand.Reader = &bytes.Buffer{}
|
||||
|
||||
// Unsuccessful case.
|
||||
token, err = newSessionToken()
|
||||
require.Error(t, err)
|
||||
assert.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestAuth(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
fn := filepath.Join(dir, "sessions.db")
|
||||
@@ -47,8 +28,7 @@ func TestAuth(t *testing.T) {
|
||||
assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
|
||||
a.removeSession("notfound")
|
||||
|
||||
sess, err := newSessionToken()
|
||||
require.NoError(t, err)
|
||||
sess := newSessionToken()
|
||||
sessStr := hex.EncodeToString(sess)
|
||||
|
||||
now := time.Now().UTC().Unix()
|
||||
|
||||
@@ -47,11 +47,7 @@ func (a *Auth) newCookie(req loginJSON, addr string) (c *http.Cookie, err error)
|
||||
rateLimiter.remove(addr)
|
||||
}
|
||||
|
||||
sess, err := newSessionToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating token: %w", err)
|
||||
}
|
||||
|
||||
sess := newSessionToken()
|
||||
now := time.Now().UTC()
|
||||
|
||||
a.addSession(sess, &session{
|
||||
@@ -155,7 +151,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if rateLimiter := Context.auth.rateLimiter; rateLimiter != nil {
|
||||
if rateLimiter := globalContext.auth.rateLimiter; rateLimiter != nil {
|
||||
if left := rateLimiter.check(remoteIP); left > 0 {
|
||||
w.Header().Set(httphdr.RetryAfter, strconv.Itoa(int(left.Seconds())))
|
||||
writeErrorWithIP(
|
||||
@@ -176,10 +172,10 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
log.Error("auth: getting real ip from request with remote ip %s: %s", remoteIP, err)
|
||||
}
|
||||
|
||||
cookie, err := Context.auth.newCookie(req, remoteIP)
|
||||
cookie, err := globalContext.auth.newCookie(req, remoteIP)
|
||||
if err != nil {
|
||||
logIP := remoteIP
|
||||
if Context.auth.trustedProxies.Contains(ip.Unmap()) {
|
||||
if globalContext.auth.trustedProxies.Contains(ip.Unmap()) {
|
||||
logIP = ip.String()
|
||||
}
|
||||
|
||||
@@ -213,7 +209,7 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
Context.auth.removeSession(c.Value)
|
||||
globalContext.auth.removeSession(c.Value)
|
||||
|
||||
c = &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
@@ -232,7 +228,7 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// RegisterAuthHandlers - register handlers
|
||||
func RegisterAuthHandlers() {
|
||||
Context.mux.Handle("/control/login", postInstallHandler(ensureHandler(http.MethodPost, handleLogin)))
|
||||
globalContext.mux.Handle("/control/login", postInstallHandler(ensureHandler(http.MethodPost, handleLogin)))
|
||||
httpRegister(http.MethodGet, "/control/logout", handleLogout)
|
||||
}
|
||||
|
||||
@@ -254,13 +250,13 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (mustAuth bool) {
|
||||
// Check Basic authentication.
|
||||
user, pass, hasBasic := r.BasicAuth()
|
||||
if hasBasic {
|
||||
_, isAuthenticated = Context.auth.findUser(user, pass)
|
||||
_, isAuthenticated = globalContext.auth.findUser(user, pass)
|
||||
if !isAuthenticated {
|
||||
log.Info("%s: invalid basic authorization value", pref)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
res := Context.auth.checkSession(cookie.Value)
|
||||
res := globalContext.auth.checkSession(cookie.Value)
|
||||
isAuthenticated = res == checkSessionOK
|
||||
if !isAuthenticated {
|
||||
log.Debug("%s: invalid cookie value: %q", pref, cookie)
|
||||
@@ -294,12 +290,12 @@ func optionalAuth(
|
||||
) (wrapped func(http.ResponseWriter, *http.Request)) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
p := r.URL.Path
|
||||
authRequired := Context.auth != nil && Context.auth.authRequired()
|
||||
authRequired := globalContext.auth != nil && globalContext.auth.authRequired()
|
||||
if p == "/login.html" {
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if authRequired && err == nil {
|
||||
// Redirect to the dashboard if already authenticated.
|
||||
res := Context.auth.checkSession(cookie.Value)
|
||||
res := globalContext.auth.checkSession(cookie.Value)
|
||||
if res == checkSessionOK {
|
||||
http.Redirect(w, r, "", http.StatusFound)
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ func TestAuthHTTP(t *testing.T) {
|
||||
users := []webUser{
|
||||
{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
|
||||
}
|
||||
Context.auth = InitAuth(fn, users, 60, nil, nil)
|
||||
globalContext.auth = InitAuth(fn, users, 60, nil, nil)
|
||||
|
||||
handlerCalled := false
|
||||
handler := func(_ http.ResponseWriter, _ *http.Request) {
|
||||
@@ -68,7 +68,7 @@ func TestAuthHTTP(t *testing.T) {
|
||||
assert.True(t, handlerCalled)
|
||||
|
||||
// perform login
|
||||
cookie, err := Context.auth.newCookie(loginJSON{Name: "name", Password: "password"}, "")
|
||||
cookie, err := globalContext.auth.newCookie(loginJSON{Name: "name", Password: "password"}, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cookie)
|
||||
|
||||
@@ -114,7 +114,7 @@ func TestAuthHTTP(t *testing.T) {
|
||||
assert.True(t, handlerCalled)
|
||||
r.Header.Del(httphdr.Cookie)
|
||||
|
||||
Context.auth.Close()
|
||||
globalContext.auth.Close()
|
||||
}
|
||||
|
||||
func TestRealIP(t *testing.T) {
|
||||
|
||||
@@ -12,17 +12,14 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
)
|
||||
|
||||
// clientsContainer is the storage of all runtime and persistent clients.
|
||||
@@ -75,6 +72,7 @@ func (clients *clientsContainer) Init(
|
||||
etcHosts *aghnet.HostsContainer,
|
||||
arpDB arpdb.Interface,
|
||||
filteringConf *filtering.Config,
|
||||
sigHdlr *signalHandler,
|
||||
) (err error) {
|
||||
// TODO(s.chzhen): Refactor it.
|
||||
if clients.storage != nil {
|
||||
@@ -109,6 +107,7 @@ func (clients *clientsContainer) Init(
|
||||
|
||||
clients.storage, err = client.NewStorage(ctx, &client.StorageConfig{
|
||||
Logger: baseLogger.With(slogutil.KeyPrefix, "client_storage"),
|
||||
Clock: timeutil.SystemClock{},
|
||||
InitialClients: confClients,
|
||||
DHCP: dhcpServer,
|
||||
EtcHosts: hosts,
|
||||
@@ -120,6 +119,8 @@ func (clients *clientsContainer) Init(
|
||||
return fmt.Errorf("init client storage: %w", err)
|
||||
}
|
||||
|
||||
sigHdlr.addClientStorage(clients.storage)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -370,63 +371,6 @@ func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) {
|
||||
return true
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ dnsforward.ClientsContainer = (*clientsContainer)(nil)
|
||||
|
||||
// UpstreamConfigByID implements the [dnsforward.ClientsContainer] interface for
|
||||
// *clientsContainer. upsConf is nil if the client isn't found or if the client
|
||||
// has no custom upstreams.
|
||||
func (clients *clientsContainer) UpstreamConfigByID(
|
||||
id string,
|
||||
bootstrap upstream.Resolver,
|
||||
) (conf *proxy.CustomUpstreamConfig, err error) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
c, ok := clients.storage.Find(id)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
} else if c.UpstreamConfig != nil {
|
||||
return c.UpstreamConfig, nil
|
||||
}
|
||||
|
||||
upstreams := stringutil.FilterOut(c.Upstreams, dnsforward.IsCommentOrEmpty)
|
||||
if len(upstreams) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var upsConf *proxy.UpstreamConfig
|
||||
upsConf, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
&upstream.Options{
|
||||
Bootstrap: bootstrap,
|
||||
Timeout: time.Duration(config.DNS.UpstreamTimeout),
|
||||
HTTPVersions: dnsforward.UpstreamHTTPVersions(config.DNS.UseHTTP3Upstreams),
|
||||
PreferIPv6: config.DNS.BootstrapPreferIPv6,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conf = proxy.NewCustomUpstreamConfig(
|
||||
upsConf,
|
||||
c.UpstreamsCacheEnabled,
|
||||
int(c.UpstreamsCacheSize),
|
||||
config.DNS.EDNSClientSubnet.Enabled,
|
||||
)
|
||||
c.UpstreamConfig = conf
|
||||
|
||||
// TODO(s.chzhen): Pass context.
|
||||
err = clients.storage.Update(context.TODO(), c.Name, c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting upstream config: %w", err)
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ client.AddressUpdater = (*clientsContainer)(nil)
|
||||
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -31,34 +28,10 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) {
|
||||
nil,
|
||||
nil,
|
||||
&filtering.Config{},
|
||||
newSignalHandler(nil, nil),
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func TestClientsCustomUpstream(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
// Add client with upstreams.
|
||||
err := clients.storage.Add(ctx, &client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")},
|
||||
Upstreams: []string{
|
||||
"1.1.1.1",
|
||||
"[/example.org/]8.8.8.8",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
upsConf, err := clients.UpstreamConfigByID("1.2.3.4", net.DefaultResolver)
|
||||
assert.Nil(t, upsConf)
|
||||
assert.NoError(t, err)
|
||||
|
||||
upsConf, err = clients.UpstreamConfigByID("1.1.1.1", net.DefaultResolver)
|
||||
require.NotNil(t, upsConf)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -486,9 +486,9 @@ var config = &configuration{
|
||||
// configFilePath returns the absolute path to the symlink-evaluated path to the
|
||||
// current config file.
|
||||
func configFilePath() (confPath string) {
|
||||
confPath, err := filepath.EvalSymlinks(Context.confFilePath)
|
||||
confPath, err := filepath.EvalSymlinks(globalContext.confFilePath)
|
||||
if err != nil {
|
||||
confPath = Context.confFilePath
|
||||
confPath = globalContext.confFilePath
|
||||
logFunc := log.Error
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
logFunc = log.Debug
|
||||
@@ -498,7 +498,7 @@ func configFilePath() (confPath string) {
|
||||
}
|
||||
|
||||
if !filepath.IsAbs(confPath) {
|
||||
confPath = filepath.Join(Context.workDir, confPath)
|
||||
confPath = filepath.Join(globalContext.workDir, confPath)
|
||||
}
|
||||
|
||||
return confPath
|
||||
@@ -530,8 +530,8 @@ func parseConfig() (err error) {
|
||||
}
|
||||
|
||||
migrator := configmigrate.New(&configmigrate.Config{
|
||||
WorkingDir: Context.workDir,
|
||||
DataDir: Context.getDataDir(),
|
||||
WorkingDir: globalContext.workDir,
|
||||
DataDir: globalContext.getDataDir(),
|
||||
})
|
||||
|
||||
var upgraded bool
|
||||
@@ -640,31 +640,31 @@ func readConfigFile() (fileData []byte, err error) {
|
||||
}
|
||||
|
||||
// Saves configuration to the YAML file and also saves the user filter contents to a file
|
||||
func (c *configuration) write() (err error) {
|
||||
func (c *configuration) write(tlsMgr *tlsManager) (err error) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
if Context.auth != nil {
|
||||
config.Users = Context.auth.usersList()
|
||||
if globalContext.auth != nil {
|
||||
config.Users = globalContext.auth.usersList()
|
||||
}
|
||||
|
||||
if Context.tls != nil {
|
||||
if tlsMgr != nil {
|
||||
tlsConf := tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
tlsMgr.WriteDiskConfig(&tlsConf)
|
||||
config.TLS = tlsConf
|
||||
}
|
||||
|
||||
if Context.stats != nil {
|
||||
if globalContext.stats != nil {
|
||||
statsConf := stats.Config{}
|
||||
Context.stats.WriteDiskConfig(&statsConf)
|
||||
globalContext.stats.WriteDiskConfig(&statsConf)
|
||||
config.Stats.Interval = timeutil.Duration(statsConf.Limit)
|
||||
config.Stats.Enabled = statsConf.Enabled
|
||||
config.Stats.Ignored = statsConf.Ignored.Values()
|
||||
}
|
||||
|
||||
if Context.queryLog != nil {
|
||||
if globalContext.queryLog != nil {
|
||||
dc := querylog.Config{}
|
||||
Context.queryLog.WriteDiskConfig(&dc)
|
||||
globalContext.queryLog.WriteDiskConfig(&dc)
|
||||
config.DNS.AnonymizeClientIP = dc.AnonymizeClientIP
|
||||
config.QueryLog.Enabled = dc.Enabled
|
||||
config.QueryLog.FileEnabled = dc.FileEnabled
|
||||
@@ -673,14 +673,14 @@ func (c *configuration) write() (err error) {
|
||||
config.QueryLog.Ignored = dc.Ignored.Values()
|
||||
}
|
||||
|
||||
if Context.filters != nil {
|
||||
Context.filters.WriteDiskConfig(config.Filtering)
|
||||
if globalContext.filters != nil {
|
||||
globalContext.filters.WriteDiskConfig(config.Filtering)
|
||||
config.Filters = config.Filtering.Filters
|
||||
config.WhitelistFilters = config.Filtering.WhitelistFilters
|
||||
config.UserRules = config.Filtering.UserRules
|
||||
}
|
||||
|
||||
if s := Context.dnsServer; s != nil {
|
||||
if s := globalContext.dnsServer; s != nil {
|
||||
c := dnsforward.Config{}
|
||||
s.WriteDiskConfig(&c)
|
||||
dns := &config.DNS
|
||||
@@ -695,11 +695,11 @@ func (c *configuration) write() (err error) {
|
||||
dns.UpstreamTimeout = timeutil.Duration(s.UpstreamTimeout())
|
||||
}
|
||||
|
||||
if Context.dhcpServer != nil {
|
||||
Context.dhcpServer.WriteDiskConfig(config.DHCP)
|
||||
if globalContext.dhcpServer != nil {
|
||||
globalContext.dhcpServer.WriteDiskConfig(config.DHCP)
|
||||
}
|
||||
|
||||
config.Clients.Persistent = Context.clients.forConfig()
|
||||
config.Clients.Persistent = globalContext.clients.forConfig()
|
||||
|
||||
confPath := configFilePath()
|
||||
log.Debug("writing config file %q", confPath)
|
||||
@@ -726,14 +726,14 @@ func setContextTLSCipherIDs() (err error) {
|
||||
if len(config.TLS.OverrideTLSCiphers) == 0 {
|
||||
log.Info("tls: using default ciphers")
|
||||
|
||||
Context.tlsCipherIDs = aghtls.SaferCipherSuites()
|
||||
globalContext.tlsCipherIDs = aghtls.SaferCipherSuites()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info("tls: overriding ciphers: %s", config.TLS.OverrideTLSCiphers)
|
||||
|
||||
Context.tlsCipherIDs, err = aghtls.ParseCiphers(config.TLS.OverrideTLSCiphers)
|
||||
globalContext.tlsCipherIDs, err = aghtls.ParseCiphers(config.TLS.OverrideTLSCiphers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing override ciphers: %w", err)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/NYTimes/gziphandler"
|
||||
@@ -69,7 +68,8 @@ func appendDNSAddrsWithIfaces(dst []string, src []netip.Addr) (res []string, err
|
||||
|
||||
// collectDNSAddresses returns the list of DNS addresses the server is listening
|
||||
// on, including the addresses on all interfaces in cases of unspecified IPs.
|
||||
func collectDNSAddresses() (addrs []string, err error) {
|
||||
// tlsMgr must not be nil.
|
||||
func collectDNSAddresses(tlsMgr *tlsManager) (addrs []string, err error) {
|
||||
if hosts := config.DNS.BindHosts; len(hosts) == 0 {
|
||||
addrs = appendDNSAddrs(addrs, netutil.IPv4Localhost())
|
||||
} else {
|
||||
@@ -79,7 +79,7 @@ func collectDNSAddresses() (addrs []string, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
de := getDNSEncryption()
|
||||
de := getDNSEncryption(tlsMgr)
|
||||
if de.https != "" {
|
||||
addrs = append(addrs, de.https)
|
||||
}
|
||||
@@ -114,8 +114,8 @@ type statusResponse struct {
|
||||
IsRunning bool `json:"running"`
|
||||
}
|
||||
|
||||
func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
dnsAddrs, err := collectDNSAddresses()
|
||||
func (web *webAPI) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
dnsAddrs, err := collectDNSAddresses(web.tlsManager)
|
||||
if err != nil {
|
||||
// Don't add a lot of formatting, since the error is already
|
||||
// wrapped by collectDNSAddresses.
|
||||
@@ -129,10 +129,10 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
protectionDisabledUntil *time.Time
|
||||
protectionEnabled bool
|
||||
)
|
||||
if Context.dnsServer != nil {
|
||||
if globalContext.dnsServer != nil {
|
||||
fltConf = &dnsforward.Config{}
|
||||
Context.dnsServer.WriteDiskConfig(fltConf)
|
||||
protectionEnabled, protectionDisabledUntil = Context.dnsServer.UpdatedProtectionStatus()
|
||||
globalContext.dnsServer.WriteDiskConfig(fltConf)
|
||||
protectionEnabled, protectionDisabledUntil = globalContext.dnsServer.UpdatedProtectionStatus()
|
||||
}
|
||||
|
||||
var resp statusResponse
|
||||
@@ -162,42 +162,42 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// IsDHCPAvailable field is now false by default for Windows.
|
||||
if runtime.GOOS != "windows" {
|
||||
resp.IsDHCPAvailable = Context.dhcpServer != nil
|
||||
resp.IsDHCPAvailable = globalContext.dhcpServer != nil
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
}
|
||||
|
||||
// ------------------------
|
||||
// registration of handlers
|
||||
// ------------------------
|
||||
// registerControlHandlers sets up HTTP handlers for various control endpoints.
|
||||
// web must not be nil.
|
||||
func registerControlHandlers(web *webAPI) {
|
||||
Context.mux.HandleFunc(
|
||||
globalContext.mux.HandleFunc(
|
||||
"/control/version.json",
|
||||
postInstall(optionalAuth(web.handleVersionJSON)),
|
||||
)
|
||||
httpRegister(http.MethodPost, "/control/update", web.handleUpdate)
|
||||
|
||||
httpRegister(http.MethodGet, "/control/status", handleStatus)
|
||||
httpRegister(http.MethodGet, "/control/status", web.handleStatus)
|
||||
httpRegister(http.MethodPost, "/control/i18n/change_language", handleI18nChangeLanguage)
|
||||
httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage)
|
||||
httpRegister(http.MethodGet, "/control/profile", handleGetProfile)
|
||||
httpRegister(http.MethodPut, "/control/profile/update", handlePutProfile)
|
||||
|
||||
// No auth is necessary for DoH/DoT configurations
|
||||
Context.mux.HandleFunc("/apple/doh.mobileconfig", postInstall(handleMobileConfigDoH))
|
||||
Context.mux.HandleFunc("/apple/dot.mobileconfig", postInstall(handleMobileConfigDoT))
|
||||
globalContext.mux.HandleFunc("/apple/doh.mobileconfig", postInstall(handleMobileConfigDoH))
|
||||
globalContext.mux.HandleFunc("/apple/dot.mobileconfig", postInstall(handleMobileConfigDoT))
|
||||
RegisterAuthHandlers()
|
||||
}
|
||||
|
||||
// httpRegister registers an HTTP handler.
|
||||
func httpRegister(method, url string, handler http.HandlerFunc) {
|
||||
if method == "" {
|
||||
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
|
||||
Context.mux.HandleFunc(url, postInstall(handler))
|
||||
globalContext.mux.HandleFunc(url, postInstall(handler))
|
||||
return
|
||||
}
|
||||
|
||||
Context.mux.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
|
||||
globalContext.mux.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
|
||||
}
|
||||
|
||||
// ensure returns a wrapped handler that makes sure that the request has the
|
||||
@@ -207,11 +207,7 @@ func ensure(
|
||||
handler func(http.ResponseWriter, *http.Request),
|
||||
) (wrapped func(http.ResponseWriter, *http.Request)) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
m, u := r.Method, r.URL
|
||||
log.Debug("started %s %s %s", m, r.Host, u)
|
||||
defer func() { log.Debug("finished %s %s %s in %s", m, r.Host, u, time.Since(start)) }()
|
||||
|
||||
m := r.Method
|
||||
if m != method {
|
||||
aghhttp.Error(r, w, http.StatusMethodNotAllowed, "only method %s is allowed", method)
|
||||
|
||||
@@ -223,8 +219,8 @@ func ensure(
|
||||
return
|
||||
}
|
||||
|
||||
Context.controlLock.Lock()
|
||||
defer Context.controlLock.Unlock()
|
||||
globalContext.controlLock.Lock()
|
||||
defer globalContext.controlLock.Unlock()
|
||||
}
|
||||
|
||||
handler(w, r)
|
||||
@@ -293,7 +289,7 @@ func ensureHandler(method string, handler func(http.ResponseWriter, *http.Reques
|
||||
// preInstall lets the handler run only if firstRun is true, no redirects
|
||||
func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if !Context.firstRun {
|
||||
if !globalContext.firstRun {
|
||||
// if it's not first run, don't let users access it (for example /install.html when configuration is done)
|
||||
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
return
|
||||
@@ -320,7 +316,7 @@ func preInstallHandler(handler http.Handler) http.Handler {
|
||||
// HTTPS-related headers. If proceed is true, the middleware must continue
|
||||
// handling the request.
|
||||
func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
web := Context.web
|
||||
web := globalContext.web
|
||||
if web.httpsServer.server == nil {
|
||||
return true
|
||||
}
|
||||
@@ -409,7 +405,7 @@ func httpsURL(u *url.URL, host string, portHTTPS uint16) (redirectURL *url.URL)
|
||||
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
path := r.URL.Path
|
||||
if Context.firstRun && !strings.HasPrefix(path, "/install.") &&
|
||||
if globalContext.firstRun && !strings.HasPrefix(path, "/install.") &&
|
||||
!strings.HasPrefix(path, "/assets/") {
|
||||
http.Redirect(w, r, "install.html", http.StatusFound)
|
||||
|
||||
|
||||
@@ -428,20 +428,20 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
curConfig := &configuration{}
|
||||
copyInstallSettings(curConfig, config)
|
||||
|
||||
Context.firstRun = false
|
||||
globalContext.firstRun = false
|
||||
config.DNS.BindHosts = []netip.Addr{req.DNS.IP}
|
||||
config.DNS.Port = req.DNS.Port
|
||||
config.Filtering.SafeFSPatterns = []string{
|
||||
filepath.Join(Context.workDir, userFilterDataDir, "*"),
|
||||
filepath.Join(globalContext.workDir, userFilterDataDir, "*"),
|
||||
}
|
||||
config.HTTPConfig.Address = netip.AddrPortFrom(req.Web.IP, req.Web.Port)
|
||||
|
||||
u := &webUser{
|
||||
Name: req.Username,
|
||||
}
|
||||
err = Context.auth.addUser(u, req.Password)
|
||||
err = globalContext.auth.addUser(u, req.Password)
|
||||
if err != nil {
|
||||
Context.firstRun = true
|
||||
globalContext.firstRun = true
|
||||
copyInstallSettings(config, curConfig)
|
||||
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "%s", err)
|
||||
|
||||
@@ -452,18 +452,18 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
// moment we'll allow setting up TLS in the initial configuration or the
|
||||
// configuration itself will use HTTPS protocol, because the underlying
|
||||
// functions potentially restart the HTTPS server.
|
||||
err = startMods(web.baseLogger)
|
||||
err = startMods(r.Context(), web.baseLogger, web.tlsManager)
|
||||
if err != nil {
|
||||
Context.firstRun = true
|
||||
globalContext.firstRun = true
|
||||
copyInstallSettings(config, curConfig)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = config.write()
|
||||
err = config.write(web.tlsManager)
|
||||
if err != nil {
|
||||
Context.firstRun = true
|
||||
globalContext.firstRun = true
|
||||
copyInstallSettings(config, curConfig)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write config: %s", err)
|
||||
|
||||
@@ -527,8 +527,33 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
|
||||
return req, restartHTTP, err
|
||||
}
|
||||
|
||||
func (web *webAPI) registerInstallHandlers() {
|
||||
Context.mux.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(web.handleInstallGetAddresses)))
|
||||
Context.mux.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig)))
|
||||
Context.mux.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure)))
|
||||
// startMods initializes and starts the DNS server after installation.
|
||||
// baseLogger and tlsMgr must not be nil.
|
||||
func startMods(ctx context.Context, baseLogger *slog.Logger, tlsMgr *tlsManager) (err error) {
|
||||
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&globalContext, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = initDNS(baseLogger, tlsMgr, statsDir, querylogDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tlsMgr.start(ctx)
|
||||
|
||||
err = startDNSServer()
|
||||
if err != nil {
|
||||
closeDNSServer()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (web *webAPI) registerInstallHandlers() {
|
||||
globalContext.mux.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(web.handleInstallGetAddresses)))
|
||||
globalContext.mux.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig)))
|
||||
globalContext.mux.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure)))
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = resp.setAllowedToAutoUpdate()
|
||||
err = resp.setAllowedToAutoUpdate(web.tlsManager)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
@@ -158,14 +158,14 @@ type versionResponse struct {
|
||||
}
|
||||
|
||||
// setAllowedToAutoUpdate sets CanAutoUpdate to true if AdGuard Home is actually
|
||||
// allowed to perform an automatic update by the OS.
|
||||
func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
|
||||
// allowed to perform an automatic update by the OS. tlsMgr must not be nil.
|
||||
func (vr *versionResponse) setAllowedToAutoUpdate(tlsMgr *tlsManager) (err error) {
|
||||
if vr.CanAutoUpdate != aghalg.NBTrue {
|
||||
return nil
|
||||
}
|
||||
|
||||
tlsConf := &tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(tlsConf)
|
||||
tlsMgr.WriteDiskConfig(tlsConf)
|
||||
|
||||
canUpdate := true
|
||||
if tlsConfUsesPrivilegedPorts(tlsConf) ||
|
||||
|
||||
@@ -39,16 +39,22 @@ const (
|
||||
|
||||
// Called by other modules when configuration is changed
|
||||
func onConfigModified() {
|
||||
err := config.write()
|
||||
err := config.write(globalContext.tls)
|
||||
if err != nil {
|
||||
log.Error("writing config: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// initDNS updates all the fields of the [Context] needed to initialize the DNS
|
||||
// server and initializes it at last. It also must not be called unless
|
||||
// [config] and [Context] are initialized. baseLogger must not be nil.
|
||||
func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error) {
|
||||
// initDNS updates all the fields of the [globalContext] needed to initialize
|
||||
// the DNS server and initializes it at last. It also must not be called unless
|
||||
// [config] and [globalContext] are initialized. baseLogger and tlsMgr must not
|
||||
// be nil.
|
||||
func initDNS(
|
||||
baseLogger *slog.Logger,
|
||||
tlsMgr *tlsManager,
|
||||
statsDir string,
|
||||
querylogDir string,
|
||||
) (err error) {
|
||||
anonymizer := config.anonymizer()
|
||||
|
||||
statsConf := stats.Config{
|
||||
@@ -58,7 +64,7 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpRegister,
|
||||
Enabled: config.Stats.Enabled,
|
||||
ShouldCountClient: Context.clients.shouldCountClient,
|
||||
ShouldCountClient: globalContext.clients.shouldCountClient,
|
||||
}
|
||||
|
||||
engine, err := aghnet.NewIgnoreEngine(config.Stats.Ignored)
|
||||
@@ -67,7 +73,7 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
|
||||
}
|
||||
|
||||
statsConf.Ignored = engine
|
||||
Context.stats, err = stats.New(statsConf)
|
||||
globalContext.stats, err = stats.New(statsConf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init stats: %w", err)
|
||||
}
|
||||
@@ -77,7 +83,7 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
|
||||
Anonymizer: anonymizer,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpRegister,
|
||||
FindClient: Context.clients.findMultiple,
|
||||
FindClient: globalContext.clients.findMultiple,
|
||||
BaseDir: querylogDir,
|
||||
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
|
||||
RotationIvl: time.Duration(config.QueryLog.Interval),
|
||||
@@ -92,25 +98,25 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
|
||||
}
|
||||
|
||||
conf.Ignored = engine
|
||||
Context.queryLog, err = querylog.New(conf)
|
||||
globalContext.queryLog, err = querylog.New(conf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init querylog: %w", err)
|
||||
}
|
||||
|
||||
Context.filters, err = filtering.New(config.Filtering, nil)
|
||||
globalContext.filters, err = filtering.New(config.Filtering, nil)
|
||||
if err != nil {
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
tlsConf := &tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(tlsConf)
|
||||
tlsMgr.WriteDiskConfig(tlsConf)
|
||||
|
||||
return initDNSServer(
|
||||
Context.filters,
|
||||
Context.stats,
|
||||
Context.queryLog,
|
||||
Context.dhcpServer,
|
||||
globalContext.filters,
|
||||
globalContext.stats,
|
||||
globalContext.queryLog,
|
||||
globalContext.dhcpServer,
|
||||
anonymizer,
|
||||
httpRegister,
|
||||
tlsConf,
|
||||
@@ -121,7 +127,7 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
|
||||
// initDNSServer initializes the [context.dnsServer]. To only use the internal
|
||||
// proxy, none of the arguments are required, but tlsConf and l still must not
|
||||
// be nil, in other cases all the arguments also must not be nil. It also must
|
||||
// not be called unless [config] and [Context] are initialized.
|
||||
// not be called unless [config] and [globalContext] are initialized.
|
||||
//
|
||||
// TODO(e.burkov): Use [dnsforward.DNSCreateParams] as a parameter.
|
||||
func initDNSServer(
|
||||
@@ -134,7 +140,7 @@ func initDNSServer(
|
||||
tlsConf *tlsConfigSettings,
|
||||
l *slog.Logger,
|
||||
) (err error) {
|
||||
Context.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{
|
||||
globalContext.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{
|
||||
Logger: l,
|
||||
DNSFilter: filters,
|
||||
Stats: sts,
|
||||
@@ -142,7 +148,7 @@ func initDNSServer(
|
||||
PrivateNets: parseSubnetSet(config.DNS.PrivateNets),
|
||||
Anonymizer: anonymizer,
|
||||
DHCPServer: dhcpSrv,
|
||||
EtcHosts: Context.etcHosts,
|
||||
EtcHosts: globalContext.etcHosts,
|
||||
LocalDomain: config.DHCP.LocalDomainName,
|
||||
})
|
||||
defer func() {
|
||||
@@ -154,21 +160,27 @@ func initDNSServer(
|
||||
return fmt.Errorf("dnsforward.NewServer: %w", err)
|
||||
}
|
||||
|
||||
Context.clients.clientChecker = Context.dnsServer
|
||||
globalContext.clients.clientChecker = globalContext.dnsServer
|
||||
|
||||
dnsConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpReg)
|
||||
dnsConf, err := newServerConfig(
|
||||
&config.DNS,
|
||||
config.Clients.Sources,
|
||||
tlsConf,
|
||||
httpReg,
|
||||
globalContext.clients.storage,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("newServerConfig: %w", err)
|
||||
}
|
||||
|
||||
// Try to prepare the server with disabled private RDNS resolution if it
|
||||
// failed to prepare as is. See TODO on [dnsforward.PrivateRDNSError].
|
||||
err = Context.dnsServer.Prepare(dnsConf)
|
||||
err = globalContext.dnsServer.Prepare(dnsConf)
|
||||
if privRDNSErr := (&dnsforward.PrivateRDNSError{}); errors.As(err, &privRDNSErr) {
|
||||
log.Info("WARNING: %s; trying to disable private RDNS resolution", err)
|
||||
|
||||
dnsConf.UsePrivateRDNS = false
|
||||
err = Context.dnsServer.Prepare(dnsConf)
|
||||
err = globalContext.dnsServer.Prepare(dnsConf)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -194,7 +206,7 @@ func parseSubnetSet(nets []netutil.Prefix) (s netutil.SubnetSet) {
|
||||
}
|
||||
|
||||
func isRunning() bool {
|
||||
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
|
||||
return globalContext.dnsServer != nil && globalContext.dnsServer.IsRunning()
|
||||
}
|
||||
|
||||
func ipsToTCPAddrs(ips []netip.Addr, port uint16) (tcpAddrs []*net.TCPAddr) {
|
||||
@@ -230,12 +242,13 @@ func newServerConfig(
|
||||
clientSrcConf *clientSourcesConfig,
|
||||
tlsConf *tlsConfigSettings,
|
||||
httpReg aghhttp.RegisterFunc,
|
||||
clientsContainer dnsforward.ClientsContainer,
|
||||
) (newConf *dnsforward.ServerConfig, err error) {
|
||||
hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()})
|
||||
|
||||
fwdConf := dnsConf.Config
|
||||
fwdConf.FilterHandler = applyAdditionalFiltering
|
||||
fwdConf.ClientsContainer = &Context.clients
|
||||
fwdConf.ClientsContainer = clientsContainer
|
||||
|
||||
newConf = &dnsforward.ServerConfig{
|
||||
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
|
||||
@@ -244,7 +257,7 @@ func newServerConfig(
|
||||
TLSConfig: newDNSTLSConfig(tlsConf, hosts),
|
||||
TLSAllowUnencryptedDoH: tlsConf.AllowUnencryptedDoH,
|
||||
UpstreamTimeout: time.Duration(dnsConf.UpstreamTimeout),
|
||||
TLSv12Roots: Context.tlsRoots,
|
||||
TLSv12Roots: globalContext.tlsRoots,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpReg,
|
||||
LocalPTRResolvers: dnsConf.PrivateRDNSResolvers,
|
||||
@@ -259,16 +272,16 @@ func newServerConfig(
|
||||
var initialAddresses []netip.Addr
|
||||
// Context.stats may be nil here if initDNSServer is called from
|
||||
// [cmdlineUpdate].
|
||||
if sts := Context.stats; sts != nil {
|
||||
if sts := globalContext.stats; sts != nil {
|
||||
const initialClientsNum = 100
|
||||
initialAddresses = Context.stats.TopClientsIP(initialClientsNum)
|
||||
initialAddresses = globalContext.stats.TopClientsIP(initialClientsNum)
|
||||
}
|
||||
|
||||
// Do not set DialContext, PrivateSubnets, and UsePrivateRDNS, because they
|
||||
// are set by [dnsforward.Server.Prepare].
|
||||
newConf.AddrProcConf = &client.DefaultAddrProcConfig{
|
||||
Exchanger: Context.dnsServer,
|
||||
AddressUpdater: &Context.clients,
|
||||
Exchanger: globalContext.dnsServer,
|
||||
AddressUpdater: &globalContext.clients,
|
||||
InitialAddresses: initialAddresses,
|
||||
CatchPanics: true,
|
||||
UseRDNS: clientSrcConf.RDNS,
|
||||
@@ -350,16 +363,18 @@ func newDNSCryptConfig(
|
||||
}, nil
|
||||
}
|
||||
|
||||
// dnsEncryption contains different types of TLS encryption addresses.
|
||||
type dnsEncryption struct {
|
||||
https string
|
||||
tls string
|
||||
quic string
|
||||
}
|
||||
|
||||
func getDNSEncryption() (de dnsEncryption) {
|
||||
// getDNSEncryption returns the TLS encryption addresses that AdGuard Home
|
||||
// listens on. tlsMgr must not be nil.
|
||||
func getDNSEncryption(tlsMgr *tlsManager) (de dnsEncryption) {
|
||||
tlsConf := tlsConfigSettings{}
|
||||
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
tlsMgr.WriteDiskConfig(&tlsConf)
|
||||
|
||||
if !tlsConf.Enabled || len(tlsConf.ServerName) == 0 {
|
||||
return dnsEncryption{}
|
||||
@@ -402,7 +417,7 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
|
||||
// pref is a prefix for logging messages around the scope.
|
||||
const pref = "applying filters"
|
||||
|
||||
Context.filters.ApplyBlockedServices(setts)
|
||||
globalContext.filters.ApplyBlockedServices(setts)
|
||||
|
||||
log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID)
|
||||
|
||||
@@ -412,9 +427,9 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
|
||||
|
||||
setts.ClientIP = clientIP
|
||||
|
||||
c, ok := Context.clients.storage.Find(clientID)
|
||||
c, ok := globalContext.clients.storage.Find(clientID)
|
||||
if !ok {
|
||||
c, ok = Context.clients.storage.Find(clientIP.String())
|
||||
c, ok = globalContext.clients.storage.Find(clientIP.String())
|
||||
if !ok {
|
||||
log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID)
|
||||
|
||||
@@ -429,7 +444,7 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
|
||||
setts.ServicesRules = nil
|
||||
svcs := c.BlockedServices.IDs
|
||||
if !c.BlockedServices.Schedule.Contains(time.Now()) {
|
||||
Context.filters.ApplyBlockedServicesList(setts, svcs)
|
||||
globalContext.filters.ApplyBlockedServicesList(setts, svcs)
|
||||
log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs)
|
||||
}
|
||||
}
|
||||
@@ -455,24 +470,24 @@ func startDNSServer() error {
|
||||
return fmt.Errorf("unable to start forwarding DNS server: Already running")
|
||||
}
|
||||
|
||||
Context.filters.EnableFilters(false)
|
||||
globalContext.filters.EnableFilters(false)
|
||||
|
||||
// TODO(s.chzhen): Pass context.
|
||||
ctx := context.TODO()
|
||||
err := Context.clients.Start(ctx)
|
||||
err := globalContext.clients.Start(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting clients container: %w", err)
|
||||
}
|
||||
|
||||
err = Context.dnsServer.Start()
|
||||
err = globalContext.dnsServer.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting dns server: %w", err)
|
||||
}
|
||||
|
||||
Context.filters.Start()
|
||||
Context.stats.Start()
|
||||
globalContext.filters.Start()
|
||||
globalContext.stats.Start()
|
||||
|
||||
err = Context.queryLog.Start(ctx)
|
||||
err = globalContext.queryLog.Start(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting query log: %w", err)
|
||||
}
|
||||
@@ -480,16 +495,24 @@ func startDNSServer() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func reconfigureDNSServer() (err error) {
|
||||
// reconfigureDNSServer updates the DNS server configuration using the provided
|
||||
// TLS settings. tlsMgr must not be nil.
|
||||
func reconfigureDNSServer(tlsMgr *tlsManager) (err error) {
|
||||
tlsConf := &tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(tlsConf)
|
||||
tlsMgr.WriteDiskConfig(tlsConf)
|
||||
|
||||
newConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpRegister)
|
||||
newConf, err := newServerConfig(
|
||||
&config.DNS,
|
||||
config.Clients.Sources,
|
||||
tlsConf,
|
||||
httpRegister,
|
||||
globalContext.clients.storage,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating forwarding dns server config: %w", err)
|
||||
}
|
||||
|
||||
err = Context.dnsServer.Reconfigure(newConf)
|
||||
err = globalContext.dnsServer.Reconfigure(newConf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting forwarding dns server: %w", err)
|
||||
}
|
||||
@@ -502,12 +525,12 @@ func stopDNSServer() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = Context.dnsServer.Stop()
|
||||
err = globalContext.dnsServer.Stop()
|
||||
if err != nil {
|
||||
return fmt.Errorf("stopping forwarding dns server: %w", err)
|
||||
}
|
||||
|
||||
err = Context.clients.close(context.TODO())
|
||||
err = globalContext.clients.close(context.TODO())
|
||||
if err != nil {
|
||||
return fmt.Errorf("closing clients container: %w", err)
|
||||
}
|
||||
@@ -519,25 +542,25 @@ func stopDNSServer() (err error) {
|
||||
|
||||
func closeDNSServer() {
|
||||
// DNS forward module must be closed BEFORE stats or queryLog because it depends on them
|
||||
if Context.dnsServer != nil {
|
||||
Context.dnsServer.Close()
|
||||
Context.dnsServer = nil
|
||||
if globalContext.dnsServer != nil {
|
||||
globalContext.dnsServer.Close()
|
||||
globalContext.dnsServer = nil
|
||||
}
|
||||
|
||||
if Context.filters != nil {
|
||||
Context.filters.Close()
|
||||
if globalContext.filters != nil {
|
||||
globalContext.filters.Close()
|
||||
}
|
||||
|
||||
if Context.stats != nil {
|
||||
err := Context.stats.Close()
|
||||
if globalContext.stats != nil {
|
||||
err := globalContext.stats.Close()
|
||||
if err != nil {
|
||||
log.Error("closing stats: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if Context.queryLog != nil {
|
||||
if globalContext.queryLog != nil {
|
||||
// TODO(s.chzhen): Pass context.
|
||||
err := Context.queryLog.Shutdown(context.TODO())
|
||||
err := globalContext.queryLog.Shutdown(context.TODO())
|
||||
if err != nil {
|
||||
log.Error("closing query log: %s", err)
|
||||
}
|
||||
|
||||
@@ -37,14 +37,14 @@ func newStorage(tb testing.TB, clients []*client.Persistent) (s *client.Storage)
|
||||
func TestApplyAdditionalFiltering(t *testing.T) {
|
||||
var err error
|
||||
|
||||
Context.filters, err = filtering.New(&filtering.Config{
|
||||
globalContext.filters, err = filtering.New(&filtering.Config{
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
},
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
Context.clients.storage = newStorage(t, []*client.Persistent{{
|
||||
globalContext.clients.storage = newStorage(t, []*client.Persistent{{
|
||||
Name: "default",
|
||||
ClientIDs: []string{"default"},
|
||||
UseOwnSettings: false,
|
||||
@@ -124,7 +124,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
|
||||
err error
|
||||
)
|
||||
|
||||
Context.filters, err = filtering.New(&filtering.Config{
|
||||
globalContext.filters, err = filtering.New(&filtering.Config{
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
IDs: globalBlockedServices,
|
||||
@@ -132,7 +132,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
Context.clients.storage = newStorage(t, []*client.Persistent{{
|
||||
globalContext.clients.storage = newStorage(t, []*client.Persistent{{
|
||||
Name: "default",
|
||||
ClientIDs: []string{"default"},
|
||||
UseOwnBlockedServices: false,
|
||||
|
||||
@@ -57,7 +57,12 @@ type homeContext struct {
|
||||
auth *Auth // HTTP authentication module
|
||||
filters *filtering.DNSFilter // DNS filtering module
|
||||
web *webAPI // Web (HTTP, HTTPS) module
|
||||
tls *tlsManager // TLS module
|
||||
|
||||
// tls contains the current configuration and state of TLS encryption.
|
||||
//
|
||||
// TODO(s.chzhen): Remove once it is no longer called from different
|
||||
// modules. See [onConfigModified].
|
||||
tls *tlsManager
|
||||
|
||||
// etcHosts contains IP-hostname mappings taken from the OS-specific hosts
|
||||
// configuration files, for example /etc/hosts.
|
||||
@@ -91,10 +96,10 @@ func (c *homeContext) getDataDir() string {
|
||||
return filepath.Join(c.workDir, dataDir)
|
||||
}
|
||||
|
||||
// Context - a global context object
|
||||
// globalContext is a global context object.
|
||||
//
|
||||
// TODO(a.garipov): Refactor.
|
||||
var Context homeContext
|
||||
var globalContext homeContext
|
||||
|
||||
// Main is the entry point
|
||||
func Main(clientBuildFS fs.FS) {
|
||||
@@ -113,40 +118,32 @@ func Main(clientBuildFS fs.FS) {
|
||||
signals := make(chan os.Signal, 1)
|
||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
|
||||
|
||||
go func() {
|
||||
ctx := context.Background()
|
||||
for {
|
||||
sig := <-signals
|
||||
log.Info("Received signal %q", sig)
|
||||
switch sig {
|
||||
case syscall.SIGHUP:
|
||||
Context.clients.storage.ReloadARP(ctx)
|
||||
Context.tls.reload()
|
||||
default:
|
||||
cleanup(ctx)
|
||||
cleanupAlways()
|
||||
close(done)
|
||||
}
|
||||
}
|
||||
}()
|
||||
ctx := context.Background()
|
||||
sigHdlr := newSignalHandler(signals, func(ctx context.Context) {
|
||||
cleanup(ctx)
|
||||
cleanupAlways()
|
||||
close(done)
|
||||
})
|
||||
|
||||
go sigHdlr.handle(ctx)
|
||||
|
||||
if opts.serviceControlAction != "" {
|
||||
handleServiceControlAction(opts, clientBuildFS, signals, done)
|
||||
handleServiceControlAction(opts, clientBuildFS, signals, done, sigHdlr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// run the protection
|
||||
run(opts, clientBuildFS, done)
|
||||
run(opts, clientBuildFS, done, sigHdlr)
|
||||
}
|
||||
|
||||
// setupContext initializes [Context] fields. It also reads and upgrades
|
||||
// setupContext initializes [globalContext] fields. It also reads and upgrades
|
||||
// config file if necessary.
|
||||
func setupContext(opts options) (err error) {
|
||||
Context.firstRun = detectFirstRun()
|
||||
globalContext.firstRun = detectFirstRun()
|
||||
|
||||
Context.tlsRoots = aghtls.SystemRootCAs()
|
||||
Context.mux = http.NewServeMux()
|
||||
globalContext.tlsRoots = aghtls.SystemRootCAs()
|
||||
globalContext.mux = http.NewServeMux()
|
||||
|
||||
if !opts.noEtcHosts {
|
||||
err = setupHostsContainer()
|
||||
@@ -156,7 +153,7 @@ func setupContext(opts options) (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
if Context.firstRun {
|
||||
if globalContext.firstRun {
|
||||
log.Info("This is the first time AdGuard Home is launched")
|
||||
checkNetworkPermissions()
|
||||
|
||||
@@ -247,7 +244,7 @@ func setupHostsContainer() (err error) {
|
||||
return fmt.Errorf("getting default system hosts paths: %w", err)
|
||||
}
|
||||
|
||||
Context.etcHosts, err = aghnet.NewHostsContainer(osutil.RootDirFS(), hostsWatcher, paths...)
|
||||
globalContext.etcHosts, err = aghnet.NewHostsContainer(osutil.RootDirFS(), hostsWatcher, paths...)
|
||||
if err != nil {
|
||||
closeErr := hostsWatcher.Close()
|
||||
if errors.Is(err, aghnet.ErrNoHostsPaths) {
|
||||
@@ -271,14 +268,18 @@ func setupOpts(opts options) (err error) {
|
||||
}
|
||||
|
||||
if len(opts.pidFile) != 0 && writePIDFile(opts.pidFile) {
|
||||
Context.pidFileName = opts.pidFile
|
||||
globalContext.pidFileName = opts.pidFile
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initContextClients initializes Context clients and related fields.
|
||||
func initContextClients(ctx context.Context, logger *slog.Logger) (err error) {
|
||||
func initContextClients(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
sigHdlr *signalHandler,
|
||||
) (err error) {
|
||||
err = setupDNSFilteringConf(ctx, logger, config.Filtering)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
@@ -286,13 +287,13 @@ func initContextClients(ctx context.Context, logger *slog.Logger) (err error) {
|
||||
}
|
||||
|
||||
//lint:ignore SA1019 Migration is not over.
|
||||
config.DHCP.WorkDir = Context.workDir
|
||||
config.DHCP.DataDir = Context.getDataDir()
|
||||
config.DHCP.WorkDir = globalContext.workDir
|
||||
config.DHCP.DataDir = globalContext.getDataDir()
|
||||
config.DHCP.HTTPRegister = httpRegister
|
||||
config.DHCP.ConfigModified = onConfigModified
|
||||
|
||||
Context.dhcpServer, err = dhcpd.Create(config.DHCP)
|
||||
if Context.dhcpServer == nil || err != nil {
|
||||
globalContext.dhcpServer, err = dhcpd.Create(config.DHCP)
|
||||
if globalContext.dhcpServer == nil || err != nil {
|
||||
// TODO(a.garipov): There are a lot of places in the code right
|
||||
// now which assume that the DHCP server can be nil despite this
|
||||
// condition. Inspect them and perhaps rewrite them to use
|
||||
@@ -305,14 +306,15 @@ func initContextClients(ctx context.Context, logger *slog.Logger) (err error) {
|
||||
arpDB = arpdb.New(logger.With(slogutil.KeyError, "arpdb"))
|
||||
}
|
||||
|
||||
return Context.clients.Init(
|
||||
return globalContext.clients.Init(
|
||||
ctx,
|
||||
logger,
|
||||
config.Clients.Persistent,
|
||||
Context.dhcpServer,
|
||||
Context.etcHosts,
|
||||
globalContext.dhcpServer,
|
||||
globalContext.etcHosts,
|
||||
arpDB,
|
||||
config.Filtering,
|
||||
sigHdlr,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -374,15 +376,15 @@ func setupDNSFilteringConf(
|
||||
pcTXTSuffix = `pc.dns.adguard.com.`
|
||||
)
|
||||
|
||||
conf.EtcHosts = Context.etcHosts
|
||||
conf.EtcHosts = globalContext.etcHosts
|
||||
// TODO(s.chzhen): Use empty interface.
|
||||
if Context.etcHosts == nil || !config.DNS.HostsFileEnabled {
|
||||
if globalContext.etcHosts == nil || !config.DNS.HostsFileEnabled {
|
||||
conf.EtcHosts = nil
|
||||
}
|
||||
|
||||
conf.ConfigModified = onConfigModified
|
||||
conf.HTTPRegister = httpRegister
|
||||
conf.DataDir = Context.getDataDir()
|
||||
conf.DataDir = globalContext.getDataDir()
|
||||
conf.Filters = slices.Clone(config.Filters)
|
||||
conf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
|
||||
conf.UserRules = slices.Clone(config.UserRules)
|
||||
@@ -522,13 +524,15 @@ func isUpdateEnabled(ctx context.Context, l *slog.Logger, opts *options, customU
|
||||
}
|
||||
}
|
||||
|
||||
// initWeb initializes the web module. upd and baseLogger must not be nil.
|
||||
// initWeb initializes the web module. upd, baseLogger, and tlsMgr must not be
|
||||
// nil.
|
||||
func initWeb(
|
||||
ctx context.Context,
|
||||
opts options,
|
||||
clientBuildFS fs.FS,
|
||||
upd *updater.Updater,
|
||||
baseLogger *slog.Logger,
|
||||
tlsMgr *tlsManager,
|
||||
customURL bool,
|
||||
) (web *webAPI, err error) {
|
||||
logger := baseLogger.With(slogutil.KeyPrefix, "webapi")
|
||||
@@ -551,6 +555,7 @@ func initWeb(
|
||||
updater: upd,
|
||||
logger: logger,
|
||||
baseLogger: baseLogger,
|
||||
tlsManager: tlsMgr,
|
||||
|
||||
clientFS: clientFS,
|
||||
|
||||
@@ -560,7 +565,7 @@ func initWeb(
|
||||
ReadHeaderTimeout: readHdrTimeout,
|
||||
WriteTimeout: writeTimeout,
|
||||
|
||||
firstRun: Context.firstRun,
|
||||
firstRun: globalContext.firstRun,
|
||||
disableUpdate: disableUpdate,
|
||||
runningAsService: opts.runningAsService,
|
||||
serveHTTP3: config.DNS.ServeHTTP3,
|
||||
@@ -583,7 +588,7 @@ func fatalOnError(err error) {
|
||||
// run configures and starts AdGuard Home.
|
||||
//
|
||||
// TODO(e.burkov): Make opts a pointer.
|
||||
func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalHandler) {
|
||||
// Configure working dir.
|
||||
err := initWorkingDir(opts)
|
||||
fatalOnError(err)
|
||||
@@ -599,10 +604,11 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
|
||||
// TODO(a.garipov): Use slog everywhere.
|
||||
slogLogger := newSlogLogger(ls)
|
||||
sigHdlr.swapLogger(slogLogger)
|
||||
|
||||
// Print the first message after logger is configured.
|
||||
log.Info(version.Full())
|
||||
log.Debug("current working directory is %s", Context.workDir)
|
||||
log.Info("%s", version.Full())
|
||||
log.Debug("current working directory is %s", globalContext.workDir)
|
||||
if opts.runningAsService {
|
||||
log.Info("AdGuard Home is running as a service")
|
||||
}
|
||||
@@ -621,7 +627,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
// TODO(s.chzhen): Use it for the entire initialization process.
|
||||
ctx := context.Background()
|
||||
|
||||
err = initContextClients(ctx, slogLogger)
|
||||
err = initContextClients(ctx, slogLogger, sigHdlr)
|
||||
fatalOnError(err)
|
||||
|
||||
err = setupOpts(opts)
|
||||
@@ -632,15 +638,15 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
|
||||
confPath := configFilePath()
|
||||
|
||||
upd, customURL := newUpdater(ctx, slogLogger, Context.workDir, confPath, execPath, config)
|
||||
upd, customURL := newUpdater(ctx, slogLogger, globalContext.workDir, confPath, execPath, config)
|
||||
|
||||
// TODO(e.burkov): This could be made earlier, probably as the option's
|
||||
// effect.
|
||||
cmdlineUpdate(ctx, slogLogger, opts, upd)
|
||||
|
||||
if !Context.firstRun {
|
||||
if !globalContext.firstRun {
|
||||
// Save the updated config.
|
||||
err = config.write()
|
||||
err = config.write(nil)
|
||||
fatalOnError(err)
|
||||
|
||||
if config.HTTPConfig.Pprof.Enabled {
|
||||
@@ -648,33 +654,36 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
}
|
||||
}
|
||||
|
||||
dataDir := Context.getDataDir()
|
||||
dataDir := globalContext.getDataDir()
|
||||
err = os.MkdirAll(dataDir, aghos.DefaultPermDir)
|
||||
fatalOnError(errors.Annotate(err, "creating DNS data dir at %s: %w", dataDir))
|
||||
|
||||
GLMode = opts.glinetMode
|
||||
|
||||
// Init auth module.
|
||||
Context.auth, err = initUsers()
|
||||
globalContext.auth, err = initUsers()
|
||||
fatalOnError(err)
|
||||
|
||||
Context.tls, err = newTLSManager(config.TLS, config.DNS.ServePlainDNS)
|
||||
tlsMgr, err := newTLSManager(config.TLS, config.DNS.ServePlainDNS)
|
||||
if err != nil {
|
||||
log.Error("initializing tls: %s", err)
|
||||
onConfigModified()
|
||||
}
|
||||
|
||||
Context.web, err = initWeb(ctx, opts, clientBuildFS, upd, slogLogger, customURL)
|
||||
globalContext.tls = tlsMgr
|
||||
sigHdlr.addTLSManager(tlsMgr)
|
||||
|
||||
globalContext.web, err = initWeb(ctx, opts, clientBuildFS, upd, slogLogger, tlsMgr, customURL)
|
||||
fatalOnError(err)
|
||||
|
||||
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&Context, config)
|
||||
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&globalContext, config)
|
||||
fatalOnError(err)
|
||||
|
||||
if !Context.firstRun {
|
||||
err = initDNS(slogLogger, statsDir, querylogDir)
|
||||
if !globalContext.firstRun {
|
||||
err = initDNS(slogLogger, tlsMgr, statsDir, querylogDir)
|
||||
fatalOnError(err)
|
||||
|
||||
Context.tls.start()
|
||||
tlsMgr.start(ctx)
|
||||
|
||||
go func() {
|
||||
startErr := startDNSServer()
|
||||
@@ -684,8 +693,8 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
}
|
||||
}()
|
||||
|
||||
if Context.dhcpServer != nil {
|
||||
err = Context.dhcpServer.Start()
|
||||
if globalContext.dhcpServer != nil {
|
||||
err = globalContext.dhcpServer.Start()
|
||||
if err != nil {
|
||||
log.Error("starting dhcp server: %s", err)
|
||||
}
|
||||
@@ -693,10 +702,10 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
}
|
||||
|
||||
if !opts.noPermCheck {
|
||||
checkPermissions(ctx, slogLogger, Context.workDir, confPath, dataDir, statsDir, querylogDir)
|
||||
checkPermissions(ctx, slogLogger, globalContext.workDir, confPath, dataDir, statsDir, querylogDir)
|
||||
}
|
||||
|
||||
Context.web.start(ctx)
|
||||
globalContext.web.start(ctx)
|
||||
|
||||
// Wait for other goroutines to complete their job.
|
||||
<-done
|
||||
@@ -775,7 +784,7 @@ func checkPermissions(
|
||||
|
||||
// initUsers initializes context auth module. Clears config users field.
|
||||
func initUsers() (auth *Auth, err error) {
|
||||
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
|
||||
sessFilename := filepath.Join(globalContext.getDataDir(), "sessions.db")
|
||||
|
||||
var rateLimiter *authRateLimiter
|
||||
if config.AuthAttempts > 0 && config.AuthBlockMin > 0 {
|
||||
@@ -807,31 +816,6 @@ func (c *configuration) anonymizer() (ipmut *aghnet.IPMut) {
|
||||
return aghnet.NewIPMut(anonFunc)
|
||||
}
|
||||
|
||||
// startMods initializes and starts the DNS server after installation.
|
||||
// baseLogger must not be nil.
|
||||
func startMods(baseLogger *slog.Logger) (err error) {
|
||||
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&Context, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = initDNS(baseLogger, statsDir, querylogDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Context.tls.start()
|
||||
|
||||
err = startDNSServer()
|
||||
if err != nil {
|
||||
closeDNSServer()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkNetworkPermissions checks if the current user permissions are enough to
|
||||
// use the required networking functionality.
|
||||
func checkNetworkPermissions() {
|
||||
@@ -883,14 +867,14 @@ func writePIDFile(fn string) bool {
|
||||
func initConfigFilename(opts options) {
|
||||
confPath := opts.confFilename
|
||||
if confPath == "" {
|
||||
Context.confFilePath = filepath.Join(Context.workDir, "AdGuardHome.yaml")
|
||||
globalContext.confFilePath = filepath.Join(globalContext.workDir, "AdGuardHome.yaml")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("config path overridden to %q from cmdline", confPath)
|
||||
|
||||
Context.confFilePath = confPath
|
||||
globalContext.confFilePath = confPath
|
||||
}
|
||||
|
||||
// initWorkingDir initializes the workDir. If no command-line arguments are
|
||||
@@ -904,18 +888,18 @@ func initWorkingDir(opts options) (err error) {
|
||||
|
||||
if opts.workDir != "" {
|
||||
// If there is a custom config file, use it's directory as our working dir
|
||||
Context.workDir = opts.workDir
|
||||
globalContext.workDir = opts.workDir
|
||||
} else {
|
||||
Context.workDir = filepath.Dir(execPath)
|
||||
globalContext.workDir = filepath.Dir(execPath)
|
||||
}
|
||||
|
||||
workDir, err := filepath.EvalSymlinks(Context.workDir)
|
||||
workDir, err := filepath.EvalSymlinks(globalContext.workDir)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
Context.workDir = workDir
|
||||
globalContext.workDir = workDir
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -924,13 +908,13 @@ func initWorkingDir(opts options) (err error) {
|
||||
func cleanup(ctx context.Context) {
|
||||
log.Info("stopping AdGuard Home")
|
||||
|
||||
if Context.web != nil {
|
||||
Context.web.close(ctx)
|
||||
Context.web = nil
|
||||
if globalContext.web != nil {
|
||||
globalContext.web.close(ctx)
|
||||
globalContext.web = nil
|
||||
}
|
||||
if Context.auth != nil {
|
||||
Context.auth.Close()
|
||||
Context.auth = nil
|
||||
if globalContext.auth != nil {
|
||||
globalContext.auth.Close()
|
||||
globalContext.auth = nil
|
||||
}
|
||||
|
||||
err := stopDNSServer()
|
||||
@@ -938,28 +922,24 @@ func cleanup(ctx context.Context) {
|
||||
log.Error("stopping dns server: %s", err)
|
||||
}
|
||||
|
||||
if Context.dhcpServer != nil {
|
||||
err = Context.dhcpServer.Stop()
|
||||
if globalContext.dhcpServer != nil {
|
||||
err = globalContext.dhcpServer.Stop()
|
||||
if err != nil {
|
||||
log.Error("stopping dhcp server: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if Context.etcHosts != nil {
|
||||
if err = Context.etcHosts.Close(); err != nil {
|
||||
if globalContext.etcHosts != nil {
|
||||
if err = globalContext.etcHosts.Close(); err != nil {
|
||||
log.Error("closing hosts container: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if Context.tls != nil {
|
||||
Context.tls = nil
|
||||
}
|
||||
}
|
||||
|
||||
// This function is called before application exits
|
||||
func cleanupAlways() {
|
||||
if len(Context.pidFileName) != 0 {
|
||||
_ = os.Remove(Context.pidFileName)
|
||||
if len(globalContext.pidFileName) != 0 {
|
||||
_ = os.Remove(globalContext.pidFileName)
|
||||
}
|
||||
|
||||
log.Info("stopped")
|
||||
@@ -975,7 +955,7 @@ func exitWithError() {
|
||||
func loadCmdLineOpts() (opts options) {
|
||||
opts, eff, err := parseCmdOpts(os.Args[0], os.Args[1:])
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
log.Error("%s", err)
|
||||
printHelp(os.Args[0])
|
||||
|
||||
exitWithError()
|
||||
@@ -984,7 +964,7 @@ func loadCmdLineOpts() (opts options) {
|
||||
if eff != nil {
|
||||
err = eff()
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
log.Error("%s", err)
|
||||
exitWithError()
|
||||
}
|
||||
|
||||
@@ -1005,10 +985,12 @@ func printWebAddrs(proto, addr string, port uint16) {
|
||||
|
||||
// printHTTPAddresses prints the IP addresses which user can use to access the
|
||||
// admin interface. proto is either schemeHTTP or schemeHTTPS.
|
||||
func printHTTPAddresses(proto string) {
|
||||
//
|
||||
// TODO(s.chzhen): Implement separate functions for HTTP and HTTPS.
|
||||
func printHTTPAddresses(proto string, tlsMgr *tlsManager) {
|
||||
tlsConf := tlsConfigSettings{}
|
||||
if Context.tls != nil {
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
if tlsMgr != nil {
|
||||
tlsMgr.WriteDiskConfig(&tlsConf)
|
||||
}
|
||||
|
||||
port := config.HTTPConfig.Address.Port()
|
||||
@@ -1016,7 +998,6 @@ func printHTTPAddresses(proto string) {
|
||||
port = tlsConf.PortHTTPS
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Inspect and perhaps merge with the previous condition.
|
||||
if proto == urlutil.SchemeHTTPS && tlsConf.ServerName != "" {
|
||||
printWebAddrs(proto, tlsConf.ServerName, tlsConf.PortHTTPS)
|
||||
|
||||
@@ -1050,9 +1031,9 @@ func printHTTPAddresses(proto string) {
|
||||
|
||||
// detectFirstRun returns true if this is the first run of AdGuard Home.
|
||||
func detectFirstRun() (ok bool) {
|
||||
confPath := Context.confFilePath
|
||||
confPath := globalContext.confFilePath
|
||||
if !filepath.IsAbs(confPath) {
|
||||
confPath = filepath.Join(Context.workDir, Context.confFilePath)
|
||||
confPath = filepath.Join(globalContext.workDir, globalContext.confFilePath)
|
||||
}
|
||||
|
||||
_, err := os.Stat(confPath)
|
||||
@@ -1105,7 +1086,7 @@ func cmdlineUpdate(ctx context.Context, l *slog.Logger, opts options, upd *updat
|
||||
os.Exit(osutil.ExitCodeSuccess)
|
||||
}
|
||||
|
||||
err = upd.Update(Context.firstRun)
|
||||
err = upd.Update(globalContext.firstRun)
|
||||
fatalOnError(err)
|
||||
|
||||
err = restartService()
|
||||
|
||||
@@ -17,7 +17,7 @@ func httpClient() (c *http.Client) {
|
||||
// Do not use Context.dnsServer.DialContext directly in the struct literal
|
||||
// below, since Context.dnsServer may be nil when this function is called.
|
||||
dialContext := func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
||||
return Context.dnsServer.DialContext(ctx, network, addr)
|
||||
return globalContext.dnsServer.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
@@ -27,8 +27,8 @@ func httpClient() (c *http.Client) {
|
||||
DialContext: dialContext,
|
||||
Proxy: httpProxy,
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: Context.tlsRoots,
|
||||
CipherSuites: Context.tlsCipherIDs,
|
||||
RootCAs: globalContext.tlsRoots,
|
||||
CipherSuites: globalContext.tlsCipherIDs,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -66,7 +66,7 @@ func configureLogger(ls *logSettings) (err error) {
|
||||
|
||||
logFilePath := ls.File
|
||||
if !filepath.IsAbs(logFilePath) {
|
||||
logFilePath = filepath.Join(Context.workDir, logFilePath)
|
||||
logFilePath = filepath.Join(globalContext.workDir, logFilePath)
|
||||
}
|
||||
|
||||
log.SetOutput(&lumberjack.Logger{
|
||||
|
||||
@@ -19,10 +19,8 @@ func setupDNSIPs(t testing.TB) {
|
||||
t.Helper()
|
||||
|
||||
prevConfig := config
|
||||
prevTLS := Context.tls
|
||||
t.Cleanup(func() {
|
||||
config = prevConfig
|
||||
Context.tls = prevTLS
|
||||
})
|
||||
|
||||
config = &configuration{
|
||||
@@ -31,8 +29,6 @@ func setupDNSIPs(t testing.TB) {
|
||||
Port: defaultPortDNS,
|
||||
},
|
||||
}
|
||||
|
||||
Context.tls = &tlsManager{}
|
||||
}
|
||||
|
||||
func TestHandleMobileConfigDoH(t *testing.T) {
|
||||
@@ -62,11 +58,6 @@ func TestHandleMobileConfigDoH(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("error_no_host", func(t *testing.T) {
|
||||
oldTLSConf := Context.tls
|
||||
t.Cleanup(func() { Context.tls = oldTLSConf })
|
||||
|
||||
Context.tls = &tlsManager{conf: tlsConfigSettings{}}
|
||||
|
||||
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -134,11 +125,6 @@ func TestHandleMobileConfigDoT(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("error_no_host", func(t *testing.T) {
|
||||
oldTLSConf := Context.tls
|
||||
t.Cleanup(func() { Context.tls = oldTLSConf })
|
||||
|
||||
Context.tls = &tlsManager{conf: tlsConfigSettings{}}
|
||||
|
||||
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -47,7 +47,7 @@ type profileJSON struct {
|
||||
|
||||
// handleGetProfile is the handler for GET /control/profile endpoint.
|
||||
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
||||
u := Context.auth.getCurrentUser(r)
|
||||
u := globalContext.auth.getCurrentUser(r)
|
||||
|
||||
var resp profileJSON
|
||||
func() {
|
||||
|
||||
@@ -36,6 +36,7 @@ type program struct {
|
||||
signals chan os.Signal
|
||||
done chan struct{}
|
||||
opts options
|
||||
sigHdlr *signalHandler
|
||||
}
|
||||
|
||||
// type check
|
||||
@@ -47,7 +48,7 @@ func (p *program) Start(_ service.Service) (err error) {
|
||||
args := p.opts
|
||||
args.runningAsService = true
|
||||
|
||||
go run(args, p.clientBuildFS, p.done)
|
||||
go run(args, p.clientBuildFS, p.done, p.sigHdlr)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -204,13 +205,14 @@ func handleServiceControlAction(
|
||||
clientBuildFS fs.FS,
|
||||
signals chan os.Signal,
|
||||
done chan struct{},
|
||||
sigHdlr *signalHandler,
|
||||
) {
|
||||
// Call chooseSystem explicitly to introduce OpenBSD support for service
|
||||
// package. It's a noop for other GOOS values.
|
||||
chooseSystem()
|
||||
|
||||
action := opts.serviceControlAction
|
||||
log.Info(version.Full())
|
||||
log.Info("%s", version.Full())
|
||||
log.Info("service: control action: %s", action)
|
||||
|
||||
if action == "reload" {
|
||||
@@ -244,6 +246,7 @@ func handleServiceControlAction(
|
||||
signals: signals,
|
||||
done: done,
|
||||
opts: runOpts,
|
||||
sigHdlr: sigHdlr,
|
||||
}, svcConfig)
|
||||
if err != nil {
|
||||
log.Fatalf("service: initializing service: %s", err)
|
||||
@@ -336,7 +339,7 @@ AdGuard Home is successfully installed and will automatically start on boot.
|
||||
There are a few more things that must be configured before you can use it.
|
||||
Click on the link below and follow the Installation Wizard steps to finish setup.
|
||||
AdGuard Home is now available at the following addresses:`)
|
||||
printHTTPAddresses(urlutil.SchemeHTTP)
|
||||
printHTTPAddresses(urlutil.SchemeHTTP, nil)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -392,7 +392,7 @@ type sysLogger struct{}
|
||||
|
||||
// Error implements service.Logger interface for sysLogger.
|
||||
func (sysLogger) Error(v ...any) error {
|
||||
log.Error(fmt.Sprint(v...))
|
||||
log.Error("%s", fmt.Sprint(v...))
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -406,7 +406,7 @@ func (sysLogger) Warning(v ...any) error {
|
||||
|
||||
// Info implements service.Logger interface for sysLogger.
|
||||
func (sysLogger) Info(v ...any) error {
|
||||
log.Info(fmt.Sprint(v...))
|
||||
log.Info("%s", fmt.Sprint(v...))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
121
internal/home/signal.go
Normal file
121
internal/home/signal.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
)
|
||||
|
||||
// signalHandler processes incoming signals. It reloads configurations of
|
||||
// stored entities on SIGHUP and performs cleanup on all other signals.
|
||||
type signalHandler struct {
|
||||
// logger is used to log the operation of the signal handler. Initially,
|
||||
// [slog.Default] is used, but it should be swapped later using
|
||||
// [signalHandler.swapLogger].
|
||||
logger *atomic.Pointer[slog.Logger]
|
||||
|
||||
// mu protects clientStorage and tlsManager.
|
||||
mu *sync.Mutex
|
||||
|
||||
// clientStorage is used to reload information about runtime clients with an
|
||||
// ARP source.
|
||||
clientStorage *client.Storage
|
||||
|
||||
// tlsManager is used to reload the TLS configuration.
|
||||
tlsManager *tlsManager
|
||||
|
||||
// signals receives incoming signals.
|
||||
signals <-chan os.Signal
|
||||
|
||||
// cleanup is called to perform cleanup on all incoming signals, except
|
||||
// SIGHUP.
|
||||
cleanup func(ctx context.Context)
|
||||
}
|
||||
|
||||
// newSignalHandler returns a new properly initialized *signalHandler.
|
||||
func newSignalHandler(
|
||||
signals <-chan os.Signal,
|
||||
cleanup func(ctx context.Context),
|
||||
) (h *signalHandler) {
|
||||
h = &signalHandler{
|
||||
logger: &atomic.Pointer[slog.Logger]{},
|
||||
mu: &sync.Mutex{},
|
||||
signals: signals,
|
||||
cleanup: cleanup,
|
||||
}
|
||||
|
||||
h.logger.Store(slog.Default())
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// swapLogger replaces the stored logger with the given logger.
|
||||
func (h *signalHandler) swapLogger(logger *slog.Logger) {
|
||||
h.logger.Swap(logger)
|
||||
}
|
||||
|
||||
// addClientStorage stores the client storage.
|
||||
func (h *signalHandler) addClientStorage(s *client.Storage) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
h.clientStorage = s
|
||||
}
|
||||
|
||||
// addTLSManager stores the TLS manager.
|
||||
func (h *signalHandler) addTLSManager(m *tlsManager) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
h.tlsManager = m
|
||||
}
|
||||
|
||||
// handle processes incoming signals. It blocks until a signal is received. It
|
||||
// reloads configurations of stored entities on SIGHUP, or performs cleanup on
|
||||
// all other signals. It is intended to be used as a goroutine.
|
||||
func (h *signalHandler) handle(ctx context.Context) {
|
||||
// NOTE: Avoid using [slogutil.RecoverAndExit] to prevent immediate
|
||||
// evaluation of the logger.
|
||||
defer func() {
|
||||
v := recover()
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
|
||||
slogutil.PrintRecovered(ctx, h.logger.Load(), v)
|
||||
|
||||
os.Exit(osutil.ExitCodeFailure)
|
||||
}()
|
||||
|
||||
for {
|
||||
sig := <-h.signals
|
||||
h.logger.Load().InfoContext(ctx, "received signal", "signal", sig)
|
||||
switch sig {
|
||||
case syscall.SIGHUP:
|
||||
h.reloadConfig(ctx)
|
||||
default:
|
||||
h.cleanup(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reloadConfig refreshes configurations of stored entities.
|
||||
func (h *signalHandler) reloadConfig(ctx context.Context) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if h.clientStorage != nil {
|
||||
h.clientStorage.ReloadARP(ctx)
|
||||
}
|
||||
|
||||
if h.tlsManager != nil {
|
||||
h.tlsManager.reload()
|
||||
}
|
||||
}
|
||||
@@ -102,7 +102,9 @@ func (m *tlsManager) setCertFileTime() {
|
||||
}
|
||||
|
||||
// start updates the configuration of t and starts it.
|
||||
func (m *tlsManager) start() {
|
||||
//
|
||||
// TODO(s.chzhen): Use context.
|
||||
func (m *tlsManager) start(_ context.Context) {
|
||||
m.registerWebHandlers()
|
||||
|
||||
m.confLock.Lock()
|
||||
@@ -112,7 +114,7 @@ func (m *tlsManager) start() {
|
||||
// The background context is used because the TLSConfigChanged wraps context
|
||||
// with timeout on its own and shuts down the server, which handles current
|
||||
// request.
|
||||
Context.web.tlsConfigChanged(context.Background(), tlsConf)
|
||||
globalContext.web.tlsConfigChanged(context.Background(), tlsConf)
|
||||
}
|
||||
|
||||
// reload updates the configuration and restarts t.
|
||||
@@ -151,7 +153,7 @@ func (m *tlsManager) reload() {
|
||||
|
||||
m.certLastMod = fi.ModTime().UTC()
|
||||
|
||||
_ = reconfigureDNSServer()
|
||||
_ = reconfigureDNSServer(m)
|
||||
|
||||
m.confLock.Lock()
|
||||
tlsConf = m.conf
|
||||
@@ -160,7 +162,7 @@ func (m *tlsManager) reload() {
|
||||
// The background context is used because the TLSConfigChanged wraps context
|
||||
// with timeout on its own and shuts down the server, which handles current
|
||||
// request.
|
||||
Context.web.tlsConfigChanged(context.Background(), tlsConf)
|
||||
globalContext.web.tlsConfigChanged(context.Background(), tlsConf)
|
||||
}
|
||||
|
||||
// loadTLSConf loads and validates the TLS configuration. The returned error is
|
||||
@@ -440,7 +442,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
onConfigModified()
|
||||
|
||||
err = reconfigureDNSServer()
|
||||
err = reconfigureDNSServer(m)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
|
||||
@@ -463,7 +465,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
||||
// same reason.
|
||||
if restartHTTPS {
|
||||
go func() {
|
||||
Context.web.tlsConfigChanged(context.Background(), req.tlsConfigSettings)
|
||||
globalContext.web.tlsConfigChanged(context.Background(), req.tlsConfigSettings)
|
||||
}()
|
||||
}
|
||||
}
|
||||
@@ -539,7 +541,7 @@ func validateCertChain(certs []*x509.Certificate, srvName string) (err error) {
|
||||
|
||||
opts := x509.VerifyOptions{
|
||||
DNSName: srvName,
|
||||
Roots: Context.tlsRoots,
|
||||
Roots: globalContext.tlsRoots,
|
||||
Intermediates: pool,
|
||||
}
|
||||
_, err = main.Verify(opts)
|
||||
|
||||
@@ -49,6 +49,10 @@ type webConfig struct {
|
||||
// nil.
|
||||
baseLogger *slog.Logger
|
||||
|
||||
// tlsManager contains the current configuration and state of TLS
|
||||
// encryption. It must not be nil.
|
||||
tlsManager *tlsManager
|
||||
|
||||
clientFS fs.FS
|
||||
|
||||
// BindAddr is the binding address with port for plain HTTP web interface.
|
||||
@@ -108,6 +112,10 @@ type webAPI struct {
|
||||
// nil.
|
||||
baseLogger *slog.Logger
|
||||
|
||||
// tlsManager contains the current configuration and state of TLS
|
||||
// encryption.
|
||||
tlsManager *tlsManager
|
||||
|
||||
// httpsServer is the server that handles HTTPS traffic. If it is not nil,
|
||||
// [Web.http3Server] must also not be nil.
|
||||
httpsServer httpsServer
|
||||
@@ -124,12 +132,13 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
|
||||
conf: conf,
|
||||
logger: conf.logger,
|
||||
baseLogger: conf.baseLogger,
|
||||
tlsManager: conf.tlsManager,
|
||||
}
|
||||
|
||||
clientFS := http.FileServer(http.FS(conf.clientFS))
|
||||
|
||||
// if not configured, redirect / to /install.html, otherwise redirect /install.html to /
|
||||
Context.mux.Handle("/", withMiddlewares(clientFS, gziphandler.GzipHandler, optionalAuthHandler, postInstallHandler))
|
||||
globalContext.mux.Handle("/", withMiddlewares(clientFS, gziphandler.GzipHandler, optionalAuthHandler, postInstallHandler))
|
||||
|
||||
// add handlers for /install paths, we only need them when we're not configured yet
|
||||
if conf.firstRun {
|
||||
@@ -138,7 +147,7 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
|
||||
"This is the first launch of AdGuard Home, redirecting everything to /install.html",
|
||||
)
|
||||
|
||||
Context.mux.Handle("/install.html", preInstallHandler(clientFS))
|
||||
globalContext.mux.Handle("/install.html", preInstallHandler(clientFS))
|
||||
w.registerInstallHandlers()
|
||||
} else {
|
||||
registerControlHandlers(w)
|
||||
@@ -154,7 +163,7 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
|
||||
//
|
||||
// TODO(a.garipov): Adapt for HTTP/3.
|
||||
func webCheckPortAvailable(port uint16) (ok bool) {
|
||||
if Context.web.httpsServer.server != nil {
|
||||
if globalContext.web.httpsServer.server != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -220,14 +229,18 @@ func (web *webAPI) start(ctx context.Context) {
|
||||
|
||||
// this loop is used as an ability to change listening host and/or port
|
||||
for !web.httpsServer.inShutdown {
|
||||
printHTTPAddresses(urlutil.SchemeHTTP)
|
||||
printHTTPAddresses(urlutil.SchemeHTTP, web.tlsManager)
|
||||
errs := make(chan error, 2)
|
||||
|
||||
// Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies.
|
||||
hdlr := h2c.NewHandler(withMiddlewares(Context.mux, limitRequestBody), &http2.Server{})
|
||||
hdlr := h2c.NewHandler(withMiddlewares(globalContext.mux, limitRequestBody), &http2.Server{})
|
||||
|
||||
logger := web.baseLogger.With(loggerKeyServer, "plain")
|
||||
|
||||
// TODO(a.garipov): Remove other logs like this in other code.
|
||||
logMw := httputil.NewLogMiddleware(logger, slog.LevelDebug)
|
||||
hdlr = logMw.Wrap(hdlr)
|
||||
|
||||
// Create a new instance, because the Web is not usable after Shutdown.
|
||||
web.httpServer = &http.Server{
|
||||
Addr: web.conf.BindAddr.String(),
|
||||
@@ -238,7 +251,9 @@ func (web *webAPI) start(ctx context.Context) {
|
||||
ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError),
|
||||
}
|
||||
go func() {
|
||||
defer slogutil.RecoverAndLog(ctx, web.logger)
|
||||
defer slogutil.RecoverAndLog(ctx, logger)
|
||||
|
||||
logger.InfoContext(ctx, "starting plain server", "addr", web.httpServer.Addr)
|
||||
|
||||
errs <- web.httpServer.ListenAndServe()
|
||||
}()
|
||||
@@ -305,13 +320,17 @@ func (web *webAPI) tlsServerLoop(ctx context.Context) {
|
||||
addr := netip.AddrPortFrom(web.conf.BindAddr.Addr(), portHTTPS).String()
|
||||
logger := web.baseLogger.With(loggerKeyServer, "https")
|
||||
|
||||
// TODO(a.garipov): Remove other logs like this in other code.
|
||||
logMw := httputil.NewLogMiddleware(logger, slog.LevelDebug)
|
||||
hdlr := logMw.Wrap(withMiddlewares(globalContext.mux, limitRequestBody))
|
||||
|
||||
web.httpsServer.server = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: withMiddlewares(Context.mux, limitRequestBody),
|
||||
Handler: hdlr,
|
||||
TLSConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{web.httpsServer.cert},
|
||||
RootCAs: Context.tlsRoots,
|
||||
CipherSuites: Context.tlsCipherIDs,
|
||||
RootCAs: globalContext.tlsRoots,
|
||||
CipherSuites: globalContext.tlsCipherIDs,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
ReadTimeout: web.conf.ReadTimeout,
|
||||
@@ -320,13 +339,13 @@ func (web *webAPI) tlsServerLoop(ctx context.Context) {
|
||||
ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError),
|
||||
}
|
||||
|
||||
printHTTPAddresses(urlutil.SchemeHTTPS)
|
||||
printHTTPAddresses(urlutil.SchemeHTTPS, web.tlsManager)
|
||||
|
||||
if web.conf.serveHTTP3 {
|
||||
go web.mustStartHTTP3(ctx, addr)
|
||||
}
|
||||
|
||||
web.logger.DebugContext(ctx, "starting https server")
|
||||
logger.InfoContext(ctx, "starting https server")
|
||||
err := web.httpsServer.server.ListenAndServeTLS("", "")
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
cleanupAlways()
|
||||
@@ -344,11 +363,11 @@ func (web *webAPI) mustStartHTTP3(ctx context.Context, address string) {
|
||||
Addr: address,
|
||||
TLSConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{web.httpsServer.cert},
|
||||
RootCAs: Context.tlsRoots,
|
||||
CipherSuites: Context.tlsCipherIDs,
|
||||
RootCAs: globalContext.tlsRoots,
|
||||
CipherSuites: globalContext.tlsCipherIDs,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
Handler: withMiddlewares(Context.mux, limitRequestBody),
|
||||
Handler: withMiddlewares(globalContext.mux, limitRequestBody),
|
||||
}
|
||||
|
||||
web.logger.DebugContext(ctx, "starting http/3 server")
|
||||
|
||||
@@ -134,9 +134,6 @@ func TestManager_Add(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// ipsetPropsSink is the typed sink for benchmark results.
|
||||
var ipsetPropsSink []props
|
||||
|
||||
func BenchmarkManager_LookupHost(b *testing.B) {
|
||||
propsLong := []props{{
|
||||
name: "example.com",
|
||||
@@ -155,9 +152,13 @@ func BenchmarkManager_LookupHost(b *testing.B) {
|
||||
},
|
||||
}
|
||||
|
||||
var ipsetPropsSink []props
|
||||
|
||||
b.Run("long", func(b *testing.B) {
|
||||
const name = "a.very.long.domain.name.inside.the.domain.example.com"
|
||||
for range b.N {
|
||||
|
||||
b.ReportAllocs()
|
||||
for b.Loop() {
|
||||
ipsetPropsSink = m.lookupHost(name)
|
||||
}
|
||||
|
||||
@@ -166,10 +167,21 @@ func BenchmarkManager_LookupHost(b *testing.B) {
|
||||
|
||||
b.Run("short", func(b *testing.B) {
|
||||
const name = "example.net"
|
||||
for range b.N {
|
||||
|
||||
b.ReportAllocs()
|
||||
for b.Loop() {
|
||||
ipsetPropsSink = m.lookupHost(name)
|
||||
}
|
||||
|
||||
require.Equal(b, propsShort, ipsetPropsSink)
|
||||
})
|
||||
|
||||
// Most recent results:
|
||||
//
|
||||
// goos: linux
|
||||
// goarch: amd64
|
||||
// pkg: github.com/AdguardTeam/AdGuardHome/internal/ipset
|
||||
// cpu: Intel(R) Core(TM) i7-10510U CPU @ 1.80GHz
|
||||
// BenchmarkManager_LookupHost/long-8 6562424 174.8 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkManager_LookupHost/short-8 100000000 10.72 ns/op 0 B/op 0 allocs/op
|
||||
}
|
||||
|
||||
@@ -283,6 +283,8 @@ func anonymizeIPSlow(ip net.IP) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Investigate the results, it seems that the slow version
|
||||
// isn't that slow.
|
||||
func BenchmarkAnonymizeIP(b *testing.B) {
|
||||
benchCases := []struct {
|
||||
name string
|
||||
@@ -320,7 +322,7 @@ func BenchmarkAnonymizeIP(b *testing.B) {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
for range b.N {
|
||||
for b.Loop() {
|
||||
AnonymizeIP(bc.ip)
|
||||
}
|
||||
|
||||
@@ -330,11 +332,26 @@ func BenchmarkAnonymizeIP(b *testing.B) {
|
||||
b.Run(bc.name+"_slow", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
for range b.N {
|
||||
for b.Loop() {
|
||||
anonymizeIPSlow(bc.ip)
|
||||
}
|
||||
|
||||
assert.Equal(b, bc.want, bc.ip)
|
||||
})
|
||||
}
|
||||
|
||||
// Most recent results:
|
||||
//
|
||||
// goos: darwin
|
||||
// goarch: amd64
|
||||
// pkg: github.com/AdguardTeam/AdGuardHome/internal/querylog
|
||||
// cpu: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
|
||||
// BenchmarkAnonymizeIP/v4-12 426499675 2.687 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkAnonymizeIP/v4_slow-12 510082938 2.412 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkAnonymizeIP/v4_mapped-12 149121745 7.992 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkAnonymizeIP/v4_mapped_slow-12 178441804 6.698 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkAnonymizeIP/v6-12 346746447 3.436 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkAnonymizeIP/v6_slow-12 419062732 2.966 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkAnonymizeIP/invalid-12 316385232 3.941 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkAnonymizeIP/invalid_slow-12 456531592 2.760 ns/op 0 B/op 0 allocs/op
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user