Compare commits

...

2 Commits

Author SHA1 Message Date
Ainar Garipov
6d3b5c364b all: add tests; imp addrproc, docs 2023-07-14 17:39:42 +03:00
Ainar Garipov
9f93a21bf6 all: imp client resolving 2023-07-13 18:18:49 +03:00
12 changed files with 386 additions and 209 deletions

View File

@@ -25,6 +25,8 @@ NOTE: Add new changes BELOW THIS COMMENT.
### Fixed ### Fixed
- Occasional client information lookup failures that could lead to the DNS
server getting stuck ([#6006]).
- `bufio.Scanner: token too long` errors when trying to add filtering-rule lists - `bufio.Scanner: token too long` errors when trying to add filtering-rule lists
with lines over 1024 bytes long ([#6003]). with lines over 1024 bytes long ([#6003]).
@@ -34,6 +36,7 @@ NOTE: Add new changes BELOW THIS COMMENT.
the `Dockerfile`. the `Dockerfile`.
[#6003]: https://github.com/AdguardTeam/AdGuardHome/issues/6003 [#6003]: https://github.com/AdguardTeam/AdGuardHome/issues/6003
[#6006]: https://github.com/AdguardTeam/AdGuardHome/issues/6006
<!-- <!--
NOTE: Add new changes ABOVE THIS COMMENT. NOTE: Add new changes ABOVE THIS COMMENT.

View File

@@ -270,7 +270,10 @@ type ServerConfig struct {
UDPListenAddrs []*net.UDPAddr // UDP listen address UDPListenAddrs []*net.UDPAddr // UDP listen address
TCPListenAddrs []*net.TCPAddr // TCP listen address TCPListenAddrs []*net.TCPAddr // TCP listen address
UpstreamConfig *proxy.UpstreamConfig // Upstream DNS servers config UpstreamConfig *proxy.UpstreamConfig // Upstream DNS servers config
OnDNSRequest func(d *proxy.DNSContext)
// ClientIPs, if not nil, is used to send clients' IP addresses to other
// parts of AdGuard Home that may use it for resolving rDNS, WHOIS, etc.
ClientIPs chan netip.Addr
FilteringConfig FilteringConfig
TLSConfig TLSConfig

View File

@@ -99,6 +99,10 @@ type Server struct {
// must be a valid domain name plus dots on each side. // must be a valid domain name plus dots on each side.
localDomainSuffix string localDomainSuffix string
// ClientIPs, if not nil, is used to send clients' IP addresses to other
// parts of AdGuard Home that may use it for resolving rDNS, WHOIS, etc.
clientIPs chan<- netip.Addr
ipset ipsetCtx ipset ipsetCtx
privateNets netutil.SubnetSet privateNets netutil.SubnetSet
localResolvers *proxy.Proxy localResolvers *proxy.Proxy
@@ -318,7 +322,8 @@ func (s *Server) Exchange(ip netip.Addr) (host string, err error) {
Qclass: dns.ClassINET, Qclass: dns.ClassINET,
}}, }},
} }
ctx := &proxy.DNSContext{
dctx := &proxy.DNSContext{
Proto: "udp", Proto: "udp",
Req: req, Req: req,
StartTime: time.Now(), StartTime: time.Now(),
@@ -336,11 +341,11 @@ func (s *Server) Exchange(ip netip.Addr) (host string, err error) {
resolver = s.internalProxy resolver = s.internalProxy
} }
if err = resolver.Resolve(ctx); err != nil { if err = resolver.Resolve(dctx); err != nil {
return "", err return "", err
} }
return hostFromPTR(ctx.Res) return hostFromPTR(dctx.Res)
} }
// hostFromPTR returns domain name from the PTR response or error. // hostFromPTR returns domain name from the PTR response or error.
@@ -555,6 +560,8 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
s.recDetector.clear() s.recDetector.clear()
s.clientIPs = s.conf.ClientIPs
return nil return nil
} }
@@ -696,6 +703,9 @@ func (s *Server) Reconfigure(conf *ServerConfig) error {
// TODO(a.garipov): This whole piece of API is weird and needs to be remade. // TODO(a.garipov): This whole piece of API is weird and needs to be remade.
if conf == nil { if conf == nil {
conf = &s.conf conf = &s.conf
} else if s.clientIPs != nil {
close(s.clientIPs)
s.clientIPs = nil
} }
err = s.Prepare(conf) err = s.Prepare(conf)

View File

@@ -39,11 +39,29 @@ func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m) testutil.DiscardLogOutput(m)
} }
// testTimeout is the common timeout for tests.
//
// TODO(a.garipov): Use more.
const testTimeout = 1 * time.Second
// testQuestionTarget is the common question target for tests.
//
// TODO(a.garipov): Use more.
const testQuestionTarget = "target.example"
const ( const (
tlsServerName = "testdns.adguard.com" tlsServerName = "testdns.adguard.com"
testMessagesCount = 10 testMessagesCount = 10
) )
// testClientAddr is the common net.Addr for tests.
//
// TODO(a.garipov): Use more.
var testClientAddr net.Addr = &net.TCPAddr{
IP: net.IP{1, 2, 3, 4},
Port: 12345,
}
func startDeferStop(t *testing.T, s *Server) { func startDeferStop(t *testing.T, s *Server) {
t.Helper() t.Helper()
@@ -53,6 +71,13 @@ func startDeferStop(t *testing.T, s *Server) {
testutil.CleanupAndRequireSuccess(t, s.Stop) testutil.CleanupAndRequireSuccess(t, s.Stop)
} }
// packageUpstreamVariableMu is used to serialize access to the package-level
// variables of package upstream.
//
// TODO(s.chzhen): Move these parameters to upstream options and remove this
// crutch.
var packageUpstreamVariableMu = &sync.Mutex{}
func createTestServer( func createTestServer(
t *testing.T, t *testing.T,
filterConf *filtering.Config, filterConf *filtering.Config,
@@ -61,6 +86,9 @@ func createTestServer(
) (s *Server) { ) (s *Server) {
t.Helper() t.Helper()
packageUpstreamVariableMu.Lock()
defer packageUpstreamVariableMu.Unlock()
rules := `||nxdomain.example.org rules := `||nxdomain.example.org
||NULL.example.org^ ||NULL.example.org^
127.0.0.1 host.example.org 127.0.0.1 host.example.org
@@ -307,11 +335,9 @@ func TestServer(t *testing.T) {
} }
func TestServer_timeout(t *testing.T) { func TestServer_timeout(t *testing.T) {
const timeout time.Duration = time.Second
t.Run("custom", func(t *testing.T) { t.Run("custom", func(t *testing.T) {
srvConf := &ServerConfig{ srvConf := &ServerConfig{
UpstreamTimeout: timeout, UpstreamTimeout: testTimeout,
FilteringConfig: FilteringConfig{ FilteringConfig: FilteringConfig{
BlockingMode: BlockingModeDefault, BlockingMode: BlockingModeDefault,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
@@ -324,7 +350,7 @@ func TestServer_timeout(t *testing.T) {
err = s.Prepare(srvConf) err = s.Prepare(srvConf)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, timeout, s.conf.UpstreamTimeout) assert.Equal(t, testTimeout, s.conf.UpstreamTimeout)
}) })
t.Run("default", func(t *testing.T) { t.Run("default", func(t *testing.T) {
@@ -545,7 +571,7 @@ func TestInvalidRequest(t *testing.T) {
// Send a DNS request without question. // Send a DNS request without question.
_, _, err := (&dns.Client{ _, _, err := (&dns.Client{
Timeout: 500 * time.Millisecond, Timeout: testTimeout,
}).Exchange(&req, addr) }).Exchange(&req, addr)
assert.NoErrorf(t, err, "got a response to an invalid query") assert.NoErrorf(t, err, "got a response to an invalid query")

View File

@@ -50,10 +50,10 @@ func (s *Server) beforeRequestHandler(
return true, nil return true, nil
} }
// getClientRequestFilteringSettings looks up client filtering settings using // clientRequestFilteringSettings looks up client filtering settings using the
// the client's IP address and ID, if any, from dctx. // client's IP address and ID, if any, from dctx.
func (s *Server) getClientRequestFilteringSettings(dctx *dnsContext) *filtering.Settings { func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) {
setts := s.dnsFilter.Settings() setts = s.dnsFilter.Settings()
setts.ProtectionEnabled = dctx.protectionEnabled setts.ProtectionEnabled = dctx.protectionEnabled
if s.conf.FilterHandler != nil { if s.conf.FilterHandler != nil {
ip, _ := netutil.IPAndPortFromAddr(dctx.proxyCtx.Addr) ip, _ := netutil.IPAndPortFromAddr(dctx.proxyCtx.Addr)

View File

@@ -30,6 +30,7 @@ type dnsContext struct {
setts *filtering.Settings setts *filtering.Settings
result *filtering.Result result *filtering.Result
// origResp is the response received from upstream. It is set when the // origResp is the response received from upstream. It is set when the
// response is modified by filters. // response is modified by filters.
origResp *dns.Msg origResp *dns.Msg
@@ -48,13 +49,13 @@ type dnsContext struct {
// clientID is the ClientID from DoH, DoQ, or DoT, if provided. // clientID is the ClientID from DoH, DoQ, or DoT, if provided.
clientID string clientID string
// startTime is the time at which the processing of the request has started.
startTime time.Time
// origQuestion is the question received from the client. It is set // origQuestion is the question received from the client. It is set
// when the request is modified by rewrites. // when the request is modified by rewrites.
origQuestion dns.Question origQuestion dns.Question
// startTime is the time at which the processing of the request has started.
startTime time.Time
// protectionEnabled shows if the filtering is enabled, and if the // protectionEnabled shows if the filtering is enabled, and if the
// server's DNS filter is ready. // server's DNS filter is ready.
protectionEnabled bool protectionEnabled bool
@@ -160,6 +161,22 @@ func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess return resultCodeSuccess
} }
// mozillaFQDN is the domain used to signal the Firefox browser to not use its
// own DoH server.
//
// See https://support.mozilla.org/en-US/kb/canary-domain-use-application-dnsnet.
const mozillaFQDN = "use-application-dns.net."
// healthcheckFQDN is a reserved domain-name used for healthchecking.
//
// [Section 6.2 of RFC 6761] states that DNS Registries/Registrars must not
// grant requests to register test names in the normal way to any person or
// entity, making domain names under the test. TLD free to use in internal
// purposes.
//
// [Section 6.2 of RFC 6761]: https://www.rfc-editor.org/rfc/rfc6761.html#section-6.2
const healthcheckFQDN = "healthcheck.adguardhome.test."
// processInitial terminates the following processing for some requests if // processInitial terminates the following processing for some requests if
// needed and enriches dctx with some client-specific information. // needed and enriches dctx with some client-specific information.
// //
@@ -169,6 +186,8 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
defer log.Debug("dnsforward: finished processing initial") defer log.Debug("dnsforward: finished processing initial")
pctx := dctx.proxyCtx pctx := dctx.proxyCtx
s.processClientIP(pctx.Addr)
q := pctx.Req.Question[0] q := pctx.Req.Question[0]
qt := q.Qtype qt := q.Qtype
if s.conf.AAAADisabled && qt == dns.TypeAAAA { if s.conf.AAAADisabled && qt == dns.TypeAAAA {
@@ -177,28 +196,13 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
return resultCodeFinish return resultCodeFinish
} }
if s.conf.OnDNSRequest != nil { if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == mozillaFQDN {
s.conf.OnDNSRequest(pctx)
}
// Disable Mozilla DoH.
//
// See https://support.mozilla.org/en-US/kb/canary-domain-use-application-dnsnet.
if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == "use-application-dns.net." {
pctx.Res = s.genNXDomain(pctx.Req) pctx.Res = s.genNXDomain(pctx.Req)
return resultCodeFinish return resultCodeFinish
} }
// Handle a reserved domain healthcheck.adguardhome.test. if q.Name == healthcheckFQDN {
//
// [Section 6.2 of RFC 6761] states that DNS Registries/Registrars must not
// grant requests to register test names in the normal way to any person or
// entity, making domain names under test. TLD free to use in internal
// purposes.
//
// [Section 6.2 of RFC 6761]: https://www.rfc-editor.org/rfc/rfc6761.html#section-6.2
if q.Name == "healthcheck.adguardhome.test." {
// Generate a NODATA negative response to make nslookup exit with 0. // Generate a NODATA negative response to make nslookup exit with 0.
pctx.Res = s.makeResponse(pctx.Req) pctx.Res = s.makeResponse(pctx.Req)
@@ -213,11 +217,33 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
// Get the client-specific filtering settings. // Get the client-specific filtering settings.
dctx.protectionEnabled, _ = s.UpdatedProtectionStatus() dctx.protectionEnabled, _ = s.UpdatedProtectionStatus()
dctx.setts = s.getClientRequestFilteringSettings(dctx) dctx.setts = s.clientRequestFilteringSettings(dctx)
return resultCodeSuccess return resultCodeSuccess
} }
// processClientIP sends the client IP address to s.clientIPs, if needed.
func (s *Server) processClientIP(addr net.Addr) {
clientIP := netutil.NetAddrToAddrPort(addr).Addr()
if clientIP == (netip.Addr{}) {
log.Info("dnsforward: warning: bad client addr %q", addr)
return
}
// Do not assign s.clientIPs to a local variable to then use, since this
// lock also serializes the closure of s.clientIPs.
s.serverLock.RLock()
defer s.serverLock.RUnlock()
select {
case s.clientIPs <- clientIP:
// Go on.
default:
log.Debug("dnsforward: client ip channel is nil or full; len: %d", len(s.clientIPs))
}
}
func (s *Server) setTableHostToIP(t hostToIPTable) { func (s *Server) setTableHostToIP(t hostToIPTable) {
s.tableHostToIPLock.Lock() s.tableHostToIPLock.Lock()
defer s.tableHostToIPLock.Unlock() defer s.tableHostToIPLock.Unlock()

View File

@@ -12,6 +12,7 @@ import (
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -22,6 +23,95 @@ const (
ddrTestFQDN = ddrTestDomainName + "." ddrTestFQDN = ddrTestDomainName + "."
) )
func TestServer_ProcessInitial(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
target string
wantRCode rules.RCode
qType rules.RRType
aaaaDisabled bool
wantRC resultCode
}{{
name: "success",
target: testQuestionTarget,
wantRCode: -1,
qType: dns.TypeA,
aaaaDisabled: false,
wantRC: resultCodeSuccess,
}, {
name: "aaaa_disabled",
target: testQuestionTarget,
wantRCode: dns.RcodeSuccess,
qType: dns.TypeAAAA,
aaaaDisabled: true,
wantRC: resultCodeFinish,
}, {
name: "aaaa_disabled_a",
target: testQuestionTarget,
wantRCode: -1,
qType: dns.TypeA,
aaaaDisabled: true,
wantRC: resultCodeSuccess,
}, {
name: "mozilla_canary",
target: mozillaFQDN,
wantRCode: dns.RcodeNameError,
qType: dns.TypeA,
aaaaDisabled: false,
wantRC: resultCodeFinish,
}, {
name: "adguardhome_healthcheck",
target: healthcheckFQDN,
wantRCode: dns.RcodeSuccess,
qType: dns.TypeA,
aaaaDisabled: false,
wantRC: resultCodeFinish,
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
clientIPs := make(chan netip.Addr, 1)
c := ServerConfig{
FilteringConfig: FilteringConfig{
AAAADisabled: tc.aaaaDisabled,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ClientIPs: clientIPs,
}
s := createTestServer(t, &filtering.Config{}, c, nil)
dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{
Req: createTestMessageWithType(tc.target, tc.qType),
Addr: testClientAddr,
RequestID: 1234,
},
}
gotRC := s.processInitial(dctx)
assert.Equal(t, tc.wantRC, gotRC)
gotAddr, _ := testutil.RequireReceive(t, clientIPs, testTimeout)
assert.Equal(t, netutil.NetAddrToAddrPort(testClientAddr).Addr(), gotAddr)
if tc.wantRCode > 0 {
gotResp := dctx.proxyCtx.Res
require.NotNil(t, gotResp)
assert.Equal(t, tc.wantRCode, gotResp.Rcode)
}
})
}
}
func TestServer_ProcessDDRQuery(t *testing.T) { func TestServer_ProcessDDRQuery(t *testing.T) {
dohSVCB := &dns.SVCB{ dohSVCB := &dns.SVCB{
Priority: 1, Priority: 1,
@@ -64,7 +154,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
}{{ }{{
name: "pass_host", name: "pass_host",
wantRes: resultCodeSuccess, wantRes: resultCodeSuccess,
host: "example.net.", host: testQuestionTarget,
qtype: dns.TypeSVCB, qtype: dns.TypeSVCB,
ddrEnabled: true, ddrEnabled: true,
portDoH: 8043, portDoH: 8043,
@@ -234,33 +324,33 @@ func TestServer_ProcessDetermineLocal(t *testing.T) {
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) { func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
knownIP := netip.MustParseAddr("1.2.3.4") knownIP := netip.MustParseAddr("1.2.3.4")
testCases := []struct { testCases := []struct {
wantIP netip.Addr
name string name string
host string host string
wantIP netip.Addr
wantRes resultCode wantRes resultCode
isLocalCli bool isLocalCli bool
}{{ }{{
wantIP: knownIP,
name: "local_client_success", name: "local_client_success",
host: "example.lan", host: "example.lan",
wantIP: knownIP,
wantRes: resultCodeSuccess, wantRes: resultCodeSuccess,
isLocalCli: true, isLocalCli: true,
}, { }, {
wantIP: netip.Addr{},
name: "local_client_unknown_host", name: "local_client_unknown_host",
host: "wronghost.lan", host: "wronghost.lan",
wantIP: netip.Addr{},
wantRes: resultCodeSuccess, wantRes: resultCodeSuccess,
isLocalCli: true, isLocalCli: true,
}, { }, {
wantIP: netip.Addr{},
name: "external_client_known_host", name: "external_client_known_host",
host: "example.lan", host: "example.lan",
wantIP: netip.Addr{},
wantRes: resultCodeFinish, wantRes: resultCodeFinish,
isLocalCli: false, isLocalCli: false,
}, { }, {
wantIP: netip.Addr{},
name: "external_client_unknown_host", name: "external_client_unknown_host",
host: "wronghost.lan", host: "wronghost.lan",
wantIP: netip.Addr{},
wantRes: resultCodeFinish, wantRes: resultCodeFinish,
isLocalCli: false, isLocalCli: false,
}} }}
@@ -332,52 +422,52 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
knownIP := netip.MustParseAddr("1.2.3.4") knownIP := netip.MustParseAddr("1.2.3.4")
testCases := []struct { testCases := []struct {
wantIP netip.Addr
name string name string
host string host string
suffix string suffix string
wantIP netip.Addr
wantRes resultCode wantRes resultCode
qtyp uint16 qtyp uint16
}{{ }{{
wantIP: netip.Addr{},
name: "success_external", name: "success_external",
host: examplecom, host: examplecom,
suffix: defaultLocalDomainSuffix, suffix: defaultLocalDomainSuffix,
wantIP: netip.Addr{},
wantRes: resultCodeSuccess, wantRes: resultCodeSuccess,
qtyp: dns.TypeA, qtyp: dns.TypeA,
}, { }, {
wantIP: netip.Addr{},
name: "success_external_non_a", name: "success_external_non_a",
host: examplecom, host: examplecom,
suffix: defaultLocalDomainSuffix, suffix: defaultLocalDomainSuffix,
wantIP: netip.Addr{},
wantRes: resultCodeSuccess, wantRes: resultCodeSuccess,
qtyp: dns.TypeCNAME, qtyp: dns.TypeCNAME,
}, { }, {
wantIP: knownIP,
name: "success_internal", name: "success_internal",
host: examplelan, host: examplelan,
suffix: defaultLocalDomainSuffix, suffix: defaultLocalDomainSuffix,
wantIP: knownIP,
wantRes: resultCodeSuccess, wantRes: resultCodeSuccess,
qtyp: dns.TypeA, qtyp: dns.TypeA,
}, { }, {
wantIP: netip.Addr{},
name: "success_internal_unknown", name: "success_internal_unknown",
host: "example-new.lan", host: "example-new.lan",
suffix: defaultLocalDomainSuffix, suffix: defaultLocalDomainSuffix,
wantIP: netip.Addr{},
wantRes: resultCodeSuccess, wantRes: resultCodeSuccess,
qtyp: dns.TypeA, qtyp: dns.TypeA,
}, { }, {
wantIP: netip.Addr{},
name: "success_internal_aaaa", name: "success_internal_aaaa",
host: examplelan, host: examplelan,
suffix: defaultLocalDomainSuffix, suffix: defaultLocalDomainSuffix,
wantIP: netip.Addr{},
wantRes: resultCodeSuccess, wantRes: resultCodeSuccess,
qtyp: dns.TypeAAAA, qtyp: dns.TypeAAAA,
}, { }, {
wantIP: knownIP,
name: "success_custom_suffix", name: "success_custom_suffix",
host: "example.custom", host: "example.custom",
suffix: "custom", suffix: "custom",
wantIP: knownIP,
wantRes: resultCodeSuccess, wantRes: resultCodeSuccess,
qtyp: dns.TypeA, qtyp: dns.TypeA,
}} }}
@@ -560,10 +650,8 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
var dnsCtx *dnsContext var dnsCtx *dnsContext
setup := func(use bool) { setup := func(use bool) {
proxyCtx = &proxy.DNSContext{ proxyCtx = &proxy.DNSContext{
Addr: &net.TCPAddr{ Addr: testClientAddr,
IP: net.IP{127, 0, 0, 1}, Req: createTestMessageWithType(reqAddr, dns.TypePTR),
},
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
} }
dnsCtx = &dnsContext{ dnsCtx = &dnsContext{
proxyCtx: proxyCtx, proxyCtx: proxyCtx,

View File

@@ -42,11 +42,13 @@ func (s *Server) loadUpstreams() (upstreams []string, err error) {
// prepareUpstreamSettings sets upstream DNS server settings. // prepareUpstreamSettings sets upstream DNS server settings.
func (s *Server) prepareUpstreamSettings() (err error) { func (s *Server) prepareUpstreamSettings() (err error) {
// We're setting a customized set of RootCAs. The reason is that Go default // Use a customized set of RootCAs, because Go's default mechanism of
// mechanism of loading TLS roots does not always work properly on some // loading TLS roots does not always work properly on some routers so we're
// routers so we're loading roots manually and pass it here. // loading roots manually and pass it here.
// //
// See [aghtls.SystemRootCAs]. // See [aghtls.SystemRootCAs].
//
// TODO(a.garipov): Investigate if that's true.
upstream.RootCAs = s.conf.TLSv12Roots upstream.RootCAs = s.conf.TLSv12Roots
upstream.CipherSuites = s.conf.TLSCiphers upstream.CipherSuites = s.conf.TLSCiphers
@@ -190,7 +192,7 @@ func (s *Server) resolveUpstreamsWithHosts(
// extractUpstreamHost returns the hostname of addr without port with an // extractUpstreamHost returns the hostname of addr without port with an
// assumption that any address passed here has already been successfully parsed // assumption that any address passed here has already been successfully parsed
// by [upstream.AddressToUpstream]. This function eesentially mirrors the logic // by [upstream.AddressToUpstream]. This function essentially mirrors the logic
// of [upstream.AddressToUpstream], see TODO on [replaceUpstreamsWithHosts]. // of [upstream.AddressToUpstream], see TODO on [replaceUpstreamsWithHosts].
func extractUpstreamHost(addr string) (host string) { func extractUpstreamHost(addr string) (host string) {
var err error var err error

145
internal/home/clientaddr.go Normal file
View File

@@ -0,0 +1,145 @@
package home
import (
"context"
"net/netip"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/log"
)
// TODO(a.garipov): It is currently hard to add tests for this structure due to
// strong coupling between it and Context.dnsServer with Context.clients.
// Resolve this coupling and add proper testing.
// clientAddrProcessor processes incoming client addresses with rDNS and WHOIS,
// if configured.
type clientAddrProcessor struct {
rdns rdns.Interface
whois whois.Interface
}
const (
// defaultQueueSize is the size of queue of IPs for rDNS and WHOIS
// processing.
defaultQueueSize = 255
// defaultCacheSize is the maximum size of the cache for rDNS and WHOIS
// processing. It must be greater than zero.
defaultCacheSize = 10_000
// defaultIPTTL is the Time to Live duration for IP addresses cached by
// rDNS and WHOIS.
defaultIPTTL = 1 * time.Hour
)
// newClientAddrProcessor returns a new client address processor. c must not be
// nil.
func newClientAddrProcessor(c *clientSourcesConfig) (p *clientAddrProcessor) {
p = &clientAddrProcessor{
rdns: &rdns.Empty{},
whois: &whois.Empty{},
}
if c.RDNS {
p.rdns = rdns.New(&rdns.Config{
Exchanger: Context.dnsServer,
CacheSize: defaultCacheSize,
CacheTTL: defaultIPTTL,
})
}
if c.WHOIS {
// TODO(s.chzhen): Consider making configurable.
const (
// defaultTimeout is the timeout for WHOIS requests.
defaultTimeout = 5 * time.Second
// defaultMaxConnReadSize is an upper limit in bytes for reading from a
// net.Conn.
defaultMaxConnReadSize = 64 * 1024
// defaultMaxRedirects is the maximum redirects count.
defaultMaxRedirects = 5
// defaultMaxInfoLen is the maximum length of whois.Info fields.
defaultMaxInfoLen = 250
)
p.whois = whois.New(&whois.Config{
DialContext: customDialContext,
ServerAddr: whois.DefaultServer,
Port: whois.DefaultPort,
Timeout: defaultTimeout,
CacheSize: defaultCacheSize,
MaxConnReadSize: defaultMaxConnReadSize,
MaxRedirects: defaultMaxRedirects,
MaxInfoLen: defaultMaxInfoLen,
CacheTTL: defaultIPTTL,
})
}
return p
}
// process processes the incoming client IP-address information. It is intended
// to be used as a goroutine. Once clientIPs is closed, process exits.
func (p *clientAddrProcessor) process(clientIPs <-chan netip.Addr) {
defer log.OnPanic("clientAddrProcessor.process")
log.Info("home: processing client addresses")
for ip := range clientIPs {
p.processRDNS(ip)
p.processWHOIS(ip)
}
log.Info("home: finished processing client addresses")
}
// processRDNS resolves the clients' IP addresses using reverse DNS.
func (p *clientAddrProcessor) processRDNS(ip netip.Addr) {
start := time.Now()
log.Debug("home: processing client %s with rdns", ip)
defer func() {
log.Debug("home: finished processing client %s with rdns in %s", ip, time.Since(start))
}()
ok := Context.dnsServer.ShouldResolveClient(ip)
if !ok {
return
}
host, changed := p.rdns.Process(ip)
if host == "" || !changed {
return
}
ok = Context.clients.AddHost(ip, host, ClientSourceRDNS)
if ok {
return
}
log.Debug("dns: setting rdns info for client %q: already set with higher priority source", ip)
}
// processWHOIS looks up the information aobut clients' IP addresses in the
// WHOIS databases.
func (p *clientAddrProcessor) processWHOIS(ip netip.Addr) {
start := time.Now()
log.Debug("home: processing client %s with whois", ip)
defer func() {
log.Debug("home: finished processing client %s with whois in %s", ip, time.Since(start))
}()
// TODO(s.chzhen): Move the timeout logic from WHOIS configuration to the
// context.
info, changed := p.whois.Process(context.Background(), ip)
if info == nil || !changed {
return
}
Context.clients.setWHOISInfo(ip, info)
}

View File

@@ -141,7 +141,7 @@ func (clients *clientsContainer) handleHostsUpdates() {
} }
} }
// webHandlersRegistered prevents a [clientsContainer] from regisering its web // webHandlersRegistered prevents a [clientsContainer] from registering its web
// handlers more than once. // handlers more than once.
// //
// TODO(a.garipov): Refactor HTTP handler registration logic. // TODO(a.garipov): Refactor HTTP handler registration logic.

View File

@@ -17,10 +17,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
"github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
@@ -154,134 +151,32 @@ func initDNSServer(
Context.clients.dnsServer = Context.dnsServer Context.clients.dnsServer = Context.dnsServer
dnsConf, err := generateServerConfig(tlsConf, httpReg) dnsConf, err := newServerConfig(tlsConf, httpReg)
if err != nil { if err != nil {
closeDNSServer() closeDNSServer()
return fmt.Errorf("generateServerConfig: %w", err) return fmt.Errorf("newServerConfig: %w", err)
} }
err = Context.dnsServer.Prepare(&dnsConf) err = Context.dnsServer.Prepare(dnsConf)
if err != nil { if err != nil {
closeDNSServer() closeDNSServer()
return fmt.Errorf("dnsServer.Prepare: %w", err) return fmt.Errorf("dnsServer.Prepare: %w", err)
} }
initRDNS() clientIPs := dnsConf.ClientIPs
initWHOIS() addrProc := newClientAddrProcessor(config.Clients.Sources)
go addrProc.process(clientIPs)
const topClientsNumber = 100
for _, ip := range Context.stats.TopClientsIP(topClientsNumber) {
clientIPs <- ip
}
return nil return nil
} }
const (
// defaultQueueSize is the size of queue of IPs for rDNS and WHOIS
// processing.
defaultQueueSize = 255
// defaultCacheSize is the maximum size of the cache for rDNS and WHOIS
// processing. It must be greater than zero.
defaultCacheSize = 10_000
// defaultIPTTL is the Time to Live duration for IP addresses cached by
// rDNS and WHOIS.
defaultIPTTL = 1 * time.Hour
)
// initRDNS initializes the rDNS.
func initRDNS() {
Context.rdnsCh = make(chan netip.Addr, defaultQueueSize)
// TODO(s.chzhen): Add ability to disable it on dns server configuration
// update in [dnsforward] package.
r := rdns.New(&rdns.Config{
Exchanger: Context.dnsServer,
CacheSize: defaultCacheSize,
CacheTTL: defaultIPTTL,
})
go processRDNS(r)
}
// processRDNS processes reverse DNS lookup queries. It is intended to be used
// as a goroutine.
func processRDNS(r rdns.Interface) {
defer log.OnPanic("rdns")
for ip := range Context.rdnsCh {
ok := Context.dnsServer.ShouldResolveClient(ip)
if !ok {
continue
}
host, changed := r.Process(ip)
if host == "" || !changed {
continue
}
ok = Context.clients.AddHost(ip, host, ClientSourceRDNS)
if ok {
continue
}
log.Debug(
"dns: can't set rdns info for client %q: already set with higher priority source",
ip,
)
}
}
// initWHOIS initializes the WHOIS.
//
// TODO(s.chzhen): Consider making configurable.
func initWHOIS() {
const (
// defaultTimeout is the timeout for WHOIS requests.
defaultTimeout = 5 * time.Second
// defaultMaxConnReadSize is an upper limit in bytes for reading from
// net.Conn.
defaultMaxConnReadSize = 64 * 1024
// defaultMaxRedirects is the maximum redirects count.
defaultMaxRedirects = 5
// defaultMaxInfoLen is the maximum length of whois.Info fields.
defaultMaxInfoLen = 250
)
Context.whoisCh = make(chan netip.Addr, defaultQueueSize)
var w whois.Interface
if config.Clients.Sources.WHOIS {
w = whois.New(&whois.Config{
DialContext: customDialContext,
ServerAddr: whois.DefaultServer,
Port: whois.DefaultPort,
Timeout: defaultTimeout,
CacheSize: defaultCacheSize,
MaxConnReadSize: defaultMaxConnReadSize,
MaxRedirects: defaultMaxRedirects,
MaxInfoLen: defaultMaxInfoLen,
CacheTTL: defaultIPTTL,
})
} else {
w = whois.Empty{}
}
go func() {
defer log.OnPanic("whois")
for ip := range Context.whoisCh {
info, changed := w.Process(context.Background(), ip)
if info != nil && changed {
Context.clients.setWHOISInfo(ip, info)
}
}
}()
}
// parseSubnetSet parses a slice of subnets. If the slice is empty, it returns // parseSubnetSet parses a slice of subnets. If the slice is empty, it returns
// a subnet set that matches all locally served networks, see // a subnet set that matches all locally served networks, see
// [netutil.IsLocallyServed]. // [netutil.IsLocallyServed].
@@ -312,17 +207,6 @@ func isRunning() bool {
return Context.dnsServer != nil && Context.dnsServer.IsRunning() return Context.dnsServer != nil && Context.dnsServer.IsRunning()
} }
func onDNSRequest(pctx *proxy.DNSContext) {
ip := netutil.NetAddrToAddrPort(pctx.Addr).Addr()
if ip == (netip.Addr{}) {
// This would be quite weird if we get here.
return
}
Context.rdnsCh <- ip
Context.whoisCh <- ip
}
func ipsToTCPAddrs(ips []netip.Addr, port int) (tcpAddrs []*net.TCPAddr) { func ipsToTCPAddrs(ips []netip.Addr, port int) (tcpAddrs []*net.TCPAddr) {
if ips == nil { if ips == nil {
return nil return nil
@@ -349,19 +233,20 @@ func ipsToUDPAddrs(ips []netip.Addr, port int) (udpAddrs []*net.UDPAddr) {
return udpAddrs return udpAddrs
} }
func generateServerConfig( func newServerConfig(
tlsConf *tlsConfigSettings, tlsConf *tlsConfigSettings,
httpReg aghhttp.RegisterFunc, httpReg aghhttp.RegisterFunc,
) (newConf dnsforward.ServerConfig, err error) { ) (newConf *dnsforward.ServerConfig, err error) {
dnsConf := config.DNS dnsConf := config.DNS
hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()}) hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()})
newConf = dnsforward.ServerConfig{ clientIPs := make(chan netip.Addr, defaultQueueSize)
newConf = &dnsforward.ServerConfig{
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port), UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
TCPListenAddrs: ipsToTCPAddrs(hosts, dnsConf.Port), TCPListenAddrs: ipsToTCPAddrs(hosts, dnsConf.Port),
FilteringConfig: dnsConf.FilteringConfig, FilteringConfig: dnsConf.FilteringConfig,
ConfigModified: onConfigModified, ConfigModified: onConfigModified,
HTTPRegister: httpReg, HTTPRegister: httpReg,
OnDNSRequest: onDNSRequest, ClientIPs: clientIPs,
UseDNS64: config.DNS.UseDNS64, UseDNS64: config.DNS.UseDNS64,
DNS64Prefixes: config.DNS.DNS64Prefixes, DNS64Prefixes: config.DNS.DNS64Prefixes,
} }
@@ -385,9 +270,9 @@ func generateServerConfig(
if tlsConf.PortDNSCrypt != 0 { if tlsConf.PortDNSCrypt != 0 {
newConf.DNSCryptConfig, err = newDNSCrypt(hosts, *tlsConf) newConf.DNSCryptConfig, err = newDNSCrypt(hosts, *tlsConf)
if err != nil { if err != nil {
// Don't wrap the error, because it's already // Don't wrap the error, because it's already wrapped by
// wrapped by newDNSCrypt. // newDNSCrypt.
return dnsforward.ServerConfig{}, err return nil, err
} }
} }
} }
@@ -556,31 +441,26 @@ func startDNSServer() error {
Context.stats.Start() Context.stats.Start()
Context.queryLog.Start() Context.queryLog.Start()
const topClientsNumber = 100 // the number of clients to get
for _, ip := range Context.stats.TopClientsIP(topClientsNumber) {
Context.rdnsCh <- ip
Context.whoisCh <- ip
}
return nil return nil
} }
func reconfigureDNSServer() (err error) { func reconfigureDNSServer() (err error) {
var newConf dnsforward.ServerConfig
tlsConf := &tlsConfigSettings{} tlsConf := &tlsConfigSettings{}
Context.tls.WriteDiskConfig(tlsConf) Context.tls.WriteDiskConfig(tlsConf)
newConf, err = generateServerConfig(tlsConf, httpRegister) newConf, err := newServerConfig(tlsConf, httpRegister)
if err != nil { if err != nil {
return fmt.Errorf("generating forwarding dns server config: %w", err) return fmt.Errorf("generating forwarding dns server config: %w", err)
} }
err = Context.dnsServer.Reconfigure(&newConf) err = Context.dnsServer.Reconfigure(newConf)
if err != nil { if err != nil {
return fmt.Errorf("starting forwarding dns server: %w", err) return fmt.Errorf("starting forwarding dns server: %w", err)
} }
addrProc := newClientAddrProcessor(config.Clients.Sources)
go addrProc.process(newConf.ClientIPs)
return nil return nil
} }

View File

@@ -82,12 +82,6 @@ type homeContext struct {
client *http.Client client *http.Client
appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app
// rdnsCh is the channel for receiving IPs for rDNS processing.
rdnsCh chan netip.Addr
// whoisCh is the channel for receiving IPs for WHOIS processing.
whoisCh chan netip.Addr
// tlsCipherIDs are the ID of the cipher suites that AdGuard Home must use. // tlsCipherIDs are the ID of the cipher suites that AdGuard Home must use.
tlsCipherIDs []uint16 tlsCipherIDs []uint16
@@ -634,10 +628,10 @@ func run(opts options, clientBuildFS fs.FS) {
Context.tls.start() Context.tls.start()
go func() { go func() {
sErr := startDNSServer() startErr := startDNSServer()
if sErr != nil { if startErr != nil {
closeDNSServer() closeDNSServer()
fatalOnError(sErr) fatalOnError(startErr)
} }
}() }()