Pull request 2346: AGDNS-2686-client-upstream-manager
Merge in DNS/adguard-home from AGDNS-2686-client-upstream-manager to master Squashed commit of the following: commit 563cb583f01c26434fa04d0e37dcbe2ba15c0912 Merge: f4b0caf5c61fe269cbAuthor: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Mar 3 19:07:35 2025 +0300 Merge branch 'master' into AGDNS-2686-client-upstream-manager commit f4b0caf5c8bc48ee8be97f031cd1aa1399eb461c Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Feb 27 21:52:51 2025 +0300 client: imp docs commit e7d74931b1cc9b62eeadbe1168ae5781d57d6c73 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Feb 26 21:44:04 2025 +0300 client: imp code commit 1cba38c1bc3b6b5afb7829c230c4e831f789647e Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Feb 26 18:06:17 2025 +0300 client: fix typo commit 65b6b1e8c0fde47f367c428a78fefc4c63bc45f9 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Feb 26 17:52:02 2025 +0300 all: imp code, docs commit ed158ef09fc26bc9c57c91dbfa04d89fede583d0 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Feb 26 14:34:50 2025 +0300 client: imp code commit ab897f64c8751ea158408521116d5b689e6d39a9 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Feb 25 18:26:16 2025 +0300 all: upd chlog commit a2c30e3ede6fb61f6d23fd392cc3035dc96f77af Merge: bdb08ee0ed8ce5b453Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Feb 25 17:40:32 2025 +0300 Merge branch 'master' into AGDNS-2686-client-upstream-manager commit bdb08ee0e6122de727f2749a44f5df7e29d0eee2 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Feb 25 17:16:31 2025 +0300 all: imp tests commit 00f0eb60474a2297567acf5a3a27e8b5c2d99229 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Feb 20 21:37:58 2025 +0300 all: imp code, docs commit 13934176636dd70a17e53bc1956d6cf51602760a Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Feb 19 15:58:11 2025 +0300 all: client upstream manager
This commit is contained in:
@@ -20,6 +20,7 @@ NOTE: Add new changes BELOW THIS COMMENT.
|
|||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
||||||
|
- Changes to global upstream DNS settings not applying to custom client upstream configurations.
|
||||||
- The formatting of large numbers in the clients tables on the *Client settings* page ([#7583]).
|
- The formatting of large numbers in the clients tables on the *Client settings* page ([#7583]).
|
||||||
|
|
||||||
[#7583]: https://github.com/AdguardTeam/AdGuardHome/issues/7583
|
[#7583]: https://github.com/AdguardTeam/AdGuardHome/issues/7583
|
||||||
|
|||||||
24
internal/aghnet/upstream.go
Normal file
24
internal/aghnet/upstream.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package aghnet
|
||||||
|
|
||||||
|
import "github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
|
|
||||||
|
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration
|
||||||
|
// depending on configuration.
|
||||||
|
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
|
||||||
|
if !http3 {
|
||||||
|
return upstream.DefaultHTTPVersions
|
||||||
|
}
|
||||||
|
|
||||||
|
return []upstream.HTTPVersion{
|
||||||
|
upstream.HTTPVersion3,
|
||||||
|
upstream.HTTPVersion2,
|
||||||
|
upstream.HTTPVersion11,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
|
||||||
|
// This function is useful for filtering out non-upstream lines from upstream
|
||||||
|
// configs.
|
||||||
|
func IsCommentOrEmpty(s string) (ok bool) {
|
||||||
|
return len(s) == 0 || s[0] == '#'
|
||||||
|
}
|
||||||
26
internal/aghnet/upstream_test.go
Normal file
26
internal/aghnet/upstream_test.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package aghnet_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsCommentOrEmpty(t *testing.T) {
|
||||||
|
for _, tc := range []struct {
|
||||||
|
want assert.BoolAssertionFunc
|
||||||
|
str string
|
||||||
|
}{{
|
||||||
|
want: assert.True,
|
||||||
|
str: "",
|
||||||
|
}, {
|
||||||
|
want: assert.True,
|
||||||
|
str: "# comment",
|
||||||
|
}, {
|
||||||
|
want: assert.False,
|
||||||
|
str: "1.2.3.4",
|
||||||
|
}} {
|
||||||
|
tc.want(t, aghnet.IsCommentOrEmpty(tc.str))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
|
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
@@ -121,26 +120,6 @@ func (p *AddressUpdater) UpdateAddress(
|
|||||||
p.OnUpdateAddress(ctx, ip, host, info)
|
p.OnUpdateAddress(ctx, ip, host, info)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Package dnsforward
|
|
||||||
|
|
||||||
// ClientsContainer is a fake [dnsforward.ClientsContainer] implementation for
|
|
||||||
// tests.
|
|
||||||
type ClientsContainer struct {
|
|
||||||
OnUpstreamConfigByID func(
|
|
||||||
id string,
|
|
||||||
boot upstream.Resolver,
|
|
||||||
) (conf *proxy.CustomUpstreamConfig, err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpstreamConfigByID implements the [dnsforward.ClientsContainer] interface
|
|
||||||
// for *ClientsContainer.
|
|
||||||
func (c *ClientsContainer) UpstreamConfigByID(
|
|
||||||
id string,
|
|
||||||
boot upstream.Resolver,
|
|
||||||
) (conf *proxy.CustomUpstreamConfig, err error) {
|
|
||||||
return c.OnUpstreamConfigByID(id, boot)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Package filtering
|
// Package filtering
|
||||||
|
|
||||||
// Resolver is a fake [filtering.Resolver] implementation for tests.
|
// Resolver is a fake [filtering.Resolver] implementation for tests.
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package aghtest_test
|
|||||||
import (
|
import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -12,9 +11,6 @@ import (
|
|||||||
// type check
|
// type check
|
||||||
var _ filtering.Resolver = (*aghtest.Resolver)(nil)
|
var _ filtering.Resolver = (*aghtest.Resolver)(nil)
|
||||||
|
|
||||||
// type check
|
|
||||||
var _ dnsforward.ClientsContainer = (*aghtest.ClientsContainer)(nil)
|
|
||||||
|
|
||||||
// type check
|
// type check
|
||||||
//
|
//
|
||||||
// TODO(s.chzhen): It's here to avoid the import cycle. Remove it.
|
// TODO(s.chzhen): It's here to avoid the import cycle. Remove it.
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// macKey contains MAC as byte array of 6, 8, or 20 bytes.
|
// macKey contains MAC as byte array of 6, 8, or 20 bytes.
|
||||||
@@ -35,7 +34,7 @@ type index struct {
|
|||||||
// nameToUID maps client name to UID.
|
// nameToUID maps client name to UID.
|
||||||
nameToUID map[string]UID
|
nameToUID map[string]UID
|
||||||
|
|
||||||
// clientIDToUID maps client ID to UID.
|
// clientIDToUID maps ClientID to UID.
|
||||||
clientIDToUID map[string]UID
|
clientIDToUID map[string]UID
|
||||||
|
|
||||||
// ipToUID maps IP address to UID.
|
// ipToUID maps IP address to UID.
|
||||||
@@ -205,19 +204,19 @@ func (ci *index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr)
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// find finds persistent client by string representation of the client ID, IP
|
// find finds persistent client by string representation of the ClientID, IP
|
||||||
// address, or MAC.
|
// address, or MAC.
|
||||||
func (ci *index) find(id string) (c *Persistent, ok bool) {
|
func (ci *index) find(id string) (c *Persistent, ok bool) {
|
||||||
uid, found := ci.clientIDToUID[id]
|
c, ok = ci.findByClientID(id)
|
||||||
if found {
|
if ok {
|
||||||
return ci.uidToClient[uid], true
|
return c, true
|
||||||
}
|
}
|
||||||
|
|
||||||
ip, err := netip.ParseAddr(id)
|
ip, err := netip.ParseAddr(id)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// MAC addresses can be successfully parsed as IP addresses.
|
// MAC addresses can be successfully parsed as IP addresses.
|
||||||
c, found = ci.findByIP(ip)
|
c, ok = ci.findByIP(ip)
|
||||||
if found {
|
if ok {
|
||||||
return c, true
|
return c, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -230,6 +229,16 @@ func (ci *index) find(id string) (c *Persistent, ok bool) {
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// findByClientID finds persistent client by ClientID.
|
||||||
|
func (ci *index) findByClientID(clientID string) (c *Persistent, ok bool) {
|
||||||
|
uid, ok := ci.clientIDToUID[clientID]
|
||||||
|
if ok {
|
||||||
|
return ci.uidToClient[uid], true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
// findByName finds persistent client by name.
|
// findByName finds persistent client by name.
|
||||||
func (ci *index) findByName(name string) (c *Persistent, found bool) {
|
func (ci *index) findByName(name string) (c *Persistent, found bool) {
|
||||||
uid, found := ci.nameToUID[name]
|
uid, found := ci.nameToUID[name]
|
||||||
@@ -343,18 +352,3 @@ func (ci *index) rangeByName(f func(c *Persistent) (cont bool)) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// closeUpstreams closes upstream configurations of persistent clients.
|
|
||||||
func (ci *index) closeUpstreams() (err error) {
|
|
||||||
var errs []error
|
|
||||||
ci.rangeByName(func(c *Persistent) (cont bool) {
|
|
||||||
err = c.CloseUpstreams()
|
|
||||||
if err != nil {
|
|
||||||
errs = append(errs, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
|
|
||||||
return errors.Join(errs...)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -58,12 +58,6 @@ func (uid *UID) UnmarshalText(data []byte) error {
|
|||||||
|
|
||||||
// Persistent contains information about persistent clients.
|
// Persistent contains information about persistent clients.
|
||||||
type Persistent struct {
|
type Persistent struct {
|
||||||
// UpstreamConfig is the custom upstream configuration for this client. If
|
|
||||||
// it's nil, it has not been initialized yet. If it's non-nil and empty,
|
|
||||||
// there are no valid upstreams. If it's non-nil and non-empty, these
|
|
||||||
// upstream must be used.
|
|
||||||
UpstreamConfig *proxy.CustomUpstreamConfig
|
|
||||||
|
|
||||||
// SafeSearch handles search engine hosts rewrites.
|
// SafeSearch handles search engine hosts rewrites.
|
||||||
SafeSearch filtering.SafeSearch
|
SafeSearch filtering.SafeSearch
|
||||||
|
|
||||||
@@ -262,7 +256,7 @@ func ValidateClientID(id string) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IDs returns a list of client IDs containing at least one element.
|
// IDs returns a list of ClientIDs containing at least one element.
|
||||||
func (c *Persistent) IDs() (ids []string) {
|
func (c *Persistent) IDs() (ids []string) {
|
||||||
ids = make([]string, 0, c.IDsLen())
|
ids = make([]string, 0, c.IDsLen())
|
||||||
|
|
||||||
@@ -281,7 +275,7 @@ func (c *Persistent) IDs() (ids []string) {
|
|||||||
return append(ids, c.ClientIDs...)
|
return append(ids, c.ClientIDs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IDsLen returns a length of client ids.
|
// IDsLen returns a length of ClientIDs.
|
||||||
func (c *Persistent) IDsLen() (n int) {
|
func (c *Persistent) IDsLen() (n int) {
|
||||||
return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs)
|
return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs)
|
||||||
}
|
}
|
||||||
@@ -312,14 +306,3 @@ func (c *Persistent) ShallowClone() (clone *Persistent) {
|
|||||||
|
|
||||||
return clone
|
return clone
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseUpstreams closes the client-specific upstream config of c if any.
|
|
||||||
func (c *Persistent) CloseUpstreams() (err error) {
|
|
||||||
if c.UpstreamConfig != nil {
|
|
||||||
if err = c.UpstreamConfig.Close(); err != nil {
|
|
||||||
return fmt.Errorf("closing upstreams of client %q: %w", c.Name, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||||
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/hostsfile"
|
"github.com/AdguardTeam/golibs/hostsfile"
|
||||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
@@ -126,6 +127,9 @@ type Storage struct {
|
|||||||
// runtimeIndex contains information about runtime clients.
|
// runtimeIndex contains information about runtime clients.
|
||||||
runtimeIndex *runtimeIndex
|
runtimeIndex *runtimeIndex
|
||||||
|
|
||||||
|
// upstreamManager stores and updates custom client upstream configurations.
|
||||||
|
upstreamManager *upstreamManager
|
||||||
|
|
||||||
// dhcp is used to update [SourceDHCP] runtime client information.
|
// dhcp is used to update [SourceDHCP] runtime client information.
|
||||||
dhcp DHCP
|
dhcp DHCP
|
||||||
|
|
||||||
@@ -163,6 +167,7 @@ func NewStorage(ctx context.Context, conf *StorageConfig) (s *Storage, err error
|
|||||||
mu: &sync.Mutex{},
|
mu: &sync.Mutex{},
|
||||||
index: newIndex(),
|
index: newIndex(),
|
||||||
runtimeIndex: newRuntimeIndex(),
|
runtimeIndex: newRuntimeIndex(),
|
||||||
|
upstreamManager: newUpstreamManager(conf.Logger),
|
||||||
dhcp: conf.DHCP,
|
dhcp: conf.DHCP,
|
||||||
etcHosts: conf.EtcHosts,
|
etcHosts: conf.EtcHosts,
|
||||||
arpDB: conf.ARPDB,
|
arpDB: conf.ARPDB,
|
||||||
@@ -200,7 +205,7 @@ func (s *Storage) Start(ctx context.Context) (err error) {
|
|||||||
func (s *Storage) Shutdown(_ context.Context) (err error) {
|
func (s *Storage) Shutdown(_ context.Context) (err error) {
|
||||||
close(s.done)
|
close(s.done)
|
||||||
|
|
||||||
return s.closeUpstreams()
|
return s.upstreamManager.close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// periodicARPUpdate periodically reloads runtime clients from ARP. It is
|
// periodicARPUpdate periodically reloads runtime clients from ARP. It is
|
||||||
@@ -416,6 +421,7 @@ func (s *Storage) Add(ctx context.Context, p *Persistent) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.index.add(p)
|
s.index.add(p)
|
||||||
|
s.upstreamManager.updateCustomUpstreamConfig(p)
|
||||||
|
|
||||||
s.logger.DebugContext(
|
s.logger.DebugContext(
|
||||||
ctx,
|
ctx,
|
||||||
@@ -441,7 +447,7 @@ func (s *Storage) FindByName(name string) (p *Persistent, ok bool) {
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find finds persistent client by string representation of the client ID, IP
|
// Find finds persistent client by string representation of the ClientID, IP
|
||||||
// address, or MAC. And returns its shallow copy.
|
// address, or MAC. And returns its shallow copy.
|
||||||
//
|
//
|
||||||
// TODO(s.chzhen): Accept ClientIDData structure instead, which will contain
|
// TODO(s.chzhen): Accept ClientIDData structure instead, which will contain
|
||||||
@@ -514,12 +520,13 @@ func (s *Storage) RemoveByName(ctx context.Context, name string) (ok bool) {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.CloseUpstreams(); err != nil {
|
|
||||||
s.logger.ErrorContext(ctx, "removing client", "name", p.Name, slogutil.KeyError, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.index.remove(p)
|
s.index.remove(p)
|
||||||
|
|
||||||
|
err := s.upstreamManager.remove(p.UID)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.DebugContext(ctx, "closing client upstreams", "name", name, slogutil.KeyError, err)
|
||||||
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -556,6 +563,8 @@ func (s *Storage) Update(ctx context.Context, name string, p *Persistent) (err e
|
|||||||
s.index.remove(stored)
|
s.index.remove(stored)
|
||||||
s.index.add(p)
|
s.index.add(p)
|
||||||
|
|
||||||
|
s.upstreamManager.updateCustomUpstreamConfig(p)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -576,14 +585,6 @@ func (s *Storage) Size() (n int) {
|
|||||||
return s.index.size()
|
return s.index.size()
|
||||||
}
|
}
|
||||||
|
|
||||||
// closeUpstreams closes upstream configurations of persistent clients.
|
|
||||||
func (s *Storage) closeUpstreams() (err error) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
return s.index.closeUpstreams()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClientRuntime returns a copy of the saved runtime client by ip. If no such
|
// ClientRuntime returns a copy of the saved runtime client by ip. If no such
|
||||||
// client exists, returns nil.
|
// client exists, returns nil.
|
||||||
func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) {
|
func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) {
|
||||||
@@ -626,3 +627,42 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
|
|||||||
func (s *Storage) AllowedTags() (tags []string) {
|
func (s *Storage) AllowedTags() (tags []string) {
|
||||||
return s.allowedTags
|
return s.allowedTags
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CustomUpstreamConfig implements the [dnsforward.ClientsContainer] interface
|
||||||
|
// for *Storage
|
||||||
|
func (s *Storage) CustomUpstreamConfig(
|
||||||
|
id string,
|
||||||
|
addr netip.Addr,
|
||||||
|
) (prxConf *proxy.CustomUpstreamConfig) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
c, ok := s.index.findByClientID(id)
|
||||||
|
if !ok {
|
||||||
|
c, ok = s.index.findByIP(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.upstreamManager.customUpstreamConfig(c.UID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateCommonUpstreamConfig implements the [dnsforward.ClientsContainer]
|
||||||
|
// interface for *Storage
|
||||||
|
func (s *Storage) UpdateCommonUpstreamConfig(conf *CommonUpstreamConfig) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.upstreamManager.updateCommonUpstreamConfig(conf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUpstreamCache implements the [dnsforward.ClientsContainer] interface for
|
||||||
|
// *Storage
|
||||||
|
func (s *Storage) ClearUpstreamCache() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.upstreamManager.clearUpstreamCache()
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||||
"github.com/AdguardTeam/golibs/hostsfile"
|
"github.com/AdguardTeam/golibs/hostsfile"
|
||||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
@@ -34,6 +35,9 @@ func newTestStorage(tb testing.TB) (s *client.Storage) {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ dnsforward.ClientsContainer = (*client.Storage)(nil)
|
||||||
|
|
||||||
// testHostsContainer is a mock implementation of the [client.HostsContainer]
|
// testHostsContainer is a mock implementation of the [client.HostsContainer]
|
||||||
// interface.
|
// interface.
|
||||||
type testHostsContainer struct {
|
type testHostsContainer struct {
|
||||||
@@ -1278,3 +1282,90 @@ func TestStorage_RangeByName(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStorage_CustomUpstreamConfig(t *testing.T) {
|
||||||
|
const (
|
||||||
|
existingName = "existing_name"
|
||||||
|
existingClientID = "existing_client_id"
|
||||||
|
|
||||||
|
nonExistingClientID = "non_existing_client_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
existingClientUID = client.MustNewUID()
|
||||||
|
existingIP = netip.MustParseAddr("192.0.2.1")
|
||||||
|
|
||||||
|
nonExistingIP = netip.MustParseAddr("192.0.2.255")
|
||||||
|
|
||||||
|
testUpstreamTimeout = time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
existingClient := &client.Persistent{
|
||||||
|
Name: existingName,
|
||||||
|
IPs: []netip.Addr{existingIP},
|
||||||
|
ClientIDs: []string{existingClientID},
|
||||||
|
UID: existingClientUID,
|
||||||
|
Upstreams: []string{"192.0.2.0"},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := newTestStorage(t)
|
||||||
|
s.UpdateCommonUpstreamConfig(&client.CommonUpstreamConfig{
|
||||||
|
UpstreamTimeout: testUpstreamTimeout,
|
||||||
|
})
|
||||||
|
|
||||||
|
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||||
|
return s.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||||
|
err := s.Add(ctx, existingClient)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
cliAddr netip.Addr
|
||||||
|
wantNilConf assert.ValueAssertionFunc
|
||||||
|
name string
|
||||||
|
cliID string
|
||||||
|
}{{
|
||||||
|
name: "client_id",
|
||||||
|
cliID: existingClientID,
|
||||||
|
cliAddr: netip.Addr{},
|
||||||
|
wantNilConf: assert.NotNil,
|
||||||
|
}, {
|
||||||
|
name: "client_addr",
|
||||||
|
cliID: "",
|
||||||
|
cliAddr: existingIP,
|
||||||
|
wantNilConf: assert.NotNil,
|
||||||
|
}, {
|
||||||
|
name: "non_existing_client_id",
|
||||||
|
cliID: nonExistingClientID,
|
||||||
|
cliAddr: netip.Addr{},
|
||||||
|
wantNilConf: assert.Nil,
|
||||||
|
}, {
|
||||||
|
name: "non_existing_client_addr",
|
||||||
|
cliID: "",
|
||||||
|
cliAddr: nonExistingIP,
|
||||||
|
wantNilConf: assert.Nil,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
conf := s.CustomUpstreamConfig(tc.cliID, tc.cliAddr)
|
||||||
|
tc.wantNilConf(t, conf)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("update_common_config", func(t *testing.T) {
|
||||||
|
conf := s.CustomUpstreamConfig(existingClientID, existingIP)
|
||||||
|
require.NotNil(t, conf)
|
||||||
|
|
||||||
|
s.UpdateCommonUpstreamConfig(&client.CommonUpstreamConfig{
|
||||||
|
UpstreamTimeout: testUpstreamTimeout * 2,
|
||||||
|
})
|
||||||
|
|
||||||
|
updConf := s.CustomUpstreamConfig(existingClientID, existingIP)
|
||||||
|
require.NotNil(t, updConf)
|
||||||
|
|
||||||
|
assert.NotEqual(t, conf, updConf)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
219
internal/client/upstreammanager.go
Normal file
219
internal/client/upstreammanager.go
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"slices"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
|
"github.com/AdguardTeam/golibs/stringutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CommonUpstreamConfig contains common settings for custom client upstream
|
||||||
|
// configurations.
|
||||||
|
type CommonUpstreamConfig struct {
|
||||||
|
Bootstrap upstream.Resolver
|
||||||
|
UpstreamTimeout time.Duration
|
||||||
|
BootstrapPreferIPv6 bool
|
||||||
|
EDNSClientSubnetEnabled bool
|
||||||
|
UseHTTP3Upstreams bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// customUpstreamConfig contains custom client upstream configuration and the
|
||||||
|
// timestamp of the latest configuration update.
|
||||||
|
type customUpstreamConfig struct {
|
||||||
|
// proxyConf is the constructed upstream configuration for the [proxy],
|
||||||
|
// derived from the fields below. It is initialized on demand with
|
||||||
|
// [newCustomUpstreamConfig].
|
||||||
|
proxyConf *proxy.CustomUpstreamConfig
|
||||||
|
|
||||||
|
// commonConfUpdate is the timestamp of the latest configuration update,
|
||||||
|
// used to check against [upstreamManager.confUpdate] to determine if the
|
||||||
|
// configuration is up to date.
|
||||||
|
commonConfUpdate time.Time
|
||||||
|
|
||||||
|
// upstreams is the cached list of custom upstream DNS servers used for the
|
||||||
|
// configuration of proxyConf.
|
||||||
|
upstreams []string
|
||||||
|
|
||||||
|
// upstreamsCacheSize is the cached value of the cache size of the
|
||||||
|
// upstreams, used for the configuration of proxyConf.
|
||||||
|
upstreamsCacheSize uint32
|
||||||
|
|
||||||
|
// upstreamsCacheEnabled is the cached value indicating whether the cache of
|
||||||
|
// the upstreams is enabled for the configuration of proxyConf.
|
||||||
|
upstreamsCacheEnabled bool
|
||||||
|
|
||||||
|
// isChanged indicates whether the proxyConf needs to be updated.
|
||||||
|
isChanged bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// upstreamManager stores and updates custom client upstream configurations.
|
||||||
|
type upstreamManager struct {
|
||||||
|
// logger is used for logging the operation of the upstream manager. It
|
||||||
|
// must not be nil.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Consider using a logger with its own prefix.
|
||||||
|
logger *slog.Logger
|
||||||
|
|
||||||
|
// uidToCustomConf maps persistent client UID to the custom client upstream
|
||||||
|
// configuration. Stored UIDs must be in sync with the [index.uidToClient].
|
||||||
|
uidToCustomConf map[UID]*customUpstreamConfig
|
||||||
|
|
||||||
|
// commonConf is the common upstream configuration.
|
||||||
|
commonConf *CommonUpstreamConfig
|
||||||
|
|
||||||
|
// confUpdate is the timestamp of the latest common upstream configuration
|
||||||
|
// update.
|
||||||
|
confUpdate time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// newUpstreamManager returns the new properly initialized upstream manager.
|
||||||
|
func newUpstreamManager(logger *slog.Logger) (m *upstreamManager) {
|
||||||
|
return &upstreamManager{
|
||||||
|
logger: logger,
|
||||||
|
uidToCustomConf: make(map[UID]*customUpstreamConfig),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateCommonUpstreamConfig updates the common upstream configuration and the
|
||||||
|
// timestamp of the latest configuration update.
|
||||||
|
func (m *upstreamManager) updateCommonUpstreamConfig(conf *CommonUpstreamConfig) {
|
||||||
|
m.commonConf = conf
|
||||||
|
m.confUpdate = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateCustomUpstreamConfig updates the stored custom client upstream
|
||||||
|
// configuration associated with the persistent client. It also sets
|
||||||
|
// [customUpstreamConfig.isChanged] to true so [customUpstreamConfig.proxyConf]
|
||||||
|
// can be updated later in [upstreamManager.customUpstreamConfig].
|
||||||
|
func (m *upstreamManager) updateCustomUpstreamConfig(c *Persistent) {
|
||||||
|
cliConf, ok := m.uidToCustomConf[c.UID]
|
||||||
|
if !ok {
|
||||||
|
cliConf = &customUpstreamConfig{
|
||||||
|
commonConfUpdate: m.confUpdate,
|
||||||
|
}
|
||||||
|
|
||||||
|
m.uidToCustomConf[c.UID] = cliConf
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(s.chzhen): Compare before cloning.
|
||||||
|
cliConf.upstreams = slices.Clone(c.Upstreams)
|
||||||
|
cliConf.upstreamsCacheSize = c.UpstreamsCacheSize
|
||||||
|
cliConf.upstreamsCacheEnabled = c.UpstreamsCacheEnabled
|
||||||
|
cliConf.isChanged = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// customUpstreamConfig returns the custom client upstream configuration.
|
||||||
|
func (m *upstreamManager) customUpstreamConfig(uid UID) (proxyConf *proxy.CustomUpstreamConfig) {
|
||||||
|
cliConf, ok := m.uidToCustomConf[uid]
|
||||||
|
if !ok {
|
||||||
|
// TODO(s.chzhen): Consider panic.
|
||||||
|
m.logger.Error("no associated custom client upstream config")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.isConfigChanged(cliConf) {
|
||||||
|
return cliConf.proxyConf
|
||||||
|
}
|
||||||
|
|
||||||
|
if cliConf.proxyConf != nil {
|
||||||
|
err := cliConf.proxyConf.Close()
|
||||||
|
if err != nil {
|
||||||
|
// TODO(s.chzhen): Pass context.
|
||||||
|
m.logger.Debug("closing custom upstream config", slogutil.KeyError, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyConf = newCustomUpstreamConfig(cliConf, m.commonConf)
|
||||||
|
cliConf.proxyConf = proxyConf
|
||||||
|
cliConf.isChanged = false
|
||||||
|
|
||||||
|
return proxyConf
|
||||||
|
}
|
||||||
|
|
||||||
|
// isConfigChanged returns true if the update is necessary for the custom client
|
||||||
|
// upstream configuration.
|
||||||
|
func (m *upstreamManager) isConfigChanged(cliConf *customUpstreamConfig) (ok bool) {
|
||||||
|
return !m.confUpdate.Equal(cliConf.commonConfUpdate) || cliConf.isChanged
|
||||||
|
}
|
||||||
|
|
||||||
|
// clearUpstreamCache clears the upstream cache for each stored custom client
|
||||||
|
// upstream configuration.
|
||||||
|
func (m *upstreamManager) clearUpstreamCache() {
|
||||||
|
for _, c := range m.uidToCustomConf {
|
||||||
|
c.proxyConf.ClearCache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove deletes the custom client upstream configuration and closes
|
||||||
|
// [customUpstreamConfig.proxyConf] if necessary.
|
||||||
|
func (m *upstreamManager) remove(uid UID) (err error) {
|
||||||
|
cliConf, ok := m.uidToCustomConf[uid]
|
||||||
|
if !ok {
|
||||||
|
// TODO(s.chzhen): Consider panic.
|
||||||
|
return errors.Error("no associated custom client upstream config")
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(m.uidToCustomConf, uid)
|
||||||
|
|
||||||
|
if cliConf.proxyConf != nil {
|
||||||
|
return cliConf.proxyConf.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// close shuts down each stored custom client upstream configuration.
|
||||||
|
func (m *upstreamManager) close() (err error) {
|
||||||
|
var errs []error
|
||||||
|
for _, c := range m.uidToCustomConf {
|
||||||
|
if c.proxyConf == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
errs = append(errs, c.proxyConf.Close())
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.Join(errs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newCustomUpstreamConfig returns the new properly initialized custom proxy
|
||||||
|
// upstream configuration for the client.
|
||||||
|
func newCustomUpstreamConfig(
|
||||||
|
cliConf *customUpstreamConfig,
|
||||||
|
conf *CommonUpstreamConfig,
|
||||||
|
) (proxyConf *proxy.CustomUpstreamConfig) {
|
||||||
|
upstreams := stringutil.FilterOut(cliConf.upstreams, aghnet.IsCommentOrEmpty)
|
||||||
|
if len(upstreams) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
upsConf, err := proxy.ParseUpstreamsConfig(
|
||||||
|
upstreams,
|
||||||
|
&upstream.Options{
|
||||||
|
Bootstrap: conf.Bootstrap,
|
||||||
|
Timeout: time.Duration(conf.UpstreamTimeout),
|
||||||
|
HTTPVersions: aghnet.UpstreamHTTPVersions(conf.UseHTTP3Upstreams),
|
||||||
|
PreferIPv6: conf.BootstrapPreferIPv6,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
// Should not happen because upstreams are already validated. See
|
||||||
|
// [Persistent.validate].
|
||||||
|
panic(fmt.Errorf("creating custom upstream config: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return proxy.NewCustomUpstreamConfig(
|
||||||
|
upsConf,
|
||||||
|
cliConf.upstreamsCacheEnabled,
|
||||||
|
int(cliConf.upstreamsCacheSize),
|
||||||
|
conf.EDNSClientSubnetEnabled,
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
var _ proxy.BeforeRequestHandler = (*Server)(nil)
|
var _ proxy.BeforeRequestHandler = (*Server)(nil)
|
||||||
|
|
||||||
// HandleBefore is the handler that is called before any other processing,
|
// HandleBefore is the handler that is called before any other processing,
|
||||||
// including logs. It performs access checks and puts the client ID, if there
|
// including logs. It performs access checks and puts the ClientID, if there
|
||||||
// is one, into the server's cache.
|
// is one, into the server's cache.
|
||||||
//
|
//
|
||||||
// TODO(d.kolyshev): Extract to separate package.
|
// TODO(d.kolyshev): Extract to separate package.
|
||||||
|
|||||||
@@ -266,6 +266,7 @@ func TestServer_HandleBefore_udp(t *testing.T) {
|
|||||||
UpstreamDNS: []string{localUpsAddr},
|
UpstreamDNS: []string{localUpsAddr},
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ func clientIDFromClientServerName(
|
|||||||
return strings.ToLower(clientID), nil
|
return strings.ToLower(clientID), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// clientIDFromDNSContextHTTPS extracts the client's ID from the path of the
|
// clientIDFromDNSContextHTTPS extracts the ClientID from the path of the
|
||||||
// client's DNS-over-HTTPS request.
|
// client's DNS-over-HTTPS request.
|
||||||
func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err error) {
|
func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err error) {
|
||||||
r := pctx.HTTPRequest
|
r := pctx.HTTPRequest
|
||||||
|
|||||||
46
internal/dnsforward/clientscontainer.go
Normal file
46
internal/dnsforward/clientscontainer.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package dnsforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClientsContainer provides information about preconfigured DNS clients.
|
||||||
|
type ClientsContainer interface {
|
||||||
|
// CustomUpstreamConfig returns the custom client upstream configuration, if
|
||||||
|
// any. It prioritizes ClientID over client IP address to identify the
|
||||||
|
// client.
|
||||||
|
CustomUpstreamConfig(clientID string, cliAddr netip.Addr) (conf *proxy.CustomUpstreamConfig)
|
||||||
|
|
||||||
|
// UpdateCommonUpstreamConfig updates the common upstream configuration.
|
||||||
|
UpdateCommonUpstreamConfig(conf *client.CommonUpstreamConfig)
|
||||||
|
|
||||||
|
// ClearUpstreamCache clears the upstream cache for each stored custom
|
||||||
|
// client upstream configuration.
|
||||||
|
ClearUpstreamCache()
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmptyClientsContainer is an [ClientsContainer] implementation that does nothing.
|
||||||
|
type EmptyClientsContainer struct{}
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ ClientsContainer = EmptyClientsContainer{}
|
||||||
|
|
||||||
|
// CustomUpstreamConfig implements the [ClientsContainer] interface for
|
||||||
|
// EmptyClientsContainer.
|
||||||
|
func (EmptyClientsContainer) CustomUpstreamConfig(
|
||||||
|
clientID string,
|
||||||
|
cliAddr netip.Addr,
|
||||||
|
) (conf *proxy.CustomUpstreamConfig) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateCommonUpstreamConfig implements the [ClientsContainer] interface for
|
||||||
|
// EmptyClientsContainer.
|
||||||
|
func (EmptyClientsContainer) UpdateCommonUpstreamConfig(conf *client.CommonUpstreamConfig) {}
|
||||||
|
|
||||||
|
// ClearUpstreamCache implements the [ClientsContainer] interface for
|
||||||
|
// EmptyClientsContainer.
|
||||||
|
func (EmptyClientsContainer) ClearUpstreamCache() {}
|
||||||
@@ -29,19 +29,6 @@ import (
|
|||||||
"github.com/ameshkov/dnscrypt/v2"
|
"github.com/ameshkov/dnscrypt/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ClientsContainer provides information about preconfigured DNS clients.
|
|
||||||
type ClientsContainer interface {
|
|
||||||
// UpstreamConfigByID returns the custom upstream configuration for the
|
|
||||||
// client having id, using boot to initialize the one if necessary. It
|
|
||||||
// returns nil if there is no custom upstream configuration for the client.
|
|
||||||
// The id is expected to be either a string representation of an IP address
|
|
||||||
// or the ClientID.
|
|
||||||
UpstreamConfigByID(
|
|
||||||
id string,
|
|
||||||
boot upstream.Resolver,
|
|
||||||
) (conf *proxy.CustomUpstreamConfig, err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Config represents the DNS filtering configuration of AdGuard Home. The zero
|
// Config represents the DNS filtering configuration of AdGuard Home. The zero
|
||||||
// Config is empty and ready for use.
|
// Config is empty and ready for use.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -467,7 +454,7 @@ func (s *Server) prepareIpsetListSettings() (ipsets []string, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ipsets = stringutil.SplitTrimmed(string(data), "\n")
|
ipsets = stringutil.SplitTrimmed(string(data), "\n")
|
||||||
ipsets = slices.DeleteFunc(ipsets, IsCommentOrEmpty)
|
ipsets = slices.DeleteFunc(ipsets, aghnet.IsCommentOrEmpty)
|
||||||
|
|
||||||
log.Debug("dns: using %d ipset rules from file %q", len(ipsets), fn)
|
log.Debug("dns: using %d ipset rules from file %q", len(ipsets), fn)
|
||||||
|
|
||||||
@@ -478,7 +465,7 @@ func (s *Server) prepareIpsetListSettings() (ipsets []string, err error) {
|
|||||||
// the configuration itself.
|
// the configuration itself.
|
||||||
func (conf *ServerConfig) loadUpstreams() (upstreams []string, err error) {
|
func (conf *ServerConfig) loadUpstreams() (upstreams []string, err error) {
|
||||||
if conf.UpstreamDNSFileName == "" {
|
if conf.UpstreamDNSFileName == "" {
|
||||||
return stringutil.FilterOut(conf.UpstreamDNS, IsCommentOrEmpty), nil
|
return stringutil.FilterOut(conf.UpstreamDNS, aghnet.IsCommentOrEmpty), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var data []byte
|
var data []byte
|
||||||
@@ -491,7 +478,7 @@ func (conf *ServerConfig) loadUpstreams() (upstreams []string, err error) {
|
|||||||
|
|
||||||
log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), conf.UpstreamDNSFileName)
|
log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), conf.UpstreamDNSFileName)
|
||||||
|
|
||||||
return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil
|
return stringutil.FilterOut(upstreams, aghnet.IsCommentOrEmpty), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// collectListenAddr adds addrPort to addrs. It also adds its port to
|
// collectListenAddr adds addrPort to addrs. It also adds its port to
|
||||||
|
|||||||
@@ -299,6 +299,7 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
|
|||||||
Config: Config{
|
Config: Config{
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
UpstreamDNS: []string{upsAddr},
|
UpstreamDNS: []string{upsAddr},
|
||||||
},
|
},
|
||||||
UsePrivateRDNS: true,
|
UsePrivateRDNS: true,
|
||||||
@@ -337,6 +338,7 @@ func TestServer_dns64WithDisabledRDNS(t *testing.T) {
|
|||||||
Config: Config{
|
Config: Config{
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
UpstreamDNS: []string{upsAddr},
|
UpstreamDNS: []string{upsAddr},
|
||||||
},
|
},
|
||||||
UsePrivateRDNS: false,
|
UsePrivateRDNS: false,
|
||||||
|
|||||||
@@ -540,7 +540,7 @@ func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
|
|||||||
uc, err := newUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
|
uc, err := newUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
|
||||||
Bootstrap: boot,
|
Bootstrap: boot,
|
||||||
Timeout: s.conf.UpstreamTimeout,
|
Timeout: s.conf.UpstreamTimeout,
|
||||||
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
|
HTTPVersions: aghnet.UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
|
||||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||||
// Use a customized set of RootCAs, because Go's default mechanism of
|
// Use a customized set of RootCAs, because Go's default mechanism of
|
||||||
// loading TLS roots does not always work properly on some routers so we're
|
// loading TLS roots does not always work properly on some routers so we're
|
||||||
@@ -557,6 +557,13 @@ func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.conf.UpstreamConfig = uc
|
s.conf.UpstreamConfig = uc
|
||||||
|
s.conf.ClientsContainer.UpdateCommonUpstreamConfig(&client.CommonUpstreamConfig{
|
||||||
|
Bootstrap: boot,
|
||||||
|
UpstreamTimeout: s.conf.UpstreamTimeout,
|
||||||
|
BootstrapPreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||||
|
EDNSClientSubnetEnabled: s.conf.EDNSClientSubnet.Enabled,
|
||||||
|
UseHTTP3Upstreams: s.conf.UseHTTP3Upstreams,
|
||||||
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -630,7 +637,7 @@ func (s *Server) prepareInternalDNS() (err error) {
|
|||||||
|
|
||||||
bootOpts := &upstream.Options{
|
bootOpts := &upstream.Options{
|
||||||
Timeout: DefaultTimeout,
|
Timeout: DefaultTimeout,
|
||||||
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
|
HTTPVersions: aghnet.UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
|
||||||
}
|
}
|
||||||
|
|
||||||
s.bootstrap, s.bootResolvers, err = newBootstrap(s.conf.BootstrapDNS, s.etcHosts, bootOpts)
|
s.bootstrap, s.bootResolvers, err = newBootstrap(s.conf.BootstrapDNS, s.etcHosts, bootOpts)
|
||||||
@@ -661,7 +668,7 @@ func (s *Server) prepareInternalDNS() (err error) {
|
|||||||
// setupFallbackDNS initializes the fallback DNS servers.
|
// setupFallbackDNS initializes the fallback DNS servers.
|
||||||
func (s *Server) setupFallbackDNS() (uc *proxy.UpstreamConfig, err error) {
|
func (s *Server) setupFallbackDNS() (uc *proxy.UpstreamConfig, err error) {
|
||||||
fallbacks := s.conf.FallbackDNS
|
fallbacks := s.conf.FallbackDNS
|
||||||
fallbacks = stringutil.FilterOut(fallbacks, IsCommentOrEmpty)
|
fallbacks = stringutil.FilterOut(fallbacks, aghnet.IsCommentOrEmpty)
|
||||||
if len(fallbacks) == 0 {
|
if len(fallbacks) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/hashprefix"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering/hashprefix"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||||
@@ -61,6 +62,42 @@ const (
|
|||||||
// TODO(a.garipov): Use more.
|
// TODO(a.garipov): Use more.
|
||||||
var testClientAddrPort = netip.MustParseAddrPort("1.2.3.4:12345")
|
var testClientAddrPort = netip.MustParseAddrPort("1.2.3.4:12345")
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ ClientsContainer = (*clientsContainer)(nil)
|
||||||
|
|
||||||
|
// clientsContainer is a mock [ClientsContainer] implementation for tests.
|
||||||
|
type clientsContainer struct {
|
||||||
|
OnCustomUpstreamConfig func(
|
||||||
|
clientID string,
|
||||||
|
cliAddr netip.Addr,
|
||||||
|
) (conf *proxy.CustomUpstreamConfig)
|
||||||
|
|
||||||
|
OnUpdateCommonUpstreamConfig func(conf *client.CommonUpstreamConfig)
|
||||||
|
|
||||||
|
OnClearUpstreamCache func()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CustomUpstreamConfig implements the [ClientsContainer] interface for
|
||||||
|
// *clientsContainer.
|
||||||
|
func (c *clientsContainer) CustomUpstreamConfig(
|
||||||
|
clientID string,
|
||||||
|
cliAddr netip.Addr,
|
||||||
|
) (conf *proxy.CustomUpstreamConfig) {
|
||||||
|
return c.OnCustomUpstreamConfig(clientID, cliAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateCommonUpstreamConfig implements the [ClientsContainer] interface for
|
||||||
|
// *clientsContainer.
|
||||||
|
func (c *clientsContainer) UpdateCommonUpstreamConfig(conf *client.CommonUpstreamConfig) {
|
||||||
|
c.OnUpdateCommonUpstreamConfig(conf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUpstreamCache implements the [ClientsContainer] interface for
|
||||||
|
// *clientsContainer.
|
||||||
|
func (c *clientsContainer) ClearUpstreamCache() {
|
||||||
|
c.OnClearUpstreamCache()
|
||||||
|
}
|
||||||
|
|
||||||
func startDeferStop(t *testing.T, s *Server) {
|
func startDeferStop(t *testing.T, s *Server) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
@@ -168,6 +205,7 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte)
|
|||||||
Config: Config{
|
Config: Config{
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
})
|
})
|
||||||
@@ -297,6 +335,7 @@ func TestServer(t *testing.T) {
|
|||||||
Config: Config{
|
Config: Config{
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
})
|
})
|
||||||
@@ -337,6 +376,7 @@ func TestServer_timeout(t *testing.T) {
|
|||||||
Config: Config{
|
Config: Config{
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}
|
}
|
||||||
@@ -364,6 +404,7 @@ func TestServer_timeout(t *testing.T) {
|
|||||||
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{
|
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
}
|
}
|
||||||
|
s.conf.Config.ClientsContainer = EmptyClientsContainer{}
|
||||||
err = s.Prepare(&s.conf)
|
err = s.Prepare(&s.conf)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -380,6 +421,7 @@ func TestServer_Prepare_fallbacks(t *testing.T) {
|
|||||||
},
|
},
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}
|
}
|
||||||
@@ -405,6 +447,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
|
|||||||
Config: Config{
|
Config: Config{
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
})
|
})
|
||||||
@@ -536,6 +579,7 @@ func TestSafeSearch(t *testing.T) {
|
|||||||
EDNSClientSubnet: &EDNSClientSubnet{
|
EDNSClientSubnet: &EDNSClientSubnet{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
},
|
},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}
|
}
|
||||||
@@ -629,6 +673,7 @@ func TestInvalidRequest(t *testing.T) {
|
|||||||
EDNSClientSubnet: &EDNSClientSubnet{
|
EDNSClientSubnet: &EDNSClientSubnet{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
},
|
},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
})
|
})
|
||||||
@@ -659,6 +704,7 @@ func TestBlockedRequest(t *testing.T) {
|
|||||||
EDNSClientSubnet: &EDNSClientSubnet{
|
EDNSClientSubnet: &EDNSClientSubnet{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
},
|
},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}
|
}
|
||||||
@@ -696,6 +742,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
|
|||||||
EDNSClientSubnet: &EDNSClientSubnet{
|
EDNSClientSubnet: &EDNSClientSubnet{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
},
|
},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}
|
}
|
||||||
@@ -721,12 +768,12 @@ func TestServerCustomClientUpstream(t *testing.T) {
|
|||||||
forwardConf.EDNSClientSubnet.Enabled,
|
forwardConf.EDNSClientSubnet.Enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
s.conf.ClientsContainer = &aghtest.ClientsContainer{
|
s.conf.ClientsContainer = &clientsContainer{
|
||||||
OnUpstreamConfigByID: func(
|
OnCustomUpstreamConfig: func(
|
||||||
_ string,
|
_ string,
|
||||||
_ upstream.Resolver,
|
_ netip.Addr,
|
||||||
) (conf *proxy.CustomUpstreamConfig, err error) {
|
) (conf *proxy.CustomUpstreamConfig) {
|
||||||
return customUpsConf, nil
|
return customUpsConf
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -774,6 +821,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
|
|||||||
EDNSClientSubnet: &EDNSClientSubnet{
|
EDNSClientSubnet: &EDNSClientSubnet{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
},
|
},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
})
|
})
|
||||||
@@ -808,6 +856,7 @@ func TestBlockCNAME(t *testing.T) {
|
|||||||
EDNSClientSubnet: &EDNSClientSubnet{
|
EDNSClientSubnet: &EDNSClientSubnet{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
},
|
},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}
|
}
|
||||||
@@ -884,6 +933,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
|
|||||||
EDNSClientSubnet: &EDNSClientSubnet{
|
EDNSClientSubnet: &EDNSClientSubnet{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
},
|
},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}
|
}
|
||||||
@@ -930,6 +980,7 @@ func TestNullBlockedRequest(t *testing.T) {
|
|||||||
EDNSClientSubnet: &EDNSClientSubnet{
|
EDNSClientSubnet: &EDNSClientSubnet{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
},
|
},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}
|
}
|
||||||
@@ -998,6 +1049,7 @@ func TestBlockedCustomIP(t *testing.T) {
|
|||||||
EDNSClientSubnet: &EDNSClientSubnet{
|
EDNSClientSubnet: &EDNSClientSubnet{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
},
|
},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}
|
}
|
||||||
@@ -1051,6 +1103,7 @@ func TestBlockedByHosts(t *testing.T) {
|
|||||||
EDNSClientSubnet: &EDNSClientSubnet{
|
EDNSClientSubnet: &EDNSClientSubnet{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
},
|
},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}
|
}
|
||||||
@@ -1103,6 +1156,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
|||||||
EDNSClientSubnet: &EDNSClientSubnet{
|
EDNSClientSubnet: &EDNSClientSubnet{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
},
|
},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}
|
}
|
||||||
@@ -1164,6 +1218,7 @@ func TestRewrite(t *testing.T) {
|
|||||||
EDNSClientSubnet: &EDNSClientSubnet{
|
EDNSClientSubnet: &EDNSClientSubnet{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
},
|
},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}))
|
}))
|
||||||
@@ -1290,6 +1345,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
|||||||
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
|
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
|
||||||
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
|
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
|
||||||
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
|
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
|
||||||
|
s.conf.Config.ClientsContainer = EmptyClientsContainer{}
|
||||||
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
|
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
|
||||||
|
|
||||||
err = s.Prepare(&s.conf)
|
err = s.Prepare(&s.conf)
|
||||||
@@ -1375,6 +1431,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
|||||||
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
|
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
|
||||||
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
|
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
|
||||||
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
|
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
|
||||||
|
s.conf.Config.ClientsContainer = EmptyClientsContainer{}
|
||||||
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
|
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
|
||||||
|
|
||||||
err = s.Prepare(&s.conf)
|
err = s.Prepare(&s.conf)
|
||||||
@@ -1643,6 +1700,7 @@ func TestServer_Exchange(t *testing.T) {
|
|||||||
UpstreamDNS: []string{upsAddr},
|
UpstreamDNS: []string{upsAddr},
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
LocalPTRResolvers: []string{localUpsAddr},
|
LocalPTRResolvers: []string{localUpsAddr},
|
||||||
UsePrivateRDNS: true,
|
UsePrivateRDNS: true,
|
||||||
@@ -1665,6 +1723,7 @@ func TestServer_Exchange(t *testing.T) {
|
|||||||
UpstreamDNS: []string{upsAddr},
|
UpstreamDNS: []string{upsAddr},
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
LocalPTRResolvers: []string{},
|
LocalPTRResolvers: []string{},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
|||||||
Config: Config{
|
Config: Config{
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) {
|
|||||||
EDNSClientSubnet: &EDNSClientSubnet{
|
EDNSClientSubnet: &EDNSClientSubnet{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
},
|
},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
@@ -647,7 +648,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req.BootstrapDNS = stringutil.FilterOut(req.BootstrapDNS, IsCommentOrEmpty)
|
req.BootstrapDNS = stringutil.FilterOut(req.BootstrapDNS, aghnet.IsCommentOrEmpty)
|
||||||
|
|
||||||
opts := &upstream.Options{
|
opts := &upstream.Options{
|
||||||
Timeout: s.conf.UpstreamTimeout,
|
Timeout: s.conf.UpstreamTimeout,
|
||||||
@@ -673,6 +674,8 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
|||||||
// handleCacheClear is the handler for the POST /control/cache_clear HTTP API.
|
// handleCacheClear is the handler for the POST /control/cache_clear HTTP API.
|
||||||
func (s *Server) handleCacheClear(w http.ResponseWriter, _ *http.Request) {
|
func (s *Server) handleCacheClear(w http.ResponseWriter, _ *http.Request) {
|
||||||
s.dnsProxy.ClearCache()
|
s.dnsProxy.ClearCache()
|
||||||
|
s.conf.ClientsContainer.ClearUpstreamCache()
|
||||||
|
|
||||||
_, _ = io.WriteString(w, "OK")
|
_, _ = io.WriteString(w, "OK")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -83,6 +83,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
|
|||||||
RatelimitSubnetLenIPv6: 56,
|
RatelimitSubnetLenIPv6: 56,
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ConfigModified: func() {},
|
ConfigModified: func() {},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
@@ -164,6 +165,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
|||||||
RatelimitSubnetLenIPv6: 56,
|
RatelimitSubnetLenIPv6: 56,
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ConfigModified: func() {},
|
ConfigModified: func() {},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
@@ -299,24 +301,6 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsCommentOrEmpty(t *testing.T) {
|
|
||||||
for _, tc := range []struct {
|
|
||||||
want assert.BoolAssertionFunc
|
|
||||||
str string
|
|
||||||
}{{
|
|
||||||
want: assert.True,
|
|
||||||
str: "",
|
|
||||||
}, {
|
|
||||||
want: assert.True,
|
|
||||||
str: "# comment",
|
|
||||||
}, {
|
|
||||||
want: assert.False,
|
|
||||||
str: "1.2.3.4",
|
|
||||||
}} {
|
|
||||||
tc.want(t, IsCommentOrEmpty(tc.str))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) {
|
func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
@@ -388,6 +372,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
|||||||
Config: Config{
|
Config: Config{
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package dnsforward
|
package dnsforward
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"net"
|
"net"
|
||||||
@@ -577,17 +576,14 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use the ClientID first, since it has a higher priority.
|
cliAddr := pctx.Addr.Addr()
|
||||||
id := cmp.Or(clientID, pctx.Addr.Addr().String())
|
upsConf := s.conf.ClientsContainer.CustomUpstreamConfig(clientID, cliAddr)
|
||||||
upsConf, err := s.conf.ClientsContainer.UpstreamConfigByID(id, s.bootstrap)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if upsConf != nil {
|
if upsConf != nil {
|
||||||
log.Debug("dnsforward: using custom upstreams for client %s", id)
|
log.Debug(
|
||||||
|
"dnsforward: using custom upstreams for client with ip %s and clientid %q",
|
||||||
|
cliAddr,
|
||||||
|
clientID,
|
||||||
|
)
|
||||||
|
|
||||||
pctx.CustomUpstreamConfig = upsConf
|
pctx.CustomUpstreamConfig = upsConf
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -81,6 +81,7 @@ func TestServer_ProcessInitial(t *testing.T) {
|
|||||||
AAAADisabled: tc.aaaaDisabled,
|
AAAADisabled: tc.aaaaDisabled,
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}
|
}
|
||||||
@@ -180,6 +181,7 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
|
|||||||
AAAADisabled: tc.aaaaDisabled,
|
AAAADisabled: tc.aaaaDisabled,
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}
|
}
|
||||||
@@ -324,6 +326,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
|||||||
HandleDDR: tc.ddrEnabled,
|
HandleDDR: tc.ddrEnabled,
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
TLSConfig: TLSConfig{
|
TLSConfig: TLSConfig{
|
||||||
ServerName: ddrTestDomainName,
|
ServerName: ddrTestDomainName,
|
||||||
@@ -660,6 +663,7 @@ func TestServer_HandleDNSRequest_restrictLocal(t *testing.T) {
|
|||||||
UpstreamDNS: []string{localUpsAddr},
|
UpstreamDNS: []string{localUpsAddr},
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
UsePrivateRDNS: true,
|
UsePrivateRDNS: true,
|
||||||
LocalPTRResolvers: []string{localUpsAddr},
|
LocalPTRResolvers: []string{localUpsAddr},
|
||||||
@@ -788,6 +792,7 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) {
|
|||||||
Config: Config{
|
Config: Config{
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
UsePrivateRDNS: true,
|
UsePrivateRDNS: true,
|
||||||
LocalPTRResolvers: []string{localUpsAddr},
|
LocalPTRResolvers: []string{localUpsAddr},
|
||||||
@@ -816,6 +821,7 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) {
|
|||||||
Config: Config{
|
Config: Config{
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
UsePrivateRDNS: false,
|
UsePrivateRDNS: false,
|
||||||
LocalPTRResolvers: []string{localUpsAddr},
|
LocalPTRResolvers: []string{localUpsAddr},
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
|
|||||||
Config: Config{
|
Config: Config{
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
ClientsContainer: EmptyClientsContainer{},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ func newPrivateConfig(
|
|||||||
) (uc *proxy.UpstreamConfig, err error) {
|
) (uc *proxy.UpstreamConfig, err error) {
|
||||||
confNeedsFiltering := len(addrs) > 0
|
confNeedsFiltering := len(addrs) > 0
|
||||||
if confNeedsFiltering {
|
if confNeedsFiltering {
|
||||||
addrs = stringutil.FilterOut(addrs, IsCommentOrEmpty)
|
addrs = stringutil.FilterOut(addrs, aghnet.IsCommentOrEmpty)
|
||||||
} else {
|
} else {
|
||||||
sysResolvers := slices.DeleteFunc(slices.Clone(sysResolvers.Addrs()), unwanted.Has)
|
sysResolvers := slices.DeleteFunc(slices.Clone(sysResolvers.Addrs()), unwanted.Has)
|
||||||
addrs = make([]string, 0, len(sysResolvers))
|
addrs = make([]string, 0, len(sysResolvers))
|
||||||
@@ -127,20 +127,6 @@ func newPrivateConfig(
|
|||||||
return uc, nil
|
return uc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration
|
|
||||||
// depending on configuration.
|
|
||||||
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
|
|
||||||
if !http3 {
|
|
||||||
return upstream.DefaultHTTPVersions
|
|
||||||
}
|
|
||||||
|
|
||||||
return []upstream.HTTPVersion{
|
|
||||||
upstream.HTTPVersion3,
|
|
||||||
upstream.HTTPVersion2,
|
|
||||||
upstream.HTTPVersion11,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// setProxyUpstreamMode sets the upstream mode and related settings in conf
|
// setProxyUpstreamMode sets the upstream mode and related settings in conf
|
||||||
// based on provided parameters.
|
// based on provided parameters.
|
||||||
func setProxyUpstreamMode(
|
func setProxyUpstreamMode(
|
||||||
@@ -162,10 +148,3 @@ func setProxyUpstreamMode(
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
|
|
||||||
// This function is useful for filtering out non-upstream lines from upstream
|
|
||||||
// configs.
|
|
||||||
func IsCommentOrEmpty(s string) (ok bool) {
|
|
||||||
return len(s) == 0 || s[0] == '#'
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -12,17 +12,13 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
"github.com/AdguardTeam/golibs/stringutil"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// clientsContainer is the storage of all runtime and persistent clients.
|
// clientsContainer is the storage of all runtime and persistent clients.
|
||||||
@@ -373,63 +369,6 @@ func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// type check
|
|
||||||
var _ dnsforward.ClientsContainer = (*clientsContainer)(nil)
|
|
||||||
|
|
||||||
// UpstreamConfigByID implements the [dnsforward.ClientsContainer] interface for
|
|
||||||
// *clientsContainer. upsConf is nil if the client isn't found or if the client
|
|
||||||
// has no custom upstreams.
|
|
||||||
func (clients *clientsContainer) UpstreamConfigByID(
|
|
||||||
id string,
|
|
||||||
bootstrap upstream.Resolver,
|
|
||||||
) (conf *proxy.CustomUpstreamConfig, err error) {
|
|
||||||
clients.lock.Lock()
|
|
||||||
defer clients.lock.Unlock()
|
|
||||||
|
|
||||||
c, ok := clients.storage.Find(id)
|
|
||||||
if !ok {
|
|
||||||
return nil, nil
|
|
||||||
} else if c.UpstreamConfig != nil {
|
|
||||||
return c.UpstreamConfig, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
upstreams := stringutil.FilterOut(c.Upstreams, dnsforward.IsCommentOrEmpty)
|
|
||||||
if len(upstreams) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var upsConf *proxy.UpstreamConfig
|
|
||||||
upsConf, err = proxy.ParseUpstreamsConfig(
|
|
||||||
upstreams,
|
|
||||||
&upstream.Options{
|
|
||||||
Bootstrap: bootstrap,
|
|
||||||
Timeout: time.Duration(config.DNS.UpstreamTimeout),
|
|
||||||
HTTPVersions: dnsforward.UpstreamHTTPVersions(config.DNS.UseHTTP3Upstreams),
|
|
||||||
PreferIPv6: config.DNS.BootstrapPreferIPv6,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
// Don't wrap the error since it's informative enough as is.
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
conf = proxy.NewCustomUpstreamConfig(
|
|
||||||
upsConf,
|
|
||||||
c.UpstreamsCacheEnabled,
|
|
||||||
int(c.UpstreamsCacheSize),
|
|
||||||
config.DNS.EDNSClientSubnet.Enabled,
|
|
||||||
)
|
|
||||||
c.UpstreamConfig = conf
|
|
||||||
|
|
||||||
// TODO(s.chzhen): Pass context.
|
|
||||||
err = clients.storage.Update(context.TODO(), c.Name, c)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("setting upstream config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return conf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// type check
|
// type check
|
||||||
var _ client.AddressUpdater = (*clientsContainer)(nil)
|
var _ client.AddressUpdater = (*clientsContainer)(nil)
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,12 @@
|
|||||||
package home
|
package home
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -38,28 +35,3 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) {
|
|||||||
|
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClientsCustomUpstream(t *testing.T) {
|
|
||||||
clients := newClientsContainer(t)
|
|
||||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
|
||||||
|
|
||||||
// Add client with upstreams.
|
|
||||||
err := clients.storage.Add(ctx, &client.Persistent{
|
|
||||||
Name: "client1",
|
|
||||||
UID: client.MustNewUID(),
|
|
||||||
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")},
|
|
||||||
Upstreams: []string{
|
|
||||||
"1.1.1.1",
|
|
||||||
"[/example.org/]8.8.8.8",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
upsConf, err := clients.UpstreamConfigByID("1.2.3.4", net.DefaultResolver)
|
|
||||||
assert.Nil(t, upsConf)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
upsConf, err = clients.UpstreamConfigByID("1.1.1.1", net.DefaultResolver)
|
|
||||||
require.NotNil(t, upsConf)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -156,7 +156,13 @@ func initDNSServer(
|
|||||||
|
|
||||||
globalContext.clients.clientChecker = globalContext.dnsServer
|
globalContext.clients.clientChecker = globalContext.dnsServer
|
||||||
|
|
||||||
dnsConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpReg)
|
dnsConf, err := newServerConfig(
|
||||||
|
&config.DNS,
|
||||||
|
config.Clients.Sources,
|
||||||
|
tlsConf,
|
||||||
|
httpReg,
|
||||||
|
globalContext.clients.storage,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("newServerConfig: %w", err)
|
return fmt.Errorf("newServerConfig: %w", err)
|
||||||
}
|
}
|
||||||
@@ -230,12 +236,13 @@ func newServerConfig(
|
|||||||
clientSrcConf *clientSourcesConfig,
|
clientSrcConf *clientSourcesConfig,
|
||||||
tlsConf *tlsConfigSettings,
|
tlsConf *tlsConfigSettings,
|
||||||
httpReg aghhttp.RegisterFunc,
|
httpReg aghhttp.RegisterFunc,
|
||||||
|
clientsContainer dnsforward.ClientsContainer,
|
||||||
) (newConf *dnsforward.ServerConfig, err error) {
|
) (newConf *dnsforward.ServerConfig, err error) {
|
||||||
hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()})
|
hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()})
|
||||||
|
|
||||||
fwdConf := dnsConf.Config
|
fwdConf := dnsConf.Config
|
||||||
fwdConf.FilterHandler = applyAdditionalFiltering
|
fwdConf.FilterHandler = applyAdditionalFiltering
|
||||||
fwdConf.ClientsContainer = &globalContext.clients
|
fwdConf.ClientsContainer = clientsContainer
|
||||||
|
|
||||||
newConf = &dnsforward.ServerConfig{
|
newConf = &dnsforward.ServerConfig{
|
||||||
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
|
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
|
||||||
@@ -484,7 +491,13 @@ func reconfigureDNSServer() (err error) {
|
|||||||
tlsConf := &tlsConfigSettings{}
|
tlsConf := &tlsConfigSettings{}
|
||||||
globalContext.tls.WriteDiskConfig(tlsConf)
|
globalContext.tls.WriteDiskConfig(tlsConf)
|
||||||
|
|
||||||
newConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpRegister)
|
newConf, err := newServerConfig(
|
||||||
|
&config.DNS,
|
||||||
|
config.Clients.Sources,
|
||||||
|
tlsConf,
|
||||||
|
httpRegister,
|
||||||
|
globalContext.clients.storage,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("generating forwarding dns server config: %w", err)
|
return fmt.Errorf("generating forwarding dns server config: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user