diff --git a/CHANGELOG.md b/CHANGELOG.md index f1e5d85b..6ec11c7a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ NOTE: Add new changes BELOW THIS COMMENT. ### Fixed +- Changes to global upstream DNS settings not applying to custom client upstream configurations. - The formatting of large numbers in the clients tables on the *Client settings* page ([#7583]). [#7583]: https://github.com/AdguardTeam/AdGuardHome/issues/7583 diff --git a/internal/aghnet/upstream.go b/internal/aghnet/upstream.go new file mode 100644 index 00000000..e61c47e8 --- /dev/null +++ b/internal/aghnet/upstream.go @@ -0,0 +1,24 @@ +package aghnet + +import "github.com/AdguardTeam/dnsproxy/upstream" + +// UpstreamHTTPVersions returns the HTTP versions for upstream configuration +// depending on configuration. +func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) { + if !http3 { + return upstream.DefaultHTTPVersions + } + + return []upstream.HTTPVersion{ + upstream.HTTPVersion3, + upstream.HTTPVersion2, + upstream.HTTPVersion11, + } +} + +// IsCommentOrEmpty returns true if s starts with a "#" character or is empty. +// This function is useful for filtering out non-upstream lines from upstream +// configs. +func IsCommentOrEmpty(s string) (ok bool) { + return len(s) == 0 || s[0] == '#' +} diff --git a/internal/aghnet/upstream_test.go b/internal/aghnet/upstream_test.go new file mode 100644 index 00000000..1c0cd1c0 --- /dev/null +++ b/internal/aghnet/upstream_test.go @@ -0,0 +1,26 @@ +package aghnet_test + +import ( + "testing" + + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" + "github.com/stretchr/testify/assert" +) + +func TestIsCommentOrEmpty(t *testing.T) { + for _, tc := range []struct { + want assert.BoolAssertionFunc + str string + }{{ + want: assert.True, + str: "", + }, { + want: assert.True, + str: "# comment", + }, { + want: assert.False, + str: "1.2.3.4", + }} { + tc.want(t, aghnet.IsCommentOrEmpty(tc.str)) + } +} diff --git a/internal/aghtest/interface.go b/internal/aghtest/interface.go index 8db2882b..00d537bd 100644 --- a/internal/aghtest/interface.go +++ b/internal/aghtest/interface.go @@ -10,7 +10,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/rdns" "github.com/AdguardTeam/AdGuardHome/internal/whois" - "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/miekg/dns" ) @@ -121,26 +120,6 @@ func (p *AddressUpdater) UpdateAddress( p.OnUpdateAddress(ctx, ip, host, info) } -// Package dnsforward - -// ClientsContainer is a fake [dnsforward.ClientsContainer] implementation for -// tests. -type ClientsContainer struct { - OnUpstreamConfigByID func( - id string, - boot upstream.Resolver, - ) (conf *proxy.CustomUpstreamConfig, err error) -} - -// UpstreamConfigByID implements the [dnsforward.ClientsContainer] interface -// for *ClientsContainer. -func (c *ClientsContainer) UpstreamConfigByID( - id string, - boot upstream.Resolver, -) (conf *proxy.CustomUpstreamConfig, err error) { - return c.OnUpstreamConfigByID(id, boot) -} - // Package filtering // Resolver is a fake [filtering.Resolver] implementation for tests. diff --git a/internal/aghtest/interface_test.go b/internal/aghtest/interface_test.go index f0f55451..93866c2b 100644 --- a/internal/aghtest/interface_test.go +++ b/internal/aghtest/interface_test.go @@ -3,7 +3,6 @@ package aghtest_test import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/client" - "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/filtering" ) @@ -12,9 +11,6 @@ import ( // type check var _ filtering.Resolver = (*aghtest.Resolver)(nil) -// type check -var _ dnsforward.ClientsContainer = (*aghtest.ClientsContainer)(nil) - // type check // // TODO(s.chzhen): It's here to avoid the import cycle. Remove it. diff --git a/internal/client/index.go b/internal/client/index.go index 2eb7411b..d34e0e51 100644 --- a/internal/client/index.go +++ b/internal/client/index.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/AdguardTeam/AdGuardHome/internal/aghalg" - "github.com/AdguardTeam/golibs/errors" ) // macKey contains MAC as byte array of 6, 8, or 20 bytes. @@ -35,7 +34,7 @@ type index struct { // nameToUID maps client name to UID. nameToUID map[string]UID - // clientIDToUID maps client ID to UID. + // clientIDToUID maps ClientID to UID. clientIDToUID map[string]UID // ipToUID maps IP address to UID. @@ -205,19 +204,19 @@ func (ci *index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr) return nil, nil } -// find finds persistent client by string representation of the client ID, IP +// find finds persistent client by string representation of the ClientID, IP // address, or MAC. func (ci *index) find(id string) (c *Persistent, ok bool) { - uid, found := ci.clientIDToUID[id] - if found { - return ci.uidToClient[uid], true + c, ok = ci.findByClientID(id) + if ok { + return c, true } ip, err := netip.ParseAddr(id) if err == nil { // MAC addresses can be successfully parsed as IP addresses. - c, found = ci.findByIP(ip) - if found { + c, ok = ci.findByIP(ip) + if ok { return c, true } } @@ -230,6 +229,16 @@ func (ci *index) find(id string) (c *Persistent, ok bool) { return nil, false } +// findByClientID finds persistent client by ClientID. +func (ci *index) findByClientID(clientID string) (c *Persistent, ok bool) { + uid, ok := ci.clientIDToUID[clientID] + if ok { + return ci.uidToClient[uid], true + } + + return nil, false +} + // findByName finds persistent client by name. func (ci *index) findByName(name string) (c *Persistent, found bool) { uid, found := ci.nameToUID[name] @@ -343,18 +352,3 @@ func (ci *index) rangeByName(f func(c *Persistent) (cont bool)) { } } } - -// closeUpstreams closes upstream configurations of persistent clients. -func (ci *index) closeUpstreams() (err error) { - var errs []error - ci.rangeByName(func(c *Persistent) (cont bool) { - err = c.CloseUpstreams() - if err != nil { - errs = append(errs, err) - } - - return true - }) - - return errors.Join(errs...) -} diff --git a/internal/client/persistent.go b/internal/client/persistent.go index 1cea335b..4ec3695e 100644 --- a/internal/client/persistent.go +++ b/internal/client/persistent.go @@ -58,12 +58,6 @@ func (uid *UID) UnmarshalText(data []byte) error { // Persistent contains information about persistent clients. type Persistent struct { - // UpstreamConfig is the custom upstream configuration for this client. If - // it's nil, it has not been initialized yet. If it's non-nil and empty, - // there are no valid upstreams. If it's non-nil and non-empty, these - // upstream must be used. - UpstreamConfig *proxy.CustomUpstreamConfig - // SafeSearch handles search engine hosts rewrites. SafeSearch filtering.SafeSearch @@ -262,7 +256,7 @@ func ValidateClientID(id string) (err error) { return nil } -// IDs returns a list of client IDs containing at least one element. +// IDs returns a list of ClientIDs containing at least one element. func (c *Persistent) IDs() (ids []string) { ids = make([]string, 0, c.IDsLen()) @@ -281,7 +275,7 @@ func (c *Persistent) IDs() (ids []string) { return append(ids, c.ClientIDs...) } -// IDsLen returns a length of client ids. +// IDsLen returns a length of ClientIDs. func (c *Persistent) IDsLen() (n int) { return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs) } @@ -312,14 +306,3 @@ func (c *Persistent) ShallowClone() (clone *Persistent) { return clone } - -// CloseUpstreams closes the client-specific upstream config of c if any. -func (c *Persistent) CloseUpstreams() (err error) { - if c.UpstreamConfig != nil { - if err = c.UpstreamConfig.Close(); err != nil { - return fmt.Errorf("closing upstreams of client %q: %w", c.Name, err) - } - } - - return nil -} diff --git a/internal/client/storage.go b/internal/client/storage.go index 455abb9b..fbbfd1b8 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -13,6 +13,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/arpdb" "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/AdGuardHome/internal/whois" + "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/logutil/slogutil" @@ -126,6 +127,9 @@ type Storage struct { // runtimeIndex contains information about runtime clients. runtimeIndex *runtimeIndex + // upstreamManager stores and updates custom client upstream configurations. + upstreamManager *upstreamManager + // dhcp is used to update [SourceDHCP] runtime client information. dhcp DHCP @@ -163,6 +167,7 @@ func NewStorage(ctx context.Context, conf *StorageConfig) (s *Storage, err error mu: &sync.Mutex{}, index: newIndex(), runtimeIndex: newRuntimeIndex(), + upstreamManager: newUpstreamManager(conf.Logger), dhcp: conf.DHCP, etcHosts: conf.EtcHosts, arpDB: conf.ARPDB, @@ -200,7 +205,7 @@ func (s *Storage) Start(ctx context.Context) (err error) { func (s *Storage) Shutdown(_ context.Context) (err error) { close(s.done) - return s.closeUpstreams() + return s.upstreamManager.close() } // periodicARPUpdate periodically reloads runtime clients from ARP. It is @@ -416,6 +421,7 @@ func (s *Storage) Add(ctx context.Context, p *Persistent) (err error) { } s.index.add(p) + s.upstreamManager.updateCustomUpstreamConfig(p) s.logger.DebugContext( ctx, @@ -441,7 +447,7 @@ func (s *Storage) FindByName(name string) (p *Persistent, ok bool) { return nil, false } -// Find finds persistent client by string representation of the client ID, IP +// Find finds persistent client by string representation of the ClientID, IP // address, or MAC. And returns its shallow copy. // // TODO(s.chzhen): Accept ClientIDData structure instead, which will contain @@ -514,12 +520,13 @@ func (s *Storage) RemoveByName(ctx context.Context, name string) (ok bool) { return false } - if err := p.CloseUpstreams(); err != nil { - s.logger.ErrorContext(ctx, "removing client", "name", p.Name, slogutil.KeyError, err) - } - s.index.remove(p) + err := s.upstreamManager.remove(p.UID) + if err != nil { + s.logger.DebugContext(ctx, "closing client upstreams", "name", name, slogutil.KeyError, err) + } + return true } @@ -556,6 +563,8 @@ func (s *Storage) Update(ctx context.Context, name string, p *Persistent) (err e s.index.remove(stored) s.index.add(p) + s.upstreamManager.updateCustomUpstreamConfig(p) + return nil } @@ -576,14 +585,6 @@ func (s *Storage) Size() (n int) { return s.index.size() } -// closeUpstreams closes upstream configurations of persistent clients. -func (s *Storage) closeUpstreams() (err error) { - s.mu.Lock() - defer s.mu.Unlock() - - return s.index.closeUpstreams() -} - // ClientRuntime returns a copy of the saved runtime client by ip. If no such // client exists, returns nil. func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) { @@ -626,3 +627,42 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) { func (s *Storage) AllowedTags() (tags []string) { return s.allowedTags } + +// CustomUpstreamConfig implements the [dnsforward.ClientsContainer] interface +// for *Storage +func (s *Storage) CustomUpstreamConfig( + id string, + addr netip.Addr, +) (prxConf *proxy.CustomUpstreamConfig) { + s.mu.Lock() + defer s.mu.Unlock() + + c, ok := s.index.findByClientID(id) + if !ok { + c, ok = s.index.findByIP(addr) + } + + if !ok { + return nil + } + + return s.upstreamManager.customUpstreamConfig(c.UID) +} + +// UpdateCommonUpstreamConfig implements the [dnsforward.ClientsContainer] +// interface for *Storage +func (s *Storage) UpdateCommonUpstreamConfig(conf *CommonUpstreamConfig) { + s.mu.Lock() + defer s.mu.Unlock() + + s.upstreamManager.updateCommonUpstreamConfig(conf) +} + +// ClearUpstreamCache implements the [dnsforward.ClientsContainer] interface for +// *Storage +func (s *Storage) ClearUpstreamCache() { + s.mu.Lock() + defer s.mu.Unlock() + + s.upstreamManager.clearUpstreamCache() +} diff --git a/internal/client/storage_test.go b/internal/client/storage_test.go index eb69b9fe..4a981148 100644 --- a/internal/client/storage_test.go +++ b/internal/client/storage_test.go @@ -13,6 +13,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" + "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/logutil/slogutil" @@ -34,6 +35,9 @@ func newTestStorage(tb testing.TB) (s *client.Storage) { return s } +// type check +var _ dnsforward.ClientsContainer = (*client.Storage)(nil) + // testHostsContainer is a mock implementation of the [client.HostsContainer] // interface. type testHostsContainer struct { @@ -1278,3 +1282,90 @@ func TestStorage_RangeByName(t *testing.T) { }) } } + +func TestStorage_CustomUpstreamConfig(t *testing.T) { + const ( + existingName = "existing_name" + existingClientID = "existing_client_id" + + nonExistingClientID = "non_existing_client_id" + ) + + var ( + existingClientUID = client.MustNewUID() + existingIP = netip.MustParseAddr("192.0.2.1") + + nonExistingIP = netip.MustParseAddr("192.0.2.255") + + testUpstreamTimeout = time.Second + ) + + existingClient := &client.Persistent{ + Name: existingName, + IPs: []netip.Addr{existingIP}, + ClientIDs: []string{existingClientID}, + UID: existingClientUID, + Upstreams: []string{"192.0.2.0"}, + } + + s := newTestStorage(t) + s.UpdateCommonUpstreamConfig(&client.CommonUpstreamConfig{ + UpstreamTimeout: testUpstreamTimeout, + }) + + testutil.CleanupAndRequireSuccess(t, func() (err error) { + return s.Shutdown(testutil.ContextWithTimeout(t, testTimeout)) + }) + + ctx := testutil.ContextWithTimeout(t, testTimeout) + err := s.Add(ctx, existingClient) + require.NoError(t, err) + + testCases := []struct { + cliAddr netip.Addr + wantNilConf assert.ValueAssertionFunc + name string + cliID string + }{{ + name: "client_id", + cliID: existingClientID, + cliAddr: netip.Addr{}, + wantNilConf: assert.NotNil, + }, { + name: "client_addr", + cliID: "", + cliAddr: existingIP, + wantNilConf: assert.NotNil, + }, { + name: "non_existing_client_id", + cliID: nonExistingClientID, + cliAddr: netip.Addr{}, + wantNilConf: assert.Nil, + }, { + name: "non_existing_client_addr", + cliID: "", + cliAddr: nonExistingIP, + wantNilConf: assert.Nil, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + conf := s.CustomUpstreamConfig(tc.cliID, tc.cliAddr) + tc.wantNilConf(t, conf) + }) + } + + t.Run("update_common_config", func(t *testing.T) { + conf := s.CustomUpstreamConfig(existingClientID, existingIP) + require.NotNil(t, conf) + + s.UpdateCommonUpstreamConfig(&client.CommonUpstreamConfig{ + UpstreamTimeout: testUpstreamTimeout * 2, + }) + + updConf := s.CustomUpstreamConfig(existingClientID, existingIP) + require.NotNil(t, updConf) + + assert.NotEqual(t, conf, updConf) + }) +} diff --git a/internal/client/upstreammanager.go b/internal/client/upstreammanager.go new file mode 100644 index 00000000..bc2c4362 --- /dev/null +++ b/internal/client/upstreammanager.go @@ -0,0 +1,219 @@ +package client + +import ( + "fmt" + "log/slog" + "slices" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" + "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/stringutil" +) + +// CommonUpstreamConfig contains common settings for custom client upstream +// configurations. +type CommonUpstreamConfig struct { + Bootstrap upstream.Resolver + UpstreamTimeout time.Duration + BootstrapPreferIPv6 bool + EDNSClientSubnetEnabled bool + UseHTTP3Upstreams bool +} + +// customUpstreamConfig contains custom client upstream configuration and the +// timestamp of the latest configuration update. +type customUpstreamConfig struct { + // proxyConf is the constructed upstream configuration for the [proxy], + // derived from the fields below. It is initialized on demand with + // [newCustomUpstreamConfig]. + proxyConf *proxy.CustomUpstreamConfig + + // commonConfUpdate is the timestamp of the latest configuration update, + // used to check against [upstreamManager.confUpdate] to determine if the + // configuration is up to date. + commonConfUpdate time.Time + + // upstreams is the cached list of custom upstream DNS servers used for the + // configuration of proxyConf. + upstreams []string + + // upstreamsCacheSize is the cached value of the cache size of the + // upstreams, used for the configuration of proxyConf. + upstreamsCacheSize uint32 + + // upstreamsCacheEnabled is the cached value indicating whether the cache of + // the upstreams is enabled for the configuration of proxyConf. + upstreamsCacheEnabled bool + + // isChanged indicates whether the proxyConf needs to be updated. + isChanged bool +} + +// upstreamManager stores and updates custom client upstream configurations. +type upstreamManager struct { + // logger is used for logging the operation of the upstream manager. It + // must not be nil. + // + // TODO(s.chzhen): Consider using a logger with its own prefix. + logger *slog.Logger + + // uidToCustomConf maps persistent client UID to the custom client upstream + // configuration. Stored UIDs must be in sync with the [index.uidToClient]. + uidToCustomConf map[UID]*customUpstreamConfig + + // commonConf is the common upstream configuration. + commonConf *CommonUpstreamConfig + + // confUpdate is the timestamp of the latest common upstream configuration + // update. + confUpdate time.Time +} + +// newUpstreamManager returns the new properly initialized upstream manager. +func newUpstreamManager(logger *slog.Logger) (m *upstreamManager) { + return &upstreamManager{ + logger: logger, + uidToCustomConf: make(map[UID]*customUpstreamConfig), + } +} + +// updateCommonUpstreamConfig updates the common upstream configuration and the +// timestamp of the latest configuration update. +func (m *upstreamManager) updateCommonUpstreamConfig(conf *CommonUpstreamConfig) { + m.commonConf = conf + m.confUpdate = time.Now() +} + +// updateCustomUpstreamConfig updates the stored custom client upstream +// configuration associated with the persistent client. It also sets +// [customUpstreamConfig.isChanged] to true so [customUpstreamConfig.proxyConf] +// can be updated later in [upstreamManager.customUpstreamConfig]. +func (m *upstreamManager) updateCustomUpstreamConfig(c *Persistent) { + cliConf, ok := m.uidToCustomConf[c.UID] + if !ok { + cliConf = &customUpstreamConfig{ + commonConfUpdate: m.confUpdate, + } + + m.uidToCustomConf[c.UID] = cliConf + } + + // TODO(s.chzhen): Compare before cloning. + cliConf.upstreams = slices.Clone(c.Upstreams) + cliConf.upstreamsCacheSize = c.UpstreamsCacheSize + cliConf.upstreamsCacheEnabled = c.UpstreamsCacheEnabled + cliConf.isChanged = true +} + +// customUpstreamConfig returns the custom client upstream configuration. +func (m *upstreamManager) customUpstreamConfig(uid UID) (proxyConf *proxy.CustomUpstreamConfig) { + cliConf, ok := m.uidToCustomConf[uid] + if !ok { + // TODO(s.chzhen): Consider panic. + m.logger.Error("no associated custom client upstream config") + + return nil + } + + if !m.isConfigChanged(cliConf) { + return cliConf.proxyConf + } + + if cliConf.proxyConf != nil { + err := cliConf.proxyConf.Close() + if err != nil { + // TODO(s.chzhen): Pass context. + m.logger.Debug("closing custom upstream config", slogutil.KeyError, err) + } + } + + proxyConf = newCustomUpstreamConfig(cliConf, m.commonConf) + cliConf.proxyConf = proxyConf + cliConf.isChanged = false + + return proxyConf +} + +// isConfigChanged returns true if the update is necessary for the custom client +// upstream configuration. +func (m *upstreamManager) isConfigChanged(cliConf *customUpstreamConfig) (ok bool) { + return !m.confUpdate.Equal(cliConf.commonConfUpdate) || cliConf.isChanged +} + +// clearUpstreamCache clears the upstream cache for each stored custom client +// upstream configuration. +func (m *upstreamManager) clearUpstreamCache() { + for _, c := range m.uidToCustomConf { + c.proxyConf.ClearCache() + } +} + +// remove deletes the custom client upstream configuration and closes +// [customUpstreamConfig.proxyConf] if necessary. +func (m *upstreamManager) remove(uid UID) (err error) { + cliConf, ok := m.uidToCustomConf[uid] + if !ok { + // TODO(s.chzhen): Consider panic. + return errors.Error("no associated custom client upstream config") + } + + delete(m.uidToCustomConf, uid) + + if cliConf.proxyConf != nil { + return cliConf.proxyConf.Close() + } + + return nil +} + +// close shuts down each stored custom client upstream configuration. +func (m *upstreamManager) close() (err error) { + var errs []error + for _, c := range m.uidToCustomConf { + if c.proxyConf == nil { + continue + } + + errs = append(errs, c.proxyConf.Close()) + } + + return errors.Join(errs...) +} + +// newCustomUpstreamConfig returns the new properly initialized custom proxy +// upstream configuration for the client. +func newCustomUpstreamConfig( + cliConf *customUpstreamConfig, + conf *CommonUpstreamConfig, +) (proxyConf *proxy.CustomUpstreamConfig) { + upstreams := stringutil.FilterOut(cliConf.upstreams, aghnet.IsCommentOrEmpty) + if len(upstreams) == 0 { + return nil + } + + upsConf, err := proxy.ParseUpstreamsConfig( + upstreams, + &upstream.Options{ + Bootstrap: conf.Bootstrap, + Timeout: time.Duration(conf.UpstreamTimeout), + HTTPVersions: aghnet.UpstreamHTTPVersions(conf.UseHTTP3Upstreams), + PreferIPv6: conf.BootstrapPreferIPv6, + }, + ) + if err != nil { + // Should not happen because upstreams are already validated. See + // [Persistent.validate]. + panic(fmt.Errorf("creating custom upstream config: %w", err)) + } + + return proxy.NewCustomUpstreamConfig( + upsConf, + cliConf.upstreamsCacheEnabled, + int(cliConf.upstreamsCacheSize), + conf.EDNSClientSubnetEnabled, + ) +} diff --git a/internal/dnsforward/beforerequest.go b/internal/dnsforward/beforerequest.go index 5d09c2e5..469af019 100644 --- a/internal/dnsforward/beforerequest.go +++ b/internal/dnsforward/beforerequest.go @@ -15,7 +15,7 @@ import ( var _ proxy.BeforeRequestHandler = (*Server)(nil) // HandleBefore is the handler that is called before any other processing, -// including logs. It performs access checks and puts the client ID, if there +// including logs. It performs access checks and puts the ClientID, if there // is one, into the server's cache. // // TODO(d.kolyshev): Extract to separate package. diff --git a/internal/dnsforward/beforerequest_internal_test.go b/internal/dnsforward/beforerequest_internal_test.go index 7e0d6e9b..d4be8782 100644 --- a/internal/dnsforward/beforerequest_internal_test.go +++ b/internal/dnsforward/beforerequest_internal_test.go @@ -266,6 +266,7 @@ func TestServer_HandleBefore_udp(t *testing.T) { UpstreamDNS: []string{localUpsAddr}, UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, }) diff --git a/internal/dnsforward/clientid.go b/internal/dnsforward/clientid.go index 352db4f2..9c18b342 100644 --- a/internal/dnsforward/clientid.go +++ b/internal/dnsforward/clientid.go @@ -62,7 +62,7 @@ func clientIDFromClientServerName( return strings.ToLower(clientID), nil } -// clientIDFromDNSContextHTTPS extracts the client's ID from the path of the +// clientIDFromDNSContextHTTPS extracts the ClientID from the path of the // client's DNS-over-HTTPS request. func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err error) { r := pctx.HTTPRequest diff --git a/internal/dnsforward/clientscontainer.go b/internal/dnsforward/clientscontainer.go new file mode 100644 index 00000000..a3e39163 --- /dev/null +++ b/internal/dnsforward/clientscontainer.go @@ -0,0 +1,46 @@ +package dnsforward + +import ( + "net/netip" + + "github.com/AdguardTeam/AdGuardHome/internal/client" + "github.com/AdguardTeam/dnsproxy/proxy" +) + +// ClientsContainer provides information about preconfigured DNS clients. +type ClientsContainer interface { + // CustomUpstreamConfig returns the custom client upstream configuration, if + // any. It prioritizes ClientID over client IP address to identify the + // client. + CustomUpstreamConfig(clientID string, cliAddr netip.Addr) (conf *proxy.CustomUpstreamConfig) + + // UpdateCommonUpstreamConfig updates the common upstream configuration. + UpdateCommonUpstreamConfig(conf *client.CommonUpstreamConfig) + + // ClearUpstreamCache clears the upstream cache for each stored custom + // client upstream configuration. + ClearUpstreamCache() +} + +// EmptyClientsContainer is an [ClientsContainer] implementation that does nothing. +type EmptyClientsContainer struct{} + +// type check +var _ ClientsContainer = EmptyClientsContainer{} + +// CustomUpstreamConfig implements the [ClientsContainer] interface for +// EmptyClientsContainer. +func (EmptyClientsContainer) CustomUpstreamConfig( + clientID string, + cliAddr netip.Addr, +) (conf *proxy.CustomUpstreamConfig) { + return nil +} + +// UpdateCommonUpstreamConfig implements the [ClientsContainer] interface for +// EmptyClientsContainer. +func (EmptyClientsContainer) UpdateCommonUpstreamConfig(conf *client.CommonUpstreamConfig) {} + +// ClearUpstreamCache implements the [ClientsContainer] interface for +// EmptyClientsContainer. +func (EmptyClientsContainer) ClearUpstreamCache() {} diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index e91657ed..c549a07e 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -29,19 +29,6 @@ import ( "github.com/ameshkov/dnscrypt/v2" ) -// ClientsContainer provides information about preconfigured DNS clients. -type ClientsContainer interface { - // UpstreamConfigByID returns the custom upstream configuration for the - // client having id, using boot to initialize the one if necessary. It - // returns nil if there is no custom upstream configuration for the client. - // The id is expected to be either a string representation of an IP address - // or the ClientID. - UpstreamConfigByID( - id string, - boot upstream.Resolver, - ) (conf *proxy.CustomUpstreamConfig, err error) -} - // Config represents the DNS filtering configuration of AdGuard Home. The zero // Config is empty and ready for use. type Config struct { @@ -467,7 +454,7 @@ func (s *Server) prepareIpsetListSettings() (ipsets []string, err error) { } ipsets = stringutil.SplitTrimmed(string(data), "\n") - ipsets = slices.DeleteFunc(ipsets, IsCommentOrEmpty) + ipsets = slices.DeleteFunc(ipsets, aghnet.IsCommentOrEmpty) log.Debug("dns: using %d ipset rules from file %q", len(ipsets), fn) @@ -478,7 +465,7 @@ func (s *Server) prepareIpsetListSettings() (ipsets []string, err error) { // the configuration itself. func (conf *ServerConfig) loadUpstreams() (upstreams []string, err error) { if conf.UpstreamDNSFileName == "" { - return stringutil.FilterOut(conf.UpstreamDNS, IsCommentOrEmpty), nil + return stringutil.FilterOut(conf.UpstreamDNS, aghnet.IsCommentOrEmpty), nil } var data []byte @@ -491,7 +478,7 @@ func (conf *ServerConfig) loadUpstreams() (upstreams []string, err error) { log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), conf.UpstreamDNSFileName) - return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil + return stringutil.FilterOut(upstreams, aghnet.IsCommentOrEmpty), nil } // collectListenAddr adds addrPort to addrs. It also adds its port to diff --git a/internal/dnsforward/dns64_test.go b/internal/dnsforward/dns64_test.go index 18bc348f..205cbe7e 100644 --- a/internal/dnsforward/dns64_test.go +++ b/internal/dnsforward/dns64_test.go @@ -299,6 +299,7 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) { Config: Config{ UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + ClientsContainer: EmptyClientsContainer{}, UpstreamDNS: []string{upsAddr}, }, UsePrivateRDNS: true, @@ -337,6 +338,7 @@ func TestServer_dns64WithDisabledRDNS(t *testing.T) { Config: Config{ UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + ClientsContainer: EmptyClientsContainer{}, UpstreamDNS: []string{upsAddr}, }, UsePrivateRDNS: false, diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 759e5c25..00cfbe30 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -540,7 +540,7 @@ func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) { uc, err := newUpstreamConfig(upstreams, defaultDNS, &upstream.Options{ Bootstrap: boot, Timeout: s.conf.UpstreamTimeout, - HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams), + HTTPVersions: aghnet.UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams), PreferIPv6: s.conf.BootstrapPreferIPv6, // Use a customized set of RootCAs, because Go's default mechanism of // loading TLS roots does not always work properly on some routers so we're @@ -557,6 +557,13 @@ func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) { } s.conf.UpstreamConfig = uc + s.conf.ClientsContainer.UpdateCommonUpstreamConfig(&client.CommonUpstreamConfig{ + Bootstrap: boot, + UpstreamTimeout: s.conf.UpstreamTimeout, + BootstrapPreferIPv6: s.conf.BootstrapPreferIPv6, + EDNSClientSubnetEnabled: s.conf.EDNSClientSubnet.Enabled, + UseHTTP3Upstreams: s.conf.UseHTTP3Upstreams, + }) return nil } @@ -630,7 +637,7 @@ func (s *Server) prepareInternalDNS() (err error) { bootOpts := &upstream.Options{ Timeout: DefaultTimeout, - HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams), + HTTPVersions: aghnet.UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams), } s.bootstrap, s.bootResolvers, err = newBootstrap(s.conf.BootstrapDNS, s.etcHosts, bootOpts) @@ -661,7 +668,7 @@ func (s *Server) prepareInternalDNS() (err error) { // setupFallbackDNS initializes the fallback DNS servers. func (s *Server) setupFallbackDNS() (uc *proxy.UpstreamConfig, err error) { fallbacks := s.conf.FallbackDNS - fallbacks = stringutil.FilterOut(fallbacks, IsCommentOrEmpty) + fallbacks = stringutil.FilterOut(fallbacks, aghnet.IsCommentOrEmpty) if len(fallbacks) == 0 { return nil, nil } diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 6227dd09..0ced288d 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -23,6 +23,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering/hashprefix" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" @@ -61,6 +62,42 @@ const ( // TODO(a.garipov): Use more. var testClientAddrPort = netip.MustParseAddrPort("1.2.3.4:12345") +// type check +var _ ClientsContainer = (*clientsContainer)(nil) + +// clientsContainer is a mock [ClientsContainer] implementation for tests. +type clientsContainer struct { + OnCustomUpstreamConfig func( + clientID string, + cliAddr netip.Addr, + ) (conf *proxy.CustomUpstreamConfig) + + OnUpdateCommonUpstreamConfig func(conf *client.CommonUpstreamConfig) + + OnClearUpstreamCache func() +} + +// CustomUpstreamConfig implements the [ClientsContainer] interface for +// *clientsContainer. +func (c *clientsContainer) CustomUpstreamConfig( + clientID string, + cliAddr netip.Addr, +) (conf *proxy.CustomUpstreamConfig) { + return c.OnCustomUpstreamConfig(clientID, cliAddr) +} + +// UpdateCommonUpstreamConfig implements the [ClientsContainer] interface for +// *clientsContainer. +func (c *clientsContainer) UpdateCommonUpstreamConfig(conf *client.CommonUpstreamConfig) { + c.OnUpdateCommonUpstreamConfig(conf) +} + +// ClearUpstreamCache implements the [ClientsContainer] interface for +// *clientsContainer. +func (c *clientsContainer) ClearUpstreamCache() { + c.OnClearUpstreamCache() +} + func startDeferStop(t *testing.T, s *Server) { t.Helper() @@ -168,6 +205,7 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte) Config: Config{ UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, }) @@ -297,6 +335,7 @@ func TestServer(t *testing.T) { Config: Config{ UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, }) @@ -337,6 +376,7 @@ func TestServer_timeout(t *testing.T) { Config: Config{ UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, } @@ -364,6 +404,7 @@ func TestServer_timeout(t *testing.T) { s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{ Enabled: false, } + s.conf.Config.ClientsContainer = EmptyClientsContainer{} err = s.Prepare(&s.conf) require.NoError(t, err) @@ -380,6 +421,7 @@ func TestServer_Prepare_fallbacks(t *testing.T) { }, UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, } @@ -405,6 +447,7 @@ func TestServerWithProtectionDisabled(t *testing.T) { Config: Config{ UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, }) @@ -536,6 +579,7 @@ func TestSafeSearch(t *testing.T) { EDNSClientSubnet: &EDNSClientSubnet{ Enabled: false, }, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, } @@ -629,6 +673,7 @@ func TestInvalidRequest(t *testing.T) { EDNSClientSubnet: &EDNSClientSubnet{ Enabled: false, }, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, }) @@ -659,6 +704,7 @@ func TestBlockedRequest(t *testing.T) { EDNSClientSubnet: &EDNSClientSubnet{ Enabled: false, }, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, } @@ -696,6 +742,7 @@ func TestServerCustomClientUpstream(t *testing.T) { EDNSClientSubnet: &EDNSClientSubnet{ Enabled: false, }, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, } @@ -721,12 +768,12 @@ func TestServerCustomClientUpstream(t *testing.T) { forwardConf.EDNSClientSubnet.Enabled, ) - s.conf.ClientsContainer = &aghtest.ClientsContainer{ - OnUpstreamConfigByID: func( + s.conf.ClientsContainer = &clientsContainer{ + OnCustomUpstreamConfig: func( _ string, - _ upstream.Resolver, - ) (conf *proxy.CustomUpstreamConfig, err error) { - return customUpsConf, nil + _ netip.Addr, + ) (conf *proxy.CustomUpstreamConfig) { + return customUpsConf }, } @@ -774,6 +821,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) { EDNSClientSubnet: &EDNSClientSubnet{ Enabled: false, }, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, }) @@ -808,6 +856,7 @@ func TestBlockCNAME(t *testing.T) { EDNSClientSubnet: &EDNSClientSubnet{ Enabled: false, }, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, } @@ -884,6 +933,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) { EDNSClientSubnet: &EDNSClientSubnet{ Enabled: false, }, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, } @@ -930,6 +980,7 @@ func TestNullBlockedRequest(t *testing.T) { EDNSClientSubnet: &EDNSClientSubnet{ Enabled: false, }, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, } @@ -998,6 +1049,7 @@ func TestBlockedCustomIP(t *testing.T) { EDNSClientSubnet: &EDNSClientSubnet{ Enabled: false, }, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, } @@ -1051,6 +1103,7 @@ func TestBlockedByHosts(t *testing.T) { EDNSClientSubnet: &EDNSClientSubnet{ Enabled: false, }, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, } @@ -1103,6 +1156,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) { EDNSClientSubnet: &EDNSClientSubnet{ Enabled: false, }, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, } @@ -1164,6 +1218,7 @@ func TestRewrite(t *testing.T) { EDNSClientSubnet: &EDNSClientSubnet{ Enabled: false, }, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, })) @@ -1290,6 +1345,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) { s.conf.TCPListenAddrs = []*net.TCPAddr{{}} s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false} + s.conf.Config.ClientsContainer = EmptyClientsContainer{} s.conf.Config.UpstreamMode = UpstreamModeLoadBalance err = s.Prepare(&s.conf) @@ -1375,6 +1431,7 @@ func TestPTRResponseFromHosts(t *testing.T) { s.conf.TCPListenAddrs = []*net.TCPAddr{{}} s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false} + s.conf.Config.ClientsContainer = EmptyClientsContainer{} s.conf.Config.UpstreamMode = UpstreamModeLoadBalance err = s.Prepare(&s.conf) @@ -1643,6 +1700,7 @@ func TestServer_Exchange(t *testing.T) { UpstreamDNS: []string{upsAddr}, UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + ClientsContainer: EmptyClientsContainer{}, }, LocalPTRResolvers: []string{localUpsAddr}, UsePrivateRDNS: true, @@ -1665,6 +1723,7 @@ func TestServer_Exchange(t *testing.T) { UpstreamDNS: []string{upsAddr}, UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + ClientsContainer: EmptyClientsContainer{}, }, LocalPTRResolvers: []string{}, ServePlainDNS: true, diff --git a/internal/dnsforward/dnsrewrite_test.go b/internal/dnsforward/dnsrewrite_test.go index 8f26ac85..56043b2e 100644 --- a/internal/dnsforward/dnsrewrite_test.go +++ b/internal/dnsforward/dnsrewrite_test.go @@ -40,6 +40,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) { Config: Config{ UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, }) diff --git a/internal/dnsforward/filter_test.go b/internal/dnsforward/filter_test.go index 57d265f7..922213c4 100644 --- a/internal/dnsforward/filter_test.go +++ b/internal/dnsforward/filter_test.go @@ -36,6 +36,7 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) { EDNSClientSubnet: &EDNSClientSubnet{ Enabled: false, }, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, } diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index ab12524f..cfa428cb 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -11,6 +11,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" @@ -647,7 +648,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { return } - req.BootstrapDNS = stringutil.FilterOut(req.BootstrapDNS, IsCommentOrEmpty) + req.BootstrapDNS = stringutil.FilterOut(req.BootstrapDNS, aghnet.IsCommentOrEmpty) opts := &upstream.Options{ Timeout: s.conf.UpstreamTimeout, @@ -673,6 +674,8 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { // handleCacheClear is the handler for the POST /control/cache_clear HTTP API. func (s *Server) handleCacheClear(w http.ResponseWriter, _ *http.Request) { s.dnsProxy.ClearCache() + s.conf.ClientsContainer.ClearUpstreamCache() + _, _ = io.WriteString(w, "OK") } diff --git a/internal/dnsforward/http_test.go b/internal/dnsforward/http_test.go index e92da018..bf04ee1b 100644 --- a/internal/dnsforward/http_test.go +++ b/internal/dnsforward/http_test.go @@ -83,6 +83,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) { RatelimitSubnetLenIPv6: 56, UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + ClientsContainer: EmptyClientsContainer{}, }, ConfigModified: func() {}, ServePlainDNS: true, @@ -164,6 +165,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) { RatelimitSubnetLenIPv6: 56, UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + ClientsContainer: EmptyClientsContainer{}, }, ConfigModified: func() {}, ServePlainDNS: true, @@ -299,24 +301,6 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) { } } -func TestIsCommentOrEmpty(t *testing.T) { - for _, tc := range []struct { - want assert.BoolAssertionFunc - str string - }{{ - want: assert.True, - str: "", - }, { - want: assert.True, - str: "# comment", - }, { - want: assert.False, - str: "1.2.3.4", - }} { - tc.want(t, IsCommentOrEmpty(tc.str)) - } -} - func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) { t.Helper() @@ -388,6 +372,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) { Config: Config{ UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, }) diff --git a/internal/dnsforward/process.go b/internal/dnsforward/process.go index 66baf368..259aeff2 100644 --- a/internal/dnsforward/process.go +++ b/internal/dnsforward/process.go @@ -1,7 +1,6 @@ package dnsforward import ( - "cmp" "context" "encoding/binary" "net" @@ -577,17 +576,14 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) { return } - // Use the ClientID first, since it has a higher priority. - id := cmp.Or(clientID, pctx.Addr.Addr().String()) - upsConf, err := s.conf.ClientsContainer.UpstreamConfigByID(id, s.bootstrap) - if err != nil { - log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err) - - return - } - + cliAddr := pctx.Addr.Addr() + upsConf := s.conf.ClientsContainer.CustomUpstreamConfig(clientID, cliAddr) if upsConf != nil { - log.Debug("dnsforward: using custom upstreams for client %s", id) + log.Debug( + "dnsforward: using custom upstreams for client with ip %s and clientid %q", + cliAddr, + clientID, + ) pctx.CustomUpstreamConfig = upsConf } diff --git a/internal/dnsforward/process_internal_test.go b/internal/dnsforward/process_internal_test.go index 2ba42fa2..dcff3b9f 100644 --- a/internal/dnsforward/process_internal_test.go +++ b/internal/dnsforward/process_internal_test.go @@ -81,6 +81,7 @@ func TestServer_ProcessInitial(t *testing.T) { AAAADisabled: tc.aaaaDisabled, UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, } @@ -180,6 +181,7 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) { AAAADisabled: tc.aaaaDisabled, UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, } @@ -324,6 +326,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) { HandleDDR: tc.ddrEnabled, UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + ClientsContainer: EmptyClientsContainer{}, }, TLSConfig: TLSConfig{ ServerName: ddrTestDomainName, @@ -660,6 +663,7 @@ func TestServer_HandleDNSRequest_restrictLocal(t *testing.T) { UpstreamDNS: []string{localUpsAddr}, UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + ClientsContainer: EmptyClientsContainer{}, }, UsePrivateRDNS: true, LocalPTRResolvers: []string{localUpsAddr}, @@ -788,6 +792,7 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) { Config: Config{ UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + ClientsContainer: EmptyClientsContainer{}, }, UsePrivateRDNS: true, LocalPTRResolvers: []string{localUpsAddr}, @@ -816,6 +821,7 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) { Config: Config{ UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + ClientsContainer: EmptyClientsContainer{}, }, UsePrivateRDNS: false, LocalPTRResolvers: []string{localUpsAddr}, diff --git a/internal/dnsforward/svcbmsg_test.go b/internal/dnsforward/svcbmsg_test.go index c5dbff6f..611549db 100644 --- a/internal/dnsforward/svcbmsg_test.go +++ b/internal/dnsforward/svcbmsg_test.go @@ -19,6 +19,7 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) { Config: Config{ UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + ClientsContainer: EmptyClientsContainer{}, }, ServePlainDNS: true, }) diff --git a/internal/dnsforward/upstreams.go b/internal/dnsforward/upstreams.go index 00e10125..618601bd 100644 --- a/internal/dnsforward/upstreams.go +++ b/internal/dnsforward/upstreams.go @@ -94,7 +94,7 @@ func newPrivateConfig( ) (uc *proxy.UpstreamConfig, err error) { confNeedsFiltering := len(addrs) > 0 if confNeedsFiltering { - addrs = stringutil.FilterOut(addrs, IsCommentOrEmpty) + addrs = stringutil.FilterOut(addrs, aghnet.IsCommentOrEmpty) } else { sysResolvers := slices.DeleteFunc(slices.Clone(sysResolvers.Addrs()), unwanted.Has) addrs = make([]string, 0, len(sysResolvers)) @@ -127,20 +127,6 @@ func newPrivateConfig( return uc, nil } -// UpstreamHTTPVersions returns the HTTP versions for upstream configuration -// depending on configuration. -func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) { - if !http3 { - return upstream.DefaultHTTPVersions - } - - return []upstream.HTTPVersion{ - upstream.HTTPVersion3, - upstream.HTTPVersion2, - upstream.HTTPVersion11, - } -} - // setProxyUpstreamMode sets the upstream mode and related settings in conf // based on provided parameters. func setProxyUpstreamMode( @@ -162,10 +148,3 @@ func setProxyUpstreamMode( return nil } - -// IsCommentOrEmpty returns true if s starts with a "#" character or is empty. -// This function is useful for filtering out non-upstream lines from upstream -// configs. -func IsCommentOrEmpty(s string) (ok bool) { - return len(s) == 0 || s[0] == '#' -} diff --git a/internal/home/clients.go b/internal/home/clients.go index 23958cc9..e2fd62fb 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -12,17 +12,13 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/arpdb" "github.com/AdguardTeam/AdGuardHome/internal/client" - "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/schedule" "github.com/AdguardTeam/AdGuardHome/internal/whois" - "github.com/AdguardTeam/dnsproxy/proxy" - "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/logutil/slogutil" - "github.com/AdguardTeam/golibs/stringutil" ) // clientsContainer is the storage of all runtime and persistent clients. @@ -373,63 +369,6 @@ func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) { return true } -// type check -var _ dnsforward.ClientsContainer = (*clientsContainer)(nil) - -// UpstreamConfigByID implements the [dnsforward.ClientsContainer] interface for -// *clientsContainer. upsConf is nil if the client isn't found or if the client -// has no custom upstreams. -func (clients *clientsContainer) UpstreamConfigByID( - id string, - bootstrap upstream.Resolver, -) (conf *proxy.CustomUpstreamConfig, err error) { - clients.lock.Lock() - defer clients.lock.Unlock() - - c, ok := clients.storage.Find(id) - if !ok { - return nil, nil - } else if c.UpstreamConfig != nil { - return c.UpstreamConfig, nil - } - - upstreams := stringutil.FilterOut(c.Upstreams, dnsforward.IsCommentOrEmpty) - if len(upstreams) == 0 { - return nil, nil - } - - var upsConf *proxy.UpstreamConfig - upsConf, err = proxy.ParseUpstreamsConfig( - upstreams, - &upstream.Options{ - Bootstrap: bootstrap, - Timeout: time.Duration(config.DNS.UpstreamTimeout), - HTTPVersions: dnsforward.UpstreamHTTPVersions(config.DNS.UseHTTP3Upstreams), - PreferIPv6: config.DNS.BootstrapPreferIPv6, - }, - ) - if err != nil { - // Don't wrap the error since it's informative enough as is. - return nil, err - } - - conf = proxy.NewCustomUpstreamConfig( - upsConf, - c.UpstreamsCacheEnabled, - int(c.UpstreamsCacheSize), - config.DNS.EDNSClientSubnet.Enabled, - ) - c.UpstreamConfig = conf - - // TODO(s.chzhen): Pass context. - err = clients.storage.Update(context.TODO(), c.Name, c) - if err != nil { - return nil, fmt.Errorf("setting upstream config: %w", err) - } - - return conf, nil -} - // type check var _ client.AddressUpdater = (*clientsContainer)(nil) diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index 0f80604b..92d563f6 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -1,15 +1,12 @@ package home import ( - "net" - "net/netip" "testing" "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -38,28 +35,3 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) { return c } - -func TestClientsCustomUpstream(t *testing.T) { - clients := newClientsContainer(t) - ctx := testutil.ContextWithTimeout(t, testTimeout) - - // Add client with upstreams. - err := clients.storage.Add(ctx, &client.Persistent{ - Name: "client1", - UID: client.MustNewUID(), - IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")}, - Upstreams: []string{ - "1.1.1.1", - "[/example.org/]8.8.8.8", - }, - }) - require.NoError(t, err) - - upsConf, err := clients.UpstreamConfigByID("1.2.3.4", net.DefaultResolver) - assert.Nil(t, upsConf) - assert.NoError(t, err) - - upsConf, err = clients.UpstreamConfigByID("1.1.1.1", net.DefaultResolver) - require.NotNil(t, upsConf) - assert.NoError(t, err) -} diff --git a/internal/home/dns.go b/internal/home/dns.go index 7b2815f1..e4af268f 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -156,7 +156,13 @@ func initDNSServer( globalContext.clients.clientChecker = globalContext.dnsServer - dnsConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpReg) + dnsConf, err := newServerConfig( + &config.DNS, + config.Clients.Sources, + tlsConf, + httpReg, + globalContext.clients.storage, + ) if err != nil { return fmt.Errorf("newServerConfig: %w", err) } @@ -230,12 +236,13 @@ func newServerConfig( clientSrcConf *clientSourcesConfig, tlsConf *tlsConfigSettings, httpReg aghhttp.RegisterFunc, + clientsContainer dnsforward.ClientsContainer, ) (newConf *dnsforward.ServerConfig, err error) { hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()}) fwdConf := dnsConf.Config fwdConf.FilterHandler = applyAdditionalFiltering - fwdConf.ClientsContainer = &globalContext.clients + fwdConf.ClientsContainer = clientsContainer newConf = &dnsforward.ServerConfig{ UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port), @@ -484,7 +491,13 @@ func reconfigureDNSServer() (err error) { tlsConf := &tlsConfigSettings{} globalContext.tls.WriteDiskConfig(tlsConf) - newConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpRegister) + newConf, err := newServerConfig( + &config.DNS, + config.Clients.Sources, + tlsConf, + httpRegister, + globalContext.clients.storage, + ) if err != nil { return fmt.Errorf("generating forwarding dns server config: %w", err) }