From 6d3b5c364bbdfd5667bcc265835af30446d7634f Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Fri, 14 Jul 2023 17:39:42 +0300 Subject: [PATCH] all: add tests; imp addrproc, docs --- CHANGELOG.md | 4 +- internal/dnsforward/dnsforward_test.go | 36 +++++- internal/dnsforward/filter.go | 8 +- internal/dnsforward/process.go | 37 +++--- internal/dnsforward/process_internal_test.go | 122 ++++++++++++++++--- internal/dnsforward/upstreams.go | 10 +- internal/home/clientaddr.go | 11 +- 7 files changed, 174 insertions(+), 54 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4f2360b4..c586c68b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,8 +25,8 @@ NOTE: Add new changes BELOW THIS COMMENT. ### Fixed -- Occasional client information lookup failures leading to DNS resolving getting - stuck ([#6006]). +- Occasional client information lookup failures that could lead to the DNS + server getting stuck ([#6006]). - `bufio.Scanner: token too long` errors when trying to add filtering-rule lists with lines over 1024 bytes long ([#6003]). diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 705227a1..3ed0ade0 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -39,11 +39,29 @@ func TestMain(m *testing.M) { testutil.DiscardLogOutput(m) } +// testTimeout is the common timeout for tests. +// +// TODO(a.garipov): Use more. +const testTimeout = 1 * time.Second + +// testQuestionTarget is the common question target for tests. +// +// TODO(a.garipov): Use more. +const testQuestionTarget = "target.example" + const ( tlsServerName = "testdns.adguard.com" testMessagesCount = 10 ) +// testClientAddr is the common net.Addr for tests. +// +// TODO(a.garipov): Use more. +var testClientAddr net.Addr = &net.TCPAddr{ + IP: net.IP{1, 2, 3, 4}, + Port: 12345, +} + func startDeferStop(t *testing.T, s *Server) { t.Helper() @@ -53,6 +71,13 @@ func startDeferStop(t *testing.T, s *Server) { testutil.CleanupAndRequireSuccess(t, s.Stop) } +// packageUpstreamVariableMu is used to serialize access to the package-level +// variables of package upstream. +// +// TODO(s.chzhen): Move these parameters to upstream options and remove this +// crutch. +var packageUpstreamVariableMu = &sync.Mutex{} + func createTestServer( t *testing.T, filterConf *filtering.Config, @@ -61,6 +86,9 @@ func createTestServer( ) (s *Server) { t.Helper() + packageUpstreamVariableMu.Lock() + defer packageUpstreamVariableMu.Unlock() + rules := `||nxdomain.example.org ||NULL.example.org^ 127.0.0.1 host.example.org @@ -307,11 +335,9 @@ func TestServer(t *testing.T) { } func TestServer_timeout(t *testing.T) { - const timeout time.Duration = time.Second - t.Run("custom", func(t *testing.T) { srvConf := &ServerConfig{ - UpstreamTimeout: timeout, + UpstreamTimeout: testTimeout, FilteringConfig: FilteringConfig{ BlockingMode: BlockingModeDefault, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, @@ -324,7 +350,7 @@ func TestServer_timeout(t *testing.T) { err = s.Prepare(srvConf) require.NoError(t, err) - assert.Equal(t, timeout, s.conf.UpstreamTimeout) + assert.Equal(t, testTimeout, s.conf.UpstreamTimeout) }) t.Run("default", func(t *testing.T) { @@ -545,7 +571,7 @@ func TestInvalidRequest(t *testing.T) { // Send a DNS request without question. _, _, err := (&dns.Client{ - Timeout: 500 * time.Millisecond, + Timeout: testTimeout, }).Exchange(&req, addr) assert.NoErrorf(t, err, "got a response to an invalid query") diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index f55e3059..3f35afc2 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -50,10 +50,10 @@ func (s *Server) beforeRequestHandler( return true, nil } -// getClientRequestFilteringSettings looks up client filtering settings using -// the client's IP address and ID, if any, from dctx. -func (s *Server) getClientRequestFilteringSettings(dctx *dnsContext) *filtering.Settings { - setts := s.dnsFilter.Settings() +// clientRequestFilteringSettings looks up client filtering settings using the +// client's IP address and ID, if any, from dctx. +func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) { + setts = s.dnsFilter.Settings() setts.ProtectionEnabled = dctx.protectionEnabled if s.conf.FilterHandler != nil { ip, _ := netutil.IPAndPortFromAddr(dctx.proxyCtx.Addr) diff --git a/internal/dnsforward/process.go b/internal/dnsforward/process.go index 4891b144..079cc928 100644 --- a/internal/dnsforward/process.go +++ b/internal/dnsforward/process.go @@ -161,6 +161,22 @@ func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) { return resultCodeSuccess } +// mozillaFQDN is the domain used to signal the Firefox browser to not use its +// own DoH server. +// +// See https://support.mozilla.org/en-US/kb/canary-domain-use-application-dnsnet. +const mozillaFQDN = "use-application-dns.net." + +// healthcheckFQDN is a reserved domain-name used for healthchecking. +// +// [Section 6.2 of RFC 6761] states that DNS Registries/Registrars must not +// grant requests to register test names in the normal way to any person or +// entity, making domain names under the test. TLD free to use in internal +// purposes. +// +// [Section 6.2 of RFC 6761]: https://www.rfc-editor.org/rfc/rfc6761.html#section-6.2 +const healthcheckFQDN = "healthcheck.adguardhome.test." + // processInitial terminates the following processing for some requests if // needed and enriches dctx with some client-specific information. // @@ -170,6 +186,8 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) { defer log.Debug("dnsforward: finished processing initial") pctx := dctx.proxyCtx + s.processClientIP(pctx.Addr) + q := pctx.Req.Question[0] qt := q.Qtype if s.conf.AAAADisabled && qt == dns.TypeAAAA { @@ -178,26 +196,13 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) { return resultCodeFinish } - s.processClientIP(pctx.Addr) - - // Disable Mozilla DoH. - // - // See https://support.mozilla.org/en-US/kb/canary-domain-use-application-dnsnet. - if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == "use-application-dns.net." { + if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == mozillaFQDN { pctx.Res = s.genNXDomain(pctx.Req) return resultCodeFinish } - // Handle a reserved domain healthcheck.adguardhome.test. - // - // [Section 6.2 of RFC 6761] states that DNS Registries/Registrars must not - // grant requests to register test names in the normal way to any person or - // entity, making domain names under test. TLD free to use in internal - // purposes. - // - // [Section 6.2 of RFC 6761]: https://www.rfc-editor.org/rfc/rfc6761.html#section-6.2 - if q.Name == "healthcheck.adguardhome.test." { + if q.Name == healthcheckFQDN { // Generate a NODATA negative response to make nslookup exit with 0. pctx.Res = s.makeResponse(pctx.Req) @@ -212,7 +217,7 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) { // Get the client-specific filtering settings. dctx.protectionEnabled, _ = s.UpdatedProtectionStatus() - dctx.setts = s.getClientRequestFilteringSettings(dctx) + dctx.setts = s.clientRequestFilteringSettings(dctx) return resultCodeSuccess } diff --git a/internal/dnsforward/process_internal_test.go b/internal/dnsforward/process_internal_test.go index 1bcca756..5f2f7d4b 100644 --- a/internal/dnsforward/process_internal_test.go +++ b/internal/dnsforward/process_internal_test.go @@ -12,6 +12,7 @@ import ( "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" + "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -22,6 +23,95 @@ const ( ddrTestFQDN = ddrTestDomainName + "." ) +func TestServer_ProcessInitial(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + target string + wantRCode rules.RCode + qType rules.RRType + aaaaDisabled bool + wantRC resultCode + }{{ + name: "success", + target: testQuestionTarget, + wantRCode: -1, + qType: dns.TypeA, + aaaaDisabled: false, + wantRC: resultCodeSuccess, + }, { + name: "aaaa_disabled", + target: testQuestionTarget, + wantRCode: dns.RcodeSuccess, + qType: dns.TypeAAAA, + aaaaDisabled: true, + wantRC: resultCodeFinish, + }, { + name: "aaaa_disabled_a", + target: testQuestionTarget, + wantRCode: -1, + qType: dns.TypeA, + aaaaDisabled: true, + wantRC: resultCodeSuccess, + }, { + name: "mozilla_canary", + target: mozillaFQDN, + wantRCode: dns.RcodeNameError, + qType: dns.TypeA, + aaaaDisabled: false, + wantRC: resultCodeFinish, + }, { + name: "adguardhome_healthcheck", + target: healthcheckFQDN, + wantRCode: dns.RcodeSuccess, + qType: dns.TypeA, + aaaaDisabled: false, + wantRC: resultCodeFinish, + }} + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + clientIPs := make(chan netip.Addr, 1) + + c := ServerConfig{ + FilteringConfig: FilteringConfig{ + AAAADisabled: tc.aaaaDisabled, + EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + }, + ClientIPs: clientIPs, + } + + s := createTestServer(t, &filtering.Config{}, c, nil) + + dctx := &dnsContext{ + proxyCtx: &proxy.DNSContext{ + Req: createTestMessageWithType(tc.target, tc.qType), + Addr: testClientAddr, + RequestID: 1234, + }, + } + + gotRC := s.processInitial(dctx) + assert.Equal(t, tc.wantRC, gotRC) + + gotAddr, _ := testutil.RequireReceive(t, clientIPs, testTimeout) + assert.Equal(t, netutil.NetAddrToAddrPort(testClientAddr).Addr(), gotAddr) + + if tc.wantRCode > 0 { + gotResp := dctx.proxyCtx.Res + require.NotNil(t, gotResp) + + assert.Equal(t, tc.wantRCode, gotResp.Rcode) + } + }) + } +} + func TestServer_ProcessDDRQuery(t *testing.T) { dohSVCB := &dns.SVCB{ Priority: 1, @@ -64,7 +154,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) { }{{ name: "pass_host", wantRes: resultCodeSuccess, - host: "example.net.", + host: testQuestionTarget, qtype: dns.TypeSVCB, ddrEnabled: true, portDoH: 8043, @@ -234,33 +324,33 @@ func TestServer_ProcessDetermineLocal(t *testing.T) { func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) { knownIP := netip.MustParseAddr("1.2.3.4") testCases := []struct { + wantIP netip.Addr name string host string - wantIP netip.Addr wantRes resultCode isLocalCli bool }{{ + wantIP: knownIP, name: "local_client_success", host: "example.lan", - wantIP: knownIP, wantRes: resultCodeSuccess, isLocalCli: true, }, { + wantIP: netip.Addr{}, name: "local_client_unknown_host", host: "wronghost.lan", - wantIP: netip.Addr{}, wantRes: resultCodeSuccess, isLocalCli: true, }, { + wantIP: netip.Addr{}, name: "external_client_known_host", host: "example.lan", - wantIP: netip.Addr{}, wantRes: resultCodeFinish, isLocalCli: false, }, { + wantIP: netip.Addr{}, name: "external_client_unknown_host", host: "wronghost.lan", - wantIP: netip.Addr{}, wantRes: resultCodeFinish, isLocalCli: false, }} @@ -332,52 +422,52 @@ func TestServer_ProcessDHCPHosts(t *testing.T) { knownIP := netip.MustParseAddr("1.2.3.4") testCases := []struct { + wantIP netip.Addr name string host string suffix string - wantIP netip.Addr wantRes resultCode qtyp uint16 }{{ + wantIP: netip.Addr{}, name: "success_external", host: examplecom, suffix: defaultLocalDomainSuffix, - wantIP: netip.Addr{}, wantRes: resultCodeSuccess, qtyp: dns.TypeA, }, { + wantIP: netip.Addr{}, name: "success_external_non_a", host: examplecom, suffix: defaultLocalDomainSuffix, - wantIP: netip.Addr{}, wantRes: resultCodeSuccess, qtyp: dns.TypeCNAME, }, { + wantIP: knownIP, name: "success_internal", host: examplelan, suffix: defaultLocalDomainSuffix, - wantIP: knownIP, wantRes: resultCodeSuccess, qtyp: dns.TypeA, }, { + wantIP: netip.Addr{}, name: "success_internal_unknown", host: "example-new.lan", suffix: defaultLocalDomainSuffix, - wantIP: netip.Addr{}, wantRes: resultCodeSuccess, qtyp: dns.TypeA, }, { + wantIP: netip.Addr{}, name: "success_internal_aaaa", host: examplelan, suffix: defaultLocalDomainSuffix, - wantIP: netip.Addr{}, wantRes: resultCodeSuccess, qtyp: dns.TypeAAAA, }, { + wantIP: knownIP, name: "success_custom_suffix", host: "example.custom", suffix: "custom", - wantIP: knownIP, wantRes: resultCodeSuccess, qtyp: dns.TypeA, }} @@ -560,10 +650,8 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) { var dnsCtx *dnsContext setup := func(use bool) { proxyCtx = &proxy.DNSContext{ - Addr: &net.TCPAddr{ - IP: net.IP{127, 0, 0, 1}, - }, - Req: createTestMessageWithType(reqAddr, dns.TypePTR), + Addr: testClientAddr, + Req: createTestMessageWithType(reqAddr, dns.TypePTR), } dnsCtx = &dnsContext{ proxyCtx: proxyCtx, diff --git a/internal/dnsforward/upstreams.go b/internal/dnsforward/upstreams.go index cbd92b36..ceec1cb7 100644 --- a/internal/dnsforward/upstreams.go +++ b/internal/dnsforward/upstreams.go @@ -42,11 +42,13 @@ func (s *Server) loadUpstreams() (upstreams []string, err error) { // prepareUpstreamSettings sets upstream DNS server settings. func (s *Server) prepareUpstreamSettings() (err error) { - // We're setting a customized set of RootCAs. The reason is that Go default - // mechanism of loading TLS roots does not always work properly on some - // routers so we're loading roots manually and pass it here. + // Use a customized set of RootCAs, because Go's default mechanism of + // loading TLS roots does not always work properly on some routers so we're + // loading roots manually and pass it here. // // See [aghtls.SystemRootCAs]. + // + // TODO(a.garipov): Investigate if that's true. upstream.RootCAs = s.conf.TLSv12Roots upstream.CipherSuites = s.conf.TLSCiphers @@ -190,7 +192,7 @@ func (s *Server) resolveUpstreamsWithHosts( // extractUpstreamHost returns the hostname of addr without port with an // assumption that any address passed here has already been successfully parsed -// by [upstream.AddressToUpstream]. This function eesentially mirrors the logic +// by [upstream.AddressToUpstream]. This function essentially mirrors the logic // of [upstream.AddressToUpstream], see TODO on [replaceUpstreamsWithHosts]. func extractUpstreamHost(addr string) (host string) { var err error diff --git a/internal/home/clientaddr.go b/internal/home/clientaddr.go index d2e4f021..6684d4e4 100644 --- a/internal/home/clientaddr.go +++ b/internal/home/clientaddr.go @@ -38,7 +38,10 @@ const ( // newClientAddrProcessor returns a new client address processor. c must not be // nil. func newClientAddrProcessor(c *clientSourcesConfig) (p *clientAddrProcessor) { - p = &clientAddrProcessor{} + p = &clientAddrProcessor{ + rdns: &rdns.Empty{}, + whois: &whois.Empty{}, + } if c.RDNS { p.rdns = rdns.New(&rdns.Config{ @@ -46,8 +49,6 @@ func newClientAddrProcessor(c *clientSourcesConfig) (p *clientAddrProcessor) { CacheSize: defaultCacheSize, CacheTTL: defaultIPTTL, }) - } else { - p.rdns = rdns.Empty{} } if c.WHOIS { @@ -78,15 +79,13 @@ func newClientAddrProcessor(c *clientSourcesConfig) (p *clientAddrProcessor) { MaxInfoLen: defaultMaxInfoLen, CacheTTL: defaultIPTTL, }) - } else { - p.whois = whois.Empty{} } return p } // process processes the incoming client IP-address information. It is intended -// to be used as a goroutine. +// to be used as a goroutine. Once clientIPs is closed, process exits. func (p *clientAddrProcessor) process(clientIPs <-chan netip.Addr) { defer log.OnPanic("clientAddrProcessor.process")